test_suite_pk: add test cases for RSA keys (sign/verify & crypt/decrypt)

Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
Signed-off-by: Valerio Setti <valerio.setti@nordicsemi.no>
diff --git a/tests/suites/test_suite_pk.function b/tests/suites/test_suite_pk.function
index ff843cb..946c52f 100644
--- a/tests/suites/test_suite_pk.function
+++ b/tests/suites/test_suite_pk.function
@@ -800,9 +800,9 @@
 /* END_CASE */
 
 /* BEGIN_CASE depends_on:MBEDTLS_RSA_C */
-void pk_rsa_verify_test_vec(data_t *message_str, int digest, int mod,
-                            char *input_N, char *input_E,
-                            data_t *result_str, int result)
+void pk_rsa_verify_test_vec(data_t *message_str, int padding, int digest,
+                            int mod, char *input_N, char *input_E,
+                            data_t *result_str, int expected_result)
 {
     mbedtls_rsa_context *rsa;
     mbedtls_pk_context pk;
@@ -817,28 +817,54 @@
 #endif
 
     mbedtls_pk_init(&pk);
-    USE_PSA_INIT();
+    MD_OR_USE_PSA_INIT();
 
     TEST_ASSERT(mbedtls_pk_setup(&pk, mbedtls_pk_info_from_type(MBEDTLS_PK_RSA)) == 0);
     rsa = mbedtls_pk_rsa(pk);
 
     rsa->len = (mod + 7) / 8;
+    if (padding >= 0) {
+        TEST_EQUAL(mbedtls_rsa_set_padding(rsa, padding, MBEDTLS_MD_NONE), 0);
+    }
+
     TEST_ASSERT(mbedtls_test_read_mpi(&rsa->N, input_N) == 0);
     TEST_ASSERT(mbedtls_test_read_mpi(&rsa->E, input_E) == 0);
 
-    TEST_ASSERT(mbedtls_pk_verify(&pk, digest, message_str->x, 0,
-                                  result_str->x, mbedtls_pk_get_len(&pk)) == result);
+    int actual_result;
+    actual_result = mbedtls_pk_verify(&pk, digest, message_str->x, 0,
+                                      result_str->x, mbedtls_pk_get_len(&pk));
+#if !defined(MBEDTLS_USE_PSA_CRYPTO)
+    if (actual_result == MBEDTLS_ERR_RSA_INVALID_PADDING &&
+        expected_result == MBEDTLS_ERR_RSA_VERIFY_FAILED) {
+        /* Tolerate INVALID_PADDING error for an invalid signature with
+         * the legacy API (but not with PSA). */
+    } else
+#endif
+    {
+        TEST_EQUAL(actual_result, expected_result);
+    }
 
-    TEST_ASSERT(mbedtls_pk_verify_restartable(&pk, digest, message_str->x, 0,
-                                              result_str->x, mbedtls_pk_get_len(
-                                                  &pk), rs_ctx) == result);
+    actual_result = mbedtls_pk_verify_restartable(&pk, digest, message_str->x, 0,
+                                                  result_str->x,
+                                                  mbedtls_pk_get_len(&pk),
+                                                  rs_ctx);
+#if !defined(MBEDTLS_USE_PSA_CRYPTO)
+    if (actual_result == MBEDTLS_ERR_RSA_INVALID_PADDING &&
+        expected_result == MBEDTLS_ERR_RSA_VERIFY_FAILED) {
+        /* Tolerate INVALID_PADDING error for an invalid signature with
+         * the legacy API (but not with PSA). */
+    } else
+#endif
+    {
+        TEST_EQUAL(actual_result, expected_result);
+    }
 
 exit:
 #if defined(MBEDTLS_ECDSA_C) && defined(MBEDTLS_ECP_RESTARTABLE)
     mbedtls_pk_restart_free(rs_ctx);
 #endif
     mbedtls_pk_free(&pk);
-    USE_PSA_DONE();
+    MD_OR_USE_PSA_DONE();
 }
 /* END_CASE */
 
@@ -1027,7 +1053,8 @@
 /* END_CASE */
 
 /* BEGIN_CASE depends_on:MBEDTLS_MD_CAN_SHA256 */
-void pk_sign_verify(int type, int curve_or_keybits, int sign_ret, int verify_ret)
+void pk_sign_verify(int type, int curve_or_keybits, int rsa_padding, int rsa_md_alg,
+                    int sign_ret, int verify_ret)
 {
     mbedtls_pk_context pk;
     size_t sig_len;
@@ -1055,6 +1082,17 @@
     TEST_ASSERT(mbedtls_pk_setup(&pk, mbedtls_pk_info_from_type(type)) == 0);
     TEST_ASSERT(pk_genkey(&pk, curve_or_keybits) == 0);
 
+#if defined(MBEDTLS_RSA_C)
+    if (type == MBEDTLS_PK_RSA) {
+        /* Just pick SHA1 here as hashing algorithm as we're more interested
+         * in checking the compatibility between */
+        TEST_ASSERT(mbedtls_rsa_set_padding(mbedtls_pk_rsa(pk), rsa_padding, rsa_md_alg) == 0);
+    }
+#else
+    (void) rsa_padding;
+    (void) rsa_md_alg;
+#endif /* MBEDTLS_RSA_C */
+
     TEST_ASSERT(mbedtls_pk_sign_restartable(&pk, MBEDTLS_MD_SHA256,
                                             hash, hash_len,
                                             sig, sizeof(sig), &sig_len,
@@ -1194,7 +1232,7 @@
 /* END_CASE */
 
 /* BEGIN_CASE depends_on:MBEDTLS_RSA_C */
-void pk_rsa_decrypt_test_vec(data_t *cipher, int mod,
+void pk_rsa_decrypt_test_vec(data_t *cipher, int mod, int padding, int md_alg,
                              char *input_P, char *input_Q,
                              char *input_N, char *input_E,
                              data_t *clear, int ret)
@@ -1209,7 +1247,7 @@
     mbedtls_pk_init(&pk);
     mbedtls_mpi_init(&N); mbedtls_mpi_init(&P);
     mbedtls_mpi_init(&Q); mbedtls_mpi_init(&E);
-    USE_PSA_INIT();
+    MD_OR_USE_PSA_INIT();
 
     memset(&rnd_info,  0, sizeof(mbedtls_test_rnd_pseudo_info));
 
@@ -1231,6 +1269,11 @@
     TEST_EQUAL(mbedtls_pk_get_bitlen(&pk), mod);
     TEST_EQUAL(mbedtls_pk_get_len(&pk), (mod + 7) / 8);
 
+    /* set padding mode */
+    if (padding >= 0) {
+        TEST_EQUAL(mbedtls_rsa_set_padding(rsa, padding, md_alg), 0);
+    }
+
     /* decryption test */
     memset(output, 0, sizeof(output));
     olen = 0;
@@ -1246,7 +1289,7 @@
     mbedtls_mpi_free(&N); mbedtls_mpi_free(&P);
     mbedtls_mpi_free(&Q); mbedtls_mpi_free(&E);
     mbedtls_pk_free(&pk);
-    USE_PSA_DONE();
+    MD_OR_USE_PSA_DONE();
 }
 /* END_CASE */