mirror of
https://github.com/PaiGramTeam/PaiGram.git
synced 2024-12-29 12:22:36 +00:00
448 lines
14 KiB
Python
448 lines
14 KiB
Python
"""线程安全的队列"""
|
|
|
|
import asyncio
|
|
import sys
|
|
from asyncio import QueueEmpty as AsyncQueueEmpty
|
|
from asyncio import QueueFull as AsyncQueueFull
|
|
from collections import deque
|
|
from heapq import heappop, heappush
|
|
from queue import Empty as SyncQueueEmpty
|
|
from queue import Full as SyncQueueFull
|
|
from threading import Condition, Lock
|
|
from typing import TYPE_CHECKING, Any, Callable, Deque, Generic, List, NoReturn, Optional, Set, TypeVar
|
|
|
|
from utils.typedefs import AsyncQueue, SyncQueue
|
|
|
|
if TYPE_CHECKING:
|
|
from asyncio import AbstractEventLoop as EventLoop
|
|
|
|
__all__ = (
|
|
"Queue",
|
|
"PriorityQueue",
|
|
"LifoQueue",
|
|
)
|
|
|
|
T = TypeVar("T")
|
|
OptFloat = Optional[float]
|
|
|
|
|
|
class Queue(Generic[T]):
|
|
"""线程安全的同步、异步队列"""
|
|
|
|
_loop: "EventLoop"
|
|
|
|
@property
|
|
def loop(self) -> "EventLoop":
|
|
"""返回该队列的事件循环"""
|
|
try:
|
|
self._loop = asyncio.get_running_loop()
|
|
except RuntimeError as e:
|
|
raise RuntimeError("没有正在运行的事件循环, 请在异步函数中使用.") from e
|
|
return self._loop
|
|
|
|
def __init__(self, maxsize: int = 0):
|
|
"""初始化队列
|
|
|
|
Args:
|
|
maxsize (int): 队列的大小
|
|
Returns:
|
|
无
|
|
"""
|
|
self._maxsize = maxsize
|
|
|
|
self._init()
|
|
|
|
self.unfinished_tasks = 0
|
|
|
|
self.sync_mutex = Lock()
|
|
self.sync_not_empty = Condition(self.sync_mutex)
|
|
self.sync_not_full = Condition(self.sync_mutex)
|
|
self.all_tasks_done = Condition(self.sync_mutex)
|
|
|
|
self.async_mutex = asyncio.Lock()
|
|
if sys.version_info[:3] == (3, 10, 0):
|
|
# 针对 3.10 的 bug
|
|
getattr(self.async_mutex, "_get_loop", lambda: None)()
|
|
self.async_not_empty = asyncio.Condition(self.async_mutex)
|
|
self.async_not_full = asyncio.Condition(self.async_mutex)
|
|
self.finished = asyncio.Event()
|
|
self.finished.set()
|
|
|
|
self.closing = False
|
|
self.pending = set() # type: Set[asyncio.Future[Any]]
|
|
|
|
def checked_call_soon_threadsafe(callback: Callable[..., None], *args: Any):
|
|
try:
|
|
self.loop.call_soon_threadsafe(callback, *args)
|
|
except RuntimeError:
|
|
pass
|
|
|
|
self._call_soon_threadsafe = checked_call_soon_threadsafe
|
|
|
|
def checked_call_soon(callback: Callable[..., None], *args: Any):
|
|
if not self.loop.is_closed():
|
|
self.loop.call_soon(callback, *args)
|
|
|
|
self._call_soon = checked_call_soon
|
|
|
|
self._sync_queue = _SyncQueueProxy(self)
|
|
self._async_queue = _AsyncQueueProxy(self)
|
|
|
|
def close(self):
|
|
"""关闭队列"""
|
|
with self.sync_mutex:
|
|
self.closing = True
|
|
for fut in self.pending:
|
|
fut.cancel()
|
|
self.finished.set() # 取消堵塞全部的 async_q.join()
|
|
self.all_tasks_done.notify_all() # 取消堵塞全部的 sync_q.join()
|
|
|
|
async def wait_closed(self):
|
|
if not self.closing:
|
|
raise RuntimeError("队列已被关闭")
|
|
await asyncio.sleep(0)
|
|
if not self.pending:
|
|
return
|
|
await asyncio.wait(self.pending)
|
|
|
|
@property
|
|
def closed(self) -> bool:
|
|
return self.closing and not self.pending
|
|
|
|
@property
|
|
def maxsize(self) -> int:
|
|
return self._maxsize
|
|
|
|
@property
|
|
def sync_q(self) -> "_SyncQueueProxy[T]":
|
|
return self._sync_queue
|
|
|
|
@property
|
|
def async_q(self) -> "_AsyncQueueProxy[T]":
|
|
return self._async_queue
|
|
|
|
def _init(self):
|
|
self._queue = deque() # type: Deque[T]
|
|
|
|
def qsize(self) -> int:
|
|
return len(self._queue)
|
|
|
|
def put(self, item: T):
|
|
self._queue.append(item)
|
|
|
|
def get(self) -> T:
|
|
return self._queue.popleft()
|
|
|
|
def put_internal(self, item: T):
|
|
self.put(item)
|
|
self.unfinished_tasks += 1
|
|
self.finished.clear()
|
|
|
|
def notify_sync_not_empty(self):
|
|
def f():
|
|
with self.sync_mutex:
|
|
self.sync_not_empty.notify()
|
|
|
|
self.loop.run_in_executor(None, f)
|
|
|
|
def notify_sync_not_full(self):
|
|
def f():
|
|
with self.sync_mutex:
|
|
self.sync_not_full.notify()
|
|
|
|
fut = asyncio.ensure_future(self.loop.run_in_executor(None, f))
|
|
fut.add_done_callback(self.pending.discard)
|
|
self.pending.add(fut)
|
|
|
|
def notify_async_not_empty(self, *, threadsafe: bool):
|
|
async def f():
|
|
async with self.async_mutex:
|
|
self.async_not_empty.notify()
|
|
|
|
def task_maker():
|
|
task = self.loop.create_task(f())
|
|
task.add_done_callback(self.pending.discard)
|
|
self.pending.add(task)
|
|
|
|
if threadsafe:
|
|
self._call_soon_threadsafe(task_maker)
|
|
else:
|
|
self._call_soon(task_maker)
|
|
|
|
def notify_async_not_full(self, *, threadsafe: bool):
|
|
async def f():
|
|
async with self.async_mutex:
|
|
self.async_not_full.notify()
|
|
|
|
def task_maker():
|
|
task = self.loop.create_task(f())
|
|
task.add_done_callback(self.pending.discard)
|
|
self.pending.add(task)
|
|
|
|
if threadsafe:
|
|
self._call_soon_threadsafe(task_maker)
|
|
else:
|
|
self._call_soon(task_maker)
|
|
|
|
def check_closing(self):
|
|
if self.closing:
|
|
raise RuntimeError("禁止对已关闭的队列进行操作")
|
|
|
|
|
|
# noinspection PyProtectedMember
|
|
class _SyncQueueProxy(SyncQueue[T]): # pylint: disable=W0212
|
|
"""同步"""
|
|
|
|
def __init__(self, parent: Queue[T]):
|
|
self._parent = parent
|
|
|
|
@property
|
|
def maxsize(self) -> int:
|
|
return self._parent.maxsize
|
|
|
|
@property
|
|
def closed(self) -> bool:
|
|
return self._parent.closed
|
|
|
|
def task_done(self):
|
|
self._parent.check_closing()
|
|
with self._parent.all_tasks_done:
|
|
unfinished = self._parent.unfinished_tasks - 1
|
|
if unfinished <= 0:
|
|
if unfinished < 0:
|
|
raise ValueError("task_done() 执行次数过多")
|
|
self._parent.all_tasks_done.notify_all()
|
|
self._parent.loop.call_soon_threadsafe(self._parent.finished.set)
|
|
self._parent.unfinished_tasks = unfinished
|
|
|
|
def join(self):
|
|
self._parent.check_closing()
|
|
with self._parent.all_tasks_done:
|
|
while self._parent.unfinished_tasks:
|
|
self._parent.all_tasks_done.wait()
|
|
self._parent.check_closing()
|
|
|
|
def qsize(self) -> int:
|
|
"""返回队列的大致大小(不可靠)"""
|
|
return self._parent.qsize()
|
|
|
|
@property
|
|
def unfinished_tasks(self) -> int:
|
|
"""返回未完成 task 的数量"""
|
|
return self._parent.unfinished_tasks
|
|
|
|
def empty(self) -> bool:
|
|
return not self._parent.qsize()
|
|
|
|
def full(self) -> bool:
|
|
return 0 < self._parent.maxsize <= self._parent.qsize()
|
|
|
|
def put(self, item: T, block: bool = True, timeout: OptFloat = None):
|
|
self._parent.check_closing()
|
|
with self._parent.sync_not_full:
|
|
if self._parent.maxsize > 0:
|
|
if not block:
|
|
if self._parent.qsize() >= self._parent.maxsize:
|
|
raise SyncQueueFull
|
|
elif timeout is None:
|
|
while self._parent.qsize() >= self._parent.maxsize:
|
|
self._parent.sync_not_full.wait()
|
|
elif timeout < 0:
|
|
raise ValueError("'timeout' 必须为一个非负数")
|
|
else:
|
|
time = self._parent.loop.time
|
|
end_time = time() + timeout
|
|
while self._parent.qsize() >= self._parent.maxsize:
|
|
remaining = end_time - time()
|
|
if remaining <= 0.0:
|
|
raise SyncQueueFull
|
|
self._parent.sync_not_full.wait(remaining)
|
|
self._parent.put_internal(item)
|
|
self._parent.sync_not_empty.notify()
|
|
self._parent.notify_async_not_empty(threadsafe=True)
|
|
|
|
def get(self, block: bool = True, timeout: OptFloat = None) -> T:
|
|
self._parent.check_closing()
|
|
with self._parent.sync_not_empty:
|
|
if not block:
|
|
if not self._parent.qsize():
|
|
raise SyncQueueEmpty
|
|
elif timeout is None:
|
|
while not self._parent.qsize():
|
|
self._parent.sync_not_empty.wait()
|
|
elif timeout < 0:
|
|
raise ValueError("'timeout' 必须为一个非负数")
|
|
else:
|
|
time = self._parent.loop.time
|
|
end_time = time() + timeout
|
|
while not self._parent.qsize():
|
|
remaining = end_time - time()
|
|
if remaining <= 0.0:
|
|
raise SyncQueueEmpty
|
|
self._parent.sync_not_empty.wait(remaining)
|
|
item = self._parent.get()
|
|
self._parent.sync_not_full.notify()
|
|
self._parent.notify_async_not_full(threadsafe=True)
|
|
return item
|
|
|
|
def put_nowait(self, item: T):
|
|
return self.put(item, block=False)
|
|
|
|
def get_nowait(self) -> T:
|
|
return self.get(block=False)
|
|
|
|
|
|
# noinspection PyProtectedMember
|
|
class _AsyncQueueProxy(AsyncQueue[T]): # pylint: disable=W0212
|
|
"""异步"""
|
|
|
|
def __init__(self, parent: Queue[T]):
|
|
self._parent = parent
|
|
|
|
@property
|
|
def closed(self) -> bool:
|
|
return self._parent.closed
|
|
|
|
def qsize(self) -> int:
|
|
return self._parent.qsize()
|
|
|
|
@property
|
|
def unfinished_tasks(self) -> int:
|
|
return self._parent.unfinished_tasks
|
|
|
|
@property
|
|
def maxsize(self) -> int:
|
|
return self._parent.maxsize
|
|
|
|
def empty(self) -> bool:
|
|
return self.qsize() == 0
|
|
|
|
def full(self) -> bool:
|
|
if self._parent.maxsize <= 0:
|
|
return False
|
|
return self.qsize() >= self._parent.maxsize
|
|
|
|
async def put(self, item: T) -> None:
|
|
self._parent.check_closing()
|
|
async with self._parent.async_not_full:
|
|
self._parent.sync_mutex.acquire()
|
|
locked = True
|
|
try:
|
|
if self._parent.maxsize > 0:
|
|
do_wait = True
|
|
while do_wait:
|
|
do_wait = self._parent.qsize() >= self._parent.maxsize
|
|
if do_wait:
|
|
locked = False
|
|
self._parent.sync_mutex.release()
|
|
await self._parent.async_not_full.wait()
|
|
self._parent.sync_mutex.acquire()
|
|
locked = True
|
|
|
|
self._parent.put_internal(item)
|
|
self._parent.async_not_empty.notify()
|
|
self._parent.notify_sync_not_empty()
|
|
finally:
|
|
if locked:
|
|
self._parent.sync_mutex.release()
|
|
|
|
def put_nowait(self, item: T):
|
|
self._parent.check_closing()
|
|
with self._parent.sync_mutex and 0 < self._parent.maxsize <= self._parent.qsize():
|
|
raise AsyncQueueFull
|
|
|
|
self._parent.put_internal(item)
|
|
self._parent.notify_async_not_empty(threadsafe=False)
|
|
self._parent.notify_sync_not_empty()
|
|
|
|
async def get(self) -> T:
|
|
self._parent.check_closing()
|
|
async with self._parent.async_not_empty:
|
|
self._parent.sync_mutex.acquire()
|
|
locked = True
|
|
try:
|
|
do_wait = True
|
|
while do_wait:
|
|
do_wait = self._parent.qsize() == 0
|
|
|
|
if do_wait:
|
|
locked = False
|
|
self._parent.sync_mutex.release()
|
|
await self._parent.async_not_empty.wait()
|
|
self._parent.sync_mutex.acquire()
|
|
locked = True
|
|
|
|
item = self._parent.get()
|
|
self._parent.async_not_full.notify()
|
|
self._parent.notify_sync_not_full()
|
|
return item
|
|
finally:
|
|
if locked:
|
|
self._parent.sync_mutex.release()
|
|
|
|
def get_nowait(self) -> T:
|
|
self._parent.check_closing()
|
|
with self._parent.sync_mutex:
|
|
if self._parent.qsize() == 0:
|
|
raise AsyncQueueEmpty
|
|
|
|
item = self._parent.get()
|
|
self._parent.notify_async_not_full(threadsafe=False)
|
|
self._parent.notify_sync_not_full()
|
|
return item
|
|
|
|
def task_done(self):
|
|
self._parent.check_closing()
|
|
with self._parent.all_tasks_done:
|
|
if self._parent.unfinished_tasks <= 0:
|
|
raise ValueError("task_done() called too many times")
|
|
self._parent.unfinished_tasks -= 1
|
|
if self._parent.unfinished_tasks == 0:
|
|
self._parent.finished.set()
|
|
self._parent.all_tasks_done.notify_all()
|
|
|
|
async def join(self) -> None:
|
|
while True:
|
|
with self._parent.sync_mutex:
|
|
self._parent.check_closing()
|
|
if self._parent.unfinished_tasks == 0:
|
|
break
|
|
await self._parent.finished.wait()
|
|
|
|
|
|
class PriorityQueue(Queue[T]):
|
|
"""优先级队列"""
|
|
|
|
def _init(self):
|
|
self._heap_queue: List[T] = []
|
|
|
|
def qsize(self) -> int:
|
|
return len(self._heap_queue)
|
|
|
|
def put(self, item: T):
|
|
if not isinstance(item, tuple):
|
|
if hasattr(item, "priority"):
|
|
item = (int(item.priority), item)
|
|
else:
|
|
try:
|
|
item = (int(item), item)
|
|
except (TypeError, ValueError):
|
|
pass
|
|
heappush(self._heap_queue, item)
|
|
|
|
def get(self) -> T:
|
|
return heappop(self._heap_queue)
|
|
|
|
|
|
class LifoQueue(Queue[T]):
|
|
"""后进先出队列"""
|
|
|
|
def qsize(self) -> int:
|
|
return len(self._queue)
|
|
|
|
def put(self, item: T):
|
|
self._queue.append(item)
|
|
|
|
def get(self) -> T:
|
|
return self._queue.pop()
|