import cdms2


def regrid_to_lower_res(mv1, mv2, regrid_tool, regrid_method):
    """Regrid transient variable toward lower resolution of two variables."""

    axes1 = mv1.getAxisList()
    axes2 = mv2.getAxisList()

    # use nlat to decide data resolution, higher number means higher data
    # resolution. For the difference plot, regrid toward lower resolution
    if len(axes1[1]) <= len(axes2[1]):
        mv_grid = mv1.getGrid()
        mv1_reg = mv1
        mv2_reg = mv2.regrid(
            mv_grid, regridTool=regrid_tool, regridMethod=regrid_method
        )
        mv2_reg.units = mv2.units

    else:
        mv_grid = mv2.getGrid()
        mv2_reg = mv2
        mv1_reg = mv1.regrid(
            mv_grid, regridTool=regrid_tool, regridMethod=regrid_method
        )
        mv1_reg.units = mv1.units

    return mv1_reg, mv2_reg


file1 = '/global/cfs/cdirs/e3sm/www/zhang40/tests_e3sm_diags_issue561/lat_lon/SST_CL_HadISST/HadISST_CL-SST-ANN-global_test.nc'
file2 = '/global/cfs/cdirs/e3sm/www/zhang40/tests_e3sm_diags_issue561/lat_lon/SST_CL_HadISST/HadISST_CL-SST-ANN-global_ref.nc'

var1 = cdms2.open(file1)('SST')
var2 = cdms2.open(file2)('sst')

var1_reg, var2_reg = regrid_to_lower_res(var1, var2, "esmf", "conservative")
diff = var2_reg - var1_reg
outfile = 'new_diff.nc'
fout = cdms2.open(outfile, 'w')
fout.write(diff)
fout.close()

