From ada3641296ec7d8bc2cde1ab4f39f6a8f0be46cc Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 31 Jan 2026 15:41:03 +0000 Subject: [PATCH 1/8] Initial plan From 0b7f204ebe995eef2082ddaa3530d6edefcf11b5 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 31 Jan 2026 16:06:30 +0000 Subject: [PATCH 2/8] Implement computational graph runtime with JSON serialization Co-authored-by: SnifferCaptain <151126294+SnifferCaptain@users.noreply.github.com> --- doc/ad_module_cn.md | 226 +++++++++++++++++++++++++ include/ad.hpp | 18 ++ include/ad/graph_edge.hpp | 93 +++++++++++ include/ad/graph_node.hpp | 105 ++++++++++++ include/ad/graph_runtime.hpp | 161 ++++++++++++++++++ include/ad/graph_serialization.hpp | 113 +++++++++++++ include/ytensor_concepts.hpp | 1 + src/ad/graph_edge.inl | 16 ++ src/ad/graph_node.inl | 42 +++++ src/ad/graph_runtime.inl | 252 ++++++++++++++++++++++++++++ src/ad/graph_serialization.inl | 255 +++++++++++++++++++++++++++++ 11 files changed, 1282 insertions(+) create mode 100644 doc/ad_module_cn.md create mode 100644 include/ad.hpp create mode 100644 include/ad/graph_edge.hpp create mode 100644 include/ad/graph_node.hpp create mode 100644 include/ad/graph_runtime.hpp create mode 100644 include/ad/graph_serialization.hpp create mode 100644 src/ad/graph_edge.inl create mode 100644 src/ad/graph_node.inl create mode 100644 src/ad/graph_runtime.inl create mode 100644 src/ad/graph_serialization.inl diff --git a/doc/ad_module_cn.md b/doc/ad_module_cn.md new file mode 100644 index 0000000..d44ffc1 --- /dev/null +++ b/doc/ad_module_cn.md @@ -0,0 +1,226 @@ +# 自动微分 (Automatic Differentiation) 模块 + +## 概述 + +本模块实现了基于计算图的运行时系统,为YTensor提供自动微分的基础设施。计算图采用节点-边的设计,支持动态构建和执行。 + +## 设计理念 + +- **节点 (Node)**: 表示计算图中的算子操作(如加法、乘法、矩阵乘等) +- **边 (Edge)**: 表示有向数据流,连接节点并携带YTensorBase张量数据 +- **运行时 (Runtime)**: 管理计算图的构建、执行和序列化 + +## 特性 + +✅ **动态计算图**: 无需编译时确定,使用YTensorBase基类实现运行时灵活性 +✅ **拓扑排序**: 自动确定节点执行顺序 +✅ **算子注册**: 灵活的算子执行器注册机制 +✅ **JSON序列化**: 支持计算图的导入导出 +✅ **节点参数**: 支持算子的可配置参数 + +## 文件结构 + +``` +include/ad/ +├── graph_node.hpp # 节点定义 +├── graph_edge.hpp # 边定义 +├── graph_runtime.hpp # 运行时引擎 +└── graph_serialization.hpp # JSON序列化 + +src/ad/ +├── graph_node.inl # 节点实现 +├── graph_edge.inl # 边实现 +├── graph_runtime.inl # 运行时实现 +└── graph_serialization.inl # 序列化实现 + +include/ad.hpp # 模块总入口 +``` + +## 快速开始 + +### 1. 包含头文件 + +```cpp +#include "ytensor.hpp" +#include "include/ad.hpp" + +using namespace yt; +using namespace yt::ad; +``` + +### 2. 创建计算图 + +```cpp +ComputationGraph graph; + +// 创建节点 +auto node1 = graph.createNode("add1", "add"); +auto node2 = graph.createNode("mul1", "mul"); + +// 创建边并连接节点 +auto input1 = graph.createEdge("input1", nullptr, node1); +auto input2 = graph.createEdge("input2", nullptr, node1); +auto edge1 = graph.createEdge("edge1", node1, node2); +auto output = graph.createEdge("output", node2, nullptr); +``` + +### 3. 注册算子 + +```cpp +void addOperator(const GraphNode& node, + const std::vector>& inputs, + std::vector>& outputs) { + const auto& t1 = inputs[0]->getTensor(); + const auto& t2 = inputs[1]->getTensor(); + YTensorBase result = t1 + t2; + outputs[0]->setTensor(result); +} + +graph.registerOperator("add", addOperator); +``` + +### 4. 执行计算图 + +```cpp +// 准备输入 +YTensorBase t1({2, 3}, "float32"); +YTensorBase t2({2, 3}, "float32"); +// ... 填充数据 ... + +std::unordered_map inputs = { + {"input1", t1}, + {"input2", t2} +}; + +// 执行 +auto outputs = graph.execute(inputs, {"output"}); +const auto& result = outputs["output"]; +``` + +### 5. 序列化 + +```cpp +// 序列化为JSON字符串 +std::string json = GraphSerializer::toJson(graph); + +// 保存到文件 +GraphSerializer::toJsonFile(graph, "graph.json"); + +// 从文件加载(需要完整的JSON解析器支持) +ComputationGraph newGraph; +GraphSerializer::fromJsonFile("graph.json", newGraph); +``` + +## API 文档 + +### ComputationGraph 类 + +#### 主要方法 + +- `createNode(id, opType)` - 创建新节点 +- `createEdge(id, from, to)` - 创建新边 +- `connect(fromId, toId, edgeId)` - 连接两个节点 +- `registerOperator(opType, executor)` - 注册算子执行器 +- `execute(inputs, outputs)` - 执行计算图 +- `topologicalSort()` - 获取拓扑排序结果 +- `clear()` - 清空计算图 + +### GraphNode 类 + +#### 主要方法 + +- `getId()` / `getOpType()` / `getName()` - 获取节点信息 +- `setParameter(key, value)` - 设置节点参数 +- `getParameter(key)` - 获取节点参数 +- `getInputEdges()` / `getOutputEdges()` - 获取连接的边 + +### GraphEdge 类 + +#### 主要方法 + +- `getId()` / `getName()` - 获取边信息 +- `setTensor(tensor)` / `getTensor()` - 设置/获取张量数据 +- `getFromNode()` / `getToNode()` - 获取连接的节点 +- `hasTensor()` - 检查是否有张量数据 + +### GraphSerializer 类 + +#### 静态方法 + +- `toJson(graph)` - 序列化为JSON字符串 +- `toJsonFile(graph, filename)` - 序列化到文件 +- `fromJson(json, graph)` - 从JSON反序列化 +- `fromJsonFile(filename, graph)` - 从文件反序列化 + +## 示例程序 + +完整的示例程序位于 `example/ad_demo/main.cpp`,演示了: + +1. 创建包含3个节点的计算图 +2. 注册加法和乘法算子 +3. 执行计算: `(input1 + input2) * input3 + input4` +4. 序列化为JSON格式 +5. 验证计算结果 + +运行示例: + +```bash +cd example/ad_demo +mkdir build && cd build +cmake .. +make +./ad_demo +``` + +## JSON格式说明 + +计算图序列化为JSON格式,包含nodes和edges两个主要部分: + +```json +{ + "nodes": [ + { + "id": "node1", + "opType": "add", + "name": "加法节点", + "inputEdges": ["edge1", "edge2"], + "outputEdges": ["edge3"], + "parameters": {} + } + ], + "edges": [ + { + "id": "edge1", + "name": "输入边1", + "fromNode": null, + "toNode": "node1" + } + ] +} +``` + +## 注意事项 + +1. **节点ID和边ID必须唯一**:在同一个计算图中,节点ID和边ID不能重复 +2. **算子必须先注册**:执行前需要注册所有用到的算子类型 +3. **输入边必须有数据**:执行时所有输入边必须设置张量数据 +4. **无环要求**:计算图必须是有向无环图(DAG) +5. **JSON反序列化**:当前fromJson实现为简化版本,完整功能建议使用nlohmann/json库 + +## 后续规划 + +- [ ] 实现反向传播和梯度计算 +- [ ] 支持更多内置算子 +- [ ] 优化内存管理 +- [ ] 支持子图和条件执行 +- [ ] 完整的JSON反序列化实现 +- [ ] 计算图可视化工具 +- [ ] 性能分析和优化 + +## 贡献 + +欢迎提交Issue和Pull Request! + +## 许可 + +遵循YTensor项目的许可协议。 diff --git a/include/ad.hpp b/include/ad.hpp new file mode 100644 index 0000000..76a9e7c --- /dev/null +++ b/include/ad.hpp @@ -0,0 +1,18 @@ +#pragma once +/*************** +* @file: ad.hpp +* @brief: 自动微分模块总入口 +* @description: 包含计算图的所有头文件和实现 +***************/ + +// 头文件 +#include "ad/graph_node.hpp" +#include "ad/graph_edge.hpp" +#include "ad/graph_runtime.hpp" +#include "ad/graph_serialization.hpp" + +// 实现文件 +#include "../src/ad/graph_node.inl" +#include "../src/ad/graph_edge.inl" +#include "../src/ad/graph_runtime.inl" +#include "../src/ad/graph_serialization.inl" diff --git a/include/ad/graph_edge.hpp b/include/ad/graph_edge.hpp new file mode 100644 index 0000000..b7a9ec5 --- /dev/null +++ b/include/ad/graph_edge.hpp @@ -0,0 +1,93 @@ +#pragma once +/*************** +* @file: graph_edge.hpp +* @brief: 计算图边定义,表示计算图中的数据流 +* @description: 边代表有向的数据流,连接源节点和目标节点,携带张量数据 +***************/ + +#include +#include +#include "../ytensor_base.hpp" + +namespace yt { +namespace ad { + +// 前向声明 +class GraphNode; + +/// @brief 计算图边类,表示节点之间的有向数据流 +class GraphEdge { +public: + /// @brief 默认构造函数 + GraphEdge() = default; + + /// @brief 构造一个计算图边 + /// @param edgeId 边的唯一标识符 + /// @param fromNode 源节点(可为nullptr,表示输入边) + /// @param toNode 目标节点(可为nullptr,表示输出边) + GraphEdge(const std::string& edgeId, + std::shared_ptr fromNode = nullptr, + std::shared_ptr toNode = nullptr); + + /// @brief 获取边ID + /// @return 边的唯一标识符 + std::string getId() const { return edgeId_; } + + /// @brief 设置边名称(可选) + /// @param name 边名称 + void setName(const std::string& name) { name_ = name; } + + /// @brief 获取边名称 + /// @return 边名称 + std::string getName() const { return name_; } + + /// @brief 设置源节点 + /// @param node 源节点指针 + void setFromNode(std::shared_ptr node) { fromNode_ = node; } + + /// @brief 设置目标节点 + /// @param node 目标节点指针 + void setToNode(std::shared_ptr node) { toNode_ = node; } + + /// @brief 获取源节点 + /// @return 源节点指针 + std::shared_ptr getFromNode() const { return fromNode_; } + + /// @brief 获取目标节点 + /// @return 目标节点指针 + std::shared_ptr getToNode() const { return toNode_; } + + /// @brief 设置边上的张量数据 + /// @param tensor 张量数据 + void setTensor(const YTensorBase& tensor) { tensor_ = tensor; } + + /// @brief 获取边上的张量数据 + /// @return 张量数据的引用 + YTensorBase& getTensor() { return tensor_; } + + /// @brief 获取边上的张量数据(const版本) + /// @return 张量数据的const引用 + const YTensorBase& getTensor() const { return tensor_; } + + /// @brief 检查边是否包含有效的张量数据 + /// @return 如果张量有效返回true + bool hasTensor() const { return tensor_.size() > 0; } + + /// @brief 检查是否为输入边(没有源节点) + /// @return 如果是输入边返回true + bool isInputEdge() const { return fromNode_ == nullptr; } + + /// @brief 检查是否为输出边(没有目标节点) + /// @return 如果是输出边返回true + bool isOutputEdge() const { return toNode_ == nullptr; } + +private: + std::string edgeId_; // 边的唯一标识符 + std::string name_; // 边名称(可选) + std::shared_ptr fromNode_; // 源节点 + std::shared_ptr toNode_; // 目标节点 + YTensorBase tensor_; // 边上携带的张量数据 +}; + +} // namespace ad +} // namespace yt diff --git a/include/ad/graph_node.hpp b/include/ad/graph_node.hpp new file mode 100644 index 0000000..f80aa68 --- /dev/null +++ b/include/ad/graph_node.hpp @@ -0,0 +1,105 @@ +#pragma once +/*************** +* @file: graph_node.hpp +* @brief: 计算图节点定义,表示计算图中的算子 +* @description: 节点代表算子操作,包含操作类型、参数和连接关系 +***************/ + +#include +#include +#include +#include +#include + +namespace yt { +namespace ad { + +// 前向声明 +class GraphEdge; +class ComputationGraph; + +/// @brief 计算图节点类,表示计算图中的一个算子 +class GraphNode { +public: + /// @brief 默认构造函数 + GraphNode() = default; + + /// @brief 构造一个计算图节点 + /// @param nodeId 节点的唯一标识符 + /// @param opType 算子类型(如 "add", "mul", "matmul", "relu" 等) + GraphNode(const std::string& nodeId, const std::string& opType); + + /// @brief 获取节点ID + /// @return 节点的唯一标识符 + std::string getId() const { return nodeId_; } + + /// @brief 获取算子类型 + /// @return 算子类型字符串 + std::string getOpType() const { return opType_; } + + /// @brief 设置节点名称(可选,用于调试和可视化) + /// @param name 节点名称 + void setName(const std::string& name) { name_ = name; } + + /// @brief 获取节点名称 + /// @return 节点名称 + std::string getName() const { return name_; } + + /// @brief 添加输入边 + /// @param edge 指向输入边的指针 + void addInputEdge(std::shared_ptr edge); + + /// @brief 添加输出边 + /// @param edge 指向输出边的指针 + void addOutputEdge(std::shared_ptr edge); + + /// @brief 获取所有输入边 + /// @return 输入边的vector + const std::vector>& getInputEdges() const { return inputEdges_; } + + /// @brief 获取所有输出边 + /// @return 输出边的vector + const std::vector>& getOutputEdges() const { return outputEdges_; } + + /// @brief 设置节点参数 + /// @param key 参数名 + /// @param value 参数值(使用std::any以支持多种类型) + void setParameter(const std::string& key, const std::any& value); + + /// @brief 获取节点参数 + /// @param key 参数名 + /// @return 参数值 + std::any getParameter(const std::string& key) const; + + /// @brief 检查是否存在某个参数 + /// @param key 参数名 + /// @return 如果参数存在返回true,否则返回false + bool hasParameter(const std::string& key) const; + + /// @brief 获取所有参数 + /// @return 参数映射表 + const std::unordered_map& getParameters() const { return parameters_; } + + /// @brief 标记节点是否已执行 + /// @param executed 是否已执行 + void setExecuted(bool executed) { executed_ = executed; } + + /// @brief 检查节点是否已执行 + /// @return 如果已执行返回true + bool isExecuted() const { return executed_; } + + /// @brief 重置节点执行状态 + void reset() { executed_ = false; } + +private: + std::string nodeId_; // 节点唯一标识符 + std::string opType_; // 算子类型 + std::string name_; // 节点名称(可选) + std::vector> inputEdges_; // 输入边列表 + std::vector> outputEdges_; // 输出边列表 + std::unordered_map parameters_; // 节点参数 + bool executed_ = false; // 是否已执行 +}; + +} // namespace ad +} // namespace yt diff --git a/include/ad/graph_runtime.hpp b/include/ad/graph_runtime.hpp new file mode 100644 index 0000000..2aaae45 --- /dev/null +++ b/include/ad/graph_runtime.hpp @@ -0,0 +1,161 @@ +#pragma once +/*************** +* @file: graph_runtime.hpp +* @brief: 计算图运行时,负责计算图的构建、执行和管理 +* @description: 运行时引擎管理计算图的生命周期,提供节点和边的创建、执行、序列化等功能 +***************/ + +#include +#include +#include +#include +#include +#include + +#include "graph_node.hpp" +#include "graph_edge.hpp" +#include "../ytensor_base.hpp" + +namespace yt { +namespace ad { + +// 算子执行函数类型定义 +// 输入:节点指针,输入边列表,输出边列表 +// 输出:void(结果写入输出边) +using OpExecutor = std::function>&, + std::vector>& +)>; + +/// @brief 计算图运行时类 +class ComputationGraph { +public: + /// @brief 默认构造函数 + ComputationGraph() = default; + + /// @brief 创建一个新节点 + /// @param nodeId 节点ID(必须唯一) + /// @param opType 算子类型 + /// @return 节点指针 + std::shared_ptr createNode(const std::string& nodeId, const std::string& opType); + + /// @brief 创建一个新边 + /// @param edgeId 边ID(必须唯一) + /// @param fromNode 源节点(nullptr表示输入边) + /// @param toNode 目标节点(nullptr表示输出边) + /// @return 边指针 + std::shared_ptr createEdge( + const std::string& edgeId, + std::shared_ptr fromNode = nullptr, + std::shared_ptr toNode = nullptr + ); + + /// @brief 连接两个节点 + /// @param fromNodeId 源节点ID + /// @param toNodeId 目标节点ID + /// @param edgeId 边ID(如果为空则自动生成) + /// @return 创建的边指针 + std::shared_ptr connect( + const std::string& fromNodeId, + const std::string& toNodeId, + const std::string& edgeId = "" + ); + + /// @brief 获取节点 + /// @param nodeId 节点ID + /// @return 节点指针,如果不存在则返回nullptr + std::shared_ptr getNode(const std::string& nodeId) const; + + /// @brief 获取边 + /// @param edgeId 边ID + /// @return 边指针,如果不存在则返回nullptr + std::shared_ptr getEdge(const std::string& edgeId) const; + + /// @brief 获取所有节点 + /// @return 所有节点的map + const std::unordered_map>& getNodes() const { + return nodes_; + } + + /// @brief 获取所有边 + /// @return 所有边的map + const std::unordered_map>& getEdges() const { + return edges_; + } + + /// @brief 注册算子执行器 + /// @param opType 算子类型 + /// @param executor 执行函数 + void registerOperator(const std::string& opType, OpExecutor executor); + + /// @brief 执行计算图 + /// @param inputs 输入张量映射(边ID -> 张量) + /// @param outputs 输出边ID列表 + /// @return 输出张量映射(边ID -> 张量) + std::unordered_map execute( + const std::unordered_map& inputs, + const std::vector& outputs + ); + + /// @brief 执行单个节点 + /// @param nodeId 节点ID + void executeNode(const std::string& nodeId); + + /// @brief 拓扑排序,返回执行顺序 + /// @return 节点ID的执行顺序 + std::vector topologicalSort() const; + + /// @brief 重置所有节点的执行状态 + void reset(); + + /// @brief 清空计算图 + void clear(); + + /// @brief 获取计算图中节点的数量 + /// @return 节点数量 + size_t nodeCount() const { return nodes_.size(); } + + /// @brief 获取计算图中边的数量 + /// @return 边数量 + size_t edgeCount() const { return edges_.size(); } + + /// @brief 检查计算图是否为空 + /// @return 如果为空返回true + bool empty() const { return nodes_.empty(); } + +private: + /// @brief 生成唯一的边ID + /// @return 唯一的边ID + std::string generateEdgeId(); + + /// @brief 检查图是否有环(DFS辅助函数) + /// @param nodeId 当前节点ID + /// @param visited 访问标记 + /// @param recursionStack 递归栈标记 + /// @return 如果检测到环返回true + bool hasCycleDFS( + const std::string& nodeId, + std::unordered_map& visited, + std::unordered_map& recursionStack + ) const; + + /// @brief 拓扑排序的DFS辅助函数 + /// @param nodeId 当前节点ID + /// @param visited 访问标记 + /// @param stack 结果栈 + void topologicalSortDFS( + const std::string& nodeId, + std::unordered_map& visited, + std::vector& stack + ) const; + +private: + std::unordered_map> nodes_; // 节点映射表 + std::unordered_map> edges_; // 边映射表 + std::unordered_map operators_; // 算子执行器映射表 + int edgeCounter_ = 0; // 边计数器,用于生成唯一ID +}; + +} // namespace ad +} // namespace yt diff --git a/include/ad/graph_serialization.hpp b/include/ad/graph_serialization.hpp new file mode 100644 index 0000000..ab7a362 --- /dev/null +++ b/include/ad/graph_serialization.hpp @@ -0,0 +1,113 @@ +#pragma once +/*************** +* @file: graph_serialization.hpp +* @brief: 计算图序列化,支持JSON格式的导入导出 +* @description: 提供计算图与JSON之间的序列化和反序列化功能 +***************/ + +#include +#include +#include +#include +#include + +#include "graph_runtime.hpp" + +namespace yt { +namespace ad { + +/// @brief 简单的JSON构建器类(用于序列化) +class JsonBuilder { +public: + /// @brief 开始一个对象 + void beginObject(); + + /// @brief 结束一个对象 + void endObject(); + + /// @brief 开始一个数组 + void beginArray(); + + /// @brief 结束一个数组 + void endArray(); + + /// @brief 添加键值对 + void addKey(const std::string& key); + + /// @brief 添加字符串值 + void addString(const std::string& value); + + /// @brief 添加数字值 + void addNumber(double value); + + /// @brief 添加整数值 + void addInt(int value); + + /// @brief 添加布尔值 + void addBool(bool value); + + /// @brief 添加null值 + void addNull(); + + /// @brief 获取生成的JSON字符串 + std::string toString() const { return buffer_; } + + /// @brief 清空缓冲区 + void clear() { buffer_.clear(); needComma_ = false; } + +private: + /// @brief 添加逗号(如果需要) + void addCommaIfNeeded(); + + std::string buffer_; + bool needComma_ = false; +}; + +/// @brief 计算图序列化类 +class GraphSerializer { +public: + /// @brief 将计算图序列化为JSON字符串 + /// @param graph 计算图 + /// @return JSON字符串 + static std::string toJson(const ComputationGraph& graph); + + /// @brief 将计算图序列化到JSON文件 + /// @param graph 计算图 + /// @param filename 文件名 + /// @return 如果成功返回true + static bool toJsonFile(const ComputationGraph& graph, const std::string& filename); + + /// @brief 从JSON字符串反序列化计算图 + /// @param json JSON字符串 + /// @param graph 输出的计算图 + /// @return 如果成功返回true + static bool fromJson(const std::string& json, ComputationGraph& graph); + + /// @brief 从JSON文件反序列化计算图 + /// @param filename 文件名 + /// @param graph 输出的计算图 + /// @return 如果成功返回true + static bool fromJsonFile(const std::string& filename, ComputationGraph& graph); + +private: + /// @brief 序列化单个节点 + /// @param node 节点指针 + /// @param builder JSON构建器 + static void serializeNode(const std::shared_ptr& node, JsonBuilder& builder); + + /// @brief 序列化单个边 + /// @param edge 边指针 + /// @param builder JSON构建器 + static void serializeEdge(const std::shared_ptr& edge, JsonBuilder& builder); + + /// @brief 序列化节点参数 + /// @param parameters 参数映射 + /// @param builder JSON构建器 + static void serializeParameters( + const std::unordered_map& parameters, + JsonBuilder& builder + ); +}; + +} // namespace ad +} // namespace yt diff --git a/include/ytensor_concepts.hpp b/include/ytensor_concepts.hpp index ec3c610..299a881 100644 --- a/include/ytensor_concepts.hpp +++ b/include/ytensor_concepts.hpp @@ -3,6 +3,7 @@ #include #include #include +#include // 前向声明 namespace yt { diff --git a/src/ad/graph_edge.inl b/src/ad/graph_edge.inl new file mode 100644 index 0000000..e1c2db0 --- /dev/null +++ b/src/ad/graph_edge.inl @@ -0,0 +1,16 @@ +/*************** +* @file: graph_edge.inl +* @brief: 计算图边的实现 +***************/ + +namespace yt { +namespace ad { + +inline GraphEdge::GraphEdge(const std::string& edgeId, + std::shared_ptr fromNode, + std::shared_ptr toNode) + : edgeId_(edgeId), name_(edgeId), fromNode_(fromNode), toNode_(toNode) { +} + +} // namespace ad +} // namespace yt diff --git a/src/ad/graph_node.inl b/src/ad/graph_node.inl new file mode 100644 index 0000000..e4f6527 --- /dev/null +++ b/src/ad/graph_node.inl @@ -0,0 +1,42 @@ +/*************** +* @file: graph_node.inl +* @brief: 计算图节点的实现 +***************/ + +namespace yt { +namespace ad { + +inline GraphNode::GraphNode(const std::string& nodeId, const std::string& opType) + : nodeId_(nodeId), opType_(opType), name_(nodeId) { +} + +inline void GraphNode::addInputEdge(std::shared_ptr edge) { + if (edge) { + inputEdges_.push_back(edge); + } +} + +inline void GraphNode::addOutputEdge(std::shared_ptr edge) { + if (edge) { + outputEdges_.push_back(edge); + } +} + +inline void GraphNode::setParameter(const std::string& key, const std::any& value) { + parameters_[key] = value; +} + +inline std::any GraphNode::getParameter(const std::string& key) const { + auto it = parameters_.find(key); + if (it != parameters_.end()) { + return it->second; + } + return std::any(); +} + +inline bool GraphNode::hasParameter(const std::string& key) const { + return parameters_.find(key) != parameters_.end(); +} + +} // namespace ad +} // namespace yt diff --git a/src/ad/graph_runtime.inl b/src/ad/graph_runtime.inl new file mode 100644 index 0000000..0aa9df3 --- /dev/null +++ b/src/ad/graph_runtime.inl @@ -0,0 +1,252 @@ +/*************** +* @file: graph_runtime.inl +* @brief: 计算图运行时的实现 +***************/ + +#include +#include +#include + +namespace yt { +namespace ad { + +inline std::shared_ptr ComputationGraph::createNode( + const std::string& nodeId, const std::string& opType) { + + if (nodes_.find(nodeId) != nodes_.end()) { + throw std::runtime_error("Node with ID '" + nodeId + "' already exists"); + } + + auto node = std::make_shared(nodeId, opType); + nodes_[nodeId] = node; + return node; +} + +inline std::shared_ptr ComputationGraph::createEdge( + const std::string& edgeId, + std::shared_ptr fromNode, + std::shared_ptr toNode) { + + if (edges_.find(edgeId) != edges_.end()) { + throw std::runtime_error("Edge with ID '" + edgeId + "' already exists"); + } + + auto edge = std::make_shared(edgeId, fromNode, toNode); + edges_[edgeId] = edge; + + // 更新节点的输入输出边列表 + if (fromNode) { + fromNode->addOutputEdge(edge); + } + if (toNode) { + toNode->addInputEdge(edge); + } + + return edge; +} + +inline std::shared_ptr ComputationGraph::connect( + const std::string& fromNodeId, + const std::string& toNodeId, + const std::string& edgeId) { + + auto fromNode = getNode(fromNodeId); + auto toNode = getNode(toNodeId); + + if (!fromNode) { + throw std::runtime_error("Source node '" + fromNodeId + "' not found"); + } + if (!toNode) { + throw std::runtime_error("Target node '" + toNodeId + "' not found"); + } + + std::string actualEdgeId = edgeId.empty() ? generateEdgeId() : edgeId; + return createEdge(actualEdgeId, fromNode, toNode); +} + +inline std::shared_ptr ComputationGraph::getNode(const std::string& nodeId) const { + auto it = nodes_.find(nodeId); + return (it != nodes_.end()) ? it->second : nullptr; +} + +inline std::shared_ptr ComputationGraph::getEdge(const std::string& edgeId) const { + auto it = edges_.find(edgeId); + return (it != edges_.end()) ? it->second : nullptr; +} + +inline void ComputationGraph::registerOperator(const std::string& opType, OpExecutor executor) { + operators_[opType] = executor; +} + +inline std::unordered_map ComputationGraph::execute( + const std::unordered_map& inputs, + const std::vector& outputs) { + + // 重置所有节点的执行状态 + reset(); + + // 设置输入张量 + for (const auto& [edgeId, tensor] : inputs) { + auto edge = getEdge(edgeId); + if (!edge) { + throw std::runtime_error("Input edge '" + edgeId + "' not found"); + } + edge->setTensor(tensor); + } + + // 获取执行顺序 + std::vector execOrder = topologicalSort(); + + // 按拓扑顺序执行节点 + for (const auto& nodeId : execOrder) { + executeNode(nodeId); + } + + // 收集输出张量 + std::unordered_map results; + for (const auto& edgeId : outputs) { + auto edge = getEdge(edgeId); + if (!edge) { + throw std::runtime_error("Output edge '" + edgeId + "' not found"); + } + if (!edge->hasTensor()) { + throw std::runtime_error("Output edge '" + edgeId + "' has no tensor data"); + } + results[edgeId] = edge->getTensor(); + } + + return results; +} + +inline void ComputationGraph::executeNode(const std::string& nodeId) { + auto node = getNode(nodeId); + if (!node) { + throw std::runtime_error("Node '" + nodeId + "' not found"); + } + + if (node->isExecuted()) { + return; // 节点已执行,跳过 + } + + // 获取算子执行器 + auto it = operators_.find(node->getOpType()); + if (it == operators_.end()) { + throw std::runtime_error("Operator '" + node->getOpType() + "' not registered"); + } + + // 获取输入输出边 + auto inputEdges = node->getInputEdges(); + auto outputEdges = node->getOutputEdges(); + + // 检查输入是否就绪 + for (const auto& edge : inputEdges) { + if (!edge->hasTensor()) { + throw std::runtime_error("Input edge '" + edge->getId() + "' for node '" + + nodeId + "' has no tensor data"); + } + } + + // 执行算子 + it->second(*node, inputEdges, outputEdges); + + // 标记节点已执行 + node->setExecuted(true); +} + +inline std::vector ComputationGraph::topologicalSort() const { + std::vector result; + std::unordered_map visited; + + // 初始化访问标记 + for (const auto& [nodeId, _] : nodes_) { + visited[nodeId] = false; + } + + // 对每个未访问的节点执行DFS + for (const auto& [nodeId, _] : nodes_) { + if (!visited[nodeId]) { + topologicalSortDFS(nodeId, visited, result); + } + } + + // 反转结果(DFS后序遍历需要反转才是拓扑序) + std::reverse(result.begin(), result.end()); + return result; +} + +inline void ComputationGraph::topologicalSortDFS( + const std::string& nodeId, + std::unordered_map& visited, + std::vector& stack) const { + + visited[nodeId] = true; + + auto node = getNode(nodeId); + if (node) { + // 访问所有输出边的目标节点 + for (const auto& edge : node->getOutputEdges()) { + auto toNode = edge->getToNode(); + if (toNode) { + const std::string& nextNodeId = toNode->getId(); + if (!visited[nextNodeId]) { + topologicalSortDFS(nextNodeId, visited, stack); + } + } + } + } + + // 将当前节点加入栈 + stack.push_back(nodeId); +} + +inline void ComputationGraph::reset() { + for (auto& [_, node] : nodes_) { + node->reset(); + } +} + +inline void ComputationGraph::clear() { + nodes_.clear(); + edges_.clear(); + operators_.clear(); + edgeCounter_ = 0; +} + +inline std::string ComputationGraph::generateEdgeId() { + std::ostringstream oss; + oss << "edge_" << edgeCounter_++; + return oss.str(); +} + +inline bool ComputationGraph::hasCycleDFS( + const std::string& nodeId, + std::unordered_map& visited, + std::unordered_map& recursionStack) const { + + visited[nodeId] = true; + recursionStack[nodeId] = true; + + auto node = getNode(nodeId); + if (node) { + for (const auto& edge : node->getOutputEdges()) { + auto toNode = edge->getToNode(); + if (toNode) { + const std::string& nextNodeId = toNode->getId(); + + if (!visited[nextNodeId]) { + if (hasCycleDFS(nextNodeId, visited, recursionStack)) { + return true; + } + } else if (recursionStack[nextNodeId]) { + return true; // 发现环 + } + } + } + } + + recursionStack[nodeId] = false; + return false; +} + +} // namespace ad +} // namespace yt diff --git a/src/ad/graph_serialization.inl b/src/ad/graph_serialization.inl new file mode 100644 index 0000000..05a4831 --- /dev/null +++ b/src/ad/graph_serialization.inl @@ -0,0 +1,255 @@ +/*************** +* @file: graph_serialization.inl +* @brief: 计算图序列化的实现 +***************/ + +#include +#include + +namespace yt { +namespace ad { + +// JsonBuilder 实现 +inline void JsonBuilder::beginObject() { + addCommaIfNeeded(); + buffer_ += "{"; + needComma_ = false; +} + +inline void JsonBuilder::endObject() { + buffer_ += "}"; + needComma_ = true; +} + +inline void JsonBuilder::beginArray() { + addCommaIfNeeded(); + buffer_ += "["; + needComma_ = false; +} + +inline void JsonBuilder::endArray() { + buffer_ += "]"; + needComma_ = true; +} + +inline void JsonBuilder::addKey(const std::string& key) { + addCommaIfNeeded(); + buffer_ += "\"" + key + "\":"; + needComma_ = false; +} + +inline void JsonBuilder::addString(const std::string& value) { + addCommaIfNeeded(); + buffer_ += "\""; + // 转义特殊字符 + for (char c : value) { + switch (c) { + case '"': buffer_ += "\\\""; break; + case '\\': buffer_ += "\\\\"; break; + case '\n': buffer_ += "\\n"; break; + case '\r': buffer_ += "\\r"; break; + case '\t': buffer_ += "\\t"; break; + default: buffer_ += c; + } + } + buffer_ += "\""; + needComma_ = true; +} + +inline void JsonBuilder::addNumber(double value) { + addCommaIfNeeded(); + std::ostringstream oss; + oss << std::setprecision(15) << value; + buffer_ += oss.str(); + needComma_ = true; +} + +inline void JsonBuilder::addInt(int value) { + addCommaIfNeeded(); + buffer_ += std::to_string(value); + needComma_ = true; +} + +inline void JsonBuilder::addBool(bool value) { + addCommaIfNeeded(); + buffer_ += value ? "true" : "false"; + needComma_ = true; +} + +inline void JsonBuilder::addNull() { + addCommaIfNeeded(); + buffer_ += "null"; + needComma_ = true; +} + +inline void JsonBuilder::addCommaIfNeeded() { + if (needComma_) { + buffer_ += ","; + } +} + +// GraphSerializer 实现 +inline std::string GraphSerializer::toJson(const ComputationGraph& graph) { + JsonBuilder builder; + + builder.beginObject(); + + // 序列化节点 + builder.addKey("nodes"); + builder.beginArray(); + for (const auto& [nodeId, node] : graph.getNodes()) { + serializeNode(node, builder); + } + builder.endArray(); + + // 序列化边 + builder.addKey("edges"); + builder.beginArray(); + for (const auto& [edgeId, edge] : graph.getEdges()) { + serializeEdge(edge, builder); + } + builder.endArray(); + + builder.endObject(); + + return builder.toString(); +} + +inline bool GraphSerializer::toJsonFile(const ComputationGraph& graph, const std::string& filename) { + std::string json = toJson(graph); + std::ofstream file(filename); + if (!file.is_open()) { + return false; + } + file << json; + file.close(); + return true; +} + +inline bool GraphSerializer::fromJson(const std::string& json, ComputationGraph& graph) { + // 简化的JSON解析实现 + // 注意:这是一个简化版本,仅用于基本功能演示 + // 实际应用中建议使用专业的JSON库如nlohmann/json + + graph.clear(); + + // 这里提供一个基本的解析框架 + // 完整实现需要一个完整的JSON解析器 + // 暂时返回false表示需要使用外部JSON库 + + return false; +} + +inline bool GraphSerializer::fromJsonFile(const std::string& filename, ComputationGraph& graph) { + std::ifstream file(filename); + if (!file.is_open()) { + return false; + } + + std::stringstream buffer; + buffer << file.rdbuf(); + file.close(); + + return fromJson(buffer.str(), graph); +} + +inline void GraphSerializer::serializeNode(const std::shared_ptr& node, JsonBuilder& builder) { + builder.beginObject(); + + builder.addKey("id"); + builder.addString(node->getId()); + + builder.addKey("opType"); + builder.addString(node->getOpType()); + + builder.addKey("name"); + builder.addString(node->getName()); + + // 序列化输入边ID列表 + builder.addKey("inputEdges"); + builder.beginArray(); + for (const auto& edge : node->getInputEdges()) { + builder.addString(edge->getId()); + } + builder.endArray(); + + // 序列化输出边ID列表 + builder.addKey("outputEdges"); + builder.beginArray(); + for (const auto& edge : node->getOutputEdges()) { + builder.addString(edge->getId()); + } + builder.endArray(); + + // 序列化参数 + builder.addKey("parameters"); + serializeParameters(node->getParameters(), builder); + + builder.endObject(); +} + +inline void GraphSerializer::serializeEdge(const std::shared_ptr& edge, JsonBuilder& builder) { + builder.beginObject(); + + builder.addKey("id"); + builder.addString(edge->getId()); + + builder.addKey("name"); + builder.addString(edge->getName()); + + builder.addKey("fromNode"); + if (edge->getFromNode()) { + builder.addString(edge->getFromNode()->getId()); + } else { + builder.addNull(); + } + + builder.addKey("toNode"); + if (edge->getToNode()) { + builder.addString(edge->getToNode()->getId()); + } else { + builder.addNull(); + } + + builder.endObject(); +} + +inline void GraphSerializer::serializeParameters( + const std::unordered_map& parameters, + JsonBuilder& builder) { + + builder.beginObject(); + + for (const auto& [key, value] : parameters) { + builder.addKey(key); + + // 尝试转换常见类型 + // 注意:std::any的类型检查在运行时进行 + try { + if (value.type() == typeid(int)) { + builder.addInt(std::any_cast(value)); + } else if (value.type() == typeid(double)) { + builder.addNumber(std::any_cast(value)); + } else if (value.type() == typeid(float)) { + builder.addNumber(static_cast(std::any_cast(value))); + } else if (value.type() == typeid(bool)) { + builder.addBool(std::any_cast(value)); + } else if (value.type() == typeid(std::string)) { + builder.addString(std::any_cast(value)); + } else if (value.type() == typeid(const char*)) { + builder.addString(std::any_cast(value)); + } else { + // 未知类型,序列化为null + builder.addNull(); + } + } catch (...) { + // 转换失败,序列化为null + builder.addNull(); + } + } + + builder.endObject(); +} + +} // namespace ad +} // namespace yt From a0c0bc05973b1406a9c63a353ab3ffd2f676f8f7 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 31 Jan 2026 16:06:55 +0000 Subject: [PATCH 3/8] Add example program for computational graph demo Co-authored-by: SnifferCaptain <151126294+SnifferCaptain@users.noreply.github.com> --- example/ad_demo/CMakeLists.txt | 27 ++++++ example/ad_demo/main.cpp | 170 +++++++++++++++++++++++++++++++++ 2 files changed, 197 insertions(+) create mode 100644 example/ad_demo/CMakeLists.txt create mode 100644 example/ad_demo/main.cpp diff --git a/example/ad_demo/CMakeLists.txt b/example/ad_demo/CMakeLists.txt new file mode 100644 index 0000000..5f8cb2f --- /dev/null +++ b/example/ad_demo/CMakeLists.txt @@ -0,0 +1,27 @@ +cmake_minimum_required(VERSION 3.10) + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -g") + +# 设置编译模式 +set(CMAKE_BUILD_TYPE Release) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2") + +project(ad_demo) + +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +find_package(ZLIB REQUIRED) + +set(SRC + main.cpp +) + +add_executable(${PROJECT_NAME} ${SRC}) + +target_link_libraries(${PROJECT_NAME} + ZLIB::ZLIB + pthread +) diff --git a/example/ad_demo/main.cpp b/example/ad_demo/main.cpp new file mode 100644 index 0000000..fd64d8c --- /dev/null +++ b/example/ad_demo/main.cpp @@ -0,0 +1,170 @@ +#include +#include +#include "../../ytensor.hpp" +#include "../../include/ad.hpp" + +using namespace yt; +using namespace yt::ad; + +// 示例:实现简单的算子 +void addOperator(const GraphNode& node, + const std::vector>& inputs, + std::vector>& outputs) { + if (inputs.size() < 2) { + throw std::runtime_error("Add operator requires at least 2 inputs"); + } + if (outputs.empty()) { + throw std::runtime_error("Add operator requires at least 1 output"); + } + + const auto& tensor1 = inputs[0]->getTensor(); + const auto& tensor2 = inputs[1]->getTensor(); + + // 简单相加操作 + YTensorBase result = tensor1 + tensor2; + outputs[0]->setTensor(result); +} + +void mulOperator(const GraphNode& node, + const std::vector>& inputs, + std::vector>& outputs) { + if (inputs.size() < 2) { + throw std::runtime_error("Mul operator requires at least 2 inputs"); + } + if (outputs.empty()) { + throw std::runtime_error("Mul operator requires at least 1 output"); + } + + const auto& tensor1 = inputs[0]->getTensor(); + const auto& tensor2 = inputs[1]->getTensor(); + + // 简单相乘操作 + YTensorBase result = tensor1 * tensor2; + outputs[0]->setTensor(result); +} + +int main() { + std::cout << "=== YTensor Computational Graph Demo ===" << std::endl; + + try { + // 创建计算图 + ComputationGraph graph; + + // 注册算子 + graph.registerOperator("add", addOperator); + graph.registerOperator("mul", mulOperator); + + // 创建节点 + auto node1 = graph.createNode("add1", "add"); + node1->setName("加法节点1"); + + auto node2 = graph.createNode("mul1", "mul"); + node2->setName("乘法节点1"); + + auto node3 = graph.createNode("add2", "add"); + node3->setName("加法节点2"); + + // 创建输入边 + auto input1 = graph.createEdge("input1", nullptr, node1); + auto input2 = graph.createEdge("input2", nullptr, node1); + auto input3 = graph.createEdge("input3", nullptr, node2); + + // 创建中间边 + auto edge1 = graph.createEdge("edge1", node1, node2); + auto edge2 = graph.createEdge("edge2", node2, node3); + + // 创建输出边 + auto input4 = graph.createEdge("input4", nullptr, node3); + auto output1 = graph.createEdge("output1", node3, nullptr); + + std::cout << "\n图结构创建完成!" << std::endl; + std::cout << "节点数量: " << graph.nodeCount() << std::endl; + std::cout << "边数量: " << graph.edgeCount() << std::endl; + + // 序列化为JSON + std::cout << "\n序列化计算图为JSON..." << std::endl; + std::string jsonStr = GraphSerializer::toJson(graph); + std::cout << "JSON输出:\n" << jsonStr << std::endl; + + // 保存到文件 + bool saved = GraphSerializer::toJsonFile(graph, "/tmp/graph.json"); + if (saved) { + std::cout << "\n计算图已保存到 /tmp/graph.json" << std::endl; + } + + // 准备输入数据 + std::cout << "\n准备输入数据..." << std::endl; + YTensorBase t1({2, 3}, "float32"); + YTensorBase t2({2, 3}, "float32"); + YTensorBase t3({2, 3}, "float32"); + YTensorBase t4({2, 3}, "float32"); + + // 填充数据 + for (int i = 0; i < 6; i++) { + t1.atData(i) = 1.0f; + t2.atData(i) = 2.0f; + t3.atData(i) = 3.0f; + t4.atData(i) = 4.0f; + } + + std::cout << "t1 = 1.0 (全部元素)" << std::endl; + std::cout << "t2 = 2.0 (全部元素)" << std::endl; + std::cout << "t3 = 3.0 (全部元素)" << std::endl; + std::cout << "t4 = 4.0 (全部元素)" << std::endl; + + // 执行计算图 + // 计算过程: output1 = (t1 + t2) * t3 + t4 = 3 * 3 + 4 = 13 + std::cout << "\n执行计算图..." << std::endl; + std::cout << "计算表达式: output = (input1 + input2) * input3 + input4" << std::endl; + std::cout << " = (1.0 + 2.0) * 3.0 + 4.0 = 13.0" << std::endl; + + std::unordered_map inputs = { + {"input1", t1}, + {"input2", t2}, + {"input3", t3}, + {"input4", t4} + }; + + auto outputs = graph.execute(inputs, {"output1"}); + + // 输出结果 + std::cout << "\n执行完成!" << std::endl; + const auto& result = outputs["output1"]; + std::cout << "输出张量形状: [" << result.shape(0) << ", " << result.shape(1) << "]" << std::endl; + std::cout << "输出值 (前6个元素): "; + for (int i = 0; i < 6; i++) { + std::cout << result.atData(i) << " "; + } + std::cout << std::endl; + + // 验证结果 + bool correct = true; + for (int i = 0; i < 6; i++) { + if (std::abs(result.atData(i) - 13.0f) > 1e-5) { + correct = false; + break; + } + } + + if (correct) { + std::cout << "\n✓ 计算结果正确!" << std::endl; + } else { + std::cout << "\n✗ 计算结果不正确!" << std::endl; + } + + // 测试拓扑排序 + std::cout << "\n拓扑排序结果:" << std::endl; + auto topoOrder = graph.topologicalSort(); + for (const auto& nodeId : topoOrder) { + auto node = graph.getNode(nodeId); + std::cout << " - " << nodeId << " (" << node->getName() << ")" << std::endl; + } + + } catch (const std::exception& e) { + std::cerr << "错误: " << e.what() << std::endl; + return 1; + } + + std::cout << "\n=== Demo完成 ===" << std::endl; + return 0; +} From 62327b41aa5bde4ee91cf11002ac81e101434ad1 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 31 Jan 2026 16:16:14 +0000 Subject: [PATCH 4/8] Refactor: Move tests to test folder, remove example, improve implementation Co-authored-by: SnifferCaptain <151126294+SnifferCaptain@users.noreply.github.com> --- doc/ad_module_cn.md | 226 ---------------------------- example/ad_demo/CMakeLists.txt | 27 ---- example/ad_demo/main.cpp | 170 --------------------- test/test_ad_graph.cpp | 262 +++++++++++++++++++++++++++++++++ 4 files changed, 262 insertions(+), 423 deletions(-) delete mode 100644 doc/ad_module_cn.md delete mode 100644 example/ad_demo/CMakeLists.txt delete mode 100644 example/ad_demo/main.cpp create mode 100644 test/test_ad_graph.cpp diff --git a/doc/ad_module_cn.md b/doc/ad_module_cn.md deleted file mode 100644 index d44ffc1..0000000 --- a/doc/ad_module_cn.md +++ /dev/null @@ -1,226 +0,0 @@ -# 自动微分 (Automatic Differentiation) 模块 - -## 概述 - -本模块实现了基于计算图的运行时系统,为YTensor提供自动微分的基础设施。计算图采用节点-边的设计,支持动态构建和执行。 - -## 设计理念 - -- **节点 (Node)**: 表示计算图中的算子操作(如加法、乘法、矩阵乘等) -- **边 (Edge)**: 表示有向数据流,连接节点并携带YTensorBase张量数据 -- **运行时 (Runtime)**: 管理计算图的构建、执行和序列化 - -## 特性 - -✅ **动态计算图**: 无需编译时确定,使用YTensorBase基类实现运行时灵活性 -✅ **拓扑排序**: 自动确定节点执行顺序 -✅ **算子注册**: 灵活的算子执行器注册机制 -✅ **JSON序列化**: 支持计算图的导入导出 -✅ **节点参数**: 支持算子的可配置参数 - -## 文件结构 - -``` -include/ad/ -├── graph_node.hpp # 节点定义 -├── graph_edge.hpp # 边定义 -├── graph_runtime.hpp # 运行时引擎 -└── graph_serialization.hpp # JSON序列化 - -src/ad/ -├── graph_node.inl # 节点实现 -├── graph_edge.inl # 边实现 -├── graph_runtime.inl # 运行时实现 -└── graph_serialization.inl # 序列化实现 - -include/ad.hpp # 模块总入口 -``` - -## 快速开始 - -### 1. 包含头文件 - -```cpp -#include "ytensor.hpp" -#include "include/ad.hpp" - -using namespace yt; -using namespace yt::ad; -``` - -### 2. 创建计算图 - -```cpp -ComputationGraph graph; - -// 创建节点 -auto node1 = graph.createNode("add1", "add"); -auto node2 = graph.createNode("mul1", "mul"); - -// 创建边并连接节点 -auto input1 = graph.createEdge("input1", nullptr, node1); -auto input2 = graph.createEdge("input2", nullptr, node1); -auto edge1 = graph.createEdge("edge1", node1, node2); -auto output = graph.createEdge("output", node2, nullptr); -``` - -### 3. 注册算子 - -```cpp -void addOperator(const GraphNode& node, - const std::vector>& inputs, - std::vector>& outputs) { - const auto& t1 = inputs[0]->getTensor(); - const auto& t2 = inputs[1]->getTensor(); - YTensorBase result = t1 + t2; - outputs[0]->setTensor(result); -} - -graph.registerOperator("add", addOperator); -``` - -### 4. 执行计算图 - -```cpp -// 准备输入 -YTensorBase t1({2, 3}, "float32"); -YTensorBase t2({2, 3}, "float32"); -// ... 填充数据 ... - -std::unordered_map inputs = { - {"input1", t1}, - {"input2", t2} -}; - -// 执行 -auto outputs = graph.execute(inputs, {"output"}); -const auto& result = outputs["output"]; -``` - -### 5. 序列化 - -```cpp -// 序列化为JSON字符串 -std::string json = GraphSerializer::toJson(graph); - -// 保存到文件 -GraphSerializer::toJsonFile(graph, "graph.json"); - -// 从文件加载(需要完整的JSON解析器支持) -ComputationGraph newGraph; -GraphSerializer::fromJsonFile("graph.json", newGraph); -``` - -## API 文档 - -### ComputationGraph 类 - -#### 主要方法 - -- `createNode(id, opType)` - 创建新节点 -- `createEdge(id, from, to)` - 创建新边 -- `connect(fromId, toId, edgeId)` - 连接两个节点 -- `registerOperator(opType, executor)` - 注册算子执行器 -- `execute(inputs, outputs)` - 执行计算图 -- `topologicalSort()` - 获取拓扑排序结果 -- `clear()` - 清空计算图 - -### GraphNode 类 - -#### 主要方法 - -- `getId()` / `getOpType()` / `getName()` - 获取节点信息 -- `setParameter(key, value)` - 设置节点参数 -- `getParameter(key)` - 获取节点参数 -- `getInputEdges()` / `getOutputEdges()` - 获取连接的边 - -### GraphEdge 类 - -#### 主要方法 - -- `getId()` / `getName()` - 获取边信息 -- `setTensor(tensor)` / `getTensor()` - 设置/获取张量数据 -- `getFromNode()` / `getToNode()` - 获取连接的节点 -- `hasTensor()` - 检查是否有张量数据 - -### GraphSerializer 类 - -#### 静态方法 - -- `toJson(graph)` - 序列化为JSON字符串 -- `toJsonFile(graph, filename)` - 序列化到文件 -- `fromJson(json, graph)` - 从JSON反序列化 -- `fromJsonFile(filename, graph)` - 从文件反序列化 - -## 示例程序 - -完整的示例程序位于 `example/ad_demo/main.cpp`,演示了: - -1. 创建包含3个节点的计算图 -2. 注册加法和乘法算子 -3. 执行计算: `(input1 + input2) * input3 + input4` -4. 序列化为JSON格式 -5. 验证计算结果 - -运行示例: - -```bash -cd example/ad_demo -mkdir build && cd build -cmake .. -make -./ad_demo -``` - -## JSON格式说明 - -计算图序列化为JSON格式,包含nodes和edges两个主要部分: - -```json -{ - "nodes": [ - { - "id": "node1", - "opType": "add", - "name": "加法节点", - "inputEdges": ["edge1", "edge2"], - "outputEdges": ["edge3"], - "parameters": {} - } - ], - "edges": [ - { - "id": "edge1", - "name": "输入边1", - "fromNode": null, - "toNode": "node1" - } - ] -} -``` - -## 注意事项 - -1. **节点ID和边ID必须唯一**:在同一个计算图中,节点ID和边ID不能重复 -2. **算子必须先注册**:执行前需要注册所有用到的算子类型 -3. **输入边必须有数据**:执行时所有输入边必须设置张量数据 -4. **无环要求**:计算图必须是有向无环图(DAG) -5. **JSON反序列化**:当前fromJson实现为简化版本,完整功能建议使用nlohmann/json库 - -## 后续规划 - -- [ ] 实现反向传播和梯度计算 -- [ ] 支持更多内置算子 -- [ ] 优化内存管理 -- [ ] 支持子图和条件执行 -- [ ] 完整的JSON反序列化实现 -- [ ] 计算图可视化工具 -- [ ] 性能分析和优化 - -## 贡献 - -欢迎提交Issue和Pull Request! - -## 许可 - -遵循YTensor项目的许可协议。 diff --git a/example/ad_demo/CMakeLists.txt b/example/ad_demo/CMakeLists.txt deleted file mode 100644 index 5f8cb2f..0000000 --- a/example/ad_demo/CMakeLists.txt +++ /dev/null @@ -1,27 +0,0 @@ -cmake_minimum_required(VERSION 3.10) - -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -g") - -# 设置编译模式 -set(CMAKE_BUILD_TYPE Release) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2") - -project(ad_demo) - -set(CMAKE_CXX_STANDARD 20) -set(CMAKE_CXX_STANDARD_REQUIRED ON) -set(CMAKE_EXPORT_COMPILE_COMMANDS ON) - -find_package(ZLIB REQUIRED) - -set(SRC - main.cpp -) - -add_executable(${PROJECT_NAME} ${SRC}) - -target_link_libraries(${PROJECT_NAME} - ZLIB::ZLIB - pthread -) diff --git a/example/ad_demo/main.cpp b/example/ad_demo/main.cpp deleted file mode 100644 index fd64d8c..0000000 --- a/example/ad_demo/main.cpp +++ /dev/null @@ -1,170 +0,0 @@ -#include -#include -#include "../../ytensor.hpp" -#include "../../include/ad.hpp" - -using namespace yt; -using namespace yt::ad; - -// 示例:实现简单的算子 -void addOperator(const GraphNode& node, - const std::vector>& inputs, - std::vector>& outputs) { - if (inputs.size() < 2) { - throw std::runtime_error("Add operator requires at least 2 inputs"); - } - if (outputs.empty()) { - throw std::runtime_error("Add operator requires at least 1 output"); - } - - const auto& tensor1 = inputs[0]->getTensor(); - const auto& tensor2 = inputs[1]->getTensor(); - - // 简单相加操作 - YTensorBase result = tensor1 + tensor2; - outputs[0]->setTensor(result); -} - -void mulOperator(const GraphNode& node, - const std::vector>& inputs, - std::vector>& outputs) { - if (inputs.size() < 2) { - throw std::runtime_error("Mul operator requires at least 2 inputs"); - } - if (outputs.empty()) { - throw std::runtime_error("Mul operator requires at least 1 output"); - } - - const auto& tensor1 = inputs[0]->getTensor(); - const auto& tensor2 = inputs[1]->getTensor(); - - // 简单相乘操作 - YTensorBase result = tensor1 * tensor2; - outputs[0]->setTensor(result); -} - -int main() { - std::cout << "=== YTensor Computational Graph Demo ===" << std::endl; - - try { - // 创建计算图 - ComputationGraph graph; - - // 注册算子 - graph.registerOperator("add", addOperator); - graph.registerOperator("mul", mulOperator); - - // 创建节点 - auto node1 = graph.createNode("add1", "add"); - node1->setName("加法节点1"); - - auto node2 = graph.createNode("mul1", "mul"); - node2->setName("乘法节点1"); - - auto node3 = graph.createNode("add2", "add"); - node3->setName("加法节点2"); - - // 创建输入边 - auto input1 = graph.createEdge("input1", nullptr, node1); - auto input2 = graph.createEdge("input2", nullptr, node1); - auto input3 = graph.createEdge("input3", nullptr, node2); - - // 创建中间边 - auto edge1 = graph.createEdge("edge1", node1, node2); - auto edge2 = graph.createEdge("edge2", node2, node3); - - // 创建输出边 - auto input4 = graph.createEdge("input4", nullptr, node3); - auto output1 = graph.createEdge("output1", node3, nullptr); - - std::cout << "\n图结构创建完成!" << std::endl; - std::cout << "节点数量: " << graph.nodeCount() << std::endl; - std::cout << "边数量: " << graph.edgeCount() << std::endl; - - // 序列化为JSON - std::cout << "\n序列化计算图为JSON..." << std::endl; - std::string jsonStr = GraphSerializer::toJson(graph); - std::cout << "JSON输出:\n" << jsonStr << std::endl; - - // 保存到文件 - bool saved = GraphSerializer::toJsonFile(graph, "/tmp/graph.json"); - if (saved) { - std::cout << "\n计算图已保存到 /tmp/graph.json" << std::endl; - } - - // 准备输入数据 - std::cout << "\n准备输入数据..." << std::endl; - YTensorBase t1({2, 3}, "float32"); - YTensorBase t2({2, 3}, "float32"); - YTensorBase t3({2, 3}, "float32"); - YTensorBase t4({2, 3}, "float32"); - - // 填充数据 - for (int i = 0; i < 6; i++) { - t1.atData(i) = 1.0f; - t2.atData(i) = 2.0f; - t3.atData(i) = 3.0f; - t4.atData(i) = 4.0f; - } - - std::cout << "t1 = 1.0 (全部元素)" << std::endl; - std::cout << "t2 = 2.0 (全部元素)" << std::endl; - std::cout << "t3 = 3.0 (全部元素)" << std::endl; - std::cout << "t4 = 4.0 (全部元素)" << std::endl; - - // 执行计算图 - // 计算过程: output1 = (t1 + t2) * t3 + t4 = 3 * 3 + 4 = 13 - std::cout << "\n执行计算图..." << std::endl; - std::cout << "计算表达式: output = (input1 + input2) * input3 + input4" << std::endl; - std::cout << " = (1.0 + 2.0) * 3.0 + 4.0 = 13.0" << std::endl; - - std::unordered_map inputs = { - {"input1", t1}, - {"input2", t2}, - {"input3", t3}, - {"input4", t4} - }; - - auto outputs = graph.execute(inputs, {"output1"}); - - // 输出结果 - std::cout << "\n执行完成!" << std::endl; - const auto& result = outputs["output1"]; - std::cout << "输出张量形状: [" << result.shape(0) << ", " << result.shape(1) << "]" << std::endl; - std::cout << "输出值 (前6个元素): "; - for (int i = 0; i < 6; i++) { - std::cout << result.atData(i) << " "; - } - std::cout << std::endl; - - // 验证结果 - bool correct = true; - for (int i = 0; i < 6; i++) { - if (std::abs(result.atData(i) - 13.0f) > 1e-5) { - correct = false; - break; - } - } - - if (correct) { - std::cout << "\n✓ 计算结果正确!" << std::endl; - } else { - std::cout << "\n✗ 计算结果不正确!" << std::endl; - } - - // 测试拓扑排序 - std::cout << "\n拓扑排序结果:" << std::endl; - auto topoOrder = graph.topologicalSort(); - for (const auto& nodeId : topoOrder) { - auto node = graph.getNode(nodeId); - std::cout << " - " << nodeId << " (" << node->getName() << ")" << std::endl; - } - - } catch (const std::exception& e) { - std::cerr << "错误: " << e.what() << std::endl; - return 1; - } - - std::cout << "\n=== Demo完成 ===" << std::endl; - return 0; -} diff --git a/test/test_ad_graph.cpp b/test/test_ad_graph.cpp new file mode 100644 index 0000000..83f1d54 --- /dev/null +++ b/test/test_ad_graph.cpp @@ -0,0 +1,262 @@ +#include +#include +#include +#include "../ytensor.hpp" +#include "../include/ad.hpp" + +using namespace yt; +using namespace yt::ad; + +// 测试用:构建一个简单的前馈神经网络计算图 +// 模拟 ymodel2 的一个简化版本结构 +class TestFFNGraph { +public: + ComputationGraph graph; + + // 构建一个简单的 FFN 层计算图: y = gelu(x @ W1) @ W2 + void build() { + std::cout << "构建计算图..." << std::endl; + + // 注册算子 + registerOperators(); + + // 创建节点 + auto input_node = graph.createNode("input", "input"); + auto matmul1_node = graph.createNode("matmul1", "matmul"); + auto gelu_node = graph.createNode("gelu", "gelu"); + auto matmul2_node = graph.createNode("matmul2", "matmul"); + auto output_node = graph.createNode("output", "output"); + + // 设置节点名称 + input_node->setName("输入层"); + matmul1_node->setName("第一层矩阵乘法"); + gelu_node->setName("GELU激活"); + matmul2_node->setName("第二层矩阵乘法"); + output_node->setName("输出层"); + + // 创建边连接 + auto x_edge = graph.createEdge("x", nullptr, input_node); // 输入数据 + auto x_out = graph.createEdge("x_out", input_node, matmul1_node); // 从input到matmul1 + auto w1_edge = graph.createEdge("w1", nullptr, matmul1_node); // 权重1 + auto h1_edge = graph.createEdge("h1", matmul1_node, gelu_node); // 中间结果1 + auto h2_edge = graph.createEdge("h2", gelu_node, matmul2_node); // 激活后结果 + auto w2_edge = graph.createEdge("w2", nullptr, matmul2_node); // 权重2 + auto y_edge = graph.createEdge("y", matmul2_node, output_node); // 输出 + auto out_edge = graph.createEdge("out", output_node, nullptr); // 最终输出 + + std::cout << "计算图构建完成!" << std::endl; + std::cout << " 节点数: " << graph.nodeCount() << std::endl; + std::cout << " 边数: " << graph.edgeCount() << std::endl; + } + + void registerOperators() { + // 输入算子:直接传递输入 + graph.registerOperator("input", [](const GraphNode& node, + const std::vector>& inputs, + std::vector>& outputs) { + if (!inputs.empty() && !outputs.empty()) { + outputs[0]->setTensor(inputs[0]->getTensor()); + } + }); + + // 输出算子:直接传递输出 + graph.registerOperator("output", [](const GraphNode& node, + const std::vector>& inputs, + std::vector>& outputs) { + if (!inputs.empty() && !outputs.empty()) { + outputs[0]->setTensor(inputs[0]->getTensor()); + } + }); + + // 矩阵乘法算子 (支持3D输入) + graph.registerOperator("matmul", [](const GraphNode& node, + const std::vector>& inputs, + std::vector>& outputs) { + if (inputs.size() < 2 || outputs.empty()) { + throw std::runtime_error("matmul requires 2 inputs and 1 output"); + } + auto x = inputs[0]->getTensor(); + const auto& w = inputs[1]->getTensor(); + + // 如果是3D张量,需要reshape成2D,执行matmul,再reshape回3D + if (x.ndim() == 3) { + int b = x.shape(0), s = x.shape(1), h = x.shape(2); + x = x.view(b * s, h); // [b, s, h] -> [b*s, h] + YTensorBase result = x.matmul(w); // [b*s, h] @ [h, h'] -> [b*s, h'] + int h_out = result.shape(1); + result = result.view(b, s, h_out); // [b*s, h'] -> [b, s, h'] + outputs[0]->setTensor(result); + } else { + // 2D矩阵乘法 + YTensorBase result = x.matmul(w); + outputs[0]->setTensor(result); + } + }); + + // GELU激活函数算子 + graph.registerOperator("gelu", [](const GraphNode& node, + const std::vector>& inputs, + std::vector>& outputs) { + if (inputs.empty() || outputs.empty()) { + throw std::runtime_error("gelu requires 1 input and 1 output"); + } + const auto& x = inputs[0]->getTensor(); + + // GELU近似: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) + YTensorBase result = x; // 简化版本,实际应该计算GELU + for (size_t i = 0; i < result.size(); i++) { + float val = result.atData(i); + // 简化的GELU近似 + float gelu_val = val * 0.5f * (1.0f + std::tanh(0.797885f * (val + 0.044715f * val * val * val))); + result.atData(i) = gelu_val; + } + outputs[0]->setTensor(result); + }); + } + + // 打印计算图结构 + void printGraph() { + std::cout << "\n========== 计算图结构 ==========" << std::endl; + std::cout << "节点总数: " << graph.nodeCount() << std::endl; + std::cout << "边总数: " << graph.edgeCount() << std::endl; + + std::cout << "\n节点列表:" << std::endl; + for (const auto& [id, node] : graph.getNodes()) { + std::cout << " [" << id << "] " << node->getName() + << " (类型: " << node->getOpType() << ")" << std::endl; + std::cout << " 输入边: "; + for (const auto& edge : node->getInputEdges()) { + std::cout << edge->getId() << " "; + } + std::cout << std::endl; + std::cout << " 输出边: "; + for (const auto& edge : node->getOutputEdges()) { + std::cout << edge->getId() << " "; + } + std::cout << std::endl; + } + + std::cout << "\n边列表:" << std::endl; + for (const auto& [id, edge] : graph.getEdges()) { + std::cout << " [" << id << "] "; + if (edge->getFromNode()) { + std::cout << edge->getFromNode()->getId(); + } else { + std::cout << ""; + } + std::cout << " -> "; + if (edge->getToNode()) { + std::cout << edge->getToNode()->getId(); + } else { + std::cout << ""; + } + std::cout << std::endl; + } + std::cout << "==============================\n" << std::endl; + } + + // 执行前向传播 + YTensorBase forward(const YTensorBase& x, const YTensorBase& w1, const YTensorBase& w2) { + std::cout << "执行前向传播..." << std::endl; + + // 准备输入 + std::unordered_map inputs = { + {"x", x}, + {"w1", w1}, + {"w2", w2} + }; + + // 执行计算图 + auto outputs = graph.execute(inputs, {"out"}); + + std::cout << "前向传播完成!" << std::endl; + return outputs["out"]; + } + + // 序列化为JSON + std::string toJson() { + return GraphSerializer::toJson(graph); + } +}; + +int main() { + std::cout << "========================================" << std::endl; + std::cout << " YTensor 计算图测试程序" << std::endl; + std::cout << "========================================\n" << std::endl; + + try { + // 创建测试图 + TestFFNGraph test_graph; + test_graph.build(); + + // 打印计算图结构 + test_graph.printGraph(); + + // 准备测试数据 + std::cout << "准备测试数据..." << std::endl; + int batch = 2, seq_len = 3, hidden = 4, intermediate = 6; + + YTensorBase x({batch, seq_len, hidden}, "float32"); + YTensorBase w1({hidden, intermediate}, "float32"); + YTensorBase w2({intermediate, hidden}, "float32"); + + // 初始化数据 + for (size_t i = 0; i < x.size(); i++) { + x.atData(i) = 0.1f * (i % 10); + } + for (size_t i = 0; i < w1.size(); i++) { + w1.atData(i) = 0.01f * (i % 20 - 10); + } + for (size_t i = 0; i < w2.size(); i++) { + w2.atData(i) = 0.01f * (i % 20 - 10); + } + + std::cout << " 输入形状: [" << x.shape(0) << ", " << x.shape(1) << ", " << x.shape(2) << "]" << std::endl; + std::cout << " W1形状: [" << w1.shape(0) << ", " << w1.shape(1) << "]" << std::endl; + std::cout << " W2形状: [" << w2.shape(0) << ", " << w2.shape(1) << "]" << std::endl; + + // 执行前向传播 + auto output = test_graph.forward(x, w1, w2); + + std::cout << "\n前向传播结果:" << std::endl; + std::cout << " 输出形状: [" << output.shape(0) << ", " << output.shape(1) << ", " << output.shape(2) << "]" << std::endl; + std::cout << " 输出数据(前10个元素): "; + for (int i = 0; i < std::min(10, (int)output.size()); i++) { + std::cout << std::fixed << std::setprecision(4) << output.atData(i) << " "; + } + std::cout << std::endl; + + // 序列化为JSON + std::cout << "\n序列化为JSON..." << std::endl; + std::string json = test_graph.toJson(); + std::cout << "JSON输出:" << std::endl; + std::cout << json << std::endl; + + // 保存到文件 + std::string filename = "/tmp/test_graph.json"; + if (GraphSerializer::toJsonFile(test_graph.graph, filename)) { + std::cout << "\n计算图已保存到: " << filename << std::endl; + } + + // 验证拓扑排序 + std::cout << "\n拓扑排序结果:" << std::endl; + auto topo_order = test_graph.graph.topologicalSort(); + for (size_t i = 0; i < topo_order.size(); i++) { + auto node = test_graph.graph.getNode(topo_order[i]); + std::cout << " " << (i+1) << ". " << node->getName() + << " (" << node->getId() << ")" << std::endl; + } + + std::cout << "\n✓ 所有测试通过!" << std::endl; + + } catch (const std::exception& e) { + std::cerr << "✗ 错误: " << e.what() << std::endl; + return 1; + } + + std::cout << "\n========================================" << std::endl; + std::cout << " 测试完成" << std::endl; + std::cout << "========================================" << std::endl; + + return 0; +} From 0ada507b081afb70258f8694f41fa628debc1e54 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 31 Jan 2026 16:16:26 +0000 Subject: [PATCH 5/8] Add test CMakeLists.txt --- test/CMakeLists.txt | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 test/CMakeLists.txt diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt new file mode 100644 index 0000000..3f9ecf6 --- /dev/null +++ b/test/CMakeLists.txt @@ -0,0 +1,21 @@ +cmake_minimum_required(VERSION 3.10) + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra") + +set(CMAKE_BUILD_TYPE Release) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2") + +project(test_ad_graph) + +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +find_package(ZLIB REQUIRED) + +add_executable(test_ad_graph test_ad_graph.cpp) + +target_link_libraries(test_ad_graph + ZLIB::ZLIB + pthread +) From cc738d926a9f43110d98fc402b9d45735fb7b2e0 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 31 Jan 2026 16:25:02 +0000 Subject: [PATCH 6/8] Refactor: Separate parameter nodes from operator nodes in graph architecture Co-authored-by: SnifferCaptain <151126294+SnifferCaptain@users.noreply.github.com> --- include/ad/graph_node.hpp | 44 ++++++++++++++++++--- include/ad/graph_runtime.hpp | 5 ++- src/ad/graph_node.inl | 4 +- src/ad/graph_runtime.inl | 33 ++++++++++++++-- src/ad/graph_serialization.inl | 9 +++++ test/test_ad_graph.cpp | 71 +++++++++++++++++++--------------- 6 files changed, 121 insertions(+), 45 deletions(-) diff --git a/include/ad/graph_node.hpp b/include/ad/graph_node.hpp index f80aa68..1883371 100644 --- a/include/ad/graph_node.hpp +++ b/include/ad/graph_node.hpp @@ -18,7 +18,15 @@ namespace ad { class GraphEdge; class ComputationGraph; -/// @brief 计算图节点类,表示计算图中的一个算子 +/// @brief 节点类型枚举 +enum class NodeType { + Operator, // 算子节点(如add, mul, matmul等) + Parameter, // 参数节点(如权重、偏置) + Input, // 输入节点 + Constant // 常量节点 +}; + +/// @brief 计算图节点类,表示计算图中的一个节点(算子、参数、输入或常量) class GraphNode { public: /// @brief 默认构造函数 @@ -26,17 +34,22 @@ class GraphNode { /// @brief 构造一个计算图节点 /// @param nodeId 节点的唯一标识符 - /// @param opType 算子类型(如 "add", "mul", "matmul", "relu" 等) - GraphNode(const std::string& nodeId, const std::string& opType); + /// @param opType 算子类型(如 "add", "mul", "matmul", "relu" 等)或节点类型 + /// @param nodeType 节点类型(默认为Operator) + GraphNode(const std::string& nodeId, const std::string& opType, NodeType nodeType = NodeType::Operator); /// @brief 获取节点ID /// @return 节点的唯一标识符 std::string getId() const { return nodeId_; } - /// @brief 获取算子类型 + /// @brief 获取算子类型或节点类型描述 /// @return 算子类型字符串 std::string getOpType() const { return opType_; } + /// @brief 获取节点类型 + /// @return 节点类型 + NodeType getNodeType() const { return nodeType_; } + /// @brief 设置节点名称(可选,用于调试和可视化) /// @param name 节点名称 void setName(const std::string& name) { name_ = name; } @@ -91,13 +104,32 @@ class GraphNode { /// @brief 重置节点执行状态 void reset() { executed_ = false; } + /// @brief 设置节点数据(用于参数节点和常量节点) + /// @param tensor 张量数据 + void setData(const YTensorBase& tensor) { data_ = tensor; hasData_ = true; } + + /// @brief 获取节点数据 + /// @return 张量数据的引用 + YTensorBase& getData() { return data_; } + + /// @brief 获取节点数据(const版本) + /// @return 张量数据的const引用 + const YTensorBase& getData() const { return data_; } + + /// @brief 检查节点是否有数据 + /// @return 如果有数据返回true + bool hasData() const { return hasData_; } + private: std::string nodeId_; // 节点唯一标识符 - std::string opType_; // 算子类型 + std::string opType_; // 算子类型或节点类型描述 std::string name_; // 节点名称(可选) + NodeType nodeType_ = NodeType::Operator; // 节点类型 std::vector> inputEdges_; // 输入边列表 std::vector> outputEdges_; // 输出边列表 - std::unordered_map parameters_; // 节点参数 + std::unordered_map parameters_; // 节点参数(用于算子配置) + YTensorBase data_; // 节点数据(用于参数和常量节点) + bool hasData_ = false; // 是否有数据 bool executed_ = false; // 是否已执行 }; diff --git a/include/ad/graph_runtime.hpp b/include/ad/graph_runtime.hpp index 2aaae45..ec3c0be 100644 --- a/include/ad/graph_runtime.hpp +++ b/include/ad/graph_runtime.hpp @@ -36,9 +36,10 @@ class ComputationGraph { /// @brief 创建一个新节点 /// @param nodeId 节点ID(必须唯一) - /// @param opType 算子类型 + /// @param opType 算子类型或节点描述 + /// @param nodeType 节点类型(默认为Operator) /// @return 节点指针 - std::shared_ptr createNode(const std::string& nodeId, const std::string& opType); + std::shared_ptr createNode(const std::string& nodeId, const std::string& opType, NodeType nodeType = NodeType::Operator); /// @brief 创建一个新边 /// @param edgeId 边ID(必须唯一) diff --git a/src/ad/graph_node.inl b/src/ad/graph_node.inl index e4f6527..77ade47 100644 --- a/src/ad/graph_node.inl +++ b/src/ad/graph_node.inl @@ -6,8 +6,8 @@ namespace yt { namespace ad { -inline GraphNode::GraphNode(const std::string& nodeId, const std::string& opType) - : nodeId_(nodeId), opType_(opType), name_(nodeId) { +inline GraphNode::GraphNode(const std::string& nodeId, const std::string& opType, NodeType nodeType) + : nodeId_(nodeId), opType_(opType), name_(nodeId), nodeType_(nodeType) { } inline void GraphNode::addInputEdge(std::shared_ptr edge) { diff --git a/src/ad/graph_runtime.inl b/src/ad/graph_runtime.inl index 0aa9df3..2af6925 100644 --- a/src/ad/graph_runtime.inl +++ b/src/ad/graph_runtime.inl @@ -11,13 +11,13 @@ namespace yt { namespace ad { inline std::shared_ptr ComputationGraph::createNode( - const std::string& nodeId, const std::string& opType) { + const std::string& nodeId, const std::string& opType, NodeType nodeType) { if (nodes_.find(nodeId) != nodes_.end()) { throw std::runtime_error("Node with ID '" + nodeId + "' already exists"); } - auto node = std::make_shared(nodeId, opType); + auto node = std::make_shared(nodeId, opType, nodeType); nodes_[nodeId] = node; return node; } @@ -128,7 +128,34 @@ inline void ComputationGraph::executeNode(const std::string& nodeId) { return; // 节点已执行,跳过 } - // 获取算子执行器 + // 参数节点和常量节点不需要执行,直接标记为已执行 + if (node->getNodeType() == NodeType::Parameter || node->getNodeType() == NodeType::Constant) { + // 将节点数据传递到输出边 + if (node->hasData() && !node->getOutputEdges().empty()) { + for (auto& edge : node->getOutputEdges()) { + edge->setTensor(node->getData()); + } + } + node->setExecuted(true); + return; + } + + // 输入节点:从输入边读取数据,传递到输出边 + if (node->getNodeType() == NodeType::Input) { + auto inputEdges = node->getInputEdges(); + auto outputEdges = node->getOutputEdges(); + if (!inputEdges.empty() && !outputEdges.empty()) { + if (inputEdges[0]->hasTensor()) { + for (auto& edge : outputEdges) { + edge->setTensor(inputEdges[0]->getTensor()); + } + } + } + node->setExecuted(true); + return; + } + + // 算子节点:执行计算 auto it = operators_.find(node->getOpType()); if (it == operators_.end()) { throw std::runtime_error("Operator '" + node->getOpType() + "' not registered"); diff --git a/src/ad/graph_serialization.inl b/src/ad/graph_serialization.inl index 05a4831..9ae9b96 100644 --- a/src/ad/graph_serialization.inl +++ b/src/ad/graph_serialization.inl @@ -165,6 +165,15 @@ inline void GraphSerializer::serializeNode(const std::shared_ptr& nod builder.addKey("name"); builder.addString(node->getName()); + // 序列化节点类型 + builder.addKey("nodeType"); + switch(node->getNodeType()) { + case NodeType::Operator: builder.addString("operator"); break; + case NodeType::Parameter: builder.addString("parameter"); break; + case NodeType::Input: builder.addString("input"); break; + case NodeType::Constant: builder.addString("constant"); break; + } + // 序列化输入边ID列表 builder.addKey("inputEdges"); builder.beginArray(); diff --git a/test/test_ad_graph.cpp b/test/test_ad_graph.cpp index 83f1d54..73660fd 100644 --- a/test/test_ad_graph.cpp +++ b/test/test_ad_graph.cpp @@ -14,6 +14,7 @@ class TestFFNGraph { ComputationGraph graph; // 构建一个简单的 FFN 层计算图: y = gelu(x @ W1) @ W2 + // 现在 W1 和 W2 是参数节点,不是边 void build() { std::cout << "构建计算图..." << std::endl; @@ -21,28 +22,32 @@ class TestFFNGraph { registerOperators(); // 创建节点 - auto input_node = graph.createNode("input", "input"); - auto matmul1_node = graph.createNode("matmul1", "matmul"); - auto gelu_node = graph.createNode("gelu", "gelu"); - auto matmul2_node = graph.createNode("matmul2", "matmul"); - auto output_node = graph.createNode("output", "output"); + auto input_node = graph.createNode("input", "input", NodeType::Input); + auto w1_node = graph.createNode("w1", "parameter", NodeType::Parameter); + auto w2_node = graph.createNode("w2", "parameter", NodeType::Parameter); + auto matmul1_node = graph.createNode("matmul1", "matmul", NodeType::Operator); + auto gelu_node = graph.createNode("gelu", "gelu", NodeType::Operator); + auto matmul2_node = graph.createNode("matmul2", "matmul", NodeType::Operator); + auto output_node = graph.createNode("output", "output", NodeType::Operator); // 设置节点名称 - input_node->setName("输入层"); + input_node->setName("输入节点"); + w1_node->setName("权重参数W1"); + w2_node->setName("权重参数W2"); matmul1_node->setName("第一层矩阵乘法"); gelu_node->setName("GELU激活"); matmul2_node->setName("第二层矩阵乘法"); - output_node->setName("输出层"); + output_node->setName("输出节点"); - // 创建边连接 - auto x_edge = graph.createEdge("x", nullptr, input_node); // 输入数据 - auto x_out = graph.createEdge("x_out", input_node, matmul1_node); // 从input到matmul1 - auto w1_edge = graph.createEdge("w1", nullptr, matmul1_node); // 权重1 - auto h1_edge = graph.createEdge("h1", matmul1_node, gelu_node); // 中间结果1 - auto h2_edge = graph.createEdge("h2", gelu_node, matmul2_node); // 激活后结果 - auto w2_edge = graph.createEdge("w2", nullptr, matmul2_node); // 权重2 - auto y_edge = graph.createEdge("y", matmul2_node, output_node); // 输出 - auto out_edge = graph.createEdge("out", output_node, nullptr); // 最终输出 + // 创建边连接:现在 w1 和 w2 是节点,需要通过边连接到算子 + auto x_edge = graph.createEdge("x", nullptr, input_node); // 外部输入 -> input节点 + auto x_to_mm1 = graph.createEdge("x_to_mm1", input_node, matmul1_node); // input -> matmul1 + auto w1_to_mm1 = graph.createEdge("w1_to_mm1", w1_node, matmul1_node); // w1参数 -> matmul1 + auto h1_edge = graph.createEdge("h1", matmul1_node, gelu_node); // matmul1 -> gelu + auto h2_edge = graph.createEdge("h2", gelu_node, matmul2_node); // gelu -> matmul2 + auto w2_to_mm2 = graph.createEdge("w2_to_mm2", w2_node, matmul2_node); // w2参数 -> matmul2 + auto y_edge = graph.createEdge("y", matmul2_node, output_node); // matmul2 -> output + auto out_edge = graph.createEdge("out", output_node, nullptr); // output -> 外部输出 std::cout << "计算图构建完成!" << std::endl; std::cout << " 节点数: " << graph.nodeCount() << std::endl; @@ -50,15 +55,6 @@ class TestFFNGraph { } void registerOperators() { - // 输入算子:直接传递输入 - graph.registerOperator("input", [](const GraphNode& node, - const std::vector>& inputs, - std::vector>& outputs) { - if (!inputs.empty() && !outputs.empty()) { - outputs[0]->setTensor(inputs[0]->getTensor()); - } - }); - // 输出算子:直接传递输出 graph.registerOperator("output", [](const GraphNode& node, const std::vector>& inputs, @@ -122,8 +118,15 @@ class TestFFNGraph { std::cout << "\n节点列表:" << std::endl; for (const auto& [id, node] : graph.getNodes()) { + std::string nodeTypeStr; + switch(node->getNodeType()) { + case NodeType::Operator: nodeTypeStr = "算子"; break; + case NodeType::Parameter: nodeTypeStr = "参数"; break; + case NodeType::Input: nodeTypeStr = "输入"; break; + case NodeType::Constant: nodeTypeStr = "常量"; break; + } std::cout << " [" << id << "] " << node->getName() - << " (类型: " << node->getOpType() << ")" << std::endl; + << " (类型: " << nodeTypeStr << ", op: " << node->getOpType() << ")" << std::endl; std::cout << " 输入边: "; for (const auto& edge : node->getInputEdges()) { std::cout << edge->getId() << " "; @@ -142,13 +145,13 @@ class TestFFNGraph { if (edge->getFromNode()) { std::cout << edge->getFromNode()->getId(); } else { - std::cout << ""; + std::cout << ""; } std::cout << " -> "; if (edge->getToNode()) { std::cout << edge->getToNode()->getId(); } else { - std::cout << ""; + std::cout << ""; } std::cout << std::endl; } @@ -159,11 +162,15 @@ class TestFFNGraph { YTensorBase forward(const YTensorBase& x, const YTensorBase& w1, const YTensorBase& w2) { std::cout << "执行前向传播..." << std::endl; - // 准备输入 + // 设置参数节点的数据 + auto w1_node = graph.getNode("w1"); + auto w2_node = graph.getNode("w2"); + w1_node->setData(w1); + w2_node->setData(w2); + + // 准备输入(只有外部输入x) std::unordered_map inputs = { - {"x", x}, - {"w1", w1}, - {"w2", w2} + {"x", x} }; // 执行计算图 From ff5cea57aeb4e4f192e549a3e0d28dc0029c9705 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 31 Jan 2026 16:36:45 +0000 Subject: [PATCH 7/8] Add comprehensive FFN graph test with direct implementation comparison Co-authored-by: SnifferCaptain <151126294+SnifferCaptain@users.noreply.github.com> --- test/CMakeLists.txt | 12 +- test/test_ad_ffn.cpp | 299 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 306 insertions(+), 5 deletions(-) create mode 100644 test/test_ad_ffn.cpp diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 3f9ecf6..d7c556f 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -6,16 +6,18 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra") set(CMAKE_BUILD_TYPE Release) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2") -project(test_ad_graph) +project(ad_tests) set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED ON) find_package(ZLIB REQUIRED) +# Test 1: Basic graph test add_executable(test_ad_graph test_ad_graph.cpp) +target_link_libraries(test_ad_graph ZLIB::ZLIB pthread) + +# Test 2: FFN comparison test +add_executable(test_ad_ffn test_ad_ffn.cpp) +target_link_libraries(test_ad_ffn ZLIB::ZLIB pthread) -target_link_libraries(test_ad_graph - ZLIB::ZLIB - pthread -) diff --git a/test/test_ad_ffn.cpp b/test/test_ad_ffn.cpp new file mode 100644 index 0000000..2943ac3 --- /dev/null +++ b/test/test_ad_ffn.cpp @@ -0,0 +1,299 @@ +#include +#include +#include +#include +#include +#include "../ytensor.hpp" +#include "../include/ad.hpp" + +using namespace yt; +using namespace yt::ad; + +// 简化的FFN块:与ymodel2的FFN类似 +// FFN: x -> linear(up) -> gelu -> linear(down) -> output + +struct FFNConfig { + int hidden_size = 64; + int intermediate_size = 128; +}; + +class SimplifiedFFNGraph { +public: + ComputationGraph graph; + FFNConfig config; + + void build() { + std::cout << "\n=== 构建简化FFN计算图 ===" << std::endl; + + // 注册算子 + registerOperators(); + + // 创建节点 + // 输入节点 + auto input_node = graph.createNode("input", "input", NodeType::Input); + input_node->setName("输入X"); + + // 参数节点 + auto up_weight = graph.createNode("up_weight", "parameter", NodeType::Parameter); + up_weight->setName("上投影权重"); + + auto down_weight = graph.createNode("down_weight", "parameter", NodeType::Parameter); + down_weight->setName("下投影权重"); + + // 算子节点 + auto linear1 = graph.createNode("linear1", "linear", NodeType::Operator); + linear1->setName("第一层线性变换"); + + auto gelu = graph.createNode("gelu", "gelu", NodeType::Operator); + gelu->setName("GELU激活"); + + auto linear2 = graph.createNode("linear2", "linear", NodeType::Operator); + linear2->setName("第二层线性变换"); + + auto output_node = graph.createNode("output", "output", NodeType::Operator); + output_node->setName("输出"); + + // 创建边连接 + graph.createEdge("x_input", nullptr, input_node); + graph.createEdge("x_to_l1", input_node, linear1); + graph.createEdge("up_to_l1", up_weight, linear1); + graph.createEdge("h1", linear1, gelu); + graph.createEdge("h2", gelu, linear2); + graph.createEdge("down_to_l2", down_weight, linear2); + graph.createEdge("y", linear2, output_node); + graph.createEdge("output", output_node, nullptr); + + std::cout << "图构建完成: " << graph.nodeCount() << " 节点, " + << graph.edgeCount() << " 边" << std::endl; + } + + void registerOperators() { + // 输出算子 + graph.registerOperator("output", [](const GraphNode&, + const std::vector>& inputs, + std::vector>& outputs) { + if (!inputs.empty() && !outputs.empty()) { + outputs[0]->setTensor(inputs[0]->getTensor()); + } + }); + + // 线性层算子 (matmul) + graph.registerOperator("linear", [](const GraphNode&, + const std::vector>& inputs, + std::vector>& outputs) { + if (inputs.size() < 2 || outputs.empty()) { + throw std::runtime_error("linear requires 2 inputs"); + } + + auto x = inputs[0]->getTensor(); + const auto& w = inputs[1]->getTensor(); + + // 支持3D输入 [batch, seq, hidden] + if (x.ndim() == 3) { + int b = x.shape(0), s = x.shape(1), h = x.shape(2); + x = x.view(b * s, h); + YTensorBase result = x.matmul(w); + int h_out = result.shape(1); + result = result.view(b, s, h_out); + outputs[0]->setTensor(result); + } else { + outputs[0]->setTensor(x.matmul(w)); + } + }); + + // GELU激活 + graph.registerOperator("gelu", [](const GraphNode&, + const std::vector>& inputs, + std::vector>& outputs) { + if (inputs.empty() || outputs.empty()) { + throw std::runtime_error("gelu requires 1 input"); + } + + const auto& x = inputs[0]->getTensor(); + YTensorBase result = x; + + // GELU: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) + const float sqrt_2_over_pi = 0.7978845608f; + for (size_t i = 0; i < result.size(); i++) { + float val = result.atData(i); + float x3 = val * val * val; + float inner = sqrt_2_over_pi * (val + 0.044715f * x3); + float tanh_val = std::tanh(inner); + result.atData(i) = 0.5f * val * (1.0f + tanh_val); + } + + outputs[0]->setTensor(result); + }); + } + + void printGraph() { + std::cout << "\n========== 计算图结构 ==========" << std::endl; + std::cout << "节点总数: " << graph.nodeCount() << std::endl; + std::cout << "边总数: " << graph.edgeCount() << std::endl; + + std::cout << "\n节点列表:" << std::endl; + for (const auto& [id, node] : graph.getNodes()) { + std::string nodeTypeStr; + switch(node->getNodeType()) { + case NodeType::Operator: nodeTypeStr = "算子"; break; + case NodeType::Parameter: nodeTypeStr = "参数"; break; + case NodeType::Input: nodeTypeStr = "输入"; break; + case NodeType::Constant: nodeTypeStr = "常量"; break; + } + std::cout << " [" << id << "] " << node->getName() + << " (类型: " << nodeTypeStr << ")" << std::endl; + } + + std::cout << "\n拓扑排序:" << std::endl; + auto topo = graph.topologicalSort(); + for (size_t i = 0; i < topo.size(); i++) { + auto node = graph.getNode(topo[i]); + std::cout << " " << (i+1) << ". " << node->getName() << std::endl; + } + std::cout << "==============================\n" << std::endl; + } + + YTensorBase forward(const YTensorBase& x, const YTensorBase& up_w, const YTensorBase& down_w) { + // 设置参数 + graph.getNode("up_weight")->setData(up_w); + graph.getNode("down_weight")->setData(down_w); + + // 执行 + std::unordered_map inputs = {{"x_input", x}}; + auto outputs = graph.execute(inputs, {"output"}); + return outputs["output"]; + } + + std::string toJson() { + return GraphSerializer::toJson(graph); + } +}; + +// 直接实现的FFN(用于对比) +YTensorBase directFFN(const YTensorBase& x, const YTensorBase& up_w, const YTensorBase& down_w) { + // x @ up_w + auto x_copy = x; + if (x_copy.ndim() == 3) { + int b = x_copy.shape(0), s = x_copy.shape(1), h = x_copy.shape(2); + x_copy = x_copy.view(b * s, h); + auto h1 = x_copy.matmul(up_w); + int h_out = h1.shape(1); + h1 = h1.view(b, s, h_out); + + // GELU + const float sqrt_2_over_pi = 0.7978845608f; + for (size_t i = 0; i < h1.size(); i++) { + float val = h1.atData(i); + float x3 = val * val * val; + float inner = sqrt_2_over_pi * (val + 0.044715f * x3); + float tanh_val = std::tanh(inner); + h1.atData(i) = 0.5f * val * (1.0f + tanh_val); + } + + // h1 @ down_w + h1 = h1.view(b * s, h1.shape(2)); + auto output = h1.matmul(down_w); + output = output.view(b, s, output.shape(1)); + return output; + } + return x; +} + +int main() { + std::cout << "========================================" << std::endl; + std::cout << " YTensor 计算图完整测试" << std::endl; + std::cout << "========================================" << std::endl; + + try { + // 创建简化的FFN图 + SimplifiedFFNGraph ffn_graph; + ffn_graph.build(); + ffn_graph.printGraph(); + + // 准备测试数据 + std::cout << "\n=== 准备测试数据 ===" << std::endl; + int batch = 2, seq_len = 3, hidden = 64, intermediate = 128; + + YTensorBase x({batch, seq_len, hidden}, "float32"); + YTensorBase up_weight({hidden, intermediate}, "float32"); + YTensorBase down_weight({intermediate, hidden}, "float32"); + + // 使用固定种子初始化 + std::mt19937 gen(42); + std::normal_distribution dist(0.0f, 0.1f); + + for (size_t i = 0; i < x.size(); i++) { + x.atData(i) = dist(gen); + } + for (size_t i = 0; i < up_weight.size(); i++) { + up_weight.atData(i) = dist(gen); + } + for (size_t i = 0; i < down_weight.size(); i++) { + down_weight.atData(i) = dist(gen); + } + + std::cout << "输入形状: [" << batch << ", " << seq_len << ", " << hidden << "]" << std::endl; + std::cout << "上投影权重: [" << hidden << ", " << intermediate << "]" << std::endl; + std::cout << "下投影权重: [" << intermediate << ", " << hidden << "]" << std::endl; + + // 通过计算图执行 + std::cout << "\n=== 计算图前向传播 ===" << std::endl; + auto graph_output = ffn_graph.forward(x, up_weight, down_weight); + + std::cout << "输出形状: [" << graph_output.shape(0) << ", " + << graph_output.shape(1) << ", " << graph_output.shape(2) << "]" << std::endl; + std::cout << "输出数据(前5个元素): "; + for (int i = 0; i < 5 && i < (int)graph_output.size(); i++) { + std::cout << std::fixed << std::setprecision(6) << graph_output.atData(i) << " "; + } + std::cout << std::endl; + + // 直接实现对比 + std::cout << "\n=== 直接实现对比 ===" << std::endl; + auto direct_output = directFFN(x, up_weight, down_weight); + + std::cout << "输出形状: [" << direct_output.shape(0) << ", " + << direct_output.shape(1) << ", " << direct_output.shape(2) << "]" << std::endl; + std::cout << "输出数据(前5个元素): "; + for (int i = 0; i < 5 && i < (int)direct_output.size(); i++) { + std::cout << std::fixed << std::setprecision(6) << direct_output.atData(i) << " "; + } + std::cout << std::endl; + + // 验证结果 + std::cout << "\n=== 验证结果一致性 ===" << std::endl; + float max_diff = 0.0f; + for (size_t i = 0; i < graph_output.size(); i++) { + float diff = std::abs(graph_output.atData(i) - direct_output.atData(i)); + max_diff = std::max(max_diff, diff); + } + + std::cout << "最大差异: " << std::scientific << max_diff << std::endl; + if (max_diff < 1e-5f) { + std::cout << "✓ 结果一致!(误差 < 1e-5)" << std::endl; + } else { + std::cout << "✗ 结果不一致!" << std::endl; + } + + // JSON序列化 + std::cout << "\n=== JSON序列化 ===" << std::endl; + std::string json = ffn_graph.toJson(); + std::cout << "JSON长度: " << json.length() << " 字符" << std::endl; + std::cout << "保存到 /tmp/ffn_graph.json" << std::endl; + GraphSerializer::toJsonFile(ffn_graph.graph, "/tmp/ffn_graph.json"); + + // 显示JSON片段 + std::cout << "\nJSON片段(前300字符):" << std::endl; + std::cout << json.substr(0, 300) << "..." << std::endl; + + } catch (const std::exception& e) { + std::cerr << "✗ 错误: " << e.what() << std::endl; + return 1; + } + + std::cout << "\n========================================" << std::endl; + std::cout << " 所有测试完成" << std::endl; + std::cout << "========================================" << std::endl; + + return 0; +} From 2ab10ae824276dae16ac1981b172c064aadcc061 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 31 Jan 2026 16:46:59 +0000 Subject: [PATCH 8/8] Fix code style: match ytensor namespace formatting (no space before brace) Co-authored-by: SnifferCaptain <151126294+SnifferCaptain@users.noreply.github.com> --- include/ad/graph_edge.hpp | 4 ++-- include/ad/graph_node.hpp | 4 ++-- include/ad/graph_runtime.hpp | 4 ++-- include/ad/graph_serialization.hpp | 4 ++-- src/ad/graph_edge.inl | 4 ++-- src/ad/graph_node.inl | 4 ++-- src/ad/graph_runtime.inl | 4 ++-- src/ad/graph_serialization.inl | 4 ++-- 8 files changed, 16 insertions(+), 16 deletions(-) diff --git a/include/ad/graph_edge.hpp b/include/ad/graph_edge.hpp index b7a9ec5..ef06f61 100644 --- a/include/ad/graph_edge.hpp +++ b/include/ad/graph_edge.hpp @@ -9,8 +9,8 @@ #include #include "../ytensor_base.hpp" -namespace yt { -namespace ad { +namespace yt{ +namespace ad{ // 前向声明 class GraphNode; diff --git a/include/ad/graph_node.hpp b/include/ad/graph_node.hpp index 1883371..3f07ff1 100644 --- a/include/ad/graph_node.hpp +++ b/include/ad/graph_node.hpp @@ -11,8 +11,8 @@ #include #include -namespace yt { -namespace ad { +namespace yt{ +namespace ad{ // 前向声明 class GraphEdge; diff --git a/include/ad/graph_runtime.hpp b/include/ad/graph_runtime.hpp index ec3c0be..0b80894 100644 --- a/include/ad/graph_runtime.hpp +++ b/include/ad/graph_runtime.hpp @@ -16,8 +16,8 @@ #include "graph_edge.hpp" #include "../ytensor_base.hpp" -namespace yt { -namespace ad { +namespace yt{ +namespace ad{ // 算子执行函数类型定义 // 输入:节点指针,输入边列表,输出边列表 diff --git a/include/ad/graph_serialization.hpp b/include/ad/graph_serialization.hpp index ab7a362..d306be2 100644 --- a/include/ad/graph_serialization.hpp +++ b/include/ad/graph_serialization.hpp @@ -13,8 +13,8 @@ #include "graph_runtime.hpp" -namespace yt { -namespace ad { +namespace yt{ +namespace ad{ /// @brief 简单的JSON构建器类(用于序列化) class JsonBuilder { diff --git a/src/ad/graph_edge.inl b/src/ad/graph_edge.inl index e1c2db0..e0283dc 100644 --- a/src/ad/graph_edge.inl +++ b/src/ad/graph_edge.inl @@ -3,8 +3,8 @@ * @brief: 计算图边的实现 ***************/ -namespace yt { -namespace ad { +namespace yt{ +namespace ad{ inline GraphEdge::GraphEdge(const std::string& edgeId, std::shared_ptr fromNode, diff --git a/src/ad/graph_node.inl b/src/ad/graph_node.inl index 77ade47..b122ff7 100644 --- a/src/ad/graph_node.inl +++ b/src/ad/graph_node.inl @@ -3,8 +3,8 @@ * @brief: 计算图节点的实现 ***************/ -namespace yt { -namespace ad { +namespace yt{ +namespace ad{ inline GraphNode::GraphNode(const std::string& nodeId, const std::string& opType, NodeType nodeType) : nodeId_(nodeId), opType_(opType), name_(nodeId), nodeType_(nodeType) { diff --git a/src/ad/graph_runtime.inl b/src/ad/graph_runtime.inl index 2af6925..dc2924a 100644 --- a/src/ad/graph_runtime.inl +++ b/src/ad/graph_runtime.inl @@ -7,8 +7,8 @@ #include #include -namespace yt { -namespace ad { +namespace yt{ +namespace ad{ inline std::shared_ptr ComputationGraph::createNode( const std::string& nodeId, const std::string& opType, NodeType nodeType) { diff --git a/src/ad/graph_serialization.inl b/src/ad/graph_serialization.inl index 9ae9b96..a9e88f7 100644 --- a/src/ad/graph_serialization.inl +++ b/src/ad/graph_serialization.inl @@ -6,8 +6,8 @@ #include #include -namespace yt { -namespace ad { +namespace yt{ +namespace ad{ // JsonBuilder 实现 inline void JsonBuilder::beginObject() {