Skip to content

Commit

Permalink
add cfg
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Jan 3, 2025
1 parent 60b567e commit c8685f6
Showing 1 changed file with 41 additions and 11 deletions.
52 changes: 41 additions & 11 deletions src/gluonts/torch/model/seg_diff/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,14 @@ def __init__(
hidden_dim: int,
time_embed_dim: int = 8,
act_fn_name: str = "gelu",
cond_drop_prob: float = 0.2, # Add conditioning dropout probability
cfg_scale: float = 1.5,
):
super().__init__()

act_fn = ACT2FN[act_fn_name]
self.cond_drop_prob = cond_drop_prob
self.cfg_scale = cfg_scale

# Time embedding network
self.time_embed = nn.Sequential(
Expand Down Expand Up @@ -171,12 +176,33 @@ def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor):
# Get time embeddings
t_embed = self.time_embed(t)

# Process conditioning
# Get conditioning features
cond_features = self.cond_net(cond)

# Concatenate all inputs and compute velocity
inputs = torch.cat([x, t_embed, cond_features], dim=-1)
return self.net(inputs)
if self.training:
# Drop conditioning features with probability cond_drop_prob
mask = (
torch.rand(
cond_features.shape[:2], device=cond_features.device
)
> self.cond_drop_prob
)
cond_features = cond_features * mask.unsqueeze(-1)

# Concatenate all inputs and compute conditioned velocity
inputs = torch.cat([x, t_embed, cond_features], dim=-1)
return self.net(inputs)
else:
# Compute null velocity
null_cond_features = torch.zeros_like(cond_features)
null_inputs = torch.cat([x, t_embed, null_cond_features], dim=-1)
null_output = self.net(null_inputs)

# Concatenate all inputs and compute conditioned velocity
inputs = torch.cat([x, t_embed, cond_features], dim=-1)
cond_output = self.net(inputs)

return null_output + self.cfg_scale * (cond_output - null_output)


class Flow(nn.Module):
Expand All @@ -193,10 +219,8 @@ def __init__(
self.velocity_model = VelocityModel(
cond_dim, feat_dim, hidden_dim, act_fn_name=act_fn_name
)

# Flow matching components
self.prob_path = CondOTProbPath()
self.solver = ODESolver(self.velocity_model)

def compute_loss(
self, x_0: torch.Tensor, x_1: torch.Tensor, cond: torch.Tensor
Expand Down Expand Up @@ -614,7 +638,9 @@ def log_prob(
x_1 = target[:, 1:, :] # Target patches
cond = flow_cond[:, :-1, :]

_, exact_log_p = self.flow.solver.compute_likelihood(
solver = ODESolver(self.flow.velocity_model)

_, exact_log_p = solver.compute_likelihood(
x_1=x_1.reshape(-1, self.patch_len),
cond=cond.reshape(-1, self.d_model),
method="midpoint",
Expand All @@ -624,6 +650,7 @@ def log_prob(
)
return -exact_log_p.mean()

@torch.inference_mode()
def forward(
self,
past_target: torch.Tensor,
Expand Down Expand Up @@ -663,11 +690,14 @@ def forward(
num_parallel_samples, dim=0
)

# solver
solver = ODESolver(self.flow.velocity_model)

# Evolve the samples through time using the flow
x = self.flow.solver.sample(
x = solver.sample(
x_init=x,
cond=last_cond,
method="dopri8",
method="midpoint",
step_size=0.05,
return_intermediates=False,
time_grid=T,
Expand Down Expand Up @@ -740,7 +770,7 @@ def forward(
# t_end=time_steps[i + 1],
# cond=last_cond,
# )
x = self.flow.sample(
x = solver.sample(
x_init=x,
cond=last_cond,
method="midpoint",
Expand All @@ -749,7 +779,7 @@ def forward(
time_grid=T,
)

# Scale and store the samples
# Scale and store the sampwg
next_sample = x.view(
batch_size, num_parallel_samples, self.patch_len
) * scale.view(batch_size, num_parallel_samples, -1) + loc.view(
Expand Down

0 comments on commit c8685f6

Please sign in to comment.