From 6a874e062be9e6c7cf8ff19d8aa96d8a77e043b1 Mon Sep 17 00:00:00 2001 From: Connectomics Team Date: Thu, 17 Oct 2024 10:16:00 -0700 Subject: [PATCH] Update TfGrainCheckpointHandler. PiperOrigin-RevId: 686958575 --- ffn/jax/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ffn/jax/train.py b/ffn/jax/train.py index 698e9af..4149747 100644 --- a/ffn/jax/train.py +++ b/ffn/jax/train.py @@ -432,7 +432,7 @@ def train_and_evaluate( train_state_path, args=ocp.args.StandardRestore(state) ) checkpointed_state['train_iter'] = iter_handler.restore( - train_iter_path, args + train_iter_path, args=args ) logging.info('Initializing training from %r', config.init_from_cpoint) elif latest_step is not None: