Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/amd_detail/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ set(HIPFILE_SOURCES
file-descriptor.cpp
hip.cpp
hipfile.cpp
hipfile-stats.cpp
io.cpp
mountinfo.cpp
state.cpp
Expand Down
95 changes: 95 additions & 0 deletions src/amd_detail/hipfile-stats.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/* Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
*
* SPDX-License-Identifier: MIT
*/

#include "include_internal/hipfile-stats.h"
#include "stats.h"

#include <sstream>
#include <unistd.h>

using namespace hipFile;

hipFileStatsError_t
hipFileStatsCreateContext(hipFileStatsContext_t **context, int targetPid)
{
if (!context) {
return hipFileStatsInvalidArgument;
}
if (targetPid <= 0) {
*context = nullptr;
return hipFileStatsInvalidArgument;
}
try {
*context = reinterpret_cast<hipFileStatsContext_t *>(new StatsClient{targetPid});
}
catch (...) {
*context = nullptr;
return hipFileStatsTargetProcessNotFound;
}
return hipFileStatsSuccess;
}

void
hipFileStatsCloseContext(hipFileStatsContext_t *context)
{
if (!context) {
return;
}
delete reinterpret_cast<IStatsClient *>(context);
}

hipFileStatsError_t
hipFileStatsConnectToTargetProcess(hipFileStatsContext_t *context)
{
if (!context) {
return hipFileStatsInvalidArgument;
}
IStatsClient *client{reinterpret_cast<IStatsClient *>(context)};
if (!client->connectServer()) {
return hipFileStatsTargetProcessNotAccessible;
}
return hipFileStatsSuccess;
}

hipFileStatsError_t
hipFileStatsPollTargetProcess(const hipFileStatsContext_t *context, bool block)
{
if (!context) {
return hipFileStatsInvalidArgument;
}
const IStatsClient *client{reinterpret_cast<const IStatsClient *>(context)};
if (!client->pollProcess(block ? -1 : 0)) {
return hipFileStatsTargetProcessNotAccessible;
}
return hipFileStatsSuccess;
}

hipFileStatsError_t
hipFileStatsGenerateReport(const hipFileStatsContext_t *context, int fd)
{
if (!context || fd < 0) {
return hipFileStatsInvalidArgument;
}
const IStatsClient *client{reinterpret_cast<const IStatsClient *>(context)};
std::ostringstream stream;
if (!client->generateReport(stream)) {
return hipFileStatsReportGenerationFailed;
}
std::string report{stream.str()};
const char *data{report.c_str()};
size_t total_size{report.size()};
size_t written{0};
while (written < total_size) {
ssize_t n{write(fd, data + written, total_size - written)};
if (n <= 0) {
if (errno == EINTR) {
continue;
}
return hipFileStatsReportGenerationFailed;
}
written += static_cast<size_t>(n);
}
return hipFileStatsSuccess;
}
83 changes: 83 additions & 0 deletions src/amd_detail/include_internal/hipfile-stats.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/* Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
*
* SPDX-License-Identifier: MIT
*/

#pragma once

#include "hipfile.h"

#ifdef __cplusplus
extern "C" {
#endif

/** @brief Opaque pointer to a stats collection context */
typedef struct hipFileStatsContext hipFileStatsContext_t;

/**
* @enum hipFileStatsError
* @brief Error codes returned by stats collection API functions
*/
typedef enum hipFileStatsError {
hipFileStatsSuccess, /**< Operation completed successfully */
hipFileStatsInvalidArgument, /**< Invalid argument passed to function */
hipFileStatsTargetProcessNotFound, /**< Target process with given PID not found */
hipFileStatsTargetProcessNotAccessible, /**< Cannot access target process */
hipFileStatsReportGenerationFailed, /**< Failed to generate or write report */
} hipFileStatsError_t;

/**
* @brief Create a new stats collection context for a target process
* @param[out] context Pointer to store the created context handle
* @param[in] targetPid Process ID of the target process to monitor
* @return #hipFileStatsSuccess on success, error code otherwise
*
* Creates a new statistics collection context for the specified target process.
* The returned context must be freed with hipFileStatsCloseContext().
*/
HIPFILE_API hipFileStatsError_t hipFileStatsCreateContext(hipFileStatsContext_t **context, int targetPid);

/**
* @brief Close and free a stats collection context
* @param[in] context Stats context handle to close (may be NULL)
*
* Closes the specified stats context and releases all associated resources.
* Safe to call with NULL pointer.
*/
HIPFILE_API void hipFileStatsCloseContext(hipFileStatsContext_t *context);

/**
* @brief Connect the stats context to its target process
* @param[in] context Stats context handle
* @return #hipFileStatsSuccess on success, error code otherwise
*
* Establishes connection to the target process for statistics collection.
* Must be called before hipFileStatsGenerateReport().
*/
HIPFILE_API hipFileStatsError_t hipFileStatsConnectToTargetProcess(hipFileStatsContext_t *context);

/**
* @brief Poll the target process for updated statistics
* @param[in] context Stats context handle
* @param[in] block Whether to block until the target process completes
* @return #hipFileStatsSuccess on success, error code otherwise
*
* Polls the target process for completion.
* If block is true, waits indefinitely.
*/
HIPFILE_API hipFileStatsError_t hipFileStatsPollTargetProcess(const hipFileStatsContext_t *context,
bool block);

/**
* @brief Generate a statistics report and write it to a file descriptor
* @param[in] context Stats context handle
* @param[in] fd File descriptor to write the report to
* @return #hipFileStatsSuccess on success, error code otherwise
*
* Generates a formatted statistics report from collected data and writes it to
* the specified file descriptor.
*/
HIPFILE_API hipFileStatsError_t hipFileStatsGenerateReport(const hipFileStatsContext_t *context, int fd);
#ifdef __cplusplus
}
#endif
4 changes: 2 additions & 2 deletions src/amd_detail/stats.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ StatsClient::StatsClient(pid_t p)
}

bool
StatsClient::pollProcess(int timeout)
StatsClient::pollProcess(int timeout) const
{
if (m_pfd.get() == -1) {
return true;
Expand Down Expand Up @@ -201,7 +201,7 @@ StatsClient::connectServer()
}

bool
StatsClient::generateReport(std::ostream &stream)
StatsClient::generateReport(std::ostream &stream) const
{
if (m_sfd.get() == -1) {
return false;
Expand Down
16 changes: 12 additions & 4 deletions src/amd_detail/stats.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,20 @@ class StatsServer {
std::thread m_thread;
};

class StatsClient {
class IStatsClient {
public:
virtual ~IStatsClient() = default;
virtual bool pollProcess(int timeout) const = 0;
virtual bool connectServer() = 0;
virtual bool generateReport(std::ostream &stream) const = 0;
};

class StatsClient : public IStatsClient {
public:
explicit StatsClient(pid_t p);
bool pollProcess(int timeout);
bool connectServer();
bool generateReport(std::ostream &stream);
bool pollProcess(int timeout) const override;
bool connectServer() override;
bool generateReport(std::ostream &stream) const override;

static void generateReportV1(std::ostream &stream, const Stats *stats);

Expand Down
1 change: 1 addition & 0 deletions test/amd_detail/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ set(TEST_SOURCE_FILES
handle.cpp
hip.cpp
hipfile-api.cpp
hipfile-stats.cpp
fallback.cpp
fastpath.cpp
main.cpp
Expand Down
114 changes: 114 additions & 0 deletions test/amd_detail/hipfile-stats.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/* Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
*
* SPDX-License-Identifier: MIT
*/
#include "hipfile-test.h"
#include "include_internal/hipfile-stats.h"
#include "msys.h"
#include "stats.h"

#include <gmock/gmock.h>
#include <gtest/gtest.h>

using namespace hipFile;
using ::testing::StrictMock;

// Put tests inside the macros to suppress the global constructor
// warnings
HIPFILE_WARN_NO_GLOBAL_CTOR_OFF

class MStatsClient : public IStatsClient {
public:
MOCK_METHOD(bool, pollProcess, (int timeout), (const, override));
MOCK_METHOD(bool, connectServer, (), (override));
MOCK_METHOD(bool, generateReport, (std::ostream & stream), (const, override));
};

struct HipFileStatsApi : public HipFileUnopened {};

TEST_F(HipFileStatsApi, CreateContextInvalidArgument)
{
hipFileStatsContext_t *context;
EXPECT_EQ(hipFileStatsCreateContext(nullptr, 1234), hipFileStatsInvalidArgument);
EXPECT_EQ(hipFileStatsCreateContext(&context, -1), hipFileStatsInvalidArgument);
EXPECT_EQ(context, nullptr);
EXPECT_EQ(hipFileStatsCreateContext(&context, 0), hipFileStatsInvalidArgument);
EXPECT_EQ(context, nullptr);
}

TEST_F(HipFileStatsApi, CreateContextInvalidTarget)
{
StrictMock<MSys> msys{};
EXPECT_CALL(msys, pidfd_open)
.WillOnce(testing::Throw(std::system_error(EINVAL, std::generic_category())));
hipFileStatsContext_t *context;
EXPECT_EQ(hipFileStatsCreateContext(&context, 1234), hipFileStatsTargetProcessNotFound);
EXPECT_EQ(context, nullptr);
}

TEST_F(HipFileStatsApi, ConnectToTargetProcessInvalidArgument)
{
EXPECT_EQ(hipFileStatsConnectToTargetProcess(nullptr), hipFileStatsInvalidArgument);
}

TEST_F(HipFileStatsApi, ConnectToTargetProcessFailure)
{
StrictMock<MStatsClient> client{};
EXPECT_CALL(client, connectServer).WillOnce(testing::Return(false));
EXPECT_EQ(hipFileStatsConnectToTargetProcess(reinterpret_cast<hipFileStatsContext_t *>(&client)),
hipFileStatsTargetProcessNotAccessible);
}

TEST_F(HipFileStatsApi, ConnectToTargetProcessSuccess)
{
StrictMock<MStatsClient> client{};
EXPECT_CALL(client, connectServer).WillOnce(testing::Return(true));
EXPECT_EQ(hipFileStatsConnectToTargetProcess(reinterpret_cast<hipFileStatsContext_t *>(&client)),
hipFileStatsSuccess);
}

TEST_F(HipFileStatsApi, PollTargetProcessInvalidArgument)
{
EXPECT_EQ(hipFileStatsPollTargetProcess(nullptr, false), hipFileStatsInvalidArgument);
}

TEST_F(HipFileStatsApi, PollTargetProcessFailure)
{
StrictMock<MStatsClient> client{};
EXPECT_CALL(client, pollProcess).WillOnce(testing::Return(false));
EXPECT_EQ(hipFileStatsPollTargetProcess(reinterpret_cast<const hipFileStatsContext_t *>(&client), false),
hipFileStatsTargetProcessNotAccessible);
}

TEST_F(HipFileStatsApi, PollTargetProcessSuccess)
{
StrictMock<MStatsClient> client{};
EXPECT_CALL(client, pollProcess).WillOnce(testing::Return(true));
EXPECT_EQ(hipFileStatsPollTargetProcess(reinterpret_cast<const hipFileStatsContext_t *>(&client), false),
hipFileStatsSuccess);
}

TEST_F(HipFileStatsApi, GenerateReportInvalidArgument)
{
EXPECT_EQ(hipFileStatsGenerateReport(nullptr, 1), hipFileStatsInvalidArgument);
EXPECT_EQ(hipFileStatsGenerateReport(reinterpret_cast<const hipFileStatsContext_t *>(0x1234), -1),
hipFileStatsInvalidArgument);
}

TEST_F(HipFileStatsApi, GenerateReportSuccess)
{
StrictMock<MStatsClient> client{};
EXPECT_CALL(client, generateReport).WillOnce(testing::Return(true));
EXPECT_EQ(hipFileStatsGenerateReport(reinterpret_cast<const hipFileStatsContext_t *>(&client), 1),
hipFileStatsSuccess);
}

TEST_F(HipFileStatsApi, GenerateReportFailure)
{
StrictMock<MStatsClient> client{};
EXPECT_CALL(client, generateReport).WillOnce(testing::Return(false));
EXPECT_EQ(hipFileStatsGenerateReport(reinterpret_cast<const hipFileStatsContext_t *>(&client), 1),
hipFileStatsReportGenerationFailed);
}

HIPFILE_WARN_NO_GLOBAL_CTOR_ON
Loading
Loading