Don't immediately flush datagram after preparing a record

This commit finally enables datagram packing by modifying the
record preparation function ssl_write_record() to not always
calling mbedtls_ssl_flush_output().
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index ad071a9..878495b 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -100,6 +100,10 @@
                                      mbedtls_ssl_transform *transform );
 static void ssl_update_in_pointers( mbedtls_ssl_context *ssl,
                                     mbedtls_ssl_transform *transform );
+
+#define SSL_DONT_FORCE_FLUSH 0
+#define SSL_FORCE_FLUSH      1
+
 #if defined(MBEDTLS_SSL_PROTO_DTLS)
 
 static uint16_t ssl_get_maximum_datagram_size( mbedtls_ssl_context const *ssl )
@@ -112,6 +116,55 @@
     return( MBEDTLS_SSL_OUT_BUFFER_LEN );
 }
 
+static int ssl_get_remaining_space_in_datagram( mbedtls_ssl_context const *ssl )
+{
+    size_t   const bytes_written = ssl->out_left;
+    uint16_t const mtu           = ssl_get_maximum_datagram_size( ssl );
+
+    /* Double-check that the write-index hasn't gone
+     * past what we can transmit in a single datagram. */
+    if( bytes_written > (size_t) mtu )
+    {
+        /* Should never happen... */
+        return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
+    }
+
+    return( (int) ( mtu - bytes_written ) );
+}
+
+static int ssl_get_remaining_payload_in_datagram( mbedtls_ssl_context const *ssl )
+{
+    int ret;
+    size_t remaining, expansion;
+    size_t max_len = MBEDTLS_SSL_MAX_CONTENT_LEN;
+
+#if defined(MBEDTLS_SSL_MAX_FRAGMENT_LENGTH)
+    const size_t mfl = mbedtls_ssl_get_max_frag_len( ssl );
+
+    if( max_len > mfl )
+        max_len = mfl;
+#endif
+
+    ret = ssl_get_remaining_space_in_datagram( ssl );
+    if( ret < 0 )
+        return( ret );
+    remaining = (size_t) ret;
+
+    ret = mbedtls_ssl_get_record_expansion( ssl );
+    if( ret < 0 )
+        return( ret );
+    expansion = (size_t) ret;
+
+    if( remaining <= expansion )
+        return( 0 );
+
+    remaining -= expansion;
+    if( remaining >= max_len )
+        remaining = max_len;
+
+    return( (int) remaining );
+}
+
 /*
  * Double the retransmit timeout value, within the allowed range,
  * returning -1 if the maximum value has already been reached.
@@ -2857,20 +2910,9 @@
  */
 int mbedtls_ssl_flight_transmit( mbedtls_ssl_context *ssl )
 {
-    const int ret_payload = mbedtls_ssl_get_max_out_record_payload( ssl );
-    const size_t max_record_payload = (size_t) ret_payload;
-    /* DTLS handshake headers are 12 bytes */
-    const size_t max_hs_fragment_len = max_record_payload - 12;
-
+    int ret;
     MBEDTLS_SSL_DEBUG_MSG( 2, ( "=> mbedtls_ssl_flight_transmit" ) );
 
-    if( ret_payload < 0 )
-    {
-        MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_get_max_out_record_payload",
-                                  ret_payload );
-        return( ret_payload );
-    }
-
     if( ssl->handshake->retransmit_state != MBEDTLS_SSL_RETRANS_SENDING )
     {
         MBEDTLS_SSL_DEBUG_MSG( 2, ( "initialise flight transmission" ) );
@@ -2884,22 +2926,38 @@
 
     while( ssl->handshake->cur_msg != NULL )
     {
-        int ret;
+        size_t max_frag_len;
         const mbedtls_ssl_flight_item * const cur = ssl->handshake->cur_msg;
+
         /* Swap epochs before sending Finished: we can't do it after
          * sending ChangeCipherSpec, in case write returns WANT_READ.
          * Must be done before copying, may change out_msg pointer */
         if( cur->type == MBEDTLS_SSL_MSG_HANDSHAKE &&
-            cur->p[0] == MBEDTLS_SSL_HS_FINISHED )
+            cur->p[0] == MBEDTLS_SSL_HS_FINISHED   &&
+            ssl->handshake->cur_msg_p == ( cur->p + 12 ) )
         {
+            MBEDTLS_SSL_DEBUG_MSG( 2, ( "swap epochs to send finished message" ) );
             ssl_swap_epochs( ssl );
         }
 
+        ret = ssl_get_remaining_payload_in_datagram( ssl );
+        if( ret < 0 )
+            return( ret );
+        max_frag_len = (size_t) ret;
+
         /* CCS is copied as is, while HS messages may need fragmentation */
         if( cur->type == MBEDTLS_SSL_MSG_CHANGE_CIPHER_SPEC )
         {
+            if( max_frag_len == 0 )
+            {
+                if( ( ret = mbedtls_ssl_flush_output( ssl ) ) != 0 )
+                    return( ret );
+
+                continue;
+            }
+
             memcpy( ssl->out_msg, cur->p, cur->len );
-            ssl->out_msglen = cur->len;
+            ssl->out_msglen  = cur->len;
             ssl->out_msgtype = cur->type;
 
             /* Update position inside current message */
@@ -2911,14 +2969,31 @@
             const size_t hs_len = cur->len - 12;
             const size_t frag_off = p - ( cur->p + 12 );
             const size_t rem_len = hs_len - frag_off;
-            const size_t frag_len = rem_len > max_hs_fragment_len
-                                  ? max_hs_fragment_len : rem_len;
+            size_t cur_hs_frag_len, max_hs_frag_len;
 
-            if( frag_off == 0 && frag_len != hs_len )
+            if( max_frag_len < 12 )
+            {
+                if( cur->type == MBEDTLS_SSL_MSG_HANDSHAKE &&
+                    cur->p[0] == MBEDTLS_SSL_HS_FINISHED )
+                {
+                    ssl_swap_epochs( ssl );
+                }
+
+                if( ( ret = mbedtls_ssl_flush_output( ssl ) ) != 0 )
+                    return( ret );
+
+                continue;
+            }
+            max_hs_frag_len = max_frag_len - 12;
+
+            cur_hs_frag_len = rem_len > max_hs_frag_len ?
+                max_hs_frag_len : rem_len;
+
+            if( frag_off == 0 && cur_hs_frag_len != hs_len )
             {
                 MBEDTLS_SSL_DEBUG_MSG( 2, ( "fragmenting handshake message (%u > %u)",
-                                            (unsigned) hs_len,
-                                            (unsigned) max_hs_fragment_len ) );
+                                            (unsigned) cur_hs_frag_len,
+                                            (unsigned) max_hs_frag_len ) );
             }
 
             /* Messages are stored with handshake headers as if not fragmented,
@@ -2930,19 +3005,19 @@
             ssl->out_msg[7] = ( ( frag_off >>  8 ) & 0xff );
             ssl->out_msg[8] = ( ( frag_off       ) & 0xff );
 
-            ssl->out_msg[ 9] = ( ( frag_len >> 16 ) & 0xff );
-            ssl->out_msg[10] = ( ( frag_len >>  8 ) & 0xff );
-            ssl->out_msg[11] = ( ( frag_len       ) & 0xff );
+            ssl->out_msg[ 9] = ( ( cur_hs_frag_len >> 16 ) & 0xff );
+            ssl->out_msg[10] = ( ( cur_hs_frag_len >>  8 ) & 0xff );
+            ssl->out_msg[11] = ( ( cur_hs_frag_len       ) & 0xff );
 
             MBEDTLS_SSL_DEBUG_BUF( 3, "handshake header", ssl->out_msg, 12 );
 
-            /* Copy the handshake message content and set records fields */
-            memcpy( ssl->out_msg + 12, p, frag_len );
-            ssl->out_msglen = frag_len + 12;
+            /* Copy the handshame message content and set records fields */
+            memcpy( ssl->out_msg + 12, p, cur_hs_frag_len );
+            ssl->out_msglen = cur_hs_frag_len + 12;
             ssl->out_msgtype = cur->type;
 
             /* Update position inside current message */
-            ssl->handshake->cur_msg_p += frag_len;
+            ssl->handshake->cur_msg_p += cur_hs_frag_len;
         }
 
         /* If done with the current message move to the next one if any */
@@ -2961,13 +3036,17 @@
         }
 
         /* Actually send the message out */
-        if( ( ret = mbedtls_ssl_write_record( ssl ) ) != 0 )
+        if( ( ret = mbedtls_ssl_write_record( ssl,
+                                              SSL_DONT_FORCE_FLUSH ) ) != 0 )
         {
             MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_write_record", ret );
             return( ret );
         }
     }
 
+    if( ( ret = mbedtls_ssl_flush_output( ssl ) ) != 0 )
+        return( ret );
+
     /* Update state and set timer */
     if( ssl->state == MBEDTLS_SSL_HANDSHAKE_OVER )
         ssl->handshake->retransmit_state = MBEDTLS_SSL_RETRANS_FINISHED;
@@ -3158,7 +3237,7 @@
     else
 #endif
     {
-        if( ( ret = mbedtls_ssl_write_record( ssl ) ) != 0 )
+        if( ( ret = mbedtls_ssl_write_record( ssl, SSL_FORCE_FLUSH ) ) != 0 )
         {
             MBEDTLS_SSL_DEBUG_RET( 1, "ssl_write_record", ret );
             return( ret );
@@ -3182,10 +3261,11 @@
  *  - ssl->out_msglen: length of the record content (excl headers)
  *  - ssl->out_msg: record content
  */
-int mbedtls_ssl_write_record( mbedtls_ssl_context *ssl )
+int mbedtls_ssl_write_record( mbedtls_ssl_context *ssl, uint8_t force_flush )
 {
     int ret, done = 0;
     size_t len = ssl->out_msglen;
+    uint8_t flush = force_flush;
 
     MBEDTLS_SSL_DEBUG_MSG( 2, ( "=> write record" ) );
 
@@ -3288,7 +3368,21 @@
         }
     }
 
-    if( ( ret = mbedtls_ssl_flush_output( ssl ) ) != 0 )
+#if defined(MBEDTLS_SSL_PROTO_DTLS)
+    if( ssl->conf->transport == MBEDTLS_SSL_TRANSPORT_DATAGRAM )
+    {
+        size_t remaining = ssl_get_remaining_payload_in_datagram( ssl );
+        if( remaining == 0 )
+            flush = SSL_FORCE_FLUSH;
+        else
+        {
+            MBEDTLS_SSL_DEBUG_MSG( 2, ( "Stil %u bytes available in current datagram", (unsigned) remaining ) );
+        }
+    }
+#endif /* MBEDTLS_SSL_PROTO_DTLS */
+
+    if( ( flush == SSL_FORCE_FLUSH ) &&
+        ( ret = mbedtls_ssl_flush_output( ssl ) ) != 0 )
     {
         MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_flush_output", ret );
         return( ret );
@@ -4570,7 +4664,7 @@
     ssl->out_msg[0] = level;
     ssl->out_msg[1] = message;
 
-    if( ( ret = mbedtls_ssl_write_record( ssl ) ) != 0 )
+    if( ( ret = mbedtls_ssl_write_record( ssl, SSL_FORCE_FLUSH ) ) != 0 )
     {
         MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_write_record", ret );
         return( ret );
@@ -7815,7 +7909,7 @@
         ssl->out_msgtype = MBEDTLS_SSL_MSG_APPLICATION_DATA;
         memcpy( ssl->out_msg, buf, len );
 
-        if( ( ret = mbedtls_ssl_write_record( ssl ) ) != 0 )
+        if( ( ret = mbedtls_ssl_write_record( ssl, SSL_FORCE_FLUSH ) ) != 0 )
         {
             MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_write_record", ret );
             return( ret );