From 838e153bf259becc26d9eabcbb258f0bd35bcc92 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 17 Dec 2025 23:58:04 -0800 Subject: [PATCH 1/5] Check consistency of multiplier --- csrc/multidevice/execution_utils.cpp | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/csrc/multidevice/execution_utils.cpp b/csrc/multidevice/execution_utils.cpp index a7a7da703e2..5caee19c070 100644 --- a/csrc/multidevice/execution_utils.cpp +++ b/csrc/multidevice/execution_utils.cpp @@ -68,6 +68,9 @@ std::vector unshardedSizes( "Producing logical axis not found for ", sharded_id); + // Global map to track extent -> multiplier relationships + static std::unordered_map extent_to_multiplier_map; + auto multiplier = [&]() -> int64_t { if (parallel_type == ParallelType::Stream) { // TODO(#5525): hack for MultiDeviceExecutor. MultiDeviceExecutor looks @@ -101,6 +104,22 @@ std::vector unshardedSizes( NVF_THROW("Unexpected parallel type: ", parallel_type); }(); + + // Check consistency: for the same extent, we should always get the same multiplier + Val* extent = sharded_id->extent(); + auto it = extent_to_multiplier_map.find(extent); + if (it != extent_to_multiplier_map.end()) { + NVF_ERROR( + it->second == multiplier, + "Inconsistent multiplier for extent ", + extent->toString(), + ": expected ", + it->second, + " but got ", + multiplier); + } else { + extent_to_multiplier_map[extent] = multiplier; + } unsharded_sizes.at(sharded_axis) *= multiplier; } From 88e982cf1c70f3d92cd90b0e6e4ef310617523c2 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 5 Jan 2026 15:19:29 -0800 Subject: [PATCH 2/5] let PrecomputedValues own the map --- csrc/evaluator_common.cpp | 2 +- csrc/evaluator_common.h | 8 +++++++ csrc/expr_evaluator.cpp | 2 +- csrc/expr_evaluator.h | 5 ++++ csrc/multidevice/execution_utils.cpp | 35 ++++++++++++++-------------- csrc/multidevice/execution_utils.h | 3 ++- csrc/runtime/allocations.cpp | 9 +++---- csrc/runtime/executor.cpp | 2 +- csrc/tensor_metadata.cpp | 4 ++-- 9 files changed, 43 insertions(+), 27 deletions(-) diff --git a/csrc/evaluator_common.cpp b/csrc/evaluator_common.cpp index 5e983777b04..bdf5e70e77f 100644 --- a/csrc/evaluator_common.cpp +++ b/csrc/evaluator_common.cpp @@ -382,7 +382,7 @@ void PrecomputedValues::bindTensorMetaData( tensor.dim() == static_cast(logical_domain.size()), "Something went wrong configuring launch. Inputs do not match."); - std::vector logical_sizes = unshardedSizes(tv, tensor.sizes()); + std::vector logical_sizes = unshardedSizes(tv, tensor.sizes(), &extent_to_multiplier_map_); adjustEvaluatorSizes(tv, logical_sizes); for (const auto dim : arange(static_cast(logical_domain.size()))) { diff --git a/csrc/evaluator_common.h b/csrc/evaluator_common.h index aabf029ed4d..c68045535ac 100644 --- a/csrc/evaluator_common.h +++ b/csrc/evaluator_common.h @@ -181,6 +181,11 @@ class PrecomputedValues { return has_valid_values_; } + //! Get the extent to multiplier map for unshardedSizes + std::unordered_map* getExtentToMultiplierMap() { + return &extent_to_multiplier_map_; + } + //! Runs the internal value machine that will compute //! the values allocated in the workspace. void evaluate(); @@ -289,6 +294,9 @@ class PrecomputedValues { //! Stores the IR nodes corresponding to each index. std::vector symbols_; + //! Extent to multiplier map for unshardedSizes - owned by this PrecomputedValues + std::unordered_map extent_to_multiplier_map_; + //! An internal log to keep track of all the bindings //! used in each evaluation cycle. To be used for //! consistency check. diff --git a/csrc/expr_evaluator.cpp b/csrc/expr_evaluator.cpp index 38bbad4fda7..afec1cd2ca5 100644 --- a/csrc/expr_evaluator.cpp +++ b/csrc/expr_evaluator.cpp @@ -146,7 +146,7 @@ void ExpressionEvaluator::bindTensorDomain( tv->toString(), ", to be bound to a tensor of equal rank."); - std::vector logical_sizes = unshardedSizes(tv, t.sizes()); + std::vector logical_sizes = unshardedSizes(tv, t.sizes(), getExtentToMultiplierMap()); adjustEvaluatorSizes(tv, logical_sizes); for (const auto& [i, id] : enumerate(logical_domain)) { diff --git a/csrc/expr_evaluator.h b/csrc/expr_evaluator.h index 2c7b2fdb2ef..ca4190fea0a 100644 --- a/csrc/expr_evaluator.h +++ b/csrc/expr_evaluator.h @@ -79,6 +79,11 @@ class ExpressionEvaluator { return precomputed_values_; } + //! Get the extent to multiplier map for unshardedSizes from PrecomputedValues + std::unordered_map* getExtentToMultiplierMap() const { + return precomputed_values_ ? precomputed_values_->getExtentToMultiplierMap() : nullptr; + } + //! Augment the evaluator with the exact root-domain map such that //! if the extent of a root ID is known, the extents of all other //! root IDs that are exactly mapped also get bound to the same diff --git a/csrc/multidevice/execution_utils.cpp b/csrc/multidevice/execution_utils.cpp index 5caee19c070..92ad1a29364 100644 --- a/csrc/multidevice/execution_utils.cpp +++ b/csrc/multidevice/execution_utils.cpp @@ -50,7 +50,8 @@ at::Tensor shardTensor( std::vector unshardedSizes( const TensorView* tv, - c10::IntArrayRef sizes) { + c10::IntArrayRef sizes, + std::unordered_map* extent_to_multiplier_map) { std::vector unsharded_sizes = sizes.vec(); for (ParallelType parallel_type : deviceAndStreamParallelTypes()) { const DomainType domain_type = parallel_type == ParallelType::Stream @@ -68,9 +69,6 @@ std::vector unshardedSizes( "Producing logical axis not found for ", sharded_id); - // Global map to track extent -> multiplier relationships - static std::unordered_map extent_to_multiplier_map; - auto multiplier = [&]() -> int64_t { if (parallel_type == ParallelType::Stream) { // TODO(#5525): hack for MultiDeviceExecutor. MultiDeviceExecutor looks @@ -106,19 +104,22 @@ std::vector unshardedSizes( }(); // Check consistency: for the same extent, we should always get the same multiplier - Val* extent = sharded_id->extent(); - auto it = extent_to_multiplier_map.find(extent); - if (it != extent_to_multiplier_map.end()) { - NVF_ERROR( - it->second == multiplier, - "Inconsistent multiplier for extent ", - extent->toString(), - ": expected ", - it->second, - " but got ", - multiplier); - } else { - extent_to_multiplier_map[extent] = multiplier; + // Only perform this check if a map is provided + if (extent_to_multiplier_map) { + Val* extent = sharded_id->extent(); + auto it = extent_to_multiplier_map->find(extent); + if (it != extent_to_multiplier_map->end()) { + NVF_ERROR( + it->second == multiplier, + "Inconsistent multiplier for extent ", + extent->toString(), + ": expected ", + it->second, + " but got ", + multiplier); + } else { + (*extent_to_multiplier_map)[extent] = multiplier; + } } unsharded_sizes.at(sharded_axis) *= multiplier; } diff --git a/csrc/multidevice/execution_utils.h b/csrc/multidevice/execution_utils.h index 2032b440f22..99323f73d10 100644 --- a/csrc/multidevice/execution_utils.h +++ b/csrc/multidevice/execution_utils.h @@ -67,6 +67,7 @@ NVF_API at::Tensor shardTensor( // ExpressionEvaluator, and so on, which is an API overhaul. std::vector unshardedSizes( const TensorView* tv, - c10::IntArrayRef sizes); + c10::IntArrayRef sizes, + std::unordered_map* extent_to_multiplier_map = nullptr); } // namespace nvfuser diff --git a/csrc/runtime/allocations.cpp b/csrc/runtime/allocations.cpp index 3c5a9c2fb46..31dce3fd246 100644 --- a/csrc/runtime/allocations.cpp +++ b/csrc/runtime/allocations.cpp @@ -886,6 +886,7 @@ std::pair, std::vector> inferShapeOfOutput( TensorShapeInfo inferTensorShapes( TensorView* tv, const ExpressionEvaluator& expr_eval) { + auto* extent_map = expr_eval.getExtentToMultiplierMap(); // Alias handling: auto alias_info = tv->fusion()->getOutputAlias(tv); if (alias_info.type != AllocationType::New) { @@ -902,7 +903,7 @@ TensorShapeInfo inferTensorShapes( return TensorShapeInfo{ tensor.sizes().vec(), tensor.strides().vec(), - isSharded(tv) ? unshardedSizes(tv, tensor.sizes().vec()) + isSharded(tv) ? unshardedSizes(tv, tensor.sizes().vec(), extent_map) : std::vector(), }; } @@ -911,7 +912,7 @@ TensorShapeInfo inferTensorShapes( return TensorShapeInfo{ tensor.sizes().vec(), tensor.strides().vec(), - isSharded(tv) ? unshardedSizes(tv, tensor.sizes().vec()) + isSharded(tv) ? unshardedSizes(tv, tensor.sizes().vec(), extent_map) : std::vector(), allocation_size_stride.first, allocation_size_stride.second}; @@ -923,7 +924,7 @@ TensorShapeInfo inferTensorShapes( return TensorShapeInfo{ allocation_size_stride.first, allocation_size_stride.second, - isSharded(tv) ? unshardedSizes(tv, allocation_size_stride.first) + isSharded(tv) ? unshardedSizes(tv, allocation_size_stride.first, extent_map) : std::vector(), }; } @@ -940,7 +941,7 @@ TensorShapeInfo inferTensorShapes( return TensorShapeInfo{ logical_meta_tensor.sizes().vec(), logical_meta_tensor.strides().vec(), - isSharded(tv) ? unshardedSizes(tv, logical_meta_tensor.sizes().vec()) + isSharded(tv) ? unshardedSizes(tv, logical_meta_tensor.sizes().vec(), extent_map) : std::vector(), allocation_size_stride.first, allocation_size_stride.second}; diff --git a/csrc/runtime/executor.cpp b/csrc/runtime/executor.cpp index baa478f095f..271790e7048 100644 --- a/csrc/runtime/executor.cpp +++ b/csrc/runtime/executor.cpp @@ -744,7 +744,7 @@ void KernelExecutor::initializeExecutorEntry( shape_info.logical_strides = arg_tensor.strides().vec(); if (isSharded(input_tv)) { shape_info.unsharded_logical_sizes = - unshardedSizes(input_tv, shape_info.logical_sizes); + unshardedSizes(input_tv, shape_info.logical_sizes, expr_eval.getExtentToMultiplierMap()); } shape_info.allocation_sizes = alloc_sizes; shape_info.allocation_strides = alloc_strides; diff --git a/csrc/tensor_metadata.cpp b/csrc/tensor_metadata.cpp index 89d83a97eea..f18b644cf2c 100644 --- a/csrc/tensor_metadata.cpp +++ b/csrc/tensor_metadata.cpp @@ -302,7 +302,7 @@ inferAllocationSizesAndStrides( const auto& alloc = tv->getMaybeAllocationDomain(); // active IDs and their shape and stride - std::vector logical_sizes = unshardedSizes(tv, tensor.sizes()); + std::vector logical_sizes = unshardedSizes(tv, tensor.sizes(), ee.getExtentToMultiplierMap()); std::unordered_map> active_ids; int64_t dim_index = 0; for (IterDomain* id : logical | TensorDomain::kNoReductions) { @@ -398,7 +398,7 @@ std::vector GetMetaData::evaluate( metadata->data = input.data_ptr(); if (isSharded(tv)) { - std::vector unsharded_sizes = unshardedSizes(tv, input.sizes()); + std::vector unsharded_sizes = unshardedSizes(tv, input.sizes(), ee.getExtentToMultiplierMap()); metadata->logical_size_data = std::move(unsharded_sizes); metadata->logical_size = c10::makeArrayRef(metadata->logical_size_data); } else { From 113dc05b8ae78cb2de8fd50841efbf26927dedbc Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 5 Jan 2026 18:43:48 -0800 Subject: [PATCH 3/5] save --- csrc/multidevice/execution_utils.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/csrc/multidevice/execution_utils.cpp b/csrc/multidevice/execution_utils.cpp index 92ad1a29364..cf764e6d4c9 100644 --- a/csrc/multidevice/execution_utils.cpp +++ b/csrc/multidevice/execution_utils.cpp @@ -118,7 +118,8 @@ std::vector unshardedSizes( " but got ", multiplier); } else { - (*extent_to_multiplier_map)[extent] = multiplier; + NVF_ERROR(false, "Extent to multiplier map not found for extent ", extent->toString()); + // (*extent_to_multiplier_map)[extent] = multiplier; } } unsharded_sizes.at(sharded_axis) *= multiplier; From b4821fd4585d91653cf15efdf4d2aec3b721a458 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Tue, 6 Jan 2026 10:19:02 -0800 Subject: [PATCH 4/5] fix --- csrc/multidevice/execution_utils.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/csrc/multidevice/execution_utils.cpp b/csrc/multidevice/execution_utils.cpp index cf764e6d4c9..ae97f07a704 100644 --- a/csrc/multidevice/execution_utils.cpp +++ b/csrc/multidevice/execution_utils.cpp @@ -118,9 +118,10 @@ std::vector unshardedSizes( " but got ", multiplier); } else { - NVF_ERROR(false, "Extent to multiplier map not found for extent ", extent->toString()); - // (*extent_to_multiplier_map)[extent] = multiplier; + (*extent_to_multiplier_map)[extent] = multiplier; } + } else { + NVF_ERROR(false, "Extent to multiplier map not provided"); } unsharded_sizes.at(sharded_axis) *= multiplier; } From 438a860a67ce936c19ffffa25e3d5895e216dfc0 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Tue, 6 Jan 2026 12:25:55 -0800 Subject: [PATCH 5/5] save --- csrc/multidevice/execution_utils.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/multidevice/execution_utils.cpp b/csrc/multidevice/execution_utils.cpp index ae97f07a704..39a05cdde4a 100644 --- a/csrc/multidevice/execution_utils.cpp +++ b/csrc/multidevice/execution_utils.cpp @@ -121,7 +121,7 @@ std::vector unshardedSizes( (*extent_to_multiplier_map)[extent] = multiplier; } } else { - NVF_ERROR(false, "Extent to multiplier map not provided"); + // NVF_ERROR(false, "Extent to multiplier map not provided"); } unsharded_sizes.at(sharded_axis) *= multiplier; }