import abc
import itertools
import logging
from collections import OrderedDict
from copy import deepcopy

import numpy as np
import six
from ocgis import constants
from ocgis import env
from ocgis.base import get_variables, get_dimension_names, AbstractOcgisObject
from ocgis.constants import TagName, DimensionMapKey, HeaderName
from ocgis.exc import SampleSizeNotImplemented, DefinitionValidationError, UnitsValidationError
from ocgis.util.broadcaster import broadcast_array_by_dimension_names
from ocgis.util.helpers import get_default_or_apply, get_iter
from ocgis.util.logging_ocgis import ocgis_lh
from ocgis.util.units import get_are_units_equal_by_string_or_cfunits
from ocgis.variable.base import Variable, VariableCollection, get_default_fill_value_from_dtype
from six.moves import zip_longest

# Standard dimension order for data arrays.
STANDARD_DIMENSIONS = (DimensionMapKey.REALIZATION, DimensionMapKey.TIME, DimensionMapKey.LEVEL,
                       DimensionMapKey.Y, DimensionMapKey.X)

[docs]@six.add_metaclass(abc.ABCMeta) class AbstractFunction(AbstractOcgisObject): """ Required class attributes to overload: * **description** (str): A arbitrary length string describing the calculation. * **key** (str): The function's unique string identifier. * **standard_name** (str): Standard name to store in output metadata. * **long_name** (str): Long name description to store in output metadata. :param alias: The string identifier to use for the calculation. :type alias: str :param dtype: The output data type. Set this to ``'int'`` or ``'float'`` to use the default datatype for the output format and NumPy installation (recommended). If a specific NumPy type is needed, provide the string representation of the type (i.e. ``'int32'``). :type dtype: str or :class:`numpy.core.multiarray.dtype` :param field: The field object over which the calculation is applied. :type field: :class:`ocgis.interface.base.Field` :param file_only: If ``True`` pass through but compute output sizes, etc. :type file_only: bool :param vc: The :class:`ocgis.interface.base.variable.VariableCollection` to append output calculation arrays to. If ``None`` a new collection will be created. :type vc: :class:`ocgis.interface.base.variable.VariableCollection` :param parms: A dictionary of parameter values. The includes any parameters for the calculation. :type parms: dict :param tgd: An instance of :class:`ocgis.interface.base.dimension.temporal.TemporalGroupDimension`. :type tgd: :class:`ocgis.interface.base.dimension.temporal.TemporalGroupDimension` :param calc_sample_size: If ``True``, also compute sample sizes for the calculation. :type calc_sample_size: bool :param meta_attrs: Contains overloads for variable and/or field attribute values. :type meta_attrs: :class:`ocgis.driver.parms.definition_helpers.MetadataAttributes` :param str tag: The tag to use for variable iteration on the source field (the source variables for calculation). """ @abc.abstractproperty def description(self): pass Group = None @abc.abstractproperty def key(self): pass #: The calculation's long name. @abc.abstractproperty def long_name(self): pass #: The calculation's standard name. @abc.abstractproperty def standard_name(self): pass #: The output data type is NumPy's default float representation. This may be overloaded by subclasses. dtype_default = 'float' #: The calculation's output units. Modify :meth:`get_output_units` for more complex units calculations. If the units #: are left as the default '_input_' then the input variable units are maintained. Otherwise, they will be set to #: units attribute value. The string flag is used to allow ``None`` units to be applied. units = '_input_' # standard empty dictionary to use for calculation outputs when the operation is file only _empty_fill = {'fill': None, 'sample_size': None} def __init__(self, alias=None, dtype=None, field=None, file_only=False, vc=None, parms=None, tgd=None, calc_sample_size=False, fill_value=None, meta_attrs=None, tag=TagName.DATA_VARIABLES, spatial_aggregation=False): self._curr_variable = None self._current_conformed_array = None self._dtype = dtype self.alias = alias or self.key self.fill_value = fill_value = vc or VariableCollection() self.field = field self.file_only = file_only self.parms = get_default_or_apply(parms, self._format_parms_, default={}) self.tgd = tgd self.calc_sample_size = calc_sample_size self.meta_attrs = deepcopy(meta_attrs) self.tag = tag self.spatial_aggregation = spatial_aggregation @property def dtype(self): return self._dtype
[docs] def aggregate_spatial(self, values, weights): """ Optional method overload for spatial aggregation. :param values: The input array with dimensions `(m, n)`. :type values: :class:`` :param weights: The input weights array with dimension matching ``value``. :type weights: :class:`numpy.core.multiarray.ndarray` :rtype: :class:`` """ ret =, weights=weights) return ret
[docs] def aggregate_temporal(self, values, **kwargs): """ Optional method to overload for temporal aggregation. :param values: The input five-dimensional array. :type values: :class:`` """ return, axis=0)
[docs] @abc.abstractmethod def calculate(self, values, **kwargs): """ The calculation method to overload. Values are explicitly passed to avoid dereferencing. Reducing along the time axis is required (i.e. axis=0). :param values: A three-dimensional array with dimensions (time, row, column). :type values: :class:`` :param kwargs: Any keyword parameters for the function. :rtype: :class:`` """ pass
[docs] def execute(self): """ Execute the computation over the input field. :rtype: :class:`ocgis.interface.base.variable.VariableCollection` """ # call the subclass execute method self._execute_() # allow the field metadata to be modified self.set_field_metadata() return
@classmethod def get_default_dtype(cls): """ :returns: The data type for the calculation. :rtype: type """ cls_dtype = cls.dtype_default if cls_dtype == 'int': ret = env.NP_INT elif cls_dtype == 'float': ret = env.NP_FLOAT else: ret = cls_dtype return ret def get_fill_variable(self, archetype, name, dimensions, file_only=False, dtype=None, add_repeat_record=True, add_repeat_record_archetype_name=True, variable_value=None): """ Initialize a return variable for a calculation. :param archetype: An archetypical variable to use for the creation of the output variable. :type archetype: :class:`ocgis.Variable` :param str name: Name of the output variable. :param dimensions: Dimension tuple for the variable creation. The dimensions from `archetype` or not used because output dimensions are often different. Temporal grouping is an example of this. :type dimensions: tuple(:class:`ocgis.Dimension`, ...) :param bool file_only: If `True`, this is a file-only operation and no value should be allocated. :param type dtype: The data type for the output variable. :param bool add_repeat_record: If `True`, add a repeat record to the variable containing the calculation key. :param add_repeat_record_archetype_name: If `True`, add the `archetype` name repeat record. :param variable_value: If not `None`, use this as the variable value during initialization. :return: :class:`ocgis.Variable` """ # If a default data type was provided at initialization, use this value otherwise use the data type from the # input value. if dtype is None: if self.dtype is None: dtype = archetype.dtype else: dtype = self.get_default_dtype() if self.fill_value is None: fill_value = archetype.fill_value else: fill_value = self.fill_value # Provide a default fill value for file only operations. This will guarantee variable value arrays are # initialized with a default fill value. if file_only and fill_value is None: fill_value = get_default_fill_value_from_dtype(dtype) attrs = OrderedDict() attrs['standard_name'] = self.standard_name attrs['long_name'] = self.long_name units = self.get_output_units(archetype) if add_repeat_record: repeat_record = [(HeaderName.CALCULATION_KEY, self.key)] if add_repeat_record_archetype_name: repeat_record.append((HeaderName.CALCULATION_SOURCE_VARIABLE, else: repeat_record = None fill = Variable(name=name, dimensions=dimensions, dtype=dtype, fill_value=fill_value, attrs=attrs, units=units, repeat_record=repeat_record, value=variable_value) if not file_only and variable_value is None: fill.allocate_value() return fill @staticmethod def get_fill_sample_size_variable(archetype, file_only): attrs = OrderedDict() attrs['standard_name'] = constants.DEFAULT_SAMPLE_SIZE_STANDARD_NAME attrs['long_name'] = constants.DEFAULT_SAMPLE_SIZE_LONG_NAME fill_sample_size = Variable(name='n_{}'.format(, dimensions=archetype.dimensions, attrs=attrs, dtype=np.int32) if not file_only: fill_sample_size.allocate_value() return fill_sample_size def get_function_definition(self): """ Return a dictionary representation of the function definition. :rtype: dict """ ret = {'key': self.key, 'alias': self.alias, 'parms': self.parms} return ret
[docs] def get_output_units(self, variable): """ Get the output units. :type variable: :class:`ocgis.interface.base.variable.Variable` :rtype: str """ if self.units == '_input_': ret = variable.units else: ret = self.units return ret
def get_sample_size(self, values): """ Calculate the sample size of the temporal group based on the mask. :type values: :class:`` :rtype: :class:`` """ to_sum = np.invert(values.mask) ret =, axis=0) ret =, mask=values.mask[0, :, :]) return ret @staticmethod def get_variable_value(variable): """ Select the appropriate value to use for the calculation. This will return the raw or aggregated values. If the data is not spatially aggregated, then this will have no effect. :param variable: :class:`ocgis.interface.base.variable.Variable` :rtype: :class:`` """ return variable.get_masked_value() def iter_calculation_targets(self, variable_names=None, yield_calculation_name=True, validate_units=True): if variable_names is None: seq = list(self.field.get_by_tag(self.tag)) else: seq = get_variables(variable_names, self.field) if len(seq) == 0: raise ValueError('No data variables available on the field.') for variable in seq: if validate_units: self.validate_units(variable) self._curr_variable = variable if yield_calculation_name: # Append the variable to the calculation if there are more than one to calculate across. if len(seq) > 1: calculation_name = '{}_{}'.format(self.alias, else: calculation_name = self.alias else: calculation_name = None if yield_calculation_name: yield variable, calculation_name else: yield variable def set_field_metadata(self): """ Modify the :class:~`ocgis.interface.base.field.Field` metadata dictionary. """ pass def set_variable_metadata(self, variable): """ Set variable level metadata. If units are to be updated, this must be done on the "units" attribute of the variable as this value is read directly from the variable object during conversion. """ pass
[docs] @classmethod def validate(cls, ops): """ Optional method to overload that validates the input :class:`ocgis.OcgOperations`. :type ops: :class:`ocgis.driver.operations.OcgOperations` :raises: :class:`ocgis.exc.DefinitionValidationError` """
@classmethod def validate_definition(cls, definition): """ Method to validate calculation definitions passed to operations. :param dict definition: The dictionary definition for the function as returned from :attr:`~ocgis.driver.parms.definition.Calc.value`. :raises: :class:`ocgis.exc.DefinitionValidationError` """
[docs] def validate_units(self, *args, **kwargs): """Optional method to overload for units validation at the calculation level."""
def _add_to_collection_(self, value): """ :param str units: The units for the derived variable. >>> units = 'kelvin' :param value: The value for the derived variable. :type value: :class:`` or dict >>> import numpy as np >>> value = np.zeros((2, 3, 4, 5, 6)) >>> value = *or* >>> sample_size = value.copy() >>> sample_size[:] = 5 >>> value = {'fill': value, 'sample_size': sample_size} :param parent_variables: A variable collection containing variable data used to derive the current output. :type parent_variables: :class:`ocgis.interface.base.variable.VariableCollection` :param str alias: The alias of the derived variable. :param type dtype: The type of the derived variable. :param fill_value: The mask fill value of the derived variable. """ dv = value['fill'] sample_size = value.get('sample_size', None) # allow more complex manipulations of metadata self.set_variable_metadata(dv) # overload the metadata attributes with any provided if self.meta_attrs is not None: dv.attrs.update(self.meta_attrs.value['variable']) self.field.attrs.update(self.meta_attrs.value['field']) # add the sample size if it is present in the fill dictionary if sample_size is not None: @abc.abstractmethod def _execute_(self): pass def _format_parms_(self, values): return values def _get_dimension_crosswalk_(self, variable): crosswalk = [] for dim in variable.dimensions: found = False for k in STANDARD_DIMENSIONS: dmap_variable = self.field.dimension_map.get_variable(k) if dmap_variable is not None: dmap_variable_dimensions = self.field.dimension_map.get_dimension(k) if in dmap_variable_dimensions: crosswalk.append(k) found = True break if not found: crosswalk.append( return crosswalk def _get_extra_indices_itr_and_src_names_(self, crosswalk, variable_shape): extra = [dn for dn in crosswalk if dn not in STANDARD_DIMENSIONS] extra = {e: {'index': crosswalk.index(e)} for e in extra} for v in list(extra.values()): v['size'] = variable_shape[v['index']] # Remove the extra dimensions. The only dimension names remaining are standard field names. src_names_extra_removed = [dn for dn in crosswalk if dn not in extra] # Handles extra dimensions in the outer loop. izl = [zip_longest([v['index']], list(range(v['size'])), fillvalue=v['index']) for v in list(extra.values())] itr = itertools.product(*izl) return itr, src_names_extra_removed def _get_parms_(self): return self.parms def _get_temporal_agg_fill_(self, variable, name, file_only, f=None, parms=None, add_repeat_record_archetype_name=True): # Depending on the computational class we may be just aggregating temporally or actually executing the # calculation. f = f or self.calculate # Allow parameter overloading. parms = parms or self.parms # Variable dimension names remapped to standard field dimension names. crosswalk = self._get_dimension_crosswalk_(variable) # Create the fill variable. time_axis = crosswalk.index(DimensionMapKey.TIME) fill_dimensions = list(variable.dimensions) fill_dimensions[time_axis] = self.tgd.dimensions[0] fill = self.get_fill_variable(variable, name, fill_dimensions, file_only, add_repeat_record_archetype_name=add_repeat_record_archetype_name) # Create the sample size variable. if self.calc_sample_size: fill_sample_size = self.get_fill_sample_size_variable(fill, file_only) else: fill_sample_size = None if not file_only: # Get value arrays. arr = self.get_variable_value(variable) arr_fill = self.get_variable_value(fill) if self.calc_sample_size: arr_fill_sample_size = self.get_variable_value(fill_sample_size) else: arr_fill_sample_size = None # Extra dimensions are not standard field dimensions. for yld in self._iter_conformed_arrays_(crosswalk, variable.shape, arr, arr_fill, arr_fill_sample_size): if not self.calc_sample_size: carr, carr_fill = yld carr_fill_sample_size = None else: carr, carr_fill, carr_fill_sample_size = yld # Some variables need access to the entire 5d conformed array. self._current_conformed_array = carr # Standard field dimension iterators. standard_itrs = [list(range(carr.shape[ii])) for ii in [0, 2]] standard_itrs.append(list(range(self.tgd.shape[0]))) # Execute the calculation. for ir, il, it in itertools.product(*standard_itrs): self._curr_group = self.tgd.dgroups[it] calculation_value = carr[ir, self._curr_group, il, :, :] assert calculation_value.ndim == 3 res = f(calculation_value, **parms) if self.spatial_aggregation: # Weights are not currently conformed so should not be used. res = self.aggregate_spatial(res, None)[ir, it, il, :, :] = res else: try:[ir, it, il, :, :] = except ValueError: if not hasattr(res, 'mask'): raise ValueError('Array return from calculation is not a masked array.') else: carr_fill.mask[ir, it, il, :, :] = res.mask if self.calc_sample_size: ss = self.get_sample_size(calculation_value)[ir, it, il, :, :] = carr_fill_sample_size.mask[ir, it, il, :, :] = ss.mask # Setting the values ensures the mask is updated on the output variables. fill.set_value(arr_fill) if self.calc_sample_size: fill_sample_size.set_value(arr_fill_sample_size) fill_sample_size.set_mask(fill.get_mask()) return {'fill': fill, 'sample_size': fill_sample_size} def _iter_conformed_arrays_(self, crosswalk, variable_shape, arr, arr_fill, arr_fill_sample_size): # Allow sample size array to be set to None. if arr_fill_sample_size is None: calc_sample_size = False else: calc_sample_size = self.calc_sample_size itr_extra_indices, src_names_extra_removed = self._get_extra_indices_itr_and_src_names_(crosswalk, variable_shape) # Loop for the extra dimensions. for indices in itr_extra_indices: # Slice for the extra dimensions. slc = [slice(None)] * arr.ndim for ii in indices: slc[ii[0]] = ii[1] extras_removed = arr.__getitem__(slc) extras_removed_fill = arr_fill.__getitem__(slc) if calc_sample_size: extras_removed_fill_sample_size = arr_fill_sample_size.__getitem__(slc) # Swap axes for the calculation values, the fill array for the calculation result, and (potentially) the # sample size. carr, carr_fill = [broadcast_array_by_dimension_names(t, src_names_extra_removed, STANDARD_DIMENSIONS) for t in [extras_removed, extras_removed_fill]] if calc_sample_size: carr_fill_sample_size = broadcast_array_by_dimension_names(extras_removed_fill_sample_size, src_names_extra_removed, STANDARD_DIMENSIONS) if not calc_sample_size: yld = (carr, carr_fill) else: yld = (carr, carr_fill, carr_fill_sample_size) yield yld def _set_derived_variable_alias_(self, dv, parent_variables): """ Set the alias of the derived variable. """ if len(self.field.variables) > 1: original_alias = dv.alias dv.alias = '{0}_{1}'.format(dv.alias, parent_variables[0].alias) msg = 'Alias updated to maintain uniqueness. Changing "{0}" to "{1}".'.format(original_alias, dv.alias) ocgis_lh(logger='calc.base', level=logging.WARNING, msg=msg)
@six.add_metaclass(abc.ABCMeta) class AbstractFieldFunction(AbstractFunction): @abc.abstractmethod def calculate(self, **kwargs): """ Add variables to the internal variable collection. These variables will be treated as data variables in the output field. Clients are responsible for creating well-formed variables with appropriate units, etc. as this function type bypasses most output variable preparation steps. :param kwargs: Typically keyword arguments to multivariate and/or parameterized functions. """ def _execute_(self): return self.calculate(**self.parms)
[docs]@six.add_metaclass(abc.ABCMeta) class AbstractParameterizedFunction(AbstractFunction): """ Base class for functions accepting parameters. """ #: Set to a tuple containing keys of required parameters. The keys correspond to keys in ``parms_definition``. parms_required = None def __init__(self, **kwargs): super(AbstractParameterizedFunction, self).__init__(**kwargs) self.parms = self._format_parms_(self.parms) @abc.abstractproperty def parms_definition(self): """ A dictionary describing the input parameters with keys corresponding to parameter names and values to their types. Set the type to `None` for no type checking. >>> {'threshold': float, 'operation': str, 'basis': None} """ dict @classmethod def validate_definition(cls, definition): AbstractFunction.validate_definition(definition) assert isinstance(definition, dict) from ocgis.ops.parms.definition import Calc key = constants.CALC_KEY_KEYWORDS if key not in definition: msg = 'Keyword arguments are required using the "{0}" key: {1}'.format(key, cls.parms_definition) raise DefinitionValidationError(Calc, msg) else: kwds = definition[key] try: required = cls.required_variables except AttributeError: # this function likely does not have required variables and is not a multivariate function assert not issubclass(cls, AbstractMultivariateFunction) else: kwds = kwds.copy() for r in required: kwds.pop(r, None) if not set(kwds.keys()).issubset(list(cls.parms_definition.keys())): msg = 'Keyword arguments incorrect. Correct keyword arguments are: {0}'.format(cls.parms_definition) raise DefinitionValidationError(Calc, msg) if cls.parms_required is not None: for k in cls.parms_required: if k not in kwds: msg = 'The keyword parameter "{0}" is required.'.format(k) raise DefinitionValidationError(Calc, msg) def _format_parms_(self, values): """ :param values: A dictionary containing the parameter values to check. :type values: dict[str, type] """ ret = {} for k, v in values.items(): try: if isinstance(v, self.parms_definition[k]): formatted = v else: formatted = self.parms_definition[k](v) # likely a nonetype except TypeError as e: if self.parms_definition[k] is None: formatted = v else: ocgis_lh(exc=e, logger='calc.base') # likely a required variable for a multivariate calculation except KeyError as e: if k in self.required_variables: formatted = values[k] else: ocgis_lh(exc=e, logger='calc.base') ret.update({k: formatted}) return ret
[docs]@six.add_metaclass(abc.ABCMeta) class AbstractUnivariateFunction(AbstractFunction): """ Base class for functions accepting a single univariate input. """ # Some calculations use a temporal group dimension but do not want any temporal aggregation to occur after the # calculation. should_temporally_aggregate = True #: Optional sequence of acceptable string units definitions for input variables. If this is set to ``None``, no unit #: validation will occur. required_units = None def __init__(self, *args, **kwargs): super(AbstractUnivariateFunction, self).__init__(*args, **kwargs) if self.calc_sample_size and self.tgd is None: msg = 'Sample sizes not relevant for scalar transforms with no temporal grouping. Setting to False.' ocgis_lh(msg=msg, logger='calc.base', level=logging.WARN) self.calc_sample_size = False def validate_units(self, variable): if self.required_units is not None: matches = [get_are_units_equal_by_string_or_cfunits(variable.units, target, try_cfunits=env.USE_CFUNITS) \ for target in self.required_units] if not any(matches): raise UnitsValidationError(variable, self.required_units, self.key) def _execute_(self): for variable, calculation_name in self.iter_calculation_targets(): crosswalk = self._get_dimension_crosswalk_(variable) fill_dimensions = variable.dimensions # Some calculations introduce a temporal group dimension during initialization and perform the temporal # aggregation implicit to the function. if self.tgd is not None and not self.should_temporally_aggregate: time_dimension_name = self.field.time.dimensions[0].name dimension_name = get_dimension_names(fill_dimensions) time_index = dimension_name.index(time_dimension_name) fill_dimensions = list(variable.dimensions) fill_dimensions[time_index] = self.tgd.dimensions[0] fill = self.get_fill_variable(variable, calculation_name, fill_dimensions, self.file_only) if not self.file_only: arr = self.get_variable_value(variable) arr_fill = self.get_variable_value(fill) for yld in self._iter_conformed_arrays_(crosswalk, variable.shape, arr, arr_fill, None): carr, carr_fill = yld res_calculation = self.calculate(carr, **self.parms)[:] = carr_fill.mask[:] = res_calculation.mask # Setting the values ensures the mask is updated on the output variables. fill.set_value(arr_fill) if self.tgd is not None and self.should_temporally_aggregate: fill = self._get_temporal_agg_fill_(fill, calculation_name, self.file_only, f=self.aggregate_temporal, parms={}) else: fill = {'fill': fill} self._add_to_collection_(fill)
[docs]@six.add_metaclass(abc.ABCMeta) class AbstractUnivariateSetFunction(AbstractUnivariateFunction): """ Base class for functions operating on a single variable but always reducing input data along the time dimension. """
[docs] def aggregate_temporal(self, *args, **kwargs): """ This operations is always implicit to :meth:`~ocgis.calc.base.AbstractFunction.calculate`. """ raise NotImplementedError('aggregation implicit to calculate method')
def _execute_(self): for variable, calculation_name in self.iter_calculation_targets(): # These executes a calculation with a temporal aggregation. fill = self._get_temporal_agg_fill_(variable, calculation_name, self.file_only) # Add the output to the variable collection self._add_to_collection_(value=fill) @classmethod def validate(cls, ops): if ops.calc_grouping is None: from ocgis.ops.parms.definition import Calc msg = 'Set functions must have a temporal grouping.' ocgis_lh(exc=DefinitionValidationError(Calc, msg), logger='calc.base')
[docs]@six.add_metaclass(abc.ABCMeta) class AbstractMultivariateFunction(AbstractFunction): """ Base class for functions operating on multivariate inputs. Multivariate functions also double as set functions (i.e. they can temporally group). This can be turned off at the class level by setting ``time_aggregation_external`` to ``False``. """ # : Optional dictionary mapping unit definitions for required variables. #: For example: required_units = {'tas':'fahrenheit','rhs':'percent'} required_units = None #: If True, time aggregation is external to the calculation and will require running the standard time aggregation #: methods. time_aggregation_external = True def __init__(self, *args, **kwargs): if kwargs.get('calc_sample_size') is True: exc = SampleSizeNotImplemented(self.__class__, 'Multivariate functions do not calculate sample size at this time.') ocgis_lh(exc=exc, logger='calc.base') else: AbstractFunction.__init__(self, *args, **kwargs) @abc.abstractproperty def required_variables(self): """ Required property/attribute containing the list of input variables expected by the function. >>> ('tas', 'rhs') """ def get_output_units(self, *args, **kwargs): return None def _get_slice_and_calculation_(self, f, ir, il, parms, value=None): if self.time_aggregation_external: ret = AbstractFunction._get_slice_and_calculation_(self, f, ir, il, parms, value=value) else: new_parms = {} for k, v in parms.items(): if k in self.required_variables: new_parms[k] = v[ir, self._curr_group, il, :, :] else: new_parms[k] = v cc = f(**new_parms) ret = (cc, None) return ret def _execute_(self): # Multivariate unit validation requires both variables as opposed to individual unit validation that typically # occurs in the calculation target iteration. self.validate_units() try: variable_names = [self.parms[r] for r in self.required_variables] # Try again without the parms dictionary assuming the variables are named appropriately with the parameter # dictionary. except KeyError: variable_names = self.required_variables # Get the variable calculation targets out of the field. itr = self.iter_calculation_targets(variable_names=variable_names, yield_calculation_name=False, validate_units=False) calculation_targets = {self.required_variables[idx]: var for idx, var in enumerate(itr)} # Synchronize the parameters that only contain the variable mappings with the other parameters passed in at # calculation initialization. keys = list(calculation_targets.keys()) crosswalks = [self._get_dimension_crosswalk_(calculation_targets[k]) for k in keys] variable_shapes = [calculation_targets[k].shape for k in keys] arrs = [self.get_variable_value(calculation_targets[k]) for k in keys] archetype = calculation_targets[keys[0]] fill = self.get_fill_variable(archetype, self.alias, archetype.dimensions, self.file_only, add_repeat_record_archetype_name=False) fill.units = self.get_output_units() arr_fill = self.get_variable_value(fill) itrs = [self._iter_conformed_arrays_(crosswalks[idx], variable_shapes[idx], arrs[idx], arr_fill, None) for idx in range(len(crosswalks))] for yld in zip(*itrs): parms = {} for idx in range(len(keys)): parms[keys[idx]] = yld[idx][0] for k, v in self.parms.items(): if k not in self.required_variables: parms.update({k: v}) res = self.calculate(**parms) carr_fill = yld[0][1][:] = carr_fill.mask[:] = res.mask if not self.file_only: fill.set_value(arr_fill) if self.tgd is not None: fill = self._get_temporal_agg_fill_(fill,, self.file_only, f=self.aggregate_temporal, parms={}, add_repeat_record_archetype_name=False) else: fill = {'fill': fill} self._add_to_collection_(fill) @classmethod def validate(cls, ops): if ops.calc_sample_size: from ocgis.ops.parms.definition import CalcSampleSize exc = DefinitionValidationError(CalcSampleSize, 'Multivariate functions do not calculate sample size at this time.') ocgis_lh(exc=exc, logger='calc.base') # ensure the required variables are present should_raise = False for c in ops.calc: if c['func'] == cls.key: kwds = c['kwds'] # Check the required variables are keyword arguments. if not len(set(kwds.keys()).intersection(set(cls.required_variables))) >= 2: should_raise = True break # Ensure the mapped aliases exist. fnames = [] for d in ops.dataset: try: for r in get_iter(d.rename_variable): fnames.append(r) except AttributeError: # Fields do not have a rename variable attribute. fnames += list(d.keys()) for xx in cls.required_variables: to_check = kwds[xx] if to_check not in fnames: should_raise = True break if should_raise: from ocgis.ops.parms.definition import Calc msg = 'These field names are missing for multivariate function "{0}": {1}.' exc = DefinitionValidationError(Calc, msg.format(cls.__name__, cls.required_variables)) ocgis_lh(exc=exc, logger='calc.base') def validate_units(self): if self.required_units is not None: for required_variable in self.required_variables: alias_variable = self.parms[required_variable] variable = self.field[alias_variable] source = variable.units target = self.required_units[required_variable] match = get_are_units_equal_by_string_or_cfunits(source, target, try_cfunits=env.USE_CFUNITS) if not match: raise UnitsValidationError(variable, target, self.key)
@six.add_metaclass(abc.ABCMeta) class AbstractKeyedOutputFunction(AbstractOcgisObject): @abc.abstractproperty def structure_dtype(self): dict