Skip to content

Commit

Permalink
[debug] Remove hard-coded datetime and interpolation functions in dat…
Browse files Browse the repository at this point in the history
…a generator (#22)

* [improve] remove all datetime(2022, 10, 1); seperate data/model/inference config

* [fix] standardization.json and warning in DataGenerator

* [debug] correct Radar Reflectivity key and value

* [feat] use Kth hour prediction to train the model

* [feat] Add linear decay learning rate schedule

* delete image_lat and image_lon in DataConfig

* remove cached 202501 configs

* inference onnx
  • Loading branch information
Chia-Tung authored Jan 3, 2025
1 parent 2f71f4e commit cc6077b
Show file tree
Hide file tree
Showing 38 changed files with 377 additions and 223 deletions.
File renamed without changes.
102 changes: 102 additions & 0 deletions assets/standardization_partial.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
{
"Geopotential Height@200 Hpa": {
"mean": 12505.2807025529,
"std": 31.23690801193252
},
"Geopotential Height@300 Hpa": {
"mean": 9732.684060083873,
"std": 23.59395278782757
},
"Geopotential Height@500 Hpa": {
"mean": 5881.078627356256,
"std": 21.227364772148206
},
"Geopotential Height@700 Hpa": {
"mean": 3147.417053029739,
"std": 22.679428052389323
},
"Geopotential Height@850 Hpa": {
"mean": 1499.3848574888973,
"std": 24.940354463433184
},
"Geopotential Height@925 Hpa": {
"mean": 767.0191895649674,
"std": 28.92858339148864
},
"Temperature@200 Hpa": {
"mean": 222.47866570404634,
"std": 1.1873095673931153
},
"Temperature@300 Hpa": {
"mean": 244.69536708875444,
"std": 1.3255060441101292
},
"Temperature@500 Hpa": {
"mean": 269.15684662006595,
"std": 0.9910247889826362
},
"Temperature@700 Hpa": {
"mean": 284.26493983305636,
"std": 1.2131348201022787
},
"Temperature@850 Hpa": {
"mean": 292.08385257839444,
"std": 2.465933917252086
},
"Temperature@925 Hpa": {
"mean": 295.3969013762513,
"std": 2.8345488795075937
},
"U-wind@200 Hpa": {
"mean": -1.4728492556729618,
"std": 10.455350215208668
},
"U-wind@300 Hpa": {
"mean": 0.5694262664279642,
"std": 8.352128746790838
},
"U-wind@500 Hpa": {
"mean": 1.7419202405172896,
"std": 6.741210946180685
},
"U-wind@700 Hpa": {
"mean": 1.3884807588227728,
"std": 5.97533398398843
},
"U-wind@850 Hpa": {
"mean": -0.026072261389558333,
"std": 6.141238748292276
},
"U-wind@925 Hpa": {
"mean": -0.8825579715061802,
"std": 6.335889760655964
},
"V-wind@200 Hpa": {
"mean": -5.301883747876605,
"std": 7.104234188182886
},
"V-wind@300 Hpa": {
"mean": -1.873841695051257,
"std": 5.12916839232868
},
"V-wind@500 Hpa": {
"mean": 0.6418335554148797,
"std": 4.328289945964971
},
"V-wind@700 Hpa": {
"mean": 1.6541866714064417,
"std": 4.679441343010596
},
"V-wind@850 Hpa": {
"mean": 1.6800889374926513,
"std": 5.321468450675873
},
"V-wind@925 Hpa": {
"mean": 1.3371695255277345,
"std": 6.41552949382295
},
"Radar Reflectivity@NoRule": {
"mean": 2.6260936703541637,
"std": 7.044576417439349
}
}
8 changes: 1 addition & 7 deletions config/data/rwrf.yaml → config/data/rwrf_202409.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@ format: "%Y-%m-%d %H:%M"
time_interval:
hours: 1
data_shape: [450, 450]
image_shape: [336, 336]
image_lat: [20.4, 27.1] # not used
image_lon: [117.52, 124.22] # not used
image_res: 0.02 # not used
image_shape: [224, 224]
train_data:
Z:
- Hpa200
Expand All @@ -21,20 +18,17 @@ train_data:
- Hpa700
- Hpa850
- Hpa925
- Meter2
U:
- Hpa200
- Hpa500
- Hpa700
- Hpa850
- Hpa925
- Meter10
V:
- Hpa200
- Hpa500
- Hpa700
- Hpa850
- Hpa925
- Meter10
Radar:
- NoRule
3 changes: 0 additions & 3 deletions config/data/rwrf_dense.yaml → config/data/rwrf_202412.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@ time_interval:
hours: 1
data_shape: [450, 450]
image_shape: [336, 336]
image_lat: [20.4, 27.1] # not used
image_lon: [117.52, 124.22] # not used
image_res: 0.02 # not used
train_data:
Z:
- Hpa100
Expand Down
2 changes: 1 addition & 1 deletion config/inference/pangu_rwrf_ckpt.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
infer_type: "ckpt"
best_ckpt: "./checkpoints/Pangu_241222_034119-epoch=026-val_loss_epoch=0.1527.ckpt"
best_ckpt: "./checkpoints/Pangu_241229_120914-epoch=162-val_loss_epoch=0.1453.ckpt"
is_bdy_swap: True
output_itv:
hours: 6
4 changes: 2 additions & 2 deletions config/inference/pangu_rwrf_onnx.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
infer_type: "onnx"
onnx_path: "./export/Pangu_model_241222.onnx"
onnx_path: "./export/Pangu_model_241229.onnx"
is_bdy_swap: True
output_itv:
hours: 3
hours: 1
gpu_id: 0
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ optim_config:
name: AdamW
args:
lr: 1e-5
warmup_epochs: 3
lr_schedule:
name: linear_decay
args:
warmup_epochs: 3
last_epoch: -1
# lightning.Trainer
num_gpus: null
strategy: "auto"
Expand Down
File renamed without changes.
File renamed without changes.
14 changes: 14 additions & 0 deletions config/model/pangu_rwrf_202409.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
model_name: Pangu
# patch embedding
patch_size: [1, 2, 2]
smoothing_kernel_size: 5
segmented_smooth_boundary_width: null
# earth layer
depths: [2, 6]
# earth block
max_drop_path_ratio: 0.2
# earth attn 3d
heads: [6, 12]
embed_dim: 192
dropout_rate: 0
window_size: [3, 4, 4]
File renamed without changes.
2 changes: 1 addition & 1 deletion config/plot/pangu_rwrf.yaml
Original file line number Diff line number Diff line change
@@ -1 +1 @@
figure_columns: 5
figure_columns: 7
6 changes: 3 additions & 3 deletions config/predict.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ lightning:
workers: 4

defaults:
- data: rwrf_dense
- lightning: pangu_rwrf
- model: pangu_rwrf
- data: rwrf_202412
- lightning: pangu_rwrf_202409
- model: pangu_rwrf_202412
- inference: pangu_rwrf_onnx
- plot: pangu_rwrf
- _self_
Expand Down
6 changes: 3 additions & 3 deletions config/train_diffusion.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ hydra:
level: DEBUG

defaults:
- data: rwrf
- lightning: diffusion_rwrf
- model: diffusion_rwrf
- data: rwrf_202409
- lightning: diffusion_rwrf_202409
- model: diffusion_rwrf_202409
- _self_
- override hydra/job_logging: default
6 changes: 3 additions & 3 deletions config/train_diffusion_radar.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ model:
hidden_dim: 32

defaults:
- data: rwrf
- lightning: diffusion_rwrf
- model: diffusion_rwrf
- data: rwrf_202409
- lightning: diffusion_rwrf_202409
- model: diffusion_rwrf_202409
- _self_
- override hydra/job_logging: default
6 changes: 3 additions & 3 deletions config/train_pangu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ hydra:
# verbose: [src.managers.data_manager]

defaults:
- data: rwrf_dense
- lightning: pangu_rwrf
- model: pangu_rwrf
- data: rwrf_202412
- lightning: pangu_rwrf_202409
- model: pangu_rwrf_202412
- _self_
- override hydra/job_logging: default # default/disabled/custom
15 changes: 7 additions & 8 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@
import hydra
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from matplotlib.figure import Figure
from omegaconf import DictConfig, OmegaConf
from tqdm import tqdm, trange

from inference import InferenceBase
from src.const import DATA_PATH, FIGURE_PATH
from src.utils import DataCompose, DataGenerator, get_var_from_wrfout_nc
from src.utils import DataCompose, DataGenerator, read_cwa_ncfile
from visual import *


Expand All @@ -37,8 +36,9 @@ def main(cfg: DictConfig) -> None:
# Prepare lat/lon
data_gnrt: DataGenerator = infer_machine.data_manager.data_gnrt
dc_lat, dc_lon = DataCompose.from_config({"Lat": ["NoRule"], "Lon": ["NoRule"]})
lat = data_gnrt.yield_data(datetime(2022, 10, 1, 0), dc_lat)
lon = data_gnrt.yield_data(datetime(2022, 10, 1, 0), dc_lon)
start_t = datetime.strptime(cfg.data.start_time, cfg.data.format)
lat = data_gnrt.yield_data(start_t, dc_lat)
lon = data_gnrt.yield_data(start_t, dc_lon)

# Prepare painter
u_compose, v_compose = DataCompose.from_config({"U": ["Hpa850"], "V": ["Hpa850"]})
Expand Down Expand Up @@ -77,12 +77,11 @@ def main(cfg: DictConfig) -> None:
u_tmp, v_tmp = [], []
for i in trange(infer_machine.showcase_length, desc="Get RWRF data"):
curr_time = eval_case + infer_machine.output_itv * i
filename = (
filename = Path(
f"{rwrf_dir}/wrfout_d01_{curr_time.strftime('%Y-%m-%d_%H')}_interp"
)
dataset = xr.open_dataset(filename)
u = get_var_from_wrfout_nc(dataset, u_compose)[57:-57, 57:-57] # (336, 336)
v = get_var_from_wrfout_nc(dataset, v_compose)[57:-57, 57:-57] # (336, 336)
u = read_cwa_ncfile(filename, u_compose)[57:-57, 57:-57] # (336, 336)
v = read_cwa_ncfile(filename, v_compose)[57:-57, 57:-57] # (336, 336)

u_tmp.append(u)
v_tmp.append(v)
Expand Down
21 changes: 11 additions & 10 deletions src/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
LAND_SEA_MASK_PATH = "./assets/constant_masks/land_sea_mask_2km.npy"
TOPOGRAPHY_MASK_PATH = "./assets/constant_masks/topography_mask_2km.npy"
COUNTY_SHP_PATH = "./assets/town_shp/COUNTY_MOI_1090820.shp"
STANDARDIZATION_PATH = "./assets/standardization.json"
STANDARDIZATION_PATH = "./assets/standardization_complete.json"
DATA_PATH = "/work/dong1128/rwrf_data/"
FIGURE_PATH = "./gallery/"
DATA_CONFIG_PATH = "./config/data/rwrf_dense.yaml"
DATA_CONFIG_PATH = "./config/data/rwrf_202412.yaml"

# Radar color bar
DBZ_LV = np.arange(0, 66, 1)
Expand Down Expand Up @@ -270,27 +270,28 @@
datetime(2021, 6, 4), # ATS
datetime(2022, 6, 24), # ATS, observe graupel in Taipei
datetime(2022, 8, 25), # ATS
],
"three_days": [
datetime(2021, 8, 7), # South-western flow + Tropical Depression
datetime(2021, 8, 8), # South-western flow
datetime(2023, 4, 20), # cold front
],
"three_days": [
"five_days": [
# == harsh northward turning == #
# datetime(2022, 9, 3), # TC HINNAMNOR
datetime(2022, 9, 12), # TC MUIFA
# datetime(2021, 7, 23), # TC IN-FA
# == north-eastern wind accompanied == #
datetime(2022, 10, 16), # TC NESAT
# datetime(2022, 10, 16), # TC NESAT
# datetime(2022, 10, 31), # TC NALGAE
# == pass by northern Taiwan == #
datetime(2020, 8, 3), # TC HAGUPI
],
"five_days": [
# datetime(2020, 8, 3), # TC HAGUPI
# == pass by eastern Taiwan == #
datetime(2023, 7, 26), # TC DOKSURI
# datetime(2023, 7, 26), # TC DOKSURI
# == landing == #
# datetime(2023, 9, 3), # TC HAIKUI
datetime(2024, 7, 24), # TC GAEMI
# datetime(2024, 10, 31), # TC Kong-rey
# datetime(2024, 7, 24), # TC GAEMI
datetime(2024, 10, 31), # TC Kong-rey
],
"seven_days": [
# == landing == #
Expand Down
6 changes: 5 additions & 1 deletion src/datasets/custom_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(
sampling_rate: int,
init_time_list: list[datetime],
data_list: list[DataCompose],
use_Kth_hour_pred: int | None,
is_train_or_valid: bool,
):
super().__init__()
Expand All @@ -32,6 +33,7 @@ def __init__(
self._sr = sampling_rate
self._init_time_list = init_time_list
self._data_list = data_list
self.use_Kth_hour_pred = use_Kth_hour_pred
self._is_train_or_valid = is_train_or_valid

if Path(STANDARDIZATION_PATH).exists():
Expand Down Expand Up @@ -91,7 +93,9 @@ def _get_variables_from_dt(self, dt: datetime) -> dict[str, np.ndarray]:
pre_output = defaultdict(list)
# via traversing data_list, the levels/vars are in the the same order as the
# order in `config/data/data_config.yaml`
data_dict = self._data_gnrt.yield_data(dt, self._data_list)
data_dict = self._data_gnrt.yield_data(
dt, self._data_list, use_Kth_hour_pred=self.use_Kth_hour_pred
)
for var_level_str, data in data_dict.items():
if var_level_str in self.stat_dict and var_level_str not in [
"Cloud Water Mixing Ratio@100 Hpa",
Expand Down
Loading

0 comments on commit cc6077b

Please sign in to comment.