Handle simple copy import/export before driver dispatch

Signed-off-by: Przemek Stekiel <przemyslaw.stekiel@mobica.com>
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 7b6f05b..242eb85 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -640,23 +640,6 @@
 
         return PSA_SUCCESS;
     } else if (PSA_KEY_TYPE_IS_ASYMMETRIC(type)) {
-#if defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_DH_KEY_PAIR) || \
-        defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_DH_PUBLIC_KEY)
-        if (PSA_KEY_TYPE_IS_DH(type)) {
-            if (psa_is_dh_key_size_valid(PSA_BYTES_TO_BITS(data_length)) == 0) {
-                return PSA_ERROR_INVALID_ARGUMENT;
-            }
-
-            /* Copy the key material. */
-            memcpy(key_buffer, data, data_length);
-            *key_buffer_length = data_length;
-            *bits = PSA_BYTES_TO_BITS(data_length);
-            (void) key_buffer_size;
-
-            return PSA_SUCCESS;
-        }
-#endif /* defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_DH_KEY_PAIR) ||
-        * defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_DH_PUBLIC_KEY) */
 #if defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_ECC_KEY_PAIR) || \
         defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_ECC_PUBLIC_KEY)
         if (PSA_KEY_TYPE_IS_ECC(type)) {
@@ -1426,14 +1409,7 @@
 {
     psa_key_type_t type = attributes->core.type;
 
-    if (PSA_KEY_TYPE_IS_PUBLIC_KEY(type) &&
-        (PSA_KEY_TYPE_IS_RSA(type) || PSA_KEY_TYPE_IS_ECC(type) ||
-         PSA_KEY_TYPE_IS_DH(type))) {
-        /* Exporting public -> public */
-        return psa_export_key_buffer_internal(
-            key_buffer, key_buffer_size,
-            data, data_size, data_length);
-    } else if (PSA_KEY_TYPE_IS_RSA(type)) {
+    if (PSA_KEY_TYPE_IS_RSA(type)) {
 #if defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR) || \
         defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY)
         return mbedtls_psa_rsa_export_public_key(attributes,
@@ -1514,9 +1490,23 @@
     psa_key_attributes_t attributes = {
         .core = slot->attr
     };
-    status = psa_driver_wrapper_export_public_key(
-        &attributes, slot->key.data, slot->key.bytes,
-        data, data_size, data_length);
+
+    psa_key_location_t location = PSA_KEY_LIFETIME_GET_LOCATION(
+        psa_get_key_lifetime(&attributes));
+
+    if (location == PSA_KEY_LOCATION_LOCAL_STORAGE &&
+        PSA_KEY_TYPE_IS_PUBLIC_KEY(slot->attr.type) &&
+        (PSA_KEY_TYPE_IS_RSA(slot->attr.type) || PSA_KEY_TYPE_IS_ECC(slot->attr.type) ||
+         PSA_KEY_TYPE_IS_DH(slot->attr.type))) {
+        /* Exporting public -> public */
+        status = psa_export_key_buffer_internal(
+            slot->key.data, slot->key.bytes,
+            data, data_size, data_length);
+    } else {
+        status = psa_driver_wrapper_export_public_key(
+            &attributes, slot->key.data, slot->key.bytes,
+            data, data_size, data_length);
+    }
 
 exit:
     unlock_status = psa_unlock_key_slot(slot);
@@ -2011,12 +2001,27 @@
         }
     }
 
-    bits = slot->attr.bits;
-    status = psa_driver_wrapper_import_key(attributes,
-                                           data, data_length,
-                                           slot->key.data,
-                                           slot->key.bytes,
-                                           &slot->key.bytes, &bits);
+    if (PSA_KEY_TYPE_IS_ASYMMETRIC(attributes->core.type) &&
+        PSA_KEY_TYPE_IS_DH(attributes->core.type)) {
+        if (psa_is_dh_key_size_valid(PSA_BYTES_TO_BITS(data_length)) == 0) {
+            status = PSA_ERROR_INVALID_ARGUMENT;
+            goto exit;
+        }
+
+        /* Copy the key material. */
+        memcpy(slot->key.data, data, data_length);
+        bits = PSA_BYTES_TO_BITS(data_length);
+
+        status = PSA_SUCCESS;
+    } else {
+        bits = slot->attr.bits;
+        status = psa_driver_wrapper_import_key(attributes,
+                                               data, data_length,
+                                               slot->key.data,
+                                               slot->key.bytes,
+                                               &slot->key.bytes, &bits);
+    }
+
     if (status != PSA_SUCCESS) {
         goto exit;
     }
@@ -5831,11 +5836,25 @@
         goto exit;
     }
 
-    status = psa_driver_wrapper_import_key(&attributes,
-                                           data, bytes,
-                                           slot->key.data,
-                                           slot->key.bytes,
-                                           &slot->key.bytes, &bits);
+    if (PSA_KEY_TYPE_IS_ASYMMETRIC(attributes.core.type) &&
+        PSA_KEY_TYPE_IS_DH(attributes.core.type)) {
+        if (psa_is_dh_key_size_valid(PSA_BYTES_TO_BITS(bytes)) == 0) {
+            status = PSA_ERROR_INVALID_ARGUMENT;
+            goto exit;
+        }
+
+        /* Copy the key material. */
+        memcpy(slot->key.data, data, bytes);
+        bits = PSA_BYTES_TO_BITS(bytes);
+
+        status = PSA_SUCCESS;
+    } else {
+        status = psa_driver_wrapper_import_key(&attributes,
+                                               data, bytes,
+                                               slot->key.data,
+                                               slot->key.bytes,
+                                               &slot->key.bytes, &bits);
+    }
     if (bits != slot->attr.bits) {
         status = PSA_ERROR_INVALID_ARGUMENT;
     }