Skip to content

Conversation

@ibsidorenko
Copy link
Contributor

@ibsidorenko ibsidorenko commented Apr 16, 2024

This commit replaces fp16 compute dtype and scale dtype by fp32 in cublas matmul.

According to cuBLAS docs there are two possible options for compute/scale dtype when input/output dtype is fp16:

  1. compute dtype is fp16 and scale dtype is fp16
  2. compute dtype is fp32 and scale dtype is fp32

By default, we use 1) in apache/tvm and 2) in octoml/tvm. This commit aligns different behaviour and set fp32 as default.

cc @vinx13 @masahi

This commit replaces fp16 compute dtype and scale dtype by fp32 in
cublas matmul.
@github-actions github-actions bot requested review from masahi and vinx13 April 16, 2024 15:35
@vinx13 vinx13 merged commit 08965f0 into apache:main Apr 16, 2024
@ibsidorenko ibsidorenko deleted the cublas-fp32-compute-dtype branch April 17, 2024 08:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants