Cleanup the code

Signed-off-by: Przemek Stekiel <przemyslaw.stekiel@mobica.com>
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 0bb751b..f7b0270 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -898,7 +898,7 @@
     psa_algorithm_t alg)
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
-    psa_key_slot_t *slot;
+    psa_key_slot_t *slot = NULL;
 
     status = psa_get_and_lock_key_slot(key, p_slot);
     if (status != PSA_SUCCESS) {
@@ -7180,9 +7180,6 @@
     psa_pake_operation_t *operation,
     const psa_pake_cipher_suite_t *cipher_suite)
 {
-    psa_jpake_computation_stage_t *computation_stage =
-        &operation->computation_stage.data.jpake_computation_stage;
-
     if (operation->stage != PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
         return PSA_ERROR_BAD_STATE;
     }
@@ -7205,6 +7202,9 @@
     operation->data.inputs.cipher_suite = *cipher_suite;
 
     if (operation->alg == PSA_ALG_JPAKE) {
+        psa_jpake_computation_stage_t *computation_stage =
+            &operation->computation_stage.data.jpake_computation_stage;
+
         computation_stage->state = PSA_PAKE_STATE_SETUP;
         computation_stage->sequence = PSA_PAKE_SEQ_INVALID;
         computation_stage->input_step = PSA_PAKE_STEP_X1_X2;
@@ -7260,7 +7260,6 @@
     operation->data.inputs.key_lifetime = attributes.core.lifetime;
 error:
     unlock_status = psa_unlock_key_slot(slot);
-
     return (status == PSA_SUCCESS) ? unlock_status : status;
 }
 
@@ -7603,7 +7602,6 @@
     return PSA_SUCCESS;
 }
 
-
 static psa_status_t psa_jpake_input_epilogue(
     psa_pake_operation_t *operation)
 {
@@ -7624,7 +7622,6 @@
     return PSA_SUCCESS;
 }
 
-
 psa_status_t psa_pake_input(
     psa_pake_operation_t *operation,
     psa_pake_step_t step,
@@ -7733,27 +7730,38 @@
 psa_status_t psa_pake_abort(
     psa_pake_operation_t *operation)
 {
-    psa_jpake_computation_stage_t *computation_stage =
-        &operation->computation_stage.data.jpake_computation_stage;
+    psa_status_t status = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
 
-    /* If we are in collecting inputs stage clear inputs. */
-    if (operation->stage == PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
-        if (operation->data.inputs.password_len > 0) {
-            mbedtls_platform_zeroize(operation->data.inputs.password,
-                                     operation->data.inputs.password_len);
-            mbedtls_free(operation->data.inputs.password);
+    if (operation->id != 0) {
+        status = psa_driver_wrapper_pake_abort(operation);
+        if (status != PSA_SUCCESS) {
+            return status;
         }
-        memset(&operation->data.inputs, 0, sizeof(psa_crypto_driver_pake_inputs_t));
-        return PSA_SUCCESS;
     }
+
+    if (operation->data.inputs.password_len > 0) {
+        mbedtls_platform_zeroize(operation->data.inputs.password,
+                                 operation->data.inputs.password_len);
+        mbedtls_free(operation->data.inputs.password);
+    }
+
+    memset(&operation->data, 0, sizeof(operation->data));
+
     if (operation->alg == PSA_ALG_JPAKE) {
+        psa_jpake_computation_stage_t *computation_stage =
+            &operation->computation_stage.data.jpake_computation_stage;
+
         computation_stage->input_step = PSA_PAKE_STEP_INVALID;
         computation_stage->output_step = PSA_PAKE_STEP_INVALID;
         computation_stage->state = PSA_PAKE_STATE_INVALID;
         computation_stage->sequence = PSA_PAKE_SEQ_INVALID;
     }
 
-    return psa_driver_wrapper_pake_abort(operation);
+    operation->alg = PSA_ALG_NONE;
+    operation->stage = PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS;
+    operation->id = 0;
+
+    return PSA_SUCCESS;
 }
 
 #endif /* MBEDTLS_PSA_CRYPTO_C */
diff --git a/library/psa_crypto_pake.c b/library/psa_crypto_pake.c
index 01998a6..a238147 100644
--- a/library/psa_crypto_pake.c
+++ b/library/psa_crypto_pake.c
@@ -274,11 +274,7 @@
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     size_t length;
-    (void) step;
-
-    if (operation->alg == PSA_ALG_NONE) {
-        return PSA_ERROR_BAD_STATE;
-    }
+    (void) step; // Unused parameter
 
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
     /*
@@ -412,10 +408,7 @@
     size_t input_length)
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-    (void) step;
-    if (operation->alg == PSA_ALG_NONE) {
-        return PSA_ERROR_BAD_STATE;
-    }
+    (void) step; // Unused parameter
 
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
     /*
@@ -528,10 +521,6 @@
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
 
-    if (operation->alg == PSA_ALG_NONE) {
-        return PSA_ERROR_BAD_STATE;
-    }
-
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
     if (operation->alg == PSA_ALG_JPAKE) {
         ret = mbedtls_ecjpake_write_shared_key(&operation->ctx.pake,
@@ -562,10 +551,6 @@
 
 psa_status_t mbedtls_psa_pake_abort(mbedtls_psa_pake_operation_t *operation)
 {
-    if (operation->alg == PSA_ALG_NONE) {
-        return PSA_SUCCESS;
-    }
-
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
 
     if (operation->alg == PSA_ALG_JPAKE) {
diff --git a/tests/scripts/all.sh b/tests/scripts/all.sh
index e757674..98060d7 100755
--- a/tests/scripts/all.sh
+++ b/tests/scripts/all.sh
@@ -2524,7 +2524,7 @@
 }
 
 component_test_psa_crypto_config_accel_pake_no_fallback () {
-    msg "test: MBEDTLS_PSA_CRYPTO_CONFIG with accelerated PAKE"
+    msg "test: MBEDTLS_PSA_CRYPTO_CONFIG with accelerated PAKE - no fallback"
 
     # Start with full
     scripts/config.py full
@@ -2550,7 +2550,7 @@
     loc_accel_flags="$loc_accel_flags $( echo "$loc_accel_list" | sed 's/[^ ]* */-DMBEDTLS_PSA_ACCEL_&/g' )"
     make CFLAGS="$ASAN_CFLAGS -Werror -I../tests/include -I../tests -I../../tests -DPSA_CRYPTO_DRIVER_TEST -DMBEDTLS_TEST_LIBTESTDRIVER1 $loc_accel_flags" LDFLAGS="-ltestdriver1 $ASAN_CFLAGS"
 
-    msg "test: MBEDTLS_PSA_CRYPTO_CONFIG with accelerated PAKE"
+    msg "test: MBEDTLS_PSA_CRYPTO_CONFIG with accelerated PAKE - no fallback"
     make test
 }
 
diff --git a/tests/suites/test_suite_psa_crypto_driver_wrappers.function b/tests/suites/test_suite_psa_crypto_driver_wrappers.function
index 0f376ef..f718349 100644
--- a/tests/suites/test_suite_psa_crypto_driver_wrappers.function
+++ b/tests/suites/test_suite_psa_crypto_driver_wrappers.function
@@ -3248,7 +3248,7 @@
     ecjpake_do_round(alg, primitive_arg, &server, &client,
                      client_input_first, 2);
 
-    /* After get the key is obtained operation is aborted.
+    /* After the key is obtained operation is aborted.
        Adapt counter of expected hits. */
     if (pake_in_driver) {
         pake_expected_hit_count++;
@@ -3258,7 +3258,7 @@
     TEST_EQUAL(mbedtls_test_driver_pake_hooks.hits,
                pake_in_driver ? pake_expected_hit_count++ : pake_expected_hit_count);
 
-    /* After get the key is obtained operation is aborted.
+    /* After the key is obtained operation is aborted.
        Adapt counter of expected hits. */
     if (pake_in_driver) {
         pake_expected_hit_count++;