Fixing datatype error for gpt-2 #18328
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Before this PR I was getting the following error:
Traceback (most recent call last): File "/home/thais/Dev/TVM-LLM-Bench/test-export-gpt.py", line 35, in <module> mod: tvm.IRModule = from_exported_program(exported_program) File "/home/thais/Dev/new-tvm/tvm/python/tvm/relax/frontend/torch/exported_program_translator.py", line 806, in from_exported_program return ExportedProgramImporter().from_exported_program( File "/home/thais/Dev/new-tvm/tvm/python/tvm/relax/frontend/torch/exported_program_translator.py", line 697, in from_exported_program self.env[node] = self.convert_map[func_name](node) File "/home/thais/Dev/new-tvm/tvm/python/tvm/relax/frontend/torch/base_fx_graph_translator.py", line 409, in convert return call_binary_op(relax_op, lhs, rhs) File "/home/thais/Dev/new-tvm/tvm/python/tvm/relax/frontend/torch/base_fx_graph_translator.py", line 405, in call_binary_op return self.block_builder.emit(op(lhs, rhs)) File "/home/thais/Dev/new-tvm/tvm/python/tvm/relax/block_builder.py", line 323, in emit return _ffi_api.BlockBuilderEmit(self, expr, name_hint) # type: ignore File "tvm/ffi/cython/./function.pxi", line 228, in tvm.ffi.core.Function.__call__ File "tvm/ffi/cython/core.cpp", line 24315, in __pyx_pw_3tvm_3ffi_4core_8Function_1__call__ File "tvm/ffi/cython/core.cpp", line 24369, in __pyx_pf_3tvm_3ffi_4core_8Function___call__ File "tvm/ffi/cython/core.cpp", line 23995, in __pyx_f_3tvm_3ffi_4core_FuncCall File "tvm/ffi/cython/core.cpp", line 23890, in __pyx_f_3tvm_3ffi_4core_FuncCall3 TypeError: Binary operators must have the same datatype for both operands. However, R.multiply(lv15, lv18) uses datatype float32 on the LHS (StructInfo of R.Tensor((14, 14), dtype="float32")), and datatype bool on the RHS (StructInfo of R.Tensor((14, 14), dtype="bool")).