-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathalpha.py
42 lines (36 loc) · 1.37 KB
/
alpha.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
from torchvision import models
from ..image import ImageBatch
from ..objectives import (ChannelObjective, MeanOpacityObjective,
TVRegularizerObjective)
from ..renderer import RendererBuilder
from ..transforms import presets
from ..transforms.fft import TFMSIFFT
from ..transforms.unit_space.TFMSTrainingToUnitSpace import \
TFMSTrainingToUnitSpace
def alpha(device="cuda:0", numberOfFrames=500):
model = models.resnet18(pretrained=True)
base_objective = ChannelObjective(lambda m: m.layer3[1].conv2, channel=15)
alpha = MeanOpacityObjective(torch.nn.Sequential(
TFMSIFFT(),
TFMSTrainingToUnitSpace()
))
lowfreq = TVRegularizerObjective(torch.nn.Sequential(
TFMSIFFT(),
TFMSTrainingToUnitSpace()
))
objective = base_objective * (1.0 - alpha) * (1.0 - lowfreq)
imageBatch = ImageBatch.generate(
data_space_transform=presets.dataspaceTFMS(alpha=True)
).to(device)
renderer = (RendererBuilder()
.imageBatch(imageBatch)
.model(model)
.objective(objective)
.trainTFMS(presets.trainTFMS(alpha=True))
.drawTFMS(presets.drawTFMS(alpha=True))
.withLivePreview()
.build()
)
renderer.render(numberOfFrames)
return renderer.drawableImageBatch().data