import struct
import time
from enum import IntEnum
from queue import Queue
from typing import Protocol, Any, cast, Final

import bpy
from bpy.types import PoseBone, Object, ChildOfConstraint, Operator
from mathutils import Matrix, Vector, Euler, Quaternion

from .preset_storage import Preset
from .stream import Stream
from ..properties import Context


class CoordinateSystem(IntEnum):
    L_XYZ = 0
    L_XZY = 1
    L_YXZ = 2
    L_YZX = 3
    L_ZXY = 4
    L_ZYX = 5

    R_XYZ = 6
    R_XZY = 7
    R_YXZ = 8
    R_YZX = 9
    R_ZXY = 10
    R_ZYX = 11


class PositionUnit(IntEnum):
    Meter = 0
    Centimeter = 1


class RotationUnit(IntEnum):
    Degrees = 0
    Radians = 1


class ExecuteOnMainThreadFunction(Protocol):
    def __call__(self, context: Context, discard: bool, *args: Any) -> Any: ...


class UserFacingError(Exception):
    pass


class MsgType(IntEnum):
    SET_UP_RIG = 0
    CONVERT_WORLD_POSE_TO_LOCAL_POSE = 1
    CONVERT_LOCAL_POSE_TO_WORLD_POSE = 2
    RECEIVE_MESSAGES_FROM_MAYA = 3
    SET_LOCAL_KEYFRAMES = 4
    SET_FRAME_RATE = 5
    CLOSE = 6
    READ_LOCAL_POSE = 7
    WRITE_LOCAL_POSE = 8
    READ_WORLD_POSE = 9
    WRITE_WORLD_POSE = 10
    SET_WORLD_KEYFRAMES = 11
    HEARTBEAT = 12


class TangentMode(IntEnum):
    AUTO = 0
    STEPPED = 1
    CUSTOM = 2


supported_rotation_modes: Final = [
    "XYZ",
    "XZY",
    "YXZ",
    "YZX",
    "ZXY",
    "ZYX",
]

SetUpRigMessage = list[tuple[str, str, Vector, Euler, str]]


class Keyframe:
    time: float
    value: float

    def __init__(
        self,
        *,
        time: float,
        value: float,
    ):
        self.time = time
        self.value = value

    def __str__(self) -> str:
        return f"Keyframe(" f"time={self.time}, " f"value={self.value}, " f")"


class KeyframeWithTangents:
    time: float
    value: float
    in_tangent: float
    out_tangent: float
    in_weight: float
    out_weight: float

    def __init__(
        self,
        *,
        time: float,
        value: float,
        in_tangent: float,
        out_tangent: float,
        in_weight: float,
        out_weight: float,
    ):
        self.time = time
        self.value = value
        self.in_tangent = in_tangent
        self.out_tangent = out_tangent
        self.in_weight = in_weight
        self.out_weight = out_weight

    def __str__(self) -> str:
        return (
            f"Keyframe("
            f"time={self.time}, "
            f"value={self.value}, "
            f"in_tangent={self.in_tangent}, "
            f"out_tangent={self.out_tangent}, "
            f"in_weight={self.in_weight}, "
            f"out_weight={self.out_weight}"
            f")"
        )


class MessageHandler:
    input_queue: Queue[ExecuteOnMainThreadFunction] = Queue()
    output_queue: Queue[Any] = Queue()
    deferred_queue: Queue[ExecuteOnMainThreadFunction] = Queue()
    _preset: Preset | None = None

    def read_message(self, stream: Stream) -> bool:
        assert self._preset is not None

        if not stream.can_read():
            return True

        msg_type = stream.receive_int()
        match msg_type:
            case MsgType.HEARTBEAT:
                stream.send_bool(True)
            case (
                MsgType.CONVERT_WORLD_POSE_TO_LOCAL_POSE
                | MsgType.CONVERT_LOCAL_POSE_TO_WORLD_POSE
            ):
                pose_count = stream.receive_int()
                restore_original_pose = stream.receive_bool()
                poses = [
                    [
                        (
                            (
                                stream.receive_float(),
                                stream.receive_float(),
                                stream.receive_float(),
                            ),
                            (
                                stream.receive_float(),
                                stream.receive_float(),
                                stream.receive_float(),
                            ),
                        )
                        for _, name, copy_rotation, copy_position in self._preset[
                            "bone_map"
                        ]
                    ]
                    for _ in range(pose_count)
                ]
                assert len(poses) == pose_count
                from_world = msg_type == MsgType.CONVERT_WORLD_POSE_TO_LOCAL_POSE
                stopwatch = time.time()
                self.input_queue.put(
                    cast(
                        ExecuteOnMainThreadFunction,
                        lambda context, discard, a=[
                            poses[:1],
                            from_world,
                            self._preset,
                        ]: self._convert_poses(context, discard, *a),
                    )
                )
                self.output_queue.get()
                elapsed = time.time() - stopwatch
                batch_size = 1
                while batch_size * elapsed < 30 and batch_size < len(poses):
                    batch_size += 1
                for i in range(0, len(poses), batch_size):
                    batch = poses[i : i + batch_size]
                    self.input_queue.put(
                        cast(
                            ExecuteOnMainThreadFunction,
                            lambda context, discard, a=[
                                batch,
                                from_world,
                                self._preset,
                            ]: self._convert_poses(context, discard, *a),
                        )
                    )
                    converted_poses = self.output_queue.get()
                    for pose in converted_poses:
                        for position, rotation in pose:
                            for axis in range(3):
                                stream.send_float(position[axis])
                            for axis in range(3):
                                stream.send_float(rotation[axis])
            case MsgType.SET_LOCAL_KEYFRAMES:
                name = stream.receive_str()
                tangent_mode = stream.receive_int()
                keys = []
                for bone_index in range(len(self._preset["bone_map"])):
                    key_count = stream.receive_int()
                    if key_count > 0:
                        keys.append(
                            [
                                [
                                    [
                                        KeyframeWithTangents(
                                            time=stream.receive_float(),
                                            value=stream.receive_float(),
                                            in_tangent=stream.receive_float(),
                                            out_tangent=stream.receive_float(),
                                            in_weight=stream.receive_float(),
                                            out_weight=stream.receive_float(),
                                        )
                                        for axis in range(3)
                                    ]
                                    for key_type in range(2)
                                ]
                                for key_index in range(key_count)
                            ]
                        )
                    else:
                        keys.append([])
                self.deferred_queue.put(
                    cast(
                        ExecuteOnMainThreadFunction,
                        lambda context, discard, n=name, k=keys, p=self._preset: self._set_local_keyframes(
                            context, discard, n, k, p
                        ),
                    )
                )
            case MsgType.SET_FRAME_RATE:
                frame_rate = stream.receive_int()
                self.deferred_queue.put(
                    cast(
                        ExecuteOnMainThreadFunction,
                        lambda context, discard, f=frame_rate: self._set_frame_rate(
                            context, f
                        ),
                    )
                )
            case MsgType.CLOSE:
                return False
            case MsgType.READ_LOCAL_POSE | MsgType.READ_WORLD_POSE:
                world = msg_type == MsgType.READ_WORLD_POSE
                args = [world, self._preset]
                self.input_queue.put(
                    cast(
                        ExecuteOnMainThreadFunction,
                        lambda context, discard, a=args: self._read_pose(
                            context, discard, *a
                        ),
                    )
                )
                pose = self.output_queue.get()
                for position, rotation in pose:
                    for axis in range(3):
                        stream.send_float(position[axis])
                    for axis in range(3):
                        stream.send_float(rotation[axis])
            case MsgType.WRITE_LOCAL_POSE | MsgType.WRITE_WORLD_POSE:
                can_be_discarded = stream.receive_bool()
                pose = [
                    (
                        (
                            stream.receive_float(),
                            stream.receive_float(),
                            stream.receive_float(),
                        ),
                        (
                            stream.receive_float(),
                            stream.receive_float(),
                            stream.receive_float(),
                        ),
                    )
                    for _, name, copy_rotation, copy_position in self._preset[
                        "bone_map"
                    ]
                ]
                world = msg_type == MsgType.WRITE_WORLD_POSE
                args = [
                    pose,
                    world,
                    self._preset,
                    can_be_discarded,
                ]
                self.deferred_queue.put(
                    cast(
                        ExecuteOnMainThreadFunction,
                        lambda context, discard, a=args: self._write_pose(
                            context, discard, *a
                        ),
                    )
                )
            case _:
                raise struct.error(f"Invalid msg type {msg_type}")

        return True

    def on_connected(
        self, stream: Stream, protocol_version: int, preset: Preset
    ) -> None:
        marionette_protocol_version = stream.receive_int()
        print(f"Client: Received v{marionette_protocol_version}")
        stream.send_int(protocol_version)
        print(f"Client: Sending v{protocol_version}")
        if marionette_protocol_version != protocol_version:
            print(
                "Mismatched DCC protocol version. Make sure the plugin version is compatible with the current "
                "Marionette version."
            )
            return
        stream.send_str("BLENDER")
        stream.send_bool(False)

        paired_bones = [r for r in preset["bone_map"] if r[0] and r[1]]
        self._preset = cast(
            Preset,
            {
                "bone_map": paired_bones,
                "tpose": preset["tpose"],
                "enable_ik": preset["enable_ik"],
            },
        )
        print(f"Client: Creating set up message")
        self.input_queue.put(
            cast(
                ExecuteOnMainThreadFunction,
                lambda context, discard, p=self._preset: self._create_set_up_message(
                    context, discard, p
                ),
            )
        )
        message: SetUpRigMessage = self.output_queue.get()
        print(f"Client: Sending coord system")
        stream.send_int(int(CoordinateSystem.R_XZY))
        stream.send_int(int(PositionUnit.Meter))
        stream.send_int(int(RotationUnit.Radians))
        stream.send_int(len(self._preset["bone_map"]))
        for bone, name, position, rotation, rotate_order in message:
            stream.send_str(bone)
            stream.send_str(name)
            for i in range(3):
                stream.send_float(position[i])
            stream.send_str(rotate_order)
            for i in range(3):
                stream.send_float(rotation[i])
        print(f"Client: Waiting for handshake...")
        assert stream.receive_str() == "HANDSHAKE COMPLETED"
        print(f"Client: Received hand shake")
        stream.send_str("ROGER")
        print("Client: Set up completed")

    @staticmethod
    def _create_set_up_message(
        context: Context, discard: bool, preset: Preset
    ) -> SetUpRigMessage:
        context.view_layer.update()  # get latest scene updates
        rig: bpy.types.Object = context.scene.mb_control_rig
        msg: SetUpRigMessage = []
        for human_name, bone_name, copy_rotation, copy_position in preset["bone_map"]:
            pose_bone = rig.pose.bones[bone_name]
            tpose_mapping = next(
                t for t in preset["tpose"] if t["bone_name"] == bone_name
            )
            pos = Vector(tpose_mapping["world_position"])
            rot = Quaternion(tpose_mapping["world_rotation"]).to_euler(
                pose_bone.rotation_mode
            )  # type: ignore[call-arg]
            msg.append((human_name, bone_name, pos, rot, pose_bone.rotation_mode))
        return msg

    @staticmethod
    def get_world_coords(rig: Object, transform: PoseBone) -> tuple[Vector, Quaternion]:
        mat = rig.matrix_world @ transform.matrix
        l, r, s = mat.decompose()  # current location, rotation, scale
        return l, r

    @staticmethod
    def get_world_coords_as_euler(
        rig: Object, transform: PoseBone, rotation_order: str
    ) -> tuple[Vector, Euler]:
        mat = rig.matrix_world @ transform.matrix
        l, r, s = mat.decompose()  # current location, rotation, scale
        return l, r.to_euler(rotation_order)  # type: ignore[call-arg]

    @staticmethod
    def _read_pose(
        context: Context,
        discard: bool,
        world: bool,
        preset: Preset,
    ) -> list[tuple[Vector, Euler]]:
        pose = []
        for (_, name, copy_rotation, copy_position), rotation_order in zip(
            preset["bone_map"], MessageHandler.rotation_orders(preset)
        ):
            transform = context.scene.mb_control_rig.pose.bones[name]
            if world:
                translation, rotation = MessageHandler.get_world_coords_as_euler(
                    context.scene.mb_control_rig,
                    transform,
                    rotation_order,
                )
            else:
                translation = transform.location.copy()
                rotation = transform.rotation_euler.copy()
            pose.append((translation, rotation))
        return pose

    @staticmethod
    def _matrix_from_world_coords(
        rig: Object,
        transform: PoseBone,
        copy_rotation: bool,
        copy_position: bool,
        rotation: tuple[float, float, float],
        rotation_order: str,
        position: tuple[float, float, float],
    ) -> Matrix:
        l, r, s = (
            rig.matrix_world @ transform.matrix
        ).decompose()  # current location, rotation, scale
        world_mat = Matrix.LocRotScale(
            Vector(position) if copy_position else l,
            (Euler(rotation, rotation_order).to_quaternion() if copy_rotation else r),
            s,
        )
        arm_mat = rig.matrix_world.inverted() @ world_mat
        return arm_mat

    @staticmethod
    def _matrix_from_local_coords(
        transform: PoseBone,
        copy_rotation: bool,
        copy_position: bool,
        rotation: tuple[float, float, float],
        rotation_order: str,
        position: tuple[float, float, float],
    ) -> Matrix:
        orig_mat = transform.matrix.copy()
        transform.location = position if copy_position else transform.location
        transform.rotation_euler = (
            rotation if copy_rotation else transform.rotation_euler
        )
        arm_mat = transform.matrix.copy()
        transform.matrix = orig_mat
        return arm_mat

    @staticmethod
    def _write_pose(
        context: Context,
        discard: bool,
        pose: list[tuple[tuple[float, float, float], tuple[float, float, float]]],
        world_space: bool,
        preset: Preset,
        can_be_discarded: bool,
    ) -> None:
        if discard and world_space and can_be_discarded:
            return
        for (
            (_, name, copy_rotation, copy_position),
            (position, rotation),
            rotation_order,
        ) in zip(preset["bone_map"], pose, MessageHandler.rotation_orders(preset)):
            if not copy_position and not copy_rotation:
                continue
            transform = context.scene.mb_control_rig.pose.bones[name]
            rotation = (rotation[0], rotation[1], rotation[2])
            if world_space:
                target_matrix = MessageHandler._matrix_from_world_coords(
                    context.scene.mb_control_rig,
                    transform,
                    copy_rotation,
                    copy_position,
                    rotation,
                    rotation_order,
                    position,
                )
                MessageHandler.set_with_constraints(context, transform, target_matrix)
            else:
                if copy_position:
                    transform.location = position
                if copy_rotation:
                    transform.rotation_mode = rotation_order
                    transform.rotation_euler = Euler(rotation, rotation_order)
                context.evaluated_depsgraph_get().update()

    @staticmethod
    def _convert_poses(
        context: Context,
        discard: bool,
        poses: list[
            list[tuple[tuple[float, float, float], tuple[float, float, float]]]
        ],
        world_space: bool,
        preset: Preset,
    ) -> list[list[tuple[Vector, Euler]]]:
        rig: bpy.types.Object = context.scene.mb_control_rig
        transforms: list[bpy.types.PoseBone] = [
            rig.pose.bones[name] for _, name, _, _ in preset["bone_map"]
        ]
        converted_poses = []
        for pose in poses:
            MessageHandler._write_pose(
                context, discard, pose, world_space, preset, False
            )
            if world_space:
                converted_pose = [
                    (transform.location.copy(), transform.rotation_euler.copy())
                    for transform, rotation_order in zip(
                        transforms, MessageHandler.rotation_orders(preset)
                    )
                ]
            else:
                converted_pose = [
                    MessageHandler.get_world_coords_as_euler(
                        rig, transform, rotation_order
                    )
                    for transform, rotation_order in zip(
                        transforms, MessageHandler.rotation_orders(preset)
                    )
                ]
            converted_poses.append(converted_pose)
        return converted_poses

    @staticmethod
    def rotation_orders(preset: Preset) -> list[str]:
        return [
            next(t["rotation_mode"] for t in preset["tpose"] if t["bone_name"] == name)
            for _, name, _, _ in preset["bone_map"]
        ]

    @staticmethod
    def _calc_tan_coords(
        x: float, x_p: float, y: float, slope: float, weight: float, c: float
    ) -> tuple[float, float]:
        x_delta = (x_p - x) * weight
        tan_x = (x + x_delta) * c
        tan_y = y + slope * x_delta
        return tan_x, tan_y

    @staticmethod
    def _set_local_keyframes(
        context: Context,
        discard: bool,
        name: str,
        keys: list[list[list[list[KeyframeWithTangents]]]],
        preset: Preset,
    ) -> None:
        rig: Object = context.scene.mb_control_rig
        if rig.animation_data is None:
            rig.animation_data_create()
        action_name = rig.name + "_" + name
        action = bpy.data.actions.get(name)
        if action is None:
            action = bpy.data.actions.new(action_name)
        rig.animation_data.action = action

        fps = bpy.context.scene.render.fps / bpy.context.scene.render.fps_base
        data_paths = ["location", "rotation_euler"]
        for bone_index in range(len(preset["bone_map"])):
            # set up bone related stuff
            _, name, copy_rotation, copy_position = preset["bone_map"][bone_index]
            should_copy = (copy_position, copy_rotation)
            pose_bone = rig.pose.bones[name]
            for key_index in range(len(keys[bone_index])):
                for key_type in range(2):
                    # set up keyframe related stuff
                    data_path = data_paths[key_type]
                    if should_copy[key_type]:
                        for axis in range(3):
                            key = keys[bone_index][key_index][key_type][axis]
                            frame = fps * key.time
                            pose_bone.keyframe_insert(
                                data_path=data_path, index=axis, frame=frame
                            )
                            fcurve = rig.animation_data.action.fcurves.find(
                                pose_bone.path_from_id(data_path), index=axis
                            )
                            kf_point = fcurve.keyframe_points[-1]
                            # set Y value of keyframe
                            kf_point.co[1] = key.value
                            # set interpolation to bezier
                            kf_point.interpolation = "BEZIER"

                            # Set tangents
                            # IN
                            if (
                                key_index > 0
                            ):  # do not care about in tangent for first kf
                                prev_key = keys[bone_index][key_index - 1][key_type][
                                    axis
                                ]
                                t_x, t_y = MessageHandler._calc_tan_coords(
                                    key.time,
                                    prev_key.time,
                                    key.value,
                                    key.in_tangent,
                                    key.in_weight,
                                    fps,
                                )  # account for transformation time-fps
                                kf_point.handle_left_type = "FREE"
                                kf_point.handle_left = (t_x, t_y)

                            # OUT
                            if key_index < (
                                len(keys[bone_index]) - 1
                            ):  # do not care about out tangent for last kf
                                next_key = keys[bone_index][key_index + 1][key_type][
                                    axis
                                ]
                                t_x, t_y = MessageHandler._calc_tan_coords(
                                    key.time,
                                    next_key.time,
                                    key.value,
                                    key.out_tangent,
                                    key.out_weight,
                                    fps,
                                )  # account for transformation time-fps
                                kf_point.handle_right_type = "FREE"
                                kf_point.handle_right = (t_x, t_y)

    @staticmethod
    def _set_frame_rate(context: Context, fps: int) -> None:
        context.scene.render.fps = fps
        context.scene.render.fps_base = 1

    @staticmethod
    def set_with_constraints(
        context: Context, transform: PoseBone, target_matrix: Matrix
    ) -> None:
        child_of_constraint: ChildOfConstraint | None = None
        for constraint in transform.constraints:
            if constraint.type == "CHILD_OF" and constraint.influence >= 1:
                if child_of_constraint is not None:
                    child_of_constraint = None
                    break
                else:
                    child_of_constraint = cast(ChildOfConstraint, constraint)
            elif constraint.influence > 0:
                child_of_constraint = None
                break

        if child_of_constraint is not None:
            parent_m = child_of_constraint.target.pose.bones[
                child_of_constraint.subtarget
            ].matrix
            transform.matrix = (
                child_of_constraint.inverse_matrix.inverted()
                @ parent_m.inverted()
                @ target_matrix
            )
        else:
            transform.matrix = target_matrix
        context.evaluated_depsgraph_get().update()
        if not MessageHandler.almost_equal(transform.matrix, target_matrix):
            print(
                f'ERROR: Failed to set matrix due to unsupported constraints on {transform}. Please contact us at "https://marionettexr.com".'
            )
            print(f"{transform.matrix=}")
            print(f"{target_matrix=}")

    @staticmethod
    def almost_equal(a: Matrix, b: Matrix, tolerance: float = 0.001) -> bool:
        for i in range(4):
            for j in range(4):
                if abs(a[i][j] - b[i][j]) > tolerance:
                    return False
        return True
