Enable multiple calls to mbedtls_gcm_update_ad.

Signed-off-by: Mateusz Starzyk <mateusz.starzyk@mobica.com>
diff --git a/docs/3.0-migration-guide.d/gcm-multipart.md b/docs/3.0-migration-guide.d/gcm-multipart.md
index 98e9fad..ebc6397 100644
--- a/docs/3.0-migration-guide.d/gcm-multipart.md
+++ b/docs/3.0-migration-guide.d/gcm-multipart.md
@@ -6,7 +6,6 @@
 Applications using one-shot GCM or using GCM via the `mbedtls_cipher_xxx` or `psa_aead_xxx` interfaces do not require any changes.
 
 * `mbedtls_gcm_starts()` now only sets the mode and the nonce (IV). Call the new function `mbedtls_gcm_update_ad()` to pass the associated data.
-* The current implementation has a limitation that `mbedtls_gcm_update_ad()` may only be called once. This limitation will be lifted shortly; watch https://github.com/ARMmbed/mbedtls/issues/4351 for updates.
 * `mbedtls_gcm_update()` now takes an extra parameter to indicate the actual output length. In Mbed TLS 2.x, applications had to pass inputs consisting of whole 16-byte blocks except for the last block (this limitation has been lifted). In this case:
     * As long as the input remains block-aligned, the output length is exactly the input length, as before.
     * If the length of the last input is not a multiple of 16, alternative implementations may return the last partial block in the call to `mbedtls_gcm_finish()` instead of returning it in the last call to `mbedtls_gcm_update()`.
diff --git a/include/mbedtls/gcm.h b/include/mbedtls/gcm.h
index c8e384a..f3c3035 100644
--- a/include/mbedtls/gcm.h
+++ b/include/mbedtls/gcm.h
@@ -246,11 +246,6 @@
  *                  you do not need to call this function. You may not
  *                  call this function after calling mbedtls_cipher_update().
  *
- * \note            This function may only be called once per operation:
- *                  you must pass the whole associated data in a single
- *                  call. This limitation will be lifted in a future version
- *                  of Mbed TLS.
- *
  * \param ctx       The GCM context. This must have been started with
  *                  mbedtls_gcm_starts() and must not have yet received
  *                  any input with mbedtls_gcm_update().
diff --git a/library/gcm.c b/library/gcm.c
index 2bd9071..23b6ebb 100644
--- a/library/gcm.c
+++ b/library/gcm.c
@@ -337,7 +337,7 @@
                            const unsigned char *add, size_t add_len )
 {
     const unsigned char *p;
-    size_t use_len, i;
+    size_t use_len, i, offset;
 
     GCM_VALIDATE_RET( add_len == 0 || add != NULL );
 
@@ -345,15 +345,31 @@
     if( (uint64_t) add_len >> 61 != 0 )
         return( MBEDTLS_ERR_GCM_BAD_INPUT );
 
-    /* Calling update_ad multiple times is not yet supported */
-    if( ctx->add_len != 0 )
-        return( MBEDTLS_ERR_GCM_BAD_INPUT );
-
-    ctx->add_len = add_len;
+    offset = ctx->add_len % 16;
     p = add;
-    while( add_len > 0 )
+
+    if (offset)
     {
-        use_len = ( add_len < 16 ) ? add_len : 16;
+        use_len = 16 - offset;
+        if( use_len > add_len )
+            use_len = add_len;
+
+        for (i = 0; i < use_len; i++)
+            ctx->buf[i+offset] ^= p[i];
+
+        if( offset + use_len == 16 )
+            gcm_mult( ctx, ctx->buf, ctx->buf );
+
+        ctx->add_len += use_len;
+        add_len -= use_len;
+        p += use_len;
+    }
+
+    ctx->add_len += add_len;
+
+    while( add_len >= 16 )
+    {
+        use_len = 16;
 
         for( i = 0; i < use_len; i++ )
             ctx->buf[i] ^= p[i];
@@ -364,6 +380,12 @@
         p += use_len;
     }
 
+    if ( add_len > 0 )
+    {
+        for( i = 0; i < add_len; i++ )
+            ctx->buf[i] ^= p[i];
+    }
+
     return( 0 );
 }
 
@@ -442,6 +464,11 @@
         return( MBEDTLS_ERR_GCM_BAD_INPUT );
     }
 
+    if ( ( ctx->len == 0 ) && ( ctx->add_len % 16 ) )
+    {
+        gcm_mult( ctx, ctx->buf, ctx->buf );
+    }
+
     offset = ctx->len % 16;
     if( offset != 0 )
     {
@@ -507,6 +534,11 @@
     orig_len = ctx->len * 8;
     orig_add_len = ctx->add_len * 8;
 
+    if ( ( ctx->len == 0 ) && ( ctx->add_len % 16 ) )
+    {
+        gcm_mult( ctx, ctx->buf, ctx->buf );
+    }
+
     if( tag_len > 16 || tag_len < 4 )
         return( MBEDTLS_ERR_GCM_BAD_INPUT );