Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 39 additions & 39 deletions snuba/web/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Any,
Callable,
Dict,
Iterator,
Mapping,
MutableMapping,
MutableSequence,
Expand Down Expand Up @@ -371,51 +372,50 @@ def storage_delete(*, storage: WritableTableStorage, timer: Timer) -> Union[Resp
return make_response(jsonify({"error": details}), 500)

# i put the result inside "data" bc thats how sentry utils/snuba.py expects the result
return Response(dump_payload({"data": payload}), 200, {"Content-Type": "application/json"})
return Response(
stream_payload({"data": payload}), 200, {"Content-Type": "application/json"}
)

else:
assert False, "unexpected fallthrough"


def _sanitize_payload(payload: MutableMapping[str, Any], res: MutableMapping[str, Any]) -> None:
def hex_encode_if_bytes(value: Any) -> Any:
if isinstance(value, bytes):
try:
return value.decode("utf-8")
except UnicodeDecodeError:
# encode the byte string in a hex string
return "RAW_BYTESTRING__" + value.hex()

return value

for k, v in payload.items():
if isinstance(v, dict):
res[hex_encode_if_bytes(k)] = {}
_sanitize_payload(v, res[hex_encode_if_bytes(k)])
elif isinstance(v, list):
res[hex_encode_if_bytes(k)] = []
for item in v:
if isinstance(item, dict):
res[hex_encode_if_bytes(k)].append({})
_sanitize_payload(item, res[hex_encode_if_bytes(k)][-1])
else:
res[hex_encode_if_bytes(k)].append(hex_encode_if_bytes(item))
else:
res[hex_encode_if_bytes(k)] = hex_encode_if_bytes(v)
def _payload_default(obj: Any) -> Any:
"""
Serialization fallback for values the JSON encoder does not natively
handle. Bytes are decoded as UTF-8 when possible, otherwise hex-encoded so
one undecodable string cannot break the whole payload (and to avoid handing
potentially malicious raw bytes to downstream clients). Everything else
falls back to ``str`` (matching the previous ``default=str`` behavior).
"""
if isinstance(obj, bytes):
try:
return obj.decode("utf-8")
except UnicodeDecodeError:
# encode the byte string in a hex string
return "RAW_BYTESTRING__" + obj.hex()

return str(obj)

def dump_payload(payload: MutableMapping[str, Any]) -> str:
try:
return json.dumps(payload, default=str)
except UnicodeDecodeError:
# If there were any string that could not be decoded, we
# encode the problematic bytes in a hex string.
# this is to prevent other clients downstream of us from having
# to deal with potentially malicious strings and to prevent one
# bad string from breaking the entire payload.
sanitized_payload: MutableMapping[str, Any] = {}
_sanitize_payload(payload, sanitized_payload)
return json.dumps(sanitized_payload, default=str)

def stream_payload(payload: Mapping[str, Any]) -> Iterator[str]:
"""
Incrementally encode ``payload`` to JSON, yielding chunks instead of
building the entire serialized response as a single string in memory.

``encoding=None`` routes ``bytes`` values to ``_payload_default`` rather
than letting simplejson attempt (and fail) to decode them itself, which
would raise ``UnicodeDecodeError`` mid-stream. Combined with the default
``ensure_ascii=True``, every yielded chunk is pure ASCII and safe to encode.
"""
encoder = json.JSONEncoder(encoding=None, default=_payload_default)
# _one_shot=False keeps the incremental (chunk-yielding) encoder path; passing
# True would invoke the C fast path that builds the whole string at once.
yield from encoder.iterencode(payload, _one_shot=False)


def dump_payload(payload: Mapping[str, Any]) -> str:
return "".join(stream_payload(payload))


@with_span()
Expand Down Expand Up @@ -503,7 +503,7 @@ def dataset_query(
if settings.STATS_IN_RESPONSE or request.query_settings.get_debug():
payload.update(result.extra)

return Response(dump_payload(payload), 200, {"Content-Type": "application/json"})
return Response(stream_payload(payload), 200, {"Content-Type": "application/json"})


@application.errorhandler(InvalidSubscriptionError)
Expand Down
43 changes: 43 additions & 0 deletions tests/web/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,49 @@ def test_response_dumping() -> None:
assert json.loads(dumped_payload) == clean_data


def test_stream_payload() -> None:
from snuba.web.views import stream_payload

data = {
"data": [
{"count": 5181337, "release": "elsa"},
{"count": 2170, "release": b"valid-utf8"},
{"count": 88, "release": b"x;\x83\xc0\x05"},
],
"meta": [],
}

chunks = list(stream_payload(data))
# The whole point is incremental encoding: many small chunks, not one blob.
assert len(chunks) > 1
assert all(isinstance(chunk, str) for chunk in chunks)

parsed = json.loads("".join(chunks))
assert parsed["data"][0]["release"] == "elsa"
# valid UTF-8 bytes are decoded to a string
assert parsed["data"][1]["release"] == "valid-utf8"
# undecodable bytes are hex-encoded rather than breaking the stream
assert parsed["data"][2]["release"] == "RAW_BYTESTRING__" + b"x;\x83\xc0\x05".hex()


def test_streamed_response() -> None:
from flask import Response

from snuba.web.views import application, stream_payload

payload = {"data": [{"a": 1}, {"b": b"x;\x83\xc0\x05"}], "meta": []}

with application.app_context():
response = Response(stream_payload(payload), 200, {"Content-Type": "application/json"})
# A generator body is streamed, not buffered into a fixed-length body.
assert response.is_streamed
body = response.get_data(as_text=True)

parsed = json.loads(body)
assert parsed["data"][0]["a"] == 1
assert parsed["data"][1]["b"] == "RAW_BYTESTRING__" + b"x;\x83\xc0\x05".hex()


@pytest.mark.parametrize("exception, expected_log_level", invalid_query_exception_test_cases)
def test_handle_invalid_query(
caplog: Any, exception: InvalidQueryException, expected_log_level: str
Expand Down
Loading