#!/usr/bin/env python3
# pylint: disable=C0114,C0115,C0116,C0209,C0302,R0902,R0911,R0912,R0914,R0915,E1101
#
# Copyright 2022-2023 by Wilson Snyder. Verilator is free software; you
# can redistribute it and/or modify it under the terms of either the GNU Lesser
# General Public License Version 3 or the Apache License 2.0.
# SPDX-License-Identifier: LGPL-3.0-only OR Apache-2.0

import argparse
import os
import sys
import shlex
from typing import Callable, Iterable, Optional, Union
import dataclasses
from dataclasses import dataclass
import enum
from enum import Enum
import multiprocessing
import tempfile

import clang.cindex
from clang.cindex import CursorKind, Index, TranslationUnitSaveError, TranslationUnitLoadError


def fully_qualified_name(node):
    if node is None:
        return []
    if node.kind == CursorKind.TRANSLATION_UNIT:
        return []
    res = fully_qualified_name(node.semantic_parent)
    if res:
        return res + ([node.displayname] if node.displayname else [])
    return [node.displayname] if node.displayname else []


@dataclass
class VlAnnotations:
    mt_start: bool = False
    mt_safe: bool = False
    mt_safe_postinit: bool = False
    mt_unsafe: bool = False
    mt_unsafe_one: bool = False
    pure: bool = False
    guarded: bool = False
    requires: bool = False
    excludes: bool = False
    acquire: bool = False
    release: bool = False

    def is_mt_safe_context(self):
        return (not (self.mt_unsafe or self.mt_unsafe_one)
                and (self.mt_safe or self.mt_start))

    def is_pure_context(self):
        return self.pure

    def is_mt_unsafe_call(self):
        return self.mt_unsafe or self.mt_unsafe_one

    def is_mt_safe_call(self):
        return (not self.is_mt_unsafe_call()
                and (self.mt_safe or self.mt_safe_postinit or self.pure
                     or self.requires or self.excludes or self.acquire
                     or self.release))

    def is_pure_call(self):
        return self.pure

    def __or__(self, other: "VlAnnotations"):
        result = VlAnnotations()
        for key, value in dataclasses.asdict(self).items():
            setattr(result, key, value | getattr(other, key))
        return result

    def is_empty(self):
        for value in dataclasses.asdict(self).values():
            if value:
                return False
        return True

    def __str__(self):
        result = []
        for field, value in dataclasses.asdict(self).items():
            if value:
                result.append(field)
        return ", ".join(result)

    @staticmethod
    def from_nodes_list(nodes: Iterable):
        result = VlAnnotations()
        for node in nodes:
            if node.kind == CursorKind.ANNOTATE_ATTR:
                if node.displayname == "MT_START":
                    result.mt_start = True
                elif node.displayname == "MT_SAFE":
                    result.mt_safe = True
                elif node.displayname == "MT_SAFE_POSTINIT":
                    result.mt_safe_postinit = True
                elif node.displayname == "MT_UNSAFE":
                    result.mt_unsafe = True
                elif node.displayname == "MT_UNSAFE_ONE":
                    result.mt_unsafe_one = True
                elif node.displayname == "PURE":
                    result.pure = True
                elif node.displayname in ["ACQUIRE", "ACQUIRE_SHARED"]:
                    result.acquire = True
                elif node.displayname in ["RELEASE", "RELEASE_SHARED"]:
                    result.release = True
                elif node.displayname == "REQUIRES":
                    result.requires = True
                elif node.displayname in ["EXCLUDES", "MT_SAFE_EXCLUDES"]:
                    result.excludes = True
                elif node.displayname == "GUARDED_BY":
                    result.guarded = True
            # Attributes are always at the beginning
            elif not node.kind.is_attribute():
                break
        return result


class FunctionType(Enum):
    UNKNOWN = enum.auto()
    FUNCTION = enum.auto()
    METHOD = enum.auto()
    STATIC_METHOD = enum.auto()
    CONSTRUCTOR = enum.auto()

    @staticmethod
    def from_node(node: clang.cindex.Cursor):
        if node is None:
            return FunctionType.UNKNOWN
        if node.kind == CursorKind.FUNCTION_DECL:
            return FunctionType.FUNCTION
        if node.kind == CursorKind.CXX_METHOD and node.is_static_method():
            return FunctionType.STATIC_METHOD
        if node.kind == CursorKind.CXX_METHOD:
            return FunctionType.METHOD
        if node.kind == CursorKind.CONSTRUCTOR:
            return FunctionType.CONSTRUCTOR
        return FunctionType.UNKNOWN


@dataclass(eq=False)
class FunctionInfo:
    name_parts: list[str]
    usr: str
    file: str
    line: int
    annotations: VlAnnotations
    ftype: FunctionType

    _hash: Optional[int] = dataclasses.field(default=None,
                                             init=False,
                                             repr=False)

    @property
    def name(self):
        return "::".join(self.name_parts)

    def __str__(self):
        return f"[{self.name}@{self.file}:{self.line}]"

    def __hash__(self):
        if not self._hash:
            self._hash = hash(f"{self.usr}:{self.file}:{self.line}")
        return self._hash

    def __eq__(self, other):
        return (self.usr == other.usr and self.file == other.file
                and self.line == other.line)

    def copy(self, /, **changes):
        return dataclasses.replace(self, **changes)

    @staticmethod
    def from_node(node: clang.cindex.Cursor,
                  refd: Optional[clang.cindex.Cursor] = None,
                  annotations: Optional[VlAnnotations] = None):
        file = os.path.abspath(node.location.file.name)
        line = node.location.line
        if annotations is None:
            annotations = VlAnnotations.from_nodes_list(node.get_children())
        if refd is None:
            refd = node.referenced
        if refd is not None:
            refd = refd.canonical
        assert refd is not None
        name_parts = fully_qualified_name(refd)
        usr = refd.get_usr()
        ftype = FunctionType.from_node(refd)

        return FunctionInfo(name_parts, usr, file, line, annotations, ftype)


class DiagnosticKind(Enum):
    ANNOTATIONS_DEF_DECL_MISMATCH = enum.auto()
    NON_PURE_CALL_IN_PURE_CTX = enum.auto()
    NON_MT_SAFE_CALL_IN_MT_SAFE_CTX = enum.auto()

    def __lt__(self, other):
        return self.value < other.value


@dataclass
class Diagnostic:
    target: FunctionInfo
    source: FunctionInfo
    source_ctx: FunctionInfo
    kind: DiagnosticKind

    _hash: Optional[int] = dataclasses.field(default=None,
                                             init=False,
                                             repr=False)

    def __hash__(self):
        if not self._hash:
            self._hash = hash(
                hash(self.target) ^ hash(self.source_ctx) ^ hash(self.kind))
        return self._hash


class CallAnnotationsValidator:

    def __init__(self, diagnostic_cb: Callable[[Diagnostic], None],
                 is_ignored_top_level: Callable[[clang.cindex.Cursor], bool],
                 is_ignored_def: Callable[
                     [clang.cindex.Cursor, clang.cindex.Cursor], bool],
                 is_ignored_call: Callable[[clang.cindex.Cursor], bool]):
        self._diagnostic_cb = diagnostic_cb
        self._is_ignored_top_level = is_ignored_top_level
        self._is_ignored_call = is_ignored_call
        self._is_ignored_def = is_ignored_def

        self._index = Index.create()

        self._processed_headers: set[str] = set()

        # Current context
        self._call_location: Optional[FunctionInfo] = None
        self._caller: Optional[FunctionInfo] = None
        self._level: int = 0

    def compile_and_analyze_file(self, source_file: str,
                                 compiler_args: list[str],
                                 build_dir: Optional[str]):
        filename = os.path.abspath(source_file)
        initial_cwd = "."

        if build_dir:
            initial_cwd = os.getcwd()
            os.chdir(build_dir)
        translation_unit = self._index.parse(filename, compiler_args)
        has_errors = False
        for diag in translation_unit.diagnostics:
            if diag.severity > clang.cindex.Diagnostic.Error:
                has_errors = True
        if translation_unit and not has_errors:
            self.process_translation_unit(translation_unit)
        else:
            print(f"%Error: parsing failed: {filename}", file=sys.stderr)
        if build_dir:
            os.chdir(initial_cwd)

    def emit_diagnostic(self, target: Union[FunctionInfo, clang.cindex.Cursor],
                        kind: DiagnosticKind):
        assert self._caller is not None
        assert self._call_location is not None
        source = self._caller
        source_ctx = self._call_location
        if isinstance(target, FunctionInfo):
            self._diagnostic_cb(Diagnostic(target, source, source_ctx, kind))
        else:
            self._diagnostic_cb(
                Diagnostic(FunctionInfo.from_node(target), source, source_ctx,
                           kind))

    def iterate_children(self, children: Iterable[clang.cindex.Cursor],
                         handler: Callable[[clang.cindex.Cursor], None]):
        if children:
            self._level += 1
            for child in children:
                handler(child)
            self._level -= 1

    @staticmethod
    def get_referenced_node_info(
        node: clang.cindex.Cursor
    ) -> tuple[bool, Optional[clang.cindex.Cursor], VlAnnotations,
               Iterable[clang.cindex.Cursor]]:
        if not node.spelling and not node.displayname:
            return (False, None, VlAnnotations(), [])

        refd = node.referenced
        if refd is None:
            raise ValueError("The node does not specify referenced node.")

        refd = refd.canonical
        children = list(refd.get_children())

        annotations = VlAnnotations.from_nodes_list(children)
        return (True, refd, annotations, children)

    # Call handling

    def process_method_call(self, node: clang.cindex.Cursor,
                            refd: clang.cindex.Cursor,
                            annotations: VlAnnotations):
        assert self._call_location
        ctx = self._call_location.annotations

        # MT-safe context
        if ctx.is_mt_safe_context():
            is_mt_safe = False

            if annotations.is_mt_safe_call():
                is_mt_safe = True
            elif not annotations.is_mt_unsafe_call():
                # Check whether the object the method is called on is mt-safe
                def find_object_ref(node):
                    try:
                        node = next(node.get_children())
                        if node.kind == CursorKind.DECL_REF_EXPR:
                            # Operator on an argument or local object
                            return node
                        if node.kind != CursorKind.MEMBER_REF_EXPR:
                            return None
                        if node.referenced and node.referenced.kind == CursorKind.FIELD_DECL:
                            # Operator on a member object
                            return node
                        node = next(node.get_children())
                        if node.kind == CursorKind.UNEXPOSED_EXPR:
                            node = next(node.get_children())
                        return node
                    except StopIteration:
                        return None

                refn = find_object_ref(node)
                # class/struct member
                if refn and refn.kind == CursorKind.MEMBER_REF_EXPR and refn.referenced:
                    refn = refn.referenced
                    refna = VlAnnotations.from_nodes_list(refn.get_children())
                    if refna.guarded:
                        is_mt_safe = True
                # variable
                elif refn and refn.kind == CursorKind.DECL_REF_EXPR and refn.referenced:
                    # This is probably a local or an argument. Assume it's safe.
                    is_mt_safe = True

            if not is_mt_safe:
                self.emit_diagnostic(
                    FunctionInfo.from_node(refd, refd, annotations),
                    DiagnosticKind.NON_MT_SAFE_CALL_IN_MT_SAFE_CTX)

        if ctx.is_pure_context():
            if not annotations.is_pure_call():
                self.emit_diagnostic(
                    FunctionInfo.from_node(refd, refd, annotations),
                    DiagnosticKind.NON_PURE_CALL_IN_PURE_CTX)

    def process_function_call(self, refd: clang.cindex.Cursor,
                              annotations: VlAnnotations):
        assert self._call_location
        ctx = self._call_location.annotations

        # MT-safe context
        if ctx.is_mt_safe_context():
            if not annotations.is_mt_safe_call():
                self.emit_diagnostic(
                    FunctionInfo.from_node(refd, refd, annotations),
                    DiagnosticKind.NON_MT_SAFE_CALL_IN_MT_SAFE_CTX)
        # pure context
        if ctx.is_pure_context():
            if not annotations.is_pure_call():
                self.emit_diagnostic(
                    FunctionInfo.from_node(refd, refd, annotations),
                    DiagnosticKind.NON_PURE_CALL_IN_PURE_CTX)

    def process_constructor_call(self, refd: clang.cindex.Cursor,
                                 annotations: VlAnnotations):
        assert self._call_location
        ctx = self._call_location.annotations

        # Constructors are always OK in MT-safe context.

        # pure context
        if ctx.is_pure_context():
            if not annotations.is_pure_call(
            ) and not refd.is_default_constructor():
                self.emit_diagnostic(
                    FunctionInfo.from_node(refd, refd, annotations),
                    DiagnosticKind.NON_PURE_CALL_IN_PURE_CTX)

    def dispatch_call_node(self, node: clang.cindex.Cursor):
        [supported, refd, annotations, _] = self.get_referenced_node_info(node)

        if not supported:
            self.iterate_children(node.get_children(),
                                  self.dispatch_node_inside_definition)
            return

        assert refd is not None
        if self._is_ignored_call(refd):
            return

        assert self._call_location is not None
        node_file = os.path.abspath(node.location.file.name)
        self._call_location = self._call_location.copy(file=node_file,
                                                       line=node.location.line)

        # Standalone functions and static class methods
        if (refd.kind == CursorKind.FUNCTION_DECL
                or refd.kind == CursorKind.CXX_METHOD
                and refd.is_static_method()):
            self.process_function_call(refd, annotations)
            return
        # Function pointer
        if refd.kind in [
                CursorKind.VAR_DECL, CursorKind.FIELD_DECL,
                CursorKind.PARM_DECL
        ]:
            self.process_function_call(refd, annotations)
            return
        # Non-static class methods
        if refd.kind == CursorKind.CXX_METHOD:
            self.process_method_call(node, refd, annotations)
            return
        # Conversion method (e.g. `operator int()`)
        if refd.kind == CursorKind.CONVERSION_FUNCTION:
            self.process_method_call(node, refd, annotations)
            return
        # Constructors
        if refd.kind == CursorKind.CONSTRUCTOR:
            self.process_constructor_call(refd, annotations)
            return

        # Ignore other callables
        print(f"{refd.location.file.name}:{refd.location.line}: "
              f"{refd.displayname}    {refd.kind}\n"
              f"    from: {node.location.file.name}:{node.location.line}")

    # Definition handling

    def dispatch_node_inside_definition(self, node: clang.cindex.Cursor):
        if node.kind == CursorKind.CALL_EXPR:
            self.dispatch_call_node(node)
            return None

        return self.iterate_children(node.get_children(),
                                     self.dispatch_node_inside_definition)

    def process_function_definition(self, node: clang.cindex.Cursor):
        [supported, refd, annotations, _] = self.get_referenced_node_info(node)

        if refd and self._is_ignored_def(node, refd):
            return None

        node_children = list(node.get_children())

        if not supported:
            return self.iterate_children(node_children, self.dispatch_node)

        assert refd is not None

        prev_caller = self._caller
        prev_call_location = self._call_location

        def_annotations = VlAnnotations.from_nodes_list(node_children)

        if not (def_annotations.is_empty() or def_annotations == annotations):
            # Use definition's annotations for the diagnostic
            # source (i.e. the definition)
            self._caller = FunctionInfo.from_node(node, refd, def_annotations)
            self._call_location = self._caller

            self.emit_diagnostic(
                FunctionInfo.from_node(refd, refd, annotations),
                DiagnosticKind.ANNOTATIONS_DEF_DECL_MISMATCH)

        # Use concatenation of definition and declaration annotations
        # for callees validation.
        self._caller = FunctionInfo.from_node(node, refd,
                                              def_annotations | annotations)
        self._call_location = self._caller

        self.iterate_children(node_children,
                              self.dispatch_node_inside_definition)

        self._call_location = prev_call_location
        self._caller = prev_caller

        return None

    # Nodes not located inside definition

    def dispatch_node(self, node: clang.cindex.Cursor):
        if node.is_definition() and node.kind in [
                CursorKind.CXX_METHOD, CursorKind.FUNCTION_DECL,
                CursorKind.CONSTRUCTOR, CursorKind.CONVERSION_FUNCTION
        ]:
            return self.process_function_definition(node)
        if node.is_definition() and node.kind in [
                CursorKind.NAMESPACE, CursorKind.STRUCT_DECL,
                CursorKind.UNION_DECL, CursorKind.CLASS_DECL
        ]:
            return self.iterate_children(node.get_children(),
                                         self.dispatch_node)

        return self.iterate_children(node.get_children(), self.dispatch_node)

    def process_translation_unit(
            self, translation_unit: clang.cindex.TranslationUnit):
        self._level += 1
        for child in translation_unit.cursor.get_children():
            if self._is_ignored_top_level(child):
                continue
            if self._processed_headers:
                filename = os.path.abspath(child.location.file.name)
                if filename in self._processed_headers:
                    continue
            self.dispatch_node(child)
        self._level -= 1

        self._processed_headers.update([
            os.path.abspath(str(hdr.source))
            for hdr in translation_unit.get_includes()
        ])


@dataclass
class CompileCommand:
    refid: int
    filename: str
    args: list[str]
    directory: str = dataclasses.field(default_factory=os.getcwd)


def get_filter_funcs(verilator_root: str):
    verilator_root = os.path.abspath(verilator_root) + "/"

    def is_ignored_top_level(node: clang.cindex.Cursor) -> bool:
        # Anything defined in a header outside Verilator root
        if not node.location.file:
            return True
        filename = os.path.abspath(node.location.file.name)
        return not filename.startswith(verilator_root)

    def is_ignored_def(node: clang.cindex.Cursor,
                       refd: clang.cindex.Cursor) -> bool:
        # __*
        if str(refd.spelling).startswith("__"):
            return True

        # Anything defined in a header outside Verilator root
        if not node.location.file:
            return True
        filename = os.path.abspath(node.location.file.name)
        if not filename.startswith(verilator_root):
            return True

        return False

    def is_ignored_call(refd: clang.cindex.Cursor) -> bool:
        # __*
        if str(refd.spelling).startswith("__"):
            return True

        # std::*
        fqn = fully_qualified_name(refd)
        if fqn and fqn[0] == "std":
            return True

        # Anything declared in a header outside Verilator root
        if not refd.location.file:
            return True
        filename = os.path.abspath(refd.location.file.name)
        if not filename.startswith(verilator_root):
            return True

        return False

    return (is_ignored_top_level, is_ignored_def, is_ignored_call)


def precompile_header(compile_command: CompileCommand, tmp_dir: str) -> str:
    try:
        initial_cwd = os.getcwd()
        os.chdir(compile_command.directory)

        index = Index.create()
        translation_unit = index.parse(compile_command.filename,
                                       compile_command.args)
        for diag in translation_unit.diagnostics:
            if diag.severity > clang.cindex.Diagnostic.Error:
                pch_file = None
                break
        else:
            pch_file = os.path.join(
                tmp_dir,
                f"{compile_command.refid:02}_{os.path.basename(compile_command.filename)}.pch"
            )
            translation_unit.save(pch_file)

        os.chdir(initial_cwd)

        if pch_file:
            return pch_file

    except (TranslationUnitSaveError, TranslationUnitLoadError, OSError):
        pass

    print(
        f"%Warning: Precompiling failed, skipping: {compile_command.filename}")
    return ""


# Compile and analyze inputs in a single process.
def run_analysis(ccl: Iterable[CompileCommand], pccl: Iterable[CompileCommand],
                 diagnostic_cb: Callable[[Diagnostic],
                                         None], verilator_root: str):
    (is_ignored_top_level, is_ignored_def,
     is_ignored_call) = get_filter_funcs(verilator_root)

    prefix = "verilator_clang_check_attributes_"
    with tempfile.TemporaryDirectory(prefix=prefix) as tmp_dir:
        extra_args = []
        for pcc in pccl:
            pch_file = precompile_header(pcc, tmp_dir)
            if pch_file:
                extra_args += ["-include-pch", pch_file]

        cav = CallAnnotationsValidator(diagnostic_cb, is_ignored_top_level,
                                       is_ignored_def, is_ignored_call)
        for compile_command in ccl:
            cav.compile_and_analyze_file(compile_command.filename,
                                         compile_command.args + extra_args,
                                         compile_command.directory)


class ParallelAnalysisProcess:
    cav: Optional[CallAnnotationsValidator] = None
    diags: set[Diagnostic] = dataclasses.field(default_factory=set)
    tmp_dir: str = ""

    @staticmethod
    def init_data(verilator_root: str, tmp_dir: str):
        (is_ignored_top_level, is_ignored_def,
         is_ignored_call) = get_filter_funcs(verilator_root)

        ParallelAnalysisProcess.cav = CallAnnotationsValidator(
            ParallelAnalysisProcess._diagnostic_handler, is_ignored_top_level,
            is_ignored_def, is_ignored_call)
        ParallelAnalysisProcess.tmp_dir = tmp_dir

    @staticmethod
    def _diagnostic_handler(diag: Diagnostic):
        ParallelAnalysisProcess.diags.add(diag)

    @staticmethod
    def analyze_cpp_file(compile_command: CompileCommand) -> set[Diagnostic]:
        ParallelAnalysisProcess.diags = set()
        assert ParallelAnalysisProcess.cav is not None
        ParallelAnalysisProcess.cav.compile_and_analyze_file(
            compile_command.filename, compile_command.args,
            compile_command.directory)
        return ParallelAnalysisProcess.diags

    @staticmethod
    def precompile_header(compile_command: CompileCommand) -> str:
        return precompile_header(compile_command,
                                 ParallelAnalysisProcess.tmp_dir)


# Compile and analyze inputs in multiple processes.
def run_parallel_analysis(ccl: Iterable[CompileCommand],
                          pccl: Iterable[CompileCommand],
                          diagnostic_cb: Callable[[Diagnostic], None],
                          jobs_count: int, verilator_root: str):
    prefix = "verilator_clang_check_attributes_"
    with tempfile.TemporaryDirectory(prefix=prefix) as tmp_dir:
        with multiprocessing.Pool(
                processes=jobs_count,
                initializer=ParallelAnalysisProcess.init_data,
                initargs=[verilator_root, tmp_dir]) as pool:
            extra_args = []
            for pch_file in pool.imap_unordered(
                    ParallelAnalysisProcess.precompile_header, pccl):
                if pch_file:
                    extra_args += ["-include-pch", pch_file]

            if extra_args:
                for compile_command in ccl:
                    compile_command.args = compile_command.args + extra_args

            for diags in pool.imap_unordered(
                    ParallelAnalysisProcess.analyze_cpp_file, ccl, 1):
                for diag in diags:
                    diagnostic_cb(diag)


class TopDownSummaryPrinter():

    @dataclass
    class FunctionCallees:
        info: FunctionInfo
        calees: set[FunctionInfo]
        mismatch: Optional[FunctionInfo] = None
        non_mt_safe_call_in_mt_safe: Optional[FunctionInfo] = None
        non_pure_call_in_pure: Optional[FunctionInfo] = None

    def __init__(self):
        self._is_first_group = True

        self._funcs: dict[str, TopDownSummaryPrinter.FunctionCallees] = {}
        self._unsafe_in_safe: set[str] = set()

    def begin_group(self, label):
        if not self._is_first_group:
            print()

        print(f"%Error: {label}")

        self._is_first_group = False

    def handle_diagnostic(self, diag: Diagnostic):
        usr = diag.source.usr
        func = self._funcs.get(usr, None)
        if func is None:
            func = TopDownSummaryPrinter.FunctionCallees(diag.source, set())
            self._funcs[usr] = func
        if diag.kind == DiagnosticKind.ANNOTATIONS_DEF_DECL_MISMATCH:
            func.mismatch = diag.target
        else:
            func.calees.add(diag.target)
            self._unsafe_in_safe.add(diag.target.usr)
            if diag.kind == DiagnosticKind.NON_MT_SAFE_CALL_IN_MT_SAFE_CTX:
                func.non_mt_safe_call_in_mt_safe = diag.target
            elif diag.kind == DiagnosticKind.NON_PURE_CALL_IN_PURE_CTX:
                func.non_pure_call_in_pure = diag.target

    def print_summary(self, root_dir: str):
        row_groups: dict[str, list[list[str]]] = {}
        column_widths = [0, 0]
        for func in sorted(self._funcs.values(),
                           key=lambda func:
                           (func.info.file, func.info.line, func.info.usr)):
            func_info = func.info
            relfile = os.path.relpath(func_info.file, root_dir)

            row_group = []
            name = f"\"{func_info.name}\" "
            if func.mismatch:
                name += "declaration does not match definition"
            elif func.non_mt_safe_call_in_mt_safe:
                name += "is mtsafe but calls non-mtsafe function(s)"
            elif func.non_pure_call_in_pure:
                name += "is pure but calls non-pure function(s)"
            else:
                name += "for unknown reason (please add description)"

            if func.mismatch:
                mrelfile = os.path.relpath(func.mismatch.file, root_dir)
                row_group.append([
                    f"{mrelfile}:{func.mismatch.line}:",
                    f"[{func.mismatch.annotations}]",
                    func.mismatch.name + " [declaration]"
                ])

            row_group.append([
                f"{relfile}:{func_info.line}:", f"[{func_info.annotations}]",
                func_info.name
            ])

            for callee in sorted(func.calees,
                                 key=lambda func:
                                 (func.file, func.line, func.usr)):
                crelfile = os.path.relpath(callee.file, root_dir)
                row_group.append([
                    f"{crelfile}:{callee.line}:", f"[{callee.annotations}]",
                    "  " + callee.name
                ])

            row_groups[name] = row_group

            for row in row_group:
                for row_id, value in enumerate(row[0:-1]):
                    column_widths[row_id] = max(column_widths[row_id],
                                                len(value))

        for label, rows in sorted(row_groups.items(), key=lambda kv: kv[0]):
            self.begin_group(label)
            for row in rows:
                print(f"{row[0]:<{column_widths[0]}}  "
                      f"{row[1]:<{column_widths[1]}}    "
                      f"{row[2]}")
        print(
            f"Number of functions reported unsafe: {len(self._unsafe_in_safe)}"
        )


def main():
    default_verilator_root = os.path.abspath(
        os.path.join(os.path.dirname(__file__), ".."))

    parser = argparse.ArgumentParser(
        allow_abbrev=False,
        formatter_class=argparse.RawDescriptionHelpFormatter,
        description="""Check function annotations for correctness""",
        epilog=
        """Copyright 2022-2023 by Wilson Snyder. Verilator is free software;
    you can redistribute it and/or modify it under the terms of either the GNU
    Lesser General Public License Version 3 or the Apache License 2.0.
    SPDX-License-Identifier: LGPL-3.0-only OR Apache-2.0""")

    parser.add_argument("--verilator-root",
                        type=str,
                        default=default_verilator_root,
                        help="Path to Verilator sources root directory.")
    parser.add_argument("--jobs",
                        "-j",
                        type=int,
                        default=0,
                        help="Number of parallel jobs to use.")
    parser.add_argument("--cxxflags",
                        type=str,
                        default=None,
                        help="Flags passed to clang++.")
    parser.add_argument(
        "--compilation-root",
        type=str,
        default=os.getcwd(),
        help="Directory used as CWD when compiling source files.")
    parser.add_argument(
        "-c",
        "--precompile",
        action="append",
        help="Header file to be precompiled and cached at the start.")
    parser.add_argument("file",
                        type=str,
                        nargs="+",
                        help="Source file to analyze.")

    cmdline = parser.parse_args()

    if cmdline.jobs == 0:
        cmdline.jobs = max(1, len(os.sched_getaffinity(0)))

    if not cmdline.compilation_root:
        cmdline.compilation_root = cmdline.verilator_root

    verilator_root = os.path.abspath(cmdline.verilator_root)
    compilation_root = os.path.abspath(cmdline.compilation_root)

    default_cxx_flags = [
        f"-I{verilator_root}/src",
        f"-I{verilator_root}/include",
        f"-I{verilator_root}/src/obj_opt",
        "-fcoroutines-ts",
    ]
    if cmdline.cxxflags is not None:
        cxxflags = shlex.split(cmdline.cxxflags)
    else:
        cxxflags = default_cxx_flags

    precompile_commands_list = []

    if cmdline.precompile:
        hdr_cxxflags = ['-xc++-header'] + cxxflags
        for refid, file in enumerate(cmdline.precompile):
            filename = os.path.abspath(file)
            compile_command = CompileCommand(refid, filename, hdr_cxxflags,
                                             compilation_root)
            precompile_commands_list.append(compile_command)

    compile_commands_list = []
    for refid, file in enumerate(cmdline.file):
        filename = os.path.abspath(file)
        compile_command = CompileCommand(refid, filename, cxxflags,
                                         compilation_root)
        compile_commands_list.append(compile_command)

    summary_printer = TopDownSummaryPrinter()

    if cmdline.jobs == 1:
        run_analysis(compile_commands_list, precompile_commands_list,
                     summary_printer.handle_diagnostic, verilator_root)
    else:
        run_parallel_analysis(compile_commands_list, precompile_commands_list,
                              summary_printer.handle_diagnostic, cmdline.jobs,
                              verilator_root)

    summary_printer.print_summary(verilator_root)


if __name__ == '__main__':
    main()
