Skip to content

Conversation

@AndrewZhaoLuo
Copy link
Contributor

@tvm-bot
Copy link
Collaborator

tvm-bot commented Dec 1, 2022

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

@tqchen
Copy link
Member

tqchen commented Dec 1, 2022

I am not too sure if we should do this. Mainly because the fp16 intrinsic can be slow and software based on some of the kernels. How about considering passing fp32 as scalar argument and do live conversion in kernel when we look at GPU setting?

@shingjan
Copy link

Hit this problem while tuning some convolutions from resnet18 on cuda w. fp16. Comment to keep in the loop of this problem.

@AndrewZhaoLuo
Copy link
Contributor Author

@shingjan I will have time next week to look at this and try TQ' suggestion. One workaround for now is disabling CSE-TIR pass I believe for now.

@masahi
Copy link
Member

masahi commented Dec 15, 2022

Also getting the error

  2: tvm::runtime::CUDAModuleNode::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::Ob
jectPtr<tvm::runtime::Object> const&)                                                                                                                  
  1: tvm::runtime::PackedFunc tvm::runtime::PackFuncVoidAddr<tvm::runtime::CUDAWrappedFunc>(tvm::runtime::CUDAWrappedFunc, std::vector<DLDataType, std:
:allocator<DLDataType> > const&)                                                                                                                       
  0: tvm::runtime::detail::GetArgConvertCode(DLDataType)                                                                                               
  File "/home/masa/projects/dev/tvm/src/runtime/cuda/../pack_args.h", line 146                                                                         
  File "/home/masa/projects/dev/tvm/src/runtime/library_module.cc", line 80                                                                            
TVMError:                                                                                                                                              
---------------------------------------------------------------                                                                                        
An error occurred during the execution of TVM.                                                                                                         
For more information, please see: https://tvm.apache.org/docs/errors.html                                                                              
---------------------------------------------------------------                                                                                        
                                                                                                                                                       
  Check failed: ret == 0 (-1 vs. 0) : TVMError: Cannot handle float16 as device function argument  

when tuning winograd fp16 tasks from stable diffusion UNet. Normal convolution works.

@masahi
Copy link
Member

masahi commented Dec 15, 2022

One workaround for now is disabling CSE-TIR

Yup can confirm that disabling TIR CSE fixes this issue. Turns out the combination of CSE + SplitHostDevice is generating a signature like this

extern "C" __global__ void __launch_bounds__(256) main_kernel2(half* __restrict__ inverse, half* __restrict__ bgemm, half cse_var_12, half cse_var_10, half cse_var_3, half cse_var_9, half cse_var_5, half cse_var_8, half cse_var_15, half cse_var_7, half cse_var_6) {

and I'm not sure if this is a good idea.

@shingjan
Copy link

Same concern here. Right now I am just cherry-picking this pr and it worked. Will circle back hopefully next week.

@masahi
Copy link
Member

masahi commented Dec 15, 2022

Turns out, for winograd fp16, the issue was due to the following cse_vars

  let cse_var_11 = (0h*2h)
  let cse_var_10 = (0h*2.5h)
  let cse_var_9 = (0h*1h)
  let cse_var_8 = (0h*1.5h)
  let cse_var_7 = (0h*0h)
  let cse_var_6 = (0h*0.5h)
  let cse_var_5 = (0h*-2h)
  let cse_var_4 = (0h*-2.5h)
  let cse_var_3 = (0h*-1h)
  let cse_var_2 = (0h*-1.5h)
  let cse_var_1 = (0h*-0.5h)

which are used by two kernels in winograd, so they are computed on the host once and passed to the winograd kernels via SplitHostDevice. I'm going to update the arithmetic simplifier to remove such useless math (surprised it doesn't do it already).

@AndrewZhaoLuo I wonder, for layer_norm what variables are being passed between two kernels? Useless variables like above or something meaningful? Given that you said disabling CSE worksaround this issue, maybe it is the former. Then improved algebraic simplification can fix it as well.

@masahi
Copy link
Member

masahi commented Dec 16, 2022

Found a bug #13631

@shingjan This may fix your issue.

@AndrewZhaoLuo
Copy link
Contributor Author

Can confirm that this was a CSE issue which is now fixed for all the use cases I care about. masahi has a good summary in the above of what occured. I am closing for now to avoid delving deeper.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants