-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata.py
154 lines (133 loc) · 5.75 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import os
import random
from typing import List, Tuple, Dict
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
class ImageNetValidationDataset(Dataset):
"""
A PyTorch Dataset for loading ImageNet validation data.
Attributes:
data_list (List[Tuple[str, str]]): List of (image_path, class_name) tuples
transform: PyTorch transforms to apply to images, if 'default' use default ImageNet transforms
class_mapping (Dict[str, str]): Mapping from WordNet IDs to class names
wnids (List[str]): List of WordNet IDs
class_names (List[str]): List of class names
"""
def __init__(
self,
validation_path: str,
class_mapping_file: str,
num_classes: int = None,
transform=None,
fraction_per_class: float = 1.0,
seed: int = 42
):
"""
Initialize the dataset.
Args:
validation_path (str): Path to ImageNet validation set
class_mapping_file (str): Path to synset words mapping file
num_classes (int): Number of classes to sample (default: 200)
transform: Optional transforms to apply to images
seed (int): Random seed for class sampling
"""
if transform == "default":
transform = get_default_transforms()
self.transform = transform
self.wnids, self.class_names = self._load_class_mapping(class_mapping_file)
self.class_mapping = dict(zip(self.wnids, self.class_names))
self.data_list = self._create_data_list(validation_path, num_classes, seed, fraction_per_class)
def _load_class_mapping(self, mapping_file: str) -> Tuple[List[str], List[str]]:
"""Load mapping between WordNet IDs and class names."""
wnids = []
class_names = []
with open(mapping_file, 'r') as f:
for line in f:
parts = line.strip().split(' ', 1)
if len(parts) == 2:
wnid, class_name = parts
class_name = class_name.split(', ')[0].strip()
class_names.append(class_name)
wnids.append(wnid)
return wnids, class_names
def _create_data_list(
self,
validation_path: str,
num_classes: int,
seed: int,
fraction_per_class: float = 1.0
) -> List[Tuple[str, str]]:
"""
Create list of (image_path, class_name) tuples with optional subsampling of files.
Args:
validation_path (str): Path to the validation dataset.
num_classes (int): Number of classes to sample. If None, include all classes.
seed (int): Random seed for deterministic sampling.
fraction_per_class (float): Fraction of files to include per class (default: 1.0, i.e., all files).
Returns:
List[Tuple[str, str]]: List of (image_path, class_name) tuples.
"""
random.seed(seed)
class_names = os.listdir(validation_path)
if num_classes:
sampled_classes = random.sample(class_names, num_classes)
else:
sampled_classes = class_names
data_list = []
for class_name in sampled_classes:
class_path = os.path.join(validation_path, class_name)
# Get all files in the class directory and subsample
all_files = os.listdir(class_path)
if fraction_per_class < 1.0:
print(f"Subsampling {fraction_per_class} of files from class {class_name}")
print("Before:", len(all_files), "After:", max(1, int(len(all_files) * fraction_per_class)))
num_files_to_sample = max(1, int(len(all_files) * fraction_per_class))
sampled_files = random.sample(all_files, num_files_to_sample)
else:
sampled_files = all_files # Use all files if fraction is 1.0
# Create (file_path, class_name) tuples
for file_name in sampled_files:
file_path = os.path.join(class_path, file_name)
data_list.append((file_path, class_name))
return data_list
def __len__(self) -> int:
"""Return the total number of images."""
return len(self.data_list)
def __getitem__(self, idx: int) -> Tuple[Image.Image, str]:
"""
Get an image and its class name.
Args:
idx (int): Index of the data item
Returns:
tuple: (image, class_name) where image is the transformed PIL Image
and class_name is the human-readable class name
"""
img_path, wnid = self.data_list[idx]
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
class_name = self.class_names[self.wnids.index(wnid)]
return image, class_name, img_path
def get_default_transforms():
"""Get the default ImageNet transforms."""
return transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])
if __name__ == "__main__":
VALIDATION_PATH = "/common/datasets/ImageNet_ILSVRC2012/val"
CLASS_MAPPING_FILE = "/common/datasets/ImageNet_ILSVRC2012/synset_words.txt"
dataset = ImageNetValidationDataset(
validation_path=VALIDATION_PATH,
class_mapping_file=CLASS_MAPPING_FILE,
transform=get_default_transforms()
)
image, class_name = dataset[0]
print(f"Image shape: {image.shape}")
print(f"Class name: {class_name}")