#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 24 07:54:51 2023

@author: pkooloth
"""

"""
1D Diffusive Energy Balance Model (EBM)
This script should be ran serially (because it is 1D). The script computes a steady solution to the 1D EBM.
Requires: Dedalus-v2 (https://dedalus-project.readthedocs.io/en/v2_master/)
"""

import numpy as np
import time
import matplotlib.pyplot as plt

from dedalus import public as de
from dedalus.extras.plot_tools import quad_mesh, pad_limits
from dedalus.extras import flow_tools

import logging
logger = logging.getLogger(__name__)

#Parameters
ncc_cutoff = 1e-7
tolerance = 1e-6
f=1

Ta = 273.16

N = 256


# Bases and domain
x_basis = de.Legendre('x', N, interval=(0,1))
domain = de.Domain([x_basis], np.float64)

# Problem
problem = de.NLBVP(domain, variables=['T', 'Tx', 'alpha',],ncc_cutoff=ncc_cutoff)
problem.parameters['a'] =2.1#accounts for Planck's radiation in aT-b
problem.parameters['b'] = 211 #accounts for greenhouse effects
problem.parameters['C'] = 3 #ocean heat capacity
#problem.parameters['alpha'] = 0.31 #albedo
problem.parameters['D'] = 0.6 #diffusivity
problem.parameters['del_sol']  = 0.482/2
problem.parameters['solar_const'] = 0.25*1332*f

problem.add_equation("Tx  -  (1-x**2)*dx(T)  = 0")
problem.add_equation(" a*T - D*dx(Tx) = -b + (1-alpha)*solar_const*(1+del_sol*(1-3*(x**2)))")



problem.add_equation("alpha = 0.62 - 0.3*(1 + exp(-10*(T+10)))**(-1)")


problem.add_bc("Tx(x='left') = 0")
problem.add_bc("Tx(x='right') = 0")


# Build solver
solver = problem.build_solver()

# Initial conditions
x = domain.grid(0, scales=1)
T = solver.state['T']
Tx = solver.state['Tx']
alpha = solver.state['alpha']



T['g'] = 10*np.ones((N)) #

T.set_scales(1)

Tx.set_scales(1)
Tx['g'] = (1-x**2)*Tx['g']
T.set_scales(1)

alpha['g'] = 0.62 - 0.3*(1 + np.exp(-10*(T['g']+10)))**(-1)


# Store data for final plot
T.set_scales(1)
T_list = [np.copy(T['g'])]
#solver.solve()
# Iterations
pert = solver.perturbations.data
pert.fill(1+tolerance)
start_time = time.time()
while np.sum(np.abs(pert)) > tolerance:
    solver.newton_iteration()
    logger.info('Perturbation norm: {}'.format(np.sum(np.abs(pert))))
    #logger.info('R iterate: {}'.format(R['g'][0]))
end_time = time.time()

T.set_scales(1)
print(T['g'])

plt.plot(np.arcsin(x)*180/np.pi,T['g'])
plt.title('Temperature profile')
plt.ylabel(r'Temperature ($ ^0$ C)')
plt.xlabel(r'latitude ($\theta$)')
plt.savefig('Tprofile.png')
plt.show()

T_mean = np.trapz(T['g'],x)
print(T_mean)
print(x[T['g']<-10.0])

alpha.set_scales(1)
alpha_p = np.trapz(alpha['g'],x)
print(alpha_p)

np.savetxt('T.out', T['g']) 
np.savetxt('lat.out', x)

