-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_scANVI_CD8T.py
executable file
·155 lines (120 loc) · 4.63 KB
/
run_scANVI_CD8T.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import scanpy as sc
import anndata
import warnings
#Define wrappers to run scANVI and scVI (which is required to pre-train scANVI)
def scvi(adata, batch, dims, hvg=None, return_model=False, max_epochs=None):
"""scVI wrapper function
Adapted from scib package version 1.1.3, based on scvi-tools version >=0.16.0 (available through `conda <https://docs.scvi-tools.org/en/stable/installation.html>`_)
.. note::
scVI expects only non-normalized (count) data on highly variable genes!
:param adata: preprocessed ``anndata`` object
:param batch: batch key in ``adata.obs``
:param dims: number of dimensions to use for integration
:param hvg: list of highly variables to subset to. If ``None``, the full dataset will be used
:return: ``anndata`` object containing the corrected feature matrix as well as an embedding representation of the
corrected data
"""
import numpy as np
try:
from scvi.model import SCVI
except ModuleNotFoundError as e:
raise OptionalDependencyNotInstalled(e)
# scib.utils.check_sanity(adata, batch, hvg)
# Check for counts data layer
if "counts" not in adata.layers:
raise TypeError(
"Adata does not contain a `counts` layer in `adata.layers[`counts`]`"
)
n_latent = dims #Defaults from SCVI github tutorials scanpy_pbmc3k and harmonization is 30
# Defaults from SCVI github tutorials scanpy_pbmc3k and harmonization
n_hidden = 128
n_layers = 2
# copying to not return values added to adata during setup_anndata
net_adata = adata.copy()
if hvg is not None:
net_adata = adata[:, hvg].copy()
SCVI.setup_anndata(net_adata, layer="counts", batch_key=batch)
vae = SCVI(
net_adata,
gene_likelihood="nb",
n_layers=n_layers,
n_latent=n_latent,
n_hidden=n_hidden,
)
train_kwargs = {"train_size": 1.0}
if max_epochs is not None:
train_kwargs["max_epochs"] = max_epochs
vae.train(**train_kwargs)
adata.obsm["X_emb"] = vae.get_latent_representation()
if not return_model:
return adata
else:
return vae
def scanvi(adata, batch, dims, labels, hvg=None, max_epochs=None):
"""scANVI wrapper function
Adapted from scib package version 1.1.3, based on scvi-tools version >=0.16.0 (available through `conda <https://docs.scvi-tools.org/en/stable/installation.html>`_)
.. note::
Use non-normalized (count) data for scANVI!
:param adata: preprocessed ``anndata`` object
:param batch: batch key in ``adata.obs``
:param dims: number of dimensions to use for integration
:param labels: label key in ``adata.obs``
:param hvg: list of highly variables to subset to. If ``None``, the full dataset will be used
:return: ``anndata`` object containing the corrected feature matrix as well as an embedding representation of the
corrected data
"""
import numpy as np
try:
from scvi.model import SCANVI
except ModuleNotFoundError as e:
raise OptionalDependencyNotInstalled(e)
if max_epochs is None:
n_epochs_scVI = int(np.min([round((20000 / adata.n_obs) * 400), 400]))
n_epochs_scANVI = int(np.min([np.max([2, round(n_epochs_scVI / 3.0)]), 2]))
else:
n_epochs_scVI = max_epochs
n_epochs_scANVI = max_epochs
vae = scvi(adata, batch, dims, hvg, return_model=True,max_epochs=n_epochs_scVI)
# STEP 1: RUN scVI to initialize scANVI
scanvae = SCANVI.from_scvi_model(
scvi_model=vae,
labels_key=labels,
unlabeled_category="unknown"
)
# STEP 2: RUN scANVI
scanvae.train(max_epochs=n_epochs_scANVI, train_size=1.0)
adata.obsm["X_emb"] = scanvae.get_latent_representation()
return adata
def calc_hvg(adata, nhvg, batch):
# remove HVG if already precomputed
if 'highly_variable' in adata.var:
del adata.var['highly_variable']
h = sc.pp.highly_variable_genes(
adata,
flavor="seurat_v3",
n_top_genes=nhvg,
layer="counts",
batch_key=batch,
subset=False,
inplace=False
)
h = h[h.highly_variable==True]
return list(h.index)
if __name__ == '__main__':
file = "cache/merged.h5ad"
batch = "SampleLabel"
labels = "scGate_multi"
dims = 50
nhvg = 800
outPath = "out/scANVI_integrated_CD8.h5ad"
adata = anndata.read_h5ad(file)
adata.layers["counts"] = adata.raw.X.copy()
# hvg = list(adata.var.index)
hvg = calc_hvg(adata, nhvg, batch)
print(hvg[0:10])
integrated = scanvi(adata, batch, dims, labels, hvg)
sc.pp.neighbors(integrated, use_rep="X_emb")
sc.tl.leiden(integrated)
sc.tl.umap(integrated)
del integrated.layers
sc.write(outPath, integrated)