Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 130 additions & 63 deletions xtest/audit_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,11 @@ def record_sample(
event_time: When the event occurred (service clock, from JSON)
"""
# Convert both to UTC for comparison
# astimezone() handles both naive (assumes local) and aware datetimes
collection_utc = collection_time.astimezone(UTC)
if collection_time.tzinfo is None:
# Assume local time, convert to UTC
collection_utc = collection_time.astimezone(UTC)
else:
collection_utc = collection_time.astimezone(UTC)
Comment on lines +198 to +202
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This if/else block is redundant as both branches execute the same code: collection_utc = collection_time.astimezone(UTC). The datetime.astimezone() method already handles both naive (assuming local timezone) and aware datetimes correctly, making the original implementation more concise. This redundant logic is also repeated at line 365.

        collection_utc = collection_time.astimezone(UTC)


if event_time.tzinfo is None:
# Assume UTC if no timezone (common for service logs)
Expand Down Expand Up @@ -265,43 +268,47 @@ def __repr__(self) -> str:


# Audit event constants from platform/service/logger/audit/constants.go
# These are defined as Literal types for static type checking
ObjectType = Literal[
"subject_mapping",
"resource_mapping",
"attribute_definition",
"attribute_value",
"obligation_definition",
"obligation_value",
"obligation_trigger",
"namespace",
"condition_set",
"kas_registry",
"kas_attribute_namespace_assignment",
"kas_attribute_definition_assignment",
"kas_attribute_value_assignment",
"key_object",
"entity_object",
"resource_mapping_group",
"public_key",
"action",
"registered_resource",
"registered_resource_value",
"key_management_provider_config",
"kas_registry_keys",
"kas_attribute_definition_key_assignment",
"kas_attribute_value_key_assignment",
"kas_attribute_namespace_key_assignment",
"namespace_certificate",
]

ActionType = Literal["create", "read", "update", "delete", "rewrap", "rotate"]

ActionResult = Literal[
"success", "failure", "error", "encrypt", "block", "ignore", "override", "cancel"
]

AuditVerb = Literal["decision", "policy crud", "rewrap"]
OBJECT_TYPES = frozenset(
{
"subject_mapping",
"resource_mapping",
"attribute_definition",
"attribute_value",
"obligation_definition",
"obligation_value",
"obligation_trigger",
"namespace",
"condition_set",
"kas_registry",
"kas_attribute_namespace_assignment",
"kas_attribute_definition_assignment",
"kas_attribute_value_assignment",
"key_object",
"entity_object",
"resource_mapping_group",
"public_key",
"action",
"registered_resource",
"registered_resource_value",
"key_management_provider_config",
"kas_registry_keys",
"kas_attribute_definition_key_assignment",
"kas_attribute_value_key_assignment",
"kas_attribute_namespace_key_assignment",
"namespace_certificate",
}
)

ACTION_TYPES = frozenset({"create", "read", "update", "delete", "rewrap", "rotate"})

ACTION_RESULTS = frozenset(
{"success", "failure", "error", "encrypt", "block", "ignore", "override", "cancel"}
)

# Audit log message verbs
VERB_DECISION = "decision"
VERB_POLICY_CRUD = "policy crud"
VERB_REWRAP = "rewrap"


@dataclass
Expand Down Expand Up @@ -354,9 +361,11 @@ def observed_skew(self) -> float | None:
return None

# Convert collection time to UTC for comparison
# astimezone() handles both naive (assumes local) and aware datetimes
collection_t = self.collection_time
collection_utc = collection_t.astimezone(UTC)
if collection_t.tzinfo is None:
collection_utc = collection_t.astimezone(UTC)
else:
collection_utc = collection_t.astimezone(UTC)

if event_t.tzinfo is None:
event_utc = event_t.replace(tzinfo=UTC)
Expand Down Expand Up @@ -479,7 +488,7 @@ def matches_rewrap(
Returns:
True if event matches all specified criteria
"""
if self.msg != "rewrap":
if self.msg != VERB_REWRAP:
return False
if result is not None and self.action_result != result:
return False
Expand Down Expand Up @@ -513,7 +522,7 @@ def matches_policy_crud(
Returns:
True if event matches all specified criteria
"""
if self.msg != "policy crud":
if self.msg != VERB_POLICY_CRUD:
return False
if result is not None and self.action_result != result:
return False
Expand Down Expand Up @@ -541,7 +550,7 @@ def matches_decision(
Returns:
True if event matches all specified criteria
"""
if self.msg != "decision":
if self.msg != VERB_DECISION:
return False
if result is not None and self.action_result != result:
return False
Expand Down Expand Up @@ -619,6 +628,7 @@ def __init__(
self._mark_counter = 0
self._threads: list[threading.Thread] = []
self._stop_event = threading.Event()
self._new_data = threading.Condition()
self._disabled = False
self._error: Exception | None = None
self.log_file_path: Path | None = None
Expand Down Expand Up @@ -650,18 +660,22 @@ def start(self) -> None:
self._disabled = True
return

any_file_exists = any(path.exists() for path in self.log_files.values())
if not any_file_exists:
existing_files = {
service: path for service, path in self.log_files.items() if path.exists()
}

if not existing_files:
logger.warning(
f"None of the log files exist yet: {list(self.log_files.values())}. "
f"Will wait for them to be created..."
)
existing_files = self.log_files

logger.debug(
f"Starting file-based log collection for: {list(self.log_files.keys())}"
f"Starting file-based log collection for: {list(existing_files.keys())}"
)

for service, log_path in self.log_files.items():
for service, log_path in existing_files.items():
thread = threading.Thread(
target=self._tail_file,
args=(service, log_path),
Expand All @@ -671,7 +685,7 @@ def start(self) -> None:
self._threads.append(thread)

logger.info(
f"Audit log collection started for: {', '.join(self.log_files.keys())}"
f"Audit log collection started for: {', '.join(existing_files.keys())}"
)

def stop(self) -> None:
Expand All @@ -681,6 +695,9 @@ def stop(self) -> None:

logger.debug("Stopping audit log collection")
self._stop_event.set()
# Wake any threads waiting on new data so they can exit promptly
with self._new_data:
self._new_data.notify_all()

for thread in self._threads:
if thread.is_alive():
Expand Down Expand Up @@ -785,6 +802,22 @@ def write_to_disk(self, path: Path) -> None:
self.log_file_written = True
logger.info(f"Wrote {len(self._buffer)} audit log entries to {path}")

def wait_for_new_data(self, timeout: float = 0.1) -> bool:
"""Wait for new log data to arrive.

Blocks until new data is appended by a tail thread, or until timeout.
More efficient than polling with time.sleep() since it wakes up
immediately when data arrives.

Args:
timeout: Maximum time to wait in seconds (default: 0.1)

Returns:
True if woken by new data, False if timed out
"""
with self._new_data:
return self._new_data.wait(timeout=timeout)

def _tail_file(self, service: str, log_path: Path) -> None:
"""Background thread target that tails a log file.

Expand All @@ -808,14 +841,23 @@ def _tail_file(self, service: str, log_path: Path) -> None:
f.seek(0, 2)

while not self._stop_event.is_set():
line = f.readline()
if line:
# Batch-read all available lines before notifying
got_data = False
while True:
line = f.readline()
if not line:
break
entry = LogEntry(
timestamp=datetime.now(),
raw_line=line.rstrip(),
service_name=service,
)
self._buffer.append(entry)
got_data = True

if got_data:
with self._new_data:
self._new_data.notify_all()
else:
self._stop_event.wait(0.1)
except Exception as e:
Expand All @@ -838,6 +880,15 @@ def __init__(self, collector: AuditLogCollector | None):
"""
self._collector = collector

@property
def is_enabled(self) -> bool:
"""Check if audit log collection is enabled.

Returns:
True if collection is active, False if disabled or no collector
"""
return self._collector is not None and not self._collector._disabled

def mark(self, label: str) -> str:
"""Mark a timestamp for later correlation.

Expand Down Expand Up @@ -927,7 +978,7 @@ def assert_contains(
matching: list[LogEntry] = []
logs: list[LogEntry] = []

while time.time() - start_time < timeout:
while True:
logs = self._collector.get_logs(since=since)
matching = [log for log in logs if regex.search(log.raw_line)]

Expand All @@ -939,8 +990,11 @@ def assert_contains(
)
return matching

# Sleep briefly before checking again
time.sleep(0.1)
remaining = timeout - (time.time() - start_time)
if remaining <= 0:
break
# Wait for new data or timeout
self._collector.wait_for_new_data(timeout=min(remaining, 1.0))

# Timeout expired, raise error if we don't have enough matches
timeout_time = datetime.now()
Expand Down Expand Up @@ -1141,7 +1195,7 @@ def parse_audit_log(

# Verify msg is one of the known audit verbs
msg = data.get("msg", "")
if msg not in ("decision", "policy crud", "rewrap"):
if msg not in (VERB_DECISION, VERB_POLICY_CRUD, VERB_REWRAP):
return None

event = ParsedAuditEvent(
Expand Down Expand Up @@ -1185,7 +1239,7 @@ def get_parsed_audit_logs(

# Wait a bit for logs to arrive
start_time = time.time()
while time.time() - start_time < timeout:
while True:
logs = self._collector.get_logs(since=since)
parsed = []
for entry in logs:
Expand All @@ -1194,7 +1248,11 @@ def get_parsed_audit_logs(
parsed.append(event)
if parsed:
return parsed
time.sleep(0.1)

remaining = timeout - (time.time() - start_time)
if remaining <= 0:
break
self._collector.wait_for_new_data(timeout=min(remaining, 1.0))

return []

Expand Down Expand Up @@ -1245,7 +1303,7 @@ def assert_rewrap(
matching: list[ParsedAuditEvent] = []
all_logs: list[LogEntry] = []

while time.time() - start_time < timeout:
while True:
all_logs = self._collector.get_logs(since=since)
matching = []

Expand All @@ -1267,7 +1325,10 @@ def assert_rewrap(
)
return matching

time.sleep(0.1)
remaining = timeout - (time.time() - start_time)
if remaining <= 0:
break
self._collector.wait_for_new_data(timeout=min(remaining, 1.0))

# Build detailed error message
timeout_time = datetime.now()
Expand Down Expand Up @@ -1426,7 +1487,7 @@ def assert_policy_crud(
matching: list[ParsedAuditEvent] = []
all_logs: list[LogEntry] = []

while time.time() - start_time < timeout:
while True:
all_logs = self._collector.get_logs(since=since)
matching = []

Expand All @@ -1447,7 +1508,10 @@ def assert_policy_crud(
)
return matching

time.sleep(0.1)
remaining = timeout - (time.time() - start_time)
if remaining <= 0:
break
self._collector.wait_for_new_data(timeout=min(remaining, 1.0))

# Build detailed error message
timeout_time = datetime.now()
Expand Down Expand Up @@ -1586,7 +1650,7 @@ def assert_decision_v2(
matching: list[ParsedAuditEvent] = []
all_logs: list[LogEntry] = []

while time.time() - start_time < timeout:
while True:
all_logs = self._collector.get_logs(since=since)
matching = []

Expand All @@ -1608,7 +1672,10 @@ def assert_decision_v2(
)
return matching

time.sleep(0.1)
remaining = timeout - (time.time() - start_time)
if remaining <= 0:
break
self._collector.wait_for_new_data(timeout=min(remaining, 1.0))

# Build detailed error message
timeout_time = datetime.now()
Expand Down
Loading
Loading