Merge pull request #6233 from tom-cosgrove-arm/issue-6226-core-mul

Bignum: extract core_mul from the prototype
diff --git a/library/bignum.c b/library/bignum.c
index d3a1b00..2421c1a 100644
--- a/library/bignum.c
+++ b/library/bignum.c
@@ -1136,7 +1136,8 @@
     MPI_VALIDATE_RET(A != NULL);
     MPI_VALIDATE_RET(B != NULL);
 
-    mbedtls_mpi_init(&TA); mbedtls_mpi_init(&TB);
+    mbedtls_mpi_init(&TA);
+    mbedtls_mpi_init(&TB);
 
     if (X == A) {
         MBEDTLS_MPI_CHK(mbedtls_mpi_copy(&TA, A)); A = &TA;
@@ -1166,13 +1167,7 @@
     MBEDTLS_MPI_CHK(mbedtls_mpi_grow(X, i + j));
     MBEDTLS_MPI_CHK(mbedtls_mpi_lset(X, 0));
 
-    for (size_t k = 0; k < j; k++) {
-        /* We know that there cannot be any carry-out since we're
-         * iterating from bottom to top. */
-        (void) mbedtls_mpi_core_mla(X->p + k, i + 1,
-                                    A->p, i,
-                                    B->p[k]);
-    }
+    mbedtls_mpi_core_mul(X->p, A->p, i, B->p, j);
 
     /* If the result is 0, we don't shortcut the operation, which reduces
      * but does not eliminate side channels leaking the zero-ness. We do
diff --git a/library/bignum_core.c b/library/bignum_core.c
index e50f043..1ba4142 100644
--- a/library/bignum_core.c
+++ b/library/bignum_core.c
@@ -448,6 +448,17 @@
     return c;
 }
 
+void mbedtls_mpi_core_mul(mbedtls_mpi_uint *X,
+                          const mbedtls_mpi_uint *A, size_t A_limbs,
+                          const mbedtls_mpi_uint *B, size_t B_limbs)
+{
+    memset(X, 0, (A_limbs + B_limbs) * ciL);
+
+    for (size_t i = 0; i < B_limbs; i++) {
+        (void) mbedtls_mpi_core_mla(X + i, A_limbs + 1, A, A_limbs, B[i]);
+    }
+}
+
 /*
  * Fast Montgomery initialization (thanks to Tom St Denis).
  */
diff --git a/library/bignum_core.h b/library/bignum_core.h
index 05bc923..7a0311a 100644
--- a/library/bignum_core.h
+++ b/library/bignum_core.h
@@ -399,6 +399,26 @@
                                       mbedtls_mpi_uint b);
 
 /**
+ * \brief Perform a known-size multiplication
+ *
+ * \p X may not be aliased to any of the inputs for this function.
+ * \p A may be aliased to \p B.
+ *
+ * \param[out] X     The pointer to the (little-endian) array to receive
+ *                   the product of \p A_limbs and \p B_limbs.
+ *                   This must be of length \p A_limbs + \p B_limbs.
+ * \param[in] A      The pointer to the (little-endian) array
+ *                   representing the first factor.
+ * \param A_limbs    The number of limbs in \p A.
+ * \param[in] B      The pointer to the (little-endian) array
+ *                   representing the second factor.
+ * \param B_limbs    The number of limbs in \p B.
+ */
+void mbedtls_mpi_core_mul(mbedtls_mpi_uint *X,
+                          const mbedtls_mpi_uint *A, size_t A_limbs,
+                          const mbedtls_mpi_uint *B, size_t B_limbs);
+
+/**
  * \brief Calculate initialisation value for fast Montgomery modular
  *        multiplication
  *
diff --git a/scripts/mbedtls_dev/bignum_common.py b/scripts/mbedtls_dev/bignum_common.py
index 5319ec6..aa2cd25 100644
--- a/scripts/mbedtls_dev/bignum_common.py
+++ b/scripts/mbedtls_dev/bignum_common.py
@@ -68,7 +68,8 @@
 
 def limbs_mpi(val: int, bits_in_limb: int) -> int:
     """Return the number of limbs required to store value."""
-    return (val.bit_length() + bits_in_limb - 1) // bits_in_limb
+    bit_length = max(val.bit_length(), 1)
+    return (bit_length + bits_in_limb - 1) // bits_in_limb
 
 def combination_pairs(values: List[T]) -> List[Tuple[T, T]]:
     """Return all pair combinations from input values."""
diff --git a/scripts/mbedtls_dev/bignum_core.py b/scripts/mbedtls_dev/bignum_core.py
index 24d37cb..e914ae7 100644
--- a/scripts/mbedtls_dev/bignum_core.py
+++ b/scripts/mbedtls_dev/bignum_core.py
@@ -230,6 +230,31 @@
                 yield cur_op.create_test_case()
 
 
+class BignumCoreMul(BignumCoreTarget, bignum_common.OperationCommon):
+    """Test cases for bignum core multiplication."""
+    count = 0
+    input_style = "arch_split"
+    symbol = "*"
+    test_function = "mpi_core_mul"
+    test_name = "mbedtls_mpi_core_mul"
+    arity = 2
+    unique_combinations_only = True
+
+    def format_arg(self, val: str) -> str:
+        return val
+
+    def format_result(self, res: int) -> str:
+        res_str = '{:x}'.format(res)
+        a_limbs = bignum_common.limbs_mpi(self.int_a, self.bits_in_limb)
+        b_limbs = bignum_common.limbs_mpi(self.int_b, self.bits_in_limb)
+        hex_digits = bignum_common.hex_digits_for_limb(a_limbs + b_limbs, self.bits_in_limb)
+        return bignum_common.quote_str(self.format_arg(res_str).zfill(hex_digits))
+
+    def result(self) -> List[str]:
+        result = self.int_a * self.int_b
+        return [self.format_result(result)]
+
+
 class BignumCoreMontmul(BignumCoreTarget, test_data_generation.BaseTest):
     """Test cases for Montgomery multiplication."""
     count = 0
diff --git a/tests/suites/test_suite_bignum_core.function b/tests/suites/test_suite_bignum_core.function
index 408eb0b..2f87ea9 100644
--- a/tests/suites/test_suite_bignum_core.function
+++ b/tests/suites/test_suite_bignum_core.function
@@ -1057,6 +1057,72 @@
 }
 /* END_CASE */
 
+/* BEGIN_CASE */
+void mpi_core_mul(char *input_A,
+                  char *input_B,
+                  char *result)
+{
+    mbedtls_mpi_uint *A      = NULL;
+    mbedtls_mpi_uint *A_orig = NULL;
+    mbedtls_mpi_uint *B      = NULL;
+    mbedtls_mpi_uint *B_orig = NULL;
+    mbedtls_mpi_uint *R      = NULL;
+    mbedtls_mpi_uint *X      = NULL;
+    size_t A_limbs, B_limbs, R_limbs;
+
+    TEST_EQUAL(mbedtls_test_read_mpi_core(&A, &A_limbs, input_A), 0);
+    TEST_EQUAL(mbedtls_test_read_mpi_core(&B, &B_limbs, input_B), 0);
+    TEST_EQUAL(mbedtls_test_read_mpi_core(&R, &R_limbs, result), 0);
+
+    TEST_EQUAL(R_limbs, A_limbs + B_limbs);
+
+    const size_t X_limbs = A_limbs + B_limbs;
+    const size_t X_bytes = X_limbs * sizeof(mbedtls_mpi_uint);
+    ASSERT_ALLOC(X, X_limbs);
+
+    const size_t A_bytes = A_limbs * sizeof(mbedtls_mpi_uint);
+    ASSERT_ALLOC(A_orig, A_limbs);
+    memcpy(A_orig, A, A_bytes);
+
+    const size_t B_bytes = B_limbs * sizeof(mbedtls_mpi_uint);
+    ASSERT_ALLOC(B_orig, B_limbs);
+    memcpy(B_orig, B, B_bytes);
+
+    /* Set result to something that is unlikely to be correct */
+    memset(X, '!', X_bytes);
+
+    /* 1. X = A * B - result should be correct, A and B unchanged */
+    mbedtls_mpi_core_mul(X, A, A_limbs, B, B_limbs);
+    ASSERT_COMPARE(X, X_bytes, R, X_bytes);
+    ASSERT_COMPARE(A, A_bytes, A_orig, A_bytes);
+    ASSERT_COMPARE(B, B_bytes, B_orig, B_bytes);
+
+    /* 2. A == B: alias A and B - result should be correct, A and B unchanged */
+    if (A_bytes == B_bytes && memcmp(A, B, A_bytes) == 0) {
+        memset(X, '!', X_bytes);
+        mbedtls_mpi_core_mul(X, A, A_limbs, A, A_limbs);
+        ASSERT_COMPARE(X, X_bytes, R, X_bytes);
+        ASSERT_COMPARE(A, A_bytes, A_orig, A_bytes);
+    }
+    /* 3. X = B * A - result should be correct, A and B unchanged */
+    else {
+        memset(X, '!', X_bytes);
+        mbedtls_mpi_core_mul(X, B, B_limbs, A, A_limbs);
+        ASSERT_COMPARE(X, X_bytes, R, X_bytes);
+        ASSERT_COMPARE(A, A_bytes, A_orig, A_bytes);
+        ASSERT_COMPARE(B, B_bytes, B_orig, B_bytes);
+    }
+
+exit:
+    mbedtls_free(A);
+    mbedtls_free(A_orig);
+    mbedtls_free(B);
+    mbedtls_free(B_orig);
+    mbedtls_free(R);
+    mbedtls_free(X);
+}
+/* END_CASE */
+
 /* BEGIN MERGE SLOT 1 */
 
 /* BEGIN_CASE */