Add type annotations

Prepare to move InputsForTest to macro_collector.py.

Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
diff --git a/tests/scripts/test_psa_constant_names.py b/tests/scripts/test_psa_constant_names.py
index 15b83d8..237a344 100755
--- a/tests/scripts/test_psa_constant_names.py
+++ b/tests/scripts/test_psa_constant_names.py
@@ -28,13 +28,15 @@
 import re
 import subprocess
 import sys
+from typing import Dict, Iterable, Iterator, List, Optional, Pattern, Set, Tuple, Union
 
 import scripts_path # pylint: disable=unused-import
 from mbedtls_dev import c_build_helper
-from mbedtls_dev import macro_collector
+from mbedtls_dev.macro_collector import PSAMacroEnumerator
+from mbedtls_dev import typing_util
 
 class ReadFileLineException(Exception):
-    def __init__(self, filename, line_number):
+    def __init__(self, filename: str, line_number: Union[int, str]) -> None:
         message = 'in {} at {}'.format(filename, line_number)
         super(ReadFileLineException, self).__init__(message)
         self.filename = filename
@@ -59,36 +61,37 @@
     except that if process(line) raises an exception, then the read_file_lines
     snippet annotates the exception with the file name and line number.
     """
-    def __init__(self, filename, binary=False):
+    def __init__(self, filename: str, binary: bool = False) -> None:
         self.filename = filename
-        self.line_number = 'entry'
-        self.generator = None
+        self.line_number = 'entry' #type: Union[int, str]
+        self.generator = None #type: Optional[Iterable[Tuple[int, str]]]
         self.binary = binary
-    def __enter__(self):
+    def __enter__(self) -> 'read_file_lines':
         self.generator = enumerate(open(self.filename,
                                         'rb' if self.binary else 'r'))
         return self
-    def __iter__(self):
+    def __iter__(self) -> Iterator[str]:
+        assert self.generator is not None
         for line_number, content in self.generator:
             self.line_number = line_number
             yield content
         self.line_number = 'exit'
-    def __exit__(self, exc_type, exc_value, exc_traceback):
+    def __exit__(self, exc_type, exc_value, exc_traceback) -> None:
         if exc_type is not None:
             raise ReadFileLineException(self.filename, self.line_number) \
                 from exc_value
 
-class InputsForTest(macro_collector.PSAMacroEnumerator):
+class InputsForTest(PSAMacroEnumerator):
     # pylint: disable=too-many-instance-attributes
     """Accumulate information about macros to test.
-
+enumerate
     This includes macro names as well as information about their arguments
     when applicable.
     """
 
-    def __init__(self):
+    def __init__(self) -> None:
         super().__init__()
-        self.all_declared = set()
+        self.all_declared = set() #type: Set[str]
         # Sets of names per type
         self.statuses.add('PSA_SUCCESS')
         self.algorithms.add('0xffffffff')
@@ -118,7 +121,7 @@
             'DH_GROUP': self.dh_groups,
             'KEY_TYPE': self.key_types,
             'KEY_USAGE': self.key_usage_flags,
-        }
+        } #type: Dict[str, Set[str]]
         # Test functions
         self.table_by_test_function = {
             # Any function ending in _algorithm also gets added to
@@ -141,13 +144,13 @@
             'asymmetric_signature_wildcard': [self.algorithms],
             'asymmetric_encryption_algorithm': [],
             'other_algorithm': [],
-        }
+        } #type: Dict[str, List[Set[str]]]
         self.arguments_for['mac_length'] += ['1', '63']
         self.arguments_for['min_mac_length'] += ['1', '63']
         self.arguments_for['tag_length'] += ['1', '63']
         self.arguments_for['min_tag_length'] += ['1', '63']
 
-    def get_names(self, type_word):
+    def get_names(self, type_word: str) -> Set[str]:
         """Return the set of known names of values of the given type."""
         return {
             'status': self.statuses,
@@ -176,7 +179,7 @@
         # auxiliary macros.
         'PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG_CASE',
     ])
-    def parse_header_line(self, line):
+    def parse_header_line(self, line: str) -> None:
         """Parse a C header line, looking for "#define PSA_xxx"."""
         m = re.match(self._header_line_re, line)
         if not m:
@@ -193,8 +196,8 @@
         if m.group(3):
             self.argspecs[name] = self._argument_split(m.group(3))
 
-    _nonascii_re = re.compile(rb'[^\x00-\x7f]+')
-    def parse_header(self, filename):
+    _nonascii_re = re.compile(rb'[^\x00-\x7f]+') #type: Pattern
+    def parse_header(self, filename: str) -> None:
         """Parse a C header file, looking for "#define PSA_xxx"."""
         with read_file_lines(filename, binary=True) as lines:
             for line in lines:
@@ -202,19 +205,19 @@
                 self.parse_header_line(line)
 
     _macro_identifier_re = re.compile(r'[A-Z]\w+')
-    def generate_undeclared_names(self, expr):
+    def generate_undeclared_names(self, expr: str) -> Iterable[str]:
         for name in re.findall(self._macro_identifier_re, expr):
             if name not in self.all_declared:
                 yield name
 
-    def accept_test_case_line(self, function, argument):
+    def accept_test_case_line(self, function: str, argument: str) -> bool:
         #pylint: disable=unused-argument
         undeclared = list(self.generate_undeclared_names(argument))
         if undeclared:
             raise Exception('Undeclared names in test case', undeclared)
         return True
 
-    def add_test_case_line(self, function, argument):
+    def add_test_case_line(self, function: str, argument: str) -> None:
         """Parse a test case data line, looking for algorithm metadata tests."""
         sets = []
         if function.endswith('_algorithm'):
@@ -234,7 +237,7 @@
     # its arguments. The actual definition is partly positional, but this
     # regex is good enough in practice.
     _test_case_line_re = re.compile(r'(?!depends_on:)(\w+):([^\n :][^:\n]*)')
-    def parse_test_cases(self, filename):
+    def parse_test_cases(self, filename: str) -> None:
         """Parse a test case file (*.data), looking for algorithm metadata tests."""
         with read_file_lines(filename) as lines:
             for line in lines:
@@ -242,7 +245,9 @@
                 if m:
                     self.add_test_case_line(m.group(1), m.group(2))
 
-def gather_inputs(headers, test_suites, inputs_class=InputsForTest):
+def gather_inputs(headers: Iterable[str],
+                  test_suites: Iterable[str],
+                  inputs_class=InputsForTest) -> PSAMacroEnumerator:
     """Read the list of inputs to test psa_constant_names with."""
     inputs = inputs_class()
     for header in headers:
@@ -252,7 +257,10 @@
     inputs.gather_arguments()
     return inputs
 
-def run_c(type_word, expressions, include_path=None, keep_c=False):
+def run_c(type_word: str,
+          expressions: Iterable[str],
+          include_path: Optional[str] = None,
+          keep_c: bool = False) -> List[str]:
     """Generate and run a program to print out numerical values of C expressions."""
     if type_word == 'status':
         cast_to = 'long'
@@ -271,14 +279,17 @@
     )
 
 NORMALIZE_STRIP_RE = re.compile(r'\s+')
-def normalize(expr):
+def normalize(expr: str) -> str:
     """Normalize the C expression so as not to care about trivial differences.
 
     Currently "trivial differences" means whitespace.
     """
     return re.sub(NORMALIZE_STRIP_RE, '', expr)
 
-def collect_values(inputs, type_word, include_path=None, keep_c=False):
+def collect_values(inputs: InputsForTest,
+                   type_word: str,
+                   include_path: Optional[str] = None,
+                   keep_c: bool = False) -> Tuple[List[str], List[str]]:
     """Generate expressions using known macro names and calculate their values.
 
     Return a list of pairs of (expr, value) where expr is an expression and
@@ -296,12 +307,12 @@
     Error = namedtuple('Error',
                        ['type', 'expression', 'value', 'output'])
 
-    def __init__(self, options):
+    def __init__(self, options) -> None:
         self.options = options
         self.count = 0
-        self.errors = []
+        self.errors = [] #type: List[Tests.Error]
 
-    def run_one(self, inputs, type_word):
+    def run_one(self, inputs: InputsForTest, type_word: str) -> None:
         """Test psa_constant_names for the specified type.
 
         Run the program on the names for this type.
@@ -311,9 +322,10 @@
         expressions, values = collect_values(inputs, type_word,
                                              include_path=self.options.include,
                                              keep_c=self.options.keep_c)
-        output = subprocess.check_output([self.options.program, type_word] +
-                                         values)
-        outputs = output.decode('ascii').strip().split('\n')
+        output_bytes = subprocess.check_output([self.options.program,
+                                                type_word] + values)
+        output = output_bytes.decode('ascii')
+        outputs = output.strip().split('\n')
         self.count += len(expressions)
         for expr, value, output in zip(expressions, values, outputs):
             if self.options.show:
@@ -324,13 +336,13 @@
                                               value=value,
                                               output=output))
 
-    def run_all(self, inputs):
+    def run_all(self, inputs: InputsForTest) -> None:
         """Run psa_constant_names on all the gathered inputs."""
         for type_word in ['status', 'algorithm', 'ecc_curve', 'dh_group',
                           'key_type', 'key_usage']:
             self.run_one(inputs, type_word)
 
-    def report(self, out):
+    def report(self, out: typing_util.Writable) -> None:
         """Describe each case where the output is not as expected.
 
         Write the errors to ``out``.