PaiGram/utils/queues.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

448 lines
14 KiB
Python
Raw Permalink Normal View History

"""线程安全的队列"""
import asyncio
import sys
from asyncio import QueueEmpty as AsyncQueueEmpty
from asyncio import QueueFull as AsyncQueueFull
from collections import deque
from heapq import heappop, heappush
from queue import Empty as SyncQueueEmpty
from queue import Full as SyncQueueFull
from threading import Condition, Lock
from typing import TYPE_CHECKING, Any, Callable, Deque, Generic, List, NoReturn, Optional, Set, TypeVar
from utils.typedefs import AsyncQueue, SyncQueue
if TYPE_CHECKING:
from asyncio import AbstractEventLoop as EventLoop
__all__ = (
"Queue",
"PriorityQueue",
"LifoQueue",
)
T = TypeVar("T")
OptFloat = Optional[float]
class Queue(Generic[T]):
"""线程安全的同步、异步队列"""
_loop: "EventLoop"
@property
def loop(self) -> "EventLoop":
"""返回该队列的事件循环"""
try:
self._loop = asyncio.get_running_loop()
except RuntimeError as e:
raise RuntimeError("没有正在运行的事件循环, 请在异步函数中使用.") from e
return self._loop
def __init__(self, maxsize: int = 0):
"""初始化队列
Args:
maxsize (int): 队列的大小
Returns:
"""
self._maxsize = maxsize
self._init()
self.unfinished_tasks = 0
self.sync_mutex = Lock()
self.sync_not_empty = Condition(self.sync_mutex)
self.sync_not_full = Condition(self.sync_mutex)
self.all_tasks_done = Condition(self.sync_mutex)
self.async_mutex = asyncio.Lock()
if sys.version_info[:3] == (3, 10, 0):
# 针对 3.10 的 bug
getattr(self.async_mutex, "_get_loop", lambda: None)()
self.async_not_empty = asyncio.Condition(self.async_mutex)
self.async_not_full = asyncio.Condition(self.async_mutex)
self.finished = asyncio.Event()
self.finished.set()
self.closing = False
self.pending = set() # type: Set[asyncio.Future[Any]]
2023-10-08 07:14:33 +00:00
def checked_call_soon_threadsafe(callback: Callable[..., None], *args: Any):
try:
self.loop.call_soon_threadsafe(callback, *args)
except RuntimeError:
pass
self._call_soon_threadsafe = checked_call_soon_threadsafe
2023-10-08 07:14:33 +00:00
def checked_call_soon(callback: Callable[..., None], *args: Any):
if not self.loop.is_closed():
self.loop.call_soon(callback, *args)
self._call_soon = checked_call_soon
self._sync_queue = _SyncQueueProxy(self)
self._async_queue = _AsyncQueueProxy(self)
2023-10-08 07:14:33 +00:00
def close(self):
"""关闭队列"""
with self.sync_mutex:
self.closing = True
for fut in self.pending:
fut.cancel()
self.finished.set() # 取消堵塞全部的 async_q.join()
self.all_tasks_done.notify_all() # 取消堵塞全部的 sync_q.join()
2023-10-08 07:14:33 +00:00
async def wait_closed(self):
if not self.closing:
raise RuntimeError("队列已被关闭")
await asyncio.sleep(0)
if not self.pending:
return
await asyncio.wait(self.pending)
@property
def closed(self) -> bool:
return self.closing and not self.pending
@property
def maxsize(self) -> int:
return self._maxsize
@property
def sync_q(self) -> "_SyncQueueProxy[T]":
return self._sync_queue
@property
def async_q(self) -> "_AsyncQueueProxy[T]":
return self._async_queue
2023-10-08 07:14:33 +00:00
def _init(self):
self._queue = deque() # type: Deque[T]
def qsize(self) -> int:
return len(self._queue)
2023-10-08 07:14:33 +00:00
def put(self, item: T):
self._queue.append(item)
def get(self) -> T:
return self._queue.popleft()
2023-10-08 07:14:33 +00:00
def put_internal(self, item: T):
self.put(item)
self.unfinished_tasks += 1
self.finished.clear()
2023-10-08 07:14:33 +00:00
def notify_sync_not_empty(self):
def f():
with self.sync_mutex:
self.sync_not_empty.notify()
self.loop.run_in_executor(None, f)
2023-10-08 07:14:33 +00:00
def notify_sync_not_full(self):
def f():
with self.sync_mutex:
self.sync_not_full.notify()
fut = asyncio.ensure_future(self.loop.run_in_executor(None, f))
fut.add_done_callback(self.pending.discard)
self.pending.add(fut)
2023-10-08 07:14:33 +00:00
def notify_async_not_empty(self, *, threadsafe: bool):
async def f():
async with self.async_mutex:
self.async_not_empty.notify()
2023-10-08 07:14:33 +00:00
def task_maker():
task = self.loop.create_task(f())
task.add_done_callback(self.pending.discard)
self.pending.add(task)
if threadsafe:
self._call_soon_threadsafe(task_maker)
else:
self._call_soon(task_maker)
2023-10-08 07:14:33 +00:00
def notify_async_not_full(self, *, threadsafe: bool):
async def f():
async with self.async_mutex:
self.async_not_full.notify()
2023-10-08 07:14:33 +00:00
def task_maker():
task = self.loop.create_task(f())
task.add_done_callback(self.pending.discard)
self.pending.add(task)
if threadsafe:
self._call_soon_threadsafe(task_maker)
else:
self._call_soon(task_maker)
2023-10-08 07:14:33 +00:00
def check_closing(self):
if self.closing:
raise RuntimeError("禁止对已关闭的队列进行操作")
# noinspection PyProtectedMember
class _SyncQueueProxy(SyncQueue[T]): # pylint: disable=W0212
"""同步"""
def __init__(self, parent: Queue[T]):
self._parent = parent
@property
def maxsize(self) -> int:
return self._parent.maxsize
@property
def closed(self) -> bool:
return self._parent.closed
2023-10-08 07:14:33 +00:00
def task_done(self):
self._parent.check_closing()
with self._parent.all_tasks_done:
unfinished = self._parent.unfinished_tasks - 1
if unfinished <= 0:
if unfinished < 0:
raise ValueError("task_done() 执行次数过多")
self._parent.all_tasks_done.notify_all()
self._parent.loop.call_soon_threadsafe(self._parent.finished.set)
self._parent.unfinished_tasks = unfinished
2023-10-08 07:14:33 +00:00
def join(self):
self._parent.check_closing()
with self._parent.all_tasks_done:
while self._parent.unfinished_tasks:
self._parent.all_tasks_done.wait()
self._parent.check_closing()
def qsize(self) -> int:
"""返回队列的大致大小(不可靠)"""
return self._parent.qsize()
@property
def unfinished_tasks(self) -> int:
"""返回未完成 task 的数量"""
return self._parent.unfinished_tasks
def empty(self) -> bool:
return not self._parent.qsize()
def full(self) -> bool:
return 0 < self._parent.maxsize <= self._parent.qsize()
2023-10-08 07:14:33 +00:00
def put(self, item: T, block: bool = True, timeout: OptFloat = None):
self._parent.check_closing()
with self._parent.sync_not_full:
if self._parent.maxsize > 0:
if not block:
if self._parent.qsize() >= self._parent.maxsize:
raise SyncQueueFull
elif timeout is None:
while self._parent.qsize() >= self._parent.maxsize:
self._parent.sync_not_full.wait()
elif timeout < 0:
raise ValueError("'timeout' 必须为一个非负数")
else:
time = self._parent.loop.time
end_time = time() + timeout
while self._parent.qsize() >= self._parent.maxsize:
remaining = end_time - time()
if remaining <= 0.0:
raise SyncQueueFull
self._parent.sync_not_full.wait(remaining)
self._parent.put_internal(item)
self._parent.sync_not_empty.notify()
self._parent.notify_async_not_empty(threadsafe=True)
def get(self, block: bool = True, timeout: OptFloat = None) -> T:
self._parent.check_closing()
with self._parent.sync_not_empty:
if not block:
if not self._parent.qsize():
raise SyncQueueEmpty
elif timeout is None:
while not self._parent.qsize():
self._parent.sync_not_empty.wait()
elif timeout < 0:
raise ValueError("'timeout' 必须为一个非负数")
else:
time = self._parent.loop.time
end_time = time() + timeout
while not self._parent.qsize():
remaining = end_time - time()
if remaining <= 0.0:
raise SyncQueueEmpty
self._parent.sync_not_empty.wait(remaining)
item = self._parent.get()
self._parent.sync_not_full.notify()
self._parent.notify_async_not_full(threadsafe=True)
return item
2023-10-08 07:14:33 +00:00
def put_nowait(self, item: T):
return self.put(item, block=False)
def get_nowait(self) -> T:
return self.get(block=False)
# noinspection PyProtectedMember
class _AsyncQueueProxy(AsyncQueue[T]): # pylint: disable=W0212
"""异步"""
def __init__(self, parent: Queue[T]):
self._parent = parent
@property
def closed(self) -> bool:
return self._parent.closed
def qsize(self) -> int:
return self._parent.qsize()
@property
def unfinished_tasks(self) -> int:
return self._parent.unfinished_tasks
@property
def maxsize(self) -> int:
return self._parent.maxsize
def empty(self) -> bool:
return self.qsize() == 0
def full(self) -> bool:
if self._parent.maxsize <= 0:
return False
return self.qsize() >= self._parent.maxsize
async def put(self, item: T) -> None:
self._parent.check_closing()
async with self._parent.async_not_full:
self._parent.sync_mutex.acquire()
locked = True
try:
if self._parent.maxsize > 0:
do_wait = True
while do_wait:
do_wait = self._parent.qsize() >= self._parent.maxsize
if do_wait:
locked = False
self._parent.sync_mutex.release()
await self._parent.async_not_full.wait()
self._parent.sync_mutex.acquire()
locked = True
self._parent.put_internal(item)
self._parent.async_not_empty.notify()
self._parent.notify_sync_not_empty()
finally:
if locked:
self._parent.sync_mutex.release()
2023-10-08 07:14:33 +00:00
def put_nowait(self, item: T):
self._parent.check_closing()
with self._parent.sync_mutex and 0 < self._parent.maxsize <= self._parent.qsize():
raise AsyncQueueFull
self._parent.put_internal(item)
self._parent.notify_async_not_empty(threadsafe=False)
self._parent.notify_sync_not_empty()
async def get(self) -> T:
self._parent.check_closing()
async with self._parent.async_not_empty:
self._parent.sync_mutex.acquire()
locked = True
try:
do_wait = True
while do_wait:
do_wait = self._parent.qsize() == 0
if do_wait:
locked = False
self._parent.sync_mutex.release()
await self._parent.async_not_empty.wait()
self._parent.sync_mutex.acquire()
locked = True
item = self._parent.get()
self._parent.async_not_full.notify()
self._parent.notify_sync_not_full()
return item
finally:
if locked:
self._parent.sync_mutex.release()
def get_nowait(self) -> T:
self._parent.check_closing()
with self._parent.sync_mutex:
if self._parent.qsize() == 0:
raise AsyncQueueEmpty
item = self._parent.get()
self._parent.notify_async_not_full(threadsafe=False)
self._parent.notify_sync_not_full()
return item
2023-10-08 07:14:33 +00:00
def task_done(self):
self._parent.check_closing()
with self._parent.all_tasks_done:
if self._parent.unfinished_tasks <= 0:
raise ValueError("task_done() called too many times")
self._parent.unfinished_tasks -= 1
if self._parent.unfinished_tasks == 0:
self._parent.finished.set()
self._parent.all_tasks_done.notify_all()
async def join(self) -> None:
while True:
with self._parent.sync_mutex:
self._parent.check_closing()
if self._parent.unfinished_tasks == 0:
break
await self._parent.finished.wait()
class PriorityQueue(Queue[T]):
"""优先级队列"""
2023-10-08 07:14:33 +00:00
def _init(self):
self._heap_queue: List[T] = []
def qsize(self) -> int:
return len(self._heap_queue)
2023-10-08 07:14:33 +00:00
def put(self, item: T):
if not isinstance(item, tuple):
if hasattr(item, "priority"):
item = (int(item.priority), item)
else:
try:
item = (int(item), item)
except (TypeError, ValueError):
pass
heappush(self._heap_queue, item)
def get(self) -> T:
return heappop(self._heap_queue)
class LifoQueue(Queue[T]):
"""后进先出队列"""
def qsize(self) -> int:
return len(self._queue)
2023-10-08 07:14:33 +00:00
def put(self, item: T):
self._queue.append(item)
def get(self) -> T:
return self._queue.pop()