PSA: Add support for HKDF-Extend and HKDF-Expand algs

Signed-off-by: Przemek Stekiel <przemyslaw.stekiel@mobica.com>
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index d58923d..0783697 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -4295,7 +4295,8 @@
     }
     else
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_HKDF)
-    if( PSA_ALG_IS_HKDF( kdf_alg ) )
+    if( PSA_ALG_IS_HKDF( kdf_alg ) || PSA_ALG_IS_HKDF_EXTRACT( kdf_alg ) ||
+        PSA_ALG_IS_HKDF_EXPAND( kdf_alg ) )
     {
         mbedtls_free( operation->ctx.hkdf.info );
         status = psa_mac_abort( &operation->ctx.hkdf.hmac );
@@ -4379,15 +4380,17 @@
 /* Read some bytes from an HKDF-based operation. This performs a chunk
  * of the expand phase of the HKDF algorithm. */
 static psa_status_t psa_key_derivation_hkdf_read( psa_hkdf_key_derivation_t *hkdf,
-                                                  psa_algorithm_t hash_alg,
+                                                  psa_algorithm_t kdf_alg,
                                                   uint8_t *output,
                                                   size_t output_length )
 {
+    psa_algorithm_t hash_alg = PSA_ALG_HKDF_GET_HASH( kdf_alg );
     uint8_t hash_length = PSA_HASH_LENGTH( hash_alg );
     size_t hmac_output_length;
     psa_status_t status;
 
-    if( hkdf->state < HKDF_STATE_KEYED || ! hkdf->info_set )
+    if( hkdf->state < HKDF_STATE_KEYED ||
+        ( ! hkdf->info_set && ! PSA_ALG_IS_HKDF_EXTRACT( kdf_alg ) ) )
         return( PSA_ERROR_BAD_STATE );
     hkdf->state = HKDF_STATE_OUTPUT;
 
@@ -4411,40 +4414,49 @@
         if( hkdf->block_number == 0xff )
             return( PSA_ERROR_BAD_STATE );
 
+
+        if( PSA_ALG_IS_HKDF_EXTRACT( kdf_alg ) && hkdf->block_number == 0 )
+        {
+            memcpy( hkdf->output_block, hkdf->prk, hash_length );
+        }
+
         /* We need a new block */
         ++hkdf->block_number;
         hkdf->offset_in_block = 0;
 
-        status = psa_key_derivation_start_hmac( &hkdf->hmac,
-                                                hash_alg,
-                                                hkdf->prk,
-                                                hash_length );
-        if( status != PSA_SUCCESS )
-            return( status );
-
-        if( hkdf->block_number != 1 )
+        if( ! PSA_ALG_IS_HKDF_EXTRACT( kdf_alg ) )
         {
+            status = psa_key_derivation_start_hmac( &hkdf->hmac,
+                                                    hash_alg,
+                                                    hkdf->prk,
+                                                    hash_length );
+            if( status != PSA_SUCCESS )
+                return( status );
+
+            if( hkdf->block_number != 1 )
+            {
+                status = psa_mac_update( &hkdf->hmac,
+                                        hkdf->output_block,
+                                        hash_length );
+                if( status != PSA_SUCCESS )
+                    return( status );
+            }
             status = psa_mac_update( &hkdf->hmac,
-                                     hkdf->output_block,
-                                     hash_length );
+                                    hkdf->info,
+                                    hkdf->info_length );
+            if( status != PSA_SUCCESS )
+                return( status );
+            status = psa_mac_update( &hkdf->hmac,
+                                    &hkdf->block_number, 1 );
+            if( status != PSA_SUCCESS )
+                return( status );
+            status = psa_mac_sign_finish( &hkdf->hmac,
+                                        hkdf->output_block,
+                                        sizeof( hkdf->output_block ),
+                                        &hmac_output_length );
             if( status != PSA_SUCCESS )
                 return( status );
         }
-        status = psa_mac_update( &hkdf->hmac,
-                                 hkdf->info,
-                                 hkdf->info_length );
-        if( status != PSA_SUCCESS )
-            return( status );
-        status = psa_mac_update( &hkdf->hmac,
-                                 &hkdf->block_number, 1 );
-        if( status != PSA_SUCCESS )
-            return( status );
-        status = psa_mac_sign_finish( &hkdf->hmac,
-                                      hkdf->output_block,
-                                      sizeof( hkdf->output_block ),
-                                      &hmac_output_length );
-        if( status != PSA_SUCCESS )
-            return( status );
     }
 
     return( PSA_SUCCESS );
@@ -4650,10 +4662,10 @@
     operation->capacity -= output_length;
 
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_HKDF)
-    if( PSA_ALG_IS_HKDF( kdf_alg ) )
+    if( PSA_ALG_IS_HKDF( kdf_alg ) || PSA_ALG_IS_HKDF_EXPAND( kdf_alg ) ||
+        PSA_ALG_IS_HKDF_EXTRACT( kdf_alg ) )
     {
-        psa_algorithm_t hash_alg = PSA_ALG_HKDF_GET_HASH( kdf_alg );
-        status = psa_key_derivation_hkdf_read( &operation->ctx.hkdf, hash_alg,
+        status = psa_key_derivation_hkdf_read( &operation->ctx.hkdf, kdf_alg,
                                           output, output_length );
     }
     else
@@ -5043,7 +5055,8 @@
 static int is_kdf_alg_supported( psa_algorithm_t kdf_alg )
 {
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_HKDF)
-    if( PSA_ALG_IS_HKDF( kdf_alg ) )
+    if( PSA_ALG_IS_HKDF( kdf_alg ) || PSA_ALG_IS_HKDF_EXTRACT( kdf_alg ) ||
+        PSA_ALG_IS_HKDF_EXPAND( kdf_alg ) )
         return( 1 );
 #endif
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_TLS12_PRF)
@@ -5097,8 +5110,10 @@
     {
         return( PSA_ERROR_NOT_SUPPORTED );
     }
-
-    operation->capacity = 255 * hash_size;
+    if( PSA_ALG_IS_HKDF_EXTRACT( kdf_alg ) )
+        operation->capacity = hash_size;
+    else
+        operation->capacity = 255 * hash_size;
     return( PSA_SUCCESS );
 }
 
@@ -5154,15 +5169,18 @@
 
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_HKDF)
 static psa_status_t psa_hkdf_input( psa_hkdf_key_derivation_t *hkdf,
-                                    psa_algorithm_t hash_alg,
+                                    psa_algorithm_t kdf_alg,
                                     psa_key_derivation_step_t step,
                                     const uint8_t *data,
                                     size_t data_length )
 {
+    psa_algorithm_t hash_alg = PSA_ALG_HKDF_GET_HASH( kdf_alg );
     psa_status_t status;
     switch( step )
     {
         case PSA_KEY_DERIVATION_INPUT_SALT:
+            if( PSA_ALG_IS_HKDF_EXPAND( kdf_alg ) )
+                return( PSA_ERROR_INVALID_ARGUMENT );
             if( hkdf->state != HKDF_STATE_INIT )
                 return( PSA_ERROR_BAD_STATE );
             else
@@ -5177,32 +5195,48 @@
             }
         case PSA_KEY_DERIVATION_INPUT_SECRET:
             /* If no salt was provided, use an empty salt. */
-            if( hkdf->state == HKDF_STATE_INIT )
+            if( PSA_ALG_IS_HKDF_EXPAND( kdf_alg ) )
             {
-                status = psa_key_derivation_start_hmac( &hkdf->hmac,
-                                                        hash_alg,
-                                                        NULL, 0 );
+                if( hkdf->state != HKDF_STATE_INIT )
+                    return( PSA_ERROR_BAD_STATE );
+
+                if( data_length > sizeof( hkdf->prk ) )
+                    return( PSA_ERROR_INVALID_ARGUMENT );
+
+                memcpy( hkdf->prk, data, data_length );
+            }
+            else
+            {
+                if( hkdf->state == HKDF_STATE_INIT )
+                {
+                    status = psa_key_derivation_start_hmac( &hkdf->hmac,
+                                                            hash_alg,
+                                                            NULL, 0 );
+                    if( status != PSA_SUCCESS )
+                        return( status );
+                    hkdf->state = HKDF_STATE_STARTED;
+                }
+                if( hkdf->state != HKDF_STATE_STARTED )
+                    return( PSA_ERROR_BAD_STATE );
+                status = psa_mac_update( &hkdf->hmac,
+                                        data, data_length );
                 if( status != PSA_SUCCESS )
                     return( status );
-                hkdf->state = HKDF_STATE_STARTED;
+                status = psa_mac_sign_finish( &hkdf->hmac,
+                                            hkdf->prk,
+                                            sizeof( hkdf->prk ),
+                                            &data_length );
+                if( status != PSA_SUCCESS )
+                    return( status );
             }
-            if( hkdf->state != HKDF_STATE_STARTED )
-                return( PSA_ERROR_BAD_STATE );
-            status = psa_mac_update( &hkdf->hmac,
-                                     data, data_length );
-            if( status != PSA_SUCCESS )
-                return( status );
-            status = psa_mac_sign_finish( &hkdf->hmac,
-                                          hkdf->prk,
-                                          sizeof( hkdf->prk ),
-                                          &data_length );
-            if( status != PSA_SUCCESS )
-                return( status );
+
             hkdf->offset_in_block = PSA_HASH_LENGTH( hash_alg );
             hkdf->block_number = 0;
             hkdf->state = HKDF_STATE_KEYED;
             return( PSA_SUCCESS );
         case PSA_KEY_DERIVATION_INPUT_INFO:
+            if( PSA_ALG_IS_HKDF_EXTRACT( kdf_alg ) )
+                return( PSA_ERROR_INVALID_ARGUMENT );
             if( hkdf->state == HKDF_STATE_OUTPUT )
                 return( PSA_ERROR_BAD_STATE );
             if( hkdf->info_set )
@@ -5488,10 +5522,10 @@
         goto exit;
 
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_HKDF)
-    if( PSA_ALG_IS_HKDF( kdf_alg ) )
+    if( PSA_ALG_IS_HKDF( kdf_alg ) || PSA_ALG_IS_HKDF_EXTRACT( kdf_alg ) ||
+        PSA_ALG_IS_HKDF_EXPAND( kdf_alg ) )
     {
-        status = psa_hkdf_input( &operation->ctx.hkdf,
-                                 PSA_ALG_HKDF_GET_HASH( kdf_alg ),
+        status = psa_hkdf_input( &operation->ctx.hkdf, kdf_alg,
                                  step, data, data_length );
     }
     else