Simplify the final reduction in mpi_montmul

There was some confusion during review about when A->p[n] could be
nonzero. In fact, there is no need to set A->p[n]: only the
intermediate result d might need to extend to n+1 limbs, not the final
result A. So never access A->p[n]. Rework the explanation of the
calculation in a way that should be easier to follow.

Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
diff --git a/library/bignum.c b/library/bignum.c
index 26784f9..65f20b3 100644
--- a/library/bignum.c
+++ b/library/bignum.c
@@ -1730,8 +1730,8 @@
 /** Montgomery multiplication: A = A * B * R^-1 mod N  (HAC 14.36)
  *
  * \param[in,out]   A   One of the numbers to multiply.
- *                      It must have at least one more limb than N
- *                      (A->n >= N->n + 1).
+ *                      It must have at least as many limbs as N
+ *                      (A->n >= N->n), and any limbs beyond n are ignored.
  *                      On successful completion, A contains the result of
  *                      the multiplication A * B * R^-1 mod N where
  *                      R = (2^ciL)^n.
@@ -1775,18 +1775,25 @@
         *d++ = u0; d[n + 1] = 0;
     }
 
-    memcpy( A->p, d, ( n + 1 ) * ciL );
+    /* At this point, d is either the desired result or the desired result
+     * plus N. We now potentially subtract N, avoiding leaking whether the
+     * subtraction is performed through side channels. */
 
-    /* If A >= N then A -= N. Do the subtraction unconditionally to prevent
-     * timing attacks. */
-    /* Set d to A + (2^biL)^n - N. */
+    /* Copy the n least significant limbs of d to A, so that
+     * A = d if d < N (recall that N has n limbs). */
+    memcpy( A->p, d, n * ciL );
+    /* If d >= N then we want to set A to N - d. To prevent timing attacks,
+     * do the calculation without using conditional tests. */
+    /* Set d to d0 + (2^biL)^n - N where d0 is the current value of d. */
     d[n] += 1;
     d[n] -= mpi_sub_hlp( n, d, N->p );
-    /* Now d - (2^biL)^n = A - N so d >= (2^biL)^n iff A >= N.
-     * So we want to copy the result of the subtraction iff d->p[n] != 0.
-     * Note that d->p[n] is either 0 or 1 since A - N <= N <= (2^biL)^n. */
-    mpi_safe_cond_assign( n + 1, A->p, d, (unsigned char) d[n] );
-    A->p[n] = 0;
+    /* If d0 < N then d < (2^biL)^n
+     * so d[n] == 0 and we want to keep A as it is.
+     * If d0 >= N then d >= (2^biL)^n, and d <= (2^biL)^n + N < 2 * (2^biL)^n
+     * so d[n] == 1 and we want to set A to the result of the subtraction
+     * which is d - (2^biL)^n, i.e. the n least significant limbs of d.
+     * This exactly corresponds to a conditional assignment. */
+    mpi_safe_cond_assign( n, A->p, d, (unsigned char) d[n] );
 }
 
 /*