- 
                Notifications
    
You must be signed in to change notification settings  - Fork 0
 
[Experimental] Add initial implementation of GSPMD->Shardy pass within PyTorch/XLA #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…sues. Will fix later
… into hshah/add-gspmd-to-shardy-pass
… into hshah/add-gspmd-to-shardy-pass
…xla into hshah/add-gspmd-to-shardy-pass
…sues. Will fix later
…xla into hshah/add-gspmd-to-shardy-pass
…d-gspmd-to-shardy-pass
| return t | ||
| 
               | 
          ||
| op_sharding = mesh.get_op_sharding(partition_spec) | ||
| if os.environ.get('CONVERT_SHLO_TO_SHARDY', False): | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be hidden under a different env var or does shardy inherently only understand V2?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shardy only understands V2. I wasn't able to get the pass working with a V1 graph, and also Kevin mentioned that V2 is a required work item for getting the Shardy pass working: pytorch#9348 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, awesome, then it can remain under the CONVERT_SHLO_TO_SHARDY flag.
…n PyTorch/XLA (#1) Adds an environment variable CONVERT_SHLO_TO_SHARDY that does 2 things: - Uses V2 sharding annotations when generating the GSPMD SHLO module (i.e., in V1 a mesh annotation string like: devices=[2,1,4]0,1,2,3,4,5,6,7 becomes this in V2: devices=[2,1,4]<=[8]). - Converts the new GSPMD module with the V2 annotations into a Shardy module.
…n PyTorch/XLA (#1) Adds an environment variable CONVERT_SHLO_TO_SHARDY that does 2 things: - Uses V2 sharding annotations when generating the GSPMD SHLO module (i.e., in V1 a mesh annotation string like: devices=[2,1,4]0,1,2,3,4,5,6,7 becomes this in V2: devices=[2,1,4]<=[8]). - Converts the new GSPMD module with the V2 annotations into a Shardy module.
…n PyTorch/XLA (#1) Adds an environment variable CONVERT_SHLO_TO_SHARDY that does 2 things: - Uses V2 sharding annotations when generating the GSPMD SHLO module (i.e., in V1 a mesh annotation string like: devices=[2,1,4]0,1,2,3,4,5,6,7 becomes this in V2: devices=[2,1,4]<=[8]). - Converts the new GSPMD module with the V2 annotations into a Shardy module.
…n PyTorch/XLA (#1) Adds an environment variable CONVERT_SHLO_TO_SHARDY that does 2 things: - Uses V2 sharding annotations when generating the GSPMD SHLO module (i.e., in V1 a mesh annotation string like: devices=[2,1,4]0,1,2,3,4,5,6,7 becomes this in V2: devices=[2,1,4]<=[8]). - Converts the new GSPMD module with the V2 annotations into a Shardy module.
…n PyTorch/XLA (#1) Adds an environment variable CONVERT_SHLO_TO_SHARDY that does 2 things: - Uses V2 sharding annotations when generating the GSPMD SHLO module (i.e., in V1 a mesh annotation string like: devices=[2,1,4]0,1,2,3,4,5,6,7 becomes this in V2: devices=[2,1,4]<=[8]). - Converts the new GSPMD module with the V2 annotations into a Shardy module.
Adds an environment variable
CONVERT_SHLO_TO_SHARDYthat does 2 things:devices=[2,1,4]0,1,2,3,4,5,6,7becomes this in V2:devices=[2,1,4]<=[8]).