Use mbedtls_mpi_core_montmul() in mpi_montmul()

Signed-off-by: Tom Cosgrove <tom.cosgrove@arm.com>
diff --git a/library/bignum.c b/library/bignum.c
index 44b2c87..48bc77c 100644
--- a/library/bignum.c
+++ b/library/bignum.c
@@ -1553,65 +1553,39 @@
     *mm = mbedtls_mpi_montg_init( N->p[0] );
 }
 
-/* This would be static, but is tested */
-void mbedtls_mpi_montmul( mbedtls_mpi *A, const mbedtls_mpi *B,
-                          const mbedtls_mpi *N, mbedtls_mpi_uint mm,
-                          const mbedtls_mpi *T )
+/** Montgomery multiplication: A = A * B * R^-1 mod N  (HAC 14.36)
+ *
+ * \param[in,out]   A   One of the numbers to multiply.
+ *                      It must have at least as many limbs as N
+ *                      (A->n >= N->n), and any limbs beyond n are ignored.
+ *                      On successful completion, A contains the result of
+ *                      the multiplication A * B * R^-1 mod N where
+ *                      R = (2^ciL)^n.
+ * \param[in]       B   One of the numbers to multiply.
+ *                      It must be nonzero and must not have more limbs than N
+ *                      (B->n <= N->n).
+ * \param[in]       N   The modulo. N must be odd.
+ * \param           mm  The value calculated by `mpi_montg_init(&mm, N)`.
+ *                      This is -N^-1 mod 2^ciL.
+ * \param[in,out]   T   A bignum for temporary storage.
+ *                      It must be at least twice the limb size of N plus 1
+ *                      (T->n >= 2 * N->n + 1).
+ *                      Its initial content is unused and
+ *                      its final content is indeterminate.
+ *                      Note that unlike the usual convention in the library
+ *                      for `const mbedtls_mpi*`, the content of T can change.
+ */
+static void mpi_montmul( mbedtls_mpi *A, const mbedtls_mpi *B,
+                         const mbedtls_mpi *N, mbedtls_mpi_uint mm,
+                         const mbedtls_mpi *T )
 {
-    size_t n, m;
-    mbedtls_mpi_uint *d;
-
-    memset( T->p, 0, T->n * ciL );
-
-    d = T->p;
-    n = N->n;
-    m = ( B->n < n ) ? B->n : n;
-
-    for( size_t i = 0; i < n; i++ )
-    {
-        mbedtls_mpi_uint u0, u1;
-
-        /*
-         * T = (T + u0*B + u1*N) / 2^biL
-         */
-        u0 = A->p[i];
-        u1 = ( d[0] + u0 * B->p[0] ) * mm;
-
-        (void) mbedtls_mpi_core_mla( d, n + 2,
-                                     B->p, m,
-                                     u0 );
-        (void) mbedtls_mpi_core_mla( d, n + 2,
-                                     N->p, n,
-                                     u1 );
-        d++;
-    }
-
-    /* At this point, d is either the desired result or the desired result
-     * plus N. We now potentially subtract N, avoiding leaking whether the
-     * subtraction is performed through side channels. */
-
-    /* Copy the n least significant limbs of d to A, so that
-     * A = d if d < N (recall that N has n limbs). */
-    memcpy( A->p, d, n * ciL );
-    /* If d >= N then we want to set A to d - N. To prevent timing attacks,
-     * do the calculation without using conditional tests. */
-    /* Set d to d0 + (2^biL)^n - N where d0 is the current value of d. */
-    d[n] += 1;
-    d[n] -= mbedtls_mpi_core_sub( d, d, N->p, n );
-    /* If d0 < N then d < (2^biL)^n
-     * so d[n] == 0 and we want to keep A as it is.
-     * If d0 >= N then d >= (2^biL)^n, and d <= (2^biL)^n + N < 2 * (2^biL)^n
-     * so d[n] == 1 and we want to set A to the result of the subtraction
-     * which is d - (2^biL)^n, i.e. the n least significant limbs of d.
-     * This exactly corresponds to a conditional assignment. */
-    mbedtls_ct_mpi_uint_cond_assign( n, A->p, d, (unsigned char) d[n] );
+    mbedtls_mpi_core_montmul( A->p, A->p, B->p, B->n, N->p, N->n, mm, T->p );
 }
 
 /*
  * Montgomery reduction: A = A * R^-1 mod N
  *
- * See the doc for mbedtls_mpi_montmul() regarding constraints and guarantees on
- * the parameters.
+ * See mpi_montmul() regarding constraints and guarantees on the parameters.
  */
 static void mpi_montred( mbedtls_mpi *A, const mbedtls_mpi *N,
                          mbedtls_mpi_uint mm, const mbedtls_mpi *T )
@@ -1622,7 +1596,7 @@
     U.n = U.s = (int) z;
     U.p = &z;
 
-    mbedtls_mpi_montmul( A, &U, N, mm, T );
+    mpi_montmul( A, &U, N, mm, T );
 }
 
 /**
@@ -1704,7 +1678,7 @@
 #endif
 
     j = N->n + 1;
-    /* All W[i] and X must have at least N->n limbs for the mbedtls_mpi_montmul()
+    /* All W[i] and X must have at least N->n limbs for the mpi_montmul()
      * and mpi_montred() calls later. Here we ensure that W[1] and X are
      * large enough, and later we'll grow other W[i] to the same length.
      * They must not be shrunk midway through this function!
@@ -1747,7 +1721,7 @@
         MBEDTLS_MPI_CHK( mbedtls_mpi_mod_mpi( &W[1], A, N ) );
         /* This should be a no-op because W[1] is already that large before
          * mbedtls_mpi_mod_mpi(), but it's necessary to avoid an overflow
-         * in mbedtls_mpi_montmul() below, so let's make sure. */
+         * in mpi_montmul() below, so let's make sure. */
         MBEDTLS_MPI_CHK( mbedtls_mpi_grow( &W[1], N->n + 1 ) );
     }
     else
@@ -1755,7 +1729,7 @@
 
     /* Note that this is safe because W[1] always has at least N->n limbs
      * (it grew above and was preserved by mbedtls_mpi_copy()). */
-    mbedtls_mpi_montmul( &W[1], &RR, N, mm, &T );
+    mpi_montmul( &W[1], &RR, N, mm, &T );
 
     /*
      * X = R^2 * R^-1 mod N = R mod N
@@ -1774,7 +1748,7 @@
         MBEDTLS_MPI_CHK( mbedtls_mpi_copy( &W[j], &W[1]    ) );
 
         for( i = 0; i < wsize - 1; i++ )
-            mbedtls_mpi_montmul( &W[j], &W[j], N, mm, &T );
+            mpi_montmul( &W[j], &W[j], N, mm, &T );
 
         /*
          * W[i] = W[i - 1] * W[1]
@@ -1784,7 +1758,7 @@
             MBEDTLS_MPI_CHK( mbedtls_mpi_grow( &W[i], N->n + 1 ) );
             MBEDTLS_MPI_CHK( mbedtls_mpi_copy( &W[i], &W[i - 1] ) );
 
-            mbedtls_mpi_montmul( &W[i], &W[1], N, mm, &T );
+            mpi_montmul( &W[i], &W[1], N, mm, &T );
         }
     }
 
@@ -1821,7 +1795,7 @@
             /*
              * out of window, square X
              */
-            mbedtls_mpi_montmul( X, X, N, mm, &T );
+            mpi_montmul( X, X, N, mm, &T );
             continue;
         }
 
@@ -1839,13 +1813,13 @@
              * X = X^wsize R^-1 mod N
              */
             for( i = 0; i < wsize; i++ )
-                mbedtls_mpi_montmul( X, X, N, mm, &T );
+                mpi_montmul( X, X, N, mm, &T );
 
             /*
              * X = X * W[wbits] R^-1 mod N
              */
             MBEDTLS_MPI_CHK( mpi_select( &WW, W, (size_t) 1 << wsize, wbits ) );
-            mbedtls_mpi_montmul( X, &WW, N, mm, &T );
+            mpi_montmul( X, &WW, N, mm, &T );
 
             state--;
             nbits = 0;
@@ -1858,12 +1832,12 @@
      */
     for( i = 0; i < nbits; i++ )
     {
-        mbedtls_mpi_montmul( X, X, N, mm, &T );
+        mpi_montmul( X, X, N, mm, &T );
 
         wbits <<= 1;
 
         if( ( wbits & ( one << wsize ) ) != 0 )
-            mbedtls_mpi_montmul( X, &W[1], N, mm, &T );
+            mpi_montmul( X, &W[1], N, mm, &T );
     }
 
     /*
diff --git a/library/bignum_core.h b/library/bignum_core.h
index ca45480..02ac55d 100644
--- a/library/bignum_core.h
+++ b/library/bignum_core.h
@@ -263,32 +263,4 @@
                                           size_t n,
                                           unsigned cond );
 
-/** Montgomery multiplication: A = A * B * R^-1 mod N  (HAC 14.36)
- *
- * This would be static, but is tested.
- *
- * \param[in,out]   A   One of the numbers to multiply.
- *                      It must have at least as many limbs as N
- *                      (A->n >= N->n), and any limbs beyond n are ignored.
- *                      On successful completion, A contains the result of
- *                      the multiplication A * B * R^-1 mod N where
- *                      R = (2^ciL)^n.
- * \param[in]       B   One of the numbers to multiply.
- *                      It must be nonzero and must not have more limbs than N
- *                      (B->n <= N->n).
- * \param[in]       N   The modulo. N must be odd.
- * \param           mm  The value calculated by `mpi_montg_init(&mm, N)`.
- *                      This is -N^-1 mod 2^ciL.
- * \param[in,out]   T   A bignum for temporary storage.
- *                      It must be at least twice the limb size of N plus 1
- *                      (T->n >= 2 * N->n + 1).
- *                      Its initial content is unused and
- *                      its final content is indeterminate.
- *                      Note that unlike the usual convention in the library
- *                      for `const mbedtls_mpi*`, the content of T can change.
- */
-void mbedtls_mpi_montmul( mbedtls_mpi *A, const mbedtls_mpi *B,
-                          const mbedtls_mpi *N, mbedtls_mpi_uint mm,
-                          const mbedtls_mpi *T );
-
 #endif /* MBEDTLS_BIGNUM_CORE_H */
diff --git a/tests/suites/test_suite_mpi.function b/tests/suites/test_suite_mpi.function
index d9109ec..bf1212a 100644
--- a/tests/suites/test_suite_mpi.function
+++ b/tests/suites/test_suite_mpi.function
@@ -2035,7 +2035,7 @@
                                char * input_X4,
                                char * input_X8 )
 {
-    mbedtls_mpi A, B, N, X4, X8, T, CA;
+    mbedtls_mpi A, B, N, X4, X8, T, R;
 
     mbedtls_mpi_init( &A );
     mbedtls_mpi_init( &B );
@@ -2043,7 +2043,7 @@
     mbedtls_mpi_init( &X4 );    /* expected result, sizeof(mbedtls_mpi_uint) == 4 */
     mbedtls_mpi_init( &X8 );    /* expected result, sizeof(mbedtls_mpi_uint) == 8 */
     mbedtls_mpi_init( &T );
-    mbedtls_mpi_init( &CA );    /* copy of A */
+    mbedtls_mpi_init( &R );     /* for the result */
 
     TEST_EQUAL( mbedtls_test_read_mpi( &A, input_A ), 0 );
     TEST_EQUAL( mbedtls_test_read_mpi( &B, input_B ), 0 );
@@ -2076,24 +2076,10 @@
     /* Calculate the Montgomery constant (this is unit tested separately) */
     mbedtls_mpi_uint mm = mbedtls_mpi_montg_init( N.p[0] );
 
-    TEST_EQUAL( mbedtls_mpi_copy( &CA, &A ), 0 );       /* take a copy */
-    TEST_EQUAL( mbedtls_mpi_grow( &CA, limbs_AN ), 0 ); /* ensure it's got the right number of limbs */
+    TEST_EQUAL( mbedtls_mpi_grow( &R, limbs_AN ), 0 ); /* ensure it's got the right number of limbs */
 
-    mbedtls_mpi_montmul( &A, &B, &N, mm, &T );
-    TEST_EQUAL( A.s, 1 );               /* ensure still positive */
-
-    /* Could use mbedtls_mpi_cmp_mpi(), but this gives finer detail if not the same */
-    TEST_EQUAL( A.n, X->n );
-    TEST_EQUAL( memcmp( A.p, X->p, A.n * sizeof(mbedtls_mpi_uint) ), 0 );
-
-    /* First overwrite A so we ensure mbedtls_mpi_core_montmul() does something */
-    memset( A.p, 0xAA, A.n * sizeof(mbedtls_mpi_uint) );
-
-    /* Now test the new function: use the copy CA we took earlier of A as the
-     * LHS, and use A as the destination
-     */
-    mbedtls_mpi_core_montmul( A.p, CA.p, B.p, B.n, N.p, N.n, mm, T.p );
-    TEST_EQUAL( memcmp( A.p, X->p, A.n * sizeof(mbedtls_mpi_uint) ), 0 );
+    mbedtls_mpi_core_montmul( R.p, A.p, B.p, B.n, N.p, N.n, mm, T.p );
+    TEST_EQUAL( memcmp( R.p, X->p, N.n * sizeof(mbedtls_mpi_uint) ), 0 );
 
 exit:
     mbedtls_mpi_free( &A );
@@ -2102,7 +2088,7 @@
     mbedtls_mpi_free( &X4 );
     mbedtls_mpi_free( &X8 );
     mbedtls_mpi_free( &T );
-    mbedtls_mpi_free( &CA );
+    mbedtls_mpi_free( &R );
 }
 /* END_CASE */