diff --git a/src/amd_detail/CMakeLists.txt b/src/amd_detail/CMakeLists.txt index c264f630..143d078e 100644 --- a/src/amd_detail/CMakeLists.txt +++ b/src/amd_detail/CMakeLists.txt @@ -20,6 +20,7 @@ set(HIPFILE_SOURCES file-descriptor.cpp hip.cpp hipfile.cpp + hipfile-stats.cpp io.cpp mountinfo.cpp state.cpp diff --git a/src/amd_detail/hipfile-stats.cpp b/src/amd_detail/hipfile-stats.cpp new file mode 100644 index 00000000..a82bc061 --- /dev/null +++ b/src/amd_detail/hipfile-stats.cpp @@ -0,0 +1,95 @@ +/* Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + * + * SPDX-License-Identifier: MIT + */ + +#include "include_internal/hipfile-stats.h" +#include "stats.h" + +#include +#include + +using namespace hipFile; + +hipFileStatsError_t +hipFileStatsCreateContext(hipFileStatsContext_t **context, int targetPid) +{ + if (!context) { + return hipFileStatsInvalidArgument; + } + if (targetPid <= 0) { + *context = nullptr; + return hipFileStatsInvalidArgument; + } + try { + *context = reinterpret_cast(new StatsClient{targetPid}); + } + catch (...) { + *context = nullptr; + return hipFileStatsTargetProcessNotFound; + } + return hipFileStatsSuccess; +} + +void +hipFileStatsCloseContext(hipFileStatsContext_t *context) +{ + if (!context) { + return; + } + delete reinterpret_cast(context); +} + +hipFileStatsError_t +hipFileStatsConnectToTargetProcess(hipFileStatsContext_t *context) +{ + if (!context) { + return hipFileStatsInvalidArgument; + } + IStatsClient *client{reinterpret_cast(context)}; + if (!client->connectServer()) { + return hipFileStatsTargetProcessNotAccessible; + } + return hipFileStatsSuccess; +} + +hipFileStatsError_t +hipFileStatsPollTargetProcess(const hipFileStatsContext_t *context, bool block) +{ + if (!context) { + return hipFileStatsInvalidArgument; + } + const IStatsClient *client{reinterpret_cast(context)}; + if (!client->pollProcess(block ? -1 : 0)) { + return hipFileStatsTargetProcessNotAccessible; + } + return hipFileStatsSuccess; +} + +hipFileStatsError_t +hipFileStatsGenerateReport(const hipFileStatsContext_t *context, int fd) +{ + if (!context || fd < 0) { + return hipFileStatsInvalidArgument; + } + const IStatsClient *client{reinterpret_cast(context)}; + std::ostringstream stream; + if (!client->generateReport(stream)) { + return hipFileStatsReportGenerationFailed; + } + std::string report{stream.str()}; + const char *data{report.c_str()}; + size_t total_size{report.size()}; + size_t written{0}; + while (written < total_size) { + ssize_t n{write(fd, data + written, total_size - written)}; + if (n <= 0) { + if (errno == EINTR) { + continue; + } + return hipFileStatsReportGenerationFailed; + } + written += static_cast(n); + } + return hipFileStatsSuccess; +} diff --git a/src/amd_detail/include_internal/hipfile-stats.h b/src/amd_detail/include_internal/hipfile-stats.h new file mode 100644 index 00000000..3efeb919 --- /dev/null +++ b/src/amd_detail/include_internal/hipfile-stats.h @@ -0,0 +1,83 @@ +/* Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + * + * SPDX-License-Identifier: MIT + */ + +#pragma once + +#include "hipfile.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** @brief Opaque pointer to a stats collection context */ +typedef struct hipFileStatsContext hipFileStatsContext_t; + +/** + * @enum hipFileStatsError + * @brief Error codes returned by stats collection API functions + */ +typedef enum hipFileStatsError { + hipFileStatsSuccess, /**< Operation completed successfully */ + hipFileStatsInvalidArgument, /**< Invalid argument passed to function */ + hipFileStatsTargetProcessNotFound, /**< Target process with given PID not found */ + hipFileStatsTargetProcessNotAccessible, /**< Cannot access target process */ + hipFileStatsReportGenerationFailed, /**< Failed to generate or write report */ +} hipFileStatsError_t; + +/** + * @brief Create a new stats collection context for a target process + * @param[out] context Pointer to store the created context handle + * @param[in] targetPid Process ID of the target process to monitor + * @return #hipFileStatsSuccess on success, error code otherwise + * + * Creates a new statistics collection context for the specified target process. + * The returned context must be freed with hipFileStatsCloseContext(). + */ +HIPFILE_API hipFileStatsError_t hipFileStatsCreateContext(hipFileStatsContext_t **context, int targetPid); + +/** + * @brief Close and free a stats collection context + * @param[in] context Stats context handle to close (may be NULL) + * + * Closes the specified stats context and releases all associated resources. + * Safe to call with NULL pointer. + */ +HIPFILE_API void hipFileStatsCloseContext(hipFileStatsContext_t *context); + +/** + * @brief Connect the stats context to its target process + * @param[in] context Stats context handle + * @return #hipFileStatsSuccess on success, error code otherwise + * + * Establishes connection to the target process for statistics collection. + * Must be called before hipFileStatsGenerateReport(). + */ +HIPFILE_API hipFileStatsError_t hipFileStatsConnectToTargetProcess(hipFileStatsContext_t *context); + +/** + * @brief Poll the target process for updated statistics + * @param[in] context Stats context handle + * @param[in] block Whether to block until the target process completes + * @return #hipFileStatsSuccess on success, error code otherwise + * + * Polls the target process for completion. + * If block is true, waits indefinitely. + */ +HIPFILE_API hipFileStatsError_t hipFileStatsPollTargetProcess(const hipFileStatsContext_t *context, + bool block); + +/** + * @brief Generate a statistics report and write it to a file descriptor + * @param[in] context Stats context handle + * @param[in] fd File descriptor to write the report to + * @return #hipFileStatsSuccess on success, error code otherwise + * + * Generates a formatted statistics report from collected data and writes it to + * the specified file descriptor. + */ +HIPFILE_API hipFileStatsError_t hipFileStatsGenerateReport(const hipFileStatsContext_t *context, int fd); +#ifdef __cplusplus +} +#endif diff --git a/src/amd_detail/stats.cpp b/src/amd_detail/stats.cpp index 7e792d34..214e1a30 100644 --- a/src/amd_detail/stats.cpp +++ b/src/amd_detail/stats.cpp @@ -167,7 +167,7 @@ StatsClient::StatsClient(pid_t p) } bool -StatsClient::pollProcess(int timeout) +StatsClient::pollProcess(int timeout) const { if (m_pfd.get() == -1) { return true; @@ -201,7 +201,7 @@ StatsClient::connectServer() } bool -StatsClient::generateReport(std::ostream &stream) +StatsClient::generateReport(std::ostream &stream) const { if (m_sfd.get() == -1) { return false; diff --git a/src/amd_detail/stats.h b/src/amd_detail/stats.h index 3351d2fc..c4a3bc94 100644 --- a/src/amd_detail/stats.h +++ b/src/amd_detail/stats.h @@ -75,12 +75,20 @@ class StatsServer { std::thread m_thread; }; -class StatsClient { +class IStatsClient { +public: + virtual ~IStatsClient() = default; + virtual bool pollProcess(int timeout) const = 0; + virtual bool connectServer() = 0; + virtual bool generateReport(std::ostream &stream) const = 0; +}; + +class StatsClient : public IStatsClient { public: explicit StatsClient(pid_t p); - bool pollProcess(int timeout); - bool connectServer(); - bool generateReport(std::ostream &stream); + bool pollProcess(int timeout) const override; + bool connectServer() override; + bool generateReport(std::ostream &stream) const override; static void generateReportV1(std::ostream &stream, const Stats *stats); diff --git a/test/amd_detail/CMakeLists.txt b/test/amd_detail/CMakeLists.txt index 96e871ad..b09f9a8e 100644 --- a/test/amd_detail/CMakeLists.txt +++ b/test/amd_detail/CMakeLists.txt @@ -22,6 +22,7 @@ set(TEST_SOURCE_FILES handle.cpp hip.cpp hipfile-api.cpp + hipfile-stats.cpp fallback.cpp fastpath.cpp main.cpp diff --git a/test/amd_detail/hipfile-stats.cpp b/test/amd_detail/hipfile-stats.cpp new file mode 100644 index 00000000..130ff627 --- /dev/null +++ b/test/amd_detail/hipfile-stats.cpp @@ -0,0 +1,114 @@ +/* Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + * + * SPDX-License-Identifier: MIT + */ +#include "hipfile-test.h" +#include "include_internal/hipfile-stats.h" +#include "msys.h" +#include "stats.h" + +#include +#include + +using namespace hipFile; +using ::testing::StrictMock; + +// Put tests inside the macros to suppress the global constructor +// warnings +HIPFILE_WARN_NO_GLOBAL_CTOR_OFF + +class MStatsClient : public IStatsClient { +public: + MOCK_METHOD(bool, pollProcess, (int timeout), (const, override)); + MOCK_METHOD(bool, connectServer, (), (override)); + MOCK_METHOD(bool, generateReport, (std::ostream & stream), (const, override)); +}; + +struct HipFileStatsApi : public HipFileUnopened {}; + +TEST_F(HipFileStatsApi, CreateContextInvalidArgument) +{ + hipFileStatsContext_t *context; + EXPECT_EQ(hipFileStatsCreateContext(nullptr, 1234), hipFileStatsInvalidArgument); + EXPECT_EQ(hipFileStatsCreateContext(&context, -1), hipFileStatsInvalidArgument); + EXPECT_EQ(context, nullptr); + EXPECT_EQ(hipFileStatsCreateContext(&context, 0), hipFileStatsInvalidArgument); + EXPECT_EQ(context, nullptr); +} + +TEST_F(HipFileStatsApi, CreateContextInvalidTarget) +{ + StrictMock msys{}; + EXPECT_CALL(msys, pidfd_open) + .WillOnce(testing::Throw(std::system_error(EINVAL, std::generic_category()))); + hipFileStatsContext_t *context; + EXPECT_EQ(hipFileStatsCreateContext(&context, 1234), hipFileStatsTargetProcessNotFound); + EXPECT_EQ(context, nullptr); +} + +TEST_F(HipFileStatsApi, ConnectToTargetProcessInvalidArgument) +{ + EXPECT_EQ(hipFileStatsConnectToTargetProcess(nullptr), hipFileStatsInvalidArgument); +} + +TEST_F(HipFileStatsApi, ConnectToTargetProcessFailure) +{ + StrictMock client{}; + EXPECT_CALL(client, connectServer).WillOnce(testing::Return(false)); + EXPECT_EQ(hipFileStatsConnectToTargetProcess(reinterpret_cast(&client)), + hipFileStatsTargetProcessNotAccessible); +} + +TEST_F(HipFileStatsApi, ConnectToTargetProcessSuccess) +{ + StrictMock client{}; + EXPECT_CALL(client, connectServer).WillOnce(testing::Return(true)); + EXPECT_EQ(hipFileStatsConnectToTargetProcess(reinterpret_cast(&client)), + hipFileStatsSuccess); +} + +TEST_F(HipFileStatsApi, PollTargetProcessInvalidArgument) +{ + EXPECT_EQ(hipFileStatsPollTargetProcess(nullptr, false), hipFileStatsInvalidArgument); +} + +TEST_F(HipFileStatsApi, PollTargetProcessFailure) +{ + StrictMock client{}; + EXPECT_CALL(client, pollProcess).WillOnce(testing::Return(false)); + EXPECT_EQ(hipFileStatsPollTargetProcess(reinterpret_cast(&client), false), + hipFileStatsTargetProcessNotAccessible); +} + +TEST_F(HipFileStatsApi, PollTargetProcessSuccess) +{ + StrictMock client{}; + EXPECT_CALL(client, pollProcess).WillOnce(testing::Return(true)); + EXPECT_EQ(hipFileStatsPollTargetProcess(reinterpret_cast(&client), false), + hipFileStatsSuccess); +} + +TEST_F(HipFileStatsApi, GenerateReportInvalidArgument) +{ + EXPECT_EQ(hipFileStatsGenerateReport(nullptr, 1), hipFileStatsInvalidArgument); + EXPECT_EQ(hipFileStatsGenerateReport(reinterpret_cast(0x1234), -1), + hipFileStatsInvalidArgument); +} + +TEST_F(HipFileStatsApi, GenerateReportSuccess) +{ + StrictMock client{}; + EXPECT_CALL(client, generateReport).WillOnce(testing::Return(true)); + EXPECT_EQ(hipFileStatsGenerateReport(reinterpret_cast(&client), 1), + hipFileStatsSuccess); +} + +TEST_F(HipFileStatsApi, GenerateReportFailure) +{ + StrictMock client{}; + EXPECT_CALL(client, generateReport).WillOnce(testing::Return(false)); + EXPECT_EQ(hipFileStatsGenerateReport(reinterpret_cast(&client), 1), + hipFileStatsReportGenerationFailed); +} + +HIPFILE_WARN_NO_GLOBAL_CTOR_ON diff --git a/tools/ais-stats/ais-stats.cpp b/tools/ais-stats/ais-stats.cpp index 46543394..2b4863af 100644 --- a/tools/ais-stats/ais-stats.cpp +++ b/tools/ais-stats/ais-stats.cpp @@ -3,7 +3,7 @@ * SPDX-License-Identifier: MIT */ -#include "stats.h" +#include "include_internal/hipfile-stats.h" #include #include #include @@ -47,17 +47,37 @@ main(int argc, char *argv[]) return 1; } } - hipFile::StatsClient client{pid}; - if (!client.connectServer()) { - std::cerr << "Failed to collect info from target process.\n"; + + hipFileStatsContext_t *context; + if (hipFileStatsCreateContext(&context, pid) != hipFileStatsSuccess) { + std::cerr << "Failed to create context for target process.\n"; + return 1; + } + if (hipFileStatsConnectToTargetProcess(context) != hipFileStatsSuccess) { + std::cerr << "Failed to connect to target process.\n"; + hipFileStatsCloseContext(context); return 1; } if (!imm) { - client.pollProcess(-1); + if (hipFileStatsPollTargetProcess(context, true) != hipFileStatsSuccess) { + std::cerr << "Failed to poll target process.\n"; + hipFileStatsCloseContext(context); + return 1; + } + } + int fd{dup(STDOUT_FILENO)}; + if (fd < 0) { + std::cerr << "Failed to duplicate stdout file descriptor.\n"; + hipFileStatsCloseContext(context); + return 1; } - if (!client.generateReport(std::cout)) { - std::cerr << "No stats could be collected from target process.\n"; + if (hipFileStatsGenerateReport(context, fd) != hipFileStatsSuccess) { + std::cerr << "Failed to generate report from target process.\n"; + close(fd); + hipFileStatsCloseContext(context); return 1; } + close(fd); + hipFileStatsCloseContext(context); return 0; }