Fix configuration for accelerated jpake

Signed-off-by: Przemek Stekiel <przemyslaw.stekiel@mobica.com>
diff --git a/include/mbedtls/config_psa.h b/include/mbedtls/config_psa.h
index 48b2d32..f7de6d1 100644
--- a/include/mbedtls/config_psa.h
+++ b/include/mbedtls/config_psa.h
@@ -147,12 +147,15 @@
 #endif
 
 #if defined(PSA_WANT_ALG_JPAKE)
+#if !defined(MBEDTLS_PSA_ACCEL_ALG_JPAKE)
 #define MBEDTLS_PSA_BUILTIN_PAKE 1
 #define MBEDTLS_PSA_BUILTIN_ALG_JPAKE 1
 #define MBEDTLS_ECP_DP_SECP256R1_ENABLED
 #define MBEDTLS_BIGNUM_C
 #define MBEDTLS_ECP_C
 #define MBEDTLS_ECJPAKE_C
+#define MBEDTLS_SHA256_C
+#endif /* MBEDTLS_PSA_ACCEL_ALG_JPAKE */
 #endif /* PSA_WANT_ALG_JPAKE */
 
 #if defined(PSA_WANT_ALG_RIPEMD160) && !defined(MBEDTLS_PSA_ACCEL_ALG_RIPEMD160)
diff --git a/include/psa/crypto_extra.h b/include/psa/crypto_extra.h
index 5f86c3f..5cf5615 100644
--- a/include/psa/crypto_extra.h
+++ b/include/psa/crypto_extra.h
@@ -2042,7 +2042,7 @@
     /* Holds computation stage of the PAKE algorithms. */
     union {
         uint8_t MBEDTLS_PRIVATE(dummy);
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
+#if defined(PSA_WANT_ALG_JPAKE)
         psa_jpake_computation_stage_t MBEDTLS_PRIVATE(jpake);
 #endif
     } MBEDTLS_PRIVATE(computation_stage);
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 1611fc9..d7eeead 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -7255,7 +7255,7 @@
     operation->alg = cipher_suite->algorithm;
     operation->data.inputs.cipher_suite = *cipher_suite;
 
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
+#if defined(PSA_WANT_ALG_JPAKE)
     if (operation->alg == PSA_ALG_JPAKE) {
         psa_jpake_computation_stage_t *computation_stage =
             &operation->computation_stage.jpake;
@@ -7405,7 +7405,7 @@
 }
 
 /* Auxiliary function to convert core computation stage(step, sequence, state) to single driver step. */
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
+#if defined(PSA_WANT_ALG_JPAKE)
 static psa_crypto_driver_pake_step_t convert_jpake_computation_stage_to_driver_step(
     psa_jpake_computation_stage_t *stage)
 {
@@ -7499,7 +7499,7 @@
     mbedtls_free(inputs.password);
 
     if (status == PSA_SUCCESS) {
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
+#if defined(PSA_WANT_ALG_JPAKE)
         if (operation->alg == PSA_ALG_JPAKE) {
             operation->stage = PSA_PAKE_OPERATION_STAGE_COMPUTATION;
             psa_jpake_computation_stage_t *computation_stage =
@@ -7517,7 +7517,7 @@
     return status;
 }
 
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
+#if defined(PSA_WANT_ALG_JPAKE)
 static psa_status_t psa_jpake_output_prologue(
     psa_pake_operation_t *operation,
     psa_pake_step_t step)
@@ -7639,7 +7639,7 @@
     }
 
     switch (operation->alg) {
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
+#if defined(PSA_WANT_ALG_JPAKE)
         case PSA_ALG_JPAKE:
             status = psa_jpake_output_prologue(operation, step);
             if (status != PSA_SUCCESS) {
@@ -7653,7 +7653,7 @@
             goto exit;
     }
 
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
+#if defined(PSA_WANT_ALG_JPAKE)
     status = psa_driver_wrapper_pake_output(operation,
                                             convert_jpake_computation_stage_to_driver_step(
                                                 &operation->computation_stage.jpake),
@@ -7670,7 +7670,7 @@
     }
 
     switch (operation->alg) {
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
+#if defined(PSA_WANT_ALG_JPAKE)
         case PSA_ALG_JPAKE:
             status = psa_jpake_output_epilogue(operation);
             if (status != PSA_SUCCESS) {
@@ -7689,7 +7689,7 @@
     return status;
 }
 
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
+#if defined(PSA_WANT_ALG_JPAKE)
 static psa_status_t psa_jpake_input_prologue(
     psa_pake_operation_t *operation,
     psa_pake_step_t step,
@@ -7816,7 +7816,7 @@
     }
 
     switch (operation->alg) {
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
+#if defined(PSA_WANT_ALG_JPAKE)
         case PSA_ALG_JPAKE:
             status = psa_jpake_input_prologue(operation, step, input_length);
             if (status != PSA_SUCCESS) {
@@ -7830,7 +7830,7 @@
             goto exit;
     }
 
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
+#if defined(PSA_WANT_ALG_JPAKE)
     status = psa_driver_wrapper_pake_input(operation,
                                            convert_jpake_computation_stage_to_driver_step(
                                                &operation->computation_stage.jpake),
@@ -7846,7 +7846,7 @@
     }
 
     switch (operation->alg) {
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
+#if defined(PSA_WANT_ALG_JPAKE)
         case PSA_ALG_JPAKE:
             status = psa_jpake_input_epilogue(operation);
             if (status != PSA_SUCCESS) {
@@ -7879,7 +7879,7 @@
         goto exit;
     }
 
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
+#if defined(PSA_WANT_ALG_JPAKE)
     if (operation->alg == PSA_ALG_JPAKE) {
         psa_jpake_computation_stage_t *computation_stage =
             &operation->computation_stage.jpake;
diff --git a/tests/include/test/drivers/crypto_config_test_driver_extension.h b/tests/include/test/drivers/crypto_config_test_driver_extension.h
index 393d632..26c432c 100644
--- a/tests/include/test/drivers/crypto_config_test_driver_extension.h
+++ b/tests/include/test/drivers/crypto_config_test_driver_extension.h
@@ -158,6 +158,14 @@
 #endif
 #endif
 
+#if defined(PSA_WANT_ALG_JPAKE)
+#if defined(MBEDTLS_PSA_ACCEL_ALG_JPAKE)
+#undef MBEDTLS_PSA_ACCEL_ALG_JPAKE
+#else
+#define MBEDTLS_PSA_ACCEL_ALG_JPAKE 1
+#endif
+#endif
+
 #if defined(PSA_WANT_KEY_TYPE_AES)
 #if defined(MBEDTLS_PSA_ACCEL_KEY_TYPE_AES)
 #undef MBEDTLS_PSA_ACCEL_KEY_TYPE_AES
diff --git a/tests/scripts/all.sh b/tests/scripts/all.sh
index a2c0cb7..f20a7dc 100755
--- a/tests/scripts/all.sh
+++ b/tests/scripts/all.sh
@@ -2500,7 +2500,7 @@
     make test
 }
 
-component_test_psa_crypto_config_accel_pake () {
+component_test_psa_crypto_config_accel_pake() {
     msg "test: MBEDTLS_PSA_CRYPTO_CONFIG with accelerated PAKE"
 
     # Start with full
@@ -2518,44 +2518,8 @@
     scripts/config.py set MBEDTLS_PSA_CRYPTO_DRIVERS
     scripts/config.py set MBEDTLS_PSA_CRYPTO_CONFIG
 
-    scripts/config.py unset MBEDTLS_ECJPAKE_C
-
-    # Dynamic secure element support is a deprecated feature and needs to be disabled here.
-    # This is done to have the same form of psa_key_attributes_s for libdriver and library.
-    scripts/config.py unset MBEDTLS_PSA_CRYPTO_SE_C
-
-    loc_accel_flags="$loc_accel_flags $( echo "$loc_accel_list" | sed 's/[^ ]* */-DMBEDTLS_PSA_ACCEL_&/g' )"
-    make CFLAGS="$ASAN_CFLAGS -Werror -I../tests/include -I../tests -I../../tests -DPSA_CRYPTO_DRIVER_TEST -DMBEDTLS_TEST_LIBTESTDRIVER1 $loc_accel_flags" LDFLAGS="-ltestdriver1 $ASAN_CFLAGS"
-
-    msg "test: ssl-opt.sh, MBEDTLS_PSA_CRYPTO_CONFIG with accelerated PAKE"
-    tests/ssl-opt.sh -f "ECJPAKE"
-
-    msg "test: MBEDTLS_PSA_CRYPTO_CONFIG with accelerated PAKE"
-    make test
-}
-
-component_test_psa_crypto_config_accel_pake_no_fallback () {
-    msg "test: MBEDTLS_PSA_CRYPTO_CONFIG with accelerated PAKE - no fallback"
-
-    # Start with full
-    scripts/config.py full
-
-    # Disable ALG_STREAM_CIPHER and ALG_ECB_NO_PADDING to avoid having
-    # partial support for cipher operations in the driver test library.
-    scripts/config.py -f include/psa/crypto_config.h unset PSA_WANT_ALG_STREAM_CIPHER
-    scripts/config.py -f include/psa/crypto_config.h unset PSA_WANT_ALG_ECB_NO_PADDING
-
-    loc_accel_list="ALG_JPAKE"
-    loc_accel_flags=$( echo "$loc_accel_list" | sed 's/[^ ]* */-DLIBTESTDRIVER1_MBEDTLS_PSA_ACCEL_&/g' )
-    make -C tests libtestdriver1.a CFLAGS="$ASAN_CFLAGS $loc_accel_flags" LDFLAGS="$ASAN_CFLAGS"
-
-    scripts/config.py set MBEDTLS_PSA_CRYPTO_DRIVERS
-    scripts/config.py set MBEDTLS_PSA_CRYPTO_CONFIG
-
-    scripts/config.py unset MBEDTLS_ECJPAKE_C
-
     # Make build-in fallback not available
-    scripts/config.py -f include/psa/crypto_config.h unset PSA_WANT_ALG_JPAKE
+    scripts/config.py unset MBEDTLS_ECJPAKE_C
     scripts/config.py unset MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED
 
     # Dynamic secure element support is a deprecated feature and needs to be disabled here.
@@ -2565,7 +2529,9 @@
     loc_accel_flags="$loc_accel_flags $( echo "$loc_accel_list" | sed 's/[^ ]* */-DMBEDTLS_PSA_ACCEL_&/g' )"
     make CFLAGS="$ASAN_CFLAGS -Werror -I../tests/include -I../tests -I../../tests -DPSA_CRYPTO_DRIVER_TEST -DMBEDTLS_TEST_LIBTESTDRIVER1 $loc_accel_flags" LDFLAGS="-ltestdriver1 $ASAN_CFLAGS"
 
-    msg "test: MBEDTLS_PSA_CRYPTO_CONFIG with accelerated PAKE - no fallback"
+    not grep mbedtls_ecjpake_init library/ecjpake.o
+
+    msg "test: MBEDTLS_PSA_CRYPTO_CONFIG with accelerated PAKE"
     make test
 }
 
diff --git a/tests/suites/test_suite_psa_crypto_driver_wrappers.function b/tests/suites/test_suite_psa_crypto_driver_wrappers.function
index 6522fe5..8a4c007 100644
--- a/tests/suites/test_suite_psa_crypto_driver_wrappers.function
+++ b/tests/suites/test_suite_psa_crypto_driver_wrappers.function
@@ -2976,7 +2976,7 @@
 }
 /* END_CASE */
 
-/* BEGIN_CASE depends_on:PSA_WANT_ALG_JPAKE:PSA_WANT_KEY_TYPE_ECC_KEY_PAIR:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256 */
+/* BEGIN_CASE depends_on:PSA_WANT_ALG_JPAKE */
 void pake_operations(data_t *pw_data, int forced_status_setup_arg, int forced_status_arg,
                      data_t *forced_output, int expected_status_arg,
                      int fut)
diff --git a/tests/suites/test_suite_psa_crypto_pake.function b/tests/suites/test_suite_psa_crypto_pake.function
index f094eb9..2bed45a 100644
--- a/tests/suites/test_suite_psa_crypto_pake.function
+++ b/tests/suites/test_suite_psa_crypto_pake.function
@@ -909,7 +909,7 @@
 }
 /* END_CASE */
 
-/* BEGIN_CASE depends_on:PSA_WANT_ALG_JPAKE:PSA_ALG_SHA_256 */
+/* BEGIN_CASE depends_on:PSA_WANT_ALG_JPAKE */
 void pake_input_getters_password()
 {
     psa_pake_cipher_suite_t cipher_suite = psa_pake_cipher_suite_init();
@@ -975,7 +975,7 @@
 }
 /* END_CASE */
 
-/* BEGIN_CASE depends_on:PSA_WANT_ALG_JPAKE:PSA_ALG_SHA_256 */
+/* BEGIN_CASE depends_on:PSA_WANT_ALG_JPAKE */
 void pake_input_getters_cipher_suite()
 {
     psa_pake_cipher_suite_t cipher_suite = psa_pake_cipher_suite_init();
@@ -1008,7 +1008,7 @@
 }
 /* END_CASE */
 
-/* BEGIN_CASE depends_on:PSA_WANT_ALG_JPAKE:PSA_ALG_SHA_256 */
+/* BEGIN_CASE depends_on:PSA_WANT_ALG_JPAKE */
 void pake_input_getters_role()
 {
     psa_pake_cipher_suite_t cipher_suite = psa_pake_cipher_suite_init();