Merge pull request #324 from gilles-peskine-arm/psa-test_psa_constant_names-refactor_and_ka
test_psa_constant_names: support key agreement, better code structure
diff --git a/tests/scripts/test_psa_constant_names.py b/tests/scripts/test_psa_constant_names.py
index 724f8d9..8931987 100755
--- a/tests/scripts/test_psa_constant_names.py
+++ b/tests/scripts/test_psa_constant_names.py
@@ -8,6 +8,7 @@
"""
import argparse
+from collections import namedtuple
import itertools
import os
import platform
@@ -60,12 +61,15 @@
from exc_value
class Inputs:
+ # pylint: disable=too-many-instance-attributes
"""Accumulate information about macros to test.
+
This includes macro names as well as information about their arguments
when applicable.
"""
def __init__(self):
+ self.all_declared = set()
# Sets of names per type
self.statuses = set(['PSA_SUCCESS'])
self.algorithms = set(['0xffffffff'])
@@ -86,11 +90,30 @@
self.table_by_prefix = {
'ERROR': self.statuses,
'ALG': self.algorithms,
- 'CURVE': self.ecc_curves,
- 'GROUP': self.dh_groups,
+ 'ECC_CURVE': self.ecc_curves,
+ 'DH_GROUP': self.dh_groups,
'KEY_TYPE': self.key_types,
'KEY_USAGE': self.key_usage_flags,
}
+ # Test functions
+ self.table_by_test_function = {
+ # Any function ending in _algorithm also gets added to
+ # self.algorithms.
+ 'key_type': [self.key_types],
+ 'ecc_key_types': [self.ecc_curves],
+ 'dh_key_types': [self.dh_groups],
+ 'hash_algorithm': [self.hash_algorithms],
+ 'mac_algorithm': [self.mac_algorithms],
+ 'cipher_algorithm': [],
+ 'hmac_algorithm': [self.mac_algorithms],
+ 'aead_algorithm': [self.aead_algorithms],
+ 'key_derivation_algorithm': [self.kdf_algorithms],
+ 'key_agreement_algorithm': [self.ka_algorithms],
+ 'asymmetric_signature_algorithm': [],
+ 'asymmetric_signature_wildcard': [self.algorithms],
+ 'asymmetric_encryption_algorithm': [],
+ 'other_algorithm': [],
+ }
# macro name -> list of argument names
self.argspecs = {}
# argument name -> list of values
@@ -99,8 +122,20 @@
'tag_length': ['1', '63'],
}
+ def get_names(self, type_word):
+ """Return the set of known names of values of the given type."""
+ return {
+ 'status': self.statuses,
+ 'algorithm': self.algorithms,
+ 'ecc_curve': self.ecc_curves,
+ 'dh_group': self.dh_groups,
+ 'key_type': self.key_types,
+ 'key_usage': self.key_usage_flags,
+ }[type_word]
+
def gather_arguments(self):
"""Populate the list of values for macro arguments.
+
Call this after parsing all the inputs.
"""
self.arguments_for['hash_alg'] = sorted(self.hash_algorithms)
@@ -118,6 +153,7 @@
def distribute_arguments(self, name):
"""Generate macro calls with each tested argument set.
+
If name is a macro without arguments, just yield "name".
If name is a macro with arguments, yield a series of
"name(arg1,...,argN)" where each argument takes each possible
@@ -145,6 +181,9 @@
except BaseException as e:
raise Exception('distribute_arguments({})'.format(name)) from e
+ def generate_expressions(self, names):
+ return itertools.chain(*map(self.distribute_arguments, names))
+
_argument_split_re = re.compile(r' *, *')
@classmethod
def _argument_split(cls, arguments):
@@ -154,7 +193,7 @@
# Groups: 1=macro name, 2=type, 3=argument list (optional).
_header_line_re = \
re.compile(r'#define +' +
- r'(PSA_((?:KEY_)?[A-Z]+)_\w+)' +
+ r'(PSA_((?:(?:DH|ECC|KEY)_)?[A-Z]+)_\w+)' +
r'(?:\(([^\n()]*)\))?')
# Regex of macro names to exclude.
_excluded_name_re = re.compile(r'_(?:GET|IS|OF)_|_(?:BASE|FLAG|MASK)\Z')
@@ -167,10 +206,6 @@
# Auxiliary macro whose name doesn't fit the usual patterns for
# auxiliary macros.
'PSA_ALG_AEAD_WITH_DEFAULT_TAG_LENGTH_CASE',
- # PSA_ALG_ECDH and PSA_ALG_FFDH are excluded for now as the script
- # currently doesn't support them.
- 'PSA_ALG_ECDH',
- 'PSA_ALG_FFDH',
# Deprecated aliases.
'PSA_ERROR_UNKNOWN_ERROR',
'PSA_ERROR_OCCUPIED_SLOT',
@@ -184,6 +219,7 @@
if not m:
return
name = m.group(1)
+ self.all_declared.add(name)
if re.search(self._excluded_name_re, name) or \
name in self._excluded_names:
return
@@ -200,26 +236,34 @@
for line in lines:
self.parse_header_line(line)
+ _macro_identifier_re = r'[A-Z]\w+'
+ def generate_undeclared_names(self, expr):
+ for name in re.findall(self._macro_identifier_re, expr):
+ if name not in self.all_declared:
+ yield name
+
+ def accept_test_case_line(self, function, argument):
+ #pylint: disable=unused-argument
+ undeclared = list(self.generate_undeclared_names(argument))
+ if undeclared:
+ raise Exception('Undeclared names in test case', undeclared)
+ return True
+
def add_test_case_line(self, function, argument):
"""Parse a test case data line, looking for algorithm metadata tests."""
+ sets = []
if function.endswith('_algorithm'):
- # As above, ECDH and FFDH algorithms are excluded for now.
- # Support for them will be added in the future.
- if 'ECDH' in argument or 'FFDH' in argument:
- return
- self.algorithms.add(argument)
- if function == 'hash_algorithm':
- self.hash_algorithms.add(argument)
- elif function in ['mac_algorithm', 'hmac_algorithm']:
- self.mac_algorithms.add(argument)
- elif function == 'aead_algorithm':
- self.aead_algorithms.add(argument)
- elif function == 'key_type':
- self.key_types.add(argument)
- elif function == 'ecc_key_types':
- self.ecc_curves.add(argument)
- elif function == 'dh_key_types':
- self.dh_groups.add(argument)
+ sets.append(self.algorithms)
+ if function == 'key_agreement_algorithm' and \
+ argument.startswith('PSA_ALG_KEY_AGREEMENT('):
+ # We only want *raw* key agreement algorithms as such, so
+ # exclude ones that are already chained with a KDF.
+ # Keep the expression as one to test as an algorithm.
+ function = 'other_algorithm'
+ sets += self.table_by_test_function[function]
+ if self.accept_test_case_line(function, argument):
+ for s in sets:
+ s.add(argument)
# Regex matching a *.data line containing a test function call and
# its arguments. The actual definition is partly positional, but this
@@ -233,9 +277,9 @@
if m:
self.add_test_case_line(m.group(1), m.group(2))
-def gather_inputs(headers, test_suites):
+def gather_inputs(headers, test_suites, inputs_class=Inputs):
"""Read the list of inputs to test psa_constant_names with."""
- inputs = Inputs()
+ inputs = inputs_class()
for header in headers:
inputs.parse_header(header)
for test_cases in test_suites:
@@ -252,8 +296,10 @@
except OSError:
pass
-def run_c(options, type_word, names):
- """Generate and run a program to print out numerical values for names."""
+def run_c(type_word, expressions, include_path=None, keep_c=False):
+ """Generate and run a program to print out numerical values for expressions."""
+ if include_path is None:
+ include_path = []
if type_word == 'status':
cast_to = 'long'
printf_format = '%ld'
@@ -278,18 +324,18 @@
int main(void)
{
''')
- for name in names:
+ for expr in expressions:
c_file.write(' printf("{}\\n", ({}) {});\n'
- .format(printf_format, cast_to, name))
+ .format(printf_format, cast_to, expr))
c_file.write(''' return 0;
}
''')
c_file.close()
cc = os.getenv('CC', 'cc')
subprocess.check_call([cc] +
- ['-I' + dir for dir in options.include] +
+ ['-I' + dir for dir in include_path] +
['-o', exe_name, c_name])
- if options.keep_c:
+ if keep_c:
sys.stderr.write('List of {} tests kept at {}\n'
.format(type_word, c_name))
else:
@@ -302,76 +348,101 @@
NORMALIZE_STRIP_RE = re.compile(r'\s+')
def normalize(expr):
"""Normalize the C expression so as not to care about trivial differences.
+
Currently "trivial differences" means whitespace.
"""
- expr = re.sub(NORMALIZE_STRIP_RE, '', expr, len(expr))
- return expr.strip().split('\n')
+ return re.sub(NORMALIZE_STRIP_RE, '', expr)
-def do_test(options, inputs, type_word, names):
- """Test psa_constant_names for the specified type.
- Run program on names.
- Use inputs to figure out what arguments to pass to macros that
- take arguments.
+def collect_values(inputs, type_word, include_path=None, keep_c=False):
+ """Generate expressions using known macro names and calculate their values.
+
+ Return a list of pairs of (expr, value) where expr is an expression and
+ value is a string representation of its integer value.
"""
- names = sorted(itertools.chain(*map(inputs.distribute_arguments, names)))
- values = run_c(options, type_word, names)
- output = subprocess.check_output([options.program, type_word] + values)
- outputs = output.decode('ascii').strip().split('\n')
- errors = [(type_word, name, value, output)
- for (name, value, output) in zip(names, values, outputs)
- if normalize(name) != normalize(output)]
- return len(names), errors
+ names = inputs.get_names(type_word)
+ expressions = sorted(inputs.generate_expressions(names))
+ values = run_c(type_word, expressions,
+ include_path=include_path, keep_c=keep_c)
+ return expressions, values
-def report_errors(errors):
- """Describe each case where the output is not as expected."""
- for type_word, name, value, output in errors:
- print('For {} "{}", got "{}" (value: {})'
- .format(type_word, name, output, value))
+class Tests:
+ """An object representing tests and their results."""
-def run_tests(options, inputs):
- """Run psa_constant_names on all the gathered inputs.
- Return a tuple (count, errors) where count is the total number of inputs
- that were tested and errors is the list of cases where the output was
- not as expected.
- """
- count = 0
- errors = []
- for type_word, names in [('status', inputs.statuses),
- ('algorithm', inputs.algorithms),
- ('ecc_curve', inputs.ecc_curves),
- ('dh_group', inputs.dh_groups),
- ('key_type', inputs.key_types),
- ('key_usage', inputs.key_usage_flags)]:
- c, e = do_test(options, inputs, type_word, names)
- count += c
- errors += e
- return count, errors
+ Error = namedtuple('Error',
+ ['type', 'expression', 'value', 'output'])
+
+ def __init__(self, options):
+ self.options = options
+ self.count = 0
+ self.errors = []
+
+ def run_one(self, inputs, type_word):
+ """Test psa_constant_names for the specified type.
+
+ Run the program on the names for this type.
+ Use the inputs to figure out what arguments to pass to macros that
+ take arguments.
+ """
+ expressions, values = collect_values(inputs, type_word,
+ include_path=self.options.include,
+ keep_c=self.options.keep_c)
+ output = subprocess.check_output([self.options.program, type_word] +
+ values)
+ outputs = output.decode('ascii').strip().split('\n')
+ self.count += len(expressions)
+ for expr, value, output in zip(expressions, values, outputs):
+ if normalize(expr) != normalize(output):
+ self.errors.append(self.Error(type=type_word,
+ expression=expr,
+ value=value,
+ output=output))
+
+ def run_all(self, inputs):
+ """Run psa_constant_names on all the gathered inputs."""
+ for type_word in ['status', 'algorithm', 'ecc_curve', 'dh_group',
+ 'key_type', 'key_usage']:
+ self.run_one(inputs, type_word)
+
+ def report(self, out):
+ """Describe each case where the output is not as expected.
+
+ Write the errors to ``out``.
+ Also write a total.
+ """
+ for error in self.errors:
+ out.write('For {} "{}", got "{}" (value: {})\n'
+ .format(error.type, error.expression,
+ error.output, error.value))
+ out.write('{} test cases'.format(self.count))
+ if self.errors:
+ out.write(', {} FAIL\n'.format(len(self.errors)))
+ else:
+ out.write(' PASS\n')
+
+HEADERS = ['psa/crypto.h', 'psa/crypto_extra.h', 'psa/crypto_values.h']
+TEST_SUITES = ['tests/suites/test_suite_psa_crypto_metadata.data']
def main():
parser = argparse.ArgumentParser(description=globals()['__doc__'])
parser.add_argument('--include', '-I',
action='append', default=['include'],
help='Directory for header files')
- parser.add_argument('--program',
- default='programs/psa/psa_constant_names',
- help='Program to test')
parser.add_argument('--keep-c',
action='store_true', dest='keep_c', default=False,
help='Keep the intermediate C file')
parser.add_argument('--no-keep-c',
action='store_false', dest='keep_c',
help='Don\'t keep the intermediate C file (default)')
+ parser.add_argument('--program',
+ default='programs/psa/psa_constant_names',
+ help='Program to test')
options = parser.parse_args()
- headers = [os.path.join(options.include[0], 'psa', h)
- for h in ['crypto.h', 'crypto_extra.h', 'crypto_values.h']]
- test_suites = ['tests/suites/test_suite_psa_crypto_metadata.data']
- inputs = gather_inputs(headers, test_suites)
- count, errors = run_tests(options, inputs)
- report_errors(errors)
- if errors == []:
- print('{} test cases PASS'.format(count))
- else:
- print('{} test cases, {} FAIL'.format(count, len(errors)))
+ headers = [os.path.join(options.include[0], h) for h in HEADERS]
+ inputs = gather_inputs(headers, TEST_SUITES)
+ tests = Tests(options)
+ tests.run_all(inputs)
+ tests.report(sys.stdout)
+ if tests.errors:
exit(1)
if __name__ == '__main__':
diff --git a/tests/suites/test_suite_psa_crypto_metadata.data b/tests/suites/test_suite_psa_crypto_metadata.data
index e989895..9cdee03 100644
--- a/tests/suites/test_suite_psa_crypto_metadata.data
+++ b/tests/suites/test_suite_psa_crypto_metadata.data
@@ -262,6 +262,26 @@
depends_on:MBEDTLS_SHA256_C
key_derivation_algorithm:PSA_ALG_HKDF( PSA_ALG_SHA_256 ):ALG_IS_HKDF
+Key derivation: HKDF using SHA-384
+depends_on:MBEDTLS_SHA512_C
+key_derivation_algorithm:PSA_ALG_HKDF( PSA_ALG_SHA_384 ):ALG_IS_HKDF
+
+Key derivation: TLS 1.2 PRF using SHA-256
+depends_on:MBEDTLS_SHA256_C
+key_derivation_algorithm:PSA_ALG_TLS12_PRF( PSA_ALG_SHA_256 ):ALG_IS_TLS12_PRF
+
+Key derivation: TLS 1.2 PRF using SHA-384
+depends_on:MBEDTLS_SHA512_C
+key_derivation_algorithm:PSA_ALG_TLS12_PRF( PSA_ALG_SHA_384 ):ALG_IS_TLS12_PRF
+
+Key derivation: TLS 1.2 PSK-to-MS using SHA-256
+depends_on:MBEDTLS_SHA256_C
+key_derivation_algorithm:PSA_ALG_TLS12_PSK_TO_MS( PSA_ALG_SHA_256 ):ALG_IS_TLS12_PSK_TO_MS
+
+Key derivation: TLS 1.2 PSK-to-MS using SHA-384
+depends_on:MBEDTLS_SHA512_C
+key_derivation_algorithm:PSA_ALG_TLS12_PSK_TO_MS( PSA_ALG_SHA_384 ):ALG_IS_TLS12_PSK_TO_MS
+
Key agreement: FFDH, raw output
depends_on:MBEDTLS_DHM_C
key_agreement_algorithm:PSA_ALG_FFDH:ALG_IS_FFDH | ALG_IS_RAW_KEY_AGREEMENT:PSA_ALG_FFDH:PSA_ALG_CATEGORY_KEY_DERIVATION
@@ -270,6 +290,10 @@
depends_on:MBEDTLS_DHM_C
key_agreement_algorithm:PSA_ALG_KEY_AGREEMENT( PSA_ALG_FFDH, PSA_ALG_HKDF( PSA_ALG_SHA_256 ) ):ALG_IS_FFDH:PSA_ALG_FFDH:PSA_ALG_HKDF( PSA_ALG_SHA_256 )
+Key agreement: FFDH, HKDF using SHA-384
+depends_on:MBEDTLS_DHM_C
+key_agreement_algorithm:PSA_ALG_KEY_AGREEMENT( PSA_ALG_FFDH, PSA_ALG_HKDF( PSA_ALG_SHA_384 ) ):ALG_IS_FFDH:PSA_ALG_FFDH:PSA_ALG_HKDF( PSA_ALG_SHA_384 )
+
Key agreement: ECDH, raw output
depends_on:MBEDTLS_ECDH_C
key_agreement_algorithm:PSA_ALG_ECDH:ALG_IS_ECDH | ALG_IS_RAW_KEY_AGREEMENT:PSA_ALG_ECDH:PSA_ALG_CATEGORY_KEY_DERIVATION
@@ -278,6 +302,10 @@
depends_on:MBEDTLS_ECDH_C
key_agreement_algorithm:PSA_ALG_KEY_AGREEMENT( PSA_ALG_ECDH, PSA_ALG_HKDF( PSA_ALG_SHA_256 ) ):ALG_IS_ECDH:PSA_ALG_ECDH:PSA_ALG_HKDF( PSA_ALG_SHA_256 )
+Key agreement: ECDH, HKDF using SHA-384
+depends_on:MBEDTLS_ECDH_C
+key_agreement_algorithm:PSA_ALG_KEY_AGREEMENT( PSA_ALG_ECDH, PSA_ALG_HKDF( PSA_ALG_SHA_384 ) ):ALG_IS_ECDH:PSA_ALG_ECDH:PSA_ALG_HKDF( PSA_ALG_SHA_384 )
+
Key type: raw data
key_type:PSA_KEY_TYPE_RAW_DATA:KEY_TYPE_IS_UNSTRUCTURED
diff --git a/tests/suites/test_suite_psa_crypto_metadata.function b/tests/suites/test_suite_psa_crypto_metadata.function
index a9f1b39..3a9347e 100644
--- a/tests/suites/test_suite_psa_crypto_metadata.function
+++ b/tests/suites/test_suite_psa_crypto_metadata.function
@@ -37,6 +37,8 @@
#define ALG_IS_WILDCARD ( 1u << 19 )
#define ALG_IS_RAW_KEY_AGREEMENT ( 1u << 20 )
#define ALG_IS_AEAD_ON_BLOCK_CIPHER ( 1u << 21 )
+#define ALG_IS_TLS12_PRF ( 1u << 22 )
+#define ALG_IS_TLS12_PSK_TO_MS ( 1u << 23 )
/* Flags for key type classification macros. There is a flag for every
* key type classification macro PSA_KEY_TYPE_IS_xxx except for some that