diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml index 3c49e43..6098a6e 100644 --- a/.github/workflows/linting.yml +++ b/.github/workflows/linting.yml @@ -22,6 +22,7 @@ jobs: run: | sudo apt-get update python -m pip install --upgrade pip + sudo apt install libbz2-dev pip install wheel pip install . pip install .[dev] diff --git a/.github/workflows/publish-to-pypi.yml b/.github/workflows/publish-to-pypi.yml index a616f50..01b9342 100644 --- a/.github/workflows/publish-to-pypi.yml +++ b/.github/workflows/publish-to-pypi.yml @@ -30,6 +30,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip + sudo apt install libbz2-dev pip install build - name: Build package run: python -m build diff --git a/.github/workflows/smoke-test.yml b/.github/workflows/smoke-test.yml index 2692b32..2f01df1 100644 --- a/.github/workflows/smoke-test.yml +++ b/.github/workflows/smoke-test.yml @@ -28,6 +28,7 @@ jobs: run: | sudo apt-get update python -m pip install --upgrade pip + sudo apt install libbz2-dev pip install wheel pip install . pip install .[dev] diff --git a/.github/workflows/testing-and-coverage.yml b/.github/workflows/testing-and-coverage.yml index 0705bfb..032b2cd 100644 --- a/.github/workflows/testing-and-coverage.yml +++ b/.github/workflows/testing-and-coverage.yml @@ -27,6 +27,7 @@ jobs: run: | sudo apt-get update python -m pip install --upgrade pip + sudo apt install libbz2-dev pip install wheel pip install . pip install .[dev] diff --git a/pyproject.toml b/pyproject.toml index af9b221..0ecc7e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,6 @@ dynamic = ["version"] dependencies = [ "pz-rail-base", "click", - "pyarrow", ] # On a mac, install optional dependencies with `pip install '.[dev]'` (include the single quotes) @@ -37,9 +36,6 @@ dev = [ "pylint", # Used for static linting of files ] -[project.scripts] -rail_pipe = "rail.cli.rail_pipe.pipe_commands:pipe_cli" - [build-system] requires = [ "setuptools>=62", # Used to build and package the Python project diff --git a/src/rail/cli/rail_pipe/__init__.py b/src/rail/cli/rail_pipe/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/rail/cli/rail_pipe/__main__.py b/src/rail/cli/rail_pipe/__main__.py deleted file mode 100644 index e941646..0000000 --- a/src/rail/cli/rail_pipe/__main__.py +++ /dev/null @@ -1,5 +0,0 @@ -# This file must exist with these contents -from .pipe_commands import pipe_cli - -if __name__ == "__main__": - pipe_cli() diff --git a/src/rail/cli/rail_pipe/pipe_commands.py b/src/rail/cli/rail_pipe/pipe_commands.py deleted file mode 100644 index 922ceda..0000000 --- a/src/rail/cli/rail_pipe/pipe_commands.py +++ /dev/null @@ -1,337 +0,0 @@ -from typing import Any - -import click - -from rail.core import __version__ - -from ...utils.project import RailProject -from . import pipe_options, pipe_scripts -from .reduce_roman_rubin_data import reduce_roman_rubin_data - - -@click.group() -@click.version_option(__version__) -def pipe_cli() -> None: - """RAIL pipeline scripts""" - - -@pipe_cli.command(name="inspect") -@pipe_options.config_file() -def inspect_command(config_file: str) -> int: - """Inspect a rail pipeline project config""" - return pipe_scripts.inspect(config_file) - - -@pipe_cli.command(name="build") -@pipe_options.config_file() -@pipe_options.flavor() -@pipe_options.force() -def build_command(config_file: str, **kwargs: Any) -> int: - """Build the ceci pipeline configuration files""" - project = RailProject.load_config(config_file) - flavors = project.get_flavor_args(kwargs.pop('flavor')) - iter_kwargs = project.generate_kwargs_iterable(flavor=flavors) - ok = 0 - for kw in iter_kwargs: - ok |= pipe_scripts.build_pipelines(project, **kw, **kwargs) - return ok - - -@pipe_cli.command(name="subsample") -@pipe_options.config_file() -@pipe_options.selection() -@pipe_options.flavor() -@pipe_options.label() -@pipe_options.run_mode() -def subsample_command(config_file: str, **kwargs: Any) -> int: - """Make a training or test data set by randomly selecting objects""" - """Make a training data set by randomly selecting objects""" - project = RailProject.load_config(config_file) - flavors = project.get_flavor_args(kwargs.pop('flavor')) - selections = project.get_selection_args(kwargs.pop('selection')) - iter_kwargs = project.generate_kwargs_iterable(flavor=flavors, selection=selections) - ok = 0 - for kw in iter_kwargs: - ok |= pipe_scripts.subsample_data(project, **kw, **kwargs) - return ok - - -@pipe_cli.group(name="reduce") -def reduce_group() -> None: - """Reduce input data for PZ analysis""" - - -@reduce_group.command(name="roman_rubin") -@pipe_options.config_file() -@pipe_options.input_tag() -@pipe_options.input_selection() -@pipe_options.selection() -@pipe_options.run_mode() -def reduce_roman_rubin(config_file: str, **kwargs: Any) -> int: - """Reduce the roman rubin simulations for PZ analysis""" - project = RailProject.load_config(config_file) - selections = project.get_selection_args(kwargs.pop('selection')) - input_selections = kwargs.pop('input_selection') - iter_kwargs = project.generate_kwargs_iterable(selection=selections, input_selection=input_selections) - input_tag = kwargs.pop('input_tag', 'truth') - ok = 0 - for kw in iter_kwargs: - ok |= reduce_roman_rubin_data(project, input_tag, **kw, **kwargs) - return ok - - -@pipe_cli.group(name="run") -def run_group() -> None: - """Run a pipeline""" - - -@run_group.command(name="phot-errors") -@pipe_options.config_file() -@pipe_options.selection() -@pipe_options.flavor() -@pipe_options.run_mode() -def photmetric_errors_pipeline(config_file: str, **kwargs: Any) -> int: - """Run the photometric errors analysis pipeline""" - project = RailProject.load_config(config_file) - flavors = project.get_flavor_args(kwargs.pop('flavor')) - selections = project.get_selection_args(kwargs.pop('selection')) - iter_kwargs = project.generate_kwargs_iterable(flavor=flavors, selection=selections) - ok = 0 - pipeline_name = "photometric_errors" - pipeline_info = project.get_pipeline(pipeline_name) - input_catalog_name = pipeline_info['InputCatalogTag'] - pipeline_catalog_config = pipe_scripts.PhotmetricErrorsPipelineCatalogConfiguration( - project, source_catalog_tag=input_catalog_name, sink_catalog_tag='degraded', - ) - - for kw in iter_kwargs: - ok |= pipe_scripts.run_pipeline_on_catalog( - project, pipeline_name, - pipeline_catalog_config, - **kw, **kwargs, - ) - return ok - - -@run_group.command(name="truth-to-observed") -@pipe_options.config_file() -@pipe_options.selection() -@pipe_options.flavor() -@pipe_options.run_mode() -def truth_to_observed_pipeline(config_file: str, **kwargs: Any) -> int: - """Run the truth-to-observed data pipeline""" - project = RailProject.load_config(config_file) - flavors = project.get_flavor_args(kwargs.pop('flavor')) - selections = project.get_selection_args(kwargs.pop('selection')) - iter_kwargs = project.generate_kwargs_iterable(flavor=flavors, selection=selections) - ok = 0 - pipeline_name = "truth_to_observed" - pipeline_info = project.get_pipeline(pipeline_name) - input_catalog_name = pipeline_info['InputCatalogTag'] - pipeline_catalog_config = pipe_scripts.SpectroscopicPipelineCatalogConfiguration( - project, - source_catalog_tag=input_catalog_name, - sink_catalog_tag='degraded', - source_catalog_basename="output_dereddener_errors.pq", - ) - - for kw in iter_kwargs: - ok |= pipe_scripts.run_pipeline_on_catalog( - project, pipeline_name, - pipeline_catalog_config, - **kw, **kwargs, - ) - return ok - - -@run_group.command(name="blending") -@pipe_options.config_file() -@pipe_options.selection() -@pipe_options.flavor() -@pipe_options.run_mode() -def blending_pipeline(config_file: str, **kwargs: Any) -> int: - """Run the blending analysis pipeline""" - project = RailProject.load_config(config_file) - flavors = project.get_flavor_args(kwargs.pop('flavor')) - selections = project.get_selection_args(kwargs.pop('selection')) - iter_kwargs = project.generate_kwargs_iterable(flavor=flavors, selection=selections) - ok = 0 - pipeline_name = "blending" - pipeline_info = project.get_pipeline(pipeline_name) - input_catalog_name = pipeline_info['InputCatalogTag'] - pipeline_catalog_config = pipe_scripts.BlendingPipelineCatalogConfiguration( - project, - source_catalog_tag=input_catalog_name, - sink_catalog_tag='degraded', - ) - - for kw in iter_kwargs: - ok |= pipe_scripts.run_pipeline_on_catalog( - project, pipeline_name, - pipeline_catalog_config, - **kw, **kwargs, - ) - return ok - - -@run_group.command(name="spec-selection") -@pipe_options.config_file() -@pipe_options.selection() -@pipe_options.flavor() -@pipe_options.run_mode() -def spectroscopic_selection_pipeline(config_file: str, **kwargs: Any) -> int: - """Run the spectroscopic selection data pipeline""" - project = RailProject.load_config(config_file) - flavors = project.get_flavor_args(kwargs.pop('flavor')) - selections = project.get_selection_args(kwargs.pop('selection')) - iter_kwargs = project.generate_kwargs_iterable(flavor=flavors, selection=selections) - ok = 0 - pipeline_name = "spec_selection" - pipeline_info = project.get_pipeline(pipeline_name) - input_catalog_name = pipeline_info['InputCatalogTag'] - pipeline_catalog_config = pipe_scripts.SpectroscopicPipelineCatalogConfiguration( - project, - source_catalog_tag=input_catalog_name, - sink_catalog_tag='degraded', - source_catalog_basename="output_dereddener_errors.pq", - ) - - for kw in iter_kwargs: - ok |= pipe_scripts.run_pipeline_on_catalog( - project, pipeline_name, - pipeline_catalog_config, - **kw, **kwargs, - ) - return ok - - -@run_group.command(name="inform") -@pipe_options.config_file() -@pipe_options.flavor() -@pipe_options.selection() -@pipe_options.run_mode() -def inform_single(config_file: str, **kwargs: Any) -> int: - """Run the inform pipeline""" - pipeline_name = "inform" - project = RailProject.load_config(config_file) - flavors = project.get_flavor_args(kwargs.pop('flavor')) - selections = project.get_selection_args(kwargs.pop('selection')) - iter_kwargs = project.generate_kwargs_iterable(flavor=flavors, selection=selections) - ok = 0 - for kw in iter_kwargs: - ok |= pipe_scripts.run_pipeline_on_single_input( - project, pipeline_name, - pipe_scripts.inform_input_callback, - **kw, **kwargs, - ) - return ok - - -@run_group.command(name="estimate") -@pipe_options.config_file() -@pipe_options.flavor() -@pipe_options.selection() -@pipe_options.run_mode() -def estimate_single(config_file: str, **kwargs: Any) -> int: - """Run the estimation pipeline""" - pipeline_name = "estimate" - project = RailProject.load_config(config_file) - flavors = project.get_flavor_args(kwargs.pop('flavor')) - selections = project.get_selection_args(kwargs.pop('selection')) - iter_kwargs = project.generate_kwargs_iterable(flavor=flavors, selection=selections) - ok = 0 - for kw in iter_kwargs: - ok |= pipe_scripts.run_pipeline_on_single_input( - project, pipeline_name, - pipe_scripts.estimate_input_callback, - **kw, **kwargs, - ) - return ok - - -@run_group.command(name="evaluate") -@pipe_options.config_file() -@pipe_options.flavor() -@pipe_options.selection() -@pipe_options.run_mode() -def evaluate_single(config_file: str, **kwargs: Any) -> int: - """Run the evaluation pipeline""" - pipeline_name = "evaluate" - project = RailProject.load_config(config_file) - flavors = project.get_flavor_args(kwargs.pop('flavor')) - selections = project.get_selection_args(kwargs.pop('selection')) - iter_kwargs = project.generate_kwargs_iterable(flavor=flavors, selection=selections) - ok = 0 - for kw in iter_kwargs: - ok |= pipe_scripts.run_pipeline_on_single_input( - project, pipeline_name, - pipe_scripts.evaluate_input_callback, - **kw, **kwargs, - ) - return ok - - -@run_group.command(name="pz") -@pipe_options.config_file() -@pipe_options.flavor() -@pipe_options.selection() -@pipe_options.run_mode() -def pz_single(config_file: str, **kwargs: Any) -> int: - """Run the pz pipeline""" - pipeline_name = "pz" - project = RailProject.load_config(config_file) - flavors = project.get_flavor_args(kwargs.pop('flavor')) - selections = project.get_selection_args(kwargs.pop('selection')) - iter_kwargs = project.generate_kwargs_iterable(flavor=flavors, selection=selections) - ok = 0 - for kw in iter_kwargs: - ok |= pipe_scripts.run_pipeline_on_single_input( - project, pipeline_name, - pipe_scripts.pz_input_callback, - **kw, **kwargs, - ) - return ok - - -@run_group.command(name="tomography") -@pipe_options.config_file() -@pipe_options.flavor() -@pipe_options.selection() -@pipe_options.run_mode() -def tomography_single(config_file : str, **kwargs: Any) -> int: - """Run the tomography pipeline""" - pipeline_name = "tomography" - project = RailProject.load_config(config_file) - flavors = project.get_flavor_args(kwargs.pop('flavor')) - selections = project.get_selection_args(kwargs.pop('selection')) - iter_kwargs = project.generate_kwargs_iterable(flavor=flavors, selection=selections) - ok = 0 - for kw in iter_kwargs: - ok |= pipe_scripts.run_pipeline_on_single_input( - project, pipeline_name, - pipe_scripts.tomography_input_callback, - **kw, **kwargs, - ) - return ok - - -@run_group.command(name="sompz") -@pipe_options.config_file() -@pipe_options.flavor() -@pipe_options.selection() -@pipe_options.run_mode() -def sompz_single(config_file: str, **kwargs: Any) -> int: - """Run the sompz pipeline""" - pipeline_name = "sompz" - project = RailProject.load_config(config_file) - flavors = project.get_flavor_args(kwargs.pop('flavor')) - selections = project.get_selection_args(kwargs.pop('selection')) - iter_kwargs = project.generate_kwargs_iterable(flavor=flavors, selection=selections) - ok = 0 - for kw in iter_kwargs: - ok |= pipe_scripts.run_pipeline_on_single_input( - project, pipeline_name, - pipe_scripts.sompz_input_callback, - **kw, **kwargs, - ) - return ok diff --git a/src/rail/cli/rail_pipe/pipe_options.py b/src/rail/cli/rail_pipe/pipe_options.py deleted file mode 100644 index a792782..0000000 --- a/src/rail/cli/rail_pipe/pipe_options.py +++ /dev/null @@ -1,205 +0,0 @@ -import enum - -import click - -from rail.cli.rail.options import ( - EnumChoice, - PartialOption, - PartialArgument, -) - - -__all__: list[str] = [ - "RunMode", - "config_path", - "force", - "flavor", - "input_dir", - "input_file", - "input_selection", - "input_tag", - "label", - "maglim", - "model_dir", - "model_name", - "model_path", - "output_dir", - "pdf_dir", - "pdf_path", - "run_mode", - "selection", - "output_dir", - "output_file", - "truth_path", - "seed", -] - - -class RunMode(enum.Enum): - """Choose the run mode""" - - dry_run = 0 - bash = 1 - slurm = 2 - - -config_file = PartialArgument( - "config_file", - type=click.Path(), -) - - -config_path = PartialOption( - "--config_path", - help="Path to configuration file", - type=click.Path(), -) - -force = PartialOption( - "--force", - help="Overwrite existing ceci configuration files", - is_flag=True, -) - -flavor = PartialOption( - "--flavor", - help="Pipeline configuraiton flavor", - multiple=True, - default=["baseline"], -) - - -label = PartialOption( - "--label", - help="File label (e.g., 'test' or 'train')", - type=str, -) - - -selection = PartialOption( - "--selection", - help="Data selection", - multiple=True, - default=["gold"], -) - - -input_dir = PartialOption( - "--input_dir", - help="Input Directory", - type=click.Path(), -) - - -input_file = PartialOption( - "--input_file", - type=click.Path(), - help="Input file", -) - - -input_selection = PartialOption( - "--input_selection", - help="Data selection", - multiple=True, - default=[None], -) - - -input_tag = PartialOption( - "--input_tag", - type=str, - default=None, - help="Input Catalog tag", -) - - -maglim = PartialOption( - "--maglim", - help="Magnitude limit", - type=float, - default=25.5, -) - - -model_dir = PartialOption( - "--model_dir", - help="Path to directory with model files", - type=click.Path(), -) - - -model_path = PartialOption( - "--model_path", - help="Path to model file", - type=click.Path(), -) - - -model_name = PartialOption( - "--model_name", - help="Model Name", - type=str, -) - -output_dir = PartialOption( - "--output_dir", - help="Path to for output files", - type=click.Path(), -) - - -pdf_dir = PartialOption( - "--pdf_dir", - help="Path to directory with p(z) files", - type=click.Path(), -) - - -pdf_path = PartialOption( - "--pdf_path", - help="Path to p(z) estimate file", - type=click.Path(), -) - - -run_mode = PartialOption( - "--run_mode", - type=EnumChoice(RunMode), - default="bash", - help="Mode to run script", -) - -size = PartialOption( - "--size", - type=int, - default=100_000, - help="Number of objects in file", -) - - -output_dir = PartialOption( - "--output_dir", - type=click.Path(), - help="Path to directory for output", -) - - -output_file = PartialOption( - "--output_file", - type=click.Path(), - help="Output file", -) - - -truth_path = PartialOption( - "--truth_path", - help="Path to truth redshift file", - type=click.Path(), -) - -seed = PartialOption( - "--seed", - help="Random seed", - type=int, -) diff --git a/src/rail/cli/rail_pipe/pipe_scripts.py b/src/rail/cli/rail_pipe/pipe_scripts.py deleted file mode 100644 index e8fe49a..0000000 --- a/src/rail/cli/rail_pipe/pipe_scripts.py +++ /dev/null @@ -1,848 +0,0 @@ -import os -import subprocess -import pprint -import time -import itertools -from typing import Any, Callable - -import numpy as np -import pyarrow.parquet as pq -import pyarrow.dataset as ds -import yaml - -from rail.utils import catalog_utils -from rail.core.stage import RailPipeline -from rail.utils.project import RailProject -from rail.cli.rail_pipe.pipe_options import RunMode - - -def handle_command( - run_mode: RunMode, - command_line: list[str], -) -> int: - """ Run a single command in the mode requested - - Parameters - ---------- - run_mode: RunMode - How to run the command, e.g., dry_run, bash or slurm - - command_line: list[str] - Tokens in the command line - - Returns - ------- - returncode: int - Status returned by the command. 0 for success, exit code otherwise - """ - print("subprocess:", *command_line) - _start_time = time.time() - print(">>>>>>>>") - if run_mode == RunMode.dry_run: - # print(command_line) - command_line.insert(0, "echo") - finished = subprocess.run(command_line, check=False) - elif run_mode == RunMode.bash: - # return os.system(command_line) - finished = subprocess.run(command_line, check=False) - elif run_mode == RunMode.slurm: - raise RuntimeError("handle_command should not be called with run_mode == RunMode.slurm") - - returncode = finished.returncode - _end_time = time.time() - _elapsed_time = _end_time - _start_time - print("<<<<<<<<") - print(f"subprocess completed with status {returncode} in {_elapsed_time} seconds\n") - return returncode - - -def handle_commands( - run_mode: RunMode, - command_lines: list[list[str]], - script_path:str | None=None, -) -> int: - """ Run a multiple commands in the mode requested - - Parameters - ---------- - run_mode: RunMode - How to run the command, e.g., dry_run, bash or slurm - - command_lines: list[list[str]] - List of commands to run, each one is the list of tokens in the command line - - script_path: str | None - Path to write the slurm submit script to - - Returns - ------- - returncode: int - Status returned by the commands. 0 for success, exit code otherwise - """ - - if run_mode in [RunMode.dry_run, RunMode.bash]: - for command_ in command_lines: - retcode = handle_command(run_mode, command_) - if retcode: - return retcode - return 0 - # At this point we are using slurm and need a script to send to batch - if script_path is None: - raise ValueError( - "handle_commands with run_mode == RunMode.slurm requires a path to a script to write", - ) - - try: - os.makedirs(os.path.dirname(script_path)) - except FileExistsError: - pass - with open(script_path, 'w', encoding='utf-8') as fout: - fout.write("#!/usr/bin/bash\n\n") - for command_ in command_lines: - com_line = ' '.join(command_) - fout.write(f"{com_line}\n") - - script_log = script_path.replace('.sh', '.log') - try: - with subprocess.Popen( - ["sbatch", "-o", script_log, "--mem", "16448", "-p", "milano", "--parsable", script_path], - stdout=subprocess.PIPE, - ) as sbatch: - assert sbatch.stdout - line = sbatch.stdout.read().decode().strip() - ret_val = int(line.split("|")[0]) - except TypeError as msg: - raise TypeError(f"Bad slurm submit: {msg}") from msg - return ret_val - - -def inspect(config_file: str) -> int: - """ Inspect a rail project file and print out the configuration - - Parameters - ---------- - config_file: str - Project file in question - - Returns - ------- - returncode: int - Status. 0 for success, exit code otherwise - """ - project = RailProject.load_config(config_file) - printable_config = pprint.pformat(project.config, compact=True) - print(f"RAIL Project: {project}") - print(">>>>>>>>") - print(printable_config) - print("<<<<<<<<") - return 0 - - -class PipelineCatalogConfiguration: - """Small plugin class to handle configuring a pipeline to run on a catalog - - Sub-classes will have to implment "get_convert_commands" function - """ - - def __init__( - self, - project: RailProject, - source_catalog_tag: str, - sink_catalog_tag: str, - source_catalog_basename: str | None=None, - sink_catalog_basename: str | None=None, - ): - self._project = project - self._source_catalog_tag = source_catalog_tag - self._sink_catalog_tag = sink_catalog_tag - self._source_catalog_basename = source_catalog_basename - self._sink_catalog_basename = sink_catalog_basename - - def get_source_catalog(self, **kwargs: Any) -> str: - """Get the name of the source (i.e. input) catalog file""" - return self._project.get_catalog( - self._source_catalog_tag, basename=self._source_catalog_basename, **kwargs, - ) - - def get_sink_catalog(self, **kwargs: Any) -> str: - """Get the name of the sink (i.e., output) catalog file""" - return self._project.get_catalog( - self._sink_catalog_tag, basename=self._sink_catalog_basename, **kwargs, - ) - - def get_script_path(self, pipeline_name: str, sink_dir: str, **kwargs: Any) -> str: - """Get path to use for the slurm batch submit script""" - selection = kwargs['selection'] - flavor = kwargs['flavor'] - return os.path.join( - sink_dir, - f"submit_{pipeline_name}_{selection}_{flavor}.sh" - ) - - def get_convert_commands(self, sink_dir: str) -> list[list[str]]: - """Get the set of commands to run after the pipeline to - convert output files - """ - raise NotImplementedError() - - -class TruthToObservedPipelineCatalogConfiguration(PipelineCatalogConfiguration): - - def get_convert_commands(self, sink_dir: str) -> list[list[str]]: - convert_command = [ - "tables-io", - "convert", - "--input", - f"{sink_dir}/output_dereddener_errors.pq", - "--output", - f"{sink_dir}/output.hdf5", - ] - convert_commands = [convert_command] - return convert_commands - - -class PhotmetricErrorsPipelineCatalogConfiguration(PipelineCatalogConfiguration): - - def get_convert_commands(self, sink_dir: str) -> list[list[str]]: - convert_command = [ - "tables-io", - "convert", - "--input", - f"{sink_dir}/output_dereddener_errors.pq", - "--output", - f"{sink_dir}/output.hdf5", - ] - convert_commands = [convert_command] - return convert_commands - - -class SpectroscopicPipelineCatalogConfiguration(PipelineCatalogConfiguration): - - def get_convert_commands(self, sink_dir: str) -> list[list[str]]: - convert_commands = [] - spec_selections = self._project.get_spec_selections() - for spec_selection_ in spec_selections.keys(): - convert_command = [ - "tables-io", - "convert", - "--input", - f"{sink_dir}/output_select_{spec_selection_}.pq", - "--output", - f"{sink_dir}/output_select_{spec_selection_}.hdf5", - ] - convert_commands.append(convert_command) - return convert_commands - - -class BlendingPipelineCatalogConfiguration(PipelineCatalogConfiguration): - - def get_convert_commands(self, sink_dir: str) -> list[list[str]]: - convert_command = [ - "tables-io", - "convert", - "--input", - f"{sink_dir}/output_blended.pq", - "--output", - f"{sink_dir}/output_blended.hdf5", - ] - convert_commands = [convert_command] - return convert_commands - - -def run_pipeline_on_catalog( - project: RailProject, - pipeline_name: str, - pipeline_catalog_configuration: PipelineCatalogConfiguration, - run_mode: RunMode=RunMode.bash, - **kwargs: Any, -) -> int: - """ Run a pipeline on an entire catalog - - Parameters - ---------- - project: RailProject - Object with project configuration - - pipeline_name: str - Name of the pipeline to run - - pipeline_catalog_configuration: PipelineCatalogConfiguration - Class to manage input and output catalogs and files - - run_mode: RunMode - How to run the command, e.g., dry_run, bash or slurm - - kwargs: Any - Additional parameters to specify pipeline, e.g., flavor, selection, ... - - - Returns - ------- - returncode: int - Status returned by the command. 0 for success, exit code otherwise - """ - - pipeline_info = project.get_pipeline(pipeline_name) - pipeline_path = project.get_path('pipeline_path', pipeline=pipeline_name, **kwargs) - - input_catalog_name = pipeline_info['InputCatalogTag'] - input_catalog = project.get_catalogs().get(input_catalog_name, {}) - - # Loop through all possible combinations of the iteration variables that are - # relevant to this pipeline - if (iteration_vars := input_catalog.get("IterationVars", {})) is not None: - iterations = itertools.product( - *[ - project.config.get("IterationVars", {}).get(iteration_var, "") - for iteration_var in iteration_vars - ] - ) - for iteration_args in iterations: - iteration_kwargs = { - iteration_vars[i]: iteration_args[i] - for i in range(len(iteration_vars)) - } - - source_catalog = pipeline_catalog_configuration.get_source_catalog(**kwargs, **iteration_kwargs) - sink_catalog = pipeline_catalog_configuration.get_sink_catalog(**kwargs, **iteration_kwargs) - sink_dir = os.path.dirname(sink_catalog) - script_path = pipeline_catalog_configuration.get_script_path( - pipeline_name, - sink_dir, - **kwargs, **iteration_kwargs, - ) - convert_commands = pipeline_catalog_configuration.get_convert_commands(sink_dir) - - ceci_command = project.generate_ceci_command( - pipeline_path=pipeline_path, - config=pipeline_path.replace('.yaml', '_config.yml'), - inputs=dict(input=source_catalog), - output_dir=sink_dir, - log_dir=sink_dir, - ) - - if not os.path.isfile(source_catalog): - raise ValueError(f"Input file {source_catalog} not found") - try: - handle_commands( - run_mode, - [ - ["mkdir", "-p", f"{sink_dir}"], - ceci_command, - *convert_commands, - ], - script_path, - ) - except Exception as msg: - print(msg) - return 1 - return 0 - - # FIXME need to get catalogs even if iteration not specified; this return fallback isn't ideal - return 1 - - -def run_pipeline_on_single_input( - project: RailProject, - pipeline_name: str, - input_callback: Callable, - run_mode: RunMode=RunMode.bash, - **kwargs: Any, -) -> int: - """ Run a single pipeline - - Parameters - ---------- - project: RailProject - Object with project configuration - - pipeline_name: str - Name of the pipeline to run - - input_callback: Callable - Function that creates dict of input files - - run_mode: RunMode - How to run the command, e.g., dry_run, bash or slurm - - kwargs: Any - Additional parameters to specify pipeline, e.g., flavor, selection, ... - - - Returns - ------- - returncode: int - Status returned by the command. 0 for success, exit code otherwise - """ - pipeline_path = project.get_path('pipeline_path', pipeline=pipeline_name, **kwargs) - pipeline_config = pipeline_path.replace('.yaml', '_config.yaml') - sink_dir = project.get_path('ceci_output_dir', **kwargs) - script_path = os.path.join(sink_dir, f"submit_{pipeline_name}.sh") - - input_files = input_callback(project, pipeline_name, sink_dir, **kwargs) - - command_line = project.generate_ceci_command( - pipeline_path=pipeline_path, - config=pipeline_config, - inputs=input_files, - output_dir=sink_dir, - log_dir=f"{sink_dir}/logs", - ) - - try: - statuscode = handle_commands(run_mode, [command_line], script_path) - except Exception as msg: - print(msg) - statuscode = 1 - return statuscode - - -def inform_input_callback( - project: RailProject, - pipeline_name: str, - sink_dir: str, # pylint: disable=unused-argument - **kwargs: Any, -) -> dict[str, str]: - """Make dict of input tags and paths for the inform pipeline - - Parameters - ---------- - project: RailProject - Object with project configuration - - pipeline_name: str - Name of the pipeline to run - - sink_dir: str - Path to output directory - - kwargs: Any - Additional parameters to specify pipeline, e.g., flavor, selection, ... - - Returns - ------- - input_files: dict[str, str] - Dictionary of input file tags and paths - """ - pipeline_info = project.get_pipeline(pipeline_name) - input_files = {} - input_file_tags = pipeline_info['InputFileTags'] - flavor = kwargs.pop('flavor', 'baseline') - for key, val in input_file_tags.items(): - input_file_flavor = val.get('flavor', flavor) - input_files[key] = project.get_file_for_flavor(input_file_flavor, val['tag'], **kwargs) - return input_files - - -def estimate_input_callback( - project: RailProject, - pipeline_name: str, - sink_dir: str, - **kwargs: Any, -) -> dict[str, str]: - """Make dict of input tags and paths for the estimate pipeline - - Parameters - ---------- - project: RailProject - Object with project configuration - - pipeline_name: str - Name of the pipeline to run - - sink_dir: str - Path to output directory - - kwargs: Any - Additional parameters to specify pipeline, e.g., flavor, selection, ... - - Returns - ------- - input_files: dict[str, str] - Dictionary of input file tags and paths - """ - pipeline_info = project.get_pipeline(pipeline_name) - input_files = {} - input_file_tags = pipeline_info['InputFileTags'] - flavor = kwargs.pop('flavor', 'baseline') - for key, val in input_file_tags.items(): - input_file_flavor = val.get('flavor', flavor) - input_files[key] = project.get_file_for_flavor(input_file_flavor, val['tag'], **kwargs) - - pz_algorithms = project.get_pzalgorithms() - for pz_algo_ in pz_algorithms.keys(): - input_files[f"model_{pz_algo_}"] = os.path.join(sink_dir, f'inform_model_{pz_algo_}.pkl') - return input_files - - -def evaluate_input_callback( - project: RailProject, - pipeline_name: str, - sink_dir: str, - **kwargs: Any, -) -> dict[str, str]: - """Make dict of input tags and paths for the evalute pipeline - - Parameters - ---------- - project: RailProject - Object with project configuration - - pipeline_name: str - Name of the pipeline to run - - sink_dir: str - Path to output directory - - kwargs: Any - Additional parameters to specify pipeline, e.g., flavor, selection, ... - - Returns - ------- - input_files: dict[str, str] - Dictionary of input file tags and paths - """ - pipeline_info = project.get_pipeline(pipeline_name) - input_files = {} - input_file_tags = pipeline_info['InputFileTags'] - flavor = kwargs.pop('flavor', 'baseline') - for key, val in input_file_tags.items(): - input_file_flavor = val.get('flavor', flavor) - input_files[key] = project.get_file_for_flavor(input_file_flavor, val['tag'], **kwargs) - - pdfs_dir = sink_dir - pz_algorithms = project.get_pzalgorithms() - for pz_algo_ in pz_algorithms.keys(): - input_files[f"input_evaluate_{pz_algo_}"] = os.path.join(pdfs_dir, f'estimate_output_{pz_algo_}.hdf5') - return input_files - - -def pz_input_callback( - project: RailProject, - pipeline_name: str, - sink_dir: str, # pylint: disable=unused-argument - **kwargs: Any, -) -> dict[str, str]: - """Make dict of input tags and paths for the pz pipeline - - Parameters - ---------- - project: RailProject - Object with project configuration - - pipeline_name: str - Name of the pipeline to run - - sink_dir: str - Path to output directory - - kwargs: Any - Additional parameters to specify pipeline, e.g., flavor, selection, ... - - Returns - ------- - input_files: dict[str, str] - Dictionary of input file tags and paths - """ - pipeline_info = project.get_pipeline(pipeline_name) - input_files = {} - input_file_tags = pipeline_info['InputFileTags'] - flavor = kwargs.pop('flavor') - for key, val in input_file_tags.items(): - input_file_flavor = val.get('flavor', flavor) - input_files[key] = project.get_file_for_flavor(input_file_flavor, val['tag'], **kwargs) - return input_files - - -def tomography_input_callback( - project: RailProject, - pipeline_name: str, - sink_dir: str, - **kwargs: Any, -) -> dict[str, str]: - """Make dict of input tags and paths for the tomography pipeline - - Parameters - ---------- - project: RailProject - Object with project configuration - - pipeline_name: str - Name of the pipeline to run - - sink_dir: str - Path to output directory - - kwargs: Any - Additional parameters to specify pipeline, e.g., flavor, selection, ... - - Returns - ------- - input_files: dict[str, str] - Dictionary of input file tags and paths - """ - pipeline_info = project.get_pipeline(pipeline_name) - input_files = {} - input_file_tags = pipeline_info['InputFileTags'] - flavor = kwargs.pop('flavor') - selection = kwargs.get('selection') - for key, val in input_file_tags.items(): - input_file_flavor = val.get('flavor', flavor) - input_files[key] = project.get_file_for_flavor(input_file_flavor, val['tag'], selection=selection) - - pdfs_dir = sink_dir - pz_algorithms = project.get_pzalgorithms() - for pz_algo_ in pz_algorithms.keys(): - input_files[f"input_{pz_algo_}"] = os.path.join(pdfs_dir, f'output_estimate_{pz_algo_}.hdf5') - - return input_files - - -def sompz_input_callback( - project: RailProject, - pipeline_name: str, - sink_dir: str, # pylint: disable=unused-argument - **kwargs: Any, -) -> dict[str, str]: - """Make dict of input tags and paths for the sompz pipeline - - Parameters - ---------- - project: RailProject - Object with project configuration - - pipeline_name: str - Name of the pipeline to run - - sink_dir: str - Path to output directory - - kwargs: Any - Additional parameters to specify pipeline, e.g., flavor, selection, ... - - Returns - ------- - input_files: dict[str, str] - Dictionary of input file tags and paths - """ - pipeline_info = project.get_pipeline(pipeline_name) - input_file_dict = {} - input_file_tags = pipeline_info['InputFileTags'] - flavor = kwargs.pop('flavor') - selection = kwargs.get('selection') - for key, val in input_file_tags.items(): - input_file_flavor = val.get('flavor', flavor) - input_file_dict[key] = project.get_file_for_flavor(input_file_flavor, val['tag'], selection=selection) - - input_files = dict( - train_deep_data = input_file_dict['input_train'], - train_wide_data = input_file_dict['input_train'], - test_spec_data = input_file_dict['input_test'], - test_balrog_data = input_file_dict['input_test'], - test_wide_data = input_file_dict['input_test'], - truth = input_file_dict['input_test'], - ) - return input_files - - -def subsample_data( - project: RailProject, - source_tag: str="degraded", - selection: str="gold", - flavor: str="baseline", - label: str="train_file", - run_mode: RunMode=RunMode.bash, -) -> int: - """Make dict of input tags and paths for the sompz pipeline - - Parameters - ---------- - project: RailProject - Object with project configuration - - source_tag: str - Tag for the input catalog - - selection: str - Which sub-selection of data to draw from - - flavor: str - Which analysis flavor to draw from - - label: str - Which label to apply to output dataset - - run_mode: RunMode - How to run the command, e.g., dry_run, bash or slurm - - Returns - ------- - returncode: int - Status returned by the command. 0 for success, exit code otherwise - """ - - hdf5_output = project.get_file_for_flavor(flavor, label, selection=selection) - output = hdf5_output.replace('.hdf5', '.parquet') - output_metadata = project.get_file_metadata_for_flavor(flavor, label) - basename = output_metadata['SourceFileBasename'] - output_dir = os.path.dirname(output) - size = output_metadata.get("NumObjects") - seed = output_metadata.get("Seed") - catalog_metadata = project.get_catalogs()['degraded'] - iteration_vars = catalog_metadata['IterationVars'] - - iterations = itertools.product( - *[ - project.config.get("IterationVars", {}).get(iteration_var, "") - for iteration_var in iteration_vars - ] - ) - sources = [] - for iteration_args in iterations: - iteration_kwargs = { - iteration_vars[i]: iteration_args[i] - for i in range(len(iteration_vars)) - } - - source_catalog = project.get_catalog( - source_tag, - selection=selection, - flavor=flavor, - basename=basename, - **iteration_kwargs, - ) - sources.append(source_catalog) - - if run_mode == RunMode.slurm: - raise NotImplementedError("subsample_data not set up to run under slurm") - - dataset = ds.dataset(sources) - num_rows = dataset.count_rows() - print("num rows", num_rows) - rng = np.random.default_rng(seed) - print("sampling", size) - - size = min(size, num_rows) - indices = rng.choice(num_rows, size=size, replace=False) - subset = dataset.take(indices) - print("writing", output) - - if run_mode == RunMode.bash: - os.makedirs(output_dir, exist_ok=True) - pq.write_table( - subset, - output, - ) - print("done") - handle_command(run_mode, ["tables-io", "convert", "--input", f"{output}", "--output", f"{hdf5_output}"]) - return 0 - - - -def build_pipelines( - project: RailProject, - flavor: str='baseline', - *, - force: bool=False, -) -> int: - """Build ceci pipeline configuraiton files for this project - - Parameters - ---------- - project: RailProject - Object with project configuration - - flavor: str - Which analysis flavor to draw from - - force: bool - Force overwriting of existing pipeline files - - Returns - ------- - returncode: int - Status returned by the command. 0 for success, exit code otherwise - """ - - output_dir = project.get_common_path('project_scratch_dir') - flavor_dict = project.get_flavor(flavor) - pipelines_to_build = flavor_dict['Pipelines'] - pipeline_overrides = flavor_dict.get('PipelineOverrides', {}) - do_all = 'all' in pipelines_to_build - - for pipeline_name, pipeline_info in project.get_pipelines().items(): - if not (do_all or pipeline_name in pipelines_to_build): - print(f"Skipping pipeline {pipeline_name} from flavor {flavor}") - continue - output_yaml = project.get_path('pipeline_path', pipeline=pipeline_name, flavor=flavor) - if os.path.exists(output_yaml): - if force: - print(f"Overwriting existing pipeline {output_yaml}") - else: - print(f"Skipping existing pipeline {output_yaml}") - continue - pipe_out_dir = os.path.dirname(output_yaml) - - try: - os.makedirs(pipe_out_dir) - except FileExistsError: - pass - - overrides = pipeline_overrides.get('default', {}) - overrides.update(**pipeline_overrides.get(pipeline_name, {})) - - pipeline_kwargs = pipeline_info.get('kwargs', {}) - for key, val in pipeline_kwargs.items(): - if val == 'SpecSelections': - pipeline_kwargs[key] = project.get_spec_selections() - elif val == 'PZAlgorithms': - pipeline_kwargs[key] = project.get_pzalgorithms() - elif val == 'NZAlgorithms': - pipeline_kwargs[key] = project.get_nzalgorithms() - elif val == 'Classifiers': - pipeline_kwargs[key] = project.get_classifiers() - elif val == 'Summarizers': - pipeline_kwargs[key] = project.get_summarizers() - elif val == 'ErrorModels': - pipeline_kwargs[key] = project.get_error_models() - - if overrides: - pipe_ctor_kwargs = overrides.pop('kwargs', {}) - pz_algorithms = pipe_ctor_kwargs.pop('PZAlgorithms', None) - if pz_algorithms: - orig_pz_algorithms = project.get_pzalgorithms().copy() - pipe_ctor_kwargs['algorithms'] = { - pz_algo_: orig_pz_algorithms[pz_algo_] for pz_algo_ in pz_algorithms - } - pipeline_kwargs.update(**pipe_ctor_kwargs) - stages_config = os.path.join(pipe_out_dir, f"{pipeline_name}_{flavor}_overrides.yml") - with open(stages_config, 'w', encoding='utf-8') as fout: - yaml.dump(overrides, fout) - else: - stages_config = None - - pipeline_class = pipeline_info['PipelineClass'] - catalog_tag = pipeline_info['CatalogTag'] - - if catalog_tag: - catalog_utils.apply_defaults(catalog_tag) - - tokens = pipeline_class.split('.') - module = '.'.join(tokens[:-1]) - class_name = tokens[-1] - log_dir = f"{output_dir}/logs/{pipeline_name}" - - print(f"Writing {output_yaml}") - - __import__(module) - RailPipeline.build_and_write( - class_name, - output_yaml, - None, - stages_config, - output_dir, - log_dir, - **pipeline_kwargs, - ) - - return 0 diff --git a/src/rail/cli/rail_pipe/reduce_roman_rubin_data.py b/src/rail/cli/rail_pipe/reduce_roman_rubin_data.py deleted file mode 100644 index e3bfe41..0000000 --- a/src/rail/cli/rail_pipe/reduce_roman_rubin_data.py +++ /dev/null @@ -1,206 +0,0 @@ -import math -import itertools -import os - -import pyarrow.compute as pc -import pyarrow.dataset as ds -import pyarrow.parquet as pq -from pyarrow import acero - -from .pipe_options import RunMode -from ...utils.project import RailProject - - -COLUMNS = [ - "galaxy_id", - "ra", - "dec", - "redshift", - "LSST_obs_u", - "LSST_obs_g", - "LSST_obs_r", - "LSST_obs_i", - "LSST_obs_z", - "LSST_obs_y", - "ROMAN_obs_F184", - "ROMAN_obs_J129", - "ROMAN_obs_H158", - "ROMAN_obs_W146", - "ROMAN_obs_Z087", - "ROMAN_obs_Y106", - "ROMAN_obs_K213", - "ROMAN_obs_R062", - "totalEllipticity", - "totalEllipticity1", - "totalEllipticity2", - "diskHalfLightRadiusArcsec", - "spheroidHalfLightRadiusArcsec", - "bulge_frac", - # "healpix", -] - -PROJECTIONS = [ - { - "mag_u_lsst": pc.field("LSST_obs_u"), - "mag_g_lsst": pc.field("LSST_obs_g"), - "mag_r_lsst": pc.field("LSST_obs_r"), - "mag_i_lsst": pc.field("LSST_obs_i"), - "mag_z_lsst": pc.field("LSST_obs_z"), - "mag_y_lsst": pc.field("LSST_obs_y"), - "totalHalfLightRadiusArcsec": pc.add( - pc.multiply( - pc.field("diskHalfLightRadiusArcsec"), - pc.subtract(pc.scalar(1), pc.field("bulge_frac")), - ), - pc.multiply( - pc.field("spheroidHalfLightRadiusArcsec"), - pc.field("bulge_frac"), - ) - ), - "_orientationAngle": pc.atan2(pc.field("totalEllipticity2"), pc.field("totalEllipticity1")), - }, - { - "major": pc.divide( - pc.field("totalHalfLightRadiusArcsec"), - pc.sqrt(pc.field("totalEllipticity")), - ), - "minor": pc.multiply( - pc.field("totalHalfLightRadiusArcsec"), - pc.sqrt(pc.field("totalEllipticity")), - ), - "orientationAngle": pc.multiply( - pc.scalar(0.5), - pc.subtract( - pc.field("_orientationAngle"), - pc.multiply( - pc.floor( - pc.divide( - pc.field("_orientationAngle"), - pc.scalar(2 * math.pi) - ) - ), - pc.scalar(2 * math.pi) - ) - ) - ), - } -] - - -def reduce_roman_rubin_data( - project: RailProject, - input_tag: str, - input_selection: str, - selection: str|None, - run_mode: RunMode=RunMode.bash, -) -> int: - - source_catalogs = [] - sink_catalogs = [] - catalogs = [] - predicates = [] - - if selection is not None: - selection_dict = project.get_selection(selection) - else: - selection_dict = {} - - - # FIXME - iteration_vars = list(project.config.get("IterationVars", {}).keys()) - if iteration_vars is not None: - iterations = itertools.product( - *[ - project.config.get("IterationVars", {}).get(iteration_var, "") - for iteration_var in iteration_vars - ] - ) - for iteration_args in iterations: - iteration_kwargs = { - iteration_vars[i]: iteration_args[i] - for i in range(len(iteration_vars)) - } - source_catalog = project.get_catalog(input_tag, selection=input_selection, **iteration_kwargs) - sink_catalog = project.get_catalog('reduced', selection=selection, **iteration_kwargs) - sink_dir = os.path.dirname(sink_catalog) - if selection_dict: - predicate = pc.field("LSST_obs_i") < selection_dict["maglim_i"][1] - else: - predicate = None - - if not os.path.isfile(source_catalog): - raise ValueError(f"Input file {source_catalog} not found") - - # FIXME properly warn - if os.path.isfile(sink_catalog): - # raise ValueError(f"Input file {source_catalog} not found") - print(f"Warning: output file {sink_catalog} found; may be rewritten...") - - source_catalogs.append(source_catalog) - sink_catalogs.append(sink_catalog) - - catalogs.append((source_catalog, sink_catalog)) - - predicates.append(predicate) - - dataset = ds.dataset( - source_catalog, - format="parquet", - ) - - scan_node = acero.Declaration( - "scan", - acero.ScanNodeOptions( - dataset, - columns=COLUMNS, - filter=predicate, - ), - ) - - filter_node = acero.Declaration( - "filter", - acero.FilterNodeOptions( - predicate, - ), - ) - - column_projection = { - k: pc.field(k) - for k in COLUMNS - } - projection = column_projection - project_nodes = [] - for _projection in PROJECTIONS: - for k, v in _projection.items(): - projection[k] = v - project_node = acero.Declaration( - "project", - acero.ProjectNodeOptions( - [v for k, v in projection.items()], - names=[k for k, v in projection.items()], - ) - ) - project_nodes.append(project_node) - - seq = [ - scan_node, - filter_node, - *project_nodes, - ] - plan = acero.Declaration.from_sequence(seq) - print(plan) - - if run_mode == RunMode.dry_run: - continue - if run_mode == RunMode.slurm: - raise NotImplementedError("run_mode == RunMode.slurm not implemented for reduce_roman_rubin") - - # batches = plan.to_reader(use_threads=True) - table = plan.to_table(use_threads=True) - print(f"writing dataset to {sink_catalog}") - os.makedirs(sink_dir, exist_ok=True) - pq.write_table(table, sink_catalog) - - print("writing completed") - - return 0 diff --git a/src/rail/utils/name_utils.py b/src/rail/utils/name_utils.py deleted file mode 100644 index 5e17579..0000000 --- a/src/rail/utils/name_utils.py +++ /dev/null @@ -1,342 +0,0 @@ -""" -Utility code to help define standard paths for various data products -""" - -import copy -import re -from functools import partial -from typing import Any, Mapping - - -CommonPaths = dict( - root='.', - scratch_root='.', - project='', - project_dir='{root}/projects/{project}', - project_scratch_dir='{scratch_root}/projects/{project}', - catalogs_dir='{root}/catalogs', - pipelines_dir='{project_dir}/pipelines', -) - -PathTemplates = dict( - pipeline_path="{pipelines_dir}/{pipeline}_{flavor}.yaml", - ceci_output_dir="{project_dir}/data/{selection}_{flavor}", - ceci_file_path="{tag}_{stage}.{suffix}", -) - - -def update_include_dict( - orig_dict: dict[str, Any], - include_dict: dict[str, Any], -) -> None: - """Update a dict by updating (instead of replacing) sub-dicts - - Parameters - ---------- - orig_dict: dict[str, Any] - Original dict - include_dict: dict[str, Any], - Dict used to update the original - """ - for key, val in include_dict.items(): - if isinstance(val, Mapping) and key in orig_dict: - update_include_dict(orig_dict[key], val) - else: - orig_dict[key] = val - - -def _get_required_interpolants(template: str) -> list[str]: - """ Get the list of interpolants required to format a template string - - Notes - ----- - 'interpolants' are strings that must be replaced in to format a string, - e.g., in "{project_dir}/models" "{project_dir}" would an interpolant - """ - return re.findall('{.*?}', template) - - -def _format_template(template: str, **kwargs: Any) -> str: - """ Resolve a specific template - - This is fault-tolerant and will not raise KeyError if some - of the required interpolants are missing, but rather just - leave them untouched - """ - - required_interpolants = re.findall('{.*?}', template) - interpolants = kwargs.copy() - - for interpolant_ in required_interpolants: - interpolants.setdefault(interpolant_.replace('}', '').replace('{', ''), interpolant_) - return template.format(**interpolants) - - -def _resolve_dict(source: dict, interpolants: dict) -> dict: - """ Recursively resolve a dictionary using interpolants - - Parameters - ---------- - source: dict - Dictionary of string templates - - interpolants: dict - Dictionary of strings used to resolve templates - - Returns - ------- - sink : dict - Dictionary of resolved templates - """ - if source: - sink = copy.deepcopy(source) - for k, v in source.items(): - v_interpolated: list | dict | str = "" - match v: - case dict(): - v_interpolated = _resolve_dict(source[k], interpolants) - case list(): - v_interpolated = [_resolve_dict(_v, interpolants) for _v in v] - case str(): - v_interpolated = v.format(**interpolants) - case _: - raise ValueError("Cannot interpolate type!") - - sink[k] = v_interpolated - else: - sink = {} - - return sink - - -def _resolve(templates: dict, source: dict, interpolants: dict) -> dict: - """ Resolve a set of templates using interpolants and allow for overrides - - Parameters - ---------- - templates: dict - Dictionary of string templates - - source: dict - Dictionary of overrides - - interpolants: dict - Dictionary of strings used to resolve templates - - - Returns - ------- - sink : dict - Dictionary of resoluved templates - """ - - sink = copy.deepcopy(templates) - if (overrides := source) is not None: - for k, v in overrides.items(): - sink[k] = v - for k, v in sink.items(): - match v: - case partial(): - sink[k] = v(**sink) - case _: - continue - sink = _resolve_dict(sink, interpolants) - return sink - - -class NameFactory: - """ Class defining standard paths for various data products - - """ - config_template = dict( - CommonPaths = CommonPaths, - PathTemplates = PathTemplates, - ) - - def __init__( - self, - config: dict | None=None, - templates: dict | None=None, - interpolants: dict | None=None, - ): - """ C'tor - - """ - if config is None: - config = {} - if templates is None: - templates = {} - if interpolants is None: - interpolants = {} - - self._config = copy.deepcopy(self.config_template) - for key, _val in config.items(): - if key in self._config: - self._config[key].update(**config[key]) - self._templates = copy.deepcopy(self._config['PathTemplates']) - self._templates.update(**templates) - self._interpolants: dict = {} - - self.templates = {} - for k, v in templates.items(): - self.templates[k] = partial(v.format, **templates) - - self.interpolants = self._config['CommonPaths'] - self.interpolants = interpolants - - def get_path_templates(self) -> dict: - return self._config['PathTemplates'] - - def get_common_paths(self) -> dict: - return self._config['CommonPaths'] - - @property - def interpolants(self) -> dict: - """ Return the dict of interpolants that are used to resolve templates """ - return self._interpolants - - @interpolants.setter - def interpolants(self, config: dict) -> None: - """ Update the dict of interpolants that are used to resolve templates """ - for key, value in config.items(): - new_value = value.format(**self.interpolants) - self.interpolants[key] = new_value - - @interpolants.deleter - def interpolants(self) -> None: - """ Reset the dict of interpolants that are used to resolve templates""" - self._interpolants = {} - - def resolve_from_config(self, config: dict) -> dict: - """ Resolve all the templates in a dict - - Parameters - ---------- - config: dict - Dictionary containing templates to be resolved - - Returns - ------- - resolved: dict - Dictionary with resolved versions of the templates - """ - resolved = _resolve( - self.templates, - config, - self.interpolants, - ) - config.update(resolved) - - return resolved - - def resolve_path(self, config: dict, path_key: str, **kwargs: Any) -> str: - """ Resolve a particular template in a config dict - - Parameters - ---------- - config: dict - Dictionary containing templates to be resolved - - path_key: str - Key for the specific template - - Returns - ------- - formatted: str - Resolved version of the template - """ - if (path_value := config.get(path_key)) is not None: - formatted = _format_template(path_value, **kwargs, **self.interpolants) - else: - raise KeyError(f"Path '{path_key}' not found in {config}") - return formatted - - - def get_template(self, section_key: str, path_key: str) -> str: - """ Return the template for a particular file type - - Parameters - ---------- - section_key: str - Which part of the config to look in - E.g., (CommonPaths, PathTemplates, Files) - - path_key: str - Key for the specific template - - Returns - ------- - the_template: str - Template for file of this type - """ - try: - section = self._config[section_key] - except KeyError as msg: - raise KeyError( - f"Config section {section_key} not present:" - f"available sections are {list(self._config.keys())}", - ) from msg - try: - return section[path_key] - except KeyError as msg: - raise KeyError( - f"Config key {path_key} not present in {section_key}:" - f"available paths are {list(section.keys())}", - ) from msg - - def resolve_template(self, section_key: str, path_key: str, **kwargs: Any) -> str: - """ Return the template for a particular file type - - Parameters - ---------- - section_key: str - Which part of the config to look in - E.g., (CommonPaths, PathTemplates, Files) - - path_key: str - Key for the specific template - - Returns - ------- - resovled: str - Resolved path - """ - template = self.get_template(section_key, path_key) - return _format_template(template, **self.interpolants, **kwargs) - - def resolve_path_template(self, path_key: str, **kwargs: Any) -> str: - """ Return a particular path templated - - Parameters - ---------- - path_key: str - Key for the specific template - - Returns - ------- - resovled: str - Resolved path - """ - template = self.get_template('PathTemplates', path_key) - interp_dict = self.interpolants.copy() - interp_dict.update(**kwargs) - return _format_template(template, **interp_dict) - - - def resolve_common_path(self, path_key: str, **kwargs: Any) -> str: - """ Return a particular common path template - - Parameters - ---------- - path_key: str - Key for the specific template - - Returns - ------- - resovled: str - Resolved path - """ - template = self.get_template('CommonPaths', path_key) - interp_dict = self.interpolants.copy() - interp_dict.update(**kwargs) - return _format_template(template, **interp_dict) diff --git a/src/rail/utils/project.py b/src/rail/utils/project.py deleted file mode 100644 index 4e5ce93..0000000 --- a/src/rail/utils/project.py +++ /dev/null @@ -1,319 +0,0 @@ -from __future__ import annotations - -import copy -from pathlib import Path -import itertools -from typing import Any - -import yaml - -from rail.utils import name_utils - - -class RailProject: - config_template: dict[str, dict] = { - "IterationVars": {}, - "CommonPaths": {}, - "PathTemplates": {}, - "Catalogs": {}, - "Files": {}, - "Pipelines": {}, - "Flavors": {}, - "Selections": {}, - "ErrorModels": {}, - "PZAlgorithms": {}, - "NZAlgorithms": {}, - "SpecSelections": {}, - "Classifiers": {}, - "Summarizers": {}, - } - - def __init__(self, name: str, config_dict: dict): - self.name = name - self._config_dict = config_dict - self.config = copy.deepcopy(self.config_template) - for k in self.config.keys(): - if (v := self._config_dict.get(k)) is not None: - self.config[k] = v - # self.interpolants = self.get_interpolants() - self.name_factory = name_utils.NameFactory( - config=self.config, - templates=config_dict.get('PathTemplates', {}), - interpolants=self.config.get("CommonPaths", {}), - ) - self.name_factory.resolve_from_config( - self.config.get("CommonPaths", {}) - ) - - def __repr__(self) -> str: - return f"{self.name}" - - @staticmethod - def load_config(config_file: str) -> RailProject: - """ Create and return a RailProject from a yaml config file""" - project_name = Path(config_file).stem - with open(config_file, "r", encoding='utf-8') as fp: - config_orig = yaml.safe_load(fp) - includes = config_orig.get('Includes', []) - config_dict = {} - # FIXME, make this recursive to allow for multiple layers of includes - for include_ in includes: - with open(include_, "r", encoding='utf-8') as fp: - config_extra = yaml.safe_load(fp) - name_utils.update_include_dict(config_dict, config_extra) - name_utils.update_include_dict(config_dict, config_orig) - project = RailProject(project_name, config_dict) - # project.resolve_common() - return project - - def get_path_templates(self) -> dict: - """ Return the dictionary of templates used to construct paths """ - return self.name_factory.get_path_templates() - - def get_path(self, path_key: str, **kwargs: Any) -> str: - """ Resolve and return a path using the kwargs as interopolants """ - return self.name_factory.resolve_path_template(path_key, **kwargs) - - def get_common_paths(self) -> dict: - """ Return the dictionary of common paths """ - return self.name_factory.get_common_paths() - - def get_common_path(self, path_key: str, **kwargs: Any) -> str: - """ Resolve and return a common path using the kwargs as interopolants """ - return self.name_factory.resolve_common_path(path_key, **kwargs) - - def get_files(self) -> dict: - """ Return the dictionary of specific files """ - return self.config.get("Files", {}) - - def get_file(self, name: str, **kwargs: Any) -> str: - """ Resolve and return a file using the kwargs as interpolants """ - files = self.get_files() - file_dict = files.get(name, None) - if file_dict is None: - raise KeyError(f"file '{name}' not found in {self}") - path = self.name_factory.resolve_path(file_dict, "PathTemplate", **kwargs) - return path - - def get_flavors(self) -> dict: - """ Return the dictionary of analysis flavor variants """ - flavors = self.config.get("Flavors", {}) - baseline = flavors.get("baseline", {}) - for k, v in flavors.items(): - if k != "baseline": - flavors[k] = baseline | v - - return flavors - - def get_flavor(self, name: str) -> dict: - """ Resolve the configuration for a particular analysis flavor variant """ - flavors = self.get_flavors() - flavor = flavors.get(name, None) - if flavor is None: - raise KeyError(f"flavor '{name}' not found in {self}") - return flavor - - def get_file_for_flavor(self, flavor: str, label: str, **kwargs: Any) -> str: - """ Resolve the file associated to a particular flavor and label - - E.g., flavor=baseline and label=train would give the baseline training file - """ - flavor_dict = self.get_flavor(flavor) - try: - file_alias = flavor_dict['FileAliases'][label] - except KeyError as msg: - raise KeyError(f"Label '{label}' not found in flavor '{flavor}'") from msg - return self.get_file(file_alias, flavor=flavor, label=label, **kwargs) - - def get_file_metadata_for_flavor(self, flavor: str, label: str) -> dict: - """ Resolve the metadata associated to a particular flavor and label - - E.g., flavor=baseline and label=train would give the baseline training metadata - """ - flavor_dict = self.get_flavor(flavor) - try: - file_alias = flavor_dict['FileAliases'][label] - except KeyError as msg: - raise KeyError(f"Label '{label}' not found in flavor '{flavor}'") from msg - return self.get_files()[file_alias] - - def get_selections(self) -> dict: - """ Get the dictionary describing all the selections""" - return self.config.get("Selections", {}) - - def get_selection(self, name: str) -> dict: - """ Get a particular selection by name""" - selections = self.get_selections() - selection = selections.get(name, None) - if selection is None: - raise KeyError(f"selection '{name}' not found in {self}") - return selection - - def get_error_models(self) -> dict: - """ Get the dictionary describing all the photometric error model algorithms""" - return self.config.get("ErrorModels", {}) - - def get_error_model(self, name: str) -> dict: - """ Get the information about a particular photometric error model algorithms""" - error_models = self.get_error_models() - error_model = error_models.get(name, None) - if error_model is None: - raise KeyError(f"error_models '{name}' not found in {self}") - return error_model - - def get_pzalgorithms(self) -> dict: - """ Get the dictionary describing all the PZ estimation algorithms""" - return self.config.get("PZAlgorithms", {}) - - def get_pzalgorithm(self, name: str) -> dict: - """ Get the information about a particular PZ estimation algorithm""" - pzalgorithms = self.get_pzalgorithms() - pzalgorithm = pzalgorithms.get(name, None) - if pzalgorithm is None: - raise KeyError(f"pz algorithm '{name}' not found in {self}") - return pzalgorithm - - def get_nzalgorithms(self) -> dict: - """ Get the dictionary describing all the PZ estimation algorithms""" - return self.config.get("NZAlgorithms", {}) - - def get_nzalgorithm(self, name: str) -> dict: - """ Get the information about a particular NZ estimation algorithm""" - nzalgorithms = self.get_nzalgorithms() - nzalgorithm = nzalgorithms.get(name, None) - if nzalgorithm is None: - raise KeyError(f"nz algorithm '{name}' not found in {self}") - return nzalgorithm - - def get_spec_selections(self) -> dict: - """ Get the dictionary describing all the spectroscopic selection algorithms""" - return self.config.get("SpecSelections", {}) - - def get_spec_selection(self, name: str) -> dict: - """ Get the information about a particular spectroscopic selection algorithm""" - spec_selections = self.get_spec_selections() - spec_selection = spec_selections.get(name, None) - if spec_selection is None: - raise KeyError(f"spectroscopic selection '{name}' not found in {self}") - return spec_selection - - def get_classifiers(self) -> dict: - """ Get the dictionary describing all the tomographic bin classification""" - return self.config.get("Classifiers", {}) - - def get_classifier(self, name: str) -> dict: - """ Get the information about a particular tomographic bin classification""" - classifiers = self.get_classifiers() - classifier = classifiers.get(name, None) - if classifier is None: - raise KeyError(f"tomographic bin classifier '{name}' not found in {self}") - return classifier - - def get_summarizers(self) -> dict: - """ Get the dictionary describing all the NZ summarization algorithms""" - return self.config.get("Summarizers", {}) - - def get_summarizer(self, name: str) -> dict: - """ Get the information about a particular NZ summarization algorithms""" - summarizers = self.get_summarizers() - summarizer = summarizers.get(name, None) - if summarizer is None: - raise KeyError(f"NZ summarizer '{name}' not found in {self}") - return summarizer - - def get_catalogs(self) -> dict: - """ Get the dictionary describing all the types of data catalogs""" - return self.config.get('Catalogs', {}) - - def get_catalog(self, catalog: str, **kwargs: Any) -> str: - """ Resolve the path for a particular catalog file""" - catalog_dict = self.config['Catalogs'].get(catalog, {}) - try: - path = self.name_factory.resolve_path(catalog_dict, "PathTemplate", **kwargs) - return path - except KeyError as msg: - raise KeyError(f"PathTemplate not found in {catalog}") from msg - - def get_pipelines(self) -> dict: - """ Get the dictionary describing all the types of ceci pipelines""" - return self.config.get("Pipelines", {}) - - def get_pipeline(self, name: str) -> dict: - """ Get the information about a particular ceci pipeline""" - pipelines = self.get_pipelines() - pipeline = pipelines.get(name, None) - if pipeline is None: - raise KeyError(f"pipeline '{name}' not found in {self}") - return pipeline - - def get_flavor_args(self, flavors: list[str]) -> list[str]: - """ Get the 'flavors' to iterate a particular command over - - Notes - ----- - If the flavor 'all' is included in the list of flavors, this - will replace the list with all the flavors defined in this project - """ - flavor_dict = self.get_flavors() - if 'all' in flavors: - return list(flavor_dict.keys()) - return flavors - - def get_selection_args(self, selections: list[str]) -> list[str]: - """ Get the 'selections' to iterate a particular command over - - Notes - ----- - If the selection 'all' is included in the list of selections, this - will replace the list with all the selections defined in this project - """ - selection_dict = self.get_selections() - if 'all' in selections: - return list(selection_dict.keys()) - return selections - - def generate_kwargs_iterable(self, **iteration_dict: Any) -> list[dict]: - iteration_vars = list(iteration_dict.keys()) - iterations = itertools.product( - *[ - iteration_dict.get(key, []) for key in iteration_vars - ] - ) - iteration_kwarg_list = [] - for iteration_args in iterations: - iteration_kwargs = { - iteration_vars[i]: iteration_args[i] - for i in range(len(iteration_vars)) - } - iteration_kwarg_list.append(iteration_kwargs) - return iteration_kwarg_list - - def generate_ceci_command( - self, - pipeline_path: str, - config: str|None, - inputs: dict, - output_dir: str='.', - log_dir: str='.', - **kwargs: Any, - ) -> list[str]: - - if config is None: - config = pipeline_path.replace('.yaml', '_config.yml') - - command_line = [ - "ceci", - f"{pipeline_path}", - f"config={config}", - f"output_dir={output_dir}", - f"log_dir={log_dir}", - ] - - for key, val in inputs.items(): - command_line.append(f"inputs.{key}={val}") - - - for key, val in kwargs.items(): - command_line.append(f"{key}={val}") - - return command_line diff --git a/tests/example.yaml b/tests/example.yaml deleted file mode 100644 index cb9dfaa..0000000 --- a/tests/example.yaml +++ /dev/null @@ -1,32 +0,0 @@ -# Include other configuration files -Includes: - - tests/example_common.yaml - -# These are used to make all the other paths -CommonPaths: - project: eac_test - sim_version: v1.1.3 - -# These define the variant configurations for the various parts of the analysis -Flavors: - # Baseline configuraiton, included in others by default - baseline: - Pipelines: ['all'] - FileAliases: # Set the training and test files - test: test_file_100k - train: train_file_100k - train_zCOSMOS: train_file_zCOSMOS_100k - train_cosmos: - Pipelines: ['pz', 'tomography'] - FileAliases: # Set the training and test files - test: test_file_100k - train: train_file_zCOSMOS_100k - gpz_gl: - Pipelines: ['inform', 'estimate', 'evaluate', 'pz'] - PipelineOverrides: # Override specifics for particular pipelines - default: - kwargs: - PZAlgorithms: ['gpz'] - inform: - inform_gpz: - gpz_method: GL diff --git a/tests/example_common.yaml b/tests/example_common.yaml deleted file mode 100644 index 72d8a2a..0000000 --- a/tests/example_common.yaml +++ /dev/null @@ -1,240 +0,0 @@ -# These are used to make all the other paths -CommonPaths: - root: /sdf/data/rubin/shared/pz - scratch_root: "{root}" - catalogs_dir: "{root}/data" - -# These are templates for catalogs produced in the early stages of the analysis -Catalogs: - truth: - PathTemplate: "{catalogs_dir}/{project}_{sim_version}/{healpix}/part-0.parquet" - IterationVars: ['healpix'] - reduced: - PathTemplate: "{catalogs_dir}/{project}_{sim_version}_{selection}/{healpix}/part-0.pq" - IterationVars: ['healpix'] - degraded: - PathTemplate: "{catalogs_dir}/{project}_{sim_version}_{selection}_{flavor}/{healpix}/{basename}" - IterationVars: ['healpix'] - -# These are templates for specific files, such as testing and training files -Files: - test_file_100k: - NumObjects: 100000 - Seed: 1234 - PathTemplate: "{catalogs_dir}/test/{project}_{selection}_baseline_100k.hdf5" - SourceFileBasename: output_dereddener_errors.pq - train_file_100k: - NumObjects: 100000 - Seed: 4321 - PathTemplate: "{catalogs_dir}/test/{project}_{selection}_baseline_100k.hdf5" - SourceFileBasename: output_dereddener_errors.pq - train_file_zCOSMOS_100k: - NumObjects: 100000 - Seed: 4321 - PathTemplate: "{catalogs_dir}/train/{project}_{selection}_zCOSMOS_100k.hdf5" - SourceFileBasename: output_select_zCOSMOS.pq - -# These are ceci pipelines that we will be running -Pipelines: - truth_to_observed: - PipelineClass: rail.pipelines.degradation.truth_to_observed.TruthToObservedPipeline - CatalogTag: roman_rubin - InputCatalogTag: reduced - kwargs: - error_models: ErrorModels - selectors: SpecSelections - blending: true - photometric_errors: - PipelineClass: rail.pipelines.degradation.apply_phot_errors.ApplyPhotErrorsPipeline - CatalogTag: roman_rubin - InputCatalogTag: reduced - kwargs: - error_models: ErrorModels - blending: - PipelineClass: rail.pipelines.degradation.blending_pipeline.BlendingPipeline - CatalogTag: roman_rubin - InputCatalogTag: degraded - kwargs: {} - spec_selection: - PipelineClass: rail.pipelines.degradation.spectroscopic_selection_pipeline.SpectroscopicSelectionPipeline - CatalogTag: roman_rubin - kwargs: - selectors: SpecSelections - InputCatalogTag: degraded - inform: - PipelineClass: rail.pipelines.estimation.inform_all.InformPipeline - CatalogTag: roman_rubin - kwargs: - algorithms: PZAlgorithms - InputFileTags: - input: - flavor: baseline - tag: train - estimate: - PipelineClass: rail.pipelines.estimation.estimate_all.EstimatePipeline - CatalogTag: roman_rubin - kwargs: - algorithms: PZAlgorithms - InputFileTags: - input: - tag: test - flavor: baseline - InputCatalogTag: degraded - evaluate: - PipelineClass: rail.pipelines.evaluation.evaluate_all.EvaluationPipeline - CatalogTag: roman_rubin - kwargs: - algorithms: PZAlgorithms - InputFileTags: - truth: - tag: test - flavor: baseline - InputCatalogTag: degraded - pz: - PipelineClass: rail.pipelines.estimation.pz_all.PzPipeline - CatalogTag: roman_rubin - kwargs: - algorithms: PZAlgorithms - InputFileTags: - input_train: - tag: train - input_test: - tag: test - tomography: - PipelineClass: rail.pipelines.estimation.tomography.TomographyPipeline - CatalogTag: roman_rubin - kwargs: - algorithms: PZAlgorithms - classifiers: Classifiers - summarizers: Summarizers - n_tomo_bins: 5 - InputFileTags: - truth: - tag: test - InputCatalogTag: degraded - - -# These describe the selections going from Input to Reduced catalog -Selections: - maglim_25.5: - maglim_i: [null, 25.5] - gold: - maglim_i: [null, 25.5] - blend: - maglim_i: [null, 26.0] - crap: - maglim_i: [null, 30.0] - all: - maglim_i: [null, null] - - -# These describe all the algorithms that emulate spectroscopic selections -SpecSelections: -# GAMA: -# Select: SpecSelection_GAMA -# Module: rail.creation.degraders.spectroscopic_selections -# BOSS: -# Select: SpecSelection_BOSS -# Module: rail.creation.degraders.spectroscopic_selections -# VVDSf02: -# Select: SpecSelection_VVDSf02 -# Module: rail.creation.degraders.spectroscopic_selections - zCOSMOS: - Select: SpecSelection_zCOSMOS - Module: rail.creation.degraders.spectroscopic_selections -# HSC: -# Select: SpecSelection_HSC -# Module: rail.creation.degraders.spectroscopic_selections - - -# These describe all the algorithms that estimate PZ -PZAlgorithms: - trainz: - Estimate: TrainZEstimator - Inform: TrainZInformer - Module: rail.estimation.algos.train_z - simplenn: - Estimate: SklNeurNetEstimator - Inform: SklNeurNetInformer - Module: rail.estimation.algos.sklearn_neurnet - fzboost: - Estimate: FlexZBoostEstimator - Inform: FlexZBoostInformer - Module: rail.estimation.algos.flexzboost - knn: - Estimate: KNearNeighEstimator - Inform: KNearNeighInformer - Module: rail.estimation.algos.k_nearneigh - gpz: - Estimate: GPzEstimator - Inform: GPzInformer - Module: rail.estimation.algos.gpz - - -# These describe all the algorithms that classify objects into tomographic bins -Classifiers: - equal_count: - Classify: EqualCountClassifier - Module: rail.estimation.algos.equal_count - uniform_binning: - Classify: UniformBinningClassifier - Module: rail.estimation.algos.uniform_binning - - -# These describe all the algorithms that summarize PZ information into NZ distributions -Summarizers: - naive_stack: - Summarize: NaiveStackMaskedSummarizer - Module: rail.estimation.algos.naive_stack - point_est_hist: - Summarize: PointEstHistMaskedSummarizer - Module: rail.estimation.algos.point_est_hist - - -# The describe the error models we use in the truth_to_observed pipeline -ErrorModels: - lsst: - ErrorModel: LSSTErrorModel - Module: rail.creation.degraders.photometric_errors - roman: - ErrorModel: RomanErrorModel - Module: rail.creation.degraders.photometric_errors - - -# These are variables that we iterate over when running over entire catalogs -IterationVars: - healpix: - - 10050 - - 10051 - - 10052 - - 10053 - - 10177 - - 10178 - - 10179 - - 10180 - - 10181 - - 10305 - - 10306 - - 10307 - - 10308 - - 10429 - - 10430 - - 10431 - - 10432 - - 10549 - - 10550 - - 10551 - - 10552 - - 10665 - - 10666 - - 10667 - - 10668 - - 10777 - - 10778 - - 10779 - - 9921 - - 9922 - - 9923 - - 9924 - - 9925 - diff --git a/tests/test_name_utils.py b/tests/test_name_utils.py deleted file mode 100644 index 8c11e86..0000000 --- a/tests/test_name_utils.py +++ /dev/null @@ -1,25 +0,0 @@ -import os - -import pytest -from rail.utils import name_utils - -def test_name_utils(): - - assert name_utils._get_required_interpolants('xx_{alice}_{bob}') == ['{alice}', '{bob}'] - assert name_utils._format_template('xx_{alice}_{bob}', alice='x', bob='x') == 'xx_x_x' - - test_dict = dict( - a='a_{alice}', - #b=['b1_{alice}', 'b2_{bob}'], - c=dict(c1='c1_{alice}', c2='c2_{alice}'), - #c2=dict(c2_1='c1_{alice}', c2_2=['c2_{alice}', 'c2_{bob}']), - ) - - name_utils._resolve_dict(test_dict, dict(alice='x', bob='y')) - - assert not name_utils._resolve_dict(None, {}) - with pytest.raises(ValueError): - name_utils._resolve_dict(dict(a=('s','d',)), dict(alice='x', bob='y')) - - - diff --git a/tests/test_project.py b/tests/test_project.py deleted file mode 100644 index 1ef4638..0000000 --- a/tests/test_project.py +++ /dev/null @@ -1,92 +0,0 @@ -import os - -import pytest -from rail.utils.project import RailProject - -def check_get_func(func, check_dict): - for key, val in check_dict.items(): - check_val = func(key) - if isinstance(check_val, dict): - for kk, vv in check_val.items(): - assert vv == val[kk] - with pytest.raises(KeyError): - func('does_not_exist') - - -def test_project(): - - project = RailProject.load_config('tests/example.yaml') - - print(project) - - templates = project.get_path_templates() - check_get_func(project.get_path, templates) - - common_paths = project.get_common_paths() - check_get_func(project.get_common_path, common_paths) - - files = project.get_files() - check_get_func(project.get_file, files) - - flavors = project.get_flavors() - check_get_func(project.get_flavor, flavors) - all_flavors = project.get_flavor_args(['all']) - assert set(all_flavors) == set(flavors.keys()) - assert project.get_flavor_args(['dummy'])[0] == 'dummy' - - project.get_file_for_flavor('baseline', 'test') - with pytest.raises(KeyError): - project.get_file_for_flavor('baseline', 'does not exist') - - project.get_file_metadata_for_flavor('baseline', 'test') - with pytest.raises(KeyError): - project.get_file_metadata_for_flavor('baseline', 'does not exist') - - selections = project.get_selections() - check_get_func(project.get_selection, selections) - all_selections = project.get_selection_args(['all']) - assert set(all_selections) == set(selections.keys()) - assert project.get_selection_args(['dummy'])[0] == 'dummy' - - itr = project.generate_kwargs_iterable( - selections=all_selections, - flavors=all_flavors, - ) - for x_ in itr: - assert isinstance(x_, dict) - - error_models = project.get_error_models() - check_get_func(project.get_error_model, error_models) - - pz_algos = project.get_pzalgorithms() - check_get_func(project.get_pzalgorithm, pz_algos) - - nz_algos = project.get_nzalgorithms() - check_get_func(project.get_nzalgorithm, nz_algos) - - spec_selections = project.get_spec_selections() - check_get_func(project.get_spec_selection, spec_selections) - - classifiers = project.get_classifiers() - check_get_func(project.get_classifier, classifiers) - - summarizers = project.get_summarizers() - check_get_func(project.get_summarizer, summarizers) - - catalogs = project.get_catalogs() - check_get_func(project.get_catalog, catalogs) - - pipelines = project.get_pipelines() - check_get_func(project.get_pipeline, pipelines) - - ceci_command = project.generate_ceci_command( - pipeline_path='dummy.yaml', - config=None, - inputs={'bob':'bob.pkl'}, - output_dir='.', - log_dir='.', - alice='bob', - ) - - -