Refactor aead setup functions into single function

Move common encrypt / decrypt code into common function, and roll in
previously refactored setup checks function, as this is now the only
place it is called.

Signed-off-by: Paul Elliott <paul.elliott@arm.com>
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index e40e370..1566a45 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -3401,26 +3401,82 @@
     return PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG( alg );
 }
 
-static psa_status_t psa_aead_setup_checks( psa_aead_operation_t *operation,
-                                           psa_algorithm_t alg )
+/* Set the key for a multipart authenticated operation. */
+static psa_status_t psa_aead_setup( psa_aead_operation_t *operation,
+                                    mbedtls_svc_key_id_t key,
+                                    psa_algorithm_t alg )
 {
+    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+    psa_status_t unlock_status = PSA_ERROR_CORRUPTION_DETECTED;
+    psa_key_slot_t *slot = NULL;
+    psa_key_usage_t key_usage = 0;
+
     if( !PSA_ALG_IS_AEAD( alg ) || PSA_ALG_IS_WILDCARD( alg ) )
     {
-        return( PSA_ERROR_INVALID_ARGUMENT );
+        status = PSA_ERROR_INVALID_ARGUMENT;
+        goto exit;
     }
 
     if( operation->id != 0 )
     {
-        return( PSA_ERROR_BAD_STATE );
+        status = PSA_ERROR_BAD_STATE;
+        goto exit;
     }
 
     if( operation->nonce_set || operation->lengths_set ||
         operation->ad_started || operation->body_started )
     {
-        return( PSA_ERROR_BAD_STATE );
+        status = PSA_ERROR_BAD_STATE;
+        goto exit;
     }
 
-    return( PSA_SUCCESS );
+    if( operation->is_encrypt )
+        key_usage = PSA_KEY_USAGE_ENCRYPT;
+    else
+        key_usage = PSA_KEY_USAGE_DECRYPT;
+
+    status = psa_get_and_lock_key_slot_with_policy( key, &slot, key_usage,
+                                                    alg );
+
+    if( status != PSA_SUCCESS )
+        goto exit;
+
+    psa_key_attributes_t attributes = {
+        .core = slot->attr
+    };
+
+    if( operation->is_encrypt )
+        status = psa_driver_wrapper_aead_encrypt_setup( operation,
+                                                        &attributes,
+                                                        slot->key.data,
+                                                        slot->key.bytes,
+                                                        alg );
+    else
+        status = psa_driver_wrapper_aead_decrypt_setup( operation,
+                                                        &attributes,
+                                                        slot->key.data,
+                                                        slot->key.bytes,
+                                                        alg );
+
+
+    if( status != PSA_SUCCESS )
+        goto exit;
+
+    operation->key_type = psa_get_key_type( &attributes );
+
+exit:
+
+    unlock_status = psa_unlock_key_slot( slot );
+
+    if( status == PSA_SUCCESS )
+    {
+        status = unlock_status;
+        operation->alg = psa_aead_get_base_algorithm( alg );
+    }
+    else
+        psa_aead_abort( operation );
+
+    return( status );
 }
 
 /* Set the key for a multipart authenticated encryption operation. */
@@ -3428,48 +3484,9 @@
                                      mbedtls_svc_key_id_t key,
                                      psa_algorithm_t alg )
 {
-    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
-    psa_status_t unlock_status = PSA_ERROR_CORRUPTION_DETECTED;
-    psa_key_slot_t *slot = NULL;
+    operation->is_encrypt = 1;
 
-    status = psa_aead_setup_checks( operation, alg );
-
-    if( status != PSA_SUCCESS )
-        goto exit;
-
-    status = psa_get_and_lock_key_slot_with_policy(
-                 key, &slot, PSA_KEY_USAGE_ENCRYPT, alg );
-
-    if( status != PSA_SUCCESS )
-        goto exit;
-
-    psa_key_attributes_t attributes = {
-      .core = slot->attr
-    };
-
-    status = psa_driver_wrapper_aead_encrypt_setup( operation,
-                                                    &attributes, slot->key.data,
-                                                    slot->key.bytes, alg );
-
-    if( status != PSA_SUCCESS )
-        goto exit;
-
-    operation->key_type = psa_get_key_type( &attributes );
-
-exit:
-
-    unlock_status = psa_unlock_key_slot( slot );
-
-    if( status == PSA_SUCCESS )
-    {
-        status = unlock_status;
-        operation->alg = psa_aead_get_base_algorithm( alg );
-        operation->is_encrypt = 1;
-    }
-    else
-        psa_aead_abort( operation );
-
-    return( status );
+    return( psa_aead_setup( operation, key, alg ) );
 }
 
 /* Set the key for a multipart authenticated decryption operation. */
@@ -3477,48 +3494,9 @@
                                      mbedtls_svc_key_id_t key,
                                      psa_algorithm_t alg )
 {
-    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
-    psa_status_t unlock_status = PSA_ERROR_CORRUPTION_DETECTED;
-    psa_key_slot_t *slot = NULL;
+    operation->is_encrypt = 0;
 
-    status = psa_aead_setup_checks( operation, alg );
-
-    if( status != PSA_SUCCESS )
-        goto exit;
-
-    status = psa_get_and_lock_key_slot_with_policy(
-                 key, &slot, PSA_KEY_USAGE_DECRYPT, alg );
-
-    if( status != PSA_SUCCESS )
-        goto exit;
-
-    psa_key_attributes_t attributes = {
-      .core = slot->attr
-    };
-
-    status = psa_driver_wrapper_aead_decrypt_setup( operation,
-                                                    &attributes, slot->key.data,
-                                                    slot->key.bytes, alg );
-
-    if( status != PSA_SUCCESS )
-        goto exit;
-
-    operation->key_type = psa_get_key_type( &attributes );
-
-exit:
-
-    unlock_status = psa_unlock_key_slot( slot );
-
-    if( status == PSA_SUCCESS )
-    {
-        status = unlock_status;
-        operation->alg = psa_aead_get_base_algorithm( alg );
-        operation->is_encrypt = 0;
-    }
-    else
-        psa_aead_abort( operation );
-
-    return( status );
+    return( psa_aead_setup( operation, key, alg ) );
 }
 
 /* Generate a random nonce / IV for multipart AEAD operation */