import numpy as np
import xarray as xr
from pathlib import Path
import calendar

def get_dates(start,end):
    ## Get date array
    years = np.arange(1849,2030)
    dates = np.zeros(len(years)*12,dtype=np.int32)
    for i in range(len(years)):
        dates[i*12:(i+1)*12] = np.arange(18490115+i*1e4,18491315+i*1e4,100,dtype=np.int32)
    dates = dates[np.where((dates>=int(start)*1e4+115) & (dates<=int(end)*1e4+1215))]
    return dates    

def get_num_factor(mw, rho, diam):
    avgod = 6.022e23
    num_factor = mw/avgod/rho/((1/6)*np.pi*(diam*1e-6)**3)*1e-3*6.022e26
    return num_factor

def get_dir_path(path):
    if (path == ''):
        p=Path('.')
        dir_path = p.absolute()
    else:
        dir_path = Path(path)
    return dir_path

def get_sectors(data):
    print('\nGrabbing E3SM sector info in a dict')
    orig_sectors = {'Agriculture': 'AGR',
             'Energy': 'ENE',
             'Industrial': 'IND',
             'Transportation': 'TRA',
             'Residential, Commercial, Other': 'RCO',
             'Solvents production and application': 'SLV',
             'Waste': 'WST',
             'International Shipping': 'SHP'}
    print('\nE3SM sectors:',orig_sectors)
    print('\nGetting the sectors available in the dataset')
    try:
        items_data = data['sector'].ids.split(';')
    except:
        raise Exception('\nNo sectors available in the selected dataset!\nPossibly due to no sector dimension.')
    keys = []
    values = []
    for item in items_data:
        values.append(int(item.split(':')[0]))
        keys.append(orig_sectors[item.split(':')[1].strip()])
    sectors = dict(zip(keys, values))
    print('\nAvailable sectors: ', sectors) 
    return sectors

def get_zcol(altitude_air=None):
    if altitude_air == None:
        altitude_air = np.array([0.063000001013279, 0.202000007033348, 0.36599999666214,0.554000020027161, \
                                0.767000019550323, 1.00300002098083, 1.26199996471405,1.64100003242493, \
                                2.23200011253357, 3.02500009536743, 4.00899982452393,5.15700006484985, \
                                6.35599994659424], dtype=np.float64)
    else:
        altitude_air = np.array(altitude_air, dtype=np.float64)
        
    altitude_int_air = np.zeros(len(altitude_air)+1,dtype=np.float64)
    for i in range(len(altitude_air)):
        altitude_int_air[i+1] = altitude_air[i]*2 - altitude_int_air[i]
    
    dz = (altitude_int_air[1:] - altitude_int_air[:-1])*1e5 # km to cm
    zcol = xr.DataArray(dz,coords={'altitude': altitude_air},dims=["altitude"])
    return zcol, altitude_air, altitude_int_air

def rounding(n):
    if (type(n)==str) or (np.isnan(n)):
        return str('-')
    elif ((abs(n)>1e-4) and (abs(n)<1e4)):
        try:
            sgn = '-' if n<0 else ''
            num = format(abs(n)-int(abs(n)),'f')
            if int(num[2:])<1:
                d = str((abs(n)))
                return sgn + d
            else:
                for i,e in enumerate(num[2:]):
                    if e!= '0':
                        if i==0:
                            d = str(int(abs(n))) + (num[1:i+5])
                        else:
                            d = str(int(abs(n))) + (num[1:i+4])
                        return sgn+d
        except:
            return '-'
    else:
        return '{:.0e}'.format(n)
    
def get_emis_vertint(data):
    altitude_air = data.altitude
    altitude_int_air = np.zeros(len(altitude_air)+1,dtype=np.float64)
    for i in range(len(altitude_air)):
        altitude_int_air[i+1] = altitude_air[i]*2 - altitude_int_air[i]
    aa=((altitude_int_air[1:]-altitude_int_air[:-1])*1e5)
    hh=np.round(aa,2).tolist()
    hh_arr = xr.DataArray(hh,coords={'altitude': data.altitude},dims=["altitude"])
    data = data*hh_arr
    data = data.sum(dim='altitude')
    return data

def sec_in_mon(year, month):
    _, days = calendar.monthrange(year,month)
    seconds = days*24*60*60
    return seconds