Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
314 changes: 228 additions & 86 deletions fluxon_py/tests/test_api_chan_mpmc/test_api_chan_mpmc_base.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import etcd3

# Ensure absolute imports work when running this file directly
import os as _os
import sys as _sys
Expand All @@ -37,14 +35,13 @@
)
from fluxon_py.tests.test_lib import ( # noqa: E402
CHAN_CONFIG_TEST,
ETCD_HOST,
ETCD_PORT,
TEST_TIMEOUT_SECONDS,
setup_test_environment,
new_test_consumer,
new_test_producer,
load_test_fluxon_cluster_name,
run_with_argmatrix,
etcd_control_call_with_retry as _etcd_call_with_retry,
)
from fluxon_py.api_ext_chan import ( # noqa: E402
_new_unique_lock_key,
Expand Down Expand Up @@ -267,7 +264,10 @@ def wait_for_processes(processes: List[Tuple[str, subprocess.Popen, str]]) -> No
if proc.returncode != 0:
raise RuntimeError(
f"Process {process_type} failed (log: {log_file}),"
f" return code: {proc.returncode}"
f" return code: {proc.returncode}\n"
"--- child log tail ---\n"
f"{_read_log_tail(log_file)}\n"
"--- end child log tail ---"
)


Expand Down Expand Up @@ -448,27 +448,29 @@ def consumer_fairness_tolerance(total_consumed: int) -> int:
def clean_namespace() -> None:
if LOG_DIR.exists():
shutil.rmtree(LOG_DIR)
with etcd3.client(ETCD_HOST, ETCD_PORT) as etcd_client:

def _clean(etcd_client: Any) -> None:
# Delete unique mapping and its lock key to keep this scenario deterministic across reruns.
etcd_client.delete(_new_unique_mapping_key(CHANNEL_KEY))
etcd_client.delete(_new_unique_lock_key(CHANNEL_KEY))
etcd_client.delete_prefix("/mpmc_channels")
etcd_client.delete_prefix("/channels")
# Delete all producer done keys with correct format
for p_idx in range(PRODUCER_COUNT):
producer_id = f"P{p_idx}"
etcd_client.delete(f"{PRODUCER_DONE_KEY}_{producer_id}")

_etcd_call_with_retry("clean quick fair consume namespace", _clean)


def reset_producer_done_flag() -> None:
with etcd3.client(ETCD_HOST, ETCD_PORT) as etcd_client:
# Delete legacy unsuffixed key (if any)
def _reset(etcd_client: Any) -> None:
etcd_client.delete(PRODUCER_DONE_KEY)
# Delete all per-producer done keys
for p_idx in range(PRODUCER_COUNT):
producer_id = f"P{p_idx}"
etcd_client.delete(f"{PRODUCER_DONE_KEY}_{producer_id}")

_etcd_call_with_retry("reset producer done flags", _reset)


def _read_log_tail(path: str, *, max_lines: int = 80) -> str:
try:
Expand Down Expand Up @@ -505,31 +507,34 @@ def _wait_unique_key_mapping(
bootstrap_log: str,
) -> str:
deadline = time.time() + float(timeout_seconds)
with etcd3.client(ETCD_HOST, ETCD_PORT) as etcd_client:
while time.time() < deadline:
bootstrap_state = _bootstrap_process_state(
bootstrap_proc=bootstrap_proc,
bootstrap_log=bootstrap_log,
mapping_key = _new_unique_mapping_key(CHANNEL_KEY)
while time.time() < deadline:
bootstrap_state = _bootstrap_process_state(
bootstrap_proc=bootstrap_proc,
bootstrap_log=bootstrap_log,
)
if bootstrap_state is not None:
raise RuntimeError(
"Bootstrap producer exited before publishing channel mapping: "
f"unique_key={CHANNEL_KEY!r} {bootstrap_state}"
)
if bootstrap_state is not None:
raise RuntimeError(
"Bootstrap producer exited before publishing channel mapping: "
f"unique_key={CHANNEL_KEY!r} {bootstrap_state}"
)
value, _ = etcd_client.get(_new_unique_mapping_key(CHANNEL_KEY))
if value is not None:
try:
chan_id = value.decode("utf-8")
except Exception as err: # noqa: BLE001
raise RuntimeError(
f"Invalid channel mapping value for unique_key={CHANNEL_KEY!r}: {value!r}, err={err}"
) from None
if chan_id.isdigit():
return chan_id
value = _etcd_call_with_retry(
f"wait channel mapping unique_key={CHANNEL_KEY!r}",
lambda etcd_client: etcd_client.get(mapping_key),
)
if value is not None:
try:
chan_id = value.decode("utf-8")
except Exception as err: # noqa: BLE001
raise RuntimeError(
f"Invalid channel mapping for unique_key={CHANNEL_KEY!r}: {chan_id!r} (expected digit-only chan_id)"
)
time.sleep(0.2)
f"Invalid channel mapping value for unique_key={CHANNEL_KEY!r}: {value!r}, err={err}"
) from None
if chan_id.isdigit():
return chan_id
raise RuntimeError(
f"Invalid channel mapping for unique_key={CHANNEL_KEY!r}: {chan_id!r} (expected digit-only chan_id)"
)
time.sleep(0.2)
bootstrap_state = _bootstrap_process_state(
bootstrap_proc=bootstrap_proc,
bootstrap_log=bootstrap_log,
Expand Down Expand Up @@ -610,8 +615,11 @@ def run_producer(env, args: Dict[str, Any]) -> None:
# Close returns Result[OkNone, ApiError]; consume explicitly
producer.close().unwrap()
finally:
with etcd3.client(ETCD_HOST, ETCD_PORT) as etcd_client:
etcd_client.put(f"{PRODUCER_DONE_KEY}_{producer_id}", str(produced))
done_key = f"{PRODUCER_DONE_KEY}_{producer_id}"
_etcd_call_with_retry(
f"mark producer {producer_id} done",
lambda etcd_client: etcd_client.put(done_key, str(produced).encode()),
)
finally:
configure_backend(env, backend_type=prev_type, backend_ip=prev_ip)

Expand Down Expand Up @@ -647,74 +655,76 @@ def run_consumer(env, args: Dict[str, Any]) -> None:
start_time = time.monotonic()
max_deadline = start_time + MAX_CONSUMER_RUNTIME
try:
with etcd3.client(ETCD_HOST, ETCD_PORT) as etcd_client:
all_producers_done = False
last_producer_check = time.monotonic()
producer_check_interval = 0.5
all_producers_done = False
last_producer_check = time.monotonic()
producer_check_interval = 0.5

while True:
now = time.monotonic()

if now - last_producer_check >= producer_check_interval:
if not all_producers_done:
done_count = 0
for p_idx in range(PRODUCER_COUNT):
producer_id = f"P{p_idx}"
done_key = f"{PRODUCER_DONE_KEY}_{producer_id}"
value = _etcd_call_with_retry(
f"read producer {producer_id} done flag",
lambda etcd_client, key=done_key: etcd_client.get(key),
)
if value is not None:
done_count += 1
all_producers_done = done_count == PRODUCER_COUNT
if all_producers_done:
msg = f"🎉 Consumer {consumer_id}: All {PRODUCER_COUNT} producers done! consumed={consumed}"
print(msg, file=sys.stdout, flush=True)
last_producer_check = now

while True:
now = time.monotonic()
res = consumer.get_data(batch_size=1, try_time=1)

if now - last_producer_check >= producer_check_interval:
if not all_producers_done:
done_count = 0
for p_idx in range(PRODUCER_COUNT):
producer_id = f"P{p_idx}"
value, _metadata = etcd_client.get(f"{PRODUCER_DONE_KEY}_{producer_id}")
if value is not None:
done_count += 1
all_producers_done = done_count == PRODUCER_COUNT
if all_producers_done:
import sys
msg = f"🎉 Consumer {consumer_id}: All {PRODUCER_COUNT} producers done! consumed={consumed}"
print(msg, file=sys.stdout, flush=True)
last_producer_check = now

res = consumer.get_data(batch_size=1, try_time=1)

if res is None:
now = time.monotonic()
if res is None:
now = time.monotonic()
if now >= max_deadline:
raise RuntimeError(
f"Consumer {consumer_id} get_data returned None unexpectedly"
)
elif res.is_ok():
success = res.unwrap()
now = time.monotonic()
if isinstance(success, list) and success:
consumed += 1
last_activity = now
if isinstance(success[0], dict):
msg_key = str(success[0]["unique_id"])
if msg_key.startswith("quick-msg-"):
parts = msg_key.split("-")
if len(parts) >= 3:
producer_id_str = parts[2]
producer_consumed_counts[producer_id_str] = producer_consumed_counts.get(producer_id_str, 0) + 1
else:
err = res.unwrap_error()
now = time.monotonic()
if isinstance(err, MessageConsumptionNoNewMessageError):
if now >= max_deadline:
raise RuntimeError(
f"Consumer {consumer_id} get_data returned None unexpectedly"
f"Consumer {consumer_id} exceeded max runtime with no new message"
)
elif res.is_ok():
success = res.unwrap()
now = time.monotonic()
if isinstance(success, list) and success:
consumed += 1
last_activity = now
if isinstance(success[0], dict):
msg_key = str(success[0]["unique_id"])
if msg_key.startswith("quick-msg-"):
parts = msg_key.split("-")
if len(parts) >= 3:
producer_id_str = parts[2]
producer_consumed_counts[producer_id_str] = producer_consumed_counts.get(producer_id_str, 0) + 1
else:
err = res.unwrap_error()
now = time.monotonic()
if isinstance(err, MessageConsumptionNoNewMessageError):
if now >= max_deadline:
raise RuntimeError(
f"Consumer {consumer_id} exceeded max runtime with no new message"
)
else:
raise RuntimeError(
f"Consumer {consumer_id} get_data failed: {err}"
)

idle_time = now - last_activity
if all_producers_done and idle_time >= idle_timeout:
print(
f"✅ Consumer {consumer_id} exiting: all producers done and idle for {idle_time:.1f}s (consumed {consumed} messages)"
)
break
if now >= max_deadline:
raise RuntimeError(
f"Consumer {consumer_id} exceeded max runtime"
f"Consumer {consumer_id} get_data failed: {err}"
)
time.sleep(PRODUCER_DONE_POLL_INTERVAL)

idle_time = now - last_activity
if all_producers_done and idle_time >= idle_timeout:
print(
f"✅ Consumer {consumer_id} exiting: all producers done and idle for {idle_time:.1f}s (consumed {consumed} messages)"
)
break
if now >= max_deadline:
raise RuntimeError(
f"Consumer {consumer_id} exceeded max runtime"
)
time.sleep(PRODUCER_DONE_POLL_INTERVAL)
if consumed == 0:
raise AssertionError(
f"Consumer {consumer_id} did not receive any message"
Expand Down
Loading
Loading