I have repeatedly observed that TabPFN causes (non-deterministic) segmentation faults in my pipeline. Similar issues have been reported by other users (see here).
While the suggested fix of setting "OMP_NUM_THREADS" to 1, seems to make the pipeline more stable, I still observed segmentation faults in my rather complex pipeline. After quite a bit of debugging, I have now managed to create a "minimal" reproducible example that produces segmentation faults eventually during the run.
The compute node I am running this script on has around 750GB of RAM and I have observed this behaviour on multiple different L40S GPUs installed in the compute node.
To control for issues that arise with interactions with other packages, I created a minimal conda env to run this script. The .yaml is the following:
The script runs through without a segmentation fault.
As stated above, when exactly a segmentation fault occurs is non-deterministic. The output below is from one of my runs. I have not been able to run the above script until termination without a segmentation fault occuring.
[COPIED TABPFN_INFO TO THE FIELD BELOW]
[OMITTED RUN 1-3 FOR READABLILITY]
RUN 4
Running run 1/3 for 100 test samples...
Running inference on 100 test samples...
Accuracy: 0.810 for 100 test samples
Running run 2/3 for 100 test samples...
Running inference on 100 test samples...
Accuracy: 0.830 for 100 test samples
Running run 3/3 for 100 test samples...
Running inference on 100 test samples...
Accuracy: 0.810 for 100 test samples
Running run 1/3 for 500 test samples...
Running inference on 500 test samples...
Accuracy: 0.778 for 500 test samples
Running run 2/3 for 500 test samples...
Running inference on 500 test samples...
Accuracy: 0.844 for 500 test samples
Running run 3/3 for 500 test samples...
Running inference on 500 test samples...
Accuracy: 0.836 for 500 test samples
Running run 1/3 for 1000 test samples...
Running inference on 1000 test samples...
Accuracy: 0.789 for 1000 test samples
Running run 2/3 for 1000 test samples...
Running inference on 1000 test samples...
Accuracy: 0.854 for 1000 test samples
Running run 3/3 for 1000 test samples...
Running inference on 1000 test samples...
Accuracy: 0.834 for 1000 test samples
Running run 1/3 for 5000 test samples...
Running inference on 5000 test samples...
Accuracy: 0.816 for 5000 test samples
Running run 2/3 for 5000 test samples...
Running inference on 5000 test samples...
Accuracy: 0.840 for 5000 test samples
Running run 3/3 for 5000 test samples...
Running inference on 5000 test samples...
Accuracy: 0.860 for 5000 test samples
Running run 1/3 for 10000 test samples...
Running inference on 10000 test samples...
Fatal Python error: Segmentation fault
Current thread 0x00007fa63e9db740 (most recent call first):
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/tabpfn/model/mlp.py", line 97 in _compute
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/tabpfn/model/memory.py", line 100 in method_
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/tabpfn/model/mlp.py", line 132 in forward
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762 in _call_impl
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751 in _wrapped_call_impl
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/tabpfn/model/layer.py", line 440 in forward
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762 in _call_impl
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751 in _wrapped_call_impl
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/tabpfn/model/transformer.py", line 89 in forward
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762 in _call_impl
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751 in _wrapped_call_impl
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/tabpfn/model/transformer.py", line 605 in _forward
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/tabpfn/model/transformer.py", line 383 in forward
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762 in _call_impl
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751 in _wrapped_call_impl
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/tabpfn/inference.py", line 512 in iter_outputs
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/tabpfn/classifier.py", line 754 in forward
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/tabpfn/classifier.py", line 685 in predict_proba
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/contextlib.py", line 81 in inner
File "/workspaces/[some_repo]/misc/segmentation_error_simple_example.py", line 77 in main
File "/workspaces/[some_repo]/misc/segmentation_error_simple_example.py", line 84 in <module>
Extension modules: numpy._core._multiarray_umath, numpy.linalg._umath_linalg, torch._C, torch._C._dynamo.autograd_compiler, torch._C._dynamo.eval_frame, torch._C._dynamo.guards, torch._C._dynamo.utils, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special, sklearn.__check_build._check_build, scipy._lib._ccallback_c, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._pcg64, numpy.random._mt19937, numpy.random._generator, numpy.random._philox, numpy.random._sfc64, numpy.random.mtrand, charset_normalizer.md, scipy.sparse._sparsetools, _csparsetools, _cyutility, scipy._cyutility, scipy.sparse._csparsetools, scipy.special._ufuncs_cxx, scipy.special._ellip_harm_2, scipy.special._special_ufuncs, scipy.special._gufuncs, scipy.special._ufuncs, scipy.special._specfun, scipy.special._comb, scipy.linalg._fblas, scipy.linalg._flapack, scipy.linalg.cython_lapack, scipy.linalg._cythonized_array_utils, scipy.linalg._solve_toeplitz, scipy.linalg._decomp_lu_cython, scipy.linalg._matfuncs_schur_sqrtm, scipy.linalg._matfuncs_expm, scipy.linalg._linalg_pythran, scipy.linalg.cython_blas, scipy.linalg._decomp_update, scipy.sparse.linalg._dsolve._superlu, scipy.sparse.linalg._eigen.arpack._arpack, scipy.sparse.linalg._propack._spropack, scipy.sparse.linalg._propack._dpropack, scipy.sparse.linalg._propack._cpropack, scipy.sparse.linalg._propack._zpropack, scipy.spatial._ckdtree, scipy._lib.messagestream, scipy.spatial._qhull, scipy.spatial._voronoi, scipy.spatial._hausdorff, scipy.spatial._distance_wrap, scipy.spatial.transform._rotation, scipy.spatial.transform._rigid_transform, scipy.optimize._group_columns, scipy.optimize._trlib._trlib, scipy.optimize._lbfgsb, _moduleTNC, scipy.optimize._moduleTNC, scipy.optimize._slsqplib, scipy.optimize._minpack, scipy.optimize._lsq.givens_elimination, scipy.optimize._zeros, scipy._lib._uarray._uarray, scipy.linalg._decomp_interpolative, scipy.optimize._bglu_dense, scipy.optimize._lsap, scipy.optimize._direct, scipy.integrate._odepack, scipy.integrate._quadpack, scipy.integrate._vode, scipy.integrate._dop, scipy.integrate._lsoda, scipy.interpolate._fitpack, scipy.interpolate._dfitpack, scipy.interpolate._dierckx, scipy.interpolate._ppoly, scipy.interpolate._interpnd, scipy.interpolate._rbfinterp_pythran, scipy.interpolate._rgi_cython, scipy.special.cython_special, scipy.stats._stats, scipy.stats._biasedurn, scipy.stats._stats_pythran, scipy.stats._levy_stable.levyst, scipy.stats._ansari_swilk_statistics, scipy.sparse.csgraph._tools, scipy.sparse.csgraph._shortest_path, scipy.sparse.csgraph._traversal, scipy.sparse.csgraph._min_spanning_tree, scipy.sparse.csgraph._flow, scipy.sparse.csgraph._matching, scipy.sparse.csgraph._reordering, scipy.stats._sobol, scipy.stats._qmc_cy, scipy.stats._rcont.rcont, scipy.stats._qmvnt_cy, scipy.ndimage._nd_image, scipy.ndimage._rank_filter_1d, _ni_label, scipy.ndimage._ni_label, pandas._libs.tslibs.ccalendar, pandas._libs.tslibs.np_datetime, pandas._libs.tslibs.dtypes, pandas._libs.tslibs.base, pandas._libs.tslibs.nattype, pandas._libs.tslibs.timezones, pandas._libs.tslibs.fields, pandas._libs.tslibs.timedeltas, pandas._libs.tslibs.tzconversion, pandas._libs.tslibs.timestamps, pandas._libs.properties, pandas._libs.tslibs.offsets, pandas._libs.tslibs.strptime, pandas._libs.tslibs.parsing, pandas._libs.tslibs.conversion, pandas._libs.tslibs.period, pandas._libs.tslibs.vectorized, pandas._libs.ops_dispatch, pandas._libs.missing, pandas._libs.hashtable, pandas._libs.algos, pandas._libs.interval, pandas._libs.lib, pandas._libs.ops, pandas._libs.hashing, pandas._libs.arrays, pandas._libs.tslib, pandas._libs.sparse, pandas._libs.internals, pandas._libs.indexing, pandas._libs.index, pandas._libs.writers, pandas._libs.join, pandas._libs.window.aggregations, pandas._libs.window.indexers, pandas._libs.reshape, pandas._libs.groupby, pandas._libs.json, pandas._libs.parsers, pandas._libs.testing, sklearn.utils._isfinite, sklearn.utils.sparsefuncs_fast, sklearn.utils.murmurhash, sklearn.utils._openmp_helpers, sklearn.preprocessing._csr_polynomial_expansion, sklearn.preprocessing._target_encoder_fast, sklearn.utils._random, sklearn.utils._seq_dataset, sklearn.metrics.cluster._expected_mutual_info_fast, sklearn.metrics._dist_metrics, sklearn.metrics._pairwise_distances_reduction._datasets_pair, sklearn.utils._cython_blas, sklearn.metrics._pairwise_distances_reduction._base, sklearn.metrics._pairwise_distances_reduction._middle_term_computer, sklearn.utils._heap, sklearn.utils._sorting, sklearn.metrics._pairwise_distances_reduction._argkmin, sklearn.metrics._pairwise_distances_reduction._argkmin_classmode, sklearn.utils._vector_sentinel, sklearn.metrics._pairwise_distances_reduction._radius_neighbors, sklearn.metrics._pairwise_distances_reduction._radius_neighbors_classmode, sklearn.metrics._pairwise_fast, sklearn.linear_model._cd_fast, _loss, sklearn._loss._loss, sklearn.utils.arrayfuncs, sklearn.svm._liblinear, sklearn.svm._libsvm, sklearn.svm._libsvm_sparse, sklearn.linear_model._sag_fast, sklearn.utils._weight_vector, sklearn.linear_model._sgd_fast, sklearn.decomposition._online_lda_fast, sklearn.decomposition._cdnmf_fast, sklearn.neighbors._partition_nodes, sklearn.neighbors._ball_tree, sklearn.neighbors._kd_tree, sklearn._isotonic, sklearn.utils._fast_dict, sklearn.cluster._hierarchical_fast, sklearn.cluster._k_means_common, sklearn.cluster._k_means_elkan, sklearn.cluster._k_means_lloyd, sklearn.cluster._k_means_minibatch, sklearn.cluster._dbscan_inner, sklearn.cluster._hdbscan._tree, sklearn.cluster._hdbscan._linkage, sklearn.cluster._hdbscan._reachability, sklearn.tree._utils, sklearn.tree._tree, sklearn.tree._partitioner, sklearn.tree._splitter, sklearn.tree._criterion, sklearn.neighbors._quad_tree, sklearn.manifold._barnes_hut_tsne, sklearn.manifold._utils, sklearn.ensemble._gradient_boosting, sklearn.ensemble._hist_gradient_boosting.common, sklearn.ensemble._hist_gradient_boosting._gradient_boosting, sklearn.ensemble._hist_gradient_boosting._binning, sklearn.ensemble._hist_gradient_boosting._bitset, sklearn.ensemble._hist_gradient_boosting.histogram, sklearn.ensemble._hist_gradient_boosting._predictor, sklearn.ensemble._hist_gradient_boosting.splitting, scipy.io.matlab._mio_utils, scipy.io.matlab._streams, scipy.io.matlab._mio5_utils, sklearn.datasets._svmlight_format_fast, sklearn.feature_extraction._hashing_fast (total: 218)
Segmentation fault (core dumped)
Describe the bug
I have repeatedly observed that TabPFN causes (non-deterministic) segmentation faults in my pipeline. Similar issues have been reported by other users (see here).
While the suggested fix of setting
"OMP_NUM_THREADS"to1, seems to make the pipeline more stable, I still observed segmentation faults in my rather complex pipeline. After quite a bit of debugging, I have now managed to create a "minimal" reproducible example that produces segmentation faults eventually during the run.Steps/Code to Reproduce
The compute node I am running this script on has around 750GB of RAM and I have observed this behaviour on multiple different
L40SGPUs installed in the compute node.To control for issues that arise with interactions with other packages, I created a minimal conda env to run this script. The
.yamlis the following:Expected Results
The script runs through without a segmentation fault.
Actual Results
As stated above, when exactly a segmentation fault occurs is non-deterministic. The output below is from one of my runs. I have not been able to run the above script until termination without a segmentation fault occuring.
Versions