#!/usr/bin/python3
# Copyright 2024 Chris Hofstaedtler <chris@hofstaedtler.name>
# SPDX-License-Identifier: GPL-3.0+
#
# Use `black --line-length 120 bin/reform-mcu-tool` to reformat.
#
# Installation
#
# Install `python3-usb1` as a pre-requisite. No other dependencies should be necessary.
#
# Do not run shellcheck on this file
# shellcheck disable=SC1071

_DOC = """
Tool to interact with Microcontrollers used by MNT Research in Reform projects.

It can currently talk to the pocket sysctl firmware, and to the RP2040 bootrom.
"""

_EPILOG = """
Example usage:

To reboot the pocket sysctl into the bootrom:

  $ sudo reform-mcu-tool bootsel pocket-sysctl-1.0

To reset the RP2040 bootrom back into the application:

  $ sudo reform-mcu-tool reset rp2040-boot

"""

import argparse
import struct
import sys

try:
    import usb1
except ModuleNotFoundError as except_inst:
    raise RuntimeError('Library "usb1" not found, please install python3-usb1') from except_inst


USB_VID_PIDCODES = 0x1209
USB_VID_RASPBERRY = 0x2E8A
USB_PID_MNT_POCKET_REFORM_INPUT_10 = 0x6D06
USB_PID_MNT_POCKET_REFORM_SYSCTL_10 = 0x6D07
USB_PID_RASPBERRY_RP2040_BOOT = 0x0003
USB_PID_RASPBERRY_RP2350_BOOT = 0x000F

IS_APP = 0b0000_0001
IS_RP_BOOTROM = 0b0000_0010

MCU_TYPES = {
    "pocket-input-1.0": (USB_VID_PIDCODES, USB_PID_MNT_POCKET_REFORM_INPUT_10, IS_APP),
    "pocket-sysctl-1.0": (USB_VID_PIDCODES, USB_PID_MNT_POCKET_REFORM_SYSCTL_10, IS_APP),
    "rp2040-boot": (USB_VID_RASPBERRY, USB_PID_RASPBERRY_RP2040_BOOT, IS_RP_BOOTROM),
    "rp2350-boot": (USB_VID_RASPBERRY, USB_PID_RASPBERRY_RP2350_BOOT, IS_RP_BOOTROM),
}

RESET_INTERFACE_SUBCLASS = 0
RESET_INTERFACE_PROTOCOL = 1

RESET_REQUEST_BOOTSEL = 1
RESET_REQUEST_FLASH = 2

PICOBOOT_MAGIC = 0x431FD10B
PC_REBOOT = 0x2
PICOBOOT_IF_RESET = 0x41


def device_reset(handle: usb1.USBDeviceHandle, reset_interface: int, reset_request: int):
    handle.claimInterface(reset_interface)
    try:
        handle.controlWrite(
            usb1.TYPE_CLASS | usb1.RECIPIENT_INTERFACE, reset_request, 0, reset_interface, b"", timeout=2000
        )
    except (usb1.USBErrorIO, usb1.USBErrorPipe):
        pass  # Expected. MCU has reset and vanishes from USB. Exact error appears to be timing/fw-dependent.
    else:
        handle.releaseInterface(reset_interface)


def picoboot_reset(handle: usb1.USBDeviceHandle, picoboot_interface: usb1.USBInterfaceSetting):
    handle.claimInterface(picoboot_interface.getNumber())
    out_address = picoboot_interface[0].getAddress()
    handle.clearHalt(out_address)
    in_address = picoboot_interface[1].getAddress()
    handle.clearHalt(in_address)

    pc = 0
    sp = 0
    delay_ms = 500
    reboot_cmd = struct.pack("<LLL", pc, sp, delay_ms)
    reboot_cmd_padded = reboot_cmd + struct.pack("<L", 0)

    token = 1
    picoboot_cmd = struct.pack("<LLBBHL", PICOBOOT_MAGIC, token, PC_REBOOT, len(reboot_cmd), 0, 0) + reboot_cmd_padded
    sent = handle.bulkWrite(out_address, picoboot_cmd, timeout=3000)
    if sent != 32:
        raise ValueError(f"Expected to send picoboot_cmd of size 32, but sent {sent}")

    received = handle.bulkRead(in_address, 1, timeout=10000)

    handle.releaseInterface(picoboot_interface.getNumber())


def find_reset_interface(device: usb1.USBDevice) -> int | None:
    for setting in device.iterSettings():
        if (
            setting.getClass() == 0xFF
            and setting.getSubClass() == RESET_INTERFACE_SUBCLASS
            and setting.getProtocol() == RESET_INTERFACE_PROTOCOL
            and setting.getNumEndpoints() == 0
        ):
            return setting.getNumber()
    return None


def find_picoboot_interface(device: usb1.USBDevice) -> usb1.USBInterfaceSetting | None:
    for setting in device.iterSettings():
        out_address = setting[0].getAddress()
        in_address = setting[1].getAddress()
        if (
            setting.getClass() == 0xFF
            and setting.getNumEndpoints() == 2
            and out_address & 0x80 == 0
            and in_address & 0x80 == 0x80
        ):
            return setting
    return None


def action_bootsel(args, device: usb1.USBDevice):
    target_flags = MCU_TYPES[args.target][2]
    if target_flags & IS_RP_BOOTROM:
        print("E: Device is already in bootrom.")
        return 2

    if (reset_interface := find_reset_interface(device)) is None:
        print("E: Could not find Reset USB Interface.")
        return 1

    serial_number = device.getSerialNumber()
    print(f"I: Resetting device with serial {serial_number} into BOOTSEL")
    handle = device.open()
    device_reset(handle, reset_interface, RESET_REQUEST_BOOTSEL)
    print(f"I: You may now use: $ picotool info --ser {serial_number}")
    return 0


def action_reset(args, device: usb1.USBDevice):
    target_flags = MCU_TYPES[args.target][2]
    if target_flags & IS_APP:
        if (reset_interface := find_reset_interface(device)) is None:
            print("E: Could not find Reset USB Interface.")
            return 1

        print(f"I: Resetting device")
        handle = device.open()
        device_reset(handle, reset_interface, RESET_REQUEST_FLASH)

    elif target_flags & IS_RP_BOOTROM:
        if (reset_interface := find_picoboot_interface(device)) is None:
            print("E: Could not find PICOBOOT USB Interface.")
            return 1

        print(f"I: Resetting bootrom into application")
        handle = device.open()
        picoboot_reset(handle, reset_interface)
    return 0


def action_list(args, usb_context: usb1.USBContext):
    for device in usb_context.getDeviceIterator(skip_on_error=True):
        vid = device.getVendorID()
        pid = device.getProductID()
        for mcu_name, (mcu_vid, mcu_pid, flags) in MCU_TYPES.items():
            if (vid, pid) == (mcu_vid, mcu_pid):
                print(
                    f"Target {mcu_name} ID {vid:04x}:{pid:04x} Serial# {device.getSerialNumber()} "
                    f"USB bus {device.getBusNumber()} address {device.getDeviceAddress()}"
                )
                break
    return 0


def parse_args():
    parser = argparse.ArgumentParser(
        prog="reform-mcu-tool", description=_DOC, epilog=_EPILOG, formatter_class=argparse.RawTextHelpFormatter
    )
    subparsers = parser.add_subparsers(help="Action to execute")

    parser_bootsel = subparsers.add_parser("bootsel", help="Reboot MCU into BOOTSEL mode")
    parser_bootsel.set_defaults(func=action_bootsel)
    parser_bootsel.add_argument(
        "target",
        choices=MCU_TYPES.keys(),
        metavar="TARGET",
        help=f"Target device to operate on. Choices: {', '.join(MCU_TYPES.keys())}",
    )

    parser_reset = subparsers.add_parser("reset", help="Reboot MCU into application mode")
    parser_reset.set_defaults(func=action_reset)
    parser_reset.add_argument(
        "target",
        choices=MCU_TYPES.keys(),
        metavar="TARGET",
        help=f"Target device to operate on. Choices: {', '.join(MCU_TYPES.keys())}",
    )

    parser_list = subparsers.add_parser("list", help="List USB devices matching known VID/PIDs")
    parser_list.set_defaults(func=action_list)

    args = parser.parse_args()
    if "func" not in args:
        parser.print_help()
        parser.exit()
    return args


def run(args, usb_context: usb1.USBContext) -> int:
    if "target" in args:
        (vid, pid, _) = MCU_TYPES[args.target]
        device = usb_context.getByVendorIDAndProductID(vid, pid, skip_on_error=True)
        if not device:
            print(f"E: USB device with Vendor-ID {vid} Product-ID {pid} not found.")
            return 1

        print(
            f"I: Found {device.getManufacturer()} {device.getProduct()} "
            f"on bus {device.getBusNumber()} address {device.getDeviceAddress()}"
        )
        return args.func(args, device)
    else:
        return args.func(args, usb_context)


def main() -> int:
    args = parse_args()

    with usb1.USBContext() as usb_context:
        return run(args, usb_context)


if __name__ == "__main__":
    sys.exit(main())
