-
Notifications
You must be signed in to change notification settings - Fork 39
/
Copy pathpredict.py
113 lines (89 loc) · 3.49 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md
import os
import numpy as np
import yaml
import torch
from cog import BasePredictor, Input, Path
from utils import *
from models import instructir
from text.models import LanguageModel, LMHead
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""
LM_MODEL = "models/lm_instructir-7d.pt"
MODEL_NAME = "models/im_instructir-7d.pt"
device = torch.device("cpu")
with open(os.path.join("configs/eval5d.yml"), "r") as f:
config = yaml.safe_load(f)
cfg = dict2namespace(config)
torch.backends.cudnn.deterministic = True
self.model = instructir.create_model(
input_channels=cfg.model.in_ch,
width=cfg.model.width,
enc_blks=cfg.model.enc_blks,
middle_blk_num=cfg.model.middle_blk_num,
dec_blks=cfg.model.dec_blks,
txtdim=cfg.model.textdim,
)
self.model = self.model.to(device)
print("IMAGE MODEL CKPT:", MODEL_NAME)
self.model.load_state_dict(
torch.load(MODEL_NAME, map_location="cpu"), strict=True
)
# Initialize the LanguageModel class
LMODEL = cfg.llm.model
self.language_model = LanguageModel(model=LMODEL)
self.lm_head = LMHead(
embedding_dim=cfg.llm.model_dim,
hidden_dim=cfg.llm.embd_dim,
num_classes=cfg.llm.nclasses,
)
self.lm_head = self.lm_head # .to(device)
print("LMHEAD MODEL CKPT:", LM_MODEL)
self.lm_head.load_state_dict(
torch.load(LM_MODEL, map_location="cpu"), strict=True
)
print("Loaded weights!")
def predict(
self,
image: Path = Input(description="Input image."),
prompt: str = Input(description="Input prompt."),
seed: int = Input(
description="Random seed. Leave blank to randomize the seed", default=None
),
) -> Path:
"""Run a single prediction on the model"""
if seed is None:
seed = int.from_bytes(os.urandom(2), "big")
print(f"Using seed: {seed}")
seed_everything(SEED=seed)
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
image = load_img(str(image))
out_image = process_img(
image, prompt, self.language_model, self.model, self.lm_head
)
out_path = "/tmp/out.png"
saveImage(out_path, out_image)
return Path(out_path)
def process_img(image, prompt, language_model, model, lm_head):
"""
Given an image and a prompt, we run InstructIR to restore the image following the human prompt.
image: RGB image as numpy array normalized to [0,1]
prompt: plain python string,
returns the restored image as numpy array.
"""
# Convert the image to tensor
y = torch.Tensor(image).permute(2, 0, 1).unsqueeze(0)
# Get the text embedding (and predicted degradation class)
lm_embd = language_model(prompt)
lm_embd = lm_embd # .to(device)
text_embd, deg_pred = lm_head(lm_embd)
# Forward pass: Paper Figure 2
x_hat = model(y, text_embd)
# convert the restored image <x_hat> into a np array
restored_img = x_hat[0].permute(1, 2, 0).cpu().detach().numpy()
restored_img = np.clip(restored_img, 0.0, 1.0)
return restored_img