Skip to content

Commit

Permalink
Added force_reload option to Dataset and InMemoryDataset [3/n] (#…
Browse files Browse the repository at this point in the history
…8436)

Co-authored-by: rusty1s <[email protected]>
  • Loading branch information
EdisonLeeeee and rusty1s authored Nov 24, 2023
1 parent 5954bb9 commit 9adb8d0
Show file tree
Hide file tree
Showing 53 changed files with 431 additions and 112 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for device conversions of `InMemoryDataset` ([#8402] (https://github.com/pyg-team/pytorch_geometric/pull/8402))
- Added support for edge-level temporal sampling in `NeighborLoader` and `LinkNeighborLoader` ([#8372] (https://github.com/pyg-team/pytorch_geometric/pull/8372))
- Added support for `torch.compile` in `ModuleDict` and `ParameterDict` ([#8363](https://github.com/pyg-team/pytorch_geometric/pull/8363))
- Added `force_reload` option to `Dataset` and `InMemoryDataset` to reload datasets ([#8352](https://github.com/pyg-team/pytorch_geometric/pull/8352), [#8357](https://github.com/pyg-team/pytorch_geometric/pull/8357))
- Added `force_reload` option to `Dataset` and `InMemoryDataset` to reload datasets ([#8352](https://github.com/pyg-team/pytorch_geometric/pull/8352), [#8357](https://github.com/pyg-team/pytorch_geometric/pull/8357), [#8436](https://github.com/pyg-team/pytorch_geometric/pull/8436))
- Added support for `torch.compile` in `MultiAggregation` ([#8345](https://github.com/pyg-team/pytorch_geometric/pull/8345))
- Added support for `torch.compile` in `HeteroConv` ([#8344](https://github.com/pyg-team/pytorch_geometric/pull/8344))
- Added support for weighted `sparse_cross_entropy` ([#8340](https://github.com/pyg-team/pytorch_geometric/pull/8340))
Expand Down
6 changes: 5 additions & 1 deletion torch_geometric/datasets/ba2motif_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ class BA2MotifDataset(InMemoryDataset):
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
**STATS:**
Expand All @@ -83,8 +85,10 @@ def __init__(
root: str,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
force_reload: bool = False,
):
super().__init__(root, transform, pre_transform)
super().__init__(root, transform, pre_transform,
force_reload=force_reload)
self.load(self.processed_paths[0])

def raw_file_names(self) -> str:
Expand Down
16 changes: 12 additions & 4 deletions torch_geometric/datasets/bitcoin_otc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class BitcoinOTC(InMemoryDataset):
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
**STATS:**
Expand All @@ -52,11 +54,17 @@ class BitcoinOTC(InMemoryDataset):

url = 'https://snap.stanford.edu/data/soc-sign-bitcoinotc.csv.gz'

def __init__(self, root: str, edge_window_size: int = 10,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None):
def __init__(
self,
root: str,
edge_window_size: int = 10,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
force_reload: bool = False,
):
self.edge_window_size = edge_window_size
super().__init__(root, transform, pre_transform)
super().__init__(root, transform, pre_transform,
force_reload=force_reload)
self.load(self.processed_paths[0])

@property
Expand Down
6 changes: 5 additions & 1 deletion torch_geometric/datasets/citation_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class CitationFull(InMemoryDataset):
being saved to disk. (default: :obj:`None`)
to_undirected (bool, optional): Whether the original graph is
converted to an undirected one. (default: :obj:`True`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
**STATS:**
Expand Down Expand Up @@ -75,11 +77,13 @@ def __init__(
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
to_undirected: bool = True,
force_reload: bool = False,
):
self.name = name.lower()
self.to_undirected = to_undirected
assert self.name in ['cora', 'cora_ml', 'citeseer', 'dblp', 'pubmed']
super().__init__(root, transform, pre_transform)
super().__init__(root, transform, pre_transform,
force_reload=force_reload)
self.load(self.processed_paths[0])

@property
Expand Down
6 changes: 5 additions & 1 deletion torch_geometric/datasets/coauthor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class Coauthor(InMemoryDataset):
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
**STATS:**
Expand Down Expand Up @@ -57,10 +59,12 @@ def __init__(
name: str,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
force_reload: bool = False,
):
assert name.lower() in ['cs', 'physics']
self.name = 'CS' if name.lower() == 'cs' else 'Physics'
super().__init__(root, transform, pre_transform)
super().__init__(root, transform, pre_transform,
force_reload=force_reload)
self.load(self.processed_paths[0])

@property
Expand Down
14 changes: 11 additions & 3 deletions torch_geometric/datasets/dblp.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class DBLP(InMemoryDataset):
an :obj:`torch_geometric.data.HeteroData` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
**STATS:**
Expand Down Expand Up @@ -80,9 +82,15 @@ class DBLP(InMemoryDataset):

url = 'https://www.dropbox.com/s/yh4grpeks87ugr2/DBLP_processed.zip?dl=1'

def __init__(self, root: str, transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None):
super().__init__(root, transform, pre_transform)
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
force_reload: bool = False,
):
super().__init__(root, transform, pre_transform,
force_reload=force_reload)
self.load(self.processed_paths[0], data_cls=HeteroData)

@property
Expand Down
16 changes: 12 additions & 4 deletions torch_geometric/datasets/dbp15k.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,24 @@ class DBP15K(InMemoryDataset):
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
"""
url = 'https://docs.google.com/uc?export=download&id={}&confirm=t'
file_id = '1ggYlYf2_kTyi7oF9g07oTNn3VDhjl7so'

def __init__(self, root: str, pair: str,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None):
def __init__(
self,
root: str,
pair: str,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
force_reload: bool = False,
):
assert pair in ['en_zh', 'en_fr', 'en_ja', 'zh_en', 'fr_en', 'ja_en']
self.pair = pair
super().__init__(root, transform, pre_transform)
super().__init__(root, transform, pre_transform,
force_reload=force_reload)
self.load(self.processed_paths[0])

@property
Expand Down
14 changes: 11 additions & 3 deletions torch_geometric/datasets/deezer_europe.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,21 @@ class DeezerEurope(InMemoryDataset):
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
"""

url = 'https://graphmining.ai/datasets/ptg/deezer_europe.npz'

def __init__(self, root: str, transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None):
super().__init__(root, transform, pre_transform)
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
force_reload: bool = False,
):
super().__init__(root, transform, pre_transform,
force_reload=force_reload)
self.load(self.processed_paths[0])

@property
Expand Down
14 changes: 11 additions & 3 deletions torch_geometric/datasets/dgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class DGraphFin(InMemoryDataset):
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
**STATS:**
Expand All @@ -51,9 +53,15 @@ class DGraphFin(InMemoryDataset):

url = "https://dgraph.xinye.com"

def __init__(self, root: str, transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None):
super().__init__(root, transform, pre_transform)
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
force_reload: bool = False,
):
super().__init__(root, transform, pre_transform,
force_reload=force_reload)
self.load(self.processed_paths[0])

def download(self):
Expand Down
14 changes: 11 additions & 3 deletions torch_geometric/datasets/elliptic.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class EllipticBitcoinDataset(InMemoryDataset):
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
**STATS:**
Expand All @@ -55,9 +57,15 @@ class EllipticBitcoinDataset(InMemoryDataset):
"""
url = 'https://data.pyg.org/datasets/elliptic'

def __init__(self, root: str, transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None):
super().__init__(root, transform, pre_transform)
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
force_reload: bool = False,
):
super().__init__(root, transform, pre_transform,
force_reload=force_reload)
self.load(self.processed_paths[0])

@property
Expand Down
6 changes: 5 additions & 1 deletion torch_geometric/datasets/elliptic_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class EllipticBitcoinTemporalDataset(EllipticBitcoinDataset):
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
**STATS:**
Expand All @@ -58,12 +60,14 @@ def __init__(
t: int,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
force_reload: bool = False,
):
if t < 1 or t > 49:
raise ValueError("'t' needs to be between 1 and 49")

self.t = t
super().__init__(root, transform, pre_transform)
super().__init__(root, transform, pre_transform,
force_reload=force_reload)

@property
def processed_file_names(self) -> str:
Expand Down
14 changes: 11 additions & 3 deletions torch_geometric/datasets/email_eu_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,24 @@ class EmailEUCore(InMemoryDataset):
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
"""

urls = [
'https://snap.stanford.edu/data/email-Eu-core.txt.gz',
'https://snap.stanford.edu/data/email-Eu-core-department-labels.txt.gz'
]

def __init__(self, root: str, transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None):
super().__init__(root, transform, pre_transform)
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
force_reload: bool = False,
):
super().__init__(root, transform, pre_transform,
force_reload=force_reload)
self.load(self.processed_paths[0])

@property
Expand Down
17 changes: 13 additions & 4 deletions torch_geometric/datasets/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class Entities(InMemoryDataset):
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
**STATS:**
Expand Down Expand Up @@ -73,13 +75,20 @@ class Entities(InMemoryDataset):

url = 'https://data.dgl.ai/dataset/{}.tgz'

def __init__(self, root: str, name: str, hetero: bool = False,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None):
def __init__(
self,
root: str,
name: str,
hetero: bool = False,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
force_reload: bool = False,
):
self.name = name.lower()
self.hetero = hetero
assert self.name in ['aifb', 'am', 'mutag', 'bgs']
super().__init__(root, transform, pre_transform)
super().__init__(root, transform, pre_transform,
force_reload=force_reload)
self.load(
self.processed_paths[0],
data_cls=HeteroData if hetero else Data,
Expand Down
14 changes: 11 additions & 3 deletions torch_geometric/datasets/facebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,21 @@ class FacebookPagePage(InMemoryDataset):
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
"""

url = 'https://graphmining.ai/datasets/ptg/facebook.npz'

def __init__(self, root: str, transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None):
super().__init__(root, transform, pre_transform)
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
force_reload: bool = False,
):
super().__init__(root, transform, pre_transform,
force_reload=force_reload)
self.load(self.processed_paths[0])

@property
Expand Down
14 changes: 11 additions & 3 deletions torch_geometric/datasets/flickr.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class Flickr(InMemoryDataset):
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
**STATS:**
Expand All @@ -48,9 +50,15 @@ class Flickr(InMemoryDataset):
class_map_id = '1uxIkbtg5drHTsKt-PAsZZ4_yJmgFmle9'
role_id = '1htXCtuktuCW8TR8KiKfrFDAxUgekQoV7'

def __init__(self, root: str, transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None):
super().__init__(root, transform, pre_transform)
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
force_reload: bool = False,
):
super().__init__(root, transform, pre_transform,
force_reload=force_reload)
self.load(self.processed_paths[0])

@property
Expand Down
Loading

0 comments on commit 9adb8d0

Please sign in to comment.