#!/usr/bin/env python3
"""
Automates bootloader and firmware flashing over serial for Renesas platforms.

Features:
    - Uploads Flash Writer and firmware via serial
    - Supports multiple platforms via BOOTASSETS_MAPPING or external JSON config

Usage:
    sudo python3 flash_bootloader.py g2l --serial-port /dev/ttyUSB0 --baud-rate 115200
    sudo python3 flash_bootloader.py g2l --bootassets-dir /path/to/files --bootassets-config bootassets.json
    sudo python3 flash_bootloader.py g2l --mac-addresses 0e2503363770,0e2503363771

    If system already flashed image and able to boot into OS, then pass user and pwd
    when the user and pwd is not ubuntu.
    sudo python3 flash_bootloader.py g2l --mac-addresses 0e2503363770,0e2503363771 --user test --pwd test
"""

import os
import time
import serial
import argparse
import logging
import json
import re
import glob

# Default Platform to boot asset address mapping
DEFAULT_BOOTASSETS_MAPPING = {
    "g2l": {
        "flash_writer": "Flash_Writer_SCIF_RZG2L_SMARC_DDR4_2GB.mot",
        "bootloader": {
            "bl2_bp-smarc-rzg2l_pmic.srec": ["11E00", "00000"],
            "fip-smarc-rzg2l_pmic.srec": ["00000", "1D200"],
        },
    },
    "g2l_sec": {
        "flash_writer": "Flash_Writer_SCIF_RZG2L_SMARC_PMIC_DDR4_2GB_1PCS_TBB.mot",
        "bootloader": {
            "bl2_bp-smarc-rzg2l_pmic_tbb.srec": ["11E00", "00000"],
            "fip-smarc-rzg2l_pmic_tbb.srec": ["00000", "1D200"],
        },
    },
    "g2lc": {
        "flash_writer": "Flash_Writer_SCIF_RZG2LC_SMARC_DDR4_2GB.mot",
        "bootloader": {
            "bl2_bp-smarc-rzg2lc.srec": ["11E00", "00000"],
            "fip-smarc-rzg2lc.srec": ["00000", "1D200"],
        },
    },
    "g2lc_sec": {
        "flash_writer": "Flash_Writer_SCIF_RZG2LC_SMARC_DDR4_1GB_1PCS_TBB.mot",
        "bootloader": {
            "bl2_bp-smarc-rzg2lc_tbb.srec": ["11E00", "00000"],
            "fip-smarc-rzg2lc_tbb.srec": ["00000", "1D200"],
        },
    },
    "g2ul": {
        "flash_writer": "Flash_Writer_SCIF_RZG2UL_SMARC_DDR4_1GB_1PCS.mot",
        "bootloader": {
            "bl2_bp-smarc-rzg2ul.srec": ["11E00", "00000"],
            "fip-smarc-rzg2ul.srec": ["00000", "1D200"],
        },
    },
    "g2ul_sec": {
        "flash_writer": "Flash_Writer_SCIF_RZG2UL_SMARC_DDR4_1GB_1PCS_TBB.mot",
        "bootloader": {
            "bl2_bp-smarc-rzg2ul_tbb.srec": ["11E00", "00000"],
            "fip-smarc-rzg2ul_tbb.srec": ["00000", "1D200"],
        },
    },
    "g3s": {
        "flash_writer": "FlashWriter-smarc-rzg3s.mot",
        "bootloader": {
            "bl2_bp_spi-smarc-rzg3s.srec": ["A1E00", "00000"],
            "fip-smarc-rzg3s.srec": ["00000", "64000"],
        },
    },
}

logging.basicConfig(
    level=logging.INFO,  # Default to INFO
    format="%(asctime)s [%(levelname)s] %(message)s",
)


class FirmwareDeploymentError(Exception):
    pass


class MacDeploymentError(Exception):
    pass


class SerialConsole:
    def __init__(self, port, baudrate=115200, timeout=3):
        self.port = port
        self.baudrate = baudrate
        self.timeout = timeout
        self.con = None
        self._escape = "\r"

    def connect(self):
        try:
            self.con = serial.Serial(
                self.port, self.baudrate, timeout=self.timeout
            )
            logging.info(
                "Connected to serial port %s at baudrate %s",
                self.port,
                self.baudrate,
            )
            return True
        except Exception as e:
            logging.error("Serial connection error: %s", e)
            return False

    def close(self):
        if self.con and self.con.is_open:
            self.con.close()
            logging.info("Serial connection closed")

    def write_con_no_wait(self, data_str):
        if self.con and self.con.is_open:
            self.con.write((data_str + self._escape).encode("utf-8"))

    def read_con(self):
        empty_count = 0
        while True:
            if self.con and self.con.is_open:
                line = (
                    self.con.readline()
                    .decode("utf-8", errors="ignore")
                    .strip()
                )
                logging.debug("Received from serial: %s", line)

                if line:
                    return line
                else:
                    empty_count += 1
                    if empty_count >= 10:
                        logging.debug(
                            "Received 10 empty lines from serial."
                            "Sending enter..."
                        )
                        self.write_con_no_wait("")

    def send_file_con(self, filepath):
        try:
            logging.info("Sending file: %s", filepath)
            with open(filepath, "rb") as file:
                data = file.read()
                self.con.write(data)
                self.con.flush()
            logging.info("File sent successfully")
            return True
        except Exception as e:
            logging.error("Failed to send file %s: %s", filepath, e)
            return False


def run_xls2(con):
    logging.info("Waiting for prompt to run xls2")
    while True:
        mesg = con.read_con()
        if ">" in mesg:
            '''
            We send command xls2 twice for workaround issue that
            first xls2 command will always fail.
            [DEBUG] Received from serial: xls2
            [DEBUG] Received from serial: command not found
            [DEBUG] Received from serial: >
            '''
            con.write_con_no_wait("xls2")
            time.sleep(0.5)
            con.write_con_no_wait("xls2")
            logging.info("'xls2' command sent")
            break


def run_sup(con):
    logging.info("Waiting for prompt to run SUP")
    while True:
        mesg = con.read_con()
        if ">" in mesg:
            con.write_con_no_wait("SUP")
            logging.info("'SUP' command sent")
            break


def run_xls2_stage(con, data1, data2):
    logging.info("Running xls2 stage with addresses: %s, %s", data1, data2)
    while True:
        mesg = con.read_con()
        if "Please Input Program Top Address" in mesg:
            con.write_con_no_wait(data1)
            logging.info("Sent Program Top Address: %s", data1)
        elif "Please Input Qspi Save Address" in mesg:
            con.write_con_no_wait(data2)
            logging.info("Sent Qspi Save Address: %s", data2)
        elif "please send ! ('.' & CR stop load)" in mesg:
            logging.info("Received prompt to start sending data")
            break


def check_xls2_upload(con):
    logging.info("Checking xls2 upload confirmation")
    while True:
        mesg = con.read_con()
        if "Clear OK?(y/n)" in mesg:
            con.write_con_no_wait("y")
            logging.info("Confirmed uploaded!")
            time.sleep(5)
            break


def write_mac(con, mac: list, user, pwd):
    logging.info("Writting MAC address to U-Boot")
    while True:
        mesg = con.read_con()
        # Handle booting into the system
        # The default user and pwd are ubuntu/ubuntu
        if "ubuntu login" in mesg:
            logging.info("In Ubuntu system. Trying to login into system.")
            con.write_con_no_wait(user)
            time.sleep(3)
            con.write_con_no_wait(pwd)
        # Handle if already login then send a system reboot
        elif "@ubuntu:" in mesg or "":
            logging.info("In Ubuntu system. Trying to reboot into U-Boot.")
            con.write_con_no_wait("sudo reboot")
        elif "Hit any key to stop autoboot" in mesg:
            con.write_con_no_wait("")
            logging.info("Boot into U-Boot shell")
        elif "=>" in mesg:
            logging.info("In U-Boot shell. Starting to write MAC.")
            con.write_con_no_wait("setenv ethaddr %s" % mac[0])
            time.sleep(1)
            con.write_con_no_wait("setenv eth1addr %s" % mac[1])
            time.sleep(1)
            con.write_con_no_wait("saveenv")
            time.sleep(3)
            logging.info("Write MAC addresses finished.")
            break
        else:
            con.write_con_no_wait("")


def normalize_mac(mac):
    # Remove colons or dashes and uppercase
    mac = re.sub(r"[^0-9A-Fa-f]", "", mac).upper()
    if len(mac) != 12:
        return None
    # Return in standard colon-separated format
    return ":".join(mac[i: i + 2] for i in range(0, 12, 2))


def validate_mac_list(mac_list):
    mac_patten = re.compile(
        r"""
            (^([0-9A-Fa-f]{2}[:-]){5}([0-9A-Fa-f]{2})$) |
            (^([0-9A-Fa-f]{12})$)
        """,
        re.VERBOSE,
    )
    valid_mac = []
    for mac in mac_list:
        if mac_patten.match(mac):
            normalized = normalize_mac(mac)
            if normalized:
                valid_mac.append(normalized)
            else:
                raise MacDeploymentError("MAC address validate failed")
        else:
            raise MacDeploymentError("MAC address validate failed")
    return valid_mac


def verify_bootassets_exist(platform_config, bootassets_dir):
    expected_files = [platform_config["flash_writer"]]
    expected_files.extend(platform_config["bootloader"].keys())

    missing_files = []
    for fname in expected_files:
        fpath = os.path.join(bootassets_dir, fname)
        if not os.path.isfile(fpath):
            missing_files.append(fname)

    if missing_files:
        raise FirmwareDeploymentError(
            f"Missing required boot asset file(s): {', '.join(missing_files)}"
        )

    logging.info("All required boot asset files found.")
    return True


def load_bootassets_mapping(config_path):
    try:
        with open(config_path, "r") as f:
            mapping = json.load(f)
        logging.info("Loaded boot asset mapping from %s", config_path)
        return mapping
    except Exception as e:
        raise FirmwareDeploymentError(
            f"Failed to load boot asset mapping JSON: {e}"
        )


def deploy_firmware_via_serial(
    platform, bootassets_dir, bootassets_mapping, con
):
    logging.info("Starting firmware deployment for platform: %s", platform)

    if platform not in bootassets_mapping:
        raise FirmwareDeploymentError(f"Unsupported platform: {platform}")

    platform_config = bootassets_mapping[platform]

    verify_bootassets_exist(platform_config, bootassets_dir)

    flash_writer_path = os.path.join(
        bootassets_dir, platform_config["flash_writer"]
    )
    try:
        if not con.send_file_con(flash_writer_path):
            raise FirmwareDeploymentError("Failed to send Flash Writer")

        run_sup(con)

        # Reconnect console with baudrate
        con.close()
        logging.info("Reconnecting console with baudrate 921600 ...")
        con = SerialConsole(port=args.serial_port, baudrate=921600)
        if not con.connect():
            raise FirmwareDeploymentError(
                f"Failed to connect to {args.serial_port}"
            )

        for asset_filename, addresses in platform_config["bootloader"].items():
            asset_path = os.path.join(bootassets_dir, asset_filename)
            logging.info("Starting deployment stage for %s", asset_filename)

            run_xls2(con)
            run_xls2_stage(con, addresses[0], addresses[1])

            if not con.send_file_con(asset_path):
                raise FirmwareDeploymentError(
                    f"Failed to send {asset_filename}"
                )

            check_xls2_upload(con)
            logging.info("Completed deployment stage for %s", asset_filename)

        logging.info("Firmware deployment complete")

    finally:
        con.close()


def find_ftdi_uart_port():
    pattern = '/dev/serial/by-id/usb-FTDI_FT230X_Basic_UART_*-if00-port0'
    matches = glob.glob(pattern)

    if matches:
        # Return the serial path by id
        return matches[0]

    # No matching device found return ttyUSB0 as default serial port
    return "/dev/ttyUSB0"


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "platform", help="Platform name (must match key in boot asset mapping)"
    )
    parser.add_argument("--serial-port", default=find_ftdi_uart_port())
    parser.add_argument("--baud-rate", type=int, default=115200)
    parser.add_argument(
        "--bootassets-dir",
        default=os.getcwd(),
        help=(
            "Directory where boot assets are stored"
            "(default: current directory)"
        ),
    )
    parser.add_argument(
        "--bootassets-config",
        type=str,
        help="Path to JSON file defining boot asset mapping",
    )
    parser.add_argument(
        "--debug", action="store_true", help="Enable debug logging"
    )
    parser.add_argument(
        "--mac-addresses",
        type=lambda s: s.split(","),
        help="Comma-separated list of MAC addresses",
    )
    parser.add_argument("--user", type=str, default="ubuntu", help="Username")
    parser.add_argument("--pwd", type=str, default="ubuntu", help="Password")

    args = parser.parse_args()

    if args.debug:
        logging.getLogger().setLevel(logging.DEBUG)

    # Create and connect serial console once
    con = SerialConsole(port=args.serial_port, baudrate=args.baud_rate)
    if not con.connect():
        raise FirmwareDeploymentError(
            f"Failed to connect to {args.serial_port}"
        )

    try:
        # Handle MAC writing
        if args.mac_addresses:
            try:
                mac = validate_mac_list(args.mac_addresses)
                write_mac(con, mac, user=args.user, pwd=args.pwd)
            except Exception:
                raise MacDeploymentError("MAC deployment failed")
        else:
            try:
                # Load bootassets mapping and run firmware deployment
                bootassets_mapping = (
                    load_bootassets_mapping(args.bootassets_config)
                    if args.bootassets_config
                    else DEFAULT_BOOTASSETS_MAPPING
                )
                deploy_firmware_via_serial(
                    platform=args.platform,
                    bootassets_dir=args.bootassets_dir,
                    bootassets_mapping=bootassets_mapping,
                    con=con,
                )
            except FirmwareDeploymentError as e:
                logging.error(str(e))
    finally:
        con.close()
