Add option for COSE protected header

Add option to generate and check the protected header of the COSE
envelope.

Change-Id: I5d298c5a5bb90ba32443c731d75400169c06de1c
Signed-off-by: Mate Toth-Pal <mate.toth-pal@arm.com>
diff --git a/iat-verifier/dev_scripts/generate-sample-iat.py b/iat-verifier/dev_scripts/generate-sample-iat.py
index 5c9d35b..cc0fdc8 100755
--- a/iat-verifier/dev_scripts/generate-sample-iat.py
+++ b/iat-verifier/dev_scripts/generate-sample-iat.py
@@ -89,7 +89,7 @@
     sk = SigningKey.from_pem(open(keyfile, 'rb').read())
     token = cbor2.dumps(token_map)
     verifier = PSAIoTProfile1TokenVerifier.get_verifier()
-    signed_token = sign_eat(token, verifier, sk)
+    signed_token = sign_eat(token, verifier, add_p_header=False, key=sk)
 
     with open(outfile, 'wb') as wfh:
         wfh.write(signed_token)
diff --git a/iat-verifier/iatverifier/util.py b/iat-verifier/iatverifier/util.py
index 1025355..9418721 100644
--- a/iat-verifier/iatverifier/util.py
+++ b/iat-verifier/iatverifier/util.py
@@ -13,6 +13,7 @@
 import cbor2
 import yaml
 from ecdsa import SigningKey, VerifyingKey
+from pycose.attributes import CoseAttrs
 from pycose.sign1message import Sign1Message
 from pycose.mac0message import Mac0Message
 from iatverifier.verifiers import AttestationTokenVerifier
@@ -20,8 +21,11 @@
 
 _logger = logging.getLogger("util")
 
-def sign_eat(token, verifier, key=None):
-    signed_msg = Sign1Message()
+def sign_eat(token, verifier, *, add_p_header, key=None):
+    protected_header = CoseAttrs()
+    if add_p_header and key:
+        protected_header['alg'] = verifier.cose_alg
+    signed_msg = Sign1Message(p_header=protected_header)
     signed_msg.payload = token
     if key:
         signed_msg.key = key
@@ -29,13 +33,16 @@
     return signed_msg.encode()
 
 
-def hmac_eat(token, verifier, key=None):
-    hmac_msg = Mac0Message(payload=token, key=key)
+def hmac_eat(token, verifier, *, add_p_header, key=None):
+    protected_header = CoseAttrs()
+    if add_p_header and key:
+        protected_header['alg'] = verifier.cose_alg
+    hmac_msg = Mac0Message(payload=token, key=key, p_header=protected_header)
     hmac_msg.compute_auth_tag(alg=verifier.cose_alg)
     return hmac_msg.encode()
 
 
-def convert_map_to_token_files(mapfile, keyfile, verifier, outfile):
+def convert_map_to_token_files(mapfile, keyfile, verifier, outfile, add_p_header):
     token_map = read_token_map(mapfile)
 
     if verifier.method == 'sign':
@@ -46,10 +53,10 @@
             signing_key = fh.read()
 
     with open(outfile, 'wb') as wfh:
-        convert_map_to_token(token_map, signing_key, verifier, wfh)
+        convert_map_to_token(token_map, signing_key, verifier, wfh, add_p_header)
 
 
-def convert_map_to_token(token_map, signing_key, verifier, wfh):
+def convert_map_to_token(token_map, signing_key, verifier, wfh, add_p_header):
     wrapping_tag = verifier.get_wrapping_tag()
     if wrapping_tag is not None:
         token = cbor2.dumps(CBORTag(wrapping_tag, token_map))
@@ -59,9 +66,9 @@
     if verifier.method == AttestationTokenVerifier.SIGN_METHOD_RAW:
         signed_token = token
     elif verifier.method == AttestationTokenVerifier.SIGN_METHOD_SIGN1:
-        signed_token = sign_eat(token, verifier, signing_key)
+        signed_token = sign_eat(token, verifier, add_p_header=add_p_header, key=signing_key)
     elif verifier.method == AttestationTokenVerifier.SIGN_METHOD_MAC0:
-        signed_token = hmac_eat(token, verifier, signing_key)
+        signed_token = hmac_eat(token, verifier, add_p_header=add_p_header, key=signing_key)
     else:
         err_msg = 'Unexpected method "{}"; must be one of: raw, sign, mac'
         raise ValueError(err_msg.format(method))
@@ -70,7 +77,7 @@
 
 
 def convert_token_to_map(raw_data, verifier):
-    payload = get_cose_payload(raw_data, verifier)
+    payload = get_cose_payload(raw_data, verifier, check_p_header=False)
     token_map = cbor2.loads(payload)
     return _relabel_keys(token_map)
 
@@ -85,29 +92,38 @@
     return _parse_raw_token(raw)
 
 
-def extract_iat_from_cose(keyfile, tokenfile, verifier):
+def extract_iat_from_cose(keyfile, tokenfile, verifier, check_p_header):
     key = read_keyfile(keyfile, verifier.method)
 
     try:
         with open(tokenfile, 'rb') as wfh:
-            return get_cose_payload(wfh.read(), verifier, key)
+            return get_cose_payload(wfh.read(), verifier, check_p_header=check_p_header, key=key)
     except Exception as e:
         msg = 'Bad COSE file "{}": {}'
         raise ValueError(msg.format(tokenfile, e))
 
 
-def get_cose_payload(cose, verifier, key=None):
+def get_cose_payload(cose, verifier, *, check_p_header, key=None):
     if verifier.method == AttestationTokenVerifier.SIGN_METHOD_SIGN1:
-        return get_cose_sign1_payload(cose, verifier, key)
+        return get_cose_sign1_payload(cose, verifier, check_p_header=check_p_header, key=key)
     if verifier.method == AttestationTokenVerifier.SIGN_METHOD_MAC0:
-        return get_cose_mac0_pyload(cose, verifier, key)
+        return get_cose_mac0_payload(cose, verifier, check_p_header=check_p_header, key=key)
     err_msg = 'Unexpected method "{}"; must be one of: sign, mac'
-    raise ValueError(err_msg.format(method))
+    raise ValueError(err_msg.format(verifier.method))
 
+def parse_protected_header(msg, alg):
+    try:
+        msg_alg = msg.protected_header['alg']
+    except KeyError:
+        raise ValueError('Missing alg from protected header (expected {})'.format(alg))
+    if alg != msg_alg:
+        raise ValueError('Unexpected alg in protected header (expected {} instead of {})'.format(alg, msg_alg))
 
-def get_cose_sign1_payload(cose, verifier, key=None):
+def get_cose_sign1_payload(cose, verifier, *, check_p_header, key=None):
     msg = Sign1Message.decode(cose)
     if key:
+        if check_p_header:
+            parse_protected_header(msg, verifier.cose_alg)
         msg.key = key
         msg.signature = msg.signers
         try:
@@ -117,9 +133,11 @@
     return msg.payload
 
 
-def get_cose_mac0_pyload(cose, verifier, key=None):
+def get_cose_mac0_payload(cose, verifier, *, check_p_header, key=None):
     msg = Mac0Message.decode(cose)
     if key:
+        if check_p_header:
+            parse_protected_header(msg, verifier.cose_alg)
         msg.key = key
         try:
             msg.verify_auth_tag(alg=verifier.cose_alg)
diff --git a/iat-verifier/iatverifier/verify.py b/iat-verifier/iatverifier/verify.py
index 97c8078..ee5c7dc 100644
--- a/iat-verifier/iatverifier/verify.py
+++ b/iat-verifier/iatverifier/verify.py
@@ -49,6 +49,10 @@
                         help='''
                         Report failure if unknown claim is encountered.
                         ''')
+    parser.add_argument('-c', '--check-protected-header', action='store_true',
+                        help='''
+                        Check the presence and content of COSE protected header.
+                        ''')
     parser.add_argument('-m', '--method', choices=['sign', 'mac'], default='sign',
                         help='''
                         Specify how this token is wrapped -- whether Sign1Message or
@@ -70,7 +74,7 @@
         verifier.cose_alg = AttestationTokenVerifier.COSE_ALG_HS256
 
     try:
-        raw_iat = extract_iat_from_cose(args.keyfile, args.tokenfile, verifier)
+        raw_iat = extract_iat_from_cose(args.keyfile, args.tokenfile, verifier, args.check_protected_header)
         if args.keyfile:
             print('Signature OK')
     except ValueError as e:
diff --git a/iat-verifier/scripts/compile_token b/iat-verifier/scripts/compile_token
index 22fbe23..1272ad4 100755
--- a/iat-verifier/scripts/compile_token
+++ b/iat-verifier/scripts/compile_token
@@ -34,6 +34,10 @@
                         sign the token. If this is not specified, the token will be
                         unsigned.''')
     group = parser.add_mutually_exclusive_group()
+    parser.add_argument('-a', '--add-protected-header', action='store_true',
+                        help='''
+                        Add protected header to the COSE wrapper.
+                        ''')
     group.add_argument('-r', '--raw', action='store_true',
                        help='''Generate raw CBOR and do not create a signature
                        or COSE wrapper.''')
@@ -75,7 +79,7 @@
 
     if args.outfile:
         with open(args.outfile, 'wb') as wfh:
-            convert_map_to_token(token_map, signing_key, verifier, wfh)
+            convert_map_to_token(token_map, signing_key, verifier, wfh, args.add_protected_header)
     else:
         with os.fdopen(sys.stdout.fileno(), 'wb') as wfh:
-            convert_map_to_token(token_map, signing_key, verifier, wfh)
+            convert_map_to_token(token_map, signing_key, verifier, wfh, args.add_protected_header)
diff --git a/iat-verifier/tests/test_verifier.py b/iat-verifier/tests/test_verifier.py
index 5e15895..7725039 100644
--- a/iat-verifier/tests/test_verifier.py
+++ b/iat-verifier/tests/test_verifier.py
@@ -25,13 +25,13 @@
     source_path = os.path.join(DATA_DIR, source_name)
     fd, dest_path = tempfile.mkstemp()
     os.close(fd)
-    convert_map_to_token_files(source_path, keyfile, verifier, dest_path)
+    convert_map_to_token_files(source_path, keyfile, verifier, dest_path, True)
     return dest_path
 
 
 def read_iat(filename, keyfile, verifier):
     filepath = os.path.join(DATA_DIR, filename)
-    raw_iat = extract_iat_from_cose(keyfile, filepath, verifier)
+    raw_iat = extract_iat_from_cose(keyfile, filepath, verifier, True)
     return verifier.decode_and_validate_iat(raw_iat)
 
 
@@ -50,10 +50,10 @@
         good_sig = create_token('valid-iat.yaml', KEYFILE, verifier)
         bad_sig = create_token('valid-iat.yaml', KEYFILE_ALT, verifier)
 
-        raw_iat = extract_iat_from_cose(KEYFILE, good_sig, verifier)
+        raw_iat = extract_iat_from_cose(KEYFILE, good_sig, verifier, True)
 
         with self.assertRaises(ValueError) as cm:
-            raw_iat = extract_iat_from_cose(KEYFILE, bad_sig, verifier)
+            raw_iat = extract_iat_from_cose(KEYFILE, bad_sig, verifier, True)
 
         self.assertIn('Bad signature', cm.exception.args[0])