Skip to content

Commit

Permalink
cleanups
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Dec 13, 2024
1 parent 8d5f4b8 commit e9ceffa
Showing 1 changed file with 44 additions and 22 deletions.
66 changes: 44 additions & 22 deletions src/gluonts/torch/model/seg_diff/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def __init__(self, cond_dim: int, out_dim: int, h: int):
self.sampler = Sampler(self.transport)

def compute_loss(
self, x_0: torch.Tensor, x_1: torch.Tensor, cond: torch.Tensor
self, x_1: torch.Tensor, cond: torch.Tensor
) -> torch.Tensor:
"""Compute flow matching loss."""
# Use transport training loss
Expand All @@ -238,7 +238,6 @@ def compute_loss(
)
return terms["loss"].mean()


# def sample(
# self,
# x_init: torch.Tensor,
Expand All @@ -250,36 +249,36 @@ def compute_loss(
# ) -> torch.Tensor:
# """
# Generate samples using the ODE solver.

# Args:
# x_init: Initial noise tensor
# cond: Conditioning tensor
# method: ODE solver method ('dopri5', 'euler', 'heun', etc.)
# step_size: Step size for fixed-step solvers
# return_intermediates: Whether to return intermediate states
# time_grid: Optional time points for sampling. If None, uses default grid

# Returns:
# Generated samples
# """
# if time_grid is None:
# time_grid = torch.linspace(0, 1.0, int(1.0/step_size) + 1, device=x_init.device)

# # Get ODE sampler with specified method
# ode_sampler = self.sampler.sample_ode(
# sampling_method=method,
# num_steps=len(time_grid) if method in ['euler', 'heun'] else 50,
# atol=1e-5,
# rtol=1e-5,
# )

# # Sample using the velocity model
# samples = ode_sampler(
# x_init,
# model=self.velocity_model,
# cond=cond,
# )

# if return_intermediates:
# return samples
# return samples[-1] # Return only final state if intermediates not requested
Expand All @@ -297,7 +296,7 @@ def sample(
) -> torch.Tensor:
"""
Generate samples using the SDE solver.
Args:
x_init: Initial noise tensor
cond: Conditioning tensor
Expand All @@ -307,12 +306,14 @@ def sample(
time_grid: Optional time points for sampling. If None, uses default grid
diffusion_form: Form of diffusion coefficient ('linear', 'constant', 'SBDM', etc.)
diffusion_norm: Scale of the diffusion coefficient
Returns:
Generated samples
"""
num_steps = int(1.0/step_size) + 1 if time_grid is None else len(time_grid)

num_steps = (
int(1.0 / step_size) + 1 if time_grid is None else len(time_grid)
)

# Get SDE sampler with specified method
sde_sampler = self.sampler.sample_sde(
sampling_method=method,
Expand All @@ -322,17 +323,31 @@ def sample(
last_step_size=step_size,
num_steps=num_steps,
)

# Sample using the velocity model
samples = sde_sampler(
x_init,
model=lambda x, t, **kwargs: self.velocity_model(x, t, kwargs["cond"]),
cond=cond
model=self.velocity_model,
cond=cond,
)


# ode_sampler = self.sampler.sample_ode(
# sampling_method=method,
# num_steps=num_steps,
# diffusion_form=diffusion_form,
# diffusion_norm=diffusion_norm,
# )

# samples = ode_sampler(
# x_init,
# model=self.velocity_model,
# cond=cond,
# )

if return_intermediates:
return samples
return samples[-1] # Return only final state if intermediates not requested
# Return only final state if intermediates not requested
return samples[-1]


class SegDiffModel(nn.Module):
Expand Down Expand Up @@ -535,11 +550,10 @@ def loss(

# Flow matching loss
x_1 = target[:, 1:, :] # Target patches
x_0 = torch.randn_like(x_1) # Random noise source distribution
# x_0 = torch.randn_like(x_1) # Random noise source distribution
# x_0 = target[:, :-1, :]

return self.flow.compute_loss(
x_0=x_0, x_1=x_1, cond=flow_cond[:, :-1, :]
)
return self.flow.compute_loss(x_1=x_1, cond=flow_cond[:, :-1, :])

def forward(
self,
Expand Down Expand Up @@ -569,6 +583,11 @@ def forward(
device=past_target.device,
)

# # the very last patch from past_target
# x = (
# (self.patch(past_target)[:, -1, :] - loc) / scale
# ).repeat_interleave(num_parallel_samples, dim=0)

time_grid = torch.linspace(0, 1.0, self.n_steps + 1, device=x.device)
# Get last decoder output and repeat for parallel samples
last_cond = flow_cond[:, -1, :].repeat_interleave(
Expand All @@ -582,7 +601,7 @@ def forward(
method="Heun",
step_size=0.05,
return_intermediates=False,
time_grid=time_grid
time_grid=time_grid,
)

# Reshape and scale the samples
Expand Down Expand Up @@ -638,6 +657,9 @@ def forward(
device=past_target.device,
)

# # the very last patch from past_target
# x = (future_samples[-1].view(batch_size*num_parallel_samples, -1) - loc) / scale

# Use updated conditioning from decoder
last_cond = flow_cond[:, -1, :]

Expand All @@ -655,7 +677,7 @@ def forward(
method="Heun",
step_size=0.05,
return_intermediates=False,
time_grid=time_grid
time_grid=time_grid,
)

# Scale and store the samples
Expand Down

0 comments on commit e9ceffa

Please sign in to comment.