506 lines
17 KiB
Python
506 lines
17 KiB
Python
# This file is part of Pebble.
|
|
# Copyright (c) 2013-2022, Matteo Cafasso
|
|
|
|
# Pebble is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU Lesser General Public License
|
|
# as published by the Free Software Foundation,
|
|
# either version 3 of the License, or (at your option) any later version.
|
|
|
|
# Pebble is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU Lesser General Public License for more details.
|
|
|
|
# You should have received a copy of the GNU Lesser General Public License
|
|
# along with Pebble. If not, see <http://www.gnu.org/licenses/>.
|
|
|
|
import os
|
|
import time
|
|
import atexit
|
|
import pickle
|
|
import multiprocessing
|
|
|
|
from itertools import count
|
|
from collections import namedtuple
|
|
from signal import SIG_IGN, SIGINT, signal
|
|
from typing import Any, Callable, Optional
|
|
from concurrent.futures import CancelledError, TimeoutError
|
|
from concurrent.futures.process import BrokenProcessPool
|
|
|
|
from pebble.pool.base_pool import ProcessMapFuture, map_results
|
|
from pebble.pool.base_pool import CREATED, ERROR, RUNNING, SLEEP_UNIT
|
|
from pebble.pool.base_pool import Worker, iter_chunks, run_initializer
|
|
from pebble.pool.base_pool import PoolContext, BasePool, Task, TaskPayload
|
|
from pebble.pool.channel import ChannelError, WorkerChannel, channels
|
|
from pebble.common import launch_process, stop_process
|
|
from pebble.common import ProcessExpired, ProcessFuture
|
|
from pebble.common import process_execute, launch_thread
|
|
|
|
|
|
class ProcessPool(BasePool):
|
|
"""Allows to schedule jobs within a Pool of Processes.
|
|
|
|
max_workers is an integer representing the amount of desired process workers
|
|
managed by the pool.
|
|
If max_tasks is a number greater than zero,
|
|
each worker will be restarted after performing an equal amount of tasks.
|
|
|
|
initializer must be callable, if passed, it will be called
|
|
every time a worker is started, receiving initargs as arguments.
|
|
|
|
"""
|
|
|
|
def __init__(self, max_workers: int = multiprocessing.cpu_count(),
|
|
max_tasks: int = 0,
|
|
initializer: Callable = None,
|
|
initargs: list = (),
|
|
context: multiprocessing.context.BaseContext = None):
|
|
super().__init__(max_workers, max_tasks, initializer, initargs)
|
|
mp_context = multiprocessing if context is None else context
|
|
self._pool_manager = PoolManager(self._context, mp_context)
|
|
self._task_scheduler_loop = None
|
|
self._pool_manager_loop = None
|
|
self._message_manager_loop = None
|
|
|
|
def _start_pool(self):
|
|
with self._context.state_mutex:
|
|
if self._context.state == CREATED:
|
|
self._pool_manager.start()
|
|
|
|
self._task_scheduler_loop = launch_thread(
|
|
None, task_scheduler_loop, True, self._pool_manager)
|
|
self._pool_manager_loop = launch_thread(
|
|
None, pool_manager_loop, True, self._pool_manager)
|
|
self._message_manager_loop = launch_thread(
|
|
None, message_manager_loop, True, self._pool_manager)
|
|
|
|
self._context.state = RUNNING
|
|
|
|
def _stop_pool(self):
|
|
if self._pool_manager_loop is not None:
|
|
self._pool_manager_loop.join()
|
|
self._pool_manager.stop()
|
|
if self._task_scheduler_loop is not None:
|
|
self._task_scheduler_loop.join()
|
|
if self._message_manager_loop is not None:
|
|
self._message_manager_loop.join()
|
|
|
|
def schedule(self, function: Callable,
|
|
args: list = (),
|
|
kwargs: dict = {},
|
|
timeout: float = None) -> ProcessFuture:
|
|
"""Schedules *function* to be run the Pool.
|
|
|
|
*args* and *kwargs* will be forwareded to the scheduled function
|
|
respectively as arguments and keyword arguments.
|
|
|
|
*timeout* is an integer, if expires the task will be terminated
|
|
and *Future.result()* will raise *TimeoutError*.
|
|
|
|
A *pebble.ProcessFuture* object is returned.
|
|
|
|
"""
|
|
self._check_pool_state()
|
|
|
|
future = ProcessFuture()
|
|
payload = TaskPayload(function, args, kwargs)
|
|
task = Task(next(self._task_counter), future, timeout, payload)
|
|
|
|
self._context.task_queue.put(task)
|
|
|
|
return future
|
|
|
|
def submit(self, function: Callable,
|
|
timeout: Optional[float],
|
|
*args, **kwargs) -> ProcessFuture:
|
|
"""This function is provided for compatibility with
|
|
`asyncio.loop.run_in_executor`.
|
|
|
|
For scheduling jobs within the pool use `schedule` instead.
|
|
|
|
"""
|
|
return self.schedule(
|
|
function, args=args, kwargs=kwargs, timeout=timeout)
|
|
|
|
def map(self, function: Callable,
|
|
*iterables, **kwargs) -> ProcessMapFuture:
|
|
"""Computes the *function* using arguments from
|
|
each of the iterables. Stops when the shortest iterable is exhausted.
|
|
|
|
*timeout* is an integer, if expires the task will be terminated
|
|
and the call to next will raise *TimeoutError*.
|
|
The *timeout* is applied to each chunk of the iterable.
|
|
|
|
*chunksize* controls the size of the chunks the iterable will
|
|
be broken into before being passed to the function.
|
|
|
|
A *pebble.ProcessFuture* object is returned.
|
|
|
|
"""
|
|
self._check_pool_state()
|
|
|
|
timeout = kwargs.get('timeout')
|
|
chunksize = kwargs.get('chunksize', 1)
|
|
|
|
if chunksize < 1:
|
|
raise ValueError("chunksize must be >= 1")
|
|
|
|
futures = [self.schedule(
|
|
process_chunk, args=(function, chunk), timeout=timeout)
|
|
for chunk in iter_chunks(chunksize, *iterables)]
|
|
|
|
return map_results(ProcessMapFuture(futures), timeout)
|
|
|
|
|
|
def task_scheduler_loop(pool_manager: 'PoolManager'):
|
|
context = pool_manager.context
|
|
task_queue = context.task_queue
|
|
|
|
try:
|
|
while context.alive and not GLOBAL_SHUTDOWN:
|
|
task = task_queue.get()
|
|
|
|
if task is not None:
|
|
if task.future.cancelled():
|
|
task.set_running_or_notify_cancel()
|
|
task_queue.task_done()
|
|
else:
|
|
pool_manager.schedule(task)
|
|
else:
|
|
task_queue.task_done()
|
|
except BrokenProcessPool:
|
|
context.state = ERROR
|
|
|
|
|
|
def pool_manager_loop(pool_manager: 'PoolManager'):
|
|
context = pool_manager.context
|
|
|
|
try:
|
|
while context.alive and not GLOBAL_SHUTDOWN:
|
|
pool_manager.update_status()
|
|
time.sleep(SLEEP_UNIT)
|
|
except BrokenProcessPool:
|
|
context.state = ERROR
|
|
|
|
|
|
def message_manager_loop(pool_manager: 'PoolManager'):
|
|
context = pool_manager.context
|
|
|
|
try:
|
|
while context.alive and not GLOBAL_SHUTDOWN:
|
|
pool_manager.process_next_message(SLEEP_UNIT)
|
|
except BrokenProcessPool:
|
|
context.state = ERROR
|
|
|
|
|
|
class PoolManager:
|
|
"""Combines Task and Worker Managers providing a higher level one."""
|
|
def __init__(self, context: PoolContext,
|
|
mp_context: multiprocessing.context.BaseContext):
|
|
self.context = context
|
|
self.task_manager = TaskManager(context.task_queue.task_done)
|
|
self.worker_manager = WorkerManager(context.workers,
|
|
context.worker_parameters,
|
|
mp_context)
|
|
|
|
def start(self):
|
|
self.worker_manager.create_workers()
|
|
|
|
def stop(self):
|
|
self.worker_manager.close_channels()
|
|
self.worker_manager.stop_workers()
|
|
|
|
def schedule(self, task: Task):
|
|
"""Schedules a new Task in the PoolManager."""
|
|
self.task_manager.register(task)
|
|
try:
|
|
self.worker_manager.dispatch(task)
|
|
except (pickle.PicklingError, TypeError) as error:
|
|
self.task_manager.task_problem(task.id, error)
|
|
|
|
def process_next_message(self, timeout: float):
|
|
"""Processes the next message coming from the workers."""
|
|
message = self.worker_manager.receive(timeout)
|
|
|
|
if isinstance(message, Acknowledgement):
|
|
self.task_manager.task_start(message.task, message.worker)
|
|
elif isinstance(message, Result):
|
|
self.task_manager.task_done(message.task, message.result)
|
|
elif isinstance(message, Problem):
|
|
self.task_manager.task_problem(message.task, message.error)
|
|
|
|
def update_status(self):
|
|
self.update_tasks()
|
|
self.update_workers()
|
|
|
|
def update_tasks(self):
|
|
"""Handles timing out Tasks."""
|
|
for task in self.task_manager.timeout_tasks():
|
|
self.worker_manager.stop_worker(task.worker_id)
|
|
self.task_manager.task_done(
|
|
task.id, TimeoutError("Task timeout", task.timeout))
|
|
|
|
for task in self.task_manager.cancelled_tasks():
|
|
self.worker_manager.stop_worker(task.worker_id)
|
|
self.task_manager.task_done(
|
|
task.id, CancelledError())
|
|
|
|
def update_workers(self):
|
|
"""Handles unexpected processes termination."""
|
|
for expiration in self.worker_manager.inspect_workers():
|
|
self.handle_worker_expiration(expiration)
|
|
|
|
self.worker_manager.create_workers()
|
|
|
|
def handle_worker_expiration(self, expiration: tuple):
|
|
worker_id, exitcode = expiration
|
|
|
|
try:
|
|
task = self.find_expired_task(worker_id)
|
|
except LookupError:
|
|
return
|
|
else:
|
|
error = ProcessExpired('Abnormal termination', code=exitcode)
|
|
self.task_manager.task_done(task.id, error)
|
|
|
|
def find_expired_task(self, worker_id: int) -> Task:
|
|
tasks = tuple(self.task_manager.tasks.values())
|
|
running_tasks = tuple(t for t in tasks if t.worker_id != 0)
|
|
|
|
if running_tasks:
|
|
return task_worker_lookup(running_tasks, worker_id)
|
|
else:
|
|
raise BrokenProcessPool("All workers expired")
|
|
|
|
|
|
class TaskManager:
|
|
"""Manages the tasks flow within the Pool.
|
|
|
|
Tasks are registered, acknowledged and completed.
|
|
Timing out and cancelled tasks are handled as well.
|
|
"""
|
|
|
|
def __init__(self, task_done_callback: Callable):
|
|
self.tasks = {}
|
|
self.task_done_callback = task_done_callback
|
|
|
|
def register(self, task: Task):
|
|
self.tasks[task.id] = task
|
|
|
|
def task_start(self, task_id: int, worker_id: Optional[int]):
|
|
task = self.tasks[task_id]
|
|
task.worker_id = worker_id
|
|
task.timestamp = time.time()
|
|
task.set_running_or_notify_cancel()
|
|
|
|
def task_done(self, task_id: int, result: Any):
|
|
"""Set the tasks result and run the callback."""
|
|
try:
|
|
task = self.tasks.pop(task_id)
|
|
except KeyError:
|
|
return # result of previously timeout Task
|
|
else:
|
|
if task.future.cancelled():
|
|
task.set_running_or_notify_cancel()
|
|
elif isinstance(result, BaseException):
|
|
task.future.set_exception(result)
|
|
else:
|
|
task.future.set_result(result)
|
|
|
|
self.task_done_callback()
|
|
|
|
def task_problem(self, task_id: int, error: Exception):
|
|
"""Set the task with the error it caused within the Pool."""
|
|
self.task_start(task_id, None)
|
|
self.task_done(task_id, error)
|
|
|
|
def timeout_tasks(self) -> tuple:
|
|
return tuple(t for t in tuple(self.tasks.values()) if self.timeout(t))
|
|
|
|
def cancelled_tasks(self) -> tuple:
|
|
return tuple(t for t in tuple(self.tasks.values())
|
|
if t.timestamp != 0 and t.future.cancelled())
|
|
|
|
@staticmethod
|
|
def timeout(task: Task) -> bool:
|
|
if task.timeout and task.started:
|
|
return time.time() - task.timestamp > task.timeout
|
|
else:
|
|
return False
|
|
|
|
|
|
class WorkerManager:
|
|
"""Manages the workers related mechanics within the Pool.
|
|
|
|
Maintains the workers active and encapsulates their communication logic.
|
|
"""
|
|
|
|
def __init__(self, workers:int,
|
|
worker_parameters: Worker,
|
|
mp_context: multiprocessing.context.BaseContext):
|
|
self.workers = {}
|
|
self.workers_number = workers
|
|
self.worker_parameters = worker_parameters
|
|
self.pool_channel, self.workers_channel = channels(mp_context)
|
|
self.mp_context = mp_context
|
|
|
|
def dispatch(self, task: Task):
|
|
try:
|
|
self.pool_channel.send(WorkerTask(task.id, task.payload))
|
|
except (pickle.PicklingError, TypeError) as error:
|
|
raise error
|
|
except OSError as error:
|
|
raise BrokenProcessPool from error
|
|
|
|
def receive(self, timeout: float):
|
|
try:
|
|
if self.pool_channel.poll(timeout):
|
|
return self.pool_channel.recv()
|
|
else:
|
|
return NoMessage()
|
|
except (OSError, TypeError) as error:
|
|
raise BrokenProcessPool from error
|
|
except EOFError: # Pool shutdown
|
|
return NoMessage()
|
|
|
|
def inspect_workers(self) -> tuple:
|
|
"""Updates the workers status.
|
|
|
|
Returns the workers which have unexpectedly ended.
|
|
|
|
"""
|
|
workers = tuple(self.workers.values())
|
|
expired = tuple(w for w in workers if not w.is_alive())
|
|
|
|
for worker in expired:
|
|
self.workers.pop(worker.pid)
|
|
|
|
return tuple((w.pid, w.exitcode) for w in expired if w.exitcode != 0)
|
|
|
|
def create_workers(self):
|
|
for _ in range(self.workers_number - len(self.workers)):
|
|
self.new_worker()
|
|
|
|
def close_channels(self):
|
|
self.pool_channel.close()
|
|
self.workers_channel.close()
|
|
|
|
def stop_workers(self):
|
|
for worker_id in tuple(self.workers.keys()):
|
|
self.stop_worker(worker_id, force=True)
|
|
|
|
def new_worker(self):
|
|
try:
|
|
worker = launch_process(
|
|
WORKERS_NAME, worker_process, False, self.mp_context,
|
|
self.worker_parameters, self.workers_channel)
|
|
self.workers[worker.pid] = worker
|
|
except OSError as error:
|
|
raise BrokenProcessPool from error
|
|
|
|
def stop_worker(self, worker_id: int, force=False):
|
|
try:
|
|
if force:
|
|
stop_process(self.workers.pop(worker_id))
|
|
else:
|
|
with self.workers_channel.lock:
|
|
stop_process(self.workers.pop(worker_id))
|
|
except ChannelError as error:
|
|
raise BrokenProcessPool from error
|
|
except KeyError:
|
|
return # worker already expired
|
|
|
|
|
|
def worker_process(params: Worker, channel: WorkerChannel):
|
|
"""The worker process routines."""
|
|
signal(SIGINT, SIG_IGN)
|
|
|
|
channel.initialize()
|
|
|
|
if params.initializer is not None:
|
|
if not run_initializer(params.initializer, params.initargs):
|
|
os._exit(1)
|
|
|
|
try:
|
|
for task in worker_get_next_task(channel, params.max_tasks):
|
|
payload = task.payload
|
|
result = process_execute(
|
|
payload.function, *payload.args, **payload.kwargs)
|
|
send_result(channel, Result(task.id, result))
|
|
except (OSError, RuntimeError) as error:
|
|
errno = getattr(error, 'errno', 1)
|
|
os._exit(errno if isinstance(errno, int) else 1)
|
|
except EOFError as error:
|
|
os._exit(0)
|
|
|
|
|
|
def worker_get_next_task(channel: WorkerChannel, max_tasks: int):
|
|
counter = count()
|
|
|
|
while max_tasks == 0 or next(counter) < max_tasks:
|
|
yield fetch_task(channel)
|
|
|
|
|
|
def send_result(channel: WorkerChannel, result: Any):
|
|
"""Send result handling pickling and communication errors."""
|
|
try:
|
|
channel.send(result)
|
|
except (pickle.PicklingError, TypeError) as error:
|
|
channel.send(Problem(result.task, error))
|
|
|
|
|
|
def fetch_task(channel: WorkerChannel) -> Task:
|
|
while channel.poll():
|
|
try:
|
|
return task_transaction(channel)
|
|
except RuntimeError:
|
|
continue # another worker got the task
|
|
|
|
|
|
def task_transaction(channel: WorkerChannel) -> Task:
|
|
"""Ensures a task is fetched and acknowledged atomically."""
|
|
with channel.lock:
|
|
if channel.poll(0):
|
|
task = channel.recv()
|
|
channel.send(Acknowledgement(os.getpid(), task.id))
|
|
else:
|
|
raise RuntimeError("Race condition between workers")
|
|
|
|
return task
|
|
|
|
|
|
def task_worker_lookup(running_tasks: tuple, worker_id: int) -> Task:
|
|
for task in running_tasks:
|
|
if task.worker_id == worker_id:
|
|
return task
|
|
|
|
raise LookupError("Not found")
|
|
|
|
|
|
def process_chunk(function: Callable, chunk: list) -> list:
|
|
"""Processes a chunk of the iterable passed to map dealing with errors."""
|
|
return [process_execute(function, *args) for args in chunk]
|
|
|
|
|
|
def interpreter_shutdown():
|
|
global GLOBAL_SHUTDOWN
|
|
GLOBAL_SHUTDOWN = True
|
|
|
|
workers = [p for p in multiprocessing.active_children()
|
|
if p.name == WORKERS_NAME]
|
|
|
|
for worker in workers:
|
|
stop_process(worker)
|
|
|
|
|
|
atexit.register(interpreter_shutdown)
|
|
|
|
|
|
GLOBAL_SHUTDOWN = False
|
|
WORKERS_NAME = 'pebble_pool_worker'
|
|
NoMessage = namedtuple('NoMessage', ())
|
|
Result = namedtuple('Result', ('task', 'result'))
|
|
Problem = namedtuple('Problem', ('task', 'error'))
|
|
WorkerTask = namedtuple('WorkerTask', ('id', 'payload'))
|
|
Acknowledgement = namedtuple('Acknowledgement', ('worker', 'task'))
|