Simulated Satellite Scan Strategies - MPI Example

In [ ]:
# Are you using a special reservation for a workshop?
# If so, set it here:
nersc_reservation = None

# Load common tools for all lessons
import sys
sys.path.insert(0, "..")
from lesson_tools import (
    check_nersc,
)
nersc_host, nersc_repo = check_nersc(reservation=nersc_reservation)
if nersc_host is not None:
    %reload_ext slurm_magic
In [ ]:
%%writefile simscan_satellite_mpi.py

import toast
from toast.mpi import MPI

# Load common tools for all lessons
import sys
sys.path.insert(0, "..")
from lesson_tools import (
    fake_focalplane
)

import numpy as np
import healpy as hp
import matplotlib.pyplot as plt

from toast.todmap import (
    slew_precession_axis,
    TODSatellite,
    get_submaps_nested,
    OpPointingHpix,
    OpAccumDiag
)
from toast.map import (
    DistPixels
)

env = toast.Environment.get()
env.set_threads(2)
print(env)

# We have many small observations, so we should use a small
# group size.  Here we choose a group size of one process.

comm = toast.Comm(world=MPI.COMM_WORLD, groupsize=1)

# Create our fake focalplane

fp = fake_focalplane()

detnames = list(sorted(fp.keys()))
detquat = {x: fp[x]["quat"] for x in detnames}

# Scan parameters

alpha = 50.0      # precession opening angle, degrees
beta = 45.0       # spin opening angle, degrees
p_alpha = 25.0    # precession period, minutes
p_beta = 1.25     # spin period, minutes
samplerate = 8.9  # sample rate, Hz
hwprpm = 5.0      # HWP rotation in RPM
nside = 32        # Healpix NSIDE

# We will use one observation per day, with no gaps in between, and
# run for one year.

obs_samples = int(24 * 3600.0 * samplerate) - 1
nobs = 366

# Slew the precession axis so that it completes one circle

deg_per_day = 360.0 / nobs

# Create distributed data

data = toast.Data(comm)

# Append observations

for ob in range(nobs):
    # Am I in the group that has this observation?
    if (ob % comm.ngroups) != comm.group:
        # nope...
        continue
    obsname = "{:03d}".format(ob)
    obsfirst = ob * (obs_samples + 1)
    obsstart = 24 * 3600.0
    tod = TODSatellite(
        comm.comm_group, 
        detquat, 
        obs_samples, 
        firstsamp=obsfirst,
        firsttime=obsstart,
        rate=samplerate,
        spinperiod=p_beta,
        spinangle=beta,
        precperiod=p_alpha,
        precangle=alpha,
        coord="E",
        hwprpm=hwprpm
    )
    qprec = np.empty(4 * tod.local_samples[1], dtype=np.float64).reshape((-1, 4))
    slew_precession_axis(
        qprec,
        firstsamp=obsfirst,
        samplerate=samplerate,
        degday=deg_per_day,
    )
    tod.set_prec_axis(qprec=qprec)
    obs = dict()
    obs["tod"] = tod
    data.obs.append(obs)

# Make a simple pointing matrix

pointing = OpPointingHpix(nside=nside, nest=True, mode="IQU")
pointing.exec(data)

# Compute the locally hit pixels

localpix, localsm, subnpix = get_submaps_nested(data, nside)

# Construct a distributed map to store the hit map

npix = 12 * nside**2

hits = DistPixels(
    comm=data.comm.comm_world,
    size=npix,
    nnz=1,
    dtype=np.int64,
    submap=subnpix,
    local=localsm,
)
hits.data.fill(0)

# Accumulate the hit map locally

build_hits = OpAccumDiag(hits=hits)
build_hits.exec(data)

# Reduce the map across processes (a No-op in this case)

hits.allreduce()

# Write out the map

hitsfile = "simscan_satellite_hits_mpi.fits"
hits.write_healpix_fits(hitsfile)

# Plot the map.

if comm.world_rank == 0:
    hitdata = hp.read_map(hitsfile, nest=True)
    hp.mollview(hitdata, xsize=800, nest=True, cmap="cool", min=0)
    plt.savefig("{}.png".format(hitsfile))
    plt.close()
In [ ]:
if nersc_host is not None:
    %srun -N 1 -C haswell -n 32 -c 2 --cpu_bind=cores -t 00:05:00 python simscan_satellite_mpi.py >simscan_satellite_mpi.log 2>&1
else:
    # Just use mpirun
    import subprocess as sp
    sp.check_call("mpirun -np 4 python simscan_satellite_mpi.py >simscan_satellite_mpi.log 2>&1", shell=True)
In [ ]: