import xarray as xr

from src.utils import get_dir_path, get_sectors, get_emis_vertint

def origSurfData_Checker(years,sectors,indir,filenames,start,end,var,mean,fracs):
    species_originname = {'so':'SO2_em_anthro',
                          'bc':'BC_em_anthro',
                          'po':'OC_em_anthro'}
    in_dir_path = get_dir_path(indir)
    orig_varname = species_originname[var[:2]]
    factbb  = 86400.0*365.0*1e-9
    grid = xr.open_dataset(in_dir_path / 'gridarea_CEDS.nc')['cell_area']
    in_dir_path = get_dir_path(indir)
    totlist = []
    new_file = ''
    for sect,file,frac,s,e in zip(sectors,filenames,fracs,start,end):
        if '1850' in file:
            grid = xr.open_dataset(in_dir_path / 'gridarea_CEDS.nc')['cell_area']
        else:
            grid = xr.open_dataset(in_dir_path / 'gridarea_CEDS.nc')['cell_area']
        old_file = file
        if new_file != old_file:
            data = xr.open_dataset(in_dir_path / file).sel(time=slice(str(s),str(e)))
            av_sectors = get_sectors(data)
        data_sect = data[orig_varname].sel(sector=av_sectors[sect])
        if mean != None:
            xr.set_options(use_flox=True)
            data_sect = data_sect.groupby('time.month').mean('time')
            data_sect = data_sect.rename({'month':'time'})
        val = (data_sect.sum('time')*grid*factbb/12).sum(['lat','lon']).values*frac
        if (e-s+1) < years:
            val = val*years
        totlist.append(val)
        new_file = old_file
    return totlist

def newSurfData_Checker(ndata,mw,indir,mean,grid):
    in_dir_path = get_dir_path(indir)
    avgod = 6.022e23
    factbb  = 86400.0*365.0*1e-9
    factor = 10*mw/avgod
    varbls = list(ndata.variables.keys())
    varbls.remove('date')
    varbls.remove('lat')
    varbls.remove('lon')
    varbls.remove('lev')
    print(varbls)
    totlist = []
    for v in varbls:
        cc = (ndata[v])
        dd = (cc*factor).sum('time')
        ne30area = xr.open_dataset(in_dir_path / grid)['area']
        totlist.append((factbb*dd*ne30area*(6.37122e6)**2/12).sum().values)
    return totlist

def origElevData_Checker(years,sectors,indir,outdir,filename,start,end,var,mean,fracs,mws,cgrid):
    var_originname = {'so':'SO2',
                    'bc':'BC',
                    'po':'OC'}

    in_dir_path = get_dir_path(indir)
    avgod = 6.022e23 
    factbb  = 86400*365*1e-9
    totlist = []
    for sect,file,mw,frac,s,e in zip(sectors,filename,mws,fracs,start,end):
        factor = 10*mw/avgod
        if sect == 'BB':
            var = var_originname[var[:2]]
            if file == None:
                grid = xr.open_dataset(in_dir_path / 'GFED_gridarea_p25.nc')['cell_area']
                file = var+'_biomass_burning_emis_GFED_1997-2022.nc'
                data = xr.open_dataset(in_dir_path / file).sel(time=slice(str(s),str(e)))
                if mean != None:
                    xr.set_options(use_flox=True)
                    data = data.groupby('time.month').mean('time')
                    data = data.rename({'month':'time'})
                cc = data[var].sum('time')
                val = (cc*grid*factbb).sum()/12*frac
                if (e-s+1) < years:
                    val = val*years
                totlist.append(val.values)
            else:
                grid = xr.open_dataset(in_dir_path / 'BB_gridarea_p25.nc')['cell_area']
                data = xr.open_dataset(in_dir_path / file).sel(time=slice(str(s),str(e)))
                if mean != None:
                    xr.set_options(use_flox=True)
                    data = data.groupby('time.month').mean('time')
                    data = data.rename({'month':'time'})
                cc = data[var].sum('time')
                val = (cc*grid).sum()*factbb/12*frac
                if (e-s+1) < years:
                    val = val*years
                totlist.append(val.values)
        if sect == 'contvolc':
            grid = xr.open_dataset(in_dir_path / 'gridarea_RLL.nc')['cell_area']
            data = xr.open_dataset(in_dir_path / file)[sect]
            data = get_emis_vertint(data).sum('time')
            val = (data*grid*factbb*factor/12).sum()
            totlist.append(val.values)
        if '_ELEV' in sect:
            sect_name = sect.split('_')[0]
            ne30area = xr.open_dataset(in_dir_path / cgrid)['area']
            data = xr.open_dataset(outdir / file)
            bb = (data[sect_name]*factor).sum('time')
            val = (bb*factbb*ne30area*(6.37122e6)**2/12).sum()
            if (e-s+1) < years:
                val = val*years
            totlist.append(val.values)          
    return totlist

def newElevData_Checker(ndata,mw,indir,grid):
    in_dir_path = get_dir_path(indir)
    avgod = 6.022e23
    factbb  = 86400.0*365.0*1e-9
    factor = 10*mw/avgod
    varbls = list(ndata.variables.keys())
    varbls.remove('date')
    varbls.remove('lat')
    varbls.remove('lon')
    varbls.remove('altitude')
    varbls.remove('altitude_int')
    totlist = []
    years = len(ndata.time)/12
    for v in varbls:
        cc = get_emis_vertint(ndata[v])
        dd = (cc*factor).sum('time')
        ne30area = xr.open_dataset(in_dir_path / grid)['area']
        val = (factbb*dd*ne30area*(6.37122e6)**2/12).sum().values
        if v == 'contvolc':
            val = val/years
        totlist.append(val)
    return totlist