Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 11, 2025
1 parent 5ba0010 commit b308b85
Showing 1 changed file with 22 additions and 30 deletions.
52 changes: 22 additions & 30 deletions test/data/lightning/test_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
MyFeatureStore,
MyGraphStore,
get_random_edge_index,
has_package,
onlyCUDA,
onlyFullTest,
onlyNeighborSampler,
onlyOnline,
withPackage,
has_package,
)

try:
Expand Down Expand Up @@ -116,19 +116,16 @@ def expect_rank_zero_user_warning(match: str):
assert 'shuffle' not in datamodule.kwargs
old_x = train_dataset._data.x.clone()
new_datamodule_repr = has_package('pytorch_lightning>=2.5.0')
datamodule_repr = (
'{Train dataloader: size=50}\n'
'{Validation dataloader: size=30}\n'
'{Test dataloader: size=10}\n'
'{Predict dataloader: size=98}'
if new_datamodule_repr else
'LightningDataset(train_dataset=MUTAG(50), '
'val_dataset=MUTAG(30), '
'test_dataset=MUTAG(10), '
'pred_dataset=MUTAG(98), batch_size=5, '
'num_workers=3, pin_memory=True, '
'persistent_workers=True)'
)
datamodule_repr = ('{Train dataloader: size=50}\n'
'{Validation dataloader: size=30}\n'
'{Test dataloader: size=10}\n'
'{Predict dataloader: size=98}' if new_datamodule_repr
else 'LightningDataset(train_dataset=MUTAG(50), '
'val_dataset=MUTAG(30), '
'test_dataset=MUTAG(10), '
'pred_dataset=MUTAG(98), batch_size=5, '
'num_workers=3, pin_memory=True, '
'persistent_workers=True)')
assert str(datamodule) == datamodule_repr

trainer.fit(model, datamodule)
Expand All @@ -144,17 +141,14 @@ def expect_rank_zero_user_warning(match: str):
log_every_n_steps=1)

datamodule = LightningDataset(train_dataset, batch_size=5)
datamodule_repr = (
'{Train dataloader: size=50}\n'
'{Validation dataloader: None}\n'
'{Test dataloader: None}\n{'
'Predict dataloader: None}'
if new_datamodule_repr else
'LightningDataset(train_dataset=MUTAG(50), '
'batch_size=5, num_workers=0, '
'pin_memory=True, '
'persistent_workers=False)'
)
datamodule_repr = ('{Train dataloader: size=50}\n'
'{Validation dataloader: None}\n'
'{Test dataloader: None}\n{'
'Predict dataloader: None}' if new_datamodule_repr
else 'LightningDataset(train_dataset=MUTAG(50), '
'batch_size=5, num_workers=0, '
'pin_memory=True, '
'persistent_workers=False)')
assert str(datamodule) == datamodule_repr

with expect_rank_zero_user_warning("defined a `validation_step`"):
Expand Down Expand Up @@ -256,14 +250,12 @@ def test_lightning_node_data(get_dataset, strategy_type, loader):
'{Train dataloader: ' + f'size={140 if flag else 1}' + '}\n'
'{Validation dataloader: ' + f'size={500 if flag else 1}' + '}\n'
'{Test dataloader: ' + f'size={1000 if flag else 1}' + '}\n'
'{Predict dataloader: ' + f'size={2708 if flag else 1}' + '}'
if new_datamodule_repr else
f'LightningNodeData(data={data_repr}, '
'{Predict dataloader: ' + f'size={2708 if flag else 1}' +
'}' if new_datamodule_repr else f'LightningNodeData(data={data_repr}, '
f'loader={loader}, batch_size={batch_size}, '
f'num_workers={num_workers}, {kwargs_repr}'
f'pin_memory={flag}, '
f'persistent_workers={flag})'
)
f'persistent_workers={flag})')
assert str(datamodule) == datamodule_repr

trainer.fit(model, datamodule)
Expand Down

0 comments on commit b308b85

Please sign in to comment.