diff --git a/src/gluonts/torch/model/seg_diff/module.py b/src/gluonts/torch/model/seg_diff/module.py index bfb3e0a99a..c49e2e8332 100644 --- a/src/gluonts/torch/model/seg_diff/module.py +++ b/src/gluonts/torch/model/seg_diff/module.py @@ -119,9 +119,14 @@ def __init__( hidden_dim: int, time_embed_dim: int = 8, act_fn_name: str = "gelu", + cond_drop_prob: float = 0.2, # Add conditioning dropout probability + cfg_scale: float = 1.5, ): super().__init__() + act_fn = ACT2FN[act_fn_name] + self.cond_drop_prob = cond_drop_prob + self.cfg_scale = cfg_scale # Time embedding network self.time_embed = nn.Sequential( @@ -171,12 +176,33 @@ def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor): # Get time embeddings t_embed = self.time_embed(t) - # Process conditioning + # Get conditioning features cond_features = self.cond_net(cond) - # Concatenate all inputs and compute velocity - inputs = torch.cat([x, t_embed, cond_features], dim=-1) - return self.net(inputs) + if self.training: + # Drop conditioning features with probability cond_drop_prob + mask = ( + torch.rand( + cond_features.shape[:2], device=cond_features.device + ) + > self.cond_drop_prob + ) + cond_features = cond_features * mask.unsqueeze(-1) + + # Concatenate all inputs and compute conditioned velocity + inputs = torch.cat([x, t_embed, cond_features], dim=-1) + return self.net(inputs) + else: + # Compute null velocity + null_cond_features = torch.zeros_like(cond_features) + null_inputs = torch.cat([x, t_embed, null_cond_features], dim=-1) + null_output = self.net(null_inputs) + + # Concatenate all inputs and compute conditioned velocity + inputs = torch.cat([x, t_embed, cond_features], dim=-1) + cond_output = self.net(inputs) + + return null_output + self.cfg_scale * (cond_output - null_output) class Flow(nn.Module): @@ -193,10 +219,8 @@ def __init__( self.velocity_model = VelocityModel( cond_dim, feat_dim, hidden_dim, act_fn_name=act_fn_name ) - # Flow matching components self.prob_path = CondOTProbPath() - self.solver = ODESolver(self.velocity_model) def compute_loss( self, x_0: torch.Tensor, x_1: torch.Tensor, cond: torch.Tensor @@ -614,7 +638,9 @@ def log_prob( x_1 = target[:, 1:, :] # Target patches cond = flow_cond[:, :-1, :] - _, exact_log_p = self.flow.solver.compute_likelihood( + solver = ODESolver(self.flow.velocity_model) + + _, exact_log_p = solver.compute_likelihood( x_1=x_1.reshape(-1, self.patch_len), cond=cond.reshape(-1, self.d_model), method="midpoint", @@ -624,6 +650,7 @@ def log_prob( ) return -exact_log_p.mean() + @torch.inference_mode() def forward( self, past_target: torch.Tensor, @@ -663,11 +690,14 @@ def forward( num_parallel_samples, dim=0 ) + # solver + solver = ODESolver(self.flow.velocity_model) + # Evolve the samples through time using the flow - x = self.flow.solver.sample( + x = solver.sample( x_init=x, cond=last_cond, - method="dopri8", + method="midpoint", step_size=0.05, return_intermediates=False, time_grid=T, @@ -740,7 +770,7 @@ def forward( # t_end=time_steps[i + 1], # cond=last_cond, # ) - x = self.flow.sample( + x = solver.sample( x_init=x, cond=last_cond, method="midpoint", @@ -749,7 +779,7 @@ def forward( time_grid=T, ) - # Scale and store the samples + # Scale and store the sampwg next_sample = x.view( batch_size, num_parallel_samples, self.patch_len ) * scale.view(batch_size, num_parallel_samples, -1) + loc.view(