Merge pull request #7546 from mpg/align-psa-md-identifiers

Align psa md identifiers
diff --git a/include/mbedtls/md.h b/include/mbedtls/md.h
index f717618..5831e12 100644
--- a/include/mbedtls/md.h
+++ b/include/mbedtls/md.h
@@ -146,19 +146,22 @@
  *            stronger message digests instead.
  *
  */
+/* Note: these are aligned with the definitions of PSA_ALG_ macros for hashes,
+ * in order to enable an efficient implementation of conversion functions.
+ * This is tested by md_to_from_psa() in test_suite_md. */
 typedef enum {
     MBEDTLS_MD_NONE=0,    /**< None. */
-    MBEDTLS_MD_MD5,       /**< The MD5 message digest. */
-    MBEDTLS_MD_SHA1,      /**< The SHA-1 message digest. */
-    MBEDTLS_MD_SHA224,    /**< The SHA-224 message digest. */
-    MBEDTLS_MD_SHA256,    /**< The SHA-256 message digest. */
-    MBEDTLS_MD_SHA384,    /**< The SHA-384 message digest. */
-    MBEDTLS_MD_SHA512,    /**< The SHA-512 message digest. */
-    MBEDTLS_MD_RIPEMD160, /**< The RIPEMD-160 message digest. */
-    MBEDTLS_MD_SHA3_224,    /**< The SHA3-224 message digest. */
-    MBEDTLS_MD_SHA3_256,    /**< The SHA3-256 message digest. */
-    MBEDTLS_MD_SHA3_384,    /**< The SHA3-384 message digest. */
-    MBEDTLS_MD_SHA3_512,    /**< The SHA3-512 message digest. */
+    MBEDTLS_MD_MD5=0x03,       /**< The MD5 message digest. */
+    MBEDTLS_MD_RIPEMD160=0x04, /**< The RIPEMD-160 message digest. */
+    MBEDTLS_MD_SHA1=0x05,      /**< The SHA-1 message digest. */
+    MBEDTLS_MD_SHA224=0x08,    /**< The SHA-224 message digest. */
+    MBEDTLS_MD_SHA256=0x09,    /**< The SHA-256 message digest. */
+    MBEDTLS_MD_SHA384=0x0a,    /**< The SHA-384 message digest. */
+    MBEDTLS_MD_SHA512=0x0b,    /**< The SHA-512 message digest. */
+    MBEDTLS_MD_SHA3_224=0x10,  /**< The SHA3-224 message digest. */
+    MBEDTLS_MD_SHA3_256=0x11,  /**< The SHA3-256 message digest. */
+    MBEDTLS_MD_SHA3_384=0x12,  /**< The SHA3-384 message digest. */
+    MBEDTLS_MD_SHA3_512=0x13,  /**< The SHA3-512 message digest. */
 } mbedtls_md_type_t;
 
 /* Note: this should always be >= PSA_HASH_MAX_SIZE
diff --git a/library/md.c b/library/md.c
index 964d4bd..3589d63 100644
--- a/library/md.c
+++ b/library/md.c
@@ -786,78 +786,6 @@
 }
 
 #if defined(MBEDTLS_PSA_CRYPTO_C)
-psa_algorithm_t mbedtls_md_psa_alg_from_type(mbedtls_md_type_t md_type)
-{
-    switch (md_type) {
-#if defined(MBEDTLS_MD_CAN_MD5)
-        case MBEDTLS_MD_MD5:
-            return PSA_ALG_MD5;
-#endif
-#if defined(MBEDTLS_MD_CAN_RIPEMD160)
-        case MBEDTLS_MD_RIPEMD160:
-            return PSA_ALG_RIPEMD160;
-#endif
-#if defined(MBEDTLS_MD_CAN_SHA1)
-        case MBEDTLS_MD_SHA1:
-            return PSA_ALG_SHA_1;
-#endif
-#if defined(MBEDTLS_MD_CAN_SHA224)
-        case MBEDTLS_MD_SHA224:
-            return PSA_ALG_SHA_224;
-#endif
-#if defined(MBEDTLS_MD_CAN_SHA256)
-        case MBEDTLS_MD_SHA256:
-            return PSA_ALG_SHA_256;
-#endif
-#if defined(MBEDTLS_MD_CAN_SHA384)
-        case MBEDTLS_MD_SHA384:
-            return PSA_ALG_SHA_384;
-#endif
-#if defined(MBEDTLS_MD_CAN_SHA512)
-        case MBEDTLS_MD_SHA512:
-            return PSA_ALG_SHA_512;
-#endif
-        default:
-            return PSA_ALG_NONE;
-    }
-}
-
-mbedtls_md_type_t mbedtls_md_type_from_psa_alg(psa_algorithm_t psa_alg)
-{
-    switch (psa_alg) {
-#if defined(MBEDTLS_MD_CAN_MD5)
-        case PSA_ALG_MD5:
-            return MBEDTLS_MD_MD5;
-#endif
-#if defined(MBEDTLS_MD_CAN_RIPEMD160)
-        case PSA_ALG_RIPEMD160:
-            return MBEDTLS_MD_RIPEMD160;
-#endif
-#if defined(MBEDTLS_MD_CAN_SHA1)
-        case PSA_ALG_SHA_1:
-            return MBEDTLS_MD_SHA1;
-#endif
-#if defined(MBEDTLS_MD_CAN_SHA224)
-        case PSA_ALG_SHA_224:
-            return MBEDTLS_MD_SHA224;
-#endif
-#if defined(MBEDTLS_MD_CAN_SHA256)
-        case PSA_ALG_SHA_256:
-            return MBEDTLS_MD_SHA256;
-#endif
-#if defined(MBEDTLS_MD_CAN_SHA384)
-        case PSA_ALG_SHA_384:
-            return MBEDTLS_MD_SHA384;
-#endif
-#if defined(MBEDTLS_MD_CAN_SHA512)
-        case PSA_ALG_SHA_512:
-            return MBEDTLS_MD_SHA512;
-#endif
-        default:
-            return MBEDTLS_MD_NONE;
-    }
-}
-
 int mbedtls_md_error_from_psa(psa_status_t status)
 {
     return PSA_TO_MBEDTLS_ERR_LIST(status, psa_to_md_errors,
diff --git a/library/md_psa.h b/library/md_psa.h
index 6645c83..8e00bb1 100644
--- a/library/md_psa.h
+++ b/library/md_psa.h
@@ -31,12 +31,21 @@
  * \brief           This function returns the PSA algorithm identifier
  *                  associated with the given digest type.
  *
- * \param md_type   The type of digest to search for.
+ * \param md_type   The type of digest to search for. Must not be NONE.
  *
- * \return          The PSA algorithm identifier associated with \p md_type.
- * \return          PSA_ALG_NONE if the algorithm is not supported.
+ * \warning         If \p md_type is \c MBEDTLS_MD_NONE, this function will
+ *                  not return \c PSA_ALG_NONE, but an invalid algorithm.
+ *
+ * \warning         This function does not check if the algorithm is
+ *                  supported, it always returns the corresponding identifier.
+ *
+ * \return          The PSA algorithm identifier associated with \p md_type,
+ *                  regardless of whether it is supported or not.
  */
-psa_algorithm_t mbedtls_md_psa_alg_from_type(mbedtls_md_type_t md_type);
+static inline psa_algorithm_t mbedtls_md_psa_alg_from_type(mbedtls_md_type_t md_type)
+{
+    return PSA_ALG_CATEGORY_HASH | (psa_algorithm_t) md_type;
+}
 
 /**
  * \brief           This function returns the given digest type
@@ -44,10 +53,16 @@
  *
  * \param psa_alg   The PSA algorithm identifier to search for.
  *
- * \return          The MD type associated with \p psa_alg.
- * \return          MBEDTLS_MD_NONE if the algorithm is not supported.
+ * \warning         This function does not check if the algorithm is
+ *                  supported, it always returns the corresponding identifier.
+ *
+ * \return          The MD type associated with \p psa_alg,
+ *                  regardless of whether it is supported or not.
  */
-mbedtls_md_type_t mbedtls_md_type_from_psa_alg(psa_algorithm_t psa_alg);
+static inline mbedtls_md_type_t mbedtls_md_type_from_psa_alg(psa_algorithm_t psa_alg)
+{
+    return (mbedtls_md_type_t) (psa_alg & PSA_ALG_HASH_MASK);
+}
 
 /** Convert PSA status to MD error code.
  *
diff --git a/library/psa_crypto_rsa.c b/library/psa_crypto_rsa.c
index 4e11b36..30d4c04 100644
--- a/library/psa_crypto_rsa.c
+++ b/library/psa_crypto_rsa.c
@@ -529,6 +529,12 @@
     psa_algorithm_t hash_alg = PSA_ALG_RSA_OAEP_GET_HASH(alg);
     mbedtls_md_type_t md_alg = mbedtls_md_type_from_psa_alg(hash_alg);
 
+    /* Just to get the error status right, as rsa_set_padding() doesn't
+     * distinguish between "bad RSA algorithm" and "unknown hash". */
+    if (mbedtls_md_info_from_type(md_alg) == NULL) {
+        return PSA_ERROR_NOT_SUPPORTED;
+    }
+
     return mbedtls_rsa_set_padding(rsa, MBEDTLS_RSA_PKCS_V21, md_alg);
 }
 #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) */
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index f75bc20..6e15493 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -8807,11 +8807,17 @@
             MBEDTLS_SSL_TLS12_SIG_ALG_FROM_SIG_AND_HASH_ALG(
                 received_sig_algs[i]);
 
+        mbedtls_md_type_t md_alg =
+            mbedtls_ssl_md_alg_from_hash((unsigned char) hash_alg_received);
+        if (md_alg == MBEDTLS_MD_NONE) {
+            continue;
+        }
+
         if (sig_alg == sig_alg_received) {
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
             if (ssl->handshake->key_cert && ssl->handshake->key_cert->key) {
                 psa_algorithm_t psa_hash_alg =
-                    mbedtls_md_psa_alg_from_type(hash_alg_received);
+                    mbedtls_md_psa_alg_from_type(md_alg);
 
                 if (sig_alg_received == MBEDTLS_SSL_SIG_ECDSA &&
                     !mbedtls_pk_can_do_ext(ssl->handshake->key_cert->key,
diff --git a/tests/scripts/all.sh b/tests/scripts/all.sh
index a747c9d..26cf4c6 100755
--- a/tests/scripts/all.sh
+++ b/tests/scripts/all.sh
@@ -4271,6 +4271,10 @@
 
     msg "size: ${ARM_NONE_EABI_GCC_PREFIX}gcc -mthumb -mcpu=cortex-m0plus -Os, baremetal_size"
     ${ARM_NONE_EABI_GCC_PREFIX}size -t library/*.o
+    for lib in library/*.a; do
+        echo "$lib:"
+        ${ARM_NONE_EABI_GCC_PREFIX}size -t $lib | grep TOTALS
+    done
 }
 
 component_build_arm_none_eabi_gcc_no_udbl_division () {
diff --git a/tests/suites/test_suite_md.data b/tests/suites/test_suite_md.data
index 0b0afee..9b39e9f 100644
--- a/tests/suites/test_suite_md.data
+++ b/tests/suites/test_suite_md.data
@@ -2,6 +2,9 @@
 MD list
 mbedtls_md_list:
 
+MD <-> PSA conversion
+md_to_from_psa:
+
 MD NULL/uninitialised arguments
 md_null_args:
 
diff --git a/tests/suites/test_suite_md.function b/tests/suites/test_suite_md.function
index ac9516a..e3f0e15 100644
--- a/tests/suites/test_suite_md.function
+++ b/tests/suites/test_suite_md.function
@@ -1,5 +1,10 @@
 /* BEGIN_HEADER */
 #include "mbedtls/md.h"
+#include "md_psa.h"
+
+#define MD_PSA(md, psa) \
+    TEST_EQUAL(mbedtls_md_psa_alg_from_type(md), psa);  \
+    TEST_EQUAL(mbedtls_md_type_from_psa_alg(psa), md);
 /* END_HEADER */
 
 /* BEGIN_DEPENDENCIES
@@ -36,6 +41,27 @@
 }
 /* END_CASE */
 
+/* BEGIN_CASE depends_on:MBEDTLS_PSA_CRYPTO_C */
+void md_to_from_psa()
+{
+    /* We use a simplified implementation that relies on numerical values
+     * being aligned, so make sure they remain so. */
+    MD_PSA(MBEDTLS_MD_MD5, PSA_ALG_MD5);
+    MD_PSA(MBEDTLS_MD_RIPEMD160, PSA_ALG_RIPEMD160);
+    MD_PSA(MBEDTLS_MD_SHA1, PSA_ALG_SHA_1);
+    MD_PSA(MBEDTLS_MD_SHA224, PSA_ALG_SHA_224);
+    MD_PSA(MBEDTLS_MD_SHA256, PSA_ALG_SHA_256);
+    MD_PSA(MBEDTLS_MD_SHA384, PSA_ALG_SHA_384);
+    MD_PSA(MBEDTLS_MD_SHA512, PSA_ALG_SHA_512);
+    MD_PSA(MBEDTLS_MD_SHA3_224, PSA_ALG_SHA3_224);
+    MD_PSA(MBEDTLS_MD_SHA3_256, PSA_ALG_SHA3_256);
+    MD_PSA(MBEDTLS_MD_SHA3_384, PSA_ALG_SHA3_384);
+    MD_PSA(MBEDTLS_MD_SHA3_512, PSA_ALG_SHA3_512);
+
+    /* Don't test for NONE<->NONE as this is not guaranteed */
+}
+/* END_CASE */
+
 /* BEGIN_CASE */
 void md_null_args()
 {