from matplotlib import cm, cbook
from matplotlib.colors import Colormap
from numpy import ma

import numpy as np


class CenteredColorMap(Colormap):
    """A (nonlinear) centered colormap"""
    def __init__(self, center, leftColor, leftMidColor, rightMidColor, rightColor, stretch=1.0, name="centered", N=256):
        """
        creates a new centered color-map with center point 'center'
        @param center: the 'center' value for this color-map that determines where 'white' will be. Has to be in range(0.0, 1.0).
        @param leftColor: a color-tuple representing the color of the far left side (0.0)
        @param leftMidColor: a color-tuple representing the color at the 'center' when coming from the left
        @param rightMidColor: a color-tuple representing the color at the 'center' when coming from the right
        @param rightColor: a color-tuple representing the color of the far right side (1.0)
        @param stretch: a stretch factor that gets applied as power of the 'distance from center' when evaluating the color-map.
        stretch < 1.0 means the binary map gets compressed (the map darkens faster), stretch > 1.0 means the map gets stretched (slower darkening)
        @param name: the name this colormap gets registered with
        @param N: the number of samples for the lookup-table (should  be in (1, 256)) 
        """
        Colormap.__init__(self, name, N=N)
        self._center = center
        self._leftColor = leftColor
        self._leftMidColor = leftMidColor
        self._rightMidColor = rightMidColor
        self._rightColor = rightColor
        self._stretch = stretch
        self.initLut()
        # register color map
        cm.register_cmap(cmap=self)

    def initLut(self):
        X = np.linspace(0.0, 1.0, self.N)
        
        centered = X - self._center
        xi = self._center + np.sign(centered) * np.abs(centered)**self._stretch
        np.clip(xi, 0.0, 1.0, xi)
        
        samples = (0.0, self._center, self._center, 1.0)
        colors = np.array((self._leftColor, self._leftMidColor, self._rightMidColor, self._rightColor), dtype=np.float64)

        self._lut = np.zeros(shape=xi.shape + (4,))
        self._lut[:,0] = np.interp(xi, samples, colors[:,0])
        self._lut[:,1] = np.interp(xi, samples, colors[:,1])
        self._lut[:,2] = np.interp(xi, samples, colors[:,2])
        self._lut[:,3] = np.interp(xi, samples, colors[:,3])
 
    def __call__(self, X, alpha=None, bytes=False):
        if not cbook.iterable(X):
            vtype = 'scalar'
            xa = np.array([X])
        else:
            vtype = 'array'
            xma = ma.array(X, copy=True)  # Copy here to avoid side effects.
            xa = xma.filled()             # Fill to avoid infs, etc.
            del xma
 
        # Calculations with native byteorder are faster, and avoid a
        # bug that otherwise can occur with putmask when the last
        # argument is a numpy scalar.
        if not xa.dtype.isnative:
            xa = xa.byteswap().newbyteorder()
 
        if xa.dtype.kind == "f":
            # Treat 1.0 as slightly less than 1.
            vals = np.array([1, 0], dtype=xa.dtype)
            almost_one = np.nextafter(*vals)
            #cbook._putmask(xa, xa == 1.0, almost_one)
            np.copyto(xa, almost_one, where=(xa == 1.0))

            # The following clip is fast, and prevents possible
            # conversion of large positive values to negative integers.
            xa *= self.N
            np.clip(xa, -1, self.N, out=xa)
 
            # ensure that all 'under' values will still have negative
            # value after casting to int
            #cbook._putmask(xa, xa < 0.0, -1)
            np.copyto(xa, -1, where=(xa < 0.0))
            xa = xa.astype(int)

        if bytes:
            lut = (self._lut * 255).astype(np.uint8)
        else:
            lut = self._lut.copy()  # Don't let alpha modify original _lut.
             
        if alpha is not None:
            alpha = min(alpha, 1.0)  # alpha must be between 0 and 1
            alpha = max(alpha, 0.0)
            if bytes:
                alpha = int(alpha * 255)
            if (lut[-1] == 0).all():
                lut[:-1, -1] = alpha
                # All zeros is taken as a flag for the default bad
                # color, which is no color--fully transparent.  We
                # don't want to override this.
            else:
                lut[:, -1] = alpha
                # If the bad value is set to have a color, then we
                # override its alpha just as for any other value.
 
        rgba = np.empty(shape=xa.shape + (4,), dtype=lut.dtype)
        lut.take(xa, axis=0, mode='clip', out=rgba)
        if vtype == 'scalar':
            rgba = tuple(rgba[0, :])
        return rgba
