#!/usr/bin/env python3
#
# Wrapper script to be deployed on machines whose network interfaces should be
# controllable via the RawNetworkInterfaceDriver. A /etc/labgrid/helpers.yaml
# can deny access to network interfaces. See below.
#
# This is intended to be used via sudo. For example, add via visudo:
# %developers ALL = NOPASSWD: /usr/sbin/labgrid-raw-interface

import argparse
import os
import string
import sys

import yaml


def get_denylist():
    denylist_file = "/etc/labgrid/helpers.yaml"
    try:
        with open(denylist_file) as stream:
            data = yaml.load(stream, Loader=yaml.SafeLoader)
    except (PermissionError, FileNotFoundError, AttributeError) as e:
        raise Exception(f"No configuration file ({denylist_file}), inaccessable or invalid yaml") from e

    denylist = data.get("raw-interface", {}).get("denied-interfaces", [])

    if not isinstance(denylist, list):
        raise Exception("No explicit denied-interfaces or not a list, please check your configuration")

    denylist.append("lo")

    return denylist


def main(program, options):
    if not options.ifname:
        raise ValueError("Empty interface name.")
    if any((c == "/" or c.isspace()) for c in options.ifname):
        raise ValueError(f"Interface name '{options.ifname}' contains invalid characters.")
    if len(options.ifname) > 16:
        raise ValueError(f"Interface name '{options.ifname}' is too long.")

    denylist = get_denylist()

    if options.ifname in denylist:
        raise ValueError(f"Interface name '{options.ifname}' is denied in denylist.")

    programs = ["tcpreplay", "tcpdump", "ip", "ethtool"]
    if program not in programs:
        raise ValueError(f"Invalid program {program} called with wrapper, valid programs are: {programs}")

    args = [
        program,
    ]

    if program == "tcpreplay":
        args.append(f"--intf1={options.ifname}")
        args.append("-")

    elif program == "tcpdump":
        args.append("-n")
        args.append(f"--interface={options.ifname}")
        # Write out each packet as it is received
        args.append("--packet-buffered")
        # Capture complete packets (for compatibility with older tcpdump versions)
        args.append("--snapshot-length=0")
        args.append("-w")
        args.append("-")

        if options.count:
            args.append("-c")
            args.append(str(options.count))

        if options.timeout:
            # The timeout is implemented by specifying the number of seconds before rotating the
            # dump file, but limiting the number of files to 1
            args.append("-G")
            args.append(str(options.timeout))
            args.append("-W")
            args.append("1")

    elif program == "ip":
        args.append("link")
        args.append("set")
        args.append("dev")
        args.append(options.ifname)
        args.append(options.action)

    elif program == "ethtool":
        allowed_chars = set(string.ascii_letters + string.digits + "-/:")

        if options.subcommand == "change":
            for arg in options.ethtool_change_args:
                if arg.startswith("-") or not allowed_chars.issuperset(arg):
                    raise ValueError(f"ethtool --change arg '{arg}' contains invalid characters")

            args.append("--change")
            args.append(options.ifname)
            args.extend(options.ethtool_change_args)

        elif options.subcommand == "set-eee":
            for arg in options.ethtool_set_eee_args:
                if arg.startswith("-") or not allowed_chars.issuperset(arg):
                    raise ValueError(f"ethtool --set-eee arg '{arg}' contains invalid characters")

            args.append("--set-eee")
            args.append(options.ifname)
            args.extend(options.ethtool_set_eee_args)

        elif options.subcommand == "pause":
            for arg in options.ethtool_pause_args:
                if arg.startswith("-") or not allowed_chars.issuperset(arg):
                    raise ValueError(f"ethtool --pause arg '{arg}' contains invalid characters")

            args.append("--pause")
            args.append(options.ifname)
            args.extend(options.ethtool_pause_args)

    try:
        os.execvp(args[0], args)
    except FileNotFoundError as e:
        raise RuntimeError(f"Missing {program} binary") from e


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-d", "--debug", action="store_true", default=False, help="enable debug mode")
    subparsers = parser.add_subparsers(dest="program", help="program to run")

    # tcpdump
    tcpdump_parser = subparsers.add_parser("tcpdump")
    tcpdump_parser.add_argument("ifname", type=str, help="interface name")
    tcpdump_parser.add_argument("count", type=int, default=None, help="amount of frames to capture while recording")
    tcpdump_parser.add_argument(
        "--timeout", type=int, default=None, help="Amount of time to capture while recording. 0 means capture forever"
    )

    # tcpreplay
    tcpreplay_parser = subparsers.add_parser("tcpreplay")
    tcpreplay_parser.add_argument("ifname", type=str, help="interface name")

    # ip
    ip_parser = subparsers.add_parser("ip")
    ip_parser.add_argument("ifname", type=str, help="interface name")
    ip_parser.add_argument("action", type=str, choices=["up", "down"], help="action, one of {%(choices)s}")

    # ethtool
    ethtool_parser = subparsers.add_parser("ethtool")
    ethtool_subparsers = ethtool_parser.add_subparsers(dest="subcommand")

    # ethtool: change
    ethtool_change_parser = ethtool_subparsers.add_parser("change")
    ethtool_change_parser.add_argument("ifname", type=str, help="interface name")
    ethtool_change_parser.add_argument(
        "ethtool_change_args", metavar="ARG", nargs=argparse.REMAINDER, help="ethtool --change args"
    )

    # ethtool: set-eee
    ethtool_change_parser = ethtool_subparsers.add_parser("set-eee")
    ethtool_change_parser.add_argument("ifname", type=str, help="interface name")
    ethtool_change_parser.add_argument(
        "ethtool_set_eee_args", metavar="ARG", nargs=argparse.REMAINDER, help="ethtool --set-eee args"
    )

    # ethtool: pause
    ethtool_change_parser = ethtool_subparsers.add_parser("pause")
    ethtool_change_parser.add_argument("ifname", type=str, help="interface name")
    ethtool_change_parser.add_argument(
        "ethtool_pause_args", metavar="ARG", nargs=argparse.REMAINDER, help="ethtool --pause args"
    )

    args = parser.parse_args()
    try:
        main(args.program, args)
    except Exception as e:  # pylint: disable=broad-except
        if args.debug:
            import traceback

            traceback.print_exc(file=sys.stderr)
        print(f"ERROR: {e}", file=sys.stderr)
        exit(1)
