tf_fuzz: add new crypto key generation model

* Add a new model of what it means for a crypto policy to be valid. This
  fixes the simulation of psa_generate_key() calls, which TF-Fuzz
  currently cannot accurately predict.

  The PSA Crypto properties modelled are whether an algorithm and key
  type are compatible, whether a key type/ algorithm are disabled, and
  what key sizes are allowed.

  Although the PSA Crypto specification has additional, more complicated
  requirements for policies and keys, this commit only implements what
  is necessary for predicting the outcome of psa_generate_key(), and not
  requirements that make a key useless but still valid or involve
  features not yet in TF-M or MbedTLS.

* Improve the way in which policies are filled in, and allow policies to
  be updated at simulation time using the value of a named policy asset.
  Information about other assets and calls is not accessible during
  parse time when the key policy is usually filled in. However, this is
  required for the improved simulation of psa_generate_key calls. This
  is because policy information for generate key calls come from a named
  policy created earlier in the test file.

* Add valid flag to set-policy calls, allowing the creation of a random
  valid policy. For example, see demo/36.test.

* Add demo/36.test. This test generates a policy with a (roughly) even
  chance of being valid or invalid and then tries to generate a key
  using it.

  Running this test a large number of times (~300) succeeds with the
  changes in this commit, showing that TF-Fuzz can now accurately
  predict the outcome of psa_generate_key calls.

Change-Id: Ia40ff893db50b8d2c579d975aa23341b7aab004d
Signed-off-by: Nik Dewally <Nik.Dewally@arm.com>
diff --git a/tf_fuzz/tfz-cpp/crypto_model/crypto_model.cpp b/tf_fuzz/tfz-cpp/crypto_model/crypto_model.cpp
new file mode 100644
index 0000000..e6b8caf
--- /dev/null
+++ b/tf_fuzz/tfz-cpp/crypto_model/crypto_model.cpp
@@ -0,0 +1,456 @@
+/* Copyright (c) 2024 Arm Limited. All Rights Reserved.
+ *
+ * SPDX-License-Identifier: BSD-3-Clause
+ */
+
+#include "crypto_model.hpp"
+#include <algorithm>
+#include <cmath>
+#include <map>
+#include <numeric>
+#include <stdexcept>
+
+using namespace crypto_model;
+using namespace crypto_model::internal;
+using namespace std;
+
+std::map<std::string,crypto_model::key_type> key_types;
+std::map<std::string,crypto_model::algorithm> algorithms;
+std::vector<std::string> hash_algorithms;
+
+// From PSA_VENDOR_RSA_MAX_KEY_BITS
+const uint RSA_KEY_MAX_SIZE = 4096;
+
+// From PSA_VENDOR_RSA_MIN_KEY_BITS
+const uint RSA_KEY_MIN_SIZE = 1024;
+
+// From PSA_MAX_KEY_BITS
+const uint MAX_KEY_SIZE = 0xfff8;
+
+// Size cannot be 0
+const uint MIN_KEY_SIZE = 1;
+
+const std::vector<std::pair<string, string>> key_type_allowed_with_algorithm = {
+    {"PSA_KEY_TYPE_AES","PSA_ALG_CBC_MAC"},
+    {"PSA_KEY_TYPE_AES","PSA_ALG_CMAC"},
+    {"PSA_KEY_TYPE_AES","PSA_ALG_CTR"},
+    {"PSA_KEY_TYPE_AES","PSA_ALG_CFB"},
+    {"PSA_KEY_TYPE_AES","PSA_ALG_OFB"},
+    {"PSA_KEY_TYPE_AES","PSA_ALG_CBC_NO_PADDING"},
+    {"PSA_KEY_TYPE_AES","PSA_ALG_CBC_PKCS7"},
+    {"PSA_KEY_TYPE_AES","PSA_ALG_ECB_NO_PADDING"},
+    {"PSA_KEY_TYPE_AES","PSA_ALG_CCM"},
+    {"PSA_KEY_TYPE_AES","PSA_ALG_GCM"},
+
+    {"PSA_KEY_TYPE_ARC4","PSA_ALG_STREAM_CIPHER"},
+
+    {"PSA_KEY_TYPE_ARIA","PSA_ALG_CBC_MAC"},
+    {"PSA_KEY_TYPE_ARIA","PSA_ALG_CMAC"},
+    {"PSA_KEY_TYPE_ARIA","PSA_ALG_CTR"},
+    {"PSA_KEY_TYPE_ARIA","PSA_ALG_CFB"},
+    {"PSA_KEY_TYPE_ARIA","PSA_ALG_OFB"},
+    {"PSA_KEY_TYPE_ARIA","PSA_ALG_CBC_NO_PADDING"},
+    {"PSA_KEY_TYPE_ARIA","PSA_ALG_CBC_PKCS7"},
+    {"PSA_KEY_TYPE_ARIA","PSA_ALG_ECB_NO_PADDING"},
+    {"PSA_KEY_TYPE_ARIA","PSA_ALG_CCM"},
+    {"PSA_KEY_TYPE_ARIA","PSA_ALG_GCM"},
+
+    {"PSA_KEY_TYPE_CAMELLIA","PSA_ALG_CBC_MAC"},
+    {"PSA_KEY_TYPE_CAMELLIA","PSA_ALG_CMAC"},
+    {"PSA_KEY_TYPE_CAMELLIA","PSA_ALG_CTR"},
+    {"PSA_KEY_TYPE_CAMELLIA","PSA_ALG_CFB"},
+    {"PSA_KEY_TYPE_CAMELLIA","PSA_ALG_OFB"},
+    {"PSA_KEY_TYPE_CAMELLIA","PSA_ALG_CBC_NO_PADDING"},
+    {"PSA_KEY_TYPE_CAMELLIA","PSA_ALG_CBC_PKCS7"},
+    {"PSA_KEY_TYPE_CAMELLIA","PSA_ALG_ECB_NO_PADDING"},
+    {"PSA_KEY_TYPE_CAMELLIA","PSA_ALG_CCM"},
+    {"PSA_KEY_TYPE_CAMELLIA","PSA_ALG_GCM"},
+
+    {"PSA_KEY_TYPE_CHACHA20","PSA_ALG_STREAM_CIPHER"},
+    {"PSA_KEY_TYPE_CHACHA20","PSA_ALG_CHACHA20_POLY1305"},
+
+    {"PSA_KEY_TYPE_DES","PSA_ALG_CMAC"},
+    {"PSA_KEY_TYPE_DES","PSA_ALG_CTR"},
+    {"PSA_KEY_TYPE_DES","PSA_ALG_CFB"},
+    {"PSA_KEY_TYPE_DES","PSA_ALG_OFB"},
+    {"PSA_KEY_TYPE_DES","PSA_ALG_CBC_NO_PADDING"},
+    {"PSA_KEY_TYPE_DES","PSA_ALG_CBC_PKCS7"},
+    {"PSA_KEY_TYPE_DES","PSA_ALG_ECB_NO_PADDING"},
+
+    {"PSA_KEY_TYPE_HMAC","PSA_ALG_HMAC"},
+
+    {"PSA_KEY_TYPE_DERIVE","PSA_ALG_HKDF"},
+    {"PSA_KEY_TYPE_DERIVE","PSA_ALG_TLS12_PRF"},
+    {"PSA_KEY_TYPE_DERIVE","PSA_ALG_TLS12_PSK_TO_MS"},
+
+    {"PSA_KEY_TYPE_PASSWORD_HASH","PSA_ALG_PBKDF2_HMAC"},
+    {"PSA_KEY_TYPE_PASSWORD_HASH","PSA_ALG_PBKDF2_AES_CMAC_PRF_128"},
+
+    {"PSA_KEY_TYPE_PASSWORD","PSA_ALG_PBKDF2_HMAC"},
+    {"PSA_KEY_TYPE_PASSWORD","PSA_ALG_PBKDF2_AES_CMAC_PRF_128"},
+
+    {"PSA_KEY_TYPE_PEPPER","PSA_ALG_PBKDF2_HMAC"},
+    {"PSA_KEY_TYPE_PEPPER","PSA_ALG_PBKDF2_AES_CMAC_PRF_128"},
+
+    {"PSA_KEY_TYPE_RAW_DATA","PSA_ALG_HKDF"},
+    {"PSA_KEY_TYPE_RAW_DATA","PSA_ALG_TLS12_PRF"},
+    {"PSA_KEY_TYPE_RAW_DATA","PSA_ALG_TLS12_PSK_TO_MS"},
+
+    {"PSA_KEY_TYPE_RSA_KEY_PAIR","PSA_ALG_RSA_OAEP"},
+    {"PSA_KEY_TYPE_RSA_KEY_PAIR","PSA_ALG_RSA_PKCS1V15_CRYPT"},
+    {"PSA_KEY_TYPE_RSA_KEY_PAIR","PSA_ALG_RSA_PKCS1V15_SIGN"},
+    {"PSA_KEY_TYPE_RSA_KEY_PAIR","PSA_ALG_RSA_PKCS1V15_SIGN_RAW"},
+
+    {"PSA_KEY_TYPE_RSA_PUBLIC_KEY","PSA_ALG_RSA_OAEP"},
+    {"PSA_KEY_TYPE_RSA_PUBLIC_KEY","PSA_ALG_RSA_PKCS1V15_CRYPT"},
+    {"PSA_KEY_TYPE_RSA_PUBLIC_KEY","PSA_ALG_RSA_PKCS1V15_SIGN"},
+    {"PSA_KEY_TYPE_RSA_PUBLIC_KEY","PSA_ALG_RSA_PKCS1V15_SIGN_RAW"},
+};
+
+// See tf-m/lib/ext/mbed-crypto/mbed_crypto_config
+const std::vector<string> disabled_key_types = {
+    "PSA_KEY_TYPE_ARIA",
+    "PSA_KEY_TYPE_ARC4",
+    "PSA_KEY_TYPE_CAMELLIA",
+    "PSA_KEY_TYPE_CHACHA20",
+    "PSA_KEY_TYPE_DES"
+};
+
+// See tf-m/lib/ext/mbed-crypto/mbed_crypto_config
+const std::vector<string> disabled_algorithms = {
+    "PSA_ALG_CBC_MAC",
+    "PSA_ALG_CHACHA20_POLY1305",
+    "PSA_WANT_ALG_ECB_NO_PADDING",
+    "PSA_ALG_MD5",
+    "PSA_ALG_OFB",
+    "PSA_ALG_SHA_1",
+    "PSA_ALG_PBKDF2_HMAC",
+    "PSA_ALG_RIPEMD160",
+    "PSA_ALG_STREAM_CIPHER",
+};
+
+// for readability
+const bool doesnt_require_hash = false;
+const bool requires_hash = true;
+const bool hash_alg = true;
+const bool not_hash_alg = false;
+
+void crypto_model::init_crypto_model(void) {
+
+    // Key type definitions.
+
+    define_key_type("PSA_KEY_TYPE_AES", {128, 192, 256}, MAX_KEY_SIZE,
+                    MIN_KEY_SIZE);
+    define_key_type("PSA_KEY_TYPE_ARC4", {},40,2048);
+    define_key_type("PSA_KEY_TYPE_CAMELLIA", {128, 192, 256}, MAX_KEY_SIZE,
+                    MIN_KEY_SIZE);
+    define_key_type("PSA_KEY_TYPE_CHACHA20", {256}, 256, 256);
+    define_key_type("PSA_KEY_TYPE_DES", {64, 128, 192}, MAX_KEY_SIZE,
+                    MIN_KEY_SIZE);
+    define_key_type("PSA_KEY_TYPE_DERIVE", {}, MAX_KEY_SIZE, MIN_KEY_SIZE);
+    define_key_type("PSA_KEY_TYPE_HMAC", {}, MAX_KEY_SIZE, MIN_KEY_SIZE);
+    define_key_type("PSA_KEY_TYPE_RAW_DATA", {}, MAX_KEY_SIZE, MIN_KEY_SIZE);
+    define_key_type("PSA_KEY_TYPE_PASSWORD", {}, MAX_KEY_SIZE, MIN_KEY_SIZE);
+    define_key_type("PSA_KEY_TYPE_PASSWORD_HASH", {}, MAX_KEY_SIZE,
+                    MIN_KEY_SIZE);
+
+    define_key_type("PSA_KEY_TYPE_PEPPER", {}, MAX_KEY_SIZE, MIN_KEY_SIZE);
+    define_key_type("PSA_KEY_TYPE_RSA_KEY_PAIR", {}, RSA_KEY_MAX_SIZE,
+                    RSA_KEY_MIN_SIZE);
+
+    define_key_type("PSA_KEY_TYPE_RSA_PUBLIC_KEY", {}, RSA_KEY_MAX_SIZE,
+                    RSA_KEY_MIN_SIZE);
+
+    // Algorithm definitions
+    define_algorithm("PSA_ALG_CCM",not_hash_alg,doesnt_require_hash);
+    define_algorithm("PSA_ALG_GCM",not_hash_alg,doesnt_require_hash);
+    define_algorithm("PSA_ALG_CHACHA20_POLY1305",not_hash_alg,doesnt_require_hash);
+    define_algorithm("PSA_ALG_RSA_PKCS1V15_CRYPT",not_hash_alg,doesnt_require_hash);
+    define_algorithm("PSA_ALG_RSA_OAEP", not_hash_alg, requires_hash);
+    define_algorithm("PSA_ALG_RSA_PKCS1V15_SIGN", not_hash_alg, requires_hash);
+    define_algorithm("PSA_ALG_RSA_PKCS1V15_SIGN_RAW", not_hash_alg, doesnt_require_hash);
+    define_algorithm("PSA_ALG_STREAM_CIPHER", not_hash_alg, doesnt_require_hash);
+    define_algorithm("PSA_ALG_CTR", not_hash_alg, doesnt_require_hash);
+    define_algorithm("PSA_ALG_CCM_STAR_NO_TAG", not_hash_alg, doesnt_require_hash);
+    define_algorithm("PSA_ALG_CFB", not_hash_alg, doesnt_require_hash);
+    define_algorithm("PSA_ALG_OFB", not_hash_alg, doesnt_require_hash);
+    define_algorithm("PSA_ALG_ECB_NO_PADDING", not_hash_alg, doesnt_require_hash);
+    define_algorithm("PSA_ALG_CBC_NO_PADDING", not_hash_alg, doesnt_require_hash);
+    define_algorithm("PSA_ALG_CBC_PKCS7", not_hash_alg, doesnt_require_hash);
+
+    define_algorithm("PSA_ALG_MD5",hash_alg,doesnt_require_hash);
+    define_algorithm("PSA_ALG_RIPEMD160",hash_alg,doesnt_require_hash);
+    define_algorithm("PSA_ALG_SHA_1",hash_alg,doesnt_require_hash);
+    define_algorithm("PSA_ALG_SHA_224",hash_alg,doesnt_require_hash);
+    define_algorithm("PSA_ALG_SHA_256",hash_alg,doesnt_require_hash);
+    define_algorithm("PSA_ALG_SHA_384",hash_alg,doesnt_require_hash);
+    define_algorithm("PSA_ALG_SHA_512",hash_alg,doesnt_require_hash);
+    define_algorithm("PSA_ALG_SHA_512_224",hash_alg,doesnt_require_hash);
+    define_algorithm("PSA_ALG_SHA_512_256",hash_alg,doesnt_require_hash);
+    define_algorithm("PSA_ALG_SHA3_224",hash_alg,doesnt_require_hash);
+    define_algorithm("PSA_ALG_SHA3_256",hash_alg,doesnt_require_hash);
+    define_algorithm("PSA_ALG_SHA3_384",hash_alg,doesnt_require_hash);
+    define_algorithm("PSA_ALG_SHA3_512",hash_alg,doesnt_require_hash);
+    define_algorithm("PSA_ALG_SHA3_512",hash_alg,doesnt_require_hash);
+
+    define_algorithm("PSA_ALG_FFDH",not_hash_alg,doesnt_require_hash);
+    define_algorithm("PSA_ALG_ECDH",not_hash_alg,doesnt_require_hash);
+
+    define_algorithm("PSA_ALG_HKDF",not_hash_alg,requires_hash);
+    define_algorithm("PSA_ALG_TLS12_PRF",not_hash_alg,requires_hash);
+    define_algorithm("PSA_ALG_TLS12_PSK_TO_MS",not_hash_alg,requires_hash);
+    define_algorithm("PSA_ALG_HKDF_EXTRACT",not_hash_alg,requires_hash);
+    define_algorithm("PSA_ALG_HKDF_EXPAND",not_hash_alg,requires_hash);
+    define_algorithm("PSA_ALG_TLS12_ECJPAKE_TO_PMS",not_hash_alg,doesnt_require_hash);
+    define_algorithm("PSA_ALG_PBKDF2_HMAC",not_hash_alg,requires_hash);
+    define_algorithm("PSA_ALG_PBKDF2_AES_CMAC_PRF_128",not_hash_alg,doesnt_require_hash);
+
+    define_algorithm("PSA_ALG_HMAC",not_hash_alg,requires_hash);
+    define_algorithm("PSA_ALG_CBC_MAC",not_hash_alg,doesnt_require_hash);
+    define_algorithm("PSA_ALG_CMAC",not_hash_alg,doesnt_require_hash);
+
+    // Load the data tables into the right places
+    for (auto &kt: disabled_key_types) {
+        key_types[kt].enabled = false;
+    }
+
+    for (auto &alg: disabled_algorithms) {
+        algorithms[alg].enabled = false;
+    }
+
+    for (auto pair: key_type_allowed_with_algorithm) {
+        key_type &kt = key_types[pair.first];
+        algorithm &alg = algorithms[pair.second];
+
+        if (!kt.is_enabled() || !alg.is_enabled()) {
+            continue;
+        }
+
+        alg.allowed_key_types.push_back(kt.get_string());
+        kt.allowed_algorithms.push_back(alg.get_string());
+    }
+
+}
+
+void crypto_model::internal::define_key_type(
+    std::string name, std::vector<uint> allowed_key_sizes_bits,
+    uint max_key_size_bits, uint min_key_size_bits) {
+
+    key_type kt;
+    kt.name = name;
+    kt.allowed_key_sizes_bits = allowed_key_sizes_bits;
+    kt.max_key_size_bits = max_key_size_bits;
+    kt.min_key_size_bits = min_key_size_bits;
+
+    key_types[name] = kt;
+}
+
+void crypto_model::internal::define_algorithm(std::string name,
+                                              bool is_hash_algorithm,
+                                              bool requires_hash) {
+    algorithm a;
+    a.name = name;
+    a.is_hash_algorithm_flag = is_hash_algorithm;
+    a.requires_hash_flag = requires_hash;
+
+    if (is_hash_algorithm) {
+        hash_algorithms.push_back(name);
+    }
+
+    algorithms[name] = a;
+}
+
+namespace crypto_model {
+
+uint get_random_key_size() {
+    int lower = std::floor(MIN_KEY_SIZE/ 8);
+    int higher = std::floor(MAX_KEY_SIZE/8);
+    return ((rand() % (higher - lower)) + lower) * 8;
+}
+
+algorithm& get_algorithm(std::string name) {
+
+    // turn hash algs of form PSA_ALG_FOO(PSA_ALG_BAR) into PSA_ALG_FOO.
+    int index = name.find("(");
+    if (index != string::npos) {
+        name = name.substr(index);
+    }
+
+    return algorithms[name];
+}
+
+
+key_type& get_key_type(std::string name) {
+    return key_types[name];
+}
+
+algorithm& get_random_algorithm() {
+    // to randomise:
+    //  get random index.
+    //  if it is a disabled alg, repeat
+    std::vector<int> v(algorithms.size());
+    std::iota(v.begin(),v.end(),0);
+    std::random_shuffle(v.begin(),v.end());
+    auto it = algorithms.begin();
+    for (int i:v) {
+        auto it = algorithms.begin();
+        std::advance(it,i);
+        algorithm& alg = it->second;
+        if (alg.is_enabled()) {
+            return alg;
+        }
+    }
+    throw std::logic_error("No enabled algorithms");
+}
+
+key_type& get_random_key_type() {
+    // to randomise:
+    //  get random index.
+    //  if it is a disabled key type, repeat
+    std::vector<int> v(key_types.size());
+    std::iota(v.begin(),v.end(),0);
+    std::random_shuffle(v.begin(),v.end());
+    auto it = key_types.begin();
+    for (int i:v) {
+        auto it = key_types.begin();
+        std::advance(it,i);
+        key_type& kt = it->second;
+        if (kt.is_enabled()) {
+            return kt;
+        }
+    }
+    throw std::logic_error("No enabled key types");
+}
+
+algorithm& get_random_hash_algorithm() {
+    std::random_shuffle(hash_algorithms.begin(),hash_algorithms.end());
+    for (int i=0;i<hash_algorithms.size();i++) {
+        if (algorithms[hash_algorithms[i]].is_enabled()) {
+            return algorithms[hash_algorithms[i]];
+        }
+    }
+
+    throw std::logic_error("No enabled hash algorithms");
+}
+
+algorithm::algorithm() {
+
+};
+
+algorithm::~algorithm() {
+
+};
+
+string algorithm::get_string() {
+    return name;
+};
+
+bool algorithm::is_enabled() {
+    return enabled;
+}
+
+std::string algorithm::get_string_with_hash() {
+    string out = name;
+    if (requires_hash()) {
+        algorithm& hash_alg = get_random_hash_algorithm();
+        out += "(";
+        out += hash_alg.name;
+        out += ")";
+    }
+
+    return out;
+}
+
+bool algorithm::valid_for_key_type(key_type& kt) {
+    if (std::find(allowed_key_types.begin(),allowed_key_types.end(),kt.get_string()) != allowed_key_types.end()) {
+        return true;
+    }
+    return false;
+}
+
+bool algorithm::requires_hash() {
+    return requires_hash_flag;
+}
+
+bool algorithm::is_hash_algorithm() {
+    return is_hash_algorithm_flag;
+}
+
+key_type& algorithm::random_valid_key_type() {
+    if (allowed_key_types.size() == 0) {
+        throw std::logic_error("No allowed key types for algorithm");
+    }
+    std::random_shuffle(allowed_key_types.begin(),allowed_key_types.end());
+    return key_types[allowed_key_types[0]];
+
+}
+key_type::key_type() {
+
+};
+
+key_type::~key_type() {
+
+};
+
+string key_type::get_string() {
+    return name;
+}
+
+bool key_type::is_enabled() {
+    return enabled;
+}
+
+algorithm& key_type::random_allowed_algorithm() {
+    if (allowed_algorithms.size() == 0) {
+        throw std::logic_error("Key type has no allowed algorithms");
+    }
+    std::random_shuffle(allowed_algorithms.begin(),allowed_algorithms.end());
+    return algorithms[allowed_algorithms[0]];
+
+}
+
+bool key_type::is_allowed_algorithm(algorithm& algorithm) {
+    if (std::find(allowed_algorithms.begin(),allowed_algorithms.end(),algorithm.get_string()) != allowed_algorithms.end()) {
+        return true;
+    }
+    return false;
+}
+
+bool key_type::is_valid_key_size(uint size) {
+    // (MbedTLS): size is always byte aligned
+    if (size % 8 != 0) {
+        return false;
+    }
+
+    if (size > max_key_size_bits) {
+        return false;
+    }
+
+    if (size < min_key_size_bits) {
+        return false;
+    }
+
+    // some keys only allow a fixed set of values
+    if (!allowed_key_sizes_bits.empty()) {
+        if (std::find(allowed_key_sizes_bits.begin(),
+                      allowed_key_sizes_bits.end(),
+                      size) == allowed_key_sizes_bits.end()) {
+            return false;
+        }
+    }
+    return true;
+}
+
+uint key_type::get_random_valid_key_size() {
+    if (!allowed_key_sizes_bits.empty()) {
+        std::random_shuffle(allowed_key_sizes_bits.begin(),allowed_key_sizes_bits.end());
+        return allowed_key_sizes_bits[0];
+    }
+
+    int lower = floor(min_key_size_bits/ 8);
+    int higher = floor(max_key_size_bits/8);
+
+    return ((rand() % (higher - lower)) + lower) * 8;
+}
+
+} // namespace crypto_model