"""
Cette bibliotheque de fonction regroupe des methodes assez
variees principalement utilisees avec le pencil-code.
"""
import numpy as N
from numpy.fft import fft
import sys
import matplotlib

__author__  = "$Author: dintrans $"
__date__   = "$Date: 2009/09/22 08:52:32 $"
__version__ = "$Revision: 1.1 $"

def crete(t, ruzm):
    """
    Fonction "detecteur de cretes" retournant le tableau
    des maxima de ruzm (f) et le tableau time des abscisses
    correspondantes

        >>> ts = pc.read_ts(plot_data=False)
        >>> tt, yy = crete(ts.t, ts.ruzm)

    @param t: l'axe des x initial
    @type t: numpy array
    @param ruzm: le tableau dont on veut extraire les cretes
    @type ruzm: numpy array
    @return: un tuple contenant:

        - tt, le nouveau temps
        - yy, les cretes de ruzm

    @rtype: numpy array
    """

    f = N.zeros(1, dtype='Float32')
    time = N.zeros_like(f)
    for i in range(1, len(ruzm)-1):
        if (ruzm[i] >= ruzm[i+1] and ruzm[i] > ruzm[i-1] ):
            f = N.append(f, ruzm[i])
            time = N.append(time, t[i])
    return time, f

def deriv(x, y):
    """
    Derivee a l'ordre 2

        >>> x = N.linspace(0., 1., 100)
        >>> y = N.cos(5*x)
        >>> dy = deriv(x, y)

    @param x: la grandeur par rapport a laquelle on derive
    @type x: numpy array
    @param y: la grandeur que l'on veut deriver
    @type y: numpy array
    @return: dy/dx
    @rtype: numpy array
    """
    if len(x) < 3:
        sys.exit("Paramaters must have at least three points")
    if len(x) != len(y):
        sys.exit("Vectors must have the same size")
    d = (N.roll(y, -1)-N.roll(y, 1))/(N.roll(x, -1)-N.roll(x, 1))
    d[0] = (-3.*y[0]+4.*y[1]-y[2])/(x[2]-x[0])
    d[-1] = (3*y[-1]-4*y[-2]+y[-3])/(x[-1]-x[-3])
    return d

def fourier_perso(x, y):
    """
    Fonction permettant de faire la TF de y(x). Elle
    retourne w, le tableau des frequences positives,
    et w2, la TF associee.

        >>> x = N.linspace(0., 1., 100)
        >>> y = N.cos(5*x)
        >>> w, w2 = fourier_perso(x, y)

    @param x: le tableau x
    @type x: numpy array
    @param y: le tableau dont on veut faire la TF
    @type y: numpy array
    @return: un tuple contenant:

        - w: me tableau des frequences positives    
        - w2: la TF associee

    @rtype: tuple
    """
    w1 = fft(y)/len(y)
    w2 = N.abs(w1[1:len(y)/2+1])
    dw = 2.*N.pi/(x[-1]-x[0])
    w = dw*N.arange(len(x))
    w = w[1:len(x)/2+1]
    return w, w2

def change_xticks(ax, newxticks):
    """
    Routine pour changer les xticks d'un plot

    @param ax: l'axe matplotlib
    @type ax: matplotlib.axes.AxesSubplot
    @param newxticks: le nouveau xticks qu'on veut mettre
    @type newxticks: numpy array
    """
    xx = ax.get_xticks()
    xmin = newxticks.min()
    xmax = newxticks.max()
    xxmarque = N.linspace(xmin, xmax, len(xx))
    xticks(xx, xxmarque, color='k')

def change_yticks(ax, newyticks):
    """
    Routine pour changer les yticks d'un plot

    @param ax: l'axe matplotlib
    @type ax: matplotlib.axes.AxesSubplot
    @param newyticks: le nouveau yticks qu'on veut mettre
    @type newyticks: numpy array
    """
    yy = ax.get_yticks()
    ymin = newyticks.min()
    ymax = newyticks.max()
    yymarque = N.linspace(ymin, ymax, len(yy))
    yticks(yy, yymarque, color='k')

def cmap_map(function, cmap):
    """
    Routine permettant d'appliquer une contion sur une colormap
    matplotlib (par exemple inverser la colormap)

        >>> func = lambda x: x[::-1]
        >>> cmap = P.get_cmap('RdBu')
        >>> cmPerso = cmap_map(func, cmap)

    @param function: la fonction de modification que l'on souhaite
                     appliquer a la colormap d'origine
    @type function: function
    @param cmap: la colormap d'origine
    @type cmap: matplotlib colormap
    @return: une nouvelle colormap
    @rtype: matplotlib colormap
    """
    cdict = cmap._segmentdata
    step_dict = {}
    # Firt get the list of points where the segments start or end
    for key in ('red', 'green', 'blue'):
        step_dict[key] = map(lambda x:x[0], cdict[key])
    step_list = reduce(lambda x, y: x+y, step_dict.values())
    step_list = N.sort(N.array(list(set(step_list))))
    # Then compute the LUT, and apply the function to the LUT
    reduced_cmap = lambda step : N.array(cmap(step)[0:3])
    old_LUT = N.array(map( reduced_cmap, step_list))
    new_LUT = N.transpose(N.array(map( function, N.transpose(old_LUT))))
    # Now try to make a minimal segment definition of the new LUT
    cdict = {}
    for i,key in enumerate(('red','green','blue')):
        this_cdict = {}
        for j,step in enumerate(step_list):
            if step in step_dict[key]:
                this_cdict[step] = new_LUT[j,i]
            elif new_LUT[j,i]!=old_LUT[j,i]:
                this_cdict[step] = new_LUT[j,i]
        colorvector =  map(lambda x: x + (x[1], ), this_cdict.items())
        colorvector.sort()
        cdict[key] = colorvector

    return matplotlib.colors.LinearSegmentedColormap('colormap',cdict,1024)

def hfluct2d(a, axis=1):
    """
    Calcule les fluctuations d'un champ 2D par rapport a la moyenne
    horizontale avec axis=0: % 1er indice, axis=1: % 2nd indice

        >>> x = hfluct2d(var.uz)

    @param a: le tableau d'entree
    @type a: numpy array
    @param axis: l'axe par rapport auquel on souhaite calculer les
                 fluctuations
    @type axis: integer
    @return: le tableau de fluctutations
    @rtype: numpy array
    """
    hav = a.mean(axis=axis)
    ap = N.zeros_like(a)
    if axis == 0:
        ap = a-hav
    else:
        ap = a.transpose()-hav
        ap = ap.transpose()
    return ap

def hfluct3dsph(a, sinth, dth, dphi):
    """
    Calcule les fluctuations d'un champ 3D dans le cas spherique

    @param a: le tableau d'entree
    @type a: numpy array
    @param sinth: M{sin} S{theta}
    @type sinth: float
    @param dth: d S{theta}
    @type dth: float
    @param dphi: d S{phi}
    @type dphi: float
    @return: le tableau de fluctutations
    @rtype: numpy array
    """
    angav = dth*dphi/(4.*N.pi) * ((sinth*a).sum(axis=0)).sum(axis=0)
    fluc = N.zeros_like(a)
    fluc = a-angav
    return fluc

def nligne(fichier):
    """
    Donne le nombre de lignes d'un fichiers

        >>> n = nligne('data/time_series.dat')

    @param fichier: le nom du fichier
    @type fichier: string
    @return: le nombre de lignes
    @rtype: integer
    """
    file = open(fichier, 'r')
    nl = len(file.readlines())
    file.close()
    return nl

def step(x, x0, width):
    """
    Smooth unit step function centred at x0; implemented as tanh profile

    @param x: tableau d'entree
    @type x: numpy array
    @param x0: la valeur centrale du step
    @type x0: float
    @param width: largeur du step
    @type width: float
    @return: le tableau d'entree fois le step
    @rtype: numpy array
    """
    tini = 1.e-10
    return 0.5*(1+N.tanh((x-x0)/(width+tini)))

def ou1(tab, val):
    """
    Routine qui dit ou un tableau est egal a une certaine valeur.

        >>> x = N.linspace(0., 1., 100)
        >>> ind = ou1(x, 0.28)

    @param tab: le tableau d'entree
    @type tab: numpy array
    @param val: la valeur a laquelle on veut comparer
    @type val: float
    @return: l'indice le plus proche
    @rtype: integer
    """
    a = N.abs(tab[:]-val)
    return N.argmin(a)

def haver(a):
    """
    Horizontal average of a 3-D array
    
    @note: in Python, the PC arrays are ordered like (nz,ny,nx)
    
    @param a: le tableau d'entree
    @type a: numpy array
    @return: la moyenne horizontale du tableau
    @rtype: numpy array
    """
    return a.sum(axis=2).sum(axis=1)/(a.shape[1]*a.shape[2])

def plot_circle(rad, ls='--', color=None, half=False):
    """
    plot_circle(rad, ls='--', color=None, half=False):
    Trace un cercle de rayon donne

    @param rad: le rayon du cercle
    @type rad: float
    @param ls: linestyle du cercle
    @type rad: string
    @param color: couleur du cercle
    @type color: string
    @param half: moitie du cercle ou pas
    @type half: logical
    @return: matplotlib figure
    """
    if (color is None): color='black'
    if half:
        an = matplotlib.pylab.linspace(0, N.pi, 100)
    else:
        an = matplotlib.pylab.linspace(0, 2*N.pi, 100)
    matplotlib.pylab.plot(rad*N.cos(an), rad*N.sin(an), 
    color=color, linestyle=ls)

def intersect(list1, list2):
    """
    Intersection entre 2 listes

    @param list1: premiere liste
    @type list1: list
    @param list2: premiere liste
    @type list2: list
    @return: l'intersection entre les deux listes
    """

    inter = filter(lambda x:x in list1,list2)
    return inter

def normalize(a, b):
    """
    16-mai-2009/dintrans: coded
    Normalization of two vectors a and b (useful to plot velocity
    fields)
    """

    c=a**2+b**2 ; norm=N.sqrt(c.max())
    a=a/norm
    b=b/norm
    return a, b

