Skip to content

Commit

Permalink
add fixed split
Browse files Browse the repository at this point in the history
  • Loading branch information
weihua916 committed Oct 12, 2023
1 parent f40a360 commit fa0048d
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 14 deletions.
21 changes: 21 additions & 0 deletions test/utils/test_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import numpy as np

from torch_frame.utils.split import SPLIT_TO_NUM, generate_random_split


def test_generate_random_split():
num_data = 20
train_ratio = 0.8
val_ratio = 0.1
test_ratio = 0.1

split = generate_random_split(num_data, seed=42, train_ratio=train_ratio,
val_ratio=val_ratio)
assert (split == SPLIT_TO_NUM['train']).sum() == int(num_data *
train_ratio)
assert (split == SPLIT_TO_NUM['val']).sum() == int(num_data * val_ratio)
assert (split == SPLIT_TO_NUM['test']).sum() == int(num_data * test_ratio)
assert np.allclose(
split,
np.array([0, 1, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0]),
)
14 changes: 8 additions & 6 deletions torch_frame/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
IndexSelectType,
TaskType,
)
from torch_frame.utils.split import SPLIT_TO_NUM


def requires_pre_materialization(func):
Expand Down Expand Up @@ -152,8 +153,8 @@ class Dataset(ABC):
target_col (str, optional): The column used as target.
(default: :obj:`None`)
split_col (str, optional): The column that stores the pre-defined split
information. The column should only contain 'train', 'val', or
'test'. (default: :obj:`None`).
information. The column should only contain :obj:`0`, :obj:`1`, or
:obj:`2`. (default: :obj:`None`).
text_embedder_cfg (TextEmbedderConfig, optional): A text embedder
config specifying :obj:`text_embedder` that maps sentences into
PyTorch embeddings and :obj:`batch_size` that specifies the
Expand All @@ -179,10 +180,10 @@ def __init__(
raise ValueError(
f"col_to_stype should not contain the split_col "
f"({col_to_stype}).")
if not set(df[split_col]).issubset({'train', 'val', 'test'}):
if not set(df[split_col]).issubset(set(SPLIT_TO_NUM.values())):
raise ValueError(
"split_col must only contain either 'train', 'val', or "
"'test'.")
f"split_col must only contain {set(SPLIT_TO_NUM.values())}"
)
self.split_col = split_col
self.col_to_stype = col_to_stype

Expand Down Expand Up @@ -419,7 +420,8 @@ def get_split_dataset(self, split: str) -> 'Dataset':
if split not in ['train', 'val', 'test']:
raise ValueError(f"The split named {split} is not available. "
f"Needs to either 'train', 'val', or 'test'.")
indices = self.df.index[self.df[self.split_col] == split].tolist()
indices = self.df.index[self.df[self.split_col] ==
SPLIT_TO_NUM[split]].tolist()
return self[indices]

@property
Expand Down
7 changes: 4 additions & 3 deletions torch_frame/datasets/fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch_frame import stype
from torch_frame.config.text_embedder import TextEmbedderConfig
from torch_frame.typing import TaskType
from torch_frame.utils.split import SPLIT_TO_NUM


class FakeDataset(torch_frame.data.Dataset):
Expand Down Expand Up @@ -83,9 +84,9 @@ def __init__(
if num_rows < 3:
raise ValueError("Dataframe needs at least 3 rows to include"
" each of train, val and test split.")
split = ['train'] * num_rows
split[1] = 'val'
split[2] = 'test'
split = [SPLIT_TO_NUM['train']] * num_rows
split[1] = SPLIT_TO_NUM['val']
split[2] = SPLIT_TO_NUM['test']
df['split'] = split
super().__init__(
df,
Expand Down
24 changes: 21 additions & 3 deletions torch_frame/datasets/torch_frame_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

import torch_frame
from torch_frame.typing import TaskType
from torch_frame.utils import generate_random_split

SPLIT_COL = 'split'


class DataFrameBenchmark(torch_frame.data.Dataset):
Expand Down Expand Up @@ -206,7 +209,14 @@ def datasets_available(cls, task_type: TaskType,
def num_datasets_available(cls, task_type: TaskType, scale: str):
return len(cls.datasets_available(task_type, scale))

def __init__(self, root: str, task_type: TaskType, scale: str, idx: int):
def __init__(
self,
root: str,
task_type: TaskType,
scale: str,
idx: int,
split_random_state: int = 42,
):
self.root = root
self._task_type = task_type
self.scale = scale
Expand All @@ -224,6 +234,14 @@ def __init__(self, root: str, task_type: TaskType, scale: str, idx: int):
**kwargs)
self.cls_str = str(dataset)

# Add split col
df = dataset.df
if SPLIT_COL in df.columns:
df.drop(columns=[SPLIT_COL], in_place=True)
df[SPLIT_COL] = generate_random_split(length=len(df),
seed=split_random_state,
train_ratio=0.8, val_ratio=0.1)

# check the scale
if dataset.num_rows < 5000:
assert False
Expand All @@ -234,8 +252,8 @@ def __init__(self, root: str, task_type: TaskType, scale: str, idx: int):
else:
assert scale == "large"

super().__init__(df=dataset.df, col_to_stype=dataset.col_to_stype,
target_col=dataset.target_col)
super().__init__(df=df, col_to_stype=dataset.col_to_stype,
target_col=dataset.target_col, split_col=SPLIT_COL)
del dataset

def __repr__(self) -> str:
Expand Down
3 changes: 2 additions & 1 deletion torch_frame/datasets/yandex.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pandas as pd

import torch_frame
from torch_frame.utils.split import SPLIT_TO_NUM


def load_numpy_dict(path: str) -> Dict[str, np.ndarray]:
Expand Down Expand Up @@ -89,7 +90,7 @@ def get_df_and_col_to_stype(
df[n_col] = df[n_col].astype('float64')
df['label'] = labels
# Stores the split information in "split" column.
df['split'] = split
df['split'] = SPLIT_TO_NUM[split]
dataframes.append(df)

df = pd.concat(dataframes, ignore_index=True)
Expand Down
2 changes: 2 additions & 0 deletions torch_frame/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from .io import save, load
from .concat import cat
from .split import generate_random_split

__all__ = functions = [
'save',
'load',
'cat',
'generate_random_split',
]
2 changes: 1 addition & 1 deletion torch_frame/utils/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch import Tensor

import torch_frame
from torch_frame import TensorFrame
from torch_frame.data.tensor_frame import TensorFrame


def cat(tf_list: List[TensorFrame], along: str) -> TensorFrame:
Expand Down
28 changes: 28 additions & 0 deletions torch_frame/utils/split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import numpy as np

SPLIT_TO_NUM = {'train': 0, 'val': 1, 'test': 2}


def generate_random_split(length: int, seed: int, train_ratio: float = 0.8,
val_ratio: float = 0.1) -> np.ndarray:
r"""Generate a list of random split assignments of the specified length.
The elements are either :obj:`0`, :obj:`1`, or :obj:`2`, representing
train, val, test, respectively. Note that this relies on the fact that
numpy's shuffle is consistent across versions, which has been historically
the case."""
assert train_ratio + val_ratio < 1
assert train_ratio > 0
assert train_ratio > 0
train_num = int(length * train_ratio)
val_num = int(length * val_ratio)
test_num = length - train_num - val_num

arr = np.concatenate([
np.full(train_num, SPLIT_TO_NUM['train']),
np.full(val_num, SPLIT_TO_NUM['val']),
np.full(test_num, SPLIT_TO_NUM['test'])
])
np.random.seed(seed)
np.random.shuffle(arr)

return arr

0 comments on commit fa0048d

Please sign in to comment.