| 
 | 1 | +import pickle  | 
 | 2 | +from dataclasses import dataclass  | 
 | 3 | +from io import BufferedIOBase  | 
 | 4 | +from typing import Any, Dict, List, Tuple  | 
 | 5 | + | 
 | 6 | +import torch  | 
 | 7 | +import torch._weights_only_unpickler as _weights_only_unpickler  | 
 | 8 | +from torch.serialization import _load, _save, DEFAULT_PROTOCOL, MAP_LOCATION  | 
 | 9 | + | 
 | 10 | + | 
 | 11 | +__all__: List[str] = []  | 
 | 12 | + | 
 | 13 | + | 
 | 14 | +@dataclass  | 
 | 15 | +class _Entry:  | 
 | 16 | +    key: str  | 
 | 17 | +    is_storage: bool  | 
 | 18 | +    length: int  | 
 | 19 | + | 
 | 20 | + | 
 | 21 | +_weights_only_unpickler._add_safe_globals([_Entry])  | 
 | 22 | + | 
 | 23 | + | 
 | 24 | +class _PseudoZipFile:  | 
 | 25 | +    def __init__(self) -> None:  | 
 | 26 | +        self.records: Dict[str, Tuple[object, int]] = {}  | 
 | 27 | + | 
 | 28 | +    def write_record(self, key: str, data: object, length: int) -> None:  | 
 | 29 | +        self.records[key] = (data, length)  | 
 | 30 | + | 
 | 31 | +    def write_to(self, f: BufferedIOBase) -> None:  | 
 | 32 | +        entries = []  | 
 | 33 | +        for key, (data, length) in self.records.items():  | 
 | 34 | +            entries.append(  | 
 | 35 | +                _Entry(  | 
 | 36 | +                    key=key,  | 
 | 37 | +                    is_storage=isinstance(data, torch.UntypedStorage),  | 
 | 38 | +                    length=length,  | 
 | 39 | +                )  | 
 | 40 | +            )  | 
 | 41 | + | 
 | 42 | +        pickle.dump(entries, f, protocol=DEFAULT_PROTOCOL)  | 
 | 43 | + | 
 | 44 | +        for key, (data, length) in self.records.items():  | 
 | 45 | +            if isinstance(data, bytes):  | 
 | 46 | +                f.write(data)  | 
 | 47 | +            elif isinstance(data, str):  | 
 | 48 | +                f.write(data.encode("utf-8"))  | 
 | 49 | +            elif isinstance(data, torch.UntypedStorage):  | 
 | 50 | +                data._write_file(f, False, False, 1)  | 
 | 51 | +            else:  | 
 | 52 | +                raise TypeError(f"unknown type: {type(data)}")  | 
 | 53 | + | 
 | 54 | +    def read_from(self, f: BufferedIOBase) -> None:  | 
 | 55 | +        entries = _weights_only_unpickler.load(f)  | 
 | 56 | + | 
 | 57 | +        for entry in entries:  | 
 | 58 | +            data = f.read(entry.length)  | 
 | 59 | +            if entry.is_storage:  | 
 | 60 | +                storage = torch.frombuffer(  | 
 | 61 | +                    data,  | 
 | 62 | +                    dtype=torch.uint8,  | 
 | 63 | +                ).untyped_storage()  | 
 | 64 | + | 
 | 65 | +                self.records[entry.key] = (  | 
 | 66 | +                    storage,  | 
 | 67 | +                    entry.length,  | 
 | 68 | +                )  | 
 | 69 | +            else:  | 
 | 70 | +                self.records[entry.key] = (data, entry.length)  | 
 | 71 | + | 
 | 72 | +    def has_record(self, key: str) -> bool:  | 
 | 73 | +        return key in self.records  | 
 | 74 | + | 
 | 75 | +    def get_record(self, key: str) -> object:  | 
 | 76 | +        return self.records[key][0]  | 
 | 77 | + | 
 | 78 | +    def get_storage_from_record(  | 
 | 79 | +        self, key: str, _length: int, _type: int  | 
 | 80 | +    ) -> torch.Tensor:  | 
 | 81 | +        return torch.tensor(self.records[key][0], dtype=torch.uint8)  | 
 | 82 | + | 
 | 83 | +    def serialization_id(self) -> str:  | 
 | 84 | +        return "torchft"  | 
 | 85 | + | 
 | 86 | + | 
 | 87 | +def _streaming_save(  | 
 | 88 | +    obj: object,  | 
 | 89 | +    f: BufferedIOBase,  | 
 | 90 | +    pickle_module: Any = pickle,  | 
 | 91 | +    pickle_protocol: int = DEFAULT_PROTOCOL,  | 
 | 92 | +) -> None:  | 
 | 93 | +    """  | 
 | 94 | +    Save the object to a file-like object in a streaming fashion compatible with  | 
 | 95 | +    network sockets.  | 
 | 96 | +
  | 
 | 97 | +    This behaves similarly to :func:`torch.save` with a few notable differences:  | 
 | 98 | +
  | 
 | 99 | +    * A non-seekable file like object can be used when loading.  | 
 | 100 | +    * No forwards/backwards compatiblity is provided for the serialization  | 
 | 101 | +      format. This is only intended to be used with a single version of PyTorch  | 
 | 102 | +      with transient storage (i.e. sockets or temp files).  | 
 | 103 | +    * mmap is not supported  | 
 | 104 | +
  | 
 | 105 | +    See :func:`torch.save` for more details on specific arguments.  | 
 | 106 | +    """  | 
 | 107 | + | 
 | 108 | +    zip_file = _PseudoZipFile()  | 
 | 109 | +    _save(  | 
 | 110 | +        obj,  | 
 | 111 | +        zip_file=zip_file,  | 
 | 112 | +        pickle_module=pickle_module,  | 
 | 113 | +        pickle_protocol=pickle_protocol,  | 
 | 114 | +        _disable_byteorder_record=False,  | 
 | 115 | +    )  | 
 | 116 | +    zip_file.write_to(f)  | 
 | 117 | + | 
 | 118 | + | 
 | 119 | +def _streaming_load(  | 
 | 120 | +    f: BufferedIOBase,  | 
 | 121 | +    map_location: MAP_LOCATION = None,  | 
 | 122 | +    pickle_module: Any = None,  | 
 | 123 | +    *,  | 
 | 124 | +    weights_only: bool = True,  | 
 | 125 | +    **pickle_load_args: Any,  | 
 | 126 | +) -> object:  | 
 | 127 | +    """  | 
 | 128 | +    Load the object from a file-like object in a streaming fashion compatible with  | 
 | 129 | +    network sockets.  | 
 | 130 | +
  | 
 | 131 | +    See :func:`_streaming_save` for more details about the streaming behavior.  | 
 | 132 | +
  | 
 | 133 | +    See :func:`torch.load` for more details on specific arguments.  | 
 | 134 | +    """  | 
 | 135 | +    if weights_only:  | 
 | 136 | +        if pickle_module is not None:  | 
 | 137 | +            raise RuntimeError(  | 
 | 138 | +                "Can not safely load weights when explicit pickle_module is specified"  | 
 | 139 | +            )  | 
 | 140 | +        pickle_module = _weights_only_unpickler  | 
 | 141 | +    else:  | 
 | 142 | +        if pickle_module is None:  | 
 | 143 | +            pickle_module = pickle  | 
 | 144 | + | 
 | 145 | +    if "encoding" not in pickle_load_args.keys():  | 
 | 146 | +        pickle_load_args["encoding"] = "utf-8"  | 
 | 147 | + | 
 | 148 | +    zip_file = _PseudoZipFile()  | 
 | 149 | +    zip_file.read_from(f)  | 
 | 150 | +    return _load(  | 
 | 151 | +        zip_file=zip_file,  | 
 | 152 | +        map_location=map_location,  | 
 | 153 | +        pickle_module=pickle_module,  | 
 | 154 | +        **pickle_load_args,  | 
 | 155 | +    )  | 
0 commit comments