#!/usr/bin/python3

import xml.etree.ElementTree as ET
import textwrap
import argparse
import sys
import os
import re
import logging

log = logging.getLogger()

class Markdown:
    """
    Markdown output writer
    """
    def __init__(self, out):
        self.out = out

    def print(self, *args, **kw):
        print(*args, file=self.out, **kw)

    def title(self, text, level=1):
        self.print("{} {}".format("#" * level, text))
        self.print()

    def anchor(self, key):
        self.print("<a name='{key}'></a>".format(key=key))

    def link_target(self, key, link):
        self.print("[{key}]: {link}".format(key=key, link=link))

    def def_list(self, items):
        for head, root in items:
            for idx, node in enumerate(root):
                if idx == 0:
                    initial_indent = "* {head}: ".format(head=head)
                else:
                    initial_indent = "  "
                    self.print()
                self.node(node, initial_indent=initial_indent, subsequent_indent="  ")
        self.print()

    def node(self, node, **wrap_kw):
        if node.tag == "para":
            self.node_para(node, **wrap_kw)
        else:
            log.warn("Unknown node to print: %s", node.tag)

    def node_para(self, node, **wrap_kw):
        content = []
        if node.text:
            content.append(node.text)

        def flush_content():
            nonlocal content
            # Flush the current content
            text = "".join(content).strip()
            if text:
                text = re.sub("fapi_parms.md", "[fapi_parms.md](fapi_parms.md)", text)
                text = re.sub(r"idba_([a-z_]+)(\(\))?", r"[idba_\1()](#idba_\1)", text)
                for line in textwrap.wrap(text, break_on_hyphens=False, **wrap_kw):
                    self.print(line)
            content = []

        for el in node:
            if el.tag == "computeroutput":
                content.append("`{}`".format(el.text))
                if el.tail: content.append(el.tail)
            elif el.tag == "ref":
                content.append(el.text)
                if el.tail: content.append(el.tail)
            elif el.tag == "anchor":
                if el.tail: content.append(el.tail)
            elif el.tag == "itemizedlist":
                # Flush the current content
                flush_content()

                self.print()
                self.node_list(el, **wrap_kw)
                self.print()

                if el.tail: content.append(el.tail)
            elif el.tag == "programlisting":
                # Flush the current content
                flush_content()

                self.print()
                self.node_programlisting(el, **wrap_kw)
                self.print()

                if el.tail: content.append(el.tail)
            elif el.tag in ("parameterlist", "simplesect"):
                continue
            else:
                log.warn("Invalid tag inside para: %s", el.tag)

        flush_content()

    def node_list(self, node, **wrap_kw):
        if "initial_indent" in wrap_kw or "subsequent_indent" in wrap_kw:
            log.warn("Nested lists are not supported")
            return

        for el in node:
            if el.tag != "listitem":
                log.warn("Found unsupported %s inside itemizedlist", el.type)
                continue
            for n in el:
                self.node(n, initial_indent="* ", subsequent_indent="  ")

    def node_programlisting(self, node, **wrap_kw):
        if "initial_indent" in wrap_kw or "subsequent_indent" in wrap_kw:
            log.warn("Code insite lists is not supported")
            return

        def harvest_codeline(node):
            content = []
            if node.text:
                content.append(node.text)
            for n in node:
                if n.tag == "sp":
                    content.append(" ")
                content.extend(harvest_codeline(n))
                if n.tail:
                    content.append(n.tail)
            return content

        self.print()
        self.print("```fortran")
        for node_line in node:
            if node_line.tag != "codeline":
                log.warn("Found unsupported %s inside programlisting", node_line.tag)
                continue
            content = harvest_codeline(node_line)
            self.print("".join(content))
        self.print("```")
        self.print()

    def para(self, text, **wrap_kw):
        for line in textwrap.wrap(text, break_on_hyphens=False, **wrap_kw):
            self.print(line)
        self.print()

    def raw(self, text):
        self.print(text)


def collect_text(node):
    return "".join(x for x in node.itertext() if x and not x.isspace()).strip()

def collect_desc(node):
    for para in node.iterfind("para"):
        if para.find("parameterlist"): continue
        yield collect_text(para)

class Argument:
    """
    Parameter built from a <parameterlist> element
    """
    def __init__(self, node):
        self.name = node.find(".//parametername").text
        self.desc_node = node.find(".//parameterdescription")


class Function:
    """
    Function built from a <memberdef> element
    """
    def __init__(self, node):
        self.name = node.find("name").text
        log.debug("Found function %s", self.name)
        self.summary = collect_text(node.find("briefdescription"))
        self.params = []
        for n in node.iter("param"):
            type = n.find("type").text
            if type == "SUBROUTINE(func)":
                self.params.append("func")
            else:
                self.params.append(n.find("declname").text)

        self.desc_node = node.find("detaileddescription")

        self.args = []
        for paramlist_node in node.findall(".//parameterlist"):
            if paramlist_node.attrib["kind"] in ("param", "retval"):
                for n in paramlist_node.findall("parameteritem"):
                    self.args.append(Argument(n))


        self.retval_desc = node.find(".//simplesect[@kind='return']")

    @property
    def arg_string(self):
        return ", ".join(self.params)

    def print_details(self, md):
        md.anchor(self.name)
        md.title("{f.name}({f.arg_string})".format(f=self), level=4)

        if self.args:
            md.para("Parameters:")
            dl = []
            for a in self.args:
                dl.append(
                    ("`{a.name}`".format(a=a), a.desc_node)
                )
            md.def_list(dl)

        if self.retval_desc:
            md.para("Return value:")
            for n in self.retval_desc:
                md.node(n)

        md.para(self.summary)
        for node in self.desc_node:
            md.node(node)


class Section:
    def __init__(self, tree):
        self.functions = []

        bdesc_node = tree.find("header")
        if bdesc_node is not None:
            self.summary = bdesc_node.text
        else:
            self.summary = None
        log.debug("Found section %s", self.summary)
        self.desc_node = tree.find("description")
        #if self.desc_node is None:
            #log.warn("Section \"%s\" has no description", self.summary)

        for node in tree.iter("memberdef"):
            if node.attrib["kind"] != "function": continue
            f = Function(node)
            if not f.name.startswith("idba_"): continue
            self.functions.append(f)

    def print_links(self, md):
        for f in self.functions:
            md.link_target(f.name, "fapi_reference.md#{f.name}".format(f=f))

    def print_summary(self, md):
        md.title(self.summary, level=3)

        if self.desc_node:
            for node in self.desc_node:
                log.debug("%s SUMMARY %s", self.summary, node.tag)
                md.node(node)

        md.raw("""<table class="table">
<thead>
<tr>
    <th>Name</th>
    <th>Description</th>
</th>
</thead>
<tbody>""")

        for f in self.functions:
            md.raw("<tr><td><code><a href='#{f.name}'>{f.name}({f.arg_string})</a></code></td><td>{f.summary}</td></tr>".format(f=f))

        md.raw("""</tbody>
</table>
""")

    def print_details(self, md):
        md.title(self.summary, level=3)

        for f in self.functions:
            f.print_details(md)


class API:
    def __init__(self, args):
        self.sections = []
        self.load_file(os.path.join(args.xmldir, "error_8cc.xml"))
        self.load_file(os.path.join(args.xmldir, "binding_8cc.xml"))

    def load_file(self, pathname):
        log.info("Loading file %s", pathname)
        with open(pathname) as fd:
            etree = ET.fromstring(fd.read())

        for node in etree.iter("sectiondef"):
            if node.attrib["kind"] not in ("define", "user-defined", "func"): continue
            sec = Section(node)
            if not sec.functions or not sec.summary: continue
            self.sections.append(sec)

    def print_links(self, md):
        for s in self.sections:
            s.print_links()

    def print_reference(self, md):
        md.title("Fortran API reference", 1)
        self.print_summary(md)
        self.print_details(md)

    def print_summary(self, md):
        md.title("Summary of routines", 2)

        for s in self.sections:
            s.print_summary(md)

    def print_details(self, md):
        md.title("Reference of routines", 2)

        for s in self.sections:
            s.print_details(md)


def main():
    parser = argparse.ArgumentParser(description="Build fortran API reference markdown documentation.")
    parser.add_argument("xmldir", help="doxygen xml directory")
    parser.add_argument("--links", action="store_true", help="print link shortcuts to the functions in the reference")
    parser.add_argument("-v", "--verbose", action="store_true", help="verbose output")
    parser.add_argument("--debug", action="store_true", help="verbose output")

    args = parser.parse_args()

    FORMAT = "%(asctime)-15s %(levelname)s %(message)s"
    if args.debug:
        logging.basicConfig(level=logging.DEBUG, stream=sys.stderr, format=FORMAT)
    elif args.verbose:
        logging.basicConfig(level=logging.INFO, stream=sys.stderr, format=FORMAT)
    else:
        logging.basicConfig(level=logging.WARN, stream=sys.stderr, format=FORMAT)

    api = API(args)
    md = Markdown(sys.stdout)
    if args.links:
        api.print_links(md)
    else:
        api.print_reference(md)



if __name__ == "__main__":
    main()
