Source code for txpipe.noise_maps

from .base_stage import PipelineStage
from .maps import TXBaseMaps
from .data_types import (
    ShearCatalog,
    TomographyCatalog,
    MapsFile,
    LensingNoiseMaps,
    ClusteringNoiseMaps,
    HDFFile,
)
import numpy as np
from .utils.mpi_utils import mpi_reduce_large
from .utils import (
    choose_pixelization,
    read_shear_catalog_type,
    rename_iterated,
)
from .shear_calibration import Calibrator
from ceci.config import StageParameter


[docs] class TXSourceNoiseMaps(TXBaseMaps): """ Generate realizations of shear noise maps with random rotations This takes the shear catalogs and tomography and randomly spins the shear values in it, removing the shear signal and leaving only shape noise """ name = "TXSourceNoiseMaps" inputs = [ ("shear_catalog", ShearCatalog), ("shear_tomography_catalog", TomographyCatalog), # We get the pixelization info from the diagnostic maps ("mask", MapsFile), ] outputs = [ ("source_noise_maps", LensingNoiseMaps), ] config_options = { "chunk_rows": StageParameter(int, 100000, msg="Number of rows to process in each chunk."), "lensing_realizations": StageParameter(int, 30, msg="Number of lensing noise realizations to generate."), "true_shear": StageParameter(bool, False, msg="Whether to use true shear values for noise maps."), } # instead of reading from config we match the basic maps def choose_pixel_scheme(self): with self.open_input("mask", wrapper=True) as maps_file: pix_info = maps_file.read_map_info("mask") #overwrite nside if it differs from the mask's native nside if self.config['nside'] != pix_info["nside"]: pix_info["nside"] = self.config['nside'] return choose_pixelization(**pix_info) def prepare_mappers(self, pixel_scheme): import healsparse as hsp read_shear_catalog_type(self) with self.open_input("mask", wrapper=True) as maps_file: mask = maps_file.read_mask("mask", degrade_nside=self.config["nside"]) with self.open_input("shear_tomography_catalog", wrapper=True) as f: nbin_source = f.file["tomography"].attrs["nbin"] # Mapping from 0 .. nhit - 1 to healpix indices reverse_map = mask.valid_pixels # Get a mapping from healpix indices to masked pixel indices # This reduces memory usage compared to the full healpix array # We could use the healsparse array itself here # Possiby with a recarray for different realizations # TODO: test and implement this index_map = hsp.HealSparseMap.make_empty( pixel_scheme.nside_coverage, pixel_scheme.nside, dtype=int, sentinel=-1, ) index_map[reverse_map] = np.arange(reverse_map.size) # Number of unmasked pixels npix = reverse_map.size lensing_realizations = self.config["lensing_realizations"] # lensing g1, g2 G1 = np.zeros((npix, nbin_source, lensing_realizations)) G2 = np.zeros((npix, nbin_source, lensing_realizations)) # lensing weight GW = np.zeros((npix, nbin_source)) return (npix, G1, G2, GW, index_map, reverse_map, nbin_source) def data_iterator(self): with self.open_input("shear_catalog", wrapper=True) as f: shear_cols, renames = f.get_primary_catalog_names(self.config["true_shear"]) it = self.combined_iterators( self.config["chunk_rows"], "shear_catalog", "shear", shear_cols, "shear_tomography_catalog", "tomography", ["bin"], ) return rename_iterated(it, renames) def accumulate_maps(self, pixel_scheme, data, mappers): npix, G1, G2, GW, index_map, _, _ = mappers lensing_realizations = self.config["lensing_realizations"] source_bin = data["bin"] # Get the pixel index for each object and convert # to the reduced index ra = data["ra"] dec = data["dec"] orig_pixels = pixel_scheme.ang2pix(ra, dec) pixels = index_map[orig_pixels] # Pull out some columns we need n = len(ra) w = data["weight"] # Pre-weight the g1 values so we don't have to # weight each realization again g1 = data["g1"] * w g2 = data["g2"] * w # random rotations of the g1, g2 values phi = np.random.uniform(0, 2 * np.pi, (n, lensing_realizations)) c = np.cos(phi) s = np.sin(phi) g1r = c * g1[:, np.newaxis] + s * g2[:, np.newaxis] g2r = -s * g1[:, np.newaxis] + c * g2[:, np.newaxis] for i in range(n): sb = source_bin[i] # Skip objects we don't use if sb < 0: continue # convert to the index in the partial space pix = pixels[i] # The sentinel value for pixels is -1 if pix < 0: continue # build up the rotated map for each bin G1[pix, sb, :] += g1r[i] G2[pix, sb, :] += g2r[i] GW[pix, sb] += w[i] def finalize_mappers(self, pixel_scheme, mappers): # only one mapper here - we call its finalize method # to collect everything npix, G1, G2, GW, index_map, reverse_map, nbin_source = mappers lensing_realizations = self.config["lensing_realizations"] # Sum everything at root if self.comm is not None: mpi_reduce_large(G1, self.comm, max_chunk_count=2**26) # fiducial is 2**30 mpi_reduce_large(G2, self.comm, max_chunk_count=2**26) mpi_reduce_large(GW, self.comm, max_chunk_count=2**26) if self.rank != 0: del G1, G2, GW # build up output maps = {} # only master gets full stuff if self.rank != 0: return maps # We need to calibrate the shear maps cal, _ = Calibrator.load(self.get_input("shear_tomography_catalog")) for b in range(nbin_source): for i in range(lensing_realizations): bin_mask = np.where(GW[:, b] > 0) g1 = G1[:, b, i] / GW[:, b] g2 = G2[:, b, i] / GW[:, b] g1, g2 = cal[b].apply(g1, g2, subtract_mean=False) maps["source_noise_maps", f"rotation_{i}/g1_{b}"] = ( reverse_map[bin_mask], g1[bin_mask], ) maps["source_noise_maps", f"rotation_{i}/g2_{b}"] = ( reverse_map[bin_mask], g2[bin_mask], ) return maps
[docs] class TXLensNoiseMaps(TXBaseMaps): """ Generate lens density noise realizations using random splits This randomly assigns each galaxy to one of two bins and uses the different between the halves to get a noise estimate. """ name = "TXLensNoiseMaps" inputs = [ ("lens_tomography_catalog", TomographyCatalog), ("photometry_catalog", HDFFile), ("mask", MapsFile), ] outputs = [ ("lens_noise_maps", ClusteringNoiseMaps), ] config_options = { "chunk_rows": StageParameter(int, 100000, msg="Number of rows to process in each chunk."), "clustering_realizations": StageParameter(int, 1, msg="Number of clustering noise realizations to generate."), "mask_in_weights": StageParameter(bool, False, msg="Whether to include mask in weight calculations."), } # instead of reading from config we match the basic maps def choose_pixel_scheme(self): with self.open_input("mask", wrapper=True) as maps_file: pix_info = maps_file.read_map_info("mask") #overwrite nside if it differs from the mask's native nside if self.config['nside'] != pix_info["nside"]: pix_info["nside"] = self.config['nside'] return choose_pixelization(**pix_info) def prepare_mappers(self, pixel_scheme): import healsparse as hsp with self.open_input("mask", wrapper=True) as maps_file: mask = maps_file.read_mask("mask", degrade_nside=self.config["nside"]) with self.open_input("lens_tomography_catalog", wrapper=True) as f: nbin_lens = f.file["tomography"].attrs["nbin"] # Mapping from 0 .. nhit - 1 to healpix indices reverse_map = mask.valid_pixels # Get a mapping from healpix indices to masked pixel indices # This reduces memory usage compared to the full healpix array # We could use the healsparse array itself here # Possiby with a recarray for different realizations # TODO: test and implement this index_map = hsp.HealSparseMap.make_empty( pixel_scheme.nside_coverage, pixel_scheme.nside, dtype=int, sentinel=-1, ) index_map[reverse_map] = np.arange(reverse_map.size) # Number of unmasked pixels npix = reverse_map.size clustering_realizations = self.config["clustering_realizations"] ngal_split = np.zeros((npix, nbin_lens, clustering_realizations, 2), dtype=np.int32) # TODO: Clustering weights go here return (npix, ngal_split, index_map, reverse_map, mask, nbin_lens) def data_iterator(self): it = self.combined_iterators( self.config["chunk_rows"], "photometry_catalog", "photometry", ["ra", "dec"], "lens_tomography_catalog", "tomography", ["bin"], ) return it def accumulate_maps(self, pixel_scheme, data, mappers): npix, ngal_split, index_map, _, _, _ = mappers clustering_realizations = self.config["clustering_realizations"] # Tomographic bin lens_bin = data["bin"] # Get the pixel index for each object and convert # to the reduced index ra = data["ra"] dec = data["dec"] orig_pixels = pixel_scheme.ang2pix(ra, dec) pixels = index_map[orig_pixels] n = len(ra) # randomly select a half for each object split = np.random.binomial(1, 0.5, (n, clustering_realizations)) for i in range(n): lb = lens_bin[i] # Skip objects we don't use if lb < 0: continue # convert to the index in the partial space pix = pixels[i] # The sentinel value for pixels is -1 if pix < 0: continue for j in range(clustering_realizations): ngal_split[pix, lb, j, split[i]] += 1 def finalize_mappers(self, pixel_scheme, mappers): # only one mapper here - we call its finalize method # to collect everything npix, ngal_split, index_map, reverse_map, mask, nbin_lens = mappers clustering_realizations = self.config["clustering_realizations"] # Sum everything at root if self.comm is not None: mpi_reduce_large(ngal_split, self.comm) if self.rank != 0: del ngal_split # build up output maps = {} # only master gets full stuff if self.rank != 0: return maps for b in range(nbin_lens): for i in range(clustering_realizations): # We have computed the first half already, # and we have the total map from an earlier stage half1 = np.zeros(npix) half2 = np.zeros_like(half1) if self.config["mask_in_weights"]: half1 = ngal_split[:, b, i, 0] half2 = ngal_split[:, b, i, 1] else: half1 = (ngal_split[:, b, i, 0]) / mask[reverse_map] half2 = (ngal_split[:, b, i, 1]) / mask[reverse_map] # Convert to overdensity. I thought about # using half the mean from the full map to reduce # noise, but thought that might add covariance # to the two maps, and this shouldn't be that noisy # half1 and half2 are already weighted by the mask, so we just need the average mu1 = np.average(half1[mask[reverse_map] > 0]) mu2 = np.average(half2[mask[reverse_map] > 0]) # This will produce some mangled sentinel values # but they will be masked out rho1 = (half1 - mu1) / mu1 rho2 = (half2 - mu2) / mu2 # Save four maps - density splits and ngal splits maps["lens_noise_maps", f"split_{i}/rho1_{b}"] = (reverse_map, rho1) maps["lens_noise_maps", f"split_{i}/rho2_{b}"] = (reverse_map, rho2) maps["lens_noise_maps", f"split_{i}/ngal1_{b}"] = (reverse_map, half1) maps["lens_noise_maps", f"split_{i}/ngal2_{b}"] = (reverse_map, half2) return maps
[docs] class TXExternalLensNoiseMaps(TXLensNoiseMaps): """ Generate lens density noise realizations using random splits of an external catalog This randomly assigns each galaxy to one of two bins and uses the different between the halves to get a noise estimate. """ name = "TXExternalLensNoiseMaps" inputs = [ ("lens_tomography_catalog", TomographyCatalog), ("lens_catalog", HDFFile), ("mask", MapsFile), ] def data_iterator(self): it = self.combined_iterators( self.config["chunk_rows"], "lens_catalog", "lens", ["ra", "dec"], "lens_tomography_catalog", "tomography", ["bin"], ) return it
# These functions will be jitted and used in the TXNoiseMapsJax class below. # Note that, quoting the JAX docs: # Unlike NumPy in-place operations such as x[idx] += y, if multiple indices # refer to the same location, all updates will be applied (NumPy would only # apply the last update, rather than applying all updates.) # So this is not just the raw equivalent of GN[masked_pixels, masked_source_bin] += masked_gnr # This is better than original numpy! def GN_add(GN, masked_pixels, masked_source_bin, masked_gnr): return GN.at[masked_pixels, masked_source_bin, :].add(masked_gnr) def GW_add(GW, masked_pixels, masked_source_bin, masked_weights): return GW.at[masked_pixels, masked_source_bin].add(masked_weights) def ngal_split_add(ngal_split, pixels_lb_mask, clustering_realizations, split_lb_mask, lens_bin_lb_mask): from jax import numpy as jnp return ngal_split.at[ pixels_lb_mask, lens_bin_lb_mask, jnp.arange(clustering_realizations), split_lb_mask, ].add(1)
[docs] class TXNoiseMapsJax(PipelineStage): """ Generate noise realisations of lens and source maps using JAX This is a JAX/GPU version of the noise map stages. Need to update to stop assuming lens and source are the same and split into two stages. """ name = "TXNoiseMapsJax" inputs = [ ("shear_catalog", ShearCatalog), ("lens_tomography_catalog", TomographyCatalog), ("shear_tomography_catalog", TomographyCatalog), # We get the pixelization info from the diagnostic maps ("mask", MapsFile), ("lens_maps", MapsFile), ] outputs = [ ("source_noise_maps", LensingNoiseMaps), ("lens_noise_maps", ClusteringNoiseMaps), ] config_options = { "chunk_rows": StageParameter(int, 4000000, msg="Number of rows to process in each chunk."), "lensing_realizations": StageParameter(int, 30, msg="Number of lensing realizations."), "clustering_realizations": StageParameter(int, 1, msg="Number of clustering realizations."), "seed": StageParameter(int, 0, msg="Random seed for reproducibility."), } def run(self): from jax import numpy as jnp from jax.ops import index from jax import random, jit, device_get, device_put from .utils import choose_pixelization raise ValueError( "This code needs rewriting because source_bin and lens_bin now have the same name in the tomo files." ) # get the number of bins. nbin_source, nbin_lens, ngal_maps, mask, map_info = self.read_inputs() pixel_scheme = choose_pixelization(**map_info) lensing_realizations = self.config["lensing_realizations"] clustering_realizations = self.config["clustering_realizations"] # The columns we will need shear_cols = ["ra", "dec", "weight", "g1", "g2"] # Make the iterators chunk_rows = self.config["chunk_rows"] it = self.combined_iterators( chunk_rows, "shear_catalog", "shear", shear_cols, "shear_tomography_catalog", "tomography", ["bin"], "lens_tomography_catalog", "tomography", ["bin"], ) # Get a mapping from healpix indices to masked pixel indices index_map = np.zeros(pixel_scheme.npix, dtype=jnp.int32) - 1 counter = 0 for i in range(pixel_scheme.npix): if mask[i] > 0: index_map[i] = counter counter += 1 # Number of unmasked pixels npix = counter # The memory usage of this class can get high, so we report what is expected here, so # if a crash happens a few moments later it's clear why. if self.rank == 0: nmaps = nbin_source * (2 * lensing_realizations + 1) + nbin_lens * clustering_realizations * 2 nGB = (npix * nmaps * 8) / 1000.0**3 print(f"Allocating maps of size {nGB:.2f} GB") # lensing g1, g2. To start with we accumalate these, and normalize them later G1 = jnp.zeros((npix, nbin_source, lensing_realizations)) G2 = jnp.zeros((npix, nbin_source, lensing_realizations)) # lensing weights per pixel, which we later use to normalize g1, g2 GW = jnp.zeros((npix, nbin_source)) # clustering map - we start by generating a random split in the number count # maps, and later convert this to overdensity maps ngal_split = jnp.zeros((npix, nbin_lens, clustering_realizations, 2), dtype=np.int32) # TODO: Clustering weights go here # Initialize PRNG key for Jax with a seed, which can either be # chosen by the user or generated with numpy if self.config["seed"] == 0: seed = np.random.randint(2**32) else: seed = self.config["seed"] # ensure that every MPI rank has a different seed, and set up the JAX RNG # system seed += self.rank key = random.PRNGKey(seed) # apply the just-in-time compilation to these functions; this means that # they are compiled on first use, for the data types given to them, then # subsequent times the compiled version is used. They are used on each chunk # of the data as we loop through it. GN_add_jit = jit(GN_add) GW_add_jit = jit(GW_add) ngal_split_add_jit = jit(ngal_split_add, static_argnums=(2,)) # Loop through the data # TODO: this whole bit should be a single jax.jit kernel for speed for s, e, data in it: # Number of objects in this chunk n = e - s print(f"Rank {self.rank} processing rows {s} - {e}") # Send data to GPU source_bin = device_put(data["bin"]) lens_bin = device_put(data["bin"]) weights = device_put(data["weight"]) g1 = device_put(data["g1"]) * weights g2 = device_put(data["g2"]) * weights # Compute which pixel each object is in ra = data["ra"] dec = data["dec"] orig_pixels = device_put(pixel_scheme.ang2pix(ra, dec)) pixels = device_put(index_map[orig_pixels]) # This is how you do RNG with JAX. We use subkey for this RNG operation # and then key is passed forward for the next operation key, subkey = random.split(key) # randomly select a half for each lens bin object # random.bernoulli returns True/False arrays. Convert that # to an integer array (both on the GPU) by multiplying by 1 split = 1 * random.bernoulli(subkey, 0.5, (n, clustering_realizations)) # random rotations of the g1, g2 values key, subkey = random.split(key) phi = random.uniform(subkey, shape=(lensing_realizations, n), minval=0, maxval=2 * jnp.pi) cos = jnp.cos(phi) sin = jnp.sin(phi) g1r = jnp.transpose(cos * g1 + sin * g2) g2r = jnp.transpose(-sin * g1 + cos * g2) # masks showing which pixels to fill in pix_mask = pixels >= 0 sb_mask = (source_bin >= 0) & pix_mask # jax.jit doesn't like masks inside masks so we have to calculate these in # advance instead of doing that within the jitted functions masked_pixels = pixels[sb_mask] masked_source_bin = source_bin[sb_mask] masked_g1r = g1r[sb_mask] masked_g2r = g2r[sb_mask] masked_weights = weights[sb_mask] lb_mask = sb_mask & (lens_bin >= 0) pixels_lb_mask = pixels[lb_mask] lens_bin_lb_mask = lens_bin[lb_mask] split_lb_mask = split[lb_mask] # Accumulate into the total noise maps. Under JAX this can't be an in-place # operation, so we have to replace G1 each time. Under the hood this may # be happening in-place, I think it depends. G1 = GN_add_jit(G1, masked_pixels, masked_source_bin, masked_g1r) G2 = GN_add_jit(G2, masked_pixels, masked_source_bin, masked_g2r) GW = GW_add_jit(GW, masked_pixels, masked_source_bin, masked_weights) ngal_split = ngal_split_add_jit( ngal_split, pixels_lb_mask, clustering_realizations, split_lb_mask, lens_bin_lb_mask, ) # TODO: Currently breaks with clustering_realizations > 1 # Now we have finished looping through the data, we sum everything over the # different processes to the root process if self.comm is not None: import mpi4jax from mpi4py import MPI G1, token = mpi4jax.reduce(G1, MPI.SUM, root=0) G2, token = mpi4jax.reduce(G2, MPI.SUM, root=0, token=token) G2, token = mpi4jax.reduce(GW, MPI.SUM, root=0, token=token) ngal_split, token = mpi4jax.reduce(ngal_split, MPI.SUM, root=0, token=token) if self.rank != 0: del G1, G2, GW, ngal_split # Save the maps on the root processor if self.rank == 0: print("Saving maps") # First we save the source noise maps outfile = self.open_output("source_noise_maps", wrapper=True) # The top section has the metadata in it group = outfile.file.create_group("maps") # TODO: sort out nbin vs nbin_source, nbin_lens group.attrs["nbin_source"] = nbin_source group.attrs["nbin"] = nbin_source group.attrs["lensing_realizations"] = lensing_realizations # Get outputs from GPU G1 = device_get(G1) G2 = device_get(G2) GW = device_get(GW) metadata = {**self.config, **map_info} # We save only the hit pixels pixels = np.where(mask > 0)[0] # Loop through each realization of each bin for b in range(nbin_source): for i in range(lensing_realizations): # Normalize this bin with the weights bin_mask = np.where(GW[:, b] > 0) g1 = G1[:, b, i] / GW[:, b] g2 = G2[:, b, i] / GW[:, b] # and save g1 and g2 maps to the file. outfile.write_map(f"rotation_{i}/g1_{b}", pixels[bin_mask], g1[bin_mask], metadata) outfile.write_map(f"rotation_{i}/g2_{b}", pixels[bin_mask], g2[bin_mask], metadata) # Similar for the lensing noise maps outfile = self.open_output("lens_noise_maps", wrapper=True) group = outfile.file.create_group("maps") group.attrs["nbin_lens"] = nbin_lens group.attrs["clustering_realizations"] = clustering_realizations for b in range(nbin_lens): for i in range(clustering_realizations): # We have computed the first half already, # and we have the total map from an earlier stage half1 = ngal_split[:, b, i, 0] half2 = ngal_split[:, b, i, 1] # Convert to overdensity. I thought about # using half the mean from the full map to reduce # noise, but thought that might add covariance # to the two maps, and this shouldn't be that noisy mu1 = np.average(half1, weights=mask[pixels]) mu2 = np.average(half2, weights=mask[pixels]) # This will produce some mangled sentinel values # but they will be masked out rho1 = (half1 - mu1) / mu1 rho2 = (half2 - mu2) / mu2 # Write both overdensity and count maps # for each bin for each split outfile.write_map(f"split_{i}/rho1_{b}", pixels, rho1, metadata) outfile.write_map(f"split_{i}/rho2_{b}", pixels, rho2, metadata) # counts outfile.write_map(f"split_{i}/ngal1_{b}", pixels, half1, metadata) outfile.write_map(f"split_{i}/ngal2_{b}", pixels, half2, metadata) def read_inputs(self): with self.open_input("mask", wrapper=True) as f: mask = f.read_mask("mask", degrade_nside=self.config["nside"]) # pixelization etc map_info = f.read_map_info("mask") with self.open_input("lens_maps", wrapper=True) as f: nbin_lens = f.file["maps"].attrs["nbin"] ngal_maps = [f.read_map(f"ngal_{b}") for b in range(nbin_lens)] with self.open_input("shear_tomography_catalog") as f: nbin_source = f["tomography"].attrs["nbin"] sz1 = f["tomography/bin"].size with self.open_input("lens_tomography_catalog") as f: sz2 = f["tomography/bin"].size if sz1 != sz2: raise ValueError( "Lens and source catalogs are different sizes in " "TXNoiseMaps. In this case run TXSourceNoiseMaps " "and TXLensNoiseMaps separately." ) return nbin_source, nbin_lens, ngal_maps, mask, map_info