From 0674960697d75dac4eeb18b5f4a31a70a3bb4e69 Mon Sep 17 00:00:00 2001 From: Nan Xiao Date: Wed, 25 Dec 2024 23:52:41 -0500 Subject: [PATCH] Use pytest temporary file fixture --- tests/test_fit_disk.py | 16 ++++------ tests/test_utils.py | 67 +++++++++++++++++++----------------------- 2 files changed, 37 insertions(+), 46 deletions(-) diff --git a/tests/test_fit_disk.py b/tests/test_fit_disk.py index fc58403..7dd3287 100644 --- a/tests/test_fit_disk.py +++ b/tests/test_fit_disk.py @@ -1,6 +1,3 @@ -from pathlib import Path -import tempfile - import pytest import torch import numpy as np @@ -16,16 +13,15 @@ @pytest.fixture -def sample_data(): +def sample_data(tmp_path): """Generate sample data and return both tensor and file path.""" - with tempfile.TemporaryDirectory() as tmp_dir: - set_random_seed(42) - X, _, _ = generate_synthetic_data(n=N_DOCS, m=N_TERMS, k=N_TOPICS) + set_random_seed(42) + X, _, _ = generate_synthetic_data(n=N_DOCS, m=N_TERMS, k=N_TOPICS) - file_path = Path(tmp_dir) / "test_data.npy" - np.save(file_path, X.cpu().numpy()) + file_path = tmp_path / "test_data.npy" + np.save(file_path, X.cpu().numpy()) - yield X, file_path + return X, file_path def test_disk_dataset_reproducibility(sample_data): diff --git a/tests/test_utils.py b/tests/test_utils.py index 23e9871..77dcdd1 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,3 @@ -from pathlib import Path -import tempfile - import pytest import torch import numpy as np @@ -148,26 +145,25 @@ def test_numpy_disk_dataset_from_array(): assert torch.allclose(item, torch.tensor(data[i], dtype=torch.float32)) -def test_numpy_disk_dataset_from_file(): +def test_numpy_disk_dataset_from_file(tmp_path): """Test NumpyDiskDataset with .npy file input.""" - with tempfile.TemporaryDirectory() as tmp_dir: - data = np.random.rand(10, 5).astype(np.float32) - file_path = Path(tmp_dir) / "test_data.npy" - np.save(file_path, data) + data = np.random.rand(10, 5).astype(np.float32) + file_path = tmp_path / "test_data.npy" + np.save(file_path, data) - dataset = NumpyDiskDataset(file_path) + dataset = NumpyDiskDataset(file_path) - # Test basic properties - assert len(dataset) == 10 - assert dataset.num_terms == 5 - assert dataset.shape == (10, 5) + # Test basic properties + assert len(dataset) == 10 + assert dataset.num_terms == 5 + assert dataset.shape == (10, 5) - # Test data access - for i in range(len(dataset)): - item = dataset[i] - assert isinstance(item, torch.Tensor) - assert item.shape == (5,) - assert torch.allclose(item, torch.tensor(data[i], dtype=torch.float32)) + # Test data access + for i in range(len(dataset)): + item = dataset[i] + assert isinstance(item, torch.Tensor) + assert item.shape == (5,) + assert torch.allclose(item, torch.tensor(data[i], dtype=torch.float32)) def test_numpy_disk_dataset_with_indices(): @@ -196,21 +192,20 @@ def test_numpy_disk_dataset_file_not_found(): NumpyDiskDataset("non_existent_file.npy") -def test_numpy_disk_dataset_memory_efficiency(): +def test_numpy_disk_dataset_memory_efficiency(tmp_path): """Test that NumpyDiskDataset uses memory mapping efficiently.""" - with tempfile.TemporaryDirectory() as tmp_dir: - shape = (1000, 500) # 500K elements - data = np.random.rand(*shape).astype(np.float32) - file_path = Path(tmp_dir) / "large_data.npy" - np.save(file_path, data) - - dataset = NumpyDiskDataset(file_path) - - # Access data in random order - indices = np.random.permutation(shape[0])[:100] # Sample 100 random rows - for idx in indices: - item = dataset[idx] - assert torch.allclose(item, torch.tensor(data[idx], dtype=torch.float32)) - - # Memory mapping should be initialized only after first access - assert dataset.mmap_data is not None + shape = (1000, 500) # 500K elements + data = np.random.rand(*shape).astype(np.float32) + file_path = tmp_path / "large_data.npy" + np.save(file_path, data) + + dataset = NumpyDiskDataset(file_path) + + # Access data in random order + indices = np.random.permutation(shape[0])[:100] # Sample 100 random rows + for idx in indices: + item = dataset[idx] + assert torch.allclose(item, torch.tensor(data[idx], dtype=torch.float32)) + + # Memory mapping should be initialized only after first access + assert dataset.mmap_data is not None