#!/bin/python3
import sys,re, math
from copy import deepcopy
import numpy as np
from pprint import pprint
sys.path.insert(0, '../../')
from fred import list2int, toGrid, nprint,get_re

input_f = 'test'

part = 1
#########################################
#                                       #
#              Part 1                   #
#                                       #
#########################################

instructions = {}

grid = [
    ['.','#','.'],
    ['.','.','#'],
    ['#','#','#']
]

def toSet(input,regex):
    set = {}
    with open(input) as file:
        for line in file:
            r = get_re(regex, line.rstrip())
            set[r.group(1).replace('/','')] = r.group(2).replace('/','')
    return set
    
def need2split(grid):
    if len(grid) > 3:
        return True
    else:
        return False
    
def grid2line(grid):
    line = ''
    for i in grid:
        line += ''.join(i)
    return line

def line2grid(line):
    grid = []
    size = int(math.sqrt(len(line)))
    for i in range(size):
        grid.append(list(line[i*size:size*(1+i)]))
    return grid

def rotate90(grid):
    rotatedGrid = np.array(deepcopy(grid))
    return np.rot90(rotatedGrid)

def splitGrid(grid):
    size = len(grid)
    block_size = 0
    if size % 2 == 0:
        block_size = 2
    elif size % 3 == 0:
        block_size = 3
    else:
        print('Grind is wrong at splitGrid()')

    blocks = []

    for i in range(0, size, block_size):
        for j in range(0, size, block_size):
            block = [row[j:j+block_size] for row in grid[i:i+block_size]]
            #print(block)
            for bdx,b in enumerate(block):
                block[bdx] = list(b)
            #print(block)
            #input()
            blocks.append(block)    
    return blocks

def findInst(grid,instructions):
    found = False
    new_grid = []
    grid = grid
    while not found:
        try:
            new_grid = instructions[grid2line(grid)]
            found = True
        except:
            grid = rotate90(grid)  
    return new_grid

def transform_grid(input_grid):
    result = []

    for half in [input_grid[:2], input_grid[2:]]:
        for i in range(3):  # Each subgrid has 3 rows
            # Merge the rows from the two subgrids in the current half
            row = ''.join([''.join(subgrid[i]) for subgrid in half])
            result.append(row)
    
    return result

if part == 1:
    instructions = toSet(input_f,r"^(.*) => (.*)$")

    print(instructions)
    blocks = []
    mixed_grid = [[]]
    for i in range(0,5):
        size = len(grid)
        nprint(grid)
        print()
        if size % 3 == 0 or size % 2 == 0:
            if not need2split(grid):
                new_grid = findInst(grid,instructions)
                grid = line2grid(new_grid)
                #print(grid)
            else:
                blocks = splitGrid(grid)
                #print(blocks)
                #print(len(blocks))
                #print(blocks)
                print(blocks)
                input()
                mixed_grid = []
                

                for i in range(0,len(blocks)):
                    x = line2grid(findInst(blocks[i],instructions))
                    
                    mixed_grid.append(x)                  
                grid = transform_grid(mixed_grid)
                #nprint(grid)
                #input()

        else:
            print('Something is wrong with the grid size of ', size)
        #input()
#########################################
#                                       #
#              Part 2                   #
#                                       #
#########################################
if part == 2:
    exit()