modified: pixi.toml
[GalaxyCodeBases.git] / BioInfo / BS-Seq / bwa-meth / bwameth.py
blobc1e7f60f50679655fac4b6c05cf8a2fbe6c91e90
1 #!/usr/bin/env python
2 """
3 map bisulfite converted reads to an insilico converted genome using bwa mem.
4 A command to this program like:
6 python bwameth.py --reference ref.fa A.fq B.fq
8 Gets converted to:
10 bwa mem -pCMR ref.fa.bwameth.c2t '<python bwameth.py c2t A.fq B.fq'
12 So that A.fq has C's converted to T's and B.fq has G's converted to A's
13 and both are streamed directly to the aligner without a temporary file.
14 The output is a corrected, sorted, indexed BAM.
15 """
16 from __future__ import print_function
17 import tempfile
18 import sys
19 import os
20 import os.path as op
21 import argparse
22 from subprocess import check_call
23 from operator import itemgetter
24 from itertools import groupby, repeat, chain
25 import re
27 try:
28 from itertools import izip
29 import string
30 maketrans = string.maketrans
31 except ImportError: # python3
32 izip = zip
33 maketrans = str.maketrans
34 from toolshed import nopen, reader, is_newer_b
36 __version__ = "0.2.0"
38 def checkX(cmd):
39 for p in os.environ['PATH'].split(":"):
40 if os.access(os.path.join(p, cmd), os.X_OK):
41 break
42 else:
43 raise Exception("executable for '%s' not found" % cmd)
45 checkX('samtools')
46 checkX('bwa')
48 class BWAMethException(Exception): pass
50 def comp(s, _comp=maketrans('ATCG', 'TAGC')):
51 return s.translate(_comp)
53 def wrap(text, width=100): # much faster than textwrap
54 try: xrange
55 except NameError: xrange = range
56 for s in xrange(0, len(text), width):
57 yield text[s:s+width]
59 def run(cmd):
60 list(nopen("|%s" % cmd.lstrip("|")))
62 def fasta_iter(fasta_name):
63 fh = nopen(fasta_name)
64 faiter = (x[1] for x in groupby(fh, lambda line: line[0] == ">"))
65 for header in faiter:
66 header = next(header)[1:].strip()
67 yield header, "".join(s.strip() for s in next(faiter)).upper()
69 def convert_reads(fq1s, fq2s, out=sys.stdout):
71 for fq1, fq2 in zip(fq1s.split(","), fq2s.split(",")):
72 sys.stderr.write("converting reads in %s,%s\n" % (fq1, fq2))
73 fq1 = nopen(fq1)
74 if fq2 != "NA":
75 fq2 = nopen(fq2)
76 q2_iter = izip(*[fq2] * 4)
77 else:
78 sys.stderr.write("WARNING: running bwameth in single-end mode\n")
79 q2_iter = repeat((None, None, None, None))
80 q1_iter = izip(*[fq1] * 4)
82 lt80 = 0
83 for pair in izip(q1_iter, q2_iter):
84 for read_i, (name, seq, _, qual) in enumerate(pair):
85 if name is None: continue
86 name = name.rstrip("\r\n").split(" ")[0]
87 if name[0] != "@":
88 sys.stderr.write("""ERROR!!!!
89 ERROR!!! FASTQ conversion failed
90 ERROR!!! expecting FASTQ 4-tuples, but found a record %s that doesn't start with "@"
91 """ % name)
92 sys.exit(1)
93 if name.endswith(("_R1", "_R2")):
94 name = name[:-3]
95 elif name.endswith(("/1", "/2")):
96 name = name[:-2]
98 seq = seq.upper().rstrip('\n')
99 if len(seq) < 80:
100 lt80 += 1
102 char_a, char_b = ['CT', 'GA'][read_i]
103 # keep original sequence as name.
104 name = " ".join((name,
105 "YS:Z:" + seq +
106 "\tYC:Z:" + char_a + char_b + '\n'))
107 seq = seq.replace(char_a, char_b)
108 out.write("".join((name, seq, "\n+\n", qual)))
110 out.flush()
111 if lt80 > 50:
112 sys.stderr.write("WARNING: %i reads with length < 80\n" % lt80)
113 sys.stderr.write(" : this program is designed for long reads\n")
114 return 0
116 def convert_fasta(ref_fasta, just_name=False):
117 out_fa = ref_fasta + ".bwameth.c2t"
118 if just_name:
119 return out_fa
120 msg = "c2t in %s to %s" % (ref_fasta, out_fa)
121 if is_newer_b(ref_fasta, out_fa):
122 sys.stderr.write("already converted: %s\n" % msg)
123 return out_fa
124 sys.stderr.write("converting %s\n" % msg)
125 try:
126 fh = open(out_fa, "w")
127 for header, seq in fasta_iter(ref_fasta):
128 ########### Reverse ######################
129 fh.write(">r%s\n" % header)
131 #if non_cpg_only:
132 # for ctx in "TAG": # use "ATC" for fwd
133 # seq = seq.replace('G' + ctx, "A" + ctx)
134 # for line in wrap(seq):
135 # print >>fh, line
136 #else:
137 for line in wrap(seq.replace("G", "A")):
138 fh.write(line + '\n')
140 ########### Forward ######################
141 fh.write(">f%s\n" % header)
142 for line in wrap(seq.replace("C", "T")):
143 fh.write(line + '\n')
144 fh.close()
145 except:
146 try:
147 fh.close()
148 except UnboundLocalError:
149 pass
150 os.unlink(out_fa)
151 raise
152 return out_fa
155 def bwa_index(fa):
156 if is_newer_b(fa, (fa + '.amb', fa + '.sa')):
157 return
158 sys.stderr.write("indexing: %s\n" % fa)
159 try:
160 run("bwa index -a bwtsw %s" % fa)
161 except:
162 if op.exists(fa + ".amb"):
163 os.unlink(fa + ".amb")
164 raise
166 class Bam(object):
167 __slots__ = 'read flag chrom pos mapq cigar chrom_mate pos_mate tlen \
168 seq qual other'.split()
169 def __init__(self, args):
170 for a, v in zip(self.__slots__[:11], args):
171 setattr(self, a, v)
172 self.other = args[11:]
173 self.flag = int(self.flag)
174 self.pos = int(self.pos)
175 self.tlen = int(float(self.tlen))
177 def __repr__(self):
178 return "Bam({chr}:{start}:{read}".format(chr=self.chrom,
179 start=self.pos,
180 read=self.read)
182 def __str__(self):
183 return "\t".join(str(getattr(self, s)) for s in self.__slots__[:11]) \
184 + "\t" + "\t".join(self.other)
186 def is_first_read(self):
187 return bool(self.flag & 0x40)
189 def is_second_read(self):
190 return bool(self.flag & 0x80)
192 def is_plus_read(self):
193 return not (self.flag & 0x10)
195 def is_minus_read(self):
196 return bool(self.flag & 0x10)
198 def is_mapped(self):
199 return not (self.flag & 0x4)
201 def cigs(self):
202 if self.cigar == "*":
203 yield (0, None)
204 raise StopIteration
205 cig_iter = groupby(self.cigar, lambda c: c.isdigit())
206 for g, n in cig_iter:
207 yield int("".join(n)), "".join(next(cig_iter)[1])
209 def cig_len(self):
210 return sum(c[0] for c in self.cigs() if c[1] in
211 ("M", "D", "N", "EQ", "X", "P"))
213 def left_shift(self):
214 left = 0
215 for n, cig in self.cigs():
216 if cig == "M": break
217 if cig == "H":
218 left += n
219 return left
221 def right_shift(self):
222 right = 0
223 for n, cig in reversed(list(self.cigs())):
224 if cig == "M": break
225 if cig == "H":
226 right += n
227 return -right or None
229 @property
230 def original_seq(self):
231 try:
232 return next(x for x in self.other if x.startswith("YS:Z:"))[5:]
233 except:
234 sys.stderr.write(repr(self.other) + "\n")
235 sys.stderr.write(self.read + "\n")
236 raise
238 @property
239 def ga_ct(self):
240 return [x for x in self.other if x.startswith("YC:Z:")]
242 def longest_match(self, patt=re.compile("\d+M")):
243 return max(int(x[:-1]) for x in patt.findall(self.cigar))
246 def rname(fq1, fq2=""):
247 fq1, fq2 = fq1.split(",")[0], fq2.split(",")[0]
248 def name(f):
249 n = op.basename(op.splitext(f)[0])
250 if n.endswith('.fastq'): n = n[:-6]
251 if n.endswith(('.fq', '.r1', '.r2')): n = n[:-3]
252 return n
253 return "".join(a for a, b in zip(name(fq1), name(fq2)) if a == b) or 'bm'
256 def bwa_mem(fa, mfq, extra_args, threads=1, rg=None,
257 paired=True, set_as_failed=None):
258 conv_fa = convert_fasta(fa, just_name=True)
259 if not is_newer_b(conv_fa, (conv_fa + '.amb', conv_fa + '.sa')):
260 raise BWAMethException("first run bwameth.py index %s" % fa)
262 if not rg is None and not rg.startswith('@RG'):
263 rg = '@RG\tID:{rg}\tSM:{rg}'.format(rg=rg)
265 # penalize clipping and unpaired. lower penalty on mismatches (-B)
266 cmd = "|bwa mem -T 40 -B 2 -L 10 -CM "
268 if paired:
269 cmd += ("-U 100 -p ")
270 cmd += "-R '{rg}' -t {threads} {extra_args} {conv_fa} {mfq}"
271 cmd = cmd.format(**locals())
272 sys.stderr.write("running: %s\n" % cmd.lstrip("|"))
273 as_bam(cmd, fa, set_as_failed)
276 def as_bam(pfile, fa, set_as_failed=None):
278 pfile: either a file or a |process to generate sam output
279 fa: the reference fasta
280 set_as_failed: None, 'f', or 'r'. If 'f'. Reads mapping to that strand
281 are given the sam flag of a failed QC alignment (0x200).
283 sam_iter = nopen(pfile)
285 for line in sam_iter:
286 if not line[0] == "@": break
287 handle_header(line)
288 else:
289 sys.stderr.flush()
290 raise Exception("bad or empty fastqs")
291 sam_iter2 = (x.rstrip().split("\t") for x in chain([line], sam_iter))
292 for read_name, pair_list in groupby(sam_iter2, itemgetter(0)):
293 pair_list = [Bam(toks) for toks in pair_list]
295 for aln in handle_reads(pair_list, set_as_failed):
296 sys.stdout.write(str(aln) + '\n')
298 def handle_header(line, out=sys.stdout):
299 toks = line.rstrip().split("\t")
300 if toks[0].startswith("@SQ"):
301 sq, sn, ln = toks # @SQ SN:fchr11 LN:122082543
302 # we have f and r, only print out f
303 chrom = sn.split(":")[1]
304 if chrom.startswith('r'): return
305 chrom = chrom[1:]
306 toks = ["%s\tSN:%s\t%s" % (sq, chrom, ln)]
307 if toks[0].startswith("@PG"):
308 #out.write("\t".join(toks) + "\n")
309 toks = ["@PG\tID:bwa-meth\tPN:bwa-meth\tVN:%s\tCL:\"%s\"" % (
310 __version__,
311 " ".join(x.replace("\t", "\\t") for x in sys.argv))]
312 out.write("\t".join(toks) + "\n")
315 def handle_reads(alns, set_as_failed):
317 for aln in alns:
318 orig_seq = aln.original_seq
319 assert len(aln.seq) == len(aln.qual), aln.read
320 # don't need this any more.
321 aln.other = [x for x in aln.other if not x.startswith('YS:Z')]
323 # first letter of chrom is 'f' or 'r'
324 direction = aln.chrom[0]
325 aln.chrom = aln.chrom.lstrip('fr')
327 if not aln.is_mapped():
328 aln.seq = orig_seq
329 continue
331 assert direction in 'fr', (direction, aln)
332 aln.other.append('YD:Z:' + direction)
334 if set_as_failed == direction:
335 aln.flag |= 0x200
337 # here we have a heuristic that if the longest match is not 44% of the
338 # sequence length, we mark it as failed QC and un-pair it. At the end
339 # of the loop we set all members of this pair to be unmapped
340 if aln.longest_match() < (len(orig_seq) * 0.44):
341 aln.flag |= 0x200 # fail qc
342 aln.flag &= (~0x2) # un-pair
343 aln.mapq = min(int(aln.mapq), 1)
345 mate_direction = aln.chrom_mate[0]
346 if mate_direction not in "*=":
347 aln.chrom_mate = aln.chrom_mate[1:]
349 # adjust the original seq to the cigar
350 l, r = aln.left_shift(), aln.right_shift()
351 if aln.is_plus_read():
352 aln.seq = orig_seq[l:r]
353 else:
354 aln.seq = comp(orig_seq[::-1][l:r])
356 if any(aln.flag & 0x200 for aln in alns):
357 for aln in alns:
358 aln.flag |= 0x200
359 aln.flag &= (~0x2)
360 return alns
362 def cnvs_main(args):
363 __doc__ = """
364 calculate CNVs from BS-Seq bams or vcfs
366 p = argparse.ArgumentParser(__doc__)
367 p.add_argument("--regions", help="optional target regions", default='NA')
368 p.add_argument("bams", nargs="+")
370 a = p.parse_args(args)
371 r_script = """
372 options(stringsAsFactors=FALSE)
373 suppressPackageStartupMessages(library(cn.mops))
374 suppressPackageStartupMessages(library(snow))
375 args = commandArgs(TRUE)
376 regions = args[1]
377 bams = args[2:length(args)]
378 n = length(bams)
379 if(is.na(regions)){
380 bam_counts = getReadCountsFromBAM(bams, parallel=min(n, 4), mode="paired")
381 res = cn.mops(bam_counts, parallel=min(n, 4), priorImpact=20)
382 } else {
383 segments = read.delim(regions, header=FALSE)
384 gr = GRanges(segments[,1], IRanges(segments[,2], segments[,3]))
385 bam_counts = getSegmentReadCountsFromBAM(bams, GR=gr, mode="paired", parallel=min(n, 4))
386 res = exomecn.mops(bam_counts, parallel=min(n, 4), priorImpact=20)
388 res = calcIntegerCopyNumbers(res)
390 df = as.data.frame(cnvs(res))
391 write.table(df, row.names=FALSE, quote=FALSE, sep="\t")
393 with tempfile.NamedTemporaryFile(delete=True) as rfh:
394 rfh.write(r_script + '\n')
395 rfh.flush()
396 for d in reader('|Rscript {rs_name} {regions} {bams}'.format(
397 rs_name=rfh.name, regions=a.regions, bams=" ".join(a.bams)),
398 header=False):
399 print("\t".join(d))
402 def convert_fqs(fqs):
403 script = __file__
404 return "'<%s %s c2t %s %s'" % (sys.executable, script, fqs[0],
405 fqs[1] if len(fqs) > 1
406 else ','.join(['NA'] * len(fqs[0].split(","))))
408 def main(args=sys.argv[1:]):
410 if len(args) > 0 and args[0] == "index":
411 assert len(args) == 2, ("must specify fasta as 2nd argument")
412 sys.exit(bwa_index(convert_fasta(args[1])))
414 if len(args) > 0 and args[0] == "c2t":
415 sys.exit(convert_reads(args[1], args[2]))
417 if len(args) > 0 and args[0] == "cnvs":
418 sys.exit(cnvs_main(args[1:]))
420 p = argparse.ArgumentParser(__doc__)
421 p.add_argument("--reference", help="reference fasta", required=True)
422 p.add_argument("-t", "--threads", type=int, default=6)
423 p.add_argument("--read-group", help="read-group to add to bam in same"
424 " format as to bwa: '@RG\\tID:foo\\tSM:bar'")
425 p.add_argument('--set-as-failed', help="flag alignments to this strand"
426 " as not passing QC (0x200). Targetted BS-Seq libraries are often"
427 " to a single strand, so we can flag them as QC failures. Note"
428 " f == OT, r == OB. Likely, this will be 'f' as we will expect"
429 " reads to align to the original-bottom (OB) strand and will flag"
430 " as failed those aligning to the forward, or original top (OT).",
431 default=None, choices=('f', 'r'))
432 p.add_argument('--version', action='version', version='bwa-meth.py {}'.format(__version__))
434 p.add_argument("fastqs", nargs="+", help="bs-seq fastqs to align. Run"
435 "multiple sets separated by commas, e.g. ... a_R1.fastq,b_R1.fastq"
436 " a_R2.fastq,b_R2.fastq note that the order must be maintained.")
438 args, pass_through_args = p.parse_known_args(args)
440 # for the 2nd file. use G => A and bwa's support for streaming.
441 conv_fqs_cmd = convert_fqs(args.fastqs)
443 bwa_mem(args.reference, conv_fqs_cmd, ' '.join(map(str, pass_through_args)),
444 threads=args.threads, rg=args.read_group or
445 rname(*args.fastqs),
446 paired=len(args.fastqs) == 2,
447 set_as_failed=args.set_as_failed)
449 if __name__ == "__main__":
450 main(sys.argv[1:])