Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions hipfile/src/amd_detail/batch/batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ BatchContext::submit_operations(const hipFileIOParams_t *params, unsigned num_pa
throw std::invalid_argument(msg.str());
}

std::vector<std::shared_ptr<BatchOperation>> pending_ops{};
std::vector<std::shared_ptr<IBatchOperation>> pending_ops{};

// It would be more performant to be able to perform multiple lookups
// rather than waiting to lock the DriverState lock for each lookup.
Expand All @@ -124,7 +124,7 @@ BatchContext::submit_operations(const hipFileIOParams_t *params, unsigned num_pa
// file flags.
auto [_file, _buffer] = Context<DriverState>::get()->getFileAndBuffer(
param_copy->fh, param_copy->u.batch.devPtr_base, param_copy->u.batch.size, 0);
auto op = std::make_shared<BatchOperation>(std::move(param_copy), _buffer, _file);
auto op = std::shared_ptr<IBatchOperation>{new BatchOperation{std::move(param_copy), _buffer, _file}};

pending_ops.push_back(op);
}
Expand All @@ -133,6 +133,11 @@ BatchContext::submit_operations(const hipFileIOParams_t *params, unsigned num_pa
outstanding_ops.insert(pending_ops.begin(), pending_ops.end());
}

std::unordered_set<std::shared_ptr<IBatchOperation>>&
BatchContextAccessor::get_ops_set(BatchContext& _context){
return _context.outstanding_ops;
}

void
BatchContextMap::clear()
{
Expand Down
23 changes: 21 additions & 2 deletions hipfile/src/amd_detail/batch/batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,13 @@ struct InvalidBatchHandle : public std::invalid_argument {
}
};

class IBatchOperation {
public:
virtual ~IBatchOperation() = default;
};

/// @brief Represents a single IO Request
class BatchOperation {
class BatchOperation : public IBatchOperation {
public:
/// @brief Create an operation to handle and track an IO request.
/// @param [in] params IO parameters
Expand Down Expand Up @@ -89,13 +94,27 @@ class BatchContext : public IBatchContext {
/// but is not yet complete or completed but not yet retrieved by the
/// application.
/// shared_ptr as it may need to be passed to a backend.
std::unordered_set<std::shared_ptr<BatchOperation>> outstanding_ops;
std::unordered_set<std::shared_ptr<IBatchOperation>> outstanding_ops;

BatchContext(unsigned capacity);

friend class BatchContextAccessor;
friend class BatchContextMap;
};

/*
* Friend class of BatchContext
*
* Can be used to peer into BatchContext's hidden members.
* Should not be used in production.
*/
class BatchContextAccessor {
public:
// Return a reference to the unordered_set to modify what ops are loaded
// in the context.
std::unordered_set<std::shared_ptr<IBatchOperation>>& get_ops_set(BatchContext& _context);
};

class BatchContextMap {
public:
/*!
Expand Down
14 changes: 14 additions & 0 deletions hipfile/test/amd_detail/batch/batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "hipfile-test.h"
#include "hipfile-warnings.h"
#include "invalid-enum.h"
#include "mbatch.h"
#include "mbuffer.h"
#include "mfile.h"
#include "mstate.h"
Expand Down Expand Up @@ -348,4 +349,17 @@ TEST_F(HipFileBatchContext, SubmitSingleBadParamModeInvalid)
ASSERT_THROW(_context->submit_operations(&bad_io_params, 1), std::invalid_argument);
}

// Not a real test - proof of concept
TEST_F(HipFileBatchContext, _InsertMBatchOperationIntoContext)
{
BatchContextAccessor bca;
auto ops = bca.get_ops_set(*std::dynamic_pointer_cast<BatchContext>(_context));

std::shared_ptr<IBatchOperation> mock_op = std::make_unique<MBatchOperation>();

ops.insert(mock_op);

ASSERT_EQ(1, ops.size());
}

HIPFILE_WARN_NO_GLOBAL_CTOR_ON
3 changes: 3 additions & 0 deletions hipfile/test/amd_detail/mbatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@

namespace hipFile {

class MBatchOperation : public IBatchOperation {
};

class MBatchContext : public IBatchContext {
public:
MOCK_METHOD(unsigned, get_capacity, (), (const, noexcept, override));
Expand Down
Loading