# -*- coding: utf-8 -*-

import traceback
import argparse
import sys
import os
from glob import glob
import pprint
import uuid
import time
import subprocess
import signal
import pypdt
import numpy as np
import torch
import pyprob
from pyprob import RemoteModel, PriorInflation, InferenceEngine, InferenceNetwork, ObserveEmbedding
import pyprob.diagnostics
from pyprob.util import get_time_stamp, to_tensor, to_numpy, days_hours_mins_secs_str
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib import gridspec
from matplotlib.ticker import MultipleLocator
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection, Line3DCollection
import matplotlib.image as mpimg
import matplotlib.cm as cm
plt.switch_backend('agg')
mpl.rcParams['axes.unicode_minus'] = False
colors_inferno = [cm.inferno(x) for x in np.linspace(0, 1, 5)]

min_energy_deposit = 0.05  # GeV



def particle_name(pid):
    if abs(pid) > pid:
        return 'anti-' + pypdt.get(abs(pid)).name
    else:
        return pypdt.get(pid).name


channel_names = ['tau-_nu_taunu_ebe-',
                 'tau-_nu_taunu_mubmu-',
                 'tau-_pi-nu_tau',
                 'tau-_K-nu_tau',
                 'tau-_pi-pinu_tau',
                 'tau-_K-KSnu_tau',
                 'tau-_K-KLnu_tau',
                 'tau-_KSpi-nu_tau',
                 'tau-_pi-KLnu_tau',
                 'tau-_K-pinu_tau',
                 'tau-_pi-pipinu_tau',
                 'tau-_pi+pi-pi-nu_tau',
                 'tau-_KSpi-pinu_tau',
                 'tau-_pi-KLpinu_tau',
                 'tau-_K-pi+pi-nu_tau',
                 'tau-_K-pipinu_tau',
                 'tau-_K+K-pi-nu_tau',
                 'tau-_K-KSpinu_tau',
                 'tau-_K-KLpinu_tau',
                 'tau-_KSpi-KLnu_tau',
                 'tau-_KSKSpi-nu_tau',
                 'tau-_pi-KLKLnu_tau',
                 'tau-_K+K-K-nu_tau',
                 'tau-_pi+pi-pi-pinu_tau',
                 'tau-_pi-pipipinu_tau',
                 'tau-_K-pipipinu_tau',
                 'tau-_KSpi-pipinu_tau',
                 'tau-_pi-KLpipinu_tau',
                 'tau-_K+K-pi-pinu_tau',
                 'tau-_etapi-nu_tau',
                 'tau-_K-etanu_tau',
                 'tau-_etaetapi-nu_tau',
                 'tau-_etapi-pinu_tau',
                 'tau-_pi+pi-pi-pipinu_tau',
                 'tau-_pi-pipipipinu_tau',
                 'tau-_K-pipipipinu_tau',
                 'tau-_pi+pi-pi-pipipinu_tau',
                 'tau-_pi+pi+pi-pi-pi-pinu_tau']


def trace_to_numpy(trace):
    px = trace.named_variables['mother_momentum_x'].value
    py = trace.named_variables['mother_momentum_y'].value
    pz = trace.named_variables['mother_momentum_z'].value
    channel = trace.named_variables['channel_index'].value
    particles = trace.named_variables['final_state_particles'].value.view(-1)
    calorimeter = trace.named_variables['calorimeter_n_deposits'].value
    return torch.cat([px, py, pz, channel, particles, calorimeter]).cpu().numpy()


def distribution_to_numpy(dist, file_name=None):
    # print('Resampling from distribution of length {}'.format(dist.length))
    # dist_thinned = dist.thin(num_traces, min_index=min_index, max_index=max_index)
    # dist_thinned_max_log_prob = dist_thinned.arg_max(map_func=lambda trace: trace.log_prob)
    dist_numpy = dist.map(func=trace_to_numpy)
    ret = dist_numpy.values_numpy()
    if file_name is not None:
        print('Saving to numpy: {}'.format(file_name))
        np.save(file_name, dist_numpy)
    return ret


def plot_trace(trace, file_name=None):
    data = trace.named_variables['calorimeter_n_deposits'].distribution.mean.view(35, 35, 20).data.numpy()
    data_flat = data.reshape(-1)
    data_flat_min = min(data_flat) * min_energy_deposit
    data_flat_max = max(data_flat) * min_energy_deposit
    data_flat_total = np.sum(data_flat) * min_energy_deposit

    particles = trace.named_variables['final_state_particles'].value
    particles = particles[particles > -99999].view(-1, 8)

    trace_text = []
    trace_text.append(str(trace))
    trace_text.append('p_x      : {}'.format(float(trace.named_variables['mother_momentum_x'].value)))
    trace_text.append('p_y      : {}'.format(float(trace.named_variables['mother_momentum_y'].value)))
    trace_text.append('p_z      : {}'.format(float(trace.named_variables['mother_momentum_z'].value)))
    channel = int(trace.named_variables['channel_index'].value)
    trace_text.append('channel  : {} ({})'.format(channel, channel_names[channel]))
    trace_text.append('particles: {}'.format(particles.size(0)))
    for particle in particles:
        pid = particle_name(int(particle[-2]))
        trace_text.append(' {} {}'.format(pid, str(list(particle.numpy()))))
    trace_text.append('calorimeter:')
    trace_text.append(' min   : {:.3f} GeV'.format(data_flat_min))
    trace_text.append(' max   : {:.3f} GeV'.format(data_flat_max))
    trace_text.append(' total : {:.3f} GeV'.format(data_flat_total))
    trace_text = '\n'.join(trace_text)

    fig = plt.figure(figsize=(15, 4))
    # ax1 = fig.add_subplot(1, 2, 1)
    ax2 = fig.add_subplot(1, 1, 1, projection='3d')

    plt.rc('font', family='monospace')
    # ax1.text(0, 0, trace_text)
    # ax1.axis('off')
    cmbr = mpl.colors.LinearSegmentedColormap.from_list('blue_to_red', ['b', 'r'])
    cnorm = mpl.colors.Normalize(vmin=data_flat_min, vmax=data_flat_max)
    cpick = mpl.cm.ScalarMappable(norm=cnorm, cmap=cmbr)
    cpick.set_array([])

    ax2.set_xlabel('x')
    ax2.set_ylabel('y')
    ax2.set_zlabel('z')
    ax2.set_aspect('equal')

    if data_flat.max() - data_flat.min() > 0:
        ix, iy, iz = np.mgrid[-3:3:35j, -3:3:35j, 4:15:20j]
        data_flat_normalized = np.clip((data_flat - data_flat.min())/(data_flat.max() - data_flat.min()), 0, 1)
        colors = [[x, 0, 1-x, x] for x in data_flat_normalized]
        ax2.scatter(ix, iy, iz, s=100, c=colors, marker='o', edgecolor='none')
        plt.colorbar(cpick, label='Deposited energy (GeV)', fraction=0.02)
        # plt.suptitle('channel  : {} ({})'.format(channel, channel_names[channel]))

    if file_name is not None:
        plt.savefig(file_name + '.png', bbox_inches='tight')
        with open(file_name + '.txt', 'w') as file:
            file.writelines(trace_text)


def plot_distribution(dist, ground_truth_trace=None, file_name=None):
    if dist.length > 0:
        dist_mode = dist.mode
        dist_mode_numpy = trace_to_numpy(dist_mode)
        dist_numpy = distribution_to_numpy(dist)
        num_traces = dist_numpy.shape[0]
        if ground_truth_trace is not None:
            ground = trace_to_numpy(ground_truth_trace)

        def nhad_nem_ninvis(pids,viz):
            n_em      = np.sum([1 if abs(p) in [22,11] else 0 for p in pids])
            n_had     = np.sum([1 if abs(p)>100 else 0 for p in pids])
            n_calovis = np.sum(viz)
            n_invis   = np.sum(1-viz)
            assert n_em + n_had + n_invis == len(pids)
            return [n_em,n_had,n_calovis,n_invis]


        mother           = dist_numpy[:,:3]
        channel          = dist_numpy[:,3]
        final            = dist_numpy[:,4:4+(30*8)].reshape(num_traces,30,8)
        finalfilt        = [np.array(sorted(f[f>-9999].reshape(-1,8), key=lambda x: -x[3])) for f in final]
        finalmult        = np.array([f.shape[0] for f in finalfilt])
        obs              = dist_numpy[:,4+(30*8):].reshape(num_traces,35,35,20)
        particle_types   = np.array([nhad_nem_ninvis(f[:,6],f[:,7]) for f in finalfilt])

        if ground_truth_trace is not None:
            g_mother         = ground[:3]
            g_channel        = ground[3]
            g_final          = ground[4:4+(30*8)].reshape(30,8)
            g_finalfilt      = np.array(sorted(g_final[g_final>-9999].reshape(-1,8), key=lambda x: -x[3]))
            g_finalmult      = g_finalfilt.shape[0]
            g_obs            = ground[4+(30*8):].reshape(35,35,20)
            g_particle_types = nhad_nem_ninvis(g_finalfilt[:,6],g_finalfilt[:,7])

        from mpl_toolkits.mplot3d.art3d import Poly3DCollection, Line3DCollection


        particle_styles = {
            211: ('k', 'solid',  [0,15]), # pi
            16:  ('b', 'dashed', [0,15]), # tau lepton
            22:  ('y', 'solid',  [0, 9]), # photon
            11:  ('y', 'solid',  [0, 9]), # electron
        }

        def trajectory(momentum, zrange = [4,10]):
            t0 = zrange[0]/momentum['pz']
            t1 = zrange[1]/momentum['pz']
            return [momentum['px']*np.linspace(t0,t1), momentum['py']*np.linspace(t0,t1), momentum['pz']*np.linspace(t0,t1)]


        def plot_surf(ax,surfz = 4):
            colors = [mpl.colors.hex2color(mpl.colors.cnames['green']) + (0.1,)]
            faces  = [
                [[-3,-3,surfz], [-3,3,surfz], [3,3,surfz], [3,-3,surfz]]
            ]

            for face,color in zip(faces,colors):
                ax.add_collection3d (Poly3DCollection ([face], facecolor = color))

        def extend_obs(observation):
            minx = 4 - 11./(20.-1.) * 7
            np.linspace(minx,15,27)
            mgrid = ix,iy,iz = np.mgrid[-3:3:35j,-3:3:35j,minx:15:27j]
            extended = np.zeros((35,35,27))
        #    extended[:,:,7:] = np.nan
        #    observation[observation==0]=np.nan
        #    extended[:,:,0:7] = np.zeros((35,35,7))
            extended[:,:,7:] = observation
            return mgrid, extended

        def plot(ax,ix,iy,iz,observation, trajectories_trace = None):
            cutoff = np.mean(observation)
            ax.set_xlim(-3,3)
            ax.set_ylim(-3,3)
            ax.set_zlim(0,15)

            if trajectories_trace is not None:
                mother         = trajectories_trace[:3]
                final          = trajectories_trace[4:4+(30*8)].reshape(30,8)
                finalfilt      = np.array(sorted(final[final>-9999].reshape(-1,8), key=lambda x: -x[3]))
                mother_momentum = dict(zip(['px','py','pz'],mother))
                decay_particles = [dict(zip(['px', 'py', 'pz', 'E', 'theta', 'phi', 'pid', 'visible'],map(float,p))) for p in finalfilt]
                ax.plot(*trajectory(mother_momentum,[4,15]), c = 'r', linestyle = 'dotted')

                for p in decay_particles:
                    if abs(p['pid']) in particle_styles:
                        c, style, zlim = particle_styles[abs(p['pid'])]
                    else:
                        c, style, zlim = ('red', 'solid',  [0, 9])  # other (added by Gunes)
                #     ax.plot(*trajectory(p), linewidth=p['E']/5., c = c, linestyle = style)
                    ax.plot(*trajectory(p,zlim), linewidth=2.5, c = c, linestyle = style)


            # cutoff = np.mean(observation)
            cutoff = 0.1
            sizes = np.zeros(shape = observation.shape)
            sizes[observation > cutoff] = 100
            ax.scatter(ix,iy,iz, c = observation.ravel(), alpha = 0.1, s = sizes)

            plot_surf(ax)

            ax.set_xlabel('x')
            ax.set_ylabel('y')
            ax.set_zlabel('z')

            ax.view_init(10.,20.)


        f = plt.figure()
        f.set_size_inches(25,10)

        shape = (4,11)

        ax1 = plt.subplot2grid(shape, (0, 1))
        ax2 = plt.subplot2grid(shape, (0, 2))
        ax3 = plt.subplot2grid(shape, (0, 3))


        motheraxes = [ax1,ax2,ax3]


        emhadcomp   = plt.subplot2grid(shape, (0, 4))
        nfinalstate = plt.subplot2grid(shape, (1, 4))
        channelax   = plt.subplot2grid(shape, (1, 2), colspan=2)

        measimobsax = plt.subplot2grid(shape, (0, 5), rowspan=2, colspan=2, projection='3d')
        modsimobsax = plt.subplot2grid(shape, (0, 7), rowspan=2, colspan=2, projection='3d')
        obsax       = plt.subplot2grid(shape, (0, 9), rowspan=2, colspan=2, projection='3d')

        finalstateaxes = {}

        for i in range(2):
            for j in range(2):
                if i < j:
                    continue
                ax = plt.subplot2grid(shape, (i,j))
                finalstateaxes.setdefault(i,{})[j] = ax



        # ## number of final state particles
        ax = nfinalstate

        h    = nfinalstate.hist(finalmult, bins = np.linspace(-0.5,10.5,12), density=True)
        ymax = np.max(h[0])*1.5
        if ground_truth_trace is not None:
            l    = ax.vlines(g_finalmult,0,ymax, linestyles='dashed')
        # ax.legend([l,h[2][0]],['ground truth','posterior'])
        ax.set_ylim(0,ymax)
        ax.set_xlim(xmin=0)
        ax.set_xlabel('Number of Final State Particles')
        ax.set_title('Decay Products')
        ax.xaxis.set_major_locator(MultipleLocator(1))


        # ## number of final state particles
        ax = emhadcomp
        em_had   = particle_types[:,:2]
        if ground_truth_trace is not None:
            g_em_had = g_particle_types[:2]

        nx, ny = 10,8
        h =ax.hist2d(em_had[:,0],em_had[:,1], [np.linspace(-0.5,0.5+nx,nx+2), np.linspace(-0.5,0.5+ny,ny+2)], normed=True, cmap = 'viridis')
        if ground_truth_trace is not None:
            l = ax.scatter(g_em_had[0],g_em_had[1],c = 'w', edgecolors='k')
        # ax.legend([l],['ground truth'])
        ax.set_xlabel('Number of EM Particles')
        ax.set_ylabel('Number of HAD Particles')
        ax.set_title('Event Composition')
        ax.xaxis.set_major_locator(MultipleLocator(1))
        ax.yaxis.set_major_locator(MultipleLocator(1))
        # # plt.colorbar(h[3], ax=ax)



        ##
        axarr = finalstateaxes
        energies = np.array([f[:2] for f in finalfilt])
        try:
            for i in range(2):
                for j in range(2):
                    if i > j:
                        ax = axarr[i][j]
                        ax.hist2d(energies[:,j,3],energies[:,i,3], [np.linspace(0,40,11),np.linspace(0,40,11)], cmap='viridis')
                        if ground_truth_trace is not None:
                            l = ax.scatter(g_finalfilt[:3][j,3],g_finalfilt[:3][i,3], c = 'w', edgecolors='k')
                        ax.set_title('FSP Energy Joint')
                    elif i==j:
                        ax = axarr[i][j]
                        h = ax.hist(energies[:,j,3],bins = np.linspace(0,40,11), facecolor = 'grey', density=True)
                        if ground_truth_trace is not None:
                            l = ax.vlines(g_finalfilt[:3][j,3],0,1, linestyles='dashed')
                        ax.set_ylim(0,1.5*np.max(h[0]))
                        ax.set_xlim(0,40)
                        ax.set_title('FSP Energy {}'.format(i + 1))
        except:
            print('Error with plotting FSP')

        # ## tau momentum
        axarr = motheraxes
        colors = mpl.cm.inferno(np.linspace(0,1,5))[1:-1]

        titles = ['τ px','τ py','τ pz']
        limits = [[-3,3],[-3,3],[43,47]]

        for i,(lim,ax,c,t) in enumerate(zip(limits,axarr,colors,titles)):
            n,_,h = ax.hist(mother[:,i], density=True, bins = np.linspace(*(lim+[11])), facecolor=c)
            ymax = 1.5*np.max(n)
            if ground_truth_trace is not None:
                l = ax.vlines(g_mother[i],0,ymax, linestyles='dashed')
            # ax.legend([l,h[0]],['ground truth','posterior'])
            ax.set_xlim(*lim)
            ax.set_title(t)
            ax.set_ylim(0,ymax)

        # ##

        ax = channelax
        ch = ax.hist(channel, np.linspace(-0.5,35.5,37), density=True, facecolor = 'grey')
        if ground_truth_trace is not None:
            l = ax.vlines(g_channel,0,1.0, linestyles='dashed')
        ax.set_ylim(0.,1.0)
        ax.set_xlim(-1, 38)
        # ax.legend([l,ch[2][0]],['ground truth','posterior'])
        ax.xaxis.set_minor_locator(MultipleLocator(1))
        ax.set_title('Decay Channel')

        if ground_truth_trace is not None:
            (ix, iy, iz), obs_extended = extend_obs(g_obs)
            obs_extended = (obs_extended - np.min(obs_extended)) / (np.max(obs_extended) - np.min(obs_extended))
            cutoff = 0.1
            # cutoff = 0
            # print(observation.shape)
            # print(ix.shape)
            obsax.set_title('Observed Calorimeter')
            ixa = np.append(ix[obs_extended>cutoff],[3, -3.])
            iya = np.append(iy[obs_extended>cutoff],[3, -3.])
            iza = np.append(iz[obs_extended>cutoff],[0, 14])
            # print(ixa.shape, iya.shape, iza.shape)
            # print(np.append(obs_extended[obs_extended>cutoff],[0,0]).shape)
            plot(obsax,ixa,iya,iza,np.append(obs_extended[obs_extended>cutoff],[0,0]), trajectories_trace=ground)
            #
            # print('cutoff', cutoff)
            # print('avgobs_extended min', np.min(avgobs_extended))
            # print('avgobs_extended max', np.max(avgobs_extended))

        avgobs = np.average(obs,0)
        # print(avgobs.shape)
        (ix, iy, iz), avgobs_extended = extend_obs(avgobs)
        # cutoff = np.mean(avgobs_extended)/10000
        # cutoff = 0
        avgobs_extended = (avgobs_extended - np.min(avgobs_extended)) / (np.max(avgobs_extended) - np.min(avgobs_extended))
        cutoff = 0.1

        measimobsax.set_title('Simulated Calorimeter (Mean)')
        ixa = np.append(ix[avgobs_extended>cutoff],[3, -3.])
        iya = np.append(iy[avgobs_extended>cutoff],[3, -3.])
        iza = np.append(iz[avgobs_extended>cutoff],[0, 14])
        # print(ixa.shape, iya.shape, iza.shape)
        # print(np.append(avgobs_extended[avgobs_extended>cutoff],[0,0]).shape)
        plot(measimobsax,ixa,iya,iza,np.append(avgobs_extended[avgobs_extended>cutoff],[0,0]), trajectories_trace=None)

        modsimobsax.set_title('Simulated Calorimeter (Mode)')
        modeobs = dist_mode_numpy[4+(30*8):].reshape(35, 35, 20)
        (ix, iy, iz), modeobs_extended = extend_obs(modeobs)
        # cutoff = np.mean(modeobs_extended)/10000
        # cutoff = 0
        modeobs_extended = (modeobs_extended - np.min(modeobs_extended)) / (np.max(modeobs_extended) - np.min(modeobs_extended))
        cutoff = 0.1

        ixa = np.append(ix[modeobs_extended>cutoff],[3, -3.])
        iya = np.append(iy[modeobs_extended>cutoff],[3, -3.])
        iza = np.append(iz[modeobs_extended>cutoff],[0, 14])
        # print(ixa.shape, iya.shape, iza.shape)
        # print(np.append(modeobs_extended[modeobs_extended>cutoff],[0,0]).shape)
        plot(modsimobsax,ixa,iya,iza,np.append(modeobs_extended[modeobs_extended>cutoff],[0,0]), trajectories_trace=dist_mode_numpy)

        plt.suptitle(dist.name, x=0.0, y=.99, horizontalalignment='left', verticalalignment='top')
        plt.tight_layout(rect=[0, 0.03, 1, 0.98])

        if file_name is not None:
            plt.savefig(file_name + '.pdf', bbox_inches='tight')


def create_path(path, directory=False):
    if directory:
        dir = path
    else:
        dir = os.path.dirname(path)
    if not os.path.exists(dir):
        print('{} does not exist, creating'.format(dir))
        os.makedirs(dir)


def main():
    try:
        parser = argparse.ArgumentParser(description='etalumis sherpa tau decay experiments', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
        parser.add_argument('--cuda', help='Enable CUDA if available', action='store_true')
        parser.add_argument('--seed', help='Random number seed', default=None, type=int)
        parser.add_argument('--model_executable', help='Model executable to start as a background process', default='/code/sherpa_probprog/sherpa_tau_decay')
        parser.add_argument('--model_address', help='PPX protocol address', default=None, type=str)
        # parser.add_argument('--mode', help='', choices=['create_truths', 'prior', 'infer', 'plot', 'plot_diagnostics', 'plot_diagnostics_network', 'plot_diagnostics_addresses', 'save_numpy', 'save_dataset', 'train'], nargs='?', default='', type=str)
        parser.add_argument('--mode', '-m', help='Experiment mode', choices=['ground_truth', 'prior', 'posterior', 'plot', 'plot_log_prob', 'plot_autocorrelation', 'plot_gelman_rubin', 'plot_addresses', 'plot_traces', 'plot_graph', 'plot_network', 'combine', 'save_numpy', 'save_dataset', 'train'], nargs='?', default=None, type=str)
        parser.add_argument('--num_traces', '-n', help='Number of traces for various tasks', default=None, type=int)
        parser.add_argument('--num_traces_per_file', help='Number of traces per file for offline dataset generation', default=None, type=int)
        parser.add_argument('--distributed_backend', help='Backend to use with distributed training', choices=['mpi'], nargs='?', default=None, type=str)
        parser.add_argument('--input_file', '-i', help='Input file(s) for various tasks', default=[], action='append', type=str)
        parser.add_argument('--output_file', '-o', help='Output file for various tasks', default=None, type=str)
        parser.add_argument('--output_dir', help='Output directory for various tasks', default=None, type=str)
        parser.add_argument('--infer_engine', '-e', help='RMH: MCMC with Random-walk Metropolis Hastings, IS: Importance sampling, IC: Inference compilation', choices=['RMH', 'IS', 'IC'], nargs='?', type=str)
        parser.add_argument('--infer_init', help='Initialization of RMH inference (prior: start from a random trace sampled from prior, ground_truth: start from the ground truth trace)', choices=['prior', 'ground_truth'], nargs='?', default='ground_truth', type=str)
        parser.add_argument('--network_type', help='Type of inference neural network to train for IC inference', default=None, choices=['feedforward', 'lstm'], nargs='?', type=str)
        parser.add_argument('--optimizer', help='Type of optimizer to train NNs for IC inference', choices=['adam', 'sgd', 'larc_adam', 'larc_sgd'], nargs='?', default='adam', type=str)
        parser.add_argument('--LR_schedule_method', help='Type of learning rate schedule method to train NNs for IC inference', choices=['poly_2', 'poly_1', 'cosine', 'multi_steps', 'step'], nargs='?', default='poly_2', type=str)
        parser.add_argument('--learning_rate', help='IC training learning rate', default=0.0001, type=float)
        parser.add_argument('--min_index', help='Starting index of distributions for various tasks', default=None, type=int)
        parser.add_argument('--max_index', help='End index of distributions for various tasks', default=None, type=int)
        parser.add_argument('--prior_inflation', help='Use prior inflation', action='store_true')
        parser.add_argument('--channel', help='Generate a specific decay channel event in mode:ground_truth', default=None, type=int)
        parser.add_argument('--ground_truth', '-g', help='Ground truth file for various tasks', default=None, type=str)
        parser.add_argument('--address_dict', '-a', help='Address dictionary', default=None, type=str)
        parser.add_argument('--dataset_dir', help='Dataset directory for offline training', default=None, type=str)
        parser.add_argument('--dataset_valid_dir', help='Dataset directory for offline validation', default=None, type=str)
        parser.add_argument('--network_dir', help='Dictionary to keep inference neural network snapshots', default=None, type=str)
        parser.add_argument('--use_address_base', help='Use base addresses (ignore iteration counters) in diagnostics plots', action='store_true')
        parser.add_argument('--batch_size', help='Minibatch size for IC training', default=None, type=int)
        parser.add_argument('--valid_size', help='Validation size for IC training (online training only)', default=None, type=int)
        parser.add_argument('--valid_every', help='Number of traces between validations for IC training', default=None, type=int)
        parser.add_argument('--save_every', help='Interval for saving neural network to disk during IC training', default=None, type=int)
        parser.add_argument('--dataloader_offline_num_workers', help='Number of worker threads for data loading in offline training mode', default=0, type=int)
        parser.add_argument('--n_most_frequent', help='Number of most frequent trace types to use for graph construction (mode: plot_graph)', default=None, type=int)

        opt = parser.parse_args()

        if opt.mode is None:
            parser.print_help()
            quit()

        print('etalumis sherpa tau decay experiments\n')
        print('Mode:\n{}\n'.format(opt.mode))

        print('Arguments:\n{}\n'.format(' '.join(sys.argv[1:])))

        print('Config:')
        pprint.pprint(vars(opt), depth=2, width=50)

        if opt.model_address is None:
            opt.model_address = 'ipc://@sherpa_tau_decay_{}'.format(str(uuid.uuid4()))
        model_process = None

        pyprob.set_random_seed(opt.seed)  # Sets random seed from system time if seed is None
        pyprob.set_verbosity(2)
        pyprob.set_cuda(opt.cuda)
        if opt.address_dict is not None:
            print('Address dictionary in use: {}'.format(opt.address_dict))
            create_path(opt.address_dict)

        if (opt.mode == 'ground_truth') and (opt.output_file is not None):
            print('Starting model in the background: {} {}'.format(opt.model_executable, opt.model_address))
            model_process = subprocess.Popen('{} {} > /dev/null &'.format(opt.model_executable, opt.model_address), shell=True, preexec_fn=os.setsid)
            print('Started model process: {}'.format(model_process.pid))
            model = RemoteModel(opt.model_address, address_dict_file_name=opt.address_dict)
            print('Sampling ground truth trace from prior')
            if opt.channel is not None:
                print('Target channel given, enabling prior inflation')
                opt.prior_inflation = True
            found = False
            while not found:
                ground_truth_trace = next(model._trace_generator(prior_inflation=PriorInflation.ENABLED if opt.prior_inflation else PriorInflation.DISABLED, inference_engine=InferenceEngine.RANDOM_WALK_METROPOLIS_HASTINGS))
                ground_truth_channel = int(ground_truth_trace.named_variables['channel_index'].value)
                print('Sampled event with channel: {}'.format(ground_truth_channel))
                if (opt.channel is None):
                    found = True
                elif opt.channel == ground_truth_channel:
                    found = True
            print('Saving ground truth trace to: {}'.format(opt.output_file))
            create_path(opt.output_file)
            torch.save(ground_truth_trace, opt.output_file)
            print('Rendering ground truth to: {}(.png & .txt)'.format(opt.output_file))
            plot_trace(ground_truth_trace, file_name=opt.output_file)

        elif (opt.mode == 'prior') and (opt.output_file is not None) and (opt.num_traces is not None):
            print('Starting model in the background: {} {}'.format(opt.model_executable, opt.model_address))
            model_process = subprocess.Popen('{} {} > /dev/null &'.format(opt.model_executable, opt.model_address), shell=True, preexec_fn=os.setsid)
            print('Started model process: {}'.format(model_process.pid))
            model = RemoteModel(opt.model_address, address_dict_file_name=opt.address_dict)
            print('Saving prior distribution with {} traces to: {}'.format(opt.num_traces, opt.output_file))
            create_path(opt.output_file)
            prior_dist = model.prior_traces(num_traces=opt.num_traces, prior_inflation=PriorInflation.ENABLED if opt.prior_inflation else PriorInflation.DISABLED, file_name=opt.output_file)

        elif (opt.mode == 'posterior') and (opt.output_file is not None) and (opt.ground_truth is not None) and (opt.num_traces is not None) and (opt.infer_engine is not None):
            print('Starting model in the background: {} {}'.format(opt.model_executable, opt.model_address))
            model_process = subprocess.Popen('{} {} > /dev/null &'.format(opt.model_executable, opt.model_address), shell=True, preexec_fn=os.setsid)
            print('Started model process: {}'.format(model_process.pid))
            model = RemoteModel(opt.model_address, address_dict_file_name=opt.address_dict)
            print('Loading ground truth file: {}'.format(opt.ground_truth))
            ground_truth_trace = torch.load(opt.ground_truth)
            observation = ground_truth_trace.named_variables['calorimeter_n_deposits'].distribution.mean
            initial_trace = None
            if opt.infer_engine == 'RMH':
                inference_engine = InferenceEngine.RANDOM_WALK_METROPOLIS_HASTINGS
                if os.path.exists(opt.output_file):
                    print('Found distribution from previous RMH run, resuming MCMC chain: {}'.format(opt.output_file))
                    dist = pyprob.distributions.Empirical(file_name=opt.output_file)
                    dist.finalize()
                    print('Distribution length: {}'.format(dist.length))
                    initial_trace = dist[-1]
                    dist.close()
                else:
                    if opt.infer_init == 'prior':
                        print('Initializing RMH chain with prior.')
                        initial_trace = None
                    elif opt.infer_init == 'ground_truth':
                        print('Initializing RMH chain with ground truth.')
                        initial_trace = ground_truth_trace
            elif opt.infer_engine == 'IS':
                inference_engine = InferenceEngine.IMPORTANCE_SAMPLING
            elif opt.infer_engine == 'IC':
                if opt.network_dir is None:
                    print('network_dir needs to be given for running IC.')
                    quit()
                inference_engine = InferenceEngine.IMPORTANCE_SAMPLING_WITH_INFERENCE_NETWORK
                files = sorted(glob(os.path.join(opt.network_dir, 'sherpa_tau_decay*.network')))
                if len(files) > 0:
                    inference_network_file_name = files[-1]
                    print('Using latest inference network: {}'.format(inference_network_file_name))
                    model.load_inference_network(inference_network_file_name)
                else:
                    print('Cannot find an inference network in folder: ' + opt.network_dir)
                    quit()
            else:
                raise ValueError('Unknown infer_engine: {}'.format(opt.infer_engine))
            print('Saving posterior distribution with {} traces and inference engine {} to: {}'.format(opt.num_traces, opt.infer_engine, opt.output_file))
            create_path(opt.output_file)
            sposterior_dist = model.posterior_traces(num_traces=opt.num_traces, inference_engine=inference_engine, observe={'calorimeter_n_deposits': observation}, initial_trace=initial_trace, file_name=opt.output_file)

        elif (opt.mode == 'plot') and (opt.output_file is not None) and (len(opt.input_file) == 1) and (opt.num_traces is not None):
            ground_truth_trace = None
            if opt.ground_truth is not None:
                print('Loading ground truth file: {}'.format(opt.ground_truth))
                ground_truth_trace = torch.load(opt.ground_truth)

            print('Opening distribution: {}'.format(opt.input_file[0]))
            dist = pyprob.distributions.Empirical(file_name=opt.input_file[0])
            dist.finalize()
            create_path(opt.output_file)
            if dist._uniform_weights:
                dist = dist.thin(opt.num_traces, min_index=opt.min_index, max_index=opt.max_index)
                if opt.channel is not None:
                    print('Filtering channel: {}'.format(opt.channel))
                    dist = dist.filter(lambda trace: int(trace.named_variables['channel_index'].value) == opt.channel)
                print('Plotting distribution to: {}'.format(opt.output_file))
                plot_distribution(dist, ground_truth_trace=ground_truth_trace, file_name=opt.output_file)
            else:
                dist_posterior = dist.resample(opt.num_traces)
                if opt.channel is not None:
                    print('Filtering channel: {}'.format(opt.channel))
                    dist_posterior = dist_posterior.filter(lambda trace: int(trace.named_variables['channel_index'].value) == opt.channel)
                dist_posterior_file_name = opt.output_file
                print('Plotting distribution to: {}'.format(opt.output_file))
                plot_distribution(dist_posterior, ground_truth_trace=ground_truth_trace, file_name=dist_posterior_file_name)

                dist_proposal = dist.thin(opt.num_traces).unweighted().rename(dist.name.replace('Posterior', 'Proposal'))
                if opt.channel is not None:
                    print('Filtering channel: {}'.format(opt.channel))
                    dist_proposal = dist_proposal.filter(lambda trace: int(trace.named_variables['channel_index'].value) == opt.channel)
                dist_proposal_file_name = opt.output_file + '_proposal'
                print('Plotting proposal distribution to: {}'.format(dist_proposal_file_name))
                plot_distribution(dist_proposal, ground_truth_trace=ground_truth_trace, file_name=dist_proposal_file_name)
            dist.close()

        elif (opt.mode == 'plot_traces') and (len(opt.input_file) == 1) and (opt.output_file is not None) and (opt.num_traces is not None):
            dist = pyprob.distributions.Empirical(file_name=opt.input_file[0])
            dist.finalize()
            dist = dist.thin(opt.num_traces, min_index=opt.min_index, max_index=opt.max_index)
            create_path(opt.output_file)
            pyprob.diagnostics.trace_histograms(dist, plot=True, plot_show=False, file_name=opt.output_file, use_address_base=opt.use_address_base)
            dist.close()

        elif (opt.mode == 'plot_addresses') and (len(opt.input_file) > 0) and (opt.output_file is not None) and (opt.num_traces is not None):
            dists = []
            for file in opt.input_file:
                print('Opening distribution: {}'.format(file))
                dist = pyprob.distributions.Empirical(file_name=file)
                dist.finalize()
                if dist._uniform_weights:
                    dist = dist.thin(opt.num_traces, min_index=opt.min_index, max_index=opt.max_index)
                    dists.append(dist)
                else:
                    dist_posterior = dist.resample(opt.num_traces)
                    dists.append(dist_posterior)
                    dist_proposal = dist.thin(opt.num_traces).unweighted().rename(dist.name.replace('Posterior', 'Proposal'))
                    dists.append(dist_proposal)

            ground_truth_trace = None
            if opt.ground_truth is not None:
                print('Loading ground truth file: {}'.format(opt.ground_truth))
                ground_truth_trace = torch.load(opt.ground_truth)
            create_path(opt.output_file)
            pyprob.diagnostics.address_histograms(dists, plot=True, plot_show=False, file_name=opt.output_file, ground_truth_trace=ground_truth_trace, use_address_base=opt.use_address_base)

        elif (opt.mode == 'plot_log_prob') and (len(opt.input_file) > 0) and (opt.output_file is not None):
            dists = []
            for file in opt.input_file:
                print('Opening distribution: {}'.format(file))
                dist = pyprob.distributions.Empirical(file_name=file)
                dist.finalize()
                dists.append(dist)

            create_path(opt.output_file)
            log_prob_plot_file_name = opt.output_file + '.pdf'
            print('Saving log_prob plot to: {}'.format(log_prob_plot_file_name))
            iter_numpy, log_prob_numpy = pyprob.diagnostics.log_prob(dists, plot=True, plot_show=False, file_name=log_prob_plot_file_name, min_index=opt.min_index, max_index=opt.max_index, resolution=1000 if opt.num_traces is None else opt.num_traces)
            log_prob_iter_file_name = opt.output_file + '_log_prob_iter.npy'
            print('Saving log_prob iterations to Numpy: {}'.format(log_prob_iter_file_name))
            np.save(log_prob_iter_file_name, iter_numpy)
            log_prob_file_name = opt.output_file + '_log_prob.npy'
            print('Saving log_prob to Numpy: {}'.format(log_prob_file_name))
            np.save(log_prob_file_name, log_prob_numpy)

            for dist in dists:
                dist.close()

        elif (opt.mode == 'plot_gelman_rubin') and (len(opt.input_file) > 1) and (opt.output_file is not None):
            dists = []
            for file in opt.input_file:
                print('Opening distribution: {}'.format(file))
                dist = pyprob.distributions.Empirical(file_name=file)
                dist.finalize()
                dists.append(dist)

            create_path(opt.output_file)
            gr_plot_file_name = opt.output_file + '.pdf'
            print('Saving Gelman-Rubin diagnostic plot to: {}'.format(gr_plot_file_name))
            iter_numpy, gr_numpy = pyprob.diagnostics.gelman_rubin(dists, n_most_frequent=50, plot=True, plot_show=False, file_name=gr_plot_file_name)
            gr_iter_file_name = opt.output_file + '_iter.npy'
            print('Saving Gelman-Rubin diagnostic iterations to Numpy: {}'.format(gr_iter_file_name))
            np.save(gr_iter_file_name, iter_numpy)
            gr_file_name = opt.output_file + '.npy'
            print('Saving Gelman-Rubin diagnostic to Numpy: {}'.format(gr_file_name))
            np.save(gr_file_name, gr_numpy)

            for dist in dists:
                dist.close()

        elif (opt.mode == 'plot_autocorrelation') and (len(opt.input_file) == 1) and (opt.output_file is not None):
            print('Opening distribution: {}'.format(opt.input_file[0]))
            dist = pyprob.distributions.Empirical(file_name=opt.input_file[0])
            dist.finalize()
            if dist.length == 0:
                print('Distribution is empty.')
            else:
                if '.pdf' not in opt.output_file:
                    opt.output_file += '.pdf'
                create_path(opt.output_file)
                print('Saving autocorrelation plot to: {}'.format(opt.output_file))
                lags_numpy, autocorrelations_dict = pyprob.diagnostics.autocorrelations(dist, names=['channel_index', 'mother_momentum_x', 'mother_momentum_y', 'mother_momentum_z'], n_most_frequent=50, plot=True, plot_show=False, file_name=opt.output_file)
                np.save(opt.output_file + '_lags.npy', lags_numpy)
                torch.save(autocorrelations_dict, opt.output_file + '_autocorrvalues')
            dist.close()

        elif (opt.mode == 'plot_network') and (len(opt.input_file) == 1) and (opt.output_dir is not None):
            print('Loading inference network: {}'.format(opt.input_file[0]))
            inference_network = pyprob.nn.InferenceNetworkFeedForward._load(opt.input_file[0])
            print('Saving inference network diagnostics to folder: {}'.format(opt.output_dir))
            create_path(opt.output_dir, True)
            pyprob.diagnostics.network(inference_network, opt.output_dir)

        elif (opt.mode == 'plot_graph') and (len(opt.input_file) == 1 or len(opt.input_file) == 2) and (opt.output_dir is not None) and (opt.num_traces is not None):
            dists = []
            for file in opt.input_file:
                print('Opening distribution: {}'.format(file))
                dist = pyprob.distributions.Empirical(file_name=file)
                dist.finalize()
                if dist._uniform_weights:
                    dist = dist.thin(opt.num_traces, min_index=opt.min_index, max_index=opt.max_index)
                else:
                    dist = dist.resample(opt.num_traces)
                dists.append(dist)
            create_path(opt.output_dir, True)
            dist = dists[0]
            base_graph = None
            if len(dists) == 2:
                print('Two distributions give, using the first as the base (background) graph')
                base_graph = pyprob.diagnostics.graph(dists[0], n_most_frequent=opt.n_most_frequent, use_address_base=opt.use_address_base)
                dist = dists[1]
            file_name = os.path.join(opt.output_dir, 'graph')
            print('Producing graphs: {}'.format(file_name))
            pyprob.diagnostics.graph(dist, file_name=file_name, n_most_frequent=opt.n_most_frequent, use_address_base=opt.use_address_base, base_graph=base_graph)

        elif (opt.mode == 'combine') and (len(opt.input_file) > 1) and (opt.output_file is not None):
            dists = []
            for file in opt.input_file:
                print('Opening distribution: {}'.format(file))
                dist = pyprob.distributions.Empirical(file_name=file)
                dist.finalize()
                dists.append(dist)
            create_path(opt.output_file)
            print('Writing combined distribution to: {}'.format(opt.output_file))
            pyprob.distributions.Empirical.combine(dists, file_name=opt.output_file)
            for dist in dists:
                dist.close()

        elif (opt.mode == 'save_numpy') and (len(opt.input_file) == 1) and (opt.output_file is not None) and (opt.num_traces is not None):
            print('Opening distribution: {}'.format(opt.input_file[0]))
            dist = pyprob.distributions.Empirical(file_name=opt.input_file[0])
            dist.finalize()
            create_path(opt.output_file)
            if dist._uniform_weights:
                dist = dist.thin(opt.num_traces, min_index=opt.min_index, max_index=opt.max_index)
                distribution_to_numpy(dist, opt.output_file)
            else:
                dist_posterior = dist.resample(opt.num_traces)
                distribution_to_numpy(dist_posterior, opt.output_file + '_posterior')
                dist_proposal = dist.thin(opt.num_traces).unweighted().rename(dist.name.replace('Posterior', 'Proposal'))
                distribution_to_numpy(dist_proposal, opt.output_file + '_proposal')
            dist.close()

        elif (opt.mode == 'save_dataset') and (opt.dataset_dir is not None) and (opt.num_traces is not None) and (opt.num_traces_per_file is not None):
            print('Starting model in the background: {} {}'.format(opt.model_executable, opt.model_address))
            model_process = subprocess.Popen('{} {} > /dev/null &'.format(opt.model_executable, opt.model_address), shell=True, preexec_fn=os.setsid)
            print('Started model process: {}'.format(model_process.pid))
            model = RemoteModel(opt.model_address, address_dict_file_name=opt.address_dict)
            create_path(opt.dataset_dir, True)
            print('Saving offline dataset to: {}'.format(opt.dataset_dir))
            model.save_dataset(dataset_dir=opt.dataset_dir, num_traces=opt.num_traces, num_traces_per_file=opt.num_traces_per_file, prior_inflation=PriorInflation.ENABLED if opt.prior_inflation else PriorInflation.DISABLED)
            print('Offline dataset saved to {}'.format(opt.dataset_dir))

        elif (opt.mode == 'train') and (opt.num_traces is not None) and (opt.network_dir is not None) and (opt.batch_size is not None) and (opt.save_every is not None) and (opt.network_type is not None):
            model = RemoteModel(opt.model_address, address_dict_file_name=opt.address_dict)
            if opt.dataset_dir is None:
                print('Online training')
                print('Starting model in the background: {} {}'.format(opt.model_executable, opt.model_address))
                model_process = subprocess.Popen('{} {} > /dev/null &'.format(opt.model_executable, opt.model_address), shell=True, preexec_fn=os.setsid)
                print('Started model process: {}'.format(model_process.pid))
            else:
                print('Offline training with dataset at: {}'.format(opt.dataset_dir))

            create_path(opt.network_dir, True)
            inference_network_file_name = None
            files = sorted(glob(os.path.join(opt.network_dir, 'sherpa_tau_decay*.network')))
            if len(files) > 0:
                inference_network_file_name = files[-1]
                print('Resuming to train latest inference network: {}'.format(inference_network_file_name))
                model.load_inference_network(inference_network_file_name)
            else:
                print('Creating new inference network, none found in: {}'.format(opt.network_dir))
            save_file_name_prefix = os.path.join(opt.network_dir, 'sherpa_tau_decay')
            if opt.network_type == 'feedforward':
                inference_network = InferenceNetwork.FEEDFORWARD
            else:
                inference_network = InferenceNetwork.LSTM
            if opt.optimizer == 'adam':
                optimizer_type = pyprob.Optimizer.ADAM
            elif opt.optimizer == 'sgd':
                optimizer_type = pyprob.Optimizer.SGD
            elif opt.optimizer == 'larc_adam':#Lei
                optimizer_type = pyprob.Optimizer.LARC_ADAM
            elif opt.optimizer == 'larc_sgd':#Lei
                optimizer_type = pyprob.Optimizer.LARC_SGD
            else:
                print('Unknown optimizer: {}'.format(opt.optimizer))
                quit()
            if opt.LR_schedule_method== 'step':
                LR_schedule_method = pyprob.LRScheduler.STEP
            elif opt.LR_schedule_method== 'multi_steps':
                LR_schedule_method = pyprob.LRScheduler.MULTI_STEPS
            elif opt.LR_schedule_method== 'poly_2':
                LR_schedule_method = pyprob.LRScheduler.POLY_2
            elif opt.LR_schedule_method== 'poly_1':
                LR_schedule_method = pyprob.LRScheduler.POLY_1
            elif opt.LR_schedule_method== 'cosine':
                LR_schedule_method = pyprob.LRScheduler.COSINEANNEALING


            model.learn_inference_network(num_traces=opt.num_traces, observe_embeddings={'calorimeter_n_deposits': {'reshape': [1, 35, 35, 20], 'embedding': ObserveEmbedding.CNN3D5C}}, batch_size=opt.batch_size, valid_size=opt.valid_size, valid_every=opt.valid_every, save_file_name_prefix=save_file_name_prefix, save_every_sec=opt.save_every, prior_inflation=PriorInflation.ENABLED if opt.prior_inflation else PriorInflation.DISABLED, dataset_dir=opt.dataset_dir, dataset_valid_dir=opt.dataset_valid_dir, distributed_backend=opt.distributed_backend, inference_network=inference_network, dataloader_offline_num_workers=opt.dataloader_offline_num_workers, optimizer_type=optimizer_type, LR_schedule_method=LR_schedule_method, learning_rate=opt.learning_rate)

        else:
            parser.print_help()
            quit()
        print()

    except KeyboardInterrupt:
        print('Stopped.')
    except Exception:
        traceback.print_exc(file=sys.stdout)

    if model_process is not None:
        print('Done, killing model process: {}'.format(model_process.pid))
        os.killpg(os.getpgid(model_process.pid), signal.SIGTERM)



if __name__ == "__main__":
    time_start = time.time()
    main()
    print('\nTotal duration: {}'.format(days_hours_mins_secs_str(time.time() - time_start)))
    sys.exit(0)
