Adapt ssl_fetch_input() for UDP
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index e009b97..e44ffa6 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -1821,6 +1821,13 @@
 
 /*
  * Fill the input message buffer
+ *
+ * If we return 0, is it guaranteed that (at least) nb_want bytes are
+ * available (from this read and/or a previous one). Otherwise, an error code
+ * is returned (possibly EOF or WANT_READ).
+ *
+ * Set ssl->in_left to 0 before calling to start a new record. Apart from
+ * this, ssl->in_left is an internal variable and should never be read.
  */
 int ssl_fetch_input( ssl_context *ssl, size_t nb_want )
 {
@@ -1829,19 +1836,40 @@
 
     SSL_DEBUG_MSG( 2, ( "=> fetch input" ) );
 
-    if( nb_want > SSL_BUFFER_LEN - 8 )
+    if( nb_want > SSL_BUFFER_LEN - (size_t)( ssl->in_hdr - ssl->in_buf ) )
     {
         SSL_DEBUG_MSG( 1, ( "requesting more data than fits" ) );
         return( POLARSSL_ERR_SSL_BAD_INPUT_DATA );
     }
 
-    while( ssl->in_left < nb_want )
+#if defined(POLARSSL_SSL_PROTO_DTLS)
+    if( ssl->transport == SSL_TRANSPORT_DATAGRAM )
     {
-        len = nb_want - ssl->in_left;
-        ret = ssl->f_recv( ssl->p_recv, ssl->in_hdr + ssl->in_left, len );
-
         SSL_DEBUG_MSG( 2, ( "in_left: %d, nb_want: %d",
                        ssl->in_left, nb_want ) );
+
+        /*
+         * With UDP, we must always read a full datagram.
+         * Just remember how much we read and avoid reading again if we
+         * already have enough data.
+         */
+        if( nb_want <= ssl->in_left)
+            return( 0 );
+
+        /*
+         * A record can't be split accross datagrams. If we need to read but
+         * are not at the beginning of a new record, the caller did something
+         * wrong.
+         */
+        if( ssl->in_left != 0 )
+        {
+            SSL_DEBUG_MSG( 1, ( "should never happen" ) );
+            return( POLARSSL_ERR_SSL_INTERNAL_ERROR );
+        }
+
+        len = SSL_BUFFER_LEN - ( ssl->in_hdr - ssl->in_buf );
+        ret = ssl->f_recv( ssl->p_recv, ssl->in_hdr, len );
+
         SSL_DEBUG_RET( 2, "ssl->f_recv", ret );
 
         if( ret == 0 )
@@ -1850,7 +1878,28 @@
         if( ret < 0 )
             return( ret );
 
-        ssl->in_left += ret;
+        ssl->in_left = ret;
+    }
+    else
+#endif
+    {
+        while( ssl->in_left < nb_want )
+        {
+            len = nb_want - ssl->in_left;
+            ret = ssl->f_recv( ssl->p_recv, ssl->in_hdr + ssl->in_left, len );
+
+            SSL_DEBUG_MSG( 2, ( "in_left: %d, nb_want: %d",
+                           ssl->in_left, nb_want ) );
+            SSL_DEBUG_RET( 2, "ssl->f_recv", ret );
+
+            if( ret == 0 )
+                return( POLARSSL_ERR_SSL_CONN_EOF );
+
+            if( ret < 0 )
+                return( ret );
+
+            ssl->in_left += ret;
+        }
     }
 
     SSL_DEBUG_MSG( 2, ( "<= fetch input" ) );
@@ -2140,7 +2189,8 @@
     }
 
     /* Sanity check (outer boundaries) */
-    if( ssl->in_msglen < 1 || ssl->in_msglen > SSL_BUFFER_LEN - 13 )
+    if( ssl->in_msglen < 1 ||
+        ssl->in_msglen > SSL_BUFFER_LEN - (size_t)( ssl->in_msg - ssl->in_buf ) )
     {
         SSL_DEBUG_MSG( 1, ( "bad message length" ) );
         return( POLARSSL_ERR_SSL_INVALID_RECORD );