Source code for s2_sdk._batching

from __future__ import annotations

import asyncio
from typing import AsyncIterable

from s2_sdk._types import AppendInput, Batching, Record, metered_bytes
from s2_sdk._validators import validate_batching


class BatchAccumulator:
    __slots__ = ("_batching", "_bytes", "_records")

    def __init__(self, batching: Batching) -> None:
        self._batching = batching
        self._records: list[Record] = []
        self._bytes = 0

    def add(self, record: Record) -> None:
        self._records.append(record)
        self._bytes += metered_bytes((record,))

    def take(self) -> list[Record]:
        records = list(self._records)
        self._records.clear()
        self._bytes = 0
        return records

    def is_full(self) -> bool:
        return (
            len(self._records) >= self._batching.max_records
            or self._bytes >= self._batching.max_bytes
        )

    def is_empty(self) -> bool:
        return len(self._records) == 0

    @property
    def linger(self) -> float:
        return self._batching.linger.total_seconds()


[docs] async def append_record_batches( records: AsyncIterable[Record], *, batching: Batching | None = None, ) -> AsyncIterable[list[Record]]: """Group records into batches based on count, bytes, and linger time.""" if batching is None: batching = Batching() validate_batching(batching.max_records, batching.max_bytes) acc = BatchAccumulator(batching) linger_secs = batching.linger.total_seconds() record_iter = records.__aiter__() pending_next = None try: while True: if pending_next is not None: record = await pending_next pending_next = None else: record = await anext(record_iter, None) if record is None: break acc.add(record) deadline = ( asyncio.get_running_loop().time() + linger_secs if linger_secs > 0 else None ) while not acc.is_full(): if deadline is not None: remaining = deadline - asyncio.get_running_loop().time() if remaining <= 0: break next_task = asyncio.create_task(anext(record_iter, None)) done, _ = await asyncio.wait({next_task}, timeout=remaining) if not done: pending_next = next_task break record = next_task.result() else: record = await anext(record_iter, None) if record is None: break acc.add(record) yield acc.take() except Exception: if not acc.is_empty(): yield acc.take() raise finally: if pending_next is not None: pending_next.cancel()
[docs] async def append_inputs( records: AsyncIterable[Record], *, match_seq_num: int | None = None, fencing_token: str | None = None, batching: Batching | None = None, ) -> AsyncIterable[AppendInput]: """Group records into :class:`AppendInput` batches based on count, bytes, and linger time. If ``match_seq_num`` is set, it applies to the first input and is auto-incremented for subsequent ones. """ if batching is None: batching = Batching() async for batch in append_record_batches(records, batching=batching): if not batch: continue append_input = AppendInput( records=batch, match_seq_num=match_seq_num, fencing_token=fencing_token, ) if match_seq_num is not None: match_seq_num += len(batch) yield append_input