Source code for ocgis.vmachine.core

from contextlib import contextmanager

import numpy as np
from ocgis.base import AbstractOcgisObject
from ocgis.constants import MPIOps
from ocgis.exc import SubcommNotFoundError, SubcommAlreadyCreatedError
from ocgis.vmachine.mpi import MPI_COMM, get_nonempty_ranks, MPI_SIZE, MPI_RANK, COMM_NULL, MPI_TYPE_MAPPING, \
    DummyMPIComm


[docs]class OcgVM(AbstractOcgisObject): """ Manages communicators for parallel execution. Provides access to a dummy communicator when running in serial. :param comm: The default communicator. :type comm: MPI Communicator or :class:`~ocgis.vmachine.mpi.DummyMPIComm` """ def __init__(self, comm=None): self._subcomms = {} self._current_comm_name = None if comm is None: comm = MPI_COMM self._comm = comm self._original_comm = comm if isinstance(comm, DummyMPIComm): is_dummy = True else: is_dummy = False self._is_dummy = is_dummy def __del__(self): try: self.finalize() except: pass @property def comm(self): if self.current_comm_name is None: ret = self._comm else: ret = self.get_subcomm(self.current_comm_name) return ret @property def comm_world(self): return MPI_COMM @property def current_comm_name(self): return self._current_comm_name @property def is_null(self): return self.comm == COMM_NULL @property def rank(self): if self.is_null: ret = None else: ret = self.comm.Get_rank() return ret @property def rank_global(self): return MPI_RANK @property def ranks(self): return range(self.size) @property def size(self): if self.is_null: ret = None else: ret = self.comm.Get_size() return ret @property def size_global(self): return MPI_SIZE def abort(self, msg=None, exc=None, int_errorcode=1): try: prefix = "OCGIS MPI Abort (Current Comm Name={}) Message: ".format(self._current_comm_name) if msg is not None: self.rank_print(prefix + msg) if exc is not None: self.rank_print('{}{}: {}'.format(prefix, exc.__class__.__name__, str(exc))) finally: self.comm_world.Abort(int_errorcode) def barrier(self): self.comm.Barrier() def Barrier(self): self.barrier() def bcast(self, *args, **kwargs): return self.comm.bcast(*args, **kwargs) def create_subcomm(self, name, ranks, is_current=False, clobber=False): if self._is_dummy: self._subcomms[name] = self._comm else: if len(ranks) == 0: self._subcomms[name] = COMM_NULL else: the_pool = self.comm.Get_group() sub_group = the_pool.Incl(ranks) try: ret = self.comm.Create(sub_group) if name in self._subcomms: if clobber: vm.free_subcomm(name=name) else: raise SubcommAlreadyCreatedError(name) self._subcomms[name] = ret finally: sub_group.Free() if is_current: self._current_comm_name = name return name def create_subcomm_by_emptyable(self, name, emptyable, **kwargs): live_ranks = self.get_live_ranks_from_object(emptyable) name = self.create_subcomm(name, live_ranks, **kwargs) return name, live_ranks def free_subcomm(self, subcomm=None, name=None): if not self._is_dummy: if subcomm is None: if name not in self._subcomms: raise SubcommNotFoundError(name) subcomm = self._subcomms.pop(name) if subcomm != COMM_NULL: subcomm.Free() def finalize(self): for v in self._subcomms.values(): self.free_subcomm(subcomm=v) self._subcomms = {} self._current_comm_name = None self._comm = self._original_comm def gather(self, *args, **kwargs): return self.comm.gather(*args, **kwargs) def get_live_ranks_from_object(self, target): return get_nonempty_ranks(target, self) @staticmethod def get_mpi_type(other): other = np.dtype(other) if other == np.int32: other = np.int32 elif other == np.int64: other = np.int64 if MPI_TYPE_MAPPING is None: ret = None else: ret = MPI_TYPE_MAPPING[other] return ret def scatter(self, *args, **kwargs): return self.comm.scatter(*args, **kwargs) @staticmethod def barrier_print(*args, **kwargs): from ocgis.vmachine.mpi import barrier_print barrier_print(*args, **kwargs) def get_subcomm(self, name): try: return self._subcomms[name] except KeyError: raise SubcommNotFoundError(name) @staticmethod def rank_print(*args, **kwargs): from ocgis.vmachine.mpi import rank_print rank_print(*args, **kwargs) def reduce(self, target, op, root=0): if self._is_dummy: ret = target else: ret = vm.comm.reduce(target, MPIOps.get_op(op), root=root) return ret def set_comm(self, name=None): self._current_comm_name = name def scoped(self, *args, **kwargs): return vm_scope(self, *args, **kwargs) def scoped_barrier(self, **kwargs): return vm_scoped_barrier(self, **kwargs) def scoped_by_emptyable(self, name, emptyable): live_ranks = self.get_live_ranks_from_object(emptyable) return self.scoped(name, live_ranks) def scoped_by_name(self, name): return vm_scoped_by_name(self, name)
@contextmanager def vm_scoped_barrier(vm_obj, first=True, last=True): if first: vm_obj.barrier() try: yield vm_obj finally: if last: vm_obj.barrier() @contextmanager def vm_scope(vm_obj, name, ranks): original = vm_obj.current_comm_name vm_obj.create_subcomm(name, ranks, is_current=True) try: yield vm_obj finally: vm_obj.free_subcomm(name=name) vm_obj.set_comm(name=original) @contextmanager def vm_scoped_by_name(vm_obj, name): original_name = vm_obj.current_comm_name vm_obj.set_comm(name=name) try: yield vm_obj finally: vm_obj.set_comm(name=original_name) vm = OcgVM()