Merge pull request #9229 from gabor-mezei-arm/9158_config.py_use_crypto_config

Adapt config.py to configuration file split
diff --git a/framework b/framework
index e8b4ae9..331565b 160000
--- a/framework
+++ b/framework
@@ -1 +1 @@
-Subproject commit e8b4ae9bc4bf7e643ee46bf8ff4ef613be2de86f
+Subproject commit 331565b041f794df2da76394b3b0039abce30355
diff --git a/scripts/config.py b/scripts/config.py
index 7c32db1..150078a 100755
--- a/scripts/config.py
+++ b/scripts/config.py
@@ -19,6 +19,8 @@
 import os
 import re
 
+from abc import ABCMeta
+
 class Setting:
     """Representation of one Mbed TLS mbedtls_config.h setting.
 
@@ -30,12 +32,13 @@
       present in mbedtls_config.h but commented out.
     * section: the name of the section that contains this symbol.
     """
-    # pylint: disable=too-few-public-methods
-    def __init__(self, active, name, value='', section=None):
+    # pylint: disable=too-few-public-methods, too-many-arguments
+    def __init__(self, active, name, value='', section=None, configfile=None):
         self.active = active
         self.name = name
         self.value = value
         self.section = section
+        self.configfile = configfile
 
 class Config:
     """Representation of the Mbed TLS configuration.
@@ -54,6 +57,7 @@
       name to become set.
     """
 
+    # pylint: disable=unused-argument
     def __init__(self):
         self.settings = {}
 
@@ -125,7 +129,13 @@
         """
         if name not in self.settings:
             return
-        self.settings[name].active = False
+
+        setting = self.settings[name]
+        # Check if modifying the config file
+        if setting.configfile and setting.active:
+            setting.configfile.modified = True
+
+        setting.active = False
 
     def adapt(self, adapter):
         """Run adapter on each known symbol and (de)activate it accordingly.
@@ -138,8 +148,12 @@
         otherwise unset `name` (i.e. make it known but inactive).
         """
         for setting in self.settings.values():
+            is_active = setting.active
             setting.active = adapter(setting.name, setting.active,
                                      setting.section)
+            # Check if modifying the config file
+            if setting.configfile and setting.active != is_active:
+                setting.configfile.modified = True
 
     def change_matching(self, regexs, enable):
         """Change all symbols matching one of the regexs to the desired state."""
@@ -148,11 +162,18 @@
         regex = re.compile('|'.join(regexs))
         for setting in self.settings.values():
             if regex.search(setting.name):
+                # Check if modifying the config file
+                if setting.configfile and setting.active != enable:
+                    setting.configfile.modified = True
                 setting.active = enable
 
 def is_full_section(section):
-    """Is this section affected by "config.py full" and friends?"""
-    return section.endswith('support') or section.endswith('modules')
+    """Is this section affected by "config.py full" and friends?
+
+    In a config file where the sections are not used the whole config file
+    is an empty section (with value None) and the whole file is affected.
+    """
+    return section is None or section.endswith('support') or section.endswith('modules')
 
 def realfull_adapter(_name, active, section):
     """Activate all symbols found in the global and boolean feature sections.
@@ -168,6 +189,26 @@
         return active
     return True
 
+PSA_UNSUPPORTED_FEATURE = frozenset([
+    'PSA_WANT_ALG_CBC_MAC',
+    'PSA_WANT_ALG_XTS',
+    'PSA_WANT_KEY_TYPE_RSA_KEY_PAIR_DERIVE',
+    'PSA_WANT_KEY_TYPE_DH_KEY_PAIR_DERIVE'
+])
+
+PSA_DEPRECATED_FEATURE = frozenset([
+    'PSA_WANT_KEY_TYPE_ECC_KEY_PAIR',
+    'PSA_WANT_KEY_TYPE_RSA_KEY_PAIR'
+])
+
+PSA_UNSTABLE_FEATURE = frozenset([
+    'PSA_WANT_ECC_SECP_K1_224'
+])
+
+EXCLUDE_FROM_CRYPTO = PSA_UNSUPPORTED_FEATURE | \
+                      PSA_DEPRECATED_FEATURE | \
+                      PSA_UNSTABLE_FEATURE
+
 # The goal of the full configuration is to have everything that can be tested
 # together. This includes deprecated or insecure options. It excludes:
 # * Options that require additional build dependencies or unusual hardware.
@@ -210,6 +251,9 @@
     'MBEDTLS_TEST_CONSTANT_FLOW_MEMSAN', # build dependency (clang+memsan)
     'MBEDTLS_TEST_CONSTANT_FLOW_VALGRIND', # build dependency (valgrind headers)
     'MBEDTLS_X509_REMOVE_INFO', # removes a feature
+    *PSA_UNSUPPORTED_FEATURE,
+    *PSA_DEPRECATED_FEATURE,
+    *PSA_UNSTABLE_FEATURE
 ])
 
 def is_seamless_alt(name):
@@ -316,6 +360,8 @@
             'MBEDTLS_PKCS7_C', # part of libmbedx509
     ]:
         return False
+    if name in EXCLUDE_FROM_CRYPTO:
+        return False
     return True
 
 def crypto_adapter(adapter):
@@ -334,6 +380,7 @@
 
 DEPRECATED = frozenset([
     'MBEDTLS_PSA_CRYPTO_SE_C',
+    *PSA_DEPRECATED_FEATURE
 ])
 def no_deprecated_adapter(adapter):
     """Modify an adapter to disable deprecated symbols.
@@ -368,43 +415,25 @@
         return adapter(name, active, section)
     return continuation
 
-class ConfigFile(Config):
-    """Representation of the Mbed TLS configuration read for a file.
+class ConfigFile(metaclass=ABCMeta):
+    """Representation of a configuration file."""
 
-    See the documentation of the `Config` class for methods to query
-    and modify the configuration.
-    """
-
-    _path_in_tree = 'include/mbedtls/mbedtls_config.h'
-    default_path = [_path_in_tree,
-                    os.path.join(os.path.dirname(__file__),
-                                 os.pardir,
-                                 _path_in_tree),
-                    os.path.join(os.path.dirname(os.path.abspath(os.path.dirname(__file__))),
-                                 _path_in_tree)]
-
-    def __init__(self, filename=None):
-        """Read the Mbed TLS configuration file."""
+    def __init__(self, default_path, name, filename=None):
+        """Check if the config file exists."""
         if filename is None:
-            for candidate in self.default_path:
+            for candidate in default_path:
                 if os.path.lexists(candidate):
                     filename = candidate
                     break
             else:
-                raise Exception('Mbed TLS configuration file not found',
-                                self.default_path)
-        super().__init__()
-        self.filename = filename
-        self.inclusion_guard = None
-        self.current_section = 'header'
-        with open(filename, 'r', encoding='utf-8') as file:
-            self.templates = [self._parse_line(line) for line in file]
-        self.current_section = None
+                raise FileNotFoundError(f'{name} configuration file not found: '
+                                        f'{filename if filename else default_path}')
 
-    def set(self, name, value=None):
-        if name not in self.settings:
-            self.templates.append((name, '', '#define ' + name + ' '))
-        super().set(name, value)
+        self.filename = filename
+        self.templates = []
+        self.current_section = None
+        self.inclusion_guard = None
+        self.modified = False
 
     _define_line_regexp = (r'(?P<indentation>\s*)' +
                            r'(?P<commented_out>(//\s*)?)' +
@@ -420,39 +449,57 @@
                                                 _ifndef_line_regexp,
                                                 _section_line_regexp]))
     def _parse_line(self, line):
-        """Parse a line in mbedtls_config.h and return the corresponding template."""
+        """Parse a line in the config file, save the templates representing the lines
+           and return the corresponding setting element.
+        """
+
         line = line.rstrip('\r\n')
         m = re.match(self._config_line_regexp, line)
         if m is None:
-            return line
+            self.templates.append(line)
+            return None
         elif m.group('section'):
             self.current_section = m.group('section')
-            return line
+            self.templates.append(line)
+            return None
         elif m.group('inclusion_guard') and self.inclusion_guard is None:
             self.inclusion_guard = m.group('inclusion_guard')
-            return line
+            self.templates.append(line)
+            return None
         else:
             active = not m.group('commented_out')
             name = m.group('name')
             value = m.group('value')
             if name == self.inclusion_guard and value == '':
                 # The file double-inclusion guard is not an option.
-                return line
+                self.templates.append(line)
+                return None
             template = (name,
                         m.group('indentation'),
                         m.group('define') + name +
                         m.group('arguments') + m.group('separator'))
-            self.settings[name] = Setting(active, name, value,
-                                          self.current_section)
-            return template
+            self.templates.append(template)
 
-    def _format_template(self, name, indent, middle):
-        """Build a line for mbedtls_config.h for the given setting.
+            return (active, name, value, self.current_section)
+
+    def parse_file(self):
+        """Parse the whole file and return the settings."""
+
+        with open(self.filename, 'r', encoding='utf-8') as file:
+            for line in file:
+                setting = self._parse_line(line)
+                if setting is not None:
+                    yield setting
+        self.current_section = None
+
+    #pylint: disable=no-self-use
+    def _format_template(self, setting, indent, middle):
+        """Build a line for the config file for the given setting.
 
         The line has the form "<indent>#define <name> <value>"
         where <middle> is "#define <name> ".
         """
-        setting = self.settings[name]
+
         value = setting.value
         if value is None:
             value = ''
@@ -470,26 +517,230 @@
                         middle,
                         value]).rstrip()
 
-    def write_to_stream(self, output):
+    def write_to_stream(self, settings, output):
         """Write the whole configuration to output."""
+
         for template in self.templates:
             if isinstance(template, str):
                 line = template
             else:
-                line = self._format_template(*template)
+                name, indent, middle = template
+                line = self._format_template(settings[name], indent, middle)
             output.write(line + '\n')
 
+    def write(self, settings, filename=None):
+        """Write the whole configuration to the file it was read from.
+
+        If filename is specified, write to this file instead.
+        """
+
+        if filename is None:
+            filename = self.filename
+
+        # Not modified so no need to write to the file
+        if not self.modified and filename == self.filename:
+            return
+
+        with open(filename, 'w', encoding='utf-8') as output:
+            self.write_to_stream(settings, output)
+
+class MbedTLSConfigFile(ConfigFile):
+    """Representation of an MbedTLS configuration file."""
+
+    _path_in_tree = 'include/mbedtls/mbedtls_config.h'
+    default_path = [_path_in_tree,
+                    os.path.join(os.path.dirname(__file__),
+                                 os.pardir,
+                                 _path_in_tree),
+                    os.path.join(os.path.dirname(os.path.abspath(os.path.dirname(__file__))),
+                                 _path_in_tree)]
+
+    def __init__(self, filename=None):
+        super().__init__(self.default_path, 'Mbed TLS', filename)
+        self.current_section = 'header'
+
+class CryptoConfigFile(ConfigFile):
+    """Representation of a Crypto configuration file."""
+
+    # Temporary, while Mbed TLS does not just rely on the TF-PSA-Crypto
+    # build system to build its crypto library. When it does, the
+    # condition can just be removed.
+    _path_in_tree = 'include/psa/crypto_config.h' \
+                    if os.path.isfile('include/psa/crypto_config.h') else \
+                    'tf-psa-crypto/include/psa/crypto_config.h'
+    default_path = [_path_in_tree,
+                    os.path.join(os.path.dirname(__file__),
+                                 os.pardir,
+                                 _path_in_tree),
+                    os.path.join(os.path.dirname(os.path.abspath(os.path.dirname(__file__))),
+                                 _path_in_tree)]
+
+    def __init__(self, filename=None):
+        super().__init__(self.default_path, 'Crypto', filename)
+
+class MbedTLSConfig(Config):
+    """Representation of the Mbed TLS configuration.
+
+    See the documentation of the `Config` class for methods to query
+    and modify the configuration.
+    """
+
+    def __init__(self, filename=None):
+        """Read the Mbed TLS configuration file."""
+
+        super().__init__()
+        self.configfile = MbedTLSConfigFile(filename)
+        self.settings.update({name: Setting(active, name, value, section, self.configfile)
+                              for (active, name, value, section)
+                              in self.configfile.parse_file()})
+
+    def set(self, name, value=None):
+        """Set name to the given value and make it active."""
+
+        if name not in self.settings:
+            self.configfile.templates.append((name, '', '#define ' + name + ' '))
+
+        super().set(name, value)
+
     def write(self, filename=None):
         """Write the whole configuration to the file it was read from.
 
         If filename is specified, write to this file instead.
         """
-        if filename is None:
-            filename = self.filename
-        with open(filename, 'w', encoding='utf-8') as output:
-            self.write_to_stream(output)
+
+        self.configfile.write(self.settings, filename)
+
+    def filename(self):
+        """Get the name of the config file."""
+
+        return self.configfile.filename
+
+class CryptoConfig(Config):
+    """Representation of the PSA crypto configuration.
+
+    See the documentation of the `Config` class for methods to query
+    and modify the configuration.
+    """
+
+    def __init__(self, filename=None):
+        """Read the PSA crypto configuration file."""
+
+        super().__init__()
+        self.configfile = CryptoConfigFile(filename)
+        self.settings.update({name: Setting(active, name, value, section, self.configfile)
+                              for (active, name, value, section)
+                              in self.configfile.parse_file()})
+
+    def set(self, name, value='1'):
+        """Set name to the given value and make it active."""
+
+        if name in PSA_UNSUPPORTED_FEATURE:
+            raise ValueError(f'Feature is unsupported: \'{name}\'')
+        if name in PSA_UNSTABLE_FEATURE:
+            raise ValueError(f'Feature is unstable: \'{name}\'')
+
+        if name not in self.settings:
+            self.configfile.templates.append((name, '', '#define ' + name + ' '))
+
+        super().set(name, value)
+
+    def write(self, filename=None):
+        """Write the whole configuration to the file it was read from.
+
+        If filename is specified, write to this file instead.
+        """
+
+        self.configfile.write(self.settings, filename)
+
+    def filename(self):
+        """Get the name of the config file."""
+
+        return self.configfile.filename
+
+class CombinedConfig(Config):
+    """Representation of MbedTLS and PSA crypto configuration
+
+    See the documentation of the `Config` class for methods to query
+    and modify the configuration.
+    """
+
+    def __init__(self, *configs):
+        super().__init__()
+        for config in configs:
+            if isinstance(config, MbedTLSConfigFile):
+                self.mbedtls_configfile = config
+            elif isinstance(config, CryptoConfigFile):
+                self.crypto_configfile = config
+            else:
+                raise ValueError(f'Invalid configfile: {config}')
+
+        self.settings.update({name: Setting(active, name, value, section, configfile)
+                              for configfile in [self.mbedtls_configfile, self.crypto_configfile]
+                              for (active, name, value, section) in configfile.parse_file()})
+
+    _crypto_regexp = re.compile(r'$PSA_.*')
+    def _get_configfile(self, name):
+        """Find a config type for a setting name"""
+
+        if name in self.settings:
+            return self.settings[name].configfile
+        elif re.match(self._crypto_regexp, name):
+            return self.crypto_configfile
+        else:
+            return self.mbedtls_configfile
+
+    def __setitem__(self, name, value):
+        super().__setitem__(name, value)
+        self.settings[name].configfile.modified = True
+
+    def set(self, name, value=None):
+        """Set name to the given value and make it active."""
+
+        configfile = self._get_configfile(name)
+
+        if configfile == self.crypto_configfile:
+            if name in PSA_UNSUPPORTED_FEATURE:
+                raise ValueError(f'Feature is unsupported: \'{name}\'')
+            if name in PSA_UNSTABLE_FEATURE:
+                raise ValueError(f'Feature is unstable: \'{name}\'')
+
+            # The default value in the crypto config is '1'
+            if not value:
+                value = '1'
+
+        if name in self.settings:
+            setting = self.settings[name]
+            if not setting.active or (value is not None and setting.value != value):
+                configfile.modified = True
+        else:
+            configfile.templates.append((name, '', '#define ' + name + ' '))
+            configfile.modified = True
+
+        super().set(name, value)
+
+    def write(self, mbedtls_file=None, crypto_file=None):
+        """Write the whole configuration to the file it was read from.
+
+        If mbedtls_file or crypto_file is specified, write the specific configuration
+        to the corresponding file instead.
+        """
+
+        self.mbedtls_configfile.write(self.settings, mbedtls_file)
+        self.crypto_configfile.write(self.settings, crypto_file)
+
+    def filename(self, name=None):
+        """Get the names of the config files.
+
+        If 'name' is specified return the name of the config file where it is defined.
+        """
+
+        if not name:
+            return [config.filename for config in [self.mbedtls_configfile, self.crypto_configfile]]
+
+        return self._get_configfile(name).filename
 
 if __name__ == '__main__':
+    #pylint: disable=too-many-statements
     def main():
         """Command line mbedtls_config.h manipulation tool."""
         parser = argparse.ArgumentParser(description="""
@@ -498,7 +749,11 @@
         parser.add_argument('--file', '-f',
                             help="""File to read (and modify if requested).
                             Default: {}.
-                            """.format(ConfigFile.default_path))
+                            """.format(MbedTLSConfigFile.default_path))
+        parser.add_argument('--cryptofile', '-c',
+                            help="""Crypto file to read (and modify if requested).
+                            Default: {}.
+                            """.format(CryptoConfigFile.default_path))
         parser.add_argument('--force', '-o',
                             action='store_true',
                             help="""For the set command, if SYMBOL is not
@@ -576,7 +831,7 @@
                     excluding X.509 and TLS.""")
 
         args = parser.parse_args()
-        config = ConfigFile(args.file)
+        config = CombinedConfig(MbedTLSConfigFile(args.file), CryptoConfigFile(args.cryptofile))
         if args.command is None:
             parser.print_help()
             return 1
@@ -590,7 +845,7 @@
             if not args.force and args.symbol not in config.settings:
                 sys.stderr.write("A #define for the symbol {} "
                                  "was not found in {}\n"
-                                 .format(args.symbol, config.filename))
+                                 .format(args.symbol, config.filename(args.symbol)))
                 return 1
             config.set(args.symbol, value=args.value)
         elif args.command == 'set-all':
diff --git a/tests/scripts/depends.py b/tests/scripts/depends.py
index fa17e13..5098099 100755
--- a/tests/scripts/depends.py
+++ b/tests/scripts/depends.py
@@ -541,7 +541,7 @@
                             default=True)
         options = parser.parse_args()
         os.chdir(options.directory)
-        conf = config.ConfigFile(options.config)
+        conf = config.MbedTLSConfig(options.config)
         domain_data = DomainData(options, conf)
 
         if options.tasks is True: