1717
1818namespace vkcompute {
1919
20+ //
21+ // VTensorPtr
22+ //
23+
24+ #define VALUE_PTR_CLASS_IMPL (classname, ctype, type_name ) \
25+ classname::classname (ComputeGraph* const graph, const ValueRef idx) \
26+ : graph_(graph), ptr_(&(graph_->values_.at(idx).to##type_name())) { \
27+ graph_->values_in_use_ ++; \
28+ } \
29+ ctype* classname::operator ->() const { \
30+ return ptr_; \
31+ } \
32+ ctype& classname::operator *() const { \
33+ return *ptr_; \
34+ } \
35+ classname::~classname () { \
36+ graph_->values_in_use_ --; \
37+ }
38+
39+ VALUE_PTR_CLASS_IMPL (vTensorPtr, vTensor, Tensor)
40+ VALUE_PTR_CLASS_IMPL (StagingPtr, api::StorageBuffer, Staging)
41+ VALUE_PTR_CLASS_IMPL (IntListPtr, std::vector<int64_t >, IntList)
42+ VALUE_PTR_CLASS_IMPL (DoubleListPtr, std::vector<double >, DoubleList)
43+ VALUE_PTR_CLASS_IMPL (BoolListPtr, std::vector<bool >, BoolList)
44+ VALUE_PTR_CLASS_IMPL (ValueListPtr, std::vector<ValueRef>, ValueList)
45+
46+ #undef VALUE_PTR_CLASS_IMPL
47+
48+ //
49+ // ComputeGraph
50+ //
51+
2052ComputeGraph::ComputeGraph (GraphConfig config)
2153 : config_{config},
2254 prepack_descriptor_counts_{},
@@ -105,6 +137,35 @@ api::GPUMemoryLayout ComputeGraph::suggested_memory_layout(
105137 return api::kChannelsPacked ;
106138}
107139
140+ void ComputeGraph::check_no_active_value_ptrs () {
141+ VK_CHECK_COND (
142+ values_in_use_ == 0 ,
143+ " Make sure that there are no pointers stored from the return values of "
144+ " `ComputeGraph::get_*()` functions in scope before adding Values to the "
145+ " graph. Modifying the graph's values may cause existing pointers to be "
146+ " invalidated." );
147+ }
148+
149+ std::vector<int64_t > ComputeGraph::get_sizes_of (ValueRef idx) {
150+ Value& val = values_.at (idx);
151+ if (val.isTensor ()) {
152+ return val.toTensor ().sizes ();
153+ } else if (val.isTensorRef ()) {
154+ return val.toTensorRef ().sizes ;
155+ }
156+ VK_THROW (" Could not get sizes of value with type " , val.type ());
157+ }
158+
159+ api::ScalarType ComputeGraph::get_dtype_of (ValueRef idx) {
160+ Value& val = values_.at (idx);
161+ if (val.isTensor ()) {
162+ return val.toTensor ().dtype ();
163+ } else if (val.isTensorRef ()) {
164+ return val.toTensorRef ().dtype ;
165+ }
166+ VK_THROW (" Could not get dtype of value with type " , val.type ());
167+ }
168+
108169ValueRef ComputeGraph::add_tensor (
109170 const std::vector<int64_t >& sizes,
110171 const api::ScalarType dtype,
@@ -114,6 +175,7 @@ ValueRef ComputeGraph::add_tensor(
114175 bool allocate_memory = shared_object_idx < 0 ;
115176
116177 ValueRef idx (static_cast <int >(values_.size ()));
178+ check_no_active_value_ptrs ();
117179 values_.emplace_back (vTensor (
118180 context (), sizes, dtype, storage_type, memory_layout, allocate_memory));
119181
@@ -136,14 +198,14 @@ ValueRef ComputeGraph::add_tensor_like(
136198 const ValueRef vref,
137199 const api::StorageType storage_type,
138200 const api::GPUMemoryLayout memory_layout) {
139- TensorRef& tref = get_val (vref). toTensorRef ( );
201+ TensorRef tref = get_tref (vref);
140202 return add_tensor (tref.sizes , tref.dtype , storage_type, memory_layout);
141203}
142204
143205ValueRef ComputeGraph::add_tensor_like (
144206 const ValueRef vref,
145207 const api::GPUMemoryLayout memory_layout) {
146- TensorRef& tref = get_val (vref). toTensorRef ( );
208+ TensorRef tref = get_tref (vref);
147209 return add_tensor (tref.sizes , tref.dtype , memory_layout);
148210}
149211
@@ -160,6 +222,7 @@ ValueRef ComputeGraph::add_tensorref(
160222 const api::ScalarType dtype,
161223 const void * const data) {
162224 ValueRef idx (static_cast <int >(values_.size ()));
225+ check_no_active_value_ptrs ();
163226 values_.emplace_back (TensorRef (sizes, dtype, data));
164227 return idx;
165228}
@@ -168,24 +231,28 @@ ValueRef ComputeGraph::add_staging(
168231 const api::ScalarType dtype,
169232 const size_t numel) {
170233 ValueRef idx (static_cast <int >(values_.size ()));
234+ check_no_active_value_ptrs ();
171235 values_.emplace_back (api::StorageBuffer (context (), dtype, numel));
172236 return idx;
173237}
174238
175239ValueRef ComputeGraph::add_none () {
176240 ValueRef idx (static_cast <int >(values_.size ()));
241+ check_no_active_value_ptrs ();
177242 values_.emplace_back ();
178243 return idx;
179244}
180245
181246ValueRef ComputeGraph::add_value_list (std::vector<ValueRef>&& value) {
182247 ValueRef idx (static_cast <int >(values_.size ()));
248+ check_no_active_value_ptrs ();
183249 values_.emplace_back (std::move (value));
184250 return idx;
185251}
186252
187253ValueRef ComputeGraph::add_string (std::string&& str) {
188254 ValueRef idx (static_cast <int >(values_.size ()));
255+ check_no_active_value_ptrs ();
189256 values_.emplace_back (std::move (str));
190257 return idx;
191258}
@@ -194,8 +261,9 @@ ValueRef ComputeGraph::set_input_tensor(
194261 const ValueRef idx,
195262 const bool use_staging) {
196263 if (use_staging) {
197- vTensor& tensor = get_val (idx).toTensor ();
198- ValueRef staging_idx = add_staging (tensor.dtype (), tensor.gpu_numel ());
264+ api::ScalarType dtype = get_tensor (idx)->dtype ();
265+ size_t gpu_numel = get_tensor (idx)->gpu_numel ();
266+ ValueRef staging_idx = add_staging (dtype, gpu_numel);
199267 add_staging_to_tensor_node (*this , staging_idx, idx);
200268 inputs_.push_back ({idx, staging_idx});
201269 return staging_idx;
@@ -208,8 +276,9 @@ ValueRef ComputeGraph::set_output_tensor(
208276 const ValueRef idx,
209277 const bool use_staging) {
210278 if (use_staging) {
211- vTensor& tensor = get_val (idx).toTensor ();
212- ValueRef staging_idx = add_staging (tensor.dtype (), tensor.gpu_numel ());
279+ api::ScalarType dtype = get_tensor (idx)->dtype ();
280+ size_t gpu_numel = get_tensor (idx)->gpu_numel ();
281+ ValueRef staging_idx = add_staging (dtype, gpu_numel);
213282 add_tensor_to_staging_node (*this , idx, staging_idx);
214283 outputs_.push_back ({idx, staging_idx});
215284 return staging_idx;
@@ -229,20 +298,18 @@ void ComputeGraph::copy_into_staging(
229298 const ValueRef idx,
230299 const void * data,
231300 const size_t numel) {
232- Value& in_val = get_val (idx);
233- api::StorageBuffer& staging = in_val.toStaging ();
234- size_t nbytes = numel * api::element_size (staging.dtype ());
235- copy_ptr_to_staging (data, staging, nbytes);
301+ StagingPtr staging = get_staging (idx);
302+ size_t nbytes = numel * api::element_size (staging->dtype ());
303+ copy_ptr_to_staging (data, *staging, nbytes);
236304}
237305
238306void ComputeGraph::copy_from_staging (
239307 const ValueRef idx,
240308 void * data,
241309 const size_t numel) {
242- Value& out_val = get_val (idx);
243- api::StorageBuffer& staging = out_val.toStaging ();
244- size_t nbytes = numel * api::element_size (staging.dtype ());
245- copy_staging_to_ptr (staging, data, nbytes);
310+ StagingPtr staging = get_staging (idx);
311+ size_t nbytes = numel * api::element_size (staging->dtype ());
312+ copy_staging_to_ptr (*staging, data, nbytes);
246313}
247314
248315void ComputeGraph::prepare () {
@@ -308,7 +375,7 @@ void ComputeGraph::resize_input(
308375 const int64_t idx,
309376 const std::vector<int64_t >& new_sizes) {
310377 IOValueRef io_val = inputs_.at (idx);
311- get_val (io_val.value ). toTensor (). virtual_resize (new_sizes);
378+ get_tensor (io_val.value )-> virtual_resize (new_sizes);
312379}
313380
314381void ComputeGraph::propagate_resize () {
0 commit comments