ASN.1: test that we can parse what we can write

In asn1_write tests, when there's a parsing function corresponding to the
write function, call it and check that it can parse what we wrote.

Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
diff --git a/tests/suites/test_suite_asn1write.function b/tests/suites/test_suite_asn1write.function
index 2b5697a..daa3830 100644
--- a/tests/suites/test_suite_asn1write.function
+++ b/tests/suites/test_suite_asn1write.function
@@ -16,6 +16,8 @@
 int generic_write_start_step( generic_write_data_t *data )
 {
     mbedtls_test_set_step( data->size );
+    mbedtls_free( data->output );
+    data->output = NULL;
     ASSERT_ALLOC( data->output, data->size == 0 ? 1 : data->size );
     data->end = data->output + data->size;
     data->p = data->end;
@@ -45,8 +47,6 @@
     ok = 1;
 
 exit:
-    mbedtls_free( data->output );
-    data->output = NULL;
     return( ok );
 }
 
@@ -70,6 +70,7 @@
         ret = mbedtls_asn1_write_null( &data.p, data.start );
         if( ! generic_write_finish_step( &data, expected, ret ) )
             goto exit;
+        /* There's no parsing function for NULL. */
     }
 
 exit:
@@ -90,6 +91,14 @@
         ret = mbedtls_asn1_write_bool( &data.p, data.start, val );
         if( ! generic_write_finish_step( &data, expected, ret ) )
             goto exit;
+#if defined(MBEDTLS_ASN1_PARSE_C)
+        if( ret >= 0 )
+        {
+            int read = 0xdeadbeef;
+            TEST_EQUAL( mbedtls_asn1_get_bool( &data.p, data.end, &read ), 0 );
+            TEST_EQUAL( val, read );
+        }
+#endif /* MBEDTLS_ASN1_PARSE_C */
     }
 
 exit:
@@ -110,6 +119,14 @@
         ret = mbedtls_asn1_write_int( &data.p, data.start, val );
         if( ! generic_write_finish_step( &data, expected, ret ) )
             goto exit;
+#if defined(MBEDTLS_ASN1_PARSE_C)
+        if( ret >= 0 )
+        {
+            int read = 0xdeadbeef;
+            TEST_EQUAL( mbedtls_asn1_get_int( &data.p, data.end, &read ), 0 );
+            TEST_EQUAL( val, read );
+        }
+#endif /* MBEDTLS_ASN1_PARSE_C */
     }
 
 exit:
@@ -131,6 +148,14 @@
         ret = mbedtls_asn1_write_enum( &data.p, data.start, val );
         if( ! generic_write_finish_step( &data, expected, ret ) )
             goto exit;
+#if defined(MBEDTLS_ASN1_PARSE_C)
+        if( ret >= 0 )
+        {
+            int read = 0xdeadbeef;
+            TEST_EQUAL( mbedtls_asn1_get_enum( &data.p, data.end, &read ), 0 );
+            TEST_EQUAL( val, read );
+        }
+#endif /* MBEDTLS_ASN1_PARSE_C */
     }
 
 exit:
@@ -142,10 +167,11 @@
 void mbedtls_asn1_write_mpi( data_t *val, data_t *expected )
 {
     generic_write_data_t data = { NULL, NULL, NULL, NULL, 0 };
-    mbedtls_mpi mpi;
+    mbedtls_mpi mpi, read;
     int ret;
 
     mbedtls_mpi_init( &mpi );
+    mbedtls_mpi_init( &read );
     TEST_ASSERT( mbedtls_mpi_read_binary( &mpi, val->x, val->len ) == 0 );
 
     for( data.size = 0; data.size <= expected->len + 1; data.size++ )
@@ -155,12 +181,21 @@
         ret = mbedtls_asn1_write_mpi( &data.p, data.start, &mpi );
         if( ! generic_write_finish_step( &data, expected, ret ) )
             goto exit;
+#if defined(MBEDTLS_ASN1_PARSE_C)
+        if( ret >= 0 )
+        {
+            TEST_EQUAL( mbedtls_asn1_get_mpi( &data.p, data.end, &read ), 0 );
+            TEST_EQUAL( 0, mbedtls_mpi_cmp_mpi( &mpi, &read ) );
+        }
+#endif /* MBEDTLS_ASN1_PARSE_C */
+        /* Skip some intermediate lengths, they're boring. */
         if( expected->len > 10 && data.size == 8 )
             data.size = expected->len - 2;
     }
 
 exit:
     mbedtls_mpi_free( &mpi );
+    mbedtls_mpi_free( &read );
     mbedtls_free( data.output );
 }
 /* END_CASE */
@@ -208,6 +243,8 @@
         }
         if( ! generic_write_finish_step( &data, expected, ret ) )
             goto exit;
+        /* There's no parsing function for octet or character strings. */
+        /* Skip some intermediate lengths, they're boring. */
         if( expected->len > 10 && data.size == 8 )
             data.size = expected->len - 2;
     }
@@ -224,6 +261,9 @@
 {
     generic_write_data_t data = { NULL, NULL, NULL, NULL, 0 };
     int ret;
+#if defined(MBEDTLS_ASN1_PARSE_C)
+    unsigned char *buf_complete = NULL;
+#endif /* MBEDTLS_ASN1_PARSE_C */
 
     for( data.size = 0; data.size <= expected->len + 1; data.size++ )
     {
@@ -240,10 +280,69 @@
             ret -= par_len;
         if( ! generic_write_finish_step( &data, expected, ret ) )
             goto exit;
+
+#if defined(MBEDTLS_ASN1_PARSE_C)
+        /* Only do a parse-back test if the parameters aren't too large for
+         * a small-heap environment. The boundary is somewhat arbitrary. */
+        if( ret >= 0 && par_len <= 1234 )
+        {
+            mbedtls_asn1_buf alg = {0, 0, NULL};
+            mbedtls_asn1_buf params = {0, 0, NULL};
+            /* The writing function doesn't write the parameters unless
+             * they're null: it only takes their length as input. But the
+             * parsing function requires the parameters to be present.
+             * Thus make up parameters. */
+            size_t data_len = data.end - data.p;
+            size_t len_complete = data_len + par_len;
+            unsigned char expected_params_tag;
+            size_t expected_params_len;
+            ASSERT_ALLOC( buf_complete, len_complete );
+            unsigned char *end_complete = buf_complete + len_complete;
+            memcpy( buf_complete, data.p, data_len );
+            if( par_len == 0 )
+            {
+                /* mbedtls_asn1_write_algorithm_identifier() wrote a NULL */
+                expected_params_tag = 0x05;
+                expected_params_len = 0;
+            }
+            else if( par_len >= 2 && par_len < 2 + 128 )
+            {
+                /* Write an OCTET STRING with a short length encoding */
+                expected_params_tag = buf_complete[data_len] = 0x04;
+                expected_params_len = par_len - 2;
+                buf_complete[data_len + 1] = (unsigned char) expected_params_len;
+            }
+            else if( par_len >= 4 + 128 && par_len < 3 + 256 * 256 )
+            {
+                /* Write an OCTET STRING with a two-byte length encoding */
+                expected_params_tag = buf_complete[data_len] = 0x04;
+                expected_params_len = par_len - 4;
+                buf_complete[data_len + 1] = 0x82;
+                buf_complete[data_len + 2] = (unsigned char) ( expected_params_len >> 8 );
+                buf_complete[data_len + 3] = (unsigned char) ( expected_params_len );
+            }
+            else
+            {
+                TEST_ASSERT( ! "Bad test data: invalid length of ASN.1 element" );
+            }
+            unsigned char *p = buf_complete;
+            TEST_EQUAL( mbedtls_asn1_get_alg( &p, end_complete,
+                                              &alg, &params ), 0 );
+            TEST_EQUAL( alg.tag, MBEDTLS_ASN1_OID );
+            ASSERT_COMPARE( alg.p, alg.len, oid->x, oid->len );
+            TEST_EQUAL( params.tag, expected_params_tag );
+            TEST_EQUAL( params.len, expected_params_len );
+            mbedtls_free( buf_complete );
+            buf_complete = NULL;
+        }
+#endif /* MBEDTLS_ASN1_PARSE_C */
     }
 
 exit:
     mbedtls_free( data.output );
+#if defined(MBEDTLS_ASN1_PARSE_C)
+    mbedtls_free( buf_complete );
+#endif /* MBEDTLS_ASN1_PARSE_C */
 }
 /* END_CASE */
 
@@ -308,6 +407,34 @@
                    const unsigned char *buf, size_t bits ) =
         ( is_named ? mbedtls_asn1_write_named_bitstring :
           mbedtls_asn1_write_bitstring );
+#if defined(MBEDTLS_ASN1_PARSE_C)
+    unsigned char *masked_bitstring = NULL;
+#endif /* MBEDTLS_ASN1_PARSE_C */
+
+    /* The API expects `bitstring->x` to contain `bits` bits. */
+    size_t byte_length = ( bits + 7 ) / 8;
+    TEST_ASSERT( bitstring->len >= byte_length );
+
+#if defined(MBEDTLS_ASN1_PARSE_C)
+    ASSERT_ALLOC( masked_bitstring, byte_length );
+    memcpy( masked_bitstring, bitstring->x, byte_length );
+    if( bits % 8 != 0 )
+        masked_bitstring[byte_length - 1] &= ~( 0xff >> ( bits % 8 ) );
+    size_t value_bits = bits;
+    if( is_named )
+    {
+        /* In a named bit string, all trailing 0 bits are removed. */
+        while( byte_length > 0 && masked_bitstring[byte_length - 1] == 0 )
+            --byte_length;
+        value_bits = 8 * byte_length;
+        if( byte_length > 0 )
+        {
+            unsigned char last_byte = masked_bitstring[byte_length - 1];
+            for( unsigned b = 1; b < 0xff && ( last_byte & b ) == 0; b <<= 1 )
+                --value_bits;
+        }
+    }
+#endif /* MBEDTLS_ASN1_PARSE_C */
 
     for( data.size = 0; data.size <= expected->len + 1; data.size++ )
     {
@@ -316,10 +443,24 @@
         ret = ( *func )( &data.p, data.start, bitstring->x, bits );
         if( ! generic_write_finish_step( &data, expected, ret ) )
             goto exit;
+#if defined(MBEDTLS_ASN1_PARSE_C)
+        if( ret >= 0 )
+        {
+            mbedtls_asn1_bitstring read = {0, 0, NULL};
+            TEST_EQUAL( mbedtls_asn1_get_bitstring( &data.p, data.end,
+                                                    &read ), 0 );
+            ASSERT_COMPARE( read.p, read.len,
+                            masked_bitstring, byte_length );
+            TEST_EQUAL( read.unused_bits, 8 * byte_length - value_bits );
+        }
+#endif /* MBEDTLS_ASN1_PARSE_C */
     }
 
 exit:
     mbedtls_free( data.output );
+#if defined(MBEDTLS_ASN1_PARSE_C)
+    mbedtls_free( masked_bitstring );
+#endif /* MBEDTLS_ASN1_PARSE_C */
 }
 /* END_CASE */