fix(save_segment): Adds segment len check the same as bootloader does

This commit is contained in:
Konstantin Kondrashov
2024-11-05 16:45:27 +02:00
parent 7681ec0b7e
commit a6bceb7207

View File

@@ -270,13 +270,24 @@ class BaseFirmwareImage(object):
) )
return segment_data return segment_data
def save_segment(self, f, segment, checksum=None): def save_segment(self, f, segment, checksum=None, segment_name=None):
""" """
Save the next segment to the image file, Save the next segment to the image file,
return next checksum value if provided return next checksum value if provided
""" """
segment_data = self.maybe_patch_segment_data(f, segment.data) segment_data = self.maybe_patch_segment_data(f, segment.data)
f.write(struct.pack("<II", segment.addr, len(segment_data))) segment_len = len(segment_data)
segment_name = segment_name if segment_name is not None else ""
if segment_len & 3:
raise FatalError(
f"Invalid {segment_name} segment length {segment_len:#x}. It has to be multiple of 4."
)
SIXTEEN_MB = 0x1000000
if segment_len >= SIXTEEN_MB:
raise FatalError(
f"Invalid {segment_name} segment length {segment_len:#x}. The 16 MB limit has been exceeded."
)
f.write(struct.pack("<II", segment.addr, segment_len))
f.write(segment_data) f.write(segment_data)
if checksum is not None: if checksum is not None:
return ESPLoader.checksum(segment_data, checksum) return ESPLoader.checksum(segment_data, checksum)
@@ -293,7 +304,8 @@ class BaseFirmwareImage(object):
segment_len_remainder = segment_end_pos % self.IROM_ALIGN segment_len_remainder = segment_end_pos % self.IROM_ALIGN
if segment_len_remainder < 0x24: if segment_len_remainder < 0x24:
segment.data += b"\x00" * (0x24 - segment_len_remainder) segment.data += b"\x00" * (0x24 - segment_len_remainder)
return self.save_segment(f, segment, checksum) segment_name = getattr(segment, "name", None)
return self.save_segment(f, segment, checksum, segment_name)
def read_checksum(self, f): def read_checksum(self, f):
"""Return ESPLoader checksum from end of just-read image""" """Return ESPLoader checksum from end of just-read image"""
@@ -748,7 +760,7 @@ class ESP32FirmwareImage(BaseFirmwareImage):
# and checksum (ROM bootloader will only care for RAM segments and its # and checksum (ROM bootloader will only care for RAM segments and its
# correct checksums) # correct checksums)
for segment in ram_segments: for segment in ram_segments:
checksum = self.save_segment(f, segment, checksum) checksum = self.save_segment(f, segment, checksum, segment.name)
total_segments += 1 total_segments += 1
self.append_checksum(f, checksum) self.append_checksum(f, checksum)
@@ -769,7 +781,7 @@ class ESP32FirmwareImage(BaseFirmwareImage):
pad_len -= self.ROM_LOADER.BOOTLOADER_FLASH_OFFSET pad_len -= self.ROM_LOADER.BOOTLOADER_FLASH_OFFSET
pad_segment = ImageSegment(0, b"\x00" * pad_len, f.tell()) pad_segment = ImageSegment(0, b"\x00" * pad_len, f.tell())
self.save_segment(f, pad_segment) self.save_segment(f, pad_segment, None, segment.name)
total_segments += 1 total_segments += 1
# check the alignment # check the alignment
assert (f.tell() + 8 + self.ROM_LOADER.BOOTLOADER_FLASH_OFFSET) % ( assert (f.tell() + 8 + self.ROM_LOADER.BOOTLOADER_FLASH_OFFSET) % (
@@ -793,7 +805,9 @@ class ESP32FirmwareImage(BaseFirmwareImage):
ram_segments.pop(0) ram_segments.pop(0)
else: else:
pad_segment = ImageSegment(0, b"\x00" * pad_len, f.tell()) pad_segment = ImageSegment(0, b"\x00" * pad_len, f.tell())
checksum = self.save_segment(f, pad_segment, checksum) checksum = self.save_segment(
f, pad_segment, checksum, segment.name
)
total_segments += 1 total_segments += 1
else: else:
# write the flash segment # write the flash segment
@@ -806,7 +820,7 @@ class ESP32FirmwareImage(BaseFirmwareImage):
# flash segments all written, so write any remaining RAM segments # flash segments all written, so write any remaining RAM segments
for segment in ram_segments: for segment in ram_segments:
checksum = self.save_segment(f, segment, checksum) checksum = self.save_segment(f, segment, checksum, segment.name)
total_segments += 1 total_segments += 1
if self.secure_pad: if self.secure_pad: