From 55ae9070d4de4c7155c2c8f211013c991c36430c Mon Sep 17 00:00:00 2001 From: Weihua Hu Date: Thu, 12 Oct 2023 12:32:32 -0700 Subject: [PATCH] Add `DataFrameBenchmark` (#107) For standardized benchmark. Categorize existing datasets into 9 categories with different task types (binary cls, multi-class cls, regression) and different scales (small, medium, large). Usage is as follows: ```python dataset = DataFrameBenchmark(root, task_type=TaskType.BINARY_CLASSIFICATION, scale = "medium", idx = 2) # Get fixed split train_dataset = dataset.get_split_dataset('train') val_dataset = dataset.get_split_dataset('val') test_dataset = dataset.get_split_dataset('test') ``` Dataset documentation [here](https://pyg-team-pytorch-frame--107.com.readthedocs.build/en/107/generated/torch_frame.datasets.DataFrameBenchmark.html#torch_frame.datasets.DataFrameBenchmark). --- CHANGELOG.md | 1 + test/datasets/test_data_frame_benchmark.py | 134 ++++ test/utils/test_split.py | 21 + torch_frame/data/dataset.py | 22 +- torch_frame/datasets/__init__.py | 2 + torch_frame/datasets/data_frame_benchmark.py | 788 +++++++++++++++++++ torch_frame/datasets/fake.py | 7 +- torch_frame/datasets/tabular_benchmark.py | 12 + torch_frame/datasets/yandex.py | 32 +- torch_frame/utils/__init__.py | 2 + torch_frame/utils/concat.py | 2 +- torch_frame/utils/split.py | 29 + 12 files changed, 1034 insertions(+), 18 deletions(-) create mode 100644 test/datasets/test_data_frame_benchmark.py create mode 100644 test/utils/test_split.py create mode 100644 torch_frame/datasets/data_frame_benchmark.py create mode 100644 torch_frame/utils/split.py diff --git a/CHANGELOG.md b/CHANGELOG.md index f489876b..2ff3d2f7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added `DataFrameBenchmark` ([#107](https://github.com/pyg-team/pytorch-frame/pull/107)). - Added stats to datasets documentation ([#101](https://github.com/pyg-team/pytorch-frame/pull/101)). - Add concat and equal ops for `TensorFrame` ([#100](https://github.com/pyg-team/pytorch-frame/pull/100)). - Use ROC-AUC for binary classification in GBDT ([#98](https://github.com/pyg-team/pytorch-frame/pull/98)). diff --git a/test/datasets/test_data_frame_benchmark.py b/test/datasets/test_data_frame_benchmark.py new file mode 100644 index 00000000..668ac3ae --- /dev/null +++ b/test/datasets/test_data_frame_benchmark.py @@ -0,0 +1,134 @@ +import pytest + +from torch_frame.datasets import DataFrameBenchmark +from torch_frame.typing import TaskType + + +@pytest.mark.parametrize('scale', ["small", "medium", "large"]) +@pytest.mark.parametrize('task_type', [ + TaskType.BINARY_CLASSIFICATION, TaskType.MULTICLASS_CLASSIFICATION, + TaskType.REGRESSION +]) +def test_data_frame_benchmark_match(task_type, scale): + # Make sure task_type, scale, idx triple map to the fixed underlying + # dataset. New dataset can be appeneded, but the existing mapping needes to + # be preserved. + datasets = DataFrameBenchmark.datasets_available(task_type=task_type, + scale=scale) + if task_type == TaskType.BINARY_CLASSIFICATION: + if scale == 'small': + assert datasets[0] == ('AdultCensusIncome', {}) + assert datasets[1] == ('Mushroom', {}) + assert datasets[2] == ('BankMarketing', {}) + assert datasets[3] == ('TabularBenchmark', { + 'name': 'MagicTelescope' + }) + assert datasets[4] == ('TabularBenchmark', { + 'name': 'bank-marketing' + }) + assert datasets[5] == ('TabularBenchmark', {'name': 'california'}) + assert datasets[6] == ('TabularBenchmark', {'name': 'credit'}) + assert datasets[7] == ('TabularBenchmark', { + 'name': 'default-of-credit-card-clients' + }) + assert datasets[8] == ('TabularBenchmark', {'name': 'electricity'}) + assert datasets[9] == ('TabularBenchmark', { + 'name': 'eye_movements' + }) + assert datasets[10] == ('TabularBenchmark', {'name': 'heloc'}) + assert datasets[11] == ('TabularBenchmark', {'name': 'house_16H'}) + assert datasets[12] == ('TabularBenchmark', {'name': 'pol'}) + assert datasets[13] == ('Yandex', {'name': 'adult'}) + elif scale == 'medium': + assert datasets[0] == ('Dota2', {}) + assert datasets[1] == ('KDDCensusIncome', {}) + assert datasets[2] == ('TabularBenchmark', { + 'name': 'Diabetes130US' + }) + assert datasets[3] == ('TabularBenchmark', {'name': 'MiniBooNE'}) + assert datasets[4] == ('TabularBenchmark', {'name': 'albert'}) + assert datasets[5] == ('TabularBenchmark', {'name': 'covertype'}) + assert datasets[6] == ('TabularBenchmark', {'name': 'jannis'}) + assert datasets[7] == ('TabularBenchmark', {'name': 'road-safety'}) + assert datasets[8] == ('Yandex', {'name': 'higgs_small'}) + elif scale == 'large': + assert datasets[0] == ('TabularBenchmark', {'name': 'Higgs'}) + elif task_type == TaskType.MULTICLASS_CLASSIFICATION: + if scale == 'small': + assert len(datasets) == 0 + elif scale == 'medium': + assert datasets[0] == ('Yandex', {'name': 'aloi'}) + assert datasets[1] == ('Yandex', {'name': 'helena'}) + assert datasets[2] == ('Yandex', {'name': 'jannis'}) + elif scale == 'large': + assert datasets[0] == ('ForestCoverType', {}) + assert datasets[1] == ('PokerHand', {}) + assert datasets[2] == ('Yandex', {'name': 'covtype'}) + elif task_type == TaskType.REGRESSION: + if scale == 'small': + assert datasets[0] == ('TabularBenchmark', { + 'name': 'Bike_Sharing_Demand' + }) + assert datasets[1] == ('TabularBenchmark', { + 'name': 'Brazilian_houses' + }) + assert datasets[2] == ('TabularBenchmark', {'name': 'cpu_act'}) + assert datasets[3] == ('TabularBenchmark', {'name': 'elevators'}) + assert datasets[4] == ('TabularBenchmark', {'name': 'house_sales'}) + assert datasets[5] == ('TabularBenchmark', {'name': 'houses'}) + assert datasets[6] == ('TabularBenchmark', {'name': 'sulfur'}) + assert datasets[7] == ('TabularBenchmark', { + 'name': 'superconduct' + }) + assert datasets[8] == ('TabularBenchmark', {'name': 'topo_2_1'}) + assert datasets[9] == ('TabularBenchmark', { + 'name': 'visualizing_soil' + }) + assert datasets[10] == ('TabularBenchmark', { + 'name': 'wine_quality' + }) + assert datasets[11] == ('TabularBenchmark', {'name': 'yprop_4_1'}) + assert datasets[12] == ('Yandex', {'name': 'california_housing'}) + elif scale == 'medium': + assert datasets[0] == ('TabularBenchmark', { + 'name': 'Allstate_Claims_Severity' + }) + assert datasets[1] == ('TabularBenchmark', { + 'name': 'SGEMM_GPU_kernel_performance' + }) + assert datasets[2] == ('TabularBenchmark', {'name': 'diamonds'}) + assert datasets[3] == ('TabularBenchmark', { + 'name': 'medical_charges' + }) + assert datasets[4] == ('TabularBenchmark', { + 'name': 'particulate-matter-ukair-2017' + }) + assert datasets[5] == ('TabularBenchmark', { + 'name': 'seattlecrime6' + }) + elif scale == 'large': + assert datasets[0] == ('TabularBenchmark', { + 'name': 'Airlines_DepDelay_1M' + }) + assert datasets[1] == ('TabularBenchmark', { + 'name': 'delays_zurich_transport' + }) + assert datasets[2] == ('TabularBenchmark', { + 'name': 'nyc-taxi-green-dec-2016' + }) + assert datasets[3] == ('Yandex', {'name': 'microsoft'}) + assert datasets[4] == ('Yandex', {'name': 'yahoo'}) + assert datasets[5] == ('Yandex', {'name': 'year'}) + + +def test_data_frame_benchmark_object(tmp_path): + dataset = DataFrameBenchmark(tmp_path, TaskType.BINARY_CLASSIFICATION, + 'small', 1) + assert str(dataset) == ("DataFrameBenchmark(\n" + " task_type=binary_classification,\n" + " scale=small,\n" + " idx=1,\n" + " cls=Mushroom()\n" + ")") + assert dataset.num_rows == 8124 + dataset.materialize() diff --git a/test/utils/test_split.py b/test/utils/test_split.py new file mode 100644 index 00000000..889a0de3 --- /dev/null +++ b/test/utils/test_split.py @@ -0,0 +1,21 @@ +import numpy as np + +from torch_frame.utils.split import SPLIT_TO_NUM, generate_random_split + + +def test_generate_random_split(): + num_data = 20 + train_ratio = 0.8 + val_ratio = 0.1 + test_ratio = 0.1 + + split = generate_random_split(num_data, seed=42, train_ratio=train_ratio, + val_ratio=val_ratio) + assert (split == SPLIT_TO_NUM['train']).sum() == int(num_data * + train_ratio) + assert (split == SPLIT_TO_NUM['val']).sum() == int(num_data * val_ratio) + assert (split == SPLIT_TO_NUM['test']).sum() == int(num_data * test_ratio) + assert np.allclose( + split, + np.array([0, 1, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0]), + ) diff --git a/torch_frame/data/dataset.py b/torch_frame/data/dataset.py index 1f73c189..a88574f6 100644 --- a/torch_frame/data/dataset.py +++ b/torch_frame/data/dataset.py @@ -25,6 +25,7 @@ IndexSelectType, TaskType, ) +from torch_frame.utils.split import SPLIT_TO_NUM def requires_pre_materialization(func): @@ -121,7 +122,8 @@ def __call__( df: DataFrame, device: Optional[torch.device] = None, ) -> TensorFrame: - r"""Convert a given dataframe into :obj:`TensorFrame`.""" + r"""Convert a given :obj:`DataFrame` object into :obj:`TensorFrame` + object.""" xs_dict: Dict[torch_frame.stype, List[Tensor]] = defaultdict(list) @@ -152,8 +154,8 @@ class Dataset(ABC): target_col (str, optional): The column used as target. (default: :obj:`None`) split_col (str, optional): The column that stores the pre-defined split - information. The column should only contain 'train', 'val', or - 'test'. (default: :obj:`None`). + information. The column should only contain :obj:`0`, :obj:`1`, or + :obj:`2`. (default: :obj:`None`). text_embedder_cfg (TextEmbedderConfig, optional): A text embedder config specifying :obj:`text_embedder` that maps sentences into PyTorch embeddings and :obj:`batch_size` that specifies the @@ -179,10 +181,10 @@ def __init__( raise ValueError( f"col_to_stype should not contain the split_col " f"({col_to_stype}).") - if not set(df[split_col]).issubset({'train', 'val', 'test'}): + if not set(df[split_col]).issubset(set(SPLIT_TO_NUM.values())): raise ValueError( - "split_col must only contain either 'train', 'val', or " - "'test'.") + f"split_col must only contain {set(SPLIT_TO_NUM.values())}" + ) self.split_col = split_col self.col_to_stype = col_to_stype @@ -256,6 +258,11 @@ def task_type(self) -> TaskType: else: raise ValueError("Task type cannot be inferred.") + @property + def num_rows(self): + r"""Number of rows.""" + return len(self.df) + @property @requires_post_materialization def num_classes(self) -> int: @@ -414,7 +421,8 @@ def get_split_dataset(self, split: str) -> 'Dataset': if split not in ['train', 'val', 'test']: raise ValueError(f"The split named {split} is not available. " f"Needs to either 'train', 'val', or 'test'.") - indices = self.df.index[self.df[self.split_col] == split].tolist() + indices = self.df.index[self.df[self.split_col] == + SPLIT_TO_NUM[split]].tolist() return self[indices] @property diff --git a/torch_frame/datasets/__init__.py b/torch_frame/datasets/__init__.py index 893a4fff..6ac0fe36 100644 --- a/torch_frame/datasets/__init__.py +++ b/torch_frame/datasets/__init__.py @@ -12,6 +12,7 @@ from .dota2 import Dota2 from .kdd_census_income import KDDCensusIncome from .multimodal_text_benchmark import MultimodalTextBenchmark +from .data_frame_benchmark import DataFrameBenchmark real_world_datasets = [ 'Titanic', @@ -25,6 +26,7 @@ 'Yandex', 'KDDCensusIncome', 'MultimodalTextBenchmark', + 'DataFrameBenchmark', ] synthetic_datasets = [ diff --git a/torch_frame/datasets/data_frame_benchmark.py b/torch_frame/datasets/data_frame_benchmark.py new file mode 100644 index 00000000..4b6fe6b8 --- /dev/null +++ b/torch_frame/datasets/data_frame_benchmark.py @@ -0,0 +1,788 @@ +from typing import Any, Dict, List, Tuple + +import pandas as pd + +import torch_frame +from torch_frame.typing import TaskType +from torch_frame.utils import generate_random_split + +SPLIT_COL = 'split' + + +class DataFrameBenchmark(torch_frame.data.Dataset): + r"""A collection of standardized datasets for tabular learning, covering + categorical and numerical features. The datasets are categorized according + to their task types and scales. + + Args: + root (str): Root directory. + task_type (TaskType): The task type. Either + :obj:`TaskType.BINARY_CLASSIFICATION`, + :obj:`TaskType.MULTICLASS_CLASSIFICATION`, or + :obj:`TaskType.REGRESSION` + scale (str): The scale of the dataset. :obj:`small` means 5K to 50K + rows. :obj:`medium` means 50K to 500K rows. :obj:`large` means more + than 500K rows. + idx (int): The integer + + **STATS:** + + .. list-table:: + :widths: 20 10 10 10 10 10 20 20 10 + :header-rows: 1 + + * - Task + - Scale + - Idx + - #rows + - #cols (numerical) + - #cols (categorical) + - #classes + - Class object + - Missing value ratio + * - binary_classification + - small + - 0 + - 32,561 + - 4 + - 8 + - 2 + - AdultCensusIncome() + - 0.0% + * - binary_classification + - small + - 1 + - 8,124 + - 0 + - 22 + - 2 + - Mushroom() + - 0.0% + * - binary_classification + - small + - 2 + - 45,211 + - 7 + - 9 + - 2 + - BankMarketing() + - 0.0% + * - binary_classification + - small + - 3 + - 13,376 + - 10 + - 0 + - 2 + - TabularBenchmark(name='MagicTelescope') + - 0.0% + * - binary_classification + - small + - 4 + - 10,578 + - 7 + - 0 + - 2 + - TabularBenchmark(name='bank-marketing') + - 0.0% + * - binary_classification + - small + - 5 + - 20,634 + - 8 + - 0 + - 2 + - TabularBenchmark(name='california') + - 0.0% + * - binary_classification + - small + - 6 + - 16,714 + - 10 + - 0 + - 2 + - TabularBenchmark(name='credit') + - 0.0% + * - binary_classification + - small + - 7 + - 13,272 + - 20 + - 1 + - 2 + - TabularBenchmark(name='default-of-credit-card-clients') + - 0.0% + * - binary_classification + - small + - 8 + - 38,474 + - 7 + - 1 + - 2 + - TabularBenchmark(name='electricity') + - 0.0% + * - binary_classification + - small + - 9 + - 7,608 + - 18 + - 5 + - 2 + - TabularBenchmark(name='eye_movements') + - 0.0% + * - binary_classification + - small + - 10 + - 10,000 + - 22 + - 0 + - 2 + - TabularBenchmark(name='heloc') + - 0.0% + * - binary_classification + - small + - 11 + - 13,488 + - 16 + - 0 + - 2 + - TabularBenchmark(name='house_16H') + - 0.0% + * - binary_classification + - small + - 12 + - 10,082 + - 26 + - 0 + - 2 + - TabularBenchmark(name='pol') + - 0.0% + * - binary_classification + - small + - 13 + - 48,842 + - 6 + - 8 + - 2 + - Yandex(name='adult') + - 0.0% + * - binary_classification + - medium + - 0 + - 92,650 + - 0 + - 116 + - 2 + - Dota2() + - 0.0% + * - binary_classification + - medium + - 1 + - 199,523 + - 7 + - 34 + - 2 + - KDDCensusIncome() + - 0.0% + * - binary_classification + - medium + - 2 + - 71,090 + - 7 + - 0 + - 2 + - TabularBenchmark(name='Diabetes130US') + - 0.0% + * - binary_classification + - medium + - 3 + - 72,998 + - 50 + - 0 + - 2 + - TabularBenchmark(name='MiniBooNE') + - 0.0% + * - binary_classification + - medium + - 4 + - 58,252 + - 23 + - 8 + - 2 + - TabularBenchmark(name='albert') + - 0.0% + * - binary_classification + - medium + - 5 + - 423,680 + - 10 + - 44 + - 2 + - TabularBenchmark(name='covertype') + - 0.0% + * - binary_classification + - medium + - 6 + - 57,580 + - 54 + - 0 + - 2 + - TabularBenchmark(name='jannis') + - 0.0% + * - binary_classification + - medium + - 7 + - 111,762 + - 24 + - 8 + - 2 + - TabularBenchmark(name='road-safety') + - 0.0% + * - binary_classification + - medium + - 8 + - 98,050 + - 28 + - 0 + - 2 + - Yandex(name='higgs_small') + - 0.0% + * - binary_classification + - large + - 0 + - 940,160 + - 24 + - 0 + - 2 + - TabularBenchmark(name='Higgs') + - 0.0% + * - multiclass_classification + - medium + - 0 + - 108,000 + - 128 + - 0 + - 1,000 + - Yandex(name='aloi') + - 0.0% + * - multiclass_classification + - medium + - 1 + - 65,196 + - 27 + - 0 + - 100 + - Yandex(name='helena') + - 0.0% + * - multiclass_classification + - medium + - 2 + - 83,733 + - 54 + - 0 + - 4 + - Yandex(name='jannis') + - 0.0% + * - multiclass_classification + - large + - 0 + - 581,012 + - 10 + - 44 + - 7 + - ForestCoverType() + - 0.0% + * - multiclass_classification + - large + - 1 + - 1,025,010 + - 5 + - 5 + - 10 + - PokerHand() + - 0.0% + * - multiclass_classification + - large + - 2 + - 581,012 + - 54 + - 0 + - 7 + - Yandex(name='covtype') + - 0.0% + * - regression + - small + - 0 + - 17,379 + - 6 + - 5 + - 1 + - TabularBenchmark(name='Bike_Sharing_Demand') + - 0.0% + * - regression + - small + - 1 + - 10,692 + - 7 + - 4 + - 1 + - TabularBenchmark(name='Brazilian_houses') + - 0.0% + * - regression + - small + - 2 + - 8,192 + - 21 + - 0 + - 1 + - TabularBenchmark(name='cpu_act') + - 0.0% + * - regression + - small + - 3 + - 16,599 + - 16 + - 0 + - 1 + - TabularBenchmark(name='elevators') + - 0.0% + * - regression + - small + - 4 + - 21,613 + - 15 + - 2 + - 1 + - TabularBenchmark(name='house_sales') + - 0.0% + * - regression + - small + - 5 + - 20,640 + - 8 + - 0 + - 1 + - TabularBenchmark(name='houses') + - 0.0% + * - regression + - small + - 6 + - 10,081 + - 6 + - 0 + - 1 + - TabularBenchmark(name='sulfur') + - 0.0% + * - regression + - small + - 7 + - 21,263 + - 79 + - 0 + - 1 + - TabularBenchmark(name='superconduct') + - 0.0% + * - regression + - small + - 8 + - 8,885 + - 252 + - 3 + - 1 + - TabularBenchmark(name='topo_2_1') + - 0.0% + * - regression + - small + - 9 + - 8,641 + - 3 + - 1 + - 1 + - TabularBenchmark(name='visualizing_soil') + - 0.0% + * - regression + - small + - 10 + - 6,497 + - 11 + - 0 + - 1 + - TabularBenchmark(name='wine_quality') + - 0.0% + * - regression + - small + - 11 + - 8,885 + - 42 + - 0 + - 1 + - TabularBenchmark(name='yprop_4_1') + - 0.0% + * - regression + - small + - 12 + - 20,640 + - 8 + - 0 + - 1 + - Yandex(name='california_housing') + - 0.0% + * - regression + - medium + - 0 + - 188,318 + - 25 + - 99 + - 1 + - TabularBenchmark(name='Allstate_Claims_Severity') + - 0.0% + * - regression + - medium + - 1 + - 241,600 + - 3 + - 6 + - 1 + - TabularBenchmark(name='SGEMM_GPU_kernel_performance') + - 0.0% + * - regression + - medium + - 2 + - 53,940 + - 6 + - 3 + - 1 + - TabularBenchmark(name='diamonds') + - 0.0% + * - regression + - medium + - 3 + - 163,065 + - 3 + - 0 + - 1 + - TabularBenchmark(name='medical_charges') + - 0.0% + * - regression + - medium + - 4 + - 394,299 + - 4 + - 2 + - 1 + - TabularBenchmark(name='particulate-matter-ukair-2017') + - 0.0% + * - regression + - medium + - 5 + - 52,031 + - 3 + - 1 + - 1 + - TabularBenchmark(name='seattlecrime6') + - 0.0% + * - regression + - large + - 0 + - 1,000,000 + - 5 + - 0 + - 1 + - TabularBenchmark(name='Airlines_DepDelay_1M') + - 0.0% + * - regression + - large + - 1 + - 5,465,575 + - 8 + - 0 + - 1 + - TabularBenchmark(name='delays_zurich_transport') + - 0.0% + * - regression + - large + - 2 + - 581,835 + - 9 + - 0 + - 1 + - TabularBenchmark(name='nyc-taxi-green-dec-2016') + - 0.0% + * - regression + - large + - 3 + - 1,200,192 + - 136 + - 0 + - 1 + - Yandex(name='microsoft') + - 0.0% + * - regression + - large + - 4 + - 709,877 + - 699 + - 0 + - 1 + - Yandex(name='yahoo') + - 0.0% + * - regression + - large + - 5 + - 515,345 + - 90 + - 0 + - 1 + - Yandex(name='year') + - 0.0% + """ + dataset_categorization_dict = { + 'binary_classification': { + 'small': [ + ('AdultCensusIncome', {}), + ('Mushroom', {}), + ('BankMarketing', {}), + ('TabularBenchmark', { + 'name': 'MagicTelescope' + }), + ('TabularBenchmark', { + 'name': 'bank-marketing' + }), + ('TabularBenchmark', { + 'name': 'california' + }), + ('TabularBenchmark', { + 'name': 'credit' + }), + ('TabularBenchmark', { + 'name': 'default-of-credit-card-clients' + }), + ('TabularBenchmark', { + 'name': 'electricity' + }), + ('TabularBenchmark', { + 'name': 'eye_movements' + }), + ('TabularBenchmark', { + 'name': 'heloc' + }), + ('TabularBenchmark', { + 'name': 'house_16H' + }), + ('TabularBenchmark', { + 'name': 'pol' + }), + ('Yandex', { + 'name': 'adult' + }), + ], + 'medium': [ + ('Dota2', {}), + ('KDDCensusIncome', {}), + ('TabularBenchmark', { + 'name': 'Diabetes130US' + }), + ('TabularBenchmark', { + 'name': 'MiniBooNE' + }), + ('TabularBenchmark', { + 'name': 'albert' + }), + ('TabularBenchmark', { + 'name': 'covertype' + }), + ('TabularBenchmark', { + 'name': 'jannis' + }), + ('TabularBenchmark', { + 'name': 'road-safety' + }), + ('Yandex', { + 'name': 'higgs_small' + }), + ], + 'large': [ + ('TabularBenchmark', { + 'name': 'Higgs' + }), + ] + }, + 'multiclass_classification': { + 'small': [], + 'medium': [ + ('Yandex', { + 'name': 'aloi' + }), + ('Yandex', { + 'name': 'helena' + }), + ('Yandex', { + 'name': 'jannis' + }), + ], + 'large': [ + ('ForestCoverType', {}), + ('PokerHand', {}), + ('Yandex', { + 'name': 'covtype' + }), + ] + }, + 'regression': { + 'small': [ + ('TabularBenchmark', { + 'name': 'Bike_Sharing_Demand' + }), + ('TabularBenchmark', { + 'name': 'Brazilian_houses' + }), + ('TabularBenchmark', { + 'name': 'cpu_act' + }), + ('TabularBenchmark', { + 'name': 'elevators' + }), + ('TabularBenchmark', { + 'name': 'house_sales' + }), + ('TabularBenchmark', { + 'name': 'houses' + }), + ('TabularBenchmark', { + 'name': 'sulfur' + }), + ('TabularBenchmark', { + 'name': 'superconduct' + }), + ('TabularBenchmark', { + 'name': 'topo_2_1' + }), + ('TabularBenchmark', { + 'name': 'visualizing_soil' + }), + ('TabularBenchmark', { + 'name': 'wine_quality' + }), + ('TabularBenchmark', { + 'name': 'yprop_4_1' + }), + ('Yandex', { + 'name': 'california_housing' + }), + ], + 'medium': [ + ('TabularBenchmark', { + 'name': 'Allstate_Claims_Severity' + }), + ('TabularBenchmark', { + 'name': 'SGEMM_GPU_kernel_performance' + }), + ('TabularBenchmark', { + 'name': 'diamonds' + }), + ('TabularBenchmark', { + 'name': 'medical_charges' + }), + ('TabularBenchmark', { + 'name': 'particulate-matter-ukair-2017' + }), + ('TabularBenchmark', { + 'name': 'seattlecrime6' + }), + ], + 'large': [ + ('TabularBenchmark', { + 'name': 'Airlines_DepDelay_1M' + }), + ('TabularBenchmark', { + 'name': 'delays_zurich_transport' + }), + ('TabularBenchmark', { + 'name': 'nyc-taxi-green-dec-2016' + }), + ('Yandex', { + 'name': 'microsoft' + }), + ('Yandex', { + 'name': 'yahoo' + }), + ('Yandex', { + 'name': 'year' + }), + ] + } + } + + @classmethod + def datasets_available(cls, task_type: TaskType, + scale: str) -> List[Tuple[str, Dict[str, Any]]]: + return cls.dataset_categorization_dict[task_type.value][scale] + + @classmethod + def num_datasets_available(cls, task_type: TaskType, scale: str): + return len(cls.datasets_available(task_type, scale)) + + def __init__( + self, + root: str, + task_type: TaskType, + scale: str, + idx: int, + split_random_state: int = 42, + ): + self.root = root + self._task_type = task_type + self.scale = scale + self.idx = idx + + datasets = self.datasets_available(task_type, scale) + if idx >= len(datasets): + raise ValueError( + f"The idx needs to be smaller than {len(datasets)}, which is " + f"the number of available datasets for task_type: " + f"{task_type.value} and scale: {scale} (got idx: {idx}).") + + class_name, kwargs = self.datasets_available(task_type, scale)[idx] + dataset = getattr(torch_frame.datasets, class_name)(root=root, + **kwargs) + self.cls_str = str(dataset) + + # Add split col + df = dataset.df + if SPLIT_COL in df.columns: + df.drop(columns=[SPLIT_COL], inplace=True) + split_df = pd.DataFrame({ + SPLIT_COL: + generate_random_split(length=len(df), seed=split_random_state, + train_ratio=0.8, val_ratio=0.1) + }) + df = pd.concat([df, split_df], axis=1) + + # check the scale + if dataset.num_rows < 5000: + assert False + elif dataset.num_rows < 50000: + assert scale == "small" + elif dataset.num_rows < 500000: + assert scale == "medium" + else: + assert scale == "large" + + super().__init__(df=df, col_to_stype=dataset.col_to_stype, + target_col=dataset.target_col, split_col=SPLIT_COL) + del dataset + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}(\n' + f' task_type={self._task_type.value},\n' + f' scale={self.scale},\n' + f' idx={self.idx},\n' + f' cls={self.cls_str}\n' + f')') + + def materialize(self): + super().materialize() + if self.task_type != self._task_type: + raise RuntimeError(f"task type does not match. It should be " + f"{self.task_type.value} but specified as " + f"{self._task_type.value}.") diff --git a/torch_frame/datasets/fake.py b/torch_frame/datasets/fake.py index dd59109e..5c9a5e41 100644 --- a/torch_frame/datasets/fake.py +++ b/torch_frame/datasets/fake.py @@ -7,6 +7,7 @@ from torch_frame import stype from torch_frame.config.text_embedder import TextEmbedderConfig from torch_frame.typing import TaskType +from torch_frame.utils.split import SPLIT_TO_NUM class FakeDataset(torch_frame.data.Dataset): @@ -83,9 +84,9 @@ def __init__( if num_rows < 3: raise ValueError("Dataframe needs at least 3 rows to include" " each of train, val and test split.") - split = ['train'] * num_rows - split[1] = 'val' - split[2] = 'test' + split = [SPLIT_TO_NUM['train']] * num_rows + split[1] = SPLIT_TO_NUM['val'] + split[2] = SPLIT_TO_NUM['test'] df['split'] = split super().__init__( df, diff --git a/torch_frame/datasets/tabular_benchmark.py b/torch_frame/datasets/tabular_benchmark.py index 5d42ad7f..81fca118 100644 --- a/torch_frame/datasets/tabular_benchmark.py +++ b/torch_frame/datasets/tabular_benchmark.py @@ -1,4 +1,5 @@ import os +from typing import List import pandas as pd from pandas.api.types import is_numeric_dtype @@ -386,7 +387,15 @@ class TabularBenchmark(torch_frame.data.Dataset): # Dedicated URLs for large datasets base_url_large = 'https://huggingface.co/datasets/inria-soda/tabular-benchmark/resolve/main/' # noqa + @classmethod + @property + def name_list(cls) -> List[str]: + r"List of dataset names available." + return sorted(list(cls.name_to_task_category.keys())) + def __init__(self, root: str, name: str): + self.root = root + self.name = name if name not in self.name_to_task_category: raise ValueError( f"The given dataset name ('{name}') is not available. It " @@ -423,3 +432,6 @@ def __init__(self, root: str, name: str): else: col_to_stype[col] = torch_frame.categorical super().__init__(df, col_to_stype, target_col=target_col) + + def __repr__(self) -> str: + return (f"{self.__class__.__name__}(name='{self.name}')") diff --git a/torch_frame/datasets/yandex.py b/torch_frame/datasets/yandex.py index ed5c5e37..5cab52ca 100644 --- a/torch_frame/datasets/yandex.py +++ b/torch_frame/datasets/yandex.py @@ -6,6 +6,10 @@ import pandas as pd import torch_frame +from torch_frame.utils.split import SPLIT_TO_NUM + +SPLIT_COL = 'split_col' +TARGET_COL = 'target_col' def load_numpy_dict(path: str) -> Dict[str, np.ndarray]: @@ -87,9 +91,13 @@ def get_df_and_col_to_stype( if numerical_features is not None: for n_col in n_col_names: df[n_col] = df[n_col].astype('float64') - df['label'] = labels - # Stores the split information in "split" column. - df['split'] = split + label_split_df = pd.DataFrame({ + TARGET_COL: + labels, + SPLIT_COL: + np.full((len(df), ), fill_value=SPLIT_TO_NUM[split]) + }) + df = pd.concat([df, label_split_df], axis=1) dataframes.append(df) df = pd.concat(dataframes, ignore_index=True) @@ -193,6 +201,13 @@ class Yandex(torch_frame.data.Dataset): } regression_datasets = {'california_housing', 'microsoft', 'yahoo', 'year'} + @classmethod + @property + def name_list(cls) -> List[str]: + r"List of dataset names available." + return sorted( + list(cls.classification_datasets) + list(cls.regression_datasets)) + def __init__(self, root: str, name: str): assert name in self.classification_datasets | self.regression_datasets self.root = root @@ -201,8 +216,11 @@ def __init__(self, root: str, name: str): root) df, col_to_stype = get_df_and_col_to_stype(path) if name in self.regression_datasets: - col_to_stype['label'] = torch_frame.numerical + col_to_stype[TARGET_COL] = torch_frame.numerical else: - col_to_stype['label'] = torch_frame.categorical - super().__init__(df, col_to_stype, target_col='label', - split_col='split') + col_to_stype[TARGET_COL] = torch_frame.categorical + super().__init__(df, col_to_stype, target_col=TARGET_COL, + split_col=SPLIT_COL) + + def __repr__(self) -> str: + return (f"{self.__class__.__name__}(name='{self.name}')") diff --git a/torch_frame/utils/__init__.py b/torch_frame/utils/__init__.py index ff0d083e..3ab31254 100644 --- a/torch_frame/utils/__init__.py +++ b/torch_frame/utils/__init__.py @@ -1,8 +1,10 @@ from .io import save, load from .concat import cat +from .split import generate_random_split __all__ = functions = [ 'save', 'load', 'cat', + 'generate_random_split', ] diff --git a/torch_frame/utils/concat.py b/torch_frame/utils/concat.py index 683d0cad..3b6df956 100644 --- a/torch_frame/utils/concat.py +++ b/torch_frame/utils/concat.py @@ -5,7 +5,7 @@ from torch import Tensor import torch_frame -from torch_frame import TensorFrame +from torch_frame.data.tensor_frame import TensorFrame def cat(tf_list: List[TensorFrame], along: str) -> TensorFrame: diff --git a/torch_frame/utils/split.py b/torch_frame/utils/split.py new file mode 100644 index 00000000..cd9dd9db --- /dev/null +++ b/torch_frame/utils/split.py @@ -0,0 +1,29 @@ +import numpy as np + +# Mapping split name to integer. +SPLIT_TO_NUM = {'train': 0, 'val': 1, 'test': 2} + + +def generate_random_split(length: int, seed: int, train_ratio: float = 0.8, + val_ratio: float = 0.1) -> np.ndarray: + r"""Generate a list of random split assignments of the specified length. + The elements are either :obj:`0`, :obj:`1`, or :obj:`2`, representing + train, val, test, respectively. Note that this function relies on the fact + that numpy's shuffle is consistent across versions, which has been + historically the case.""" + assert train_ratio + val_ratio < 1 + assert train_ratio > 0 + assert val_ratio > 0 + train_num = int(length * train_ratio) + val_num = int(length * val_ratio) + test_num = length - train_num - val_num + + arr = np.concatenate([ + np.full(train_num, SPLIT_TO_NUM['train']), + np.full(val_num, SPLIT_TO_NUM['val']), + np.full(test_num, SPLIT_TO_NUM['test']) + ]) + np.random.seed(seed) + np.random.shuffle(arr) + + return arr