Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion colossalai/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Colossal Inference is composed of three main components:

In this section we discuss how the colossal inference works and integrates with the `Shardformer` . The details can be found in our codes.

![Colossal-Inference](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/inference-arch.png)
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/inference-arch.png" alt="Colossal-Inference" style="zoom: 33%;"/>

## Roadmap of our implementation

Expand Down
45 changes: 6 additions & 39 deletions colossalai/shardformer/policies/auto_policy.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import importlib
from dataclasses import dataclass
from typing import Optional

import torch.nn as nn

from ..shard.shard_config import ShardConfig
from .base_policy import Policy

__all__ = ["PolicyLocation", "get_autopolicy", "import_policy"]
Expand Down Expand Up @@ -150,39 +148,12 @@ class PolicyLocation:
),
}

_INFER_POLICY_LIST = {
# LlaMa
"transformers.models.llama.modeling_llama.LlamaModel": PolicyLocation(
file_name="llama", class_name="LlamaModelInferPolicy"
),
"transformers.models.llama.modeling_llama.LlamaForCausalLM": PolicyLocation(
file_name="llama", class_name="LlamaModelInferPolicy"
),
# Bloom
"transformers.models.bloom.modeling_bloom.BloomModel": PolicyLocation(
file_name="bloom", class_name="BloomModelInferPolicy"
),
"transformers.models.bloom.modeling_bloom.BloomForCausalLM": PolicyLocation(
file_name="bloom", class_name="BloomModelInferPolicy"
),
# ChatGLM2
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel": PolicyLocation(
file_name="chatglm2", class_name="ChatGLM2InferPolicy"
),
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
file_name="chatglm2", class_name="ChatGLM2ForConditionalGenerationInferPolicy"
),
}


def import_policy(policy_location: PolicyLocation, inference_only: Optional[bool] = False) -> Policy:
def import_policy(policy_location: PolicyLocation) -> Policy:
"""
Dynamically import a Policy class based on the policy location.
"""
if inference_only:
module_name = f"colossalai.inference.tensor_parallel.policies.{policy_location.file_name}"
else:
module_name = f"colossalai.shardformer.policies.{policy_location.file_name}"
module_name = f"colossalai.shardformer.policies.{policy_location.file_name}"
module = importlib.import_module(module_name)
return getattr(module, policy_location.class_name)

Expand All @@ -198,7 +169,7 @@ def _fullname(obj):
return module + "." + klass.__qualname__


def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy:
def get_autopolicy(model: nn.Module) -> Policy:
r"""
Return the auto policy for the model

Expand All @@ -209,16 +180,12 @@ def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy
:class:`Policy`: The auto policy for the model
"""
full_name = _fullname(model)
inference_only = shard_config.extra_kwargs.get("inference_only", None)
if inference_only:
policy_location = _INFER_POLICY_LIST.get(full_name, None)
else:
policy_location = _POLICY_LIST.get(full_name, None)
policy_location = _POLICY_LIST.get(full_name, None)

if policy_location is None:
raise NotImplementedError(
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())} and {list(_INFER_POLICY_LIST.keys())}"
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}"
)
else:
policy = import_policy(policy_location, inference_only)
policy = import_policy(policy_location)
return policy()
2 changes: 1 addition & 1 deletion colossalai/shardformer/shard/sharder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class ModelSharder(object):
def __init__(self, model: nn.Module, policy: Policy, shard_config: ShardConfig = None) -> None:
self.model = model
self.shard_config = shard_config
self.policy = get_autopolicy(self.model, shard_config) if policy is None else policy
self.policy = get_autopolicy(self.model) if policy is None else policy

def shard(self) -> List[Dict[int, Tensor]]:
r"""
Expand Down
1 change: 0 additions & 1 deletion tests/test_infer/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def build_model(
enable_tensor_parallelism=enable_tensor_parallelism,
enable_flash_attention=enable_flash_attention,
enable_jit_fused=enable_jit_fused,
extra_kwargs={"inference_only": True},
)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
Expand Down