diff --git a/tensorrt_llm/auto_parallel/parallelization.py b/tensorrt_llm/auto_parallel/parallelization.py index d25d273d112..0e0e0d78c3f 100644 --- a/tensorrt_llm/auto_parallel/parallelization.py +++ b/tensorrt_llm/auto_parallel/parallelization.py @@ -1,7 +1,6 @@ import contextlib import copy import itertools -import pickle # nosec B403 import re from abc import ABC, abstractmethod from dataclasses import dataclass, field @@ -12,6 +11,7 @@ import torch from filelock import FileLock +from tensorrt_llm import serialization from tensorrt_llm._utils import (str_dtype_to_trt, trt_dtype_to_np, trt_dtype_to_torch) from tensorrt_llm.functional import AllReduceParams, create_allreduce_plugin @@ -39,6 +39,13 @@ default_int_dtype = trt.int64 +# These dataclasses are used in ParallelConfig serialization. If there are other classes need to be serialized, please add to this list. +BASE_AUTOPP_CLASSES = { + "tensorrt_llm.auto_parallel.parallelization": ["ParallelConfig"], + "tensorrt_llm.auto_parallel.config": ["AutoParallelConfig", "CostModel"], + "tensorrt_llm.auto_parallel.simplifier": ["GraphConfig", "StageType"] +} + @dataclass class ParallelConfig: @@ -55,12 +62,13 @@ class ParallelConfig: def save(self, filename): with open(filename, 'wb') as file: - pickle.dump(self, file) + serialization.dump(self, file) @staticmethod def from_file(filename) -> "ParallelConfig": with open(filename, "rb") as file: - return pickle.load(file) # nosec B301 + return serialization.load(file, + approved_imports=BASE_AUTOPP_CLASSES) def print_graph_strategy(self, file=None): for index, (node_name, diff --git a/tensorrt_llm/executor/serialization.py b/tensorrt_llm/serialization.py similarity index 100% rename from tensorrt_llm/executor/serialization.py rename to tensorrt_llm/serialization.py diff --git a/tests/unittest/llmapi/test_serialization.py b/tests/unittest/llmapi/test_serialization.py index e2bf6ed2df6..02caa0c72cc 100644 --- a/tests/unittest/llmapi/test_serialization.py +++ b/tests/unittest/llmapi/test_serialization.py @@ -1,6 +1,12 @@ +import os +import tempfile + import torch -import tensorrt_llm.executor.serialization as serialization +from tensorrt_llm import serialization +from tensorrt_llm.auto_parallel.config import AutoParallelConfig +from tensorrt_llm.auto_parallel.parallelization import ParallelConfig +from tensorrt_llm.auto_parallel.simplifier import GraphConfig, StageType class TestClass: @@ -77,5 +83,38 @@ def test_serialization_complex_object_disallowed_class(): excep) == "Import torch._utils | _rebuild_tensor_v2 is not allowed" +def test_parallel_config_serialization(): + with tempfile.TemporaryDirectory() as tmpdir: + # Create a ParallelConfig instance with some test data + config = ParallelConfig() + config.version = "test_version" + config.network_hash = "test_hash" + config.auto_parallel_config = AutoParallelConfig( + world_size=2, gpus_per_node=2, cluster_key="test_cluster") + config.graph_config = GraphConfig(num_micro_batches=2, + num_blocks=3, + num_stages=2) + config.cost = 1.5 + config.stage_type = StageType.START + + config_path = os.path.join(tmpdir, "parallel_config.pkl") + config.save(config_path) + + loaded_config = ParallelConfig.from_file(config_path) + + # Verify the loaded config matches the original + assert loaded_config.version == config.version + assert loaded_config.network_hash == config.network_hash + assert loaded_config.auto_parallel_config.world_size == config.auto_parallel_config.world_size + assert loaded_config.auto_parallel_config.gpus_per_node == config.auto_parallel_config.gpus_per_node + assert loaded_config.auto_parallel_config.cluster_key == config.auto_parallel_config.cluster_key + assert loaded_config.graph_config.num_micro_batches == config.graph_config.num_micro_batches + assert loaded_config.graph_config.num_blocks == config.graph_config.num_blocks + assert loaded_config.graph_config.num_stages == config.graph_config.num_stages + assert loaded_config.cost == config.cost + assert loaded_config.stage_type == config.stage_type + + if __name__ == "__main__": test_serialization_allowed_class() + test_parallel_config_serialization()