Skip to content

Commit ca38b59

Browse files
committed
Fix some misspecified generics
1 parent 12df49e commit ca38b59

2 files changed

Lines changed: 18 additions & 18 deletions

File tree

crates/ratchet-core/src/tensor.rs

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1165,15 +1165,15 @@ impl Tensor {
11651165

11661166
#[cfg(feature = "rand")]
11671167
pub(crate) fn randn_impl<T: TensorDType + num_traits::Float>(
1168-
mean: f32,
1169-
std: f32,
1168+
mean: T,
1169+
std: T,
11701170
shape: Shape,
11711171
device: Device,
11721172
is_variable: bool,
11731173
) -> Result<Self> {
11741174
let rng = device.get_rng();
11751175
if device.is_cpu() {
1176-
let distr = Normal::new(mean as f64, std as f64).unwrap();
1176+
let distr = Normal::new(mean.to_f64().unwrap(), std.to_f64().unwrap()).unwrap();
11771177
let data = (0..shape.numel())
11781178
.map(|_| {
11791179
let sample: f64 = distr.sample(&mut *rng.write());
@@ -1193,14 +1193,14 @@ impl Tensor {
11931193
} else {
11941194
let meta = StorageView {
11951195
shape: shape.clone(),
1196-
dtype: DType::F32,
1196+
dtype: T::dtype(),
11971197
stride: Stride::from(&shape.clone()),
11981198
};
11991199
Ok(Self::new_impl(
12001200
LazyOp::FillRandn(FillRandn {
12011201
shape,
1202-
mean,
1203-
std,
1202+
mean: mean.to_f32().unwrap(),
1203+
std: std.to_f32().unwrap(),
12041204
seed: Some(rng.write().next_u32()),
12051205
}),
12061206
meta,
@@ -1213,8 +1213,8 @@ impl Tensor {
12131213

12141214
#[cfg(feature = "rand")]
12151215
pub fn randn<T: TensorDType + num_traits::Float>(
1216-
mean: f32,
1217-
std: f32,
1216+
mean: T,
1217+
std: T,
12181218
shape: Shape,
12191219
device: Device,
12201220
) -> Result<Self> {
@@ -1223,18 +1223,18 @@ impl Tensor {
12231223

12241224
#[cfg(feature = "rand")]
12251225
pub(crate) fn rand_impl<T: TensorDType + num_traits::Float>(
1226-
lo: f32,
1227-
up: f32,
1226+
lo: T,
1227+
up: T,
12281228
shape: Shape,
12291229
device: Device,
12301230
is_variable: bool,
12311231
) -> Result<Self> {
12321232
let rng = device.get_rng();
1233-
let distr = Uniform::new(lo, up);
1233+
let distr = Uniform::new(lo.to_f32().unwrap(), up.to_f32().unwrap());
12341234
let data = (0..shape.numel())
12351235
.map(|_| {
12361236
let sample: f32 = distr.sample(&mut *rng.write());
1237-
T::from(sample).expect("Failed to convert sample")
1237+
T::from(sample as f32).expect("Failed to convert sample")
12381238
})
12391239
.collect::<Vec<_>>();
12401240

@@ -1243,8 +1243,8 @@ impl Tensor {
12431243

12441244
#[cfg(feature = "rand")]
12451245
pub fn rand<T: TensorDType + num_traits::Float>(
1246-
lo: f32,
1247-
up: f32,
1246+
lo: T,
1247+
up: T,
12481248
shape: Shape,
12491249
device: Device,
12501250
) -> Result<Self> {

crates/ratchet-core/src/variable.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ impl Var {
6969

7070
#[cfg(feature = "rand")]
7171
pub fn rand<T: TensorDType + num_traits::Float>(
72-
lo: f32,
73-
up: f32,
72+
lo: T,
73+
up: T,
7474
shape: Shape,
7575
device: Device,
7676
) -> Result<Self> {
@@ -80,8 +80,8 @@ impl Var {
8080

8181
#[cfg(feature = "rand")]
8282
pub fn randn<T: TensorDType + num_traits::Float>(
83-
mean: f32,
84-
std: f32,
83+
mean: T,
84+
std: T,
8585
shape: Shape,
8686
device: Device,
8787
) -> Result<Self> {

0 commit comments

Comments
 (0)