scripts: imgtool: compression

Adds LZMA2 compression to imgtool.
Python lzma library is unable to compress with proper parameters while using
"ALONE" container, therefore 2 header bytes are calculated and added
to payload by imgtool.

Signed-off-by: Mateusz Michalek <mateusz.michalek@nordicsemi.no>
diff --git a/scripts/imgtool/image.py b/scripts/imgtool/image.py
index 1f9149b..690f0b7 100644
--- a/scripts/imgtool/image.py
+++ b/scripts/imgtool/image.py
@@ -20,7 +20,15 @@
 Image signing and management.
 """
 
+from . import version as versmod
+from .boot_record import create_sw_component_data
+import click
+import copy
+from enum import Enum
+import array
+from intelhex import IntelHex
 import hashlib
+import array
 import os.path
 import struct
 from enum import Enum
@@ -60,6 +68,8 @@
         'NON_BOOTABLE':          0x0000010,
         'RAM_LOAD':              0x0000020,
         'ROM_FIXED':             0x0000100,
+        'COMPRESSED_LZMA1':      0x0000200,
+        'COMPRESSED_LZMA2':      0x0000400,
 }
 
 TLV_VALUES = {
@@ -80,6 +90,9 @@
         'DEPENDENCY': 0x40,
         'SEC_CNT': 0x50,
         'BOOT_RECORD': 0x60,
+        'DECOMP_SIZE': 0x70,
+        'DECOMP_SHA': 0x71,
+        'DECOMP_SIGNATURE': 0x72,
 }
 
 TLV_SIZE = 4
@@ -238,6 +251,9 @@
         if load_addr and rom_fixed:
             raise click.UsageError("Can not set rom_fixed and load_addr at the same time")
 
+        self.image_hash = None
+        self.image_size = None
+        self.signature = None
         self.version = version or versmod.decode_version("0")
         self.header_size = header_size
         self.pad_header = pad_header
@@ -253,6 +269,7 @@
         self.rom_fixed = rom_fixed
         self.erased_val = 0xff if erased_val is None else int(erased_val, 0)
         self.payload = []
+        self.infile_data = []
         self.enckey = None
         self.save_enctlv = save_enctlv
         self.enctlv_len = 0
@@ -307,13 +324,31 @@
         try:
             if ext == INTEL_HEX_EXT:
                 ih = IntelHex(path)
-                self.payload = ih.tobinarray()
+                self.infile_data = ih.tobinarray()
+                self.payload = copy.copy(self.infile_data)
                 self.base_addr = ih.minaddr()
             else:
                 with open(path, 'rb') as f:
-                    self.payload = f.read()
+                    self.infile_data = f.read()
+                    self.payload = copy.copy(self.infile_data)
         except FileNotFoundError:
             raise click.UsageError("Input file not found")
+        self.image_size = len(self.payload)
+
+        # Add the image header if needed.
+        if self.pad_header and self.header_size > 0:
+            if self.base_addr:
+                # Adjust base_addr for new header
+                self.base_addr -= self.header_size
+            self.payload = bytes([self.erased_val] * self.header_size) + \
+                self.payload
+
+        self.check_header()
+
+    def load_compressed(self, data, compression_header):
+        """Load an image from buffer"""
+        self.payload = compression_header + data
+        self.image_size = len(self.payload)
 
         # Add the image header if needed.
         if self.pad_header and self.header_size > 0:
@@ -408,7 +443,8 @@
         return cipherkey, ciphermac, pubk
 
     def create(self, key, public_key_format, enckey, dependencies=None,
-               sw_type=None, custom_tlvs=None, encrypt_keylen=128, clear=False,
+               sw_type=None, custom_tlvs=None, compression_tlvs=None,
+               compression_type=None, encrypt_keylen=128, clear=False,
                fixed_sig=None, pub_key=None, vector_to_sign=None, user_sha='auto'):
         self.enckey = enckey
 
@@ -471,6 +507,9 @@
             dependencies_num = len(dependencies[DEP_IMAGES_KEY])
             protected_tlv_size += (dependencies_num * 16)
 
+        if compression_tlvs is not None:
+            for value in compression_tlvs.values():
+                protected_tlv_size += TLV_SIZE + len(value)
         if custom_tlvs is not None:
             for value in custom_tlvs.values():
                 protected_tlv_size += TLV_SIZE + len(value)
@@ -492,11 +531,15 @@
                 else:
                     self.payload.extend(pad)
 
+        compression_flags = 0x0
+        if compression_tlvs is not None:
+            if compression_type == "lzma2":
+                compression_flags = IMAGE_F['COMPRESSED_LZMA2']
         # This adds the header to the payload as well
         if encrypt_keylen == 256:
-            self.add_header(enckey, protected_tlv_size, 256)
+            self.add_header(enckey, protected_tlv_size, compression_flags, 256)
         else:
-            self.add_header(enckey, protected_tlv_size)
+            self.add_header(enckey, protected_tlv_size, compression_flags)
 
         prot_tlv = TLV(self.endian, TLV_PROT_INFO_MAGIC)
 
@@ -526,6 +569,9 @@
                     )
                     prot_tlv.add('DEPENDENCY', payload)
 
+            if compression_tlvs is not None:
+                for tag, value in compression_tlvs.items():
+                    prot_tlv.add(tag, value)
             if custom_tlvs is not None:
                 for tag, value in custom_tlvs.items():
                     prot_tlv.add(tag, value)
@@ -544,6 +590,7 @@
         digest = sha.digest()
         message = digest;
         tlv.add(hash_tlv, digest)
+        self.image_hash = digest
 
         if vector_to_sign == 'payload':
             # Stop amending data to the image
@@ -623,10 +670,16 @@
 
         self.check_trailer()
 
+    def get_struct_endian(self):
+        return STRUCT_ENDIAN_DICT[self.endian]
+
     def get_signature(self):
         return self.signature
 
-    def add_header(self, enckey, protected_tlv_size, aes_length=128):
+    def get_infile_data(self):
+        return self.infile_data
+
+    def add_header(self, enckey, protected_tlv_size, compression_flags, aes_length=128):
         """Install the image header."""
 
         flags = 0
@@ -664,7 +717,7 @@
                              protected_tlv_size,  # TLV Info header +
                                                   # Protected TLVs
                              len(self.payload) - self.header_size,  # ImageSz
-                             flags,
+                             flags | compression_flags,
                              self.version.major,
                              self.version.minor or 0,
                              self.version.revision or 0,
diff --git a/scripts/imgtool/main.py b/scripts/imgtool/main.py
index 848fd31..03bb565 100755
--- a/scripts/imgtool/main.py
+++ b/scripts/imgtool/main.py
@@ -22,6 +22,10 @@
 import getpass
 import imgtool.keys as keys
 import sys
+import struct
+import os
+import lzma
+import hashlib
 import base64
 from imgtool import image, imgtool_version
 from imgtool.version import decode_version
@@ -29,6 +33,13 @@
 from .keys import (
     RSAUsageError, ECDSAUsageError, Ed25519UsageError, X25519UsageError)
 
+comp_default_dictsize=131072
+comp_default_pb=2
+comp_default_lc=3
+comp_default_lp=1
+comp_default_preset=9
+
+
 MIN_PYTHON_VERSION = (3, 6)
 if sys.version_info < MIN_PYTHON_VERSION:
     sys.exit("Python %s.%s or newer is required by imgtool."
@@ -300,6 +311,14 @@
         dependencies[image.DEP_VERSIONS_KEY] = versions
         return dependencies
 
+def create_lzma2_header(dictsize, pb, lc, lp):
+    header = bytearray()
+    for i in range(0, 40):
+        if dictsize <= ((2 | ((i) & 1)) << int((i) / 2 + 11)):
+            header.append(i)
+            break
+    header.append( ( pb * 5 + lp) * 9 + lc)
+    return header
 
 class BasedIntParamType(click.ParamType):
     name = 'integer'
@@ -343,6 +362,11 @@
               type=click.Choice(['128', '256']),
               help='When encrypting the image using AES, select a 128 bit or '
                    '256 bit key len.')
+@click.option('--compression', default='disabled',
+              type=click.Choice(['disabled', 'lzma2']),
+              help='Enable image compression using specified type. '
+                   'Will fall back without image compression automatically '
+                   'if the compression increases the image size.')
 @click.option('-c', '--clear', required=False, is_flag=True, default=False,
               help='Output a non-encrypted image with encryption capabilities,'
                    'so it can be installed in the primary slot, and encrypted '
@@ -414,10 +438,11 @@
                .hex extension, otherwise binary format is used''')
 def sign(key, public_key_format, align, version, pad_sig, header_size,
          pad_header, slot_size, pad, confirm, max_sectors, overwrite_only,
-         endian, encrypt_keylen, encrypt, infile, outfile, dependencies,
-         load_addr, hex_addr, erased_val, save_enctlv, security_counter,
-         boot_record, custom_tlv, rom_fixed, max_align, clear, fix_sig,
-         fix_sig_pubkey, sig_out, user_sha, vector_to_sign, non_bootable):
+         endian, encrypt_keylen, encrypt, compression, infile, outfile,
+         dependencies, load_addr, hex_addr, erased_val, save_enctlv,
+         security_counter, boot_record, custom_tlv, rom_fixed, max_align,
+         clear, fix_sig, fix_sig_pubkey, sig_out, user_sha, vector_to_sign,
+         non_bootable):
 
     if confirm:
         # Confirmed but non-padded images don't make much sense, because
@@ -431,6 +456,7 @@
                       erased_val=erased_val, save_enctlv=save_enctlv,
                       security_counter=security_counter, max_align=max_align,
                       non_bootable=non_bootable)
+    compression_tlvs = {}
     img.load(infile)
     key = load_key(key) if key else None
     enckey = load_key(encrypt) if encrypt else None
@@ -484,10 +510,49 @@
         }
 
     img.create(key, public_key_format, enckey, dependencies, boot_record,
-               custom_tlvs, int(encrypt_keylen), clear, baked_signature,
-               pub_key, vector_to_sign, user_sha)
-    img.save(outfile, hex_addr)
+               custom_tlvs, compression_tlvs, int(encrypt_keylen), clear,
+               baked_signature, pub_key, vector_to_sign, user_sha)
 
+    if compression == "lzma2":
+        compressed_img = image.Image(version=decode_version(version),
+                  header_size=header_size, pad_header=pad_header,
+                  pad=pad, confirm=confirm, align=int(align),
+                  slot_size=slot_size, max_sectors=max_sectors,
+                  overwrite_only=overwrite_only, endian=endian,
+                  load_addr=load_addr, rom_fixed=rom_fixed,
+                  erased_val=erased_val, save_enctlv=save_enctlv,
+                  security_counter=security_counter, max_align=max_align)
+        compression_filters = [
+            {"id": lzma.FILTER_LZMA2, "preset": comp_default_preset,
+                "dict_size": comp_default_dictsize, "lp": comp_default_lp,
+                "lc": comp_default_lc}
+        ]
+        compressed_data = lzma.compress(img.get_infile_data(),filters=compression_filters,
+            format=lzma.FORMAT_RAW)
+        uncompressed_size = len(img.get_infile_data())
+        compressed_size = len(compressed_data)
+        print(f"compressed image size: {compressed_size} bytes")
+        print(f"original image size: {uncompressed_size} bytes")
+        compression_tlvs["DECOMP_SIZE"] = struct.pack(
+            img.get_struct_endian() + 'L', img.image_size)
+        compression_tlvs["DECOMP_SHA"] = img.image_hash
+        compression_tlvs_size = len(compression_tlvs["DECOMP_SIZE"])
+        compression_tlvs_size += len(compression_tlvs["DECOMP_SHA"])
+        if img.get_signature():
+            compression_tlvs["DECOMP_SIGNATURE"] = img.get_signature()
+            compression_tlvs_size += len(compression_tlvs["DECOMP_SIGNATURE"])
+        if (compressed_size + compression_tlvs_size) < uncompressed_size:
+            compression_header = create_lzma2_header(
+                dictsize = comp_default_dictsize, pb = comp_default_pb,
+                lc = comp_default_lc, lp = comp_default_lp)
+            compressed_img.load_compressed(compressed_data, compression_header)
+            compressed_img.base_addr = img.base_addr
+            compressed_img.create(key, public_key_format, enckey,
+               dependencies, boot_record, custom_tlvs, compression_tlvs,
+               compression, int(encrypt_keylen), clear, baked_signature,
+               pub_key, vector_to_sign)
+            img = compressed_img
+    img.save(outfile, hex_addr)
     if sig_out is not None:
         new_signature = img.get_signature()
         save_signature(sig_out, new_signature)