66 * LICENSE file in the root directory of this source tree.
77 */
88
9- use std:: ffi:: CString ;
109use std:: fmt;
1110use std:: fmt:: Write ;
1211use std:: hash:: Hasher ;
@@ -26,7 +25,6 @@ use torch_sys2::TensorCell;
2625use torch_sys2:: factory_float_tensor;
2726use torch_sys2:: is_float8_type;
2827
29- use crate :: bridge:: ffi:: make_nccl_config;
3028use crate :: cuda:: CudaError ;
3129use crate :: cuda:: Stream ;
3230use crate :: cuda:: set_device;
@@ -100,60 +98,6 @@ pub enum NcclStatus {
10098 InProgress ,
10199}
102100
103- /// Rust version of ncclConfig_t. See nccl documentation for what each field
104- /// means:
105- /// https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t
106- ///
107- /// Note that we don't validate field values; we rely on nccl to do that.
108- #[ derive( Debug , Clone , Serialize , Deserialize ) ]
109- pub struct NcclConfig {
110- pub blocking : bool ,
111- pub cga_cluster_size : u8 ,
112- pub min_ctas : u8 ,
113- pub max_ctas : u8 ,
114- pub net_name : Option < String > ,
115- pub split_share : bool ,
116- }
117-
118- impl Default for NcclConfig {
119- fn default ( ) -> Self {
120- NcclConfig {
121- blocking : true ,
122- cga_cluster_size : 4 ,
123- min_ctas : 1 ,
124- max_ctas : 32 ,
125- net_name : None ,
126- split_share : false ,
127- }
128- }
129- }
130-
131- impl From < NcclConfig > for ncclConfig_t {
132- fn from ( config : NcclConfig ) -> Self {
133- let mut ret = make_nccl_config ( ) ;
134- ret. blocking = config. blocking . into ( ) ;
135- ret. cgaClusterSize = config. cga_cluster_size . into ( ) ;
136- ret. minCTAs = config. min_ctas . into ( ) ;
137- ret. maxCTAs = config. max_ctas . into ( ) ;
138- if let Some ( net_name) = config. net_name {
139- let c_string = CString :: new ( net_name)
140- . expect ( "failed to create CString" )
141- . into_boxed_c_str ( ) ;
142-
143- // Just leak the string to avoid complicated ownership issues. I'm
144- // not aware of anywhere where we actually want to specify the
145- // network module name in configuration instead of letting nccl just
146- // choose it for us. If this happens + we are creating tons of
147- // config objects, we can revisit this.
148- let ptr = Box :: leak ( c_string) . as_ptr ( ) ;
149- ret. netName = ptr;
150- }
151- ret. splitShare = config. split_share . into ( ) ;
152-
153- ret
154- }
155- }
156-
157101fn nccl_check ( result : ncclResult_t ) -> Result < NcclStatus , RawNcclError > {
158102 match result. 0 {
159103 0 => Ok ( NcclStatus :: Success ) ,
@@ -383,9 +327,9 @@ impl Communicator {
383327
384328 /// Split off a new communicator from this one, preserving the same world
385329 /// size.
386- pub fn split_all ( & mut self , config : Option < NcclConfig > ) -> Result < Self , NcclError > {
330+ pub fn split_all ( & mut self ) -> Result < Self , NcclError > {
387331 let ranks = ( 0 ..self . global_world_size ) . collect ( ) ;
388- Ok ( self . split_from ( ranks, config ) ?. unwrap ( ) )
332+ Ok ( self . split_from ( ranks) ?. unwrap ( ) )
389333 }
390334
391335 /// Split off a new communicator from this one. Only `ranks` will be present
@@ -394,11 +338,7 @@ impl Communicator {
394338 /// If `ranks` is empty, `ncclCommSplit` will be called with
395339 /// NCCL_SPLIT_NOCOLOR. This can be useful if ranks excluded from the split
396340 /// don't even know what ranks will be included.
397- pub fn split_from (
398- & mut self ,
399- mut ranks : Vec < i32 > ,
400- config : Option < NcclConfig > ,
401- ) -> Result < Option < Self > , NcclError > {
341+ pub fn split_from ( & mut self , mut ranks : Vec < i32 > ) -> Result < Option < Self > , NcclError > {
402342 ranks. sort ( ) ;
403343 for rank in & ranks {
404344 if * rank < 0 || * rank >= self . global_world_size {
@@ -411,34 +351,17 @@ impl Communicator {
411351 Err ( _) => NCCL_SPLIT_NOCOLOR ,
412352 } ;
413353
414- let config = config. map ( ncclConfig_t:: from) ;
415354 let mut new = MaybeUninit :: uninit ( ) ;
416355
417356 // SAFETY: intended use of C function
418357 let new = unsafe {
419- // This rather awkward duplication is intentional; we are passing in
420- // `config` as a pointer, which is only guaranteed to be valid for
421- // the duration of `Some(mut config)` match arm.
422- match config {
423- Some ( mut config) => {
424- nccl_check ( ncclCommSplit (
425- self . inner ,
426- color,
427- self . rank ,
428- new. as_mut_ptr ( ) ,
429- & mut config,
430- ) ) ?;
431- }
432- None => {
433- nccl_check ( ncclCommSplit (
434- self . inner ,
435- color,
436- self . rank ,
437- new. as_mut_ptr ( ) ,
438- std:: ptr:: null_mut ( ) ,
439- ) ) ?;
440- }
441- }
358+ nccl_check ( ncclCommSplit (
359+ self . inner ,
360+ color,
361+ self . rank ,
362+ new. as_mut_ptr ( ) ,
363+ std:: ptr:: null_mut ( ) ,
364+ ) ) ?;
442365 new. assume_init ( )
443366 } ;
444367
0 commit comments