Move InputsForTest to macro_collector.py

This is useful to generate PSA tests for more than constant names.

Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
diff --git a/scripts/mbedtls_dev/macro_collector.py b/scripts/mbedtls_dev/macro_collector.py
index a2192ba..0c9a9a5 100644
--- a/scripts/mbedtls_dev/macro_collector.py
+++ b/scripts/mbedtls_dev/macro_collector.py
@@ -18,7 +18,55 @@
 
 import itertools
 import re
-from typing import Dict, Iterable, Iterator, List, Set
+from typing import Dict, Iterable, Iterator, List, Optional, Pattern, Set, Tuple, Union
+
+
+class ReadFileLineException(Exception):
+    def __init__(self, filename: str, line_number: Union[int, str]) -> None:
+        message = 'in {} at {}'.format(filename, line_number)
+        super(ReadFileLineException, self).__init__(message)
+        self.filename = filename
+        self.line_number = line_number
+
+
+class read_file_lines:
+    # Dear Pylint, conventionally, a context manager class name is lowercase.
+    # pylint: disable=invalid-name,too-few-public-methods
+    """Context manager to read a text file line by line.
+
+    ```
+    with read_file_lines(filename) as lines:
+        for line in lines:
+            process(line)
+    ```
+    is equivalent to
+    ```
+    with open(filename, 'r') as input_file:
+        for line in input_file:
+            process(line)
+    ```
+    except that if process(line) raises an exception, then the read_file_lines
+    snippet annotates the exception with the file name and line number.
+    """
+    def __init__(self, filename: str, binary: bool = False) -> None:
+        self.filename = filename
+        self.line_number = 'entry' #type: Union[int, str]
+        self.generator = None #type: Optional[Iterable[Tuple[int, str]]]
+        self.binary = binary
+    def __enter__(self) -> 'read_file_lines':
+        self.generator = enumerate(open(self.filename,
+                                        'rb' if self.binary else 'r'))
+        return self
+    def __iter__(self) -> Iterator[str]:
+        assert self.generator is not None
+        for line_number, content in self.generator:
+            self.line_number = line_number
+            yield content
+        self.line_number = 'exit'
+    def __exit__(self, exc_type, exc_value, exc_traceback) -> None:
+        if exc_type is not None:
+            raise ReadFileLineException(self.filename, self.line_number) \
+                from exc_value
 
 
 class PSAMacroEnumerator:
@@ -251,3 +299,168 @@
                 m = re.search(self._continued_line_re, line)
             line = re.sub(self._nonascii_re, rb'', line).decode('ascii')
             self.read_line(line)
+
+
+class InputsForTest(PSAMacroEnumerator):
+    # pylint: disable=too-many-instance-attributes
+    """Accumulate information about macros to test.
+enumerate
+    This includes macro names as well as information about their arguments
+    when applicable.
+    """
+
+    def __init__(self) -> None:
+        super().__init__()
+        self.all_declared = set() #type: Set[str]
+        # Sets of names per type
+        self.statuses.add('PSA_SUCCESS')
+        self.algorithms.add('0xffffffff')
+        self.ecc_curves.add('0xff')
+        self.dh_groups.add('0xff')
+        self.key_types.add('0xffff')
+        self.key_usage_flags.add('0x80000000')
+
+        # Hard-coded values for unknown algorithms
+        #
+        # These have to have values that are correct for their respective
+        # PSA_ALG_IS_xxx macros, but are also not currently assigned and are
+        # not likely to be assigned in the near future.
+        self.hash_algorithms.add('0x020000fe') # 0x020000ff is PSA_ALG_ANY_HASH
+        self.mac_algorithms.add('0x03007fff')
+        self.ka_algorithms.add('0x09fc0000')
+        self.kdf_algorithms.add('0x080000ff')
+        # For AEAD algorithms, the only variability is over the tag length,
+        # and this only applies to known algorithms, so don't test an
+        # unknown algorithm.
+
+        # Identifier prefixes
+        self.table_by_prefix = {
+            'ERROR': self.statuses,
+            'ALG': self.algorithms,
+            'ECC_CURVE': self.ecc_curves,
+            'DH_GROUP': self.dh_groups,
+            'KEY_TYPE': self.key_types,
+            'KEY_USAGE': self.key_usage_flags,
+        } #type: Dict[str, Set[str]]
+        # Test functions
+        self.table_by_test_function = {
+            # Any function ending in _algorithm also gets added to
+            # self.algorithms.
+            'key_type': [self.key_types],
+            'block_cipher_key_type': [self.key_types],
+            'stream_cipher_key_type': [self.key_types],
+            'ecc_key_family': [self.ecc_curves],
+            'ecc_key_types': [self.ecc_curves],
+            'dh_key_family': [self.dh_groups],
+            'dh_key_types': [self.dh_groups],
+            'hash_algorithm': [self.hash_algorithms],
+            'mac_algorithm': [self.mac_algorithms],
+            'cipher_algorithm': [],
+            'hmac_algorithm': [self.mac_algorithms],
+            'aead_algorithm': [self.aead_algorithms],
+            'key_derivation_algorithm': [self.kdf_algorithms],
+            'key_agreement_algorithm': [self.ka_algorithms],
+            'asymmetric_signature_algorithm': [],
+            'asymmetric_signature_wildcard': [self.algorithms],
+            'asymmetric_encryption_algorithm': [],
+            'other_algorithm': [],
+        } #type: Dict[str, List[Set[str]]]
+        self.arguments_for['mac_length'] += ['1', '63']
+        self.arguments_for['min_mac_length'] += ['1', '63']
+        self.arguments_for['tag_length'] += ['1', '63']
+        self.arguments_for['min_tag_length'] += ['1', '63']
+
+    def get_names(self, type_word: str) -> Set[str]:
+        """Return the set of known names of values of the given type."""
+        return {
+            'status': self.statuses,
+            'algorithm': self.algorithms,
+            'ecc_curve': self.ecc_curves,
+            'dh_group': self.dh_groups,
+            'key_type': self.key_types,
+            'key_usage': self.key_usage_flags,
+        }[type_word]
+
+    # Regex for interesting header lines.
+    # Groups: 1=macro name, 2=type, 3=argument list (optional).
+    _header_line_re = \
+        re.compile(r'#define +' +
+                   r'(PSA_((?:(?:DH|ECC|KEY)_)?[A-Z]+)_\w+)' +
+                   r'(?:\(([^\n()]*)\))?')
+    # Regex of macro names to exclude.
+    _excluded_name_re = re.compile(r'_(?:GET|IS|OF)_|_(?:BASE|FLAG|MASK)\Z')
+    # Additional excluded macros.
+    _excluded_names = set([
+        # Macros that provide an alternative way to build the same
+        # algorithm as another macro.
+        'PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG',
+        'PSA_ALG_FULL_LENGTH_MAC',
+        # Auxiliary macro whose name doesn't fit the usual patterns for
+        # auxiliary macros.
+        'PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG_CASE',
+    ])
+    def parse_header_line(self, line: str) -> None:
+        """Parse a C header line, looking for "#define PSA_xxx"."""
+        m = re.match(self._header_line_re, line)
+        if not m:
+            return
+        name = m.group(1)
+        self.all_declared.add(name)
+        if re.search(self._excluded_name_re, name) or \
+           name in self._excluded_names:
+            return
+        dest = self.table_by_prefix.get(m.group(2))
+        if dest is None:
+            return
+        dest.add(name)
+        if m.group(3):
+            self.argspecs[name] = self._argument_split(m.group(3))
+
+    _nonascii_re = re.compile(rb'[^\x00-\x7f]+') #type: Pattern
+    def parse_header(self, filename: str) -> None:
+        """Parse a C header file, looking for "#define PSA_xxx"."""
+        with read_file_lines(filename, binary=True) as lines:
+            for line in lines:
+                line = re.sub(self._nonascii_re, rb'', line).decode('ascii')
+                self.parse_header_line(line)
+
+    _macro_identifier_re = re.compile(r'[A-Z]\w+')
+    def generate_undeclared_names(self, expr: str) -> Iterable[str]:
+        for name in re.findall(self._macro_identifier_re, expr):
+            if name not in self.all_declared:
+                yield name
+
+    def accept_test_case_line(self, function: str, argument: str) -> bool:
+        #pylint: disable=unused-argument
+        undeclared = list(self.generate_undeclared_names(argument))
+        if undeclared:
+            raise Exception('Undeclared names in test case', undeclared)
+        return True
+
+    def add_test_case_line(self, function: str, argument: str) -> None:
+        """Parse a test case data line, looking for algorithm metadata tests."""
+        sets = []
+        if function.endswith('_algorithm'):
+            sets.append(self.algorithms)
+            if function == 'key_agreement_algorithm' and \
+               argument.startswith('PSA_ALG_KEY_AGREEMENT('):
+                # We only want *raw* key agreement algorithms as such, so
+                # exclude ones that are already chained with a KDF.
+                # Keep the expression as one to test as an algorithm.
+                function = 'other_algorithm'
+        sets += self.table_by_test_function[function]
+        if self.accept_test_case_line(function, argument):
+            for s in sets:
+                s.add(argument)
+
+    # Regex matching a *.data line containing a test function call and
+    # its arguments. The actual definition is partly positional, but this
+    # regex is good enough in practice.
+    _test_case_line_re = re.compile(r'(?!depends_on:)(\w+):([^\n :][^:\n]*)')
+    def parse_test_cases(self, filename: str) -> None:
+        """Parse a test case file (*.data), looking for algorithm metadata tests."""
+        with read_file_lines(filename) as lines:
+            for line in lines:
+                m = re.match(self._test_case_line_re, line)
+                if m:
+                    self.add_test_case_line(m.group(1), m.group(2))
diff --git a/tests/scripts/test_psa_constant_names.py b/tests/scripts/test_psa_constant_names.py
index 237a344..c6f2305 100755
--- a/tests/scripts/test_psa_constant_names.py
+++ b/tests/scripts/test_psa_constant_names.py
@@ -28,223 +28,13 @@
 import re
 import subprocess
 import sys
-from typing import Dict, Iterable, Iterator, List, Optional, Pattern, Set, Tuple, Union
+from typing import Iterable, List, Optional, Tuple
 
 import scripts_path # pylint: disable=unused-import
 from mbedtls_dev import c_build_helper
-from mbedtls_dev.macro_collector import PSAMacroEnumerator
+from mbedtls_dev.macro_collector import InputsForTest, PSAMacroEnumerator
 from mbedtls_dev import typing_util
 
-class ReadFileLineException(Exception):
-    def __init__(self, filename: str, line_number: Union[int, str]) -> None:
-        message = 'in {} at {}'.format(filename, line_number)
-        super(ReadFileLineException, self).__init__(message)
-        self.filename = filename
-        self.line_number = line_number
-
-class read_file_lines:
-    # Dear Pylint, conventionally, a context manager class name is lowercase.
-    # pylint: disable=invalid-name,too-few-public-methods
-    """Context manager to read a text file line by line.
-
-    ```
-    with read_file_lines(filename) as lines:
-        for line in lines:
-            process(line)
-    ```
-    is equivalent to
-    ```
-    with open(filename, 'r') as input_file:
-        for line in input_file:
-            process(line)
-    ```
-    except that if process(line) raises an exception, then the read_file_lines
-    snippet annotates the exception with the file name and line number.
-    """
-    def __init__(self, filename: str, binary: bool = False) -> None:
-        self.filename = filename
-        self.line_number = 'entry' #type: Union[int, str]
-        self.generator = None #type: Optional[Iterable[Tuple[int, str]]]
-        self.binary = binary
-    def __enter__(self) -> 'read_file_lines':
-        self.generator = enumerate(open(self.filename,
-                                        'rb' if self.binary else 'r'))
-        return self
-    def __iter__(self) -> Iterator[str]:
-        assert self.generator is not None
-        for line_number, content in self.generator:
-            self.line_number = line_number
-            yield content
-        self.line_number = 'exit'
-    def __exit__(self, exc_type, exc_value, exc_traceback) -> None:
-        if exc_type is not None:
-            raise ReadFileLineException(self.filename, self.line_number) \
-                from exc_value
-
-class InputsForTest(PSAMacroEnumerator):
-    # pylint: disable=too-many-instance-attributes
-    """Accumulate information about macros to test.
-enumerate
-    This includes macro names as well as information about their arguments
-    when applicable.
-    """
-
-    def __init__(self) -> None:
-        super().__init__()
-        self.all_declared = set() #type: Set[str]
-        # Sets of names per type
-        self.statuses.add('PSA_SUCCESS')
-        self.algorithms.add('0xffffffff')
-        self.ecc_curves.add('0xff')
-        self.dh_groups.add('0xff')
-        self.key_types.add('0xffff')
-        self.key_usage_flags.add('0x80000000')
-
-        # Hard-coded values for unknown algorithms
-        #
-        # These have to have values that are correct for their respective
-        # PSA_ALG_IS_xxx macros, but are also not currently assigned and are
-        # not likely to be assigned in the near future.
-        self.hash_algorithms.add('0x020000fe') # 0x020000ff is PSA_ALG_ANY_HASH
-        self.mac_algorithms.add('0x03007fff')
-        self.ka_algorithms.add('0x09fc0000')
-        self.kdf_algorithms.add('0x080000ff')
-        # For AEAD algorithms, the only variability is over the tag length,
-        # and this only applies to known algorithms, so don't test an
-        # unknown algorithm.
-
-        # Identifier prefixes
-        self.table_by_prefix = {
-            'ERROR': self.statuses,
-            'ALG': self.algorithms,
-            'ECC_CURVE': self.ecc_curves,
-            'DH_GROUP': self.dh_groups,
-            'KEY_TYPE': self.key_types,
-            'KEY_USAGE': self.key_usage_flags,
-        } #type: Dict[str, Set[str]]
-        # Test functions
-        self.table_by_test_function = {
-            # Any function ending in _algorithm also gets added to
-            # self.algorithms.
-            'key_type': [self.key_types],
-            'block_cipher_key_type': [self.key_types],
-            'stream_cipher_key_type': [self.key_types],
-            'ecc_key_family': [self.ecc_curves],
-            'ecc_key_types': [self.ecc_curves],
-            'dh_key_family': [self.dh_groups],
-            'dh_key_types': [self.dh_groups],
-            'hash_algorithm': [self.hash_algorithms],
-            'mac_algorithm': [self.mac_algorithms],
-            'cipher_algorithm': [],
-            'hmac_algorithm': [self.mac_algorithms],
-            'aead_algorithm': [self.aead_algorithms],
-            'key_derivation_algorithm': [self.kdf_algorithms],
-            'key_agreement_algorithm': [self.ka_algorithms],
-            'asymmetric_signature_algorithm': [],
-            'asymmetric_signature_wildcard': [self.algorithms],
-            'asymmetric_encryption_algorithm': [],
-            'other_algorithm': [],
-        } #type: Dict[str, List[Set[str]]]
-        self.arguments_for['mac_length'] += ['1', '63']
-        self.arguments_for['min_mac_length'] += ['1', '63']
-        self.arguments_for['tag_length'] += ['1', '63']
-        self.arguments_for['min_tag_length'] += ['1', '63']
-
-    def get_names(self, type_word: str) -> Set[str]:
-        """Return the set of known names of values of the given type."""
-        return {
-            'status': self.statuses,
-            'algorithm': self.algorithms,
-            'ecc_curve': self.ecc_curves,
-            'dh_group': self.dh_groups,
-            'key_type': self.key_types,
-            'key_usage': self.key_usage_flags,
-        }[type_word]
-
-    # Regex for interesting header lines.
-    # Groups: 1=macro name, 2=type, 3=argument list (optional).
-    _header_line_re = \
-        re.compile(r'#define +' +
-                   r'(PSA_((?:(?:DH|ECC|KEY)_)?[A-Z]+)_\w+)' +
-                   r'(?:\(([^\n()]*)\))?')
-    # Regex of macro names to exclude.
-    _excluded_name_re = re.compile(r'_(?:GET|IS|OF)_|_(?:BASE|FLAG|MASK)\Z')
-    # Additional excluded macros.
-    _excluded_names = set([
-        # Macros that provide an alternative way to build the same
-        # algorithm as another macro.
-        'PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG',
-        'PSA_ALG_FULL_LENGTH_MAC',
-        # Auxiliary macro whose name doesn't fit the usual patterns for
-        # auxiliary macros.
-        'PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG_CASE',
-    ])
-    def parse_header_line(self, line: str) -> None:
-        """Parse a C header line, looking for "#define PSA_xxx"."""
-        m = re.match(self._header_line_re, line)
-        if not m:
-            return
-        name = m.group(1)
-        self.all_declared.add(name)
-        if re.search(self._excluded_name_re, name) or \
-           name in self._excluded_names:
-            return
-        dest = self.table_by_prefix.get(m.group(2))
-        if dest is None:
-            return
-        dest.add(name)
-        if m.group(3):
-            self.argspecs[name] = self._argument_split(m.group(3))
-
-    _nonascii_re = re.compile(rb'[^\x00-\x7f]+') #type: Pattern
-    def parse_header(self, filename: str) -> None:
-        """Parse a C header file, looking for "#define PSA_xxx"."""
-        with read_file_lines(filename, binary=True) as lines:
-            for line in lines:
-                line = re.sub(self._nonascii_re, rb'', line).decode('ascii')
-                self.parse_header_line(line)
-
-    _macro_identifier_re = re.compile(r'[A-Z]\w+')
-    def generate_undeclared_names(self, expr: str) -> Iterable[str]:
-        for name in re.findall(self._macro_identifier_re, expr):
-            if name not in self.all_declared:
-                yield name
-
-    def accept_test_case_line(self, function: str, argument: str) -> bool:
-        #pylint: disable=unused-argument
-        undeclared = list(self.generate_undeclared_names(argument))
-        if undeclared:
-            raise Exception('Undeclared names in test case', undeclared)
-        return True
-
-    def add_test_case_line(self, function: str, argument: str) -> None:
-        """Parse a test case data line, looking for algorithm metadata tests."""
-        sets = []
-        if function.endswith('_algorithm'):
-            sets.append(self.algorithms)
-            if function == 'key_agreement_algorithm' and \
-               argument.startswith('PSA_ALG_KEY_AGREEMENT('):
-                # We only want *raw* key agreement algorithms as such, so
-                # exclude ones that are already chained with a KDF.
-                # Keep the expression as one to test as an algorithm.
-                function = 'other_algorithm'
-        sets += self.table_by_test_function[function]
-        if self.accept_test_case_line(function, argument):
-            for s in sets:
-                s.add(argument)
-
-    # Regex matching a *.data line containing a test function call and
-    # its arguments. The actual definition is partly positional, but this
-    # regex is good enough in practice.
-    _test_case_line_re = re.compile(r'(?!depends_on:)(\w+):([^\n :][^:\n]*)')
-    def parse_test_cases(self, filename: str) -> None:
-        """Parse a test case file (*.data), looking for algorithm metadata tests."""
-        with read_file_lines(filename) as lines:
-            for line in lines:
-                m = re.match(self._test_case_line_re, line)
-                if m:
-                    self.add_test_case_line(m.group(1), m.group(2))
-
 def gather_inputs(headers: Iterable[str],
                   test_suites: Iterable[str],
                   inputs_class=InputsForTest) -> PSAMacroEnumerator: