Key derivation: allow both keys and direct inputs

Allow a direct input as the SECRET input step in a key derivation, in
addition to allowing DERIVE keys. This makes it easier for
applications to run a key derivation where the "secret" input is
obtained from somewhere else. This makes it possible for the "secret"
input to be empty (keys cannot be empty), which some protocols do (for
example the IV derivation in EAP-TLS).

Conversely, allow a RAW_DATA key as the INFO/LABEL/SALT/SEED input to a key
derivation, in addition to allowing direct inputs. This doesn't
improve security, but removes a step when a personalization parameter
is stored in the key store, and allows this personalization parameter
to remain opaque.

Add test cases that explore step/key-type-and-keyhood combinations.
diff --git a/include/psa/crypto.h b/include/psa/crypto.h
index 9c303cb..ddc86cd 100644
--- a/include/psa/crypto.h
+++ b/include/psa/crypto.h
@@ -3298,7 +3298,8 @@
  * \retval #PSA_ERROR_INVALID_ARGUMENT
  *         \c step is not compatible with the operation's algorithm.
  * \retval #PSA_ERROR_INVALID_ARGUMENT
- *         \c step does not allow key inputs.
+ *         \c step does not allow key inputs of the given type
+ *         or does not allow key inputs at all.
  * \retval #PSA_ERROR_INSUFFICIENT_MEMORY
  * \retval #PSA_ERROR_COMMUNICATION_FAILURE
  * \retval #PSA_ERROR_HARDWARE_FAILURE
@@ -3368,6 +3369,8 @@
  *         \c private_key.
  * \retval #PSA_ERROR_NOT_SUPPORTED
  *         \c alg is not supported or is not a key derivation algorithm.
+ * \retval #PSA_ERROR_INVALID_ARGUMENT
+ *         \c step does not allow an input resulting from a key agreement.
  * \retval #PSA_ERROR_INSUFFICIENT_MEMORY
  * \retval #PSA_ERROR_COMMUNICATION_FAILURE
  * \retval #PSA_ERROR_HARDWARE_FAILURE
diff --git a/include/psa/crypto_values.h b/include/psa/crypto_values.h
index b3e0940..57d0651 100644
--- a/include/psa/crypto_values.h
+++ b/include/psa/crypto_values.h
@@ -1618,31 +1618,39 @@
 
 /** A secret input for key derivation.
  *
- * This must be a key of type #PSA_KEY_TYPE_DERIVE.
+ * This should be a key of type #PSA_KEY_TYPE_DERIVE
+ * (passed to psa_key_derivation_input_key())
+ * or the shared secret resulting from a key agreement
+ * (obtained via psa_key_derivation_key_agreement()).
+ * It can also be a direct input (passed to key_derivation_input_bytes()).
  */
 #define PSA_KEY_DERIVATION_INPUT_SECRET     ((psa_key_derivation_step_t)0x0101)
 
 /** A label for key derivation.
  *
- * This must be a direct input.
+ * This should be a direct input.
+ * It can also be a key of type #PSA_KEY_TYPE_RAW_DATA.
  */
 #define PSA_KEY_DERIVATION_INPUT_LABEL      ((psa_key_derivation_step_t)0x0201)
 
 /** A salt for key derivation.
  *
- * This must be a direct input.
+ * This should be a direct input.
+ * It can also be a key of type #PSA_KEY_TYPE_RAW_DATA.
  */
 #define PSA_KEY_DERIVATION_INPUT_SALT       ((psa_key_derivation_step_t)0x0202)
 
 /** An information string for key derivation.
  *
- * This must be a direct input.
+ * This should be a direct input.
+ * It can also be a key of type #PSA_KEY_TYPE_RAW_DATA.
  */
 #define PSA_KEY_DERIVATION_INPUT_INFO       ((psa_key_derivation_step_t)0x0203)
 
 /** A seed for key derivation.
  *
- * This must be a direct input.
+ * This should be a direct input.
+ * It can also be a key of type #PSA_KEY_TYPE_RAW_DATA.
  */
 #define PSA_KEY_DERIVATION_INPUT_SEED       ((psa_key_derivation_step_t)0x0204)
 
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index fe737d2..1494593 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -5076,13 +5076,38 @@
 }
 #endif /* MBEDTLS_MD_C */
 
+static int psa_key_derivation_check_input_type(
+    psa_key_derivation_step_t step,
+    psa_key_type_t key_type )
+{
+    switch( step )
+    {
+        case PSA_KEY_DERIVATION_INPUT_SECRET:
+            if( key_type == PSA_KEY_TYPE_DERIVE || key_type == 0 )
+                return( PSA_SUCCESS );
+            break;
+        case PSA_KEY_DERIVATION_INPUT_LABEL:
+        case PSA_KEY_DERIVATION_INPUT_SALT:
+        case PSA_KEY_DERIVATION_INPUT_INFO:
+        case PSA_KEY_DERIVATION_INPUT_SEED:
+            if( key_type == PSA_KEY_TYPE_RAW_DATA || key_type == 0 )
+                return( PSA_SUCCESS );
+            break;
+    }
+    return( PSA_ERROR_INVALID_ARGUMENT );
+}
+
 static psa_status_t psa_key_derivation_input_internal(
     psa_key_derivation_operation_t *operation,
     psa_key_derivation_step_t step,
+    psa_key_type_t key_type,
     const uint8_t *data,
     size_t data_length )
 {
-    psa_status_t status;
+    psa_status_t status = psa_key_derivation_check_input_type( step, key_type );
+    if( status != PSA_SUCCESS )
+        goto exit;
+
     psa_algorithm_t kdf_alg = psa_key_derivation_get_kdf_alg( operation );
 
 #if defined(MBEDTLS_MD_C)
@@ -5111,6 +5136,7 @@
         return( PSA_ERROR_BAD_STATE );
     }
 
+exit:
     if( status != PSA_SUCCESS )
         psa_key_derivation_abort( operation );
     return( status );
@@ -5122,10 +5148,7 @@
     const uint8_t *data,
     size_t data_length )
 {
-    if( step == PSA_KEY_DERIVATION_INPUT_SECRET )
-        return( PSA_ERROR_INVALID_ARGUMENT );
-
-    return( psa_key_derivation_input_internal( operation, step,
+    return( psa_key_derivation_input_internal( operation, step, 0,
                                                data, data_length ) );
 }
 
@@ -5141,18 +5164,8 @@
                                       operation->alg );
     if( status != PSA_SUCCESS )
         return( status );
-    if( slot->attr.type != PSA_KEY_TYPE_DERIVE )
-        return( PSA_ERROR_INVALID_ARGUMENT );
-    /* Don't allow a key to be used as an input that is usually public.
-     * This is debatable. It's ok from a cryptographic perspective to
-     * use secret material as an input that is usually public. However
-     * the material should be dedicated to a particular input step,
-     * otherwise this may allow the key to be used in an unintended way
-     * and leak values derived from the key. So be conservative. */
-    if( step != PSA_KEY_DERIVATION_INPUT_SECRET )
-        return( PSA_ERROR_INVALID_ARGUMENT );
     return( psa_key_derivation_input_internal( operation,
-                                               step,
+                                               step, slot->attr.type,
                                                slot->data.raw.data,
                                                slot->data.raw.bytes ) );
 }
@@ -5265,8 +5278,10 @@
         goto exit;
 
     /* Step 2: set up the key derivation to generate key material from
-     * the shared secret. */
+     * the shared secret. A shared secret is permitted wherever a key
+     * of type DERIVE is permitted. */
     status = psa_key_derivation_input_internal( operation, step,
+                                                PSA_KEY_TYPE_DERIVE,
                                                 shared_secret,
                                                 shared_secret_length );
 
diff --git a/tests/suites/test_suite_psa_crypto.data b/tests/suites/test_suite_psa_crypto.data
index cf95698..bfa3c1d 100644
--- a/tests/suites/test_suite_psa_crypto.data
+++ b/tests/suites/test_suite_psa_crypto.data
@@ -1900,6 +1900,30 @@
 depends_on:MBEDTLS_MD_C:MBEDTLS_SHA256_C
 derive_input:PSA_ALG_HKDF(PSA_ALG_SHA_256):PSA_KEY_DERIVATION_INPUT_SALT:0:"":PSA_KEY_DERIVATION_INPUT_SECRET:PSA_KEY_TYPE_RAW_DATA:"0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b":PSA_KEY_DERIVATION_INPUT_INFO:0:"":PSA_SUCCESS:PSA_ERROR_INVALID_ARGUMENT:PSA_SUCCESS
 
+PSA key derivation: HKDF-SHA-256, direct secret
+depends_on:MBEDTLS_MD_C:MBEDTLS_SHA256_C
+derive_input:PSA_ALG_HKDF(PSA_ALG_SHA_256):PSA_KEY_DERIVATION_INPUT_SALT:0:"":PSA_KEY_DERIVATION_INPUT_SECRET:0:"0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b":PSA_KEY_DERIVATION_INPUT_INFO:0:"":PSA_SUCCESS:PSA_SUCCESS:PSA_SUCCESS
+
+PSA key derivation: HKDF-SHA-256, direct empty secret
+depends_on:MBEDTLS_MD_C:MBEDTLS_SHA256_C
+derive_input:PSA_ALG_HKDF(PSA_ALG_SHA_256):PSA_KEY_DERIVATION_INPUT_SALT:0:"":PSA_KEY_DERIVATION_INPUT_SECRET:0:"":PSA_KEY_DERIVATION_INPUT_INFO:0:"":PSA_SUCCESS:PSA_SUCCESS:PSA_SUCCESS
+
+PSA key derivation: HKDF-SHA-256, RAW_DATA key as salt
+depends_on:MBEDTLS_MD_C:MBEDTLS_SHA256_C
+derive_input:PSA_ALG_HKDF(PSA_ALG_SHA_256):PSA_KEY_DERIVATION_INPUT_SALT:PSA_KEY_TYPE_RAW_DATA:"412073616c74":PSA_KEY_DERIVATION_INPUT_SECRET:PSA_KEY_TYPE_DERIVE:"0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b":PSA_KEY_DERIVATION_INPUT_INFO:0:"":PSA_SUCCESS:PSA_SUCCESS:PSA_SUCCESS
+
+PSA key derivation: HKDF-SHA-256, RAW_DATA key as info
+depends_on:MBEDTLS_MD_C:MBEDTLS_SHA256_C
+derive_input:PSA_ALG_HKDF(PSA_ALG_SHA_256):PSA_KEY_DERIVATION_INPUT_SALT:0:"":PSA_KEY_DERIVATION_INPUT_SECRET:PSA_KEY_TYPE_DERIVE:"0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b":PSA_KEY_DERIVATION_INPUT_INFO:PSA_KEY_TYPE_RAW_DATA:"4120696e666f":PSA_SUCCESS:PSA_SUCCESS:PSA_SUCCESS
+
+PSA key derivation: HKDF-SHA-256, DERIVE key as salt
+depends_on:MBEDTLS_MD_C:MBEDTLS_SHA256_C
+derive_input:PSA_ALG_HKDF(PSA_ALG_SHA_256):PSA_KEY_DERIVATION_INPUT_SALT:PSA_KEY_TYPE_DERIVE:"412073616c74":PSA_KEY_DERIVATION_INPUT_SECRET:PSA_KEY_TYPE_DERIVE:"0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b":PSA_KEY_DERIVATION_INPUT_INFO:0:"":PSA_ERROR_INVALID_ARGUMENT:PSA_ERROR_BAD_STATE:PSA_ERROR_BAD_STATE
+
+PSA key derivation: HKDF-SHA-256, DERIVE key as info
+depends_on:MBEDTLS_MD_C:MBEDTLS_SHA256_C
+derive_input:PSA_ALG_HKDF(PSA_ALG_SHA_256):PSA_KEY_DERIVATION_INPUT_SALT:0:"":PSA_KEY_DERIVATION_INPUT_SECRET:PSA_KEY_TYPE_DERIVE:"0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b":PSA_KEY_DERIVATION_INPUT_INFO:PSA_KEY_TYPE_DERIVE:"4120696e666f":PSA_SUCCESS:PSA_SUCCESS:PSA_ERROR_INVALID_ARGUMENT
+
 PSA key derivation: TLS 1.2 PRF SHA-256, good case
 depends_on:MBEDTLS_MD_C:MBEDTLS_SHA256_C
 derive_input:PSA_ALG_TLS12_PRF(PSA_ALG_SHA_256):PSA_KEY_DERIVATION_INPUT_SEED:0:"":PSA_KEY_DERIVATION_INPUT_SECRET:PSA_KEY_TYPE_DERIVE:"0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b":PSA_KEY_DERIVATION_INPUT_LABEL:0:"":PSA_SUCCESS:PSA_SUCCESS:PSA_SUCCESS
@@ -1928,6 +1952,30 @@
 depends_on:MBEDTLS_MD_C:MBEDTLS_SHA256_C
 derive_input:PSA_ALG_TLS12_PRF(PSA_ALG_SHA_256):PSA_KEY_DERIVATION_INPUT_SEED:0:"":PSA_KEY_DERIVATION_INPUT_SECRET:PSA_KEY_TYPE_RAW_DATA:"0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b":PSA_KEY_DERIVATION_INPUT_LABEL:0:"":PSA_SUCCESS:PSA_ERROR_INVALID_ARGUMENT:PSA_ERROR_BAD_STATE
 
+PSA key derivation: TLS 1.2 PRF SHA-256, direct secret
+depends_on:MBEDTLS_MD_C:MBEDTLS_SHA256_C
+derive_input:PSA_ALG_TLS12_PRF(PSA_ALG_SHA_256):PSA_KEY_DERIVATION_INPUT_SEED:0:"":PSA_KEY_DERIVATION_INPUT_SECRET:0:"0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b":PSA_KEY_DERIVATION_INPUT_LABEL:0:"":PSA_SUCCESS:PSA_SUCCESS:PSA_SUCCESS
+
+PSA key derivation: TLS 1.2 PRF SHA-256, direct empty secret
+depends_on:MBEDTLS_MD_C:MBEDTLS_SHA256_C
+derive_input:PSA_ALG_TLS12_PRF(PSA_ALG_SHA_256):PSA_KEY_DERIVATION_INPUT_SEED:0:"":PSA_KEY_DERIVATION_INPUT_SECRET:0:"":PSA_KEY_DERIVATION_INPUT_LABEL:0:"":PSA_SUCCESS:PSA_SUCCESS:PSA_SUCCESS
+
+PSA key derivation: TLS 1.2 PRF SHA-256, RAW_DATA key as seed
+depends_on:MBEDTLS_MD_C:MBEDTLS_SHA256_C
+derive_input:PSA_ALG_TLS12_PRF(PSA_ALG_SHA_256):PSA_KEY_DERIVATION_INPUT_SEED:PSA_KEY_TYPE_RAW_DATA:"612073656564":PSA_KEY_DERIVATION_INPUT_SECRET:PSA_KEY_TYPE_DERIVE:"0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b":PSA_KEY_DERIVATION_INPUT_LABEL:0:"":PSA_SUCCESS:PSA_SUCCESS:PSA_SUCCESS
+
+PSA key derivation: TLS 1.2 PRF SHA-256, RAW_DATA key as label
+depends_on:MBEDTLS_MD_C:MBEDTLS_SHA256_C
+derive_input:PSA_ALG_TLS12_PRF(PSA_ALG_SHA_256):PSA_KEY_DERIVATION_INPUT_SEED:0:"":PSA_KEY_DERIVATION_INPUT_SECRET:PSA_KEY_TYPE_DERIVE:"0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b":PSA_KEY_DERIVATION_INPUT_LABEL:PSA_KEY_TYPE_RAW_DATA:"61206c6162656c":PSA_SUCCESS:PSA_SUCCESS:PSA_SUCCESS
+
+PSA key derivation: TLS 1.2 PRF SHA-256, DERIVE key as seed
+depends_on:MBEDTLS_MD_C:MBEDTLS_SHA256_C
+derive_input:PSA_ALG_TLS12_PRF(PSA_ALG_SHA_256):PSA_KEY_DERIVATION_INPUT_SEED:PSA_KEY_TYPE_DERIVE:"612073656564":PSA_KEY_DERIVATION_INPUT_SECRET:PSA_KEY_TYPE_DERIVE:"0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b":PSA_KEY_DERIVATION_INPUT_LABEL:0:"":PSA_ERROR_INVALID_ARGUMENT:PSA_ERROR_BAD_STATE:PSA_ERROR_BAD_STATE
+
+PSA key derivation: TLS 1.2 PRF SHA-256, DERIVE key as label
+depends_on:MBEDTLS_MD_C:MBEDTLS_SHA256_C
+derive_input:PSA_ALG_TLS12_PRF(PSA_ALG_SHA_256):PSA_KEY_DERIVATION_INPUT_SEED:0:"":PSA_KEY_DERIVATION_INPUT_SECRET:PSA_KEY_TYPE_DERIVE:"0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b":PSA_KEY_DERIVATION_INPUT_LABEL:PSA_KEY_TYPE_DERIVE:"61206c6162656c":PSA_SUCCESS:PSA_SUCCESS:PSA_ERROR_INVALID_ARGUMENT
+
 PSA key derivation: TLS 1.2 PSK-to-MS, SHA-256, PSK too long (160 Bytes)
 depends_on:MBEDTLS_MD_C:MBEDTLS_SHA256_C
 derive_input:PSA_ALG_TLS12_PSK_TO_MS(PSA_ALG_SHA_256):PSA_KEY_DERIVATION_INPUT_SEED:0:"":PSA_KEY_DERIVATION_INPUT_SECRET:PSA_KEY_TYPE_DERIVE:"01020304050607080102030405060708010203040506070801020304050607080102030405060708010203040506070801020304050607080102030405060708010203040506070801020304050607080102030405060708010203040506070801020304050607080102030405060708010203040506070801020304050607080102030405060708010203040506070801020304050607080102030405060708":PSA_KEY_DERIVATION_INPUT_LABEL:0:"":PSA_SUCCESS:PSA_ERROR_INVALID_ARGUMENT:PSA_ERROR_BAD_STATE