diff --git a/CHANGELOG.md b/CHANGELOG.md index 2ff3d2f7..43c30dfb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Fixed the `MultimodalTextBenchmark`. ([#117](https://github.com/pyg-team/pytorch-frame/pull/117)) - Added `DataFrameBenchmark` ([#107](https://github.com/pyg-team/pytorch-frame/pull/107)). - Added stats to datasets documentation ([#101](https://github.com/pyg-team/pytorch-frame/pull/101)). - Add concat and equal ops for `TensorFrame` ([#100](https://github.com/pyg-team/pytorch-frame/pull/100)). diff --git a/examples/fttransformer_text.py b/examples/fttransformer_text.py index 2b3e5642..b13ef4fb 100644 --- a/examples/fttransformer_text.py +++ b/examples/fttransformer_text.py @@ -22,6 +22,12 @@ # Text embedded: # ============== wine_reviews =============== # Best Val Acc: 0.7946, Best Test Acc: 0.7878 +# ===== product_sentiment_machine_hack ====== +# Best Val Acc: 0.9334, Best Test Acc: 0.8814 +# ========== data_scientist_salary ========== +# Best Val Acc: 0.5355, Best Test Acc: 0.4582 +# ======== jigsaw_unintended_bias100K ======= +# Best Val Acc: 0.9543, Best Test Acc: 0.9511 class PretrainedTextEncoder: @@ -63,9 +69,12 @@ def __call__(self, sentences: List[str]) -> Tensor: is_classification = dataset.task_type.is_classification -train_dataset = dataset.get_split_dataset('train')[:0.9] -val_dataset = dataset.get_split_dataset('train')[0.9:] +train_dataset = dataset.get_split_dataset('train') +val_dataset = dataset.get_split_dataset('val') test_dataset = dataset.get_split_dataset('test') +if val_dataset.tensor_frame.num_rows == 0: + train_dataset = dataset.get_split_dataset('train')[:0.9] + val_dataset = dataset.get_split_dataset('train')[0.9:] # Set up data loaders train_tensor_frame = train_dataset.tensor_frame.to(device) @@ -82,9 +91,14 @@ def __call__(self, sentences: List[str]) -> Tensor: stype.text_embedded: LinearEmbeddingEncoder(in_channels=768) } +if is_classification: + output_channels = dataset.num_classes +else: + output_channels = 1 + model = FTTransformer( channels=args.channels, - out_channels=dataset.num_classes, + out_channels=output_channels, num_layers=args.num_layers, col_stats=dataset.col_stats, col_names_dict=train_tensor_frame.col_names_dict, diff --git a/torch_frame/datasets/multimodal_text_benchmark.py b/torch_frame/datasets/multimodal_text_benchmark.py index 6e582357..a664396b 100644 --- a/torch_frame/datasets/multimodal_text_benchmark.py +++ b/torch_frame/datasets/multimodal_text_benchmark.py @@ -6,6 +6,7 @@ import torch_frame from torch_frame.config.text_embedder import TextEmbedderConfig +from torch_frame.utils.split import SPLIT_TO_NUM class MultimodalTextBenchmark(torch_frame.data.Dataset): @@ -20,8 +21,6 @@ class MultimodalTextBenchmark(torch_frame.data.Dataset): **STATS:** - (TODO) To be added by zecheng. - .. list-table:: :widths: 20 10 10 10 10 10 20 10 :header-rows: 1 @@ -34,6 +33,94 @@ class MultimodalTextBenchmark(torch_frame.data.Dataset): - #classes - Task - Missing value ratio + * - product_sentiment_machine_hack + - 6,364 + - 0 + - 1 + - 1 + - 4 + - multiclass_classification + - 0.0% + * - jigsaw_unintended_bias100K + - 125,000 + - 29 + - 0 + - 1 + - 2 + - binary_classification + - 41.4% + * - news_channel + - 25,355 + - 14 + - 0 + - 1 + - 6 + - multiclass_classification + - 0.0% + * - wine_reviews + - 105,154 + - 2 + - 2 + - 1 + - 30 + - multiclass_classification + - 1.0% + * - fake_job_postings2 + - 15,907 + - 0 + - 3 + - 2 + - 2 + - binary_classification + - 23.8% + * - google_qa_answer_type_reason_explanation + - 6,079 + - 0 + - 1 + - 3 + - 1 + - regression + - 0.0% + * - google_qa_question_type_reason_explanation + - 6,079 + - 0 + - 1 + - 3 + - 1 + - regression + - 0.0% + * - bookprice_prediction + - 6,237 + - 2 + - 3 + - 3 + - 1 + - regression + - 1.7% + * - jc_penney_products + - 13,575 + - 2 + - 1 + - 2 + - 1 + - regression + - 13.7% + * - women_clothing_review + - 23,486 + - 1 + - 3 + - 2 + - 1 + - regression + - 1.8% + * - news_popularity2 + - 30,009 + - 3 + - 0 + - 1 + - 1 + - regression + - 0.0% """ base_url = 'https://automl-mm-bench.s3.amazonaws.com' @@ -106,8 +193,8 @@ class MultimodalTextBenchmark(torch_frame.data.Dataset): } _dataset_splits = { - 'product_sentiment_machine_hack': ['train', 'dev', 'test'], - 'data_scientist_salary': ['train', 'test', 'competition'], + 'product_sentiment_machine_hack': ['train', 'dev'], + 'data_scientist_salary': ['train', 'test'], 'melbourne_airbnb': ['train', 'test'], 'news_channel': ['train', 'test'], 'wine_reviews': ['train', 'test'], @@ -115,9 +202,9 @@ class MultimodalTextBenchmark(torch_frame.data.Dataset): 'fake_job_postings2': ['train', 'test'], 'kick_starter_funding': ['train', 'test'], 'jigsaw_unintended_bias100K': ['train', 'test'], - 'google_qa_answer_type_reason_explanation': ['train', 'dev', 'test'], - 'google_qa_question_type_reason_explanation': ['train', 'dev', 'test'], - 'bookprice_prediction': ['train', 'test', 'competition'], + 'google_qa_answer_type_reason_explanation': ['train', 'dev'], + 'google_qa_question_type_reason_explanation': ['train', 'dev'], + 'bookprice_prediction': ['train', 'test'], 'jc_penney_products': ['train', 'test'], 'women_clothing_review': ['train', 'test'], 'ae_price_prediction': ['train', 'test'], @@ -365,7 +452,7 @@ def __init__(self, root: str, name: str, splits = ['train', 'val', 'test'] if len( self._dataset_splits[self.name]) == 3 else ['train', 'test'] for split_df, split in zip(dfs, splits): - split_df['split'] = split + split_df['split'] = SPLIT_TO_NUM[split] df = pd.concat(dfs, ignore_index=True)