Skip to content

Conversation

@Lunderberg
Copy link
Contributor

If R.matmul(x, R.take(weights, indices)) occurs, with R.take selecting along the output feature dimension, it can be rearranged to R.take(R.matmul(x, weights), indices).

@tqchen
Copy link
Member

tqchen commented Jan 11, 2024

what is the usecase of this pass? Seems such reordering would increase the amount of compute

@Lunderberg
Copy link
Contributor Author

This pass would be part of an optimization pipeline that could be used for batched LoRA models. In this usage, the R.take(weights, indices) would select which LoRA should be used for each prompt in the batch. By rearranging the matmul, the resulting R.matmul(x, weights) could be combined with R.matmul(x, base_weights) using CombineParallelMatmul.

While this individual pass does increase the total amount of compute, the overall pipeline should (dependent on model size, number of loras, and batch size) improve performance by changing from three matmuls (one large matmul of the base weights and two small matmuls with the LoRA components) to two matmuls (one large matmul of concat(base_weights, lora_a) and one small matmul with lora_b).

@tqchen
Copy link
Member

tqchen commented Jan 11, 2024

I feel for LoRA maybe the best approach we need some specialized kernels. cc @yzh119 to see if you have more suggestions based on your experience

Copy link
Contributor

@slyubomirsky slyubomirsky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the discussion in #16313, I don't think it would be harmful to include this pass as a tool for other transformations even if we don't plan to use it in the standard compilation immediately.

@tqchen
Copy link
Member

tqchen commented Jan 19, 2024

i agree as long as it is optional

@Lunderberg
Copy link
Contributor Author

Sounds good, and thank you! I should have led with this being an optional pass to be used as needed, not intended as part of a default pass.

@Lunderberg Lunderberg force-pushed the unity_reorder_take_after_matmul branch from 17e1ac7 to 1581cf2 Compare January 22, 2024 17:55
@Lunderberg
Copy link
Contributor Author

Rebased onto main, now that the main branch includes unity. No other changes made, just wanting to avoid stale CI before merge.

@Lunderberg Lunderberg changed the base branch from unity to main January 22, 2024 17:58
If `R.matmul(x, R.take(weights, indices))` occurs, with `R.take`
selecting along the output feature dimension, it can be
rearranged to `R.take(R.matmul(x, weights), indices)`.
@Lunderberg Lunderberg force-pushed the unity_reorder_take_after_matmul branch from 1581cf2 to 65ae3c9 Compare January 23, 2024 16:20
@Lunderberg Lunderberg merged commit 2c49e01 into apache:main Jan 23, 2024
@Lunderberg Lunderberg deleted the unity_reorder_take_after_matmul branch January 23, 2024 22:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants