import math
import struct
import time
from enum import IntEnum
from typing import Callable, Optional

import maya
import maya.api.OpenMaya as om
import maya.utils
from maya import cmds

from marionette.preset_storage import Preset
from marionette.progress_window import ProgressWindow
from marionette.stream import Stream
from marionette.undo_chunk import UndoChunk


class CoordinateSystem(IntEnum):
    L_XYZ = 0
    L_XZY = 1
    L_YXZ = 2
    L_YZX = 3
    L_ZXY = 4
    L_ZYX = 5

    R_XYZ = 6
    R_XZY = 7
    R_YXZ = 8
    R_YZX = 9
    R_ZXY = 10
    R_ZYX = 11


class PositionUnit(IntEnum):
    Meter = 0
    Centimeter = 1


class RotationUnit(IntEnum):
    Degrees = 0
    Radians = 1


class MsgType(IntEnum):
    CONVERT_WORLD_POSE_TO_LOCAL_POSE = 1
    CONVERT_LOCAL_POSE_TO_WORLD_POSE = 2
    SET_LOCAL_KEYFRAMES = 4
    SET_FRAME_RATE = 5
    CLOSE = 6
    READ_LOCAL_POSE = 7
    WRITE_LOCAL_POSE = 8
    READ_WORLD_POSE = 9
    WRITE_WORLD_POSE = 10
    SET_WORLD_KEYFRAMES = 11
    HEARTBEAT = 12


class TangentMode(IntEnum):
    AUTO = 0
    STEPPED = 1
    CUSTOM = 2


SetUpRigMessage = list[tuple[str, str, tuple[float, float, float], tuple[float, float, float], str]]


class Keyframe:
    time: float
    value: float

    def __init__(
            self, *,
            time: float,
            value: float,
    ):
        self.time = time
        self.value = value

    def __str__(self) -> str:
        return (f"Keyframe("
                f"time={self.time}, "
                f"value={self.value}, "
                f")")


class KeyframeWithTangents:
    time: float
    value: float
    in_tangent: float
    out_tangent: float
    in_weight: float
    out_weight: float

    def __init__(
            self, *,
            time: float,
            value: float,
            in_tangent: float,
            out_tangent: float,
            in_weight: float,
            out_weight: float
    ):
        self.time = time
        self.value = value
        self.in_tangent = in_tangent
        self.out_tangent = out_tangent
        self.in_weight = in_weight
        self.out_weight = out_weight

    def __str__(self) -> str:
        return (f"Keyframe("
                f"time={self.time}, "
                f"value={self.value}, "
                f"in_tangent={self.in_tangent}, "
                f"out_tangent={self.out_tangent}, "
                f"in_weight={self.in_weight}, "
                f"out_weight={self.out_weight}"
                f")")


class MessageHandler:
    def __init__(
            self,
            connection_attempt_started: Optional[Callable] = None,
            connected: Optional[Callable] = None,
            disconnected: Optional[Callable] = None,
            set_up_completed: Optional[Callable] = None,
    ) -> None:
        self.connection_attempt_started = connection_attempt_started
        self.connected = connected
        self.disconnected = disconnected
        self.set_up_completed = set_up_completed
        self._rig_has_been_set_up = False
        self.original_pose = None
        self.last_read_message_time = 0

    def read_message(self, stream: Stream) -> bool:
        if not stream.can_read():
            return True

        # If congested we drop messages that are not essential.
        congested = time.time() - self.last_read_message_time < 0.1
        msg_type = stream.receive_int()
        if msg_type == MsgType.HEARTBEAT:
            stream.send_bool(True)
        elif (
                msg_type == MsgType.CONVERT_WORLD_POSE_TO_LOCAL_POSE
                or msg_type == MsgType.CONVERT_LOCAL_POSE_TO_WORLD_POSE
        ):
            pose_count = stream.receive_int()
            restore_original_pose = stream.receive_bool()
            poses = [[
                (
                    (stream.receive_float(), stream.receive_float(), stream.receive_float(),),
                    (stream.receive_float(), stream.receive_float(), stream.receive_float(),),
                )
                for _ in self._preset.human_names
            ] for _ in range(pose_count)]
            assert len(poses) == pose_count
            from_world = msg_type == MsgType.CONVERT_WORLD_POSE_TO_LOCAL_POSE
            batch_size = 100
            for i in range(0, len(poses), batch_size):
                batch = poses[i:i + batch_size]
                converted_poses = maya.utils.executeInMainThreadWithResult(
                    self._convert_poses, batch, from_world, restore_original_pose, self._preset, self.transforms,
                    self.rotation_orders, self.original_pose
                )
                for converted_pose in converted_poses:
                    for position, rotation in converted_pose:
                        for axis in range(3):
                            stream.send_float(position[axis])
                        for axis in range(3):
                            stream.send_float(rotation[axis])
        elif msg_type == MsgType.SET_LOCAL_KEYFRAMES:
            name = stream.receive_str()
            tangent_mode = stream.receive_int()
            keys = []
            for bone_index in range(len(self._preset.human_names)):
                key_count = stream.receive_int()
                if key_count > 0:
                    keys.append([[[KeyframeWithTangents(
                        time=stream.receive_float(),
                        value=stream.receive_float(),
                        in_tangent=stream.receive_float(),
                        out_tangent=stream.receive_float(),
                        in_weight=stream.receive_float(),
                        out_weight=stream.receive_float(),
                    )
                        for axis in range(3)]
                        for key_type in range(2)]
                        for key_index in range(key_count)]
                    )
                else:
                    keys.append([])
            maya.utils.executeDeferred(self._set_local_keyframes, keys, self._preset)
        elif msg_type == MsgType.SET_WORLD_KEYFRAMES:
            tangent_mode = stream.receive_int()
            key_count = stream.receive_int()
            has_bone = [
                stream.receive_bool()
                for bone_index in range(len(self._preset.human_names))
            ]
            keys = []
            for key_index in range(key_count):
                frame = []
                for bone_index in range(len(self._preset.human_names)):
                    if has_bone[bone_index]:
                        frame.append([[
                            Keyframe(
                                time=stream.receive_float(),
                                value=stream.receive_float(),
                            )
                            for axis in range(3)]
                            for key_type in range(2)
                        ])
                    else:
                        frame.append([])
                    keys.append(frame)
            maya.utils.executeDeferred(self._set_world_keyframes, keys, has_bone, self._preset, tangent_mode,
                                       self.transforms,
                                       self.rotation_orders)
        elif msg_type == MsgType.SET_FRAME_RATE:
            frame_rate = stream.receive_int()
            maya.utils.executeDeferred(self._set_frame_rate, frame_rate)
        elif msg_type == MsgType.CLOSE:
            return False
        elif (msg_type == MsgType.READ_LOCAL_POSE
              or msg_type == MsgType.READ_WORLD_POSE):
            world = msg_type == MsgType.READ_WORLD_POSE
            pose = maya.utils.executeInMainThreadWithResult(
                self._read_pose, world, self._preset, self.transforms, self.rotation_orders
            )
            for position, rotation in pose:
                for axis in range(3):
                    stream.send_float(position[axis])
                for axis in range(3):
                    stream.send_float(rotation[axis])
        elif (msg_type == MsgType.WRITE_LOCAL_POSE
              or msg_type == MsgType.WRITE_WORLD_POSE):
            can_be_discarded = stream.receive_bool()
            pose = [
                (
                    (stream.receive_float(), stream.receive_float(), stream.receive_float(),),
                    (stream.receive_float(), stream.receive_float(), stream.receive_float(),),
                )
                for _ in self._preset.human_names
            ]
            world = msg_type == MsgType.WRITE_WORLD_POSE
            if not congested or not can_be_discarded:
                maya.utils.executeDeferred(
                    self._write_pose, pose, world, self._preset, self.transforms, self.rotation_orders
                )
        else:
            raise struct.error(f"Invalid msg type {msg_type}")

        if not congested:
            self.last_read_message_time = time.time()
        return True

    def on_connection_attempt_started(self) -> None:
        maya.utils.executeInMainThreadWithResult(self.connection_attempt_started)

    def on_connected(self, stream: Stream, protocol_version: int, preset: Preset) -> None:
        marionette_protocol_version = stream.receive_int()
        stream.send_int(protocol_version)
        if marionette_protocol_version != protocol_version:
            print("Mismatched DCC protocol version. Make sure the plugin version is compatible with the current "
                  "Marionette version.")
            self.on_disconnected()
            return
        stream.send_str("MAYA")
        stream.send_bool(False)
        self._preset = preset
        maya.utils.executeInMainThreadWithResult(self._set_up_transforms, preset)
        maya.utils.executeInMainThreadWithResult(self.connected)

        message: SetUpRigMessage = maya.utils.executeInMainThreadWithResult(
            lambda p=self._preset: self._create_set_up_message(p)
        )
        if maya.utils.executeInMainThreadWithResult(lambda: cmds.upAxis(q=True, axis=True)) == 'z':
            stream.send_int(CoordinateSystem.R_XZY)
        else:
            stream.send_int(CoordinateSystem.R_XYZ)
        stream.send_int(PositionUnit.Centimeter)
        stream.send_int(RotationUnit.Degrees)
        stream.send_int(len(self._preset.bone_names))
        for bone, name, position, rotation, rotate_order in message:
            stream.send_str(bone)
            stream.send_str(name)
            for i in range(3):
                stream.send_float(position[i])
            stream.send_str(rotate_order)
            for i in range(3):
                stream.send_float(rotation[i])
        assert stream.receive_str() == "HANDSHAKE COMPLETED"
        stream.send_str("ROGER")
        self._rig_has_been_set_up = True
        maya.utils.executeDeferred(self.set_up_completed)

    def on_disconnected(self) -> None:
        self._rig_has_been_set_up = False
        maya.utils.executeInMainThreadWithResult(self._restore_original_pose)
        maya.utils.executeInMainThreadWithResult(self.disconnected)

    def _set_up_transforms(self, preset: Preset) -> None:
        selection_list = om.MSelectionList()
        for name in preset.bone_names:
            selection_list.add(name)
        dag_paths = [selection_list.getDagPath(i) for i in range(len(preset.bone_names))]
        self.transforms = [om.MFnTransform(dag_path) for dag_path in dag_paths]
        # Subtract rotation orders by 1, because it is apparently not zero index in the python api
        self.rotation_orders = [transform.rotationOrder() - 1 for transform in self.transforms]
        self.original_pose = [
            (
                transform,
                transform.translation(om.MSpace.kTransform),
                transform.rotation(om.MSpace.kTransform, True)
            )
            for transform in self.transforms
        ]

    def _restore_original_pose(self) -> None:
        if self.original_pose is not None:
            for transform, position, rotation in self.original_pose:
                transform.setTranslation(position, om.MSpace.kTransform)
                transform.setRotation(rotation, om.MSpace.kTransform)

    @staticmethod
    def _create_set_up_message(preset: Preset) -> SetUpRigMessage:
        return [
            (
                human_name,
                bone_name,
                cmds.xform(bone_name, q=True, rp=True, ws=True),
                cmds.xform(bone_name, query=True, ws=True, rotation=True),
                cmds.xform(bone_name, query=True, ws=True, rotateOrder=True)
            )
            for human_name, bone_name, copy_rotation, copy_position in
            zip(preset.human_names, preset.bone_names, preset.copy_rotation, preset.copy_position)
        ]

    @staticmethod
    def _read_pose(
            world: bool,
            preset: Preset,
            transforms,
            rotation_orders
    ) -> list[tuple[tuple[float, float, float], tuple[float, float, float]]]:
        space = om.MSpace.kWorld if world else om.MSpace.kTransform
        pose = []
        for transform, rotation_order in zip(transforms, rotation_orders):
            translation = transform.rotatePivot(space)
            rotation = transform.rotation(space, True).asEulerRotation().reorder(rotation_order)
            rotation = (
                math.degrees(rotation[0]),
                math.degrees(rotation[1]),
                math.degrees(rotation[2]),
            )
            pose.append((translation, rotation))
        return pose

    @staticmethod
    def _write_pose(
            pose: list[tuple[tuple[float, float, float], tuple[float, float, float]]],
            world: bool,
            preset: Preset,
            transforms,
            rotation_orders
    ) -> None:
        space = om.MSpace.kWorld if world else om.MSpace.kTransform
        for (copy_rotation, copy_position), (pivot_position, rotation), transform, rotation_order in zip(
                zip(preset.copy_rotation, preset.copy_position), pose, transforms, rotation_orders
        ):
            if copy_position:
                position = [
                    transform.translation(space)[0] + pivot_position[0] - transform.rotatePivot(space)[0],
                    transform.translation(space)[1] + pivot_position[1] - transform.rotatePivot(space)[1],
                    transform.translation(space)[2] + pivot_position[2] - transform.rotatePivot(space)[2],
                ]
                transform.setTranslation(om.MVector(position), space)
            if copy_rotation:
                transform.setRotation(om.MEulerRotation(
                    math.radians(rotation[0]),
                    math.radians(rotation[1]),
                    math.radians(rotation[2]),
                    rotation_order
                ).asQuaternion(), space)

    @staticmethod
    def _convert_poses(
            poses: list[list[tuple[tuple[float, float, float], tuple[float, float, float]]]],
            from_world: bool,
            restore_original_pose: bool,
            preset: Preset,
            transforms,
            rotation_orders,
            original_pose
    ) -> list[list[tuple[tuple[float, float, float], tuple[float, float, float]]]]:
        converted_poses = []
        from_space = om.MSpace.kWorld if from_world else om.MSpace.kTransform
        to_space = om.MSpace.kWorld if not from_world else om.MSpace.kTransform
        for pose in poses:
            converted_pose = []
            for (copy_rotation, copy_position), (from_position, rotation), transform, rotation_order in zip(
                    zip(preset.copy_rotation, preset.copy_position), pose, transforms, rotation_orders
            ):
                if copy_position:
                    if from_world:
                        position = transform.translation(from_space)
                        rotate_pivot = transform.rotatePivot(from_space)
                        position = [
                            position[0] + from_position[0] - rotate_pivot[0],
                            position[1] + from_position[1] - rotate_pivot[1],
                            position[2] + from_position[2] - rotate_pivot[2],
                        ]
                        transform.setTranslation(om.MVector(position), from_space)
                    else:
                        transform.setTranslation(om.MVector(from_position), from_space)
                if copy_rotation:
                    transform.setRotation(om.MEulerRotation(
                        math.radians(rotation[0]),
                        math.radians(rotation[1]),
                        math.radians(rotation[2]),
                        rotation_order
                    ).asQuaternion(), from_space)
                if from_world:
                    converted_translation = transform.translation(to_space)
                else:
                    converted_translation = transform.rotatePivot(to_space)
                converted_rotation = transform.rotation(to_space, True).asEulerRotation().reorder(rotation_order)
                converted_rotation = (
                    math.degrees(converted_rotation[0]),
                    math.degrees(converted_rotation[1]),
                    math.degrees(converted_rotation[2]),
                )
                converted_pose.append((converted_translation, converted_rotation))
            converted_poses.append(converted_pose)
        if restore_original_pose:
            for transform, position, rotation in original_pose:
                transform.setTranslation(position, om.MSpace.kTransform)
                transform.setRotation(rotation, om.MSpace.kTransform)
        return converted_poses

    @staticmethod
    def _set_world_keyframes(
            keys: list[list[list[list[Keyframe]]]],
            has_bone: list[bool],
            preset: Preset,
            tangent_mode: TangentMode,
            transforms,
            rotation_orders
    ) -> None:
        with UndoChunk("ReceiveMarionetteKeyframes"):
            key_type_position = 0
            key_type_rotation = 1
            attribute_name = [['translateX',
                               'translateY',
                               'translateZ', ],
                              ['rotateX',
                               'rotateY',
                               'rotateZ', ]]
            for key_index in range(len(keys)):
                for bone_index in range(len(keys[key_index])):
                    if not has_bone[bone_index]:
                        continue
                    bone_name = preset.bone_names[bone_index]
                    copy_rotation = preset.copy_rotation[bone_index]
                    copy_position = preset.copy_position[bone_index]
                    transform = transforms[bone_index]
                    rotation_order = rotation_orders[bone_index]
                    if copy_position:
                        position = transform.translation(om.MSpace.kWorld)
                        rotate_pivot = transform.rotatePivot(om.MSpace.kWorld)
                        position = [
                            position[0] + keys[key_index][bone_index][0][0].value - rotate_pivot[0],
                            position[1] + keys[key_index][bone_index][0][1].value - rotate_pivot[1],
                            position[2] + keys[key_index][bone_index][0][2].value - rotate_pivot[2],
                        ]
                        transform.setTranslation(om.MVector(position), om.MSpace.kWorld)
                    if copy_rotation:
                        transform.setRotation(om.MEulerRotation(
                            math.radians(keys[key_index][bone_index][1][0].value),
                            math.radians(keys[key_index][bone_index][1][1].value),
                            math.radians(keys[key_index][bone_index][1][2].value),
                            rotation_order
                        ).asQuaternion(), om.MSpace.kWorld)
                    converted_translation = transform.translation(om.MSpace.kTransform)
                    converted_rotation = \
                        transform.rotation(om.MSpace.kTransform, True).asEulerRotation().reorder(rotation_order)
                    converted_rotation = (
                        math.degrees(converted_rotation[0]),
                        math.degrees(converted_rotation[1]),
                        math.degrees(converted_rotation[2]),
                    )
                    for key_type in range(2):
                        for axis in range(3):
                            time = keys[key_index][bone_index][key_type][axis].time
                            if key_type == key_type_position and copy_position:
                                value = converted_translation[axis]
                            elif key_type == key_type_rotation and copy_rotation:
                                value = converted_rotation[axis]
                            else:
                                value = None
                            if value is not None:
                                cmds.setKeyframe(
                                    bone_name,
                                    attribute=attribute_name[key_type][axis],
                                    v=value,
                                    t=f'{time}sec',
                                    inTangentType="stepnext" if tangent_mode == TangentMode.STEPPED else "auto",
                                    outTangentType="step" if tangent_mode == TangentMode.STEPPED else "auto",
                                )

    @staticmethod
    def _set_local_keyframes(
            keys: list[list[list[list[KeyframeWithTangents]]]],
            preset: Preset
    ) -> None:
        with UndoChunk("ReceiveMarionetteKeyframes"), \
                ProgressWindow('Loading', 'Settings key frames...', True) as progress:
            last_report_time = time.time()
            key_type_position = 0
            key_type_rotation = 1
            attribute_name = [['translateX',
                               'translateY',
                               'translateZ', ],
                              ['rotateX',
                               'rotateY',
                               'rotateZ', ]]
            has_set_to_weighted_tangents = set()
            for bone_index in range(len(preset.human_names)):
                for key_index in range(len(keys[bone_index])):
                    if time.time() - last_report_time > 1:
                        if progress.has_been_cancelled():
                            return
                        progress.report('Setting key frames...',
                                        (bone_index + key_index / len(keys[bone_index])) / len(
                                            preset.human_names) * 100)
                        last_report_time = time.time()
                    for key_type in range(2):
                        for axis in range(3):
                            bone_name = preset.bone_names[bone_index]
                            copy_rotation = preset.copy_rotation[bone_index]
                            copy_position = preset.copy_position[bone_index]
                            key = keys[bone_index][key_index][key_type][axis]
                            should_copy = (
                                    (key_type == key_type_position and copy_position)
                                    or (key_type == key_type_rotation and copy_rotation)
                            )
                            if should_copy:
                                curve_name = f'{bone_name}.{attribute_name[key_type][axis]}'
                                cmds.setKeyframe(
                                    bone_name,
                                    attribute=attribute_name[key_type][axis],
                                    v=key.value,
                                    t=f'{key.time}sec'
                                )
                                if curve_name not in has_set_to_weighted_tangents:
                                    cmds.keyTangent(
                                        f'{bone_name}.{attribute_name[key_type][axis]}',
                                        weightedTangents=True,
                                    )
                                    has_set_to_weighted_tangents.add(f'{bone_name}.{attribute_name[key_type][axis]}')
                                if key_index > 0:
                                    # Maya seems to act as if the timeline is always set to 3 fps
                                    x_axis_magic_constant = 3
                                    # The y constant baffles me, could be 0.33333.. rad converted to degrees
                                    y_axis_magic_constant = 19.098595

                                    key_x = key.time * x_axis_magic_constant
                                    previous_key = keys[bone_index][key_index - 1][key_type][axis]
                                    previous_x = previous_key.time * x_axis_magic_constant

                                    tangent_x = (key_x - previous_x) * key.in_weight
                                    in_tangent = key.in_tangent
                                    # y-value / seconds to y-value / frames with 3 fps
                                    in_tangent /= x_axis_magic_constant
                                    # Account for Maya
                                    in_tangent /= y_axis_magic_constant
                                    tangent_y = tangent_x * in_tangent
                                    cmds.keyTangent(
                                        f'{bone_name}.{attribute_name[key_type][axis]}',
                                        time=(f'{key.time}sec',),
                                        edit=True,
                                        iy=tangent_y,
                                        ix=tangent_x,
                                        inTangentType='spline',
                                        lock=False,
                                        unify=False,
                                    )

                                    tangent_x = (key_x - previous_x) * key.in_weight
                                    out_tangent = previous_key.out_tangent
                                    # y-value / seconds to y-value / frames with 3 fps
                                    out_tangent /= x_axis_magic_constant
                                    # Account for Maya
                                    out_tangent /= y_axis_magic_constant
                                    tangent_y = tangent_x * out_tangent
                                    cmds.keyTangent(
                                        f'{bone_name}.{attribute_name[key_type][axis]}',
                                        time=(f'{previous_key.time}sec',),
                                        edit=True,
                                        oy=tangent_y,
                                        ox=tangent_x,
                                        outTangentType='spline',
                                        lock=False,
                                        unify=False,
                                    )

    @staticmethod
    def _set_frame_rate(fps: float) -> None:
        if fps == 15:
            unit = 'game'
        elif fps == 24:
            unit = 'film'
        elif fps == 25:
            unit = 'pal'
        elif fps == 30:
            unit = 'ntsc'
        elif fps == 48:
            unit = 'show'
        elif fps == 50:
            unit = 'palf'
        elif fps == 60:
            unit = 'ntscf'
        else:
            unit = str(fps) + 'fps'

        cmds.currentUnit(time=unit)
