Skip to content

Commit 565569c

Browse files
committed
Add Candle-like indexer
1 parent adb7dcd commit 565569c

2 files changed

Lines changed: 144 additions & 0 deletions

File tree

crates/piston-core/src/indexer.rs

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
// Adapted from Candle: https://github.com/huggingface/candle/blob/main/candle-core/src/indexer.rs
2+
use crate::OpTensor;
3+
use anyhow::Error;
4+
use std::ops::{
5+
Bound, Range, RangeBounds, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive,
6+
};
7+
8+
impl OpTensor {
9+
fn index(&self, indexers: &[TensorIndexer]) -> Result<Self, Error> {
10+
let mut x = self.clone();
11+
let dims = self.shape().as_slice();
12+
let mut current_dim = 0;
13+
for (i, indexer) in indexers.iter().enumerate() {
14+
x = match indexer {
15+
TensorIndexer::Select(n) => x.narrow(current_dim, *n, 1)?.squeeze(current_dim)?,
16+
TensorIndexer::Narrow(left_bound, right_bound) => {
17+
let start = match left_bound {
18+
Bound::Included(n) => *n,
19+
Bound::Excluded(n) => *n + 1,
20+
Bound::Unbounded => 0,
21+
};
22+
let stop = match right_bound {
23+
Bound::Included(n) => *n + 1,
24+
Bound::Excluded(n) => *n,
25+
Bound::Unbounded => dims[i],
26+
};
27+
let out = x.narrow(current_dim, start, stop.saturating_sub(start))?;
28+
current_dim += 1;
29+
out
30+
}
31+
TensorIndexer::IndexSelect(indexes) => {
32+
if indexes.rank() != 1 {
33+
anyhow::bail!("multi-dimensional tensor indexing is not supported")
34+
}
35+
if indexes.device() != x.device() {
36+
anyhow::bail!("indexing device mismatch: index tensor is on {:?} but input tensor is on {:?}", indexes.device(), x.device())
37+
}
38+
let out = x.index_select(indexes.clone(), current_dim)?;
39+
current_dim += 1;
40+
out
41+
}
42+
TensorIndexer::Err(e) => anyhow::bail!("indexing error {e:?}"),
43+
};
44+
}
45+
Ok(x)
46+
}
47+
}
48+
49+
#[derive(Debug)]
50+
/// Generic structure used to index a slice of the tensor
51+
pub enum TensorIndexer {
52+
/// This selects the elements for which an index has some specific value.
53+
Select(usize),
54+
/// This is a regular slice, purely indexing a chunk of the tensor
55+
Narrow(Bound<usize>, Bound<usize>),
56+
/// Indexing via a 1d tensor
57+
IndexSelect(OpTensor),
58+
Err(Error),
59+
}
60+
61+
impl From<usize> for TensorIndexer {
62+
fn from(index: usize) -> Self {
63+
TensorIndexer::Select(index)
64+
}
65+
}
66+
67+
impl From<&OpTensor> for TensorIndexer {
68+
fn from(tensor: &OpTensor) -> Self {
69+
TensorIndexer::IndexSelect(tensor.clone())
70+
}
71+
}
72+
73+
trait RB: RangeBounds<usize> {}
74+
impl RB for Range<usize> {}
75+
impl RB for RangeFrom<usize> {}
76+
impl RB for RangeFull {}
77+
impl RB for RangeInclusive<usize> {}
78+
impl RB for RangeTo<usize> {}
79+
impl RB for RangeToInclusive<usize> {}
80+
81+
impl<T: RB> From<T> for TensorIndexer {
82+
fn from(range: T) -> Self {
83+
use std::ops::Bound::*;
84+
let start = match range.start_bound() {
85+
Included(idx) => Included(*idx),
86+
Excluded(idx) => Excluded(*idx),
87+
Unbounded => Unbounded,
88+
};
89+
let end = match range.end_bound() {
90+
Included(idx) => Included(*idx),
91+
Excluded(idx) => Excluded(*idx),
92+
Unbounded => Unbounded,
93+
};
94+
TensorIndexer::Narrow(start, end)
95+
}
96+
}
97+
98+
/// Trait used to implement multiple signatures for ease of use of the slicing
99+
/// of a tensor
100+
pub trait IndexOp<T> {
101+
/// Returns a slicing iterator which are the chunks of data necessary to
102+
/// reconstruct the desired tensor.
103+
fn i(&self, index: T) -> Result<OpTensor, Error>;
104+
}
105+
106+
impl<T> IndexOp<T> for OpTensor
107+
where
108+
T: Into<TensorIndexer>,
109+
{
110+
fn i(&self, index: T) -> Result<OpTensor, Error> {
111+
self.index(&[index.into()])
112+
}
113+
}
114+
115+
impl<A> IndexOp<(A,)> for OpTensor
116+
where
117+
A: Into<TensorIndexer>,
118+
{
119+
fn i(&self, (a,): (A,)) -> Result<OpTensor, Error> {
120+
self.index(&[a.into()])
121+
}
122+
}
123+
124+
macro_rules! index_op_tuple {
125+
($($t:ident),+) => {
126+
#[allow(non_snake_case)]
127+
impl<$($t),*> IndexOp<($($t,)*)> for OpTensor
128+
where
129+
$($t: Into<TensorIndexer>,)*
130+
{
131+
fn i(&self, ($($t,)*): ($($t,)*)) -> Result<OpTensor, Error> {
132+
self.index(&[$($t.into(),)*])
133+
}
134+
}
135+
};
136+
}
137+
138+
index_op_tuple!(A, B, C);
139+
index_op_tuple!(A, B, C, D);
140+
index_op_tuple!(A, B, C, D, E);
141+
index_op_tuple!(A, B, C, D, E, F);
142+
index_op_tuple!(A, B, C, D, E, F, G);

crates/piston-core/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ mod dtype;
77
mod enforcer;
88
mod executable;
99
mod gpu;
10+
mod indexer;
1011
mod ndarray_ext;
1112
mod op;
1213
mod ops;
@@ -28,6 +29,7 @@ pub use dtype::*;
2829
pub use enforcer::*;
2930
pub use executable::*;
3031
pub use gpu::*;
32+
pub use indexer::*;
3133
pub use ndarray_ext::*;
3234
pub use op::*;
3335
pub use ops::*;

0 commit comments

Comments
 (0)