Add explicit counter in DTLS record header
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index 58fb306..5ad244d 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -1284,18 +1284,6 @@
         return( POLARSSL_ERR_SSL_INTERNAL_ERROR );
     }
 
-    // TODO: adapt for DTLS (start from i = 6)
-    for( i = 8; i > 0; i-- )
-        if( ++ssl->out_ctr[i - 1] != 0 )
-            break;
-
-    /* The loops goes to its end iff the counter is wrapping */
-    if( i == 0 )
-    {
-        SSL_DEBUG_MSG( 1, ( "outgoing message counter would wrap" ) );
-        return( POLARSSL_ERR_SSL_COUNTER_WRAPPING );
-    }
-
     SSL_DEBUG_MSG( 2, ( "<= encrypt buf" ) );
 
     return( 0 );
@@ -1702,16 +1690,19 @@
     else
         ssl->nb_zero = 0;
 
-    // TODO: DTLS: i = 6
-    for( i = 8; i > 0; i-- )
-        if( ++ssl->in_ctr[i - 1] != 0 )
-            break;
-
-    /* The loops goes to its end iff the counter is wrapping */
-    if( i == 0 )
+    /* For DTLS we don't maintain our own incoming counter (for now) */
+    if( ssl->transport == SSL_TRANSPORT_STREAM )
     {
-        SSL_DEBUG_MSG( 1, ( "incoming message counter would wrap" ) );
-        return( POLARSSL_ERR_SSL_COUNTER_WRAPPING );
+        for( i = 8; i > 0; i-- )
+            if( ++ssl->in_ctr[i - 1] != 0 )
+                break;
+
+        /* The loop goes to its end iff the counter is wrapping */
+        if( i == 0 )
+        {
+            SSL_DEBUG_MSG( 1, ( "incoming message counter would wrap" ) );
+            return( POLARSSL_ERR_SSL_COUNTER_WRAPPING );
+        }
     }
 
     SSL_DEBUG_MSG( 2, ( "<= decrypt buf" ) );
@@ -1860,17 +1851,25 @@
  */
 int ssl_flush_output( ssl_context *ssl )
 {
-    int ret;
+    int ret, i;
     unsigned char *buf;
 
     SSL_DEBUG_MSG( 2, ( "=> flush output" ) );
 
+    /* Avoid incrementing counter if data is flushed */
+    if( ssl->out_left == 0 )
+    {
+        SSL_DEBUG_MSG( 2, ( "<= flush output" ) );
+        return( 0 );
+    }
+
     while( ssl->out_left > 0 )
     {
         SSL_DEBUG_MSG( 2, ( "message length: %d, out_left: %d",
-                       5 + ssl->out_msglen, ssl->out_left ) );
+                       ssl_hdr_len( ssl ) + ssl->out_msglen, ssl->out_left ) );
 
-        buf = ssl->out_hdr + 5 + ssl->out_msglen - ssl->out_left;
+        buf = ssl->out_hdr + ssl_hdr_len( ssl ) +
+              ssl->out_msglen - ssl->out_left;
         ret = ssl->f_send( ssl->p_send, buf, ssl->out_left );
 
         SSL_DEBUG_RET( 2, "ssl->f_send", ret );
@@ -1881,6 +1880,18 @@
         ssl->out_left -= ret;
     }
 
+    // TODO: adapt for DTLS (start from i = 6)
+    for( i = 8; i > 0; i-- )
+        if( ++ssl->out_ctr[i - 1] != 0 )
+            break;
+
+    /* The loop goes to its end iff the counter is wrapping */
+    if( i == 0 )
+    {
+        SSL_DEBUG_MSG( 1, ( "outgoing message counter would wrap" ) );
+        return( POLARSSL_ERR_SSL_COUNTER_WRAPPING );
+    }
+
     SSL_DEBUG_MSG( 2, ( "<= flush output" ) );
 
     return( 0 );
@@ -1958,7 +1969,7 @@
             ssl->out_len[1] = (unsigned char)( len      );
         }
 
-        ssl->out_left = 5 + ssl->out_msglen;
+        ssl->out_left = ssl_hdr_len( ssl ) + ssl->out_msglen;
 
         SSL_DEBUG_MSG( 3, ( "output record: msgtype = %d, "
                             "version = [%d:%d], msglen = %d",
@@ -1966,7 +1977,7 @@
                      ( ssl->out_len[0] << 8 ) | ssl->out_len[1] ) );
 
         SSL_DEBUG_BUF( 4, "output record sent to network",
-                       ssl->out_hdr, 5 + ssl->out_msglen );
+                       ssl->out_hdr, ssl_hdr_len( ssl ) + ssl->out_msglen );
     }
 
     if( ( ret = ssl_flush_output( ssl ) ) != 0 )
@@ -2028,7 +2039,7 @@
     /*
      * Read the record header and validate it
      */
-    if( ( ret = ssl_fetch_input( ssl, 5 ) ) != 0 )
+    if( ( ret = ssl_fetch_input( ssl, ssl_hdr_len( ssl ) ) ) != 0 )
     {
         SSL_DEBUG_RET( 1, "ssl_fetch_input", ret );
         return( ret );
@@ -2110,14 +2121,15 @@
     /*
      * Read and optionally decrypt the message contents
      */
-    if( ( ret = ssl_fetch_input( ssl, 5 + ssl->in_msglen ) ) != 0 )
+    if( ( ret = ssl_fetch_input( ssl,
+                                 ssl_hdr_len( ssl ) + ssl->in_msglen ) ) != 0 )
     {
         SSL_DEBUG_RET( 1, "ssl_fetch_input", ret );
         return( ret );
     }
 
     SSL_DEBUG_BUF( 4, "input record from network",
-                   ssl->in_hdr, 5 + ssl->in_msglen );
+                   ssl->in_hdr, ssl_hdr_len( ssl ) + ssl->in_msglen );
 
 #if defined(POLARSSL_SSL_HW_RECORD_ACCEL)
     if( ssl_hw_record_read != NULL )
@@ -3417,39 +3429,27 @@
 #endif
 
     /*
-     * Prepare base structures (assume TLS for now)
+     * Prepare base structures
      */
     ssl->in_buf = (unsigned char *) polarssl_malloc( len );
-    ssl->in_ctr = ssl->in_buf;
-    ssl->in_hdr = ssl->in_buf +  8;
-    ssl->in_len = ssl->in_buf + 11;
-    ssl->in_iv  = ssl->in_buf + 13;
-    ssl->in_msg = ssl->in_buf + 13;
-
-    if( ssl->in_buf == NULL )
-    {
-        SSL_DEBUG_MSG( 1, ( "malloc(%d bytes) failed", len ) );
-        return( POLARSSL_ERR_SSL_MALLOC_FAILED );
-    }
-
     ssl->out_buf = (unsigned char *) polarssl_malloc( len );
-    ssl->out_ctr = ssl->out_buf;
-    ssl->out_hdr = ssl->out_buf +  8;
-    ssl->out_len = ssl->out_buf + 11;
-    ssl->out_iv  = ssl->out_buf + 13;
-    ssl->out_msg = ssl->out_buf + 13;
 
-    if( ssl->out_buf == NULL )
+    if( ssl->in_buf == NULL || ssl->out_buf == NULL )
     {
         SSL_DEBUG_MSG( 1, ( "malloc(%d bytes) failed", len ) );
         polarssl_free( ssl->in_buf );
+        polarssl_free( ssl->out_buf );
         ssl->in_buf = NULL;
+        ssl->out_buf = NULL;
         return( POLARSSL_ERR_SSL_MALLOC_FAILED );
     }
 
     memset( ssl-> in_buf, 0, SSL_BUFFER_LEN );
     memset( ssl->out_buf, 0, SSL_BUFFER_LEN );
 
+    /* No error is possible, SSL_TRANSPORT_STREAM always valid */
+    (void) ssl_set_transport( ssl, SSL_TRANSPORT_STREAM );
+
 #if defined(POLARSSL_SSL_SESSION_TICKETS)
     ssl->ticket_lifetime = SSL_DEFAULT_TICKET_LIFETIME;
 #endif
@@ -3617,6 +3617,18 @@
     {
         ssl->transport = transport;
 
+        ssl->out_hdr = ssl->out_buf;
+        ssl->out_ctr = ssl->out_buf +  3;
+        ssl->out_len = ssl->out_buf + 11;
+        ssl->out_iv  = ssl->out_buf + 13;
+        ssl->out_msg = ssl->out_buf + 13;
+
+        ssl->in_hdr = ssl->in_buf;
+        ssl->in_ctr = ssl->in_buf +  3;
+        ssl->in_len = ssl->in_buf + 11;
+        ssl->in_iv  = ssl->in_buf + 13;
+        ssl->in_msg = ssl->in_buf + 13;
+
         /* DTLS starts with TLS1.1 */
         if( ssl->min_minor_ver < SSL_MINOR_VERSION_2 )
             ssl->min_minor_ver = SSL_MINOR_VERSION_2;
@@ -3631,6 +3643,19 @@
     if( transport == SSL_TRANSPORT_STREAM )
     {
         ssl->transport = transport;
+
+        ssl->out_ctr = ssl->out_buf;
+        ssl->out_hdr = ssl->out_buf +  8;
+        ssl->out_len = ssl->out_buf + 11;
+        ssl->out_iv  = ssl->out_buf + 13;
+        ssl->out_msg = ssl->out_buf + 13;
+
+        ssl->in_ctr = ssl->in_buf;
+        ssl->in_hdr = ssl->in_buf +  8;
+        ssl->in_len = ssl->in_buf + 11;
+        ssl->in_iv  = ssl->in_buf + 13;
+        ssl->in_msg = ssl->in_buf + 13;
+
         return( 0 );
     }