diff --git a/loky/__init__.py b/loky/__init__.py index a3b30c09..a14e2ae4 100644 --- a/loky/__init__.py +++ b/loky/__init__.py @@ -20,6 +20,7 @@ from .reusable_executor import get_reusable_executor from .cloudpickle_wrapper import wrap_non_picklable_objects from .process_executor import BrokenProcessPool, ProcessPoolExecutor +from .worker_id import get_worker_id __all__ = [ @@ -37,6 +38,7 @@ "FIRST_EXCEPTION", "ALL_COMPLETED", "wrap_non_picklable_objects", + "get_worker_id", "set_loky_pickler", ] diff --git a/loky/process_executor.py b/loky/process_executor.py index 1e08cc21..fc3db908 100644 --- a/loky/process_executor.py +++ b/loky/process_executor.py @@ -384,6 +384,7 @@ def _process_worker( timeout, worker_exit_lock, current_depth, + worker_id, ): """Evaluates calls from call_queue and places the results in result_queue. @@ -420,6 +421,9 @@ def _process_worker( _last_memory_leak_check = None pid = os.getpid() + # set the worker_id environment variable + os.environ["LOKY_WORKER_ID"] = str(worker_id) + mp.util.debug(f"Worker started with timeout={timeout}") while True: try: @@ -562,6 +566,9 @@ def weakref_cb( # A list of the ctx.Process instances used as workers. self.processes = executor._processes + # A dict mapping worker pids to worker IDs + self.process_worker_ids = executor._process_worker_ids + # A ctx.Queue that will be filled with _CallItems derived from # _WorkItems for processing by the process workers. self.call_queue = executor._call_queue @@ -727,6 +734,7 @@ def process_result_item(self, result_item): # itself: we should not mark the executor as broken. with self.processes_management_lock: p = self.processes.pop(result_item, None) + self.process_worker_ids.pop(result_item, None) # p can be None if the executor is concurrently shutting down. if p is not None: @@ -830,7 +838,9 @@ def kill_workers(self, reason=""): # terminates descendant workers of the children in case there is some # nested parallelism. while self.processes: - _, p = self.processes.popitem() + pid, p = self.processes.popitem() + self.process_worker_ids.pop(pid, None) + mp.util.debug(f"terminate process {p.name}, reason: {reason}") try: kill_process_tree(p) @@ -1101,8 +1111,10 @@ def __init__( # Map of pids to processes self._processes = {} + # Map of pids to process worker IDs + self._process_worker_ids = {} + # Internal variables of the ProcessPoolExecutor - self._processes = {} self._queue_count = 0 self._pending_work_items = {} self._running_work_items = [] @@ -1183,9 +1195,21 @@ def _start_executor_manager_thread(self): _python_exit ) + def _get_available_worker_id(self): + if _CURRENT_DEPTH > 0: + return -1 + + used_ids = set(self._process_worker_ids.values()) + available_ids = set(range(self._max_workers)) - used_ids + if len(available_ids): + return available_ids.pop() + else: + return -1 + def _adjust_process_count(self): while len(self._processes) < self._max_workers: worker_exit_lock = self._context.BoundedSemaphore(1) + worker_id = self._get_available_worker_id() args = ( self._call_queue, self._result_queue, @@ -1195,8 +1219,10 @@ def _adjust_process_count(self): self._timeout, worker_exit_lock, _CURRENT_DEPTH + 1, + worker_id, ) worker_exit_lock.acquire() + try: # Try to spawn the process with some environment variable to # overwrite but it only works with the loky context for now. @@ -1208,6 +1234,7 @@ def _adjust_process_count(self): p._worker_exit_lock = worker_exit_lock p.start() self._processes[p.pid] = p + self._process_worker_ids[p.pid] = worker_id mp.util.debug( f"Adjusted process count to {self._max_workers}: " f"{[(p.name, pid) for pid, p in self._processes.items()]}" diff --git a/loky/worker_id.py b/loky/worker_id.py new file mode 100644 index 00000000..8d5e0f4a --- /dev/null +++ b/loky/worker_id.py @@ -0,0 +1,15 @@ +import os + + +def get_worker_id(): + """Get the worker ID of the current process. + + For a `ReusableExectutor` with `max_workers=n`, the worker ID is in the + range [0..n). This is suited for reuse of persistent objects such as GPU + IDs. This function only works at the first level of parallelization (i.e. + not for nested parallelization). Resizing the `ReusableExectutor` will + result in unpredictable return values. + + Returns -1 when the process is not a worker. + """ + return int(os.environ.get('LOKY_WORKER_ID', -1)) diff --git a/tests/test_worker_id.py b/tests/test_worker_id.py new file mode 100644 index 00000000..4d2f95a3 --- /dev/null +++ b/tests/test_worker_id.py @@ -0,0 +1,36 @@ +import time +import pytest +import numpy as np +from collections import defaultdict +from loky import get_reusable_executor, get_worker_id + + +def random_sleep(args): + k, max_duration = args + rng = np.random.RandomState(seed=k) + duration = rng.uniform(0, max_duration) + t0 = time.time() + time.sleep(duration) + t1 = time.time() + wid = get_worker_id() + return (wid, t0, t1) + + +@pytest.mark.parametrize("max_duration,timeout,kmax", [(0.05, 2, 100), + (1, 0.01, 4)]) +def test_worker_ids(max_duration, timeout, kmax): + """Test that worker IDs are always unique, with re-use over time""" + num_workers = 4 + executor = get_reusable_executor(max_workers=num_workers, timeout=timeout) + results = executor.map(random_sleep, [(k, max_duration) + for k in range(kmax)]) + + all_intervals = defaultdict(list) + for wid, t0, t1 in results: + assert wid in set(range(num_workers)) + all_intervals[wid].append((t0, t1)) + + for intervals in all_intervals.values(): + intervals = sorted(intervals) + for i in range(len(intervals) - 1): + assert intervals[i + 1][0] >= intervals[i][1]