#!/usr/bin/python3

# pylint: disable=invalid-name
# pylint: enable=invalid-name

"""Run given checks on the log output."""

import argparse
import json
import pathlib
import re
import shlex
import subprocess
import sys


class Check:
    """The Check class contains all the checks that the caller can run."""

    def __init__(self, log: str) -> None:
        self.errors = 0
        self.command_outputs = self._extract_command_outputs_from_log(log)

    def _error(self, msg: str) -> None:
        print("ERROR: " + msg)
        self.errors += 1

    @staticmethod
    def _extract_command_outputs_from_log(log: str):
        """Extract command outputs from the given log output.

        The output must be framed by a header and a footer line. The header
        line contains the key surrounded by 10 # characters. The footer
        line consist of 40 # characters. Returns a mapping from the output
        key to the command output.
        """
        marker = "#" * 10
        footer = "#" * 40
        matches = re.findall(
            f"^{marker} ([^#]+) {marker}\n(.*?\n){footer}$",
            log,
            flags=re.DOTALL | re.MULTILINE,
        )
        return {m[0]: m[1] for m in matches}

    def get_commands_from_ps_output(self):
        """Get list of command from `ps -ww aux` output."""
        ps_output = self.command_outputs["ps -ww aux"]
        lines = ps_output.strip().split("\n")[1:]
        commands = []
        for line in lines:
            columns = re.split(r"\s+", line, maxsplit=10)
            commands.append(columns[10])
        return commands

    def get_ip_addr(self):
        """Get IP address information from `ip addr` JSON output."""
        ip_addr = json.loads(self.command_outputs["ip -json addr"])
        assert isinstance(ip_addr, list)
        return ip_addr

    def get_ip_route(self):
        """Get IP route information from `ip route` JSON output."""
        ip_route = json.loads(self.command_outputs["ip -json route"])
        assert isinstance(ip_route, list)
        return ip_route

    def run_checks(self, args) -> int:
        """Run the checks and return the number of errors found.

        The methods of this class can be u
        """
        if not args:
            return self.errors

        try:
            check = getattr(self, args[0])
        except AttributeError:
            self._error(f"Check '{args[0]}' not found.")
            return self.errors
        check_args = []

        for arg in args[1:]:
            if not hasattr(self, arg):
                check_args.append(arg)
                continue

            check(*check_args)
            check = getattr(self, arg)
            check_args = []

        check(*check_args)
        return self.errors

    def _check_is_subset(self, expected, actual) -> None:
        """Check that the first dictionary is a subset of the second one.

        Log errors if the sets are different.
        """
        unexpected = actual - expected
        if unexpected:
            self._error(f"Not expected entries: {unexpected}")

    def _check_is_subdict(
        self,
        expected_dict,
        actual_dict,
        log_prefix: str,
    ) -> None:
        """Check that the first dictionary is a subset of the second one.

        Log errors if differences are found.
        """
        missing_keys = set(expected_dict.keys()) - set(actual_dict.keys())
        if missing_keys:
            self._error(f"{log_prefix}Missing keys: {missing_keys}")
        for key, expected_value in sorted(expected_dict.items()):
            actual_value = actual_dict.get(key, "")
            if expected_value != actual_value:
                self._error(
                    f"{log_prefix}Value for key '{key}' differs:"
                    f" '{expected_value}' expected, but got '{actual_value}'"
                )

    # Below are all checks that the user might call.

    def has_hostname(self, hostname_pattern: str) -> None:
        """Check that the hostname matches the given regular expression."""
        hostname = self.command_outputs["hostname"].strip()
        if re.fullmatch(hostname_pattern, hostname):
            print(f"hostname '{hostname}' matches pattern '{hostname_pattern}'")
            return
        self._error(
            f"hostname '{hostname}' does not match"
            f" expected pattern '{hostname_pattern}'"
        )

    def has_interface_mtu(self, device_pattern: str, expected_mtu: str) -> None:
        """Check that a matching network device has the expected MTU set."""
        for device in self.get_ip_addr():
            if not re.fullmatch(device_pattern, device["ifname"]):
                continue
            if str(device["mtu"]) == expected_mtu:
                print(f"device {device['ifname']} has MTU {device['mtu']}")
                return
            self._error(
                f"device {device['ifname']} has MTU {device['mtu']}"
                f" but expected {expected_mtu}"
            )
            return
        self._error(f"no link found that matches '{device_pattern}'")

    def has_ip_addr(self, family: str, addr_pattern: str, device_pattern: str) -> None:
        """Check that a matching network device has a matching IP address."""
        for device in self.get_ip_addr():
            if not re.fullmatch(device_pattern, device["ifname"]):
                continue
            for addr in device["addr_info"]:
                if addr["family"] != family or addr["scope"] != "global":
                    continue
                address = f"{addr['local']}/{addr['prefixlen']}"
                if re.fullmatch(addr_pattern, address):
                    print(f"found addr {address} for {device['ifname']}: {addr}")
                    return
                self._error(
                    f"addr {address} for {device['ifname']}"
                    f" does not match {addr_pattern}: {addr}"
                )
                return
        name = {"inet": "IPv4", "inet6": "IPv6"}[family]
        self._error(
            f"no link found that matches '{device_pattern}' and has an {name} address"
        )

    def has_ipv4_addr(self, addr_pattern: str, device_pattern: str) -> None:
        """Check that a matching network device has a matching IPv4 address."""
        self.has_ip_addr("inet", addr_pattern, device_pattern)

    def has_ipv6_addr(self, addr_pattern: str, device_pattern: str) -> None:
        """Check that a matching network device has a matching IPv6 address."""
        self.has_ip_addr("inet6", addr_pattern, device_pattern)

    def has_ipv4_default_route(self, gateway_pattern: str, device_pattern: str) -> None:
        """Check that the IPv4 default route is via a matching gateway and device."""
        for route in self.get_ip_route():
            if route["dst"] != "default":
                continue
            if not re.fullmatch(gateway_pattern, route["gateway"]) or not re.fullmatch(
                device_pattern, route["dev"]
            ):
                self._error(
                    f"Default IPv4 route does not match expected gateway pattern"
                    f" '{gateway_pattern}' or dev pattern '{device_pattern}': {route}"
                )
                continue
            print(
                f"found IPv4 default route via {route['gateway']}"
                f" for {route['dev']}: {route}"
            )
            return
        self._error("no IPv4 default route found")

    def has_net_conf(self, min_expected_files: str, *expected_net_confs: str) -> None:
        """Compare the /run/net*.conf files.

        There must be at least one /run/net*.conf file in the log output and no
        unexpected one. The format is `<file name>=<expected content>`.
        """

        expected = dict(nc.split("=", maxsplit=1) for nc in expected_net_confs)
        prog = re.compile(r"/run/net[^#]+\.conf")
        got = {
            key: value for key, value in self.command_outputs.items() if prog.match(key)
        }

        if len(got) < int(min_expected_files):
            self._error(
                f"Expected at least {min_expected_files} /run/net*.conf files,"
                f" but got only {len(got)}: {set(got.keys())}"
            )
        self._check_is_subset(set(expected.keys()), set(got.keys()))

        for net_dev in sorted(got.keys()):
            log_prefix = f"{net_dev}: "
            expected_net_conf = parse_net_conf(expected.get(net_dev, ""))
            actual_net_conf = parse_net_conf(got.get(net_dev, ""))
            self._check_is_subdict(expected_net_conf, actual_net_conf, log_prefix)
            print(f"compared {len(expected_net_conf)} items from {net_dev}")

    def has_no_running_processes(self) -> None:
        """Check that there are no remaining running processes from the initrd."""
        processes = drop_kernel_processes(self.get_commands_from_ps_output())

        udevd = "/lib/systemd/systemd-udevd --daemon --resolve-names=never"
        if udevd in processes and get_architecture() != "amd64":
            print(
                "Warning: Ignoring running /lib/systemd/systemd-udevd processes."
                " See https://launchpad.net/bugs/2025369"
            )
            processes = [p for p in processes if p != udevd]

        if len(processes) == 2:
            print(f"found only expected init and ps process: {processes}")
            return
        self._error(
            f"Expected only init and ps process, but got {len(processes)}: {processes}"
        )


def drop_kernel_processes(processes):
    """Return a list of processes with the kernel processes dropped."""
    return [p for p in processes if not p.startswith("[") and not p.endswith("]")]


def get_architecture() -> str:
    """Return architecture of packages dpkg installs."""
    cmd = ["dpkg", "--print-architecture"]
    process = subprocess.run(cmd, capture_output=True, check=True, text=True)
    return process.stdout.strip()


def parse_net_conf(net_conf: str):
    """Parse /run/net*.conf file and return a key to value mapping."""
    items = shlex.split(net_conf)
    return dict(item.split("=", maxsplit=1) for item in items)


def main(arguments) -> int:
    """Run given checks on the log output. Return number of errors."""
    parser = argparse.ArgumentParser()
    parser.add_argument("log_file", metavar="log-file")
    parser.add_argument("checks", metavar="check", nargs="+")
    args = parser.parse_args(arguments)

    log = pathlib.Path(args.log_file).read_text(encoding="ascii")
    check = Check(log)
    return check.run_checks(args.checks)


if __name__ == "__main__":
    sys.exit(main(sys.argv[1:]))
