Add tests for mod_mul

Signed-off-by: Gabor Mezei <gabor.mezei@arm.com>
diff --git a/tests/suites/test_suite_bignum_mod.function b/tests/suites/test_suite_bignum_mod.function
index 79f5134..7c407a8 100644
--- a/tests/suites/test_suite_bignum_mod.function
+++ b/tests/suites/test_suite_bignum_mod.function
@@ -2,6 +2,7 @@
 #include "mbedtls/bignum.h"
 #include "mbedtls/entropy.h"
 #include "bignum_mod.h"
+#include "bignum_mod_raw.h"
 #include "constant_time_internal.h"
 #include "test/constant_flow.h"
 
@@ -102,6 +103,191 @@
 
 /* BEGIN MERGE SLOT 2 */
 
+/* BEGIN_CASE */
+void mpi_mod_mul( char * input_A,
+                  char * input_B,
+                  char * input_N,
+                  char * result )
+{
+    mbedtls_mpi_uint *A = NULL;
+    mbedtls_mpi_uint *B = NULL;
+    mbedtls_mpi_uint *N = NULL;
+    mbedtls_mpi_uint *X = NULL;
+    mbedtls_mpi_uint *R = NULL;
+    size_t limbs_A;
+    size_t limbs_B;
+    size_t limbs_N;
+    size_t limbs_R;
+
+    mbedtls_mpi_mod_modulus m;
+    mbedtls_mpi_mod_modulus_init( &m );
+
+    TEST_EQUAL( mbedtls_test_read_mpi_core( &A, &limbs_A, input_A ), 0 );
+    TEST_EQUAL( mbedtls_test_read_mpi_core( &B, &limbs_B, input_B ), 0 );
+    TEST_EQUAL( mbedtls_test_read_mpi_core( &N, &limbs_N, input_N ), 0 );
+    TEST_EQUAL( mbedtls_test_read_mpi_core( &R, &limbs_R, result  ), 0 );
+
+    const size_t limbs = limbs_N;
+    const size_t bytes = limbs * sizeof( mbedtls_mpi_uint );
+
+    TEST_EQUAL( limbs_A, limbs );
+    TEST_EQUAL( limbs_B, limbs );
+    TEST_EQUAL( limbs_R, limbs );
+
+    ASSERT_ALLOC( X, limbs );
+
+    TEST_EQUAL( mbedtls_mpi_mod_modulus_setup(
+                        &m, N, limbs,
+                        MBEDTLS_MPI_MOD_REP_MONTGOMERY ), 0 );
+
+    mbedtls_mpi_mod_residue rA;
+    TEST_EQUAL( mbedtls_mpi_mod_residue_setup( &rA, &m, A, limbs ), 0 );
+
+    mbedtls_mpi_mod_residue rB;
+    TEST_EQUAL( mbedtls_mpi_mod_residue_setup( &rB, &m, B, limbs ), 0 );
+
+    mbedtls_mpi_mod_residue rX;
+    TEST_EQUAL( mbedtls_mpi_mod_residue_setup( &rX, &m, X, limbs ), 0 );
+
+    /* Convert to Montgomery representation */
+    TEST_EQUAL( mbedtls_mpi_mod_raw_to_mont_rep( rA.p, &m ), 0 );
+    TEST_EQUAL( mbedtls_mpi_mod_raw_to_mont_rep( rB.p, &m ), 0 );
+
+    TEST_EQUAL( mbedtls_mpi_mod_mul( &rX, &rA, &rB, &m ), 0 );
+    TEST_EQUAL( mbedtls_mpi_mod_raw_from_mont_rep( rX.p, &m ), 0 );
+    ASSERT_COMPARE( rX.p, bytes, R, bytes );
+
+    /* alias X to A */
+    memcpy( rX.p, rA.p, bytes );
+    TEST_EQUAL( mbedtls_mpi_mod_mul( &rX, &rX, &rB, &m ), 0 );
+    TEST_EQUAL( mbedtls_mpi_mod_raw_from_mont_rep( rX.p, &m ), 0 );
+    ASSERT_COMPARE( rX.p, bytes, R, bytes );
+
+    /* alias X to B */
+    memcpy( rX.p, rB.p, bytes );
+    TEST_EQUAL( mbedtls_mpi_mod_mul( &rX, &rA, &rX, &m ), 0);
+    TEST_EQUAL( mbedtls_mpi_mod_raw_from_mont_rep( rX.p, &m ), 0 );
+    ASSERT_COMPARE( rX.p, bytes, R, bytes );
+
+    /* A == B: alias A and B */
+    if( memcmp( rA.p, rB.p, bytes ) == 0 )
+    {
+        TEST_EQUAL( mbedtls_mpi_mod_mul( &rX, &rA, &rA, &m ), 0 );
+        TEST_EQUAL( mbedtls_mpi_mod_raw_from_mont_rep( rX.p, &m ), 0 );
+        ASSERT_COMPARE( rX.p, bytes, R, bytes );
+
+        /* X, A, B all aliased together */
+        memcpy( rX.p, rA.p, bytes );
+        TEST_EQUAL( mbedtls_mpi_mod_mul( &rX, &rX, &rX, &m ), 0 );
+        TEST_EQUAL( mbedtls_mpi_mod_raw_from_mont_rep( rX.p, &m ), 0 );
+        ASSERT_COMPARE( rX.p, bytes, R, bytes );
+    }
+
+    /* A != B: test B * A */
+    else
+    {
+        TEST_EQUAL( mbedtls_mpi_mod_mul( &rX, &rB, &rA, &m ), 0 );
+        TEST_EQUAL( mbedtls_mpi_mod_raw_from_mont_rep( rX.p, &m ), 0 );
+        ASSERT_COMPARE( rX.p, bytes, R, bytes );
+
+        /* B * A: alias X to A */
+        memcpy( rX.p, rA.p, bytes );
+        TEST_EQUAL( mbedtls_mpi_mod_mul( &rX, &rB, &rX, &m ), 0 );
+        TEST_EQUAL( mbedtls_mpi_mod_raw_from_mont_rep( rX.p, &m ), 0 );
+        ASSERT_COMPARE( rX.p, bytes, R, bytes );
+
+        /* B + A: alias X to B */
+        memcpy( rX.p, rB.p, bytes );
+        TEST_EQUAL( mbedtls_mpi_mod_mul( &rX, &rX, &rA, &m ), 0 );
+        TEST_EQUAL( mbedtls_mpi_mod_raw_from_mont_rep( rX.p, &m ), 0 );
+        ASSERT_COMPARE( rX.p, bytes, R, bytes );
+    }
+
+exit:
+    mbedtls_mpi_mod_residue_release( &rA );
+    mbedtls_mpi_mod_residue_release( &rB );
+    mbedtls_mpi_mod_residue_release( &rX );
+    mbedtls_mpi_mod_modulus_free( &m );
+
+    mbedtls_free( A );
+    mbedtls_free( B );
+    mbedtls_free( N );
+    mbedtls_free( X );
+    mbedtls_free( R );
+}
+/* END_CASE */
+
+/* BEGIN_CASE */
+void mpi_mod_mul_neg( char * input_A,
+                      char * input_B,
+                      char * input_N,
+                      char * result,
+                      int exp_ret )
+{
+    mbedtls_mpi_uint *A = NULL;
+    mbedtls_mpi_uint *B = NULL;
+    mbedtls_mpi_uint *N = NULL;
+    mbedtls_mpi_uint *X = NULL;
+    mbedtls_mpi_uint *R = NULL;
+    size_t limbs_A = 0;
+    size_t limbs_B = 0;
+    size_t limbs_N = 0;
+    size_t limbs_X = 0;
+
+    mbedtls_mpi_mod_modulus m;
+    mbedtls_mpi_mod_modulus_init( &m );
+
+    mbedtls_mpi_mod_modulus fake_m;
+    mbedtls_mpi_mod_modulus_init( &fake_m );
+
+    TEST_EQUAL( mbedtls_test_read_mpi_core( &A, &limbs_A, input_A ), 0 );
+    TEST_EQUAL( mbedtls_test_read_mpi_core( &B, &limbs_B, input_B ), 0 );
+    TEST_EQUAL( mbedtls_test_read_mpi_core( &N, &limbs_N, input_N ), 0 );
+    TEST_EQUAL( mbedtls_test_read_mpi_core( &R, &limbs_X, result  ), 0 );
+
+    ASSERT_ALLOC( X, limbs_X );
+
+    TEST_EQUAL( mbedtls_mpi_mod_modulus_setup(
+                    &m, N, limbs_N,
+                    MBEDTLS_MPI_MOD_REP_MONTGOMERY ), 0 );
+
+    mbedtls_mpi_mod_residue rA;
+    TEST_EQUAL( mbedtls_mpi_mod_residue_setup( &rA, &m, A, limbs_N ), 0 );
+    rA.limbs = limbs_A;
+
+    mbedtls_mpi_mod_residue rB;
+    TEST_EQUAL( mbedtls_mpi_mod_residue_setup( &rB, &m, B, limbs_N ), 0 );
+    rB.limbs = limbs_B;
+
+    mbedtls_mpi_mod_residue rX;
+    TEST_EQUAL( mbedtls_mpi_mod_residue_setup( &rX, &m, X, limbs_N ), 0 );
+    rX.limbs = limbs_X;
+
+    /* Convert to Montgomery representation */
+    TEST_EQUAL( mbedtls_mpi_mod_raw_to_mont_rep( rA.p, &m ), 0 );
+    TEST_EQUAL( mbedtls_mpi_mod_raw_to_mont_rep( rB.p, &m ), 0 );
+
+    TEST_EQUAL( mbedtls_mpi_mod_mul( &rX, &rA, &rB, &m ), exp_ret );
+
+    /* Check when m is not initialized */
+    TEST_EQUAL( mbedtls_mpi_mod_mul( &rX, &rA, &rB, &fake_m ),
+                MBEDTLS_ERR_MPI_BAD_INPUT_DATA );
+
+exit:
+    mbedtls_mpi_mod_residue_release( &rA );
+    mbedtls_mpi_mod_residue_release( &rB );
+    mbedtls_mpi_mod_residue_release( &rX );
+    mbedtls_mpi_mod_modulus_free( &m );
+    mbedtls_mpi_mod_modulus_free( &fake_m );
+
+    mbedtls_free( A );
+    mbedtls_free( B );
+    mbedtls_free( N );
+    mbedtls_free( X );
+    mbedtls_free( R );
+}
+/* END_CASE */
+
 /* END MERGE SLOT 2 */
 
 /* BEGIN MERGE SLOT 3 */