diff --git a/library/ssl_tls.c b/library/ssl_tls.c index a62d4e1962..d8fbd77b91 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -8942,36 +8942,43 @@ static int mbedtls_ssl_tls12_export_keying_material(const mbedtls_ssl_context *s const int use_context) { int ret = 0; - size_t prf_input_len = use_context ? 64 + 2 + context_len : 64; unsigned char *prf_input = NULL; - if (use_context && context_len >= (1 << 16)) { - ret = MBEDTLS_ERR_SSL_BAD_INPUT_DATA; - goto exit; - } - - prf_input = mbedtls_calloc(prf_input_len, sizeof(unsigned char)); - if (prf_input == NULL) { - ret = MBEDTLS_ERR_SSL_ALLOC_FAILED; - goto exit; - } - /* The input to the PRF is client_random, then server_random. * If a context is provided, this is then followed by the context length * as a 16-bit big-endian integer, and then the context itself. */ - memcpy(prf_input, ssl->transform->randbytes + 32, 32); - memcpy(prf_input + 32, ssl->transform->randbytes, 32); + const size_t randbytes_len = MBEDTLS_CLIENT_HELLO_RANDOM_LEN + MBEDTLS_SERVER_HELLO_RANDOM_LEN; + size_t prf_input_len = randbytes_len; if (use_context) { - prf_input[64] = (unsigned char) ((context_len >> 8) & 0xff); - prf_input[65] = (unsigned char) (context_len & 0xff); - memcpy(prf_input + 66, context, context_len); + if (context_len > UINT16_MAX) { + return MBEDTLS_ERR_SSL_BAD_INPUT_DATA; + } + + /* This does not overflow a 32-bit size_t because the current value of + * prf_input_len is 64 (length of client_random + server_random) and + * context_len fits into two bytes (checked above). */ + prf_input_len += sizeof(uint16_t) + context_len; } - ret = tls_prf_generic(hash_alg, ssl->session->master, 48, + + prf_input = mbedtls_calloc(prf_input_len, sizeof(unsigned char)); + if (prf_input == NULL) { + return MBEDTLS_ERR_SSL_ALLOC_FAILED; + } + + memcpy(prf_input, + ssl->transform->randbytes + MBEDTLS_SERVER_HELLO_RANDOM_LEN, + MBEDTLS_CLIENT_HELLO_RANDOM_LEN); + memcpy(prf_input + MBEDTLS_CLIENT_HELLO_RANDOM_LEN, + ssl->transform->randbytes, + MBEDTLS_SERVER_HELLO_RANDOM_LEN); + if (use_context) { + MBEDTLS_PUT_UINT16_BE(context_len, prf_input, randbytes_len); + memcpy(prf_input + randbytes_len + sizeof(uint16_t), context, context_len); + } + ret = tls_prf_generic(hash_alg, ssl->session->master, sizeof(ssl->session->master), label, label_len, prf_input, prf_input_len, out, key_len); - -exit: mbedtls_free(prf_input); return ret; } @@ -8991,7 +8998,11 @@ static int mbedtls_ssl_tls13_export_keying_material(mbedtls_ssl_context *ssl, const size_t hash_len = PSA_HASH_LENGTH(hash_alg); const unsigned char *secret = ssl->session->app_secrets.exporter_master_secret; - if (key_len > 0xffff || label_len > 250) { + /* Check that the label and key_len fit into the HkdfLabel struct as defined + * in RFC 8446, Section 7.1. key_len must fit into an uint16 and the label + * must be at most 250 bytes long. (The struct allows up to 256 bytes for + * the label, but it is prefixed with "tls13 ".) */ + if (key_len > UINT16_MAX || label_len > 250) { return MBEDTLS_ERR_SSL_BAD_INPUT_DATA; }