Merge pull request #6195 from superna9999/6149-driver-only-hashes-ec-j-pake

Driver-only hashes: EC J-PAKE
diff --git a/include/mbedtls/check_config.h b/include/mbedtls/check_config.h
index e0633f5..165bb09 100644
--- a/include/mbedtls/check_config.h
+++ b/include/mbedtls/check_config.h
@@ -108,7 +108,8 @@
 #endif
 
 #if defined(MBEDTLS_ECJPAKE_C) &&           \
-    ( !defined(MBEDTLS_ECP_C) || !defined(MBEDTLS_MD_C) )
+    ( !defined(MBEDTLS_ECP_C) ||            \
+      !( defined(MBEDTLS_MD_C) || defined(MBEDTLS_PSA_CRYPTO_C) ) )
 #error "MBEDTLS_ECJPAKE_C defined, but not all prerequisites"
 #endif
 
diff --git a/include/mbedtls/ecjpake.h b/include/mbedtls/ecjpake.h
index 7853a6a..ffdea05 100644
--- a/include/mbedtls/ecjpake.h
+++ b/include/mbedtls/ecjpake.h
@@ -70,7 +70,7 @@
  */
 typedef struct mbedtls_ecjpake_context
 {
-    const mbedtls_md_info_t *MBEDTLS_PRIVATE(md_info);   /**< Hash to use                    */
+    mbedtls_md_type_t MBEDTLS_PRIVATE(md_type);          /**< Hash to use                    */
     mbedtls_ecp_group MBEDTLS_PRIVATE(grp);              /**< Elliptic curve                 */
     mbedtls_ecjpake_role MBEDTLS_PRIVATE(role);          /**< Are we client or server?       */
     int MBEDTLS_PRIVATE(point_format);                   /**< Format for point export        */
diff --git a/include/mbedtls/mbedtls_config.h b/include/mbedtls/mbedtls_config.h
index a970de7..e1821f7 100644
--- a/include/mbedtls/mbedtls_config.h
+++ b/include/mbedtls/mbedtls_config.h
@@ -2330,6 +2330,9 @@
  *      ECJPAKE
  *
  * Requires: MBEDTLS_ECP_C, MBEDTLS_MD_C
+ *
+ * \warning If building without MBEDTLS_MD_C, you must call psa_crypto_init()
+ * before doing any EC J-PAKE operations.
  */
 #define MBEDTLS_ECJPAKE_C
 
diff --git a/library/ecjpake.c b/library/ecjpake.c
index c591924..7447354 100644
--- a/library/ecjpake.c
+++ b/library/ecjpake.c
@@ -30,6 +30,15 @@
 #include "mbedtls/platform_util.h"
 #include "mbedtls/error.h"
 
+/* We use MD first if it's available (for compatibility reasons)
+ * and "fall back" to PSA otherwise (which needs psa_crypto_init()). */
+#if !defined(MBEDTLS_MD_C)
+#include "psa/crypto.h"
+#include "mbedtls/psa_util.h"
+#endif /* !MBEDTLS_MD_C */
+
+#include "hash_info.h"
+
 #include <string.h>
 
 #if !defined(MBEDTLS_ECJPAKE_ALT)
@@ -45,12 +54,34 @@
 #define ID_MINE     ( ecjpake_id[ ctx->role ] )
 #define ID_PEER     ( ecjpake_id[ 1 - ctx->role ] )
 
+/**
+  * Helper to Compute a hash from md_type
+  */
+static int mbedtls_ecjpake_compute_hash( mbedtls_md_type_t md_type,
+                                    const unsigned char *input, size_t ilen,
+                                    unsigned char *output )
+{
+#if defined(MBEDTLS_MD_C)
+    return( mbedtls_md( mbedtls_md_info_from_type( md_type ),
+                        input, ilen, output ) );
+#else
+    psa_algorithm_t alg = mbedtls_psa_translate_md( md_type );
+    psa_status_t status;
+    size_t out_size = PSA_HASH_LENGTH( alg );
+    size_t out_len;
+
+    status = psa_hash_compute( alg, input, ilen, output, out_size, &out_len );
+
+    return( mbedtls_md_error_from_psa( status ) );
+#endif /* !MBEDTLS_MD_C */
+}
+
 /*
  * Initialize context
  */
 void mbedtls_ecjpake_init( mbedtls_ecjpake_context *ctx )
 {
-    ctx->md_info = NULL;
+    ctx->md_type = MBEDTLS_MD_NONE;
     mbedtls_ecp_group_init( &ctx->grp );
     ctx->point_format = MBEDTLS_ECP_PF_UNCOMPRESSED;
 
@@ -73,7 +104,7 @@
     if( ctx == NULL )
         return;
 
-    ctx->md_info = NULL;
+    ctx->md_type = MBEDTLS_MD_NONE;
     mbedtls_ecp_group_free( &ctx->grp );
 
     mbedtls_ecp_point_free( &ctx->Xm1 );
@@ -104,8 +135,15 @@
 
     ctx->role = role;
 
-    if( ( ctx->md_info = mbedtls_md_info_from_type( hash ) ) == NULL )
+#if defined(MBEDTLS_MD_C)
+    if( ( mbedtls_md_info_from_type( hash ) ) == NULL )
         return( MBEDTLS_ERR_MD_FEATURE_UNAVAILABLE );
+#else
+    if( mbedtls_psa_translate_md( hash ) == MBEDTLS_MD_NONE )
+        return( MBEDTLS_ERR_MD_FEATURE_UNAVAILABLE );
+#endif
+
+    ctx->md_type = hash;
 
     MBEDTLS_MPI_CHK( mbedtls_ecp_group_load( &ctx->grp, curve ) );
 
@@ -137,7 +175,7 @@
  */
 int mbedtls_ecjpake_check( const mbedtls_ecjpake_context *ctx )
 {
-    if( ctx->md_info == NULL ||
+    if( ctx->md_type == MBEDTLS_MD_NONE ||
         ctx->grp.id == MBEDTLS_ECP_DP_NONE ||
         ctx->s.p == NULL )
     {
@@ -184,7 +222,7 @@
 /*
  * Compute hash for ZKP (7.4.2.2.2.1)
  */
-static int ecjpake_hash( const mbedtls_md_info_t *md_info,
+static int ecjpake_hash( const mbedtls_md_type_t md_type,
                          const mbedtls_ecp_group *grp,
                          const int pf,
                          const mbedtls_ecp_point *G,
@@ -218,11 +256,12 @@
     p += id_len;
 
     /* Compute hash */
-    MBEDTLS_MPI_CHK( mbedtls_md( md_info, buf, p - buf, hash ) );
+    MBEDTLS_MPI_CHK( mbedtls_ecjpake_compute_hash( md_type,
+                                                   buf, p - buf, hash ) );
 
     /* Turn it into an integer mod n */
     MBEDTLS_MPI_CHK( mbedtls_mpi_read_binary( h, hash,
-                                        mbedtls_md_get_size( md_info ) ) );
+                                    mbedtls_hash_info_get_size( md_type ) ) );
     MBEDTLS_MPI_CHK( mbedtls_mpi_mod_mpi( h, h, &grp->N ) );
 
 cleanup:
@@ -232,7 +271,7 @@
 /*
  * Parse a ECShnorrZKP (7.4.2.2.2) and verify it (7.4.2.3.3)
  */
-static int ecjpake_zkp_read( const mbedtls_md_info_t *md_info,
+static int ecjpake_zkp_read( const mbedtls_md_type_t md_type,
                              const mbedtls_ecp_group *grp,
                              const int pf,
                              const mbedtls_ecp_point *G,
@@ -282,7 +321,7 @@
     /*
      * Verification
      */
-    MBEDTLS_MPI_CHK( ecjpake_hash( md_info, grp, pf, G, &V, X, id, &h ) );
+    MBEDTLS_MPI_CHK( ecjpake_hash( md_type, grp, pf, G, &V, X, id, &h ) );
     MBEDTLS_MPI_CHK( mbedtls_ecp_muladd( (mbedtls_ecp_group *) grp,
                      &VV, &h, X, &r, G ) );
 
@@ -304,7 +343,7 @@
 /*
  * Generate ZKP (7.4.2.3.2) and write it as ECSchnorrZKP (7.4.2.2.2)
  */
-static int ecjpake_zkp_write( const mbedtls_md_info_t *md_info,
+static int ecjpake_zkp_write( const mbedtls_md_type_t md_type,
                               const mbedtls_ecp_group *grp,
                               const int pf,
                               const mbedtls_ecp_point *G,
@@ -332,7 +371,7 @@
     /* Compute signature */
     MBEDTLS_MPI_CHK( mbedtls_ecp_gen_keypair_base( (mbedtls_ecp_group *) grp,
                                                    G, &v, &V, f_rng, p_rng ) );
-    MBEDTLS_MPI_CHK( ecjpake_hash( md_info, grp, pf, G, &V, X, id, &h ) );
+    MBEDTLS_MPI_CHK( ecjpake_hash( md_type, grp, pf, G, &V, X, id, &h ) );
     MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &h, &h, x ) ); /* x*h */
     MBEDTLS_MPI_CHK( mbedtls_mpi_sub_mpi( &h, &v, &h ) ); /* v - x*h */
     MBEDTLS_MPI_CHK( mbedtls_mpi_mod_mpi( &h, &h, &grp->N ) ); /* r */
@@ -365,7 +404,7 @@
  * Parse a ECJPAKEKeyKP (7.4.2.2.1) and check proof
  * Output: verified public key X
  */
-static int ecjpake_kkp_read( const mbedtls_md_info_t *md_info,
+static int ecjpake_kkp_read( const mbedtls_md_type_t md_type,
                              const mbedtls_ecp_group *grp,
                              const int pf,
                              const mbedtls_ecp_point *G,
@@ -392,7 +431,7 @@
         goto cleanup;
     }
 
-    MBEDTLS_MPI_CHK( ecjpake_zkp_read( md_info, grp, pf, G, X, id, p, end ) );
+    MBEDTLS_MPI_CHK( ecjpake_zkp_read( md_type, grp, pf, G, X, id, p, end ) );
 
 cleanup:
     return( ret );
@@ -402,7 +441,7 @@
  * Generate an ECJPAKEKeyKP
  * Output: the serialized structure, plus private/public key pair
  */
-static int ecjpake_kkp_write( const mbedtls_md_info_t *md_info,
+static int ecjpake_kkp_write( const mbedtls_md_type_t md_type,
                               const mbedtls_ecp_group *grp,
                               const int pf,
                               const mbedtls_ecp_point *G,
@@ -428,7 +467,7 @@
     *p += len;
 
     /* Generate and write proof */
-    MBEDTLS_MPI_CHK( ecjpake_zkp_write( md_info, grp, pf, G, x, X, id,
+    MBEDTLS_MPI_CHK( ecjpake_zkp_write( md_type, grp, pf, G, x, X, id,
                                         p, end, f_rng, p_rng ) );
 
 cleanup:
@@ -439,7 +478,7 @@
  * Read a ECJPAKEKeyKPPairList (7.4.2.3) and check proofs
  * Outputs: verified peer public keys Xa, Xb
  */
-static int ecjpake_kkpp_read( const mbedtls_md_info_t *md_info,
+static int ecjpake_kkpp_read( const mbedtls_md_type_t md_type,
                               const mbedtls_ecp_group *grp,
                               const int pf,
                               const mbedtls_ecp_point *G,
@@ -458,8 +497,8 @@
      *     ECJPAKEKeyKP ecjpake_key_kp_pair_list[2];
      * } ECJPAKEKeyKPPairList;
      */
-    MBEDTLS_MPI_CHK( ecjpake_kkp_read( md_info, grp, pf, G, Xa, id, &p, end ) );
-    MBEDTLS_MPI_CHK( ecjpake_kkp_read( md_info, grp, pf, G, Xb, id, &p, end ) );
+    MBEDTLS_MPI_CHK( ecjpake_kkp_read( md_type, grp, pf, G, Xa, id, &p, end ) );
+    MBEDTLS_MPI_CHK( ecjpake_kkp_read( md_type, grp, pf, G, Xb, id, &p, end ) );
 
     if( p != end )
         ret = MBEDTLS_ERR_ECP_BAD_INPUT_DATA;
@@ -472,7 +511,7 @@
  * Generate a ECJPAKEKeyKPPairList
  * Outputs: the serialized structure, plus two private/public key pairs
  */
-static int ecjpake_kkpp_write( const mbedtls_md_info_t *md_info,
+static int ecjpake_kkpp_write( const mbedtls_md_type_t md_type,
                                const mbedtls_ecp_group *grp,
                                const int pf,
                                const mbedtls_ecp_point *G,
@@ -491,9 +530,9 @@
     unsigned char *p = buf;
     const unsigned char *end = buf + len;
 
-    MBEDTLS_MPI_CHK( ecjpake_kkp_write( md_info, grp, pf, G, xm1, Xa, id,
+    MBEDTLS_MPI_CHK( ecjpake_kkp_write( md_type, grp, pf, G, xm1, Xa, id,
                 &p, end, f_rng, p_rng ) );
-    MBEDTLS_MPI_CHK( ecjpake_kkp_write( md_info, grp, pf, G, xm2, Xb, id,
+    MBEDTLS_MPI_CHK( ecjpake_kkp_write( md_type, grp, pf, G, xm2, Xb, id,
                 &p, end, f_rng, p_rng ) );
 
     *olen = p - buf;
@@ -509,7 +548,7 @@
                                     const unsigned char *buf,
                                     size_t len )
 {
-    return( ecjpake_kkpp_read( ctx->md_info, &ctx->grp, ctx->point_format,
+    return( ecjpake_kkpp_read( ctx->md_type, &ctx->grp, ctx->point_format,
                                &ctx->grp.G,
                                &ctx->Xp1, &ctx->Xp2, ID_PEER,
                                buf, len ) );
@@ -523,7 +562,7 @@
                             int (*f_rng)(void *, unsigned char *, size_t),
                             void *p_rng )
 {
-    return( ecjpake_kkpp_write( ctx->md_info, &ctx->grp, ctx->point_format,
+    return( ecjpake_kkpp_write( ctx->md_type, &ctx->grp, ctx->point_format,
                                 &ctx->grp.G,
                                 &ctx->xm1, &ctx->Xm1, &ctx->xm2, &ctx->Xm2,
                                 ID_MINE, buf, len, olen, f_rng, p_rng ) );
@@ -593,7 +632,7 @@
         }
     }
 
-    MBEDTLS_MPI_CHK( ecjpake_kkp_read( ctx->md_info, &ctx->grp,
+    MBEDTLS_MPI_CHK( ecjpake_kkp_read( ctx->md_type, &ctx->grp,
                             ctx->point_format,
                             &G, &ctx->Xp, ID_PEER, &p, end ) );
 
@@ -703,7 +742,7 @@
                      ctx->point_format, &ec_len, p, end - p ) );
     p += ec_len;
 
-    MBEDTLS_MPI_CHK( ecjpake_zkp_write( ctx->md_info, &ctx->grp,
+    MBEDTLS_MPI_CHK( ecjpake_zkp_write( ctx->md_type, &ctx->grp,
                                         ctx->point_format,
                                         &G, &xm, &Xm, ID_MINE,
                                         &p, end, f_rng, p_rng ) );
@@ -732,7 +771,7 @@
     unsigned char kx[MBEDTLS_ECP_MAX_BYTES];
     size_t x_bytes;
 
-    *olen = mbedtls_md_get_size( ctx->md_info );
+    *olen = mbedtls_hash_info_get_size( ctx->md_type );
     if( len < *olen )
         return( MBEDTLS_ERR_ECP_BUFFER_TOO_SMALL );
 
@@ -758,7 +797,8 @@
     /* PMS = SHA-256( K.X ) */
     x_bytes = ( ctx->grp.pbits + 7 ) / 8;
     MBEDTLS_MPI_CHK( mbedtls_mpi_write_binary( &K.X, kx, x_bytes ) );
-    MBEDTLS_MPI_CHK( mbedtls_md( ctx->md_info, kx, x_bytes, buf ) );
+    MBEDTLS_MPI_CHK( mbedtls_ecjpake_compute_hash( ctx->md_type,
+                                                   kx, x_bytes, buf ) );
 
 cleanup:
     mbedtls_ecp_point_free( &K );
diff --git a/tests/scripts/all.sh b/tests/scripts/all.sh
index 5393ef4..0356139 100755
--- a/tests/scripts/all.sh
+++ b/tests/scripts/all.sh
@@ -1208,7 +1208,6 @@
     scripts/config.py crypto_full
     scripts/config.py unset MBEDTLS_MD_C
     # Direct dependencies
-    scripts/config.py unset MBEDTLS_ECJPAKE_C
     scripts/config.py unset MBEDTLS_HKDF_C
     scripts/config.py unset MBEDTLS_HMAC_DRBG_C
     scripts/config.py unset MBEDTLS_PKCS5_C
@@ -1869,7 +1868,6 @@
     # Also unset MD_C and things that depend on it;
     # see component_test_crypto_full_no_md.
     scripts/config.py unset MBEDTLS_MD_C
-    scripts/config.py unset MBEDTLS_ECJPAKE_C
     scripts/config.py unset MBEDTLS_HKDF_C
     scripts/config.py unset MBEDTLS_HMAC_DRBG_C
     scripts/config.py unset MBEDTLS_PKCS5_C
diff --git a/tests/suites/test_suite_ecjpake.function b/tests/suites/test_suite_ecjpake.function
index e8aaa6c..449b368 100644
--- a/tests/suites/test_suite_ecjpake.function
+++ b/tests/suites/test_suite_ecjpake.function
@@ -1,7 +1,8 @@
 /* BEGIN_HEADER */
 #include "mbedtls/ecjpake.h"
+#include "legacy_or_psa.h"
 
-#if defined(MBEDTLS_ECP_DP_SECP256R1_ENABLED) && defined(MBEDTLS_SHA256_C)
+#if defined(MBEDTLS_ECP_DP_SECP256R1_ENABLED) && defined(MBEDTLS_HAS_ALG_SHA_256_VIA_MD_OR_PSA)
 static const unsigned char ecjpake_test_x1[] = {
     0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c,
     0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18,
@@ -90,7 +91,7 @@
 }
 
 #define ADD_SIZE( x )   x, sizeof( x )
-#endif /* MBEDTLS_ECP_DP_SECP256R1_ENABLED && MBEDTLS_SHA256_C */
+#endif /* MBEDTLS_ECP_DP_SECP256R1_ENABLED && MBEDTLS_HAS_ALG_SHA_256_VIA_MD_OR_PSA */
 /* END_HEADER */
 
 /* BEGIN_DEPENDENCIES
@@ -126,7 +127,7 @@
 }
 /* END_CASE */
 
-/* BEGIN_CASE depends_on:MBEDTLS_ECP_DP_SECP256R1_ENABLED:MBEDTLS_SHA256_C */
+/* BEGIN_CASE depends_on:MBEDTLS_ECP_DP_SECP256R1_ENABLED:MBEDTLS_HAS_ALG_SHA_256_VIA_MD_OR_PSA */
 void read_bad_md( data_t *msg )
 {
     mbedtls_ecjpake_context corrupt_ctx;
@@ -137,17 +138,17 @@
     mbedtls_ecjpake_init( &corrupt_ctx );
     TEST_ASSERT( mbedtls_ecjpake_setup( &corrupt_ctx, any_role,
                  MBEDTLS_MD_SHA256, MBEDTLS_ECP_DP_SECP256R1, pw, pw_len ) == 0 );
-    corrupt_ctx.md_info = NULL;
+    corrupt_ctx.md_type = MBEDTLS_MD_NONE;
 
-    TEST_ASSERT( mbedtls_ecjpake_read_round_one( &corrupt_ctx, msg->x,
-                 msg->len ) == MBEDTLS_ERR_MD_BAD_INPUT_DATA );
+    TEST_EQUAL( mbedtls_ecjpake_read_round_one( &corrupt_ctx, msg->x,
+                 msg->len ), MBEDTLS_ERR_MD_BAD_INPUT_DATA );
 
 exit:
     mbedtls_ecjpake_free( &corrupt_ctx );
 }
 /* END_CASE */
 
-/* BEGIN_CASE depends_on:MBEDTLS_ECP_DP_SECP256R1_ENABLED:MBEDTLS_SHA256_C */
+/* BEGIN_CASE depends_on:MBEDTLS_ECP_DP_SECP256R1_ENABLED:MBEDTLS_HAS_ALG_SHA_256_VIA_MD_OR_PSA */
 void read_round_one( int role, data_t * msg, int ref_ret )
 {
     mbedtls_ecjpake_context ctx;
@@ -166,7 +167,7 @@
 }
 /* END_CASE */
 
-/* BEGIN_CASE depends_on:MBEDTLS_ECP_DP_SECP256R1_ENABLED:MBEDTLS_SHA256_C */
+/* BEGIN_CASE depends_on:MBEDTLS_ECP_DP_SECP256R1_ENABLED:MBEDTLS_HAS_ALG_SHA_256_VIA_MD_OR_PSA */
 void read_round_two_cli( data_t * msg, int ref_ret )
 {
     mbedtls_ecjpake_context ctx;
@@ -191,7 +192,7 @@
 }
 /* END_CASE */
 
-/* BEGIN_CASE depends_on:MBEDTLS_ECP_DP_SECP256R1_ENABLED:MBEDTLS_SHA256_C */
+/* BEGIN_CASE depends_on:MBEDTLS_ECP_DP_SECP256R1_ENABLED:MBEDTLS_HAS_ALG_SHA_256_VIA_MD_OR_PSA */
 void read_round_two_srv( data_t * msg, int ref_ret )
 {
     mbedtls_ecjpake_context ctx;