imgtool: fix getpriv format type for keys

A previous change was added to allow the `getpriv` command to dump ec256
keys in both openssl and pkcs8. That PR did not touch other key file
types which resulted in errors using that command with RSA, X25519, etc.

This commit generalizes the passing of the `format` parameter, so each
key type can decide which format it allows a dump to be produced in,
and what default to use.

Fixes #1529

Signed-off-by: Fabio Utzig <utzig@apache.org>
diff --git a/scripts/imgtool/keys/ecdsa.py b/scripts/imgtool/keys/ecdsa.py
index 79e4bb8..addceb2 100644
--- a/scripts/imgtool/keys/ecdsa.py
+++ b/scripts/imgtool/keys/ecdsa.py
@@ -11,6 +11,7 @@
 from cryptography.hazmat.primitives.hashes import SHA256
 
 from .general import KeyClass
+from .privatebytes import PrivateBytesMixin
 
 
 class ECDSAUsageError(Exception):
@@ -41,7 +42,7 @@
                 encoding=serialization.Encoding.PEM,
                 format=serialization.PublicFormat.SubjectPublicKeyInfo)
 
-    def get_private_bytes(self, minimal):
+    def get_private_bytes(self, minimal, format):
         self._unsupported('get_private_bytes')
 
     def export_private(self, path, passwd=None):
@@ -85,7 +86,7 @@
                         signature_algorithm=ec.ECDSA(SHA256()))
 
 
-class ECDSA256P1(ECDSA256P1Public):
+class ECDSA256P1(ECDSA256P1Public, PrivateBytesMixin):
     """
     Wrapper around an ECDSA private key.
     """
@@ -149,16 +150,17 @@
 
         return b
 
+    _VALID_FORMATS = {
+        'pkcs8': serialization.PrivateFormat.PKCS8,
+        'openssl': serialization.PrivateFormat.TraditionalOpenSSL
+    }
+    _DEFAULT_FORMAT='pkcs8'
+
     def get_private_bytes(self, minimal, format):
-        formats = {'pkcs8': serialization.PrivateFormat.PKCS8,
-                   'openssl': serialization.PrivateFormat.TraditionalOpenSSL
-                   }
-        priv = self.key.private_bytes(
-                encoding=serialization.Encoding.DER,
-                format=formats[format],
-                encryption_algorithm=serialization.NoEncryption())
+        format, priv = self._get_private_bytes(minimal, format, ECDSAUsageError)
         if minimal:
-            priv = self._build_minimal_ecdsa_privkey(priv, formats[format])
+            priv = self._build_minimal_ecdsa_privkey(priv,
+                                                     self._VALID_FORMATS[format])
         return priv
 
     def export_private(self, path, passwd=None):
diff --git a/scripts/imgtool/keys/ed25519.py b/scripts/imgtool/keys/ed25519.py
index b6367e7..6ca7b10 100644
--- a/scripts/imgtool/keys/ed25519.py
+++ b/scripts/imgtool/keys/ed25519.py
@@ -34,7 +34,7 @@
                 encoding=serialization.Encoding.DER,
                 format=serialization.PublicFormat.SubjectPublicKeyInfo)
 
-    def get_private_bytes(self, minimal):
+    def get_private_bytes(self, minimal, format):
         self._unsupported('get_private_bytes')
 
     def export_private(self, path, passwd=None):
@@ -75,7 +75,7 @@
     def _get_public(self):
         return self.key.public_key()
 
-    def get_private_bytes(self, minimal):
+    def get_private_bytes(self, minimal, format):
         raise Ed25519UsageError("Operation not supported with {} keys".format(
             self.shortname()))
 
diff --git a/scripts/imgtool/keys/privatebytes.py b/scripts/imgtool/keys/privatebytes.py
new file mode 100644
index 0000000..8027ac8
--- /dev/null
+++ b/scripts/imgtool/keys/privatebytes.py
@@ -0,0 +1,16 @@
+# SPDX-License-Identifier: Apache-2.0
+
+from cryptography.hazmat.primitives import serialization
+
+
+class PrivateBytesMixin():
+    def _get_private_bytes(self, minimal, format, exclass):
+        if format is None:
+            format = self._DEFAULT_FORMAT
+        if format not in self._VALID_FORMATS:
+            raise exclass("{} does not support {}".format(
+                self.shortname(), format))
+        return format, self.key.private_bytes(
+                encoding=serialization.Encoding.DER,
+                format=self._VALID_FORMATS[format],
+                encryption_algorithm=serialization.NoEncryption())
diff --git a/scripts/imgtool/keys/rsa.py b/scripts/imgtool/keys/rsa.py
index d51d36d..d4793c5 100644
--- a/scripts/imgtool/keys/rsa.py
+++ b/scripts/imgtool/keys/rsa.py
@@ -11,6 +11,7 @@
 from cryptography.hazmat.primitives.hashes import SHA256
 
 from .general import KeyClass
+from .privatebytes import PrivateBytesMixin
 
 
 # Sizes that bootutil will recognize
@@ -49,7 +50,7 @@
                 encoding=serialization.Encoding.PEM,
                 format=serialization.PublicFormat.SubjectPublicKeyInfo)
 
-    def get_private_bytes(self, minimal):
+    def get_private_bytes(self, minimal, format):
         self._unsupported('get_private_bytes')
 
     def export_private(self, path, passwd=None):
@@ -81,7 +82,7 @@
                         algorithm=SHA256())
 
 
-class RSA(RSAPublic):
+class RSA(RSAPublic, PrivateBytesMixin):
     """
     Wrapper around an RSA key, with imgtool support.
     """
@@ -138,11 +139,13 @@
         b[3] = (off - 4) & 0xff
         return b[:off]
 
-    def get_private_bytes(self, minimal):
-        priv = self.key.private_bytes(
-                encoding=serialization.Encoding.DER,
-                format=serialization.PrivateFormat.TraditionalOpenSSL,
-                encryption_algorithm=serialization.NoEncryption())
+    _VALID_FORMATS = {
+        'openssl': serialization.PrivateFormat.TraditionalOpenSSL
+    }
+    _DEFAULT_FORMAT = 'openssl'
+
+    def get_private_bytes(self, minimal, format):
+        _, priv = self._get_private_bytes(minimal, format, RSAUsageError)
         if minimal:
             priv = self._build_minimal_rsa_privkey(priv)
         return priv
diff --git a/scripts/imgtool/keys/x25519.py b/scripts/imgtool/keys/x25519.py
index 1e0aadb..a99cf18 100644
--- a/scripts/imgtool/keys/x25519.py
+++ b/scripts/imgtool/keys/x25519.py
@@ -9,6 +9,7 @@
 from cryptography.hazmat.primitives.asymmetric import x25519
 
 from .general import KeyClass
+from .privatebytes import PrivateBytesMixin
 
 
 class X25519UsageError(Exception):
@@ -39,7 +40,7 @@
                 encoding=serialization.Encoding.PEM,
                 format=serialization.PublicFormat.SubjectPublicKeyInfo)
 
-    def get_private_bytes(self, minimal):
+    def get_private_bytes(self, minimal, format):
         self._unsupported('get_private_bytes')
 
     def export_private(self, path, passwd=None):
@@ -63,7 +64,7 @@
         return 32
 
 
-class X25519(X25519Public):
+class X25519(X25519Public, PrivateBytesMixin):
     """
     Wrapper around an X25519 private key.
     """
@@ -80,11 +81,15 @@
     def _get_public(self):
         return self.key.public_key()
 
-    def get_private_bytes(self, minimal):
-        return self.key.private_bytes(
-            encoding=serialization.Encoding.DER,
-            format=serialization.PrivateFormat.PKCS8,
-            encryption_algorithm=serialization.NoEncryption())
+    _VALID_FORMATS = {
+        'pkcs8': serialization.PrivateFormat.PKCS8
+    }
+    _DEFAULT_FORMAT = 'pkcs8'
+
+    def get_private_bytes(self, minimal, format):
+        _, priv = self._get_private_bytes(minimal, format,
+                                          X25519UsageError)
+        return priv
 
     def export_private(self, path, passwd=None):
         """
diff --git a/scripts/imgtool/main.py b/scripts/imgtool/main.py
index e485a1a..a246bea 100755
--- a/scripts/imgtool/main.py
+++ b/scripts/imgtool/main.py
@@ -157,8 +157,8 @@
 @click.option('-k', '--key', metavar='filename', required=True)
 @click.option('-f', '--format',
               type=click.Choice(valid_formats),
-              help='Valid formats: {}'.format(', '.join(valid_formats)),
-              default='pkcs8')
+              help='Valid formats: {}'.format(', '.join(valid_formats))
+              )
 @click.command(help='Dump private key from keypair')
 def getpriv(key, minimal, format):
     key = load_key(key)