diff --git a/snuba/web/views.py b/snuba/web/views.py index 96bd76b1e5..d842df9bb0 100644 --- a/snuba/web/views.py +++ b/snuba/web/views.py @@ -9,6 +9,7 @@ Any, Callable, Dict, + Iterator, Mapping, MutableMapping, MutableSequence, @@ -371,51 +372,50 @@ def storage_delete(*, storage: WritableTableStorage, timer: Timer) -> Union[Resp return make_response(jsonify({"error": details}), 500) # i put the result inside "data" bc thats how sentry utils/snuba.py expects the result - return Response(dump_payload({"data": payload}), 200, {"Content-Type": "application/json"}) + return Response( + stream_payload({"data": payload}), 200, {"Content-Type": "application/json"} + ) else: assert False, "unexpected fallthrough" -def _sanitize_payload(payload: MutableMapping[str, Any], res: MutableMapping[str, Any]) -> None: - def hex_encode_if_bytes(value: Any) -> Any: - if isinstance(value, bytes): - try: - return value.decode("utf-8") - except UnicodeDecodeError: - # encode the byte string in a hex string - return "RAW_BYTESTRING__" + value.hex() - - return value - - for k, v in payload.items(): - if isinstance(v, dict): - res[hex_encode_if_bytes(k)] = {} - _sanitize_payload(v, res[hex_encode_if_bytes(k)]) - elif isinstance(v, list): - res[hex_encode_if_bytes(k)] = [] - for item in v: - if isinstance(item, dict): - res[hex_encode_if_bytes(k)].append({}) - _sanitize_payload(item, res[hex_encode_if_bytes(k)][-1]) - else: - res[hex_encode_if_bytes(k)].append(hex_encode_if_bytes(item)) - else: - res[hex_encode_if_bytes(k)] = hex_encode_if_bytes(v) +def _payload_default(obj: Any) -> Any: + """ + Serialization fallback for values the JSON encoder does not natively + handle. Bytes are decoded as UTF-8 when possible, otherwise hex-encoded so + one undecodable string cannot break the whole payload (and to avoid handing + potentially malicious raw bytes to downstream clients). Everything else + falls back to ``str`` (matching the previous ``default=str`` behavior). + """ + if isinstance(obj, bytes): + try: + return obj.decode("utf-8") + except UnicodeDecodeError: + # encode the byte string in a hex string + return "RAW_BYTESTRING__" + obj.hex() + return str(obj) -def dump_payload(payload: MutableMapping[str, Any]) -> str: - try: - return json.dumps(payload, default=str) - except UnicodeDecodeError: - # If there were any string that could not be decoded, we - # encode the problematic bytes in a hex string. - # this is to prevent other clients downstream of us from having - # to deal with potentially malicious strings and to prevent one - # bad string from breaking the entire payload. - sanitized_payload: MutableMapping[str, Any] = {} - _sanitize_payload(payload, sanitized_payload) - return json.dumps(sanitized_payload, default=str) + +def stream_payload(payload: Mapping[str, Any]) -> Iterator[str]: + """ + Incrementally encode ``payload`` to JSON, yielding chunks instead of + building the entire serialized response as a single string in memory. + + ``encoding=None`` routes ``bytes`` values to ``_payload_default`` rather + than letting simplejson attempt (and fail) to decode them itself, which + would raise ``UnicodeDecodeError`` mid-stream. Combined with the default + ``ensure_ascii=True``, every yielded chunk is pure ASCII and safe to encode. + """ + encoder = json.JSONEncoder(encoding=None, default=_payload_default) + # _one_shot=False keeps the incremental (chunk-yielding) encoder path; passing + # True would invoke the C fast path that builds the whole string at once. + yield from encoder.iterencode(payload, _one_shot=False) + + +def dump_payload(payload: Mapping[str, Any]) -> str: + return "".join(stream_payload(payload)) @with_span() @@ -503,7 +503,7 @@ def dataset_query( if settings.STATS_IN_RESPONSE or request.query_settings.get_debug(): payload.update(result.extra) - return Response(dump_payload(payload), 200, {"Content-Type": "application/json"}) + return Response(stream_payload(payload), 200, {"Content-Type": "application/json"}) @application.errorhandler(InvalidSubscriptionError) diff --git a/tests/web/test_views.py b/tests/web/test_views.py index 7459220555..719fcf0823 100644 --- a/tests/web/test_views.py +++ b/tests/web/test_views.py @@ -62,6 +62,49 @@ def test_response_dumping() -> None: assert json.loads(dumped_payload) == clean_data +def test_stream_payload() -> None: + from snuba.web.views import stream_payload + + data = { + "data": [ + {"count": 5181337, "release": "elsa"}, + {"count": 2170, "release": b"valid-utf8"}, + {"count": 88, "release": b"x;\x83\xc0\x05"}, + ], + "meta": [], + } + + chunks = list(stream_payload(data)) + # The whole point is incremental encoding: many small chunks, not one blob. + assert len(chunks) > 1 + assert all(isinstance(chunk, str) for chunk in chunks) + + parsed = json.loads("".join(chunks)) + assert parsed["data"][0]["release"] == "elsa" + # valid UTF-8 bytes are decoded to a string + assert parsed["data"][1]["release"] == "valid-utf8" + # undecodable bytes are hex-encoded rather than breaking the stream + assert parsed["data"][2]["release"] == "RAW_BYTESTRING__" + b"x;\x83\xc0\x05".hex() + + +def test_streamed_response() -> None: + from flask import Response + + from snuba.web.views import application, stream_payload + + payload = {"data": [{"a": 1}, {"b": b"x;\x83\xc0\x05"}], "meta": []} + + with application.app_context(): + response = Response(stream_payload(payload), 200, {"Content-Type": "application/json"}) + # A generator body is streamed, not buffered into a fixed-length body. + assert response.is_streamed + body = response.get_data(as_text=True) + + parsed = json.loads(body) + assert parsed["data"][0]["a"] == 1 + assert parsed["data"][1]["b"] == "RAW_BYTESTRING__" + b"x;\x83\xc0\x05".hex() + + @pytest.mark.parametrize("exception, expected_log_level", invalid_query_exception_test_cases) def test_handle_invalid_query( caplog: Any, exception: InvalidQueryException, expected_log_level: str