Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 62 additions & 16 deletions csrc/evaluator_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ void PrecomputedValues::bindParallelExtents(
auto raw_val = launch_constraint.getRawVal(it.first);
if (raw_val > 0) {
for (auto extent : it.second) {
bindValue(extent->evaluatorIndex(), raw_val);
bindValue(extent->evaluatorIndex(), raw_val, extent);
}
}
}
Expand All @@ -198,13 +198,17 @@ void PrecomputedValues::bindConcreteParallelTypeValue(
auto index_list_it = thread_dim_value_indices_.find(pt);
if (index_list_it != thread_dim_value_indices_.end()) {
for (auto index : *(index_list_it->second)) {
bindValue(index, value);
const Val* ir_node = (index >= 0 && index < (int)symbols_.size())
? symbols_[index]
: nullptr;
bindValue(index, value, ir_node);
}
}
}

void PrecomputedValues::bindInputs(const KernelArgumentHolder& args) {
FUSER_PERF_SCOPE("PrecomputedValues::bindInputs");
debug() << "[DEBUG] PrecomputedValues::bindInputs called" << std::endl;
if (hasValidValues()) {
invalidate();
}
Expand All @@ -219,6 +223,9 @@ void PrecomputedValues::bindValues(
std::ssize(inputs),
"kernel inputs size does not match args");

debug() << "[DEBUG] PrecomputedValues::bindValues called with "
<< inputs.size() << " inputs" << std::endl;

for (const auto i : arange((int64_t)inputs.size())) {
const auto input = inputs[i];
NVF_ERROR(input != nullptr);
Expand All @@ -228,7 +235,7 @@ void PrecomputedValues::bindValues(
bindTensorMetaData(tv, tensor);
}
} else {
bindValue(input->evaluatorIndex(), args[i]);
bindValue(input->evaluatorIndex(), args[i], input);
}
}
}
Expand Down Expand Up @@ -360,15 +367,34 @@ void PrecomputedValues::initializeNamedScalars() {
void PrecomputedValues::validate() {
FUSER_PERF_SCOPE("PrecomputedValuess::Validate");
using namespace PolymorphicValue_functions;
for (const auto& it : binding_log_) {
NVF_ERROR(
isSame(values_[it.first], it.second),
"Precomputed values failed to validate.",
"\nSomething unexpected changed between the compilation and "
"execution.\n",
values_[it.first],
" != ",
it.second);
for (const auto& [index, expected_value, ir_node] : binding_log_) {
if (!isSame(values_[index], expected_value)) {
std::stringstream error_msg;
error_msg << "Precomputed values failed to validate.\n"
<< "Something unexpected changed between the compilation and "
"execution.\n";
if (ir_node != nullptr) {
error_msg << "IR node: " << ir_node->toString() << "\n";
}
error_msg << "Computed value: " << toString(values_[index]) << "\n"
<< "Expected value: " << toString(expected_value);

// Debug: Show binding history for this index
debug() << "[DEBUG] ===== VALIDATION FAILED =====" << std::endl;
debug() << "[DEBUG] Binding history for index " << index << ":" << std::endl;
for (const auto& [idx, val, node] : binding_log_) {
if (idx == index) {
debug() << "[DEBUG] Bound to: " << toString(val);
if (node != nullptr) {
debug() << " (node: " << node->toString() << ")";
}
debug() << std::endl;
}
}
debug() << "[DEBUG] ================================" << std::endl;

NVF_ERROR(false, error_msg.str());
}
}
has_valid_values_ = true;
}
Expand All @@ -383,6 +409,21 @@ void PrecomputedValues::bindTensorMetaData(
"Something went wrong configuring launch. Inputs do not match.");

std::vector<int64_t> logical_sizes = unshardedSizes(tv, tensor.sizes());

debug() << "[DEBUG] bindTensorMetaData for TV: " << tv->toString() << std::endl;
debug() << "[DEBUG] Actual tensor.sizes(): [";
for (size_t i = 0; i < tensor.sizes().size(); ++i) {
if (i > 0) debug() << ", ";
debug() << tensor.sizes()[i];
}
debug() << "]" << std::endl;
debug() << "[DEBUG] Unsharded logical_sizes: [";
for (size_t i = 0; i < logical_sizes.size(); ++i) {
if (i > 0) debug() << ", ";
debug() << logical_sizes[i];
}
debug() << "]" << std::endl;

adjustEvaluatorSizes(tv, logical_sizes);

for (const auto dim : arange(static_cast<int64_t>(logical_domain.size()))) {
Expand All @@ -391,12 +432,17 @@ void PrecomputedValues::bindTensorMetaData(
if (id->isBroadcast()) {
// DIDs are ignored for broadcast. See MultideviceShardingTest.Broadcast
// and .ExpandedBroadcast.
bindValue(id->extent()->evaluatorIndex(), 1L);
bindValue(id->extent()->evaluatorIndex(), 1L, id->extent());
if (id->hasExpandedExtent()) {
bindValue(id->expandedExtent()->evaluatorIndex(), dim_size);
bindValue(
id->expandedExtent()->evaluatorIndex(),
dim_size,
id->expandedExtent());
}
} else {
bindValue(id->extent()->evaluatorIndex(), dim_size);
debug() << "[DEBUG] Binding " << id->extent()->toString()
<< " = " << dim_size << std::endl;
bindValue(id->extent()->evaluatorIndex(), dim_size, id->extent());
}
}

Expand Down Expand Up @@ -424,7 +470,7 @@ void PrecomputedValues::bindTensorMetaData(
tv->toString(),
" with input tensor ",
tensor);
bindValue(metadata_val->evaluatorIndex(), metadata);
bindValue(metadata_val->evaluatorIndex(), metadata, metadata_val);
}

NaiveValueMachine::NaiveValueMachine(PrecomputedValues& precomputed_values)
Expand Down
25 changes: 20 additions & 5 deletions csrc/evaluator_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,17 +211,32 @@ class PrecomputedValues {

//! Bind concrete value to the given index
//! if the index is valid.
void bindValue_(int index, const PolymorphicValue& value) {
void bindValue_(
int index,
const PolymorphicValue& value,
const Val* ir_node = nullptr) {
if (index < 0 || is_constant_[index]) {
return;
}

// Debug: show if we're rebinding a value
if (defined_[index]) {
debug() << "[DEBUG] REBINDING index " << index;
if (ir_node != nullptr) {
debug() << " (node: " << ir_node->toString() << ")";
}
debug() << " from " << PolymorphicValue_functions::toString(values_[index])
<< " to " << PolymorphicValue_functions::toString(value) << std::endl;
}

defined_[index] = true;
values_[index] = value;
binding_log_.emplace_back(index, value);
binding_log_.emplace_back(index, value, ir_node);
validate();
}
template <typename T>
void bindValue(int index, const T& value) {
bindValue_(index, PolymorphicValue(value));
void bindValue(int index, const T& value, const Val* ir_node = nullptr) {
bindValue_(index, PolymorphicValue(value), ir_node);
}

//! Invalidate all computed values in the workspace.
Expand Down Expand Up @@ -292,7 +307,7 @@ class PrecomputedValues {
//! An internal log to keep track of all the bindings
//! used in each evaluation cycle. To be used for
//! consistency check.
std::vector<std::pair<int, PolymorphicValue>> binding_log_;
std::vector<std::tuple<int, PolymorphicValue, const Val*>> binding_log_;

//! Integer runtime for realizing the values computations.
std::unique_ptr<NaiveValueMachine> value_machine_;
Expand Down
6 changes: 6 additions & 0 deletions csrc/multidevice/execution_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ std::vector<int64_t> unshardedSizes(
sharded_id) != tv->getLogicalDomain().end()) {
return 1;
}
if (std::find(
tv->getMaybeAllocationDomain().begin(),
tv->getMaybeAllocationDomain().end(),
sharded_id) != tv->getMaybeAllocationDomain().end()) {
return 1;
}

NVF_ERROR(
sharded_id->extent()->isConstInt(),
Expand Down
6 changes: 6 additions & 0 deletions csrc/runtime/fusion_executor_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ KernelArgumentHolder FusionExecutorCache::runFusionWithInputs(
std::optional<int8_t> selected_device) {
FUSER_PERF_SCOPE("FusionExecutorCache::runFusionWithInputs");

// Print fusion IR every run
debug() << "Fusion IR in FusionExecutorCache::runFusionWithInputs:" << std::endl;
fusion_->print();

if (isProfilerEnabled()) {
FusionProfiler::start(!isProfilerEnabledWithCupti());
}
Expand All @@ -63,6 +67,7 @@ KernelArgumentHolder FusionExecutorCache::runFusionWithInputs(
}

if (!kernel_runtime->isCompiled()) {
debug() << "[DEBUG] ===== COMPILING KERNEL =====" << std::endl;
kernel_runtime->compileFusionParallel(args);
}

Expand All @@ -80,6 +85,7 @@ KernelArgumentHolder FusionExecutorCache::runFusionWithInputs(
" failed.");
}

debug() << "[DEBUG] ===== EXECUTING KERNEL =====" << std::endl;
auto outputs = kernel_runtime->runWithInputs(args);

// Kernel time measurement is off by default
Expand Down
3 changes: 3 additions & 0 deletions csrc/runtime/fusion_kernel_runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -578,11 +578,14 @@ std::optional<std::unique_ptr<HeuristicParamsList>> FusionKernelRuntime::
{
FUSER_PERF_SCOPE(
"FusionKernelRuntime::getMaybeHeuristicsFor::PrecomputedValues");
debug() << "[DEBUG] compileFusionParallel: Creating PrecomputedValues and binding inputs" << std::endl;
evaluator_precomputed_values =
std::make_unique<PrecomputedValues>(fusion_to_run);
debug() << "[DEBUG] compileFusionParallel: Calling bindInputs (group_runtime_inputs)" << std::endl;
evaluator_precomputed_values->bindInputs(group_runtime_inputs);
// TODO Remove binding the original fusion inputs when creating
// heuristics for fusion segment.
debug() << "[DEBUG] compileFusionParallel: Calling bindValues (complete fusion inputs)" << std::endl;
evaluator_precomputed_values->bindValues(
group_to_run->getCompleteFusionInputs(), args);
evaluator_precomputed_values->evaluate();
Expand Down
94 changes: 94 additions & 0 deletions tests/cpp/test_stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,4 +193,98 @@ TEST_F(StreamTest, ReplicatedAllocation) {
}
}

TEST_F(StreamTest, Matmul) {
constexpr int64_t c = 3;

auto fusion = std::make_unique<Fusion>();
{
FusionGuard fg(fusion.get());
TensorView* in = makeSymbolicTensor(2);
TensorView* w = makeSymbolicTensor(2);
TensorView* out = matmul(in, w);
fusion->addInput(in);
fusion->addInput(w);
fusion->addOutput(out);

out->outer_split(1, c);
out->axis(1)->parallelize(ParallelType::Stream);
}

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA);
at::Tensor in_tensor = at::randn({5, 7}, options);
at::Tensor w_tensor = at::randn({7, c * 2}, options);

// With NVFUSER_DUMP=host_ir, you'll see the host IR container like the
// following:
// clang-format off
// %HostIrContainer { (T0_g_float[iS0{i0}, iS1{i2}], T1_g_float[istreamIdx7{3}, iS11{i2}, iS8{( ceilDiv(i4, 3) )}]) -> (T2_g_float[istreamIdx9{3}, iS4{i0}, iS10{( ceilDiv(i4, 3) )}, rS6{i2}]) :
// FOR i18 from 0 to 3:
// T2_g_float[istreamIdx9{3}, iS4{i0}, iS10{( ceilDiv(i4, 3) )}, rS6{i2}]
// = matmul(T0_g_float[iS0{i0}, iS1{i2}],
// T1_g_float[istreamIdx7{3}, iS11{i2}, iS8{( ceilDiv(i4, 3) )}])
// } // %HostIrContainer
// clang-format on
FusionExecutorCache executor_cache(std::move(fusion));
auto out_tensor = executor_cache.runFusionWithInputs({in_tensor, w_tensor})[0]
.as<at::Tensor>();

testValidate(
executor_cache.fusion(),
{out_tensor},
{in_tensor, w_tensor},
__LINE__,
__FILE__);
}

TEST_F(StreamTest, TwoMatmuls) {
constexpr int64_t c = 3;

auto fusion = std::make_unique<Fusion>();
{
FusionGuard fg(fusion.get());
TensorView* in = makeSymbolicTensor(2);
TensorView* w1 = makeSymbolicTensor(2);
TensorView* w2 = makeSymbolicTensor(2);
TensorView* out = matmul(in, w1);
out = matmul(out, w2);
fusion->addInput(in);
fusion->addInput(w1);
fusion->addInput(w2);
fusion->addOutput(out);

in->outer_split(0, c);
in->axis(0)->parallelize(ParallelType::Stream);
}

{
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA);
at::Tensor in = at::randn({c * 2, 3}, options);
at::Tensor w1 = at::randn({3, 5}, options);
at::Tensor w2 = at::randn({5, 3}, options);

// With NVFUSER_DUMP=host_ir, you'll see the host IR container like the
// following:
// clang-format off
// %HostIrContainer { (T0_g_float[istreamIdx12{3}, iS13{( ceilDiv(i0, 3) )}, iS1{i2}], T1_g_float[iS14{i2}, iS3{i4}], T2_g_float[iS15{i4}, iS5{i6}]) -> (T4_g_float[istreamIdx18{3}, iS19{( ceilDiv(i0, 3) )}, iS10{i6}, rS11{i4}]) :
// T4_g_float[istreamIdx18{3}, iS19{( ceilDiv(i0, 3) )}, iS10{i6}, rS11{i4}] = ALLOCATE(buffer=T4_g_float[istreamIdx18{3}, iS19{( ceilDiv(i0, 3) )}, iS10{i6}, rS11{i4}], mem_type=global, size=( i0 * i6 ), zero_init=false, resets_to_zero=false)
// FOR i99 from 0 to 3:
// T5_l_float[istreamIdx22{3}, iS23{( ceilDiv(i0, 3) )}, iS21{i2}] = ShardByStream(T0_g_float[istreamIdx12{3}, iS13{( ceilDiv(i0, 3) )}, iS1{i2}], stream_index = i99)
// T3_g_float[istreamIdx16{3}, iS17{( ceilDiv(i0, 3) )}, iS7{i4}, rS8{i2}]
// = matmul(T5_l_float[istreamIdx22{3}, iS23{( ceilDiv(i0, 3) )}, iS21{i2}],
// T1_g_float[iS14{i2}, iS3{i4}])
// T6_l_float[istreamIdx26{3}, iS27{( ceilDiv(i0, 3) )}, iS25{i6}] = ShardByStream(T4_g_float[istreamIdx18{3}, iS19{( ceilDiv(i0, 3) )}, iS10{i6}, rS11{i4}], stream_index = i99)
// T6_l_float[istreamIdx26{3}, iS27{( ceilDiv(i0, 3) )}, iS25{i6}]
// = matmul(T3_g_float[istreamIdx16{3}, iS17{( ceilDiv(i0, 3) )}, iS7{i4}, rS8{i2}],
// T2_g_float[iS15{i4}, iS5{i6}])
// } // %HostIrContainer
// clang-format on
FusionExecutorCache executor_cache(std::move(fusion));
auto out =
executor_cache.runFusionWithInputs({in, w1, w2})[0].as<at::Tensor>();

testValidate(
executor_cache.fusion(), {out}, {in, w1, w2}, __LINE__, __FILE__);
}
}

} // namespace nvfuser
Loading