Test PSA-based CCM cipher operations
diff --git a/tests/suites/test_suite_cipher.function b/tests/suites/test_suite_cipher.function
index da9dfa1..ada7347 100644
--- a/tests/suites/test_suite_cipher.function
+++ b/tests/suites/test_suite_cipher.function
@@ -542,33 +542,66 @@
 /* BEGIN_CASE depends_on:MBEDTLS_CIPHER_MODE_AEAD */
 void auth_crypt_tv( int cipher_id, data_t * key, data_t * iv,
                     data_t * ad, data_t * cipher, data_t * tag,
-                    char * result, data_t * clear )
+                    char * result, data_t * clear, int use_psa )
 {
+    /* Takes an AEAD ciphertext + tag and performs a pair
+     * of AEAD decryption and AEAD encryption. It checks that
+     * this results in the expected plaintext, and that
+     * decryption and encryption are inverse to one another. */
+
     int ret;
-    unsigned char output[267]; /* above + 2 (overwrite check) */
-    unsigned char my_tag[20];
+    unsigned char output[300];        /* Temporary buffer for results of
+                                       * encryption and decryption. */
+    unsigned char *output_tag = NULL; /* Temporary buffer for tag in the
+                                       * encryption step. */
+
     mbedtls_cipher_context_t ctx;
     size_t outlen;
 
+    unsigned char *tmp_tag    = NULL;
+    unsigned char *tmp_cipher = NULL;
+
     mbedtls_cipher_init( &ctx );
-
     memset( output, 0xFF, sizeof( output ) );
-    memset( my_tag, 0xFF, sizeof( my_tag ) );
-
 
     /* Prepare context */
-    TEST_ASSERT( 0 == mbedtls_cipher_setup( &ctx,
-                                       mbedtls_cipher_info_from_type( cipher_id ) ) );
-    TEST_ASSERT( 0 == mbedtls_cipher_setkey( &ctx, key->x, 8 * key->len, MBEDTLS_DECRYPT ) );
+#if !defined(MBEDTLS_USE_PSA_CRYPTO)
+    (void) use_psa;
+#else
+    if( use_psa == 1 )
+    {
+        /* PSA requires that the tag immediately follows the ciphertext. */
+        tmp_cipher = mbedtls_calloc( 1, cipher->len + tag->len );
+        TEST_ASSERT( tmp_cipher != NULL );
+        tmp_tag = tmp_cipher + cipher->len;
+
+        memcpy( tmp_cipher, cipher->x, cipher->len );
+        memcpy( tmp_tag, tag->x, tag->len );
+
+        TEST_ASSERT( 0 == mbedtls_cipher_setup_psa( &ctx,
+                              mbedtls_cipher_info_from_type( cipher_id ),
+                              tag->len ) );
+    }
+    else
+#endif
+    {
+        tmp_tag = tag->x;
+        tmp_cipher = cipher->x;
+        TEST_ASSERT( 0 == mbedtls_cipher_setup( &ctx,
+                              mbedtls_cipher_info_from_type( cipher_id ) ) );
+    }
+
+    TEST_ASSERT( 0 == mbedtls_cipher_setkey( &ctx, key->x, 8 * key->len,
+                                             MBEDTLS_DECRYPT ) );
 
     /* decode buffer and check tag->x */
-    ret = mbedtls_cipher_auth_decrypt( &ctx, iv->x, iv->len, ad->x, ad->len,
-                               cipher->x, cipher->len, output, &outlen,
-                               tag->x, tag->len );
 
-    /* make sure we didn't overwrite */
-    TEST_ASSERT( output[outlen + 0] == 0xFF );
-    TEST_ASSERT( output[outlen + 1] == 0xFF );
+    /* Sanity check that we don't use overly long inputs. */
+    TEST_ASSERT( sizeof( output ) >= cipher->len );
+
+    ret = mbedtls_cipher_auth_decrypt( &ctx, iv->x, iv->len, ad->x, ad->len,
+                               tmp_cipher, cipher->len, output, &outlen,
+                               tmp_tag, tag->len );
 
     /* make sure the message is rejected if it should be */
     if( strcmp( result, "FAIL" ) == 0 )
@@ -587,23 +620,28 @@
     memset( output, 0xFF, sizeof( output ) );
     outlen = 0;
 
+    /* Sanity check that we don't use overly long inputs. */
+    TEST_ASSERT( sizeof( output ) >= clear->len + tag->len );
+
+    output_tag = output + clear->len;
     ret = mbedtls_cipher_auth_encrypt( &ctx, iv->x, iv->len, ad->x, ad->len,
                                clear->x, clear->len, output, &outlen,
-                               my_tag, tag->len );
+                               output_tag, tag->len );
     TEST_ASSERT( ret == 0 );
 
     TEST_ASSERT( outlen == clear->len );
-    TEST_ASSERT( memcmp( output, cipher->x, clear->len ) == 0 );
-    TEST_ASSERT( memcmp( my_tag, tag->x, tag->len ) == 0 );
-
-    /* make sure we didn't overwrite */
-    TEST_ASSERT( output[outlen + 0] == 0xFF );
-    TEST_ASSERT( output[outlen + 1] == 0xFF );
-    TEST_ASSERT( my_tag[tag->len + 0] == 0xFF );
-    TEST_ASSERT( my_tag[tag->len + 1] == 0xFF );
-
+    TEST_ASSERT( memcmp( output, cipher->x, cipher->len ) == 0 );
+    TEST_ASSERT( memcmp( output_tag, tag->x, tag->len ) == 0 );
 
 exit:
+
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    if( use_psa == 1 )
+    {
+        mbedtls_free( tmp_cipher );
+    }
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
+
     mbedtls_cipher_free( &ctx );
 }
 /* END_CASE */
@@ -675,12 +713,12 @@
     if( use_psa == 1 )
     {
         TEST_ASSERT( 0 == mbedtls_cipher_setup_psa( &ctx,
-                               mbedtls_cipher_info_from_type( cipher_id ) ) );
+                              mbedtls_cipher_info_from_type( cipher_id ), 0 ) );
     }
     else
 #endif /* MBEDTLS_USE_PSA_CRYPTO */
     TEST_ASSERT( 0 == mbedtls_cipher_setup( &ctx,
-                               mbedtls_cipher_info_from_type( cipher_id ) ) );
+                              mbedtls_cipher_info_from_type( cipher_id ) ) );
 
     key_len = unhexify( key, hex_key );
     inputlen =  unhexify( input, hex_input );