@@ -81,15 +81,14 @@ ValueRef prepack_biases(
8181 ComputeGraph& graph,
8282 const ValueRef vref,
8383 const ValueRef weight,
84- const bool transposed) {
84+ const bool transposed,
85+ const api::StorageType storage_type,
86+ const api::GPUMemoryLayout memory_layout) {
8587 auto sizes = graph.get_sizes_of (weight);
8688 const int64_t out_channels = transposed ? sizes.at (1 ) : sizes.at (0 );
8789
8890 ValueRef v = graph.add_tensor (
89- {out_channels},
90- graph.get_dtype_of (weight),
91- api::kTexture2D ,
92- api::kWidthPacked );
91+ {out_channels}, graph.get_dtype_of (weight), storage_type, memory_layout);
9392 vTensorPtr t = graph.get_tensor (v);
9493
9594 api::ShaderInfo shader = get_nchw_to_image_shader (*t);
@@ -329,7 +328,13 @@ void add_conv2d_node(
329328
330329 ValueRef arg_in = prepack_if_tensor_ref (graph, in);
331330 ValueRef arg_weight = prepack_weights (graph, weight, method);
332- ValueRef arg_bias = prepack_biases (graph, bias, weight, transposed_val);
331+ ValueRef arg_bias = prepack_biases (
332+ graph,
333+ bias,
334+ weight,
335+ transposed_val,
336+ /* storage_type = */ api::kTexture2D ,
337+ /* memory_layout = */ api::kWidthPacked );
333338
334339 vTensorPtr t_in = graph.get_tensor (arg_in);
335340 vTensorPtr t_out = graph.get_tensor (out);
@@ -383,15 +388,16 @@ void add_conv1d_node(
383388 const ValueRef dilation,
384389 const ValueRef groups,
385390 const ValueRef out) {
386- if (graph.val_is_none (bias)) {
387- VK_THROW (" conv1d: Null bias is not supported yet!" );
388- }
389-
390391 ValueRef arg_in = prepack_if_tensor_ref (graph, in);
391392 ValueRef arg_weight =
392393 prepack_if_tensor_ref (graph, weight, graph.memory_layout_of (arg_in));
393- ValueRef arg_bias =
394- prepack_if_tensor_ref (graph, bias, graph.memory_layout_of (arg_in));
394+ ValueRef arg_bias = prepack_biases (
395+ graph,
396+ bias,
397+ weight,
398+ /* transposed = */ false ,
399+ /* storage_type = */ api::kTexture3D ,
400+ /* memory_layout = */ api::kChannelsPacked );
395401
396402 vTensorPtr t_in = graph.get_tensor (arg_in);
397403 vTensorPtr t_weight = graph.get_tensor (arg_weight);
0 commit comments