Skip to content

Commit eb32490

Browse files
committed
Remove ncclConfig_t bindings
Pull Request resolved: #2086 These were never used but complicate how we have to generate bindings to nccl. ghstack-source-id: 327879735 @exported-using-ghexport Differential Revision: [D88656649](https://our.internmc.facebook.com/intern/diff/D88656649/)
1 parent 0708dd3 commit eb32490

File tree

9 files changed

+18
-143
lines changed

9 files changed

+18
-143
lines changed

monarch_extension/src/convert.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,14 +355,12 @@ fn create_map(py: Python) -> HashMap<u64, FnType> {
355355
dims: p.parse("dims")?,
356356
device_mesh: p.parseRef("device_mesh")?,
357357
stream: p.parseStreamRef("stream")?,
358-
config: None,
359358
})
360359
});
361360
m.insert(key("SplitCommForProcessGroup"), |p| {
362361
Ok(WorkerMessage::SplitCommForProcessGroup {
363362
remote_process_group: p.parseRef("remote_process_group")?,
364363
stream: p.parseStreamRef("stream")?,
365-
config: None,
366364
})
367365
});
368366
m.insert(key("DefineRecording"), |p| {

monarch_messages/src/worker.rs

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ use pyo3::types::PyTuple;
3838
use serde::Deserialize;
3939
use serde::Serialize;
4040
use thiserror::Error;
41-
use torch_sys_cuda::nccl::NcclConfig;
4241
use torch_sys_cuda::nccl::ReduceOp;
4342
use torch_sys_cuda::nccl::UniqueId;
4443
use torch_sys2::BorrowError;
@@ -800,10 +799,6 @@ pub enum WorkerMessage {
800799
/// will be ordered with respect to other operations scheduled on this
801800
/// stream.
802801
stream: StreamRef,
803-
/// Configuration for the new communicator. If None, we will not pass a
804-
/// config object to nccl, which means that the created communicator
805-
/// will inherit its parent's config.
806-
config: Option<NcclConfig>,
807802
},
808803

809804
/// Create a new communicator on each rank in `ranks`, capable of
@@ -816,10 +811,6 @@ pub enum WorkerMessage {
816811
/// will be ordered with respect to other operations scheduled on this
817812
/// stream.
818813
stream: StreamRef,
819-
/// Configuration for the new communicator. If None, we will not pass a
820-
/// config object to nccl, which means that the created communicator
821-
/// will inherit its parent's config.
822-
config: Option<NcclConfig>,
823814
},
824815

825816
SendTensor {

monarch_tensor_worker/src/comm.rs

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ use tokio::task::spawn_blocking;
2525
use torch_sys_cuda::cuda::Event;
2626
use torch_sys_cuda::cuda::Stream;
2727
use torch_sys_cuda::nccl::Communicator;
28-
use torch_sys_cuda::nccl::NcclConfig;
2928
use torch_sys_cuda::nccl::NcclError;
3029
use torch_sys_cuda::nccl::NcclStatus;
3130
use torch_sys_cuda::nccl::ReduceOp;
@@ -90,14 +89,10 @@ pub enum CommMessage {
9089

9190
Group(Vec<CommMessage>, Stream, #[reply] OncePortHandle<Event>),
9291

93-
SplitAll(
94-
Option<NcclConfig>,
95-
#[reply] OncePortHandle<ActorHandle<NcclCommActor>>,
96-
),
92+
SplitAll(#[reply] OncePortHandle<ActorHandle<NcclCommActor>>),
9793

9894
SplitFrom(
9995
Vec<i32>,
100-
Option<NcclConfig>,
10196
#[reply] OncePortHandle<Option<ActorHandle<NcclCommActor>>>,
10297
),
10398
}
@@ -224,11 +219,10 @@ impl CommMessageHandler for NcclCommActor {
224219
async fn split_all(
225220
&mut self,
226221
cx: &hyperactor::Context<Self>,
227-
nccl_config: Option<NcclConfig>,
228222
) -> Result<ActorHandle<NcclCommActor>> {
229223
let comm = self.comm.clone();
230224

231-
let split_comm = spawn_blocking(move || comm.lock().split_all(nccl_config))
225+
let split_comm = spawn_blocking(move || comm.lock().split_all())
232226
.await
233227
.unwrap()?;
234228

@@ -241,11 +235,10 @@ impl CommMessageHandler for NcclCommActor {
241235
&mut self,
242236
cx: &hyperactor::Context<Self>,
243237
ranks: Vec<i32>,
244-
nccl_config: Option<NcclConfig>,
245238
) -> Result<Option<ActorHandle<NcclCommActor>>> {
246239
let comm = self.comm.clone();
247240

248-
let split_comm = spawn_blocking(move || comm.lock().split_from(ranks, nccl_config))
241+
let split_comm = spawn_blocking(move || comm.lock().split_from(ranks))
249242
.await
250243
.unwrap()?;
251244

monarch_tensor_worker/src/lib.rs

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@ use sorted_vec::SortedVec;
9090
use stream::StreamActor;
9191
use stream::StreamMessageClient;
9292
use stream::StreamParams;
93-
use torch_sys_cuda::nccl::NcclConfig;
9493
use torch_sys_cuda::nccl::ReduceOp;
9594
use torch_sys_cuda::nccl::UniqueId;
9695
use torch_sys2::CudaDevice;
@@ -340,7 +339,7 @@ impl WorkerMessageHandler for WorkerActor {
340339
for _ in 0..sorted_streams.len() {
341340
// Do the split in this event loop, to provide a deterministic
342341
// order.
343-
splits.push(comm.split_all(cx, None).await?);
342+
splits.push(comm.split_all(cx).await?);
344343
}
345344
let _: Vec<()> = try_join_all(
346345
sorted_streams
@@ -371,7 +370,7 @@ impl WorkerMessageHandler for WorkerActor {
371370
.comm
372371
.as_ref()
373372
.context("tried to call Reduce before BackendNetworkInit")?;
374-
let comm = global_comm.split_all(cx, None).await?;
373+
let comm = global_comm.split_all(cx).await?;
375374
self.send_recv_comms
376375
.insert((from_stream, to_stream), Arc::new(comm));
377376
Ok(())
@@ -803,7 +802,6 @@ impl WorkerMessageHandler for WorkerActor {
803802
dims: Vec<String>,
804803
device_mesh: Ref,
805804
stream_ref: StreamRef,
806-
config: Option<NcclConfig>,
807805
) -> Result<()> {
808806
let global_comm = self
809807
.comm
@@ -833,7 +831,6 @@ impl WorkerMessageHandler for WorkerActor {
833831
.into_iter()
834832
.map(|v| v.clone().try_into())
835833
.collect::<Result<Vec<_>, _>>()?,
836-
config,
837834
)
838835
.await?
839836
.context("split comm should include self rank")?;
@@ -842,7 +839,7 @@ impl WorkerMessageHandler for WorkerActor {
842839
None => {
843840
// This rank is not in the group to be split off. We still need to
844841
// participate in the commSplit call, however.
845-
global_comm.split_from(cx, vec![], config).await?;
842+
global_comm.split_from(cx, vec![]).await?;
846843
}
847844
}
848845
Ok(())
@@ -853,7 +850,6 @@ impl WorkerMessageHandler for WorkerActor {
853850
cx: &hyperactor::Context<Self>,
854851
remote_process_group_ref: Ref,
855852
stream_ref: StreamRef,
856-
config: Option<NcclConfig>,
857853
) -> Result<()> {
858854
ensure!(
859855
self.streams.contains_key(&stream_ref),
@@ -888,7 +884,6 @@ impl WorkerMessageHandler for WorkerActor {
888884
.into_iter()
889885
.map(|v| v.clone().try_into())
890886
.collect::<Result<Vec<_>, _>>()?,
891-
config,
892887
)
893888
.await?
894889
.context("split comm should include self rank")?;
@@ -897,7 +892,7 @@ impl WorkerMessageHandler for WorkerActor {
897892
None => {
898893
// This rank is not in the group to be split off. We still need to
899894
// participate in the commSplit call, however.
900-
global_comm.split_from(cx, vec![], config).await?;
895+
global_comm.split_from(cx, vec![]).await?;
901896
}
902897
}
903898
Ok(())

nccl-sys/build.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ fn main() {
6868
.allowlist_type("ncclDataType_t")
6969
.allowlist_type("ncclRedOp_t")
7070
.allowlist_type("ncclScalarResidence_t")
71-
.allowlist_type("ncclConfig_t")
7271
.allowlist_type("ncclSimInfo_t")
7372
.allowlist_var("NCCL_SPLIT_NOCOLOR")
7473
.allowlist_var("NCCL_MAJOR")

nccl-sys/src/lib.rs

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,6 @@ unsafe impl ExternType for CUstream_st {
1515
type Kind = cxx::kind::Opaque;
1616
}
1717

18-
/// SAFETY: bindings
19-
/// Trivial because this is POD struct
20-
unsafe impl ExternType for ncclConfig_t {
21-
type Id = type_id!("ncclConfig_t");
22-
type Kind = cxx::kind::Trivial;
23-
}
24-
2518
/// SAFETY: bindings
2619
unsafe impl ExternType for ncclComm {
2720
type Id = type_id!("ncclComm");

torch-sys-cuda/src/bridge.h

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,4 @@
88

99
#pragma once
1010

11-
#include <nccl.h> // @manual
12-
13-
namespace monarch {
14-
15-
/// This function exists because ncclConfig initialization requires the use of
16-
/// a macro. We cannot reference the macro directly from Rust code, so we wrap
17-
/// the macro use in a function and bind that to Rust instead.
18-
inline ncclConfig_t make_nccl_config() {
19-
ncclConfig_t ret = NCCL_CONFIG_INITIALIZER;
20-
return ret;
21-
}
22-
23-
} // namespace monarch
11+
namespace monarch {} // namespace monarch

torch-sys-cuda/src/bridge.rs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,5 @@
1111
pub(crate) mod ffi {
1212
unsafe extern "C++" {
1313
include!("monarch/torch-sys-cuda/src/bridge.h");
14-
15-
// nccl helpers
16-
#[namespace = ""]
17-
type ncclConfig_t = nccl_sys::ncclConfig_t;
18-
fn make_nccl_config() -> ncclConfig_t;
1914
}
2015
}

torch-sys-cuda/src/nccl.rs

Lines changed: 10 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
use std::ffi::CString;
109
use std::fmt;
1110
use std::fmt::Write;
1211
use std::hash::Hasher;
@@ -26,7 +25,6 @@ use torch_sys2::TensorCell;
2625
use torch_sys2::factory_float_tensor;
2726
use torch_sys2::is_float8_type;
2827

29-
use crate::bridge::ffi::make_nccl_config;
3028
use crate::cuda::CudaError;
3129
use crate::cuda::Stream;
3230
use crate::cuda::set_device;
@@ -100,60 +98,6 @@ pub enum NcclStatus {
10098
InProgress,
10199
}
102100

103-
/// Rust version of ncclConfig_t. See nccl documentation for what each field
104-
/// means:
105-
/// https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t
106-
///
107-
/// Note that we don't validate field values; we rely on nccl to do that.
108-
#[derive(Debug, Clone, Serialize, Deserialize)]
109-
pub struct NcclConfig {
110-
pub blocking: bool,
111-
pub cga_cluster_size: u8,
112-
pub min_ctas: u8,
113-
pub max_ctas: u8,
114-
pub net_name: Option<String>,
115-
pub split_share: bool,
116-
}
117-
118-
impl Default for NcclConfig {
119-
fn default() -> Self {
120-
NcclConfig {
121-
blocking: true,
122-
cga_cluster_size: 4,
123-
min_ctas: 1,
124-
max_ctas: 32,
125-
net_name: None,
126-
split_share: false,
127-
}
128-
}
129-
}
130-
131-
impl From<NcclConfig> for ncclConfig_t {
132-
fn from(config: NcclConfig) -> Self {
133-
let mut ret = make_nccl_config();
134-
ret.blocking = config.blocking.into();
135-
ret.cgaClusterSize = config.cga_cluster_size.into();
136-
ret.minCTAs = config.min_ctas.into();
137-
ret.maxCTAs = config.max_ctas.into();
138-
if let Some(net_name) = config.net_name {
139-
let c_string = CString::new(net_name)
140-
.expect("failed to create CString")
141-
.into_boxed_c_str();
142-
143-
// Just leak the string to avoid complicated ownership issues. I'm
144-
// not aware of anywhere where we actually want to specify the
145-
// network module name in configuration instead of letting nccl just
146-
// choose it for us. If this happens + we are creating tons of
147-
// config objects, we can revisit this.
148-
let ptr = Box::leak(c_string).as_ptr();
149-
ret.netName = ptr;
150-
}
151-
ret.splitShare = config.split_share.into();
152-
153-
ret
154-
}
155-
}
156-
157101
fn nccl_check(result: ncclResult_t) -> Result<NcclStatus, RawNcclError> {
158102
match result.0 {
159103
0 => Ok(NcclStatus::Success),
@@ -383,9 +327,9 @@ impl Communicator {
383327

384328
/// Split off a new communicator from this one, preserving the same world
385329
/// size.
386-
pub fn split_all(&mut self, config: Option<NcclConfig>) -> Result<Self, NcclError> {
330+
pub fn split_all(&mut self) -> Result<Self, NcclError> {
387331
let ranks = (0..self.global_world_size).collect();
388-
Ok(self.split_from(ranks, config)?.unwrap())
332+
Ok(self.split_from(ranks)?.unwrap())
389333
}
390334

391335
/// Split off a new communicator from this one. Only `ranks` will be present
@@ -394,11 +338,7 @@ impl Communicator {
394338
/// If `ranks` is empty, `ncclCommSplit` will be called with
395339
/// NCCL_SPLIT_NOCOLOR. This can be useful if ranks excluded from the split
396340
/// don't even know what ranks will be included.
397-
pub fn split_from(
398-
&mut self,
399-
mut ranks: Vec<i32>,
400-
config: Option<NcclConfig>,
401-
) -> Result<Option<Self>, NcclError> {
341+
pub fn split_from(&mut self, mut ranks: Vec<i32>) -> Result<Option<Self>, NcclError> {
402342
ranks.sort();
403343
for rank in &ranks {
404344
if *rank < 0 || *rank >= self.global_world_size {
@@ -411,34 +351,17 @@ impl Communicator {
411351
Err(_) => NCCL_SPLIT_NOCOLOR,
412352
};
413353

414-
let config = config.map(ncclConfig_t::from);
415354
let mut new = MaybeUninit::uninit();
416355

417356
// SAFETY: intended use of C function
418357
let new = unsafe {
419-
// This rather awkward duplication is intentional; we are passing in
420-
// `config` as a pointer, which is only guaranteed to be valid for
421-
// the duration of `Some(mut config)` match arm.
422-
match config {
423-
Some(mut config) => {
424-
nccl_check(ncclCommSplit(
425-
self.inner,
426-
color,
427-
self.rank,
428-
new.as_mut_ptr(),
429-
&mut config,
430-
))?;
431-
}
432-
None => {
433-
nccl_check(ncclCommSplit(
434-
self.inner,
435-
color,
436-
self.rank,
437-
new.as_mut_ptr(),
438-
std::ptr::null_mut(),
439-
))?;
440-
}
441-
}
358+
nccl_check(ncclCommSplit(
359+
self.inner,
360+
color,
361+
self.rank,
362+
new.as_mut_ptr(),
363+
std::ptr::null_mut(),
364+
))?;
442365
new.assume_init()
443366
};
444367

0 commit comments

Comments
 (0)