imgtool: Fix verify command for edcsa-p384 signed images

Fixed hash algorithm defaults to SHA256 in case no key provided.
Verification improved by adding check for key - tlv mismatch,
VerifyResult.KEY_MISMATCH added to indicate this case.
Multiple styling fixes and import optimisation, exception handling.

Signed-off-by: Rustam Ismayilov <rustam.ismayilov@arm.com>
Change-Id: I61a588de5b39678707c0179f4edaa411ceb67c8e
diff --git a/scripts/imgtool/image.py b/scripts/imgtool/image.py
index 3de8357..a30d53b 100644
--- a/scripts/imgtool/image.py
+++ b/scripts/imgtool/image.py
@@ -1,6 +1,6 @@
 # Copyright 2018 Nordic Semiconductor ASA
 # Copyright 2017-2020 Linaro Limited
-# Copyright 2019-2023 Arm Limited
+# Copyright 2019-2024 Arm Limited
 #
 # SPDX-License-Identifier: Apache-2.0
 #
@@ -20,23 +20,25 @@
 Image signing and management.
 """
 
-from . import version as versmod
-from .boot_record import create_sw_component_data
-import click
-from enum import Enum
-from intelhex import IntelHex
 import hashlib
-import struct
 import os.path
-from .keys import rsa, ecdsa, x25519
+import struct
+from enum import Enum
+
+import click
+from cryptography.exceptions import InvalidSignature
+from cryptography.hazmat.backends import default_backend
+from cryptography.hazmat.primitives import hashes, hmac
 from cryptography.hazmat.primitives.asymmetric import ec, padding
 from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
 from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
 from cryptography.hazmat.primitives.kdf.hkdf import HKDF
 from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat
-from cryptography.hazmat.backends import default_backend
-from cryptography.hazmat.primitives import hashes, hmac
-from cryptography.exceptions import InvalidSignature
+from intelhex import IntelHex
+
+from . import version as versmod, keys
+from .boot_record import create_sw_component_data
+from .keys import rsa, ecdsa, x25519
 
 IMAGE_MAGIC = 0x96f3b83d
 IMAGE_HEADER_SIZE = 32
@@ -90,10 +92,8 @@
 }
 
 VerifyResult = Enum('VerifyResult',
-                    """
-                    OK INVALID_MAGIC INVALID_TLV_INFO_MAGIC INVALID_HASH
-                    INVALID_SIGNATURE
-                    """)
+                    ['OK', 'INVALID_MAGIC', 'INVALID_TLV_INFO_MAGIC', 'INVALID_HASH', 'INVALID_SIGNATURE',
+                     'KEY_MISMATCH'])
 
 
 def align_up(num, align):
@@ -135,7 +135,24 @@
         return header + bytes(self.buf)
 
 
-class Image():
+def get_digest(tlv_type, hash_region):
+    if tlv_type == TLV_VALUES["SHA384"]:
+        sha = hashlib.sha384()
+    elif tlv_type == TLV_VALUES["SHA256"]:
+        sha = hashlib.sha256()
+
+    sha.update(hash_region)
+    return sha.digest()
+
+
+def tlv_matches_key_type(tlv_type, key):
+    """Check if provided key matches to TLV record in the image"""
+    return (key is None or
+            type(key) == keys.ECDSA384P1 and tlv_type == TLV_VALUES["SHA384"] or
+            type(key) != keys.ECDSA384P1 and tlv_type == TLV_VALUES["SHA256"])
+
+
+class Image:
 
     def __init__(self, version=None, header_size=IMAGE_HEADER_SIZE,
                  pad_header=False, pad=False, confirm=False, align=1,
@@ -178,9 +195,9 @@
             msb = (self.max_align & 0xff00) >> 8
             align = bytes([msb, lsb]) if self.endian == "big" else bytes([lsb, msb])
             self.boot_magic = align + bytes([0x2d, 0xe1,
-                                            0x5d, 0x29, 0x41, 0x0b,
-                                            0x8d, 0x77, 0x67, 0x9c,
-                                            0x11, 0x0f, 0x1f, 0x8a, ])
+                                             0x5d, 0x29, 0x41, 0x0b,
+                                             0x8d, 0x77, 0x67, 0x9c,
+                                             0x11, 0x0f, 0x1f, 0x8a, ])
 
         if security_counter == 'auto':
             # Security counter has not been explicitly provided,
@@ -321,9 +338,8 @@
         self.enckey = enckey
 
         # Check what hashing algorithm should be used
-        if (key is not None and isinstance(key, ecdsa.ECDSA384P1) or
-                pub_key is not None and isinstance(pub_key,
-                                                   ecdsa.ECDSA384P1Public)):
+        if (key and isinstance(key, ecdsa.ECDSA384P1)
+                or pub_key and isinstance(pub_key, ecdsa.ECDSA384P1Public)):
             hash_algorithm = hashlib.sha384
             hash_tlv = "SHA384"
         else:
@@ -430,13 +446,13 @@
             if dependencies is not None:
                 for i in range(dependencies_num):
                     payload = struct.pack(
-                                    e + 'B3x'+'BBHI',
-                                    int(dependencies[DEP_IMAGES_KEY][i]),
-                                    dependencies[DEP_VERSIONS_KEY][i].major,
-                                    dependencies[DEP_VERSIONS_KEY][i].minor,
-                                    dependencies[DEP_VERSIONS_KEY][i].revision,
-                                    dependencies[DEP_VERSIONS_KEY][i].build
-                                    )
+                        e + 'B3x' + 'BBHI',
+                        int(dependencies[DEP_IMAGES_KEY][i]),
+                        dependencies[DEP_VERSIONS_KEY][i].major,
+                        dependencies[DEP_VERSIONS_KEY][i].minor,
+                        dependencies[DEP_VERSIONS_KEY][i].revision,
+                        dependencies[DEP_VERSIONS_KEY][i].build
+                    )
                     prot_tlv.add('DEPENDENCY', payload)
 
             if custom_tlvs is not None:
@@ -640,42 +656,37 @@
             return VerifyResult.INVALID_MAGIC, None, None
 
         tlv_off = header_size + img_size
-        tlv_info = b[tlv_off:tlv_off+TLV_INFO_SIZE]
+        tlv_info = b[tlv_off:tlv_off + TLV_INFO_SIZE]
         magic, tlv_tot = struct.unpack('HH', tlv_info)
         if magic == TLV_PROT_INFO_MAGIC:
             tlv_off += tlv_tot
-            tlv_info = b[tlv_off:tlv_off+TLV_INFO_SIZE]
+            tlv_info = b[tlv_off:tlv_off + TLV_INFO_SIZE]
             magic, tlv_tot = struct.unpack('HH', tlv_info)
 
         if magic != TLV_INFO_MAGIC:
             return VerifyResult.INVALID_TLV_INFO_MAGIC, None, None
 
-        if isinstance(key, ecdsa.ECDSA384P1Public):
-            sha = hashlib.sha384()
-            hash_tlv = "SHA384"
-        else:
-            sha = hashlib.sha256()
-            hash_tlv = "SHA256"
-
         prot_tlv_size = tlv_off
-        sha.update(b[:prot_tlv_size])
-        digest = sha.digest()
-
+        hash_region = b[:prot_tlv_size]
+        digest = None
         tlv_end = tlv_off + tlv_tot
         tlv_off += TLV_INFO_SIZE  # skip tlv info
         while tlv_off < tlv_end:
-            tlv = b[tlv_off:tlv_off+TLV_SIZE]
+            tlv = b[tlv_off:tlv_off + TLV_SIZE]
             tlv_type, _, tlv_len = struct.unpack('BBH', tlv)
-            if tlv_type == TLV_VALUES[hash_tlv]:
+            if tlv_type == TLV_VALUES["SHA256"] or tlv_type == TLV_VALUES["SHA384"]:
+                if not tlv_matches_key_type(tlv_type, key):
+                    return VerifyResult.KEY_MISMATCH, None, None
                 off = tlv_off + TLV_SIZE
-                if digest == b[off:off+tlv_len]:
+                digest = get_digest(tlv_type, hash_region)
+                if digest == b[off:off + tlv_len]:
                     if key is None:
                         return VerifyResult.OK, version, digest
                 else:
                     return VerifyResult.INVALID_HASH, None, None
             elif key is not None and tlv_type == TLV_VALUES[key.sig_tlv()]:
                 off = tlv_off + TLV_SIZE
-                tlv_sig = b[off:off+tlv_len]
+                tlv_sig = b[off:off + tlv_len]
                 payload = b[:prot_tlv_size]
                 try:
                     if hasattr(key, 'verify'):