gen_prime: ensure X = 2 mod 3 -> 2.5x speedup
diff --git a/include/polarssl/bignum.h b/include/polarssl/bignum.h
index 9bed027..b63a242 100644
--- a/include/polarssl/bignum.h
+++ b/include/polarssl/bignum.h
@@ -58,7 +58,7 @@
 #define POLARSSL_ERR_MPI_NOT_ACCEPTABLE                    -0x000E  /**< The input arguments are not acceptable. */
 #define POLARSSL_ERR_MPI_MALLOC_FAILED                     -0x0010  /**< Memory allocation failed. */
 
-#define MPI_CHK(f) if( ( ret = f ) != 0 ) goto cleanup
+#define MPI_CHK(f) do { if( ( ret = f ) != 0 ) goto cleanup; } while( 0 )
 
 /*
  * Maximum size MPIs are allowed to grow to in number of limbs.
diff --git a/library/bignum.c b/library/bignum.c
index 7fac5fa..c81ee5b 100644
--- a/library/bignum.c
+++ b/library/bignum.c
@@ -1923,6 +1923,7 @@
 {
     int ret;
     size_t k, n;
+    t_uint r;
     mpi Y;
 
     if( nbits < 3 || nbits > POLARSSL_MPI_MAX_BITS )
@@ -1952,7 +1953,19 @@
     }
     else
     {
-        MPI_CHK( mpi_sub_int( &Y, X, 1 ) );
+        /*
+         * An necessary condition for Y and X = 2Y + 1 to be prime
+         * is X = 2 mod 3 (which is equivalent to Y = 2 mod 3).
+         * Make sure it is satisfied, while keeping X = 3 mod 4
+         */
+        MPI_CHK( mpi_mod_int( &r, X, 3 ) );
+        if( r == 0 )
+            MPI_CHK( mpi_add_int( X, X, 8 ) );
+        else if( r == 1 )
+            MPI_CHK( mpi_add_int( X, X, 4 ) );
+
+        /* Set Y = (X-1) / 2, which is X / 2 because X is odd */
+        MPI_CHK( mpi_copy( &Y, X ) );
         MPI_CHK( mpi_shift_r( &Y, 1 ) );
 
         while( 1 )
@@ -1969,9 +1982,13 @@
             if( ret != POLARSSL_ERR_MPI_NOT_ACCEPTABLE )
                 goto cleanup;
 
-            /* Keep X = 3 mod 4 */
-            MPI_CHK( mpi_add_int(  X,  X, 4 ) );
-            MPI_CHK( mpi_add_int( &Y, &Y, 2 ) );
+            /*
+             * Next candidates. We want to preserve
+             * Y = 1 mod 2 and Y = 2 mod 3 (eq X = 3 mod 4 and X = 2 mod 3)
+             * so up Y by 6 and X by 12.
+             */
+            MPI_CHK( mpi_add_int(  X,  X, 12 ) );
+            MPI_CHK( mpi_add_int( &Y, &Y, 6  ) );
         }
     }