Skip to content

Commit

Permalink
Internal change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 708635543
  • Loading branch information
mjanusz authored and copybara-github committed Dec 21, 2024
1 parent 4bfc8e8 commit 3fc09dc
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions ffn/jax/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,15 +338,17 @@ def _get_tf_writer(writers) -> metric_writers.SummaryWriter | None:
# pylint:enable=protected-access


def _get_ocp_args(train_iter: DataIterator) -> DataIterator:
def _get_ocp_args(
train_iter: DataIterator, restore: bool = True
) -> DataIterator:
if isinstance(train_iter, tf.data.Iterator):
return DatasetArgs(train_iter)


def _make_ckpt_args(state, train_iter: DataIterator) -> ocp.args.CheckpointArgs:
return ocp.args.Composite(
train_state=ocp.args.StandardSave(state),
train_iter=_get_ocp_args(train_iter),
train_iter=_get_ocp_args(train_iter, restore=False),
)


Expand Down

0 comments on commit 3fc09dc

Please sign in to comment.