Make RNG parameters mandatory in DHM functions

Signed-off-by: Manuel Pégourié-Gonnard <manuel.pegourie-gonnard@arm.com>
diff --git a/include/mbedtls/dhm.h b/include/mbedtls/dhm.h
index e8c8a82..850813e 100644
--- a/include/mbedtls/dhm.h
+++ b/include/mbedtls/dhm.h
@@ -279,10 +279,10 @@
  * \param output_size   The size of the destination buffer. This must be at
  *                      least the size of \c ctx->len (the size of \c P).
  * \param olen          On exit, holds the actual number of Bytes written.
- * \param f_rng         The RNG function, for blinding purposes. This may
- *                      b \c NULL if blinding isn't needed.
- * \param p_rng         The RNG context. This may be \c NULL if \p f_rng
- *                      doesn't need a context argument.
+ * \param f_rng         The RNG function. Must not be \c NULL. Used for
+ *                      blinding.
+ * \param p_rng         The RNG context to be passed to \p f_rng. This may be
+ *                      \c NULL if \p f_rng doesn't need a context parameter.
  *
  * \return              \c 0 on success.
  * \return              An \c MBEDTLS_ERR_DHM_XXX error code on failure.
diff --git a/library/dhm.c b/library/dhm.c
index e88f3a2..29ce755 100644
--- a/library/dhm.c
+++ b/library/dhm.c
@@ -444,6 +444,9 @@
     DHM_VALIDATE_RET( output != NULL );
     DHM_VALIDATE_RET( olen != NULL );
 
+    if( f_rng == NULL )
+        return( MBEDTLS_ERR_DHM_BAD_INPUT_DATA );
+
     if( output_size < mbedtls_dhm_get_len( ctx ) )
         return( MBEDTLS_ERR_DHM_BAD_INPUT_DATA );
 
@@ -453,25 +456,17 @@
     mbedtls_mpi_init( &GYb );
 
     /* Blind peer's value */
-    if( f_rng != NULL )
-    {
-        MBEDTLS_MPI_CHK( dhm_update_blinding( ctx, f_rng, p_rng ) );
-        MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &GYb, &ctx->GY, &ctx->Vi ) );
-        MBEDTLS_MPI_CHK( mbedtls_mpi_mod_mpi( &GYb, &GYb, &ctx->P ) );
-    }
-    else
-        MBEDTLS_MPI_CHK( mbedtls_mpi_copy( &GYb, &ctx->GY ) );
+    MBEDTLS_MPI_CHK( dhm_update_blinding( ctx, f_rng, p_rng ) );
+    MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &GYb, &ctx->GY, &ctx->Vi ) );
+    MBEDTLS_MPI_CHK( mbedtls_mpi_mod_mpi( &GYb, &GYb, &ctx->P ) );
 
     /* Do modular exponentiation */
     MBEDTLS_MPI_CHK( mbedtls_mpi_exp_mod( &ctx->K, &GYb, &ctx->X,
                           &ctx->P, &ctx->RP ) );
 
     /* Unblind secret value */
-    if( f_rng != NULL )
-    {
-        MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &ctx->K, &ctx->K, &ctx->Vf ) );
-        MBEDTLS_MPI_CHK( mbedtls_mpi_mod_mpi( &ctx->K, &ctx->K, &ctx->P ) );
-    }
+    MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &ctx->K, &ctx->K, &ctx->Vf ) );
+    MBEDTLS_MPI_CHK( mbedtls_mpi_mod_mpi( &ctx->K, &ctx->K, &ctx->P ) );
 
     /* Output the secret without any leading zero byte. This is mandatory
      * for TLS per RFC 5246 §8.1.2. */
diff --git a/tests/suites/test_suite_dhm.function b/tests/suites/test_suite_dhm.function
index 62e634a..5286bc7 100644
--- a/tests/suites/test_suite_dhm.function
+++ b/tests/suites/test_suite_dhm.function
@@ -150,7 +150,10 @@
                                           &sec_srv_len,
                                           &mbedtls_test_rnd_pseudo_rand,
                                           &rnd_info ) == 0 );
-    TEST_ASSERT( mbedtls_dhm_calc_secret( &ctx_cli, sec_cli, sizeof( sec_cli ), &sec_cli_len, NULL, NULL ) == 0 );
+    TEST_ASSERT( mbedtls_dhm_calc_secret( &ctx_cli, sec_cli, sizeof( sec_cli ),
+                                          &sec_cli_len,
+                                          &mbedtls_test_rnd_pseudo_rand,
+                                          &rnd_info ) == 0 );
 
     TEST_ASSERT( sec_srv_len == sec_cli_len );
     TEST_ASSERT( sec_srv_len != 0 );
@@ -206,7 +209,10 @@
                                           &sec_srv_len,
                                           &mbedtls_test_rnd_pseudo_rand,
                                           &rnd_info ) == 0 );
-    TEST_ASSERT( mbedtls_dhm_calc_secret( &ctx_cli, sec_cli, sizeof( sec_cli ), &sec_cli_len, NULL, NULL ) == 0 );
+    TEST_ASSERT( mbedtls_dhm_calc_secret( &ctx_cli, sec_cli, sizeof( sec_cli ),
+                                          &sec_cli_len,
+                                          &mbedtls_test_rnd_pseudo_rand,
+                                          &rnd_info ) == 0 );
 
     TEST_ASSERT( sec_srv_len == sec_cli_len );
     TEST_ASSERT( sec_srv_len != 0 );