Move JPAKE rounds into a common function, add reordering and error injection

Signed-off-by: Neil Armstrong <narmstrong@baylibre.com>
diff --git a/tests/suites/test_suite_psa_crypto.data b/tests/suites/test_suite_psa_crypto.data
index e571e51..fef475a 100644
--- a/tests/suites/test_suite_psa_crypto.data
+++ b/tests/suites/test_suite_psa_crypto.data
@@ -6496,4 +6496,8 @@
 
 PSA PAKE: ecjpake rounds
 depends_on:PSA_WANT_KEY_TYPE_ECC_KEY_PAIR:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256
-ecjpake_rounds:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:PSA_ALG_TLS12_PSK_TO_MS(PSA_ALG_SHA_256):"abcdef"
+ecjpake_rounds:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:PSA_ALG_TLS12_PSK_TO_MS(PSA_ALG_SHA_256):"abcdef":0
+
+PSA PAKE: ecjpake rounds, client input first
+depends_on:PSA_WANT_KEY_TYPE_ECC_KEY_PAIR:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256
+ecjpake_rounds:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:PSA_ALG_TLS12_PSK_TO_MS(PSA_ALG_SHA_256):"abcdef":1
diff --git a/tests/suites/test_suite_psa_crypto.function b/tests/suites/test_suite_psa_crypto.function
index 6d4f2a8..cf7ea7d 100644
--- a/tests/suites/test_suite_psa_crypto.function
+++ b/tests/suites/test_suite_psa_crypto.function
@@ -705,6 +705,296 @@
     return( test_ok );
 }
 
+static int ecjpake_do_round( psa_algorithm_t alg, unsigned int primitive,
+                             psa_pake_operation_t *server,
+                             psa_pake_operation_t *client,
+                             int client_input_first,
+                             int round, int inject_error )
+{
+    unsigned char *buffer0 = NULL, *buffer1 = NULL;
+    size_t buffer_length = (
+        PSA_PAKE_OUTPUT_SIZE(alg, primitive, PSA_PAKE_STEP_KEY_SHARE) +
+        PSA_PAKE_OUTPUT_SIZE(alg, primitive, PSA_PAKE_STEP_ZK_PUBLIC) +
+        PSA_PAKE_OUTPUT_SIZE(alg, primitive, PSA_PAKE_STEP_ZK_PROOF)) * 2;
+    size_t buffer0_off = 0;
+    size_t buffer1_off = 0;
+    size_t s_g1_len, s_g2_len, s_a_len;
+    size_t s_g1_off, s_g2_off, s_a_off;
+    size_t s_x1_pk_len, s_x2_pk_len, s_x2s_pk_len;
+    size_t s_x1_pk_off, s_x2_pk_off, s_x2s_pk_off;
+    size_t s_x1_pr_len, s_x2_pr_len, s_x2s_pr_len;
+    size_t s_x1_pr_off, s_x2_pr_off, s_x2s_pr_off;
+    size_t c_g1_len, c_g2_len, c_a_len;
+    size_t c_g1_off, c_g2_off, c_a_off;
+    size_t c_x1_pk_len, c_x2_pk_len, c_x2s_pk_len;
+    size_t c_x1_pk_off, c_x2_pk_off, c_x2s_pk_off;
+    size_t c_x1_pr_len, c_x2_pr_len, c_x2s_pr_len;
+    size_t c_x1_pr_off, c_x2_pr_off, c_x2s_pr_off;
+    psa_status_t expected_status = PSA_SUCCESS;
+    int ret;
+
+    ASSERT_ALLOC( buffer0, buffer_length );
+    ASSERT_ALLOC( buffer1, buffer_length );
+
+    switch( round )
+    {
+        case 1:
+            /* Server first round Output */
+            PSA_ASSERT( psa_pake_output( server, PSA_PAKE_STEP_KEY_SHARE,
+                                         buffer0 + buffer0_off,
+                                         512 - buffer0_off, &s_g1_len ) );
+            s_g1_off = buffer0_off;
+            buffer0_off += s_g1_len;
+            PSA_ASSERT( psa_pake_output( server, PSA_PAKE_STEP_ZK_PUBLIC,
+                                         buffer0 + buffer0_off,
+                                         512 - buffer0_off, &s_x1_pk_len ) );
+            s_x1_pk_off = buffer0_off;
+            buffer0_off += s_x1_pk_len;
+            PSA_ASSERT( psa_pake_output( server, PSA_PAKE_STEP_ZK_PROOF,
+                                         buffer0 + buffer0_off,
+                                         512 - buffer0_off, &s_x1_pr_len ) );
+            s_x1_pr_off = buffer0_off;
+            buffer0_off += s_x1_pr_len;
+            PSA_ASSERT( psa_pake_output( server, PSA_PAKE_STEP_KEY_SHARE,
+                                         buffer0 + buffer0_off,
+                                         512 - buffer0_off, &s_g2_len ) );
+            s_g2_off = buffer0_off;
+            buffer0_off += s_g2_len;
+            PSA_ASSERT( psa_pake_output( server, PSA_PAKE_STEP_ZK_PUBLIC,
+                                         buffer0 + buffer0_off,
+                                         512 - buffer0_off, &s_x2_pk_len ) );
+            s_x2_pk_off = buffer0_off;
+            buffer0_off += s_x2_pk_len;
+            PSA_ASSERT( psa_pake_output( server, PSA_PAKE_STEP_ZK_PROOF,
+                                         buffer0 + buffer0_off,
+                                         512 - buffer0_off, &s_x2_pr_len ) );
+            s_x2_pr_off = buffer0_off;
+            buffer0_off += s_x2_pr_len;
+
+            if( inject_error == 1 )
+            {
+                buffer0[s_x1_pk_off + 12] >>= 4;
+                buffer0[s_x2_pk_off + 7] <<= 4;
+                expected_status = PSA_ERROR_DATA_INVALID;
+            }
+
+            if( client_input_first == 1 )
+            {
+                /* Client first round Input */
+                PSA_ASSERT( psa_pake_input( client, PSA_PAKE_STEP_KEY_SHARE,
+                                            buffer0 + s_g1_off, s_g1_len ) );
+                PSA_ASSERT( psa_pake_input( client, PSA_PAKE_STEP_ZK_PUBLIC,
+                                            buffer0 + s_x1_pk_off,
+                                            s_x1_pk_len ) );
+                PSA_ASSERT( psa_pake_input( client, PSA_PAKE_STEP_ZK_PROOF,
+                                            buffer0 + s_x1_pr_off,
+                                            s_x1_pr_len ) );
+                PSA_ASSERT( psa_pake_input( client, PSA_PAKE_STEP_KEY_SHARE,
+                                            buffer0 + s_g2_off,
+                                            s_g2_len ) );
+                PSA_ASSERT( psa_pake_input( client, PSA_PAKE_STEP_ZK_PUBLIC,
+                                            buffer0 + s_x2_pk_off,
+                                            s_x2_pk_len ) );
+                TEST_EQUAL( psa_pake_input( client, PSA_PAKE_STEP_ZK_PROOF,
+                                            buffer0 + s_x2_pr_off,
+                                            s_x2_pr_len ),
+                            expected_status );
+
+                if( inject_error == 1 )
+                {
+                    ret = 1;
+                    goto exit;
+                }
+            }
+
+            /* Client first round Output */
+            PSA_ASSERT( psa_pake_output( client, PSA_PAKE_STEP_KEY_SHARE,
+                                         buffer1 + buffer1_off,
+                                         512 - buffer1_off, &c_g1_len ) );
+            c_g1_off = buffer1_off;
+            buffer1_off += c_g1_len;
+            PSA_ASSERT( psa_pake_output( client, PSA_PAKE_STEP_ZK_PUBLIC,
+                                         buffer1 + buffer1_off,
+                                         512 - buffer1_off, &c_x1_pk_len ) );
+            c_x1_pk_off = buffer1_off;
+            buffer1_off += c_x1_pk_len;
+            PSA_ASSERT( psa_pake_output( client, PSA_PAKE_STEP_ZK_PROOF,
+                                         buffer1 + buffer1_off,
+                                         512 - buffer1_off, &c_x1_pr_len ) );
+            c_x1_pr_off = buffer1_off;
+            buffer1_off += c_x1_pr_len;
+            PSA_ASSERT( psa_pake_output( client, PSA_PAKE_STEP_KEY_SHARE,
+                                         buffer1 + buffer1_off,
+                                         512 - buffer1_off, &c_g2_len ) );
+            c_g2_off = buffer1_off;
+            buffer1_off += c_g2_len;
+            PSA_ASSERT( psa_pake_output( client, PSA_PAKE_STEP_ZK_PUBLIC,
+                                         buffer1 + buffer1_off,
+                                         512 - buffer1_off, &c_x2_pk_len ) );
+            c_x2_pk_off = buffer1_off;
+            buffer1_off += c_x2_pk_len;
+            PSA_ASSERT( psa_pake_output( client, PSA_PAKE_STEP_ZK_PROOF,
+                                         buffer1 + buffer1_off,
+                                         512 - buffer1_off, &c_x2_pr_len ) );
+            c_x2_pr_off = buffer1_off;
+            buffer1_off += c_x2_pr_len;
+
+            if( client_input_first == 0 )
+            {
+                /* Client first round Input */
+                PSA_ASSERT( psa_pake_input( client, PSA_PAKE_STEP_KEY_SHARE,
+                                            buffer0 + s_g1_off, s_g1_len ) );
+                PSA_ASSERT( psa_pake_input( client, PSA_PAKE_STEP_ZK_PUBLIC,
+                                            buffer0 + s_x1_pk_off,
+                                            s_x1_pk_len ) );
+                PSA_ASSERT( psa_pake_input( client, PSA_PAKE_STEP_ZK_PROOF,
+                                            buffer0 + s_x1_pr_off,
+                                            s_x1_pr_len ) );
+                PSA_ASSERT( psa_pake_input( client, PSA_PAKE_STEP_KEY_SHARE,
+                                            buffer0 + s_g2_off,
+                                            s_g2_len ) );
+                PSA_ASSERT( psa_pake_input( client, PSA_PAKE_STEP_ZK_PUBLIC,
+                                            buffer0 + s_x2_pk_off,
+                                            s_x2_pk_len ) );
+                TEST_EQUAL( psa_pake_input( client, PSA_PAKE_STEP_ZK_PROOF,
+                                            buffer0 + s_x2_pr_off,
+                                            s_x2_pr_len ),
+                            expected_status );
+
+                if( inject_error == 1 )
+                    break;
+            }
+
+            if( inject_error == 2 )
+            {
+                buffer1[c_x1_pk_off + 12] >>= 4;
+                buffer1[c_x2_pk_off + 7] <<= 4;
+                expected_status = PSA_ERROR_DATA_INVALID;
+            }
+
+            /* Server first round Input */
+            PSA_ASSERT( psa_pake_input( server, PSA_PAKE_STEP_KEY_SHARE,
+                                        buffer1 + c_g1_off, c_g1_len ) );
+            PSA_ASSERT( psa_pake_input( server, PSA_PAKE_STEP_ZK_PUBLIC,
+                                        buffer1 + c_x1_pk_off, c_x1_pk_len ) );
+            PSA_ASSERT( psa_pake_input( server, PSA_PAKE_STEP_ZK_PROOF,
+                                        buffer1 + c_x1_pr_off, c_x1_pr_len ) );
+            PSA_ASSERT( psa_pake_input( server, PSA_PAKE_STEP_KEY_SHARE,
+                                        buffer1 + c_g2_off, c_g2_len ) );
+            PSA_ASSERT( psa_pake_input( server, PSA_PAKE_STEP_ZK_PUBLIC,
+                                        buffer1 + c_x2_pk_off, c_x2_pk_len ) );
+            TEST_EQUAL( psa_pake_input( server, PSA_PAKE_STEP_ZK_PROOF,
+                                        buffer1 + c_x2_pr_off, c_x2_pr_len ),
+                        expected_status );
+
+            break;
+
+        case 2:
+            /* Server second round Output */
+            buffer0_off = 0;
+
+            PSA_ASSERT( psa_pake_output( server, PSA_PAKE_STEP_KEY_SHARE,
+                                         buffer0 + buffer0_off,
+                                         512 - buffer0_off, &s_a_len ) );
+            s_a_off = buffer0_off;
+            buffer0_off += s_a_len;
+            PSA_ASSERT( psa_pake_output( server, PSA_PAKE_STEP_ZK_PUBLIC,
+                                         buffer0 + buffer0_off,
+                                         512 - buffer0_off, &s_x2s_pk_len ) );
+            s_x2s_pk_off = buffer0_off;
+            buffer0_off += s_x2s_pk_len;
+            PSA_ASSERT( psa_pake_output( server, PSA_PAKE_STEP_ZK_PROOF,
+                                         buffer0 + buffer0_off,
+                                         512 - buffer0_off, &s_x2s_pr_len ) );
+            s_x2s_pr_off = buffer0_off;
+            buffer0_off += s_x2s_pr_len;
+
+            if( inject_error == 3 )
+            {
+                buffer0[s_x2s_pk_off + 12] >>= 4;
+                expected_status = PSA_ERROR_DATA_INVALID;
+            }
+
+            if( client_input_first == 1 )
+            {
+                /* Client second round Input */
+                PSA_ASSERT( psa_pake_input( client, PSA_PAKE_STEP_KEY_SHARE,
+                                            buffer0 + s_a_off, s_a_len ) );
+                PSA_ASSERT( psa_pake_input( client, PSA_PAKE_STEP_ZK_PUBLIC,
+                                            buffer0 + s_x2s_pk_off,
+                                            s_x2s_pk_len ) );
+                TEST_EQUAL( psa_pake_input( client, PSA_PAKE_STEP_ZK_PROOF,
+                                            buffer0 + s_x2s_pr_off,
+                                            s_x2s_pr_len ),
+                            expected_status );
+
+                if( inject_error == 3 )
+                    break;
+            }
+
+            /* Client second round Output */
+            buffer1_off = 0;
+
+            PSA_ASSERT( psa_pake_output( client, PSA_PAKE_STEP_KEY_SHARE,
+                                         buffer1 + buffer1_off,
+                                         512 - buffer1_off, &c_a_len ) );
+            c_a_off = buffer1_off;
+            buffer1_off += c_a_len;
+            PSA_ASSERT( psa_pake_output( client, PSA_PAKE_STEP_ZK_PUBLIC,
+                                         buffer1 + buffer1_off,
+                                         512 - buffer1_off, &c_x2s_pk_len ) );
+            c_x2s_pk_off = buffer1_off;
+            buffer1_off += c_x2s_pk_len;
+            PSA_ASSERT( psa_pake_output( client, PSA_PAKE_STEP_ZK_PROOF,
+                                         buffer1 + buffer1_off,
+                                         512 - buffer1_off, &c_x2s_pr_len ) );
+            c_x2s_pr_off = buffer1_off;
+            buffer1_off += c_x2s_pr_len;
+
+            if( client_input_first == 0 )
+            {
+                /* Client second round Input */
+                PSA_ASSERT( psa_pake_input( client, PSA_PAKE_STEP_KEY_SHARE,
+                                            buffer0 + s_a_off, s_a_len ) );
+                PSA_ASSERT( psa_pake_input( client, PSA_PAKE_STEP_ZK_PUBLIC,
+                                            buffer0 + s_x2s_pk_off,
+                                            s_x2s_pk_len ) );
+                TEST_EQUAL( psa_pake_input( client, PSA_PAKE_STEP_ZK_PROOF,
+                                            buffer0 + s_x2s_pr_off,
+                                            s_x2s_pr_len ),
+                            expected_status );
+
+                if( inject_error == 3 )
+                    break;
+            }
+
+            if( inject_error == 4 )
+            {
+                buffer1[c_x2s_pk_off + 12] >>= 4;
+                expected_status = PSA_ERROR_DATA_INVALID;
+            }
+
+            /* Server second round Input */
+            PSA_ASSERT( psa_pake_input( server, PSA_PAKE_STEP_KEY_SHARE,
+                                        buffer1 + c_a_off, c_a_len ) );
+            PSA_ASSERT( psa_pake_input( server, PSA_PAKE_STEP_ZK_PUBLIC,
+                                        buffer1 + c_x2s_pk_off, c_x2s_pk_len ) );
+            TEST_EQUAL( psa_pake_input( server, PSA_PAKE_STEP_ZK_PROOF,
+                                        buffer1 + c_x2s_pr_off, c_x2s_pr_len ),
+                        expected_status );
+
+            break;
+
+    }
+
+    ret = 1;
+
+exit:
+    mbedtls_free( buffer0 );
+    mbedtls_free( buffer1 );
+    return( ret );
+}
+
 /* END_HEADER */
 
 /* BEGIN_DEPENDENCIES
@@ -8267,7 +8557,8 @@
 
 /* BEGIN_CASE depends_on:PSA_WANT_ALG_JPAKE */
 void ecjpake_rounds( int alg_arg, int primitive_arg, int hash_arg,
-                     int derive_alg_arg, data_t *pw_data )
+                     int derive_alg_arg, data_t *pw_data,
+                     int client_input_first )
 {
     psa_pake_cipher_suite_t cipher_suite = psa_pake_cipher_suite_init();
     psa_pake_operation_t server = psa_pake_operation_init();
@@ -8281,31 +8572,9 @@
                             PSA_KEY_DERIVATION_OPERATION_INIT;
     psa_key_derivation_operation_t client_derive =
                             PSA_KEY_DERIVATION_OPERATION_INIT;
-    unsigned char *buffer0 = NULL, *buffer1 = NULL;
-    size_t buffer_length = (
-        PSA_PAKE_OUTPUT_SIZE(alg, primitive_arg, PSA_PAKE_STEP_KEY_SHARE) +
-        PSA_PAKE_OUTPUT_SIZE(alg, primitive_arg, PSA_PAKE_STEP_ZK_PUBLIC) +
-        PSA_PAKE_OUTPUT_SIZE(alg, primitive_arg, PSA_PAKE_STEP_ZK_PROOF)) * 2;
-    size_t buffer0_off = 0;
-    size_t buffer1_off = 0;
-    size_t s_g1_len, s_g2_len, s_a_len;
-    size_t s_g1_off, s_g2_off, s_a_off;
-    size_t s_x1_pk_len, s_x2_pk_len, s_x2s_pk_len;
-    size_t s_x1_pk_off, s_x2_pk_off, s_x2s_pk_off;
-    size_t s_x1_pr_len, s_x2_pr_len, s_x2s_pr_len;
-    size_t s_x1_pr_off, s_x2_pr_off, s_x2s_pr_off;
-    size_t c_g1_len, c_g2_len, c_a_len;
-    size_t c_g1_off, c_g2_off, c_a_off;
-    size_t c_x1_pk_len, c_x2_pk_len, c_x2s_pk_len;
-    size_t c_x1_pk_off, c_x2_pk_off, c_x2s_pk_off;
-    size_t c_x1_pr_len, c_x2_pr_len, c_x2s_pr_len;
-    size_t c_x1_pr_off, c_x2_pr_off, c_x2s_pr_off;
 
     PSA_INIT( );
 
-    ASSERT_ALLOC( buffer0, buffer_length );
-    ASSERT_ALLOC( buffer1, buffer_length );
-
     psa_set_key_usage_flags( &attributes, PSA_KEY_USAGE_DERIVE );
     psa_set_key_algorithm( &attributes, alg );
     psa_set_key_type( &attributes, PSA_KEY_TYPE_PASSWORD );
@@ -8345,169 +8614,18 @@
     TEST_EQUAL( psa_pake_get_implicit_key( &client, &client_derive ),
                 PSA_ERROR_BAD_STATE );
 
-    /* Server first round Output */
-    PSA_ASSERT( psa_pake_output( &server, PSA_PAKE_STEP_KEY_SHARE,
-                                 buffer0 + buffer0_off,
-                                 512 - buffer0_off, &s_g1_len ) );
-    s_g1_off = buffer0_off;
-    buffer0_off += s_g1_len;
-    PSA_ASSERT( psa_pake_output( &server, PSA_PAKE_STEP_ZK_PUBLIC,
-                                 buffer0 + buffer0_off,
-                                 512 - buffer0_off, &s_x1_pk_len ) );
-    s_x1_pk_off = buffer0_off;
-    buffer0_off += s_x1_pk_len;
-    PSA_ASSERT( psa_pake_output( &server, PSA_PAKE_STEP_ZK_PROOF,
-                                 buffer0 + buffer0_off,
-                                 512 - buffer0_off, &s_x1_pr_len ) );
-    s_x1_pr_off = buffer0_off;
-    buffer0_off += s_x1_pr_len;
-    PSA_ASSERT( psa_pake_output( &server, PSA_PAKE_STEP_KEY_SHARE,
-                                 buffer0 + buffer0_off,
-                                 512 - buffer0_off, &s_g2_len ) );
-    s_g2_off = buffer0_off;
-    buffer0_off += s_g2_len;
-    PSA_ASSERT( psa_pake_output( &server, PSA_PAKE_STEP_ZK_PUBLIC,
-                                 buffer0 + buffer0_off,
-                                 512 - buffer0_off, &s_x2_pk_len ) );
-    s_x2_pk_off = buffer0_off;
-    buffer0_off += s_x2_pk_len;
-    PSA_ASSERT( psa_pake_output( &server, PSA_PAKE_STEP_ZK_PROOF,
-                                 buffer0 + buffer0_off,
-                                 512 - buffer0_off, &s_x2_pr_len ) );
-    s_x2_pr_off = buffer0_off;
-    buffer0_off += s_x2_pr_len;
-
-    /* Client first round Output */
-    PSA_ASSERT( psa_pake_output( &client, PSA_PAKE_STEP_KEY_SHARE,
-                                 buffer1 + buffer1_off,
-                                 512 - buffer1_off, &c_g1_len ) );
-    c_g1_off = buffer1_off;
-    buffer1_off += c_g1_len;
-    PSA_ASSERT( psa_pake_output( &client, PSA_PAKE_STEP_ZK_PUBLIC,
-                                 buffer1 + buffer1_off,
-                                 512 - buffer1_off, &c_x1_pk_len ) );
-    c_x1_pk_off = buffer1_off;
-    buffer1_off += c_x1_pk_len;
-    PSA_ASSERT( psa_pake_output( &client, PSA_PAKE_STEP_ZK_PROOF,
-                                 buffer1 + buffer1_off,
-                                 512 - buffer1_off, &c_x1_pr_len ) );
-    c_x1_pr_off = buffer1_off;
-    buffer1_off += c_x1_pr_len;
-    PSA_ASSERT( psa_pake_output( &client, PSA_PAKE_STEP_KEY_SHARE,
-                                 buffer1 + buffer1_off,
-                                 512 - buffer1_off, &c_g2_len ) );
-    c_g2_off = buffer1_off;
-    buffer1_off += c_g2_len;
-    PSA_ASSERT( psa_pake_output( &client, PSA_PAKE_STEP_ZK_PUBLIC,
-                                 buffer1 + buffer1_off,
-                                 512 - buffer1_off, &c_x2_pk_len ) );
-    c_x2_pk_off = buffer1_off;
-    buffer1_off += c_x2_pk_len;
-    PSA_ASSERT( psa_pake_output( &client, PSA_PAKE_STEP_ZK_PROOF,
-                                 buffer1 + buffer1_off,
-                                 512 - buffer1_off, &c_x2_pr_len ) );
-    c_x2_pr_off = buffer1_off;
-    buffer1_off += c_x2_pr_len;
+    /* First round */
+    TEST_EQUAL( ecjpake_do_round( alg, primitive_arg, &server, &client,
+                                  client_input_first, 1, 0 ), 1 );
 
     TEST_EQUAL( psa_pake_get_implicit_key( &server, &server_derive ),
                 PSA_ERROR_BAD_STATE );
     TEST_EQUAL( psa_pake_get_implicit_key( &client, &client_derive ),
                 PSA_ERROR_BAD_STATE );
 
-    /* Client first round Input */
-    PSA_ASSERT( psa_pake_input( &client, PSA_PAKE_STEP_KEY_SHARE,
-                                buffer0 + s_g1_off, s_g1_len ) );
-    PSA_ASSERT( psa_pake_input( &client, PSA_PAKE_STEP_ZK_PUBLIC,
-                                buffer0 + s_x1_pk_off, s_x1_pk_len ) );
-    PSA_ASSERT( psa_pake_input( &client, PSA_PAKE_STEP_ZK_PROOF,
-                                buffer0 + s_x1_pr_off, s_x1_pr_len ) );
-    PSA_ASSERT( psa_pake_input( &client, PSA_PAKE_STEP_KEY_SHARE,
-                                buffer0 + s_g2_off, s_g2_len ) );
-    PSA_ASSERT( psa_pake_input( &client, PSA_PAKE_STEP_ZK_PUBLIC,
-                                buffer0 + s_x2_pk_off, s_x2_pk_len ) );
-    PSA_ASSERT( psa_pake_input( &client, PSA_PAKE_STEP_ZK_PROOF,
-                                buffer0 + s_x2_pr_off, s_x2_pr_len ) );
-
-    /* Server first round Input */
-    PSA_ASSERT( psa_pake_input( &server, PSA_PAKE_STEP_KEY_SHARE,
-                                buffer1 + c_g1_off, c_g1_len ) );
-    PSA_ASSERT( psa_pake_input( &server, PSA_PAKE_STEP_ZK_PUBLIC,
-                                buffer1 + c_x1_pk_off, c_x1_pk_len ) );
-    PSA_ASSERT( psa_pake_input( &server, PSA_PAKE_STEP_ZK_PROOF,
-                                buffer1 + c_x1_pr_off, c_x1_pr_len ) );
-    PSA_ASSERT( psa_pake_input( &server, PSA_PAKE_STEP_KEY_SHARE,
-                                buffer1 + c_g2_off, c_g2_len ) );
-    PSA_ASSERT( psa_pake_input( &server, PSA_PAKE_STEP_ZK_PUBLIC,
-                                buffer1 + c_x2_pk_off, c_x2_pk_len ) );
-    PSA_ASSERT( psa_pake_input( &server, PSA_PAKE_STEP_ZK_PROOF,
-                                buffer1 + c_x2_pr_off, c_x2_pr_len ) );
-
-    TEST_EQUAL( psa_pake_get_implicit_key( &server, &server_derive ),
-                PSA_ERROR_BAD_STATE );
-    TEST_EQUAL( psa_pake_get_implicit_key( &client, &client_derive ),
-                PSA_ERROR_BAD_STATE );
-
-    /* Server second round Output */
-    buffer0_off = 0;
-
-    PSA_ASSERT( psa_pake_output( &server, PSA_PAKE_STEP_KEY_SHARE,
-                                 buffer0 + buffer0_off,
-                                 512 - buffer0_off, &s_a_len ) );
-    s_a_off = buffer0_off;
-    buffer0_off += s_a_len;
-    PSA_ASSERT( psa_pake_output( &server, PSA_PAKE_STEP_ZK_PUBLIC,
-                                 buffer0 + buffer0_off,
-                                 512 - buffer0_off, &s_x2s_pk_len ) );
-    s_x2s_pk_off = buffer0_off;
-    buffer0_off += s_x2s_pk_len;
-    PSA_ASSERT( psa_pake_output( &server, PSA_PAKE_STEP_ZK_PROOF,
-                                 buffer0 + buffer0_off,
-                                 512 - buffer0_off, &s_x2s_pr_len ) );
-    s_x2s_pr_off = buffer0_off;
-    buffer0_off += s_x2s_pr_len;
-
-    /* Client second round Output */
-    buffer1_off = 0;
-
-    PSA_ASSERT( psa_pake_output( &client, PSA_PAKE_STEP_KEY_SHARE,
-                                 buffer1 + buffer1_off,
-                                 512 - buffer1_off, &c_a_len ) );
-    c_a_off = buffer1_off;
-    buffer1_off += c_a_len;
-    PSA_ASSERT( psa_pake_output( &client, PSA_PAKE_STEP_ZK_PUBLIC,
-                                 buffer1 + buffer1_off,
-                                 512 - buffer1_off, &c_x2s_pk_len ) );
-    c_x2s_pk_off = buffer1_off;
-    buffer1_off += c_x2s_pk_len;
-    PSA_ASSERT( psa_pake_output( &client, PSA_PAKE_STEP_ZK_PROOF,
-                                 buffer1 + buffer1_off,
-                                 512 - buffer1_off, &c_x2s_pr_len ) );
-    c_x2s_pr_off = buffer1_off;
-    buffer1_off += c_x2s_pr_len;
-
-    TEST_EQUAL( psa_pake_get_implicit_key( &server, &server_derive ),
-                PSA_ERROR_BAD_STATE );
-    TEST_EQUAL( psa_pake_get_implicit_key( &client, &client_derive ),
-                PSA_ERROR_BAD_STATE );
-
-    /* Client second round Input */
-    PSA_ASSERT( psa_pake_input( &client, PSA_PAKE_STEP_KEY_SHARE,
-                                buffer0 + s_a_off, s_a_len ) );
-    PSA_ASSERT( psa_pake_input( &client, PSA_PAKE_STEP_ZK_PUBLIC,
-                                buffer0 + s_x2s_pk_off, s_x2s_pk_len ) );
-    PSA_ASSERT( psa_pake_input( &client, PSA_PAKE_STEP_ZK_PROOF,
-                                buffer0 + s_x2s_pr_off, s_x2s_pr_len ) );
-
-    TEST_EQUAL( psa_pake_get_implicit_key( &server, &server_derive ),
-                PSA_ERROR_BAD_STATE );
-
-    /* Server second round Input */
-    PSA_ASSERT( psa_pake_input( &server, PSA_PAKE_STEP_KEY_SHARE,
-                                buffer1 + c_a_off, c_a_len ) );
-    PSA_ASSERT( psa_pake_input( &server, PSA_PAKE_STEP_ZK_PUBLIC,
-                                buffer1 + c_x2s_pk_off, c_x2s_pk_len ) );
-    PSA_ASSERT( psa_pake_input( &server, PSA_PAKE_STEP_ZK_PROOF,
-                                buffer1 + c_x2s_pr_off, c_x2s_pr_len ) );
+    /* Second round */
+    TEST_EQUAL( ecjpake_do_round( alg, primitive_arg, &server, &client,
+                                  client_input_first, 2, 0 ), 1 );
 
     PSA_ASSERT( psa_pake_get_implicit_key( &server, &server_derive ) );
     PSA_ASSERT( psa_pake_get_implicit_key( &client, &client_derive ) );
@@ -8518,8 +8636,6 @@
     psa_destroy_key( key );
     psa_pake_abort( &server );
     psa_pake_abort( &client );
-    mbedtls_free( buffer0 );
-    mbedtls_free( buffer1 );
     PSA_DONE( );
 }
 /* END_CASE */