Skip to content

Commit

Permalink
cast tensors to float32 automatically
Browse files Browse the repository at this point in the history
  • Loading branch information
cdedonno committed Jun 13, 2023
1 parent ef6e85b commit 8a2fd6c
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 3 deletions.
4 changes: 4 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"python.linting.pylintEnabled": true,
"python.linting.enabled": true
}
4 changes: 2 additions & 2 deletions scarches/dataset/scpoli/anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self,
self.cell_type_keys = cell_type_keys
self.cell_type_encoder = cell_type_encoder
self._is_sparse = sparse.issparse(adata.X)
self.data = adata.X if self._is_sparse else torch.tensor(adata.X)
self.data = adata.X if self._is_sparse else torch.tensor(adata.X, dtype=torch.float32)

size_factors = np.ravel(adata.X.sum(1))

Expand Down Expand Up @@ -80,7 +80,7 @@ def __getitem__(self, index):
outputs = dict()

if self._is_sparse:
x = torch.tensor(np.squeeze(self.data[index].toarray()))
x = torch.tensor(np.squeeze(self.data[index].toarray()), dtype=torch.float32)
else:
x = self.data[index]
outputs["x"] = x
Expand Down
2 changes: 1 addition & 1 deletion scarches/models/scpoli/scpoli_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def get_latent(
c = torch.tensor(label_tensor, device=device).T
if sparse.issparse(x):
x = x.A
x = torch.tensor(x, device=device)
x = torch.tensor(x, device=device, dtype=torch.float32)

latents = []
# batch the latent transformation process
Expand Down

0 comments on commit 8a2fd6c

Please sign in to comment.