pk: improve sign, check_pair and wrap_as_opaque functions with new format

Signed-off-by: Valerio Setti <valerio.setti@nordicsemi.no>
diff --git a/library/pk.c b/library/pk.c
index 77012e1..cccadb1 100644
--- a/library/pk.c
+++ b/library/pk.c
@@ -912,24 +912,34 @@
 #else /* !MBEDTLS_ECP_LIGHT && !MBEDTLS_RSA_C */
 #if defined(MBEDTLS_ECP_LIGHT)
     if (mbedtls_pk_get_type(pk) == MBEDTLS_PK_ECKEY) {
-        mbedtls_ecp_keypair *ec;
         unsigned char d[MBEDTLS_ECP_MAX_BYTES];
         size_t d_len;
         psa_ecc_family_t curve_id;
         psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
         psa_key_type_t key_type;
         size_t bits;
-        int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
         psa_status_t status;
 
         /* export the private key material in the format PSA wants */
-        ec = mbedtls_pk_ec_rw(*pk);
+#if defined(MBEDTLS_PK_USE_PSA_EC_DATA)
+        status = psa_export_key(pk->priv_id, d, sizeof(d), &d_len);
+        if (status != PSA_SUCCESS) {
+            return psa_pk_status_to_mbedtls(status);
+        }
+
+        curve_id = pk->ec_family;
+        bits = pk->ec_bits;
+#else /* MBEDTLS_PK_USE_PSA_EC_DATA */
+        mbedtls_ecp_keypair *ec = mbedtls_pk_ec_rw(*pk);
+        int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+
         d_len = PSA_BITS_TO_BYTES(ec->grp.nbits);
         if ((ret = mbedtls_ecp_write_key(ec, d, d_len)) != 0) {
             return ret;
         }
 
         curve_id = mbedtls_ecc_group_to_psa(ec->grp.id, &bits);
+#endif /* MBEDTLS_PK_USE_PSA_EC_DATA */
         key_type = PSA_KEY_TYPE_ECC_KEY_PAIR(curve_id);
 
         /* prepare the key attributes */
diff --git a/library/pk_wrap.c b/library/pk_wrap.c
index 7f5e751..f3a44ae 100644
--- a/library/pk_wrap.c
+++ b/library/pk_wrap.c
@@ -925,12 +925,9 @@
                            unsigned char *sig, size_t sig_size, size_t *sig_len,
                            int (*f_rng)(void *, unsigned char *, size_t), void *p_rng)
 {
-    mbedtls_ecp_keypair *ctx = pk->pk_ctx;
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
     mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
     psa_status_t status;
-    unsigned char buf[MBEDTLS_PSA_MAX_EC_KEY_PAIR_LENGTH];
 #if defined(MBEDTLS_ECDSA_DETERMINISTIC)
     psa_algorithm_t psa_sig_md =
         PSA_ALG_DETERMINISTIC_ECDSA(mbedtls_hash_info_psa_from_md(md_alg));
@@ -938,10 +935,17 @@
     psa_algorithm_t psa_sig_md =
         PSA_ALG_ECDSA(mbedtls_hash_info_psa_from_md(md_alg));
 #endif
+#if defined(MBEDTLS_PK_USE_PSA_EC_DATA)
+    psa_ecc_family_t curve = pk->ec_family;
+#else /* MBEDTLS_PK_USE_PSA_EC_DATA */
+    mbedtls_ecp_keypair *ctx = pk->pk_ctx;
+    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
+    unsigned char buf[MBEDTLS_PSA_MAX_EC_KEY_PAIR_LENGTH];
     size_t curve_bits;
     psa_ecc_family_t curve =
         mbedtls_ecc_group_to_psa(ctx->grp.id, &curve_bits);
     size_t key_len = PSA_BITS_TO_BYTES(curve_bits);
+#endif /* MBEDTLS_PK_USE_PSA_EC_DATA */
 
     /* PSA has its own RNG */
     ((void) f_rng);
@@ -951,6 +955,12 @@
         return MBEDTLS_ERR_PK_BAD_INPUT_DATA;
     }
 
+#if defined(MBEDTLS_PK_USE_PSA_EC_DATA)
+    if (MBEDTLS_SVC_KEY_ID_GET_KEY_ID(pk->priv_id) == PSA_KEY_ID_NULL) {
+        return MBEDTLS_ERR_PK_BAD_INPUT_DATA;
+    }
+    key_id = pk->priv_id;
+#else
     if (key_len > sizeof(buf)) {
         return MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     }
@@ -970,6 +980,7 @@
         ret = PSA_PK_TO_MBEDTLS_ERR(status);
         goto cleanup;
     }
+#endif /* MBEDTLS_PK_USE_PSA_EC_DATA */
 
     status = psa_sign_hash(key_id, psa_sig_md, hash, hash_len,
                            sig, sig_size, sig_len);
@@ -981,8 +992,11 @@
     ret = pk_ecdsa_sig_asn1_from_psa(sig, sig_len, sig_size);
 
 cleanup:
+
+#if !defined(MBEDTLS_PK_USE_PSA_EC_DATA)
     mbedtls_platform_zeroize(buf, sizeof(buf));
     status = psa_destroy_key(key_id);
+#endif /* MBEDTLS_PK_USE_PSA_EC_DATA */
     if (ret == 0 && status != PSA_SUCCESS) {
         ret = PSA_PK_TO_MBEDTLS_ERR(status);
     }
@@ -1123,24 +1137,19 @@
 static int eckey_check_pair_psa(mbedtls_pk_context *pub, mbedtls_pk_context *prv)
 {
     psa_status_t status, destruction_status;
-    psa_key_attributes_t key_attr = PSA_KEY_ATTRIBUTES_INIT;
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-    /* We are using MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH for the size of this
-     * buffer because it will be used to hold the private key at first and
-     * then its public part (but not at the same time). */
     uint8_t prv_key_buf[MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH];
     size_t prv_key_len;
-    mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
 #if defined(MBEDTLS_PK_USE_PSA_EC_DATA)
-    const psa_ecc_family_t curve = prv->ec_family;
-    const size_t curve_bits = prv->ec_bits;
+    mbedtls_svc_key_id_t key_id = prv->priv_id;
 #else /* !MBEDTLS_PK_USE_PSA_EC_DATA */
+    mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
+    psa_key_attributes_t key_attr = PSA_KEY_ATTRIBUTES_INIT;
     uint8_t pub_key_buf[MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH];
     size_t pub_key_len;
     size_t curve_bits;
     const psa_ecc_family_t curve =
         mbedtls_ecc_group_to_psa(mbedtls_pk_ec_ro(*prv)->grp.id, &curve_bits);
-#endif /* !MBEDTLS_PK_USE_PSA_EC_DATA */
     const size_t curve_bytes = PSA_BITS_TO_BYTES(curve_bits);
 
     if (curve == 0) {
@@ -1163,6 +1172,7 @@
     }
 
     mbedtls_platform_zeroize(prv_key_buf, sizeof(prv_key_buf));
+#endif /* !MBEDTLS_PK_USE_PSA_EC_DATA */
 
     status = psa_export_public_key(key_id, prv_key_buf, sizeof(prv_key_buf),
                                    &prv_key_len);