- RSASSA-PSS verification now properly handles salt lengths other than hlen

diff --git a/library/rsa.c b/library/rsa.c
index ec44d84..0a6c490 100644
--- a/library/rsa.c
+++ b/library/rsa.c
@@ -580,7 +580,7 @@
     unsigned char salt[POLARSSL_MD_MAX_SIZE];
     const md_info_t *md_info;
     md_context_t md_ctx;
-    int i, hlen, msb, offset = 0;
+    int i, slen, hlen, msb, offset = 0;
 #else
     (void) f_rng;
     (void) p_rng;
@@ -733,6 +733,8 @@
                 return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
                 
             hlen = md_get_size( md_info );
+            slen = hlen;
+
             memset( sig, 0, olen );
             memset( &md_ctx, 0, sizeof( md_context_t ) );
 
@@ -740,9 +742,9 @@
 
             msb = mpi_msb( &ctx->N ) - 1;
 
-            // Generate salt of length hlen
+            // Generate salt of length slen
             //
-            for( i = 0; i < hlen; ++i )
+            for( i = 0; i < slen; ++i )
                 salt[i] = (unsigned char) f_rng( p_rng ); 
 
             // Note: EMSA-PSS encoding is over the length of N - 1 bits
@@ -750,15 +752,15 @@
             msb = mpi_msb( &ctx->N ) - 1;
             p += olen - hlen * 2 - 2;
             *p++ = 0x01;
-            memcpy( p, salt, hlen );
-            p += hlen;
+            memcpy( p, salt, slen );
+            p += slen;
 
             // Generate H = Hash( M' )
             //
             md_starts( &md_ctx );
             md_update( &md_ctx, p, 8 );
             md_update( &md_ctx, hash, hashlen );
-            md_update( &md_ctx, salt, hlen );
+            md_update( &md_ctx, salt, slen );
             md_finish( &md_ctx, p );
 
             // Compensate for boundary condition when applying mask
@@ -805,7 +807,7 @@
     unsigned char zeros[8];
     const md_info_t *md_info;
     md_context_t md_ctx;
-    int hlen, msb;
+    int slen, hlen, msb;
 #endif
     siglen = ctx->len;
 
@@ -935,6 +937,8 @@
                 return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
                 
             hlen = md_get_size( md_info );
+            slen = siglen - hlen - 1;
+
             memset( &md_ctx, 0, sizeof( md_context_t ) );
             memset( zeros, 0, 8 );
 
@@ -967,15 +971,17 @@
             if( *p++ != 0x01 )
                 return( POLARSSL_ERR_RSA_INVALID_PADDING );
 
+            slen -= p - buf;
+
             // Generate H = Hash( M' )
             //
             md_starts( &md_ctx );
             md_update( &md_ctx, zeros, 8 );
             md_update( &md_ctx, hash, hashlen );
-            md_update( &md_ctx, p, hlen );
+            md_update( &md_ctx, p, slen );
             md_finish( &md_ctx, p );
 
-            if( memcmp( p, p + hlen, hlen ) == 0 )
+            if( memcmp( p, p + slen, hlen ) == 0 )
                 return( 0 );
             else
                 return( POLARSSL_ERR_RSA_VERIFY_FAILED );
diff --git a/library/x509parse.c b/library/x509parse.c
index e330743..df671ef 100644
--- a/library/x509parse.c
+++ b/library/x509parse.c
@@ -1827,8 +1827,6 @@
 #endif
     end = p + keylen;
 
-    memset( rsa, 0, sizeof( rsa_context ) );
-
     /*
      *  RSAPrivateKey ::= SEQUENCE {
      *      version           Version,
@@ -1942,6 +1940,116 @@
 }
 
 /*
+ * Parse a public RSA key
+ */
+int x509parse_public_key( rsa_context *rsa, const unsigned char *key, int keylen )
+{
+    int ret, len;
+    unsigned char *p, *end;
+    x509_buf alg_oid;
+#if defined(POLARSSL_PEM_C)
+    pem_context pem;
+
+    pem_init( &pem );
+    ret = pem_read_buffer( &pem,
+            "-----BEGIN PUBLIC KEY-----",
+            "-----END PUBLIC KEY-----",
+            key, NULL, 0, &len );
+
+    if( ret == 0 )
+    {
+        /*
+         * Was PEM encoded
+         */
+        keylen = pem.buflen;
+    }
+    else if( ret != POLARSSL_ERR_PEM_NO_HEADER_PRESENT )
+    {
+        pem_free( &pem );
+        return( ret );
+    }
+
+    p = ( ret == 0 ) ? pem.buf : (unsigned char *) key;
+#else
+    p = (unsigned char *) key;
+#endif
+    end = p + keylen;
+
+    /*
+     *  PublicKeyInfo ::= SEQUENCE {
+     *    algorithm       AlgorithmIdentifier,
+     *    PublicKey       BIT STRING
+     *  }
+     *
+     *  AlgorithmIdentifier ::= SEQUENCE {
+     *    algorithm       OBJECT IDENTIFIER,
+     *    parameters      ANY DEFINED BY algorithm OPTIONAL
+     *  }
+     *
+     *  RSAPublicKey ::= SEQUENCE {
+     *      modulus           INTEGER,  -- n
+     *      publicExponent    INTEGER   -- e
+     *  }
+     */
+
+    if( ( ret = asn1_get_tag( &p, end, &len,
+                    ASN1_CONSTRUCTED | ASN1_SEQUENCE ) ) != 0 )
+    {
+#if defined(POLARSSL_PEM_C)
+        pem_free( &pem );
+#endif
+        rsa_free( rsa );
+        return( POLARSSL_ERR_X509_CERT_INVALID_FORMAT | ret );
+    }
+
+    if( ( ret = x509_get_pubkey( &p, end, &alg_oid, &rsa->N, &rsa->E ) ) != 0 )
+    {
+#if defined(POLARSSL_PEM_C)
+        pem_free( &pem );
+#endif
+        rsa_free( rsa );
+        return( POLARSSL_ERR_X509_KEY_INVALID_FORMAT | ret );
+    }
+
+    if( ( ret = rsa_check_pubkey( rsa ) ) != 0 )
+    {
+#if defined(POLARSSL_PEM_C)
+        pem_free( &pem );
+#endif
+        rsa_free( rsa );
+        return( ret );
+    }
+
+    rsa->len = mpi_size( &rsa->N );
+
+#if defined(POLARSSL_PEM_C)
+    pem_free( &pem );
+#endif
+
+    return( 0 );
+}
+
+/*
+ * Load and parse a public RSA key
+ */
+int x509parse_public_keyfile( rsa_context *rsa, const char *path )
+{
+    int ret;
+    size_t n;
+    unsigned char *buf;
+
+    if ( load_file( path, &buf, &n ) )
+        return( 1 );
+
+    ret = x509parse_public_key( rsa, buf, (int) n );
+
+    memset( buf, 0, n + 1 );
+    free( buf );
+
+    return( ret );
+}
+
+/*
  * Parse DHM parameters
  */
 int x509parse_dhm( dhm_context *dhm, const unsigned char *dhmin, int dhminlen )