diff --git a/include/polarssl/pk.h b/include/polarssl/pk.h
index a39fadf..da13136 100644
--- a/include/polarssl/pk.h
+++ b/include/polarssl/pk.h
@@ -148,31 +148,38 @@
 } pk_context;
 
 /**
+ * \brief           Return information associated with the given PK type
+ *
+ * \param type      PK type to search for.
+ *
+ * \return          The PK info associated with the type or NULL if not found.
+ */
+const pk_info_t *pk_info_from_type( pk_type_t pk_type );
+
+/**
  * \brief           Initialize a pk_context (as NONE)
  */
 void pk_init( pk_context *ctx );
 
 /**
+ * \brief           Initialize a PK context with the information given
+ *                  and allocates the type-specific PK subcontext.
+ *
+ * \param ctx       Context to initialize. Must be empty (type NONE).
+ * \param info      Information to use
+ *
+ * \return          0 on success,
+ *                  POLARSSL_ERR_PK_BAD_INPUT_DATA on invalid input,
+ *                  POLARSSL_ERR_PK_MALLOC_FAILED on allocation failure.
+ */
+int pk_init_ctx( pk_context *ctx, const pk_info_t *info );
+
+/**
  * \brief           Free a pk_context
  */
 void pk_free( pk_context *ctx );
 
 /**
- * \brief           Set a pk_context to a given type
- *
- * \param ctx       Context to initialize
- * \param type      Type of key
- *
- * \note            Once the type of a key has been set, it cannot be reset.
- *                  If you want to do so, you need to use pk_free() first.
- *
- * \return          O on success,
- *                  POLARSSL_ERR_PK_MALLOC_FAILED on memory allocation fail,
- *                  POLARSSL_ERR_PK_TYPE_MISMATCH on attempts to reset type.
- */
-int pk_set_type( pk_context *ctx, pk_type_t type );
-
-/**
  * \brief           Get the size in bits of the underlying key
  *
  * \param ctx       Context to use
diff --git a/library/pk.c b/library/pk.c
index 61544eb..4c16de8 100644
--- a/library/pk.c
+++ b/library/pk.c
@@ -67,7 +67,7 @@
 /*
  * Get pk_info structure from type
  */
-static const pk_info_t * pk_info_from_type( pk_type_t pk_type )
+const pk_info_t * pk_info_from_type( pk_type_t pk_type )
 {
     switch( pk_type ) {
 #if defined(POLARSSL_RSA_C)
@@ -90,21 +90,11 @@
 }
 
 /*
- * Set a pk_context to a given type
+ * Initialise context
  */
-int pk_set_type( pk_context *ctx, pk_type_t type )
+int pk_init_ctx( pk_context *ctx, const pk_info_t *info )
 {
-    const pk_info_t *info;
-
-    if( ctx->pk_info != NULL )
-    {
-        if( ctx->pk_info->type == type )
-            return 0;
-
-        return( POLARSSL_ERR_PK_TYPE_MISMATCH );
-    }
-
-    if( ( info = pk_info_from_type( type ) ) == NULL )
+    if( ctx == NULL || info == NULL || ctx->pk_info != NULL )
         return( POLARSSL_ERR_PK_BAD_INPUT_DATA );
 
     if( ( ctx->pk_ctx = info->ctx_alloc_func() ) == NULL )
diff --git a/library/x509parse.c b/library/x509parse.c
index e080174..4da4e75 100644
--- a/library/x509parse.c
+++ b/library/x509parse.c
@@ -570,6 +570,7 @@
     size_t len;
     x509_buf alg_params;
     pk_type_t pk_alg = POLARSSL_PK_NONE;
+    const pk_info_t *pk_info;
 
     if( ( ret = asn1_get_tag( p, end, &len,
                     ASN1_CONSTRUCTED | ASN1_SEQUENCE ) ) != 0 )
@@ -589,7 +590,10 @@
         return( POLARSSL_ERR_X509_CERT_INVALID_PUBKEY +
                 POLARSSL_ERR_ASN1_LENGTH_MISMATCH );
 
-    if( ( ret = pk_set_type( pk, pk_alg ) ) != 0 )
+    if( ( pk_info = pk_info_from_type( pk_alg ) ) == NULL )
+        return( POLARSSL_ERR_X509_UNKNOWN_PK_ALG );
+
+    if( ( ret = pk_init_ctx( pk, pk_info ) ) != 0 )
         return( ret );
 
 #if defined(POLARSSL_RSA_C)
@@ -2142,10 +2146,12 @@
     pk_context pk;
 
     pk_init( &pk );
-    pk_set_type( &pk, POLARSSL_PK_RSA );
 
     ret = x509parse_keyfile( &pk, path, pwd );
 
+    if( ret == 0 && ! pk_can_do( &pk, POLARSSL_PK_RSA ) )
+        ret = POLARSSL_ERR_PK_TYPE_MISMATCH;
+
     if( ret == 0 )
         rsa_copy( rsa, pk_rsa( pk ) );
     else
@@ -2165,10 +2171,12 @@
     pk_context pk;
 
     pk_init( &pk );
-    pk_set_type( &pk, POLARSSL_PK_RSA );
 
     ret = x509parse_public_keyfile( &pk, path );
 
+    if( ret == 0 && ! pk_can_do( &pk, POLARSSL_PK_RSA ) )
+        ret = POLARSSL_ERR_PK_TYPE_MISMATCH;
+
     if( ret == 0 )
         rsa_copy( rsa, pk_rsa( pk ) );
     else
@@ -2380,6 +2388,7 @@
     unsigned char *p = (unsigned char *) key;
     unsigned char *end = p + keylen;
     pk_type_t pk_alg = POLARSSL_PK_NONE;
+    const pk_info_t *pk_info;
 
     /*
      * This function parses the PrivatKeyInfo object (PKCS#8 v1.2 = RFC 5208)
@@ -2421,7 +2430,10 @@
         return( POLARSSL_ERR_X509_KEY_INVALID_FORMAT +
                 POLARSSL_ERR_ASN1_OUT_OF_DATA );
 
-    if( ( ret = pk_set_type( pk, pk_alg ) ) != 0 )
+    if( ( pk_info = pk_info_from_type( pk_alg ) ) == NULL )
+        return( POLARSSL_ERR_X509_UNKNOWN_PK_ALG );
+
+    if( ( ret = pk_init_ctx( pk, pk_info ) ) != 0 )
         return( ret );
 
 #if defined(POLARSSL_RSA_C)
@@ -2568,6 +2580,7 @@
                    const unsigned char *pwd, size_t pwdlen )
 {
     int ret;
+    const pk_info_t *pk_info;
 
 #if defined(POLARSSL_PEM_C)
     size_t len;
@@ -2582,7 +2595,10 @@
                            key, pwd, pwdlen, &len );
     if( ret == 0 )
     {
-        if( ( ret = pk_set_type( pk, POLARSSL_PK_RSA             ) ) != 0 ||
+        if( ( pk_info = pk_info_from_type( POLARSSL_PK_RSA ) ) == NULL )
+            return( POLARSSL_ERR_X509_UNKNOWN_PK_ALG );
+
+        if( ( ret = pk_init_ctx( pk, pk_info                     ) ) != 0 ||
             ( ret = x509parse_key_pkcs1_der( pk_rsa( *pk ),
                                              pem.buf, pem.buflen ) ) != 0 )
         {
@@ -2607,7 +2623,10 @@
                            key, pwd, pwdlen, &len );
     if( ret == 0 )
     {
-        if( ( ret = pk_set_type( pk, POLARSSL_PK_ECKEY          ) ) != 0 ||
+        if( ( pk_info = pk_info_from_type( POLARSSL_PK_ECKEY ) ) == NULL )
+            return( POLARSSL_ERR_X509_UNKNOWN_PK_ALG );
+
+        if( ( ret = pk_init_ctx( pk, pk_info                    ) ) != 0 ||
             ( ret = x509parse_key_sec1_der( pk_ec( *pk ),
                                             pem.buf, pem.buflen ) ) != 0 )
         {
@@ -2692,7 +2711,10 @@
     pk_free( pk );
 
 #if defined(POLARSSL_RSA_C)
-    if( ( ret = pk_set_type( pk, POLARSSL_PK_RSA                    ) ) == 0 &&
+    if( ( pk_info = pk_info_from_type( POLARSSL_PK_RSA ) ) == NULL )
+        return( POLARSSL_ERR_X509_UNKNOWN_PK_ALG );
+
+    if( ( ret = pk_init_ctx( pk, pk_info                            ) ) != 0 ||
         ( ret = x509parse_key_pkcs1_der( pk_rsa( *pk ), key, keylen ) ) == 0 )
     {
         return( 0 );
@@ -2702,7 +2724,10 @@
 #endif /* POLARSSL_RSA_C */
 
 #if defined(POLARSSL_ECP_C)
-    if( ( ret = pk_set_type( pk, POLARSSL_PK_ECKEY                ) ) == 0 &&
+    if( ( pk_info = pk_info_from_type( POLARSSL_PK_ECKEY ) ) == NULL )
+        return( POLARSSL_ERR_X509_UNKNOWN_PK_ALG );
+
+    if( ( ret = pk_init_ctx( pk, pk_info                          ) ) != 0 ||
         ( ret = x509parse_key_sec1_der( pk_ec( *pk ), key, keylen ) ) == 0 )
     {
         return( 0 );
@@ -2769,10 +2794,12 @@
     pk_context pk;
 
     pk_init( &pk );
-    pk_set_type( &pk, POLARSSL_PK_RSA );
 
     ret = x509parse_key( &pk, key, keylen, pwd, pwdlen );
 
+    if( ret == 0 && ! pk_can_do( &pk, POLARSSL_PK_RSA ) )
+        ret = POLARSSL_ERR_PK_TYPE_MISMATCH;
+
     if( ret == 0 )
         rsa_copy( rsa, pk_rsa( pk ) );
     else
@@ -2793,10 +2820,12 @@
     pk_context pk;
 
     pk_init( &pk );
-    pk_set_type( &pk, POLARSSL_PK_RSA );
 
     ret = x509parse_public_key( &pk, key, keylen );
 
+    if( ret == 0 && ! pk_can_do( &pk, POLARSSL_PK_RSA ) )
+        ret = POLARSSL_ERR_PK_TYPE_MISMATCH;
+
     if( ret == 0 )
         rsa_copy( rsa, pk_rsa( pk ) );
     else
