Add ALPN extension to the server side

CustomizedGitHooks: yes
Change-Id: I6fe1516963e7b5727710872ee91fea7fc51d2776
Signed-off-by: XiaokangQian <xiaokang.qian@arm.com>
diff --git a/library/ssl_misc.h b/library/ssl_misc.h
index b1f0c90..20e94dc 100644
--- a/library/ssl_misc.h
+++ b/library/ssl_misc.h
@@ -2280,4 +2280,16 @@
                                            mbedtls_pk_context *own_key,
                                            uint16_t *algorithm );
 
+#if defined(MBEDTLS_SSL_ALPN)
+int mbedtls_ssl_parse_alpn_ext( mbedtls_ssl_context *ssl,
+                                const unsigned char *buf,
+                                const unsigned char *end );
+
+
+int mbedtls_ssl_write_alpn_ext( mbedtls_ssl_context *ssl,
+                                unsigned char *buf,
+                                unsigned char *end,
+                                size_t *olen );
+#endif /* MBEDTLS_SSL_ALPN */
+
 #endif /* ssl_misc.h */
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index 8332461..6fa169f 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -8285,4 +8285,104 @@
 }
 #endif /* MBEDTLS_SSL_SERVER_NAME_INDICATION */
 
+#if defined(MBEDTLS_SSL_ALPN)
+int mbedtls_ssl_parse_alpn_ext( mbedtls_ssl_context *ssl,
+                                const unsigned char *buf,
+                                const unsigned char *end )
+{
+    const unsigned char *p = buf;
+    size_t list_len;
+
+    const unsigned char *cur_alpn;
+    size_t cur_alpn_len;
+
+    /* If ALPN not configured, just ignore the extension */
+    if( ssl->conf->alpn_list == NULL )
+        return( 0 );
+
+    /*
+     * opaque ProtocolName<1..2^8-1>;
+     *
+     * struct {
+     *     ProtocolName protocol_name_list<2..2^16-1>
+     * } ProtocolNameList;
+     */
+
+    /* Min length is 2 ( list_len ) + 1 ( name_len ) + 1 ( name ) */
+    MBEDTLS_SSL_CHK_BUF_READ_PTR( p, end, 4 );
+
+    list_len = MBEDTLS_GET_UINT16_BE( p, 0 );
+    p += 2;
+    MBEDTLS_SSL_CHK_BUF_READ_PTR( p, end, list_len );
+
+    /* Validate peer's list (lengths) */
+    for( cur_alpn = p; cur_alpn != end; cur_alpn += cur_alpn_len )
+    {
+        cur_alpn_len = *cur_alpn++;
+        MBEDTLS_SSL_CHK_BUF_READ_PTR( cur_alpn, end, cur_alpn_len );
+        if( cur_alpn_len == 0 )
+            return( MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER );
+    }
+
+    /* Use our order of preference */
+    for( const char **alpn = ssl->conf->alpn_list; *alpn != NULL; alpn++ )
+    {
+        size_t const alpn_len = strlen( *alpn );
+        for( cur_alpn = p; cur_alpn != end; cur_alpn += cur_alpn_len )
+        {
+            cur_alpn_len = *cur_alpn++;
+
+            if( cur_alpn_len == alpn_len &&
+                memcmp( cur_alpn, *alpn, alpn_len ) == 0 )
+            {
+                ssl->alpn_chosen = *alpn;
+                return( 0 );
+            }
+        }
+    }
+
+    /* If we get hhere, no match was found */
+    MBEDTLS_SSL_PEND_FATAL_ALERT(
+            MBEDTLS_SSL_ALERT_MSG_NO_APPLICATION_PROTOCOL,
+            MBEDTLS_ERR_SSL_NO_APPLICATION_PROTOCOL );
+    return( MBEDTLS_ERR_SSL_NO_APPLICATION_PROTOCOL );
+}
+
+int mbedtls_ssl_write_alpn_ext( mbedtls_ssl_context *ssl,
+                                unsigned char *buf,
+                                unsigned char *end,
+                                size_t *olen )
+{
+    unsigned char *p = buf;
+    *olen = 0;
+
+    if( ssl->alpn_chosen == NULL )
+    {
+        return( 0 );
+    }
+
+    MBEDTLS_SSL_CHK_BUF_PTR( p, end, 7 + strlen( ssl->alpn_chosen ) );
+
+    MBEDTLS_SSL_DEBUG_MSG( 3, ( "server side, adding alpn extension" ) );
+    /*
+     * 0 . 1    ext identifier
+     * 2 . 3    ext length
+     * 4 . 5    protocol list length
+     * 6 . 6    protocol name length
+     * 7 . 7+n  protocol name
+     */
+    MBEDTLS_PUT_UINT16_BE( MBEDTLS_TLS_EXT_ALPN, p, 0 );
+
+    *olen = 7 + strlen( ssl->alpn_chosen );
+
+    MBEDTLS_PUT_UINT16_BE( *olen - 4, p, 2 );
+    MBEDTLS_PUT_UINT16_BE( *olen - 6, p, 4 );
+    p[6] = MBEDTLS_BYTE_0( *olen - 7 );
+    p += 7;
+
+    memcpy( p, ssl->alpn_chosen, *olen - 7 );
+    return ( 0 );
+}
+#endif /* MBEDTLS_SSL_ALPN */
+
 #endif /* MBEDTLS_SSL_TLS_C */
diff --git a/library/ssl_tls12_server.c b/library/ssl_tls12_server.c
index e92014c..21e5cda 100644
--- a/library/ssl_tls12_server.c
+++ b/library/ssl_tls12_server.c
@@ -528,94 +528,6 @@
 }
 #endif /* MBEDTLS_SSL_SESSION_TICKETS */
 
-#if defined(MBEDTLS_SSL_ALPN)
-static int ssl_parse_alpn_ext( mbedtls_ssl_context *ssl,
-                               const unsigned char *buf, size_t len )
-{
-    size_t list_len, cur_len, ours_len;
-    const unsigned char *theirs, *start, *end;
-    const char **ours;
-
-    /* If ALPN not configured, just ignore the extension */
-    if( ssl->conf->alpn_list == NULL )
-        return( 0 );
-
-    /*
-     * opaque ProtocolName<1..2^8-1>;
-     *
-     * struct {
-     *     ProtocolName protocol_name_list<2..2^16-1>
-     * } ProtocolNameList;
-     */
-
-    /* Min length is 2 (list_len) + 1 (name_len) + 1 (name) */
-    if( len < 4 )
-    {
-        mbedtls_ssl_send_alert_message( ssl, MBEDTLS_SSL_ALERT_LEVEL_FATAL,
-                                        MBEDTLS_SSL_ALERT_MSG_DECODE_ERROR );
-        return( MBEDTLS_ERR_SSL_DECODE_ERROR );
-    }
-
-    list_len = ( buf[0] << 8 ) | buf[1];
-    if( list_len != len - 2 )
-    {
-        mbedtls_ssl_send_alert_message( ssl, MBEDTLS_SSL_ALERT_LEVEL_FATAL,
-                                        MBEDTLS_SSL_ALERT_MSG_DECODE_ERROR );
-        return( MBEDTLS_ERR_SSL_DECODE_ERROR );
-    }
-
-    /*
-     * Validate peer's list (lengths)
-     */
-    start = buf + 2;
-    end = buf + len;
-    for( theirs = start; theirs != end; theirs += cur_len )
-    {
-        cur_len = *theirs++;
-
-        /* Current identifier must fit in list */
-        if( cur_len > (size_t)( end - theirs ) )
-        {
-            mbedtls_ssl_send_alert_message( ssl, MBEDTLS_SSL_ALERT_LEVEL_FATAL,
-                                            MBEDTLS_SSL_ALERT_MSG_DECODE_ERROR );
-            return( MBEDTLS_ERR_SSL_DECODE_ERROR );
-        }
-
-        /* Empty strings MUST NOT be included */
-        if( cur_len == 0 )
-        {
-            mbedtls_ssl_send_alert_message( ssl, MBEDTLS_SSL_ALERT_LEVEL_FATAL,
-                                            MBEDTLS_SSL_ALERT_MSG_ILLEGAL_PARAMETER );
-            return( MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER );
-        }
-    }
-
-    /*
-     * Use our order of preference
-     */
-    for( ours = ssl->conf->alpn_list; *ours != NULL; ours++ )
-    {
-        ours_len = strlen( *ours );
-        for( theirs = start; theirs != end; theirs += cur_len )
-        {
-            cur_len = *theirs++;
-
-            if( cur_len == ours_len &&
-                memcmp( theirs, *ours, cur_len ) == 0 )
-            {
-                ssl->alpn_chosen = *ours;
-                return( 0 );
-            }
-        }
-    }
-
-    /* If we get there, no match was found */
-    mbedtls_ssl_send_alert_message( ssl, MBEDTLS_SSL_ALERT_LEVEL_FATAL,
-                            MBEDTLS_SSL_ALERT_MSG_NO_APPLICATION_PROTOCOL );
-    return( MBEDTLS_ERR_SSL_NO_APPLICATION_PROTOCOL );
-}
-#endif /* MBEDTLS_SSL_ALPN */
-
 #if defined(MBEDTLS_SSL_DTLS_SRTP)
 static int ssl_parse_use_srtp_ext( mbedtls_ssl_context *ssl,
                                    const unsigned char *buf,
@@ -1524,7 +1436,8 @@
             case MBEDTLS_TLS_EXT_ALPN:
                 MBEDTLS_SSL_DEBUG_MSG( 3, ( "found alpn extension" ) );
 
-                ret = ssl_parse_alpn_ext( ssl, ext + 4, ext_size );
+                ret = mbedtls_ssl_parse_alpn_ext( ssl, ext + 4,
+                                                  ext + 4 + ext_size );
                 if( ret != 0 )
                     return( ret );
                 break;
@@ -2040,39 +1953,6 @@
 }
 #endif /* MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED */
 
-#if defined(MBEDTLS_SSL_ALPN )
-static void ssl_write_alpn_ext( mbedtls_ssl_context *ssl,
-                                unsigned char *buf, size_t *olen )
-{
-    if( ssl->alpn_chosen == NULL )
-    {
-        *olen = 0;
-        return;
-    }
-
-    MBEDTLS_SSL_DEBUG_MSG( 3, ( "server hello, adding alpn extension" ) );
-
-    /*
-     * 0 . 1    ext identifier
-     * 2 . 3    ext length
-     * 4 . 5    protocol list length
-     * 6 . 6    protocol name length
-     * 7 . 7+n  protocol name
-     */
-    MBEDTLS_PUT_UINT16_BE( MBEDTLS_TLS_EXT_ALPN, buf, 0);
-
-    *olen = 7 + strlen( ssl->alpn_chosen );
-
-    MBEDTLS_PUT_UINT16_BE( *olen - 4, buf, 2 );
-
-    MBEDTLS_PUT_UINT16_BE( *olen - 6, buf, 4 );
-
-    buf[6] = MBEDTLS_BYTE_0( *olen - 7 );
-
-    memcpy( buf + 7, ssl->alpn_chosen, *olen - 7 );
-}
-#endif /* MBEDTLS_ECDH_C || MBEDTLS_ECDSA_C */
-
 #if defined(MBEDTLS_SSL_DTLS_SRTP ) && defined(MBEDTLS_SSL_PROTO_DTLS)
 static void ssl_write_use_srtp_ext( mbedtls_ssl_context *ssl,
                                     unsigned char *buf,
@@ -2446,7 +2326,8 @@
 #endif
 
 #if defined(MBEDTLS_SSL_ALPN)
-    ssl_write_alpn_ext( ssl, p + 2 + ext_len, &olen );
+    unsigned char *end = buf + MBEDTLS_SSL_OUT_CONTENT_LEN - 4;
+    mbedtls_ssl_write_alpn_ext( ssl, p + 2 + ext_len, end, &olen );
     ext_len += olen;
 #endif
 
diff --git a/library/ssl_tls13_server.c b/library/ssl_tls13_server.c
index 5be338d..2ee67bf 100644
--- a/library/ssl_tls13_server.c
+++ b/library/ssl_tls13_server.c
@@ -303,6 +303,13 @@
                 & MBEDTLS_SSL_EXT_SERVERNAME ) > 0 ) ?
                 "TRUE" : "FALSE" ) );
 #endif /* MBEDTLS_SSL_SERVER_NAME_INDICATION */
+#if defined ( MBEDTLS_SSL_ALPN )
+    MBEDTLS_SSL_DEBUG_MSG( 3,
+            ( "- ALPN_EXTENSION   ( %s )",
+            ( ( ssl->handshake->extensions_present
+                & MBEDTLS_SSL_EXT_ALPN ) > 0 ) ?
+                "TRUE" : "FALSE" ) );
+#endif /* MBEDTLS_SSL_ALPN */
 }
 #endif /* MBEDTLS_DEBUG_C */
 
@@ -731,6 +738,21 @@
                 ssl->handshake->extensions_present |= MBEDTLS_SSL_EXT_SUPPORTED_VERSIONS;
                 break;
 
+#if defined(MBEDTLS_SSL_ALPN)
+            case MBEDTLS_TLS_EXT_ALPN:
+                MBEDTLS_SSL_DEBUG_MSG( 3, ( "found alpn extension" ) );
+
+                ret = mbedtls_ssl_parse_alpn_ext( ssl, p, extension_data_end );
+                if( ret != 0 )
+                {
+                    MBEDTLS_SSL_DEBUG_RET(
+                            1, ( "mbedtls_ssl_parse_alpn_ext" ), ret );
+                    return( ret );
+                }
+                ssl->handshake->extensions_present |= MBEDTLS_SSL_EXT_ALPN;
+                break;
+#endif /* MBEDTLS_SSL_ALPN */
+
 #if defined(MBEDTLS_KEY_EXCHANGE_WITH_CERT_ENABLED)
             case MBEDTLS_TLS_EXT_SIG_ALG:
                 MBEDTLS_SSL_DEBUG_MSG( 3, ( "found signature_algorithms extension" ) );
@@ -1361,9 +1383,11 @@
                                                       unsigned char *end,
                                                       size_t *out_len )
 {
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     unsigned char *p = buf;
     size_t extensions_len = 0;
     unsigned char *p_extensions_len;
+    size_t output_len;
 
     *out_len = 0;
 
@@ -1372,6 +1396,15 @@
     p += 2;
 
     ((void) ssl);
+    ((void) ret);
+    ((void) output_len);
+
+#if defined(MBEDTLS_SSL_ALPN)
+    ret = mbedtls_ssl_write_alpn_ext( ssl, p, end, &output_len );
+    if( ret != 0 )
+        return( ret );
+    p  += output_len;
+#endif /* MBEDTLS_SSL_ALPN */
 
     extensions_len = ( p - p_extensions_len ) - 2;
     MBEDTLS_PUT_UINT16_BE( extensions_len, p_extensions_len, 0 );