Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions tensorrt_llm/auto_parallel/parallelization.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import contextlib
import copy
import itertools
import pickle # nosec B403
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
Expand All @@ -12,6 +11,7 @@
import torch
from filelock import FileLock

from tensorrt_llm import serialization
from tensorrt_llm._utils import (str_dtype_to_trt, trt_dtype_to_np,
trt_dtype_to_torch)
from tensorrt_llm.functional import AllReduceParams, create_allreduce_plugin
Expand Down Expand Up @@ -39,6 +39,13 @@

default_int_dtype = trt.int64

# These dataclasses are used in ParallelConfig serialization. If there are other classes need to be serialized, please add to this list.
BASE_AUTOPP_CLASSES = {
"tensorrt_llm.auto_parallel.parallelization": ["ParallelConfig"],
"tensorrt_llm.auto_parallel.config": ["AutoParallelConfig", "CostModel"],
"tensorrt_llm.auto_parallel.simplifier": ["GraphConfig", "StageType"]
}


@dataclass
class ParallelConfig:
Expand All @@ -55,12 +62,13 @@ class ParallelConfig:

def save(self, filename):
with open(filename, 'wb') as file:
pickle.dump(self, file)
serialization.dump(self, file)

@staticmethod
def from_file(filename) -> "ParallelConfig":
with open(filename, "rb") as file:
return pickle.load(file) # nosec B301
return serialization.load(file,
approved_imports=BASE_AUTOPP_CLASSES)

def print_graph_strategy(self, file=None):
for index, (node_name,
Expand Down
File renamed without changes.
41 changes: 40 additions & 1 deletion tests/unittest/llmapi/test_serialization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import os
import tempfile

import torch

import tensorrt_llm.executor.serialization as serialization
from tensorrt_llm import serialization
from tensorrt_llm.auto_parallel.config import AutoParallelConfig
from tensorrt_llm.auto_parallel.parallelization import ParallelConfig
from tensorrt_llm.auto_parallel.simplifier import GraphConfig, StageType


class TestClass:
Expand Down Expand Up @@ -77,5 +83,38 @@ def test_serialization_complex_object_disallowed_class():
excep) == "Import torch._utils | _rebuild_tensor_v2 is not allowed"


def test_parallel_config_serialization():
with tempfile.TemporaryDirectory() as tmpdir:
# Create a ParallelConfig instance with some test data
config = ParallelConfig()
config.version = "test_version"
config.network_hash = "test_hash"
config.auto_parallel_config = AutoParallelConfig(
world_size=2, gpus_per_node=2, cluster_key="test_cluster")
config.graph_config = GraphConfig(num_micro_batches=2,
num_blocks=3,
num_stages=2)
config.cost = 1.5
config.stage_type = StageType.START

config_path = os.path.join(tmpdir, "parallel_config.pkl")
config.save(config_path)

loaded_config = ParallelConfig.from_file(config_path)

# Verify the loaded config matches the original
assert loaded_config.version == config.version
assert loaded_config.network_hash == config.network_hash
assert loaded_config.auto_parallel_config.world_size == config.auto_parallel_config.world_size
assert loaded_config.auto_parallel_config.gpus_per_node == config.auto_parallel_config.gpus_per_node
assert loaded_config.auto_parallel_config.cluster_key == config.auto_parallel_config.cluster_key
assert loaded_config.graph_config.num_micro_batches == config.graph_config.num_micro_batches
assert loaded_config.graph_config.num_blocks == config.graph_config.num_blocks
assert loaded_config.graph_config.num_stages == config.graph_config.num_stages
assert loaded_config.cost == config.cost
assert loaded_config.stage_type == config.stage_type


if __name__ == "__main__":
test_serialization_allowed_class()
test_parallel_config_serialization()