Source code for specsim.config

# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""Manage simulation configuration data.

Configuration data is normally loaded from a yaml file. Some standard
configurations are included with this package and can be loaded by name,
for example:

    >>> test_config = load_config('test')

Otherwise any filename with extension .yaml can be loaded::

    my_config = load_config('path/my_config.yaml')

Configuration data is accessed using attribute notation to specify a
sequence of keys:

    >>> test_config.name
    'Test Simulation'
    >>> test_config.atmosphere.airmass
    1.0

Use :meth:`Configuration.get_constants` to parse values with dimensions and
:meth:`Configuration.load_table` to load and interpolate tabular data.
"""
import os
import re
import warnings

import yaml

import numpy as np
import scipy.interpolate

import astropy.units
import astropy.table
import astropy.coordinates
import astropy.time
import astropy.io.fits
import astropy.wcs

from importlib.resources import files

# Extract a number from a string with optional leading and
# trailing whitespace.
_float_pattern = re.compile(r'\s*([-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?)\s*')


[docs] def parse_quantity(quantity, dimensions=None): """Parse a string containing a numeric value with optional units. The result is a :class:`Quantity <astropy.units.Quantity` object even when units are not present. Optional units are interpreted by :class:`astropy.units.Unit`. Some valid examples:: 1.23 1.23um 123 um / arcsec 1 electron/adu Used by :meth:`Configuration.get_constants`. Parameters ---------- quantity : str or astropy.units.Quantity String to parse. If a quantity is provided, it is checked against the expected dimensions and passed through. dimensions : str or astropy.units.Unit or None The units of the input quantity are expected to have the same dimensions as these units, if not None. Raises a ValueError if the input quantity is not convertible to dimensions. Returns ------- astropy.units.Quantity If dimensions is not None, the returned quantity will be converted to its units. Raises ------ ValueError Unable to parse quantity. """ if not isinstance(quantity, astropy.units.Quantity): # Look for a valid number starting the string. found_number = _float_pattern.match(quantity) if not found_number: raise ValueError('Unable to parse quantity.') value = float(found_number.group(1)) unit = quantity[found_number.end():] quantity = astropy.units.Quantity(value, unit) if dimensions is not None: try: if not isinstance(dimensions, astropy.units.Unit): dimensions = astropy.units.Unit(dimensions) quantity = quantity.to(dimensions) except (ValueError, astropy.units.UnitConversionError): raise ValueError('Quantity "{0}" is not convertible to {1}.' .format(quantity, dimensions)) return quantity
[docs] class Node(object): """A single node of a configuration data structure. """ def __init__(self, value, path=[]): self._assign('_value', value) self._assign('_path', path)
[docs] def keys(self): return self._value.keys()
def _assign(self, name, value): # Bypass our __setattr__ super(Node, self).__setattr__(name, value) def __str__(self): return '.'.join(self._path) def __getattr__(self, name): # This method is only called when self.name fails. child_path = self._path[:] child_path.append(name) if name in self._value: child_value = self._value[name] if isinstance(child_value, dict): return Node(child_value, child_path) else: # Return the actual value for leaf nodes. return child_value else: raise AttributeError( 'No such config node: {0}'.format('.'.join(child_path))) def __setattr__(self, name, value): # This method is always triggered by self.name = ... child_path = self._path[:] child_path.append(name) if name in self._value: child_value = self._value[name] if isinstance(child_value, dict): raise AttributeError( 'Cannot assign to non-leaf config node: {0}' .format('.'.join(child_path))) else: self._value[name] = value else: raise AttributeError( 'No such config node: {0}'.format('.'.join(child_path)))
[docs] class Configuration(Node): """Configuration parameters container and utilities. This class specifies the required top-level keys and delegates the interpretation and validation of their values to other functions. Parameters ---------- config : dict Dictionary of configuration parameters, normally obtained by parsing a YAML file with :func:`load`. Raises ------ ValueError Missing required top-level configuration key. Attributes ---------- wavelength : astropy.units.Quantity Array of linearly increasing wavelength values used for all simulation calculations. Determined by the wavelength_grid configuration parameters. abs_base_path : str Absolute base path used for loading tabulated data. Determined by the basepath configuration parameter. """ def __init__(self, config): Node.__init__(self, config) self.update()
[docs] def update(self): """Update this configuration. Updates the wavelength and abs_base_path attributes based on the current settings of the wavelength_grid and base_path nodes. """ # Initialize our wavelength grid. grid = self.wavelength_grid nwave = 1 + int(np.round( (grid.max - grid.min) / grid.step)) if nwave <= 0: raise ValueError('Invalid wavelength grid.') wave_unit = astropy.units.Unit(grid.unit) wave = (grid.min + grid.step * np.arange(nwave)) * wave_unit self._assign('wavelength', wave) # Use environment variables to interpolate {NAME} in the base path. base_path = self.base_path if base_path == '<PACKAGE_DATA>': self._assign( 'abs_base_path', str(files('specsim').joinpath('data'))) else: try: self._assign('abs_base_path', base_path.format(**os.environ)) except KeyError as e: raise ValueError('Environment variable not set: {0}.'.format(e))
[docs] def get_sky(self, parent): """Create a sky coordinate from a configuration node. Parameters ---------- parent : :class:`Node` Parent node in this configuration whose ``sky`` child will be processed. Returns ------- astropy.coordinates.SkyCoord Sky coordinates object constructed from node parameters. """ node = parent.sky frame = getattr(node, 'frame', None) return astropy.coordinates.SkyCoord(node.coordinates, frame=frame)
[docs] def get_timestamp(self, parent): """Create a timestamp from a configuration node. Parameters ---------- parent : :class:`Node` Parent node in this configuration whose ``timestamp`` child will be processed. Returns ------- astropy.time.Time Timestamp object constructed from node parameters. """ node = parent.timestamp format = getattr(node, 'format', None) scale = getattr(node, 'scale', None) return astropy.time.Time(node.when, format=format, scale=scale)
[docs] def get_constants(self, parent, required_names=None, optional_names=None): """Interpret a constants node in this configuration. Constant values are parsed by :func:`parse_quantity`. Parameters ---------- parent : :class:`Node` Parent node in this configuration whose ``constants`` child will be processed. required_names : iterable or None List of constant names that are required to be present for this method to succeed. If None, then no specific names are required. When specified, exactly these names are required and any other names will raise a RuntimeError. optional_names : iterable or None List of constant names that are optional for the parent node. When specified, all non-required names must be listed here or else a RuntimeError will be raised. Returns ------- dict Dictionary of (name, value) pairs where each value is an :class:`astropy.units.Quantity`. When ``required_names`` is specified, they are guaranteed to be present as keys of the returned dictionary. Raises ------ RuntimeError Constants present in the node do not match the required or optional names. """ constants = {} node = parent.constants if node is None: names = [] else: names = sorted(node.keys()) # All required names must be present, if specified. if required_names is not None: if not (set(required_names) <= set(names)): raise RuntimeError( 'Expected {0} for "{1}.constants"' .format(required_names, parent)) else: extra_names = set(names) - set(required_names) else: extra_names = set(names) # All non-required names must be listed in optional_names, if specified. if optional_names is not None: extra_names -= set(optional_names) # If either required_names or optional_names is specified, there # should not be any extra names. if required_names is not None or optional_names is not None: if extra_names: raise RuntimeError( 'Unexpected "{0}.constants" names: {1}.' .format(parent, extra_names)) for name in names: value = getattr(node, name) try: if isinstance(value, str): constants[name] = parse_quantity(value) else: constants[name] = astropy.units.Quantity(float(value)) except ValueError: raise RuntimeError('Invalid value for {0}.{1}: {2}' .format(node, name, value)) return constants
[docs] def load_table(self, parent, column_names, interpolate=True, as_dict=False): """Load and interpolate tabular data from one or more files. Reads a single file if parent.table.path exists, or else reads multiple files if parent.table.paths exists (and returns a dictionary). If as_dict is True, always return a dictionary using the 'default' key when only a single parent.table.path is present. """ node = parent.table # Check that the required column names are present. if isinstance(column_names, str): return_scalar = True column_names = [column_names] else: return_scalar = False required_names = column_names[:] if interpolate: required_names.append('wavelength') required_names = sorted(required_names) columns = node.columns config_column_names = sorted(columns.keys()) if required_names != config_column_names: raise RuntimeError( 'Expected {0} for "{1}"'.format(required_names, columns)) # Prepare the arguments we will send to astropy.table.Table.read() read_args = {} keys = node.keys() for key in ('format', 'hdu'): if key in keys: read_args[key] = getattr(node, key) # Prepare a list of paths we will load tables from. paths = [] path_keys = None try: # Look for parent.table.path first. paths.append(os.path.join(self.abs_base_path, node.path)) except AttributeError: path_keys = list(node.paths.keys()) for key in path_keys: path = getattr(node.paths, key) paths.append(os.path.join(self.abs_base_path, path)) tables = {} # Loop over tables to load. for i, path in enumerate(paths): key = path_keys[i] if path_keys else 'default' with warnings.catch_warnings(): warnings.simplefilter( 'ignore', category=astropy.units.core.UnitsWarning) table = astropy.table.Table.read(path, **read_args) if self.verbose: print('Loaded {0} rows from {1} with args {2}' .format(len(table), path, read_args)) # Loop over columns to read. loaded_columns = {} for config_name in config_column_names: column = getattr(columns, config_name) # Look up the column data by index first, then by name. try: column_data = table.columns[column.index] except AttributeError: column_data = table[column.name] column_values = column_data.data # Resolve column units. try: column_unit = astropy.units.Unit(column.unit) except AttributeError: column_unit = None try: override_unit = column.override_unit assert override_unit in (True, False) except AttributeError: override_unit = False if override_unit or column_data.unit is None: if column_unit is not None: # Assign the unit specified in our config. column_data.unit = column_unit else: if ((column_unit is not None) and (column_unit != column_data.unit)): raise RuntimeError( 'Units do not match for "{0}".'.format(column)) if interpolate: loaded_columns[config_name] = column_data else: unit = column_data.unit if unit: loaded_columns[config_name] = column_data.data * unit else: loaded_columns[config_name] = column_data.data if interpolate: wavelength_column = loaded_columns['wavelength'] # Convert wavelength column units if necesary. if wavelength_column.unit is None: raise RuntimeError( 'Wavelength units required for "{0}"'.format(columns)) wavelength = wavelength_column.data * wavelength_column.unit if wavelength.unit != self.wavelength.unit: wavelength = wavelength.to(self.wavelength.unit) # Initialize extrapolation if requested. try: fill_value = node.extrapolated_value bounds_error = False except AttributeError: fill_value = None bounds_error = True # Loop over other columns to interpolate onto our # wavelength grid. for column_name in column_names: interpolator = scipy.interpolate.interp1d( wavelength.value, loaded_columns[column_name].data, kind='linear', copy=False, bounds_error=bounds_error, fill_value=fill_value) interpolated_values = interpolator(self.wavelength.value) unit = loaded_columns[column_name].unit if unit: interpolated_values = interpolated_values * unit loaded_columns[column_name] = interpolated_values # Delete the temporary wavelength column now we have # finished using it for interpolation. del loaded_columns['wavelength'] if return_scalar: # Return just the one column that was requested. tables[key] = loaded_columns[column_names[0]] else: # Return a dictionary of all requested columns. tables[key] = loaded_columns if path_keys is None and not as_dict: return tables['default'] else: return tables
[docs] def load_table2d(self, node, y_column_name, x_column_prefix): """Read values for some quantity tabulated along 2 axes. Parameters ---------- filename : str Name of the file to read using :meth:`astropy.table.Table.read`. y_column_name : str Name of the column containing y coordinate values. x_column_prefix : str Prefix for column names at different values of the x coordinate. The remainder of the column name must be interpretable by :meth:`specsim.config.parse_quantity` as the x coordinate value. Values in each column correspond to ``data[:, x]``. format : str A table format supported by :meth:`astropy.table.Table.read`. Returns ------- :class:`scipy.interpolate.RectBivariateSpline` A 2D linear interpolator in (x,y) that handles units correctly. """ path = os.path.join(self.abs_base_path, node.path) fmt = getattr(node, 'format', None) table = astropy.table.Table.read(path, format=fmt) ny = len(table) y_col = table[y_column_name] y_value = np.array(y_col.data) if y_col.unit is not None: y_unit = y_col.unit else: y_unit = 1 # Look for columns whose name has the specified prefix. x_value, x_index = [], [] x_unit, data_unit = 1, 1 for i, colname in enumerate(table.colnames): if colname.startswith(x_column_prefix): # Parse the column name as a value. x = parse_quantity(colname[len(x_column_prefix):]) if x_unit == 1: x_unit = x.unit elif x_unit != x.unit: raise RuntimeError('Column unit mismatch: {0} != {1}.' .format(x_unit, x.unit)) if data_unit == 1: data_unit = table[colname].unit elif data_unit != table[colname].unit: raise RuntimeError('Data unit mismatch: {0} != {1}.' .format(data_unit, table[colname].unit)) x_value.append(x.value) x_index.append(i) # Extract values for each x,y pair. nx = len(x_value) data = np.empty((nx, ny)) for j, i in enumerate(x_index): data[j] = table.columns[i].data if self.verbose: print('Loaded {0} x {1} values from {2}.'.format(nx, ny, path)) # Build a 2D linear interpolator. interpolator = scipy.interpolate.RectBivariateSpline( x_value, y_value, data, kx=1, ky=1, s=0) # Return a wrapper that handles units. Note that default parameters # are used below to capture values (rather than references) in the # lambda closures. if x_unit != 1: get_x = lambda x, u=x_unit: x.to(u).value else: get_x = lambda x: np.asarray(x) if y_unit != 1: get_y = lambda y, u=y_unit: y.to(u).value else: get_y = lambda y: np.asarray(y) return ( lambda x, y, f=interpolator, u=data_unit: f.ev(get_x(x), get_y(y)) * u)
[docs] def load_fits2d(self, filename, xy_unit, **hdus): """Load the specified FITS file. The data in each image HDU is interpreted with x mapped to columns (NAXIS1) and y mapped to rows (NAXIS2). The x, y coordinates are inferred from each image HDUs basic WCS parameters. The returned interpolators expect parameter with units and return interpolated values with units. Units for x, y are specified via a parameter and assumed to be the same for all HDUs. Units for the interpolated data are taken from the BUNIT header keyword, and must be interpretable by astropy. Parameters ---------- filename : str Name of the file to read using :meth:`astropy.table.Table.read`. xy_unit : astropy.units.Unit Unit of x, y coordinates. hdus : dict Dictionary of name, hdu mappings where each hdu is specified by its integer offset or its name. Returns ------- dict Dictionary of 2D linear interpolators corresponding to each hdu, with the same keys that appear in the hdus input parameter. """ path = os.path.join(self.abs_base_path, filename) hdu_list = astropy.io.fits.open(path, memmap=False) interpolators = {} for name in hdus: hdu = hdu_list[hdus[name]] ny, nx = hdu.data.shape # Use the header WCS to reconstruct the x,y grids. wcs = astropy.wcs.WCS(hdu.header) x, _ = wcs.wcs_pix2world(np.arange(nx), [0], 0) _, y = wcs.wcs_pix2world([0], np.arange(ny), 0) try: bunit = hdu.header['BUNIT'] data_unit = astropy.units.Unit(bunit) except KeyError: raise KeyError('Missing BUNIT header keyword for HDU {0}.' .format(hdus[name])) except ValueError: raise ValueError('Invalid BUNIT "{0}" for HDU {1}.' .format(bunit, hdus[name])) dimensionless_interpolator = scipy.interpolate.RectBivariateSpline( x, y, hdu.data, kx=1, ky=1, s=0) # Note that the default arg values are used to capture the # current values of dimensionless_interpolator and data_unit # in the closure of this inner function. def interpolator(x, y, f=dimensionless_interpolator, u=data_unit): return f.ev(x.to(xy_unit).value, y.to(xy_unit).value) * u interpolators[name] = interpolator if self.verbose: print('Loaded {0} from HDU[{1}] of {2}.' .format(name, hdus[name], path)) hdu_list.close() return interpolators
[docs] def load_config(name, config_type=Configuration): """Load configuration data from a YAML file. Valid configuration files are YAML files containing no custom types, no sequences (lists), and with all mapping (dict) keys being valid python identifiers. Parameters ---------- name : str Name of the configuration to load, which can either be a pre-defined name or else the name of a yaml file (with extension .yaml) to load. Pre-defined names are mapped to corresponding files in this package's data/config/ directory. Returns ------- Configuration Initialized configuration object. Raises ------ ValueError File name has wrong extension or does not exist. RuntimeError Configuration data failed a validation test. """ base_name, extension = os.path.splitext(name) if extension not in ('', '.yaml'): raise ValueError('Config file must have .yaml extension.') if extension: file_name = name else: file_name = str(files('specsim').joinpath('data', 'config', f'{name}.yaml')) if not os.path.isfile(file_name): raise ValueError('No such config file "{0}".'.format(file_name)) # Validate that all mapping keys are valid python identifiers. valid_key = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*\Z') with open(file_name) as f: next_value_is_key = False for token in yaml.scan(f): if isinstance( token, (yaml.BlockSequenceStartToken, yaml.FlowSequenceStartToken)): raise RuntimeError('Config sequences not implemented yet.') if next_value_is_key: if not isinstance(token, yaml.ScalarToken): raise RuntimeError( 'Invalid config key type: {0}'.format(token)) if not valid_key.match(token.value): raise RuntimeError( 'Invalid config key name: {0}'.format(token.value)) next_value_is_key = isinstance(token, yaml.KeyToken) with open(file_name) as f: return config_type(yaml.safe_load(f))