Merge pull request #6138 from Zaya-dyno/validation_remove_change_key_agree

Validation remove change key agree
diff --git a/library/dhm.c b/library/dhm.c
index 1e95bda..1ba5339 100644
--- a/library/dhm.c
+++ b/library/dhm.c
@@ -55,11 +55,6 @@
 
 #if !defined(MBEDTLS_DHM_ALT)
 
-#define DHM_VALIDATE_RET( cond )    \
-    MBEDTLS_INTERNAL_VALIDATE_RET( cond, MBEDTLS_ERR_DHM_BAD_INPUT_DATA )
-#define DHM_VALIDATE( cond )        \
-    MBEDTLS_INTERNAL_VALIDATE( cond )
-
 /*
  * helper to validate the mbedtls_mpi size and import it
  */
@@ -120,7 +115,6 @@
 
 void mbedtls_dhm_init( mbedtls_dhm_context *ctx )
 {
-    DHM_VALIDATE( ctx != NULL );
     memset( ctx, 0, sizeof( mbedtls_dhm_context ) );
 }
 
@@ -173,9 +167,6 @@
                      const unsigned char *end )
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-    DHM_VALIDATE_RET( ctx != NULL );
-    DHM_VALIDATE_RET( p != NULL && *p != NULL );
-    DHM_VALIDATE_RET( end != NULL );
 
     if( ( ret = dhm_read_bignum( &ctx->P,  p, end ) ) != 0 ||
         ( ret = dhm_read_bignum( &ctx->G,  p, end ) ) != 0 ||
@@ -252,10 +243,6 @@
     int ret;
     size_t n1, n2, n3;
     unsigned char *p;
-    DHM_VALIDATE_RET( ctx != NULL );
-    DHM_VALIDATE_RET( output != NULL );
-    DHM_VALIDATE_RET( olen != NULL );
-    DHM_VALIDATE_RET( f_rng != NULL );
 
     ret = dhm_make_common( ctx, x_size, f_rng, p_rng );
     if( ret != 0 )
@@ -300,9 +287,6 @@
                            const mbedtls_mpi *G )
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-    DHM_VALIDATE_RET( ctx != NULL );
-    DHM_VALIDATE_RET( P != NULL );
-    DHM_VALIDATE_RET( G != NULL );
 
     if( ( ret = mbedtls_mpi_copy( &ctx->P, P ) ) != 0 ||
         ( ret = mbedtls_mpi_copy( &ctx->G, G ) ) != 0 )
@@ -320,8 +304,6 @@
                      const unsigned char *input, size_t ilen )
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-    DHM_VALIDATE_RET( ctx != NULL );
-    DHM_VALIDATE_RET( input != NULL );
 
     if( ilen < 1 || ilen > mbedtls_dhm_get_len( ctx ) )
         return( MBEDTLS_ERR_DHM_BAD_INPUT_DATA );
@@ -341,9 +323,6 @@
                      void *p_rng )
 {
     int ret;
-    DHM_VALIDATE_RET( ctx != NULL );
-    DHM_VALIDATE_RET( output != NULL );
-    DHM_VALIDATE_RET( f_rng != NULL );
 
     if( olen < 1 || olen > mbedtls_dhm_get_len( ctx ) )
         return( MBEDTLS_ERR_DHM_BAD_INPUT_DATA );
@@ -440,9 +419,6 @@
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     mbedtls_mpi GYb;
-    DHM_VALIDATE_RET( ctx != NULL );
-    DHM_VALIDATE_RET( output != NULL );
-    DHM_VALIDATE_RET( olen != NULL );
 
     if( f_rng == NULL )
         return( MBEDTLS_ERR_DHM_BAD_INPUT_DATA );
@@ -518,9 +494,6 @@
     mbedtls_pem_context pem;
 #endif /* MBEDTLS_PEM_PARSE_C */
 
-    DHM_VALIDATE_RET( dhm != NULL );
-    DHM_VALIDATE_RET( dhmin != NULL );
-
 #if defined(MBEDTLS_PEM_PARSE_C)
     mbedtls_pem_init( &pem );
 
@@ -667,8 +640,6 @@
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     size_t n;
     unsigned char *buf;
-    DHM_VALIDATE_RET( dhm != NULL );
-    DHM_VALIDATE_RET( path != NULL );
 
     if( ( ret = load_file( path, &buf, &n ) ) != 0 )
         return( ret );
diff --git a/library/rsa.c b/library/rsa.c
index e597555..4df240a 100644
--- a/library/rsa.c
+++ b/library/rsa.c
@@ -74,19 +74,12 @@
 
 #if !defined(MBEDTLS_RSA_ALT)
 
-/* Parameter validation macros */
-#define RSA_VALIDATE_RET( cond )                                       \
-    MBEDTLS_INTERNAL_VALIDATE_RET( cond, MBEDTLS_ERR_RSA_BAD_INPUT_DATA )
-#define RSA_VALIDATE( cond )                                           \
-    MBEDTLS_INTERNAL_VALIDATE( cond )
-
 int mbedtls_rsa_import( mbedtls_rsa_context *ctx,
                         const mbedtls_mpi *N,
                         const mbedtls_mpi *P, const mbedtls_mpi *Q,
                         const mbedtls_mpi *D, const mbedtls_mpi *E )
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-    RSA_VALIDATE_RET( ctx != NULL );
 
     if( ( N != NULL && ( ret = mbedtls_mpi_copy( &ctx->N, N ) ) != 0 ) ||
         ( P != NULL && ( ret = mbedtls_mpi_copy( &ctx->P, P ) ) != 0 ) ||
@@ -111,7 +104,6 @@
                             unsigned char const *E, size_t E_len )
 {
     int ret = 0;
-    RSA_VALIDATE_RET( ctx != NULL );
 
     if( N != NULL )
     {
@@ -241,8 +233,6 @@
 #endif
     int n_missing, pq_missing, d_missing, is_pub, is_priv;
 
-    RSA_VALIDATE_RET( ctx != NULL );
-
     have_N = ( mbedtls_mpi_cmp_int( &ctx->N, 0 ) != 0 );
     have_P = ( mbedtls_mpi_cmp_int( &ctx->P, 0 ) != 0 );
     have_Q = ( mbedtls_mpi_cmp_int( &ctx->Q, 0 ) != 0 );
@@ -345,7 +335,6 @@
 {
     int ret = 0;
     int is_priv;
-    RSA_VALIDATE_RET( ctx != NULL );
 
     /* Check if key is private or public */
     is_priv =
@@ -390,7 +379,6 @@
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     int is_priv;
-    RSA_VALIDATE_RET( ctx != NULL );
 
     /* Check if key is private or public */
     is_priv =
@@ -434,7 +422,6 @@
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     int is_priv;
-    RSA_VALIDATE_RET( ctx != NULL );
 
     /* Check if key is private or public */
     is_priv =
@@ -471,8 +458,6 @@
  */
 void mbedtls_rsa_init( mbedtls_rsa_context *ctx )
 {
-    RSA_VALIDATE( ctx != NULL );
-
     memset( ctx, 0, sizeof( mbedtls_rsa_context ) );
 
     ctx->padding = MBEDTLS_RSA_PKCS_V15;
@@ -549,8 +534,6 @@
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     mbedtls_mpi H, G, L;
     int prime_quality = 0;
-    RSA_VALIDATE_RET( ctx != NULL );
-    RSA_VALIDATE_RET( f_rng != NULL );
 
     /*
      * If the modulus is 1024 bit long or shorter, then the security strength of
@@ -663,8 +646,6 @@
  */
 int mbedtls_rsa_check_pubkey( const mbedtls_rsa_context *ctx )
 {
-    RSA_VALIDATE_RET( ctx != NULL );
-
     if( rsa_check_context( ctx, 0 /* public */, 0 /* no blinding */ ) != 0 )
         return( MBEDTLS_ERR_RSA_KEY_CHECK_FAILED );
 
@@ -688,8 +669,6 @@
  */
 int mbedtls_rsa_check_privkey( const mbedtls_rsa_context *ctx )
 {
-    RSA_VALIDATE_RET( ctx != NULL );
-
     if( mbedtls_rsa_check_pubkey( ctx ) != 0 ||
         rsa_check_context( ctx, 1 /* private */, 1 /* blinding */ ) != 0 )
     {
@@ -719,9 +698,6 @@
 int mbedtls_rsa_check_pub_priv( const mbedtls_rsa_context *pub,
                                 const mbedtls_rsa_context *prv )
 {
-    RSA_VALIDATE_RET( pub != NULL );
-    RSA_VALIDATE_RET( prv != NULL );
-
     if( mbedtls_rsa_check_pubkey( pub )  != 0 ||
         mbedtls_rsa_check_privkey( prv ) != 0 )
     {
@@ -747,9 +723,6 @@
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     size_t olen;
     mbedtls_mpi T;
-    RSA_VALIDATE_RET( ctx != NULL );
-    RSA_VALIDATE_RET( input != NULL );
-    RSA_VALIDATE_RET( output != NULL );
 
     if( rsa_check_context( ctx, 0 /* public */, 0 /* no blinding */ ) )
         return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
@@ -917,10 +890,6 @@
      * checked result; should be the same in the end. */
     mbedtls_mpi I, C;
 
-    RSA_VALIDATE_RET( ctx != NULL );
-    RSA_VALIDATE_RET( input  != NULL );
-    RSA_VALIDATE_RET( output != NULL );
-
     if( f_rng == NULL )
         return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
 
@@ -1308,11 +1277,6 @@
     unsigned char *p = output;
     unsigned int hlen;
 
-    RSA_VALIDATE_RET( ctx != NULL );
-    RSA_VALIDATE_RET( output != NULL );
-    RSA_VALIDATE_RET( ilen == 0 || input != NULL );
-    RSA_VALIDATE_RET( label_len == 0 || label != NULL );
-
     if( f_rng == NULL )
         return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
 
@@ -1374,10 +1338,6 @@
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     unsigned char *p = output;
 
-    RSA_VALIDATE_RET( ctx != NULL );
-    RSA_VALIDATE_RET( output != NULL );
-    RSA_VALIDATE_RET( ilen == 0 || input != NULL );
-
     olen = ctx->len;
 
     /* first comparison checks for overflow */
@@ -1426,10 +1386,6 @@
                        const unsigned char *input,
                        unsigned char *output )
 {
-    RSA_VALIDATE_RET( ctx != NULL );
-    RSA_VALIDATE_RET( output != NULL );
-    RSA_VALIDATE_RET( ilen == 0 || input != NULL );
-
     switch( ctx->padding )
     {
 #if defined(MBEDTLS_PKCS1_V15)
@@ -1469,12 +1425,6 @@
     unsigned char lhash[MBEDTLS_HASH_MAX_SIZE];
     unsigned int hlen;
 
-    RSA_VALIDATE_RET( ctx != NULL );
-    RSA_VALIDATE_RET( output_max_len == 0 || output != NULL );
-    RSA_VALIDATE_RET( label_len == 0 || label != NULL );
-    RSA_VALIDATE_RET( input != NULL );
-    RSA_VALIDATE_RET( olen != NULL );
-
     /*
      * Parameters sanity checks
      */
@@ -1595,11 +1545,6 @@
     size_t ilen;
     unsigned char buf[MBEDTLS_MPI_MAX_SIZE];
 
-    RSA_VALIDATE_RET( ctx != NULL );
-    RSA_VALIDATE_RET( output_max_len == 0 || output != NULL );
-    RSA_VALIDATE_RET( input != NULL );
-    RSA_VALIDATE_RET( olen != NULL );
-
     ilen = ctx->len;
 
     if( ctx->padding != MBEDTLS_RSA_PKCS_V15 )
@@ -1634,11 +1579,6 @@
                        unsigned char *output,
                        size_t output_max_len)
 {
-    RSA_VALIDATE_RET( ctx != NULL );
-    RSA_VALIDATE_RET( output_max_len == 0 || output != NULL );
-    RSA_VALIDATE_RET( input != NULL );
-    RSA_VALIDATE_RET( olen != NULL );
-
     switch( ctx->padding )
     {
 #if defined(MBEDTLS_PKCS1_V15)
@@ -1676,11 +1616,8 @@
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     size_t msb;
 
-    RSA_VALIDATE_RET( ctx != NULL );
-    RSA_VALIDATE_RET( ( md_alg  == MBEDTLS_MD_NONE &&
-                        hashlen == 0 ) ||
-                      hash != NULL );
-    RSA_VALIDATE_RET( sig != NULL );
+    if( ( md_alg != MBEDTLS_MD_NONE || hashlen != 0 ) && hash == NULL )
+        return MBEDTLS_ERR_RSA_BAD_INPUT_DATA;
 
     if( ctx->padding != MBEDTLS_RSA_PKCS_V21 )
         return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
@@ -1952,11 +1889,8 @@
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     unsigned char *sig_try = NULL, *verif = NULL;
 
-    RSA_VALIDATE_RET( ctx != NULL );
-    RSA_VALIDATE_RET( ( md_alg  == MBEDTLS_MD_NONE &&
-                        hashlen == 0 ) ||
-                      hash != NULL );
-    RSA_VALIDATE_RET( sig != NULL );
+    if( ( md_alg != MBEDTLS_MD_NONE || hashlen != 0 ) && hash == NULL )
+        return MBEDTLS_ERR_RSA_BAD_INPUT_DATA;
 
     if( ctx->padding != MBEDTLS_RSA_PKCS_V15 )
         return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
@@ -2020,11 +1954,8 @@
                     const unsigned char *hash,
                     unsigned char *sig )
 {
-    RSA_VALIDATE_RET( ctx != NULL );
-    RSA_VALIDATE_RET( ( md_alg  == MBEDTLS_MD_NONE &&
-                        hashlen == 0 ) ||
-                      hash != NULL );
-    RSA_VALIDATE_RET( sig != NULL );
+    if( ( md_alg != MBEDTLS_MD_NONE || hashlen != 0 ) && hash == NULL )
+        return MBEDTLS_ERR_RSA_BAD_INPUT_DATA;
 
     switch( ctx->padding )
     {
@@ -2066,11 +1997,8 @@
     size_t observed_salt_len, msb;
     unsigned char buf[MBEDTLS_MPI_MAX_SIZE] = {0};
 
-    RSA_VALIDATE_RET( ctx != NULL );
-    RSA_VALIDATE_RET( sig != NULL );
-    RSA_VALIDATE_RET( ( md_alg  == MBEDTLS_MD_NONE &&
-                        hashlen == 0 ) ||
-                      hash != NULL );
+    if( ( md_alg != MBEDTLS_MD_NONE || hashlen != 0 ) && hash == NULL )
+        return MBEDTLS_ERR_RSA_BAD_INPUT_DATA;
 
     siglen = ctx->len;
 
@@ -2165,11 +2093,8 @@
                            const unsigned char *sig )
 {
     mbedtls_md_type_t mgf1_hash_id;
-    RSA_VALIDATE_RET( ctx != NULL );
-    RSA_VALIDATE_RET( sig != NULL );
-    RSA_VALIDATE_RET( ( md_alg  == MBEDTLS_MD_NONE &&
-                        hashlen == 0 ) ||
-                      hash != NULL );
+    if( ( md_alg != MBEDTLS_MD_NONE || hashlen != 0 ) && hash == NULL )
+        return MBEDTLS_ERR_RSA_BAD_INPUT_DATA;
 
     mgf1_hash_id = ( ctx->hash_id != MBEDTLS_MD_NONE )
                              ? (mbedtls_md_type_t) ctx->hash_id
@@ -2198,11 +2123,8 @@
     size_t sig_len;
     unsigned char *encoded = NULL, *encoded_expected = NULL;
 
-    RSA_VALIDATE_RET( ctx != NULL );
-    RSA_VALIDATE_RET( sig != NULL );
-    RSA_VALIDATE_RET( ( md_alg  == MBEDTLS_MD_NONE &&
-                        hashlen == 0 ) ||
-                      hash != NULL );
+    if( ( md_alg != MBEDTLS_MD_NONE || hashlen != 0 ) && hash == NULL )
+        return MBEDTLS_ERR_RSA_BAD_INPUT_DATA;
 
     sig_len = ctx->len;
 
@@ -2267,11 +2189,8 @@
                       const unsigned char *hash,
                       const unsigned char *sig )
 {
-    RSA_VALIDATE_RET( ctx != NULL );
-    RSA_VALIDATE_RET( sig != NULL );
-    RSA_VALIDATE_RET( ( md_alg  == MBEDTLS_MD_NONE &&
-                        hashlen == 0 ) ||
-                      hash != NULL );
+    if( ( md_alg != MBEDTLS_MD_NONE || hashlen != 0 ) && hash == NULL )
+        return MBEDTLS_ERR_RSA_BAD_INPUT_DATA;
 
     switch( ctx->padding )
     {
@@ -2298,8 +2217,6 @@
 int mbedtls_rsa_copy( mbedtls_rsa_context *dst, const mbedtls_rsa_context *src )
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-    RSA_VALIDATE_RET( dst != NULL );
-    RSA_VALIDATE_RET( src != NULL );
 
     dst->len = src->len;
 
diff --git a/tests/suites/test_suite_rsa.function b/tests/suites/test_suite_rsa.function
index a866d43..65731ed 100644
--- a/tests/suites/test_suite_rsa.function
+++ b/tests/suites/test_suite_rsa.function
@@ -16,6 +16,8 @@
     mbedtls_rsa_context ctx;
     const int invalid_padding = 42;
     const int invalid_hash_id = 0xff;
+    unsigned char buf[] = {0x00,0x01,0x02,0x03,0x04,0x05};
+    size_t buf_len = sizeof( buf );
 
     mbedtls_rsa_init( &ctx );
 
@@ -29,6 +31,28 @@
                                          invalid_hash_id ),
                 MBEDTLS_ERR_RSA_INVALID_PADDING );
 
+    TEST_EQUAL( mbedtls_rsa_pkcs1_sign(&ctx, NULL,
+                                       NULL, MBEDTLS_MD_NONE,
+                                       buf_len,
+                                       NULL, buf),
+                MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
+
+    TEST_EQUAL( mbedtls_rsa_pkcs1_sign(&ctx, NULL,
+                                       NULL, MBEDTLS_MD_SHA256,
+                                       0,
+                                       NULL, buf),
+                MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
+
+    TEST_EQUAL( mbedtls_rsa_pkcs1_verify(&ctx, MBEDTLS_MD_NONE,
+                                         buf_len,
+                                         NULL, buf),
+                MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
+
+    TEST_EQUAL( mbedtls_rsa_pkcs1_verify(&ctx, MBEDTLS_MD_SHA256,
+                                         0,
+                                         NULL, buf),
+                MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
+
 #if !defined(MBEDTLS_PKCS1_V15)
     TEST_EQUAL( mbedtls_rsa_set_padding( &ctx,
                                          MBEDTLS_RSA_PKCS_V15,
@@ -36,6 +60,32 @@
                 MBEDTLS_ERR_RSA_INVALID_PADDING );
 #endif
 
+#if defined(MBEDTLS_PKCS1_V15)
+    TEST_EQUAL( mbedtls_rsa_rsassa_pkcs1_v15_sign(&ctx, NULL,
+                                              NULL, MBEDTLS_MD_NONE,
+                                              buf_len,
+                                              NULL, buf),
+                MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
+
+    TEST_EQUAL( mbedtls_rsa_rsassa_pkcs1_v15_sign(&ctx, NULL,
+                                              NULL, MBEDTLS_MD_SHA256,
+                                              0,
+                                              NULL, buf),
+                MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
+
+    TEST_EQUAL( mbedtls_rsa_rsassa_pkcs1_v15_verify(&ctx, MBEDTLS_MD_NONE,
+                                                buf_len,
+                                                NULL, buf),
+                MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
+
+    TEST_EQUAL( mbedtls_rsa_rsassa_pkcs1_v15_verify(&ctx, MBEDTLS_MD_SHA256,
+                                                0,
+                                                NULL, buf),
+                MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
+
+
+#endif
+
 #if !defined(MBEDTLS_PKCS1_V21)
     TEST_EQUAL( mbedtls_rsa_set_padding( &ctx,
                                          MBEDTLS_RSA_PKCS_V21,
@@ -43,6 +93,42 @@
                 MBEDTLS_ERR_RSA_INVALID_PADDING );
 #endif
 
+#if defined(MBEDTLS_PKCS1_V21)
+    TEST_EQUAL( mbedtls_rsa_rsassa_pss_sign_ext(&ctx, NULL, NULL,
+                                    MBEDTLS_MD_NONE, buf_len,
+                                    NULL, buf_len,
+                                    buf ),
+                MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
+
+    TEST_EQUAL( mbedtls_rsa_rsassa_pss_sign_ext(&ctx, NULL, NULL,
+                                    MBEDTLS_MD_SHA256, 0,
+                                    NULL, buf_len,
+                                    buf ),
+                MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
+
+    TEST_EQUAL( mbedtls_rsa_rsassa_pss_verify_ext(&ctx, MBEDTLS_MD_NONE,
+                                                  buf_len, NULL,
+                                                  MBEDTLS_MD_NONE,
+                                                  buf_len, buf),
+                MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
+
+    TEST_EQUAL( mbedtls_rsa_rsassa_pss_verify_ext(&ctx, MBEDTLS_MD_SHA256,
+                                                  0, NULL,
+                                                  MBEDTLS_MD_NONE,
+                                                  buf_len, buf),
+                MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
+
+    TEST_EQUAL( mbedtls_rsa_rsassa_pss_verify(&ctx, MBEDTLS_MD_NONE,
+                                              buf_len,
+                                              NULL, buf),
+                MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
+
+    TEST_EQUAL( mbedtls_rsa_rsassa_pss_verify(&ctx, MBEDTLS_MD_SHA256,
+                                              0,
+                                              NULL, buf),
+                MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
+#endif
+
 exit:
     mbedtls_rsa_free( &ctx );
 }