Add RSA-3072 support to imgtool

Signed-off-by: Fabio Utzig <utzig@apache.org>
diff --git a/scripts/imgtool/image.py b/scripts/imgtool/image.py
index c36f802..ad156a1 100644
--- a/scripts/imgtool/image.py
+++ b/scripts/imgtool/image.py
@@ -49,6 +49,7 @@
         'RSA2048': 0x20,
         'ECDSA224': 0x21,
         'ECDSA256': 0x22,
+        'RSA3072': 0x23,
         'ENCRSA2048': 0x30,
         'ENCKW128': 0x31,
         'DEPENDENCY': 0x40
diff --git a/scripts/imgtool/keys/__init__.py b/scripts/imgtool/keys/__init__.py
index da5b083..b92f871 100644
--- a/scripts/imgtool/keys/__init__.py
+++ b/scripts/imgtool/keys/__init__.py
@@ -21,7 +21,7 @@
 from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey
 from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey, EllipticCurvePublicKey
 
-from .rsa import RSA2048, RSA2048Public, RSAUsageError
+from .rsa import RSA, RSAPublic, RSAUsageError, RSA_KEY_SIZES
 from .ecdsa import ECDSA256P1, ECDSA256P1Public, ECDSAUsageError
 
 class PasswordRequired(Exception):
@@ -53,13 +53,13 @@
                 backend=default_backend())
 
     if isinstance(pk, RSAPrivateKey):
-        if pk.key_size != 2048:
+        if pk.key_size not in RSA_KEY_SIZES:
             raise Exception("Unsupported RSA key size: " + pk.key_size)
-        return RSA2048(pk)
+        return RSA(pk)
     elif isinstance(pk, RSAPublicKey):
-        if pk.key_size != 2048:
+        if pk.key_size not in RSA_KEY_SIZES:
             raise Exception("Unsupported RSA key size: " + pk.key_size)
-        return RSA2048Public(pk)
+        return RSAPublic(pk)
     elif isinstance(pk, EllipticCurvePrivateKey):
         if pk.curve.name != 'secp256r1':
             raise Exception("Unsupported EC curve: " + pk.curve.name)
diff --git a/scripts/imgtool/keys/rsa.py b/scripts/imgtool/keys/rsa.py
index 4ddbfc6..94af064 100644
--- a/scripts/imgtool/keys/rsa.py
+++ b/scripts/imgtool/keys/rsa.py
@@ -10,14 +10,23 @@
 
 from .general import KeyClass
 
+
+# Sizes that bootutil will recognize
+RSA_KEY_SIZES = [2048, 3072]
+
+
 class RSAUsageError(Exception):
     pass
 
-class RSA2048Public(KeyClass):
+
+class RSAPublic(KeyClass):
     """The public key can only do a few operations"""
     def __init__(self, key):
         self.key = key
 
+    def key_size(self):
+        return self.key.key_size
+
     def shortname(self):
         return "rsa"
 
@@ -45,17 +54,18 @@
             f.write(pem)
 
     def sig_type(self):
-        return "PKCS1_PSS_RSA2048_SHA256"
+        return "PKCS1_PSS_RSA{}_SHA256".format(self.key_size())
 
     def sig_tlv(self):
-        return "RSA2048"
+        return"RSA{}".format(self.key_size())
 
     def sig_len(self):
-        return 256
+        return self.key_size() / 8
 
-class RSA2048(RSA2048Public):
+
+class RSA(RSAPublic):
     """
-    Wrapper around an 2048-bit RSA key, with imgtool support.
+    Wrapper around an RSA key, with imgtool support.
     """
 
     def __init__(self, key):
@@ -63,18 +73,22 @@
         self.key = key
 
     @staticmethod
-    def generate():
+    def generate(key_size=2048):
+        if key_size not in RSA_KEY_SIZES:
+            raise RSAUsageError("Key size {} is not supported by MCUboot"
+                                .format(key_size))
         pk = rsa.generate_private_key(
                 public_exponent=65537,
-                key_size=2048,
+                key_size=key_size,
                 backend=default_backend())
-        return RSA2048(pk)
+        return RSA(pk)
 
     def _get_public(self):
         return self.key.public_key()
 
     def export_private(self, path, passwd=None):
-        """Write the private key to the given file, protecting it with the optional password."""
+        """Write the private key to the given file, protecting it with the
+        optional password."""
         if passwd is None:
             enc = serialization.NoEncryption()
         else:
diff --git a/scripts/imgtool/keys/rsa_test.py b/scripts/imgtool/keys/rsa_test.py
index 8151878..b01635d 100644
--- a/scripts/imgtool/keys/rsa_test.py
+++ b/scripts/imgtool/keys/rsa_test.py
@@ -13,9 +13,12 @@
 from cryptography.hazmat.primitives.hashes import SHA256
 
 # Setup sys path so 'imgtool' is in it.
-sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
+sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__),
+                                                '../..')))
 
-from imgtool.keys import load, RSA2048, RSAUsageError
+from imgtool.keys import load, RSA, RSAUsageError
+from imgtool.keys.rsa import RSA_KEY_SIZES
+
 
 class KeyGeneration(unittest.TestCase):
 
@@ -29,74 +32,84 @@
         self.test_dir.cleanup()
 
     def test_keygen(self):
-        name1 = self.tname("keygen.pem")
-        k = RSA2048.generate()
-        k.export_private(name1, b'secret')
+        # Try generating a RSA key with non-supported size
+        with self.assertRaises(RSAUsageError):
+            RSA.generate(key_size=1024)
 
-        # Try loading the key without a password.
-        self.assertIsNone(load(name1))
+        for key_size in RSA_KEY_SIZES:
+            name1 = self.tname("keygen.pem")
+            k = RSA.generate(key_size=key_size)
+            k.export_private(name1, b'secret')
 
-        k2 = load(name1, b'secret')
+            # Try loading the key without a password.
+            self.assertIsNone(load(name1))
 
-        pubname = self.tname('keygen-pub.pem')
-        k2.export_public(pubname)
-        pk2 = load(pubname)
+            k2 = load(name1, b'secret')
 
-        # We should be able to export the public key from the loaded
-        # public key, but not the private key.
-        pk2.export_public(self.tname('keygen-pub2.pem'))
-        self.assertRaises(RSAUsageError, pk2.export_private, self.tname('keygen-priv2.pem'))
+            pubname = self.tname('keygen-pub.pem')
+            k2.export_public(pubname)
+            pk2 = load(pubname)
+
+            # We should be able to export the public key from the loaded
+            # public key, but not the private key.
+            pk2.export_public(self.tname('keygen-pub2.pem'))
+            self.assertRaises(RSAUsageError, pk2.export_private,
+                              self.tname('keygen-priv2.pem'))
 
     def test_emit(self):
         """Basic sanity check on the code emitters."""
-        k = RSA2048.generate()
+        for key_size in RSA_KEY_SIZES:
+            k = RSA.generate(key_size=key_size)
 
-        ccode = io.StringIO()
-        k.emit_c(ccode)
-        self.assertIn("rsa_pub_key", ccode.getvalue())
-        self.assertIn("rsa_pub_key_len", ccode.getvalue())
+            ccode = io.StringIO()
+            k.emit_c(ccode)
+            self.assertIn("rsa_pub_key", ccode.getvalue())
+            self.assertIn("rsa_pub_key_len", ccode.getvalue())
 
-        rustcode = io.StringIO()
-        k.emit_rust(rustcode)
-        self.assertIn("RSA_PUB_KEY", rustcode.getvalue())
+            rustcode = io.StringIO()
+            k.emit_rust(rustcode)
+            self.assertIn("RSA_PUB_KEY", rustcode.getvalue())
 
     def test_emit_pub(self):
         """Basic sanity check on the code emitters, from public key."""
         pubname = self.tname("public.pem")
-        k = RSA2048.generate()
-        k.export_public(pubname)
+        for key_size in RSA_KEY_SIZES:
+            k = RSA.generate(key_size=key_size)
+            k.export_public(pubname)
 
-        k2 = load(pubname)
+            k2 = load(pubname)
 
-        ccode = io.StringIO()
-        k2.emit_c(ccode)
-        self.assertIn("rsa_pub_key", ccode.getvalue())
-        self.assertIn("rsa_pub_key_len", ccode.getvalue())
+            ccode = io.StringIO()
+            k2.emit_c(ccode)
+            self.assertIn("rsa_pub_key", ccode.getvalue())
+            self.assertIn("rsa_pub_key_len", ccode.getvalue())
 
-        rustcode = io.StringIO()
-        k2.emit_rust(rustcode)
-        self.assertIn("RSA_PUB_KEY", rustcode.getvalue())
+            rustcode = io.StringIO()
+            k2.emit_rust(rustcode)
+            self.assertIn("RSA_PUB_KEY", rustcode.getvalue())
 
     def test_sig(self):
-        k = RSA2048.generate()
-        buf = b'This is the message'
-        sig = k.sign(buf)
+        for key_size in RSA_KEY_SIZES:
+            k = RSA.generate(key_size=key_size)
+            buf = b'This is the message'
+            sig = k.sign(buf)
 
-        # The code doesn't have any verification, so verify this
-        # manually.
-        k.key.public_key().verify(
+            # The code doesn't have any verification, so verify this
+            # manually.
+            k.key.public_key().verify(
                 signature=sig,
                 data=buf,
                 padding=PSS(mgf=MGF1(SHA256()), salt_length=32),
                 algorithm=SHA256())
 
-        # Modify the message to make sure the signature fails.
-        self.assertRaises(InvalidSignature,
-                k.key.public_key().verify,
-                signature=sig,
-                data=b'This is thE message',
-                padding=PSS(mgf=MGF1(SHA256()), salt_length=32),
-                algorithm=SHA256())
+            # Modify the message to make sure the signature fails.
+            self.assertRaises(InvalidSignature,
+                              k.key.public_key().verify,
+                              signature=sig,
+                              data=b'This is thE message',
+                              padding=PSS(mgf=MGF1(SHA256()), salt_length=32),
+                              algorithm=SHA256())
+
 
 if __name__ == '__main__':
     unittest.main()
diff --git a/scripts/imgtool/main.py b/scripts/imgtool/main.py
index f26592d..cb204b0 100755
--- a/scripts/imgtool/main.py
+++ b/scripts/imgtool/main.py
@@ -24,7 +24,12 @@
 
 
 def gen_rsa2048(keyfile, passwd):
-    keys.RSA2048.generate().export_private(path=keyfile, passwd=passwd)
+    keys.RSA.generate().export_private(path=keyfile, passwd=passwd)
+
+
+def gen_rsa3072(keyfile, passwd):
+    keys.RSA.generate(key_size=3072).export_private(path=keyfile,
+                                                    passwd=passwd)
 
 
 def gen_ecdsa_p256(keyfile, passwd):
@@ -38,6 +43,7 @@
 valid_langs = ['c', 'rust']
 keygens = {
     'rsa-2048':   gen_rsa2048,
+    'rsa-3072':   gen_rsa3072,
     'ecdsa-p256': gen_ecdsa_p256,
     'ecdsa-p224': gen_ecdsa_p224,
 }
@@ -184,9 +190,9 @@
     key = load_key(key) if key else None
     enckey = load_key(encrypt) if encrypt else None
     if enckey:
-        if not isinstance(enckey, (keys.RSA2048, keys.RSA2048Public)):
+        if not isinstance(enckey, (keys.RSA, keys.RSAPublic)):
             raise Exception("Encryption only available with RSA key")
-        if key and not isinstance(key, keys.RSA2048):
+        if key and not isinstance(key, keys.RSA):
             raise Exception("Signing only available with private RSA key")
     img.create(key, enckey, dependencies)
     img.save(outfile)