- Abstracted checksum updating during handshake

diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index 9a962b2..88c6e55 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -195,6 +195,16 @@
     return( 0 );
 }
 
+static void ssl_update_checksum_start(ssl_context *, unsigned char *, size_t);
+static void ssl_update_checksum_md5sha1(ssl_context *, unsigned char *, size_t);
+static void ssl_update_checksum_sha256(ssl_context *, unsigned char *, size_t);
+static void ssl_update_checksum_sha384(ssl_context *, unsigned char *, size_t);
+
+static void ssl_calc_verify_ssl(ssl_context *,unsigned char *);
+static void ssl_calc_verify_tls(ssl_context *,unsigned char *);
+static void ssl_calc_verify_tls_sha256(ssl_context *,unsigned char *);
+static void ssl_calc_verify_tls_sha384(ssl_context *,unsigned char *);
+
 static void ssl_calc_finished_ssl(ssl_context *,unsigned char *,int);
 static void ssl_calc_finished_tls(ssl_context *,unsigned char *,int);
 static void ssl_calc_finished_tls_sha256(ssl_context *,unsigned char *,int);
@@ -221,22 +231,26 @@
     if( ssl->minor_ver == SSL_MINOR_VERSION_0 )
     {
         ssl->tls_prf = tls1_prf;
+        ssl->calc_verify = ssl_calc_verify_ssl;
         ssl->calc_finished = ssl_calc_finished_ssl;
     }
     else if( ssl->minor_ver < SSL_MINOR_VERSION_3 )
     {
         ssl->tls_prf = tls1_prf;
+        ssl->calc_verify = ssl_calc_verify_tls;
         ssl->calc_finished = ssl_calc_finished_tls;
     }
     else if( ssl->session->ciphersuite == SSL_RSA_AES_256_GCM_SHA384 ||
             ssl->session->ciphersuite == SSL_EDH_RSA_AES_256_GCM_SHA384 )
     {
         ssl->tls_prf = tls_prf_sha384;
+        ssl->calc_verify = ssl_calc_verify_tls_sha384;
         ssl->calc_finished = ssl_calc_finished_tls_sha384;
     }
     else
     {
         ssl->tls_prf = tls_prf_sha256;
+        ssl->calc_verify = ssl_calc_verify_tls_sha256;
         ssl->calc_finished = ssl_calc_finished_tls_sha256;
     }
 
@@ -602,61 +616,91 @@
     return( 0 );
 }
 
-void ssl_calc_verify( ssl_context *ssl, unsigned char hash[48] )
+void ssl_calc_verify_ssl( ssl_context *ssl, unsigned char hash[36] )
 {
     md5_context md5;
     sha1_context sha1;
-    sha2_context sha2;
-    sha4_context sha4;
     unsigned char pad_1[48];
     unsigned char pad_2[48];
 
-    SSL_DEBUG_MSG( 2, ( "=> calc verify" ) );
+    SSL_DEBUG_MSG( 2, ( "=> calc verify ssl" ) );
 
-    memcpy( &md5 , &ssl->fin_md5 , sizeof(  md5_context ) );
-    memcpy( &sha1, &ssl->fin_sha1, sizeof( sha1_context ) );
-    memcpy( &sha2, &ssl->fin_sha2, sizeof( sha2_context ) );
-    memcpy( &sha4, &ssl->fin_sha4, sizeof( sha4_context ) );
+    memcpy( &md5 , (md5_context *) ssl->ctx_checksum, sizeof(md5_context) );
+    memcpy( &sha1, (sha1_context *) ( ssl->ctx_checksum + sizeof(md5_context) ),
+            sizeof( sha1_context ) );
 
-    if( ssl->minor_ver == SSL_MINOR_VERSION_0 )
-    {
-        memset( pad_1, 0x36, 48 );
-        memset( pad_2, 0x5C, 48 );
+    memset( pad_1, 0x36, 48 );
+    memset( pad_2, 0x5C, 48 );
 
-        md5_update( &md5, ssl->session->master, 48 );
-        md5_update( &md5, pad_1, 48 );
-        md5_finish( &md5, hash );
+    md5_update( &md5, ssl->session->master, 48 );
+    md5_update( &md5, pad_1, 48 );
+    md5_finish( &md5, hash );
 
-        md5_starts( &md5 );
-        md5_update( &md5, ssl->session->master, 48 );
-        md5_update( &md5, pad_2, 48 );
-        md5_update( &md5, hash,  16 );
-        md5_finish( &md5, hash );
-        
-        sha1_update( &sha1, ssl->session->master, 48 );
-        sha1_update( &sha1, pad_1, 40 );
-        sha1_finish( &sha1, hash + 16 );
+    md5_starts( &md5 );
+    md5_update( &md5, ssl->session->master, 48 );
+    md5_update( &md5, pad_2, 48 );
+    md5_update( &md5, hash,  16 );
+    md5_finish( &md5, hash );
 
-        sha1_starts( &sha1 );
-        sha1_update( &sha1, ssl->session->master, 48 );
-        sha1_update( &sha1, pad_2, 40 );
-        sha1_update( &sha1, hash + 16, 20 );
-        sha1_finish( &sha1, hash + 16 );
-    }
-    else if( ssl->minor_ver != SSL_MINOR_VERSION_3 ) /* TLSv1 */
-    {
-         md5_finish( &md5,  hash );
-        sha1_finish( &sha1, hash + 16 );
-    }
-    else if( ssl->session->ciphersuite == SSL_RSA_AES_256_GCM_SHA384 ||
-             ssl->session->ciphersuite == SSL_EDH_RSA_AES_256_GCM_SHA384 )
-    {
-        sha4_finish( &sha4, hash );
-    }
-    else
-    {
-        sha2_finish( &sha2, hash );
-    }
+    sha1_update( &sha1, ssl->session->master, 48 );
+    sha1_update( &sha1, pad_1, 40 );
+    sha1_finish( &sha1, hash + 16 );
+
+    sha1_starts( &sha1 );
+    sha1_update( &sha1, ssl->session->master, 48 );
+    sha1_update( &sha1, pad_2, 40 );
+    sha1_update( &sha1, hash + 16, 20 );
+    sha1_finish( &sha1, hash + 16 );
+
+    SSL_DEBUG_BUF( 3, "calculated verify result", hash, 36 );
+    SSL_DEBUG_MSG( 2, ( "<= calc verify" ) );
+
+    return;
+}
+
+void ssl_calc_verify_tls( ssl_context *ssl, unsigned char hash[36] )
+{
+    md5_context md5;
+    sha1_context sha1;
+
+    SSL_DEBUG_MSG( 2, ( "=> calc verify tls" ) );
+
+    memcpy( &md5 , (md5_context *) ssl->ctx_checksum, sizeof(md5_context) );
+    memcpy( &sha1, (sha1_context *) ( ssl->ctx_checksum + sizeof(md5_context) ),
+            sizeof( sha1_context ) );
+
+    md5_finish( &md5,  hash );
+    sha1_finish( &sha1, hash + 16 );
+
+    SSL_DEBUG_BUF( 3, "calculated verify result", hash, 36 );
+    SSL_DEBUG_MSG( 2, ( "<= calc verify" ) );
+
+    return;
+}
+
+void ssl_calc_verify_tls_sha256( ssl_context *ssl, unsigned char hash[32] )
+{
+    sha2_context sha2;
+
+    SSL_DEBUG_MSG( 2, ( "=> calc verify sha256" ) );
+
+    memcpy( &sha2 , (sha2_context *) ssl->ctx_checksum, sizeof(sha2_context) );
+    sha2_finish( &sha2, hash );
+
+    SSL_DEBUG_BUF( 3, "calculated verify result", hash, 32 );
+    SSL_DEBUG_MSG( 2, ( "<= calc verify" ) );
+
+    return;
+}
+
+void ssl_calc_verify_tls_sha384( ssl_context *ssl, unsigned char hash[48] )
+{
+    sha4_context sha4;
+
+    SSL_DEBUG_MSG( 2, ( "=> calc verify sha384" ) );
+
+    memcpy( &sha4 , (sha4_context *) ssl->ctx_checksum, sizeof(sha4_context) );
+    sha4_finish( &sha4, hash );
 
     SSL_DEBUG_BUF( 3, "calculated verify result", hash, 48 );
     SSL_DEBUG_MSG( 2, ( "<= calc verify" ) );
@@ -1395,10 +1439,7 @@
         ssl->out_msg[2] = (unsigned char)( ( len - 4 ) >>  8 );
         ssl->out_msg[3] = (unsigned char)( ( len - 4 )       );
 
-         md5_update( &ssl->fin_md5 , ssl->out_msg, len );
-        sha1_update( &ssl->fin_sha1, ssl->out_msg, len );
-        sha2_update( &ssl->fin_sha2, ssl->out_msg, len );
-        sha4_update( &ssl->fin_sha4, ssl->out_msg, len );
+        ssl->update_checksum( ssl, ssl->out_msg, len );
     }
 
     if( ssl->do_crypt != 0 )
@@ -1471,10 +1512,7 @@
             return( POLARSSL_ERR_SSL_INVALID_RECORD );
         }
 
-         md5_update( &ssl->fin_md5 , ssl->in_msg, ssl->in_hslen );
-        sha1_update( &ssl->fin_sha1, ssl->in_msg, ssl->in_hslen );
-        sha2_update( &ssl->fin_sha2, ssl->in_msg, ssl->in_hslen );
-        sha4_update( &ssl->fin_sha4, ssl->in_msg, ssl->in_hslen );
+        ssl->update_checksum( ssl, ssl->in_msg, ssl->in_hslen );
 
         return( 0 );
     }
@@ -1618,10 +1656,7 @@
             return( POLARSSL_ERR_SSL_INVALID_RECORD );
         }
 
-         md5_update( &ssl->fin_md5 , ssl->in_msg, ssl->in_hslen );
-        sha1_update( &ssl->fin_sha1, ssl->in_msg, ssl->in_hslen );
-        sha2_update( &ssl->fin_sha2, ssl->in_msg, ssl->in_hslen );
-        sha4_update( &ssl->fin_sha4, ssl->in_msg, ssl->in_hslen );
+        ssl->update_checksum( ssl, ssl->in_msg, ssl->in_hslen );
     }
 
     if( ssl->in_msgtype == SSL_MSG_ALERT )
@@ -1990,6 +2025,62 @@
     return( 0 );
 }
 
+void ssl_kickstart_checksum( ssl_context *ssl, int ciphersuite,
+                             unsigned char *input_buf, size_t len )
+{
+    if( ssl->minor_ver < SSL_MINOR_VERSION_3 )
+    {
+        md5_starts( (md5_context *) ssl->ctx_checksum );
+        sha1_starts( (sha1_context *) ( ssl->ctx_checksum +
+                                        sizeof(md5_context) ) );
+
+        ssl->update_checksum = ssl_update_checksum_md5sha1;
+    }
+    else if ( ciphersuite == SSL_RSA_AES_256_GCM_SHA384 ||
+              ciphersuite == SSL_EDH_RSA_AES_256_GCM_SHA384 )
+    {
+        sha4_starts( (sha4_context *) ssl->ctx_checksum, 1 );
+        ssl->update_checksum = ssl_update_checksum_sha384;
+    }
+    else
+    {
+        sha2_starts( (sha2_context *) ssl->ctx_checksum, 0 );
+        ssl->update_checksum = ssl_update_checksum_sha256;
+    }
+
+    if( ssl->endpoint == SSL_IS_CLIENT )
+        ssl->update_checksum( ssl, ssl->out_msg, ssl->out_msglen );
+    ssl->update_checksum( ssl, input_buf, len );
+}
+    
+static void ssl_update_checksum_start( ssl_context *ssl, unsigned char *buf,
+                                       size_t len )
+{
+    ((void) ssl);
+    ((void) buf);
+    ((void) len);
+}
+
+static void ssl_update_checksum_md5sha1( ssl_context *ssl, unsigned char *buf,
+                                         size_t len )
+{
+    md5_update( (md5_context *) ssl->ctx_checksum, buf, len );
+    sha1_update( (sha1_context *) ( ssl->ctx_checksum + sizeof(md5_context) ),
+                 buf, len );
+}
+
+static void ssl_update_checksum_sha256( ssl_context *ssl, unsigned char *buf,
+                                        size_t len )
+{
+    sha2_update( (sha2_context *) ssl->ctx_checksum, buf, len );
+}
+
+static void ssl_update_checksum_sha384( ssl_context *ssl, unsigned char *buf,
+                                        size_t len )
+{
+    sha4_update( (sha4_context *) ssl->ctx_checksum, buf, len );
+}
+
 static void ssl_calc_finished_ssl(
                 ssl_context *ssl, unsigned char *buf, int from )
 {
@@ -2003,8 +2094,9 @@
 
     SSL_DEBUG_MSG( 2, ( "=> calc  finished ssl" ) );
 
-    memcpy( &md5 , &ssl->fin_md5 , sizeof(  md5_context ) );
-    memcpy( &sha1, &ssl->fin_sha1, sizeof( sha1_context ) );
+    memcpy( &md5 , (md5_context *) ssl->ctx_checksum, sizeof(md5_context) );
+    memcpy( &sha1, (sha1_context *) ( ssl->ctx_checksum + sizeof(md5_context) ),
+            sizeof( sha1_context ) );
 
     /*
      * SSLv3:
@@ -2073,8 +2165,9 @@
 
     SSL_DEBUG_MSG( 2, ( "=> calc  finished tls" ) );
 
-    memcpy( &md5 , &ssl->fin_md5 , sizeof(  md5_context ) );
-    memcpy( &sha1, &ssl->fin_sha1, sizeof( sha1_context ) );
+    memcpy( &md5 , (md5_context *) ssl->ctx_checksum, sizeof(md5_context) );
+    memcpy( &sha1, (sha1_context *) ( ssl->ctx_checksum + sizeof(md5_context) ),
+            sizeof( sha1_context ) );
 
     /*
      * TLSv1:
@@ -2116,9 +2209,9 @@
     sha2_context sha2;
     unsigned char padbuf[32];
 
-    SSL_DEBUG_MSG( 2, ( "=> calc  finished tls 1.2" ) );
+    SSL_DEBUG_MSG( 2, ( "=> calc  finished tls sha256" ) );
 
-    memcpy( &sha2, &ssl->fin_sha2, sizeof( sha2_context ) );
+    memcpy( &sha2 , (sha2_context *) ssl->ctx_checksum, sizeof(sha2_context) );
 
     /*
      * TLSv1.2:
@@ -2155,9 +2248,9 @@
     sha4_context sha4;
     unsigned char padbuf[48];
 
-    SSL_DEBUG_MSG( 2, ( "=> calc  finished tls 1.2" ) );
+    SSL_DEBUG_MSG( 2, ( "=> calc  finished tls sha384" ) );
 
-    memcpy( &sha4, &ssl->fin_sha4, sizeof( sha4_context ) );
+    memcpy( &sha4 , (sha4_context *) ssl->ctx_checksum, sizeof(sha4_context) );
 
     /*
      * TLSv1.2:
@@ -2320,10 +2413,7 @@
     ssl->hostname = NULL;
     ssl->hostname_len = 0;
 
-     md5_starts( &ssl->fin_md5  );
-    sha1_starts( &ssl->fin_sha1 );
-    sha2_starts( &ssl->fin_sha2, 0 );
-    sha4_starts( &ssl->fin_sha4, 1 );
+    ssl->update_checksum = ssl_update_checksum_start;
 
     return( 0 );
 }
@@ -2367,10 +2457,7 @@
     memset( ssl->ctx_enc, 0, 128 );
     memset( ssl->ctx_dec, 0, 128 );
 
-     md5_starts( &ssl->fin_md5  );
-    sha1_starts( &ssl->fin_sha1 );
-    sha2_starts( &ssl->fin_sha2, 0 );
-    sha4_starts( &ssl->fin_sha4, 1 );
+    ssl->update_checksum = ssl_update_checksum_start;
 }
 
 /*