Merge pull request #5701 from hanno-arm/mpi_mul_hlp

Make size of output in mpi_mul_hlp() explicit
diff --git a/library/bignum.c b/library/bignum.c
index 288f859..6f634b5 100644
--- a/library/bignum.c
+++ b/library/bignum.c
@@ -38,6 +38,7 @@
 #if defined(MBEDTLS_BIGNUM_C)
 
 #include "mbedtls/bignum.h"
+#include "bignum_internal.h"
 #include "bn_mul.h"
 #include "mbedtls/platform_util.h"
 #include "mbedtls/error.h"
@@ -1369,53 +1370,29 @@
     return( mbedtls_mpi_sub_mpi( X, A, &B ) );
 }
 
-/** Helper for mbedtls_mpi multiplication.
- *
- * Add \p b * \p s to \p d.
- *
- * \param i             The number of limbs of \p s.
- * \param[in] s         A bignum to multiply, of size \p i.
- *                      It may overlap with \p d, but only if
- *                      \p d <= \p s.
- *                      Its leading limb must not be \c 0.
- * \param[in,out] d     The bignum to add to.
- *                      It must be sufficiently large to store the
- *                      result of the multiplication. This means
- *                      \p i + 1 limbs if \p d[\p i - 1] started as 0 and \p b
- *                      is not known a priori.
- * \param b             A scalar to multiply.
- */
-static
-#if defined(__APPLE__) && defined(__arm__)
-/*
- * Apple LLVM version 4.2 (clang-425.0.24) (based on LLVM 3.2svn)
- * appears to need this to prevent bad ARM code generation at -O3.
- */
-__attribute__ ((noinline))
-#endif
-void mpi_mul_hlp( size_t i,
-                  const mbedtls_mpi_uint *s,
-                  mbedtls_mpi_uint *d,
-                  mbedtls_mpi_uint b )
+mbedtls_mpi_uint mbedtls_mpi_core_mla( mbedtls_mpi_uint *d, size_t d_len,
+                                       const mbedtls_mpi_uint *s, size_t s_len,
+                                       mbedtls_mpi_uint b )
 {
-    mbedtls_mpi_uint c = 0, t = 0;
+    mbedtls_mpi_uint c = 0; /* carry */
+    size_t excess_len = d_len - s_len;
 
 #if defined(MULADDC_HUIT)
-    for( ; i >= 8; i -= 8 )
+    for( ; s_len >= 8; s_len -= 8 )
     {
         MULADDC_INIT
         MULADDC_HUIT
         MULADDC_STOP
     }
 
-    for( ; i > 0; i-- )
+    for( ; s_len > 0; s_len-- )
     {
         MULADDC_INIT
         MULADDC_CORE
         MULADDC_STOP
     }
 #else /* MULADDC_HUIT */
-    for( ; i >= 16; i -= 16 )
+    for( ; s_len >= 16; s_len -= 16 )
     {
         MULADDC_INIT
         MULADDC_CORE   MULADDC_CORE
@@ -1430,7 +1407,7 @@
         MULADDC_STOP
     }
 
-    for( ; i >= 8; i -= 8 )
+    for( ; s_len >= 8; s_len -= 8 )
     {
         MULADDC_INIT
         MULADDC_CORE   MULADDC_CORE
@@ -1441,7 +1418,7 @@
         MULADDC_STOP
     }
 
-    for( ; i > 0; i-- )
+    for( ; s_len > 0; s_len-- )
     {
         MULADDC_INIT
         MULADDC_CORE
@@ -1449,12 +1426,12 @@
     }
 #endif /* MULADDC_HUIT */
 
-    t++;
-
-    while( c != 0 )
+    while( excess_len-- )
     {
         *d += c; c = ( *d < c ); d++;
     }
+
+    return( c );
 }
 
 /*
@@ -1490,8 +1467,14 @@
     MBEDTLS_MPI_CHK( mbedtls_mpi_grow( X, i + j ) );
     MBEDTLS_MPI_CHK( mbedtls_mpi_lset( X, 0 ) );
 
-    for( ; j > 0; j-- )
-        mpi_mul_hlp( i, A->p, X->p + j - 1, B->p[j - 1] );
+    for( size_t k = 0; k < j; k++ )
+    {
+        /* We know that there cannot be any carry-out since we're
+         * iterating from bottom to top. */
+        (void) mbedtls_mpi_core_mla( X->p + k, i + 1,
+                                     A->p, i,
+                                     B->p[k] );
+    }
 
     /* If the result is 0, we don't shortcut the operation, which reduces
      * but does not eliminate side channels leaking the zero-ness. We do
@@ -1517,19 +1500,15 @@
     MPI_VALIDATE_RET( X != NULL );
     MPI_VALIDATE_RET( A != NULL );
 
-    /* mpi_mul_hlp can't deal with a leading 0. */
     size_t n = A->n;
     while( n > 0 && A->p[n - 1] == 0 )
         --n;
 
-    /* The general method below doesn't work if n==0 or b==0. By chance
-     * calculating the result is trivial in those cases. */
+    /* The general method below doesn't work if b==0. */
     if( b == 0 || n == 0 )
-    {
         return( mbedtls_mpi_lset( X, 0 ) );
-    }
 
-    /* Calculate A*b as A + A*(b-1) to take advantage of mpi_mul_hlp */
+    /* Calculate A*b as A + A*(b-1) to take advantage of mbedtls_mpi_core_mla */
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     /* In general, A * b requires 1 limb more than b. If
      * A->p[n - 1] * b / b == A->p[n - 1], then A * b fits in the same
@@ -1538,10 +1517,13 @@
      * making the call to grow() unconditional causes slightly fewer
      * calls to calloc() in ECP code, presumably because it reuses the
      * same mpi for a while and this way the mpi is more likely to directly
-     * grow to its final size. */
+     * grow to its final size.
+     *
+     * Note that calculating A*b as 0 + A*b doesn't work as-is because
+     * A,X can be the same. */
     MBEDTLS_MPI_CHK( mbedtls_mpi_grow( X, n + 1 ) );
     MBEDTLS_MPI_CHK( mbedtls_mpi_copy( X, A ) );
-    mpi_mul_hlp( n, A->p, X->p, b - 1 );
+    mbedtls_mpi_core_mla( X->p, X->n, A->p, n, b - 1 );
 
 cleanup:
     return( ret );
@@ -1907,8 +1889,8 @@
  * \param           mm  The value calculated by `mpi_montg_init(&mm, N)`.
  *                      This is -N^-1 mod 2^ciL.
  * \param[in,out]   T   A bignum for temporary storage.
- *                      It must be at least twice the limb size of N plus 2
- *                      (T->n >= 2 * (N->n + 1)).
+ *                      It must be at least twice the limb size of N plus 1
+ *                      (T->n >= 2 * N->n + 1).
  *                      Its initial content is unused and
  *                      its final content is indeterminate.
  *                      Note that unlike the usual convention in the library
@@ -1917,8 +1899,8 @@
 static void mpi_montmul( mbedtls_mpi *A, const mbedtls_mpi *B, const mbedtls_mpi *N, mbedtls_mpi_uint mm,
                          const mbedtls_mpi *T )
 {
-    size_t i, n, m;
-    mbedtls_mpi_uint u0, u1, *d;
+    size_t n, m;
+    mbedtls_mpi_uint *d;
 
     memset( T->p, 0, T->n * ciL );
 
@@ -1926,18 +1908,23 @@
     n = N->n;
     m = ( B->n < n ) ? B->n : n;
 
-    for( i = 0; i < n; i++ )
+    for( size_t i = 0; i < n; i++ )
     {
+        mbedtls_mpi_uint u0, u1;
+
         /*
          * T = (T + u0*B + u1*N) / 2^biL
          */
         u0 = A->p[i];
         u1 = ( d[0] + u0 * B->p[0] ) * mm;
 
-        mpi_mul_hlp( m, B->p, d, u0 );
-        mpi_mul_hlp( n, N->p, d, u1 );
-
-        d++; d[n + 1] = 0;
+        (void) mbedtls_mpi_core_mla( d, n + 2,
+                                     B->p, m,
+                                     u0 );
+        (void) mbedtls_mpi_core_mla( d, n + 2,
+                                     N->p, n,
+                                     u1 );
+        d++;
     }
 
     /* At this point, d is either the desired result or the desired result
diff --git a/library/bignum_internal.h b/library/bignum_internal.h
new file mode 100644
index 0000000..8677dcf
--- /dev/null
+++ b/library/bignum_internal.h
@@ -0,0 +1,50 @@
+/**
+ *  Internal bignum functions
+ *
+ *  Copyright The Mbed TLS Contributors
+ *  SPDX-License-Identifier: Apache-2.0
+ *
+ *  Licensed under the Apache License, Version 2.0 (the "License"); you may
+ *  not use this file except in compliance with the License.
+ *  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ *  Unless required by applicable law or agreed to in writing, software
+ *  distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ *  WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ *  See the License for the specific language governing permissions and
+ *  limitations under the License.
+ */
+
+#ifndef MBEDTLS_BIGNUM_INTERNAL_H
+#define MBEDTLS_BIGNUM_INTERNAL_H
+
+#include "common.h"
+
+#if defined(MBEDTLS_BIGNUM_C)
+#include "mbedtls/bignum.h"
+#endif
+
+/** Perform a known-size multiply accumulate operation
+ *
+ * Add \p b * \p s to \p d.
+ *
+ * \param[in,out] d     The pointer to the (little-endian) array
+ *                      representing the bignum to accumulate onto.
+ * \param d_len         The number of limbs of \p d. This must be
+ *                      at least \p s_len.
+ * \param[in] s         The pointer to the (little-endian) array
+ *                      representing the bignum to multiply with.
+ *                      This may be the same as \p d. Otherwise,
+ *                      it must be disjoint from \p d.
+ * \param s_len         The number of limbs of \p s.
+ * \param b             A scalar to multiply with.
+ *
+ * \return c            The carry at the end of the operation.
+ */
+mbedtls_mpi_uint mbedtls_mpi_core_mla( mbedtls_mpi_uint *d, size_t d_len ,
+                                       const mbedtls_mpi_uint *s, size_t s_len,
+                                       mbedtls_mpi_uint b );
+
+#endif /* MBEDTLS_BIGNUM_INTERNAL_H */
diff --git a/library/bn_mul.h b/library/bn_mul.h
index b71ddd8..aa1183f 100644
--- a/library/bn_mul.h
+++ b/library/bn_mul.h
@@ -99,6 +99,7 @@
 #if defined(__i386__) && defined(__OPTIMIZE__)
 
 #define MULADDC_INIT                        \
+    { mbedtls_mpi_uint t;                   \
     asm(                                    \
         "movl   %%ebx, %0           \n\t"   \
         "movl   %5, %%esi           \n\t"   \
@@ -190,7 +191,8 @@
         : "=m" (t), "=m" (c), "=m" (d), "=m" (s)        \
         : "m" (t), "m" (s), "m" (d), "m" (c), "m" (b)   \
         : "eax", "ebx", "ecx", "edx", "esi", "edi"      \
-    );
+    ); }                                                \
+
 
 #else
 
@@ -202,7 +204,7 @@
         : "=m" (t), "=m" (c), "=m" (d), "=m" (s)        \
         : "m" (t), "m" (s), "m" (d), "m" (c), "m" (b)   \
         : "eax", "ebx", "ecx", "edx", "esi", "edi"      \
-    );
+    ); }
 #endif /* SSE2 */
 #endif /* i386 */
 
diff --git a/library/ecp_curves.c b/library/ecp_curves.c
index 421a067..6b8ff5c 100644
--- a/library/ecp_curves.c
+++ b/library/ecp_curves.c
@@ -26,6 +26,7 @@
 #include "mbedtls/error.h"
 
 #include "bn_mul.h"
+#include "bignum_internal.h"
 #include "ecp_invasive.h"
 
 #include <string.h>
@@ -5213,40 +5214,30 @@
 
 /*
  * Fast quasi-reduction modulo p255 = 2^255 - 19
- * Write N as A0 + 2^255 A1, return A0 + 19 * A1
+ * Write N as A0 + 2^256 A1, return A0 + 38 * A1
  */
 static int ecp_mod_p255( mbedtls_mpi *N )
 {
-    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-    size_t i;
-    mbedtls_mpi M;
-    mbedtls_mpi_uint Mp[P255_WIDTH + 2];
+    mbedtls_mpi_uint Mp[P255_WIDTH];
 
-    if( N->n < P255_WIDTH )
+    /* Helper references for top part of N */
+    mbedtls_mpi_uint * const NT_p = N->p + P255_WIDTH;
+    const size_t NT_n = N->n - P255_WIDTH;
+    if( N->n <= P255_WIDTH )
         return( 0 );
-
-    /* M = A1 */
-    M.s = 1;
-    M.n = N->n - ( P255_WIDTH - 1 );
-    if( M.n > P255_WIDTH + 1 )
+    if( NT_n > P255_WIDTH )
         return( MBEDTLS_ERR_ECP_BAD_INPUT_DATA );
-    M.p = Mp;
-    memset( Mp, 0, sizeof Mp );
-    memcpy( Mp, N->p + P255_WIDTH - 1, M.n * sizeof( mbedtls_mpi_uint ) );
-    MBEDTLS_MPI_CHK( mbedtls_mpi_shift_r( &M, 255 % ( 8 * sizeof( mbedtls_mpi_uint ) ) ) );
-    M.n++; /* Make room for multiplication by 19 */
 
-    /* N = A0 */
-    MBEDTLS_MPI_CHK( mbedtls_mpi_set_bit( N, 255, 0 ) );
-    for( i = P255_WIDTH; i < N->n; i++ )
-        N->p[i] = 0;
+    /* Split N as N + 2^256 M */
+    memcpy( Mp,   NT_p, sizeof( mbedtls_mpi_uint ) * NT_n );
+    memset( NT_p, 0,    sizeof( mbedtls_mpi_uint ) * NT_n );
 
-    /* N = A0 + 19 * A1 */
-    MBEDTLS_MPI_CHK( mbedtls_mpi_mul_int( &M, &M, 19 ) );
-    MBEDTLS_MPI_CHK( mbedtls_mpi_add_abs( N, N, &M ) );
+    /* N = A0 + 38 * A1 */
+    mbedtls_mpi_core_mla( N->p, P255_WIDTH + 1,
+                          Mp, NT_n,
+                          38 );
 
-cleanup:
-    return( ret );
+    return( 0 );
 }
 #endif /* MBEDTLS_ECP_DP_CURVE25519_ENABLED */