-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
89 lines (72 loc) · 1.84 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
r"""CIFAR experiment helpers"""
import os
from jax import Array
from pathlib import Path
from typing import *
# isort: split
from priors.common import *
from priors.data import *
from priors.diffusion import *
from priors.image import *
from priors.nn import *
from priors.optim import *
if 'SCRATCH' in os.environ:
SCRATCH = os.environ['SCRATCH']
PATH = Path(SCRATCH) / 'priors/cifar'
else:
PATH = Path('.')
PATH.mkdir(parents=True, exist_ok=True)
def measure(A: Array, x: Array) -> Array:
return flatten(A * unflatten(x, 32, 32))
def sample(
model: nn.Module,
y: Array,
A: Array,
key: Array,
shard: bool = False,
**kwargs,
) -> Array:
if shard:
y, A = distribute((y, A))
x = sample_any(
model=model,
shape=flatten(y).shape,
shard=shard,
A=inox.Partial(measure, A),
y=flatten(y),
cov_y=1e-3**2,
key=key,
**kwargs,
)
x = unflatten(x, 32, 32)
return x
def make_model(
key: Array,
hid_channels: Sequence[int] = (64, 128, 256),
hid_blocks: Sequence[int] = (3, 3, 3),
kernel_size: Sequence[int] = (3, 3),
emb_features: int = 256,
heads: Dict[int, int] = {2: 1},
dropout: float = None,
**absorb,
) -> Denoiser:
return Denoiser(
network=FlatUNet(
in_channels=3,
out_channels=3,
hid_channels=hid_channels,
hid_blocks=hid_blocks,
kernel_size=kernel_size,
emb_features=emb_features,
heads=heads,
dropout=dropout,
key=key,
),
emb_features=emb_features,
)
class FlatUNet(UNet):
def __call__(self, x: Array, t: Array, key: Array = None) -> Array:
x = unflatten(x, width=32, height=32)
x = super().__call__(x, t, key)
x = flatten(x)
return x