Move AEAD length checks to PSA core

Signed-off-by: Paul Elliott <paul.elliott@arm.com>
diff --git a/include/psa/crypto_builtin_primitives.h b/include/psa/crypto_builtin_primitives.h
index e3903bc..b28e0d7 100644
--- a/include/psa/crypto_builtin_primitives.h
+++ b/include/psa/crypto_builtin_primitives.h
@@ -130,7 +130,6 @@
     psa_algorithm_t alg;
     psa_key_type_t key_type;
 
-    unsigned int lengths_set : 1;
     unsigned int is_encrypt : 1;
     unsigned int ad_started : 1;
     unsigned int body_started : 1;
@@ -138,9 +137,6 @@
     uint8_t tag_length;
     uint8_t nonce_length;
 
-    size_t ad_remaining;
-    size_t body_remaining;
-
     /* Buffers for AD/data - only required until CCM gets proper multipart
        support. */
     uint8_t *ad_buffer;
@@ -172,7 +168,7 @@
 
 } mbedtls_psa_aead_operation_t;
 
-#define MBEDTLS_PSA_AEAD_OPERATION_INIT {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, {0}, {0}}
+#define MBEDTLS_PSA_AEAD_OPERATION_INIT {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, {0}, {0}}
 
 /*
  * BEYOND THIS POINT, TEST DRIVER DECLARATIONS ONLY.
diff --git a/include/psa/crypto_struct.h b/include/psa/crypto_struct.h
index 36503f9..0f74c54 100644
--- a/include/psa/crypto_struct.h
+++ b/include/psa/crypto_struct.h
@@ -165,6 +165,9 @@
     psa_algorithm_t alg;
     psa_key_type_t key_type;
 
+    size_t ad_remaining;
+    size_t body_remaining;
+
     unsigned int nonce_set : 1;
     unsigned int lengths_set : 1;
     unsigned int ad_started : 1;
@@ -173,7 +176,7 @@
     psa_driver_aead_context_t ctx;
 };
 
-#define PSA_AEAD_OPERATION_INIT {0, 0, 0, 0, 0, 0, 0, {0}}
+#define PSA_AEAD_OPERATION_INIT {0, 0, 0, 0, 0, 0, 0, 0, 0, {0}}
 static inline struct psa_aead_operation_s psa_aead_operation_init( void )
 {
     const struct psa_aead_operation_s v = PSA_AEAD_OPERATION_INIT;
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index e97cbaf..c53020a 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -3467,7 +3467,11 @@
 exit:
 
     if( status == PSA_SUCCESS )
+    {
+        operation->ad_remaining = ad_length;
+        operation->body_remaining = plaintext_length;
         operation->lengths_set = 1;
+    }
     else
         psa_aead_abort( operation );
 
@@ -3492,6 +3496,17 @@
         goto exit;
     }
 
+    if( operation->lengths_set )
+    {
+        if ( operation->ad_remaining < input_length )
+        {
+            status = PSA_ERROR_INVALID_ARGUMENT;
+            goto exit;
+        }
+
+        operation->ad_remaining -= input_length;
+    }
+
     status = psa_driver_wrapper_aead_update_ad( operation, input,
                                                 input_length );
 
@@ -3530,6 +3545,26 @@
         goto exit;
     }
 
+    if( operation->lengths_set )
+    {
+        /* Additional data length was supplied, but not all the additional
+           data was supplied.*/
+        if( operation->ad_remaining != 0 )
+        {
+            status = PSA_ERROR_INVALID_ARGUMENT;
+            goto exit;
+        }
+
+        /* Too much data provided. */
+        if( operation->body_remaining < input_length )
+        {
+            status = PSA_ERROR_INVALID_ARGUMENT;
+            goto exit;
+        }
+
+        operation->body_remaining -= input_length;
+    }
+
     status = psa_driver_wrapper_aead_update( operation, input, input_length,
                                              output, output_size,
                                              output_length );
@@ -3571,6 +3606,13 @@
         goto exit;
     }
 
+    if( operation->lengths_set && (operation->ad_remaining != 0 ||
+                                   operation->body_remaining != 0 ) )
+    {
+        status = PSA_ERROR_BAD_STATE;
+        goto exit;
+    }
+
     status = psa_driver_wrapper_aead_finish( operation, ciphertext,
                                              ciphertext_size,
                                              ciphertext_length,
@@ -3609,6 +3651,13 @@
         goto exit;
     }
 
+    if( operation->lengths_set && (operation->ad_remaining != 0 ||
+                                   operation->body_remaining != 0 ) )
+    {
+        status = PSA_ERROR_BAD_STATE;
+        goto exit;
+    }
+
     status = psa_driver_wrapper_aead_verify( operation, plaintext,
                                              plaintext_size,
                                              plaintext_length,
diff --git a/library/psa_crypto_aead.c b/library/psa_crypto_aead.c
index 0daa303..bbfc927 100644
--- a/library/psa_crypto_aead.c
+++ b/library/psa_crypto_aead.c
@@ -481,10 +481,6 @@
         return ( PSA_ERROR_NOT_SUPPORTED );
     }
 
-    operation->ad_remaining = ad_length;
-    operation->body_remaining = plaintext_length;
-    operation->lengths_set = 1;
-
     return ( PSA_SUCCESS );
 }
 
@@ -496,14 +492,6 @@
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
 
-    if( operation->lengths_set )
-    {
-        if ( operation->ad_remaining < input_length )
-            return( PSA_ERROR_INVALID_ARGUMENT );
-
-        operation->ad_remaining -= input_length;
-    }
-
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_GCM)
     if( operation->alg == PSA_ALG_GCM )
     {
@@ -590,20 +578,6 @@
                                         input_length ) > output_size )
         return ( PSA_ERROR_BUFFER_TOO_SMALL );
 
-    if( operation->lengths_set)
-    {
-        /* Additional data length was supplied, but not all the additional
-           data was supplied.*/
-        if( operation->ad_remaining != 0 )
-            return ( PSA_ERROR_INVALID_ARGUMENT );
-
-        /* Too much data provided. */
-        if( operation->body_remaining < input_length )
-            return ( PSA_ERROR_INVALID_ARGUMENT );
-
-        operation->body_remaining -= input_length;
-    }
-
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_GCM)
     if( operation->alg == PSA_ALG_GCM )
     {
@@ -725,10 +699,6 @@
 {
     size_t finish_output_size;
 
-    if( operation->lengths_set )
-        if( operation->ad_remaining != 0 || operation->body_remaining != 0 )
-            return( PSA_ERROR_BAD_STATE );
-
     if( tag_size < operation->tag_length )
         return ( PSA_ERROR_BUFFER_TOO_SMALL );
 
@@ -934,7 +904,6 @@
 #endif /* MBEDTLS_PSA_BUILTIN_ALG_CHACHA20_POLY1305 */
     }
 
-    operation->lengths_set = 0;
     operation->is_encrypt = 0;
     operation->ad_started = 0;
     operation->body_started = 0;