#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Jul  9 12:16:25 2024

@author: pkooloth
"""

import numpy as np
import time
import matplotlib.pyplot as plt
from mpl_toolkits.basemap import Basemap
import netCDF4 as nc

from sklearn.linear_model import Ridge

N = 120
cp = 1008
nlat = 48
nlon = 96
Nk=5

def read_data(): #Read data from Green's function experiments
    
    dT=np.load('dmsegz.npy')/cp 
    dalb=np.load('dalb.npy')
    dTo=np.load('dsst.npy') 
    dcod=np.load('dcod_int.npy') 
    dla = -np.load('dla.npy')
    
    dF = np.load('dF.npy')
    
    coalb = np.load('coalb.npy')
    
    dFtoa = np.load('dFg.npy')
    
    dFe = np.reshape(dFtoa,(120,4608))*np.reshape(coalb,(4608))

    dFe = np.reshape(dFe, (120,48,96)) #-dFp

    for i in range(N):
        dF[i,:,:] = dF[i,:,:]/np.max(dF[i,:,:])
        
    dTg = np.matmul(np.linalg.pinv(np.reshape(dF.T,(4608,120))),np.reshape(dT.T,(4608,120)))
    dalbg = np.matmul(np.linalg.pinv(np.reshape(dF.T,(4608,120))),np.reshape(dalb.T,(4608,120)))
    dcodg = np.matmul(np.linalg.pinv(np.reshape(dF.T,(4608,120))),np.reshape(dcod.T,(4608,120)))
    dTog = np.matmul(np.linalg.pinv(np.reshape(dF.T,(4608,120))),np.reshape(dTo.T,(4608,120)))
    dlag = np.matmul(np.linalg.pinv(np.reshape(dF.T,(4608,120))),np.reshape(dla.T,(4608,120)))

    dFg = np.matmul(np.linalg.pinv(np.reshape(dF.T,(4608,120))),np.reshape(dFe.T,(4608,120)))
    
    dTs = np.max((dTg))
    dalbs = np.max((dalbg))
    dTos = np.max((dTog))
    dlas = np.max((dlag))
    dcods = np.max((dcodg))
    dFs = np.max((dFg))


    dTg = (dTg)/dTs
    dalbg = (dalbg)/dalbs
    dTog = (dTog)/dTos
    dlag = (dlag)/dlas
    dcodg = (dcodg)/dcods
    dFg = (dFg)/dFs

    std = np.array([dTs, dTos, dalbs, dcods, dlas])
            
    dX = np.vstack((dTg,dTog,dalbg,dcodg,dlag))
    
    return dX, dFg, dF, std, dFs

def compute_patch_area():
    
    lat = np.load('lat.npy')
    lon = np.load('lon.npy')
    land = np.load('land.npy')
    
    dA = np.zeros((48,96))       
    dphi = (lat[1]-lat[0])*np.pi/180
    dtheta = (lon[1] - lon[0])*np.pi/180

    for j in range(48):
        for i in range(96):
            phi = lat[j]*np.pi/180
            dA[j,i] = (np.sin(phi + dphi/2)-np.sin(phi -dphi/2))*(dtheta)
            
    return dA
              
def plotting_blocks():
    
    bij = np.zeros((120,48,96))

    dlon = 12
    dlat = 3.2

    for j in range(15):
        for i in range(8):    
            js = int(np.round(j*dlat,decimals=0))
            je = int(np.round((j+1)*dlat,decimals=0))
            #print(js,je)
            bij[j*8+i,js:je,i*dlon:(i+1)*dlon] = 1
            
    return bij
    
def read_testdata(std, dF): #Read data from independent test dataset
    
    tdT=np.load('tdmsegz.npy')/cp
    tdalb=np.load('tdalb.npy')
    tdTo=np.load('tdsst.npy') 
    tdcod=np.load('tdcod_int.npy') 
    tdla =- np.load('tdla.npy')

    tcoalb = np.load('tcoalb.npy')


    tdTg = np.matmul(np.linalg.pinv(np.reshape(dF.T,(4608,120))),np.reshape(tdT.T,(4608,1)))
    tdalbg = np.matmul(np.linalg.pinv(np.reshape(dF.T,(4608,120))),np.reshape(tdalb.T,(4608,1)))
    tdcodg = np.matmul(np.linalg.pinv(np.reshape(dF.T,(4608,120))),np.reshape(tdcod.T,(4608,1)))
    tdTog = np.matmul(np.linalg.pinv(np.reshape(dF.T,(4608,120))),np.reshape(tdTo.T,(4608,1)))
    tdlag = np.matmul(np.linalg.pinv(np.reshape(dF.T,(4608,120))),np.reshape(tdla.T,(4608,1)))

    tdTg = (tdTg)/std[0]
    tdTog = (tdTog)/std[1]
    tdalbg = (tdalbg)/std[2]
    tdcodg = (tdcodg)/std[3]
    tdlag = (tdlag)/std[4]
    
    tdF = np.load('dF6w.npy')*(tcoalb)  

    tdX = np.vstack((tdTg,tdTog,tdalbg,tdcodg,tdlag))
    
    return tdX, tdF

def read_testdata_2xCO2(std, dF):
    
    tdT=np.load('cdmsegz.npy')/cp
    tdalb=np.load('cdalb.npy')
    tdTo=np.load('cdsst_temp.npy')
    tdcod=np.load('cdcod.npy')
    tdla =- np.load('cdla.npy')

    tcoalb = np.load('ccoalb.npy')


    #Projection to patch basis

    tdTg = np.matmul(np.linalg.pinv(np.reshape(dF.T,(4608,120))),np.reshape(tdT.T,(4608,1)))
    tdalbg = np.matmul(np.linalg.pinv(np.reshape(dF.T,(4608,120))),np.reshape(tdalb.T,(4608,1)))
    tdcodg = np.matmul(np.linalg.pinv(np.reshape(dF.T,(4608,120))),np.reshape(tdcod.T,(4608,1)))
    tdTog = np.matmul(np.linalg.pinv(np.reshape(dF.T,(4608,120))),np.reshape(tdTo.T,(4608,1)))
    tdlag = np.matmul(np.linalg.pinv(np.reshape(dF.T,(4608,120))),np.reshape(tdla.T,(4608,1)))

    #Normalization

    tdTg = (tdTg)/std[0]
    tdTog = (tdTog)/std[1]
    tdalbg = (tdalbg)/std[2]
    tdcodg = (tdcodg)/std[3]
    tdlag = (tdlag)/std[4]
    
    tdF = np.load('cdF.npy') 

    tdX = np.vstack((tdTg,tdTog,tdalbg,tdcodg,tdlag))
    
    return tdX, tdF

def read_testdata_dipole(std, dF):
    
    tdT=np.load('ddmsegz.npy') 
    tdalb=np.load('ddalb.npy')
    tdTo=np.load('ddTo.npy') 
    tdcod=np.load('ddcod.npy') 
    tdla =- np.load('ddla.npy')

    tcoalb = np.load('dcoalb.npy')

    #Projection to patch basis

    tdTg = np.matmul(np.linalg.pinv(np.reshape(dF.T,(4608,120))),np.reshape(tdT.T,(4608,1)))
    tdalbg = np.matmul(np.linalg.pinv(np.reshape(dF.T,(4608,120))),np.reshape(tdalb.T,(4608,1)))
    tdcodg = np.matmul(np.linalg.pinv(np.reshape(dF.T,(4608,120))),np.reshape(tdcod.T,(4608,1)))
    tdTog = np.matmul(np.linalg.pinv(np.reshape(dF.T,(4608,120))),np.reshape(tdTo.T,(4608,1)))
    tdlag = np.matmul(np.linalg.pinv(np.reshape(dF.T,(4608,120))),np.reshape(tdla.T,(4608,1)))

    #Normalization

    tdTg = (tdTg)/std[0]
    tdTog = (tdTog)/std[1]
    tdalbg = (tdalbg)/std[2]
    tdcodg = (tdcodg)/std[3]
    tdlag = (tdlag)/std[4]
    
    tdF = np.load('ddF.npy')*(tcoalb)   

    tdX = np.vstack((tdTg,tdTog,tdalbg,tdcodg,tdlag))
    
    return tdX, tdF

def compute_kernel(X,y,a=5e-1):
    reg = Ridge(alpha=a,fit_intercept=False, max_iter=10**6).fit(X.T,y.T)
    K = reg.coef_
    
    return K

def get_Kp(p,K): #reduced Comprehensive kernel
    
    U, S, Vh = np.linalg.svd(K,full_matrices=False)
    Kp =np.dot(np.dot(U[:, -p:],np.diag(S[-p:])), Vh[-p:])
    
    return Kp

def get_modei(i,K,dF, dFs,std):
    
    U, S, Vh = np.linalg.svd(K,full_matrices=False)
    modek = U[0:N,-i]

    Fi = recon_from_basis(modek,dF,dFs)
        
    ClMode = np.zeros((Nk,nlat,nlon))
    
    for k in range(Nk):
        modek = Vh.T[k*120:(k+1)*120,-i]
        ClMode[k,] = recon_from_basis(modek,dF,dFs*std[k])
        
        tdFg = np.matmul(K[:,k*120:(k+1)*120],modek)   
           
    return Fi, ClMode

def project_to_basis(X,dF,x_std):
    Xg = np.matmul(np.linalg.pinv(np.reshape(dF.T,(4608,120))),np.reshape(X.T,(4608,1)))/x_std
    return Xg

def recon_from_basis(Xg, dF, std):
    rX = np.zeros((nlat,nlon))
    for i in range(N):
        rX = rX + Xg[i]*dF[i,:,:]*std
        
    return rX

def compute_da_dT(dX):
    Km = np.zeros((5,120,120))
    for kk in range(5):
        X = dX[0:120,]
        y = dX[kk*120:(kk+1)*120,]
    #exit
        reg = Ridge(alpha=5e-1,fit_intercept=False, max_iter=10**6).fit(X.T,y.T)
        Km[kk,] = reg.coef_
    return Km

def read_responses_test():
    tdT=np.load('tdmsegz.npy')/cp
    tdalb=np.load('tdalb.npy')
    tdTo=np.load('tdsst.npy') 
    tdcod=np.load('tdcod_int.npy')
    tdla = -np.load('tdla.npy')
    
    return tdT, tdalb, tdTo, tdcod, tdla

def read_responses_test_2xCO2():
    tdT=np.load('cdmsegz.npy')/cp
    tdalb=np.load('cdalb.npy')
    tdTo=np.load('cdsst.npy') 
    tdcod=np.load('cdcod.npy')
    tdla = -np.load('cdla.npy')
    
    return tdT, tdalb, tdTo, tdcod, tdla

def read_responses_test_dipole():
    tdT=np.load('ddmsegz.npy') 
    tdalb=np.load('ddalb.npy')
    tdTo=np.load('ddTo.npy') 
    tdcod=np.load('ddcod.npy') 
    tdla =- np.load('ddla.npy')
    
    return tdT, tdalb, tdTo, tdcod, tdla

def compute_lambda(tdXg, Kp, k, dF, dFs, tdTm):
    tdFg = np.matmul(Kp[:,k*120:(k+1)*120],tdXg)

    tdFr = recon_from_basis(tdFg, dF, dFs)
        
    Fk = tdFr.reshape(48*96)/tdTm

    return -Fk.reshape((48,96))     
   

def get_data():
    
    dT=np.load('dmsegz.npy')/cp 
    dalb=np.load('dalb.npy')
    dTo=np.load('dsst.npy') 
    dcod=np.load('dcod_int.npy') 
    dla = -np.load('dla.npy')
    
    dF = np.load('dF.npy')
    
    dX = np.stack((dT, dTo, dalb, dcod, dla))
    
    return dX

def read_data_uq(dX):
    
    dT=dX[0] 
    dalb=dX[1]
    dTo=dX[2]
    dcod=dX[3]
    dla = dX[4]
    
    dF = np.load('dF.npy')
    
    coalb = np.load('coalb.npy')
    
    dFtoa = np.load('dFg.npy')
    
    dFe = np.reshape(dFtoa,(120,4608))*np.reshape(coalb,(4608))

    dFe = np.reshape(dFe, (120,48,96)) 

    for i in range(N):
        dF[i,:,:] = dF[i,:,:]/np.max(dF[i,:,:])
        
    dTg = np.matmul(np.linalg.pinv(np.reshape(dF.T,(4608,120))),np.reshape(dT.T,(4608,120)))
    dalbg = np.matmul(np.linalg.pinv(np.reshape(dF.T,(4608,120))),np.reshape(dalb.T,(4608,120)))
    dcodg = np.matmul(np.linalg.pinv(np.reshape(dF.T,(4608,120))),np.reshape(dcod.T,(4608,120)))
    dTog = np.matmul(np.linalg.pinv(np.reshape(dF.T,(4608,120))),np.reshape(dTo.T,(4608,120)))
    dlag = np.matmul(np.linalg.pinv(np.reshape(dF.T,(4608,120))),np.reshape(dla.T,(4608,120)))

    dFg = np.matmul(np.linalg.pinv(np.reshape(dF.T,(4608,120))),np.reshape(dFe.T,(4608,120)))
    
    dTs = np.max((dTg))
    dalbs = np.max((dalbg))
    dTos = np.max((dTog))
    dlas = np.max((dlag))
    dcods = np.max((dcodg))
    dFs = np.max((dFg))


    dTg = (dTg)/dTs
    dalbg = (dalbg)/dalbs
    dTog = (dTog)/dTos
    dlag = (dlag)/dlas
    dcodg = (dcodg)/dcods
    dFg = (dFg)/dFs

    std = np.array([dTs, dTos, dalbs, dcods, dlas])
            
    dX = np.vstack((dTg,dTog,dalbg,dcodg,dlag))
    
    return dX, dFg, dF, std, dFs