Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 21 additions & 11 deletions core_engine/src/outputs/asr_sink.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,15 +178,20 @@ impl InputState {
}

fn poll(&mut self, callback: &mut dyn AsrSinkCallback) -> Result<bool, AsrSinkError> {
self.poll_with_limit(callback, None)
}

fn poll_stop(&mut self, callback: &mut dyn AsrSinkCallback) -> Result<bool, AsrSinkError> {
let channels = MASTER_FORMAT.channels.max(1) as usize;
let bounded_samples = self.consumer.occupied_len() / channels * channels;
self.poll_with_limit(callback, Some(bounded_samples))
}

fn stop_now(&mut self) {
self.drained_master.clear();
self.converted_output.clear();
self.pending_output.clear();
self.pending_offset = 0;
self.quantized_output.clear();
self.metrics.pending_frames.store(0, Ordering::Relaxed);
}

fn poll_with_limit(
&mut self,
callback: &mut dyn AsrSinkCallback,
Expand Down Expand Up @@ -323,7 +328,7 @@ impl AsrSinkMetrics {

pub struct AsrSink {
stop: Arc<AtomicBool>,
handle: Option<JoinHandle<Result<(), AsrSinkError>>>,
handle: Option<JoinHandle<Result<Vec<AsrSinkInput>, AsrSinkError>>>,
metrics: Arc<AsrSinkMetrics>,
}

Expand Down Expand Up @@ -387,13 +392,13 @@ impl AsrSink {
let stop = Arc::new(AtomicBool::new(false));
let stop_thread = stop.clone();

let handle = thread::spawn(move || -> Result<(), AsrSinkError> {
let handle = thread::spawn(move || -> Result<Vec<AsrSinkInput>, AsrSinkError> {
let idle_sleep = Duration::from_micros(200);

loop {
if stop_thread.load(Ordering::Relaxed) {
for state in &mut states {
let _ = state.poll_stop(&mut *callback)?;
state.stop_now();
}
break;
}
Expand All @@ -408,7 +413,13 @@ impl AsrSink {
}
}

Ok(())
Ok(states
.into_iter()
.map(|state| AsrSinkInput {
input_id: state.input_id,
consumer: state.consumer,
})
.collect())
});

Ok(Self {
Expand All @@ -430,7 +441,7 @@ impl AsrSink {
self.metrics.snapshot()
}

pub fn stop(&mut self) -> Result<(), AsrSinkError> {
pub fn stop(&mut self) -> Result<Vec<AsrSinkInput>, AsrSinkError> {
self.stop.store(true, Ordering::Relaxed);
let Some(handle) = self.handle.take() else {
return Err(AsrSinkError::AlreadyStopped);
Expand Down Expand Up @@ -672,8 +683,7 @@ mod tests {
.expect("spawn sink");

sink.stop().expect("first stop");
let err = sink.stop().expect_err("second stop should fail");
assert!(matches!(err, AsrSinkError::AlreadyStopped));
assert!(matches!(sink.stop(), Err(AsrSinkError::AlreadyStopped)));
}

#[test]
Expand Down
11 changes: 5 additions & 6 deletions core_engine/src/outputs/wav_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ pub struct WavSinkMetricsSnapshot {

pub struct WavFileOutput {
stop: Arc<AtomicBool>,
handle: Option<JoinHandle<Result<(), WavOutputError>>>,
handle: Option<JoinHandle<Result<Vec<RouteConsumer>, WavOutputError>>>,
metrics: Arc<WavSinkMetrics>,
}

Expand Down Expand Up @@ -123,7 +123,7 @@ impl WavFileOutput {
let channels = format.channels.max(1) as u64;
let frame_channels = format.channels.max(1) as usize;

let handle = thread::spawn(move || -> Result<(), WavOutputError> {
let handle = thread::spawn(move || -> Result<Vec<RouteConsumer>, WavOutputError> {
let mut writer = hound::WavWriter::new(writer, spec)?;
let idle_sleep = Duration::from_micros(200);
let mut consumers = consumers;
Expand Down Expand Up @@ -207,7 +207,7 @@ impl WavFileOutput {
metrics_thread
.finalize
.record(duration_to_u32_us(finalize_start.elapsed()));
Ok(())
Ok(consumers)
});

Ok(Self {
Expand Down Expand Up @@ -305,7 +305,7 @@ impl WavFileOutput {
self.metrics.snapshot()
}

pub fn stop(&mut self) -> Result<(), WavOutputError> {
pub fn stop(&mut self) -> Result<Vec<RouteConsumer>, WavOutputError> {
self.stop.store(true, Ordering::Relaxed);
let Some(handle) = self.handle.take() else {
return Err(WavOutputError::AlreadyStopped);
Expand Down Expand Up @@ -539,8 +539,7 @@ mod tests {
.expect("spawn wav");

wav.stop().expect("first stop");
let err = wav.stop().expect_err("second stop should fail");
assert!(matches!(err, WavOutputError::AlreadyStopped));
assert!(matches!(wav.stop(), Err(WavOutputError::AlreadyStopped)));
}

#[test]
Expand Down
33 changes: 29 additions & 4 deletions macloop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,25 @@ def _raise_on_unexpected_kwargs(name: str, kwargs: dict[str, Any]) -> None:
raise TypeError(f"{name} got unexpected keyword arguments: {unexpected}")


def _close_backend_with_optional_engine(backend: Any, engine_backend: Any | None) -> None:
if engine_backend is not None:
try:
backend.close(engine_backend)
return
except TypeError as exc:
try:
backend.close()
return
except TypeError:
raise exc

close_no_restore = getattr(backend, "close_no_restore", None)
if close_no_restore is not None:
close_no_restore()
else:
backend.close()


@dataclass(frozen=True, slots=True)
class AudioChunk:
route_id: str
Expand Down Expand Up @@ -506,13 +525,16 @@ def close(self) -> None:
return

err: Optional[Exception] = None
engine = self._engine_ref()
try:
self._backend.close()
_close_backend_with_optional_engine(
self._backend,
engine._backend if engine is not None else None,
)
except Exception as exc:
err = exc
finally:
self._closed = True
engine = self._engine_ref()
if engine is not None:
engine._release_routes(self._route_ids)
_drop_oldest_put(self._queue, _STOP)
Expand Down Expand Up @@ -634,13 +656,16 @@ def close(self) -> None:
return

err: Optional[Exception] = None
engine = self._engine_ref()
try:
self._backend.close()
_close_backend_with_optional_engine(
self._backend,
engine._backend if engine is not None else None,
)
except Exception as exc:
err = exc
finally:
self._closed = True
engine = self._engine_ref()
if engine is not None:
engine._release_routes(self._route_ids)

Expand Down
73 changes: 67 additions & 6 deletions python_ffi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -904,7 +904,12 @@ impl PyAsrSinkBackend {
Ok(out)
}

fn close(&mut self, py: Python<'_>) -> PyResult<()> {
#[pyo3(signature = (engine=None))]
fn close(
&mut self,
py: Python<'_>,
mut engine: Option<PyRefMut<'_, PyAudioEngineBackend>>,
) -> PyResult<()> {
let Some(mut sink) = self.sink.take() else {
return Ok(());
};
Expand All @@ -915,22 +920,49 @@ impl PyAsrSinkBackend {
(stop_result, final_stats)
});
self.final_stats = Some(final_stats);
stop_result
let route_consumers = stop_result
.map_err(|e| PyRuntimeError::new_err(format!("failed to stop asr sink: {e}")))?;

if let Some(engine) = engine.as_mut() {
engine
.restore_route_consumers(
route_consumers
.into_iter()
.map(|input| (input.input_id, input.consumer))
.collect(),
)
.map_err(|e| PyRuntimeError::new_err(format!("failed to restore asr sink routes: {e}")))?;
}
Ok(())
}

fn close_no_restore(&mut self, py: Python<'_>) -> PyResult<()> {
let Some(mut sink) = self.sink.take() else {
return Ok(());
};

let (stop_result, final_stats) = py.detach(move || {
let stop_result = sink.stop().map_err(|e| e.to_string());
let final_stats = sink.stats();
(stop_result, final_stats)
});
self.final_stats = Some(final_stats);
let _ = stop_result;
Ok(())
}
}

impl Drop for PyAsrSinkBackend {
fn drop(&mut self) {
let _ = Python::try_attach(|py| self.close(py));
let _ = Python::try_attach(|py| self.close_no_restore(py));
}
}

#[pyclass(name = "_WavSinkBackend", module = "macloop._macloop", unsendable)]
struct PyWavSinkBackend {
sink: Option<WavFileOutput>,
final_stats: Option<WavSinkMetricsSnapshot>,
route_ids: Vec<String>,
}

#[pymethods]
Expand All @@ -944,26 +976,53 @@ impl PyWavSinkBackend {
Py::new(py, PyWavSinkStats::from_snapshot(py, snapshot)?)
}

fn close(&mut self, py: Python<'_>) -> PyResult<()> {
#[pyo3(signature = (engine=None))]
fn close(
&mut self,
py: Python<'_>,
mut engine: Option<PyRefMut<'_, PyAudioEngineBackend>>,
) -> PyResult<()> {
let Some(mut sink) = self.sink.take() else {
return Ok(());
};
let route_ids = self.route_ids.clone();

let (stop_result, final_stats) = py.detach(move || {
let stop_result = sink.stop().map_err(|e| e.to_string());
let final_stats = sink.stats();
(stop_result, final_stats)
});
self.final_stats = Some(final_stats);
stop_result
let consumers = stop_result
.map_err(|e| PyRuntimeError::new_err(format!("failed to stop wav sink: {e}")))?;

if let Some(engine) = engine.as_mut() {
engine
.restore_route_consumers(route_ids.into_iter().zip(consumers).collect())
.map_err(|e| PyRuntimeError::new_err(format!("failed to restore wav sink routes: {e}")))?;
}
Ok(())
}

fn close_no_restore(&mut self, py: Python<'_>) -> PyResult<()> {
let Some(mut sink) = self.sink.take() else {
return Ok(());
};

let (stop_result, final_stats) = py.detach(move || {
let stop_result = sink.stop().map_err(|e| e.to_string());
let final_stats = sink.stats();
(stop_result, final_stats)
});
self.final_stats = Some(final_stats);
let _ = stop_result;
Ok(())
}
}

impl Drop for PyWavSinkBackend {
fn drop(&mut self) {
let _ = Python::try_attach(|py| self.close(py));
let _ = Python::try_attach(|py| self.close_no_restore(py));
}
}

Expand Down Expand Up @@ -1539,6 +1598,7 @@ impl PyAudioEngineBackend {
self.restore_stream_states(started_states);

let route_consumers = self.take_route_consumers(&route_ids)?;
let route_ids_for_sink = route_ids.clone();
let master_format = self.controller.master_format();
let detached_result: DetachedWavStartResult = py.detach(move || {
let consumers = route_consumers
Expand All @@ -1559,6 +1619,7 @@ impl PyAudioEngineBackend {
Ok(sink) => Ok(PyWavSinkBackend {
sink: Some(sink),
final_stats: None,
route_ids: route_ids_for_sink,
}),
Err((err, route_consumers)) => {
if let Err(restore_err) = self.restore_route_consumers(route_consumers) {
Expand Down
Loading
Loading