Merge pull request #3512 from gilles-peskine-arm/ecp-alloc-202007

Reduce the number of allocations in ECP operations
diff --git a/library/bignum.c b/library/bignum.c
index e74a1ad..56d7dbe 100644
--- a/library/bignum.c
+++ b/library/bignum.c
@@ -1339,29 +1339,32 @@
 /**
  * Helper for mbedtls_mpi subtraction.
  *
- * Calculate d - s where d and s have the same size.
+ * Calculate l - r where l and r have the same size.
  * This function operates modulo (2^ciL)^n and returns the carry
- * (1 if there was a wraparound, i.e. if `d < s`, and 0 otherwise).
+ * (1 if there was a wraparound, i.e. if `l < r`, and 0 otherwise).
  *
- * \param n             Number of limbs of \p d and \p s.
- * \param[in,out] d     On input, the left operand.
- *                      On output, the result of the subtraction:
- * \param[in] s         The right operand.
+ * d may be aliased to l or r.
  *
- * \return              1 if `d < s`.
- *                      0 if `d >= s`.
+ * \param n             Number of limbs of \p d, \p l and \p r.
+ * \param[out] d        The result of the subtraction.
+ * \param[in] l         The left operand.
+ * \param[in] r         The right operand.
+ *
+ * \return              1 if `l < r`.
+ *                      0 if `l >= r`.
  */
 static mbedtls_mpi_uint mpi_sub_hlp( size_t n,
                                      mbedtls_mpi_uint *d,
-                                     const mbedtls_mpi_uint *s )
+                                     const mbedtls_mpi_uint *l,
+                                     const mbedtls_mpi_uint *r )
 {
     size_t i;
-    mbedtls_mpi_uint c, z;
+    mbedtls_mpi_uint c = 0, t, z;
 
-    for( i = c = 0; i < n; i++, s++, d++ )
+    for( i = 0; i < n; i++ )
     {
-        z = ( *d <  c );     *d -=  c;
-        c = ( *d < *s ) + z; *d -= *s;
+        z = ( l[i] <  c );    t = l[i] - c;
+        c = ( t < r[i] ) + z; d[i] = t - r[i];
     }
 
     return( c );
@@ -1372,7 +1375,6 @@
  */
 int mbedtls_mpi_sub_abs( mbedtls_mpi *X, const mbedtls_mpi *A, const mbedtls_mpi *B )
 {
-    mbedtls_mpi TB;
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     size_t n;
     mbedtls_mpi_uint carry;
@@ -1380,24 +1382,6 @@
     MPI_VALIDATE_RET( A != NULL );
     MPI_VALIDATE_RET( B != NULL );
 
-    mbedtls_mpi_init( &TB );
-
-    if( X == B )
-    {
-        MBEDTLS_MPI_CHK( mbedtls_mpi_copy( &TB, B ) );
-        B = &TB;
-    }
-
-    if( X != A )
-        MBEDTLS_MPI_CHK( mbedtls_mpi_copy( X, A ) );
-
-    /*
-     * X should always be positive as a result of unsigned subtractions.
-     */
-    X->s = 1;
-
-    ret = 0;
-
     for( n = B->n; n > 0; n-- )
         if( B->p[n - 1] != 0 )
             break;
@@ -1408,7 +1392,17 @@
         goto cleanup;
     }
 
-    carry = mpi_sub_hlp( n, X->p, B->p );
+    MBEDTLS_MPI_CHK( mbedtls_mpi_grow( X, A->n ) );
+
+    /* Set the high limbs of X to match A. Don't touch the lower limbs
+     * because X might be aliased to B, and we must not overwrite the
+     * significant digits of B. */
+    if( A->n > n )
+        memcpy( X->p + n, A->p + n, ( A->n - n ) * ciL );
+    if( X->n > A->n )
+        memset( X->p + A->n, 0, ( X->n - A->n ) * ciL );
+
+    carry = mpi_sub_hlp( n, X->p, A->p, B->p );
     if( carry != 0 )
     {
         /* Propagate the carry to the first nonzero limb of X. */
@@ -1424,10 +1418,10 @@
         --X->p[n];
     }
 
+    /* X should always be positive as a result of unsigned subtractions. */
+    X->s = 1;
+
 cleanup:
-
-    mbedtls_mpi_free( &TB );
-
     return( ret );
 }
 
@@ -1537,8 +1531,21 @@
     return( mbedtls_mpi_sub_mpi( X, A, &_B ) );
 }
 
-/*
- * Helper for mbedtls_mpi multiplication
+/** Helper for mbedtls_mpi multiplication.
+ *
+ * Add \p b * \p s to \p d.
+ *
+ * \param i             The number of limbs of \p s.
+ * \param[in] s         A bignum to multiply, of size \p i.
+ *                      It may overlap with \p d, but only if
+ *                      \p d <= \p s.
+ *                      Its leading limb must not be \c 0.
+ * \param[in,out] d     The bignum to add to.
+ *                      It must be sufficiently large to store the
+ *                      result of the multiplication. This means
+ *                      \p i + 1 limbs if \p d[\p i - 1] started as 0 and \p b
+ *                      is not known a priori.
+ * \param b             A scalar to multiply.
  */
 static
 #if defined(__APPLE__) && defined(__arm__)
@@ -1548,7 +1555,10 @@
  */
 __attribute__ ((noinline))
 #endif
-void mpi_mul_hlp( size_t i, mbedtls_mpi_uint *s, mbedtls_mpi_uint *d, mbedtls_mpi_uint b )
+void mpi_mul_hlp( size_t i,
+                  const mbedtls_mpi_uint *s,
+                  mbedtls_mpi_uint *d,
+                  mbedtls_mpi_uint b )
 {
     mbedtls_mpi_uint c = 0, t = 0;
 
@@ -1603,10 +1613,10 @@
 
     t++;
 
-    do {
+    while( c != 0 )
+    {
         *d += c; c = ( *d < c ); d++;
     }
-    while( c != 0 );
 }
 
 /*
@@ -1654,17 +1664,38 @@
  */
 int mbedtls_mpi_mul_int( mbedtls_mpi *X, const mbedtls_mpi *A, mbedtls_mpi_uint b )
 {
-    mbedtls_mpi _B;
-    mbedtls_mpi_uint p[1];
     MPI_VALIDATE_RET( X != NULL );
     MPI_VALIDATE_RET( A != NULL );
 
-    _B.s = 1;
-    _B.n = 1;
-    _B.p = p;
-    p[0] = b;
+    /* mpi_mul_hlp can't deal with a leading 0. */
+    size_t n = A->n;
+    while( n > 0 && A->p[n - 1] == 0 )
+        --n;
 
-    return( mbedtls_mpi_mul_mpi( X, A, &_B ) );
+    /* The general method below doesn't work if n==0 or b==0. By chance
+     * calculating the result is trivial in those cases. */
+    if( b == 0 || n == 0 )
+    {
+        mbedtls_mpi_lset( X, 0 );
+        return( 0 );
+    }
+
+    /* Calculate A*b as A + A*(b-1) to take advantage of mpi_mul_hlp */
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    /* In general, A * b requires 1 limb more than b. If
+     * A->p[n - 1] * b / b == A->p[n - 1], then A * b fits in the same
+     * number of limbs as A and the call to grow() is not required since
+     * copy() will take care of the growth if needed. However, experimentally,
+     * making the call to grow() unconditional causes slightly fewer
+     * calls to calloc() in ECP code, presumably because it reuses the
+     * same mpi for a while and this way the mpi is more likely to directly
+     * grow to its final size. */
+    MBEDTLS_MPI_CHK( mbedtls_mpi_grow( X, n + 1 ) );
+    MBEDTLS_MPI_CHK( mbedtls_mpi_copy( X, A ) );
+    mpi_mul_hlp( n, A->p, X->p, b - 1 );
+
+cleanup:
+    return( ret );
 }
 
 /*
@@ -1805,7 +1836,7 @@
 
     MBEDTLS_MPI_CHK( mbedtls_mpi_grow( &Z, A->n + 2 ) );
     MBEDTLS_MPI_CHK( mbedtls_mpi_lset( &Z,  0 ) );
-    MBEDTLS_MPI_CHK( mbedtls_mpi_grow( &T1, 2 ) );
+    MBEDTLS_MPI_CHK( mbedtls_mpi_grow( &T1, A->n + 2 ) );
 
     k = mbedtls_mpi_bitlen( &Y ) % biL;
     if( k < biL - 1 )
@@ -2071,7 +2102,7 @@
      * do the calculation without using conditional tests. */
     /* Set d to d0 + (2^biL)^n - N where d0 is the current value of d. */
     d[n] += 1;
-    d[n] -= mpi_sub_hlp( n, d, N->p );
+    d[n] -= mpi_sub_hlp( n, d, d, N->p );
     /* If d0 < N then d < (2^biL)^n
      * so d[n] == 0 and we want to keep A as it is.
      * If d0 >= N then d >= (2^biL)^n, and d <= (2^biL)^n + N < 2 * (2^biL)^n
diff --git a/library/ecp_curves.c b/library/ecp_curves.c
index 839fb5e..962d5af 100644
--- a/library/ecp_curves.c
+++ b/library/ecp_curves.c
@@ -1000,25 +1000,20 @@
 #define ADD( j )    add32( &cur, A( j ), &c );
 #define SUB( j )    sub32( &cur, A( j ), &c );
 
+#define ciL    (sizeof(mbedtls_mpi_uint))         /* chars in limb  */
+#define biL    (ciL << 3)                         /* bits  in limb  */
+
 /*
  * Helpers for the main 'loop'
- * (see fix_negative for the motivation of C)
  */
 #define INIT( b )                                                       \
-    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;                                                            \
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;                    \
     signed char c = 0, cc;                                              \
     uint32_t cur;                                                       \
     size_t i = 0, bits = (b);                                           \
-    mbedtls_mpi C;                                                      \
-    mbedtls_mpi_uint Cp[ (b) / 8 / sizeof( mbedtls_mpi_uint) + 1 ];     \
-                                                                        \
-    C.s = 1;                                                            \
-    C.n = (b) / 8 / sizeof( mbedtls_mpi_uint) + 1;                      \
-    C.p = Cp;                                                           \
-    memset( Cp, 0, C.n * sizeof( mbedtls_mpi_uint ) );                  \
-                                                                        \
-    MBEDTLS_MPI_CHK( mbedtls_mpi_grow( N, (b) * 2 / 8 /                 \
-                                       sizeof( mbedtls_mpi_uint ) ) );  \
+    /* N is the size of the product of two b-bit numbers, plus one */   \
+    /* limb for fix_negative */                                         \
+    MBEDTLS_MPI_CHK( mbedtls_mpi_grow( N, ( b ) * 2 / biL + 1 ) );      \
     LOAD32;
 
 #define NEXT                    \
@@ -1033,33 +1028,32 @@
     STORE32; i++;                               \
     cur = c > 0 ? c : 0; STORE32;               \
     cur = 0; while( ++i < MAX32 ) { STORE32; }  \
-    if( c < 0 ) MBEDTLS_MPI_CHK( fix_negative( N, c, &C, bits ) );
+    if( c < 0 ) fix_negative( N, c, bits );
 
 /*
  * If the result is negative, we get it in the form
  * c * 2^(bits + 32) + N, with c negative and N positive shorter than 'bits'
  */
-static inline int fix_negative( mbedtls_mpi *N, signed char c, mbedtls_mpi *C, size_t bits )
+static inline void fix_negative( mbedtls_mpi *N, signed char c, size_t bits )
 {
-    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    size_t i;
 
-    /* C = - c * 2^(bits + 32) */
-#if !defined(MBEDTLS_HAVE_INT64)
-    ((void) bits);
-#else
-    if( bits == 224 )
-        C->p[ C->n - 1 ] = ((mbedtls_mpi_uint) -c) << 32;
-    else
-#endif
-        C->p[ C->n - 1 ] = (mbedtls_mpi_uint) -c;
-
-    /* N = - ( C - N ) */
-    MBEDTLS_MPI_CHK( mbedtls_mpi_sub_abs( N, C, N ) );
+    /* Set N := N - 2^bits */
+    --N->p[0];
+    for( i = 0; i <= bits / 8 / sizeof( mbedtls_mpi_uint ); i++ )
+    {
+        N->p[i] = ~(mbedtls_mpi_uint)0 - N->p[i];
+    }
     N->s = -1;
 
-cleanup:
-
-    return( ret );
+    /* Add |c| * 2^(bits + 32) to the absolute value. Since c and N are
+    * negative, this adds c * 2^(bits + 32). */
+    mbedtls_mpi_uint msw = (mbedtls_mpi_uint) -c;
+#if defined(MBEDTLS_HAVE_INT64)
+    if( bits == 224 )
+        msw <<= 32;
+#endif
+    N->p[bits / 8 / sizeof( mbedtls_mpi_uint)] += msw;
 }
 
 #if defined(MBEDTLS_ECP_DP_SECP224R1_ENABLED)
diff --git a/programs/test/benchmark.c b/programs/test/benchmark.c
index 251cbb6..9c5911b 100644
--- a/programs/test/benchmark.c
+++ b/programs/test/benchmark.c
@@ -266,6 +266,21 @@
 #define ecp_clear_precomputed( g )
 #endif
 
+#if defined(MBEDTLS_ECP_C)
+static int set_ecp_curve( const char *string, mbedtls_ecp_curve_info *curve )
+{
+    const mbedtls_ecp_curve_info *found =
+        mbedtls_ecp_curve_info_from_name( string );
+    if( found != NULL )
+    {
+        *curve = *found;
+        return( 1 );
+    }
+    else
+        return( 0 );
+}
+#endif
+
 unsigned char buf[BUFSIZE];
 
 typedef struct {
@@ -289,6 +304,17 @@
 #if defined(MBEDTLS_MEMORY_BUFFER_ALLOC_C)
     unsigned char alloc_buf[HEAP_SIZE] = { 0 };
 #endif
+#if defined(MBEDTLS_ECP_C)
+    mbedtls_ecp_curve_info single_curve[2] = {
+        { MBEDTLS_ECP_DP_NONE, 0, 0, NULL },
+        { MBEDTLS_ECP_DP_NONE, 0, 0, NULL },
+    };
+    const mbedtls_ecp_curve_info *curve_list = mbedtls_ecp_curve_list( );
+#endif
+
+#if defined(MBEDTLS_ECP_C)
+    (void) curve_list; /* Unused in some configurations where no benchmark uses ECC */
+#endif
 
     if( argc <= 1 )
     {
@@ -356,6 +382,10 @@
                 todo.ecdsa = 1;
             else if( strcmp( argv[i], "ecdh" ) == 0 )
                 todo.ecdh = 1;
+#if defined(MBEDTLS_ECP_C)
+            else if( set_ecp_curve( argv[i], single_curve ) )
+                curve_list = single_curve;
+#endif
             else
             {
                 mbedtls_printf( "Unrecognized option: %s\n", argv[i] );
@@ -845,7 +875,7 @@
 
         memset( buf, 0x2A, sizeof( buf ) );
 
-        for( curve_info = mbedtls_ecp_curve_list();
+        for( curve_info = curve_list;
              curve_info->grp_id != MBEDTLS_ECP_DP_NONE;
              curve_info++ )
         {
@@ -867,7 +897,7 @@
             mbedtls_ecdsa_free( &ecdsa );
         }
 
-        for( curve_info = mbedtls_ecp_curve_list();
+        for( curve_info = curve_list;
              curve_info->grp_id != MBEDTLS_ECP_DP_NONE;
              curve_info++ )
         {
@@ -911,8 +941,23 @@
         };
         const mbedtls_ecp_curve_info *curve_info;
         size_t olen;
+        const mbedtls_ecp_curve_info *selected_montgomery_curve_list =
+            montgomery_curve_list;
 
-        for( curve_info = mbedtls_ecp_curve_list();
+        if( curve_list == (const mbedtls_ecp_curve_info*) &single_curve )
+        {
+            mbedtls_ecp_group grp;
+            mbedtls_ecp_group_init( &grp );
+            if( mbedtls_ecp_group_load( &grp, curve_list->grp_id ) != 0 )
+                mbedtls_exit( 1 );
+            if( mbedtls_ecp_get_type( &grp ) == MBEDTLS_ECP_TYPE_MONTGOMERY )
+                selected_montgomery_curve_list = single_curve;
+            else /* empty list */
+                selected_montgomery_curve_list = single_curve + 1;
+            mbedtls_ecp_group_free( &grp );
+        }
+
+        for( curve_info = curve_list;
              curve_info->grp_id != MBEDTLS_ECP_DP_NONE;
              curve_info++ )
         {
@@ -938,7 +983,7 @@
         }
 
         /* Montgomery curves need to be handled separately */
-        for ( curve_info = montgomery_curve_list;
+        for ( curve_info = selected_montgomery_curve_list;
               curve_info->grp_id != MBEDTLS_ECP_DP_NONE;
               curve_info++ )
         {
@@ -960,7 +1005,7 @@
             mbedtls_mpi_free( &z );
         }
 
-        for( curve_info = mbedtls_ecp_curve_list();
+        for( curve_info = curve_list;
              curve_info->grp_id != MBEDTLS_ECP_DP_NONE;
              curve_info++ )
         {
@@ -986,7 +1031,7 @@
         }
 
         /* Montgomery curves need to be handled separately */
-        for ( curve_info = montgomery_curve_list;
+        for ( curve_info = selected_montgomery_curve_list;
               curve_info->grp_id != MBEDTLS_ECP_DP_NONE;
               curve_info++)
         {
@@ -1015,7 +1060,6 @@
     {
         mbedtls_ecdh_context ecdh_srv, ecdh_cli;
         unsigned char buf_srv[BUFSIZE], buf_cli[BUFSIZE];
-        const mbedtls_ecp_curve_info * curve_list = mbedtls_ecp_curve_list();
         const mbedtls_ecp_curve_info *curve_info;
         size_t olen;