Divide pake operation into two phases collecting inputs and computation.
Functions that only set inputs do not have driver entry points.
Signed-off-by: Przemek Stekiel <przemyslaw.stekiel@mobica.com>
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 18aa18b..4742c3c 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -7180,7 +7180,29 @@
psa_pake_operation_t *operation,
const psa_pake_cipher_suite_t *cipher_suite)
{
- return psa_driver_wrapper_pake_setup(operation, cipher_suite);
+ if (operation->stage != PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
+ return PSA_ERROR_BAD_STATE;
+ }
+
+ if (operation->data.inputs.alg != PSA_ALG_NONE) {
+ return PSA_ERROR_BAD_STATE;
+ }
+
+ if (cipher_suite == NULL ||
+ PSA_ALG_IS_PAKE(cipher_suite->algorithm) == 0 ||
+ (cipher_suite->type != PSA_PAKE_PRIMITIVE_TYPE_ECC &&
+ cipher_suite->type != PSA_PAKE_PRIMITIVE_TYPE_DH) ||
+ PSA_ALG_IS_HASH(cipher_suite->hash) == 0) {
+ return PSA_ERROR_INVALID_ARGUMENT;
+ }
+
+ ;
+ memset(&operation->data.inputs, 0, sizeof(operation->data.inputs));
+
+ operation->data.inputs.alg = cipher_suite->algorithm;
+ operation->data.inputs.cipher_suite = *cipher_suite;
+
+ return PSA_SUCCESS;
}
psa_status_t psa_pake_set_password_key(
@@ -7191,7 +7213,11 @@
psa_status_t unlock_status = PSA_ERROR_CORRUPTION_DETECTED;
psa_key_slot_t *slot = NULL;
- if (operation->id == 0) {
+ if (operation->stage != PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
+ return PSA_ERROR_BAD_STATE;
+ }
+
+ if (operation->data.inputs.alg == PSA_ALG_NONE) {
return PSA_ERROR_BAD_STATE;
}
@@ -7206,9 +7232,29 @@
.core = slot->attr
};
- status = psa_driver_wrapper_pake_set_password_key(&attributes, operation,
- slot->key.data, slot->key.bytes);
+ psa_key_type_t type = psa_get_key_type(&attributes);
+ psa_key_usage_t usage = psa_get_key_usage_flags(&attributes);
+ if (type != PSA_KEY_TYPE_PASSWORD &&
+ type != PSA_KEY_TYPE_PASSWORD_HASH) {
+ status = PSA_ERROR_INVALID_ARGUMENT;
+ goto error;
+ }
+
+ if ((usage & PSA_KEY_USAGE_DERIVE) == 0) {
+ status = PSA_ERROR_NOT_PERMITTED;
+ goto error;
+ }
+
+ operation->data.inputs.password = mbedtls_calloc(1, slot->key.bytes);
+ if (operation->data.inputs.password == NULL) {
+ return PSA_ERROR_INSUFFICIENT_MEMORY;
+ }
+
+ memcpy(operation->data.inputs.password, slot->key.data, slot->key.bytes);
+ operation->data.inputs.password_len = slot->key.bytes;
+ operation->data.inputs.key_lifetime = attributes.core.lifetime;
+error:
unlock_status = psa_unlock_key_slot(slot);
return (status == PSA_SUCCESS) ? unlock_status : status;
@@ -7219,16 +7265,21 @@
const uint8_t *user_id,
size_t user_id_len)
{
- if (operation->id == 0) {
+ (void) user_id;
+
+ if (operation->stage != PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
return PSA_ERROR_BAD_STATE;
}
- if (user_id_len == 0 || user_id == NULL) {
+ if (operation->data.inputs.alg == PSA_ALG_NONE) {
+ return PSA_ERROR_BAD_STATE;
+ }
+
+ if (user_id_len == 0) {
return PSA_ERROR_INVALID_ARGUMENT;
}
- return psa_driver_wrapper_pake_set_user(operation, user_id,
- user_id_len);
+ return PSA_ERROR_NOT_SUPPORTED;
}
psa_status_t psa_pake_set_peer(
@@ -7236,23 +7287,32 @@
const uint8_t *peer_id,
size_t peer_id_len)
{
- if (operation->id == 0) {
+ (void) peer_id;
+
+ if (operation->stage != PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
return PSA_ERROR_BAD_STATE;
}
- if (peer_id_len == 0 || peer_id == NULL) {
+ if (operation->data.inputs.alg == PSA_ALG_NONE) {
+ return PSA_ERROR_BAD_STATE;
+ }
+
+ if (peer_id_len == 0) {
return PSA_ERROR_INVALID_ARGUMENT;
}
- return psa_driver_wrapper_pake_set_peer(operation, peer_id,
- peer_id_len);
+ return PSA_ERROR_NOT_SUPPORTED;
}
psa_status_t psa_pake_set_role(
psa_pake_operation_t *operation,
psa_pake_role_t role)
{
- if (operation->id == 0) {
+ if (operation->stage != PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
+ return PSA_ERROR_BAD_STATE;
+ }
+
+ if (operation->data.inputs.alg == PSA_ALG_NONE) {
return PSA_ERROR_BAD_STATE;
}
@@ -7264,7 +7324,9 @@
return PSA_ERROR_INVALID_ARGUMENT;
}
- return psa_driver_wrapper_pake_set_role(operation, role);
+ operation->data.inputs.role = role;
+
+ return PSA_SUCCESS;
}
psa_status_t psa_pake_output(
@@ -7274,11 +7336,34 @@
size_t output_size,
size_t *output_length)
{
+ psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+
+ if (operation->stage == PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
+ if (operation->data.inputs.alg == PSA_ALG_NONE ||
+ operation->data.inputs.password_len == 0 ||
+ operation->data.inputs.role == PSA_PAKE_ROLE_NONE) {
+ return PSA_ERROR_BAD_STATE;
+ }
+
+ status = psa_driver_wrapper_pake_setup(operation,
+ &operation->data.inputs);
+
+ if (status == PSA_SUCCESS) {
+ operation->stage = PSA_PAKE_OPERATION_STAGE_COMPUTATION;
+ } else {
+ return status;
+ }
+ }
+
+ if (operation->stage != PSA_PAKE_OPERATION_STAGE_COMPUTATION) {
+ return PSA_ERROR_BAD_STATE;
+ }
+
if (operation->id == 0) {
return PSA_ERROR_BAD_STATE;
}
- if (output == NULL || output_size == 0 || output_length == NULL) {
+ if (output == NULL || output_size == 0) {
return PSA_ERROR_INVALID_ARGUMENT;
}
@@ -7292,6 +7377,29 @@
const uint8_t *input,
size_t input_length)
{
+ psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+
+ if (operation->stage == PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
+ if (operation->data.inputs.alg == PSA_ALG_NONE ||
+ operation->data.inputs.password_len == 0 ||
+ operation->data.inputs.role == PSA_PAKE_ROLE_NONE) {
+ return PSA_ERROR_BAD_STATE;
+ }
+
+ status = psa_driver_wrapper_pake_setup(operation,
+ &operation->data.inputs);
+
+ if (status == PSA_SUCCESS) {
+ operation->stage = PSA_PAKE_OPERATION_STAGE_COMPUTATION;
+ } else {
+ return status;
+ }
+ }
+
+ if (operation->stage != PSA_PAKE_OPERATION_STAGE_COMPUTATION) {
+ return PSA_ERROR_BAD_STATE;
+ }
+
if (operation->id == 0) {
return PSA_ERROR_BAD_STATE;
}
@@ -7341,8 +7449,10 @@
psa_status_t psa_pake_abort(
psa_pake_operation_t *operation)
{
- /* Aborting a non-active operation is allowed */
- if (operation->id == 0) {
+ /* If we are in collecting inputs stage clear inputs. */
+ if (operation->stage == PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
+ mbedtls_free(operation->data.inputs.password);
+ memset(&operation->data.inputs, 0, sizeof(psa_crypto_driver_pake_inputs_t));
return PSA_SUCCESS;
}