Unify multipart cipher operation tester functions

Signed-off-by: gabor-mezei-arm <gabor.mezei@arm.com>
diff --git a/tests/suites/test_suite_psa_crypto.function b/tests/suites/test_suite_psa_crypto.function
index a79ae43..d7e7b06 100644
--- a/tests/suites/test_suite_psa_crypto.function
+++ b/tests/suites/test_suite_psa_crypto.function
@@ -2589,94 +2589,19 @@
 /* END_CASE */
 
 /* BEGIN_CASE */
-void cipher_encrypt_simple_multipart( int alg_arg,
-                                      int key_type_arg,
-                                      data_t *key_data,
-                                      data_t *iv,
-                                      data_t *input,
-                                      data_t *expected_output,
-                                      int expected_status_arg )
-{
-    mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
-    psa_status_t status;
-    psa_key_type_t key_type = key_type_arg;
-    psa_algorithm_t alg = alg_arg;
-    psa_status_t expected_status = expected_status_arg;
-    unsigned char *output = NULL;
-    size_t output_buffer_size = 0;
-    size_t function_output_length = 0;
-    size_t total_output_length = 0;
-    psa_cipher_operation_t operation = PSA_CIPHER_OPERATION_INIT;
-    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
-
-    PSA_ASSERT( psa_crypto_init( ) );
-
-    psa_set_key_usage_flags( &attributes, PSA_KEY_USAGE_ENCRYPT );
-    psa_set_key_algorithm( &attributes, alg );
-    psa_set_key_type( &attributes, key_type );
-
-    PSA_ASSERT( psa_import_key( &attributes, key_data->x, key_data->len,
-                                &key ) );
-
-    PSA_ASSERT( psa_cipher_encrypt_setup( &operation, key, alg ) );
-
-    if( iv->len > 0 )
-    {
-        PSA_ASSERT( psa_cipher_set_iv( &operation, iv->x, iv->len ) );
-    }
-
-    output_buffer_size = PSA_CIPHER_UPDATE_OUTPUT_SIZE( key_type, alg, input->len ) +
-                         PSA_CIPHER_FINISH_OUTPUT_SIZE( key_type, alg );
-    ASSERT_ALLOC( output, output_buffer_size );
-
-    PSA_ASSERT( psa_cipher_update( &operation,
-                                   input->x, input->len,
-                                   output, output_buffer_size,
-                                   &function_output_length ) );
-    TEST_ASSERT( function_output_length <=
-                 PSA_CIPHER_UPDATE_OUTPUT_SIZE( key_type, alg, input->len ) );
-    TEST_ASSERT( function_output_length <=
-                 PSA_CIPHER_UPDATE_OUTPUT_MAX_SIZE( input->len ) );
-    total_output_length += function_output_length;
-
-    status = psa_cipher_finish( &operation,
-                                ( output_buffer_size == 0 ? NULL :
-                                  output + total_output_length ),
-                                output_buffer_size - total_output_length,
-                                &function_output_length );
-    TEST_ASSERT( function_output_length <=
-                 PSA_CIPHER_FINISH_OUTPUT_SIZE( key_type, alg ) );
-    TEST_ASSERT( function_output_length <=
-                 PSA_CIPHER_FINISH_OUTPUT_MAX_SIZE );
-    total_output_length += function_output_length;
-
-    TEST_EQUAL( status, expected_status );
-    if( expected_status == PSA_SUCCESS )
-    {
-        PSA_ASSERT( psa_cipher_abort( &operation ) );
-        ASSERT_COMPARE( expected_output->x, expected_output->len,
-                        output, total_output_length );
-    }
-
-exit:
-    psa_cipher_abort( &operation );
-    mbedtls_free( output );
-    psa_destroy_key( key );
-    PSA_DONE( );
-}
-/* END_CASE */
-
-/* BEGIN_CASE */
 void cipher_encrypt_multipart( int alg_arg, int key_type_arg,
                                data_t *key_data, data_t *iv,
                                data_t *input,
                                int first_part_size_arg,
                                int output1_length_arg, int output2_length_arg,
-                               data_t *expected_output )
+                               data_t *expected_output,
+                               int expected_status_arg )
 {
     mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
     psa_key_type_t key_type = key_type_arg;
     psa_algorithm_t alg = alg_arg;
+    psa_status_t status;
+    psa_status_t expected_status = expected_status_arg;
     size_t first_part_size = first_part_size_arg;
     size_t output1_length = output1_length_arg;
     size_t output2_length = output2_length_arg;
@@ -2718,36 +2643,44 @@
                  PSA_CIPHER_UPDATE_OUTPUT_MAX_SIZE( first_part_size) );
     total_output_length += function_output_length;
 
-    PSA_ASSERT( psa_cipher_update( &operation,
-                                   input->x + first_part_size,
-                                   input->len - first_part_size,
-                                   ( output_buffer_size == 0 ? NULL :
-                                     output + total_output_length ),
-                                   output_buffer_size - total_output_length,
-                                   &function_output_length ) );
-    TEST_ASSERT( function_output_length == output2_length );
-    TEST_ASSERT( function_output_length <=
-                 PSA_CIPHER_UPDATE_OUTPUT_SIZE( key_type,
-                                                alg,
-                                                input->len - first_part_size ) );
-    TEST_ASSERT( function_output_length <=
-                 PSA_CIPHER_UPDATE_OUTPUT_MAX_SIZE( input->len ) );
-    total_output_length += function_output_length;
+    if( first_part_size < input->len )
+    {
+        PSA_ASSERT( psa_cipher_update( &operation,
+                                       input->x + first_part_size,
+                                       input->len - first_part_size,
+                                       ( output_buffer_size == 0 ? NULL :
+                                         output + total_output_length ),
+                                       output_buffer_size - total_output_length,
+                                       &function_output_length ) );
+        TEST_ASSERT( function_output_length == output2_length );
+        TEST_ASSERT( function_output_length <=
+                     PSA_CIPHER_UPDATE_OUTPUT_SIZE( key_type,
+                                                    alg,
+                                                    input->len - first_part_size ) );
+        TEST_ASSERT( function_output_length <=
+                     PSA_CIPHER_UPDATE_OUTPUT_MAX_SIZE( input->len ) );
+        total_output_length += function_output_length;
+    }
 
-    PSA_ASSERT( psa_cipher_finish( &operation,
-                                   ( output_buffer_size == 0 ? NULL :
-                                     output + total_output_length ),
-                                   output_buffer_size - total_output_length,
-                                   &function_output_length ) );
+    status = psa_cipher_finish( &operation,
+                                ( output_buffer_size == 0 ? NULL :
+                                  output + total_output_length ),
+                                output_buffer_size - total_output_length,
+                                &function_output_length );
     TEST_ASSERT( function_output_length <=
                  PSA_CIPHER_FINISH_OUTPUT_SIZE( key_type, alg ) );
     TEST_ASSERT( function_output_length <=
                  PSA_CIPHER_FINISH_OUTPUT_MAX_SIZE );
     total_output_length += function_output_length;
-    PSA_ASSERT( psa_cipher_abort( &operation ) );
+    TEST_EQUAL( status, expected_status );
 
-    ASSERT_COMPARE( expected_output->x, expected_output->len,
-                    output, total_output_length );
+    if( expected_status == PSA_SUCCESS )
+    {
+        PSA_ASSERT( psa_cipher_abort( &operation ) );
+
+        ASSERT_COMPARE( expected_output->x, expected_output->len,
+                        output, total_output_length );
+    }
 
 exit:
     psa_cipher_abort( &operation );
@@ -2763,11 +2696,14 @@
                                data_t *input,
                                int first_part_size_arg,
                                int output1_length_arg, int output2_length_arg,
-                               data_t *expected_output )
+                               data_t *expected_output,
+                               int expected_status_arg )
 {
     mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
     psa_key_type_t key_type = key_type_arg;
     psa_algorithm_t alg = alg_arg;
+    psa_status_t status;
+    psa_status_t expected_status = expected_status_arg;
     size_t first_part_size = first_part_size_arg;
     size_t output1_length = output1_length_arg;
     size_t output2_length = output2_length_arg;
@@ -2810,96 +2746,25 @@
                  PSA_CIPHER_UPDATE_OUTPUT_MAX_SIZE( first_part_size ) );
     total_output_length += function_output_length;
 
-    PSA_ASSERT( psa_cipher_update( &operation,
-                                   input->x + first_part_size,
-                                   input->len - first_part_size,
-                                   ( output_buffer_size == 0 ? NULL :
-                                     output + total_output_length ),
-                                   output_buffer_size - total_output_length,
-                                   &function_output_length ) );
-    TEST_ASSERT( function_output_length == output2_length );
-    TEST_ASSERT( function_output_length <=
-                 PSA_CIPHER_UPDATE_OUTPUT_SIZE( key_type,
-                                                alg,
-                                                input->len - first_part_size ) );
-    TEST_ASSERT( function_output_length <=
-                 PSA_CIPHER_UPDATE_OUTPUT_MAX_SIZE( input->len ) );
-    total_output_length += function_output_length;
-
-    PSA_ASSERT( psa_cipher_finish( &operation,
-                                   ( output_buffer_size == 0 ? NULL :
-                                     output + total_output_length ),
-                                   output_buffer_size - total_output_length,
-                                   &function_output_length ) );
-    TEST_ASSERT( function_output_length <=
-                 PSA_CIPHER_FINISH_OUTPUT_SIZE( key_type, alg ) );
-    TEST_ASSERT( function_output_length <=
-                 PSA_CIPHER_FINISH_OUTPUT_MAX_SIZE );
-    total_output_length += function_output_length;
-    PSA_ASSERT( psa_cipher_abort( &operation ) );
-
-    ASSERT_COMPARE( expected_output->x, expected_output->len,
-                    output, total_output_length );
-
-exit:
-    psa_cipher_abort( &operation );
-    mbedtls_free( output );
-    psa_destroy_key( key );
-    PSA_DONE( );
-}
-/* END_CASE */
-
-/* BEGIN_CASE */
-void cipher_decrypt_simple_multipart( int alg_arg,
-                                      int key_type_arg,
-                                      data_t *key_data,
-                                      data_t *iv,
-                                      data_t *input,
-                                      data_t *expected_output,
-                                      int expected_status_arg )
-{
-    mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
-    psa_status_t status;
-    psa_key_type_t key_type = key_type_arg;
-    psa_algorithm_t alg = alg_arg;
-    psa_status_t expected_status = expected_status_arg;
-    unsigned char *output = NULL;
-    size_t output_buffer_size = 0;
-    size_t function_output_length = 0;
-    size_t total_output_length = 0;
-    psa_cipher_operation_t operation = PSA_CIPHER_OPERATION_INIT;
-    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
-
-    PSA_ASSERT( psa_crypto_init( ) );
-
-    psa_set_key_usage_flags( &attributes, PSA_KEY_USAGE_DECRYPT );
-    psa_set_key_algorithm( &attributes, alg );
-    psa_set_key_type( &attributes, key_type );
-
-    PSA_ASSERT( psa_import_key( &attributes, key_data->x, key_data->len,
-                                &key ) );
-
-    PSA_ASSERT( psa_cipher_decrypt_setup( &operation, key, alg ) );
-
-    if( iv->len > 0 )
+    if( first_part_size < input->len )
     {
-        PSA_ASSERT( psa_cipher_set_iv( &operation, iv->x, iv->len ) );
+        PSA_ASSERT( psa_cipher_update( &operation,
+                                       input->x + first_part_size,
+                                       input->len - first_part_size,
+                                       ( output_buffer_size == 0 ? NULL :
+                                         output + total_output_length ),
+                                       output_buffer_size - total_output_length,
+                                       &function_output_length ) );
+        TEST_ASSERT( function_output_length == output2_length );
+        TEST_ASSERT( function_output_length <=
+                     PSA_CIPHER_UPDATE_OUTPUT_SIZE( key_type,
+                                                    alg,
+                                                    input->len - first_part_size ) );
+        TEST_ASSERT( function_output_length <=
+                     PSA_CIPHER_UPDATE_OUTPUT_MAX_SIZE( input->len ) );
+        total_output_length += function_output_length;
     }
 
-    output_buffer_size = PSA_CIPHER_UPDATE_OUTPUT_SIZE( key_type, alg, input->len ) +
-                         PSA_CIPHER_FINISH_OUTPUT_SIZE( key_type, alg );
-    ASSERT_ALLOC( output, output_buffer_size );
-
-    PSA_ASSERT( psa_cipher_update( &operation,
-                                   input->x, input->len,
-                                   output, output_buffer_size,
-                                   &function_output_length ) );
-    TEST_ASSERT( function_output_length <=
-                 PSA_CIPHER_UPDATE_OUTPUT_SIZE( key_type, alg, input->len ) );
-    TEST_ASSERT( function_output_length <=
-                 PSA_CIPHER_UPDATE_OUTPUT_MAX_SIZE( input->len ) );
-    total_output_length += function_output_length;
-
     status = psa_cipher_finish( &operation,
                                 ( output_buffer_size == 0 ? NULL :
                                   output + total_output_length ),
@@ -2915,6 +2780,7 @@
     if( expected_status == PSA_SUCCESS )
     {
         PSA_ASSERT( psa_cipher_abort( &operation ) );
+
         ASSERT_COMPARE( expected_output->x, expected_output->len,
                         output, total_output_length );
     }