#!/usr/bin/env python3

import argparse
import bz2
import ftplib
import glob
import gzip
import hashlib
import inspect
import itertools
import logging
import lzma
import math
import os
import re
import shutil
import subprocess
import sys
import tarfile
import tempfile
import threading
import urllib
import urllib.error
import urllib.parse
import urllib.request

LOG = None
SCRIPT_PATHNAME = None

NCBI_SERVER = "ftp.ncbi.nlm.nih.gov"
GREENGENES_SERVER = "greengenes.microbio.me"
SILVA_SERVER = "ftp.arb-silva.de"

WRAPPER_ARGS_TO_BIN_ARGS = {
    "block_size": "-B",
    "classified_out": "-C",
    "confidence": "-T",
    "fast_build": "-F",
    "interleaved": "-S",
    "kmer_len": "-k",
    "memory_mapping": "-M",
    "minimizer_len": "-l",
    "minimum_bits_for_taxid": "-r",
    "mininum_base_quality": "-Q",
    "mininum_hit_groups": "-g",
    "output": "-O",
    "paired": "-P",
    "protein": "-X",
    "quick": "-q",
    "report": "-R",
    "report_minimizer_data": "-K",
    "report_zero_counts": "-z",
    "skip_counts": "-s",
    "sub_block_size": "-b",
    "threads": "-p",
    "unclassified_out": "-U",
    "use_mpa_style": "-m",
    "use_names": "-n",
}


class FTP:
    def __init__(self, server):
        self.ftp = ftplib.FTP(server, timeout=600)
        self.ftp.login()
        self.ftp.sendcmd("TYPE I")
        self.pwd = "/"
        self.server = server

    def _progress_bar(self, f, remote_size):
        pb = ProgressBar(remote_size, f.tell())

        def inner(block):
            nonlocal f, remote_size, pb
            f.write(block)
            size_on_disk = f.tell()
            # if size_on_disk > remote_size:
            #     size_on_disk = remote_size
            pb.progress(size_on_disk)
            LOG.debug(
                "{:s} {: >10s}\r".format(
                    pb.get_bar(), format_bytes(size_on_disk)
                )
            )

        return inner

    def download(self, remote_dir, filepaths):
        if isinstance(filepaths, str):
            filepaths = [filepaths]
        number_of_files = len(filepaths)
        self.cwd(remote_dir)
        for index, filepath in enumerate(filepaths):
            mode = "ab"
            local_size = 0
            remote_size = self.size(filepath)
            if os.path.exists(filepath):
                local_size = os.stat(filepath).st_size
            else:
                if os.path.basename(filepath) != filepath:
                    os.makedirs(os.path.dirname(filepath), exist_ok=True)
            if local_size == remote_size:
                LOG.info(
                    "Already downloaded {:s}\n".format(get_abs_path(filepath))
                )
                continue
            if local_size > remote_size:
                mode = "wb"
            url_components = urllib.parse.SplitResult(
                "ftp", self.server, os.path.join(remote_dir, filepath), "", ""
            )
            url = urllib.parse.urlunsplit(url_components)
            if number_of_files == 1:
                LOG.info("Downloading {:s}\n".format(url))
            else:
                LOG.info(
                    "[{:d}/{:d}] Downloading {:s}\n".format(
                        index + 1, number_of_files, url
                    )
                )
            with open(filepath, mode) as f:
                while True:
                    try:
                        cb = self._progress_bar(f, remote_size)
                        self.ftp.retrbinary(
                            "RETR " + filepath, cb, rest=f.tell())
                        break
                    except KeyboardInterrupt:
                        f.flush()
                        self.close()
                        sys.exit(1)
                    except Exception:
                        f.flush()
                        self.reconnect()
                        self.cwd(remote_dir)
                        continue
            absolute_path = get_abs_path(filepath)
            local_filename, local_dirname = os.path.basename(
                absolute_path
            ), os.path.dirname(absolute_path)
            clear_console_line()
            LOG.info(
                "Saved {:s} to {:s}\n".format(local_filename, local_dirname)
            )

    def cwd(self, remote_pathname):
        self.ftp.cwd(remote_pathname)
        self.pwd = remote_pathname

    def size(self, filepath):
        size = 0
        while True:
            try:
                size = self.ftp.size(filepath)
                break
            except ftplib.error_temp:
                self.reconnect()
                continue
        return size

    def exists(self, filepath):
        while True:
            try:
                self.size(filepath)
                break
            except ftplib.error_perm as e:
                if e.args[0].find("No such file or directory"):
                    return False
                raise
        return True

    def connect(self, server):
        self.ftp = ftplib.FTP(server)
        self.ftp.login()
        self.ftp.sendcmd("TYPE I")

    def reconnect(self):
        host = self.ftp.host
        self.ftp.close()
        self.connect(host)
        self.ftp.cwd(self.pwd)

    def host(self):
        return self.ftp.host

    def close(self):
        self.ftp.quit()


class ProgressBar:
    def __init__(self, stop, current=0, width=30):
        self.stop = stop
        self.width = width
        self.current = current
        self.bar = list("-" * self.width)
        self.step = stop / self.width
        self.last_index = self._calculate_index()
        if self.current > 0:
            self.progress()

    def progress(self, amount=0, relative=False):
        if relative:
            self.current += amount
        else:
            self.current = amount
        if self.current > self.stop:
            self.current = self.stop
        index = self._calculate_index()
        for i in range(self.last_index, index):
            if i == 0:
                self.bar[i] = ">"
            else:
                self.bar[i - 1], self.bar[i] = "=", ">"
        self.last_index = index

    def get_bar(self):
        percentage = int(self.current / self.stop * 100)
        return "{:3d}% {:s}".format(percentage, "[" + "".join(self.bar) + "]")

    def _calculate_index(self):
        return math.floor(self.current / self.step)


def clear_console_line():
    LOG.debug("\33[2K\r")


def count_lines(*filenames):
    lines = 0
    for fname in filenames:
        with open(fname, "r") as f:
            for line in f:
                lines += 1
    return lines


def dwk2():
    estimate_capacity = find_kraken2_binary("estimate_capacity")
    output = subprocess.check_output(
        [estimate_capacity, "-h"], stderr=subprocess.STDOUT
    )
    for line in output.split(b"\n"):
        if line.startswith(b"Usage:"):
            return True if line.strip().endswith(b"<options>") else False
    return False


def get_binary_options(binary_pathname):
    options = []
    proc = subprocess.Popen(binary_pathname, stderr=subprocess.PIPE)
    lines = proc.stderr.readlines()
    for line in lines:
        match = re.search(br"\s(-\w)\s", line)
        if not match:
            continue
        options.append(match.group(1).decode())
    return options


def construct_seed_template(args):
    if int(args.minimizer_len / 4) < args.minimizer_spaces:
        LOG.error(
            "Number of minimizer spaces, {}, exceeds max for minimizer length, {}; max: {}\n".format(
                args.minimizer_spaces,
                args.minimizer_len,
                int(args.minimizer_len / 4),
            )
        )
        sys.exit(1)
    return (
        "1" * (args.minimizer_len - 2 * args.minimizer_spaces)
        + "01" * args.minimizer_spaces
    )


def wrapper_args_to_binary_args(opts, argv, binary_args):
    for k, v in vars(opts).items():
        if k not in WRAPPER_ARGS_TO_BIN_ARGS:
            continue
        if WRAPPER_ARGS_TO_BIN_ARGS[k] not in binary_args:
            continue
        if v is False:
            continue
        if v is True:
            argv.append(WRAPPER_ARGS_TO_BIN_ARGS[k])
        else:
            argv.extend([WRAPPER_ARGS_TO_BIN_ARGS[k], str(v)])


def find_kraken2_binary(name):
    # search the OS PATH
    if "PATH" in os.environ:
        for dir in os.environ["PATH"].split(":"):
            if os.path.exists(os.path.join(dir, name)):
                return os.path.join(dir, name)
    # search for binary in the same directory as wrapper
    script_parent_directory = get_parent_directory(SCRIPT_PATHNAME)
    if os.path.exists(os.path.join(script_parent_directory, name)):
        return os.path.join(script_parent_directory, name)
    # if called from within kraken2 project root, search the src dir
    project_root = get_parent_directory(script_parent_directory)
    if "src" in os.listdir(project_root) and name in os.listdir(
            os.path.join(project_root, "src")
    ):
        return os.path.join(project_root, os.path.join("src", name))
    # not found in these likely places, exit
    LOG.error("Unable to find {:s}, exiting\n".format(name))
    sys.exit(1)


def get_parent_directory(pathname):
    if len(pathname) == 0:
        return None
    pathname = os.path.abspath(pathname)
    if len(pathname) > 1 and pathname[-1] == os.path.sep:
        return os.path.dirname(pathname[:-1])
    return os.path.dirname(pathname)


def find_database(database_name):
    database_path = None
    if database_name.find(os.path.sep) < 0:
        if "KRAKEN2_DB_PATH" in os.environ:
            for directory in os.environ["KRAKEN2_DB_PATH"].split(":"):
                if os.path.exists(os.path.join(directory, database_name)):
                    database_path = os.path.join(directory, database_name)
                    break
        else:
            if database_name in os.listdir(os.getcwd()):
                database_path = database_name
    elif os.path.exists(database_name):
        database_path = database_name
    if database_path:
        for db_file in ["taxo.k2d", "hash.k2d", "opts.k2d"]:
            if not os.path.exists(os.path.join(database_path, db_file)):
                return None
    return database_path


def remove_files(*filepaths):
    for fname in filepaths:
        if not os.path.exists(fname):
            continue
        elif os.path.isdir(fname):
            shutil.rmtree(fname)
        else:
            os.remove(fname)


def check_seqid(seqid):
    taxid = None
    match = re.match(r"(?:^|\|)kraken:taxid\|(\d+)", seqid)
    if match:
        taxid = match.group(1)
    elif re.match(r"^\d+$", seqid):
        taxid = seqid
    if not taxid:
        match = re.match(r"(?:^|\|)([A-Z]+_?[A-Z0-9]+)(?:\||\b|\.)", seqid)
        if match:
            taxid = match.group(1)
    return taxid


def hash_file(filename, buf_size=8192):
    md5 = hashlib.md5()
    with open(filename, "rb") as in_file:
        while True:
            data = in_file.read(buf_size)
            if not data:
                break
            md5.update(data)
    return md5.hexdigest()


# This function is part of the Kraken 2 taxonomic sequence
# classification system.
#
# Reads multi-FASTA input and examines each sequence header.  Headers are
# OK if a taxonomy ID is found (as either the entire sequence ID or as part
# of a "kraken:taxid" token), or if something looking like an accession
# number is found.  Not "OK" headers will are fatal errors unless "lenient"
# is used.
#
# Each sequence header results in a line with three tab-separated values;
# the first indicating whether third column is the taxonomy ID ("TAXID") or
# an accession number ("ACCNUM") for the sequence ID listed in the second
# column.
#
def scan_fasta_file(in_file, out_file, lenient=False, hash=False):
    for line in in_file:
        if not line.startswith(">"):
            continue
        for match in re.finditer(r"(?:^>|\x01)(\S+)", line):
            seqid = match.group(1)
            taxid = check_seqid(seqid)
            if not taxid:
                if lenient:
                    continue
                else:
                    sys.exit(1)
            if re.match(r"^\d+$", taxid):
                out_file.write("TAXID\t{:s}\t{:s}\n".format(seqid, taxid))
            else:
                out_file.write("ACCNUM\t{:s}\t{:s}\n".format(seqid, taxid))


# This function is part of the Kraken 2 taxonomic sequence
# classification system.
#
# Looks up accession numbers and reports associated taxonomy IDs
#
# `lookup_list_file` is 1 2-column TSV file w/ sequence IDs and
# accession numbers, and `accession_map_files` is a list of
# accession2taxid files from NCBI.  Output is tab-delimited lines,
# with sequence IDs in first column and taxonomy IDs in second.
#
def lookup_accession_numbers(
    lookup_list_filename, out_filename, *accession_map_files
):
    target_lists = {}
    with open(lookup_list_filename, "r") as f:
        for line in f:
            line = line.strip()
            seqid, acc_num = line.split("\t")
            if acc_num in target_lists:
                target_lists[acc_num].append(seqid)
            else:
                target_lists[acc_num] = []
    initial_target_count = len(target_lists)
    with open(out_filename, "a") as out_file:
        for filename in accession_map_files:
            with open(filename, "r") as in_file:
                in_file.readline()  # discard header line
                for line in in_file:
                    line = line.strip()
                    accession, with_version, taxid, gi = line.split("\t")
                    if accession in target_lists:
                        lst = target_lists[accession]
                        del target_lists[accession]
                        for seqid in lst:
                            out_file.write(seqid + "\t" + taxid + "\n")
                        if len(target_lists) == 0:
                            break
            if len(target_lists) == 0:
                break
    if target_lists:
        LOG.warning(
            "{}/{} accession numbers remain unmapped, see unmapped.txt in DB directory\n".format(
                len(target_lists),
                initial_target_count),
        )
        with open("unmapped.txt", "w") as f:
            for k in target_lists:
                f.write(k + "\n")


def spawn_masking_subprocess(output_file, protein=False):
    masking_binary = "segmasker" if protein else "k2mask"
    if "MASKER" in os.environ:
        masking_binary = os.environ["MASKER"]
    masking_binary = find_kraken2_binary(masking_binary)

    argv = masking_binary + " -outfmt fasta | sed -e '/^>/!s/[a-z]/x/g'"
    if masking_binary.find("k2mask") >= 0:
        # k2mask can run multithreaded
        argv = masking_binary + " -outfmt fasta -threads 4 -r x"

    p = subprocess.Popen(
        argv, shell=True, stdin=subprocess.PIPE, stdout=output_file
    )

    def masker(input_file, final=False):
        shutil.copyfileobj(input_file, p.stdin)
        if final:
            p.stdin.close()
            p.wait()

    return masker


# Mask low complexity sequences in the database
def mask_files(input_filenames, output_filename, protein=False):
    with open(output_filename, "wb") as fout:
        masker = spawn_masking_subprocess(fout, protein)
        number_of_files = len(input_filenames)
        for i, input_filename in enumerate(input_filenames):
            library_name = os.path.basename(os.getcwd())
            if library_name == "added":
                LOG.info("Masking low-complexity regions of added library.")
            else:
                LOG.info(
                    "Masking low-complexity regions of downloaded library {:s}\n".format(
                        library_name
                    )
                )
            with open(input_filename, "rb") as fin:
                masker(fin, i + 1 == number_of_files)


def add_to_library(args):
    if not os.path.isdir(args.db):
        LOG.error("Invalid database: {:s}\n".format(args.db))
        sys.exit(1)
    library_pathname = os.path.join(args.db, "library")
    added_pathname = os.path.join(library_pathname, "added")
    os.makedirs(added_pathname, exist_ok=True)
    args.files = [os.path.abspath(f) for f in args.files]
    os.chdir(added_pathname)
    hashes = []
    if os.path.exists("added.md5"):
        with open("added.md5", "r") as in_file:
            hashes = [line.split()[0] for line in in_file.readlines()]
    for filename in args.files:
        filehash = hash_file(filename)
        if filehash in hashes:
            LOG.info(filename + " already added to library.\n")
            LOG.info("If not the case, remove the  entry from added.md5\n")
            return
        destination = os.path.basename(filename)
        ext = ".faa" if args.protein else "fna"
        if not destination.endswith(ext):
            base = destination.rsplit(".", 1)[0]
            destination = base + "." + ext
        prelim_map_filename = "prelim_map_" + filehash + ".txt"
        with open(prelim_map_filename, mode="a") as out_file:
            with open(filename, "r") as in_file:
                scan_fasta_file(in_file, out_file, lenient=True)
            shutil.copyfile(filename, destination)
        if not args.no_masking:
            mask_files(
                [destination], destination + ".masked", protein=args.protein
            )
            shutil.move(destination + ".masked", destination)
        with open("added.md5", "a") as out_file:
            out_file.write(filehash + "\t" + destination + "\n")
        LOG.info("Added " + filename + " to library " + args.db + "\n")


def make_manifest_from_assembly_summary(
    assembly_summary_file, is_protein=False
):
    suffix = "_protein.faa.gz" if is_protein else "_genomic.fna.gz"
    manifest_to_taxid = {}
    for line in assembly_summary_file:
        if line.startswith("#"):
            continue
        fields = line.strip().split("\t")
        taxid, asm_level, ftp_path = fields[5], fields[11], fields[19]
        if not re.match("Complete Genome|Chromosome", asm_level):
            continue
        if ftp_path == "na":
            continue
        remote_path = ftp_path + "/" + os.path.basename(ftp_path) + suffix
        url_components = urllib.parse.urlsplit(remote_path)
        local_path = url_components.path.replace("/genomes/", "")
        manifest_to_taxid[local_path] = taxid
    with open("manifest.txt", "w") as f:
        for k in manifest_to_taxid:
            f.write(k + "\n")
    return manifest_to_taxid


def assign_taxid_to_sequences(manifest_to_taxid, is_protein):
    LOG.info("Assigning taxonomic IDs to sequences\n")
    out_filename = "library.faa" if is_protein else "library.fna"
    with open(out_filename, "wb") as f:
        projects_added = 0
        total_projects = len(manifest_to_taxid)
        sequences_added = 0
        ch_added = 0
        ch = "aa" if is_protein else "bp"
        max_out_chars = 0
        for filepath in sorted(manifest_to_taxid):
            taxid = manifest_to_taxid[filepath]
            with gzip.open(filepath) as in_file:
                while True:
                    line = in_file.readline()
                    if line == b"":
                        break
                    if line.startswith(b">"):
                        line = line.replace(
                            b">", str.encode(">kraken:taxid|" + taxid + "|"), 1
                        )
                        sequences_added += 1
                    else:
                        ch_added += len(line) - 1
                    f.write(line)
            projects_added += 1
            out_line = progress_line(
                projects_added, total_projects, sequences_added, ch_added, ch
            )
            max_out_chars = max(len(out_line), max_out_chars)
            space_line = " " * max_out_chars
            LOG.debug("\r{:s}\r{:s}".format(space_line, out_line))
        LOG.info("\nFinished assigning taxonomic IDs to sequences\n")


def progress_line(projects, total_projects, seqs, chars, ch):
    line = "Processed "
    if projects == total_projects:
        line += str(projects)
    else:
        line += "{:d}/{:d}".format(projects, total_projects)
    line += " project(s), {:d} sequence(s), ".format(seqs)
    prefix = None
    for p in ["k", "M", "G", "T", "P", "E"]:
        if chars >= 1000:
            prefix = p
            chars /= 1000
        else:
            break
    if prefix:
        line += "{:.2f} {:s}{:s}".format(chars, prefix, ch)
    else:
        line += "{:.2f} {:s}".format(chars, ch)
    return line


def decompress_files(compressed_filenames, out_filename=None, buf_size=8192):
    if out_filename:
        if os.path.exists(out_filename + ".tmp"):
            os.remove(out_filename + ".tmp")
        with open(out_filename + ".tmp", "ab") as out_file:
            for filename in compressed_filenames:
                with gzip.open(filename) as gz:
                    decompress_file(gz, out_file)
            os.rename(out_filename + ".tmp", out_filename)
    else:
        for filename in compressed_filenames:
            out_filename, ext = os.path.splitext(filename)
            if os.path.exists(out_filename + ".tmp"):
                os.remove(out_filename + ".tmp")
            with gzip.open(filename) as gz:
                with open(out_filename + ".tmp", "wb") as out:
                    decompress_file(gz, out, buf_size)
            os.rename(out_filename + ".tmp", out_filename)
    remove_files(*compressed_filenames)


def decompress_file(in_file, out_file, buf_size=8129):
    LOG.info(
        "Decompressing {:s}\n".format(os.path.join(os.getcwd(), in_file.name))
    )
    while True:
        data = in_file.read(buf_size)
        out_file.write(data)
        if data == b"":
            break
    LOG.info(
        "Finished decompressing {:s}\n".format(
            os.path.join(os.getcwd(), in_file.name)
        )
    )


def download_log(filename):
    pb = None
    current_size = 0

    def inner(block_number, read_size, total_size):
        nonlocal pb, current_size
        if not pb:
            pb = ProgressBar(total_size)
        current_size += read_size
        pb.progress(current_size)
        LOG.debug(
            "{:s} {: >10s}\r".format(pb.get_bar(), format_bytes(current_size))
        )

    return inner


def download_file(url, local_name=None):
    if not local_name:
        local_name = urllib.parse.urlparse(url).path.split("/")[-1]
    else:
        os.makedirs(os.path.dirname(local_name), exist_ok=True)
    LOG.info("Beginning download of {:s}\n".format(local_name))
    urllib.request.urlretrieve(
        url, local_name, reporthook=download_log(local_name)
    )
    clear_console_line()
    LOG.info("Saved {:s} to {:s}\n".format(local_name, os.getcwd()))


def make_manifest_filter(file, regex):
    def inner(listing):
        nonlocal file, regex
        path = listing.split()[-1]
        if path.endswith(regex):
            file.write(path + "\n")

    return inner


def move(src, dst):
    src = os.path.abspath(src)
    dst = os.path.abspath(dst)
    if os.path.isfile(src) and os.path.isdir(dst):
        dst = os.path.join(dst, os.path.basename(src))
    shutil.move(src, dst)


def get_manifest(server, remote_path, regex):
    with open("manifest.txt", "w") as f:
        ftp = ftplib.FTP(server)
        ftp.login()
        ftp.cwd(remote_path)
        ftp.retrlines("LIST", callback=make_manifest_filter(f, regex))
        ftp.close()


def download_files_from_manifest(
    server,
    remote_dir,
    manifest_filename="manifest.txt",
    decompress=False,
    out_filename=None,
    filepath_to_taxid_table=None,
):
    with open(manifest_filename, "r") as f:
        filepaths = []
        ftp = FTP(server)
        spinner = ["|", "/", "—", "\\"]
        i = 0
        for filepath in f:
            LOG.info(
                "Checking if manifest files exist on server {:s}\r".format(
                    spinner[i % 4]
                )
            )
            i += 1
            filepath = filepath.strip()
            if not ftp.exists(urllib.parse.urljoin(remote_dir, filepath)):
                if filepath_to_taxid_table:
                    del filepath_to_taxid_table[filepath]
                LOG.warning(
                    "{:s} does not exist on server, skipping\n".format(
                        remote_dir + filepath
                    )
                )
                continue
            filepaths.append(filepath)
        ftp.download(remote_dir, filepaths)
        ftp.close()
        if decompress:
            decompress_files(filepaths, out_filename)


def download_taxonomy(args):
    taxonomy_path = os.path.join(args.db, "taxonomy")
    os.makedirs(taxonomy_path, exist_ok=True)
    os.chdir(taxonomy_path)
    ftp = FTP(NCBI_SERVER)
    if not args.skip_maps:
        if not args.protein:
            for subsection in ["gb", "wgs"]:
                LOG.info(
                    "Downloading nucleotide {:s} accession to taxon map\n".format(
                        subsection
                    )
                )
                filename = "nucl_" + subsection + ".accession2taxid.gz"
                ftp.download("/pub/taxonomy/accession2taxid/", filename)
        else:
            LOG.info("Downloading protein accession to taxon map\n")
            ftp.download(
                "/pub/taxonomy/accession2taxid", "prot.accession2taxid.gz"
            )
    LOG.info("Downloaded accession to taxon map(s)")
    LOG.info("and taxonomy tree data\n")
    ftp.download("/pub/taxonomy", "taxdump.tar.gz")
    ftp.close()
    LOG.info("Decompressing taxonomy data\n")
    decompress_files(glob.glob("*accession2taxid.gz"))
    LOG.info("Untarring taxonomy tree data\n")
    with tarfile.open("taxdump.tar.gz", "r:gz") as tar:
        tar.extractall()
    remove_files(*glob.glob("*.gz"))


def download_genomic_library(args):
    library_filename = "library.faa" if args.protein else "library.fna"
    library_pathname = os.path.join(args.db, "library")
    LOG.info("Adding {:s} to {:s}\n".format(args.library, args.db))
    files_to_remove = None
    if args.library in [
        "archaea",
        "bacteria",
        "viral",
        "fungi",
        "plant",
        "human",
        "protozoa",
    ]:
        library_pathname = os.path.join(library_pathname, args.library)
        os.makedirs(library_pathname, exist_ok=True)
        os.chdir(library_pathname)
        try:
            os.remove("assembly_summary.txt")
        except FileNotFoundError:
            pass
        remote_dir_name = args.library
        if args.library == "human":
            remote_dir_name = "vertebrate_mammalian/Homo_sapiens"
        try:
            url = "ftp://{:s}/genomes/refseq/{:s}/assembly_summary.txt".format(
                NCBI_SERVER, remote_dir_name
            )
            download_file(url)
        except urllib.error.URLError:
            LOG.error(
                "Error downloading assembly summary file for {:s}, exiting\n".format(
                    args.library
                )
            )
            sys.exit(1)
        if args.library == "human":
            with open("assembly_summary.txt", "r") as f1:
                with open("grc.txt", "w") as f2:
                    for line in f1:
                        if line.find("Genome Reference Consortium"):
                            f2.write(line)
            os.rename("grc.txt", "assembly_summary.txt")
        with open("assembly_summary.txt", "r") as f:
            filepath_to_taxid_table = make_manifest_from_assembly_summary(
                f, args.protein
            )
            download_files_from_manifest(
                NCBI_SERVER,
                "/genomes/",
                filepath_to_taxid_table=filepath_to_taxid_table,
            )
            assign_taxid_to_sequences(filepath_to_taxid_table, args.protein)
        with open(library_filename, "r") as in_file:
            with open("prelim_map.txt", "w") as out_file:
                scan_fasta_file(in_file, out_file)
        files_to_remove = ["all", "manifest.txt"]
    elif args.library == "plasmid":
        library_pathname = os.path.join(library_pathname, args.library)
        os.makedirs(library_pathname, exist_ok=True)
        os.chdir(library_pathname)
        pat = ".faa.gz" if args.protein else ".fna.gz"
        get_manifest(NCBI_SERVER, "genomes/refseq/plasmid/", pat)
        download_files_from_manifest(
            NCBI_SERVER,
            "/genomes/refseq/plasmid/",
            decompress=True,
            out_filename=library_filename,
        )
        with open(library_filename, "r") as in_file:
            with open("prelim_map.txt", "w") as out_file:
                scan_fasta_file(in_file, out_file)
        files_to_remove = glob.glob("plasmid.*")
        files_to_remove.append("manifest.txt")
    elif args.library in ["nr", "nt"]:
        protein_lib = True if args.library == "nr" else False
        if protein_lib and not args.protein:
            LOG.error(
                "{:s} is a protein database, and the Kraken2 database specified is nucleotide".format(
                    args.library
                )
            )
            sys.exit(1)
        library_pathname = os.path.join(library_pathname, args.library)
        os.makedirs(library_pathname, exist_ok=True)
        os.chdir(library_pathname)
        ftp = FTP(NCBI_SERVER)
        ftp.download("blast/db/FASTA/", args.library + ".gz")
        with gzip.open(args.library + ".gz", mode="rt") as in_file:
            with open("prelim_map.txt", "w") as out_file:
                scan_fasta_file(in_file, out_file, lenient=True)
        files_to_remove = [glob.glob("*.gz")]
    elif args.library in ["UniVec", "UniVec_Core"]:
        if args.protein:
            LOG.error(
                "{:s} is for nucleotide databases only\n".format(args.library)
            )
            sys.exit(1)
        library_pathname = os.path.join(library_pathname, args.library)
        os.makedirs(library_pathname, exist_ok=True)
        os.chdir(library_pathname)
        ftp = FTP(NCBI_SERVER)
        ftp.download("pub/UniVec", args.library)
        special_taxid = 28384
        LOG.info(
            "Adding taxonomy ID of {:d} to all sequences\n".format(
                special_taxid
            )
        )
        with open(args.library, "r") as in_file:
            with open("library.fna", "w") as out_file:
                for line in in_file:
                    if line.startswith(">"):
                        line = re.sub(
                            ">",
                            ">kraken:taxid|" + str(special_taxid) + "|",
                            line,
                        )
                    out_file.write(line)
        with open("library.fna", "r") as in_file:
            with open("prelim_map.txt", "w") as out_file:
                scan_fasta_file(in_file, out_file)
        files_to_remove = [args.library]
    if not args.no_masking:
        mask_files(
            [library_filename], library_filename + ".masked", args.protein
        )
        shutil.move(library_filename + ".masked", library_filename)
    LOG.info("Added {:s} to {:s}\n".format(args.library, args.db))
    if files_to_remove:
        clean_up(files_to_remove)


def get_abs_path(filename):
    return os.path.abspath(filename)


def is_compressed(filename):
    bzip_magic = b"\x42\x5A\x68"
    gzip_magic = b"\x1F\x8B"
    xz_magic = b"\xFD\x37\x7A\x58\x5A\x00"

    nbytes = len(xz_magic)
    with open(filename, "rb") as f:
        data = f.read(nbytes)
        if data.startswith((bzip_magic, gzip_magic, xz_magic)):
            return True
        return False


def get_reader(filename):
    bzip_magic = b"\x42\x5A\x68"
    gzip_magic = b"\x1F\x8B"
    xz_magic = b"\xFD\x37\x7A\x58\x5A\x00"

    nbytes = len(xz_magic)
    with open(filename, "rb") as f:
        data = f.read(nbytes)
        if data.startswith(bzip_magic):
            return bz2.open
        elif data.startswith(gzip_magic):
            return gzip.open
        elif data.startswith(xz_magic):
            return lzma.open
        else:
            return open


def read_from_files(filename1, filename2=None):
    reader1 = get_reader(filename1)
    reader2 = None

    if filename2 is not None:
        reader2 = get_reader(filename2)

    if reader2 is None:
        with reader1(filename1, "rb") as f:
            for seq in f:
                yield seq
    else:
        with reader1(filename1, "rb") as f1, reader2(filename2, "rb") as f2:
            for seq1, seq2 in itertools.zip_longest(f1, f2):
                if seq1 is None:
                    LOG.error(
                        "{} contains more sequences than {}".format(filename1, filename2)
                    )
                    sys.exit(1)
                if seq2 is None:
                    LOG.error(
                        "{} contains more sequences than {}".format(filename2, filename1)
                    )
                    sys.exit(1)
                yield (seq1, seq2)


def write_to_fifo(filenames, fifo1=None, fifo2=None):
    if fifo2 is not None:
        with open(fifo1, "wb") as file1, open(fifo2, "wb") as file2:
            for fn1, fn2 in zip(filenames[0::2], filenames[1::2]):
                for seq1, seq2 in read_from_files(fn1, fn2):
                    file1.write(seq1)
                    file2.write(seq2)
    else:
        with open(fifo1, "wb") as file1:
            for fn in filenames:
                for seq in read_from_files(fn):
                    file1.write(seq)


def build_kraken2_db(args):
    if not os.path.isdir(get_abs_path(args.db)):
        LOG.error('Cannot find Kraken 2 database: "{:s}\n'.format(args.db))
        sys.exit(1)
    os.chdir(args.db)
    if not os.path.isdir("taxonomy"):
        LOG.error("Cannot find taxonomy subdirectory in database\n")
        sys.exit(1)
    if not os.path.isdir("library"):
        LOG.error("Cannot find library subdirectory in database\n")
        sys.exit(1)
    if os.path.isdir(os.path.join("library", "added")):
        added_dirpath = os.path.join("library", "added")
        prelim_map_filenames = glob.glob(
            os.path.join(added_dirpath, "prelim_map_*.txt")
        )
        if prelim_map_filenames:
            with open(
                os.path.join(added_dirpath, "prelim_map.txt"), "w"
            ) as out_file:
                for filename in prelim_map_filenames:
                    with open(filename, "r") as in_file:
                        shutil.copyfileobj(in_file, out_file)
    if os.path.isfile("seqid2taxid.map"):
        LOG.info("Sequence ID to taxonomy ID map already present, skipping\n")
    else:
        LOG.info("Creating sequence ID to taxonomy ID map\n")
        with open(os.path.join("taxonomy", "prelim_map.txt"), "w") as out_file:
            for dirpath, dirnames, filenames in os.walk("library"):
                if "prelim_map.txt" in filenames:
                    for line in open(
                        os.path.join(
                            os.path.abspath(dirpath), "prelim_map.txt"
                        ),
                        "r",
                    ):
                        out_file.write(line)
        if os.path.getsize(os.path.join("taxonomy", "prelim_map.txt")) == 0:
            os.remove(os.path.join("taxonomy", "prelim_map.txt"))
            LOG.error(
                "No preliminary seqid/taxid mapping files found, aborting\n"
            )
            sys.exit(1)
        with open(os.path.join("taxonomy", "prelim_map.txt"), "r") as in_file:
            with open("seqid2taxid.map.tmp", "w") as seqid2taxid_file:
                with open("accmap.tmp", "w") as accmap_file:
                    for line in in_file:
                        new_line = "\t".join(line.split("\t")[1:])
                        if line.startswith("TAXID"):
                            seqid2taxid_file.write(new_line)
                        elif line.startswith("ACCNUM"):
                            accmap_file.write(new_line)
        if os.path.getsize("accmap.tmp") > 0:
            accession2taxid_filenames = glob.glob("taxonomy/*.accession2taxid")
            if accession2taxid_filenames:
                lookup_accession_numbers(
                    "accmap.tmp",
                    "seqid2taxid.map.tmp",
                    *accession2taxid_filenames
                )
            else:
                LOG.error(
                    "Accession to taxid map files are required to build this database.\n"
                )
                LOG.error(
                    "Run k2 download-taxonomy --db {:s} again".format(args.db)
                )
                sys.exit(1)
        os.remove("accmap.tmp")
        move("seqid2taxid.map.tmp", "seqid2taxid.map")
        LOG.info("Created sequence ID to taxonomy ID map\n")
    estimate_capacity_binary = find_kraken2_binary("estimate_capacity")
    argv = [estimate_capacity_binary, "-S", construct_seed_template(args)]
    if args.protein:
        argv.append("-X")
    wrapper_args_to_binary_args(
        args, argv, get_binary_options(estimate_capacity_binary)
    )
    fasta_filenames = glob.glob(
        os.path.join("library", os.path.join("**", "*.f[an]a")), recursive=True
    )
    if not dwk2():
        argv.extend(fasta_filenames)
    LOG.info("Running: " + " ".join(argv) + "\n")
    proc = subprocess.Popen(
        argv, stdin=subprocess.PIPE, stdout=subprocess.PIPE
    )
    if dwk2():
        for filename in fasta_filenames:
            with open(filename, "rb") as in_file:
                while True:
                    data = in_file.read(8192)
                    if not data:
                        break
                    proc.stdin.write(data)
    estimate = proc.communicate()[0].decode()
    required_capacity = (int(estimate.strip()) + 8192) / args.load_factor
    LOG.info(
        "Estimated hash table requirement: {:s}\n".format(
            format_bytes(required_capacity * 4)
        )
    )
    if args.max_db_size:
        if args.max_db_size < required_capacity * 4:
            args.max_db_size = int(args.max_db_size / 4)
            LOG.warning(
                "Specifiying lower maximum hash table size of {:f}\n".format(
                    args.max_db_size
                )
            )
    if os.path.isfile("hash.k2d"):
        LOG.info("Hash table already present, skipping build\n")
    else:
        LOG.info("Starting database build\n")
        build_db_bin = find_kraken2_binary("build_db")
        argv = [
            build_db_bin,
            "-H",
            "hash.k2d.tmp",
            "-t",
            "taxo.k2d.tmp",
            "-o",
            "opts.k2d.tmp",
            "-n",
            "taxonomy",
            "-m",
            "seqid2taxid.map",
            "-c",
            str(required_capacity),
            "-S",
            construct_seed_template(args),
        ]
        if args.protein:
            argv.append("-X")
        wrapper_args_to_binary_args(
            args, argv, get_binary_options(build_db_bin)
        )

        LOG.info("Running: " + " ".join(argv) + "\n")
        proc = subprocess.Popen(
            argv, stdin=subprocess.PIPE, stdout=subprocess.PIPE
        )
        for filename in fasta_filenames:
            with open(filename, "rb") as in_file:
                while True:
                    data = in_file.read(8192)
                    if not data:
                        break
                    proc.stdin.write(data)
        proc.stdin.close()
        proc.wait()

        move("hash.k2d.tmp", "hash.k2d")
        move("taxo.k2d.tmp", "taxo.k2d")
        move("opts.k2d.tmp", "opts.k2d")
        LOG.info("Finished building database\n")


# Parses RDP sequence data to create Kraken taxonomy
# and sequence ID -> taxonomy ID mapping
# Input (as <>): current_{Archaea,Bacteria}_unaligned.fa
def build_rdp_taxonomy(f):
    seqid_map = {}
    seen_it = {}
    child_data = {"root;no rank": {}}

    for line in f:
        if not line.startswith(">"):
            continue
        line = line.strip()
        seq_label, taxonomy_string = line.split("\t")
        seqid = seq_label.split(" ")[0]
        taxonomy_string = re.sub(
            "^Lineage=Root;rootrank;", "root;no rank;", taxonomy_string
        )
        taxonomy_string = re.sub(";$", ";no rank", taxonomy_string)
        seqid_map[seqid] = taxonomy_string
        seen_it.setdefault(taxonomy_string, 0)
        seen_it[taxonomy_string] += 1
        if seen_it[taxonomy_string] > 1:
            continue
        while True:
            match = re.search("(;[^;]+;[^;]+)$", taxonomy_string)
            if match is None:
                break
            level = match.group(1)
            taxonomy_string = re.sub(";[^;]+;[^;]+$", "", taxonomy_string)
            key = taxonomy_string + level
            child_data.setdefault(taxonomy_string, {})
            seen_it.setdefault(taxonomy_string, 0)
            child_data[taxonomy_string].setdefault(key, 0)
            child_data[taxonomy_string][key] += 1
            seen_it[taxonomy_string] += 1
            if seen_it[taxonomy_string] > 1:
                break
    id_map = {}
    next_node_id = 1
    with open("names.dmp", "w") as names_file:
        with open("nodes.dmp", "w") as nodes_file:
            bfs_queue = [["root;no rank", 1]]
            while len(bfs_queue) > 0:
                node, parent_id = bfs_queue.pop()
                match = re.search("([^;]+);([^;]+)$", node)
                if match is None:
                    LOG.error(
                        'BFS processing encountered formatting eror, "{:s}"\n'.format(
                            node
                        )
                    )
                    sys.exit(1)
                display_name, rank = match.group(1), match.group(2)
                if rank == "domain":
                    rank = "superkingdom"
                node_id, next_node_id = next_node_id, next_node_id + 1
                id_map[node] = node_id
                names_file.write(
                    "{:d}\t|\t{:s}\t|\t-\t|\tscientific name\t|\n".format(
                        node_id, display_name
                    )
                )
                nodes_file.write(
                    "{:d}\t|\t{:d}\t|\t{:s}\t|\t-\t|\n".format(
                        node_id, parent_id, rank
                    )
                )
                children = (
                    sorted([key for key in child_data[node]])
                    if node in child_data
                    else []
                )
                for node in children:
                    bfs_queue.insert(0, [node, node_id])
    with open("seqid2taxid.map", "w") as f:
        for seqid in sorted([key for key in seqid_map]):
            taxid = id_map[seqid_map[seqid]]
            f.write("{:s}\t{:d}\n".format(seqid, taxid))


# Build the standard Kraken database
def build_standard_database(args):
    download_taxonomy(args)
    for library in [
        "archaea",
        "bacteria",
        "viral",
        "plasmid",
        "human",
        "UniVec_Core",
    ]:
        if library == "UniVec_Core" and args.protein:
            continue
        args.library = library
        download_genomic_library(args)
    build_kraken2_db(args)


# Parses Silva taxonomy file to create Kraken taxonomy
# Input (as <>): tax_slv_ssu_nr_119.txt
def build_silva_taxonomy(in_file):
    id_map = {"root": 1}
    with open("names.dmp", "w") as names_file:
        with open("nodes.dmp", "w") as nodes_file:
            names_file.write("1\t|\troot\t|\t-\t|\tscientific name\t|\n")
            nodes_file.write("1\t|\t1\t|\tno rank\t|\t-\t|\n")
            for line in in_file:
                line = line.strip()
                taxonomy_string, node_id, rank = line.split("\t")[:3]
                id_map[taxonomy_string] = node_id
                match = re.search("^(.+;|)([^;]+);$", taxonomy_string)
                if match:
                    parent_name = match.group(1)
                    display_name = match.group(2)
                    if parent_name == "":
                        parent_name = "root"
                    parent_id = id_map[parent_name] or None
                    if not parent_id:
                        LOG.error('orphan error: "{:s}"\n'.format(line))
                        sys.exit(1)
                    if rank == "domain":
                        rank = "superkingdom"
                    names_file.write(
                        "{:s}\t|\t{:s}\t|\t-\t|\tscientific name\t|\n".format(
                            node_id, display_name
                        )
                    )
                    nodes_file.write(
                        "{:s}\t|\t{:s}\t|\t{:s}\t|\t-\t|\n".format(
                            node_id, str(parent_id), rank
                        )
                    )
                else:
                    LOG.error('strange input: "{:s}"\n'.format(line))
                    sys.exit(1)


# Build a 16S database from Silva data
def build_16S_silva(args):
    os.makedirs(args.db, exist_ok=True)
    os.chdir(args.db)
    for directory in ["data", "taxonomy", "library"]:
        os.makedirs(directory, exist_ok=True)
    os.chdir("data")
    remote_directory = "/release_138_1/Exports"
    fasta_filename = "SILVA_138.1_SSURef_NR99_tax_silva.fasta.gz"
    taxonomy_prefix = "tax_slv_ssu_138.1"
    ftp = FTP(SILVA_SERVER)
    ftp.download(remote_directory, fasta_filename)
    ftp.download(
        remote_directory + "/taxonomy", taxonomy_prefix + ".acc_taxid.gz"
    )
    decompress_files([taxonomy_prefix + ".acc_taxid.gz"])
    ftp.download(remote_directory + "/taxonomy", taxonomy_prefix + ".txt.gz")
    with gzip.open(taxonomy_prefix + ".txt.gz", "rt") as f:
        build_silva_taxonomy(f)
    os.chdir(os.path.pardir)
    move(os.path.join("data", "names.dmp"), "taxonomy")
    move(os.path.join("data", "nodes.dmp"), "taxonomy")
    move(
        os.path.join("data", taxonomy_prefix + ".acc_taxid"), "seqid2taxid.map"
    )
    with gzip.open(os.path.join("data", fasta_filename), "rt") as in_file:
        with open(os.path.join("library", "silva.fna"), "w") as out_file:
            for line in in_file:
                if not line.startswith(">"):
                    line = line.replace("U", "T")
                out_file.write(line)
    if not args.no_masking:
        filename = os.path.join("library", "silva.fna")
        mask_files([filename], filename + ".masked")
        shutil.move(filename + ".masked", filename)

    os.chdir(os.path.pardir)
    build_kraken2_db(args)


# Parses Greengenes taxonomy file to create Kraken taxonomy
# and sequence ID -> taxonomy ID mapping
# Input (as <>): gg_13_5_taxonomy.txt
def build_gg_taxonomy(in_file):
    rank_codes = {
        "k": "superkingdom",
        "p": "phylum",
        "c": "class",
        "o": "order",
        "f": "family",
        "g": "genus",
        "s": "species",
    }
    seqid_map = {}
    seen_it = {}
    child_data = {"root": {}}
    for line in in_file:
        line = line.strip()
        seqid, taxonomy_string = line.split("\t")
        taxonomy_string = re.sub("(; [a-z]__)+$", "", taxonomy_string)
        seqid_map[seqid] = taxonomy_string
        seen_it.setdefault(taxonomy_string, 0)
        seen_it[taxonomy_string] += 1
        if seen_it[taxonomy_string] > 1:
            continue
        while True:
            match = re.search("(; [a-z]__[^;]+$)", taxonomy_string)
            if not match:
                break
            level = match.group(1)
            taxonomy_string = re.sub("(; [a-z]__[^;]+$)", "", taxonomy_string)
            child_data.setdefault(taxonomy_string, {})
            key = taxonomy_string + level
            seen_it.setdefault(taxonomy_string, 0)
            child_data[taxonomy_string].setdefault(key, 0)
            child_data[taxonomy_string][key] += 1
            seen_it[taxonomy_string] += 1
            if seen_it[taxonomy_string] > 1:
                break
        if seen_it[taxonomy_string] == 1:
            child_data["root"].setdefault(taxonomy_string, 0)
            child_data["root"][taxonomy_string] += 1
    id_map = {}
    next_node_id = 1
    with open("names.dmp", "w") as names_file:
        with open("nodes.dmp", "w") as nodes_file:
            bfs_queue = [["root", 1]]
            while len(bfs_queue) > 0:
                node, parent_id = bfs_queue.pop()
                display_name = node
                rank = None
                match = re.search("g__([^;]+); s__([^;]+)$", node)
                if match:
                    genus, species = match.group(1), match.group(2)
                    rank = "species"
                    if re.search(" endosymbiont ", species):
                        display_name = species
                    else:
                        display_name = genus + " " + species
                else:
                    match = re.search("([a-z])__([^;]+)$", node)
                    if match:
                        rank = rank_codes[match.group(1)]
                        display_name = match.group(2)
                rank = rank or "no rank"
                node_id, next_node_id = next_node_id, next_node_id + 1
                id_map[node] = node_id
                names_file.write(
                    "{:d}\t|\t{:s}\t|\t-\t|\tscientific name\t|\n".format(
                        node_id, display_name
                    )
                )
                nodes_file.write(
                    "{:d}\t|\t{:d}\t|\t{:s}\t|\t-\t|\n".format(
                        node_id, parent_id, rank
                    )
                )
                children = (
                    sorted([key for key in child_data[node]])
                    if node in child_data
                    else []
                )
                for node in children:
                    bfs_queue.insert(0, [node, node_id])
    with open("seqid2taxid.map", "w") as f:
        for seqid in sorted([key for key in seqid_map], key=int):
            taxid = id_map[seqid_map[seqid]]
            f.write("{:s}\t{:d}\n".format(seqid, taxid))


# Build a 16S database from Greengenes data
def build_16S_gg(args):
    os.makedirs(args.db, exist_ok=True)
    gg_version = "gg_13_5"
    remote_directory = "/greengenes_release/" + gg_version
    os.chdir(args.db)
    for directory in ["data", "taxonomy", "library"]:
        os.makedirs(directory, exist_ok=True)
    os.chdir("data")
    ftp = FTP(GREENGENES_SERVER)
    ftp.download(remote_directory, gg_version + ".fasta.gz")
    decompress_files([gg_version + ".fasta.gz"])
    ftp.download(remote_directory, gg_version + "_taxonomy.txt.gz")
    decompress_files([gg_version + "_taxonomy.txt.gz"])
    with open(gg_version + "_taxonomy.txt", "r") as f:
        build_gg_taxonomy(f)
    os.chdir(os.path.abspath(os.path.pardir))
    move(os.path.join("data", "names.dmp"), "taxonomy")
    move(os.path.join("data", "nodes.dmp"), "taxonomy")
    move(os.path.join("data", "seqid2taxid.map"), os.getcwd())
    move(
        os.path.join("data", gg_version + ".fasta"),
        os.path.join("library", "gg.fna"),
    )
    if not args.no_masking:
        filename = os.path.join("library", "gg.fna")
        mask_files([filename], filename + ".masked")
        shutil.move(filename + ".masked", filename)
    os.chdir(os.path.abspath(os.path.pardir))
    build_kraken2_db(args)


# Build a 16S data from RDP data
def build_16S_rdp(args):
    os.makedirs(args.db, exist_ok=True)
    os.chdir(args.db)
    for directory in ["data", "taxonomy", "library"]:
        os.makedirs(directory, exist_ok=True)
    os.chdir("data")
    download_file(
        "http://rdp.cme.msu.edu/download/current_Bacteria_unaligned.fa.gz"
    )
    download_file(
        "http://rdp.cme.msu.edu/download/current_Archaea_unaligned.fa.gz"
    )
    decompress_files(glob.glob("*gz"))
    for filename in glob.glob("current_*_unaligned.fa"):
        with open(filename, "r") as f:
            build_rdp_taxonomy(f)
    os.chdir(os.pardir)
    move(os.path.join("data", "names.dmp"), "taxonomy")
    move(os.path.join("data", "nodes.dmp"), "taxonomy")
    move(os.path.join("data", "seqid2taxid.map"), os.getcwd())
    for filename in glob.glob(os.path.join("data", "*.fa")):
        new_filename = os.path.basename(re.sub(r"\.fa$", ".fna", filename))
        shutil.move(filename, os.path.join("library", new_filename))
        if not args.no_masking:
            new_filename = os.path.join("library", new_filename)
            mask_files([new_filename], new_filename + ".masked")
            shutil.move(new_filename + ".masked", new_filename)

    build_kraken2_db(args)


# Reads multi-FASTA input and examines each sequence header.  In quiet
# mode headers are OK if a taxonomy ID is found (as either the entire
# sequence ID or as part of a "kraken:taxid" token), or if something
# looking like a GI or accession number is found.  In normal mode, the
# taxonomy ID will be looked up (if not explicitly specified in the
# sequence ID) and reported if it can be found.  Output is
# tab-delimited lines, with sequence IDs in first column and taxonomy
# IDs in second.


# Sequence IDs with a kraken:taxid token will use that to assign taxonomy
# ID, e.g.:
# >gi|32499|ref|NC_021949.2|kraken:taxid|562|
#
# Sequence IDs that are completely numeric are assumed to be the taxonomy
# ID for that sequence.
#
# Otherwise, an accession number is searched for; if not found, a GI
# number is searched for.  Failure to find any of the above is a fatal error.
# Without `quiet`, a comma-separated file list specified by -A (for both accession
# numbers and GI numbers) is examined; failure to find a
# taxonomy ID that maps to a provided accession/GI number is non-fatal and
# will emit a warning.
#
# With -q, does not print any output, and will die w/ nonzero exit instead
# of warning when unable to find a taxid, accession #, or GI #.
#
def make_seqid_to_taxid_map(
    in_file, quiet, accession_map_filenames=False, library_map_filename=None
):
    target_lists = {}
    for line in in_file:
        match = re.match(r">(\S+)", line)
        if match is None:
            continue
        seqid = match.group(1)
        output = None
        regexes = [
            r"(?:^|\|)kraken:taxid\|(\d+)",
            r"^\d+$",
            r"(?:^|\|)([A-Z]+_?[A-Z0-9]+)(?:\||\b|\.)",
            r"(?:^|\|)gi\|(\d+)",
        ]
        match = None
        index = None
        for i, regex in enumerate(regexes):
            match = re.match(regex, seqid)
            if match:
                index = i
                break
        if index == 0:
            output = seqid + "\t" + match.group(1) + "\n"
        elif index == 1:
            output = seqid + "\t" + seqid + "\n"
        elif index in [2, 3]:
            if not quiet:
                capture = match.group(1)
                target_lists.setdefault(capture, [])
                target_lists[capture].insert(0, seqid)
        else:
            LOG.error(
                "Unable to determine taxonomy ID for sequence {:s}\n".format(
                    seqid
                )
            )
            sys.exit(1)
        if output and not quiet:
            print(output)
    if quiet:
        if len(target_lists) == 0:
            LOG.error("External map required\n")
        sys.exit(0)
    if len(target_lists) == 0:
        sys.exit(0)
    if not accession_map_filenames and library_map_filename is None:
        LOG.error(
            "Found sequence ID without explicit taxonomy ID, but no map used\n"
        )
        sys.exit(1)
    # Remove targets where we've already handled the mapping
    if library_map_filename:
        with open(library_map_filename, "r") as f:
            for line in f:
                line = line.strip()
                seqid, taxid = line.split("\t")
                if seqid in target_lists:
                    print("{:s}\t{:s}\n".format(seqid, taxid))
                    del target_lists[seqid]
    if len(target_lists) == 0:
        sys.exit(0)
    for filename in accession_map_filenames:
        with open(filename, "r") as f:
            f.readline()
            for line in f:
                line = line.strip()
                accession, with_version, taxid, gi = line.split("\t")
                if accession in target_lists:
                    target_list = target_lists[accession]
                    del target_lists[accession]
                    for seqid in target_list:
                        print("{:s}\t{:s}".format(seqid, taxid))
                if gi != "na" and gi in target_lists:
                    target_list = target_lists[gi]
                    del target_lists[gi]
                    for seqid in target_list:
                        print("{:s}\t{:s}\n".format(seqid, taxid))


def classify(args):
    classify_bin = find_kraken2_binary("classify")
    database_path = find_database(args.db)
    if database_path is None:
        LOG.error("{:s} is not a valid database... exiting".format(args.db))
        sys.exit(1)
    if "paired" in args and len(args.filenames) % 2 != 0:
        LOG.error("--paired requires an even number of file names")
        sys.exit(1)
    if args.confidence < 0 or args.confidence > 1:
        LOG.error(
            "--confidence, {:f}, must be between 0 and 1 inclusive".format(
                args.confidence
            )
        )
        sys.exit(1)
    argv = [
        classify_bin,
        "-H",
        os.path.join(database_path, "hash.k2d"),
        "-t",
        os.path.join(database_path, "taxo.k2d"),
        "-o",
        os.path.join(database_path, "opts.k2d"),
    ]
    wrapper_args_to_binary_args(args, argv, get_binary_options(classify_bin))
    if any([is_compressed(filename) for filename in args.filenames]):
        with tempfile.TemporaryDirectory() as temp_dir_name:
            fifo1_pathname = os.path.join(temp_dir_name, "fifo1")
            fifo2_pathname = None
            try:
                os.mkfifo(fifo1_pathname, 0o600)
            except OSError:
                LOG.error(
                    "Unable to create FIFO for processing compressed files"
                )
                sys.exit(1)
            if "-P" in argv:
                fifo2_pathname = os.path.join(temp_dir_name, "fifo2")
                try:
                    os.mkfifo(fifo2_pathname, 0o600)
                except OSError:
                    LOG.error(
                        "Unable to create FIFO for processing compressed files"
                    )
                    sys.exit(1)
                argv.extend([fifo1_pathname, fifo2_pathname])
            else:
                argv.append(fifo1_pathname)
            thread = threading.Thread(target=subprocess.call, args=(argv,))
            thread.start()
            write_to_fifo(args.filenames, fifo1_pathname, fifo2_pathname)
            thread.join()
    else:
        argv.extend(args.filenames)
        sys.exit(subprocess.call(argv))


def inspect_db(args):
    database_pathname = find_database(args.db)
    if not database_pathname:
        LOG.error("{:s} database does not exist\n")
        sys.exit(1)
    for database_file in ["taxo.k2d", "hash.k2d", "opts.k2d"]:
        if not os.path.isfile(os.path.join(database_pathname, database_file)):
            LOG.error("{:s} does not exist\n".format(database_file))
    dump_table_bin = find_kraken2_binary("dump_table")
    argv = [
        dump_table_bin,
        "-H",
        os.path.join(database_pathname, "hash.k2d"),
        "-t",
        os.path.join(database_pathname, "taxo.k2d"),
        "-o",
        os.path.join(database_pathname, "opts.k2d"),
    ]
    wrapper_args_to_binary_args(args, argv, get_binary_options(dump_table_bin))
    sys.exit(subprocess.call(argv))


def format_bytes(size):
    current_suffix = "B"
    for suffix in ["kB", "MB", "GB", "TB", "PB", "EB"]:
        if size >= 1024:
            current_suffix = suffix
            size /= 1024
        else:
            break
    return "{:.2f}{:s}".format(size, current_suffix)


def clean_up(filenames):
    LOG.info("Removing extra files\n")
    disk_usage_before = shutil.disk_usage(os.getcwd()).free
    remove_files(*filenames)
    disk_usage_after = shutil.disk_usage(os.getcwd()).free
    freed = disk_usage_after - disk_usage_before
    LOG.info("Cleaned up {:s} of space\n".format(format_bytes(freed)))


def clean_db(args):
    os.chdir(args.db)
    clean_up(["data", "library", "taxonomy", "seqid2taxid.map"])


def make_build_parser(subparsers):
    parser = subparsers.add_parser(
        "build",
        help="Create DB from library\
              (requires taxonomy downloaded and at least one file\
              in library)",
    )
    parser.add_argument(
        "--db",
        type=str,
        metavar="PATHNAME",
        required=True,
        help="Name of Kraken 2 database"
    )
    group = parser.add_argument_group("special")
    mutex_group = group.add_mutually_exclusive_group()
    mutex_group.add_argument(
        "--standard", action="store_true", help="Make standard database"
    )
    mutex_group.add_argument(
        "--special",
        type=str,
        choices=["greengenes", "rdp", "silva"],
        help="Build special database",
    )
    group.add_argument(
        "--no-masking",
        action="store_true",
        help="Avoid masking low-complexity sequences prior to\
              building; masking requires dustmasker or segmasker to be\
              installed",
    )

    parser.add_argument(
        "--kmer-len",
        type=int,
        metavar="INT",
        help="K-mer length in bp/aa"
    )
    parser.add_argument(
        "--minimizer-len",
        type=int,
        metavar="INT",
        help="Minimizer length in bp/aa"
    )
    parser.add_argument(
        "--minimizer-spaces",
        type=int,
        metavar="INT",
        help="Number of characters in minimizer that are\
              ignored in comparisons",
    )
    parser.add_argument(
        "--threads",
        type=int,
        metavar="INT",
        default=os.environ.get("KRAKEN2_NUM_THREADS") or 1,
        help="Number of threads",
    )
    parser.add_argument(
        "--load-factor",
        type=float,
        metavar="FLOAT (0,1]",
        default=0.7,
        help="Proportion of the hash table to be populated",
    )
    parser.add_argument(
        "--fast-build",
        action="store_true",
        help="Do not require database to be deterministically\
              built when using multiple threads. This is faster, but\
              does introduce variability in minimizer/LCA pairs.",
    )
    parser.add_argument(
        "--max-db-size",
        type=int,
        metavar="INT",
        help="Maximum number of bytes for Kraken 2 hash table;\
              if the estimator determines more would normally be\
              needed, the reference library will be downsampled to fit",
    )
    parser.add_argument(
        "--skip-maps",
        action="store_true",
        help="Avoids downloading accession number to taxid maps",
    )
    parser.add_argument(
        "--protein",
        action="store_true",
        help="Build a protein database for translated search",
    )
    parser.add_argument(
        "--block-size",
        type=int,
        metavar="INT",
        default=16384,
        help="Read block size"
    )
    parser.add_argument(
        "--sub-block-size",
        type=int,
        metavar="INT",
        default=0,
        help="Read subblock size"
    )
    parser.add_argument(
        "--minimum-bits-for-taxid",
        type=int,
        metavar="INT",
        default=0,
        help="Bit storage requested for taxid",
    )
    parser.add_argument(
        "--log",
        type=str,
        metavar="FILENAME",
        default=None,
        help="Specify a log file (default: stderr)",
    )


def make_download_taxonomy_parser(subparsers):
    parser = subparsers.add_parser(
        "download-taxonomy", help="Download NCBI taxonomic information"
    )
    parser.add_argument(
        "--db",
        type=str,
        metavar="PATHNAME",
        required=True,
        help="Name of Kraken 2 database"
    )
    parser.add_argument(
        "--protein",
        action="store_true",
        help="Files being added are for a protein database",
    )
    parser.add_argument(
        "--skip-maps",
        action="store_true",
        help="Avoids downloading accession number to taxid maps",
    )
    parser.add_argument(
        "--log",
        type=str,
        metavar="FILENAME",
        default=None,
        help="Specify a log filename (default: stderr)",
    )


def make_download_library_parser(subparsers):
    parser = subparsers.add_parser(
        "download-library", help="Download and build a special database"
    )
    parser.add_argument(
        "--db",
        type=str,
        metavar="PATHNAME",
        required=True,
        help="Name of Kraken 2 database"
    )
    parser.add_argument(
        "--library",
        type=str,
        required=True,
        choices=[
            "archaea",
            "bacteria",
            "plasmid",
            "viral",
            "human",
            "fungi",
            "plant",
            "protozoa",
            "nr",
            "nt",
            "UniVec",
            "UniVec_Core",
        ],
        help="Name of library to download",
    )
    parser.add_argument(
        "--protein",
        action="store_true",
        help="Files being added are for a protein database",
    )
    parser.add_argument(
        "--no-masking",
        action="store_true",
        help="Avoid masking low-complexity sequences prior to\
              building; masking requires dustmasker or segmasker to be\
              installed",
    )
    parser.add_argument(
        "--log",
        type=str,
        metavar="FILENAME",
        default=None,
        help="Specify a log filename (default stderr)",
    )


def make_add_to_library_parser(subparsers):
    parser = subparsers.add_parser(
        "add-to-library", help="Add file(s) to library"
    )
    parser.add_argument(
        "--db",
        type=str,
        metavar="PATHNAME",
        required=True,
        help="Name of Kraken 2 database"
    )
    parser.add_argument(
        "--file",
        "--files",
        type=str,
        nargs="+",
        dest="files",
        help="Pathname of file(s) to be added to library",
    )
    parser.add_argument(
        "--protein",
        action="store_true",
        help="Files being added are for a protein database",
    )
    parser.add_argument(
        "--no-masking",
        action="store_true",
        help="Avoid asking low-complexity sequences prior to\
              building; masking requires dustmasker or segmasker to be\
              installed",
    )
    parser.add_argument(
        "--log",
        type=str,
        metavar="FILENAME",
        default=None,
        help="Specify a log filename (default: stderr)",
    )


def make_classify_parser(subparsers):
    parser = subparsers.add_parser(
        "classify", help="Classify a set of sequences"
    )
    parser.add_argument(
        "--db",
        type=str,
        metavar="PATHNAME",
        required=True,
        help="Name of Kraken2 DB"
    )
    parser.add_argument(
        "--threads",
        type=int,
        metavar="INT",
        default=os.environ.get("KRAKEN2_NUM_THREADS") or 1,
        help="Number of threads",
    )
    parser.add_argument(
        "--quick",
        action="store_true",
        default=argparse.SUPPRESS,
        help="Quick operation (use first hit or hits)",
    )
    parser.add_argument(
        "--unclassified-out",
        type=str,
        default=argparse.SUPPRESS,
        metavar="FILENAME",
        help="Print unclassified sequences to filename",
    )
    parser.add_argument(
        "--classified-out",
        type=str,
        metavar="FILENAME",
        default=argparse.SUPPRESS,
        help="Print classified sequences to filename",
    )
    parser.add_argument(
        "--output",
        type=str,
        metavar="FILENAME",
        default=argparse.SUPPRESS,
        help='Print output to file (default: stdout) "-" will \
              suppress normal output',
    )
    parser.add_argument(
        "--confidence",
        type=float,
        default=0.0,
        help="confidence score threshold (default: 0.0); must be in [0,1]",
    )
    parser.add_argument(
        "--mininum-base-quality",
        type=int,
        metavar="INT",
        default=0,
        help="Mininum base quality used in classification",
    )
    parser.add_argument(
        "--report",
        type=str,
        default=argparse.SUPPRESS,
        help="Print a report with aggregate counts/clade to file",
    )
    parser.add_argument(
        "--use-mpa-style",
        action="store_true",
        default=argparse.SUPPRESS,
        help="With --report, format report output like Kraken 1's\
              kraken-mpa-report",
    )
    parser.add_argument(
        "--report-zero-counts",
        action="store_true",
        default=argparse.SUPPRESS,
        help="With --report, report counts for ALL taxa, even if\
              counts are zero",
    )
    parser.add_argument(
        "--report-minimizer-data",
        action="store_true",
        default=argparse.SUPPRESS,
        help="With --report, report minimizer and distinct minimizer\
              count information in addition to normal Kraken report",
    )
    parser.add_argument(
        "--memory-mapping",
        action="store_true",
        default=argparse.SUPPRESS,
        help="Avoids loading database into RAM",
    )
    paired_group = parser.add_mutually_exclusive_group()
    paired_group.add_argument(
        "--paired",
        action="store_true",
        default=argparse.SUPPRESS,
        help="The filenames provided have paired-end reads",
    )
    paired_group.add_argument(
        "--interleaved",
        action="store_true",
        default=argparse.SUPPRESS,
        help="The filenames provided have paired-end reads",
    )
    parser.add_argument(
        "--use-names",
        action="store_true",
        default=argparse.SUPPRESS,
        help="Print scientific names instead of just taxids",
    )
    parser.add_argument(
        "--mininum-hit-groups",
        type=int,
        metavar="INT",
        default=2,
        help="Minimum number of hit groups (overlapping k-mers\
              sharing the same minimizer) needed to make a call",
    )
    parser.add_argument(
        "--log",
        type=str,
        metavar="FILENAME",
        default=None,
        help="Specify a log filename (default: stderr)",
    )
    parser.add_argument(
        "filenames",
        nargs="+",
        type=str,
        help="Filenames to be classified, supports bz2, gzip, and xz"
    )


def make_inspect_parser(subparsers):
    parser = subparsers.add_parser("inspect", help="Inspect Kraken 2 database")
    parser.add_argument(
        "--db",
        type=str,
        metavar="PATHNAME",
        required=True,
        help="Name of Kraken2 DB"
    )
    parser.add_argument(
        "--threads",
        type=int,
        default=os.environ.get("KRAKEN2_NUM_THREADS") or 1,
        help="Number of threads",
    )
    parser.add_argument(
        "--skip-counts",
        action="store_true",
        help="Only print database summary statistics",
    )
    parser.add_argument(
        "--use-mpa-style",
        action="store_true",
        help="Format output like Kraken 1's kraken-mpa-report",
    )
    parser.add_argument(
        "--report-zero-counts",
        action="store_true",
        help="Report counts for ALL taxa, even if counts are zero",
    )
    parser.add_argument(
        "--log",
        type=str,
        metavar="FILENAME",
        default=None,
        help="Specify a log filename (default: stderr)",
    )


def make_clean_parser(subparsers):
    parser = subparsers.add_parser(
        "clean", help="Remove unneeded files from database"
    )
    parser.add_argument(
        "--db",
        type=str,
        metavar="PATHNAME",
        required=True,
        help="Name of Kraken2 DB"
    )
    parser.add_argument(
        "--log",
        type=str,
        metavar="FILENAME",
        default=None,
        help="Specify a log filename (default: stderr)",
    )


class HelpAction(argparse._HelpAction):
    def __call__(self, parser, namespace, values, option_string=None):
        parser.print_help()
        subparsers = parser._actions[1].choices
        for action, arg_parser in subparsers.items():
            sys.stderr.write("\n\n" + action + "\n" + "-" * len(action) + "\n")
            arg_parser.print_help()
        sys.exit(0)


def make_cmdline_parser():
    parser = argparse.ArgumentParser("kraken2", add_help=False)
    parser.add_argument("-h", "--help", action=HelpAction)
    subparsers = parser.add_subparsers()
    make_add_to_library_parser(subparsers)
    make_download_library_parser(subparsers)
    make_download_taxonomy_parser(subparsers)
    make_build_parser(subparsers)
    make_classify_parser(subparsers)
    make_inspect_parser(subparsers)
    make_clean_parser(subparsers)
    return parser


def setup_logger(filename=None):
    logging.StreamHandler.terminator = ""
    logger = logging.getLogger("kraken2")
    if filename:
        logger.setLevel(logging.INFO)
        handler = logging.FileHandler(filename)
        formatter = logging.Formatter("%(asctime)s: %(message)s")
        handler.setFormatter(formatter)
        logger.addHandler(handler)
    else:
        logger.setLevel(logging.DEBUG)
        handler = logging.StreamHandler()
        formatter = logging.Formatter("%(message)s")
        handler.setFormatter(formatter)
        logger.addHandler(handler)
    return logger


def k2_main():
    global SCRIPT_PATHNAME
    global LOG

    SCRIPT_PATHNAME = os.path.realpath(inspect.getsourcefile(k2_main))

    parser = make_cmdline_parser()
    if len(sys.argv) == 1:
        parser.print_help()
        sys.exit(1)
    args = parser.parse_args(sys.argv[1:])
    LOG = setup_logger(args.log)
    task = sys.argv[1]
    if task not in ["classify", "inspect"]:
        args.db = os.path.abspath(args.db)
    if task == "download-taxonomy":
        download_taxonomy(args)
    elif task == "classify":
        classify(args)
    elif task == "download-library":
        download_genomic_library(args)
    elif task == "add-to-library":
        add_to_library(args)
    elif task == "inspect":
        inspect_db(args)
    elif task == "clean":
        clean_db(args)
    elif task == "build":
        # Protein defaults
        default_aa_minimizer_length = 12
        default_aa_kmer_length = 15
        default_aa_minimizer_spaces = 0
        # Nucleotide defaults
        default_nt_minimizer_length = 31
        default_nt_kmer_length = 35
        default_nt_minimizer_spaces = 7

        if args.sub_block_size == 0:
            args.sub_block_size = math.ceil(args.block_size / args.threads)
        if not args.kmer_len:
            args.kmer_len = (
                default_aa_kmer_length
                if args.protein
                else default_nt_kmer_length
            )
        if not args.minimizer_len:
            args.minimizer_len = (
                default_aa_minimizer_length
                if args.protein
                else default_nt_minimizer_length
            )
        if not args.minimizer_spaces:
            args.minimizer_spaces = (
                default_aa_minimizer_spaces
                if args.protein
                else default_nt_minimizer_spaces
            )
        if args.minimizer_len > args.kmer_len:
            LOG.error(
                "Minimizer length ({}) must not be greater than kmer length {}\n".format(
                    args.minimizer_len, args.kmer_len
                )
            )
            sys.exit(1)
        if args.load_factor <= 0 or args.load_factor > 1:
            LOG.error(
                "Load factor must be greater than 0 but no more than 1\n"
            )
            sys.exit(1)
        if args.minimizer_len <= 0 or args.minimizer_len > 31:
            LOG.error(
                "Minimizer length must be a positive integer and cannot exceed 31\n"
            )
            sys.exit(1)
        if args.standard:
            build_standard_database(args)
        elif args.special:
            if args.special == "greengenes":
                build_16S_gg(args)
            elif args.special == "silva":
                build_16S_silva(args)
            else:
                build_16S_rdp(args)
        else:
            if args.no_masking:
                LOG.warning(
                    "--no-masking only affects the `--standard` and"
                    "`--special` flags. Its effect will be ignored.\n"
                )
            build_kraken2_db(args)


if __name__ == "__main__":
    k2_main()
