import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
import pandas as pd
import os
import seaborn as sns
from matplotlib.backends.backend_pdf import PdfPages
import h5py
import matplotlib.ticker as ticker
import matplotlib.colors as mcolors
import matplotlib.cm as cm
import argparse


#Start defining functions: 

def thd_correct(array):

    # Define start and end indices
    bins = np.int32(np.shape(array)[3] / 25)
    #print("bins:", bins)
    indices = np.arange(0, bins+1) * 25  # (39,)
    start_indices, end_indices = indices[:-1], indices[1:]  # (39,)

    segment_range = np.arange(25)  # Shape: (25,)
    index_array = start_indices[:, None] + segment_range  # Shape: (39, 25)
    #print('a')

    # Extract data using advanced indexing
    sliced_data = array[..., index_array]
    del index_array

    ranges = np.abs(np.ptp(sliced_data, axis=-1))
    means = np.mean(sliced_data, axis=-1)  # Compute mean (n, 8, 64, 39)
    del sliced_data

    # Find ordering based on the smallest range
    smallest_ordering = np.argsort(ranges, axis=-1)  # Shape (n, 8, 64, 39)
    mask_zero = (ranges != 0)
    del ranges
    means = np.where(mask_zero, means, np.nan)
    del mask_zero
    # Sort means using the ordering
    sorted_means = np.take_along_axis(means, smallest_ordering, axis=-1)  # Shape (n, 8, 64, 39)
    del means
    #print('c')
    # Compute average of 2nd, 3rd, and 4th smallest means
    average_mean = np.mean(sorted_means[..., :1], axis=-1)  # Shape (n, 8, 64)
    #print('d')
    del sorted_means
    expanded_mean = average_mean[..., None] 
    #print('e')
    del average_mean
    broadcasted_mean = np.tile(expanded_mean, (1, np.shape(array)[-1]))  
    #print('f')
    del expanded_mean
    filtered_wvfm = array - broadcasted_mean
    #print('g')
    del broadcasted_mean
    del array
    #print('h')
    return filtered_wvfm

def track_rate_changes(Filenum_list):

    file_num = 0
    output_rates = np.zeros((len(Filenum_list),64), dtype=np.float64)
    time_lengths = np.zeros((len(Filenum_list),2), dtype=np.int64)
    foas_input_path = '/global/cfs/cdirs/dune/www/data/2x2/LRS_det_config_run2/foas_csv/FOAS_20251203_103411.csv'
    foas_config = pd.read_csv(foas_input_path)
    foas_trig_status = foas_config['trig_active'].to_numpy(dtype=int)
    foas_thd = foas_config['threshold'].to_numpy(dtype=int)
    sum_adc_chan = get_sumadc_chans()
    new_dead_array = get_new_dead_array()
    #count = np.zeros((8,64))
    sipm_channels = ([4,5,6,7,8,9] + \
                    [10,11,12,13,14,15] + \
                    [20,21,22,23,24,25] + \
                    [26,27,28,29,30,31] + \
                    [36,37,38,39,40,41] + \
                    [42,43,44,45,46,47] + \
                    [52,53,54,55,56,57] + \
                    [58,59,60,61,62,63])
    start_value = 0
    #for Nf in Filenum_list:
    for file in Filenum_list:
        file = f'/global/cfs/cdirs/dune/www/data/2x2/nearline_run2/flowed_light/source_dtg_bin1/two_trig_40us_period_500tick/intensity_scan/hv_60kV_beamC_20uA/decays/{file}'
        #if file_num > 0:
        #     start_value = second_start_valu
        if file_num <= len(Filenum_list): #508:
            #print('file number:', file_num)
            file_num += 1
            if not os.path.isfile(file):
                continue
            else:
                #file = f'/global/cfs/cdirs/dune/www/data/2x2/nearline_run2/flowed_light/source_rn_bin1/injection/mpd_run_data_rctl_774_p{Nf}.FLOW.hdf5'
                size_bytes = os.path.getsize(file)
                #print('File size (GB):', size_bytes / 1e9)
                #helper_idx = 0
                #if size_bytes > 4.0:
                #    helper_idx += 1
                #    file = f'/global/cfs/cdirs/dune/www/data/2x2/nearline_run2/flowed_light/source_rn_bin1/injection/mpd_run_data_rctl_774_p{Nf+helper_idx}.FLOW.hdf5'
                with h5py.File(file, 'r') as h5:
                    event_length = np.shape(h5['light/events/data']['utime_ms'])[0]
                    #print("File length (events):", event_length)
                    del event_length
                    time_lengths[file_num-1, 0] = np.min(h5['light/events/data']['utime_ms'])
                    time_lengths[file_num-1, 1] = np.max(h5['light/events/data']['utime_ms'])
                    file_length = (np.max(h5['light/events/data']['utime_ms']) - np.min(h5['light/events/data']['utime_ms'])) / 1e3
                    offbeam_wvfm_v1 =  h5['light/wvfm/data']['samples'][:,:,:,:400].astype(float)
                    offbeam_wvfm_v2 =  thd_correct(offbeam_wvfm_v1/4)[:, :, sipm_channels, :]
                    del offbeam_wvfm_v1
                    trigger_channels = np.zeros(64, dtype=np.int64)
                    for event in range(np.shape(offbeam_wvfm_v2)[0]):
                        float_prod = offbeam_wvfm_v2[event, :, :, :] #* new_dead_array[:,:,np.newaxis]
                        float_prod = np.nan_to_num(float_prod, nan=0.0, posinf=0.0, neginf=0.0)
                        one_event = float_prod.astype(np.int64)
                        del float_prod
                        one_event_reshaped = one_event.reshape(8, 8, 6, 400)
                        del one_event
                        sum_wvfms = np.nansum(one_event_reshaped, axis=2, dtype=np.int64)
                        del one_event_reshaped
                        sum_wvfms_flat = np.concatenate(sum_wvfms, axis=0)
                        del sum_wvfms
                        for i in range(64):
                            sum_adc_idx = sum_adc_chan[i]
                            #check_1 = (foas_trig_status[sum_adc_idx] == 1)
                            check_2 = (np.max(sum_wvfms_flat[i,70:130]) >= (foas_thd[sum_adc_idx]/4))
                            #if (check_1*check_2) == 1:
                            if check_2 == 1:
                                trigger_channels[sum_adc_idx] += 1
                            #del check_1
                            del check_2
                        del sum_wvfms_flat
                    del offbeam_wvfm_v2
                    channel_rates = trigger_channels / file_length
                    del trigger_channels
                    output_rates[file_num-1,:] += channel_rates
                    del channel_rates
    #print('Time So Far:', np.sum(time_lengths[:,1]-time_lengths[:,0]) / 1e3 / 60, 'minutes')
    return output_rates, time_lengths

def get_sumadc_chans():
    csv_file_path = '/global/cfs/cdirs/dune/users/ajwhite/2x2_LRS_DataAssess/2025_Calibration/102025_ThreshTests/Find_Energy_Thd_Co60/11192025_Corrected_Sum2ADC_Map.csv'
    sum_channel_map = pd.read_csv(csv_file_path)
    sum_adc_chan = sum_channel_map['sum_adc_chan'].to_numpy(dtype=int)
    return sum_adc_chan

def get_new_dead_array():
    sipm_channels = ([4,5,6,7,8,9] + \
            [10,11,12,13,14,15] + \
            [20,21,22,23,24,25] + \
            [26,27,28,29,30,31] + \
            [36,37,38,39,40,41] + \
            [42,43,44,45,46,47] + \
            [52,53,54,55,56,57] + \
            [58,59,60,61,62,63])

    dead_array = [np.array([7,20]),
                    np.array([]),
                    np.array([22,54]),
                    np.array([61]),
                    np.array([36,47]),
                    np.array([]),
                    np.array([20,21,22,23,46,47]),
                    np.array([4,15])]

    new_dead_array = np.ones((8,48))
    for i in range(8):
        for j in range(48):
            if len(dead_array[i]) > 0:
                for k in dead_array[i]:
                    loc_48 = np.where(sipm_channels == k)[0][0]
                    new_dead_array[i,loc_48] *= 0
    return new_dead_array


def get_every_10min(last_filenum):
    Num_files = last_filenum
    Num_switch = 126
    Num_switch_2 = 330
    # in minutes: 
    if last_filenum > Num_switch_2:
        total_time = ((Num_files-Num_switch_2)*0.5 +  (Num_switch_2-Num_switch) + (Num_switch * 2))
        first_set = np.arange(0, Num_switch, 5)
        second_set = np.arange(Num_switch+1, Num_switch_2, 10)
        third_set = np.arange(Num_switch_2+1, Num_files, 20)
        number_list = np.concatenate((first_set, second_set, third_set))
    else:
        total_time = ((Num_files-Num_switch) +  (Num_switch * 2)) 
        first_set = np.arange(0, Num_switch, 5)
        #print(first_set)
        second_set = np.arange(Num_switch+1, Num_files, 10)
        number_list = np.concatenate((first_set, second_set))

    return number_list

def output_plot(trigger_rate_array, file_length_array, png_output_name):
    fig, ax = plt.subplots(figsize=(12, 7))
    time_minutes = np.int32((file_length_array[1,-1] - file_length_array[0,0]) / (1e3*60))
    ax.set_title(f'Reconstructed Trigger Rates During Next {time_minutes} Minutes of HV Ramp\nMORCs Run 60168 (LRS Run 948)', fontsize=14, fontweight='bold')
    #plt.axvspan(37.5, 39.5, color='yellow', alpha=0.15, label='Swapped LCMs')
    plt.axvspan(-0.5, 7.5, color='green', alpha=0.15, label='ACL Channels')
    ax.axvspan(53.5, 55.5, color='coral', alpha=0.5, label='Swapped ACL/LCMs', hatch='xx', fill=False)
    ax.axvspan(61.5, 63.5, color='coral', alpha=0.5, hatch='xx', fill=False)
    ax.axvspan(15.5, 23.5, color='green', alpha=0.15)
    ax.axvspan(31.5, 39.5, color='green', alpha=0.15)
    ax.axvspan(47.5, 55.5, color='green', alpha=0.15)
    plt.axvline(-0.5, color='gold', linewidth=2)
    plt.axvline(15.4, color='gold', linewidth=2)
    plt.axvline(15.6, color='rebeccapurple', linewidth=2)
    plt.axvline(31.4, color='rebeccapurple', linewidth=2)
    plt.axvline(31.6, color='cornflowerblue', linewidth=2)
    plt.axvline(47.4, color='cornflowerblue', linewidth=2)
    plt.axvline(47.6, color='darkorange', linewidth=2)
    plt.axvline(63.5, color='darkorange', linewidth=2)
    #
    x_axis_array = np.arange(0, 64, 1)
    cmap = plt.cm.Blues  # other good choices: "Blues_r", "PuBu", "GnBu"

    colors = cmap(np.linspace(0.1, 1, len(trigger_rate_array)-1))
    for i in range(len(trigger_rate_array)-1):
        channel_rates = trigger_rate_array[i+1,:] / trigger_rate_array[0,:]
        ax.plot(x_axis_array, channel_rates, color=colors[i], linestyle='dashed', alpha=0.4)
        ax.scatter(x_axis_array, channel_rates, color=colors[i], marker='o')
    ax.plot(x_axis_array, trigger_rate_array[0,:], color='darkorange', linewidth=1, marker='o', label='Pre-Injection Rate')

    n_points = np.shape(file_length_array)[-1]
    #dt = 2
    t_start = np.int32((file_length_array[1,0] - file_length_array[0,0]) / (1e3*60))  # in minutes                
    t_end = time_minutes   

    norm = mcolors.Normalize(vmin=t_start, vmax=t_end)

    # ScalarMappable used ONLY to drive the colorbar
    sm = cm.ScalarMappable(norm=norm, cmap=cmap)
    sm.set_array([]) 
    cax = ax.inset_axes([1.01, 0.0, 0.02, 1.0])  # [x0, y0, width, height]
    cbar = fig.colorbar(sm, cax=cax)
    cbar.set_label("Time (hr:min)")

    # ---- Custom tick formatter: hr:min ---- #
    def fmt_hr_min(value, pos):
        total_minutes = int(round(value))
        hr = total_minutes // 60
        mn = total_minutes % 60
        return f"{hr}:{mn:02d}"

    # Choose a nice set of ticks (e.g., every 30 minutes, starting at 5)
    ticks = np.arange(t_start, t_end + 1, 30)
    cbar.set_ticks(ticks)
    cbar.formatter = ticker.FuncFormatter(fmt_hr_min)
    cbar.update_ticks()


    ax.set_xlabel('Channel Number: Sum ADC')
    ax.set_xlim(-1, 64)
    #ax.set_ylim(-0.2, 1.3)
    ax.set_ylabel('Trigger Rate [Hz]')
    ax.legend(facecolor="white", edgecolor="black", framealpha=1.0, loc='center', bbox_to_anchor=(0.375, 0.85))  
    ax.grid(True)

    textstr = f'Module 1 Channels'
    props = dict(boxstyle='round', facecolor='white', edgecolor='gold')
    ax.text(
        0.125, 0.95, textstr,
        transform=ax.transAxes,
        fontsize=12,
        verticalalignment='top',
        horizontalalignment='center',
        bbox=props
    )
    textstr2 = f'Module 0 Channels'
    props = dict(boxstyle='round', facecolor='white', edgecolor='rebeccapurple')
    ax.text(
        0.375, 0.95, textstr2,
        transform=ax.transAxes,
        fontsize=12,
        verticalalignment='top',
        horizontalalignment='center',
        bbox=props
    )
    textstr3 = f'Module 3 Channels'
    props = dict(boxstyle='round', facecolor='white', edgecolor='cornflowerblue')
    ax.text(
        0.625, 0.95, textstr3,
        transform=ax.transAxes,
        fontsize=12,
        verticalalignment='top',
        horizontalalignment='center',
        bbox=props
    )
    textstr4 = f'Module 2 Channels'
    props = dict(boxstyle='round', facecolor='white', edgecolor='darkorange')
    ax.text(
        0.875, 0.95, textstr4,
        transform=ax.transAxes,
        fontsize=12,
        verticalalignment='top',
        horizontalalignment='center',
        bbox=props
    )
    plt.tight_layout()
    #pdf_output = PdfPages(pdf_output_name, "w")
    plt.savefig(png_output_name)
    #pdf_output.savefig(fig)
    plt.close(fig)



#def main(last_filenum=1200, pdf_output_name='Trigger_Rates_Rn_Injection.pdf'):
    #number_list = get_every_10min(last_filenum)
def main(filename=None):
    #with open("/global/cfs/cdirs/dune/users/ajwhite/2x2_LRS_DataAssess/Commissioning_Code/Triggering_Work/Radon_Run1_Files.txt") as f:
    #    files_list = [line.strip() for line in f]

    #chunk_size = 500
    #start_idx = multiple
    #batch = files_list[start_idx:(start_idx + chunk_size)]

    #trigger_rate_array, file_length_array = track_rate_changes(batch)
    batch  = [filename]
    trigger_rate_array, file_length_array = track_rate_changes(batch)
    #plot_name = f'/global/cfs/cdirs/dune/users/ajwhite/2x2_LRS_DataAssess/Commissioning_Code/Triggering_Work/Rn_Trigger_Plots/Progress_Out/Rate_Change_Plot_allchans_{start_idx}_to_{start_idx + chunk_size}.png'
    #output_plot(trigger_rate_array, file_length_array, plot_name)
    #npz_path = '/global/cfs/cdirs/dune/users/ajwhite/2x2_LRS_DataAssess/Commissioning_Code/Triggering_Work/Rn_Trigger_Plots/Progress_Out/trigger_rate_andtime_RadonRun1_allchans.npz'
    #npz_path = '/global/cfs/cdirs/dune/users/ajwhite/2x2_LRS_DataAssess/Commissioning_Code/Triggering_Work/Rn_Trigger_Plots/Progress_Out/trigger_rate_andtime_RadonRun1_parallelTest.npz'
    #if not os.path.exists(npz_path):
    #    np.savez(npz_path, t_rate=trigger_rate_array, f_length=file_length_array)
    #    print("Created new NPZ file.")
    #else:
    #    old = np.load(npz_path)
    #    old_t_rate = old['t_rate']
    #    old_f_length = old['f_length']

    #    combined_t_rate = np.concatenate([old_t_rate, trigger_rate_array], axis=0)
    #    print("Combined t_rate shape:", combined_t_rate.shape)
    #    combined_f_length = np.concatenate([old_f_length, file_length_array], axis=0)
    #    print("Combined f_length shape:", combined_f_length.shape)
    #    np.savez(npz_path, t_rate=combined_t_rate, f_length=combined_f_length)
    #    print("Saved combined data to NPZ file.")
    import sys
    row = np.concatenate([trigger_rate_array.ravel(), file_length_array[0].ravel()], axis=0)
    np.savetxt(sys.stdout, row.reshape(1, -1))



if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-f', '--filename', default=None, required=False, type=str, \
                        help='''string corresponding to data list chunk subset.''')
    args = parser.parse_args()
    main(**vars(args))


