Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CorrDiff bugfixes #756

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/generative/corrdiff/conf/dataset/gefs_hrrr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

type: gefs_hrrr
data_path: /data
stats_path: modulus/examples/generative/corrdiff/stats.json
stats_path: /data/stats.json
output_variables: ["u10m", "v10m", "t2m", "precip", "cat_snow", "cat_ice", "cat_freez", "cat_rain", "cat_none"]
prob_variables: ["cat_snow", "cat_ice", "cat_freez", "cat_rain"]
input_surface_variables: ["u10m", "v10m", "t2m", "q2m", "sp", "msl", "precipitable_water"]
Expand Down
2 changes: 1 addition & 1 deletion examples/generative/corrdiff/datasets/gefs_hrrr.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ def get_prob_channel_index(self):
"""
Get prob_channel_index list one more dimension
"""
return self.prob_channel_index + [len(self.output_variables)]
return self.prob_channel_index + [len(self.output_variables) - 1]

def input_channels(self):
return [ChannelMetadata(name=n) for n in self.input_variables]
Expand Down
20 changes: 12 additions & 8 deletions examples/generative/corrdiff/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ def main(cfg: DictConfig) -> None:

if cfg.model.name == "lt_aware_ce_regression":
prob_channels = dataset.get_prob_channel_index()
else:
prob_channels = None

# Parse the patch shape
if (
cfg.model.name == "patched_diffusion"
Expand Down Expand Up @@ -314,19 +317,20 @@ def main(cfg: DictConfig) -> None:
img_clean = img_clean.to(dist.device).to(torch.float32).contiguous()
img_lr = img_lr.to(dist.device).to(torch.float32).contiguous()
labels = labels.to(dist.device).contiguous()
loss_fn_kwargs = {
"net": model,
"img_clean": img_clean,
"img_lr": img_lr,
"labels": labels,
"augment_pipe": None,
}
if lead_time_label:
lead_time_label = lead_time_label[0].to(dist.device).contiguous()
loss_fn_kwargs.update({"lead_time_label": lead_time_label})
else:
lead_time_label = None
with torch.autocast("cuda", dtype=amp_dtype, enabled=enable_amp):
loss = loss_fn(
net=model,
img_clean=img_clean,
img_lr=img_lr,
labels=labels,
lead_time_label=lead_time_label,
augment_pipe=None,
)
loss = loss_fn(**loss_fn_kwargs)
loss = loss.sum() / batch_size_per_gpu
loss_accum += loss / num_accumulation_rounds
loss.backward()
Expand Down
80 changes: 55 additions & 25 deletions modulus/metrics/diffusion/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,14 +527,23 @@ def __call__(
].expand(b, -1, -1, -1)

# form residual
y_mean = self.unet(
torch.zeros_like(y, device=img_clean.device),
y_lr_res,
sigma,
labels,
lead_time_label=lead_time_label,
augment_labels=augment_labels,
)
if lead_time_label is not None:
y_mean = self.unet(
torch.zeros_like(y, device=img_clean.device),
y_lr_res,
sigma,
labels,
lead_time_label=lead_time_label,
augment_labels=augment_labels,
)
else:
y_mean = self.unet(
torch.zeros_like(y, device=img_clean.device),
y_lr_res,
sigma,
labels,
augment_labels=augment_labels,
)

y = y - y_mean

Expand Down Expand Up @@ -617,15 +626,26 @@ def __call__(
y = y_new
y_lr = y_lr_new
latent = y + torch.randn_like(y) * sigma
D_yn = net(
latent,
y_lr,
sigma,
labels,
global_index=global_index,
lead_time_label=lead_time_label,
augment_labels=augment_labels,
)

if lead_time_label is not None:
D_yn = net(
latent,
y_lr,
sigma,
labels,
global_index=global_index,
lead_time_label=lead_time_label,
augment_labels=augment_labels,
)
else:
D_yn = net(
latent,
y_lr,
sigma,
labels,
global_index=global_index,
augment_labels=augment_labels,
)
loss = weight * ((D_yn - y) ** 2)

return loss
Expand Down Expand Up @@ -865,14 +885,24 @@ def __call__(
y_lr = y_tot[:, img_clean.shape[1] :, :, :]

input = torch.zeros_like(y, device=img_clean.device)
D_yn = net(
input,
y_lr,
sigma,
labels,
lead_time_label=lead_time_label,
augment_labels=augment_labels,
)

if lead_time_label is not None:
D_yn = net(
input,
y_lr,
sigma,
labels,
lead_time_label=lead_time_label,
augment_labels=augment_labels,
)
else:
D_yn = net(
input,
y_lr,
sigma,
labels,
augment_labels=augment_labels,
)
loss1 = weight * ((D_yn[:, scalar_channels] - y[:, scalar_channels]) ** 2)
loss2 = (
weight
Expand Down