Skip to content

Commit

Permalink
Add DataFrameBenchmark (#107)
Browse files Browse the repository at this point in the history
For standardized benchmark.
Categorize existing datasets into 9 categories with different task types
(binary cls, multi-class cls, regression) and different scales (small,
medium, large).

Usage is as follows:
```python
dataset = DataFrameBenchmark(root, task_type=TaskType.BINARY_CLASSIFICATION, scale = "medium", idx = 2)
# Get fixed split
train_dataset = dataset.get_split_dataset('train')
val_dataset = dataset.get_split_dataset('val')
test_dataset = dataset.get_split_dataset('test')
```

Dataset documentation
[here](https://pyg-team-pytorch-frame--107.com.readthedocs.build/en/107/generated/torch_frame.datasets.DataFrameBenchmark.html#torch_frame.datasets.DataFrameBenchmark).
  • Loading branch information
weihua916 authored Oct 12, 2023
1 parent e2da6e5 commit 55ae907
Show file tree
Hide file tree
Showing 12 changed files with 1,034 additions and 18 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `DataFrameBenchmark` ([#107](https://github.com/pyg-team/pytorch-frame/pull/107)).
- Added stats to datasets documentation ([#101](https://github.com/pyg-team/pytorch-frame/pull/101)).
- Add concat and equal ops for `TensorFrame` ([#100](https://github.com/pyg-team/pytorch-frame/pull/100)).
- Use ROC-AUC for binary classification in GBDT ([#98](https://github.com/pyg-team/pytorch-frame/pull/98)).
Expand Down
134 changes: 134 additions & 0 deletions test/datasets/test_data_frame_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import pytest

from torch_frame.datasets import DataFrameBenchmark
from torch_frame.typing import TaskType


@pytest.mark.parametrize('scale', ["small", "medium", "large"])
@pytest.mark.parametrize('task_type', [
TaskType.BINARY_CLASSIFICATION, TaskType.MULTICLASS_CLASSIFICATION,
TaskType.REGRESSION
])
def test_data_frame_benchmark_match(task_type, scale):
# Make sure task_type, scale, idx triple map to the fixed underlying
# dataset. New dataset can be appeneded, but the existing mapping needes to
# be preserved.
datasets = DataFrameBenchmark.datasets_available(task_type=task_type,
scale=scale)
if task_type == TaskType.BINARY_CLASSIFICATION:
if scale == 'small':
assert datasets[0] == ('AdultCensusIncome', {})
assert datasets[1] == ('Mushroom', {})
assert datasets[2] == ('BankMarketing', {})
assert datasets[3] == ('TabularBenchmark', {
'name': 'MagicTelescope'
})
assert datasets[4] == ('TabularBenchmark', {
'name': 'bank-marketing'
})
assert datasets[5] == ('TabularBenchmark', {'name': 'california'})
assert datasets[6] == ('TabularBenchmark', {'name': 'credit'})
assert datasets[7] == ('TabularBenchmark', {
'name': 'default-of-credit-card-clients'
})
assert datasets[8] == ('TabularBenchmark', {'name': 'electricity'})
assert datasets[9] == ('TabularBenchmark', {
'name': 'eye_movements'
})
assert datasets[10] == ('TabularBenchmark', {'name': 'heloc'})
assert datasets[11] == ('TabularBenchmark', {'name': 'house_16H'})
assert datasets[12] == ('TabularBenchmark', {'name': 'pol'})
assert datasets[13] == ('Yandex', {'name': 'adult'})
elif scale == 'medium':
assert datasets[0] == ('Dota2', {})
assert datasets[1] == ('KDDCensusIncome', {})
assert datasets[2] == ('TabularBenchmark', {
'name': 'Diabetes130US'
})
assert datasets[3] == ('TabularBenchmark', {'name': 'MiniBooNE'})
assert datasets[4] == ('TabularBenchmark', {'name': 'albert'})
assert datasets[5] == ('TabularBenchmark', {'name': 'covertype'})
assert datasets[6] == ('TabularBenchmark', {'name': 'jannis'})
assert datasets[7] == ('TabularBenchmark', {'name': 'road-safety'})
assert datasets[8] == ('Yandex', {'name': 'higgs_small'})
elif scale == 'large':
assert datasets[0] == ('TabularBenchmark', {'name': 'Higgs'})
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
if scale == 'small':
assert len(datasets) == 0
elif scale == 'medium':
assert datasets[0] == ('Yandex', {'name': 'aloi'})
assert datasets[1] == ('Yandex', {'name': 'helena'})
assert datasets[2] == ('Yandex', {'name': 'jannis'})
elif scale == 'large':
assert datasets[0] == ('ForestCoverType', {})
assert datasets[1] == ('PokerHand', {})
assert datasets[2] == ('Yandex', {'name': 'covtype'})
elif task_type == TaskType.REGRESSION:
if scale == 'small':
assert datasets[0] == ('TabularBenchmark', {
'name': 'Bike_Sharing_Demand'
})
assert datasets[1] == ('TabularBenchmark', {
'name': 'Brazilian_houses'
})
assert datasets[2] == ('TabularBenchmark', {'name': 'cpu_act'})
assert datasets[3] == ('TabularBenchmark', {'name': 'elevators'})
assert datasets[4] == ('TabularBenchmark', {'name': 'house_sales'})
assert datasets[5] == ('TabularBenchmark', {'name': 'houses'})
assert datasets[6] == ('TabularBenchmark', {'name': 'sulfur'})
assert datasets[7] == ('TabularBenchmark', {
'name': 'superconduct'
})
assert datasets[8] == ('TabularBenchmark', {'name': 'topo_2_1'})
assert datasets[9] == ('TabularBenchmark', {
'name': 'visualizing_soil'
})
assert datasets[10] == ('TabularBenchmark', {
'name': 'wine_quality'
})
assert datasets[11] == ('TabularBenchmark', {'name': 'yprop_4_1'})
assert datasets[12] == ('Yandex', {'name': 'california_housing'})
elif scale == 'medium':
assert datasets[0] == ('TabularBenchmark', {
'name': 'Allstate_Claims_Severity'
})
assert datasets[1] == ('TabularBenchmark', {
'name': 'SGEMM_GPU_kernel_performance'
})
assert datasets[2] == ('TabularBenchmark', {'name': 'diamonds'})
assert datasets[3] == ('TabularBenchmark', {
'name': 'medical_charges'
})
assert datasets[4] == ('TabularBenchmark', {
'name': 'particulate-matter-ukair-2017'
})
assert datasets[5] == ('TabularBenchmark', {
'name': 'seattlecrime6'
})
elif scale == 'large':
assert datasets[0] == ('TabularBenchmark', {
'name': 'Airlines_DepDelay_1M'
})
assert datasets[1] == ('TabularBenchmark', {
'name': 'delays_zurich_transport'
})
assert datasets[2] == ('TabularBenchmark', {
'name': 'nyc-taxi-green-dec-2016'
})
assert datasets[3] == ('Yandex', {'name': 'microsoft'})
assert datasets[4] == ('Yandex', {'name': 'yahoo'})
assert datasets[5] == ('Yandex', {'name': 'year'})


def test_data_frame_benchmark_object(tmp_path):
dataset = DataFrameBenchmark(tmp_path, TaskType.BINARY_CLASSIFICATION,
'small', 1)
assert str(dataset) == ("DataFrameBenchmark(\n"
" task_type=binary_classification,\n"
" scale=small,\n"
" idx=1,\n"
" cls=Mushroom()\n"
")")
assert dataset.num_rows == 8124
dataset.materialize()
21 changes: 21 additions & 0 deletions test/utils/test_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import numpy as np

from torch_frame.utils.split import SPLIT_TO_NUM, generate_random_split


def test_generate_random_split():
num_data = 20
train_ratio = 0.8
val_ratio = 0.1
test_ratio = 0.1

split = generate_random_split(num_data, seed=42, train_ratio=train_ratio,
val_ratio=val_ratio)
assert (split == SPLIT_TO_NUM['train']).sum() == int(num_data *
train_ratio)
assert (split == SPLIT_TO_NUM['val']).sum() == int(num_data * val_ratio)
assert (split == SPLIT_TO_NUM['test']).sum() == int(num_data * test_ratio)
assert np.allclose(
split,
np.array([0, 1, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0]),
)
22 changes: 15 additions & 7 deletions torch_frame/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
IndexSelectType,
TaskType,
)
from torch_frame.utils.split import SPLIT_TO_NUM


def requires_pre_materialization(func):
Expand Down Expand Up @@ -121,7 +122,8 @@ def __call__(
df: DataFrame,
device: Optional[torch.device] = None,
) -> TensorFrame:
r"""Convert a given dataframe into :obj:`TensorFrame`."""
r"""Convert a given :obj:`DataFrame` object into :obj:`TensorFrame`
object."""

xs_dict: Dict[torch_frame.stype, List[Tensor]] = defaultdict(list)

Expand Down Expand Up @@ -152,8 +154,8 @@ class Dataset(ABC):
target_col (str, optional): The column used as target.
(default: :obj:`None`)
split_col (str, optional): The column that stores the pre-defined split
information. The column should only contain 'train', 'val', or
'test'. (default: :obj:`None`).
information. The column should only contain :obj:`0`, :obj:`1`, or
:obj:`2`. (default: :obj:`None`).
text_embedder_cfg (TextEmbedderConfig, optional): A text embedder
config specifying :obj:`text_embedder` that maps sentences into
PyTorch embeddings and :obj:`batch_size` that specifies the
Expand All @@ -179,10 +181,10 @@ def __init__(
raise ValueError(
f"col_to_stype should not contain the split_col "
f"({col_to_stype}).")
if not set(df[split_col]).issubset({'train', 'val', 'test'}):
if not set(df[split_col]).issubset(set(SPLIT_TO_NUM.values())):
raise ValueError(
"split_col must only contain either 'train', 'val', or "
"'test'.")
f"split_col must only contain {set(SPLIT_TO_NUM.values())}"
)
self.split_col = split_col
self.col_to_stype = col_to_stype

Expand Down Expand Up @@ -256,6 +258,11 @@ def task_type(self) -> TaskType:
else:
raise ValueError("Task type cannot be inferred.")

@property
def num_rows(self):
r"""Number of rows."""
return len(self.df)

@property
@requires_post_materialization
def num_classes(self) -> int:
Expand Down Expand Up @@ -414,7 +421,8 @@ def get_split_dataset(self, split: str) -> 'Dataset':
if split not in ['train', 'val', 'test']:
raise ValueError(f"The split named {split} is not available. "
f"Needs to either 'train', 'val', or 'test'.")
indices = self.df.index[self.df[self.split_col] == split].tolist()
indices = self.df.index[self.df[self.split_col] ==
SPLIT_TO_NUM[split]].tolist()
return self[indices]

@property
Expand Down
2 changes: 2 additions & 0 deletions torch_frame/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .dota2 import Dota2
from .kdd_census_income import KDDCensusIncome
from .multimodal_text_benchmark import MultimodalTextBenchmark
from .data_frame_benchmark import DataFrameBenchmark

real_world_datasets = [
'Titanic',
Expand All @@ -25,6 +26,7 @@
'Yandex',
'KDDCensusIncome',
'MultimodalTextBenchmark',
'DataFrameBenchmark',
]

synthetic_datasets = [
Expand Down
Loading

0 comments on commit 55ae907

Please sign in to comment.