Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions python/mlc_llm/compiler_pass/rewrite_softmax.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""A compiler pass that rewrites one-shot softmax into two-stage softmax."""

import math

import tvm
from tvm import relax
from tvm.ir.module import IRModule
Expand Down Expand Up @@ -81,8 +79,6 @@ def visit_call_(self, call: relax.Call) -> Expr: # pylint: disable=arguments-re
def _get_lse_and_softmax_func( # pylint: disable=too-many-locals,too-many-statements
target: tvm.target.Target, chunk_size: int
):
log2e = math.log2(math.exp(1))

# pylint: disable=invalid-name
@T.prim_func
def chunk_lse(var_A: T.handle, var_chunked_lse: T.handle): # pylint: disable=too-many-locals
Expand Down Expand Up @@ -117,13 +113,13 @@ def chunk_lse(var_A: T.handle, var_chunked_lse: T.handle): # pylint: disable=to
temp_sum[v0, v1] = T.float32(0)
temp_sum[v0, v1] += T.if_then_else(
v1 * T.int64(chunk_size) + v2 < vocab_size,
T.exp2((A_pad[v0, v1, v2] - temp_max[v0, v1]) * log2e),
T.exp(A_pad[v0, v1, v2] - temp_max[v0, v1]),
T.float32(0),
)
for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(1)):
with T.block("log"):
v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2])
chunked_lse[v0, v1] = T.log2(temp_sum[v0, v1]) + temp_max[v0, v1] * log2e
chunked_lse[v0, v1] = T.log(temp_sum[v0, v1]) + temp_max[v0, v1]

@T.prim_func
def softmax_with_chunked_lse(var_A: T.handle, var_chunked_lse: T.handle, var_softmax: T.handle):
Expand All @@ -148,17 +144,17 @@ def softmax_with_chunked_lse(var_A: T.handle, var_chunked_lse: T.handle, var_sof
v0, v1 = T.axis.remap("SR", [l0, l1])
with T.init():
temp_sum[v0] = T.float32(0)
temp_sum[v0] += T.exp2(chunked_lse[v0, v1] - temp_max[v0])
temp_sum[v0] += T.exp(chunked_lse[v0, v1] - temp_max[v0])
for l0 in T.serial(0, batch_size):
with T.block("log"):
v0 = T.axis.remap("S", [l0])
lse[v0] = T.log2(temp_sum[v0]) + temp_max[v0]
lse[v0] = T.log(temp_sum[v0]) + temp_max[v0]
for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(chunk_size)):
with T.block("pad"):
v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2])
if v1 * T.int64(chunk_size) + v2 < vocab_size:
softmax[v0, v1 * T.int64(chunk_size) + v2] = T.exp2(
A[v0, v1 * T.int64(chunk_size) + v2] * log2e - lse[v0]
softmax[v0, v1 * T.int64(chunk_size) + v2] = T.exp(
A[v0, v1 * T.int64(chunk_size) + v2] - lse[v0]
)

sch = tvm.tir.Schedule(IRModule({"softmax_with_chunked_lse": softmax_with_chunked_lse}))
Expand Down