import numpy as np

Nout    = 2048
maxmem  = 4       # GB 
maxmem *= 1024**3 # bytes 

Nin      = 6144
infname  = "/global/cfs/cdirs/sobs/www/users/websky/ICs/Fvec_7700Mpc_n6144_nb30_nt16"
outfname = "/global/cfs/cdirs/sobs/www/users/websky/ICs/Fvec_7700Mpc_n6144_nb30_nt16_no2048"
infile   = open(infname, "rb")
outfile  = open(outfname,"wb")

Nper     = Nin // Nout
slicemem = 4 * Nin * Nin
maxslice = maxmem // slicemem 

Nslabout_max = maxslice // Nper 
Nslabin_max  = Nslabout_max * Nper

firstin = 0
lastin  = Nslabin_max - 1
while lastin < Nin and firstin < lastin: # loop over x slices

    Nslabin  = lastin - firstin + 1
    Nslabout = Nslabin // Nper

    Ncellin = Nslabin * Nin * Nin
    print(lastin,Nin)

    slab = np.fromfile(infile,dtype=np.float32,count=Ncellin)    
    #slab = slab.reshape((Nslabout,Nin,Nin))
    
    #slab = slab.transpose([0,1,2])               # --> [x,y,z]
    slab = np.mean(slab.reshape(-1,Nper), axis=1) # average z
    slab = slab.reshape((Nslabin,Nin,Nout))

    slab = slab.transpose([0,2,1])                # --> [x,z,y]
    slab = np.mean(slab.reshape(-1,Nper), axis=1) # average y
    slab = slab.reshape((Nslabin,Nout,Nout)) 

    slab = slab.transpose([2,1,0])                # --> [y,z,x]
    slab = np.mean(slab.reshape(-1,Nper), axis=1) # average x
    slab = slab.reshape((Nout,Nout,Nslabout))

    slab = slab.transpose([2,0,1])                # --> [x,y,z]
    slab.tofile(outfile)

    firstin = lastin + 1    
    lastin  = min(firstin + Nslabin_max - 1, Nin-1)
