Skip to content
18 changes: 18 additions & 0 deletions include/ad.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#pragma once
/***************
* @file: ad.hpp
* @brief: 自动微分模块总入口
* @description: 包含计算图的所有头文件和实现
***************/

// 头文件
#include "ad/graph_node.hpp"
#include "ad/graph_edge.hpp"
#include "ad/graph_runtime.hpp"
#include "ad/graph_serialization.hpp"

// 实现文件
#include "../src/ad/graph_node.inl"
#include "../src/ad/graph_edge.inl"
#include "../src/ad/graph_runtime.inl"
#include "../src/ad/graph_serialization.inl"
93 changes: 93 additions & 0 deletions include/ad/graph_edge.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#pragma once
/***************
* @file: graph_edge.hpp
* @brief: 计算图边定义,表示计算图中的数据流
* @description: 边代表有向的数据流,连接源节点和目标节点,携带张量数据
***************/

#include <string>
#include <memory>
#include "../ytensor_base.hpp"

namespace yt{
namespace ad{

// 前向声明
class GraphNode;

/// @brief 计算图边类,表示节点之间的有向数据流
class GraphEdge {
public:
/// @brief 默认构造函数
GraphEdge() = default;

/// @brief 构造一个计算图边
/// @param edgeId 边的唯一标识符
/// @param fromNode 源节点(可为nullptr,表示输入边)
/// @param toNode 目标节点(可为nullptr,表示输出边)
GraphEdge(const std::string& edgeId,
std::shared_ptr<GraphNode> fromNode = nullptr,
std::shared_ptr<GraphNode> toNode = nullptr);

/// @brief 获取边ID
/// @return 边的唯一标识符
std::string getId() const { return edgeId_; }

/// @brief 设置边名称(可选)
/// @param name 边名称
void setName(const std::string& name) { name_ = name; }

/// @brief 获取边名称
/// @return 边名称
std::string getName() const { return name_; }

/// @brief 设置源节点
/// @param node 源节点指针
void setFromNode(std::shared_ptr<GraphNode> node) { fromNode_ = node; }

/// @brief 设置目标节点
/// @param node 目标节点指针
void setToNode(std::shared_ptr<GraphNode> node) { toNode_ = node; }

/// @brief 获取源节点
/// @return 源节点指针
std::shared_ptr<GraphNode> getFromNode() const { return fromNode_; }

/// @brief 获取目标节点
/// @return 目标节点指针
std::shared_ptr<GraphNode> getToNode() const { return toNode_; }

/// @brief 设置边上的张量数据
/// @param tensor 张量数据
void setTensor(const YTensorBase& tensor) { tensor_ = tensor; }

/// @brief 获取边上的张量数据
/// @return 张量数据的引用
YTensorBase& getTensor() { return tensor_; }

/// @brief 获取边上的张量数据(const版本)
/// @return 张量数据的const引用
const YTensorBase& getTensor() const { return tensor_; }

/// @brief 检查边是否包含有效的张量数据
/// @return 如果张量有效返回true
bool hasTensor() const { return tensor_.size() > 0; }

/// @brief 检查是否为输入边(没有源节点)
/// @return 如果是输入边返回true
bool isInputEdge() const { return fromNode_ == nullptr; }

/// @brief 检查是否为输出边(没有目标节点)
/// @return 如果是输出边返回true
bool isOutputEdge() const { return toNode_ == nullptr; }

private:
std::string edgeId_; // 边的唯一标识符
std::string name_; // 边名称(可选)
std::shared_ptr<GraphNode> fromNode_; // 源节点
std::shared_ptr<GraphNode> toNode_; // 目标节点
YTensorBase tensor_; // 边上携带的张量数据
};

} // namespace ad
} // namespace yt
137 changes: 137 additions & 0 deletions include/ad/graph_node.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
#pragma once
/***************
* @file: graph_node.hpp
* @brief: 计算图节点定义,表示计算图中的算子
* @description: 节点代表算子操作,包含操作类型、参数和连接关系
***************/

#include <string>
#include <vector>
#include <memory>
#include <unordered_map>
#include <any>

namespace yt{
namespace ad{

// 前向声明
class GraphEdge;
class ComputationGraph;

/// @brief 节点类型枚举
enum class NodeType {
Operator, // 算子节点(如add, mul, matmul等)
Parameter, // 参数节点(如权重、偏置)
Input, // 输入节点
Constant // 常量节点
};

/// @brief 计算图节点类,表示计算图中的一个节点(算子、参数、输入或常量)
class GraphNode {
public:
/// @brief 默认构造函数
GraphNode() = default;

/// @brief 构造一个计算图节点
/// @param nodeId 节点的唯一标识符
/// @param opType 算子类型(如 "add", "mul", "matmul", "relu" 等)或节点类型
/// @param nodeType 节点类型(默认为Operator)
GraphNode(const std::string& nodeId, const std::string& opType, NodeType nodeType = NodeType::Operator);

/// @brief 获取节点ID
/// @return 节点的唯一标识符
std::string getId() const { return nodeId_; }

/// @brief 获取算子类型或节点类型描述
/// @return 算子类型字符串
std::string getOpType() const { return opType_; }

/// @brief 获取节点类型
/// @return 节点类型
NodeType getNodeType() const { return nodeType_; }

/// @brief 设置节点名称(可选,用于调试和可视化)
/// @param name 节点名称
void setName(const std::string& name) { name_ = name; }

/// @brief 获取节点名称
/// @return 节点名称
std::string getName() const { return name_; }

/// @brief 添加输入边
/// @param edge 指向输入边的指针
void addInputEdge(std::shared_ptr<GraphEdge> edge);

/// @brief 添加输出边
/// @param edge 指向输出边的指针
void addOutputEdge(std::shared_ptr<GraphEdge> edge);

/// @brief 获取所有输入边
/// @return 输入边的vector
const std::vector<std::shared_ptr<GraphEdge>>& getInputEdges() const { return inputEdges_; }

/// @brief 获取所有输出边
/// @return 输出边的vector
const std::vector<std::shared_ptr<GraphEdge>>& getOutputEdges() const { return outputEdges_; }

/// @brief 设置节点参数
/// @param key 参数名
/// @param value 参数值(使用std::any以支持多种类型)
void setParameter(const std::string& key, const std::any& value);

/// @brief 获取节点参数
/// @param key 参数名
/// @return 参数值
std::any getParameter(const std::string& key) const;

/// @brief 检查是否存在某个参数
/// @param key 参数名
/// @return 如果参数存在返回true,否则返回false
bool hasParameter(const std::string& key) const;

/// @brief 获取所有参数
/// @return 参数映射表
const std::unordered_map<std::string, std::any>& getParameters() const { return parameters_; }

/// @brief 标记节点是否已执行
/// @param executed 是否已执行
void setExecuted(bool executed) { executed_ = executed; }

/// @brief 检查节点是否已执行
/// @return 如果已执行返回true
bool isExecuted() const { return executed_; }

/// @brief 重置节点执行状态
void reset() { executed_ = false; }

/// @brief 设置节点数据(用于参数节点和常量节点)
/// @param tensor 张量数据
void setData(const YTensorBase& tensor) { data_ = tensor; hasData_ = true; }

/// @brief 获取节点数据
/// @return 张量数据的引用
YTensorBase& getData() { return data_; }

/// @brief 获取节点数据(const版本)
/// @return 张量数据的const引用
const YTensorBase& getData() const { return data_; }

/// @brief 检查节点是否有数据
/// @return 如果有数据返回true
bool hasData() const { return hasData_; }

private:
std::string nodeId_; // 节点唯一标识符
std::string opType_; // 算子类型或节点类型描述
std::string name_; // 节点名称(可选)
NodeType nodeType_ = NodeType::Operator; // 节点类型
std::vector<std::shared_ptr<GraphEdge>> inputEdges_; // 输入边列表
std::vector<std::shared_ptr<GraphEdge>> outputEdges_; // 输出边列表
std::unordered_map<std::string, std::any> parameters_; // 节点参数(用于算子配置)
YTensorBase data_; // 节点数据(用于参数和常量节点)
bool hasData_ = false; // 是否有数据
bool executed_ = false; // 是否已执行
};

} // namespace ad
} // namespace yt
Loading