Move the common parameters check code out of the wrapper

Signed-off-by: Neil Armstrong <narmstrong@baylibre.com>
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 2cd4ee7..3494ae7 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -7168,6 +7168,19 @@
     psa_pake_operation_t *operation,
     const psa_pake_cipher_suite_t *cipher_suite)
 {
+    /* A context must be freshly initialized before it can be set up. */
+    if (operation->alg != PSA_ALG_NONE) {
+        return PSA_ERROR_BAD_STATE;
+    }
+
+    if (cipher_suite == NULL ||
+        PSA_ALG_IS_PAKE(cipher_suite->algorithm) == 0 ||
+        (cipher_suite->type != PSA_PAKE_PRIMITIVE_TYPE_ECC &&
+         cipher_suite->type != PSA_PAKE_PRIMITIVE_TYPE_DH) ||
+        PSA_ALG_IS_HASH(cipher_suite->hash) == 0) {
+        return PSA_ERROR_INVALID_ARGUMENT;
+    }
+
     return psa_driver_wrapper_pake_setup(operation, cipher_suite);
 }
 
@@ -7175,6 +7188,34 @@
     psa_pake_operation_t *operation,
     mbedtls_svc_key_id_t password)
 {
+    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+    psa_key_attributes_t attributes = psa_key_attributes_init();
+    psa_key_type_t type;
+    psa_key_usage_t usage;
+
+    if (operation->alg == PSA_ALG_NONE) {
+        return PSA_ERROR_BAD_STATE;
+    }
+
+    status = psa_get_key_attributes(password, &attributes);
+    if (status != PSA_SUCCESS) {
+        return status;
+    }
+
+    type = psa_get_key_type(&attributes);
+    usage = psa_get_key_usage_flags(&attributes);
+
+    psa_reset_key_attributes(&attributes);
+
+    if (type != PSA_KEY_TYPE_PASSWORD &&
+        type != PSA_KEY_TYPE_PASSWORD_HASH) {
+        return PSA_ERROR_INVALID_ARGUMENT;
+    }
+
+    if ((usage & PSA_KEY_USAGE_DERIVE) == 0) {
+        return PSA_ERROR_NOT_PERMITTED;
+    }
+
     return psa_driver_wrapper_pake_set_password_key(operation, password);
 }
 
@@ -7183,6 +7224,14 @@
     const uint8_t *user_id,
     size_t user_id_len)
 {
+    if (operation->alg == PSA_ALG_NONE) {
+        return PSA_ERROR_BAD_STATE;
+    }
+
+    if (user_id_len == 0 || user_id == NULL) {
+        return PSA_ERROR_INVALID_ARGUMENT;
+    }
+
     return psa_driver_wrapper_pake_set_user(operation, user_id,
                                             user_id_len);
 }
@@ -7192,6 +7241,14 @@
     const uint8_t *peer_id,
     size_t peer_id_len)
 {
+    if (operation->alg == PSA_ALG_NONE) {
+        return PSA_ERROR_BAD_STATE;
+    }
+
+    if (peer_id_len == 0 || peer_id == NULL) {
+        return PSA_ERROR_INVALID_ARGUMENT;
+    }
+
     return psa_driver_wrapper_pake_set_peer(operation, peer_id,
                                             peer_id_len);
 }
@@ -7200,6 +7257,18 @@
     psa_pake_operation_t *operation,
     psa_pake_role_t role)
 {
+    if (operation->alg == PSA_ALG_NONE) {
+        return PSA_ERROR_BAD_STATE;
+    }
+
+    if (role != PSA_PAKE_ROLE_NONE &&
+        role != PSA_PAKE_ROLE_FIRST &&
+        role != PSA_PAKE_ROLE_SECOND &&
+        role != PSA_PAKE_ROLE_CLIENT &&
+        role != PSA_PAKE_ROLE_SERVER) {
+        return PSA_ERROR_INVALID_ARGUMENT;
+    }
+
     return psa_driver_wrapper_pake_set_role(operation, role);
 }
 
@@ -7210,6 +7279,14 @@
     size_t output_size,
     size_t *output_length)
 {
+    if (operation->alg == PSA_ALG_NONE) {
+        return PSA_ERROR_BAD_STATE;
+    }
+
+    if (output == NULL || output_size == 0 || output_length == NULL) {
+        return PSA_ERROR_INVALID_ARGUMENT;
+    }
+
     return psa_driver_wrapper_pake_output(operation, step, output,
                                           output_size, output_length);
 }
@@ -7220,6 +7297,14 @@
     const uint8_t *input,
     size_t input_length)
 {
+    if (operation->alg == PSA_ALG_NONE) {
+        return PSA_ERROR_BAD_STATE;
+    }
+
+    if (input == NULL || input_length == 0) {
+        return PSA_ERROR_INVALID_ARGUMENT;
+    }
+
     return psa_driver_wrapper_pake_input(operation, step, input,
                                          input_length);
 }
@@ -7228,12 +7313,20 @@
     psa_pake_operation_t *operation,
     psa_key_derivation_operation_t *output)
 {
+    if (operation->alg == PSA_ALG_NONE) {
+        return PSA_ERROR_BAD_STATE;
+    }
+
     return psa_driver_wrapper_pake_get_implicit_key(operation, output);
 }
 
 psa_status_t psa_pake_abort(
     psa_pake_operation_t *operation)
 {
+    if (operation->alg == PSA_ALG_NONE) {
+        return PSA_SUCCESS;
+    }
+
     return psa_driver_wrapper_pake_abort(operation);
 }
 #endif /* MBEDTLS_PSA_BUILTIN_PAKE */