Merge remote-tracking branch 'public/pr/1380' into development-proposed

* public/pr/1380:
  Update ChangeLog for #1380
  Generate RSA keys according to FIPS 186-4
  Generate primes according to FIPS 186-4
  Avoid small private exponents during RSA key generation
diff --git a/ChangeLog b/ChangeLog
index 9ee82c6..ae8d86f 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -42,7 +42,7 @@
      mnacamura.
    * Fix parsing of PKCS#8 encoded Elliptic Curve keys. Previously Mbed TLS was
      unable to parse keys with only the optional parameters field of the
-     ECPrivateKey structure. Found by jethrogb, fixed in #1379.
+     ECPrivateKey structure. Found by Jethro Beekman, fixed in #1379.
    * Return plaintext data sooner on unpadded CBC decryption, as stated in
      the mbedtls_cipher_update() documentation. Contributed by Andy Leiserson.
    * Fix overriding and ignoring return values when parsing and writing to
@@ -93,6 +93,8 @@
    * Improve robustness of mbedtls_ssl_derive_keys against the use of
      HMAC functions with non-HMAC ciphersuites. Independently contributed
      by Jiayuan Chen in #1377. Fixes #1437.
+   * Improve security of RSA key generation by including criteria from FIPS
+     186-4. Contributed by Jethro Beekman. #1380
 
 = mbed TLS 2.8.0 branch released 2018-03-16
 
diff --git a/library/bignum.c b/library/bignum.c
index 47bf1ef..f58af78 100644
--- a/library/bignum.c
+++ b/library/bignum.c
@@ -2194,12 +2194,23 @@
 
 /*
  * Prime number generation
+ *
+ * If dh_flag is 0 and nbits is at least 1024, then the procedure
+ * follows the RSA probably-prime generation method of FIPS 186-4.
+ * NB. FIPS 186-4 only allows the specific bit lengths of 1024 and 1536.
  */
 int mbedtls_mpi_gen_prime( mbedtls_mpi *X, size_t nbits, int dh_flag,
                    int (*f_rng)(void *, unsigned char *, size_t),
                    void *p_rng )
 {
-    int ret;
+#ifdef MBEDTLS_HAVE_INT64
+// ceil(2^63.5)
+#define CEIL_MAXUINT_DIV_SQRT2 0xb504f333f9de6485ULL
+#else
+// ceil(2^31.5)
+#define CEIL_MAXUINT_DIV_SQRT2 0xb504f334U
+#endif
+    int ret = MBEDTLS_ERR_MPI_NOT_ACCEPTABLE;
     size_t k, n;
     mbedtls_mpi_uint r;
     mbedtls_mpi Y;
@@ -2211,69 +2222,66 @@
 
     n = BITS_TO_LIMBS( nbits );
 
-    MBEDTLS_MPI_CHK( mbedtls_mpi_fill_random( X, n * ciL, f_rng, p_rng ) );
-
-    k = mbedtls_mpi_bitlen( X );
-    if( k > nbits ) MBEDTLS_MPI_CHK( mbedtls_mpi_shift_r( X, k - nbits + 1 ) );
-
-    mbedtls_mpi_set_bit( X, nbits-1, 1 );
-
-    X->p[0] |= 1;
-
-    if( dh_flag == 0 )
+    while( 1 )
     {
-        while( ( ret = mbedtls_mpi_is_prime( X, f_rng, p_rng ) ) != 0 )
+        MBEDTLS_MPI_CHK( mbedtls_mpi_fill_random( X, n * ciL, f_rng, p_rng ) );
+        /* make sure generated number is at least (nbits-1)+0.5 bits (FIPS 186-4 §B.3.3 steps 4.4, 5.5) */
+        if( X->p[n-1] < CEIL_MAXUINT_DIV_SQRT2 ) continue;
+
+        k = n * biL;
+        if( k > nbits ) MBEDTLS_MPI_CHK( mbedtls_mpi_shift_r( X, k - nbits ) );
+        X->p[0] |= 1;
+
+        if( dh_flag == 0 )
         {
+            ret = mbedtls_mpi_is_prime( X, f_rng, p_rng );
+
             if( ret != MBEDTLS_ERR_MPI_NOT_ACCEPTABLE )
                 goto cleanup;
-
-            MBEDTLS_MPI_CHK( mbedtls_mpi_add_int( X, X, 2 ) );
         }
-    }
-    else
-    {
-        /*
-         * An necessary condition for Y and X = 2Y + 1 to be prime
-         * is X = 2 mod 3 (which is equivalent to Y = 2 mod 3).
-         * Make sure it is satisfied, while keeping X = 3 mod 4
-         */
-
-        X->p[0] |= 2;
-
-        MBEDTLS_MPI_CHK( mbedtls_mpi_mod_int( &r, X, 3 ) );
-        if( r == 0 )
-            MBEDTLS_MPI_CHK( mbedtls_mpi_add_int( X, X, 8 ) );
-        else if( r == 1 )
-            MBEDTLS_MPI_CHK( mbedtls_mpi_add_int( X, X, 4 ) );
-
-        /* Set Y = (X-1) / 2, which is X / 2 because X is odd */
-        MBEDTLS_MPI_CHK( mbedtls_mpi_copy( &Y, X ) );
-        MBEDTLS_MPI_CHK( mbedtls_mpi_shift_r( &Y, 1 ) );
-
-        while( 1 )
+        else
         {
             /*
-             * First, check small factors for X and Y
-             * before doing Miller-Rabin on any of them
+             * An necessary condition for Y and X = 2Y + 1 to be prime
+             * is X = 2 mod 3 (which is equivalent to Y = 2 mod 3).
+             * Make sure it is satisfied, while keeping X = 3 mod 4
              */
-            if( ( ret = mpi_check_small_factors(  X         ) ) == 0 &&
-                ( ret = mpi_check_small_factors( &Y         ) ) == 0 &&
-                ( ret = mpi_miller_rabin(  X, f_rng, p_rng  ) ) == 0 &&
-                ( ret = mpi_miller_rabin( &Y, f_rng, p_rng  ) ) == 0 )
+
+            X->p[0] |= 2;
+
+            MBEDTLS_MPI_CHK( mbedtls_mpi_mod_int( &r, X, 3 ) );
+            if( r == 0 )
+                MBEDTLS_MPI_CHK( mbedtls_mpi_add_int( X, X, 8 ) );
+            else if( r == 1 )
+                MBEDTLS_MPI_CHK( mbedtls_mpi_add_int( X, X, 4 ) );
+
+            /* Set Y = (X-1) / 2, which is X / 2 because X is odd */
+            MBEDTLS_MPI_CHK( mbedtls_mpi_copy( &Y, X ) );
+            MBEDTLS_MPI_CHK( mbedtls_mpi_shift_r( &Y, 1 ) );
+
+            while( 1 )
             {
-                break;
+                /*
+                 * First, check small factors for X and Y
+                 * before doing Miller-Rabin on any of them
+                 */
+                if( ( ret = mpi_check_small_factors(  X         ) ) == 0 &&
+                    ( ret = mpi_check_small_factors( &Y         ) ) == 0 &&
+                    ( ret = mpi_miller_rabin(  X, f_rng, p_rng  ) ) == 0 &&
+                    ( ret = mpi_miller_rabin( &Y, f_rng, p_rng  ) ) == 0 )
+                    goto cleanup;
+
+                if( ret != MBEDTLS_ERR_MPI_NOT_ACCEPTABLE )
+                    goto cleanup;
+
+                /*
+                 * Next candidates. We want to preserve Y = (X-1) / 2 and
+                 * Y = 1 mod 2 and Y = 2 mod 3 (eq X = 3 mod 4 and X = 2 mod 3)
+                 * so up Y by 6 and X by 12.
+                 */
+                MBEDTLS_MPI_CHK( mbedtls_mpi_add_int(  X,  X, 12 ) );
+                MBEDTLS_MPI_CHK( mbedtls_mpi_add_int( &Y, &Y, 6  ) );
             }
-
-            if( ret != MBEDTLS_ERR_MPI_NOT_ACCEPTABLE )
-                goto cleanup;
-
-            /*
-             * Next candidates. We want to preserve Y = (X-1) / 2 and
-             * Y = 1 mod 2 and Y = 2 mod 3 (eq X = 3 mod 4 and X = 2 mod 3)
-             * so up Y by 6 and X by 12.
-             */
-            MBEDTLS_MPI_CHK( mbedtls_mpi_add_int(  X,  X, 12 ) );
-            MBEDTLS_MPI_CHK( mbedtls_mpi_add_int( &Y, &Y, 6  ) );
         }
     }
 
diff --git a/library/rsa.c b/library/rsa.c
index 2185040..729e1f7 100644
--- a/library/rsa.c
+++ b/library/rsa.c
@@ -495,6 +495,9 @@
 
 /*
  * Generate an RSA keypair
+ *
+ * This generation method follows the RSA key pair generation procedure of
+ * FIPS 186-4 if 2^16 < exponent < 2^256 and nbits = 2048 or nbits = 3072.
  */
 int mbedtls_rsa_gen_key( mbedtls_rsa_context *ctx,
                  int (*f_rng)(void *, unsigned char *, size_t),
@@ -502,7 +505,7 @@
                  unsigned int nbits, int exponent )
 {
     int ret;
-    mbedtls_mpi H, G;
+    mbedtls_mpi H, G, L;
 
     if( f_rng == NULL || nbits < 128 || exponent < 3 )
         return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
@@ -512,10 +515,13 @@
 
     mbedtls_mpi_init( &H );
     mbedtls_mpi_init( &G );
+    mbedtls_mpi_init( &L );
 
     /*
      * find primes P and Q with Q < P so that:
-     * GCD( E, (P-1)*(Q-1) ) == 1
+     * 1.  |P-Q| > 2^( nbits / 2 - 100 )
+     * 2.  GCD( E, (P-1)*(Q-1) ) == 1
+     * 3.  E^-1 mod LCM(P-1, Q-1) > 2^( nbits / 2 )
      */
     MBEDTLS_MPI_CHK( mbedtls_mpi_lset( &ctx->E, exponent ) );
 
@@ -527,40 +533,51 @@
         MBEDTLS_MPI_CHK( mbedtls_mpi_gen_prime( &ctx->Q, nbits >> 1, 0,
                                                 f_rng, p_rng ) );
 
-        if( mbedtls_mpi_cmp_mpi( &ctx->P, &ctx->Q ) == 0 )
+        /* make sure the difference between p and q is not too small (FIPS 186-4 §B.3.3 step 5.4) */
+        MBEDTLS_MPI_CHK( mbedtls_mpi_sub_mpi( &H, &ctx->P, &ctx->Q ) );
+        if( mbedtls_mpi_bitlen( &H ) <= ( ( nbits >= 200 ) ? ( ( nbits >> 1 ) - 99 ) : 0 ) )
             continue;
 
-        MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &ctx->N, &ctx->P, &ctx->Q ) );
-        if( mbedtls_mpi_bitlen( &ctx->N ) != nbits )
-            continue;
-
-        if( mbedtls_mpi_cmp_mpi( &ctx->P, &ctx->Q ) < 0 )
+        /* not required by any standards, but some users rely on the fact that P > Q */
+        if( H.s < 0 )
             mbedtls_mpi_swap( &ctx->P, &ctx->Q );
 
         /* Temporarily replace P,Q by P-1, Q-1 */
         MBEDTLS_MPI_CHK( mbedtls_mpi_sub_int( &ctx->P, &ctx->P, 1 ) );
         MBEDTLS_MPI_CHK( mbedtls_mpi_sub_int( &ctx->Q, &ctx->Q, 1 ) );
         MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &H, &ctx->P, &ctx->Q ) );
+
+        /* check GCD( E, (P-1)*(Q-1) ) == 1 (FIPS 186-4 §B.3.1 criterion 2(a)) */
         MBEDTLS_MPI_CHK( mbedtls_mpi_gcd( &G, &ctx->E, &H  ) );
+        if( mbedtls_mpi_cmp_int( &G, 1 ) != 0 )
+            continue;
+
+        /* compute smallest possible D = E^-1 mod LCM(P-1, Q-1) (FIPS 186-4 §B.3.1 criterion 3(b)) */
+        MBEDTLS_MPI_CHK( mbedtls_mpi_gcd( &G, &ctx->P, &ctx->Q ) );
+        MBEDTLS_MPI_CHK( mbedtls_mpi_div_mpi( &L, NULL, &H, &G ) );
+        MBEDTLS_MPI_CHK( mbedtls_mpi_inv_mod( &ctx->D, &ctx->E, &L ) );
+
+        if( mbedtls_mpi_bitlen( &ctx->D ) <= ( ( nbits + 1 ) / 2 ) ) // (FIPS 186-4 §B.3.1 criterion 3(a))
+            continue;
+
+        break;
     }
-    while( mbedtls_mpi_cmp_int( &G, 1 ) != 0 );
+    while( 1 );
 
     /* Restore P,Q */
     MBEDTLS_MPI_CHK( mbedtls_mpi_add_int( &ctx->P,  &ctx->P, 1 ) );
     MBEDTLS_MPI_CHK( mbedtls_mpi_add_int( &ctx->Q,  &ctx->Q, 1 ) );
 
+    MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &ctx->N, &ctx->P, &ctx->Q ) );
+
     ctx->len = mbedtls_mpi_size( &ctx->N );
 
+#if !defined(MBEDTLS_RSA_NO_CRT)
     /*
-     * D  = E^-1 mod ((P-1)*(Q-1))
      * DP = D mod (P - 1)
      * DQ = D mod (Q - 1)
      * QP = Q^-1 mod P
      */
-
-    MBEDTLS_MPI_CHK( mbedtls_mpi_inv_mod( &ctx->D, &ctx->E, &H  ) );
-
-#if !defined(MBEDTLS_RSA_NO_CRT)
     MBEDTLS_MPI_CHK( mbedtls_rsa_deduce_crt( &ctx->P, &ctx->Q, &ctx->D,
                                              &ctx->DP, &ctx->DQ, &ctx->QP ) );
 #endif /* MBEDTLS_RSA_NO_CRT */
@@ -572,6 +589,7 @@
 
     mbedtls_mpi_free( &H );
     mbedtls_mpi_free( &G );
+    mbedtls_mpi_free( &L );
 
     if( ret != 0 )
     {
diff --git a/tests/suites/test_suite_mpi.data b/tests/suites/test_suite_mpi.data
index 17cf350..2a2cfce 100644
--- a/tests/suites/test_suite_mpi.data
+++ b/tests/suites/test_suite_mpi.data
@@ -688,6 +688,18 @@
 depends_on:MBEDTLS_GENPRIME
 mbedtls_mpi_gen_prime:3:0:0
 
+Test mbedtls_mpi_gen_prime (corner case limb size -1 bits)
+depends_on:MBEDTLS_GENPRIME
+mbedtls_mpi_gen_prime:63:0:0
+
+Test mbedtls_mpi_gen_prime (corner case limb size)
+depends_on:MBEDTLS_GENPRIME
+mbedtls_mpi_gen_prime:64:0:0
+
+Test mbedtls_mpi_gen_prime (corner case limb size +1 bits)
+depends_on:MBEDTLS_GENPRIME
+mbedtls_mpi_gen_prime:65:0:0
+
 Test mbedtls_mpi_gen_prime (Larger)
 depends_on:MBEDTLS_GENPRIME
 mbedtls_mpi_gen_prime:128:0:0