From 22d814b714c24b26418da45b3924e24aea8dc895 Mon Sep 17 00:00:00 2001 From: Neal Pan Date: Tue, 19 Nov 2024 13:37:03 -0800 Subject: [PATCH 1/2] Fix arg parse error induced by introduction of lead time in corrdiff --- examples/generative/corrdiff/train.py | 20 ++++--- modulus/metrics/diffusion/loss.py | 80 ++++++++++++++++++--------- 2 files changed, 67 insertions(+), 33 deletions(-) diff --git a/examples/generative/corrdiff/train.py b/examples/generative/corrdiff/train.py index 7edc3da08..c3fcd4faf 100644 --- a/examples/generative/corrdiff/train.py +++ b/examples/generative/corrdiff/train.py @@ -108,6 +108,9 @@ def main(cfg: DictConfig) -> None: if cfg.model.name == "lt_aware_ce_regression": prob_channels = dataset.get_prob_channel_index() + else: + prob_channels = None + # Parse the patch shape if ( cfg.model.name == "patched_diffusion" @@ -314,19 +317,20 @@ def main(cfg: DictConfig) -> None: img_clean = img_clean.to(dist.device).to(torch.float32).contiguous() img_lr = img_lr.to(dist.device).to(torch.float32).contiguous() labels = labels.to(dist.device).contiguous() + loss_fn_kwargs = { + "net": model, + "img_clean": img_clean, + "img_lr": img_lr, + "labels": labels, + "augment_pipe": None, + } if lead_time_label: lead_time_label = lead_time_label[0].to(dist.device).contiguous() + loss_fn_kwargs.update({"lead_time_label": lead_time_label}) else: lead_time_label = None with torch.autocast("cuda", dtype=amp_dtype, enabled=enable_amp): - loss = loss_fn( - net=model, - img_clean=img_clean, - img_lr=img_lr, - labels=labels, - lead_time_label=lead_time_label, - augment_pipe=None, - ) + loss = loss_fn(**loss_fn_kwargs) loss = loss.sum() / batch_size_per_gpu loss_accum += loss / num_accumulation_rounds loss.backward() diff --git a/modulus/metrics/diffusion/loss.py b/modulus/metrics/diffusion/loss.py index 12166eb50..ded29a1c5 100644 --- a/modulus/metrics/diffusion/loss.py +++ b/modulus/metrics/diffusion/loss.py @@ -518,14 +518,23 @@ def __call__( ].expand(b, -1, -1, -1) # form residual - y_mean = self.unet( - torch.zeros_like(y, device=img_clean.device), - y_lr_res, - sigma, - labels, - lead_time_label=lead_time_label, - augment_labels=augment_labels, - ) + if lead_time_label: + y_mean = self.unet( + torch.zeros_like(y, device=img_clean.device), + y_lr_res, + sigma, + labels, + lead_time_label=lead_time_label, + augment_labels=augment_labels, + ) + else: + y_mean = self.unet( + torch.zeros_like(y, device=img_clean.device), + y_lr_res, + sigma, + labels, + augment_labels=augment_labels, + ) y = y - y_mean @@ -608,15 +617,26 @@ def __call__( y = y_new y_lr = y_lr_new latent = y + torch.randn_like(y) * sigma - D_yn = net( - latent, - y_lr, - sigma, - labels, - global_index=global_index, - lead_time_label=lead_time_label, - augment_labels=augment_labels, - ) + + if lead_time_label: + D_yn = net( + latent, + y_lr, + sigma, + labels, + global_index=global_index, + lead_time_label=lead_time_label, + augment_labels=augment_labels, + ) + else: + D_yn = net( + latent, + y_lr, + sigma, + labels, + global_index=global_index, + augment_labels=augment_labels, + ) loss = weight * ((D_yn - y) ** 2) return loss @@ -856,14 +876,24 @@ def __call__( y_lr = y_tot[:, img_clean.shape[1] :, :, :] input = torch.zeros_like(y, device=img_clean.device) - D_yn = net( - input, - y_lr, - sigma, - labels, - lead_time_label=lead_time_label, - augment_labels=augment_labels, - ) + + if lead_time_label: + D_yn = net( + input, + y_lr, + sigma, + labels, + lead_time_label=lead_time_label, + augment_labels=augment_labels, + ) + else: + D_yn = net( + input, + y_lr, + sigma, + labels, + augment_labels=augment_labels, + ) loss1 = weight * ((D_yn[:, scalar_channels] - y[:, scalar_channels]) ** 2) loss2 = ( weight From 4ef3384907194eccbb6a6ed5e6bf6c27b22ec2aa Mon Sep 17 00:00:00 2001 From: Peter Harrington Date: Fri, 10 Jan 2025 17:50:10 -0800 Subject: [PATCH 2/2] fix lead time label and gefs datapipe bug --- examples/generative/corrdiff/conf/dataset/gefs_hrrr.yaml | 2 +- examples/generative/corrdiff/datasets/gefs_hrrr.py | 2 +- modulus/metrics/diffusion/loss.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/generative/corrdiff/conf/dataset/gefs_hrrr.yaml b/examples/generative/corrdiff/conf/dataset/gefs_hrrr.yaml index f7fed742d..3b67d77bb 100644 --- a/examples/generative/corrdiff/conf/dataset/gefs_hrrr.yaml +++ b/examples/generative/corrdiff/conf/dataset/gefs_hrrr.yaml @@ -16,7 +16,7 @@ type: gefs_hrrr data_path: /data -stats_path: modulus/examples/generative/corrdiff/stats.json +stats_path: /data/stats.json output_variables: ["u10m", "v10m", "t2m", "precip", "cat_snow", "cat_ice", "cat_freez", "cat_rain", "cat_none"] prob_variables: ["cat_snow", "cat_ice", "cat_freez", "cat_rain"] input_surface_variables: ["u10m", "v10m", "t2m", "q2m", "sp", "msl", "precipitable_water"] diff --git a/examples/generative/corrdiff/datasets/gefs_hrrr.py b/examples/generative/corrdiff/datasets/gefs_hrrr.py index 4eb71eab2..50a85ccff 100644 --- a/examples/generative/corrdiff/datasets/gefs_hrrr.py +++ b/examples/generative/corrdiff/datasets/gefs_hrrr.py @@ -562,7 +562,7 @@ def get_prob_channel_index(self): """ Get prob_channel_index list one more dimension """ - return self.prob_channel_index + [len(self.output_variables)] + return self.prob_channel_index + [len(self.output_variables) - 1] def input_channels(self): return [ChannelMetadata(name=n) for n in self.input_variables] diff --git a/modulus/metrics/diffusion/loss.py b/modulus/metrics/diffusion/loss.py index f1d770b68..18dde13b5 100644 --- a/modulus/metrics/diffusion/loss.py +++ b/modulus/metrics/diffusion/loss.py @@ -527,7 +527,7 @@ def __call__( ].expand(b, -1, -1, -1) # form residual - if lead_time_label: + if lead_time_label is not None: y_mean = self.unet( torch.zeros_like(y, device=img_clean.device), y_lr_res, @@ -627,7 +627,7 @@ def __call__( y_lr = y_lr_new latent = y + torch.randn_like(y) * sigma - if lead_time_label: + if lead_time_label is not None: D_yn = net( latent, y_lr, @@ -886,7 +886,7 @@ def __call__( input = torch.zeros_like(y, device=img_clean.device) - if lead_time_label: + if lead_time_label is not None: D_yn = net( input, y_lr,