-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference.py
106 lines (80 loc) · 3.08 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import argparse
import torch
from aicsimageio.writers.ome_tiff_writer import OmeTiffWriter
from aicsimageio.writers.ome_zarr_writer import OmeZarrWriter
from aicsimageio import types
from src.data.aicszarr import zarr_to_input
from src.model import UNet
from src.utils.utils import get_device
from src.utils.parsers import add_inference_parser_arguments
def load_generator(ckp_path):
try:
gen_ckp = torch.load(ckp_path)
except:
return ValueError(f"Model file {ckp_path} not found")
gen_hp = gen_ckp['gen_hyperparams']
ndim = gen_hp['ndim']
depth = gen_hp['depth']
mult_chan = gen_hp['mult_chan']
lr_g = gen_hp['lr_g']
target_channels = gen_ckp['target_channels']
# Load the model
Gen = UNet(
ndim=ndim,
activation_fn=torch.nn.LeakyReLU,
activation_kwargs=(lr_g, True),
depth=depth,
n_in_channels=1,
out_channels=len(target_channels),
mult_chan=mult_chan,
)
Gen.load_state_dict(gen_ckp['gen_state_dict'])
return Gen, target_channels
def save_output(output, output_format, output_path, target_channels):
output = output.cpu().numpy()
if output_format == 'tiff':
OmeTiffWriter.save(output,
f"{output_path}.ome.tiff",
dim_order='CZYX',
channel_names=target_channels,
)
elif output_format == 'zarr':
writer = OmeZarrWriter(f"{output_path}.ome.zarr")
writer.write_image(output,
dimension_order='CZYX',
image_name=f'{output_path}',
channel_names=target_channels,
physical_pixel_sizes=types.PhysicalPixelSizes(
X=0.108, Y=0.108, Z=0.108),
channel_colors=[
0, 255, 255*256, 255*256**2, 255*256+255*256**2, 255*256+255, 256**2*255+128*256]
)
else:
raise ValueError(f"Output format {output_format} not supported")
def inference(img_path, ckp_path, output_path, output_format, section, source, device):
Gen, target_channels = load_generator(ckp_path)
Gen = Gen.to(device)
source = zarr_to_input(img_path, section=section,
src_channel=source, device=device)
Gen.eval()
with torch.inference_mode():
output = Gen(source).squeeze(0)
if output_format is not None:
save_output(output, output_format, output_path, target_channels)
return output
def main() -> None:
device = get_device(0)
parser = argparse.ArgumentParser(
description="Inference script for the model")
add_inference_parser_arguments(parser)
args = parser.parse_args()
img_path = args.img_path
ckp_path = args.model
output_path = args.output
output_format = args.output_format
section = args.section
source = args.source
inference(img_path, ckp_path, output_path,
output_format, section, source, device)
if __name__ == "__main__":
main()