Skip to content

Latest commit

 

History

History
executable file
·
112 lines (74 loc) · 2.98 KB

File metadata and controls

executable file
·
112 lines (74 loc) · 2.98 KB

TensorFlow MUSA Extension

面向摩尔线程(Moore Threads)MUSA GPU 的 TensorFlow 插件:通过 MUSA 内核与图优化为 TensorFlow 提供 GPU 加速。

特性

  • 核心算子与常用融合路径的 MUSA 实现
  • Grappler 图优化(布局、融合、可选混合精度等)
  • Python 包 tensorflow_musa:自动加载插件与设备查询
  • 可选遥测与调试说明见 调试指南

环境要求

  • CMake ≥ 3.10,Make,GCC/G++(与 TensorFlow 2.6.1 wheel ABI 一致)
  • MUSA SDK(默认路径 /usr/local/musa):Runtime、muBLAS、muDNN
  • Python ≥ 3.7
  • TensorFlow == 2.6.1(须与此版本一致)
  • NumPy ≥ 1.19.0

安装(推荐:Wheel)

git clone <repository-url>
cd tensorflow_musa_extension

pip install tensorflow==2.6.1
./build.sh wheel
pip install dist/tensorflow_musa-*.whl --no-deps

重新构建后覆盖安装可加 --force-reinstall

快速验证

import tensorflow_musa as tf_musa

print(tf_musa.__version__)
print(tf_musa.get_musa_devices())

在计算图中使用 MUSA 设备(示例):

import tensorflow as tf
import tensorflow_musa  # 确保插件已加载

with tf.device("/device:MUSA:0"):
    a = tf.constant([[1.0, 2.0], [3.0, 4.0]])
    b = tf.matmul(a, a)

MUSA 自定义图优化器开关

tensorflow_musa 提供了 ConfigProto 级别的接口,用于启用、关闭或查询 musa_graph_optimizer。常规推理场景推荐使用 enable_musa_graph_optimizer(config),它等价于向 config.graph_options.rewrite_options.custom_optimizers 注册 musa_graph_optimizer

import tensorflow as tf
import tensorflow_musa as tf_musa

config = tf.compat.v1.ConfigProto()

# 启用 MUSA 自定义图优化器
tf_musa.enable_musa_graph_optimizer(config)

# 查询是否已启用
print(tf_musa.is_musa_graph_optimizer_enabled(config))

# 关闭 MUSA 自定义图优化器
tf_musa.disable_musa_graph_optimizer(config)

也可以使用统一接口显式传入开关值:

tf_musa.set_musa_graph_optimizer_enabled(config, enabled=True)
tf_musa.set_musa_graph_optimizer_enabled(config, enabled=False)

少数测试或调试场景需要强制设置 Grappler optimizer 列表时,可以额外传入 add_to_optimizer_list=True

tf_musa.enable_musa_graph_optimizer(config, add_to_optimizer_list=True)

从源码构建插件(可选)

仅生成 build/libmusa_plugin.so(不打包 wheel):

pip install tensorflow==2.6.1
./build.sh          # 或 ./build.sh release

开发时也可在 Python 中 tf.load_library("./build/libmusa_plugin.so") 手动加载。

文档与示例

参与贡献

欢迎提交 Issue 与 Pull Request(新算子请附带测试)。

许可证

Apache License 2.0

支持

请在仓库 Issue 中反馈问题或联系维护者。