## This script pre-processes the GFED4 emissions for FlexEmis.
## Use this to generate fire emissions data from raw/new GFED4 data (NOT part of input4CMIP6).

import numpy as np
import pandas as pd
import xarray as xr
from datetime import datetime
import pkg_resources
import h5py

from src.utils import sec_in_mon, get_dir_path

def gen_GFED4(aer,indir,outdir,start_year,end_year):
    
    months = '01','02','03','04','05','06','07','08','09','10','11','12'
    sources = ['SAVA','BORF','TEMF','DEFO','PEAT','AGRI']
    
    """
    Read in emission factors
    """
    resource_package = __name__
    resource_path = '/'.join(('..', 'data', 'GFED4_Emission_Factors.csv'))
    EFdata = pkg_resources.resource_stream(resource_package, resource_path)

    EFs = pd.read_csv(EFdata)
    EF_CO = EFs[EFs['spc']==aer][sources].values[0]
    print(EF_CO)
    
    bb_year = np.zeros(((end_year-start_year+1)*12, 720, 1440))
    
    i=0
    dir_path = get_dir_path(indir)
    for year in range(start_year, end_year+1):
        print(str(dir_path)+'/GFED4.1s_'+str(year)+'.hdf5')
        string = str(dir_path)+'/GFED4.1s_'+str(year)+'.hdf5'
        f = h5py.File(string, 'r')    
        
        for month in range(12):
            # read in DM emissions
            string = '/emissions/'+months[month]+'/DM'
            DM_emissions = f[string][:]
            CO_emissions = np.zeros((720, 1440))
            for source in range(6):
                # read in the fractional contribution of each source
                string = '/emissions/'+months[month]+'/partitioning/DM_'+sources[source]
                contribution = f[string][:]
                CO_emissions += DM_emissions * contribution * EF_CO[source]
            fact = sec_in_mon(year, month+1)
            conv_vals = CO_emissions/fact/1e3 # converted to kg/m2/s
            bb_year[month+12*i] = conv_vals
        i+=1
    
    time = pd.date_range(str(start_year)+"-01-01",str(end_year)+"-12-31", freq='M')
    lat_biom = np.arange(89.875, -90.125, -0.25).round(3)
    lon_biom = np.arange(-179.875, 180.125, 0.25).round(3)
    
    data_vars = {'time':(['time'], time),
                aer:(['time','latitude','longitude'], (bb_year).data,{'units': 'kg/m2/s'})
                }
    
    coords = {
              'latitude': (['latitude'], np.array(lat_biom.data, dtype=np.float32),{'units': 'degrees_north', 'long_name':'latitude'}),
              'longitude': (['longitude'], np.array(lon_biom.data, dtype=np.float32),{'units': 'degrees_east', 'long_name':'longitude'})
              }
    
    attrs = {
            'comment':'GFED data produced for the E3SM emission pre-processor',
            'contact':'Taufiq Hassan (taufiq.hassan@pnnl.gov)',
            'creation_date':datetime.today().strftime('%Y-%m-%d %H:%M:%S'),
            'grid':"0.25x0.25 degree latitudexlongitude",
            'nominal_resolution':"25 km",
            }
    
    ds = xr.Dataset(data_vars=data_vars, coords=coords, attrs=attrs)
    ds.encoding = {'unlimited_dims': {'time'}}
    comp = dict(_FillValue=None)
    encodings = {var: comp for var in ds.data_vars}
    encodings.update({coord: comp for coord in ds.coords})
    
    dir_path = get_dir_path(outdir)
    ds.load().to_netcdf(dir_path / str(aer+'_biomass_burning_emis_GFED_'+str(start_year)+'-'+str(end_year)+'.nc'),encoding=encodings)
    return str(aer+'_biomass_burning_emis_GFED_'+str(start_year)+'-'+str(end_year)+'.nc')
