from pysit.core.domain import RectangularDomain, PML
from pysit.core.mesh import CartesianMesh
from pysit.core.receivers import PointReceiver
from pysit.core.receivers import ReceiverSet
from pysit.core.shot import Shot
from pysit.core.sources import PointSource
from pysit.core.wave_source import RickerWavelet, GaussianPulse, DerivativeGaussianPulse, SourceWaveletBase, _arrayify
from pysit.objective_functions.temporal_least_squares import TemporalLeastSquares
from pysit.solvers.wave_factory import ConstantDensityAcousticWave

from lib import constants
import numpy as np
from util.decorators import benchmark


# TODO: benchmarkWidth, benchmarkHeight

def getGridResolution(aspectRatio):
    # TODO: use results from benchmark here
    gridPoints = constants.TOTAL_GRID_SIZE
    # height^2 * aspectRatio = gridPoints
    width = int(round(np.sqrt(gridPoints * aspectRatio)))
    height = int(round(width / aspectRatio))
    print "GridSize: ({}, {})".format(width,height)
    
    return (width, height)

# def calculateGridSize(width, water_height, bottom_height, N=TOTAL_GRID_SIZE):
#     total_height = water_height + bottom_height
#     r = float(total_height) / width
#     w = round(np.sqrt(N / r))
#     h_total = round(r * w)
#     h_water = round(water_height * h_total / total_height)
#     h_bottom = round(bottom_height * h_total / total_height)
#     return (int(w), int(h_water), int(h_bottom))

def perturb_material(material, water_height=10):
    # TODO find better ways to perturb.
    
    tmp_material = material.copy()
    # tmp_material[:, water_height:] = filters.gaussian_filter(material[:, water_height:], sigma=12, order=0)
    tmp_material[:, water_height:] = BACKGROUND_MATERIAL
    
    return tmp_material


class CustomSource(SourceWaveletBase):
    def __init__(self, scaling, source_fct, *args, **kwargs):
        self._scaling = scaling
        self._source_fct = source_fct

    @property
    def time_source(self):
        """bool, Indicates if wavelet is defined in time domain."""
        return True

    def _evaluate_time(self, ts):
        # Vectorize the time list
        ts_was_not_array, ts = _arrayify(ts)
        v = self._scaling * self._source_fct(ts / 10.0)
        return v[0] if ts_was_not_array else v

    def _evaluate_frequency(self, nus):
        raise NotImplementedError


class ForwardSolver:
    def __init__(self, dimensions=((-1, 1), (-1, 1)), grid_resolution=(30, 30), trange=(0, 1.7),
                 background_velocity=1.0):

        self.Nx, self.Ny = grid_resolution

        min_x, max_x = dimensions[0]
        min_y, max_y = dimensions[1]

        length_x = max_x - min_x
        length_y = max_y - min_y

        self._pmlx = PML(0.1 * length_x, 100)
        self._pmly = PML(0.1 * length_y, 100)

        x_config = (min_x, max_x, self._pmlx, self._pmlx)
        y_config = (min_y, max_y, self._pmly, self._pmly)
        # y_config = (min_y,max_y, Dirichlet(), self._pmly)

        self._background_velocity = background_velocity
        self._trange = trange

        self._domain = RectangularDomain(x_config, y_config)
        self._mesh = CartesianMesh(self._domain,
                                   self.Nx, self.Ny)

        self._initial_material = self._background_velocity * np.ones(self._mesh.shape())
        self._true_material = self._initial_material.copy()

        self._solver = ConstantDensityAcousticWave(self._mesh,
                                                   spatial_accuracy_order=8,
                                                   trange=self._trange,
                                                   kernel_implementation='cpp',
                                                   # cfl_safety=0.1,
                                                   formulation='scalar',
                                                   model_parameters={'C': self._true_material}
                                                   )
        self._objective = TemporalLeastSquares(self._solver) 

    def createShot(self, source_position, receiver_positions,
                   wavelet_type="Ricker", source_fct=None):
        wavelet_types = {"Ricker" : RickerWavelet(10.0),
                         "Gaussian" : GaussianPulse(10.0),
                         "dGaussian" : DerivativeGaussianPulse(10.0)}

        receivers = ReceiverSet(self._mesh,
                                [PointReceiver(self._mesh, p) for p in receiver_positions])

        if wavelet_type != "custom":
            wavelet = wavelet_types[wavelet_type]
        else:
            wavelet = CustomSource(10.0, source_fct)

        source = PointSource(self._mesh, source_position, wavelet)
        shot = Shot(source, receivers)
        return shot
    
    def _setup_forward_rhs(self, rhs_array, data):
        return self._solver.mesh.pad_array(data, out_array=rhs_array)
    
    def _forward_model(self, shot, m0, update_shot=False, res_queue = None):
        """ orignal copied from modeling/temporal_modeling.py """

        # Local references
        solver = self._solver
        solver.model_parameters = m0
        
        mesh = solver.mesh
        
        d = solver.domain
        dt = solver.dt
        nsteps = solver.nsteps
        source = shot.sources

        # Storage for the field        
        us = list()
        
        # Setup data storage for the forward modeled data        
        simdata = np.zeros((solver.nsteps, shot.receivers.receiver_count))            

        # Storage for the time derivatives of p        
        dWaveOp = list()
        
        # we must go deeper
        if res_queue:
            indices = [int((nsteps-1) * frame / (constants.FRAME_COUNT - 1)) for frame in range(constants.FRAME_COUNT)]
            indices_idx = 0
        
        
        # Step k = 0
        # p_0 is a zero array because if we assume the input signal is causal
        # and we assume that the initial system (i.e., p_(-2) and p_(-1)) is
        # uniformly zero, then the leapfrog scheme would compute that p_0 = 0 as
        # well. ukm1 is needed to compute the temporal derivative.
        solver_data = solver.SolverData()
        
        rhs_k = np.zeros(mesh.shape(include_bc=True))
        rhs_kp1 = np.zeros(mesh.shape(include_bc=True))
        print "steps = " + str(nsteps)
        for k in xrange(nsteps):
            uk = solver_data.k.primary_wavefield
            uk_bulk = mesh.unpad_array(uk)
            us.append(uk_bulk.copy())
            
            # Record the data at t_k            
            shot.receivers.sample_data_from_array(uk_bulk, k, data=simdata)
            
            # we must go deeper
            if res_queue and indices_idx < len(indices) and indices[indices_idx] == k:
                indices_idx += 1
                res_queue.put((indices_idx,uk_bulk.copy(), simdata.copy()))
                
            if update_shot:
                shot.receivers.sample_data_from_array(uk_bulk, k)
            
            if k == 0:
                rhs_k = self._setup_forward_rhs(rhs_k, source.f(k * dt))
                rhs_kp1 = self._setup_forward_rhs(rhs_kp1, source.f((k + 1) * dt))
            else:
                # shift time forward
                rhs_k, rhs_kp1 = rhs_kp1, rhs_k
            rhs_kp1 = self._setup_forward_rhs(rhs_kp1, source.f((k + 1) * dt))
            
            # Note, we compute result for k+1 even when k == nsteps-1.  We need
            # it for the time derivative at k=nsteps-1.
            solver.time_step(solver_data, rhs_k, rhs_kp1)

            # Compute time derivative of p at time k
            # Note that this is is returned as a PADDED array            
            dWaveOp.append(solver.compute_dWaveOp('time', solver_data))
        
            # When k is the nth step, the next step is uneeded, so don't swap 
            # any values.  This way, uk at the end is always the final step
            if(k == (nsteps - 1)): break
            
            # Don't know what data is needed for the solver, so the solver data
            # handles advancing everything forward by one time step.
            # k-1 <-- k, k <-- k+1, etc
            solver_data.advance()
                    
        return us, simdata, dWaveOp

    @benchmark
    def solve(self, material, source_position, receiver_positions,
              wavelet_type="Ricker", source_fct=None, update_shot=False,
              res_queue = None):
        
        shot = self.createShot(source_position, receiver_positions, wavelet_type, source_fct)
        base_model = self._solver.ModelParameters(self._mesh, {'C': material})
        
        self._solver.model_parameters = base_model
        ts = self._solver.ts()
        shot.reset_time_series(ts)
        
        shot.dt = self._solver.dt
        shot.trange = self._solver.trange
        
        # retval = self._objective.modeling_tools.forward_model(shot, base_model,
        #                                                      return_parameters=['dWaveOp', 'simdata', 'wavefield']) 
        
        wavefields, measurements, dWaveOp = self._forward_model(shot, base_model, update_shot=update_shot,
                                                                res_queue = res_queue) 
        
        # generate_shot_data_time(shot, self._solver, base_model, verbose=False)	
        # 'simdata' ist gemessen, 'dWaveOp' das (2-te)Ableitungsfeld. 'wavefield' das volle Wellenfeld 
        # das alte shots[0].gather(as_array=True) quasi
        #measurements = retval["simdata"]  # np.asarray(retval['simdata'])
        #wavefields = retval['wavefield']
        
        return wavefields, measurements, shot, dWaveOp

    @benchmark
    def backprop_and_grad_multi(self, shot, m, s_k, u_kdWaveOp, res_queue):        
        s_obs = shot.receivers.interpolate_data(self._solver.ts())
        s_residual = s_obs - s_k
        
        print "misfit: ", 0.5 * np.sum(s_residual ** 2) * self._solver.dt
        
        model_k = self._solver.ModelParameters(self._mesh, {'C': m})
        
        #g = self.migrate_shot(shot, model_k, s_residual,dWaveOp=u_kdWaveOp, res_queue=res_queue)
        ic = self.adjoint_model_multi(shot, model_k, s_residual, dWaveOp=u_kdWaveOp, res_queue=res_queue)
        
        return ic.without_padding().C
        
    def _setup_adjoint_rhs(self, rhs_array, shot, k, operand_simdata):        
        return self._solver.mesh.pad_array(shot.receivers.extend_data_to_array(k, data=operand_simdata), out_array=rhs_array)
    
    def adjoint_model_multi(self, shot, m0, operand_simdata, dWaveOp=None, res_queue = None):
        """
        from temporal_modeling.py
        """
        
        # Local references
        solver = self._solver
        solver.model_parameters = m0
        
        mesh = solver.mesh
        
        d = solver.domain
        dt = solver.dt
        nsteps = solver.nsteps
        source = shot.sources

        
        if dWaveOp is not None:
            ic = solver.model_parameters.perturbation()
            do_ic = True        

        # Time-reversed wave solver
        solver_data = solver.SolverData()
        
        rhs_k   = np.zeros(mesh.shape(include_bc=True))
        rhs_km1 = np.zeros(mesh.shape(include_bc=True))
        
        indices = [int((nsteps-1) * (constants.FRAME_COUNT-1 - frame) / (constants.FRAME_COUNT - 1)) for frame in range(constants.FRAME_COUNT)]
        indices_idx = 0
        
        # Loop goes over the valid indices backwards
        for k in xrange(nsteps-1, -1, -1): #xrange(int(solver.nsteps)):
            
            # Local reference
            vk = solver_data.k.primary_wavefield
            vk_bulk = mesh.unpad_array(vk)
            
            if indices_idx < len(indices) and indices[indices_idx] == k:
                indices_idx+=1
                res_queue.put((indices_idx,vk_bulk.copy()))
            
            # can maybe speed up by using only the bulk and not unpadding later
            if do_ic:
                ic += vk*dWaveOp[k]
            
            if k == nsteps-1:
                rhs_k   = self._setup_adjoint_rhs(rhs_k,   shot, k,   operand_simdata)
                rhs_km1 = self._setup_adjoint_rhs(rhs_km1, shot, k-1, operand_simdata)
            else:
                # shift time forward
                rhs_k, rhs_km1 = rhs_km1, rhs_k
            rhs_km1 = self._setup_adjoint_rhs( rhs_km1, shot, k-1, operand_simdata)
            
            solver.time_step(solver_data, rhs_k, rhs_km1)
                    
            # If k is 0, we don't need results for k-1, so save computation and
            # stop early
            if(k == 0): break
                
            # Don't know what data is needed for the solver, so the solver data
            # handles advancing everything forward by one time step.
            # k-1 <-- k, k <-- k+1, etc
            solver_data.advance()
    
        if do_ic:
            ic *= (-1*dt)    
            ic = ic.without_padding() # gradient is never padded   
        
        return ic
    
    @benchmark
    def backpropAndGrad(self, shot, m, s_k, u_kdWaveOp):
        s_obs = shot.receivers.interpolate_data(self._solver.ts())
        s_residual = s_obs - s_k
        
        print "misfit: ", 0.5 * np.sum(s_residual ** 2) * self._solver.dt
        
        model_k = self._solver.ModelParameters(self._mesh, {'C': m})
        
        adjointfield = []
        g = self._objective.modeling_tools.migrate_shot(shot, model_k, s_residual,
                                                        dWaveOp=u_kdWaveOp,
                                                        adjointfield=adjointfield)
        
        return g.C, adjointfield

    @benchmark
    def caluclatealpha(self, shot, m, g, u_kdWaveOp):
        model_k = self._solver.ModelParameters(self._mesh, {'C': m})
        gradient = model_k.perturbation(data=g)
        linear_retval = self._objective.modeling_tools.linear_forward_model(shot, model_k, gradient,
                                                                            dWaveOp0=u_kdWaveOp,
                                                                            return_parameters=['simdata'])        
        delta_t = self._solver.dt
        norm_lin_sq = np.sum(linear_retval['simdata'] ** 2) * delta_t  # ???? * (delta_t*delta_t) * 7.5
        hsquared = np.prod(self._mesh.deltas)
        norm_adj_sq = np.sum(g ** 2) * hsquared
        alpha = norm_adj_sq / norm_lin_sq
        
        print "hsquared", hsquared, "dt", delta_t
        print "norm_adj", np.sqrt(norm_adj_sq)
        print "norm_lin", np.sqrt(norm_lin_sq)
        print "alpha", alpha
        
        return alpha


    def mesh(self): return self._mesh
    def domain(self): return self._domain
    
    def domainNxNy(self): return (self._domain, self.Nx, self.Ny)


class FrameGenerator(object):
    def __init__(self, fwd, wavefields, frame_count=constants.FRAME_COUNT, colormap=None):
        self._images = []
        self._wavefields = wavefields
        self._k = 0
        self._frame_count = frame_count
        self._animate_generator = fwd.animate_generator(self._wavefields, frame_count=self._frame_count,
                                                              colormap=colormap)
        self._finished = False
        next(iter(self))
    
    def __iter__(self):
        return self
    
    def __len__(self):
        return self._frame_count
    
    def finished_frame_count(self):
        return len(self._images)
    
    def rendering_finished(self):
        return self._finished
    
    def next(self):        
        if self._k >= self._frame_count:
            self._k = 0
            self._finished = True
            raise StopIteration
        
        if self._k >= len(self._images):
            try:
                frame = next(self._animate_generator)
                self._images += [frame]
            except StopIteration:
                self._frame_count = len(self._images)
                self._k = 0
                self._finished = True
                raise StopIteration
            
            self._k += 1
    
    # TODO check for correctness:
    # added code from paintEvent: loop till expected Element is generated 
    
    def __getitem__(self, k):        
        while len(self._images) <= k and not self._finished:            
            self.next()        
        return self._images[k]
            


if __name__ == "__main__":
    def create_example_material(mesh, domain):
        def _gaussian_derivative_pulse(XX, YY, threshold, **kwargs):
            """ Derivative of a Gaussian at a specific sigma """
            T = -5 * YY * np.exp(-(YY ** 2 * 10))
            T[np.where(abs(T) < threshold)] = 0
            return T

        [XX, YY] = mesh.mesh_coords()
        material = 1 * np.ones(mesh.shape())

        material += _gaussian_derivative_pulse(XX, YY - 30, 1e-7)
        return 10 * material

    source_pos = (40.13251783893986, 6.477272727272727)
    receiver_pos = [(52.330275229357795, 4.534090909090909), (55.869520897043834, 4.534090909090909), (59.408766564729866, 4.534090909090909), (62.9480122324159, 4.534090909090909), (66.48725790010194, 4.534090909090909), (70.02650356778797, 4.534090909090909), (73.56574923547402, 4.534090909090909)]

    # N = 10
    # receiver_pos = [(52.330275229357795 + i*(73.56574923547402-52.330275229357795)/(N-1), 4.534090909090909) for i in range(N)]

    FWD = ForwardSolver(
        dimensions=((0.0, 100.0), (0.0, 70.0)), grid_resolution=(70, 100), trange=(0, 10),
        )

    mesh, domain = FWD.mesh(), FWD.domain()
    material = create_example_material(mesh, domain)
    wavefields, s_obs, shot, retval = FWD.solve(material, source_position=source_pos, receiver_positions=receiver_pos,
                                                update_shot=True)

    shot.receivers.data = s_obs
    print(shot)
