From fa0048d3d39e5dbcfdbaa4e54380d64bb27420e7 Mon Sep 17 00:00:00 2001 From: Weihua Hu Date: Wed, 11 Oct 2023 20:30:28 -0700 Subject: [PATCH] add fixed split --- test/utils/test_split.py | 21 +++++++++++++++ torch_frame/data/dataset.py | 14 +++++----- torch_frame/datasets/fake.py | 7 ++--- torch_frame/datasets/torch_frame_datasets.py | 24 ++++++++++++++--- torch_frame/datasets/yandex.py | 3 ++- torch_frame/utils/__init__.py | 2 ++ torch_frame/utils/concat.py | 2 +- torch_frame/utils/split.py | 28 ++++++++++++++++++++ 8 files changed, 87 insertions(+), 14 deletions(-) create mode 100644 test/utils/test_split.py create mode 100644 torch_frame/utils/split.py diff --git a/test/utils/test_split.py b/test/utils/test_split.py new file mode 100644 index 00000000..889a0de3 --- /dev/null +++ b/test/utils/test_split.py @@ -0,0 +1,21 @@ +import numpy as np + +from torch_frame.utils.split import SPLIT_TO_NUM, generate_random_split + + +def test_generate_random_split(): + num_data = 20 + train_ratio = 0.8 + val_ratio = 0.1 + test_ratio = 0.1 + + split = generate_random_split(num_data, seed=42, train_ratio=train_ratio, + val_ratio=val_ratio) + assert (split == SPLIT_TO_NUM['train']).sum() == int(num_data * + train_ratio) + assert (split == SPLIT_TO_NUM['val']).sum() == int(num_data * val_ratio) + assert (split == SPLIT_TO_NUM['test']).sum() == int(num_data * test_ratio) + assert np.allclose( + split, + np.array([0, 1, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0]), + ) diff --git a/torch_frame/data/dataset.py b/torch_frame/data/dataset.py index 97bbacda..f02a50bc 100644 --- a/torch_frame/data/dataset.py +++ b/torch_frame/data/dataset.py @@ -25,6 +25,7 @@ IndexSelectType, TaskType, ) +from torch_frame.utils.split import SPLIT_TO_NUM def requires_pre_materialization(func): @@ -152,8 +153,8 @@ class Dataset(ABC): target_col (str, optional): The column used as target. (default: :obj:`None`) split_col (str, optional): The column that stores the pre-defined split - information. The column should only contain 'train', 'val', or - 'test'. (default: :obj:`None`). + information. The column should only contain :obj:`0`, :obj:`1`, or + :obj:`2`. (default: :obj:`None`). text_embedder_cfg (TextEmbedderConfig, optional): A text embedder config specifying :obj:`text_embedder` that maps sentences into PyTorch embeddings and :obj:`batch_size` that specifies the @@ -179,10 +180,10 @@ def __init__( raise ValueError( f"col_to_stype should not contain the split_col " f"({col_to_stype}).") - if not set(df[split_col]).issubset({'train', 'val', 'test'}): + if not set(df[split_col]).issubset(set(SPLIT_TO_NUM.values())): raise ValueError( - "split_col must only contain either 'train', 'val', or " - "'test'.") + f"split_col must only contain {set(SPLIT_TO_NUM.values())}" + ) self.split_col = split_col self.col_to_stype = col_to_stype @@ -419,7 +420,8 @@ def get_split_dataset(self, split: str) -> 'Dataset': if split not in ['train', 'val', 'test']: raise ValueError(f"The split named {split} is not available. " f"Needs to either 'train', 'val', or 'test'.") - indices = self.df.index[self.df[self.split_col] == split].tolist() + indices = self.df.index[self.df[self.split_col] == + SPLIT_TO_NUM[split]].tolist() return self[indices] @property diff --git a/torch_frame/datasets/fake.py b/torch_frame/datasets/fake.py index dd59109e..5c9a5e41 100644 --- a/torch_frame/datasets/fake.py +++ b/torch_frame/datasets/fake.py @@ -7,6 +7,7 @@ from torch_frame import stype from torch_frame.config.text_embedder import TextEmbedderConfig from torch_frame.typing import TaskType +from torch_frame.utils.split import SPLIT_TO_NUM class FakeDataset(torch_frame.data.Dataset): @@ -83,9 +84,9 @@ def __init__( if num_rows < 3: raise ValueError("Dataframe needs at least 3 rows to include" " each of train, val and test split.") - split = ['train'] * num_rows - split[1] = 'val' - split[2] = 'test' + split = [SPLIT_TO_NUM['train']] * num_rows + split[1] = SPLIT_TO_NUM['val'] + split[2] = SPLIT_TO_NUM['test'] df['split'] = split super().__init__( df, diff --git a/torch_frame/datasets/torch_frame_datasets.py b/torch_frame/datasets/torch_frame_datasets.py index 13f9d1b5..8cb510d8 100644 --- a/torch_frame/datasets/torch_frame_datasets.py +++ b/torch_frame/datasets/torch_frame_datasets.py @@ -2,6 +2,9 @@ import torch_frame from torch_frame.typing import TaskType +from torch_frame.utils import generate_random_split + +SPLIT_COL = 'split' class DataFrameBenchmark(torch_frame.data.Dataset): @@ -206,7 +209,14 @@ def datasets_available(cls, task_type: TaskType, def num_datasets_available(cls, task_type: TaskType, scale: str): return len(cls.datasets_available(task_type, scale)) - def __init__(self, root: str, task_type: TaskType, scale: str, idx: int): + def __init__( + self, + root: str, + task_type: TaskType, + scale: str, + idx: int, + split_random_state: int = 42, + ): self.root = root self._task_type = task_type self.scale = scale @@ -224,6 +234,14 @@ def __init__(self, root: str, task_type: TaskType, scale: str, idx: int): **kwargs) self.cls_str = str(dataset) + # Add split col + df = dataset.df + if SPLIT_COL in df.columns: + df.drop(columns=[SPLIT_COL], in_place=True) + df[SPLIT_COL] = generate_random_split(length=len(df), + seed=split_random_state, + train_ratio=0.8, val_ratio=0.1) + # check the scale if dataset.num_rows < 5000: assert False @@ -234,8 +252,8 @@ def __init__(self, root: str, task_type: TaskType, scale: str, idx: int): else: assert scale == "large" - super().__init__(df=dataset.df, col_to_stype=dataset.col_to_stype, - target_col=dataset.target_col) + super().__init__(df=df, col_to_stype=dataset.col_to_stype, + target_col=dataset.target_col, split_col=SPLIT_COL) del dataset def __repr__(self) -> str: diff --git a/torch_frame/datasets/yandex.py b/torch_frame/datasets/yandex.py index 63fbf4a7..c8b8fb74 100644 --- a/torch_frame/datasets/yandex.py +++ b/torch_frame/datasets/yandex.py @@ -6,6 +6,7 @@ import pandas as pd import torch_frame +from torch_frame.utils.split import SPLIT_TO_NUM def load_numpy_dict(path: str) -> Dict[str, np.ndarray]: @@ -89,7 +90,7 @@ def get_df_and_col_to_stype( df[n_col] = df[n_col].astype('float64') df['label'] = labels # Stores the split information in "split" column. - df['split'] = split + df['split'] = SPLIT_TO_NUM[split] dataframes.append(df) df = pd.concat(dataframes, ignore_index=True) diff --git a/torch_frame/utils/__init__.py b/torch_frame/utils/__init__.py index ff0d083e..3ab31254 100644 --- a/torch_frame/utils/__init__.py +++ b/torch_frame/utils/__init__.py @@ -1,8 +1,10 @@ from .io import save, load from .concat import cat +from .split import generate_random_split __all__ = functions = [ 'save', 'load', 'cat', + 'generate_random_split', ] diff --git a/torch_frame/utils/concat.py b/torch_frame/utils/concat.py index 683d0cad..3b6df956 100644 --- a/torch_frame/utils/concat.py +++ b/torch_frame/utils/concat.py @@ -5,7 +5,7 @@ from torch import Tensor import torch_frame -from torch_frame import TensorFrame +from torch_frame.data.tensor_frame import TensorFrame def cat(tf_list: List[TensorFrame], along: str) -> TensorFrame: diff --git a/torch_frame/utils/split.py b/torch_frame/utils/split.py new file mode 100644 index 00000000..db224a74 --- /dev/null +++ b/torch_frame/utils/split.py @@ -0,0 +1,28 @@ +import numpy as np + +SPLIT_TO_NUM = {'train': 0, 'val': 1, 'test': 2} + + +def generate_random_split(length: int, seed: int, train_ratio: float = 0.8, + val_ratio: float = 0.1) -> np.ndarray: + r"""Generate a list of random split assignments of the specified length. + The elements are either :obj:`0`, :obj:`1`, or :obj:`2`, representing + train, val, test, respectively. Note that this relies on the fact that + numpy's shuffle is consistent across versions, which has been historically + the case.""" + assert train_ratio + val_ratio < 1 + assert train_ratio > 0 + assert train_ratio > 0 + train_num = int(length * train_ratio) + val_num = int(length * val_ratio) + test_num = length - train_num - val_num + + arr = np.concatenate([ + np.full(train_num, SPLIT_TO_NUM['train']), + np.full(val_num, SPLIT_TO_NUM['val']), + np.full(test_num, SPLIT_TO_NUM['test']) + ]) + np.random.seed(seed) + np.random.shuffle(arr) + + return arr