Fix PSA_ALG_RSA_PSS verification accepting an arbitrary salt length

PSA_ALG_RSA_PSS algorithm now accepts only the same salt length for
verification that it produces when signing, as documented.

Fixes #4946.

Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
diff --git a/library/psa_crypto_rsa.c b/library/psa_crypto_rsa.c
index 80b9985..de17e5d 100644
--- a/library/psa_crypto_rsa.c
+++ b/library/psa_crypto_rsa.c
@@ -447,6 +447,27 @@
     return( status );
 }
 
+#if defined(BUILTIN_ALG_RSA_PSS)
+static int rsa_pss_expected_salt_len( psa_algorithm_t alg,
+                                      const mbedtls_rsa_context *rsa,
+                                      size_t hash_length )
+{
+    if( PSA_ALG_IS_RSA_PSS_ANY_SALT( alg ) )
+        return( MBEDTLS_RSA_SALT_LEN_ANY );
+    /* Otherwise: standard salt length, i.e. largest possible salt length
+     * up to the hash length. */
+    int klen = (int) (int) mbedtls_rsa_get_len( rsa ); // known to fit
+    int hlen = (int) hash_length; // known to fit
+    int room = klen - 2 - hlen;
+    if( room < 0 )
+        return( 0 ); // there is no valid signature in this case anyway
+    else if( room > hlen )
+        return( hlen );
+    else
+        return( room );
+}
+#endif
+
 static psa_status_t rsa_verify_hash(
     const psa_key_attributes_t *attributes,
     const uint8_t *key_buffer, size_t key_buffer_size,
@@ -494,15 +515,18 @@
 #if defined(BUILTIN_ALG_RSA_PSS)
     if( PSA_ALG_IS_RSA_PSS( alg ) )
     {
+        int slen = rsa_pss_expected_salt_len( alg, rsa, hash_length );
         mbedtls_rsa_set_padding( rsa, MBEDTLS_RSA_PKCS_V21, md_alg );
-        ret = mbedtls_rsa_rsassa_pss_verify( rsa,
-                                             mbedtls_psa_get_random,
-                                             MBEDTLS_PSA_RANDOM_STATE,
-                                             MBEDTLS_RSA_PUBLIC,
-                                             md_alg,
-                                             (unsigned int) hash_length,
-                                             hash,
-                                             signature );
+        ret = mbedtls_rsa_rsassa_pss_verify_ext( rsa,
+                                                 mbedtls_psa_get_random,
+                                                 MBEDTLS_PSA_RANDOM_STATE,
+                                                 MBEDTLS_RSA_PUBLIC,
+                                                 md_alg,
+                                                 (unsigned int) hash_length,
+                                                 hash,
+                                                 md_alg,
+                                                 slen,
+                                                 signature );
     }
     else
 #endif /* BUILTIN_ALG_RSA_PSS */