Fix possible issues in testing and implementation of psa_key_agreement()

Signed-off-by: Waleed Elmelegy <waleed.elmelegy@arm.com>
diff --git a/tf-psa-crypto/core/psa_crypto.c b/tf-psa-crypto/core/psa_crypto.c
index e987625..fd05128 100644
--- a/tf-psa-crypto/core/psa_crypto.c
+++ b/tf-psa-crypto/core/psa_crypto.c
@@ -7711,14 +7711,8 @@
     uint8_t shared_secret[PSA_RAW_KEY_AGREEMENT_OUTPUT_MAX_SIZE];
     size_t shared_secret_len;
     psa_key_type_t key_type;
-    size_t key_size = PSA_RAW_KEY_AGREEMENT_OUTPUT_MAX_SIZE;
-    psa_algorithm_t key_alg;
 
-#if !defined(MBEDTLS_PSA_CRYPTO_KEY_ID_ENCODES_OWNER)
-    *key = PSA_KEY_ID_NULL;
-#else
-    key->key_id = PSA_KEY_ID_NULL;
-#endif
+    *key = MBEDTLS_SVC_KEY_ID_INIT;
 
     key_type = psa_get_key_type(attributes);
     if (key_type != PSA_KEY_TYPE_DERIVE && key_type != PSA_KEY_TYPE_RAW_DATA
@@ -7726,31 +7720,15 @@
         return PSA_ERROR_INVALID_ARGUMENT;
     }
 
-    key_alg = psa_get_key_algorithm(attributes);
-    if (key_alg != PSA_ALG_ECDH && key_alg != PSA_ALG_FFDH) {
-        return PSA_ERROR_INVALID_ARGUMENT;
-    }
-
-    if (psa_get_key_bits(attributes) != 0) {
-        key_size = PSA_BITS_TO_BYTES(psa_get_key_bits(attributes));
-    }
-
     status = psa_raw_key_agreement(alg, private_key, peer_key, peer_key_length, shared_secret,
-                                   key_size, &shared_secret_len);
+                                   sizeof(shared_secret), &shared_secret_len);
 
-    if (status == PSA_SUCCESS) {
-
-        psa_key_attributes_t shared_secret_attributes = PSA_KEY_ATTRIBUTES_INIT;
-        psa_set_key_type(&shared_secret_attributes, key_type);
-        psa_set_key_usage_flags(&shared_secret_attributes, psa_get_key_usage_flags(attributes));
-        psa_set_key_algorithm(&shared_secret_attributes, key_alg);
-        psa_set_key_lifetime(&shared_secret_attributes, psa_get_key_lifetime(attributes));
-        psa_set_key_bits(&shared_secret_attributes, shared_secret_len * 8);
-
-        status = psa_import_key(&shared_secret_attributes, shared_secret,
-                                shared_secret_len, key);
+    if (status != PSA_SUCCESS) {
+        return status;
     }
 
+    status = psa_import_key(attributes, shared_secret, shared_secret_len, key);
+
     return status;
 }
 
diff --git a/tf-psa-crypto/tests/suites/test_suite_psa_crypto.function b/tf-psa-crypto/tests/suites/test_suite_psa_crypto.function
index db2ac33..cee73b0 100644
--- a/tf-psa-crypto/tests/suites/test_suite_psa_crypto.function
+++ b/tf-psa-crypto/tests/suites/test_suite_psa_crypto.function
@@ -9733,6 +9733,7 @@
     size_t key_bits;
     mbedtls_svc_key_id_t shared_secret_id = MBEDTLS_SVC_KEY_ID_INIT;
     psa_key_attributes_t shared_secret_attributes = PSA_KEY_ATTRIBUTES_INIT;
+    psa_key_attributes_t output_attributes;
 
     PSA_ASSERT(psa_crypto_init());
 
@@ -9761,12 +9762,11 @@
     TEST_MEMORY_COMPARE(output, output_length,
                         expected_output->x, expected_output->len);
 
-    mbedtls_platform_zeroize(output, expected_output->len);
+    memset(output, 0, expected_output->len);
     output_length = 0;
 
     psa_set_key_type(&shared_secret_attributes, PSA_KEY_TYPE_DERIVE);
     psa_set_key_usage_flags(&shared_secret_attributes, PSA_KEY_USAGE_DERIVE | PSA_KEY_USAGE_EXPORT);
-    psa_set_key_algorithm(&shared_secret_attributes, PSA_ALG_ECDH);
 
     PSA_ASSERT(psa_key_agreement(our_key, peer_key_data->x, peer_key_data->len,
                                  alg, &shared_secret_attributes, &shared_secret_id));
@@ -9776,6 +9776,14 @@
     TEST_MEMORY_COMPARE(output, output_length,
                         expected_output->x, expected_output->len);
 
+    PSA_ASSERT(psa_get_key_attributes(shared_secret_id, &output_attributes));
+
+    TEST_EQUAL(PSA_BITS_TO_BYTES(psa_get_key_bits(&output_attributes)),
+               expected_output->len);
+    TEST_EQUAL(psa_get_key_type(&output_attributes), PSA_KEY_TYPE_DERIVE);
+    TEST_EQUAL(psa_get_key_usage_flags(&output_attributes),
+               PSA_KEY_USAGE_DERIVE | PSA_KEY_USAGE_EXPORT);
+
     mbedtls_free(output);
     output = NULL;
     output_length = ~0;
@@ -9791,18 +9799,6 @@
     TEST_MEMORY_COMPARE(output, output_length,
                         expected_output->x, expected_output->len);
 
-    mbedtls_platform_zeroize(output, expected_output->len + 1);
-    output_length = 0;
-
-    psa_set_key_bits(&shared_secret_attributes, (expected_output->len + 1) * 8);
-    PSA_ASSERT(psa_key_agreement(our_key, peer_key_data->x, peer_key_data->len,
-                                 alg, &shared_secret_attributes, &shared_secret_id));
-
-    PSA_ASSERT(psa_export_key(shared_secret_id, output, expected_output->len + 1, &output_length));
-
-    TEST_MEMORY_COMPARE(output, output_length,
-                        expected_output->x, expected_output->len);
-
     mbedtls_free(output);
     output = NULL;
     output_length = ~0;
@@ -9819,20 +9815,6 @@
     /* Not required by the spec, but good robustness */
     TEST_LE_U(output_length, expected_output->len - 1);
 
-    mbedtls_platform_zeroize(output, expected_output->len - 1);
-    output_length = 0;
-
-    psa_set_key_bits(&shared_secret_attributes, (expected_output->len - 1) * 8);
-    TEST_EQUAL(psa_key_agreement(our_key, peer_key_data->x, peer_key_data->len,
-                                 alg, &shared_secret_attributes, &shared_secret_id),
-               PSA_ERROR_BUFFER_TOO_SMALL);
-
-#if !defined(MBEDTLS_PSA_CRYPTO_KEY_ID_ENCODES_OWNER)
-    TEST_EQUAL(shared_secret_id, PSA_KEY_ID_NULL);
-#else
-    TEST_EQUAL(shared_secret_id.key_id, PSA_KEY_ID_NULL);
-#endif
-
     mbedtls_free(output);
     output = NULL;