Source code for pylag.parallel.mediator

"""
Module containing the derived class MPIMediator, which helps to manage
access to input data during parallel execution.

See Also
--------
pylag.mediator - Serial mediator for serial execution
"""

import numpy as np
import logging
import traceback

# For parallel simulations
from mpi4py import MPI

from pylag.data_types_python import DTYPE_INT
from pylag.file_reader import FileReader
from pylag.file_reader import DiskFileNameReader
from pylag.file_reader import NetCDFDatasetReader
from pylag.mediator import Mediator


[docs]class MPIMediator(Mediator): """ MPI mediator MPI mediator for parallel runs. Parameters ---------- config : ConfigParser Run configuration object data_source : str String indicating what type of data the datetime objects will be associated with. Options are: 'ocean', 'atmosphere', and 'wave'. start_datetime : Datetime Simulation start date/time. end_datetime : Datetime Simulation end date/time. Attributes ---------- config : ConfigParser Run configuration object file_reader : pylag.FileReader FileReader object. """ def __init__(self, config, data_source, datetime_start, datetime_end): self.config = config # MPI objects and variables comm = MPI.COMM_WORLD rank = comm.Get_rank() # Only the root process accesses the file system if rank == 0: try: file_name_reader = DiskFileNameReader() dataset_reader = NetCDFDatasetReader() self.file_reader = FileReader(config, data_source, file_name_reader, dataset_reader, datetime_start, datetime_end) except Exception as e: logger = logging.getLogger(__name__) logger.error('Caught exception when reading input file. '\ 'Terminating all tasks ...') logger.error(traceback.format_exc()) comm.Abort() else: self.file_reader = None
[docs] def setup_data_access(self, datetime_start, datetime_end): # MPI objects and variables comm = MPI.COMM_WORLD rank = comm.Get_rank() if rank == 0: try: self.file_reader.setup_data_access(datetime_start, datetime_end) except Exception as e: logger = logging.getLogger(__name__) logger.error('Caught exception when setting up data access. '\ 'Terminating all tasks ...') logger.error(traceback.format_exc()) comm.Abort()
[docs] def update_reading_frames(self, time): # MPI objects and variables comm = MPI.COMM_WORLD rank = comm.Get_rank() if rank == 0: try: self.file_reader.update_reading_frames(time) except Exception as e: logger = logging.getLogger(__name__) logger.error('Caught exception when updating reading frames. '\ 'Terminating all tasks ...') logger.error(traceback.format_exc()) comm.Abort()
[docs] def get_dimension_variable(self, var_name): # MPI objects and variables comm = MPI.COMM_WORLD rank = comm.Get_rank() if rank == 0: try: var = self.file_reader.get_dimension_variable(var_name) except Exception as e: logger = logging.getLogger(__name__) logger.error('Caught exception when getting dimension variable. '\ 'Terminating all tasks ...') logger.error(traceback.format_exc()) comm.Abort() else: var = None var = comm.bcast(var, root=0) return var
[docs] def get_grid_variable(self, var_name, var_dims, var_type): # MPI objects and variables comm = MPI.COMM_WORLD rank = comm.Get_rank() if rank == 0: try: var = self.file_reader.get_grid_variable(var_name).astype(var_type) except Exception as e: logger = logging.getLogger(__name__) logger.error('Caught exception when getting grid variable. '\ 'Terminating all tasks ...') logger.error(traceback.format_exc()) comm.Abort() else: var = np.empty(var_dims, dtype=var_type) comm.Bcast(var, root=0) return var
[docs] def get_time_at_last_time_index(self): # MPI objects and variables comm = MPI.COMM_WORLD rank = comm.Get_rank() if rank == 0: try: time = self.file_reader.get_time_at_last_time_index() except Exception as e: logger = logging.getLogger(__name__) logger.error('Caught exception when getting last time index. '\ 'Terminating all tasks ...') logger.error(traceback.format_exc()) comm.Abort() else: time = None time = comm.bcast(time, root=0) return time
[docs] def get_time_at_next_time_index(self): # MPI objects and variables comm = MPI.COMM_WORLD rank = comm.Get_rank() if rank == 0: try: time = self.file_reader.get_time_at_next_time_index() except Exception as e: logger = logging.getLogger(__name__) logger.error('Caught exception when getting next time index. '\ 'Terminating all tasks ...') logger.error(traceback.format_exc()) comm.Abort() else: time = None time = comm.bcast(time, root=0) return time
[docs] def get_grid_variable_dimensions(self, var_name): # MPI objects and variables comm = MPI.COMM_WORLD rank = comm.Get_rank() if rank == 0: try: dimensions = self.file_reader.get_grid_variable_dimensions(var_name) except Exception as e: logger = logging.getLogger(__name__) logger.error('Caught exception when getting variable dimensions. ' \ 'Terminating all tasks ...') logger.error(traceback.format_exc()) comm.Abort() else: dimensions = None dimensions = comm.bcast(dimensions, root=0) return dimensions
[docs] def get_variable_dimensions(self, var_name, include_time=True): # MPI objects and variables comm = MPI.COMM_WORLD rank = comm.Get_rank() if rank == 0: try: dimensions = self.file_reader.get_variable_dimensions(var_name, include_time) except Exception as e: logger = logging.getLogger(__name__) logger.error('Caught exception when getting variable dimensions. ' \ 'Terminating all tasks ...') logger.error(traceback.format_exc()) comm.Abort() else: dimensions = None dimensions = comm.bcast(dimensions, root=0) return dimensions
[docs] def get_variable_shape(self, var_name, include_time=True): # MPI objects and variables comm = MPI.COMM_WORLD rank = comm.Get_rank() if rank == 0: try: shape = self.file_reader.get_variable_shape(var_name, include_time) except Exception as e: logger = logging.getLogger(__name__) logger.error('Caught exception when getting variable shape. ' \ 'Terminating all tasks ...') logger.error(traceback.format_exc()) comm.Abort() else: shape = None shape = comm.bcast(shape, root=0) return shape
[docs] def get_time_dependent_variable_at_last_time_index(self, var_name, var_dims, var_type): # MPI objects and variables comm = MPI.COMM_WORLD rank = comm.Get_rank() if rank == 0: try: var = self.file_reader.get_time_dependent_variable_at_last_time_index(var_name).astype(var_type) except Exception as e: logger = logging.getLogger(__name__) logger.error('Caught exception when getting time variable at '\ 'last time index. Terminating all tasks ...') logger.error(traceback.format_exc()) comm.Abort() else: var = np.empty(var_dims, dtype=var_type) comm.Bcast(var, root=0) return var
[docs] def get_time_dependent_variable_at_next_time_index(self, var_name, var_dims, var_type): # MPI objects and variables comm = MPI.COMM_WORLD rank = comm.Get_rank() if rank == 0: try: var = self.file_reader.get_time_dependent_variable_at_next_time_index(var_name).astype(var_type) except Exception as e: logger = logging.getLogger(__name__) logger.error('Caught exception when getting time variable at '\ 'next time index. Terminating all tasks ...') logger.error(traceback.format_exc()) comm.Abort() else: var = np.empty(var_dims, dtype=var_type) comm.Bcast(var, root=0) return var
[docs] def get_mask_at_last_time_index(self, var_name, var_dims): # MPI objects and variables comm = MPI.COMM_WORLD rank = comm.Get_rank() if rank == 0: try: mask = self.file_reader.get_mask_at_last_time_index(var_name).astype(DTYPE_INT) except Exception as e: logger = logging.getLogger(__name__) logger.error('Caught exception when getting mask at ' \ 'last time index. Terminating all tasks ...') logger.error(traceback.format_exc()) comm.Abort() else: mask = np.empty(var_dims, dtype=DTYPE_INT) comm.Bcast(mask, root=0) return mask
[docs] def get_mask_at_next_time_index(self, var_name, var_dims): # MPI objects and variables comm = MPI.COMM_WORLD rank = comm.Get_rank() if rank == 0: try: mask = self.file_reader.get_mask_at_next_time_index(var_name).astype(DTYPE_INT) except Exception as e: logger = logging.getLogger(__name__) logger.error('Caught exception when getting mask at ' \ 'next time index. Terminating all tasks ...') logger.error(traceback.format_exc()) comm.Abort() else: mask = np.empty(var_dims, dtype=DTYPE_INT) comm.Bcast(mask, root=0) return mask