from enum import IntEnum, _EnumDict
from dataclasses import dataclass
from typing import Optional, List
from math import floor

class FusionLightEffect(IntEnum):
    Static = 1
    Breathing = 2
    Wave = 3
    Fadeonkeypress = 4
    Marquee = 5
    Ripple = 6
    Flashonkeypress = 7
    Neon = 8
    Rainbowmarquee = 9
    Raindrop = 10
    Circlemarquee = 11
    Hedge = 12
    Rotate = 13
    Custom1 = 51
    Custom2 = 52
    Custom3 = 53
    Custom4 = 54
    Custom5 = 55

    #accept just plain ints to bypass "technically correct" profile ids
    #idk what hell i just unleashed upon myself with this but yes it works
    #blame object.__new__ not being usable >:(
    @classmethod
    def _missing_(cls, value):
        #need to reset member names quickly to bypass _check_for_existing_members
        backup = FusionLightEffect._member_names_
        FusionLightEffect._member_names_ = []

        #create a subclass with an unknown member just for this value
        enum_dict = _EnumDict()
        enum_dict._cls_name = cls
        enum_dict['Unknown'] = value
        member = type('FusionLightEffectCustom', (FusionLightEffect,), enum_dict).Unknown

        #restore the backup to avoid breaking things
        FusionLightEffect._member_names_ = backup
        return member

class FusionLightColor(IntEnum):
    Black = 0
    Red = 1
    Green = 2
    Yellow = 3
    Blue = 4
    Orange = 5
    Purple = 6
    White = 7
    Random = 8

class IoneLightDirection(IntEnum):
    Left2Right = 0
    Right2Left = 1
    Up2Down = 2
    Down2Up = 3
    Clockwise = 4
    AntiClockwise = 5

class FusionLightDirection(IntEnum):
    Left2Right = 1
    Right2Left = 2
    Down2Up = 3
    Up2Down = 4
    Clockwise = 1
    AntiClockwise = 2


@dataclass
class FusionLightData:
    fusion_effect: FusionLightEffect
    #seems like speed actually means duration here since shorter = faster
    fusion_speed: int
    fusion_brightness: int
    fusion_color: FusionLightColor
    fusion_direction: FusionLightDirection

@dataclass
class RawInputDevice:
    usUsagePage: int
    usUsage: int 
    dwFlags: int
    hwndTarget: Optional[int] = None

@dataclass
class RGB:
    red: int
    green: int
    blue: int

@dataclass
class PictureMatrix:
    pixels: List[RGB]

    def to_bytes(self, adjust: RGB = None) -> bytes:
        if len(self.pixels) not in range(106):
            raise TypeError('Too many pixels!')
        arr = bytearray(512)
        for i, pixel in enumerate(self.pixels):
            idx = KEY_MATRIX_INDEX_VALUES[i]*4
            if adjust:
                arr[idx:idx+4] = [0, (pixel.red * adjust.red) // 255, (pixel.green * adjust.green) // 255, (pixel.blue * adjust.blue) // 255]
            else:
                arr[idx:idx+4] = [0, pixel.red, pixel.green, pixel.blue]
        return arr

    @classmethod
    def from_bytes(cls, data: bytes) -> 'PictureMatrix':
        if len(data) != 512:
            raise TypeError('Invalid payload size!')
        #the nones should be gone after the loop since KEY_MATRIX_INDEX_VALUES should have all the positions for 105 keys regardless of color
        pixels = [None]*105
        for chunk, idx in enumerate(range(0, len(data), 4)):
            if chunk in KEY_MATRIX_INDEX_VALUES:
                #apparently there are duplicates in the mapping so we need to loop through it
                start = 0
                try:
                    while True:
                        key_idx = KEY_MATRIX_INDEX_VALUES.index(chunk, start)
                        start = key_idx + 1
                        pixels[key_idx] = RGB(*data[idx+1:idx+4])
                except ValueError:
                    continue
        return PictureMatrix(pixels)

    @classmethod
    def pixel_matrix_to_keys(cls, list: List[RGB]) -> 'PictureMatrix':
        """Turns a proper 19x6 matrix of RGB values into the keyboard matrix by averaging the color over larger keys"""
        if len(list) != 19*6:
            raise TypeError('Has to be a 19x6 matrix!')
        
        mat = [list[i:i+19] for i in range(0, 19*6, 19)]
        
        def avg(ints: List[int]):
            return sum(ints) // len(ints)

        def merge(pos: int, length: int, pixels: List[RGB]) -> RGB:
            curr_length = 1 - (pos - floor(pos))  #how much of the key is actually in the first pixel location
            mat_pos = floor(pos)   #first pixel location
            
            length -= curr_length

            r, g, b = [], [], []
            while (curr_length + length) > 0:
                #print(mat_pos, length, curr_length)
                #apply weight
                r.append(pixels[mat_pos].red * curr_length)
                g.append(pixels[mat_pos].green * curr_length)
                b.append(pixels[mat_pos].blue * curr_length)
                #either it takes a full pixel, or a portion of the pixel if nothing else is left
                curr_length = min(1, length)
                length -= curr_length
                mat_pos += 1

            return RGB(round(avg(r)), round(avg(g)), round(avg(b)))

        key_pixels = []

        for pixels, weights in zip(mat, KEY_WEIGHTS):
            pos = 0
            for weight in weights:
                key_pixels.append(merge(pos, weight, pixels))
                pos += weight

        #account for the 2 vertical keys
        def merge_vertical(*pos: int):
            p1, p2 = key_pixels[pos[0]], key_pixels[pos[1]], 

            key_pixels[pos[1]] = RGB(avg([p1.red, p2.red]), avg([p1.green, p2.green]), avg([p1.blue, p2.blue]))
            #first position is the one to be removed according to KEY_TITLES mapping
            key_pixels.pop(pos[0])


        #note the order - it has to be from top of list to bottom of list to avoid repositioning
        merge_vertical(88, 102)  #numpad enter
        merge_vertical(54, 71)  #numpad plus
        
        return PictureMatrix(key_pixels)
        


KEY_MATRIX_INDEX_VALUES = [
    11, 17, 23, 29, 35, 41, 47, 53, 59, 65,
    71, 77, 83, 89, 95, 101, 107, 113, 119, 10,
    16, 22, 28, 34, 40, 46, 52, 58, 64, 70,
    76, 82, 94, 100, 106, 112, 118, 9, 15, 21,
    27, 33, 39, 45, 51, 57, 63, 69, 75, 81,
    87, 99, 105, 111, 8, 14, 20, 26, 32, 38,
    44, 50, 56, 62, 68, 74, 92, 98, 104, 110,
    116, 7, 19, 25, 31, 37, 43, 49, 55, 61,
    67, 73, 85, 91, 97, 103, 109, 6, 12, 18,
    24, 42, 60, 66, 72, 84, 90, 96, 102, 108,
    114, 86, 92, 7, 13
]

KEY_TITLES = [
    "btnEsc", "btnF1", "btnF2", "btnF3", "btnF4", "btnF5", "btnF6", "btnF7", "btnF8", "btnF9", "btnF10", "btnF11", "btnF12", "btnPause", "btnDel", "btnHome", "btnPgUp", "btnPgDn", "btnEnd", 
    "btnGrave", "btn1", "btn2", "btn3", "btn4", "btn5", "btn6", "btn7", "btn8", "btn9", "btn0", "btnHyphen", "btnEqual", "btnBackspace", "btnNumLk", "btnSlash2", "btnAsterisk", "btnMinus", 
    "btnTab", "btnQ", "btnW", "btnE", "btnR", "btnT", "btnY", "btnU", "btnI", "btnO", "btnP", "btnLsquarebracket", "btnRsquarebracket", "btnBackslash", "btn_7", "btn_8", "btn_9", 
    "btnCapsLock", "btnA", "btnS", "btnD", "btnF", "btnG", "btnH", "btnJ", "btnK", "btnL", "btnSemicolon", "btnApostrophe", "btnEnter", "btn_4", "btn_5", "btn_6", "btnPlus", 
    "btnLshift", "btnZ", "btnX", "btnC", "btnV", "btnB", "btnN", "btnM", "btnComma", "btnFullstop", "btn_Slash", "btnRshift", "btnUp", "btn_1", "btn_2", "btn_3", 
    "btnLctrl", "btnFn", "btnWin", "btnLalt", "btnSpace", "btnRalt", "btnApp", "btnRctrl", "btnLeft", "btnDown", "btnRight", "btn_0", "btn_Del", "btnEnter2", 
    "btnSharpUk", "btnEnterUk", "btnLshiftUk", "btnSlashUk"
]

KEY_WEIGHTS = [
    [1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1],
    [1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,          2,    1,    1,    1,    1],
    [1.5,     1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,     1.5,    1,    1,    1,    1], 
    [1.8,       1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,         2.2,    1,    1,    1,    1],
    [2.3,         1,    1,    1,    1,    1,    1,    1,    1,    1,    1,        1.7,   1,    1,    1,    1,    1],
    [1.2,   1,    1,    1,                            5.2,    1,    1,     1.6,    1,    1,    1,    1,    1,    1],
]