psa_generate_key_ext: RSA: support custom public exponent

Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
diff --git a/ChangeLog.d/psa_generate_key_ext.txt b/ChangeLog.d/psa_generate_key_ext.txt
new file mode 100644
index 0000000..8340f01
--- /dev/null
+++ b/ChangeLog.d/psa_generate_key_ext.txt
@@ -0,0 +1,3 @@
+Features
+   * The new function psa_generate_key_ext() allows generating an RSA
+     key pair with a custom public exponent.
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 6263be9..98823de 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -7501,11 +7501,16 @@
 
 psa_status_t psa_generate_key_internal(
     const psa_key_attributes_t *attributes,
+    const psa_key_generation_method_t *method, size_t method_length,
     uint8_t *key_buffer, size_t key_buffer_size, size_t *key_buffer_length)
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
     psa_key_type_t type = attributes->core.type;
 
+    /* Only used for RSA */
+    (void) method;
+    (void) method_length;
+
     if ((attributes->domain_parameters == NULL) &&
         (attributes->domain_parameters_size != 0)) {
         return PSA_ERROR_INVALID_ARGUMENT;
@@ -7526,7 +7531,17 @@
 
 #if defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_GENERATE)
     if (type == PSA_KEY_TYPE_RSA_KEY_PAIR) {
-        return mbedtls_psa_rsa_generate_key(attributes,
+        /* Hack: if the method specifies a non-default e, pass it
+         * via the domain parameters. TODO: refactor this code so
+         * that mbedtls_psa_rsa_generate_key() gets e via a new
+         * parameter instead. */
+        psa_key_attributes_t override_attributes = *attributes;
+        if (method_length > sizeof(*method)) {
+            override_attributes.domain_parameters_size =
+                method_length - offsetof(psa_key_generation_method_t, data);
+            override_attributes.domain_parameters = (uint8_t *) &method->data;
+        }
+        return mbedtls_psa_rsa_generate_key(&override_attributes,
                                             key_buffer,
                                             key_buffer_size,
                                             key_buffer_length);
@@ -7584,6 +7599,14 @@
     if (method_length < sizeof(*method)) {
         return PSA_ERROR_INVALID_ARGUMENT;
     }
+
+#if defined(PSA_WANT_KEY_TYPE_RSA_KEY_PAIR_GENERATE)
+    if (attributes->core.type == PSA_KEY_TYPE_RSA_KEY_PAIR) {
+        if (method->flags != 0) {
+            return PSA_ERROR_INVALID_ARGUMENT;
+        }
+    } else
+#endif
     if (!psa_key_generation_method_is_default(method, method_length)) {
         return PSA_ERROR_INVALID_ARGUMENT;
     }
@@ -7625,8 +7648,9 @@
     }
 
     status = psa_driver_wrapper_generate_key(attributes,
-                                             slot->key.data, slot->key.bytes, &slot->key.bytes);
-
+                                             method, method_length,
+                                             slot->key.data, slot->key.bytes,
+                                             &slot->key.bytes);
     if (status != PSA_SUCCESS) {
         psa_remove_key_data_from_memory(slot);
     }
diff --git a/library/psa_crypto_core.h b/library/psa_crypto_core.h
index dc376d7..99aea9d 100644
--- a/library/psa_crypto_core.h
+++ b/library/psa_crypto_core.h
@@ -403,6 +403,10 @@
  *       entry point.
  *
  * \param[in]  attributes         The attributes for the key to generate.
+ * \param[in]  method             The generation method from
+ *                                psa_generate_key_ext().
+ *                                This can be \c NULL if \p method_length is 0.
+ * \param      method_length      The size of \p method in bytes.
  * \param[out] key_buffer         Buffer where the key data is to be written.
  * \param[in]  key_buffer_size    Size of \p key_buffer in bytes.
  * \param[out] key_buffer_length  On success, the number of bytes written in
@@ -417,6 +421,8 @@
  *         The size of \p key_buffer is too small.
  */
 psa_status_t psa_generate_key_internal(const psa_key_attributes_t *attributes,
+                                       const psa_key_generation_method_t *method,
+                                       size_t method_length,
                                        uint8_t *key_buffer,
                                        size_t key_buffer_size,
                                        size_t *key_buffer_length);
diff --git a/scripts/data_files/driver_templates/psa_crypto_driver_wrappers.h.jinja b/scripts/data_files/driver_templates/psa_crypto_driver_wrappers.h.jinja
index 924b08c..e6d827b 100644
--- a/scripts/data_files/driver_templates/psa_crypto_driver_wrappers.h.jinja
+++ b/scripts/data_files/driver_templates/psa_crypto_driver_wrappers.h.jinja
@@ -731,12 +731,16 @@
 
 static inline psa_status_t psa_driver_wrapper_generate_key(
     const psa_key_attributes_t *attributes,
+    const psa_key_generation_method_t *method, size_t method_length,
     uint8_t *key_buffer, size_t key_buffer_size, size_t *key_buffer_length )
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
     psa_key_location_t location =
         PSA_KEY_LIFETIME_GET_LOCATION(attributes->core.lifetime);
 
+    /* TODO: if method is non-default, we need a driver that supports
+     * passing a method. */
+
     /* Try dynamically-registered SE interface first */
 #if defined(MBEDTLS_PSA_CRYPTO_SE_C)
     const psa_drv_se_t *drv;
@@ -793,7 +797,8 @@
 
             /* Software fallback */
             status = psa_generate_key_internal(
-                attributes, key_buffer, key_buffer_size, key_buffer_length );
+                attributes, method, method_length,
+                key_buffer, key_buffer_size, key_buffer_length );
             break;
 
         /* Add cases for opaque driver here */
diff --git a/tests/suites/test_suite_psa_crypto.data b/tests/suites/test_suite_psa_crypto.data
index bc6ca1b..2d0cf2c 100644
--- a/tests/suites/test_suite_psa_crypto.data
+++ b/tests/suites/test_suite_psa_crypto.data
@@ -7439,22 +7439,22 @@
 depends_on:PSA_WANT_ALG_ECDH:PSA_WANT_KEY_TYPE_ECC_KEY_PAIR_GENERATE:PSA_WANT_ECC_MONTGOMERY_448
 generate_key:PSA_KEY_TYPE_ECC_KEY_PAIR(PSA_ECC_FAMILY_MONTGOMERY):448:PSA_KEY_USAGE_EXPORT | PSA_KEY_USAGE_DERIVE:PSA_ALG_ECDH:PSA_SUCCESS:0
 
-PSA generate key: RSA, default e
+PSA generate key: RSA, domain parameters: default e
 generate_key_rsa:PSA_VENDOR_RSA_GENERATE_MIN_KEY_BITS:"":PSA_SUCCESS
 
-PSA generate key: RSA, e=3
+PSA generate key: RSA, domain parameters: e=3
 generate_key_rsa:PSA_VENDOR_RSA_GENERATE_MIN_KEY_BITS:"03":PSA_SUCCESS
 
-PSA generate key: RSA, e=65537
+PSA generate key: RSA, domain parameters: e=65537
 generate_key_rsa:PSA_VENDOR_RSA_GENERATE_MIN_KEY_BITS:"010001":PSA_SUCCESS
 
-PSA generate key: RSA, e=513
+PSA generate key: RSA, domain parameters: e=513
 generate_key_rsa:PSA_VENDOR_RSA_GENERATE_MIN_KEY_BITS:"0201":PSA_SUCCESS
 
-PSA generate key: RSA, e=1
+PSA generate key: RSA, domain parameters: e=1
 generate_key_rsa:PSA_VENDOR_RSA_GENERATE_MIN_KEY_BITS:"01":PSA_ERROR_INVALID_ARGUMENT
 
-PSA generate key: RSA, e=2
+PSA generate key: RSA, domain parameters: e=2
 generate_key_rsa:PSA_VENDOR_RSA_GENERATE_MIN_KEY_BITS:"02":PSA_ERROR_INVALID_ARGUMENT
 
 PSA generate key: FFDH, 2048 bits, good
@@ -7482,13 +7482,68 @@
 generate_key:PSA_KEY_TYPE_DH_KEY_PAIR(PSA_DH_FAMILY_RFC7919):1024:PSA_KEY_USAGE_EXPORT:PSA_ALG_FFDH:PSA_ERROR_NOT_SUPPORTED:0
 
 PSA generate key ext: RSA, null method
-generate_key_ext:PSA_KEY_TYPE_RSA_KEY_PAIR:PSA_VENDOR_RSA_GENERATE_MIN_KEY_BITS:PSA_KEY_USAGE_ENCRYPT | PSA_KEY_USAGE_DECRYPT:PSA_ALG_RSA_PKCS1V15_CRYPT:-offsetof(psa_key_generation_method_t, data):"":PSA_ERROR_INVALID_ARGUMENT
+depends_on:PSA_WANT_KEY_TYPE_RSA_KEY_PAIR_GENERATE
+generate_key_ext:PSA_KEY_TYPE_RSA_KEY_PAIR:PSA_VENDOR_RSA_GENERATE_MIN_KEY_BITS:PSA_KEY_USAGE_ENCRYPT | PSA_KEY_USAGE_DECRYPT:0:-offsetof(psa_key_generation_method_t, data):"":PSA_ERROR_INVALID_ARGUMENT
 
 PSA generate key ext: RSA, method too short by 1
-generate_key_ext:PSA_KEY_TYPE_RSA_KEY_PAIR:PSA_VENDOR_RSA_GENERATE_MIN_KEY_BITS:PSA_KEY_USAGE_ENCRYPT | PSA_KEY_USAGE_DECRYPT:PSA_ALG_RSA_PKCS1V15_CRYPT:-1:"":PSA_ERROR_INVALID_ARGUMENT
+depends_on:PSA_WANT_KEY_TYPE_RSA_KEY_PAIR_GENERATE
+generate_key_ext:PSA_KEY_TYPE_RSA_KEY_PAIR:PSA_VENDOR_RSA_GENERATE_MIN_KEY_BITS:PSA_KEY_USAGE_ENCRYPT | PSA_KEY_USAGE_DECRYPT:0:-1:"":PSA_ERROR_INVALID_ARGUMENT
 
 PSA generate key ext: RSA, method.flags=1
-generate_key_ext:PSA_KEY_TYPE_RSA_KEY_PAIR:PSA_VENDOR_RSA_GENERATE_MIN_KEY_BITS:PSA_KEY_USAGE_ENCRYPT | PSA_KEY_USAGE_DECRYPT:PSA_ALG_RSA_PKCS1V15_CRYPT:1:"":PSA_ERROR_INVALID_ARGUMENT
+depends_on:PSA_WANT_KEY_TYPE_RSA_KEY_PAIR_GENERATE
+generate_key_ext:PSA_KEY_TYPE_RSA_KEY_PAIR:PSA_VENDOR_RSA_GENERATE_MIN_KEY_BITS:PSA_KEY_USAGE_ENCRYPT | PSA_KEY_USAGE_DECRYPT:0:1:"":PSA_ERROR_INVALID_ARGUMENT
+
+PSA generate key ext: RSA, empty e
+depends_on:PSA_WANT_KEY_TYPE_RSA_KEY_PAIR_GENERATE:PSA_WANT_ALG_RSA_PKCS1V15_CRYPT
+generate_key_ext:PSA_KEY_TYPE_RSA_KEY_PAIR:PSA_VENDOR_RSA_GENERATE_MIN_KEY_BITS:PSA_KEY_USAGE_ENCRYPT | PSA_KEY_USAGE_DECRYPT:PSA_ALG_RSA_PKCS1V15_CRYPT:0:"":PSA_SUCCESS
+
+PSA generate key ext: RSA, e=3
+depends_on:PSA_WANT_KEY_TYPE_RSA_KEY_PAIR_GENERATE:PSA_WANT_ALG_RSA_PKCS1V15_CRYPT
+generate_key_ext:PSA_KEY_TYPE_RSA_KEY_PAIR:PSA_VENDOR_RSA_GENERATE_MIN_KEY_BITS:PSA_KEY_USAGE_ENCRYPT | PSA_KEY_USAGE_DECRYPT:PSA_ALG_RSA_PKCS1V15_CRYPT:0:"03":PSA_SUCCESS
+
+PSA generate key ext: RSA, e=3 with leading zeros
+depends_on:PSA_WANT_KEY_TYPE_RSA_KEY_PAIR_GENERATE:PSA_WANT_ALG_RSA_PKCS1V15_CRYPT
+generate_key_ext:PSA_KEY_TYPE_RSA_KEY_PAIR:PSA_VENDOR_RSA_GENERATE_MIN_KEY_BITS:PSA_KEY_USAGE_ENCRYPT | PSA_KEY_USAGE_DECRYPT:PSA_ALG_RSA_PKCS1V15_CRYPT:0:"000003":PSA_SUCCESS
+
+# TODO: currently errors with NOT_SUPPORTED because e is converted to an int
+# and the conversion errors out if there are too many digits without checking
+# for leading zeros. This is a very minor bug. Re-enable this test when this
+# bug is fixed.
+#PSA generate key ext: RSA, e=3 with many leading zeros
+#depends_on:PSA_WANT_KEY_TYPE_RSA_KEY_PAIR_GENERATE:PSA_WANT_ALG_RSA_PKCS1V15_CRYPT
+#generate_key_ext:PSA_KEY_TYPE_RSA_KEY_PAIR:PSA_VENDOR_RSA_GENERATE_MIN_KEY_BITS:PSA_KEY_USAGE_ENCRYPT | PSA_KEY_USAGE_DECRYPT:PSA_ALG_RSA_PKCS1V15_CRYPT:0:"0000000000000000000000000000000003":PSA_SUCCESS
+
+PSA generate key ext: RSA, e=513
+depends_on:PSA_WANT_KEY_TYPE_RSA_KEY_PAIR_GENERATE:PSA_WANT_ALG_RSA_PKCS1V15_CRYPT
+generate_key_ext:PSA_KEY_TYPE_RSA_KEY_PAIR:PSA_VENDOR_RSA_GENERATE_MIN_KEY_BITS:PSA_KEY_USAGE_ENCRYPT | PSA_KEY_USAGE_DECRYPT:PSA_ALG_RSA_PKCS1V15_CRYPT:0:"0201":PSA_SUCCESS
+
+PSA generate key ext: RSA, e=65537
+depends_on:PSA_WANT_KEY_TYPE_RSA_KEY_PAIR_GENERATE:PSA_WANT_ALG_RSA_PKCS1V15_CRYPT
+generate_key_ext:PSA_KEY_TYPE_RSA_KEY_PAIR:PSA_VENDOR_RSA_GENERATE_MIN_KEY_BITS:PSA_KEY_USAGE_ENCRYPT | PSA_KEY_USAGE_DECRYPT:PSA_ALG_RSA_PKCS1V15_CRYPT:0:"010001":PSA_SUCCESS
+
+PSA generate key ext: RSA, e=2^31-1
+depends_on:PSA_WANT_KEY_TYPE_RSA_KEY_PAIR_GENERATE:PSA_WANT_ALG_RSA_PKCS1V15_CRYPT:INT_MAX>=0x7fffffff
+generate_key_ext:PSA_KEY_TYPE_RSA_KEY_PAIR:PSA_VENDOR_RSA_GENERATE_MIN_KEY_BITS:PSA_KEY_USAGE_ENCRYPT | PSA_KEY_USAGE_DECRYPT:PSA_ALG_RSA_PKCS1V15_CRYPT:0:"7fffffff":PSA_SUCCESS
+
+PSA generate key ext: RSA, e=2^31+3 (too large for built-in RSA)
+depends_on:PSA_WANT_KEY_TYPE_RSA_KEY_PAIR_GENERATE:!MBEDTLS_PSA_ACCEL_KEY_TYPE_RSA_KEY_PAIR_GENERATE:INT_MAX<=0x7fffffff
+generate_key_ext:PSA_KEY_TYPE_RSA_KEY_PAIR:PSA_VENDOR_RSA_GENERATE_MIN_KEY_BITS:PSA_KEY_USAGE_ENCRYPT | PSA_KEY_USAGE_DECRYPT:0:0:"80000003":PSA_ERROR_NOT_SUPPORTED
+
+PSA generate key ext: RSA, e=2^64+3 (too large for built-in RSA)
+depends_on:PSA_WANT_KEY_TYPE_RSA_KEY_PAIR_GENERATE:!MBEDTLS_PSA_ACCEL_KEY_TYPE_RSA_KEY_PAIR_GENERATE:INT_MAX<=0xffffffffffffffff
+generate_key_ext:PSA_KEY_TYPE_RSA_KEY_PAIR:PSA_VENDOR_RSA_GENERATE_MIN_KEY_BITS:PSA_KEY_USAGE_ENCRYPT | PSA_KEY_USAGE_DECRYPT:0:0:"010000000000000003":PSA_ERROR_NOT_SUPPORTED
+
+PSA generate key ext: RSA, e=1
+depends_on:PSA_WANT_KEY_TYPE_RSA_KEY_PAIR_GENERATE
+generate_key_ext:PSA_KEY_TYPE_RSA_KEY_PAIR:PSA_VENDOR_RSA_GENERATE_MIN_KEY_BITS:PSA_KEY_USAGE_ENCRYPT | PSA_KEY_USAGE_DECRYPT:0:0:"01":PSA_ERROR_INVALID_ARGUMENT
+
+PSA generate key ext: RSA, e=0
+depends_on:PSA_WANT_KEY_TYPE_RSA_KEY_PAIR_GENERATE
+generate_key_ext:PSA_KEY_TYPE_RSA_KEY_PAIR:PSA_VENDOR_RSA_GENERATE_MIN_KEY_BITS:PSA_KEY_USAGE_ENCRYPT | PSA_KEY_USAGE_DECRYPT:0:0:"00":PSA_ERROR_INVALID_ARGUMENT
+
+PSA generate key ext: RSA, e=2
+depends_on:PSA_WANT_KEY_TYPE_RSA_KEY_PAIR_GENERATE
+generate_key_ext:PSA_KEY_TYPE_RSA_KEY_PAIR:PSA_VENDOR_RSA_GENERATE_MIN_KEY_BITS:PSA_KEY_USAGE_ENCRYPT | PSA_KEY_USAGE_DECRYPT:0:0:"02":PSA_ERROR_INVALID_ARGUMENT
 
 PSA generate key ext: ECC, null method
 depends_on:PSA_WANT_KEY_TYPE_ECC_KEY_PAIR_GENERATE:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_ECDH
diff --git a/tests/suites/test_suite_psa_crypto.function b/tests/suites/test_suite_psa_crypto.function
index 05b49bd..34baa1b 100644
--- a/tests/suites/test_suite_psa_crypto.function
+++ b/tests/suites/test_suite_psa_crypto.function
@@ -1293,7 +1293,13 @@
         TEST_EQUAL(p[1], 0);
         TEST_EQUAL(p[2], 1);
     } else {
-        TEST_MEMORY_COMPARE(p, len, e_arg->x, e_arg->len);
+        const uint8_t *expected = e_arg->x;
+        size_t expected_len = e_arg->len;
+        while (expected_len > 0 && *expected == 0) {
+            ++expected;
+            --expected_len;
+        }
+        TEST_MEMORY_COMPARE(p, len, expected, expected_len);
     }
     ok = 1;
 
@@ -9960,6 +9966,12 @@
     TEST_EQUAL(psa_get_key_type(&got_attributes), type);
     TEST_EQUAL(psa_get_key_bits(&got_attributes), bits);
 
+#if defined(PSA_WANT_KEY_TYPE_RSA_KEY_PAIR_GENERATE)
+    if (type == PSA_KEY_TYPE_RSA_KEY_PAIR) {
+        TEST_ASSERT(rsa_test_e(key, bits, method_data));
+    }
+#endif
+
     /* Do something with the key according to its type and permitted usage. */
     if (!mbedtls_test_psa_exercise_key(key, usage, alg)) {
         goto exit;