-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathexperiment.py
422 lines (341 loc) · 16.5 KB
/
experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
# Copyright 2022 The VDM Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
import functools
import os
from typing import Any, Tuple
from absl import logging
import chex
from clu import periodic_actions
from clu import parameter_overview
from clu import metric_writers
from clu import checkpoint
from flax.core.frozen_dict import unfreeze, FrozenDict
import flax.jax_utils as flax_utils
import flax
from jax._src.random import PRNGKey
import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
import optax
from optax._src import base
import tensorflow as tf
from tqdm import tqdm
import io
import train_state
import utils as utils
import dataset as dataset
class Experiment(ABC):
"""Boilerplate for training and evaluating VDM models."""
def __init__(self, config: ml_collections.ConfigDict):
self.config = config
# Set seed before initializing model.
seed = config.training.seed
self.rng = utils.with_verbosity("ERROR", lambda: jax.random.PRNGKey(seed))
# initialize dataset
logging.warning("=== Initializing dataset ===")
self.rng, data_rng = jax.random.split(self.rng)
self.train_iter, self.eval_iter = dataset.create_dataset(config, data_rng)
# initialize model
logging.warning("=== Initializing model ===")
self.rng, model_rng = jax.random.split(self.rng)
self.model, params = self.get_model_and_params(model_rng)
parameter_overview.log_parameter_overview(params)
# initialize train state
logging.info("=== Initializing train state ===")
self.state = train_state.TrainState.create(
apply_fn=self.model.apply, variables=params, optax_optimizer=self.get_optimizer
)
self.lr_schedule = self.get_lr_schedule()
# Restore from checkpoint
ckpt_restore_dir = self.config.get("ckpt_restore_dir", "None")
if ckpt_restore_dir != "None":
ckpt_restore = checkpoint.Checkpoint(ckpt_restore_dir)
checkpoint_to_restore = ckpt_restore.get_latest_checkpoint_to_restore_from()
assert checkpoint_to_restore
state_restore_dict = ckpt_restore.restore_dict(checkpoint_to_restore)
self.state = restore_partial(self.state, state_restore_dict)
del state_restore_dict, ckpt_restore, checkpoint_to_restore
# initialize train/eval step
logging.info("=== Initializing train/eval step ===")
self.rng, train_rng = jax.random.split(self.rng)
self.p_train_step = functools.partial(self.train_step, train_rng)
self.p_train_step = functools.partial(jax.lax.scan, self.p_train_step)
self.p_train_step = jax.pmap(self.p_train_step, "batch")
self.rng, eval_rng, sample_rng = jax.random.split(self.rng, 3)
self.p_eval_step = functools.partial(self.eval_step, eval_rng)
self.p_eval_step = jax.pmap(self.p_eval_step, "batch")
self.p_sample = functools.partial(
self.sample_fn,
dummy_inputs=next(self.eval_iter)["images"],
rng=sample_rng,
)
def _gather(fn, *args, **kwargs):
samples, nfe = fn(*args, **kwargs)
shape = samples.shape
samples = samples.reshape((-1, shape[2], shape[3], shape[4]))
return samples, nfe
self.p_sample = functools.partial(_gather, fn=self.p_sample)
# self.p_sample = utils.dist(self.p_sample, accumulate="concat", axis_name="batch")
logging.info("=== Done with Experiment.__init__ ===")
def get_lr_schedule(self):
learning_rate = self.config.optimizer.learning_rate
config_train = self.config.training
# Create learning rate schedule
warmup_fn = optax.linear_schedule(
init_value=1e-8, end_value=learning_rate, transition_steps=config_train.num_steps_lr_warmup
)
if self.config.optimizer.lr_decay:
decay_fn = optax.linear_schedule(
init_value=learning_rate,
end_value=1e-8,
transition_steps=config_train.num_steps_train - config_train.num_steps_lr_warmup,
)
schedule_fn = optax.join_schedules(
schedules=[warmup_fn, decay_fn], boundaries=[config_train.num_steps_lr_warmup]
)
else:
schedule_fn = warmup_fn
return schedule_fn
def get_optimizer(self, lr: float) -> base.GradientTransformation:
"""Get an optax optimizer. Can be overided."""
config = self.config.optimizer
def decay_mask_fn(params):
flat_params = flax.traverse_util.flatten_dict(unfreeze(params))
flat_mask = {
path: (
path[-1] != "pos_embedding"
and path[-1] != "bias"
and path[-2:] not in [("LayerNorm_0", "scale"), ("LayerNorm_1", "scale")]
)
for path in flat_params
}
return FrozenDict(flax.traverse_util.unflatten_dict(flat_mask))
if config.name == "adamw":
optimizer = optax.adamw(
learning_rate=lr,
mask=decay_mask_fn,
**config.args,
)
if hasattr(config, "gradient_clip_norm"):
clip = optax.clip_by_global_norm(config.gradient_clip_norm)
optimizer = optax.chain(clip, optimizer)
else:
raise Exception("Unknow optimizer.")
if config.gradient_accumulation > 1:
optimizer = optax.MultiSteps(optimizer, every_k_schedule=config.gradient_accumulation)
return optimizer
@abstractmethod
def get_model_and_params(self, rng: PRNGKey):
"""Return the model and initialized parameters."""
...
@abstractmethod
def sample_fn(self, *, dummy_inputs, rng, params) -> chex.Array:
"""Generate a batch of samples in [0, 255]."""
...
@abstractmethod
def loss_fn(self, params, batch, rng, is_train) -> Tuple[float, Any]:
"""Loss function and metrics."""
...
@abstractmethod
def likelihood_fn(self, *, inputs, rng, params) -> float:
"""Likelihood function."""
...
def train_and_evaluate(self, workdir: str):
logging.warning("=== Experiment.train_and_evaluate() ===")
logging.info("Workdir: " + workdir)
# if jax.process_index() == 0:
# if not tf.io.gfile.exists(workdir):
# tf.io.gfile.mkdir(workdir)
config = self.config.training
logging.info("num_steps_train=%d", config.num_steps_train)
# Get train state
state = self.state
# Set up checkpointing of the model and the input pipeline.
checkpoint_dir = os.path.join(workdir, "checkpoints")
ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir, max_to_keep=5)
checkpoint_to_restore = ckpt.get_latest_checkpoint_to_restore_from()
if checkpoint_to_restore:
state = ckpt.restore_or_initialize(state, checkpoint_to_restore)
initial_step = int(state.step)
# Distribute training.
state = flax_utils.replicate(state)
# Create logger/writer
writer = utils.create_custom_writer(workdir, jax.process_index())
logging_writer = utils.CustomLoggingWriter()
if initial_step == 0:
writer.write_hparams(dict(self.config))
hooks = []
report_progress = periodic_actions.ReportProgress(num_train_steps=config.num_steps_train, writer=writer)
if jax.process_index() == 0:
hooks += [report_progress]
if config.profile:
hooks += [periodic_actions.Profile(num_profile_steps=5, logdir=workdir)]
step = initial_step
substeps = config.substeps
with metric_writers.ensure_flushes(writer):
logging.info("=== Start of training ===")
# the step count starts from 1 to num_steps_train
while step < config.num_steps_train:
is_last_step = step + substeps >= config.num_steps_train
# One training step
with jax.profiler.StepTraceAnnotation("train", step_num=step):
batch = jax.tree_map(jnp.asarray, next(self.train_iter))
state, _train_metrics = self.p_train_step(state, batch)
# Quick indication that training is happening.
logging.log_first_n(logging.WARNING, "Finished training step %d.", 3, step)
for h in hooks:
h(step)
new_step = int(state.step[0])
assert new_step == step + substeps
step = new_step
if step % config.steps_per_logging == 0 or is_last_step:
logging.debug("=== Writing scalars ===")
metrics = flax_utils.unreplicate(_train_metrics["scalars"])
def avg_over_substeps(x):
assert x.shape[0] == substeps
return float(x.mean(axis=0))
metrics = jax.tree_map(avg_over_substeps, metrics)
writer.write_scalars(step, metrics)
if step % config.steps_per_eval == 0 or is_last_step or step == 1000:
logging.debug("=== Running eval ===")
with report_progress.timed("eval"):
eval_metrics = []
for eval_step in range(config.num_steps_eval):
batch = self.eval_iter.next()
batch = jax.tree_map(jnp.asarray, batch)
metrics = self.p_eval_step(state.ema_params, batch, flax_utils.replicate(eval_step))
eval_metrics.append(metrics["scalars"])
# average over eval metrics
eval_metrics = utils.get_metrics(eval_metrics)
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
writer.write_scalars(step, eval_metrics)
# print out a batch of images
metrics = flax_utils.unreplicate(metrics)
images = metrics["images"]
samples, _ = self.p_sample(params=flax_utils.unreplicate(state.ema_params))
samples = utils.generate_image_grids(samples)[None, :, :, :]
images["samples"] = samples.astype(np.uint8)
writer.write_images(step, images)
if step % config.steps_per_save == 0 or is_last_step:
with report_progress.timed("checkpoint"):
ckpt.save(flax_utils.unreplicate(state))
def evaluate(self, logdir, checkpoint_dir):
"""Perform one evaluation."""
logging.info("=== Experiment.evaluate() ===")
ckpt = checkpoint.Checkpoint(checkpoint_dir)
state_dict = ckpt.restore_dict()
params = flax.core.FrozenDict(state_dict["ema_params"])
step = int(state_dict["step"])
# Distribute training.
params_rep = flax_utils.replicate(params)
eval_logdir = os.path.join(logdir, "eval")
tf.io.gfile.makedirs(eval_logdir)
writer = metric_writers.create_default_writer(eval_logdir, just_logging=jax.process_index() > 0)
eval_metrics = []
ode_bpds = []
sde_bpds = []
for eval_step in tqdm(range(self.config.training.num_steps_eval)):
batch = self.eval_iter.next()
batch = jax.tree_map(jnp.asarray, batch)
metrics = self.p_eval_step(params_rep, batch, flax_utils.replicate(eval_step))
self.rng, *step_rng = jax.random.split(self.rng, jax.local_device_count() + 1)
step_rng = jnp.asarray(step_rng)
ode_bpd = self.likelihood_fn(inputs=batch, rng=step_rng, params=params)
ode_bpds.append(jnp.mean(ode_bpd.reshape((-1,))))
sde_bpds.append(jnp.mean(metrics["scalars"]["eval_bpd"].reshape((-1,))))
eval_metrics.append(metrics["scalars"])
print("SDE likelihood:", np.mean(sde_bpds))
print("ODE likelihood:", np.mean(ode_bpds))
# average over eval metrics
eval_metrics = utils.get_metrics(eval_metrics)
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
writer.write_scalars(step, eval_metrics)
print("ode_bpd", np.mean(ode_bpds))
# sample a batch of images
samples, _ = self.p_sample(params=params)
samples = utils.generate_image_grids(samples)[None, :, :, :]
samples = {"samples": samples.astype(np.uint8)}
writer.write_images(step, samples)
def sample(self, logdir, checkpoint_dir):
logging.info("=== Experiment.sample() ===")
ckpt = checkpoint.Checkpoint(checkpoint_dir)
state_dict = ckpt.restore_dict()
params = flax.core.FrozenDict(state_dict["ema_params"])
eval_logdir = os.path.join(logdir, "samples")
tf.io.gfile.makedirs(eval_logdir)
rng = self.rng
nfes = []
for i in tqdm(range(200)):
rng, sample_rng = jax.random.split(rng)
p_sample = functools.partial(
self.sample_fn,
dummy_inputs=next(self.eval_iter)["images"],
rng=sample_rng,
)
def _gather(fn, *args, **kwargs):
samples, nfe = fn(*args, **kwargs)
shape = samples.shape
samples = samples.reshape((-1, shape[2], shape[3], shape[4]))
return samples, nfe
p_sample = functools.partial(_gather, fn=p_sample)
samples, nfe = p_sample(params=params)
nfes.append(nfe)
samples = samples.astype(np.uint8)
with tf.io.gfile.GFile(os.path.join(eval_logdir, f"samples_{i}.npz"), "wb") as fout:
io_buffer = io.BytesIO()
np.savez_compressed(io_buffer, samples=samples)
fout.write(io_buffer.getvalue())
print("avg nfe", np.mean(nfes))
def train_step(self, base_rng, state, batch):
rng = jax.random.fold_in(base_rng, jax.lax.axis_index("batch"))
rng = jax.random.fold_in(rng, state.step)
grad_fn = jax.value_and_grad(self.loss_fn, has_aux=True)
(_, metrics), grads = grad_fn(state.params, batch, rng=rng, is_train=True)
grads = jax.lax.pmean(grads, "batch")
learning_rate = self.lr_schedule(state.step)
new_state = state.apply_gradients(
grads=grads,
lr=learning_rate,
ema_rate=jax.lax.select(
(state.step + 1) % self.config.optimizer.gradient_accumulation == 0, self.config.optimizer.ema_rate, 1.0
),
)
metrics["scalars"] = jax.tree_map(lambda x: jax.lax.pmean(x, axis_name="batch"), metrics["scalars"])
metrics["scalars"] = {"train_" + k: v for (k, v) in metrics["scalars"].items()}
metrics["images"] = jax.tree_map(lambda x: utils.generate_image_grids(x)[None, :, :, :], metrics["images"])
return new_state, metrics
def eval_step(self, base_rng, params, batch, eval_step=0):
rng = jax.random.fold_in(base_rng, jax.lax.axis_index("batch"))
rng = jax.random.fold_in(rng, eval_step)
_, metrics = self.loss_fn(params, batch, rng=rng, is_train=False)
# summarize metrics
metrics["scalars"] = jax.lax.pmean(metrics["scalars"], axis_name="batch")
metrics["scalars"] = {"eval_" + k: v for (k, v) in metrics["scalars"].items()}
metrics["images"] = jax.tree_map(lambda x: utils.generate_image_grids(x)[None, :, :, :], metrics["images"])
return metrics
def copy_dict(dict1, dict2):
if not isinstance(dict1, dict):
assert not isinstance(dict2, dict)
return dict2
for key in dict1.keys():
if key in dict2:
dict1[key] = copy_dict(dict1[key], dict2[key])
return dict1
def restore_partial(state, state_restore_dict):
state_dict = flax.serialization.to_state_dict(state)
state_dict = copy_dict(state_dict, state_restore_dict)
state = flax.serialization.from_state_dict(state, state_dict)
return state