Separate general and token specific code

Change-Id: I36f19eb2583398c2badcf7b078684412ed99922b
Signed-off-by: Mate Toth-Pal <mate.toth-pal@arm.com>
diff --git a/iat-verifier/dev_scripts/generate-sample-iat.py b/iat-verifier/dev_scripts/generate-sample-iat.py
index cc0fdc8..7350575 100755
--- a/iat-verifier/dev_scripts/generate-sample-iat.py
+++ b/iat-verifier/dev_scripts/generate-sample-iat.py
@@ -15,11 +15,11 @@
 
 from iatverifier.util import sign_eat
 
-from iatverifier.verifiers import InstanceIdClaim, ImplementationIdClaim, ChallengeClaim
-from iatverifier.verifiers import ClientIdClaim, SecurityLifecycleClaim, ProfileIdClaim
-from iatverifier.verifiers import BootSeedClaim, SWComponentsClaim, SWComponentTypeClaim
-from iatverifier.verifiers import SignerIdClaim, SwComponentVersionClaim
-from iatverifier.verifiers import MeasurementValueClaim, MeasurementDescriptionClaim
+from iatverifier.psa_iot_profile1_token_claims import InstanceIdClaim, ImplementationIdClaim, ChallengeClaim
+from iatverifier.psa_iot_profile1_token_claims import ClientIdClaim, SecurityLifecycleClaim, ProfileIdClaim
+from iatverifier.psa_iot_profile1_token_claims import BootSeedClaim, SWComponentsClaim, SWComponentTypeClaim
+from iatverifier.psa_iot_profile1_token_claims import SignerIdClaim, SwComponentVersionClaim
+from iatverifier.psa_iot_profile1_token_claims import MeasurementValueClaim, MeasurementDescriptionClaim
 from iatverifier.psa_iot_profile1_token_verifier import PSAIoTProfile1TokenVerifier
 
 # First byte indicates "GUID"
diff --git a/iat-verifier/iatverifier/attest_token_verifier.py b/iat-verifier/iatverifier/attest_token_verifier.py
new file mode 100644
index 0000000..e89dc6b
--- /dev/null
+++ b/iat-verifier/iatverifier/attest_token_verifier.py
@@ -0,0 +1,206 @@
+# -----------------------------------------------------------------------------
+# Copyright (c) 2019-2022, Arm Limited. All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+import logging
+
+import cbor2
+
+logger = logging.getLogger('iat-verifiers')
+
+class AttestationClaim:
+    MANDATORY = 0
+    RECOMMENDED = 1
+    OPTIONAL = 2
+
+    def __init__(self, verifier, necessity=MANDATORY):
+        self.config = verifier.config
+        self.verifier = verifier
+        self.necessity = necessity
+        self.verify_count = 0
+
+    def verify(self, value):
+        raise NotImplementedError
+
+    def get_claim_key(self=None):
+        raise NotImplementedError
+
+    def get_claim_name(self=None):
+        raise NotImplementedError
+
+    def get_contained_claim_key_list(self):
+        return {}
+
+    def decode(self, value):
+        if self.is_utf_8():
+            try:
+                return value.decode()
+            except UnicodeDecodeError as e:
+                msg = 'Error decodeing value for "{}": {}'
+                self.verifier.error(msg.format(self.get_claim_name(), e))
+                return str(value)[2:-1]
+        else:  # not a UTF-8 value, i.e. a bytestring
+            return value
+
+    def add_tokens_to_dict(self, token, value):
+        entry_name = self.get_claim_name()
+        if isinstance(value, bytes):
+            value = self.decode(value)
+        token[entry_name] = value
+
+    def claim_found(self):
+        return self.verify_count>0
+
+    def _check_type(self, name, value, expected_type):
+        if not isinstance(value, expected_type):
+            msg = 'Invalid {}: must be a(n) {}: found {}'
+            self.verifier.error(msg.format(name, expected_type, type(value)))
+            return False
+        return True
+
+    def _validate_bytestring_length_equals(self, value, name, expected_len):
+        self._check_type(name, value, bytes)
+
+        value_len = len(value)
+        if value_len != expected_len:
+            msg = 'Invalid {} length: must be exactly {} bytes, found {} bytes'
+            self.verifier.error(msg.format(name, expected_len, value_len))
+
+    def _validate_bytestring_length_is_at_least(self, value, name, minimal_length):
+        self._check_type(name, value, bytes)
+
+        value_len = len(value)
+        if value_len < minimal_length:
+            msg = 'Invalid {} length: must be at least {} bytes, found {} bytes'
+            self.verifier.error(msg.format(name, minimal_length, value_len))
+
+    @staticmethod
+    def parse_raw(raw_value):
+        return raw_value
+
+    @staticmethod
+    def get_formatted_value(value):
+        return value
+
+    def is_utf_8(self):
+        return False
+
+    def check_cross_claim_requirements(self):
+        pass
+
+
+class NonVerifiedClaim(AttestationClaim):
+    def verify(self, value):
+        self.verify_count += 1
+
+    def get_claim_key(self=None):
+        raise NotImplementedError
+
+    def get_claim_name(self=None):
+        raise NotImplementedError
+
+
+class VerifierConfiguration:
+    def __init__(self, keep_going=False, strict=False):
+        self.keep_going=keep_going
+        self.strict=strict
+
+class AttestationTokenVerifier:
+
+    all_known_claims = {}
+
+    SIGN_METHOD_SIGN1 = "sign"
+    SIGN_METHOD_MAC0 = "mac"
+    SIGN_METHOD_RAW = "raw"
+
+    COSE_ALG_ES256="ES256"
+    COSE_ALG_ES384="ES384"
+    COSE_ALG_ES512="ES512"
+    COSE_ALG_HS256_64="HS256/64"
+    COSE_ALG_HS256="HS256"
+    COSE_ALG_HS384="HS384"
+    COSE_ALG_HS512="HS512"
+
+    def __init__(self, method, cose_alg, configuration=None):
+        self.method = method
+        self.cose_alg = cose_alg
+        self.config = configuration if configuration is not None else VerifierConfiguration()
+        self.claims = []
+
+        self.seen_errors = False
+
+    def add_claims(self, claims):
+        for claim in claims:
+            key = claim.get_claim_key()
+            if key not in AttestationTokenVerifier.all_known_claims:
+                AttestationTokenVerifier.all_known_claims[key] = claim.__class__
+
+            AttestationTokenVerifier.all_known_claims.update(claim.get_contained_claim_key_list())
+        self.claims.extend(claims)
+
+    def check_cross_claim_requirements(self):
+        pass
+
+    def decode_and_validate_iat(self, encoded_iat):
+        try:
+            raw_token = cbor2.loads(encoded_iat)
+        except Exception as e:
+            msg = 'Invalid CBOR: {}'
+            raise ValueError(msg.format(e))
+
+        claims = {v.get_claim_key(): v for v in self.claims}
+
+        token = {}
+        while not hasattr(raw_token, 'items'):
+            # TODO: token map is not a map. We are assuming that it is a tag
+            raw_token = raw_token.value
+        for entry in raw_token.keys():
+            value = raw_token[entry]
+
+            try:
+                claim = claims[entry]
+            except KeyError:
+                if self.config.strict:
+                    self.error('Invalid IAT claim: {}'.format(entry))
+                token[entry] = value
+                continue
+
+            claim.verify(value)
+            claim.add_tokens_to_dict(token, value)
+
+        # Check claims' necessity
+        for claim in claims.values():
+            if not claim.claim_found():
+                if claim.necessity==AttestationClaim.MANDATORY:
+                    msg = 'Invalid IAT: missing MANDATORY claim "{}"'
+                    self.error(msg.format(claim.get_claim_name()))
+                elif claim.necessity==AttestationClaim.RECOMMENDED:
+                    msg = 'Missing RECOMMENDED claim "{}"'
+                    self.warning(msg.format(claim.get_claim_name()))
+
+            claim.check_cross_claim_requirements()
+
+        self.check_cross_claim_requirements()
+
+        return token
+
+
+    def get_wrapping_tag(self=None):
+        """The value of the tag that the token is wrapped in.
+
+        The function should return None if the token is not wrapped.
+        """
+        return None
+
+    def error(self, message):
+        self.seen_errors = True
+        if self.config.keep_going:
+            logger.error(message)
+        else:
+            raise ValueError(message)
+
+    def warning(self, message):
+        logger.warning(message)
diff --git a/iat-verifier/iatverifier/verifiers.py b/iat-verifier/iatverifier/psa_iot_profile1_token_claims.py
similarity index 62%
rename from iat-verifier/iatverifier/verifiers.py
rename to iat-verifier/iatverifier/psa_iot_profile1_token_claims.py
index 904e760..b4c6a4f 100644
--- a/iat-verifier/iatverifier/verifiers.py
+++ b/iat-verifier/iatverifier/psa_iot_profile1_token_claims.py
@@ -5,12 +5,9 @@
 #
 # -----------------------------------------------------------------------------
 
-import logging
 import string
 
-import cbor2
-
-logger = logging.getLogger('iat-verifiers')
+from iatverifier.attest_token_verifier import AttestationClaim, NonVerifiedClaim
 
 # IAT custom claims
 ARM_RANGE = -75000
@@ -18,90 +15,6 @@
 # SW component IDs
 SW_COMPONENT_RANGE = 0
 
-class AttestationClaim:
-    MANDATORY = 0
-    RECOMMENDED = 1
-    OPTIONAL = 2
-
-    def __init__(self, verifier, necessity=MANDATORY):
-        self.config = verifier.config
-        self.verifier = verifier
-        self.necessity = necessity
-        self.verify_count = 0
-
-    def verify(self, value):
-        raise NotImplementedError
-
-    def get_claim_key(self=None):
-        raise NotImplementedError
-
-    def get_claim_name(self=None):
-        raise NotImplementedError
-
-    def get_contained_claim_key_list(self):
-        return {}
-
-    def decode(self, value):
-        if self.is_utf_8():
-            try:
-                return value.decode()
-            except UnicodeDecodeError as e:
-                msg = 'Error decodeing value for "{}": {}'
-                self.verifier.error(msg.format(self.get_claim_name(), e))
-                return str(value)[2:-1]
-        else:  # not a UTF-8 value, i.e. a bytestring
-            return value
-
-    def add_tokens_to_dict(self, token, value):
-        entry_name = self.get_claim_name()
-        if isinstance(value, bytes):
-            value = self.decode(value)
-        token[entry_name] = value
-
-    def claim_found(self):
-        return self.verify_count>0
-
-    def _check_type(self, name, value, expected_type):
-        if not isinstance(value, expected_type):
-            msg = 'Invalid {}: must be a(n) {}: found {}'
-            self.verifier.error(msg.format(name, expected_type, type(value)))
-            return False
-        return True
-
-    def _validate_bytestring_length_equals(self, value, name, expected_len):
-        self._check_type(name, value, bytes)
-
-        value_len = len(value)
-        if value_len != expected_len:
-            msg = 'Invalid {} length: must be exactly {} bytes, found {} bytes'
-            self.verifier.error(msg.format(name, expected_len, value_len))
-
-    def _validate_bytestring_length_is_at_least(self, value, name, minimal_length):
-        self._check_type(name, value, bytes)
-
-        value_len = len(value)
-        if value_len < minimal_length:
-            msg = 'Invalid {} length: must be at least {} bytes, found {} bytes'
-            self.verifier.error(msg.format(name, minimal_length, value_len))
-
-    @staticmethod
-    def parse_raw(raw_value):
-        return raw_value
-
-    @staticmethod
-    def get_formatted_value(value):
-        return value
-
-    def is_utf_8(self):
-        return False
-
-    def check_cross_claim_requirements(self):
-        pass
-
-
-# ----------------------------------------------------------------------------
-# Validation classes
-#
 class InstanceIdClaim(AttestationClaim):
     def __init__(self, verifier, expected_len, necessity=AttestationClaim.MANDATORY):
         super().__init__(verifier, necessity)
@@ -141,17 +54,6 @@
         self.verify_count += 1
 
 
-class NonVerifiedClaim(AttestationClaim):
-    def verify(self, value):
-        self.verify_count += 1
-
-    def get_claim_key(self=None):
-        raise NotImplementedError
-
-    def get_claim_name(self=None):
-        raise NotImplementedError
-
-
 class ImplementationIdClaim(NonVerifiedClaim):
     def get_claim_key(self=None):
         return ARM_RANGE - 3
@@ -433,108 +335,3 @@
 
     def is_utf_8(self):
         return True
-
-
-# ----------------------------------------------------------------------------
-
-class VerifierConfiguration:
-    def __init__(self, keep_going=False, strict=False):
-        self.keep_going=keep_going
-        self.strict=strict
-
-class AttestationTokenVerifier:
-
-    all_known_claims = {}
-
-    SIGN_METHOD_SIGN1 = "sign"
-    SIGN_METHOD_MAC0 = "mac"
-    SIGN_METHOD_RAW = "raw"
-
-    COSE_ALG_ES256="ES256"
-    COSE_ALG_ES384="ES384"
-    COSE_ALG_ES512="ES512"
-    COSE_ALG_HS256_64="HS256/64"
-    COSE_ALG_HS256="HS256"
-    COSE_ALG_HS384="HS384"
-    COSE_ALG_HS512="HS512"
-
-    def __init__(self, method, cose_alg, configuration=None):
-        self.method = method
-        self.cose_alg = cose_alg
-        self.config = configuration if configuration is not None else VerifierConfiguration()
-        self.claims = []
-
-        self.seen_errors = False
-
-    def add_claims(self, claims):
-        for claim in claims:
-            key = claim.get_claim_key()
-            if key not in AttestationTokenVerifier.all_known_claims:
-                AttestationTokenVerifier.all_known_claims[key] = claim.__class__
-
-            AttestationTokenVerifier.all_known_claims.update(claim.get_contained_claim_key_list())
-        self.claims.extend(claims)
-
-    def check_cross_claim_requirements(self):
-        pass
-
-    def decode_and_validate_iat(self, encoded_iat):
-        try:
-            raw_token = cbor2.loads(encoded_iat)
-        except Exception as e:
-            msg = 'Invalid CBOR: {}'
-            raise ValueError(msg.format(e))
-
-        claims = {v.get_claim_key(): v for v in self.claims}
-
-        token = {}
-        while not hasattr(raw_token, 'items'):
-            # TODO: token map is not a map. We are assuming that it is a tag
-            raw_token = raw_token.value
-        for entry in raw_token.keys():
-            value = raw_token[entry]
-
-            try:
-                claim = claims[entry]
-            except KeyError:
-                if self.config.strict:
-                    self.error('Invalid IAT claim: {}'.format(entry))
-                token[entry] = value
-                continue
-
-            claim.verify(value)
-            claim.add_tokens_to_dict(token, value)
-
-        # Check claims' necessity
-        for claim in claims.values():
-            if not claim.claim_found():
-                if claim.necessity==AttestationClaim.MANDATORY:
-                    msg = 'Invalid IAT: missing MANDATORY claim "{}"'
-                    self.error(msg.format(claim.get_claim_name()))
-                elif claim.necessity==AttestationClaim.RECOMMENDED:
-                    msg = 'Missing RECOMMENDED claim "{}"'
-                    self.warning(msg.format(claim.get_claim_name()))
-
-            claim.check_cross_claim_requirements()
-
-        self.check_cross_claim_requirements()
-
-        return token
-
-
-    def get_wrapping_tag(self=None):
-        """The value of the tag that the token is wrapped in.
-
-        The function should return None if the token is not wrapped.
-        """
-        return None
-
-    def error(self, message):
-        self.seen_errors = True
-        if self.config.keep_going:
-            logger.error(message)
-        else:
-            raise ValueError(message)
-
-    def warning(self, message):
-        logger.warning(message)
diff --git a/iat-verifier/iatverifier/psa_iot_profile1_token_verifier.py b/iat-verifier/iatverifier/psa_iot_profile1_token_verifier.py
index 1d0e773..e9a7366 100644
--- a/iat-verifier/iatverifier/psa_iot_profile1_token_verifier.py
+++ b/iat-verifier/iatverifier/psa_iot_profile1_token_verifier.py
@@ -5,14 +5,14 @@
 #
 # -----------------------------------------------------------------------------
 
-from iatverifier.verifiers import AttestationTokenVerifier as Verifier
-from iatverifier.verifiers import AttestationClaim as Claim
-from iatverifier.verifiers import ProfileIdClaim, ClientIdClaim, SecurityLifecycleClaim
-from iatverifier.verifiers import ImplementationIdClaim, BootSeedClaim, HardwareVersionClaim
-from iatverifier.verifiers import NoMeasurementsClaim, ChallengeClaim
-from iatverifier.verifiers import InstanceIdClaim, VerificationServiceClaim, SWComponentsClaim
-from iatverifier.verifiers import SWComponentTypeClaim, SwComponentVersionClaim
-from iatverifier.verifiers import MeasurementValueClaim, MeasurementDescriptionClaim, SignerIdClaim
+from iatverifier.attest_token_verifier import AttestationTokenVerifier as Verifier
+from iatverifier.attest_token_verifier import AttestationClaim as Claim
+from iatverifier.psa_iot_profile1_token_claims import ProfileIdClaim, ClientIdClaim, SecurityLifecycleClaim
+from iatverifier.psa_iot_profile1_token_claims import ImplementationIdClaim, BootSeedClaim, HardwareVersionClaim
+from iatverifier.psa_iot_profile1_token_claims import NoMeasurementsClaim, ChallengeClaim
+from iatverifier.psa_iot_profile1_token_claims import InstanceIdClaim, VerificationServiceClaim, SWComponentsClaim
+from iatverifier.psa_iot_profile1_token_claims import SWComponentTypeClaim, SwComponentVersionClaim
+from iatverifier.psa_iot_profile1_token_claims import MeasurementValueClaim, MeasurementDescriptionClaim, SignerIdClaim
 
 class PSAIoTProfile1TokenVerifier(Verifier):
     @staticmethod
diff --git a/iat-verifier/iatverifier/util.py b/iat-verifier/iatverifier/util.py
index 9418721..39af2e3 100644
--- a/iat-verifier/iatverifier/util.py
+++ b/iat-verifier/iatverifier/util.py
@@ -16,7 +16,7 @@
 from pycose.attributes import CoseAttrs
 from pycose.sign1message import Sign1Message
 from pycose.mac0message import Mac0Message
-from iatverifier.verifiers import AttestationTokenVerifier
+from iatverifier.attest_token_verifier import AttestationTokenVerifier
 from cbor2 import CBORTag
 
 _logger = logging.getLogger("util")
diff --git a/iat-verifier/scripts/check_iat b/iat-verifier/scripts/check_iat
index b91702a..c391393 100755
--- a/iat-verifier/scripts/check_iat
+++ b/iat-verifier/scripts/check_iat
@@ -12,7 +12,7 @@
 
 from iatverifier.util import extract_iat_from_cose, recursive_bytes_to_strings
 from iatverifier.psa_iot_profile1_token_verifier import PSAIoTProfile1TokenVerifier
-from iatverifier.verifiers import VerifierConfiguration, AttestationTokenVerifier
+from iatverifier.attest_token_verifier import VerifierConfiguration, AttestationTokenVerifier
 
 logger = logging.getLogger('iat-verify')
 
diff --git a/iat-verifier/scripts/compile_token b/iat-verifier/scripts/compile_token
index 1272ad4..8fac1fc 100755
--- a/iat-verifier/scripts/compile_token
+++ b/iat-verifier/scripts/compile_token
@@ -14,7 +14,7 @@
 from ecdsa import SigningKey
 from iatverifier.util import read_token_map, convert_map_to_token
 from iatverifier.psa_iot_profile1_token_verifier import PSAIoTProfile1TokenVerifier
-from iatverifier.verifiers import AttestationTokenVerifier
+from iatverifier.attest_token_verifier import AttestationTokenVerifier
 
 
 if __name__ == '__main__':
diff --git a/iat-verifier/tests/test_verifier.py b/iat-verifier/tests/test_verifier.py
index 7824ea6..a7e4c92 100644
--- a/iat-verifier/tests/test_verifier.py
+++ b/iat-verifier/tests/test_verifier.py
@@ -11,7 +11,7 @@
 
 from iatverifier.psa_iot_profile1_token_verifier import PSAIoTProfile1TokenVerifier
 from iatverifier.util import convert_map_to_token_files, extract_iat_from_cose
-from iatverifier.verifiers import VerifierConfiguration
+from iatverifier.attest_token_verifier import VerifierConfiguration
 
 
 THIS_DIR = os.path.dirname(__file__)