Source code for s2_sdk._append_session

from __future__ import annotations

import asyncio
from collections import deque
from typing import AsyncIterable, NamedTuple, Self

from s2_sdk._client import HttpClient
from s2_sdk._exceptions import S2ClientError
from s2_sdk._s2s._append_session import run_append_session
from s2_sdk._types import (
    AppendAck,
    AppendInput,
    Compression,
    Retry,
    metered_bytes,
)
from s2_sdk._validators import validate_append_input


class _UnackedBatch(NamedTuple):
    ack_fut: asyncio.Future[AppendAck]
    metered_bytes: int


[docs] class AppendSession: """Session for high-throughput appending with backpressure control. Supports pipelining multiple :class:`AppendInput`\\ s while preserving submission order. Caution: Returned by :meth:`S2Stream.append_session`. Do not instantiate directly. """ __slots__ = ( "_closed", "_client", "_compression", "_error", "_encryption_key", "_permits", "_queue", "_retry", "_stream_name", "_task", "_unacked", ) def __init__( self, client: HttpClient, stream_name: str, retry: Retry, compression: Compression, max_unacked_bytes: int, max_unacked_batches: int | None, encryption_key: str | None = None, ) -> None: self._client = client self._stream_name = stream_name self._retry = retry self._compression = compression self._encryption_key = encryption_key self._permits = _AppendPermits(max_unacked_bytes, max_unacked_batches) self._queue: asyncio.Queue[AppendInput | None] = asyncio.Queue() self._unacked: deque[_UnackedBatch] = deque() self._closed = False self._error: BaseException | None = None self._task = asyncio.get_running_loop().create_task(self._run())
[docs] async def submit(self, inp: AppendInput) -> BatchSubmitTicket: """Submit a batch of records for appending. Waits when backpressure limits are reached. """ self._check_ready() batch_bytes = metered_bytes(inp.records) validate_append_input(len(inp.records), batch_bytes) await self._permits.acquire(batch_bytes) # Re-check after potentially waiting on backpressure. try: self._check_ready() except BaseException: self._permits.release(batch_bytes) raise ack_fut: asyncio.Future[AppendAck] = asyncio.get_running_loop().create_future() self._unacked.append(_UnackedBatch(ack_fut, batch_bytes)) await self._queue.put(inp) return BatchSubmitTicket(ack_fut)
def _check_ready(self) -> None: if self._closed: raise S2ClientError("AppendSession is closed") if self._error is not None: raise self._error
[docs] async def close(self) -> None: """Close the session and wait for all submitted batches to be appended.""" if self._closed: return self._closed = True await self._queue.put(None) await self._task if self._error is not None: raise self._error
async def __aenter__(self) -> Self: return self async def __aexit__(self, exc_type, exc_val, exc_tb) -> bool: await self.close() return False async def _run(self) -> None: try: async for ack in run_append_session( self._client, self._stream_name, self._input_iter(), retry=self._retry, compression=self._compression, ack_timeout=self._client._request_timeout, encryption_key=self._encryption_key, ): self._resolve_next(ack) except BaseException as e: # Unwrap single-exception ExceptionGroups so callers see # the original exception type (e.g. S2ServerError, SeqNumMismatchError). exc = e while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1: exc = exc.exceptions[0] self._fail_all(exc) async def _input_iter(self) -> AsyncIterable[AppendInput]: while True: item = await self._queue.get() if item is None: return yield item def _resolve_next(self, ack: AppendAck) -> None: unacked = self._unacked.popleft() self._permits.release(unacked.metered_bytes) if not unacked.ack_fut.done(): unacked.ack_fut.set_result(ack) def _fail_all(self, error: BaseException) -> None: self._error = error for unacked in self._unacked: self._permits.release(unacked.metered_bytes) if not unacked.ack_fut.done(): unacked.ack_fut.set_exception(error) self._unacked.clear() # Drain queue while not self._queue.empty(): try: self._queue.get_nowait() except asyncio.QueueEmpty: break
[docs] class BatchSubmitTicket: """Awaitable that resolves to an :class:`AppendAck` once the batch is appended.""" __slots__ = ("_ack_fut",) def __init__(self, ack_fut: asyncio.Future[AppendAck]) -> None: self._ack_fut = ack_fut def __await__(self): return self._ack_fut.__await__()
class _Semaphore: __slots__ = ("_event", "_lock", "_value") def __init__(self, value: int) -> None: self._value = value self._event = asyncio.Event() self._event.set() self._lock = asyncio.Lock() async def acquire(self, n: int) -> None: while True: async with self._lock: if self._value >= n: self._value -= n return self._event.clear() await self._event.wait() def release(self, n: int) -> None: self._value += n self._event.set() class _AppendPermits: __slots__ = ("_bytes", "_count") def __init__(self, max_bytes: int, max_count: int | None = None) -> None: self._bytes = _Semaphore(max_bytes) self._count = _Semaphore(max_count) if max_count is not None else None async def acquire(self, n_bytes: int) -> None: if self._count is not None: await self._count.acquire(1) try: await self._bytes.acquire(n_bytes) except BaseException: if self._count is not None: self._count.release(1) raise def release(self, n_bytes: int) -> None: self._bytes.release(n_bytes) if self._count is not None: self._count.release(1)