Skip to content

Commit

Permalink
Merge pull request #26 from nanxstats/dataset
Browse files Browse the repository at this point in the history
Add `Dataset` input support for `fit_model()`
  • Loading branch information
nanxstats authored Dec 26, 2024
2 parents aa75155 + 0674960 commit 6c12999
Show file tree
Hide file tree
Showing 8 changed files with 312 additions and 41 deletions.
1 change: 1 addition & 0 deletions docs/reference/utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
::: tinytopics.utils
options:
members:
- NumpyDiskDataset
- set_random_seed
- generate_synthetic_data
- align_topics
Expand Down
1 change: 1 addition & 0 deletions src/tinytopics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
generate_synthetic_data,
align_topics,
sort_documents,
NumpyDiskDataset,
)
from .colors import pal_tinytopics, scale_color_tinytopics
from .plot import plot_loss, plot_structure, plot_top_terms
17 changes: 10 additions & 7 deletions src/tinytopics/colors.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,20 @@ def pal_tinytopics(
Args:
format: Returned color format. Options are:
`hex`: Hex strings (default).
`rgb`: Array of RGB values.
`lab`: Array of CIELAB values.
- `hex`: Hex strings (default).
- `rgb`: Array of RGB values.
- `lab`: Array of CIELAB values.
Returns:
- If `format='hex'`, returns a list of hex color strings.
- If `format='rgb'`, returns an Nx3 numpy array of RGB values.
- If `format='lab'`, returns an Nx3 numpy array of CIELAB values.
Colors in the requested format:
- If `format='hex'`, returns a list of hex color strings.
- If `format='rgb'`, returns an Nx3 numpy array of RGB values.
- If `format='lab'`, returns an Nx3 numpy array of CIELAB values.
Raises:
ValueError: If format is not 'hex', 'rgb', or 'lab'.
ValueError: If format is not `'hex'`, `'rgb'`, or `'lab'`.
"""
TINYTOPICS_10_COLORS: Sequence[str] = (
"#4269D0", # Blue
Expand Down
87 changes: 62 additions & 25 deletions src/tinytopics/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch import Tensor
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

from .models import NeuralPoissonNMF
Expand All @@ -27,8 +28,27 @@ def poisson_nmf_loss(X: Tensor, X_reconstructed: Tensor) -> Tensor:
).sum()


class IndexTrackingDataset(Dataset):
"""Dataset wrapper that tracks indices through shuffling"""

def __init__(self, dataset: Dataset | Tensor) -> None:
self.dataset = dataset
self.shape: tuple[int, int] = (
dataset.shape
if hasattr(dataset, "shape")
else (len(dataset), dataset[0].shape[0])
)
self.is_tensor: bool = isinstance(dataset, torch.Tensor)

def __len__(self) -> int:
return len(self.dataset)

def __getitem__(self, idx: int) -> tuple[Tensor, Tensor]:
return self.dataset[idx], torch.tensor(idx)


def fit_model(
X: Tensor,
X: Tensor | Dataset,
k: int,
num_epochs: int = 200,
batch_size: int = 16,
Expand All @@ -40,29 +60,47 @@ def fit_model(
device: torch.device | None = None,
) -> Tuple[NeuralPoissonNMF, Sequence[float]]:
"""
Fit topic model using sum-to-one constrained neural Poisson NMF,
optimized with AdamW and a cosine annealing with warm restarts scheduler.
Fit topic model using sum-to-one constrained neural Poisson NMF.
Supports both in-memory tensors and custom datasets.
Args:
X: Document-term matrix.
X: Input data, can be:
- `torch.Tensor`: In-memory document-term matrix.
- `Dataset`: Custom dataset implementation.
For example, see `NumpyDiskDataset`.
k: Number of topics.
num_epochs: Number of training epochs. Default is 200.
batch_size: Number of documents per batch. Default is 16.
base_lr: Minimum learning rate after annealing. Default is 0.01.
max_lr: Starting maximum learning rate. Default is 0.05.
T_0: Number of epochs until the first restart. Default is 20.
T_mult: Factor by which the restart interval increases after each restart. Default is 1.
weight_decay: Weight decay for the AdamW optimizer. Default is 1e-5.
device: Device to run the training on. Defaults to CUDA if available, otherwise CPU.
num_epochs: Number of training epochs.
batch_size: Number of documents per batch.
base_lr: Minimum learning rate after annealing.
max_lr: Starting maximum learning rate.
T_0: Number of epochs until first restart
T_mult: Factor increasing restart interval.
weight_decay: Weight decay for AdamW optimizer.
device: Device to run training on.
Returns:
A tuple containing:
- The trained NeuralPoissonNMF model
- List of training losses for each epoch
Tuple containing:
- Trained NeuralPoissonNMF model.
- List of training losses per epoch.
"""
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
X = X.to(device)
n, m = X.shape

# Handle different input types
if isinstance(X, Dataset):
base_dataset = X
n = len(X)
m = X.num_terms if hasattr(X, "num_terms") else X[0].shape[0]
else: # torch.Tensor
X = X.to(device)
n, m = X.shape
base_dataset = X # Pass tensor directly

# Wrap dataset to track indices
dataset = IndexTrackingDataset(base_dataset)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

model = NeuralPoissonNMF(n=n, m=m, k=k, device=device)
optimizer = AdamW(model.parameters(), lr=max_lr, weight_decay=weight_decay)
Expand All @@ -71,25 +109,24 @@ def fit_model(
)

losses: Sequence[float] = []
num_batches: int = n // batch_size
num_batches: int = len(dataloader)

with tqdm(total=num_epochs, desc="Training Progress") as pbar:
for epoch in range(num_epochs):
permutation = torch.randperm(n, device=device)
epoch_loss: float = 0.0

for i in range(num_batches):
indices = permutation[i * batch_size : (i + 1) * batch_size]
batch_X = X[indices, :]
for batch_i, (batch_X, batch_indices) in enumerate(dataloader):
batch_X = batch_X.to(device)
batch_indices = batch_indices.to(device)

optimizer.zero_grad()
X_reconstructed = model(indices)
X_reconstructed = model(batch_indices)
loss = poisson_nmf_loss(batch_X, X_reconstructed)
loss.backward()

optimizer.step()

# Update per batch for cosine annealing with restarts
scheduler.step(epoch + i / num_batches)
scheduler.step(epoch + batch_i / num_batches)

epoch_loss += loss.item()

Expand Down
16 changes: 8 additions & 8 deletions src/tinytopics/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ def plot_loss(
Args:
losses: List of loss values for each epoch.
figsize: Plot size. Default is (10, 8).
dpi: Plot resolution. Default is 300.
title: Plot title. Default is "Loss curve".
figsize: Plot size.
dpi: Plot resolution.
title: Plot title.
color_palette: Custom color palette.
output_file: File path to save the plot. If None, displays the plot.
"""
Expand Down Expand Up @@ -63,8 +63,8 @@ def plot_structure(
Args:
L_matrix: Document-topic distribution matrix.
normalize_rows: If True, normalizes each row of L_matrix to sum to 1.
figsize: Plot size. Default is (12, 6).
dpi: Plot resolution. Default is 300.
figsize: Plot size.
dpi: Plot resolution.
title: Plot title.
color_palette: Custom color palette.
output_file: File path to save the plot. If None, displays the plot.
Expand Down Expand Up @@ -122,10 +122,10 @@ def plot_top_terms(
Args:
F_matrix: Topic-term distribution matrix.
n_top_terms: Number of top terms to display per topic. Default is 10.
n_top_terms: Number of top terms to display per topic.
term_names: List of term names corresponding to indices.
figsize: Plot size. Default is (10, 8).
dpi: Plot resolution. Default is 300.
figsize: Plot size.
dpi: Plot resolution.
title: Plot title.
color_palette: Custom color palette.
nrows: Number of rows in the subplot grid.
Expand Down
57 changes: 56 additions & 1 deletion src/tinytopics/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Tuple
from collections.abc import Sequence, MutableMapping
from collections import defaultdict
from pathlib import Path

import torch
import numpy as np
from torch.utils.data import Dataset
from scipy.optimize import linear_sum_assignment
from tqdm import tqdm

Expand Down Expand Up @@ -100,7 +102,7 @@ def sort_documents(L_matrix: np.ndarray) -> Sequence[int]:
L_matrix (np.ndarray): Document-topic distribution matrix.
Returns:
Sequence[int]: Indices of documents sorted by dominant topics.
Indices of documents sorted by dominant topics.
"""
n, k = L_matrix.shape
L_normalized = L_matrix / L_matrix.sum(axis=1, keepdims=True)
Expand Down Expand Up @@ -129,3 +131,56 @@ def sort_topic_groups(grouped_docs: MutableMapping[int, list]) -> Sequence[int]:
doc_info = get_document_info()
grouped_docs = group_by_topic(doc_info)
return sort_topic_groups(grouped_docs)


class NumpyDiskDataset(Dataset):
"""
A PyTorch Dataset class for loading document-term matrices from disk.
The dataset can be initialized with either a path to a `.npy` file or
a NumPy array. When a file path is provided, the data is accessed
lazily using memory mapping, which is useful for handling large datasets
that do not fit entirely in (CPU) memory.
"""

def __init__(
self, data: str | Path | np.ndarray, indices: Sequence[int] | None = None
) -> None:
"""
Args:
data: Either path to `.npy` file (str or Path) or numpy array.
indices: Optional sequence of indices to use as valid indices.
"""
if isinstance(data, (str, Path)):
data_path = Path(data)
if not data_path.exists():
raise FileNotFoundError(f"Data file not found: {data_path}")
# Get shape without loading full array
self.shape: tuple[int, int] = tuple(np.load(data_path, mmap_mode="r").shape)
self.data_path: Path = data_path
self.mmap_data: np.ndarray | None = None
else:
self.shape: tuple[int, int] = data.shape
self.data_path: None = None
self.data: np.ndarray = data

self.indices: Sequence[int] = indices or range(self.shape[0])

def __len__(self) -> int:
return len(self.indices)

def __getitem__(self, idx: int) -> torch.Tensor:
real_idx = self.indices[idx]

if self.data_path is not None:
# Load mmap data lazily
if self.mmap_data is None:
self.mmap_data = np.load(self.data_path, mmap_mode="r")
return torch.tensor(self.mmap_data[real_idx], dtype=torch.float32)
else:
return torch.tensor(self.data[real_idx], dtype=torch.float32)

@property
def num_terms(self) -> int:
"""Return vocabulary size (number of columns)."""
return self.shape[1]
88 changes: 88 additions & 0 deletions tests/test_fit_disk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import pytest
import torch
import numpy as np

from tinytopics.utils import set_random_seed, generate_synthetic_data, NumpyDiskDataset
from tinytopics.fit import fit_model

# Test data dimensions
N_DOCS = 50
N_TERMS = 100
N_TOPICS = 5
N_EPOCHS = 3


@pytest.fixture
def sample_data(tmp_path):
"""Generate sample data and return both tensor and file path."""
set_random_seed(42)
X, _, _ = generate_synthetic_data(n=N_DOCS, m=N_TERMS, k=N_TOPICS)

file_path = tmp_path / "test_data.npy"
np.save(file_path, X.cpu().numpy())

return X, file_path


def test_disk_dataset_reproducibility(sample_data):
"""Test that training with same disk dataset and seed gives identical results."""
X, file_path = sample_data
dataset = NumpyDiskDataset(file_path)

set_random_seed(42)
model1, losses1 = fit_model(dataset, k=N_TOPICS, num_epochs=N_EPOCHS)

set_random_seed(42)
model2, losses2 = fit_model(dataset, k=N_TOPICS, num_epochs=N_EPOCHS)

assert np.allclose(losses1, losses2)
assert torch.allclose(model1.get_normalized_L(), model2.get_normalized_L())
assert torch.allclose(model1.get_normalized_F(), model2.get_normalized_F())


def test_disk_dataset_different_seeds(sample_data):
"""Test that training with same disk dataset but different seeds gives different results."""
_, file_path = sample_data
dataset = NumpyDiskDataset(file_path)

set_random_seed(42)
model1, losses1 = fit_model(dataset, k=N_TOPICS, num_epochs=N_EPOCHS)

set_random_seed(43)
model2, losses2 = fit_model(dataset, k=N_TOPICS, num_epochs=N_EPOCHS)

assert not np.allclose(losses1, losses2)
assert not torch.allclose(model1.get_normalized_L(), model2.get_normalized_L())
assert not torch.allclose(model1.get_normalized_F(), model2.get_normalized_F())


def test_tensor_vs_disk_same_seed(sample_data):
"""Test that training with tensor and disk dataset gives identical results with same seed."""
X, file_path = sample_data
dataset = NumpyDiskDataset(file_path)

set_random_seed(42)
model1, losses1 = fit_model(X, k=N_TOPICS, num_epochs=N_EPOCHS)

set_random_seed(42)
model2, losses2 = fit_model(dataset, k=N_TOPICS, num_epochs=N_EPOCHS)

assert np.allclose(losses1, losses2)
assert torch.allclose(model1.get_normalized_L(), model2.get_normalized_L())
assert torch.allclose(model1.get_normalized_F(), model2.get_normalized_F())


def test_tensor_vs_disk_different_seeds(sample_data):
"""Test that training with tensor and disk dataset gives different results with different seeds."""
X, file_path = sample_data
dataset = NumpyDiskDataset(file_path)

set_random_seed(42)
model1, losses1 = fit_model(X, k=N_TOPICS, num_epochs=N_EPOCHS)

set_random_seed(43)
model2, losses2 = fit_model(dataset, k=N_TOPICS, num_epochs=N_EPOCHS)

assert not np.allclose(losses1, losses2)
assert not torch.allclose(model1.get_normalized_L(), model2.get_normalized_L())
assert not torch.allclose(model1.get_normalized_F(), model2.get_normalized_F())
Loading

0 comments on commit 6c12999

Please sign in to comment.