diff --git a/cca_zoo/deep/_base.py b/cca_zoo/deep/_base.py index e9031f3d..dc8a3950 100644 --- a/cca_zoo/deep/_base.py +++ b/cca_zoo/deep/_base.py @@ -58,7 +58,13 @@ def training_step(self, batch: Dict[str, Any], batch_idx: int) -> torch.Tensor: loss = self.loss(batch["views"]) for k, v in loss.items(): # Use f-string instead of concatenation - self.log(f"train/{k}", v, prog_bar=True, on_epoch=True, batch_size=batch["views"][0].shape[0]) + self.log( + f"train/{k}", + v, + prog_bar=True, + on_epoch=True, + batch_size=batch["views"][0].shape[0], + ) return loss["objective"] def validation_step(self, batch: Dict[str, Any], batch_idx: int) -> torch.Tensor: @@ -66,7 +72,9 @@ def validation_step(self, batch: Dict[str, Any], batch_idx: int) -> torch.Tensor loss = self.loss(batch["views"]) for k, v in loss.items(): # Use f-string instead of concatenation - self.log(f"val/{k}", v, on_epoch=True, batch_size=batch["views"][0].shape[0]) + self.log( + f"val/{k}", v, on_epoch=True, batch_size=batch["views"][0].shape[0] + ) return loss["objective"] def test_step(self, batch: Dict[str, Any], batch_idx: int) -> torch.Tensor: @@ -74,7 +82,9 @@ def test_step(self, batch: Dict[str, Any], batch_idx: int) -> torch.Tensor: loss = self.loss(batch["views"]) for k, v in loss.items(): # Use f-string instead of concatenation - self.log(f"test/{k}", v, on_epoch=True, batch_size=batch["views"][0].shape[0]) + self.log( + f"test/{k}", v, on_epoch=True, batch_size=batch["views"][0].shape[0] + ) return loss["objective"] @torch.no_grad()