120 lines
4.6 KiB
Python
120 lines
4.6 KiB
Python
"""inspired from https://quantlane.com/blog/ensure-asyncio-task-exceptions-get-logged/"""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from collections.abc import Awaitable, Coroutine, Generator
|
|
from typing import Any, Callable, TypeVar, cast
|
|
|
|
from . import core
|
|
from .logging import log
|
|
|
|
running_tasks: set[asyncio.Task] = set()
|
|
lazy_tasks_running: dict[str, asyncio.Task] = {}
|
|
lazy_coroutines_waiting: dict[str, Coroutine[Any, Any, Any]] = {}
|
|
_await_tasks_on_shutdown: set[asyncio.Task] = set()
|
|
|
|
|
|
def create(coroutine: Awaitable, *, name: str = 'unnamed task', handle_exceptions: bool = True) -> asyncio.Task:
|
|
"""Wraps a loop.create_task call and ensures there is an exception handler added to the task.
|
|
|
|
Also a reference to the task is kept until it is done, so that the task is not garbage collected mid-execution.
|
|
See https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task.
|
|
|
|
:param coroutine: the coroutine or awaitable to wrap
|
|
:param name: the name of the task which is helpful for debugging (default: "unnamed task")
|
|
:param handle_exceptions: if ``True`` (default) possible exceptions are forwarded to the global exception handlers
|
|
"""
|
|
assert core.loop is not None
|
|
real_coroutine = coroutine if asyncio.iscoroutine(coroutine) else asyncio.wait_for(coroutine, None)
|
|
task: asyncio.Task = core.loop.create_task(real_coroutine, name=name)
|
|
if handle_exceptions:
|
|
task.add_done_callback(_handle_exceptions)
|
|
running_tasks.add(task)
|
|
task.add_done_callback(running_tasks.discard)
|
|
if isinstance(coroutine, _AwaitOnShutdown):
|
|
_await_tasks_on_shutdown.add(task)
|
|
task.add_done_callback(_await_tasks_on_shutdown.discard)
|
|
return task
|
|
|
|
|
|
def create_lazy(coroutine: Awaitable, *, name: str) -> None:
|
|
"""Wraps a create call and ensures a second task with the same name is delayed until the first one is done.
|
|
|
|
If a third task with the same name is created while the first one is still running, the second one is discarded.
|
|
"""
|
|
if name in lazy_tasks_running:
|
|
if name in lazy_coroutines_waiting:
|
|
lazy_coroutines_waiting[name].close()
|
|
lazy_coroutines_waiting[name] = _ensure_coroutine(coroutine)
|
|
return
|
|
|
|
def finalize(name: str) -> None:
|
|
lazy_tasks_running.pop(name)
|
|
if name in lazy_coroutines_waiting:
|
|
create_lazy(lazy_coroutines_waiting.pop(name), name=name)
|
|
task = create(coroutine, name=name)
|
|
lazy_tasks_running[name] = task
|
|
task.add_done_callback(lambda _: finalize(name))
|
|
|
|
|
|
class _AwaitOnShutdown:
|
|
def __init__(self, factory: Callable[[], Awaitable[Any]]) -> None:
|
|
self._factory = factory
|
|
|
|
def __await__(self) -> Generator[Any, None, Any]:
|
|
return self._factory().__await__()
|
|
|
|
|
|
F = TypeVar('F', bound=Callable[..., Awaitable[Any]])
|
|
|
|
|
|
def await_on_shutdown(func: F) -> F:
|
|
"""Tag a coroutine function so tasks created from it won't be cancelled during shutdown.
|
|
|
|
*Added in version 2.16.0*
|
|
"""
|
|
def wrapper(*args: Any, **kwargs: Any) -> Awaitable[Any]:
|
|
return _AwaitOnShutdown(lambda: func(*args, **kwargs))
|
|
return cast(F, wrapper)
|
|
|
|
|
|
def _ensure_coroutine(awaitable: Awaitable[Any]) -> Coroutine[Any, Any, Any]:
|
|
"""Convert an awaitable to a coroutine if it isn't already one."""
|
|
if asyncio.iscoroutine(awaitable):
|
|
return awaitable
|
|
|
|
async def wrapper() -> Any:
|
|
return await awaitable
|
|
return wrapper()
|
|
|
|
|
|
def _handle_exceptions(task: asyncio.Task) -> None:
|
|
try:
|
|
task.result()
|
|
except asyncio.CancelledError:
|
|
pass
|
|
except Exception as e:
|
|
core.app.handle_exception(e)
|
|
|
|
|
|
async def teardown() -> None:
|
|
"""Cancel all running tasks and coroutines on shutdown. (For internal use only.)"""
|
|
while running_tasks or lazy_tasks_running:
|
|
tasks = running_tasks | set(lazy_tasks_running.values())
|
|
for task in tasks:
|
|
if task.done() or task.cancelled() or task in _await_tasks_on_shutdown:
|
|
continue
|
|
task.cancel()
|
|
if tasks:
|
|
await asyncio.sleep(0) # NOTE: ensure the loop can cancel the tasks before it shuts down
|
|
try:
|
|
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=2.0)
|
|
except asyncio.TimeoutError:
|
|
log.error('Could not cancel %s tasks within timeout: %s',
|
|
len(tasks),
|
|
', '.join(t.get_name() for t in tasks if not t.done()))
|
|
except Exception:
|
|
log.exception('Error while cancelling tasks')
|
|
for coro in lazy_coroutines_waiting.values():
|
|
coro.close()
|