#!/usr/bin/env python3

import argparse, pathlib, os, sys, subprocess

###############################################################################
def expect(condition, error_msg, exc_type=SystemExit, error_prefix="ERROR:"):
###############################################################################
    """
    Similar to assert except doesn't generate an ugly stacktrace. Useful for
    checking user error, not programming error.

    >>> expect(True, "error1")
    >>> expect(False, "error2")
    Traceback (most recent call last):
        ...
    SystemExit: ERROR: error2
    """
    if not condition:
        msg = error_prefix + " " + error_msg
        raise exc_type(msg)

###############################################################################
def run_cmd(cmd, env):
###############################################################################
    """
    We've noticed some LLNL clusters cannot run MPI programs via subprocess for
    some unknown and extremely subtle reason. Therefore, we use os.system unless
    the user requests buffering, in which case we must use subprocess.
    """
    prev_env = dict(os.environ)

    print("RUN: {}\nFROM: {}".format(cmd, os.getcwd()))

    os.environ = env
    stat = os.system(cmd)
    os.environ = prev_env

    return stat

###############################################################################
def run_cmd_buff(cmd, env):
###############################################################################
    arg_stderr = subprocess.STDOUT

    print("RUN: {}\nFROM: {}".format(cmd, os.getcwd()))

    proc = subprocess.Popen(cmd,
                            shell=True,
                            stdout=subprocess.PIPE,
                            stderr=arg_stderr,
                            env=env)

    output, _ = proc.communicate()
    if output is not None:
        try:
            output = output.decode('utf-8', errors='ignore')
            output = output.strip()
        except AttributeError:
            print("Something funky for output")

    stat = proc.wait()

    return stat, output

###############################################################################
def parse_command_line(args, description):
###############################################################################
    parser = argparse.ArgumentParser(
        usage="""\n{0} -- <executable> [<exe-args>]
OR
{0} --help

\033[1mEXAMPLES:\033[0m
    \033[1;32m# Lightweight python wrapper for a test that handles thread binding.\033[0m
    > {0} -- ./my_exec

""".format(pathlib.Path(args[0]).name),
        description=description,
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )

    parser.add_argument("-b","--buffer-output", action="store_true",
                        help="Buffer all out and print all at once. Useful for avoiding interleaving output when multiple tests are running concurrently")
    parser.add_argument("-p","--print-omp-affinity", help="Print threads affinity at runtime",action="store_true")

    return parser.parse_known_args(args[1:])

###############################################################################
def run_test(buffer_output,print_omp_affinity,exec_args):
###############################################################################

    expect(len(exec_args) >= 2 and exec_args[0] == "--",
           "Expected exec_args='-- <exe> [<exe-args>], got: {}".format(exec_args))

    env = os.environ

    omp_env = []
    if print_omp_affinity:
        omp_env.append("OMP_DISPLAY_ENV=verbose")
        omp_env.append("OMP_DISPLAY_AFFINITY=true")

    # Be sure to respect user env settings if they are there
    if "OMP_PROC_BIND" not in env:
        omp_env.append("OMP_PROC_BIND=spread")

    if "OMP_PLACES" not in env:
        omp_env.append("OMP_PLACES=threads")

    # Check if we're run with MPI
    if "${EKAT_MPIRUN_COMM_WORLD_SIZE}" in env and "${EKAT_MPIRUN_COMM_WORLD_RANK}" in env:
        nranks = int(env["${EKAT_MPIRUN_COMM_WORLD_SIZE}"])
        myrank = int(env["${EKAT_MPIRUN_COMM_WORLD_RANK}"])
    else:
        nranks = 1
        myrank = 0

    # Start with all env vars settings
    cmd = " ".join(omp_env)

    # If run with --resource-spec-file, we will have some special env var set
    if ${TEST_LAUNCHER_MANAGE_RESOURCES}:
        if "CTEST_RESOURCE_GROUP_COUNT" in env:
            print ("HAS RESOURCE SPECS")
            # Total number of resources for this test (for all ranks/threads)
            res_count = int(env["CTEST_RESOURCE_GROUP_COUNT"])
            my_res_id = myrank % res_count

            # Get the list of devices ids in my resource group
            key = "CTEST_RESOURCE_GROUP_" + str(my_res_id)
            expect(key in env, "Error! CTEST_RESOURCE_GROUP_COUNT found in the env, but can't find {}".format(key))
            my_res_name = env[key]
            expect (my_res_name=="devices", "Error! My ctest resource group should be 'devices'.")
            key += "_DEVICES"
            expect(key in env, "Error! CTEST_RESOURCE_GROUP_{} found in the env, but can't find {}".format(my_res_id,key))
            my_res_str = env[key]

            # Split the CTEST_RESOURCE_GROUP_[N]_DEVICE string into tokens, and get the slots ids
            my_res = my_res_str.split(';')
            ids = []
            for res in my_res:
                id_slots_tokens = res.split(',')
                id_tokens = id_slots_tokens[0].split(':')
                ids.append(int(id_tokens[1]))

            # CTest should already present these in lexicographic order, which should match
            # numeric order thanks to our recent change to prepend leading 0's to fill empty
            # digits. This sort shouldn't be necessary but it can't hurt.
            ids.sort()

            # Convert back to str so ids can be used in shell commands
            ids = [str(item) for item in ids]

            # Now get the chunk of ids that belongs to this rank
            expect (len(ids) % nranks == 0, "Error! Number of ranks does not divide the number of resources.")
            ids_per_rank = int(len(ids) / nranks)
            my_ids = ids[myrank*ids_per_rank : (myrank+1)*ids_per_rank]

            # This file is to be run through cmake's configure_file, which will expand the
            # variable TEST_LAUNCHER_ON_GPU to either True or False
            if ${TEST_LAUNCHER_ON_GPU}:
                expect (len(my_ids)==1, "Error! For GPU runs, each rank should use only one device.")
                exec_args.append(" --ekat-kokkos-device {}".format(ids[0]))
            else:
                # Start placing on the correct cpu set
                cmd += " taskset -c " + ",".join(my_ids)

        else:
            print ("WARNING: NO RESOUCE SPECS. EKAT was configured to manage resources, but no --resource-spec-file option was provided.")

    else:
        print ("EKAT is not managing resources.")

    cmd += " {}".format(" ".join(exec_args[1:]))

    if buffer_output:
        stat, out = run_cmd_buff(cmd, env)
        sys.stdout.flush()
        sys.stdout.write(out)
        sys.stdout.flush()
    else:
        stat = run_cmd(cmd, env)

    return stat==0

###############################################################################
def _main_func(description):
###############################################################################
    options, rest = parse_command_line(sys.argv, description)
    success = run_test(**vars(options), exec_args=rest)

    sys.exit(0 if success else 1)

###############################################################################

if (__name__ == "__main__"):
    _main_func(__doc__)
