-
Notifications
You must be signed in to change notification settings - Fork 370
fix(module_fallback): Catching recursive search if method doesnt exist #619
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to C++ style guidelines:
diff --git a/workspace/core/lowering/lowering.cpp b/tmp/changes.txt
index 42be486..dae163a 100644
--- a/workspace/core/lowering/lowering.cpp
+++ b/tmp/changes.txt
@@ -60,10 +60,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
LOG_GRAPH(*g);
}
-torch::jit::Module LowerModule(
- const torch::jit::Module& mod,
- std::string method_name,
- const LowerInfo& lower_info) {
+torch::jit::Module LowerModule(const torch::jit::Module& mod, std::string method_name, const LowerInfo& lower_info) {
std::unordered_set<std::string> forced_fallback_modules(
lower_info.forced_fallback_modules.begin(), lower_info.forced_fallback_modules.end());
passes::NotateModuleForFallback(mod, "", method_name, forced_fallback_modules);
ERROR: Some files do not conform to style guidelinesThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to C++ style guidelines:
diff --git a/workspace/core/lowering/lowering.cpp b/tmp/changes.txt
index 42be486..dae163a 100644
--- a/workspace/core/lowering/lowering.cpp
+++ b/tmp/changes.txt
@@ -60,10 +60,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
LOG_GRAPH(*g);
}
-torch::jit::Module LowerModule(
- const torch::jit::Module& mod,
- std::string method_name,
- const LowerInfo& lower_info) {
+torch::jit::Module LowerModule(const torch::jit::Module& mod, std::string method_name, const LowerInfo& lower_info) {
std::unordered_set<std::string> forced_fallback_modules(
lower_info.forced_fallback_modules.begin(), lower_info.forced_fallback_modules.end());
passes::NotateModuleForFallback(mod, "", method_name, forced_fallback_modules);
ERROR: Some files do not conform to style guidelinesThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to C++ style guidelines:
diff --git a/workspace/core/lowering/lowering.cpp b/tmp/changes.txt
index 42be486..dae163a 100644
--- a/workspace/core/lowering/lowering.cpp
+++ b/tmp/changes.txt
@@ -60,10 +60,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
LOG_GRAPH(*g);
}
-torch::jit::Module LowerModule(
- const torch::jit::Module& mod,
- std::string method_name,
- const LowerInfo& lower_info) {
+torch::jit::Module LowerModule(const torch::jit::Module& mod, std::string method_name, const LowerInfo& lower_info) {
std::unordered_set<std::string> forced_fallback_modules(
lower_info.forced_fallback_modules.begin(), lower_info.forced_fallback_modules.end());
passes::NotateModuleForFallback(mod, "", method_name, forced_fallback_modules);
ERROR: Some files do not conform to style guidelinesc049894 to
29ded84
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
|
|
||
| for (const auto sub_mod : mod.named_children()) { | ||
| NotateModuleForFallback(sub_mod.value, sub_mod.name, method_name, forced_fallback_modules); | ||
| if (sub_mod.value.find_method(method_name)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there ever a case where we might pass in a name other than "forward" as the method name here? In that case, would the desired behavior be to only recurse through submodules with that specific method (and ignore others, and their children)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What we really should be doing is finding the Method calls and using that to recurse vs all children probably
peri044
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One minor change maybe when forced_fallback_modules is empty ?
PR runs fine and BERT 1.9 graph passes
core/lowering/lowering.cpp
Outdated
| torch::jit::Module LowerModule(const torch::jit::Module& mod, std::string method_name, const LowerInfo& lower_info) { | ||
| std::unordered_set<std::string> forced_fallback_modules( | ||
| lower_info.forced_fallback_modules.begin(), lower_info.forced_fallback_modules.end()); | ||
| passes::NotateModuleForFallback(mod, "", method_name, forced_fallback_modules); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When the forced_fallback_modules is empty, we don't need to have this pass so we can disable it. This can prevent iterating through all the nodes in graph. Disabling this also fixed BERT issue but your other fix is necessary.
|
Can we merge this ? Seems like we are getting duplicate "Method forward is not defined" bugs. |
|
Yeah when can we expect to get this patch as a new version? |
|
We are trying to get out a patch release in the next couple weeks. The issue is I think this WAR is probably not the correct solution. I am trying to determine if a formal solution could be implemented in a reasonable amount of time |
This commit fixes module level fallback by using method calls to determine modules to recurse down too. This should be robust to names other than forward used for methods as well as ignoring functional modules. Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
f7240e3 to
f94ae8f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
peri044
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Tests pass
|
@peri044 can you also test with a network you know failed previously? Also try with module level fallback specifying a module that is not part of the module to be compiled |
|
@narendasan Tested a network (DETR) which failed previously. Tried a module which is not a part of the graph. Works fine. |
Description
There was an issue where the module fallback notation pass just assumes
methods exist. This should catch that. This might be something we may want to release as a patch
Fixes #613, (potentially #609, Also probably partially solves: #608)
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: