import traceback
import argparse
import sys
import os
import subprocess
import pprint
from glob import glob
import gc
import uuid
import shutil
from functools import reduce
import torch

import pyprob
from pyprob import ModelRemote, InferenceEngine, PriorInflation
from pyprob.distributions import Distribution, Empirical

import numpy as np
import matplotlib.pyplot as plt
plt.switch_backend('agg')
import matplotlib as mpl
from matplotlib.ticker import MultipleLocator
from matplotlib import cm
colors = [cm.inferno(x) for x in np.linspace(0, 1, 5)]
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
colors = [cm.inferno(x) for x in np.linspace(0, 1, 5)]


min_energy_deposit = 0.05  # caloutils::minEnergyDeposit in sharpa_tau_decay.cpp


def plot_distribution(dist, plot_samples=250000, title='', ground_truths=None, file_name=None):
    dist_px = dist.map(lambda x: float(x[0]))
    dist_py = dist.map(lambda x: float(x[1]))
    dist_pz = dist.map(lambda x: float(x[2]))
    dist_channel = dist.map(lambda x: int(x[3]))
    dist_px_samples = [dist_px.sample() for i in range(plot_samples)]
    dist_py_samples = [dist_py.sample() for i in range(plot_samples)]
    dist_pz_samples = [dist_pz.sample() for i in range(plot_samples)]
    dist_channel_samples = [dist_channel.sample() for i in range(plot_samples)]

    dist_channel_combined = pyprob.distributions.Empirical(dist_channel.values, weights=dist_channel.weights, combine_duplicates=True)
    dist_channel_combined_values = dist_channel_combined.values
    dist_channel_combined_weights = [float(w) for w in dist_channel_combined.weights]
    dist_channel_str = 'channels: ' + ', '.join(['{}: {:.3f}'.format(v, w) for v, w in zip(dist_channel_combined_values, dist_channel_combined_weights)])

    fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(20, 4))
    plt.suptitle(title)
    ax1.title.set_text('Channel')
    ax1.text(-1, -0.3, 'mean={:.3f} (solid), stddev={:.3f}'.format(float(dist_channel.mean), float(dist_channel.stddev)))
    ax1.text(-1, -0.4, dist_channel_str)
    ax2.title.set_text('p_x')
    ax2.text(-2.5, -3, 'mean={:.3f} (solid), stddev={:.3f}'.format(float(dist_px.mean), float(dist_px.stddev)))
    ax3.title.set_text('p_y')
    ax3.text(-2.5, -3, 'mean={:.3f} (solid), stddev={:.3f}'.format(float(dist_py.mean), float(dist_py.stddev)))
    ax4.title.set_text('p_z')
    ax4.text(43, -3.6, 'mean={:.3f} (solid), stddev={:.3f}'.format(float(dist_pz.mean), float(dist_pz.stddev)))

    _ = ax1.hist(dist_channel_samples, bins=np.arange(-0.5,38+0.5,1), density=1, alpha=0.75, color=colors[0])
    ax1.set_xlim([-1,39])
    ax1.set_ylim([0,1])
    ax1.axvline(float(dist_channel.mean), color='gray', linestyle='solid', linewidth=1)
    ax1.xaxis.set_minor_locator(MultipleLocator(1))
    _ = ax2.hist(dist_px_samples, bins=np.arange(-2.5,2.5+0.1,0.1), density=1, alpha=0.75, color=colors[1])
    ax2.set_xlim([-2.5,2.5])
    ax2.set_ylim([0,10])
    ax2.axvline(float(dist_px.mean), color='gray', linestyle='solid', linewidth=1)
    _ = ax3.hist(dist_py_samples, bins=np.arange(-2.5,2.5+0.1,0.1), density=1, alpha=0.75, color=colors[2])
    ax3.set_xlim([-2.5,2.5])
    ax3.set_ylim([0,10])
    ax3.axvline(float(dist_py.mean), color='gray', linestyle='solid', linewidth=1)
    _ = ax4.hist(dist_pz_samples, bins=np.arange(43,47+0.08,0.08), density=1, alpha=0.75, color=colors[3])
    ax4.set_xlim([43,47])
    ax4.set_ylim([0,12.5])
    ax4.axvline(float(dist_pz.mean), color='gray', linestyle='solid', linewidth=1)
    if ground_truths is not None:
        ax1.axvline(float(ground_truths[0]), color='gray', linestyle='dashed', linewidth=2)
        ax2.axvline(float(ground_truths[1]), color='gray', linestyle='dashed', linewidth=2)
        ax3.axvline(float(ground_truths[2]), color='gray', linestyle='dashed', linewidth=2)
        ax4.axvline(float(ground_truths[3]), color='gray', linestyle='dashed', linewidth=2)

        ax1.text(-1, -0.2, 'ground_truth={:.3f} (dashed)'.format(float(ground_truths[0])))
        ax2.text(-2.5, -2, 'ground_truth={:.3f} (dashed)'.format(float(ground_truths[1])))
        ax3.text(-2.5, -2, 'ground_truth={:.3f} (dashed)'.format(float(ground_truths[2])))
        ax4.text(43, -2.4, 'ground_truth={:.3f} (dashed)'.format(float(ground_truths[3])))

    if file_name is not None:
        plt.savefig(file_name, bbox_inches='tight')


def plot_trace(trace, file_name=None):
    data = trace.samples_observed[0].distribution.mean.view(35,35,20).data.numpy()
    data_flat = data.reshape(-1)
    data_flat_min = min(data_flat) * min_energy_deposit
    data_flat_max = max(data_flat) * min_energy_deposit
    data_flat_total = np.sum(data_flat) * min_energy_deposit

    particles = trace.result[4:]
    particles = particles[particles > -9999].view(-1,8)

    trace_text = [str(trace),
                  '\nlatents',
                  'p_x      : {}'.format(float(trace.result[0])),
                  'p_y      : {}'.format(float(trace.result[1])),
                  'p_z      : {}'.format(float(trace.result[2])),
                  'channel  : {}'.format(int(trace.result[3])),
                  'particles: {}'.format(particles.size(0)),
                  str(particles),
                  'calorimeter',
                  'min   : {:.3f} GeV'.format(data_flat_min),
                  'max   : {:.3f} GeV'.format(data_flat_max),
                  'total : {:.3f} GeV'.format(data_flat_total)]
    trace_text = '\n'.join(trace_text)
#     print(trace_text)

    fig = plt.figure(figsize=(15, 4))
    ax1 = fig.add_subplot(1, 2, 1)
    ax2 = fig.add_subplot(1, 2, 2, projection='3d')

    ax1.text(0, 0, trace_text)
    ax1.axis('off')
    cmbr = mpl.colors.LinearSegmentedColormap.from_list('blue_to_red',['b','r'])
    cnorm = mpl.colors.Normalize(vmin=data_flat_min,vmax=data_flat_max)
    cpick = mpl.cm.ScalarMappable(norm=cnorm,cmap=cmbr)
    cpick.set_array([])

    ax2.set_xlabel('x')
    ax2.set_ylabel('y')
    ax2.set_zlabel('z')
    ax2.set_aspect('equal')

    if data_flat.max() - data_flat.min() > 0:
        ix,iy,iz = np.mgrid[-3:3:35j,-3:3:35j,4:15:20j]
        data_flat_normalized = np.clip((data_flat - data_flat.min())/(data_flat.max() - data_flat.min()), 0, 1)
        colors = [[x,0,1-x,x] for x in data_flat_normalized]
        ax2.scatter(ix, iy, iz, s=100, c=colors, marker='o', edgecolor='none')

        plt.colorbar(cpick,label='Deposited energy (GeV)',fraction=0.02)

    if file_name is not None:
        plt.savefig(file_name, bbox_inches='tight')


def main():
    try:
        parser = argparse.ArgumentParser(description='Sherpa tau decay experiments', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
        parser.add_argument('--dir', help='experiment root directory', default='', type=str)
        parser.add_argument('--cuda', help='enable CUDA if available', action='store_true')
        parser.add_argument('--model_executable', help='model executable to start as a background process', default='/code/sherpa_probprog/sherpa_tau_decay')
        parser.add_argument('--mode', help='experiment mode (save_traces: save traces to disk for later training, train: train inference neural network, create_experiments: create ground truths for experiments, infer: perform inference, infer_combine: combine existing inference results into a single posterior per inference engine)', choices=['save_traces', 'train', 'create_experiments', 'infer', 'infer_combine'], nargs='?', default='', type=str)
        parser.add_argument('--save_traces_files', help='number of files to save (-1: unlimited)', default=1024, type=int)
        parser.add_argument('--save_traces_per_file', help='number of traces per file', default=256, type=int)
        parser.add_argument('--train_traces', help='number of traces to use for training', default=150000000, type=int)
        parser.add_argument('--train_batch_size', help='size of the minibatches for training', default=32, type=int)
        parser.add_argument('--train_valid_size', help='size of the validation set for training', default=256, type=int)
        parser.add_argument('--train_valid_interval', help='number of traces between validations (and neural network saving)', default=20000, type=int)
        parser.add_argument('--train_use_cache', help='use existing trace cache from disk (offline training) instead of trace generation from model (online training)', action='store_true')
        parser.add_argument('--ground_truths', help='number of ground truth traces to generate', default=20, type=int)
        parser.add_argument('--infer_burn_in', help='list of burn in values (traces to discard in the beginning) for Metropolis Hastings (-1: default, 10%% of infer_traces)', default=[-1, -1, -1, -1], type=int, nargs='+')
        parser.add_argument('--infer_engine', help='list of inference engines for each inference', default=['LMH', 'RMH', 'IS', 'IC'], choices=['LMH', 'RMH', 'IS', 'IC'], type=str, nargs='+')
        parser.add_argument('--infer_traces', help='list of number of traces for each inference', default=[1000, 1000, 1000, 1000], type=int, nargs='+')
        parser.add_argument('--infer_report_traces', help='number of traces to plot in the inference network report', default=10, type=int)
        parser.add_argument('--prior_traces', help='number of traces to use for plotting the prior', default=10, type=int)
        parser.add_argument('--plot_samples', help='number of samples from the distributions to use for histogram plotting purposes', default=250000, type=int)
        parser.add_argument('--save_distributions', help='save resulting distributions (in addition to the distribution plots) to file for later analysis and use. Beware, potentially very large file sizes up to several GB per distribution.', action='store_true')
        parser.add_argument('--prior_inflation', help='use prior inflation', action='store_true')

        opt = parser.parse_args()

        if opt.dir == '' or opt.mode == '':
            parser.print_help()
            quit()

        if not (len(opt.infer_traces) == len(opt.infer_burn_in) == len(opt.infer_engine)):
            print('infer_traces, infer_burn_in, and infer_engine should have the same number of elements.')
            parser.print_help()
            quit()

        for i in range(len(opt.infer_burn_in)):
            if opt.infer_burn_in[i] == -1:
                opt.infer_burn_in[i] = int(opt.infer_traces[i] / 10)

        print('Mode: ' + opt.mode)

        pprint.pprint(vars(opt), depth=2, width=50)

        model_address = 'ipc://@sherpa_tau_decay_{}'.format(str(uuid.uuid4()))
        print('Starting model in the background: {} {}'.format(opt.model_executable, model_address))
        os.system('{} {} > /dev/null &'.format(opt.model_executable, model_address))

        if not os.path.exists(opt.dir):
            print('{} does not exist, creating...'.format(opt.dir))
            os.makedirs(opt.dir)
        inference_networks_folder = os.path.join(opt.dir, 'inference_networks')
        experiments_folder = os.path.join(opt.dir, 'experiments')
        trace_cache_folder = os.path.join(opt.dir, 'trace_cache')
        current_script_file_name = os.path.abspath(__file__)
        print('Copying current script to {} ...'.format(opt.dir))
        shutil.copy(current_script_file_name, opt.dir)

        pyprob.set_random_seed(None)  # Set random seed from system time
        pyprob.set_verbosity(2)
        model = ModelRemote(model_address)
        pyprob.set_cuda(opt.cuda)

        if opt.mode == 'save_traces':
            if not os.path.exists(trace_cache_folder):
                print('{} does not exist, creating...'.format(trace_cache_folder))
                os.makedirs(trace_cache_folder)

            pyprob.set_verbosity(3)
            print('Saving trace cache to {}...'.format(trace_cache_folder))
            model.save_trace_cache(trace_cache_folder, opt.save_traces_files, opt.save_traces_per_file, PriorInflation.ENABLED if opt.prior_inflation else PriorInflation.DISABLED)
            print('Trace cache saved to {}'.format(trace_cache_folder))

        elif opt.mode == 'train':
            if not os.path.exists(inference_networks_folder):
                print('{} does not exist, creating...'.format(inference_networks_folder))
                os.makedirs(inference_networks_folder)
            if opt.train_use_cache:
                if not os.path.exists(trace_cache_folder):
                    print('{} does not exist, creating...'.format(trace_cache_folder))
                    os.makedirs(trace_cache_folder)
                model.use_trace_cache(trace_cache_folder)

            model.learn_inference_network(num_traces=opt.train_traces, observe_embedding=pyprob.nn.ObserveEmbedding.CONVNET_3D_4C, observe_reshape=[1, 35, 35, 20], batch_size=opt.train_batch_size, valid_size=opt.train_valid_size, valid_interval=opt.train_valid_interval, auto_save=True, auto_save_file_name=os.path.join(inference_networks_folder, 'pyprob_inference_network'), use_trace_cache=opt.train_use_cache, prior_inflation=PriorInflation.ENABLED if opt.prior_inflation else PriorInflation.DISABLED)

        elif opt.mode == 'create_experiments':
            if not os.path.exists(experiments_folder):
                print('{} does not exist, creating...'.format(experiments_folder))
                os.makedirs(experiments_folder)

            for i in range(opt.ground_truths):
                print('\n\n** Creating ground truth {}/{}\n\n'.format(i+1, opt.ground_truths))
                ground_truth_trace = next(model._trace_generator(inference_engine=InferenceEngine.LIGHTWEIGHT_METROPOLIS_HASTINGS))

                ground_truth_folder = os.path.join(experiments_folder, 'ground_truth_' + str(uuid.uuid4()))
                os.makedirs(ground_truth_folder)

                ground_truth_plot_file_name = os.path.join(ground_truth_folder, 'ground_truth.pdf')
                print('Plotting ground truth to {} ...'.format(ground_truth_plot_file_name))
                plot_trace(ground_truth_trace, file_name=ground_truth_plot_file_name)
                ground_truth_trace_file_name = os.path.join(ground_truth_folder, 'ground_truth.trace')
                print('Saving trace to          {} ...'.format(ground_truth_trace_file_name))
                model._save_traces([ground_truth_trace], ground_truth_trace_file_name)

        elif opt.mode == 'infer_combine':
            ground_truth_folders = sorted(glob(os.path.join(experiments_folder, 'ground_truth_*')))
            if len(ground_truth_folders) > 0:
                for i in range(len(ground_truth_folders)):
                    print('\n\n** Ground truth {}/{}\n\n'.format(i+1, len(ground_truth_folders)))
                    ground_truth_folder = ground_truth_folders[i]

                    ground_truth_file = os.path.join(ground_truth_folder, 'ground_truth.trace')
                    print('Loading ground truth file {} ...'.format(ground_truth_file))
                    ground_truth_trace = model._load_traces(ground_truth_file)[0]
                    # observation = ground_truth_trace.samples_observed[0].distribution.mean * min_energy_deposit

                    for engine in opt.infer_engine:
                        posterior_file_names = sorted(glob(os.path.join(ground_truth_folder, '{}_posterior_num_traces*.distribution'.format(engine))))
                        if len(posterior_file_names) > 0:
                            inference_suffix = str(uuid.uuid4())
                            posteriors = []
                            for j in range(len(posterior_file_names)):
                                print('Loading {} posterior {}/{}'.format(engine, j+1, len(posterior_file_names)))
                                posterior_file_name = posterior_file_names[j]
                                posterior = Distribution.load(posterior_file_name)
                                posteriors.append(posterior)
                            print('Combining {} {} posteriors...'.format(len(posteriors), engine))
                            combined_posterior_dist = Empirical.combine(posteriors)
                            total_num_traces = sum([dist.length for dist in posteriors])
                            if opt.save_distributions:
                                combined_posterior_dist_file_name = os.path.join(ground_truth_folder, '{}_posterior_combined_num_traces_{}_{}.distribution'.format(engine, total_num_traces, inference_suffix))
                                print('Saving combined {} posterior to {} ...'.format(engine, combined_posterior_dist_file_name))
                                combined_posterior_dist.save(combined_posterior_dist_file_name)

                            combined_posterior_plot_file_name = os.path.join(ground_truth_folder, '{}_posterior_combined_num_traces_{}_{}.pdf'.format(engine, total_num_traces, inference_suffix))
                            print('Plotting combined {} posterior to {} ...'.format(engine, combined_posterior_plot_file_name))
                            plot_distribution(combined_posterior_dist, plot_samples=opt.plot_samples, title='Combined posterior, {}, num_posteriors={}, total_traces={}'.format(engine, len(posteriors), total_num_traces), ground_truths=[ground_truth_trace.result[3], ground_truth_trace.result[0], ground_truth_trace.result[1], ground_truth_trace.result[2]], file_name=combined_posterior_plot_file_name)

            else:
                print('Cannot find any ground truths in experiments folder ' + experiments_folder)
                quit()

        elif opt.mode == 'infer':
            inference_network_file_name = None
            if 'IC' in opt.infer_engine:
                files = sorted(glob(os.path.join(inference_networks_folder, 'pyprob_inference_network*')))
                if len(files) > 0:
                    inference_network_file_name = files[-1]
                    print('Selecting latest inference network {}'.format(inference_network_file_name))
                    model.load_inference_network(inference_network_file_name)
                    analytics_file_name = os.path.join(opt.dir, 'inference_network_report')
                    print('Saving inference network analytics report to {} ...'.format(analytics_file_name))
                    # try:
                    model.save_analytics(analytics_file_name, detailed_traces=opt.infer_report_traces)
                    # except:
                        # print('Cannot save inference network analytics report. Check that a tex package (e.g., texlive-full) is installed.')
                else:
                    print('Cannot find an inference network in folder ' + inference_networks_folder)
                    quit()

            prior_file_name = os.path.join(opt.dir, 'prior_num_traces_{}.pdf'.format(opt.prior_traces))
            print('Plotting prior to {} ...'.format(prior_file_name))
            prior_dist = model.prior_distribution(num_traces=opt.prior_traces)
            plot_distribution(prior_dist, plot_samples=opt.plot_samples, title='Prior, num_traces={}'.format(opt.prior_traces), file_name=prior_file_name)

            prior_inflated_file_name = os.path.join(opt.dir, 'prior_inflated_num_traces_{}.pdf'.format(opt.prior_traces))
            print('Plotting inflated prior to {} ...'.format(prior_inflated_file_name))
            prior_inflated_dist = model.prior_distribution(num_traces=opt.prior_traces, prior_inflation=PriorInflation.ENABLED)
            plot_distribution(prior_inflated_dist, plot_samples=opt.plot_samples, title='Prior (inflated), num_traces={}'.format(opt.prior_traces), file_name=prior_inflated_file_name)

            ground_truth_folders = sorted(glob(os.path.join(experiments_folder, 'ground_truth_*')))

            if len(ground_truth_folders) > 0:
                for i in range(len(ground_truth_folders)):
                    print('\n\n** Ground truth {}/{}\n\n'.format(i+1, len(ground_truth_folders)))
                    ground_truth_folder = ground_truth_folders[i]

                    ground_truth_file = os.path.join(ground_truth_folder, 'ground_truth.trace')
                    print('Loading ground truth file {} ...'.format(ground_truth_file))
                    ground_truth_trace = model._load_traces(ground_truth_file)[0]
                    observation = ground_truth_trace.samples_observed[0].distribution.mean * min_energy_deposit

                    for j in range(len(opt.infer_traces)):
                        gc.collect()

                        inference_suffix = str(uuid.uuid4())
                        num_traces = opt.infer_traces[j]
                        burn_in = opt.infer_burn_in[j]
                        engine = opt.infer_engine[j]
                        if engine == 'LMH':
                            inference_engine = InferenceEngine.LIGHTWEIGHT_METROPOLIS_HASTINGS
                        elif engine == 'RMH':
                            inference_engine = InferenceEngine.RANDOM_WALK_METROPOLIS_HASTINGS
                        elif engine == 'IS':
                            inference_engine = InferenceEngine.IMPORTANCE_SAMPLING
                        else:  # engine == 'IC':
                            inference_engine = InferenceEngine.IMPORTANCE_SAMPLING_WITH_INFERENCE_NETWORK

                        print('\n\n**** Running inference with {}, {} traces...\n\n'.format(engine, num_traces))
                        posterior_dist = model.posterior_distribution(num_traces=num_traces, inference_engine=inference_engine, burn_in=burn_in, initial_trace=ground_truth_trace, observation=observation)
                        if opt.save_distributions:
                            posterior_dist_file_name = os.path.join(ground_truth_folder, '{}_posterior_num_traces_{}_{}.distribution'.format(engine, num_traces, inference_suffix))
                            print('Saving posterior to      {} ...'.format(posterior_dist_file_name))
                            posterior_dist.save(posterior_dist_file_name)

                        posterior_plot_file_name = os.path.join(ground_truth_folder, '{}_posterior_num_traces_{}_{}.pdf'.format(engine, num_traces, inference_suffix))
                        print('Plotting posterior to {} ...'.format(posterior_plot_file_name))
                        plot_distribution(posterior_dist, plot_samples=opt.plot_samples, title=posterior_dist.name, ground_truths=[ground_truth_trace.result[3], ground_truth_trace.result[0], ground_truth_trace.result[1], ground_truth_trace.result[2]], file_name=posterior_plot_file_name)

                        if engine == 'IS' or engine == 'IC':
                            proposal_plot_file_name = os.path.join(ground_truth_folder, '{}_proposal_num_traces_{}_{}.pdf'.format(engine, num_traces, inference_suffix))
                            print('Plotting proposal to  {} ...'.format(proposal_plot_file_name))
                            plot_distribution(posterior_dist.unweighted(), plot_samples=opt.plot_samples, title=posterior_dist.name.replace('Posterior', 'Unweighted posterior (i.e., proposal)'), ground_truths=[ground_truth_trace.result[3], ground_truth_trace.result[0], ground_truth_trace.result[1], ground_truth_trace.result[2]], file_name=proposal_plot_file_name)

                        print()
                    print()

            else:
                print('Cannot find any ground truths in experiments folder ' + experiments_folder)
                quit()

    except KeyboardInterrupt:
        print('Stopped.')
    except Exception:
        traceback.print_exc(file=sys.stdout)

    sys.exit(0)


if __name__ == "__main__":
    main()
