Merge pull request #324 from gilles-peskine-arm/psa-test_psa_constant_names-refactor_and_ka

test_psa_constant_names: support key agreement, better code structure
diff --git a/tests/scripts/test_psa_constant_names.py b/tests/scripts/test_psa_constant_names.py
index 724f8d9..8931987 100755
--- a/tests/scripts/test_psa_constant_names.py
+++ b/tests/scripts/test_psa_constant_names.py
@@ -8,6 +8,7 @@
 """
 
 import argparse
+from collections import namedtuple
 import itertools
 import os
 import platform
@@ -60,12 +61,15 @@
                 from exc_value
 
 class Inputs:
+    # pylint: disable=too-many-instance-attributes
     """Accumulate information about macros to test.
+
     This includes macro names as well as information about their arguments
     when applicable.
     """
 
     def __init__(self):
+        self.all_declared = set()
         # Sets of names per type
         self.statuses = set(['PSA_SUCCESS'])
         self.algorithms = set(['0xffffffff'])
@@ -86,11 +90,30 @@
         self.table_by_prefix = {
             'ERROR': self.statuses,
             'ALG': self.algorithms,
-            'CURVE': self.ecc_curves,
-            'GROUP': self.dh_groups,
+            'ECC_CURVE': self.ecc_curves,
+            'DH_GROUP': self.dh_groups,
             'KEY_TYPE': self.key_types,
             'KEY_USAGE': self.key_usage_flags,
         }
+        # Test functions
+        self.table_by_test_function = {
+            # Any function ending in _algorithm also gets added to
+            # self.algorithms.
+            'key_type': [self.key_types],
+            'ecc_key_types': [self.ecc_curves],
+            'dh_key_types': [self.dh_groups],
+            'hash_algorithm': [self.hash_algorithms],
+            'mac_algorithm': [self.mac_algorithms],
+            'cipher_algorithm': [],
+            'hmac_algorithm': [self.mac_algorithms],
+            'aead_algorithm': [self.aead_algorithms],
+            'key_derivation_algorithm': [self.kdf_algorithms],
+            'key_agreement_algorithm': [self.ka_algorithms],
+            'asymmetric_signature_algorithm': [],
+            'asymmetric_signature_wildcard': [self.algorithms],
+            'asymmetric_encryption_algorithm': [],
+            'other_algorithm': [],
+        }
         # macro name -> list of argument names
         self.argspecs = {}
         # argument name -> list of values
@@ -99,8 +122,20 @@
             'tag_length': ['1', '63'],
         }
 
+    def get_names(self, type_word):
+        """Return the set of known names of values of the given type."""
+        return {
+            'status': self.statuses,
+            'algorithm': self.algorithms,
+            'ecc_curve': self.ecc_curves,
+            'dh_group': self.dh_groups,
+            'key_type': self.key_types,
+            'key_usage': self.key_usage_flags,
+        }[type_word]
+
     def gather_arguments(self):
         """Populate the list of values for macro arguments.
+
         Call this after parsing all the inputs.
         """
         self.arguments_for['hash_alg'] = sorted(self.hash_algorithms)
@@ -118,6 +153,7 @@
 
     def distribute_arguments(self, name):
         """Generate macro calls with each tested argument set.
+
         If name is a macro without arguments, just yield "name".
         If name is a macro with arguments, yield a series of
         "name(arg1,...,argN)" where each argument takes each possible
@@ -145,6 +181,9 @@
         except BaseException as e:
             raise Exception('distribute_arguments({})'.format(name)) from e
 
+    def generate_expressions(self, names):
+        return itertools.chain(*map(self.distribute_arguments, names))
+
     _argument_split_re = re.compile(r' *, *')
     @classmethod
     def _argument_split(cls, arguments):
@@ -154,7 +193,7 @@
     # Groups: 1=macro name, 2=type, 3=argument list (optional).
     _header_line_re = \
         re.compile(r'#define +' +
-                   r'(PSA_((?:KEY_)?[A-Z]+)_\w+)' +
+                   r'(PSA_((?:(?:DH|ECC|KEY)_)?[A-Z]+)_\w+)' +
                    r'(?:\(([^\n()]*)\))?')
     # Regex of macro names to exclude.
     _excluded_name_re = re.compile(r'_(?:GET|IS|OF)_|_(?:BASE|FLAG|MASK)\Z')
@@ -167,10 +206,6 @@
         # Auxiliary macro whose name doesn't fit the usual patterns for
         # auxiliary macros.
         'PSA_ALG_AEAD_WITH_DEFAULT_TAG_LENGTH_CASE',
-        # PSA_ALG_ECDH and PSA_ALG_FFDH are excluded for now as the script
-        # currently doesn't support them.
-        'PSA_ALG_ECDH',
-        'PSA_ALG_FFDH',
         # Deprecated aliases.
         'PSA_ERROR_UNKNOWN_ERROR',
         'PSA_ERROR_OCCUPIED_SLOT',
@@ -184,6 +219,7 @@
         if not m:
             return
         name = m.group(1)
+        self.all_declared.add(name)
         if re.search(self._excluded_name_re, name) or \
            name in self._excluded_names:
             return
@@ -200,26 +236,34 @@
             for line in lines:
                 self.parse_header_line(line)
 
+    _macro_identifier_re = r'[A-Z]\w+'
+    def generate_undeclared_names(self, expr):
+        for name in re.findall(self._macro_identifier_re, expr):
+            if name not in self.all_declared:
+                yield name
+
+    def accept_test_case_line(self, function, argument):
+        #pylint: disable=unused-argument
+        undeclared = list(self.generate_undeclared_names(argument))
+        if undeclared:
+            raise Exception('Undeclared names in test case', undeclared)
+        return True
+
     def add_test_case_line(self, function, argument):
         """Parse a test case data line, looking for algorithm metadata tests."""
+        sets = []
         if function.endswith('_algorithm'):
-            # As above, ECDH and FFDH algorithms are excluded for now.
-            # Support for them will be added in the future.
-            if 'ECDH' in argument or 'FFDH' in argument:
-                return
-            self.algorithms.add(argument)
-            if function == 'hash_algorithm':
-                self.hash_algorithms.add(argument)
-            elif function in ['mac_algorithm', 'hmac_algorithm']:
-                self.mac_algorithms.add(argument)
-            elif function == 'aead_algorithm':
-                self.aead_algorithms.add(argument)
-        elif function == 'key_type':
-            self.key_types.add(argument)
-        elif function == 'ecc_key_types':
-            self.ecc_curves.add(argument)
-        elif function == 'dh_key_types':
-            self.dh_groups.add(argument)
+            sets.append(self.algorithms)
+            if function == 'key_agreement_algorithm' and \
+               argument.startswith('PSA_ALG_KEY_AGREEMENT('):
+                # We only want *raw* key agreement algorithms as such, so
+                # exclude ones that are already chained with a KDF.
+                # Keep the expression as one to test as an algorithm.
+                function = 'other_algorithm'
+        sets += self.table_by_test_function[function]
+        if self.accept_test_case_line(function, argument):
+            for s in sets:
+                s.add(argument)
 
     # Regex matching a *.data line containing a test function call and
     # its arguments. The actual definition is partly positional, but this
@@ -233,9 +277,9 @@
                 if m:
                     self.add_test_case_line(m.group(1), m.group(2))
 
-def gather_inputs(headers, test_suites):
+def gather_inputs(headers, test_suites, inputs_class=Inputs):
     """Read the list of inputs to test psa_constant_names with."""
-    inputs = Inputs()
+    inputs = inputs_class()
     for header in headers:
         inputs.parse_header(header)
     for test_cases in test_suites:
@@ -252,8 +296,10 @@
     except OSError:
         pass
 
-def run_c(options, type_word, names):
-    """Generate and run a program to print out numerical values for names."""
+def run_c(type_word, expressions, include_path=None, keep_c=False):
+    """Generate and run a program to print out numerical values for expressions."""
+    if include_path is None:
+        include_path = []
     if type_word == 'status':
         cast_to = 'long'
         printf_format = '%ld'
@@ -278,18 +324,18 @@
 int main(void)
 {
 ''')
-        for name in names:
+        for expr in expressions:
             c_file.write('    printf("{}\\n", ({}) {});\n'
-                         .format(printf_format, cast_to, name))
+                         .format(printf_format, cast_to, expr))
         c_file.write('''    return 0;
 }
 ''')
         c_file.close()
         cc = os.getenv('CC', 'cc')
         subprocess.check_call([cc] +
-                              ['-I' + dir for dir in options.include] +
+                              ['-I' + dir for dir in include_path] +
                               ['-o', exe_name, c_name])
-        if options.keep_c:
+        if keep_c:
             sys.stderr.write('List of {} tests kept at {}\n'
                              .format(type_word, c_name))
         else:
@@ -302,76 +348,101 @@
 NORMALIZE_STRIP_RE = re.compile(r'\s+')
 def normalize(expr):
     """Normalize the C expression so as not to care about trivial differences.
+
     Currently "trivial differences" means whitespace.
     """
-    expr = re.sub(NORMALIZE_STRIP_RE, '', expr, len(expr))
-    return expr.strip().split('\n')
+    return re.sub(NORMALIZE_STRIP_RE, '', expr)
 
-def do_test(options, inputs, type_word, names):
-    """Test psa_constant_names for the specified type.
-    Run program on names.
-    Use inputs to figure out what arguments to pass to macros that
-    take arguments.
+def collect_values(inputs, type_word, include_path=None, keep_c=False):
+    """Generate expressions using known macro names and calculate their values.
+
+    Return a list of pairs of (expr, value) where expr is an expression and
+    value is a string representation of its integer value.
     """
-    names = sorted(itertools.chain(*map(inputs.distribute_arguments, names)))
-    values = run_c(options, type_word, names)
-    output = subprocess.check_output([options.program, type_word] + values)
-    outputs = output.decode('ascii').strip().split('\n')
-    errors = [(type_word, name, value, output)
-              for (name, value, output) in zip(names, values, outputs)
-              if normalize(name) != normalize(output)]
-    return len(names), errors
+    names = inputs.get_names(type_word)
+    expressions = sorted(inputs.generate_expressions(names))
+    values = run_c(type_word, expressions,
+                   include_path=include_path, keep_c=keep_c)
+    return expressions, values
 
-def report_errors(errors):
-    """Describe each case where the output is not as expected."""
-    for type_word, name, value, output in errors:
-        print('For {} "{}", got "{}" (value: {})'
-              .format(type_word, name, output, value))
+class Tests:
+    """An object representing tests and their results."""
 
-def run_tests(options, inputs):
-    """Run psa_constant_names on all the gathered inputs.
-    Return a tuple (count, errors) where count is the total number of inputs
-    that were tested and errors is the list of cases where the output was
-    not as expected.
-    """
-    count = 0
-    errors = []
-    for type_word, names in [('status', inputs.statuses),
-                             ('algorithm', inputs.algorithms),
-                             ('ecc_curve', inputs.ecc_curves),
-                             ('dh_group', inputs.dh_groups),
-                             ('key_type', inputs.key_types),
-                             ('key_usage', inputs.key_usage_flags)]:
-        c, e = do_test(options, inputs, type_word, names)
-        count += c
-        errors += e
-    return count, errors
+    Error = namedtuple('Error',
+                       ['type', 'expression', 'value', 'output'])
+
+    def __init__(self, options):
+        self.options = options
+        self.count = 0
+        self.errors = []
+
+    def run_one(self, inputs, type_word):
+        """Test psa_constant_names for the specified type.
+
+        Run the program on the names for this type.
+        Use the inputs to figure out what arguments to pass to macros that
+        take arguments.
+        """
+        expressions, values = collect_values(inputs, type_word,
+                                             include_path=self.options.include,
+                                             keep_c=self.options.keep_c)
+        output = subprocess.check_output([self.options.program, type_word] +
+                                         values)
+        outputs = output.decode('ascii').strip().split('\n')
+        self.count += len(expressions)
+        for expr, value, output in zip(expressions, values, outputs):
+            if normalize(expr) != normalize(output):
+                self.errors.append(self.Error(type=type_word,
+                                              expression=expr,
+                                              value=value,
+                                              output=output))
+
+    def run_all(self, inputs):
+        """Run psa_constant_names on all the gathered inputs."""
+        for type_word in ['status', 'algorithm', 'ecc_curve', 'dh_group',
+                          'key_type', 'key_usage']:
+            self.run_one(inputs, type_word)
+
+    def report(self, out):
+        """Describe each case where the output is not as expected.
+
+        Write the errors to ``out``.
+        Also write a total.
+        """
+        for error in self.errors:
+            out.write('For {} "{}", got "{}" (value: {})\n'
+                      .format(error.type, error.expression,
+                              error.output, error.value))
+        out.write('{} test cases'.format(self.count))
+        if self.errors:
+            out.write(', {} FAIL\n'.format(len(self.errors)))
+        else:
+            out.write(' PASS\n')
+
+HEADERS = ['psa/crypto.h', 'psa/crypto_extra.h', 'psa/crypto_values.h']
+TEST_SUITES = ['tests/suites/test_suite_psa_crypto_metadata.data']
 
 def main():
     parser = argparse.ArgumentParser(description=globals()['__doc__'])
     parser.add_argument('--include', '-I',
                         action='append', default=['include'],
                         help='Directory for header files')
-    parser.add_argument('--program',
-                        default='programs/psa/psa_constant_names',
-                        help='Program to test')
     parser.add_argument('--keep-c',
                         action='store_true', dest='keep_c', default=False,
                         help='Keep the intermediate C file')
     parser.add_argument('--no-keep-c',
                         action='store_false', dest='keep_c',
                         help='Don\'t keep the intermediate C file (default)')
+    parser.add_argument('--program',
+                        default='programs/psa/psa_constant_names',
+                        help='Program to test')
     options = parser.parse_args()
-    headers = [os.path.join(options.include[0], 'psa', h)
-               for h in ['crypto.h', 'crypto_extra.h', 'crypto_values.h']]
-    test_suites = ['tests/suites/test_suite_psa_crypto_metadata.data']
-    inputs = gather_inputs(headers, test_suites)
-    count, errors = run_tests(options, inputs)
-    report_errors(errors)
-    if errors == []:
-        print('{} test cases PASS'.format(count))
-    else:
-        print('{} test cases, {} FAIL'.format(count, len(errors)))
+    headers = [os.path.join(options.include[0], h) for h in HEADERS]
+    inputs = gather_inputs(headers, TEST_SUITES)
+    tests = Tests(options)
+    tests.run_all(inputs)
+    tests.report(sys.stdout)
+    if tests.errors:
         exit(1)
 
 if __name__ == '__main__':
diff --git a/tests/suites/test_suite_psa_crypto_metadata.data b/tests/suites/test_suite_psa_crypto_metadata.data
index e989895..9cdee03 100644
--- a/tests/suites/test_suite_psa_crypto_metadata.data
+++ b/tests/suites/test_suite_psa_crypto_metadata.data
@@ -262,6 +262,26 @@
 depends_on:MBEDTLS_SHA256_C
 key_derivation_algorithm:PSA_ALG_HKDF( PSA_ALG_SHA_256 ):ALG_IS_HKDF
 
+Key derivation: HKDF using SHA-384
+depends_on:MBEDTLS_SHA512_C
+key_derivation_algorithm:PSA_ALG_HKDF( PSA_ALG_SHA_384 ):ALG_IS_HKDF
+
+Key derivation: TLS 1.2 PRF using SHA-256
+depends_on:MBEDTLS_SHA256_C
+key_derivation_algorithm:PSA_ALG_TLS12_PRF( PSA_ALG_SHA_256 ):ALG_IS_TLS12_PRF
+
+Key derivation: TLS 1.2 PRF using SHA-384
+depends_on:MBEDTLS_SHA512_C
+key_derivation_algorithm:PSA_ALG_TLS12_PRF( PSA_ALG_SHA_384 ):ALG_IS_TLS12_PRF
+
+Key derivation: TLS 1.2 PSK-to-MS using SHA-256
+depends_on:MBEDTLS_SHA256_C
+key_derivation_algorithm:PSA_ALG_TLS12_PSK_TO_MS( PSA_ALG_SHA_256 ):ALG_IS_TLS12_PSK_TO_MS
+
+Key derivation: TLS 1.2 PSK-to-MS using SHA-384
+depends_on:MBEDTLS_SHA512_C
+key_derivation_algorithm:PSA_ALG_TLS12_PSK_TO_MS( PSA_ALG_SHA_384 ):ALG_IS_TLS12_PSK_TO_MS
+
 Key agreement: FFDH, raw output
 depends_on:MBEDTLS_DHM_C
 key_agreement_algorithm:PSA_ALG_FFDH:ALG_IS_FFDH | ALG_IS_RAW_KEY_AGREEMENT:PSA_ALG_FFDH:PSA_ALG_CATEGORY_KEY_DERIVATION
@@ -270,6 +290,10 @@
 depends_on:MBEDTLS_DHM_C
 key_agreement_algorithm:PSA_ALG_KEY_AGREEMENT( PSA_ALG_FFDH, PSA_ALG_HKDF( PSA_ALG_SHA_256 ) ):ALG_IS_FFDH:PSA_ALG_FFDH:PSA_ALG_HKDF( PSA_ALG_SHA_256 )
 
+Key agreement: FFDH, HKDF using SHA-384
+depends_on:MBEDTLS_DHM_C
+key_agreement_algorithm:PSA_ALG_KEY_AGREEMENT( PSA_ALG_FFDH, PSA_ALG_HKDF( PSA_ALG_SHA_384 ) ):ALG_IS_FFDH:PSA_ALG_FFDH:PSA_ALG_HKDF( PSA_ALG_SHA_384 )
+
 Key agreement: ECDH, raw output
 depends_on:MBEDTLS_ECDH_C
 key_agreement_algorithm:PSA_ALG_ECDH:ALG_IS_ECDH | ALG_IS_RAW_KEY_AGREEMENT:PSA_ALG_ECDH:PSA_ALG_CATEGORY_KEY_DERIVATION
@@ -278,6 +302,10 @@
 depends_on:MBEDTLS_ECDH_C
 key_agreement_algorithm:PSA_ALG_KEY_AGREEMENT( PSA_ALG_ECDH, PSA_ALG_HKDF( PSA_ALG_SHA_256 ) ):ALG_IS_ECDH:PSA_ALG_ECDH:PSA_ALG_HKDF( PSA_ALG_SHA_256 )
 
+Key agreement: ECDH, HKDF using SHA-384
+depends_on:MBEDTLS_ECDH_C
+key_agreement_algorithm:PSA_ALG_KEY_AGREEMENT( PSA_ALG_ECDH, PSA_ALG_HKDF( PSA_ALG_SHA_384 ) ):ALG_IS_ECDH:PSA_ALG_ECDH:PSA_ALG_HKDF( PSA_ALG_SHA_384 )
+
 Key type: raw data
 key_type:PSA_KEY_TYPE_RAW_DATA:KEY_TYPE_IS_UNSTRUCTURED
 
diff --git a/tests/suites/test_suite_psa_crypto_metadata.function b/tests/suites/test_suite_psa_crypto_metadata.function
index a9f1b39..3a9347e 100644
--- a/tests/suites/test_suite_psa_crypto_metadata.function
+++ b/tests/suites/test_suite_psa_crypto_metadata.function
@@ -37,6 +37,8 @@
 #define ALG_IS_WILDCARD                 ( 1u << 19 )
 #define ALG_IS_RAW_KEY_AGREEMENT        ( 1u << 20 )
 #define ALG_IS_AEAD_ON_BLOCK_CIPHER     ( 1u << 21 )
+#define ALG_IS_TLS12_PRF                ( 1u << 22 )
+#define ALG_IS_TLS12_PSK_TO_MS          ( 1u << 23 )
 
 /* Flags for key type classification macros. There is a flag for every
  * key type classification macro PSA_KEY_TYPE_IS_xxx except for some that