Skip to content

Commit

Permalink
Use pytest temporary file fixture
Browse files Browse the repository at this point in the history
  • Loading branch information
nanxstats committed Dec 26, 2024
1 parent a5577b4 commit 0674960
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 46 deletions.
16 changes: 6 additions & 10 deletions tests/test_fit_disk.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
from pathlib import Path
import tempfile

import pytest
import torch
import numpy as np
Expand All @@ -16,16 +13,15 @@


@pytest.fixture
def sample_data():
def sample_data(tmp_path):
"""Generate sample data and return both tensor and file path."""
with tempfile.TemporaryDirectory() as tmp_dir:
set_random_seed(42)
X, _, _ = generate_synthetic_data(n=N_DOCS, m=N_TERMS, k=N_TOPICS)
set_random_seed(42)
X, _, _ = generate_synthetic_data(n=N_DOCS, m=N_TERMS, k=N_TOPICS)

file_path = Path(tmp_dir) / "test_data.npy"
np.save(file_path, X.cpu().numpy())
file_path = tmp_path / "test_data.npy"
np.save(file_path, X.cpu().numpy())

yield X, file_path
return X, file_path


def test_disk_dataset_reproducibility(sample_data):
Expand Down
67 changes: 31 additions & 36 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
from pathlib import Path
import tempfile

import pytest
import torch
import numpy as np
Expand Down Expand Up @@ -148,26 +145,25 @@ def test_numpy_disk_dataset_from_array():
assert torch.allclose(item, torch.tensor(data[i], dtype=torch.float32))


def test_numpy_disk_dataset_from_file():
def test_numpy_disk_dataset_from_file(tmp_path):
"""Test NumpyDiskDataset with .npy file input."""
with tempfile.TemporaryDirectory() as tmp_dir:
data = np.random.rand(10, 5).astype(np.float32)
file_path = Path(tmp_dir) / "test_data.npy"
np.save(file_path, data)
data = np.random.rand(10, 5).astype(np.float32)
file_path = tmp_path / "test_data.npy"
np.save(file_path, data)

dataset = NumpyDiskDataset(file_path)
dataset = NumpyDiskDataset(file_path)

# Test basic properties
assert len(dataset) == 10
assert dataset.num_terms == 5
assert dataset.shape == (10, 5)
# Test basic properties
assert len(dataset) == 10
assert dataset.num_terms == 5
assert dataset.shape == (10, 5)

# Test data access
for i in range(len(dataset)):
item = dataset[i]
assert isinstance(item, torch.Tensor)
assert item.shape == (5,)
assert torch.allclose(item, torch.tensor(data[i], dtype=torch.float32))
# Test data access
for i in range(len(dataset)):
item = dataset[i]
assert isinstance(item, torch.Tensor)
assert item.shape == (5,)
assert torch.allclose(item, torch.tensor(data[i], dtype=torch.float32))


def test_numpy_disk_dataset_with_indices():
Expand Down Expand Up @@ -196,21 +192,20 @@ def test_numpy_disk_dataset_file_not_found():
NumpyDiskDataset("non_existent_file.npy")


def test_numpy_disk_dataset_memory_efficiency():
def test_numpy_disk_dataset_memory_efficiency(tmp_path):
"""Test that NumpyDiskDataset uses memory mapping efficiently."""
with tempfile.TemporaryDirectory() as tmp_dir:
shape = (1000, 500) # 500K elements
data = np.random.rand(*shape).astype(np.float32)
file_path = Path(tmp_dir) / "large_data.npy"
np.save(file_path, data)

dataset = NumpyDiskDataset(file_path)

# Access data in random order
indices = np.random.permutation(shape[0])[:100] # Sample 100 random rows
for idx in indices:
item = dataset[idx]
assert torch.allclose(item, torch.tensor(data[idx], dtype=torch.float32))

# Memory mapping should be initialized only after first access
assert dataset.mmap_data is not None
shape = (1000, 500) # 500K elements
data = np.random.rand(*shape).astype(np.float32)
file_path = tmp_path / "large_data.npy"
np.save(file_path, data)

dataset = NumpyDiskDataset(file_path)

# Access data in random order
indices = np.random.permutation(shape[0])[:100] # Sample 100 random rows
for idx in indices:
item = dataset[idx]
assert torch.allclose(item, torch.tensor(data[idx], dtype=torch.float32))

# Memory mapping should be initialized only after first access
assert dataset.mmap_data is not None

0 comments on commit 0674960

Please sign in to comment.