mbedtls_pk_import_into_psa: implement and test

Implement mbedtls_pk_import_into_psa for all PK types except RSA_ALT.
This covers importing a key pair, importing a public key and importing
the public part of a key pair.

Test mbedtls_pk_import_into_psa() with the output of
mbedtls_pk_get_psa_attributes(). Also unit-test mbedtls_pk_import_into_psa()
on its own to get extra coverage, mostly for negative cases.

Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
diff --git a/library/pk.c b/library/pk.c
index 1b481e1..f623af9 100644
--- a/library/pk.c
+++ b/library/pk.c
@@ -18,10 +18,8 @@
 
 #if defined(MBEDTLS_RSA_C)
 #include "mbedtls/rsa.h"
-#if defined(MBEDTLS_PKCS1_V21) && !defined(MBEDTLS_USE_PSA_CRYPTO)
 #include "rsa_internal.h"
 #endif
-#endif
 #if defined(MBEDTLS_PK_HAVE_ECC_KEYS)
 #include "mbedtls/ecp.h"
 #endif
@@ -579,6 +577,284 @@
 
     return 0;
 }
+
+#if defined(MBEDTLS_PK_USE_PSA_EC_DATA) || defined(MBEDTLS_USE_PSA_CRYPTO)
+static psa_status_t export_import_into_psa(mbedtls_svc_key_id_t old_key_id,
+                                           const psa_key_attributes_t *attributes,
+                                           mbedtls_svc_key_id_t *new_key_id)
+{
+    unsigned char key_buffer[PSA_EXPORT_KEY_PAIR_MAX_SIZE];
+    size_t key_length = 0;
+    psa_status_t status = psa_export_key(old_key_id,
+                                         key_buffer, sizeof(key_buffer),
+                                         &key_length);
+    if (status != PSA_SUCCESS) {
+        return status;
+    }
+    status = psa_import_key(attributes, key_buffer, key_length, new_key_id);
+    mbedtls_platform_zeroize(key_buffer, key_length);
+    return status;
+}
+
+static int copy_into_psa(mbedtls_svc_key_id_t old_key_id,
+                         const psa_key_attributes_t *attributes,
+                         mbedtls_svc_key_id_t *new_key_id)
+{
+    /* Normally, we prefer copying: it's more efficient and works even
+     * for non-exportable keys. */
+    psa_status_t status = psa_copy_key(old_key_id, attributes, new_key_id);
+    if (status == PSA_ERROR_NOT_PERMITTED /*missing COPY usage*/ ||
+        status == PSA_ERROR_INVALID_ARGUMENT /*incompatible policy*/) {
+        /* There are edge cases where copying won't work, but export+import
+         * might:
+         * - If the old key does not allow PSA_KEY_USAGE_COPY.
+         * - If the old key's usage does not allow what attributes wants.
+         *   Because the key was intended for use in the pk module, and may
+         *   have had a policy chosen solely for what pk needs rather than
+         *   based on a detailed understanding of PSA policies, we are a bit
+         *   more liberal than psa_copy_key() here.
+         */
+        /* Here we need to check that the types match, otherwise we risk
+         * importing nonsensical data. */
+        psa_key_attributes_t old_attributes = PSA_KEY_ATTRIBUTES_INIT;
+        status = psa_get_key_attributes(old_key_id, &old_attributes);
+        if (status != PSA_SUCCESS) {
+            return MBEDTLS_ERR_PK_BAD_INPUT_DATA;
+        }
+        psa_key_type_t old_type = psa_get_key_type(&old_attributes);
+        psa_reset_key_attributes(&old_attributes);
+        if (old_type != psa_get_key_type(attributes)) {
+            return MBEDTLS_ERR_PK_TYPE_MISMATCH;
+        }
+        status = export_import_into_psa(old_key_id, attributes, new_key_id);
+    }
+    return PSA_PK_TO_MBEDTLS_ERR(status);
+}
+#endif /* MBEDTLS_PK_USE_PSA_EC_DATA || MBEDTLS_USE_PSA_CRYPTO */
+
+static int import_pair_into_psa(const mbedtls_pk_context *pk,
+                                const psa_key_attributes_t *attributes,
+                                mbedtls_svc_key_id_t *key_id)
+{
+    switch (mbedtls_pk_get_type(pk)) {
+#if defined(MBEDTLS_RSA_C)
+        case MBEDTLS_PK_RSA:
+        {
+            if (psa_get_key_type(attributes) != PSA_KEY_TYPE_RSA_KEY_PAIR) {
+                return MBEDTLS_ERR_PK_TYPE_MISMATCH;
+            }
+            unsigned char key_buffer[
+                PSA_KEY_EXPORT_RSA_KEY_PAIR_MAX_SIZE(PSA_VENDOR_RSA_MAX_KEY_BITS)];
+            unsigned char *const key_end = key_buffer + sizeof(key_buffer);
+            unsigned char *key_data = key_end;
+            int ret = mbedtls_rsa_write_key(mbedtls_pk_rsa(*pk),
+                                            key_buffer, &key_data);
+            if (ret < 0) {
+                return ret;
+            }
+            size_t key_length = key_end - key_data;
+            ret = PSA_PK_TO_MBEDTLS_ERR(psa_import_key(attributes,
+                                                       key_data, key_length,
+                                                       key_id));
+            mbedtls_platform_zeroize(key_data, key_length);
+            return ret;
+        }
+#endif /* MBEDTLS_RSA_C */
+
+#if defined(MBEDTLS_PK_HAVE_ECC_KEYS)
+        case MBEDTLS_PK_ECKEY:
+        case MBEDTLS_PK_ECKEY_DH:
+        case MBEDTLS_PK_ECDSA:
+        {
+            /* We need to check the curve family, otherwise the import could
+             * succeed with nonsensical data.
+             * We don't check the bit-size: it's optional in attributes,
+             * and if it's specified, psa_import_key() will know from the key
+             * data length and will check that the bit-size matches. */
+            psa_key_type_t to_type = psa_get_key_type(attributes);
+#if defined(MBEDTLS_PK_USE_PSA_EC_DATA)
+            psa_ecc_family_t from_family = pk->ec_family;
+#else /* MBEDTLS_PK_USE_PSA_EC_DATA */
+            /* We're only reading the key, but mbedtls_ecp_write_key()
+             * is missing a const annotation on its key parameter, so
+             * we need the non-const accessor here. */
+            mbedtls_ecp_keypair *ec = mbedtls_pk_ec_rw(*pk);
+            size_t from_bits = 0;
+            psa_ecc_family_t from_family = mbedtls_ecc_group_to_psa(ec->grp.id,
+                                                                    &from_bits);
+#endif /* MBEDTLS_PK_USE_PSA_EC_DATA */
+            if (to_type != PSA_KEY_TYPE_ECC_KEY_PAIR(from_family)) {
+                return MBEDTLS_ERR_PK_TYPE_MISMATCH;
+            }
+
+#if defined(MBEDTLS_PK_USE_PSA_EC_DATA)
+            if (mbedtls_svc_key_id_is_null(pk->priv_id)) {
+                /* We have a public key and want a key pair. */
+                return MBEDTLS_ERR_PK_TYPE_MISMATCH;
+            }
+            return copy_into_psa(pk->priv_id, attributes, key_id);
+#else /* MBEDTLS_PK_USE_PSA_EC_DATA */
+            if (ec->d.n == 0) {
+                /* Private key not set. Assume the input is a public key only.
+                 * (The other possibility is that it's an incomplete object
+                 * where the group is set but neither the public key nor
+                 * the private key. This is not possible through ecp.h
+                 * functions, so we don't bother reporting a more suitable
+                 * error in that case.) */
+                return MBEDTLS_ERR_PK_TYPE_MISMATCH;
+            }
+            unsigned char key_buffer[PSA_BITS_TO_BYTES(PSA_VENDOR_ECC_MAX_CURVE_BITS)];
+            int ret = mbedtls_ecp_write_key(ec,
+                                            key_buffer, sizeof(key_buffer));
+            if (ret < 0) {
+                return ret;
+            }
+            size_t key_length = PSA_BITS_TO_BYTES(ec->grp.nbits);
+            unsigned char *key_data = key_buffer + sizeof(key_buffer) - key_length;
+            ret = PSA_PK_TO_MBEDTLS_ERR(psa_import_key(attributes,
+                                                       key_data, key_length,
+                                                       key_id));
+            mbedtls_platform_zeroize(key_data, key_length);
+            return ret;
+#endif /* MBEDTLS_PK_USE_PSA_EC_DATA */
+        }
+#endif /* MBEDTLS_PK_HAVE_ECC_KEYS */
+
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+        case MBEDTLS_PK_OPAQUE:
+            return copy_into_psa(pk->priv_id, attributes, key_id);
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
+
+        default:
+            return MBEDTLS_ERR_PK_BAD_INPUT_DATA;
+    }
+}
+
+static int import_public_into_psa(const mbedtls_pk_context *pk,
+                                  const psa_key_attributes_t *attributes,
+                                  mbedtls_svc_key_id_t *key_id)
+{
+    psa_key_type_t psa_type = psa_get_key_type(attributes);
+
+#if defined(MBEDTLS_RSA_C) ||                                           \
+    (defined(MBEDTLS_PK_HAVE_ECC_KEYS) && !defined(MBEDTLS_PK_USE_PSA_EC_DATA)) || \
+    defined(MBEDTLS_USE_PSA_CRYPTO)
+    unsigned char key_buffer[PSA_EXPORT_PUBLIC_KEY_MAX_SIZE];
+#endif
+    unsigned char *key_data = NULL;
+    size_t key_length = 0;
+
+    switch (mbedtls_pk_get_type(pk)) {
+#if defined(MBEDTLS_RSA_C)
+        case MBEDTLS_PK_RSA:
+        {
+            if (psa_type != PSA_KEY_TYPE_RSA_PUBLIC_KEY) {
+                return MBEDTLS_ERR_PK_TYPE_MISMATCH;
+            }
+            unsigned char *const key_end = key_buffer + sizeof(key_buffer);
+            key_data = key_end;
+            int ret = mbedtls_rsa_write_pubkey(mbedtls_pk_rsa(*pk),
+                                               key_buffer, &key_data);
+            if (ret < 0) {
+                return ret;
+            }
+            key_length = (size_t) ret;
+            break;
+        }
+#endif /*MBEDTLS_RSA_C */
+
+#if defined(MBEDTLS_PK_HAVE_ECC_KEYS)
+        case MBEDTLS_PK_ECKEY:
+        case MBEDTLS_PK_ECKEY_DH:
+        case MBEDTLS_PK_ECDSA:
+        {
+            /* We need to check the curve family, otherwise the import could
+             * succeed with nonsensical data.
+             * We don't check the bit-size: it's optional in attributes,
+             * and if it's specified, psa_import_key() will know from the key
+             * data length and will check that the bit-size matches. */
+#if defined(MBEDTLS_PK_USE_PSA_EC_DATA)
+            if (psa_type != PSA_KEY_TYPE_ECC_PUBLIC_KEY(pk->ec_family)) {
+                return MBEDTLS_ERR_PK_TYPE_MISMATCH;
+            }
+            key_data = (unsigned char *) pk->pub_raw;
+            key_length = pk->pub_raw_len;
+#else /* MBEDTLS_PK_USE_PSA_EC_DATA */
+            const mbedtls_ecp_keypair *ec = mbedtls_pk_ec_ro(*pk);
+            size_t from_bits = 0;
+            psa_ecc_family_t from_family = mbedtls_ecc_group_to_psa(ec->grp.id,
+                                                                    &from_bits);
+            psa_key_type_t to_type = psa_get_key_type(attributes);
+            if (to_type != PSA_KEY_TYPE_ECC_PUBLIC_KEY(from_family)) {
+                return MBEDTLS_ERR_PK_TYPE_MISMATCH;
+            }
+            int ret = mbedtls_ecp_write_public_key(
+                ec, MBEDTLS_ECP_PF_UNCOMPRESSED,
+                &key_length, key_buffer, sizeof(key_buffer));
+            if (ret < 0) {
+                return ret;
+            }
+            key_data = key_buffer;
+#endif /* MBEDTLS_PK_USE_PSA_EC_DATA */
+            break;
+        }
+#endif /* MBEDTLS_PK_HAVE_ECC_KEYS */
+
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+        case MBEDTLS_PK_OPAQUE:
+        {
+            psa_key_attributes_t old_attributes = PSA_KEY_ATTRIBUTES_INIT;
+            psa_status_t status =
+                psa_get_key_attributes(pk->priv_id, &old_attributes);
+            if (status != PSA_SUCCESS) {
+                return MBEDTLS_ERR_PK_BAD_INPUT_DATA;
+            }
+            psa_key_type_t old_type = psa_get_key_type(&old_attributes);
+            psa_reset_key_attributes(&old_attributes);
+            if (psa_type != PSA_KEY_TYPE_PUBLIC_KEY_OF_KEY_PAIR(old_type)) {
+                return MBEDTLS_ERR_PK_TYPE_MISMATCH;
+            }
+            status = psa_export_public_key(pk->priv_id,
+                                           key_buffer, sizeof(key_buffer),
+                                           &key_length);
+            if (status != PSA_SUCCESS) {
+                return PSA_PK_TO_MBEDTLS_ERR(status);
+            }
+            key_data = key_buffer;
+            break;
+        }
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
+
+        default:
+            return MBEDTLS_ERR_PK_BAD_INPUT_DATA;
+    }
+
+    return PSA_PK_TO_MBEDTLS_ERR(psa_import_key(attributes,
+                                                key_data, key_length,
+                                                key_id));
+}
+
+int mbedtls_pk_import_into_psa(const mbedtls_pk_context *pk,
+                               const psa_key_attributes_t *attributes,
+                               mbedtls_svc_key_id_t *key_id)
+{
+    /* Set the output immediately so that it won't contain garbage even
+     * if we error out before calling psa_import_key(). */
+    *key_id = MBEDTLS_SVC_KEY_ID_INIT;
+
+#if defined(MBEDTLS_PK_RSA_ALT_SUPPORT)
+    if (mbedtls_pk_get_type(pk) == MBEDTLS_PK_RSA_ALT) {
+        return MBEDTLS_ERR_PK_FEATURE_UNAVAILABLE;
+    }
+#endif /* MBEDTLS_PK_RSA_ALT_SUPPORT */
+
+    int want_public = PSA_KEY_TYPE_IS_PUBLIC_KEY(psa_get_key_type(attributes));
+    if (want_public) {
+        return import_public_into_psa(pk, attributes, key_id);
+    } else {
+        return import_pair_into_psa(pk, attributes, key_id);
+    }
+}
 #endif /* MBEDTLS_PSA_CRYPTO_C */
 
 /*