Bases: TaskGroup
Asynchronous context manager for managing groups of tasks.
See python asyncio task groups documentation.
Adapted for pulsefire, key differences from asyncio.TaskGroup
:
- Accepts a semaphore to restrict the amount of concurrent running coroutines.
- Due to semaphore support, the
create_task
method is now async.
- Allows internal collection of results and exceptions, similar to
asyncio.Task
.
- If exception collection is on (default), the task group will not abort on task exceptions.
Example:
async with TaskGroup(asyncio.Semaphore(100)) as tg:
await tg.create_task(coro_func(...))
results = tg.results()
Source code in pulsefire/taskgroups.py
| class TaskGroup(asyncio.TaskGroup):
"""Asynchronous context manager for managing groups of tasks.
See [python asyncio task groups documentation](https://docs.python.org/3/library/asyncio-task.html#task-groups).
Adapted for pulsefire, key differences from `asyncio.TaskGroup`:
- Accepts a semaphore to restrict the amount of concurrent running coroutines.
- Due to semaphore support, the `create_task` method is now async.
- Allows internal collection of results and exceptions, similar to `asyncio.Task`.
- If exception collection is on (default), the task group will not abort on task exceptions.
Example:
```python
async with TaskGroup(asyncio.Semaphore(100)) as tg:
await tg.create_task(coro_func(...))
results = tg.results()
```
"""
semaphore: asyncio.Semaphore | None = None
"""Semaphore for restricting concurrent running coroutines."""
collect_results: bool = True
"""Flag for collecting task results."""
collect_exceptions: bool = True
"""Flag for collecting task exceptions, disables abort."""
def __init__(
self,
semaphore: asyncio.Semaphore | None = None,
*,
collect_results: bool = True,
collect_exceptions: bool = True,
) -> None:
super().__init__()
self.semaphore = semaphore
self.collect_results = collect_results
self.collect_exceptions = collect_exceptions
self._exceptions: list[BaseException] = []
self._results = []
async def __aenter__(self):
self._exceptions = []
self._results = []
return await super().__aenter__()
def results[T](self) -> list[T]:
"""Return the collected results returned from created tasks."""
if not self.collect_results:
raise RuntimeError(f"TaskGroup {self!r} has `collect_results` off")
return self._results
def exceptions(self) -> list[BaseException]:
"""Return the collected exceptions raised from created tasks."""
if not self.collect_exceptions:
raise RuntimeError(f"TaskGroup {self!r} has `collect_exceptions` off")
return self._exceptions
@override
async def create_task[T](self, coro: Awaitable[T], *, name: str | None = None, context: Context | None = None) -> asyncio.Task[T]:
"""Create a new task in this group and return it.
If this group has a semaphore, wrap this semaphore on the coroutine.
"""
_coro = coro
if self.semaphore:
await self.semaphore.acquire()
async def semaphored():
try:
return await _coro
finally:
self.semaphore.release()
coro = semaphored()
return super().create_task(coro, name=name, context=context)
def _on_task_done(self, task) -> None:
if task.cancelled():
return super()._on_task_done(task)
if exc := task.exception():
if self.collect_exceptions:
LOGGER.warning(
"TaskGroup: unhandled exception\n" +
"".join(traceback.format_exception(type(exc), exc, exc.__traceback__))
)
self._exceptions.append(exc)
self._tasks.discard(task)
if self._on_completed_fut is not None and not self._tasks:
if not self._on_completed_fut.done():
self._on_completed_fut.set_result(True)
return
elif self.collect_results:
self._results.append(task.result())
return super()._on_task_done(task)
|
Attributes
semaphore
class-attribute
instance-attribute
semaphore: Semaphore | None = semaphore
Semaphore for restricting concurrent running coroutines.
collect_results
class-attribute
instance-attribute
collect_results: bool = collect_results
Flag for collecting task results.
collect_exceptions
class-attribute
instance-attribute
collect_exceptions: bool = collect_exceptions
Flag for collecting task exceptions, disables abort.
Functions
__init__
__init__(
semaphore: asyncio.Semaphore | None = None,
*,
collect_results: bool = True,
collect_exceptions: bool = True
) -> None
Source code in pulsefire/taskgroups.py
| def __init__(
self,
semaphore: asyncio.Semaphore | None = None,
*,
collect_results: bool = True,
collect_exceptions: bool = True,
) -> None:
super().__init__()
self.semaphore = semaphore
self.collect_results = collect_results
self.collect_exceptions = collect_exceptions
self._exceptions: list[BaseException] = []
self._results = []
|
__aenter__
async
Source code in pulsefire/taskgroups.py
| async def __aenter__(self):
self._exceptions = []
self._results = []
return await super().__aenter__()
|
results
Return the collected results returned from created tasks.
Source code in pulsefire/taskgroups.py
| def results[T](self) -> list[T]:
"""Return the collected results returned from created tasks."""
if not self.collect_results:
raise RuntimeError(f"TaskGroup {self!r} has `collect_results` off")
return self._results
|
exceptions
exceptions() -> list[BaseException]
Return the collected exceptions raised from created tasks.
Source code in pulsefire/taskgroups.py
| def exceptions(self) -> list[BaseException]:
"""Return the collected exceptions raised from created tasks."""
if not self.collect_exceptions:
raise RuntimeError(f"TaskGroup {self!r} has `collect_exceptions` off")
return self._exceptions
|
create_task
async
create_task(
coro: Awaitable[T],
*,
name: str | None = None,
context: Context | None = None
) -> asyncio.Task[T]
Create a new task in this group and return it.
If this group has a semaphore, wrap this semaphore on the coroutine.
Source code in pulsefire/taskgroups.py
| @override
async def create_task[T](self, coro: Awaitable[T], *, name: str | None = None, context: Context | None = None) -> asyncio.Task[T]:
"""Create a new task in this group and return it.
If this group has a semaphore, wrap this semaphore on the coroutine.
"""
_coro = coro
if self.semaphore:
await self.semaphore.acquire()
async def semaphored():
try:
return await _coro
finally:
self.semaphore.release()
coro = semaphored()
return super().create_task(coro, name=name, context=context)
|