diff --git a/core_engine/src/outputs/asr_sink.rs b/core_engine/src/outputs/asr_sink.rs index afcfba7..253be7b 100644 --- a/core_engine/src/outputs/asr_sink.rs +++ b/core_engine/src/outputs/asr_sink.rs @@ -178,15 +178,20 @@ impl InputState { } fn poll(&mut self, callback: &mut dyn AsrSinkCallback) -> Result { - self.poll_with_limit(callback, None) - } - - fn poll_stop(&mut self, callback: &mut dyn AsrSinkCallback) -> Result { let channels = MASTER_FORMAT.channels.max(1) as usize; let bounded_samples = self.consumer.occupied_len() / channels * channels; self.poll_with_limit(callback, Some(bounded_samples)) } + fn stop_now(&mut self) { + self.drained_master.clear(); + self.converted_output.clear(); + self.pending_output.clear(); + self.pending_offset = 0; + self.quantized_output.clear(); + self.metrics.pending_frames.store(0, Ordering::Relaxed); + } + fn poll_with_limit( &mut self, callback: &mut dyn AsrSinkCallback, @@ -323,7 +328,7 @@ impl AsrSinkMetrics { pub struct AsrSink { stop: Arc, - handle: Option>>, + handle: Option, AsrSinkError>>>, metrics: Arc, } @@ -387,13 +392,13 @@ impl AsrSink { let stop = Arc::new(AtomicBool::new(false)); let stop_thread = stop.clone(); - let handle = thread::spawn(move || -> Result<(), AsrSinkError> { + let handle = thread::spawn(move || -> Result, AsrSinkError> { let idle_sleep = Duration::from_micros(200); loop { if stop_thread.load(Ordering::Relaxed) { for state in &mut states { - let _ = state.poll_stop(&mut *callback)?; + state.stop_now(); } break; } @@ -408,7 +413,13 @@ impl AsrSink { } } - Ok(()) + Ok(states + .into_iter() + .map(|state| AsrSinkInput { + input_id: state.input_id, + consumer: state.consumer, + }) + .collect()) }); Ok(Self { @@ -430,7 +441,7 @@ impl AsrSink { self.metrics.snapshot() } - pub fn stop(&mut self) -> Result<(), AsrSinkError> { + pub fn stop(&mut self) -> Result, AsrSinkError> { self.stop.store(true, Ordering::Relaxed); let Some(handle) = self.handle.take() else { return Err(AsrSinkError::AlreadyStopped); @@ -672,8 +683,7 @@ mod tests { .expect("spawn sink"); sink.stop().expect("first stop"); - let err = sink.stop().expect_err("second stop should fail"); - assert!(matches!(err, AsrSinkError::AlreadyStopped)); + assert!(matches!(sink.stop(), Err(AsrSinkError::AlreadyStopped))); } #[test] diff --git a/core_engine/src/outputs/wav_file.rs b/core_engine/src/outputs/wav_file.rs index 6f9e5eb..47f29c1 100644 --- a/core_engine/src/outputs/wav_file.rs +++ b/core_engine/src/outputs/wav_file.rs @@ -91,7 +91,7 @@ pub struct WavSinkMetricsSnapshot { pub struct WavFileOutput { stop: Arc, - handle: Option>>, + handle: Option, WavOutputError>>>, metrics: Arc, } @@ -123,7 +123,7 @@ impl WavFileOutput { let channels = format.channels.max(1) as u64; let frame_channels = format.channels.max(1) as usize; - let handle = thread::spawn(move || -> Result<(), WavOutputError> { + let handle = thread::spawn(move || -> Result, WavOutputError> { let mut writer = hound::WavWriter::new(writer, spec)?; let idle_sleep = Duration::from_micros(200); let mut consumers = consumers; @@ -207,7 +207,7 @@ impl WavFileOutput { metrics_thread .finalize .record(duration_to_u32_us(finalize_start.elapsed())); - Ok(()) + Ok(consumers) }); Ok(Self { @@ -305,7 +305,7 @@ impl WavFileOutput { self.metrics.snapshot() } - pub fn stop(&mut self) -> Result<(), WavOutputError> { + pub fn stop(&mut self) -> Result, WavOutputError> { self.stop.store(true, Ordering::Relaxed); let Some(handle) = self.handle.take() else { return Err(WavOutputError::AlreadyStopped); @@ -539,8 +539,7 @@ mod tests { .expect("spawn wav"); wav.stop().expect("first stop"); - let err = wav.stop().expect_err("second stop should fail"); - assert!(matches!(err, WavOutputError::AlreadyStopped)); + assert!(matches!(wav.stop(), Err(WavOutputError::AlreadyStopped))); } #[test] diff --git a/macloop/__init__.py b/macloop/__init__.py index 40c279f..571f8de 100644 --- a/macloop/__init__.py +++ b/macloop/__init__.py @@ -68,6 +68,25 @@ def _raise_on_unexpected_kwargs(name: str, kwargs: dict[str, Any]) -> None: raise TypeError(f"{name} got unexpected keyword arguments: {unexpected}") +def _close_backend_with_optional_engine(backend: Any, engine_backend: Any | None) -> None: + if engine_backend is not None: + try: + backend.close(engine_backend) + return + except TypeError as exc: + try: + backend.close() + return + except TypeError: + raise exc + + close_no_restore = getattr(backend, "close_no_restore", None) + if close_no_restore is not None: + close_no_restore() + else: + backend.close() + + @dataclass(frozen=True, slots=True) class AudioChunk: route_id: str @@ -506,13 +525,16 @@ def close(self) -> None: return err: Optional[Exception] = None + engine = self._engine_ref() try: - self._backend.close() + _close_backend_with_optional_engine( + self._backend, + engine._backend if engine is not None else None, + ) except Exception as exc: err = exc finally: self._closed = True - engine = self._engine_ref() if engine is not None: engine._release_routes(self._route_ids) _drop_oldest_put(self._queue, _STOP) @@ -634,13 +656,16 @@ def close(self) -> None: return err: Optional[Exception] = None + engine = self._engine_ref() try: - self._backend.close() + _close_backend_with_optional_engine( + self._backend, + engine._backend if engine is not None else None, + ) except Exception as exc: err = exc finally: self._closed = True - engine = self._engine_ref() if engine is not None: engine._release_routes(self._route_ids) diff --git a/python_ffi/src/lib.rs b/python_ffi/src/lib.rs index b20a7c2..0a11e4c 100644 --- a/python_ffi/src/lib.rs +++ b/python_ffi/src/lib.rs @@ -904,7 +904,12 @@ impl PyAsrSinkBackend { Ok(out) } - fn close(&mut self, py: Python<'_>) -> PyResult<()> { + #[pyo3(signature = (engine=None))] + fn close( + &mut self, + py: Python<'_>, + mut engine: Option>, + ) -> PyResult<()> { let Some(mut sink) = self.sink.take() else { return Ok(()); }; @@ -915,15 +920,41 @@ impl PyAsrSinkBackend { (stop_result, final_stats) }); self.final_stats = Some(final_stats); - stop_result + let route_consumers = stop_result .map_err(|e| PyRuntimeError::new_err(format!("failed to stop asr sink: {e}")))?; + + if let Some(engine) = engine.as_mut() { + engine + .restore_route_consumers( + route_consumers + .into_iter() + .map(|input| (input.input_id, input.consumer)) + .collect(), + ) + .map_err(|e| PyRuntimeError::new_err(format!("failed to restore asr sink routes: {e}")))?; + } + Ok(()) + } + + fn close_no_restore(&mut self, py: Python<'_>) -> PyResult<()> { + let Some(mut sink) = self.sink.take() else { + return Ok(()); + }; + + let (stop_result, final_stats) = py.detach(move || { + let stop_result = sink.stop().map_err(|e| e.to_string()); + let final_stats = sink.stats(); + (stop_result, final_stats) + }); + self.final_stats = Some(final_stats); + let _ = stop_result; Ok(()) } } impl Drop for PyAsrSinkBackend { fn drop(&mut self) { - let _ = Python::try_attach(|py| self.close(py)); + let _ = Python::try_attach(|py| self.close_no_restore(py)); } } @@ -931,6 +962,7 @@ impl Drop for PyAsrSinkBackend { struct PyWavSinkBackend { sink: Option, final_stats: Option, + route_ids: Vec, } #[pymethods] @@ -944,10 +976,16 @@ impl PyWavSinkBackend { Py::new(py, PyWavSinkStats::from_snapshot(py, snapshot)?) } - fn close(&mut self, py: Python<'_>) -> PyResult<()> { + #[pyo3(signature = (engine=None))] + fn close( + &mut self, + py: Python<'_>, + mut engine: Option>, + ) -> PyResult<()> { let Some(mut sink) = self.sink.take() else { return Ok(()); }; + let route_ids = self.route_ids.clone(); let (stop_result, final_stats) = py.detach(move || { let stop_result = sink.stop().map_err(|e| e.to_string()); @@ -955,15 +993,36 @@ impl PyWavSinkBackend { (stop_result, final_stats) }); self.final_stats = Some(final_stats); - stop_result + let consumers = stop_result .map_err(|e| PyRuntimeError::new_err(format!("failed to stop wav sink: {e}")))?; + + if let Some(engine) = engine.as_mut() { + engine + .restore_route_consumers(route_ids.into_iter().zip(consumers).collect()) + .map_err(|e| PyRuntimeError::new_err(format!("failed to restore wav sink routes: {e}")))?; + } + Ok(()) + } + + fn close_no_restore(&mut self, py: Python<'_>) -> PyResult<()> { + let Some(mut sink) = self.sink.take() else { + return Ok(()); + }; + + let (stop_result, final_stats) = py.detach(move || { + let stop_result = sink.stop().map_err(|e| e.to_string()); + let final_stats = sink.stats(); + (stop_result, final_stats) + }); + self.final_stats = Some(final_stats); + let _ = stop_result; Ok(()) } } impl Drop for PyWavSinkBackend { fn drop(&mut self) { - let _ = Python::try_attach(|py| self.close(py)); + let _ = Python::try_attach(|py| self.close_no_restore(py)); } } @@ -1539,6 +1598,7 @@ impl PyAudioEngineBackend { self.restore_stream_states(started_states); let route_consumers = self.take_route_consumers(&route_ids)?; + let route_ids_for_sink = route_ids.clone(); let master_format = self.controller.master_format(); let detached_result: DetachedWavStartResult = py.detach(move || { let consumers = route_consumers @@ -1559,6 +1619,7 @@ impl PyAudioEngineBackend { Ok(sink) => Ok(PyWavSinkBackend { sink: Some(sink), final_stats: None, + route_ids: route_ids_for_sink, }), Err((err, route_consumers)) => { if let Err(restore_err) = self.restore_route_consumers(route_consumers) { diff --git a/tests/test_e2e_synthetic.py b/tests/test_e2e_synthetic.py index 3558d3c..fa56345 100644 --- a/tests/test_e2e_synthetic.py +++ b/tests/test_e2e_synthetic.py @@ -323,3 +323,100 @@ def test_hot_synthetic_engine_close_completes_with_wav_and_asr_sinks(tmp_path: P assert stats["wav_size"] > 44 assert stats["asr_closed"] is True assert stats["wav_closed"] is True + + +def test_hot_synthetic_restart_on_single_engine_does_not_hang(tmp_path: Path) -> None: + child_code = textwrap.dedent( + """ + import json + import sys + import time + from pathlib import Path + + import macloop + + base = Path(sys.argv[1]) + engine = macloop.AudioEngine() + stream = engine.create_stream( + macloop.SyntheticSource, + frames_per_callback=160, + callback_count=1_000_000, + start_value=0.1, + step_value=0.0, + interval_ms=0, + start_delay_ms=0, + ) + route_wav = engine.route("restart_wav", stream=stream) + route_asr = engine.route("restart_asr", stream=stream) + + wav_sizes = [] + first_close_elapsed = None + final_close_elapsed = None + final_asr_closed = None + final_wav_closed = None + + for cycle in range(2): + wav_path = base / f"restart_{cycle}.wav" + wav_sink = macloop.WavSink(route=route_wav, file=wav_path) + asr_sink = macloop.AsrSink( + routes=[route_asr], + chunk_frames=320, + sample_rate=16_000, + channels=1, + sample_format="f32", + ) + time.sleep(0.05) + + if cycle == 0: + started = time.monotonic() + wav_sink.close() + asr_sink.close() + first_close_elapsed = time.monotonic() - started + else: + started = time.monotonic() + engine.close() + final_close_elapsed = time.monotonic() - started + final_asr_closed = asr_sink._closed + final_wav_closed = wav_sink._closed + + wav_sizes.append(wav_path.stat().st_size if wav_path.exists() else 0) + + print( + json.dumps( + { + "first_close_elapsed_s": round(first_close_elapsed, 6), + "final_close_elapsed_s": round(final_close_elapsed, 6), + "wav_sizes": wav_sizes, + "final_asr_closed": final_asr_closed, + "final_wav_closed": final_wav_closed, + } + ), + flush=True, + ) + """ + ) + + try: + completed = subprocess.run( + [sys.executable, "-c", child_code, str(tmp_path)], + capture_output=True, + text=True, + timeout=10.0, + check=False, + ) + except subprocess.TimeoutExpired as exc: + pytest.fail( + "single-engine restart hung for hot SyntheticSource with WavSink + AsrSink; " + f"stdout={((exc.stdout or '')[:400])!r} stderr={((exc.stderr or '')[:400])!r}" + ) + + assert completed.returncode == 0, completed.stderr + payload = completed.stdout.strip() + assert payload, completed.stderr + stats = json.loads(payload) + assert stats["first_close_elapsed_s"] < 10.0 + assert stats["final_close_elapsed_s"] < 10.0 + assert stats["wav_sizes"] == [size for size in stats["wav_sizes"] if size > 44] + assert len(stats["wav_sizes"]) == 2 + assert stats["final_asr_closed"] is True + assert stats["final_wav_closed"] is True diff --git a/tests/test_medium_e2e_real_capture.py b/tests/test_medium_e2e_real_capture.py index f40793b..616d56a 100644 --- a/tests/test_medium_e2e_real_capture.py +++ b/tests/test_medium_e2e_real_capture.py @@ -1,6 +1,7 @@ from __future__ import annotations import math +import queue import shutil import struct import subprocess @@ -90,10 +91,25 @@ def _read_float_wav(path: Path) -> tuple[tuple[int, int, int], list[float]]: def _invoke_with_timeout(fn: Callable[[], None], *, timeout_s: float, label: str) -> float: started_at = time.monotonic() - fn() + errors: queue.Queue[BaseException] = queue.Queue() + + def runner() -> None: + try: + fn() + except BaseException as exc: # pragma: no cover - exercised via callers + errors.put(exc) + + thread = threading.Thread(target=runner, name=f"timeout:{label}", daemon=True) + thread.start() + thread.join(timeout=timeout_s) + elapsed = time.monotonic() - started_at - if elapsed > timeout_s: - pytest.fail(f"{label} exceeded {timeout_s:.1f}s (took {elapsed:.3f}s)") + if thread.is_alive(): + pytest.fail(f"{label} exceeded {timeout_s:.1f}s (timed out after {elapsed:.3f}s)") + + if not errors.empty(): + raise errors.get() + return elapsed