modified: myjupyterlab.sh
[GalaxyCodeBases.git] / python / salus / cmplatform / nfig1.py
blob4b8bac18d84e404f848a383760eaca8a0ad38eb9
1 #!/usr/bin/env python3
2 import sys
3 import os
4 from typing import NamedTuple
6 PlatformTuple = ('Illumina', 'Salus')
7 SamplesDict = {
8 'mbrain': {
9 'sid' : 'mbrain',
10 'sub' : 'Mouse Brain Sptial',
11 'type': 'visium',
12 'fltPct' : 99.5,
13 'prefix' : '/share/result/spatial/data/BoAo_sp',
14 'suffixOut': dict.fromkeys(PlatformTuple,"outs"),
15 'suffixMtx': 'filtered_feature_bc_matrix',
16 'platforms': {PlatformTuple[0]:'illumina', PlatformTuple[1]: 'salus'},
17 'pattern': ('prefix', 'platformV', 'sid', 'suffixOutV', 'suffixMtx')
19 'mkidney': {
20 'sid' : 'mkidney',
21 'sub' : 'Mouse Kindey Sptial',
22 'type': 'visium',
23 'fltPct' : 99.5,
24 'prefix' : '/share/result/spatial/data/BoAo_sp',
25 'suffixOut': dict.fromkeys(PlatformTuple,"outs"),
26 'suffixMtx': 'filtered_feature_bc_matrix',
27 'platforms': {PlatformTuple[0]:'illumina', PlatformTuple[1]: 'salus'},
28 'pattern': ('prefix', 'platformV', 'sid', 'suffixOutV', 'suffixMtx')
30 'human': {
31 'sid' : 'human',
32 'sub' : 'Human Single Cell',
33 'type': 'mobivision',
34 'fltPct' : 85,
35 'prefix' : '/share/result/spatial/data/MoZhuo_sc/FX20230913',
36 'suffixOut': {PlatformTuple[0]: 'out/R22045213-220914-LYY-S11-R03-220914-LYY-S11-R03_combined_outs',
37 PlatformTuple[1]: 'out_subset/20221124-LYY-S09-R03_AGGCAGAA_fastq_outs'},
38 'suffixMtx': 'filtered_cell_gene_matrix',
39 'platforms': {PlatformTuple[0]:'illumina', PlatformTuple[1]: 'sailu'},
40 'pattern': ('prefix', 'platformV', 'suffixOutV', 'suffixMtx')
44 thisID = 'mbrain'
45 if __name__ == "__main__":
46 if len(sys.argv) > 1:
47 thisID = sys.argv[1]
48 if thisID not in SamplesDict:
49 print(f"[x]sid can only be {SamplesDict.keys()}", file=sys.stderr)
50 exit(1)
51 print(sys.argv, file=sys.stderr)
52 print(f"[i]{thisID}")
53 sys.stdout.flush()
54 nfoDict = SamplesDict[thisID]
56 import matplotlib; matplotlib.use("module://mplcairo.base")
57 from matplotlib import pyplot as plt
58 import mplcairo
60 plt.rcParams['figure.figsize'] = (6.0, 6.0) # set default size of plots
61 plt.rcParams['figure.dpi'] = 300
62 plt.rcParams['savefig.bbox'] = 'tight'
63 plt.rcParams["savefig.transparent"] = True
64 font = {'family' : 'STIX Two Text',
65 #'size' : 22,
66 'weight' : 'normal'}
67 matplotlib.rc('font', **font)
69 import numpy as np
70 import pandas as pd
71 import fast_matrix_market
72 import anndata as ad
73 import leidenalg
74 import scanpy as sc
75 sc._settings.ScanpyConfig.n_jobs = -1
76 #import squidpy as sq
77 import seaborn as sns
78 import scipy
79 import pynndescent
81 import warnings
82 warnings.filterwarnings('ignore')
83 from copy import deepcopy
85 class scDatItem(NamedTuple):
86 name: str
87 rawDat: ad.AnnData
88 annDat: ad.AnnData
89 def __repr__(self) -> str:
90 return f'[sc:{self.name}, BC*Gene: Raw={self.rawDat.shape}, Filtered={self.annDat.shape}]'
92 def main() -> None:
93 plt.switch_backend('pdf')
94 scDat = []
95 #nfoDict = SamplesDict[thisID]
96 print("[i]Start.", file=sys.stderr)
97 for platform in PlatformTuple:
98 nfoDict['platformK'] = platform
99 nfoDict['platformV'] = nfoDict['platforms'][platform]
100 nfoDict['suffixOutV'] = nfoDict['suffixOut'][platform]
101 mtxPath = os.path.join( *[nfoDict[v] for v in nfoDict['pattern']] )
102 print(f"[i]Reading {mtxPath}", file=sys.stderr)
103 adata=sc.read_10x_mtx(mtxPath, var_names='gene_symbols', make_unique=True, gex_only=True)
104 adata.var_names_make_unique()
105 adata.var['mt'] = adata.var_names.str.startswith('MT-') | adata.var_names.str.startswith('mt-')
106 rdata = deepcopy(adata)
107 sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], percent_top=None, log1p=True, inplace=True)
108 sc.pp.filter_cells(adata, min_genes=1)
109 sc.pp.filter_genes(adata, min_cells=1)
110 scDat.append(scDatItem(platform,rdata,adata))
111 rdata.write_h5ad(f"{nfoDict['sid']}_{platform}.raw.h5ad",compression='lzf')
112 print("\n".join(map(str,scDat)))
113 with pd.option_context("mode.copy_on_write", True):
114 obsmbi = scDat[0].annDat.obs[['n_genes_by_counts', 'total_counts']].copy(deep=False)
115 obsmbs = scDat[1].annDat.obs[['n_genes_by_counts', 'total_counts']].copy(deep=False)
116 p1df = pd.concat([obsmbi.assign(Platform=scDat[0].name), obsmbs.assign(Platform=scDat[1].name)], ignore_index=True).replace([np.inf, -np.inf, 0], np.nan).dropna()
117 p2df = obsmbi.join(obsmbs,lsuffix='_'+scDat[0].name,rsuffix='_'+scDat[1].name,how='inner').replace([np.inf, -np.inf, 0], np.nan).dropna()
118 p3tuple = (frozenset(scDat[0].annDat.var_names), frozenset(scDat[1].annDat.var_names))
120 metapdf={'Subject': f"{nfoDict['sub']} Data", 'Author': 'HU Xuesong'}
121 print("[i]Begin fig A. 1D", file=sys.stderr)
122 custom_params = {"axes.spines.right": False, "axes.spines.top": False}
123 sns.set_theme(style="ticks", rc=custom_params, font="STIX Two Text")
124 figA=sns.JointGrid(data=p1df, x="total_counts", y="n_genes_by_counts", hue='Platform', dropna=True)
125 #figA.plot(sns.scatterplot, sns.histplot, alpha=.7, edgecolor=".2", linewidth=.5)
126 figA.plot_joint(sns.scatterplot, s=12.7, alpha=.6)
127 figA.plot_marginals(sns.histplot, kde=False, alpha=.618)
128 figA.figure.suptitle(f"Gene to UMI plot - {nfoDict['sub']}")
129 figA.set_axis_labels(xlabel='UMIs per Barcode', ylabel='Genes per Barcode')
130 figA.savefig(f"1D_{nfoDict['sid']}.pdf", metadata={**metapdf, 'Title': 'Gene to UMI plot'})
132 print("[i]Begin fig B. 1E", file=sys.stderr)
133 figB=sns.JointGrid(data=p2df, x="total_counts_Illumina", y="total_counts_Salus", dropna=True)
134 figB.plot_joint(sns.scatterplot, s=12.7, alpha=.6)
135 figB.plot_marginals(sns.histplot, kde=True, alpha=.618)
136 figB.figure.suptitle(f"UMI per Barcode Counts Comparing - {nfoDict['sub']}")
137 figB.set_axis_labels(xlabel='UMI Counts from Illumina', ylabel='UMI Counts from Salus')
138 figB.savefig(f"1E_{nfoDict['sid']}.pdf", metadata={**metapdf, 'Title': 'UMI per Barcode Counts Comparing'})
140 print("[i]Begin fig . 1G", file=sys.stderr)
141 from matplotlib_venn import venn2
142 plt.figure(figsize=(4,4))
143 plt.title(f"Genes Venn diagram - {nfoDict['sub']}")
144 p3intersection = p3tuple[0] & p3tuple[1]
145 p3veen = (p3tuple[0]-p3intersection, p3tuple[1]-p3intersection, p3intersection)
146 GenesA = scDat[0].annDat.var.loc[p3veen[0]-p3veen[2]]
147 GenesB = scDat[1].annDat.var.loc[p3veen[1]-p3veen[2]]
148 GenesC = scDat[0].annDat.var.loc[p3veen[2]]
149 p3vd=venn2(subsets=tuple(map(len,p3veen)), set_labels=(scDat[0].name, scDat[1].name))
150 plt.savefig(f"1G_Genes_{nfoDict['sid']}.pdf", metadata={**metapdf, 'Title': 'Veen of Genes'})
151 GenesA.to_csv(f"1G_Genes_{nfoDict['sid']}_{scDat[0].name}_only.csv",encoding='utf-8')
152 GenesB.to_csv(f"1G_Genes_{nfoDict['sid']}_{scDat[1].name}_only.csv",encoding='utf-8')
153 GenesC.to_csv(f"1G_Genes_{nfoDict['sid']}_intersection.csv.zst",encoding='utf-8',compression={'method': 'zstd', 'level': 9, 'write_checksum': True})
155 print("[i]Begin fig C. 2A", file=sys.stderr)
156 # https://www.kaggle.com/code/lizabogdan/top-correlated-genes?scriptVersionId=109838203&cellId=21
157 p4xdf = scDat[0].annDat.to_df()
158 p4ydf = scDat[1].annDat.to_df()
159 p4corraw = p4xdf.corrwith(p4ydf,axis=1)
160 p4corr = p4corraw.dropna()
161 plt.figure(figsize=(6,4))
162 plt.title(f"Pearson correlation - {nfoDict['sub']}")
163 figC=sns.histplot(p4corr,stat='count',binwidth=0.01)
164 plt.savefig(f"2A_Correlation_{nfoDict['sid']}.pdf", metadata={**metapdf, 'Title': 'Pearson correlation'})
166 print("[i]Begin fig D. 2B", file=sys.stderr)
167 var_names = scDat[0].annDat.var_names.intersection(scDat[1].annDat.var_names)
168 xadata = scDat[0].annDat[:, var_names]
169 yadata = scDat[1].annDat[:, var_names]
170 xdf=getOBSMdf(xadata)
171 ydf=getOBSMdf(yadata)
172 #p4df = xdf.assign(Platform=scDat[0].name).join(ydf.assign(Platform=scDat[1].name),lsuffix='_'+scDat[0].name,rsuffix='_'+scDat[1].name,how='inner')
173 p4df = pd.concat([xdf.assign(Platform=scDat[0].name), ydf.assign(Platform=scDat[1].name)], ignore_index=True).replace([np.inf, -np.inf, 0], np.nan).dropna()
174 figD=sns.JointGrid(data=p4df, x="P1", y="P2", hue='Platform', dropna=True)
175 figD.plot_joint(sns.scatterplot, s=12.7, alpha=.6)
176 figD.plot_marginals(sns.histplot, kde=True, alpha=.618)
177 figD.figure.suptitle(f"PCA - {nfoDict['sub']}")
178 figD.set_axis_labels(xlabel='PC1', ylabel='PC2')
179 figD.savefig(f"2B_rawPCA_{nfoDict['sid']}.pdf", metadata={**metapdf, 'Title': 'PCA'})
181 import scvi
182 for IDlist in ([0],[1],[0,1]):
183 rawList = [scDat[i].rawDat for i in IDlist]
184 dataIDs = [scDat[i].name for i in IDlist]
185 if len(rawList) == 1:
186 adata = rawList[0]
187 dataID = dataIDs[0]
188 elif len(rawList) == 2:
189 adata=ad.concat(rawList, label='Platform', keys=PlatformTuple, index_unique='-')
190 dataID = 'Both'
191 print(f"[i]Begin Tab 1. 1F Dropout rates - {dataID}. With scvi {scvi.__version__}", file=sys.stderr)
192 adata.var['mt'] = adata.var_names.str.startswith('MT-') | adata.var_names.str.startswith('mt-')
193 sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], percent_top=None, log1p=True, inplace=True)
194 if dataID == 'Both':
195 scvi.data.poisson_gene_selection(adata,n_top_genes=8000,n_samples=10000,batch_key='Platform')
196 else:
197 scvi.data.poisson_gene_selection(adata,n_top_genes=8000,n_samples=10000)
198 doDropOutPlot(dataID,adata)
199 adata = None
201 def doDropOutPlot(dataID,adata) -> None:
202 adata.var['mean_'] = np.array(adata.X.mean(0))[0]
203 GenesM = adata.var.sort_values(by='prob_zero_enrichment_rank', ascending=False)
204 GenesM.to_csv(f"1F_GenesDropout_{nfoDict['sid']}_{dataID}_PlatformAsBatch.csv.zst",encoding='utf-8',compression={'method': 'zstd', 'level': 9, 'write_checksum': True})
205 print(f"[i]Begin Fig 1. 1F GenesM3DropSelected (added) - {dataID}", file=sys.stderr)
206 highly_variable_df = adata.var.query('highly_variable')
207 # Set up the figure and axes
208 fig, ax = plt.subplots(figsize=(10, 6))
209 # Create the scatter plot for the main points with color bar
210 scatter = sns.scatterplot(x='mean_', y='observed_fraction_zeros', hue='prob_zero_enrichment', data=adata.var, palette='viridis', legend='brief')
211 # Create the line plot for expected_fraction_zeros
212 sns.lineplot(x='mean_', y='expected_fraction_zeros', data=adata.var, color='r', label='Expected Fraction Zeros')
213 # Highlight highly variable points
214 sns.scatterplot(x='mean_', y='observed_fraction_zeros', data=highly_variable_df, color='pink', marker='.', s=5, alpha=0.5)
215 box_coords = adata.var.query('highly_variable').agg({'mean_': ['min', 'max'], 'observed_fraction_zeros': ['min', 'max']})
216 # Draw a rectangle to cover highly variable points
217 rect = plt.Rectangle(box_coords.loc['min'],
218 box_coords['mean_'].diff()['max'], box_coords['observed_fraction_zeros'].diff()['max'],
219 fill=None, edgecolor='blue', linewidth=2, alpha=0.5)
220 ax.add_patch(rect)
221 # Annotate right-top and left-bottom points
222 fmt = '.4f'
223 for mean_val, obs_frac_val in zip(box_coords['mean_'], box_coords['observed_fraction_zeros']):
224 label = f'({mean_val:{fmt}},{obs_frac_val:{fmt}})'
225 # Add padding to avoid overlapping with the rectangle
226 bbox_props = dict(boxstyle="round,pad=0.3", fc="white", ec="white", lw=1, alpha=0.62)
227 ax.text(mean_val, obs_frac_val, label, bbox=bbox_props)
228 # Set x-axis to log scale
229 ax.set_xscale('log')
230 # Set plot title
231 ax.set_title(f'Mean vs Observed Fraction Zeros - {nfoDict["sub"]} {dataID}')
232 # Create a color bar for Prob Zero Enrichment
233 cbar = fig.colorbar(scatter.get_children()[0], ax=ax, orientation='vertical', pad=0.1)
234 cbar.set_label('Prob Zero Enrichment')
235 plt.savefig(f"1F_GenesM3DropSelected_{nfoDict['sid']}_{dataID}_PlatformAsBatch.pdf", metadata={'Title': 'scvi.data.poisson_gene_selection', 'Subject': f"{nfoDict['sub']} Data", 'Author': 'HU Xuesong'})
236 plt.close('all')
237 print(f"[i]Begin Fig 1. 1F GenesDropoutHist (added) - {dataID}", file=sys.stderr)
238 plt.figure(figsize=(6,4))
239 plt.title(f"Gene DropRatio Histogram - {nfoDict['sub']} {dataID}")
240 histplot = sns.histplot(adata.var, x='observed_fraction_zeros', bins=30, kde=False, hue='highly_variable', multiple="dodge", shrink=.8)
241 bars_heights = [p.get_height() for p in histplot.patches if p.get_facecolor()[:3] == sns.color_palette()[1]]
242 plt.ylim(0, max(bars_heights)*1.1) # Adjust the margin as needed
243 plt.savefig(f"1F_GenesDropoutHist_{nfoDict['sid']}_{dataID}_PlatformAsBatch.pdf", metadata={'Title': 'Gene DropRatio Histogram', 'Subject': f"{nfoDict['sub']} Data", 'Author': 'HU Xuesong'})
244 plt.close('all')
246 def getOBSMdf(anndata, obsmkey='X_pca') -> pd.DataFrame:
247 if not obsmkey in anndata.obsm:
248 if obsmkey=='X_pca':
249 sc.tl.pca(anndata,zero_center=True)
250 data=anndata.obsm[obsmkey][0:,0:2]
251 df=pd.DataFrame(data=data[0:,0:], index=[anndata.obs_names[i] for i in range(data.shape[0])], columns=['P'+str(1+i) for i in range(data.shape[1])])
252 return df
254 if __name__ == "__main__":
255 main() # time (./nfig1.py human; ./nfig1.py mbrain ; ./nfig1.py mkidney ) | tee nplot.log
257 # pip install -U --force-reinstall lightning
259 pip3 install git+https://github.com/matplotlib/mplcairo
260 pip3 install matplotlib_venn
262 # salus
263 micromamba install scvi-tools[version='>1'] # Changing lightning-2.1.3-pyhd8ed1ab_1 ==> lightning-2.0.9.post0-pyhd8ed1ab_0