from PyQt5.Qt import Qt
from PyQt5.QtGui import QColor, QImage, QPainter, QPen
from PyQt5.QtWidgets import QApplication, QDialog, QDialogButtonBox, QGraphicsLineItem, QGraphicsScene, QGraphicsView, QVBoxLayout
from scipy.interpolate.fitpack2 import UnivariateSpline
from util import utils
from widgets.signalconfig.nodeitem import NodeItem
import matplotlib
from lib.i18n import lu
matplotlib.use("Agg")
from matplotlib import pyplot
import numpy
import sys

SCENE_WIDTH = 600
SCENE_HEIGHT = 400

class WaveConfiguration(QGraphicsView):
    def __init__(self, parent=None):
        QGraphicsView.__init__(self, parent)
        
        self.setCacheMode(QGraphicsView.CacheBackground)
        self.setViewportUpdateMode(QGraphicsView.BoundingRectViewportUpdate)
        self.setRenderHint(QPainter.Antialiasing)
        self.setTransformationAnchor(QGraphicsView.AnchorUnderMouse)
        self.setFixedSize(SCENE_WIDTH + 50, SCENE_HEIGHT + 50)
        self.setWindowTitle("Wave Configuration")
        self.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
        self.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
        
        self.setScene(QGraphicsScene(self))
        self.scene().setItemIndexMethod(QGraphicsScene.NoIndex)
        self.scene().setSceneRect(0, -SCENE_HEIGHT/2, SCENE_WIDTH, SCENE_HEIGHT)
        
        # dashed (stationary) line
        line = QGraphicsLineItem(0, 0, SCENE_WIDTH, 0)
        pen = QPen(QColor(150, 150, 150))
        pen.setStyle(Qt.DashLine)
        line.setPen(pen)
        self.scene().addItem(line)
        
        # left and right fixed points
        self._leftAnchor = NodeItem(self, stationary=True)
        self._leftAnchor.setPos(0, 0)
        self._rightAnchor = NodeItem(self, stationary=True)
        self._rightAnchor.setPos(SCENE_WIDTH, 0)
        self._rightAnchor.setLeft(self._leftAnchor)
        
        self.scene().addItem(self._leftAnchor)
        self.scene().addItem(self._rightAnchor)
        
        # interpolation function
        self._f = None
        self.recalculate()

    def recalculate(self):        
        # need minimum 4 values for interp1d
        x = [-1]
        y = [0]
        node = self._leftAnchor
        while node is not self._rightAnchor:
            x.append(node.x())
            y.append(-node.y())
            node = node.getRight()
            
        x += [SCENE_WIDTH, SCENE_WIDTH+1]
        y += [0, 0]
        
        # cubic spline interpolation
#         self._f = interpolate.interp1d(x, y, kind="cubic", assume_sorted=True)
        self._f = UnivariateSpline(x, y, k=3, s=0)
        # redraw
        self.invalidateScene(rect=self.sceneRect(), layers=QGraphicsScene.BackgroundLayer)

    def drawBackground(self, painter, rect):
        QGraphicsView.drawBackground(self, painter, rect)

        width, height = self.sceneRect().width(), self.sceneRect().height()
        # pyplot
        xn = numpy.arange(0, width, 1.0)
        yn = self._f(xn)  
        
        dpi = pyplot.gcf().dpi # 100.0
        inchWidth = width / dpi
        inchHeight = height / dpi
        
        fig = pyplot.figure(figsize=(inchWidth, inchHeight), dpi=dpi)
        
        ax = fig.add_axes((0, 0, 1, 1))
        ax.set_axis_off()
        ax.set_autoscale_on(False)
        ax.set_xlim(0, width)
        ax.set_ylim(-height / 2, height / 2)
        
        ax.plot(xn, yn, color="black", linewidth=2.0)
        pyplot.draw()

        argb_data = bytes(fig.canvas.buffer_rgba())
        qImage = QImage(argb_data, fig.canvas.renderer.width, fig.canvas.renderer.height, QImage.Format_ARGB32)
        pyplot.close(fig)

        painter.drawImage(0, -height / 2, qImage)

    def findNeighbour(self, x):
        node = self._leftAnchor
        
        while node.x() < x:
            node = node.getRight()
            
        return node

    def mousePressEvent(self, event):
        QGraphicsView.mousePressEvent(self, event)
        
        if not event.isAccepted():
            pos = self.mapToScene(event.pos())
            
            y = -self._f(pos.x())
            m = self._f(pos.x(), nu=1) # derivative of f
            
            # projection
            # lambda = v * w / (w * w),    here: v = (0, pos.y() + y), w = (1, m)
#             scalar = (pos.y() + y) / m
            
            # distance: sin(a) = GK / H
#             alpha = math.pi / 2 - math.atan(m)
#             d = math.sin(alpha) * abs(pos.y() - y)
            
            if abs(pos.y() - y) < 10 + 3*abs(m):
                right = self.findNeighbour(pos.x())
                 
                node = NodeItem(self)
                y = utils.clamp(y, -SCENE_HEIGHT / 2, SCENE_HEIGHT / 2)
                node.setPos(pos.x(), y)
                node.setLeft(right.getLeft())
                node.setRight(right)
                 
                self.scene().addItem(node)

    def getFunction(self):
        def f_normed(x, **kwargs):
            # uniformly scaled by SCENE_WIDTH
            return self._f(x * SCENE_WIDTH, **kwargs) / (SCENE_WIDTH / 2.0)
        return f_normed

class WaveDialog(QDialog):
    def __init__(self, parent=None):
        QDialog.__init__(self, parent)
        self.setWindowTitle(lu("labelWaveConfiguration"))
        self.setWindowModality(Qt.ApplicationModal)
        
        self.setLayout(QVBoxLayout())
        
        self._config = WaveConfiguration()
        self.layout().addWidget(self._config)
#         # cancel button
#         self.layout().addWidget(QPushButton("Cancel", clicked=self.close), 1, 0)
#         # confirm button
#         confButton = QPushButton("Confirm", clicked=self.accept)
#         self.layout().addWidget(confButton, 1, 1)
        
        buttonBox = QDialogButtonBox(QDialogButtonBox.Cancel | QDialogButtonBox.Save)
        buttonBox.accepted.connect(self.accept)
        buttonBox.rejected.connect(self.reject)
        self.layout().addWidget(buttonBox)

    def getResult(self):
        return self._config.getFunction()
        
if __name__ == "__main__":
    sys.excepthook = utils.excepthook
    
    app = QApplication(sys.argv)
    widget = WaveDialog()
    widget.show()
    if widget.exec_():
        f = widget.getResult()
        print f(0.5), f(numpy.array([0.0, 0.5, 1.0]))
        
    sys.exit(app.exec_())
