New function mbedtls_pk_copy_public_from_psa

Document and implement mbedtls_pk_copy_public_from_psa() to export the
public key of a PSA key into PK.

Unit-test it alongside mbedtls_pk_copy_from_psa().

Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
diff --git a/ChangeLog.d/8709.txt b/ChangeLog.d/8709.txt
index a9b2751..e0bea44 100644
--- a/ChangeLog.d/8709.txt
+++ b/ChangeLog.d/8709.txt
@@ -1,3 +1,4 @@
 Features
-   * The new function mbedtls_pk_copy_from_psa() provides a way to set up a PK
-     context with the same content as a PSA key.
+   * The new functions mbedtls_pk_copy_from_psa() and
+     mbedtls_pk_copy_public_from_psa() provide ways to set up a PK context
+     with the same content as a PSA key.
diff --git a/include/mbedtls/pk.h b/include/mbedtls/pk.h
index c1fb605..d2e8674 100644
--- a/include/mbedtls/pk.h
+++ b/include/mbedtls/pk.h
@@ -426,6 +426,39 @@
  *                  parameters are not correct.
  */
 int mbedtls_pk_copy_from_psa(mbedtls_svc_key_id_t key_id, mbedtls_pk_context *pk);
+
+/**
+ * \brief           Create a PK context for the public key of a PSA key.
+ *
+ *                  The key must be an RSA or ECC key. It can be either a
+ *                  public key or a key pair, and only the public key is copied.
+ *                  The resulting PK object will be a transparent type:
+ *                  - #MBEDTLS_PK_RSA for RSA keys or
+ *                  - #MBEDTLS_PK_ECKEY for EC keys.
+ *
+ *                  Once this functions returns the PK object will be completely
+ *                  independent from the original PSA key that it was generated
+ *                  from.
+ *                  Calling mbedtls_pk_verify() or
+ *                  mbedtls_pk_encrypt() on the resulting
+ *                  PK context will perform the corresponding algorithm for that
+ *                  PK context type.
+ *
+ *                  For an RSA key, the output PK context will allow both
+ *                  encrypt and verify regardless of the original key's policy.
+ *                  The original key's policy determines the output key's padding
+ *                  mode: PCKS1 v2.1 is set if the PSA key policy is OAEP or PSS,
+ *                  otherwise PKCS1 v1.5 is set.
+ *
+ * \param key_id    The key identifier of the key stored in PSA.
+ * \param pk        The PK context that will be filled. It must be initialized,
+ *                  but not set up.
+ *
+ * \return          0 on success.
+ * \return          MBEDTLS_ERR_PK_BAD_INPUT_DATA in case the provided input
+ *                  parameters are not correct.
+ */
+int mbedtls_pk_copy_public_from_psa(mbedtls_svc_key_id_t key_id, mbedtls_pk_context *pk);
 #endif /* MBEDTLS_PSA_CRYPTO_C */
 
 #if defined(MBEDTLS_PK_RSA_ALT_SUPPORT)
diff --git a/library/pk.c b/library/pk.c
index 4345ea2..7bc1da8 100644
--- a/library/pk.c
+++ b/library/pk.c
@@ -1379,7 +1379,9 @@
 }
 
 #if defined(MBEDTLS_PSA_CRYPTO_C)
-int mbedtls_pk_copy_from_psa(mbedtls_svc_key_id_t key_id, mbedtls_pk_context *pk)
+static int copy_from_psa(mbedtls_svc_key_id_t key_id,
+                         mbedtls_pk_context *pk,
+                         int public_only)
 {
     psa_status_t status;
     psa_key_attributes_t key_attr = PSA_KEY_ATTRIBUTES_INIT;
@@ -1400,13 +1402,20 @@
         return MBEDTLS_ERR_PK_BAD_INPUT_DATA;
     }
 
-    status = psa_export_key(key_id, exp_key, sizeof(exp_key), &exp_key_len);
+    if (public_only) {
+        status = psa_export_public_key(key_id, exp_key, sizeof(exp_key), &exp_key_len);
+    } else {
+        status = psa_export_key(key_id, exp_key, sizeof(exp_key), &exp_key_len);
+    }
     if (status != PSA_SUCCESS) {
         ret = PSA_PK_TO_MBEDTLS_ERR(status);
         goto exit;
     }
 
     key_type = psa_get_key_type(&key_attr);
+    if (public_only) {
+        key_type = PSA_KEY_TYPE_PUBLIC_KEY_OF_KEY_PAIR(key_type);
+    }
     key_bits = psa_get_key_bits(&key_attr);
     alg_type = psa_get_key_algorithm(&key_attr);
 
@@ -1485,6 +1494,19 @@
 
     return ret;
 }
+
+
+int mbedtls_pk_copy_from_psa(mbedtls_svc_key_id_t key_id,
+                             mbedtls_pk_context *pk)
+{
+    return copy_from_psa(key_id, pk, 0);
+}
+
+int mbedtls_pk_copy_public_from_psa(mbedtls_svc_key_id_t key_id,
+                                    mbedtls_pk_context *pk)
+{
+    return copy_from_psa(key_id, pk, 1);
+}
 #endif /* MBEDTLS_PSA_CRYPTO_C */
 
 #endif /* MBEDTLS_PK_C */
diff --git a/tests/suites/test_suite_pk.function b/tests/suites/test_suite_pk.function
index 75dc4ac..d955ab6 100644
--- a/tests/suites/test_suite_pk.function
+++ b/tests/suites/test_suite_pk.function
@@ -322,6 +322,83 @@
     expected_usage |= PSA_KEY_USAGE_EXPORT | PSA_KEY_USAGE_COPY;
     return expected_usage;
 }
+
+#define RSA_WRITE_PUBKEY_MAX_SIZE                                       \
+    PSA_KEY_EXPORT_RSA_PUBLIC_KEY_MAX_SIZE(PSA_VENDOR_RSA_MAX_KEY_BITS)
+#define ECP_WRITE_PUBKEY_MAX_SIZE                                       \
+    PSA_KEY_EXPORT_ECC_PUBLIC_KEY_MAX_SIZE(PSA_VENDOR_ECC_MAX_CURVE_BITS)
+static int pk_public_same(const mbedtls_pk_context *pk1,
+                          const mbedtls_pk_context *pk2)
+{
+    int ok = 0;
+
+    mbedtls_pk_type_t type = mbedtls_pk_get_type(pk1);
+    TEST_EQUAL(type, mbedtls_pk_get_type(pk2));
+
+    switch (type) {
+#if defined(MBEDTLS_RSA_C)
+        case MBEDTLS_PK_RSA:
+        {
+            const mbedtls_rsa_context *rsa1 = mbedtls_pk_rsa(*pk1);
+            const mbedtls_rsa_context *rsa2 = mbedtls_pk_rsa(*pk2);
+            TEST_EQUAL(mbedtls_rsa_get_padding_mode(rsa1),
+                       mbedtls_rsa_get_padding_mode(rsa2));
+            TEST_EQUAL(mbedtls_rsa_get_md_alg(rsa1),
+                       mbedtls_rsa_get_md_alg(rsa2));
+            unsigned char buf1[RSA_WRITE_PUBKEY_MAX_SIZE];
+            unsigned char *p1 = buf1 + sizeof(buf1);
+            int len1 = mbedtls_rsa_write_pubkey(rsa1, buf1, &p1);
+            TEST_LE_U(0, len1);
+            unsigned char buf2[RSA_WRITE_PUBKEY_MAX_SIZE];
+            unsigned char *p2 = buf2 + sizeof(buf2);
+            int len2 = mbedtls_rsa_write_pubkey(rsa2, buf2, &p2);
+            TEST_LE_U(0, len2);
+            TEST_MEMORY_COMPARE(p1, len1, p2, len2);
+            break;
+        }
+#endif /* MBEDTLS_RSA_C */
+
+#if defined(MBEDTLS_PK_HAVE_ECC_KEYS)
+        case MBEDTLS_PK_ECKEY:
+        case MBEDTLS_PK_ECKEY_DH:
+        case MBEDTLS_PK_ECDSA:
+        {
+#if defined(MBEDTLS_PK_USE_PSA_EC_DATA)
+            TEST_MEMORY_COMPARE(pk1->pub_raw, pk1->pub_raw_len,
+                                pk2->pub_raw, pk2->pub_raw_len);
+            TEST_EQUAL(pk1->ec_family, pk2->ec_family);
+            TEST_EQUAL(pk1->ec_bits, pk2->ec_bits);
+
+#else /* MBEDTLS_PK_USE_PSA_EC_DATA */
+            const mbedtls_ecp_keypair *ec1 = mbedtls_pk_ec_ro(*pk1);
+            const mbedtls_ecp_keypair *ec2 = mbedtls_pk_ec_ro(*pk2);
+            TEST_EQUAL(mbedtls_ecp_keypair_get_group_id(ec1),
+                       mbedtls_ecp_keypair_get_group_id(ec2));
+            unsigned char buf1[ECP_WRITE_PUBKEY_MAX_SIZE];
+            size_t len1 = 99999991;
+            TEST_EQUAL(mbedtls_ecp_write_public_key(
+                           ec1, MBEDTLS_ECP_PF_UNCOMPRESSED,
+                           &len1, buf1, sizeof(buf1)), 0);
+            unsigned char buf2[ECP_WRITE_PUBKEY_MAX_SIZE];
+            size_t len2 = 99999992;
+            TEST_EQUAL(mbedtls_ecp_write_public_key(
+                           ec2, MBEDTLS_ECP_PF_UNCOMPRESSED,
+                           &len2, buf2, sizeof(buf2)), 0);
+            TEST_MEMORY_COMPARE(buf1, len1, buf2, len2);
+#endif /* MBEDTLS_PK_USE_PSA_EC_DATA */
+        }
+        break;
+#endif /* MBEDTLS_PK_HAVE_ECC_KEYS */
+
+        default:
+            TEST_FAIL("Unsupported pk type in pk_public_same");
+    }
+
+    ok = 1;
+
+exit:
+    return ok;
+}
 #endif /* MBEDTLS_PSA_CRYPTO_C */
 
 #if defined(MBEDTLS_RSA_C)
@@ -2322,16 +2399,21 @@
     /* Null pk pointer. */
     TEST_EQUAL(mbedtls_pk_copy_from_psa(key_id, NULL),
                MBEDTLS_ERR_PK_BAD_INPUT_DATA);
+    TEST_EQUAL(mbedtls_pk_copy_public_from_psa(key_id, NULL),
+               MBEDTLS_ERR_PK_BAD_INPUT_DATA);
 
     /* Invalid key ID. */
     TEST_EQUAL(mbedtls_pk_copy_from_psa(mbedtls_svc_key_id_make(0, 0), &pk_ctx),
                MBEDTLS_ERR_PK_BAD_INPUT_DATA);
+    TEST_EQUAL(mbedtls_pk_copy_public_from_psa(mbedtls_svc_key_id_make(0, 0), &pk_ctx),
+               MBEDTLS_ERR_PK_BAD_INPUT_DATA);
 
 #if defined(PSA_WANT_KEY_TYPE_DH_KEY_PAIR_GENERATE)
     /* Generate a key type that is not handled by the PK module. */
     PSA_ASSERT(pk_psa_genkey_generic(PSA_KEY_TYPE_DH_KEY_PAIR(PSA_DH_FAMILY_RFC7919), 2048,
                                      PSA_KEY_USAGE_EXPORT, PSA_ALG_NONE, &key_id));
     TEST_EQUAL(mbedtls_pk_copy_from_psa(key_id, &pk_ctx), MBEDTLS_ERR_PK_BAD_INPUT_DATA);
+    TEST_EQUAL(mbedtls_pk_copy_public_from_psa(key_id, &pk_ctx), MBEDTLS_ERR_PK_BAD_INPUT_DATA);
     psa_destroy_key(key_id);
 #endif /* PSA_WANT_KEY_TYPE_DH_KEY_PAIR_GENERATE */
 
@@ -2382,7 +2464,7 @@
     psa_algorithm_t key_alg = key_alg_arg;
     psa_key_usage_t key_usage = PSA_KEY_USAGE_SIGN_HASH | PSA_KEY_USAGE_VERIFY_HASH |
                                 PSA_KEY_USAGE_EXPORT | PSA_KEY_USAGE_COPY;
-    mbedtls_pk_context pk_priv, pk_pub;
+    mbedtls_pk_context pk_priv, pk_priv_copy_public, pk_pub, pk_pub_copy_public;
     mbedtls_svc_key_id_t priv_key_id = MBEDTLS_SVC_KEY_ID_INIT;
     mbedtls_svc_key_id_t pub_key_id = MBEDTLS_SVC_KEY_ID_INIT;
     unsigned char *in_buf = NULL;
@@ -2392,7 +2474,9 @@
     size_t out_buf_len, out_buf2_len;
 
     mbedtls_pk_init(&pk_priv);
+    mbedtls_pk_init(&pk_priv_copy_public);
     mbedtls_pk_init(&pk_pub);
+    mbedtls_pk_init(&pk_pub_copy_public);
     PSA_INIT();
 
     if (key_type == PSA_KEY_TYPE_RSA_KEY_PAIR) {
@@ -2404,9 +2488,11 @@
                                  key_type, key_usage, key_alg, &priv_key_id));
     pub_key_id = psa_pub_key_from_priv(priv_key_id);
 
-    /* Create 2 PK contexts starting from the PSA keys we just created. */
+    /* Create 4 PK contexts starting from the PSA keys we just created. */
     TEST_EQUAL(mbedtls_pk_copy_from_psa(priv_key_id, &pk_priv), 0);
+    TEST_EQUAL(mbedtls_pk_copy_public_from_psa(pub_key_id, &pk_priv_copy_public), 0);
     TEST_EQUAL(mbedtls_pk_copy_from_psa(pub_key_id, &pk_pub), 0);
+    TEST_EQUAL(mbedtls_pk_copy_public_from_psa(pub_key_id, &pk_pub_copy_public), 0);
 
     /* Destoy both PSA keys to prove that generated PK contexts are independent
      * from them. */
@@ -2534,10 +2620,19 @@
         }
     }
 
+    /* Test that the keys from mbedtls_pk_copy_public_from_psa() are identical
+     * to the public key from mbedtls_pk_copy_from_psa(). */
+    mbedtls_test_set_step(1);
+    TEST_ASSERT(pk_public_same(&pk_pub, &pk_priv_copy_public));
+    mbedtls_test_set_step(2);
+    TEST_ASSERT(pk_public_same(&pk_pub, &pk_pub_copy_public));
+
 exit:
     mbedtls_free(in_buf);
     mbedtls_pk_free(&pk_priv);
+    mbedtls_pk_free(&pk_priv_copy_public);
     mbedtls_pk_free(&pk_pub);
+    mbedtls_pk_free(&pk_pub_copy_public);
     psa_destroy_key(priv_key_id);
     psa_destroy_key(pub_key_id);
     PSA_DONE();