psa: Rework unauthenticated cipher support in transparent test driver

Make use of psa_cipher_xyz_internal() functions to
simplify the transparent test driver code and
extend the algorithms it supports to all algorithms
supported by the MbedTLS library.

Signed-off-by: Ronald Cron <ronald.cron@arm.com>
diff --git a/tests/src/drivers/cipher.c b/tests/src/drivers/cipher.c
index fa7c6a9..6a205b4 100644
--- a/tests/src/drivers/cipher.c
+++ b/tests/src/drivers/cipher.c
@@ -26,6 +26,7 @@
 
 #if defined(MBEDTLS_PSA_CRYPTO_DRIVERS) && defined(PSA_CRYPTO_DRIVER_TEST)
 #include "psa/crypto.h"
+#include "psa_crypto_cipher.h"
 #include "psa_crypto_core.h"
 #include "mbedtls/cipher.h"
 
@@ -204,79 +205,28 @@
             output, output_size, output_length) );
 }
 
-static psa_status_t test_transparent_cipher_setup(
-    mbedtls_operation_t direction,
-    test_transparent_cipher_operation_t *operation,
-    const psa_key_attributes_t *attributes,
-    const uint8_t *key, size_t key_length,
-    psa_algorithm_t alg)
-{
-    const mbedtls_cipher_info_t *cipher_info = NULL;
-    int ret = 0;
-
-    test_driver_cipher_hooks.hits++;
-
-    if( operation->alg != 0 )
-        return( PSA_ERROR_BAD_STATE );
-
-    /* Wiping the entire struct here, instead of member-by-member. This is useful
-     * for the test suite, since it gives a chance of catching memory corruption
-     * errors should the core not have allocated (enough) memory for our context
-     * struct. */
-    memset( operation, 0, sizeof( *operation ) );
-
-    /* Allow overriding return value for testing purposes */
-    if( test_driver_cipher_hooks.forced_status != PSA_SUCCESS )
-        return( test_driver_cipher_hooks.forced_status );
-
-    /* Test driver supports AES-CTR only, to verify operation calls. */
-    if( alg != PSA_ALG_CTR ||
-        psa_get_key_type( attributes ) != PSA_KEY_TYPE_AES )
-        return( PSA_ERROR_NOT_SUPPORTED );
-
-    operation->alg = alg;
-    operation->iv_size = 16;
-
-    cipher_info = mbedtls_cipher_info_from_values( MBEDTLS_CIPHER_ID_AES,
-                                                   key_length * 8,
-                                                   MBEDTLS_MODE_CTR );
-    if( cipher_info == NULL )
-        return( PSA_ERROR_NOT_SUPPORTED );
-
-    mbedtls_cipher_init( &operation->cipher );
-    ret = mbedtls_cipher_setup( &operation->cipher, cipher_info );
-    if( ret != 0 ) {
-        mbedtls_cipher_free( &operation->cipher );
-        return( mbedtls_to_psa_error( ret ) );
-    }
-
-    ret = mbedtls_cipher_setkey( &operation->cipher,
-                                 key,
-                                 key_length * 8, direction );
-    if( ret != 0 ) {
-        mbedtls_cipher_free( &operation->cipher );
-        return( mbedtls_to_psa_error( ret ) );
-    }
-
-    operation->iv_set = 0;
-    operation->iv_required = 1;
-    operation->key_set = 1;
-
-    return( test_driver_cipher_hooks.forced_status );
-}
-
 psa_status_t test_transparent_cipher_encrypt_setup(
     test_transparent_cipher_operation_t *operation,
     const psa_key_attributes_t *attributes,
     const uint8_t *key, size_t key_length,
     psa_algorithm_t alg)
 {
-    return ( test_transparent_cipher_setup( MBEDTLS_ENCRYPT,
-                                            operation,
-                                            attributes,
-                                            key,
-                                            key_length,
-                                            alg ) );
+    test_driver_cipher_hooks.hits++;
+
+    /* Wiping the entire struct here, instead of member-by-member. This is
+     * useful for the test suite, since it gives a chance of catching memory
+     * corruption errors should the core not have allocated (enough) memory for
+     * our context struct. */
+    memset( operation, 0, sizeof( *operation ) );
+
+    if( test_driver_cipher_hooks.forced_status != PSA_SUCCESS )
+        return( test_driver_cipher_hooks.forced_status );
+
+    return ( mbedtls_psa_cipher_encrypt_setup( operation,
+                                               attributes,
+                                               key,
+                                               key_length,
+                                               alg ) );
 }
 
 psa_status_t test_transparent_cipher_decrypt_setup(
@@ -285,12 +235,16 @@
     const uint8_t *key, size_t key_length,
     psa_algorithm_t alg)
 {
-    return ( test_transparent_cipher_setup( MBEDTLS_DECRYPT,
-                                            operation,
-                                            attributes,
-                                            key,
-                                            key_length,
-                                            alg ) );
+    test_driver_cipher_hooks.hits++;
+
+    if( test_driver_cipher_hooks.forced_status != PSA_SUCCESS )
+        return( test_driver_cipher_hooks.forced_status );
+
+    return ( mbedtls_psa_cipher_decrypt_setup( operation,
+                                               attributes,
+                                               key,
+                                               key_length,
+                                               alg ) );
 }
 
 psa_status_t test_transparent_cipher_abort(
@@ -300,18 +254,16 @@
 
     if( operation->alg == 0 )
         return( PSA_SUCCESS );
-    if( operation->alg != PSA_ALG_CTR )
-        return( PSA_ERROR_BAD_STATE );
 
-    mbedtls_cipher_free( &operation->cipher );
+    mbedtls_psa_cipher_abort( operation );
 
-    /* Wiping the entire struct here, instead of member-by-member. This is useful
-     * for the test suite, since it gives a chance of catching memory corruption
-     * errors should the core not have allocated (enough) memory for our context
-     * struct. */
+    /* Wiping the entire struct here, instead of member-by-member. This is
+     * useful for the test suite, since it gives a chance of catching memory
+     * corruption errors should the core not have allocated (enough) memory for
+     * our context struct. */
     memset( operation, 0, sizeof( *operation ) );
 
-    return( PSA_SUCCESS );
+    return( test_driver_cipher_hooks.forced_status );
 }
 
 psa_status_t test_transparent_cipher_generate_iv(
@@ -320,35 +272,15 @@
     size_t iv_size,
     size_t *iv_length)
 {
-    psa_status_t status;
-    mbedtls_test_rnd_pseudo_info rnd_info;
-    memset( &rnd_info, 0x5A, sizeof( mbedtls_test_rnd_pseudo_info ) );
-
     test_driver_cipher_hooks.hits++;
 
     if( test_driver_cipher_hooks.forced_status != PSA_SUCCESS )
         return( test_driver_cipher_hooks.forced_status );
 
-    if( operation->alg != PSA_ALG_CTR )
-        return( PSA_ERROR_BAD_STATE );
-
-    if( operation->iv_set || ! operation->iv_required )
-        return( PSA_ERROR_BAD_STATE );
-
-    if( iv_size < operation->iv_size )
-        return( PSA_ERROR_BUFFER_TOO_SMALL );
-
-    status = mbedtls_to_psa_error(
-        mbedtls_test_rnd_pseudo_rand( &rnd_info,
-                                      iv,
-                                      operation->iv_size ) );
-    if( status != PSA_SUCCESS )
-        return( status );
-
-    *iv_length = operation->iv_size;
-    status = test_transparent_cipher_set_iv( operation, iv, *iv_length );
-
-    return( status );
+    return( mbedtls_psa_cipher_generate_iv( operation,
+                                            iv,
+                                            iv_size,
+                                            iv_length ) );
 }
 
 psa_status_t test_transparent_cipher_set_iv(
@@ -356,29 +288,14 @@
     const uint8_t *iv,
     size_t iv_length)
 {
-    psa_status_t status;
-
     test_driver_cipher_hooks.hits++;
 
     if( test_driver_cipher_hooks.forced_status != PSA_SUCCESS )
         return( test_driver_cipher_hooks.forced_status );
 
-    if( operation->alg != PSA_ALG_CTR )
-        return( PSA_ERROR_BAD_STATE );
-
-    if( operation->iv_set || ! operation->iv_required )
-        return( PSA_ERROR_BAD_STATE );
-
-    if( iv_length != operation->iv_size )
-        return( PSA_ERROR_INVALID_ARGUMENT );
-
-    status = mbedtls_to_psa_error(
-        mbedtls_cipher_set_iv( &operation->cipher, iv, iv_length ) );
-
-    if( status == PSA_SUCCESS )
-        operation->iv_set = 1;
-
-    return( status );
+    return( mbedtls_psa_cipher_set_iv( operation,
+                                       iv,
+                                       iv_length ) );
 }
 
 psa_status_t test_transparent_cipher_update(
@@ -389,27 +306,8 @@
     size_t output_size,
     size_t *output_length)
 {
-    psa_status_t status;
-
     test_driver_cipher_hooks.hits++;
 
-    if( test_driver_cipher_hooks.forced_status != PSA_SUCCESS )
-        return( test_driver_cipher_hooks.forced_status );
-
-    if( operation->alg != PSA_ALG_CTR )
-        return( PSA_ERROR_BAD_STATE );
-
-    /* CTR is a stream cipher, so data in and out are always the same size */
-    if( output_size < input_length )
-        return( PSA_ERROR_BUFFER_TOO_SMALL );
-
-    status = mbedtls_to_psa_error(
-        mbedtls_cipher_update( &operation->cipher, input,
-                               input_length, output, output_length ) );
-
-    if( status != PSA_SUCCESS )
-        return status;
-
     if( test_driver_cipher_hooks.forced_output != NULL )
     {
         if( output_size < test_driver_cipher_hooks.forced_output_length )
@@ -419,9 +317,17 @@
                 test_driver_cipher_hooks.forced_output,
                 test_driver_cipher_hooks.forced_output_length );
         *output_length = test_driver_cipher_hooks.forced_output_length;
+
+        return( test_driver_cipher_hooks.forced_status );
     }
 
-    return( test_driver_cipher_hooks.forced_status );
+    if( test_driver_cipher_hooks.forced_status != PSA_SUCCESS )
+        return( test_driver_cipher_hooks.forced_status );
+
+    return( mbedtls_psa_cipher_update( operation,
+                                       input, input_length,
+                                       output, output_size,
+                                       output_length ) );
 }
 
 psa_status_t test_transparent_cipher_finish(
@@ -430,41 +336,8 @@
     size_t output_size,
     size_t *output_length)
 {
-    psa_status_t status = PSA_ERROR_GENERIC_ERROR;
-    uint8_t temp_output_buffer[MBEDTLS_MAX_BLOCK_LENGTH];
-
     test_driver_cipher_hooks.hits++;
 
-    if( test_driver_cipher_hooks.forced_status != PSA_SUCCESS )
-        return( test_driver_cipher_hooks.forced_status );
-
-    if( operation->alg != PSA_ALG_CTR )
-        return( PSA_ERROR_BAD_STATE );
-
-    if( ! operation->key_set )
-        return( PSA_ERROR_BAD_STATE );
-
-    if( operation->iv_required && ! operation->iv_set )
-        return( PSA_ERROR_BAD_STATE );
-
-    status = mbedtls_to_psa_error(
-        mbedtls_cipher_finish( &operation->cipher,
-                               temp_output_buffer,
-                               output_length ) );
-
-    mbedtls_cipher_free( &operation->cipher );
-
-    if( status != PSA_SUCCESS )
-        return( status );
-
-    if( *output_length == 0 )
-        ; /* Nothing to copy. Note that output may be NULL in this case. */
-    else if( output_size >= *output_length )
-        memcpy( output, temp_output_buffer, *output_length );
-    else
-        return( PSA_ERROR_BUFFER_TOO_SMALL );
-
-
     if( test_driver_cipher_hooks.forced_output != NULL )
     {
         if( output_size < test_driver_cipher_hooks.forced_output_length )
@@ -474,9 +347,16 @@
                 test_driver_cipher_hooks.forced_output,
                 test_driver_cipher_hooks.forced_output_length );
         *output_length = test_driver_cipher_hooks.forced_output_length;
+
+        return( test_driver_cipher_hooks.forced_status );
     }
 
-    return( test_driver_cipher_hooks.forced_status );
+    if( test_driver_cipher_hooks.forced_status != PSA_SUCCESS )
+        return( test_driver_cipher_hooks.forced_status );
+
+    return( mbedtls_psa_cipher_finish( operation,
+                                       output, output_size,
+                                       output_length ) );
 }
 
 /*