Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
jameschapman19 committed Aug 30, 2023
2 parents 5a19f94 + a118431 commit a49dc6d
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 3 deletions.
7 changes: 6 additions & 1 deletion cca_zoo/linear/_gradient/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,12 @@ def _fit(self, views: Iterable[np.ndarray]):
)
if self.batch_size is None:
# if the batch size is None, put views on the device
self.batch={"views":[view.to(trainer._accelerator_connector._accelerator_flag) for view in train_dataset.views]}
self.batch = {
"views": [
view.to(trainer._accelerator_connector._accelerator_flag)
for view in train_dataset.views
]
}
trainer.fit(self, train_dataloader, val_dataloader)
# return the weights from the module. They will need to be changed from torch tensors to numpy arrays
weights = [weight.detach().cpu().numpy() for weight in self.torch_weights]
Expand Down
4 changes: 3 additions & 1 deletion cca_zoo/linear/_gradient/_ey.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def _compute_loss(self, batch) -> dict:
else:
if len(self.batch_queue) < 5:
self.batch_queue.append(batch)
return {"loss": torch.tensor(0, requires_grad=True, dtype=torch.float32)}
return {
"loss": torch.tensor(0, requires_grad=True, dtype=torch.float32)
}
else:
random_batch = self._get_random_batch()
loss = self.loss(batch["views"], random_batch["views"])
Expand Down
2 changes: 1 addition & 1 deletion cca_zoo/linear/_gradient/_stochasticpls.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def _more_tags(self):

def training_step(self, batch, batch_idx):
if batch is None:
batch=dict(("views",self.data))
batch = dict(("views", self.data))
for weight in self.torch_weights:
weight.data = self._orth(weight)
scores = self(batch["views"])
Expand Down

0 comments on commit a49dc6d

Please sign in to comment.