modified: Makefile
[GalaxyCodeBases.git] / python / salus / dfcorr / cell2loc.py
blob191181179b4463ad82d1b2ebd8080776cc1c958e
1 #!/usr/bin/env python3
3 import os
4 import sys
5 os.environ['CUDA_VISIBLE_DEVICES'] = '0'
6 os.environ['NCCL_P2P_LEVEL'] = 'SYS'
7 os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
8 import copy
10 import torch
11 torch.set_float32_matmul_precision('high')
13 import scanpy as sc
14 sc.logging.print_header()
15 import numpy as np
16 import matplotlib as mpl
17 import matplotlib.pyplot as plt
18 from matplotlib.backends.backend_pdf import PdfPages
19 mpl.rcdefaults()
20 mpl.rc('ps', fonttype=42, papersize='figure')
21 mpl.rc('pdf', fonttype=42, compression=9) #pdf.fonttype: 3 # Output Type 3 (Type3) or Type 42 (TrueType)
22 mpl.rc('figure', figsize=(129/25.4, 129/25.4), dpi=600) # autolayout=True
23 mpl.rc('savefig', dpi='figure') # bbox='tight'
25 import pandas as pd
26 from scipy.sparse import csr_matrix
27 import pathlib
29 import scvi
30 scvi.settings.seed = 42
31 scvi.settings.dl_num_workers = 16
32 scvi.settings.num_threads = 8
33 import cell2location
35 fn = 'adata31'
36 tp = '/share/result/spatial/test_huxs/prj/cmpmethod/testicles2/c2ltest'
37 prefix = '/share/result/spatial/test_huxs/prj/humlung/refdat/GSE112393'
39 adata_file = f'{prefix}/GSE112393_model.sc.h5ad'
40 adata_ref = sc.read_h5ad(adata_file)
41 mod = cell2location.models.RegressionModel.load(f'{prefix}/GSE112393_model', adata_ref)
42 mod.export_posterior(
43 adata_ref, sample_kwargs={'num_samples': 1000, 'batch_size': 2500, 'accelerator':'gpu'}
46 if 'means_per_cluster_mu_fg' in adata_ref.varm.keys():
47 inf_aver = adata_ref.varm['means_per_cluster_mu_fg'][[f'means_per_cluster_mu_fg_{i}'
48 for i in adata_ref.uns['mod']['factor_names']]].copy()
49 else:
50 inf_aver = adata_ref.var[[f'means_per_cluster_mu_fg_{i}'
51 for i in adata_ref.uns['mod']['factor_names']]].copy()
52 inf_aver.columns = adata_ref.uns['mod']['factor_names']
53 #inf_aver.iloc[0:5, 0:5]
55 # Define a function to hijack plt.show()
56 def hijacked_show(*args, **kwargs):
57 pdf.savefig()
58 plt.close()
59 plt.figure()
61 sqrt2val = 1.41422
62 mynacolor = (1,1,1,0)
63 mycmap = mpl.colormaps.get_cmap('nipy_spectral') # viridis is the default colormap for imshow
64 maxArray = 724041728
66 import copy
67 mod0 = copy.copy(mod)
69 spf = f'{tp}/{fn}.h5ad'
70 adata_vis = sc.read_h5ad(spf)
71 adata_ori = adata_vis.copy()
73 adata_vis.obs['sample'] = 'Testicles40'
74 adata_vis.var['SYMBOL'] = adata_vis.var_names
76 sc.preprocessing.filter_cells(adata_vis,min_genes=9)
77 sc.preprocessing.filter_genes(adata_vis,min_cells=3)
78 adata_vis.X = csr_matrix(adata_vis.X)
79 intersect = np.intersect1d(adata_vis.var_names, inf_aver.index)
80 adata_vis = adata_vis[:, intersect].copy()
81 inf_aver = inf_aver.loc[intersect, :].copy()
83 cell2location.models.Cell2location.setup_anndata(adata=adata_vis, batch_key="sample")
85 mod = cell2location.models.Cell2location(
86 adata_vis, cell_state_df=inf_aver,
87 # the expected average cell abundance: tissue-dependent
88 # hyper-prior which can be estimated from paired histology:
89 N_cells_per_location=5,
90 # hyperparameter controlling normalisation of
91 # within-experiment variation in RNA detection:
92 detection_alpha=200
94 mod.view_anndata_setup()
96 pdf = PdfPages(f"{tp}/{fn}_cell2loc.pdf")
97 plt.figure()
98 ax = plt.gca()
99 ax.set_rasterized(True)
100 sc.pl.spatial(adata_vis,color='total_counts',spot_size=40,scale_factor=1,title="nCount_Spatial",ax=ax)
101 pdf.savefig()
102 plt.close()
103 plt.figure()
104 ax = plt.gca()
105 ax.set_rasterized(True)
106 sc.pl.spatial(adata_vis,color='n_genes_by_counts',spot_size=40,scale_factor=1,title="nFeature_Spatial",ax=ax)
107 pdf.savefig()
108 plt.close()
110 eee = 30000
111 #eee = 1000
112 mod.train(max_epochs=eee,
113 # train using full data (batch_size=None)
114 batch_size=None,
115 # use all data points in training because
116 # we need to estimate cell abundance at all locations
117 train_size=1, accelerator="gpu"
120 plt.figure()
121 mod.plot_history(1000)
122 plt.legend(labels=['full data training']);
123 pdf.savefig()
124 plt.close()
125 adata_vis = mod.export_posterior(
126 adata_vis, sample_kwargs={'num_samples': 1000, 'batch_size': mod.adata.n_obs, 'accelerator':'gpu'}
128 mod.save(f'{tp}/{fn}_cell2loc_model', overwrite=True)
129 adata_file = f"{tp}/{fn}_cell2loc_model_sp.h5ad"
130 adata_vis.write(adata_file)
131 plt._original_show = plt.show
132 plt.show = hijacked_show
133 plt.figure()
134 mod.plot_QC()
135 plt.close()
136 plt.show = plt._original_show
137 mod.adata.uns['spatial'] = {'Testicles40':{
138 'images':{'hires':np.array([0,0,0],ndmin=3)},
140 #fig = mod.plot_spatial_QC_across_batches()
141 #pdf.savefig()
142 #plt.close()
143 adata_vis.obs[adata_vis.uns['mod']['factor_names']] = adata_vis.obsm['q05_cell_abundance_w_sf']
144 torch.cuda.empty_cache()
145 print('[i]Ploting CellTypes:', end=' ', flush=True)
146 for onefactor in adata_vis.uns['mod']['factor_names']:
147 mytype = ''
148 print(f'[{onefactor}]', end=' ', flush=True)
149 with mpl.rc_context({'axes.facecolor': 'black'}):
150 plt.figure()
151 ax = plt.gca()
152 ax.set_rasterized(True)
153 sc.pl.spatial(adata_vis, #cmap='magma',
154 color=onefactor,
155 #ncols=int(math.ceil(math.sqrt(len(adata_vis.uns['mod']['factor_names'])))),
156 img_key = None,
157 spot_size=90,
158 cmap=mycmap,na_color=mynacolor,size=sqrt2val, scale_factor=1,
159 vmin=0, vmax='p99.2',
160 title=f'{onefactor}',
161 ax=ax,
163 pdf.savefig()
164 plt.close()
165 pdf.close()
166 print(f'.\n[i]Done.', flush=True)