Move InputsForTest to macro_collector.py

This is useful to generate PSA tests for more than constant names.

Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
diff --git a/scripts/mbedtls_dev/macro_collector.py b/scripts/mbedtls_dev/macro_collector.py
index a2192ba..0c9a9a5 100644
--- a/scripts/mbedtls_dev/macro_collector.py
+++ b/scripts/mbedtls_dev/macro_collector.py
@@ -18,7 +18,55 @@
 
 import itertools
 import re
-from typing import Dict, Iterable, Iterator, List, Set
+from typing import Dict, Iterable, Iterator, List, Optional, Pattern, Set, Tuple, Union
+
+
+class ReadFileLineException(Exception):
+    def __init__(self, filename: str, line_number: Union[int, str]) -> None:
+        message = 'in {} at {}'.format(filename, line_number)
+        super(ReadFileLineException, self).__init__(message)
+        self.filename = filename
+        self.line_number = line_number
+
+
+class read_file_lines:
+    # Dear Pylint, conventionally, a context manager class name is lowercase.
+    # pylint: disable=invalid-name,too-few-public-methods
+    """Context manager to read a text file line by line.
+
+    ```
+    with read_file_lines(filename) as lines:
+        for line in lines:
+            process(line)
+    ```
+    is equivalent to
+    ```
+    with open(filename, 'r') as input_file:
+        for line in input_file:
+            process(line)
+    ```
+    except that if process(line) raises an exception, then the read_file_lines
+    snippet annotates the exception with the file name and line number.
+    """
+    def __init__(self, filename: str, binary: bool = False) -> None:
+        self.filename = filename
+        self.line_number = 'entry' #type: Union[int, str]
+        self.generator = None #type: Optional[Iterable[Tuple[int, str]]]
+        self.binary = binary
+    def __enter__(self) -> 'read_file_lines':
+        self.generator = enumerate(open(self.filename,
+                                        'rb' if self.binary else 'r'))
+        return self
+    def __iter__(self) -> Iterator[str]:
+        assert self.generator is not None
+        for line_number, content in self.generator:
+            self.line_number = line_number
+            yield content
+        self.line_number = 'exit'
+    def __exit__(self, exc_type, exc_value, exc_traceback) -> None:
+        if exc_type is not None:
+            raise ReadFileLineException(self.filename, self.line_number) \
+                from exc_value
 
 
 class PSAMacroEnumerator:
@@ -251,3 +299,168 @@
                 m = re.search(self._continued_line_re, line)
             line = re.sub(self._nonascii_re, rb'', line).decode('ascii')
             self.read_line(line)
+
+
+class InputsForTest(PSAMacroEnumerator):
+    # pylint: disable=too-many-instance-attributes
+    """Accumulate information about macros to test.
+enumerate
+    This includes macro names as well as information about their arguments
+    when applicable.
+    """
+
+    def __init__(self) -> None:
+        super().__init__()
+        self.all_declared = set() #type: Set[str]
+        # Sets of names per type
+        self.statuses.add('PSA_SUCCESS')
+        self.algorithms.add('0xffffffff')
+        self.ecc_curves.add('0xff')
+        self.dh_groups.add('0xff')
+        self.key_types.add('0xffff')
+        self.key_usage_flags.add('0x80000000')
+
+        # Hard-coded values for unknown algorithms
+        #
+        # These have to have values that are correct for their respective
+        # PSA_ALG_IS_xxx macros, but are also not currently assigned and are
+        # not likely to be assigned in the near future.
+        self.hash_algorithms.add('0x020000fe') # 0x020000ff is PSA_ALG_ANY_HASH
+        self.mac_algorithms.add('0x03007fff')
+        self.ka_algorithms.add('0x09fc0000')
+        self.kdf_algorithms.add('0x080000ff')
+        # For AEAD algorithms, the only variability is over the tag length,
+        # and this only applies to known algorithms, so don't test an
+        # unknown algorithm.
+
+        # Identifier prefixes
+        self.table_by_prefix = {
+            'ERROR': self.statuses,
+            'ALG': self.algorithms,
+            'ECC_CURVE': self.ecc_curves,
+            'DH_GROUP': self.dh_groups,
+            'KEY_TYPE': self.key_types,
+            'KEY_USAGE': self.key_usage_flags,
+        } #type: Dict[str, Set[str]]
+        # Test functions
+        self.table_by_test_function = {
+            # Any function ending in _algorithm also gets added to
+            # self.algorithms.
+            'key_type': [self.key_types],
+            'block_cipher_key_type': [self.key_types],
+            'stream_cipher_key_type': [self.key_types],
+            'ecc_key_family': [self.ecc_curves],
+            'ecc_key_types': [self.ecc_curves],
+            'dh_key_family': [self.dh_groups],
+            'dh_key_types': [self.dh_groups],
+            'hash_algorithm': [self.hash_algorithms],
+            'mac_algorithm': [self.mac_algorithms],
+            'cipher_algorithm': [],
+            'hmac_algorithm': [self.mac_algorithms],
+            'aead_algorithm': [self.aead_algorithms],
+            'key_derivation_algorithm': [self.kdf_algorithms],
+            'key_agreement_algorithm': [self.ka_algorithms],
+            'asymmetric_signature_algorithm': [],
+            'asymmetric_signature_wildcard': [self.algorithms],
+            'asymmetric_encryption_algorithm': [],
+            'other_algorithm': [],
+        } #type: Dict[str, List[Set[str]]]
+        self.arguments_for['mac_length'] += ['1', '63']
+        self.arguments_for['min_mac_length'] += ['1', '63']
+        self.arguments_for['tag_length'] += ['1', '63']
+        self.arguments_for['min_tag_length'] += ['1', '63']
+
+    def get_names(self, type_word: str) -> Set[str]:
+        """Return the set of known names of values of the given type."""
+        return {
+            'status': self.statuses,
+            'algorithm': self.algorithms,
+            'ecc_curve': self.ecc_curves,
+            'dh_group': self.dh_groups,
+            'key_type': self.key_types,
+            'key_usage': self.key_usage_flags,
+        }[type_word]
+
+    # Regex for interesting header lines.
+    # Groups: 1=macro name, 2=type, 3=argument list (optional).
+    _header_line_re = \
+        re.compile(r'#define +' +
+                   r'(PSA_((?:(?:DH|ECC|KEY)_)?[A-Z]+)_\w+)' +
+                   r'(?:\(([^\n()]*)\))?')
+    # Regex of macro names to exclude.
+    _excluded_name_re = re.compile(r'_(?:GET|IS|OF)_|_(?:BASE|FLAG|MASK)\Z')
+    # Additional excluded macros.
+    _excluded_names = set([
+        # Macros that provide an alternative way to build the same
+        # algorithm as another macro.
+        'PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG',
+        'PSA_ALG_FULL_LENGTH_MAC',
+        # Auxiliary macro whose name doesn't fit the usual patterns for
+        # auxiliary macros.
+        'PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG_CASE',
+    ])
+    def parse_header_line(self, line: str) -> None:
+        """Parse a C header line, looking for "#define PSA_xxx"."""
+        m = re.match(self._header_line_re, line)
+        if not m:
+            return
+        name = m.group(1)
+        self.all_declared.add(name)
+        if re.search(self._excluded_name_re, name) or \
+           name in self._excluded_names:
+            return
+        dest = self.table_by_prefix.get(m.group(2))
+        if dest is None:
+            return
+        dest.add(name)
+        if m.group(3):
+            self.argspecs[name] = self._argument_split(m.group(3))
+
+    _nonascii_re = re.compile(rb'[^\x00-\x7f]+') #type: Pattern
+    def parse_header(self, filename: str) -> None:
+        """Parse a C header file, looking for "#define PSA_xxx"."""
+        with read_file_lines(filename, binary=True) as lines:
+            for line in lines:
+                line = re.sub(self._nonascii_re, rb'', line).decode('ascii')
+                self.parse_header_line(line)
+
+    _macro_identifier_re = re.compile(r'[A-Z]\w+')
+    def generate_undeclared_names(self, expr: str) -> Iterable[str]:
+        for name in re.findall(self._macro_identifier_re, expr):
+            if name not in self.all_declared:
+                yield name
+
+    def accept_test_case_line(self, function: str, argument: str) -> bool:
+        #pylint: disable=unused-argument
+        undeclared = list(self.generate_undeclared_names(argument))
+        if undeclared:
+            raise Exception('Undeclared names in test case', undeclared)
+        return True
+
+    def add_test_case_line(self, function: str, argument: str) -> None:
+        """Parse a test case data line, looking for algorithm metadata tests."""
+        sets = []
+        if function.endswith('_algorithm'):
+            sets.append(self.algorithms)
+            if function == 'key_agreement_algorithm' and \
+               argument.startswith('PSA_ALG_KEY_AGREEMENT('):
+                # We only want *raw* key agreement algorithms as such, so
+                # exclude ones that are already chained with a KDF.
+                # Keep the expression as one to test as an algorithm.
+                function = 'other_algorithm'
+        sets += self.table_by_test_function[function]
+        if self.accept_test_case_line(function, argument):
+            for s in sets:
+                s.add(argument)
+
+    # Regex matching a *.data line containing a test function call and
+    # its arguments. The actual definition is partly positional, but this
+    # regex is good enough in practice.
+    _test_case_line_re = re.compile(r'(?!depends_on:)(\w+):([^\n :][^:\n]*)')
+    def parse_test_cases(self, filename: str) -> None:
+        """Parse a test case file (*.data), looking for algorithm metadata tests."""
+        with read_file_lines(filename) as lines:
+            for line in lines:
+                m = re.match(self._test_case_line_re, line)
+                if m:
+                    self.add_test_case_line(m.group(1), m.group(2))