Restrict input parameter size for ecp_mod_p521_raw

The imput mpi parameter must have twice as many limbs as the modulus.

Signed-off-by: Gabor Mezei <gabor.mezei@arm.com>
diff --git a/library/ecp_curves.c b/library/ecp_curves.c
index 49182a4..85d634a 100644
--- a/library/ecp_curves.c
+++ b/library/ecp_curves.c
@@ -5203,7 +5203,7 @@
 static int ecp_mod_p521(mbedtls_mpi *N)
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-    size_t expected_width = 2 * ((521 + biL - 1) / biL);
+    size_t expected_width = 2 * P521_WIDTH;
     MBEDTLS_MPI_CHK(mbedtls_mpi_grow(N, expected_width));
     ret = mbedtls_ecp_mod_p521_raw(N->p, expected_width);
 cleanup:
@@ -5215,41 +5215,34 @@
 {
     mbedtls_mpi_uint carry = 0;
 
-    if (X_limbs > 2 * P521_WIDTH - 1) {
-        X_limbs = 2 * P521_WIDTH - 1;
-    }
-    if (X_limbs < P521_WIDTH) {
-        return 0;
+    if (X_limbs != 2 * P521_WIDTH || X[2 * P521_WIDTH - 1] != 0) {
+       return MBEDTLS_ERR_ECP_BAD_INPUT_DATA;
     }
 
     /* Step 1: Reduction to P521_WIDTH limbs */
-    if (X_limbs > P521_WIDTH) {
-        /* Helper references for bottom part of X */
-        mbedtls_mpi_uint *X0 = X;
-        size_t X0_limbs = P521_WIDTH;
-        /* Helper references for top part of X */
-        mbedtls_mpi_uint *X1 = X + X0_limbs;
-        size_t X1_limbs = X_limbs - X0_limbs;
-
-        /* Split X as X0 + 2^P521_WIDTH X1 and compute X0 + 2^(biL - 9) X1.
-         * (We are using that 2^P521_WIDTH = 2^(512 + biL) and that
-         * 2^(512 + biL) X1 = 2^(biL - 9) X1 mod P521.)
-         * The high order limb of the result will be held in carry and the rest
-         * in X0 (that is the result will be represented as
-         * 2^P521_WIDTH carry + X0).
-         *
-         * Also, note that the resulting carry is either 0 or 1:
-         * X0 < 2^P521_WIDTH = 2^(512 + biL) and X1 < 2^(P521_WIDTH-biL) = 2^512
-         * therefore
-         * X0 + 2^(biL - 9) X1 < 2^(512 + biL) + 2^(512 + biL - 9)
-         * which in turn is less than 2 * 2^(512 + biL).
-         */
-        mbedtls_mpi_uint shift = ((mbedtls_mpi_uint) 1u) << (biL - 9);
-        carry = mbedtls_mpi_core_mla(X0, X0_limbs, X1, X1_limbs, shift);
-
-        /* Set X to X0 (by clearing the top part). */
-        memset(X1, 0, X1_limbs * sizeof(mbedtls_mpi_uint));
-    }
+    /* Helper references for bottom part of X */
+    mbedtls_mpi_uint *X0 = X;
+    size_t X0_limbs = P521_WIDTH;
+    /* Helper references for top part of X */
+    mbedtls_mpi_uint *X1 = X + X0_limbs;
+    size_t X1_limbs = X_limbs - X0_limbs;
+    /* Split X as X0 + 2^P521_WIDTH X1 and compute X0 + 2^(biL - 9) X1.
+     * (We are using that 2^P521_WIDTH = 2^(512 + biL) and that
+     * 2^(512 + biL) X1 = 2^(biL - 9) X1 mod P521.)
+     * The high order limb of the result will be held in carry and the rest
+     * in X0 (that is the result will be represented as
+     * 2^P521_WIDTH carry + X0).
+     *
+     * Also, note that the resulting carry is either 0 or 1:
+     * X0 < 2^P521_WIDTH = 2^(512 + biL) and X1 < 2^(P521_WIDTH-biL) = 2^512
+     * therefore
+     * X0 + 2^(biL - 9) X1 < 2^(512 + biL) + 2^(512 + biL - 9)
+     * which in turn is less than 2 * 2^(512 + biL).
+     */
+    mbedtls_mpi_uint shift = ((mbedtls_mpi_uint) 1u) << (biL - 9);
+    carry = mbedtls_mpi_core_mla(X0, X0_limbs, X1, X1_limbs, shift);
+    /* Set X to X0 (by clearing the top part). */
+    memset(X1, 0, X1_limbs * sizeof(mbedtls_mpi_uint));
 
     /* Step 2: Reduction modulo P521
      *
@@ -5267,14 +5260,9 @@
      * carrying out the addition. */
     mbedtls_mpi_uint *addend_arr = X + P521_WIDTH;
     addend_arr[0] = addend;
-    /* The unused part of X is P521_WIDTH - 1 limbs in size and only that
-     * size can be used for addition. Due to the addend fit in a limb
-     * the limbs other the first in the helper array are only used for
-     * propagating the carry. By adding the carry of the P521_WIDTH - 1 limb
-     * addition to the last limb of X makes the addition of X and the addend
-     * complete. */
-    carry = mbedtls_mpi_core_add(X, X, addend_arr, P521_WIDTH - 1);
-    X[P521_WIDTH - 1] += carry;
+    (void) mbedtls_mpi_core_add(X, X, addend_arr, P521_WIDTH);
+    /* Both addends were less than P521 therefore X < 2 * P521. (This also means
+     * that the result fit in P521_WIDTH limbs and there won't be any carry.) */
 
     /* Clear the reused part of X. */
     addend_arr[0] = 0;
diff --git a/scripts/mbedtls_dev/ecp.py b/scripts/mbedtls_dev/ecp.py
index fa70ded..d436d0a 100644
--- a/scripts/mbedtls_dev/ecp.py
+++ b/scripts/mbedtls_dev/ecp.py
@@ -81,7 +81,7 @@
     """Test cases for ecp quasi_reduction()."""
     test_function = "ecp_mod_p521_raw"
     test_name = "ecp_mod_p521_raw"
-    input_style = "arch_split"
+    input_style = "fixed"
     arity = 1
 
     moduli = [("01ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"
@@ -156,9 +156,8 @@
 
     @property
     def arg_a(self) -> str:
-        # Number of limbs: 2 * N - 1
-        hex_digits = bignum_common.hex_digits_for_limb(2 * self.limbs - 1, self.bits_in_limb)
-        return super().format_arg('{:x}'.format(self.int_a)).zfill(hex_digits)
+        # Number of limbs: 2 * N
+        return super().format_arg('{:x}'.format(self.int_a)).zfill(2 * self.hex_digits)
 
     def result(self) -> List[str]:
         result = self.int_a % self.int_n
diff --git a/tests/suites/test_suite_ecp.function b/tests/suites/test_suite_ecp.function
index a0042ed..212dfcb 100644
--- a/tests/suites/test_suite_ecp.function
+++ b/tests/suites/test_suite_ecp.function
@@ -1367,7 +1367,7 @@
     size_t limbs = limbs_N;
     size_t bytes = limbs * sizeof(mbedtls_mpi_uint);
 
-    TEST_EQUAL(limbs_X, 2 * limbs - 1);
+    TEST_EQUAL(limbs_X, 2 * limbs);
     TEST_EQUAL(limbs_res, limbs);
 
     TEST_EQUAL(mbedtls_mpi_mod_modulus_setup(