Merge pull request #263 from hanno-arm/asn1_traversal_api

Introduce ASN.1 SEQUENCE traversal API
diff --git a/include/mbedtls/asn1.h b/include/mbedtls/asn1.h
index 1c6683f..33b3004 100644
--- a/include/mbedtls/asn1.h
+++ b/include/mbedtls/asn1.h
@@ -90,6 +90,18 @@
 #define MBEDTLS_ASN1_CONSTRUCTED             0x20
 #define MBEDTLS_ASN1_CONTEXT_SPECIFIC        0x80
 
+/* Slightly smaller way to check if tag is a string tag
+ * compared to canonical implementation. */
+#define MBEDTLS_ASN1_IS_STRING_TAG( tag )                                     \
+    ( ( tag ) < 32u && (                                                      \
+        ( ( 1u << ( tag ) ) & ( ( 1u << MBEDTLS_ASN1_BMP_STRING )       |     \
+                                ( 1u << MBEDTLS_ASN1_UTF8_STRING )      |     \
+                                ( 1u << MBEDTLS_ASN1_T61_STRING )       |     \
+                                ( 1u << MBEDTLS_ASN1_IA5_STRING )       |     \
+                                ( 1u << MBEDTLS_ASN1_UNIVERSAL_STRING ) |     \
+                                ( 1u << MBEDTLS_ASN1_PRINTABLE_STRING ) |     \
+                                ( 1u << MBEDTLS_ASN1_BIT_STRING ) ) ) != 0 ) )
+
 /*
  * Bit masks for each of the components of an ASN.1 tag as specified in
  * ITU X.690 (08/2015), section 8.1 "General rules for encoding",
@@ -120,6 +132,10 @@
         ( ( MBEDTLS_OID_SIZE(oid_str) != (oid_buf)->len ) ||                \
           memcmp( (oid_str), (oid_buf)->p, (oid_buf)->len) != 0 )
 
+#define MBEDTLS_OID_CMP_RAW(oid_str, oid_buf, oid_buf_len)              \
+        ( ( MBEDTLS_OID_SIZE(oid_str) != (oid_buf_len) ) ||             \
+          memcmp( (oid_str), (oid_buf), (oid_buf_len) ) != 0 )
+
 #ifdef __cplusplus
 extern "C" {
 #endif
@@ -327,6 +343,9 @@
  * \brief       Parses and splits an ASN.1 "SEQUENCE OF <tag>".
  *              Updates the pointer to immediately behind the full sequence tag.
  *
+ * This function allocates memory for the sequence elements. You can free
+ * the allocated memory with mbedtls_asn1_sequence_free().
+ *
  * \note        On error, this function may return a partial list in \p cur.
  *              You must set `cur->next = NULL` before calling this function!
  *              Otherwise it is impossible to distinguish a previously non-null
@@ -360,14 +379,133 @@
  * \return      0 if successful.
  * \return      #MBEDTLS_ERR_ASN1_LENGTH_MISMATCH if the input contains
  *              extra data after a valid SEQUENCE OF \p tag.
+ * \return      #MBEDTLS_ERR_ASN1_UNEXPECTED_TAG if the input starts with
+ *              an ASN.1 SEQUENCE in which an element has a tag that
+ *              is different from \p tag.
  * \return      #MBEDTLS_ERR_ASN1_ALLOC_FAILED if a memory allocation failed.
  * \return      An ASN.1 error code if the input does not start with
- *              a valid ASN.1 BIT STRING.
+ *              a valid ASN.1 SEQUENCE.
  */
 int mbedtls_asn1_get_sequence_of( unsigned char **p,
                                   const unsigned char *end,
                                   mbedtls_asn1_sequence *cur,
                                   int tag );
+/**
+ * \brief          Free a heap-allocated linked list presentation of
+ *                 an ASN.1 sequence, including the first element.
+ *
+ * There are two common ways to manage the memory used for the representation
+ * of a parsed ASN.1 sequence:
+ * - Allocate a head node `mbedtls_asn1_sequence *head` with mbedtls_calloc().
+ *   Pass this node as the `cur` argument to mbedtls_asn1_get_sequence_of().
+ *   When you have finished processing the sequence,
+ *   call mbedtls_asn1_sequence_free() on `head`.
+ * - Allocate a head node `mbedtls_asn1_sequence *head` in any manner,
+ *   for example on the stack. Make sure that `head->next == NULL`.
+ *   Pass `head` as the `cur` argument to mbedtls_asn1_get_sequence_of().
+ *   When you have finished processing the sequence,
+ *   call mbedtls_asn1_sequence_free() on `head->cur`,
+ *   then free `head` itself in the appropriate manner.
+ *
+ * \param seq      The address of the first sequence component. This may
+ *                 be \c NULL, in which case this functions returns
+ *                 immediately.
+ */
+void mbedtls_asn1_sequence_free( mbedtls_asn1_sequence *seq );
+
+/**
+ * \brief                Traverse an ASN.1 SEQUENCE container and
+ *                       call a callback for each entry.
+ *
+ * This function checks that the input is a SEQUENCE of elements that
+ * each have a "must" tag, and calls a callback function on the elements
+ * that have a "may" tag.
+ *
+ * For example, to validate that the input is a SEQUENCE of `tag1` and call
+ * `cb` on each element, use
+ * ```
+ * mbedtls_asn1_traverse_sequence_of(&p, end, 0xff, tag1, 0, 0, cb, ctx);
+ * ```
+ *
+ * To validate that the input is a SEQUENCE of ANY and call `cb` on
+ * each element, use
+ * ```
+ * mbedtls_asn1_traverse_sequence_of(&p, end, 0, 0, 0, 0, cb, ctx);
+ * ```
+ *
+ * To validate that the input is a SEQUENCE of CHOICE {NULL, OCTET STRING}
+ * and call `cb` on each element that is an OCTET STRING, use
+ * ```
+ * mbedtls_asn1_traverse_sequence_of(&p, end, 0xfe, 0x04, 0xff, 0x04, cb, ctx);
+ * ```
+ *
+ * The callback is called on the elements with a "may" tag from left to
+ * right. If the input is not a valid SEQUENCE of elements with a "must" tag,
+ * the callback is called on the elements up to the leftmost point where
+ * the input is invalid.
+ *
+ * \warning              This function is still experimental and may change
+ *                       at any time.
+ *
+ * \param p              The address of the pointer to the beginning of
+ *                       the ASN.1 SEQUENCE header. This is updated to
+ *                       point to the end of the ASN.1 SEQUENCE container
+ *                       on a successful invocation.
+ * \param end            The end of the ASN.1 SEQUENCE container.
+ * \param tag_must_mask  A mask to be applied to the ASN.1 tags found within
+ *                       the SEQUENCE before comparing to \p tag_must_value.
+ * \param tag_must_val   The required value of each ASN.1 tag found in the
+ *                       SEQUENCE, after masking with \p tag_must_mask.
+ *                       Mismatching tags lead to an error.
+ *                       For example, a value of \c 0 for both \p tag_must_mask
+ *                       and \p tag_must_val means that every tag is allowed,
+ *                       while a value of \c 0xFF for \p tag_must_mask means
+ *                       that \p tag_must_val is the only allowed tag.
+ * \param tag_may_mask   A mask to be applied to the ASN.1 tags found within
+ *                       the SEQUENCE before comparing to \p tag_may_value.
+ * \param tag_may_val    The desired value of each ASN.1 tag found in the
+ *                       SEQUENCE, after masking with \p tag_may_mask.
+ *                       Mismatching tags will be silently ignored.
+ *                       For example, a value of \c 0 for \p tag_may_mask and
+ *                       \p tag_may_val means that any tag will be considered,
+ *                       while a value of \c 0xFF for \p tag_may_mask means
+ *                       that all tags with value different from \p tag_may_val
+ *                       will be ignored.
+ * \param cb             The callback to trigger for each component
+ *                       in the ASN.1 SEQUENCE that matches \p tag_may_val.
+ *                       The callback function is called with the following
+ *                       parameters:
+ *                       - \p ctx.
+ *                       - The tag of the current element.
+ *                       - A pointer to the start of the current element's
+ *                         content inside the input.
+ *                       - The length of the content of the current element.
+ *                       If the callback returns a non-zero value,
+ *                       the function stops immediately,
+ *                       forwarding the callback's return value.
+ * \param ctx            The context to be passed to the callback \p cb.
+ *
+ * \return               \c 0 if successful the entire ASN.1 SEQUENCE
+ *                       was traversed without parsing or callback errors.
+ * \return               #MBEDTLS_ERR_ASN1_LENGTH_MISMATCH if the input
+ *                       contains extra data after a valid SEQUENCE
+ *                       of elements with an accepted tag.
+ * \return               #MBEDTLS_ERR_ASN1_UNEXPECTED_TAG if the input starts
+ *                       with an ASN.1 SEQUENCE in which an element has a tag
+ *                       that is not accepted.
+ * \return               An ASN.1 error code if the input does not start with
+ *                       a valid ASN.1 SEQUENCE.
+ * \return               A non-zero error code forwarded from the callback
+ *                       \p cb in case the latter returns a non-zero value.
+ */
+int mbedtls_asn1_traverse_sequence_of(
+    unsigned char **p,
+    const unsigned char *end,
+    unsigned char tag_must_mask, unsigned char tag_must_val,
+    unsigned char tag_may_mask, unsigned char tag_may_val,
+    int (*cb)( void *ctx, int tag,
+               unsigned char* start, size_t len ),
+    void *ctx );
 
 #if defined(MBEDTLS_BIGNUM_C)
 /**
diff --git a/library/asn1parse.c b/library/asn1parse.c
index e7e4d13..34c6607 100644
--- a/library/asn1parse.c
+++ b/library/asn1parse.c
@@ -248,6 +248,58 @@
 }
 
 /*
+ * Traverse an ASN.1 "SEQUENCE OF <tag>"
+ * and call a callback for each entry found.
+ */
+int mbedtls_asn1_traverse_sequence_of(
+    unsigned char **p,
+    const unsigned char *end,
+    unsigned char tag_must_mask, unsigned char tag_must_val,
+    unsigned char tag_may_mask, unsigned char tag_may_val,
+    int (*cb)( void *ctx, int tag,
+               unsigned char *start, size_t len ),
+    void *ctx )
+{
+    int ret;
+    size_t len;
+
+    /* Get main sequence tag */
+    if( ( ret = mbedtls_asn1_get_tag( p, end, &len,
+            MBEDTLS_ASN1_CONSTRUCTED | MBEDTLS_ASN1_SEQUENCE ) ) != 0 )
+    {
+        return( ret );
+    }
+
+    if( *p + len != end )
+        return( MBEDTLS_ERR_ASN1_LENGTH_MISMATCH );
+
+    while( *p < end )
+    {
+        unsigned char const tag = *(*p)++;
+
+        if( ( tag & tag_must_mask ) != tag_must_val )
+            return( MBEDTLS_ERR_ASN1_UNEXPECTED_TAG );
+
+        if( ( ret = mbedtls_asn1_get_len( p, end, &len ) ) != 0 )
+            return( ret );
+
+        if( ( tag & tag_may_mask ) == tag_may_val )
+        {
+            if( cb != NULL )
+            {
+                ret = cb( ctx, tag, *p, len );
+                if( ret != 0 )
+                    return( ret );
+            }
+        }
+
+        *p += len;
+    }
+
+    return( 0 );
+}
+
+/*
  * Get a bit string without unused bits
  */
 int mbedtls_asn1_get_bitstring_null( unsigned char **p, const unsigned char *end,
@@ -269,7 +321,51 @@
     return( 0 );
 }
 
+void mbedtls_asn1_sequence_free( mbedtls_asn1_sequence *seq )
+{
+    while( seq != NULL )
+    {
+        mbedtls_asn1_sequence *next = seq->next;
+        mbedtls_platform_zeroize( seq, sizeof( *seq ) );
+        mbedtls_free( seq );
+        seq = next;
+    }
+}
 
+typedef struct
+{
+    int tag;
+    mbedtls_asn1_sequence *cur;
+} asn1_get_sequence_of_cb_ctx_t;
+
+static int asn1_get_sequence_of_cb( void *ctx,
+                                    int tag,
+                                    unsigned char *start,
+                                    size_t len )
+{
+    asn1_get_sequence_of_cb_ctx_t *cb_ctx =
+        (asn1_get_sequence_of_cb_ctx_t *) ctx;
+    mbedtls_asn1_sequence *cur =
+        cb_ctx->cur;
+
+    if( cur->buf.p != NULL )
+    {
+        cur->next =
+            mbedtls_calloc( 1, sizeof( mbedtls_asn1_sequence ) );
+
+        if( cur->next == NULL )
+            return( MBEDTLS_ERR_ASN1_ALLOC_FAILED );
+
+        cur = cur->next;
+    }
+
+    cur->buf.p = start;
+    cur->buf.len = len;
+    cur->buf.tag = tag;
+
+    cb_ctx->cur = cur;
+    return( 0 );
+}
 
 /*
  *  Parses and splits an ASN.1 "SEQUENCE OF <tag>"
@@ -279,49 +375,11 @@
                           mbedtls_asn1_sequence *cur,
                           int tag)
 {
-    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-    size_t len;
-    mbedtls_asn1_buf *buf;
-
-    /* Get main sequence tag */
-    if( ( ret = mbedtls_asn1_get_tag( p, end, &len,
-            MBEDTLS_ASN1_CONSTRUCTED | MBEDTLS_ASN1_SEQUENCE ) ) != 0 )
-        return( ret );
-
-    if( *p + len != end )
-        return( MBEDTLS_ERR_ASN1_LENGTH_MISMATCH );
-
-    while( *p < end )
-    {
-        buf = &(cur->buf);
-        buf->tag = **p;
-
-        if( ( ret = mbedtls_asn1_get_tag( p, end, &buf->len, tag ) ) != 0 )
-            return( ret );
-
-        buf->p = *p;
-        *p += buf->len;
-
-        /* Allocate and assign next pointer */
-        if( *p < end )
-        {
-            cur->next = (mbedtls_asn1_sequence*)mbedtls_calloc( 1,
-                                            sizeof( mbedtls_asn1_sequence ) );
-
-            if( cur->next == NULL )
-                return( MBEDTLS_ERR_ASN1_ALLOC_FAILED );
-
-            cur = cur->next;
-        }
-    }
-
-    /* Set final sequence entry's next pointer to NULL */
-    cur->next = NULL;
-
-    if( *p != end )
-        return( MBEDTLS_ERR_ASN1_LENGTH_MISMATCH );
-
-    return( 0 );
+    asn1_get_sequence_of_cb_ctx_t cb_ctx = { tag, cur };
+    memset( cur, 0, sizeof( mbedtls_asn1_sequence ) );
+    return( mbedtls_asn1_traverse_sequence_of(
+                p, end, 0xFF, tag, 0, 0,
+                asn1_get_sequence_of_cb, &cb_ctx ) );
 }
 
 int mbedtls_asn1_get_alg( unsigned char **p,
diff --git a/tests/suites/test_suite_asn1parse.data b/tests/suites/test_suite_asn1parse.data
index e26f93a..6a66ee9 100644
--- a/tests/suites/test_suite_asn1parse.data
+++ b/tests/suites/test_suite_asn1parse.data
@@ -481,6 +481,60 @@
 Not a SEQUENCE (not SEQUENCE)
 get_sequence_of:"3100":0x04:"":MBEDTLS_ERR_ASN1_UNEXPECTED_TAG
 
+Traverse empty SEQUENCE
+traverse_sequence_of:"3000":0:0:0:0:"":0
+
+Traverse empty SEQUENCE plus trailing garbage
+traverse_sequence_of:"30007e":0:0:0:0:"":MBEDTLS_ERR_ASN1_LENGTH_MISMATCH
+
+Traverse SEQUENCE of INTEGER: 1 INTEGER
+traverse_sequence_of:"30050203123456":0xff:0x02:0:0:"4,0x02,3":0
+
+Traverse SEQUENCE of INTEGER: 2 INTEGERs
+traverse_sequence_of:"30080203123456020178":0xff:0x02:0:0:"4,0x02,3,9,0x02,1":0
+
+Traverse SEQUENCE of INTEGER: INTEGER, NULL
+traverse_sequence_of:"300702031234560500":0xff:0x02:0:0:"4,0x02,3":MBEDTLS_ERR_ASN1_UNEXPECTED_TAG
+
+Traverse SEQUENCE of INTEGER: NULL, INTEGER
+traverse_sequence_of:"300705000203123456":0xff:0x02:0:0:"":MBEDTLS_ERR_ASN1_UNEXPECTED_TAG
+
+Traverse SEQUENCE of ANY: NULL, INTEGER
+traverse_sequence_of:"300705000203123456":0:0:0:0:"4,0x05,0,6,0x02,3":0
+
+Traverse SEQUENCE of ANY, skip non-INTEGER: INTEGER, NULL
+traverse_sequence_of:"300702031234560500":0:0:0xff:0x02:"4,0x02,3":0
+
+Traverse SEQUENCE of ANY, skip non-INTEGER: NULL, INTEGER
+traverse_sequence_of:"300705000203123456":0:0:0xff:0x02:"6,0x02,3":0
+
+Traverse SEQUENCE of INTEGER, skip everything
+traverse_sequence_of:"30080203123456020178":0xff:0x02:0:1:"":0
+
+Traverse SEQUENCE of {NULL, OCTET STRING}, skip NULL: OS, NULL
+traverse_sequence_of:"300704031234560500":0xfe:0x04:0xff:0x04:"4,0x04,3":0
+
+Traverse SEQUENCE of {NULL, OCTET STRING}, skip NULL: NULL, OS
+traverse_sequence_of:"300705000403123456":0xfe:0x04:0xff:0x04:"6,0x04,3":0
+
+Traverse SEQUENCE of {NULL, OCTET STRING}, skip everything
+traverse_sequence_of:"300705000403123456":0xfe:0x04:0:1:"":0
+
+Traverse SEQUENCE of INTEGER, stop at 0: NULL
+traverse_sequence_of:"30020500":0xff:0x02:0:0:"":MBEDTLS_ERR_ASN1_UNEXPECTED_TAG
+
+Traverse SEQUENCE of INTEGER, stop at 0: INTEGER
+traverse_sequence_of:"30050203123456":0xff:0x02:0:0:"":RET_TRAVERSE_STOP
+
+Traverse SEQUENCE of INTEGER, stop at 0: INTEGER, NULL
+traverse_sequence_of:"300702031234560500":0xff:0x02:0:0:"":RET_TRAVERSE_STOP
+
+Traverse SEQUENCE of INTEGER, stop at 1: INTEGER, NULL
+traverse_sequence_of:"300702031234560500":0xff:0x02:0:0:"4,0x02,3":MBEDTLS_ERR_ASN1_UNEXPECTED_TAG
+
+Traverse SEQUENCE of INTEGER, stop at 1: INTEGER, INTEGER
+traverse_sequence_of:"30080203123456020178":0xff:0x02:0:0:"4,0x02,3":RET_TRAVERSE_STOP
+
 AlgorithmIdentifier, no params
 get_alg:"300506034f4944":4:3:0:0:0:7:0
 
diff --git a/tests/suites/test_suite_asn1parse.function b/tests/suites/test_suite_asn1parse.function
index f07fd40..3419f03 100644
--- a/tests/suites/test_suite_asn1parse.function
+++ b/tests/suites/test_suite_asn1parse.function
@@ -170,6 +170,53 @@
     return( 0 );
 }
 
+typedef struct
+{
+    const unsigned char *input_start;
+    const char *description;
+} traverse_state_t;
+
+/* Value returned by traverse_callback if description runs out. */
+#define RET_TRAVERSE_STOP 1
+/* Value returned by traverse_callback if description has an invalid format
+ * (see traverse_sequence_of). */
+#define RET_TRAVERSE_ERROR 2
+
+
+static int traverse_callback( void *ctx, int tag,
+                              unsigned char *content, size_t len )
+{
+    traverse_state_t *state = ctx;
+    size_t offset;
+    const char *rest = state->description;
+    unsigned long n;
+
+    TEST_ASSERT( content > state->input_start );
+    offset = content - state->input_start;
+    test_set_step( offset );
+
+    if( *rest == 0 )
+        return( RET_TRAVERSE_STOP );
+    n = strtoul( rest, (char **) &rest, 0 );
+    TEST_EQUAL( n, offset );
+    TEST_EQUAL( *rest, ',' );
+    ++rest;
+    n = strtoul( rest, (char **) &rest, 0 );
+    TEST_EQUAL( n, (unsigned) tag );
+    TEST_EQUAL( *rest, ',' );
+    ++rest;
+    n = strtoul( rest, (char **) &rest, 0 );
+    TEST_EQUAL( n, len );
+    if( *rest == ',' )
+        ++rest;
+
+    state->description = rest;
+    return( 0 );
+
+exit:
+    return( RET_TRAVERSE_ERROR );
+}
+
 /* END_HEADER */
 
 /* BEGIN_DEPENDENCIES
@@ -507,8 +554,15 @@
                       const char *description,
                       int expected_result )
 {
+    /* The description string is a comma-separated list of integers.
+     * For each element in the SEQUENCE in input, description contains
+     * two integers: the offset of the element (offset from the start
+     * of input to the tag of the element) and the length of the
+     * element's contents.
+     * "offset1,length1,..." */
+
     mbedtls_asn1_sequence head = { { 0, 0, NULL }, NULL };
-    mbedtls_asn1_sequence *cur, *next;
+    mbedtls_asn1_sequence *cur;
     unsigned char *p = input->x;
     const char *rest = description;
     unsigned long n;
@@ -549,13 +603,36 @@
     }
 
 exit:
-    cur = head.next;
-    while( cur != NULL )
-    {
-        next = cur->next;
-        mbedtls_free( cur );
-        cur = next;
-    }
+    mbedtls_asn1_sequence_free( head.next );
+}
+/* END_CASE */
+
+/* BEGIN_CASE */
+void traverse_sequence_of( const data_t *input,
+                           int tag_must_mask, int tag_must_val,
+                           int tag_may_mask, int tag_may_val,
+                           const char *description,
+                           int expected_result )
+{
+    /* The description string is a comma-separated list of integers.
+     * For each element in the SEQUENCE in input, description contains
+     * three integers: the offset of the element's content (offset from
+     * the start of input to the content of the element), the element's tag,
+     * and the length of the element's contents.
+     * "offset1,tag1,length1,..." */
+
+    unsigned char *p = input->x;
+    traverse_state_t traverse_state = {input->x, description};
+    int ret;
+
+    ret = mbedtls_asn1_traverse_sequence_of( &p, input->x + input->len,
+                                             (uint8_t) tag_must_mask, (uint8_t) tag_must_val,
+                                             (uint8_t) tag_may_mask, (uint8_t) tag_may_val,
+                                             traverse_callback, &traverse_state );
+    if( ret == RET_TRAVERSE_ERROR )
+        goto exit;
+    TEST_EQUAL( ret, expected_result );
+    TEST_EQUAL( *traverse_state.description, 0 );
 }
 /* END_CASE */