import bpy
from mathutils import Matrix, Vector
from .copy_paste_utils import quick_sort, translateScale
import numpy as np


class WS_OT_copy_bone_relative_matrix(bpy.types.Operator):
    bl_idname = "bone_relative.copy"
    bl_label = "Copy Relative Transform"
    bl_options = {'UNDO'}
    bl_description = "Copy Transforms Relative to Active Bone"

    @classmethod
    def poll(cls, context):
        return context.active_pose_bone is not None

    def execute(self, context):

        bpy.context.window_manager['bone_rel_clipboard_data'] = {}

        relative_bone = bpy.context.active_pose_bone

        bpy.context.window_manager['relative_bone'] = relative_bone.name
        bpy.context.window_manager['armature'] = relative_bone.id_data

        # clear clipboard property group
        bpy.context.window_manager.bone_data.clear()

        for pbone in bpy.context.selected_pose_bones:
            arm = pbone.id_data
            matrix_final = arm.matrix_world @ pbone.matrix

            if pbone != relative_bone:

                cbone = bpy.context.window_manager.bone_data.add()
                cbone.arm_data = pbone.id_data
                cbone.name = pbone.name
                flat_mat = np.matrix(matrix_final.copy())
                # flatten matrix for FloatVector storage
                cbone.bone_matrix = np.ravel(flat_mat, order='F')

                bpy.types.WindowManager.bone_rel_clipboard_data = True

            else:
                bpy.context.window_manager['bone_relative_matrix'] = matrix_final.copy()
                bpy.context.window_manager['source_tail'] = arm.matrix_world @ relative_bone.tail

        return {'FINISHED'}


class WS_OT_paste_bone_relative_matrix(bpy.types.Operator):
    bl_idname = "bone_relative.paste"
    bl_label = "Paste Relative Transforms"
    bl_options = {'REGISTER', 'UNDO'}
    bl_description = "Paste Transforms Relative to Active Bone"

    include_loc: bpy.props.BoolProperty(name="Location",
                                        description="Paste relative location",
                                        default=True)

    include_rot: bpy.props.BoolProperty(name="Rotation",
                                        description="Paste relative rotation",
                                        default=True)

    include_scale: bpy.props.BoolProperty(name="Scale",
                                          description="Paste relative scale",
                                          default=True)

    set_inverse: bpy.props.BoolProperty(name="Set inverse",
                                        description="'Set Inverse' for 'Child Of' constraints",
                                        default=False)

    relative_to: bpy.props.EnumProperty(name="Relative to bone",
                                        description="Paste transforms relative to bone head/tail",
                                        items=[
                                            ("TAIL", "Tail", "Relative to bone tail"),
                                            ("HEAD", "Head", "Relative to bone head")
                                            ], default='TAIL')

    @classmethod
    def poll(cls, context):
        if bpy.types.WindowManager.bone_rel_clipboard_data is True and context.active_object.mode == 'POSE':
            return True

    def draw(self, context):
        layout = self.layout
        layout.use_property_split = False
        layout.use_property_decorate = False
        row = layout.row()
        row.label(text='Relative to :')
        row.prop(self, "relative_to", expand=True)
        layout.row().separator()
        row = layout.row()
        row.label(text='Include :')
        layout.use_property_split = True
        row = layout.row()
        row.prop(self, "include_loc")
        row = layout.row()
        row.prop(self, "include_rot")
        row = layout.row()
        row.prop(self, "include_scale")
        layout.row().separator()
        row = layout.row()
        row.prop(self, "set_inverse")

    def execute(self, context):

        bone_data = bpy.context.window_manager.bone_data
        selected = bpy.context.selected_pose_bones
        qsort = quick_sort(selected)

        relative_bone_name = bpy.context.window_manager['relative_bone']
        relative_arm = bpy.context.window_manager['armature']

        relative_bone = relative_arm.pose.bones[relative_bone_name]

        source_tail_loc = Vector(bpy.context.window_manager['source_tail'])
        source_matrix = Matrix(bpy.context.window_manager['bone_relative_matrix'])
        source_loc, source_rot, source_scale = source_matrix.decompose()
        source_loc_mat = Matrix.Translation(source_loc)
        source_rot_mat = source_rot.to_matrix().to_4x4()
        source_scale_mat = translateScale(source_scale)
        current_bone_y = source_matrix[1][1]

        current_tail = relative_arm.matrix_world @ relative_bone.tail
        current_matrix = relative_arm.matrix_world @ relative_bone.matrix
        current_loc, current_rot, current_scale = current_matrix.decompose()
        current_loc_mat = Matrix.Translation(current_loc)
        current_rot_mat = current_rot.to_matrix().to_4x4()
        current_scale_mat = translateScale(current_scale)

        offset_matrix = source_matrix @ current_matrix.inverted()
        offset_loc, offset_rot, offset_scale = offset_matrix.inverted().decompose()
        offset_loc_mat = Matrix.Translation(offset_loc)
        offset_rot_mat = offset_rot.to_matrix().to_4x4()
        offset_scale_mat = translateScale(offset_scale)

        if qsort is not None:
            for pbone in qsort:
                arm = pbone.id_data

                # only supporting child_of constraints with influence of 1 for now
                valid_childof_cons = [con for con in pbone.constraints
                                      if con.type == "CHILD_OF" and con.target and con.influence == 1.0 and not con.mute]

                for i, item in enumerate(bone_data.items()):
                    # check for matching name and armature
                    if bone_data[i].name == pbone.name and bone_data[i].arm_data == arm:
                        # cast FloatVector to Matrix
                        stored_matrix = Matrix(bone_data[i].bone_matrix)
                        stored_loc, stored_rot, stored_scale = stored_matrix.decompose()
                        stored_loc_mat = Matrix.Translation(stored_loc)
                        stored_rot_mat = stored_rot.to_matrix().to_4x4()
                        stored_scale_mat = translateScale(stored_scale)

                        if self.relative_to == 'TAIL':
                            relative_vec = Vector(stored_loc - source_tail_loc)
                        else:
                            relative_vec = Vector(stored_loc - source_loc)

                        relative_vec.rotate(offset_rot)

                        rot_final_mat = offset_rot_mat

                        scale_rel_mat = source_scale_mat.inverted() @ current_scale_mat

                        if self.include_loc:
                            if self.relative_to == 'TAIL':
                                loc_final_mat = Matrix.Translation(current_tail + relative_vec)
                            else:
                                loc_final_mat = Matrix.Translation(current_loc + relative_vec)
                        else:
                            loc_final_mat = stored_loc_mat

                        if self.include_scale:
                            scale_final_mat = stored_scale_mat @ scale_rel_mat

                        else:
                            scale_final_mat = stored_scale_mat

                        if self.include_rot:
                            matrix_final = loc_final_mat @ rot_final_mat @ stored_rot_mat @ scale_final_mat

                        else:
                            matrix_final = loc_final_mat @ stored_rot_mat @ scale_final_mat

                        # use the last valid childof constraint
                        if valid_childof_cons:
                            con = valid_childof_cons[-1]
                            con_name = con.name

                            if self.set_inverse:
                                context.view_layer.objects.active = arm
                                bpy.context.active_object.data.bones.active = pbone.bone
                                override = bpy.context.copy()
                                override["constraint"] = pbone.constraints[con_name]
                                bpy.ops.constraint.childof_set_inverse(override, constraint=con_name, owner='BONE')
                                bpy.context.view_layer.update()

                            target_bone = con.target.pose.bones[con.subtarget]
                            con_matrix = arm.matrix_world @ con.inverse_matrix.inverted()
                            target_matrix = arm.matrix_world @ target_bone.matrix

                            offset_matrix = target_matrix @ con_matrix.inverted()

                            matrix_final = offset_matrix.inverted() @ matrix_final

                            pbone.matrix = arm.convert_space(pose_bone=pbone,
                                                             matrix=matrix_final,
                                                             from_space='WORLD',
                                                             to_space='POSE')

                        else:
                            pbone.matrix = arm.convert_space(pose_bone=pbone,
                                                             matrix=matrix_final,
                                                             from_space='WORLD',
                                                             to_space='POSE')

                        bpy.context.view_layer.update()

                        arm.data.bones.active = pbone.bone

                        if bpy.context.scene.tool_settings.use_keyframe_insert_auto:
                            try:
                                bpy.ops.anim.keyframe_insert_menu(type='Available')
                            except RuntimeError:
                                self.report({'WARNING'}, f'{pbone.name} has no active keyframes')
                                pass

        return {'FINISHED'}
