Skip to content

Commit

Permalink
Update the multimodal text data (#117)
Browse files Browse the repository at this point in the history
- Added multimodal data stats information, some datasets' performance,
and fixed the dataset (split, NAs in the test set)
- Notice that the stats now is only for datasets that only have
categorical, numerical and text data (those having temporal and
multi-categorical data are not included)

![image](https://github.com/pyg-team/pytorch-frame/assets/21955420/423d1cb2-1de0-4ba6-9ff4-1c70f22cafb9)
  • Loading branch information
zechengz authored Oct 16, 2023
1 parent 76cfa75 commit 759ce26
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 11 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Fixed the `MultimodalTextBenchmark`. ([#117](https://github.com/pyg-team/pytorch-frame/pull/117))
- Added `DataFrameBenchmark` ([#107](https://github.com/pyg-team/pytorch-frame/pull/107)).
- Added stats to datasets documentation ([#101](https://github.com/pyg-team/pytorch-frame/pull/101)).
- Add concat and equal ops for `TensorFrame` ([#100](https://github.com/pyg-team/pytorch-frame/pull/100)).
Expand Down
20 changes: 17 additions & 3 deletions examples/fttransformer_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
# Text embedded:
# ============== wine_reviews ===============
# Best Val Acc: 0.7946, Best Test Acc: 0.7878
# ===== product_sentiment_machine_hack ======
# Best Val Acc: 0.9334, Best Test Acc: 0.8814
# ========== data_scientist_salary ==========
# Best Val Acc: 0.5355, Best Test Acc: 0.4582
# ======== jigsaw_unintended_bias100K =======
# Best Val Acc: 0.9543, Best Test Acc: 0.9511


class PretrainedTextEncoder:
Expand Down Expand Up @@ -63,9 +69,12 @@ def __call__(self, sentences: List[str]) -> Tensor:

is_classification = dataset.task_type.is_classification

train_dataset = dataset.get_split_dataset('train')[:0.9]
val_dataset = dataset.get_split_dataset('train')[0.9:]
train_dataset = dataset.get_split_dataset('train')
val_dataset = dataset.get_split_dataset('val')
test_dataset = dataset.get_split_dataset('test')
if val_dataset.tensor_frame.num_rows == 0:
train_dataset = dataset.get_split_dataset('train')[:0.9]
val_dataset = dataset.get_split_dataset('train')[0.9:]

# Set up data loaders
train_tensor_frame = train_dataset.tensor_frame.to(device)
Expand All @@ -82,9 +91,14 @@ def __call__(self, sentences: List[str]) -> Tensor:
stype.text_embedded: LinearEmbeddingEncoder(in_channels=768)
}

if is_classification:
output_channels = dataset.num_classes
else:
output_channels = 1

model = FTTransformer(
channels=args.channels,
out_channels=dataset.num_classes,
out_channels=output_channels,
num_layers=args.num_layers,
col_stats=dataset.col_stats,
col_names_dict=train_tensor_frame.col_names_dict,
Expand Down
103 changes: 95 additions & 8 deletions torch_frame/datasets/multimodal_text_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import torch_frame
from torch_frame.config.text_embedder import TextEmbedderConfig
from torch_frame.utils.split import SPLIT_TO_NUM


class MultimodalTextBenchmark(torch_frame.data.Dataset):
Expand All @@ -20,8 +21,6 @@ class MultimodalTextBenchmark(torch_frame.data.Dataset):
**STATS:**
(TODO) To be added by zecheng.
.. list-table::
:widths: 20 10 10 10 10 10 20 10
:header-rows: 1
Expand All @@ -34,6 +33,94 @@ class MultimodalTextBenchmark(torch_frame.data.Dataset):
- #classes
- Task
- Missing value ratio
* - product_sentiment_machine_hack
- 6,364
- 0
- 1
- 1
- 4
- multiclass_classification
- 0.0%
* - jigsaw_unintended_bias100K
- 125,000
- 29
- 0
- 1
- 2
- binary_classification
- 41.4%
* - news_channel
- 25,355
- 14
- 0
- 1
- 6
- multiclass_classification
- 0.0%
* - wine_reviews
- 105,154
- 2
- 2
- 1
- 30
- multiclass_classification
- 1.0%
* - fake_job_postings2
- 15,907
- 0
- 3
- 2
- 2
- binary_classification
- 23.8%
* - google_qa_answer_type_reason_explanation
- 6,079
- 0
- 1
- 3
- 1
- regression
- 0.0%
* - google_qa_question_type_reason_explanation
- 6,079
- 0
- 1
- 3
- 1
- regression
- 0.0%
* - bookprice_prediction
- 6,237
- 2
- 3
- 3
- 1
- regression
- 1.7%
* - jc_penney_products
- 13,575
- 2
- 1
- 2
- 1
- regression
- 13.7%
* - women_clothing_review
- 23,486
- 1
- 3
- 2
- 1
- regression
- 1.8%
* - news_popularity2
- 30,009
- 3
- 0
- 1
- 1
- regression
- 0.0%
"""
base_url = 'https://automl-mm-bench.s3.amazonaws.com'

Expand Down Expand Up @@ -106,18 +193,18 @@ class MultimodalTextBenchmark(torch_frame.data.Dataset):
}

_dataset_splits = {
'product_sentiment_machine_hack': ['train', 'dev', 'test'],
'data_scientist_salary': ['train', 'test', 'competition'],
'product_sentiment_machine_hack': ['train', 'dev'],
'data_scientist_salary': ['train', 'test'],
'melbourne_airbnb': ['train', 'test'],
'news_channel': ['train', 'test'],
'wine_reviews': ['train', 'test'],
'imdb_genre_prediction': ['train', 'test'],
'fake_job_postings2': ['train', 'test'],
'kick_starter_funding': ['train', 'test'],
'jigsaw_unintended_bias100K': ['train', 'test'],
'google_qa_answer_type_reason_explanation': ['train', 'dev', 'test'],
'google_qa_question_type_reason_explanation': ['train', 'dev', 'test'],
'bookprice_prediction': ['train', 'test', 'competition'],
'google_qa_answer_type_reason_explanation': ['train', 'dev'],
'google_qa_question_type_reason_explanation': ['train', 'dev'],
'bookprice_prediction': ['train', 'test'],
'jc_penney_products': ['train', 'test'],
'women_clothing_review': ['train', 'test'],
'ae_price_prediction': ['train', 'test'],
Expand Down Expand Up @@ -365,7 +452,7 @@ def __init__(self, root: str, name: str,
splits = ['train', 'val', 'test'] if len(
self._dataset_splits[self.name]) == 3 else ['train', 'test']
for split_df, split in zip(dfs, splits):
split_df['split'] = split
split_df['split'] = SPLIT_TO_NUM[split]

df = pd.concat(dfs, ignore_index=True)

Expand Down

0 comments on commit 759ce26

Please sign in to comment.