forked from kekeblom/mmstereo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcallbacks.py
32 lines (22 loc) · 1.01 KB
/
callbacks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# Copyright 2021 Toyota Research Institute. All rights reserved.
import os
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities import rank_zero_only
class OnnxExport(Callback):
"""Callback for automatically exporting ONNX to the checkpoint directory at the end of validation"""
def __init__(self, output_dir, note):
self.output_dir = output_dir
self.note = note
@rank_zero_only
def on_validation_epoch_end(self, trainer, pl_module):
path = os.path.join(self.output_dir, "model.onnx")
pl_module.export_onnx(path)
class TorchscriptExport(Callback):
"""Callback for automatically exporting Torchscript to the checkpoint directory at the end of validation"""
def __init__(self, output_dir, note):
self.output_dir = output_dir
self.note = note
@rank_zero_only
def on_validation_epoch_end(self, trainer, pl_module):
path = os.path.join(self.output_dir, "model.pt")
pl_module.export_torchscript(path)