pk_wrap: set proper PSA algin rsa wrappers based on padding mode set in RSA context

Signed-off-by: Valerio Setti <valerio.setti@nordicsemi.no>
diff --git a/library/pk_wrap.c b/library/pk_wrap.c
index 69e1baf..b472cfb 100644
--- a/library/pk_wrap.c
+++ b/library/pk_wrap.c
@@ -74,8 +74,7 @@
     int key_len;
     unsigned char buf[MBEDTLS_PK_RSA_PUB_DER_MAX_BYTES];
     unsigned char *p = buf + sizeof(buf);
-    psa_algorithm_t psa_alg_md =
-        PSA_ALG_RSA_PKCS1V15_SIGN(mbedtls_md_psa_alg_from_type(md_alg));
+    psa_algorithm_t psa_alg_md;
     size_t rsa_len = mbedtls_rsa_get_len(rsa);
 
 #if SIZE_MAX > UINT_MAX
@@ -84,6 +83,12 @@
     }
 #endif
 
+    if (mbedtls_rsa_get_padding_mode(rsa) == MBEDTLS_RSA_PKCS_V21) {
+        psa_alg_md = PSA_ALG_RSA_PSS(mbedtls_md_psa_alg_from_type(md_alg));
+    } else {
+        psa_alg_md = PSA_ALG_RSA_PKCS1V15_SIGN(mbedtls_md_psa_alg_from_type(md_alg));
+    }
+
     if (sig_len < rsa_len) {
         return MBEDTLS_ERR_RSA_VERIFY_FAILED;
     }
@@ -235,10 +240,14 @@
     if (psa_md_alg == 0) {
         return MBEDTLS_ERR_PK_BAD_INPUT_DATA;
     }
+    psa_algorithm_t psa_alg;
+    if (mbedtls_rsa_get_padding_mode(mbedtls_pk_rsa(*pk)) == MBEDTLS_RSA_PKCS_V21) {
+        psa_alg = PSA_ALG_RSA_PSS(psa_md_alg);
+    } else {
+        psa_alg = PSA_ALG_RSA_PKCS1V15_SIGN(psa_md_alg);
+    }
 
-    return mbedtls_pk_psa_rsa_sign_ext(PSA_ALG_RSA_PKCS1V15_SIGN(
-                                           psa_md_alg),
-                                       pk->pk_ctx, hash, hash_len,
+    return mbedtls_pk_psa_rsa_sign_ext(psa_alg, pk->pk_ctx, hash, hash_len,
                                        sig, sig_size, sig_len);
 }
 #else /* MBEDTLS_USE_PSA_CRYPTO */
@@ -276,6 +285,7 @@
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
     mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
+    psa_algorithm_t psa_md_alg, decrypt_alg;
     psa_status_t status;
     int key_len;
     unsigned char buf[MBEDTLS_PK_RSA_PRV_DER_MAX_BYTES];
@@ -284,12 +294,6 @@
     ((void) f_rng);
     ((void) p_rng);
 
-#if !defined(MBEDTLS_RSA_ALT)
-    if (rsa->padding != MBEDTLS_RSA_PKCS_V15) {
-        return MBEDTLS_ERR_RSA_INVALID_PADDING;
-    }
-#endif /* !MBEDTLS_RSA_ALT */
-
     if (ilen != mbedtls_rsa_get_len(rsa)) {
         return MBEDTLS_ERR_RSA_BAD_INPUT_DATA;
     }
@@ -301,7 +305,13 @@
 
     psa_set_key_type(&attributes, PSA_KEY_TYPE_RSA_KEY_PAIR);
     psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_DECRYPT);
-    psa_set_key_algorithm(&attributes, PSA_ALG_RSA_PKCS1V15_CRYPT);
+    if (mbedtls_rsa_get_padding_mode(rsa) == MBEDTLS_RSA_PKCS_V21) {
+        psa_md_alg = mbedtls_md_psa_alg_from_type(mbedtls_rsa_get_md_alg(rsa));
+        decrypt_alg = PSA_ALG_RSA_OAEP(psa_md_alg);
+    } else {
+        decrypt_alg = PSA_ALG_RSA_PKCS1V15_CRYPT;
+    }
+    psa_set_key_algorithm(&attributes, decrypt_alg);
 
     status = psa_import_key(&attributes,
                             buf + sizeof(buf) - key_len, key_len,
@@ -311,7 +321,7 @@
         goto cleanup;
     }
 
-    status = psa_asymmetric_decrypt(key_id, PSA_ALG_RSA_PKCS1V15_CRYPT,
+    status = psa_asymmetric_decrypt(key_id, decrypt_alg,
                                     input, ilen,
                                     NULL, 0,
                                     output, osize, olen);
@@ -358,6 +368,7 @@
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
     mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
+    psa_algorithm_t psa_md_alg;
     psa_status_t status;
     int key_len;
     unsigned char buf[MBEDTLS_PK_RSA_PUB_DER_MAX_BYTES];
@@ -366,12 +377,6 @@
     ((void) f_rng);
     ((void) p_rng);
 
-#if !defined(MBEDTLS_RSA_ALT)
-    if (rsa->padding != MBEDTLS_RSA_PKCS_V15) {
-        return MBEDTLS_ERR_RSA_INVALID_PADDING;
-    }
-#endif
-
     if (mbedtls_rsa_get_len(rsa) > osize) {
         return MBEDTLS_ERR_RSA_OUTPUT_TOO_LARGE;
     }
@@ -382,7 +387,12 @@
     }
 
     psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_ENCRYPT);
-    psa_set_key_algorithm(&attributes, PSA_ALG_RSA_PKCS1V15_CRYPT);
+    if (mbedtls_rsa_get_padding_mode(rsa) == MBEDTLS_RSA_PKCS_V21) {
+        psa_md_alg = mbedtls_md_psa_alg_from_type(mbedtls_rsa_get_md_alg(rsa));
+        psa_set_key_algorithm(&attributes, PSA_ALG_RSA_OAEP(psa_md_alg));
+    } else {
+        psa_set_key_algorithm(&attributes, PSA_ALG_RSA_PKCS1V15_CRYPT);
+    }
     psa_set_key_type(&attributes, PSA_KEY_TYPE_RSA_PUBLIC_KEY);
 
     status = psa_import_key(&attributes,