Make RSA-PSS verification use PSA with MBEDTLS_USE_PSA_CRYPTO

Duplicate a test case but with a different expected error
due to error translation to and from PSA.
Signed-off-by: Andrzej Kurek <andrzej.kurek@arm.com>
diff --git a/library/pk.c b/library/pk.c
index e364520..4eff8e5 100644
--- a/library/pk.c
+++ b/library/pk.c
@@ -347,22 +347,76 @@
     if( ! mbedtls_pk_can_do( ctx, type ) )
         return( MBEDTLS_ERR_PK_TYPE_MISMATCH );
 
-    if( type == MBEDTLS_PK_RSASSA_PSS )
+    if( type != MBEDTLS_PK_RSASSA_PSS )
     {
+        /* General case: no options */
+        if( options != NULL )
+            return( MBEDTLS_ERR_PK_BAD_INPUT_DATA );
+
+        return( mbedtls_pk_verify( ctx, md_alg, hash, hash_len, sig, sig_len ) );
+    }
+
 #if defined(MBEDTLS_RSA_C) && defined(MBEDTLS_PKCS1_V21)
-        int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-        const mbedtls_pk_rsassa_pss_options *pss_opts;
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    const mbedtls_pk_rsassa_pss_options *pss_opts;
 
 #if SIZE_MAX > UINT_MAX
-        if( md_alg == MBEDTLS_MD_NONE && UINT_MAX < hash_len )
-            return( MBEDTLS_ERR_PK_BAD_INPUT_DATA );
+    if( md_alg == MBEDTLS_MD_NONE && UINT_MAX < hash_len )
+        return( MBEDTLS_ERR_PK_BAD_INPUT_DATA );
 #endif /* SIZE_MAX > UINT_MAX */
 
-        if( options == NULL )
-            return( MBEDTLS_ERR_PK_BAD_INPUT_DATA );
+    if( options == NULL )
+        return( MBEDTLS_ERR_PK_BAD_INPUT_DATA );
 
-        pss_opts = (const mbedtls_pk_rsassa_pss_options *) options;
+    pss_opts = (const mbedtls_pk_rsassa_pss_options *) options;
 
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    psa_status_t status;
+    if( pss_opts->mgf1_hash_id == md_alg &&
+        ( (size_t) pss_opts->expected_salt_len == hash_len ||
+            pss_opts->expected_salt_len  == MBEDTLS_RSA_SALT_LEN_ANY ) )
+    {
+        /* see RSA_PUB_DER_MAX_BYTES in pkwrite.c */
+        unsigned char buf[ 38 + 2 * MBEDTLS_MPI_MAX_SIZE ];
+        unsigned char *p;
+        int key_len;
+        psa_algorithm_t psa_md_alg = mbedtls_psa_translate_md( md_alg );
+        mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
+        psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
+        psa_algorithm_t psa_sig_md =
+            ( pss_opts->expected_salt_len == MBEDTLS_RSA_SALT_LEN_ANY ?
+                                 PSA_ALG_RSA_PSS_ANY_SALT(psa_md_alg) :
+                                 PSA_ALG_RSA_PSS(psa_md_alg) );
+        p = buf + sizeof( buf );
+        key_len = mbedtls_pk_write_pubkey( &p, buf, ctx );
+
+        if( key_len < 0 )
+            return( key_len );
+
+        psa_set_key_type( &attributes, PSA_KEY_TYPE_RSA_PUBLIC_KEY );
+        psa_set_key_usage_flags( &attributes, PSA_KEY_USAGE_VERIFY_HASH );
+        psa_set_key_algorithm( &attributes, psa_sig_md );
+
+        status = psa_import_key( &attributes,
+                                 buf + sizeof( buf ) - key_len, key_len,
+                                 &key_id );
+        if( status != PSA_SUCCESS )
+        {
+            psa_destroy_key( key_id );
+            return( mbedtls_psa_err_translate_pk( status ) );
+        }
+
+        status = psa_verify_hash( key_id, psa_sig_md, hash,
+                                  hash_len, sig, sig_len );
+        psa_destroy_key( key_id );
+
+        return( status == PSA_ERROR_INVALID_SIGNATURE?
+                              MBEDTLS_ERR_RSA_VERIFY_FAILED :
+                              mbedtls_psa_err_translate_pk( status ) );
+    }
+    else
+#endif
+    {
         if( sig_len < mbedtls_pk_get_len( ctx ) )
             return( MBEDTLS_ERR_RSA_VERIFY_FAILED );
 
@@ -376,18 +430,11 @@
 
         if( sig_len > mbedtls_pk_get_len( ctx ) )
             return( MBEDTLS_ERR_PK_SIG_LEN_MISMATCH );
-
-        return( 0 );
-#else
-        return( MBEDTLS_ERR_PK_FEATURE_UNAVAILABLE );
-#endif /* MBEDTLS_RSA_C && MBEDTLS_PKCS1_V21 */
     }
-
-    /* General case: no options */
-    if( options != NULL )
-        return( MBEDTLS_ERR_PK_BAD_INPUT_DATA );
-
-    return( mbedtls_pk_verify( ctx, md_alg, hash, hash_len, sig, sig_len ) );
+    return( 0 );
+#else
+    return( MBEDTLS_ERR_PK_FEATURE_UNAVAILABLE );
+#endif /* MBEDTLS_RSA_C && MBEDTLS_PKCS1_V21 */
 }
 
 /*