#!/usr/bin/env python3

__version__ = '1.4.2'
__author__ = 'David Heller'

import sys
import os
import re
import pickle
import gzip
import logging
import pysam

from time import strftime, localtime

from svim.SVIM_input_parsing import parse_arguments, guess_file_type, read_file_list
from svim.SVIM_alignment import run_alignment
from svim.SVIM_COLLECT import analyze_alignment_file_coordsorted, analyze_alignment_file_querysorted
from svim.SVIM_CLUSTER import cluster_sv_signatures, write_signature_clusters_bed, write_signature_clusters_vcf
from svim.SVIM_COMBINE import combine_clusters, write_candidates, write_final_vcf
from svim.SVIM_genotyping import genotype
from svim.SVIM_plot import plot_sv_lengths, plot_sv_alleles


def main():
    # Fetch command-line options
    options = parse_arguments(program_version=__version__)

    if not options.sub:
        print("Please choose one of the two modes ('reads' or 'alignment'). See --help for more information.")
        return

    # Set up logging
    logFormatter = logging.Formatter("%(asctime)s [%(levelname)-7.7s]  %(message)s")
    rootLogger = logging.getLogger()
    if options.verbose:
        rootLogger.setLevel(logging.DEBUG)
    else:
        rootLogger.setLevel(logging.INFO)

    # Create working dir if it does not exist
    if not os.path.exists(options.working_dir):
        os.makedirs(options.working_dir)

    # Create log file
    fileHandler = logging.FileHandler("{0}/SVIM_{1}.log".format(options.working_dir, strftime("%y%m%d_%H%M%S", localtime())), mode="w")
    fileHandler.setFormatter(logFormatter)
    rootLogger.addHandler(fileHandler)

    consoleHandler = logging.StreamHandler()
    consoleHandler.setFormatter(logFormatter)
    rootLogger.addHandler(consoleHandler)

    logging.info("****************** Start SVIM, version {0} ******************".format(__version__))
    logging.info("CMD: python3 {0}".format(" ".join(sys.argv)))
    logging.info("WORKING DIR: {0}".format(os.path.abspath(options.working_dir)))
    for arg in vars(options):
        logging.info("PARAMETER: {0}, VALUE: {1}".format(arg, getattr(options, arg)))

    logging.info("****************** STEP 1: COLLECT ******************")
    if options.sub == 'reads':
        logging.info("MODE: reads")
        logging.info("INPUT: {0}".format(os.path.abspath(options.reads)))
        logging.info("GENOME: {0}".format(os.path.abspath(options.genome)))
        reads_type = guess_file_type(options.reads)
        if reads_type == "unknown":
            return
        elif reads_type == "list":
            # List of read files
            sv_signatures = []
            #Translocation signatures from other SV classes are stored separately for --all_bnd option
            translocation_signatures_all_bnds = []
            for index, file_path in enumerate(read_file_list(options.reads)):
                logging.info("Starting processing of file {0} from the list..".format(index))
                reads_type = guess_file_type(file_path)
                if reads_type == "unknown" or reads_type == "list":
                    return
                bam_path = run_alignment(options.working_dir, options.genome, file_path, reads_type, options.cores, options.aligner, options.nanopore)
                aln_file = pysam.AlignmentFile(bam_path)
                sigs, trans_sigs = analyze_alignment_file_coordsorted(aln_file, options)
                sv_signatures.extend(sigs)
                translocation_signatures_all_bnds.extend(trans_sigs)
        else:
            # Single read file
            bam_path = run_alignment(options.working_dir, options.genome, options.reads, reads_type, options.cores, options.aligner, options.nanopore)
            aln_file = pysam.AlignmentFile(bam_path)
            sv_signatures, translocation_signatures_all_bnds = analyze_alignment_file_coordsorted(aln_file, options)
    elif options.sub == 'alignment':
        logging.info("MODE: alignment")
        logging.info("INPUT: {0}".format(os.path.abspath(options.bam_file)))
        aln_file = pysam.AlignmentFile(options.bam_file)
        try:
            if aln_file.header["HD"]["SO"] == "coordinate":
                try:
                    aln_file.check_index()
                except ValueError:
                    logging.warning("Input BAM file is missing a valid index. Please generate with 'samtools index'. Continuing without genotyping for now..")
                    options.skip_genotyping = True
                except AttributeError:
                    logging.warning("pysam's .check_index raised an Attribute error. Something is wrong with the input BAM file.")
                    return
                sv_signatures, translocation_signatures_all_bnds = analyze_alignment_file_coordsorted(aln_file, options)
            elif aln_file.header["HD"]["SO"] == "queryname":
                sv_signatures, translocation_signatures_all_bnds = analyze_alignment_file_querysorted(aln_file, options)
                logging.warning("Skipping genotyping because it requires a coordinate-sorted input BAM file and an index. The given file, however, is queryname-sorted according to its header line.")
                options.skip_genotyping = True
            else:
                logging.error("Input BAM file needs to be coordinate-sorted or queryname-sorted. The given file, however, is unsorted according to its header line.")
                return
        except KeyError:
            logging.error("Is the given input BAM file sorted? It does not contain a sorting order in its header line.")
            return
        

    deletion_signatures = [ev for ev in sv_signatures if ev.type == "DEL"]
    insertion_signatures = [ev for ev in sv_signatures if ev.type == "INS"]
    inversion_signatures = [ev for ev in sv_signatures if ev.type == "INV"]
    tandem_duplication_signatures = [ev for ev in sv_signatures if ev.type == "DUP_TAN"]
    translocation_signatures = [ev for ev in sv_signatures if ev.type == "BND"]
    insertion_from_signatures = [ev for ev in sv_signatures if ev.type == "DUP_INT"]

    logging.info("Found {0} signatures for deleted regions.".format(len(deletion_signatures)))
    logging.info("Found {0} signatures for inserted regions.".format(len(insertion_signatures)))
    logging.info("Found {0} signatures for inverted regions.".format(len(inversion_signatures)))
    logging.info("Found {0} signatures for tandem duplicated regions.".format(len(tandem_duplication_signatures)))
    logging.info("Found {0} signatures for translocation breakpoints.".format(len(translocation_signatures)))
    if options.all_bnds:
        logging.info("Found {0} signatures for translocation breakpoints from other SV classes (DEL, INV, DUP).".format(len(translocation_signatures_all_bnds)))
    logging.info("Found {0} signatures for inserted regions with detected region of origin.".format(len(insertion_from_signatures)))
    
    logging.info("****************** STEP 2: CLUSTER ******************")
    signature_clusters = cluster_sv_signatures(sv_signatures, options)
    if options.all_bnds:
        rootLogger.setLevel(logging.WARNING)
        translocation_signature_clusters_all_bnds = cluster_sv_signatures(translocation_signatures_all_bnds, options)
        if options.verbose:
            rootLogger.setLevel(logging.DEBUG)
        else:
            rootLogger.setLevel(logging.INFO)

    # Write SV signature clusters
    logging.info("Finished clustering. Writing signature clusters..")
    if options.all_bnds:
        all_signature_clusters = (signature_clusters[0], signature_clusters[1], signature_clusters[2], signature_clusters[3], signature_clusters[4], signature_clusters[5] + translocation_signature_clusters_all_bnds[5])
        write_signature_clusters_bed(options.working_dir, all_signature_clusters)
        write_signature_clusters_vcf(options.working_dir, all_signature_clusters, __version__)
    else:
        write_signature_clusters_bed(options.working_dir, signature_clusters)
        write_signature_clusters_vcf(options.working_dir, signature_clusters, __version__)

    logging.info("****************** STEP 3: COMBINE ******************")
    deletion_candidates, inversion_candidates, int_duplication_candidates, tan_dup_candidates, novel_insertion_candidates, breakend_candidates = combine_clusters(signature_clusters, options)
    if options.all_bnds:
        rootLogger.setLevel(logging.WARNING)
        breakend_candidates_all_bnds = combine_clusters(translocation_signature_clusters_all_bnds, options)[5]
        if options.verbose:
            rootLogger.setLevel(logging.DEBUG)
        else:
            rootLogger.setLevel(logging.INFO)

    if not options.skip_genotyping:
        logging.info("****************** STEP 4: GENOTYPE ******************")
        logging.info("Genotyping deletions..")
        genotype(deletion_candidates, aln_file, "DEL", options)
        logging.info("Genotyping inversions..")
        genotype(inversion_candidates, aln_file, "INV", options)
        logging.info("Genotyping novel insertions..")
        genotype(novel_insertion_candidates, aln_file, "INS", options)
        logging.info("Genotyping interspersed duplications..")
        genotype(int_duplication_candidates, aln_file, "DUP_INT", options)
    
    # Write SV candidates
    logging.info("Write SV candidates..")
    logging.info("Final deletion candidates: {0}".format(len(deletion_candidates)))
    logging.info("Final inversion candidates: {0}".format(len(inversion_candidates)))
    logging.info("Final interspersed duplication candidates: {0}".format(len(int_duplication_candidates)))
    logging.info("Final tandem duplication candidates: {0}".format(len(tan_dup_candidates)))
    logging.info("Final novel insertion candidates: {0}".format(len(novel_insertion_candidates)))
    logging.info("Final breakend candidates: {0}".format(len(breakend_candidates)))
    types_to_output = [entry.strip() for entry in options.types.split(",")]
    if options.all_bnds:
        logging.info("Final breakend candidates from other SV classes (DEL, INV, DUP): {0}".format(len(breakend_candidates_all_bnds)))       
        write_candidates(options.working_dir, (int_duplication_candidates, inversion_candidates, tan_dup_candidates, deletion_candidates, novel_insertion_candidates, breakend_candidates + breakend_candidates_all_bnds))
        write_final_vcf(int_duplication_candidates,
                    inversion_candidates,
                    tan_dup_candidates,
                    deletion_candidates,
                    novel_insertion_candidates,
                    breakend_candidates + breakend_candidates_all_bnds,
                    __version__,
                    aln_file.references,
                    aln_file.lengths,
                    types_to_output,
                    options)
    else:
        write_candidates(options.working_dir, (int_duplication_candidates, inversion_candidates, tan_dup_candidates, deletion_candidates, novel_insertion_candidates, breakend_candidates))
        write_final_vcf(int_duplication_candidates,
                        inversion_candidates,
                        tan_dup_candidates,
                        deletion_candidates,
                        novel_insertion_candidates,
                        breakend_candidates,
                        __version__,
                        aln_file.references,
                        aln_file.lengths,
                        types_to_output,
                        options)
    logging.info("Draw plots..")
    plot_sv_lengths(deletion_candidates, inversion_candidates, int_duplication_candidates, tan_dup_candidates, novel_insertion_candidates, options)
    if not options.skip_genotyping:
        plot_sv_alleles(deletion_candidates + inversion_candidates + int_duplication_candidates + novel_insertion_candidates, options)
    logging.info("Done.")

if __name__ == "__main__":
    try:
        sys.exit(main())
    except Exception as e:
        logging.error(e, exc_info=True)
