diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index dada914b22d..789a0945366 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -132,16 +132,27 @@ ValueRef ComputeGraph::add_tensor( sizes, dtype, suggested_storage_type(), memory_layout, shared_object_idx); } +ValueRef ComputeGraph::add_tensor_like( + const ValueRef vref, + const api::StorageType storage_type, + const api::GPUMemoryLayout memory_layout) { + TensorRef& tref = get_val(vref).toTensorRef(); + return add_tensor(tref.sizes, tref.dtype, storage_type, memory_layout); +} + +ValueRef ComputeGraph::add_tensor_like( + const ValueRef vref, + const api::GPUMemoryLayout memory_layout) { + TensorRef& tref = get_val(vref).toTensorRef(); + return add_tensor(tref.sizes, tref.dtype, memory_layout); +} + ValueRef ComputeGraph::add_tensor( const std::vector& sizes, const api::ScalarType dtype, const int64_t shared_object_idx) { return add_tensor( - sizes, - dtype, - suggested_storage_type(), - suggested_memory_layout(sizes), - shared_object_idx); + sizes, dtype, suggested_memory_layout(sizes), shared_object_idx); } ValueRef ComputeGraph::add_tensorref( diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 00aa60020f3..24117d39f9f 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -172,7 +172,7 @@ class ComputeGraph final { const api::ScalarType dtype, const api::StorageType storage_type, const api::GPUMemoryLayout memory_layout, - const int64_t shared_object_idx); + const int64_t shared_object_idx = -1); /* * Add a `vTensor` value to the graph with the specified properties. The @@ -191,9 +191,25 @@ class ComputeGraph final { */ ValueRef add_tensor( const std::vector& sizes, - const api::ScalarType dtype = api::ScalarType::Float, + const api::ScalarType dtype, const int64_t shared_object_idx = -1); + /* + * Add a `vTensor` value to the graph with the properties of `vref`. + */ + ValueRef add_tensor_like( + const ValueRef vref, + const api::StorageType storage_type, + const api::GPUMemoryLayout memory_layout); + + /* + * Add a `vTensor` value to the graph with the properties of `vref`. The + * suggested storage type will be used to construct the `vTensor`. + */ + ValueRef add_tensor_like( + const ValueRef vref, + const api::GPUMemoryLayout memory_layout); + /* * Add a `TensorRef` value to the graph with the specific properties. A * `TensorRef` is a reference to a `vTensor` whose data is stored in an diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp index e08fad5df83..7d646a27111 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp @@ -63,8 +63,7 @@ ValueRef prepack( ComputeGraph& graph, const ValueRef vref, const api::GPUMemoryLayout layout) { - TensorRef& tref = graph.get_val(vref).toTensorRef(); - ValueRef v = graph.add_tensor(tref.sizes, tref.dtype, layout); + ValueRef v = graph.add_tensor_like(vref, layout); vTensor& t = graph.get_val(v).toTensor(); api::ShaderInfo shader = get_nchw_to_image_shader(t);