@@ -530,7 +530,8 @@ TEST_F(VulkanComputeAPITest, spec_var_classes_test) {
530530
531531TEST_F (VulkanComputeAPITest, spec_var_shader_test) {
532532 size_t len = 16 ;
533- StagingBuffer buffer (context (), vkapi::kFloat , len);
533+ StagingBuffer buffer (
534+ context (), vkapi::kFloat , len, vkapi::CopyDirection::DEVICE_TO_HOST);
534535
535536 float scale = 3 .0f ;
536537 float offset = 1 .5f ;
@@ -602,7 +603,10 @@ TEST_F(VulkanComputeAPITest, update_params_between_submit) {
602603 }
603604
604605 StagingBuffer staging_buffer (
605- context (), vkapi::kFloat , a.staging_buffer_numel ());
606+ context (),
607+ vkapi::kFloat ,
608+ a.staging_buffer_numel (),
609+ vkapi::CopyDirection::DEVICE_TO_HOST);
606610 record_image_to_nchw_op (context (), a, staging_buffer.buffer ());
607611
608612 submit_to_gpu ();
@@ -622,7 +626,8 @@ TEST_F(VulkanComputeAPITest, update_params_between_submit) {
622626
623627template <typename T, vkapi::ScalarType dtype>
624628void test_storage_buffer_type (const size_t len) {
625- StagingBuffer buffer (context (), dtype, len);
629+ StagingBuffer buffer (
630+ context (), dtype, len, vkapi::CopyDirection::DEVICE_TO_HOST);
626631
627632 std::string kernel_name (" idx_fill_buffer" );
628633 switch (dtype) {
@@ -2013,7 +2018,11 @@ void run_from_gpu_test(
20132018 vten.sizes_ubo ());
20142019 }
20152020
2016- StagingBuffer staging_buffer (context (), dtype, vten.staging_buffer_numel ());
2021+ StagingBuffer staging_buffer (
2022+ context (),
2023+ dtype,
2024+ vten.staging_buffer_numel (),
2025+ vkapi::CopyDirection::DEVICE_TO_HOST);
20172026
20182027 if (dtype == vkapi::kChar &&
20192028 !context ()->adapter_ptr ()->has_full_int8_buffers_support ()) {
@@ -2049,7 +2058,10 @@ void round_trip_test(
20492058
20502059 // Create and fill input staging buffer
20512060 StagingBuffer staging_buffer_in (
2052- context (), dtype, vten.staging_buffer_numel ());
2061+ context (),
2062+ dtype,
2063+ vten.staging_buffer_numel (),
2064+ vkapi::CopyDirection::HOST_TO_DEVICE);
20532065
20542066 std::vector<T> data_in (staging_buffer_in.numel ());
20552067 for (int i = 0 ; i < staging_buffer_in.numel (); i++) {
@@ -2059,7 +2071,10 @@ void round_trip_test(
20592071
20602072 // Output staging buffer
20612073 StagingBuffer staging_buffer_out (
2062- context (), dtype, vten.staging_buffer_numel ());
2074+ context (),
2075+ dtype,
2076+ vten.staging_buffer_numel (),
2077+ vkapi::CopyDirection::DEVICE_TO_HOST);
20632078
20642079 record_nchw_to_image_op (context (), staging_buffer_in.buffer (), vten);
20652080
0 commit comments