@@ -1165,15 +1165,15 @@ impl Tensor {
11651165
11661166 #[ cfg( feature = "rand" ) ]
11671167 pub ( crate ) fn randn_impl < T : TensorDType + num_traits:: Float > (
1168- mean : f32 ,
1169- std : f32 ,
1168+ mean : T ,
1169+ std : T ,
11701170 shape : Shape ,
11711171 device : Device ,
11721172 is_variable : bool ,
11731173 ) -> Result < Self > {
11741174 let rng = device. get_rng ( ) ;
11751175 if device. is_cpu ( ) {
1176- let distr = Normal :: new ( mean as f64 , std as f64 ) . unwrap ( ) ;
1176+ let distr = Normal :: new ( mean. to_f64 ( ) . unwrap ( ) , std. to_f64 ( ) . unwrap ( ) ) . unwrap ( ) ;
11771177 let data = ( 0 ..shape. numel ( ) )
11781178 . map ( |_| {
11791179 let sample: f64 = distr. sample ( & mut * rng. write ( ) ) ;
@@ -1193,14 +1193,14 @@ impl Tensor {
11931193 } else {
11941194 let meta = StorageView {
11951195 shape : shape. clone ( ) ,
1196- dtype : DType :: F32 ,
1196+ dtype : T :: dtype ( ) ,
11971197 stride : Stride :: from ( & shape. clone ( ) ) ,
11981198 } ;
11991199 Ok ( Self :: new_impl (
12001200 LazyOp :: FillRandn ( FillRandn {
12011201 shape,
1202- mean,
1203- std,
1202+ mean : mean . to_f32 ( ) . unwrap ( ) ,
1203+ std : std . to_f32 ( ) . unwrap ( ) ,
12041204 seed : Some ( rng. write ( ) . next_u32 ( ) ) ,
12051205 } ) ,
12061206 meta,
@@ -1213,8 +1213,8 @@ impl Tensor {
12131213
12141214 #[ cfg( feature = "rand" ) ]
12151215 pub fn randn < T : TensorDType + num_traits:: Float > (
1216- mean : f32 ,
1217- std : f32 ,
1216+ mean : T ,
1217+ std : T ,
12181218 shape : Shape ,
12191219 device : Device ,
12201220 ) -> Result < Self > {
@@ -1223,18 +1223,18 @@ impl Tensor {
12231223
12241224 #[ cfg( feature = "rand" ) ]
12251225 pub ( crate ) fn rand_impl < T : TensorDType + num_traits:: Float > (
1226- lo : f32 ,
1227- up : f32 ,
1226+ lo : T ,
1227+ up : T ,
12281228 shape : Shape ,
12291229 device : Device ,
12301230 is_variable : bool ,
12311231 ) -> Result < Self > {
12321232 let rng = device. get_rng ( ) ;
1233- let distr = Uniform :: new ( lo, up) ;
1233+ let distr = Uniform :: new ( lo. to_f32 ( ) . unwrap ( ) , up. to_f32 ( ) . unwrap ( ) ) ;
12341234 let data = ( 0 ..shape. numel ( ) )
12351235 . map ( |_| {
12361236 let sample: f32 = distr. sample ( & mut * rng. write ( ) ) ;
1237- T :: from ( sample) . expect ( "Failed to convert sample" )
1237+ T :: from ( sample as f32 ) . expect ( "Failed to convert sample" )
12381238 } )
12391239 . collect :: < Vec < _ > > ( ) ;
12401240
@@ -1243,8 +1243,8 @@ impl Tensor {
12431243
12441244 #[ cfg( feature = "rand" ) ]
12451245 pub fn rand < T : TensorDType + num_traits:: Float > (
1246- lo : f32 ,
1247- up : f32 ,
1246+ lo : T ,
1247+ up : T ,
12481248 shape : Shape ,
12491249 device : Device ,
12501250 ) -> Result < Self > {
0 commit comments