Restructure cipher context object to contain driver switch

Once an operation has been 'accepted' by a driver, the remainder is bound
to the same driver, since driver-specific context structs cannot be shared.
This provides a pretty good gate mechanism for the fallback logic, too.

Signed-off-by: Steven Cooreman <steven.cooreman@silabs.com>
diff --git a/include/psa/crypto_struct.h b/include/psa/crypto_struct.h
index 0ea8073..3ff3f93 100644
--- a/include/psa/crypto_struct.h
+++ b/include/psa/crypto_struct.h
@@ -158,6 +158,7 @@
     unsigned int key_set : 1;
     unsigned int iv_required : 1;
     unsigned int iv_set : 1;
+    unsigned int accelerator_set : 1;
     uint8_t iv_size;
     uint8_t block_size;
     union
@@ -173,7 +174,7 @@
     } ctx;
 };
 
-#define PSA_CIPHER_OPERATION_INIT {0, 0, 0, 0, 0, 0, {0}}
+#define PSA_CIPHER_OPERATION_INIT {0, 0, 0, 0, 0, 0, 0, {0}}
 static inline struct psa_cipher_operation_s psa_cipher_operation_init( void )
 {
     const struct psa_cipher_operation_s v = PSA_CIPHER_OPERATION_INIT;
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 6acf498..7eb9568 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -4059,9 +4059,9 @@
     {
         operation->iv_required = 1;
     }
+    operation->accelerator_set = 0;
     operation->iv_size = 0;
     operation->block_size = 0;
-    mbedtls_cipher_init( &operation->ctx.cipher );
     return( PSA_SUCCESS );
 }
 
@@ -4083,6 +4083,14 @@
     if( status != PSA_SUCCESS )
         goto exit;
 
+    /* A context must be freshly initialized before it can be set up. */
+    if( operation->alg != 0 )
+        return( PSA_ERROR_BAD_STATE );
+
+    status = psa_cipher_init( operation, alg );
+    if( status != PSA_SUCCESS )
+        return( status );
+
     /* Try doing this through a driver before using software fallback */
     if( cipher_operation == MBEDTLS_ENCRYPT )
         status = psa_driver_wrapper_cipher_encrypt_setup( operation,
@@ -4093,18 +4101,19 @@
                                                           slot,
                                                           alg );
 
-    if( status != PSA_ERROR_NOT_SUPPORTED )
-        goto exit;
-
-    /* A context must be freshly initialized before it can be set up. */
-    if( operation->alg != 0 )
+    if( status == PSA_SUCCESS )
     {
-        return( PSA_ERROR_BAD_STATE );
+        operation->accelerator_set = 1;
+        operation->key_set = 1;
     }
 
-    status = psa_cipher_init( operation, alg );
-    if( status != PSA_SUCCESS )
-        return( status );
+    if( status != PSA_ERROR_NOT_SUPPORTED ||
+        psa_key_lifetime_is_external( slot->attr.lifetime ) )
+        goto exit;
+
+    /* Proceed with initializing mbed TLS cipher context if no accelerator is
+     * available for the given algorithm & key. */
+    mbedtls_cipher_init( &operation->ctx.cipher );
 
     status = psa_get_transparent_key( handle, &slot, usage, alg);
     if( status != PSA_SUCCESS )
@@ -4206,7 +4215,14 @@
 {
     psa_status_t status;
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-    if( operation->iv_set || ! operation->iv_required )
+
+    if( operation->accelerator_set == 1 )
+        return( psa_driver_wrapper_cipher_generate_iv( operation,
+                                                       iv,
+                                                       iv_size,
+                                                       iv_length ) );
+
+    if( operation->iv_set || ! operation->iv_required || ! operation->key_set )
     {
         return( PSA_ERROR_BAD_STATE );
     }
@@ -4238,7 +4254,13 @@
 {
     psa_status_t status;
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-    if( operation->iv_set || ! operation->iv_required )
+
+    if( operation->accelerator_set == 1 )
+        return( psa_driver_wrapper_cipher_set_iv( operation,
+                                                  iv,
+                                                  iv_length ) );
+
+    if( operation->iv_set || ! operation->iv_required || ! operation->key_set )
     {
         return( PSA_ERROR_BAD_STATE );
     }
@@ -4355,7 +4377,15 @@
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
     size_t expected_output_size;
 
-    if( operation->alg == 0 )
+    if( operation->accelerator_set == 1 )
+        return( psa_driver_wrapper_cipher_update( operation,
+                                               input,
+                                               input_length,
+                                               output,
+                                               output_size,
+                                               output_length ) );
+
+    if( operation->alg == 0 || ! operation->key_set )
     {
         return( PSA_ERROR_BAD_STATE );
     }
@@ -4414,6 +4444,12 @@
     int cipher_ret = MBEDTLS_ERR_CIPHER_FEATURE_UNAVAILABLE;
     uint8_t temp_output_buffer[MBEDTLS_MAX_BLOCK_LENGTH];
 
+    if( operation->accelerator_set == 1 )
+        return( psa_driver_wrapper_cipher_finish( operation,
+                                                  output,
+                                                  output_size,
+                                                  output_length ) );
+
     if( ! operation->key_set )
     {
         return( PSA_ERROR_BAD_STATE );
@@ -4483,11 +4519,15 @@
     if( ! PSA_ALG_IS_CIPHER( operation->alg ) )
         return( PSA_ERROR_BAD_STATE );
 
-    mbedtls_cipher_free( &operation->ctx.cipher );
+    if( operation->accelerator_set == 1 )
+        psa_driver_wrapper_cipher_abort( operation );
+    else
+        mbedtls_cipher_free( &operation->ctx.cipher );
 
     operation->alg = 0;
     operation->key_set = 0;
     operation->iv_set = 0;
+    operation->accelerator_set = 0;
     operation->iv_size = 0;
     operation->block_size = 0;
     operation->iv_required = 0;
diff --git a/tests/src/drivers/cipher.c b/tests/src/drivers/cipher.c
index 0f059a0..9db5061 100644
--- a/tests/src/drivers/cipher.c
+++ b/tests/src/drivers/cipher.c
@@ -40,6 +40,8 @@
 void *test_driver_cipher_forced_output = NULL;
 size_t test_driver_cipher_forced_output_length = 0;
 
+/* Test driver, if not explicitly setup, returns 'PSA_ERROR_NOT_SUPPORTED' by default,
+ * causing regular test suites to pass since the core will go into fallback mode. */
 psa_status_t test_transparent_cipher_status = PSA_ERROR_NOT_SUPPORTED;
 unsigned long test_transparent_cipher_hit = 0;