Make RSA-PSS verification use PSA with MBEDTLS_USE_PSA_CRYPTO

Duplicate a test case but with a different expected error
due to error translation to and from PSA.
Signed-off-by: Andrzej Kurek <andrzej.kurek@arm.com>
diff --git a/library/pk.c b/library/pk.c
index e364520..4eff8e5 100644
--- a/library/pk.c
+++ b/library/pk.c
@@ -347,22 +347,76 @@
     if( ! mbedtls_pk_can_do( ctx, type ) )
         return( MBEDTLS_ERR_PK_TYPE_MISMATCH );
 
-    if( type == MBEDTLS_PK_RSASSA_PSS )
+    if( type != MBEDTLS_PK_RSASSA_PSS )
     {
+        /* General case: no options */
+        if( options != NULL )
+            return( MBEDTLS_ERR_PK_BAD_INPUT_DATA );
+
+        return( mbedtls_pk_verify( ctx, md_alg, hash, hash_len, sig, sig_len ) );
+    }
+
 #if defined(MBEDTLS_RSA_C) && defined(MBEDTLS_PKCS1_V21)
-        int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-        const mbedtls_pk_rsassa_pss_options *pss_opts;
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    const mbedtls_pk_rsassa_pss_options *pss_opts;
 
 #if SIZE_MAX > UINT_MAX
-        if( md_alg == MBEDTLS_MD_NONE && UINT_MAX < hash_len )
-            return( MBEDTLS_ERR_PK_BAD_INPUT_DATA );
+    if( md_alg == MBEDTLS_MD_NONE && UINT_MAX < hash_len )
+        return( MBEDTLS_ERR_PK_BAD_INPUT_DATA );
 #endif /* SIZE_MAX > UINT_MAX */
 
-        if( options == NULL )
-            return( MBEDTLS_ERR_PK_BAD_INPUT_DATA );
+    if( options == NULL )
+        return( MBEDTLS_ERR_PK_BAD_INPUT_DATA );
 
-        pss_opts = (const mbedtls_pk_rsassa_pss_options *) options;
+    pss_opts = (const mbedtls_pk_rsassa_pss_options *) options;
 
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    psa_status_t status;
+    if( pss_opts->mgf1_hash_id == md_alg &&
+        ( (size_t) pss_opts->expected_salt_len == hash_len ||
+            pss_opts->expected_salt_len  == MBEDTLS_RSA_SALT_LEN_ANY ) )
+    {
+        /* see RSA_PUB_DER_MAX_BYTES in pkwrite.c */
+        unsigned char buf[ 38 + 2 * MBEDTLS_MPI_MAX_SIZE ];
+        unsigned char *p;
+        int key_len;
+        psa_algorithm_t psa_md_alg = mbedtls_psa_translate_md( md_alg );
+        mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
+        psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
+        psa_algorithm_t psa_sig_md =
+            ( pss_opts->expected_salt_len == MBEDTLS_RSA_SALT_LEN_ANY ?
+                                 PSA_ALG_RSA_PSS_ANY_SALT(psa_md_alg) :
+                                 PSA_ALG_RSA_PSS(psa_md_alg) );
+        p = buf + sizeof( buf );
+        key_len = mbedtls_pk_write_pubkey( &p, buf, ctx );
+
+        if( key_len < 0 )
+            return( key_len );
+
+        psa_set_key_type( &attributes, PSA_KEY_TYPE_RSA_PUBLIC_KEY );
+        psa_set_key_usage_flags( &attributes, PSA_KEY_USAGE_VERIFY_HASH );
+        psa_set_key_algorithm( &attributes, psa_sig_md );
+
+        status = psa_import_key( &attributes,
+                                 buf + sizeof( buf ) - key_len, key_len,
+                                 &key_id );
+        if( status != PSA_SUCCESS )
+        {
+            psa_destroy_key( key_id );
+            return( mbedtls_psa_err_translate_pk( status ) );
+        }
+
+        status = psa_verify_hash( key_id, psa_sig_md, hash,
+                                  hash_len, sig, sig_len );
+        psa_destroy_key( key_id );
+
+        return( status == PSA_ERROR_INVALID_SIGNATURE?
+                              MBEDTLS_ERR_RSA_VERIFY_FAILED :
+                              mbedtls_psa_err_translate_pk( status ) );
+    }
+    else
+#endif
+    {
         if( sig_len < mbedtls_pk_get_len( ctx ) )
             return( MBEDTLS_ERR_RSA_VERIFY_FAILED );
 
@@ -376,18 +430,11 @@
 
         if( sig_len > mbedtls_pk_get_len( ctx ) )
             return( MBEDTLS_ERR_PK_SIG_LEN_MISMATCH );
-
-        return( 0 );
-#else
-        return( MBEDTLS_ERR_PK_FEATURE_UNAVAILABLE );
-#endif /* MBEDTLS_RSA_C && MBEDTLS_PKCS1_V21 */
     }
-
-    /* General case: no options */
-    if( options != NULL )
-        return( MBEDTLS_ERR_PK_BAD_INPUT_DATA );
-
-    return( mbedtls_pk_verify( ctx, md_alg, hash, hash_len, sig, sig_len ) );
+    return( 0 );
+#else
+    return( MBEDTLS_ERR_PK_FEATURE_UNAVAILABLE );
+#endif /* MBEDTLS_RSA_C && MBEDTLS_PKCS1_V21 */
 }
 
 /*
diff --git a/tests/suites/test_suite_pk.data b/tests/suites/test_suite_pk.data
index 5eb145d..9fff671 100644
--- a/tests/suites/test_suite_pk.data
+++ b/tests/suites/test_suite_pk.data
@@ -180,8 +180,12 @@
 depends_on:MBEDTLS_PKCS1_V21:MBEDTLS_SHA256_C
 pk_rsa_verify_ext_test_vec:"54657374206d657373616765":MBEDTLS_MD_SHA256:1024:16:"00dd118a9f99bab068ca2aea3b6a6d5997ed4ec954e40deecea07da01eaae80ec2bb1340db8a128e891324a5c5f5fad8f590d7c8cacbc5fe931dafda1223735279461abaa0572b761631b3a8afe7389b088b63993a0a25ee45d21858bab9931aedd4589a631b37fcf714089f856549f359326dd1e0e86dde52ed66b4a90bda4095":16:"010001":"0d2bdb0456a3d651d5bd48a4204493898f72cf1aaddd71387cc058bc3f4c235ea6be4010fd61b28e1fbb275462b53775c04be9022d38b6a2e0387dddba86a3f8554d2858044a59fddbd594753fc056fe33c8daddb85dc70d164690b1182209ff84824e0be10e35c379f2f378bf176a9f7cb94d95e44d90276a298c8810f741c9":MBEDTLS_PK_RSASSA_PSS:MBEDTLS_MD_SHA256:94:0
 
+Verify ext RSA #5 using PSA (PKCS1 v2.1, wrong salt_len)
+depends_on:MBEDTLS_PKCS1_V21:MBEDTLS_SHA256_C:MBEDTLS_USE_PSA_CRYPTO
+pk_rsa_verify_ext_test_vec:"54657374206d657373616765":MBEDTLS_MD_SHA256:1024:16:"00dd118a9f99bab068ca2aea3b6a6d5997ed4ec954e40deecea07da01eaae80ec2bb1340db8a128e891324a5c5f5fad8f590d7c8cacbc5fe931dafda1223735279461abaa0572b761631b3a8afe7389b088b63993a0a25ee45d21858bab9931aedd4589a631b37fcf714089f856549f359326dd1e0e86dde52ed66b4a90bda4095":16:"010001":"0d2bdb0456a3d651d5bd48a4204493898f72cf1aaddd71387cc058bc3f4c235ea6be4010fd61b28e1fbb275462b53775c04be9022d38b6a2e0387dddba86a3f8554d2858044a59fddbd594753fc056fe33c8daddb85dc70d164690b1182209ff84824e0be10e35c379f2f378bf176a9f7cb94d95e44d90276a298c8810f741c9":MBEDTLS_PK_RSASSA_PSS:MBEDTLS_MD_SHA256:32:MBEDTLS_ERR_RSA_VERIFY_FAILED
+
 Verify ext RSA #5 (PKCS1 v2.1, wrong salt_len)
-depends_on:MBEDTLS_PKCS1_V21:MBEDTLS_SHA256_C
+depends_on:MBEDTLS_PKCS1_V21:MBEDTLS_SHA256_C:!MBEDTLS_USE_PSA_CRYPTO
 pk_rsa_verify_ext_test_vec:"54657374206d657373616765":MBEDTLS_MD_SHA256:1024:16:"00dd118a9f99bab068ca2aea3b6a6d5997ed4ec954e40deecea07da01eaae80ec2bb1340db8a128e891324a5c5f5fad8f590d7c8cacbc5fe931dafda1223735279461abaa0572b761631b3a8afe7389b088b63993a0a25ee45d21858bab9931aedd4589a631b37fcf714089f856549f359326dd1e0e86dde52ed66b4a90bda4095":16:"010001":"0d2bdb0456a3d651d5bd48a4204493898f72cf1aaddd71387cc058bc3f4c235ea6be4010fd61b28e1fbb275462b53775c04be9022d38b6a2e0387dddba86a3f8554d2858044a59fddbd594753fc056fe33c8daddb85dc70d164690b1182209ff84824e0be10e35c379f2f378bf176a9f7cb94d95e44d90276a298c8810f741c9":MBEDTLS_PK_RSASSA_PSS:MBEDTLS_MD_SHA256:32:MBEDTLS_ERR_RSA_INVALID_PADDING
 
 Verify ext RSA #6 (PKCS1 v2.1, MGF1 alg != MSG hash alg)
diff --git a/tests/suites/test_suite_pk.function b/tests/suites/test_suite_pk.function
index 56cc45b..edd5e66 100644
--- a/tests/suites/test_suite_pk.function
+++ b/tests/suites/test_suite_pk.function
@@ -438,6 +438,9 @@
     void *options;
     size_t hash_len;
 
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    psa_crypto_init();
+#endif
     mbedtls_pk_init( &pk );
 
     memset( hash_result, 0x00, sizeof( hash_result ) );
@@ -481,6 +484,10 @@
 
 exit:
     mbedtls_pk_free( &pk );
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    mbedtls_test_psa_purge_key_storage();
+    mbedtls_psa_crypto_free();
+#endif
 }
 /* END_CASE */