Unify sanity checks for RSA private and public keys
diff --git a/library/rsa.c b/library/rsa.c
index efc1489..493cd1c 100644
--- a/library/rsa.c
+++ b/library/rsa.c
@@ -139,7 +139,7 @@
     uint16_t attempt;  /* Number of current attempt  */
     uint16_t iter;     /* Number of squares computed in the current attempt */
 
-    uint16_t order;       /* Order of 2 in DE - 1 */
+    uint16_t order;    /* Order of 2 in DE - 1 */
 
     mbedtls_mpi T;  /* Holds largest odd divisor of DE - 1 */
     mbedtls_mpi K;  /* During factorization attempts, stores a random integer
@@ -601,6 +601,89 @@
     return( 0 );
 }
 
+/*
+ * Checks whether the context fields are set in such a way
+ * that the RSA primitives will be able to execute without error.
+ * It does *not* make guarantees for consistency of the parameters.
+ */
+static int rsa_check_context( mbedtls_rsa_context const *ctx, int is_priv )
+{
+    if( ctx->len != mbedtls_mpi_size( &ctx->N ) )
+        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
+
+    /*
+     * 1. Modular exponentiation needs positive, odd moduli.
+     */
+
+    /* Modular exponentiation wrt. N is always used for
+     * RSA public key operations. */
+    if( mbedtls_mpi_cmp_int( &ctx->N, 0 ) <= 0 ||
+        mbedtls_mpi_get_bit( &ctx->N, 0 ) == 0  )
+    {
+        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
+    }
+
+#if !defined(MBEDTLS_RSA_NO_CRT)
+    /* Modular exponentiation for P and Q is only
+     * used for private key operations and if CRT
+     * is used. */
+    if( is_priv &&
+        ( mbedtls_mpi_cmp_int( &ctx->P, 0 ) <= 0 ||
+          mbedtls_mpi_get_bit( &ctx->P, 0 ) == 0 ||
+          mbedtls_mpi_cmp_int( &ctx->Q, 0 ) <= 0 ||
+          mbedtls_mpi_get_bit( &ctx->Q, 0 ) == 0  ) )
+    {
+        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
+    }
+#endif /* !MBEDTLS_RSA_NO_CRT */
+
+    /*
+     * 2. Exponents must be positive
+     */
+
+    /* Always need E for public key operations */
+    if( mbedtls_mpi_cmp_int( &ctx->E, 0 ) <= 0 )
+        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
+
+#if !defined(MBEDTLS_NO_CRT)
+    /* For private key operations, use D or DP & DQ
+     * as (unblinded) exponents. */
+    if( is_priv && mbedtls_mpi_cmp_int( &ctx->D, 0 ) <= 0 )
+        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
+#else
+    if( is_priv &&
+        ( mbedtls_mpi_cmp_int( &ctx->DP, 0 ) <= 0 ||
+          mbedtls_mpi_cmp_int( &ctx->DQ, 0 ) <= 0  ) )
+    {
+        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
+    }
+#endif /* MBEDTLS_RSA_NO_CRT */
+
+    /* Blinding shouldn't make exponents negative either,
+     * so check that P, Q >= 1 if that hasn't yet been
+     * done as part of 1. */
+#if defined(MBEDTLS_NO_CRT)
+    if( is_priv &&
+        ( mbedtls_mpi_cmp_int( &ctx->P, 0 ) <= 0 ||
+          mbedtls_mpi_cmp_int( &ctx->Q, 0 ) <= 0 ) )
+    {
+        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
+    }
+#endif
+
+    /* It wouldn't lead to an error if it wasn't satisfied,
+     * but check for PQ >= 1 nonetheless. */
+#if !defined(MBEDTLS_NO_CRT)
+    if( is_priv &&
+        mbedtls_mpi_cmp_int( &ctx->QP, 0 ) <= 0 )
+    {
+        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
+    }
+#endif
+
+    return( 0 );
+}
+
 int mbedtls_rsa_complete( mbedtls_rsa_context *ctx )
 {
     int ret = 0;
@@ -686,21 +769,10 @@
 #endif /* MBEDTLS_RSA_NO_CRT */
 
     /*
-     * Step 3: Basic sanity check
+     * Step 3: Basic sanity checks
      */
 
-    if( is_priv )
-    {
-        if( ( ret = mbedtls_rsa_check_privkey( ctx ) ) != 0 )
-            return( ret );
-    }
-    else
-    {
-        if( ( ret = mbedtls_rsa_check_pubkey( ctx ) ) != 0 )
-            return( ret );
-    }
-
-    return( 0 );
+    return( rsa_check_context( ctx, is_priv ) );
 }
 
 int mbedtls_rsa_export_raw( const mbedtls_rsa_context *ctx,
@@ -960,20 +1032,8 @@
  */
 int mbedtls_rsa_check_pubkey( const mbedtls_rsa_context *ctx )
 {
-    if( mbedtls_mpi_cmp_int( &ctx->N, 0 ) == 0 ||
-        mbedtls_mpi_cmp_int( &ctx->E, 0 ) == 0 )
-    {
+    if( rsa_check_context( ctx, 0 /* public */ ) != 0 )
         return( MBEDTLS_ERR_RSA_KEY_CHECK_FAILED );
-    }
-
-    if( ctx->len != mbedtls_mpi_size( &ctx->N ) )
-        return( MBEDTLS_ERR_RSA_KEY_CHECK_FAILED );
-
-    if( mbedtls_mpi_get_bit( &ctx->N, 0 ) == 0 ||
-        mbedtls_mpi_get_bit( &ctx->E, 0 ) == 0 )
-    {
-        return( MBEDTLS_ERR_RSA_KEY_CHECK_FAILED );
-    }
 
     if( mbedtls_mpi_bitlen( &ctx->N ) < 128 ||
         mbedtls_mpi_bitlen( &ctx->N ) > MBEDTLS_MPI_MAX_BITS )
@@ -981,7 +1041,8 @@
         return( MBEDTLS_ERR_RSA_KEY_CHECK_FAILED );
     }
 
-    if( mbedtls_mpi_bitlen( &ctx->E ) < 2 ||
+    if( mbedtls_mpi_get_bit( &ctx->E, 0 ) == 0 ||
+        mbedtls_mpi_bitlen( &ctx->E )     < 2  ||
         mbedtls_mpi_cmp_mpi( &ctx->E, &ctx->N ) >= 0 )
     {
         return( MBEDTLS_ERR_RSA_KEY_CHECK_FAILED );
@@ -991,18 +1052,22 @@
 }
 
 /*
- * Check a private RSA key
+ * Check for the consistency of all fields in an RSA private key context
  */
 int mbedtls_rsa_check_privkey( const mbedtls_rsa_context *ctx )
 {
-    if( mbedtls_rsa_check_pubkey( ctx ) != 0 )
+    if( mbedtls_rsa_check_pubkey( ctx ) != 0 ||
+        rsa_check_context( ctx, 1 /* private */ ) != 0 )
+    {
         return( MBEDTLS_ERR_RSA_KEY_CHECK_FAILED );
+    }
 
     if( mbedtls_rsa_validate_params( &ctx->N, &ctx->P, &ctx->Q,
                                      &ctx->D, &ctx->E, NULL, NULL ) != 0 )
     {
         return( MBEDTLS_ERR_RSA_KEY_CHECK_FAILED );
     }
+
 #if !defined(MBEDTLS_RSA_NO_CRT)
     else if( mbedtls_rsa_validate_crt( &ctx->P, &ctx->Q, &ctx->D,
                                        &ctx->DP, &ctx->DQ, &ctx->QP ) != 0 )
@@ -1046,6 +1111,9 @@
     size_t olen;
     mbedtls_mpi T;
 
+    if( rsa_check_context( ctx, 0 /* public */ ) )
+        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
+
     mbedtls_mpi_init( &T );
 
 #if defined(MBEDTLS_THREADING_C)
@@ -1162,24 +1230,8 @@
     mbedtls_mpi *DQ = &ctx->DQ;
 #endif
 
-    /* Sanity-check that all relevant fields are at least set,
-     * but don't perform a full keycheck. */
-    if( mbedtls_mpi_cmp_int( &ctx->N, 0 ) == 0 ||
-        mbedtls_mpi_cmp_int( &ctx->P, 0 ) == 0 ||
-        mbedtls_mpi_cmp_int( &ctx->Q, 0 ) == 0 ||
-        mbedtls_mpi_cmp_int( &ctx->D, 0 ) == 0 ||
-        mbedtls_mpi_cmp_int( &ctx->E, 0 ) == 0 )
-    {
+    if( rsa_check_context( ctx, 1 /* private */ ) != 0 )
         return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
-    }
-#if !defined(MBEDTLS_RSA_NO_CRT)
-    if( mbedtls_mpi_cmp_int( &ctx->DP, 0 ) == 0 ||
-        mbedtls_mpi_cmp_int( &ctx->DQ, 0 ) == 0 ||
-        mbedtls_mpi_cmp_int( &ctx->QP, 0 ) == 0 )
-    {
-        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
-    }
-#endif /* MBEDTLS_RSA_NO_CRT */
 
     mbedtls_mpi_init( &T ); mbedtls_mpi_init( &T1 ); mbedtls_mpi_init( &T2 );
     mbedtls_mpi_init( &P1 ); mbedtls_mpi_init( &Q1 ); mbedtls_mpi_init( &R );