import pandas as pd
import numpy as np
import sys
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import glob
import os


# columns in the trace
cols=[
    "junk",
    "rank",
    "seqno",
    "tracelen",
    "timestamp",
    "duration",
    "loss_allreduce",
    "optimizer",
    "gradient_allreduce",
    "backward",
    "forward",
    "batch_read",
    "arsize",
    ]
# columns for unsynchronized work
workcols = ['optimizer','backward','forward','batch_read']
# columns for sync
synccols = ['loss_allreduce','gradient_allreduce']

def readtrace(dirname):
    filelist = glob.glob(dirname + "/Trace.*")
    print(filelist)
    df = pd.concat(
        [pd.read_csv(l,names=cols,dtype=np.int64,usecols=range(1,len(cols))) 
            for l in filelist])
    df['work'] = df[workcols].sum(axis=1)
    df['sync'] = df[synccols].sum(axis=1)
    # perf-dcn02-1-1-64-300000
    name = dirname.split('/')
    pos = len(name)-1
    if name[pos] == '':
        pos -= 1
    descr = name[pos].split('-')
    ret = {'nodes':int(descr[2]),'ppn':int(descr[3]),'bs':int(descr[4]),
        'ntraces':int(descr[5])}
    # kludge based on diamond node naming convention
    if descr[1].startswith('dcn'): arch = 'BDW'
    elif descr[1].startswith('skl'): arch = 'SKL'
    elif descr[1].startswith('knl'): arch = 'KNL'
    else: arch = 'UNK'
    ret['arch'] = arch
    ret['df'] = df
    return ret

def basic_chart(df):
    # min and max per batch for work and sync, and total time per batch
    g = df.groupby(['seqno'])
    r=range(2,len(g))
    workmean = g.mean()['work'][2:]*1e-6
    workmax = g.max()['work'][2:]*1e-6
    syncmin = g.min()['sync'][2:]*1e-6
    syncmax = g.max()['sync'][2:]*1e-6
    dur = g.max()['duration'][2:]*1e-6
    plt.plot(r,dur,label='total time',color='black')
    plt.plot(r,workmean,label='work mean')
    plt.plot(r,workmax,label='work max')
    plt.plot(r,syncmin,label='sync min')
    #plt.plot(r,syncmax,label='sync max')
    plt.legend()
    plt.show()


def overall_barchart(df):
    # bar chart of time spent by each rank
    allcols = workcols+synccols
    width = 0.3
    g = df.groupby('rank')
    ind = np.arange(len(g))
    start = np.zeros(len(g))
    legendv = []
    for t in allcols:
        yvals = g.sum()[t]*1e-6
        p = plt.bar(ind, yvals, width, bottom=start)
        start = start + yvals
        legendv.append(p[0])
    plt.legend(legendv,allcols)
    plt.show()



def imbalance_chart(df):
    g = df.groupby('seqno')
    y = []
    imbals = {}
    totalimbal = 0
    for i in workcols:
        min = g.min()[i][2:]
        max = g.max()[i][2:]
        best = min.sum()*1e-6
        worst = max.sum()*1e-6
        imbal = worst - best
        totalimbal += imbal
        imbals[i] = [best,worst,imbal]
        y.append((max-min)*1e-6)

    dur = g.max()['duration'][2:]*1e-6
    totaltime = dur.sum()
    print('time',totaltime,'imbal',totalimbal)

    fig,axes = plt.subplots(2,1,sharex=True,
       gridspec_kw = {'height_ratios':[1, 3]})
    ax = axes[1]
    r = range(2,len(g))
    ax.stackplot(r, y, labels=workcols)
    ax.plot(r,dur,color='black',label='duration')
    ax.set_ylabel('Time(s)')
    ax.set_title('Imbalance and total time')
    ax.legend(ncol=3,loc='center left')

    ax = axes[0]
    min_tracelen = g.min()['tracelen'][2:]
    max_tracelen = g.max()['tracelen'][2:]
    ax.plot(r,min_tracelen,color='red',marker='.',linestyle='none',label='min tracelen')
    ax.plot(r,max_tracelen,color='black',marker='.',linestyle='none',label='max tracelen')
    ax.legend()
    ax.set_ylabel('tracelen')
    ax.set_title('Trace Length per rank')

    axes[1].set_xlabel('Batch #')
    plt.show()



def batch_stats(df):
    g = df.groupby('seqno')
    # For each metric get the minimum and maximum for each rank for each batch
    workmin=g.max()['work']*1e-6
    workmax=g.min()['work']*1e-6
    armax=g.max()['sync']*1e-6
    armin=g.min()['sync']*1e-6
    # If we sum the mins, we get the least amount of time it would have taken
    # If we sum the maxs, we get the actual time it took
    # the difference is the load imbalance
    print("work min",workmin.sum(),'max',workmax.sum())
    print("allreduce min",armin.sum(),'max',armax.sum())
    #print(g.min()['arsize']*4.*1e-6/armin)
    arbytes = g.min()['arsize'].sum()*4
    print('arbytes', arbytes, 'arbytes/batch', arbytes/len(g))

def summary(dat):
    df = dat['df']
    g = df.groupby('seqno')
    dur = g.max()['duration'].sum()
    means = g.mean()[workcols].sum()
    idx = g['work'].transform(max) == df['work']
    dfmax = df[idx]
    if len(g) != len(dfmax):
        print("Warning: too many maxima", len(g), len(dfmax))
    maxcols=workcols+['sync']
    maxes = dfmax[maxcols].sum()
    imbal = (maxes[workcols].sum() - means.sum())
    totranks = dat['nodes'] * dat['ppn']
    tps = dat['ntraces'] / (maxes.sum() * 1e-6)
    besttps = dat['ntraces'] / ((means.sum() + maxes['sync']) * 1e-6)
    dfdict = {
        'arch':   [dat['arch'], dat['arch']],
        'class':  ['Actual', 'Best'],
        'nodes':  [dat['nodes'], dat['nodes']],
        'ppn':    [dat['ppn'], dat['ppn']],
        'tps':    [tps, besttps],
        'ntps':   [tps/totranks, besttps/totranks],
    }
    cols = ['arch', 'class', 'nodes', 'ppn', 'tps', 'ntps'] + maxcols
    for k in workcols:
        maxval =  1e-3 * totranks * maxes[k] / dat['ntraces']
        meanval = 1e-3 * totranks * means[k] / dat['ntraces']
        dfdict[k] = [maxval, meanval]
    synper = 1e-3 * totranks * maxes['sync'] / dat['ntraces']
    dfdict['sync'] = [synper, synper]
    ret = pd.DataFrame(dfdict,columns=cols)
    #ret.set_index(['arch','class','nodes','ppn'],inplace=True)
    return ret


def plotbreakdown(df):
    width = 0.5
    plt.figure(figsize=(5.5,5.7))
    allcols = workcols + ['sync']
    legendv = []
    ph = []
    df1 = df[(df['ppn'] > 1) | (df['class'] == 'Actual')]
    ind = np.arange(len(df1))
    start = np.zeros(len(df1))
    
    colors = cm.inferno(np.linspace(0,1,5))[1:-1]
    for t , c in zip(allcols, colors):
        if t == 'sync': # skip sync for first bar because it should be 0
            p = plt.bar(ind[1:], df1[t][1:], width, bottom=start[1:], color=c)
        else:
            p = plt.bar(ind, df1[t], width, bottom=start)
        ph.append(p)
        start += df1[t]
        legendv.append(p[0])
    for j in range(len(ph)):
        for i,patch in enumerate(ph[j].get_children()):
            if allcols[j] == 'sync':
                val = df1[allcols[j]].iloc[i+1]
            else:
                val = df1[allcols[j]].iloc[i]
            bl = patch.get_xy()
            x = bl[0] + patch.get_width() / 2
            y = bl[1]+0.1
            plt.text(x, y, "%.1f"%(val),ha='center')
    xlabels = [df1['class'].iloc[i] + '\n' + \
               str(df1['nodes'].iloc[i] * df1['ppn'].iloc[i]) + ' socket' \
                   for i in range(len(df1))]
    # We're going to cheat here and just do columns 2 and 4 (the two multi-socket 'best')
    actual = df1[allcols].iloc[1].sum()
    best =  df1[allcols].iloc[2].sum()
    imbal = 100 * (actual-best) / actual
    # FIXME
    # looks like x coords are 0-1, this probably only works for this particular width
    plt.axhline(actual,0.45,0.55)
    # but text is in expected coodinates. Yikes.
    plt.text(1.9,actual-1.5,"%.0f%%"%(imbal))

    actual = df1[allcols].iloc[3].sum()
    best =  df1[allcols].iloc[4].sum()
    imbal = 100 * (actual-best) / actual
    plt.axhline(actual,0.85,0.95)
    plt.text(3.9,actual-5,"%.0f%%"%(imbal))

    plt.legend(legendv,allcols,ncol=2)
    plt.xticks(range(len(df1)),xlabels)
    plt.ylabel('Normalized time/trace(msec)')
    plt.title('BDW Scaling and Imbalance')
    plt.savefig('imbal.png')
    plt.show()


