55#include < aclnnop/aclnn_layer_norm.h>
66#include < aclnnop/aclnn_repeat.h>
77#include < aclnnop/aclnn_softmax.h>
8+ #include < aclnnop/aclnn_upsample_nearest_2d.h>
89#include < aclnnop/aclnn_reduce_sum.h>
910
1011#include < cmath>
@@ -486,10 +487,6 @@ void ggml_cann_sum_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
486487 GGML_ASSERT (dst->ne [0 ] == 1 );
487488 aclTensor* acl_dst = create_acl_tensor (dst);
488489
489- uint64_t workspaceSize = 0 ;
490- aclOpExecutor* executor;
491- void * workspaceAddr = nullptr ;
492-
493490 int64_t reduce_dims_host[] = {3 };
494491 aclIntArray* reduce_dims = aclCreateIntArray (reduce_dims_host, 1 );
495492
@@ -503,6 +500,41 @@ void ggml_cann_sum_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
503500 aclrtStream stream = ctx.stream ();
504501 ACL_CHECK (aclnnReduceSum (workspaceAddr, workspaceSize, executor, stream));
505502
503+ ACL_CHECK (aclDestroyTensor (acl_src));
504+ ACL_CHECK (aclDestroyTensor (acl_dst));
505+ }
506+
507+ void ggml_cann_upsample_nearest2d (ggml_backend_cann_context& ctx,
508+ ggml_tensor* dst) {
509+
510+ ggml_tensor* src = dst->src [0 ];
511+
512+ aclTensor* acl_src = create_acl_tensor (src, nullptr , nullptr , 0 ,
513+ ACL_FORMAT_NCHW);
514+ aclTensor* acl_dst = create_acl_tensor (dst, nullptr , nullptr , 0 ,
515+ ACL_FORMAT_NCHW);
516+
517+ const int scale_factor = dst->op_params [0 ];
518+ std::vector<int64_t > output_size{dst->ne [1 ], dst->ne [0 ]};
519+ auto output_size_array = aclCreateIntArray (output_size.data (), 2 );
520+
521+ uint64_t workspaceSize = 0 ;
522+ aclOpExecutor* executor;
523+ void * workspaceAddr = nullptr ;
524+
525+ aclrtStream stream = ctx.stream ();
526+
527+ ACL_CHECK (aclnnUpsampleNearest2dGetWorkspaceSize (acl_src, output_size_array,
528+ acl_dst, &workspaceSize,
529+ &executor));
530+ if (workspaceSize > 0 ) {
531+ workspaceAddr = ctx.alloc_buffer (workspaceSize);
532+ }
533+
534+ ACL_CHECK (aclnnUpsampleNearest2d (workspaceAddr, workspaceSize, executor,
535+ stream));
536+
537+ ACL_CHECK (aclDestroyIntArray (output_size_array));
506538 ACL_CHECK (aclDestroyTensor (acl_src));
507539 ACL_CHECK (aclDestroyTensor (acl_dst));
508540}
0 commit comments