Move JPAKE state machine logic from driver to core

- Add `alg` and `computation_stage` to `psa_pake_operation_s`.
  Now when logic is moved to core information about `alg` is required.
  `computation_stage` is a structure that provides a union of computation stages for pake algorithms.
- Move the jpake operation logic from driver to core. This requires changing driver entry points for `psa_pake_output`/`psa_pake_input` functions and adding a `computation_stage` parameter. I'm not sure if this solution is correct. Now the driver can check the current computation stage and perform some action. For jpake drivers `step` parameter is now not used, but I think it needs to stay as it might be needed for other pake algorithms.
- Removed test that seems to be redundant as we can't be sure that operation is aborted after failure.

Signed-off-by: Przemek Stekiel <przemyslaw.stekiel@mobica.com>
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 273d248..66ecc06 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -7180,11 +7180,14 @@
     psa_pake_operation_t *operation,
     const psa_pake_cipher_suite_t *cipher_suite)
 {
+    psa_jpake_computation_stage_t *computation_stage =
+        &operation->computation_stage.data.jpake_computation_stage;
+
     if (operation->stage != PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
         return PSA_ERROR_BAD_STATE;
     }
 
-    if (operation->data.inputs.alg != PSA_ALG_NONE) {
+    if (operation->alg != PSA_ALG_NONE) {
         return PSA_ERROR_BAD_STATE;
     }
 
@@ -7198,9 +7201,16 @@
 
     memset(&operation->data.inputs, 0, sizeof(operation->data.inputs));
 
-    operation->data.inputs.alg = cipher_suite->algorithm;
+    operation->alg = cipher_suite->algorithm;
     operation->data.inputs.cipher_suite = *cipher_suite;
 
+    if (operation->alg == PSA_ALG_JPAKE) {
+        computation_stage->state = PSA_PAKE_STATE_SETUP;
+        computation_stage->sequence = PSA_PAKE_SEQ_INVALID;
+        computation_stage->input_step = PSA_PAKE_STEP_X1_X2;
+        computation_stage->output_step = PSA_PAKE_STEP_X1_X2;
+    }
+
     return PSA_SUCCESS;
 }
 
@@ -7216,7 +7226,7 @@
         return PSA_ERROR_BAD_STATE;
     }
 
-    if (operation->data.inputs.alg == PSA_ALG_NONE) {
+    if (operation->alg == PSA_ALG_NONE) {
         return PSA_ERROR_BAD_STATE;
     }
 
@@ -7241,7 +7251,8 @@
 
     operation->data.inputs.password = mbedtls_calloc(1, slot->key.bytes);
     if (operation->data.inputs.password == NULL) {
-        return PSA_ERROR_INSUFFICIENT_MEMORY;
+        status = PSA_ERROR_INSUFFICIENT_MEMORY;
+        goto error;
     }
 
     memcpy(operation->data.inputs.password, slot->key.data, slot->key.bytes);
@@ -7264,7 +7275,7 @@
         return PSA_ERROR_BAD_STATE;
     }
 
-    if (operation->data.inputs.alg == PSA_ALG_NONE) {
+    if (operation->alg == PSA_ALG_NONE) {
         return PSA_ERROR_BAD_STATE;
     }
 
@@ -7286,7 +7297,7 @@
         return PSA_ERROR_BAD_STATE;
     }
 
-    if (operation->data.inputs.alg == PSA_ALG_NONE) {
+    if (operation->alg == PSA_ALG_NONE) {
         return PSA_ERROR_BAD_STATE;
     }
 
@@ -7305,7 +7316,7 @@
         return PSA_ERROR_BAD_STATE;
     }
 
-    if (operation->data.inputs.alg == PSA_ALG_NONE) {
+    if (operation->alg == PSA_ALG_NONE) {
         return PSA_ERROR_BAD_STATE;
     }
 
@@ -7322,6 +7333,98 @@
     return PSA_SUCCESS;
 }
 
+static psa_status_t psa_jpake_output_prologue(
+    psa_pake_operation_t *operation,
+    psa_pake_step_t step)
+{
+    psa_jpake_computation_stage_t *computation_stage =
+        &operation->computation_stage.data.jpake_computation_stage;
+
+    if (computation_stage->state == PSA_PAKE_STATE_INVALID) {
+        return PSA_ERROR_BAD_STATE;
+    }
+
+    if (step != PSA_PAKE_STEP_KEY_SHARE &&
+        step != PSA_PAKE_STEP_ZK_PUBLIC &&
+        step != PSA_PAKE_STEP_ZK_PROOF) {
+        return PSA_ERROR_INVALID_ARGUMENT;
+    }
+
+    if (computation_stage->state != PSA_PAKE_STATE_READY &&
+        computation_stage->state != PSA_PAKE_OUTPUT_X1_X2 &&
+        computation_stage->state != PSA_PAKE_OUTPUT_X2S) {
+        return PSA_ERROR_BAD_STATE;
+    }
+
+    if (computation_stage->state == PSA_PAKE_STATE_READY) {
+        if (step != PSA_PAKE_STEP_KEY_SHARE) {
+            return PSA_ERROR_BAD_STATE;
+        }
+
+        switch (computation_stage->output_step) {
+            case PSA_PAKE_STEP_X1_X2:
+                computation_stage->state = PSA_PAKE_OUTPUT_X1_X2;
+                break;
+            case PSA_PAKE_STEP_X2S:
+                computation_stage->state = PSA_PAKE_OUTPUT_X2S;
+                break;
+            default:
+                return PSA_ERROR_BAD_STATE;
+        }
+
+        computation_stage->sequence = PSA_PAKE_X1_STEP_KEY_SHARE;
+    }
+
+    /* Check if step matches current sequence */
+    switch (computation_stage->sequence) {
+        case PSA_PAKE_X1_STEP_KEY_SHARE:
+        case PSA_PAKE_X2_STEP_KEY_SHARE:
+            if (step != PSA_PAKE_STEP_KEY_SHARE) {
+                return PSA_ERROR_BAD_STATE;
+            }
+            break;
+
+        case PSA_PAKE_X1_STEP_ZK_PUBLIC:
+        case PSA_PAKE_X2_STEP_ZK_PUBLIC:
+            if (step != PSA_PAKE_STEP_ZK_PUBLIC) {
+                return PSA_ERROR_BAD_STATE;
+            }
+            break;
+
+        case PSA_PAKE_X1_STEP_ZK_PROOF:
+        case PSA_PAKE_X2_STEP_ZK_PROOF:
+            if (step != PSA_PAKE_STEP_ZK_PROOF) {
+                return PSA_ERROR_BAD_STATE;
+            }
+            break;
+
+        default:
+            return PSA_ERROR_BAD_STATE;
+    }
+
+    return PSA_SUCCESS;
+}
+
+static psa_status_t psa_jpake_output_epilogue(
+    psa_pake_operation_t *operation)
+{
+    psa_jpake_computation_stage_t *computation_stage =
+        &operation->computation_stage.data.jpake_computation_stage;
+
+    if ((computation_stage->state == PSA_PAKE_OUTPUT_X1_X2 &&
+         computation_stage->sequence == PSA_PAKE_X2_STEP_ZK_PROOF) ||
+        (computation_stage->state == PSA_PAKE_OUTPUT_X2S &&
+         computation_stage->sequence == PSA_PAKE_X1_STEP_ZK_PROOF)) {
+        computation_stage->state = PSA_PAKE_STATE_READY;
+        computation_stage->output_step++;
+        computation_stage->sequence = PSA_PAKE_SEQ_INVALID;
+    } else {
+        computation_stage->sequence++;
+    }
+
+    return PSA_SUCCESS;
+}
+
 psa_status_t psa_pake_output(
     psa_pake_operation_t *operation,
     psa_pake_step_t step,
@@ -7330,9 +7433,11 @@
     size_t *output_length)
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+    psa_jpake_computation_stage_t *computation_stage =
+        &operation->computation_stage.data.jpake_computation_stage;
 
     if (operation->stage == PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
-        if (operation->data.inputs.alg == PSA_ALG_NONE ||
+        if (operation->alg == PSA_ALG_NONE ||
             operation->data.inputs.password_len == 0 ||
             operation->data.inputs.role == PSA_PAKE_ROLE_NONE) {
             return PSA_ERROR_BAD_STATE;
@@ -7343,6 +7448,12 @@
 
         if (status == PSA_SUCCESS) {
             operation->stage = PSA_PAKE_OPERATION_STAGE_COMPUTATION;
+            if (operation->alg == PSA_ALG_JPAKE) {
+                computation_stage->state = PSA_PAKE_STATE_READY;
+                computation_stage->sequence = PSA_PAKE_SEQ_INVALID;
+                computation_stage->input_step = PSA_PAKE_STEP_X1_X2;
+                computation_stage->output_step = PSA_PAKE_STEP_X1_X2;
+            }
         } else {
             return status;
         }
@@ -7360,10 +7471,140 @@
         return PSA_ERROR_INVALID_ARGUMENT;
     }
 
-    return psa_driver_wrapper_pake_output(operation, step, output,
-                                          output_size, output_length);
+    switch (operation->alg) {
+        case PSA_ALG_JPAKE:
+            status = psa_jpake_output_prologue(operation, step);
+            if (status != PSA_SUCCESS) {
+                return status;
+            }
+            break;
+        default:
+            return PSA_ERROR_NOT_SUPPORTED;
+    }
+
+    status = psa_driver_wrapper_pake_output(operation, step,
+                                            &operation->computation_stage,
+                                            output, output_size, output_length);
+
+    if (status != PSA_SUCCESS) {
+        return status;
+    }
+
+    switch (operation->alg) {
+        case PSA_ALG_JPAKE:
+            status = psa_jpake_output_epilogue(operation);
+            if (status != PSA_SUCCESS) {
+                return status;
+            }
+            break;
+        default:
+            return PSA_ERROR_NOT_SUPPORTED;
+    }
+
+    return status;
 }
 
+static psa_status_t psa_jpake_input_prologue(
+    psa_pake_operation_t *operation,
+    psa_pake_step_t step,
+    size_t input_length)
+{
+    psa_jpake_computation_stage_t *computation_stage =
+        &operation->computation_stage.data.jpake_computation_stage;
+
+    if (computation_stage->state == PSA_PAKE_STATE_INVALID) {
+        return PSA_ERROR_BAD_STATE;
+    }
+
+    if (step != PSA_PAKE_STEP_KEY_SHARE &&
+        step != PSA_PAKE_STEP_ZK_PUBLIC &&
+        step != PSA_PAKE_STEP_ZK_PROOF) {
+        return PSA_ERROR_INVALID_ARGUMENT;
+    }
+
+    const psa_pake_primitive_t prim = PSA_PAKE_PRIMITIVE(
+        PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256);
+    if (input_length > (size_t) PSA_PAKE_INPUT_SIZE(PSA_ALG_JPAKE, prim, step)) {
+        return PSA_ERROR_INVALID_ARGUMENT;
+    }
+
+    if (computation_stage->state != PSA_PAKE_STATE_READY &&
+        computation_stage->state != PSA_PAKE_INPUT_X1_X2 &&
+        computation_stage->state != PSA_PAKE_INPUT_X4S) {
+        return PSA_ERROR_BAD_STATE;
+    }
+
+    if (computation_stage->state == PSA_PAKE_STATE_READY) {
+        if (step != PSA_PAKE_STEP_KEY_SHARE) {
+            return PSA_ERROR_BAD_STATE;
+        }
+
+        switch (computation_stage->input_step) {
+            case PSA_PAKE_STEP_X1_X2:
+                computation_stage->state = PSA_PAKE_INPUT_X1_X2;
+                break;
+            case PSA_PAKE_STEP_X2S:
+                computation_stage->state = PSA_PAKE_INPUT_X4S;
+                break;
+            default:
+                return PSA_ERROR_BAD_STATE;
+        }
+
+        computation_stage->sequence = PSA_PAKE_X1_STEP_KEY_SHARE;
+    }
+
+    /* Check if step matches current sequence */
+    switch (computation_stage->sequence) {
+        case PSA_PAKE_X1_STEP_KEY_SHARE:
+        case PSA_PAKE_X2_STEP_KEY_SHARE:
+            if (step != PSA_PAKE_STEP_KEY_SHARE) {
+                return PSA_ERROR_BAD_STATE;
+            }
+            break;
+
+        case PSA_PAKE_X1_STEP_ZK_PUBLIC:
+        case PSA_PAKE_X2_STEP_ZK_PUBLIC:
+            if (step != PSA_PAKE_STEP_ZK_PUBLIC) {
+                return PSA_ERROR_BAD_STATE;
+            }
+            break;
+
+        case PSA_PAKE_X1_STEP_ZK_PROOF:
+        case PSA_PAKE_X2_STEP_ZK_PROOF:
+            if (step != PSA_PAKE_STEP_ZK_PROOF) {
+                return PSA_ERROR_BAD_STATE;
+            }
+            break;
+
+        default:
+            return PSA_ERROR_BAD_STATE;
+    }
+
+    return PSA_SUCCESS;
+}
+
+
+static psa_status_t psa_jpake_input_epilogue(
+    psa_pake_operation_t *operation)
+{
+    psa_jpake_computation_stage_t *computation_stage =
+        &operation->computation_stage.data.jpake_computation_stage;
+
+    if ((computation_stage->state == PSA_PAKE_INPUT_X1_X2 &&
+         computation_stage->sequence == PSA_PAKE_X2_STEP_ZK_PROOF) ||
+        (computation_stage->state == PSA_PAKE_INPUT_X4S &&
+         computation_stage->sequence == PSA_PAKE_X1_STEP_ZK_PROOF)) {
+        computation_stage->state = PSA_PAKE_STATE_READY;
+        computation_stage->input_step++;
+        computation_stage->sequence = PSA_PAKE_SEQ_INVALID;
+    } else {
+        computation_stage->sequence++;
+    }
+
+    return PSA_SUCCESS;
+}
+
+
 psa_status_t psa_pake_input(
     psa_pake_operation_t *operation,
     psa_pake_step_t step,
@@ -7371,9 +7612,11 @@
     size_t input_length)
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+    psa_jpake_computation_stage_t *computation_stage =
+        &operation->computation_stage.data.jpake_computation_stage;
 
     if (operation->stage == PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
-        if (operation->data.inputs.alg == PSA_ALG_NONE ||
+        if (operation->alg == PSA_ALG_NONE ||
             operation->data.inputs.password_len == 0 ||
             operation->data.inputs.role == PSA_PAKE_ROLE_NONE) {
             return PSA_ERROR_BAD_STATE;
@@ -7384,6 +7627,12 @@
 
         if (status == PSA_SUCCESS) {
             operation->stage = PSA_PAKE_OPERATION_STAGE_COMPUTATION;
+            if (operation->alg == PSA_ALG_JPAKE) {
+                computation_stage->state = PSA_PAKE_STATE_READY;
+                computation_stage->sequence = PSA_PAKE_SEQ_INVALID;
+                computation_stage->input_step = PSA_PAKE_STEP_X1_X2;
+                computation_stage->output_step = PSA_PAKE_STEP_X1_X2;
+            }
         } else {
             return status;
         }
@@ -7401,8 +7650,37 @@
         return PSA_ERROR_INVALID_ARGUMENT;
     }
 
-    return psa_driver_wrapper_pake_input(operation, step, input,
-                                         input_length);
+    switch (operation->alg) {
+        case PSA_ALG_JPAKE:
+            status = psa_jpake_input_prologue(operation, step, input_length);
+            if (status != PSA_SUCCESS) {
+                return status;
+            }
+            break;
+        default:
+            return PSA_ERROR_NOT_SUPPORTED;
+    }
+
+    status = psa_driver_wrapper_pake_input(operation, step,
+                                           &operation->computation_stage,
+                                           input, input_length);
+
+    if (status != PSA_SUCCESS) {
+        return status;
+    }
+
+    switch (operation->alg) {
+        case PSA_ALG_JPAKE:
+            status = psa_jpake_input_epilogue(operation);
+            if (status != PSA_SUCCESS) {
+                return status;
+            }
+            break;
+        default:
+            return PSA_ERROR_NOT_SUPPORTED;
+    }
+
+    return status;
 }
 
 psa_status_t psa_pake_get_implicit_key(
@@ -7412,11 +7690,20 @@
     psa_status_t status = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     uint8_t shared_key[MBEDTLS_PSA_PAKE_BUFFER_SIZE];
     size_t shared_key_len = 0;
+    psa_jpake_computation_stage_t *computation_stage =
+        &operation->computation_stage.data.jpake_computation_stage;
 
     if (operation->id == 0) {
         return PSA_ERROR_BAD_STATE;
     }
 
+    if (operation->alg == PSA_ALG_JPAKE) {
+        if (computation_stage->input_step != PSA_PAKE_STEP_DERIVE ||
+            computation_stage->output_step != PSA_PAKE_STEP_DERIVE) {
+            return PSA_ERROR_BAD_STATE;
+        }
+    }
+
     status = psa_driver_wrapper_pake_get_implicit_key(operation,
                                                       shared_key,
                                                       &shared_key_len);
@@ -7436,18 +7723,29 @@
 
     mbedtls_platform_zeroize(shared_key, MBEDTLS_PSA_PAKE_BUFFER_SIZE);
 
+    psa_pake_abort(operation);
+
     return status;
 }
 
 psa_status_t psa_pake_abort(
     psa_pake_operation_t *operation)
 {
+    psa_jpake_computation_stage_t *computation_stage =
+        &operation->computation_stage.data.jpake_computation_stage;
+
     /* If we are in collecting inputs stage clear inputs. */
     if (operation->stage == PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
         mbedtls_free(operation->data.inputs.password);
         memset(&operation->data.inputs, 0, sizeof(psa_crypto_driver_pake_inputs_t));
         return PSA_SUCCESS;
     }
+    if (operation->alg == PSA_ALG_JPAKE) {
+        computation_stage->input_step = PSA_PAKE_STEP_INVALID;
+        computation_stage->output_step = PSA_PAKE_STEP_INVALID;
+        computation_stage->state = PSA_PAKE_STATE_INVALID;
+        computation_stage->sequence = PSA_PAKE_SEQ_INVALID;
+    }
 
     return psa_driver_wrapper_pake_abort(operation);
 }