import xarray as xr
import numpy as np
import pandas as pd
from datetime import datetime
import pkg_resources

import warnings
warnings.filterwarnings('ignore')

from src.utils import get_dates, get_num_factor, get_dir_path, rounding, get_sectors
from src.checkers import origSurfData_Checker, newSurfData_Checker
from src.ggen.driver_mod import driver
    
class gen_surf_emis(object):
    
    def __init__(self,start,end,filename,**kwargs):
        self.start = start
        self.end = end
        self.filename = filename
        self.varbl = kwargs.get('variable',None)
        self.species = kwargs.get('species',None)
        self.mode = kwargs.get('mode',None)
        self.all_sects = kwargs.get('sectors',None)
        self.start_yr_list = kwargs.get('syr_list',None)
        self.end_yr_list = kwargs.get('eyr_list',None)
        self.filename_list = kwargs.get('filenames',None)
        self.all_diams = kwargs.get('diam',None)
        self.all_mws = kwargs.get('mw',None)
        self.all_rhos = kwargs.get('rho',None)
        self.all_sulfactors = kwargs.get('sulfactor',None)
        self.all_fracs = kwargs.get('frac',None)
        self.all_numstrs = kwargs.get('numstr',None)
        self.numfile = kwargs.get('numfile',None)
        self.ncvars = kwargs.get('ncvars',None)
        self.param_file = kwargs.get('param_file',None)
        self.indir = kwargs.get('indir','')
        self.outdir = kwargs.get('outdir','')
        self.grid = kwargs.get('grid',None)
        self.cgrid = kwargs.get('cgrid',None)
        self.res = kwargs.get('res',None)
        self.mean = kwargs.get('mean',None)
        self.mean_yr = kwargs.get('mean_yr',None)
        self.checker = kwargs.get('checker',False)
        self.SEdata = kwargs.get('SEdata',True)
        self.checkVals = kwargs.get('checkVals',pd.DataFrame())
        self.ytag = kwargs.get('ytag',None)
        self.outfile = kwargs.get('output','E3SM_surf_emis.nc')
    
    def get_params(self):
        
        if self.param_file == None:
            ## Default species specific sectors
            resource_package = __name__
            resource_path = '/'.join(('..', 'data', 'E3SM_emis_default_params.csv'))
            self.param_file = pkg_resources.resource_stream(resource_package, resource_path)

        df = pd.read_csv(self.param_file)
        
        species_originname = {'so':'SO2_em_anthro',
                              'bc':'BC_em_anthro',
                              'po':'OC_em_anthro'}
        
        # define variable
        if self.varbl == None:
            if (self.species == None) or (self.mode == None):
                Exception('\nEither variable name or species and mode has to be declared!')
            else:
                print('\nSelected species: ',self.species)
                print('\nSelected mode: ',self.mode)
                tmp_var = self.species+'_a'+str(self.mode)
                self.varbl = tmp_var.lower()

        # define sectors
        if self.all_sects == None:
            self.all_sects = df.query("species=='"+self.varbl+"'")['sector'].tolist()
            if self.all_sects == []:
                raise Exception('\nUnknown sectors for: '+self.varbl+'\nSpecify sector for non-default species.')
        else:
            if len(df['sector'].loc[df['species'] == self.varbl]) == len(self.all_sects):
                df['sector'].loc[df['species'] == self.varbl] = self.all_sects
            else:
                df = pd.DataFrame()
                df['sector'] = self.all_sects
                df['species'] = self.varbl
                df = df[['species','sector']]

        # define diameters
        if self.all_diams == None:
            self.all_diams = df.query("species=='"+self.varbl+"'")['diam'].tolist()
        else:
            try:
                assert len(self.all_sects) == len(self.all_diams), "List of diameters and sectors should have the same length!"
                df['diam'].loc[df['species'] == self.varbl] = self.all_diams
            except:
                df['diam'] = self.all_diams
        # define MWs
        if self.all_mws == None:
            self.all_mws = df.query("species=='"+self.varbl+"'")['mw'].tolist()
        else:
            try:
                assert len(self.all_sects) == len(self.all_mws), "List of MW and sectors should have the same length!"
                df['mw'].loc[df['species'] == self.varbl] = self.all_mws
            except:
                df['mw'] = self.all_mws
        # define rhos
        if self.all_rhos == None:
            self.all_rhos = df.query("species=='"+self.varbl+"'")['rho'].tolist()
        else:
            try:
                assert len(self.all_sects) == len(self.all_rhos), "List of density and sectors should have the same length!"
                df['rho'].loc[df['species'] == self.varbl] = self.all_rhos
            except:
                df['rho'] = self.all_rhos
        # define sulfate factors
        if self.all_sulfactors == None:
            self.all_sulfactors = df.query("species=='"+self.varbl+"'")['sulfactor'].tolist()
        else:
            try:
                assert len(self.all_sects) == len(self.all_sulfactors), "List of sulfactor (MWso2/MWso4) and sectors should have the same length!"
                df['sulfactor'].loc[df['species'] == self.varbl] = self.all_sulfactors
            except:
                df['sulfactor'] = self.all_sulfactors
        # define fractions
        if self.all_fracs == None:
            self.all_fracs = df.query("species=='"+self.varbl+"'")['frac'].tolist()
        else:
            try:
                assert len(self.all_sects) == len(self.all_fracs), "List of fractions and sectors should have the same length!"
                df['frac'].loc[df['species'] == self.varbl] = self.all_fracs
            except:
                df['frac'] = self.all_fracs
        # define number strings    
        if self.all_numstrs == None:
            self.all_numstrs = df.query("species=='"+self.varbl+"'")['numstr'].tolist()
        else:
            try:
                assert len(self.all_sects) == len(self.all_numstrs), "List of numstr and sectors should have the same length!"
                df['numstr'].loc[df['species'] == self.varbl] = self.all_numstrs
            except:
                df['numstr'] = self.all_numstrs
        
        # define variable names in the netcdfs    
        if self.ncvars == None:
            self.ncvars = [species_originname[self.varbl[:2]]]*len(self.all_sects)
        
        df['start_yr'] = np.nan
        df['end_yr'] = np.nan
        df['OrigFileName'] = np.nan
        
        if self.start_yr_list == None:
            self.start_yr_list = [self.start]*len(self.all_sects)
            self.end_yr_list = [self.end]*len(self.all_sects)
        else:
            assert len(self.all_sects) == len(self.start_yr_list), "List of starting years and sectors should have the same length!"
            assert len(self.all_sects) == len(self.end_yr_list), "List of ending years and sectors should have the same length!"
        
        if self.filename_list == None:
            self.filename_list = [self.filename]*len(self.all_sects)
        else:
            assert len(self.all_sects) == len(self.filename_list), "List of files and sectors should have the same length!"
        
        df['start_yr'].loc[df['species'] == self.varbl]= self.start_yr_list
        df['end_yr'].loc[df['species'] == self.varbl] = self.end_yr_list
        df['OrigFileName'].loc[df['species'] == self.varbl]= self.filename_list
        
        # define number conc file
        if self.numfile == None:
            self.numfile = df.query("species=='"+self.varbl+"'")['numfile'].unique()[0]
        else:
            try:
                assert len(self.all_sects) == len(self.numfile), "List of numfile and sectors should have the same length!"
                df['numfile'].loc[df['species'] == self.varbl] = self.numfile
            except:
                df['numfile'] = self.numfile
                
        print('\nSelected parameters for species: ',self.varbl,'\n')
        print(df.query("species=='"+self.varbl+"'").reset_index(drop=True))
        print('\n')
    
    def get_data(self,filename,start,end,varbl):
        print(filename)
        in_dir_path = get_dir_path(self.indir)  
        out_dir_path = get_dir_path(self.outdir) 

        if self.SEdata == True:
            print('\nRemapping from RLL to SE grid.')
            try:
                fname = driver(file=filename,ind=str(in_dir_path),out=str(out_dir_path),grid=self.grid,res=self.res).gen_remapped_files()
            except:
                fname = driver(file=filename,ind=str(in_dir_path),out=str(out_dir_path),grid=self.grid,res=self.res,xdim='longitude',ydim='latitude').gen_remapped_files()
            print(fname)
            print('\nRemapped files are output to:',str(out_dir_path))
            data = xr.open_dataset(fname).sel(time=slice(str(start),str(end)))
            print('\nRe-formatting data for E3SM.')
            lev=np.array([1e5])
            data=data.expand_dims('lev',axis=1)
            data = data.assign_coords(lev=('lev',lev))
            attrs = data.attrs
            if self.mean != None:
                print('\nGetting monthly climotology.')
                xr.set_options(use_flox=True)
                data = data.groupby('time.month').mean('time')
                data.attrs = attrs
            print('\nAssigning coordinates.')
            coords = {
                      'lat': (['ncol'], np.array(data['lat'].data, dtype=np.float32),{'units': 'degrees_north', 'long_name':'latitude'}),
                      'lon': (['ncol'], np.array(data['lon'].data, dtype=np.float32),{'units': 'degrees_east', 'long_name':'longitude'}),
                      'lev': (['lev'], np.array(data['lev'].data, dtype=np.float32),{'units': 'Pa','long_name':'dummy_dim'})
                      }
        else:
            if (start < 1997) and ('em_anthro' not in varbl):
                print('\nUsing CMIP6 biomass burning data.')
                pre_biomass = xr.open_dataset(in_dir_path / filename).sel(time=slice(str(start),str(end)))
                attrs = pre_biomass.attrs
                pre_biomass[varbl].encoding['_FillValue'] = 9.96921e+36
                pre_biomass[varbl].encoding['missing_value'] = 9.96921e+36
                pre_biomass.load().to_netcdf(out_dir_path / str(varbl+'_biomass_fixed_'+str(start)+'-'+str(end)+'.nc'))
                filename = varbl+'_biomass_fixed_'+str(start)+'-'+str(end)+'.nc'
            print('\nUsing the RLL grid data.')
            try:
                fname = driver(file=filename,ind=str(in_dir_path),out=str(out_dir_path),grid=self.grid,res=self.res).gen_remapped_files()
            except:
                fname = driver(file=filename,ind=str(in_dir_path),out=str(out_dir_path),grid=self.grid,res=self.res,xdim='longitude',ydim='latitude').gen_remapped_files()
            print(fname)
            print('\nRemapped files are output to:',str(out_dir_path))
            data = xr.open_dataset(fname).sel(time=slice(str(start),str(end)))
            try:
                data = data.rename({'latitude':'lat','longitude':'lon'})
            except:
                pass
            attrs = data.attrs
            if self.mean != None:
                print('\nGetting monthly climotology.')
                xr.set_options(use_flox=True)
                data = data.groupby('time.month').mean('time')
                data.attrs = attrs
            print('\nAssigning coordinates.')
            if len(data['lat'].shape) > 1:
                coords = {
                          'lat': (['lat'], np.array(data['lat'][:,0].data, dtype=np.float64),{'units': 'degrees_north', 'long_name':'latitude'}),
                          'lon': (['lon'], np.array(data['lon'][0,:].data, dtype=np.float64),{'units': 'degrees_east', 'long_name':'longitude'})
                          }
            else:
                coords = {
                          'lat': (['lat'], np.array(data['lat'].data, dtype=np.float64),{'units': 'degrees_north', 'long_name':'latitude'}),
                          'lon': (['lon'], np.array(data['lon'].data, dtype=np.float64),{'units': 'degrees_east', 'long_name':'longitude'})
                          }

        return data, coords
    
    def get_vars(self):
        ## Conversion factors
        avgod = 6.022e23
        
        ## Generate the data variables, coordinates, and attributes  
        self.get_params()
        
        ## Get the longest period
        ind = np.argmax(np.array(self.end_yr_list)-np.array(self.start_yr_list))
        if len(set(self.start_yr_list)) > 1:
            start = self.start_yr_list[ind]
            end = self.end_yr_list[ind]
        else:
            start = self.start 
            end = self.end
        years = end - start + 1
        if (self.mean_yr != None) and (self.mean != None):
            dates = get_dates(self.mean_yr,self.mean_yr)
        else:
            dates = get_dates(start,end)
        data_vars = {'date':(['time'], dates)}
        num_vars = {'date':(['time'], dates)}
        
        for sect,filename,s,e,mw,rho,diam,sulfactor,num_str,frac,orig_varname in zip(self.all_sects,self.filename_list,self.start_yr_list,self.end_yr_list,self.all_mws,\
                                                                                     self.all_rhos,self.all_diams,self.all_sulfactors,self.all_numstrs,self.all_fracs,self.ncvars):
            
            print('\n====Processing '+sect+' sector now.====')
            factor = 10*mw/avgod         # convert molec/cm2/s to kg/m2/s (vice versa for i/factor)
            num_factor = get_num_factor(mw, rho, diam)
            data, coords = self.get_data(filename,start=s,end=e,varbl=orig_varname)

            if sect != 'BB':            
                av_sectors = get_sectors(data)
                
            if (self.mean == None) and (len(data.time)/12 < years):
                print('\nCopying same year for the whole time period in ',sect)
                data_renamed = data.copy().rename({'time': 'month'})
                data_renamed['month'] = [1,2,3,4,5,6,7,8,9,10,11,12]
                date_range = xr.cftime_range(str(start)+"-01-01", str(end)+"-12-31", freq="MS")
                months = date_range.month
                selArr = xr.DataArray(months, dims=["time"], coords=[date_range])
                data = data_renamed.sel(month=selArr)

            try:
                emis_val = data[orig_varname].sel(sector=av_sectors[sect]).data/factor*frac
                new_var = {sect:(list(data[orig_varname].sel(sector=0).dims), emis_val, {'units': 'molecules/cm2/s'})}
                num_var = {num_str+'_'+self.varbl.split('_')[0].upper()+'_'+sect:(list(data[orig_varname].sel(sector=0).dims), emis_val*num_factor*sulfactor, {'units': '(particles/cm2/s) * 6.022e26'})}
            except:
                emis_val = data.fillna(0)[orig_varname].data/factor*frac
                new_var = {sect:(list(data[orig_varname].dims), emis_val, {'units': 'molecules/cm2/s'})}
                num_var = {num_str+'_'+self.varbl.split('_')[0].upper()+'_'+sect:(list(data[orig_varname].dims), emis_val*num_factor*sulfactor, {'units': '(particles/cm2/s) * 6.022e26'})}
            
            data_vars.update(new_var)
            num_vars.update(num_var)
        
        ## General attributes are added here. User may change the attrs below if needed.    
        attrs = {
                'comment':'This data was produced using the E3SM emission pre-processor',
                'contact':'Taufiq Hassan (taufiq.hassan@pnnl.gov)',
                'creation_date':datetime.today().strftime('%Y-%m-%d %H:%M:%S'),
                'origin_grid':data.attrs['grid'],
                'origin_nominal_resolution':data.attrs['nominal_resolution'],
                }
        ## Specific for IND and ENE elev sources
        if 'so' in self.varbl:
            orig_varname = self.ncvars[0]
            data, coords = self.get_data(self.filename_list[0],start=s,end=e,varbl=orig_varname)
            print('\nProducing the '+self.varbl+' IND and ENE outputs to be used in the elevated emissions.')
            dir_path = get_dir_path(self.outdir)
            if self.mean != None:
                data = data.rename({'month':'time'})
            emis_ind = data[orig_varname].sel(sector=av_sectors['IND']).data/factor*frac
            emis_ene = data[orig_varname].sel(sector=av_sectors['ENE']).data/factor*frac
            data_vars_ind = {'date':(['time'], dates),
                             'IND':(list(data[orig_varname].sel(sector=0).dims), emis_ind,{'units': 'molecules/cm2/s'}),
                             'ENE':(list(data[orig_varname].sel(sector=0).dims), emis_ene,{'units': 'molecules/cm2/s'})
                             }
            ds_ind = xr.Dataset(data_vars=data_vars_ind, coords=coords)
            if 'lev' in ds_ind.dims:
                ds_ind = ds_ind.isel(lev=0).drop('lev')
            ds_ind.encoding = {'unlimited_dims': {'time'}}
            ds_ind.to_netcdf(dir_path / str(self.varbl+'_IND_emis_'+str(start)+'-'+str(end)+'.nc'))
        
        if self.checker:
            orig_vals = origSurfData_Checker(years,self.all_sects,self.indir,self.filename_list,self.start_yr_list,self.end_yr_list,self.varbl,self.mean,self.all_fracs)
            self.checkVals['sectors'] = self.all_sects
            self.checkVals['Orig (Tg/yr)'] = orig_vals
            self.checkVals = self.checkVals[self.checkVals.sectors.isin(self.all_sects)]
                
        return data_vars, num_vars, attrs, coords
    
    def prod_emis(self):
        data_vars, num_vars, attrs, coords = self.get_vars()
        dir_path = get_dir_path(self.outdir)
        
        ## Number conc out
        if self.varbl != 'so2':
            ds = xr.Dataset(data_vars=num_vars, coords=coords, attrs=attrs)
            if self.mean != None:
                ds = ds.rename({'month':'time'})
                self.ytag = str(self.mean_yr)+'CLIM'
            else:
                self.ytag = str(self.start)+'-'+str(self.end)
            ds.encoding = {'unlimited_dims': {'time'}}
            comp = dict(_FillValue=None)
            encodings = {var: comp for var in ds.data_vars}
            encodings.update({coord: comp for coord in ds.coords})
            ds.to_netcdf(dir_path / str(self.numfile+'_'+self.varbl+'_'+self.ytag+'_'+self.outfile),encoding=encodings,format="NETCDF3_64BIT")
            
        ## Mass out
        ds = xr.Dataset(data_vars=data_vars, coords=coords, attrs=attrs)
        if self.mean != None:
            ds = ds.rename({'month':'time'})
            self.ytag = str(self.mean_yr)+'CLIM'
        else:
            self.ytag = str(self.start)+'-'+str(self.end)
        ds.encoding = {'unlimited_dims': {'time'}}
        comp = dict(_FillValue=None)
        encodings = {var: comp for var in ds.data_vars}
        encodings.update({coord: comp for coord in ds.coords})
        ds.to_netcdf(dir_path / str(self.varbl+'_'+self.ytag+'_'+self.outfile),encoding=encodings,format="NETCDF3_64BIT")
        
        if self.checker:
            new_vals = newSurfData_Checker(ds,self.all_mws[0],self.indir,self.mean,grid=self.cgrid)
            self.checkVals['New (Tg/yr)'] = new_vals
            self.checkVals[['Orig (Tg/yr)','New (Tg/yr)']] = self.checkVals[['Orig (Tg/yr)','New (Tg/yr)']].applymap(lambda x: rounding(x))
            print(self.checkVals)
    
    def combine_numvars(self,numval):
        print('\nCombining number conc files.')
        if self.mean != None:
            self.ytag = str(self.mean_yr)+'CLIM'
        else:
            self.ytag = str(self.start)+'-'+str(self.end)
        dir_path = get_dir_path(self.outdir)
        print('File source:',str(dir_path)+'/'+str(numval+'_*_'+self.ytag+'_'+self.outfile))
        data = xr.open_mfdataset(str(dir_path)+'/'+str(numval+'_*_'+self.ytag+'_'+self.outfile))
        data.load().to_netcdf(dir_path / str(numval+'_'+self.ytag+'_'+self.outfile),format="NETCDF3_64BIT")
    
