From a118431a695e5442ffb2e027d1548aadb3b776a5 Mon Sep 17 00:00:00 2001 From: jameschapman19 Date: Wed, 30 Aug 2023 16:08:33 +0000 Subject: [PATCH] Format code with black --- cca_zoo/linear/_gradient/_base.py | 7 ++++++- cca_zoo/linear/_gradient/_ey.py | 4 +++- cca_zoo/linear/_gradient/_stochasticpls.py | 2 +- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/cca_zoo/linear/_gradient/_base.py b/cca_zoo/linear/_gradient/_base.py index c18b0586..6c4d3769 100644 --- a/cca_zoo/linear/_gradient/_base.py +++ b/cca_zoo/linear/_gradient/_base.py @@ -93,7 +93,12 @@ def _fit(self, views: Iterable[np.ndarray]): ) if self.batch_size is None: # if the batch size is None, put views on the device - self.batch={"views":[view.to(trainer._accelerator_connector._accelerator_flag) for view in train_dataset.views]} + self.batch = { + "views": [ + view.to(trainer._accelerator_connector._accelerator_flag) + for view in train_dataset.views + ] + } trainer.fit(self, train_dataloader, val_dataloader) # return the weights from the module. They will need to be changed from torch tensors to numpy arrays weights = [weight.detach().cpu().numpy() for weight in self.torch_weights] diff --git a/cca_zoo/linear/_gradient/_ey.py b/cca_zoo/linear/_gradient/_ey.py index abdf3ef9..289829f1 100644 --- a/cca_zoo/linear/_gradient/_ey.py +++ b/cca_zoo/linear/_gradient/_ey.py @@ -38,7 +38,9 @@ def _compute_loss(self, batch) -> dict: else: if len(self.batch_queue) < 5: self.batch_queue.append(batch) - return {"loss": torch.tensor(0, requires_grad=True, dtype=torch.float32)} + return { + "loss": torch.tensor(0, requires_grad=True, dtype=torch.float32) + } else: random_batch = self._get_random_batch() loss = self.loss(batch["views"], random_batch["views"]) diff --git a/cca_zoo/linear/_gradient/_stochasticpls.py b/cca_zoo/linear/_gradient/_stochasticpls.py index 3ca9a660..ee158d52 100644 --- a/cca_zoo/linear/_gradient/_stochasticpls.py +++ b/cca_zoo/linear/_gradient/_stochasticpls.py @@ -10,7 +10,7 @@ def _more_tags(self): def training_step(self, batch, batch_idx): if batch is None: - batch=dict(("views",self.data)) + batch = dict(("views", self.data)) for weight in self.torch_weights: weight.data = self._orth(weight) scores = self(batch["views"])