Bignum test: remove type restrictrion
The special case list type depends on the arity and the subclass. Remove
type restriction to make defining special case lists more flexible and natural.
Signed-off-by: Janos Follath <janos.follath@arm.com>
diff --git a/scripts/mbedtls_dev/bignum_common.py b/scripts/mbedtls_dev/bignum_common.py
index 7d7170d..ed321d7 100644
--- a/scripts/mbedtls_dev/bignum_common.py
+++ b/scripts/mbedtls_dev/bignum_common.py
@@ -15,7 +15,8 @@
# limitations under the License.
from abc import abstractmethod
-from typing import Iterator, List, Tuple, TypeVar
+from typing import Iterator, List, Tuple, TypeVar, Any
+from itertools import chain
from . import test_case
from . import test_data_generation
@@ -90,7 +91,7 @@
"""
symbol = ""
input_values = [] # type: List[str]
- input_cases = [] # type: List[Tuple[str, str]]
+ input_cases = [] # type: List[Any]
unique_combinations_only = True
input_styles = ["variable", "arch_split"] # type: List[str]
input_style = "variable" # type: str
@@ -200,7 +201,6 @@
for a in cls.input_values
for b in cls.input_values
)
- yield from cls.input_cases
@classmethod
def generate_function_tests(cls) -> Iterator[test_case.TestCase]:
@@ -212,14 +212,20 @@
test_objects = (cls(a, b, bits_in_limb=bil)
for a, b in cls.get_value_pairs()
for bil in cls.limb_sizes)
+ special_cases = (cls(*args, bits_in_limb=bil) # type: ignore
+ for args in cls.input_cases
+ for bil in cls.limb_sizes)
else:
test_objects = (cls(a, b)
for a, b in cls.get_value_pairs())
+ special_cases = (cls(*args) for args in cls.input_cases)
yield from (valid_test_object.create_test_case()
for valid_test_object in filter(
lambda test_object: test_object.is_valid,
- test_objects
- ))
+ chain(test_objects, special_cases)
+ )
+ )
+
class ModOperationCommon(OperationCommon):
diff --git a/scripts/mbedtls_dev/bignum_core.py b/scripts/mbedtls_dev/bignum_core.py
index 48390b9..1bfc652 100644
--- a/scripts/mbedtls_dev/bignum_core.py
+++ b/scripts/mbedtls_dev/bignum_core.py
@@ -244,6 +244,16 @@
]
@classmethod
+ def get_value_pairs(cls) -> Iterator[Tuple[str, str]]:
+ """Generator to yield pairs of inputs.
+
+ Combinations are first generated from all input values, and then
+ specific cases provided.
+ """
+ yield from super().get_value_pairs()
+ yield from cls.input_cases
+
+ @classmethod
def generate_function_tests(cls) -> Iterator[test_case.TestCase]:
"""Override for additional scalar input."""
for a_value, b_value in cls.get_value_pairs():