Move PSAMacroEnumerator to macro_collector
It's useful for more than test_psa_constant_names.
Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
diff --git a/scripts/mbedtls_dev/macro_collector.py b/scripts/mbedtls_dev/macro_collector.py
index 7ebd8f7..c9e6ec3 100644
--- a/scripts/mbedtls_dev/macro_collector.py
+++ b/scripts/mbedtls_dev/macro_collector.py
@@ -16,8 +16,115 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import itertools
import re
-from typing import Dict, Set
+from typing import Dict, Iterable, Iterator, List, Set
+
+
+class PSAMacroEnumerator:
+ """Information about constructors of various PSA Crypto types.
+
+ This includes macro names as well as information about their arguments
+ when applicable.
+
+ This class only provides ways to enumerate expressions that evaluate to
+ values of the covered types. Derived classes are expected to populate
+ the set of known constructors of each kind, as well as populate
+ `self.arguments_for` for arguments that are not of a kind that is
+ enumerated here.
+ """
+
+ def __init__(self) -> None:
+ """Set up an empty set of known constructor macros.
+ """
+ self.statuses = set() #type: Set[str]
+ self.algorithms = set() #type: Set[str]
+ self.ecc_curves = set() #type: Set[str]
+ self.dh_groups = set() #type: Set[str]
+ self.key_types = set() #type: Set[str]
+ self.key_usage_flags = set() #type: Set[str]
+ self.hash_algorithms = set() #type: Set[str]
+ self.mac_algorithms = set() #type: Set[str]
+ self.ka_algorithms = set() #type: Set[str]
+ self.kdf_algorithms = set() #type: Set[str]
+ self.aead_algorithms = set() #type: Set[str]
+ # macro name -> list of argument names
+ self.argspecs = {} #type: Dict[str, List[str]]
+ # argument name -> list of values
+ self.arguments_for = {
+ 'mac_length': [],
+ 'min_mac_length': [],
+ 'tag_length': [],
+ 'min_tag_length': [],
+ } #type: Dict[str, List[str]]
+
+ def gather_arguments(self) -> None:
+ """Populate the list of values for macro arguments.
+
+ Call this after parsing all the inputs.
+ """
+ self.arguments_for['hash_alg'] = sorted(self.hash_algorithms)
+ self.arguments_for['mac_alg'] = sorted(self.mac_algorithms)
+ self.arguments_for['ka_alg'] = sorted(self.ka_algorithms)
+ self.arguments_for['kdf_alg'] = sorted(self.kdf_algorithms)
+ self.arguments_for['aead_alg'] = sorted(self.aead_algorithms)
+ self.arguments_for['curve'] = sorted(self.ecc_curves)
+ self.arguments_for['group'] = sorted(self.dh_groups)
+
+ @staticmethod
+ def _format_arguments(name: str, arguments: Iterable[str]) -> str:
+ """Format a macro call with arguments.."""
+ return name + '(' + ', '.join(arguments) + ')'
+
+ _argument_split_re = re.compile(r' *, *')
+ @classmethod
+ def _argument_split(cls, arguments: str) -> List[str]:
+ return re.split(cls._argument_split_re, arguments)
+
+ def distribute_arguments(self, name: str) -> Iterator[str]:
+ """Generate macro calls with each tested argument set.
+
+ If name is a macro without arguments, just yield "name".
+ If name is a macro with arguments, yield a series of
+ "name(arg1,...,argN)" where each argument takes each possible
+ value at least once.
+ """
+ try:
+ if name not in self.argspecs:
+ yield name
+ return
+ argspec = self.argspecs[name]
+ if argspec == []:
+ yield name + '()'
+ return
+ argument_lists = [self.arguments_for[arg] for arg in argspec]
+ arguments = [values[0] for values in argument_lists]
+ yield self._format_arguments(name, arguments)
+ # Dear Pylint, enumerate won't work here since we're modifying
+ # the array.
+ # pylint: disable=consider-using-enumerate
+ for i in range(len(arguments)):
+ for value in argument_lists[i][1:]:
+ arguments[i] = value
+ yield self._format_arguments(name, arguments)
+ arguments[i] = argument_lists[0][0]
+ except BaseException as e:
+ raise Exception('distribute_arguments({})'.format(name)) from e
+
+ def generate_expressions(self, names: Iterable[str]) -> Iterator[str]:
+ """Generate expressions covering values constructed from the given names.
+
+ `names` can be any iterable collection of macro names.
+
+ For example:
+ * ``generate_expressions(['PSA_ALG_CMAC', 'PSA_ALG_HMAC'])``
+ generates ``'PSA_ALG_CMAC'`` as well as ``'PSA_ALG_HMAC(h)'`` for
+ every known hash algorithm ``h``.
+ * ``macros.generate_expressions(macros.key_types)`` generates all
+ key types.
+ """
+ return itertools.chain(*map(self.distribute_arguments, names))
+
class PSAMacroCollector:
"""Collect PSA crypto macro definitions from C header files.