Split multipart AEAD contexts into two parts

Split to data required for internal implementation and data required for
driver implementation with data left over for the PSA layer.

Signed-off-by: Paul Elliott <paul.elliott@arm.com>
diff --git a/library/psa_crypto_aead.c b/library/psa_crypto_aead.c
index f5b4dc5..8f8b74e 100644
--- a/library/psa_crypto_aead.c
+++ b/library/psa_crypto_aead.c
@@ -20,7 +20,6 @@
 
 #include "common.h"
 
-
 #if defined(MBEDTLS_PSA_CRYPTO_C)
 
 #include "psa_crypto_aead.h"
@@ -55,7 +54,7 @@
 
 
 static psa_status_t psa_aead_setup(
-    psa_aead_operation_t *operation,
+    mbedtls_psa_aead_operation_t *operation,
     const psa_key_attributes_t *attributes,
     const uint8_t *key_buffer,
     psa_algorithm_t alg )
@@ -66,12 +65,6 @@
     mbedtls_cipher_id_t cipher_id;
     size_t full_tag_length = 0;
 
-    if( operation->key_set || operation->nonce_set ||
-        operation->ad_started || operation->body_started )
-    {
-        return( PSA_ERROR_BAD_STATE );
-    }
-
     key_bits = attributes->core.bits;
 
     cipher_info = mbedtls_cipher_info_from_psa( alg,
@@ -146,12 +139,12 @@
         > full_tag_length )
         return( PSA_ERROR_INVALID_ARGUMENT );
 
-    operation->tag_length = PSA_AEAD_TAG_LENGTH( attributes->core.type,
+    operation->key_type = psa_get_key_type( attributes );
+
+    operation->tag_length = PSA_AEAD_TAG_LENGTH( operation->key_type,
                                                  key_bits,
                                                  alg );
 
-    operation->key_set = 1;
-
     return( PSA_SUCCESS );
 }
 
@@ -165,7 +158,7 @@
     uint8_t *ciphertext, size_t ciphertext_size, size_t *ciphertext_length )
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
-    psa_aead_operation_t operation = PSA_AEAD_OPERATION_INIT;
+    mbedtls_psa_aead_operation_t operation = MBEDTLS_PSA_AEAD_OPERATION_INIT;
     uint8_t *tag;
     (void) key_buffer_size;
 
@@ -275,7 +268,7 @@
     uint8_t *plaintext, size_t plaintext_size, size_t *plaintext_length )
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
-    psa_aead_operation_t operation = PSA_AEAD_OPERATION_INIT;
+    mbedtls_psa_aead_operation_t operation = MBEDTLS_PSA_AEAD_OPERATION_INIT;
     const uint8_t *tag = NULL;
     (void) key_buffer_size;
 
@@ -354,7 +347,8 @@
 
 /* Set the key and algorithm for a multipart authenticated encryption
  * operation. */
-psa_status_t mbedtls_psa_aead_encrypt_setup( psa_aead_operation_t *operation,
+psa_status_t mbedtls_psa_aead_encrypt_setup( mbedtls_psa_aead_operation_t
+                                                                    *operation,
                                              const psa_key_attributes_t
                                                                     *attributes,
                                              const uint8_t *key_buffer,
@@ -377,7 +371,8 @@
 
 /* Set the key and algorithm for a multipart authenticated decryption
  * operation. */
-psa_status_t mbedtls_psa_aead_decrypt_setup( psa_aead_operation_t *operation,
+psa_status_t mbedtls_psa_aead_decrypt_setup( mbedtls_psa_aead_operation_t
+                                                                    *operation,
                                              const psa_key_attributes_t
                                                                     *attributes,
                                              const uint8_t *key_buffer,
@@ -399,7 +394,8 @@
 }
 
 /* Set a nonce for the multipart AEAD operation*/
-psa_status_t mbedtls_psa_aead_set_nonce( psa_aead_operation_t *operation,
+psa_status_t mbedtls_psa_aead_set_nonce( mbedtls_psa_aead_operation_t
+                                                                    *operation,
                                          const uint8_t *nonce,
                                          size_t nonce_length )
 {
@@ -454,15 +450,11 @@
         return ( PSA_ERROR_NOT_SUPPORTED );
     }
 
-    if( status == PSA_SUCCESS )
-    {
-        operation->nonce_set = 1;
-    }
-
     return( status );
 }
  /* Declare the lengths of the message and additional data for AEAD. */
-psa_status_t mbedtls_psa_aead_set_lengths( psa_aead_operation_t *operation,
+psa_status_t mbedtls_psa_aead_set_lengths( mbedtls_psa_aead_operation_t
+                                                                    *operation,
                                            size_t ad_length,
                                            size_t plaintext_length )
 {
@@ -512,7 +504,8 @@
 }
 
 /* Pass additional data to an active multipart AEAD operation. */
-psa_status_t mbedtls_psa_aead_update_ad( psa_aead_operation_t *operation,
+psa_status_t mbedtls_psa_aead_update_ad( mbedtls_psa_aead_operation_t
+                                                                    *operation,
                                          const uint8_t *input,
                                          size_t input_length )
 {
@@ -611,7 +604,7 @@
 
 /* Encrypt or decrypt a message fragment in an active multipart AEAD
  * operation.*/
-psa_status_t mbedtls_psa_aead_update( psa_aead_operation_t *operation,
+psa_status_t mbedtls_psa_aead_update( mbedtls_psa_aead_operation_t *operation,
                                       const uint8_t *input,
                                       size_t input_length,
                                       uint8_t *output,
@@ -786,7 +779,7 @@
 
 /* Common checks for both mbedtls_psa_aead_finish() and
    mbedtls_psa_aead_verify() */
-static psa_status_t mbedtls_psa_aead_finish_checks( psa_aead_operation_t
+static psa_status_t mbedtls_psa_aead_finish_checks( mbedtls_psa_aead_operation_t
                                                                     *operation,
                                                     size_t output_size,
                                                     size_t tag_size )
@@ -828,7 +821,7 @@
 }
 
 /* Finish encrypting a message in a multipart AEAD operation. */
-psa_status_t mbedtls_psa_aead_finish( psa_aead_operation_t *operation,
+psa_status_t mbedtls_psa_aead_finish( mbedtls_psa_aead_operation_t *operation,
                                       uint8_t *ciphertext,
                                       size_t ciphertext_size,
                                       size_t *ciphertext_length,
@@ -903,7 +896,7 @@
 
 /* Finish authenticating and decrypting a message in a multipart AEAD
  * operation.*/
-psa_status_t mbedtls_psa_aead_verify( psa_aead_operation_t *operation,
+psa_status_t mbedtls_psa_aead_verify( mbedtls_psa_aead_operation_t *operation,
                                       uint8_t *plaintext,
                                       size_t plaintext_size,
                                       size_t *plaintext_length,
@@ -1033,7 +1026,7 @@
 }
 
 /* Abort an AEAD operation */
-psa_status_t mbedtls_psa_aead_abort( psa_aead_operation_t *operation )
+psa_status_t mbedtls_psa_aead_abort( mbedtls_psa_aead_operation_t *operation )
 {
     switch( operation->alg )
     {
@@ -1054,6 +1047,11 @@
 #endif /* MBEDTLS_PSA_BUILTIN_ALG_CHACHA20_POLY1305 */
     }
 
+    operation->lengths_set = 0;
+    operation->is_encrypt = 0;
+    operation->ad_started = 0;
+    operation->body_started = 0;
+
     mbedtls_free(operation->ad_buffer);
     operation->ad_buffer = NULL;
     operation->ad_length = 0;