Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 117 additions & 58 deletions xtest/audit_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ def record_sample(
event_time: When the event occurred (service clock, from JSON)
"""
# Convert both to UTC for comparison
# astimezone() handles both naive (assumes local) and aware datetimes
collection_utc = collection_time.astimezone(UTC)

if event_time.tzinfo is None:
Expand Down Expand Up @@ -265,43 +264,47 @@ def __repr__(self) -> str:


# Audit event constants from platform/service/logger/audit/constants.go
# These are defined as Literal types for static type checking
ObjectType = Literal[
"subject_mapping",
"resource_mapping",
"attribute_definition",
"attribute_value",
"obligation_definition",
"obligation_value",
"obligation_trigger",
"namespace",
"condition_set",
"kas_registry",
"kas_attribute_namespace_assignment",
"kas_attribute_definition_assignment",
"kas_attribute_value_assignment",
"key_object",
"entity_object",
"resource_mapping_group",
"public_key",
"action",
"registered_resource",
"registered_resource_value",
"key_management_provider_config",
"kas_registry_keys",
"kas_attribute_definition_key_assignment",
"kas_attribute_value_key_assignment",
"kas_attribute_namespace_key_assignment",
"namespace_certificate",
]

ActionType = Literal["create", "read", "update", "delete", "rewrap", "rotate"]

ActionResult = Literal[
"success", "failure", "error", "encrypt", "block", "ignore", "override", "cancel"
]

AuditVerb = Literal["decision", "policy crud", "rewrap"]
OBJECT_TYPES = frozenset(
{
"subject_mapping",
"resource_mapping",
"attribute_definition",
"attribute_value",
"obligation_definition",
"obligation_value",
"obligation_trigger",
"namespace",
"condition_set",
"kas_registry",
"kas_attribute_namespace_assignment",
"kas_attribute_definition_assignment",
"kas_attribute_value_assignment",
"key_object",
"entity_object",
"resource_mapping_group",
"public_key",
"action",
"registered_resource",
"registered_resource_value",
"key_management_provider_config",
"kas_registry_keys",
"kas_attribute_definition_key_assignment",
"kas_attribute_value_key_assignment",
"kas_attribute_namespace_key_assignment",
"namespace_certificate",
}
)

ACTION_TYPES = frozenset({"create", "read", "update", "delete", "rewrap", "rotate"})

ACTION_RESULTS = frozenset(
{"success", "failure", "error", "encrypt", "block", "ignore", "override", "cancel"}
)

# Audit log message verbs
VERB_DECISION = "decision"
VERB_POLICY_CRUD = "policy crud"
VERB_REWRAP = "rewrap"


@dataclass
Expand Down Expand Up @@ -354,7 +357,6 @@ def observed_skew(self) -> float | None:
return None

# Convert collection time to UTC for comparison
# astimezone() handles both naive (assumes local) and aware datetimes
collection_t = self.collection_time
collection_utc = collection_t.astimezone(UTC)

Expand Down Expand Up @@ -479,7 +481,7 @@ def matches_rewrap(
Returns:
True if event matches all specified criteria
"""
if self.msg != "rewrap":
if self.msg != VERB_REWRAP:
return False
if result is not None and self.action_result != result:
return False
Expand Down Expand Up @@ -513,7 +515,7 @@ def matches_policy_crud(
Returns:
True if event matches all specified criteria
"""
if self.msg != "policy crud":
if self.msg != VERB_POLICY_CRUD:
return False
if result is not None and self.action_result != result:
return False
Expand Down Expand Up @@ -541,7 +543,7 @@ def matches_decision(
Returns:
True if event matches all specified criteria
"""
if self.msg != "decision":
if self.msg != VERB_DECISION:
return False
if result is not None and self.action_result != result:
return False
Expand Down Expand Up @@ -619,6 +621,7 @@ def __init__(
self._mark_counter = 0
self._threads: list[threading.Thread] = []
self._stop_event = threading.Event()
self._new_data = threading.Condition()
self._disabled = False
self._error: Exception | None = None
self.log_file_path: Path | None = None
Expand Down Expand Up @@ -650,8 +653,11 @@ def start(self) -> None:
self._disabled = True
return

any_file_exists = any(path.exists() for path in self.log_files.values())
if not any_file_exists:
existing_files = {
service: path for service, path in self.log_files.items() if path.exists()
}

if not existing_files:
logger.warning(
f"None of the log files exist yet: {list(self.log_files.values())}. "
f"Will wait for them to be created..."
Expand Down Expand Up @@ -681,6 +687,9 @@ def stop(self) -> None:

logger.debug("Stopping audit log collection")
self._stop_event.set()
# Wake any threads waiting on new data so they can exit promptly
with self._new_data:
self._new_data.notify_all()

for thread in self._threads:
if thread.is_alive():
Expand Down Expand Up @@ -785,6 +794,22 @@ def write_to_disk(self, path: Path) -> None:
self.log_file_written = True
logger.info(f"Wrote {len(self._buffer)} audit log entries to {path}")

def wait_for_new_data(self, timeout: float = 0.1) -> bool:
"""Wait for new log data to arrive.

Blocks until new data is appended by a tail thread, or until timeout.
More efficient than polling with time.sleep() since it wakes up
immediately when data arrives.

Args:
timeout: Maximum time to wait in seconds (default: 0.1)

Returns:
True if woken by new data, False if timed out
"""
with self._new_data:
return self._new_data.wait(timeout=timeout)

def _tail_file(self, service: str, log_path: Path) -> None:
"""Background thread target that tails a log file.

Expand All @@ -808,14 +833,23 @@ def _tail_file(self, service: str, log_path: Path) -> None:
f.seek(0, 2)

while not self._stop_event.is_set():
line = f.readline()
if line:
# Batch-read all available lines before notifying
got_data = False
while True:
line = f.readline()
if not line:
break
entry = LogEntry(
timestamp=datetime.now(),
raw_line=line.rstrip(),
service_name=service,
)
self._buffer.append(entry)
got_data = True

if got_data:
with self._new_data:
self._new_data.notify_all()
else:
self._stop_event.wait(0.1)
except Exception as e:
Expand All @@ -838,6 +872,15 @@ def __init__(self, collector: AuditLogCollector | None):
"""
self._collector = collector

@property
def is_enabled(self) -> bool:
"""Check if audit log collection is enabled.

Returns:
True if collection is active, False if disabled or no collector
"""
return self._collector is not None and not self._collector._disabled

def mark(self, label: str) -> str:
"""Mark a timestamp for later correlation.

Expand Down Expand Up @@ -927,7 +970,7 @@ def assert_contains(
matching: list[LogEntry] = []
logs: list[LogEntry] = []

while time.time() - start_time < timeout:
while True:
logs = self._collector.get_logs(since=since)
matching = [log for log in logs if regex.search(log.raw_line)]

Expand All @@ -939,8 +982,11 @@ def assert_contains(
)
return matching

# Sleep briefly before checking again
time.sleep(0.1)
remaining = timeout - (time.time() - start_time)
if remaining <= 0:
break
# Wait for new data or timeout
self._collector.wait_for_new_data(timeout=min(remaining, 1.0))

# Timeout expired, raise error if we don't have enough matches
timeout_time = datetime.now()
Expand Down Expand Up @@ -1141,7 +1187,7 @@ def parse_audit_log(

# Verify msg is one of the known audit verbs
msg = data.get("msg", "")
if msg not in ("decision", "policy crud", "rewrap"):
if msg not in (VERB_DECISION, VERB_POLICY_CRUD, VERB_REWRAP):
return None

event = ParsedAuditEvent(
Expand Down Expand Up @@ -1185,7 +1231,7 @@ def get_parsed_audit_logs(

# Wait a bit for logs to arrive
start_time = time.time()
while time.time() - start_time < timeout:
while True:
logs = self._collector.get_logs(since=since)
parsed = []
for entry in logs:
Expand All @@ -1194,7 +1240,11 @@ def get_parsed_audit_logs(
parsed.append(event)
if parsed:
return parsed
time.sleep(0.1)

remaining = timeout - (time.time() - start_time)
if remaining <= 0:
break
self._collector.wait_for_new_data(timeout=min(remaining, 1.0))

return []

Expand Down Expand Up @@ -1245,7 +1295,7 @@ def assert_rewrap(
matching: list[ParsedAuditEvent] = []
all_logs: list[LogEntry] = []

while time.time() - start_time < timeout:
while True:
all_logs = self._collector.get_logs(since=since)
matching = []

Expand All @@ -1267,7 +1317,10 @@ def assert_rewrap(
)
return matching

time.sleep(0.1)
remaining = timeout - (time.time() - start_time)
if remaining <= 0:
break
self._collector.wait_for_new_data(timeout=min(remaining, 1.0))

# Build detailed error message
timeout_time = datetime.now()
Expand Down Expand Up @@ -1426,7 +1479,7 @@ def assert_policy_crud(
matching: list[ParsedAuditEvent] = []
all_logs: list[LogEntry] = []

while time.time() - start_time < timeout:
while True:
all_logs = self._collector.get_logs(since=since)
matching = []

Expand All @@ -1447,7 +1500,10 @@ def assert_policy_crud(
)
return matching

time.sleep(0.1)
remaining = timeout - (time.time() - start_time)
if remaining <= 0:
break
self._collector.wait_for_new_data(timeout=min(remaining, 1.0))

# Build detailed error message
timeout_time = datetime.now()
Expand Down Expand Up @@ -1586,7 +1642,7 @@ def assert_decision_v2(
matching: list[ParsedAuditEvent] = []
all_logs: list[LogEntry] = []

while time.time() - start_time < timeout:
while True:
all_logs = self._collector.get_logs(since=since)
matching = []

Expand All @@ -1608,7 +1664,10 @@ def assert_decision_v2(
)
return matching

time.sleep(0.1)
remaining = timeout - (time.time() - start_time)
if remaining <= 0:
break
self._collector.wait_for_new_data(timeout=min(remaining, 1.0))

# Build detailed error message
timeout_time = datetime.now()
Expand Down
38 changes: 38 additions & 0 deletions xtest/fixtures/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,3 +473,41 @@ def ns_and_value_kas_grants_and(
otdfctl.key_assign_ns(kas_key_ns, temp_namespace)

return allof


# Default KAS RSA key fixture for tests that need explicit RSA wrapping
@pytest.fixture(scope="module")
def attribute_default_rsa(
otdfctl: OpentdfCommandLineTool,
kas_entry_default: abac.KasEntry,
kas_public_key_r1: abac.KasPublicKey,
otdf_client_scs: abac.SubjectConditionSet,
temporary_namespace: abac.Namespace,
) -> abac.Attribute:
"""Attribute with RSA key mapping on default KAS.

Use this fixture when tests need to ensure RSA wrapping is used,
regardless of what base_key may be configured on the platform.
This prevents test order sensitivity when base_key tests run.
"""
anyof = otdfctl.attribute_create(
temporary_namespace, "defaultrsa", abac.AttributeRule.ANY_OF, ["wrapped"]
)
assert anyof.values
(wrapped,) = anyof.values
assert wrapped.value == "wrapped"

# Assign to all clientIds = opentdf-sdk
sm = otdfctl.scs_map(otdf_client_scs, wrapped)
assert sm.attribute_value.value == "wrapped"

# Assign RSA key on default KAS
if "key_management" not in tdfs.PlatformFeatureSet().features:
otdfctl.grant_assign_value(kas_entry_default, wrapped)
else:
kas_key = otdfctl.kas_registry_create_public_key_only(
kas_entry_default, kas_public_key_r1
)
otdfctl.key_assign_value(kas_key, wrapped)

return anyof
Loading
Loading