Make key agreement the secret input for key derivation

* Documentation
* Proof-of-concept implementation
* Updates to the tests (work in progress)
diff --git a/include/psa/crypto.h b/include/psa/crypto.h
index f1731f6..2217b95 100644
--- a/include/psa/crypto.h
+++ b/include/psa/crypto.h
@@ -2284,19 +2284,24 @@
                                 size_t label_length,
                                 size_t capacity);
 
-/** Set up a key agreement operation.
+/** Perform a key agreement and use the shared secret as input to a key
+ * derivation.
  *
  * A key agreement algorithm takes two inputs: a private key \p private_key
  * a public key \p peer_key.
- * The result of this function is a byte generator which can
- * be used to produce keys and other cryptographic material.
+ * The result of this function is passed as input to a key derivation.
+ * The output of this key derivation can be extracted by reading from the
+ * resulting generator to produce keys and other cryptographic material.
  *
- * The resulting generator always has the maximum capacity permitted by
- * the algorithm.
- *
- * \param[in,out] generator       The generator object to set up. It must have
- *                                been initialized as per the documentation for
- *                                #psa_crypto_generator_t and not yet in use.
+ * \param[in,out] generator       The generator object to use. It must
+ *                                have been set up with
+ *                                psa_key_derivation_setup() with a
+ *                                key agreement algorithm
+ *                                (\c PSA_ALG_XXX value such that
+ *                                #PSA_ALG_IS_KEY_AGREEMENT(\p alg) is true).
+ *                                The generator must be ready for an
+ *                                input of the type given by \p step.
+ * \param step                    Which step the input data is for.
  * \param private_key             Handle to the private key to use.
  * \param[in] peer_key            Public key of the peer. It must be
  *                                in the same format that psa_import_key()
@@ -2304,9 +2309,6 @@
  *                                keys are documented in the documentation
  *                                of psa_export_public_key().
  * \param peer_key_length         Size of \p peer_key in bytes.
- * \param alg                     The key agreement algorithm to compute
- *                                (\c PSA_ALG_XXX value such that
- *                                #PSA_ALG_IS_KEY_AGREEMENT(\p alg) is true).
  *
  * \retval #PSA_SUCCESS
  *         Success.
@@ -2325,10 +2327,10 @@
  * \retval #PSA_ERROR_TAMPERING_DETECTED
  */
 psa_status_t psa_key_agreement(psa_crypto_generator_t *generator,
+                               psa_key_derivation_step_t step,
                                psa_key_handle_t private_key,
                                const uint8_t *peer_key,
-                               size_t peer_key_length,
-                               psa_algorithm_t alg);
+                               size_t peer_key_length);
 
 /**@}*/
 
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 6269fba..d616c14 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -3318,17 +3318,28 @@
 #define HKDF_STATE_KEYED 2 /* got key */
 #define HKDF_STATE_OUTPUT 3 /* output started */
 
+static psa_algorithm_t psa_generator_get_kdf_alg(
+    const psa_crypto_generator_t *generator )
+{
+    if ( PSA_ALG_IS_KEY_AGREEMENT( generator->alg ) )
+        return( PSA_ALG_KEY_AGREEMENT_GET_KDF( generator->alg ) );
+    else
+        return( generator->alg );
+}
+
+
 psa_status_t psa_generator_abort( psa_crypto_generator_t *generator )
 {
     psa_status_t status = PSA_SUCCESS;
-    if( generator->alg == 0 )
+    psa_algorithm_t kdf_alg = psa_generator_get_kdf_alg( generator );
+    if( kdf_alg == 0 )
     {
         /* The object has (apparently) been initialized but it is not
          * in use. It's ok to call abort on such an object, and there's
          * nothing to do. */
     }
     else
-    if( generator->alg == PSA_ALG_SELECT_RAW )
+    if( kdf_alg == PSA_ALG_SELECT_RAW )
     {
         if( generator->ctx.buffer.data != NULL )
         {
@@ -3339,14 +3350,14 @@
     }
     else
 #if defined(MBEDTLS_MD_C)
-    if( PSA_ALG_IS_HKDF( generator->alg ) )
+    if( PSA_ALG_IS_HKDF( kdf_alg ) )
     {
         mbedtls_free( generator->ctx.hkdf.info );
         status = psa_hmac_abort_internal( &generator->ctx.hkdf.hmac );
     }
-    else if( PSA_ALG_IS_TLS12_PRF( generator->alg ) ||
+    else if( PSA_ALG_IS_TLS12_PRF( kdf_alg ) ||
              /* TLS-1.2 PSK-to-MS KDF uses the same generator as TLS-1.2 PRF */
-             PSA_ALG_IS_TLS12_PSK_TO_MS( generator->alg ) )
+             PSA_ALG_IS_TLS12_PSK_TO_MS( kdf_alg ) )
     {
         if( generator->ctx.tls12_prf.key != NULL )
         {
@@ -3617,6 +3628,7 @@
                                  size_t output_length )
 {
     psa_status_t status;
+    psa_algorithm_t kdf_alg = psa_generator_get_kdf_alg( generator );
 
     if( output_length > generator->capacity )
     {
@@ -3627,7 +3639,7 @@
         goto exit;
     }
     if( output_length == 0 &&
-        generator->capacity == 0 && generator->alg == 0 )
+        generator->capacity == 0 && kdf_alg == 0 )
     {
         /* Edge case: this is a blank or finished generator, and 0
          * bytes were requested. The right error in this case could
@@ -3639,7 +3651,7 @@
     }
     generator->capacity -= output_length;
 
-    if( generator->alg == PSA_ALG_SELECT_RAW )
+    if( kdf_alg == PSA_ALG_SELECT_RAW )
     {
         /* Initially, the capacity of a selection generator is always
          * the size of the buffer, i.e. `generator->ctx.buffer.size`,
@@ -3657,17 +3669,17 @@
     }
     else
 #if defined(MBEDTLS_MD_C)
-    if( PSA_ALG_IS_HKDF( generator->alg ) )
+    if( PSA_ALG_IS_HKDF( kdf_alg ) )
     {
-        psa_algorithm_t hash_alg = PSA_ALG_HKDF_GET_HASH( generator->alg );
+        psa_algorithm_t hash_alg = PSA_ALG_HKDF_GET_HASH( kdf_alg );
         status = psa_generator_hkdf_read( &generator->ctx.hkdf, hash_alg,
                                           output, output_length );
     }
-    else if( PSA_ALG_IS_TLS12_PRF( generator->alg ) ||
-             PSA_ALG_IS_TLS12_PSK_TO_MS( generator->alg ) )
+    else if( PSA_ALG_IS_TLS12_PRF( kdf_alg ) ||
+             PSA_ALG_IS_TLS12_PSK_TO_MS( kdf_alg ) )
     {
         status = psa_generator_tls12_prf_read( &generator->ctx.tls12_prf,
-                                               generator->alg, output,
+                                               kdf_alg, output,
                                                output_length );
     }
     else
@@ -4019,38 +4031,66 @@
     return( status );
 }
 
-psa_status_t psa_key_derivation_setup( psa_crypto_generator_t *generator,
-                                       psa_algorithm_t alg )
+static psa_status_t psa_key_derivation_setup_kdf(
+    psa_crypto_generator_t *generator,
+    psa_algorithm_t kdf_alg )
 {
-    if( generator->alg != 0 )
-        return( PSA_ERROR_BAD_STATE );
-    /* Make sure that alg is a supported key derivation algorithm.
-     * Key agreement algorithms and key selection algorithms are not
-     * supported by this function. */
+    /* Make sure that kdf_alg is a supported key derivation algorithm. */
 #if defined(MBEDTLS_MD_C)
-    if( PSA_ALG_IS_HKDF( alg ) ||
-        PSA_ALG_IS_TLS12_PRF( alg ) ||
-        PSA_ALG_IS_TLS12_PSK_TO_MS( alg ) )
+    if( PSA_ALG_IS_HKDF( kdf_alg ) ||
+        PSA_ALG_IS_TLS12_PRF( kdf_alg ) ||
+        PSA_ALG_IS_TLS12_PSK_TO_MS( kdf_alg ) )
     {
-        psa_algorithm_t hash_alg = PSA_ALG_HKDF_GET_HASH( alg );
+        psa_algorithm_t hash_alg = PSA_ALG_HKDF_GET_HASH( kdf_alg );
         size_t hash_size = PSA_HASH_SIZE( hash_alg );
         if( hash_size == 0 )
             return( PSA_ERROR_NOT_SUPPORTED );
-        if( ( PSA_ALG_IS_TLS12_PRF( alg ) ||
-              PSA_ALG_IS_TLS12_PSK_TO_MS( alg ) ) &&
+        if( ( PSA_ALG_IS_TLS12_PRF( kdf_alg ) ||
+              PSA_ALG_IS_TLS12_PSK_TO_MS( kdf_alg ) ) &&
             ! ( hash_alg == PSA_ALG_SHA_256 && hash_alg == PSA_ALG_SHA_384 ) )
         {
             return( PSA_ERROR_NOT_SUPPORTED );
         }
         generator->capacity = 255 * hash_size;
+        return( PSA_SUCCESS );
     }
 #endif /* MBEDTLS_MD_C */
-    else if( PSA_ALG_IS_KEY_DERIVATION( alg ) )
+    else
         return( PSA_ERROR_NOT_SUPPORTED );
+}
+
+psa_status_t psa_key_derivation_setup( psa_crypto_generator_t *generator,
+                                       psa_algorithm_t alg )
+{
+    psa_status_t status;
+
+    if( generator->alg != 0 )
+        return( PSA_ERROR_BAD_STATE );
+
+    if( PSA_ALG_IS_KEY_AGREEMENT( alg ) )
+    {
+        psa_algorithm_t kdf_alg = PSA_ALG_KEY_AGREEMENT_GET_KDF( alg );
+        if( kdf_alg == PSA_ALG_SELECT_RAW )
+        {
+            /* It's too early to set the generator's capacity since it
+             * depends on the key size for the key agreement. */
+            status = PSA_SUCCESS;
+        }
+        else
+        {
+            status = psa_key_derivation_setup_kdf( generator, kdf_alg );
+        }
+    }
+    else if( PSA_ALG_IS_KEY_DERIVATION( alg ) )
+    {
+        status = psa_key_derivation_setup_kdf( generator, alg );
+    }
     else
         return( PSA_ERROR_INVALID_ARGUMENT );
-    generator->alg = alg;
-    return( PSA_SUCCESS );
+
+    if( status == PSA_SUCCESS )
+        generator->alg = alg;
+    return( status );
 }
 
 #if defined(MBEDTLS_MD_C)
@@ -4135,27 +4175,40 @@
     size_t data_length )
 {
     psa_status_t status;
+    psa_algorithm_t kdf_alg = psa_generator_get_kdf_alg( generator );
 
+    if( kdf_alg == PSA_ALG_SELECT_RAW )
+    {
+        if( generator->capacity != 0 )
+            return( PSA_ERROR_INVALID_ARGUMENT );
+        generator->ctx.buffer.data = mbedtls_calloc( 1, data_length );
+        if( generator->ctx.buffer.data == NULL )
+            return( PSA_ERROR_INSUFFICIENT_MEMORY );
+        memcpy( generator->ctx.buffer.data, data, data_length );
+        generator->ctx.buffer.size = data_length;
+        generator->capacity = data_length;
+        status = PSA_SUCCESS;
+    }
+    else
 #if defined(MBEDTLS_MD_C)
-    if( PSA_ALG_IS_HKDF( generator->alg ) )
+    if( PSA_ALG_IS_HKDF( kdf_alg ) )
     {
         status = psa_hkdf_input( &generator->ctx.hkdf,
-                                 PSA_ALG_HKDF_GET_HASH( generator->alg ),
+                                 PSA_ALG_HKDF_GET_HASH( kdf_alg ),
                                  step, data, data_length );
     }
+    else
 #endif /* MBEDTLS_MD_C */
-
 #if defined(MBEDTLS_MD_C)
     /* TLS-1.2 PRF and TLS-1.2 PSK-to-MS are very similar, so share code. */
-    else if( PSA_ALG_IS_TLS12_PRF( generator->alg ) ||
-             PSA_ALG_IS_TLS12_PSK_TO_MS( generator->alg ) )
+    if( PSA_ALG_IS_TLS12_PRF( kdf_alg ) ||
+             PSA_ALG_IS_TLS12_PSK_TO_MS( kdf_alg ) )
     {
         // TODO
         status = PSA_ERROR_NOT_SUPPORTED;
     }
     else
 #endif /* MBEDTLS_MD_C */
-
     {
         /* This can't happen unless the generator object was not initialized */
         return( PSA_ERROR_BAD_STATE );
@@ -4277,10 +4330,10 @@
  * to potentially free embedded data structures and wipe confidential data.
  */
 static psa_status_t psa_key_agreement_internal( psa_crypto_generator_t *generator,
+                                                psa_key_derivation_step_t step,
                                                 psa_key_slot_t *private_key,
                                                 const uint8_t *peer_key,
-                                                size_t peer_key_length,
-                                                psa_algorithm_t alg )
+                                                size_t peer_key_length )
 {
     psa_status_t status;
     uint8_t shared_secret[PSA_KEY_AGREEMENT_MAX_SHARED_SECRET_SIZE];
@@ -4288,7 +4341,7 @@
 
     /* Step 1: run the secret agreement algorithm to generate the shared
      * secret. */
-    switch( PSA_ALG_KEY_AGREEMENT_GET_BASE( alg ) )
+    switch( PSA_ALG_KEY_AGREEMENT_GET_BASE( generator->alg ) )
     {
 #if defined(MBEDTLS_ECDH_C)
         case PSA_ALG_ECDH_BASE:
@@ -4312,34 +4365,31 @@
 
     /* Step 2: set up the key derivation to generate key material from
      * the shared secret. */
-    status = psa_key_derivation_internal( generator,
-                                          shared_secret, shared_secret_length,
-                                          PSA_ALG_KEY_AGREEMENT_GET_KDF( alg ),
-                                          NULL, 0, NULL, 0,
-                                          PSA_GENERATOR_UNBRIDLED_CAPACITY );
+    status = psa_key_derivation_input_raw( generator, step,
+                                           shared_secret, shared_secret_length );
+
 exit:
     mbedtls_platform_zeroize( shared_secret, shared_secret_length );
     return( status );
 }
 
 psa_status_t psa_key_agreement( psa_crypto_generator_t *generator,
+                                psa_key_derivation_step_t step,
                                 psa_key_handle_t private_key,
                                 const uint8_t *peer_key,
-                                size_t peer_key_length,
-                                psa_algorithm_t alg )
+                                size_t peer_key_length )
 {
     psa_key_slot_t *slot;
     psa_status_t status;
-    if( ! PSA_ALG_IS_KEY_AGREEMENT( alg ) )
+    if( ! PSA_ALG_IS_KEY_AGREEMENT( generator->alg ) )
         return( PSA_ERROR_INVALID_ARGUMENT );
     status = psa_get_key_from_slot( private_key, &slot,
-                                    PSA_KEY_USAGE_DERIVE, alg );
+                                    PSA_KEY_USAGE_DERIVE, generator->alg );
     if( status != PSA_SUCCESS )
         return( status );
-    status = psa_key_agreement_internal( generator,
+    status = psa_key_agreement_internal( generator, step,
                                          slot,
-                                         peer_key, peer_key_length,
-                                         alg );
+                                         peer_key, peer_key_length );
     if( status != PSA_SUCCESS )
         psa_generator_abort( generator );
     return( status );
diff --git a/tests/suites/test_suite_psa_crypto.function b/tests/suites/test_suite_psa_crypto.function
index 9b8e01c..f90a7b3 100644
--- a/tests/suites/test_suite_psa_crypto.function
+++ b/tests/suites/test_suite_psa_crypto.function
@@ -405,8 +405,7 @@
 /* We need two keys to exercise key agreement. Exercise the
  * private key against its own public key. */
 static psa_status_t key_agreement_with_self( psa_crypto_generator_t *generator,
-                                             psa_key_handle_t handle,
-                                             psa_algorithm_t alg )
+                                             psa_key_handle_t handle )
 {
     psa_key_type_t private_key_type;
     psa_key_type_t public_key_type;
@@ -428,9 +427,8 @@
                                        public_key, public_key_length,
                                        &public_key_length ) );
 
-    status = psa_key_agreement( generator, handle,
-                                public_key, public_key_length,
-                                alg );
+    status = psa_key_agreement( generator, PSA_KDF_STEP_SECRET, handle,
+                                public_key, public_key_length );
 exit:
     mbedtls_free( public_key );
     return( status );
@@ -448,7 +446,8 @@
     {
         /* We need two keys to exercise key agreement. Exercise the
          * private key against its own public key. */
-        PSA_ASSERT( key_agreement_with_self( &generator, handle, alg ) );
+        PSA_ASSERT( psa_key_derivation_setup( &generator, alg ) );
+        PSA_ASSERT( key_agreement_with_self( &generator, handle ) );
         PSA_ASSERT( psa_generator_read( &generator,
                                         output,
                                         sizeof( output ) ) );
@@ -1791,7 +1790,8 @@
     PSA_ASSERT( psa_import_key( handle, key_type,
                                 key_data->x, key_data->len ) );
 
-    status = key_agreement_with_self( &generator, handle, exercise_alg );
+    PSA_ASSERT( psa_key_derivation_setup( &generator, exercise_alg ) );
+    status = key_agreement_with_self( &generator, handle );
 
     if( policy_alg == exercise_alg &&
         ( policy_usage & PSA_KEY_USAGE_DERIVE ) != 0 )
@@ -3848,10 +3848,10 @@
                                 our_key_data->x,
                                 our_key_data->len ) );
 
-    TEST_EQUAL( psa_key_agreement( &generator,
+    PSA_ASSERT( psa_key_derivation_setup( &generator, alg ) );
+    TEST_EQUAL( psa_key_agreement( &generator, PSA_KDF_STEP_SECRET,
                                    our_key,
-                                   peer_key_data->x, peer_key_data->len,
-                                   alg ),
+                                   peer_key_data->x, peer_key_data->len ),
                 expected_status_arg );
 
 exit:
@@ -3887,10 +3887,10 @@
                                 our_key_data->x,
                                 our_key_data->len ) );
 
-    PSA_ASSERT( psa_key_agreement( &generator,
+    PSA_ASSERT( psa_key_derivation_setup( &generator, alg ) );
+    PSA_ASSERT( psa_key_agreement( &generator, PSA_KDF_STEP_SECRET,
                                    our_key,
-                                   peer_key_data->x, peer_key_data->len,
-                                   alg ) );
+                                   peer_key_data->x, peer_key_data->len ) );
 
     /* Test the advertized capacity. */
     PSA_ASSERT( psa_get_generator_capacity(
@@ -3944,10 +3944,10 @@
                                 our_key_data->x,
                                 our_key_data->len ) );
 
-    PSA_ASSERT( psa_key_agreement( &generator,
+    PSA_ASSERT( psa_key_derivation_setup( &generator, alg ) );
+    PSA_ASSERT( psa_key_agreement( &generator, PSA_KDF_STEP_SECRET,
                                    our_key,
-                                   peer_key_data->x, peer_key_data->len,
-                                   alg ) );
+                                   peer_key_data->x, peer_key_data->len ) );
 
     PSA_ASSERT( psa_generator_read( &generator,
                                     actual_output,