imgtool: Add generic ECDSA TLV support

Update imgtool to support the new
generic ECDSA TLV and the ECDSA
p384 curve type with sha-384

Signed-off-by: Roland Mikhel <roland.mikhel@arm.com>
Change-Id: I9b1887610cc5d0e7cde90f47999fcdf3500ef51c
diff --git a/scripts/imgtool/image.py b/scripts/imgtool/image.py
index 495416b..8da49b9 100644
--- a/scripts/imgtool/image.py
+++ b/scripts/imgtool/image.py
@@ -62,10 +62,12 @@
         'KEYHASH': 0x01,
         'PUBKEY': 0x02,
         'SHA256': 0x10,
+        'SHA384': 0x11,
         'RSA2048': 0x20,
         'ECDSA256': 0x22,
         'RSA3072': 0x23,
         'ED25519': 0x24,
+        'ECDSASIG': 0x25,
         'ENCRSA2048': 0x30,
         'ENCKW': 0x31,
         'ENCEC256': 0x32,
@@ -94,10 +96,12 @@
                     INVALID_SIGNATURE
                     """)
 
+
 def align_up(num, align):
     assert (align & (align - 1) == 0) and align != 0
     return (num + (align - 1)) & ~(align - 1)
 
+
 class TLV():
     def __init__(self, endian, magic=TLV_INFO_MAGIC):
         self.magic = magic
@@ -116,7 +120,7 @@
             if not TLV_VENDOR_RES_MIN <= kind <= TLV_VENDOR_RES_MAX:
                 msg = "Invalid custom TLV type value '0x{:04x}', allowed " \
                       "value should be between 0x{:04x} and 0x{:04x}".format(
-                      kind, TLV_VENDOR_RES_MIN, TLV_VENDOR_RES_MAX)
+                        kind, TLV_VENDOR_RES_MIN, TLV_VENDOR_RES_MAX)
                 raise click.UsageError(msg)
             buf = struct.pack(e + 'HH', kind, len(payload))
         else:
@@ -250,11 +254,13 @@
                                                   self.enctlv_len)
                 trailer_addr = (self.base_addr + self.slot_size) - trailer_size
                 if self.confirm and not self.overwrite_only:
-                    magic_align_size = align_up(len(self.boot_magic), self.max_align)
+                    magic_align_size = align_up(len(self.boot_magic),
+                                                self.max_align)
                     image_ok_idx = -(magic_align_size + self.max_align)
                     flag = bytearray([self.erased_val] * self.max_align)
-                    flag[0] = 0x01 # image_ok = 0x01
-                    h.puts(trailer_addr + trailer_size + image_ok_idx, bytes(flag))
+                    flag[0] = 0x01  # image_ok = 0x01
+                    h.puts(trailer_addr + trailer_size + image_ok_idx,
+                           bytes(flag))
                 h.puts(trailer_addr + (trailer_size - len(self.boot_magic)),
                        bytes(self.boot_magic))
             h.tofile(path, 'hex')
@@ -311,20 +317,31 @@
         return cipherkey, ciphermac, pubk
 
     def create(self, key, public_key_format, enckey, dependencies=None,
-               sw_type=None, custom_tlvs=None, encrypt_keylen=128, clear=False, fixed_sig=None, pub_key=None, vector_to_sign=None):
+               sw_type=None, custom_tlvs=None, encrypt_keylen=128, clear=False,
+               fixed_sig=None, pub_key=None, vector_to_sign=None,
+               use_legacy_tlv=False):
         self.enckey = enckey
 
+        # Check what hashing algorithm should be used
+        if (key is not None and isinstance(key, ecdsa.ECDSA384P1) or
+                pub_key is not None and isinstance(pub_key,
+                                                   ecdsa.ECDSA384P1Public)):
+            hash_algorithm = hashlib.sha384
+            hash_tlv = "SHA384"
+        else:
+            hash_algorithm = hashlib.sha256
+            hash_tlv = "SHA256"
         # Calculate the hash of the public key
         if key is not None:
             pub = key.get_public_bytes()
-            sha = hashlib.sha256()
+            sha = hash_algorithm()
             sha.update(pub)
             pubbytes = sha.digest()
         elif pub_key is not None:
             if hasattr(pub_key, 'sign'):
                 print(os.path.basename(__file__) + ": sign the payload")
             pub = pub_key.get_public_bytes()
-            sha = hashlib.sha256()
+            sha = hash_algorithm()
             sha.update(pub)
             pubbytes = sha.digest()
         else:
@@ -354,11 +371,11 @@
             # before it is even calculated. For this reason the script fills
             # this field with zeros and the bootloader will insert the right
             # value later.
-            digest = bytes(hashlib.sha256().digest_size)
+            digest = bytes(hash_algorithm().digest_size)
 
             # Create CBOR encoded boot record
             boot_record = create_sw_component_data(sw_type, image_version,
-                                                   "SHA256", digest,
+                                                   hash_tlv, digest,
                                                    pubbytes)
 
             protected_tlv_size += TLV_SIZE + len(boot_record)
@@ -435,11 +452,10 @@
 
         # Note that ecdsa wants to do the hashing itself, which means
         # we get to hash it twice.
-        sha = hashlib.sha256()
+        sha = hash_algorithm()
         sha.update(self.payload)
         digest = sha.digest()
-
-        tlv.add('SHA256', digest)
+        tlv.add(hash_tlv, digest)
 
         if vector_to_sign == 'payload':
             # Stop amending data to the image
@@ -458,8 +474,9 @@
                 tlv.add('PUBKEY', pub)
 
             if key is not None and fixed_sig is None:
-                # `sign` expects the full image payload (sha256 done internally),
-                # while `sign_digest` expects only the digest of the payload
+                # `sign` expects the full image payload (hashing done
+                # internally), while `sign_digest` expects only the digest
+                # of the payload
 
                 if hasattr(key, 'sign'):
                     print(os.path.basename(__file__) + ": sign the payload")
@@ -551,17 +568,18 @@
                )  # }
         assert struct.calcsize(fmt) == IMAGE_HEADER_SIZE
         header = struct.pack(fmt,
-                IMAGE_MAGIC,
-                self.rom_fixed or self.load_addr,
-                self.header_size,
-                protected_tlv_size,  # TLV Info header + Protected TLVs
-                len(self.payload) - self.header_size,  # ImageSz
-                flags,
-                self.version.major,
-                self.version.minor or 0,
-                self.version.revision or 0,
-                self.version.build or 0,
-                0)  # Pad1
+                             IMAGE_MAGIC,
+                             self.rom_fixed or self.load_addr,
+                             self.header_size,
+                             protected_tlv_size,  # TLV Info header +
+                                                  # Protected TLVs
+                             len(self.payload) - self.header_size,  # ImageSz
+                             flags,
+                             self.version.major,
+                             self.version.minor or 0,
+                             self.version.revision or 0,
+                             self.version.build or 0,
+                             0)  # Pad1
         self.payload = bytearray(self.payload)
         self.payload[:len(header)] = header
 
@@ -627,7 +645,13 @@
         if magic != TLV_INFO_MAGIC:
             return VerifyResult.INVALID_TLV_INFO_MAGIC, None, None
 
-        sha = hashlib.sha256()
+        if isinstance(key, ecdsa.ECDSA384P1Public):
+            sha = hashlib.sha384()
+            hash_tlv = "SHA384"
+        else:
+            sha = hashlib.sha256()
+            hash_tlv = "SHA256"
+
         prot_tlv_size = tlv_off
         sha.update(b[:prot_tlv_size])
         digest = sha.digest()
@@ -637,7 +661,7 @@
         while tlv_off < tlv_end:
             tlv = b[tlv_off:tlv_off+TLV_SIZE]
             tlv_type, _, tlv_len = struct.unpack('BBH', tlv)
-            if tlv_type == TLV_VALUES["SHA256"]:
+            if tlv_type == TLV_VALUES[hash_tlv]:
                 off = tlv_off + TLV_SIZE
                 if digest == b[off:off+tlv_len]:
                     if key is None:
diff --git a/scripts/imgtool/keys/__init__.py b/scripts/imgtool/keys/__init__.py
index dfd101d..ed2fed5 100644
--- a/scripts/imgtool/keys/__init__.py
+++ b/scripts/imgtool/keys/__init__.py
@@ -1,4 +1,5 @@
 # Copyright 2017 Linaro Limited
+# Copyright 2023 Arm Limited
 #
 # SPDX-License-Identifier: Apache-2.0
 #
@@ -30,7 +31,8 @@
     X25519PrivateKey, X25519PublicKey)
 
 from .rsa import RSA, RSAPublic, RSAUsageError, RSA_KEY_SIZES
-from .ecdsa import ECDSA256P1, ECDSA256P1Public, ECDSAUsageError
+from .ecdsa import (ECDSA256P1, ECDSA256P1Public,
+                    ECDSA384P1, ECDSA384P1Public, ECDSAUsageError)
 from .ed25519 import Ed25519, Ed25519Public, Ed25519UsageError
 from .x25519 import X25519, X25519Public, X25519UsageError
 
@@ -42,7 +44,8 @@
 
 
 def load(path, passwd=None):
-    """Try loading a key from the given path.  Returns None if the password wasn't specified."""
+    """Try loading a key from the given path.
+      Returns None if the password wasn't specified."""
     with open(path, 'rb') as f:
         raw_pem = f.read()
     try:
@@ -73,17 +76,23 @@
             raise Exception("Unsupported RSA key size: " + pk.key_size)
         return RSAPublic(pk)
     elif isinstance(pk, EllipticCurvePrivateKey):
-        if pk.curve.name != 'secp256r1':
+        if pk.curve.name not in ('secp256r1', 'secp384r1'):
             raise Exception("Unsupported EC curve: " + pk.curve.name)
-        if pk.key_size != 256:
+        if pk.key_size not in (256, 384):
             raise Exception("Unsupported EC size: " + pk.key_size)
-        return ECDSA256P1(pk)
+        if pk.curve.name == 'secp256r1':
+            return ECDSA256P1(pk)
+        elif pk.curve.name == 'secp384r1':
+            return ECDSA384P1(pk)
     elif isinstance(pk, EllipticCurvePublicKey):
-        if pk.curve.name != 'secp256r1':
+        if pk.curve.name not in ('secp256r1', 'secp384r1'):
             raise Exception("Unsupported EC curve: " + pk.curve.name)
-        if pk.key_size != 256:
+        if pk.key_size not in (256, 384):
             raise Exception("Unsupported EC size: " + pk.key_size)
-        return ECDSA256P1Public(pk)
+        if pk.curve.name == 'secp256r1':
+            return ECDSA256P1Public(pk)
+        elif pk.curve.name == 'secp384r1':
+            return ECDSA384P1Public(pk)
     elif isinstance(pk, Ed25519PrivateKey):
         return Ed25519(pk)
     elif isinstance(pk, Ed25519PublicKey):
diff --git a/scripts/imgtool/keys/ecdsa.py b/scripts/imgtool/keys/ecdsa.py
index addceb2..b70153c 100644
--- a/scripts/imgtool/keys/ecdsa.py
+++ b/scripts/imgtool/keys/ecdsa.py
@@ -4,11 +4,12 @@
 
 # SPDX-License-Identifier: Apache-2.0
 import os.path
+import hashlib
 
 from cryptography.hazmat.backends import default_backend
 from cryptography.hazmat.primitives import serialization
 from cryptography.hazmat.primitives.asymmetric import ec
-from cryptography.hazmat.primitives.hashes import SHA256
+from cryptography.hazmat.primitives.hashes import SHA256, SHA384
 
 from .general import KeyClass
 from .privatebytes import PrivateBytesMixin
@@ -18,18 +19,18 @@
     pass
 
 
-class ECDSA256P1Public(KeyClass):
+class ECDSAPublicKey(KeyClass):
+    """
+    Wrapper around an ECDSA public key.
+    """
     def __init__(self, key):
         self.key = key
 
-    def shortname(self):
-        return "ecdsa"
-
     def _unsupported(self, name):
         raise ECDSAUsageError("Operation {} requires private key".format(name))
 
     def _get_public(self):
-        return self.key
+        return self.key.public_key()
 
     def get_public_bytes(self):
         # The key is embedded into MBUboot in "SubjectPublicKeyInfo" format
@@ -56,55 +57,13 @@
         with open(path, 'wb') as f:
             f.write(pem)
 
-    def sig_type(self):
-        return "ECDSA256_SHA256"
 
-    def sig_tlv(self):
-        return "ECDSA256"
-
-    def sig_len(self):
-        # Early versions of MCUboot (< v1.5.0) required ECDSA
-        # signatures to be padded to 72 bytes.  Because the DER
-        # encoding is done with signed integers, the size of the
-        # signature will vary depending on whether the high bit is set
-        # in each value.  This padding was done in a
-        # not-easily-reversible way (by just adding zeros).
-        #
-        # The signing code no longer requires this padding, and newer
-        # versions of MCUboot don't require it.  But, continue to
-        # return the total length so that the padding can be done if
-        # requested.
-        return 72
-
-    def verify(self, signature, payload):
-        # strip possible paddings added during sign
-        signature = signature[:signature[1] + 2]
-        k = self.key
-        if isinstance(self.key, ec.EllipticCurvePrivateKey):
-            k = self.key.public_key()
-        return k.verify(signature=signature, data=payload,
-                        signature_algorithm=ec.ECDSA(SHA256()))
-
-
-class ECDSA256P1(ECDSA256P1Public, PrivateBytesMixin):
+class ECDSAPrivateKey(PrivateBytesMixin):
     """
     Wrapper around an ECDSA private key.
     """
-
     def __init__(self, key):
-        """key should be an instance of EllipticCurvePrivateKey"""
         self.key = key
-        self.pad_sig = False
-
-    @staticmethod
-    def generate():
-        pk = ec.generate_private_key(
-                ec.SECP256R1(),
-                backend=default_backend())
-        return ECDSA256P1(pk)
-
-    def _get_public(self):
-        return self.key.public_key()
 
     def _build_minimal_ecdsa_privkey(self, der, format):
         '''
@@ -154,13 +113,14 @@
         'pkcs8': serialization.PrivateFormat.PKCS8,
         'openssl': serialization.PrivateFormat.TraditionalOpenSSL
     }
-    _DEFAULT_FORMAT='pkcs8'
+    _DEFAULT_FORMAT = 'pkcs8'
 
     def get_private_bytes(self, minimal, format):
-        format, priv = self._get_private_bytes(minimal, format, ECDSAUsageError)
+        format, priv = self._get_private_bytes(minimal,
+                                               format, ECDSAUsageError)
         if minimal:
-            priv = self._build_minimal_ecdsa_privkey(priv,
-                                                     self._VALID_FORMATS[format])
+            priv = self._build_minimal_ecdsa_privkey(
+                priv, self._VALID_FORMATS[format])
         return priv
 
     def export_private(self, path, passwd=None):
@@ -177,6 +137,64 @@
         with open(path, 'wb') as f:
             f.write(pem)
 
+
+class ECDSA256P1Public(ECDSAPublicKey):
+    """
+    Wrapper around an ECDSA (p256) public key.
+    """
+    def __init__(self, key):
+        super().__init__(key)
+        self.key = key
+
+    def shortname(self):
+        return "ecdsa"
+
+    def sig_type(self):
+        return "ECDSA256_SHA256"
+
+    def sig_tlv(self):
+        return "ECDSASIG"
+
+    def sig_len(self):
+        # Early versions of MCUboot (< v1.5.0) required ECDSA
+        # signatures to be padded to 72 bytes.  Because the DER
+        # encoding is done with signed integers, the size of the
+        # signature will vary depending on whether the high bit is set
+        # in each value.  This padding was done in a
+        # not-easily-reversible way (by just adding zeros).
+        #
+        # The signing code no longer requires this padding, and newer
+        # versions of MCUboot don't require it.  But, continue to
+        # return the total length so that the padding can be done if
+        # requested.
+        return 72
+
+    def verify(self, signature, payload):
+        # strip possible paddings added during sign
+        signature = signature[:signature[1] + 2]
+        k = self.key
+        if isinstance(self.key, ec.EllipticCurvePrivateKey):
+            k = self.key.public_key()
+        return k.verify(signature=signature, data=payload,
+                        signature_algorithm=ec.ECDSA(SHA256()))
+
+
+class ECDSA256P1(ECDSA256P1Public, ECDSAPrivateKey):
+    """
+    Wrapper around an ECDSA (p256) private key.
+    """
+    def __init__(self, key):
+        super().__init__(key)
+        self.key = key
+        self.pad_sig = False
+
+    @staticmethod
+    def generate():
+        pk = ec.generate_private_key(
+                ec.SECP256R1(),
+                backend=default_backend())
+        return ECDSA256P1(pk)
+
     def raw_sign(self, payload):
         """Return the actual signature"""
         return self.key.sign(
@@ -191,3 +209,78 @@
             return sig
         else:
             return sig
+
+
+class ECDSA384P1Public(ECDSAPublicKey):
+    """
+    Wrapper around an ECDSA (p384) public key.
+    """
+    def __init__(self, key):
+        super().__init__(key)
+        self.key = key
+
+    def shortname(self):
+        return "ecdsap384"
+
+    def sig_type(self):
+        return "ECDSA384_SHA384"
+
+    def sig_tlv(self):
+        return "ECDSASIG"
+
+    def sig_len(self):
+        # Early versions of MCUboot (< v1.5.0) required ECDSA
+        # signatures to be padded to a fixed length.  Because the DER
+        # encoding is done with signed integers, the size of the
+        # signature will vary depending on whether the high bit is set
+        # in each value.  This padding was done in a
+        # not-easily-reversible way (by just adding zeros).
+        #
+        # The signing code no longer requires this padding, and newer
+        # versions of MCUboot don't require it.  But, continue to
+        # return the total length so that the padding can be done if
+        # requested.
+        return 103
+
+    def verify(self, signature, payload):
+        # strip possible paddings added during sign
+        signature = signature[:signature[1] + 2]
+        k = self.key
+        if isinstance(self.key, ec.EllipticCurvePrivateKey):
+            k = self.key.public_key()
+        return k.verify(signature=signature, data=payload,
+                        signature_algorithm=ec.ECDSA(SHA384()))
+
+
+class ECDSA384P1(ECDSA384P1Public, ECDSAPrivateKey):
+    """
+    Wrapper around an ECDSA (p384) private key.
+    """
+
+    def __init__(self, key):
+        """key should be an instance of EllipticCurvePrivateKey"""
+        super().__init__(key)
+        self.key = key
+        self.pad_sig = False
+
+    @staticmethod
+    def generate():
+        pk = ec.generate_private_key(
+                ec.SECP384R1(),
+                backend=default_backend())
+        return ECDSA384P1(pk)
+
+    def raw_sign(self, payload):
+        """Return the actual signature"""
+        return self.key.sign(
+                data=payload,
+                signature_algorithm=ec.ECDSA(SHA384()))
+
+    def sign(self, payload):
+        sig = self.raw_sign(payload)
+        if self.pad_sig:
+            # To make fixed length, pad with one or two zeros.
+            sig += b'\000' * (self.sig_len() - len(sig))
+            return sig
+        else:
+            return sig
diff --git a/scripts/imgtool/main.py b/scripts/imgtool/main.py
index 2df06e1..eba557f 100755
--- a/scripts/imgtool/main.py
+++ b/scripts/imgtool/main.py
@@ -48,6 +48,10 @@
     keys.ECDSA256P1.generate().export_private(keyfile, passwd=passwd)
 
 
+def gen_ecdsa_p384(keyfile, passwd):
+    keys.ECDSA384P1.generate().export_private(keyfile, passwd=passwd)
+
+
 def gen_ed25519(keyfile, passwd):
     keys.Ed25519.generate().export_private(path=keyfile, passwd=passwd)
 
@@ -62,6 +66,7 @@
     'rsa-2048':   gen_rsa2048,
     'rsa-3072':   gen_rsa3072,
     'ecdsa-p256': gen_ecdsa_p256,
+    'ecdsa-p384': gen_ecdsa_p384,
     'ed25519':    gen_ed25519,
     'x25519':     gen_x25519,
 }
@@ -183,7 +188,7 @@
     elif ret == image.VerifyResult.INVALID_TLV_INFO_MAGIC:
         print("Invalid TLV info magic; is this an MCUboot image?")
     elif ret == image.VerifyResult.INVALID_HASH:
-        print("Image has an invalid sha256 digest")
+        print("Image has an invalid hash")
     elif ret == image.VerifyResult.INVALID_SIGNATURE:
         print("No signature found for the given key")
     else:
@@ -384,6 +389,8 @@
     if enckey and key:
         if ((isinstance(key, keys.ECDSA256P1) and
              not isinstance(enckey, keys.ECDSA256P1Public))
+           or (isinstance(key, keys.ECDSA384P1) and
+               not isinstance(enckey, keys.ECDSA384P1Public))
                 or (isinstance(key, keys.RSA) and
                     not isinstance(enckey, keys.RSAPublic))):
             # FIXME