#!/usr/bin/env python3

"""
Script containing python test suite for SCREAM's CIME
namelist-related infrastructure.
"""

from utils import check_minimum_python_version, expect, ensure_pylint, run_cmd_assert_result, get_timestamp

check_minimum_python_version(3, 6)

from machines_specs import is_machine_supported

import unittest, argparse, sys, os, shutil
from pathlib import Path

EAMXX_DIR         = Path(__file__).resolve().parent.parent
EAMXX_CIME_DIR    = EAMXX_DIR / "cime_config"
EAMXX_SCRIPTS_DIR = EAMXX_DIR / "scripts"
CIME_SCRIPTS_DIR  = EAMXX_DIR.parent.parent / "cime" / "scripts"

CONFIG = {
    "machine"  : None,
}

###############################################################################
class TestBuildnml(unittest.TestCase):
###############################################################################

    ###########################################################################
    def _create_test(self, extra_args, env_changes=""):
    ###########################################################################
        """
        Convenience wrapper around create_test. Returns list of full paths to created cases. If multiple cases,
        the order of the returned list is not guaranteed to match the order of the arguments.
        """
        test_id = f"cmd_nml_tests-{get_timestamp()}"
        extra_args.append("-t {}".format(test_id))

        full_run = (
            set(extra_args)
            & set(["-n", "--namelist-only", "--no-setup", "--no-build", "--no-run"])
        ) == set()
        if full_run:
            extra_args.append("--wait")

        output = run_cmd_assert_result(
            self, f"{env_changes} {CIME_SCRIPTS_DIR}/create_test {' '.join(extra_args)}")
        cases = []
        for line in output.splitlines():
            if "Case dir:" in line:
                casedir = line.split()[-1]
                self.assertTrue(os.path.isdir(casedir), msg="Missing casedir {}".format(casedir))
                cases.append(casedir)

        self.assertTrue(len(cases) > 0, "create_test made no cases")

        self._dirs_to_cleanup.extend(cases)

        return cases[0] if len(cases) == 1 else cases

    ###########################################################################
    def _chg_atmconfig(self, changes, case, buff=True, reset=False, expect_lost=False):
    ###########################################################################
        buffer_opt = "" if buff else "--no-buffer"

        for name, value in changes:
            orig = run_cmd_assert_result(self, f"./atmquery {name} --value", from_dir=case)
            self.assertNotEqual(orig, value)

            run_cmd_assert_result(self, f"./atmchange {buffer_opt} {name}={value}", from_dir=case)
            curr_value = run_cmd_assert_result(self, f"./atmquery {name} --value", from_dir=case)
            self.assertEqual(curr_value, value)

        if reset:
            run_cmd_assert_result(self, "./atmchange --reset", from_dir=case)

        run_cmd_assert_result(self, "./case.setup", from_dir=case)

        for name, value in changes:
            curr_value = run_cmd_assert_result(self, f"./atmquery {name} --value", from_dir=case)
            if expect_lost:
                self.assertNotEqual(curr_value, value)
            else:
                self.assertEqual(curr_value, value)

    ###########################################################################
    def setUp(self):
    ###########################################################################
        self._dirs_to_cleanup = []

    ###########################################################################
    def tearDown(self):
    ###########################################################################
        for item in self._dirs_to_cleanup:
            shutil.rmtree(item)

    ###########################################################################
    def test_doctests(self):
    ###########################################################################
        """
        Run doctests for all eamxx/cime_config python files and nml-related files in scripts
        """
        run_cmd_assert_result(self, "python3 -m doctest *.py", from_dir=EAMXX_CIME_DIR)
        run_cmd_assert_result(self, "python3 -m doctest atm_manip.py", from_dir=EAMXX_SCRIPTS_DIR)

    ###########################################################################
    def test_pylint(self):
    ###########################################################################
        """
        Run pylint on all eamxx/cime_config python files and nml-related files in scripts
        """
        ensure_pylint()
        run_cmd_assert_result(self, "python3 -m pylint --disable C,R,E0401,W1514 *.py", from_dir=EAMXX_CIME_DIR)
        run_cmd_assert_result(self, "python3 -m pylint --disable C,R,E0401,W1514 atm_manip.py", from_dir=EAMXX_SCRIPTS_DIR)

    ###########################################################################
    def test_xmlchange_propagates_to_atmconfig(self):
    ###########################################################################
        """
        Test that xmlchanges impact atm config files
        """
        case = self._create_test("ERS_Ln22.ne30_ne30.F2010-SCREAMv1 --no-build".split())

        # atm config should match case test opts
        case_rest_n = run_cmd_assert_result(self, "./xmlquery REST_N --value", from_dir=case)
        atm_freq    = run_cmd_assert_result(self, "./atmquery Frequency --value", from_dir=case)
        self.assertEqual(case_rest_n, atm_freq)

        # Change XML and check that atmquery reflects this change
        new_rest_n = "6"
        self.assertNotEqual(new_rest_n, case_rest_n)
        run_cmd_assert_result(self, f"./xmlchange REST_N={new_rest_n}", from_dir=case)
        run_cmd_assert_result(self, "./case.setup", from_dir=case)
        new_case_rest_n = run_cmd_assert_result(self, "./xmlquery REST_N --value", from_dir=case)
        new_atm_freq    = run_cmd_assert_result(self, "./atmquery Frequency --value", from_dir=case)
        self.assertEqual(new_case_rest_n, new_rest_n)
        self.assertEqual(new_atm_freq, new_rest_n)

    ###########################################################################
    def test_atmchanges_are_preserved(self):
    ###########################################################################
        """
        Test that atmchanges are not lost when eamxx setup is called
        """
        case = self._create_test("SMS.ne30_ne30.F2010-SCREAMv1 --no-build".split())

        self._chg_atmconfig([("atm_log_level", "trace")], case)

    ###########################################################################
    def test_manual_atmchanges_are_lost(self):
    ###########################################################################
        """
        Test that manual atmchanges are lost when eamxx setup is called
        """
        case = self._create_test("SMS.ne30_ne30.F2010-SCREAMv1 --no-build".split())

        # An unbuffered atmchange is semantically the same as a manual edit
        self._chg_atmconfig([("atm_log_level", "trace")], case, buff=False, expect_lost=True)

    ###########################################################################
    def test_reset_atmchanges_are_lost(self):
    ###########################################################################
        """
        Test that manual atmchanges are lost when eamxx setup is called
        """
        case = self._create_test("SMS.ne30_ne30.F2010-SCREAMv1 --no-build".split())

        # An unbuffered atmchange is semantically the same as a manual edit
        self._chg_atmconfig([("atm_log_level", "trace")], case, reset=True, expect_lost=True)

    ###########################################################################
    def test_manual_atmchanges_are_not_lost_hack_xml(self):
    ###########################################################################
        """
        Test that manual atmchanges are not lost when eamxx setup is called if
        xml hacking is enabled.
        """
        case = self._create_test("SMS.ne30_ne30.F2010-SCREAMv1 --no-build".split())

        run_cmd_assert_result(self, f"./xmlchange SCREAM_HACK_XML=TRUE", from_dir=case)

        self._chg_atmconfig([("atm_log_level", "trace")], case, buff=False)

    ###########################################################################
    def test_multiple_atmchanges_are_preserved(self):
    ###########################################################################
        """
        Test that multiple atmchanges are not lost when eamxx setup is called
        """
        case = self._create_test("SMS.ne30_ne30.F2010-SCREAMv1 --no-build".split())

        self._chg_atmconfig([("atm_log_level", "trace"), ("output_to_screen", "true")], case)

    ###########################################################################
    def test_atmchanges_are_preserved_testmod(self):
    ###########################################################################
        """
        Test that atmchanges are not lost when eamxx setup is called when that
        parameter is impacted by an active testmod
        """
        def_mach_comp = \
            run_cmd_assert_result(self, "../CIME/Tools/list_e3sm_tests cime_tiny", from_dir=CIME_SCRIPTS_DIR).splitlines()[-1].split(".")[-1]
        case = self._create_test(f"SMS.ne30_ne30.F2010-SCREAMv1.{def_mach_comp}.scream-scream_example_testmod_atmchange --no-build".split())

        self._chg_atmconfig([("cubed_sphere_map", "84")], case)

    ###########################################################################
    def test_atmchanges_on_arrays(self):
    ###########################################################################
        """
        Test that atmchange works for array data
        """
        case = self._create_test("SMS.ne30_ne30.F2010-SCREAMv1 --no-build".split())

        self._chg_atmconfig([("surf_mom_flux", "40.0,2.0")], case)

###############################################################################
def parse_command_line(args, desc):
###############################################################################
    """
    Parse custom args for this test suite. Will delete our custom args from
    sys.argv so that only args meant for unittest remain.
    """
    help_str = \
"""
{0} [TEST] [TEST]
OR
{0} --help

\033[1mEXAMPLES:\033[0m
    \033[1;32m# Run basic pylint and doctests for everything \033[0m
    > {0}

""".format(Path(args[0]).name)

    parser = argparse.ArgumentParser(
        usage=help_str,
        description=desc,
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument("-m", "--machine",
                        help="Provide machine name. This is required for full (not dry) runs")

    args, py_ut_args = parser.parse_known_args()
    sys.argv[1:] = py_ut_args

    return args

###############################################################################
def scripts_tests(machine=None):
###############################################################################
    # Store test params in environment
    if machine:
        expect(is_machine_supported(machine), "Machine {} is not supported".format(machine))
        CONFIG["machine"] = machine

    unittest.main(verbosity=2)

###############################################################################
def _main_func(desc):
###############################################################################
    scripts_tests(**vars(parse_command_line(sys.argv, desc)))

if (__name__ == "__main__"):
    _main_func(__doc__)
