Merge pull request #6446 from yanesca/add_split_arch_tests_to_bignum_core

Add split arch tests to bignum core
diff --git a/scripts/mbedtls_dev/bignum_core.py b/scripts/mbedtls_dev/bignum_core.py
index 3652ac2..2e64195 100644
--- a/scripts/mbedtls_dev/bignum_core.py
+++ b/scripts/mbedtls_dev/bignum_core.py
@@ -61,8 +61,8 @@
         generated to provide some context to the test case.
         """
         if not self.case_description:
-            self.case_description = "{} {} {}".format(
-                self.arg_a, self.symbol, self.arg_b
+            self.case_description = "{:x} {} {:x}".format(
+                self.int_a, self.symbol, self.int_b
             )
         return super().description()
 
@@ -72,8 +72,38 @@
             yield cls(a_value, b_value).create_test_case()
 
 
+class BignumCoreOperationArchSplit(BignumCoreOperation):
+    #pylint: disable=abstract-method
+    """Common features for bignum core operations where the result depends on
+    the limb size."""
 
-class BignumCoreAddIf(BignumCoreOperation):
+    def __init__(self, val_a: str, val_b: str, bits_in_limb: int) -> None:
+        super().__init__(val_a, val_b)
+        bound_val = max(self.int_a, self.int_b)
+        self.bits_in_limb = bits_in_limb
+        self.bound = bignum_common.bound_mpi(bound_val, self.bits_in_limb)
+        limbs = bignum_common.limbs_mpi(bound_val, self.bits_in_limb)
+        byte_len = limbs * self.bits_in_limb // 8
+        self.hex_digits = 2 * byte_len
+        if self.bits_in_limb == 32:
+            self.dependencies = ["MBEDTLS_HAVE_INT32"]
+        elif self.bits_in_limb == 64:
+            self.dependencies = ["MBEDTLS_HAVE_INT64"]
+        else:
+            raise ValueError("Invalid number of bits in limb!")
+        self.arg_a = self.arg_a.zfill(self.hex_digits)
+        self.arg_b = self.arg_b.zfill(self.hex_digits)
+
+    def pad_to_limbs(self, val) -> str:
+        return "{:x}".format(val).zfill(self.hex_digits)
+
+    @classmethod
+    def generate_function_tests(cls) -> Iterator[test_case.TestCase]:
+        for a_value, b_value in cls.get_value_pairs():
+            yield cls(a_value, b_value, 32).create_test_case()
+            yield cls(a_value, b_value, 64).create_test_case()
+
+class BignumCoreAddIf(BignumCoreOperationArchSplit):
     """Test cases for bignum core add if."""
     count = 0
     symbol = "+"
@@ -81,19 +111,14 @@
     test_name = "mbedtls_mpi_core_add_if"
 
     def result(self) -> List[str]:
-        tmp = self.int_a + self.int_b
-        bound_val = max(self.int_a, self.int_b)
-        bound_4 = bignum_common.bound_mpi(bound_val, 32)
-        bound_8 = bignum_common.bound_mpi(bound_val, 64)
-        carry_4, remainder_4 = divmod(tmp, bound_4)
-        carry_8, remainder_8 = divmod(tmp, bound_8)
-        return [
-            "\"{:x}\"".format(remainder_4),
-            str(carry_4),
-            "\"{:x}\"".format(remainder_8),
-            str(carry_8)
-        ]
+        result = self.int_a + self.int_b
 
+        carry, result = divmod(result, self.bound)
+
+        return [
+            bignum_common.quote_str(self.pad_to_limbs(result)),
+            str(carry)
+        ]
 
 class BignumCoreSub(BignumCoreOperation):
     """Test cases for bignum core sub."""
diff --git a/tests/suites/test_suite_bignum_core.function b/tests/suites/test_suite_bignum_core.function
index de8b7f1..9803587 100644
--- a/tests/suites/test_suite_bignum_core.function
+++ b/tests/suites/test_suite_bignum_core.function
@@ -340,118 +340,75 @@
 
 /* BEGIN_CASE */
 void mpi_core_add_if( char * input_A, char * input_B,
-                      char * input_S4, int carry4,
-                      char * input_S8, int carry8 )
+                      char * input_S, int carry )
 {
-    mbedtls_mpi S4, S8, A, B;
-    mbedtls_mpi_uint *a = NULL; /* first value to add */
-    mbedtls_mpi_uint *b = NULL; /* second value to add */
-    mbedtls_mpi_uint *sum = NULL;
-    mbedtls_mpi_uint *d = NULL; /* destination - the in/out first operand */
+    mbedtls_mpi_uint *A = NULL; /* first value to add */
+    size_t A_limbs;
+    mbedtls_mpi_uint *B = NULL; /* second value to add */
+    size_t B_limbs;
+    mbedtls_mpi_uint *S = NULL; /* expected result */
+    size_t S_limbs;
+    mbedtls_mpi_uint *X = NULL; /* destination - the in/out first operand */
+    size_t X_limbs;
 
-    mbedtls_mpi_init( &A );
-    mbedtls_mpi_init( &B );
-    mbedtls_mpi_init( &S4 );
-    mbedtls_mpi_init( &S8 );
+    TEST_EQUAL( 0, mbedtls_test_read_mpi_core( &A, &A_limbs, input_A ) );
+    TEST_EQUAL( 0, mbedtls_test_read_mpi_core( &B, &B_limbs, input_B ) );
+    TEST_EQUAL( 0, mbedtls_test_read_mpi_core( &S, &S_limbs, input_S ) );
+    X_limbs = S_limbs;
+    ASSERT_ALLOC( X, X_limbs );
 
-    TEST_EQUAL( 0, mbedtls_test_read_mpi( &A, input_A ) );
-    TEST_EQUAL( 0, mbedtls_test_read_mpi( &B, input_B ) );
-    TEST_EQUAL( 0, mbedtls_test_read_mpi( &S4, input_S4 ) );
-    TEST_EQUAL( 0, mbedtls_test_read_mpi( &S8, input_S8 ) );
+    /* add_if expects all operands to be the same length */
+    TEST_EQUAL( A_limbs, B_limbs );
+    TEST_EQUAL( A_limbs, S_limbs );
+    size_t limbs = A_limbs;
+    size_t bytes = limbs * sizeof( *A );
 
-    /* We only need to work with one of (S4, carry4) or (S8, carry8) depending
-     * on sizeof(mbedtls_mpi_uint)
-     */
-    mbedtls_mpi *X = ( sizeof(mbedtls_mpi_uint) == 4 ) ? &S4 : &S8;
-    mbedtls_mpi_uint carry = ( sizeof(mbedtls_mpi_uint) == 4 ) ? carry4 : carry8;
+    /* The test cases have A <= B to avoid repetition, so we test A + B then,
+     * if A != B, B + A. If A == B, we can test when A and B are aliased */
 
-    /* All of the inputs are +ve (or zero) */
-    TEST_EQUAL( 1, A.s );
-    TEST_EQUAL( 1, B.s );
-    TEST_EQUAL( 1, X->s );
+    /* A + B */
 
-    /* Test cases are such that A <= B, so #limbs should be <= */
-    TEST_LE_U( A.n, B.n );
-    TEST_LE_U( X->n, B.n );
-
-    /* Now let's get arrays of mbedtls_mpi_uints, rather than MPI structures */
-
-    /* mbedtls_mpi_core_add_if() uses input arrays of mbedtls_mpi_uints which
-     * must be the same size. The MPIs we've read in will only have arrays
-     * large enough for the number they represent. Therefore we create new
-     * raw arrays of mbedtls_mpi_uints and populate them from the MPIs we've
-     * just read in.
-     *
-     * We generated test data such that B was always >= A, so that's how many
-     * limbs each of these need.
-     */
-    size_t limbs = B.n;
-    size_t bytes = limbs * sizeof(mbedtls_mpi_uint);
-
-    /* ASSERT_ALLOC() uses calloc() under the hood, so these do get zeroed */
-    ASSERT_ALLOC( a, bytes );
-    ASSERT_ALLOC( b, bytes );
-    ASSERT_ALLOC( sum, bytes );
-    ASSERT_ALLOC( d, bytes );
-
-    /* Populate the arrays. As the mbedtls_mpi_uint[]s in mbedtls_mpis (and as
-     * processed by mbedtls_mpi_core_add_if()) are little endian, we can just
-     * copy what we have as long as MSBs are 0 (which they are from ASSERT_ALLOC())
-     */
-    memcpy( a, A.p, A.n * sizeof(mbedtls_mpi_uint) );
-    memcpy( b, B.p, B.n * sizeof(mbedtls_mpi_uint) );
-    memcpy( sum, X->p, X->n * sizeof(mbedtls_mpi_uint) );
-
-    /* The test cases have a <= b to avoid repetition, so we test a + b then,
-     * if a != b, b + a. If a == b, we can test when a and b are aliased */
-
-    /* a + b */
-
-    /* cond = 0 => d unchanged, no carry */
-    memcpy( d, a, bytes );
-    TEST_EQUAL( 0, mbedtls_mpi_core_add_if( d, b, limbs, 0 ) );
-    ASSERT_COMPARE( d, bytes, a, bytes );
+    /* cond = 0 => X unchanged, no carry */
+    memcpy( X, A, bytes );
+    TEST_EQUAL( 0, mbedtls_mpi_core_add_if( X, B, limbs, 0 ) );
+    ASSERT_COMPARE( X, bytes, A, bytes );
 
     /* cond = 1 => correct result and carry */
-    TEST_EQUAL( carry, mbedtls_mpi_core_add_if( d, b, limbs, 1 ) );
-    ASSERT_COMPARE( d, bytes, sum, bytes );
+    TEST_EQUAL( carry, mbedtls_mpi_core_add_if( X, B, limbs, 1 ) );
+    ASSERT_COMPARE( X, bytes, S, bytes );
 
-    if ( A.n == B.n && memcmp( A.p, B.p, bytes ) == 0 )
+    if ( memcmp( A, B, bytes ) == 0 )
     {
-        /* a == b, so test where a and b are aliased */
+        /* A == B, so test where A and B are aliased */
 
-        /* cond = 0 => d unchanged, no carry */
-        TEST_EQUAL( 0, mbedtls_mpi_core_add_if( b, b, limbs, 0 ) );
-        ASSERT_COMPARE( b, bytes, B.p, bytes );
+        /* cond = 0 => X unchanged, no carry */
+        memcpy( X, B, bytes );
+        TEST_EQUAL( 0, mbedtls_mpi_core_add_if( X, X, limbs, 0 ) );
+        ASSERT_COMPARE( X, bytes, B, bytes );
 
         /* cond = 1 => correct result and carry */
-        TEST_EQUAL( carry, mbedtls_mpi_core_add_if( b, b, limbs, 1 ) );
-        ASSERT_COMPARE( b, bytes, sum, bytes );
+        TEST_EQUAL( carry, mbedtls_mpi_core_add_if( X, X, limbs, 1 ) );
+        ASSERT_COMPARE( X, bytes, S, bytes );
     }
     else
     {
-        /* a != b, so test b + a */
+        /* A != B, so test B + A */
 
         /* cond = 0 => d unchanged, no carry */
-        memcpy( d, b, bytes );
-        TEST_EQUAL( 0, mbedtls_mpi_core_add_if( d, a, limbs, 0 ) );
-        ASSERT_COMPARE( d, bytes, b, bytes );
+        memcpy( X, B, bytes );
+        TEST_EQUAL( 0, mbedtls_mpi_core_add_if( X, A, limbs, 0 ) );
+        ASSERT_COMPARE( X, bytes, B, bytes );
 
         /* cond = 1 => correct result and carry */
-        TEST_EQUAL( carry, mbedtls_mpi_core_add_if( d, a, limbs, 1 ) );
-        ASSERT_COMPARE( d, bytes, sum, bytes );
+        TEST_EQUAL( carry, mbedtls_mpi_core_add_if( X, A, limbs, 1 ) );
+        ASSERT_COMPARE( X, bytes, S, bytes );
     }
 
 exit:
-    mbedtls_free( a );
-    mbedtls_free( b );
-    mbedtls_free( sum );
-    mbedtls_free( d );
-
-    mbedtls_mpi_free( &S4 );
-    mbedtls_mpi_free( &S8 );
-    mbedtls_mpi_free( &A );
-    mbedtls_mpi_free( &B );
+    mbedtls_free( A );
+    mbedtls_free( B );
+    mbedtls_free( S );
+    mbedtls_free( X );
 }
 /* END_CASE */