Fix I/O format of PSA EC J-PAKE for compliance

The format used by the mbedtls_ecjpake_xxx() APIs and that defined by
the PSA Crypto PAKE extension are quite different; the former is
tailored to the needs of TLS while the later is quite generic and plain.
Previously we only addressed some part of this impedance mismatch: the
different number of I/O rounds, but failed to address the part where the
legacy API adds some extras (length bytes, ECParameters) that shouldn't
be present in the PSA Crypto version. See comments in the code.

Add some length testing as well; would have caught the issue.

Signed-off-by: Manuel Pégourié-Gonnard <manuel.pegourie-gonnard@arm.com>
diff --git a/include/psa/crypto_extra.h b/include/psa/crypto_extra.h
index 6c2e06e..ef9d138 100644
--- a/include/psa/crypto_extra.h
+++ b/include/psa/crypto_extra.h
@@ -1765,9 +1765,9 @@
       primitive == PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC,      \
                                       PSA_ECC_FAMILY_SECP_R1, 256) ?    \
       (                                                                 \
-        output_step == PSA_PAKE_STEP_KEY_SHARE ? 69 :                   \
-        output_step == PSA_PAKE_STEP_ZK_PUBLIC ? 66 :                   \
-        33                                                              \
+        output_step == PSA_PAKE_STEP_KEY_SHARE ? 65 :                   \
+        output_step == PSA_PAKE_STEP_ZK_PUBLIC ? 65 :                   \
+        32                                                              \
       ) :                                                               \
       0 )
 
@@ -1795,9 +1795,9 @@
       primitive == PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC,      \
                                       PSA_ECC_FAMILY_SECP_R1, 256) ?    \
       (                                                                 \
-        input_step == PSA_PAKE_STEP_KEY_SHARE ? 69 :                    \
-        input_step == PSA_PAKE_STEP_ZK_PUBLIC ? 66 :                    \
-        33                                                              \
+        input_step == PSA_PAKE_STEP_KEY_SHARE ? 65 :                    \
+        input_step == PSA_PAKE_STEP_ZK_PUBLIC ? 65 :                    \
+        32                                                              \
       ) :                                                               \
       0 )
 
@@ -1808,7 +1808,7 @@
  *
  * See also #PSA_PAKE_OUTPUT_SIZE(\p alg, \p primitive, \p step).
  */
-#define PSA_PAKE_OUTPUT_MAX_SIZE 69
+#define PSA_PAKE_OUTPUT_MAX_SIZE 65
 
 /** Input buffer size for psa_pake_input() for any of the supported PAKE
  * algorithm and primitive suites and input step.
@@ -1817,7 +1817,7 @@
  *
  * See also #PSA_PAKE_INPUT_SIZE(\p alg, \p primitive, \p step).
  */
-#define PSA_PAKE_INPUT_MAX_SIZE 69
+#define PSA_PAKE_INPUT_MAX_SIZE 65
 
 /** Returns a suitable initializer for a PAKE cipher suite object of type
  * psa_pake_cipher_suite_t.
@@ -1906,7 +1906,10 @@
 
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
 #include <mbedtls/ecjpake.h>
-#define PSA_PAKE_BUFFER_SIZE ( ( 69 + 66 + 33 ) * 2 )
+/* Note: the format for mbedtls_ecjpake_read/write function has an extra
+ * length byte for each step, plus an extra 3 bytes for ECParameters in the
+ * server's 2nd round. */
+#define PSA_PAKE_BUFFER_SIZE ( ( 3 + 1 + 65 + 1 + 65 + 1 + 32 ) * 2 )
 #endif
 
 struct psa_pake_operation_s
diff --git a/library/psa_crypto_pake.c b/library/psa_crypto_pake.c
index 10d3e4a..df091bc 100644
--- a/library/psa_crypto_pake.c
+++ b/library/psa_crypto_pake.c
@@ -522,44 +522,29 @@
         }
 
         /*
-         * Steps sequences are stored as:
-         * struct {
-         *     opaque point <1..2^8-1>;
-         * } ECPoint;
+         * mbedtls_ecjpake_write_round_xxx() outputs thing in the format
+         * defined by draft-cragie-tls-ecjpake-01 section 7. The summary is
+         * that the data for each step is prepended with a length byte, and
+         * then they're concatenated. Additionally, the server's second round
+         * output is prepended with a 3-bytes ECParameters structure.
          *
-         * Where byte 0 stores the ECPoint curve point length.
-         *
-         * The sequence length is equal to:
-         * - data length extracted from byte 0
-         * - byte 0 size (1)
+         * In PSA, we output each step separately, and don't prepend the
+         * output with a length byte, even less a curve identifier, as that
+         * information is already available.
          */
         if( operation->state == PSA_PAKE_OUTPUT_X2S &&
-            operation->sequence == PSA_PAKE_X1_STEP_KEY_SHARE )
+            operation->sequence == PSA_PAKE_X1_STEP_KEY_SHARE &&
+            operation->role == PSA_PAKE_ROLE_SERVER )
         {
-            if( operation->role == PSA_PAKE_ROLE_SERVER )
-                /*
-                 * The X2S KEY SHARE Server steps sequence is stored as:
-                 * struct {
-                 *     ECPoint X;
-                 *    opaque r <1..2^8-1>;
-                 * } ECSchnorrZKP;
-                 *
-                 * And MbedTLS uses a 3 bytes Ephemeral public key ECPoint,
-                 * so byte 3 stores the r Schnorr signature length.
-                 *
-                 * The sequence length is equal to:
-                 * - curve storage size (3)
-                 * - data length extracted from byte 3
-                 * - byte 3 size (1)
-                 */
-                length = 3 + operation->buffer[3] + 1;
-            else
-                length = operation->buffer[0] + 1;
+            /* Skip ECParameters, with is 3 bytes (RFC 8422) */
+            operation->buffer_offset += 3;
         }
-        else
-            length = operation->buffer[operation->buffer_offset] + 1;
 
-        if( length > operation->buffer_length )
+        /* Read the length byte then move past it to the data */
+        length = operation->buffer[operation->buffer_offset];
+        operation->buffer_offset += 1;
+
+        if( operation->buffer_offset + length > operation->buffer_length )
             return( PSA_ERROR_DATA_CORRUPT );
 
         if( output_size < length )
@@ -569,7 +554,7 @@
         }
 
         memcpy( output,
-                operation->buffer +  operation->buffer_offset,
+                operation->buffer + operation->buffer_offset,
                 length );
         *output_length = length;
 
@@ -709,7 +694,35 @@
                 return( PSA_ERROR_BAD_STATE );
         }
 
-        /* Copy input to local buffer */
+        /*
+         * Copy input to local buffer and format it as the Mbed TLS API
+         * expects, i.e. as defined by draft-cragie-tls-ecjpake-01 section 7.
+         * The summary is that the data for each step is prepended with a
+         * length byte, and then they're concatenated. Additionally, the
+         * server's second round output is prepended with a 3-bytes
+         * ECParameters structure - which means we have to prepend that when
+         * we're a client.
+         */
+        if( operation->state == PSA_PAKE_INPUT_X4S &&
+            operation->sequence == PSA_PAKE_X1_STEP_KEY_SHARE &&
+            operation->role == PSA_PAKE_ROLE_CLIENT )
+        {
+            /* We only support secp256r1. */
+            /* This is the ECParameters structure defined by RFC 8422. */
+            unsigned char ecparameters[3] = {
+                3, /* named_curve */
+                0, 23 /* secp256r1 */
+            };
+            memcpy( operation->buffer + operation->buffer_length,
+                    ecparameters, sizeof( ecparameters ) );
+            operation->buffer_length += sizeof( ecparameters );
+        }
+
+        /* Write the length byte */
+        operation->buffer[operation->buffer_length] = input_length;
+        operation->buffer_length += 1;
+
+        /* Finally copy the data */
         memcpy( operation->buffer + operation->buffer_length,
                 input, input_length );
         operation->buffer_length += input_length;
diff --git a/tests/suites/test_suite_psa_crypto.data b/tests/suites/test_suite_psa_crypto.data
index f2478be..91fced8 100644
--- a/tests/suites/test_suite_psa_crypto.data
+++ b/tests/suites/test_suite_psa_crypto.data
@@ -6594,3 +6594,7 @@
 PSA PAKE: ecjpake inject input errors, second round server, client input first
 depends_on:PSA_WANT_KEY_TYPE_ECC_KEY_PAIR:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256
 ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:1:4:"abcdef"
+
+PSA PAKE: ecjpake size macros
+depends_on:PSA_WANT_KEY_TYPE_ECC_KEY_PAIR:PSA_WANT_ECC_SECP_R1_256
+ecjpake_size_macros:
diff --git a/tests/suites/test_suite_psa_crypto.function b/tests/suites/test_suite_psa_crypto.function
index fa237d3..1b144df 100644
--- a/tests/suites/test_suite_psa_crypto.function
+++ b/tests/suites/test_suite_psa_crypto.function
@@ -717,6 +717,15 @@
         PSA_PAKE_OUTPUT_SIZE(alg, primitive, PSA_PAKE_STEP_KEY_SHARE) +
         PSA_PAKE_OUTPUT_SIZE(alg, primitive, PSA_PAKE_STEP_ZK_PUBLIC) +
         PSA_PAKE_OUTPUT_SIZE(alg, primitive, PSA_PAKE_STEP_ZK_PROOF)) * 2;
+    /* The output should be exactly this size according to the spec */
+    const size_t expected_size_key_share =
+        PSA_PAKE_OUTPUT_SIZE(alg, primitive, PSA_PAKE_STEP_KEY_SHARE);
+    /* The output should be exactly this size according to the spec */
+    const size_t expected_size_zk_public =
+        PSA_PAKE_OUTPUT_SIZE(alg, primitive, PSA_PAKE_STEP_ZK_PUBLIC);
+    /* The output can be smaller: the spec allows stripping leading zeroes */
+    const size_t max_expected_size_zk_proof =
+        PSA_PAKE_OUTPUT_SIZE(alg, primitive, PSA_PAKE_STEP_ZK_PROOF);
     size_t buffer0_off = 0;
     size_t buffer1_off = 0;
     size_t s_g1_len, s_g2_len, s_a_len;
@@ -744,31 +753,37 @@
             PSA_ASSERT( psa_pake_output( server, PSA_PAKE_STEP_KEY_SHARE,
                                          buffer0 + buffer0_off,
                                          512 - buffer0_off, &s_g1_len ) );
+            TEST_EQUAL( s_g1_len, expected_size_key_share );
             s_g1_off = buffer0_off;
             buffer0_off += s_g1_len;
             PSA_ASSERT( psa_pake_output( server, PSA_PAKE_STEP_ZK_PUBLIC,
                                          buffer0 + buffer0_off,
                                          512 - buffer0_off, &s_x1_pk_len ) );
+            TEST_EQUAL( s_x1_pk_len, expected_size_zk_public );
             s_x1_pk_off = buffer0_off;
             buffer0_off += s_x1_pk_len;
             PSA_ASSERT( psa_pake_output( server, PSA_PAKE_STEP_ZK_PROOF,
                                          buffer0 + buffer0_off,
                                          512 - buffer0_off, &s_x1_pr_len ) );
+            TEST_LE_U( s_x1_pr_len, max_expected_size_zk_proof );
             s_x1_pr_off = buffer0_off;
             buffer0_off += s_x1_pr_len;
             PSA_ASSERT( psa_pake_output( server, PSA_PAKE_STEP_KEY_SHARE,
                                          buffer0 + buffer0_off,
                                          512 - buffer0_off, &s_g2_len ) );
+            TEST_EQUAL( s_g2_len, expected_size_key_share );
             s_g2_off = buffer0_off;
             buffer0_off += s_g2_len;
             PSA_ASSERT( psa_pake_output( server, PSA_PAKE_STEP_ZK_PUBLIC,
                                          buffer0 + buffer0_off,
                                          512 - buffer0_off, &s_x2_pk_len ) );
+            TEST_EQUAL( s_x2_pk_len, expected_size_zk_public );
             s_x2_pk_off = buffer0_off;
             buffer0_off += s_x2_pk_len;
             PSA_ASSERT( psa_pake_output( server, PSA_PAKE_STEP_ZK_PROOF,
                                          buffer0 + buffer0_off,
                                          512 - buffer0_off, &s_x2_pr_len ) );
+            TEST_LE_U( s_x2_pr_len, max_expected_size_zk_proof );
             s_x2_pr_off = buffer0_off;
             buffer0_off += s_x2_pr_len;
 
@@ -876,31 +891,37 @@
             PSA_ASSERT( psa_pake_output( client, PSA_PAKE_STEP_KEY_SHARE,
                                          buffer1 + buffer1_off,
                                          512 - buffer1_off, &c_g1_len ) );
+            TEST_EQUAL( c_g1_len, expected_size_key_share );
             c_g1_off = buffer1_off;
             buffer1_off += c_g1_len;
             PSA_ASSERT( psa_pake_output( client, PSA_PAKE_STEP_ZK_PUBLIC,
                                          buffer1 + buffer1_off,
                                          512 - buffer1_off, &c_x1_pk_len ) );
+            TEST_EQUAL( c_x1_pk_len, expected_size_zk_public );
             c_x1_pk_off = buffer1_off;
             buffer1_off += c_x1_pk_len;
             PSA_ASSERT( psa_pake_output( client, PSA_PAKE_STEP_ZK_PROOF,
                                          buffer1 + buffer1_off,
                                          512 - buffer1_off, &c_x1_pr_len ) );
+            TEST_LE_U( c_x1_pr_len, max_expected_size_zk_proof );
             c_x1_pr_off = buffer1_off;
             buffer1_off += c_x1_pr_len;
             PSA_ASSERT( psa_pake_output( client, PSA_PAKE_STEP_KEY_SHARE,
                                          buffer1 + buffer1_off,
                                          512 - buffer1_off, &c_g2_len ) );
+            TEST_EQUAL( c_g2_len, expected_size_key_share );
             c_g2_off = buffer1_off;
             buffer1_off += c_g2_len;
             PSA_ASSERT( psa_pake_output( client, PSA_PAKE_STEP_ZK_PUBLIC,
                                          buffer1 + buffer1_off,
                                          512 - buffer1_off, &c_x2_pk_len ) );
+            TEST_EQUAL( c_x2_pk_len, expected_size_zk_public );
             c_x2_pk_off = buffer1_off;
             buffer1_off += c_x2_pk_len;
             PSA_ASSERT( psa_pake_output( client, PSA_PAKE_STEP_ZK_PROOF,
                                          buffer1 + buffer1_off,
                                          512 - buffer1_off, &c_x2_pr_len ) );
+            TEST_LE_U( c_x2_pr_len, max_expected_size_zk_proof );
             c_x2_pr_off = buffer1_off;
             buffer1_off += c_x2_pr_len;
 
@@ -1082,16 +1103,19 @@
             PSA_ASSERT( psa_pake_output( server, PSA_PAKE_STEP_KEY_SHARE,
                                          buffer0 + buffer0_off,
                                          512 - buffer0_off, &s_a_len ) );
+            TEST_EQUAL( s_a_len, expected_size_key_share );
             s_a_off = buffer0_off;
             buffer0_off += s_a_len;
             PSA_ASSERT( psa_pake_output( server, PSA_PAKE_STEP_ZK_PUBLIC,
                                          buffer0 + buffer0_off,
                                          512 - buffer0_off, &s_x2s_pk_len ) );
+            TEST_EQUAL( s_x2s_pk_len, expected_size_zk_public );
             s_x2s_pk_off = buffer0_off;
             buffer0_off += s_x2s_pk_len;
             PSA_ASSERT( psa_pake_output( server, PSA_PAKE_STEP_ZK_PROOF,
                                          buffer0 + buffer0_off,
                                          512 - buffer0_off, &s_x2s_pr_len ) );
+            TEST_LE_U( s_x2s_pr_len, max_expected_size_zk_proof );
             s_x2s_pr_off = buffer0_off;
             buffer0_off += s_x2s_pr_len;
 
@@ -1153,16 +1177,19 @@
             PSA_ASSERT( psa_pake_output( client, PSA_PAKE_STEP_KEY_SHARE,
                                          buffer1 + buffer1_off,
                                          512 - buffer1_off, &c_a_len ) );
+            TEST_EQUAL( c_a_len, expected_size_key_share );
             c_a_off = buffer1_off;
             buffer1_off += c_a_len;
             PSA_ASSERT( psa_pake_output( client, PSA_PAKE_STEP_ZK_PUBLIC,
                                          buffer1 + buffer1_off,
                                          512 - buffer1_off, &c_x2s_pk_len ) );
+            TEST_EQUAL( c_x2s_pk_len, expected_size_zk_public );
             c_x2s_pk_off = buffer1_off;
             buffer1_off += c_x2s_pk_len;
             PSA_ASSERT( psa_pake_output( client, PSA_PAKE_STEP_ZK_PROOF,
                                          buffer1 + buffer1_off,
                                          512 - buffer1_off, &c_x2s_pr_len ) );
+            TEST_LE_U( c_x2s_pr_len, max_expected_size_zk_proof );
             c_x2s_pr_off = buffer1_off;
             buffer1_off += c_x2s_pr_len;
 
@@ -9008,3 +9035,47 @@
     PSA_DONE( );
 }
 /* END_CASE */
+
+/* BEGIN_CASE */
+void ecjpake_size_macros( )
+{
+    const psa_algorithm_t alg = PSA_ALG_JPAKE;
+    const size_t bits = 256;
+    const psa_pake_primitive_t prim = PSA_PAKE_PRIMITIVE(
+            PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, bits );
+    const psa_key_type_t key_type = PSA_KEY_TYPE_ECC_KEY_PAIR(
+            PSA_ECC_FAMILY_SECP_R1 );
+
+    // https://armmbed.github.io/mbed-crypto/1.1_PAKE_Extension.0-bet.0/html/pake.html#pake-step-types
+    /* The output for KEY_SHARE and ZK_PUBLIC is the same as a public key */
+    TEST_EQUAL( PSA_PAKE_OUTPUT_SIZE(alg, prim, PSA_PAKE_STEP_KEY_SHARE),
+                PSA_EXPORT_PUBLIC_KEY_OUTPUT_SIZE( key_type, bits ) );
+    TEST_EQUAL( PSA_PAKE_OUTPUT_SIZE(alg, prim, PSA_PAKE_STEP_ZK_PUBLIC),
+                PSA_EXPORT_PUBLIC_KEY_OUTPUT_SIZE( key_type, bits ) );
+    /* The output for ZK_PROOF is the same bitsize as the curve */
+    TEST_EQUAL( PSA_PAKE_OUTPUT_SIZE(alg, prim, PSA_PAKE_STEP_ZK_PROOF),
+                PSA_BITS_TO_BYTES( bits ) );
+
+    /* Input sizes are the same as output sizes */
+    TEST_EQUAL( PSA_PAKE_OUTPUT_SIZE(alg, prim, PSA_PAKE_STEP_KEY_SHARE),
+                PSA_PAKE_INPUT_SIZE(alg, prim, PSA_PAKE_STEP_KEY_SHARE) );
+    TEST_EQUAL( PSA_PAKE_OUTPUT_SIZE(alg, prim, PSA_PAKE_STEP_ZK_PUBLIC),
+                PSA_PAKE_INPUT_SIZE(alg, prim, PSA_PAKE_STEP_ZK_PUBLIC) );
+    TEST_EQUAL( PSA_PAKE_OUTPUT_SIZE(alg, prim, PSA_PAKE_STEP_ZK_PROOF),
+                PSA_PAKE_INPUT_SIZE(alg, prim, PSA_PAKE_STEP_ZK_PROOF) );
+
+    /* These inequalities will always hold even when other PAKEs are added */
+    TEST_LE_U( PSA_PAKE_OUTPUT_SIZE(alg, prim, PSA_PAKE_STEP_KEY_SHARE),
+               PSA_PAKE_OUTPUT_MAX_SIZE );
+    TEST_LE_U( PSA_PAKE_OUTPUT_SIZE(alg, prim, PSA_PAKE_STEP_ZK_PUBLIC),
+               PSA_PAKE_OUTPUT_MAX_SIZE );
+    TEST_LE_U( PSA_PAKE_OUTPUT_SIZE(alg, prim, PSA_PAKE_STEP_ZK_PROOF),
+               PSA_PAKE_OUTPUT_MAX_SIZE );
+    TEST_LE_U( PSA_PAKE_INPUT_SIZE(alg, prim, PSA_PAKE_STEP_KEY_SHARE),
+               PSA_PAKE_INPUT_MAX_SIZE );
+    TEST_LE_U( PSA_PAKE_INPUT_SIZE(alg, prim, PSA_PAKE_STEP_ZK_PUBLIC),
+               PSA_PAKE_INPUT_MAX_SIZE );
+    TEST_LE_U( PSA_PAKE_INPUT_SIZE(alg, prim, PSA_PAKE_STEP_ZK_PROOF),
+               PSA_PAKE_INPUT_MAX_SIZE );
+}
+/* END_CASE */