# ##### BEGIN GPL LICENSE BLOCK #####
#
#  This program is free software; you can redistribute it and/or
#  modify it under the terms of the GNU General Public License
#  as published by the Free Software Foundation; either version 2
#  of the License, or (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.
#
#  You should have received a copy of the GNU General Public License
#  along with this program; if not, write to the Free Software Foundation,
#  Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
#
# ##### END GPL LICENSE BLOCK #####

""" Zen Part Groups System """
# blender
import bpy
import bmesh

from collections import defaultdict
import numpy as np

from ...labels import ZsLabels
from ...blender_zen_utils import ZenLocks
from ...vlog import Log
from ...blender_zen_utils import update_areas_in_all_screens

from ..basic_sets import Zs_UL_BaseList
from ..draw_sets import mark_groups_modified
from ..bl_sets import ZsBlenderSets, ZsBlenderScenePropList, ZsBlenderObjectPropList, ZsBlenderGroupIndexSync
from ..basic_map_sets import ZsMapLayerManager
from ..draw_cache import FaceMapsUniqueCacher


class ZsFaceMapsLayerManager(ZsMapLayerManager, ZsBlenderSets):

    id_group = 'blgroup_u'
    list_item_prefix = 'FMaps'
    id_mask = 'ZSUFM'
    id_element = 'face'
    id_display_element = 'blgroup_u'
    id_uv_select_mode = 'FACE'
    is_unique = True
    is_blender = True

    """ Parts section """
    @classmethod
    def get_bm_items(self, bm: bmesh.types.BMesh):
        bm.faces.ensure_lookup_table()
        return bm.faces

    @classmethod
    def get_selected_count(self, p_obj: bpy.types.Object):
        me = p_obj.data
        return me.total_face_sel

    @classmethod
    def get_mesh_select_mode(self):
        # 0 - vert, 1 - edge, 2 - face
        return (False, False, True)

    @classmethod
    def get_item_loops(self, p_item):
        return p_item.loops

    @classmethod
    def get_bm_item_zero_val(self):
        return -1

    @classmethod
    def get_bl_groups(self, p_obj: bpy.types.Object):
        return p_obj.face_maps

    @classmethod
    def get_hash_layer(self, p_bm):
        return self.get_bm_items(p_bm).layers.face_map.active

    @classmethod
    def ensure_hash_layer(self, p_bm):
        return self.get_bm_items(p_bm).layers.face_map.verify()

    @classmethod
    def ensure_mesh_layer(self, p_obj: bpy.types.Object, p_bm: bmesh.types.BMesh, layerName):
        p_layer = self.get_mesh_layer(p_obj, p_bm, layerName)
        if p_layer is None:

            p_obj.face_maps.new(name=layerName)
            p_face_map = p_bm.faces.layers.face_map.verify()
            idx = len(p_obj.face_maps) - 1

            return (p_face_map, idx)

        return p_layer

    @classmethod
    def get_mesh_layer(self, p_obj: bpy.types.Object, p_bm: bmesh.types.BMesh, layerName):
        idx = p_obj.face_maps.find(layerName)
        if idx != -1:
            p_bm_items = self.get_bm_items(p_bm)
            p_face_map = p_bm_items.layers.face_map.active
            if p_face_map:
                return (p_face_map, idx)

        return None

    @classmethod
    def create_unique_layer_name(self):
        p_scene = bpy.context.scene
        p_scene_list = self.get_list(p_scene)
        i_count = 1
        s_layer_name = self.list_item_prefix + f'.{i_count:03d}'
        while self._index_of_layer(p_scene_list, s_layer_name) != -1:
            s_layer_name = self.list_item_prefix + f'.{i_count:03d}'
            i_count += 1
        return s_layer_name

    @classmethod
    def get_hash_from_str(self, p_obj, p_str):
        return p_obj.face_maps.find(p_str)

    @classmethod
    def is_bm_item_set(self, p_bm_item, p_layer_pair):
        p_face_maps, idx = p_layer_pair
        return p_bm_item[p_face_maps] == idx

    @classmethod
    def _add_list_layer(self, p_list, layerName, sItemPrefix, p_color):
        p_list.add()

        i = len(p_list) - 1

        p_list[-1].name = layerName
        p_list[-1].layer_name = layerName
        p_list[-1].group_color = p_color

        return i

    @classmethod
    def remove_mesh_layer(self, p_obj: bpy.types.Object, p_bm: bmesh.types.BMesh, layerName, cleanup=False):
        p_group = p_obj.face_maps.get(layerName)
        if p_group:
            p_obj.face_maps.remove(p_group)

    @classmethod
    def update_all_obj_groups_count(self, p_obj: bpy.types.Object, no_lookup=False):
        p_obj_list = self.get_list(p_obj)
        if [g.layer_name for g in p_obj_list] != [fmap.name for fmap in p_obj.face_maps]:
            self.update_list(bpy.context)
            if not no_lookup:
                self.build_lookup_table(bpy.context)
            return True

        bm = self._get_bm(p_obj)
        layer = self.get_hash_layer(bm)
        b_need_update = False
        if layer:
            p_obj_list = self.get_list(p_obj)
            dic = defaultdict(list)
            for item in self.get_bm_items(bm):
                if item[layer] != -1:
                    dic[item[layer]].append(item.hide)

            hash_obj_list = {p_face_map.name: p_face_map.index for p_face_map in p_obj.face_maps}

            for group in p_obj_list:
                i_group_count = 0
                i_hide_count = 0

                idx = hash_obj_list.get(group.layer_name, -1)

                if idx in dic:
                    arr = np.fromiter(dic[idx], 'b')
                    i_group_count = len(arr)
                    i_hide_count = np.count_nonzero(arr)
                    del dic[idx]
                if i_group_count != group.group_count or i_hide_count != group.group_hide_count:
                    group.group_count = i_group_count
                    group.group_hide_count = i_hide_count
                    b_need_update = True

            if len(dic):
                p_scene = bpy.context.scene
                p_scene_list = self.get_list(p_scene)
                hash_list = {hash_obj_list.get(p_group.layer_name, -1): p_group for p_group in p_scene_list}

                for c_g_hash, c_g_value in dic.items():
                    if c_g_hash in hash_list:
                        p_group = hash_list[c_g_hash]
                        mesh_layer_name = p_group.layer_name
                        Log.warn(f'Object:[{p_obj.name}] - Layer:[{mesh_layer_name}] will be restored!')
                        p_obj_group = self.ensure_group_in_object(p_obj, p_group)
                        arr = np.fromiter(c_g_value, 'b')
                        i_group_count = len(arr)
                        i_hide_count = np.count_nonzero(arr)
                        p_obj_group.group_count = i_group_count
                        p_obj_group.group_hide_count = i_hide_count
                        b_need_update = True

            if b_need_update:
                if no_lookup:
                    mark_groups_modified(self, p_obj)
                else:
                    self.build_lookup_table(bpy.context)
        return b_need_update

    @classmethod
    def fetch_uv_selections(self, bm):
        uv_selected = set()

        uv_layer = bm.loops.layers.uv.active
        if uv_layer:

            uv_selected = set(
                item.index
                for item in self.get_bm_items(bm)
                if not item.hide and
                all(
                    (loop[uv_layer].select and loop.face.select)
                    for loop in item.loops)
            )

        return uv_selected

    @classmethod
    def check_uv_select_mode(self) -> bool:
        b_changed = False
        if bpy.context.tool_settings.uv_select_mode not in {self.id_uv_select_mode, 'ISLAND'}:
            bpy.context.tool_settings.uv_select_mode = self.id_uv_select_mode
            b_changed = True
        return b_changed

    @classmethod
    def get_cacher(self):
        return FaceMapsUniqueCacher()

    @classmethod
    def execute_DrawMenu(cls, menu, context):
        super().execute_DrawMenu(menu, context)

        layout = menu.layout

        layout.separator()
        layout.operator('zsts.group_linked')  # icon='LINKED'
        layout.operator(ZSUFM_OT_AssignMaterialsToGroups.bl_idname)  # icon='MATERIAL'

    @classmethod
    def execute_DrawImport(cls, layout, context, is_menu=False):
        super().execute_DrawImport(layout, context, is_menu)

        # if is_menu:
        #     layout.separator()
        # col = layout.column(align=True)

    @classmethod
    def execute_DrawExport(cls, layout, context, is_menu=False):
        super().execute_DrawExport(layout, context, is_menu)

        # if is_menu:
        #     layout.separator()
        # col = layout.column(align=True)

    @classmethod
    def execute_DrawTools(cls, tools, context):
        layout = tools.layout
        layout.operator('zsts.group_linked')

        super().execute_DrawTools(tools, context)

        layout.operator(ZSUFM_OT_AssignMaterialsToGroups.bl_idname)


class ZSUFM_UL_List(Zs_UL_BaseList, ZsFaceMapsLayerManager):
    pass


class ZSUFM_OT_AssignMaterialsToGroups(bpy.types.Operator, ZsFaceMapsLayerManager):
    """ Import native groups """
    bl_idname = ZsFaceMapsLayerManager.list_prop_name() + '.assign_materials_to_groups'
    bl_description = ZsLabels.OT_ASSIGN_MATERIAL_TO_GROUP_DESC
    bl_label = ZsLabels.OT_ASSIGN_MATERIAL_TO_GROUP_LABEL
    bl_options = {'REGISTER', 'UNDO'}

    @classmethod
    def poll(cls, context):
        p_scene = context.scene
        p_list = cls.get_list(p_scene)
        return p_list and len(p_list) > 0

    def execute(self, context):
        p_group_pairs = self.get_current_group_pairs(context)

        if len(p_group_pairs):

            def_mat_name = 'Material.Default'

            def_mat = None
            if def_mat_name not in bpy.data.materials.keys():
                def_mat = bpy.data.materials.new(name=def_mat_name)
                material_color = (1, 1, 1, 1)
                def_mat.diffuse_color = material_color
                def_mat.use_nodes = True
                def_mat.use_fake_user = True
                pr_bsdf_node = def_mat.node_tree.nodes.get("Principled BSDF")
                if pr_bsdf_node:
                    pr_bsdf_node.inputs['Base Color'].default_value = material_color
            else:
                def_mat = bpy.data.materials.get(def_mat_name)

            update_objs = set()

            for p_obj in context.objects_in_mode:
                id_def_mat = p_obj.data.materials.find(def_mat_name)

                if id_def_mat == -1:
                    p_obj.data.materials.append(def_mat)
                    id_def_mat = len(p_obj.data.materials) - 1

                bm = self._get_bm(p_obj)
                nonlayered_items = self.get_bm_nonlayered_items(p_obj, bm, p_group_pairs)
                if len(nonlayered_items):
                    for item in nonlayered_items:
                        item.material_index = id_def_mat
                    update_objs.add(p_obj)

            for _, p_group in p_group_pairs:
                material_name = 'Material.' + p_group.name
                material_color = (p_group.group_color.r, p_group.group_color.g, p_group.group_color.b, 1)
                new_mat = bpy.data.materials.new(name=material_name) if material_name not in bpy.data.materials.keys() else bpy.data.materials.get(material_name)
                new_mat.diffuse_color = material_color
                new_mat.use_nodes = True
                new_mat.use_fake_user = True
                pr_bsdf_node = new_mat.node_tree.nodes.get("Principled BSDF")
                if pr_bsdf_node:
                    pr_bsdf_node.inputs['Base Color'].default_value = material_color

                for p_obj in context.objects_in_mode:
                    id_mat = p_obj.data.materials.find(material_name)

                    if id_mat == -1:
                        p_obj.data.materials.append(new_mat)
                        id_mat = len(p_obj.data.materials) - 1

                    bm = self._get_bm(p_obj)
                    items = self.get_bm_layer_items(p_obj, bm, p_group.layer_name)
                    if len(items):
                        for item in items:
                            item.material_index = id_mat
                        update_objs.add(p_obj)

            for p_obj in update_objs:
                me = p_obj.data
                bmesh.update_edit_mesh(me, loop_triangles=False, destructive=False)
                ZenLocks.lock_depsgraph_update_one()

        return {'FINISHED'}


def zen_sets_update_face_maps_list():
    ctx = bpy.context
    if ctx.mode == 'EDIT_MESH':
        from ..factories import get_sets_mgr
        p_cls_mgr = get_sets_mgr(ctx.scene)
        if p_cls_mgr == ZsFaceMapsLayerManager:
            p_cls_mgr.update_list(ctx)
            p_cls_mgr.build_lookup_table(ctx)

            update_areas_in_all_screens()


class ZSUFMObjectListGroup(ZsBlenderObjectPropList, bpy.types.PropertyGroup):
    """
    Group of properties representing
    an item in the zen sets groups for OBJECT
    """
    pass


class ZSUFMSceneListGroup(ZsBlenderScenePropList, bpy.types.PropertyGroup):
    """
    Group of properties representing
    an item in the zen sets groups for SCENE
    """
    def get_object_groups(self, p_obj: bpy.types.Object):
        return p_obj.face_maps

    def get_update_method(self):
        return zen_sets_update_face_maps_list


class ZsFaceMapsIndexSync(ZsBlenderGroupIndexSync):
    def get_cls_mgr(self):
        return ZsFaceMapsLayerManager


faceMapsIndexSync = ZsFaceMapsIndexSync()


class ZSUFM_Factory:
    classes = (
        ZSUFM_UL_List,
        ZSUFM_OT_AssignMaterialsToGroups,
        ZSUFMObjectListGroup,
        ZSUFMSceneListGroup
    )

    def get_mgr():
        return ZsFaceMapsLayerManager

    def get_ui_list():
        return ZSUFMSceneListGroup

    def get_obj_ui_list():
        return ZSUFMObjectListGroup
