Change J-PAKE internal state machine

Keep track of the J-PAKE internal state in a more intuitive way.
Specifically, replace the current state with a struct of 5 fields:

* The round of J-PAKE we are currently in, FIRST or SECOND
* The 'mode' we are currently working in, INPUT or OUTPUT
* The number of inputs so far this round
* The number of outputs so far this round
* The PAKE step we are expecting, KEY_SHARE, ZK_PUBLIC or ZK_PROOF

This should improve the readability of the state-transformation code.

Signed-off-by: David Horstmann <david.horstmann@arm.com>
diff --git a/include/psa/crypto_extra.h b/include/psa/crypto_extra.h
index 5529dd1..a3351a6 100644
--- a/include/psa/crypto_extra.h
+++ b/include/psa/crypto_extra.h
@@ -2028,14 +2028,33 @@
     PSA_JPAKE_X4S_STEP_ZK_PROOF   = 12  /* Round 2: input Schnorr NIZKP proof for the X4S key (from peer) */
 } psa_crypto_driver_pake_step_t;
 
+typedef enum psa_jpake_round {
+    FIRST = 0,
+    SECOND = 1,
+    FINISHED = 2
+} psa_jpake_round_t;
+
+typedef enum psa_jpake_io_mode {
+    INPUT = 0,
+    OUTPUT = 1
+} psa_jpake_io_mode_t;
 
 struct psa_jpake_computation_stage_s {
-    psa_jpake_state_t MBEDTLS_PRIVATE(state);
-    psa_jpake_sequence_t MBEDTLS_PRIVATE(sequence);
-    psa_jpake_step_t MBEDTLS_PRIVATE(input_step);
-    psa_jpake_step_t MBEDTLS_PRIVATE(output_step);
+    /* The J-PAKE round we are currently on */
+    psa_jpake_round_t MBEDTLS_PRIVATE(round);
+    /* The 'mode' we are currently in (inputting or outputting) */
+    psa_jpake_io_mode_t MBEDTLS_PRIVATE(mode);
+    /* The number of inputs so far this round */
+    uint8_t MBEDTLS_PRIVATE(inputs);
+    /* The number of outputs so far this round */
+    uint8_t MBEDTLS_PRIVATE(outputs);
+    /* The next expected step (KEY_SHARE, ZK_PUBLIC or ZK_PROOF) */
+    psa_pake_step_t MBEDTLS_PRIVATE(step);
 };
 
+#define PSA_JPAKE_EXPECTED_INPUTS(round) (((round) == FIRST) ? 2 : 1)
+#define PSA_JPAKE_EXPECTED_OUTPUTS(round) (((round) == FIRST) ? 2 : 1)
+
 struct psa_pake_operation_s {
     /** Unique ID indicating which driver got assigned to do the
      * operation. Since driver contexts are driver-specific, swapping