Merge pull request #6742 from gabor-mezei-arm/6022_bignum_mod_raw_mul

Bignum: Implement fixed width raw modular multiplication
diff --git a/library/bignum_mod_raw.c b/library/bignum_mod_raw.c
index c98a1c1..c0877bd 100644
--- a/library/bignum_mod_raw.c
+++ b/library/bignum_mod_raw.c
@@ -120,6 +120,16 @@
     (void) mbedtls_mpi_core_add_if( X, N->p, N->limbs, (unsigned) c );
 }
 
+void mbedtls_mpi_mod_raw_mul( mbedtls_mpi_uint *X,
+                              const mbedtls_mpi_uint *A,
+                              const mbedtls_mpi_uint *B,
+                              const mbedtls_mpi_mod_modulus *N,
+                              mbedtls_mpi_uint *T )
+{
+    mbedtls_mpi_core_montmul( X, A, B, N->limbs, N->p, N->limbs,
+                              N->rep.mont.mm, T );
+}
+
 /* END MERGE SLOT 2 */
 
 /* BEGIN MERGE SLOT 3 */
diff --git a/library/bignum_mod_raw.h b/library/bignum_mod_raw.h
index e6237b3..380f30b 100644
--- a/library/bignum_mod_raw.h
+++ b/library/bignum_mod_raw.h
@@ -215,6 +215,41 @@
                               const mbedtls_mpi_uint *B,
                               const mbedtls_mpi_mod_modulus *N );
 
+/** \brief  Multiply two MPIs, returning the residue modulo the specified
+ *          modulus.
+ *
+ * \note Currently handles the case when `N->int_rep` is
+ * MBEDTLS_MPI_MOD_REP_MONTGOMERY.
+ *
+ * The size of the operation is determined by \p N. \p A, \p B and \p X must
+ * all be associated with the modulus \p N and must all have the same number
+ * of limbs as \p N.
+ *
+ * \p X may be aliased to \p A or \p B, or even both, but may not overlap
+ * either otherwise. They may not alias \p N (since they must be in canonical
+ * form, they cannot == \p N).
+ *
+ * \param[out] X        The address of the result MPI. Must have the same
+ *                      number of limbs as \p N.
+ *                      On successful completion, \p X contains the result of
+ *                      the multiplication `A * B * R^-1` mod N where
+ *                      `R = 2^(biL * N->limbs)`.
+ * \param[in]  A        The address of the first MPI.
+ * \param[in]  B        The address of the second MPI.
+ * \param[in]  N        The address of the modulus. Used to perform a modulo
+ *                      operation on the result of the multiplication.
+ * \param[in,out] T     Temporary storage of size at least 2 * N->limbs + 1
+ *                      limbs. Its initial content is unused and
+ *                      its final content is indeterminate.
+ *                      It must not alias or otherwise overlap any of the
+ *                      other parameters.
+ */
+void mbedtls_mpi_mod_raw_mul( mbedtls_mpi_uint *X,
+                              const mbedtls_mpi_uint *A,
+                              const mbedtls_mpi_uint *B,
+                              const mbedtls_mpi_mod_modulus *N,
+                              mbedtls_mpi_uint *T );
+
 /* END MERGE SLOT 2 */
 
 /* BEGIN MERGE SLOT 3 */
diff --git a/scripts/mbedtls_dev/bignum_mod_raw.py b/scripts/mbedtls_dev/bignum_mod_raw.py
index 6fc4c91..296e2d2 100644
--- a/scripts/mbedtls_dev/bignum_mod_raw.py
+++ b/scripts/mbedtls_dev/bignum_mod_raw.py
@@ -50,6 +50,25 @@
         result = (self.int_a - self.int_b) % self.int_n
         return [self.format_result(result)]
 
+class BignumModRawMul(bignum_common.ModOperationCommon,
+                      BignumModRawTarget):
+    """Test cases for bignum mpi_mod_raw_mul()."""
+    symbol = "*"
+    test_function = "mpi_mod_raw_mul"
+    test_name = "mbedtls_mpi_mod_raw_mul"
+    input_style = "arch_split"
+    arity = 2
+
+    def arguments(self) -> List[str]:
+        return [self.format_result(self.to_montgomery(self.int_a)),
+                self.format_result(self.to_montgomery(self.int_b)),
+                bignum_common.quote_str(self.arg_n)
+               ] + self.result()
+
+    def result(self) -> List[str]:
+        result = (self.int_a * self.int_b) % self.int_n
+        return [self.format_result(self.to_montgomery(result))]
+
 # END MERGE SLOT 2
 
 # BEGIN MERGE SLOT 3
diff --git a/tests/suites/test_suite_bignum_mod_raw.function b/tests/suites/test_suite_bignum_mod_raw.function
index 83e1f54..b1d1f77 100644
--- a/tests/suites/test_suite_bignum_mod_raw.function
+++ b/tests/suites/test_suite_bignum_mod_raw.function
@@ -345,6 +345,101 @@
 }
 /* END_CASE */
 
+/* BEGIN_CASE */
+void mpi_mod_raw_mul( char * input_A,
+                      char * input_B,
+                      char * input_N,
+                      char * result )
+{
+    mbedtls_mpi_uint *A = NULL;
+    mbedtls_mpi_uint *B = NULL;
+    mbedtls_mpi_uint *N = NULL;
+    mbedtls_mpi_uint *X = NULL;
+    mbedtls_mpi_uint *R = NULL;
+    mbedtls_mpi_uint *T = NULL;
+    size_t limbs_A;
+    size_t limbs_B;
+    size_t limbs_N;
+    size_t limbs_R;
+
+    mbedtls_mpi_mod_modulus m;
+    mbedtls_mpi_mod_modulus_init( &m );
+
+    TEST_EQUAL( mbedtls_test_read_mpi_core( &A, &limbs_A, input_A ), 0 );
+    TEST_EQUAL( mbedtls_test_read_mpi_core( &B, &limbs_B, input_B ), 0 );
+    TEST_EQUAL( mbedtls_test_read_mpi_core( &N, &limbs_N, input_N ), 0 );
+    TEST_EQUAL( mbedtls_test_read_mpi_core( &R, &limbs_R, result  ), 0 );
+
+    const size_t limbs = limbs_N;
+    const size_t bytes = limbs * sizeof( mbedtls_mpi_uint );
+
+    TEST_EQUAL( limbs_A, limbs );
+    TEST_EQUAL( limbs_B, limbs );
+    TEST_EQUAL( limbs_R, limbs );
+
+    ASSERT_ALLOC( X, limbs );
+
+    TEST_EQUAL( mbedtls_mpi_mod_modulus_setup(
+                        &m, N, limbs,
+                        MBEDTLS_MPI_MOD_REP_MONTGOMERY ), 0 );
+
+    const size_t limbs_T = limbs * 2 + 1;
+    ASSERT_ALLOC( T, limbs_T );
+
+    mbedtls_mpi_mod_raw_mul( X, A, B, &m, T );
+    ASSERT_COMPARE( X, bytes, R, bytes );
+
+    /* alias X to A */
+    memcpy( X, A, bytes );
+    mbedtls_mpi_mod_raw_mul( X, X, B, &m, T );
+    ASSERT_COMPARE( X, bytes, R, bytes );
+
+    /* alias X to B */
+    memcpy( X, B, bytes );
+    mbedtls_mpi_mod_raw_mul( X, A, X, &m, T );
+    ASSERT_COMPARE( X, bytes, R, bytes );
+
+    /* A == B: alias A and B */
+    if( memcmp( A, B, bytes ) == 0 )
+    {
+        mbedtls_mpi_mod_raw_mul( X, A, A, &m, T );
+        ASSERT_COMPARE( X, bytes, R, bytes );
+
+        /* X, A, B all aliased together */
+        memcpy( X, A, bytes );
+        mbedtls_mpi_mod_raw_mul( X, X, X, &m, T );
+        ASSERT_COMPARE( X, bytes, R, bytes );
+    }
+
+    /* A != B: test B * A */
+    else
+    {
+        mbedtls_mpi_mod_raw_mul( X, B, A, &m, T );
+        ASSERT_COMPARE( X, bytes, R, bytes );
+
+        /* B * A: alias X to A */
+        memcpy( X, A, bytes );
+        mbedtls_mpi_mod_raw_mul( X, B, X, &m, T );
+        ASSERT_COMPARE( X, bytes, R, bytes );
+
+        /* B + A: alias X to B */
+        memcpy( X, B, bytes );
+        mbedtls_mpi_mod_raw_mul( X, X, A, &m, T );
+        ASSERT_COMPARE( X, bytes, R, bytes );
+    }
+
+exit:
+    mbedtls_free( A );
+    mbedtls_free( B );
+    mbedtls_free( X );
+    mbedtls_free( R );
+    mbedtls_free( T );
+
+    mbedtls_mpi_mod_modulus_free( &m );
+    mbedtls_free( N );
+}
+/* END_CASE */
+
 /* END MERGE SLOT 2 */
 
 /* BEGIN MERGE SLOT 3 */