Merge pull request #6449 from gilles-peskine-arm/bignum-core-shift_r

Bignum core: shift_r
diff --git a/library/bignum.c b/library/bignum.c
index 1c7f919..58cd2f7 100644
--- a/library/bignum.c
+++ b/library/bignum.c
@@ -771,42 +771,9 @@
  */
 int mbedtls_mpi_shift_r( mbedtls_mpi *X, size_t count )
 {
-    size_t i, v0, v1;
-    mbedtls_mpi_uint r0 = 0, r1;
     MPI_VALIDATE_RET( X != NULL );
-
-    v0 = count /  biL;
-    v1 = count & (biL - 1);
-
-    if( v0 > X->n || ( v0 == X->n && v1 > 0 ) )
-        return mbedtls_mpi_lset( X, 0 );
-
-    /*
-     * shift by count / limb_size
-     */
-    if( v0 > 0 )
-    {
-        for( i = 0; i < X->n - v0; i++ )
-            X->p[i] = X->p[i + v0];
-
-        for( ; i < X->n; i++ )
-            X->p[i] = 0;
-    }
-
-    /*
-     * shift by count % limb_size
-     */
-    if( v1 > 0 )
-    {
-        for( i = X->n; i > 0; i-- )
-        {
-            r1 = X->p[i - 1] << (biL - v1);
-            X->p[i - 1] >>= v1;
-            X->p[i - 1] |= r0;
-            r0 = r1;
-        }
-    }
-
+    if( X->n != 0 )
+        mbedtls_mpi_core_shift_r( X->p, X->n, count );
     return( 0 );
 }
 
diff --git a/library/bignum_core.c b/library/bignum_core.c
index 89fd404..0083729 100644
--- a/library/bignum_core.c
+++ b/library/bignum_core.c
@@ -316,6 +316,52 @@
     return( 0 );
 }
 
+
+
+void mbedtls_mpi_core_shift_r( mbedtls_mpi_uint *X, size_t limbs,
+                               size_t count )
+{
+    size_t i, v0, v1;
+    mbedtls_mpi_uint r0 = 0, r1;
+
+    v0 = count /  biL;
+    v1 = count & (biL - 1);
+
+    if( v0 > limbs || ( v0 == limbs && v1 > 0 ) )
+    {
+        memset( X, 0, limbs * ciL );
+        return;
+    }
+
+    /*
+     * shift by count / limb_size
+     */
+    if( v0 > 0 )
+    {
+        for( i = 0; i < limbs - v0; i++ )
+            X[i] = X[i + v0];
+
+        for( ; i < limbs; i++ )
+            X[i] = 0;
+    }
+
+    /*
+     * shift by count % limb_size
+     */
+    if( v1 > 0 )
+    {
+        for( i = limbs; i > 0; i-- )
+        {
+            r1 = X[i - 1] << (biL - v1);
+            X[i - 1] >>= v1;
+            X[i - 1] |= r0;
+            r0 = r1;
+        }
+    }
+}
+
+
+
 mbedtls_mpi_uint mbedtls_mpi_core_add_if( mbedtls_mpi_uint *X,
                                           const mbedtls_mpi_uint *A,
                                           size_t limbs,
diff --git a/library/bignum_core.h b/library/bignum_core.h
index 196736d..56a3bf8 100644
--- a/library/bignum_core.h
+++ b/library/bignum_core.h
@@ -262,6 +262,21 @@
                                unsigned char *output,
                                size_t output_length );
 
+/** \brief              Shift an MPI right in place by a number of bits.
+ *
+ *                      Shifting by more bits than there are bit positions
+ *                      in \p X is valid and results in setting \p X to 0.
+ *
+ *                      This function's execution time depends on the value
+ *                      of \p count (and of course \p limbs).
+ *
+ * \param[in,out] X     The number to shift.
+ * \param limbs         The number of limbs of \p X. This must be at least 1.
+ * \param count         The number of bits to shift by.
+ */
+void mbedtls_mpi_core_shift_r( mbedtls_mpi_uint *X, size_t limbs,
+                               size_t count );
+
 /**
  * \brief Conditional addition of two fixed-size large unsigned integers,
  *        returning the carry.
diff --git a/scripts/mbedtls_dev/bignum_core.py b/scripts/mbedtls_dev/bignum_core.py
index 2e64195..e46364b 100644
--- a/scripts/mbedtls_dev/bignum_core.py
+++ b/scripts/mbedtls_dev/bignum_core.py
@@ -29,6 +29,47 @@
     target_basename = 'test_suite_bignum_core.generated'
 
 
+class BignumCoreShiftR(BignumCoreTarget, metaclass=ABCMeta):
+    """Test cases for mbedtls_bignum_core_shift_r()."""
+    count = 0
+    test_function = "mpi_core_shift_r"
+    test_name = "Core shift right"
+
+    DATA = [
+        ('00', '0', [0, 1, 8]),
+        ('01', '1', [0, 1, 2, 8, 64]),
+        ('dee5ca1a7ef10a75', '64-bit',
+         list(range(11)) + [31, 32, 33, 63, 64, 65, 71, 72]),
+        ('002e7ab0070ad57001', '[leading 0 limb]',
+         [0, 1, 8, 63, 64]),
+        ('a1055eb0bb1efa1150ff', '80-bit',
+         [0, 1, 8, 63, 64, 65, 72, 79, 80, 81, 88, 128, 129, 136]),
+        ('020100000000000000001011121314151617', '138-bit',
+         [0, 1, 8, 9, 16, 72, 73, 136, 137, 138, 144]),
+    ]
+
+    def __init__(self, input_hex: str, descr: str, count: int) -> None:
+        self.input_hex = input_hex
+        self.number_description = descr
+        self.shift_count = count
+        self.result = bignum_common.hex_to_int(input_hex) >> count
+
+    def arguments(self) -> List[str]:
+        return ['"{}"'.format(self.input_hex),
+                str(self.shift_count),
+                '"{:0{}x}"'.format(self.result, len(self.input_hex))]
+
+    def description(self) -> str:
+        return 'Core shift {} >> {}'.format(self.number_description,
+                                            self.shift_count)
+
+    @classmethod
+    def generate_function_tests(cls) -> Iterator[test_case.TestCase]:
+        for input_hex, descr, counts in cls.DATA:
+            for count in counts:
+                yield cls(input_hex, descr, count).create_test_case()
+
+
 class BignumCoreOperation(bignum_common.OperationCommon, BignumCoreTarget, metaclass=ABCMeta):
     #pylint: disable=abstract-method
     """Common features for bignum core operations."""
diff --git a/tests/suites/test_suite_bignum_core.function b/tests/suites/test_suite_bignum_core.function
index 9803587..fb5fe3a 100644
--- a/tests/suites/test_suite_bignum_core.function
+++ b/tests/suites/test_suite_bignum_core.function
@@ -339,6 +339,26 @@
 /* END_CASE */
 
 /* BEGIN_CASE */
+void mpi_core_shift_r( char *input, int count, char *result )
+{
+    mbedtls_mpi_uint *X = NULL;
+    mbedtls_mpi_uint *Y = NULL;
+    size_t limbs, n;
+
+    TEST_EQUAL( 0, mbedtls_test_read_mpi_core( &X, &limbs, input ) );
+    TEST_EQUAL( 0, mbedtls_test_read_mpi_core( &Y, &n, result ) );
+    TEST_EQUAL( limbs, n );
+
+    mbedtls_mpi_core_shift_r( X, limbs, count );
+    ASSERT_COMPARE( X, limbs * ciL, Y, limbs * ciL );
+
+exit:
+    mbedtls_free( X );
+    mbedtls_free( Y );
+}
+/* END_CASE */
+
+/* BEGIN_CASE */
 void mpi_core_add_if( char * input_A, char * input_B,
                       char * input_S, int carry )
 {