Make RNG parameters mandatory in DHM functions
Signed-off-by: Manuel Pégourié-Gonnard <manuel.pegourie-gonnard@arm.com>
diff --git a/include/mbedtls/dhm.h b/include/mbedtls/dhm.h
index e8c8a82..850813e 100644
--- a/include/mbedtls/dhm.h
+++ b/include/mbedtls/dhm.h
@@ -279,10 +279,10 @@
* \param output_size The size of the destination buffer. This must be at
* least the size of \c ctx->len (the size of \c P).
* \param olen On exit, holds the actual number of Bytes written.
- * \param f_rng The RNG function, for blinding purposes. This may
- * b \c NULL if blinding isn't needed.
- * \param p_rng The RNG context. This may be \c NULL if \p f_rng
- * doesn't need a context argument.
+ * \param f_rng The RNG function. Must not be \c NULL. Used for
+ * blinding.
+ * \param p_rng The RNG context to be passed to \p f_rng. This may be
+ * \c NULL if \p f_rng doesn't need a context parameter.
*
* \return \c 0 on success.
* \return An \c MBEDTLS_ERR_DHM_XXX error code on failure.
diff --git a/library/dhm.c b/library/dhm.c
index e88f3a2..29ce755 100644
--- a/library/dhm.c
+++ b/library/dhm.c
@@ -444,6 +444,9 @@
DHM_VALIDATE_RET( output != NULL );
DHM_VALIDATE_RET( olen != NULL );
+ if( f_rng == NULL )
+ return( MBEDTLS_ERR_DHM_BAD_INPUT_DATA );
+
if( output_size < mbedtls_dhm_get_len( ctx ) )
return( MBEDTLS_ERR_DHM_BAD_INPUT_DATA );
@@ -453,25 +456,17 @@
mbedtls_mpi_init( &GYb );
/* Blind peer's value */
- if( f_rng != NULL )
- {
- MBEDTLS_MPI_CHK( dhm_update_blinding( ctx, f_rng, p_rng ) );
- MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &GYb, &ctx->GY, &ctx->Vi ) );
- MBEDTLS_MPI_CHK( mbedtls_mpi_mod_mpi( &GYb, &GYb, &ctx->P ) );
- }
- else
- MBEDTLS_MPI_CHK( mbedtls_mpi_copy( &GYb, &ctx->GY ) );
+ MBEDTLS_MPI_CHK( dhm_update_blinding( ctx, f_rng, p_rng ) );
+ MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &GYb, &ctx->GY, &ctx->Vi ) );
+ MBEDTLS_MPI_CHK( mbedtls_mpi_mod_mpi( &GYb, &GYb, &ctx->P ) );
/* Do modular exponentiation */
MBEDTLS_MPI_CHK( mbedtls_mpi_exp_mod( &ctx->K, &GYb, &ctx->X,
&ctx->P, &ctx->RP ) );
/* Unblind secret value */
- if( f_rng != NULL )
- {
- MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &ctx->K, &ctx->K, &ctx->Vf ) );
- MBEDTLS_MPI_CHK( mbedtls_mpi_mod_mpi( &ctx->K, &ctx->K, &ctx->P ) );
- }
+ MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &ctx->K, &ctx->K, &ctx->Vf ) );
+ MBEDTLS_MPI_CHK( mbedtls_mpi_mod_mpi( &ctx->K, &ctx->K, &ctx->P ) );
/* Output the secret without any leading zero byte. This is mandatory
* for TLS per RFC 5246 §8.1.2. */
diff --git a/tests/suites/test_suite_dhm.function b/tests/suites/test_suite_dhm.function
index 62e634a..5286bc7 100644
--- a/tests/suites/test_suite_dhm.function
+++ b/tests/suites/test_suite_dhm.function
@@ -150,7 +150,10 @@
&sec_srv_len,
&mbedtls_test_rnd_pseudo_rand,
&rnd_info ) == 0 );
- TEST_ASSERT( mbedtls_dhm_calc_secret( &ctx_cli, sec_cli, sizeof( sec_cli ), &sec_cli_len, NULL, NULL ) == 0 );
+ TEST_ASSERT( mbedtls_dhm_calc_secret( &ctx_cli, sec_cli, sizeof( sec_cli ),
+ &sec_cli_len,
+ &mbedtls_test_rnd_pseudo_rand,
+ &rnd_info ) == 0 );
TEST_ASSERT( sec_srv_len == sec_cli_len );
TEST_ASSERT( sec_srv_len != 0 );
@@ -206,7 +209,10 @@
&sec_srv_len,
&mbedtls_test_rnd_pseudo_rand,
&rnd_info ) == 0 );
- TEST_ASSERT( mbedtls_dhm_calc_secret( &ctx_cli, sec_cli, sizeof( sec_cli ), &sec_cli_len, NULL, NULL ) == 0 );
+ TEST_ASSERT( mbedtls_dhm_calc_secret( &ctx_cli, sec_cli, sizeof( sec_cli ),
+ &sec_cli_len,
+ &mbedtls_test_rnd_pseudo_rand,
+ &rnd_info ) == 0 );
TEST_ASSERT( sec_srv_len == sec_cli_len );
TEST_ASSERT( sec_srv_len != 0 );