- Added support for PKCS#1 v2.1 encoding and thus support for the RSAES-OAEP and RSASSA-PSS operations (enabled by POLARSSL_PKCS1_V21)
diff --git a/library/rsa.c b/library/rsa.c
index 77404fc..ec44d84 100644
--- a/library/rsa.c
+++ b/library/rsa.c
@@ -34,6 +34,7 @@
#if defined(POLARSSL_RSA_C)
#include "polarssl/rsa.h"
+#include "polarssl/md.h"
#include <stdlib.h>
#include <string.h>
@@ -291,6 +292,55 @@
return( 0 );
}
+#if defined(POLARSSL_PKCS1_V21)
+/**
+ * Generate and apply the MGF1 operation (from PKCS#1 v2.1) to a buffer.
+ *
+ * @param dst buffer to mask
+ * @param dlen length of destination buffer
+ * @param src source of the mask generation
+ * @param slen length of the source buffer
+ * @param md_ctx message digest context to use
+ * @param hlen length of the digest result
+ */
+static void mgf_mask( unsigned char *dst, int dlen, unsigned char *src, int slen,
+ md_context_t *md_ctx )
+{
+ unsigned char mask[POLARSSL_MD_MAX_SIZE];
+ unsigned char counter[4];
+ unsigned char *p;
+ int i, use_len, hlen;
+
+ memset( mask, 0, POLARSSL_MD_MAX_SIZE );
+ memset( counter, 0, 4 );
+
+ hlen = md_ctx->md_info->size;
+
+ // Generate and apply dbMask
+ //
+ p = dst;
+
+ while( dlen > 0 )
+ {
+ use_len = hlen;
+ if( dlen < hlen )
+ use_len = dlen;
+
+ md_starts( md_ctx );
+ md_update( md_ctx, src, slen );
+ md_update( md_ctx, counter, 4 );
+ md_finish( md_ctx, mask );
+
+ for( i = 0; i < use_len; ++i )
+ *p++ ^= mask[i];
+
+ counter[3]++;
+
+ dlen -= use_len;
+ }
+}
+#endif
+
/*
* Add the message padding, then do an RSA operation
*/
@@ -303,14 +353,22 @@
{
int nb_pad, olen;
unsigned char *p = output;
+#if defined(POLARSSL_PKCS1_V21)
+ const md_info_t *md_info;
+ md_context_t md_ctx;
+ int i, hlen;
+#endif
olen = ctx->len;
+ if( f_rng == NULL )
+ return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
switch( ctx->padding )
{
case RSA_PKCS_V15:
- if( ilen < 0 || olen < ilen + 11 || f_rng == NULL )
+ if( ilen < 0 || olen < ilen + 11 )
return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
nb_pad = olen - 3 - ilen;
@@ -336,6 +394,50 @@
*p++ = 0;
memcpy( p, input, ilen );
break;
+
+#if defined(POLARSSL_PKCS1_V21)
+ case RSA_PKCS_V21:
+
+ md_info = md_info_from_type( ctx->hash_id );
+ if( md_info == NULL )
+ return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+ hlen = md_get_size( md_info );
+
+ if( ilen < 0 || olen < ilen + 2 * hlen + 2 || f_rng == NULL )
+ return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+ memset( output, 0, olen );
+ memset( &md_ctx, 0, sizeof( md_context_t ) );
+
+ md_init_ctx( &md_ctx, md_info );
+
+ *p++ = 0;
+
+ // Generate a random octet string seed
+ //
+ for( i = 0; i < hlen; ++i )
+ *p++ = (unsigned char) f_rng( p_rng );
+
+ // Construct DB
+ //
+ md( md_info, p, 0, p );
+ p += hlen;
+ p += olen - 2 * hlen - 2 - ilen;
+ *p++ = 1;
+ memcpy( p, input, ilen );
+
+ // maskedDB: Apply dbMask to DB
+ //
+ mgf_mask( output + hlen + 1, olen - hlen - 1, output + 1, hlen,
+ &md_ctx );
+
+ // maskedSeed: Apply seedMask to seed
+ //
+ mgf_mask( output + 1, hlen, output + hlen + 1, olen - hlen - 1,
+ &md_ctx );
+ break;
+#endif
default:
@@ -359,6 +461,12 @@
int ret, ilen;
unsigned char *p;
unsigned char buf[1024];
+#if defined(POLARSSL_PKCS1_V21)
+ unsigned char lhash[POLARSSL_MD_MAX_SIZE];
+ const md_info_t *md_info;
+ md_context_t md_ctx;
+ int hlen;
+#endif
ilen = ctx->len;
@@ -390,6 +498,56 @@
p++;
break;
+#if defined(POLARSSL_PKCS1_V21)
+ case RSA_PKCS_V21:
+
+ if( *p++ != 0 )
+ return( POLARSSL_ERR_RSA_INVALID_PADDING );
+
+ md_info = md_info_from_type( ctx->hash_id );
+ if( md_info == NULL )
+ return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+ hlen = md_get_size( md_info );
+ memset( &md_ctx, 0, sizeof( md_context_t ) );
+
+ md_init_ctx( &md_ctx, md_info );
+
+ // Generate lHash
+ //
+ md( md_info, lhash, 0, lhash );
+
+ // seed: Apply seedMask to maskedSeed
+ //
+ mgf_mask( buf + 1, hlen, buf + hlen + 1, ilen - hlen - 1,
+ &md_ctx );
+
+ // DB: Apply dbMask to maskedDB
+ //
+ mgf_mask( buf + hlen + 1, ilen - hlen - 1, buf + 1, hlen,
+ &md_ctx );
+
+ p += hlen;
+
+ // Check validity
+ //
+ if( memcmp( lhash, p, hlen ) != 0 )
+ return( POLARSSL_ERR_RSA_INVALID_PADDING );
+
+ p += hlen;
+
+ while( *p == 0 && p < buf + ilen )
+ p++;
+
+ if( p == buf + ilen )
+ return( POLARSSL_ERR_RSA_INVALID_PADDING );
+
+ if( *p++ != 0x01 )
+ return( POLARSSL_ERR_RSA_INVALID_PADDING );
+
+ break;
+#endif
+
default:
return( POLARSSL_ERR_RSA_INVALID_PADDING );
@@ -408,6 +566,8 @@
* Do an RSA operation to sign the message digest
*/
int rsa_pkcs1_sign( rsa_context *ctx,
+ int (*f_rng)(void *),
+ void *p_rng,
int mode,
int hash_id,
int hashlen,
@@ -416,6 +576,15 @@
{
int nb_pad, olen;
unsigned char *p = sig;
+#if defined(POLARSSL_PKCS1_V21)
+ unsigned char salt[POLARSSL_MD_MAX_SIZE];
+ const md_info_t *md_info;
+ md_context_t md_ctx;
+ int i, hlen, msb, offset = 0;
+#else
+ (void) f_rng;
+ (void) p_rng;
+#endif
olen = ctx->len;
@@ -468,63 +637,152 @@
memset( p, 0xFF, nb_pad );
p += nb_pad;
*p++ = 0;
+
+ switch( hash_id )
+ {
+ case SIG_RSA_RAW:
+ memcpy( p, hash, hashlen );
+ break;
+
+ case SIG_RSA_MD2:
+ memcpy( p, ASN1_HASH_MDX, 18 );
+ memcpy( p + 18, hash, 16 );
+ p[13] = 2; break;
+
+ case SIG_RSA_MD4:
+ memcpy( p, ASN1_HASH_MDX, 18 );
+ memcpy( p + 18, hash, 16 );
+ p[13] = 4; break;
+
+ case SIG_RSA_MD5:
+ memcpy( p, ASN1_HASH_MDX, 18 );
+ memcpy( p + 18, hash, 16 );
+ p[13] = 5; break;
+
+ case SIG_RSA_SHA1:
+ memcpy( p, ASN1_HASH_SHA1, 15 );
+ memcpy( p + 15, hash, 20 );
+ break;
+
+ case SIG_RSA_SHA224:
+ memcpy( p, ASN1_HASH_SHA2X, 19 );
+ memcpy( p + 19, hash, 28 );
+ p[1] += 28; p[14] = 4; p[18] += 28; break;
+
+ case SIG_RSA_SHA256:
+ memcpy( p, ASN1_HASH_SHA2X, 19 );
+ memcpy( p + 19, hash, 32 );
+ p[1] += 32; p[14] = 1; p[18] += 32; break;
+
+ case SIG_RSA_SHA384:
+ memcpy( p, ASN1_HASH_SHA2X, 19 );
+ memcpy( p + 19, hash, 48 );
+ p[1] += 48; p[14] = 2; p[18] += 48; break;
+
+ case SIG_RSA_SHA512:
+ memcpy( p, ASN1_HASH_SHA2X, 19 );
+ memcpy( p + 19, hash, 64 );
+ p[1] += 64; p[14] = 3; p[18] += 64; break;
+
+ default:
+ return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+ }
+
break;
+#if defined(POLARSSL_PKCS1_V21)
+ case RSA_PKCS_V21:
+
+ if( f_rng == NULL )
+ return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+ switch( hash_id )
+ {
+ case SIG_RSA_MD2:
+ case SIG_RSA_MD4:
+ case SIG_RSA_MD5:
+ hashlen = 16;
+ break;
+
+ case SIG_RSA_SHA1:
+ hashlen = 20;
+ break;
+
+ case SIG_RSA_SHA224:
+ hashlen = 28;
+ break;
+
+ case SIG_RSA_SHA256:
+ hashlen = 32;
+ break;
+
+ case SIG_RSA_SHA384:
+ hashlen = 48;
+ break;
+
+ case SIG_RSA_SHA512:
+ hashlen = 64;
+ break;
+
+ default:
+ return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+ }
+
+ md_info = md_info_from_type( ctx->hash_id );
+ if( md_info == NULL )
+ return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+ hlen = md_get_size( md_info );
+ memset( sig, 0, olen );
+ memset( &md_ctx, 0, sizeof( md_context_t ) );
+
+ md_init_ctx( &md_ctx, md_info );
+
+ msb = mpi_msb( &ctx->N ) - 1;
+
+ // Generate salt of length hlen
+ //
+ for( i = 0; i < hlen; ++i )
+ salt[i] = (unsigned char) f_rng( p_rng );
+
+ // Note: EMSA-PSS encoding is over the length of N - 1 bits
+ //
+ msb = mpi_msb( &ctx->N ) - 1;
+ p += olen - hlen * 2 - 2;
+ *p++ = 0x01;
+ memcpy( p, salt, hlen );
+ p += hlen;
+
+ // Generate H = Hash( M' )
+ //
+ md_starts( &md_ctx );
+ md_update( &md_ctx, p, 8 );
+ md_update( &md_ctx, hash, hashlen );
+ md_update( &md_ctx, salt, hlen );
+ md_finish( &md_ctx, p );
+
+ // Compensate for boundary condition when applying mask
+ //
+ if( msb % 8 == 0 )
+ offset = 1;
+
+ // maskedDB: Apply dbMask to DB
+ //
+ mgf_mask( sig + offset, olen - hlen - 1 - offset, p, hlen, &md_ctx );
+
+ msb = mpi_msb( &ctx->N ) - 1;
+ sig[0] &= 0xFF >> ( olen * 8 - msb );
+
+ p += hlen;
+ *p++ = 0xBC;
+ break;
+#endif
+
default:
return( POLARSSL_ERR_RSA_INVALID_PADDING );
}
- switch( hash_id )
- {
- case SIG_RSA_RAW:
- memcpy( p, hash, hashlen );
- break;
-
- case SIG_RSA_MD2:
- memcpy( p, ASN1_HASH_MDX, 18 );
- memcpy( p + 18, hash, 16 );
- p[13] = 2; break;
-
- case SIG_RSA_MD4:
- memcpy( p, ASN1_HASH_MDX, 18 );
- memcpy( p + 18, hash, 16 );
- p[13] = 4; break;
-
- case SIG_RSA_MD5:
- memcpy( p, ASN1_HASH_MDX, 18 );
- memcpy( p + 18, hash, 16 );
- p[13] = 5; break;
-
- case SIG_RSA_SHA1:
- memcpy( p, ASN1_HASH_SHA1, 15 );
- memcpy( p + 15, hash, 20 );
- break;
-
- case SIG_RSA_SHA224:
- memcpy( p, ASN1_HASH_SHA2X, 19 );
- memcpy( p + 19, hash, 28 );
- p[1] += 28; p[14] = 4; p[18] += 28; break;
-
- case SIG_RSA_SHA256:
- memcpy( p, ASN1_HASH_SHA2X, 19 );
- memcpy( p + 19, hash, 32 );
- p[1] += 32; p[14] = 1; p[18] += 32; break;
-
- case SIG_RSA_SHA384:
- memcpy( p, ASN1_HASH_SHA2X, 19 );
- memcpy( p + 19, hash, 48 );
- p[1] += 48; p[14] = 2; p[18] += 48; break;
-
- case SIG_RSA_SHA512:
- memcpy( p, ASN1_HASH_SHA2X, 19 );
- memcpy( p + 19, hash, 64 );
- p[1] += 64; p[14] = 3; p[18] += 64; break;
-
- default:
- return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
- }
-
return( ( mode == RSA_PUBLIC )
? rsa_public( ctx, sig, sig )
: rsa_private( ctx, sig, sig ) );
@@ -543,7 +801,12 @@
int ret, len, siglen;
unsigned char *p, c;
unsigned char buf[1024];
-
+#if defined(POLARSSL_PKCS1_V21)
+ unsigned char zeros[8];
+ const md_info_t *md_info;
+ md_context_t md_ctx;
+ int hlen, msb;
+#endif
siglen = ctx->len;
if( siglen < 16 || siglen > (int) sizeof( buf ) )
@@ -572,67 +835,158 @@
p++;
}
p++;
+
+ len = siglen - (int)( p - buf );
+
+ if( len == 34 )
+ {
+ c = p[13];
+ p[13] = 0;
+
+ if( memcmp( p, ASN1_HASH_MDX, 18 ) != 0 )
+ return( POLARSSL_ERR_RSA_VERIFY_FAILED );
+
+ if( ( c == 2 && hash_id == SIG_RSA_MD2 ) ||
+ ( c == 4 && hash_id == SIG_RSA_MD4 ) ||
+ ( c == 5 && hash_id == SIG_RSA_MD5 ) )
+ {
+ if( memcmp( p + 18, hash, 16 ) == 0 )
+ return( 0 );
+ else
+ return( POLARSSL_ERR_RSA_VERIFY_FAILED );
+ }
+ }
+
+ if( len == 35 && hash_id == SIG_RSA_SHA1 )
+ {
+ if( memcmp( p, ASN1_HASH_SHA1, 15 ) == 0 &&
+ memcmp( p + 15, hash, 20 ) == 0 )
+ return( 0 );
+ else
+ return( POLARSSL_ERR_RSA_VERIFY_FAILED );
+ }
+ if( ( len == 19 + 28 && p[14] == 4 && hash_id == SIG_RSA_SHA224 ) ||
+ ( len == 19 + 32 && p[14] == 1 && hash_id == SIG_RSA_SHA256 ) ||
+ ( len == 19 + 48 && p[14] == 2 && hash_id == SIG_RSA_SHA384 ) ||
+ ( len == 19 + 64 && p[14] == 3 && hash_id == SIG_RSA_SHA512 ) )
+ {
+ c = p[1] - 17;
+ p[1] = 17;
+ p[14] = 0;
+
+ if( p[18] == c &&
+ memcmp( p, ASN1_HASH_SHA2X, 18 ) == 0 &&
+ memcmp( p + 19, hash, c ) == 0 )
+ return( 0 );
+ else
+ return( POLARSSL_ERR_RSA_VERIFY_FAILED );
+ }
+
+ if( len == hashlen && hash_id == SIG_RSA_RAW )
+ {
+ if( memcmp( p, hash, hashlen ) == 0 )
+ return( 0 );
+ else
+ return( POLARSSL_ERR_RSA_VERIFY_FAILED );
+ }
+
break;
+#if defined(POLARSSL_PKCS1_V21)
+ case RSA_PKCS_V21:
+
+ if( buf[siglen - 1] != 0xBC )
+ return( POLARSSL_ERR_RSA_INVALID_PADDING );
+
+ switch( hash_id )
+ {
+ case SIG_RSA_MD2:
+ case SIG_RSA_MD4:
+ case SIG_RSA_MD5:
+ hashlen = 16;
+ break;
+
+ case SIG_RSA_SHA1:
+ hashlen = 20;
+ break;
+
+ case SIG_RSA_SHA224:
+ hashlen = 28;
+ break;
+
+ case SIG_RSA_SHA256:
+ hashlen = 32;
+ break;
+
+ case SIG_RSA_SHA384:
+ hashlen = 48;
+ break;
+
+ case SIG_RSA_SHA512:
+ hashlen = 64;
+ break;
+
+ default:
+ return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+ }
+
+ md_info = md_info_from_type( ctx->hash_id );
+ if( md_info == NULL )
+ return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+ hlen = md_get_size( md_info );
+ memset( &md_ctx, 0, sizeof( md_context_t ) );
+ memset( zeros, 0, 8 );
+
+ md_init_ctx( &md_ctx, md_info );
+
+ // Note: EMSA-PSS verification is over the length of N - 1 bits
+ //
+ msb = mpi_msb( &ctx->N ) - 1;
+
+ // Compensate for boundary condition when applying mask
+ //
+ if( msb % 8 == 0 )
+ {
+ p++;
+ siglen -= 1;
+ }
+ if( buf[0] >> ( 8 - siglen * 8 + msb ) )
+ return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+ mgf_mask( p, siglen - hlen - 1, p + siglen - hlen - 1, hlen, &md_ctx );
+
+ buf[0] &= 0xFF >> ( siglen * 8 - msb );
+
+ while( *p == 0 && p < buf + siglen )
+ p++;
+
+ if( p == buf + siglen )
+ return( POLARSSL_ERR_RSA_INVALID_PADDING );
+
+ if( *p++ != 0x01 )
+ return( POLARSSL_ERR_RSA_INVALID_PADDING );
+
+ // Generate H = Hash( M' )
+ //
+ md_starts( &md_ctx );
+ md_update( &md_ctx, zeros, 8 );
+ md_update( &md_ctx, hash, hashlen );
+ md_update( &md_ctx, p, hlen );
+ md_finish( &md_ctx, p );
+
+ if( memcmp( p, p + hlen, hlen ) == 0 )
+ return( 0 );
+ else
+ return( POLARSSL_ERR_RSA_VERIFY_FAILED );
+ break;
+#endif
+
default:
return( POLARSSL_ERR_RSA_INVALID_PADDING );
}
- len = siglen - (int)( p - buf );
-
- if( len == 34 )
- {
- c = p[13];
- p[13] = 0;
-
- if( memcmp( p, ASN1_HASH_MDX, 18 ) != 0 )
- return( POLARSSL_ERR_RSA_VERIFY_FAILED );
-
- if( ( c == 2 && hash_id == SIG_RSA_MD2 ) ||
- ( c == 4 && hash_id == SIG_RSA_MD4 ) ||
- ( c == 5 && hash_id == SIG_RSA_MD5 ) )
- {
- if( memcmp( p + 18, hash, 16 ) == 0 )
- return( 0 );
- else
- return( POLARSSL_ERR_RSA_VERIFY_FAILED );
- }
- }
-
- if( len == 35 && hash_id == SIG_RSA_SHA1 )
- {
- if( memcmp( p, ASN1_HASH_SHA1, 15 ) == 0 &&
- memcmp( p + 15, hash, 20 ) == 0 )
- return( 0 );
- else
- return( POLARSSL_ERR_RSA_VERIFY_FAILED );
- }
- if( ( len == 19 + 28 && p[14] == 4 && hash_id == SIG_RSA_SHA224 ) ||
- ( len == 19 + 32 && p[14] == 1 && hash_id == SIG_RSA_SHA256 ) ||
- ( len == 19 + 48 && p[14] == 2 && hash_id == SIG_RSA_SHA384 ) ||
- ( len == 19 + 64 && p[14] == 3 && hash_id == SIG_RSA_SHA512 ) )
- {
- c = p[1] - 17;
- p[1] = 17;
- p[14] = 0;
-
- if( p[18] == c &&
- memcmp( p, ASN1_HASH_SHA2X, 18 ) == 0 &&
- memcmp( p + 19, hash, c ) == 0 )
- return( 0 );
- else
- return( POLARSSL_ERR_RSA_VERIFY_FAILED );
- }
-
- if( len == hashlen && hash_id == SIG_RSA_RAW )
- {
- if( memcmp( p, hash, hashlen ) == 0 )
- return( 0 );
- else
- return( POLARSSL_ERR_RSA_VERIFY_FAILED );
- }
-
return( POLARSSL_ERR_RSA_INVALID_PADDING );
}
@@ -789,7 +1143,7 @@
sha1( rsa_plaintext, PT_LEN, sha1sum );
- if( rsa_pkcs1_sign( &rsa, RSA_PRIVATE, SIG_RSA_SHA1, 20,
+ if( rsa_pkcs1_sign( &rsa, NULL, NULL, RSA_PRIVATE, SIG_RSA_SHA1, 20,
sha1sum, rsa_ciphertext ) != 0 )
{
if( verbose != 0 )