Source code for s2_sdk._producer

from __future__ import annotations

import asyncio
from collections import deque
from dataclasses import dataclass
from typing import Self

from s2_sdk._append_session import AppendSession, BatchSubmitTicket
from s2_sdk._batching import BatchAccumulator
from s2_sdk._client import HttpClient
from s2_sdk._exceptions import S2ClientError
from s2_sdk._types import (
    AppendAck,
    AppendInput,
    Batching,
    Compression,
    IndexedAppendAck,
    Record,
    Retry,
)


@dataclass(slots=True)
class _UnackedBatch:
    ticket: BatchSubmitTicket
    indexed_ack_futs: tuple[asyncio.Future[IndexedAppendAck], ...]


[docs] class Producer: """High-level interface for submitting individual records. Handles batching into :class:`AppendInput` automatically and uses an append session internally. Caution: Returned by :meth:`S2Stream.producer`. Do not instantiate directly. """ __slots__ = ( "_accumulator", "_indexed_ack_futs", "_batch_ready", "_closed", "_drain_task", "_error", "_fencing_token", "_linger_task", "_match_seq_num", "_unacked", "_session", ) def __init__( self, client: HttpClient, stream_name: str, retry: Retry, compression: Compression, fencing_token: str | None, match_seq_num: int | None, max_unacked_bytes: int, batching: Batching, encryption_key: str | None = None, ) -> None: self._session = AppendSession( client=client, stream_name=stream_name, retry=retry, compression=compression, max_unacked_bytes=max_unacked_bytes, max_unacked_batches=None, encryption_key=encryption_key, ) self._fencing_token = fencing_token self._match_seq_num = match_seq_num self._accumulator = BatchAccumulator(batching) self._indexed_ack_futs: list[asyncio.Future[IndexedAppendAck]] = [] self._linger_task: asyncio.Task[None] | None = None self._unacked: deque[_UnackedBatch] = deque() self._batch_ready = asyncio.Event() self._drain_task = asyncio.get_running_loop().create_task(self._drain_acks()) self._closed = False self._error: BaseException | None = None
[docs] async def submit(self, record: Record) -> RecordSubmitTicket: """Submit a record for appending. Waits when backpressure limits are reached. """ if self._closed: raise S2ClientError("Producer is closed") if self._error is not None: raise self._error ack_fut: asyncio.Future[IndexedAppendAck] = ( asyncio.get_running_loop().create_future() ) self._indexed_ack_futs.append(ack_fut) first_in_batch = self._accumulator.is_empty() self._accumulator.add(record) if self._accumulator.is_full(): await self._flush() elif first_in_batch and self._accumulator.linger > 0: self._linger_task = asyncio.get_running_loop().create_task( self._linger_flush() ) return RecordSubmitTicket(ack_fut)
[docs] async def close(self) -> None: """Close the producer and wait for all submitted records to be appended.""" if self._closed: return self._closed = True try: await self._flush() await self._session.close() finally: self._batch_ready.set() await self._drain_task if self._error is not None: raise self._error
async def __aenter__(self) -> Self: return self async def __aexit__(self, exc_type, exc_val, exc_tb) -> bool: await self.close() return False async def _flush(self) -> None: if self._accumulator.is_empty(): return if self._linger_task is not None: self._linger_task.cancel() self._linger_task = None records = self._accumulator.take() indexed_ack_futs = tuple(self._indexed_ack_futs) self._indexed_ack_futs.clear() batch = AppendInput( records=records, fencing_token=self._fencing_token, match_seq_num=self._match_seq_num, ) if self._match_seq_num is not None: self._match_seq_num += len(records) try: ticket = await self._session.submit(batch) except BaseException as e: self._error = e for ack_fut in indexed_ack_futs: if not ack_fut.done(): ack_fut.set_exception(e) # Suppress "Future exception was never retrieved" for # futures the caller never got back (submit raised). ack_fut.exception() raise self._unacked.append( _UnackedBatch(ticket=ticket, indexed_ack_futs=indexed_ack_futs) ) self._batch_ready.set() async def _drain_acks(self) -> None: """Single background task that resolves batches in FIFO order.""" while True: while not self._unacked: if self._closed: return self._batch_ready.clear() if self._unacked: break await self._batch_ready.wait() unacked = self._unacked.popleft() try: ack: AppendAck = await unacked.ticket # type: ignore[assignment] for i, ack_fut in enumerate(unacked.indexed_ack_futs): if not ack_fut.done(): ack_fut.set_result( IndexedAppendAck( seq_num=ack.start.seq_num + i, batch=ack, ) ) except BaseException as e: self._error = e for ack_fut in unacked.indexed_ack_futs: if not ack_fut.done(): ack_fut.set_exception(e) # Fail all remaining unacked batches too for remaining in self._unacked: for ack_fut in remaining.indexed_ack_futs: if not ack_fut.done(): ack_fut.set_exception(e) self._unacked.clear() return async def _linger_flush(self) -> None: assert self._accumulator.linger is not None await asyncio.sleep(self._accumulator.linger) # Clear before calling _flush() so it doesn't cancel this task. self._linger_task = None await self._flush()
[docs] class RecordSubmitTicket: """Awaitable that resolves to an :class:`IndexedAppendAck` once the record is appended.""" __slots__ = ("_ack_fut",) def __init__(self, ack_fut: asyncio.Future[IndexedAppendAck]) -> None: self._ack_fut = ack_fut def __await__(self): return self._ack_fut.__await__()