Separate code parsing and name checking in two classes

Signed-off-by: Yuto Takano <yuto.takano@arm.com>
diff --git a/tests/scripts/check_names.py b/tests/scripts/check_names.py
index 9a7f391..9577014 100755
--- a/tests/scripts/check_names.py
+++ b/tests/scripts/check_names.py
@@ -20,16 +20,26 @@
 are consistent with the house style and are also self-consistent. It only runs
 on Linux and macOS since it depends on nm.
 
-The script performs the following checks:
+It contains two major Python classes, CodeParser and NameChecker. They both have
+a comprehensive "run-all" function (comprehensive_parse() and perform_checks())
+but the individual functions can also be used for specific needs.
+
+CodeParser makes heavy use of regular expressions to parse the code, and is
+dependent on the current code formatting. Many Python C parser libraries require
+preprocessed C code, which means no macro parsing. Compiler tools are also not
+very helpful when we want the exact location in the original source (which
+becomes impossible when e.g. comments are stripped).
+
+NameChecker performs the following checks:
 
 - All exported and available symbols in the library object files, are explicitly
   declared in the header files. This uses the nm command.
 - All macros, constants, and identifiers (function names, struct names, etc)
-  follow the required pattern.
+  follow the required regex pattern.
 - Typo checking: All words that begin with MBED exist as macros or constants.
 
-Returns 0 on success, 1 on test failure, and 2 if there is a script error or a
-subprocess error. Must be run from Mbed TLS root.
+The script returns 0 on success, 1 on test failure, and 2 if there is a script
+error error. Must be run from Mbed TLS root.
 """
 
 import argparse
@@ -168,16 +178,15 @@
             .format(self.match.filename, self.match.pos[0], self.match.name)
         ) + "\n" + str(self.match)
 
-class NameCheck():
+class CodeParser():
     """
-    Representation of the core name checking operation performed by this script.
-    Shares a common logger, and a shared return code.
+    Class for retrieving files and parsing the code. This can be used
+    independently of the checks that NameChecker performs, for example for
+    list_internal_identifiers.py.
     """
-    def __init__(self, verbose=False):
-        self.log = None
+    def __init__(self, log):
+        self.log = log
         self.check_repo_path()
-        self.return_code = 0
-        self.setup_logger(verbose)
 
         # Memo for storing "glob expression": set(filepaths)
         self.files = {}
@@ -185,9 +194,6 @@
         # Globally excluded filenames
         self.excluded_files = ["**/bn_mul", "**/compat-2.x.h"]
 
-        # Will contain the parse result after a comprehensive parse
-        self.parse_result = {}
-
     @staticmethod
     def check_repo_path():
         """
@@ -197,71 +203,12 @@
         if not all(os.path.isdir(d) for d in ["include", "library", "tests"]):
             raise Exception("This script must be run from Mbed TLS root")
 
-    def set_return_code(self, return_code):
-        if return_code > self.return_code:
-            self.log.debug("Setting new return code to {}".format(return_code))
-            self.return_code = return_code
-
-    def setup_logger(self, verbose=False):
+    def comprehensive_parse(self):
         """
-        Set up a logger and set the change the default logging level from
-        WARNING to INFO. Loggers are better than print statements since their
-        verbosity can be controlled.
-        """
-        self.log = logging.getLogger()
-        if verbose:
-            self.log.setLevel(logging.DEBUG)
-        else:
-            self.log.setLevel(logging.INFO)
-        self.log.addHandler(logging.StreamHandler())
+        Comprehensive ("default") function to call each parsing function and
+        retrieve various elements of the code, together with the source location.
 
-    def get_files(self, include_wildcards, exclude_wildcards):
-        """
-        Get all files that match any of the UNIX-style wildcards. While the
-        check_names script is designed only for use on UNIX/macOS (due to nm),
-        this function alone would work fine on Windows even with forward slashes
-        in the wildcard.
-
-        Args:
-        * include_wildcards: a List of shell-style wildcards to match filepaths.
-        * exclude_wildacrds: a List of shell-style wildcards to exclude.
-
-        Returns a List of relative filepaths.
-        """
-        accumulator = set()
-
-        # exclude_wildcards may be None. Also, consider the global exclusions.
-        exclude_wildcards = (exclude_wildcards or []) + self.excluded_files
-
-        # Perform set union on the glob results. Memoise individual sets.
-        for include_wildcard in include_wildcards:
-            if include_wildcard not in self.files:
-                self.files[include_wildcard] = set(glob.glob(
-                    include_wildcard,
-                    recursive=True
-                ))
-
-            accumulator = accumulator.union(self.files[include_wildcard])
-
-        # Perform set difference to exclude. Also use the same memo since their
-        # behaviour is pretty much identical and it can benefit from the cache.
-        for exclude_wildcard in exclude_wildcards:
-            if exclude_wildcard not in self.files:
-                self.files[exclude_wildcard] = set(glob.glob(
-                    exclude_wildcard,
-                    recursive=True
-                ))
-
-            accumulator = accumulator.difference(self.files[exclude_wildcard])
-
-        return list(accumulator)
-
-    def parse_names_in_source(self):
-        """
-        Comprehensive function to call each parsing function and retrieve
-        various elements of the code, together with their source location.
-        Puts the parsed values in the internal variable self.parse_result, so
-        they can be used from perform_checks().
+        Returns a dict of parsed item key to the corresponding List of Matches.
         """
         self.log.info("Parsing source code...")
         self.log.debug(
@@ -315,8 +262,7 @@
         self.log.debug("  {} Enum Constants".format(len(enum_consts)))
         self.log.debug("  {} Identifiers".format(len(identifiers)))
         self.log.debug("  {} Exported Symbols".format(len(symbols)))
-        self.log.info("Analysing...")
-        self.parse_result = {
+        return {
             "macros": actual_macros,
             "enum_consts": enum_consts,
             "identifiers": identifiers,
@@ -324,6 +270,47 @@
             "mbed_words": mbed_words
         }
 
+    def get_files(self, include_wildcards, exclude_wildcards):
+        """
+        Get all files that match any of the UNIX-style wildcards. While the
+        check_names script is designed only for use on UNIX/macOS (due to nm),
+        this function alone would work fine on Windows even with forward slashes
+        in the wildcard.
+
+        Args:
+        * include_wildcards: a List of shell-style wildcards to match filepaths.
+        * exclude_wildcards: a List of shell-style wildcards to exclude.
+
+        Returns a List of relative filepaths.
+        """
+        accumulator = set()
+
+        # exclude_wildcards may be None. Also, consider the global exclusions.
+        exclude_wildcards = (exclude_wildcards or []) + self.excluded_files
+
+        # Perform set union on the glob results. Memoise individual sets.
+        for include_wildcard in include_wildcards:
+            if include_wildcard not in self.files:
+                self.files[include_wildcard] = set(glob.glob(
+                    include_wildcard,
+                    recursive=True
+                ))
+
+            accumulator = accumulator.union(self.files[include_wildcard])
+
+        # Perform set difference to exclude. Also use the same memo since their
+        # behaviour is pretty much identical and it can benefit from the cache.
+        for exclude_wildcard in exclude_wildcards:
+            if exclude_wildcard not in self.files:
+                self.files[exclude_wildcard] = set(glob.glob(
+                    exclude_wildcard,
+                    recursive=True
+                ))
+
+            accumulator = accumulator.difference(self.files[exclude_wildcard])
+
+        return list(accumulator)
+
     def parse_macros(self, include, exclude=None):
         """
         Parse all macros defined by #define preprocessor directives.
@@ -456,11 +443,11 @@
             # Match " something(a" or " *something(a". Functions.
             # Assumptions:
             # - function definition from return type to one of its arguments is
-            #   all on one line (enforced by the previous_line concat below)
+            #   all on one line
             # - function definition line only contains alphanumeric, asterisk,
             #   underscore, and open bracket
             r".* \**(\w+) *\( *\w|"
-            # Match "(*something)(". Flexible with spaces.
+            # Match "(*something)(".
             r".*\( *\* *(\w+) *\) *\(|"
             # Match names of named data structures.
             r"(?:typedef +)?(?:struct|union|enum) +(\w+)(?: *{)?$|"
@@ -485,7 +472,7 @@
         for header_file in files:
             with open(header_file, "r", encoding="utf-8") as header:
                 in_block_comment = False
-                # The previous line varibale is used for concatenating lines
+                # The previous line variable is used for concatenating lines
                 # when identifiers are formatted and spread across multiple.
                 previous_line = ""
 
@@ -596,7 +583,6 @@
             )
         except subprocess.CalledProcessError as error:
             self.log.debug(error.output)
-            self.set_return_code(2)
             raise error
         finally:
             # Put back the original config regardless of there being errors.
@@ -614,7 +600,7 @@
         Does not return the position data since it is of no use.
 
         Args:
-        * object_files: a List of compiled object files to search through.
+        * object_files: a List of compiled object filepaths to search through.
 
         Returns a List of unique symbols defined and used in any of the object
         files.
@@ -646,18 +632,24 @@
 
         return symbols
 
+class NameChecker():
+    """
+    Representation of the core name checking operation performed by this script.
+    """
+    def __init__(self, parse_result, log):
+        self.parse_result = parse_result
+        self.log = log
+
     def perform_checks(self, quiet=False):
         """
-        Perform each check in order, output its PASS/FAIL status. Maintain an
-        overall test status, and output that at the end.
-        Assumes parse_names_in_source() was called before this.
+        A comprehensive checker that performs each check in order, and outputs
+        a final verdict.
 
         Args:
         * quiet: whether to hide detailed problem explanation.
         """
         self.log.info("=============")
         problems = 0
-
         problems += self.check_symbols_declared_in_header(quiet)
 
         pattern_checks = [
@@ -677,8 +669,10 @@
                 self.log.info("Remove --quiet to see explanations.")
             else:
                 self.log.info("Use --quiet for minimal output.")
+            return 1
         else:
             self.log.info("PASS")
+            return 0
 
     def check_symbols_declared_in_header(self, quiet):
         """
@@ -782,7 +776,6 @@
         * problems: a List of encountered Problems
         """
         if problems:
-            self.set_return_code(1)
             self.log.info("{}: FAIL\n".format(name))
             for problem in problems:
                 problem.quiet = quiet
@@ -792,8 +785,8 @@
 
 def main():
     """
-    Perform argument parsing, and create an instance of NameCheck to begin the
-    core operation.
+    Perform argument parsing, and create an instance of CodeParser and
+    NameChecker to begin the core operation.
     """
     parser = argparse.ArgumentParser(
         formatter_class=argparse.RawDescriptionHelpFormatter,
@@ -816,14 +809,22 @@
 
     args = parser.parse_args()
 
+    # Configure the global logger, which is then passed to the classes below
+    log = logging.getLogger()
+    log.setLevel(logging.DEBUG if args.verbose else logging.INFO)
+    log.addHandler(logging.StreamHandler())
+
     try:
-        name_check = NameCheck(verbose=args.verbose)
-        name_check.parse_names_in_source()
-        name_check.perform_checks(quiet=args.quiet)
-        sys.exit(name_check.return_code)
+        code_parser = CodeParser(log)
+        parse_result = code_parser.comprehensive_parse()
     except Exception: # pylint: disable=broad-except
         traceback.print_exc()
         sys.exit(2)
 
+    name_checker = NameChecker(parse_result, log)
+    return_code = name_checker.perform_checks(quiet=args.quiet)
+
+    sys.exit(return_code)
+
 if __name__ == "__main__":
     main()