Move JPAKE state machine logic from driver to core
- Add `alg` and `computation_stage` to `psa_pake_operation_s`.
Now when logic is moved to core information about `alg` is required.
`computation_stage` is a structure that provides a union of computation stages for pake algorithms.
- Move the jpake operation logic from driver to core. This requires changing driver entry points for `psa_pake_output`/`psa_pake_input` functions and adding a `computation_stage` parameter. I'm not sure if this solution is correct. Now the driver can check the current computation stage and perform some action. For jpake drivers `step` parameter is now not used, but I think it needs to stay as it might be needed for other pake algorithms.
- Removed test that seems to be redundant as we can't be sure that operation is aborted after failure.
Signed-off-by: Przemek Stekiel <przemyslaw.stekiel@mobica.com>
diff --git a/library/psa_crypto_pake.c b/library/psa_crypto_pake.c
index 3a710dc..3d5b57d 100644
--- a/library/psa_crypto_pake.c
+++ b/library/psa_crypto_pake.c
@@ -79,23 +79,6 @@
* psa_pake_abort()
*/
-enum psa_pake_step {
- PSA_PAKE_STEP_INVALID = 0,
- PSA_PAKE_STEP_X1_X2 = 1,
- PSA_PAKE_STEP_X2S = 2,
- PSA_PAKE_STEP_DERIVE = 3,
-};
-
-enum psa_pake_state {
- PSA_PAKE_STATE_INVALID = 0,
- PSA_PAKE_STATE_SETUP = 1,
- PSA_PAKE_STATE_READY = 2,
- PSA_PAKE_OUTPUT_X1_X2 = 3,
- PSA_PAKE_OUTPUT_X2S = 4,
- PSA_PAKE_INPUT_X1_X2 = 5,
- PSA_PAKE_INPUT_X4S = 6,
-};
-
/*
* The first PAKE step shares the same sequences of the second PAKE step
* but with a second set of KEY_SHARE/ZK_PUBLIC/ZK_PROOF outputs/inputs.
@@ -157,16 +140,6 @@
* psa_pake_get_implicit_key()
* => Input & Output Step = PSA_PAKE_STEP_INVALID
*/
-enum psa_pake_sequence {
- PSA_PAKE_SEQ_INVALID = 0,
- PSA_PAKE_X1_STEP_KEY_SHARE = 1, /* also X2S & X4S KEY_SHARE */
- PSA_PAKE_X1_STEP_ZK_PUBLIC = 2, /* also X2S & X4S ZK_PUBLIC */
- PSA_PAKE_X1_STEP_ZK_PROOF = 3, /* also X2S & X4S ZK_PROOF */
- PSA_PAKE_X2_STEP_KEY_SHARE = 4,
- PSA_PAKE_X2_STEP_ZK_PUBLIC = 5,
- PSA_PAKE_X2_STEP_ZK_PROOF = 6,
- PSA_PAKE_SEQ_END = 7,
-};
#if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
static psa_status_t mbedtls_ecjpake_to_psa_error(int ret)
@@ -190,65 +163,6 @@
}
#endif
-#if defined(MBEDTLS_PSA_BUILTIN_PAKE)
-psa_status_t mbedtls_psa_pake_setup(mbedtls_psa_pake_operation_t *operation,
- const psa_crypto_driver_pake_inputs_t *inputs)
-{
- psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
-
- uint8_t *password = inputs->password;
- size_t password_len = inputs->password_len;
- psa_pake_role_t role = inputs->role;
- psa_pake_cipher_suite_t cipher_suite = inputs->cipher_suite;
-
- memset(operation, 0, sizeof(mbedtls_psa_pake_operation_t));
-
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
- if (cipher_suite.algorithm == PSA_ALG_JPAKE) {
- if (cipher_suite.type != PSA_PAKE_PRIMITIVE_TYPE_ECC ||
- cipher_suite.family != PSA_ECC_FAMILY_SECP_R1 ||
- cipher_suite.bits != 256 ||
- cipher_suite.hash != PSA_ALG_SHA_256) {
- status = PSA_ERROR_NOT_SUPPORTED;
- goto error;
- }
-
- if (role != PSA_PAKE_ROLE_CLIENT &&
- role != PSA_PAKE_ROLE_SERVER) {
- status = PSA_ERROR_NOT_SUPPORTED;
- goto error;
- }
-
- mbedtls_ecjpake_init(&operation->ctx.pake);
-
- operation->state = PSA_PAKE_STATE_SETUP;
- operation->sequence = PSA_PAKE_SEQ_INVALID;
- operation->input_step = PSA_PAKE_STEP_X1_X2;
- operation->output_step = PSA_PAKE_STEP_X1_X2;
- operation->password_len = password_len;
- operation->password = password;
- operation->role = role;
- operation->alg = cipher_suite.algorithm;
-
- mbedtls_platform_zeroize(operation->buffer, MBEDTLS_PSA_PAKE_BUFFER_SIZE);
- operation->buffer_length = 0;
- operation->buffer_offset = 0;
-
- return PSA_SUCCESS;
- } else
-#else
- (void) operation;
- (void) inputs;
-#endif
- { status = PSA_ERROR_NOT_SUPPORTED; }
-
-error:
- mbedtls_free(password);
- mbedtls_psa_pake_abort(operation);
- return status;
-}
-
-
#if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
static psa_status_t psa_pake_ecjpake_setup(mbedtls_psa_pake_operation_t *operation)
{
@@ -283,31 +197,84 @@
return mbedtls_ecjpake_to_psa_error(ret);
}
- operation->state = PSA_PAKE_STATE_READY;
-
return PSA_SUCCESS;
}
+
+psa_status_t mbedtls_psa_pake_setup(mbedtls_psa_pake_operation_t *operation,
+ const psa_crypto_driver_pake_inputs_t *inputs)
+{
+ psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+
+ uint8_t *password = inputs->password;
+ size_t password_len = inputs->password_len;
+ psa_pake_role_t role = inputs->role;
+ psa_pake_cipher_suite_t cipher_suite = inputs->cipher_suite;
+
+ memset(operation, 0, sizeof(mbedtls_psa_pake_operation_t));
+
+#if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
+ if (cipher_suite.algorithm == PSA_ALG_JPAKE) {
+ if (cipher_suite.type != PSA_PAKE_PRIMITIVE_TYPE_ECC ||
+ cipher_suite.family != PSA_ECC_FAMILY_SECP_R1 ||
+ cipher_suite.bits != 256 ||
+ cipher_suite.hash != PSA_ALG_SHA_256) {
+ status = PSA_ERROR_NOT_SUPPORTED;
+ goto error;
+ }
+
+ if (role != PSA_PAKE_ROLE_CLIENT &&
+ role != PSA_PAKE_ROLE_SERVER) {
+ status = PSA_ERROR_NOT_SUPPORTED;
+ goto error;
+ }
+
+ mbedtls_ecjpake_init(&operation->ctx.pake);
+
+ operation->password_len = password_len;
+ operation->password = password;
+ operation->role = role;
+ operation->alg = cipher_suite.algorithm;
+
+ mbedtls_platform_zeroize(operation->buffer, MBEDTLS_PSA_PAKE_BUFFER_SIZE);
+ operation->buffer_length = 0;
+ operation->buffer_offset = 0;
+
+ status = psa_pake_ecjpake_setup(operation);
+
+ if (status != PSA_SUCCESS) {
+ goto error;
+ }
+
+ return PSA_SUCCESS;
+ } else
+#else
+ (void) operation;
+ (void) inputs;
#endif
+ { status = PSA_ERROR_NOT_SUPPORTED; }
+
+error:
+ mbedtls_free(password);
+ mbedtls_psa_pake_abort(operation);
+ return status;
+}
static psa_status_t mbedtls_psa_pake_output_internal(
mbedtls_psa_pake_operation_t *operation,
psa_pake_step_t step,
+ const psa_pake_computation_stage_t *computation_stage,
uint8_t *output,
size_t output_size,
size_t *output_length)
{
int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
- psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
size_t length;
+ (void) step;
if (operation->alg == PSA_ALG_NONE) {
return PSA_ERROR_BAD_STATE;
}
- if (operation->state == PSA_PAKE_STATE_INVALID) {
- return PSA_ERROR_BAD_STATE;
- }
-
#if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
/*
* The PSA CRYPTO PAKE and MbedTLS JPAKE API have a different
@@ -324,74 +291,12 @@
* to return the right parts on each step.
*/
if (operation->alg == PSA_ALG_JPAKE) {
- if (step != PSA_PAKE_STEP_KEY_SHARE &&
- step != PSA_PAKE_STEP_ZK_PUBLIC &&
- step != PSA_PAKE_STEP_ZK_PROOF) {
- return PSA_ERROR_INVALID_ARGUMENT;
- }
-
- if (operation->state == PSA_PAKE_STATE_SETUP) {
- status = psa_pake_ecjpake_setup(operation);
- if (status != PSA_SUCCESS) {
- return status;
- }
- }
-
- if (operation->state != PSA_PAKE_STATE_READY &&
- operation->state != PSA_PAKE_OUTPUT_X1_X2 &&
- operation->state != PSA_PAKE_OUTPUT_X2S) {
- return PSA_ERROR_BAD_STATE;
- }
-
- if (operation->state == PSA_PAKE_STATE_READY) {
- if (step != PSA_PAKE_STEP_KEY_SHARE) {
- return PSA_ERROR_BAD_STATE;
- }
-
- switch (operation->output_step) {
- case PSA_PAKE_STEP_X1_X2:
- operation->state = PSA_PAKE_OUTPUT_X1_X2;
- break;
- case PSA_PAKE_STEP_X2S:
- operation->state = PSA_PAKE_OUTPUT_X2S;
- break;
- default:
- return PSA_ERROR_BAD_STATE;
- }
-
- operation->sequence = PSA_PAKE_X1_STEP_KEY_SHARE;
- }
-
- /* Check if step matches current sequence */
- switch (operation->sequence) {
- case PSA_PAKE_X1_STEP_KEY_SHARE:
- case PSA_PAKE_X2_STEP_KEY_SHARE:
- if (step != PSA_PAKE_STEP_KEY_SHARE) {
- return PSA_ERROR_BAD_STATE;
- }
- break;
-
- case PSA_PAKE_X1_STEP_ZK_PUBLIC:
- case PSA_PAKE_X2_STEP_ZK_PUBLIC:
- if (step != PSA_PAKE_STEP_ZK_PUBLIC) {
- return PSA_ERROR_BAD_STATE;
- }
- break;
-
- case PSA_PAKE_X1_STEP_ZK_PROOF:
- case PSA_PAKE_X2_STEP_ZK_PROOF:
- if (step != PSA_PAKE_STEP_ZK_PROOF) {
- return PSA_ERROR_BAD_STATE;
- }
- break;
-
- default:
- return PSA_ERROR_BAD_STATE;
- }
+ const psa_jpake_computation_stage_t *jpake_computation_stage =
+ &computation_stage->data.jpake_computation_stage;
/* Initialize & write round on KEY_SHARE sequences */
- if (operation->state == PSA_PAKE_OUTPUT_X1_X2 &&
- operation->sequence == PSA_PAKE_X1_STEP_KEY_SHARE) {
+ if (jpake_computation_stage->state == PSA_PAKE_OUTPUT_X1_X2 &&
+ jpake_computation_stage->sequence == PSA_PAKE_X1_STEP_KEY_SHARE) {
ret = mbedtls_ecjpake_write_round_one(&operation->ctx.pake,
operation->buffer,
MBEDTLS_PSA_PAKE_BUFFER_SIZE,
@@ -403,8 +308,8 @@
}
operation->buffer_offset = 0;
- } else if (operation->state == PSA_PAKE_OUTPUT_X2S &&
- operation->sequence == PSA_PAKE_X1_STEP_KEY_SHARE) {
+ } else if (jpake_computation_stage->state == PSA_PAKE_OUTPUT_X2S &&
+ jpake_computation_stage->sequence == PSA_PAKE_X1_STEP_KEY_SHARE) {
ret = mbedtls_ecjpake_write_round_two(&operation->ctx.pake,
operation->buffer,
MBEDTLS_PSA_PAKE_BUFFER_SIZE,
@@ -429,8 +334,8 @@
* output with a length byte, even less a curve identifier, as that
* information is already available.
*/
- if (operation->state == PSA_PAKE_OUTPUT_X2S &&
- operation->sequence == PSA_PAKE_X1_STEP_KEY_SHARE &&
+ if (jpake_computation_stage->state == PSA_PAKE_OUTPUT_X2S &&
+ jpake_computation_stage->sequence == PSA_PAKE_X1_STEP_KEY_SHARE &&
operation->role == PSA_PAKE_ROLE_SERVER) {
/* Skip ECParameters, with is 3 bytes (RFC 8422) */
operation->buffer_offset += 3;
@@ -456,25 +361,20 @@
operation->buffer_offset += length;
/* Reset buffer after ZK_PROOF sequence */
- if ((operation->state == PSA_PAKE_OUTPUT_X1_X2 &&
- operation->sequence == PSA_PAKE_X2_STEP_ZK_PROOF) ||
- (operation->state == PSA_PAKE_OUTPUT_X2S &&
- operation->sequence == PSA_PAKE_X1_STEP_ZK_PROOF)) {
+ if ((jpake_computation_stage->state == PSA_PAKE_OUTPUT_X1_X2 &&
+ jpake_computation_stage->sequence == PSA_PAKE_X2_STEP_ZK_PROOF) ||
+ (jpake_computation_stage->state == PSA_PAKE_OUTPUT_X2S &&
+ jpake_computation_stage->sequence == PSA_PAKE_X1_STEP_ZK_PROOF)) {
mbedtls_platform_zeroize(operation->buffer, MBEDTLS_PSA_PAKE_BUFFER_SIZE);
operation->buffer_length = 0;
operation->buffer_offset = 0;
-
- operation->state = PSA_PAKE_STATE_READY;
- operation->output_step++;
- operation->sequence = PSA_PAKE_SEQ_INVALID;
- } else {
- operation->sequence++;
}
return PSA_SUCCESS;
} else
#else
(void) step;
+ (void) computation_stage;
(void) output;
(void) output_size;
(void) output_length;
@@ -484,12 +384,13 @@
psa_status_t mbedtls_psa_pake_output(mbedtls_psa_pake_operation_t *operation,
psa_pake_step_t step,
+ const psa_pake_computation_stage_t *computation_stage,
uint8_t *output,
size_t output_size,
size_t *output_length)
{
psa_status_t status = mbedtls_psa_pake_output_internal(
- operation, step, output, output_size, output_length);
+ operation, step, computation_stage, output, output_size, output_length);
if (status != PSA_SUCCESS) {
mbedtls_psa_pake_abort(operation);
@@ -501,20 +402,16 @@
static psa_status_t mbedtls_psa_pake_input_internal(
mbedtls_psa_pake_operation_t *operation,
psa_pake_step_t step,
+ const psa_pake_computation_stage_t *computation_stage,
const uint8_t *input,
size_t input_length)
{
int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
- psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
-
+ (void) step;
if (operation->alg == PSA_ALG_NONE) {
return PSA_ERROR_BAD_STATE;
}
- if (operation->state == PSA_PAKE_STATE_INVALID) {
- return PSA_ERROR_BAD_STATE;
- }
-
#if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
/*
* The PSA CRYPTO PAKE and MbedTLS JPAKE API have a different
@@ -532,77 +429,8 @@
* This causes any input error to be only detected on the last step.
*/
if (operation->alg == PSA_ALG_JPAKE) {
- if (step != PSA_PAKE_STEP_KEY_SHARE &&
- step != PSA_PAKE_STEP_ZK_PUBLIC &&
- step != PSA_PAKE_STEP_ZK_PROOF) {
- return PSA_ERROR_INVALID_ARGUMENT;
- }
-
- const psa_pake_primitive_t prim = PSA_PAKE_PRIMITIVE(
- PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256);
- if (input_length > (size_t) PSA_PAKE_INPUT_SIZE(PSA_ALG_JPAKE, prim, step)) {
- return PSA_ERROR_INVALID_ARGUMENT;
- }
-
- if (operation->state == PSA_PAKE_STATE_SETUP) {
- status = psa_pake_ecjpake_setup(operation);
- if (status != PSA_SUCCESS) {
- return status;
- }
- }
-
- if (operation->state != PSA_PAKE_STATE_READY &&
- operation->state != PSA_PAKE_INPUT_X1_X2 &&
- operation->state != PSA_PAKE_INPUT_X4S) {
- return PSA_ERROR_BAD_STATE;
- }
-
- if (operation->state == PSA_PAKE_STATE_READY) {
- if (step != PSA_PAKE_STEP_KEY_SHARE) {
- return PSA_ERROR_BAD_STATE;
- }
-
- switch (operation->input_step) {
- case PSA_PAKE_STEP_X1_X2:
- operation->state = PSA_PAKE_INPUT_X1_X2;
- break;
- case PSA_PAKE_STEP_X2S:
- operation->state = PSA_PAKE_INPUT_X4S;
- break;
- default:
- return PSA_ERROR_BAD_STATE;
- }
-
- operation->sequence = PSA_PAKE_X1_STEP_KEY_SHARE;
- }
-
- /* Check if step matches current sequence */
- switch (operation->sequence) {
- case PSA_PAKE_X1_STEP_KEY_SHARE:
- case PSA_PAKE_X2_STEP_KEY_SHARE:
- if (step != PSA_PAKE_STEP_KEY_SHARE) {
- return PSA_ERROR_BAD_STATE;
- }
- break;
-
- case PSA_PAKE_X1_STEP_ZK_PUBLIC:
- case PSA_PAKE_X2_STEP_ZK_PUBLIC:
- if (step != PSA_PAKE_STEP_ZK_PUBLIC) {
- return PSA_ERROR_BAD_STATE;
- }
- break;
-
- case PSA_PAKE_X1_STEP_ZK_PROOF:
- case PSA_PAKE_X2_STEP_ZK_PROOF:
- if (step != PSA_PAKE_STEP_ZK_PROOF) {
- return PSA_ERROR_BAD_STATE;
- }
- break;
-
- default:
- return PSA_ERROR_BAD_STATE;
- }
-
+ const psa_jpake_computation_stage_t *jpake_computation_stage =
+ &computation_stage->data.jpake_computation_stage;
/*
* Copy input to local buffer and format it as the Mbed TLS API
* expects, i.e. as defined by draft-cragie-tls-ecjpake-01 section 7.
@@ -612,8 +440,8 @@
* ECParameters structure - which means we have to prepend that when
* we're a client.
*/
- if (operation->state == PSA_PAKE_INPUT_X4S &&
- operation->sequence == PSA_PAKE_X1_STEP_KEY_SHARE &&
+ if (jpake_computation_stage->state == PSA_PAKE_INPUT_X4S &&
+ jpake_computation_stage->sequence == PSA_PAKE_X1_STEP_KEY_SHARE &&
operation->role == PSA_PAKE_ROLE_CLIENT) {
/* We only support secp256r1. */
/* This is the ECParameters structure defined by RFC 8422. */
@@ -636,8 +464,8 @@
operation->buffer_length += input_length;
/* Load buffer at each last round ZK_PROOF */
- if (operation->state == PSA_PAKE_INPUT_X1_X2 &&
- operation->sequence == PSA_PAKE_X2_STEP_ZK_PROOF) {
+ if (jpake_computation_stage->state == PSA_PAKE_INPUT_X1_X2 &&
+ jpake_computation_stage->sequence == PSA_PAKE_X2_STEP_ZK_PROOF) {
ret = mbedtls_ecjpake_read_round_one(&operation->ctx.pake,
operation->buffer,
operation->buffer_length);
@@ -648,8 +476,8 @@
if (ret != 0) {
return mbedtls_ecjpake_to_psa_error(ret);
}
- } else if (operation->state == PSA_PAKE_INPUT_X4S &&
- operation->sequence == PSA_PAKE_X1_STEP_ZK_PROOF) {
+ } else if (jpake_computation_stage->state == PSA_PAKE_INPUT_X4S &&
+ jpake_computation_stage->sequence == PSA_PAKE_X1_STEP_ZK_PROOF) {
ret = mbedtls_ecjpake_read_round_two(&operation->ctx.pake,
operation->buffer,
operation->buffer_length);
@@ -662,21 +490,11 @@
}
}
- if ((operation->state == PSA_PAKE_INPUT_X1_X2 &&
- operation->sequence == PSA_PAKE_X2_STEP_ZK_PROOF) ||
- (operation->state == PSA_PAKE_INPUT_X4S &&
- operation->sequence == PSA_PAKE_X1_STEP_ZK_PROOF)) {
- operation->state = PSA_PAKE_STATE_READY;
- operation->input_step++;
- operation->sequence = PSA_PAKE_SEQ_INVALID;
- } else {
- operation->sequence++;
- }
-
return PSA_SUCCESS;
} else
#else
(void) step;
+ (void) computation_stage;
(void) input;
(void) input_length;
#endif
@@ -685,11 +503,12 @@
psa_status_t mbedtls_psa_pake_input(mbedtls_psa_pake_operation_t *operation,
psa_pake_step_t step,
+ const psa_pake_computation_stage_t *computation_stage,
const uint8_t *input,
size_t input_length)
{
psa_status_t status = mbedtls_psa_pake_input_internal(
- operation, step, input, input_length);
+ operation, step, computation_stage, input, input_length);
if (status != PSA_SUCCESS) {
mbedtls_psa_pake_abort(operation);
@@ -703,18 +522,11 @@
uint8_t *output, size_t *output_size)
{
int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
- psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
if (operation->alg == PSA_ALG_NONE) {
return PSA_ERROR_BAD_STATE;
}
- if (operation->input_step != PSA_PAKE_STEP_DERIVE ||
- operation->output_step != PSA_PAKE_STEP_DERIVE) {
- status = PSA_ERROR_BAD_STATE;
- goto error;
- }
-
#if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
if (operation->alg == PSA_ALG_JPAKE) {
ret = mbedtls_ecjpake_write_shared_key(&operation->ctx.pake,
@@ -740,12 +552,7 @@
#else
(void) output;
#endif
- { status = PSA_ERROR_NOT_SUPPORTED; }
-
-error:
- mbedtls_psa_pake_abort(operation);
-
- return status;
+ { return PSA_ERROR_NOT_SUPPORTED; }
}
psa_status_t mbedtls_psa_pake_abort(mbedtls_psa_pake_operation_t *operation)
@@ -757,8 +564,6 @@
#if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
if (operation->alg == PSA_ALG_JPAKE) {
- operation->input_step = PSA_PAKE_STEP_INVALID;
- operation->output_step = PSA_PAKE_STEP_INVALID;
if (operation->password_len > 0) {
mbedtls_platform_zeroize(operation->password, operation->password_len);
}
@@ -774,8 +579,6 @@
#endif
operation->alg = PSA_ALG_NONE;
- operation->state = PSA_PAKE_STATE_INVALID;
- operation->sequence = PSA_PAKE_SEQ_INVALID;
return PSA_SUCCESS;
}