merge setkey_enc* functions

Signed-off-by: Jerry Yu <jerry.h.yu@arm.com>
diff --git a/library/aesce.c b/library/aesce.c
index 29a4ce0..b4ebdad 100644
--- a/library/aesce.c
+++ b/library/aesce.c
@@ -160,9 +160,6 @@
 
 }
 
-static uint8_t const rcon[] = { 0x01, 0x02, 0x04, 0x08, 0x10,
-                                0x20, 0x40, 0x80, 0x1b, 0x36 };
-
 static inline uint32_t aes_rot_word(uint32_t word)
 {
     return (word << (32 - 8)) | (word >> 8);
@@ -180,75 +177,47 @@
 }
 
 /*
- * Key expansion, 128-bit case
+ * Key expansion function
  */
-static void aesce_setkey_enc_128(unsigned char *rk,
-                                 const unsigned char *key)
+static void aesce_setkey_enc(unsigned char *rk,
+                             const unsigned char *key,
+                             const size_t key_bit_length)
 {
     uint32_t *rki;
     uint32_t *rko;
     uint32_t *rk_u32 = (uint32_t *) rk;
+    const uint32_t key_len_in_words = key_bit_length / 32;
+    const uint32_t key_len_in_bytes = key_bit_length / 8;
+    static uint8_t const rcon[] = { 0x01, 0x02, 0x04, 0x08, 0x10,
+                                    0x20, 0x40, 0x80, 0x1b, 0x36 };
+    const uint32_t rounds =
+        key_bit_length == 128 ? sizeof(rcon) : key_bit_length == 192 ? 8 : 7;
 
-    memcpy(rk, key, (128 / 8));
+    memcpy(rk, key, key_len_in_bytes);
 
-    for (size_t i = 0; i < sizeof(rcon); i++) {
-        rki = rk_u32 + i * (128 / 32);
-        rko = rki + (128 / 32);
-        rko[0] = aes_rot_word(aes_sub_word(rki[(128 / 32) - 1])) ^ rcon[i] ^ rki[0];
+    for (size_t i = 0; i < rounds; i++) {
+        rki = rk_u32 + i * key_len_in_words;
+        rko = rki + key_len_in_words;
+        rko[0] = aes_rot_word(aes_sub_word(rki[key_len_in_words - 1]));
+        rko[0] ^= rcon[i] ^ rki[0];
         rko[1] = rko[0] ^ rki[1];
         rko[2] = rko[1] ^ rki[2];
         rko[3] = rko[2] ^ rki[3];
-    }
-}
-
-/*
- * Key expansion, 192-bit case
- */
-static void aesce_setkey_enc_192(unsigned char *rk,
-                                 const unsigned char *key)
-{
-    uint32_t *rki;
-    uint32_t *rko;
-    uint32_t *rk_u32 = (uint32_t *) rk;
-    memcpy(rk, key, (192 / 8));
-
-    for (size_t i = 0; i < 8; i++) {
-        rki = rk_u32 + i * (192 / 32);
-        rko = rki + (192 / 32);
-        rko[0] = aes_rot_word(aes_sub_word(rki[(192 / 32) - 1])) ^ rcon[i] ^ rki[0];
-        rko[1] = rko[0] ^ rki[1];
-        rko[2] = rko[1] ^ rki[2];
-        rko[3] = rko[2] ^ rki[3];
-        if (i < 7) {
-            rko[4] = rko[3] ^ rki[4];
-            rko[5] = rko[4] ^ rki[5];
-        }
-    }
-}
-
-/*
- * Key expansion, 256-bit case
- */
-static void aesce_setkey_enc_256(unsigned char *rk,
-                                 const unsigned char *key)
-{
-    uint32_t *rki;
-    uint32_t *rko;
-    uint32_t *rk_u32 = (uint32_t *) rk;
-    memcpy(rk, key, (256 / 8));
-
-    for (size_t i = 0; i < 7; i++) {
-        rki = rk_u32 + i * (256 / 32);
-        rko = rki + (256 / 32);
-        rko[0] = aes_rot_word(aes_sub_word(rki[(256 / 32) - 1])) ^ rcon[i] ^ rki[0];
-        rko[1] = rko[0] ^ rki[1];
-        rko[2] = rko[1] ^ rki[2];
-        rko[3] = rko[2] ^ rki[3];
-        if (i < 6) {
-            rko[4] = aes_sub_word(rko[3]) ^ rki[4];
-            rko[5] = rko[4] ^ rki[5];
-            rko[6] = rko[5] ^ rki[6];
-            rko[7] = rko[6] ^ rki[7];
+        switch (key_bit_length) {
+            case 192:
+                if (i < 7) {
+                    rko[4] = rko[3] ^ rki[4];
+                    rko[5] = rko[4] ^ rki[5];
+                }
+                break;
+            case 256:
+                if (i < 6) {
+                    rko[4] = aes_sub_word(rko[3]) ^ rki[4];
+                    rko[5] = rko[4] ^ rki[5];
+                    rko[6] = rko[5] ^ rki[6];
+                    rko[7] = rko[6] ^ rki[7];
+                }
+                break;
         }
     }
 }
@@ -261,9 +230,10 @@
                              size_t bits)
 {
     switch (bits) {
-        case 128: aesce_setkey_enc_128(rk, key); break;
-        case 192: aesce_setkey_enc_192(rk, key); break;
-        case 256: aesce_setkey_enc_256(rk, key); break;
+        case 128:
+        case 192:
+        case 256:
+            aesce_setkey_enc(rk, key, bits); break;
         default: return MBEDTLS_ERR_AES_INVALID_KEY_LENGTH;
     }