1111#include < ATen/native/cpu/Loops.h>
1212#include < ATen/native/DispatchStub.h>
1313#include < ATen/native/Resize.h>
14+ #include < ATen/native/TensorFactories.h>
1415#include < ATen/EmptyTensor.h>
1516#include < ATen/core/GeneratorForPrivateuseone.h>
1617#include < ATen/NativeFunctions.h>
@@ -204,19 +205,25 @@ at::Tensor custom__copy_from(const at::Tensor& self, const at::Tensor& dst, bool
204205
205206 // Some dummy asserts for the basic use case: inputs are the same size / dtype, all contiguous.
206207 TORCH_CHECK (self.sizes () == dst.sizes ());
207- TORCH_CHECK (self.scalar_type () == dst.scalar_type ());
208208
209- if (self.is_contiguous () && dst.is_contiguous ()) {
209+ const bool same_dtype = (self.scalar_type () == dst.scalar_type ());
210+ const bool both_contig = self.is_contiguous () && dst.is_contiguous ();
211+
212+ // 1) fast path
213+ if (same_dtype && both_contig) {
210214 std::memcpy (dst.mutable_data_ptr (),
211215 self.data_ptr (),
212216 dst.storage ().nbytes ());
213- } else {
214- // Using cpu tensor to accomplishment stride copy.
215- at::Tensor cpu_self = unsafe_create_cpu_tensor_from_dummy_tensor (self);
216- at::Tensor cpu_dst = unsafe_create_cpu_tensor_from_dummy_tensor (dst);
217- cpu_dst.copy_ (cpu_self);
217+ return dst;
218218 }
219219
220+ // 2) slow path
221+ at::Tensor cpu_self = unsafe_create_cpu_tensor_from_dummy_tensor (self);
222+ at::Tensor cpu_dst = unsafe_create_cpu_tensor_from_dummy_tensor (dst);
223+ if (!same_dtype) {
224+ cpu_self = cpu_self.to (cpu_dst.scalar_type (), /* non_blocking=*/ false , /* copy=*/ true );
225+ }
226+ cpu_dst.copy_ (cpu_self);
220227 return dst;
221228}
222229
@@ -230,7 +237,6 @@ at::Tensor& custom_abs_out(const at::Tensor& self, at::Tensor& out) {
230237
231238at::Tensor custom_empty_strided (c10::IntArrayRef size, c10::IntArrayRef stride, c10::optional<at::ScalarType> dtype_opt, c10::optional<at::Layout> layout_opt, c10::optional<at::Device> device_opt, c10::optional<bool > pin_memory_opt) {
232239 op_counter += 1 ;
233-
234240 constexpr c10::DispatchKeySet private_use_ks (c10::DispatchKey::PrivateUse1);
235241 auto dtype = c10::dtype_or_default (dtype_opt);
236242 return at::detail::empty_strided_generic (size, stride, &global_custom_alloc, private_use_ks, dtype);
@@ -244,7 +250,23 @@ at::Tensor custom_empty(c10::IntArrayRef size, c10::optional<at::ScalarType> dty
244250 return at::detail::empty_generic (size, &global_custom_alloc, private_use_ks, dtype, optional_memory_format);
245251}
246252
247- // This macro does the heavy lifting.
253+ at::Tensor& custom_arange_start_out_impl (
254+ const c10::Scalar& start,
255+ const c10::Scalar& end,
256+ const c10::Scalar& step,
257+ at::Tensor& out) {
258+ // const int64_t n = arange_len(start.toDouble(), end.toDouble(), step.toDouble());
259+ // at::native::resize_output(out, {n});
260+ return out;
261+ }
262+
263+ static at::Tensor custom_to_dtype_impl (const at::Tensor& self,
264+ c10::ScalarType dtype,
265+ bool non_blocking, bool copy,
266+ c10::optional<c10::MemoryFormat> memory_format) {
267+ return at::native::to (self, dtype, non_blocking, copy, memory_format);
268+ }
269+
248270// With TORCH_LIBRARY_IMPL, you can register custom kernels for your backend.
249271// For open registration, we're registering all of our kernels to the PrivateUse1 dispatch key.
250272// Later in this file, we map a custom device to the PrivateUse1 device type,
@@ -255,21 +277,27 @@ at::Tensor custom_empty(c10::IntArrayRef size, c10::optional<at::ScalarType> dty
255277// More details on the dispatcher can be found at http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/.
256278TORCH_LIBRARY_IMPL (aten, PrivateUse1, m) {
257279 m.impl (" to.Device" , &custom_to_device);
280+ m.impl (" to.dtype" , &custom_to_dtype_impl);
258281 m.impl (" fill_.Scalar" , &custom_fill__scalar);
259282 m.impl (" _copy_from" , &custom__copy_from);
260283 m.impl (" _copy_from_and_resize" , &custom__copy_from_and_resize);
261284 m.impl (" empty_strided" , &custom_empty_strided);
262285 m.impl (" empty.memory_format" , &custom_empty);
263286 m.impl (" as_strided" , at::native::as_strided_tensorimpl);
264287 m.impl (" view" , at::native::view);
288+ m.impl (" arange.start_out" , &custom_arange_start_out_impl);
289+ }
290+
291+ TORCH_LIBRARY_IMPL (aten, AutogradPrivateUse1, m) {
292+ m.impl (" to.dtype" , &custom_to_dtype_impl);
265293}
266294
267295TORCH_LIBRARY_FRAGMENT (aten, m) {
268- m.def (
269- " _reinterpret_tensor(Tensor self, int[] size, int[] stride, int offset_increment=0) -> Tensor" ,
270- torch::dispatch (
271- c10::DispatchKey::AutogradPrivateUse1, _reinterpret_tensor),
272- {at::Tag::pt2_compliant_tag});
296+ m.def (
297+ " _reinterpret_tensor(Tensor self, int[] size, int[] stride, int offset_increment=0) -> Tensor" ,
298+ torch::dispatch (
299+ c10::DispatchKey::AutogradPrivateUse1, _reinterpret_tensor),
300+ {at::Tag::pt2_compliant_tag});
273301}
274302
275303void custom_cpu_fallback (const c10::OperatorHandle& op, torch::jit::Stack* stack) {
@@ -307,6 +335,9 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
307335 m.impl (" _foreach_addcdiv_.ScalarList" , torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
308336 m.impl (" _foreach_add_.List" , torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
309337 m.impl (" cat.out" , torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
338+ m.impl (" _native_multi_head_attention" , torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
339+ m.impl (" resize_" , torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
340+ m.impl (" exp.out" , torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
310341}
311342
312343// This basic implementation doesn't bother dealing with different device indices
0 commit comments