diff --git a/tests/appsec/api_security/test_api_security_rc.py b/tests/appsec/api_security/test_api_security_rc.py index e8e3f837c26..80fd10165c9 100644 --- a/tests/appsec/api_security/test_api_security_rc.py +++ b/tests/appsec/api_security/test_api_security_rc.py @@ -9,7 +9,7 @@ def get_schema(request: HttpResponse, address: str): """Get api security schema from spans""" - for _, _, span in interfaces.library.get_spans(request): + for _, _, span, _ in interfaces.library.get_spans(request): meta = span.get("meta", {}) key = "_dd.appsec.s." + address payload = meta.get(key) diff --git a/tests/appsec/api_security/test_apisec_sampling.py b/tests/appsec/api_security/test_apisec_sampling.py index 6ef509bb5b8..9e1d3927ceb 100644 --- a/tests/appsec/api_security/test_apisec_sampling.py +++ b/tests/appsec/api_security/test_apisec_sampling.py @@ -17,7 +17,7 @@ def get_schema(request: HttpResponse, address: str): """Get api security schema from spans""" - for _, _, span in interfaces.library.get_spans(request): + for _, _, span, _ in interfaces.library.get_spans(request): meta = span.get("meta", {}) payload = meta.get("_dd.appsec.s." + address) if payload is not None: diff --git a/tests/appsec/api_security/test_custom_data_classification.py b/tests/appsec/api_security/test_custom_data_classification.py index 767c989e513..ec3d214710e 100644 --- a/tests/appsec/api_security/test_custom_data_classification.py +++ b/tests/appsec/api_security/test_custom_data_classification.py @@ -10,7 +10,7 @@ def get_schema(request: HttpResponse, address: str): """Get api security schema from spans""" - for _, _, span in interfaces.library.get_spans(request): + for _, _, span, _ in interfaces.library.get_spans(request): meta = span.get("meta", {}) key = "_dd.appsec.s." + address payload = meta.get(key) diff --git a/tests/appsec/api_security/test_endpoint_fallback.py b/tests/appsec/api_security/test_endpoint_fallback.py index 97e6b09effc..e717cf7a8ee 100644 --- a/tests/appsec/api_security/test_endpoint_fallback.py +++ b/tests/appsec/api_security/test_endpoint_fallback.py @@ -16,7 +16,7 @@ def get_schema(request: HttpResponse, address: str): """Get api security schema from spans""" - for _, _, span in interfaces.library.get_spans(request): + for _, _, span, _ in interfaces.library.get_spans(request): meta = span.get("meta", {}) payload = meta.get("_dd.appsec.s." + address) if payload is not None: @@ -26,8 +26,9 @@ def get_schema(request: HttpResponse, address: str): def get_span_meta(request: HttpResponse, key: str): """Get a specific meta value from the root span""" - span = interfaces.library.get_root_span(request) - return span.get("meta", {}).get(key) + span, span_format = interfaces.library.get_root_span(request) + meta = interfaces.library.get_span_meta(span, span_format) + return meta.get(key) @rfc("https://docs.google.com/document/d/1GnWwiaw6dkVtgn5f1wcHJETND_Svqd-sJl6FSVVuCkI") diff --git a/tests/appsec/api_security/test_schemas.py b/tests/appsec/api_security/test_schemas.py index d9544f153de..c19f8a9edb8 100644 --- a/tests/appsec/api_security/test_schemas.py +++ b/tests/appsec/api_security/test_schemas.py @@ -9,8 +9,8 @@ def get_schema(request: HttpResponse, address: str): """Get api security schema from spans""" - span = interfaces.library.get_root_span(request) - meta = span.get("meta", {}) + span, span_format = interfaces.library.get_root_span(request) + meta = interfaces.library.get_span_meta(span, span_format) key = "_dd.appsec.s." + address if key not in meta: logger.info(f"Schema not found in span meta for {key}") diff --git a/tests/appsec/iast/test_vulnerability_schema.py b/tests/appsec/iast/test_vulnerability_schema.py index 68328054258..a24b2226baf 100644 --- a/tests/appsec/iast/test_vulnerability_schema.py +++ b/tests/appsec/iast/test_vulnerability_schema.py @@ -11,9 +11,9 @@ def test_vulnerability_schema(self): with open(schema_path, "r") as f: schema = json.load(f) validator = jsonschema.Draft7Validator(schema) - spans = [s for _, s in interfaces.library.get_root_spans()] - for span in spans: - meta = span.get("meta", {}) + spans_with_format = list(interfaces.library.get_root_spans()) + for _, span, span_format in spans_with_format: + meta = interfaces.library.get_span_meta(span, span_format) if "_dd.iast.json" not in meta: continue iast_data = meta["_dd.iast.json"] diff --git a/tests/appsec/iast/utils.py b/tests/appsec/iast/utils.py index 3db7ecd85ef..5ef3ac7b730 100644 --- a/tests/appsec/iast/utils.py +++ b/tests/appsec/iast/utils.py @@ -17,8 +17,8 @@ def _get_expectation(d: str | dict | None) -> str | None: def _get_span_meta(request: HttpResponse): - span = interfaces.library.get_root_span(request) - meta = span.get("meta", {}) + span, span_format = interfaces.library.get_root_span(request) + meta = interfaces.library.get_span_meta(span, span_format) meta_struct = span.get("meta_struct", {}) return meta, meta_struct @@ -56,8 +56,9 @@ def assert_iast_vulnerability( def assert_metric(request: HttpResponse, metric: str, *, expected: bool) -> None: spans_checked = 0 metric_available = False - for _, __, span in interfaces.library.get_spans(request): - if metric in span["metrics"]: + for _, __, span, span_format in interfaces.library.get_spans(request): + metrics = interfaces.library.get_span_metrics(span, span_format) + if metric in metrics: metric_available = True spans_checked += 1 assert spans_checked == 1 @@ -77,13 +78,13 @@ def _check_telemetry_response_from_agent(): def get_all_iast_events() -> list: - spans = [span[2] for span in interfaces.library.get_spans()] - assert spans, "No spans found" - spans_meta = [span.get("meta") for span in spans if span.get("meta")] - spans_meta_struct = [span.get("meta_struct") for span in spans if span.get("meta_struct")] + spans_with_format = [(span, span_format) for _, _, span, span_format in interfaces.library.get_spans()] + assert spans_with_format, "No spans found" + spans_meta = [interfaces.library.get_span_meta(span, span_format) for span, span_format in spans_with_format] + spans_meta_struct = [span.get("meta_struct") for span, _ in spans_with_format if span.get("meta_struct")] assert spans_meta or spans_meta_struct, "No spans meta found" - iast_events = [meta.get("_dd.iast.json") for meta in spans_meta if meta.get("_dd.iast.json")] - iast_events += [metastruct.get("iast") for metastruct in spans_meta_struct if metastruct.get("iast")] + iast_events = [meta.get("_dd.iast.json") for meta in spans_meta if meta and meta.get("_dd.iast.json")] + iast_events += [metastruct.get("iast") for metastruct in spans_meta_struct if metastruct and metastruct.get("iast")] assert iast_events, "No iast events found" return iast_events @@ -198,8 +199,8 @@ def assert_no_iast_event(request: HttpResponse, tested_vulnerability_type: str | def validate_stack_traces(request: HttpResponse) -> None: - span = interfaces.library.get_root_span(request) - meta = span.get("meta", {}) + span, span_format = interfaces.library.get_root_span(request) + meta = interfaces.library.get_span_meta(span, span_format) meta_struct = span.get("meta_struct", {}) iast = meta.get("_dd.iast.json") or meta_struct.get("iast") assert iast is not None, "No iast event in root span" @@ -289,9 +290,12 @@ def validate_stack_traces(request: HttpResponse) -> None: def validate_extended_location_data( request: HttpResponse, vulnerability_type: str | None, *, is_expected_location_required: bool = True ) -> None: - span = interfaces.library.get_root_span(request) - iast = span.get("meta", {}).get("_dd.iast.json") or span.get("meta_struct", {}).get("iast") - assert iast, f"Expected at least one vulnerability in span {span.get('span_id')}" + span, span_format = interfaces.library.get_root_span(request) + meta = interfaces.library.get_span_meta(span, span_format) + meta_struct = span.get("meta_struct", {}) + iast = meta.get("_dd.iast.json") or meta_struct.get("iast") + span_id = interfaces.library.get_span_span_id(span, span_format) + assert iast, f"Expected at least one vulnerability in span {span_id}" assert iast["vulnerabilities"], f"Expected at least one vulnerability: {iast['vulnerabilities']}" # Filter by vulnerability @@ -321,8 +325,8 @@ def validate_extended_location_data( if context.library.name not in ("python", "nodejs"): assert all(field in location for field in ["class", "method"]) else: - assert "vulnerability" in span["meta_struct"]["_dd.stack"], "'vulnerability' not found in '_dd.stack'" - stack_traces = span["meta_struct"]["_dd.stack"]["vulnerability"] + assert "vulnerability" in meta_struct["_dd.stack"], "'vulnerability' not found in '_dd.stack'" + stack_traces = meta_struct["_dd.stack"]["vulnerability"] assert stack_traces, "No vulnerability stack traces found" stack_traces = [s for s in stack_traces if s.get("id") == stack_id] assert stack_traces, f"No vulnerability stack trace found for id {stack_id}" @@ -364,16 +368,19 @@ def _norm(s: str | None) -> str | None: def get_hardcoded_vulnerabilities(vulnerability_type: str, request: HttpResponse | None = None) -> list: - spans = [s for _, s in interfaces.library.get_root_spans(request=request)] - assert spans, "No spans found" - spans_meta = [span.get("meta") for span in spans] + spans_with_format = [ + (span, span_format) for _, span, span_format in interfaces.library.get_root_spans(request=request) + ] + assert spans_with_format, "No spans found" + spans_meta = [interfaces.library.get_span_meta(span, span_format) for span, span_format in spans_with_format] assert spans_meta, "No spans meta found" - iast_events = [meta.get("_dd.iast.json") for meta in spans_meta if meta.get("_dd.iast.json")] + iast_events = [meta.get("_dd.iast.json") for meta in spans_meta if meta and meta.get("_dd.iast.json")] assert iast_events, "No iast events found" vulnerabilities: list = [] for event in iast_events: - vulnerabilities.extend(event.get("vulnerabilities", [])) + if event: + vulnerabilities.extend(event.get("vulnerabilities", [])) assert vulnerabilities, "No vulnerabilities found" diff --git a/tests/appsec/rasp/test_api10.py b/tests/appsec/rasp/test_api10.py index b66037d15f0..6e1220e1aa6 100644 --- a/tests/appsec/rasp/test_api10.py +++ b/tests/appsec/rasp/test_api10.py @@ -385,7 +385,7 @@ def test_api10_redirect(self): assert self.r.status_code == 200 # interfaces.library.validate_one_span(self.r, validator=self.validate) interfaces.library.validate_one_span(self.r, validator=self.validate_metric) - for _, _trace, span in interfaces.library.get_spans(request=self.r): + for _, _trace, span, _ in interfaces.library.get_spans(request=self.r): meta = span.get("meta", {}) assert isinstance(meta.get("appsec.api.redirection.move_target", None), str), f"missing tag in {meta}" assert "/redirect?totalRedirects=2" in meta["appsec.api.redirection.move_target"] diff --git a/tests/appsec/rasp/utils.py b/tests/appsec/rasp/utils.py index d759ef5f538..28969040fa9 100644 --- a/tests/appsec/rasp/utils.py +++ b/tests/appsec/rasp/utils.py @@ -13,12 +13,12 @@ def validate_span_tags( request: HttpResponse, expected_meta: Sequence[str] = (), expected_metrics: Sequence[str] = () ) -> None: """Validate RASP span tags are added when an event is generated""" - span = interfaces.library.get_root_span(request) - meta = span["meta"] + span, span_format = interfaces.library.get_root_span(request) + meta = interfaces.library.get_span_meta(span, span_format) for m in expected_meta: assert m in meta, f"missing span meta tag `{m}` in {meta}" - metrics = span["metrics"] + metrics = interfaces.library.get_span_metrics(span, span_format) for m in expected_metrics: assert m in metrics, f"missing span metric tag `{m}` in {metrics}" diff --git a/tests/appsec/test_asm_standalone.py b/tests/appsec/test_asm_standalone.py index 4590c96374b..b2a5babe2bb 100644 --- a/tests/appsec/test_asm_standalone.py +++ b/tests/appsec/test_asm_standalone.py @@ -5,7 +5,7 @@ from requests.structures import CaseInsensitiveDict -from utils.dd_constants import SAMPLING_PRIORITY_KEY, SamplingPriority +from utils.dd_constants import SAMPLING_PRIORITY_KEY, SamplingPriority, TraceLibraryPayloadFormat from utils.telemetry_utils import TelemetryUtils from utils._weblog import HttpResponse, _Weblog from utils import context, weblog, interfaces, scenarios, features, rfc, missing_feature, logger @@ -75,16 +75,16 @@ def assert_product_is_enabled(response: HttpResponse, product: str | None) -> No product_enabled = False tags = "_dd.iast.json" if product == "iast" else "_dd.appsec.json" meta_struct_key = "iast" if product == "iast" else "appsec" - spans = list(items[2] for items in interfaces.library.get_spans(request=response)) - logger.debug(f"Found {len(spans)} spans") - for span in spans: + spans_with_format = list(interfaces.library.get_spans(request=response)) + logger.debug(f"Found {len(spans_with_format)} spans") + for _, _, span, span_format in spans_with_format: # Check if the product is enabled in meta - meta = span["meta"] + meta = interfaces.library.get_span_meta(span, span_format) if tags in meta: product_enabled = True break # Check if the product is enabled in meta_struct - meta_struct = span["meta_struct"] + meta_struct = interfaces.library.get_span_meta_struct(span, span_format) if meta_struct and meta_struct.get(meta_struct_key): product_enabled = True break @@ -126,9 +126,13 @@ def setup_no_appsec_upstream__no_asm_event__is_kept_with_priority_1__from_minus_ ) def fix_priority_lambda( - self, span: dict, default_checks: dict[str, str | Callable | None] + self, + span: dict, + span_format: TraceLibraryPayloadFormat | None, + default_checks: dict[str, str | Callable | None], ) -> dict[str, str | Callable | None]: - if "_dd.appsec.s.req.headers" in span["meta"]: + meta = interfaces.library.get_span_meta(span, span_format) + if "_dd.appsec.s.req.headers" in meta: return { SAMPLING_PRIORITY_KEY: lambda x: x == SamplingPriority.USER_KEEP } # if we find evidence of API Sec schema, priority should be 2 (Manual Keep) @@ -141,13 +145,14 @@ def test_no_appsec_upstream__no_asm_event__is_kept_with_priority_1__from_minus_1 tested_meta: dict[str, str | Callable | None] = {self.propagated_tag(): None, "_dd.p.other": "1"} tested_metrics: dict[str, str | Callable | None] = {SAMPLING_PRIORITY_KEY: lambda x: x < 2} - for data, trace, span in interfaces.library.get_spans(request=self.r): + for data, trace, span, span_format in interfaces.library.get_spans(request=self.r): assert assert_tags(trace[0], span, "meta", tested_meta) - assert assert_tags(trace[0], span, "metrics", self.fix_priority_lambda(span, tested_metrics)) + metrics = interfaces.library.get_span_metrics(span, span_format) + assert assert_tags(trace[0], span, "metrics", self.fix_priority_lambda(span, span_format, tested_metrics)) - assert span["metrics"]["_dd.apm.enabled"] == 0 # if key missing -> APPSEC-55222 - assert span["trace_id"] == 1212121212121212121 - assert trace[0]["trace_id"] == 1212121212121212121 + assert metrics["_dd.apm.enabled"] == 0 # if key missing -> APPSEC-55222 + trace_id = interfaces.library.get_trace_id(trace, span_format) + assert trace_id == 1212121212121212121 # Some tracers use true while others use yes assert any( @@ -187,13 +192,14 @@ def test_no_appsec_upstream__no_asm_event__is_kept_with_priority_1__from_0(self) tested_meta: dict[str, str | Callable | None] = {self.propagated_tag(): None, "_dd.p.other": "1"} tested_metrics: dict[str, str | Callable | None] = {SAMPLING_PRIORITY_KEY: lambda x: x < 2} - for data, trace, span in interfaces.library.get_spans(request=self.r): + for data, trace, span, span_format in interfaces.library.get_spans(request=self.r): assert assert_tags(trace[0], span, "meta", tested_meta) - assert assert_tags(trace[0], span, "metrics", self.fix_priority_lambda(span, tested_metrics)) + metrics = interfaces.library.get_span_metrics(span, span_format) + assert assert_tags(trace[0], span, "metrics", self.fix_priority_lambda(span, span_format, tested_metrics)) - assert span["metrics"]["_dd.apm.enabled"] == 0 # if key missing -> APPSEC-55222 - assert span["trace_id"] == 1212121212121212121 - assert trace[0]["trace_id"] == 1212121212121212121 + assert metrics["_dd.apm.enabled"] == 0 # if key missing -> APPSEC-55222 + trace_id = interfaces.library.get_trace_id(trace, span_format) + assert trace_id == 1212121212121212121 # Some tracers use true while others use yes assert any( @@ -233,13 +239,14 @@ def test_no_appsec_upstream__no_asm_event__is_kept_with_priority_1__from_1(self) tested_meta: dict[str, str | Callable | None] = {self.propagated_tag(): None, "_dd.p.other": "1"} tested_metrics: dict[str, str | Callable | None] = {SAMPLING_PRIORITY_KEY: lambda x: x < 2} - for data, trace, span in interfaces.library.get_spans(request=self.r): + for data, trace, span, span_format in interfaces.library.get_spans(request=self.r): assert assert_tags(trace[0], span, "meta", tested_meta) - assert assert_tags(trace[0], span, "metrics", self.fix_priority_lambda(span, tested_metrics)) + metrics = interfaces.library.get_span_metrics(span, span_format) + assert assert_tags(trace[0], span, "metrics", self.fix_priority_lambda(span, span_format, tested_metrics)) - assert span["metrics"]["_dd.apm.enabled"] == 0 # if key missing -> APPSEC-55222 - assert span["trace_id"] == 1212121212121212121 - assert trace[0]["trace_id"] == 1212121212121212121 + assert metrics["_dd.apm.enabled"] == 0 # if key missing -> APPSEC-55222 + trace_id = interfaces.library.get_trace_id(trace, span_format) + assert trace_id == 1212121212121212121 # Some tracers use true while others use yes assert any( @@ -279,13 +286,14 @@ def test_no_appsec_upstream__no_asm_event__is_kept_with_priority_1__from_2(self) tested_meta: dict[str, str | Callable | None] = {self.propagated_tag(): None, "_dd.p.other": "1"} tested_metrics: dict[str, str | Callable | None] = {SAMPLING_PRIORITY_KEY: lambda x: x < 2} - for data, trace, span in interfaces.library.get_spans(request=self.r): + for data, trace, span, span_format in interfaces.library.get_spans(request=self.r): assert assert_tags(trace[0], span, "meta", tested_meta) - assert assert_tags(trace[0], span, "metrics", self.fix_priority_lambda(span, tested_metrics)) + metrics = interfaces.library.get_span_metrics(span, span_format) + assert assert_tags(trace[0], span, "metrics", self.fix_priority_lambda(span, span_format, tested_metrics)) - assert span["metrics"]["_dd.apm.enabled"] == 0 # if key missing -> APPSEC-55222 - assert span["trace_id"] == 1212121212121212121 - assert trace[0]["trace_id"] == 1212121212121212121 + assert metrics["_dd.apm.enabled"] == 0 # if key missing -> APPSEC-55222 + trace_id = interfaces.library.get_trace_id(trace, span_format) + assert trace_id == 1212121212121212121 # Some tracers use true while others use yes assert any( @@ -323,13 +331,14 @@ def test_no_upstream_appsec_propagation__with_asm_event__is_kept_with_priority_2 tested_meta: dict[str, str | Callable | None] = {self.propagated_tag(): self.propagated_tag_value()} tested_metrics: dict[str, str | Callable | None] = {SAMPLING_PRIORITY_KEY: lambda x: x == 2} - for data, trace, span in interfaces.library.get_spans(request=self.r): + for data, trace, span, span_format in interfaces.library.get_spans(request=self.r): assert assert_tags(trace[0], span, "meta", tested_meta) assert assert_tags(trace[0], span, "metrics", tested_metrics) - assert span["metrics"]["_dd.apm.enabled"] == 0 # if key missing -> APPSEC-55222 - assert span["trace_id"] == 1212121212121212121 - assert trace[0]["trace_id"] == 1212121212121212121 + metrics = interfaces.library.get_span_metrics(span, span_format) + assert metrics["_dd.apm.enabled"] == 0 # if key missing -> APPSEC-55222 + trace_id = interfaces.library.get_trace_id(trace, span_format) + assert trace_id == 1212121212121212121 # Some tracers use true while others use yes assert any( @@ -367,13 +376,13 @@ def test_no_upstream_appsec_propagation__with_asm_event__is_kept_with_priority_2 tested_meta: dict[str, str | Callable | None] = {self.propagated_tag(): self.propagated_tag_value()} tested_metrics: dict[str, str | Callable | None] = {SAMPLING_PRIORITY_KEY: lambda x: x == 2} - for data, trace, span in interfaces.library.get_spans(request=self.r): + for data, trace, span, span_format in interfaces.library.get_spans(request=self.r): assert assert_tags(trace[0], span, "meta", tested_meta) assert assert_tags(trace[0], span, "metrics", tested_metrics) assert span["metrics"]["_dd.apm.enabled"] == 0 - assert span["trace_id"] == 1212121212121212121 - assert trace[0]["trace_id"] == 1212121212121212121 + trace_id = interfaces.library.get_trace_id(trace, span_format) + assert trace_id == 1212121212121212121 # Some tracers use true while others use yes assert any( @@ -413,13 +422,14 @@ def test_upstream_appsec_propagation__no_asm_event__is_propagated_as_is__being_0 tested_meta: dict[str, str | Callable | None] = {self.propagated_tag(): self.propagated_tag_value()} tested_metrics: dict[str, str | Callable | None] = {SAMPLING_PRIORITY_KEY: lambda x: x in [0, 2]} - for data, trace, span in interfaces.library.get_spans(request=self.r): + for data, trace, span, span_format in interfaces.library.get_spans(request=self.r): assert assert_tags(trace[0], span, "meta", tested_meta) assert assert_tags(trace[0], span, "metrics", tested_metrics) - assert span["metrics"]["_dd.apm.enabled"] == 0 # if key missing -> APPSEC-55222 - assert span["trace_id"] == 1212121212121212121 - assert trace[0]["trace_id"] == 1212121212121212121 + metrics = interfaces.library.get_span_metrics(span, span_format) + assert metrics["_dd.apm.enabled"] == 0 # if key missing -> APPSEC-55222 + trace_id = interfaces.library.get_trace_id(trace, span_format) + assert trace_id == 1212121212121212121 # Some tracers use true while others use yes assert any( @@ -458,13 +468,14 @@ def test_upstream_appsec_propagation__no_asm_event__is_propagated_as_is__being_1 tested_meta: dict[str, str | Callable | None] = {self.propagated_tag(): self.propagated_tag_value()} tested_metrics: dict[str, str | Callable | None] = {SAMPLING_PRIORITY_KEY: lambda x: x in [1, 2]} - for data, trace, span in interfaces.library.get_spans(request=self.r): + for data, trace, span, span_format in interfaces.library.get_spans(request=self.r): assert assert_tags(trace[0], span, "meta", tested_meta) assert assert_tags(trace[0], span, "metrics", tested_metrics) - assert span["metrics"]["_dd.apm.enabled"] == 0 # if key missing -> APPSEC-55222 - assert span["trace_id"] == 1212121212121212121 - assert trace[0]["trace_id"] == 1212121212121212121 + metrics = interfaces.library.get_span_metrics(span, span_format) + assert metrics["_dd.apm.enabled"] == 0 # if key missing -> APPSEC-55222 + trace_id = interfaces.library.get_trace_id(trace, span_format) + assert trace_id == 1212121212121212121 # Some tracers use true while others use yes assert any( @@ -503,13 +514,14 @@ def test_upstream_appsec_propagation__no_asm_event__is_propagated_as_is__being_2 tested_meta: dict[str, str | Callable | None] = {self.propagated_tag(): self.propagated_tag_value()} tested_metrics: dict[str, str | Callable | None] = {SAMPLING_PRIORITY_KEY: lambda x: x == 2} - for data, trace, span in interfaces.library.get_spans(request=self.r): + for data, trace, span, span_format in interfaces.library.get_spans(request=self.r): assert assert_tags(trace[0], span, "meta", tested_meta) assert assert_tags(trace[0], span, "metrics", tested_metrics) - assert span["metrics"]["_dd.apm.enabled"] == 0 # if key missing -> APPSEC-55222 - assert span["trace_id"] == 1212121212121212121 - assert trace[0]["trace_id"] == 1212121212121212121 + metrics = interfaces.library.get_span_metrics(span, span_format) + assert metrics["_dd.apm.enabled"] == 0 # if key missing -> APPSEC-55222 + trace_id = interfaces.library.get_trace_id(trace, span_format) + assert trace_id == 1212121212121212121 # Some tracers use true while others use yes assert any( @@ -545,13 +557,14 @@ def test_any_upstream_propagation__with_asm_event__raises_priority_to_2__from_mi tested_meta: dict[str, str | Callable | None] = {self.propagated_tag(): self.propagated_tag_value()} tested_metrics: dict[str, str | Callable | None] = {SAMPLING_PRIORITY_KEY: lambda x: x == 2} - for data, trace, span in interfaces.library.get_spans(request=self.r): + for data, trace, span, span_format in interfaces.library.get_spans(request=self.r): assert assert_tags(trace[0], span, "meta", tested_meta) assert assert_tags(trace[0], span, "metrics", tested_metrics) - assert span["metrics"]["_dd.apm.enabled"] == 0 # if key missing -> APPSEC-55222 - assert span["trace_id"] == 1212121212121212121 - assert trace[0]["trace_id"] == 1212121212121212121 + metrics = interfaces.library.get_span_metrics(span, span_format) + assert metrics["_dd.apm.enabled"] == 0 # if key missing -> APPSEC-55222 + trace_id = interfaces.library.get_trace_id(trace, span_format) + assert trace_id == 1212121212121212121 # Some tracers use true while others use yes assert any( @@ -587,13 +600,14 @@ def test_any_upstream_propagation__with_asm_event__raises_priority_to_2__from_0( tested_meta: dict[str, str | Callable | None] = {self.propagated_tag(): self.propagated_tag_value()} tested_metrics: dict[str, str | Callable | None] = {SAMPLING_PRIORITY_KEY: lambda x: x == 2} - for data, trace, span in interfaces.library.get_spans(request=self.r): + for data, trace, span, span_format in interfaces.library.get_spans(request=self.r): assert assert_tags(trace[0], span, "meta", tested_meta) assert assert_tags(trace[0], span, "metrics", tested_metrics) - assert span["metrics"]["_dd.apm.enabled"] == 0 # if key missing -> APPSEC-55222 - assert span["trace_id"] == 1212121212121212121 - assert trace[0]["trace_id"] == 1212121212121212121 + metrics = interfaces.library.get_span_metrics(span, span_format) + assert metrics["_dd.apm.enabled"] == 0 # if key missing -> APPSEC-55222 + trace_id = interfaces.library.get_trace_id(trace, span_format) + assert trace_id == 1212121212121212121 # Some tracers use true while others use yes assert any( @@ -629,13 +643,14 @@ def test_any_upstream_propagation__with_asm_event__raises_priority_to_2__from_1( tested_meta: dict[str, str | Callable | None] = {self.propagated_tag(): self.propagated_tag_value()} tested_metrics: dict[str, str | Callable | None] = {SAMPLING_PRIORITY_KEY: lambda x: x == 2} - for data, trace, span in interfaces.library.get_spans(request=self.r): + for data, trace, span, span_format in interfaces.library.get_spans(request=self.r): assert assert_tags(trace[0], span, "meta", tested_meta) assert assert_tags(trace[0], span, "metrics", tested_metrics) - assert span["metrics"]["_dd.apm.enabled"] == 0 # if key missing -> APPSEC-55222 - assert span["trace_id"] == 1212121212121212121 - assert trace[0]["trace_id"] == 1212121212121212121 + metrics = interfaces.library.get_span_metrics(span, span_format) + assert metrics["_dd.apm.enabled"] == 0 # if key missing -> APPSEC-55222 + trace_id = interfaces.library.get_trace_id(trace, span_format) + assert trace_id == 1212121212121212121 # Some tracers use true while others use yes assert any( @@ -701,8 +716,11 @@ class BaseSCAStandaloneTelemetry: def assert_standalone_is_enabled(self, request0: HttpResponse, request1: HttpResponse): # test standalone is enabled and dropping traces spans_checked = 0 - for _, __, span in list(interfaces.library.get_spans(request0)) + list(interfaces.library.get_spans(request1)): - if span["metrics"]["_dd.apm.enabled"] == 0: + for _, __, span, span_format in list(interfaces.library.get_spans(request=request0)) + list( + interfaces.library.get_spans(request=request1) + ): + metrics = interfaces.library.get_span_metrics(span, span_format) + if metrics["_dd.apm.enabled"] == 0: spans_checked += 1 assert spans_checked > 0 @@ -796,8 +814,9 @@ def setup_client_computed_stats_header_is_not_present(self): def test_client_computed_stats_header_is_not_present(self): spans_checked = 0 - for data, _, span in interfaces.library.get_spans(request=self.r): - assert span["trace_id"] == 1212121212121212122 + for data, trace, _, span_format in interfaces.library.get_spans(request=self.r): + trace_id = interfaces.library.get_trace_id(trace, span_format) + assert trace_id == 1212121212121212122 assert "datadog-client-computed-stats" not in [x.lower() for x, y in data["request"]["headers"]] spans_checked += 1 assert spans_checked == 1 @@ -857,8 +876,8 @@ def propagated_tag_value(self) -> str: @staticmethod def get_schema(request: HttpResponse, address: str) -> list | None: """Extract API security schema from span metadata""" - for _, _, span in interfaces.library.get_spans(request=request): - meta = span.get("meta", {}) + for _, _, span, span_format in interfaces.library.get_spans(request=request): + meta = interfaces.library.get_span_meta(span, span_format) if payload := meta.get("_dd.appsec.s." + address): return payload return None @@ -871,11 +890,12 @@ def check_trace_retained(request: HttpResponse, *, should_be_retained: bool) -> tested_metrics: dict[str, str | Callable | None] = { SAMPLING_PRIORITY_KEY: lambda x: x == 2 if should_be_retained else x <= 0 } - for data, trace, span in interfaces.library.get_spans(request=request): - assert span["trace_id"] == 1212121212121212121 - assert trace[0]["trace_id"] == 1212121212121212121 + for data, trace, span, span_format in interfaces.library.get_spans(request=request): + trace_id = interfaces.library.get_trace_id(trace, span_format) + assert trace_id == 1212121212121212121 assert assert_tags(trace[0], span, "metrics", tested_metrics) - assert span["metrics"]["_dd.apm.enabled"] == 0 # if key missing -> APPSEC-55222 + metrics = interfaces.library.get_span_metrics(span, span_format) + assert metrics["_dd.apm.enabled"] == 0 # if key missing -> APPSEC-55222 # Check for client-computed-stats header headers = data["request"]["headers"] @@ -1031,19 +1051,20 @@ def _get_standalone_span_meta(self, trace_id: int): tested_meta: dict[str, str | Callable | None] = { "_dd.p.ts": "02", } - for data, trace, span in interfaces.library.get_spans(request=self.r): + for data, trace, span, span_format in interfaces.library.get_spans(request=self.r): assert assert_tags(trace[0], span, "meta", tested_meta) - assert span["metrics"]["_dd.apm.enabled"] == 0 # if key missing -> APPSEC-55222 - assert span["trace_id"] == trace_id - assert trace[0]["trace_id"] == trace_id + metrics = interfaces.library.get_span_metrics(span, span_format) + assert metrics["_dd.apm.enabled"] == 0 # if key missing -> APPSEC-55222 + span_trace_id = interfaces.library.get_trace_id(trace, span_format) + assert span_trace_id == trace_id # Some tracers use true while others use yes assert any( header.lower() == "datadog-client-computed-stats" and value.lower() in TRUTHY_VALUES for header, value in data["request"]["headers"] ) - return span["meta"] + return interfaces.library.get_span_meta(span, span_format) return None @@ -1104,19 +1125,20 @@ def _get_standalone_span_meta(self, trace_id: int): tested_meta: dict[str, str | Callable | None] = { "_dd.p.ts": "02", } - for data, trace, span in interfaces.library.get_spans(request=self.r): + for data, trace, span, span_format in interfaces.library.get_spans(request=self.r): assert assert_tags(trace[0], span, "meta", tested_meta) - assert span["metrics"]["_dd.apm.enabled"] == 0 # if key missing -> APPSEC-55222 - assert span["trace_id"] == trace_id - assert trace[0]["trace_id"] == trace_id + metrics = interfaces.library.get_span_metrics(span, span_format) + assert metrics["_dd.apm.enabled"] == 0 # if key missing -> APPSEC-55222 + span_trace_id = interfaces.library.get_trace_id(trace, span_format) + assert span_trace_id == trace_id # Some tracers use true while others use yes assert any( header.lower() == "datadog-client-computed-stats" and value.lower() in TRUTHY_VALUES for header, value in data["request"]["headers"] ) - return span["meta"] + return interfaces.library.get_span_meta(span, span_format) return None @@ -1168,19 +1190,20 @@ def _get_standalone_span_meta(self, trace_id: int): tested_meta: dict[str, str | Callable | None] = { "_dd.p.ts": "02", } - for data, trace, span in interfaces.library.get_spans(request=self.r): + for data, trace, span, span_format in interfaces.library.get_spans(request=self.r): assert assert_tags(trace[0], span, "meta", tested_meta) - assert span["metrics"]["_dd.apm.enabled"] == 0 # if key missing -> APPSEC-55222 - assert span["trace_id"] == trace_id - assert trace[0]["trace_id"] == trace_id + metrics = interfaces.library.get_span_metrics(span, span_format) + assert metrics["_dd.apm.enabled"] == 0 # if key missing -> APPSEC-55222 + span_trace_id = interfaces.library.get_trace_id(trace, span_format) + assert span_trace_id == trace_id # Some tracers use true while others use yes assert any( header.lower() == "datadog-client-computed-stats" and value.lower() in TRUTHY_VALUES for header, value in data["request"]["headers"] ) - return span["meta"] + return interfaces.library.get_span_meta(span, span_format) return None diff --git a/tests/appsec/test_automated_login_events.py b/tests/appsec/test_automated_login_events.py index f597521352a..2226328b0d0 100644 --- a/tests/appsec/test_automated_login_events.py +++ b/tests/appsec/test_automated_login_events.py @@ -91,48 +91,60 @@ def setup_login_pii_success_local(self): def test_login_pii_success_local(self): assert self.r_pii_success.status_code == 200 - for _, trace, span in interfaces.library.get_spans(request=self.r_pii_success): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_pii_success): meta = span.get("meta", {}) assert "usr.id" not in meta assert meta["_dd.appsec.events.users.login.success.auto.mode"] == "safe" assert meta["appsec.events.users.login.success.track"] == "true" - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_pii_success_basic(self): self.r_pii_success = weblog.get("/login?auth=basic", headers={"Authorization": BASIC_AUTH_USER_HEADER}) def test_login_pii_success_basic(self): assert self.r_pii_success.status_code == 200 - for _, trace, span in interfaces.library.get_spans(request=self.r_pii_success): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_pii_success): meta = span.get("meta", {}) assert "usr.id" not in meta assert meta["_dd.appsec.events.users.login.success.auto.mode"] == "safe" assert meta["appsec.events.users.login.success.track"] == "true" - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_success_local(self): self.r_success = weblog.post("/login?auth=local", data=login_data(UUID_USER, PASSWORD)) def test_login_success_local(self): assert self.r_success.status_code == 200 - for _, trace, span in interfaces.library.get_spans(request=self.r_success): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_success): meta = span.get("meta", {}) assert meta["_dd.appsec.events.users.login.success.auto.mode"] == "safe" assert meta["appsec.events.users.login.success.track"] == "true" assert meta["usr.id"] == "591dc126-8431-4d0f-9509-b23318d3dce4" - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_success_basic(self): self.r_success = weblog.get("/login?auth=basic", headers={"Authorization": BASIC_AUTH_USER_UUID_HEADER}) def test_login_success_basic(self): assert self.r_success.status_code == 200 - for _, trace, span in interfaces.library.get_spans(request=self.r_success): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_success): meta = span.get("meta", {}) assert meta["_dd.appsec.events.users.login.success.auto.mode"] == "safe" assert meta["appsec.events.users.login.success.track"] == "true" assert meta["usr.id"] == "591dc126-8431-4d0f-9509-b23318d3dce4" - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_wrong_user_failure_local(self): self.r_wrong_user_failure = weblog.post("/login?auth=local", data=login_data(INVALID_USER, PASSWORD)) @@ -140,7 +152,7 @@ def setup_login_wrong_user_failure_local(self): @missing_feature(weblog_variant="spring-boot-openliberty", reason="weblog returns error 500") def test_login_wrong_user_failure_local(self): assert self.r_wrong_user_failure.status_code == 401 - for _, trace, span in interfaces.library.get_spans(request=self.r_wrong_user_failure): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_wrong_user_failure): meta = span.get("meta", {}) if context.library not in ("nodejs", "java"): # Currently in nodejs/java there is no way to check if the user exists upon authentication failure so @@ -150,7 +162,10 @@ def test_login_wrong_user_failure_local(self): assert "appsec.events.users.login.failure.usr.id" not in meta assert meta["_dd.appsec.events.users.login.failure.auto.mode"] == "safe" assert meta["appsec.events.users.login.failure.track"] == "true" - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_wrong_user_failure_basic(self): self.r_wrong_user_failure = weblog.get( @@ -160,7 +175,7 @@ def setup_login_wrong_user_failure_basic(self): @missing_feature(weblog_variant="spring-boot-openliberty", reason="weblog returns error 500") def test_login_wrong_user_failure_basic(self): assert self.r_wrong_user_failure.status_code == 401 - for _, trace, span in interfaces.library.get_spans(request=self.r_wrong_user_failure): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_wrong_user_failure): meta = span.get("meta", {}) if context.library not in ("nodejs", "java"): # Currently in nodejs/java there is no way to check if the user exists upon authentication failure so @@ -170,7 +185,10 @@ def test_login_wrong_user_failure_basic(self): assert "appsec.events.users.login.failure.usr.id" not in meta assert meta["_dd.appsec.events.users.login.failure.auto.mode"] == "safe" assert meta["appsec.events.users.login.failure.track"] == "true" - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_wrong_password_failure_local(self): self.r_wrong_user_failure = weblog.post("/login?auth=local", data=login_data(USER, "12345")) @@ -178,7 +196,7 @@ def setup_login_wrong_password_failure_local(self): @missing_feature(weblog_variant="spring-boot-openliberty", reason="weblog returns error 500") def test_login_wrong_password_failure_local(self): assert self.r_wrong_user_failure.status_code == 401 - for _, trace, span in interfaces.library.get_spans(request=self.r_wrong_user_failure): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_wrong_user_failure): meta = span.get("meta", {}) if context.library not in ("nodejs", "java"): # Currently in nodejs/java there is no way to check if the user exists upon authentication failure so @@ -188,7 +206,10 @@ def test_login_wrong_password_failure_local(self): assert "appsec.events.users.login.failure.usr.id" not in meta assert meta["_dd.appsec.events.users.login.failure.auto.mode"] == "safe" assert meta["appsec.events.users.login.failure.track"] == "true" - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_wrong_password_failure_basic(self): self.r_wrong_user_failure = weblog.get( @@ -198,7 +219,7 @@ def setup_login_wrong_password_failure_basic(self): @missing_feature(weblog_variant="spring-boot-openliberty", reason="weblog returns error 500") def test_login_wrong_password_failure_basic(self): assert self.r_wrong_user_failure.status_code == 401 - for _, trace, span in interfaces.library.get_spans(request=self.r_wrong_user_failure): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_wrong_user_failure): meta = span.get("meta", {}) if context.library not in ("nodejs", "java"): # Currently in nodejs/java there is no way to check if the user exists upon authentication failure so @@ -208,7 +229,10 @@ def test_login_wrong_password_failure_basic(self): assert "appsec.events.users.login.failure.usr.id" not in meta assert meta["_dd.appsec.events.users.login.failure.auto.mode"] == "safe" assert meta["appsec.events.users.login.failure.track"] == "true" - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_sdk_success_local(self): self.r_sdk_success = weblog.post( @@ -218,13 +242,16 @@ def setup_login_sdk_success_local(self): def test_login_sdk_success_local(self): assert self.r_sdk_success.status_code == 200 - for _, trace, span in interfaces.library.get_spans(request=self.r_sdk_success): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_sdk_success): meta = span.get("meta", {}) assert meta["_dd.appsec.events.users.login.success.auto.mode"] == "safe" assert meta["_dd.appsec.events.users.login.success.sdk"] == "true" assert meta["appsec.events.users.login.success.track"] == "true" assert meta["usr.id"] == "sdkUser" - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_sdk_success_basic(self): self.r_sdk_success = weblog.get( @@ -234,13 +261,16 @@ def setup_login_sdk_success_basic(self): def test_login_sdk_success_basic(self): assert self.r_sdk_success.status_code == 200 - for _, trace, span in interfaces.library.get_spans(request=self.r_sdk_success): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_sdk_success): meta = span.get("meta", {}) assert meta["_dd.appsec.events.users.login.success.auto.mode"] == "safe" assert meta["_dd.appsec.events.users.login.success.sdk"] == "true" assert meta["appsec.events.users.login.success.track"] == "true" assert meta["usr.id"] == "sdkUser" - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_sdk_failure_local(self): self.r_sdk_failure = weblog.post( @@ -251,14 +281,17 @@ def setup_login_sdk_failure_local(self): @missing_feature(weblog_variant="spring-boot-openliberty", reason="weblog returns error 500") def test_login_sdk_failure_local(self): assert self.r_sdk_failure.status_code == 401 - for _, trace, span in interfaces.library.get_spans(request=self.r_sdk_failure): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_sdk_failure): meta = span.get("meta", {}) assert meta["_dd.appsec.events.users.login.failure.auto.mode"] == "safe" assert meta["_dd.appsec.events.users.login.failure.sdk"] == "true" assert meta["appsec.events.users.login.failure.track"] == "true" assert meta["appsec.events.users.login.failure.usr.id"] == "sdkUser" assert meta["appsec.events.users.login.failure.usr.exists"] == "true" - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_sdk_failure_basic(self): self.r_sdk_failure = weblog.get( @@ -269,14 +302,17 @@ def setup_login_sdk_failure_basic(self): @missing_feature(weblog_variant="spring-boot-openliberty", reason="weblog returns error 500") def test_login_sdk_failure_basic(self): assert self.r_sdk_failure.status_code == 401 - for _, trace, span in interfaces.library.get_spans(request=self.r_sdk_failure): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_sdk_failure): meta = span.get("meta", {}) assert meta["_dd.appsec.events.users.login.failure.auto.mode"] == "safe" assert meta["_dd.appsec.events.users.login.failure.sdk"] == "true" assert meta["appsec.events.users.login.failure.track"] == "true" assert meta["appsec.events.users.login.failure.usr.id"] == "sdkUser" assert meta["appsec.events.users.login.failure.usr.exists"] == "true" - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) @rfc("https://docs.google.com/document/d/1-trUpphvyZY7k5ldjhW-MgqWl0xOm7AMEQDJEAZ63_Q/edit#heading=h.8d3o7vtyu1y1") @@ -290,7 +326,7 @@ def setup_login_success_local(self): def test_login_success_local(self): assert self.r_success.status_code == 200 - for _, trace, span in interfaces.library.get_spans(request=self.r_success): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_success): meta = span.get("meta", {}) assert meta["_dd.appsec.events.users.login.success.auto.mode"] == "extended" assert meta["appsec.events.users.login.success.track"] == "true" @@ -314,14 +350,17 @@ def test_login_success_local(self): assert meta["usr.username"] == "test" assert meta["usr.login"] == "test" - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_success_basic(self): self.r_success = weblog.get("/login?auth=basic", headers={"Authorization": BASIC_AUTH_USER_HEADER}) def test_login_success_basic(self): assert self.r_success.status_code == 200 - for _, trace, span in interfaces.library.get_spans(request=self.r_success): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_success): meta = span.get("meta", {}) assert meta["_dd.appsec.events.users.login.success.auto.mode"] == "extended" assert meta["appsec.events.users.login.success.track"] == "true" @@ -342,7 +381,10 @@ def test_login_success_basic(self): assert meta["usr.login"] == "test" assert meta["usr.email"] == "testuser@ddog.com" - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_wrong_user_failure_local(self): self.r_wrong_user_failure = weblog.post("/login?auth=local", data=login_data(INVALID_USER, PASSWORD)) @@ -350,7 +392,7 @@ def setup_login_wrong_user_failure_local(self): @missing_feature(weblog_variant="spring-boot-openliberty", reason="weblog returns error 500") def test_login_wrong_user_failure_local(self): assert self.r_wrong_user_failure.status_code == 401 - for _, trace, span in interfaces.library.get_spans(request=self.r_wrong_user_failure): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_wrong_user_failure): meta = span.get("meta", {}) if context.library not in ("nodejs", "java"): # Currently in nodejs/java there is no way to check if the user exists upon authentication failure so @@ -370,7 +412,10 @@ def test_login_wrong_user_failure_local(self): assert meta["appsec.events.users.login.failure.username"] == INVALID_USER else: assert meta["appsec.events.users.login.failure.usr.id"] == INVALID_USER - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_wrong_user_failure_basic(self): self.r_wrong_user_failure = weblog.get( @@ -380,7 +425,7 @@ def setup_login_wrong_user_failure_basic(self): @missing_feature(weblog_variant="spring-boot-openliberty", reason="weblog returns error 500") def test_login_wrong_user_failure_basic(self): assert self.r_wrong_user_failure.status_code == 401 - for _, trace, span in interfaces.library.get_spans(request=self.r_wrong_user_failure): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_wrong_user_failure): meta = span.get("meta", {}) if context.library not in ("nodejs", "java"): # Currently in nodejs/java there is no way to check if the user exists upon authentication failure so @@ -400,7 +445,10 @@ def test_login_wrong_user_failure_basic(self): assert meta["appsec.events.users.login.failure.username"] == INVALID_USER else: assert meta["appsec.events.users.login.failure.usr.id"] == INVALID_USER - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_wrong_password_failure_local(self): self.r_wrong_user_failure = weblog.post("/login?auth=local", data=login_data(USER, "12345")) @@ -408,7 +456,7 @@ def setup_login_wrong_password_failure_local(self): @missing_feature(weblog_variant="spring-boot-openliberty", reason="weblog returns error 500") def test_login_wrong_password_failure_local(self): assert self.r_wrong_user_failure.status_code == 401 - for _, trace, span in interfaces.library.get_spans(request=self.r_wrong_user_failure): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_wrong_user_failure): meta = span.get("meta", {}) if context.library in ("nodejs", "java"): # Currently in nodejs/java there is no way to check if the user exists upon authentication failure so @@ -423,7 +471,10 @@ def test_login_wrong_password_failure_local(self): assert meta["_dd.appsec.events.users.login.failure.auto.mode"] == "extended" assert meta["appsec.events.users.login.failure.track"] == "true" - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_wrong_password_failure_basic(self): self.r_wrong_user_failure = weblog.get( @@ -433,7 +484,7 @@ def setup_login_wrong_password_failure_basic(self): @missing_feature(weblog_variant="spring-boot-openliberty", reason="weblog returns error 500") def test_login_wrong_password_failure_basic(self): assert self.r_wrong_user_failure.status_code == 401 - for _, trace, span in interfaces.library.get_spans(request=self.r_wrong_user_failure): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_wrong_user_failure): meta = span.get("meta", {}) if context.library not in ("nodejs", "java"): # Currently in nodejs/java there is no way to check if the user exists upon authentication failure so @@ -448,7 +499,10 @@ def test_login_wrong_password_failure_basic(self): assert meta["_dd.appsec.events.users.login.failure.auto.mode"] == "extended" assert meta["appsec.events.users.login.failure.track"] == "true" - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_sdk_success_local(self): self.r_sdk_success = weblog.post( @@ -458,13 +512,16 @@ def setup_login_sdk_success_local(self): def test_login_sdk_success_local(self): assert self.r_sdk_success.status_code == 200 - for _, trace, span in interfaces.library.get_spans(request=self.r_sdk_success): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_sdk_success): meta = span.get("meta", {}) assert meta["_dd.appsec.events.users.login.success.auto.mode"] == "extended" assert meta["_dd.appsec.events.users.login.success.sdk"] == "true" assert meta["appsec.events.users.login.success.track"] == "true" assert meta["usr.id"] == "sdkUser" - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_sdk_success_basic(self): self.r_sdk_success = weblog.get( @@ -474,13 +531,16 @@ def setup_login_sdk_success_basic(self): def test_login_sdk_success_basic(self): assert self.r_sdk_success.status_code == 200 - for _, trace, span in interfaces.library.get_spans(request=self.r_sdk_success): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_sdk_success): meta = span.get("meta", {}) assert meta["_dd.appsec.events.users.login.success.auto.mode"] == "extended" assert meta["_dd.appsec.events.users.login.success.sdk"] == "true" assert meta["appsec.events.users.login.success.track"] == "true" assert meta["usr.id"] == "sdkUser" - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_sdk_failure_basic(self): self.r_sdk_failure = weblog.get( @@ -491,14 +551,17 @@ def setup_login_sdk_failure_basic(self): @missing_feature(weblog_variant="spring-boot-openliberty", reason="weblog returns error 500") def test_login_sdk_failure_basic(self): assert self.r_sdk_failure.status_code == 401 - for _, trace, span in interfaces.library.get_spans(request=self.r_sdk_failure): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_sdk_failure): meta = span.get("meta", {}) assert meta["_dd.appsec.events.users.login.failure.auto.mode"] == "extended" assert meta["_dd.appsec.events.users.login.failure.sdk"] == "true" assert meta["appsec.events.users.login.failure.track"] == "true" assert meta["appsec.events.users.login.failure.usr.id"] == "sdkUser" assert meta["appsec.events.users.login.failure.usr.exists"] == "true" - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_sdk_failure_local(self): self.r_sdk_failure = weblog.post( @@ -509,14 +572,17 @@ def setup_login_sdk_failure_local(self): @missing_feature(weblog_variant="spring-boot-openliberty", reason="weblog returns error 500") def test_login_sdk_failure_local(self): assert self.r_sdk_failure.status_code == 401 - for _, trace, span in interfaces.library.get_spans(request=self.r_sdk_failure): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_sdk_failure): meta = span.get("meta", {}) assert meta["_dd.appsec.events.users.login.failure.auto.mode"] == "extended" assert meta["_dd.appsec.events.users.login.failure.sdk"] == "true" assert meta["appsec.events.users.login.failure.track"] == "true" assert meta["appsec.events.users.login.failure.usr.id"] == "sdkUser" assert meta["appsec.events.users.login.failure.usr.exists"] == "true" - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_success_headers(self): self.r_hdr_success = weblog.post( @@ -588,7 +654,7 @@ def setup_login_pii_success_local(self): def test_login_pii_success_local(self): assert self.r_pii_success.status_code == 200 - for _, trace, span in interfaces.library.get_spans(request=self.r_pii_success): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_pii_success): meta = span.get("meta", {}) assert "usr.id" in meta assert meta["usr.id"] == "social-security-id" @@ -598,14 +664,17 @@ def test_login_pii_success_local(self): assert "appsec.events.users.login.success.login" not in meta assert meta["_dd.appsec.events.users.login.success.auto.mode"] == "identification" assert meta["appsec.events.users.login.success.track"] == "true" - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_pii_success_basic(self): self.r_pii_success = weblog.get("/login?auth=basic", headers={"Authorization": BASIC_AUTH_USER_HEADER}) def test_login_pii_success_basic(self): assert self.r_pii_success.status_code == 200 - for _, trace, span in interfaces.library.get_spans(request=self.r_pii_success): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_pii_success): meta = span.get("meta", {}) assert "usr.id" in meta assert meta["usr.id"] == "social-security-id" @@ -615,14 +684,17 @@ def test_login_pii_success_basic(self): assert "appsec.events.users.login.success.login" not in meta assert meta["_dd.appsec.events.users.login.success.auto.mode"] == "identification" assert meta["appsec.events.users.login.success.track"] == "true" - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_success_local(self): self.r_success = weblog.post("/login?auth=local", data=login_data(UUID_USER, PASSWORD)) def test_login_success_local(self): assert self.r_success.status_code == 200 - for _, trace, span in interfaces.library.get_spans(request=self.r_success): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_success): meta = span.get("meta", {}) assert meta["_dd.appsec.events.users.login.success.auto.mode"] == "identification" assert meta["appsec.events.users.login.success.track"] == "true" @@ -631,14 +703,17 @@ def test_login_success_local(self): assert "appsec.events.users.login.success.email" not in meta assert "appsec.events.users.login.success.username" not in meta assert "appsec.events.users.login.success.login" not in meta - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_success_basic(self): self.r_success = weblog.get("/login?auth=basic", headers={"Authorization": BASIC_AUTH_USER_UUID_HEADER}) def test_login_success_basic(self): assert self.r_success.status_code == 200 - for _, trace, span in interfaces.library.get_spans(request=self.r_success): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_success): meta = span.get("meta", {}) assert meta["_dd.appsec.events.users.login.success.auto.mode"] == "identification" assert meta["appsec.events.users.login.success.track"] == "true" @@ -647,14 +722,17 @@ def test_login_success_basic(self): assert "appsec.events.users.login.success.email" not in meta assert "appsec.events.users.login.success.username" not in meta assert "appsec.events.users.login.success.login" not in meta - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_wrong_user_failure_local(self): self.r_wrong_user_failure = weblog.post("/login?auth=local", data=login_data(INVALID_USER, PASSWORD)) def test_login_wrong_user_failure_local(self): assert self.r_wrong_user_failure.status_code == 401 - for _, trace, span in interfaces.library.get_spans(request=self.r_wrong_user_failure): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_wrong_user_failure): meta = span.get("meta", {}) if context.library not in ("nodejs", "java"): # Currently in nodejs/java there is no way to check if the user exists upon authentication failure so @@ -666,7 +744,10 @@ def test_login_wrong_user_failure_local(self): assert "appsec.events.users.login.failure.usr.login" not in meta assert meta["_dd.appsec.events.users.login.failure.auto.mode"] == "identification" assert meta["appsec.events.users.login.failure.track"] == "true" - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_wrong_user_failure_basic(self): self.r_wrong_user_failure = weblog.get( @@ -675,7 +756,7 @@ def setup_login_wrong_user_failure_basic(self): def test_login_wrong_user_failure_basic(self): assert self.r_wrong_user_failure.status_code == 401 - for _, trace, span in interfaces.library.get_spans(request=self.r_wrong_user_failure): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_wrong_user_failure): meta = span.get("meta", {}) if context.library not in ("nodejs", "java"): # Currently in nodejs/java there is no way to check if the user exists upon authentication failure so @@ -687,14 +768,17 @@ def test_login_wrong_user_failure_basic(self): assert "appsec.events.users.login.failure.usr.login" not in meta assert meta["_dd.appsec.events.users.login.failure.auto.mode"] == "identification" assert meta["appsec.events.users.login.failure.track"] == "true" - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_wrong_password_failure_local(self): self.r_wrong_user_failure = weblog.post("/login?auth=local", data=login_data(USER, "12345")) def test_login_wrong_password_failure_local(self): assert self.r_wrong_user_failure.status_code == 401 - for _, trace, span in interfaces.library.get_spans(request=self.r_wrong_user_failure): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_wrong_user_failure): meta = span.get("meta", {}) if context.library not in ("nodejs", "java"): # Currently in nodejs/java there is no way to check if the user exists upon authentication failure so @@ -712,7 +796,10 @@ def test_login_wrong_password_failure_local(self): assert "appsec.events.users.login.failure.usr.login" not in meta assert meta["_dd.appsec.events.users.login.failure.auto.mode"] == "identification" assert meta["appsec.events.users.login.failure.track"] == "true" - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_wrong_password_failure_basic(self): self.r_wrong_user_failure = weblog.get( @@ -721,7 +808,7 @@ def setup_login_wrong_password_failure_basic(self): def test_login_wrong_password_failure_basic(self): assert self.r_wrong_user_failure.status_code == 401 - for _, trace, span in interfaces.library.get_spans(request=self.r_wrong_user_failure): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_wrong_user_failure): meta = span.get("meta", {}) if context.library not in ("nodejs", "java"): # Currently in nodejs/java there is no way to check if the user exists upon authentication failure so @@ -738,7 +825,10 @@ def test_login_wrong_password_failure_basic(self): assert "appsec.events.users.login.failure.usr.login" not in meta assert meta["_dd.appsec.events.users.login.failure.auto.mode"] == "identification" assert meta["appsec.events.users.login.failure.track"] == "true" - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_sdk_success_local(self): self.r_sdk_success = weblog.post( @@ -748,13 +838,16 @@ def setup_login_sdk_success_local(self): def test_login_sdk_success_local(self): assert self.r_sdk_success.status_code == 200 - for _, trace, span in interfaces.library.get_spans(request=self.r_sdk_success): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_sdk_success): meta = span.get("meta", {}) assert meta["_dd.appsec.events.users.login.success.auto.mode"] == "identification" assert meta["_dd.appsec.events.users.login.success.sdk"] == "true" assert meta["appsec.events.users.login.success.track"] == "true" assert meta["usr.id"] == "sdkUser" - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_sdk_success_basic(self): self.r_sdk_success = weblog.get( @@ -764,13 +857,16 @@ def setup_login_sdk_success_basic(self): def test_login_sdk_success_basic(self): assert self.r_sdk_success.status_code == 200 - for _, trace, span in interfaces.library.get_spans(request=self.r_sdk_success): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_sdk_success): meta = span.get("meta", {}) assert meta["_dd.appsec.events.users.login.success.auto.mode"] == "identification" assert meta["_dd.appsec.events.users.login.success.sdk"] == "true" assert meta["appsec.events.users.login.success.track"] == "true" assert meta["usr.id"] == "sdkUser" - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_sdk_failure_local(self): self.r_sdk_failure = weblog.post( @@ -780,14 +876,17 @@ def setup_login_sdk_failure_local(self): def test_login_sdk_failure_local(self): assert self.r_sdk_failure.status_code == 401 - for _, trace, span in interfaces.library.get_spans(request=self.r_sdk_failure): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_sdk_failure): meta = span.get("meta", {}) assert meta["_dd.appsec.events.users.login.failure.auto.mode"] == "identification" assert meta["_dd.appsec.events.users.login.failure.sdk"] == "true" assert meta["appsec.events.users.login.failure.track"] == "true" assert meta["appsec.events.users.login.failure.usr.id"] == "sdkUser" assert meta["appsec.events.users.login.failure.usr.exists"] == "true" - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_sdk_failure_basic(self): self.r_sdk_failure = weblog.get( @@ -797,14 +896,17 @@ def setup_login_sdk_failure_basic(self): def test_login_sdk_failure_basic(self): assert self.r_sdk_failure.status_code == 401 - for _, trace, span in interfaces.library.get_spans(request=self.r_sdk_failure): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_sdk_failure): meta = span.get("meta", {}) assert meta["_dd.appsec.events.users.login.failure.auto.mode"] == "identification" assert meta["_dd.appsec.events.users.login.failure.sdk"] == "true" assert meta["appsec.events.users.login.failure.track"] == "true" assert meta["appsec.events.users.login.failure.usr.id"] == "sdkUser" assert meta["appsec.events.users.login.failure.usr.exists"] == "true" - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) @rfc("https://docs.google.com/document/d/19VHLdJLVFwRb_JrE87fmlIM5CL5LdOBv4AmLxgdo9qI/edit") @@ -821,7 +923,7 @@ def setup_login_success_local(self): def test_login_success_local(self): assert self.r_success.status_code == 200 - for _, trace, span in interfaces.library.get_spans(request=self.r_success): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_success): meta = span.get("meta", {}) assert meta["_dd.appsec.events.users.login.success.auto.mode"] == "anonymization" assert meta["appsec.events.users.login.success.track"] == "true" @@ -834,14 +936,17 @@ def test_login_success_local(self): # "usr.username" not in meta # "usr.login" not in meta - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_success_basic(self): self.r_success = weblog.get("/login?auth=basic", headers={"Authorization": BASIC_AUTH_USER_HEADER}) def test_login_success_basic(self): assert self.r_success.status_code == 200 - for _, trace, span in interfaces.library.get_spans(request=self.r_success): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_success): meta = span.get("meta", {}) assert meta["_dd.appsec.events.users.login.success.auto.mode"] == "anonymization" assert meta["appsec.events.users.login.success.track"] == "true" @@ -854,14 +959,17 @@ def test_login_success_basic(self): # "usr.username" not in meta # "usr.login" not in meta - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_wrong_user_failure_local(self): self.r_wrong_user_failure = weblog.post("/login?auth=local", data=login_data(INVALID_USER, PASSWORD)) def test_login_wrong_user_failure_local(self): assert self.r_wrong_user_failure.status_code == 401 - for _, trace, span in interfaces.library.get_spans(request=self.r_wrong_user_failure): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_wrong_user_failure): meta = span.get("meta", {}) assert meta["appsec.events.users.login.failure.usr.exists"] == "false" @@ -872,7 +980,10 @@ def test_login_wrong_user_failure_local(self): assert "appsec.events.users.login.failure.email" not in meta assert "appsec.events.users.login.failure.username" not in meta - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_wrong_user_failure_basic(self): self.r_wrong_user_failure = weblog.get( @@ -881,7 +992,7 @@ def setup_login_wrong_user_failure_basic(self): def test_login_wrong_user_failure_basic(self): assert self.r_wrong_user_failure.status_code == 401 - for _, trace, span in interfaces.library.get_spans(request=self.r_wrong_user_failure): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_wrong_user_failure): meta = span.get("meta", {}) assert meta["appsec.events.users.login.failure.usr.exists"] == "false" @@ -892,14 +1003,17 @@ def test_login_wrong_user_failure_basic(self): assert "appsec.events.users.login.failure.email" not in meta assert "appsec.events.users.login.failure.username" not in meta - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_wrong_password_failure_local(self): self.r_wrong_user_failure = weblog.post("/login?auth=local", data=login_data(USER, "12345")) def test_login_wrong_password_failure_local(self): assert self.r_wrong_user_failure.status_code == 401 - for _, trace, span in interfaces.library.get_spans(request=self.r_wrong_user_failure): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_wrong_user_failure): meta = span.get("meta", {}) if context.library != "java": # Currently in java there is no way to check if the user exists upon authentication failure so @@ -917,7 +1031,10 @@ def test_login_wrong_password_failure_local(self): assert "appsec.events.users.login.failure.email" not in meta assert "appsec.events.users.login.failure.username" not in meta - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_wrong_password_failure_basic(self): self.r_wrong_user_failure = weblog.get( @@ -926,7 +1043,7 @@ def setup_login_wrong_password_failure_basic(self): def test_login_wrong_password_failure_basic(self): assert self.r_wrong_user_failure.status_code == 401 - for _, trace, span in interfaces.library.get_spans(request=self.r_wrong_user_failure): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_wrong_user_failure): meta = span.get("meta", {}) if context.library != "java": # Currently in java there is no way to check if the user exists upon authentication failure so @@ -944,7 +1061,10 @@ def test_login_wrong_password_failure_basic(self): assert "appsec.events.users.login.failure.email" not in meta assert "appsec.events.users.login.failure.username" not in meta - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_sdk_success_local(self): self.r_sdk_success = weblog.post( @@ -954,13 +1074,16 @@ def setup_login_sdk_success_local(self): def test_login_sdk_success_local(self): assert self.r_sdk_success.status_code == 200 - for _, trace, span in interfaces.library.get_spans(request=self.r_sdk_success): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_sdk_success): meta = span.get("meta", {}) assert meta["_dd.appsec.events.users.login.success.auto.mode"] == "anonymization" assert meta["_dd.appsec.events.users.login.success.sdk"] == "true" assert meta["appsec.events.users.login.success.track"] == "true" assert meta["usr.id"] == "sdkUser" - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_sdk_success_basic(self): self.r_sdk_success = weblog.get( @@ -970,13 +1093,16 @@ def setup_login_sdk_success_basic(self): def test_login_sdk_success_basic(self): assert self.r_sdk_success.status_code == 200 - for _, trace, span in interfaces.library.get_spans(request=self.r_sdk_success): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_sdk_success): meta = span.get("meta", {}) assert meta["_dd.appsec.events.users.login.success.auto.mode"] == "anonymization" assert meta["_dd.appsec.events.users.login.success.sdk"] == "true" assert meta["appsec.events.users.login.success.track"] == "true" assert meta["usr.id"] == "sdkUser" - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_sdk_failure_basic(self): self.r_sdk_failure = weblog.get( @@ -986,14 +1112,17 @@ def setup_login_sdk_failure_basic(self): def test_login_sdk_failure_basic(self): assert self.r_sdk_failure.status_code == 401 - for _, trace, span in interfaces.library.get_spans(request=self.r_sdk_failure): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_sdk_failure): meta = span.get("meta", {}) assert meta["_dd.appsec.events.users.login.failure.auto.mode"] == "anonymization" assert meta["_dd.appsec.events.users.login.failure.sdk"] == "true" assert meta["appsec.events.users.login.failure.track"] == "true" assert meta["appsec.events.users.login.failure.usr.id"] == "sdkUser" assert meta["appsec.events.users.login.failure.usr.exists"] == "true" - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_sdk_failure_local(self): self.r_sdk_failure = weblog.post( @@ -1003,14 +1132,17 @@ def setup_login_sdk_failure_local(self): def test_login_sdk_failure_local(self): assert self.r_sdk_failure.status_code == 401 - for _, trace, span in interfaces.library.get_spans(request=self.r_sdk_failure): + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_sdk_failure): meta = span.get("meta", {}) assert meta["_dd.appsec.events.users.login.failure.auto.mode"] == "anonymization" assert meta["_dd.appsec.events.users.login.failure.sdk"] == "true" assert meta["appsec.events.users.login.failure.track"] == "true" assert meta["appsec.events.users.login.failure.usr.id"] == "sdkUser" assert meta["appsec.events.users.login.failure.usr.exists"] == "true" - assert_priority(span, trace) + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) def setup_login_success_headers(self): self.r_hdr_success = weblog.post( @@ -1053,12 +1185,19 @@ def validate_login_failure_headers(span: dict): interfaces.library.validate_one_span(self.r_hdr_failure, validator=validate_login_failure_headers) -def assert_priority(span: dict, trace: list[dict]): - if "_sampling_priority_v1" not in span["metrics"]: +def assert_priority(span: dict, trace: list[dict] | dict): + # Convert trace to list format if it's a dict (v1 format) + if isinstance(trace, dict): + trace = trace.get("spans", []) + # Use get_span_metrics to handle both v04 and v1 formats (pass None to auto-detect) + span_metrics = interfaces.library.get_span_metrics(span, None) + if "_sampling_priority_v1" not in span_metrics: # some tracers like java only send the priority in the first and last span of the trace - assert trace[0]["metrics"].get("_sampling_priority_v1") == SamplingPriority.USER_KEEP + first_span = trace[0] + first_span_metrics = interfaces.library.get_span_metrics(first_span, None) + assert first_span_metrics.get("_sampling_priority_v1") == SamplingPriority.USER_KEEP else: - assert span["metrics"].get("_sampling_priority_v1") == SamplingPriority.USER_KEEP + assert span_metrics.get("_sampling_priority_v1") == SamplingPriority.USER_KEEP @rfc("https://docs.google.com/document/d/19VHLdJLVFwRb_JrE87fmlIM5CL5LdOBv4AmLxgdo9qI/edit") @@ -1131,7 +1270,7 @@ def _assert_response(self, test: dict, validation: Callable): assert config_states.state == rc.ApplyState.ACKNOWLEDGED assert request.status_code == 200 - spans = [s for _, _, s in interfaces.library.get_spans(request=request)] + spans = [s for _, _, s, _ in interfaces.library.get_spans(request=request)] assert spans, "No spans to validate" for span in spans: meta = span.get("meta", {}) @@ -1199,8 +1338,11 @@ def setup_login_success_local(self): def test_login_success_local(self): assert self.r_success.status_code == 200 - for _, trace, span in interfaces.library.get_spans(request=self.r_success): - assert_priority(span, trace) + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_success): + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) meta = span.get("meta", {}) # mandatory @@ -1219,8 +1361,11 @@ def setup_login_success_basic(self): def test_login_success_basic(self): assert self.r_success.status_code == 200 - for _, trace, span in interfaces.library.get_spans(request=self.r_success): - assert_priority(span, trace) + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_success): + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) meta = span.get("meta", {}) # mandatory @@ -1239,8 +1384,11 @@ def setup_login_wrong_user_failure_local(self): def test_login_wrong_user_failure_local(self): assert self.r_wrong_user_failure.status_code == 401 - for _, trace, span in interfaces.library.get_spans(request=self.r_wrong_user_failure): - assert_priority(span, trace) + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_wrong_user_failure): + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) meta = span.get("meta", {}) # mandatory @@ -1260,8 +1408,11 @@ def setup_login_wrong_user_failure_basic(self): def test_login_wrong_user_failure_basic(self): assert self.r_wrong_user_failure.status_code == 401 - for _, trace, span in interfaces.library.get_spans(request=self.r_wrong_user_failure): - assert_priority(span, trace) + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_wrong_user_failure): + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) meta = span.get("meta", {}) # mandatory @@ -1279,8 +1430,11 @@ def setup_login_wrong_password_failure_local(self): def test_login_wrong_password_failure_local(self): assert self.r_wrong_user_failure.status_code == 401 - for _, trace, span in interfaces.library.get_spans(request=self.r_wrong_user_failure): - assert_priority(span, trace) + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_wrong_user_failure): + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) meta = span.get("meta", {}) # mandatory @@ -1304,8 +1458,11 @@ def setup_login_wrong_password_failure_basic(self): def test_login_wrong_password_failure_basic(self): assert self.r_wrong_user_failure.status_code == 401 - for _, trace, span in interfaces.library.get_spans(request=self.r_wrong_user_failure): - assert_priority(span, trace) + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_wrong_user_failure): + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) meta = span.get("meta", {}) # mandatory @@ -1334,8 +1491,11 @@ def setup_login_sdk_success_local(self): def test_login_sdk_success_local(self): for request in self.r_sdk_success: assert request.status_code == 200 - for _, trace, span in interfaces.library.get_spans(request=request): - assert_priority(span, trace) + for _, trace, span, _ in interfaces.library.get_spans(request=request): + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) meta = span.get("meta", {}) # mandatory @@ -1362,8 +1522,11 @@ def setup_login_sdk_success_basic(self): def test_login_sdk_success_basic(self): for request in self.r_sdk_success: assert request.status_code == 200 - for _, trace, span in interfaces.library.get_spans(request=request): - assert_priority(span, trace) + for _, trace, span, _ in interfaces.library.get_spans(request=request): + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) meta = span.get("meta", {}) # mandatory @@ -1390,8 +1553,11 @@ def setup_login_sdk_failure_local(self): def test_login_sdk_failure_local(self): for request in self.r_sdk_failure: assert request.status_code == 401 - for _, trace, span in interfaces.library.get_spans(request=request): - assert_priority(span, trace) + for _, trace, span, _ in interfaces.library.get_spans(request=request): + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) meta = span.get("meta", {}) # mandatory @@ -1414,8 +1580,11 @@ def setup_login_sdk_failure_basic(self): def test_login_sdk_failure_basic(self): for request in self.r_sdk_failure: assert request.status_code == 401 - for _, trace, span in interfaces.library.get_spans(request=request): - assert_priority(span, trace) + for _, trace, span, _ in interfaces.library.get_spans(request=request): + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) meta = span.get("meta", {}) # mandatory @@ -1431,8 +1600,11 @@ def setup_signup_local(self): def test_signup_local(self): assert self.r_success.status_code == 200 - for _, trace, span in interfaces.library.get_spans(request=self.r_success): - assert_priority(span, trace) + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_success): + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) meta = span.get("meta", {}) # mandatory @@ -1501,8 +1673,11 @@ def setup_login_success_local(self): def test_login_success_local(self): assert self.r_success.status_code == 200 - for _, trace, span in interfaces.library.get_spans(request=self.r_success): - assert_priority(span, trace) + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_success): + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) meta = span.get("meta", {}) # mandatory @@ -1521,8 +1696,11 @@ def setup_login_success_basic(self): def test_login_success_basic(self): assert self.r_success.status_code == 200 - for _, trace, span in interfaces.library.get_spans(request=self.r_success): - assert_priority(span, trace) + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_success): + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) meta = span.get("meta", {}) # mandatory @@ -1541,8 +1719,11 @@ def setup_login_wrong_user_failure_local(self): def test_login_wrong_user_failure_local(self): assert self.r_wrong_user_failure.status_code == 401 - for _, trace, span in interfaces.library.get_spans(request=self.r_wrong_user_failure): - assert_priority(span, trace) + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_wrong_user_failure): + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) meta = span.get("meta", {}) # mandatory @@ -1562,8 +1743,11 @@ def setup_login_wrong_user_failure_basic(self): def test_login_wrong_user_failure_basic(self): assert self.r_wrong_user_failure.status_code == 401 - for _, trace, span in interfaces.library.get_spans(request=self.r_wrong_user_failure): - assert_priority(span, trace) + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_wrong_user_failure): + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) meta = span.get("meta", {}) # mandatory @@ -1581,8 +1765,11 @@ def setup_login_wrong_password_failure_local(self): def test_login_wrong_password_failure_local(self): assert self.r_wrong_user_failure.status_code == 401 - for _, trace, span in interfaces.library.get_spans(request=self.r_wrong_user_failure): - assert_priority(span, trace) + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_wrong_user_failure): + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) meta = span.get("meta", {}) # mandatory @@ -1606,8 +1793,11 @@ def setup_login_wrong_password_failure_basic(self): def test_login_wrong_password_failure_basic(self): assert self.r_wrong_user_failure.status_code == 401 - for _, trace, span in interfaces.library.get_spans(request=self.r_wrong_user_failure): - assert_priority(span, trace) + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_wrong_user_failure): + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) meta = span.get("meta", {}) # mandatory @@ -1636,8 +1826,11 @@ def setup_login_sdk_success_local(self): def test_login_sdk_success_local(self): for request in self.r_sdk_success: assert request.status_code == 200 - for _, trace, span in interfaces.library.get_spans(request=request): - assert_priority(span, trace) + for _, trace, span, _ in interfaces.library.get_spans(request=request): + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) meta = span.get("meta", {}) # mandatory @@ -1664,8 +1857,11 @@ def setup_login_sdk_success_basic(self): def test_login_sdk_success_basic(self): for request in self.r_sdk_success: assert request.status_code == 200 - for _, trace, span in interfaces.library.get_spans(request=request): - assert_priority(span, trace) + for _, trace, span, _ in interfaces.library.get_spans(request=request): + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) meta = span.get("meta", {}) # mandatory @@ -1692,8 +1888,11 @@ def setup_login_sdk_failure_local(self): def test_login_sdk_failure_local(self): for request in self.r_sdk_failure: assert request.status_code == 401 - for _, trace, span in interfaces.library.get_spans(request=request): - assert_priority(span, trace) + for _, trace, span, _ in interfaces.library.get_spans(request=request): + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) meta = span.get("meta", {}) # mandatory @@ -1716,8 +1915,11 @@ def setup_login_sdk_failure_basic(self): def test_login_sdk_failure_basic(self): for request in self.r_sdk_failure: assert request.status_code == 401 - for _, trace, span in interfaces.library.get_spans(request=request): - assert_priority(span, trace) + for _, trace, span, _ in interfaces.library.get_spans(request=request): + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) meta = span.get("meta", {}) # mandatory @@ -1733,8 +1935,11 @@ def setup_signup_local(self): def test_signup_local(self): assert self.r_success.status_code == 200 - for _, trace, span in interfaces.library.get_spans(request=self.r_success): - assert_priority(span, trace) + for _, trace, span, _ in interfaces.library.get_spans(request=self.r_success): + # Convert trace to list format for assert_priority + + trace_list = trace.get("spans", []) if isinstance(trace, dict) else trace + assert_priority(span, trace_list) meta = span.get("meta", {}) # mandatory @@ -1769,7 +1974,7 @@ def _assert_response(self, test: dict, validation: Callable): assert config_state.state == rc.ApplyState.ACKNOWLEDGED assert request.status_code == 200 - spans = [s for _, _, s in interfaces.library.get_spans(request=request)] + spans = [s for _, _, s, _ in interfaces.library.get_spans(request=request)] assert spans, "No spans to validate" for span in spans: meta = span.get("meta", {}) diff --git a/tests/appsec/test_automated_user_and_session_tracking.py b/tests/appsec/test_automated_user_and_session_tracking.py index 3a1fe819d22..9134bca7a9c 100644 --- a/tests/appsec/test_automated_user_and_session_tracking.py +++ b/tests/appsec/test_automated_user_and_session_tracking.py @@ -57,7 +57,7 @@ def test_user_tracking_auto(self): assert self.r_login.status_code == 200 assert self.r_home.status_code == 200 - for _, _, span in interfaces.library.get_spans(request=self.r_home): + for _, _, span, _ in interfaces.library.get_spans(request=self.r_home): meta = span.get("meta", {}) if context.library in libs_without_user_id: assert meta["usr.id"] == USER @@ -76,7 +76,7 @@ def test_user_tracking_sdk_overwrite(self): assert self.r_login.status_code == 200 assert self.r_users.status_code == 200 - for _, _, span in interfaces.library.get_spans(request=self.r_users): + for _, _, span, _ in interfaces.library.get_spans(request=self.r_users): meta = span.get("meta", {}) assert meta["usr.id"] == "sdkUser" if context.library in libs_without_user_id: diff --git a/tests/appsec/test_blocking_addresses.py b/tests/appsec/test_blocking_addresses.py index e90c635f82b..9fc733e9ae8 100644 --- a/tests/appsec/test_blocking_addresses.py +++ b/tests/appsec/test_blocking_addresses.py @@ -814,8 +814,8 @@ def setup_request_block_attack(self): def test_request_block_attack(self): assert self.r_attack.status_code == 403 - span = interfaces.library.get_root_span(request=self.r_attack) - meta = span.get("meta", {}) + span, span_format = interfaces.library.get_root_span(request=self.r_attack) + meta = interfaces.library.get_span_meta(span, span_format) meta_struct = span.get("meta_struct", {}) assert meta["appsec.event"] == "true" assert ("_dd.appsec.json" in meta) ^ ("appsec" in meta_struct) @@ -851,8 +851,8 @@ def setup_request_block_attack_directive(self): def test_request_block_attack_directive(self): assert self.r_attack.status_code == 403 - span = interfaces.library.get_root_span(request=self.r_attack) - meta = span.get("meta", {}) + span, span_format = interfaces.library.get_root_span(request=self.r_attack) + meta = interfaces.library.get_span_meta(span, span_format) meta_struct = span.get("meta_struct", {}) assert meta["appsec.event"] == "true" assert ("_dd.appsec.json" in meta) ^ ("appsec" in meta_struct) diff --git a/tests/appsec/test_event_tracking.py b/tests/appsec/test_event_tracking.py index ea63ffc82a0..3cd1c9bd844 100644 --- a/tests/appsec/test_event_tracking.py +++ b/tests/appsec/test_event_tracking.py @@ -2,6 +2,7 @@ # This product includes software developed at Datadog (https://www.datadoghq.com/). # Copyright 2021 Datadog, Inc. from utils import weblog, interfaces, features +from utils.dd_constants import TraceLibraryPayloadFormat from tests.appsec.utils import find_series HEADERS = { @@ -48,7 +49,7 @@ def setup_user_login_success_event(self): def test_user_login_success_event(self): # Call the user login success SDK and validate tags - def validate_user_login_success_tags(span: dict): + def validate_user_login_success_tags(span: dict, span_format: TraceLibraryPayloadFormat | None = None): expected_tags = { "http.client_ip": "1.2.3.4", "usr.id": "system_tests_user", @@ -57,9 +58,14 @@ def validate_user_login_success_tags(span: dict): "appsec.events.users.login.success.metadata1": "value1", } + # Use helper method to get meta for both v04 and v1 formats + if span_format is None: + span_format = interfaces.library._detect_span_format(span) # noqa: SLF001 + meta = interfaces.library.get_span_meta(span, span_format) + for tag, expected_value in expected_tags.items(): - assert tag in span["meta"], f"Can't find {tag} in span's meta" - value = span["meta"][tag] + assert tag in meta, f"Can't find {tag} in span's meta" + value = meta[tag] if value != expected_value: raise Exception(f"{tag} value is '{value}', should be '{expected_value}'") @@ -73,12 +79,19 @@ def setup_user_login_success_header_collection(self): def test_user_login_success_header_collection(self): # Validate that all relevant headers are included on user login success - def validate_user_login_success_header_collection(span: dict) -> bool: + def validate_user_login_success_header_collection( + span: dict, span_format: TraceLibraryPayloadFormat | None = None + ) -> bool: if span.get("parent_id") not in (0, None): return False + # Use helper method to get meta for both v04 and v1 formats + if span_format is None: + span_format = interfaces.library._detect_span_format(span) # noqa: SLF001 + meta = interfaces.library.get_span_meta(span, span_format) + for header in HEADERS: - assert f"http.request.headers.{header.lower()}" in span["meta"], f"Can't find {header} in span's meta" + assert f"http.request.headers.{header.lower()}" in meta, f"Can't find {header} in span's meta" return True @@ -117,7 +130,7 @@ def setup_user_login_failure_event(self): def test_user_login_failure_event(self): # Call the user login failure SDK and validate tags - def validate_user_login_failure_tags(span: dict): + def validate_user_login_failure_tags(span: dict, span_format: TraceLibraryPayloadFormat | None = None): expected_tags = { "http.client_ip": "1.2.3.4", "appsec.events.users.login.failure.usr.id": "system_tests_user", @@ -127,9 +140,14 @@ def validate_user_login_failure_tags(span: dict): "appsec.events.users.login.failure.metadata1": "value1", } + # Use helper method to get meta for both v04 and v1 formats + if span_format is None: + span_format = interfaces.library._detect_span_format(span) # noqa: SLF001 + meta = interfaces.library.get_span_meta(span, span_format) + for tag, expected_value in expected_tags.items(): - assert tag in span["meta"], f"Can't find {tag} in span's meta" - value = span["meta"][tag] + assert tag in meta, f"Can't find {tag} in span's meta" + value = meta[tag] if value != expected_value: raise Exception(f"{tag} value is '{value}', should be '{expected_value}'") @@ -143,12 +161,19 @@ def setup_user_login_failure_header_collection(self): def test_user_login_failure_header_collection(self): # Validate that all relevant headers are included on user login failure - def validate_user_login_failure_header_collection(span: dict): + def validate_user_login_failure_header_collection( + span: dict, span_format: TraceLibraryPayloadFormat | None = None + ): if span.get("parent_id") not in (0, None): return None + # Use helper method to get meta for both v04 and v1 formats + if span_format is None: + span_format = interfaces.library._detect_span_format(span) # noqa: SLF001 + meta = interfaces.library.get_span_meta(span, span_format) + for header in HEADERS: - assert f"http.request.headers.{header.lower()}" in span["meta"], f"Can't find {header} in span's meta" + assert f"http.request.headers.{header.lower()}" in meta, f"Can't find {header} in span's meta" return True interfaces.library.validate_one_span(self.r, validator=validate_user_login_failure_header_collection) @@ -186,7 +211,7 @@ def setup_custom_event_event(self): def test_custom_event_event(self): # Call the user login failure SDK and validate tags - def validate_custom_event_tags(span: dict): + def validate_custom_event_tags(span: dict, span_format: TraceLibraryPayloadFormat | None = None): expected_tags = { "http.client_ip": "1.2.3.4", "appsec.events.system_tests_event.track": "true", @@ -194,9 +219,14 @@ def validate_custom_event_tags(span: dict): "appsec.events.system_tests_event.metadata1": "value1", } + # Use helper method to get meta for both v04 and v1 formats + if span_format is None: + span_format = interfaces.library._detect_span_format(span) # noqa: SLF001 + meta = interfaces.library.get_span_meta(span, span_format) + for tag, expected_value in expected_tags.items(): - assert tag in span["meta"], f"Can't find {tag} in span's meta" - value = span["meta"][tag] + assert tag in meta, f"Can't find {tag} in span's meta" + value = meta[tag] if value != expected_value: raise Exception(f"{tag} value is '{value}', should be '{expected_value}'") diff --git a/tests/appsec/test_extended_data_collection.py b/tests/appsec/test_extended_data_collection.py index 0f6af4c86ed..47c49618ef2 100644 --- a/tests/appsec/test_extended_data_collection.py +++ b/tests/appsec/test_extended_data_collection.py @@ -98,8 +98,8 @@ def test_extended_data_collection_with_rc(self): assert self.response.status_code == 200 # Verify extended data collection is working by checking span metadata - span = interfaces.library.get_root_span(request=self.response) - meta = span.get("meta", {}) + span, span_format = interfaces.library.get_root_span(request=self.response) + meta = interfaces.library.get_span_meta(span, span_format) # Check that extended headers are collected when the rule matches assert meta.get("http.request.headers.x-my-header-1") == "value1" @@ -109,7 +109,7 @@ def test_extended_data_collection_with_rc(self): assert meta.get("http.request.headers.content-type") == "text/html" # Check that no headers were discarded (within the 50 limit) - metrics = span.get("metrics", {}) + metrics = interfaces.library.get_span_metrics(span, span_format) assert metrics.get("_dd.appsec.request.header_collection.discarded") is None def setup_no_extended_data_collection_without_event(self): @@ -143,8 +143,8 @@ def test_no_extended_data_collection_without_event(self): assert self.response.status_code == 200 # Verify extended data collection is ignored by checking span metadata - span = interfaces.library.get_root_span(request=self.response) - meta = span.get("meta", {}) + span, span_format = interfaces.library.get_root_span(request=self.response) + meta = interfaces.library.get_span_meta(span, span_format) # Check that extended headers are NOT collected when there is no event assert meta.get("http.request.headers.x-my-header-1") is None @@ -156,7 +156,7 @@ def test_no_extended_data_collection_without_event(self): assert meta.get("http.request.headers.content-type") == "text/html" # Check that no headers were discarded (within the 50 limit) - metrics = span.get("metrics", {}) + metrics = interfaces.library.get_span_metrics(span, span_format) assert metrics.get("_dd.appsec.request.header_collection.discarded") is None def setup_extended_data_collection_with_rc_header_limit(self): @@ -188,8 +188,8 @@ def test_extended_data_collection_with_rc_header_limit(self): assert self.response.status_code == 200 # Verify extended data collection header limit is working by checking span metadata - span = interfaces.library.get_root_span(request=self.response) - meta = span.get("meta", {}) + span, span_format = interfaces.library.get_root_span(request=self.response) + meta = interfaces.library.get_span_meta(span, span_format) # Ensure no more than 50 meta entries start with "http.request.headers." header_keys = [k for k in meta if k.startswith("http.request.headers.")] @@ -198,7 +198,7 @@ def test_extended_data_collection_with_rc_header_limit(self): # Ensure allowed headers are collected assert meta.get("http.request.headers.content-type") == "text/html" - metrics = span.get("metrics", {}) + metrics = interfaces.library.get_span_metrics(span, span_format) # Confirm _dd.appsec.request.header_collection.discarded exists and is > 0 discarded = metrics.get("_dd.appsec.request.header_collection.discarded") @@ -241,8 +241,8 @@ def test_extended_data_collection_with_rc_and_authentication_headers(self): assert self.response.status_code == 200 # Verify extended data collection is working by checking span metadata - span = interfaces.library.get_root_span(request=self.response) - meta = span.get("meta", {}) + span, span_format = interfaces.library.get_root_span(request=self.response) + meta = interfaces.library.get_span_meta(span, span_format) # Check that extended headers are redacted assert meta.get("http.request.headers.authorization") == "" @@ -256,7 +256,7 @@ def test_extended_data_collection_with_rc_and_authentication_headers(self): assert meta.get("http.request.headers.content-type") == "text/html" # Check that no headers were discarded (within the 50 limit) - metrics = span.get("metrics", {}) + metrics = interfaces.library.get_span_metrics(span, span_format) assert metrics.get("_dd.appsec.request.header_collection.discarded") is None @@ -290,8 +290,8 @@ def test_extended_response_headers_collection_with_rc(self): assert self.response.status_code == 200 # Verify extended response headers data collection is working by checking span metadata - span = interfaces.library.get_root_span(request=self.response) - meta = span.get("meta", {}) + span, span_format = interfaces.library.get_root_span(request=self.response) + meta = interfaces.library.get_span_meta(span, span_format) # Check that extended response headers are collected when the rule matches assert meta.get("http.response.headers.x-test-header-1") == "value1" @@ -301,7 +301,7 @@ def test_extended_response_headers_collection_with_rc(self): assert meta.get("http.response.headers.content-language") == "en-US" # Check that no response headers were discarded (within the 50 limit) - metrics = span.get("metrics", {}) + metrics = interfaces.library.get_span_metrics(span, span_format) assert metrics.get("_dd.appsec.response.header_collection.discarded") is None def setup_no_extended_response_headers_collection_without_event(self): @@ -328,8 +328,8 @@ def test_no_extended_response_headers_collection_without_event(self): assert self.response.status_code == 200 # Verify extended response headers data collection is not working when rule is not triggered - span = interfaces.library.get_root_span(request=self.response) - meta = span.get("meta", {}) + span, span_format = interfaces.library.get_root_span(request=self.response) + meta = interfaces.library.get_span_meta(span, span_format) # Check that extended response headers are NOT collected when the rule is not triggered assert meta.get("http.response.headers.x-test-header-1") is None @@ -341,7 +341,7 @@ def test_no_extended_response_headers_collection_without_event(self): assert meta.get("http.response.headers.content-language") is None # Check that no response headers were discarded (within the 50 limit) - metrics = span.get("metrics", {}) + metrics = interfaces.library.get_span_metrics(span, span_format) assert metrics.get("_dd.appsec.response.header_collection.discarded") is None def setup_extended_response_headers_collection_with_rc_header_limit(self): @@ -369,8 +369,8 @@ def test_extended_response_headers_collection_with_rc_header_limit(self): assert self.response.status_code == 200 # Verify extended response headers data collection header limit is working by checking span metadata - span = interfaces.library.get_root_span(request=self.response) - meta = span.get("meta", {}) + span, span_format = interfaces.library.get_root_span(request=self.response) + meta = interfaces.library.get_span_meta(span, span_format) # Ensure exactly 50 meta entries start with "http.response.headers." header_keys = [k for k in meta if k.startswith("http.response.headers.")] @@ -379,7 +379,7 @@ def test_extended_response_headers_collection_with_rc_header_limit(self): # Ensure allowed response headers are collected assert meta.get("http.response.headers.content-language") == "en-US" - metrics = span.get("metrics", {}) + metrics = interfaces.library.get_span_metrics(span, span_format) # Confirm _dd.appsec.response.header_collection.discarded exists and is > 0 discarded = metrics.get("_dd.appsec.response.header_collection.discarded") assert discarded is not None @@ -410,8 +410,8 @@ def test_extended_data_collection_with_rc_and_authentication_headers(self): assert self.response.status_code == 200 # Verify extended response headers data collection header limit is working by checking span metadata - span = interfaces.library.get_root_span(request=self.response) - meta = span.get("meta", {}) + span, span_format = interfaces.library.get_root_span(request=self.response) + meta = interfaces.library.get_span_meta(span, span_format) # Check that extended headers are redacted assert meta.get("http.response.headers.authorization") == "" @@ -453,7 +453,7 @@ def test_extended_request_body_collection(self): assert self.response.status_code == 200 # Verify extended request body data collection is working by checking span meta_struct - span = interfaces.library.get_root_span(request=self.response) + span, _span_format = interfaces.library.get_root_span(request=self.response) meta_struct = span.get("meta_struct", {}) # Check that request body is collected in meta_struct when the rule matches @@ -486,7 +486,7 @@ def test_no_extended_request_body_collection_without_event(self): assert self.response.status_code == 200 # Verify extended request body data collection is not working when rule is not triggered - span = interfaces.library.get_root_span(request=self.response) + span, _span_format = interfaces.library.get_root_span(request=self.response) meta_struct = span.get("meta_struct", {}) # Check that request body is NOT collected when the rule is not triggered @@ -517,7 +517,7 @@ def test_extended_request_body_collection_truncated(self): assert self.response.status_code == 200 # Verify extended request body data collection with truncation is working by checking span meta_struct - span = interfaces.library.get_root_span(request=self.response) + span, _span_format = interfaces.library.get_root_span(request=self.response) meta_struct = span.get("meta_struct", {}) # Check that request body is collected in meta_struct when the rule matches diff --git a/tests/appsec/test_extended_header_collection.py b/tests/appsec/test_extended_header_collection.py index b2fc92a13f8..08ba7d0fd09 100644 --- a/tests/appsec/test_extended_header_collection.py +++ b/tests/appsec/test_extended_header_collection.py @@ -13,8 +13,8 @@ class Test_ExtendedHeaderCollection: @staticmethod def assert_feature_is_enabled(response: HttpResponse) -> None: assert response.status_code == 200 - span = interfaces.library.get_root_span(request=response) - meta = span.get("meta", {}) + span, span_format = interfaces.library.get_root_span(request=response) + meta = interfaces.library.get_span_meta(span, span_format) assert meta.get("http.request.headers.x-my-header-1") == "value1" def setup_feature_is_enabled(self): @@ -41,14 +41,14 @@ def setup_if_appsec_event_collect_all_request_headers(self): def test_if_appsec_event_collect_all_request_headers(self): assert self.r.status_code == 200 - span = interfaces.library.get_root_span(request=self.r) - meta = span.get("meta", {}) + span, span_format = interfaces.library.get_root_span(request=self.r) + meta = interfaces.library.get_span_meta(span, span_format) assert meta.get("http.request.headers.x-my-header-1") == "value1" assert meta.get("http.request.headers.x-my-header-2") == "value2" assert meta.get("http.request.headers.x-my-header-3") == "value3" assert meta.get("http.request.headers.x-my-header-4") == "value4" assert meta.get("http.request.headers.content-type") == "text/html" - metrics = span.get("metrics", {}) + metrics = interfaces.library.get_span_metrics(span, span_format) assert metrics.get("_dd.appsec.request.header_collection.discarded") is None def setup_if_no_appsec_event_collect_allowed_request_headers(self): @@ -67,14 +67,14 @@ def setup_if_no_appsec_event_collect_allowed_request_headers(self): def test_if_no_appsec_event_collect_allowed_request_headers(self): self.assert_feature_is_enabled(self.check_r) assert self.r.status_code == 200 - span = interfaces.library.get_root_span(request=self.r) - meta = span.get("meta", {}) + span, span_format = interfaces.library.get_root_span(request=self.r) + meta = interfaces.library.get_span_meta(span, span_format) assert meta.get("http.request.headers.x-my-header-1") is None assert meta.get("http.request.headers.x-my-header-2") is None assert meta.get("http.request.headers.x-my-header-3") is None assert meta.get("http.request.headers.x-my-header-4") is None assert meta.get("http.request.headers.content-type") == "text/html" - metrics = span.get("metrics", {}) + metrics = interfaces.library.get_span_metrics(span, span_format) assert metrics.get("_dd.appsec.request.header_collection.discarded") is None def setup_not_exceed_default_50_maximum_request_header_collection(self): @@ -91,8 +91,8 @@ def setup_not_exceed_default_50_maximum_request_header_collection(self): def test_not_exceed_default_50_maximum_request_header_collection(self): self.assert_feature_is_enabled(self.check_r) assert self.r.status_code == 200 - span = interfaces.library.get_root_span(request=self.r) - meta = span.get("meta", {}) + span, span_format = interfaces.library.get_root_span(request=self.r) + meta = interfaces.library.get_span_meta(span, span_format) # Ensure no more than 50 meta entries start with "http.request.headers." header_keys = [k for k in meta if k.startswith("http.request.headers.")] @@ -101,7 +101,7 @@ def test_not_exceed_default_50_maximum_request_header_collection(self): # Ensure allowed headers are collected assert meta.get("http.request.headers.content-type") == "text/html" - metrics = span.get("metrics", {}) + metrics = interfaces.library.get_span_metrics(span, span_format) # Confirm _dd.appsec.request.header_collection.discarded exists and is > 0 discarded = metrics.get("_dd.appsec.request.header_collection.discarded") assert discarded is not None @@ -121,15 +121,15 @@ def setup_if_appsec_event_collect_all_response_headers(self): ) def test_if_appsec_event_collect_all_response_headers(self): assert self.r.status_code == 200 - span = interfaces.library.get_root_span(request=self.r) - meta = span.get("meta", {}) + span, span_format = interfaces.library.get_root_span(request=self.r) + meta = interfaces.library.get_span_meta(span, span_format) assert meta.get("http.response.headers.x-test-header-1") == "value1" assert meta.get("http.response.headers.x-test-header-2") == "value2" assert meta.get("http.response.headers.x-test-header-3") == "value3" assert meta.get("http.response.headers.x-test-header-4") == "value4" assert meta.get("http.response.headers.x-test-header-5") == "value5" assert meta.get("http.response.headers.content-language") == "en-US" - metrics = span.get("metrics", {}) + metrics = interfaces.library.get_span_metrics(span, span_format) assert metrics.get("_dd.appsec.response.header_collection.discarded") is None def setup_if_no_appsec_event_collect_allowed_response_headers(self): @@ -139,8 +139,8 @@ def setup_if_no_appsec_event_collect_allowed_response_headers(self): def test_if_no_appsec_event_collect_allowed_response_headers(self): self.assert_feature_is_enabled(self.check_r) assert self.r.status_code == 200 - span = interfaces.library.get_root_span(request=self.r) - meta = span.get("meta", {}) + span, span_format = interfaces.library.get_root_span(request=self.r) + meta = interfaces.library.get_span_meta(span, span_format) assert meta.get("http.response.headers.x-test-header-1") is None assert meta.get("http.response.headers.x-test-header-2") is None assert meta.get("http.response.headers.x-test-header-3") is None @@ -149,7 +149,7 @@ def test_if_no_appsec_event_collect_allowed_response_headers(self): assert ( meta.get("http.response.headers.content-language") is None ) # at least in java we are not collecting response headers by default - metrics = span.get("metrics", {}) + metrics = interfaces.library.get_span_metrics(span, span_format) assert metrics.get("_dd.appsec.response.header_collection.discarded") is None def setup_not_exceed_default_50_maximum_response_header_collection(self): @@ -168,8 +168,8 @@ def setup_not_exceed_default_50_maximum_response_header_collection(self): def test_not_exceed_default_50_maximum_response_header_collection(self): self.assert_feature_is_enabled(self.check_r) assert self.r.status_code == 200 - span = interfaces.library.get_root_span(request=self.r) - meta = span.get("meta", {}) + span, span_format = interfaces.library.get_root_span(request=self.r) + meta = interfaces.library.get_span_meta(span, span_format) # Ensure no more than 50 meta entries start with "http.request.headers." header_keys = [k for k in meta if k.startswith("http.response.headers.")] @@ -178,7 +178,7 @@ def test_not_exceed_default_50_maximum_response_header_collection(self): # Ensure allowed headers are collected assert meta.get("http.response.headers.content-language") == "en-US" - metrics = span.get("metrics", {}) + metrics = interfaces.library.get_span_metrics(span, span_format) # Confirm _dd.appsec.response.header_collection.discarded exists and is > 0 discarded = metrics.get("_dd.appsec.response.header_collection.discarded") assert discarded is not None diff --git a/tests/appsec/test_extended_request_body_collection.py b/tests/appsec/test_extended_request_body_collection.py index bab112777a7..44dbdc8107a 100644 --- a/tests/appsec/test_extended_request_body_collection.py +++ b/tests/appsec/test_extended_request_body_collection.py @@ -36,7 +36,7 @@ def assert_feature_is_enabled(response: HttpResponse) -> None: }, }, ) - span = interfaces.library.get_root_span(request=response) + span, _span_format = interfaces.library.get_root_span(request=response) meta_struct = span.get("meta_struct", {}) body = meta_struct.get("http.request.body") assert body is not None @@ -70,12 +70,12 @@ def test_request_body_truncated(self): }, }, ) - span = interfaces.library.get_root_span(request=self.r) + span, span_format = interfaces.library.get_root_span(request=self.r) meta_struct = span.get("meta_struct", {}) body = meta_struct.get("http.request.body") assert body is not None assert_body_property(body, "command", "/usr/bin/touch /tmp/passwd" + "A" * 4070) - meta = span.get("meta", {}) + meta = interfaces.library.get_span_meta(span, span_format) assert meta.get("_dd.appsec.rasp.request_body_size.exceeded") == "true" def setup_if_no_rasp_event_no_collect_request_body(self): @@ -90,7 +90,7 @@ def setup_if_no_rasp_event_no_collect_request_body(self): def test_if_no_rasp_event_no_collect_request_body(self): self.assert_feature_is_enabled(self.check_r) assert self.r.status_code == 200 - span = interfaces.library.get_root_span(request=self.r) + span, _span_format = interfaces.library.get_root_span(request=self.r) meta_struct = span.get("meta_struct", {}) assert meta_struct.get("http.request.body") is None diff --git a/tests/appsec/test_fingerprinting.py b/tests/appsec/test_fingerprinting.py index 3a56cff8529..ff68577eebd 100644 --- a/tests/appsec/test_fingerprinting.py +++ b/tests/appsec/test_fingerprinting.py @@ -14,7 +14,11 @@ def get_span_meta(r: HttpResponse): - res = [span.get("meta", {}) for _, _, span in interfaces.library.get_spans(request=r)] + # Use helper method to get meta for both v04 and v1 formats + res = [] + for _, _, span, span_format in interfaces.library.get_spans(request=r): + meta = interfaces.library.get_span_meta(span, span_format) + res.append(meta) assert res, f"no spans found in {r}" return res diff --git a/tests/appsec/test_identify.py b/tests/appsec/test_identify.py index 2e4091b0cf3..ac01b9c553a 100644 --- a/tests/appsec/test_identify.py +++ b/tests/appsec/test_identify.py @@ -3,6 +3,7 @@ # Copyright 2021 Datadog, Inc. from utils import weblog, interfaces, features +from utils.dd_constants import TraceLibraryPayloadFormat @features.user_monitoring @@ -15,13 +16,18 @@ def setup_identify_tags_with_attack(self): def test_identify_tags_with_attack(self): # Send a random attack on the identify endpoint - should not affect the usr.id tag - def validate_identify_tags(span: dict): + def validate_identify_tags(span: dict, span_format: TraceLibraryPayloadFormat | None = None): + # Use helper method to get meta for both v04 and v1 formats + if span_format is None: + span_format = interfaces.library._detect_span_format(span) # noqa: SLF001 + meta = interfaces.library.get_span_meta(span, span_format) + for tag in ["id", "name", "email", "session_id", "role", "scope"]: key = f"usr.{tag}" - assert key in span["meta"], f"Can't find {key} in span's meta" + assert key in meta, f"Can't find {key} in span's meta" expected_value = f"usr.{tag}" # key and value are the same on weblog spec - value = span["meta"][key] + value = meta[key] if value != expected_value: raise Exception(f"{key} value is '{value}', should be '{expected_value}'") diff --git a/tests/appsec/test_metastruct.py b/tests/appsec/test_metastruct.py index 858f6cfe584..2adc0f127a3 100644 --- a/tests/appsec/test_metastruct.py +++ b/tests/appsec/test_metastruct.py @@ -14,9 +14,9 @@ def setup_appsec_event_use_metastruct(self): self.r = weblog.get("/", headers={"User-Agent": "Arachni/v1"}) def test_appsec_event_use_metastruct(self): - span = interfaces.library.get_root_span(request=self.r) - meta = span.get("meta", {}) - meta_struct = span.get("meta_struct", {}) + span, span_format = interfaces.library.get_root_span(request=self.r) + meta = interfaces.library.get_span_meta(span, span_format) + meta_struct = interfaces.library.get_span_meta_struct(span, span_format) assert meta["appsec.event"] == "true" assert "_dd.appsec.json" not in meta assert "appsec" in meta_struct @@ -37,10 +37,10 @@ def setup_iast_event_use_metastruct(self): self.r = weblog.get("/iast/source/cookievalue/test", cookies={"table": "user"}) def test_iast_event_use_metastruct(self): - span = interfaces.library.get_root_span(request=self.r) - meta = span.get("meta", {}) - metrics = span.get("metrics", {}) - meta_struct = span.get("meta_struct", {}) + span, span_format = interfaces.library.get_root_span(request=self.r) + meta = interfaces.library.get_span_meta(span, span_format) + metrics = interfaces.library.get_span_metrics(span, span_format) + meta_struct = interfaces.library.get_span_meta_struct(span, span_format) assert meta.get("_dd.iast.enabled") == "1" or metrics.get("_dd.iast.enabled") == 1.0 assert "_dd.iast.json" not in meta assert "iast" in meta_struct @@ -61,9 +61,9 @@ def setup_appsec_event_fallback_json(self): self.r = weblog.get("/", headers={"User-Agent": "Arachni/v1"}) def test_appsec_event_fallback_json(self): - span = interfaces.library.get_root_span(request=self.r) - meta = span.get("meta", {}) - meta_struct = span.get("meta_struct", {}) + span, span_format = interfaces.library.get_root_span(request=self.r) + meta = interfaces.library.get_span_meta(span, span_format) + meta_struct = interfaces.library.get_span_meta_struct(span, span_format) assert meta["appsec.event"] == "true" assert "_dd.appsec.json" in meta assert "appsec" not in meta_struct @@ -85,9 +85,9 @@ def setup_iast_event_fallback_json(self): self.r = weblog.get("/set_cookie", params={"name": "metastruct-no", "value": "no"}) def test_iast_event_fallback_json(self): - span = interfaces.library.get_root_span(request=self.r) - meta = span.get("meta", {}) - meta_struct = span.get("meta_struct", {}) + span, span_format = interfaces.library.get_root_span(request=self.r) + meta = interfaces.library.get_span_meta(span, span_format) + meta_struct = interfaces.library.get_span_meta_struct(span, span_format) assert meta["_dd.iast.enabled"] == "1" assert "_dd.iast.json" in meta assert "iast" not in meta_struct diff --git a/tests/appsec/test_reports.py b/tests/appsec/test_reports.py index 8dbbbf146b4..35d18c11504 100644 --- a/tests/appsec/test_reports.py +++ b/tests/appsec/test_reports.py @@ -3,6 +3,7 @@ # Copyright 2021 Datadog, Inc. from utils import weblog, interfaces, scenarios, rfc, features from utils._weblog import HttpResponse +from utils.dd_constants import TraceLibraryPayloadFormat @features.security_events_metadata @@ -22,8 +23,12 @@ def check_http_code_legacy(event: dict): return True - def check_http_code(span: dict, appsec_data: dict): # noqa: ARG001 - status_code = span["meta"]["http.status_code"] + def check_http_code(span: dict, appsec_data: dict, span_format: TraceLibraryPayloadFormat | None = None): # noqa: ARG001 + # Use helper method to get meta for both v04 and v1 formats + if span_format is None: + span_format = interfaces.library._detect_span_format(span) # noqa: SLF001 + meta = interfaces.library.get_span_meta(span, span_format) + status_code = meta["http.status_code"] assert status_code == "404", f"404 should have been reported, not {status_code}" return True @@ -51,9 +56,14 @@ def _check_service_legacy(event: dict): return True - def _check_service(span: dict, appsec_data: dict): # noqa: ARG001 + def _check_service(span: dict, appsec_data: dict, span_format: TraceLibraryPayloadFormat | None = None): # noqa: ARG001 + # Use helper method to get meta for both v04 and v1 formats + if span_format is None: + span_format = interfaces.library._detect_span_format(span) # noqa: SLF001 + meta = interfaces.library.get_span_meta(span, span_format) + # Service name is at the top level in both formats name = span.get("service") - environment = span.get("meta", {}).get("env") + environment = meta.get("env") assert name == "weblog", f"weblog should have been reported, not {name}" assert environment == "system-tests", f"system-tests should have been reported, not {environment}" diff --git a/tests/appsec/test_shell_execution.py b/tests/appsec/test_shell_execution.py index a271c01efda..b23f00c8abf 100644 --- a/tests/appsec/test_shell_execution.py +++ b/tests/appsec/test_shell_execution.py @@ -15,7 +15,7 @@ class Test_ShellExecution: def fetch_command_execution_span(r: HttpResponse) -> dict: assert r.status_code == 200 - traces = [t for _, t in interfaces.library.get_traces(request=r)] + traces = [t for _, t, _ in interfaces.library.get_traces(request=r)] assert traces, "No traces found" assert len(traces) == 1 spans = traces[0] diff --git a/tests/appsec/test_traces.py b/tests/appsec/test_traces.py index 91f6f6baa84..754541eab8a 100644 --- a/tests/appsec/test_traces.py +++ b/tests/appsec/test_traces.py @@ -14,7 +14,7 @@ features, ) from utils.tools import nested_lookup -from utils.dd_constants import SamplingPriority +from utils.dd_constants import SamplingPriority, TraceLibraryPayloadFormat RUNTIME_FAMILIES = ["nodejs", "ruby", "jvm", "dotnet", "go", "php", "python", "cpp"] @@ -40,23 +40,28 @@ def test_appsec_event_span_tags(self): _sampling_priority_v1 tags """ - def validate_appsec_event_span_tags(span: dict): + def validate_appsec_event_span_tags(span: dict, span_format: TraceLibraryPayloadFormat | None = None): if span.get("parent_id") not in (0, None): # do nothing if not root span return None - if "appsec.event" not in span["meta"]: + # Use helper methods to get meta and metrics for both v04 and v1 formats + if span_format is None: + span_format = interfaces.library._detect_span_format(span) # noqa: SLF001 + meta = interfaces.library.get_span_meta(span, span_format) + metrics = interfaces.library.get_span_metrics(span, span_format) + + if "appsec.event" not in meta: raise Exception("Can't find appsec.event in span's meta") - if span["meta"]["appsec.event"] != "true": - raise Exception(f'appsec.event in span\'s meta should be "true", not {span["meta"]["appsec.event"]}') + if meta["appsec.event"] != "true": + raise Exception(f'appsec.event in span\'s meta should be "true", not {meta["appsec.event"]}') - if "_sampling_priority_v1" not in span["metrics"]: + if "_sampling_priority_v1" not in metrics: raise Exception("Metric _sampling_priority_v1 should be set on traces that are manually kept") - if span["metrics"]["_sampling_priority_v1"] != SamplingPriority.USER_KEEP: - raise Exception( - f"Trace id {span['trace_id']} , sampling priority should be {SamplingPriority.USER_KEEP}" - ) + if metrics["_sampling_priority_v1"] != SamplingPriority.USER_KEEP: + trace_id = interfaces.library.get_span_trace_id(span, None, span_format) + raise Exception(f"Trace id {trace_id} , sampling priority should be {SamplingPriority.USER_KEEP}") return True @@ -77,23 +82,30 @@ def setup_custom_span_tags(self): def test_custom_span_tags(self): """AppSec should store in all APM spans some tags when enabled.""" - spans = [span for _, span in interfaces.library.get_root_spans()] - assert spans, "No root spans to validate" - spans = [s for s in spans if s.get("type") in ("web", "serverless")] - assert spans, "No spans of type web or serverless to validate" - for span in spans: - if span.get("type") == "serverless" and "_dd.appsec.unsupported_event_type" in span["metrics"]: + spans_with_format = list(interfaces.library.get_root_spans()) + assert spans_with_format, "No root spans to validate" + # Filter spans by type using helper method + filtered_spans = [] + for _, span, span_format in spans_with_format: + span_type = interfaces.library.get_span_type(span, span_format) + if span_type in ("web", "serverless"): + filtered_spans.append((span, span_format)) + assert filtered_spans, "No spans of type web or serverless to validate" + for span, span_format in filtered_spans: + span_type = interfaces.library.get_span_type(span, span_format) + metrics = interfaces.library.get_span_metrics(span, span_format) + meta = interfaces.library.get_span_meta(span, span_format) + + if span_type == "serverless" and "_dd.appsec.unsupported_event_type" in metrics: # For serverless, the `healthcheck` event is not supported - assert span["metrics"]["_dd.appsec.unsupported_event_type"] == 1, ( + assert metrics["_dd.appsec.unsupported_event_type"] == 1, ( "_dd.appsec.unsupported_event_type should be 1 or 1.0" ) continue - assert "_dd.appsec.enabled" in span["metrics"], "Cannot find _dd.appsec.enabled in span metrics" - assert span["metrics"]["_dd.appsec.enabled"] == 1, "_dd.appsec.enabled should be 1 or 1.0" - assert "_dd.runtime_family" in span["meta"], "Cannot find _dd.runtime_family in span meta" - assert span["meta"]["_dd.runtime_family"] in RUNTIME_FAMILIES, ( - f"_dd.runtime_family should be in {RUNTIME_FAMILIES}" - ) + assert "_dd.appsec.enabled" in metrics, "Cannot find _dd.appsec.enabled in span metrics" + assert metrics["_dd.appsec.enabled"] == 1, "_dd.appsec.enabled should be 1 or 1.0" + assert "_dd.runtime_family" in meta, "Cannot find _dd.runtime_family in span meta" + assert meta["_dd.runtime_family"] in RUNTIME_FAMILIES, f"_dd.runtime_family should be in {RUNTIME_FAMILIES}" def setup_header_collection(self): self.r = weblog.get("/headers", headers={"User-Agent": "Arachni/v1", "Content-Type": "text/plain"}) @@ -108,32 +120,45 @@ def test_header_collection(self): """AppSec should collect some headers for http.request and http.response and store them in span tags. Note that this test checks for collection, not data. """ - spans = [span for _, _, span in interfaces.library.get_spans(request=self.r)] - assert spans, "No spans to validate" - for span in spans: + spans_with_format = list(interfaces.library.get_spans(request=self.r)) + assert spans_with_format, "No spans to validate" + for _, _, span, span_format in spans_with_format: + meta = interfaces.library.get_span_meta(span, span_format) required_request_headers = ["user-agent", "host", "content-type"] required_request_headers = [f"http.request.headers.{header}" for header in required_request_headers] - missing_request_headers = set(required_request_headers) - set(span.get("meta", {}).keys()) + missing_request_headers = set(required_request_headers) - set(meta.keys()) assert not missing_request_headers, f"Missing request headers: {missing_request_headers}" required_response_headers = ["content-type", "content-length", "content-language"] required_response_headers = [f"http.response.headers.{header}" for header in required_response_headers] - missing_response_headers = set(required_response_headers) - set(span.get("meta", {}).keys()) + missing_response_headers = set(required_response_headers) - set(meta.keys()) assert not missing_response_headers, f"Missing response headers: {missing_response_headers}" def test_root_span_coherence(self): """Appsec tags are not on span where type is not web, http or rpc""" valid_appsec_span_types = ["web", "http", "rpc", "serverless"] - spans = [span for _, _, span in interfaces.library.get_spans()] - assert spans, "No spans to validate" - assert any("_dd.appsec.enabled" in s.get("metrics", {}) for s in spans), "No appsec-enabled spans found" - for span in spans: - if span.get("type") in valid_appsec_span_types: + spans_with_format = list(interfaces.library.get_spans()) + assert spans_with_format, "No spans to validate" + # Check if any span has appsec enabled + has_appsec_enabled = False + for _, _, span, span_format in spans_with_format: + metrics = interfaces.library.get_span_metrics(span, span_format) + if "_dd.appsec.enabled" in metrics: + has_appsec_enabled = True + break + assert has_appsec_enabled, "No appsec-enabled spans found" + + for _, _, span, span_format in spans_with_format: + span_type = interfaces.library.get_span_type(span, span_format) + metrics = interfaces.library.get_span_metrics(span, span_format) + meta = interfaces.library.get_span_meta(span, span_format) + + if span_type in valid_appsec_span_types: continue - assert "_dd.appsec.enabled" not in span.get("metrics", {}), ( + assert "_dd.appsec.enabled" not in metrics, ( f"_dd.appsec.enabled should be present only when span type is any of {', '.join(valid_appsec_span_types)}" ) - assert "_dd.runtime_family" not in span.get("meta", {}), ( + assert "_dd.runtime_family" not in meta, ( f"_dd.runtime_family should be present only when span type is any of {', '.join(valid_appsec_span_types)}" ) @@ -305,13 +330,14 @@ def setup_header_collection(self): reason="The endpoint /headers is not implemented in the weblog", ) def test_header_collection(self): - def assert_header_in_span_meta(span: dict, header: str): - if header not in span["meta"]: + def assert_header_in_span_meta(span: dict, span_format: TraceLibraryPayloadFormat | None, header: str): + meta = interfaces.library.get_span_meta(span, span_format) + if header not in meta: raise Exception(f"Can't find {header} in span's meta") - def validate_response_headers(span: dict): + def validate_response_headers(span: dict, span_format: TraceLibraryPayloadFormat | None): for header in ["content-type", "content-length", "content-language"]: - assert_header_in_span_meta(span, f"http.response.headers.{header}") + assert_header_in_span_meta(span, span_format, f"http.response.headers.{header}") return True interfaces.library.validate_one_span(self.r, validator=validate_response_headers) @@ -337,9 +363,9 @@ def test_collect_default_request_headers(self): if context.library != "golang": # TODO(APPSEC-56898): Golang weblogs do not respond to this request. assert self.r.status_code == 200 - span = interfaces.library.get_root_span(self.r) + span, span_format = interfaces.library.get_root_span(self.r) + meta = interfaces.library.get_span_meta(span, span_format) for key, value in self.HEADERS.items(): - meta = span.get("meta", {}) meta_key = f"http.request.headers.{key.lower()}" assert meta_key in meta if key == "User-Agent": @@ -374,11 +400,12 @@ def setup_external_wafs_header_collection(self): def test_external_wafs_header_collection(self): """Collect external wafs request identifier and other security info when appsec is enabled.""" - def assert_header_in_span_meta(span: dict, header: str): - if header not in span["meta"]: + def assert_header_in_span_meta(span: dict, span_format: TraceLibraryPayloadFormat | None, header: str): + meta = interfaces.library.get_span_meta(span, span_format) + if header not in meta: raise Exception(f"Can't find {header} in span's meta") - def validate_request_headers(span: dict): + def validate_request_headers(span: dict, span_format: TraceLibraryPayloadFormat | None): for header in [ "x-amzn-trace-id", "cloudfront-viewer-ja3-fingerprint", @@ -389,7 +416,7 @@ def validate_request_headers(span: dict): "x-sigsci-tags", "akamai-user-risk", ]: - assert_header_in_span_meta(span, f"http.request.headers.{header}") + assert_header_in_span_meta(span, span_format, f"http.request.headers.{header}") return True interfaces.library.validate_one_span(self.r, validator=validate_request_headers) diff --git a/tests/appsec/test_user_blocking_full_denylist.py b/tests/appsec/test_user_blocking_full_denylist.py index 2dc332fc84d..7d00a459842 100644 --- a/tests/appsec/test_user_blocking_full_denylist.py +++ b/tests/appsec/test_user_blocking_full_denylist.py @@ -39,7 +39,8 @@ def test_blocking_test(self): for r in self.r_blocked_requests: assert r.status_code == 403 interfaces.library.assert_waf_attack(r, rule="blk-001-002", address="usr.id") - span = interfaces.library.get_root_span(r) - assert span["meta"]["appsec.event"] == "true" - assert span["meta"]["appsec.blocked"] == "true" - assert span["meta"]["http.status_code"] == "403" + span, span_format = interfaces.library.get_root_span(r) + meta = interfaces.library.get_span_meta(span, span_format) + assert meta["appsec.event"] == "true" + assert meta["appsec.blocked"] == "true" + assert meta["http.status_code"] == "403" diff --git a/tests/appsec/waf/test_addresses.py b/tests/appsec/waf/test_addresses.py index 57f550492d1..af578d68948 100644 --- a/tests/appsec/waf/test_addresses.py +++ b/tests/appsec/waf/test_addresses.py @@ -121,9 +121,13 @@ def test_specific_wrong_key(self): for r in [self.r_wk_1, self.r_wk_2]: logger.debug(f"Testing {r.request.headers}") assert r.status_code == 200 - spans = [span for _, span in interfaces.library.get_root_spans(request=r)] - assert spans, "No spans to validate" - assert any("_dd.appsec.enabled" in s.get("metrics", {}) for s in spans), "No appsec-enabled spans found" + spans_with_format = list(interfaces.library.get_root_spans(request=r)) + assert spans_with_format, "No spans to validate" + # Use helper method to get metrics for both v04 and v1 formats + assert any( + "_dd.appsec.enabled" in interfaces.library.get_span_metrics(span, span_format) + for _, span, span_format in spans_with_format + ), "No appsec-enabled spans found" interfaces.library.assert_no_appsec_event(self.r_wk_1) interfaces.library.assert_no_appsec_event(self.r_wk_2) @@ -399,11 +403,15 @@ class Test_GrpcServerMethod: def validate_span(self, span: dict, appsec_data: dict): tag = "rpc.grpc.full_method" - if tag not in span["meta"]: + # Use helper method to get meta for both v04 and v1 formats + # Note: span_validator receives the span and span_format is detected internally + span_format = interfaces.library._detect_span_format(span) # noqa: SLF001 + meta = interfaces.library.get_span_meta(span, span_format) + if tag not in meta: logger.info(f"Can't find '{tag}' in span's meta") return False - expected = span["meta"][tag] + expected = meta[tag] value = appsec_data["triggers"][0]["rule_matches"][0]["parameters"][0]["value"] if value != expected: logger.info( diff --git a/tests/appsec/waf/test_blocking_security_response_id.py b/tests/appsec/waf/test_blocking_security_response_id.py index 4044a7f14de..15327821042 100644 --- a/tests/appsec/waf/test_blocking_security_response_id.py +++ b/tests/appsec/waf/test_blocking_security_response_id.py @@ -203,8 +203,8 @@ def test_security_response_id_in_span_trigger(self): assert self.r.status_code == 403, f"Expected 403, got {self.r.status_code}" # Get the root span and extract appsec data - span = interfaces.library.get_root_span(request=self.r) - meta = span.get("meta", {}) + span, span_format = interfaces.library.get_root_span(request=self.r) + meta = interfaces.library.get_span_meta(span, span_format) meta_struct = span.get("meta_struct", {}) # Extract appsec data (support both formats: meta_struct.appsec or meta._dd.appsec.json) diff --git a/tests/appsec/waf/test_exclusions.py b/tests/appsec/waf/test_exclusions.py index 042c2694ded..450a16d8f3c 100644 --- a/tests/appsec/waf/test_exclusions.py +++ b/tests/appsec/waf/test_exclusions.py @@ -25,7 +25,7 @@ def setup_input_exclusion_positive_test(self): def test_input_exclusion_positive_test(self): assert self.r_iexpt.status_code == 200, "Request failed" - spans = [span for _, _, span in interfaces.library.get_spans(request=self.r_iexpt)] + spans = [span for _, _, span, _ in interfaces.library.get_spans(request=self.r_iexpt)] assert spans, "No spans to validate" assert any("_dd.appsec.enabled" in s.get("metrics", {}) for s in spans), "No appsec-enabled spans found" interfaces.library.assert_no_appsec_event(self.r_iexpt) @@ -49,7 +49,7 @@ def setup_rule_exclusion_positive_test(self): def test_rule_exclusion_positive_test(self): assert self.r_rept.status_code == 200, "Request failed" - spans = [span for _, _, span in interfaces.library.get_spans(request=self.r_rept)] + spans = [span for _, _, span, _ in interfaces.library.get_spans(request=self.r_rept)] assert spans, "No spans to validate" assert any("_dd.appsec.enabled" in s.get("metrics", {}) for s in spans), "No appsec-enabled spans found" interfaces.library.assert_no_appsec_event(self.r_rept) diff --git a/tests/appsec/waf/test_reports.py b/tests/appsec/waf/test_reports.py index 8a3f71c63b2..3b7bfe4a1af 100644 --- a/tests/appsec/waf/test_reports.py +++ b/tests/appsec/waf/test_reports.py @@ -5,6 +5,7 @@ import json from utils import weblog, context, interfaces, irrelevant, scenarios, features +from utils.dd_constants import TraceLibraryPayloadFormat @features.support_in_app_waf_metrics_report @@ -26,14 +27,21 @@ def test_waf_monitoring(self): # Tags that are expected to be reported at least once at some point - def validate_waf_monitoring_span_tags(span: dict, appsec_data: dict): # noqa: ARG001 + def validate_waf_monitoring_span_tags( + span: dict, + appsec_data: dict, # noqa: ARG001 + span_format: TraceLibraryPayloadFormat | None = None, + ): """Validate the mandatory waf monitoring span tags are added to the request span having an attack""" + # Use helper methods to get meta and metrics for both v04 and v1 formats + if span_format is None: + span_format = interfaces.library._detect_span_format(span) # noqa: SLF001 + meta = interfaces.library.get_span_meta(span, span_format) + metrics = interfaces.library.get_span_metrics(span, span_format) - meta = span["meta"] for m in expected_waf_monitoring_meta_tags: assert m in meta, f"missing span meta tag `{m}` in meta" - metrics = span["metrics"] for m in expected_waf_monitoring_metrics_tags: assert m in metrics, f"missing span metric tag `{m}` in metrics" @@ -68,16 +76,19 @@ def test_waf_monitoring_once(self): expected_rules_monitoring_nb_errors_tag, ] - def validate_rules_monitoring_span_tags(span: dict): + def validate_rules_monitoring_span_tags(span: dict, span_format: TraceLibraryPayloadFormat | None = None): """Validate the mandatory rules monitoring span tags are added to a request span at some point such as the first request or first attack. """ + # Use helper methods to get meta and metrics for both v04 and v1 formats + if span_format is None: + span_format = interfaces.library._detect_span_format(span) # noqa: SLF001 + meta = interfaces.library.get_span_meta(span, span_format) + metrics = interfaces.library.get_span_metrics(span, span_format) - meta = span["meta"] if expected_waf_version_tag not in meta: return None # Skip this span - metrics = span["metrics"] for m in expected_rules_monitoring_metrics_tags: if m not in metrics: return None # Skip this span @@ -134,12 +145,15 @@ def test_waf_monitoring_once_rfc1025(self): # Tags that are expected to be reported at least once at some point expected_waf_version_tag = "_dd.appsec.waf.version" - def validate_rules_monitoring_span_tags(span: dict): + def validate_rules_monitoring_span_tags(span: dict, span_format: TraceLibraryPayloadFormat | None = None): """Validate the mandatory rules monitoring span tags are added to a request span at some point such as the first request or first attack. """ + # Use helper methods to get meta for both v04 and v1 formats + if span_format is None: + span_format = interfaces.library._detect_span_format(span) # noqa: SLF001 + meta = interfaces.library.get_span_meta(span, span_format) - meta = span["meta"] if expected_waf_version_tag not in meta: return None # Skip this span @@ -165,8 +179,12 @@ def test_waf_monitoring_optional(self): expected_bindings_duration_metric = "_dd.appsec.waf.duration_ext" expected_metrics_tags = [expected_waf_duration_metric, expected_bindings_duration_metric] - def validate_waf_span_tags(span: dict, appsec_data: dict): # noqa: ARG001 - metrics = span["metrics"] + def validate_waf_span_tags(span: dict, appsec_data: dict, span_format: TraceLibraryPayloadFormat | None = None): # noqa: ARG001 + # Use helper method to get metrics for both v04 and v1 formats + if span_format is None: + span_format = interfaces.library._detect_span_format(span) # noqa: SLF001 + metrics = interfaces.library.get_span_metrics(span, span_format) + for m in expected_metrics_tags: if m not in metrics: raise Exception(f"missing span metric tag `{m}` in {metrics}") @@ -208,17 +226,20 @@ def test_waf_monitoring_errors(self): expected_nb_errors = 2 expected_error_details = {"missing key 'name'": ["missing-name"], "missing key 'tags'": ["missing-tags"]} - def validate_rules_monitoring_span_tags(span: dict): + def validate_rules_monitoring_span_tags(span: dict, span_format: TraceLibraryPayloadFormat | None = None): """Validate the mandatory rules monitoring span tags are added to a request span at some point such as the first request or first attack. """ + # Use helper methods to get meta and metrics for both v04 and v1 formats + if span_format is None: + span_format = interfaces.library._detect_span_format(span) # noqa: SLF001 + meta = interfaces.library.get_span_meta(span, span_format) + metrics = interfaces.library.get_span_metrics(span, span_format) - meta = span["meta"] for m in expected_rules_monitoring_meta_tags: if m not in meta: return None # Skip this span - metrics = span["metrics"] for m in expected_rules_monitoring_metrics_tags: if m not in metrics: return None # Skip this span diff --git a/tests/appsec/waf/test_telemetry.py b/tests/appsec/waf/test_telemetry.py index 6f0fa8af4f4..c212392841b 100644 --- a/tests/appsec/waf/test_telemetry.py +++ b/tests/appsec/waf/test_telemetry.py @@ -160,14 +160,15 @@ def test_metric_waf_requests(self): def test_waf_requests_match_traced_requests(self): """Total waf.requests metric should match the number of requests in traces.""" - spans = [s for _, s in interfaces.library.get_root_spans()] - spans = [ - s - for s in spans - if s.get("meta", {}).get("span.kind") == "server" - # excluding graphql introspection query executed on startup in nodejs - and s.get("meta", {}).get("graphql.operation.name") != "IntrospectionQuery" - ] + spans_with_format = [(span, span_format) for _, span, span_format in interfaces.library.get_root_spans()] + spans = [] + for span, span_format in spans_with_format: + meta = interfaces.library.get_span_meta(span, span_format) + # Filter for server spans only + if meta.get("span.kind") == "server": + # excluding graphql introspection query executed on startup in nodejs + if meta.get("graphql.operation.name") != "IntrospectionQuery": + spans.append(span) request_count = len(spans) assert request_count >= 3 diff --git a/tests/appsec/waf/test_truncation.py b/tests/appsec/waf/test_truncation.py index 890b5dfbf12..6af0ee84df4 100644 --- a/tests/appsec/waf/test_truncation.py +++ b/tests/appsec/waf/test_truncation.py @@ -31,8 +31,8 @@ def setup_truncation(self): ) def test_truncation(self): - span = interfaces.library.get_root_span(self.req) - metrics = span.get("metrics") + span, span_format = interfaces.library.get_root_span(self.req) + metrics = interfaces.library.get_span_metrics(span, span_format) assert metrics, "Expected metrics" assert int(metrics["_dd.appsec.truncated.string_length"]) == 5000 diff --git a/tests/external_processing/test_apm.py b/tests/external_processing/test_apm.py index 373ed87c9a0..d1095884454 100644 --- a/tests/external_processing/test_apm.py +++ b/tests/external_processing/test_apm.py @@ -10,8 +10,9 @@ def setup_correct_span_structure(self): def test_correct_span_structure(self): assert self.r.status_code == 200 interfaces.library.assert_trace_exists(self.r) - span = interfaces.library.get_root_span(self.r) + span, span_format = interfaces.library.get_root_span(self.r) assert span["type"] == "web" - assert span["meta"]["span.kind"] == "server" - assert span["meta"]["http.url"] == "http://localhost:7777/" - assert span["meta"]["http.host"] == "localhost:7777" + meta = interfaces.library.get_span_meta(span, span_format) + assert meta["span.kind"] == "server" + assert meta["http.url"] == "http://localhost:7777/" + assert meta["http.host"] == "localhost:7777" diff --git a/tests/integrations/crossed_integrations/test_kafka.py b/tests/integrations/crossed_integrations/test_kafka.py index e8b2c1a2682..f691cc7762b 100644 --- a/tests/integrations/crossed_integrations/test_kafka.py +++ b/tests/integrations/crossed_integrations/test_kafka.py @@ -18,8 +18,13 @@ class _BaseKafka: def get_span(cls, interface: interfaces.LibraryInterfaceValidator, span_kind: str, topic: str) -> dict | None: logger.debug(f"Trying to find traces with span kind: {span_kind} and topic: {topic} in {interface}") - for data, trace in interface.get_traces(): - for span in trace: + for data, trace, _trace_format in interface.get_traces(): + # Handle both v04 (list) and v1 (dict) formats + if isinstance(trace, list): + spans = trace + else: + spans = trace.get("spans", []) + for span in spans: if not span.get("meta"): continue diff --git a/tests/integrations/crossed_integrations/test_kinesis.py b/tests/integrations/crossed_integrations/test_kinesis.py index 55911d8858c..c2a4a585d74 100644 --- a/tests/integrations/crossed_integrations/test_kinesis.py +++ b/tests/integrations/crossed_integrations/test_kinesis.py @@ -20,8 +20,13 @@ def get_span( ) -> dict | None: logger.debug(f"Trying to find traces with span kind: {span_kind} and stream: {stream} in {interface}") - for data, trace in interface.get_traces(): - for span in trace: + for data, trace, _trace_format in interface.get_traces(): + # Handle both v04 (list) and v1 (dict) formats + if isinstance(trace, list): + spans = trace + else: + spans = trace.get("spans", []) + for span in spans: if not span.get("meta"): continue diff --git a/tests/integrations/crossed_integrations/test_rabbitmq.py b/tests/integrations/crossed_integrations/test_rabbitmq.py index 5af6110c534..a4384e41bbd 100644 --- a/tests/integrations/crossed_integrations/test_rabbitmq.py +++ b/tests/integrations/crossed_integrations/test_rabbitmq.py @@ -29,8 +29,13 @@ def get_span( ) -> dict | None: logger.debug(f"Trying to find traces with span kind: {span_kind} and queue: {queue} in {interface}") - for data, trace in interface.get_traces(): - for span in trace: + for data, trace, _trace_format in interface.get_traces(): + # Handle both v04 (list) and v1 (dict) formats + if isinstance(trace, list): + spans = trace + else: + spans = trace.get("spans", []) + for span in spans: if not span.get("meta"): continue diff --git a/tests/integrations/crossed_integrations/test_sns_to_sqs.py b/tests/integrations/crossed_integrations/test_sns_to_sqs.py index fb2e6c2977e..ee683edefc1 100644 --- a/tests/integrations/crossed_integrations/test_sns_to_sqs.py +++ b/tests/integrations/crossed_integrations/test_sns_to_sqs.py @@ -28,10 +28,15 @@ def get_span( logger.debug(f"Trying to find traces with span kind: {span_kind} and queue: {queue} in {interface}") manual_span_found = False - for data, trace in interface.get_traces(): + for data, trace, _trace_format in interface.get_traces(): + # Handle both v04 (list) and v1 (dict) formats + if isinstance(trace, list): + spans = trace + else: + spans = trace.get("spans", []) # we iterate the trace backwards to deal with the case of JS "aws.response" callback spans, which are similar for this test and test_sqs. # Instead, we look for the custom span created after the "aws.response" span - for span in reversed(trace): + for span in reversed(spans): if not span.get("meta"): continue diff --git a/tests/integrations/crossed_integrations/test_sqs.py b/tests/integrations/crossed_integrations/test_sqs.py index 63dd5123912..f7c763cd066 100644 --- a/tests/integrations/crossed_integrations/test_sqs.py +++ b/tests/integrations/crossed_integrations/test_sqs.py @@ -21,10 +21,15 @@ def get_span( logger.debug(f"Trying to find traces with span kind: {span_kind} and queue: {queue} in {interface}") manual_span_found = False - for data, trace in interface.get_traces(): + for data, trace, _trace_format in interface.get_traces(): + # Handle both v04 (list) and v1 (dict) formats + if isinstance(trace, list): + spans = trace + else: + spans = trace.get("spans", []) # we iterate the trace backwards to deal with the case of JS "aws.response" callback spans, which are similar for this test and test_sns_to_sqs. # Instead, we look for the custom span created after the "aws.response" span - for span in reversed(trace): + for span in reversed(spans): assert isinstance(span, dict), f"Span is not a dict: {data['log_filename']}" if not span.get("meta"): continue diff --git a/tests/integrations/test_dbm.py b/tests/integrations/test_dbm.py index 701bfce4a5b..6827a6ca83b 100644 --- a/tests/integrations/test_dbm.py +++ b/tests/integrations/test_dbm.py @@ -59,8 +59,12 @@ def _get_db_span(self, response: HttpResponse) -> dict: spans = [] # we do not use get_spans: the span we look for is not directly the span that carry the request information - for data, trace in interfaces.library.get_traces(request=response): - spans += [(data, span) for span in trace if span.get("type") == "sql"] + for data, trace, _trace_format in interfaces.library.get_traces(request=response): + # Handle both v04 (list) and v1 (dict) formats + if isinstance(trace, list): + spans += [(data, span) for span in trace if span.get("type") == "sql"] + else: + spans += [(data, span) for span in trace.get("spans", []) if span.get("type") == "sql"] if len(spans) == 0: raise ValueError("No span found with meta.type == 'sql'") diff --git a/tests/integrations/test_inferred_proxy.py b/tests/integrations/test_inferred_proxy.py index 71ae2ccd0d7..b33c8ac581b 100644 --- a/tests/integrations/test_inferred_proxy.py +++ b/tests/integrations/test_inferred_proxy.py @@ -118,8 +118,13 @@ def test_api_gateway_inferred_span_creation_error(self): def get_span(interface: interfaces.LibraryInterfaceValidator, resource: str) -> dict | None: logger.debug(f"Trying to find API Gateway span for interface: {interface}") - for data, trace in interface.get_traces(): - for span in trace: + for data, trace, _trace_format in interface.get_traces(): + # Handle both v04 (list) and v1 (dict) formats + if isinstance(trace, list): + spans = trace + else: + spans = trace.get("spans", []) + for span in spans: if not span.get("meta"): continue diff --git a/tests/integrations/utils.py b/tests/integrations/utils.py index 4cb047e0556..e6dddd8f396 100644 --- a/tests/integrations/utils.py +++ b/tests/integrations/utils.py @@ -85,11 +85,11 @@ def get_requests( @staticmethod def get_span_from_tracer(weblog_request: HttpResponse) -> dict: - for _, _, span in interfaces.library.get_spans(weblog_request): + for _, _, span, _ in interfaces.library.get_spans(weblog_request): logger.info(f"Span found with trace id: {span['trace_id']} and span id: {span['span_id']}") # iterate over all trace to be sure to miss nothing - for _, _, span_child in interfaces.library.get_spans(): + for _, _, span_child, _ in interfaces.library.get_spans(): if span_child["trace_id"] != span["trace_id"]: continue diff --git a/tests/test_config_consistency.py b/tests/test_config_consistency.py index bee197f2604..3ba757352ef 100644 --- a/tests/test_config_consistency.py +++ b/tests/test_config_consistency.py @@ -49,7 +49,8 @@ def test_status_code_400(self): assert interfaces.agent.get_span_type(span, span_format) == "web" span_meta = interfaces.agent.get_span_meta(span, span_format) assert span_meta["http.status_code"] == "400" - assert "error" not in span or span["error"] == 0 + # Error field is the same in both formats (top-level field) + assert "error" not in span or span.get("error") == 0 def setup_status_code_500(self): self.r = weblog.get("/status?code=500") @@ -64,7 +65,8 @@ def test_status_code_500(self): span_meta = interfaces.agent.get_span_meta(span, span_format) assert span_meta["http.status_code"] == "500" - assert span["error"] + # Error field is the same in both formats (top-level field) + assert span.get("error") @scenarios.tracing_config_nondefault @@ -85,7 +87,8 @@ def test_status_code_200(self): assert interfaces.agent.get_span_type(span, span_format) == "web" span_meta = interfaces.agent.get_span_meta(span, span_format) assert span_meta["http.status_code"] == "200" - assert span["error"] + # Error field is the same in both formats (top-level field) + assert span.get("error") def setup_status_code_202(self): self.r = weblog.get("/status?code=202") @@ -117,19 +120,23 @@ def setup_query_string_obfuscation_empty_client(self): reason="APMAPI-770", ) def test_query_string_obfuscation_empty_client(self): - spans = [s for _, _, s in interfaces.library.get_spans(request=self.r, full_trace=True)] - client_span = _get_span_by_tags( + spans = [(s, f) for _, _, s, f in interfaces.library.get_spans(request=self.r, full_trace=True)] + client_span_result = _get_span_by_tags( spans, tags={"span.kind": "client", "http.url": "http://weblog:7777/?key=monkey"} ) - assert client_span, "\n".join([str(s) for s in spans]) + # Handle both (span, format) tuple and plain span return values + client_span = client_span_result[0] if isinstance(client_span_result, tuple) else client_span_result + assert client_span, "\n".join([str(s) for s, _ in spans]) def setup_query_string_obfuscation_empty_server(self): self.r = weblog.get("/?application_key=value") def test_query_string_obfuscation_empty_server(self): - spans = [s for _, _, s in interfaces.library.get_spans(request=self.r, full_trace=True)] - server_span = _get_span_by_tags(spans, tags={"http.url": "http://localhost:7777/?application_key=value"}) - assert server_span, "\n".join([str(s) for s in spans]) + spans = [(s, f) for _, _, s, f in interfaces.library.get_spans(request=self.r, full_trace=True)] + server_span_result = _get_span_by_tags(spans, tags={"http.url": "http://localhost:7777/?application_key=value"}) + # Handle both (span, format) tuple and plain span return values + server_span = server_span_result[0] if isinstance(server_span_result, tuple) else server_span_result + assert server_span, "\n".join([str(s) for s, _ in spans]) @scenarios.tracing_config_nondefault @@ -143,11 +150,13 @@ def setup_query_string_obfuscation_configured_client(self): reason="Missing endpoint", ) def test_query_string_obfuscation_configured_client(self): - spans = [s for _, _, s in interfaces.library.get_spans(request=self.r, full_trace=True)] - client_span = _get_span_by_tags( + spans = [(s, f) for _, _, s, f in interfaces.library.get_spans(request=self.r, full_trace=True)] + client_span_result = _get_span_by_tags( spans, tags={"span.kind": "client", "http.url": "http://weblog:7777/?"} ) - assert client_span, "\n".join([str(s) for s in spans]) + # Handle both (span, format) tuple and plain span return values + client_span = client_span_result[0] if isinstance(client_span_result, tuple) else client_span_result + assert client_span, "\n".join([str(s) for s, _ in spans]) def setup_query_string_obfuscation_configured_server(self): self.r = weblog.get("/?ssn=123-45-6789") @@ -168,11 +177,13 @@ def setup_query_string_obfuscation_configured_client(self): reason="Missing endpoint", ) def test_query_string_obfuscation_configured_client(self): - spans = [s for _, _, s in interfaces.library.get_spans(request=self.r, full_trace=True)] - client_span = _get_span_by_tags( + spans = [(s, f) for _, _, s, f in interfaces.library.get_spans(request=self.r, full_trace=True)] + client_span_result = _get_span_by_tags( spans, tags={"span.kind": "client", "http.url": "http://weblog:7777/?"} ) - assert client_span, "\n".join([str(s) for s in spans]) + # Handle both (span, format) tuple and plain span return values + client_span = client_span_result[0] if isinstance(client_span_result, tuple) else client_span_result + assert client_span, "\n".join([str(s) for s, _ in spans]) def setup_query_string_obfuscation_configured_server(self): self.r = weblog.get("/?token=value") @@ -197,10 +208,13 @@ def test_status_code_400(self): assert content["status_code"] == 400 interfaces.library.assert_trace_exists(self.r) - spans = [s for _, _, s in interfaces.library.get_spans(request=self.r, full_trace=True)] + spans = [(s, f) for _, _, s, f in interfaces.library.get_spans(request=self.r, full_trace=True)] - client_span = _get_span_by_tags(spans, tags={"span.kind": "client", "http.status_code": "400"}) + client_span_result = _get_span_by_tags(spans, tags={"span.kind": "client", "http.status_code": "400"}) + # Handle both (span, format) tuple and plain span return values + client_span = client_span_result[0] if isinstance(client_span_result, tuple) else client_span_result assert client_span, spans + # Error field is the same in both formats (top-level field) assert client_span.get("error") == 1 def setup_status_code_500(self): @@ -212,10 +226,13 @@ def test_status_code_500(self): assert content["status_code"] == 500 interfaces.library.assert_trace_exists(self.r) - spans = [s for _, _, s in interfaces.library.get_spans(request=self.r, full_trace=True)] + spans = [(s, f) for _, _, s, f in interfaces.library.get_spans(request=self.r, full_trace=True)] - client_span = _get_span_by_tags(spans, tags={"span.kind": "client", "http.status_code": "500"}) + client_span_result = _get_span_by_tags(spans, tags={"span.kind": "client", "http.status_code": "500"}) + # Handle both (span, format) tuple and plain span return values + client_span = client_span_result[0] if isinstance(client_span_result, tuple) else client_span_result assert client_span, spans + # Error field is the same in both formats (top-level field) assert client_span.get("error") is None or client_span.get("error") == 0 @@ -233,10 +250,13 @@ def test_status_code_200(self): assert content["status_code"] == 200 interfaces.library.assert_trace_exists(self.r) - spans = [s for _, _, s in interfaces.library.get_spans(request=self.r, full_trace=True)] + spans = [(s, f) for _, _, s, f in interfaces.library.get_spans(request=self.r, full_trace=True)] - client_span = _get_span_by_tags(spans, tags={"span.kind": "client", "http.status_code": "200"}) + client_span_result = _get_span_by_tags(spans, tags={"span.kind": "client", "http.status_code": "200"}) + # Handle both (span, format) tuple and plain span return values + client_span = client_span_result[0] if isinstance(client_span_result, tuple) else client_span_result assert client_span, spans + # Error field is the same in both formats (top-level field) assert client_span.get("error") == 1 def setup_status_code_202(self): @@ -248,10 +268,13 @@ def test_status_code_202(self): assert content["status_code"] == 202 interfaces.library.assert_trace_exists(self.r) - spans = [s for _, _, s in interfaces.library.get_spans(request=self.r, full_trace=True)] + spans = [(s, f) for _, _, s, f in interfaces.library.get_spans(request=self.r, full_trace=True)] - client_span = _get_span_by_tags(spans, tags={"span.kind": "client", "http.status_code": "202"}) + client_span_result = _get_span_by_tags(spans, tags={"span.kind": "client", "http.status_code": "202"}) + # Handle both (span, format) tuple and plain span return values + client_span = client_span_result[0] if isinstance(client_span_result, tuple) else client_span_result assert client_span, spans + # Error field is the same in both formats (top-level field) assert client_span.get("error") == 1 @@ -264,7 +287,7 @@ def setup_query_string_redaction_unset(self): self.r = weblog.get("/make_distant_call", params={"url": "http://weblog:7777/?hi=monkey"}) def test_query_string_redaction_unset(self): - trace = [span for _, _, span in interfaces.library.get_spans(self.r, full_trace=True)] + trace = [(span, f) for _, _, span, f in interfaces.library.get_spans(self.r, full_trace=True)] expected_tags = {"http.url": "http://weblog:7777/?hi=monkey"} assert _get_span_by_tags(trace, expected_tags), f"Span with tags {expected_tags} not found in {trace}" @@ -278,7 +301,7 @@ def setup_query_string_redaction(self): self.r = weblog.get("/make_distant_call", params={"url": "http://weblog:7777/?hi=monkey"}) def test_query_string_redaction(self): - trace = [span for _, _, span in interfaces.library.get_spans(self.r, full_trace=True)] + trace = [(span, f) for _, _, span, f in interfaces.library.get_spans(self.r, full_trace=True)] expected_tags = {"http.url": "http://weblog:7777/"} assert _get_span_by_tags(trace, expected_tags), f"Span with tags {expected_tags} not found in {trace}" @@ -297,7 +320,7 @@ def setup_ip_headers_sent_in_one_request(self): def test_ip_headers_sent_in_one_request(self): # Ensures the header set in DD_TRACE_CLIENT_IP_HEADER takes precedence over all supported ip headers - trace = [span for _, _, span in interfaces.library.get_spans(self.req, full_trace=True)] + trace = [(span, f) for _, _, span, f in interfaces.library.get_spans(self.req, full_trace=True)] expected_tags = {"http.client_ip": "5.6.7.9"} assert _get_span_by_tags(trace, expected_tags), f"Span with tags {expected_tags} not found in {trace}" @@ -313,7 +336,7 @@ def setup_ip_headers_sent_in_one_request(self): ) def test_ip_headers_sent_in_one_request(self): - spans = [span for _, _, span in interfaces.library.get_spans(self.req, full_trace=True)] + spans = [(span, f) for _, _, span, f in interfaces.library.get_spans(self.req, full_trace=True)] logger.info(spans) expected_tags = {"http.client_ip": "5.6.7.9"} assert _get_span_by_tags(spans, expected_tags) == {} @@ -364,29 +387,39 @@ def test_ip_headers_precedence(self): ip = ip.removeprefix("for=") - trace = [span for _, _, span in interfaces.library.get_spans(req, full_trace=True)] + trace = [(span, f) for _, _, span, f in interfaces.library.get_spans(req, full_trace=True)] expected_tags = {"http.client_ip": ip} assert _get_span_by_tags(trace, expected_tags), f"Span with tags {expected_tags} not found in {trace}" def _get_span_by_tags(spans: list, tags: dict): + """Find a span by tags. Accepts either list of spans or list of (span, span_format) tuples.""" logger.info(f"Try to find span with metag tags {tags}") - for span in spans: - meta = span["meta"] - logger.debug(f"Checking span {span['span_id']} meta:\n{'\n'.join(map(str, meta.items()))}") + for span_item in spans: + # Handle both (span, span_format) tuples and plain spans for backward compatibility + if isinstance(span_item, tuple) and len(span_item) == 2: + span, span_format = span_item + else: + span = span_item + span_format = None + + meta = interfaces.library.get_span_meta(span, span_format) + span_id = span.get("span_id") or span.get("id", "unknown") + logger.debug(f"Checking span {span_id} meta:\n{'\n'.join(map(str, meta.items()))}") # Avoids retrieving the client span by the operation/resource name, this value varies between languages # Use the expected tags to identify the span for k, v in tags.items(): if k not in meta: - logger.debug(f"Span {span['span_id']} does not have tag {k}") + logger.debug(f"Span {span_id} does not have tag {k}") break elif meta[k] != v: - logger.debug(f"Span {span['span_id']} has tag {k}={meta[k]} instead of {v}") + logger.debug(f"Span {span_id} has tag {k}={meta[k]} instead of {v}") break else: - logger.info(f"Span found: {span['span_id']}") - return span + logger.info(f"Span found: {span_id}") + # Return (span, span_format) if format was provided, otherwise just span + return (span, span_format) if span_format is not None else span logger.warning("No span with those tags has been found") return {} @@ -448,18 +481,32 @@ def setup_integration_enabled_false(self): def test_integration_enabled_false(self): assert self.r.status_code == 200 - spans = [span for _, _, span in interfaces.library.get_spans(request=self.r, full_trace=True)] - assert spans, "No spans found in trace" + spans_with_format = [ + (span, f) for _, _, span, f in interfaces.library.get_spans(request=self.r, full_trace=True) + ] + assert spans_with_format, "No spans found in trace" # Ruby kafka integration generates a span with the name "kafka.producer.*", # unlike python/dotnet/etc. which generates a "kafka.produce" span if context.library == "php": - assert list(filter(lambda span: "pdo" in span.get("service"), spans)) == [], ( - f"PDO span was found in trace: {spans}" - ) + assert ( + list( + filter( + lambda item: "pdo" in item[0].get("service", ""), + spans_with_format, + ) + ) + == [] + ), f"PDO span was found in trace: {spans_with_format}" else: - assert list(filter(lambda span: "kafka.produce" in span.get("name"), spans)) == [], ( - f"kafka.produce span was found in trace: {spans}" - ) + assert ( + list( + filter( + lambda item: "kafka.produce" in interfaces.library.get_span_name(item[0], item[1]), + spans_with_format, + ) + ) + == [] + ), f"kafka.produce span was found in trace: {spans_with_format}" @rfc("https://docs.google.com/document/d/1kI-gTAKghfcwI7YzKhqRv2ExUstcHqADIWA4-TZ387o/edit#heading=h.8v16cioi7qxp") @@ -477,19 +524,27 @@ def setup_integration_enabled_true(self): def test_integration_enabled_true(self): assert self.r.status_code == 200 - spans = [span for _, _, span in interfaces.library.get_spans(request=self.r, full_trace=True)] - assert spans, "No spans found in trace" + spans_with_format = [ + (span, f) for _, _, span, f in interfaces.library.get_spans(request=self.r, full_trace=True) + ] + assert spans_with_format, "No spans found in trace" # PHP uses the pdo integration if context.library == "php": - assert list(filter(lambda span: "pdo" in span.get("service"), spans)), ( - f"No PDO span found in trace: {spans}" - ) + assert list( + filter( + lambda item: "pdo" in item[0].get("service", ""), + spans_with_format, + ) + ), f"No PDO span found in trace: {spans_with_format}" else: # Ruby kafka integration generates a span with the name "kafka.producer.*", # unlike python/dotnet/etc. which generates a "kafka.produce" span - assert list(filter(lambda span: "kafka.produce" in span.get("name"), spans)), ( - f"No kafka.produce span found in trace: {spans}" - ) + assert list( + filter( + lambda item: "kafka.produce" in interfaces.library.get_span_name(item[0], item[1]), + spans_with_format, + ) + ), f"No kafka.produce span found in trace: {spans_with_format}" @rfc("https://docs.google.com/document/d/1kI-gTAKghfcwI7YzKhqRv2ExUstcHqADIWA4-TZ387o/edit#heading=h.8v16cioi7qxp") diff --git a/tests/test_data_integrity.py b/tests/test_data_integrity.py index e31df252c26..981a54539b6 100644 --- a/tests/test_data_integrity.py +++ b/tests/test_data_integrity.py @@ -6,7 +6,7 @@ import string from utils import weblog, interfaces, context, rfc, missing_feature, features, scenarios, logger -from utils.dd_constants import SamplingPriority, TraceAgentPayloadFormat +from utils.dd_constants import SamplingPriority, TraceAgentPayloadFormat, TraceLibraryPayloadFormat from utils.cgroup_info import get_container_id @@ -208,13 +208,35 @@ def test_headers(self): def test_traces_coherence(self): """Agent does not like incoherent data. Check that no incoherent data are coming from the tracer""" - for data, trace in interfaces.library.get_traces(): + for data, trace, trace_format in interfaces.library.get_traces(): assert data["response"]["status_code"] == 200 - trace_id = trace[0]["trace_id"] + # Handle both v04 (list of spans) and v1 (trace chunk) formats + if trace_format == TraceLibraryPayloadFormat.v1: + # v1 format: trace is a chunk dict with spans array + assert isinstance(trace, dict), "v1 format trace must be a dict" + spans = trace.get("spans", []) + if not spans: + continue + trace_id = interfaces.library.get_span_trace_id(spans[0], trace, trace_format) + else: + # v04 format: trace is a list of spans + assert isinstance(trace, list), "v04 format trace must be a list" + if not trace: + continue + trace_id = interfaces.library.get_span_trace_id(trace[0], None, trace_format) assert isinstance(trace_id, int) assert trace_id > 0 - for span in trace: - assert span["trace_id"] == trace_id + spans_to_check = ( + trace.get("spans", []) + if trace_format == TraceLibraryPayloadFormat.v1 and isinstance(trace, dict) + else trace + ) + for span in spans_to_check: + trace_dict: dict | None = ( + trace if trace_format == TraceLibraryPayloadFormat.v1 and isinstance(trace, dict) else None + ) + span_trace_id = interfaces.library.get_span_trace_id(span, trace_dict, trace_format) + assert span_trace_id == trace_id @features.agent_data_integrity @@ -238,37 +260,50 @@ def test_agent_do_not_drop_traces(self): trace_ids_reported_by_agent.add(int(span["traceID"])) break - def get_span_with_sampling_data(trace: list): + def get_span_with_sampling_data(trace: dict | list[dict], trace_format: TraceLibraryPayloadFormat): # The root span is not necessarily the span wherein the sampling priority can be found. # If present, the root will take precedence, and otherwise the first span with the # sampling priority tag will be returned. This is the same logic found on the trace-agent. + if trace_format == TraceLibraryPayloadFormat.v1: + assert isinstance(trace, dict), "v1 format trace must be a dict" + spans_to_check = trace.get("spans", []) + else: + assert isinstance(trace, list), "v04 format trace must be a list" + spans_to_check = trace span_with_sampling_data = None - for span in trace: - if span.get("metrics", {}).get("_sampling_priority_v1", None) is not None: - if span.get("parent_id") in (0, None): - return span + for span in spans_to_check: + metrics = interfaces.library.get_span_metrics(span, trace_format) + if metrics.get("_sampling_priority_v1", None) is not None: + parent_id = interfaces.library.get_span_parent_id(span, trace_format) + if parent_id in (0, None): + return span, trace_format elif span_with_sampling_data is None: - span_with_sampling_data = span + span_with_sampling_data = (span, trace_format) return span_with_sampling_data all_traces_are_reported = True trace_ids_reported_by_tracer = set() # check that all traces reported by the tracer are also reported by the agent - for data, trace in interfaces.library.get_traces(): - span = get_span_with_sampling_data(trace) - if not span: + for data, trace, trace_format in interfaces.library.get_traces(): + span_result = get_span_with_sampling_data(trace, trace_format) + if not span_result: continue - metrics = span["metrics"] + span, span_format = span_result + metrics = interfaces.library.get_span_metrics(span, span_format) sampling_priority = metrics.get("_sampling_priority_v1") if sampling_priority in (SamplingPriority.AUTO_KEEP, SamplingPriority.USER_KEEP): - trace_ids_reported_by_tracer.add(span["trace_id"]) - if span["trace_id"] not in trace_ids_reported_by_agent: - logger.error(f"Trace {span['trace_id']} has not been reported ({data['log_filename']})") + trace_dict: dict | None = ( + trace if trace_format == TraceLibraryPayloadFormat.v1 and isinstance(trace, dict) else None + ) + trace_id = interfaces.library.get_span_trace_id(span, trace_dict, span_format) + trace_ids_reported_by_tracer.add(trace_id) + if trace_id not in trace_ids_reported_by_agent: + logger.error(f"Trace {trace_id} has not been reported ({data['log_filename']})") all_traces_are_reported = False else: - logger.debug(f"Trace {span['trace_id']} has been reported ({data['log_filename']})") + logger.debug(f"Trace {trace_id} has been reported ({data['log_filename']})") if not all_traces_are_reported: logger.info(f"Tracer reported {len(trace_ids_reported_by_tracer)} traces") diff --git a/tests/test_distributed.py b/tests/test_distributed.py index 05f913cf844..f725cac7a60 100644 --- a/tests/test_distributed.py +++ b/tests/test_distributed.py @@ -58,22 +58,34 @@ def setup_span_links_from_conflicting_contexts(self): def test_span_links_from_conflicting_contexts(self): trace = [ - span - for _, _, span in interfaces.library.get_spans(self.req, full_trace=True) - if _retrieve_span_links(span) is not None - and span["trace_id"] == 2 - and span["parent_id"] == 10 # Only fetch the trace that is related to the header extractions + (span, span_format, trace_chunk) + for _, trace_chunk, span, span_format in interfaces.library.get_spans(self.req, full_trace=True) + if interfaces.library.get_span_links(span, span_format) is not None + and interfaces.library.get_span_trace_id( + span, trace_chunk if isinstance(trace_chunk, dict) else None, span_format + ) + == 2 + and interfaces.library.get_span_parent_id(span, span_format) + == 10 # Only fetch the trace that is related to the header extractions ] - assert len(trace) == 1 - span = trace[0] - links = _retrieve_span_links(span) - assert len(links) == 1 + assert len(trace) == 1, f"Expected 1 span with matching criteria, got {len(trace)}. Trace: {trace}" + span, span_format, _trace_chunk = trace[0] + links = interfaces.library.get_span_links(span, span_format) + assert links is not None, f"No span links found in span. Span keys: {list(span.keys())}" + assert len(links) == 1, f"Expected 1 link, got {len(links)}. Links: {links}" link1 = links[0] - assert link1["trace_id"] == 2 + assert "trace_id" in link1, f"trace_id not found in link. Link keys: {list(link1.keys())}, Link: {link1}" + # Convert hex trace_id to int if needed + trace_id_value = link1["trace_id"] + if isinstance(trace_id_value, str) and trace_id_value.startswith("0x"): + trace_id_value = int(trace_id_value[-16:], 16) + assert trace_id_value == 2, f"Expected trace_id 2, got {trace_id_value}" assert link1["span_id"] == 987654321 assert link1["attributes"] == {"reason": "terminated_context", "context_headers": "tracecontext"} - assert link1["trace_id_high"] == 1229782938247303441 + # trace_id_high might not be present in v1 format + if "trace_id_high" in link1: + assert link1["trace_id_high"] == 1229782938247303441 """Datadog and tracecontext headers, trace-id does match, Datadog is primary context we want to make sure there's no span link since they match""" @@ -92,11 +104,15 @@ def setup_no_span_links_from_nonconflicting_contexts(self): def test_no_span_links_from_nonconflicting_contexts(self): trace = [ - span - for _, _, span in interfaces.library.get_spans(self.req, full_trace=True) - if _retrieve_span_links(span) is not None - and span["trace_id"] == 1 - and span["parent_id"] == 987654321 # Only fetch the trace that is related to the header extractions + (span, span_format, trace_chunk) + for _, trace_chunk, span, span_format in interfaces.library.get_spans(self.req, full_trace=True) + if interfaces.library.get_span_links(span, span_format) is not None + and interfaces.library.get_span_trace_id( + span, trace_chunk if isinstance(trace_chunk, dict) else None, span_format + ) + == 1 + and interfaces.library.get_span_parent_id(span, span_format) + == 987654321 # Only fetch the trace that is related to the header extractions ] assert len(trace) == 0 @@ -119,11 +135,15 @@ def setup_no_span_links_from_invalid_trace_id(self): def test_no_span_links_from_invalid_trace_id(self): trace = [ - span - for _, _, span in interfaces.library.get_spans(self.req, full_trace=True) - if _retrieve_span_links(span) is not None - and span["trace_id"] == 5 - and span["parent_id"] == 987654324 # Only fetch the trace that is related to the header extractions + (span, span_format, trace_chunk) + for _, trace_chunk, span, span_format in interfaces.library.get_spans(self.req, full_trace=True) + if interfaces.library.get_span_links(span, span_format) is not None + and interfaces.library.get_span_trace_id( + span, trace_chunk if isinstance(trace_chunk, dict) else None, span_format + ) + == 5 + and interfaces.library.get_span_parent_id(span, span_format) + == 987654324 # Only fetch the trace that is related to the header extractions ] assert len(trace) == 0 @@ -153,19 +173,24 @@ def setup_span_links_flags_from_conflicting_contexts(self): def test_span_links_flags_from_conflicting_contexts(self): spans = [ - span - for _, _, span in interfaces.library.get_spans(self.req, full_trace=True) - if _retrieve_span_links(span) is not None - and span["trace_id"] == 2 - and span["parent_id"] == 987654321 # Only fetch the trace that is related to the header extractions + (span, span_format, trace_chunk) + for _, trace_chunk, span, span_format in interfaces.library.get_spans(self.req, full_trace=True) + if interfaces.library.get_span_links(span, span_format) is not None + and interfaces.library.get_span_trace_id( + span, trace_chunk if isinstance(trace_chunk, dict) else None, span_format + ) + == 2 + and interfaces.library.get_span_parent_id(span, span_format) + == 987654321 # Only fetch the trace that is related to the header extractions ] if len(spans) != 1: logger.error(json.dumps(spans, indent=2)) raise ValueError(f"Expected 1 span, got {len(spans)}") - span = spans[0] - span_links = _retrieve_span_links(span) + span, span_format, _ = spans[0] + span_links = interfaces.library.get_span_links(span, span_format) + assert span_links is not None assert len(span_links) == 2 link1 = span_links[0] assert link1["flags"] == 1 | TRACECONTEXT_FLAGS_SET @@ -194,53 +219,29 @@ def setup_span_links_omit_tracestate_from_conflicting_contexts(self): def test_span_links_omit_tracestate_from_conflicting_contexts(self): spans = [ - span - for _, _, span in interfaces.library.get_spans(self.req, full_trace=True) - if _retrieve_span_links(span) is not None - and span["trace_id"] == 2 - and span["parent_id"] == 987654321 # Only fetch the trace that is related to the header extractions + (span, span_format, trace_chunk) + for _, trace_chunk, span, span_format in interfaces.library.get_spans(self.req, full_trace=True) + if interfaces.library.get_span_links(span, span_format) is not None + and interfaces.library.get_span_trace_id( + span, trace_chunk if isinstance(trace_chunk, dict) else None, span_format + ) + == 2 + and interfaces.library.get_span_parent_id(span, span_format) + == 987654321 # Only fetch the trace that is related to the header extractions ] if len(spans) != 1: logger.error(json.dumps(spans, indent=2)) raise ValueError(f"Expected 1 span, got {len(spans)}") - span = spans[0] - links = _retrieve_span_links(span) + span, span_format, _ = spans[0] + links = interfaces.library.get_span_links(span, span_format) + assert links is not None assert len(links) == 1 link1 = links[0] assert link1.get("tracestate") is None -def _retrieve_span_links(span: dict): - if span.get("span_links") is not None: - return span["span_links"] - - if span["meta"].get("_dd.span_links") is not None: - # Convert span_links tags into msgpack v0.4 format - json_links = json.loads(span["meta"].get("_dd.span_links")) - links = [] - for json_link in json_links: - link = {} - link["trace_id"] = int(json_link["trace_id"][-16:], base=16) - link["span_id"] = int(json_link["span_id"], base=16) - if len(json_link["trace_id"]) > 16: - link["trace_id_high"] = int(json_link["trace_id"][:16], base=16) - if "attributes" in json_link: - link["attributes"] = json_link.get("attributes") - if "tracestate" in json_link: - link["tracestate"] = json_link.get("tracestate") - elif "trace_state" in json_link: - link["tracestate"] = json_link.get("trace_state") - if "flags" in json_link: - link["flags"] = json_link.get("flags") | 1 << 31 - else: - link["flags"] = 0 - links.append(link) - return links - return None - - # The Datadog specific tracecontext flags to mark flags are set TRACECONTEXT_FLAGS_SET = 1 << 31 diff --git a/tests/test_graphql.py b/tests/test_graphql.py index c190ada4e4e..911226dbb4d 100644 --- a/tests/test_graphql.py +++ b/tests/test_graphql.py @@ -66,7 +66,7 @@ def test_execute_error_span_event(self): spans = list( span - for _, _, span in interfaces.library.get_spans(request=self.request, full_trace=True) + for _, _, span, _ in interfaces.library.get_spans(request=self.request, full_trace=True) if self._is_graphql_execute_span(span) ) diff --git a/tests/test_identify.py b/tests/test_identify.py index 96e35ade20e..bac3aa18869 100644 --- a/tests/test_identify.py +++ b/tests/test_identify.py @@ -3,25 +3,27 @@ # Copyright 2021 Datadog, Inc. from utils import weblog, interfaces, rfc, features +from utils.dd_constants import TraceLibraryPayloadFormat -def assert_tag_in_span_meta(span: dict, tag: str, expected: str): - if tag not in span["meta"]: +def assert_tag_in_span_meta(span: dict, tag: str, expected: str, span_format: TraceLibraryPayloadFormat | None = None): + meta = interfaces.library.get_span_meta(span, span_format) + if tag not in meta: raise Exception(f"Can't find {tag} in span's meta") - val = span["meta"][tag] + val = meta[tag] if val != expected: raise Exception(f"{tag} value is '{val}', should be '{expected}'") def validate_identify_tags(tags: dict[str, str] | list[str]): - def inner_validate(span: dict): + def inner_validate(span: dict, span_format: TraceLibraryPayloadFormat | None): for tag in tags: if isinstance(tags, dict): - assert_tag_in_span_meta(span, tag, tags[tag]) + assert_tag_in_span_meta(span, tag, tags[tag], span_format) else: full_tag = f"usr.{tag}" - assert_tag_in_span_meta(span, full_tag, full_tag) + assert_tag_in_span_meta(span, full_tag, full_tag, span_format) return True return inner_validate @@ -97,8 +99,9 @@ def setup_identify_tags_incoming(self): def test_identify_tags_incoming(self): """With W3C : this test expect to fail with DD_TRACE_PROPAGATION_STYLE_INJECT=W3C""" - def usr_id_not_present(span: dict): - if "usr.id" in span["meta"]: + def usr_id_not_present(span: dict, span_format: TraceLibraryPayloadFormat | None): + meta = interfaces.library.get_span_meta(span, span_format) + if "usr.id" in meta: raise Exception("usr.id must not be present in this span") return True diff --git a/tests/test_library_conf.py b/tests/test_library_conf.py index 1918d80487c..5c82e2ebd8d 100644 --- a/tests/test_library_conf.py +++ b/tests/test_library_conf.py @@ -40,9 +40,11 @@ class Test_HeaderTags: def test_trace_header_tags_basic(self): """Test that http.request.headers.user-agent is in all web spans""" - for _, span in interfaces.library.get_root_spans(): - if span.get("type") == "web": - assert "http.request.headers.user-agent" in span.get("meta", {}) + for _, span, span_format in interfaces.library.get_root_spans(): + span_type = interfaces.library.get_span_type(span, span_format) + if span_type == "web": + meta = interfaces.library.get_span_meta(span, span_format) + assert "http.request.headers.user-agent" in meta @scenarios.library_conf_custom_header_tags @@ -57,9 +59,10 @@ def setup_trace_header_tags(self): def test_trace_header_tags(self): tags = {TAG_SHORT: HEADER_VAL_BASIC} - for _, _, span in interfaces.library.get_spans(request=self.r): + for _, _, span, span_format in interfaces.library.get_spans(request=self.r): + meta = interfaces.library.get_span_meta(span, span_format) for tag in tags: - assert tag in span["meta"] + assert tag in meta @scenarios.library_conf_custom_header_tags @@ -74,9 +77,10 @@ def setup_trace_header_tags(self): def test_trace_header_tags(self): tags = {TAG_LONG: HEADER_VAL_BASIC} - for _, _, span in interfaces.library.get_spans(request=self.r): + for _, _, span, span_format in interfaces.library.get_spans(request=self.r): + meta = interfaces.library.get_span_meta(span, span_format) for tag in tags: - assert tag in span["meta"] + assert tag in meta @scenarios.library_conf_custom_header_tags @@ -93,9 +97,10 @@ def setup_trace_header_tags(self): def test_trace_header_tags(self): tags = {TAG_WHITESPACE_HEADER: HEADER_VAL_BASIC} - for _, _, span in interfaces.library.get_spans(request=self.r): + for _, _, span, span_format in interfaces.library.get_spans(request=self.r): + meta = interfaces.library.get_span_meta(span, span_format) for tag in tags: - assert tag in span["meta"] + assert tag in meta @scenarios.library_conf_custom_header_tags @@ -112,9 +117,10 @@ def setup_trace_header_tags(self): def test_trace_header_tags(self): tags = {TAG_WHITESPACE_TAG: HEADER_VAL_BASIC} - for _, _, span in interfaces.library.get_spans(request=self.r): + for _, _, span, span_format in interfaces.library.get_spans(request=self.r): + meta = interfaces.library.get_span_meta(span, span_format) for tag in tags: - assert tag in span["meta"] + assert tag in meta @scenarios.library_conf_custom_header_tags @@ -131,9 +137,10 @@ def setup_trace_header_tags(self): def test_trace_header_tags(self): tags = {TAG_WHITESPACE_VAL_SHORT: HEADER_VAL_WHITESPACE_VAL_SHORT.strip()} - for _, _, span in interfaces.library.get_spans(request=self.r): + for _, _, span, span_format in interfaces.library.get_spans(request=self.r): + meta = interfaces.library.get_span_meta(span, span_format) for tag in tags: - assert tag in span["meta"] + assert tag in meta @scenarios.library_conf_custom_header_tags @@ -150,9 +157,10 @@ def setup_trace_header_tags(self): def test_trace_header_tags(self): tags = {TAG_WHITESPACE_VAL_LONG: HEADER_VAL_WHITESPACE_VAL_LONG.strip()} - for _, _, span in interfaces.library.get_spans(request=self.r): + for _, _, span, span_format in interfaces.library.get_spans(request=self.r): + meta = interfaces.library.get_span_meta(span, span_format) for tag in tags: - assert tag in span["meta"] + assert tag in meta @scenarios.library_conf_custom_header_tags_invalid @@ -172,9 +180,10 @@ def test_trace_header_tags(self): CONFIG_COLON_LEADING.split(":")[1], ] - for _, _, span in interfaces.library.get_spans(request=self.r): + for _, _, span, span_format in interfaces.library.get_spans(request=self.r): + meta = interfaces.library.get_span_meta(span, span_format) for tag in nottags: - assert tag not in span["meta"] + assert tag not in meta @scenarios.library_conf_custom_header_tags_invalid @@ -194,9 +203,10 @@ def test_trace_header_tags(self): CONFIG_COLON_TRAILING.split(":")[1], ] - for _, _, span in interfaces.library.get_spans(request=self.r): + for _, _, span, span_format in interfaces.library.get_spans(request=self.r): + meta = interfaces.library.get_span_meta(span, span_format) for tag in nottags: - assert tag not in span["meta"] + assert tag not in meta @scenarios.library_conf_custom_header_tags @@ -243,29 +253,41 @@ def test_tracing_client_http_header_tags(self): Requests are made to the test agent. """ # Validate the spans generated by the first request - spans = [span for _, _, span in interfaces.library.get_spans(request=self.req1, full_trace=True)] - for s in spans: - if "/status" in s["resource"]: + spans = [ + (span, span_format) + for _, _, span, span_format in interfaces.library.get_spans(request=self.req1, full_trace=True) + ] + for s, span_format in spans: + # resource is a top-level field in both formats + resource = s.get("resource", "") + if "/status" in resource: + meta = interfaces.library.get_span_meta(s, span_format) # Header tags set via remote config - assert s["meta"].get("test_header_rc") - assert s["meta"].get("test_header_rc2") - assert s["meta"].get("http.request.headers.content-length") + assert meta.get("test_header_rc") + assert meta.get("test_header_rc2") + assert meta.get("http.request.headers.content-length") # Does not have headers set via Enviorment variables - assert TAG_SHORT not in s["meta"] + assert TAG_SHORT not in meta break else: pytest.fail(f"A span with /status in the resource name was not found {spans}") # Validate the spans generated by the second request - spans = [span for _, _, span in interfaces.library.get_spans(request=self.req2, full_trace=True)] - for s in spans: - if "/status" in s["resource"]: + spans = [ + (span, span_format) + for _, _, span, span_format in interfaces.library.get_spans(request=self.req2, full_trace=True) + ] + for s, span_format in spans: + # resource is a top-level field in both formats + resource = s.get("resource", "") + if "/status" in resource: + meta = interfaces.library.get_span_meta(s, span_format) # Headers tags set via remote config - assert s["meta"].get(TAG_SHORT) == HEADER_VAL_BASIC + assert meta.get(TAG_SHORT) == HEADER_VAL_BASIC # Does not have headers set via remote config - assert "test_header_rc" not in s["meta"], s["meta"] - assert "test_header_rc2" not in s["meta"], s["meta"] - assert "http.request.headers.content-length" in s["meta"], s["meta"] + assert "test_header_rc" not in meta, meta + assert "test_header_rc2" not in meta, meta + assert "http.request.headers.content-length" in meta, meta break else: pytest.fail(f"A span with /status in the resource name was not found {spans}") @@ -336,44 +358,62 @@ def setup_tracing_client_http_header_tags_apm_multiconfig(self): def test_tracing_client_http_header_tags_apm_multiconfig(self): """Ensure the tracing http header tags can be set via RC with the APM_TRACING_MULTICONFIG capability.""" # Validate the spans generated by the first request - spans = [span for _, _, span in interfaces.library.get_spans(request=self.req1, full_trace=True)] - for s in spans: - if "/status" in s["resource"]: + spans = [ + (span, span_format) + for _, _, span, span_format in interfaces.library.get_spans(request=self.req1, full_trace=True) + ] + for s, span_format in spans: + # resource is a top-level field in both formats + resource = s.get("resource", "") + if "/status" in resource: + meta = interfaces.library.get_span_meta(s, span_format) # Header tags set via remote config - assert s["meta"].get("test_header_rc") - assert s["meta"].get("test_header_rc2") - assert s["meta"].get("http.request.headers.content-length") + assert meta.get("test_header_rc") + assert meta.get("test_header_rc2") + assert meta.get("http.request.headers.content-length") # Does not have headers set via Enviorment variables - assert TAG_SHORT not in s["meta"] + assert TAG_SHORT not in meta break else: pytest.fail(f"A span with /status in the resource name was not found {spans}") # Validate the spans generated by the second request - spans = [span for _, _, span in interfaces.library.get_spans(request=self.req2, full_trace=True)] - for s in spans: - if "/status" in s["resource"]: + spans = [ + (span, span_format) + for _, _, span, span_format in interfaces.library.get_spans(request=self.req2, full_trace=True) + ] + for s, span_format in spans: + # resource is a top-level field in both formats + resource = s.get("resource", "") + if "/status" in resource: + meta = interfaces.library.get_span_meta(s, span_format) # Headers tags set via remote config - assert s["meta"].get(TAG_SHORT) == HEADER_VAL_BASIC - assert s["meta"].get("test_header_rc_override") + assert meta.get(TAG_SHORT) == HEADER_VAL_BASIC + assert meta.get("test_header_rc_override") # Does not have headers set via remote config - assert "test_header_rc2" not in s["meta"], s["meta"] - assert "http.request.headers.content-length" in s["meta"], s["meta"] + assert "test_header_rc2" not in meta, meta + assert "http.request.headers.content-length" in meta, meta break else: pytest.fail(f"A span with /status in the resource name was not found {spans}") # Validate the spans generated by the third request. This should be identical to the first request, because # we deleted the config with the weblog service and env. - spans = [span for _, _, span in interfaces.library.get_spans(request=self.req3, full_trace=True)] - for s in spans: - if "/status" in s["resource"]: + spans = [ + (span, span_format) + for _, _, span, span_format in interfaces.library.get_spans(request=self.req3, full_trace=True) + ] + for s, span_format in spans: + # resource is a top-level field in both formats + resource = s.get("resource", "") + if "/status" in resource: + meta = interfaces.library.get_span_meta(s, span_format) # Header tags set via remote config - assert s["meta"].get("test_header_rc") - assert s["meta"].get("test_header_rc2") - assert s["meta"].get("http.request.headers.content-length") + assert meta.get("test_header_rc") + assert meta.get("test_header_rc2") + assert meta.get("http.request.headers.content-length") # Does not have headers set via Enviorment variables - assert TAG_SHORT not in s["meta"] + assert TAG_SHORT not in meta break else: pytest.fail(f"A span with /status in the resource name was not found {spans}") diff --git a/tests/test_resource_renaming.py b/tests/test_resource_renaming.py index 114abf2363c..247894f492c 100644 --- a/tests/test_resource_renaming.py +++ b/tests/test_resource_renaming.py @@ -4,9 +4,10 @@ def get_endpoint_tag(response: HttpResponse) -> str | None: spans = interfaces.library.get_spans(response) - for _, _, span in spans: - if "http.endpoint" in span.get("meta", {}): - return span["meta"]["http.endpoint"] + for _, _, span, span_format in spans: + meta = interfaces.library.get_span_meta(span, span_format) + if "http.endpoint" in meta: + return meta["http.endpoint"] return None diff --git a/tests/test_sampling_rates.py b/tests/test_sampling_rates.py index 7060bf39c64..e1cf20f1973 100644 --- a/tests/test_sampling_rates.py +++ b/tests/test_sampling_rates.py @@ -10,7 +10,7 @@ from urllib.parse import urlparse from utils import weblog, interfaces, context, scenarios, features, logger -from utils.dd_constants import SamplingPriority +from utils.dd_constants import SamplingPriority, TraceLibraryPayloadFormat """Those are the constants used by the sampling algorithm in all the tracers @@ -21,11 +21,20 @@ MAX_UINT64 = 2**64 - 1 -def get_trace_request_path(root_span: dict) -> str | None: - if root_span.get("type") != "web": +def get_trace_request_path(root_span: dict, span_format: TraceLibraryPayloadFormat | None = None) -> str | None: + if span_format is None: + # Auto-detect format + if "attributes" in root_span and "meta" not in root_span: + span_format = TraceLibraryPayloadFormat.v1 + else: + span_format = TraceLibraryPayloadFormat.v04 + + span_type = interfaces.library.get_span_type(root_span, span_format) + if span_type != "web": return None - url = root_span["meta"].get("http.url") + meta = interfaces.library.get_span_meta(root_span, span_format) + url = meta.get("http.url") if url is None: return None @@ -36,8 +45,8 @@ def get_trace_request_path(root_span: dict) -> str | None: def assert_all_traces_requests_forwarded(paths: list[str] | set[str]) -> None: path_to_logfile = {} - for data, span in interfaces.library.get_root_spans(): - path = get_trace_request_path(span) + for data, span, span_format in interfaces.library.get_root_spans(): + path = get_trace_request_path(span, span_format) path_to_logfile[path] = data["log_filename"] has_error = False @@ -72,14 +81,25 @@ def trace_should_be_kept(sampling_rate: float, trace_id: int): def _spans_with_parent(traces: list, parent_ids: list): + """Extract spans with matching parent_ids from traces. + Returns (span, span_format) tuples. + """ if not isinstance(traces, list): logger.error("Traces should be an array") yield from [] # do not fail here, it's schema's job else: for trace in traces: for span in trace: - if span.get("parent_id") in parent_ids: - yield span + # Detect format: v1 spans have "attributes" at top level, v04 have "meta" + if isinstance(span, dict): + if "attributes" in span and "meta" not in span: + span_format = TraceLibraryPayloadFormat.v1 + else: + span_format = TraceLibraryPayloadFormat.v04 + + parent_id = interfaces.library.get_span_parent_id(span, span_format) + if parent_id in parent_ids: + yield span, span_format def generate_request_id() -> Generator[int, Any, Any]: @@ -114,8 +134,8 @@ def test_sampling_rates(self): # test sampling sampled_count = {True: 0, False: 0} - for data, root_span in interfaces.library.get_root_spans(): - metrics = root_span["metrics"] + for data, root_span, span_format in interfaces.library.get_root_spans(): + metrics = interfaces.library.get_span_metrics(root_span, span_format) assert "_sampling_priority_v1" in metrics, f"_sampling_priority_v1 is missing in {data['log_filename']}" sampled_count[priority_should_be_kept(metrics["_sampling_priority_v1"])] += 1 @@ -149,8 +169,9 @@ def setup_sampling_decision(self): def test_sampling_decision(self): """Verify that traces are sampled following the sample rate""" - def validator(datum: dict, root_span: dict): - sampling_priority = root_span["metrics"].get("_sampling_priority_v1") + def validator(datum: dict, root_span: dict, span_format: TraceLibraryPayloadFormat | None): + metrics = interfaces.library.get_span_metrics(root_span, span_format) + sampling_priority = metrics.get("_sampling_priority_v1") if sampling_priority is None: raise ValueError( f"Message: {datum['log_filename']}:" @@ -158,21 +179,23 @@ def validator(datum: dict, root_span: dict): ) sampling_decision = priority_should_be_kept(sampling_priority) - expected_decision = trace_should_be_kept(context.tracer_sampling_rate, root_span["trace_id"]) + trace_id = interfaces.library.get_span_trace_id(root_span, None, span_format) + expected_decision = trace_should_be_kept(context.tracer_sampling_rate, trace_id) if sampling_decision != expected_decision: - if sampling_decision and root_span["meta"].get("_dd.p.dm") == "-5": + meta = interfaces.library.get_span_meta(root_span, span_format) + if sampling_decision and meta.get("_dd.p.dm") == "-5": # If the decision maker is set to -5, it means that the trace has been sampled due # to AppSec, it should not impact this test and should be ignored. # In this case it is most likely the Healthcheck as it is the first request # and AppSec WAF always samples the first request. return raise ValueError( - f"Trace id {root_span['trace_id']}, sampling priority {sampling_priority}, " + f"Trace id {trace_id}, sampling priority {sampling_priority}, " f"sampling decision {sampling_decision} differs from the expected {expected_decision}" ) - for data, span in interfaces.library.get_root_spans(): - validator(data, span) + for data, span, span_format in interfaces.library.get_root_spans(): + validator(data, span, span_format) @scenarios.sampling @@ -196,20 +219,26 @@ def test_sampling_decision_added(self): spans = [] def validator(data: dict): - for span in _spans_with_parent(data["request"]["content"], list(traces.keys())): - expected_trace_id = traces[span["parent_id"]]["trace_id"] + for span, span_format in _spans_with_parent(data["request"]["content"], list(traces.keys())): + parent_id = interfaces.library.get_span_parent_id(span, span_format) + if parent_id is None or parent_id not in traces: + continue + expected_trace_id = traces[parent_id]["trace_id"] spans.append(span) - assert span["trace_id"] == expected_trace_id, ( + trace_id = interfaces.library.get_span_trace_id(span, None, span_format) + assert trace_id == expected_trace_id, ( f"Message: {data['log_filename']}: If parent_id matches, " f"trace_id should match too expected trace_id {expected_trace_id} " - f"span trace_id : {span['trace_id']}, span parent_id : {span['parent_id']}", + f"span trace_id : {trace_id}, span parent_id : {parent_id}", ) - sampling_priority = span["metrics"].get("_sampling_priority_v1") + metrics = interfaces.library.get_span_metrics(span, span_format) + sampling_priority = metrics.get("_sampling_priority_v1") + span_id = interfaces.library.get_span_span_id(span, span_format) assert sampling_priority is not None, ( - f"Message: {data['log_filename']}: sampling priority should be set on span {span['span_id']}", + f"Message: {data['log_filename']}: sampling priority should be set on span {span_id}", ) interfaces.library.validate_all(validator, path_filters=["/v0.4/traces", "/v0.5/traces"], allow_no_data=True) @@ -246,15 +275,20 @@ def test_sampling_determinism(self): sampling_decisions_per_trace_id = defaultdict(list) def validator(data: dict): - for span in _spans_with_parent(data["request"]["content"], list(traces.keys())): - expected_trace_id = traces[(span["parent_id"])]["trace_id"] - sampling_priority = span["metrics"].get("_sampling_priority_v1") - sampling_decisions_per_trace_id[span["trace_id"]].append(sampling_priority) + for span, span_format in _spans_with_parent(data["request"]["content"], list(traces.keys())): + parent_id = interfaces.library.get_span_parent_id(span, span_format) + if parent_id is None or parent_id not in traces: + continue + expected_trace_id = traces[parent_id]["trace_id"] + trace_id = interfaces.library.get_span_trace_id(span, None, span_format) + metrics = interfaces.library.get_span_metrics(span, span_format) + sampling_priority = metrics.get("_sampling_priority_v1") + sampling_decisions_per_trace_id[trace_id].append(sampling_priority) - assert span["trace_id"] == expected_trace_id, ( + assert trace_id == expected_trace_id, ( f"Message: {data['log_filename']}: If parent_id matches, " f"trace_id should match too expected trace_id {expected_trace_id} " - f"span trace_id : {span['trace_id']}, span parent_id : {span['parent_id']}", + f"span trace_id : {trace_id}, span parent_id : {parent_id}", ) assert sampling_priority is not None, ( @@ -316,10 +350,14 @@ def test_sample_rate_function(self): # Ensure the request succeeded, any failure would make the test incorrect. assert req.status_code == 200, "Call to /sample_rate_route/:i failed" - for data, _, span in interfaces.library.get_spans(request=req): + for data, trace, span, span_format in interfaces.library.get_spans(request=req): # Validate the sampling decision - trace_id = span["trace_id"] - sampling_priority = span["metrics"].get("_sampling_priority_v1") + trace_dict: dict | None = ( + trace if span_format == TraceLibraryPayloadFormat.v1 and isinstance(trace, dict) else None + ) + trace_id = interfaces.library.get_span_trace_id(span, trace_dict, span_format) + metrics = interfaces.library.get_span_metrics(span, span_format) + sampling_priority = metrics.get("_sampling_priority_v1") logger.info(f"Trying to validate trace_id:{trace_id} from {data['log_filename']}") logger.info(f"Sampling priority: {sampling_priority}") assert sampling_priority is not None, ( diff --git a/tests/test_scrubbing.py b/tests/test_scrubbing.py index 8a1a42994c2..f742364a9b2 100644 --- a/tests/test_scrubbing.py +++ b/tests/test_scrubbing.py @@ -6,6 +6,7 @@ import re from utils import context, interfaces, rfc, weblog, missing_feature, features, scenarios, logger +from utils.dd_constants import TraceLibraryPayloadFormat def validate_no_leak(needle: str, whitelist_pattern: str | None = None) -> Callable[[dict], None]: @@ -81,10 +82,22 @@ def test_main(self): assert self.r.status_code == 200 def validate_report(trace: list): + # For v1 format, trace is a list of spans from a chunk + # For v04 format, trace is a list of spans + # We need to detect format and use appropriate helper methods for span in trace: - if span.get("type") == "http": - logger.info(f"span found: {span}") - return "agent:8127" in span["meta"]["http.url"] + if isinstance(span, dict): + # Detect format: v1 spans have "attributes" at top level, v04 have "meta" + if "attributes" in span and "meta" not in span: + span_format = TraceLibraryPayloadFormat.v1 + else: + span_format = TraceLibraryPayloadFormat.v04 + + span_type = interfaces.library.get_span_type(span, span_format) + if span_type == "http": + logger.info(f"span found: {span}") + meta = interfaces.library.get_span_meta(span, span_format) + return "agent:8127" in meta.get("http.url", "") return False diff --git a/tests/test_semantic_conventions.py b/tests/test_semantic_conventions.py index 152f54b1b6e..1048e95d755 100644 --- a/tests/test_semantic_conventions.py +++ b/tests/test_semantic_conventions.py @@ -6,6 +6,7 @@ from urllib.parse import urlparse from utils import context, interfaces, features, scenarios +from utils.dd_constants import TraceLibraryPayloadFormat RUNTIME_LANGUAGE_MAP = { @@ -173,15 +174,17 @@ class Test_Meta: def test_meta_span_kind(self): """Validates that traces from an http framework carry a span.kind meta tag, with value server or client""" - def validator(span: dict): + def validator(span: dict, span_format: TraceLibraryPayloadFormat | None): if span.get("parent_id") not in (0, None): # do nothing if not root span return None - if span.get("type") != "web": # do nothing if is not web related + span_type = interfaces.library.get_span_type(span, span_format) + if span_type != "web": # do nothing if is not web related return None - assert "span.kind" in span["meta"], "Web span expects a span.kind meta tag" - assert span["meta"]["span.kind"] in ["server", "client"], "Meta tag span.kind should be client or server" + meta = interfaces.library.get_span_meta(span, span_format) + assert "span.kind" in meta, "Web span expects a span.kind meta tag" + assert meta["span.kind"] in ["server", "client"], "Meta tag span.kind should be client or server" return True @@ -190,16 +193,18 @@ def validator(span: dict): def test_meta_http_url(self): """Validates that traces from an http framework carry a http.url meta tag, formatted as a URL""" - def validator(span: dict): + def validator(span: dict, span_format: TraceLibraryPayloadFormat | None): if span.get("parent_id") not in (0, None): # do nothing if not root span return None - if span.get("type") != "web": # do nothing if is not web related + span_type = interfaces.library.get_span_type(span, span_format) + if span_type != "web": # do nothing if is not web related return None - assert "http.url" in span["meta"], "web span expect an http.url meta tag" + meta = interfaces.library.get_span_meta(span, span_format) + assert "http.url" in meta, "web span expect an http.url meta tag" - scheme = urlparse(span["meta"]["http.url"]).scheme + scheme = urlparse(meta["http.url"]).scheme assert scheme in ["http", "https"], f"Meta http.url's scheme should be http or https, not {scheme}" return True @@ -209,16 +214,18 @@ def validator(span: dict): def test_meta_http_status_code(self): """Validates that traces from an http framework carry a http.status_code meta tag, formatted as a int""" - def validator(span: dict): + def validator(span: dict, span_format: TraceLibraryPayloadFormat | None): if span.get("parent_id") not in (0, None): # do nothing if not root span return None - if span.get("type") != "web": # do nothing if is not web related + span_type = interfaces.library.get_span_type(span, span_format) + if span_type != "web": # do nothing if is not web related return None - assert "http.status_code" in span["meta"], "web span expect an http.status_code meta tag" + meta = interfaces.library.get_span_meta(span, span_format) + assert "http.status_code" in meta, "web span expect an http.status_code meta tag" - _ = int(span["meta"]["http.status_code"]) + _ = int(meta["http.status_code"]) return True @@ -227,16 +234,18 @@ def validator(span: dict): def test_meta_http_method(self): """Validates that traces from an http framework carry a http.method meta tag, with a legal HTTP method""" - def validator(span: dict): + def validator(span: dict, span_format: TraceLibraryPayloadFormat | None): if span.get("parent_id") not in (0, None): # do nothing if not root span return None - if span.get("type") != "web": # do nothing if is not web related + span_type = interfaces.library.get_span_type(span, span_format) + if span_type != "web": # do nothing if is not web related return None - assert "http.method" in span["meta"], "web span expect an http.method meta tag" + meta = interfaces.library.get_span_meta(span, span_format) + assert "http.method" in meta, "web span expect an http.method meta tag" - value = span["meta"]["http.method"] + value = meta["http.method"] assert isinstance(value, (str, bytes)), "Method should always be a string" @@ -261,16 +270,17 @@ def validator(span: dict): def test_meta_language_tag(self): """Assert that all spans have required language tag.""" - def validator(span: dict): + def validator(span: dict, span_format: TraceLibraryPayloadFormat | None): if span.get("parent_id") not in (0, None): # do nothing if not root span return - assert "language" in span["meta"], "Span must have a language tag set." + meta = interfaces.library.get_span_meta(span, span_format) + assert "language" in meta, "Span must have a language tag set." library = context.library.name expected_language = RUNTIME_LANGUAGE_MAP.get(library, library) - actual_language = span["meta"]["language"] + actual_language = meta["language"] assert actual_language == expected_language, ( f"Span actual language, {actual_language}, did not match expected language, {expected_language}." ) @@ -282,24 +292,27 @@ def validator(span: dict): def test_meta_component_tag(self): """Assert that all spans generated from a weblog_variant have component metadata tag matching integration name.""" - def validator(span: dict): - if span.get("type") != "web": # do nothing if is not web related + def validator(span: dict, span_format: TraceLibraryPayloadFormat | None): + span_type = interfaces.library.get_span_type(span, span_format) + if span_type != "web": # do nothing if is not web related return - expected_component = get_component_name(span["name"]) + span_name = interfaces.library.get_span_name(span, span_format) + expected_component = get_component_name(span_name) - assert "component" in span.get("meta", {}), ( - f"No component tag found. Expected span {span['name']} component to be: {expected_component}." + meta = interfaces.library.get_span_meta(span, span_format) + assert "component" in meta, ( + f"No component tag found. Expected span {span_name} component to be: {expected_component}." ) - actual_component = span["meta"]["component"] + actual_component = meta["component"] if isinstance(expected_component, list): - exception_message = f"""Expected span {span["name"]} to have component meta tag equal + exception_message = f"""Expected span {span_name} to have component meta tag equal to one of the following, [{expected_component}], got: {actual_component}.""" assert actual_component in expected_component, exception_message else: - exception_message = f"Expected span {span['name']} to have component meta tag, {expected_component}, got: {actual_component}." + exception_message = f"Expected span {span_name} to have component meta tag, {expected_component}, got: {actual_component}." assert actual_component == expected_component, exception_message interfaces.library.validate_all_spans(validator=validator, allow_no_data=True) @@ -309,11 +322,12 @@ def validator(span: dict): def test_meta_runtime_id_tag(self): """Assert that all spans generated from a weblog_variant have runtime-id metadata tag with some value.""" - def validator(span: dict): + def validator(span: dict, span_format: TraceLibraryPayloadFormat | None): if span.get("parent_id") not in (0, None): # do nothing if not root span return - assert "runtime-id" in span["meta"], "No runtime-id tag found. Expected tag to be present." + meta = interfaces.library.get_span_meta(span, span_format) + assert "runtime-id" in meta, "No runtime-id tag found. Expected tag to be present." interfaces.library.validate_all_spans(validator=validator, allow_no_data=True) # checking that we have at least one root span @@ -325,12 +339,11 @@ class Test_MetaDatadogTags: """Spans carry meta tags that were set in DD_TAGS tracer environment""" def test_meta_dd_tags(self): - def validator(span: dict): - assert span["meta"]["key1"] == "val1", ( - f'keyTag tag in span\'s meta should be "test", not {span["meta"]["env"]}' - ) - assert span["meta"]["key2"] == "val2", ( - f'dKey tag in span\'s meta should be "key2:val2", not {span["meta"]["key2"]}' + def validator(span: dict, span_format: TraceLibraryPayloadFormat | None): + meta = interfaces.library.get_span_meta(span, span_format) + assert meta["key1"] == "val1", f'keyTag tag in span\'s meta should be "test", not {meta.get("env", "N/A")}' + assert meta["key2"] == "val2", ( + f'dKey tag in span\'s meta should be "key2:val2", not {meta.get("key2", "N/A")}' ) return True @@ -346,7 +359,8 @@ class Test_MetricsStandardTags: def test_metrics_process_id(self): """Validates that root spans from traces contain a process_id field""" - spans = [s for _, s in interfaces.library.get_root_spans()] - assert spans, "Did not receive any root spans to validate." - for span in spans: - assert "process_id" in span["metrics"], "Root span expect a process_id metrics tag" + spans_with_format = [(s, f) for _, s, f in interfaces.library.get_root_spans()] + assert spans_with_format, "Did not receive any root spans to validate." + for span, span_format in spans_with_format: + metrics = interfaces.library.get_span_metrics(span, span_format) + assert "process_id" in metrics, "Root span expect a process_id metrics tag" diff --git a/tests/test_span_events.py b/tests/test_span_events.py index eb8ad316a9a..955632b07ef 100644 --- a/tests/test_span_events.py +++ b/tests/test_span_events.py @@ -22,8 +22,8 @@ def setup_v04_v07_default_format(self): def test_v04_v07_default_format(self): """For traces that default to the v0.4 or v0.7 format, send events as a top-level `span_events` field""" interfaces.library.assert_trace_exists(self.r) - span = interfaces.library.get_root_span(self.r) - meta = span.get("meta", {}) + span, span_format = interfaces.library.get_root_span(self.r) + meta = interfaces.library.get_span_meta(span, span_format) assert "span_events" in span assert "events" not in meta @@ -36,8 +36,8 @@ def test_v05_default_format(self): given this format does not support native serialization. """ interfaces.library.assert_trace_exists(self.r) - span = interfaces.library.get_root_span(self.r) - meta = span.get("meta", {}) + span, span_format = interfaces.library.get_root_span(self.r) + meta = interfaces.library.get_span_meta(span, span_format) assert "span_events" not in span assert "events" in meta @@ -57,7 +57,7 @@ def setup_send_as_a_tag(self): def test_send_as_a_tag(self): """Send span events as the tag `events` when the agent does not support native serialization""" interfaces.library.assert_trace_exists(self.r) - span = interfaces.library.get_root_span(self.r) - meta = span.get("meta", {}) + span, span_format = interfaces.library.get_root_span(self.r) + meta = interfaces.library.get_span_meta(span, span_format) assert "span_events" not in span assert "events" in meta diff --git a/tests/test_standard_tags.py b/tests/test_standard_tags.py index 09fafe432ec..642bda0d75e 100644 --- a/tests/test_standard_tags.py +++ b/tests/test_standard_tags.py @@ -308,8 +308,8 @@ def test_client_ip_with_appsec_event_and_vendor_headers(self): assert meta[tag] == value def _get_root_span_meta(self, request: HttpResponse): - span = interfaces.library.get_root_span(request) - return span.get("meta", {}) + span, span_format = interfaces.library.get_root_span(request) + return interfaces.library.get_span_meta(span, span_format) @features.referrer_hostname @@ -372,5 +372,5 @@ def test_referrer_hostname(self): ) def _get_root_span_meta(self, request: HttpResponse): - span = interfaces.library.get_root_span(request) - return span.get("meta", {}) + span, span_format = interfaces.library.get_root_span(request) + return interfaces.library.get_span_meta(span, span_format) diff --git a/tests/test_the_test/test_deserializer.py b/tests/test_the_test/test_deserializer.py index 2c2c3fd2021..a47966a66ae 100644 --- a/tests/test_the_test/test_deserializer.py +++ b/tests/test_the_test/test_deserializer.py @@ -5,140 +5,269 @@ @scenarios.test_the_test -def test_deserialize_http_message(): - content = msgpack.packb( - { - 2: "hello", - 11: [ +class Test_Deserializer: + def test_deserialize_http_message(self): + content = msgpack.packb( + { + 2: "hello", + 11: [ + { + 1: 1, + 2: "rum", + 3: ["some-global", 1, "cool-value", 1, 1, 2], + 4: [ + { + 1: "my-service", + 2: "span-name", + 3: 1, + 4: 1234, + 5: 5555, + 6: 987, + 7: 150, + 8: True, + 9: ["foo", 1, "bar", "fooNum", 3, 3.14], + 10: "span-type", + 13: "some-env", + 14: "my-version", + 15: "my-component", + 16: 1, + } + ], + 6: bytes( + [ + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x55, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x21, + 0xE3, + ] + ), + 7: 4, + } + ], + } + ) + + result = deserialize_v1_trace(content=content) + + assert result == { + "container_id": "hello", + "chunks": [ { - 1: 1, - 2: "rum", - 3: ["some-global", 1, "cool-value", 1, 1, 2], - 4: [ + "spans": [ { - 1: "my-service", - 2: "span-name", - 3: 1, - 4: 1234, - 5: 5555, - 6: 987, - 7: 150, - 8: True, - 9: ["foo", 1, "bar", "fooNum", 3, 3.14], - 10: "span-type", - 13: "some-env", - 14: "my-version", - 15: "my-component", - 16: 1, + "service": "my-service", + "name_value": "span-name", + "resource": "hello", + "span_id": 1234, + "parent_id": 5555, + "component": "my-component", + "span_kind": 1, + "version": "my-version", + "env": "some-env", + "start": 987, + "duration": 150, + "error": True, + "attributes": {"foo": "bar", "fooNum": 3.14}, + "type_value": "span-type", } ], - 6: bytes( - [0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x55, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x21, 0xE3] - ), - 7: 4, + "priority": 1, + "origin": "rum", + "attributes": {"some-global": "cool-value", "hello": "rum"}, + "trace_id": "0x000000000000005500000000000021E3", + "sampling_mechanism": 4, } ], } - ) - result = deserialize_v1_trace(content=content) + def test_uncompress_agent_v1_trace_with_span_links(self): + """Test that span links traceID is properly deserialized from base64 in idxTracerPayloads.""" + # Create a 16-byte trace ID and encode it as base64 (mimics what protobuf returns) + trace_id_bytes = bytes( + [0x12, 0x34, 0x56, 0x78, 0x90, 0x12, 0x34, 0x56, 0x78, 0x90, 0x12, 0x34, 0x56, 0x78, 0x90, 0x12] + ) + trace_id_base64 = base64.b64encode(trace_id_bytes).decode("utf-8") - assert result == { - "container_id": "hello", - "chunks": [ - { - "spans": [ - { - "service": "my-service", - "name_value": "span-name", - "resource": "hello", - "span_id": 1234, - "parent_id": 5555, - "component": "my-component", - "span_kind": 1, - "version": "my-version", - "env": "some-env", - "start": 987, - "duration": 150, - "error": True, - "attributes": {"foo": "bar", "fooNum": 3.14}, - "type_value": "span-type", - } - ], - "priority": 1, - "origin": "rum", - "attributes": {"some-global": "cool-value", "hello": "rum"}, - "trace_id": "0x000000000000005500000000000021E3", - "sampling_mechanism": 4, - } - ], - } + # Chunk traceID also needs to be base64 encoded + chunk_trace_id_bytes = bytes( + [0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x55, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x21, 0xE3] + ) + chunk_trace_id_base64 = base64.b64encode(chunk_trace_id_bytes).decode("utf-8") + # Simulated data structure as returned by protobuf MessageToDict + data = { + "idxTracerPayloads": [ + { + "strings": ["", "my-service", "span-name", "web", "link-key", "link-value", "tracestate-value"], + "attributes": {}, + "chunks": [ + { + "traceID": chunk_trace_id_base64, + "spans": [ + { + "service": "my-service", + "name_value": "span-name", + "typeRef": "web", + "attributes": { + "4": {"stringValueRef": 5} # "link-key": "link-value" + }, + "links": [ + { + "traceID": trace_id_base64, + "spanID": "987654321", + "attributes": { + "4": {"stringValueRef": 5} # "link-key": "link-value" + }, + "tracestateRef": 6, + "flags": 2147483649, + } + ], + } + ], + "attributes": {}, + } + ], + } + ] + } -@scenarios.test_the_test -def test_uncompress_agent_v1_trace_with_span_links(): - """Test that span links traceID is properly deserialized from base64 in idxTracerPayloads.""" - # Create a 16-byte trace ID and encode it as base64 (mimics what protobuf returns) - trace_id_bytes = bytes( - [0x12, 0x34, 0x56, 0x78, 0x90, 0x12, 0x34, 0x56, 0x78, 0x90, 0x12, 0x34, 0x56, 0x78, 0x90, 0x12] - ) - trace_id_base64 = base64.b64encode(trace_id_bytes).decode("utf-8") - - # Chunk traceID also needs to be base64 encoded - chunk_trace_id_bytes = bytes( - [0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x55, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x21, 0xE3] - ) - chunk_trace_id_base64 = base64.b64encode(chunk_trace_id_bytes).decode("utf-8") - - # Simulated data structure as returned by protobuf MessageToDict - data = { - "idxTracerPayloads": [ + result = _uncompress_agent_v1_trace(data, "agent") + + # Verify chunk traceID is deserialized + assert result["idxTracerPayloads"][0]["chunks"][0]["traceID"] == "0x000000000000005500000000000021E3" + + # Verify span link traceID is deserialized from base64 to hex + span_link = result["idxTracerPayloads"][0]["chunks"][0]["spans"][0]["links"][0] + assert span_link["traceID"] == "0x12345678901234567890123456789012" + + # Verify span link attributes are uncompressed + assert span_link["attributes"] == {"link-key": "link-value"} + + # Verify span link tracestate is resolved from string reference + assert span_link["tracestate"] == "tracestate-value" + assert "tracestateRef" not in span_link + + def test_deserialize_v1_trace_with_span_links(self): + """Test that span links are properly deserialized in v1 trace format (library interface).""" + # Create a 16-byte trace ID for the span link + link_trace_id_bytes = bytes( + [0x12, 0x34, 0x56, 0x78, 0x90, 0x12, 0x34, 0x56, 0x78, 0x90, 0x12, 0x34, 0x56, 0x78, 0x90, 0x12] + ) + + # Include tracestate string in content so it gets added to strings array + # Strings array: [""] (index 0), then strings are added as encountered + # Order: "hello" (1), "rum" (2), "some-global" (3), "cool-value" (4), + # "my-service" (5), "span-name" (6), "foo" (7), "bar" (8), + # "tracestate-ref" (9), "tracestate-value" (10), "span-type" (11), + # "some-env" (12), "my-version" (13), "my-component" (14), + # "link-key" (15), "link-value" (16) + content = msgpack.packb( { - "strings": ["", "my-service", "span-name", "web", "link-key", "link-value", "tracestate-value"], - "attributes": {}, - "chunks": [ + 2: "hello", + 11: [ { - "traceID": chunk_trace_id_base64, - "spans": [ + 1: 1, + 2: "rum", + 3: ["some-global", 1, "cool-value"], + 4: [ { - "service": "my-service", - "name_value": "span-name", - "typeRef": "web", - "attributes": { - "4": {"stringValueRef": 5} # "link-key": "link-value" - }, - "links": [ + 1: "my-service", + 2: "span-name", + 3: 1, + 4: 1234, + 5: 5555, + 6: 987, + 7: 150, + 8: True, + 9: [ + "foo", + 1, + "bar", + "tracestate-ref", + 1, + "tracestate-value", + ], # Include tracestate string + 10: "span-type", + # Span links: key 11 + # Span link keys: 1=trace_id, 2=span_id, 3=attributes, 4=trace_state, 5=flags + 11: [ { - "traceID": trace_id_base64, - "spanID": "987654321", - "attributes": { - "4": {"stringValueRef": 5} # "link-key": "link-value" - }, - "tracestateRef": 6, - "flags": 2147483649, + 1: link_trace_id_bytes, # trace_id as bytes + 2: 987654321, # span_id + 3: ["link-key", 1, "link-value"], # attributes [key, type, value] + 4: 10, # trace_state as string reference (index into strings array, "tracestate-value") + 5: 2147483649, # flags } ], + 13: "some-env", + 14: "my-version", + 15: "my-component", + 16: 1, } ], - "attributes": {}, + 6: bytes( + [ + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x55, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x21, + 0xE3, + ] + ), + 7: 4, } ], } - ] - } + ) + + result = deserialize_v1_trace(content=content) + + # Verify span links are present and properly deserialized + span = result["chunks"][0]["spans"][0] + assert "span_links" in span + assert len(span["span_links"]) == 1 + + span_link = span["span_links"][0] - result = _uncompress_agent_v1_trace(data, "agent") + # Verify trace_id is deserialized from bytes to hex string (same as chunk trace IDs) + assert "trace_id" in span_link + assert span_link["trace_id"] == "0x12345678901234567890123456789012" + assert isinstance(span_link["trace_id"], str) - # Verify chunk traceID is deserialized - assert result["idxTracerPayloads"][0]["chunks"][0]["traceID"] == "0x000000000000005500000000000021E3" + # Verify span_id is preserved + assert span_link["span_id"] == 987654321 - # Verify span link traceID is deserialized from base64 to hex - span_link = result["idxTracerPayloads"][0]["chunks"][0]["spans"][0]["links"][0] - assert span_link["traceID"] == "0x12345678901234567890123456789012" + # Verify attributes are uncompressed + assert span_link["attributes"] == {"link-key": "link-value"} - # Verify span link attributes are uncompressed - assert span_link["attributes"] == {"link-key": "link-value"} + # Verify trace_state is resolved from string reference to tracestate + # The trace_state index should resolve to "tracestate-value" from the strings array + # The code adds "tracestate" field when trace_state is a valid string reference + assert "tracestate" in span_link + assert span_link["tracestate"] == "tracestate-value" - # Verify span link tracestate is resolved from string reference - assert span_link["tracestate"] == "tracestate-value" - assert "tracestateRef" not in span_link + # Verify flags are preserved + assert span_link["flags"] == 2147483649 diff --git a/utils/build/docker/golang/install_ddtrace.sh b/utils/build/docker/golang/install_ddtrace.sh index 09455e2cc09..c8e0ad27e92 100755 --- a/utils/build/docker/golang/install_ddtrace.sh +++ b/utils/build/docker/golang/install_ddtrace.sh @@ -25,13 +25,16 @@ elif [ -e "/binaries/golang-load-from-go-get" ]; then for line in "${lines[@]}"; do path="${line%@*}" commit="${line#*@}" - # Get the correct pseudo-version using go list - pseudo_version=$(go list -m -json "$path@$commit" | jq -r .Version) + # Fetch the main module at the specified commit - this will update go.mod with the correct pseudo-version + echo "Fetching $path@$commit" + go get "$path@$commit" + # Get the pseudo-version that was added to go.mod + pseudo_version=$(go list -m -f '{{.Version}}' "$path") + echo "Using pseudo-version: $pseudo_version" + # Replace the main module to use the pseudo-version explicitly go mod edit -replace "$path=$path@$pseudo_version" - for contrib in $CONTRIBS; do - echo "Install contrib $contrib from go get -v $contrib@commit" - go mod edit -replace "$contrib=$contrib@$pseudo_version" - done + # For contrib modules, let go mod tidy resolve them based on the main module's requirements + # They will automatically resolve to compatible versions from the same commit break done else diff --git a/utils/dd_constants.py b/utils/dd_constants.py index 5f8a230c285..1716bc1d973 100644 --- a/utils/dd_constants.py +++ b/utils/dd_constants.py @@ -128,3 +128,17 @@ class TraceAgentPayloadFormat(StrEnum): efficient_trace_payload_format = "efficient_trace_payload_format" """ Efficient format introduced in agent version 7.73.0. Uses idxTracerPayloads field instead of tracerPayloads RFC: https://docs.google.com/document/d/1hNS6anKYutOYW-nmR759UlKXUdT6H0mRwVt7_L70ESc/edit?usp=sharing""" + + +class TraceLibraryPayloadFormat(StrEnum): + """Describe which format is used to carry trace payloads from the library to the agent + This enum is used only in system-tests to differentiate between different library payloads + and is not exposed directly in trace payloads. + """ + + v04 = "v04" + """ v0.4/v0.5 format - list of spans with meta/metrics separated""" + + v1 = "v1" + """ v1.0 format - chunks with spans using attributes and name_value/type_value fields + RFC: https://docs.google.com/document/d/1hNS6anKYutOYW-nmR759UlKXUdT6H0mRwVt7_L70ESc/edit?usp=sharing""" diff --git a/utils/interfaces/_agent.py b/utils/interfaces/_agent.py index 392f4b1b3b5..8c53748c669 100644 --- a/utils/interfaces/_agent.py +++ b/utils/interfaces/_agent.py @@ -98,8 +98,8 @@ def assert_headers_presence( def get_traces(self, request: HttpResponse | None = None) -> Generator[tuple[dict, dict, TraceAgentPayloadFormat]]: """Attempts to fetch the traces the agent will submit to the backend. - When a valid request is given, then we filter the spans to the ones sampled - during that request's execution, and only return those. + When a valid request is given, then we filter the traces to the ones that contain + spans sampled during that request's execution, and only return those. Returns data, trace and trace_format """ @@ -115,22 +115,38 @@ def get_traces(self, request: HttpResponse | None = None) -> Generator[tuple[dic for payload in content: for trace in payload["chunks"]: - for span in trace["spans"]: - if rid is None or get_rid_from_span(span) == rid: - logger.info(f"Found a trace in {data['log_filename']}") - yield data, trace, TraceAgentPayloadFormat.legacy - break + # Check if any span in the trace matches the RID + trace_has_matching_span = False + if rid is None: + trace_has_matching_span = True + else: + for span in trace["spans"]: + if get_rid_from_span(span) == rid: + trace_has_matching_span = True + break + + if trace_has_matching_span: + logger.info(f"Found a trace in {data['log_filename']}") + yield data, trace, TraceAgentPayloadFormat.legacy if "idxTracerPayloads" in data["request"]["content"]: content: list[dict] = data["request"]["content"]["idxTracerPayloads"] for payload in content: for trace in payload.get("chunks", []): - for span in trace["spans"]: - if rid is None or get_rid_from_span(span) == rid: - logger.info(f"Found a trace in {data['log_filename']}") - yield data, trace, TraceAgentPayloadFormat.efficient_trace_payload_format - break + # Check if any span in the trace matches the RID + trace_has_matching_span = False + if rid is None: + trace_has_matching_span = True + else: + for span in trace["spans"]: + if get_rid_from_span(span) == rid: + trace_has_matching_span = True + break + + if trace_has_matching_span: + logger.info(f"Found a trace in {data['log_filename']}") + yield data, trace, TraceAgentPayloadFormat.efficient_trace_payload_format def get_spans( self, request: HttpResponse | None = None diff --git a/utils/interfaces/_backend.py b/utils/interfaces/_backend.py index 01396f21149..3af8621b39d 100644 --- a/utils/interfaces/_backend.py +++ b/utils/interfaces/_backend.py @@ -17,6 +17,7 @@ from utils.tools import get_rid_from_span from utils._logger import logger from utils._weblog import HttpResponse +from utils.dd_constants import TraceLibraryPayloadFormat class _BackendInterfaceValidator(ProxyBasedInterfaceValidator): @@ -68,13 +69,26 @@ def load_data_from_logs(self): def _init_rid_to_library_trace_ids(self): # Map each request ID to the spans created and submitted during that request call. - for _, span in self.library_interface.get_root_spans(): + # Use get_spans to get trace chunks for v1 format + for _, trace, span, span_format in self.library_interface.get_spans(): + parent_id = span.get("parent_id") + if parent_id not in (0, None): + continue # Only process root spans + rid = get_rid_from_span(span) + # Get trace_id using helper method (pass trace chunk for v1 format) + trace_chunk: dict | None = ( + trace if span_format == TraceLibraryPayloadFormat.v1 and isinstance(trace, dict) else None + ) + trace_id = self.library_interface.get_span_trace_id(span, trace_chunk, span_format) + if trace_id == 0: + continue # Skip spans without trace_id + if not self.rid_to_library_trace_ids.get(rid): - self.rid_to_library_trace_ids[rid] = [span["trace_id"]] + self.rid_to_library_trace_ids[rid] = [trace_id] else: - self.rid_to_library_trace_ids[rid].append(span["trace_id"]) + self.rid_to_library_trace_ids[rid].append(trace_id) ################################# ######### API for tests ######### diff --git a/utils/interfaces/_library/appsec.py b/utils/interfaces/_library/appsec.py index 0787652d27f..25c971d4b2e 100644 --- a/utils/interfaces/_library/appsec.py +++ b/utils/interfaces/_library/appsec.py @@ -6,8 +6,10 @@ from collections import Counter from collections.abc import Callable + from utils.interfaces._library.appsec_data import rule_id_to_type from utils._logger import logger +from utils.dd_constants import TraceLibraryPayloadFormat class _WafAttack: @@ -172,8 +174,16 @@ def validate_legacy(self, event: dict): return True - def validate(self, span: dict, appsec_data: dict): # noqa: ARG002 - headers = [n.lower() for n in span["meta"] if n.startswith("http.request.headers.")] + def validate(self, span: dict, appsec_data: dict, span_format: TraceLibraryPayloadFormat | None = None): # noqa: ARG002 + # Use helper method to get meta for both v04 and v1 formats + # Import here to avoid circular import + from utils.interfaces._library.core import LibraryInterfaceValidator # noqa: PLC0415 + + if span_format is None: + span_format = LibraryInterfaceValidator._detect_span_format(span) # noqa: SLF001 + meta = LibraryInterfaceValidator.get_span_meta(span, span_format) + + headers = [n.lower() for n in meta if n.startswith("http.request.headers.")] assert f"http.request.headers.{self.header_name}" in headers, f"header {self.header_name} not reported" return True diff --git a/utils/interfaces/_library/core.py b/utils/interfaces/_library/core.py index 985e8e0a3b9..1246742d296 100644 --- a/utils/interfaces/_library/core.py +++ b/utils/interfaces/_library/core.py @@ -2,18 +2,21 @@ # This product includes software developed at Datadog (https://www.datadoghq.com/). # Copyright 2021 Datadog, Inc. +import ast import base64 from collections.abc import Callable, Iterable, Generator import copy +import inspect import json +import msgpack +import re import threading from utils.tools import get_rid_from_user_agent, get_rid_from_span from utils._logger import logger -from utils.dd_constants import RemoteConfigApplyState, Capabilities +from utils.dd_constants import RemoteConfigApplyState, Capabilities, TraceLibraryPayloadFormat from utils.interfaces._core import ProxyBasedInterfaceValidator from utils.interfaces._library.appsec import _WafAttack, _ReportedHeader -from utils.interfaces._library.miscs import _SpanTagValidator from utils.interfaces._library.telemetry import ( _SeqIdLatencyValidation, _NoSkippedSeqId, @@ -22,15 +25,407 @@ from utils.interfaces._misc_validators import HeadersPresenceValidator +class _SpanTagValidator: + """Will run an arbitrary check on spans. If a request is provided, only span""" + + path_filters = ["/v0.4/traces", "/v0.5/traces"] + + def __init__( + self, + tags: dict | None, + *, + value_as_regular_expression: bool, + library_interface: "LibraryInterfaceValidator | None" = None, + ): + self.tags = {} if tags is None else tags + self.value_as_regular_expression = value_as_regular_expression + self.library_interface = library_interface + + def __call__(self, span: dict, span_format: TraceLibraryPayloadFormat | None = None): + # If span_format is provided, use helper methods; otherwise assume v04 format for backward compatibility + if span_format is not None and self.library_interface is not None: + meta = self.library_interface.get_span_meta(span, span_format) + else: + # Backward compatibility: assume v04 format + meta = span.get("meta", {}) + + for tag_key in self.tags: + if tag_key not in meta: + raise ValueError(f"{tag_key} tag not found in span's meta") + + expect_value = self.tags[tag_key] + actual_value = meta[tag_key] + + if self.value_as_regular_expression: + if not re.compile(expect_value).fullmatch(actual_value): + raise ValueError( + f'{tag_key} tag value is "{actual_value}", and should match regex "{expect_value}"' + ) + elif expect_value != actual_value: + raise ValueError(f'{tag_key} tag in span\'s meta should be "{expect_value}", not "{actual_value}"') + + return True + + class LibraryInterfaceValidator(ProxyBasedInterfaceValidator): """Validate library/agent interface""" - trace_paths = ["/v0.4/traces", "/v0.5/traces"] + trace_paths = ["/v0.4/traces", "/v0.5/traces", "/v1.0/traces"] + # Number of hex characters needed to represent 64 bits (lower trace ID) + _TRACE_ID_HEX_LENGTH = 16 def __init__(self, name: str): super().__init__(name) self.ready = threading.Event() + @staticmethod + def _detect_span_format(span: dict) -> TraceLibraryPayloadFormat: + """Detect the format of a span based on its structure.""" + if "name_value" in span or "type_value" in span or "attributes" in span: + return TraceLibraryPayloadFormat.v1 + return TraceLibraryPayloadFormat.v04 + + @staticmethod + def get_span_meta(span: dict, span_format: TraceLibraryPayloadFormat | None = None) -> dict: + """Returns the meta dictionary of a span according to its format.""" + if span_format is None: + span_format = LibraryInterfaceValidator._detect_span_format(span) + + if span_format == TraceLibraryPayloadFormat.v04: + return span.get("meta", {}) + + if span_format == TraceLibraryPayloadFormat.v1: + # In v1 format, meta and metrics are joined in attributes + return span.get("attributes", {}) + + raise ValueError(f"Unknown span format: {span_format}") + + @staticmethod + def _is_numeric_value(value: float | str) -> bool: + """Check if a value is numeric (int, float, or string that can be converted to number).""" + if isinstance(value, (int, float)): + return True + if isinstance(value, str): + try: + float(value) + return True + except (ValueError, TypeError): + return False + return False + + @staticmethod + def get_span_metrics(span: dict, span_format: TraceLibraryPayloadFormat | None = None) -> dict: + """Returns the metrics dictionary of a span according to its format.""" + if span_format is None: + span_format = LibraryInterfaceValidator._detect_span_format(span) + + if span_format == TraceLibraryPayloadFormat.v04: + return span.get("metrics", {}) + + if span_format == TraceLibraryPayloadFormat.v1: + # In v1 format, metrics are in attributes, but we need to filter for numeric values + attributes = span.get("attributes", {}) + metric_keys = { + "process_id", + "_dd.top_level", + "_dd.sampling_priority", + "_sampling_priority_v1", + "_dd.agent_psr", + "_dd.trace_span_attribute_schema", + } + + return { + key: value + for key, value in attributes.items() + if key in metric_keys or (key.startswith("_dd.") and LibraryInterfaceValidator._is_numeric_value(value)) + } + + raise ValueError(f"Unknown span format: {span_format}") + + @staticmethod + def get_span_name(span: dict, span_format: TraceLibraryPayloadFormat | None = None) -> str: + """Returns the name of a span according to its format.""" + if span_format is None: + span_format = LibraryInterfaceValidator._detect_span_format(span) + + if span_format == TraceLibraryPayloadFormat.v04: + return span.get("name", "") + + if span_format == TraceLibraryPayloadFormat.v1: + return span.get("name_value", "") + + raise ValueError(f"Unknown span format: {span_format}") + + @staticmethod + def get_span_type(span: dict, span_format: TraceLibraryPayloadFormat | None = None) -> str: + """Returns the type of a span according to its format.""" + if span_format is None: + span_format = LibraryInterfaceValidator._detect_span_format(span) + + if span_format == TraceLibraryPayloadFormat.v04: + return span.get("type", "") + + if span_format == TraceLibraryPayloadFormat.v1: + return span.get("type_value", "") + + raise ValueError(f"Unknown span format: {span_format}") + + @staticmethod + def get_span_trace_id( + span: dict, trace_chunk: dict | None = None, span_format: TraceLibraryPayloadFormat | None = None + ) -> int: + """Returns the trace_id of a span according to its format.""" + if span_format is None: + span_format = LibraryInterfaceValidator._detect_span_format(span) + + if span_format == TraceLibraryPayloadFormat.v04: + return span.get("trace_id", 0) + + if span_format == TraceLibraryPayloadFormat.v1: + # In v1 format, trace_id is at the chunk level + if trace_chunk and "trace_id" in trace_chunk: + trace_id = trace_chunk["trace_id"] + # Convert hex string to int (extract lower 64 bits) + if isinstance(trace_id, str) and trace_id.startswith("0x"): + try: + return int(trace_id[-16:], 16) + except ValueError: + return int(trace_id, 16) if trace_id.startswith("0x") else 0 + return trace_id if isinstance(trace_id, int) else 0 + return 0 + + raise ValueError(f"Unknown span format: {span_format}") + + @staticmethod + def get_trace_id(trace: dict | list[dict], span_format: TraceLibraryPayloadFormat | None = None) -> int: + """Returns the trace_id from a trace according to its format. + + For v04 format, trace is a list of spans and trace_id is in the first span. + For v1 format, trace is a dict (chunk) with trace_id at the chunk level. + """ + if span_format is None: + # Try to detect format from trace structure + if isinstance(trace, dict) and "spans" in trace: + span_format = TraceLibraryPayloadFormat.v1 + elif isinstance(trace, list) and len(trace) > 0: + span_format = TraceLibraryPayloadFormat.v04 + else: + raise ValueError("Cannot determine span format from trace structure") + + if span_format == TraceLibraryPayloadFormat.v04: + # For v04, trace is a list of spans, trace_id is in the first span + if isinstance(trace, list) and len(trace) > 0: + return trace[0].get("trace_id", 0) + return 0 + + if span_format == TraceLibraryPayloadFormat.v1: + # For v1, trace is a dict (chunk) with trace_id at the chunk level + if isinstance(trace, dict) and "trace_id" in trace: + trace_id = trace["trace_id"] + # Convert hex string to int (extract lower 64 bits) + if isinstance(trace_id, str) and trace_id.startswith("0x"): + try: + return int(trace_id[-16:], 16) + except ValueError: + return int(trace_id, 16) if trace_id.startswith("0x") else 0 + return trace_id if isinstance(trace_id, int) else 0 + return 0 + + raise ValueError(f"Unknown span format: {span_format}") + + @staticmethod + def get_span_parent_id(span: dict, span_format: TraceLibraryPayloadFormat | None = None) -> int | None: + """Returns the parent_id of a span according to its format.""" + if span_format is None: + span_format = LibraryInterfaceValidator._detect_span_format(span) + + # parent_id is a top-level field in both formats + return span.get("parent_id") + + @staticmethod + def get_span_links(span: dict, span_format: TraceLibraryPayloadFormat | None = None) -> list[dict] | None: + """Returns the span links of a span according to its format. + + Returns a normalized list of span links with consistent field names: + - trace_id: int (lower 64 bits) + - trace_id_high: int | None (upper 64 bits, if present) + - span_id: int + - attributes: dict | None + - tracestate: str | None + - flags: int + """ + if span_format is None: + span_format = LibraryInterfaceValidator._detect_span_format(span) + + # Check for span_links at top level (v1 format or v04 with span_links field) + # Also check for "links" which is used in v1 format + span_links = span.get("span_links") or span.get("links") + if span_links is not None: + normalized_links = [] + for link in span_links: + normalized_link = {} + + # Handle trace_id - normalize to int (lower 64 bits) and optionally trace_id_high + trace_id = link.get("trace_id") + if trace_id is not None: + if isinstance(trace_id, str): + # Handle hex string format (e.g., "0x1234..." or "1234...") + trace_id_str = trace_id.removeprefix("0x") + if len(trace_id_str) >= LibraryInterfaceValidator._TRACE_ID_HEX_LENGTH: + # Extract lower 64 bits + normalized_link["trace_id"] = int( + trace_id_str[-LibraryInterfaceValidator._TRACE_ID_HEX_LENGTH :], 16 + ) + # Extract upper 64 bits if present + if len(trace_id_str) > LibraryInterfaceValidator._TRACE_ID_HEX_LENGTH: + normalized_link["trace_id_high"] = int( + trace_id_str[: -LibraryInterfaceValidator._TRACE_ID_HEX_LENGTH], 16 + ) + else: + normalized_link["trace_id"] = int(trace_id_str, 16) + elif isinstance(trace_id, int): + normalized_link["trace_id"] = trace_id + # Note: bytes should already be converted to hex string by deserializer + + # Handle span_id + span_id = link.get("span_id") + if span_id is not None: + if isinstance(span_id, str): + # Convert hex string to int if needed + normalized_link["span_id"] = ( + int(span_id, 16) + if span_id.startswith("0x") or all(c in "0123456789abcdefABCDEF" for c in span_id) + else int(span_id) + ) + else: + normalized_link["span_id"] = span_id + + # Copy other fields + if "attributes" in link: + normalized_link["attributes"] = link["attributes"] + if "tracestate" in link: + normalized_link["tracestate"] = link["tracestate"] + elif "trace_state" in link: + normalized_link["tracestate"] = link["trace_state"] + if "flags" in link: + # Ensure flags have the TRACECONTEXT_FLAGS_SET bit + normalized_link["flags"] = link["flags"] | (1 << 31) + else: + normalized_link["flags"] = 0 + + normalized_links.append(normalized_link) + return normalized_links + + # Check meta for _dd.span_links (v04 format stored in meta) + meta = LibraryInterfaceValidator.get_span_meta(span, span_format) + span_links_value = meta.get("_dd.span_links") + if span_links_value is not None: + # Convert span_links tags into normalized format + json_links = json.loads(span_links_value) + normalized_links = [] + for json_link in json_links: + normalized_link = {} + # Parse trace_id from hex string + trace_id_str = json_link["trace_id"] + if len(trace_id_str) >= LibraryInterfaceValidator._TRACE_ID_HEX_LENGTH: + normalized_link["trace_id"] = int( + trace_id_str[-LibraryInterfaceValidator._TRACE_ID_HEX_LENGTH :], 16 + ) + if len(trace_id_str) > LibraryInterfaceValidator._TRACE_ID_HEX_LENGTH: + normalized_link["trace_id_high"] = int( + trace_id_str[: -LibraryInterfaceValidator._TRACE_ID_HEX_LENGTH], 16 + ) + else: + normalized_link["trace_id"] = int(trace_id_str, 16) + # Parse span_id from hex string + normalized_link["span_id"] = int(json_link["span_id"], 16) + if "attributes" in json_link: + normalized_link["attributes"] = json_link["attributes"] + if "tracestate" in json_link: + normalized_link["tracestate"] = json_link["tracestate"] + elif "trace_state" in json_link: + normalized_link["tracestate"] = json_link["trace_state"] + if "flags" in json_link: + normalized_link["flags"] = json_link["flags"] | (1 << 31) + else: + normalized_link["flags"] = 0 + normalized_links.append(normalized_link) + return normalized_links + + return None + + @staticmethod + def get_span_meta_struct(span: dict, span_format: TraceLibraryPayloadFormat | None = None) -> dict: + """Returns the meta_struct dictionary of a span according to its format. + + For v04 format, returns span.get("meta_struct", {}). + For v1 format, checks if there's binary appsec data in attributes and decodes it. + """ + if span_format is None: + span_format = LibraryInterfaceValidator._detect_span_format(span) + + if span_format == TraceLibraryPayloadFormat.v04: + return span.get("meta_struct", {}) + + if span_format == TraceLibraryPayloadFormat.v1: + # In v1 format, appsec data that was in meta_struct is now in attributes as binary data + attributes = span.get("attributes", {}) + if not isinstance(attributes, dict): + return {} + + meta_struct = {} + + # Check if there's an "appsec" key in attributes with binary data + # The binary data is msgpack-encoded, similar to how meta_struct worked in v04 + if "appsec" in attributes: + appsec_value = attributes["appsec"] + + # Handle bytes value (msgpack-encoded binary data) + if isinstance(appsec_value, bytes): + try: + decoded = msgpack.unpackb(appsec_value, unicode_errors="replace", strict_map_key=False) + # The decoded value should be the appsec data dict directly + if isinstance(decoded, dict) and decoded: + meta_struct["appsec"] = decoded + except (msgpack.UnpackException, ValueError, TypeError): + # Not msgpack-encoded or invalid data, skip silently + pass + except Exception: + # Any other exception, log but don't fail + logger.debug(f"Unexpected error decoding appsec data: {type(Exception).__name__}") + elif isinstance(appsec_value, str): + # Handle string representation of bytes (e.g., "b'\\x81\\xa8triggers...'") + # This happens when the bytes are serialized to JSON as a string literal + try: + # Try to parse the bytes literal string (e.g., "b'\\x81...'" -> bytes) + if appsec_value.startswith(("b'", 'b"')): + bytes_value = ast.literal_eval(appsec_value) + if isinstance(bytes_value, bytes): + decoded = msgpack.unpackb(bytes_value, unicode_errors="replace", strict_map_key=False) + if isinstance(decoded, dict) and decoded: + meta_struct["appsec"] = decoded + except (ValueError, SyntaxError, TypeError, msgpack.UnpackException): + # Failed to parse or decode, skip silently + pass + except Exception: + # Any other exception, log but don't fail + logger.debug(f"Unexpected error decoding appsec string data: {type(Exception).__name__}") + + return meta_struct + + raise ValueError(f"Unknown span format: {span_format}") + + @staticmethod + def get_span_span_id(span: dict, span_format: TraceLibraryPayloadFormat | None = None) -> int | None: + """Returns the span_id of a span according to its format.""" + if span_format is None: + span_format = LibraryInterfaceValidator._detect_span_format(span) + + # span_id is a top-level field in both formats (may be called "id" in v1) + if span_format == TraceLibraryPayloadFormat.v1: + return span.get("id") or span.get("span_id") + return span.get("span_id") + def ingest_file(self, src_path: str): self.ready.set() return super().ingest_file(src_path) @@ -61,7 +456,7 @@ def wait_function(data: dict): ############################################################ def get_traces( self, request: HttpResponse | GrpcResponse | None = None - ) -> Generator[tuple[dict, list[dict]], None, None]: + ) -> Generator[tuple[dict, list[dict] | dict, TraceLibraryPayloadFormat], None, None]: rid: str | None = None if request: @@ -71,7 +466,8 @@ def get_traces( logger.warning("HTTP app failed to respond, it will very probably fail") trace_found = False - for data in self.get_data(path_filters=self.trace_paths): + # Handle v04/v05 traces + for data in self.get_data(path_filters=["/v0.4/traces", "/v0.5/traces"]): traces = data["request"]["content"] if not traces: # may be none continue @@ -79,13 +475,34 @@ def get_traces( for trace in traces: if rid is None: trace_found = True - yield data, trace + yield data, trace, TraceLibraryPayloadFormat.v04 else: for span in trace: if rid == get_rid_from_span(span): logger.debug(f"Found a trace in {data['log_filename']}") trace_found = True - yield data, trace + yield data, trace, TraceLibraryPayloadFormat.v04 + break + + # Handle v1.0 traces + for data in self.get_data(path_filters="/v1.0/traces"): + traces = data["request"]["content"] + if not traces: # may be none + continue + + if not traces.get("chunks"): + continue + + for trace_chunk in traces.get("chunks"): + if rid is None: + trace_found = True + yield data, trace_chunk, TraceLibraryPayloadFormat.v1 + else: + for span in trace_chunk.get("spans", []): + if rid == get_rid_from_span(span): + logger.debug(f"Found a trace in {data['log_filename']}") + trace_found = True + yield data, trace_chunk, TraceLibraryPayloadFormat.v1 break if not trace_found: @@ -125,54 +542,88 @@ def get_traces_v1(self, request: HttpResponse | GrpcResponse | None = None): if not trace_found: logger.warning("No trace found") - def get_spans(self, request: HttpResponse | None = None, *, full_trace: bool = False): + def get_spans( + self, request: HttpResponse | None = None, *, full_trace: bool = False + ) -> Generator[tuple[dict, dict | list[dict], dict, TraceLibraryPayloadFormat], None, None]: """Iterate over all spans reported by the tracer to the agent. If request is not None and full_trace is False, only span triggered by that request will be returned. If request is not None and full_trace is True, all spans from a trace triggered by that request will be returned. + + Returns: (data, trace, span, span_format) """ rid = request.get_rid() if request else None - for data, trace in self.get_traces(request=request): - for span in trace: - if rid is None or full_trace: - yield data, trace, span - elif rid == get_rid_from_span(span): - logger.debug(f"Found a span in {data['log_filename']}") - yield data, trace, span + for data, trace, trace_format in self.get_traces(request=request): + # Check if this is a v1 format trace (has "spans" key) or v04 format (list of spans) + if trace_format == TraceLibraryPayloadFormat.v1: + # v1 format: trace is a chunk with "spans" array + assert isinstance(trace, dict), "v1 format trace must be a dict" + spans = trace.get("spans", []) + for span in spans: + if rid is None or full_trace: + yield data, trace, span, trace_format + elif rid == get_rid_from_span(span): + logger.debug(f"Found a span in {data['log_filename']}") + yield data, trace, span, trace_format + else: + # v04 format: trace is a list of spans + assert isinstance(trace, list), "v04 format trace must be a list" + for span in trace: + if rid is None or full_trace: + yield data, trace, span, trace_format + elif rid == get_rid_from_span(span): + logger.debug(f"Found a span in {data['log_filename']}") + yield data, trace, span, trace_format def get_root_spans(self, request: HttpResponse | None = None): - for data, _, span in self.get_spans(request=request): - if span.get("parent_id") in (0, None): - yield data, span + """Returns root spans in their native format along with the format. + + Returns: (data, span, span_format) + Use helper methods (get_span_meta, get_span_metrics, etc.) to access span fields. + """ + for data, _trace, span, span_format in self.get_spans(request=request): + parent_id = span.get("parent_id") + if parent_id in (0, None): + yield data, span, span_format - def get_root_span(self, request: HttpResponse) -> dict: + def get_root_span(self, request: HttpResponse) -> tuple[dict, TraceLibraryPayloadFormat]: """Get the root span associated with a given request. This function will fail if a request is not given, if there is no root span, or if there is more than one root span. For special cases, use get_root_spans. + + Returns: (span, span_format) + Use helper methods to access span fields. """ assert request is not None, "A request object is mandatory" - spans = [s for _, s in self.get_root_spans(request=request)] + spans = [(s, f) for _, s, f in self.get_root_spans(request=request)] assert spans, "No root spans found" assert len(spans) == 1, "More then one root span found" return spans[0] def get_appsec_events(self, request: HttpResponse | None = None, *, full_trace: bool = False): - for data, trace, span in self.get_spans(request=request, full_trace=full_trace): - if "appsec" in span.get("meta_struct", {}): + for data, trace, span, span_format in self.get_spans(request=request, full_trace=full_trace): + # Try to get appsec data from meta_struct (works for both v04 and v1 formats) + meta_struct = self.get_span_meta_struct(span, span_format) + if "appsec" in meta_struct: if request: # do not spam log if all data are sent to the validator - logger.debug(f"Try to find relevant appsec data in {data['log_filename']}; span #{span['span_id']}") + span_id = self.get_span_span_id(span, span_format) or "unknown" + logger.debug(f"Try to find relevant appsec data in {data['log_filename']}; span #{span_id}") - appsec_data = span["meta_struct"]["appsec"] + appsec_data = meta_struct["appsec"] yield data, trace, span, appsec_data - elif "_dd.appsec.json" in span.get("meta", {}): - if request: # do not spam log if all data are sent to the validator - logger.debug(f"Try to find relevant appsec data in {data['log_filename']}; span #{span['span_id']}") + else: + # Fallback to _dd.appsec.json in meta/attributes + meta = self.get_span_meta(span, span_format) + if "_dd.appsec.json" in meta: + if request: # do not spam log if all data are sent to the validator + span_id = self.get_span_span_id(span, span_format) or "unknown" + logger.debug(f"Try to find relevant appsec data in {data['log_filename']}; span #{span_id}") - appsec_data = span["meta"]["_dd.appsec.json"] - yield data, trace, span, appsec_data + appsec_data = meta["_dd.appsec.json"] + yield data, trace, span, appsec_data def get_legacy_appsec_events(self, request: HttpResponse | None = None): paths_with_appsec_events = ["/appsec/proxy/v1/input", "/appsec/proxy/api/v2/appsecevts"] @@ -274,7 +725,7 @@ def validator_skip_onboarding_event(data: dict) -> None: def validate_one_appsec( self, request: HttpResponse | None = None, - validator: Callable[[dict, dict], bool] | None = None, + validator: Callable[[dict, dict], bool] | Callable[[dict, dict, TraceLibraryPayloadFormat], bool] | None = None, *, legacy_validator: Callable | None = None, full_trace: bool = False, @@ -285,11 +736,28 @@ def validate_one_appsec( * If validator() raise an exception. the validate_one will fail If no payload satisfies validator(), then validate_one will fail + + The validator can accept either: + - (span, appsec_data, span_format) - recommended, use helper methods for format-agnostic access + - (span, appsec_data) - for backward compatibility, but will only work with v04 format spans """ if validator: for _, _, span, appsec_data in self.get_appsec_events(request=request, full_trace=full_trace): - if validator(span, appsec_data) is True: - return + # Detect span format and try to call validator with format if it accepts 3 parameters + span_format = self._detect_span_format(span) + try: + sig = inspect.signature(validator) + if len(sig.parameters) == 3: # noqa: PLR2004 + # Validator accepts (span, appsec_data, span_format) + if validator(span, appsec_data, span_format) is True: # type: ignore[call-arg] + return + # Validator accepts only (span, appsec_data) - backward compatibility + elif validator(span, appsec_data) is True: # type: ignore[call-arg] + return + except TypeError: + # Fallback: try calling with 2 parameters + if validator(span, appsec_data) is True: # type: ignore[call-arg] + return if legacy_validator: for _, event in self.get_legacy_appsec_events(request=request): @@ -320,11 +788,13 @@ def validate_all_appsec( ###################################################### def assert_iast_implemented(self): - for _, span in self.get_root_spans(): - if "_dd.iast.enabled" in span.get("metrics", {}): + for _, span, span_format in self.get_root_spans(): + metrics = self.get_span_metrics(span, span_format) + if "_dd.iast.enabled" in metrics: return - if "_dd.iast.enabled" in span.get("meta", {}): + meta = self.get_span_meta(span, span_format) + if "_dd.iast.enabled" in meta: return raise ValueError("_dd.iast.enabled has not been found in any metrics") @@ -342,8 +812,9 @@ def assert_headers_presence( def assert_receive_request_root_trace(self): # TODO : move this in test class """Asserts that a trace for a request has been sent to the agent""" - for _, span in self.get_root_spans(): - if span.get("type") == "web": + for _, span, span_format in self.get_root_spans(): + span_type = self.get_span_type(span, span_format) + if span_type == "web": return raise ValueError("Nothing has been reported. No request root span with has been found") @@ -351,14 +822,25 @@ def assert_receive_request_root_trace(self): # TODO : move this in test class def assert_trace_id_uniqueness(self): trace_ids: dict[int, str] = {} - for data, trace in self.get_traces(): - spans = [span for span in trace if span.get("parent_id") in ("0", 0, None)] + for data, trace, trace_format in self.get_traces(): + if trace_format == TraceLibraryPayloadFormat.v1: + assert isinstance(trace, dict), "v1 format trace must be a dict" + spans = trace.get("spans", []) + else: + assert isinstance(trace, list), "v04 format trace must be a list" + spans = trace + + root_spans = [span for span in spans if span.get("parent_id") in ("0", 0, None)] - if spans: + if root_spans: log_filename = data["log_filename"] - span = spans[0] - assert "trace_id" in span, f"'trace_id' is missing in {log_filename}" - trace_id = span["trace_id"] + span = root_spans[0] + span_format = self._detect_span_format(span) + trace_dict: dict | None = ( + trace if trace_format == TraceLibraryPayloadFormat.v1 and isinstance(trace, dict) else None + ) + trace_id = self.get_span_trace_id(span, trace_dict, span_format) + assert trace_id != 0, f"'trace_id' is missing in {log_filename}" if trace_id in trace_ids: raise ValueError( @@ -430,9 +912,17 @@ def validate_one_trace(self, request: HttpResponse, validator: Callable[[list[di If no payload satisfies validator(), then validate_one will fail """ - for data, trace in self.get_traces(request=request): + for data, trace, trace_format in self.get_traces(request=request): + # For v1 format, extract spans from chunk for backward compatibility + if trace_format == TraceLibraryPayloadFormat.v1: + assert isinstance(trace, dict), "v1 format trace must be a dict" + trace_spans = trace.get("spans", []) + else: + assert isinstance(trace, list), "v04 format trace must be a list" + trace_spans = trace + try: - if validator(trace) is True: + if validator(trace_spans) is True: return except Exception: logger.error(f"{data['log_filename']} did not validate this test") @@ -445,7 +935,7 @@ def validate_one_span( self, request: HttpResponse | None = None, *, - validator: Callable[[dict], bool], + validator: Callable[[dict, TraceLibraryPayloadFormat], bool] | Callable[[dict], bool], full_trace: bool = False, ): """Will call validator() on all spans (eventually filtered on span trigerred by request). @@ -456,11 +946,23 @@ def validate_one_span( * If validator() raise an exception. the validate_one will fail If no payload satisfies validator(), then validate_one will fail + + The validator can accept either: + - (span, span_format) - recommended, use helper methods for format-agnostic access + - (span) - for backward compatibility, but will only work with v04 format spans """ - for _, _, span in self.get_spans(request=request, full_trace=full_trace): + for _, _trace, span, span_format in self.get_spans(request=request, full_trace=full_trace): try: - if validator(span) is True: - return + sig = inspect.signature(validator) + if len(sig.parameters) == 2: # noqa: PLR2004 + # Validator accepts (span, format) + if validator(span, span_format) is True: # type: ignore[call-arg] + return + # Validator accepts only (span) - backward compatibility + # Only works for v04 format + elif span_format == TraceLibraryPayloadFormat.v04: + if validator(span) is True: # type: ignore[call-arg] + return except Exception as e: logger.error(f"This span is failing validation ({e}): {json.dumps(span, indent=2)}") raise @@ -471,18 +973,29 @@ def validate_all_spans( self, request: HttpResponse | None = None, *, - validator: Callable[[dict], None], + validator: Callable[[dict, TraceLibraryPayloadFormat], None] | Callable[[dict], None], full_trace: bool = False, allow_no_data: bool = False, ): """Will call validator() on all spans (eventually filtered on span trigerred by request) If ever a validator raise an exception, the validation will fail + + The validator can accept either: + - (span, span_format) - recommended, use helper methods for format-agnostic access + - (span) - for backward compatibility, but will only work with v04 format spans """ data_is_missing = True - for _, _, span in self.get_spans(request=request, full_trace=full_trace): + for _, _trace, span, span_format in self.get_spans(request=request, full_trace=full_trace): data_is_missing = False try: - validator(span) + sig = inspect.signature(validator) + if len(sig.parameters) == 2: # noqa: PLR2004 + # Validator accepts (span, format) + validator(span, span_format) # type: ignore[call-arg] + # Validator accepts only (span) - backward compatibility + # Only works for v04 format + elif span_format == TraceLibraryPayloadFormat.v04: + validator(span) # type: ignore[call-arg] except Exception as e: logger.error(f"This span is failing validation ({e}): {json.dumps(span, indent=2)}") raise @@ -498,10 +1011,12 @@ def add_span_tag_validation( value_as_regular_expression: bool = False, full_trace: bool = False, ): - validator = _SpanTagValidator(tags=tags, value_as_regular_expression=value_as_regular_expression) + validator = _SpanTagValidator( + tags=tags, value_as_regular_expression=value_as_regular_expression, library_interface=self + ) success = False - for _, _, span in self.get_spans(request=request, full_trace=full_trace): - success = success or validator(span) + for _, _trace, span, span_format in self.get_spans(request=request, full_trace=full_trace): + success = success or validator(span, span_format) if not success: raise ValueError("Can't find anything to validate this test") @@ -520,8 +1035,11 @@ def get_profiling_data(self): yield from self.get_data(path_filters="/profiling/v1/input") def assert_trace_exists(self, request: HttpResponse, span_type: str | None = None): - for _, _, span in self.get_spans(request=request): - if span_type is None or span.get("type") == span_type: + for _, _trace, span, span_format in self.get_spans(request=request): + if span_type is None: + return + actual_type = self.get_span_type(span, span_format) + if actual_type == span_type: return raise ValueError(f"No trace has been found for request {request.get_rid()}") diff --git a/utils/interfaces/_library/miscs.py b/utils/interfaces/_library/miscs.py index 0d13fe14ab8..3511c81b3a5 100644 --- a/utils/interfaces/_library/miscs.py +++ b/utils/interfaces/_library/miscs.py @@ -4,36 +4,6 @@ """Misc validations""" -import re - - -class _SpanTagValidator: - """will run an arbitrary check on spans. If a request is provided, only span""" - - path_filters = ["/v0.4/traces", "/v0.5/traces"] - - def __init__(self, tags: dict | None, *, value_as_regular_expression: bool): - self.tags = {} if tags is None else tags - self.value_as_regular_expression = value_as_regular_expression - - def __call__(self, span: dict): - for tag_key in self.tags: - if tag_key not in span["meta"]: - raise ValueError(f"{tag_key} tag not found in span's meta") - - expect_value = self.tags[tag_key] - actual_value = span["meta"][tag_key] - - if self.value_as_regular_expression: - if not re.compile(expect_value).fullmatch(actual_value): - raise ValueError( - f'{tag_key} tag value is "{actual_value}", and should match regex "{expect_value}"' - ) - elif expect_value != actual_value: - raise ValueError(f'{tag_key} tag in span\'s meta should be "{expect_value}", not "{actual_value}"') - - return True - def validate_process_tags(process_tags: str): # entrypoint name and workdir can always be defined. diff --git a/utils/proxy/traces/trace_v1.py b/utils/proxy/traces/trace_v1.py index 16b11b4215b..85f73dcb93f 100644 --- a/utils/proxy/traces/trace_v1.py +++ b/utils/proxy/traces/trace_v1.py @@ -1,4 +1,5 @@ import base64 +import contextlib from enum import IntEnum import msgpack @@ -231,6 +232,15 @@ def _uncompress_spans(spans: list, strings: list[str]) -> list: enum_key = V1SpanKeys(k) if enum_key.name in _span_key_strings and isinstance(value, int): value = strings[v] + # Handle span_links specially - they need to be uncompressed + if enum_key == V1SpanKeys.span_links: + if value is not None: + value = _uncompress_span_links_list(value, strings) + # Handle span_events specially - they need to be uncompressed + elif enum_key == V1SpanKeys.span_events: + if value is not None: + # Debug: Uncompressing span events + value = _uncompress_span_events_list(value, strings) uncompressed_span[enum_key.name] = value except ValueError as e: raise ValueError(f"Unknown V1SpanKey: {k}") from e @@ -288,10 +298,143 @@ def deserialize_v1_trace(content: bytes) -> dict: return data +def _uncompress_span_links_list(span_links: list | None, strings: list[str]) -> list | None: + """Uncompress a list of span links by converting integer keys to string keys.""" + if span_links is None or not isinstance(span_links, list): + return span_links + + uncompressed_links = [] + for link in span_links: + if not isinstance(link, dict): + uncompressed_links.append(link) + continue + uncompressed_link = {} + for k, v in link.items(): + value = v + try: + # Check if k is a valid enum value by trying to create the enum + enum_key = V1SpanLinkKeys(k) + # Convert integer key to string key name + uncompressed_link[enum_key.name] = value + except ValueError: + # Keep non-enum keys as-is (for backward compatibility) + uncompressed_link[k] = value + + # Deserialize trace_id if present (handle both bytes and base64-encoded string) + if "trace_id" in uncompressed_link: + trace_id = uncompressed_link["trace_id"] + if isinstance(trace_id, bytes): + # Convert bytes to hex string (same as chunk trace IDs) + uncompressed_link["trace_id"] = "0x" + trace_id.hex().upper() + elif isinstance(trace_id, str): + try: + # Decode the base64-encoded trace_id string to bytes, then to hex + trace_id_bytes = base64.b64decode(trace_id) + uncompressed_link["trace_id"] = "0x" + trace_id_bytes.hex().upper() + except Exception: # noqa: S110 + # If it's not base64, it might already be in hex format + pass + + # Uncompress attributes + if "attributes" in uncompressed_link: + attrs = uncompressed_link["attributes"] + # Check if attributes are in list format (key, type, value triplets) + if isinstance(attrs, list): + uncompressed_link["attributes"] = _attributes_to_dict(attrs, strings) + else: + with contextlib.suppress(Exception): + # If attributes can't be uncompressed, keep as-is + uncompressed_link["attributes"] = _uncompress_attributes(attrs, strings) + + # Resolve tracestateRef to tracestate (if present as integer key) + if "trace_state" in uncompressed_link: + trace_state = uncompressed_link["trace_state"] + # If trace_state is an integer, it might be a reference to strings array + if isinstance(trace_state, int) and trace_state < len(strings): + uncompressed_link["tracestate"] = strings[trace_state] + + uncompressed_links.append(uncompressed_link) + return uncompressed_links + + +def _uncompress_span_events_list(span_events: list | None, strings: list[str]) -> list | None: + """Uncompress a list of span events by converting integer keys to string keys.""" + if span_events is None or not isinstance(span_events, list): + return span_events + + uncompressed_events = [] + for event in span_events: + if not isinstance(event, dict): + uncompressed_events.append(event) + continue + uncompressed_event = {} + for k, v in event.items(): + value = v + try: + # Check if k is a valid enum value by trying to create the enum + enum_key = V1SpanEventKeys(k) + # Convert integer key to string key name + # Map time -> time_unix_nano and name_value -> name for consistency + if enum_key == V1SpanEventKeys.time: + uncompressed_event["time_unix_nano"] = value + elif enum_key == V1SpanEventKeys.name_value: + # Resolve name_value from strings array if it's an integer + if isinstance(value, int) and value < len(strings): + uncompressed_event["name"] = strings[value] + else: + uncompressed_event["name"] = value + else: + uncompressed_event[enum_key.name] = value + except ValueError: + # Keep non-enum keys as-is (for backward compatibility) + uncompressed_event[k] = value + + # Uncompress attributes + if "attributes" in uncompressed_event: + attrs = uncompressed_event["attributes"] + # Check if attributes are in list format (key, type, value triplets) + if isinstance(attrs, list): + uncompressed_event["attributes"] = _attributes_to_dict(attrs, strings) + else: + with contextlib.suppress(Exception): + # If attributes can't be uncompressed, keep as-is + uncompressed_event["attributes"] = _uncompress_attributes(attrs, strings) + + uncompressed_events.append(uncompressed_event) + return uncompressed_events + + def _uncompress_span_link(link: dict, strings: list[str]) -> None: - """Uncompress a span link by deserializing traceID, attributes, and tracestate.""" - # Deserialize the base64-encoded traceID - _deserialize_base64_trace_id(link) + """Uncompress a span link by deserializing traceID, attributes, and tracestate. + This function is used for agent interface where links are already partially processed. + """ + # Convert integer keys to string keys if needed + if any(isinstance(k, int) for k in link): + uncompressed_link = {} + for k, v in link.items(): + try: + enum_key = V1SpanLinkKeys(k) + uncompressed_link[enum_key.name] = v + except ValueError: + # Keep non-enum keys as-is + uncompressed_link[k] = v + link.clear() + link.update(uncompressed_link) + + # Deserialize traceID (handle both bytes and base64-encoded string) + if "trace_id" in link: + trace_id = link["trace_id"] + if isinstance(trace_id, bytes): + # Convert bytes to hex string (same as chunk trace IDs) + link["trace_id"] = "0x" + trace_id.hex().upper() + elif isinstance(trace_id, str) and not trace_id.startswith("0x"): + try: + trace_id_bytes = base64.b64decode(trace_id) + link["trace_id"] = "0x" + trace_id_bytes.hex().upper() + except Exception: # noqa: S110 + pass + elif "traceID" in link: + _deserialize_base64_trace_id(link) # Uncompress attributes if "attributes" in link: @@ -302,6 +445,53 @@ def _uncompress_span_link(link: dict, strings: list[str]) -> None: tracestate_ref = link.pop("tracestateRef") if isinstance(tracestate_ref, int) and tracestate_ref < len(strings): link["tracestate"] = strings[tracestate_ref] + elif "trace_state" in link: + trace_state = link["trace_state"] + if isinstance(trace_state, int) and trace_state < len(strings): + link["tracestate"] = strings[trace_state] + + +def _uncompress_span_event(event: dict, strings: list[str]) -> None: + """Uncompress a span event by deserializing time, name, and attributes. + This function is used for agent interface where events are already partially processed. + """ + # Convert integer keys to string keys if needed + if any(isinstance(k, int) for k in event): + uncompressed_event = {} + for k, v in event.items(): + try: + enum_key = V1SpanEventKeys(k) + # Map time -> time_unix_nano and name_value -> name for consistency + if enum_key == V1SpanEventKeys.time: + uncompressed_event["time_unix_nano"] = v + elif enum_key == V1SpanEventKeys.name_value: + # Resolve name_value from strings array if it's an integer + if isinstance(v, int) and v < len(strings): + uncompressed_event["name"] = strings[v] + else: + uncompressed_event["name"] = v + else: + uncompressed_event[enum_key.name] = v + except ValueError: + # Keep non-enum keys as-is + uncompressed_event[k] = v + event.clear() + event.update(uncompressed_event) + else: + # Handle name_value -> name mapping even if keys are already strings + if "name_value" in event: + name_value = event.pop("name_value") + if isinstance(name_value, int) and name_value < len(strings): + event["name"] = strings[name_value] + else: + event["name"] = name_value + # Handle time -> time_unix_nano mapping + if "time" in event and "time_unix_nano" not in event: + event["time_unix_nano"] = event.pop("time") + + # Uncompress attributes + if "attributes" in event: + event["attributes"] = _uncompress_attributes(event["attributes"], strings) def _uncompress_agent_v1_trace(data: dict, interface: str): @@ -323,4 +513,9 @@ def _uncompress_agent_v1_trace(data: dict, interface: str): # Uncompress span links for link in span.get("links", []): _uncompress_span_link(link, strings) + # Uncompress span events (handle both camelCase and snake_case field names) + span_events = span.get("spanEvents") or span.get("span_events") + if span_events: + for event in span_events: + _uncompress_span_event(event, strings) return data diff --git a/utils/tools.py b/utils/tools.py index adb9f81f7ab..e46f598218d 100644 --- a/utils/tools.py +++ b/utils/tools.py @@ -72,7 +72,9 @@ def get_rid_from_span(span: dict) -> str | None: user_agent = None - if span.get("type") == "rpc": + # Handle both v04 format (type) and v1 format (type_value) + span_type = span.get("type") or span.get("type_value") + if span_type == "rpc": user_agent = meta.get("grpc.metadata.user-agent") # java does not fill this tag; it uses the normal http tags