import xarray as xr
import numpy as np
import pandas as pd
from pathlib import Path
from datetime import datetime
import pkg_resources

import warnings
warnings.filterwarnings('ignore')

from src.ggen.driver_mod import driver
from src.utils import get_dates, get_num_factor, get_dir_path, get_zcol, rounding
from src.checkers import newElevData_Checker, origElevData_Checker
from src.prep_GFED4 import gen_GFED4

    
class gen_elev_emis(object):
    
    def __init__(self,start,end,filename,**kwargs):
        self.start = start
        self.end = end
        self.filename = filename
        self.profile = kwargs.get('profile',None)
        self.varbl = kwargs.get('variable',None)
        self.species = kwargs.get('species',None)
        self.mode = kwargs.get('mode',None)
        self.all_sects = kwargs.get('sectors',None)
        self.start_yr_list = kwargs.get('syr_list',None)
        self.end_yr_list = kwargs.get('eyr_list',None)
        self.filename_list = kwargs.get('filenames',None)
        self.profile_list = kwargs.get('profiles',None)
        self.all_diams = kwargs.get('diam',None)
        self.all_mws = kwargs.get('mw',None)
        self.all_rhos = kwargs.get('rho',None)
        self.all_sulfactors = kwargs.get('sulfactor',None)
        self.all_fracs = kwargs.get('frac',None)
        self.all_numstrs = kwargs.get('numstr',None)
        self.numfile = kwargs.get('numfile',None)
        self.param_file = kwargs.get('param_file',None)
        self.indir = kwargs.get('indir','')
        self.outdir = kwargs.get('outdir','')
        self.grid = kwargs.get('grid',None)
        self.cgrid = kwargs.get('cgrid',None)
        self.SEdata = kwargs.get('SEdata',True)
        self.res = kwargs.get('res',None)
        self.mean = kwargs.get('mean',None)
        self.mean_yr = kwargs.get('mean_yr',None)
        self.altitude = kwargs.get('altitude',None)
        self.checker = kwargs.get('checker',False)
        self.fractions = kwargs.get('fractions',[0.25,0.15,0.15,0.1,0.1,0.1,0.05,0.05,0.05,0, 0, 0, 0])
        self.ind_frac = kwargs.get('ind_frac',[0.13,0.76,0.11,0,0,0,0,0,0,0, 0, 0, 0])
        self.checkVals = kwargs.get('checkVals',pd.DataFrame())
        self.ytag = kwargs.get('ytag',None)
        self.prep = kwargs.get('prep',None)
        self.outfile = kwargs.get('output','E3SM_elev_emis.nc')
    
    def get_params(self):
        
        if self.param_file == None:
            ## Default species specific sectors
            resource_package = __name__
            resource_path = '/'.join(('..', 'data', 'E3SM_elev_emis_default_params.csv'))
            self.param_file = pkg_resources.resource_stream(resource_package, resource_path)

        df = pd.read_csv(self.param_file)
        
        if self.varbl == None:
            if (self.species == None) or (self.mode == None):
                Exception('\nEither variable name or species and mode has to be declared!')
            else:
                print('\nSelected species: ',self.species)
                print('\nSelected mode: ',self.mode)
                tmp_var = self.species+'_a'+str(self.mode)
                self.varbl = tmp_var.lower()
                df = pd.DataFrame(columns=df.columns)
                df['species'] = [self.varbl]

        if self.all_sects == None:
            self.all_sects = df.query("species=='"+self.varbl+"'")['sector'].tolist()
            if self.all_sects == []:
                raise Exception('\nUnknown sectors for: '+tmp_var+'\nSpecify sector for non-default species.')
        else:
            print(df['sector'].loc[df['species'] == self.varbl])
            df['sector'].loc[df['species'] == self.varbl] = self.all_sects
            # df['sector'] = self.all_sects
            # df['species'] = self.varbl
        ## Get all default values
        if self.all_diams == None:
            self.all_diams = df.query("species=='"+self.varbl+"'")['diam'].tolist()
        else:
            assert len(self.all_sects) == len(self.all_diams), "List of diameters and sectors should have the same length!"
            df['diam'].loc[df['species'] == self.varbl] = self.all_diams
        if self.all_mws == None:
            self.all_mws = df.query("species=='"+self.varbl+"'")['mw'].tolist()
        else:
            assert len(self.all_sects) == len(self.all_mws), "List of MW and sectors should have the same length!"
            df['mw'].loc[df['species'] == self.varbl] = self.all_mws
        if self.all_rhos == None:
            self.all_rhos = df.query("species=='"+self.varbl+"'")['rho'].tolist()
        else:
            assert len(self.all_sects) == len(self.all_rhos), "List of density and sectors should have the same length!"
            df['rho'].loc[df['species'] == self.varbl] = self.all_rhos
        if self.all_sulfactors == None:
            self.all_sulfactors = df.query("species=='"+self.varbl+"'")['sulfactor'].tolist()
        else:
            assert len(self.all_sects) == len(self.all_sulfactors), "List of sulfactor (MWso2/MWso4) and sectors should have the same length!"
            df['sulfactor'].loc[df['species'] == self.varbl] = self.all_sulfactors
        if self.all_fracs == None:
            self.all_fracs = df.query("species=='"+self.varbl+"'")['frac'].tolist()
        else:
            assert len(self.all_sects) == len(self.all_fracs), "List of fractions and sectors should have the same length!"
            df['frac'].loc[df['species'] == self.varbl] = self.all_fracs
        if self.all_numstrs == None:
            self.all_numstrs = df.query("species=='"+self.varbl+"'")['numstr'].tolist()
        else:
            assert len(self.all_sects) == len(self.all_numstrs), "List of numstr and sectors should have the same length!"
            df['numstr'].loc[df['species'] == self.varbl] = self.all_numstrs
        
        df['start_yr'] = np.nan
        df['end_yr'] = np.nan
        
        if self.start_yr_list == None:
            self.start_yr_list = [self.start]*len(self.all_sects)
            self.end_yr_list = [self.end]*len(self.all_sects)
        else:
            assert len(self.all_sects) == len(self.start_yr_list), "List of starting years and sectors should have the same length!"
            assert len(self.all_sects) == len(self.end_yr_list), "List of ending years and sectors should have the same length!"
        
        if self.filename_list == None:
            self.filename_list = [self.filename]*len(self.all_sects)
            for i,sect,s,e in zip(range(len(self.filename_list)),self.all_sects,self.start_yr_list,self.end_yr_list):
                if sect == 'BB':
                    if self.filename == None:
                        self.filename_list[i] = None
                    else:
                        self.filename_list[i] = self.filename
                if sect == 'contvolc':
                    self.filename_list[i] = 'contvolc_'+self.varbl+'_elev_HR_emis_profile.nc'
                if '_ELEV' in sect:
                    self.filename_list[i] = self.varbl+'_IND_emis_'+str(s)+'-'+str(e)+'.nc'
        else:
            assert len(self.all_sects) == len(self.filename_list), "List of files and sectors should have the same length!"
        
        if self.profile_list == None:
            self.profile_list = [self.profile]*len(self.all_sects)
            for i,sect in zip(range(len(self.filename_list)),self.all_sects):
                if sect == 'BB':
                    self.profile_list[i] = self.varbl+'_elev_HR_emis_profile.nc'
                if sect == 'contvolc':
                    self.profile_list[i] = None
                if '_ELEV' in sect:
                    self.profile_list[i] = None
        else:
            assert len(self.all_sects) == len(self.filename_list), "List of files and sectors should have the same length!"

        df['start_yr'].loc[df['species'] == self.varbl]= self.start_yr_list
        df['end_yr'].loc[df['species'] == self.varbl] = self.end_yr_list
        if self.numfile == None:
            self.numfile = df.query("species=='"+self.varbl+"'")['numfile'].unique()[0]
        print('\nSelected parameters for species: ',self.varbl)
        print(df.query("species=='"+self.varbl+"'").reset_index(drop=True))
    
    def get_bb_data(self,filename,start,end):
        in_dir_path = get_dir_path(self.indir)  
        out_dir_path = get_dir_path(self.outdir)
        zcol, altitude_air, altitude_int_air = get_zcol(altitude_air=self.altitude)
        var_originname = {'so':'SO2',
                        'bc':'BC',
                        'po':'OC'}
        varbl = var_originname[self.varbl[:2]]
        if (filename==None) or (filename=='None'):
            print('\nUsing GFED4 data.')
            if self.prep == 'GFED':
                filename = gen_GFED4(varbl,str(in_dir_path),str(out_dir_path),start,end)
                print('\nGenerated GFED4 data:',filename)
            else:
                filename = varbl+'_biomass_burning_emis_GFED_1997-2022.nc'
        else:
            print('\nUsing CMIP6 biomass burning data.')
            pre_biomass = xr.open_dataset(in_dir_path / filename).sel(time=slice(str(start),str(end)))
            attrs = pre_biomass.attrs
            pre_biomass[varbl].encoding['_FillValue'] = 9.96921e+36
            pre_biomass[varbl].encoding['missing_value'] = 9.96921e+36
            pre_biomass.load().to_netcdf(out_dir_path / str(varbl+'_biomass_fixed_'+str(start)+'-'+str(end)+'.nc'))
            filename = varbl+'_biomass_fixed_'+str(start)+'-'+str(end)+'.nc'
        if self.SEdata:
            print('\nRemapping BB data from RLL to SE grid.')
            fname = driver(file=filename,ind=str(in_dir_path),out=str(out_dir_path),grid=self.grid,res=self.res,xdim='longitude',ydim='latitude').gen_remapped_files()
            biomass = xr.open_dataset(fname).sel(time=slice(str(start),str(end)))
            grid_info = biomass.attrs['grid']
            res_info = biomass.attrs['nominal_resolution']
            data_biom = biomass[varbl]
            
            if self.mean != None:
                print('\nGetting monthly climotology.')
                xr.set_options(use_flox=True)
                data_biom = data_biom.groupby('time.month').mean('time')
                data_biom = data_biom.rename({'month':'time'})
                
            coords = {
                      'altitude': (['altitude'], altitude_air.data,{'units': 'km', 'long_name':'altitude midlevel'}),
                      'altitude_int': (['altitude_int'], altitude_int_air,{'units': 'km', 'long_name':'altitude interval'}),
                      'lat': (['ncol'], np.array(data_biom['lat'].data, dtype=np.float32),{'units': 'degrees_north', 'long_name':'latitude'}),
                      'lon': (['ncol'], np.array(data_biom['lon'].data, dtype=np.float32),{'units': 'degrees_east', 'long_name':'longitude'})
                      }
        else:
            print('\nUsing the original RLL grid BB data.')
            try:
                fname = driver(file=filename,ind=str(in_dir_path),out=str(out_dir_path),grid=self.grid,res=self.res).gen_remapped_files()
            except:
                fname = driver(file=filename,ind=str(in_dir_path),out=str(out_dir_path),grid=self.grid,res=self.res,xdim='longitude',ydim='latitude').gen_remapped_files()
            print(fname)
            biomass = xr.open_dataset(fname).sel(time=slice(str(start),str(end)))
            grid_info = biomass.attrs['grid']
            res_info = biomass.attrs['nominal_resolution']
            data_biom = biomass[varbl]
            data_biom = data_biom.rename({'latitude':'lat','longitude':'lon'})
            
            if self.mean != None:
                print('\nGetting monthly climotology.')
                xr.set_options(use_flox=True)
                data_biom = data_biom.groupby('time.month').mean('time')
                data_biom = data_biom.rename({'month':'time'})
                
            coords = {
                      'altitude': (['altitude'], altitude_air.data,{'units': 'km', 'long_name':'altitude midlevel'}),
                      'altitude_int': (['altitude_int'], altitude_int_air,{'units': 'km', 'long_name':'altitude interval'}),
                      'lat': (['lat'], np.array(biomass['lat'].data, dtype=np.float64),{'units': 'degrees_north', 'long_name':'latitude'}),
                      'lon': (['lon'], np.array(biomass['lon'].data, dtype=np.float64),{'units': 'degrees_east', 'long_name':'longitude'})
                      }
        
        data_biom.encoding['_FillValue'] = 9.96921e+36
        data_biom.encoding['missing_value'] = 9.96921e+36
        attrs = {
                'comment':'This data was produced using the E3SM emission pre-processor',
                'contact':'Taufiq Hassan (taufiq.hassan@pnnl.gov)',
                'creation_date':datetime.today().strftime('%Y-%m-%d %H:%M:%S'),
                'origin_grid':grid_info,
                'origin_nominal_resolution':res_info,
                }
        return data_biom, coords, attrs
    
    def get_injection_heights(self,profile):
        in_dir_path = get_dir_path(self.indir)  
        out_dir_path = get_dir_path(self.outdir)
        if self.SEdata:
            print('\nRemapping BB profile from RLL to SE grid.')
            fname = driver(file=profile,ind=str(in_dir_path),out=str(out_dir_path),grid=self.grid,res=self.res).gen_remapped_files()
            print('\nRemapped BB profile file: ',fname)
            fire = xr.open_dataset(fname)['BB']
            coords = {
                      'altitude': (['altitude'], np.array(fire['altitude'].data, dtype=np.float32),{'units': 'km', 'long_name':'altitude','_FillValue':9.96921e+36}),
                      'lat': (['ncol'], np.array(fire['lat'].data, dtype=np.float32),{'units': 'degrees_north', 'long_name':'latitude','_FillValue':9.96921e+36}),
                      'lon': (['ncol'], np.array(fire['lon'].data, dtype=np.float32),{'units': 'degrees_east', 'long_name':'longitude','_FillValue':9.96921e+36})
                      }
        else:
            print('\nUsing the original RLL grid BB profile data.')
            try:
                fname = driver(file=profile,ind=str(in_dir_path),out=str(out_dir_path),grid=self.grid,res=self.res).gen_remapped_files()
            except:
                fname = driver(file=profile,ind=str(in_dir_path),out=str(out_dir_path),grid=self.grid,res=self.res,xdim='longitude',ydim='latitude').gen_remapped_files()
            print(fname)
            fire = xr.open_dataset(fname)['BB']
            coords = {
                      'altitude': (['altitude'], np.array(fire['altitude'].data, dtype=np.float64),{'units': 'km', 'long_name':'altitude','_FillValue':9.96921e+36}),
                      'lat': (['lat'], np.array(fire['lat'].data, dtype=np.float64),{'units': 'degrees_north', 'long_name':'latitude','_FillValue':9.96921e+36}),
                      'lon': (['lon'], np.array(fire['lon'].data, dtype=np.float64),{'units': 'degrees_east', 'long_name':'longitude','_FillValue':9.96921e+36})
                      }
        
        new_prof_name = str(out_dir_path / str('fireEmisfrac_'+str(fname).split('/')[-1]))

        print('\nSetting injection heights and vertical distribution fractions')
        if self.altitude==None:
            print('\nSetting to AEROCOM default.')
            self.altitude = np.array([0.063000001013279, 0.202000007033348, 0.36599999666214,0.554000020027161, \
            0.767000019550323, 1.00300002098083, 1.26199996471405,1.64100003242493, \
            2.23200011253357, 3.02500009536743, 4.00899982452393,5.15700006484985, \
            6.35599994659424], dtype=np.float64)
            self.fractions = np.array(self.fractions)
        else:
            self.altitude = np.array(self.altitude, dtype=np.float64)
            alt_cumsum = (self.fractions*100).cumsum()
            ind_cumsum = (self.ind_frac*100).cumsum()
            da = xr.DataArray(alt_cumsum,dims='x',coords={'x':self.altitude})
            da_ind = xr.DataArray(ind_cumsum,dims='x',coords={'x':self.altitude})
            bb = da.interp(x=self.altitude,kwargs={"fill_value": 0.0},method='linear')
            bb_ind = da_ind.interp(x=self.altitude,kwargs={"fill_value": 0.0},method='linear')
            self.fractions = (bb[1:].values-bb[:-1].values)/100
            self.ind_frac = (bb_ind[1:].values-bb_ind[:-1].values)/100
            self.fractions = np.where(self.fractions<0,0,self.fractions)
            self.ind_frac = np.where(self.ind_frac<0,0,self.ind_frac)
            fire = fire.interp(altitude=self.altitude,kwargs={"fill_value": 0.0},method='linear')
        
        ## Fractions for SO2/SO4
        self.ind_frac = xr.DataArray(self.ind_frac,coords={'altitude': self.altitude},dims=["altitude"])
        
        if not Path(new_prof_name).is_file():
            if self.altitude[0] <= 0:
                print('\nInjection height can not be zero (surface)!')
            
            altitude_int_air = np.zeros(len(self.altitude)+1,dtype=np.float64)
            for i in range(len(self.altitude)):
                altitude_int_air[i+1] = self.altitude[i]*2 - altitude_int_air[i]
            dz = (altitude_int_air[1:] - altitude_int_air[:-1])*1e5 # km to cm
            zcol = xr.DataArray(dz,coords={'altitude': self.altitude},dims=["altitude"])
            
            fire['altitude'] = zcol['altitude']
            fire.encoding['_FillValue'] = 9.96921e+36
            temp = fire*zcol/(fire*zcol).sum('altitude')
            
            print('\nFillout missing places with pre-defined fractions:', self.fractions)
            for i, frac in zip(range(len(zcol)),self.fractions):
                temp[:,i] = np.where(np.isnan(temp[:,i]),frac,temp[:,i])
            temp.name = 'frac'
            data_vars = {
                        'frac':(list(fire.dims), temp.data,{'units': 'fractions','_FillValue':9.96921e+36})
                        }
            ds = xr.Dataset(data_vars=data_vars, coords=coords)
            print('\nSaving BB profile to '+str(out_dir_path))
            ds.load().to_netcdf(out_dir_path / str(new_prof_name))
        else:
            print('\n'+new_prof_name+' already exists!\n Using it.')
            print('\nNote: If changes are made to the vertical profile,\nmake sure '+new_prof_name+' does not exist!')
        return new_prof_name
    
    def get_vars(self):
        ## Conversion factors
        avgod = 6.022e23
        in_dir_path = get_dir_path(self.indir)  
        out_dir_path = get_dir_path(self.outdir)
        
        ## Generate the data variables, coordinates, and attributes  
        self.get_params()
        ## Get the longest period
        ind = np.argmax(np.array(self.end_yr_list)-np.array(self.start_yr_list))
        if len(set(self.start_yr_list)) > 1:
            start = self.start_yr_list[ind]
            end = self.end_yr_list[ind]
        else:
            start = self.start 
            end = self.end
        years = end - start + 1
        if (self.mean_yr != None) and (self.mean != None):
            dates = get_dates(self.mean_yr,self.mean_yr)
        else:
            dates = get_dates(start,end)

        data_vars = {'date':(['time'], dates)}
        num_vars = {'date':(['time'], dates)}
        
        for sect,filename,profile,s,e,mw,rho,diam,sulfactor,num_str,frac in zip(self.all_sects,self.filename_list,\
                                                                                self.profile_list,self.start_yr_list,\
                                                                                self.end_yr_list,self.all_mws,\
                                                                                self.all_rhos,self.all_diams,\
                                                                                self.all_sulfactors,self.all_numstrs,\
                                                                                self.all_fracs):
            
            print('\n====Processing '+sect+' sector now.====')
            factor = 10*mw/avgod         # convert molec/cm2/s to kg/m2/s (vice versa for i/factor)
            num_factor = get_num_factor(mw, rho, diam)
            
            if sect == 'BB':
                BB_data, coords, attrs = self.get_bb_data(filename,start=s,end=e)
                
                ## Get BB profile
                frac_file = self.get_injection_heights(profile)
                print('\nFound BB Profile: ',frac_file)
                new_prof = xr.open_dataset(out_dir_path / frac_file)['frac']
                
                altitude_air = new_prof.altitude
                altitude_int_air = np.zeros(len(altitude_air)+1,dtype=np.float64)
                for i in range(len(altitude_air)):
                    altitude_int_air[i+1] = altitude_air[i]*2 - altitude_int_air[i]            
                dz = (altitude_int_air[1:] - altitude_int_air[:-1])*1e5 # km to cm
                zcol = xr.DataArray(dz,coords={'altitude': altitude_air},dims=["altitude"])
                
                if (self.mean == None) and (len(BB_data.time)/12 < years):
                    print('\nCopying same year for the whole time period in ',sect)
                    data_renamed = BB_data.copy().rename({'time': 'month'})
                    data_renamed['month'] = [1,2,3,4,5,6,7,8,9,10,11,12]
                    date_range = xr.cftime_range(str(start)+"-01-01", str(end)+"-12-31", freq="MS")
                    months = date_range.month
                    selArr = xr.DataArray(months, dims=["time"], coords=[date_range])
                    BB_data = data_renamed.sel(month=selArr)
                
                print('\nCopying BB profile for the whole time period in ',sect)
                if (self.mean == None):
                    new_prof_renamed = new_prof.copy().rename({'time': 'month'})
                    new_prof_renamed['month'] = [1,2,3,4,5,6,7,8,9,10,11,12]
                    date_range = xr.cftime_range(str(start)+"-01-01", str(end)+"-12-31", freq="MS")
                    months = date_range.month
                    selArr = xr.DataArray(months, dims=["time"], coords=[date_range])
                    new_prof = new_prof_renamed.sel(month=selArr)
                new_prof['time'] = BB_data['time']
                
                bc_elev_emis = new_prof.load()*BB_data.fillna(0)/zcol
                print('\nBB data shape: ',bc_elev_emis.shape)
                
                emis_val = bc_elev_emis.data/factor*frac
                new_var = {sect:(list(bc_elev_emis.dims), np.array(emis_val,dtype=np.float32),{'units': 'molecules/cm3/s'})}
                data_vars.update(new_var)
                num_var = {num_str+'_'+self.varbl.split('_')[0].upper()+'_ELEV_'+sect:(list(bc_elev_emis.dims), np.array(emis_val*num_factor,dtype=np.float32),{'units': '(particles/cm3/s) * 6.022e26'})}
                num_vars.update(num_var)
                
            if sect == 'contvolc':
                ## This is temporary ##
                aero_alt = np.array([0.063000001013279, 0.202000007033348, 0.36599999666214,0.554000020027161, \
                0.767000019550323, 1.00300002098083, 1.26199996471405,1.64100003242493, \
                2.23200011253357, 3.02500009536743, 4.00899982452393,5.15700006484985, \
                6.35599994659424], dtype=np.float64)
                zcol, altitude_air, altitude_int_air = get_zcol(altitude_air=list(aero_alt))
                #### For so4_a2 ###
                if self.SEdata:
                    print('\nRemapping volcanic data from RLL to SE grid.')
                    fname = driver(file=filename,ind=str(in_dir_path),out=str(out_dir_path),grid=self.grid,res=self.res).gen_remapped_files()
                    volc_data = xr.open_mfdataset(out_dir_path / str(fname))['contvolc']
                    coords = {
                              'altitude': (['altitude'], altitude_air.data,{'units': 'km', 'long_name':'altitude midlevel'}),
                              'altitude_int': (['altitude_int'], altitude_int_air,{'units': 'km', 'long_name':'altitude interval'}),
                              'lat': (['ncol'], np.array(volc_data['lat'].data, dtype=np.float32),{'units': 'degrees_north', 'long_name':'latitude'}),
                              'lon': (['ncol'], np.array(volc_data['lon'].data, dtype=np.float32),{'units': 'degrees_east', 'long_name':'longitude'})
                              }
                else:
                    try:
                        remapped_volc = driver(file=filename,ind=str(in_dir_path),out=str(out_dir_path),grid=self.grid,res=self.res).gen_remapped_files()
                    except:
                        remapped_volc = driver(file=filename,ind=str(in_dir_path),out=str(out_dir_path),grid=self.grid,res=self.res,xdim='longitude',ydim='latitude').gen_remapped_files()
                    print(remapped_volc)
                    volc_data = xr.open_mfdataset(remapped_volc)['contvolc']
                    coords = {
                              'altitude': (['altitude'], altitude_air.data,{'units': 'km', 'long_name':'altitude midlevel'}),
                              'altitude_int': (['altitude_int'], altitude_int_air,{'units': 'km', 'long_name':'altitude interval'}),
                              'lat': (['lat'], np.array(volc_data['lat'].data, dtype=np.float64),{'units': 'degrees_north', 'long_name':'latitude'}),
                              'lon': (['lon'], np.array(volc_data['lon'].data, dtype=np.float64),{'units': 'degrees_east', 'long_name':'longitude'})
                              }
                
                print('\nCopying same year for the whole time period in ',sect) 
                if (self.mean == None):
                    volc_data_renamed = volc_data.copy().rename({'time': 'month'})
                    volc_data_renamed['month'] = [1,2,3,4,5,6,7,8,9,10,11,12]
                    date_range = xr.cftime_range(str(start)+"-01-01", str(end)+"-12-31", freq="MS")
                    months = date_range.month
                    selArr = xr.DataArray(months, dims=["time"], coords=[date_range])
                    volc_data = volc_data_renamed.sel(month=selArr)
                print('\nVolcanic data shape: ',volc_data.shape)
                
                new_var = {'contvolc':(list(volc_data.dims), np.array(volc_data.data,dtype=np.float32),{'units': 'molecules/cm3/s'})}
                data_vars.update(new_var)
                num_var = {num_str+'_'+self.varbl.split('_')[0].upper()+'_ELEV_contvolc':(list(volc_data.dims), np.array(volc_data.data*num_factor*sulfactor,dtype=np.float32),{'units': '(particles/cm3/s) * 6.022e26'})}
                num_vars.update(num_var)
                
                attrs = {
                        'comment':'This data was produced using the E3SM emission pre-processor',
                        'contact':'Taufiq Hassan (taufiq.hassan@pnnl.gov)',
                        'creation_date':datetime.today().strftime('%Y-%m-%d %H:%M:%S'),
                        'origin_grid':'0.625x0.46 degree latitudexlongitude',
                        'origin_nominal_resolution':'~60 km',
                        }
                
            if '_ELEV' in sect:
                zcol, altitude_air, altitude_int_air = get_zcol(altitude_air=list(self.altitude))
                ind_file = str(out_dir_path)+'/'+str(filename)
                if Path(ind_file).is_file():
                    print('\nUsing industrial/energy sector data:', ind_file)
                    if self.SEdata:
                        ind_data = xr.open_mfdataset(out_dir_path / str(filename))
                        dshape = ind_data['IND']/zcol
                        dshape = dshape.transpose('time','altitude','ncol')
                    else:
                        try:
                            remapped_ind = driver(file=ind_file,ind=str(in_dir_path),out=str(out_dir_path),grid=self.grid,res=self.res).gen_remapped_files()
                        except:
                            remapped_ind = driver(file=ind_file,ind=str(in_dir_path),out=str(out_dir_path),grid=self.grid,res=self.res,xdim='longitude',ydim='latitude').gen_remapped_files()
                        print(remapped_ind)
                        ind_data = xr.open_mfdataset(remapped_ind)
                        dshape = ind_data['IND']/zcol
                        dshape = dshape.transpose('time','altitude','lat','lon')
                else:
                    print('\nMake sure '+ind_file+' exists!\nHint: Create the surface emissions first to produce these files.')
                    raise
                
                ind_new=np.zeros(dshape.shape)+dshape.copy()*0
                sect_name = sect.split('_')[0]
                for i in range(len(zcol)):
                    ind_new[:,i]=ind_data[sect_name]*self.ind_frac[i]/zcol[i]
                
                if (self.mean == None) and (len(ind_new.time)/12 < years):
                    print('\nCopying same year for the whole time period in ',sect)
                    data_renamed = ind_new.copy().rename({'time': 'month'})
                    data_renamed['month'] = [1,2,3,4,5,6,7,8,9,10,11,12]
                    date_range = xr.cftime_range(str(start)+"-01-01", str(end)+"-12-31", freq="MS")
                    months = date_range.month
                    selArr = xr.DataArray(months, dims=["time"], coords=[date_range])
                    ind_new = data_renamed.sel(month=selArr)

                print('\nIND/ENE data shape: ',ind_new.shape)
                new_var = {sect:(list(dshape.dims), np.array(ind_new.data,dtype=np.float32),{'units': 'molecules/cm3/s'})}
                data_vars.update(new_var)
                num_var = {num_str+'_'+self.varbl.split('_')[0].upper()+'_ELEV_'+sect_name:(list(dshape.dims), np.array(ind_new.data*num_factor*sulfactor,dtype=np.float32),{'units': '(particles/cm3/s) * 6.022e26'})}
                num_vars.update(num_var)
        
        if self.checker:
            orig_vals = origElevData_Checker(years,self.all_sects,in_dir_path,out_dir_path,self.filename_list,self.start_yr_list,self.end_yr_list,self.varbl,self.mean,self.all_fracs,self.all_mws,self.cgrid)
            self.checkVals['sectors'] = self.all_sects
            self.checkVals['Orig (Tg/yr)'] = orig_vals
                
        return data_vars, num_vars, attrs, coords
    
    def prod_emis(self):
        data_vars, num_vars, attrs, coords = self.get_vars()
        dir_path = get_dir_path(self.outdir)
        
        ## Number conc out
        if self.varbl != 'so2':
            ds = xr.Dataset(data_vars=num_vars, coords=coords, attrs=attrs)
            if self.mean != None:
                self.ytag = str(self.mean_yr)+'CLIM'
            else:
                self.ytag = str(self.start)+'-'+str(self.end)
            ds.encoding = {'unlimited_dims': {'time'}}
            comp = dict(_FillValue=None)
            encodings = {var: comp for var in ds.data_vars}
            encodings.update({coord: comp for coord in ds.coords})
            ds.to_netcdf(dir_path / str(self.numfile+'_'+self.varbl+'_'+self.ytag+'_'+self.outfile),encoding=encodings,format="NETCDF3_64BIT")
            
        ## Mass out
        ds = xr.Dataset(data_vars=data_vars, coords=coords, attrs=attrs)
        if self.mean != None:
            self.ytag = str(self.mean_yr)+'CLIM'
        else:
            self.ytag = str(self.start)+'-'+str(self.end)
        ds.encoding = {'unlimited_dims': {'time'}}
        comp = dict(_FillValue=None)
        encodings = {var: comp for var in ds.data_vars}
        encodings.update({coord: comp for coord in ds.coords})
        ds.to_netcdf(dir_path / str(self.varbl+'_'+self.ytag+'_'+self.outfile),encoding=encodings,format="NETCDF3_64BIT")
        
        if self.checker:
            new_vals = newElevData_Checker(ds,self.all_mws[0],self.indir,self.cgrid)
            self.checkVals['New (Tg/yr)'] = new_vals
            self.checkVals[['Orig (Tg/yr)','New (Tg/yr)']] = self.checkVals[['Orig (Tg/yr)','New (Tg/yr)']].applymap(lambda x: rounding(x))
            print(self.checkVals)
    
    def combine_numvars(self,numval):
        print('\nCombining number conc files.')
        dir_path = get_dir_path(self.outdir)
        if self.mean != None:
            self.ytag = str(self.mean_yr)+'CLIM'
        else:
            self.ytag = str(self.start)+'-'+str(self.end)
        data = xr.open_mfdataset(str(dir_path)+'/'+str(numval+'_*_'+self.ytag+'_'+self.outfile))
        data.load().to_netcdf(dir_path / str(numval+'_'+self.ytag+'_'+self.outfile),format="NETCDF3_64BIT")
    
