Move JPAKE state machine logic from driver to core
- Add `alg` and `computation_stage` to `psa_pake_operation_s`.
Now when logic is moved to core information about `alg` is required.
`computation_stage` is a structure that provides a union of computation stages for pake algorithms.
- Move the jpake operation logic from driver to core. This requires changing driver entry points for `psa_pake_output`/`psa_pake_input` functions and adding a `computation_stage` parameter. I'm not sure if this solution is correct. Now the driver can check the current computation stage and perform some action. For jpake drivers `step` parameter is now not used, but I think it needs to stay as it might be needed for other pake algorithms.
- Removed test that seems to be redundant as we can't be sure that operation is aborted after failure.
Signed-off-by: Przemek Stekiel <przemyslaw.stekiel@mobica.com>
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 273d248..66ecc06 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -7180,11 +7180,14 @@
psa_pake_operation_t *operation,
const psa_pake_cipher_suite_t *cipher_suite)
{
+ psa_jpake_computation_stage_t *computation_stage =
+ &operation->computation_stage.data.jpake_computation_stage;
+
if (operation->stage != PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
return PSA_ERROR_BAD_STATE;
}
- if (operation->data.inputs.alg != PSA_ALG_NONE) {
+ if (operation->alg != PSA_ALG_NONE) {
return PSA_ERROR_BAD_STATE;
}
@@ -7198,9 +7201,16 @@
memset(&operation->data.inputs, 0, sizeof(operation->data.inputs));
- operation->data.inputs.alg = cipher_suite->algorithm;
+ operation->alg = cipher_suite->algorithm;
operation->data.inputs.cipher_suite = *cipher_suite;
+ if (operation->alg == PSA_ALG_JPAKE) {
+ computation_stage->state = PSA_PAKE_STATE_SETUP;
+ computation_stage->sequence = PSA_PAKE_SEQ_INVALID;
+ computation_stage->input_step = PSA_PAKE_STEP_X1_X2;
+ computation_stage->output_step = PSA_PAKE_STEP_X1_X2;
+ }
+
return PSA_SUCCESS;
}
@@ -7216,7 +7226,7 @@
return PSA_ERROR_BAD_STATE;
}
- if (operation->data.inputs.alg == PSA_ALG_NONE) {
+ if (operation->alg == PSA_ALG_NONE) {
return PSA_ERROR_BAD_STATE;
}
@@ -7241,7 +7251,8 @@
operation->data.inputs.password = mbedtls_calloc(1, slot->key.bytes);
if (operation->data.inputs.password == NULL) {
- return PSA_ERROR_INSUFFICIENT_MEMORY;
+ status = PSA_ERROR_INSUFFICIENT_MEMORY;
+ goto error;
}
memcpy(operation->data.inputs.password, slot->key.data, slot->key.bytes);
@@ -7264,7 +7275,7 @@
return PSA_ERROR_BAD_STATE;
}
- if (operation->data.inputs.alg == PSA_ALG_NONE) {
+ if (operation->alg == PSA_ALG_NONE) {
return PSA_ERROR_BAD_STATE;
}
@@ -7286,7 +7297,7 @@
return PSA_ERROR_BAD_STATE;
}
- if (operation->data.inputs.alg == PSA_ALG_NONE) {
+ if (operation->alg == PSA_ALG_NONE) {
return PSA_ERROR_BAD_STATE;
}
@@ -7305,7 +7316,7 @@
return PSA_ERROR_BAD_STATE;
}
- if (operation->data.inputs.alg == PSA_ALG_NONE) {
+ if (operation->alg == PSA_ALG_NONE) {
return PSA_ERROR_BAD_STATE;
}
@@ -7322,6 +7333,98 @@
return PSA_SUCCESS;
}
+static psa_status_t psa_jpake_output_prologue(
+ psa_pake_operation_t *operation,
+ psa_pake_step_t step)
+{
+ psa_jpake_computation_stage_t *computation_stage =
+ &operation->computation_stage.data.jpake_computation_stage;
+
+ if (computation_stage->state == PSA_PAKE_STATE_INVALID) {
+ return PSA_ERROR_BAD_STATE;
+ }
+
+ if (step != PSA_PAKE_STEP_KEY_SHARE &&
+ step != PSA_PAKE_STEP_ZK_PUBLIC &&
+ step != PSA_PAKE_STEP_ZK_PROOF) {
+ return PSA_ERROR_INVALID_ARGUMENT;
+ }
+
+ if (computation_stage->state != PSA_PAKE_STATE_READY &&
+ computation_stage->state != PSA_PAKE_OUTPUT_X1_X2 &&
+ computation_stage->state != PSA_PAKE_OUTPUT_X2S) {
+ return PSA_ERROR_BAD_STATE;
+ }
+
+ if (computation_stage->state == PSA_PAKE_STATE_READY) {
+ if (step != PSA_PAKE_STEP_KEY_SHARE) {
+ return PSA_ERROR_BAD_STATE;
+ }
+
+ switch (computation_stage->output_step) {
+ case PSA_PAKE_STEP_X1_X2:
+ computation_stage->state = PSA_PAKE_OUTPUT_X1_X2;
+ break;
+ case PSA_PAKE_STEP_X2S:
+ computation_stage->state = PSA_PAKE_OUTPUT_X2S;
+ break;
+ default:
+ return PSA_ERROR_BAD_STATE;
+ }
+
+ computation_stage->sequence = PSA_PAKE_X1_STEP_KEY_SHARE;
+ }
+
+ /* Check if step matches current sequence */
+ switch (computation_stage->sequence) {
+ case PSA_PAKE_X1_STEP_KEY_SHARE:
+ case PSA_PAKE_X2_STEP_KEY_SHARE:
+ if (step != PSA_PAKE_STEP_KEY_SHARE) {
+ return PSA_ERROR_BAD_STATE;
+ }
+ break;
+
+ case PSA_PAKE_X1_STEP_ZK_PUBLIC:
+ case PSA_PAKE_X2_STEP_ZK_PUBLIC:
+ if (step != PSA_PAKE_STEP_ZK_PUBLIC) {
+ return PSA_ERROR_BAD_STATE;
+ }
+ break;
+
+ case PSA_PAKE_X1_STEP_ZK_PROOF:
+ case PSA_PAKE_X2_STEP_ZK_PROOF:
+ if (step != PSA_PAKE_STEP_ZK_PROOF) {
+ return PSA_ERROR_BAD_STATE;
+ }
+ break;
+
+ default:
+ return PSA_ERROR_BAD_STATE;
+ }
+
+ return PSA_SUCCESS;
+}
+
+static psa_status_t psa_jpake_output_epilogue(
+ psa_pake_operation_t *operation)
+{
+ psa_jpake_computation_stage_t *computation_stage =
+ &operation->computation_stage.data.jpake_computation_stage;
+
+ if ((computation_stage->state == PSA_PAKE_OUTPUT_X1_X2 &&
+ computation_stage->sequence == PSA_PAKE_X2_STEP_ZK_PROOF) ||
+ (computation_stage->state == PSA_PAKE_OUTPUT_X2S &&
+ computation_stage->sequence == PSA_PAKE_X1_STEP_ZK_PROOF)) {
+ computation_stage->state = PSA_PAKE_STATE_READY;
+ computation_stage->output_step++;
+ computation_stage->sequence = PSA_PAKE_SEQ_INVALID;
+ } else {
+ computation_stage->sequence++;
+ }
+
+ return PSA_SUCCESS;
+}
+
psa_status_t psa_pake_output(
psa_pake_operation_t *operation,
psa_pake_step_t step,
@@ -7330,9 +7433,11 @@
size_t *output_length)
{
psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+ psa_jpake_computation_stage_t *computation_stage =
+ &operation->computation_stage.data.jpake_computation_stage;
if (operation->stage == PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
- if (operation->data.inputs.alg == PSA_ALG_NONE ||
+ if (operation->alg == PSA_ALG_NONE ||
operation->data.inputs.password_len == 0 ||
operation->data.inputs.role == PSA_PAKE_ROLE_NONE) {
return PSA_ERROR_BAD_STATE;
@@ -7343,6 +7448,12 @@
if (status == PSA_SUCCESS) {
operation->stage = PSA_PAKE_OPERATION_STAGE_COMPUTATION;
+ if (operation->alg == PSA_ALG_JPAKE) {
+ computation_stage->state = PSA_PAKE_STATE_READY;
+ computation_stage->sequence = PSA_PAKE_SEQ_INVALID;
+ computation_stage->input_step = PSA_PAKE_STEP_X1_X2;
+ computation_stage->output_step = PSA_PAKE_STEP_X1_X2;
+ }
} else {
return status;
}
@@ -7360,10 +7471,140 @@
return PSA_ERROR_INVALID_ARGUMENT;
}
- return psa_driver_wrapper_pake_output(operation, step, output,
- output_size, output_length);
+ switch (operation->alg) {
+ case PSA_ALG_JPAKE:
+ status = psa_jpake_output_prologue(operation, step);
+ if (status != PSA_SUCCESS) {
+ return status;
+ }
+ break;
+ default:
+ return PSA_ERROR_NOT_SUPPORTED;
+ }
+
+ status = psa_driver_wrapper_pake_output(operation, step,
+ &operation->computation_stage,
+ output, output_size, output_length);
+
+ if (status != PSA_SUCCESS) {
+ return status;
+ }
+
+ switch (operation->alg) {
+ case PSA_ALG_JPAKE:
+ status = psa_jpake_output_epilogue(operation);
+ if (status != PSA_SUCCESS) {
+ return status;
+ }
+ break;
+ default:
+ return PSA_ERROR_NOT_SUPPORTED;
+ }
+
+ return status;
}
+static psa_status_t psa_jpake_input_prologue(
+ psa_pake_operation_t *operation,
+ psa_pake_step_t step,
+ size_t input_length)
+{
+ psa_jpake_computation_stage_t *computation_stage =
+ &operation->computation_stage.data.jpake_computation_stage;
+
+ if (computation_stage->state == PSA_PAKE_STATE_INVALID) {
+ return PSA_ERROR_BAD_STATE;
+ }
+
+ if (step != PSA_PAKE_STEP_KEY_SHARE &&
+ step != PSA_PAKE_STEP_ZK_PUBLIC &&
+ step != PSA_PAKE_STEP_ZK_PROOF) {
+ return PSA_ERROR_INVALID_ARGUMENT;
+ }
+
+ const psa_pake_primitive_t prim = PSA_PAKE_PRIMITIVE(
+ PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256);
+ if (input_length > (size_t) PSA_PAKE_INPUT_SIZE(PSA_ALG_JPAKE, prim, step)) {
+ return PSA_ERROR_INVALID_ARGUMENT;
+ }
+
+ if (computation_stage->state != PSA_PAKE_STATE_READY &&
+ computation_stage->state != PSA_PAKE_INPUT_X1_X2 &&
+ computation_stage->state != PSA_PAKE_INPUT_X4S) {
+ return PSA_ERROR_BAD_STATE;
+ }
+
+ if (computation_stage->state == PSA_PAKE_STATE_READY) {
+ if (step != PSA_PAKE_STEP_KEY_SHARE) {
+ return PSA_ERROR_BAD_STATE;
+ }
+
+ switch (computation_stage->input_step) {
+ case PSA_PAKE_STEP_X1_X2:
+ computation_stage->state = PSA_PAKE_INPUT_X1_X2;
+ break;
+ case PSA_PAKE_STEP_X2S:
+ computation_stage->state = PSA_PAKE_INPUT_X4S;
+ break;
+ default:
+ return PSA_ERROR_BAD_STATE;
+ }
+
+ computation_stage->sequence = PSA_PAKE_X1_STEP_KEY_SHARE;
+ }
+
+ /* Check if step matches current sequence */
+ switch (computation_stage->sequence) {
+ case PSA_PAKE_X1_STEP_KEY_SHARE:
+ case PSA_PAKE_X2_STEP_KEY_SHARE:
+ if (step != PSA_PAKE_STEP_KEY_SHARE) {
+ return PSA_ERROR_BAD_STATE;
+ }
+ break;
+
+ case PSA_PAKE_X1_STEP_ZK_PUBLIC:
+ case PSA_PAKE_X2_STEP_ZK_PUBLIC:
+ if (step != PSA_PAKE_STEP_ZK_PUBLIC) {
+ return PSA_ERROR_BAD_STATE;
+ }
+ break;
+
+ case PSA_PAKE_X1_STEP_ZK_PROOF:
+ case PSA_PAKE_X2_STEP_ZK_PROOF:
+ if (step != PSA_PAKE_STEP_ZK_PROOF) {
+ return PSA_ERROR_BAD_STATE;
+ }
+ break;
+
+ default:
+ return PSA_ERROR_BAD_STATE;
+ }
+
+ return PSA_SUCCESS;
+}
+
+
+static psa_status_t psa_jpake_input_epilogue(
+ psa_pake_operation_t *operation)
+{
+ psa_jpake_computation_stage_t *computation_stage =
+ &operation->computation_stage.data.jpake_computation_stage;
+
+ if ((computation_stage->state == PSA_PAKE_INPUT_X1_X2 &&
+ computation_stage->sequence == PSA_PAKE_X2_STEP_ZK_PROOF) ||
+ (computation_stage->state == PSA_PAKE_INPUT_X4S &&
+ computation_stage->sequence == PSA_PAKE_X1_STEP_ZK_PROOF)) {
+ computation_stage->state = PSA_PAKE_STATE_READY;
+ computation_stage->input_step++;
+ computation_stage->sequence = PSA_PAKE_SEQ_INVALID;
+ } else {
+ computation_stage->sequence++;
+ }
+
+ return PSA_SUCCESS;
+}
+
+
psa_status_t psa_pake_input(
psa_pake_operation_t *operation,
psa_pake_step_t step,
@@ -7371,9 +7612,11 @@
size_t input_length)
{
psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+ psa_jpake_computation_stage_t *computation_stage =
+ &operation->computation_stage.data.jpake_computation_stage;
if (operation->stage == PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
- if (operation->data.inputs.alg == PSA_ALG_NONE ||
+ if (operation->alg == PSA_ALG_NONE ||
operation->data.inputs.password_len == 0 ||
operation->data.inputs.role == PSA_PAKE_ROLE_NONE) {
return PSA_ERROR_BAD_STATE;
@@ -7384,6 +7627,12 @@
if (status == PSA_SUCCESS) {
operation->stage = PSA_PAKE_OPERATION_STAGE_COMPUTATION;
+ if (operation->alg == PSA_ALG_JPAKE) {
+ computation_stage->state = PSA_PAKE_STATE_READY;
+ computation_stage->sequence = PSA_PAKE_SEQ_INVALID;
+ computation_stage->input_step = PSA_PAKE_STEP_X1_X2;
+ computation_stage->output_step = PSA_PAKE_STEP_X1_X2;
+ }
} else {
return status;
}
@@ -7401,8 +7650,37 @@
return PSA_ERROR_INVALID_ARGUMENT;
}
- return psa_driver_wrapper_pake_input(operation, step, input,
- input_length);
+ switch (operation->alg) {
+ case PSA_ALG_JPAKE:
+ status = psa_jpake_input_prologue(operation, step, input_length);
+ if (status != PSA_SUCCESS) {
+ return status;
+ }
+ break;
+ default:
+ return PSA_ERROR_NOT_SUPPORTED;
+ }
+
+ status = psa_driver_wrapper_pake_input(operation, step,
+ &operation->computation_stage,
+ input, input_length);
+
+ if (status != PSA_SUCCESS) {
+ return status;
+ }
+
+ switch (operation->alg) {
+ case PSA_ALG_JPAKE:
+ status = psa_jpake_input_epilogue(operation);
+ if (status != PSA_SUCCESS) {
+ return status;
+ }
+ break;
+ default:
+ return PSA_ERROR_NOT_SUPPORTED;
+ }
+
+ return status;
}
psa_status_t psa_pake_get_implicit_key(
@@ -7412,11 +7690,20 @@
psa_status_t status = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
uint8_t shared_key[MBEDTLS_PSA_PAKE_BUFFER_SIZE];
size_t shared_key_len = 0;
+ psa_jpake_computation_stage_t *computation_stage =
+ &operation->computation_stage.data.jpake_computation_stage;
if (operation->id == 0) {
return PSA_ERROR_BAD_STATE;
}
+ if (operation->alg == PSA_ALG_JPAKE) {
+ if (computation_stage->input_step != PSA_PAKE_STEP_DERIVE ||
+ computation_stage->output_step != PSA_PAKE_STEP_DERIVE) {
+ return PSA_ERROR_BAD_STATE;
+ }
+ }
+
status = psa_driver_wrapper_pake_get_implicit_key(operation,
shared_key,
&shared_key_len);
@@ -7436,18 +7723,29 @@
mbedtls_platform_zeroize(shared_key, MBEDTLS_PSA_PAKE_BUFFER_SIZE);
+ psa_pake_abort(operation);
+
return status;
}
psa_status_t psa_pake_abort(
psa_pake_operation_t *operation)
{
+ psa_jpake_computation_stage_t *computation_stage =
+ &operation->computation_stage.data.jpake_computation_stage;
+
/* If we are in collecting inputs stage clear inputs. */
if (operation->stage == PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
mbedtls_free(operation->data.inputs.password);
memset(&operation->data.inputs, 0, sizeof(psa_crypto_driver_pake_inputs_t));
return PSA_SUCCESS;
}
+ if (operation->alg == PSA_ALG_JPAKE) {
+ computation_stage->input_step = PSA_PAKE_STEP_INVALID;
+ computation_stage->output_step = PSA_PAKE_STEP_INVALID;
+ computation_stage->state = PSA_PAKE_STATE_INVALID;
+ computation_stage->sequence = PSA_PAKE_SEQ_INVALID;
+ }
return psa_driver_wrapper_pake_abort(operation);
}