test_suite_pk: add some initial testing for mbedtls_pk_copy_from_psa()

Signed-off-by: Valerio Setti <valerio.setti@nordicsemi.no>
diff --git a/tests/suites/test_suite_pk.function b/tests/suites/test_suite_pk.function
index 3d75ad0..3710e3d 100644
--- a/tests/suites/test_suite_pk.function
+++ b/tests/suites/test_suite_pk.function
@@ -1,5 +1,6 @@
 /* BEGIN_HEADER */
 #include "mbedtls/pk.h"
+#include "mbedtls/psa_util.h"
 #include "pk_internal.h"
 
 /* For error codes */
@@ -425,7 +426,65 @@
 }
 #endif
 
-#if defined(MBEDTLS_USE_PSA_CRYPTO)
+#if defined(MBEDTLS_PSA_CRYPTO_CLIENT)
+mbedtls_svc_key_id_t pk_psa_pub_key_from_priv(mbedtls_svc_key_id_t priv_id,
+                                              psa_key_type_t type, psa_key_usage_t usage,
+                                              psa_algorithm_t alg, size_t bits)
+{
+    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
+    unsigned char pub_key_buf[PSA_EXPORT_PUBLIC_KEY_MAX_SIZE];
+    size_t pub_key_len;
+    mbedtls_svc_key_id_t pub_key = MBEDTLS_SVC_KEY_ID_INIT;
+
+    PSA_ASSERT(psa_export_public_key(priv_id, pub_key_buf, sizeof(pub_key_buf), &pub_key_len));
+
+    psa_set_key_usage_flags(&attributes, usage);
+    psa_set_key_algorithm(&attributes, alg);
+    psa_set_key_type(&attributes, type);
+    psa_set_key_bits(&attributes, bits);
+
+    PSA_ASSERT(psa_import_key(&attributes, pub_key_buf, pub_key_len, &pub_key));
+
+exit:
+    return pub_key;
+}
+
+psa_status_t pk_psa_import_key(unsigned char *key_data, size_t key_len,
+                               psa_key_type_t type, psa_key_usage_t usage,
+                               psa_algorithm_t alg, size_t bits,
+                               mbedtls_svc_key_id_t *key)
+{
+    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
+    psa_status_t status;
+
+    *key = MBEDTLS_SVC_KEY_ID_INIT;
+
+    psa_set_key_usage_flags(&attributes, usage);
+    psa_set_key_algorithm(&attributes, alg);
+    psa_set_key_type(&attributes, type);
+    psa_set_key_bits(&attributes, bits);
+    status = psa_import_key(&attributes, key_data, key_len, key);
+
+    return status;
+}
+
+psa_status_t pk_psa_genkey_generic(psa_key_type_t type, psa_key_usage_t usage,
+                                   psa_algorithm_t alg, size_t bits,
+                                   mbedtls_svc_key_id_t *key)
+{
+    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
+    psa_status_t status;
+
+    *key = MBEDTLS_SVC_KEY_ID_INIT;
+
+    psa_set_key_usage_flags(&attributes, usage);
+    psa_set_key_algorithm(&attributes, alg);
+    psa_set_key_type(&attributes, type);
+    psa_set_key_bits(&attributes, bits);
+    status = psa_generate_key(&attributes, key);
+
+    return status;
+}
 
 /*
  * Generate an ECC key using PSA and return the key identifier of that key,
@@ -434,19 +493,12 @@
  */
 mbedtls_svc_key_id_t pk_psa_genkey_ecc(void)
 {
-    mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
-    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
-    const psa_key_type_t type =
-        PSA_KEY_TYPE_ECC_KEY_PAIR(PSA_ECC_FAMILY_SECP_R1);
-    const size_t bits = 256;
+    mbedtls_svc_key_id_t key;
 
-    psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_SIGN_HASH);
-    psa_set_key_algorithm(&attributes, PSA_ALG_ECDSA(PSA_ALG_SHA_256));
-    psa_set_key_type(&attributes, type);
-    psa_set_key_bits(&attributes, bits);
-    PSA_ASSERT(psa_generate_key(&attributes, &key));
+    pk_psa_genkey_generic(PSA_KEY_TYPE_ECC_KEY_PAIR(PSA_ECC_FAMILY_SECP_R1),
+                          PSA_KEY_USAGE_SIGN_HASH, PSA_ALG_ECDSA(PSA_ALG_SHA_256),
+                          256, &key);
 
-exit:
     return key;
 }
 
@@ -456,21 +508,14 @@
  */
 mbedtls_svc_key_id_t pk_psa_genkey_rsa(void)
 {
-    mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
-    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
-    const psa_key_type_t type = PSA_KEY_TYPE_RSA_KEY_PAIR;
-    const size_t bits = 1024;
+    mbedtls_svc_key_id_t key;
 
-    psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_SIGN_HASH);
-    psa_set_key_algorithm(&attributes, PSA_ALG_RSA_PKCS1V15_SIGN_RAW);
-    psa_set_key_type(&attributes, type);
-    psa_set_key_bits(&attributes, bits);
-    PSA_ASSERT(psa_generate_key(&attributes, &key));
+    pk_psa_genkey_generic(PSA_KEY_TYPE_RSA_KEY_PAIR, PSA_KEY_USAGE_SIGN_HASH,
+                          PSA_ALG_RSA_PKCS1V15_SIGN_RAW, 1024, &key);
 
-exit:
     return key;
 }
-#endif /* MBEDTLS_USE_PSA_CRYPTO */
+#endif /* MBEDTLS_PSA_CRYPTO_CLIENT */
 /* END_HEADER */
 
 /* BEGIN_DEPENDENCIES
@@ -2199,3 +2244,202 @@
     PSA_DONE();
 }
 /* END_CASE */
+
+/* BEGIN_CASE depends_on:MBEDTLS_PSA_CRYPTO_CLIENT*/
+void pk_copy_from_psa_fail(void)
+{
+    mbedtls_pk_context pk_ctx;
+    mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
+#if defined(MBEDTLS_USE_PSA_CRYPTO) && defined(MBEDTLS_RSA_C) && \
+    defined(MBEDTLS_MD_CAN_SHA256) && defined(MBEDTLS_PKCS1_V21)
+    unsigned char in_buf[32]; /* Only SHA256 is used here. */
+    unsigned char out_buf[256]; /* Only 2048 RSA bit size is used here. */
+    size_t out_buf_len;
+#endif /* MBEDTLS_USE_PSA_CRYPTO && MBEDTLS_RSA_C &&
+          MBEDTLS_MD_CAN_SHA256 && MBEDTLS_PKCS1_V21 */
+
+    mbedtls_pk_init(&pk_ctx);
+    PSA_INIT();
+
+    /* Null pk pointer. */
+    TEST_EQUAL(mbedtls_pk_copy_from_psa(key_id, NULL),
+               MBEDTLS_ERR_PK_BAD_INPUT_DATA);
+
+    /* Invalid key ID. */
+    TEST_EQUAL(mbedtls_pk_copy_from_psa(mbedtls_svc_key_id_make(0, 0), &pk_ctx),
+               MBEDTLS_ERR_PK_BAD_INPUT_DATA);
+
+#if defined(PSA_WANT_KEY_TYPE_DH_KEY_PAIR_GENERATE)
+    /* Generate a key type that is not handled by the PK module. */
+    PSA_ASSERT(pk_psa_genkey_generic(PSA_KEY_TYPE_DH_KEY_PAIR(PSA_DH_FAMILY_RFC7919),
+                                     PSA_KEY_USAGE_EXPORT, PSA_ALG_NONE, 2048, &key_id));
+    TEST_EQUAL(mbedtls_pk_copy_from_psa(key_id, &pk_ctx), MBEDTLS_ERR_PK_BAD_INPUT_DATA);
+    psa_destroy_key(key_id);
+#endif /* PSA_WANT_KEY_TYPE_DH_KEY_PAIR_GENERATE */
+
+#if defined(MBEDTLS_PK_HAVE_ECC_KEYS) && defined(PSA_WANT_ECC_SECP_R1_256) && \
+    defined(PSA_WANT_KEY_TYPE_ECC_KEY_PAIR_GENERATE)
+    /* Generate an EC key which cannot be exported. */
+    PSA_ASSERT(pk_psa_genkey_generic(PSA_KEY_TYPE_ECC_KEY_PAIR(PSA_ECC_FAMILY_SECP_R1),
+                                     0, PSA_ALG_NONE, 256, &key_id));
+    TEST_EQUAL(mbedtls_pk_copy_from_psa(key_id, &pk_ctx), MBEDTLS_ERR_PK_BAD_INPUT_DATA);
+    psa_destroy_key(key_id);
+
+    /* Use valid exportable EC key with wrong alorithm. */
+    PSA_ASSERT(pk_psa_genkey_generic(PSA_KEY_TYPE_ECC_KEY_PAIR(PSA_ECC_FAMILY_SECP_R1),
+                                     PSA_KEY_USAGE_EXPORT | PSA_KEY_USAGE_SIGN_HASH,
+                                     PSA_ALG_ECDH, 256, &key_id));
+    TEST_EQUAL(mbedtls_pk_copy_from_psa(key_id, &pk_ctx), MBEDTLS_ERR_PK_BAD_INPUT_DATA);
+    psa_destroy_key(key_id);
+#endif /* MBEDTLS_PK_HAVE_ECC_KEYS && PSA_WANT_ECC_SECP_R1_256 &&
+          PSA_WANT_KEY_TYPE_ECC_KEY_PAIR_GENERATE */
+
+#if defined(MBEDTLS_RSA_C) && defined(PSA_WANT_KEY_TYPE_RSA_KEY_PAIR_GENERATE)
+    /* Use valid exportable RSA key with wrong alorithm. */
+    PSA_ASSERT(pk_psa_genkey_generic(PSA_KEY_TYPE_RSA_KEY_PAIR,
+                                     PSA_KEY_USAGE_EXPORT | PSA_KEY_USAGE_SIGN_HASH,
+                                     PSA_ALG_CMAC, 2048, &key_id));
+    TEST_EQUAL(mbedtls_pk_copy_from_psa(key_id, &pk_ctx), MBEDTLS_ERR_PK_BAD_INPUT_DATA);
+    psa_destroy_key(key_id);
+
+    /* Try to encrypt with a RSA key in PKCS1V21 format. */
+#if defined(MBEDTLS_USE_PSA_CRYPTO) && defined(MBEDTLS_MD_CAN_SHA256) && defined(MBEDTLS_PKCS1_V21)
+    PSA_ASSERT(pk_psa_genkey_generic(PSA_KEY_TYPE_RSA_KEY_PAIR,
+                                     PSA_KEY_USAGE_EXPORT | PSA_KEY_USAGE_ENCRYPT,
+                                     PSA_ALG_RSA_OAEP(PSA_ALG_SHA_256), 2048, &key_id));
+    TEST_EQUAL(mbedtls_pk_copy_from_psa(key_id, &pk_ctx), 0);
+    TEST_EQUAL(mbedtls_pk_encrypt(&pk_ctx, in_buf, sizeof(in_buf),
+                                  out_buf, &out_buf_len, sizeof(out_buf),
+                                  mbedtls_test_rnd_std_rand, NULL),
+               MBEDTLS_ERR_RSA_INVALID_PADDING);
+    psa_destroy_key(key_id);
+#endif /* MBEDTLS_USE_PSA_CRYPTO && MBEDTLS_MD_CAN_SHA256 && MBEDTLS_PKCS1_V21*/
+#endif /* MBEDTLS_RSA_C && PSA_WANT_KEY_TYPE_RSA_KEY_PAIR_GENERATE */
+
+exit:
+    mbedtls_pk_free(&pk_ctx);
+    psa_destroy_key(key_id);
+    PSA_DONE();
+}
+/* END_CASE */
+
+/* BEGIN_CASE depends_on:MBEDTLS_PSA_CRYPTO_CLIENT*/
+void pk_copy_from_psa_success(data_t *priv_key_data, int key_type_arg,
+                              int key_bits_arg, int key_usage_arg,
+                              int key_alg_arg)
+{
+    psa_key_type_t key_type = key_type_arg;
+    psa_key_type_t pub_key_type;
+    size_t key_bits = key_bits_arg;
+    psa_key_usage_t key_usage = key_usage_arg;
+    psa_algorithm_t key_alg = key_alg_arg;
+    mbedtls_pk_context pk_ctx, pk_ctx2;
+    mbedtls_svc_key_id_t priv_key_id = MBEDTLS_SVC_KEY_ID_INIT;
+    mbedtls_svc_key_id_t pub_key_id = MBEDTLS_SVC_KEY_ID_INIT;
+    unsigned char *in_buf = NULL;
+    size_t in_buf_len = MBEDTLS_MD_MAX_SIZE;
+    unsigned char out_buf[MBEDTLS_PK_SIGNATURE_MAX_SIZE];
+    unsigned char out_buf2[MBEDTLS_PK_SIGNATURE_MAX_SIZE];
+    size_t out_buf_len, out_buf2_len;
+
+    /* Get the MD type to be used for the tests below from the provided key policy. */
+    mbedtls_md_type_t md_for_test = MBEDTLS_MD_SHA256; /* Default */
+    if ((PSA_ALG_GET_HASH(key_alg) != 0) &&
+        (PSA_ALG_GET_HASH(key_alg) != PSA_ALG_ANY_HASH)) {
+        md_for_test = mbedtls_md_type_from_psa_alg(key_alg);
+    }
+
+    in_buf_len = mbedtls_md_get_size_from_type(md_for_test);
+    TEST_CALLOC(in_buf, in_buf_len);
+    memset(in_buf, 0x1, in_buf_len);
+
+    mbedtls_pk_init(&pk_ctx);
+    mbedtls_pk_init(&pk_ctx2);
+    PSA_INIT();
+
+    /* Generate a private key in PSA and create a PK context from it. */
+    PSA_ASSERT(pk_psa_import_key(priv_key_data->x, priv_key_data->len,
+                                 key_type, key_usage, key_alg, key_bits, &priv_key_id));
+    TEST_EQUAL(mbedtls_pk_copy_from_psa(priv_key_id, &pk_ctx), 0);
+
+    /* Starting from the private key above, create another PSA slot for the public
+     * one and create a new PK context from it. */
+    if (PSA_KEY_TYPE_IS_ECC_KEY_PAIR(key_type)) {
+        pub_key_type = PSA_KEY_TYPE_ECC_PUBLIC_KEY(PSA_KEY_TYPE_ECC_GET_FAMILY(key_type));
+    } else if (key_type == PSA_KEY_TYPE_RSA_KEY_PAIR) {
+        pub_key_type = PSA_KEY_TYPE_RSA_PUBLIC_KEY;
+    } else {
+        TEST_FAIL("Key type can only be EC or RSA key pair");
+    }
+
+    /* Generate a 2nd PK contex using only the public key derived from its private
+     * counterpart generated above.  */
+    pub_key_id = pk_psa_pub_key_from_priv(priv_key_id, pub_key_type, key_usage, key_alg, key_bits);
+    TEST_EQUAL(mbedtls_pk_copy_from_psa(pub_key_id, &pk_ctx2), 0);
+
+    /* Test sign/verify with the following parttern:
+     * - Sign using the PK context generated from the private key.
+     * - Verify from the same PK context used for signature.
+     * - Verify with the PK context generated using public key.
+     */
+    if ((PSA_ALG_IS_RSA_OAEP(key_alg) || PSA_ALG_IS_RSA_PSS(key_alg))) {
+        mbedtls_pk_rsassa_pss_options pss_opt = {
+            .mgf1_hash_id = md_for_test,
+            .expected_salt_len = MBEDTLS_RSA_SALT_LEN_ANY,
+        };
+
+        TEST_EQUAL(mbedtls_pk_sign_ext(MBEDTLS_PK_RSASSA_PSS, &pk_ctx, md_for_test,
+                                       in_buf, in_buf_len,
+                                       out_buf, sizeof(out_buf), &out_buf_len,
+                                       mbedtls_test_rnd_std_rand, NULL), 0);
+        TEST_EQUAL(mbedtls_pk_verify_ext(MBEDTLS_PK_RSASSA_PSS, &pss_opt,
+                                         &pk_ctx, md_for_test, in_buf, in_buf_len,
+                                         out_buf, out_buf_len), 0);
+        TEST_EQUAL(mbedtls_pk_verify_ext(MBEDTLS_PK_RSASSA_PSS, &pss_opt,
+                                         &pk_ctx2, md_for_test, in_buf, in_buf_len,
+                                         out_buf, out_buf_len), 0);
+    } else {
+        TEST_EQUAL(mbedtls_pk_sign(&pk_ctx, md_for_test, in_buf, in_buf_len,
+                                   out_buf, sizeof(out_buf), &out_buf_len,
+                                   mbedtls_test_rnd_std_rand, NULL), 0);
+        TEST_EQUAL(mbedtls_pk_verify(&pk_ctx, md_for_test, in_buf, in_buf_len,
+                                     out_buf, out_buf_len), 0);
+        TEST_EQUAL(mbedtls_pk_verify(&pk_ctx2, md_for_test, in_buf, in_buf_len,
+                                     out_buf, out_buf_len), 0);
+    }
+
+    int test_encryption = 0;
+
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    test_encryption = (PSA_ALG_IS_RSA_PKCS1V15_SIGN(key_alg) ||
+                       (key_alg == PSA_ALG_RSA_PKCS1V15_CRYPT));
+#else
+    test_encryption = ((PSA_ALG_GET_HASH(key_alg) != 0) &&
+                       (PSA_ALG_GET_HASH(key_alg) != PSA_ALG_ANY_HASH));
+#endif
+
+    /* In case of RSA key pair try also encryption/decryption. */
+    if (key_type == PSA_KEY_TYPE_RSA_KEY_PAIR) {
+        if (test_encryption) {
+            /* Encrypt with the 2nd PK context (public key only). */
+            TEST_EQUAL(mbedtls_pk_encrypt(&pk_ctx2, in_buf, in_buf_len,
+                                          out_buf, &out_buf_len, sizeof(out_buf),
+                                          mbedtls_test_rnd_std_rand, NULL), 0);
+
+            /* Decrypt with 1st PK context and compare with original data. */
+            TEST_EQUAL(mbedtls_pk_decrypt(&pk_ctx, out_buf, out_buf_len,
+                                          out_buf2, &out_buf2_len, sizeof(out_buf2),
+                                          mbedtls_test_rnd_std_rand, NULL), 0);
+            TEST_MEMORY_COMPARE(in_buf, in_buf_len, out_buf2, out_buf2_len);
+        }
+    }
+
+exit:
+    mbedtls_free(in_buf);
+    mbedtls_pk_free(&pk_ctx);
+    mbedtls_pk_free(&pk_ctx2);
+    psa_destroy_key(priv_key_id);
+    psa_destroy_key(pub_key_id);
+    PSA_DONE();
+}
+/* END_CASE */