from PIL import Image, ImageDraw
from PIL.ImageQt import ImageQt
from PyQt5.QtCore import Qt, QLineF, QRectF
from PyQt5.QtGui import QPainter, QColor
from PyQt5.QtWidgets import QWidget, QHBoxLayout, QComboBox, QVBoxLayout

from lib import constants
from lib.i18n import lu
import numpy as np


SEISMO_HEIGHT = 225
SEISMO_BG_COLOR = (50, 50, 50, 128)
SEISMO_FWD_COLOR = (0, 128, 0, 255)

class SeismoWidget(QWidget):
    def __init__(self, forward=True, parent=None):
        QWidget.__init__(self, parent=parent)
        self._originalSeismogram = None
        self._currentSeismogram = None

        self._currentFrame = 0
        self._forward = forward

    def setOriginalSeismogram(self, seismogram):
        self._originalSeismogram = seismogram
        self.update()   
        
    def setCurrentSeismogram(self, seismogram):
        self._currentSeismogram = seismogram
        self.update()

    def setForward(self, forward):
        self._forward = forward

    def setCurrentFrame(self, currentFrame):
        self._currentFrame = currentFrame
        self.update()

    def paintEvent(self, event):
        QWidget.paintEvent(self, event)
        painter = QPainter(self)

        painter.setPen(QColor(Qt.lightGray))
        x = self.width() * self._currentFrame // constants.FRAME_COUNT
        painter.drawLine(QLineF(x, 0, x, self.height()))
            
        if self._originalSeismogram is not None:
            img = self._originalSeismogram.scaled(self.size(), transformMode=Qt.SmoothTransformation)
            painter.drawImage(self.rect(), self._originalSeismogram)
    
        if self._currentSeismogram is not None:
            img = self._currentSeismogram.scaled(self.size(), transformMode=Qt.SmoothTransformation)
            if self._forward:
                rect = QRectF(0, 0, x, self.height())
                painter.drawImage(rect, img, rect)
            else:
                painter.drawImage(self.rect(), img)


class SeismoDetailWidget(QWidget):
    def __init__(self, parent=None):
        QWidget.__init__(self, parent=parent)
        self.setLayout(QVBoxLayout())
        self.layout().setContentsMargins(0, 0, 0, 0)
        
        self._dropDown = QComboBox()
        lReceiver = lu("labelReceiver")
        # TODO: receiver count hardcoded!!!
        self._dropDown.addItems([lReceiver + ": " + str(i) for i in range(1, 19 + 1)])
        self._dropDown.currentIndexChanged.connect(self.indexChanged)
        self.layout().addWidget(self._dropDown)

        self.seismo = SeismoWidget(parent=self)
        self.layout().addWidget(self.seismo)
        
        self._lastOriginalData = None
        self._lastCurrentData = None

    def setCurrentReceiver(self,index):
        self._dropDown.setCurrentIndex(index)

    def updateOriginalSeismoData(self, seismoData):
        self._lastOriginalData = seismoData
        if seismoData is None:
            self.seismo.setOriginalSeismogram(None)
        else:
            seismoImage = _seismoDataToImageSingle(seismoData, self._dropDown.currentIndex(), seismoSize=(2*self.seismo.width(), 2*self.seismo.height()), color=SEISMO_BG_COLOR)
            self.seismo.setOriginalSeismogram(seismoImage)

    def updateCurrentSeismoData(self, seismoData):
        self._lastCurrentData = seismoData
        if seismoData is None:
            self.seismo.setCurrentSeismogram(None)
        else:
            seismoImage = _seismoDataToImageSingle(seismoData, self._dropDown.currentIndex(), seismoSize=(2*self.seismo.width(), 2*self.seismo.height()), color=SEISMO_FWD_COLOR)
            self.seismo.setCurrentSeismogram(seismoImage)

    def indexChanged(self, index):
        if self._lastOriginalData is not None:
            seismoImage = _seismoDataToImageSingle(self._lastOriginalData, index, seismoSize=(2*self.seismo.width(), 2*self.seismo.height()), color=SEISMO_BG_COLOR)
            self.seismo.setOriginalSeismogram(seismoImage)

        if self._lastCurrentData is not None:
            seismoImage = _seismoDataToImageSingle(self._lastCurrentData, index, seismoSize=(2*self.seismo.width(), 2*self.seismo.height()), color=SEISMO_FWD_COLOR)
            self.seismo.setCurrentSeismogram(seismoImage)


class SeismogramView(QWidget):
    def __init__(self, parent=None):
        QWidget.__init__(self, parent=parent)
        self.setFixedHeight(SEISMO_HEIGHT)
        self.setLayout(QHBoxLayout())
        self.layout().setContentsMargins(0, 0, 0, 0)
        self.layout().setSpacing(10)
        
        self._seismo = SeismoWidget(parent=self)
        self.layout().addWidget(self._seismo, stretch=3)
        
        self._seismoDetail = SeismoDetailWidget(parent=self)
        self.layout().addWidget(self._seismoDetail, stretch=1)

    def mousePressEvent(self, event):
        # TODO: receiver count hardcoded!!!
        R = 19
        if event.x() > self._seismo.width():
            return                
        receiver_number = round(event.y()*R/self._seismo.height())
        self._seismoDetail.setCurrentReceiver(receiver_number)        

    def updateOriginalSeismoData(self, seismoData):
        if seismoData is None:
            self._seismo.setOriginalSeismogram(None)
        else:
            self._seismo.setOriginalSeismogram(_seismoDataToImage(seismoData, seismoSize=(2*self._seismo.width(), 2*self._seismo.height())))

        self._seismoDetail.updateOriginalSeismoData(seismoData)
        self.update()

    def updateCurrentSeismoData(self, seismoData):
        if seismoData is None:
            self._seismo.setCurrentSeismogram(None)
        else:
            color = SEISMO_FWD_COLOR if self._seismo._forward else SEISMO_BG_COLOR
            self._seismo.setCurrentSeismogram(_seismoDataToImage(seismoData, seismoSize=(2*self._seismo.width(), 2*self._seismo.height()), color=color))

        self._seismoDetail.updateCurrentSeismoData(seismoData)
        self.update()   

    def setForward(self, forward):
        self._seismo.setForward(forward)
        self._seismoDetail.seismo.setForward(forward)

    def setCurrentFrame(self, currentFrame):
        self._seismo.setCurrentFrame(currentFrame)
        self._seismoDetail.seismo.setCurrentFrame(currentFrame)


#===============================================================================
# util
#===============================================================================
# TODO: maybe draw this in computationmanager again instead of here
def _seismoDataToImage(measurements, seismoSize=(2000, 700), data_factor=2, color=SEISMO_BG_COLOR):
    data = measurements.T.copy()
    # data adjustment (for better visualisation)
    data = np.sqrt(np.abs(data)) * np.sign(data)

    m, M = np.min(data), np.max(data)
    R, N = data.shape
    
    # compensate for temporal decay
    #MAX, EXP = 5, 0.5, 
    #scaling_fct = lambda n: (MAX*(n/N)**EXP)
    #scaling = scaling_fct(np.arange(N, dtype=np.float)).reshape((1,N))
    #data *= scaling

    amplitude = max(abs(m), abs(M), 0.00001)
    ym, yM = (0, 2 * (R - 1) * amplitude)
    dY = (yM - ym) / (R - 1)

    # absM = max(abs(m), abs(M),0.0001)
    data *= data_factor * dY / (2 * amplitude)

    image = Image.new('RGBA', seismoSize, (255, 255, 255, 0))
    draw = ImageDraw.Draw(image)
    
    segmentWidth = image.width / float(N)
    segmentHeight = image.height / R
    lineWidth = max(seismoSize[1] // 150, 1)

    data *= -1.5 / amplitude * (segmentHeight / 4.0)

    for i in range(R):
        tr = data[i, :]  
        tr += (i + 0.5) * segmentHeight
        
        line = np.zeros(2 * len(tr), dtype=np.float64)
        line[::2] = np.arange(len(tr)) * segmentWidth
        line[1::2] = tr
            
        line = line.tolist()
        draw.line(line, fill=color, width=lineWidth)
    
    del draw

    return ImageQt(image)

def _seismoDataToImageSingle(measurements, index, seismoSize=(800, 750), color=(0, 0, 0, 255)):
    data = measurements.T[index, :].copy()
    
    m, M = np.min(data), np.max(data)
    N = data.shape[0]
    amplitude = max(abs(m), abs(M), 0.00001)

    image = Image.new('RGBA', seismoSize, (255, 255, 255, 0))
    draw = ImageDraw.Draw(image)

    # scale amplitude to image height (minus to flip upside-down to adjust to PILs coordinate system)
    data *= -image.height / (2 * amplitude)
    # move to vertical image-center
    data += 0.5 * image.height
    
    segmentWidth = image.width / float(N)
    
    line = np.zeros(2 * len(data), dtype=np.float64)
    # fill x-coordinates
    line[::2] = np.arange(len(data)) * segmentWidth
    # fill y-coordinates
    line[1::2] = data

    line = line.tolist()
    lineWidth = max(seismoSize[1] // 150, 1)
    draw.line(line, fill=color, width=lineWidth)
    del draw

    return ImageQt(image)
