aoc-2022/venv/Lib/site-packages/pebble/pool/channel.py

196 lines
5.9 KiB
Python
Raw Normal View History

# This file is part of Pebble.
# Copyright (c) 2013-2022, Matteo Cafasso
# Pebble is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License
# as published by the Free Software Foundation,
# either version 3 of the License, or (at your option) any later version.
# Pebble is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
# You should have received a copy of the GNU Lesser General Public License
# along with Pebble. If not, see <http://www.gnu.org/licenses/>.
import os
import select
import multiprocessing
from contextlib import contextmanager
from typing import Any, Callable, Tuple
class ChannelError(OSError):
"""Error occurring within the process channel."""
def channels(mp_context: multiprocessing.context.BaseContext) -> tuple:
read0, write0 = mp_context.Pipe(duplex=False)
read1, write1 = mp_context.Pipe(duplex=False)
return (Channel(read1, write0),
WorkerChannel(read0, write1, (read1, write0), mp_context))
class Channel:
def __init__(self, reader: multiprocessing.connection.Connection,
writer: multiprocessing.connection.Connection):
self.reader = reader
self.writer = writer
self.poll = self._make_poll_method()
def _make_poll_method(self):
def unix_poll(timeout: float = None) -> bool:
readonly_mask = (select.POLLIN |
select.POLLPRI |
select.POLLHUP |
select.POLLERR)
poll = select.poll()
poll.register(self.reader, readonly_mask)
# Convert from Seconds to Milliseconds
if timeout is not None:
timeout *= MILLISECONDS
return bool(poll.poll(timeout))
def windows_poll(timeout: float = None) -> bool:
return self.reader.poll(timeout)
return unix_poll if os.name != 'nt' else windows_poll
def recv(self) -> Any:
return self.reader.recv()
def send(self, obj: Any):
return self.writer.send(obj)
def close(self):
self.reader.close()
self.writer.close()
class WorkerChannel(Channel):
def __init__(self, reader: multiprocessing.connection.Connection,
writer: multiprocessing.connection.Connection,
unused: tuple,
mp_context: multiprocessing.context.BaseContext):
super().__init__(reader, writer)
self.mutex = ChannelMutex(mp_context)
self.recv = self._make_recv_method()
self.send = self._make_send_method()
self.unused = unused
def __getstate__(self) -> tuple:
return self.reader, self.writer, self.mutex, self.unused
def __setstate__(self, state: tuple):
self.reader, self.writer, self.mutex, self.unused = state
self.poll = self._make_poll_method()
self.recv = self._make_recv_method()
self.send = self._make_send_method()
def _make_recv_method(self) -> Callable:
def recv():
with self.mutex.reader:
return self.reader.recv()
return recv
def _make_send_method(self) -> Callable:
def unix_send(obj: Any):
with self.mutex.writer:
return self.writer.send(obj)
def windows_send(obj: Any):
return self.writer.send(obj)
return unix_send if os.name != 'nt' else windows_send
@property
@contextmanager
def lock(self):
with self.mutex:
yield self
def initialize(self):
"""Close unused connections."""
for connection in self.unused:
connection.close()
class ChannelMutex:
def __init__(self, mp_context: multiprocessing.context.BaseContext):
self.reader_mutex = mp_context.RLock()
self.writer_mutex = mp_context.RLock() if os.name != 'nt' else None
self.acquire = self._make_acquire_method()
self.release = self._make_release_method()
def __getstate__(self):
return self.reader_mutex, self.writer_mutex
def __setstate__(self, state):
self.reader_mutex, self.writer_mutex = state
self.acquire = self._make_acquire_method()
self.release = self._make_release_method()
def __enter__(self):
if self.acquire():
return self
raise ChannelError("Channel mutex time out")
def __exit__(self, *_):
self.release()
def _make_acquire_method(self) -> Callable:
def unix_acquire() -> bool:
return (self.reader_mutex.acquire(timeout=LOCK_TIMEOUT) and
self.writer_mutex.acquire(timeout=LOCK_TIMEOUT))
def windows_acquire() -> bool:
return self.reader_mutex.acquire(timeout=LOCK_TIMEOUT)
return unix_acquire if os.name != 'nt' else windows_acquire
def _make_release_method(self) -> Callable:
def unix_release():
self.reader_mutex.release()
self.writer_mutex.release()
def windows_release():
self.reader_mutex.release()
return unix_release if os.name != 'nt' else windows_release
@property
@contextmanager
def reader(self):
if self.reader_mutex.acquire(timeout=LOCK_TIMEOUT):
try:
yield self
finally:
self.reader_mutex.release()
else:
raise ChannelError("Channel mutex time out")
@property
@contextmanager
def writer(self):
if self.writer_mutex.acquire(timeout=LOCK_TIMEOUT):
try:
yield self
finally:
self.writer_mutex.release()
else:
raise ChannelError("Channel mutex time out")
MILLISECONDS = 1000
LOCK_TIMEOUT = 60