diff --git a/crates/processing_pyo3/examples/rectangle.py b/crates/processing_pyo3/examples/rectangle.py index cce3fe1..ac03477 100644 --- a/crates/processing_pyo3/examples/rectangle.py +++ b/crates/processing_pyo3/examples/rectangle.py @@ -1,7 +1,7 @@ from processing import * -# TODO: this should be in a setup function -size(800, 600) +def setup(): + size(800, 600) def draw(): background(220) @@ -12,4 +12,4 @@ def draw(): rect(100, 100, 200, 150) # TODO: this should happen implicitly on module load somehow -run(draw) +run() diff --git a/crates/processing_pyo3/src/lib.rs b/crates/processing_pyo3/src/lib.rs index 9fd0611..432c433 100644 --- a/crates/processing_pyo3/src/lib.rs +++ b/crates/processing_pyo3/src/lib.rs @@ -12,7 +12,7 @@ mod glfw; mod graphics; use graphics::{Graphics, get_graphics, get_graphics_mut}; -use pyo3::{exceptions::PyRuntimeError, prelude::*, types::PyAny}; +use pyo3::{exceptions::PyRuntimeError, prelude::*}; #[pymodule] fn processing(m: &Bound<'_, PyModule>) -> PyResult<()> { @@ -38,27 +38,37 @@ fn size(module: &Bound<'_, PyModule>, width: u32, height: u32) -> PyResult<()> { } #[pyfunction] -#[pyo3(pass_module, signature = (draw_fn=None))] -fn run(module: &Bound<'_, PyModule>, draw_fn: Option>) -> PyResult<()> { - loop { - { - let mut graphics = get_graphics_mut(module)?; - if !graphics.surface.poll_events() { - break; +#[pyo3(pass_module)] +fn run(module: &Bound<'_, PyModule>) -> PyResult<()> { + Python::attach(|py| { + let builtins = PyModule::import(py, "builtins")?; + let locals = builtins.getattr("locals")?.call0()?; + + let setup_fn = locals.get_item("setup")?; + let draw_fn = locals.get_item("draw")?; + + // call setup + setup_fn.call0()?; + + // start draw loop + loop { + { + let mut graphics = get_graphics_mut(module)?; + if !graphics.surface.poll_events() { + break; + } + graphics.begin_draw()?; } - graphics.begin_draw()?; - } - if let Some(ref draw) = draw_fn { - Python::attach(|py| { - draw.call0(py) - .map_err(|e| PyRuntimeError::new_err(format!("{e}"))) - })?; + draw_fn + .call0() + .map_err(|e| PyRuntimeError::new_err(format!("{e}")))?; + + get_graphics(module)?.end_draw()?; } - get_graphics(module)?.end_draw()?; - } - Ok(()) + Ok(()) + }) } #[pyfunction]