Possibly try implementing a custom version of distance calculations for speed. Possible options are with: - Cython (very much like scipy.spatial) - Jax (e.g., https://github.com/chrisflesher/jax-scipy-spatial) - Julia and then integrating with JuliaCall - Numba