TensorFlow plugin for Moore Threads MUSA GPUs: MUSA kernels and graph optimizations accelerate TensorFlow on MUSA hardware.
- MUSA implementations for core ops and common fusion paths
- Grappler-based graph optimizations (layout, fusion, optional mixed precision, etc.)
- Python package
tensorflow_musa: plugin load and device discovery - Optional telemetry and debugging: see Debug guide
- CMake ≥ 3.10, Make, GCC/G++ (ABI-compatible with TensorFlow 2.6.1 pip wheels)
- MUSA SDK (default
/usr/local/musa): runtime, muBLAS, muDNN - Python ≥ 3.7
- TensorFlow == 2.6.1 (must match this version)
- NumPy ≥ 1.19.0
git clone <repository-url>
cd tensorflow_musa_extension
pip install tensorflow==2.6.1
./build.sh wheel
pip install dist/tensorflow_musa-*.whl --no-depsUse --force-reinstall when replacing an existing install.
import tensorflow_musa as tf_musa
print(tf_musa.__version__)
print(tf_musa.get_musa_devices())Example with a MUSA device:
import tensorflow as tf
import tensorflow_musa # ensure plugin is loaded
with tf.device("/device:MUSA:0"):
a = tf.constant([[1.0, 2.0], [3.0, 4.0]])
b = tf.matmul(a, a)MUSA allocator memory growth defaults to False, matching TensorFlow's native
GPU behavior. You can configure it explicitly before MUSA devices are
initialized:
import tensorflow_musa as tf_musa
tf_musa.set_musa_allow_growth(enabled=True)To explicitly disable it:
tf_musa.set_musa_allow_growth(enabled=False)The TensorFlow-compatible environment variable can also override the Python setting:
export TF_FORCE_GPU_ALLOW_GROWTH=trueEnable or disable the MUSA custom graph optimizer:
import tensorflow as tf
import tensorflow_musa as tf_musa
config = tf.compat.v1.ConfigProto()
tf_musa.set_musa_graph_optimizer_enabled(config, enabled=True)
# To disable it:
# tf_musa.set_musa_graph_optimizer_enabled(config, enabled=False)Disable selected fusion patterns from Python by passing parameters to the C++ optimizer:
tf_musa.disable_musa_fusion_patterns(
config,
patterns=["MusaGeluFusion", "MusaLayerNormFusion"],
)
# Disable all fusion patterns
tf_musa.disable_musa_fusion_patterns(config, patterns="all")
# Clear the disabled fusion pattern list
tf_musa.clear_musa_disabled_fusion_patterns(config)Produces build/libmusa_plugin.so only (no wheel):
pip install tensorflow==2.6.1
./build.sh # or ./build.sh releaseFor experiments you can tf.load_library("./build/libmusa_plugin.so").
- Debugging and environment variables
- More examples: TensorFlow MUSA Playground
Issues and PRs are welcome (please add tests for new ops).
Apache License 2.0
Please use repository Issues or contact the maintainers.