rsa: introduce rsa_internal_rsassa_pss_sign_no_mode_check()

And use it in the non-PSA version of mbedtls_pk_sign_ext()
to bypass checks that didn't succeed when used by TLS 1.3.

That is because in the failing scenarios the padding of
the RSA context is not set to PKCS_V21.

See the discussion on PR #7930 for more details.

Signed-off-by: Tomi Fontanilles <129057597+tomi-font@users.noreply.github.com>
diff --git a/library/pk.c b/library/pk.c
index 344d29f..929af3c 100644
--- a/library/pk.c
+++ b/library/pk.c
@@ -18,6 +18,9 @@
 
 #if defined(MBEDTLS_RSA_C)
 #include "mbedtls/rsa.h"
+#if defined(MBEDTLS_PKCS1_V21) && !defined(MBEDTLS_USE_PSA_CRYPTO)
+#include "rsa_internal.h"
+#endif
 #endif
 #if defined(MBEDTLS_PK_HAVE_ECC_KEYS)
 #include "mbedtls/ecp.h"
@@ -728,8 +731,8 @@
 
     mbedtls_rsa_context *const rsa_ctx = mbedtls_pk_rsa(*ctx);
 
-    const int ret = mbedtls_rsa_rsassa_pss_sign(rsa_ctx, f_rng, p_rng, md_alg,
-                                                (unsigned int) hash_len, hash, sig);
+    const int ret = mbedtls_rsa_rsassa_pss_sign_no_mode_check(rsa_ctx, f_rng, p_rng, md_alg,
+                                                              (unsigned int) hash_len, hash, sig);
     if (ret == 0) {
         *sig_len = rsa_ctx->len;
     }
diff --git a/library/rsa.c b/library/rsa.c
index 1bf5d13..2b9f85b 100644
--- a/library/rsa.c
+++ b/library/rsa.c
@@ -29,6 +29,7 @@
 
 #include "mbedtls/rsa.h"
 #include "rsa_alt_helpers.h"
+#include "rsa_internal.h"
 #include "mbedtls/oid.h"
 #include "mbedtls/platform_util.h"
 #include "mbedtls/error.h"
@@ -1712,14 +1713,14 @@
 }
 
 #if defined(MBEDTLS_PKCS1_V21)
-static int rsa_rsassa_pss_sign(mbedtls_rsa_context *ctx,
-                               int (*f_rng)(void *, unsigned char *, size_t),
-                               void *p_rng,
-                               mbedtls_md_type_t md_alg,
-                               unsigned int hashlen,
-                               const unsigned char *hash,
-                               int saltlen,
-                               unsigned char *sig)
+static int rsa_rsassa_pss_sign_no_mode_check(mbedtls_rsa_context *ctx,
+                                             int (*f_rng)(void *, unsigned char *, size_t),
+                                             void *p_rng,
+                                             mbedtls_md_type_t md_alg,
+                                             unsigned int hashlen,
+                                             const unsigned char *hash,
+                                             int saltlen,
+                                             unsigned char *sig)
 {
     size_t olen;
     unsigned char *p = sig;
@@ -1727,15 +1728,12 @@
     size_t slen, min_slen, hlen, offset = 0;
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     size_t msb;
+    mbedtls_md_type_t hash_id;
 
     if ((md_alg != MBEDTLS_MD_NONE || hashlen != 0) && hash == NULL) {
         return MBEDTLS_ERR_RSA_BAD_INPUT_DATA;
     }
 
-    if (ctx->padding != MBEDTLS_RSA_PKCS_V21) {
-        return MBEDTLS_ERR_RSA_BAD_INPUT_DATA;
-    }
-
     if (f_rng == NULL) {
         return MBEDTLS_ERR_RSA_BAD_INPUT_DATA;
     }
@@ -1754,7 +1752,11 @@
         }
     }
 
-    hlen = mbedtls_md_get_size_from_type((mbedtls_md_type_t) ctx->hash_id);
+    hash_id = (mbedtls_md_type_t) ctx->hash_id;
+    if (hash_id == MBEDTLS_MD_NONE) {
+        hash_id = md_alg;
+    }
+    hlen = mbedtls_md_get_size_from_type(hash_id);
     if (hlen == 0) {
         return MBEDTLS_ERR_RSA_BAD_INPUT_DATA;
     }
@@ -1797,7 +1799,7 @@
     p += slen;
 
     /* Generate H = Hash( M' ) */
-    ret = hash_mprime(hash, hashlen, salt, slen, p, (mbedtls_md_type_t) ctx->hash_id);
+    ret = hash_mprime(hash, hashlen, salt, slen, p, hash_id);
     if (ret != 0) {
         return ret;
     }
@@ -1808,8 +1810,7 @@
     }
 
     /* maskedDB: Apply dbMask to DB */
-    ret = mgf_mask(sig + offset, olen - hlen - 1 - offset, p, hlen,
-                   (mbedtls_md_type_t) ctx->hash_id);
+    ret = mgf_mask(sig + offset, olen - hlen - 1 - offset, p, hlen, hash_id);
     if (ret != 0) {
         return ret;
     }
@@ -1823,6 +1824,37 @@
     return mbedtls_rsa_private(ctx, f_rng, p_rng, sig, sig);
 }
 
+static int rsa_rsassa_pss_sign(mbedtls_rsa_context *ctx,
+                               int (*f_rng)(void *, unsigned char *, size_t),
+                               void *p_rng,
+                               mbedtls_md_type_t md_alg,
+                               unsigned int hashlen,
+                               const unsigned char *hash,
+                               int saltlen,
+                               unsigned char *sig)
+{
+    if (ctx->padding != MBEDTLS_RSA_PKCS_V21) {
+        return MBEDTLS_ERR_RSA_BAD_INPUT_DATA;
+    }
+    if (ctx->hash_id == MBEDTLS_MD_NONE) {
+        return MBEDTLS_ERR_RSA_BAD_INPUT_DATA;
+    }
+    return rsa_rsassa_pss_sign_no_mode_check(ctx, f_rng, p_rng, md_alg, hashlen, hash, saltlen,
+                                             sig);
+}
+
+int mbedtls_rsa_rsassa_pss_sign_no_mode_check(mbedtls_rsa_context *ctx,
+                                              int (*f_rng)(void *, unsigned char *, size_t),
+                                              void *p_rng,
+                                              mbedtls_md_type_t md_alg,
+                                              unsigned int hashlen,
+                                              const unsigned char *hash,
+                                              unsigned char *sig)
+{
+    return rsa_rsassa_pss_sign_no_mode_check(ctx, f_rng, p_rng, md_alg,
+                                             hashlen, hash, MBEDTLS_RSA_SALT_LEN_ANY, sig);
+}
+
 /*
  * Implementation of the PKCS#1 v2.1 RSASSA-PSS-SIGN function with
  * the option to pass in the salt length.
@@ -1840,7 +1872,6 @@
                                hashlen, hash, saltlen, sig);
 }
 
-
 /*
  * Implementation of the PKCS#1 v2.1 RSASSA-PSS-SIGN function
  */
diff --git a/library/rsa_alt_helpers.h b/library/rsa_alt_helpers.h
index ca0840b..052b024 100644
--- a/library/rsa_alt_helpers.h
+++ b/library/rsa_alt_helpers.h
@@ -37,11 +37,9 @@
 /*
  *  Copyright The Mbed TLS Contributors
  *  SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
- *
  */
-
-#ifndef MBEDTLS_RSA_INTERNAL_H
-#define MBEDTLS_RSA_INTERNAL_H
+#ifndef MBEDTLS_RSA_ALT_HELPERS_H
+#define MBEDTLS_RSA_ALT_HELPERS_H
 
 #include "mbedtls/build_info.h"
 
diff --git a/library/rsa_internal.h b/library/rsa_internal.h
new file mode 100644
index 0000000..4081ac6
--- /dev/null
+++ b/library/rsa_internal.h
@@ -0,0 +1,42 @@
+/**
+ * \file rsa_internal.h
+ *
+ * \brief Internal-only RSA public-key cryptosystem API.
+ *
+ * This file declares RSA-related functions that are to be used
+ * only from within the Mbed TLS library itself.
+ *
+ */
+/*
+ *  Copyright The Mbed TLS Contributors
+ *  SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
+ */
+#ifndef MBEDTLS_RSA_INTERNAL_H
+#define MBEDTLS_RSA_INTERNAL_H
+
+#include "mbedtls/rsa.h"
+
+#if defined(MBEDTLS_PKCS1_V21)
+/**
+ * \brief This function is analogue to \c mbedtls_rsa_rsassa_pss_sign().
+ *        The only difference between them is that this function is more flexible
+ *        on the parameters of \p ctx that are set with \c mbedtls_rsa_set_padding().
+ *
+ * \note  Compared to its counterpart, this function:
+ *        - does not check the padding setting of \p ctx.
+ *        - allows the hash_id of \p ctx to be MBEDTLS_MD_NONE,
+ *          in which case it uses \p md_alg as the hash_id.
+ *
+ * \note  Refer to \c mbedtls_rsa_rsassa_pss_sign() for a description
+ *        of the functioning and parameters of this function.
+ */
+int mbedtls_rsa_rsassa_pss_sign_no_mode_check(mbedtls_rsa_context *ctx,
+                                              int (*f_rng)(void *, unsigned char *, size_t),
+                                              void *p_rng,
+                                              mbedtls_md_type_t md_alg,
+                                              unsigned int hashlen,
+                                              const unsigned char *hash,
+                                              unsigned char *sig);
+#endif /* MBEDTLS_PKCS1_V21 */
+
+#endif /* rsa_internal.h */