Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions tests/cpp/test_welford.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#include <tests/cpp/utils.h>
#include <tests/cpp/validator.h>
#include <type.h>
#include "scheduler/utils.h"
#include "utils.h"

namespace nvfuser {

Expand Down Expand Up @@ -714,9 +716,18 @@ TEST_F(NVFuserTest, Translate1Welford) {
runtime1->fusionSegments()->groups()[0]->exprs().size() > 2);

// Run an un-translated welford, use a large inner size to ensure it is not
// translated. Cluster reduction uses 16 SMs, can hold up to 32K * 16 = 512K
// elements.
auto runtime2 = run_test(512 * 1024 + 1024);
// translated.
// Context: cluster reduction only uses register persistence while shared
// memory persistence is also used if cluster reduction is not supported.
const int64_t sm_per_cluster = scheduler_utils::getMaxClusterSize();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: Add #include <scheduler/utils.h> at the top of the file (after line 14) to declare scheduler_utils::getMaxClusterSize(). Without this include, the code will not compile.

Reference: tests/cpp/test_cluster.cpp includes this header on line 17.

const int64_t regs_buffer_count =
scheduler_utils::register_file_size_bit / 32;
const int64_t smem_buffer_count =
ceilDiv(deviceAvailableSharedMemoryBytes(), 4);
const int64_t total_elements = sm_per_cluster == 1
? scheduler_utils::roundUpPow2Or8(smem_buffer_count)
: regs_buffer_count * sm_per_cluster;
auto runtime2 = run_test(total_elements + 1024);

bool found_welford = false;
for (auto group : runtime2->fusionSegments()->groups()) {
Expand Down