Skip to content

Commit d947d69

Browse files
authored
Merge pull request #162 from PSAL-POSTECH/ViT
ViT E2E model support
2 parents 4385e5a + ca4083f commit d947d69

19 files changed

Lines changed: 1670 additions & 176 deletions

.github/workflows/pytorchsim_test.yml

Lines changed: 224 additions & 6 deletions
Large diffs are not rendered by default.

Dockerfile.base

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ RUN wget https://github.com/riscv-collab/riscv-gnu-toolchain/releases/download/2
4242
rm *.tar.gz
4343

4444
# Install torchsim dependency
45-
RUN apt install ninja-build && pip install onnx matplotlib && pip install --user conan==1.56.0
45+
RUN apt install ninja-build && pip install onnx matplotlib && pip install --user conan==1.56.0 && pip install "transformers<4.44" && pip install diffusers==0.34.0
4646

4747
ENV RISCV=/workspace/riscv
4848
ENV PATH=$RISCV/bin:$PATH

PyTorchSimFrontend/extension_device.cpp

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <ATen/native/cpu/Loops.h>
1212
#include <ATen/native/DispatchStub.h>
1313
#include <ATen/native/Resize.h>
14+
#include <ATen/native/TensorFactories.h>
1415
#include <ATen/EmptyTensor.h>
1516
#include <ATen/core/GeneratorForPrivateuseone.h>
1617
#include <ATen/NativeFunctions.h>
@@ -204,19 +205,25 @@ at::Tensor custom__copy_from(const at::Tensor& self, const at::Tensor& dst, bool
204205

205206
// Some dummy asserts for the basic use case: inputs are the same size / dtype, all contiguous.
206207
TORCH_CHECK(self.sizes() == dst.sizes());
207-
TORCH_CHECK(self.scalar_type() == dst.scalar_type());
208208

209-
if (self.is_contiguous() && dst.is_contiguous()) {
209+
const bool same_dtype = (self.scalar_type() == dst.scalar_type());
210+
const bool both_contig = self.is_contiguous() && dst.is_contiguous();
211+
212+
// 1) fast path
213+
if (same_dtype && both_contig) {
210214
std::memcpy(dst.mutable_data_ptr(),
211215
self.data_ptr(),
212216
dst.storage().nbytes());
213-
} else {
214-
// Using cpu tensor to accomplishment stride copy.
215-
at::Tensor cpu_self = unsafe_create_cpu_tensor_from_dummy_tensor(self);
216-
at::Tensor cpu_dst = unsafe_create_cpu_tensor_from_dummy_tensor(dst);
217-
cpu_dst.copy_(cpu_self);
217+
return dst;
218218
}
219219

220+
// 2) slow path
221+
at::Tensor cpu_self = unsafe_create_cpu_tensor_from_dummy_tensor(self);
222+
at::Tensor cpu_dst = unsafe_create_cpu_tensor_from_dummy_tensor(dst);
223+
if (!same_dtype) {
224+
cpu_self = cpu_self.to(cpu_dst.scalar_type(), /*non_blocking=*/false, /*copy=*/true);
225+
}
226+
cpu_dst.copy_(cpu_self);
220227
return dst;
221228
}
222229

@@ -230,7 +237,6 @@ at::Tensor& custom_abs_out(const at::Tensor& self, at::Tensor& out) {
230237

231238
at::Tensor custom_empty_strided(c10::IntArrayRef size, c10::IntArrayRef stride, c10::optional<at::ScalarType> dtype_opt, c10::optional<at::Layout> layout_opt, c10::optional<at::Device> device_opt, c10::optional<bool> pin_memory_opt) {
232239
op_counter += 1;
233-
234240
constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1);
235241
auto dtype = c10::dtype_or_default(dtype_opt);
236242
return at::detail::empty_strided_generic(size, stride, &global_custom_alloc, private_use_ks, dtype);
@@ -244,7 +250,23 @@ at::Tensor custom_empty(c10::IntArrayRef size, c10::optional<at::ScalarType> dty
244250
return at::detail::empty_generic(size, &global_custom_alloc, private_use_ks, dtype, optional_memory_format);
245251
}
246252

247-
// This macro does the heavy lifting.
253+
at::Tensor& custom_arange_start_out_impl(
254+
const c10::Scalar& start,
255+
const c10::Scalar& end,
256+
const c10::Scalar& step,
257+
at::Tensor& out) {
258+
//const int64_t n = arange_len(start.toDouble(), end.toDouble(), step.toDouble());
259+
//at::native::resize_output(out, {n});
260+
return out;
261+
}
262+
263+
static at::Tensor custom_to_dtype_impl(const at::Tensor& self,
264+
c10::ScalarType dtype,
265+
bool non_blocking, bool copy,
266+
c10::optional<c10::MemoryFormat> memory_format) {
267+
return at::native::to(self, dtype, non_blocking, copy, memory_format);
268+
}
269+
248270
// With TORCH_LIBRARY_IMPL, you can register custom kernels for your backend.
249271
// For open registration, we're registering all of our kernels to the PrivateUse1 dispatch key.
250272
// Later in this file, we map a custom device to the PrivateUse1 device type,
@@ -255,21 +277,27 @@ at::Tensor custom_empty(c10::IntArrayRef size, c10::optional<at::ScalarType> dty
255277
// More details on the dispatcher can be found at http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/.
256278
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
257279
m.impl("to.Device", &custom_to_device);
280+
m.impl("to.dtype", &custom_to_dtype_impl);
258281
m.impl("fill_.Scalar", &custom_fill__scalar);
259282
m.impl("_copy_from", &custom__copy_from);
260283
m.impl("_copy_from_and_resize", &custom__copy_from_and_resize);
261284
m.impl("empty_strided", &custom_empty_strided);
262285
m.impl("empty.memory_format", &custom_empty);
263286
m.impl("as_strided", at::native::as_strided_tensorimpl);
264287
m.impl("view", at::native::view);
288+
m.impl("arange.start_out", &custom_arange_start_out_impl);
289+
}
290+
291+
TORCH_LIBRARY_IMPL(aten, AutogradPrivateUse1, m) {
292+
m.impl("to.dtype", &custom_to_dtype_impl);
265293
}
266294

267295
TORCH_LIBRARY_FRAGMENT(aten, m) {
268-
m.def(
269-
"_reinterpret_tensor(Tensor self, int[] size, int[] stride, int offset_increment=0) -> Tensor",
270-
torch::dispatch(
271-
c10::DispatchKey::AutogradPrivateUse1, _reinterpret_tensor),
272-
{at::Tag::pt2_compliant_tag});
296+
m.def(
297+
"_reinterpret_tensor(Tensor self, int[] size, int[] stride, int offset_increment=0) -> Tensor",
298+
torch::dispatch(
299+
c10::DispatchKey::AutogradPrivateUse1, _reinterpret_tensor),
300+
{at::Tag::pt2_compliant_tag});
273301
}
274302

275303
void custom_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
@@ -307,6 +335,9 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
307335
m.impl("_foreach_addcdiv_.ScalarList", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
308336
m.impl("_foreach_add_.List", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
309337
m.impl("cat.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
338+
m.impl("_native_multi_head_attention", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
339+
m.impl("resize_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
340+
m.impl("exp.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
310341
}
311342

312343
// This basic implementation doesn't bother dealing with different device indices

PyTorchSimFrontend/mlir/mlir_bmm_template.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,9 @@ def render(self,
173173

174174
W_tensor = empty_strided(W.layout.size, W.layout.stride)
175175
X_tensor = empty_strided(X.layout.size, X.layout.stride)
176-
if len(W_tensor.size()) > 3:
176+
if len(W_tensor.size()) > 3 or len(W_tensor.size()) == 2:
177177
W_tensor = W_tensor.view([-1, W_tensor.shape[-2], W_tensor.shape[-1]])
178-
if len(X_tensor.size()) > 3:
178+
if len(X_tensor.size()) > 3 or len(X_tensor.size()) == 2:
179179
X_tensor = X_tensor.view([-1, X_tensor.shape[-2], X_tensor.shape[-1]])
180180
B, M, N, K = X_tensor.size()[0], X_tensor.size()[1], W_tensor.size()[2], X_tensor.size()[2]
181181

@@ -217,6 +217,7 @@ def render(self,
217217
X_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride)
218218
X_tile_desc.set_tile_size_stride(X_tile_size, X_tile_stride)
219219
X_tile_desc.set_name("X_buffer")
220+
X_tile_desc.offset = X.get_layout().offset
220221
X_stride = X_tensor.stride()
221222
X_idx = [loop_dim[0]*X_stride[0], loop_dim[1]*X_stride[1], loop_dim[3]*X_stride[2]] # To keep index arguemnt order, we used index_list
222223

@@ -225,6 +226,7 @@ def render(self,
225226
W_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride)
226227
W_tile_desc.set_tile_size_stride(W_tile_size, W_tile_stride)
227228
W_tile_desc.set_name("W_buffer")
229+
W_tile_desc.offset = W.get_layout().offset
228230
W_stride = W_tensor.stride()
229231
W_idx = [loop_dim[0]*W_stride[0], loop_dim[3]*W_stride[1], loop_dim[2]*W_stride[2]]
230232

@@ -241,8 +243,12 @@ def render(self,
241243
Y_idx = [loop_dim[0]*Y_stride[0], loop_dim[2]*Y_stride[2], loop_dim[1]*Y_stride[1]]
242244

243245
# Extract Bias info
246+
Bias_tile_desc = mlir_common.MLIRMultiDimTile(Y_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride)
247+
Bias_tile_desc.set_tile_size_stride(Y_tile_size, Y_tile_stride)
248+
Bias_tile_desc.set_name("Y_buffer")
244249
if Bias is not None:
245250
Bias_stride = Bias.get_layout().stride
251+
Bias_tile_desc.offset = Bias.get_layout().offset
246252
if nr_rdim == 0:
247253
Bias_idx = [loop_dim[0]*Bias_stride[0], loop_dim[1]*Bias_stride[1], loop_dim[2]*Bias_stride[2]]
248254
else:

0 commit comments

Comments
 (0)