from multiprocessing import Process

from pysit.core.domain import PML, RectangularDomain
from pysit.core.mesh import CartesianMesh
from pysit.core.receivers import PointReceiver, ReceiverSet
from pysit.core.shot import Shot
from pysit.core.sources import PointSource
from pysit.core.wave_source import DerivativeGaussianPulse, GaussianPulse, RickerWavelet
from pysit.objective_functions.temporal_least_squares import TemporalLeastSquares
from pysit.solvers.wave_factory import ConstantDensityAcousticWave

from lib import constants
from lib.pysit_interface import CustomSource
import numpy as np


class ComputationProcess(Process):
    def __init__(self, connection, gridSize):
        Process.__init__(self)
#         self.daemon = True # doesn't work
        self._connection = connection
        print "[ComputationProcess] starting with connection {}".format(connection)

        #=======================================================================
        # initialize constants
        #=======================================================================
        self._gridSize = gridSize
        self._flattenedShape = (gridSize[0] * gridSize[1], 1)
        # calculate water- and material-height
        factor = constants.WATER_TO_MAT_RATIO / (1 + constants.WATER_TO_MAT_RATIO)
        self._waterHeight = int(round(gridSize[1] * factor))
        self._materialHeight = gridSize[1] - self._waterHeight

        #=======================================================================
        # state parameters
        #=======================================================================
        # the current material
        self._material = None
        # the current shot
        self._shot = None
        # positions
        self._emitterPos = None
        self._receiverPos = None

    def _initialize(self):
        self.Nx, self.Ny = self._gridSize

        #min_x, max_x = 0, self._gridSize[0]
        #min_y, max_y = 0, self._gridSize[1]

        min_x, max_x = 0.0, 100.0
        length_x = max_x - min_x
        min_y, max_y = 0.0, length_x * self.Ny / self.Nx
        length_y = max_y - min_y


        self._pmlx = PML(0.1 * length_x, 100)
        self._pmly = PML(0.1 * length_y, 100)

        x_config = (min_x, max_x, self._pmlx, self._pmlx)
        y_config = (min_y, max_y, self._pmly, self._pmly)

        domain = RectangularDomain(x_config, y_config)
        self._mesh = CartesianMesh(domain, self.Nx, self.Ny)

        initialMaterial = constants.DEFAULT_BACKGROUND_VELOCITY * np.ones(self._mesh.shape())
        self._solver = ConstantDensityAcousticWave(self._mesh,
                                                   spatial_accuracy_order=8,
                                                   trange=(0, constants.DEFAULT_TIME_RANGE),
                                                   kernel_implementation='cpp',
                                                   formulation='scalar',
                                                   model_parameters={'C': initialMaterial}, # TODO: do we need this?
                                                   cfl_safety=1.0/5
                                                   )
        self._objective = TemporalLeastSquares(self._solver)

    def materialSize(self):
        return (self._gridSize[0], self._materialHeight)
 
    def solveForward(self, material, emitterPos, receiverPos):
        print "[ComputationProcess] starting forward solve"
        adjusted = lambda (x, y): (x * self.Nx * self._mesh.deltas[0], y * self._waterHeight * self._mesh.deltas[1])
        self._emitterPos = adjusted(emitterPos)
        self._receiverPos = map(adjusted, receiverPos)

        # create calculation material
        mat = constants.WATER_MATERIAL * np.ones(self._gridSize, dtype=np.float)
        mat[:, self._waterHeight:] = constants.BACKGROUND_MATERIAL + material
        material = mat.reshape(self._flattenedShape)

        # solve and update shot (_forward_model sends the data through the pipe)
        _, _, self._shot, _ = self._solve(material, updateShot=True)
        print "[ComputationProcess] finished forward solve"

    def solveInversion(self, initialMaterial):
        print "[ComputationProcess] starting inversion computation"
        # new initial material -> use it
        if initialMaterial is not None:
            mat = constants.WATER_MATERIAL * np.ones(self._gridSize, dtype=np.float)
            mat[:, self._waterHeight:] = constants.BACKGROUND_MATERIAL + initialMaterial
            self._material = mat.reshape(self._flattenedShape)
        # no initial material -> continue the inversion with the current material

        if self._material is None:
            print "[ComputationProcess] Something went wrong, material should be set!"

        for k in range(constants.MAX_INV_ITERATIONS):
            fwdWavefields, measurements, shot, dWaveOp = self._solve(self._material)
            
            gradient = self._backpropAndGrad(self._material, measurements, dWaveOp)
            gradient.reshape(self._gridSize)[:, :self._waterHeight] = 0.0 # TODO: why exactly is this set to 0 here?

            #import scipy.misc
            #scipy.misc.imsave("/tmp/grad_{}.png".format(k), gradient.reshape(self._gridSize))

            alpha = self._caluclateAlpha(self._material, gradient, dWaveOp)
            step = constants.MAGIC_INV_STEP_SIZE_FACTOR * alpha * gradient
            newMaterial = 1.0 / (1.0 / np.sqrt(self._material) + step) ** 2
            self._material = newMaterial

            # only send the part without water
            newMaterialReshaped = newMaterial.copy().reshape(self._gridSize)[:, self._waterHeight:]
            newMaterialReshaped -= constants.BACKGROUND_MATERIAL
            self._connection.send(newMaterialReshaped)

            print("finished iteration {}".format(k + 1))
        print "[ComputationProcess] finished inversion computation"

    def reset(self):
        pass
        # TODO: WIP

    def run(self):
        self._initialize()
        while True:
            # wait for request from the interface-thread
            action = self._connection.recv()

            # proof of concept
            if action.cType == "forward":
                self.solveForward(*action.params)
            elif action.cType == "inversion":
                self.solveInversion(*action.params)

#===============================================================================
# numeric / math
#===============================================================================
    def perturb_material(self, material):
        # TODO find better ways to perturb.
        tmp_material = material.copy()
        tmp_material[:, self._waterHeight:] = constants.BACKGROUND_MATERIAL
        return tmp_material

    def _solve(self, material, wavelet_type="Ricker", sourceFct=None, updateShot=False):
        shot = self._createShot(wavelet_type, sourceFct)
        baseModel = self._solver.ModelParameters(self._mesh, {'C': material})
        
        self._solver.model_parameters = baseModel
        ts = self._solver.ts()
        shot.reset_time_series(ts)
        
        shot.dt = self._solver.dt
        shot.trange = self._solver.trange
        
        wavefields, measurements, dWaveOp = self._forwardModel(shot, baseModel, updateShot=updateShot) 
        
        return wavefields, measurements, shot, dWaveOp

    def _createShot(self, wavelet_type="Ricker", sourceFct=None):
        wavelet_types = {"Ricker" : RickerWavelet(10.0),
                         "Gaussian" : GaussianPulse(10.0),
                         "dGaussian" : DerivativeGaussianPulse(10.0)}

        receivers = ReceiverSet(self._mesh, [PointReceiver(self._mesh, p) for p in self._receiverPos])

        if wavelet_type != "custom":
            wavelet = wavelet_types[wavelet_type]
        else:
            wavelet = CustomSource(10.0, sourceFct)

        source = PointSource(self._mesh, self._emitterPos, wavelet)
        shot = Shot(source, receivers)
        return shot

    def _forwardModel(self, shot, baseModel, updateShot=False):
        """ orignal copied from modeling/temporal_modeling.py """
        # Local references
        solver = self._solver
        solver.model_parameters = baseModel
        
        mesh = solver.mesh
        dt = solver.dt
        nsteps = solver.nsteps
        source = shot.sources

        # Storage for the field        
        us = list()
        
        # Setup data storage for the forward modeled data        
        simdata = np.zeros((solver.nsteps, shot.receivers.receiver_count))            

        # Storage for the time derivatives of p        
        dWaveOp = list()
        
        # Step k = 0
        # p_0 is a zero array because if we assume the input signal is causal
        # and we assume that the initial system (i.e., p_(-2) and p_(-1)) is
        # uniformly zero, then the leapfrog scheme would compute that p_0 = 0 as
        # well. ukm1 is needed to compute the temporal derivative.
        solver_data = solver.SolverData()
        
        rhs_k = np.zeros(mesh.shape(include_bc=True))
        rhs_kp1 = np.zeros(mesh.shape(include_bc=True))

        idx = 0
        frameIndex = 0

        for k in xrange(nsteps):
            uk = solver_data.k.primary_wavefield
            uk_bulk = mesh.unpad_array(uk)
            us.append(uk_bulk.copy())
            
            # Record the data at t_k            
            shot.receivers.sample_data_from_array(uk_bulk, k, data=simdata)
            
            # we must go deeper
            if idx < constants.FRAME_COUNT and frameIndex == k:
                idx += 1
                # TODO: when nsteps < constants.FRAME_COUNT this doesn't work, possible workaround?
                frameIndex = int((nsteps - 1) * idx / (constants.FRAME_COUNT - 1))
                # send the waveData and seismoData through the connection
                waveData = uk_bulk.copy().reshape(self._gridSize)
                self._connection.send((waveData, simdata.copy()))
                
            if updateShot:
                shot.receivers.sample_data_from_array(uk_bulk, k)
            
            if k == 0:
                rhs_k = self._setupForwardRhs(rhs_k, source.f(k * dt))
                rhs_kp1 = self._setupForwardRhs(rhs_kp1, source.f((k + 1) * dt))
            else:
                # shift time forward
                rhs_k, rhs_kp1 = rhs_kp1, rhs_k
            rhs_kp1 = self._setupForwardRhs(rhs_kp1, source.f((k + 1) * dt))
            
            # Note, we compute result for k+1 even when k == nsteps-1.  We need
            # it for the time derivative at k=nsteps-1.
            solver.time_step(solver_data, rhs_k, rhs_kp1)

            # Compute time derivative of p at time k
            # Note that this is is returned as a PADDED array            
            dWaveOp.append(solver.compute_dWaveOp('time', solver_data))
        
            # When k is the nth step, the next step is uneeded, so don't swap 
            # any values.  This way, uk at the end is always the final step
            if(k == (nsteps - 1)): break
            
            # Don't know what data is needed for the solver, so the solver data
            # handles advancing everything forward by one time step.
            # k-1 <-- k, k <-- k+1, etc
            solver_data.advance()
                    
        return us, simdata, dWaveOp

    def _setupForwardRhs(self, rhsArray, data):
        return self._solver.mesh.pad_array(data, out_array=rhsArray)

    def _backpropAndGrad(self, material, s_k, u_kdWaveOp):        
        s_obs = self._shot.receivers.interpolate_data(self._solver.ts())
        s_residual = s_obs - s_k
        
        print "*** misfit: ", 0.5 * np.sum(s_residual ** 2) * self._solver.dt
        
        model_k = self._solver.ModelParameters(self._mesh, {'C': material})
        
        # g = self.migrate_shot(shot, model_k, s_residual,dWaveOp=u_kdWaveOp, res_queue=res_queue)
        ic = self._adjointModel(model_k, s_residual, dWaveOp=u_kdWaveOp)
        
        return ic.without_padding().C

    def _adjointModel(self, model_k, operandSimdata, dWaveOp=None):
        """
        from temporal_modeling.py
        """
        # Local references
        solver = self._solver
        solver.model_parameters = model_k
        
        mesh = solver.mesh
        
        dt = solver.dt
        nsteps = solver.nsteps
        
        if dWaveOp is not None:
            ic = solver.model_parameters.perturbation()
            do_ic = True        

        # Time-reversed wave solver
        solver_data = solver.SolverData()
        
        rhs_k = np.zeros(mesh.shape(include_bc=True))
        rhs_km1 = np.zeros(mesh.shape(include_bc=True))
        
        idx = 0
        frameIndex = nsteps - 1
        
        # Loop goes over the valid indices backwards
        for k in xrange(nsteps - 1, -1, -1):  # xrange(int(solver.nsteps)):
            # Local reference
            vk = solver_data.k.primary_wavefield
            vk_bulk = mesh.unpad_array(vk)
            
            if idx < constants.FRAME_COUNT and frameIndex == k:
                idx += 1
                frameIndex = int((nsteps - 1) * (constants.FRAME_COUNT - 1 - idx) / (constants.FRAME_COUNT - 1))
                # send the waveData through the connection
                waveData = vk_bulk.copy().reshape(self._gridSize)
                self._connection.send(waveData)

            # can maybe speed up by using only the bulk and not unpadding later
            if do_ic:
                ic += vk * dWaveOp[k]
            
            if k == nsteps - 1:
                rhs_k = self._setupAdjointRhs(rhs_k, self._shot, k, operandSimdata)
                rhs_km1 = self._setupAdjointRhs(rhs_km1, self._shot, k - 1, operandSimdata)
            else:
                # shift time forward
                rhs_k, rhs_km1 = rhs_km1, rhs_k
            rhs_km1 = self._setupAdjointRhs(rhs_km1, self._shot, k - 1, operandSimdata)
            
            solver.time_step(solver_data, rhs_k, rhs_km1)
                    
            # If k is 0, we don't need results for k-1, so save computation and
            # stop early
            if(k == 0): break
                
            # Don't know what data is needed for the solver, so the solver data
            # handles advancing everything forward by one time step.
            # k-1 <-- k, k <-- k+1, etc
            solver_data.advance()
    
        if do_ic:
            ic *= (-1 * dt)    
            ic = ic.without_padding()  # gradient is never padded   
        
        return ic

    def _setupAdjointRhs(self, rhsArray, shot, k, operandSimdata):        
        return self._solver.mesh.pad_array(shot.receivers.extend_data_to_array(k, data=operandSimdata), out_array=rhsArray)

    def _caluclateAlpha(self, material, g, u_kdWaveOp):
        model_k = self._solver.ModelParameters(self._mesh, {'C': material})
        gradient = model_k.perturbation(data=g)
        linear_retval = self._objective.modeling_tools.linear_forward_model(self._shot, model_k, gradient,
                                                                            dWaveOp0=u_kdWaveOp,
                                                                            return_parameters=['simdata'])        
        delta_t = self._solver.dt
        norm_lin_sq = np.sum(linear_retval['simdata'] ** 2) * delta_t  # ???? * (delta_t*delta_t) * 7.5
        hsquared = np.prod(self._mesh.deltas)
        norm_adj_sq = np.sum(g ** 2) * hsquared
        alpha = norm_adj_sq / norm_lin_sq
        
        print "hsquared", hsquared, "dt", delta_t
        print "norm_adj", np.sqrt(norm_adj_sq)
        print "norm_lin", np.sqrt(norm_lin_sq)
        print "alpha", alpha
        
        return alpha
