blob: bae99383dc53ef65652114b70fb60bdbf408c5d0 [file] [log] [blame]
Gilles Peskinee0094482021-02-17 14:34:37 +01001"""Knowledge about the PSA key store as implemented in Mbed TLS.
Gilles Peskine76851ae2022-06-20 19:10:35 +02002
3Note that if you need to make a change that affects how keys are
4stored, this may indicate that the key store is changing in a
5backward-incompatible way! Think carefully about backward compatibility
6before changing how test data is constructed or validated.
Gilles Peskinee0094482021-02-17 14:34:37 +01007"""
8
9# Copyright The Mbed TLS Contributors
10# SPDX-License-Identifier: Apache-2.0
11#
12# Licensed under the Apache License, Version 2.0 (the "License"); you may
13# not use this file except in compliance with the License.
14# You may obtain a copy of the License at
15#
16# http://www.apache.org/licenses/LICENSE-2.0
17#
18# Unless required by applicable law or agreed to in writing, software
19# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
20# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21# See the License for the specific language governing permissions and
22# limitations under the License.
23
24import re
25import struct
Gilles Peskine23523962021-03-10 01:32:38 +010026from typing import Dict, List, Optional, Set, Union
Gilles Peskinee0094482021-02-17 14:34:37 +010027import unittest
28
Gilles Peskine15997bd2022-09-16 22:35:18 +020029from . import c_build_helper
Gilles Peskine23523962021-03-10 01:32:38 +010030
Gilles Peskinee0094482021-02-17 14:34:37 +010031
32class Expr:
33 """Representation of a C expression with a known or knowable numerical value."""
Gilles Peskine23523962021-03-10 01:32:38 +010034
Gilles Peskinee0094482021-02-17 14:34:37 +010035 def __init__(self, content: Union[int, str]):
36 if isinstance(content, int):
37 digits = 8 if content > 0xffff else 4
38 self.string = '{0:#0{1}x}'.format(content, digits + 2)
39 self.value_if_known = content #type: Optional[int]
40 else:
41 self.string = content
Gilles Peskine23523962021-03-10 01:32:38 +010042 self.unknown_values.add(self.normalize(content))
Gilles Peskinee0094482021-02-17 14:34:37 +010043 self.value_if_known = None
44
Gilles Peskine23523962021-03-10 01:32:38 +010045 value_cache = {} #type: Dict[str, int]
46 """Cache of known values of expressions."""
47
48 unknown_values = set() #type: Set[str]
49 """Expressions whose values are not present in `value_cache` yet."""
Gilles Peskinee0094482021-02-17 14:34:37 +010050
51 def update_cache(self) -> None:
Gilles Peskine23523962021-03-10 01:32:38 +010052 """Update `value_cache` for expressions registered in `unknown_values`."""
53 expressions = sorted(self.unknown_values)
54 values = c_build_helper.get_c_expression_values(
55 'unsigned long', '%lu',
56 expressions,
57 header="""
58 #include <psa/crypto.h>
59 """,
60 include_path=['include']) #type: List[str]
61 for e, v in zip(expressions, values):
62 self.value_cache[e] = int(v, 0)
63 self.unknown_values.clear()
Gilles Peskinee0094482021-02-17 14:34:37 +010064
65 @staticmethod
66 def normalize(string: str) -> str:
67 """Put the given C expression in a canonical form.
68
69 This function is only intended to give correct results for the
70 relatively simple kind of C expression typically used with this
71 module.
72 """
73 return re.sub(r'\s+', r'', string)
74
75 def value(self) -> int:
76 """Return the numerical value of the expression."""
77 if self.value_if_known is None:
78 if re.match(r'([0-9]+|0x[0-9a-f]+)\Z', self.string, re.I):
79 return int(self.string, 0)
80 normalized = self.normalize(self.string)
81 if normalized not in self.value_cache:
82 self.update_cache()
83 self.value_if_known = self.value_cache[normalized]
84 return self.value_if_known
85
86Exprable = Union[str, int, Expr]
87"""Something that can be converted to a C expression with a known numerical value."""
88
89def as_expr(thing: Exprable) -> Expr:
90 """Return an `Expr` object for `thing`.
91
92 If `thing` is already an `Expr` object, return it. Otherwise build a new
93 `Expr` object from `thing`. `thing` can be an integer or a string that
94 contains a C expression.
95 """
96 if isinstance(thing, Expr):
97 return thing
98 else:
99 return Expr(thing)
100
101
102class Key:
103 """Representation of a PSA crypto key object and its storage encoding.
104 """
105
106 LATEST_VERSION = 0
107 """The latest version of the storage format."""
108
109 def __init__(self, *,
110 version: Optional[int] = None,
111 id: Optional[int] = None, #pylint: disable=redefined-builtin
112 lifetime: Exprable = 'PSA_KEY_LIFETIME_PERSISTENT',
113 type: Exprable, #pylint: disable=redefined-builtin
114 bits: int,
115 usage: Exprable, alg: Exprable, alg2: Exprable,
116 material: bytes #pylint: disable=used-before-assignment
117 ) -> None:
118 self.version = self.LATEST_VERSION if version is None else version
119 self.id = id #pylint: disable=invalid-name #type: Optional[int]
120 self.lifetime = as_expr(lifetime) #type: Expr
121 self.type = as_expr(type) #type: Expr
122 self.bits = bits #type: int
123 self.usage = as_expr(usage) #type: Expr
124 self.alg = as_expr(alg) #type: Expr
125 self.alg2 = as_expr(alg2) #type: Expr
126 self.material = material #type: bytes
127
128 MAGIC = b'PSA\000KEY\000'
129
130 @staticmethod
131 def pack(
132 fmt: str,
133 *args: Union[int, Expr]
134 ) -> bytes: #pylint: disable=used-before-assignment
135 """Pack the given arguments into a byte string according to the given format.
136
137 This function is similar to `struct.pack`, but with the following differences:
138 * All integer values are encoded with standard sizes and in
139 little-endian representation. `fmt` must not include an endianness
140 prefix.
141 * Arguments can be `Expr` objects instead of integers.
142 * Only integer-valued elements are supported.
143 """
144 return struct.pack('<' + fmt, # little-endian, standard sizes
145 *[arg.value() if isinstance(arg, Expr) else arg
146 for arg in args])
147
148 def bytes(self) -> bytes:
149 """Return the representation of the key in storage as a byte array.
150
151 This is the content of the PSA storage file. When PSA storage is
152 implemented over stdio files, this does not include any wrapping made
153 by the PSA-storage-over-stdio-file implementation.
Gilles Peskine76851ae2022-06-20 19:10:35 +0200154
155 Note that if you need to make a change in this function,
156 this may indicate that the key store is changing in a
157 backward-incompatible way! Think carefully about backward
158 compatibility before making any change here.
Gilles Peskinee0094482021-02-17 14:34:37 +0100159 """
160 header = self.MAGIC + self.pack('L', self.version)
161 if self.version == 0:
162 attributes = self.pack('LHHLLL',
163 self.lifetime, self.type, self.bits,
164 self.usage, self.alg, self.alg2)
165 material = self.pack('L', len(self.material)) + self.material
166 else:
167 raise NotImplementedError
168 return header + attributes + material
169
170 def hex(self) -> str:
171 """Return the representation of the key as a hexadecimal string.
172
173 This is the hexadecimal representation of `self.bytes`.
174 """
175 return self.bytes().hex()
176
Gilles Peskineeb7bdaa2021-04-21 22:05:34 +0200177 def location_value(self) -> int:
178 """The numerical value of the location encoded in the key's lifetime."""
179 return self.lifetime.value() >> 8
180
Gilles Peskinee0094482021-02-17 14:34:37 +0100181
182class TestKey(unittest.TestCase):
183 # pylint: disable=line-too-long
184 """A few smoke tests for the functionality of the `Key` class."""
185
186 def test_numerical(self):
187 key = Key(version=0,
188 id=1, lifetime=0x00000001,
189 type=0x2400, bits=128,
190 usage=0x00000300, alg=0x05500200, alg2=0x04c01000,
191 material=b'@ABCDEFGHIJKLMNO')
192 expected_hex = '505341004b45590000000000010000000024800000030000000250050010c00410000000404142434445464748494a4b4c4d4e4f'
193 self.assertEqual(key.bytes(), bytes.fromhex(expected_hex))
194 self.assertEqual(key.hex(), expected_hex)
195
196 def test_names(self):
197 length = 0xfff8 // 8 # PSA_MAX_KEY_BITS in bytes
198 key = Key(version=0,
199 id=1, lifetime='PSA_KEY_LIFETIME_PERSISTENT',
200 type='PSA_KEY_TYPE_RAW_DATA', bits=length*8,
201 usage=0, alg=0, alg2=0,
202 material=b'\x00' * length)
203 expected_hex = '505341004b45590000000000010000000110f8ff000000000000000000000000ff1f0000' + '00' * length
204 self.assertEqual(key.bytes(), bytes.fromhex(expected_hex))
205 self.assertEqual(key.hex(), expected_hex)
206
207 def test_defaults(self):
208 key = Key(type=0x1001, bits=8,
209 usage=0, alg=0, alg2=0,
210 material=b'\x2a')
211 expected_hex = '505341004b455900000000000100000001100800000000000000000000000000010000002a'
212 self.assertEqual(key.bytes(), bytes.fromhex(expected_hex))
213 self.assertEqual(key.hex(), expected_hex)