Separate general and token specific code
Change-Id: I36f19eb2583398c2badcf7b078684412ed99922b
Signed-off-by: Mate Toth-Pal <mate.toth-pal@arm.com>
diff --git a/iat-verifier/dev_scripts/generate-sample-iat.py b/iat-verifier/dev_scripts/generate-sample-iat.py
index cc0fdc8..7350575 100755
--- a/iat-verifier/dev_scripts/generate-sample-iat.py
+++ b/iat-verifier/dev_scripts/generate-sample-iat.py
@@ -15,11 +15,11 @@
from iatverifier.util import sign_eat
-from iatverifier.verifiers import InstanceIdClaim, ImplementationIdClaim, ChallengeClaim
-from iatverifier.verifiers import ClientIdClaim, SecurityLifecycleClaim, ProfileIdClaim
-from iatverifier.verifiers import BootSeedClaim, SWComponentsClaim, SWComponentTypeClaim
-from iatverifier.verifiers import SignerIdClaim, SwComponentVersionClaim
-from iatverifier.verifiers import MeasurementValueClaim, MeasurementDescriptionClaim
+from iatverifier.psa_iot_profile1_token_claims import InstanceIdClaim, ImplementationIdClaim, ChallengeClaim
+from iatverifier.psa_iot_profile1_token_claims import ClientIdClaim, SecurityLifecycleClaim, ProfileIdClaim
+from iatverifier.psa_iot_profile1_token_claims import BootSeedClaim, SWComponentsClaim, SWComponentTypeClaim
+from iatverifier.psa_iot_profile1_token_claims import SignerIdClaim, SwComponentVersionClaim
+from iatverifier.psa_iot_profile1_token_claims import MeasurementValueClaim, MeasurementDescriptionClaim
from iatverifier.psa_iot_profile1_token_verifier import PSAIoTProfile1TokenVerifier
# First byte indicates "GUID"
diff --git a/iat-verifier/iatverifier/attest_token_verifier.py b/iat-verifier/iatverifier/attest_token_verifier.py
new file mode 100644
index 0000000..e89dc6b
--- /dev/null
+++ b/iat-verifier/iatverifier/attest_token_verifier.py
@@ -0,0 +1,206 @@
+# -----------------------------------------------------------------------------
+# Copyright (c) 2019-2022, Arm Limited. All rights reserved.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+import logging
+
+import cbor2
+
+logger = logging.getLogger('iat-verifiers')
+
+class AttestationClaim:
+ MANDATORY = 0
+ RECOMMENDED = 1
+ OPTIONAL = 2
+
+ def __init__(self, verifier, necessity=MANDATORY):
+ self.config = verifier.config
+ self.verifier = verifier
+ self.necessity = necessity
+ self.verify_count = 0
+
+ def verify(self, value):
+ raise NotImplementedError
+
+ def get_claim_key(self=None):
+ raise NotImplementedError
+
+ def get_claim_name(self=None):
+ raise NotImplementedError
+
+ def get_contained_claim_key_list(self):
+ return {}
+
+ def decode(self, value):
+ if self.is_utf_8():
+ try:
+ return value.decode()
+ except UnicodeDecodeError as e:
+ msg = 'Error decodeing value for "{}": {}'
+ self.verifier.error(msg.format(self.get_claim_name(), e))
+ return str(value)[2:-1]
+ else: # not a UTF-8 value, i.e. a bytestring
+ return value
+
+ def add_tokens_to_dict(self, token, value):
+ entry_name = self.get_claim_name()
+ if isinstance(value, bytes):
+ value = self.decode(value)
+ token[entry_name] = value
+
+ def claim_found(self):
+ return self.verify_count>0
+
+ def _check_type(self, name, value, expected_type):
+ if not isinstance(value, expected_type):
+ msg = 'Invalid {}: must be a(n) {}: found {}'
+ self.verifier.error(msg.format(name, expected_type, type(value)))
+ return False
+ return True
+
+ def _validate_bytestring_length_equals(self, value, name, expected_len):
+ self._check_type(name, value, bytes)
+
+ value_len = len(value)
+ if value_len != expected_len:
+ msg = 'Invalid {} length: must be exactly {} bytes, found {} bytes'
+ self.verifier.error(msg.format(name, expected_len, value_len))
+
+ def _validate_bytestring_length_is_at_least(self, value, name, minimal_length):
+ self._check_type(name, value, bytes)
+
+ value_len = len(value)
+ if value_len < minimal_length:
+ msg = 'Invalid {} length: must be at least {} bytes, found {} bytes'
+ self.verifier.error(msg.format(name, minimal_length, value_len))
+
+ @staticmethod
+ def parse_raw(raw_value):
+ return raw_value
+
+ @staticmethod
+ def get_formatted_value(value):
+ return value
+
+ def is_utf_8(self):
+ return False
+
+ def check_cross_claim_requirements(self):
+ pass
+
+
+class NonVerifiedClaim(AttestationClaim):
+ def verify(self, value):
+ self.verify_count += 1
+
+ def get_claim_key(self=None):
+ raise NotImplementedError
+
+ def get_claim_name(self=None):
+ raise NotImplementedError
+
+
+class VerifierConfiguration:
+ def __init__(self, keep_going=False, strict=False):
+ self.keep_going=keep_going
+ self.strict=strict
+
+class AttestationTokenVerifier:
+
+ all_known_claims = {}
+
+ SIGN_METHOD_SIGN1 = "sign"
+ SIGN_METHOD_MAC0 = "mac"
+ SIGN_METHOD_RAW = "raw"
+
+ COSE_ALG_ES256="ES256"
+ COSE_ALG_ES384="ES384"
+ COSE_ALG_ES512="ES512"
+ COSE_ALG_HS256_64="HS256/64"
+ COSE_ALG_HS256="HS256"
+ COSE_ALG_HS384="HS384"
+ COSE_ALG_HS512="HS512"
+
+ def __init__(self, method, cose_alg, configuration=None):
+ self.method = method
+ self.cose_alg = cose_alg
+ self.config = configuration if configuration is not None else VerifierConfiguration()
+ self.claims = []
+
+ self.seen_errors = False
+
+ def add_claims(self, claims):
+ for claim in claims:
+ key = claim.get_claim_key()
+ if key not in AttestationTokenVerifier.all_known_claims:
+ AttestationTokenVerifier.all_known_claims[key] = claim.__class__
+
+ AttestationTokenVerifier.all_known_claims.update(claim.get_contained_claim_key_list())
+ self.claims.extend(claims)
+
+ def check_cross_claim_requirements(self):
+ pass
+
+ def decode_and_validate_iat(self, encoded_iat):
+ try:
+ raw_token = cbor2.loads(encoded_iat)
+ except Exception as e:
+ msg = 'Invalid CBOR: {}'
+ raise ValueError(msg.format(e))
+
+ claims = {v.get_claim_key(): v for v in self.claims}
+
+ token = {}
+ while not hasattr(raw_token, 'items'):
+ # TODO: token map is not a map. We are assuming that it is a tag
+ raw_token = raw_token.value
+ for entry in raw_token.keys():
+ value = raw_token[entry]
+
+ try:
+ claim = claims[entry]
+ except KeyError:
+ if self.config.strict:
+ self.error('Invalid IAT claim: {}'.format(entry))
+ token[entry] = value
+ continue
+
+ claim.verify(value)
+ claim.add_tokens_to_dict(token, value)
+
+ # Check claims' necessity
+ for claim in claims.values():
+ if not claim.claim_found():
+ if claim.necessity==AttestationClaim.MANDATORY:
+ msg = 'Invalid IAT: missing MANDATORY claim "{}"'
+ self.error(msg.format(claim.get_claim_name()))
+ elif claim.necessity==AttestationClaim.RECOMMENDED:
+ msg = 'Missing RECOMMENDED claim "{}"'
+ self.warning(msg.format(claim.get_claim_name()))
+
+ claim.check_cross_claim_requirements()
+
+ self.check_cross_claim_requirements()
+
+ return token
+
+
+ def get_wrapping_tag(self=None):
+ """The value of the tag that the token is wrapped in.
+
+ The function should return None if the token is not wrapped.
+ """
+ return None
+
+ def error(self, message):
+ self.seen_errors = True
+ if self.config.keep_going:
+ logger.error(message)
+ else:
+ raise ValueError(message)
+
+ def warning(self, message):
+ logger.warning(message)
diff --git a/iat-verifier/iatverifier/verifiers.py b/iat-verifier/iatverifier/psa_iot_profile1_token_claims.py
similarity index 62%
rename from iat-verifier/iatverifier/verifiers.py
rename to iat-verifier/iatverifier/psa_iot_profile1_token_claims.py
index 904e760..b4c6a4f 100644
--- a/iat-verifier/iatverifier/verifiers.py
+++ b/iat-verifier/iatverifier/psa_iot_profile1_token_claims.py
@@ -5,12 +5,9 @@
#
# -----------------------------------------------------------------------------
-import logging
import string
-import cbor2
-
-logger = logging.getLogger('iat-verifiers')
+from iatverifier.attest_token_verifier import AttestationClaim, NonVerifiedClaim
# IAT custom claims
ARM_RANGE = -75000
@@ -18,90 +15,6 @@
# SW component IDs
SW_COMPONENT_RANGE = 0
-class AttestationClaim:
- MANDATORY = 0
- RECOMMENDED = 1
- OPTIONAL = 2
-
- def __init__(self, verifier, necessity=MANDATORY):
- self.config = verifier.config
- self.verifier = verifier
- self.necessity = necessity
- self.verify_count = 0
-
- def verify(self, value):
- raise NotImplementedError
-
- def get_claim_key(self=None):
- raise NotImplementedError
-
- def get_claim_name(self=None):
- raise NotImplementedError
-
- def get_contained_claim_key_list(self):
- return {}
-
- def decode(self, value):
- if self.is_utf_8():
- try:
- return value.decode()
- except UnicodeDecodeError as e:
- msg = 'Error decodeing value for "{}": {}'
- self.verifier.error(msg.format(self.get_claim_name(), e))
- return str(value)[2:-1]
- else: # not a UTF-8 value, i.e. a bytestring
- return value
-
- def add_tokens_to_dict(self, token, value):
- entry_name = self.get_claim_name()
- if isinstance(value, bytes):
- value = self.decode(value)
- token[entry_name] = value
-
- def claim_found(self):
- return self.verify_count>0
-
- def _check_type(self, name, value, expected_type):
- if not isinstance(value, expected_type):
- msg = 'Invalid {}: must be a(n) {}: found {}'
- self.verifier.error(msg.format(name, expected_type, type(value)))
- return False
- return True
-
- def _validate_bytestring_length_equals(self, value, name, expected_len):
- self._check_type(name, value, bytes)
-
- value_len = len(value)
- if value_len != expected_len:
- msg = 'Invalid {} length: must be exactly {} bytes, found {} bytes'
- self.verifier.error(msg.format(name, expected_len, value_len))
-
- def _validate_bytestring_length_is_at_least(self, value, name, minimal_length):
- self._check_type(name, value, bytes)
-
- value_len = len(value)
- if value_len < minimal_length:
- msg = 'Invalid {} length: must be at least {} bytes, found {} bytes'
- self.verifier.error(msg.format(name, minimal_length, value_len))
-
- @staticmethod
- def parse_raw(raw_value):
- return raw_value
-
- @staticmethod
- def get_formatted_value(value):
- return value
-
- def is_utf_8(self):
- return False
-
- def check_cross_claim_requirements(self):
- pass
-
-
-# ----------------------------------------------------------------------------
-# Validation classes
-#
class InstanceIdClaim(AttestationClaim):
def __init__(self, verifier, expected_len, necessity=AttestationClaim.MANDATORY):
super().__init__(verifier, necessity)
@@ -141,17 +54,6 @@
self.verify_count += 1
-class NonVerifiedClaim(AttestationClaim):
- def verify(self, value):
- self.verify_count += 1
-
- def get_claim_key(self=None):
- raise NotImplementedError
-
- def get_claim_name(self=None):
- raise NotImplementedError
-
-
class ImplementationIdClaim(NonVerifiedClaim):
def get_claim_key(self=None):
return ARM_RANGE - 3
@@ -433,108 +335,3 @@
def is_utf_8(self):
return True
-
-
-# ----------------------------------------------------------------------------
-
-class VerifierConfiguration:
- def __init__(self, keep_going=False, strict=False):
- self.keep_going=keep_going
- self.strict=strict
-
-class AttestationTokenVerifier:
-
- all_known_claims = {}
-
- SIGN_METHOD_SIGN1 = "sign"
- SIGN_METHOD_MAC0 = "mac"
- SIGN_METHOD_RAW = "raw"
-
- COSE_ALG_ES256="ES256"
- COSE_ALG_ES384="ES384"
- COSE_ALG_ES512="ES512"
- COSE_ALG_HS256_64="HS256/64"
- COSE_ALG_HS256="HS256"
- COSE_ALG_HS384="HS384"
- COSE_ALG_HS512="HS512"
-
- def __init__(self, method, cose_alg, configuration=None):
- self.method = method
- self.cose_alg = cose_alg
- self.config = configuration if configuration is not None else VerifierConfiguration()
- self.claims = []
-
- self.seen_errors = False
-
- def add_claims(self, claims):
- for claim in claims:
- key = claim.get_claim_key()
- if key not in AttestationTokenVerifier.all_known_claims:
- AttestationTokenVerifier.all_known_claims[key] = claim.__class__
-
- AttestationTokenVerifier.all_known_claims.update(claim.get_contained_claim_key_list())
- self.claims.extend(claims)
-
- def check_cross_claim_requirements(self):
- pass
-
- def decode_and_validate_iat(self, encoded_iat):
- try:
- raw_token = cbor2.loads(encoded_iat)
- except Exception as e:
- msg = 'Invalid CBOR: {}'
- raise ValueError(msg.format(e))
-
- claims = {v.get_claim_key(): v for v in self.claims}
-
- token = {}
- while not hasattr(raw_token, 'items'):
- # TODO: token map is not a map. We are assuming that it is a tag
- raw_token = raw_token.value
- for entry in raw_token.keys():
- value = raw_token[entry]
-
- try:
- claim = claims[entry]
- except KeyError:
- if self.config.strict:
- self.error('Invalid IAT claim: {}'.format(entry))
- token[entry] = value
- continue
-
- claim.verify(value)
- claim.add_tokens_to_dict(token, value)
-
- # Check claims' necessity
- for claim in claims.values():
- if not claim.claim_found():
- if claim.necessity==AttestationClaim.MANDATORY:
- msg = 'Invalid IAT: missing MANDATORY claim "{}"'
- self.error(msg.format(claim.get_claim_name()))
- elif claim.necessity==AttestationClaim.RECOMMENDED:
- msg = 'Missing RECOMMENDED claim "{}"'
- self.warning(msg.format(claim.get_claim_name()))
-
- claim.check_cross_claim_requirements()
-
- self.check_cross_claim_requirements()
-
- return token
-
-
- def get_wrapping_tag(self=None):
- """The value of the tag that the token is wrapped in.
-
- The function should return None if the token is not wrapped.
- """
- return None
-
- def error(self, message):
- self.seen_errors = True
- if self.config.keep_going:
- logger.error(message)
- else:
- raise ValueError(message)
-
- def warning(self, message):
- logger.warning(message)
diff --git a/iat-verifier/iatverifier/psa_iot_profile1_token_verifier.py b/iat-verifier/iatverifier/psa_iot_profile1_token_verifier.py
index 1d0e773..e9a7366 100644
--- a/iat-verifier/iatverifier/psa_iot_profile1_token_verifier.py
+++ b/iat-verifier/iatverifier/psa_iot_profile1_token_verifier.py
@@ -5,14 +5,14 @@
#
# -----------------------------------------------------------------------------
-from iatverifier.verifiers import AttestationTokenVerifier as Verifier
-from iatverifier.verifiers import AttestationClaim as Claim
-from iatverifier.verifiers import ProfileIdClaim, ClientIdClaim, SecurityLifecycleClaim
-from iatverifier.verifiers import ImplementationIdClaim, BootSeedClaim, HardwareVersionClaim
-from iatverifier.verifiers import NoMeasurementsClaim, ChallengeClaim
-from iatverifier.verifiers import InstanceIdClaim, VerificationServiceClaim, SWComponentsClaim
-from iatverifier.verifiers import SWComponentTypeClaim, SwComponentVersionClaim
-from iatverifier.verifiers import MeasurementValueClaim, MeasurementDescriptionClaim, SignerIdClaim
+from iatverifier.attest_token_verifier import AttestationTokenVerifier as Verifier
+from iatverifier.attest_token_verifier import AttestationClaim as Claim
+from iatverifier.psa_iot_profile1_token_claims import ProfileIdClaim, ClientIdClaim, SecurityLifecycleClaim
+from iatverifier.psa_iot_profile1_token_claims import ImplementationIdClaim, BootSeedClaim, HardwareVersionClaim
+from iatverifier.psa_iot_profile1_token_claims import NoMeasurementsClaim, ChallengeClaim
+from iatverifier.psa_iot_profile1_token_claims import InstanceIdClaim, VerificationServiceClaim, SWComponentsClaim
+from iatverifier.psa_iot_profile1_token_claims import SWComponentTypeClaim, SwComponentVersionClaim
+from iatverifier.psa_iot_profile1_token_claims import MeasurementValueClaim, MeasurementDescriptionClaim, SignerIdClaim
class PSAIoTProfile1TokenVerifier(Verifier):
@staticmethod
diff --git a/iat-verifier/iatverifier/util.py b/iat-verifier/iatverifier/util.py
index 9418721..39af2e3 100644
--- a/iat-verifier/iatverifier/util.py
+++ b/iat-verifier/iatverifier/util.py
@@ -16,7 +16,7 @@
from pycose.attributes import CoseAttrs
from pycose.sign1message import Sign1Message
from pycose.mac0message import Mac0Message
-from iatverifier.verifiers import AttestationTokenVerifier
+from iatverifier.attest_token_verifier import AttestationTokenVerifier
from cbor2 import CBORTag
_logger = logging.getLogger("util")
diff --git a/iat-verifier/scripts/check_iat b/iat-verifier/scripts/check_iat
index b91702a..c391393 100755
--- a/iat-verifier/scripts/check_iat
+++ b/iat-verifier/scripts/check_iat
@@ -12,7 +12,7 @@
from iatverifier.util import extract_iat_from_cose, recursive_bytes_to_strings
from iatverifier.psa_iot_profile1_token_verifier import PSAIoTProfile1TokenVerifier
-from iatverifier.verifiers import VerifierConfiguration, AttestationTokenVerifier
+from iatverifier.attest_token_verifier import VerifierConfiguration, AttestationTokenVerifier
logger = logging.getLogger('iat-verify')
diff --git a/iat-verifier/scripts/compile_token b/iat-verifier/scripts/compile_token
index 1272ad4..8fac1fc 100755
--- a/iat-verifier/scripts/compile_token
+++ b/iat-verifier/scripts/compile_token
@@ -14,7 +14,7 @@
from ecdsa import SigningKey
from iatverifier.util import read_token_map, convert_map_to_token
from iatverifier.psa_iot_profile1_token_verifier import PSAIoTProfile1TokenVerifier
-from iatverifier.verifiers import AttestationTokenVerifier
+from iatverifier.attest_token_verifier import AttestationTokenVerifier
if __name__ == '__main__':
diff --git a/iat-verifier/tests/test_verifier.py b/iat-verifier/tests/test_verifier.py
index 7824ea6..a7e4c92 100644
--- a/iat-verifier/tests/test_verifier.py
+++ b/iat-verifier/tests/test_verifier.py
@@ -11,7 +11,7 @@
from iatverifier.psa_iot_profile1_token_verifier import PSAIoTProfile1TokenVerifier
from iatverifier.util import convert_map_to_token_files, extract_iat_from_cose
-from iatverifier.verifiers import VerifierConfiguration
+from iatverifier.attest_token_verifier import VerifierConfiguration
THIS_DIR = os.path.dirname(__file__)