#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 24 07:54:51 2023

@author: pkooloth
"""

"""
1D Energy Balance model equation
This script should be ran serially (because it is 1D). Computes the time-varying solution to 1D EBM with prescribed forcing.
Requires: Dedalus-v2 (https://dedalus-project.readthedocs.io/en/v2_master/)
"""

import numpy as np
import time
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 14})
plt.rcParams['axes.autolimit_mode'] = 'round_numbers'


from dedalus import public as de
from dedalus.extras.plot_tools import quad_mesh, pad_limits
from dedalus.extras import flow_tools

import logging
logger = logging.getLogger(__name__)

#Parameters
ncc_cutoff = 1e-7
tolerance = 1e-6
f=1


b = 208
N = 512
tf = 365*300 #300 years

# Bases and domain
x_basis = de.Chebyshev('x', N, interval=(0,1), dealias=3/2)
domain = de.Domain([x_basis], np.float64)

F = 5 #Maximum forcing in W/m^2
def GHForcing(*args):
    """This function applies its arguments and returns the forcing"""
    t = args[0].value # this is a scalar; we use .value to get its value
    x = args[1].data # this is an array; we use .data to get its values
    if t< tf*dt/2:
        ft = F*t*(tf/2*dt)**(-1) 
    else: ft = F -F*(t-tf/2*dt)*(tf/2*dt)**(-1) 
    return ft
    #return 5e-2*(t*tf*dt - t**2)/6 #5e-2*t

def Forcing(*args, domain=domain, F=GHForcing):
    """
    This function takes arguments *args, a function F, and a domain and
    returns a Dedalus GeneralFunction that can be applied.
    """
    return de.operators.GeneralFunction(domain, layout='g', func=F, args=args)

de.operators.parseables['F'] = Forcing

# Problem
problem = de.IVP(domain, variables=['T','Tt', 'Tx', 'alpha', 'u', 'Tu', 'Tux', 'alpha_T','Ti', 'Tut'],ncc_cutoff=ncc_cutoff)
problem.parameters['a'] = 2.1 #accounts for Planck's radiation in aT-b
problem.parameters['b'] = b #accounts for greenhouse effects
problem.parameters['C'] = 3.#ocean heat capacity
#problem.parameters['alpha'] = 0.31 #albedo
problem.parameters['D'] = 0.6 #diffusivity
problem.parameters['del_sol']  = 0.482/2
problem.parameters['solar_const'] = 0.25*1332*f
problem.parameters['w12'] = 100

#problem.add_equation("a * T  - D/cos(x)*Cp*dx(cos(x)*Tx) = b + (1 - alpha) * solar_const  *  (1  + del_sol *  ( 1 - 3 * ( sin(x) )**2 ) ) ")

problem.add_equation("C*Tt + a*T - D*dx(Tx)  = -b + F(t,x) + (1-alpha-u)*solar_const*(1+del_sol*(1-3*(x**2)))")
problem.add_equation("C*dt(Tu) + a*Tu - D*dx(Tux) = (-alpha_T*Tu-1)*solar_const*(1+del_sol*(1-3*(x**2)))")

problem.add_equation("Tx  -  (1-x**2)*dx(T)  = 0")
problem.add_equation("Tux  -  (1-x**2)*dx(Tu)  = 0")
problem.add_equation("Tt  -  dt(T)  = 0")
problem.add_equation("Tut  -  dt(Tu)  = 0")
problem.add_equation("dt(Ti) = 0")



problem.add_equation("u  = 0") #no control
problem.add_equation("alpha = 0.62 - 0.3*(1 + exp(-10*(T+10)))**(-1)")
problem.add_equation("alpha_T  =  -3*(1 + exp(-10*(T+10)))**(-2)*(exp(-1*(T+10)))")


problem.add_bc("Tx(x='left') = 0")
problem.add_bc("Tx(x='right') = 0")

problem.add_bc("Tux(x='left') = 0")
problem.add_bc("Tux(x='right') = 0")


# Build solver
solver = problem.build_solver(de.timesteppers.RK443)
solver.stop_wall_time = 600
solver.stop_iteration = tf

# Initial conditions
x = domain.grid(0, scales=1)
T = solver.state['T']
Ti = solver.state['Ti']
Tu = solver.state['Tu']
Tx = solver.state['Tx']
Tut = solver.state['Tut']
Tt = solver.state['Tt']
alpha = solver.state['alpha']
u = solver.state['u']


T['g'] =  np.loadtxt('Th.out') #initial state
Ti['g'] =  np.loadtxt('Tf1.out') #target state

T.differentiate(0, out=Tx)
T.set_scales(1)
u.set_scales(1)
alpha['g'] = 0.62 - 0.3*(1 + np.exp(-10*(T['g']+10)))**(-1)
Tu['g'] = 0*T['g']
u['g'] = 0*T['g']
Tu.set_scales(1)

# Store data for final plot
T_list = [np.copy(T['g'])]
u_list = [np.copy(u['g'])]
Tu_list = [np.copy(Tu['g'])]
t_list = [solver.sim_time]
Tm_list = [np.copy(np.mean(T['g']))]

if (np.min(T['g'])+10.01<0):
   s_list = [x[np.argmin(T['g'][T['g']+10>0])]]
else: s_list = [1.0]
um_list  = [0.0]
print(s_list)
#Monitoring convergence
#exit
flow = flow_tools.GlobalFlowProperty(solver, cadence=100)
flow.add_property("Tt", name='dQ')

# Main loop
dt = 1e-3
try:
    logger.info('Starting loop')
    while solver.proceed:
        solver.step(dt)
        if solver.iteration % 600 == 0:
            T.set_scales(1) 
            u.set_scales(1)
            Tu.set_scales(1)
            T_list.append(np.copy(T['g']))
            Tu_list.append(np.copy(Tu['g']))
            u_list.append(np.copy(u['g']))
            t_list.append(solver.sim_time)
            Tm_list.append(np.mean(T['g']))
            um_list.append(np.sqrt(np.trapz(u['g']**2,x)))

            if (np.min(T['g'])+10.001<0):
                s_list.append(x[np.argmin(T['g'][T['g']+10>0])])
            else: s_list.append(1.0)
        if solver.iteration % 100 == 0:
            logger.info('Iteration: %i, Time: %e, dt: %e' %(solver.iteration, solver.sim_time, dt))
            logger.info('Max dQ = %f' %flow.max('dQ'))
            T_mean = np.mean(T['g'])
            logger.info('%f', T['g'][N//2])
            if (np.abs(flow.max('dQ'))<1e-8 and solver.iteration>1000):
                break
except:
    logger.error('Exception raised, triggering end of main loop.')
    raise
finally:
    solver.log_stats()
ft = np.zeros(len(t_list))
i=0
for t in np.array(t_list):  
    if t< tf*dt/2:
        ft[i] = F*t*(tf/2*dt)**(-1) 
    else: ft[i] = F -F*(t-tf/2*dt)*(tf/2*dt)**(-1)
    i=i+1
    
T.set_scales(1)
plt.plot(ft,np.arcsin(s_list)*180/np.pi)
plt.title('Hysteresis in iceline')
plt.ylabel(r'Iceline ($\theta$)')
plt.xlabel(r'F $(W/m^2)$')
plt.ylim((50,92))
plt.xlim((0,F+0.5))
plt.savefig('hysteresis.png', dpi=300, bbox_inches="tight")
plt.show()

np.save('xs.out', np.arcsin(s_list)*180/np.pi)
np.save('ft.out', ft)

#Tbt = np.loadtxt('Tbt.out')

levels = np.array([-10])
fig, ax = plt.subplots()
CT=ax.contourf(np.array(t_list)/365*1e3,np.arcsin(x)*180/np.pi,np.array(T_list).T,50,cmap=plt.cm.bwr,vmin=-30, vmax=30)
CT1 = ax.contour(np.array(t_list)/365*1e3,np.arcsin(x)*180/np.pi,np.array(T_list).T,levels,colors='k')
#CT2 = ax.contour(np.array(t_list)/365*1e3,np.arcsin(x)*180/np.pi,np.array(Tbt).T,levels,colors='w')
ax.set(ylabel=r'latitude', xlabel = 'Time (years)', title = 'Temperature variation')
cbar = fig.colorbar(CT)
cbar.ax.set_ylabel(r'$^0$ C')
ax.set_yticks(np.arange(0, 90.1, 10))
ax.set_xticks(np.arange(0, 300.1, 50))
plt.savefig('T_hysteresis.png', dpi=300, bbox_inches="tight")
plt.show()

