Adapt ssl_fetch_input() for UDP
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index e009b97..e44ffa6 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -1821,6 +1821,13 @@
/*
* Fill the input message buffer
+ *
+ * If we return 0, is it guaranteed that (at least) nb_want bytes are
+ * available (from this read and/or a previous one). Otherwise, an error code
+ * is returned (possibly EOF or WANT_READ).
+ *
+ * Set ssl->in_left to 0 before calling to start a new record. Apart from
+ * this, ssl->in_left is an internal variable and should never be read.
*/
int ssl_fetch_input( ssl_context *ssl, size_t nb_want )
{
@@ -1829,19 +1836,40 @@
SSL_DEBUG_MSG( 2, ( "=> fetch input" ) );
- if( nb_want > SSL_BUFFER_LEN - 8 )
+ if( nb_want > SSL_BUFFER_LEN - (size_t)( ssl->in_hdr - ssl->in_buf ) )
{
SSL_DEBUG_MSG( 1, ( "requesting more data than fits" ) );
return( POLARSSL_ERR_SSL_BAD_INPUT_DATA );
}
- while( ssl->in_left < nb_want )
+#if defined(POLARSSL_SSL_PROTO_DTLS)
+ if( ssl->transport == SSL_TRANSPORT_DATAGRAM )
{
- len = nb_want - ssl->in_left;
- ret = ssl->f_recv( ssl->p_recv, ssl->in_hdr + ssl->in_left, len );
-
SSL_DEBUG_MSG( 2, ( "in_left: %d, nb_want: %d",
ssl->in_left, nb_want ) );
+
+ /*
+ * With UDP, we must always read a full datagram.
+ * Just remember how much we read and avoid reading again if we
+ * already have enough data.
+ */
+ if( nb_want <= ssl->in_left)
+ return( 0 );
+
+ /*
+ * A record can't be split accross datagrams. If we need to read but
+ * are not at the beginning of a new record, the caller did something
+ * wrong.
+ */
+ if( ssl->in_left != 0 )
+ {
+ SSL_DEBUG_MSG( 1, ( "should never happen" ) );
+ return( POLARSSL_ERR_SSL_INTERNAL_ERROR );
+ }
+
+ len = SSL_BUFFER_LEN - ( ssl->in_hdr - ssl->in_buf );
+ ret = ssl->f_recv( ssl->p_recv, ssl->in_hdr, len );
+
SSL_DEBUG_RET( 2, "ssl->f_recv", ret );
if( ret == 0 )
@@ -1850,7 +1878,28 @@
if( ret < 0 )
return( ret );
- ssl->in_left += ret;
+ ssl->in_left = ret;
+ }
+ else
+#endif
+ {
+ while( ssl->in_left < nb_want )
+ {
+ len = nb_want - ssl->in_left;
+ ret = ssl->f_recv( ssl->p_recv, ssl->in_hdr + ssl->in_left, len );
+
+ SSL_DEBUG_MSG( 2, ( "in_left: %d, nb_want: %d",
+ ssl->in_left, nb_want ) );
+ SSL_DEBUG_RET( 2, "ssl->f_recv", ret );
+
+ if( ret == 0 )
+ return( POLARSSL_ERR_SSL_CONN_EOF );
+
+ if( ret < 0 )
+ return( ret );
+
+ ssl->in_left += ret;
+ }
}
SSL_DEBUG_MSG( 2, ( "<= fetch input" ) );
@@ -2140,7 +2189,8 @@
}
/* Sanity check (outer boundaries) */
- if( ssl->in_msglen < 1 || ssl->in_msglen > SSL_BUFFER_LEN - 13 )
+ if( ssl->in_msglen < 1 ||
+ ssl->in_msglen > SSL_BUFFER_LEN - (size_t)( ssl->in_msg - ssl->in_buf ) )
{
SSL_DEBUG_MSG( 1, ( "bad message length" ) );
return( POLARSSL_ERR_SSL_INVALID_RECORD );