import matplotlib.pyplot as plt
import xarray as xr
import pandas as pd
import numpy as np
from manuscript_plot_functions import get_emisMasked, get_ElevEmisMasked, get_scatter_plot3
from manuscript_plot_functions import get_ts, get_vertint, get_var, get_scatter_plot2

## All available sites with emission mask
lab=['(a)','(b)','(c)','(d)','(e)','(f)','(g)','(h)']
lab=['(a)','(b)','(c)','(d)']
obs='obs'
varbs = ['bc_a1_SRF','SE_bc_a1_SRF','bc_a4_SRF','SE_bc_a4_SRF','bc','SE_bc']
df = pd.read_csv('/global/cfs/cdirs/e3sm/www/hass877/share/Data_manuscript/BC_IMPROVE_dailydata.csv')
variable='BC'
df.index=pd.to_datetime(df['time'])
## Get rid of nans & -ve vals
df_nona = df[df['obs'].notna()]
df_nona = df_nona[df_nona['obs']>0]
res='RRM'
df_nona=get_emisMasked(df_nona,res,stat='diff',factor=0.68)
df_monthly = df_nona.select_dtypes(include=[np.number]).groupby('ncols_'+res).resample('1M').mean()
df_annual = df_nona.select_dtypes(include=[np.number]).groupby('ncols_'+res).resample('1Y').mean()
## Plotting
plt.figure(figsize=(22,12))
i=1
j=1
for v,treat,ylab,xlab in zip(varbs[-2:],['BC ($\u03BCg\ m^{-3}$)                  0.22',''],['RRM-PD\n\nModel','RRM-SE-PD\n\nModel'],['','Obs']):
    ax=plt.subplot(2,4,i)
    get_scatter_plot2(df_monthly,v,res,ax,treatment=treat,temp='monthly',size=5,cax=[5e-3,5e-3,3e0,3e0],vv=ylab,vx=xlab,typ=obs)
    ax.text(0.05,0.95,lab[j-1],size=20,transform=ax.transAxes,va='top',bbox={'facecolor':'white','pad':1,'edgecolor':'none'})
    i+=4
    j+=2
## All available sites with emission mask
varbs = ['pom_a1_SRF','SE_pom_a1_SRF','pom_a4_SRF','SE_pom_a4_SRF','pom','SE_pom']
df = pd.read_csv('/global/cfs/cdirs/e3sm/www/hass877/share/Data_manuscript/POM_IMPROVE_dailydata.csv')
variable='POM'
df.index=pd.to_datetime(df['time'])
## Get rid of nans & -ve vals
df_nona = df[df['obs'].notna()]
df_nona = df_nona[df_nona['obs']>0]
res='RRM'
## Masking data based on mean +/- 0.5*Standard deviation
df_nona=get_emisMasked(df_nona,res,stat='diff',factor=0.68)
## Estimating the monthly and annual averages from daily data
df_monthly = df_nona.select_dtypes(include=[np.number]).groupby('ncols_'+res).resample('1M').mean()
df_annual = df_nona.select_dtypes(include=[np.number]).groupby('ncols_'+res).resample('1Y').mean()
## Plotting
i=2
j=2
for v,treat,ylab,xlab in zip(varbs[-2:],['POM ($\u03BCg\ m^{-3}$)               1.12',''],['',''],['','Obs']):
    ax=plt.subplot(2,4,i)
    get_scatter_plot2(df_monthly,v,res,ax,treatment=treat,temp='monthly',size=5,cax=[1e-2,1e-2,5e1,5e1],vv=ylab,vx=xlab,typ=obs)
    ax.text(0.05,0.95,lab[j-1],size=20,transform=ax.transAxes,va='top',bbox={'facecolor':'white','pad':1,'edgecolor':'none'})
    i+=4
    j+=2

# plt.savefig('fig13.png',format='png',dpi=300,bbox_inches='tight',pad_inches=0.1)
