blob: 4adbb07ab296c9facb9f115acc026f43c5de110c [file] [log] [blame]
Gilles Peskinee0094482021-02-17 14:34:37 +01001"""Knowledge about the PSA key store as implemented in Mbed TLS.
Gilles Peskine3d96ea12022-06-20 19:10:35 +02002
3Note that if you need to make a change that affects how keys are
4stored, this may indicate that the key store is changing in a
5backward-incompatible way! Think carefully about backward compatibility
6before changing how test data is constructed or validated.
Gilles Peskinee0094482021-02-17 14:34:37 +01007"""
8
9# Copyright The Mbed TLS Contributors
Dave Rodgman7ff79652023-11-03 12:04:52 +000010# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
Gilles Peskinee0094482021-02-17 14:34:37 +010011#
Gilles Peskinee0094482021-02-17 14:34:37 +010012
13import re
14import struct
Gilles Peskine23523962021-03-10 01:32:38 +010015from typing import Dict, List, Optional, Set, Union
Gilles Peskinee0094482021-02-17 14:34:37 +010016import unittest
17
Gilles Peskine239765a2022-09-16 22:35:18 +020018from . import c_build_helper
Gilles Peskine23523962021-03-10 01:32:38 +010019
Gilles Peskinee0094482021-02-17 14:34:37 +010020
21class Expr:
22 """Representation of a C expression with a known or knowable numerical value."""
Gilles Peskine23523962021-03-10 01:32:38 +010023
Gilles Peskinee0094482021-02-17 14:34:37 +010024 def __init__(self, content: Union[int, str]):
25 if isinstance(content, int):
26 digits = 8 if content > 0xffff else 4
27 self.string = '{0:#0{1}x}'.format(content, digits + 2)
28 self.value_if_known = content #type: Optional[int]
29 else:
30 self.string = content
Gilles Peskine23523962021-03-10 01:32:38 +010031 self.unknown_values.add(self.normalize(content))
Gilles Peskinee0094482021-02-17 14:34:37 +010032 self.value_if_known = None
33
Gilles Peskine23523962021-03-10 01:32:38 +010034 value_cache = {} #type: Dict[str, int]
35 """Cache of known values of expressions."""
36
37 unknown_values = set() #type: Set[str]
38 """Expressions whose values are not present in `value_cache` yet."""
Gilles Peskinee0094482021-02-17 14:34:37 +010039
40 def update_cache(self) -> None:
Gilles Peskine23523962021-03-10 01:32:38 +010041 """Update `value_cache` for expressions registered in `unknown_values`."""
42 expressions = sorted(self.unknown_values)
43 values = c_build_helper.get_c_expression_values(
44 'unsigned long', '%lu',
45 expressions,
46 header="""
47 #include <psa/crypto.h>
48 """,
49 include_path=['include']) #type: List[str]
50 for e, v in zip(expressions, values):
51 self.value_cache[e] = int(v, 0)
52 self.unknown_values.clear()
Gilles Peskinee0094482021-02-17 14:34:37 +010053
54 @staticmethod
55 def normalize(string: str) -> str:
56 """Put the given C expression in a canonical form.
57
58 This function is only intended to give correct results for the
59 relatively simple kind of C expression typically used with this
60 module.
61 """
62 return re.sub(r'\s+', r'', string)
63
64 def value(self) -> int:
65 """Return the numerical value of the expression."""
66 if self.value_if_known is None:
67 if re.match(r'([0-9]+|0x[0-9a-f]+)\Z', self.string, re.I):
68 return int(self.string, 0)
69 normalized = self.normalize(self.string)
70 if normalized not in self.value_cache:
71 self.update_cache()
72 self.value_if_known = self.value_cache[normalized]
73 return self.value_if_known
74
75Exprable = Union[str, int, Expr]
76"""Something that can be converted to a C expression with a known numerical value."""
77
78def as_expr(thing: Exprable) -> Expr:
79 """Return an `Expr` object for `thing`.
80
81 If `thing` is already an `Expr` object, return it. Otherwise build a new
82 `Expr` object from `thing`. `thing` can be an integer or a string that
83 contains a C expression.
84 """
85 if isinstance(thing, Expr):
86 return thing
87 else:
88 return Expr(thing)
89
90
91class Key:
92 """Representation of a PSA crypto key object and its storage encoding.
93 """
94
95 LATEST_VERSION = 0
96 """The latest version of the storage format."""
97
98 def __init__(self, *,
99 version: Optional[int] = None,
100 id: Optional[int] = None, #pylint: disable=redefined-builtin
101 lifetime: Exprable = 'PSA_KEY_LIFETIME_PERSISTENT',
102 type: Exprable, #pylint: disable=redefined-builtin
103 bits: int,
104 usage: Exprable, alg: Exprable, alg2: Exprable,
105 material: bytes #pylint: disable=used-before-assignment
106 ) -> None:
107 self.version = self.LATEST_VERSION if version is None else version
108 self.id = id #pylint: disable=invalid-name #type: Optional[int]
109 self.lifetime = as_expr(lifetime) #type: Expr
110 self.type = as_expr(type) #type: Expr
111 self.bits = bits #type: int
112 self.usage = as_expr(usage) #type: Expr
113 self.alg = as_expr(alg) #type: Expr
114 self.alg2 = as_expr(alg2) #type: Expr
115 self.material = material #type: bytes
116
117 MAGIC = b'PSA\000KEY\000'
118
119 @staticmethod
120 def pack(
121 fmt: str,
122 *args: Union[int, Expr]
123 ) -> bytes: #pylint: disable=used-before-assignment
124 """Pack the given arguments into a byte string according to the given format.
125
126 This function is similar to `struct.pack`, but with the following differences:
127 * All integer values are encoded with standard sizes and in
128 little-endian representation. `fmt` must not include an endianness
129 prefix.
130 * Arguments can be `Expr` objects instead of integers.
131 * Only integer-valued elements are supported.
132 """
133 return struct.pack('<' + fmt, # little-endian, standard sizes
134 *[arg.value() if isinstance(arg, Expr) else arg
135 for arg in args])
136
137 def bytes(self) -> bytes:
138 """Return the representation of the key in storage as a byte array.
139
140 This is the content of the PSA storage file. When PSA storage is
141 implemented over stdio files, this does not include any wrapping made
142 by the PSA-storage-over-stdio-file implementation.
Gilles Peskine3d96ea12022-06-20 19:10:35 +0200143
144 Note that if you need to make a change in this function,
145 this may indicate that the key store is changing in a
146 backward-incompatible way! Think carefully about backward
147 compatibility before making any change here.
Gilles Peskinee0094482021-02-17 14:34:37 +0100148 """
149 header = self.MAGIC + self.pack('L', self.version)
150 if self.version == 0:
151 attributes = self.pack('LHHLLL',
152 self.lifetime, self.type, self.bits,
153 self.usage, self.alg, self.alg2)
154 material = self.pack('L', len(self.material)) + self.material
155 else:
156 raise NotImplementedError
157 return header + attributes + material
158
159 def hex(self) -> str:
160 """Return the representation of the key as a hexadecimal string.
161
162 This is the hexadecimal representation of `self.bytes`.
163 """
164 return self.bytes().hex()
165
Gilles Peskineefb584d2021-04-21 22:05:34 +0200166 def location_value(self) -> int:
167 """The numerical value of the location encoded in the key's lifetime."""
168 return self.lifetime.value() >> 8
169
Gilles Peskinee0094482021-02-17 14:34:37 +0100170
171class TestKey(unittest.TestCase):
172 # pylint: disable=line-too-long
173 """A few smoke tests for the functionality of the `Key` class."""
174
175 def test_numerical(self):
176 key = Key(version=0,
177 id=1, lifetime=0x00000001,
178 type=0x2400, bits=128,
179 usage=0x00000300, alg=0x05500200, alg2=0x04c01000,
180 material=b'@ABCDEFGHIJKLMNO')
181 expected_hex = '505341004b45590000000000010000000024800000030000000250050010c00410000000404142434445464748494a4b4c4d4e4f'
182 self.assertEqual(key.bytes(), bytes.fromhex(expected_hex))
183 self.assertEqual(key.hex(), expected_hex)
184
185 def test_names(self):
186 length = 0xfff8 // 8 # PSA_MAX_KEY_BITS in bytes
187 key = Key(version=0,
188 id=1, lifetime='PSA_KEY_LIFETIME_PERSISTENT',
189 type='PSA_KEY_TYPE_RAW_DATA', bits=length*8,
190 usage=0, alg=0, alg2=0,
191 material=b'\x00' * length)
192 expected_hex = '505341004b45590000000000010000000110f8ff000000000000000000000000ff1f0000' + '00' * length
193 self.assertEqual(key.bytes(), bytes.fromhex(expected_hex))
194 self.assertEqual(key.hex(), expected_hex)
195
196 def test_defaults(self):
197 key = Key(type=0x1001, bits=8,
198 usage=0, alg=0, alg2=0,
199 material=b'\x2a')
200 expected_hex = '505341004b455900000000000100000001100800000000000000000000000000010000002a'
201 self.assertEqual(key.bytes(), bytes.fromhex(expected_hex))
202 self.assertEqual(key.hex(), expected_hex)