diff --git a/library/lmots.c b/library/lmots.c index c51cb41ece..404aa80da6 100644 --- a/library/lmots.c +++ b/library/lmots.c @@ -401,8 +401,11 @@ int mbedtls_lmots_import_public_key(mbedtls_lmots_public_t *ctx, return MBEDTLS_ERR_LMS_BAD_INPUT_DATA; } - ctx->params.type = (mbedtls_lmots_algorithm_type_t) - MBEDTLS_GET_UINT32_BE(key, MBEDTLS_LMOTS_SIG_TYPE_OFFSET); + uint32_t type = MBEDTLS_GET_UINT32_BE(key, MBEDTLS_LMOTS_SIG_TYPE_OFFSET); + if (type != (uint32_t) MBEDTLS_LMOTS_SHA256_N32_W8) { + return MBEDTLS_ERR_LMS_BAD_INPUT_DATA; + } + ctx->params.type = (mbedtls_lmots_algorithm_type_t) type; if (key_len != MBEDTLS_LMOTS_PUBLIC_KEY_LEN(ctx->params.type)) { return MBEDTLS_ERR_LMS_BAD_INPUT_DATA; diff --git a/library/lms.c b/library/lms.c index 4bdfd434ad..8d1840cabf 100644 --- a/library/lms.c +++ b/library/lms.c @@ -239,29 +239,25 @@ void mbedtls_lms_public_free(mbedtls_lms_public_t *ctx) int mbedtls_lms_import_public_key(mbedtls_lms_public_t *ctx, const unsigned char *key, size_t key_size) { - mbedtls_lms_algorithm_type_t type; - mbedtls_lmots_algorithm_type_t otstype; - if (key_size < 4) { return MBEDTLS_ERR_LMS_BAD_INPUT_DATA; } - type = (mbedtls_lms_algorithm_type_t) MBEDTLS_GET_UINT32_BE(key, PUBLIC_KEY_TYPE_OFFSET); - if (type != MBEDTLS_LMS_SHA256_M32_H10) { + uint32_t type = MBEDTLS_GET_UINT32_BE(key, PUBLIC_KEY_TYPE_OFFSET); + if (type != (uint32_t) MBEDTLS_LMS_SHA256_M32_H10) { return MBEDTLS_ERR_LMS_BAD_INPUT_DATA; } - ctx->params.type = type; + ctx->params.type = (mbedtls_lms_algorithm_type_t) type; if (key_size != MBEDTLS_LMS_PUBLIC_KEY_LEN(ctx->params.type)) { return MBEDTLS_ERR_LMS_BAD_INPUT_DATA; } - otstype = (mbedtls_lmots_algorithm_type_t) - MBEDTLS_GET_UINT32_BE(key, PUBLIC_KEY_OTSTYPE_OFFSET); - if (otstype != MBEDTLS_LMOTS_SHA256_N32_W8) { + uint32_t otstype = MBEDTLS_GET_UINT32_BE(key, PUBLIC_KEY_OTSTYPE_OFFSET); + if (otstype != (uint32_t) MBEDTLS_LMOTS_SHA256_N32_W8) { return MBEDTLS_ERR_LMS_BAD_INPUT_DATA; } - ctx->params.otstype = otstype; + ctx->params.otstype = (mbedtls_lmots_algorithm_type_t) otstype; memcpy(ctx->params.I_key_identifier, key + PUBLIC_KEY_I_KEY_ID_OFFSET,