Fix mbedtls_pk_get_bitlen() for RSA with non-byte-aligned sizes
Add non-regression tests. Update some test functions to not assume that
byte_length == bit_length / 8.
Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
diff --git a/tests/suites/test_suite_pk.function b/tests/suites/test_suite_pk.function
index 2574307..681de0f 100644
--- a/tests/suites/test_suite_pk.function
+++ b/tests/suites/test_suite_pk.function
@@ -427,7 +427,7 @@
TEST_ASSERT(strcmp(mbedtls_pk_get_name(&pk), name) == 0);
TEST_ASSERT(mbedtls_pk_get_bitlen(&pk) == bitlen);
- TEST_ASSERT(mbedtls_pk_get_len(&pk) == bitlen / 8);
+ TEST_ASSERT(mbedtls_pk_get_len(&pk) == (bitlen + 7) / 8);
if (key_is_rsa) {
TEST_ASSERT(mbedtls_pk_can_do(&pk, MBEDTLS_PK_ECKEY) == 0);
@@ -822,7 +822,7 @@
TEST_ASSERT(mbedtls_pk_setup(&pk, mbedtls_pk_info_from_type(MBEDTLS_PK_RSA)) == 0);
rsa = mbedtls_pk_rsa(pk);
- rsa->len = mod / 8;
+ rsa->len = (mod + 7) / 8;
TEST_ASSERT(mbedtls_test_read_mpi(&rsa->N, input_N) == 0);
TEST_ASSERT(mbedtls_test_read_mpi(&rsa->E, input_E) == 0);
@@ -862,7 +862,7 @@
TEST_ASSERT(mbedtls_pk_setup(&pk, mbedtls_pk_info_from_type(MBEDTLS_PK_RSA)) == 0);
rsa = mbedtls_pk_rsa(pk);
- rsa->len = mod / 8;
+ rsa->len = (mod + 7) / 8;
TEST_ASSERT(mbedtls_test_read_mpi(&rsa->N, input_N) == 0);
TEST_ASSERT(mbedtls_test_read_mpi(&rsa->E, input_E) == 0);
@@ -1143,7 +1143,7 @@
rsa = mbedtls_pk_rsa(pk);
/* load public key */
- rsa->len = mod / 8;
+ rsa->len = (mod + 7) / 8;
TEST_ASSERT(mbedtls_test_read_mpi(&rsa->N, input_N) == 0);
TEST_ASSERT(mbedtls_test_read_mpi(&rsa->E, input_E) == 0);
@@ -1169,9 +1169,12 @@
TEST_ASSERT(mbedtls_test_read_mpi(&P, input_P) == 0);
TEST_ASSERT(mbedtls_test_read_mpi(&Q, input_Q) == 0);
TEST_ASSERT(mbedtls_rsa_import(rsa, &N, &P, &Q, NULL, &E) == 0);
- TEST_ASSERT(mbedtls_rsa_get_len(rsa) == (size_t) (mod / 8));
+ TEST_EQUAL(mbedtls_rsa_get_len(rsa), (mod + 7) / 8);
TEST_ASSERT(mbedtls_rsa_complete(rsa) == 0);
+ TEST_EQUAL(mbedtls_pk_get_len(&pk), (mod + 7) / 8);
+ TEST_EQUAL(mbedtls_pk_get_bitlen(&pk), mod);
+
memset(result, 0, sizeof(result));
rlen = 0;
TEST_ASSERT(mbedtls_pk_decrypt(&pk, output, olen,
@@ -1222,9 +1225,12 @@
TEST_ASSERT(mbedtls_test_read_mpi(&P, input_P) == 0);
TEST_ASSERT(mbedtls_test_read_mpi(&Q, input_Q) == 0);
TEST_ASSERT(mbedtls_rsa_import(rsa, &N, &P, &Q, NULL, &E) == 0);
- TEST_ASSERT(mbedtls_rsa_get_len(rsa) == (size_t) (mod / 8));
+ TEST_EQUAL(mbedtls_rsa_get_len(rsa), (mod + 7) / 8);
TEST_ASSERT(mbedtls_rsa_complete(rsa) == 0);
+ TEST_EQUAL(mbedtls_pk_get_bitlen(&pk), mod);
+ TEST_EQUAL(mbedtls_pk_get_len(&pk), (mod + 7) / 8);
+
/* decryption test */
memset(output, 0, sizeof(output));
olen = 0;
@@ -1278,7 +1284,7 @@
TEST_EQUAL(mbedtls_test_read_mpi(&P, input_P), 0);
TEST_EQUAL(mbedtls_test_read_mpi(&Q, input_Q), 0);
TEST_EQUAL(mbedtls_rsa_import(rsa, &N, &P, &Q, NULL, &E), 0);
- TEST_EQUAL(mbedtls_rsa_get_len(rsa), (size_t) (mod / 8));
+ TEST_EQUAL(mbedtls_rsa_get_len(rsa), (mod + 7) / 8);
TEST_EQUAL(mbedtls_rsa_complete(rsa), 0);
/* Turn PK context into an opaque one. */
@@ -1287,6 +1293,8 @@
PSA_KEY_USAGE_DECRYPT,
PSA_ALG_NONE), 0);
+ TEST_EQUAL(mbedtls_pk_get_bitlen(&pk), mod);
+
/* decryption test */
memset(output, 0, sizeof(output));
olen = 0;