diff --git a/library/psa_crypto.c b/library/psa_crypto.c index 36c5f75665..9e6c2efc39 100644 --- a/library/psa_crypto.c +++ b/library/psa_crypto.c @@ -72,10 +72,6 @@ #include "mbedtls/sha512.h" #include "md_psa.h" -#if defined(MBEDTLS_TEST_HOOKS) -#include "test/memory.h" -#endif - #if defined(MBEDTLS_PSA_BUILTIN_ALG_HKDF) || \ defined(MBEDTLS_PSA_BUILTIN_ALG_HKDF_EXTRACT) || \ defined(MBEDTLS_PSA_BUILTIN_ALG_HKDF_EXPAND) @@ -8451,6 +8447,13 @@ psa_status_t psa_pake_abort( } #endif /* PSA_WANT_ALG_SOME_PAKE */ +/* Memory copying test hooks */ +#if defined(MBEDTLS_TEST_HOOKS) +void (*psa_input_pre_copy_hook)(const uint8_t *input, size_t input_len) = NULL; +void (*psa_input_post_copy_hook)(const uint8_t *input, size_t input_len) = NULL; +void (*psa_output_pre_copy_hook)(const uint8_t *output, size_t output_len) = NULL; +void (*psa_output_post_copy_hook)(const uint8_t *output, size_t output_len) = NULL; +#endif /** Copy from an input buffer to a local copy. * @@ -8473,7 +8476,9 @@ psa_status_t psa_crypto_copy_input(const uint8_t *input, size_t input_len, } #if defined(MBEDTLS_TEST_HOOKS) - MBEDTLS_TEST_MEMORY_UNPOISON(input, input_len); + if (psa_input_pre_copy_hook != NULL) { + psa_input_pre_copy_hook(input, input_len); + } #endif if (input_len > 0) { @@ -8481,7 +8486,9 @@ psa_status_t psa_crypto_copy_input(const uint8_t *input, size_t input_len, } #if defined(MBEDTLS_TEST_HOOKS) - MBEDTLS_TEST_MEMORY_POISON(input, input_len); + if (psa_input_post_copy_hook != NULL) { + psa_input_post_copy_hook(input, input_len); + } #endif return PSA_SUCCESS; @@ -8508,7 +8515,9 @@ psa_status_t psa_crypto_copy_output(const uint8_t *output_copy, size_t output_co } #if defined(MBEDTLS_TEST_HOOKS) - MBEDTLS_TEST_MEMORY_UNPOISON(output, output_len); + if (psa_output_pre_copy_hook != NULL) { + psa_output_pre_copy_hook(output, output_len); + } #endif if (output_copy_len > 0) { @@ -8516,7 +8525,9 @@ psa_status_t psa_crypto_copy_output(const uint8_t *output_copy, size_t output_co } #if defined(MBEDTLS_TEST_HOOKS) - MBEDTLS_TEST_MEMORY_POISON(output, output_len); + if (psa_output_post_copy_hook != NULL) { + psa_output_post_copy_hook(output, output_len); + } #endif return PSA_SUCCESS; diff --git a/library/psa_crypto_invasive.h b/library/psa_crypto_invasive.h index 6a1181f882..51c90c64a4 100644 --- a/library/psa_crypto_invasive.h +++ b/library/psa_crypto_invasive.h @@ -79,6 +79,14 @@ psa_status_t psa_crypto_copy_input(const uint8_t *input, size_t input_len, psa_status_t psa_crypto_copy_output(const uint8_t *output_copy, size_t output_copy_len, uint8_t *output, size_t output_len); +/* + * Test hooks to use for memory unpoisoning/poisoning in copy functions. + */ +extern void (*psa_input_pre_copy_hook)(const uint8_t *input, size_t input_len); +extern void (*psa_input_post_copy_hook)(const uint8_t *input, size_t input_len); +extern void (*psa_output_pre_copy_hook)(const uint8_t *output, size_t output_len); +extern void (*psa_output_post_copy_hook)(const uint8_t *output, size_t output_len); + #endif /* MBEDTLS_TEST_HOOKS && MBEDTLS_PSA_CRYPTO_C */ #endif /* PSA_CRYPTO_INVASIVE_H */ diff --git a/tests/include/test/psa_memory_poisoning_wrappers.h b/tests/include/test/psa_memory_poisoning_wrappers.h index 08234b4948..e1642d2c17 100644 --- a/tests/include/test/psa_memory_poisoning_wrappers.h +++ b/tests/include/test/psa_memory_poisoning_wrappers.h @@ -2,6 +2,26 @@ #include "test/memory.h" +#include "psa_crypto_invasive.h" + +#if defined(MBEDTLS_TEST_MEMORY_CAN_POISON) + +static void setup_test_hooks(void) +{ + psa_input_pre_copy_hook = mbedtls_test_memory_unpoison; + psa_input_post_copy_hook = mbedtls_test_memory_poison; + psa_output_pre_copy_hook = mbedtls_test_memory_unpoison; + psa_output_post_copy_hook = mbedtls_test_memory_poison; +} + +static void teardown_test_hooks(void) +{ + psa_input_pre_copy_hook = NULL; + psa_input_post_copy_hook = NULL; + psa_output_pre_copy_hook = NULL; + psa_output_post_copy_hook = NULL; +} + psa_status_t wrap_psa_cipher_encrypt(mbedtls_svc_key_id_t key, psa_algorithm_t alg, const uint8_t *input, @@ -10,6 +30,7 @@ psa_status_t wrap_psa_cipher_encrypt(mbedtls_svc_key_id_t key, size_t output_size, size_t *output_length) { + setup_test_hooks(); MBEDTLS_TEST_MEMORY_POISON(input, input_length); MBEDTLS_TEST_MEMORY_POISON(output, output_size); psa_status_t status = psa_cipher_encrypt(key, @@ -21,7 +42,10 @@ psa_status_t wrap_psa_cipher_encrypt(mbedtls_svc_key_id_t key, output_length); MBEDTLS_TEST_MEMORY_UNPOISON(input, input_length); MBEDTLS_TEST_MEMORY_UNPOISON(output, output_size); + teardown_test_hooks(); return status; } #define psa_cipher_encrypt(...) wrap_psa_cipher_encrypt(__VA_ARGS__) + +#endif /* MBEDTLS_TEST_MEMORY_CAN_POISON */