- Added support for PKCS#1 v2.1 encoding and thus support for the RSAES-OAEP and RSASSA-PSS operations (enabled by POLARSSL_PKCS1_V21)


diff --git a/library/rsa.c b/library/rsa.c
index 77404fc..ec44d84 100644
--- a/library/rsa.c
+++ b/library/rsa.c
@@ -34,6 +34,7 @@
 #if defined(POLARSSL_RSA_C)
 
 #include "polarssl/rsa.h"
+#include "polarssl/md.h"
 
 #include <stdlib.h>
 #include <string.h>
@@ -291,6 +292,55 @@
     return( 0 );
 }
 
+#if defined(POLARSSL_PKCS1_V21)
+/**
+ * Generate and apply the MGF1 operation (from PKCS#1 v2.1) to a buffer.
+ *
+ * @param dst       buffer to mask
+ * @param dlen      length of destination buffer
+ * @param src       source of the mask generation
+ * @param slen      length of the source buffer
+ * @param md_ctx    message digest context to use
+ * @param hlen      length of the digest result
+ */
+static void mgf_mask( unsigned char *dst, int dlen, unsigned char *src, int slen,  
+                       md_context_t *md_ctx )
+{
+    unsigned char mask[POLARSSL_MD_MAX_SIZE];
+    unsigned char counter[4];
+    unsigned char *p;
+    int i, use_len, hlen;
+
+    memset( mask, 0, POLARSSL_MD_MAX_SIZE );
+    memset( counter, 0, 4 );
+
+    hlen = md_ctx->md_info->size;
+
+    // Generate and apply dbMask
+    //
+    p = dst;
+
+    while( dlen > 0 )
+    {
+        use_len = hlen;
+        if( dlen < hlen )
+            use_len = dlen;
+
+        md_starts( md_ctx );
+        md_update( md_ctx, src, slen );
+        md_update( md_ctx, counter, 4 );
+        md_finish( md_ctx, mask );
+
+        for( i = 0; i < use_len; ++i )
+            *p++ ^= mask[i];
+
+        counter[3]++;
+
+        dlen -= use_len;
+    }
+}
+#endif
+
 /*
  * Add the message padding, then do an RSA operation
  */
@@ -303,14 +353,22 @@
 {
     int nb_pad, olen;
     unsigned char *p = output;
+#if defined(POLARSSL_PKCS1_V21)
+    const md_info_t *md_info;
+    md_context_t md_ctx;
+    int i, hlen;
+#endif
 
     olen = ctx->len;
 
+    if( f_rng == NULL )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
     switch( ctx->padding )
     {
         case RSA_PKCS_V15:
 
-            if( ilen < 0 || olen < ilen + 11 || f_rng == NULL )
+            if( ilen < 0 || olen < ilen + 11 )
                 return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
 
             nb_pad = olen - 3 - ilen;
@@ -336,6 +394,50 @@
             *p++ = 0;
             memcpy( p, input, ilen );
             break;
+        
+#if defined(POLARSSL_PKCS1_V21)
+        case RSA_PKCS_V21:
+
+            md_info = md_info_from_type( ctx->hash_id );
+            if( md_info == NULL )
+                return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+            hlen = md_get_size( md_info );
+
+            if( ilen < 0 || olen < ilen + 2 * hlen + 2 || f_rng == NULL )
+                return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+            memset( output, 0, olen );
+            memset( &md_ctx, 0, sizeof( md_context_t ) );
+
+            md_init_ctx( &md_ctx, md_info );
+
+            *p++ = 0;
+
+            // Generate a random octet string seed
+            //
+            for( i = 0; i < hlen; ++i )
+                *p++ = (unsigned char) f_rng( p_rng ); 
+
+            // Construct DB
+            //
+            md( md_info, p, 0, p );
+            p += hlen;
+            p += olen - 2 * hlen - 2 - ilen;
+            *p++ = 1;
+            memcpy( p, input, ilen ); 
+
+            // maskedDB: Apply dbMask to DB
+            //
+            mgf_mask( output + hlen + 1, olen - hlen - 1, output + 1, hlen,  
+                       &md_ctx );
+
+            // maskedSeed: Apply seedMask to seed
+            //
+            mgf_mask( output + 1, hlen, output + hlen + 1, olen - hlen - 1,  
+                       &md_ctx );
+            break;
+#endif
 
         default:
 
@@ -359,6 +461,12 @@
     int ret, ilen;
     unsigned char *p;
     unsigned char buf[1024];
+#if defined(POLARSSL_PKCS1_V21)
+    unsigned char lhash[POLARSSL_MD_MAX_SIZE];
+    const md_info_t *md_info;
+    md_context_t md_ctx;
+    int hlen;
+#endif
 
     ilen = ctx->len;
 
@@ -390,6 +498,56 @@
             p++;
             break;
 
+#if defined(POLARSSL_PKCS1_V21)
+        case RSA_PKCS_V21:
+            
+            if( *p++ != 0 )
+                return( POLARSSL_ERR_RSA_INVALID_PADDING );
+
+            md_info = md_info_from_type( ctx->hash_id );
+            if( md_info == NULL )
+                return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+                
+            hlen = md_get_size( md_info );
+            memset( &md_ctx, 0, sizeof( md_context_t ) );
+
+            md_init_ctx( &md_ctx, md_info );
+            
+            // Generate lHash
+            //
+            md( md_info, lhash, 0, lhash );
+
+            // seed: Apply seedMask to maskedSeed
+            //
+            mgf_mask( buf + 1, hlen, buf + hlen + 1, ilen - hlen - 1,
+                       &md_ctx );
+
+            // DB: Apply dbMask to maskedDB
+            //
+            mgf_mask( buf + hlen + 1, ilen - hlen - 1, buf + 1, hlen,  
+                       &md_ctx );
+
+            p += hlen;
+
+            // Check validity
+            //
+            if( memcmp( lhash, p, hlen ) != 0 )
+                return( POLARSSL_ERR_RSA_INVALID_PADDING );
+
+            p += hlen;
+
+            while( *p == 0 && p < buf + ilen )
+                p++;
+
+            if( p == buf + ilen )
+                return( POLARSSL_ERR_RSA_INVALID_PADDING );
+
+            if( *p++ != 0x01 )
+                return( POLARSSL_ERR_RSA_INVALID_PADDING );
+
+            break;
+#endif
+
         default:
 
             return( POLARSSL_ERR_RSA_INVALID_PADDING );
@@ -408,6 +566,8 @@
  * Do an RSA operation to sign the message digest
  */
 int rsa_pkcs1_sign( rsa_context *ctx,
+                    int (*f_rng)(void *),
+                    void *p_rng,
                     int mode,
                     int hash_id,
                     int hashlen,
@@ -416,6 +576,15 @@
 {
     int nb_pad, olen;
     unsigned char *p = sig;
+#if defined(POLARSSL_PKCS1_V21)
+    unsigned char salt[POLARSSL_MD_MAX_SIZE];
+    const md_info_t *md_info;
+    md_context_t md_ctx;
+    int i, hlen, msb, offset = 0;
+#else
+    (void) f_rng;
+    (void) p_rng;
+#endif
 
     olen = ctx->len;
 
@@ -468,63 +637,152 @@
             memset( p, 0xFF, nb_pad );
             p += nb_pad;
             *p++ = 0;
+
+            switch( hash_id )
+            {
+                case SIG_RSA_RAW:
+                    memcpy( p, hash, hashlen );
+                    break;
+
+                case SIG_RSA_MD2:
+                    memcpy( p, ASN1_HASH_MDX, 18 );
+                    memcpy( p + 18, hash, 16 );
+                    p[13] = 2; break;
+
+                case SIG_RSA_MD4:
+                    memcpy( p, ASN1_HASH_MDX, 18 );
+                    memcpy( p + 18, hash, 16 );
+                    p[13] = 4; break;
+
+                case SIG_RSA_MD5:
+                    memcpy( p, ASN1_HASH_MDX, 18 );
+                    memcpy( p + 18, hash, 16 );
+                    p[13] = 5; break;
+
+                case SIG_RSA_SHA1:
+                    memcpy( p, ASN1_HASH_SHA1, 15 );
+                    memcpy( p + 15, hash, 20 );
+                    break;
+
+                case SIG_RSA_SHA224:
+                    memcpy( p, ASN1_HASH_SHA2X, 19 );
+                    memcpy( p + 19, hash, 28 );
+                    p[1] += 28; p[14] = 4; p[18] += 28; break;
+
+                case SIG_RSA_SHA256:
+                    memcpy( p, ASN1_HASH_SHA2X, 19 );
+                    memcpy( p + 19, hash, 32 );
+                    p[1] += 32; p[14] = 1; p[18] += 32; break;
+
+                case SIG_RSA_SHA384:
+                    memcpy( p, ASN1_HASH_SHA2X, 19 );
+                    memcpy( p + 19, hash, 48 );
+                    p[1] += 48; p[14] = 2; p[18] += 48; break;
+
+                case SIG_RSA_SHA512:
+                    memcpy( p, ASN1_HASH_SHA2X, 19 );
+                    memcpy( p + 19, hash, 64 );
+                    p[1] += 64; p[14] = 3; p[18] += 64; break;
+
+                default:
+                    return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+            }
+
             break;
 
+#if defined(POLARSSL_PKCS1_V21)
+        case RSA_PKCS_V21:
+
+            if( f_rng == NULL )
+                return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+            switch( hash_id )
+            {
+                case SIG_RSA_MD2:
+                case SIG_RSA_MD4:
+                case SIG_RSA_MD5:
+                    hashlen = 16;
+                    break;
+
+                case SIG_RSA_SHA1:
+                    hashlen = 20;
+                    break;
+
+                case SIG_RSA_SHA224:
+                    hashlen = 28;
+                    break;
+
+                case SIG_RSA_SHA256:
+                    hashlen = 32;
+                    break;
+
+                case SIG_RSA_SHA384:
+                    hashlen = 48;
+                    break;
+
+                case SIG_RSA_SHA512:
+                    hashlen = 64;
+                    break;
+
+                default:
+                    return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+            }
+
+            md_info = md_info_from_type( ctx->hash_id );
+            if( md_info == NULL )
+                return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+                
+            hlen = md_get_size( md_info );
+            memset( sig, 0, olen );
+            memset( &md_ctx, 0, sizeof( md_context_t ) );
+
+            md_init_ctx( &md_ctx, md_info );
+
+            msb = mpi_msb( &ctx->N ) - 1;
+
+            // Generate salt of length hlen
+            //
+            for( i = 0; i < hlen; ++i )
+                salt[i] = (unsigned char) f_rng( p_rng ); 
+
+            // Note: EMSA-PSS encoding is over the length of N - 1 bits
+            //
+            msb = mpi_msb( &ctx->N ) - 1;
+            p += olen - hlen * 2 - 2;
+            *p++ = 0x01;
+            memcpy( p, salt, hlen );
+            p += hlen;
+
+            // Generate H = Hash( M' )
+            //
+            md_starts( &md_ctx );
+            md_update( &md_ctx, p, 8 );
+            md_update( &md_ctx, hash, hashlen );
+            md_update( &md_ctx, salt, hlen );
+            md_finish( &md_ctx, p );
+
+            // Compensate for boundary condition when applying mask
+            //
+            if( msb % 8 == 0 )
+                offset = 1;
+
+            // maskedDB: Apply dbMask to DB
+            //
+            mgf_mask( sig + offset, olen - hlen - 1 - offset, p, hlen, &md_ctx );
+
+            msb = mpi_msb( &ctx->N ) - 1;
+            sig[0] &= 0xFF >> ( olen * 8 - msb );
+
+            p += hlen;
+            *p++ = 0xBC;
+            break;
+#endif
+
         default:
 
             return( POLARSSL_ERR_RSA_INVALID_PADDING );
     }
 
-    switch( hash_id )
-    {
-        case SIG_RSA_RAW:
-            memcpy( p, hash, hashlen );
-            break;
-
-        case SIG_RSA_MD2:
-            memcpy( p, ASN1_HASH_MDX, 18 );
-            memcpy( p + 18, hash, 16 );
-            p[13] = 2; break;
-
-        case SIG_RSA_MD4:
-            memcpy( p, ASN1_HASH_MDX, 18 );
-            memcpy( p + 18, hash, 16 );
-            p[13] = 4; break;
-
-        case SIG_RSA_MD5:
-            memcpy( p, ASN1_HASH_MDX, 18 );
-            memcpy( p + 18, hash, 16 );
-            p[13] = 5; break;
-
-        case SIG_RSA_SHA1:
-            memcpy( p, ASN1_HASH_SHA1, 15 );
-            memcpy( p + 15, hash, 20 );
-            break;
-
-        case SIG_RSA_SHA224:
-            memcpy( p, ASN1_HASH_SHA2X, 19 );
-            memcpy( p + 19, hash, 28 );
-            p[1] += 28; p[14] = 4; p[18] += 28; break;
-
-        case SIG_RSA_SHA256:
-            memcpy( p, ASN1_HASH_SHA2X, 19 );
-            memcpy( p + 19, hash, 32 );
-            p[1] += 32; p[14] = 1; p[18] += 32; break;
-
-        case SIG_RSA_SHA384:
-            memcpy( p, ASN1_HASH_SHA2X, 19 );
-            memcpy( p + 19, hash, 48 );
-            p[1] += 48; p[14] = 2; p[18] += 48; break;
-
-        case SIG_RSA_SHA512:
-            memcpy( p, ASN1_HASH_SHA2X, 19 );
-            memcpy( p + 19, hash, 64 );
-            p[1] += 64; p[14] = 3; p[18] += 64; break;
-
-        default:
-            return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
-    }
-
     return( ( mode == RSA_PUBLIC )
             ? rsa_public(  ctx, sig, sig )
             : rsa_private( ctx, sig, sig ) );
@@ -543,7 +801,12 @@
     int ret, len, siglen;
     unsigned char *p, c;
     unsigned char buf[1024];
-
+#if defined(POLARSSL_PKCS1_V21)
+    unsigned char zeros[8];
+    const md_info_t *md_info;
+    md_context_t md_ctx;
+    int hlen, msb;
+#endif
     siglen = ctx->len;
 
     if( siglen < 16 || siglen > (int) sizeof( buf ) )
@@ -572,67 +835,158 @@
                 p++;
             }
             p++;
+
+            len = siglen - (int)( p - buf );
+
+            if( len == 34 )
+            {
+                c = p[13];
+                p[13] = 0;
+
+                if( memcmp( p, ASN1_HASH_MDX, 18 ) != 0 )
+                    return( POLARSSL_ERR_RSA_VERIFY_FAILED );
+
+                if( ( c == 2 && hash_id == SIG_RSA_MD2 ) ||
+                        ( c == 4 && hash_id == SIG_RSA_MD4 ) ||
+                        ( c == 5 && hash_id == SIG_RSA_MD5 ) )
+                {
+                    if( memcmp( p + 18, hash, 16 ) == 0 ) 
+                        return( 0 );
+                    else
+                        return( POLARSSL_ERR_RSA_VERIFY_FAILED );
+                }
+            }
+
+            if( len == 35 && hash_id == SIG_RSA_SHA1 )
+            {
+                if( memcmp( p, ASN1_HASH_SHA1, 15 ) == 0 &&
+                        memcmp( p + 15, hash, 20 ) == 0 )
+                    return( 0 );
+                else
+                    return( POLARSSL_ERR_RSA_VERIFY_FAILED );
+            }
+            if( ( len == 19 + 28 && p[14] == 4 && hash_id == SIG_RSA_SHA224 ) ||
+                    ( len == 19 + 32 && p[14] == 1 && hash_id == SIG_RSA_SHA256 ) ||
+                    ( len == 19 + 48 && p[14] == 2 && hash_id == SIG_RSA_SHA384 ) ||
+                    ( len == 19 + 64 && p[14] == 3 && hash_id == SIG_RSA_SHA512 ) )
+            {
+                c = p[1] - 17;
+                p[1] = 17;
+                p[14] = 0;
+
+                if( p[18] == c &&
+                        memcmp( p, ASN1_HASH_SHA2X, 18 ) == 0 &&
+                        memcmp( p + 19, hash, c ) == 0 )
+                    return( 0 );
+                else
+                    return( POLARSSL_ERR_RSA_VERIFY_FAILED );
+            }
+
+            if( len == hashlen && hash_id == SIG_RSA_RAW )
+            {
+                if( memcmp( p, hash, hashlen ) == 0 )
+                    return( 0 );
+                else
+                    return( POLARSSL_ERR_RSA_VERIFY_FAILED );
+            }
+
             break;
 
+#if defined(POLARSSL_PKCS1_V21)
+        case RSA_PKCS_V21:
+            
+            if( buf[siglen - 1] != 0xBC )
+                return( POLARSSL_ERR_RSA_INVALID_PADDING );
+
+            switch( hash_id )
+            {
+                case SIG_RSA_MD2:
+                case SIG_RSA_MD4:
+                case SIG_RSA_MD5:
+                    hashlen = 16;
+                    break;
+
+                case SIG_RSA_SHA1:
+                    hashlen = 20;
+                    break;
+
+                case SIG_RSA_SHA224:
+                    hashlen = 28;
+                    break;
+
+                case SIG_RSA_SHA256:
+                    hashlen = 32;
+                    break;
+
+                case SIG_RSA_SHA384:
+                    hashlen = 48;
+                    break;
+
+                case SIG_RSA_SHA512:
+                    hashlen = 64;
+                    break;
+
+                default:
+                    return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+            }
+
+            md_info = md_info_from_type( ctx->hash_id );
+            if( md_info == NULL )
+                return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+                
+            hlen = md_get_size( md_info );
+            memset( &md_ctx, 0, sizeof( md_context_t ) );
+            memset( zeros, 0, 8 );
+
+            md_init_ctx( &md_ctx, md_info );
+
+            // Note: EMSA-PSS verification is over the length of N - 1 bits
+            //
+            msb = mpi_msb( &ctx->N ) - 1;
+
+            // Compensate for boundary condition when applying mask
+            //
+            if( msb % 8 == 0 )
+            {
+                p++;
+                siglen -= 1;
+            }
+            if( buf[0] >> ( 8 - siglen * 8 + msb ) )
+                return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+            mgf_mask( p, siglen - hlen - 1, p + siglen - hlen - 1, hlen, &md_ctx );
+
+            buf[0] &= 0xFF >> ( siglen * 8 - msb );
+
+            while( *p == 0 && p < buf + siglen )
+                p++;
+
+            if( p == buf + siglen )
+                return( POLARSSL_ERR_RSA_INVALID_PADDING );
+
+            if( *p++ != 0x01 )
+                return( POLARSSL_ERR_RSA_INVALID_PADDING );
+
+            // Generate H = Hash( M' )
+            //
+            md_starts( &md_ctx );
+            md_update( &md_ctx, zeros, 8 );
+            md_update( &md_ctx, hash, hashlen );
+            md_update( &md_ctx, p, hlen );
+            md_finish( &md_ctx, p );
+
+            if( memcmp( p, p + hlen, hlen ) == 0 )
+                return( 0 );
+            else
+                return( POLARSSL_ERR_RSA_VERIFY_FAILED );
+            break;
+#endif
+
         default:
 
             return( POLARSSL_ERR_RSA_INVALID_PADDING );
     }
 
-    len = siglen - (int)( p - buf );
-
-    if( len == 34 )
-    {
-        c = p[13];
-        p[13] = 0;
-
-        if( memcmp( p, ASN1_HASH_MDX, 18 ) != 0 )
-            return( POLARSSL_ERR_RSA_VERIFY_FAILED );
-
-        if( ( c == 2 && hash_id == SIG_RSA_MD2 ) ||
-            ( c == 4 && hash_id == SIG_RSA_MD4 ) ||
-            ( c == 5 && hash_id == SIG_RSA_MD5 ) )
-        {
-            if( memcmp( p + 18, hash, 16 ) == 0 ) 
-                return( 0 );
-            else
-                return( POLARSSL_ERR_RSA_VERIFY_FAILED );
-        }
-    }
-
-    if( len == 35 && hash_id == SIG_RSA_SHA1 )
-    {
-        if( memcmp( p, ASN1_HASH_SHA1, 15 ) == 0 &&
-            memcmp( p + 15, hash, 20 ) == 0 )
-            return( 0 );
-        else
-            return( POLARSSL_ERR_RSA_VERIFY_FAILED );
-    }
-    if( ( len == 19 + 28 && p[14] == 4 && hash_id == SIG_RSA_SHA224 ) ||
-        ( len == 19 + 32 && p[14] == 1 && hash_id == SIG_RSA_SHA256 ) ||
-        ( len == 19 + 48 && p[14] == 2 && hash_id == SIG_RSA_SHA384 ) ||
-        ( len == 19 + 64 && p[14] == 3 && hash_id == SIG_RSA_SHA512 ) )
-    {
-    	c = p[1] - 17;
-        p[1] = 17;
-        p[14] = 0;
-
-        if( p[18] == c &&
-                memcmp( p, ASN1_HASH_SHA2X, 18 ) == 0 &&
-                memcmp( p + 19, hash, c ) == 0 )
-            return( 0 );
-        else
-            return( POLARSSL_ERR_RSA_VERIFY_FAILED );
-    }
-
-    if( len == hashlen && hash_id == SIG_RSA_RAW )
-    {
-        if( memcmp( p, hash, hashlen ) == 0 )
-            return( 0 );
-        else
-            return( POLARSSL_ERR_RSA_VERIFY_FAILED );
-    }
-
     return( POLARSSL_ERR_RSA_INVALID_PADDING );
 }
 
@@ -789,7 +1143,7 @@
 
     sha1( rsa_plaintext, PT_LEN, sha1sum );
 
-    if( rsa_pkcs1_sign( &rsa, RSA_PRIVATE, SIG_RSA_SHA1, 20,
+    if( rsa_pkcs1_sign( &rsa, NULL, NULL, RSA_PRIVATE, SIG_RSA_SHA1, 20,
                         sha1sum, rsa_ciphertext ) != 0 )
     {
         if( verbose != 0 )