Test cipher_auth_{en,de}crypt_ext()

Signed-off-by: Manuel Pégourié-Gonnard <manuel.pegourie-gonnard@arm.com>
diff --git a/tests/suites/test_suite_cipher.function b/tests/suites/test_suite_cipher.function
index a40bfb5..7ea1a14 100644
--- a/tests/suites/test_suite_cipher.function
+++ b/tests/suites/test_suite_cipher.function
@@ -1013,9 +1013,15 @@
      * of AEAD decryption and AEAD encryption. Check that
      * this results in the expected plaintext, and that
      * decryption and encryption are inverse to one another.
+     *
+     * Do that twice:
+     * - once with legacy functions auth_decrypt/auth_encrypt
+     * - once with new functions auth_decrypt_ext/auth_encrypt_ext
+     * This allows testing both without duplicating test cases.
      */
 
     int ret;
+    int using_nist_kw, using_nist_kw_padding;
     unsigned char output[300];        /* Temporary buffer for results of
                                        * encryption and decryption. */
     unsigned char *output_tag = NULL; /* Temporary buffer for tag in the
@@ -1027,6 +1033,13 @@
     unsigned char *tmp_tag    = NULL;
     unsigned char *tmp_cipher = NULL;
 
+    unsigned char *cipher_plus_tag = NULL;
+    size_t cipher_plus_tag_len;
+    unsigned char *decrypt_buf = NULL;
+    size_t decrypt_buf_len = 0;
+    unsigned char *encrypt_buf = NULL;
+    size_t encrypt_buf_len = 0;
+
     mbedtls_cipher_init( &ctx );
     memset( output, 0xFF, sizeof( output ) );
 
@@ -1039,30 +1052,163 @@
 #endif
 
     /*
+     * Are we using NIST_KW? with padding?
+     */
+    using_nist_kw_padding = cipher_id == MBEDTLS_CIPHER_AES_128_KWP ||
+                            cipher_id == MBEDTLS_CIPHER_AES_192_KWP ||
+                            cipher_id == MBEDTLS_CIPHER_AES_256_KWP;
+    using_nist_kw = cipher_id == MBEDTLS_CIPHER_AES_128_KW ||
+                    cipher_id == MBEDTLS_CIPHER_AES_192_KW ||
+                    cipher_id == MBEDTLS_CIPHER_AES_256_KW ||
+                    using_nist_kw_padding;
+
+    /*
      * Prepare context for decryption
      */
     cipher_reset_key( &ctx, cipher_id, use_psa, tag->len, key,
                       MBEDTLS_DECRYPT );
 
     /*
-     * Prepare buffers/pointers for decryption
+     * prepare buffer for decryption
+     * (we need the tag appended to the ciphertext)
+     */
+    cipher_plus_tag_len = cipher->len + tag->len;
+    ASSERT_ALLOC( cipher_plus_tag, cipher_plus_tag_len );
+    memcpy( cipher_plus_tag, cipher->x, cipher->len );
+    memcpy( cipher_plus_tag + cipher->len, tag->x, tag->len );
+
+    /*
+     * Compute length of output buffer according to the documentation
+     */
+    if( using_nist_kw )
+        decrypt_buf_len = cipher_plus_tag_len - 8;
+    else
+        decrypt_buf_len = cipher_plus_tag_len - tag->len;
+
+
+    /*
+     * Try decrypting to a buffer that's 1B too small
+     */
+    if( decrypt_buf_len != 0 )
+    {
+        ASSERT_ALLOC( decrypt_buf, decrypt_buf_len - 1 );
+
+        outlen = 0;
+        ret = mbedtls_cipher_auth_decrypt_ext( &ctx, iv->x, iv->len,
+                ad->x, ad->len, cipher_plus_tag, cipher_plus_tag_len,
+                decrypt_buf, decrypt_buf_len - 1, &outlen, tag->len );
+        TEST_ASSERT( ret == MBEDTLS_ERR_CIPHER_BAD_INPUT_DATA );
+
+        mbedtls_free( decrypt_buf );
+        decrypt_buf = NULL;
+    }
+
+    /*
+     * Authenticate and decrypt, and check result
+     */
+    ASSERT_ALLOC( decrypt_buf, decrypt_buf_len );
+
+    outlen = 0;
+    ret = mbedtls_cipher_auth_decrypt_ext( &ctx, iv->x, iv->len,
+            ad->x, ad->len, cipher_plus_tag, cipher_plus_tag_len,
+            decrypt_buf, decrypt_buf_len, &outlen, tag->len );
+
+    if( strcmp( result, "FAIL" ) == 0 )
+    {
+        TEST_ASSERT( ret == MBEDTLS_ERR_CIPHER_AUTH_FAILED );
+    }
+    else
+    {
+        TEST_ASSERT( ret == 0 );
+
+        TEST_ASSERT( outlen == clear->len );
+        if( clear->len != 0 )
+            TEST_ASSERT( memcmp( decrypt_buf, clear->x, clear->len ) == 0 );
+    }
+
+    /* Free this, but keep cipher_plus_tag for legacy function with PSA */
+    mbedtls_free( decrypt_buf );
+    decrypt_buf = NULL;
+
+    /*
+     * Encrypt back if test data was authentic
+     */
+    if( strcmp( result, "FAIL" ) != 0 )
+    {
+        /* prepare context for encryption */
+        cipher_reset_key( &ctx, cipher_id, use_psa, tag->len, key,
+                          MBEDTLS_ENCRYPT );
+
+        /*
+         * Compute size of output buffer according to documentation
+         */
+        if( using_nist_kw )
+        {
+            encrypt_buf_len = clear->len + 8;
+            if( using_nist_kw_padding && encrypt_buf_len % 8 != 0 )
+                encrypt_buf_len += 8 - encrypt_buf_len % 8;
+        }
+        else
+        {
+            encrypt_buf_len = clear->len + tag->len;
+        }
+
+        /*
+         * Try encrypting with an output buffer that's 1B too small
+         */
+        ASSERT_ALLOC( encrypt_buf, encrypt_buf_len - 1 );
+
+        outlen = 0;
+        ret = mbedtls_cipher_auth_encrypt_ext( &ctx, iv->x, iv->len,
+                ad->x, ad->len, clear->x, clear->len,
+                encrypt_buf, encrypt_buf_len - 1, &outlen, tag->len );
+        TEST_ASSERT( ret != 0 );
+
+        mbedtls_free( encrypt_buf );
+        encrypt_buf = NULL;
+
+        /*
+         * Encrypt and check the result
+         */
+        ASSERT_ALLOC( encrypt_buf, encrypt_buf_len );
+
+        outlen = 0;
+        ret = mbedtls_cipher_auth_encrypt_ext( &ctx, iv->x, iv->len,
+                ad->x, ad->len, clear->x, clear->len,
+                encrypt_buf, encrypt_buf_len, &outlen, tag->len );
+        TEST_ASSERT( ret == 0 );
+
+        TEST_ASSERT( outlen == cipher->len + tag->len );
+        TEST_ASSERT( memcmp( encrypt_buf, cipher->x, cipher->len ) == 0 );
+        TEST_ASSERT( memcmp( encrypt_buf + cipher->len,
+                             tag->x, tag->len ) == 0 );
+
+        mbedtls_free( encrypt_buf );
+        encrypt_buf = NULL;
+    }
+
+    /*
+     * Prepare context for decryption
+     */
+    cipher_reset_key( &ctx, cipher_id, use_psa, tag->len, key,
+                      MBEDTLS_DECRYPT );
+
+    /*
+     * Prepare pointers for decryption
      */
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
     if( use_psa == 1 )
     {
-        /* PSA requires that the tag immediately follows the ciphertext. */
-        tmp_cipher = mbedtls_calloc( 1, cipher->len + tag->len );
-        TEST_ASSERT( tmp_cipher != NULL );
+        /* PSA requires that the tag immediately follows the ciphertext.
+         * Fortunately, we already have that from testing the new API. */
+        tmp_cipher = cipher_plus_tag;
         tmp_tag = tmp_cipher + cipher->len;
-
-        memcpy( tmp_cipher, cipher->x, cipher->len );
-        memcpy( tmp_tag, tag->x, tag->len );
     }
     else
 #endif /* MBEDTLS_USE_PSA_CRYPTO */
     {
-        tmp_tag = tag->x;
         tmp_cipher = cipher->x;
+        tmp_tag = tag->x;
     }
 
     /*
@@ -1118,13 +1264,13 @@
 exit:
 
     mbedtls_cipher_free( &ctx );
+    mbedtls_free( decrypt_buf );
+    mbedtls_free( encrypt_buf );
+    mbedtls_free( cipher_plus_tag );
 
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
     if( use_psa == 1 )
-    {
-        mbedtls_free( tmp_cipher );
         PSA_DONE( );
-    }
 #endif /* MBEDTLS_USE_PSA_CRYPTO */
 }
 /* END_CASE */