import xarray
import numpy
import matplotlib.pyplot as plt
from compass.step import Step
[docs]class Viz(Step):
    """
    A step for visualizing a cross-section through the internal wave
    """
[docs]    def __init__(self, test_case):
        """
        Create the step
        Parameters
        ----------
        test_case : compass.TestCase
            The test case this step belongs to
        """
        super().__init__(test_case=test_case, name='viz')
        self.add_input_file(filename='output.nc',
                            target='../forward/output.nc')
        self.add_output_file('uNormal_depth_section_t0.png')
        self.add_output_file('pt_depth_section_t0.png')
        self.add_output_file('sa_depth_section_t0.png')
        self.add_output_file('layerThickness_depth_section_t0.png') 
[docs]    def run(self):
        """
        Run this step of the test case
        """
        ds = xarray.open_dataset('output.nc')
        figsize = [6.4, 4.8]
        markersize = 20
        if 'Time' not in ds.dims:
            print('Dataset missing time dimension')
            return
        nSteps = ds.sizes['Time']  # number of timesteps
        tend = nSteps - 1
        for j in [0, tend]:
            ds1 = ds.isel(Time=j)
            # prep all variables for uNormal plot
            ds1 = ds1.sortby('yEdge')
            nCells = ds1.sizes['nCells']
            nEdges = ds1.sizes['nEdges']
            nVertLevels = ds1.sizes['nVertLevels']
            xEdge = numpy.zeros((nEdges))
            xEdge = ds1.xEdge
            xCell = numpy.zeros((nCells))
            xCell = ds1.xCell
            yCell = numpy.zeros((nCells))
            yCell = ds1.yCell
            xEdge_mid = numpy.median(xEdge)
            edgeMask_x = numpy.equal(xEdge, xEdge_mid)
            zIndex = xarray.DataArray(data=numpy.arange(nVertLevels),
                                      dims='nVertLevels')
            zInterface = numpy.zeros((nCells, nVertLevels + 1))
            zInterface[:, 0] = ds1.ssh.values
            for zIndex in range(nVertLevels):
                thickness = ds1.layerThickness.isel(nVertLevels=zIndex)
                thickness = thickness.fillna(0.)
                zInterface[:, zIndex + 1] = \
                    
zInterface[:, zIndex] - thickness.values
            zMid = numpy.zeros((nCells, nVertLevels))
            for zIndex in range(nVertLevels):
                zMid[:, zIndex] = (zInterface[:, zIndex]
                                   + numpy.divide(zInterface[:, zIndex + 1]
                                                  - zInterface[:, zIndex], 2.))
            # Solve for lateral boundaries of uNormal at cell centers for
            # x-section
            cellsOnEdge = ds1.cellsOnEdge
            cellsOnEdge_x = cellsOnEdge[edgeMask_x, :]
            yEdges = numpy.zeros((len(cellsOnEdge_x)+1))
            for i in range(len(cellsOnEdge_x)):
                if cellsOnEdge[i, 1] == 0:
                    yEdges[i] = yCell[cellsOnEdge_x[i, 0] - 1]
                    yEdges[i+1] = yCell[cellsOnEdge_x[i, 0] - 1]
                elif cellsOnEdge[i, 1] == 0:
                    yEdges[i] = yCell[cellsOnEdge_x[i, 1] - 1]
                    yEdges[i+1] = yCell[cellsOnEdge_x[i, 1] - 1]
                else:
                    yEdges[i] = min(yCell[cellsOnEdge_x[i, 0] - 1],
                                    yCell[cellsOnEdge_x[i, 1] - 1])
                    yEdges[i+1] = max(yCell[cellsOnEdge_x[i, 0] - 1],
                                      yCell[cellsOnEdge_x[i, 1] - 1])
            zInterfaces_mesh, yEdges_mesh = numpy.meshgrid(zInterface[0, :],
                                                           yEdges)
            normalVelocity = numpy.zeros((nCells, nVertLevels))
            normalVelocity = ds1.normalVelocity
            normalVelocity_xmesh = normalVelocity[edgeMask_x, :]
            # Figures
            plt.figure(figsize=figsize, dpi=100)
            cmax = numpy.max(numpy.abs(normalVelocity_xmesh.values))
            plt.pcolormesh(numpy.divide(yEdges_mesh, 1e3),
                           zInterfaces_mesh,
                           normalVelocity_xmesh.values,
                           cmap='RdBu', vmin=-1.*cmax, vmax=cmax)
            plt.xlabel('y (km)')
            plt.ylabel('z (m)')
            cbar = plt.colorbar()
            cbar.ax.set_title('uNormal (m/s)')
            plt.savefig('uNormal_depth_section_t{}.png'.format(j),
                        bbox_inches='tight', dpi=200)
            plt.close()
            # ------------------------------------------------------------------
            # Plot cell-centered variables
            # ------------------------------------------------------------------
            # Prep variables for cell quantities
            cellIndex = numpy.subtract(cellsOnEdge_x[1:, 0], 1)
            yEdge = numpy.zeros((nEdges))
            yEdge = ds1.yEdge
            yEdge_x = yEdge[edgeMask_x]
            zInterfaces_mesh, yCells_mesh = numpy.meshgrid(zInterface[0, :],
                                                           yEdge_x)
            # Import cell quantities
            layerThickness = numpy.zeros((nCells, nVertLevels))
            layerThickness = ds1.layerThickness
            layerThickness_x = layerThickness[cellIndex, :]
            temperature = numpy.zeros((nCells, nVertLevels))
            temperature = ds1.temperature
            temperature_z = temperature.mean(dim='nCells')
            zMid_z = zMid.mean(axis=0)
            temperature_x = temperature[cellIndex, :]
            salinity = numpy.zeros((nCells, nVertLevels))
            salinity = ds1.salinity
            salinity_x = salinity[cellIndex, :]
            w = numpy.zeros((nCells, nVertLevels))
            w = ds1.vertVelocityTop
            w_x = w[cellIndex, :]
            # Figures
            plt.figure(figsize=figsize, dpi=100)
            plt.plot(temperature_z.values, zMid_z)
            plt.xlabel('PT (C)')
            plt.ylabel('z (m)')
            plt.savefig('pt_depth_t{}.png'.format(j),
                        bbox_inches='tight', dpi=200)
            plt.close()
            plt.figure(figsize=figsize, dpi=100)
            plt.pcolormesh(numpy.divide(yCells_mesh, 1e3),
                           zInterfaces_mesh,
                           temperature_x.values, cmap='viridis')
            plt.xlabel('y (km)')
            plt.ylabel('z (m)')
            cbar = plt.colorbar()
            cbar.ax.set_title('PT (C)')
            plt.savefig('pt_depth_section_t{}.png'.format(j),
                        bbox_inches='tight', dpi=200)
            plt.close()
            plt.figure(figsize=figsize, dpi=100)
            plt.pcolormesh(numpy.divide(yCells_mesh, 1e3),
                           zInterfaces_mesh,
                           salinity_x.values, cmap='viridis')
            plt.xlabel('y (km)')
            plt.ylabel('z (m)')
            cbar = plt.colorbar()
            cbar.ax.set_title('SA (g/kg)')
            plt.savefig('sa_depth_section_t{}.png'.format(j),
                        bbox_inches='tight', dpi=200)
            plt.close()
            plt.figure(figsize=figsize, dpi=100)
            plt.pcolormesh(numpy.divide(yCells_mesh, 1e3),
                           zInterfaces_mesh,
                           w_x.values, cmap='viridis')
            plt.xlabel('y (km)')
            plt.ylabel('z (m)')
            cbar = plt.colorbar()
            cbar.ax.set_title('h (m)')
            plt.savefig('w_depth_section_t{}.png'.format(j),
                        bbox_inches='tight', dpi=200)
            plt.close()
            plt.figure(figsize=figsize, dpi=100)
            plt.pcolormesh(numpy.divide(yCells_mesh, 1e3),
                           zInterfaces_mesh,
                           layerThickness_x.values, cmap='viridis')
            plt.xlabel('y (km)')
            plt.ylabel('z (m)')
            cbar = plt.colorbar()
            cbar.ax.set_title('h (m)')
            plt.savefig('layerThickness_depth_section_t{}.png'.format(j),
                        bbox_inches='tight', dpi=200)
            plt.close()