Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/serve/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class ModelImpl : public ModelObj {
memory::MemoryManager::GetOrCreateAllocator(device_host, memory::AllocatorType::kNaive);
ICHECK_NOTNULL(allocator);
token_ids_storage_ =
memory::Storage(allocator->Alloc(device_host, {prefill_chunk_size_}, DataType::Int(32)));
memory::Storage(allocator->Alloc(device_host, {prefill_chunk_size_}, DataType::Int(32)), allocator);
this->logit_pos_arr_ = NDArray::Empty({max_num_sequence}, DataType::Int(32), device_host);
}

Expand Down
8 changes: 8 additions & 0 deletions python/mlc_llm/compiler_pass/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from tvm.relax import register_pipeline # pylint: disable=no-name-in-module
from tvm.relax.frontend import nn

from mlc_llm.interface.compiler_flags import AllReduceStrategyType
from mlc_llm.support import logging

from .attach_embedding_allocator import AttachAllocEmbeddingTensorFunc
Expand Down Expand Up @@ -75,6 +76,7 @@ def _mlc_llm_pipeline( # pylint: disable=too-many-arguments
flashinfer: bool = False,
cublas_gemm: bool = False,
faster_transformer: bool = False, # pylint: disable=unused-argument
allreduce_strategy: AllReduceStrategyType = AllReduceStrategyType.RING,
variable_bounds: Dict[str, int] = None,
additional_tirs: Dict[str, tvm.tir.PrimFunc] = None,
metadata: Dict[str, Any] = None,
Expand Down Expand Up @@ -147,7 +149,13 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
tvm.relax.transform.ToNonDataflow(),
tvm.relax.transform.RemovePurityChecking(),
tvm.relax.transform.CallTIRRewrite(),
(
tvm.relax.transform.IPCAllReduceRewrite(allreduce_strategy)
if allreduce_strategy != AllReduceStrategyType.RING
else tvm.transform.Sequential([])
),
tvm.relax.transform.StaticPlanBlockMemory(),
tvm.relax.transform.LowerGPUIPCAllocStorage(),
AttachMetadataWithMemoryUsage(metadata),
tvm.relax.transform.RewriteCUDAGraph(),
tvm.relax.transform.LowerAllocTensor(),
Expand Down
1 change: 1 addition & 0 deletions python/mlc_llm/interface/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def _find_kv_cache_bytes(model: nn.Module, model_config) -> int:
flashinfer=args.opt.flashinfer,
cublas_gemm=args.opt.cublas_gemm,
faster_transformer=args.opt.faster_transformer,
allreduce_strategy=args.opt.allreduce_strategy,
variable_bounds=variable_bounds,
additional_tirs=additional_tirs,
ext_mods=ext_mods,
Expand Down
20 changes: 20 additions & 0 deletions python/mlc_llm/interface/compiler_flags.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Flags for overriding model config."""

import dataclasses
import enum
import re
from io import StringIO
from typing import Optional
Expand All @@ -13,6 +14,14 @@
logger = logging.getLogger(__name__)


class AllReduceStrategyType(enum.IntEnum):
"""The all-reduce strategy."""

RING = 0
ONESHOT = 1
TWOSHOT = 2


@dataclasses.dataclass
class OptimizationFlags:
"""Optimization flags"""
Expand All @@ -22,6 +31,7 @@ class OptimizationFlags:
faster_transformer: bool = False
cudagraph: bool = False
cutlass: bool = False
allreduce_strategy: AllReduceStrategyType = AllReduceStrategyType.RING

def __repr__(self) -> str:
out = StringIO()
Expand All @@ -30,6 +40,7 @@ def __repr__(self) -> str:
print(f";faster_transformer={int(self.faster_transformer)}", file=out, end="")
print(f";cudagraph={int(self.cudagraph)}", file=out, end="")
print(f";cutlass={int(self.cutlass)}", file=out, end="")
print(f";allreduce_strategy={self.allreduce_strategy.name}", file=out, end="")
return out.getvalue().rstrip()

@staticmethod
Expand All @@ -52,13 +63,22 @@ def boolean(value: str) -> bool:
parser.add_argument("--faster_transformer", type=boolean, default=False)
parser.add_argument("--cudagraph", type=boolean, default=False)
parser.add_argument("--cutlass", type=boolean, default=False)
parser.add_argument(
"--allreduce-strategy",
type=str,
choices=["ring", "one-shot", "two-shot"],
default="ring",
)
results = parser.parse_args([f"--{i}" for i in source.split(";") if i])
return OptimizationFlags(
flashinfer=results.flashinfer,
cublas_gemm=results.cublas_gemm,
faster_transformer=results.faster_transformer,
cudagraph=results.cudagraph,
cutlass=results.cutlass,
allreduce_strategy=AllReduceStrategyType[
results.allreduce_strategy.replace("-", "").upper()
],
)

def update(self, target, quantization) -> None:
Expand Down