Source code for pyretlife.retrieval_plotting.parallel_computation
"""
This module contains the `RetrievalPlottingObject` class, which is the main
class used to generate plots of the pyretlife retrievals.
"""
__author__ = "Konrad, Alei, Molliere, Quanz"
__copyright__ = "Copyright 2022, Konrad, Alei, Molliere, Quanz"
__maintainer__ = "Björn S. Konrad, Eleonora Alei"
__email__ = "konradb@ethz.ch, elalei@phys.ethz.ch"
__status__ = "Development"
# -----------------------------------------------------------------------------
# IMPORTS
# -----------------------------------------------------------------------------
import multiprocessing as mp
import contextlib
import pyretlife.retrieval_plotting.calculate_secondary_quantities as secondary_quantites
# Class that enables the parallelization of a function
[docs]
class parallel():
"""
A class used to represent parallel function execution across multiple processes.
This class enables parallelization of a function across a specified number of processes using
Python's `multiprocessing` module. It manages the creation of processes, distributes tasks, and
collects the results in a shared memory dictionary.
Attributes
----------
num_proc : int
The number of processes to be used for parallel execution.
manager : mp.Manager
A manager object to handle shared data between processes.
result : mp.Manager.dict
A dictionary to store the results from each process.
jobs : list
A list of process objects representing the parallel tasks.
Methods
-------
__init__(num_proc)
Initializes the parallel execution setup with the specified number of processes.
calculate(results_directory, function, function_args)
Executes the given function in parallel across the specified number of processes.
__worker(process, results_directory, function, function_args)
The worker method executed by each process to run the given function and store the results.
"""
[docs]
def __init__(self,num_proc):
"""
Initializes the parallel execution setup.
:param num_proc: The number of processes to be used for parallel execution.
:type num_proc: int
"""
self.num_proc = num_proc
# Define a manager process that collects the
# data from the other processes
self.manager = mp.Manager()
self.result = self.manager.dict()
self.jobs = []
[docs]
def calculate(self,results_directory,function,function_args):
"""
Executes the provided function in parallel using multiple processes.
This method initializes the processes, starts the calculations, and waits until all processes are done.
The results are stored in a shared dictionary and returned once all processes finish.
:param results_directory: Directory where the results will be saved or accessed.
:type results_directory: str
:param function: The function to be executed in parallel.
:type function: callable
:param function_args: Arguments to be passed to the function.
:type function_args: dict
:return: A dictionary containing the results from each process.
:rtype: dict
"""
# Initialize the processes and start the calculation
for process in range(self.num_proc):
p = mp.Process(target=self.__worker, args=(process,results_directory,function,function_args))
self.jobs.append(p)
p.start()
# Wait untill all the processes are done
for proc in self.jobs:
proc.join()
# Return the data to the user
return self.result
def __worker(self,process,results_direectory,function,function_args):
"""
Worker function executed by each process.
This function initializes a new object (e.g., a `retrieval_plotting_object`), executes the provided
function, and stores the result in a shared dictionary.
:param process: The index of the current process.
:type process: int
:param results_directory: Directory where the results will be saved or accessed.
:type results_directory: str
:param function: The function to be executed by this process.
:type function: callable
:param function_args: Arguments to be passed to the function.
:type function_args: dict
"""
# Initialization of a new radtrans object
with contextlib.redirect_stdout(None):
from pyretlife.retrieval_plotting.run_plotting import retrieval_plotting_object
results_temp = retrieval_plotting_object(results_directory = results_direectory)
# Function calculation
function_args['process'] = process
function_args['rp_object'] = results_temp
self.result[process] = getattr(secondary_quantites,function)(**function_args)