Skip to content

Commit

Permalink
sample from the solver directly
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Dec 13, 2024
1 parent 469fde3 commit a36c562
Showing 1 changed file with 60 additions and 66 deletions.
126 changes: 60 additions & 66 deletions src/gluonts/torch/model/seg_diff/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@
from gluonts.torch.scaler import StdScaler, MeanScaler, NOPScaler
from gluonts.torch.util import take_last, unsqueeze_expand, weighted_average
from gluonts.torch.model.simple_feedforward import make_linear_layer
from flow_matching.path.scheduler import CondOTScheduler, CosineScheduler, PolynomialConvexScheduler, VPScheduler, LinearVPScheduler
from flow_matching.path.scheduler import (
CondOTScheduler,
CosineScheduler,
PolynomialConvexScheduler,
VPScheduler,
LinearVPScheduler,
)
from flow_matching.path import AffineProbPath
from flow_matching.solver import ODESolver

Expand Down Expand Up @@ -110,23 +116,24 @@ def forward(self, x: torch.Tensor):
return self.layer_norm(out)
return out


class VelocityModel(nn.Module):
def __init__(self, cond_dim: int, out_dim: int, h: int, time_embed_dim: int = 8):
def __init__(
self, cond_dim: int, out_dim: int, h: int, time_embed_dim: int = 8
):
super().__init__()
# Time embedding network
self.time_embed = nn.Sequential(
nn.Linear(1, time_embed_dim),
nn.GELU(),
nn.Linear(time_embed_dim, time_embed_dim),
)

# Conditioning network for better feature extraction
self.cond_net = nn.Sequential(
nn.Linear(cond_dim, h),
nn.GELU(),
nn.Linear(h, h)
nn.Linear(cond_dim, h), nn.GELU(), nn.Linear(h, h)
)

# Main velocity network with skip connections
self.net = nn.Sequential(
nn.Linear(out_dim + h + time_embed_dim, h),
Expand All @@ -136,12 +143,12 @@ def __init__(self, cond_dim: int, out_dim: int, h: int, time_embed_dim: int = 8)
nn.Dropout(0.1), # Add some regularization
nn.Linear(h, h),
nn.GELU(),
nn.Linear(h, out_dim)
nn.Linear(h, out_dim),
)

# Initialize weights for better gradient flow
self.apply(self._init_weights)

def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
Expand All @@ -152,75 +159,51 @@ def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor):
# Handle different time tensor shapes
if t.ndim == 0: # scalar time
t = t.view(1)

# Expand t to match batch dimensions of x
while t.ndim < x.ndim:
t = t.unsqueeze(-1)
t = t.expand(*x.shape[:-1], 1)

# Get time embeddings
t_embed = self.time_embed(t)

# Process conditioning
cond_features = self.cond_net(cond)

# Concatenate all inputs and compute velocity
inputs = torch.cat([x, t_embed, cond_features], dim=-1)
return self.net(inputs)


class Flow(nn.Module):
def __init__(self, cond_dim: int, out_dim: int, h: int):
super().__init__()

# Define MLP for velocity field
self.velocity_model = VelocityModel(cond_dim, out_dim, h)

# Flow matching components
scheduler = CondOTScheduler()
self.prob_path = AffineProbPath(scheduler=scheduler)
self.solver = ODESolver(self.velocity_model)

def compute_loss(self, x_0: torch.Tensor, x_1: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
def compute_loss(
self, x_0: torch.Tensor, x_1: torch.Tensor, cond: torch.Tensor
) -> torch.Tensor:
"""Compute flow matching loss."""
# Sample time uniformly
t = torch.rand(x_0.shape[0], x_0.shape[1], device=x_0.device)

# Get path sample from probability path with scheduler outputs
path_sample = self.prob_path.sample(
t=t,
x_0=x_0,
x_1=x_1,
)

path_sample = self.prob_path.sample(t=t, x_0=x_0, x_1=x_1)

# Get velocity field prediction
v_t = self.velocity_model(path_sample.x_t, path_sample.t, cond)

# Flow matching loss
return F.mse_loss(v_t, path_sample.dx_t)

@torch.inference_mode()
def step(
self,
x_t: torch.Tensor,
t_start: float,
t_end: float,
cond: torch.Tensor,
) -> torch.Tensor:
"""Performs one step of the flow matching process using ODE solver."""
# Create time grid for integration
T = torch.tensor([t_start, t_end], device=x_t.device)

# Solve ODE
sol = self.solver.sample(
x_init=x_t,
time_grid=T,
method='midpoint',
step_size=t_end - t_start,
cond=cond # Pass conditioning as context
)

return sol


class SegDiffModel(nn.Module):
"""
Expand Down Expand Up @@ -423,8 +406,10 @@ def loss(
# Flow matching loss
x_1 = target[:, 1:, :] # Target patches
x_0 = torch.randn_like(x_1) # Random noise source distribution

return self.flow.compute_loss(x_0=x_0, x_1=x_1, cond=flow_cond[:, :-1, :])

return self.flow.compute_loss(
x_0=x_0, x_1=x_1, cond=flow_cond[:, :-1, :]
)

def forward(
self,
Expand Down Expand Up @@ -455,22 +440,23 @@ def forward(
)

# Setup time steps for flow
time_steps = torch.linspace(0, 1.0, self.n_steps + 1, device=x.device)

time_grid = torch.linspace(0, 1.0, self.n_steps + 1, device=x.device)

# Get last decoder output and repeat for parallel samples
last_cond = flow_cond[:, -1, :].repeat_interleave(
num_parallel_samples, dim=0
)

# Evolve the samples through time using the flow
for i in range(self.n_steps):
x = self.flow.step(
x_t=x,
t_start=time_steps[i],
t_end=time_steps[i + 1],
cond=last_cond,
)
x = self.flow.solver.sample(
time_grid=time_grid,
x_init=x,
method="midpoint",
step_size=0.05,
return_intermediates=False,
cond=last_cond,
)

# Reshape and scale the samples
next_sample = x.view(
Expand Down Expand Up @@ -529,13 +515,21 @@ def forward(
last_cond = flow_cond[:, -1, :]

# Evolve the new samples
for i in range(self.n_steps):
x = self.flow.step(
x_t=x,
t_start=time_steps[i],
t_end=time_steps[i + 1],
cond=last_cond,
)
# for i in range(self.n_steps):
# x = self.flow.step(
# x_t=x,
# t_start=time_steps[i],
# t_end=time_steps[i + 1],
# cond=last_cond,
# )
x = self.flow.solver.sample(
time_grid=time_grid,
x_init=x,
method="midpoint",
step_size=0.05,
return_intermediates=False,
cond=last_cond,
)

# Scale and store the samples
next_sample = x.view(
Expand Down

0 comments on commit a36c562

Please sign in to comment.