Split generate_tests to reduce code complexity
Previous implementation mixed the test case generation and the
recursive generation calls together. A separate method is added to
generate test cases for the current class' test function. This reduces
the need to override generate_tests().
Signed-off-by: Werner Lewis <werner.lewis@arm.com>
diff --git a/scripts/mbedtls_dev/test_generation.py b/scripts/mbedtls_dev/test_generation.py
index b825df0..aeb551d 100644
--- a/scripts/mbedtls_dev/test_generation.py
+++ b/scripts/mbedtls_dev/test_generation.py
@@ -25,7 +25,7 @@
import re
from abc import ABCMeta, abstractmethod
-from typing import Callable, Dict, Iterable, List, Type, TypeVar
+from typing import Callable, Dict, Iterable, Iterator, List, Type, TypeVar
from mbedtls_dev import build_tree
from mbedtls_dev import test_case
@@ -91,16 +91,31 @@
return tc
@classmethod
- def generate_tests(cls):
- """Generate test cases for the target subclasses.
+ @abstractmethod
+ def generate_function_tests(cls) -> Iterator[test_case.TestCase]:
+ """Generate test cases for the test function.
- During generation, each class will iterate over any subclasses, calling
- this method in each.
- In abstract classes, no tests will be generated, as there is no
- function to generate tests for.
- In classes which do implement a test function, this should be overridden
- and a means to use `create_test_case()` should be added.
+ This will be called in classes where `test_function` is set.
+ Implementations should yield TestCase objects, by creating instances
+ of the class with appropriate input data, and then calling
+ `create_test_case()` on each.
"""
+ pass
+
+ @classmethod
+ def generate_tests(cls) -> Iterator[test_case.TestCase]:
+ """Generate test cases for the class and its subclasses.
+
+ In classes with `test_function` set, `generate_function_tests()` is
+ used to generate test cases first.
+ In all classes, this method will iterate over its subclasses, and
+ yield from `generate_tests()` in each.
+
+ Calling this method on a class X will yield test cases from all classes
+ derived from X.
+ """
+ if cls.test_function:
+ yield from cls.generate_function_tests()
for subclass in sorted(cls.__subclasses__(), key=lambda c: c.__name__):
yield from subclass.generate_tests()