Three round solution

Attempt to fix failing test by dealing with overflow with three rounds,
instead of previous subtract modulus solution. Also optimise out shifts
by using memcpy / memmove instead. Remove final sub to return canonical
result, as this is not required here.

Signed-off-by: Paul Elliott <paul.elliott@arm.com>
diff --git a/library/ecp_curves.c b/library/ecp_curves.c
index a4b89be..2e377a0 100644
--- a/library/ecp_curves.c
+++ b/library/ecp_curves.c
@@ -5452,8 +5452,9 @@
 
 /* Number of limbs fully occupied by 2^224 (max), and limbs used by it (min) */
 #define DIV_ROUND_UP(X, Y) (((X) + (Y) -1) / (Y))
-#define P224_WIDTH_MIN   (28 / sizeof(mbedtls_mpi_uint))
-#define P224_WIDTH_MAX   DIV_ROUND_UP(28, sizeof(mbedtls_mpi_uint))
+#define P224_SIZE        (224 / 8)
+#define P224_WIDTH_MIN   (P224_SIZE / sizeof(mbedtls_mpi_uint))
+#define P224_WIDTH_MAX   DIV_ROUND_UP(P224_SIZE, sizeof(mbedtls_mpi_uint))
 #define P224_UNUSED_BITS ((P224_WIDTH_MAX * sizeof(mbedtls_mpi_uint) * 8) - 224)
 
 static int ecp_mod_p448(mbedtls_mpi *N)
@@ -5486,7 +5487,7 @@
 MBEDTLS_STATIC_TESTABLE
 int mbedtls_ecp_mod_p448(mbedtls_mpi_uint *X, size_t X_limbs)
 {
-    size_t i;
+    size_t i, round;
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
 
     if (X_limbs <= P448_WIDTH) {
@@ -5494,20 +5495,18 @@
     }
 
     size_t M_limbs = X_limbs - (P448_WIDTH);
-    const size_t Q_limbs = M_limbs;
 
     if (M_limbs > P448_WIDTH) {
         /* Shouldn't be called with X larger than 2^896! */
         return MBEDTLS_ERR_ECP_BAD_INPUT_DATA;
     }
 
-    /* Extra limb for carry below. */
+    /* Both M and Q require an extra limb to catch carries. */
     M_limbs++;
 
+    const size_t Q_limbs = M_limbs;
     mbedtls_mpi_uint *M = NULL;
     mbedtls_mpi_uint *Q = NULL;
-    const mbedtls_mpi_uint *P = (mbedtls_mpi_uint *) curve448_p;
-    const size_t P_limbs = CHARS_TO_LIMBS(sizeof(curve448_p));
 
     M = mbedtls_calloc(M_limbs, ciL);
 
@@ -5536,49 +5535,67 @@
      * added in, not returned as carry. */
     (void) mbedtls_mpi_core_add(X, X, M, M_limbs);
 
-    /* Deal with carry bit from add by subtracting P if necessary. */
-    if (X[P448_WIDTH] != 0) {
-        mbedtls_mpi_core_sub(X, X, P, P_limbs);
-    }
+    /* Q = B1 = M >> 224 */
+    memcpy(Q, (char *) M + P224_SIZE, P224_SIZE);
+    memset((char *) Q + P224_SIZE, 0, P224_SIZE);
 
-    /* Q = B1 */
-    memcpy(Q, M, (Q_limbs * ciL));
-    mbedtls_mpi_core_shift_r(Q, Q_limbs, 224);
-
-    /* X = X + Q = (A0 + A1) + B1 */
-    /* No carry here - only max 224 bits */
+    /* X = X + Q = (A0 + A1) + B1
+     * Oversize Q catches potential carry here when X is already max 448 bits.
+     */
     (void) mbedtls_mpi_core_add(X, X, Q, Q_limbs);
 
     /* M = B0 */
     if (sizeof(mbedtls_mpi_uint) > 4) {
         M[P224_WIDTH_MIN] &= ((mbedtls_mpi_uint)-1) >> (P224_UNUSED_BITS);
     }
-    for (i = P224_WIDTH_MAX; i < M_limbs; ++i) {
-        M[i] = 0;
-    }
     memset(M + P224_WIDTH_MAX, 0, ((M_limbs - P224_WIDTH_MAX) * ciL));
 
     /* M = M + Q = B0 + B1 */
     (void) mbedtls_mpi_core_add(M, M, Q, Q_limbs);
 
     /* M = (B0 + B1) * 2^224 */
-    /* Shifted carry bit from the addition fits in oversize M */
-    mbedtls_mpi_core_shift_l(M, M_limbs, 224);
-
+    /* Shifted carry bit from the addition fits in oversize M. */
+    memmove((char *) M + P224_SIZE, M, P224_SIZE + sizeof(mbedtls_mpi_uint));
+    memset(M, 0, P224_SIZE);
 
     /* X = X + M = (A0 + A1 + B1) + (B0 + B1) * 2^224 */
     (void) mbedtls_mpi_core_add(X, X, M, M_limbs);
 
-    /* Deal with carry bit by subtracting P if necessary. */
-    if (X[P448_WIDTH] != 0) {
-        mbedtls_mpi_core_sub(X, X, P, P_limbs);
-    }
+    /* In the second and third rounds A1 and B0 have at most 1 non-zero limb and
+     * B1=0.
+     * Using this we need to calculate:
+     * A0 + A1 + B1 + (B0 + B1) * 2^224 = A0 + A1 + B0 * 2^224. */
+    for (round = 0; round < 2; ++round) {
 
-    /* Returned result should be 0 < X < P. Although we have controlled bit
-     * width, we may still have a result which is greater than P. Subtract P
-     * if this is the case. */
-    if (mbedtls_mpi_core_lt_ct(P, X, P_limbs)) {
-        mbedtls_mpi_core_sub(X, X, P, P_limbs);
+        /* Q = A1 */
+        memset(Q, 0, (Q_limbs * ciL));
+        memcpy(Q, X + P448_WIDTH, ((Q_limbs - 1) * ciL));
+
+        /* X = A0 */
+        memset(X + P448_WIDTH, 0, ((M_limbs - 1) * ciL));
+
+        /* M = B0 */
+        memcpy(M, Q, (Q_limbs * ciL));
+        M[M_limbs - 1] = 0;
+
+        if (sizeof(mbedtls_mpi_uint) > 4) {
+            M[P224_WIDTH_MIN] &= ((mbedtls_mpi_uint) -1) >> (P224_UNUSED_BITS);
+        }
+
+        /* M = B0 * 2^224
+         * Oversize M once again takes any carry. */
+        memmove((char *) M + P224_SIZE, M, P224_SIZE +
+        sizeof(mbedtls_mpi_uint)); memset(M, 0, P224_SIZE);
+
+        /* M = A1 + B0 * 2^224
+         * No need to have to call mbedtls_mpi_core_add() as as both bignums
+         * should be all zero except one non-colliding limb each. */
+        for (i = 0; i < (M_limbs - 1); ++i) {
+            M[i] = M[i] + Q[i];
+        }
+
+        /* X = A0 + (A1 + B0 * 2^224) */
+        (void) mbedtls_mpi_core_add(X, X, M, M_limbs);
     }
 
     ret = 0;