Skip to content

Commit

Permalink
feat(write_flash): retry flashing if chip disconnects
Browse files Browse the repository at this point in the history
  • Loading branch information
peterdragun committed May 15, 2024
1 parent 1deb1c6 commit a15089a
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 48 deletions.
130 changes: 82 additions & 48 deletions esptool/cmds.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import itertools

from intelhex import IntelHex
from serial import SerialException

from .bin_image import ELFFile, ImageSegment, LoadFirmwareImage
from .bin_image import (
Expand Down Expand Up @@ -579,56 +580,89 @@ def write_flash(esp, args):
if compress:
uncimage = image
image = zlib.compress(uncimage, 9)
# Decompress the compressed binary a block at a time,
# to dynamically calculate the timeout based on the real write size
decompress = zlib.decompressobj()
blocks = esp.flash_defl_begin(uncsize, len(image), address)
else:
blocks = esp.flash_begin(uncsize, address, begin_rom_encrypted=encrypted)
argfile.seek(0) # in case we need it again
seq = 0
bytes_sent = 0 # bytes sent on wire
bytes_written = 0 # bytes written to flash
t = time.time()

timeout = DEFAULT_TIMEOUT

while len(image) > 0:
print_overwrite(
"Writing at 0x%08x... (%d %%)"
% (address + bytes_written, 100 * (seq + 1) // blocks)
)
sys.stdout.flush()
block = image[0 : esp.FLASH_WRITE_SIZE]
if compress:
# feeding each compressed block into the decompressor lets us
# see block-by-block how much will be written
block_uncompressed = len(decompress.decompress(block))
bytes_written += block_uncompressed
block_timeout = max(
DEFAULT_TIMEOUT,
timeout_per_mb(ERASE_WRITE_TIMEOUT_PER_MB, block_uncompressed),
)
if not esp.IS_STUB:
timeout = (
block_timeout # ROM code writes block to flash before ACKing
original_image = image # Save the whole image in case retry is needed
# Try again if reconnect was successful
for attempt in range(1, esp.WRITE_FLASH_ATTEMPTS + 1):
try:
if compress:
# Decompress the compressed binary a block at a time,
# to dynamically calculate the timeout based on the real write size
decompress = zlib.decompressobj()
blocks = esp.flash_defl_begin(uncsize, len(image), address)
else:
blocks = esp.flash_begin(
uncsize, address, begin_rom_encrypted=encrypted
)
esp.flash_defl_block(block, seq, timeout=timeout)
if esp.IS_STUB:
# Stub ACKs when block is received,
# then writes to flash while receiving the block after it
timeout = block_timeout
else:
# Pad the last block
block = block + b"\xff" * (esp.FLASH_WRITE_SIZE - len(block))
if encrypted:
esp.flash_encrypt_block(block, seq)
argfile.seek(0) # in case we need it again
seq = 0
bytes_sent = 0 # bytes sent on wire
bytes_written = 0 # bytes written to flash
t = time.time()

timeout = DEFAULT_TIMEOUT

while len(image) > 0:
print_overwrite(
"Writing at 0x%08x... (%d %%)"
% (address + bytes_written, 100 * (seq + 1) // blocks)
)
sys.stdout.flush()
block = image[0 : esp.FLASH_WRITE_SIZE]
if compress:
# feeding each compressed block into the decompressor lets us
# see block-by-block how much will be written
block_uncompressed = len(decompress.decompress(block))
bytes_written += block_uncompressed
block_timeout = max(
DEFAULT_TIMEOUT,
timeout_per_mb(
ERASE_WRITE_TIMEOUT_PER_MB, block_uncompressed
),
)
if not esp.IS_STUB:
timeout = block_timeout # ROM code writes block to flash before ACKing
esp.flash_defl_block(block, seq, timeout=timeout)
if esp.IS_STUB:
# Stub ACKs when block is received,
# then writes to flash while receiving the block after it
timeout = block_timeout
else:
# Pad the last block
block = block + b"\xff" * (esp.FLASH_WRITE_SIZE - len(block))
if encrypted:
esp.flash_encrypt_block(block, seq)
else:
esp.flash_block(block, seq)
bytes_written += len(block)
bytes_sent += len(block)
image = image[esp.FLASH_WRITE_SIZE :]
seq += 1
break
except SerialException:
if attempt == esp.WRITE_FLASH_ATTEMPTS or encrypted:
# Already retried once or encrypted mode is disabled because of security reasons
raise
print("\nLost connection, retrying...")
esp._port.close()
print("Waiting for the chip to reconnect", end="")
for _ in range(DEFAULT_CONNECT_ATTEMPTS):
try:
time.sleep(1)
esp._port.open()
print() # Print new line which was suppressed by print(".")
esp.connect()
if esp.IS_STUB:
# Hack to bypass the stub overwrite check
esp.IS_STUB = False
# Reflash stub because chip was reset
esp = esp.run_stub()
image = original_image
break
except SerialException:
print(".", end="")
sys.stdout.flush()
else:
esp.flash_block(block, seq)
bytes_written += len(block)
bytes_sent += len(block)
image = image[esp.FLASH_WRITE_SIZE :]
seq += 1
raise # Reconnect limit reached

if esp.IS_STUB:
# Stub only writes each block to flash after 'ack'ing the receive,
Expand Down
3 changes: 3 additions & 0 deletions esptool/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,9 @@ class ESPLoader(object):
# Chip IDs that are no longer supported by esptool
UNSUPPORTED_CHIPS = {6: "ESP32-S3(beta 3)"}

# Number of attempts to write flash data
WRITE_FLASH_ATTEMPTS = 2

def __init__(self, port=DEFAULT_PORT, baud=ESP_ROM_BAUD, trace_enabled=False):
"""Base constructor for ESPLoader bootloader interaction
Expand Down

0 comments on commit a15089a

Please sign in to comment.