Combine core pake computation stage(step,sequence,state) into single driver step

Signed-off-by: Przemek Stekiel <przemyslaw.stekiel@mobica.com>
diff --git a/library/psa_crypto_pake.c b/library/psa_crypto_pake.c
index a238147..da10cdd 100644
--- a/library/psa_crypto_pake.c
+++ b/library/psa_crypto_pake.c
@@ -266,8 +266,7 @@
 
 static psa_status_t mbedtls_psa_pake_output_internal(
     mbedtls_psa_pake_operation_t *operation,
-    psa_pake_step_t step,
-    const psa_pake_computation_stage_t *computation_stage,
+    psa_pake_driver_step_t step,
     uint8_t *output,
     size_t output_size,
     size_t *output_length)
@@ -292,12 +291,8 @@
      * to return the right parts on each step.
      */
     if (operation->alg == PSA_ALG_JPAKE) {
-        const psa_jpake_computation_stage_t *jpake_computation_stage =
-            &computation_stage->data.jpake_computation_stage;
-
         /* Initialize & write round on KEY_SHARE sequences */
-        if (jpake_computation_stage->state == PSA_PAKE_OUTPUT_X1_X2 &&
-            jpake_computation_stage->sequence == PSA_PAKE_X1_STEP_KEY_SHARE) {
+        if (step == PSA_JPAKE_X1_STEP_KEY_SHARE) {
             ret = mbedtls_ecjpake_write_round_one(&operation->ctx.pake,
                                                   operation->buffer,
                                                   MBEDTLS_PSA_PAKE_BUFFER_SIZE,
@@ -309,8 +304,7 @@
             }
 
             operation->buffer_offset = 0;
-        } else if (jpake_computation_stage->state == PSA_PAKE_OUTPUT_X2S &&
-                   jpake_computation_stage->sequence == PSA_PAKE_X1_STEP_KEY_SHARE) {
+        } else if (step == PSA_JPAKE_X2S_STEP_KEY_SHARE) {
             ret = mbedtls_ecjpake_write_round_two(&operation->ctx.pake,
                                                   operation->buffer,
                                                   MBEDTLS_PSA_PAKE_BUFFER_SIZE,
@@ -335,8 +329,7 @@
          * output with a length byte, even less a curve identifier, as that
          * information is already available.
          */
-        if (jpake_computation_stage->state == PSA_PAKE_OUTPUT_X2S &&
-            jpake_computation_stage->sequence == PSA_PAKE_X1_STEP_KEY_SHARE &&
+        if (step == PSA_JPAKE_X2S_STEP_KEY_SHARE &&
             operation->role == PSA_PAKE_ROLE_SERVER) {
             /* Skip ECParameters, with is 3 bytes (RFC 8422) */
             operation->buffer_offset += 3;
@@ -362,10 +355,8 @@
         operation->buffer_offset += length;
 
         /* Reset buffer after ZK_PROOF sequence */
-        if ((jpake_computation_stage->state == PSA_PAKE_OUTPUT_X1_X2 &&
-             jpake_computation_stage->sequence == PSA_PAKE_X2_STEP_ZK_PROOF) ||
-            (jpake_computation_stage->state == PSA_PAKE_OUTPUT_X2S &&
-             jpake_computation_stage->sequence == PSA_PAKE_X1_STEP_ZK_PROOF)) {
+        if ((step == PSA_JPAKE_X2_STEP_ZK_PROOF) ||
+            (step == PSA_JPAKE_X2S_STEP_ZK_PROOF)) {
             mbedtls_platform_zeroize(operation->buffer, MBEDTLS_PSA_PAKE_BUFFER_SIZE);
             operation->buffer_length = 0;
             operation->buffer_offset = 0;
@@ -375,7 +366,6 @@
     } else
 #else
     (void) step;
-    (void) computation_stage;
     (void) output;
     (void) output_size;
     (void) output_length;
@@ -384,14 +374,13 @@
 }
 
 psa_status_t mbedtls_psa_pake_output(mbedtls_psa_pake_operation_t *operation,
-                                     psa_pake_step_t step,
-                                     const psa_pake_computation_stage_t *computation_stage,
+                                     psa_pake_driver_step_t step,
                                      uint8_t *output,
                                      size_t output_size,
                                      size_t *output_length)
 {
     psa_status_t status = mbedtls_psa_pake_output_internal(
-        operation, step, computation_stage, output, output_size, output_length);
+        operation, step, output, output_size, output_length);
 
     if (status != PSA_SUCCESS) {
         mbedtls_psa_pake_abort(operation);
@@ -402,8 +391,7 @@
 
 static psa_status_t mbedtls_psa_pake_input_internal(
     mbedtls_psa_pake_operation_t *operation,
-    psa_pake_step_t step,
-    const psa_pake_computation_stage_t *computation_stage,
+    psa_pake_driver_step_t step,
     const uint8_t *input,
     size_t input_length)
 {
@@ -427,8 +415,6 @@
      * This causes any input error to be only detected on the last step.
      */
     if (operation->alg == PSA_ALG_JPAKE) {
-        const psa_jpake_computation_stage_t *jpake_computation_stage =
-            &computation_stage->data.jpake_computation_stage;
         /*
          * Copy input to local buffer and format it as the Mbed TLS API
          * expects, i.e. as defined by draft-cragie-tls-ecjpake-01 section 7.
@@ -438,8 +424,7 @@
          * ECParameters structure - which means we have to prepend that when
          * we're a client.
          */
-        if (jpake_computation_stage->state == PSA_PAKE_INPUT_X4S &&
-            jpake_computation_stage->sequence == PSA_PAKE_X1_STEP_KEY_SHARE &&
+        if (step == PSA_JPAKE_X4S_STEP_KEY_SHARE &&
             operation->role == PSA_PAKE_ROLE_CLIENT) {
             /* We only support secp256r1. */
             /* This is the ECParameters structure defined by RFC 8422. */
@@ -462,8 +447,7 @@
         operation->buffer_length += input_length;
 
         /* Load buffer at each last round ZK_PROOF */
-        if (jpake_computation_stage->state == PSA_PAKE_INPUT_X1_X2 &&
-            jpake_computation_stage->sequence == PSA_PAKE_X2_STEP_ZK_PROOF) {
+        if (step == PSA_JPAKE_X2_STEP_ZK_PROOF) {
             ret = mbedtls_ecjpake_read_round_one(&operation->ctx.pake,
                                                  operation->buffer,
                                                  operation->buffer_length);
@@ -474,8 +458,7 @@
             if (ret != 0) {
                 return mbedtls_ecjpake_to_psa_error(ret);
             }
-        } else if (jpake_computation_stage->state == PSA_PAKE_INPUT_X4S &&
-                   jpake_computation_stage->sequence == PSA_PAKE_X1_STEP_ZK_PROOF) {
+        } else if (step == PSA_JPAKE_X4S_STEP_ZK_PROOF) {
             ret = mbedtls_ecjpake_read_round_two(&operation->ctx.pake,
                                                  operation->buffer,
                                                  operation->buffer_length);
@@ -492,7 +475,6 @@
     } else
 #else
     (void) step;
-    (void) computation_stage;
     (void) input;
     (void) input_length;
 #endif
@@ -500,13 +482,12 @@
 }
 
 psa_status_t mbedtls_psa_pake_input(mbedtls_psa_pake_operation_t *operation,
-                                    psa_pake_step_t step,
-                                    const psa_pake_computation_stage_t *computation_stage,
+                                    psa_pake_driver_step_t step,
                                     const uint8_t *input,
                                     size_t input_length)
 {
     psa_status_t status = mbedtls_psa_pake_input_internal(
-        operation, step, computation_stage, input, input_length);
+        operation, step, input, input_length);
 
     if (status != PSA_SUCCESS) {
         mbedtls_psa_pake_abort(operation);