Split up the RSA PKCS#1 encrypt, decrypt, sign and verify functions

Split rsa_pkcs1_encrypt() into rsa_rsaes_oaep_encrypt() and
rsa_rsaes_pkcs1_v15_encrypt()
Split rsa_pkcs1_decrypt() into rsa_rsaes_oaep_decrypt() and
rsa_rsaes_pkcs1_v15_decrypt()
Split rsa_pkcs1_sign() into rsa_rsassa_pss_sign() and
rsa_rsassa_pkcs1_v15_sign()
Split rsa_pkcs1_verify() into rsa_rsassa_pss_verify() and
rsa_rsassa_pkcs1_v15_verify()

The original functions exist as generic wrappers to these functions.
diff --git a/library/rsa.c b/library/rsa.c
index ee6ca01..d41928f 100644
--- a/library/rsa.c
+++ b/library/rsa.c
@@ -361,6 +361,138 @@
 }
 #endif
 
+#if defined(POLARSSL_PKCS1_V21)
+/*
+ * Implementation of the PKCS#1 v2.1 RSAES-OAEP-ENCRYPT function
+ */
+int rsa_rsaes_oaep_encrypt( rsa_context *ctx,
+                            int (*f_rng)(void *, unsigned char *, size_t),
+                            void *p_rng,
+                            int mode, size_t ilen,
+                            const unsigned char *input,
+                            unsigned char *output )
+{
+    size_t olen;
+    int ret;
+    unsigned char *p = output;
+    unsigned int hlen;
+    const md_info_t *md_info;
+    md_context_t md_ctx;
+
+    if( ctx->padding != RSA_PKCS_V21 || f_rng == NULL )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+    md_info = md_info_from_type( ctx->hash_id );
+
+    if( md_info == NULL )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+    olen = ctx->len;
+    hlen = md_get_size( md_info );
+
+    if( olen < ilen + 2 * hlen + 2 || f_rng == NULL )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+    memset( output, 0, olen );
+
+    *p++ = 0;
+
+    // Generate a random octet string seed
+    //
+    if( ( ret = f_rng( p_rng, p, hlen ) ) != 0 )
+        return( POLARSSL_ERR_RSA_RNG_FAILED + ret );
+
+    p += hlen;
+
+    // Construct DB
+    //
+    md( md_info, p, 0, p );
+    p += hlen;
+    p += olen - 2 * hlen - 2 - ilen;
+    *p++ = 1;
+    memcpy( p, input, ilen );
+
+    md_init_ctx( &md_ctx, md_info );
+
+    // maskedDB: Apply dbMask to DB
+    //
+    mgf_mask( output + hlen + 1, olen - hlen - 1, output + 1, hlen,
+               &md_ctx );
+
+    // maskedSeed: Apply seedMask to seed
+    //
+    mgf_mask( output + 1, hlen, output + hlen + 1, olen - hlen - 1,
+               &md_ctx );
+
+    md_free_ctx( &md_ctx );
+
+    return( ( mode == RSA_PUBLIC )
+            ? rsa_public(  ctx, output, output )
+            : rsa_private( ctx, output, output ) );
+}
+#endif /* POLARSSL_PKCS1_V21 */
+
+/*
+ * Implementation of the PKCS#1 v2.1 RSAES-PKCS1-V1_5-ENCRYPT function
+ */
+int rsa_rsaes_pkcs1_v15_encrypt( rsa_context *ctx,
+                                 int (*f_rng)(void *, unsigned char *, size_t),
+                                 void *p_rng,
+                                 int mode, size_t ilen,
+                                 const unsigned char *input,
+                                 unsigned char *output )
+{
+    size_t nb_pad, olen;
+    int ret;
+    unsigned char *p = output;
+
+    if( ctx->padding != RSA_PKCS_V15 || f_rng == NULL )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+    olen = ctx->len;
+
+    if( olen < ilen + 11 )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+    nb_pad = olen - 3 - ilen;
+
+    *p++ = 0;
+    if( mode == RSA_PUBLIC )
+    {
+        *p++ = RSA_CRYPT;
+
+        while( nb_pad-- > 0 )
+        {
+            int rng_dl = 100;
+
+            do {
+                ret = f_rng( p_rng, p, 1 );
+            } while( *p == 0 && --rng_dl && ret == 0 );
+
+            // Check if RNG failed to generate data
+            //
+            if( rng_dl == 0 || ret != 0)
+                return POLARSSL_ERR_RSA_RNG_FAILED + ret;
+
+            p++;
+        }
+    }
+    else
+    {
+        *p++ = RSA_SIGN;
+
+        while( nb_pad-- > 0 )
+            *p++ = 0xFF;
+    }
+
+    *p++ = 0;
+    memcpy( p, input, ilen );
+
+    return( ( mode == RSA_PUBLIC )
+            ? rsa_public(  ctx, output, output )
+            : rsa_private( ctx, output, output ) );
+}
+
 /*
  * Add the message padding, then do an RSA operation
  */
@@ -371,139 +503,44 @@
                        const unsigned char *input,
                        unsigned char *output )
 {
-    size_t nb_pad, olen;
-    int ret;
-    unsigned char *p = output;
-#if defined(POLARSSL_PKCS1_V21)
-    unsigned int hlen;
-    const md_info_t *md_info;
-    md_context_t md_ctx;
-#endif
-
-    olen = ctx->len;
-
-    if( f_rng == NULL )
-        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
-
     switch( ctx->padding )
     {
         case RSA_PKCS_V15:
+            return rsa_rsaes_pkcs1_v15_encrypt( ctx, f_rng, p_rng, mode, ilen,
+                                                input, output );
 
-            if( olen < ilen + 11 )
-                return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
-
-            nb_pad = olen - 3 - ilen;
-
-            *p++ = 0;
-            if( mode == RSA_PUBLIC )
-            {
-                *p++ = RSA_CRYPT;
-
-                while( nb_pad-- > 0 )
-                {
-                    int rng_dl = 100;
-
-                    do {
-                        ret = f_rng( p_rng, p, 1 );
-                    } while( *p == 0 && --rng_dl && ret == 0 );
-
-                    // Check if RNG failed to generate data
-                    //
-                    if( rng_dl == 0 || ret != 0)
-                        return POLARSSL_ERR_RSA_RNG_FAILED + ret;
-
-                    p++;
-                }
-            }
-            else
-            {
-                *p++ = RSA_SIGN;
-
-                while( nb_pad-- > 0 )
-                    *p++ = 0xFF;
-            }
-
-            *p++ = 0;
-            memcpy( p, input, ilen );
-            break;
-        
 #if defined(POLARSSL_PKCS1_V21)
         case RSA_PKCS_V21:
-
-            md_info = md_info_from_type( ctx->hash_id );
-            if( md_info == NULL )
-                return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
-
-            hlen = md_get_size( md_info );
-
-            if( olen < ilen + 2 * hlen + 2 || f_rng == NULL )
-                return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
-
-            memset( output, 0, olen );
-
-            *p++ = 0;
-
-            // Generate a random octet string seed
-            //
-            if( ( ret = f_rng( p_rng, p, hlen ) ) != 0 )
-                return( POLARSSL_ERR_RSA_RNG_FAILED + ret );
-
-            p += hlen;
-
-            // Construct DB
-            //
-            md( md_info, p, 0, p );
-            p += hlen;
-            p += olen - 2 * hlen - 2 - ilen;
-            *p++ = 1;
-            memcpy( p, input, ilen ); 
-
-            md_init_ctx( &md_ctx, md_info );
-
-            // maskedDB: Apply dbMask to DB
-            //
-            mgf_mask( output + hlen + 1, olen - hlen - 1, output + 1, hlen,  
-                       &md_ctx );
-
-            // maskedSeed: Apply seedMask to seed
-            //
-            mgf_mask( output + 1, hlen, output + hlen + 1, olen - hlen - 1,  
-                       &md_ctx );
-
-            md_free_ctx( &md_ctx );
-            break;
+            return rsa_rsaes_oaep_encrypt( ctx, f_rng, p_rng, mode, NULL, 0,
+                                           ilen, input, output );
 #endif
 
         default:
-
             return( POLARSSL_ERR_RSA_INVALID_PADDING );
     }
-
-    return( ( mode == RSA_PUBLIC )
-            ? rsa_public(  ctx, output, output )
-            : rsa_private( ctx, output, output ) );
 }
 
+#if defined(POLARSSL_PKCS1_V21)
 /*
- * Do an RSA operation, then remove the message padding
+ * Implementation of the PKCS#1 v2.1 RSAES-OAEP-DECRYPT function
  */
-int rsa_pkcs1_decrypt( rsa_context *ctx,
-                       int mode, size_t *olen,
-                       const unsigned char *input,
-                       unsigned char *output,
-                       size_t output_max_len)
+int rsa_rsaes_oaep_decrypt( rsa_context *ctx,
+                            int mode, size_t *olen,
+                            const unsigned char *input,
+                            unsigned char *output,
+                            size_t output_max_len )
 {
     int ret;
     size_t ilen;
     unsigned char *p;
-    unsigned char bt;
     unsigned char buf[POLARSSL_MPI_MAX_SIZE];
-#if defined(POLARSSL_PKCS1_V21)
     unsigned char lhash[POLARSSL_MD_MAX_SIZE];
     unsigned int hlen;
     const md_info_t *md_info;
     md_context_t md_ctx;
-#endif
+
+    if( ctx->padding != RSA_PKCS_V21 )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
 
     ilen = ctx->len;
 
@@ -519,96 +556,121 @@
 
     p = buf;
 
-    switch( ctx->padding )
+    if( *p++ != 0 )
+        return( POLARSSL_ERR_RSA_INVALID_PADDING );
+
+    md_info = md_info_from_type( ctx->hash_id );
+    if( md_info == NULL )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+    hlen = md_get_size( md_info );
+
+    md_init_ctx( &md_ctx, md_info );
+
+    // Generate lHash
+    //
+    md( md_info, lhash, 0, lhash );
+
+    // seed: Apply seedMask to maskedSeed
+    //
+    mgf_mask( buf + 1, hlen, buf + hlen + 1, ilen - hlen - 1,
+               &md_ctx );
+
+    // DB: Apply dbMask to maskedDB
+    //
+    mgf_mask( buf + hlen + 1, ilen - hlen - 1, buf + 1, hlen,
+               &md_ctx );
+
+    p += hlen;
+    md_free_ctx( &md_ctx );
+
+    // Check validity
+    //
+    if( memcmp( lhash, p, hlen ) != 0 )
+        return( POLARSSL_ERR_RSA_INVALID_PADDING );
+
+    p += hlen;
+
+    while( *p == 0 && p < buf + ilen )
+        p++;
+
+    if( p == buf + ilen )
+        return( POLARSSL_ERR_RSA_INVALID_PADDING );
+
+    if( *p++ != 0x01 )
+        return( POLARSSL_ERR_RSA_INVALID_PADDING );
+
+    if (ilen - (p - buf) > output_max_len)
+        return( POLARSSL_ERR_RSA_OUTPUT_TOO_LARGE );
+
+    *olen = ilen - (p - buf);
+    memcpy( output, p, *olen );
+
+    return( 0 );
+}
+#endif /* POLARSSL_PKCS1_V21 */
+
+/*
+ * Implementation of the PKCS#1 v2.1 RSAES-PKCS1-V1_5-DECRYPT function
+ */
+int rsa_rsaes_pkcs1_v15_decrypt( rsa_context *ctx,
+                                 int mode, size_t *olen,
+                                 const unsigned char *input,
+                                 unsigned char *output,
+                                 size_t output_max_len)
+{
+    int ret;
+    size_t ilen;
+    unsigned char *p;
+    unsigned char bt;
+    unsigned char buf[POLARSSL_MPI_MAX_SIZE];
+
+    if( ctx->padding != RSA_PKCS_V15 )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+    ilen = ctx->len;
+
+    if( ilen < 16 || ilen > sizeof( buf ) )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+    ret = ( mode == RSA_PUBLIC )
+          ? rsa_public(  ctx, input, buf )
+          : rsa_private( ctx, input, buf );
+
+    if( ret != 0 )
+        return( ret );
+
+    p = buf;
+
+    if( *p++ != 0 )
+        return( POLARSSL_ERR_RSA_INVALID_PADDING );
+
+    bt = *p++;
+    if( ( bt != RSA_CRYPT && mode == RSA_PRIVATE ) ||
+        ( bt != RSA_SIGN && mode == RSA_PUBLIC ) )
     {
-        case RSA_PKCS_V15:
+        return( POLARSSL_ERR_RSA_INVALID_PADDING );
+    }
 
-            if( *p++ != 0 )
-                return( POLARSSL_ERR_RSA_INVALID_PADDING );
-            
-            bt = *p++;
-            if( ( bt != RSA_CRYPT && mode == RSA_PRIVATE ) ||
-                ( bt != RSA_SIGN && mode == RSA_PUBLIC ) )
-            {
-                return( POLARSSL_ERR_RSA_INVALID_PADDING );
-            }
+    if( bt == RSA_CRYPT )
+    {
+        while( *p != 0 && p < buf + ilen - 1 )
+            p++;
 
-            if( bt == RSA_CRYPT )
-            {
-                while( *p != 0 && p < buf + ilen - 1 )
-                    p++;
-
-                if( *p != 0 || p >= buf + ilen - 1 )
-                    return( POLARSSL_ERR_RSA_INVALID_PADDING );
-
-                p++;
-            }
-            else
-            {
-                while( *p == 0xFF && p < buf + ilen - 1 )
-                    p++;
-
-                if( *p != 0 || p >= buf + ilen - 1 )
-                    return( POLARSSL_ERR_RSA_INVALID_PADDING );
-
-                p++;
-            }
-
-            break;
-
-#if defined(POLARSSL_PKCS1_V21)
-        case RSA_PKCS_V21:
-            
-            if( *p++ != 0 )
-                return( POLARSSL_ERR_RSA_INVALID_PADDING );
-
-            md_info = md_info_from_type( ctx->hash_id );
-            if( md_info == NULL )
-                return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
-                
-            hlen = md_get_size( md_info );
-
-            md_init_ctx( &md_ctx, md_info );
-            
-            // Generate lHash
-            //
-            md( md_info, lhash, 0, lhash );
-
-            // seed: Apply seedMask to maskedSeed
-            //
-            mgf_mask( buf + 1, hlen, buf + hlen + 1, ilen - hlen - 1,
-                       &md_ctx );
-
-            // DB: Apply dbMask to maskedDB
-            //
-            mgf_mask( buf + hlen + 1, ilen - hlen - 1, buf + 1, hlen,  
-                       &md_ctx );
-
-            p += hlen;
-            md_free_ctx( &md_ctx );
-
-            // Check validity
-            //
-            if( memcmp( lhash, p, hlen ) != 0 )
-                return( POLARSSL_ERR_RSA_INVALID_PADDING );
-
-            p += hlen;
-
-            while( *p == 0 && p < buf + ilen )
-                p++;
-
-            if( p == buf + ilen )
-                return( POLARSSL_ERR_RSA_INVALID_PADDING );
-
-            if( *p++ != 0x01 )
-                return( POLARSSL_ERR_RSA_INVALID_PADDING );
-
-            break;
-#endif
-
-        default:
-
+        if( *p != 0 || p >= buf + ilen - 1 )
             return( POLARSSL_ERR_RSA_INVALID_PADDING );
+
+        p++;
+    }
+    else
+    {
+        while( *p == 0xFF && p < buf + ilen - 1 )
+            p++;
+
+        if( *p != 0 || p >= buf + ilen - 1 )
+            return( POLARSSL_ERR_RSA_INVALID_PADDING );
+
+        p++;
     }
 
     if (ilen - (p - buf) > output_max_len)
@@ -621,6 +683,273 @@
 }
 
 /*
+ * Do an RSA operation, then remove the message padding
+ */
+int rsa_pkcs1_decrypt( rsa_context *ctx,
+                       int mode, size_t *olen,
+                       const unsigned char *input,
+                       unsigned char *output,
+                       size_t output_max_len)
+{
+    switch( ctx->padding )
+    {
+        case RSA_PKCS_V15:
+            return rsa_rsaes_pkcs1_v15_decrypt( ctx, mode, olen, input, output,
+                                                output_max_len );
+
+#if defined(POLARSSL_PKCS1_V21)
+        case RSA_PKCS_V21:
+            return rsa_rsaes_oaep_decrypt( ctx, mode, NULL, 0, olen, input,
+                                           output, output_max_len );
+#endif
+
+        default:
+            return( POLARSSL_ERR_RSA_INVALID_PADDING );
+    }
+}
+
+#if defined(POLARSSL_PKCS1_V21)
+/*
+ * Implementation of the PKCS#1 v2.1 RSASSA-PSS-SIGN function
+ */
+int rsa_rsassa_pss_sign( rsa_context *ctx,
+                         int (*f_rng)(void *, unsigned char *, size_t),
+                         void *p_rng,
+                         int mode,
+                         int hash_id,
+                         unsigned int hashlen,
+                         const unsigned char *hash,
+                         unsigned char *sig )
+{
+    size_t olen;
+    unsigned char *p = sig;
+    unsigned char salt[POLARSSL_MD_MAX_SIZE];
+    unsigned int slen, hlen, offset = 0;
+    int ret;
+    size_t msb;
+    const md_info_t *md_info;
+    md_context_t md_ctx;
+
+    if( ctx->padding != RSA_PKCS_V21 || f_rng == NULL )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+    olen = ctx->len;
+
+    switch( hash_id )
+    {
+        case SIG_RSA_MD2:
+        case SIG_RSA_MD4:
+        case SIG_RSA_MD5:
+            hashlen = 16;
+            break;
+
+        case SIG_RSA_SHA1:
+            hashlen = 20;
+            break;
+
+        case SIG_RSA_SHA224:
+            hashlen = 28;
+            break;
+
+        case SIG_RSA_SHA256:
+            hashlen = 32;
+            break;
+
+        case SIG_RSA_SHA384:
+            hashlen = 48;
+            break;
+
+        case SIG_RSA_SHA512:
+            hashlen = 64;
+            break;
+
+        default:
+            return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+    }
+
+    md_info = md_info_from_type( ctx->hash_id );
+    if( md_info == NULL )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+    hlen = md_get_size( md_info );
+    slen = hlen;
+
+    if( olen < hlen + slen + 2 )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+    memset( sig, 0, olen );
+
+    msb = mpi_msb( &ctx->N ) - 1;
+
+    // Generate salt of length slen
+    //
+    if( ( ret = f_rng( p_rng, salt, slen ) ) != 0 )
+        return( POLARSSL_ERR_RSA_RNG_FAILED + ret );
+
+    // Note: EMSA-PSS encoding is over the length of N - 1 bits
+    //
+    msb = mpi_msb( &ctx->N ) - 1;
+    p += olen - hlen * 2 - 2;
+    *p++ = 0x01;
+    memcpy( p, salt, slen );
+    p += slen;
+
+    md_init_ctx( &md_ctx, md_info );
+
+    // Generate H = Hash( M' )
+    //
+    md_starts( &md_ctx );
+    md_update( &md_ctx, p, 8 );
+    md_update( &md_ctx, hash, hashlen );
+    md_update( &md_ctx, salt, slen );
+    md_finish( &md_ctx, p );
+
+    // Compensate for boundary condition when applying mask
+    //
+    if( msb % 8 == 0 )
+        offset = 1;
+
+    // maskedDB: Apply dbMask to DB
+    //
+    mgf_mask( sig + offset, olen - hlen - 1 - offset, p, hlen, &md_ctx );
+
+    md_free_ctx( &md_ctx );
+
+    msb = mpi_msb( &ctx->N ) - 1;
+    sig[0] &= 0xFF >> ( olen * 8 - msb );
+
+    p += hlen;
+    *p++ = 0xBC;
+
+    return( ( mode == RSA_PUBLIC )
+            ? rsa_public(  ctx, sig, sig )
+            : rsa_private( ctx, sig, sig ) );
+}
+#endif /* POLARSSL_PKCS1_V21 */
+
+/*
+ * Implementation of the PKCS#1 v2.1 RSASSA-PKCS1-V1_5-SIGN function
+ */
+/*
+ * Do an RSA operation to sign the message digest
+ */
+int rsa_rsassa_pkcs1_v15_sign( rsa_context *ctx,
+                               int mode,
+                               int hash_id,
+                               unsigned int hashlen,
+                               const unsigned char *hash,
+                               unsigned char *sig )
+{
+    size_t nb_pad, olen;
+    unsigned char *p = sig;
+
+    if( ctx->padding != RSA_PKCS_V15 )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+    olen = ctx->len;
+
+    switch( hash_id )
+    {
+        case SIG_RSA_RAW:
+            nb_pad = olen - 3 - hashlen;
+            break;
+
+        case SIG_RSA_MD2:
+        case SIG_RSA_MD4:
+        case SIG_RSA_MD5:
+            nb_pad = olen - 3 - 34;
+            break;
+
+        case SIG_RSA_SHA1:
+            nb_pad = olen - 3 - 35;
+            break;
+
+        case SIG_RSA_SHA224:
+            nb_pad = olen - 3 - 47;
+            break;
+
+        case SIG_RSA_SHA256:
+            nb_pad = olen - 3 - 51;
+            break;
+
+        case SIG_RSA_SHA384:
+            nb_pad = olen - 3 - 67;
+            break;
+
+        case SIG_RSA_SHA512:
+            nb_pad = olen - 3 - 83;
+            break;
+
+
+        default:
+            return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+    }
+
+    if( ( nb_pad < 8 ) || ( nb_pad > olen ) )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+    *p++ = 0;
+    *p++ = RSA_SIGN;
+    memset( p, 0xFF, nb_pad );
+    p += nb_pad;
+    *p++ = 0;
+
+    switch( hash_id )
+    {
+        case SIG_RSA_RAW:
+            memcpy( p, hash, hashlen );
+            break;
+
+        case SIG_RSA_MD2:
+            memcpy( p, ASN1_HASH_MDX, 18 );
+            memcpy( p + 18, hash, 16 );
+            p[13] = 2; break;
+
+        case SIG_RSA_MD4:
+            memcpy( p, ASN1_HASH_MDX, 18 );
+            memcpy( p + 18, hash, 16 );
+            p[13] = 4; break;
+
+        case SIG_RSA_MD5:
+            memcpy( p, ASN1_HASH_MDX, 18 );
+            memcpy( p + 18, hash, 16 );
+            p[13] = 5; break;
+
+        case SIG_RSA_SHA1:
+            memcpy( p, ASN1_HASH_SHA1, 15 );
+            memcpy( p + 15, hash, 20 );
+            break;
+
+        case SIG_RSA_SHA224:
+            memcpy( p, ASN1_HASH_SHA2X, 19 );
+            memcpy( p + 19, hash, 28 );
+            p[1] += 28; p[14] = 4; p[18] += 28; break;
+
+        case SIG_RSA_SHA256:
+            memcpy( p, ASN1_HASH_SHA2X, 19 );
+            memcpy( p + 19, hash, 32 );
+            p[1] += 32; p[14] = 1; p[18] += 32; break;
+
+        case SIG_RSA_SHA384:
+            memcpy( p, ASN1_HASH_SHA2X, 19 );
+            memcpy( p + 19, hash, 48 );
+            p[1] += 48; p[14] = 2; p[18] += 48; break;
+
+        case SIG_RSA_SHA512:
+            memcpy( p, ASN1_HASH_SHA2X, 19 );
+            memcpy( p + 19, hash, 64 );
+            p[1] += 64; p[14] = 3; p[18] += 64; break;
+
+        default:
+            return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+    }
+
+    return( ( mode == RSA_PUBLIC )
+            ? rsa_public(  ctx, sig, sig )
+            : rsa_private( ctx, sig, sig ) );
+}
+
+/*
  * Do an RSA operation to sign the message digest
  */
 int rsa_pkcs1_sign( rsa_context *ctx,
@@ -632,250 +961,48 @@
                     const unsigned char *hash,
                     unsigned char *sig )
 {
-    size_t nb_pad, olen;
-    unsigned char *p = sig;
-#if defined(POLARSSL_PKCS1_V21)
-    unsigned char salt[POLARSSL_MD_MAX_SIZE];
-    unsigned int slen, hlen, offset = 0;
-    int ret;
-    size_t msb;
-    const md_info_t *md_info;
-    md_context_t md_ctx;
-#else
-    (void) f_rng;
-    (void) p_rng;
-#endif
-
-    olen = ctx->len;
-
     switch( ctx->padding )
     {
         case RSA_PKCS_V15:
-
-            switch( hash_id )
-            {
-                case SIG_RSA_RAW:
-                    nb_pad = olen - 3 - hashlen;
-                    break;
-
-                case SIG_RSA_MD2:
-                case SIG_RSA_MD4:
-                case SIG_RSA_MD5:
-                    nb_pad = olen - 3 - 34;
-                    break;
-
-                case SIG_RSA_SHA1:
-                    nb_pad = olen - 3 - 35;
-                    break;
-
-                case SIG_RSA_SHA224:
-                    nb_pad = olen - 3 - 47;
-                    break;
-
-                case SIG_RSA_SHA256:
-                    nb_pad = olen - 3 - 51;
-                    break;
-
-                case SIG_RSA_SHA384:
-                    nb_pad = olen - 3 - 67;
-                    break;
-
-                case SIG_RSA_SHA512:
-                    nb_pad = olen - 3 - 83;
-                    break;
-
-
-                default:
-                    return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
-            }
-
-            if( ( nb_pad < 8 ) || ( nb_pad > olen ) )
-                return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
-
-            *p++ = 0;
-            *p++ = RSA_SIGN;
-            memset( p, 0xFF, nb_pad );
-            p += nb_pad;
-            *p++ = 0;
-
-            switch( hash_id )
-            {
-                case SIG_RSA_RAW:
-                    memcpy( p, hash, hashlen );
-                    break;
-
-                case SIG_RSA_MD2:
-                    memcpy( p, ASN1_HASH_MDX, 18 );
-                    memcpy( p + 18, hash, 16 );
-                    p[13] = 2; break;
-
-                case SIG_RSA_MD4:
-                    memcpy( p, ASN1_HASH_MDX, 18 );
-                    memcpy( p + 18, hash, 16 );
-                    p[13] = 4; break;
-
-                case SIG_RSA_MD5:
-                    memcpy( p, ASN1_HASH_MDX, 18 );
-                    memcpy( p + 18, hash, 16 );
-                    p[13] = 5; break;
-
-                case SIG_RSA_SHA1:
-                    memcpy( p, ASN1_HASH_SHA1, 15 );
-                    memcpy( p + 15, hash, 20 );
-                    break;
-
-                case SIG_RSA_SHA224:
-                    memcpy( p, ASN1_HASH_SHA2X, 19 );
-                    memcpy( p + 19, hash, 28 );
-                    p[1] += 28; p[14] = 4; p[18] += 28; break;
-
-                case SIG_RSA_SHA256:
-                    memcpy( p, ASN1_HASH_SHA2X, 19 );
-                    memcpy( p + 19, hash, 32 );
-                    p[1] += 32; p[14] = 1; p[18] += 32; break;
-
-                case SIG_RSA_SHA384:
-                    memcpy( p, ASN1_HASH_SHA2X, 19 );
-                    memcpy( p + 19, hash, 48 );
-                    p[1] += 48; p[14] = 2; p[18] += 48; break;
-
-                case SIG_RSA_SHA512:
-                    memcpy( p, ASN1_HASH_SHA2X, 19 );
-                    memcpy( p + 19, hash, 64 );
-                    p[1] += 64; p[14] = 3; p[18] += 64; break;
-
-                default:
-                    return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
-            }
-
-            break;
+            return rsa_rsassa_pkcs1_v15_sign( ctx, mode, hash_id,
+                                              hashlen, hash, sig );
 
 #if defined(POLARSSL_PKCS1_V21)
         case RSA_PKCS_V21:
-
-            if( f_rng == NULL )
-                return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
-
-            switch( hash_id )
-            {
-                case SIG_RSA_MD2:
-                case SIG_RSA_MD4:
-                case SIG_RSA_MD5:
-                    hashlen = 16;
-                    break;
-
-                case SIG_RSA_SHA1:
-                    hashlen = 20;
-                    break;
-
-                case SIG_RSA_SHA224:
-                    hashlen = 28;
-                    break;
-
-                case SIG_RSA_SHA256:
-                    hashlen = 32;
-                    break;
-
-                case SIG_RSA_SHA384:
-                    hashlen = 48;
-                    break;
-
-                case SIG_RSA_SHA512:
-                    hashlen = 64;
-                    break;
-
-                default:
-                    return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
-            }
-
-            md_info = md_info_from_type( ctx->hash_id );
-            if( md_info == NULL )
-                return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
-                
-            hlen = md_get_size( md_info );
-            slen = hlen;
-
-            if( olen < hlen + slen + 2 )
-                return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
-
-            memset( sig, 0, olen );
-
-            msb = mpi_msb( &ctx->N ) - 1;
-
-            // Generate salt of length slen
-            //
-            if( ( ret = f_rng( p_rng, salt, slen ) ) != 0 )
-                return( POLARSSL_ERR_RSA_RNG_FAILED + ret );
-
-            // Note: EMSA-PSS encoding is over the length of N - 1 bits
-            //
-            msb = mpi_msb( &ctx->N ) - 1;
-            p += olen - hlen * 2 - 2;
-            *p++ = 0x01;
-            memcpy( p, salt, slen );
-            p += slen;
-
-            md_init_ctx( &md_ctx, md_info );
-
-            // Generate H = Hash( M' )
-            //
-            md_starts( &md_ctx );
-            md_update( &md_ctx, p, 8 );
-            md_update( &md_ctx, hash, hashlen );
-            md_update( &md_ctx, salt, slen );
-            md_finish( &md_ctx, p );
-
-            // Compensate for boundary condition when applying mask
-            //
-            if( msb % 8 == 0 )
-                offset = 1;
-
-            // maskedDB: Apply dbMask to DB
-            //
-            mgf_mask( sig + offset, olen - hlen - 1 - offset, p, hlen, &md_ctx );
-
-            md_free_ctx( &md_ctx );
-
-            msb = mpi_msb( &ctx->N ) - 1;
-            sig[0] &= 0xFF >> ( olen * 8 - msb );
-
-            p += hlen;
-            *p++ = 0xBC;
-            break;
+            return rsa_rsassa_pss_sign( ctx, f_rng, p_rng, mode, hash_id,
+                                        hashlen, hash, sig );
 #endif
 
         default:
-
             return( POLARSSL_ERR_RSA_INVALID_PADDING );
     }
-
-    return( ( mode == RSA_PUBLIC )
-            ? rsa_public(  ctx, sig, sig )
-            : rsa_private( ctx, sig, sig ) );
 }
 
+#if defined(POLARSSL_PKCS1_V21)
 /*
- * Do an RSA operation and check the message digest
+ * Implementation of the PKCS#1 v2.1 RSASSA-PSS-VERIFY function
  */
-int rsa_pkcs1_verify( rsa_context *ctx,
-                      int mode,
-                      int hash_id,
-                      unsigned int hashlen,
-                      const unsigned char *hash,
-                      unsigned char *sig )
+int rsa_rsassa_pss_verify( rsa_context *ctx,
+                           int mode,
+                           int hash_id,
+                           unsigned int hashlen,
+                           const unsigned char *hash,
+                           unsigned char *sig )
 {
     int ret;
-    size_t len, siglen;
-    unsigned char *p, c;
+    size_t siglen;
+    unsigned char *p;
     unsigned char buf[POLARSSL_MPI_MAX_SIZE];
-#if defined(POLARSSL_PKCS1_V21)
     unsigned char result[POLARSSL_MD_MAX_SIZE];
     unsigned char zeros[8];
     unsigned int hlen;
     size_t slen, msb;
     const md_info_t *md_info;
     md_context_t md_ctx;
-#endif
+
+    if( ctx->padding != RSA_PKCS_V21 )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
     siglen = ctx->len;
 
     if( siglen < 16 || siglen > sizeof( buf ) )
@@ -890,189 +1017,235 @@
 
     p = buf;
 
-    switch( ctx->padding )
+    if( buf[siglen - 1] != 0xBC )
+        return( POLARSSL_ERR_RSA_INVALID_PADDING );
+
+    switch( hash_id )
     {
-        case RSA_PKCS_V15:
-
-            if( *p++ != 0 || *p++ != RSA_SIGN )
-                return( POLARSSL_ERR_RSA_INVALID_PADDING );
-
-            while( *p != 0 )
-            {
-                if( p >= buf + siglen - 1 || *p != 0xFF )
-                    return( POLARSSL_ERR_RSA_INVALID_PADDING );
-                p++;
-            }
-            p++;
-
-            len = siglen - ( p - buf );
-
-            if( len == 33 && hash_id == SIG_RSA_SHA1 )
-            {
-                if( memcmp( p, ASN1_HASH_SHA1_ALT, 13 ) == 0 &&
-                        memcmp( p + 13, hash, 20 ) == 0 )
-                    return( 0 );
-                else
-                    return( POLARSSL_ERR_RSA_VERIFY_FAILED );
-            }
-            if( len == 34 )
-            {
-                c = p[13];
-                p[13] = 0;
-
-                if( memcmp( p, ASN1_HASH_MDX, 18 ) != 0 )
-                    return( POLARSSL_ERR_RSA_VERIFY_FAILED );
-
-                if( ( c == 2 && hash_id == SIG_RSA_MD2 ) ||
-                        ( c == 4 && hash_id == SIG_RSA_MD4 ) ||
-                        ( c == 5 && hash_id == SIG_RSA_MD5 ) )
-                {
-                    if( memcmp( p + 18, hash, 16 ) == 0 ) 
-                        return( 0 );
-                    else
-                        return( POLARSSL_ERR_RSA_VERIFY_FAILED );
-                }
-            }
-
-            if( len == 35 && hash_id == SIG_RSA_SHA1 )
-            {
-                if( memcmp( p, ASN1_HASH_SHA1, 15 ) == 0 &&
-                        memcmp( p + 15, hash, 20 ) == 0 )
-                    return( 0 );
-                else
-                    return( POLARSSL_ERR_RSA_VERIFY_FAILED );
-            }
-            if( ( len == 19 + 28 && p[14] == 4 && hash_id == SIG_RSA_SHA224 ) ||
-                    ( len == 19 + 32 && p[14] == 1 && hash_id == SIG_RSA_SHA256 ) ||
-                    ( len == 19 + 48 && p[14] == 2 && hash_id == SIG_RSA_SHA384 ) ||
-                    ( len == 19 + 64 && p[14] == 3 && hash_id == SIG_RSA_SHA512 ) )
-            {
-                c = p[1] - 17;
-                p[1] = 17;
-                p[14] = 0;
-
-                if( p[18] == c &&
-                        memcmp( p, ASN1_HASH_SHA2X, 18 ) == 0 &&
-                        memcmp( p + 19, hash, c ) == 0 )
-                    return( 0 );
-                else
-                    return( POLARSSL_ERR_RSA_VERIFY_FAILED );
-            }
-
-            if( len == hashlen && hash_id == SIG_RSA_RAW )
-            {
-                if( memcmp( p, hash, hashlen ) == 0 )
-                    return( 0 );
-                else
-                    return( POLARSSL_ERR_RSA_VERIFY_FAILED );
-            }
-
+        case SIG_RSA_MD2:
+        case SIG_RSA_MD4:
+        case SIG_RSA_MD5:
+            hashlen = 16;
             break;
 
-#if defined(POLARSSL_PKCS1_V21)
-        case RSA_PKCS_V21:
-            
-            if( buf[siglen - 1] != 0xBC )
-                return( POLARSSL_ERR_RSA_INVALID_PADDING );
+        case SIG_RSA_SHA1:
+            hashlen = 20;
+            break;
 
-            switch( hash_id )
-            {
-                case SIG_RSA_MD2:
-                case SIG_RSA_MD4:
-                case SIG_RSA_MD5:
-                    hashlen = 16;
-                    break;
+        case SIG_RSA_SHA224:
+            hashlen = 28;
+            break;
 
-                case SIG_RSA_SHA1:
-                    hashlen = 20;
-                    break;
+        case SIG_RSA_SHA256:
+            hashlen = 32;
+            break;
 
-                case SIG_RSA_SHA224:
-                    hashlen = 28;
-                    break;
+        case SIG_RSA_SHA384:
+            hashlen = 48;
+            break;
 
-                case SIG_RSA_SHA256:
-                    hashlen = 32;
-                    break;
+        case SIG_RSA_SHA512:
+            hashlen = 64;
+            break;
 
-                case SIG_RSA_SHA384:
-                    hashlen = 48;
-                    break;
+        default:
+            return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+    }
 
-                case SIG_RSA_SHA512:
-                    hashlen = 64;
-                    break;
+    md_info = md_info_from_type( ctx->hash_id );
+    if( md_info == NULL )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
 
-                default:
-                    return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
-            }
+    hlen = md_get_size( md_info );
+    slen = siglen - hlen - 1;
 
-            md_info = md_info_from_type( ctx->hash_id );
-            if( md_info == NULL )
-                return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
-                
-            hlen = md_get_size( md_info );
-            slen = siglen - hlen - 1;
+    memset( zeros, 0, 8 );
 
-            memset( zeros, 0, 8 );
+    // Note: EMSA-PSS verification is over the length of N - 1 bits
+    //
+    msb = mpi_msb( &ctx->N ) - 1;
 
-            // Note: EMSA-PSS verification is over the length of N - 1 bits
-            //
-            msb = mpi_msb( &ctx->N ) - 1;
+    // Compensate for boundary condition when applying mask
+    //
+    if( msb % 8 == 0 )
+    {
+        p++;
+        siglen -= 1;
+    }
+    if( buf[0] >> ( 8 - siglen * 8 + msb ) )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
 
-            // Compensate for boundary condition when applying mask
-            //
-            if( msb % 8 == 0 )
-            {
-                p++;
-                siglen -= 1;
-            }
-            if( buf[0] >> ( 8 - siglen * 8 + msb ) )
-                return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+    md_init_ctx( &md_ctx, md_info );
 
-            md_init_ctx( &md_ctx, md_info );
+    mgf_mask( p, siglen - hlen - 1, p + siglen - hlen - 1, hlen, &md_ctx );
 
-            mgf_mask( p, siglen - hlen - 1, p + siglen - hlen - 1, hlen, &md_ctx );
+    buf[0] &= 0xFF >> ( siglen * 8 - msb );
 
-            buf[0] &= 0xFF >> ( siglen * 8 - msb );
+    while( *p == 0 && p < buf + siglen )
+        p++;
 
-            while( *p == 0 && p < buf + siglen )
-                p++;
+    if( p == buf + siglen ||
+        *p++ != 0x01 )
+    {
+        md_free_ctx( &md_ctx );
+        return( POLARSSL_ERR_RSA_INVALID_PADDING );
+    }
 
-            if( p == buf + siglen ||
-                *p++ != 0x01 )
-            {
-                md_free_ctx( &md_ctx );
-                return( POLARSSL_ERR_RSA_INVALID_PADDING );
-            }
+    slen -= p - buf;
 
-            slen -= p - buf;
+    // Generate H = Hash( M' )
+    //
+    md_starts( &md_ctx );
+    md_update( &md_ctx, zeros, 8 );
+    md_update( &md_ctx, hash, hashlen );
+    md_update( &md_ctx, p, slen );
+    md_finish( &md_ctx, result );
 
-            // Generate H = Hash( M' )
-            //
-            md_starts( &md_ctx );
-            md_update( &md_ctx, zeros, 8 );
-            md_update( &md_ctx, hash, hashlen );
-            md_update( &md_ctx, p, slen );
-            md_finish( &md_ctx, result );
+    md_free_ctx( &md_ctx );
 
-            md_free_ctx( &md_ctx );
+    if( memcmp( p + slen, result, hlen ) == 0 )
+        return( 0 );
+    else
+        return( POLARSSL_ERR_RSA_VERIFY_FAILED );
+}
+#endif /* POLARSSL_PKCS1_V21 */
 
-            if( memcmp( p + slen, result, hlen ) == 0 )
+/*
+ * Implementation of the PKCS#1 v2.1 RSASSA-PKCS1-v1_5-VERIFY function
+ */
+int rsa_rsassa_pkcs1_v15_verify( rsa_context *ctx,
+                                 int mode,
+                                 int hash_id,
+                                 unsigned int hashlen,
+                                 const unsigned char *hash,
+                                 unsigned char *sig )
+{
+    int ret;
+    size_t len, siglen;
+    unsigned char *p, c;
+    unsigned char buf[POLARSSL_MPI_MAX_SIZE];
+
+    if( ctx->padding != RSA_PKCS_V15 )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+    siglen = ctx->len;
+
+    if( siglen < 16 || siglen > sizeof( buf ) )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+    ret = ( mode == RSA_PUBLIC )
+          ? rsa_public(  ctx, sig, buf )
+          : rsa_private( ctx, sig, buf );
+
+    if( ret != 0 )
+        return( ret );
+
+    p = buf;
+
+    if( *p++ != 0 || *p++ != RSA_SIGN )
+        return( POLARSSL_ERR_RSA_INVALID_PADDING );
+
+    while( *p != 0 )
+    {
+        if( p >= buf + siglen - 1 || *p != 0xFF )
+            return( POLARSSL_ERR_RSA_INVALID_PADDING );
+        p++;
+    }
+    p++;
+
+    len = siglen - ( p - buf );
+
+    if( len == 33 && hash_id == SIG_RSA_SHA1 )
+    {
+        if( memcmp( p, ASN1_HASH_SHA1_ALT, 13 ) == 0 &&
+                memcmp( p + 13, hash, 20 ) == 0 )
+            return( 0 );
+        else
+            return( POLARSSL_ERR_RSA_VERIFY_FAILED );
+    }
+    if( len == 34 )
+    {
+        c = p[13];
+        p[13] = 0;
+
+        if( memcmp( p, ASN1_HASH_MDX, 18 ) != 0 )
+            return( POLARSSL_ERR_RSA_VERIFY_FAILED );
+
+        if( ( c == 2 && hash_id == SIG_RSA_MD2 ) ||
+                ( c == 4 && hash_id == SIG_RSA_MD4 ) ||
+                ( c == 5 && hash_id == SIG_RSA_MD5 ) )
+        {
+            if( memcmp( p + 18, hash, 16 ) == 0 )
                 return( 0 );
             else
                 return( POLARSSL_ERR_RSA_VERIFY_FAILED );
-#endif
+        }
+    }
 
-        default:
+    if( len == 35 && hash_id == SIG_RSA_SHA1 )
+    {
+        if( memcmp( p, ASN1_HASH_SHA1, 15 ) == 0 &&
+                memcmp( p + 15, hash, 20 ) == 0 )
+            return( 0 );
+        else
+            return( POLARSSL_ERR_RSA_VERIFY_FAILED );
+    }
+    if( ( len == 19 + 28 && p[14] == 4 && hash_id == SIG_RSA_SHA224 ) ||
+            ( len == 19 + 32 && p[14] == 1 && hash_id == SIG_RSA_SHA256 ) ||
+            ( len == 19 + 48 && p[14] == 2 && hash_id == SIG_RSA_SHA384 ) ||
+            ( len == 19 + 64 && p[14] == 3 && hash_id == SIG_RSA_SHA512 ) )
+    {
+        c = p[1] - 17;
+        p[1] = 17;
+        p[14] = 0;
 
-            return( POLARSSL_ERR_RSA_INVALID_PADDING );
+        if( p[18] == c &&
+                memcmp( p, ASN1_HASH_SHA2X, 18 ) == 0 &&
+                memcmp( p + 19, hash, c ) == 0 )
+            return( 0 );
+        else
+            return( POLARSSL_ERR_RSA_VERIFY_FAILED );
+    }
+
+    if( len == hashlen && hash_id == SIG_RSA_RAW )
+    {
+        if( memcmp( p, hash, hashlen ) == 0 )
+            return( 0 );
+        else
+            return( POLARSSL_ERR_RSA_VERIFY_FAILED );
     }
 
     return( POLARSSL_ERR_RSA_INVALID_PADDING );
 }
 
 /*
+ * Do an RSA operation and check the message digest
+ */
+int rsa_pkcs1_verify( rsa_context *ctx,
+                      int mode,
+                      int hash_id,
+                      unsigned int hashlen,
+                      const unsigned char *hash,
+                      unsigned char *sig )
+{
+    switch( ctx->padding )
+    {
+        case RSA_PKCS_V15:
+            return rsa_rsassa_pkcs1_v15_verify( ctx, mode, hash_id,
+                                                hashlen, hash, sig );
+
+#if defined(POLARSSL_PKCS1_V21)
+        case RSA_PKCS_V21:
+            return rsa_rsassa_pss_verify( ctx, mode, hash_id,
+                                          hashlen, hash, sig );
+#endif
+
+        default:
+            return( POLARSSL_ERR_RSA_INVALID_PADDING );
+    }
+}
+
+/*
  * Free the components of an RSA key
  */
 void rsa_free( rsa_context *ctx )