diff --git a/include/mbedtls/ssl.h b/include/mbedtls/ssl.h index 39bea79092..a6ee9a4487 100644 --- a/include/mbedtls/ssl.h +++ b/include/mbedtls/ssl.h @@ -1312,6 +1312,9 @@ struct mbedtls_ssl_session { #if defined(MBEDTLS_SSL_EARLY_DATA) uint32_t MBEDTLS_PRIVATE(max_early_data_size); /*!< maximum amount of early data in tickets */ +#if defined(MBEDTLS_SSL_ALPN) && defined(MBEDTLS_SSL_SRV_C) + char *alpn; /*!< ALPN negotiated in the session */ +#endif #endif #if defined(MBEDTLS_SSL_ENCRYPT_THEN_MAC) diff --git a/library/ssl_tls.c b/library/ssl_tls.c index 681ccab441..d7d26ab063 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -3735,7 +3735,25 @@ static int ssl_tls12_session_load(mbedtls_ssl_session *session, #if defined(MBEDTLS_SSL_PROTO_TLS1_3) /* Serialization of TLS 1.3 sessions: * - * For more detail, see the description of ssl_session_save(). + * struct { + * opaque hostname<0..2^16-1>; + * uint64 ticket_reception_time; + * uint32 ticket_lifetime; + * opaque ticket<1..2^16-1>; + * } ClientOnlyData; + * + * struct { + * uint32 ticket_age_add; + * uint8 ticket_flags; + * opaque resumption_key<0..255>; + * uint32 max_early_data_size; + * uint16 record_size_limit; + * select ( endpoint ) { + * case client: ClientOnlyData; + * case server: uint64 ticket_creation_time; + * }; + * } serialized_session_tls13; + * */ #if defined(MBEDTLS_SSL_SESSION_TICKETS) MBEDTLS_CHECK_RETURN_CRITICAL @@ -3750,9 +3768,16 @@ static int ssl_tls13_session_save(const mbedtls_ssl_session *session, size_t hostname_len = (session->hostname == NULL) ? 0 : strlen(session->hostname) + 1; #endif + +#if defined(MBEDTLS_SSL_SRV_C) && \ + defined(MBEDTLS_SSL_EARLY_DATA) && defined(MBEDTLS_SSL_ALPN) + const uint8_t alpn_len = (session->alpn == NULL) ? + 0 : (uint8_t) strlen(session->alpn) + 1; +#endif size_t needed = 4 /* ticket_age_add */ + 1 /* ticket_flags */ + 1; /* resumption_key length */ + *olen = 0; if (session->resumption_key_len > MBEDTLS_SSL_TLS1_3_TICKET_RESUMPTION_KEY_LEN) { @@ -3771,6 +3796,15 @@ static int ssl_tls13_session_save(const mbedtls_ssl_session *session, needed += 8; /* ticket_creation_time or ticket_reception_time */ #endif +#if defined(MBEDTLS_SSL_SRV_C) + if (session->endpoint == MBEDTLS_SSL_IS_SERVER) { +#if defined(MBEDTLS_SSL_EARLY_DATA) && defined(MBEDTLS_SSL_ALPN) + needed += 1 /* alpn_len */ + + alpn_len; /* alpn */ +#endif + } +#endif /* MBEDTLS_SSL_SRV_C */ + #if defined(MBEDTLS_SSL_CLI_C) if (session->endpoint == MBEDTLS_SSL_IS_CLIENT) { #if defined(MBEDTLS_SSL_SERVER_NAME_INDICATION) @@ -3813,13 +3847,24 @@ static int ssl_tls13_session_save(const mbedtls_ssl_session *session, p += 2; #endif /* MBEDTLS_SSL_RECORD_SIZE_LIMIT */ -#if defined(MBEDTLS_HAVE_TIME) && defined(MBEDTLS_SSL_SRV_C) +#if defined(MBEDTLS_SSL_SRV_C) if (session->endpoint == MBEDTLS_SSL_IS_SERVER) { +#if defined(MBEDTLS_HAVE_TIME) MBEDTLS_PUT_UINT64_BE((uint64_t) session->ticket_creation_time, p, 0); p += 8; - } #endif /* MBEDTLS_HAVE_TIME */ +#if defined(MBEDTLS_SSL_EARLY_DATA) && defined(MBEDTLS_SSL_ALPN) + *p++ = alpn_len; + if (alpn_len > 0) { + /* save chosen alpn */ + memcpy(p, session->alpn, alpn_len); + p += alpn_len; + } +#endif /* MBEDTLS_SSL_EARLY_DATA && MBEDTLS_SSL_ALPN */ + } +#endif /* MBEDTLS_SSL_SRV_C */ + #if defined(MBEDTLS_SSL_CLI_C) if (session->endpoint == MBEDTLS_SSL_IS_CLIENT) { #if defined(MBEDTLS_SSL_SERVER_NAME_INDICATION) @@ -3894,16 +3939,39 @@ static int ssl_tls13_session_load(mbedtls_ssl_session *session, p += 2; #endif /* MBEDTLS_SSL_RECORD_SIZE_LIMIT */ -#if defined(MBEDTLS_HAVE_TIME) && defined(MBEDTLS_SSL_SRV_C) +#if defined(MBEDTLS_SSL_SRV_C) if (session->endpoint == MBEDTLS_SSL_IS_SERVER) { +#if defined(MBEDTLS_HAVE_TIME) if (end - p < 8) { return MBEDTLS_ERR_SSL_BAD_INPUT_DATA; } session->ticket_creation_time = MBEDTLS_GET_UINT64_BE(p, 0); p += 8; - } #endif /* MBEDTLS_HAVE_TIME */ +#if defined(MBEDTLS_SSL_EARLY_DATA) && defined(MBEDTLS_SSL_ALPN) + uint8_t alpn_len; + + if (end - p < 1) { + return MBEDTLS_ERR_SSL_BAD_INPUT_DATA; + } + alpn_len = *p++; + + if (end - p < alpn_len) { + return MBEDTLS_ERR_SSL_BAD_INPUT_DATA; + } + if (alpn_len > 0) { + session->alpn = mbedtls_calloc(alpn_len, sizeof(char)); + if (session->alpn == NULL) { + return MBEDTLS_ERR_SSL_ALLOC_FAILED; + } + memcpy(session->alpn, p, alpn_len); + p += alpn_len; + } +#endif /* MBEDTLS_SSL_EARLY_DATA && MBEDTLS_SSL_ALPN */ + } +#endif /* MBEDTLS_SSL_SRV_C */ + #if defined(MBEDTLS_SSL_CLI_C) if (session->endpoint == MBEDTLS_SSL_IS_CLIENT) { #if defined(MBEDTLS_SSL_SERVER_NAME_INDICATION) @@ -4848,6 +4916,10 @@ void mbedtls_ssl_session_free(mbedtls_ssl_session *session) #if defined(MBEDTLS_SSL_PROTO_TLS1_3) && \ defined(MBEDTLS_SSL_SERVER_NAME_INDICATION) mbedtls_free(session->hostname); +#endif +#if defined(MBEDTLS_SSL_EARLY_DATA) && defined(MBEDTLS_SSL_ALPN) && \ + defined(MBEDTLS_SSL_SRV_C) + mbedtls_free(session->alpn); #endif mbedtls_free(session->ticket); #endif diff --git a/library/ssl_tls13_server.c b/library/ssl_tls13_server.c index 887c5c6c8f..291d64500d 100644 --- a/library/ssl_tls13_server.c +++ b/library/ssl_tls13_server.c @@ -467,7 +467,17 @@ static int ssl_tls13_session_copy_ticket(mbedtls_ssl_session *dst, #if defined(MBEDTLS_SSL_EARLY_DATA) dst->max_early_data_size = src->max_early_data_size; -#endif + +#if defined(MBEDTLS_SSL_ALPN) + if (src->alpn != NULL) { + dst->alpn = mbedtls_calloc(strlen(src->alpn) + 1, sizeof(char)); + if (dst->alpn == NULL) { + return MBEDTLS_ERR_SSL_ALLOC_FAILED; + } + memcpy(dst->alpn, src->alpn, strlen(src->alpn) + 1); + } +#endif /* MBEDTLS_SSL_ALPN */ +#endif /* MBEDTLS_SSL_EARLY_DATA*/ return 0; } @@ -3137,6 +3147,16 @@ static int ssl_tls13_prepare_new_session_ticket(mbedtls_ssl_context *ssl, MBEDTLS_SSL_PRINT_TICKET_FLAGS(4, session->ticket_flags); +#if defined(MBEDTLS_SSL_EARLY_DATA) && defined(MBEDTLS_SSL_ALPN) + if (ssl->alpn_chosen != NULL) { + session->alpn = mbedtls_calloc(strlen(ssl->alpn_chosen) + 1, sizeof(char)); + if (session->alpn == NULL) { + return MBEDTLS_ERR_SSL_ALLOC_FAILED; + } + memcpy(session->alpn, ssl->alpn_chosen, strlen(ssl->alpn_chosen) + 1); + } +#endif + /* Generate ticket_age_add */ if ((ret = ssl->conf->f_rng(ssl->conf->p_rng, (unsigned char *) &session->ticket_age_add, diff --git a/tests/src/test_helpers/ssl_helpers.c b/tests/src/test_helpers/ssl_helpers.c index 56e03f1090..89c1bbf522 100644 --- a/tests/src/test_helpers/ssl_helpers.c +++ b/tests/src/test_helpers/ssl_helpers.c @@ -1793,7 +1793,14 @@ int mbedtls_test_ssl_tls13_populate_session(mbedtls_ssl_session *session, #if defined(MBEDTLS_SSL_EARLY_DATA) session->max_early_data_size = 0x87654321; -#endif +#if defined(MBEDTLS_SSL_ALPN) && defined(MBEDTLS_SSL_SRV_C) + session->alpn = mbedtls_calloc(strlen("ALPNExample")+1, sizeof(char)); + if (session->alpn == NULL) { + return -1; + } + strcpy(session->alpn, "ALPNExample"); +#endif /* MBEDTLS_SSL_ALPN && MBEDTLS_SSL_SRV_C */ +#endif /* MBEDTLS_SSL_EARLY_DATA */ #if defined(MBEDTLS_HAVE_TIME) && defined(MBEDTLS_SSL_SRV_C) if (session->endpoint == MBEDTLS_SSL_IS_SERVER) { diff --git a/tests/suites/test_suite_ssl.function b/tests/suites/test_suite_ssl.function index 8cf2105a52..da07f2c62f 100644 --- a/tests/suites/test_suite_ssl.function +++ b/tests/suites/test_suite_ssl.function @@ -2104,6 +2104,15 @@ void ssl_serialize_session_save_load(int ticket_len, char *crt_file, #if defined(MBEDTLS_SSL_EARLY_DATA) TEST_ASSERT( original.max_early_data_size == restored.max_early_data_size); +#if defined(MBEDTLS_SSL_ALPN) && defined(MBEDTLS_SSL_SRV_C) + if (endpoint_type == MBEDTLS_SSL_IS_SERVER) { + TEST_ASSERT(original.alpn != NULL); + TEST_ASSERT(restored.alpn != NULL); + TEST_ASSERT(memcmp(original.alpn, + restored.alpn, + strlen(original.alpn)) == 0); + } +#endif #endif #if defined(MBEDTLS_SSL_SESSION_TICKETS) && defined(MBEDTLS_SSL_CLI_C)