-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
66 lines (54 loc) · 1.88 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import hydra
import wandb
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from torch import nn
@hydra.main(config_path="conf", config_name="config")
def run(cfg: DictConfig):
# print run info
print(OmegaConf.to_yaml(cfg))
print("Working directory: {}".format(os.getcwd()))
# create datamodule
datamodule = instantiate(cfg.datamodule, data_dir=cfg.data_dir)
# create network
network: nn.Sequential = instantiate(
cfg.network, in_channels=datamodule.dims[0], n_classes=datamodule.num_classes
)
# create model
# note: _recursive_ is set to False below to avoid creating optimizers
# since they are created internally in configure_optimizers() method
# of lightning module.
model = instantiate(
cfg.model,
_recursive_=False,
datamodule=datamodule,
network=network,
hparams=cfg.model,
full_config=cfg,
network_hparams=cfg.network,
)
# create trainer
if cfg.debug:
# define trainer debug behavior
cfg.trainer.max_epochs = 1
cfg.trainer.accelerator = None
cfg.trainer.gpus = 1
cfg.trainer.logger = None # disable wandb logging
cfg.trainer.enable_checkpointing = False
cfg.trainer.profiler = "simple"
trainer = instantiate(cfg.trainer)
# run experiment
trainer.fit(model, datamodule=datamodule)
# run on test set
test_results = trainer.test(model, datamodule=datamodule, verbose=True)
# # display experiment results
if wandb.run:
wandb.finish()
top1_accuracy: float = test_results[0]["test/accuracy"]
top5_accuracy: float = test_results[0]["test/top5_accuracy"]
print(f"Test top1 accuracy: {top1_accuracy:.2%}")
print(f"Test top5 accuracy: {top5_accuracy:.2%}")
return top1_accuracy, top5_accuracy
if __name__ == "__main__":
run()