test PSA key derivation: add positive and negative cases for mixed-psk

Mix-PSK-to-MS test vectors are generated using python-tls library:
https://github.com/python-tls/tls

Steps to generate test vectors:
1. git clone git@github.com:python-tls/tls.git
2. cd tls
3. python3 setup.py build
4. sudo python3 setup.py install
5. Use the python script below to generate Master Secret (see description for details):

"""
Script to derive MS using mixed PSK to MS algorithm.

Script can be used to generate expected result for mixed PSK to MS tests.

Script uses python tls library:
https://github.com/python-tls/tls

Example usage:
derive_ms.py <secret> <other_secret> <seed> <label> <hash>
derive_ms.py 01020304 ce2fa604b6a3e08fc42eda74ab647adace1168b199ed178dbaae12521d68271d7df56eb56c55878034cf01bd887ba4d7 5bc0b19b4a8b24b07afe7ec65c471e94a7d518fcef06c3574315255c52afe21b5bc0b19b872b9b26508458f03603744d575f463a11ae7f1b090c012606fd3e9f 6d617374657220736563726574 SHA256

secret          : 01020304
other_secret    : ce2fa604b6a3e08fc42eda74ab647adace1168b199ed178dbaae12521d68271d7df56eb56c55878034cf01bd887ba4d7
pms             : 0030ce2fa604b6a3e08fc42eda74ab647adace1168b199ed178dbaae12521d68271d7df56eb56c55878034cf01bd887ba4d7000401020304
seed            : 5bc0b19b4a8b24b07afe7ec65c471e94a7d518fcef06c3574315255c52afe21b5bc0b19b872b9b26508458f03603744d575f463a11ae7f1b090c012606fd3e9f
label           : 6d617374657220736563726574
output          : 168fecea35190f9df34c042f24ecaa5e7825337f2cd82719464df5462f16aae84cb38a65c0d612ca9273f998ad32c05b
"""
from cryptography.hazmat.primitives import hashes
from tls._common.prf import prf
import os
import sys

def build_pms(other_secret: bytes, secret: bytes) -> bytes:
    other_secret_size = len(other_secret).to_bytes(2, byteorder='big')
    secret_size = len(secret).to_bytes(2, byteorder='big')
    return(other_secret_size + other_secret + secret_size + secret)

def derive_ms(secret: bytes, other_secret: bytes, seed: bytes, label: bytes, hash: hashes.HashAlgorithm) -> bytes:
    return prf(build_pms(other_secret, secret), label, seed, hash, 48)

def main():
    #check args
    if len(sys.argv) != 6:
        print("Invalid number of arguments. Expected: <secret> <other_secret> <seed> <label> <hash>" )
        return
    if sys.argv[5] != 'SHA384' and sys.argv[5] != 'SHA256':
        print("Invalid hash algorithm. Expected: SHA256 or SHA384" )
        return

    secret = bytes.fromhex(sys.argv[1])
    other_secret = bytes.fromhex(sys.argv[2])
    seed = bytes.fromhex(sys.argv[3])
    label = bytes.fromhex(sys.argv[4])
    hash_func = hashes.SHA384() if sys.argv[5] == 'SHA384' else hashes.SHA256()
    pms = build_pms(other_secret, secret)

    actual_output = derive_ms(secret, other_secret, seed, label, hash_func)

    print('secret       : ' + secret.hex())
    print('other_secret : ' + other_secret.hex())
    print('pms          : ' + pms.hex())
    print('seed         : ' + seed.hex())
    print('label        : ' + label.hex())
    print('output       : ' + actual_output.hex())

if __name__ == "__main__":
    main()

Signed-off-by: Przemek Stekiel <przemyslaw.stekiel@mobica.com>
diff --git a/tests/suites/test_suite_psa_crypto.function b/tests/suites/test_suite_psa_crypto.function
index da49631..6562d10 100644
--- a/tests/suites/test_suite_psa_crypto.function
+++ b/tests/suites/test_suite_psa_crypto.function
@@ -6897,21 +6897,27 @@
 
 /* BEGIN_CASE */
 void derive_output( int alg_arg,
-                    int step1_arg, data_t *input1,
-                    int step2_arg, data_t *input2,
-                    int step3_arg, data_t *input3,
-                    int step4_arg, data_t *input4,
+                    int step1_arg, data_t *input1, int expected_status_arg1,
+                    int step2_arg, data_t *input2, int expected_status_arg2,
+                    int step3_arg, data_t *input3, int expected_status_arg3,
+                    int step4_arg, data_t *input4, int expected_status_arg4,
+                    data_t *key_agreement_peer_key,
                     int requested_capacity_arg,
                     data_t *expected_output1,
-                    data_t *expected_output2 )
+                    data_t *expected_output2,
+                    int other_key_input_type,
+                    int key_input_type,
+                    int derive_type )
 {
     psa_algorithm_t alg = alg_arg;
     psa_key_derivation_step_t steps[] = {step1_arg, step2_arg, step3_arg, step4_arg};
     data_t *inputs[] = {input1, input2, input3, input4};
-    mbedtls_svc_key_id_t keys[] = { MBEDTLS_SVC_KEY_ID_INIT,
-                                    MBEDTLS_SVC_KEY_ID_INIT,
-                                    MBEDTLS_SVC_KEY_ID_INIT,
-                                    MBEDTLS_SVC_KEY_ID_INIT};
+    mbedtls_svc_key_id_t keys[] = {MBEDTLS_SVC_KEY_ID_INIT,
+                                   MBEDTLS_SVC_KEY_ID_INIT,
+                                   MBEDTLS_SVC_KEY_ID_INIT,
+                                   MBEDTLS_SVC_KEY_ID_INIT};
+    psa_status_t statuses[] = {expected_status_arg1, expected_status_arg2,
+                               expected_status_arg3, expected_status_arg4};
     size_t requested_capacity = requested_capacity_arg;
     psa_key_derivation_operation_t operation = PSA_KEY_DERIVATION_OPERATION_INIT;
     uint8_t *expected_outputs[2] =
@@ -6922,7 +6928,10 @@
     uint8_t *output_buffer = NULL;
     size_t expected_capacity;
     size_t current_capacity;
-    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
+    psa_key_attributes_t attributes1 = PSA_KEY_ATTRIBUTES_INIT;
+    psa_key_attributes_t attributes2 = PSA_KEY_ATTRIBUTES_INIT;
+    psa_key_attributes_t attributes3 = PSA_KEY_ATTRIBUTES_INIT;
+    psa_key_attributes_t attributes4 = PSA_KEY_ATTRIBUTES_INIT;
     psa_status_t status;
     size_t i;
 
@@ -6936,10 +6945,6 @@
     ASSERT_ALLOC( output_buffer, output_buffer_size );
     PSA_ASSERT( psa_crypto_init( ) );
 
-    psa_set_key_usage_flags( &attributes, PSA_KEY_USAGE_DERIVE );
-    psa_set_key_algorithm( &attributes, alg );
-    psa_set_key_type( &attributes, PSA_KEY_TYPE_DERIVE );
-
     /* Extraction phase. */
     PSA_ASSERT( psa_key_derivation_setup( &operation, alg ) );
     PSA_ASSERT( psa_key_derivation_set_capacity( &operation,
@@ -6951,19 +6956,107 @@
             case 0:
                 break;
             case PSA_KEY_DERIVATION_INPUT_SECRET:
-                PSA_ASSERT( psa_import_key( &attributes,
-                                            inputs[i]->x, inputs[i]->len,
-                                            &keys[i] ) );
-
-                if ( PSA_ALG_IS_TLS12_PSK_TO_MS( alg ) )
+                switch ( key_input_type )
                 {
-                    PSA_ASSERT( psa_get_key_attributes( keys[i], &attributes ) );
-                    TEST_ASSERT( PSA_BITS_TO_BYTES( psa_get_key_bits( &attributes ) ) <=
-                                 PSA_TLS12_PSK_TO_MS_PSK_MAX_SIZE );
+                    case 0: // input bytes
+                        PSA_ASSERT( psa_key_derivation_input_bytes(
+                                        &operation, steps[i],
+                                        inputs[i]->x, inputs[i]->len ) );
+                        break;
+                    case 1: // input key
+                        psa_set_key_usage_flags( &attributes1, PSA_KEY_USAGE_DERIVE );
+                        psa_set_key_algorithm( &attributes1, alg );
+                        psa_set_key_type( &attributes1, PSA_KEY_TYPE_DERIVE );
+
+                        PSA_ASSERT( psa_import_key( &attributes1,
+                                                    inputs[i]->x, inputs[i]->len,
+                                                    &keys[i] ) );
+
+                        if ( PSA_ALG_IS_TLS12_PSK_TO_MS( alg ) )
+                        {
+                            PSA_ASSERT( psa_get_key_attributes( keys[i], &attributes1 ) );
+                            TEST_ASSERT( PSA_BITS_TO_BYTES( psa_get_key_bits( &attributes1 ) ) <=
+                                        PSA_TLS12_PSK_TO_MS_PSK_MAX_SIZE );
+                        }
+
+                        PSA_ASSERT( psa_key_derivation_input_key(
+                                        &operation, steps[i], keys[i] ) );
+                        break;
+                    default:
+                        TEST_ASSERT( ! "default case not supported" );
+                        break;
+                }
+                break;
+            case PSA_KEY_DERIVATION_INPUT_OTHER_SECRET:
+                switch ( other_key_input_type )
+                {
+                    case 0: // input bytes
+                        TEST_EQUAL( psa_key_derivation_input_bytes(
+                            &operation, steps[i],
+                            inputs[i]->x, inputs[i]->len ), statuses[i] );
+                        break;
+                    case 1: // input key
+                        psa_set_key_usage_flags( &attributes2, PSA_KEY_USAGE_DERIVE );
+                        psa_set_key_algorithm( &attributes2, alg );
+                        psa_set_key_type( &attributes2, PSA_KEY_TYPE_DERIVE );
+
+                        // other secret of type RAW_DATA passed with input_key
+                        if ( statuses[i] == PSA_ERROR_INVALID_ARGUMENT )
+                            psa_set_key_type( &attributes2, PSA_KEY_TYPE_RAW_DATA );
+
+                        PSA_ASSERT( psa_import_key( &attributes2,
+                            inputs[i]->x, inputs[i]->len,
+                            &keys[i] ) );
+
+                        TEST_EQUAL( psa_key_derivation_input_key(
+                                        &operation, steps[i], keys[i] ), statuses[i] );
+                        break;
+                    case 2: // key agreement
+                        psa_set_key_usage_flags( &attributes3, PSA_KEY_USAGE_DERIVE );
+                        psa_set_key_algorithm( &attributes3, alg );
+                        psa_set_key_type( &attributes3, PSA_KEY_TYPE_ECC_KEY_PAIR(PSA_ECC_FAMILY_SECP_R1) );
+
+                        PSA_ASSERT( psa_import_key( &attributes3,
+                            inputs[i]->x, inputs[i]->len,
+                            &keys[i] ) );
+
+                        TEST_EQUAL( psa_key_derivation_key_agreement(
+                                        &operation,
+                                        PSA_KEY_DERIVATION_INPUT_OTHER_SECRET,
+                                        keys[i], key_agreement_peer_key->x,
+                                        key_agreement_peer_key->len ), statuses[i] );
+                        break;
+                    case 3: // raw key agreement
+                        psa_set_key_usage_flags( &attributes3, PSA_KEY_USAGE_DERIVE );
+                        psa_set_key_algorithm( &attributes3, PSA_ALG_ECDH );
+                        psa_set_key_type( &attributes3, PSA_KEY_TYPE_ECC_KEY_PAIR(PSA_ECC_FAMILY_SECP_R1) );
+
+                        uint8_t key_agreement_output[32];
+                        size_t key_agreement_output_length;
+
+                        PSA_ASSERT( psa_import_key( &attributes3,
+                            inputs[i]->x, inputs[i]->len,
+                            &keys[i] ) );
+
+                        PSA_ASSERT( psa_raw_key_agreement(
+                                        PSA_ALG_ECDH,
+                                        keys[i], key_agreement_peer_key->x,
+                                        key_agreement_peer_key->len, key_agreement_output,
+                                        sizeof(key_agreement_output), &key_agreement_output_length ) );
+
+                        TEST_ASSERT( key_agreement_output_length == 32 );
+
+                        TEST_EQUAL( psa_key_derivation_input_bytes(
+                            &operation, steps[i],
+                            key_agreement_output, key_agreement_output_length ), statuses[i] );
+                        break;
+                    default:
+                        TEST_ASSERT( ! "default case not supported" );
+                        break;
                 }
 
-                PSA_ASSERT( psa_key_derivation_input_key(
-                                &operation, steps[i], keys[i] ) );
+                if ( statuses[i] != PSA_SUCCESS )
+                    goto exit;
                 break;
             default:
                 PSA_ASSERT( psa_key_derivation_input_bytes(
@@ -6978,37 +7071,54 @@
     TEST_EQUAL( current_capacity, requested_capacity );
     expected_capacity = requested_capacity;
 
-    /* Expansion phase. */
-    for( i = 0; i < ARRAY_LENGTH( expected_outputs ); i++ )
+    if( derive_type == 1 ) // output key
     {
-        /* Read some bytes. */
-        status = psa_key_derivation_output_bytes( &operation,
-                                                  output_buffer, output_sizes[i] );
-        if( expected_capacity == 0 && output_sizes[i] == 0 )
+        /* Test that output key derivation is not permitted when secret is
+         * passed using input bytes and other secret is passed using input key. */
+        mbedtls_svc_key_id_t derived_key = MBEDTLS_SVC_KEY_ID_INIT;
+
+        psa_set_key_usage_flags( &attributes4, PSA_KEY_USAGE_EXPORT );
+        psa_set_key_algorithm( &attributes4, alg );
+        psa_set_key_type( &attributes4, PSA_KEY_TYPE_DERIVE );
+        psa_set_key_bits( &attributes4, 48 );
+
+        TEST_EQUAL( psa_key_derivation_output_key( &attributes4, &operation,
+                                        &derived_key ), PSA_ERROR_NOT_PERMITTED );
+    }
+    else // output bytes
+    {
+        /* Expansion phase. */
+        for( i = 0; i < ARRAY_LENGTH( expected_outputs ); i++ )
         {
-            /* Reading 0 bytes when 0 bytes are available can go either way. */
-            TEST_ASSERT( status == PSA_SUCCESS ||
-                         status == PSA_ERROR_INSUFFICIENT_DATA );
-            continue;
+            /* Read some bytes. */
+            status = psa_key_derivation_output_bytes( &operation,
+                                                    output_buffer, output_sizes[i] );
+            if( expected_capacity == 0 && output_sizes[i] == 0 )
+            {
+                /* Reading 0 bytes when 0 bytes are available can go either way. */
+                TEST_ASSERT( status == PSA_SUCCESS ||
+                            status == PSA_ERROR_INSUFFICIENT_DATA );
+                continue;
+            }
+            else if( expected_capacity == 0 ||
+                    output_sizes[i] > expected_capacity )
+            {
+                /* Capacity exceeded. */
+                TEST_EQUAL( status, PSA_ERROR_INSUFFICIENT_DATA );
+                expected_capacity = 0;
+                continue;
+            }
+            /* Success. Check the read data. */
+            PSA_ASSERT( status );
+            if( output_sizes[i] != 0 )
+                ASSERT_COMPARE( output_buffer, output_sizes[i],
+                                expected_outputs[i], output_sizes[i] );
+            /* Check the operation status. */
+            expected_capacity -= output_sizes[i];
+            PSA_ASSERT( psa_key_derivation_get_capacity( &operation,
+                                                        &current_capacity ) );
+            TEST_EQUAL( expected_capacity, current_capacity );
         }
-        else if( expected_capacity == 0 ||
-                 output_sizes[i] > expected_capacity )
-        {
-            /* Capacity exceeded. */
-            TEST_EQUAL( status, PSA_ERROR_INSUFFICIENT_DATA );
-            expected_capacity = 0;
-            continue;
-        }
-        /* Success. Check the read data. */
-        PSA_ASSERT( status );
-        if( output_sizes[i] != 0 )
-            ASSERT_COMPARE( output_buffer, output_sizes[i],
-                            expected_outputs[i], output_sizes[i] );
-        /* Check the operation status. */
-        expected_capacity -= output_sizes[i];
-        PSA_ASSERT( psa_key_derivation_get_capacity( &operation,
-                                                     &current_capacity ) );
-        TEST_EQUAL( expected_capacity, current_capacity );
     }
     PSA_ASSERT( psa_key_derivation_abort( &operation ) );