pkwrap: update ECDSA verify and EC pair check to use the new public key

Signed-off-by: Valerio Setti <valerio.setti@nordicsemi.no>
diff --git a/library/pk_wrap.c b/library/pk_wrap.c
index 0e5e120..32d697a 100644
--- a/library/pk_wrap.c
+++ b/library/pk_wrap.c
@@ -23,6 +23,7 @@
 
 #if defined(MBEDTLS_PK_C)
 #include "pk_wrap.h"
+#include "pk_internal.h"
 #include "mbedtls/error.h"
 
 /* Even if RSA not activated, for the sake of RSA-alt */
@@ -653,8 +654,12 @@
 
 static size_t eckey_get_bitlen(mbedtls_pk_context *pk)
 {
+#if defined(MBEDTLS_PK_USE_PSA_EC_DATA)
+    return pk->ec_bits;
+#else
     mbedtls_ecp_keypair *ecp = (mbedtls_ecp_keypair *) pk->pk_ctx;
     return ecp->grp.pbits;
+#endif
 }
 
 #if defined(MBEDTLS_PK_CAN_ECDSA_VERIFY)
@@ -724,11 +729,20 @@
                              const unsigned char *hash, size_t hash_len,
                              const unsigned char *sig, size_t sig_len)
 {
-    mbedtls_ecp_keypair *ctx = pk->pk_ctx;
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
     mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
     psa_status_t status;
+    unsigned char *p;
+    psa_algorithm_t psa_sig_md = PSA_ALG_ECDSA_ANY;
+    size_t signature_len;
+    ((void) md_alg);
+#if defined(MBEDTLS_PK_USE_PSA_EC_DATA)
+    unsigned char buf[PSA_VENDOR_ECDSA_SIGNATURE_MAX_SIZE];
+    psa_ecc_family_t curve = pk->ec_family;
+    size_t curve_bits = pk->ec_bits;
+#else
+    mbedtls_ecp_keypair *ctx = pk->pk_ctx;
     size_t key_len;
     /* This buffer will initially contain the public key and then the signature
      * but at different points in time. For all curves except secp224k1, which
@@ -736,13 +750,10 @@
      * (header byte + 2 numbers, while the signature is only 2 numbers),
      * so use that as the buffer size. */
     unsigned char buf[MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH];
-    unsigned char *p;
-    psa_algorithm_t psa_sig_md = PSA_ALG_ECDSA_ANY;
     size_t curve_bits;
     psa_ecc_family_t curve =
         mbedtls_ecc_group_to_psa(ctx->grp.id, &curve_bits);
-    const size_t signature_part_size = (ctx->grp.nbits + 7) / 8;
-    ((void) md_alg);
+#endif
 
     if (curve == 0) {
         return MBEDTLS_ERR_PK_BAD_INPUT_DATA;
@@ -752,6 +763,11 @@
     psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_VERIFY_HASH);
     psa_set_key_algorithm(&attributes, psa_sig_md);
 
+#if defined(MBEDTLS_PK_USE_PSA_EC_DATA)
+    status = psa_import_key(&attributes,
+                            pk->pub_raw, pk->pub_raw_len,
+                            &key_id);
+#else /* MBEDTLS_PK_USE_PSA_EC_DATA */
     ret = mbedtls_ecp_point_write_binary(&ctx->grp, &ctx->Q,
                                          MBEDTLS_ECP_PF_UNCOMPRESSED,
                                          &key_len, buf, sizeof(buf));
@@ -762,27 +778,30 @@
     status = psa_import_key(&attributes,
                             buf, key_len,
                             &key_id);
+#endif /* MBEDTLS_PK_USE_PSA_EC_DATA */
     if (status != PSA_SUCCESS) {
         ret = PSA_PK_TO_MBEDTLS_ERR(status);
         goto cleanup;
     }
 
-    /* We don't need the exported key anymore and can
-     * reuse its buffer for signature extraction. */
-    if (2 * signature_part_size > sizeof(buf)) {
+    signature_len = PSA_ECDSA_SIGNATURE_SIZE(curve_bits);
+    if (signature_len > sizeof(buf)) {
         ret = MBEDTLS_ERR_PK_BAD_INPUT_DATA;
         goto cleanup;
     }
 
     p = (unsigned char *) sig;
+    /* extract_ecdsa_sig's last parameter is the size
+     * of each integer to be parse, so it's actually half
+     * the size of the signature. */
     if ((ret = extract_ecdsa_sig(&p, sig + sig_len, buf,
-                                 signature_part_size)) != 0) {
+                                 signature_len/2)) != 0) {
         goto cleanup;
     }
 
     status = psa_verify_hash(key_id, psa_sig_md,
                              hash, hash_len,
-                             buf, 2 * signature_part_size);
+                             buf, signature_len);
     if (status != PSA_SUCCESS) {
         ret = PSA_PK_ECDSA_TO_MBEDTLS_ERR(status);
         goto cleanup;
@@ -1112,26 +1131,30 @@
 {
     psa_status_t status, destruction_status;
     psa_key_attributes_t key_attr = PSA_KEY_ATTRIBUTES_INIT;
-    mbedtls_ecp_keypair *prv_ctx = prv->pk_ctx;
-    mbedtls_ecp_keypair *pub_ctx = pub->pk_ctx;
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     /* We are using MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH for the size of this
      * buffer because it will be used to hold the private key at first and
      * then its public part (but not at the same time). */
     uint8_t prv_key_buf[MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH];
     size_t prv_key_len;
+    mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
+#if defined(MBEDTLS_PK_USE_PSA_EC_DATA)
+    const psa_ecc_family_t curve = prv->ec_family;
+    const size_t curve_bits = PSA_BITS_TO_BYTES(prv->ec_bits);
+#else /* !MBEDTLS_PK_USE_PSA_EC_DATA */
     uint8_t pub_key_buf[MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH];
     size_t pub_key_len;
-    mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
     size_t curve_bits;
     const psa_ecc_family_t curve =
-        mbedtls_ecc_group_to_psa(prv_ctx->grp.id, &curve_bits);
+        mbedtls_ecc_group_to_psa(mbedtls_pk_ec_ro(*prv)->grp.id, &curve_bits);
     const size_t curve_bytes = PSA_BITS_TO_BYTES(curve_bits);
+#endif /* !MBEDTLS_PK_USE_PSA_EC_DATA */
 
     psa_set_key_type(&key_attr, PSA_KEY_TYPE_ECC_KEY_PAIR(curve));
     psa_set_key_usage_flags(&key_attr, PSA_KEY_USAGE_EXPORT);
 
-    ret = mbedtls_mpi_write_binary(&prv_ctx->d, prv_key_buf, curve_bytes);
+    ret = mbedtls_mpi_write_binary(&mbedtls_pk_ec_ro(*prv)->d,
+                                   prv_key_buf, curve_bytes);
     if (ret != 0) {
         return ret;
     }
@@ -1154,7 +1177,13 @@
         return PSA_PK_TO_MBEDTLS_ERR(destruction_status);
     }
 
-    ret = mbedtls_ecp_point_write_binary(&pub_ctx->grp, &pub_ctx->Q,
+#if defined(MBEDTLS_PK_USE_PSA_EC_DATA)
+    if (memcmp(prv_key_buf, pub->pub_raw, pub->pub_raw_len) != 0) {
+        return MBEDTLS_ERR_PK_BAD_INPUT_DATA;
+    }
+#else
+    ret = mbedtls_ecp_point_write_binary(&mbedtls_pk_ec_rw(*pub)->grp,
+                                         &mbedtls_pk_ec_rw(*pub)->Q,
                                          MBEDTLS_ECP_PF_UNCOMPRESSED,
                                          &pub_key_len, pub_key_buf,
                                          sizeof(pub_key_buf));
@@ -1165,6 +1194,7 @@
     if (memcmp(prv_key_buf, pub_key_buf, curve_bytes) != 0) {
         return MBEDTLS_ERR_PK_BAD_INPUT_DATA;
     }
+#endif /* !MBEDTLS_PK_USE_PSA_EC_DATA */
 
     return 0;
 }
@@ -1206,10 +1236,16 @@
 
 static void eckey_debug(mbedtls_pk_context *pk, mbedtls_pk_debug_item *items)
 {
+#if defined(MBEDTLS_PK_USE_PSA_EC_DATA)
+    items->type = MBEDTLS_PK_DEBUG_PSA_EC;
+    items->name = "eckey.Q";
+    items->value = pk;
+#else
     mbedtls_ecp_keypair *ecp = (mbedtls_ecp_keypair *) pk->pk_ctx;
     items->type = MBEDTLS_PK_DEBUG_ECP;
     items->name = "eckey.Q";
     items->value = &(ecp->Q);
+#endif
 }
 
 const mbedtls_pk_info_t mbedtls_eckey_info = {