Divide pake operation into two phases collecting inputs and computation.

Functions that only set inputs do not have driver entry points.

Signed-off-by: Przemek Stekiel <przemyslaw.stekiel@mobica.com>
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 18aa18b..4742c3c 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -7180,7 +7180,29 @@
     psa_pake_operation_t *operation,
     const psa_pake_cipher_suite_t *cipher_suite)
 {
-    return psa_driver_wrapper_pake_setup(operation, cipher_suite);
+    if (operation->stage != PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
+        return PSA_ERROR_BAD_STATE;
+    }
+
+    if (operation->data.inputs.alg != PSA_ALG_NONE) {
+        return PSA_ERROR_BAD_STATE;
+    }
+
+    if (cipher_suite == NULL ||
+        PSA_ALG_IS_PAKE(cipher_suite->algorithm) == 0 ||
+        (cipher_suite->type != PSA_PAKE_PRIMITIVE_TYPE_ECC &&
+         cipher_suite->type != PSA_PAKE_PRIMITIVE_TYPE_DH) ||
+        PSA_ALG_IS_HASH(cipher_suite->hash) == 0) {
+        return PSA_ERROR_INVALID_ARGUMENT;
+    }
+
+    ;
+    memset(&operation->data.inputs, 0, sizeof(operation->data.inputs));
+
+    operation->data.inputs.alg = cipher_suite->algorithm;
+    operation->data.inputs.cipher_suite = *cipher_suite;
+
+    return PSA_SUCCESS;
 }
 
 psa_status_t psa_pake_set_password_key(
@@ -7191,7 +7213,11 @@
     psa_status_t unlock_status = PSA_ERROR_CORRUPTION_DETECTED;
     psa_key_slot_t *slot = NULL;
 
-    if (operation->id == 0) {
+    if (operation->stage != PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
+        return PSA_ERROR_BAD_STATE;
+    }
+
+    if (operation->data.inputs.alg == PSA_ALG_NONE) {
         return PSA_ERROR_BAD_STATE;
     }
 
@@ -7206,9 +7232,29 @@
         .core = slot->attr
     };
 
-    status = psa_driver_wrapper_pake_set_password_key(&attributes, operation,
-                                                      slot->key.data, slot->key.bytes);
+    psa_key_type_t type = psa_get_key_type(&attributes);
+    psa_key_usage_t usage = psa_get_key_usage_flags(&attributes);
 
+    if (type != PSA_KEY_TYPE_PASSWORD &&
+        type != PSA_KEY_TYPE_PASSWORD_HASH) {
+        status = PSA_ERROR_INVALID_ARGUMENT;
+        goto error;
+    }
+
+    if ((usage & PSA_KEY_USAGE_DERIVE) == 0) {
+        status = PSA_ERROR_NOT_PERMITTED;
+        goto error;
+    }
+
+    operation->data.inputs.password = mbedtls_calloc(1, slot->key.bytes);
+    if (operation->data.inputs.password == NULL) {
+        return PSA_ERROR_INSUFFICIENT_MEMORY;
+    }
+
+    memcpy(operation->data.inputs.password, slot->key.data, slot->key.bytes);
+    operation->data.inputs.password_len = slot->key.bytes;
+    operation->data.inputs.key_lifetime = attributes.core.lifetime;
+error:
     unlock_status = psa_unlock_key_slot(slot);
 
     return (status == PSA_SUCCESS) ? unlock_status : status;
@@ -7219,16 +7265,21 @@
     const uint8_t *user_id,
     size_t user_id_len)
 {
-    if (operation->id == 0) {
+    (void) user_id;
+
+    if (operation->stage != PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
         return PSA_ERROR_BAD_STATE;
     }
 
-    if (user_id_len == 0 || user_id == NULL) {
+    if (operation->data.inputs.alg == PSA_ALG_NONE) {
+        return PSA_ERROR_BAD_STATE;
+    }
+
+    if (user_id_len == 0) {
         return PSA_ERROR_INVALID_ARGUMENT;
     }
 
-    return psa_driver_wrapper_pake_set_user(operation, user_id,
-                                            user_id_len);
+    return PSA_ERROR_NOT_SUPPORTED;
 }
 
 psa_status_t psa_pake_set_peer(
@@ -7236,23 +7287,32 @@
     const uint8_t *peer_id,
     size_t peer_id_len)
 {
-    if (operation->id == 0) {
+    (void) peer_id;
+
+    if (operation->stage != PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
         return PSA_ERROR_BAD_STATE;
     }
 
-    if (peer_id_len == 0 || peer_id == NULL) {
+    if (operation->data.inputs.alg == PSA_ALG_NONE) {
+        return PSA_ERROR_BAD_STATE;
+    }
+
+    if (peer_id_len == 0) {
         return PSA_ERROR_INVALID_ARGUMENT;
     }
 
-    return psa_driver_wrapper_pake_set_peer(operation, peer_id,
-                                            peer_id_len);
+    return PSA_ERROR_NOT_SUPPORTED;
 }
 
 psa_status_t psa_pake_set_role(
     psa_pake_operation_t *operation,
     psa_pake_role_t role)
 {
-    if (operation->id == 0) {
+    if (operation->stage != PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
+        return PSA_ERROR_BAD_STATE;
+    }
+
+    if (operation->data.inputs.alg == PSA_ALG_NONE) {
         return PSA_ERROR_BAD_STATE;
     }
 
@@ -7264,7 +7324,9 @@
         return PSA_ERROR_INVALID_ARGUMENT;
     }
 
-    return psa_driver_wrapper_pake_set_role(operation, role);
+    operation->data.inputs.role = role;
+
+    return PSA_SUCCESS;
 }
 
 psa_status_t psa_pake_output(
@@ -7274,11 +7336,34 @@
     size_t output_size,
     size_t *output_length)
 {
+    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+
+    if (operation->stage == PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
+        if (operation->data.inputs.alg == PSA_ALG_NONE ||
+            operation->data.inputs.password_len == 0 ||
+            operation->data.inputs.role == PSA_PAKE_ROLE_NONE) {
+            return PSA_ERROR_BAD_STATE;
+        }
+
+        status = psa_driver_wrapper_pake_setup(operation,
+                                               &operation->data.inputs);
+
+        if (status == PSA_SUCCESS) {
+            operation->stage = PSA_PAKE_OPERATION_STAGE_COMPUTATION;
+        } else {
+            return status;
+        }
+    }
+
+    if (operation->stage != PSA_PAKE_OPERATION_STAGE_COMPUTATION) {
+        return PSA_ERROR_BAD_STATE;
+    }
+
     if (operation->id == 0) {
         return PSA_ERROR_BAD_STATE;
     }
 
-    if (output == NULL || output_size == 0 || output_length == NULL) {
+    if (output == NULL || output_size == 0) {
         return PSA_ERROR_INVALID_ARGUMENT;
     }
 
@@ -7292,6 +7377,29 @@
     const uint8_t *input,
     size_t input_length)
 {
+    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+
+    if (operation->stage == PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
+        if (operation->data.inputs.alg == PSA_ALG_NONE ||
+            operation->data.inputs.password_len == 0 ||
+            operation->data.inputs.role == PSA_PAKE_ROLE_NONE) {
+            return PSA_ERROR_BAD_STATE;
+        }
+
+        status = psa_driver_wrapper_pake_setup(operation,
+                                               &operation->data.inputs);
+
+        if (status == PSA_SUCCESS) {
+            operation->stage = PSA_PAKE_OPERATION_STAGE_COMPUTATION;
+        } else {
+            return status;
+        }
+    }
+
+    if (operation->stage != PSA_PAKE_OPERATION_STAGE_COMPUTATION) {
+        return PSA_ERROR_BAD_STATE;
+    }
+
     if (operation->id == 0) {
         return PSA_ERROR_BAD_STATE;
     }
@@ -7341,8 +7449,10 @@
 psa_status_t psa_pake_abort(
     psa_pake_operation_t *operation)
 {
-    /* Aborting a non-active operation is allowed */
-    if (operation->id == 0) {
+    /* If we are in collecting inputs stage clear inputs. */
+    if (operation->stage == PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
+        mbedtls_free(operation->data.inputs.password);
+        memset(&operation->data.inputs, 0, sizeof(psa_crypto_driver_pake_inputs_t));
         return PSA_SUCCESS;
     }