Implement psa_generate_key_custom

Implement `psa_generate_key_custom()` and
`psa_key_derivation_output_key_custom()`. These functions replace
`psa_generate_key_ext()` and `psa_key_derivation_output_key_ext()`.
They have the same functionality, but a slightly different interface:
the `ext` functions use a structure with a flexible array member to pass
variable-length data, while the `custom` functions use a separate parameter.

Keep the `ext` functions for backward compatibility with Mbed TLS 3.6.0.
But make them a thin wrapper around the new `custom` functions.

Duplicate the test code and data. The test cases have to be duplicated
anyway, and the test functions are individually more readable this way.

Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 8100afc..a5b8557 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -6412,27 +6412,28 @@
     return status;
 }
 
-static const psa_key_production_parameters_t default_production_parameters =
+static const psa_custom_key_parameters_t default_custom_production =
     PSA_KEY_PRODUCTION_PARAMETERS_INIT;
 
 int psa_key_production_parameters_are_default(
-    const psa_key_production_parameters_t *params,
-    size_t params_data_length)
+    const psa_custom_key_parameters_t *custom,
+    size_t custom_data_length)
 {
-    if (params->flags != 0) {
+    if (custom->flags != 0) {
         return 0;
     }
-    if (params_data_length != 0) {
+    if (custom_data_length != 0) {
         return 0;
     }
     return 1;
 }
 
-psa_status_t psa_key_derivation_output_key_ext(
+psa_status_t psa_key_derivation_output_key_custom(
     const psa_key_attributes_t *attributes,
     psa_key_derivation_operation_t *operation,
-    const psa_key_production_parameters_t *params,
-    size_t params_data_length,
+    const psa_custom_key_parameters_t *custom,
+    const uint8_t *custom_data,
+    size_t custom_data_length,
     mbedtls_svc_key_id_t *key)
 {
     psa_status_t status;
@@ -6447,7 +6448,8 @@
         return PSA_ERROR_INVALID_ARGUMENT;
     }
 
-    if (!psa_key_production_parameters_are_default(params, params_data_length)) {
+    (void) custom_data;         /* We only accept 0-length data */
+    if (!psa_key_production_parameters_are_default(custom, custom_data_length)) {
         return PSA_ERROR_INVALID_ARGUMENT;
     }
 
@@ -6482,14 +6484,29 @@
     return status;
 }
 
+psa_status_t psa_key_derivation_output_key_ext(
+    const psa_key_attributes_t *attributes,
+    psa_key_derivation_operation_t *operation,
+    const psa_key_production_parameters_t *params,
+    size_t params_data_length,
+    mbedtls_svc_key_id_t *key)
+{
+    return psa_key_derivation_output_key_custom(
+        attributes, operation,
+        (const psa_custom_key_parameters_t *) params,
+        params->data, params_data_length,
+        key);
+}
+
 psa_status_t psa_key_derivation_output_key(
     const psa_key_attributes_t *attributes,
     psa_key_derivation_operation_t *operation,
     mbedtls_svc_key_id_t *key)
 {
-    return psa_key_derivation_output_key_ext(attributes, operation,
-                                             &default_production_parameters, 0,
-                                             key);
+    return psa_key_derivation_output_key_custom(attributes, operation,
+                                                &default_custom_production,
+                                                NULL, 0,
+                                                key);
 }
 
 
@@ -7863,15 +7880,18 @@
 
 psa_status_t psa_generate_key_internal(
     const psa_key_attributes_t *attributes,
-    const psa_key_production_parameters_t *params, size_t params_data_length,
+    const psa_custom_key_parameters_t *custom,
+    const uint8_t *custom_data,
+    size_t custom_data_length,
     uint8_t *key_buffer, size_t key_buffer_size, size_t *key_buffer_length)
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
     psa_key_type_t type = attributes->type;
 
     /* Only used for RSA */
-    (void) params;
-    (void) params_data_length;
+    (void) custom;
+    (void) custom_data;
+    (void) custom_data_length;
 
     if (key_type_is_raw_bytes(type)) {
         status = psa_generate_random_internal(key_buffer, key_buffer_size);
@@ -7889,7 +7909,7 @@
 #if defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_GENERATE)
     if (type == PSA_KEY_TYPE_RSA_KEY_PAIR) {
         return mbedtls_psa_rsa_generate_key(attributes,
-                                            params, params_data_length,
+                                            custom_data, custom_data_length,
                                             key_buffer,
                                             key_buffer_size,
                                             key_buffer_length);
@@ -7921,10 +7941,11 @@
     return PSA_SUCCESS;
 }
 
-psa_status_t psa_generate_key_ext(const psa_key_attributes_t *attributes,
-                                  const psa_key_production_parameters_t *params,
-                                  size_t params_data_length,
-                                  mbedtls_svc_key_id_t *key)
+psa_status_t psa_generate_key_custom(const psa_key_attributes_t *attributes,
+                                     const psa_custom_key_parameters_t *custom,
+                                     const uint8_t *custom_data,
+                                     size_t custom_data_length,
+                                     mbedtls_svc_key_id_t *key)
 {
     psa_status_t status;
     psa_key_slot_t *slot = NULL;
@@ -7946,12 +7967,12 @@
 
 #if defined(PSA_WANT_KEY_TYPE_RSA_KEY_PAIR_GENERATE)
     if (attributes->type == PSA_KEY_TYPE_RSA_KEY_PAIR) {
-        if (params->flags != 0) {
+        if (custom->flags != 0) {
             return PSA_ERROR_INVALID_ARGUMENT;
         }
     } else
 #endif
-    if (!psa_key_production_parameters_are_default(params, params_data_length)) {
+    if (!psa_key_production_parameters_are_default(custom, custom_data_length)) {
         return PSA_ERROR_INVALID_ARGUMENT;
     }
 
@@ -7992,7 +8013,8 @@
     }
 
     status = psa_driver_wrapper_generate_key(attributes,
-                                             params, params_data_length,
+                                             custom,
+                                             custom_data, custom_data_length,
                                              slot->key.data, slot->key.bytes,
                                              &slot->key.bytes);
     if (status != PSA_SUCCESS) {
@@ -8010,12 +8032,25 @@
     return status;
 }
 
+psa_status_t psa_generate_key_ext(const psa_key_attributes_t *attributes,
+                                  const psa_key_production_parameters_t *params,
+                                  size_t params_data_length,
+                                  mbedtls_svc_key_id_t *key)
+{
+    return psa_generate_key_custom(
+        attributes,
+        (const psa_custom_key_parameters_t *) params,
+        params->data, params_data_length,
+        key);
+}
+
 psa_status_t psa_generate_key(const psa_key_attributes_t *attributes,
                               mbedtls_svc_key_id_t *key)
 {
-    return psa_generate_key_ext(attributes,
-                                &default_production_parameters, 0,
-                                key);
+    return psa_generate_key_custom(attributes,
+                                   &default_custom_production,
+                                   NULL, 0,
+                                   key);
 }
 
 /****************************************************************/
diff --git a/library/psa_crypto_core.h b/library/psa_crypto_core.h
index 9462d2e..45320aa 100644
--- a/library/psa_crypto_core.h
+++ b/library/psa_crypto_core.h
@@ -343,17 +343,18 @@
     const uint8_t *key_buffer, size_t key_buffer_size,
     uint8_t *data, size_t data_size, size_t *data_length);
 
-/** Whether a key production parameters structure is the default.
+/** Whether a key custom production parameters structure is the default.
  *
- * Calls to a key generation driver with non-default production parameters
+ * Calls to a key generation driver with non-default custom production parameters
  * require a driver supporting custom production parameters.
  *
- * \param[in] params            The key production parameters to check.
- * \param params_data_length    Size of `params->data` in bytes.
+ * \param[in] custom            The key custom production parameters to check.
+ * \param custom_data_length    Size of the associated variable-length data
+ *                              in bytes.
  */
 int psa_key_production_parameters_are_default(
-    const psa_key_production_parameters_t *params,
-    size_t params_data_length);
+    const psa_custom_key_parameters_t *custom,
+    size_t custom_data_length);
 
 /**
  * \brief Generate a key.
@@ -362,9 +363,10 @@
  *       entry point.
  *
  * \param[in]  attributes         The attributes for the key to generate.
- * \param[in]  params             The production parameters from
- *                                psa_generate_key_ext().
- * \param      params_data_length The size of `params->data` in bytes.
+ * \param[in] custom              Custom parameters for the key generation.
+ * \param[in] custom_data         Variable-length data associated with \c custom.
+ * \param custom_data_length
+ *                                Length of `custom_data` in bytes.
  * \param[out] key_buffer         Buffer where the key data is to be written.
  * \param[in]  key_buffer_size    Size of \p key_buffer in bytes.
  * \param[out] key_buffer_length  On success, the number of bytes written in
@@ -379,8 +381,9 @@
  *         The size of \p key_buffer is too small.
  */
 psa_status_t psa_generate_key_internal(const psa_key_attributes_t *attributes,
-                                       const psa_key_production_parameters_t *params,
-                                       size_t params_data_length,
+                                       const psa_custom_key_parameters_t *custom,
+                                       const uint8_t *custom_data,
+                                       size_t custom_data_length,
                                        uint8_t *key_buffer,
                                        size_t key_buffer_size,
                                        size_t *key_buffer_length);
diff --git a/library/psa_crypto_rsa.c b/library/psa_crypto_rsa.c
index 2f613b3..f8e36d8 100644
--- a/library/psa_crypto_rsa.c
+++ b/library/psa_crypto_rsa.c
@@ -241,7 +241,7 @@
 
 psa_status_t mbedtls_psa_rsa_generate_key(
     const psa_key_attributes_t *attributes,
-    const psa_key_production_parameters_t *params, size_t params_data_length,
+    const uint8_t *custom_data, size_t custom_data_length,
     uint8_t *key_buffer, size_t key_buffer_size, size_t *key_buffer_length)
 {
     psa_status_t status;
@@ -249,8 +249,8 @@
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     int exponent = 65537;
 
-    if (params_data_length != 0) {
-        status = psa_rsa_read_exponent(params->data, params_data_length,
+    if (custom_data_length != 0) {
+        status = psa_rsa_read_exponent(custom_data, custom_data_length,
                                        &exponent);
         if (status != PSA_SUCCESS) {
             return status;
diff --git a/library/psa_crypto_rsa.h b/library/psa_crypto_rsa.h
index ffeef26..134844b 100644
--- a/library/psa_crypto_rsa.h
+++ b/library/psa_crypto_rsa.h
@@ -105,17 +105,11 @@
 /**
  * \brief Generate an RSA key.
  *
- * \note The signature of the function is that of a PSA driver generate_key
- *       entry point.
- *
  * \param[in]  attributes         The attributes for the RSA key to generate.
- * \param[in]  params             Production parameters for the key
- *                                generation. This function only uses
- *                                `params->data`,
- *                                which contains the public exponent.
+ * \param[in]  custom             The public exponent to use.
  *                                This can be a null pointer if
  *                                \c params_data_length is 0.
- * \param params_data_length      Length of `params->data` in bytes.
+ * \param custom_data_length      Length of \p custom_data in bytes.
  *                                This can be 0, in which case the
  *                                public exponent will be 65537.
  * \param[out] key_buffer         Buffer where the key data is to be written.
@@ -132,7 +126,7 @@
  */
 psa_status_t mbedtls_psa_rsa_generate_key(
     const psa_key_attributes_t *attributes,
-    const psa_key_production_parameters_t *params, size_t params_data_length,
+    const uint8_t *custom, size_t custom_data_length,
     uint8_t *key_buffer, size_t key_buffer_size, size_t *key_buffer_length);
 
 /** Sign an already-calculated hash with an RSA private key.