PKCS#1 v2.1 now builds with PSA if no MD_C

Test coverage not there yet, as the entire test_suite_pkcs1_v21 is
skipped so far - dependencies to be adjusted in a future commit.

Signed-off-by: Manuel Pégourié-Gonnard <manuel.pegourie-gonnard@arm.com>
diff --git a/include/mbedtls/check_config.h b/include/mbedtls/check_config.h
index 55fca55..719ffb7 100644
--- a/include/mbedtls/check_config.h
+++ b/include/mbedtls/check_config.h
@@ -162,7 +162,8 @@
 #error "MBEDTLS_PKCS12_C defined, but not all prerequisites"
 #endif
 
-#if defined(MBEDTLS_PKCS1_V21) && !defined(MBEDTLS_MD_C)
+#if defined(MBEDTLS_PKCS1_V21) && \
+    !( defined(MBEDTLS_MD_C) || defined(MBEDTLS_PSA_CRYPTO_C) )
 #error "MBEDTLS_PKCS1_V21 defined, but not all prerequisites"
 #endif
 
diff --git a/include/mbedtls/config_psa.h b/include/mbedtls/config_psa.h
index 2e9e451..fbfcdc3 100644
--- a/include/mbedtls/config_psa.h
+++ b/include/mbedtls/config_psa.h
@@ -158,7 +158,6 @@
 #define MBEDTLS_BIGNUM_C
 #define MBEDTLS_OID_C
 #define MBEDTLS_PKCS1_V21
-#define MBEDTLS_MD_C
 #endif /* !MBEDTLS_PSA_ACCEL_ALG_RSA_OAEP */
 #endif /* PSA_WANT_ALG_RSA_OAEP */
 
@@ -189,7 +188,6 @@
 #define MBEDTLS_BIGNUM_C
 #define MBEDTLS_OID_C
 #define MBEDTLS_PKCS1_V21
-#define MBEDTLS_MD_C
 #endif /* !MBEDTLS_PSA_ACCEL_ALG_RSA_PSS */
 #endif /* PSA_WANT_ALG_RSA_PSS */
 
diff --git a/include/mbedtls/mbedtls_config.h b/include/mbedtls/mbedtls_config.h
index e96a797..d7662ef 100644
--- a/include/mbedtls/mbedtls_config.h
+++ b/include/mbedtls/mbedtls_config.h
@@ -1141,7 +1141,10 @@
  *
  * Enable support for PKCS#1 v2.1 encoding.
  *
- * Requires: MBEDTLS_MD_C, MBEDTLS_RSA_C
+ * Requires: MBEDTLS_RSA_C and (MBEDTLS_MD_C or MBEDTLS_PSA_CRYPTO_C).
+ *
+ * \warning If building without MBEDTLS_MD_C, you must call psa_crypto_init()
+ * before doing any PKCS#1 v2.1 operation.
  *
  * This enables support for RSAES-OAEP and RSASSA-PSS operations.
  */
diff --git a/library/rsa.c b/library/rsa.c
index d879a30..5e7cbcc 100644
--- a/library/rsa.c
+++ b/library/rsa.c
@@ -54,6 +54,18 @@
 #include <stdlib.h>
 #endif
 
+/* We use MD first if it's available (for compatibility reasons)
+ * and "fall back" to PSA otherwise (which needs psa_crypto_init()). */
+#if defined(MBEDTLS_PKCS1_V21)
+#if defined(MBEDTLS_MD_C)
+#define HASH_MAX_SIZE   MBEDTLS_MD_MAX_SIZE
+#else /* MBEDTLS_MD_C */
+#include "psa/crypto.h"
+#include "mbedtls/psa_util.h"
+#define HASH_MAX_SIZE   PSA_HASH_MAX_SIZE
+#endif /* MBEDTLS_MD_C */
+#endif /* MBEDTLS_PKCS1_V21 */
+
 #if defined(MBEDTLS_PLATFORM_C)
 #include "mbedtls/platform.h"
 #else
@@ -1086,6 +1098,25 @@
 }
 
 #if defined(MBEDTLS_PKCS1_V21)
+#if !defined(MBEDTLS_MD_C)
+static int ret_from_status( psa_status_t status )
+{
+    switch( status )
+    {
+        case PSA_SUCCESS:
+            return( 0 );
+        case PSA_ERROR_NOT_SUPPORTED:
+            return( MBEDTLS_ERR_MD_FEATURE_UNAVAILABLE );
+        case PSA_ERROR_INVALID_ARGUMENT:
+            return( MBEDTLS_ERR_MD_BAD_INPUT_DATA );
+        case PSA_ERROR_INSUFFICIENT_MEMORY:
+            return( MBEDTLS_ERR_MD_ALLOC_FAILED );
+        default:
+            return( MBEDTLS_ERR_PLATFORM_HW_ACCEL_FAILED );
+    }
+}
+#endif /* !MBEDTLS_MD_C */
+
 /**
  * Generate and apply the MGF1 operation (from PKCS#1 v2.1) to a buffer.
  *
@@ -1098,19 +1129,17 @@
 static int mgf_mask( unsigned char *dst, size_t dlen, unsigned char *src,
                       size_t slen, mbedtls_md_type_t md_alg )
 {
-    const mbedtls_md_info_t *md_info;
-    mbedtls_md_context_t md_ctx;
-    unsigned char mask[MBEDTLS_MD_MAX_SIZE];
     unsigned char counter[4];
     unsigned char *p;
     unsigned int hlen;
     size_t i, use_len;
+    unsigned char mask[HASH_MAX_SIZE];
+#if defined(MBEDTLS_MD_C)
     int ret = 0;
+    const mbedtls_md_info_t *md_info;
+    mbedtls_md_context_t md_ctx;
 
     mbedtls_md_init( &md_ctx );
-    memset( mask, 0, MBEDTLS_MD_MAX_SIZE );
-    memset( counter, 0, 4 );
-
     md_info = mbedtls_md_info_from_type( md_alg );
     if( md_info == NULL )
         return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
@@ -1120,6 +1149,17 @@
         goto exit;
 
     hlen = mbedtls_md_get_size( md_info );
+#else
+    psa_hash_operation_t op = PSA_HASH_OPERATION_INIT;
+    psa_algorithm_t alg = mbedtls_psa_translate_md( md_alg );
+    psa_status_t status = PSA_SUCCESS;
+    size_t out_len;
+
+    hlen = PSA_HASH_LENGTH( alg );
+#endif
+
+    memset( mask, 0, sizeof( mask ) );
+    memset( counter, 0, 4 );
 
     /* Generate and apply dbMask */
     p = dst;
@@ -1130,6 +1170,7 @@
         if( dlen < hlen )
             use_len = dlen;
 
+#if defined(MBEDTLS_MD_C)
         if( ( ret = mbedtls_md_starts( &md_ctx ) ) != 0 )
             goto exit;
         if( ( ret = mbedtls_md_update( &md_ctx, src, slen ) ) != 0 )
@@ -1138,6 +1179,17 @@
             goto exit;
         if( ( ret = mbedtls_md_finish( &md_ctx, mask ) ) != 0 )
             goto exit;
+#else
+        if( ( status = psa_hash_setup( &op, alg ) ) != PSA_SUCCESS )
+            goto exit;
+        if( ( status = psa_hash_update( &op, src, slen ) ) != PSA_SUCCESS )
+            goto exit;
+        if( ( status = psa_hash_update( &op, counter, 4 ) ) != PSA_SUCCESS )
+            goto exit;
+        status = psa_hash_finish( &op, mask, sizeof( mask ), &out_len );
+        if( status != PSA_SUCCESS )
+            goto exit;
+#endif
 
         for( i = 0; i < use_len; ++i )
             *p++ ^= mask[i];
@@ -1148,10 +1200,16 @@
     }
 
 exit:
-    mbedtls_md_free( &md_ctx );
     mbedtls_platform_zeroize( mask, sizeof( mask ) );
+#if defined(MBEDTLS_MD_C)
+    mbedtls_md_free( &md_ctx );
 
     return( ret );
+#else
+    psa_hash_abort( &op );
+
+    return( ret_from_status( status ) );
+#endif
 }
 
 /**
@@ -1169,6 +1227,8 @@
                         unsigned char *out, mbedtls_md_type_t md_alg )
 {
     const unsigned char zeros[8] = { 0, 0, 0, 0, 0, 0, 0, 0 };
+
+#if defined(MBEDTLS_MD_C)
     mbedtls_md_context_t md_ctx;
     int ret;
 
@@ -1194,6 +1254,30 @@
     mbedtls_md_free( &md_ctx );
 
     return( ret );
+#else
+    psa_hash_operation_t op = PSA_HASH_OPERATION_INIT;
+    psa_algorithm_t alg = mbedtls_psa_translate_md( md_alg );
+    psa_status_t status = PSA_SUCCESS;
+    size_t out_size = PSA_HASH_LENGTH( alg );
+    size_t out_len;
+
+    if( ( status = psa_hash_setup( &op, alg ) ) != PSA_SUCCESS )
+        goto exit;
+    if( ( status = psa_hash_update( &op, zeros, sizeof( zeros ) ) ) != PSA_SUCCESS )
+        goto exit;
+    if( ( status = psa_hash_update( &op, hash, hlen ) ) != PSA_SUCCESS )
+        goto exit;
+    if( ( status = psa_hash_update( &op, salt, slen ) ) != PSA_SUCCESS )
+        goto exit;
+    status = psa_hash_finish( &op, out, out_size, &out_len );
+    if( status != PSA_SUCCESS )
+        goto exit;
+
+exit:
+    psa_hash_abort( &op );
+
+    return( ret_from_status( status ) );
+#endif /* !MBEDTLS_MD_C */
 }
 
 /**
@@ -1208,6 +1292,7 @@
                          const unsigned char *input, size_t ilen,
                          unsigned char *output )
 {
+#if defined(MBEDTLS_MD_C)
     const mbedtls_md_info_t *md_info;
 
     md_info = mbedtls_md_info_from_type( md_alg );
@@ -1215,6 +1300,16 @@
         return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
 
     return( mbedtls_md( md_info, input, ilen, output ) );
+#else
+    psa_algorithm_t alg = mbedtls_psa_translate_md( md_alg );
+    psa_status_t status;
+    size_t out_size = PSA_HASH_LENGTH( alg );
+    size_t out_len;
+
+    status = psa_hash_compute( alg, input, ilen, output, out_size, &out_len );
+
+    return( ret_from_status( status ) );
+#endif /* !MBEDTLS_MD_C */
 }
 #endif /* MBEDTLS_PKCS1_V21 */
 
@@ -1393,7 +1488,7 @@
     size_t ilen, i, pad_len;
     unsigned char *p, bad, pad_done;
     unsigned char buf[MBEDTLS_MPI_MAX_SIZE];
-    unsigned char lhash[MBEDTLS_MD_MAX_SIZE];
+    unsigned char lhash[HASH_MAX_SIZE];
     unsigned int hlen;
 
     RSA_VALIDATE_RET( ctx != NULL );
@@ -1988,7 +2083,7 @@
     size_t siglen;
     unsigned char *p;
     unsigned char *hash_start;
-    unsigned char result[MBEDTLS_MD_MAX_SIZE];
+    unsigned char result[HASH_MAX_SIZE];
     unsigned int hlen;
     size_t observed_salt_len, msb;
     unsigned char buf[MBEDTLS_MPI_MAX_SIZE] = {0};
diff --git a/tests/scripts/all.sh b/tests/scripts/all.sh
index f70dcd9..04a1cd6 100755
--- a/tests/scripts/all.sh
+++ b/tests/scripts/all.sh
@@ -1211,7 +1211,6 @@
     scripts/config.py unset MBEDTLS_ECJPAKE_C
     scripts/config.py unset MBEDTLS_HKDF_C
     scripts/config.py unset MBEDTLS_HMAC_DRBG_C
-    scripts/config.py unset MBEDTLS_PKCS1_V21
     scripts/config.py unset MBEDTLS_PKCS5_C
     scripts/config.py unset MBEDTLS_PKCS12_C
     # Indirect dependencies
@@ -1870,10 +1869,6 @@
     scripts/config.py unset MBEDTLS_ECJPAKE_C
     scripts/config.py unset MBEDTLS_HKDF_C
     scripts/config.py unset MBEDTLS_HMAC_DRBG_C
-    scripts/config.py unset MBEDTLS_PKCS1_V21
-    scripts/config.py unset MBEDTLS_X509_RSASSA_PSS_SUPPORT
-    scripts/config.py -f include/psa/crypto_config.h unset PSA_WANT_ALG_RSA_PSS
-    scripts/config.py -f include/psa/crypto_config.h unset PSA_WANT_ALG_RSA_OAEP
     scripts/config.py unset MBEDTLS_PKCS5_C
     scripts/config.py unset MBEDTLS_PKCS12_C
     scripts/config.py unset MBEDTLS_ECDSA_DETERMINISTIC