Source code for ocgis.calc.engine

import logging

import numpy as np
from ocgis.base import get_variable_names
from ocgis.calc.base import AbstractMultivariateFunction
from ocgis.calc.eval_function import EvalFunction, MultivariateEvalFunction
from ocgis.util.logging_ocgis import ocgis_lh


[docs]class CalculationEngine(object): """ Manages calculation execution. :type grouping: list of temporal groupings (e.g. ['month','year']) :type funcs: :class:`list` of `function dictionaries` :param bool calc_sample_size: If ``True``, calculation sample sizes for the calculations. :param progress: A progress object to update. :type progress: :class:`~ocgis.util.logging_ocgis.ProgressOcgOperations` """ def __init__(self, grouping, funcs, calc_sample_size=False, spatial_aggregation=False, progress=None): self.grouping = grouping self.funcs = funcs self.calc_sample_size = calc_sample_size self.spatial_aggregation = spatial_aggregation self._tgds = {} self._progress = progress @property def has_multivariate_functions(self): multivariate_classes = [AbstractMultivariateFunction, MultivariateEvalFunction] return any([self._check_calculation_members_(self.funcs, k) for k in multivariate_classes]) @staticmethod def _check_calculation_members_(funcs, klass): """ Return True if a subclass of type `klass` is contained in the calculation list. :param funcs: Sequence of calculation dictionaries. :param klass: `ocgis.calc.base.OcgFunction` """ check = [issubclass(f['ref'], klass) for f in funcs] ret = True if any(check) else False return ret def execute(self, coll, file_only=False, tgds=None): """ :param :class:~`ocgis.SpatialCollection` coll: :param bool file_only: :param dict tgds: {'field_alias': :class:`ocgis.interface.base.dimension.temporal.TemporalGroupDimension`,...} """ from ocgis import VariableCollection # Select which dictionary will hold the temporal group dimensions. if tgds is None: tgds_to_use = self._tgds tgds_overloaded = False else: tgds_to_use = tgds tgds_overloaded = True # Group the variables. If grouping is None, calculations are performed on each element. if self.grouping is not None: ocgis_lh('Setting temporal groups: {0}'.format(self.grouping), 'calc.engine') for field in coll.iter_fields(): if tgds_overloaded: assert field.name in tgds_to_use else: if field.name not in tgds_to_use: tgds_to_use[field.name] = field.time.get_grouping(self.grouping) # Iterate over functions. for ugid, container in list(coll.children.items()): for field_name, field in list(container.children.items()): new_temporal = tgds_to_use.get(field_name) if new_temporal is not None: new_temporal = new_temporal.copy() # If the engine has a grouping, ensure it is equivalent to the new temporal dimension. if self.grouping is not None: try: compare = set(new_temporal.grouping) == set(self.grouping) # Types may be unhashable, compare directly. except TypeError: compare = new_temporal.grouping == self.grouping if not compare: msg = 'Engine temporal grouping and field temporal grouping are not equivalent. Perhaps ' \ 'optimizations are incorrect?' ocgis_lh(logger='calc.engine', exc=ValueError(msg)) out_vc = VariableCollection() for f in self.funcs: try: ocgis_lh('Calculating: {0}'.format(f['func']), logger='calc.engine') # Initialize the function. function = f['ref'](alias=f['name'], dtype=None, field=field, file_only=file_only, vc=out_vc, parms=f['kwds'], tgd=new_temporal, calc_sample_size=self.calc_sample_size, meta_attrs=f.get('meta_attrs'), spatial_aggregation=self.spatial_aggregation) # Allow a calculation to create a temporal aggregation after initialization. if new_temporal is None and function.tgd is not None: new_temporal = function.tgd.extract() except KeyError: # Likely an eval function which does not have the name key. function = EvalFunction(field=field, file_only=file_only, vc=out_vc, expr=self.funcs[0]['func'], meta_attrs=self.funcs[0].get('meta_attrs')) ocgis_lh('calculation initialized', logger='calc.engine', level=logging.DEBUG) # Return the variable collection from the calculations. out_vc = function.execute() for dv in out_vc.values(): # Any outgoing variables from a calculation must have an associated data type. try: assert dv.dtype is not None except AssertionError: assert isinstance(dv.dtype, np.dtype) # If this is a file only operation, there should be no computed values. if file_only: assert dv._value is None ocgis_lh('calculation finished', logger='calc.engine', level=logging.DEBUG) # Try to mark progress. Okay if it is not there. try: self._progress.mark() except AttributeError: pass out_field = function.field.copy() function_tag = function.tag # Format the returned field. Doing things like removing original data variables and modifying the # time dimension if necessary. Field functions handle all field modifications on their own, so bypass # in that case. if new_temporal is not None: new_temporal = new_temporal.extract() format_return_field(function_tag, out_field, new_temporal=new_temporal) # Add the calculation variables. for variable in list(out_vc.values()): variable = variable.extract() out_field.add_variable(variable) # Tag the calculation data as data variables. out_field.append_to_tags(function_tag, list(out_vc.keys())) # Update the field if there is a CRS. This will ensure accurate tagging of data variables. if out_field.crs is not None: # print 'here' out_field.crs.format_spatial_object(out_field) coll.children[ugid].children[field_name] = out_field return coll
def format_return_field(function_tag, out_field, new_temporal=None): # Remove the variables used by the calculation. try: to_remove = get_variable_names(out_field.get_by_tag(function_tag)) except KeyError: # Let this fail quietly as the tag may not exist on incoming fields. pass else: for tr in to_remove: out_field.remove_variable(tr) # Remove the original time variable and replace with the new one if there is a new time dimension. New # time dimensions may not be present for calculations that do not compute one. if new_temporal is not None: out_field.remove_variable(out_field.time) out_field.set_time(new_temporal, force=True)