Implement ccm_auth_decrypt()
diff --git a/library/ccm.c b/library/ccm.c
index 399b9fe..fafa238 100644
--- a/library/ccm.c
+++ b/library/ccm.c
@@ -146,9 +146,6 @@
if( add_len > 0xFF00 )
return( POLARSSL_ERR_CCM_BAD_INPUT );
- if( mode != CCM_ENCRYPT )
- return( POLARSSL_ERR_CCM_BAD_INPUT ); /* Not implemented yet */
-
/*
* First block B_0:
* 0 .. 0 flags
@@ -238,18 +235,34 @@
dst = output;
/*
- * Authenticate and encrypt message
+ * Authenticate and crypt message
+ * The only difference between encryption and decryption is
+ * the respective order of authentication and {en,de}cryption
*/
while( len_left > 16 )
{
- memcpy( b, src, 16 );
- UPDATE_CBC_MAC( b );
+ if( mode == CCM_ENCRYPT )
+ {
+ memcpy( b, src, 16 );
+ UPDATE_CBC_MAC( b );
+ }
+
CTR_CRYPT_BLOCK( dst, src );
+ if( mode == CCM_DECRYPT )
+ {
+ memcpy( b, dst, 16 );
+ UPDATE_CBC_MAC( b );
+ }
+
dst += 16;
src += 16;
len_left -= 16;
+ /*
+ * Increment counter.
+ * No need to check for overflow thanks to the length check above.
+ */
for( i = 0; i < q; i++ )
if( ++ctr[15-i] != 0 )
break;
@@ -260,10 +273,12 @@
unsigned char mask[16];
size_t olen;
- memset( b, 0, 16 );
- memcpy( b, src, len_left );
-
- UPDATE_CBC_MAC( b );
+ if( mode == CCM_ENCRYPT )
+ {
+ memset( b, 0, 16 );
+ memcpy( b, src, len_left );
+ UPDATE_CBC_MAC( b );
+ }
if( ( ret = cipher_update( &ctx->cipher_ctx, ctr, 16,
mask, &olen ) ) != 0 )
@@ -273,10 +288,17 @@
for( i = 0; i < len_left; i++ )
dst[i] = src[i] ^ mask[i];
+
+ if( mode == CCM_DECRYPT )
+ {
+ memset( b, 0, 16 );
+ memcpy( b, dst, len_left );
+ UPDATE_CBC_MAC( b );
+ }
}
/*
- * Authentication: reset counter and encrypt internal tag
+ * Authentication: reset counter and {en,de}crypt internal tag
*/
for( i = 0; i < q; i++ )
ctr[15-i] = 0;
@@ -410,10 +432,10 @@
if( verbose != 0 )
polarssl_printf( " CCM-AES #%u: ", (unsigned int) i + 1 );
- ret = ccm_encrypt_and_tag( &ctx, msg_len[i],
- iv, iv_len[i], ad, add_len[i],
- msg, out,
- out + msg_len[i], tag_len[i] );
+ ret = ccm_encrypt_and_tag( &ctx, msg_len[i],
+ iv, iv_len[i], ad, add_len[i],
+ msg, out,
+ out + msg_len[i], tag_len[i] );
if( ret != 0 ||
memcmp( out, res[i], msg_len[i] + tag_len[i] ) != 0 )
@@ -424,6 +446,20 @@
return( 1 );
}
+ ret = ccm_auth_decrypt( &ctx, msg_len[i],
+ iv, iv_len[i], ad, add_len[i],
+ res[i], out,
+ res[i] + msg_len[i], tag_len[i] );
+
+ if( ret != 0 ||
+ memcmp( out, msg, msg_len[i] ) != 0 )
+ {
+ if( verbose != 0 )
+ polarssl_printf( "failed\n" );
+
+ return( 1 );
+ }
+
if( verbose != 0 )
polarssl_printf( "passed\n" );
}