include: "Common.snake"

configfile: "config.yaml"

import os.path

onstart:
    shell("mkdir -p tmp")

def final_stage(w):
    if config["reassembly"]["enabled"]:
        return ["propagation.done", "binning/bins_total.prof"] #Stop after the bin choosing
    if config["propagation"]["enabled"]:
        return "propagation.done" #Stop on the propagation
    return "binning/{}/binning.done".format(BINNER) #Stop on the preliminary binning

rule all:
    input:   final_stage
    message: "Dataset of {SAMPLE_COUNT} samples from {IN} has been processed."

# ---- Assembly ----------------------------------------------------------------

# Assemble with MegaHIT
rule megahit:
    input:   left=left_reads, right=right_reads
    output:  "assembly/megahit/{group}.fasta"
    params:  left=lambda w: ",".join(left_reads(w)),
             right=lambda w: ",".join(right_reads(w)),
             dir="assembly/megahit/{group}"
    threads: THREADS
    log:     "assembly/megahit.log"
    message: "Assembling {wildcards.group} with MegaHIT"
    shell:   "rm -rf {params.dir} &&"
             " {SOFT}/megahit/megahit -1 {params.left} -2 {params.right}"
             " -t {threads} -o {params.dir} >{log} 2>&1 &&"
             " cp {params.dir}/final.contigs.fa {output}"

# Assemble with SPAdes
rule spades:
    input:   left=left_reads, right=right_reads
    output:  "assembly/spades/{group}.fasta"
    params:  left=lambda w: " ".join(expand("-1 {r}", r=left_reads(w))),
             right=lambda w: " ".join(expand("-2 {r}", r=right_reads(w))),
             dir="assembly/spades/{group}", bh=lambda w: "" if is_fastq(w) else "--only-assembler"
    threads: THREADS
    log:     "assembly/{group}.log"
    message: "Assembling {wildcards.group} with metaSPAdes"
    shell:   "{ASSEMBLER_DIR}/spades.py {params.bh} --meta -m 400 -t {threads}"
             " {params.left} {params.right}"
             " --save-gp -o {params.dir} >{log} 2>&1 && "
             "cp {params.dir}/scaffolds.fasta {output}"

rule copy_contigs:
    input:   "assembly/{}/{{group}}.fasta".format(ASSEMBLER)
    output:  "assembly/full/{group,(sample|group\d+)}.fasta"
    shell:   "cp {input} {output}"

rule split_contigs:
    input:   "assembly/{}/{{group}}.fasta".format(ASSEMBLER)
    output:  "assembly/splits/{group,(sample|group)\d+}.fasta"
    message: "Cutting {wildcards.group} into {SPLIT_LENGTH} bp splits"
    shell:   "{SCRIPTS}/cut_fasta.py -c {SPLIT_LENGTH} -o 0 -m {input} > {output}"

#---- Generating profiles/depths -----------------------------------------------

# MetaBAT way

rule bowtie_index:
    input:   "assembly/{frags}/all.fasta"
    output:  "profile/jgi/{frags}/index.done"
    log:     "profile/jgi/{frags}/bowtie-build.log"
    message: "Building bowtie index"
    shell:   "bowtie2-build {input} profile/jgi/index_{wildcards.frags} >{log} 2>&1 && touch {output}"

rule align:
    input:   left=left_sample_reads, right=right_sample_reads,
             index="profile/jgi/{frags}/index.done"
    output:  "profile/jgi/{frags}/{sample}.bam"
    threads: THREADS
    log:     "profile/jgi/{frags}/bowtie-{sample}.log"
    message: "Aligning {wildcards.sample} with bowtie"
    shell:   "bowtie2 -x profile/jgi/index_{wildcards.frags} -p {threads}"
             " -1 {input.left} -2 {input.right} 2>{log} | samtools view -bS - > {output}"

rule depth:
    input:   expand("profile/jgi/{{frags}}/{sample}.bam", sample=SAMPLES)
    output:  "profile/jgi/{frags}/depth_metabat.txt"
    log:     "profile/jgi/{frags}/depths.log"
    message: "Calculating contig depths"
    shell:   "{SOFT}/metabat/jgi_summarize_bam_contig_depths --outputDepth {output} {input} >{log} 2>&1"

rule concoct_depth:
    input:   "profile/jgi/splits/depth_metabat.txt"
    output:  "binning/concoct/profiles_jgi.in"
    message: "Converting depth file into CONCOCT format"
    shell:   "awk 'NR > 1 {{for(x=1;x<=NF;x++) if(x == 1 || (x >= 4 && x % 2 == 0)) printf \"%s\", $x (x == NF || x == (NF-1) ? \"\\n\":\"\\t\")}}' {input} > {output}"

# Our way
rule kmc:
    output:  temp("tmp/{sample}.kmc_pre"), temp("tmp/{sample}.kmc_suf")
    params:  min_mult=2, tmp="tmp/{sample}_kmc", out="tmp/{sample}",
             desc="profile/{sample}.desc",
             left=left_sample_reads, right=right_sample_reads,
             format=lambda w: "-fq" if is_fastq(w) else "-fa"
    log:     "profile/kmc_{sample}.log"
    threads: THREADS
    message: "Running kmc for {wildcards.sample}"
    shell:   "mkdir -p {params.tmp}\n"
             "echo '{params.left}\n{params.right}' > {params.desc}\n"
             "{SOFT}/kmc {params.format} -k{PROFILE_K} -t{threads} -ci{params.min_mult}"
             " -cs65535 @{params.desc} {params.out} {params.tmp} >{log} 2>&1 && "
             "rm -rf {params.tmp}"

rule multiplicities:
    input:   expand("tmp/{sample}.kmc_pre", sample=SAMPLES), expand("tmp/{sample}.kmc_suf", sample=SAMPLES)
    output:  "profile/mts/kmers.kmm"
    params:  kmc_files=" ".join(expand("tmp/{sample}", sample=SAMPLES)), out="profile/mts/kmers"
    log:     "profile/mts/kmers.log"
    message: "Gathering {PROFILE_K}-mer multiplicities from all samples"
    shell:   "{BIN}/kmer_multiplicity_counter -n {SAMPLE_COUNT} -k {PROFILE_K} -s 2"
             " -f tmp -t {threads} -o {params.out} >{log} 2>&1 && "
             "rm tmp/*.sorted"

rule abundancies:
    input:   contigs="assembly/splits/{group}.fasta", mpl="profile/mts/kmers.kmm"
    output:  "profile/mts/{group}.tsv"
    log:     "profile/mts/{group}.log"
    message: "Counting contig abundancies for {wildcards.group}"
    shell:   "{BIN}/contig_abundance_counter -k {PROFILE_K} -w tmp -c {input.contigs}"
             " -n {SAMPLE_COUNT} -m profile/mts/kmers -o {output}"
             " -l {MIN_CONTIG_LENGTH} >{log} 2>&1"

rule combine_profiles:
    input:   expand("profile/mts/{group}.tsv", group=GROUPS)
    output:  "profile/mts/all.tsv"
    message: "Combine all profiles"
    run:
        shell("rm -f {output}")
        for sample_ann in input:
            sample, _ = os.path.splitext(os.path.basename(sample_ann))
            shell("sed -e 's/^/{sample}-/' {sample_ann} >> {output}")

rule binning_pre:
    input:   "profile/mts/all.tsv"
    output:  "binning/{binner}/profiles_mts.in"
    params:  " ".join(list(GROUPS.keys()))
    log:     "binning/input.log"
    message: "Preparing input for {wildcards.binner}"
    shell:   "{SCRIPTS}/make_input.py -n {SAMPLE_COUNT} -t {wildcards.binner}"
             " -o {output} {input} >{log}"

rule filter_contigs:
    input:   contigs="assembly/splits/all.fasta", profile="binning/{}/profiles_mts.in".format(BINNER)
    output:  contigs="assembly/splits/all_filtered.fasta"
    message: "Leave contigs that have profile information"
    shell:   "cut -f1 < {input.profile} > tmp/names_tmp.txt && sed -i '1d' tmp/names_tmp.txt && "
             "{SCRIPTS}/contig_name_filter.py {input.contigs} tmp/names_tmp.txt {output.contigs}"

#---- Binning ------------------------------------------------------------------

# Binning with Canopy
rule canopy:
    input:   "binning/canopy/profiles_{}.in".format(PROFILER)
    output:  out="binning/canopy/binning.out", prof="binning/canopy/bins.prof",
             flag=touch("binning/canopy/binning.done")
    threads: THREADS
    log:     "binning/canopy.log"
    message: "Running canopy clustering"
    shell:   "{SOFT}/cc.bin --filter_max_dominant_obs 1 -n {threads}"
             " -i {input} -o {output.out} -c binning/canopy/canopy_bins.prof >{log} 2>&1 && "
             "sed 's/CAG/BIN/g' binning/canopy/canopy_bins.prof >{output.prof}"

# Binning with CONCOCT
rule concoct:
    input:   contigs="assembly/splits/all.fasta", profiles="binning/concoct/profiles_{}.in".format(PROFILER)
    output:  "binning/concoct/binning.out"
    params:  max_clusters=40, out="binning/concoct"
    threads: THREADS
    log:     "binning/concoct.log"
    message: "Running CONCOCT clustering"
    shell:   "set +u; source activate py27; set -u\n"
             "concoct -c {params.max_clusters} --composition_file {input.contigs}"
             " --coverage_file {input.profiles} --length_threshold {MIN_CONTIG_LENGTH}"
             " -b {params.out} >{log} 2>&1 && "
             "cp binning/concoct/clustering_gt{MIN_CONTIG_LENGTH}.csv {output}"

rule extract_bins:
    input:   "assembly/splits/all.fasta", "binning/annotation/all.ann"
    output:  touch("binning/concoct/binning.done")
    message: "Extracting CONCOCT bins"
    shell:   "mkdir -p binning/bins && {SCRIPTS}/split_bins.py {input} binning/bins"

# Binning with MetaBAT
rule metabat:
    input:   contigs="assembly/full/all.fasta", profiles="profile/jgi/full/depth_metabat.txt"
    output:  flag=touch("binning/metabat/binning.done"),
             out="binning/metabat/binning.out"
    threads: THREADS
    params:  "binning/metabat/cluster"
    log:     "binning/metabat.log"
    message: "Running MetaBAT clustering"
    shell:   "{SOFT}/metabat/metabat -t {threads} -m {MIN_CONTIG_LENGTH} "
             " --minContigByCorr {MIN_CONTIG_LENGTH} --saveCls"
             " -i {input.contigs} -a {input.profiles}"
             " -o {params} > {log} && "
             "sed 's/\t/,/g' {params} > {output.out} && mkdir -p binning/bins && "
             "for file in binning/metabat/*.fa ; do bin=${{file##*/}}; mv $file binning/bins/${{bin%.*}}.fasta; done"

# Binning with MAXBIN2
rule maxbin:
    input:   contigs="assembly/splits/all.fasta", profiles="binning/maxbin/profiles_{}.in".format(PROFILER)
    output:  "binning/maxbin/binning.out"
    threads: THREADS
    params:  out="binning/maxbin/cluster"
    log:     "binning/maxbin.log"
    message: "Running MaxBin2 clustering"
    shell:   "perl {SOFT}/MaxBin/run_MaxBin.pl -thread {threads} -min_contig_length {MIN_CONTIG_LENGTH} "
             " -contig {input.contigs} -abund {input.profiles}"
             " -out {params.out} > {log}"
             "&& {SCRIPTS}/make_maxbincsv.py -o {output} {params.out}"

# Binning with GATTACA
# conda create -n py27 python=2.7.9 numpy scipy scikit-learn anaconda
# conda install -c bioconda pysam=0.11.2.2
rule gattaca:
    input:   contigs="assembly/splits/all_filtered.fasta", profiles="binning/gattaca/profiles_{}.in".format(PROFILER)
    output:  "binning/gattaca/binning.out"
    threads: THREADS
    log:     "binning/gattaca.log"
    message: "Running GATTACA clustering"
    shell:   "set +u; source activate py27; set -u\n"
             "python {SOFT}/gattaca/src/python/gattaca.py cluster --contigs {input.contigs}"
             " --coverage {input.profiles} --algorithm dirichlet --clusters {output} >{log} 2>&1"

# Binning with BinSanity
rule binsanity:
    input:   contigs="assembly/splits/all_filtered.fasta", profiles="binning/binsanity/profiles_{}.in".format(PROFILER)
    output:  "binning/binsanity/binning.out"
    threads: THREADS
    log:     "binning/binsanity.log"
    message: "Running BinSanity clustering"
    shell:   "python2 /home/tdvorkina/binsanity/src/BinSanity/test-scripts/Binsanity-lc "
             " -f ./ -l {input.contigs} -c {input.profiles} -o binning/binsanity/BINSANITY-RESULTS > {log}  && "
             "{SCRIPTS}/clusters2csv.py binning/binsanity/BINSANITY-RESULTS/KMEAN-BINS {output} && mv Binsanity-log.txt binning/ "

# Postprocessing
rule bin_profiles:
    input:   "profile/{}/all.tsv".format(PROFILER), "binning/{}/unified_binning.tsv".format(BINNER)
    output:  "binning/{}/bins.prof".format(BINNER)
    message: "Deriving bin profiles"
    shell:   "{SCRIPTS}/bin_profiles.py {input} > {output}"

ruleorder: canopy > bin_profiles

rule binning_format:
    input:   "binning/{}/binning.out".format(BINNER)
    output:  "binning/{}/unified_binning.tsv".format(BINNER)
    message: "Making unified binning results"
    shell:   "{SCRIPTS}/convert_output.py -t {BINNER} -o {output} {input} &&"
             "cp {output} binning/binning.tsv" #Additional table for stats

rule annotate:
    input:   "binning/{}/unified_binning.tsv".format(BINNER)
    output:  expand("binning/annotation/{sample}.ann", sample=GROUPS)
    params:  "binning/annotation/"
    message: "Preparing raw annotations"
    run:
        samples_annotation = dict()
        #Load the whole annotation: {sample: [bins]}
        with open(input[0]) as input_file:
            for line in input_file:
                annotation_str = line.split("\t", 1)
                bin_id = annotation_str[1].strip()
                sample_contig = annotation_str[0].split('-', 1)
                if len(sample_contig) > 1:
                    sample = sample_contig[0]
                    contig = sample_contig[1]
                else: #Backward compatibility with old alternative pipeline runs
                    sample = "group1"
                    contig = sample_contig[0]
                annotation = samples_annotation.setdefault(sample, dict())
                if contig not in annotation:
                    annotation[contig] = [bin_id]
                else:
                    annotation[contig].append(bin_id)

        #Serialize it in the propagator format
        for sample, annotation in samples_annotation.items():
            with open(os.path.join(params[0], sample + ".ann"), "w") as sample_out:
                for contig in samples_annotation[sample]:
                    print(contig, "\t", " ".join(annotation[contig]), sep="", file=sample_out)


#---- Post-clustering pipeline -------------------------------------------------

# Propagation stage
#Path to saves of necessary assembly stage
SAVES = "K{0}/saves/01_before_repeat_resolution/graph_pack".format(ASSEMBLY_K)

rule prop_binning:
    input:   contigs="assembly/spades/{group}.fasta", splits="assembly/splits/{group}.fasta",
             ann="binning/annotation/{group}.ann",    left=left_reads, right=right_reads,
             bins="binning/{}/filtered_bins.tsv".format(BINNER)
    output:  ann="propagation/annotation/{group}.ann", edges="propagation/edges/{group}.fasta"
    params:  saves=os.path.join("assembly/spades/{group}/", SAVES),
             samples=lambda wildcards: " ".join(GROUPS[wildcards.group])
    log:     "binning/{group}.log"
    message: "Propagating annotation & binning reads for {wildcards.group}"
    shell:   "{BIN}/prop_binning -k {ASSEMBLY_K} -s {params.saves} -c {input.contigs} -b {input.bins}"
             " -n {params.samples} -l {input.left} -r {input.right} -t {MIN_CONTIG_LENGTH}"
             " -a {input.ann} -f {input.splits} -o binning -p {output.ann} -e {output.edges} >{log} 2>&1"

rule prop_all:
    input:   expand("propagation/annotation/{group}.ann", group=GROUPS)
    output:  touch("propagation.done")
    message: "Finished propagation of all annotations."

rule choose_bins:
    input:   "binning/{}/unified_binning.tsv".format(BINNER)
    output:  "binning/{}/filtered_bins.tsv".format(BINNER)
    message: "Filter small bins"
    shell:   "{SCRIPTS}/choose_bins.py {input} >{output} 2>&1"

rule choose_samples:
    input:   "binning/{}/bins.prof".format(BINNER), "binning/{}/filtered_bins.tsv".format(BINNER)
    output:  "binning/bins_total.prof"
    log:     "binning/choose_samples.log"
    message: "Choosing bins for reassembly and samples for them"
    shell:   "rm -f binning/*.info && rm -rf binning/excluded && "
             "{SCRIPTS}/choose_samples.py {input} {output} binning >{log} 2>&1"
