diff --git a/cca_zoo/linear/_gradient/_ey.py b/cca_zoo/linear/_gradient/_ey.py index 068ced7e..d97e7ae2 100644 --- a/cca_zoo/linear/_gradient/_ey.py +++ b/cca_zoo/linear/_gradient/_ey.py @@ -64,9 +64,15 @@ def loss(self, views, independent_views=None, **kwargs): } def get_dataset(self, views: Iterable[np.ndarray], validation_views=None): - dataset = DoubleNumpyDataset(views) if self.batch_size else FullBatchDataset(views) + dataset = ( + DoubleNumpyDataset(views) if self.batch_size else FullBatchDataset(views) + ) if validation_views is not None: - val_dataset = DoubleNumpyDataset(validation_views) if self.batch_size else FullBatchDataset(validation_views) + val_dataset = ( + DoubleNumpyDataset(validation_views) + if self.batch_size + else FullBatchDataset(validation_views) + ) else: val_dataset = None return dataset, val_dataset