diff --git a/test/data/lightning/test_datamodule.py b/test/data/lightning/test_datamodule.py index 6093f5f4e53c..8f66e05918ff 100644 --- a/test/data/lightning/test_datamodule.py +++ b/test/data/lightning/test_datamodule.py @@ -18,12 +18,12 @@ MyFeatureStore, MyGraphStore, get_random_edge_index, + has_package, onlyCUDA, onlyFullTest, onlyNeighborSampler, onlyOnline, withPackage, - has_package, ) try: @@ -116,19 +116,16 @@ def expect_rank_zero_user_warning(match: str): assert 'shuffle' not in datamodule.kwargs old_x = train_dataset._data.x.clone() new_datamodule_repr = has_package('pytorch_lightning>=2.5.0') - datamodule_repr = ( - '{Train dataloader: size=50}\n' - '{Validation dataloader: size=30}\n' - '{Test dataloader: size=10}\n' - '{Predict dataloader: size=98}' - if new_datamodule_repr else - 'LightningDataset(train_dataset=MUTAG(50), ' - 'val_dataset=MUTAG(30), ' - 'test_dataset=MUTAG(10), ' - 'pred_dataset=MUTAG(98), batch_size=5, ' - 'num_workers=3, pin_memory=True, ' - 'persistent_workers=True)' - ) + datamodule_repr = ('{Train dataloader: size=50}\n' + '{Validation dataloader: size=30}\n' + '{Test dataloader: size=10}\n' + '{Predict dataloader: size=98}' if new_datamodule_repr + else 'LightningDataset(train_dataset=MUTAG(50), ' + 'val_dataset=MUTAG(30), ' + 'test_dataset=MUTAG(10), ' + 'pred_dataset=MUTAG(98), batch_size=5, ' + 'num_workers=3, pin_memory=True, ' + 'persistent_workers=True)') assert str(datamodule) == datamodule_repr trainer.fit(model, datamodule) @@ -144,17 +141,14 @@ def expect_rank_zero_user_warning(match: str): log_every_n_steps=1) datamodule = LightningDataset(train_dataset, batch_size=5) - datamodule_repr = ( - '{Train dataloader: size=50}\n' - '{Validation dataloader: None}\n' - '{Test dataloader: None}\n{' - 'Predict dataloader: None}' - if new_datamodule_repr else - 'LightningDataset(train_dataset=MUTAG(50), ' - 'batch_size=5, num_workers=0, ' - 'pin_memory=True, ' - 'persistent_workers=False)' - ) + datamodule_repr = ('{Train dataloader: size=50}\n' + '{Validation dataloader: None}\n' + '{Test dataloader: None}\n{' + 'Predict dataloader: None}' if new_datamodule_repr + else 'LightningDataset(train_dataset=MUTAG(50), ' + 'batch_size=5, num_workers=0, ' + 'pin_memory=True, ' + 'persistent_workers=False)') assert str(datamodule) == datamodule_repr with expect_rank_zero_user_warning("defined a `validation_step`"): @@ -256,14 +250,12 @@ def test_lightning_node_data(get_dataset, strategy_type, loader): '{Train dataloader: ' + f'size={140 if flag else 1}' + '}\n' '{Validation dataloader: ' + f'size={500 if flag else 1}' + '}\n' '{Test dataloader: ' + f'size={1000 if flag else 1}' + '}\n' - '{Predict dataloader: ' + f'size={2708 if flag else 1}' + '}' - if new_datamodule_repr else - f'LightningNodeData(data={data_repr}, ' + '{Predict dataloader: ' + f'size={2708 if flag else 1}' + + '}' if new_datamodule_repr else f'LightningNodeData(data={data_repr}, ' f'loader={loader}, batch_size={batch_size}, ' f'num_workers={num_workers}, {kwargs_repr}' f'pin_memory={flag}, ' - f'persistent_workers={flag})' - ) + f'persistent_workers={flag})') assert str(datamodule) == datamodule_repr trainer.fit(model, datamodule)