Adapt J-PAKE built-in impl to use user/peer

Signed-off-by: Przemek Stekiel <przemyslaw.stekiel@mobica.com>
diff --git a/include/psa/crypto_builtin_composites.h b/include/psa/crypto_builtin_composites.h
index 932c503..acda242 100644
--- a/include/psa/crypto_builtin_composites.h
+++ b/include/psa/crypto_builtin_composites.h
@@ -199,7 +199,7 @@
     uint8_t *MBEDTLS_PRIVATE(password);
     size_t MBEDTLS_PRIVATE(password_len);
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
-    uint8_t MBEDTLS_PRIVATE(role);
+    mbedtls_ecjpake_role MBEDTLS_PRIVATE(role);
     uint8_t MBEDTLS_PRIVATE(buffer[MBEDTLS_PSA_JPAKE_BUFFER_SIZE]);
     size_t MBEDTLS_PRIVATE(buffer_length);
     size_t MBEDTLS_PRIVATE(buffer_offset);
diff --git a/library/psa_crypto_pake.c b/library/psa_crypto_pake.c
index a537184..97aafb4 100644
--- a/library/psa_crypto_pake.c
+++ b/library/psa_crypto_pake.c
@@ -168,13 +168,11 @@
 static psa_status_t psa_pake_ecjpake_setup(mbedtls_psa_pake_operation_t *operation)
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-    mbedtls_ecjpake_role role = (operation->role == PSA_PAKE_ROLE_CLIENT) ?
-                                MBEDTLS_ECJPAKE_CLIENT : MBEDTLS_ECJPAKE_SERVER;
 
     mbedtls_ecjpake_init(&operation->ctx.jpake);
 
     ret = mbedtls_ecjpake_setup(&operation->ctx.jpake,
-                                role,
+                                operation->role,
                                 MBEDTLS_MD_SHA256,
                                 MBEDTLS_ECP_DP_SECP256R1,
                                 operation->password,
@@ -190,21 +188,30 @@
 }
 #endif
 
+/* The only two JPAKE user/peer identifiers supported in built-in implementation. */
+static const uint8_t jpake_server_id[] = { 's', 'e', 'r', 'v', 'e', 'r' };
+static const uint8_t jpake_client_id[] = { 'c', 'l', 'i', 'e', 'n', 't' };
+
 psa_status_t mbedtls_psa_pake_setup(mbedtls_psa_pake_operation_t *operation,
                                     const psa_crypto_driver_pake_inputs_t *inputs)
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
-    size_t password_len = 0;
-    psa_pake_role_t role = PSA_PAKE_ROLE_NONE;
+    size_t user_len = 0, peer_len = 0, password_len = 0;
+    uint8_t *peer = NULL, *user = NULL;
+    size_t actual_user_len = 0, actual_peer_len = 0, actual_password_len = 0;
     psa_pake_cipher_suite_t cipher_suite = psa_pake_cipher_suite_init();
-    size_t actual_password_len = 0;
 
     status = psa_crypto_driver_pake_get_password_len(inputs, &password_len);
     if (status != PSA_SUCCESS) {
         return status;
     }
 
-    status = psa_crypto_driver_pake_get_role(inputs, &role);
+    psa_crypto_driver_pake_get_user_len(inputs, &user_len);
+    if (status != PSA_SUCCESS) {
+        return status;
+    }
+
+    psa_crypto_driver_pake_get_peer_len(inputs, &peer_len);
     if (status != PSA_SUCCESS) {
         return status;
     }
@@ -216,7 +223,20 @@
 
     operation->password = mbedtls_calloc(1, password_len);
     if (operation->password == NULL) {
-        return PSA_ERROR_INSUFFICIENT_MEMORY;
+        status = PSA_ERROR_INSUFFICIENT_MEMORY;
+        goto error;
+    }
+
+    user = mbedtls_calloc(1, user_len);
+    if (user == NULL) {
+        status = PSA_ERROR_INSUFFICIENT_MEMORY;
+        goto error;
+    }
+
+    peer = mbedtls_calloc(1, peer_len);
+    if (peer == NULL) {
+        status = PSA_ERROR_INSUFFICIENT_MEMORY;
+        goto error;
     }
 
     status = psa_crypto_driver_pake_get_password(inputs, operation->password,
@@ -225,6 +245,18 @@
         goto error;
     }
 
+    status = psa_crypto_driver_pake_get_user(inputs, user,
+                                             user_len, &actual_user_len);
+    if (status != PSA_SUCCESS) {
+        goto error;
+    }
+
+    status = psa_crypto_driver_pake_get_peer(inputs, peer,
+                                             peer_len, &actual_peer_len);
+    if (status != PSA_SUCCESS) {
+        goto error;
+    }
+
     operation->password_len = actual_password_len;
     operation->alg = cipher_suite.algorithm;
 
@@ -238,7 +270,27 @@
             goto error;
         }
 
-        operation->role = role;
+        const size_t user_peer_len = sizeof(jpake_client_id); // client and server have the same length
+        if (actual_user_len != user_peer_len ||
+            actual_peer_len != user_peer_len) {
+            status = PSA_ERROR_NOT_SUPPORTED;
+            goto error;
+        }
+
+        if (memcmp(user, jpake_client_id, actual_user_len) == 0 &&
+            memcmp(peer, jpake_server_id, actual_peer_len) == 0) {
+            operation->role = MBEDTLS_ECJPAKE_CLIENT;
+        } else
+        if (memcmp(user, jpake_server_id, actual_user_len) == 0 &&
+            memcmp(peer, jpake_client_id, actual_peer_len) == 0) {
+            operation->role = MBEDTLS_ECJPAKE_SERVER;
+        } else {
+            status = PSA_ERROR_NOT_SUPPORTED;
+            goto error;
+        }
+
+        /* Role has been set, release user/peer buffers. */
+        mbedtls_free(user); mbedtls_free(peer);
 
         operation->buffer_length = 0;
         operation->buffer_offset = 0;
@@ -257,6 +309,7 @@
     { status = PSA_ERROR_NOT_SUPPORTED; }
 
 error:
+    mbedtls_free(user); mbedtls_free(peer);
     /* In case of failure of the setup of a multipart operation, the PSA driver interface
      * specifies that the core does not call any other driver entry point thus does not
      * call mbedtls_psa_pake_abort(). Therefore call it here to do the needed clean
@@ -332,7 +385,7 @@
          * information is already available.
          */
         if (step == PSA_JPAKE_X2S_STEP_KEY_SHARE &&
-            operation->role == PSA_PAKE_ROLE_SERVER) {
+            operation->role == MBEDTLS_ECJPAKE_SERVER) {
             /* Skip ECParameters, with is 3 bytes (RFC 8422) */
             operation->buffer_offset += 3;
         }
@@ -423,7 +476,7 @@
          * we're a client.
          */
         if (step == PSA_JPAKE_X4S_STEP_KEY_SHARE &&
-            operation->role == PSA_PAKE_ROLE_CLIENT) {
+            operation->role == MBEDTLS_ECJPAKE_CLIENT) {
             /* We only support secp256r1. */
             /* This is the ECParameters structure defined by RFC 8422. */
             unsigned char ecparameters[3] = {
@@ -541,7 +594,7 @@
 
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
     if (operation->alg == PSA_ALG_JPAKE) {
-        operation->role = PSA_PAKE_ROLE_NONE;
+        operation->role = 0;
         mbedtls_platform_zeroize(operation->buffer, sizeof(operation->buffer));
         operation->buffer_length = 0;
         operation->buffer_offset = 0;