#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 24 07:54:51 2023

@author: pkooloth
"""

"""
1D Energy Balance model with optimal control.
This script should be ran serially (because it is 1D); 
Computes time varying solution to the optimally controlled EBM with prescribed forcing
Requires: Dedalus-v2 (https://dedalus-project.readthedocs.io/en/v2_master/)
"""

import numpy as np
import time
import matplotlib.pyplot as plt
import matplotlib.colors as colors

from dedalus import public as de
from dedalus.extras.plot_tools import quad_mesh, pad_limits
from dedalus.extras import flow_tools

import logging
logger = logging.getLogger(__name__)

#Parameters
ncc_cutoff = 1e-7
tolerance = 1e-6
f=1


b = 205
N = 512
tf = 365*140 #140 years

# Bases and domain
x_basis = de.Chebyshev('x', N, interval=(0,1), dealias=3/2)
domain = de.Domain([x_basis], np.float64)

def GHForcing(*args):
    """This function applies its arguments and returns the forcing"""
    t = args[0].value # this is a scalar; we use .value to get its value
    x = args[1].data # this is an array; we use .data to get its values
    if t< tf*dt/2:
        ft = 3.5*t*(tf/2*dt)**(-1) 
    else: ft = 3.5 -3.5*(t-tf/2*dt)*(tf/2*dt)**(-1) 
    return ft
    

def Forcing(*args, domain=domain, F=GHForcing):
    """
    This function takes arguments *args, a function F, and a domain and
    returns a Dedalus GeneralFunction that can be applied.
    """
    return de.operators.GeneralFunction(domain, layout='g', func=F, args=args)

de.operators.parseables['F'] = Forcing

# Problem
problem = de.IVP(domain, variables=['T','Tt', 'Tx', 'alpha', 'u', 'Tu', 'Tux', 'alpha_T','Ti', 'Tut'],ncc_cutoff=ncc_cutoff)
problem.parameters['a'] = 2.1#1.151#accounts for Planck's radiation in aT-b
problem.parameters['b'] = b#accounts for greenhouse effects
problem.parameters['C'] = 3#ocean heat capacity
#problem.parameters['alpha'] = 0.31 #albedo
problem.parameters['D'] = 0.6 #diffusivity
problem.parameters['del_sol']  = 0.482/2
problem.parameters['solar_const'] = 0.25*1332*f
problem.parameters['w12'] = 10 #weight ratio

#problem.add_equation("a * T  - D/cos(x)*Cp*dx(cos(x)*Tx) = b + (1 - alpha) * solar_const  *  (1  + del_sol *  ( 1 - 3 * ( sin(x) )**2 ) ) ")

problem.add_equation("C*Tt + a*T - D*dx(Tx)  = -b + F(t,x) + (1-alpha-u)*solar_const*(1+del_sol*(1-3*(x**2)))")
problem.add_equation("C*dt(Tu) + a*Tu - D*dx(Tux) = (-alpha_T*Tu-1)*solar_const*(1+del_sol*(1-3*(x**2)))")

problem.add_equation("Tx  -  (1-x**2)*dx(T)  = 0")
problem.add_equation("Tux  -  (1-x**2)*dx(Tu)  = 0")
problem.add_equation("Tt  -  dt(T)  = 0")
problem.add_equation("Tut  -  dt(Tu)  = 0")
problem.add_equation("dt(Ti) = 0")


#problem.add_equation("u = F(t,x)/(solar_const*(1+del_sol*(1-3*(x**2))))")
problem.add_equation("u  = -((T-Ti)*Tu* (0.3/15)**2 )/w12") #+ Tt*Tut/(2e+1)**2
problem.add_equation("alpha = 0.62 - 0.3*(1 + exp(-10*(T+10)))**(-1)")
problem.add_equation("alpha_T  =  -3*(1 + exp(-10*(T+10)))**(-2)*(exp(-10*(T+10)))")
#problem.add_equation("alpha = 0.4 + 0.3*tanh(T-270)")

problem.add_bc("Tx(x='left') = 0")
problem.add_bc("Tx(x='right') = 0")

problem.add_bc("Tux(x='left') = 0")
problem.add_bc("Tux(x='right') = 0")


# Build solver
solver = problem.build_solver(de.timesteppers.RK443)
solver.stop_wall_time = 600
solver.stop_iteration = tf

# Initial conditions
x = domain.grid(0, scales=1)
T = solver.state['T']
Ti = solver.state['Ti']
Tu = solver.state['Tu']
Tx = solver.state['Tx']
Tut = solver.state['Tut']
Tt = solver.state['Tt']
alpha = solver.state['alpha']
u = solver.state['u']


T['g'] =  np.loadtxt('T1.out') #initial state
Ti['g'] =  np.loadtxt('Tf1.out') #target state
#E['g']= Cp * T['g'] #+  (Lv*hs*Rd*eo/(Rv*Ps) ) * np.exp( -(Lv/Rv ) * (T['g']**(-1)) ) *np.exp((Lv/Rv ) *1/Ta)
T.differentiate(0, out=Tx)
T.set_scales(1)
u.set_scales(1)
alpha['g'] = 0.62 - 0.3*(1 + np.exp(-10*(T['g']+10)))**(-1)
Tu['g'] = 0*T['g']
u['g'] = 0*T['g']
Tu.set_scales(1)

# Store data for final plot
T_list = [np.copy(T['g'])]
u_list = [np.copy(u['g'])]
Tu_list = [np.copy(Tu['g'])]
t_list = [solver.sim_time]
Tm_list = [np.copy(np.mean(T['g']))]

if (np.min(T['g'])+10.01<0):
   s_list = [x[np.argmin(T['g'][T['g']+10>0])]]
else: s_list = [1.0]
um_list  = [0.0]
print(s_list)
#Monitoring convergence
#exit
flow = flow_tools.GlobalFlowProperty(solver, cadence=100)
flow.add_property("Tt", name='dQ')

# Main loop
dt = 1e-3
try:
    logger.info('Starting loop')
    while solver.proceed:
        solver.step(dt)
        if solver.iteration % 100 == 0:
            T.set_scales(1) 
            u.set_scales(1)
            Tu.set_scales(1)
            T_list.append(np.copy(T['g']))
            Tu_list.append(np.copy(Tu['g']))
            u_list.append(np.copy(u['g']))
            t_list.append(solver.sim_time)
            Tm_list.append(np.mean(T['g']))
            um_list.append(np.sqrt(np.trapz(u['g']**2,x)))

            if (np.min(T['g'])+10.001<0):
                s_list.append(x[np.argmin(T['g'][T['g']+10>0])])
            else: s_list.append(1.0)
        if solver.iteration % 100 == 0:
            logger.info('Iteration: %i, Time: %e, dt: %e' %(solver.iteration, solver.sim_time, dt))
            logger.info('Max dQ = %f' %flow.max('dQ'))
            T_mean = np.mean(T['g'])
            logger.info('%f', T['g'][N//2])
            if (np.abs(flow.max('dQ'))<1e-8 and solver.iteration>1000):
                break
except:
    logger.error('Exception raised, triggering end of main loop.')
    raise
finally:
    solver.log_stats()
ft = np.zeros(len(t_list))
i=0
for t in np.array(t_list):  
    if t< tf*dt/2:
        ft[i] = 3.5*t*(tf/2*dt)**(-1) 
    else: ft[i] = 3.5 -3.5*(t-tf/2*dt)*(tf/2*dt)**(-1)
    i=i+1
    
T.set_scales(1)

np.savetxt('um.out', np.array(um_list))
np.savetxt('t.out', np.array(t_list)/365*1e3)

plt.plot(np.array(t_list)/365*1e3,um_list)
plt.title('Variation in mean albedo control')
plt.ylabel('Mean albedo control')
plt.xlabel('Time (years)')
plt.savefig('u.png', dpi=300, bbox_inches="tight")
plt.show()
u.set_scales(1)

plt.plot(np.array(t_list)/365*1e3,ft)
plt.title('External forcing')
plt.ylabel(r'Forcing $(W/m^2)$')
plt.xlabel('Time (years)')
plt.savefig('forcing.png', dpi=300, bbox_inches="tight")
plt.show()
u.set_scales(1)

w12 = np.zeros(19)

for i in np.arange(-3,-1,1):
    for j in range(9):
        w12[j+(i+3)*9] = (j+1)*10.0**(i)

w12[-1]=0.1


levels = np.array([-10])
fig1, ax1 = plt.subplots(figsize=(6,3))
CT2=ax1.contourf(np.array(t_list)/365*1e3,np.arcsin(x)*180/np.pi,np.array(u_list).T,w12,cmap=plt.cm.bone_r,vmin=0,vmax=0.1,norm=colors.LogNorm())
CT3 = ax1.contour(np.array(t_list)/365*1e3,np.arcsin(x)*180/np.pi,np.array(T_list).T,levels,colors='c')
ax1.set(ylabel=r'latitude', xlabel = 'Time (years)')
cbar1 = fig1.colorbar(CT2)
ax1.set_xlim([35,140])
ax1.set_yticks(np.arange(0, 90.1, 30))
ax1.set_xticks(np.arange(20, 140.1, 20))
plt.savefig('u_history.png', dpi=300, bbox_inches="tight")

plt.show()


levels = np.array([-10])
fig, ax = plt.subplots(figsize=(2,3))
CT=ax.contourf(np.array(t_list)/365*1e3,np.arcsin(x)*180/np.pi,np.array(u_list).T,w12,cmap=plt.cm.bone_r,vmin=1e-3,vmax=0.09,norm=colors.LogNorm())
CT1 = ax.contour(np.array(t_list)/365*1e3,np.arcsin(x)*180/np.pi,np.array(T_list).T,levels,colors='c')
ax.set(ylabel=r'latitude', xlabel = 'Time (years)')
cbar = fig.colorbar(CT)
ax.set_xlim([0,30.01])
ax.set_yticks(np.arange(0, 90.1, 30))
ax.set_xticks(np.arange(0, 30.1, 30))
plt.savefig('u_history_fi.png', dpi=300, bbox_inches="tight")
plt.show()

#Tbt = np.loadtxt('Tbt.out')


levels = np.array([-10]) 
fig, ax = plt.subplots()
CT=ax.contourf(np.array(t_list)/365*1e3,np.arcsin(x)*180/np.pi,np.array(T_list).T,levels=w12,cmap=plt.cm.bwr,vmin=-30, vmax=30)
CT1 = ax.contour(np.array(t_list)/365*1e3,np.arcsin(x)*180/np.pi,np.array(T_list).T,levels,colors='c')
#CT2 = ax.contour(np.array(t_list)/365*1e3,np.arcsin(x)*180/np.pi,np.array(Tbt).T,levels,colors='w')
ax.set(ylabel=r'latitude', xlabel = 'Time (years)', title = 'Temperature variation')
cbar = fig.colorbar(CT)
ax.set_yticks(np.arange(0, 90.1, 10))
ax.set_xticks(np.arange(0, 140.1, 20))
cbar.ax.set_ylabel(r'$^0$ C')
plt.show()

