import math
import numpy as np
import matplotlib.pyplot as plt
import datetime
#from mpl_toolkits.basemap import Basemap, cm
from scipy.signal import savgol_filter
import scipy.ndimage as sp
# requires netcdf4-python (netcdf4-python.googlecode.com)
from netCDF4 import Dataset as NetCDFFile
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
import cartopy
import datetime
plt.close('all')
runfile('/global/homes/s/smhagos/tools/importtools.py', wdir='/global/homes/s/smhagos/scripts/climsec')
#matplotlib qt



months = np.array([[i+1 for i in range(0,12)] for j in range(0,35)])
years = np.array([[j+1980 for i in range(0,12)] for j in range(0,35)])
months = np.reshape(months, (1,np.product(months.shape)))[0]
years = np.reshape(years, (1,np.product(years.shape)))[0]



def getpr_hist(experiment):
    filem = '/global/cfs/cdirs/cpmmjo/CLIMSEC/CMIP6DATA/historical/pr_'+experiment
    print(filem)
    ncm = NetCDFFile(filem)
    precip = 86400.0*ncm.variables['pr'][:]
    lat = ncm.variables['lat'][:]
    lon = ncm.variables['lon'][:]
    meanprecip = np.roll(np.mean(precip,axis=0),180,axis=1)
    lon2 = np.roll(lon,180)
    lon2[np.where(lon2>180)] = lon2[np.where(lon2>180)] -360
    return meanprecip,lat,lon2

def getpr_ssp585(experiment):
    filem = '/global/cfs/cdirs/cpmmjo/CLIMSEC/CMIP6DATA/ssp585/pr_'+experiment
    print(filem)
    ncm = NetCDFFile(filem)
    precip = 86400.0*ncm.variables['pr'][:]
    lat = ncm.variables['lat'][:]
    lon = ncm.variables['lon'][:]
    meanprecip = np.roll(np.mean(precip[420:600],axis=0),180,axis=1)
    lon2 = np.roll(lon,180)
    lon2[np.where(lon2>180)] = lon2[np.where(lon2>180)] -360
    return meanprecip,lat,lon2




hist1,lat,lon = getpr_hist("Amon_E3SM-1-0_historical_r1i1p1f1_gr_200001-201412.nc")
hist2,lat,lon = getpr_hist("Amon_E3SM-1-0_historical_r2i1p1f1_gr_200001-201412.nc")
hist3,lat,lon = getpr_hist("Amon_E3SM-1-0_historical_r3i1p1f1_gr_200001-201412.nc")
hist4,lat,lon = getpr_hist("Amon_E3SM-1-0_historical_r4i1p1f1_gr_200001-201412.nc")
hist5,lat,lon = getpr_hist("Amon_E3SM-1-0_historical_r5i1p1f1_gr_200001-201412.nc")

ssp5851,lat,lon = getpr_ssp585('Amon_E3SM-1-0_ssp585_r1i1p1f1_gr_201501-206412.nc')
ssp5852,lat,lon = getpr_ssp585('Amon_E3SM-1-0_ssp585_r2i1p1f1_gr_201501-206412.nc')
ssp5853,lat,lon = getpr_ssp585('Amon_E3SM-1-0_ssp585_r3i1p1f1_gr_201501-206412.nc')
ssp5854,lat,lon = getpr_ssp585('Amon_E3SM-1-0_ssp585_r4i1p1f1_gr_201501-206412.nc')
ssp5855,lat,lon = getpr_ssp585('Amon_E3SM-1-0_ssp585_r5i1p1f1_gr_201501-206412.nc')

precip_hist = np.mean([hist1,hist2,hist3,hist4,hist5],axis=0)
precip_ssp585= np.mean([ssp5851,ssp5852,ssp5853,ssp5854,ssp5855],axis=0)
clevs = list(0.05*x for x in range(-19,18))

plt.close('all')
fig, ax = plt.subplots(figsize=(9,6))
ax.axis("off")
ax = plt.axes(projection=cartopy.crs.PlateCarree(central_longitude=0))
cs= plt.contourf(lon,lat,precip_ssp585-precip_hist,clevs,cmap=my2dcmap)
ax.add_feature(cartopy.feature.LAND)
ax.add_feature(cartopy.feature.OCEAN)
ax.add_feature(cartopy.feature.COASTLINE)
ax.add_feature(cartopy.feature.BORDERS)
ax.add_feature(cartopy.feature.LAKES, alpha=0.5)
ax.add_feature(cartopy.feature.RIVERS)
plt.show()
plt.yticks(fontsize='16')
plt.xticks(fontsize='16')
plt.xticks([-20,0,20,40,60], ['20W','0','20E','40E','60E'],fontsize='18')
plt.yticks([0,20,40,60], ['EQ','20N','40N','60N'],fontsize='18')
ax.tick_params(axis='both', which='major', labelsize=18)
plt.axis([-25, 60,  10,60])
plt.xlabel('Longitude',size=16)
plt.ylabel('Latitude',size=18)
fig.canvas.draw()
cbar = plt.colorbar(cs, orientation="horizontal")
cbar.set_label('2040-2064 minus 2000-2014 annual mean precip difference (mm/day)',size=16)
cbar.ax.tick_params(labelsize=16)
plt.title( 'Precipitation change (E3SM-V1 SSP585 minus Historical)',size=18)
plt.show()
figurename = '/global/cfs/cdirs/cpmmjo/CLIMSEC/figures/precip_ssp282_hist_diff4064'






def getsoilm(year,month,experiment):
    smonth = str(month)
    if len(smonth)==1: smonth = '0'+smonth
    
    filem = '/global/project/projectdirs/m1867/smhagos/'+experiment+'/'+experiment+'.elm.h0.'+str(year)+'-'+smonth+'.nc'
    print(filem)
    ncm = NetCDFFile(filem)
    soilm = ncm.variables['H2OSOI'][:]
    lat = ncm.variables['lat'][:]
    lon = ncm.variables['lon'][:]
    return soilm,lat,lon

# calculate annual mean for each year 
def subsample(index,exiperimentin):
    experiment = exiperimentin
    it = 0
    for i in index:
        precipin,lat,lon = getseasonalmean(years[i],months[i],experiment)
#        precipin = np.expand_dims(precipin,axis=0)
        if it==0:
            precipts = precipin
        else:
            precipts =np.append(precipts,precipin,axis=0)
        it = it+1
    return precipts,lat,lon

def subsamplesoil(index,exiperimentin):
    experiment = exiperimentin
    it = 0
    for i in index:
        soilmin,lat,lon = getsoilm(years[i],months[i],experiment)
#        precipin = np.expand_dims(precipin,axis=0)
        if it==0:
            soilmts = soilmin
        else:
         soilmts =np.append(soilmts,soilmin,axis=0)
        it = it+1
    return soilmts,lat,lon

def gettrend(experiment):
    for iyear in range(1980,2015):
        index = np.where(years==iyear)[0][:]
        precipy,lat,lon = subsample(index,experiment)
        precipy = np.expand_dims(precipy,axis=0)
        print(iyear)
        if iyear==1980:
            tream = precipy
        else:
            tream = np.append(tream,precipy,axis=0)
    return tream







def gettrendsoilm(experiment):
    for iyear in range(1980,2015):
        index = np.where(years==iyear)[0][:]
        soily,lat,lon = subsamplesoil(index,experiment)
        soily = np.expand_dims(soily,axis=0)
        print(iyear)
        if iyear==1980:
            trsoil = soily
        else:
            trsoil = np.append(trsoil,soily,axis=0)
    return trsoil

def gettrendsst(experiment):
    for iyear in range(1980,2015):
        index = np.where(years==iyear)[0][:]
        ssty,lat,lon = subsamplesst(index,experiment)
        ssty = np.expand_dims(ssty,axis=0)
        print(iyear)
        if iyear==1980:
            trsst = ssty
        else:
            trsst = np.append(trsst,ssty,axis=0)
    return trsst
def getsst(year,month,experiment):
    smonth = str(month)
    if len(smonth)==1: smonth = '0'+smonth
    
    filem = '/global/project/projectdirs/m1867/smhagos/'+experiment+'/'+experiment+'.eam.h0.'+str(year)+'-'+smonth+'.nc'
    print(filem)
    ncm = NetCDFFile(filem)
    sst = ncm.variables['TS'][:]
    lat = ncm.variables['lat'][:]
    lon = ncm.variables['lon'][:]
    return sst,lat,lon

def subsamplesst(index,exiperimentin):
    experiment = exiperimentin
    it = 0
    for i in index:
        sstin,lat,lon = getsst(years[i],months[i],experiment)
#        precipin = np.expand_dims(precipin,axis=0)
        if it==0:
            sstts = sstin
        else:
         sstts =np.append(sstts,sstin,axis=0)
        it = it+1
    return sstts,lat,lon


def getslope(datain):
    mdata =datain[:,2:5,:,:]
    nt,nlat,nlon = np.shape(mdata)
    print(np.shape(mdata))
    slope= mdata[0,:,:].copy()
    slope[:,:] = 0.0
    nt,nlat,nlon = np.shape(mdata)
    x = np.arange(0,nt)
    for i in range(nlat):
        for j in range(nlon):
            test =  mk.original_test(mdata[:,i,j])
            if(test[2]<0.5): slope[i,j] = test[7]            
    return  slope           


def afrslope(mdata):
    nt,nlat,nlon = np.shape(mdata)
    print(np.shape(mdata))
    slope= mdata[0,:,:].copy()
    slope[:,:] = 0.0
    nt,nlat,nlon = np.shape(mdata)
    x = np.arange(0,nt)
    for i in range(nlat):
        for j in range(nlon):
            test =  mk.original_test(mdata[:,i,j])
            if(test[2]<0.5): slope[i,j] = test[7]            
    return  slope


def gettest(data1, data2):
    nt, nm, nlat, nlon = np.shape(data1)
    print(np.shape(data1))
    datadiff = data1[0,:,:,:].copy()
    datadiff[:,:] = 0.0
    # nt,nlat,nlon = np.shape(mdata)
    x = np.arange(0, nt)
    for im in range(nm):
        print(im)
        for i in range(nlat):
            for j in range(nlon):
                stat, pvalue = stats.ttest_ind(data1[:, im, i, j], data2[:, im, i, j])
                if(pvalue<0.20): datadiff[im,i,j]=  np.mean(data1[:, im, i, j] - data2[:, im, i, j])
    return datadiff


def getseasonalmeangpcp(year,month):
    smonth = str(month)
    if len(smonth)==1: smonth = '0'+smonth
    
    filem = '/global/project/projectdirs/m1867/smhagos/EAMONSOON/OBS/GPCP/gpcp_cdr_v23rB1_y'+str(year)+'_m'+smonth+'.nc'
    print(filem)
    ncm = NetCDFFile(filem)
    precip = ncm.variables['precip'][:]
    precip= np.flip(precip,axis=0)
    latg = ncm.variables['latitude'][:]
    long = ncm.variables['longitude'][:]
    return precip,latg,long



# calculate annual mean for each year 
def subsamplegpcp(index):
    it = 0
    for i in index:
        precipin,lat,lon = getseasonalmeangpcp(years[i],months[i])
        precipin = np.expand_dims(precipin,axis=0)
        if it==0:
            precipts = precipin
        else:
            precipts =np.append(precipts,precipin,axis=0)
        it = it+1
    return precipts,lat,lon


# append them
def gettrendgpcp():
    for iyear in range(1980,2015):
        index = np.where(years==iyear)[0][:]
        precipy,lat,lon = subsamplegpcp(index)
        precipy = np.expand_dims(precipy,axis=0)
        if iyear==1980:
            trend = precipy
        else:
            trend = np.append(trend,precipy,axis=0)
    return trend


def getlinear(x,trend):
    a, b = np.polyfit(x,trend,1)
    return a*x+b






# ctlpcpmap = gettrend('EAMCONTROL01')   
# lndpcpmap = gettrend('EAMCLIMLAND02')




# gctl = np.mean(np.mean(ctlpcpmap,axis=1),axis=0)
# glnd = np.mean(np.mean(lndpcpmap,axis=1),axis=0)

# yearc = np.array(range(1980,2022))

# # define the box overwhich you want to average
# lonmin = 30
# lonmax = 50
# latmin = -10
# latmax = 10


# ilatmax = np.max(np.where(lat<=latmax))
# ilatmin = np.max(np.where(lat<=latmin))

# ilonmin = np.max(np.where(lon<=lonmin))
# ilonmax = np.max(np.where(lon<=lonmax))

# latp = lat[ilatmin:ilatmax]
# lonp = lon[ilonmin:ilonmax]


# pcpeam1 = np.nanmean(np.nanmean(tream[:,0,ilatmin:ilatmax,ilonmin:ilonmax],axis=2),axis=1)

# variance = np.var(trend,axis=0)








