Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions monarch_extension/src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,14 +355,12 @@ fn create_map(py: Python) -> HashMap<u64, FnType> {
dims: p.parse("dims")?,
device_mesh: p.parseRef("device_mesh")?,
stream: p.parseStreamRef("stream")?,
config: None,
})
});
m.insert(key("SplitCommForProcessGroup"), |p| {
Ok(WorkerMessage::SplitCommForProcessGroup {
remote_process_group: p.parseRef("remote_process_group")?,
stream: p.parseStreamRef("stream")?,
config: None,
})
});
m.insert(key("DefineRecording"), |p| {
Expand Down
9 changes: 0 additions & 9 deletions monarch_messages/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ use pyo3::types::PyTuple;
use serde::Deserialize;
use serde::Serialize;
use thiserror::Error;
use torch_sys_cuda::nccl::NcclConfig;
use torch_sys_cuda::nccl::ReduceOp;
use torch_sys_cuda::nccl::UniqueId;
use torch_sys2::BorrowError;
Expand Down Expand Up @@ -800,10 +799,6 @@ pub enum WorkerMessage {
/// will be ordered with respect to other operations scheduled on this
/// stream.
stream: StreamRef,
/// Configuration for the new communicator. If None, we will not pass a
/// config object to nccl, which means that the created communicator
/// will inherit its parent's config.
config: Option<NcclConfig>,
},

/// Create a new communicator on each rank in `ranks`, capable of
Expand All @@ -816,10 +811,6 @@ pub enum WorkerMessage {
/// will be ordered with respect to other operations scheduled on this
/// stream.
stream: StreamRef,
/// Configuration for the new communicator. If None, we will not pass a
/// config object to nccl, which means that the created communicator
/// will inherit its parent's config.
config: Option<NcclConfig>,
},

SendTensor {
Expand Down
15 changes: 3 additions & 12 deletions monarch_tensor_worker/src/comm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ use tokio::task::spawn_blocking;
use torch_sys_cuda::cuda::Event;
use torch_sys_cuda::cuda::Stream;
use torch_sys_cuda::nccl::Communicator;
use torch_sys_cuda::nccl::NcclConfig;
use torch_sys_cuda::nccl::NcclError;
use torch_sys_cuda::nccl::NcclStatus;
use torch_sys_cuda::nccl::ReduceOp;
Expand Down Expand Up @@ -90,14 +89,10 @@ pub enum CommMessage {

Group(Vec<CommMessage>, Stream, #[reply] OncePortHandle<Event>),

SplitAll(
Option<NcclConfig>,
#[reply] OncePortHandle<ActorHandle<NcclCommActor>>,
),
SplitAll(#[reply] OncePortHandle<ActorHandle<NcclCommActor>>),

SplitFrom(
Vec<i32>,
Option<NcclConfig>,
#[reply] OncePortHandle<Option<ActorHandle<NcclCommActor>>>,
),
}
Expand Down Expand Up @@ -224,11 +219,10 @@ impl CommMessageHandler for NcclCommActor {
async fn split_all(
&mut self,
cx: &hyperactor::Context<Self>,
nccl_config: Option<NcclConfig>,
) -> Result<ActorHandle<NcclCommActor>> {
let comm = self.comm.clone();

let split_comm = spawn_blocking(move || comm.lock().split_all(nccl_config))
let split_comm = spawn_blocking(move || comm.lock().split_all())
.await
.unwrap()?;

Expand All @@ -241,11 +235,10 @@ impl CommMessageHandler for NcclCommActor {
&mut self,
cx: &hyperactor::Context<Self>,
ranks: Vec<i32>,
nccl_config: Option<NcclConfig>,
) -> Result<Option<ActorHandle<NcclCommActor>>> {
let comm = self.comm.clone();

let split_comm = spawn_blocking(move || comm.lock().split_from(ranks, nccl_config))
let split_comm = spawn_blocking(move || comm.lock().split_from(ranks))
.await
.unwrap()?;

Expand Down Expand Up @@ -691,7 +684,6 @@ mod tests {
dims: vec!["x".into()],
device_mesh: 1.into(),
stream: 0.into(),
config: None,
},
WorkerMessage::CallFunction(CallFunctionParams {
seq: 0.into(),
Expand Down Expand Up @@ -754,7 +746,6 @@ mod tests {
dims: vec!["x".into(), "y".into()],
device_mesh: 1.into(),
stream: 0.into(),
config: None,
},
// Test reduce over "x" and "y".
WorkerMessage::Reduce {
Expand Down
13 changes: 4 additions & 9 deletions monarch_tensor_worker/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ use sorted_vec::SortedVec;
use stream::StreamActor;
use stream::StreamMessageClient;
use stream::StreamParams;
use torch_sys_cuda::nccl::NcclConfig;
use torch_sys_cuda::nccl::ReduceOp;
use torch_sys_cuda::nccl::UniqueId;
use torch_sys2::CudaDevice;
Expand Down Expand Up @@ -340,7 +339,7 @@ impl WorkerMessageHandler for WorkerActor {
for _ in 0..sorted_streams.len() {
// Do the split in this event loop, to provide a deterministic
// order.
splits.push(comm.split_all(cx, None).await?);
splits.push(comm.split_all(cx).await?);
}
let _: Vec<()> = try_join_all(
sorted_streams
Expand Down Expand Up @@ -371,7 +370,7 @@ impl WorkerMessageHandler for WorkerActor {
.comm
.as_ref()
.context("tried to call Reduce before BackendNetworkInit")?;
let comm = global_comm.split_all(cx, None).await?;
let comm = global_comm.split_all(cx).await?;
self.send_recv_comms
.insert((from_stream, to_stream), Arc::new(comm));
Ok(())
Expand Down Expand Up @@ -803,7 +802,6 @@ impl WorkerMessageHandler for WorkerActor {
dims: Vec<String>,
device_mesh: Ref,
stream_ref: StreamRef,
config: Option<NcclConfig>,
) -> Result<()> {
let global_comm = self
.comm
Expand Down Expand Up @@ -833,7 +831,6 @@ impl WorkerMessageHandler for WorkerActor {
.into_iter()
.map(|v| v.clone().try_into())
.collect::<Result<Vec<_>, _>>()?,
config,
)
.await?
.context("split comm should include self rank")?;
Expand All @@ -842,7 +839,7 @@ impl WorkerMessageHandler for WorkerActor {
None => {
// This rank is not in the group to be split off. We still need to
// participate in the commSplit call, however.
global_comm.split_from(cx, vec![], config).await?;
global_comm.split_from(cx, vec![]).await?;
}
}
Ok(())
Expand All @@ -853,7 +850,6 @@ impl WorkerMessageHandler for WorkerActor {
cx: &hyperactor::Context<Self>,
remote_process_group_ref: Ref,
stream_ref: StreamRef,
config: Option<NcclConfig>,
) -> Result<()> {
ensure!(
self.streams.contains_key(&stream_ref),
Expand Down Expand Up @@ -888,7 +884,6 @@ impl WorkerMessageHandler for WorkerActor {
.into_iter()
.map(|v| v.clone().try_into())
.collect::<Result<Vec<_>, _>>()?,
config,
)
.await?
.context("split comm should include self rank")?;
Expand All @@ -897,7 +892,7 @@ impl WorkerMessageHandler for WorkerActor {
None => {
// This rank is not in the group to be split off. We still need to
// participate in the commSplit call, however.
global_comm.split_from(cx, vec![], config).await?;
global_comm.split_from(cx, vec![]).await?;
}
}
Ok(())
Expand Down
1 change: 0 additions & 1 deletion nccl-sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ fn main() {
.allowlist_type("ncclDataType_t")
.allowlist_type("ncclRedOp_t")
.allowlist_type("ncclScalarResidence_t")
.allowlist_type("ncclConfig_t")
.allowlist_type("ncclSimInfo_t")
.allowlist_var("NCCL_SPLIT_NOCOLOR")
.allowlist_var("NCCL_MAJOR")
Expand Down
7 changes: 0 additions & 7 deletions nccl-sys/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,6 @@ unsafe impl ExternType for CUstream_st {
type Kind = cxx::kind::Opaque;
}

/// SAFETY: bindings
/// Trivial because this is POD struct
unsafe impl ExternType for ncclConfig_t {
type Id = type_id!("ncclConfig_t");
type Kind = cxx::kind::Trivial;
}

/// SAFETY: bindings
unsafe impl ExternType for ncclComm {
type Id = type_id!("ncclComm");
Expand Down
14 changes: 1 addition & 13 deletions torch-sys-cuda/src/bridge.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,4 @@

#pragma once

#include <nccl.h> // @manual

namespace monarch {

/// This function exists because ncclConfig initialization requires the use of
/// a macro. We cannot reference the macro directly from Rust code, so we wrap
/// the macro use in a function and bind that to Rust instead.
inline ncclConfig_t make_nccl_config() {
ncclConfig_t ret = NCCL_CONFIG_INITIALIZER;
return ret;
}

} // namespace monarch
namespace monarch {} // namespace monarch
5 changes: 0 additions & 5 deletions torch-sys-cuda/src/bridge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,5 @@
pub(crate) mod ffi {
unsafe extern "C++" {
include!("monarch/torch-sys-cuda/src/bridge.h");

// nccl helpers
#[namespace = ""]
type ncclConfig_t = nccl_sys::ncclConfig_t;
fn make_nccl_config() -> ncclConfig_t;
}
}
99 changes: 11 additions & 88 deletions torch-sys-cuda/src/nccl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
* LICENSE file in the root directory of this source tree.
*/

use std::ffi::CString;
use std::fmt;
use std::fmt::Write;
use std::hash::Hasher;
Expand All @@ -26,7 +25,6 @@ use torch_sys2::TensorCell;
use torch_sys2::factory_float_tensor;
use torch_sys2::is_float8_type;

use crate::bridge::ffi::make_nccl_config;
use crate::cuda::CudaError;
use crate::cuda::Stream;
use crate::cuda::set_device;
Expand Down Expand Up @@ -100,60 +98,6 @@ pub enum NcclStatus {
InProgress,
}

/// Rust version of ncclConfig_t. See nccl documentation for what each field
/// means:
/// https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t
///
/// Note that we don't validate field values; we rely on nccl to do that.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NcclConfig {
pub blocking: bool,
pub cga_cluster_size: u8,
pub min_ctas: u8,
pub max_ctas: u8,
pub net_name: Option<String>,
pub split_share: bool,
}

impl Default for NcclConfig {
fn default() -> Self {
NcclConfig {
blocking: true,
cga_cluster_size: 4,
min_ctas: 1,
max_ctas: 32,
net_name: None,
split_share: false,
}
}
}

impl From<NcclConfig> for ncclConfig_t {
fn from(config: NcclConfig) -> Self {
let mut ret = make_nccl_config();
ret.blocking = config.blocking.into();
ret.cgaClusterSize = config.cga_cluster_size.into();
ret.minCTAs = config.min_ctas.into();
ret.maxCTAs = config.max_ctas.into();
if let Some(net_name) = config.net_name {
let c_string = CString::new(net_name)
.expect("failed to create CString")
.into_boxed_c_str();

// Just leak the string to avoid complicated ownership issues. I'm
// not aware of anywhere where we actually want to specify the
// network module name in configuration instead of letting nccl just
// choose it for us. If this happens + we are creating tons of
// config objects, we can revisit this.
let ptr = Box::leak(c_string).as_ptr();
ret.netName = ptr;
}
ret.splitShare = config.split_share.into();

ret
}
}

fn nccl_check(result: ncclResult_t) -> Result<NcclStatus, RawNcclError> {
match result.0 {
0 => Ok(NcclStatus::Success),
Expand Down Expand Up @@ -383,9 +327,9 @@ impl Communicator {

/// Split off a new communicator from this one, preserving the same world
/// size.
pub fn split_all(&mut self, config: Option<NcclConfig>) -> Result<Self, NcclError> {
pub fn split_all(&mut self) -> Result<Self, NcclError> {
let ranks = (0..self.global_world_size).collect();
Ok(self.split_from(ranks, config)?.unwrap())
Ok(self.split_from(ranks)?.unwrap())
}

/// Split off a new communicator from this one. Only `ranks` will be present
Expand All @@ -394,11 +338,7 @@ impl Communicator {
/// If `ranks` is empty, `ncclCommSplit` will be called with
/// NCCL_SPLIT_NOCOLOR. This can be useful if ranks excluded from the split
/// don't even know what ranks will be included.
pub fn split_from(
&mut self,
mut ranks: Vec<i32>,
config: Option<NcclConfig>,
) -> Result<Option<Self>, NcclError> {
pub fn split_from(&mut self, mut ranks: Vec<i32>) -> Result<Option<Self>, NcclError> {
ranks.sort();
for rank in &ranks {
if *rank < 0 || *rank >= self.global_world_size {
Expand All @@ -411,34 +351,17 @@ impl Communicator {
Err(_) => NCCL_SPLIT_NOCOLOR,
};

let config = config.map(ncclConfig_t::from);
let mut new = MaybeUninit::uninit();

// SAFETY: intended use of C function
let new = unsafe {
// This rather awkward duplication is intentional; we are passing in
// `config` as a pointer, which is only guaranteed to be valid for
// the duration of `Some(mut config)` match arm.
match config {
Some(mut config) => {
nccl_check(ncclCommSplit(
self.inner,
color,
self.rank,
new.as_mut_ptr(),
&mut config,
))?;
}
None => {
nccl_check(ncclCommSplit(
self.inner,
color,
self.rank,
new.as_mut_ptr(),
std::ptr::null_mut(),
))?;
}
}
nccl_check(ncclCommSplit(
self.inner,
color,
self.rank,
new.as_mut_ptr(),
std::ptr::null_mut(),
))?;
new.assume_init()
};

Expand Down Expand Up @@ -1098,7 +1021,7 @@ mod tests {
let mut comm = Communicator::new(device, 2, unique_id, i.into()).unwrap();

// Split a new comm with only rank 0
let split_comm = comm.split_from(vec![0], None).unwrap();
let split_comm = comm.split_from(vec![0]).unwrap();

match i {
0 => assert!(split_comm.is_some()),
Expand Down
Loading