# ##### BEGIN GPL LICENSE BLOCK #####
#
#  This program is free software; you can redistribute it and/or
#  modify it under the terms of the GNU General Public License
#  as published by the Free Software Foundation; either version 2
#  of the License, or (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.
#
#  You should have received a copy of the GNU General Public License
#  along with this program; if not, write to the Free Software Foundation,
#  Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
#
# ##### END GPL LICENSE BLOCK #####

""" Zen Maps Groups System """
# blender
import bmesh
import bpy

from collections import defaultdict
import numpy as np

from .basic_sets import ZsLayerManager
from .draw_sets import check_update_cache, mark_groups_modified

from ..hash_utils import hash32
from ..blender_zen_utils import (
    ZenLocks, ZenPolls, ZenSelectionStats,
    update_areas_in_all_screens)
from ..vlog import Log


class ZsMapLayerManager(ZsLayerManager):

    @classmethod
    def get_bm_item_zero_val(self):
        return 0

    @classmethod
    def is_bm_item_set(self, p_bm_item, p_layer_pair):
        p_layer, idx = p_layer_pair
        return p_bm_item[p_layer] == idx

    @classmethod
    def set_bm_item(self, p_bm_item, p_layer_pair, p_val):
        p_layer, idx = p_layer_pair
        if p_val:
            p_bm_item[p_layer] = idx
        else:
            p_bm_item[p_layer] = self.get_bm_item_zero_val()

    @classmethod
    def get_bm_nonlayered_items(self, p_obj, bm, p_group_pairs):
        hash_layer = self.get_hash_layer(bm)
        hash_set = set(self.get_hash_from_str(p_obj, p_group.layer_name) for _, p_group in p_group_pairs)
        p_zero_val = self.get_bm_item_zero_val()
        try:
            hash_set.remove(p_zero_val)
        except KeyError:
            pass
        return [
            item for item in self.get_bm_items(bm)
            if item[hash_layer] not in hash_set] if hash_layer else self.get_bm_items(bm)

    @classmethod
    def update_all_obj_groups_count(self, p_obj, no_lookup=False):
        bm = self._get_bm(p_obj)
        layer = self.get_hash_layer(bm)
        b_need_update = False
        if layer:
            p_obj_list = self.get_list(p_obj)
            dic = defaultdict(list)
            for item in self.get_bm_items(bm):
                if item[layer] != 0:
                    dic[item[layer]].append(item.hide)

            for group in p_obj_list:
                p_bytes = hash32(group.layer_name)
                i_group_count = 0
                i_hide_count = 0
                if p_bytes in dic:
                    arr = np.fromiter(dic[p_bytes], 'b')
                    i_group_count = len(arr)
                    i_hide_count = np.count_nonzero(arr)
                    del dic[p_bytes]
                if i_group_count != group.group_count or i_hide_count != group.group_hide_count:
                    group.group_count = i_group_count
                    group.group_hide_count = i_hide_count
                    b_need_update = True

            if len(dic):
                p_scene = bpy.context.scene
                p_scene_list = self.get_list(p_scene)
                hash_list = {hash32(p_group.layer_name): p_group for p_group in p_scene_list}
                for c_g_hash, c_g_value in dic.items():
                    if c_g_hash in hash_list:
                        p_group = hash_list[c_g_hash]
                        mesh_layer_name = p_group.layer_name
                        Log.warn(f'Object:[{p_obj.name}] - Layer:[{mesh_layer_name}] will be restored!')
                        p_obj_group = self.ensure_group_in_object(p_obj, p_group)
                        arr = np.fromiter(c_g_value, 'b')
                        i_group_count = len(arr)
                        i_hide_count = np.count_nonzero(arr)
                        p_obj_group.group_count = i_group_count
                        p_obj_group.group_hide_count = i_hide_count
                        b_need_update = True

            if b_need_update and not no_lookup:
                self.build_lookup_table(bpy.context)
        return b_need_update

    @classmethod
    def set_selection_to_new_group(self, p_obj, p_scene_group, i_start, i_end):
        layerName = p_scene_group.layer_name

        me = p_obj.data
        bm = bmesh.from_edit_mesh(me)

        i_selected_count = self.get_selected_count(p_obj)

        mesh_layer = self.ensure_mesh_layer(p_obj, bm, layerName) if i_selected_count else self.get_mesh_layer(p_obj, bm, layerName)
        if mesh_layer and i_selected_count != 0:
            hash_layer, p_bytes = mesh_layer
            p_zero_val = self.get_bm_item_zero_val()
            i_group_count = 0
            bm_items = self.get_bm_items(bm)
            for idx in range(i_start, i_end, 1):
                item = bm_items[idx]
                if item.select:
                    item[hash_layer] = p_bytes
                    i_group_count += 1
                    i_selected_count -= 1
                    if i_selected_count == 0:
                        break
                else:
                    if item[hash_layer] == p_bytes:
                        item[hash_layer] = p_zero_val

            if i_group_count:
                p_group = self.ensure_group_in_object(p_obj, p_scene_group)
                p_group.group_count = i_group_count
                p_group.group_hide_count = 0

            bm.select_flush_mode()
            bmesh.update_edit_mesh(me, loop_triangles=False, destructive=False)

    @classmethod
    def set_selection_to_group(self, p_obj, p_scene_group, indices):
        layerName = p_scene_group.layer_name

        me = p_obj.data
        bm = bmesh.from_edit_mesh(me)

        b_indices_mode = indices is not None
        if b_indices_mode:
            i_selected_count = len(indices)
        else:
            b_is_uv = self.is_uv_area_and_not_sync()
            if b_is_uv:
                uv_sel = self.fetch_uv_selections(bm)
                i_selected_count = len(uv_sel)
            else:
                i_selected_count = self.get_selected_count(p_obj)

        mesh_layer = self.ensure_mesh_layer(p_obj, bm, layerName) if i_selected_count else self.get_mesh_layer(p_obj, bm, layerName)
        if mesh_layer:
            hash_layer, p_bytes = mesh_layer
            p_zero_val = self.get_bm_item_zero_val()
            i_group_count = 0
            for item in self.get_bm_items(bm):
                is_selected = (
                    (item.index in indices) if b_indices_mode
                    else (item.select and (not b_is_uv or (item.index in uv_sel)))
                )
                if is_selected:
                    item[hash_layer] = p_bytes
                    i_group_count += 1
                else:
                    if item[hash_layer] == p_bytes:
                        item[hash_layer] = p_zero_val

            if i_group_count:
                p_group = self.ensure_group_in_object(p_obj, p_scene_group)
                p_group.group_count = i_group_count
                p_group.group_hide_count = 0

            bm.select_flush_mode()
            bmesh.update_edit_mesh(me, loop_triangles=False, destructive=False)

    @classmethod
    def select_ungroupped(self, p_obj, p_group_pairs):
        me = p_obj.data
        bm = bmesh.from_edit_mesh(me)

        hash_set = set(self.get_hash_from_str(p_obj, p_group.layer_name) for _, p_group in p_group_pairs)
        p_zero_val = self.get_bm_item_zero_val()
        try:
            hash_set.remove(p_zero_val)
        except KeyError:
            pass

        hash_layer = self.get_hash_layer(bm)
        b_is_uv = self.is_uv_area_and_not_sync()
        uv_layer = bm.loops.layers.uv.active
        for item in self.get_bm_items(bm):
            b_select = (hash_layer is None) or (item[hash_layer] not in hash_set)
            if b_is_uv and uv_layer:
                for loop in self.get_item_loops(item):
                    loop[uv_layer].select = loop[uv_layer].select and b_select
                    if ZenPolls.version_greater_3_2_0:
                        loop[uv_layer].select_edge = loop[uv_layer].select
            else:
                item.select = b_select

        bm.select_flush_mode()
        bmesh.update_edit_mesh(p_obj.data, loop_triangles=False, destructive=False)

        if self.is_uv_force_update():
            mark_groups_modified(self, p_obj, modes={'UV'})
            check_update_cache(self, p_obj)
        ZenLocks.lock_depsgraph_update_one()

        update_areas_in_all_screens()

    @classmethod
    def smart_select(self, context, select_active_group_only, keep_active_group) -> ZenSelectionStats:
        selection_stats = ZenSelectionStats()
        selection_stats.was_mesh_sel_count = self.get_context_selected_count(context)
        if selection_stats.was_mesh_sel_count == 0:
            return selection_stats

        p_scene = context.scene
        p_group_pairs = self.get_current_group_pairs(context)

        self.set_mesh_select_mode(context)

        b_is_uv = self.is_uv_area_and_not_sync()

        # an element may belong to many groups, so make a set of such groups
        sel_mesh_layers = set()
        map_mesh_layer = dict()
        map_obj_sel_layers = defaultdict(dict)
        map_non_sel = set()

        selection_stats.was_uv_sel_count = 0

        p_zero_val = self.get_bm_item_zero_val()

        p_mesh_objects = set(p_obj for p_obj in context.objects_in_mode_unique_data if p_obj.type == 'MESH')

        for p_obj in p_mesh_objects:
            me = p_obj.data
            bm = bmesh.from_edit_mesh(me)
            items = self.get_bm_items(bm)

            if b_is_uv:
                uv_sel = self.fetch_uv_selections(bm)
                i_cur_uv_selected = len(uv_sel)
                selection_stats.was_uv_sel_count += i_cur_uv_selected
                if i_cur_uv_selected == 0:
                    continue

            hash_layer = self.get_hash_layer(bm)
            if hash_layer:
                map_mesh_layer[p_obj] = {self.get_hash_from_str(p_obj, g.layer_name): g.layer_name for _, g in p_group_pairs}
                try:
                    del map_mesh_layer[p_obj][p_zero_val]
                except KeyError:
                    pass

                for item in items:
                    if item.select and (not b_is_uv or (item.index in uv_sel)):
                        p_layer_name = map_mesh_layer[p_obj].get(item[hash_layer], None)
                        if p_layer_name:
                            sel_mesh_layers.add(p_layer_name)
                        else:
                            map_non_sel.add(p_obj)

        if b_is_uv and selection_stats.was_uv_sel_count == 0:
            return selection_stats

        for p_obj, p_hashes in map_mesh_layer.items():
            for p_hash, p_layer_name in p_hashes.items():
                if p_layer_name in sel_mesh_layers:
                    map_obj_sel_layers[p_obj][p_hash] = p_layer_name

        b_is_face = self.id_element == 'face'

        modified_objects_data = set()

        # 1) selected elemenents do not belong to any group
        if len(sel_mesh_layers) == 0:
            for p_obj in p_mesh_objects:
                me = p_obj.data
                bm = bmesh.from_edit_mesh(me)
                items = self.get_bm_items(bm)

                hash_layer = self.get_hash_layer(bm)
                uv_layer = bm.loops.layers.uv.active
                selected_loops = set()
                unselected_loops = set()

                b_obj_selection_changed = False

                for item in items:
                    is_present = (hash_layer is None) or (item[hash_layer] not in map_mesh_layer[p_obj])
                    was_selected = item.select
                    self._smart_process_item(
                        item, b_is_uv, uv_layer, is_present,
                        b_is_face, selected_loops, unselected_loops)

                    if is_present:
                        selection_stats.new_mesh_sel_count += 1

                    if item.select != was_selected:
                        b_obj_selection_changed = True

                if b_is_uv:
                    all_loops = selected_loops.union(unselected_loops)
                    for loop in all_loops:
                        b_select = loop in selected_loops
                        was_selected = loop[uv_layer].select
                        loop[uv_layer].select = b_select
                        if ZenPolls.version_greater_3_2_0 and b_is_face:
                            loop[uv_layer].select_edge = b_select
                        selection_stats.new_uv_sel_count += 1

                        if was_selected != b_select:
                            b_obj_selection_changed = True

                if b_obj_selection_changed:
                    selection_stats.selection_changed = True
                    modified_objects_data.add(p_obj.data)
                    bm.select_flush_mode()
                    bmesh.update_edit_mesh(p_obj.data, loop_triangles=True, destructive=False)

            self.set_list_index(p_scene, -1)
            self._do_set_last_smart_select('')
        else:
            sel_active_layer_name = ''
            sel_last_layer_name = sel_active_layer_name
            p_group_pair = self.get_current_group_pair(context)
            p_scene_layer_name = p_group_pair[1].layer_name if p_group_pair else ''
            # check if we have a group present in selection to keep it remaining
            if p_scene_layer_name in sel_mesh_layers:
                sel_active_layer_name = p_scene_layer_name
                sel_last_layer_name = sel_active_layer_name

            # by default we will select first group
            if sel_active_layer_name == '':
                sel_active_layer_name = next(iter(sel_mesh_layers))
                sel_last_layer_name = sel_active_layer_name

            if context.active_object:
                p_obj = context.active_object
                bm = self._get_bm(p_obj)
                items = self.get_bm_items(bm)
                hash_layer = self.get_hash_layer(bm)
                if (hash_layer is not None) and len(items):
                    active_element = bm.select_history.active
                    if active_element and (type(items[0]) is type(active_element)):
                        p_actve_element_hash = active_element[hash_layer]
                        p_layer_name = map_obj_sel_layers[p_obj].get(p_actve_element_hash, None)
                        if p_layer_name:
                            if not keep_active_group:
                                sel_active_layer_name = p_layer_name
                            sel_last_layer_name = p_layer_name

            for p_obj in p_mesh_objects:
                me = p_obj.data
                bm = bmesh.from_edit_mesh(me)
                items = self.get_bm_items(bm)
                hash_layer = self.get_hash_layer(bm)
                uv_layer = bm.loops.layers.uv.active

                selected_loops = set()
                unselected_loops = set()

                has_nongroupped = p_obj in map_non_sel

                b_obj_selection_changed = False

                p_sel_active_hash = self.get_bm_item_zero_val()
                if sel_active_layer_name:
                    p_sel_active_hash = self.get_hash_from_str(p_obj, sel_active_layer_name)

                p_sel_hashes = map_obj_sel_layers[p_obj].keys()

                for item in items:
                    if select_active_group_only:
                        is_present = (hash_layer is not None) and (
                            item[hash_layer] != p_zero_val and
                            item[hash_layer] == p_sel_active_hash)
                    else:
                        is_present = (
                            ((hash_layer is not None) and (item[hash_layer] in p_sel_hashes)) or
                            (has_nongroupped and (hash_layer is None or item[hash_layer] == p_zero_val))
                        )

                    was_selected = item.select

                    self._smart_process_item(
                        item, b_is_uv, uv_layer, is_present,
                        b_is_face, selected_loops, unselected_loops)

                    if is_present:
                        selection_stats.new_mesh_sel_count += 1

                    if item.select != was_selected:
                        b_obj_selection_changed = True

                if b_is_uv:
                    all_loops = selected_loops.union(unselected_loops)
                    for loop in all_loops:
                        b_select = loop in selected_loops
                        was_selected = loop[uv_layer].select
                        loop[uv_layer].select = b_select
                        if ZenPolls.version_greater_3_2_0 and b_is_face:
                            loop[uv_layer].select_edge = b_select
                        selection_stats.new_uv_sel_count += 1

                        if was_selected != b_select:
                            b_obj_selection_changed = True

                if b_obj_selection_changed:
                    selection_stats.selection_changed = True
                    modified_objects_data.add(p_obj.data)
                    bm.select_flush_mode()
                    bmesh.update_edit_mesh(p_obj.data, loop_triangles=True, destructive=False)

            for i, g in p_group_pairs:
                layer_name = g.layer_name
                if sel_active_layer_name:
                    if sel_active_layer_name == layer_name:
                        self.set_list_index(p_scene, i)
                        if sel_last_layer_name == sel_active_layer_name:
                            self._do_set_last_smart_select(layer_name)
                            sel_last_layer_name = None
                        sel_active_layer_name = None
                if sel_last_layer_name:
                    if layer_name == sel_last_layer_name:
                        self._do_set_last_smart_select(layer_name)
                if sel_active_layer_name is None and sel_last_layer_name is None:
                    break

        if selection_stats.selection_changed:
            if not b_is_uv and self.is_uv_force_update():
                for p_obj in context.objects_in_mode:
                    if p_obj.data in modified_objects_data:
                        mark_groups_modified(self, p_obj, modes={'UV'})
                        check_update_cache(self, p_obj)
            ZenLocks.lock_depsgraph_update_one()

        return selection_stats

    @classmethod
    def ensure_mesh_layer(self, p_obj, p_bm, layerName):
        p_layer = self.get_mesh_layer(p_obj, p_bm, layerName)
        if p_layer is None:
            p_hash_layer = self.ensure_hash_layer(p_bm)
            idx = self.get_hash_from_str(p_obj, layerName)
            p_layer = (p_hash_layer, idx)
        return p_layer

    @classmethod
    def get_mesh_layer(self, p_obj, p_bm, layerName):
        hash_layer = self.get_hash_layer(p_bm)
        if hash_layer:
            return (hash_layer, self.get_hash_from_str(p_obj, layerName))
        else:
            return None

    @classmethod
    def get_hash_from_str(self, p_obj, p_str):
        return hash32(p_str)

    @classmethod
    def get_hash_layer(self, p_bm):
        return self._get_mesh_layer(self.get_bm_items(p_bm).layers.int, self.layer_hash_name())

    @classmethod
    def ensure_hash_layer(self, p_bm):
        return self._ensure_mesh_layer(self.get_bm_items(p_bm).layers.int, self.layer_hash_name())

    @classmethod
    def remove_mesh_layer(self, p_obj, p_bm, layerName, cleanup=False):
        if cleanup:
            hash_layer = self.get_hash_layer(p_bm)
            if hash_layer:
                bm_items = self.get_bm_items(p_bm)
                p_bytes = hash32(layerName)
                for item in bm_items:
                    if item[hash_layer] == p_bytes:
                        item[hash_layer] = 0

    @classmethod
    def _do_cleanup_object(self, obj):
        super()._do_cleanup_object(obj)

        bm = self._get_bm(obj)
        self._remove_mesh_layer(self.get_bm_items(bm).layers.int, self.layer_hash_name())
        self._remove_mesh_layer(self.get_bm_items(bm).layers.string, self.layer_hash_name())

        bmesh.update_edit_mesh(obj.data, loop_triangles=False, destructive=False)

    @classmethod
    def import_vertex_colors_to_zen_groups(self, p_scene, p_obj, b_active, p_ignored_color, s_prefix):
        bm = self._get_bm(p_obj)
        p_scene_list = self.get_list(p_scene)
        bm_items = self.get_bm_items(bm)
        i_group_index = -1
        act_name = bm.loops.layers.color.active.name if bm.loops.layers.color.active else ''
        act_name = act_name.replace(s_prefix, '', 1)
        for p_vert_color_layer in bm.loops.layers.color.values():
            if b_active and p_vert_color_layer.name != act_name:
                continue

            vc_name = p_vert_color_layer.name

            b_is_vert = self.id_element == 'vert'
            b_is_edge = self.id_element == 'edge'

            colors = set((loop[p_vert_color_layer][0], loop[p_vert_color_layer][1], loop[p_vert_color_layer][2])
                         for item in bm_items
                         for loop in (item.link_loops if (b_is_vert or b_is_edge) else item.loops)
                         if p_ignored_color is None or (loop[p_vert_color_layer][0] != p_ignored_color.r and
                                                        loop[p_vert_color_layer][1] != p_ignored_color.g and
                                                        loop[p_vert_color_layer][2] != p_ignored_color.b)
                         )

            for i, color in enumerate(colors):
                group_name = vc_name
                if len(colors) > 1:
                    group_name += f' [{i+1}]'

                i_group_index = self._index_of_group_name(p_scene_list, group_name)
                if i_group_index == -1:
                    s_layer_name = self.create_unique_layer_name()
                    i_group_index = self.add_layer_to_list(p_scene, s_layer_name, color)
                    p_scene_list[i_group_index].name = group_name

                    i_obj_group_index = self.add_layer_to_list(p_obj, s_layer_name, color)
                    p_obj_list = self.get_list(p_obj)
                    p_obj_list[i_obj_group_index].name = group_name
                else:
                    self.ensure_group_in_object(p_obj, p_scene_list[i_group_index])

                if p_scene_list[i_group_index].group_color != color:
                    p_scene_list[i_group_index].group_color = color

                hash_layer = self.ensure_hash_layer(bm)
                p_bytes = hash32(p_scene_list[i_group_index].layer_name)
                for item in bm_items:
                    loops = item.link_loops if (b_is_vert or b_is_edge) else item.loops
                    has_color = False
                    for loop in loops:
                        if loop[p_vert_color_layer][0] == color[0] and \
                           loop[p_vert_color_layer][1] == color[1] and \
                           loop[p_vert_color_layer][2] == color[2]:

                            has_color = True
                            break

                    if has_color:
                        item[hash_layer] = p_bytes
                    else:
                        if item[hash_layer] == p_bytes:
                            item[hash_layer] = 0

        self.update_all_obj_groups_count(p_obj)

        self.set_list_index(p_scene, i_group_index)
