Source code for txpipe.masks

import numpy as np
from .utils import choose_pixelization
from .base_stage import PipelineStage
from .data_types import MapsFile
from ceci.config import StageParameter
from .mapping import degrade_healsparse

class TXBaseMask(PipelineStage):
    """
    Base class for generating binary survey masks using auxiliary input maps.
    Subclasses should implement `make_binary_mask`, which defines the logic for mask construction.

    The base class handles writing the mask to output and computing metadata such as area and f_sky.
    """

    name = "TXBaseMask"
    parallel = False
    # make a mask from the auxiliary maps
    outputs = [("mask", MapsFile)]

    def run(self):
        """
        Run the pipeline stage: generate a binary mask, finalize it,
        and write to output.
        """
        mask, pixel_scheme, metadata = self.make_binary_mask()
        mask, metadata = self.finalize_mask(mask, pixel_scheme, metadata)

        with self.open_output("mask", wrapper=True) as f:
            f.file.create_group("maps")
            f.write_map("mask", mask, metadata)

    def finalize_mask(self, mask, pixel_scheme, metadata):
        """
        Finalize the binary mask by flattening, cleaning NaNs/negatives,
        and extracting unmasked pixel indices. Computes area and f_sky.

        Parameters
        ----------
        mask : hsp.HealSparseMap
            binary mask map.
        pixel_scheme : PixelScheme
            Pixelization scheme used for area calculations.
        metadata : dict
            Metadata dictionary to update with area and f_sky.

        Returns
        -------
        mask : hsp.HealSparseMap or np.ndarray
            Final mask
        metadata : dict
            Updated metadata.
        """

        # Total survey area calculation. This is simplistic:
        # TODO: account for weights / hit fractions here, and allow
        # for different lens and shear survey areas
        num_hit = mask.n_valid * 1.0
        area = pixel_scheme.pixel_area(degrees=True) * num_hit
        f_sky = area / 41252.96125
        print(f"f_sky = {f_sky}")
        print(f"area = {area:.2f} sq deg")
        metadata["area"] = area
        metadata["f_sky"] = f_sky

        return mask, metadata

    def compute_fracdet_from_hsp(self, metadata):
        """
        Computes detection fraction from an input healsparse map (higher resolution, binary)

        Parameters
        ----------
        metadata : dict
            Metadata with target resolution (e.g. 'nside').

        Returns
        -------
        fracdet : hsp.HealSparseMap
            Fractional detection array at the config nside.
        """
        import healsparse

        # load healsparse map
        supreme_map_file = self.config["supreme_map_file"]
        print(f"Generating fracdet map from {supreme_map_file}")
        spmap = healsparse.HealSparseMap.read(supreme_map_file)

        # fracdet it
        nside = metadata["nside"]
        fracdet = spmap.fracdet_map(nside)

        return fracdet

    def compute_fracdet_from_hsp_list(self, metadata):
        """
        Computes detection fraction from a list of input healsparse maps (higher resolution, binary)

        Returns a map of the fractional of valid pixels that are present in *all* the input maps

        Parameters
        ----------
        metadata : dict
            Metadata with target resolution (e.g. 'nside').

        Returns
        -------
        fracdet : hsp.HealSparseMap
            Fractional detection array at the config nside.
        """
        import healsparse

        # load healsparse map
        supreme_map_file_list = self.config["supreme_map_files"]
        print(f"Generating fracdet map from {supreme_map_file_list}")
        spmap_list = [healsparse.HealSparseMap.read(supreme_map_file) for supreme_map_file in supreme_map_file_list]
        spmap = healsparse.operations.max_intersection(spmap_list)

        # fracdet it
        nside = metadata["nside"]
        fracdet = spmap.fracdet_map(nside)

        return fracdet


[docs] class TXSimpleMask(TXBaseMask): """ Generate a simple binary mask using cuts on depth and bright object maps. """ name = "TXSimpleMask" inputs = [("aux_lens_maps", MapsFile)] config_options = { "depth_cut": StageParameter(float, 23.5, msg="Depth cut for mask creation."), "bright_object_max": StageParameter(float, 10.0, msg="Maximum allowed bright object count."), } def make_binary_mask(self): """ Create a binary mask by applying cuts to depth and bright object count maps. Returns ------- mask : hsp.HealSparseMap Boolean mask array. pixel_scheme : PixelScheme Pixelization object. metadata : dict Metadata from input file. """ import healsparse as hsp with self.open_input("aux_lens_maps", wrapper=True) as f: metadata = dict(f.file["maps"].attrs) bright_obj = f.read_map("bright_objects/count") depth = f.read_map("depth/depth") pixel_scheme = choose_pixelization(**metadata) valid_pix = depth.valid_pixels hit = hsp.HealSparseMap.make_empty( depth.nside_coverage, depth.nside_sparse, dtype=np.bool ) hit.update_values_pix(valid_pix, 1) masks = { "depth": hsp.HealSparseMap.make_empty_like(hit), "bright_obj": hsp.HealSparseMap.make_empty_like(hit), } masks["depth"].update_values_pix( valid_pix, depth[valid_pix] > self.config["depth_cut"] ) masks["bright_obj"].update_values_pix( valid_pix, ~(bright_obj[valid_pix] > self.config["bright_object_max"]) ) for name, hsp_map in masks.items(): frac = 1 - (hsp_map & hit).n_valid / hit.n_valid print(f"Mask '{name}' removes fraction {frac:.3f} of hit pixels") # Overall mask mask = hsp.operations.and_intersection([mask for _, mask in masks.items()]) return mask, pixel_scheme, metadata
[docs] class TXSimpleMaskSource(TXBaseMask): """ Generate a binary mask for source galaxies using positive lensing weights across source bins. """ name = "TXSimpleMaskSource" # make a mask from the source maps inputs = [("source_maps", MapsFile)] config_options = { "require_all_bins": True, } def make_binary_mask(self): """ Mask all pixels where lensing weights are zero for any source bin. Returns ------- mask : hsp.HealSparseMap Boolean mask array. pixel_scheme : PixelScheme Pixelization object. metadata : dict Metadata from input file. """ import healsparse as hsp lensing_weights = [] with self.open_input("source_maps", wrapper=True) as f: metadata = dict(f.file["maps"].attrs) pixel_scheme = choose_pixelization(**metadata) for i in range(metadata["nbin_source"]): lw = f.read_map(f"lensing_weight_{i}") lensing_weights.append(lw) if self.config["require_all_bins"]: # if *all* bins have a positive lensing weight, set mask to True # find min lensing weight across all bins # if min > 0.0 -> all bins have a non-zeros lensing weight so we set that pixel to True min_lw = hsp.operations.min_intersection(lensing_weights) valid_pix = min_lw.valid_pixels[min_lw[min_lw.valid_pixels] > 0.0] else: # if *any* bin has a positive lensing weight, set mask to True # find max lensing weight across all bins # if max > 0.0 -> at least 1 bins has a non-zeros lensing weight so we set that pixel to True max_lw = hsp.operations.max_union(lensing_weights) valid_pix = max_lw.valid_pixels[max_lw[max_lw.valid_pixels] > 0.0] mask = hsp.HealSparseMap.make_empty( lw.nside_coverage, lw.nside_sparse, np.bool_ ) mask[valid_pix] = True return mask, pixel_scheme, metadata
[docs] class TXSimpleMaskFrac(TXSimpleMask): """ Make a simple mask using a depth cut and bright object cut Include the fractional coverage of each pixel using a high-res survey property map (e.g. N-exposures for given band(s) ) #NOTE: This assumes all cuts to the mask are being done within TXpipe at the config Nside #if we want to apply cuts on the survey property map nside (potentially higher resolution), we will need to add this feature """ name = "TXSimpleMaskFrac" parallel = False # make a mask from the auxiliary maps inputs = [ ("aux_lens_maps", MapsFile), ] outputs = [("mask", MapsFile)] config_options = { "depth_cut": StageParameter(float, 23.5, msg="Depth cut for mask creation."), "bright_object_max": StageParameter(float, 10.0, msg="Maximum allowed bright object count."), "supreme_map_file": StageParameter(str, "none", msg="Path to supreme map file for fracdet computation."), "supreme_map_files": StageParameter(list, [], msg="List of supreme map files for fracdet computation."), "frac_cut": StageParameter(float, 0.0, msg="Minimum fractional coverage to keep pixel in mask."), } def run(self): """ Apply mask logic and replace selected pixels with fractional coverage values. """ import healsparse as hsp mask, pixel_scheme, metadata = self.make_binary_mask() mask, metadata = self.finalize_mask(mask, pixel_scheme, metadata) assert self.config["supreme_map_file"] == "none" or self.config["supreme_map_files"] == [], ( "You have specified both map_file and map_files, pick one" ) if self.config["supreme_map_file"] != "none": fracdet = self.compute_fracdet_from_hsp(metadata) else: fracdet = self.compute_fracdet_from_hsp_list(metadata) # boolean mask of pixels with high fracdet frac_cut_mask = hsp.HealSparseMap.make_empty_like(mask) frac_cut_mask[fracdet.valid_pixels] = ( fracdet[fracdet.valid_pixels] > self.config["frac_cut"] ) # find the intersection of the mask cuts and the fractional coverage map # fracdet contains the fractional coverage value # mask is the boolean mask built by the cuts to aux maps # frac_cut_mask is a boolean mask with only the high frac pixels frac_det_mask = hsp.operations.product_intersection( [fracdet, mask, frac_cut_mask] ) with self.open_output("mask", wrapper=True) as f: f.file.create_group("maps") f.write_map("mask", frac_det_mask, metadata)
class TXCustomMask(TXSimpleMaskFrac): """ Make a mask from a custom list of cuts to aux maps (e.g depth cut and bright object cuts) Fracdet currently taken from aux_lens_maps TODO: add option to compute this from hsp map """ name = "TXCustomMask" # make a mask from the auxiliary maps inputs = [ ("aux_lens_maps", MapsFile), ] outputs = [("mask", MapsFile)] config_options = { "fracdet_name": StageParameter(str, "footprint/fracdet_griz", msg="Fracdet map name."), "cuts": StageParameter(list, ["footprint/fracdet_griz > 0"], msg="List of mask cuts to apply."), "degrade": StageParameter(bool, False, msg="Degrade resolution if input map Nside differs from config nside."), } def run(self): """ Apply custom mask logic, assign fracdet values, and optionally degrade resolution. """ import healsparse as hsp mask, pixel_scheme, metadata = self.make_binary_mask() mask, metadata = self.finalize_mask(mask, pixel_scheme, metadata) fracdet = self.get_fracdet() # assign fracdec for all selected pixels valid_pix = mask.valid_pixels masked_frac_map = hsp.HealSparseMap.make_empty_like(fracdet) masked_frac_map[valid_pix] = fracdet[valid_pix] if self.config["degrade"]: if self.config["nside"] == metadata["nside"]: print("Nsides match, no degrading necessary") else: print( f"Input Nside={metadata['nside']}, degrading to {self.config['nside']}" ) masked_frac_map, metadata = self.degrade( masked_frac_map, metadata, self.config["nside"] ) with self.open_output("mask", wrapper=True) as f: f.file.create_group("maps") f.write_map("mask", masked_frac_map, metadata) def make_binary_mask(self): """ Create a binary mask from arbitrary user-defined cuts on auxiliary maps. Returns ------- mask : hsp.HealSparseMap Boolean mask array. pixel_scheme : PixelScheme Pixelization object. metadata : dict Metadata from input. """ import healpy import healsparse as hsp import re import operator OP_MAP = { ">": operator.gt, ">=": operator.ge, "<": operator.lt, "<=": operator.le, "==": operator.eq, "!=": operator.ne, } # get a list of quantities to be cut cuts = self.config["cuts"] #join operators and put in () so they can be used in re.split compare_pattern = "("+'|'.join(OP_MAP.keys())+")" masks = [] for icut, cut_string in enumerate(cuts): print(cut_string) map_name, operator_string, value = [ s.strip() for s in re.split(compare_pattern, cut_string) ] operator = OP_MAP[operator_string] # load map with self.open_input("aux_lens_maps", wrapper=True) as f: metadata = dict(f.file["maps"].attrs) m = f.read_map(map_name) pixel_scheme = choose_pixelization(**metadata) valid_pix = m.valid_pixels # set value dtype to match the map value = np.array(value, dtype=m.dtype).item() # make hit map # hit is only used to print fraction of pixels removed from each cut # Arbitrarily choose the first map specified if icut == 0: hit_name = map_name hit = hsp.HealSparseMap.make_empty( m.nside_coverage, m.nside_sparse, dtype=np.bool ) hit.update_values_pix(valid_pix, 1) # make mask for this cut mask = hsp.HealSparseMap.make_empty( m.nside_coverage, m.nside_sparse, dtype=np.bool ) mask.update_values_pix(valid_pix, operator(m[valid_pix], value)) masks.append(mask) for imask, m in enumerate(masks): frac = 1 - (m & hit).n_valid / hit.n_valid print( f"Cut '{cuts[imask]}' removes fraction {frac:.3f} of {hit.n_valid} {hit_name} pixels" ) # Overall mask mask = hsp.operations.and_intersection(masks) return mask, pixel_scheme, metadata def get_fracdet(self): """ Load fractional detection map from auxiliary maps input. Returns ------- fracdet : hsp.HealSparseMap Fracdet array. """ with self.open_input("aux_lens_maps", wrapper=True) as f: fracdet = f.read_map(self.config["fracdet_name"]) return fracdet def degrade(self, mask, metadata_in, nside_out): """ Degrade a high-resolution fractional mask to lower resolution using healsparse Parameters ---------- mask : hsp.HealSparseMap Input healsparse mask (either bool or a fractional coverage map) metadata_in : dict Input metadata. nside_out : int Desired lower-resolution nside. Returns ------- mask_out : hsp.HealSparseMap Degraded mask. metadata_out : dict Updated metadata. """ import healsparse as hsp import healpy as hp import copy # use custom recution "mask" which allows us to degrade a fractional coverage map mask_out = degrade_healsparse(mask, nside_out, reduction="mask") metadata_out = copy.copy(metadata_in) metadata_out["nside"] = nside_out # sanity_check: make sure area in == area out area_in = np.sum(mask[mask.valid_pixels]) * hp.nside2pixarea( metadata_in["nside"], degrees=True ) area_out = np.sum(mask_out[mask_out.valid_pixels]) * hp.nside2pixarea( metadata_out["nside"], degrees=True ) assert np.round(area_in, 3) == np.round(area_out, 3) return mask_out, metadata_out