diff --git a/include/ad.hpp b/include/ad.hpp new file mode 100644 index 0000000..76a9e7c --- /dev/null +++ b/include/ad.hpp @@ -0,0 +1,18 @@ +#pragma once +/*************** +* @file: ad.hpp +* @brief: 自动微分模块总入口 +* @description: 包含计算图的所有头文件和实现 +***************/ + +// 头文件 +#include "ad/graph_node.hpp" +#include "ad/graph_edge.hpp" +#include "ad/graph_runtime.hpp" +#include "ad/graph_serialization.hpp" + +// 实现文件 +#include "../src/ad/graph_node.inl" +#include "../src/ad/graph_edge.inl" +#include "../src/ad/graph_runtime.inl" +#include "../src/ad/graph_serialization.inl" diff --git a/include/ad/graph_edge.hpp b/include/ad/graph_edge.hpp new file mode 100644 index 0000000..ef06f61 --- /dev/null +++ b/include/ad/graph_edge.hpp @@ -0,0 +1,93 @@ +#pragma once +/*************** +* @file: graph_edge.hpp +* @brief: 计算图边定义,表示计算图中的数据流 +* @description: 边代表有向的数据流,连接源节点和目标节点,携带张量数据 +***************/ + +#include +#include +#include "../ytensor_base.hpp" + +namespace yt{ +namespace ad{ + +// 前向声明 +class GraphNode; + +/// @brief 计算图边类,表示节点之间的有向数据流 +class GraphEdge { +public: + /// @brief 默认构造函数 + GraphEdge() = default; + + /// @brief 构造一个计算图边 + /// @param edgeId 边的唯一标识符 + /// @param fromNode 源节点(可为nullptr,表示输入边) + /// @param toNode 目标节点(可为nullptr,表示输出边) + GraphEdge(const std::string& edgeId, + std::shared_ptr fromNode = nullptr, + std::shared_ptr toNode = nullptr); + + /// @brief 获取边ID + /// @return 边的唯一标识符 + std::string getId() const { return edgeId_; } + + /// @brief 设置边名称(可选) + /// @param name 边名称 + void setName(const std::string& name) { name_ = name; } + + /// @brief 获取边名称 + /// @return 边名称 + std::string getName() const { return name_; } + + /// @brief 设置源节点 + /// @param node 源节点指针 + void setFromNode(std::shared_ptr node) { fromNode_ = node; } + + /// @brief 设置目标节点 + /// @param node 目标节点指针 + void setToNode(std::shared_ptr node) { toNode_ = node; } + + /// @brief 获取源节点 + /// @return 源节点指针 + std::shared_ptr getFromNode() const { return fromNode_; } + + /// @brief 获取目标节点 + /// @return 目标节点指针 + std::shared_ptr getToNode() const { return toNode_; } + + /// @brief 设置边上的张量数据 + /// @param tensor 张量数据 + void setTensor(const YTensorBase& tensor) { tensor_ = tensor; } + + /// @brief 获取边上的张量数据 + /// @return 张量数据的引用 + YTensorBase& getTensor() { return tensor_; } + + /// @brief 获取边上的张量数据(const版本) + /// @return 张量数据的const引用 + const YTensorBase& getTensor() const { return tensor_; } + + /// @brief 检查边是否包含有效的张量数据 + /// @return 如果张量有效返回true + bool hasTensor() const { return tensor_.size() > 0; } + + /// @brief 检查是否为输入边(没有源节点) + /// @return 如果是输入边返回true + bool isInputEdge() const { return fromNode_ == nullptr; } + + /// @brief 检查是否为输出边(没有目标节点) + /// @return 如果是输出边返回true + bool isOutputEdge() const { return toNode_ == nullptr; } + +private: + std::string edgeId_; // 边的唯一标识符 + std::string name_; // 边名称(可选) + std::shared_ptr fromNode_; // 源节点 + std::shared_ptr toNode_; // 目标节点 + YTensorBase tensor_; // 边上携带的张量数据 +}; + +} // namespace ad +} // namespace yt diff --git a/include/ad/graph_node.hpp b/include/ad/graph_node.hpp new file mode 100644 index 0000000..3f07ff1 --- /dev/null +++ b/include/ad/graph_node.hpp @@ -0,0 +1,137 @@ +#pragma once +/*************** +* @file: graph_node.hpp +* @brief: 计算图节点定义,表示计算图中的算子 +* @description: 节点代表算子操作,包含操作类型、参数和连接关系 +***************/ + +#include +#include +#include +#include +#include + +namespace yt{ +namespace ad{ + +// 前向声明 +class GraphEdge; +class ComputationGraph; + +/// @brief 节点类型枚举 +enum class NodeType { + Operator, // 算子节点(如add, mul, matmul等) + Parameter, // 参数节点(如权重、偏置) + Input, // 输入节点 + Constant // 常量节点 +}; + +/// @brief 计算图节点类,表示计算图中的一个节点(算子、参数、输入或常量) +class GraphNode { +public: + /// @brief 默认构造函数 + GraphNode() = default; + + /// @brief 构造一个计算图节点 + /// @param nodeId 节点的唯一标识符 + /// @param opType 算子类型(如 "add", "mul", "matmul", "relu" 等)或节点类型 + /// @param nodeType 节点类型(默认为Operator) + GraphNode(const std::string& nodeId, const std::string& opType, NodeType nodeType = NodeType::Operator); + + /// @brief 获取节点ID + /// @return 节点的唯一标识符 + std::string getId() const { return nodeId_; } + + /// @brief 获取算子类型或节点类型描述 + /// @return 算子类型字符串 + std::string getOpType() const { return opType_; } + + /// @brief 获取节点类型 + /// @return 节点类型 + NodeType getNodeType() const { return nodeType_; } + + /// @brief 设置节点名称(可选,用于调试和可视化) + /// @param name 节点名称 + void setName(const std::string& name) { name_ = name; } + + /// @brief 获取节点名称 + /// @return 节点名称 + std::string getName() const { return name_; } + + /// @brief 添加输入边 + /// @param edge 指向输入边的指针 + void addInputEdge(std::shared_ptr edge); + + /// @brief 添加输出边 + /// @param edge 指向输出边的指针 + void addOutputEdge(std::shared_ptr edge); + + /// @brief 获取所有输入边 + /// @return 输入边的vector + const std::vector>& getInputEdges() const { return inputEdges_; } + + /// @brief 获取所有输出边 + /// @return 输出边的vector + const std::vector>& getOutputEdges() const { return outputEdges_; } + + /// @brief 设置节点参数 + /// @param key 参数名 + /// @param value 参数值(使用std::any以支持多种类型) + void setParameter(const std::string& key, const std::any& value); + + /// @brief 获取节点参数 + /// @param key 参数名 + /// @return 参数值 + std::any getParameter(const std::string& key) const; + + /// @brief 检查是否存在某个参数 + /// @param key 参数名 + /// @return 如果参数存在返回true,否则返回false + bool hasParameter(const std::string& key) const; + + /// @brief 获取所有参数 + /// @return 参数映射表 + const std::unordered_map& getParameters() const { return parameters_; } + + /// @brief 标记节点是否已执行 + /// @param executed 是否已执行 + void setExecuted(bool executed) { executed_ = executed; } + + /// @brief 检查节点是否已执行 + /// @return 如果已执行返回true + bool isExecuted() const { return executed_; } + + /// @brief 重置节点执行状态 + void reset() { executed_ = false; } + + /// @brief 设置节点数据(用于参数节点和常量节点) + /// @param tensor 张量数据 + void setData(const YTensorBase& tensor) { data_ = tensor; hasData_ = true; } + + /// @brief 获取节点数据 + /// @return 张量数据的引用 + YTensorBase& getData() { return data_; } + + /// @brief 获取节点数据(const版本) + /// @return 张量数据的const引用 + const YTensorBase& getData() const { return data_; } + + /// @brief 检查节点是否有数据 + /// @return 如果有数据返回true + bool hasData() const { return hasData_; } + +private: + std::string nodeId_; // 节点唯一标识符 + std::string opType_; // 算子类型或节点类型描述 + std::string name_; // 节点名称(可选) + NodeType nodeType_ = NodeType::Operator; // 节点类型 + std::vector> inputEdges_; // 输入边列表 + std::vector> outputEdges_; // 输出边列表 + std::unordered_map parameters_; // 节点参数(用于算子配置) + YTensorBase data_; // 节点数据(用于参数和常量节点) + bool hasData_ = false; // 是否有数据 + bool executed_ = false; // 是否已执行 +}; + +} // namespace ad +} // namespace yt diff --git a/include/ad/graph_runtime.hpp b/include/ad/graph_runtime.hpp new file mode 100644 index 0000000..0b80894 --- /dev/null +++ b/include/ad/graph_runtime.hpp @@ -0,0 +1,162 @@ +#pragma once +/*************** +* @file: graph_runtime.hpp +* @brief: 计算图运行时,负责计算图的构建、执行和管理 +* @description: 运行时引擎管理计算图的生命周期,提供节点和边的创建、执行、序列化等功能 +***************/ + +#include +#include +#include +#include +#include +#include + +#include "graph_node.hpp" +#include "graph_edge.hpp" +#include "../ytensor_base.hpp" + +namespace yt{ +namespace ad{ + +// 算子执行函数类型定义 +// 输入:节点指针,输入边列表,输出边列表 +// 输出:void(结果写入输出边) +using OpExecutor = std::function>&, + std::vector>& +)>; + +/// @brief 计算图运行时类 +class ComputationGraph { +public: + /// @brief 默认构造函数 + ComputationGraph() = default; + + /// @brief 创建一个新节点 + /// @param nodeId 节点ID(必须唯一) + /// @param opType 算子类型或节点描述 + /// @param nodeType 节点类型(默认为Operator) + /// @return 节点指针 + std::shared_ptr createNode(const std::string& nodeId, const std::string& opType, NodeType nodeType = NodeType::Operator); + + /// @brief 创建一个新边 + /// @param edgeId 边ID(必须唯一) + /// @param fromNode 源节点(nullptr表示输入边) + /// @param toNode 目标节点(nullptr表示输出边) + /// @return 边指针 + std::shared_ptr createEdge( + const std::string& edgeId, + std::shared_ptr fromNode = nullptr, + std::shared_ptr toNode = nullptr + ); + + /// @brief 连接两个节点 + /// @param fromNodeId 源节点ID + /// @param toNodeId 目标节点ID + /// @param edgeId 边ID(如果为空则自动生成) + /// @return 创建的边指针 + std::shared_ptr connect( + const std::string& fromNodeId, + const std::string& toNodeId, + const std::string& edgeId = "" + ); + + /// @brief 获取节点 + /// @param nodeId 节点ID + /// @return 节点指针,如果不存在则返回nullptr + std::shared_ptr getNode(const std::string& nodeId) const; + + /// @brief 获取边 + /// @param edgeId 边ID + /// @return 边指针,如果不存在则返回nullptr + std::shared_ptr getEdge(const std::string& edgeId) const; + + /// @brief 获取所有节点 + /// @return 所有节点的map + const std::unordered_map>& getNodes() const { + return nodes_; + } + + /// @brief 获取所有边 + /// @return 所有边的map + const std::unordered_map>& getEdges() const { + return edges_; + } + + /// @brief 注册算子执行器 + /// @param opType 算子类型 + /// @param executor 执行函数 + void registerOperator(const std::string& opType, OpExecutor executor); + + /// @brief 执行计算图 + /// @param inputs 输入张量映射(边ID -> 张量) + /// @param outputs 输出边ID列表 + /// @return 输出张量映射(边ID -> 张量) + std::unordered_map execute( + const std::unordered_map& inputs, + const std::vector& outputs + ); + + /// @brief 执行单个节点 + /// @param nodeId 节点ID + void executeNode(const std::string& nodeId); + + /// @brief 拓扑排序,返回执行顺序 + /// @return 节点ID的执行顺序 + std::vector topologicalSort() const; + + /// @brief 重置所有节点的执行状态 + void reset(); + + /// @brief 清空计算图 + void clear(); + + /// @brief 获取计算图中节点的数量 + /// @return 节点数量 + size_t nodeCount() const { return nodes_.size(); } + + /// @brief 获取计算图中边的数量 + /// @return 边数量 + size_t edgeCount() const { return edges_.size(); } + + /// @brief 检查计算图是否为空 + /// @return 如果为空返回true + bool empty() const { return nodes_.empty(); } + +private: + /// @brief 生成唯一的边ID + /// @return 唯一的边ID + std::string generateEdgeId(); + + /// @brief 检查图是否有环(DFS辅助函数) + /// @param nodeId 当前节点ID + /// @param visited 访问标记 + /// @param recursionStack 递归栈标记 + /// @return 如果检测到环返回true + bool hasCycleDFS( + const std::string& nodeId, + std::unordered_map& visited, + std::unordered_map& recursionStack + ) const; + + /// @brief 拓扑排序的DFS辅助函数 + /// @param nodeId 当前节点ID + /// @param visited 访问标记 + /// @param stack 结果栈 + void topologicalSortDFS( + const std::string& nodeId, + std::unordered_map& visited, + std::vector& stack + ) const; + +private: + std::unordered_map> nodes_; // 节点映射表 + std::unordered_map> edges_; // 边映射表 + std::unordered_map operators_; // 算子执行器映射表 + int edgeCounter_ = 0; // 边计数器,用于生成唯一ID +}; + +} // namespace ad +} // namespace yt diff --git a/include/ad/graph_serialization.hpp b/include/ad/graph_serialization.hpp new file mode 100644 index 0000000..d306be2 --- /dev/null +++ b/include/ad/graph_serialization.hpp @@ -0,0 +1,113 @@ +#pragma once +/*************** +* @file: graph_serialization.hpp +* @brief: 计算图序列化,支持JSON格式的导入导出 +* @description: 提供计算图与JSON之间的序列化和反序列化功能 +***************/ + +#include +#include +#include +#include +#include + +#include "graph_runtime.hpp" + +namespace yt{ +namespace ad{ + +/// @brief 简单的JSON构建器类(用于序列化) +class JsonBuilder { +public: + /// @brief 开始一个对象 + void beginObject(); + + /// @brief 结束一个对象 + void endObject(); + + /// @brief 开始一个数组 + void beginArray(); + + /// @brief 结束一个数组 + void endArray(); + + /// @brief 添加键值对 + void addKey(const std::string& key); + + /// @brief 添加字符串值 + void addString(const std::string& value); + + /// @brief 添加数字值 + void addNumber(double value); + + /// @brief 添加整数值 + void addInt(int value); + + /// @brief 添加布尔值 + void addBool(bool value); + + /// @brief 添加null值 + void addNull(); + + /// @brief 获取生成的JSON字符串 + std::string toString() const { return buffer_; } + + /// @brief 清空缓冲区 + void clear() { buffer_.clear(); needComma_ = false; } + +private: + /// @brief 添加逗号(如果需要) + void addCommaIfNeeded(); + + std::string buffer_; + bool needComma_ = false; +}; + +/// @brief 计算图序列化类 +class GraphSerializer { +public: + /// @brief 将计算图序列化为JSON字符串 + /// @param graph 计算图 + /// @return JSON字符串 + static std::string toJson(const ComputationGraph& graph); + + /// @brief 将计算图序列化到JSON文件 + /// @param graph 计算图 + /// @param filename 文件名 + /// @return 如果成功返回true + static bool toJsonFile(const ComputationGraph& graph, const std::string& filename); + + /// @brief 从JSON字符串反序列化计算图 + /// @param json JSON字符串 + /// @param graph 输出的计算图 + /// @return 如果成功返回true + static bool fromJson(const std::string& json, ComputationGraph& graph); + + /// @brief 从JSON文件反序列化计算图 + /// @param filename 文件名 + /// @param graph 输出的计算图 + /// @return 如果成功返回true + static bool fromJsonFile(const std::string& filename, ComputationGraph& graph); + +private: + /// @brief 序列化单个节点 + /// @param node 节点指针 + /// @param builder JSON构建器 + static void serializeNode(const std::shared_ptr& node, JsonBuilder& builder); + + /// @brief 序列化单个边 + /// @param edge 边指针 + /// @param builder JSON构建器 + static void serializeEdge(const std::shared_ptr& edge, JsonBuilder& builder); + + /// @brief 序列化节点参数 + /// @param parameters 参数映射 + /// @param builder JSON构建器 + static void serializeParameters( + const std::unordered_map& parameters, + JsonBuilder& builder + ); +}; + +} // namespace ad +} // namespace yt diff --git a/include/ytensor_concepts.hpp b/include/ytensor_concepts.hpp index ec3c610..299a881 100644 --- a/include/ytensor_concepts.hpp +++ b/include/ytensor_concepts.hpp @@ -3,6 +3,7 @@ #include #include #include +#include // 前向声明 namespace yt { diff --git a/src/ad/graph_edge.inl b/src/ad/graph_edge.inl new file mode 100644 index 0000000..e0283dc --- /dev/null +++ b/src/ad/graph_edge.inl @@ -0,0 +1,16 @@ +/*************** +* @file: graph_edge.inl +* @brief: 计算图边的实现 +***************/ + +namespace yt{ +namespace ad{ + +inline GraphEdge::GraphEdge(const std::string& edgeId, + std::shared_ptr fromNode, + std::shared_ptr toNode) + : edgeId_(edgeId), name_(edgeId), fromNode_(fromNode), toNode_(toNode) { +} + +} // namespace ad +} // namespace yt diff --git a/src/ad/graph_node.inl b/src/ad/graph_node.inl new file mode 100644 index 0000000..b122ff7 --- /dev/null +++ b/src/ad/graph_node.inl @@ -0,0 +1,42 @@ +/*************** +* @file: graph_node.inl +* @brief: 计算图节点的实现 +***************/ + +namespace yt{ +namespace ad{ + +inline GraphNode::GraphNode(const std::string& nodeId, const std::string& opType, NodeType nodeType) + : nodeId_(nodeId), opType_(opType), name_(nodeId), nodeType_(nodeType) { +} + +inline void GraphNode::addInputEdge(std::shared_ptr edge) { + if (edge) { + inputEdges_.push_back(edge); + } +} + +inline void GraphNode::addOutputEdge(std::shared_ptr edge) { + if (edge) { + outputEdges_.push_back(edge); + } +} + +inline void GraphNode::setParameter(const std::string& key, const std::any& value) { + parameters_[key] = value; +} + +inline std::any GraphNode::getParameter(const std::string& key) const { + auto it = parameters_.find(key); + if (it != parameters_.end()) { + return it->second; + } + return std::any(); +} + +inline bool GraphNode::hasParameter(const std::string& key) const { + return parameters_.find(key) != parameters_.end(); +} + +} // namespace ad +} // namespace yt diff --git a/src/ad/graph_runtime.inl b/src/ad/graph_runtime.inl new file mode 100644 index 0000000..dc2924a --- /dev/null +++ b/src/ad/graph_runtime.inl @@ -0,0 +1,279 @@ +/*************** +* @file: graph_runtime.inl +* @brief: 计算图运行时的实现 +***************/ + +#include +#include +#include + +namespace yt{ +namespace ad{ + +inline std::shared_ptr ComputationGraph::createNode( + const std::string& nodeId, const std::string& opType, NodeType nodeType) { + + if (nodes_.find(nodeId) != nodes_.end()) { + throw std::runtime_error("Node with ID '" + nodeId + "' already exists"); + } + + auto node = std::make_shared(nodeId, opType, nodeType); + nodes_[nodeId] = node; + return node; +} + +inline std::shared_ptr ComputationGraph::createEdge( + const std::string& edgeId, + std::shared_ptr fromNode, + std::shared_ptr toNode) { + + if (edges_.find(edgeId) != edges_.end()) { + throw std::runtime_error("Edge with ID '" + edgeId + "' already exists"); + } + + auto edge = std::make_shared(edgeId, fromNode, toNode); + edges_[edgeId] = edge; + + // 更新节点的输入输出边列表 + if (fromNode) { + fromNode->addOutputEdge(edge); + } + if (toNode) { + toNode->addInputEdge(edge); + } + + return edge; +} + +inline std::shared_ptr ComputationGraph::connect( + const std::string& fromNodeId, + const std::string& toNodeId, + const std::string& edgeId) { + + auto fromNode = getNode(fromNodeId); + auto toNode = getNode(toNodeId); + + if (!fromNode) { + throw std::runtime_error("Source node '" + fromNodeId + "' not found"); + } + if (!toNode) { + throw std::runtime_error("Target node '" + toNodeId + "' not found"); + } + + std::string actualEdgeId = edgeId.empty() ? generateEdgeId() : edgeId; + return createEdge(actualEdgeId, fromNode, toNode); +} + +inline std::shared_ptr ComputationGraph::getNode(const std::string& nodeId) const { + auto it = nodes_.find(nodeId); + return (it != nodes_.end()) ? it->second : nullptr; +} + +inline std::shared_ptr ComputationGraph::getEdge(const std::string& edgeId) const { + auto it = edges_.find(edgeId); + return (it != edges_.end()) ? it->second : nullptr; +} + +inline void ComputationGraph::registerOperator(const std::string& opType, OpExecutor executor) { + operators_[opType] = executor; +} + +inline std::unordered_map ComputationGraph::execute( + const std::unordered_map& inputs, + const std::vector& outputs) { + + // 重置所有节点的执行状态 + reset(); + + // 设置输入张量 + for (const auto& [edgeId, tensor] : inputs) { + auto edge = getEdge(edgeId); + if (!edge) { + throw std::runtime_error("Input edge '" + edgeId + "' not found"); + } + edge->setTensor(tensor); + } + + // 获取执行顺序 + std::vector execOrder = topologicalSort(); + + // 按拓扑顺序执行节点 + for (const auto& nodeId : execOrder) { + executeNode(nodeId); + } + + // 收集输出张量 + std::unordered_map results; + for (const auto& edgeId : outputs) { + auto edge = getEdge(edgeId); + if (!edge) { + throw std::runtime_error("Output edge '" + edgeId + "' not found"); + } + if (!edge->hasTensor()) { + throw std::runtime_error("Output edge '" + edgeId + "' has no tensor data"); + } + results[edgeId] = edge->getTensor(); + } + + return results; +} + +inline void ComputationGraph::executeNode(const std::string& nodeId) { + auto node = getNode(nodeId); + if (!node) { + throw std::runtime_error("Node '" + nodeId + "' not found"); + } + + if (node->isExecuted()) { + return; // 节点已执行,跳过 + } + + // 参数节点和常量节点不需要执行,直接标记为已执行 + if (node->getNodeType() == NodeType::Parameter || node->getNodeType() == NodeType::Constant) { + // 将节点数据传递到输出边 + if (node->hasData() && !node->getOutputEdges().empty()) { + for (auto& edge : node->getOutputEdges()) { + edge->setTensor(node->getData()); + } + } + node->setExecuted(true); + return; + } + + // 输入节点:从输入边读取数据,传递到输出边 + if (node->getNodeType() == NodeType::Input) { + auto inputEdges = node->getInputEdges(); + auto outputEdges = node->getOutputEdges(); + if (!inputEdges.empty() && !outputEdges.empty()) { + if (inputEdges[0]->hasTensor()) { + for (auto& edge : outputEdges) { + edge->setTensor(inputEdges[0]->getTensor()); + } + } + } + node->setExecuted(true); + return; + } + + // 算子节点:执行计算 + auto it = operators_.find(node->getOpType()); + if (it == operators_.end()) { + throw std::runtime_error("Operator '" + node->getOpType() + "' not registered"); + } + + // 获取输入输出边 + auto inputEdges = node->getInputEdges(); + auto outputEdges = node->getOutputEdges(); + + // 检查输入是否就绪 + for (const auto& edge : inputEdges) { + if (!edge->hasTensor()) { + throw std::runtime_error("Input edge '" + edge->getId() + "' for node '" + + nodeId + "' has no tensor data"); + } + } + + // 执行算子 + it->second(*node, inputEdges, outputEdges); + + // 标记节点已执行 + node->setExecuted(true); +} + +inline std::vector ComputationGraph::topologicalSort() const { + std::vector result; + std::unordered_map visited; + + // 初始化访问标记 + for (const auto& [nodeId, _] : nodes_) { + visited[nodeId] = false; + } + + // 对每个未访问的节点执行DFS + for (const auto& [nodeId, _] : nodes_) { + if (!visited[nodeId]) { + topologicalSortDFS(nodeId, visited, result); + } + } + + // 反转结果(DFS后序遍历需要反转才是拓扑序) + std::reverse(result.begin(), result.end()); + return result; +} + +inline void ComputationGraph::topologicalSortDFS( + const std::string& nodeId, + std::unordered_map& visited, + std::vector& stack) const { + + visited[nodeId] = true; + + auto node = getNode(nodeId); + if (node) { + // 访问所有输出边的目标节点 + for (const auto& edge : node->getOutputEdges()) { + auto toNode = edge->getToNode(); + if (toNode) { + const std::string& nextNodeId = toNode->getId(); + if (!visited[nextNodeId]) { + topologicalSortDFS(nextNodeId, visited, stack); + } + } + } + } + + // 将当前节点加入栈 + stack.push_back(nodeId); +} + +inline void ComputationGraph::reset() { + for (auto& [_, node] : nodes_) { + node->reset(); + } +} + +inline void ComputationGraph::clear() { + nodes_.clear(); + edges_.clear(); + operators_.clear(); + edgeCounter_ = 0; +} + +inline std::string ComputationGraph::generateEdgeId() { + std::ostringstream oss; + oss << "edge_" << edgeCounter_++; + return oss.str(); +} + +inline bool ComputationGraph::hasCycleDFS( + const std::string& nodeId, + std::unordered_map& visited, + std::unordered_map& recursionStack) const { + + visited[nodeId] = true; + recursionStack[nodeId] = true; + + auto node = getNode(nodeId); + if (node) { + for (const auto& edge : node->getOutputEdges()) { + auto toNode = edge->getToNode(); + if (toNode) { + const std::string& nextNodeId = toNode->getId(); + + if (!visited[nextNodeId]) { + if (hasCycleDFS(nextNodeId, visited, recursionStack)) { + return true; + } + } else if (recursionStack[nextNodeId]) { + return true; // 发现环 + } + } + } + } + + recursionStack[nodeId] = false; + return false; +} + +} // namespace ad +} // namespace yt diff --git a/src/ad/graph_serialization.inl b/src/ad/graph_serialization.inl new file mode 100644 index 0000000..a9e88f7 --- /dev/null +++ b/src/ad/graph_serialization.inl @@ -0,0 +1,264 @@ +/*************** +* @file: graph_serialization.inl +* @brief: 计算图序列化的实现 +***************/ + +#include +#include + +namespace yt{ +namespace ad{ + +// JsonBuilder 实现 +inline void JsonBuilder::beginObject() { + addCommaIfNeeded(); + buffer_ += "{"; + needComma_ = false; +} + +inline void JsonBuilder::endObject() { + buffer_ += "}"; + needComma_ = true; +} + +inline void JsonBuilder::beginArray() { + addCommaIfNeeded(); + buffer_ += "["; + needComma_ = false; +} + +inline void JsonBuilder::endArray() { + buffer_ += "]"; + needComma_ = true; +} + +inline void JsonBuilder::addKey(const std::string& key) { + addCommaIfNeeded(); + buffer_ += "\"" + key + "\":"; + needComma_ = false; +} + +inline void JsonBuilder::addString(const std::string& value) { + addCommaIfNeeded(); + buffer_ += "\""; + // 转义特殊字符 + for (char c : value) { + switch (c) { + case '"': buffer_ += "\\\""; break; + case '\\': buffer_ += "\\\\"; break; + case '\n': buffer_ += "\\n"; break; + case '\r': buffer_ += "\\r"; break; + case '\t': buffer_ += "\\t"; break; + default: buffer_ += c; + } + } + buffer_ += "\""; + needComma_ = true; +} + +inline void JsonBuilder::addNumber(double value) { + addCommaIfNeeded(); + std::ostringstream oss; + oss << std::setprecision(15) << value; + buffer_ += oss.str(); + needComma_ = true; +} + +inline void JsonBuilder::addInt(int value) { + addCommaIfNeeded(); + buffer_ += std::to_string(value); + needComma_ = true; +} + +inline void JsonBuilder::addBool(bool value) { + addCommaIfNeeded(); + buffer_ += value ? "true" : "false"; + needComma_ = true; +} + +inline void JsonBuilder::addNull() { + addCommaIfNeeded(); + buffer_ += "null"; + needComma_ = true; +} + +inline void JsonBuilder::addCommaIfNeeded() { + if (needComma_) { + buffer_ += ","; + } +} + +// GraphSerializer 实现 +inline std::string GraphSerializer::toJson(const ComputationGraph& graph) { + JsonBuilder builder; + + builder.beginObject(); + + // 序列化节点 + builder.addKey("nodes"); + builder.beginArray(); + for (const auto& [nodeId, node] : graph.getNodes()) { + serializeNode(node, builder); + } + builder.endArray(); + + // 序列化边 + builder.addKey("edges"); + builder.beginArray(); + for (const auto& [edgeId, edge] : graph.getEdges()) { + serializeEdge(edge, builder); + } + builder.endArray(); + + builder.endObject(); + + return builder.toString(); +} + +inline bool GraphSerializer::toJsonFile(const ComputationGraph& graph, const std::string& filename) { + std::string json = toJson(graph); + std::ofstream file(filename); + if (!file.is_open()) { + return false; + } + file << json; + file.close(); + return true; +} + +inline bool GraphSerializer::fromJson(const std::string& json, ComputationGraph& graph) { + // 简化的JSON解析实现 + // 注意:这是一个简化版本,仅用于基本功能演示 + // 实际应用中建议使用专业的JSON库如nlohmann/json + + graph.clear(); + + // 这里提供一个基本的解析框架 + // 完整实现需要一个完整的JSON解析器 + // 暂时返回false表示需要使用外部JSON库 + + return false; +} + +inline bool GraphSerializer::fromJsonFile(const std::string& filename, ComputationGraph& graph) { + std::ifstream file(filename); + if (!file.is_open()) { + return false; + } + + std::stringstream buffer; + buffer << file.rdbuf(); + file.close(); + + return fromJson(buffer.str(), graph); +} + +inline void GraphSerializer::serializeNode(const std::shared_ptr& node, JsonBuilder& builder) { + builder.beginObject(); + + builder.addKey("id"); + builder.addString(node->getId()); + + builder.addKey("opType"); + builder.addString(node->getOpType()); + + builder.addKey("name"); + builder.addString(node->getName()); + + // 序列化节点类型 + builder.addKey("nodeType"); + switch(node->getNodeType()) { + case NodeType::Operator: builder.addString("operator"); break; + case NodeType::Parameter: builder.addString("parameter"); break; + case NodeType::Input: builder.addString("input"); break; + case NodeType::Constant: builder.addString("constant"); break; + } + + // 序列化输入边ID列表 + builder.addKey("inputEdges"); + builder.beginArray(); + for (const auto& edge : node->getInputEdges()) { + builder.addString(edge->getId()); + } + builder.endArray(); + + // 序列化输出边ID列表 + builder.addKey("outputEdges"); + builder.beginArray(); + for (const auto& edge : node->getOutputEdges()) { + builder.addString(edge->getId()); + } + builder.endArray(); + + // 序列化参数 + builder.addKey("parameters"); + serializeParameters(node->getParameters(), builder); + + builder.endObject(); +} + +inline void GraphSerializer::serializeEdge(const std::shared_ptr& edge, JsonBuilder& builder) { + builder.beginObject(); + + builder.addKey("id"); + builder.addString(edge->getId()); + + builder.addKey("name"); + builder.addString(edge->getName()); + + builder.addKey("fromNode"); + if (edge->getFromNode()) { + builder.addString(edge->getFromNode()->getId()); + } else { + builder.addNull(); + } + + builder.addKey("toNode"); + if (edge->getToNode()) { + builder.addString(edge->getToNode()->getId()); + } else { + builder.addNull(); + } + + builder.endObject(); +} + +inline void GraphSerializer::serializeParameters( + const std::unordered_map& parameters, + JsonBuilder& builder) { + + builder.beginObject(); + + for (const auto& [key, value] : parameters) { + builder.addKey(key); + + // 尝试转换常见类型 + // 注意:std::any的类型检查在运行时进行 + try { + if (value.type() == typeid(int)) { + builder.addInt(std::any_cast(value)); + } else if (value.type() == typeid(double)) { + builder.addNumber(std::any_cast(value)); + } else if (value.type() == typeid(float)) { + builder.addNumber(static_cast(std::any_cast(value))); + } else if (value.type() == typeid(bool)) { + builder.addBool(std::any_cast(value)); + } else if (value.type() == typeid(std::string)) { + builder.addString(std::any_cast(value)); + } else if (value.type() == typeid(const char*)) { + builder.addString(std::any_cast(value)); + } else { + // 未知类型,序列化为null + builder.addNull(); + } + } catch (...) { + // 转换失败,序列化为null + builder.addNull(); + } + } + + builder.endObject(); +} + +} // namespace ad +} // namespace yt diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt new file mode 100644 index 0000000..d7c556f --- /dev/null +++ b/test/CMakeLists.txt @@ -0,0 +1,23 @@ +cmake_minimum_required(VERSION 3.10) + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra") + +set(CMAKE_BUILD_TYPE Release) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2") + +project(ad_tests) + +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +find_package(ZLIB REQUIRED) + +# Test 1: Basic graph test +add_executable(test_ad_graph test_ad_graph.cpp) +target_link_libraries(test_ad_graph ZLIB::ZLIB pthread) + +# Test 2: FFN comparison test +add_executable(test_ad_ffn test_ad_ffn.cpp) +target_link_libraries(test_ad_ffn ZLIB::ZLIB pthread) + diff --git a/test/test_ad_ffn.cpp b/test/test_ad_ffn.cpp new file mode 100644 index 0000000..2943ac3 --- /dev/null +++ b/test/test_ad_ffn.cpp @@ -0,0 +1,299 @@ +#include +#include +#include +#include +#include +#include "../ytensor.hpp" +#include "../include/ad.hpp" + +using namespace yt; +using namespace yt::ad; + +// 简化的FFN块:与ymodel2的FFN类似 +// FFN: x -> linear(up) -> gelu -> linear(down) -> output + +struct FFNConfig { + int hidden_size = 64; + int intermediate_size = 128; +}; + +class SimplifiedFFNGraph { +public: + ComputationGraph graph; + FFNConfig config; + + void build() { + std::cout << "\n=== 构建简化FFN计算图 ===" << std::endl; + + // 注册算子 + registerOperators(); + + // 创建节点 + // 输入节点 + auto input_node = graph.createNode("input", "input", NodeType::Input); + input_node->setName("输入X"); + + // 参数节点 + auto up_weight = graph.createNode("up_weight", "parameter", NodeType::Parameter); + up_weight->setName("上投影权重"); + + auto down_weight = graph.createNode("down_weight", "parameter", NodeType::Parameter); + down_weight->setName("下投影权重"); + + // 算子节点 + auto linear1 = graph.createNode("linear1", "linear", NodeType::Operator); + linear1->setName("第一层线性变换"); + + auto gelu = graph.createNode("gelu", "gelu", NodeType::Operator); + gelu->setName("GELU激活"); + + auto linear2 = graph.createNode("linear2", "linear", NodeType::Operator); + linear2->setName("第二层线性变换"); + + auto output_node = graph.createNode("output", "output", NodeType::Operator); + output_node->setName("输出"); + + // 创建边连接 + graph.createEdge("x_input", nullptr, input_node); + graph.createEdge("x_to_l1", input_node, linear1); + graph.createEdge("up_to_l1", up_weight, linear1); + graph.createEdge("h1", linear1, gelu); + graph.createEdge("h2", gelu, linear2); + graph.createEdge("down_to_l2", down_weight, linear2); + graph.createEdge("y", linear2, output_node); + graph.createEdge("output", output_node, nullptr); + + std::cout << "图构建完成: " << graph.nodeCount() << " 节点, " + << graph.edgeCount() << " 边" << std::endl; + } + + void registerOperators() { + // 输出算子 + graph.registerOperator("output", [](const GraphNode&, + const std::vector>& inputs, + std::vector>& outputs) { + if (!inputs.empty() && !outputs.empty()) { + outputs[0]->setTensor(inputs[0]->getTensor()); + } + }); + + // 线性层算子 (matmul) + graph.registerOperator("linear", [](const GraphNode&, + const std::vector>& inputs, + std::vector>& outputs) { + if (inputs.size() < 2 || outputs.empty()) { + throw std::runtime_error("linear requires 2 inputs"); + } + + auto x = inputs[0]->getTensor(); + const auto& w = inputs[1]->getTensor(); + + // 支持3D输入 [batch, seq, hidden] + if (x.ndim() == 3) { + int b = x.shape(0), s = x.shape(1), h = x.shape(2); + x = x.view(b * s, h); + YTensorBase result = x.matmul(w); + int h_out = result.shape(1); + result = result.view(b, s, h_out); + outputs[0]->setTensor(result); + } else { + outputs[0]->setTensor(x.matmul(w)); + } + }); + + // GELU激活 + graph.registerOperator("gelu", [](const GraphNode&, + const std::vector>& inputs, + std::vector>& outputs) { + if (inputs.empty() || outputs.empty()) { + throw std::runtime_error("gelu requires 1 input"); + } + + const auto& x = inputs[0]->getTensor(); + YTensorBase result = x; + + // GELU: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) + const float sqrt_2_over_pi = 0.7978845608f; + for (size_t i = 0; i < result.size(); i++) { + float val = result.atData(i); + float x3 = val * val * val; + float inner = sqrt_2_over_pi * (val + 0.044715f * x3); + float tanh_val = std::tanh(inner); + result.atData(i) = 0.5f * val * (1.0f + tanh_val); + } + + outputs[0]->setTensor(result); + }); + } + + void printGraph() { + std::cout << "\n========== 计算图结构 ==========" << std::endl; + std::cout << "节点总数: " << graph.nodeCount() << std::endl; + std::cout << "边总数: " << graph.edgeCount() << std::endl; + + std::cout << "\n节点列表:" << std::endl; + for (const auto& [id, node] : graph.getNodes()) { + std::string nodeTypeStr; + switch(node->getNodeType()) { + case NodeType::Operator: nodeTypeStr = "算子"; break; + case NodeType::Parameter: nodeTypeStr = "参数"; break; + case NodeType::Input: nodeTypeStr = "输入"; break; + case NodeType::Constant: nodeTypeStr = "常量"; break; + } + std::cout << " [" << id << "] " << node->getName() + << " (类型: " << nodeTypeStr << ")" << std::endl; + } + + std::cout << "\n拓扑排序:" << std::endl; + auto topo = graph.topologicalSort(); + for (size_t i = 0; i < topo.size(); i++) { + auto node = graph.getNode(topo[i]); + std::cout << " " << (i+1) << ". " << node->getName() << std::endl; + } + std::cout << "==============================\n" << std::endl; + } + + YTensorBase forward(const YTensorBase& x, const YTensorBase& up_w, const YTensorBase& down_w) { + // 设置参数 + graph.getNode("up_weight")->setData(up_w); + graph.getNode("down_weight")->setData(down_w); + + // 执行 + std::unordered_map inputs = {{"x_input", x}}; + auto outputs = graph.execute(inputs, {"output"}); + return outputs["output"]; + } + + std::string toJson() { + return GraphSerializer::toJson(graph); + } +}; + +// 直接实现的FFN(用于对比) +YTensorBase directFFN(const YTensorBase& x, const YTensorBase& up_w, const YTensorBase& down_w) { + // x @ up_w + auto x_copy = x; + if (x_copy.ndim() == 3) { + int b = x_copy.shape(0), s = x_copy.shape(1), h = x_copy.shape(2); + x_copy = x_copy.view(b * s, h); + auto h1 = x_copy.matmul(up_w); + int h_out = h1.shape(1); + h1 = h1.view(b, s, h_out); + + // GELU + const float sqrt_2_over_pi = 0.7978845608f; + for (size_t i = 0; i < h1.size(); i++) { + float val = h1.atData(i); + float x3 = val * val * val; + float inner = sqrt_2_over_pi * (val + 0.044715f * x3); + float tanh_val = std::tanh(inner); + h1.atData(i) = 0.5f * val * (1.0f + tanh_val); + } + + // h1 @ down_w + h1 = h1.view(b * s, h1.shape(2)); + auto output = h1.matmul(down_w); + output = output.view(b, s, output.shape(1)); + return output; + } + return x; +} + +int main() { + std::cout << "========================================" << std::endl; + std::cout << " YTensor 计算图完整测试" << std::endl; + std::cout << "========================================" << std::endl; + + try { + // 创建简化的FFN图 + SimplifiedFFNGraph ffn_graph; + ffn_graph.build(); + ffn_graph.printGraph(); + + // 准备测试数据 + std::cout << "\n=== 准备测试数据 ===" << std::endl; + int batch = 2, seq_len = 3, hidden = 64, intermediate = 128; + + YTensorBase x({batch, seq_len, hidden}, "float32"); + YTensorBase up_weight({hidden, intermediate}, "float32"); + YTensorBase down_weight({intermediate, hidden}, "float32"); + + // 使用固定种子初始化 + std::mt19937 gen(42); + std::normal_distribution dist(0.0f, 0.1f); + + for (size_t i = 0; i < x.size(); i++) { + x.atData(i) = dist(gen); + } + for (size_t i = 0; i < up_weight.size(); i++) { + up_weight.atData(i) = dist(gen); + } + for (size_t i = 0; i < down_weight.size(); i++) { + down_weight.atData(i) = dist(gen); + } + + std::cout << "输入形状: [" << batch << ", " << seq_len << ", " << hidden << "]" << std::endl; + std::cout << "上投影权重: [" << hidden << ", " << intermediate << "]" << std::endl; + std::cout << "下投影权重: [" << intermediate << ", " << hidden << "]" << std::endl; + + // 通过计算图执行 + std::cout << "\n=== 计算图前向传播 ===" << std::endl; + auto graph_output = ffn_graph.forward(x, up_weight, down_weight); + + std::cout << "输出形状: [" << graph_output.shape(0) << ", " + << graph_output.shape(1) << ", " << graph_output.shape(2) << "]" << std::endl; + std::cout << "输出数据(前5个元素): "; + for (int i = 0; i < 5 && i < (int)graph_output.size(); i++) { + std::cout << std::fixed << std::setprecision(6) << graph_output.atData(i) << " "; + } + std::cout << std::endl; + + // 直接实现对比 + std::cout << "\n=== 直接实现对比 ===" << std::endl; + auto direct_output = directFFN(x, up_weight, down_weight); + + std::cout << "输出形状: [" << direct_output.shape(0) << ", " + << direct_output.shape(1) << ", " << direct_output.shape(2) << "]" << std::endl; + std::cout << "输出数据(前5个元素): "; + for (int i = 0; i < 5 && i < (int)direct_output.size(); i++) { + std::cout << std::fixed << std::setprecision(6) << direct_output.atData(i) << " "; + } + std::cout << std::endl; + + // 验证结果 + std::cout << "\n=== 验证结果一致性 ===" << std::endl; + float max_diff = 0.0f; + for (size_t i = 0; i < graph_output.size(); i++) { + float diff = std::abs(graph_output.atData(i) - direct_output.atData(i)); + max_diff = std::max(max_diff, diff); + } + + std::cout << "最大差异: " << std::scientific << max_diff << std::endl; + if (max_diff < 1e-5f) { + std::cout << "✓ 结果一致!(误差 < 1e-5)" << std::endl; + } else { + std::cout << "✗ 结果不一致!" << std::endl; + } + + // JSON序列化 + std::cout << "\n=== JSON序列化 ===" << std::endl; + std::string json = ffn_graph.toJson(); + std::cout << "JSON长度: " << json.length() << " 字符" << std::endl; + std::cout << "保存到 /tmp/ffn_graph.json" << std::endl; + GraphSerializer::toJsonFile(ffn_graph.graph, "/tmp/ffn_graph.json"); + + // 显示JSON片段 + std::cout << "\nJSON片段(前300字符):" << std::endl; + std::cout << json.substr(0, 300) << "..." << std::endl; + + } catch (const std::exception& e) { + std::cerr << "✗ 错误: " << e.what() << std::endl; + return 1; + } + + std::cout << "\n========================================" << std::endl; + std::cout << " 所有测试完成" << std::endl; + std::cout << "========================================" << std::endl; + + return 0; +} diff --git a/test/test_ad_graph.cpp b/test/test_ad_graph.cpp new file mode 100644 index 0000000..73660fd --- /dev/null +++ b/test/test_ad_graph.cpp @@ -0,0 +1,269 @@ +#include +#include +#include +#include "../ytensor.hpp" +#include "../include/ad.hpp" + +using namespace yt; +using namespace yt::ad; + +// 测试用:构建一个简单的前馈神经网络计算图 +// 模拟 ymodel2 的一个简化版本结构 +class TestFFNGraph { +public: + ComputationGraph graph; + + // 构建一个简单的 FFN 层计算图: y = gelu(x @ W1) @ W2 + // 现在 W1 和 W2 是参数节点,不是边 + void build() { + std::cout << "构建计算图..." << std::endl; + + // 注册算子 + registerOperators(); + + // 创建节点 + auto input_node = graph.createNode("input", "input", NodeType::Input); + auto w1_node = graph.createNode("w1", "parameter", NodeType::Parameter); + auto w2_node = graph.createNode("w2", "parameter", NodeType::Parameter); + auto matmul1_node = graph.createNode("matmul1", "matmul", NodeType::Operator); + auto gelu_node = graph.createNode("gelu", "gelu", NodeType::Operator); + auto matmul2_node = graph.createNode("matmul2", "matmul", NodeType::Operator); + auto output_node = graph.createNode("output", "output", NodeType::Operator); + + // 设置节点名称 + input_node->setName("输入节点"); + w1_node->setName("权重参数W1"); + w2_node->setName("权重参数W2"); + matmul1_node->setName("第一层矩阵乘法"); + gelu_node->setName("GELU激活"); + matmul2_node->setName("第二层矩阵乘法"); + output_node->setName("输出节点"); + + // 创建边连接:现在 w1 和 w2 是节点,需要通过边连接到算子 + auto x_edge = graph.createEdge("x", nullptr, input_node); // 外部输入 -> input节点 + auto x_to_mm1 = graph.createEdge("x_to_mm1", input_node, matmul1_node); // input -> matmul1 + auto w1_to_mm1 = graph.createEdge("w1_to_mm1", w1_node, matmul1_node); // w1参数 -> matmul1 + auto h1_edge = graph.createEdge("h1", matmul1_node, gelu_node); // matmul1 -> gelu + auto h2_edge = graph.createEdge("h2", gelu_node, matmul2_node); // gelu -> matmul2 + auto w2_to_mm2 = graph.createEdge("w2_to_mm2", w2_node, matmul2_node); // w2参数 -> matmul2 + auto y_edge = graph.createEdge("y", matmul2_node, output_node); // matmul2 -> output + auto out_edge = graph.createEdge("out", output_node, nullptr); // output -> 外部输出 + + std::cout << "计算图构建完成!" << std::endl; + std::cout << " 节点数: " << graph.nodeCount() << std::endl; + std::cout << " 边数: " << graph.edgeCount() << std::endl; + } + + void registerOperators() { + // 输出算子:直接传递输出 + graph.registerOperator("output", [](const GraphNode& node, + const std::vector>& inputs, + std::vector>& outputs) { + if (!inputs.empty() && !outputs.empty()) { + outputs[0]->setTensor(inputs[0]->getTensor()); + } + }); + + // 矩阵乘法算子 (支持3D输入) + graph.registerOperator("matmul", [](const GraphNode& node, + const std::vector>& inputs, + std::vector>& outputs) { + if (inputs.size() < 2 || outputs.empty()) { + throw std::runtime_error("matmul requires 2 inputs and 1 output"); + } + auto x = inputs[0]->getTensor(); + const auto& w = inputs[1]->getTensor(); + + // 如果是3D张量,需要reshape成2D,执行matmul,再reshape回3D + if (x.ndim() == 3) { + int b = x.shape(0), s = x.shape(1), h = x.shape(2); + x = x.view(b * s, h); // [b, s, h] -> [b*s, h] + YTensorBase result = x.matmul(w); // [b*s, h] @ [h, h'] -> [b*s, h'] + int h_out = result.shape(1); + result = result.view(b, s, h_out); // [b*s, h'] -> [b, s, h'] + outputs[0]->setTensor(result); + } else { + // 2D矩阵乘法 + YTensorBase result = x.matmul(w); + outputs[0]->setTensor(result); + } + }); + + // GELU激活函数算子 + graph.registerOperator("gelu", [](const GraphNode& node, + const std::vector>& inputs, + std::vector>& outputs) { + if (inputs.empty() || outputs.empty()) { + throw std::runtime_error("gelu requires 1 input and 1 output"); + } + const auto& x = inputs[0]->getTensor(); + + // GELU近似: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) + YTensorBase result = x; // 简化版本,实际应该计算GELU + for (size_t i = 0; i < result.size(); i++) { + float val = result.atData(i); + // 简化的GELU近似 + float gelu_val = val * 0.5f * (1.0f + std::tanh(0.797885f * (val + 0.044715f * val * val * val))); + result.atData(i) = gelu_val; + } + outputs[0]->setTensor(result); + }); + } + + // 打印计算图结构 + void printGraph() { + std::cout << "\n========== 计算图结构 ==========" << std::endl; + std::cout << "节点总数: " << graph.nodeCount() << std::endl; + std::cout << "边总数: " << graph.edgeCount() << std::endl; + + std::cout << "\n节点列表:" << std::endl; + for (const auto& [id, node] : graph.getNodes()) { + std::string nodeTypeStr; + switch(node->getNodeType()) { + case NodeType::Operator: nodeTypeStr = "算子"; break; + case NodeType::Parameter: nodeTypeStr = "参数"; break; + case NodeType::Input: nodeTypeStr = "输入"; break; + case NodeType::Constant: nodeTypeStr = "常量"; break; + } + std::cout << " [" << id << "] " << node->getName() + << " (类型: " << nodeTypeStr << ", op: " << node->getOpType() << ")" << std::endl; + std::cout << " 输入边: "; + for (const auto& edge : node->getInputEdges()) { + std::cout << edge->getId() << " "; + } + std::cout << std::endl; + std::cout << " 输出边: "; + for (const auto& edge : node->getOutputEdges()) { + std::cout << edge->getId() << " "; + } + std::cout << std::endl; + } + + std::cout << "\n边列表:" << std::endl; + for (const auto& [id, edge] : graph.getEdges()) { + std::cout << " [" << id << "] "; + if (edge->getFromNode()) { + std::cout << edge->getFromNode()->getId(); + } else { + std::cout << ""; + } + std::cout << " -> "; + if (edge->getToNode()) { + std::cout << edge->getToNode()->getId(); + } else { + std::cout << ""; + } + std::cout << std::endl; + } + std::cout << "==============================\n" << std::endl; + } + + // 执行前向传播 + YTensorBase forward(const YTensorBase& x, const YTensorBase& w1, const YTensorBase& w2) { + std::cout << "执行前向传播..." << std::endl; + + // 设置参数节点的数据 + auto w1_node = graph.getNode("w1"); + auto w2_node = graph.getNode("w2"); + w1_node->setData(w1); + w2_node->setData(w2); + + // 准备输入(只有外部输入x) + std::unordered_map inputs = { + {"x", x} + }; + + // 执行计算图 + auto outputs = graph.execute(inputs, {"out"}); + + std::cout << "前向传播完成!" << std::endl; + return outputs["out"]; + } + + // 序列化为JSON + std::string toJson() { + return GraphSerializer::toJson(graph); + } +}; + +int main() { + std::cout << "========================================" << std::endl; + std::cout << " YTensor 计算图测试程序" << std::endl; + std::cout << "========================================\n" << std::endl; + + try { + // 创建测试图 + TestFFNGraph test_graph; + test_graph.build(); + + // 打印计算图结构 + test_graph.printGraph(); + + // 准备测试数据 + std::cout << "准备测试数据..." << std::endl; + int batch = 2, seq_len = 3, hidden = 4, intermediate = 6; + + YTensorBase x({batch, seq_len, hidden}, "float32"); + YTensorBase w1({hidden, intermediate}, "float32"); + YTensorBase w2({intermediate, hidden}, "float32"); + + // 初始化数据 + for (size_t i = 0; i < x.size(); i++) { + x.atData(i) = 0.1f * (i % 10); + } + for (size_t i = 0; i < w1.size(); i++) { + w1.atData(i) = 0.01f * (i % 20 - 10); + } + for (size_t i = 0; i < w2.size(); i++) { + w2.atData(i) = 0.01f * (i % 20 - 10); + } + + std::cout << " 输入形状: [" << x.shape(0) << ", " << x.shape(1) << ", " << x.shape(2) << "]" << std::endl; + std::cout << " W1形状: [" << w1.shape(0) << ", " << w1.shape(1) << "]" << std::endl; + std::cout << " W2形状: [" << w2.shape(0) << ", " << w2.shape(1) << "]" << std::endl; + + // 执行前向传播 + auto output = test_graph.forward(x, w1, w2); + + std::cout << "\n前向传播结果:" << std::endl; + std::cout << " 输出形状: [" << output.shape(0) << ", " << output.shape(1) << ", " << output.shape(2) << "]" << std::endl; + std::cout << " 输出数据(前10个元素): "; + for (int i = 0; i < std::min(10, (int)output.size()); i++) { + std::cout << std::fixed << std::setprecision(4) << output.atData(i) << " "; + } + std::cout << std::endl; + + // 序列化为JSON + std::cout << "\n序列化为JSON..." << std::endl; + std::string json = test_graph.toJson(); + std::cout << "JSON输出:" << std::endl; + std::cout << json << std::endl; + + // 保存到文件 + std::string filename = "/tmp/test_graph.json"; + if (GraphSerializer::toJsonFile(test_graph.graph, filename)) { + std::cout << "\n计算图已保存到: " << filename << std::endl; + } + + // 验证拓扑排序 + std::cout << "\n拓扑排序结果:" << std::endl; + auto topo_order = test_graph.graph.topologicalSort(); + for (size_t i = 0; i < topo_order.size(); i++) { + auto node = test_graph.graph.getNode(topo_order[i]); + std::cout << " " << (i+1) << ". " << node->getName() + << " (" << node->getId() << ")" << std::endl; + } + + std::cout << "\n✓ 所有测试通过!" << std::endl; + + } catch (const std::exception& e) { + std::cerr << "✗ 错误: " << e.what() << std::endl; + return 1; + } + + std::cout << "\n========================================" << std::endl; + std::cout << " 测试完成" << std::endl; + std::cout << "========================================" << std::endl; + + return 0; +}