import glob
import os
import numpy as np
import h5py
import pickle 
import re

import numpy as np
import matplotlib.pylab as plt
import pandas as pd
import h5py
import glob
import os
import time


## Loading single skypatch. 

def load_single_skypatch(output_file):
    combined_data = {}
    with h5py.File(output_file, "r") as f:
        for core_key in f.keys():
            for key in f[core_key].keys():
                if key not in combined_data:
                    combined_data[key] = []
                combined_data[key].append(f[core_key][key][...])
    # Concatenate lists of arrays into single arrays for each key
    for key in combined_data:
        combined_data[key] = np.concatenate(combined_data[key], axis=0)
    return combined_data


def combine_multiple_skypatches(file_pattern):
    """
    Combines multiple HDF5 catalogs into a single dictionary of concatenated NumPy arrays.
    Additionally, checks for consistency in 'tree_node_mass' dimensions across files.

    Parameters:
    - file_pattern (str): Glob pattern to match HDF5 files (e.g., "data_skypatch_*.h5").

    Returns:
    - combined_catalog (dict): Dictionary with keys mapping to concatenated NumPy arrays.
    """
    combined_catalog = {}
    tree_node_mass_ndims = []  # To track the number of dimensions

    # Use glob to find all files matching the pattern
    output_files = glob.glob(file_pattern)

    # Check if any files were found
    if not output_files:
        raise FileNotFoundError(f"No files found matching pattern: {file_pattern}")

    # Optionally, sort the files for consistent ordering
    output_files.sort()

    print(f"Found {len(output_files)} files matching pattern '{file_pattern}'.")

    # Initialize expected number of dimensions for 'tree_node_mass'
    expected_ndim = None

    # Iterate over each file and accumulate data
    for idx, file_path in enumerate(output_files, 1):
        file_name = os.path.basename(file_path)
        print(f"Loading file {idx}/{len(output_files)}: {file_name}")
        try:
            catalog = load_single_skypatch(file_path)
        except Exception as e:
            print(f"Error loading file '{file_name}': {e}")
            continue  # Skip this file and proceed to the next

        # Check if 'tree_node_mass' exists in the catalog
        if 'tree_node_mass' in catalog:
            tree_node_mass_array = catalog['tree_node_mass']
            current_ndim = tree_node_mass_array.ndim
            tree_node_mass_ndims.append((file_name, current_ndim))

            if expected_ndim is None:
                expected_ndim = current_ndim
                print(f"Expected 'tree_node_mass' ndim set to {expected_ndim} based on '{file_name}'.")
            elif current_ndim != expected_ndim:
                print(f"Dimension mismatch in file '{file_name}': Expected ndim={expected_ndim}, Found ndim={current_ndim}")
                # Optionally, handle the mismatch (e.g., skip the file)
                print(f"Skipping file '{file_name}' due to dimension mismatch.")
                continue  # Skip this file

        else:
            print(f"'tree_node_mass' key not found in file '{file_name}'. Skipping.")
            continue  # Skip files without 'tree_node_mass'

        # Combine data from each catalog
        for key, data in catalog.items():
            if key not in combined_catalog:
                combined_catalog[key] = []
            combined_catalog[key].append(data)

    # After loading all files, attempt concatenation
    for key in combined_catalog:
        try:
            combined_catalog[key] = np.concatenate(combined_catalog[key], axis=0)
            print(f"Combined key '{key}' with shape {combined_catalog[key].shape}.")
        except ValueError as e:
            print(f"Error concatenating key '{key}': {e}")
            print("Please check the consistency of your data across all files.")
            # Optionally, handle the error (e.g., remove the problematic key)
            # For now, we'll raise the error to halt execution
            raise

    print("All catalogs have been successfully combined.")
    return combined_catalog


'''

### From previous loader

# Custom key function to extract <x> from the filename
def extract_core_number(filename):
    match = re.search(r'core_(\d+)', filename)
    return int(match.group(1)) if match else 0


def load_and_clean_single_catalog(fileIn):
    print('Catalog: ' + fileIn)
    with h5py.File(fileIn, 'r') as f:
        items = list(f.keys())
        # print(items)
        raw_data = {}
        for item in items:
            if isinstance(f[item], h5py.Dataset):
                raw_data[item] = f[item][()]
            elif isinstance(f[item], h5py.Group):
                group_data = {}
                for sub_item in f[item].keys():
                    group_data[sub_item] = f[item][sub_item][()]
                raw_data[item] = group_data
        
        f.close()
        print('Total number of original galaxies: %d'%len(raw_data[items[3]]))  # Displaying the length of mag_i_sdss as an example


    # Identify invalid entries across all datasets starting with "mag_"
    
    mag_invalid_indices = set()
    for key, value in raw_data.items():
        if key.startswith("mag_") and isinstance(value, np.ndarray):
            invalid_indices = np.where(np.isinf(value) | np.isnan(value))[0]
            
            if (len(invalid_indices) > 5):
                print(4*'>>' + f'  {key}: {len(invalid_indices)} /{len(value)} invalid ' + 4*'<<')
                
            if (len(invalid_indices) == 0): invalid_indices = [0]
            # invalid_indices = [0]
            mag_invalid_indices.update(invalid_indices)

            

    mag_invalid_indices = np.array(list(mag_invalid_indices))
    print(mag_invalid_indices)
    mag_invalid_indices = np.array([0])
    # print(mag_invalid_indices)
    

    cleaned_data = {}
    removed_data = {}  # New dictionary for storing removed items
    for key, value in raw_data.items():
        if isinstance(value, np.ndarray):
            valid_indices = mag_invalid_indices[mag_invalid_indices < len(value)]
            if key not in ['SED_wavelength', 'time_bins_SFH']:
                cleaned_value = np.delete(value, valid_indices, axis=0)
                removed_value = value[valid_indices]  # Extracting removed items
            else:
                cleaned_value = value
                removed_value = np.array([])  # No items removed for exceptions
            cleaned_data[key] = cleaned_value
            removed_data[key] = removed_value  # Storing removed items

    print('Total number of cleaned galaxies: %d'%len(cleaned_data[items[3]]))
    print('Total number of removed galaxies: %d'%len(removed_data[items[3]]))
    if (cleaned_data['redshift_true'].shape[0] > 0):
        print('Valid REDSHIFT: %.2f <<->> %.2f'%(np.min(cleaned_data['redshift_true']), np.max(cleaned_data['redshift_true'])))
        print('Invalid REDSHIFT: %.2f <<->> %.2f'%(np.min(removed_data['redshift_true']), np.max(removed_data['redshift_true'])))
    print(10*'=--=')
    

    return cleaned_data, removed_data, items  # Now also returning removed_data




def load_all_available_catalogs(dirIn=None, 
                                exclude_core_files_numbers=[]):
    all_data = {}
    all_removed_data = {}
    all_items = []
    special_items = ['SED_wavelength', 'time_bins_SFH']
    special_items_included = {key: False for key in special_items}
    
    files = sorted(glob.glob(os.path.join(dirIn, '*.hdf5')), key=extract_core_number)
    
    for file_path in files:
        # Extract the number following 'core_' using regular expressions
        match = re.search(r'core_(\d+)', file_path)
        
        if match:
            x_number = int(match.group(1))  # Convert the extracted part to an integer
            
            if x_number in exclude_core_files_numbers:
                continue  # Skip this file and move to the next one

        
        cleaned_data, removed_data, items = load_and_clean_single_catalog(file_path)

        if not all_data:
            all_data = {key: [] for key in cleaned_data if key not in special_items}
            all_items = items
            
        if not all_removed_data:
            all_removed_data = {key: [] for key in removed_data if key not in special_items}

        for key, value in cleaned_data.items():
            if key in special_items and not special_items_included[key]:
                all_data[key] = value
                special_items_included[key] = True
            elif key not in special_items:
                all_data[key].append(value)
                
        for key, value in removed_data.items():
            if key in special_items and not special_items_included[key]:
                all_removed_data[key] = value
                # special_items_included[key] = True
            elif key not in special_items:
                all_removed_data[key].append(value)

    for key in all_data:
        if key not in special_items:
            all_data[key] = np.concatenate(all_data[key], axis=0)
            
    for key in all_removed_data:
        if key not in special_items:
            all_removed_data[key] = np.concatenate(all_removed_data[key], axis=0)
    print('Grand total number of removed galaxies: %d'%len(all_removed_data[items[3]]))
    
    
    return all_data, all_removed_data, all_items



def load_survey_pickle(survey, dirIn_bands='Bands/'):
        
    if (survey=='LSST'):
        FILTER_NAME = dirIn_bands + 'LSST.pickle'
    elif (survey=='SPHEREx'):
        FILTER_NAME = dirIn_bands + 'SPHEREx.pickle'
    elif (survey=='COSMOS'):
        FILTER_NAME = dirIn_bands + 'COSMOS.pickle'      
    elif (survey=='WISE'):
        FILTER_NAME = dirIn_bands + 'WISE.pickle'      
    elif (survey=='LEGACYSURVEY'):
        FILTER_NAME = dirIn_bands + 'LEGACYSURVEY.pickle'       
    elif (survey=='2MASS'):
        FILTER_NAME = dirIn_bands + '2MASS.pickle'
    elif (survey=='F784'):
        FILTER_NAME = dirIn_bands + 'F784.pickle'
    else: 
        raise NotImplementedError("Filter specifications not included")
        
    with open(FILTER_NAME, 'rb') as f:
     central_wavelengths, bandpass_wavs, bandpass_vals, bandpass_names = pickle.load(f)
    
    return central_wavelengths, bandpass_wavs, bandpass_vals, bandpass_names
    
    
'''