Merge pull request #4559 from gilles-peskine-arm/psa-storage-format-test-algorithms-3.0

PSA storage format test: algorithms
diff --git a/include/mbedtls/config_psa.h b/include/mbedtls/config_psa.h
index 2032a36..f5db94e 100644
--- a/include/mbedtls/config_psa.h
+++ b/include/mbedtls/config_psa.h
@@ -38,6 +38,30 @@
 extern "C" {
 #endif
 
+
+
+/****************************************************************/
+/* De facto synonyms */
+/****************************************************************/
+
+#if defined(PSA_WANT_ALG_ECDSA_ANY) && !defined(PSA_WANT_ALG_ECDSA)
+#define PSA_WANT_ALG_ECDSA PSA_WANT_ALG_ECDSA_ANY
+#elif !defined(PSA_WANT_ALG_ECDSA_ANY) && defined(PSA_WANT_ALG_ECDSA)
+#define PSA_WANT_ALG_ECDSA_ANY PSA_WANT_ALG_ECDSA
+#endif
+
+#if defined(PSA_WANT_ALG_RSA_PKCS1V15_SIGN_RAW) && !defined(PSA_WANT_ALG_RSA_PKCS1V15_SIGN)
+#define PSA_WANT_ALG_RSA_PKCS1V15_SIGN PSA_WANT_ALG_RSA_PKCS1V15_SIGN_RAW
+#elif !defined(PSA_WANT_ALG_RSA_PKCS1V15_SIGN_RAW) && defined(PSA_WANT_ALG_RSA_PKCS1V15_SIGN)
+#define PSA_WANT_ALG_RSA_PKCS1V15_SIGN_RAW PSA_WANT_ALG_RSA_PKCS1V15_SIGN
+#endif
+
+
+
+/****************************************************************/
+/* Require built-in implementations based on PSA requirements */
+/****************************************************************/
+
 #if defined(MBEDTLS_PSA_CRYPTO_CONFIG)
 
 #if defined(PSA_WANT_ALG_DETERMINISTIC_ECDSA)
@@ -497,6 +521,12 @@
 #endif /* !MBEDTLS_PSA_ACCEL_ECC_SECP_K1_256 */
 #endif /* PSA_WANT_ECC_SECP_K1_256 */
 
+
+
+/****************************************************************/
+/* Infer PSA requirements from Mbed TLS capabilities */
+/****************************************************************/
+
 #else /* MBEDTLS_PSA_CRYPTO_CONFIG */
 
 /*
@@ -522,6 +552,7 @@
 #if defined(MBEDTLS_ECDSA_C)
 #define MBEDTLS_PSA_BUILTIN_ALG_ECDSA 1
 #define PSA_WANT_ALG_ECDSA 1
+#define PSA_WANT_ALG_ECDSA_ANY 1
 
 // Only add in DETERMINISTIC support if ECDSA is also enabled
 #if defined(MBEDTLS_ECDSA_DETERMINISTIC)
@@ -586,6 +617,7 @@
 #define PSA_WANT_ALG_RSA_PKCS1V15_CRYPT 1
 #define MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN 1
 #define PSA_WANT_ALG_RSA_PKCS1V15_SIGN 1
+#define PSA_WANT_ALG_RSA_PKCS1V15_SIGN_RAW 1
 #endif /* MBEDTLSS_PKCS1_V15 */
 #if defined(MBEDTLS_PKCS1_V21)
 #define MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP 1
diff --git a/scripts/mbedtls_dev/crypto_knowledge.py b/scripts/mbedtls_dev/crypto_knowledge.py
index aa52790..94a97e7 100644
--- a/scripts/mbedtls_dev/crypto_knowledge.py
+++ b/scripts/mbedtls_dev/crypto_knowledge.py
@@ -33,7 +33,7 @@
         `name` is a string 'PSA_KEY_TYPE_xxx' which is the name of a PSA key
         type macro. For key types that take arguments, the arguments can
         be passed either through the optional argument `params` or by
-        passing an expression of the form 'PSA_KEY_TYPE_xxx(param1, param2)'
+        passing an expression of the form 'PSA_KEY_TYPE_xxx(param1, ...)'
         in `name` as a string.
         """
 
@@ -48,7 +48,7 @@
                 m = re.match(r'(\w+)\s*\((.*)\)\Z', self.name)
                 assert m is not None
                 self.name = m.group(1)
-                params = ','.split(m.group(2))
+                params = m.group(2).split(',')
         self.params = (None if params is None else
                        [param.strip() for param in params])
         """The parameters of the key type, if there are any.
diff --git a/scripts/mbedtls_dev/macro_collector.py b/scripts/mbedtls_dev/macro_collector.py
index a2192ba..0e76435 100644
--- a/scripts/mbedtls_dev/macro_collector.py
+++ b/scripts/mbedtls_dev/macro_collector.py
@@ -18,7 +18,55 @@
 
 import itertools
 import re
-from typing import Dict, Iterable, Iterator, List, Set
+from typing import Dict, Iterable, Iterator, List, Optional, Pattern, Set, Tuple, Union
+
+
+class ReadFileLineException(Exception):
+    def __init__(self, filename: str, line_number: Union[int, str]) -> None:
+        message = 'in {} at {}'.format(filename, line_number)
+        super(ReadFileLineException, self).__init__(message)
+        self.filename = filename
+        self.line_number = line_number
+
+
+class read_file_lines:
+    # Dear Pylint, conventionally, a context manager class name is lowercase.
+    # pylint: disable=invalid-name,too-few-public-methods
+    """Context manager to read a text file line by line.
+
+    ```
+    with read_file_lines(filename) as lines:
+        for line in lines:
+            process(line)
+    ```
+    is equivalent to
+    ```
+    with open(filename, 'r') as input_file:
+        for line in input_file:
+            process(line)
+    ```
+    except that if process(line) raises an exception, then the read_file_lines
+    snippet annotates the exception with the file name and line number.
+    """
+    def __init__(self, filename: str, binary: bool = False) -> None:
+        self.filename = filename
+        self.line_number = 'entry' #type: Union[int, str]
+        self.generator = None #type: Optional[Iterable[Tuple[int, str]]]
+        self.binary = binary
+    def __enter__(self) -> 'read_file_lines':
+        self.generator = enumerate(open(self.filename,
+                                        'rb' if self.binary else 'r'))
+        return self
+    def __iter__(self) -> Iterator[str]:
+        assert self.generator is not None
+        for line_number, content in self.generator:
+            self.line_number = line_number
+            yield content
+        self.line_number = 'exit'
+    def __exit__(self, exc_type, exc_value, exc_traceback) -> None:
+        if exc_type is not None:
+            raise ReadFileLineException(self.filename, self.line_number) \
+                from exc_value
 
 
 class PSAMacroEnumerator:
@@ -57,6 +105,20 @@
             'tag_length': [],
             'min_tag_length': [],
         } #type: Dict[str, List[str]]
+        # Whether to include intermediate macros in enumerations. Intermediate
+        # macros serve as category headers and are not valid values of their
+        # type. See `is_internal_name`.
+        # Always false in this class, may be set to true in derived classes.
+        self.include_intermediate = False
+
+    def is_internal_name(self, name: str) -> bool:
+        """Whether this is an internal macro. Internal macros will be skipped."""
+        if not self.include_intermediate:
+            if name.endswith('_BASE') or name.endswith('_NONE'):
+                return True
+            if '_CATEGORY_' in name:
+                return True
+        return name.endswith('_FLAG') or name.endswith('_MASK')
 
     def gather_arguments(self) -> None:
         """Populate the list of values for macro arguments.
@@ -73,7 +135,11 @@
 
     @staticmethod
     def _format_arguments(name: str, arguments: Iterable[str]) -> str:
-        """Format a macro call with arguments.."""
+        """Format a macro call with arguments.
+
+        The resulting format is consistent with
+        `InputsForTest.normalize_argument`.
+        """
         return name + '(' + ', '.join(arguments) + ')'
 
     _argument_split_re = re.compile(r' *, *')
@@ -111,6 +177,15 @@
         except BaseException as e:
             raise Exception('distribute_arguments({})'.format(name)) from e
 
+    def distribute_arguments_without_duplicates(
+            self, seen: Set[str], name: str
+    ) -> Iterator[str]:
+        """Same as `distribute_arguments`, but don't repeat seen results."""
+        for result in self.distribute_arguments(name):
+            if result not in seen:
+                seen.add(result)
+                yield result
+
     def generate_expressions(self, names: Iterable[str]) -> Iterator[str]:
         """Generate expressions covering values constructed from the given names.
 
@@ -123,7 +198,11 @@
         * ``macros.generate_expressions(macros.key_types)`` generates all
           key types.
         """
-        return itertools.chain(*map(self.distribute_arguments, names))
+        seen = set() #type: Set[str]
+        return itertools.chain(*(
+            self.distribute_arguments_without_duplicates(seen, name)
+            for name in names
+        ))
 
 
 class PSAMacroCollector(PSAMacroEnumerator):
@@ -144,15 +223,6 @@
         self.key_types_from_group = {} #type: Dict[str, str]
         self.algorithms_from_hash = {} #type: Dict[str, str]
 
-    def is_internal_name(self, name: str) -> bool:
-        """Whether this is an internal macro. Internal macros will be skipped."""
-        if not self.include_intermediate:
-            if name.endswith('_BASE') or name.endswith('_NONE'):
-                return True
-            if '_CATEGORY_' in name:
-                return True
-        return name.endswith('_FLAG') or name.endswith('_MASK')
-
     def record_algorithm_subtype(self, name: str, expansion: str) -> None:
         """Record the subtype of an algorithm constructor.
 
@@ -251,3 +321,179 @@
                 m = re.search(self._continued_line_re, line)
             line = re.sub(self._nonascii_re, rb'', line).decode('ascii')
             self.read_line(line)
+
+
+class InputsForTest(PSAMacroEnumerator):
+    # pylint: disable=too-many-instance-attributes
+    """Accumulate information about macros to test.
+enumerate
+    This includes macro names as well as information about their arguments
+    when applicable.
+    """
+
+    def __init__(self) -> None:
+        super().__init__()
+        self.all_declared = set() #type: Set[str]
+        # Identifier prefixes
+        self.table_by_prefix = {
+            'ERROR': self.statuses,
+            'ALG': self.algorithms,
+            'ECC_CURVE': self.ecc_curves,
+            'DH_GROUP': self.dh_groups,
+            'KEY_TYPE': self.key_types,
+            'KEY_USAGE': self.key_usage_flags,
+        } #type: Dict[str, Set[str]]
+        # Test functions
+        self.table_by_test_function = {
+            # Any function ending in _algorithm also gets added to
+            # self.algorithms.
+            'key_type': [self.key_types],
+            'block_cipher_key_type': [self.key_types],
+            'stream_cipher_key_type': [self.key_types],
+            'ecc_key_family': [self.ecc_curves],
+            'ecc_key_types': [self.ecc_curves],
+            'dh_key_family': [self.dh_groups],
+            'dh_key_types': [self.dh_groups],
+            'hash_algorithm': [self.hash_algorithms],
+            'mac_algorithm': [self.mac_algorithms],
+            'cipher_algorithm': [],
+            'hmac_algorithm': [self.mac_algorithms],
+            'aead_algorithm': [self.aead_algorithms],
+            'key_derivation_algorithm': [self.kdf_algorithms],
+            'key_agreement_algorithm': [self.ka_algorithms],
+            'asymmetric_signature_algorithm': [],
+            'asymmetric_signature_wildcard': [self.algorithms],
+            'asymmetric_encryption_algorithm': [],
+            'other_algorithm': [],
+        } #type: Dict[str, List[Set[str]]]
+        self.arguments_for['mac_length'] += ['1', '63']
+        self.arguments_for['min_mac_length'] += ['1', '63']
+        self.arguments_for['tag_length'] += ['1', '63']
+        self.arguments_for['min_tag_length'] += ['1', '63']
+
+    def add_numerical_values(self) -> None:
+        """Add numerical values that are not supported to the known identifiers."""
+        # Sets of names per type
+        self.algorithms.add('0xffffffff')
+        self.ecc_curves.add('0xff')
+        self.dh_groups.add('0xff')
+        self.key_types.add('0xffff')
+        self.key_usage_flags.add('0x80000000')
+
+        # Hard-coded values for unknown algorithms
+        #
+        # These have to have values that are correct for their respective
+        # PSA_ALG_IS_xxx macros, but are also not currently assigned and are
+        # not likely to be assigned in the near future.
+        self.hash_algorithms.add('0x020000fe') # 0x020000ff is PSA_ALG_ANY_HASH
+        self.mac_algorithms.add('0x03007fff')
+        self.ka_algorithms.add('0x09fc0000')
+        self.kdf_algorithms.add('0x080000ff')
+        # For AEAD algorithms, the only variability is over the tag length,
+        # and this only applies to known algorithms, so don't test an
+        # unknown algorithm.
+
+    def get_names(self, type_word: str) -> Set[str]:
+        """Return the set of known names of values of the given type."""
+        return {
+            'status': self.statuses,
+            'algorithm': self.algorithms,
+            'ecc_curve': self.ecc_curves,
+            'dh_group': self.dh_groups,
+            'key_type': self.key_types,
+            'key_usage': self.key_usage_flags,
+        }[type_word]
+
+    # Regex for interesting header lines.
+    # Groups: 1=macro name, 2=type, 3=argument list (optional).
+    _header_line_re = \
+        re.compile(r'#define +' +
+                   r'(PSA_((?:(?:DH|ECC|KEY)_)?[A-Z]+)_\w+)' +
+                   r'(?:\(([^\n()]*)\))?')
+    # Regex of macro names to exclude.
+    _excluded_name_re = re.compile(r'_(?:GET|IS|OF)_|_(?:BASE|FLAG|MASK)\Z')
+    # Additional excluded macros.
+    _excluded_names = set([
+        # Macros that provide an alternative way to build the same
+        # algorithm as another macro.
+        'PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG',
+        'PSA_ALG_FULL_LENGTH_MAC',
+        # Auxiliary macro whose name doesn't fit the usual patterns for
+        # auxiliary macros.
+        'PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG_CASE',
+    ])
+    def parse_header_line(self, line: str) -> None:
+        """Parse a C header line, looking for "#define PSA_xxx"."""
+        m = re.match(self._header_line_re, line)
+        if not m:
+            return
+        name = m.group(1)
+        self.all_declared.add(name)
+        if re.search(self._excluded_name_re, name) or \
+           name in self._excluded_names or \
+           self.is_internal_name(name):
+            return
+        dest = self.table_by_prefix.get(m.group(2))
+        if dest is None:
+            return
+        dest.add(name)
+        if m.group(3):
+            self.argspecs[name] = self._argument_split(m.group(3))
+
+    _nonascii_re = re.compile(rb'[^\x00-\x7f]+') #type: Pattern
+    def parse_header(self, filename: str) -> None:
+        """Parse a C header file, looking for "#define PSA_xxx"."""
+        with read_file_lines(filename, binary=True) as lines:
+            for line in lines:
+                line = re.sub(self._nonascii_re, rb'', line).decode('ascii')
+                self.parse_header_line(line)
+
+    _macro_identifier_re = re.compile(r'[A-Z]\w+')
+    def generate_undeclared_names(self, expr: str) -> Iterable[str]:
+        for name in re.findall(self._macro_identifier_re, expr):
+            if name not in self.all_declared:
+                yield name
+
+    def accept_test_case_line(self, function: str, argument: str) -> bool:
+        #pylint: disable=unused-argument
+        undeclared = list(self.generate_undeclared_names(argument))
+        if undeclared:
+            raise Exception('Undeclared names in test case', undeclared)
+        return True
+
+    @staticmethod
+    def normalize_argument(argument: str) -> str:
+        """Normalize whitespace in the given C expression.
+
+        The result uses the same whitespace as
+        ` PSAMacroEnumerator.distribute_arguments`.
+        """
+        return re.sub(r',', r', ', re.sub(r' +', r'', argument))
+
+    def add_test_case_line(self, function: str, argument: str) -> None:
+        """Parse a test case data line, looking for algorithm metadata tests."""
+        sets = []
+        if function.endswith('_algorithm'):
+            sets.append(self.algorithms)
+            if function == 'key_agreement_algorithm' and \
+               argument.startswith('PSA_ALG_KEY_AGREEMENT('):
+                # We only want *raw* key agreement algorithms as such, so
+                # exclude ones that are already chained with a KDF.
+                # Keep the expression as one to test as an algorithm.
+                function = 'other_algorithm'
+        sets += self.table_by_test_function[function]
+        if self.accept_test_case_line(function, argument):
+            for s in sets:
+                s.add(self.normalize_argument(argument))
+
+    # Regex matching a *.data line containing a test function call and
+    # its arguments. The actual definition is partly positional, but this
+    # regex is good enough in practice.
+    _test_case_line_re = re.compile(r'(?!depends_on:)(\w+):([^\n :][^:\n]*)')
+    def parse_test_cases(self, filename: str) -> None:
+        """Parse a test case file (*.data), looking for algorithm metadata tests."""
+        with read_file_lines(filename) as lines:
+            for line in lines:
+                m = re.match(self._test_case_line_re, line)
+                if m:
+                    self.add_test_case_line(m.group(1), m.group(2))
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
index a9c9cf3..7898004 100644
--- a/tests/CMakeLists.txt
+++ b/tests/CMakeLists.txt
@@ -151,6 +151,8 @@
 add_test_suite(psa_crypto_se_driver_hal_mocks)
 add_test_suite(psa_crypto_slot_management)
 add_test_suite(psa_crypto_storage_format psa_crypto_storage_format.misc)
+add_test_suite(psa_crypto_storage_format psa_crypto_storage_format.current)
+add_test_suite(psa_crypto_storage_format psa_crypto_storage_format.v0)
 add_test_suite(psa_its)
 add_test_suite(random)
 add_test_suite(rsa)
diff --git a/tests/scripts/check-generated-files.sh b/tests/scripts/check-generated-files.sh
index b480837..a2c285f 100755
--- a/tests/scripts/check-generated-files.sh
+++ b/tests/scripts/check-generated-files.sh
@@ -44,23 +44,28 @@
     UPDATE='y'
 fi
 
+# check SCRIPT FILENAME[...]
+# check SCRIPT DIRECTORY
+# Run SCRIPT and check that it does not modify any of the specified files.
+# In the first form, there can be any number of FILENAMEs, which must be
+# regular files.
+# In the second form, there must be a single DIRECTORY, standing for the
+# list of files in the directory. Running SCRIPT must not modify any file
+# in the directory and must not add or remove files either.
+# If $UPDATE is empty, abort with an error status if a file is modified.
 check()
 {
     SCRIPT=$1
-    TO_CHECK=$2
-    PATTERN=""
-    FILES=""
+    shift
 
-    if [ -d $TO_CHECK ]; then
-        rm -f "$TO_CHECK"/*.bak
-        for FILE in $TO_CHECK/*; do
-            FILES="$FILE $FILES"
-        done
-    else
-        FILES=$TO_CHECK
+    directory=
+    if [ -d "$1" ]; then
+        directory="$1"
+        rm -f "$directory"/*.bak
+        set -- "$1"/*
     fi
 
-    for FILE in $FILES; do
+    for FILE in "$@"; do
         if [ -e "$FILE" ]; then
             cp "$FILE" "$FILE.bak"
         else
@@ -68,37 +73,32 @@
         fi
     done
 
-    $SCRIPT
+    "$SCRIPT"
 
     # Compare the script output to the old files and remove backups
-    for FILE in $FILES; do
-        if ! diff $FILE $FILE.bak >/dev/null 2>&1; then
+    for FILE in "$@"; do
+        if ! diff "$FILE" "$FILE.bak" >/dev/null 2>&1; then
             echo "'$FILE' was either modified or deleted by '$SCRIPT'"
             if [ -z "$UPDATE" ]; then
                 exit 1
             fi
         fi
         if [ -z "$UPDATE" ]; then
-            mv $FILE.bak $FILE
+            mv "$FILE.bak" "$FILE"
         else
             rm -f "$FILE.bak"
         fi
-
-        if [ -d $TO_CHECK ]; then
-            # Create a grep regular expression that we can check against the
-            # directory contents to test whether new files have been created
-            if [ -z $PATTERN ]; then
-                PATTERN="$(basename $FILE)"
-            else
-                PATTERN="$PATTERN\|$(basename $FILE)"
-            fi
-        fi
     done
 
-    if [ -d $TO_CHECK ]; then
+    if [ -n "$directory" ]; then
+        old_list="$*"
+        set -- "$directory"/*
+        new_list="$*"
         # Check if there are any new files
-        if ls -1 $TO_CHECK | grep -v "$PATTERN" >/dev/null 2>&1; then
-            echo "Files were created by '$SCRIPT'"
+        if [ "$old_list" != "$new_list" ]; then
+            echo "Files were deleted or created by '$SCRIPT'"
+            echo "Before: $old_list"
+            echo "After: $new_list"
             if [ -z "$UPDATE" ]; then
                 exit 1
             fi
diff --git a/tests/scripts/generate_psa_tests.py b/tests/scripts/generate_psa_tests.py
index 30f82db..8c53414 100755
--- a/tests/scripts/generate_psa_tests.py
+++ b/tests/scripts/generate_psa_tests.py
@@ -60,6 +60,14 @@
     """
     return [finish_family_dependency(dep, bits) for dep in dependencies]
 
+SYMBOLS_WITHOUT_DEPENDENCY = frozenset([
+    'PSA_ALG_AEAD_WITH_AT_LEAST_THIS_LENGTH_TAG', # modifier, only in policies
+    'PSA_ALG_AEAD_WITH_SHORTENED_TAG', # modifier
+    'PSA_ALG_ANY_HASH', # only in policies
+    'PSA_ALG_AT_LEAST_THIS_LENGTH_MAC', # modifier, only in policies
+    'PSA_ALG_KEY_AGREEMENT', # chaining
+    'PSA_ALG_TRUNCATED_MAC', # modifier
+])
 def automatic_dependencies(*expressions: str) -> List[str]:
     """Infer dependencies of a test case by looking for PSA_xxx symbols.
 
@@ -70,6 +78,7 @@
     used = set()
     for expr in expressions:
         used.update(re.findall(r'PSA_(?:ALG|ECC_FAMILY|KEY_TYPE)_\w+', expr))
+    used.difference_update(SYMBOLS_WITHOUT_DEPENDENCY)
     return sorted(psa_want_symbol(name) for name in used)
 
 # A temporary hack: at the time of writing, not all dependency symbols
@@ -100,24 +109,27 @@
 
     @staticmethod
     def remove_unwanted_macros(
-            constructors: macro_collector.PSAMacroCollector
+            constructors: macro_collector.PSAMacroEnumerator
     ) -> None:
-        # Mbed TLS doesn't support DSA. Don't attempt to generate any related
-        # test case.
+        # Mbed TLS doesn't support finite-field DH yet and will not support
+        # finite-field DSA. Don't attempt to generate any related test case.
+        constructors.key_types.discard('PSA_KEY_TYPE_DH_KEY_PAIR')
+        constructors.key_types.discard('PSA_KEY_TYPE_DH_PUBLIC_KEY')
         constructors.key_types.discard('PSA_KEY_TYPE_DSA_KEY_PAIR')
         constructors.key_types.discard('PSA_KEY_TYPE_DSA_PUBLIC_KEY')
-        constructors.algorithms_from_hash.pop('PSA_ALG_DSA', None)
-        constructors.algorithms_from_hash.pop('PSA_ALG_DETERMINISTIC_DSA', None)
 
-    def read_psa_interface(self) -> macro_collector.PSAMacroCollector:
+    def read_psa_interface(self) -> macro_collector.PSAMacroEnumerator:
         """Return the list of known key types, algorithms, etc."""
-        constructors = macro_collector.PSAMacroCollector()
+        constructors = macro_collector.InputsForTest()
         header_file_names = ['include/psa/crypto_values.h',
                              'include/psa/crypto_extra.h']
+        test_suites = ['tests/suites/test_suite_psa_crypto_metadata.data']
         for header_file_name in header_file_names:
-            with open(header_file_name, 'rb') as header_file:
-                constructors.read_file(header_file)
+            constructors.parse_header(header_file_name)
+        for test_cases in test_suites:
+            constructors.parse_test_cases(test_cases)
         self.remove_unwanted_macros(constructors)
+        constructors.gather_arguments()
         return constructors
 
 
@@ -199,14 +211,18 @@
             )
             # To be added: derive
 
+    ECC_KEY_TYPES = ('PSA_KEY_TYPE_ECC_KEY_PAIR',
+                     'PSA_KEY_TYPE_ECC_PUBLIC_KEY')
+
     def test_cases_for_not_supported(self) -> Iterator[test_case.TestCase]:
         """Generate test cases that exercise the creation of keys of unsupported types."""
         for key_type in sorted(self.constructors.key_types):
+            if key_type in self.ECC_KEY_TYPES:
+                continue
             kt = crypto_knowledge.KeyType(key_type)
             yield from self.test_cases_for_key_type_not_supported(kt)
         for curve_family in sorted(self.constructors.ecc_curves):
-            for constr in ('PSA_KEY_TYPE_ECC_KEY_PAIR',
-                           'PSA_KEY_TYPE_ECC_PUBLIC_KEY'):
+            for constr in self.ECC_KEY_TYPES:
                 kt = crypto_knowledge.KeyType(constr, [curve_family])
                 yield from self.test_cases_for_key_type_not_supported(
                     kt, param_descr='type')
@@ -260,13 +276,17 @@
         if self.forward:
             extra_arguments = []
         else:
+            flags = []
             # Some test keys have the RAW_DATA type and attributes that don't
             # necessarily make sense. We do this to validate numerical
             # encodings of the attributes.
             # Raw data keys have no useful exercise anyway so there is no
             # loss of test coverage.
-            exercise = key.type.string != 'PSA_KEY_TYPE_RAW_DATA'
-            extra_arguments = ['1' if exercise else '0']
+            if key.type.string != 'PSA_KEY_TYPE_RAW_DATA':
+                flags.append('TEST_FLAG_EXERCISE')
+            if 'READ_ONLY' in key.lifetime.string:
+                flags.append('TEST_FLAG_READ_ONLY')
+            extra_arguments = [' | '.join(flags) if flags else '0']
         tc.set_arguments([key.lifetime.string,
                           key.type.string, str(key.bits),
                           key.usage.string, key.alg.string, key.alg2.string,
@@ -335,23 +355,17 @@
 
     def all_keys_for_types(self) -> Iterator[StorageKey]:
         """Generate test keys covering key types and their representations."""
-        for key_type in sorted(self.constructors.key_types):
+        key_types = sorted(self.constructors.key_types)
+        for key_type in self.constructors.generate_expressions(key_types):
             yield from self.keys_for_type(key_type)
-        for key_type in sorted(self.constructors.key_types_from_curve):
-            for curve in sorted(self.constructors.ecc_curves):
-                yield from self.keys_for_type(key_type, [curve])
-        ## Diffie-Hellman (FFDH) is not supported yet, either in
-        ## crypto_knowledge.py or in Mbed TLS.
-        # for key_type in sorted(self.constructors.key_types_from_group):
-        #     for group in sorted(self.constructors.dh_groups):
-        #         yield from self.keys_for_type(key_type, [group])
 
     def keys_for_algorithm(self, alg: str) -> Iterator[StorageKey]:
         """Generate test keys for the specified algorithm."""
         # For now, we don't have information on the compatibility of key
         # types and algorithms. So we just test the encoding of algorithms,
         # and not that operations can be performed with them.
-        descr = alg
+        descr = re.sub(r'PSA_ALG_', r'', alg)
+        descr = re.sub(r',', r', ', re.sub(r' +', r'', descr))
         usage = 'PSA_KEY_USAGE_EXPORT'
         key1 = StorageKey(version=self.version,
                           id=1, lifetime=0x00000001,
@@ -370,17 +384,21 @@
 
     def all_keys_for_algorithms(self) -> Iterator[StorageKey]:
         """Generate test keys covering algorithm encodings."""
-        for alg in sorted(self.constructors.algorithms):
+        algorithms = sorted(self.constructors.algorithms)
+        for alg in self.constructors.generate_expressions(algorithms):
             yield from self.keys_for_algorithm(alg)
-        # To do: algorithm constructors with parameters
 
     def all_test_cases(self) -> Iterator[test_case.TestCase]:
         """Generate all storage format test cases."""
-        for key in self.all_keys_for_usage_flags():
-            yield self.make_test_case(key)
-        for key in self.all_keys_for_types():
-            yield self.make_test_case(key)
-        for key in self.all_keys_for_algorithms():
+        # First build a list of all keys, then construct all the corresponding
+        # test cases. This allows all required information to be obtained in
+        # one go, which is a significant performance gain as the information
+        # includes numerical values obtained by compiling a C program.
+        keys = [] #type: List[StorageKey]
+        keys += self.all_keys_for_usage_flags()
+        keys += self.all_keys_for_types()
+        keys += self.all_keys_for_algorithms()
+        for key in keys:
             yield self.make_test_case(key)
         # To do: vary id, lifetime
 
diff --git a/tests/scripts/test_psa_constant_names.py b/tests/scripts/test_psa_constant_names.py
index b3fdb8d..07c8ab2 100755
--- a/tests/scripts/test_psa_constant_names.py
+++ b/tests/scripts/test_psa_constant_names.py
@@ -28,231 +28,30 @@
 import re
 import subprocess
 import sys
+from typing import Iterable, List, Optional, Tuple
 
 import scripts_path # pylint: disable=unused-import
 from mbedtls_dev import c_build_helper
-from mbedtls_dev import macro_collector
+from mbedtls_dev.macro_collector import InputsForTest, PSAMacroEnumerator
+from mbedtls_dev import typing_util
 
-class ReadFileLineException(Exception):
-    def __init__(self, filename, line_number):
-        message = 'in {} at {}'.format(filename, line_number)
-        super(ReadFileLineException, self).__init__(message)
-        self.filename = filename
-        self.line_number = line_number
-
-class read_file_lines:
-    # Dear Pylint, conventionally, a context manager class name is lowercase.
-    # pylint: disable=invalid-name,too-few-public-methods
-    """Context manager to read a text file line by line.
-
-    ```
-    with read_file_lines(filename) as lines:
-        for line in lines:
-            process(line)
-    ```
-    is equivalent to
-    ```
-    with open(filename, 'r') as input_file:
-        for line in input_file:
-            process(line)
-    ```
-    except that if process(line) raises an exception, then the read_file_lines
-    snippet annotates the exception with the file name and line number.
-    """
-    def __init__(self, filename, binary=False):
-        self.filename = filename
-        self.line_number = 'entry'
-        self.generator = None
-        self.binary = binary
-    def __enter__(self):
-        self.generator = enumerate(open(self.filename,
-                                        'rb' if self.binary else 'r'))
-        return self
-    def __iter__(self):
-        for line_number, content in self.generator:
-            self.line_number = line_number
-            yield content
-        self.line_number = 'exit'
-    def __exit__(self, exc_type, exc_value, exc_traceback):
-        if exc_type is not None:
-            raise ReadFileLineException(self.filename, self.line_number) \
-                from exc_value
-
-class InputsForTest(macro_collector.PSAMacroEnumerator):
-    # pylint: disable=too-many-instance-attributes
-    """Accumulate information about macros to test.
-
-    This includes macro names as well as information about their arguments
-    when applicable.
-    """
-
-    def __init__(self):
-        super().__init__()
-        self.all_declared = set()
-        # Sets of names per type
-        self.statuses.add('PSA_SUCCESS')
-        self.algorithms.add('0xffffffff')
-        self.ecc_curves.add('0xff')
-        self.dh_groups.add('0xff')
-        self.key_types.add('0xffff')
-        self.key_usage_flags.add('0x80000000')
-
-        # Hard-coded values for unknown algorithms
-        #
-        # These have to have values that are correct for their respective
-        # PSA_ALG_IS_xxx macros, but are also not currently assigned and are
-        # not likely to be assigned in the near future.
-        self.hash_algorithms.add('0x020000fe') # 0x020000ff is PSA_ALG_ANY_HASH
-        self.mac_algorithms.add('0x03007fff')
-        self.ka_algorithms.add('0x09fc0000')
-        self.kdf_algorithms.add('0x080000ff')
-        # For AEAD algorithms, the only variability is over the tag length,
-        # and this only applies to known algorithms, so don't test an
-        # unknown algorithm.
-
-        # Identifier prefixes
-        self.table_by_prefix = {
-            'ERROR': self.statuses,
-            'ALG': self.algorithms,
-            'ECC_CURVE': self.ecc_curves,
-            'DH_GROUP': self.dh_groups,
-            'KEY_TYPE': self.key_types,
-            'KEY_USAGE': self.key_usage_flags,
-        }
-        # Test functions
-        self.table_by_test_function = {
-            # Any function ending in _algorithm also gets added to
-            # self.algorithms.
-            'key_type': [self.key_types],
-            'block_cipher_key_type': [self.key_types],
-            'stream_cipher_key_type': [self.key_types],
-            'ecc_key_family': [self.ecc_curves],
-            'ecc_key_types': [self.ecc_curves],
-            'dh_key_family': [self.dh_groups],
-            'dh_key_types': [self.dh_groups],
-            'hash_algorithm': [self.hash_algorithms],
-            'mac_algorithm': [self.mac_algorithms],
-            'cipher_algorithm': [],
-            'hmac_algorithm': [self.mac_algorithms],
-            'aead_algorithm': [self.aead_algorithms],
-            'key_derivation_algorithm': [self.kdf_algorithms],
-            'key_agreement_algorithm': [self.ka_algorithms],
-            'asymmetric_signature_algorithm': [],
-            'asymmetric_signature_wildcard': [self.algorithms],
-            'asymmetric_encryption_algorithm': [],
-            'other_algorithm': [],
-        }
-        self.arguments_for['mac_length'] += ['1', '63']
-        self.arguments_for['min_mac_length'] += ['1', '63']
-        self.arguments_for['tag_length'] += ['1', '63']
-        self.arguments_for['min_tag_length'] += ['1', '63']
-
-    def get_names(self, type_word):
-        """Return the set of known names of values of the given type."""
-        return {
-            'status': self.statuses,
-            'algorithm': self.algorithms,
-            'ecc_curve': self.ecc_curves,
-            'dh_group': self.dh_groups,
-            'key_type': self.key_types,
-            'key_usage': self.key_usage_flags,
-        }[type_word]
-
-    # Regex for interesting header lines.
-    # Groups: 1=macro name, 2=type, 3=argument list (optional).
-    _header_line_re = \
-        re.compile(r'#define +' +
-                   r'(PSA_((?:(?:DH|ECC|KEY)_)?[A-Z]+)_\w+)' +
-                   r'(?:\(([^\n()]*)\))?')
-    # Regex of macro names to exclude.
-    _excluded_name_re = re.compile(r'_(?:GET|IS|OF)_|_(?:BASE|FLAG|MASK)\Z')
-    # Additional excluded macros.
-    _excluded_names = set([
-        # Macros that provide an alternative way to build the same
-        # algorithm as another macro.
-        'PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG',
-        'PSA_ALG_FULL_LENGTH_MAC',
-        # Auxiliary macro whose name doesn't fit the usual patterns for
-        # auxiliary macros.
-        'PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG_CASE',
-    ])
-    def parse_header_line(self, line):
-        """Parse a C header line, looking for "#define PSA_xxx"."""
-        m = re.match(self._header_line_re, line)
-        if not m:
-            return
-        name = m.group(1)
-        self.all_declared.add(name)
-        if re.search(self._excluded_name_re, name) or \
-           name in self._excluded_names:
-            return
-        dest = self.table_by_prefix.get(m.group(2))
-        if dest is None:
-            return
-        dest.add(name)
-        if m.group(3):
-            self.argspecs[name] = self._argument_split(m.group(3))
-
-    _nonascii_re = re.compile(rb'[^\x00-\x7f]+')
-    def parse_header(self, filename):
-        """Parse a C header file, looking for "#define PSA_xxx"."""
-        with read_file_lines(filename, binary=True) as lines:
-            for line in lines:
-                line = re.sub(self._nonascii_re, rb'', line).decode('ascii')
-                self.parse_header_line(line)
-
-    _macro_identifier_re = re.compile(r'[A-Z]\w+')
-    def generate_undeclared_names(self, expr):
-        for name in re.findall(self._macro_identifier_re, expr):
-            if name not in self.all_declared:
-                yield name
-
-    def accept_test_case_line(self, function, argument):
-        #pylint: disable=unused-argument
-        undeclared = list(self.generate_undeclared_names(argument))
-        if undeclared:
-            raise Exception('Undeclared names in test case', undeclared)
-        return True
-
-    def add_test_case_line(self, function, argument):
-        """Parse a test case data line, looking for algorithm metadata tests."""
-        sets = []
-        if function.endswith('_algorithm'):
-            sets.append(self.algorithms)
-            if function == 'key_agreement_algorithm' and \
-               argument.startswith('PSA_ALG_KEY_AGREEMENT('):
-                # We only want *raw* key agreement algorithms as such, so
-                # exclude ones that are already chained with a KDF.
-                # Keep the expression as one to test as an algorithm.
-                function = 'other_algorithm'
-        sets += self.table_by_test_function[function]
-        if self.accept_test_case_line(function, argument):
-            for s in sets:
-                s.add(argument)
-
-    # Regex matching a *.data line containing a test function call and
-    # its arguments. The actual definition is partly positional, but this
-    # regex is good enough in practice.
-    _test_case_line_re = re.compile(r'(?!depends_on:)(\w+):([^\n :][^:\n]*)')
-    def parse_test_cases(self, filename):
-        """Parse a test case file (*.data), looking for algorithm metadata tests."""
-        with read_file_lines(filename) as lines:
-            for line in lines:
-                m = re.match(self._test_case_line_re, line)
-                if m:
-                    self.add_test_case_line(m.group(1), m.group(2))
-
-def gather_inputs(headers, test_suites, inputs_class=InputsForTest):
+def gather_inputs(headers: Iterable[str],
+                  test_suites: Iterable[str],
+                  inputs_class=InputsForTest) -> PSAMacroEnumerator:
     """Read the list of inputs to test psa_constant_names with."""
     inputs = inputs_class()
     for header in headers:
         inputs.parse_header(header)
     for test_cases in test_suites:
         inputs.parse_test_cases(test_cases)
+    inputs.add_numerical_values()
     inputs.gather_arguments()
     return inputs
 
-def run_c(type_word, expressions, include_path=None, keep_c=False):
+def run_c(type_word: str,
+          expressions: Iterable[str],
+          include_path: Optional[str] = None,
+          keep_c: bool = False) -> List[str]:
     """Generate and run a program to print out numerical values of C expressions."""
     if type_word == 'status':
         cast_to = 'long'
@@ -271,14 +70,17 @@
     )
 
 NORMALIZE_STRIP_RE = re.compile(r'\s+')
-def normalize(expr):
+def normalize(expr: str) -> str:
     """Normalize the C expression so as not to care about trivial differences.
 
     Currently "trivial differences" means whitespace.
     """
     return re.sub(NORMALIZE_STRIP_RE, '', expr)
 
-def collect_values(inputs, type_word, include_path=None, keep_c=False):
+def collect_values(inputs: InputsForTest,
+                   type_word: str,
+                   include_path: Optional[str] = None,
+                   keep_c: bool = False) -> Tuple[List[str], List[str]]:
     """Generate expressions using known macro names and calculate their values.
 
     Return a list of pairs of (expr, value) where expr is an expression and
@@ -296,12 +98,12 @@
     Error = namedtuple('Error',
                        ['type', 'expression', 'value', 'output'])
 
-    def __init__(self, options):
+    def __init__(self, options) -> None:
         self.options = options
         self.count = 0
-        self.errors = []
+        self.errors = [] #type: List[Tests.Error]
 
-    def run_one(self, inputs, type_word):
+    def run_one(self, inputs: InputsForTest, type_word: str) -> None:
         """Test psa_constant_names for the specified type.
 
         Run the program on the names for this type.
@@ -311,9 +113,10 @@
         expressions, values = collect_values(inputs, type_word,
                                              include_path=self.options.include,
                                              keep_c=self.options.keep_c)
-        output = subprocess.check_output([self.options.program, type_word] +
-                                         values)
-        outputs = output.decode('ascii').strip().split('\n')
+        output_bytes = subprocess.check_output([self.options.program,
+                                                type_word] + values)
+        output = output_bytes.decode('ascii')
+        outputs = output.strip().split('\n')
         self.count += len(expressions)
         for expr, value, output in zip(expressions, values, outputs):
             if self.options.show:
@@ -324,13 +127,13 @@
                                               value=value,
                                               output=output))
 
-    def run_all(self, inputs):
+    def run_all(self, inputs: InputsForTest) -> None:
         """Run psa_constant_names on all the gathered inputs."""
         for type_word in ['status', 'algorithm', 'ecc_curve', 'dh_group',
                           'key_type', 'key_usage']:
             self.run_one(inputs, type_word)
 
-    def report(self, out):
+    def report(self, out: typing_util.Writable) -> None:
         """Describe each case where the output is not as expected.
 
         Write the errors to ``out``.
@@ -365,7 +168,7 @@
                         help='Program to test')
     parser.add_argument('--show',
                         action='store_true',
-                        help='Keep the intermediate C file')
+                        help='Show tested values on stdout')
     parser.add_argument('--no-show',
                         action='store_false', dest='show',
                         help='Don\'t show tested values (default)')
diff --git a/tests/suites/test_suite_psa_crypto_storage_format.function b/tests/suites/test_suite_psa_crypto_storage_format.function
index 76cfe57..34d63a7 100644
--- a/tests/suites/test_suite_psa_crypto_storage_format.function
+++ b/tests/suites/test_suite_psa_crypto_storage_format.function
@@ -7,6 +7,8 @@
 
 #include <psa_crypto_its.h>
 
+#define TEST_FLAG_EXERCISE      0x00000001
+
 /** Write a key with the given attributes and key material to storage.
  * Test that it has the expected representation.
  *
@@ -67,7 +69,7 @@
                           const data_t *expected_material,
                           psa_storage_uid_t uid,
                           const data_t *representation,
-                          int exercise )
+                          int flags )
 {
     psa_key_attributes_t actual_attributes = PSA_KEY_ATTRIBUTES_INIT;
     mbedtls_svc_key_id_t key_id = psa_get_key_id( expected_attributes );
@@ -105,7 +107,7 @@
                         exported_material, length );
     }
 
-    if( exercise )
+    if( flags & TEST_FLAG_EXERCISE )
     {
         TEST_ASSERT( mbedtls_test_psa_exercise_key(
                          key_id,
@@ -183,7 +185,7 @@
 void key_storage_read( int lifetime_arg, int type_arg, int bits_arg,
                        int usage_arg, int alg_arg, int alg2_arg,
                        data_t *material,
-                       data_t *representation, int exercise )
+                       data_t *representation, int flags )
 {
     /* Backward compatibility: read a key in the format of a past version
      * and check that this version can use it. */
@@ -213,7 +215,7 @@
      * guarantees backward compatibility with keys that were stored by
      * past versions of Mbed TLS. */
     TEST_ASSERT( test_read_key( &attributes, material,
-                                uid, representation, exercise ) );
+                                uid, representation, flags ) );
 
 exit:
     psa_reset_key_attributes( &attributes );