@@ -128,7 +128,6 @@ impl DMatrix {
128128 pub fn from_csr ( indptr : & [ usize ] , indices : & [ usize ] , data : & [ f32 ] , num_cols : Option < usize > ) -> XGBResult < Self > {
129129 assert_eq ! ( indices. len( ) , data. len( ) ) ;
130130 let mut handle = ptr:: null_mut ( ) ;
131- let indptr: Vec < u64 > = indptr. iter ( ) . map ( |x| * x as u64 ) . collect ( ) ;
132131 let indices: Vec < u32 > = indices. iter ( ) . map ( |x| * x as u32 ) . collect ( ) ;
133132 let num_cols = num_cols. unwrap_or ( 0 ) ; // infer from data if 0
134133 xgb_call ! ( xgboost_sys:: XGDMatrixCreateFromCSREx ( indptr. as_ptr( ) ,
@@ -152,7 +151,6 @@ impl DMatrix {
152151 pub fn from_csc ( indptr : & [ usize ] , indices : & [ usize ] , data : & [ f32 ] , num_rows : Option < usize > ) -> XGBResult < Self > {
153152 assert_eq ! ( indices. len( ) , data. len( ) ) ;
154153 let mut handle = ptr:: null_mut ( ) ;
155- let indptr: Vec < u64 > = indptr. iter ( ) . map ( |x| * x as u64 ) . collect ( ) ;
156154 let indices: Vec < u32 > = indices. iter ( ) . map ( |x| * x as u32 ) . collect ( ) ;
157155 let num_rows = num_rows. unwrap_or ( 0 ) ; // infer from data if 0
158156 xgb_call ! ( xgboost_sys:: XGDMatrixCreateFromCSCEx ( indptr. as_ptr( ) ,
@@ -349,7 +347,7 @@ mod tests {
349347
350348 #[ test]
351349 fn read_num_cols ( ) {
352- assert_eq ! ( read_train_matrix( ) . unwrap( ) . num_cols( ) , 126 ) ;
350+ assert_eq ! ( read_train_matrix( ) . unwrap( ) . num_cols( ) , 127 ) ;
353351 }
354352
355353 #[ test]
@@ -380,7 +378,7 @@ mod tests {
380378 #[ test]
381379 fn get_set_weights ( ) {
382380 let mut dmat = read_train_matrix ( ) . unwrap ( ) ;
383- assert_eq ! ( dmat. get_weights( ) . unwrap( ) , & [ ] ) ;
381+ assert ! ( dmat. get_weights( ) . unwrap( ) . is_empty ( ) ) ;
384382
385383 let weight = [ 1.0 , 10.0 , 44.9555 ] ;
386384 assert ! ( dmat. set_weights( & weight) . is_ok( ) ) ;
@@ -390,17 +388,20 @@ mod tests {
390388 #[ test]
391389 fn get_set_base_margin ( ) {
392390 let mut dmat = read_train_matrix ( ) . unwrap ( ) ;
393- assert_eq ! ( dmat. get_base_margin( ) . unwrap( ) , & [ ] ) ;
391+ assert ! ( dmat. get_base_margin( ) . unwrap( ) . is_empty ( ) ) ;
394392
395393 let base_margin = [ 0.00001 , 0.000002 , 1.23 ] ;
394+ println ! ( "rows: {:?}, {:?}" , dmat. num_rows( ) , base_margin. len( ) ) ;
395+ let result = dmat. set_base_margin ( & base_margin) ;
396+ println ! ( "{:?}" , result) ;
396397 assert ! ( dmat. set_base_margin( & base_margin) . is_ok( ) ) ;
397398 assert_eq ! ( dmat. get_base_margin( ) . unwrap( ) , base_margin) ;
398399 }
399400
400401 #[ test]
401402 fn get_set_group ( ) {
402403 let mut dmat = read_train_matrix ( ) . unwrap ( ) ;
403- assert_eq ! ( dmat. get_group( ) . unwrap( ) , & [ ] ) ;
404+ assert ! ( dmat. get_group( ) . unwrap( ) . is_empty ( ) ) ;
404405
405406 let group = [ 1 ] ;
406407 assert ! ( dmat. set_group( & group) . is_ok( ) ) ;
0 commit comments