diff --git a/docs/source/api-reference.md b/docs/source/api-reference.md index 569a06c..056953f 100644 --- a/docs/source/api-reference.md +++ b/docs/source/api-reference.md @@ -92,6 +92,8 @@ .. autoenum:: AppendRetryPolicy +.. autoenum:: Encryption + .. autoenum:: StorageClass .. autoenum:: TimestampingMode diff --git a/src/s2_sdk/__init__.py b/src/s2_sdk/__init__.py index 0a52f00..2d0850c 100644 --- a/src/s2_sdk/__init__.py +++ b/src/s2_sdk/__init__.py @@ -29,6 +29,7 @@ Batching, CommandRecord, Compression, + Encryption, Endpoints, ExactMatch, Gauge, @@ -84,6 +85,7 @@ "Record", "AppendInput", "AppendAck", + "Encryption", "IndexedAppendAck", "StreamPosition", "SeqNum", diff --git a/src/s2_sdk/_append_session.py b/src/s2_sdk/_append_session.py index e26f10a..4d52151 100644 --- a/src/s2_sdk/_append_session.py +++ b/src/s2_sdk/_append_session.py @@ -37,6 +37,7 @@ class AppendSession: "_client", "_compression", "_error", + "_encryption_key", "_permits", "_queue", "_retry", @@ -53,11 +54,13 @@ def __init__( compression: Compression, max_unacked_bytes: int, max_unacked_batches: int | None, + encryption_key: str | None = None, ) -> None: self._client = client self._stream_name = stream_name self._retry = retry self._compression = compression + self._encryption_key = encryption_key self._permits = _AppendPermits(max_unacked_bytes, max_unacked_batches) self._queue: asyncio.Queue[AppendInput | None] = asyncio.Queue() @@ -121,6 +124,7 @@ async def _run(self) -> None: retry=self._retry, compression=self._compression, ack_timeout=self._client._request_timeout, + encryption_key=self._encryption_key, ): self._resolve_next(ack) except BaseException as e: diff --git a/src/s2_sdk/_mappers.py b/src/s2_sdk/_mappers.py index e473d57..7ba91c9 100644 --- a/src/s2_sdk/_mappers.py +++ b/src/s2_sdk/_mappers.py @@ -11,6 +11,7 @@ BasinConfig, BasinInfo, BasinScope, + Encryption, ExactMatch, Gauge, Label, @@ -47,6 +48,8 @@ def basin_config_to_json(config: BasinConfig | None) -> dict[str, Any] | None: result["default_stream_config"] = stream_config_to_json( config.default_stream_config ) + if config.stream_cipher is not None: + result["stream_cipher"] = config.stream_cipher.value if config.create_stream_on_append is not None: result["create_stream_on_append"] = config.create_stream_on_append if config.create_stream_on_read is not None: @@ -56,8 +59,10 @@ def basin_config_to_json(config: BasinConfig | None) -> dict[str, Any] | None: def basin_config_from_json(data: dict[str, Any]) -> BasinConfig: dsc = data.get("default_stream_config") + stream_cipher = data.get("stream_cipher") return BasinConfig( default_stream_config=stream_config_from_json(dsc) if dsc else None, + stream_cipher=Encryption(stream_cipher) if stream_cipher else None, create_stream_on_append=data.get("create_stream_on_append"), create_stream_on_read=data.get("create_stream_on_read"), ) @@ -146,10 +151,12 @@ def stream_info_from_json(data: dict[str, Any]) -> StreamInfo: created_at = datetime.fromisoformat(data["created_at"]) deleted_at_str = data.get("deleted_at") deleted_at = datetime.fromisoformat(deleted_at_str) if deleted_at_str else None + cipher = data.get("cipher") return StreamInfo( name=data["name"], created_at=created_at, deleted_at=deleted_at, + cipher=Encryption(cipher) if cipher else None, ) diff --git a/src/s2_sdk/_ops.py b/src/s2_sdk/_ops.py index f91fa0a..86a2f4a 100644 --- a/src/s2_sdk/_ops.py +++ b/src/s2_sdk/_ops.py @@ -1,3 +1,4 @@ +import base64 import uuid from collections.abc import AsyncIterator from datetime import datetime @@ -31,11 +32,20 @@ from s2_sdk._producer import Producer from s2_sdk._retrier import Retrier, http_retry_on, is_safe_to_retry_unary from s2_sdk._s2s._read_session import run_read_session -from s2_sdk._types import ONE_MIB, Compression, Endpoints, Retry, Timeout, metered_bytes +from s2_sdk._types import ( + _S2_ENCRYPTION_KEY_HEADER, + ONE_MIB, + Compression, + Endpoints, + Retry, + Timeout, + metered_bytes, +) from s2_sdk._validators import ( validate_append_input, validate_basin, validate_batching, + validate_encryption_key, validate_max_unacked, validate_retry, ) @@ -608,11 +618,22 @@ async def create_stream( ) return stream_info_from_json(response.json()) - def stream(self, name: str) -> "S2Stream": + def stream( + self, + name: str, + *, + encryption_key: bytes | str | None = None, + ) -> "S2Stream": """Get an :class:`S2Stream` for performing stream-level operations. Args: name: Name of the stream. + encryption_key: Key for encrypting records on append and decrypting + them on read. Required when encryption is enabled via + :attr:`BasinConfig.stream_cipher` (see :class:`Encryption` + for supported algorithms). + If ``bytes``, it will get converted to a base64 encoded str. + If ``str``, it must be base64 encoded. Returns: An :class:`S2Stream` bound to the given stream name. @@ -620,11 +641,17 @@ def stream(self, name: str) -> "S2Stream": Tip: Also available via subscript: ``s2["my-basin"]["my-stream"]``. """ + if isinstance(encryption_key, str): + validate_encryption_key(encryption_key) + elif isinstance(encryption_key, bytes): + encryption_key = base64.b64encode(encryption_key).decode() + return S2Stream( name, self._client, retry=self._retry, compression=self._compression, + encryption_key=encryption_key, ) @fallible @@ -757,6 +784,7 @@ class S2Stream: "_name", "_client", "_compression", + "_encryption_key", "_retry", "_retrier", "_append_retrier", @@ -769,11 +797,13 @@ def __init__( *, retry: Retry, compression: Compression, + encryption_key: str | None = None, ) -> None: self._name = name self._client = client self._retry = retry self._compression = compression + self._encryption_key = encryption_key self._retrier = Retrier( should_retry_on=http_retry_on, max_attempts=retry.max_attempts, @@ -797,6 +827,15 @@ def name(self) -> str: """Stream name.""" return self._name + def _request_headers( + self, headers: dict[str, str] | None = None + ) -> dict[str, str] | None: + if self._encryption_key is None: + return headers + merged = dict(headers or {}) + merged[_S2_ENCRYPTION_KEY_HEADER] = self._encryption_key + return merged + @fallible async def check_tail(self) -> types.StreamPosition: """Check the tail of a stream. @@ -831,10 +870,12 @@ async def append(self, inp: types.AppendInput) -> types.AppendAck: "POST", _stream_path(self.name, "/records"), content=body, - headers={ - "content-type": "application/x-protobuf", - "accept": "application/x-protobuf", - }, + headers=self._request_headers( + { + "content-type": "application/x-protobuf", + "accept": "application/x-protobuf", + } + ), ) ack = pb.AppendAck() ack.ParseFromString(response.content) @@ -878,6 +919,7 @@ def append_session( compression=self._compression, max_unacked_bytes=max_unacked_bytes, max_unacked_batches=max_unacked_batches, + encryption_key=self._encryption_key, ) def producer( @@ -922,6 +964,7 @@ def producer( stream_name=self.name, retry=self._retry, compression=self._compression, + encryption_key=self._encryption_key, fencing_token=fencing_token, match_seq_num=match_seq_num, max_unacked_bytes=max_unacked_bytes, @@ -971,7 +1014,7 @@ async def read( "GET", _stream_path(self.name, "/records"), params=params, - headers={"accept": "application/x-protobuf"}, + headers=self._request_headers({"accept": "application/x-protobuf"}), ) proto_batch = pb.ReadBatch() @@ -1030,6 +1073,7 @@ async def read_session( wait, ignore_command_records, retry=self._retry, + encryption_key=self._encryption_key, ): yield batch diff --git a/src/s2_sdk/_producer.py b/src/s2_sdk/_producer.py index c51bc90..aeaefc0 100644 --- a/src/s2_sdk/_producer.py +++ b/src/s2_sdk/_producer.py @@ -60,6 +60,7 @@ def __init__( match_seq_num: int | None, max_unacked_bytes: int, batching: Batching, + encryption_key: str | None = None, ) -> None: self._session = AppendSession( client=client, @@ -68,6 +69,7 @@ def __init__( compression=compression, max_unacked_bytes=max_unacked_bytes, max_unacked_batches=None, + encryption_key=encryption_key, ) self._fencing_token = fencing_token self._match_seq_num = match_seq_num diff --git a/src/s2_sdk/_s2s/_append_session.py b/src/s2_sdk/_s2s/_append_session.py index 137edaa..2f08725 100644 --- a/src/s2_sdk/_s2s/_append_session.py +++ b/src/s2_sdk/_s2s/_append_session.py @@ -19,6 +19,7 @@ read_messages, ) from s2_sdk._types import ( + _S2_ENCRYPTION_KEY_HEADER, AppendAck, AppendInput, AppendRetryPolicy, @@ -43,6 +44,7 @@ async def run_append_session( retry: Retry, compression: Compression, ack_timeout: float | None = None, + encryption_key: str | None = None, ) -> AsyncIterable[AppendAck]: input_queue: asyncio.Queue[AppendInput | None] = asyncio.Queue( maxsize=_QUEUE_MAX_SIZE @@ -85,6 +87,7 @@ async def retrying_inner(): compression, frame_signal, ack_timeout, + encryption_key, ) return except Exception as e: @@ -135,14 +138,19 @@ async def _run_attempt( compression: Compression, frame_signal: FrameSignal | None, ack_timeout: float | None = None, + encryption_key: str | None = None, ) -> None: + headers = { + "content-type": "s2s/proto", + "accept": "s2s/proto", + } + if encryption_key is not None: + headers[_S2_ENCRYPTION_KEY_HEADER] = encryption_key + async with client.streaming_request( "POST", _stream_records_path(stream_name), - headers={ - "content-type": "s2s/proto", - "accept": "s2s/proto", - }, + headers=headers, content=_body_gen(inflight_inputs, input_queue, pending_resend, compression), frame_signal=frame_signal, ) as response: diff --git a/src/s2_sdk/_s2s/_read_session.py b/src/s2_sdk/_s2s/_read_session.py index 9d0330b..0c70b84 100644 --- a/src/s2_sdk/_s2s/_read_session.py +++ b/src/s2_sdk/_s2s/_read_session.py @@ -12,6 +12,7 @@ from s2_sdk._s2s import _stream_records_path from s2_sdk._s2s._protocol import parse_error_info, read_messages from s2_sdk._types import ( + _S2_ENCRYPTION_KEY_HEADER, ReadBatch, ReadLimit, Retry, @@ -36,6 +37,7 @@ async def run_read_session( wait: int | None, ignore_command_records: bool, retry: Retry, + encryption_key: str | None = None, ) -> AsyncIterable[ReadBatch]: params = _build_read_params(start, limit, until_timestamp, clamp_to_tail, wait) backoffs = compute_backoffs( @@ -50,6 +52,10 @@ async def run_read_session( last_tail_at: float | None = None + headers = {"content-type": "s2s/proto"} + if encryption_key is not None: + headers[_S2_ENCRYPTION_KEY_HEADER] = encryption_key + while True: if wait is not None: params["wait"] = _remaining_wait(wait, last_tail_at) @@ -59,7 +65,7 @@ async def run_read_session( "GET", _stream_records_path(stream_name), params=params, - headers={"content-type": "s2s/proto"}, + headers=headers, ) as response: if response.status_code != 200: body = await response.aread() diff --git a/src/s2_sdk/_types.py b/src/s2_sdk/_types.py index 8e1fb8f..e9303e5 100644 --- a/src/s2_sdk/_types.py +++ b/src/s2_sdk/_types.py @@ -11,6 +11,7 @@ T = TypeVar("T") ONE_MIB = 1024 * 1024 +_S2_ENCRYPTION_KEY_HEADER = "s2-encryption-key" def _parse_scheme(url: str) -> str: @@ -50,6 +51,13 @@ class AppendRetryPolicy(_DocEnum): ) +class Encryption(_DocEnum): + """Encryption algorithm.""" + + AEGIS_256 = "aegis-256", "AEGIS-256." + AES_256_GCM = "aes-256-gcm", "AES-256-GCM." + + class Endpoints: """S2 service endpoints. See `endpoints `_.""" @@ -479,6 +487,9 @@ class BasinConfig: default_stream_config: StreamConfig | None = None """Default configuration for streams in this basin.""" + stream_cipher: Encryption | None = None + """Encryption algorithm to apply to newly created streams in the basin.""" + create_stream_on_append: bool | None = None """Create stream on append if it doesn't exist.""" @@ -516,6 +527,9 @@ class StreamInfo: deleted_at: datetime | None """Deletion time if the stream is being deleted.""" + cipher: Encryption | None = None + """Encryption algorithm for this stream, if encryption is enabled.""" + @dataclass(slots=True) class ExactMatch: diff --git a/src/s2_sdk/_validators.py b/src/s2_sdk/_validators.py index 7e7a007..a7bd7b2 100644 --- a/src/s2_sdk/_validators.py +++ b/src/s2_sdk/_validators.py @@ -1,3 +1,4 @@ +import base64 import re from s2_sdk._exceptions import S2ClientError @@ -51,3 +52,10 @@ def validate_append_input(num_records: int, num_bytes: int) -> None: raise S2ClientError( f"Invalid append input: num_records={num_records}, metered_bytes={num_bytes}" ) + + +def validate_encryption_key(key: str) -> None: + try: + base64.b64decode(key, validate=True) + except Exception: + raise S2ClientError("encryption_key must be a base64 encoded str") diff --git a/tests/test_session.py b/tests/test_session.py index 129e61e..45de423 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -38,7 +38,15 @@ def _input(n_records: int = 1, body: bytes = b"payload") -> AppendInput: return AppendInput(records=[Record(body=body) for _ in range(n_records)]) -async def _fake_run(client, stream_name, inputs, retry, compression, ack_timeout=None): +async def _fake_run( + client, + stream_name, + inputs, + retry, + compression, + ack_timeout=None, + encryption_key=None, +): seq = 0 async for inp in inputs: n = len(inp.records) @@ -47,7 +55,13 @@ async def _fake_run(client, stream_name, inputs, retry, compression, ack_timeout async def _slow_fake_run( - client, stream_name, inputs, retry, compression, ack_timeout=None + client, + stream_name, + inputs, + retry, + compression, + ack_timeout=None, + encryption_key=None, ): seq = 0 async for inp in inputs: @@ -58,7 +72,13 @@ async def _slow_fake_run( async def _failing_run( - client, stream_name, inputs, retry, compression, ack_timeout=None + client, + stream_name, + inputs, + retry, + compression, + ack_timeout=None, + encryption_key=None, ): async for _ in inputs: raise RuntimeError("session failed") diff --git a/tests/test_stream_ops.py b/tests/test_stream_ops.py index ca362f0..864b273 100644 --- a/tests/test_stream_ops.py +++ b/tests/test_stream_ops.py @@ -1,4 +1,6 @@ import asyncio +import base64 +import os import time import uuid from datetime import timedelta @@ -12,6 +14,7 @@ Batching, CommandRecord, Compression, + Encryption, Endpoints, FencingTokenMismatchError, ReadLimit, @@ -844,3 +847,164 @@ async def test_compression_roundtrip_session( await basin.delete_stream(stream_name) finally: await s2.delete_basin(basin_name) + + +def _make_key(key_type: str) -> bytes | str: + raw = os.urandom(32) + return raw if key_type == "bytes" else base64.b64encode(raw).decode() + + +@pytest.mark.stream +@pytest.mark.parametrize( + "cipher", + [Encryption.AEGIS_256, Encryption.AES_256_GCM], +) +class TestEncryption: + @pytest.mark.parametrize("key_type", ["bytes", "str"]) + async def test_encryption_roundtrip_unary( + self, + access_token: str, + endpoints: Endpoints | None, + cipher: Encryption, + key_type: str, + ): + async with S2(access_token, endpoints=endpoints) as s2: + basin_name = f"test-py-sdk-{uuid.uuid4().hex[:8]}" + await s2.create_basin( + name=basin_name, + config=BasinConfig(stream_cipher=cipher), + ) + try: + basin = s2.basin(basin_name) + stream_name = f"stream-{uuid.uuid4().hex[:8]}" + info = await basin.create_stream(stream_name) + assert info.cipher == cipher + + key = _make_key(key_type) + try: + stream = basin.stream(stream_name, encryption_key=key) + ack = await stream.append( + AppendInput( + records=[ + Record(body=b"hello"), + Record(body=b"world"), + ] + ) + ) + assert ack.start.seq_num == 0 + assert ack.end.seq_num == 2 + + batch = await stream.read(start=SeqNum(0)) + assert len(batch.records) == 2 + assert batch.records[0].body == b"hello" + assert batch.records[1].body == b"world" + finally: + await basin.delete_stream(stream_name) + finally: + await s2.delete_basin(basin_name) + + @pytest.mark.parametrize("key_type", ["bytes", "str"]) + async def test_encryption_roundtrip_session( + self, + access_token: str, + endpoints: Endpoints | None, + cipher: Encryption, + key_type: str, + ): + async with S2(access_token, endpoints=endpoints) as s2: + basin_name = f"test-py-sdk-{uuid.uuid4().hex[:8]}" + await s2.create_basin( + name=basin_name, + config=BasinConfig(stream_cipher=cipher), + ) + try: + basin = s2.basin(basin_name) + stream_name = f"stream-{uuid.uuid4().hex[:8]}" + info = await basin.create_stream(stream_name) + assert info.cipher == cipher + + key = _make_key(key_type) + try: + stream = basin.stream(stream_name, encryption_key=key) + async with stream.append_session() as session: + ticket = await session.submit( + AppendInput(records=[Record(body=b"s2" * 10240)]) + ) + ack = await ticket + + assert ack.start.seq_num == 0 + assert ack.end.seq_num == 1 + + batch = await stream.read(start=SeqNum(0)) + assert len(batch.records) == 1 + assert batch.records[0].body == b"s2" * 10240 + finally: + await basin.delete_stream(stream_name) + finally: + await s2.delete_basin(basin_name) + + async def test_encrypted_stream_requires_key( + self, + access_token: str, + endpoints: Endpoints | None, + cipher: Encryption, + ): + async with S2(access_token, endpoints=endpoints) as s2: + basin_name = f"test-py-sdk-{uuid.uuid4().hex[:8]}" + await s2.create_basin( + name=basin_name, + config=BasinConfig(stream_cipher=cipher), + ) + try: + basin = s2.basin(basin_name) + stream_name = f"stream-{uuid.uuid4().hex[:8]}" + await basin.create_stream(stream_name) + try: + stream = basin.stream(stream_name) + with pytest.raises(S2ServerError) as append_exc: + await stream.append( + AppendInput(records=[Record(body=b"hello")]) + ) + assert append_exc.value.code == "invalid" + assert cipher.value in str(append_exc.value) + with pytest.raises(S2ServerError) as read_exc: + await stream.read(start=SeqNum(0)) + assert read_exc.value.code == "invalid" + assert cipher.value in str(read_exc.value) + finally: + await basin.delete_stream(stream_name) + finally: + await s2.delete_basin(basin_name) + + async def test_encrypted_stream_wrong_key_read_fails( + self, + access_token: str, + endpoints: Endpoints | None, + cipher: Encryption, + ): + async with S2(access_token, endpoints=endpoints) as s2: + basin_name = f"test-py-sdk-{uuid.uuid4().hex[:8]}" + await s2.create_basin( + name=basin_name, + config=BasinConfig(stream_cipher=cipher), + ) + try: + basin = s2.basin(basin_name) + stream_name = f"stream-{uuid.uuid4().hex[:8]}" + await basin.create_stream(stream_name) + try: + key_a = os.urandom(32) + key_b = os.urandom(32) + + writer = basin.stream(stream_name, encryption_key=key_a) + reader = basin.stream(stream_name, encryption_key=key_b) + + await writer.append(AppendInput(records=[Record(body=b"secret")])) + + with pytest.raises(S2ServerError) as exc_info: + await reader.read(start=SeqNum(0)) + assert exc_info.value.code == "decryption_failed" + finally: + await basin.delete_stream(stream_name) + finally: + await s2.delete_basin(basin_name)