from threading import Thread

from PIL import Image, ImageDraw
from PIL.ImageQt import ImageQt
from PyQt5.Qt import Qt
from queue import Queue

from lib import constants
import numpy as np
from parallel.computationprocess import ComputationProcess
from parallel.pipe import Pipe
from util import renderutil
from util.colormaps import CenteredColorMap


class _Action():
    def __init__(self, cType, params):
        """
        @param cType: computation-type. One of: "forward", "inversion".
        @param params: the parameters for this action
        """
        self.cType = cType
        self.params = params


class ComputationManager(Thread):
    def __init__(self, gridSize):
        Thread.__init__(self)
        self.daemon = True
        # create a connection between this thread and the computationProcess
        self._connection, childConnection = Pipe()
        self._computationProcess = ComputationProcess(childConnection, gridSize)
 
        # the buffer for the output containers for the main thread
        self._buffer = Queue()

        # start executing the computation process
        self._computationProcess.start()

    def getProcess(self):
        return self._computationProcess

    def terminateProcess(self):
        self._computationProcess.terminate()


    # maybe re-add sourceFunction, wavelet_type-parameter one day (probably never)
    def solveForward(self, material, emitterPos, receiverPos):
        """
        requests to start solving the forward problem.
        @param material: the material
        @param emitterPos: the position of the emitter (x, y)
        @param receiverPos: the positions of the receivers ((x, y), (x, y), ...)
        @return: a "future"-list, where each element has the following structure:
                (waveImage, seismogramImage, seismogramData)
        """
        action = _Action("forward", (material, emitterPos, receiverPos))
        out = []

        self._buffer.put(("forward", out))
        # sends the request to the computationProcess
        self._connection.send(action)
        
        return out

    def solveInversion(self, initialMaterial=None):
        """
        requests to start solving the inversion problem.
        @keyword initialMaterial: the initial material for the inversion. If None, continue the inversion with the old one.
        @return: a "future"-queue, where the first element will be an image of the initial material.
                After that each iteration has the following structure:
                one dictionary for the forward solve: {"type": "forward", "frames": []}, where "frames" is a "future"-list of (waveImage, seismoImage, seismoData)-tuples.
                one dictionary for the backward solve: {"type": "backward", "frames": []}, where "frames" is a "future"-list of (waveImage,)-tuples.
                one dictionary for the material update: {"type": "materialUpdate", "material": materialImage}.
        """
        action = _Action("inversion", (initialMaterial,))
        out = Queue()

        self._buffer.put(("inversion", out))
        # sends the request to the computationProcess
        # TODO: this blocks with a standard Pipe, why?
        self._connection.send(action)
        
        return out

    def reset(self):
        """
        stop the computation and reset to default values
        """
        # TODO: implement this; hard reset (Exceptions)

    def run(self):
        while True:
            # wait until we receive the type of the next computation
            cType, out = self._buffer.get()

            if cType == "forward":
                self._processForwardData(out)
            elif cType == "inversion":
                self._processInversionData(out)

    def _processForwardData(self, out):
        # TODO: calculate colormap once?
        m, M = -0.003, 0.005
        if M == m:
            mid = 0.0
        else:
            mid = abs(m) / (M - m)
        colormap = CenteredColorMap(mid, (0, 0, 1, 1), (0, 0, 1, 0), (1, 0, 0, 0), (1, 0, 0, 1), stretch=0.6)
    
        for _ in range(constants.FRAME_COUNT):
            waveData, seismoData = self._connection.recv()
            waveImage = self._waveDataToImage(waveData, m, M, colormap)
            seismoImage = self._seismoDataToImage(seismoData)

            out.append((waveImage, seismoImage, seismoData))
        print "finished processing forward data"

    def _processInversionData(self, out):        
        for _ in range(constants.MAX_INV_ITERATIONS):
            self._processForwardIteration(out)
            self._processBackwardIteration(out)
            self._processMaterialUpdateIteration(out)

    def _processForwardIteration(self, out):
        data = {"type": "forward", "frames": []}
        out.put(data)
        
        # TODO: calculate colormap once?
        m, M = -0.003, 0.005
        if M == m:
            mid = 0.0
        else:
            mid = abs(m) / (M - m)
        colormap = CenteredColorMap(mid, (0, 0, 1, 1), (0, 0, 1, 0), (1, 0, 0, 0), (1, 0, 0, 1), stretch=0.6)
        
        for _ in range(constants.FRAME_COUNT):
            waveData, seismoData = self._connection.recv()
            waveImage = self._waveDataToImage(waveData, m, M, colormap)
#             seismoImage = self._seismoDataToImage(seismoData)
#             seismoImageGreen = self._seismoDataToImage(seismoData, color=(0,128,0))

#             data["frames"].append((waveImage, seismoImage, seismoImageGreen, seismoData))
            data["frames"].append((waveImage, None, None, seismoData))

    def _processBackwardIteration(self, out):
        data = {"type": "backward", "frames": []}
        out.put(data)

        # TODO: calculate colormap once?
        m, M = -1.5e-5, 1.1e-5
        if M == m:
            mid = 0.0
        else:
            mid = abs(m) / (M - m)
        colormap = CenteredColorMap(mid, (0, 0, 1, 1), (0, 0, 1, 0), (1, 0, 0, 0), (1, 0, 0, 1), stretch=0.6)
        
        for _ in range(constants.FRAME_COUNT):
            waveData = self._connection.recv()
            waveImage = self._waveDataToImage(waveData, m, M, colormap)

            data["frames"].append((waveImage,))

    def _processMaterialUpdateIteration(self, out):
        material = self._connection.recv()
        materialImage = renderutil.arrayToQImage(material, amplify=True)
        out.put({"type": "materialUpdate", "material": materialImage})

#===============================================================================
# util
#===============================================================================
    def _waveDataToImage(self, data, m, M, colormap):
        data[(data < 0.012 * M) * (data > 0)] = 0.0
        data[(data > 0.012 * m) * (data < 0)] = 0.0

        # normalize data
        data = (data - m) / (M - m)

        colorData = (255 * colormap(data)).astype(np.int8)
        colorData = np.transpose(colorData, (1, 0, 2))  # PIL stores it's images transposed
        image = Image.fromarray(colorData, mode="RGBA")
        return ImageQt(image)

    def _seismoDataToImage(self, measurements, seismoSize=(3000, 600), data_factor=2, draw_baselines=False, color=(0, 0, 0, 255)):
        data = measurements.T.copy()
        _R, _N = data.shape
        #data = data[list(range(0,_R+1, _R//8)),:]
        data = np.sqrt(np.abs(data)) * np.sign(data)
        m, M = np.min(data), np.max(data)
        R, N = data.shape
        
        # compensate for temporal decay
        #MAX, EXP = 5, 0.5, 
        #scaling_fct = lambda n: (MAX*(n/N)**EXP)
        #scaling = scaling_fct(np.arange(N, dtype=np.float)).reshape((1,N))
        #data *= scaling
    
        amplitude = max(abs(m), abs(M), 0.00001)
        ym, yM = (0, 2 * (R - 1) * amplitude)
        dY = (yM - ym) / (R - 1)
    
        # absM = max(abs(m), abs(M),0.0001)
        data *= data_factor * dY / (2 * amplitude)
    
        image = Image.new('RGBA', seismoSize, (255, 255, 255, 0))
        draw = ImageDraw.Draw(image)
        
        segmentWidth = image.width / float(N)
        segmentHeight = image.height / R
    
        data *= -1.5 / amplitude * (segmentHeight / 4.0)
    
        for i in range(R):
            tr = data[i, :]
            
            # TODO: remove +5 for mathing with matplotlib   
            tr += (i + 0.5) * segmentHeight + 5
            
            line = np.zeros(2 * len(tr), dtype=np.float64)
            line[::2] = np.arange(len(tr)) * segmentWidth
            line[1::2] = tr
                
            line = line.tolist()
            draw.line(line, fill=color, width=2)
        
        del draw
    
        return ImageQt(image).scaled(1000, 200, transformMode=Qt.SmoothTransformation)
