import numpy     as np
import cosmology as co
from scipy.interpolate import interp1d 
from scipy.special     import jv 

arcmin_per_rad = 60. * 360. / 2 / np.pi 
    
def m200m_to_m200c(m200m,z):
    omegamz = co.omegam*(1+z)**3/(co.omegam*(1+z)**3+1-co.omegam)
    m200c   = omegamz**0.35 * m200m # m200m to m200c conversion used for websky
    return m200c

def gettable(tablefile):

    tablefile=open(tablefile)
    
    n      = np.fromfile(tablefile,count=3,dtype=np.int32)
    bounds = np.fromfile(tablefile,count=6,dtype=np.float32)
    table  = np.fromfile(tablefile,count=n[0]*n[1]*n[2],dtype=np.float32)
    
    table  = np.resize(table,(n[0],n[1],n[2]))
    
    unused = np.fromfile(tablefile,count=2*n[0]*n[1]*n[2],dtype=np.float32)
    norms  = np.fromfile(tablefile,count=n[0]*n[1],dtype=np.float32)
    
    norms  = np.resize(norms,(n[0],n[1]))

    return n, bounds, table, norms

n, bounds, table, norms = gettable('proftab_websky.bin')

nchit = n[0]
nmht  = n[1]
nrt   = n[2]

chimint = bounds[0]
chimaxt = bounds[1]
mhmint  = bounds[2]
mhmaxt  = bounds[3]
rmint   = bounds[4]
rmaxt   = bounds[5]

dchit = (np.log(chimaxt) - np.log(chimint)) / (nchit - 1)
drt   = (  np.log(rmaxt) -   np.log(rmint)) / (nrt - 1)
dmht  = ( np.log(mhmaxt) -  np.log(mhmint)) / (nmht - 1)

def y(theta,m200c,z):

    """Compton-y profile.

    Args:
        theta (float): angular distance from halo center in arc minutes
        m200c (float): mass of halo, M200c, in Msun
        z (float):     redshift of halo

    Returns:
        y (float):    Compton-y parameter
        ybar (float): sky-averaged Compton-y parameter (doesn't depend on theta)

    websky tsz map is obtained by summing the contribution from each halo

    the contribution from a halo, h, is obtained first by evaluating this 
    function to assign a value y_cen_i_h at each pixel, i, based on the 
    angular separation from the halo center to the pixel center

    after obtaining y_cen_i_h for all pixels overlapping the halo, the 
    values y_cen_i_h are multiplied by an overall factor W_norm_h,
        y_i_h = W_norm_h * y_cen_i_h
    such that the sky-averaged y value in the map from the halo obtained by
    assuming a top-hat window function, 
        ybar_h = ( sum_i [ y_i_h ] ) * omega_pix / (4pi) 
    is equal to the value in the table, i.e.
        ybar_h == y(0.0,m200c_h,z_h)[1]
    where the first argument is arbitrary, m200c_h is M200c for halo h, 
    and z_h is redshift for halo h. 

    """

    mh = m200c
    theta = 2 * np.pi * theta / 60. / 360. # arcmin to rad
    chi = co.chiofz(z)
    r   = chi * np.sin(theta)
    if(r>=rmaxt or chi>=chimaxt): return 0.0, 0.0

    if(   r <= rmint   ):   r =   rmint + 1e-5
    if( chi <= chimint ): chi = chimint + 1e-5
    if(  mh <= mhmint  ):  mh =  mhmint + 1e-5

    if(  mh >= mhmaxt  ):  mh =  mhmaxt - 1e-5

    ir   = int( (   np.log(r) -   np.log(rmint) ) / drt   )
    ichi = int( ( np.log(chi) - np.log(chimint) ) / dchit )
    imh  = int( (  np.log(mh) -  np.log(mhmint) ) / dmht  )

    fc = np.log(chi) - ( np.log(chimint) + (ichi ) * dchit )
    fm = np.log(mh)  - (  np.log(mhmint) + ( imh ) *  dmht )
    fr = np.log(r)   - (   np.log(rmint) + (  ir ) *   drt )

    fc /= dchit
    fm /= dmht
    fr /= drt

    if(fr<0): fr=0
    if(fr>1): fr=1
    if(fc<0): fc=0
    if(fc>1): fc=1
    if(fm<0): fm=0
    if(fm>1): fm=1
    y = (
         table[ichi,  imh,  ir  ] * (1-fr) * (1-fm) * (1-fc) +
         table[ichi,  imh,  ir+1] * (  fr) * (1-fm) * (1-fc) +
         table[ichi,  imh+1,ir  ] * (1-fr) * (  fm) * (1-fc) +
         table[ichi,  imh+1,ir+1] * (  fr) * (  fm) * (1-fc) +
         table[ichi+1,imh,  ir  ] * (1-fr) * (1-fm) * (  fc) +
         table[ichi+1,imh,  ir+1] * (  fr) * (1-fm) * (  fc) +
         table[ichi+1,imh+1,ir  ] * (1-fr) * (  fm) * (  fc) +
         table[ichi+1,imh+1,ir+1] * (  fr) * (  fm) * (  fc)
         )
    ybar = (
        norms[ichi,  imh  ] * (1-fm) * (1-fc) +
        norms[ichi,  imh+1] * (  fm) * (1-fc) +
        norms[ichi+1,imh  ] * (1-fm) * (  fc) +
        norms[ichi+1,imh+1] * (  fm) * (  fc) 
    )
    return y, ybar 
    
def powerspectrum():

    # parameters of M-z integration 
    nchi   = 100   # enough to converge?
    nM     = 100   # enough to converge?
    chimin = 50.0
    chimax = 8e3
    Mmin   = 1.3e12 # approximately websky minimum M200m value
    Mmax   = 1e16
    dlnchi = (np.log(chimax) - np.log(chimin)) / (nchi-1)
    dlnM   = (np.log(  Mmax) - np.log(  Mmin)) / (nM-1)

    # angular power spectrum multipoles
    nell    = 25
    ellmin  = 10
    ellmax  = 1e4
    ellvals = np.logspace(np.log10(ellmin),np.log10(ellmax),nell)
    cls     = np.zeros(nell)
    
    # mean y 
    ybar    = 0.0
    ybartab = 0.0    
    
    # y-profile discretization 
    ntheta1d = 50
    rmin     = rmint # comoving Mpc
    thetamin = 0.01 / arcmin_per_rad # enough to converge?
    ratio    = 1.1                  # ratio of successive theta values
    dlntheta = np.log(ratio)        # enough to converge?
    linear   = True
    nside    = 4096
    dtheta   = 0.1
        
    # create function dn/dM(M,z) with input M200m in Msun, and return units 
    # 1/Mpc^3/Msun; here we use Tinker et al. (2008) with Websky cosmology 
    # to use rancat.hmf_websky do:
    # 
    #   git clone https://github.com/marcelo-alvarez/rancat
    #   cd rancat
    #   pip install .
    # 
    # otherwise assign dndmofmz with a dn/dM(M,z) function using same arguments, 
    # units, and return value
    import rancat.hmf_websky as hmfw
    dndmofmz = hmfw.dndmofmz_tinker(Mmin/1.1,             Mmax*1.1,
                       co.zofchi(chimin)/1.1, co.zofchi(chimax)*1.1)
    
    # initialize 1d arrays
    chi1d    = np.exp(np.linspace(np.log(chimin),np.log(chimax),nchi))
    M1d      = np.exp(np.linspace(np.log(  Mmin),np.log(  Mmax),nM  ))

    for i in range(nchi-1,0,-1):      
        chi  = chi1d[i]
        thetamintab = np.arcsin(rmin/chi)             
        z    = co.zofchi(chi)
        if ybar > 0: 
            print('z: ',z,i)        
        chip = chi*np.exp(dlnchi/2.)
        chim = chi*np.exp(-dlnchi/2.)
        dV   = 4 * np.pi * (chip**3 - chim**3) / 3.   
        for j in range(nM): 
            M         = M1d[j]
            M200c     = m200m_to_m200c(M,z)
            dndM      = dndmofmz(M,z)[0]
            n_per_str = dV * dndM * M * dlnM / (4.*np.pi)           
            
            ytilde = np.zeros(len(ellvals)) # Hankel transform of y-profile

            ntheta  = 0
            y1d     = np.zeros(0)
            theta1d = np.zeros(0)

            theta   = thetamin * arcmin_per_rad # rad to arcmin for y-profile function
            ycur, ycurbartab = y(theta,M200c,z)
            while ycur>0: # fill y1d array until profile drops to zero (slow)
                y1d     = np.append(y1d,ycur)
                ntheta += 1
                theta1d = np.append(theta1d,theta)
                theta  *= np.exp(dlntheta)
                ycur, ycurbartab = y(theta,M200c,z)
            if ntheta > 0: 
                theta1d = theta1d / arcmin_per_rad # back to rad for Hankel transform
                dydtheta = 2*np.pi*y1d*np.sin(theta1d)*theta1d # y1d*theta1d
                ycurbar  = np.trapz(dydtheta,np.log(theta1d)) 
                ybar    += ycurbar    * n_per_str
                ybartab += ycurbartab * n_per_str                
                for k in range(ntheta):
                    # Hankel transform of order 0
                    ytilde += (2 * np.pi * jv(0,ellvals*theta1d[k]) * 
                                   theta1d[k]**2 * y1d[k] * dlntheta)

            cls  += ytilde**2 * n_per_str 
    
    return ellvals, cls*1e12/2/np.pi*ellvals*(ellvals+1), ybar, ybartab

