blob: 9c3c72319a225ac5b5d5ea19c9988c5b38a62092 [file] [log] [blame]
Gilles Peskinee7c44552021-01-25 21:40:45 +01001"""Collect macro definitions from header files.
2"""
3
4# Copyright The Mbed TLS Contributors
5# SPDX-License-Identifier: Apache-2.0
6#
7# Licensed under the Apache License, Version 2.0 (the "License"); you may
8# not use this file except in compliance with the License.
9# You may obtain a copy of the License at
10#
11# http://www.apache.org/licenses/LICENSE-2.0
12#
13# Unless required by applicable law or agreed to in writing, software
14# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
15# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16# See the License for the specific language governing permissions and
17# limitations under the License.
18
Gilles Peskine22fcf1b2021-03-10 01:02:39 +010019import itertools
Gilles Peskinee7c44552021-01-25 21:40:45 +010020import re
Gilles Peskineb4edff92021-03-30 19:09:05 +020021from typing import Dict, Iterable, Iterator, List, Optional, Pattern, Set, Tuple, Union
22
23
24class ReadFileLineException(Exception):
25 def __init__(self, filename: str, line_number: Union[int, str]) -> None:
26 message = 'in {} at {}'.format(filename, line_number)
27 super(ReadFileLineException, self).__init__(message)
28 self.filename = filename
29 self.line_number = line_number
30
31
32class read_file_lines:
33 # Dear Pylint, conventionally, a context manager class name is lowercase.
34 # pylint: disable=invalid-name,too-few-public-methods
35 """Context manager to read a text file line by line.
36
37 ```
38 with read_file_lines(filename) as lines:
39 for line in lines:
40 process(line)
41 ```
42 is equivalent to
43 ```
44 with open(filename, 'r') as input_file:
45 for line in input_file:
46 process(line)
47 ```
48 except that if process(line) raises an exception, then the read_file_lines
49 snippet annotates the exception with the file name and line number.
50 """
51 def __init__(self, filename: str, binary: bool = False) -> None:
52 self.filename = filename
53 self.line_number = 'entry' #type: Union[int, str]
54 self.generator = None #type: Optional[Iterable[Tuple[int, str]]]
55 self.binary = binary
56 def __enter__(self) -> 'read_file_lines':
57 self.generator = enumerate(open(self.filename,
58 'rb' if self.binary else 'r'))
59 return self
60 def __iter__(self) -> Iterator[str]:
61 assert self.generator is not None
62 for line_number, content in self.generator:
63 self.line_number = line_number
64 yield content
65 self.line_number = 'exit'
66 def __exit__(self, exc_type, exc_value, exc_traceback) -> None:
67 if exc_type is not None:
68 raise ReadFileLineException(self.filename, self.line_number) \
69 from exc_value
Gilles Peskine22fcf1b2021-03-10 01:02:39 +010070
71
72class PSAMacroEnumerator:
73 """Information about constructors of various PSA Crypto types.
74
75 This includes macro names as well as information about their arguments
76 when applicable.
77
78 This class only provides ways to enumerate expressions that evaluate to
79 values of the covered types. Derived classes are expected to populate
80 the set of known constructors of each kind, as well as populate
81 `self.arguments_for` for arguments that are not of a kind that is
82 enumerated here.
83 """
84
85 def __init__(self) -> None:
86 """Set up an empty set of known constructor macros.
87 """
88 self.statuses = set() #type: Set[str]
89 self.algorithms = set() #type: Set[str]
90 self.ecc_curves = set() #type: Set[str]
91 self.dh_groups = set() #type: Set[str]
92 self.key_types = set() #type: Set[str]
93 self.key_usage_flags = set() #type: Set[str]
94 self.hash_algorithms = set() #type: Set[str]
95 self.mac_algorithms = set() #type: Set[str]
96 self.ka_algorithms = set() #type: Set[str]
97 self.kdf_algorithms = set() #type: Set[str]
98 self.aead_algorithms = set() #type: Set[str]
99 # macro name -> list of argument names
100 self.argspecs = {} #type: Dict[str, List[str]]
101 # argument name -> list of values
102 self.arguments_for = {
103 'mac_length': [],
104 'min_mac_length': [],
105 'tag_length': [],
106 'min_tag_length': [],
107 } #type: Dict[str, List[str]]
Gilles Peskine537d5fa2021-04-19 13:50:25 +0200108 self.include_intermediate = False
109
110 def is_internal_name(self, name: str) -> bool:
111 """Whether this is an internal macro. Internal macros will be skipped."""
112 if not self.include_intermediate:
113 if name.endswith('_BASE') or name.endswith('_NONE'):
114 return True
115 if '_CATEGORY_' in name:
116 return True
117 return name.endswith('_FLAG') or name.endswith('_MASK')
Gilles Peskine22fcf1b2021-03-10 01:02:39 +0100118
119 def gather_arguments(self) -> None:
120 """Populate the list of values for macro arguments.
121
122 Call this after parsing all the inputs.
123 """
124 self.arguments_for['hash_alg'] = sorted(self.hash_algorithms)
125 self.arguments_for['mac_alg'] = sorted(self.mac_algorithms)
126 self.arguments_for['ka_alg'] = sorted(self.ka_algorithms)
127 self.arguments_for['kdf_alg'] = sorted(self.kdf_algorithms)
128 self.arguments_for['aead_alg'] = sorted(self.aead_algorithms)
129 self.arguments_for['curve'] = sorted(self.ecc_curves)
130 self.arguments_for['group'] = sorted(self.dh_groups)
131
132 @staticmethod
133 def _format_arguments(name: str, arguments: Iterable[str]) -> str:
Gilles Peskinecccd1ac2021-04-21 15:36:58 +0200134 """Format a macro call with arguments.
135
136 The resulting format is consistent with
137 `InputsForTest.normalize_argument`.
138 """
Gilles Peskine22fcf1b2021-03-10 01:02:39 +0100139 return name + '(' + ', '.join(arguments) + ')'
140
141 _argument_split_re = re.compile(r' *, *')
142 @classmethod
143 def _argument_split(cls, arguments: str) -> List[str]:
144 return re.split(cls._argument_split_re, arguments)
145
146 def distribute_arguments(self, name: str) -> Iterator[str]:
147 """Generate macro calls with each tested argument set.
148
149 If name is a macro without arguments, just yield "name".
150 If name is a macro with arguments, yield a series of
151 "name(arg1,...,argN)" where each argument takes each possible
152 value at least once.
153 """
154 try:
155 if name not in self.argspecs:
156 yield name
157 return
158 argspec = self.argspecs[name]
159 if argspec == []:
160 yield name + '()'
161 return
162 argument_lists = [self.arguments_for[arg] for arg in argspec]
163 arguments = [values[0] for values in argument_lists]
164 yield self._format_arguments(name, arguments)
165 # Dear Pylint, enumerate won't work here since we're modifying
166 # the array.
167 # pylint: disable=consider-using-enumerate
168 for i in range(len(arguments)):
169 for value in argument_lists[i][1:]:
170 arguments[i] = value
171 yield self._format_arguments(name, arguments)
172 arguments[i] = argument_lists[0][0]
173 except BaseException as e:
174 raise Exception('distribute_arguments({})'.format(name)) from e
175
176 def generate_expressions(self, names: Iterable[str]) -> Iterator[str]:
177 """Generate expressions covering values constructed from the given names.
178
179 `names` can be any iterable collection of macro names.
180
181 For example:
182 * ``generate_expressions(['PSA_ALG_CMAC', 'PSA_ALG_HMAC'])``
183 generates ``'PSA_ALG_CMAC'`` as well as ``'PSA_ALG_HMAC(h)'`` for
184 every known hash algorithm ``h``.
185 * ``macros.generate_expressions(macros.key_types)`` generates all
186 key types.
187 """
188 return itertools.chain(*map(self.distribute_arguments, names))
189
Gilles Peskinee7c44552021-01-25 21:40:45 +0100190
Gilles Peskine33c601c2021-03-10 01:25:50 +0100191class PSAMacroCollector(PSAMacroEnumerator):
Gilles Peskinee7c44552021-01-25 21:40:45 +0100192 """Collect PSA crypto macro definitions from C header files.
193 """
194
Gilles Peskine10ab2672021-03-10 00:59:53 +0100195 def __init__(self, include_intermediate: bool = False) -> None:
Gilles Peskine13d60eb2021-01-25 22:42:14 +0100196 """Set up an object to collect PSA macro definitions.
197
198 Call the read_file method of the constructed object on each header file.
199
200 * include_intermediate: if true, include intermediate macros such as
201 PSA_XXX_BASE that do not designate semantic values.
202 """
Gilles Peskine33c601c2021-03-10 01:25:50 +0100203 super().__init__()
Gilles Peskine13d60eb2021-01-25 22:42:14 +0100204 self.include_intermediate = include_intermediate
Gilles Peskine10ab2672021-03-10 00:59:53 +0100205 self.key_types_from_curve = {} #type: Dict[str, str]
206 self.key_types_from_group = {} #type: Dict[str, str]
Gilles Peskine10ab2672021-03-10 00:59:53 +0100207 self.algorithms_from_hash = {} #type: Dict[str, str]
Gilles Peskinee7c44552021-01-25 21:40:45 +0100208
Gilles Peskine33c601c2021-03-10 01:25:50 +0100209 def record_algorithm_subtype(self, name: str, expansion: str) -> None:
210 """Record the subtype of an algorithm constructor.
211
212 Given a ``PSA_ALG_xxx`` macro name and its expansion, if the algorithm
213 is of a subtype that is tracked in its own set, add it to the relevant
214 set.
215 """
216 # This code is very ad hoc and fragile. It should be replaced by
217 # something more robust.
218 if re.match(r'MAC(?:_|\Z)', name):
219 self.mac_algorithms.add(name)
220 elif re.match(r'KDF(?:_|\Z)', name):
221 self.kdf_algorithms.add(name)
222 elif re.search(r'0x020000[0-9A-Fa-f]{2}', expansion):
223 self.hash_algorithms.add(name)
224 elif re.search(r'0x03[0-9A-Fa-f]{6}', expansion):
225 self.mac_algorithms.add(name)
226 elif re.search(r'0x05[0-9A-Fa-f]{6}', expansion):
227 self.aead_algorithms.add(name)
228 elif re.search(r'0x09[0-9A-Fa-f]{2}0000', expansion):
229 self.ka_algorithms.add(name)
230 elif re.search(r'0x08[0-9A-Fa-f]{6}', expansion):
231 self.kdf_algorithms.add(name)
232
Gilles Peskinee7c44552021-01-25 21:40:45 +0100233 # "#define" followed by a macro name with either no parameters
234 # or a single parameter and a non-empty expansion.
235 # Grab the macro name in group 1, the parameter name if any in group 2
236 # and the expansion in group 3.
237 _define_directive_re = re.compile(r'\s*#\s*define\s+(\w+)' +
238 r'(?:\s+|\((\w+)\)\s*)' +
239 r'(.+)')
240 _deprecated_definition_re = re.compile(r'\s*MBEDTLS_DEPRECATED')
241
242 def read_line(self, line):
243 """Parse a C header line and record the PSA identifier it defines if any.
244 This function analyzes lines that start with "#define PSA_"
245 (up to non-significant whitespace) and skips all non-matching lines.
246 """
247 # pylint: disable=too-many-branches
248 m = re.match(self._define_directive_re, line)
249 if not m:
250 return
251 name, parameter, expansion = m.groups()
252 expansion = re.sub(r'/\*.*?\*/|//.*', r' ', expansion)
Gilles Peskine33c601c2021-03-10 01:25:50 +0100253 if parameter:
254 self.argspecs[name] = [parameter]
Gilles Peskinee7c44552021-01-25 21:40:45 +0100255 if re.match(self._deprecated_definition_re, expansion):
256 # Skip deprecated values, which are assumed to be
257 # backward compatibility aliases that share
258 # numerical values with non-deprecated values.
259 return
Gilles Peskinef8deb752021-01-25 22:41:45 +0100260 if self.is_internal_name(name):
Gilles Peskinee7c44552021-01-25 21:40:45 +0100261 # Macro only to build actual values
262 return
263 elif (name.startswith('PSA_ERROR_') or name == 'PSA_SUCCESS') \
264 and not parameter:
265 self.statuses.add(name)
266 elif name.startswith('PSA_KEY_TYPE_') and not parameter:
267 self.key_types.add(name)
268 elif name.startswith('PSA_KEY_TYPE_') and parameter == 'curve':
269 self.key_types_from_curve[name] = name[:13] + 'IS_' + name[13:]
270 elif name.startswith('PSA_KEY_TYPE_') and parameter == 'group':
271 self.key_types_from_group[name] = name[:13] + 'IS_' + name[13:]
272 elif name.startswith('PSA_ECC_FAMILY_') and not parameter:
273 self.ecc_curves.add(name)
274 elif name.startswith('PSA_DH_FAMILY_') and not parameter:
275 self.dh_groups.add(name)
276 elif name.startswith('PSA_ALG_') and not parameter:
277 if name in ['PSA_ALG_ECDSA_BASE',
278 'PSA_ALG_RSA_PKCS1V15_SIGN_BASE']:
279 # Ad hoc skipping of duplicate names for some numerical values
280 return
281 self.algorithms.add(name)
Gilles Peskine33c601c2021-03-10 01:25:50 +0100282 self.record_algorithm_subtype(name, expansion)
Gilles Peskinee7c44552021-01-25 21:40:45 +0100283 elif name.startswith('PSA_ALG_') and parameter == 'hash_alg':
284 if name in ['PSA_ALG_DSA', 'PSA_ALG_ECDSA']:
285 # A naming irregularity
286 tester = name[:8] + 'IS_RANDOMIZED_' + name[8:]
287 else:
288 tester = name[:8] + 'IS_' + name[8:]
289 self.algorithms_from_hash[name] = tester
290 elif name.startswith('PSA_KEY_USAGE_') and not parameter:
Gilles Peskine33c601c2021-03-10 01:25:50 +0100291 self.key_usage_flags.add(name)
Gilles Peskinee7c44552021-01-25 21:40:45 +0100292 else:
293 # Other macro without parameter
294 return
295
296 _nonascii_re = re.compile(rb'[^\x00-\x7f]+')
297 _continued_line_re = re.compile(rb'\\\r?\n\Z')
298 def read_file(self, header_file):
299 for line in header_file:
300 m = re.search(self._continued_line_re, line)
301 while m:
302 cont = next(header_file)
303 line = line[:m.start(0)] + cont
304 m = re.search(self._continued_line_re, line)
305 line = re.sub(self._nonascii_re, rb'', line).decode('ascii')
306 self.read_line(line)
Gilles Peskineb4edff92021-03-30 19:09:05 +0200307
308
Gilles Peskine537d5fa2021-04-19 13:50:25 +0200309class InputsForTest(PSAMacroEnumerator):
Gilles Peskineb4edff92021-03-30 19:09:05 +0200310 # pylint: disable=too-many-instance-attributes
311 """Accumulate information about macros to test.
312enumerate
313 This includes macro names as well as information about their arguments
314 when applicable.
315 """
316
317 def __init__(self) -> None:
318 super().__init__()
319 self.all_declared = set() #type: Set[str]
Gilles Peskineb4edff92021-03-30 19:09:05 +0200320 # Identifier prefixes
321 self.table_by_prefix = {
322 'ERROR': self.statuses,
323 'ALG': self.algorithms,
324 'ECC_CURVE': self.ecc_curves,
325 'DH_GROUP': self.dh_groups,
326 'KEY_TYPE': self.key_types,
327 'KEY_USAGE': self.key_usage_flags,
328 } #type: Dict[str, Set[str]]
329 # Test functions
330 self.table_by_test_function = {
331 # Any function ending in _algorithm also gets added to
332 # self.algorithms.
333 'key_type': [self.key_types],
334 'block_cipher_key_type': [self.key_types],
335 'stream_cipher_key_type': [self.key_types],
336 'ecc_key_family': [self.ecc_curves],
337 'ecc_key_types': [self.ecc_curves],
338 'dh_key_family': [self.dh_groups],
339 'dh_key_types': [self.dh_groups],
340 'hash_algorithm': [self.hash_algorithms],
341 'mac_algorithm': [self.mac_algorithms],
342 'cipher_algorithm': [],
343 'hmac_algorithm': [self.mac_algorithms],
344 'aead_algorithm': [self.aead_algorithms],
345 'key_derivation_algorithm': [self.kdf_algorithms],
346 'key_agreement_algorithm': [self.ka_algorithms],
347 'asymmetric_signature_algorithm': [],
348 'asymmetric_signature_wildcard': [self.algorithms],
349 'asymmetric_encryption_algorithm': [],
350 'other_algorithm': [],
351 } #type: Dict[str, List[Set[str]]]
352 self.arguments_for['mac_length'] += ['1', '63']
353 self.arguments_for['min_mac_length'] += ['1', '63']
354 self.arguments_for['tag_length'] += ['1', '63']
355 self.arguments_for['min_tag_length'] += ['1', '63']
356
Gilles Peskine3d404b82021-03-30 21:46:35 +0200357 def add_numerical_values(self) -> None:
358 """Add numerical values that are not supported to the known identifiers."""
359 # Sets of names per type
360 self.algorithms.add('0xffffffff')
361 self.ecc_curves.add('0xff')
362 self.dh_groups.add('0xff')
363 self.key_types.add('0xffff')
364 self.key_usage_flags.add('0x80000000')
365
366 # Hard-coded values for unknown algorithms
367 #
368 # These have to have values that are correct for their respective
369 # PSA_ALG_IS_xxx macros, but are also not currently assigned and are
370 # not likely to be assigned in the near future.
371 self.hash_algorithms.add('0x020000fe') # 0x020000ff is PSA_ALG_ANY_HASH
372 self.mac_algorithms.add('0x03007fff')
373 self.ka_algorithms.add('0x09fc0000')
374 self.kdf_algorithms.add('0x080000ff')
375 # For AEAD algorithms, the only variability is over the tag length,
376 # and this only applies to known algorithms, so don't test an
377 # unknown algorithm.
378
Gilles Peskineb4edff92021-03-30 19:09:05 +0200379 def get_names(self, type_word: str) -> Set[str]:
380 """Return the set of known names of values of the given type."""
381 return {
382 'status': self.statuses,
383 'algorithm': self.algorithms,
384 'ecc_curve': self.ecc_curves,
385 'dh_group': self.dh_groups,
386 'key_type': self.key_types,
387 'key_usage': self.key_usage_flags,
388 }[type_word]
389
390 # Regex for interesting header lines.
391 # Groups: 1=macro name, 2=type, 3=argument list (optional).
392 _header_line_re = \
393 re.compile(r'#define +' +
394 r'(PSA_((?:(?:DH|ECC|KEY)_)?[A-Z]+)_\w+)' +
395 r'(?:\(([^\n()]*)\))?')
396 # Regex of macro names to exclude.
397 _excluded_name_re = re.compile(r'_(?:GET|IS|OF)_|_(?:BASE|FLAG|MASK)\Z')
398 # Additional excluded macros.
399 _excluded_names = set([
400 # Macros that provide an alternative way to build the same
401 # algorithm as another macro.
402 'PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG',
403 'PSA_ALG_FULL_LENGTH_MAC',
404 # Auxiliary macro whose name doesn't fit the usual patterns for
405 # auxiliary macros.
406 'PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG_CASE',
407 ])
408 def parse_header_line(self, line: str) -> None:
409 """Parse a C header line, looking for "#define PSA_xxx"."""
410 m = re.match(self._header_line_re, line)
411 if not m:
412 return
413 name = m.group(1)
414 self.all_declared.add(name)
415 if re.search(self._excluded_name_re, name) or \
Gilles Peskine537d5fa2021-04-19 13:50:25 +0200416 name in self._excluded_names or \
417 self.is_internal_name(name):
Gilles Peskineb4edff92021-03-30 19:09:05 +0200418 return
419 dest = self.table_by_prefix.get(m.group(2))
420 if dest is None:
421 return
422 dest.add(name)
423 if m.group(3):
424 self.argspecs[name] = self._argument_split(m.group(3))
425
426 _nonascii_re = re.compile(rb'[^\x00-\x7f]+') #type: Pattern
427 def parse_header(self, filename: str) -> None:
428 """Parse a C header file, looking for "#define PSA_xxx"."""
429 with read_file_lines(filename, binary=True) as lines:
430 for line in lines:
431 line = re.sub(self._nonascii_re, rb'', line).decode('ascii')
432 self.parse_header_line(line)
433
434 _macro_identifier_re = re.compile(r'[A-Z]\w+')
435 def generate_undeclared_names(self, expr: str) -> Iterable[str]:
436 for name in re.findall(self._macro_identifier_re, expr):
437 if name not in self.all_declared:
438 yield name
439
440 def accept_test_case_line(self, function: str, argument: str) -> bool:
441 #pylint: disable=unused-argument
442 undeclared = list(self.generate_undeclared_names(argument))
443 if undeclared:
444 raise Exception('Undeclared names in test case', undeclared)
445 return True
446
Gilles Peskinecccd1ac2021-04-21 15:36:58 +0200447 @staticmethod
448 def normalize_argument(argument: str) -> str:
449 """Normalize whitespace in the given C expression.
450
451 The result uses the same whitespace as
452 ` PSAMacroEnumerator.distribute_arguments`.
453 """
454 return re.sub(r',', r', ', re.sub(r' +', r'', argument))
455
Gilles Peskineb4edff92021-03-30 19:09:05 +0200456 def add_test_case_line(self, function: str, argument: str) -> None:
457 """Parse a test case data line, looking for algorithm metadata tests."""
458 sets = []
459 if function.endswith('_algorithm'):
460 sets.append(self.algorithms)
461 if function == 'key_agreement_algorithm' and \
462 argument.startswith('PSA_ALG_KEY_AGREEMENT('):
463 # We only want *raw* key agreement algorithms as such, so
464 # exclude ones that are already chained with a KDF.
465 # Keep the expression as one to test as an algorithm.
466 function = 'other_algorithm'
467 sets += self.table_by_test_function[function]
468 if self.accept_test_case_line(function, argument):
469 for s in sets:
Gilles Peskinecccd1ac2021-04-21 15:36:58 +0200470 s.add(self.normalize_argument(argument))
Gilles Peskineb4edff92021-03-30 19:09:05 +0200471
472 # Regex matching a *.data line containing a test function call and
473 # its arguments. The actual definition is partly positional, but this
474 # regex is good enough in practice.
475 _test_case_line_re = re.compile(r'(?!depends_on:)(\w+):([^\n :][^:\n]*)')
476 def parse_test_cases(self, filename: str) -> None:
477 """Parse a test case file (*.data), looking for algorithm metadata tests."""
478 with read_file_lines(filename) as lines:
479 for line in lines:
480 m = re.match(self._test_case_line_re, line)
481 if m:
482 self.add_test_case_line(m.group(1), m.group(2))