Skip to content

[Bug] gelu_new models on Metal with f32 - "Output probabilities are all NaNs" #1505

@CharlieFRuan

Description

@CharlieFRuan

🐛 Bug

For models that use gelu_new (gpt_neox, phi), if we run it in f32 on metal, we run into:

tvm-unity/src/runtime/relax_vm/lm_support.cc:487: The output probabilities are all NaNs, can not sample from it

If we use regular gelu, no issue is observed (gelu_new and gelu_tanh are synonymous), i.e.:

  • hidden_states = op.gelu(hidden_states, approximate="tanh") runs into issue
  • hidden_states = op.gelu(hidden_states) works fine

The issue is only observed on metal (M3 macbook, M2 Mac Studio). CUDA, vulkan, intel mac are all fine.

May need some step-by-step debugging to see when the value becomes NaN.

To Reproduce

Steps to reproduce the behavior:

  1. Compile phi-2 or gpt2 with f32 on metal:
  • python -m mlc_chat gen_config dist/models/phi-2 --quantization q0f32 -o dist/phi-2-q0f32/ --conv-template phi-2
  • python -m mlc_chat compile dist/phi-2-q0f32/mlc-chat-config.json --device metal -o dist/libs/phi-2-q0f32-metal.so
  • python -m mlc_chat convert_weight dist/models/phi-2/ --quantization q0f32 -o dist/phi-2-q0f32/
  1. Run the model (either cli or Python): ./build/mlc_chat_cli --model dist/phi-2-q0f32 --model-lib-path dist/libs/phi-2-q0f32-metal.so
Instruct: Hi
[04:50:15] /Users/cfruan/Documents/tvm-unity/src/runtime/relax_vm/lm_support.cc:487: The output probabilities are all NaNs, can not sample from it
Stack trace:
  [bt] (0) 1   libtvm_runtime.dylib                0x0000000100abcf3c tvm::runtime::detail::LogFatal::Entry::Finalize() + 68
  [bt] (1) 2   libtvm_runtime.dylib                0x0000000100abcef8 tvm::runtime::detail::LogFatal::Entry::Finalize() + 0
  [bt] (2) 3   libtvm_runtime.dylib                0x0000000100ab6d58 __clang_call_terminate + 0
  [bt] (3) 4   libtvm_runtime.dylib                0x0000000100b76d08 tvm::runtime::relax_vm::SampleTopPFromProb(tvm::runtime::NDArray, double, double) + 1244
  [bt] (4) 5   libtvm_runtime.dylib                0x0000000100b8088c void tvm::runtime::TypedPackedFunc<int (tvm::runtime::NDArray, double, double)>::AssignTypedLambda<int (*)(tvm::runtime::NDArray, double, double)>(int (*)(tvm::runtime::NDArray, double, double), std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>)::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const + 188
  [bt] (5) 6   libmlc_llm.dylib                    0x0000000101206f98 mlc::llm::LLMChat::SampleFromProbOnCPU(float) + 256
  [bt] (6) 7   libmlc_llm.dylib                    0x000000010120557c mlc::llm::LLMChat::SampleTokenFromLogits(tvm::runtime::NDArray, std::__1::unordered_map<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>, picojson::value, std::__1::hash<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>>, std::__1::equal_to<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>>, std::__1::allocator<std::__1::pair<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> const, picojson::value>>>) + 520
  [bt] (7) 8   libmlc_llm.dylib                    0x0000000101208190 mlc::llm::LLMChat::PrefillStep(std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>, bool, bool, mlc::llm::PlaceInPrompt, tvm::runtime::String) + 1920
  [bt] (8) 9   libmlc_llm.dylib                    0x0000000101207654 mlc::llm::LLMChatModule::GetFunction(tvm::runtime::String const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::'lambda3'(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const + 168

Additional context

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugConfirmed bugs

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions