Source code for polaris.ocean.tasks.manufactured_solution.viz

import datetime

import cmocean  # noqa: F401
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

from polaris import Step
from polaris.ocean.resolution import resolution_to_subdir
from polaris.ocean.tasks.manufactured_solution.exact_solution import (
    ExactSolution,
)
from polaris.viz import plot_horiz_field, use_mplstyle


[docs]class Viz(Step): """ A step for visualizing the output from the manufactured solution test case Attributes ---------- resolutions : list of float The resolutions of the meshes that have been run """
[docs] def __init__(self, component, resolutions, taskdir): """ Create the step Parameters ---------- component : polaris.Component The component the step belongs to resolutions : list of float The resolutions of the meshes that have been run taskdir : str The subdirectory that the task belongs to """ super().__init__(component=component, name='viz', indir=taskdir) self.resolutions = resolutions for resolution in resolutions: mesh_name = resolution_to_subdir(resolution) self.add_input_file( filename=f'mesh_{mesh_name}.nc', target=f'../init/{mesh_name}/culled_mesh.nc') self.add_input_file( filename=f'init_{mesh_name}.nc', target=f'../init/{mesh_name}/initial_state.nc') self.add_input_file( filename=f'output_{mesh_name}.nc', target=f'../forward/{mesh_name}/output.nc') self.add_output_file('comparison.png')
[docs] def run(self): """ Run this step of the test case """ plt.switch_backend('Agg') config = self.config resolutions = self.resolutions nres = len(resolutions) section = config['manufactured_solution'] eta0 = section.getfloat('ssh_amplitude') use_mplstyle() fig, axes = plt.subplots(nrows=nres, ncols=3, figsize=(12, 2 * nres)) rmse = [] error_range = None for i, res in enumerate(resolutions): mesh_name = resolution_to_subdir(res) ds_mesh = xr.open_dataset(f'mesh_{mesh_name}.nc') ds_init = xr.open_dataset(f'init_{mesh_name}.nc') ds = xr.open_dataset(f'output_{mesh_name}.nc') exact = ExactSolution(config, ds_init) t0 = datetime.datetime.strptime(ds.xtime.values[0].decode(), '%Y-%m-%d_%H:%M:%S') tf = datetime.datetime.strptime(ds.xtime.values[-1].decode(), '%Y-%m-%d_%H:%M:%S') t = (tf - t0).total_seconds() ssh_model = ds.ssh.values[-1, :] rmse.append(np.sqrt(np.mean((ssh_model - exact.ssh(t).values)**2))) # Comparison plots ds['ssh_exact'] = exact.ssh(t) ds['ssh_error'] = ssh_model - exact.ssh(t) if error_range is None: error_range = np.max(np.abs(ds.ssh_error.values)) cell_mask = ds_init.maxLevelCell >= 1 patches, patch_mask = plot_horiz_field( ds, ds_mesh, 'ssh', ax=axes[i, 0], cmap='cmo.balance', t_index=ds.sizes["Time"] - 1, vmin=-eta0, vmax=eta0, cmap_title="SSH", cell_mask=cell_mask) plot_horiz_field(ds, ds_mesh, 'ssh_exact', ax=axes[i, 1], cmap='cmo.balance', vmin=-eta0, vmax=eta0, cmap_title="SSH", patches=patches, patch_mask=patch_mask) plot_horiz_field(ds, ds_mesh, 'ssh_error', ax=axes[i, 2], cmap='cmo.balance', cmap_title="dSSH", vmin=-error_range, vmax=error_range, patches=patches, patch_mask=patch_mask) axes[0, 0].set_title('Numerical solution') axes[0, 1].set_title('Analytical solution') axes[0, 2].set_title('Error (Numerical - Analytical)') pad = 5 for ax, res in zip(axes[:, 0], resolutions): ax.annotate(f'{res}km', xy=(0, 0.5), xytext=(-ax.yaxis.labelpad - pad, 0), xycoords=ax.yaxis.label, textcoords='offset points', size='large', ha='right', va='center') fig.savefig('comparison.png', bbox_inches='tight', pad_inches=0.1)