pk: pass pk_context pointer to wrappers intead of void one

Signed-off-by: valerio <valerio.setti@nordicsemi.no>
diff --git a/library/pk_wrap.c b/library/pk_wrap.c
index 57bfdca..6c9f97b 100644
--- a/library/pk_wrap.c
+++ b/library/pk_wrap.c
@@ -191,18 +191,18 @@
            type == MBEDTLS_PK_RSASSA_PSS;
 }
 
-static size_t rsa_get_bitlen(const void *ctx)
+static size_t rsa_get_bitlen(mbedtls_pk_context *pk)
 {
-    const mbedtls_rsa_context *rsa = (const mbedtls_rsa_context *) ctx;
+    const mbedtls_rsa_context *rsa = (const mbedtls_rsa_context *) pk->pk_ctx;
     return 8 * mbedtls_rsa_get_len(rsa);
 }
 
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
-static int rsa_verify_wrap(void *ctx, mbedtls_md_type_t md_alg,
+static int rsa_verify_wrap(mbedtls_pk_context *pk, mbedtls_md_type_t md_alg,
                            const unsigned char *hash, size_t hash_len,
                            const unsigned char *sig, size_t sig_len)
 {
-    mbedtls_rsa_context *rsa = (mbedtls_rsa_context *) ctx;
+    mbedtls_rsa_context *rsa = (mbedtls_rsa_context *) pk->pk_ctx;
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
     mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
@@ -225,7 +225,7 @@
     /* mbedtls_pk_write_pubkey_der() expects a full PK context;
      * re-construct one to make it happy */
     key.pk_info = &mbedtls_rsa_info;
-    key.pk_ctx = ctx;
+    key.pk_ctx = rsa;
     key_len = mbedtls_pk_write_pubkey_der(&key, buf, sizeof(buf));
     if (key_len <= 0) {
         return MBEDTLS_ERR_PK_BAD_INPUT_DATA;
@@ -260,12 +260,12 @@
     return ret;
 }
 #else
-static int rsa_verify_wrap(void *ctx, mbedtls_md_type_t md_alg,
+static int rsa_verify_wrap(mbedtls_pk_context *pk, mbedtls_md_type_t md_alg,
                            const unsigned char *hash, size_t hash_len,
                            const unsigned char *sig, size_t sig_len)
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-    mbedtls_rsa_context *rsa = (mbedtls_rsa_context *) ctx;
+    mbedtls_rsa_context *rsa = (mbedtls_rsa_context *) pk->pk_ctx;
     size_t rsa_len = mbedtls_rsa_get_len(rsa);
 
     if (md_alg == MBEDTLS_MD_NONE && UINT_MAX < hash_len) {
@@ -354,7 +354,7 @@
 #endif /* MBEDTLS_PSA_CRYPTO_C */
 
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
-static int rsa_sign_wrap(void *ctx, mbedtls_md_type_t md_alg,
+static int rsa_sign_wrap(mbedtls_pk_context *pk, mbedtls_md_type_t md_alg,
                          const unsigned char *hash, size_t hash_len,
                          unsigned char *sig, size_t sig_size, size_t *sig_len,
                          int (*f_rng)(void *, unsigned char *, size_t), void *p_rng)
@@ -370,16 +370,16 @@
 
     return mbedtls_pk_psa_rsa_sign_ext(PSA_ALG_RSA_PKCS1V15_SIGN(
                                            psa_md_alg),
-                                       ctx, hash, hash_len,
+                                       pk->pk_ctx, hash, hash_len,
                                        sig, sig_size, sig_len);
 }
 #else
-static int rsa_sign_wrap(void *ctx, mbedtls_md_type_t md_alg,
+static int rsa_sign_wrap(mbedtls_pk_context *pk, mbedtls_md_type_t md_alg,
                          const unsigned char *hash, size_t hash_len,
                          unsigned char *sig, size_t sig_size, size_t *sig_len,
                          int (*f_rng)(void *, unsigned char *, size_t), void *p_rng)
 {
-    mbedtls_rsa_context *rsa = (mbedtls_rsa_context *) ctx;
+    mbedtls_rsa_context *rsa = (mbedtls_rsa_context *) pk->pk_ctx;
 
     if (md_alg == MBEDTLS_MD_NONE && UINT_MAX < hash_len) {
         return MBEDTLS_ERR_PK_BAD_INPUT_DATA;
@@ -397,12 +397,12 @@
 #endif
 
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
-static int rsa_decrypt_wrap(void *ctx,
+static int rsa_decrypt_wrap(mbedtls_pk_context *pk,
                             const unsigned char *input, size_t ilen,
                             unsigned char *output, size_t *olen, size_t osize,
                             int (*f_rng)(void *, unsigned char *, size_t), void *p_rng)
 {
-    mbedtls_rsa_context *rsa = (mbedtls_rsa_context *) ctx;
+    mbedtls_rsa_context *rsa = (mbedtls_rsa_context *) pk->pk_ctx;
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
     mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
@@ -427,7 +427,7 @@
     /* mbedtls_pk_write_key_der() expects a full PK context;
      * re-construct one to make it happy */
     key.pk_info = &mbedtls_rsa_info;
-    key.pk_ctx = ctx;
+    key.pk_ctx = rsa;
     key_len = mbedtls_pk_write_key_der(&key, buf, sizeof(buf));
     if (key_len <= 0) {
         return MBEDTLS_ERR_PK_BAD_INPUT_DATA;
@@ -466,12 +466,12 @@
     return ret;
 }
 #else
-static int rsa_decrypt_wrap(void *ctx,
+static int rsa_decrypt_wrap(mbedtls_pk_context *pk,
                             const unsigned char *input, size_t ilen,
                             unsigned char *output, size_t *olen, size_t osize,
                             int (*f_rng)(void *, unsigned char *, size_t), void *p_rng)
 {
-    mbedtls_rsa_context *rsa = (mbedtls_rsa_context *) ctx;
+    mbedtls_rsa_context *rsa = (mbedtls_rsa_context *) pk->pk_ctx;
 
     if (ilen != mbedtls_rsa_get_len(rsa)) {
         return MBEDTLS_ERR_RSA_BAD_INPUT_DATA;
@@ -483,12 +483,12 @@
 #endif
 
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
-static int rsa_encrypt_wrap(void *ctx,
+static int rsa_encrypt_wrap(mbedtls_pk_context *pk,
                             const unsigned char *input, size_t ilen,
                             unsigned char *output, size_t *olen, size_t osize,
                             int (*f_rng)(void *, unsigned char *, size_t), void *p_rng)
 {
-    mbedtls_rsa_context *rsa = (mbedtls_rsa_context *) ctx;
+    mbedtls_rsa_context *rsa = (mbedtls_rsa_context *) pk->pk_ctx;
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
     mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
@@ -513,7 +513,7 @@
     /* mbedtls_pk_write_pubkey_der() expects a full PK context;
      * re-construct one to make it happy */
     key.pk_info = &mbedtls_rsa_info;
-    key.pk_ctx = ctx;
+    key.pk_ctx = rsa;
     key_len = mbedtls_pk_write_pubkey_der(&key, buf, sizeof(buf));
     if (key_len <= 0) {
         return MBEDTLS_ERR_PK_BAD_INPUT_DATA;
@@ -551,12 +551,12 @@
     return ret;
 }
 #else
-static int rsa_encrypt_wrap(void *ctx,
+static int rsa_encrypt_wrap(mbedtls_pk_context *pk,
                             const unsigned char *input, size_t ilen,
                             unsigned char *output, size_t *olen, size_t osize,
                             int (*f_rng)(void *, unsigned char *, size_t), void *p_rng)
 {
-    mbedtls_rsa_context *rsa = (mbedtls_rsa_context *) ctx;
+    mbedtls_rsa_context *rsa = (mbedtls_rsa_context *) pk->pk_ctx;
     *olen = mbedtls_rsa_get_len(rsa);
 
     if (*olen > osize) {
@@ -568,14 +568,14 @@
 }
 #endif
 
-static int rsa_check_pair_wrap(const void *pub, const void *prv,
+static int rsa_check_pair_wrap(mbedtls_pk_context *pub, mbedtls_pk_context *prv,
                                int (*f_rng)(void *, unsigned char *, size_t),
                                void *p_rng)
 {
     (void) f_rng;
     (void) p_rng;
-    return mbedtls_rsa_check_pub_priv((const mbedtls_rsa_context *) pub,
-                                      (const mbedtls_rsa_context *) prv);
+    return mbedtls_rsa_check_pub_priv((const mbedtls_rsa_context *) pub->pk_ctx,
+                                      (const mbedtls_rsa_context *) prv->pk_ctx);
 }
 
 static void *rsa_alloc_wrap(void)
@@ -595,22 +595,24 @@
     mbedtls_free(ctx);
 }
 
-static void rsa_debug(const void *ctx, mbedtls_pk_debug_item *items)
+static void rsa_debug(mbedtls_pk_context *pk, mbedtls_pk_debug_item *items)
 {
 #if defined(MBEDTLS_RSA_ALT)
     /* Not supported */
-    (void) ctx;
+    (void) pk;
     (void) items;
 #else
+    mbedtls_rsa_context *rsa = (mbedtls_rsa_context *) pk->pk_ctx;
+
     items->type = MBEDTLS_PK_DEBUG_MPI;
     items->name = "rsa.N";
-    items->value = &(((mbedtls_rsa_context *) ctx)->N);
+    items->value = &(rsa->N);
 
     items++;
 
     items->type = MBEDTLS_PK_DEBUG_MPI;
     items->name = "rsa.E";
-    items->value = &(((mbedtls_rsa_context *) ctx)->E);
+    items->value = &(rsa->E);
 #endif
 }
 
@@ -649,9 +651,10 @@
            type == MBEDTLS_PK_ECDSA;
 }
 
-static size_t eckey_get_bitlen(const void *ctx)
+static size_t eckey_get_bitlen(mbedtls_pk_context *pk)
 {
-    return ((mbedtls_ecp_keypair *) ctx)->grp.pbits;
+    mbedtls_ecp_keypair *ecp = (mbedtls_ecp_keypair *) pk->pk_ctx;
+    return ecp->grp.pbits;
 }
 
 #if defined(MBEDTLS_PK_CAN_ECDSA_VERIFY)
@@ -716,11 +719,12 @@
     return 0;
 }
 
-static int ecdsa_verify_wrap(void *ctx_arg, mbedtls_md_type_t md_alg,
+static int ecdsa_verify_wrap(mbedtls_pk_context *pk,
+                             mbedtls_md_type_t md_alg,
                              const unsigned char *hash, size_t hash_len,
                              const unsigned char *sig, size_t sig_len)
 {
-    mbedtls_ecp_keypair *ctx = ctx_arg;
+    mbedtls_ecp_keypair *ctx = pk->pk_ctx;
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
     mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
@@ -799,14 +803,14 @@
     return ret;
 }
 #else /* MBEDTLS_USE_PSA_CRYPTO */
-static int ecdsa_verify_wrap(void *ctx, mbedtls_md_type_t md_alg,
+static int ecdsa_verify_wrap(mbedtls_pk_context *pk, mbedtls_md_type_t md_alg,
                              const unsigned char *hash, size_t hash_len,
                              const unsigned char *sig, size_t sig_len)
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     ((void) md_alg);
 
-    ret = mbedtls_ecdsa_read_signature((mbedtls_ecdsa_context *) ctx,
+    ret = mbedtls_ecdsa_read_signature((mbedtls_ecdsa_context *) pk->pk_ctx,
                                        hash, hash_len, sig, sig_len);
 
     if (ret == MBEDTLS_ERR_ECP_SIG_LEN_MISMATCH) {
@@ -904,12 +908,12 @@
     return 0;
 }
 
-static int ecdsa_sign_wrap(void *ctx_arg, mbedtls_md_type_t md_alg,
+static int ecdsa_sign_wrap(mbedtls_pk_context *pk, mbedtls_md_type_t md_alg,
                            const unsigned char *hash, size_t hash_len,
                            unsigned char *sig, size_t sig_size, size_t *sig_len,
                            int (*f_rng)(void *, unsigned char *, size_t), void *p_rng)
 {
-    mbedtls_ecp_keypair *ctx = ctx_arg;
+    mbedtls_ecp_keypair *ctx = pk->pk_ctx;
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
     mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
@@ -974,12 +978,12 @@
     return ret;
 }
 #else /* MBEDTLS_USE_PSA_CRYPTO */
-static int ecdsa_sign_wrap(void *ctx, mbedtls_md_type_t md_alg,
+static int ecdsa_sign_wrap(mbedtls_pk_context *pk, mbedtls_md_type_t md_alg,
                            const unsigned char *hash, size_t hash_len,
                            unsigned char *sig, size_t sig_size, size_t *sig_len,
                            int (*f_rng)(void *, unsigned char *, size_t), void *p_rng)
 {
-    return mbedtls_ecdsa_write_signature((mbedtls_ecdsa_context *) ctx,
+    return mbedtls_ecdsa_write_signature((mbedtls_ecdsa_context *) pk->pk_ctx,
                                          md_alg, hash, hash_len,
                                          sig, sig_size, sig_len,
                                          f_rng, p_rng);
@@ -989,12 +993,12 @@
 
 #if defined(MBEDTLS_ECDSA_C) && defined(MBEDTLS_ECP_RESTARTABLE)
 /* Forward declarations */
-static int ecdsa_verify_rs_wrap(void *ctx, mbedtls_md_type_t md_alg,
+static int ecdsa_verify_rs_wrap(mbedtls_pk_context *ctx, mbedtls_md_type_t md_alg,
                                 const unsigned char *hash, size_t hash_len,
                                 const unsigned char *sig, size_t sig_len,
                                 void *rs_ctx);
 
-static int ecdsa_sign_rs_wrap(void *ctx, mbedtls_md_type_t md_alg,
+static int ecdsa_sign_rs_wrap(mbedtls_pk_context *ctx, mbedtls_md_type_t md_alg,
                               const unsigned char *hash, size_t hash_len,
                               unsigned char *sig, size_t sig_size, size_t *sig_len,
                               int (*f_rng)(void *, unsigned char *, size_t), void *p_rng,
@@ -1041,7 +1045,7 @@
     mbedtls_free(ctx);
 }
 
-static int eckey_verify_rs_wrap(void *ctx, mbedtls_md_type_t md_alg,
+static int eckey_verify_rs_wrap(mbedtls_pk_context *pk, mbedtls_md_type_t md_alg,
                                 const unsigned char *hash, size_t hash_len,
                                 const unsigned char *sig, size_t sig_len,
                                 void *rs_ctx)
@@ -1056,10 +1060,10 @@
 
     /* set up our own sub-context if needed (that is, on first run) */
     if (rs->ecdsa_ctx.grp.pbits == 0) {
-        MBEDTLS_MPI_CHK(mbedtls_ecdsa_from_keypair(&rs->ecdsa_ctx, ctx));
+        MBEDTLS_MPI_CHK(mbedtls_ecdsa_from_keypair(&rs->ecdsa_ctx, pk->pk_ctx));
     }
 
-    MBEDTLS_MPI_CHK(ecdsa_verify_rs_wrap(&rs->ecdsa_ctx,
+    MBEDTLS_MPI_CHK(ecdsa_verify_rs_wrap(pk,
                                          md_alg, hash, hash_len,
                                          sig, sig_len, &rs->ecdsa_rs));
 
@@ -1067,7 +1071,7 @@
     return ret;
 }
 
-static int eckey_sign_rs_wrap(void *ctx, mbedtls_md_type_t md_alg,
+static int eckey_sign_rs_wrap(mbedtls_pk_context *pk, mbedtls_md_type_t md_alg,
                               const unsigned char *hash, size_t hash_len,
                               unsigned char *sig, size_t sig_size, size_t *sig_len,
                               int (*f_rng)(void *, unsigned char *, size_t), void *p_rng,
@@ -1083,10 +1087,10 @@
 
     /* set up our own sub-context if needed (that is, on first run) */
     if (rs->ecdsa_ctx.grp.pbits == 0) {
-        MBEDTLS_MPI_CHK(mbedtls_ecdsa_from_keypair(&rs->ecdsa_ctx, ctx));
+        MBEDTLS_MPI_CHK(mbedtls_ecdsa_from_keypair(&rs->ecdsa_ctx, pk->pk_ctx));
     }
 
-    MBEDTLS_MPI_CHK(ecdsa_sign_rs_wrap(&rs->ecdsa_ctx, md_alg,
+    MBEDTLS_MPI_CHK(ecdsa_sign_rs_wrap(pk, md_alg,
                                        hash, hash_len, sig, sig_size, sig_len,
                                        f_rng, p_rng, &rs->ecdsa_rs));
 
@@ -1104,13 +1108,12 @@
  * - write the raw content of public key "pub" to a local buffer
  * - compare the two buffers
  */
-static int eckey_check_pair_psa(const mbedtls_ecp_keypair *pub,
-                                const mbedtls_ecp_keypair *prv)
+static int eckey_check_pair_psa(mbedtls_pk_context *pub, mbedtls_pk_context *prv)
 {
     psa_status_t status, destruction_status;
     psa_key_attributes_t key_attr = PSA_KEY_ATTRIBUTES_INIT;
-    mbedtls_ecp_keypair *prv_ctx = (mbedtls_ecp_keypair *) prv;
-    mbedtls_ecp_keypair *pub_ctx = (mbedtls_ecp_keypair *) pub;
+    mbedtls_ecp_keypair *prv_ctx = prv->pk_ctx;
+    mbedtls_ecp_keypair *pub_ctx = pub->pk_ctx;
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     /* We are using MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH for the size of this
      * buffer because it will be used to hold the private key at first and
@@ -1167,7 +1170,7 @@
 }
 #endif /* MBEDTLS_USE_PSA_CRYPTO */
 
-static int eckey_check_pair(const void *pub, const void *prv,
+static int eckey_check_pair(mbedtls_pk_context *pub, mbedtls_pk_context *prv,
                             int (*f_rng)(void *, unsigned char *, size_t),
                             void *p_rng)
 {
@@ -1176,8 +1179,8 @@
     (void) p_rng;
     return eckey_check_pair_psa(pub, prv);
 #elif defined(MBEDTLS_ECP_C)
-    return mbedtls_ecp_check_pub_priv((const mbedtls_ecp_keypair *) pub,
-                                      (const mbedtls_ecp_keypair *) prv,
+    return mbedtls_ecp_check_pub_priv((const mbedtls_ecp_keypair *) pub->pk_ctx,
+                                      (const mbedtls_ecp_keypair *) prv->pk_ctx,
                                       f_rng, p_rng);
 #else
     return MBEDTLS_ERR_PK_FEATURE_UNAVAILABLE;
@@ -1201,11 +1204,12 @@
     mbedtls_free(ctx);
 }
 
-static void eckey_debug(const void *ctx, mbedtls_pk_debug_item *items)
+static void eckey_debug(mbedtls_pk_context *pk, mbedtls_pk_debug_item *items)
 {
+    mbedtls_ecp_keypair *ecp = (mbedtls_ecp_keypair *) pk->pk_ctx;
     items->type = MBEDTLS_PK_DEBUG_ECP;
     items->name = "eckey.Q";
-    items->value = &(((mbedtls_ecp_keypair *) ctx)->Q);
+    items->value = &(ecp->Q);
 }
 
 const mbedtls_pk_info_t mbedtls_eckey_info = {
@@ -1279,7 +1283,7 @@
 }
 
 #if defined(MBEDTLS_ECDSA_C) && defined(MBEDTLS_ECP_RESTARTABLE)
-static int ecdsa_verify_rs_wrap(void *ctx, mbedtls_md_type_t md_alg,
+static int ecdsa_verify_rs_wrap(mbedtls_pk_context *pk, mbedtls_md_type_t md_alg,
                                 const unsigned char *hash, size_t hash_len,
                                 const unsigned char *sig, size_t sig_len,
                                 void *rs_ctx)
@@ -1288,7 +1292,7 @@
     ((void) md_alg);
 
     ret = mbedtls_ecdsa_read_signature_restartable(
-        (mbedtls_ecdsa_context *) ctx,
+        (mbedtls_ecdsa_context *) pk->pk_ctx,
         hash, hash_len, sig, sig_len,
         (mbedtls_ecdsa_restart_ctx *) rs_ctx);
 
@@ -1299,14 +1303,14 @@
     return ret;
 }
 
-static int ecdsa_sign_rs_wrap(void *ctx, mbedtls_md_type_t md_alg,
+static int ecdsa_sign_rs_wrap(mbedtls_pk_context *pk, mbedtls_md_type_t md_alg,
                               const unsigned char *hash, size_t hash_len,
                               unsigned char *sig, size_t sig_size, size_t *sig_len,
                               int (*f_rng)(void *, unsigned char *, size_t), void *p_rng,
                               void *rs_ctx)
 {
     return mbedtls_ecdsa_write_signature_restartable(
-        (mbedtls_ecdsa_context *) ctx,
+        (mbedtls_ecdsa_context *) pk->pk_ctx,
         md_alg, hash, hash_len, sig, sig_size, sig_len, f_rng, p_rng,
         (mbedtls_ecdsa_restart_ctx *) rs_ctx);
 
@@ -1372,19 +1376,19 @@
     return type == MBEDTLS_PK_RSA;
 }
 
-static size_t rsa_alt_get_bitlen(const void *ctx)
+static size_t rsa_alt_get_bitlen(mbedtls_pk_context *pk)
 {
-    const mbedtls_rsa_alt_context *rsa_alt = (const mbedtls_rsa_alt_context *) ctx;
+    const mbedtls_rsa_alt_context *rsa_alt = pk->pk_ctx;
 
     return 8 * rsa_alt->key_len_func(rsa_alt->key);
 }
 
-static int rsa_alt_sign_wrap(void *ctx, mbedtls_md_type_t md_alg,
+static int rsa_alt_sign_wrap(mbedtls_pk_context *pk, mbedtls_md_type_t md_alg,
                              const unsigned char *hash, size_t hash_len,
                              unsigned char *sig, size_t sig_size, size_t *sig_len,
                              int (*f_rng)(void *, unsigned char *, size_t), void *p_rng)
 {
-    mbedtls_rsa_alt_context *rsa_alt = (mbedtls_rsa_alt_context *) ctx;
+    mbedtls_rsa_alt_context *rsa_alt = pk->pk_ctx;
 
     if (UINT_MAX < hash_len) {
         return MBEDTLS_ERR_PK_BAD_INPUT_DATA;
@@ -1402,12 +1406,12 @@
                               md_alg, (unsigned int) hash_len, hash, sig);
 }
 
-static int rsa_alt_decrypt_wrap(void *ctx,
+static int rsa_alt_decrypt_wrap(mbedtls_pk_context *pk,
                                 const unsigned char *input, size_t ilen,
                                 unsigned char *output, size_t *olen, size_t osize,
                                 int (*f_rng)(void *, unsigned char *, size_t), void *p_rng)
 {
-    mbedtls_rsa_alt_context *rsa_alt = (mbedtls_rsa_alt_context *) ctx;
+    mbedtls_rsa_alt_context *rsa_alt = pk->pk_ctx;
 
     ((void) f_rng);
     ((void) p_rng);
@@ -1421,7 +1425,7 @@
 }
 
 #if defined(MBEDTLS_RSA_C)
-static int rsa_alt_check_pair(const void *pub, const void *prv,
+static int rsa_alt_check_pair(mbedtls_pk_context *pub, mbedtls_pk_context *prv,
                               int (*f_rng)(void *, unsigned char *, size_t),
                               void *p_rng)
 {
@@ -1436,14 +1440,14 @@
 
     memset(hash, 0x2a, sizeof(hash));
 
-    if ((ret = rsa_alt_sign_wrap((void *) prv, MBEDTLS_MD_NONE,
+    if ((ret = rsa_alt_sign_wrap(prv, MBEDTLS_MD_NONE,
                                  hash, sizeof(hash),
                                  sig, sizeof(sig), &sig_len,
                                  f_rng, p_rng)) != 0) {
         return ret;
     }
 
-    if (rsa_verify_wrap((void *) pub, MBEDTLS_MD_NONE,
+    if (rsa_verify_wrap(pub, MBEDTLS_MD_NONE,
                         hash, sizeof(hash), sig, sig_len) != 0) {
         return MBEDTLS_ERR_RSA_KEY_CHECK_FAILED;
     }
@@ -1515,9 +1519,9 @@
     mbedtls_free(ctx);
 }
 
-static size_t pk_opaque_get_bitlen(const void *ctx)
+static size_t pk_opaque_get_bitlen(mbedtls_pk_context *pk)
 {
-    const mbedtls_svc_key_id_t *key = (const mbedtls_svc_key_id_t *) ctx;
+    const mbedtls_svc_key_id_t *key = pk->pk_ctx;
     size_t bits;
     psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
 
@@ -1542,13 +1546,13 @@
            type == MBEDTLS_PK_RSASSA_PSS;
 }
 
-static int pk_opaque_sign_wrap(void *ctx, mbedtls_md_type_t md_alg,
+static int pk_opaque_sign_wrap(mbedtls_pk_context *pk, mbedtls_md_type_t md_alg,
                                const unsigned char *hash, size_t hash_len,
                                unsigned char *sig, size_t sig_size, size_t *sig_len,
                                int (*f_rng)(void *, unsigned char *, size_t), void *p_rng)
 {
 #if !defined(MBEDTLS_PK_CAN_ECDSA_SIGN) && !defined(MBEDTLS_RSA_C)
-    ((void) ctx);
+    ((void) pk);
     ((void) md_alg);
     ((void) hash);
     ((void) hash_len);
@@ -1559,7 +1563,7 @@
     ((void) p_rng);
     return MBEDTLS_ERR_PK_FEATURE_UNAVAILABLE;
 #else /* !MBEDTLS_PK_CAN_ECDSA_SIGN && !MBEDTLS_RSA_C */
-    const mbedtls_svc_key_id_t *key = (const mbedtls_svc_key_id_t *) ctx;
+    const mbedtls_svc_key_id_t *key = pk->pk_ctx;
     psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
     psa_algorithm_t alg;
     psa_key_type_t type;
@@ -1641,12 +1645,12 @@
 };
 
 #if defined(PSA_WANT_KEY_TYPE_RSA_KEY_PAIR)
-static int pk_opaque_rsa_decrypt(void *ctx,
+static int pk_opaque_rsa_decrypt(mbedtls_pk_context *pk,
                                  const unsigned char *input, size_t ilen,
                                  unsigned char *output, size_t *olen, size_t osize,
                                  int (*f_rng)(void *, unsigned char *, size_t), void *p_rng)
 {
-    const mbedtls_svc_key_id_t *key = (const mbedtls_svc_key_id_t *) ctx;
+    const mbedtls_svc_key_id_t *key = pk->pk_ctx;
     psa_status_t status;
 
     /* PSA has its own RNG */