99import sys
1010from pathlib import Path
1111from typing import Any
12+ #from typing import Any, Literal, NamedTuple, TypeVar, Union
1213
1314import numpy as np
15+ import numpy .typing as npt
1416
1517# Necessary to load the local gguf package
1618if "NO_LOCAL_GGUF" not in os .environ and (Path (__file__ ).parent .parent .parent / 'gguf-py' ).exists ():
1719 sys .path .insert (0 , str (Path (__file__ ).parent .parent ))
1820
19- from gguf import GGUFReader , GGUFWriter , ReaderField , GGUFEndian , GGUFValueType , Keys # noqa: E402
21+ from gguf import GGUFReader , GGUFWriter , ReaderField , GGMLQuantizationType , GGUFEndian , GGUFValueType , Keys # noqa: E402
2022
2123logger = logging .getLogger ("gguf-addfile" )
2224
@@ -54,17 +56,11 @@ def decode_field(field: ReaderField) -> Any:
5456 sub_type = field .types [- 1 ]
5557
5658 if sub_type == GGUFValueType .STRING :
57- if not field .name [0 ] == Keys .General .FILE_MARK :
58- return [str (bytes (field .parts [idx ]), encoding = 'utf8' ) for idx in field .data ]
59- else :
60- return [bytes (field .parts [idx ]) for idx in field .data ]
59+ return [str (bytes (field .parts [idx ]), encoding = 'utf8' ) for idx in field .data ]
6160 else :
6261 return [pv for idx in field .data for pv in field .parts [idx ].tolist ()]
6362 if main_type == GGUFValueType .STRING :
64- if not field .name [0 ] == Keys .General .FILE_MARK :
65- return str (bytes (field .parts [- 1 ]), encoding = 'utf8' )
66- else :
67- return bytes (field .parts [- 1 ])
63+ return str (bytes (field .parts [- 1 ]), encoding = 'utf8' )
6864 else :
6965 return field .parts [- 1 ][0 ]
7066
@@ -77,54 +73,64 @@ def get_field_data(reader: GGUFReader, key: str) -> Any:
7773 return decode_field (field )
7874
7975
80- def copy_with_new_metadata (reader : gguf .GGUFReader , writer : gguf .GGUFWriter , new_metadata : Mapping [str , str ]) -> None :
76+ def copy_with_filename (reader : gguf .GGUFReader , writer : gguf .GGUFWriter , filename : str [Any ]) -> None :
77+ logger .debug (f'copy_with_filename: { filename } ' ) #debug
78+ val = filename
8179 for field in reader .fields .values ():
8280 # Suppress virtual fields and fields written by GGUFWriter
8381 if field .name == Keys .General .ARCHITECTURE or field .name .startswith ('GGUF.' ):
8482 logger .debug (f'Suppressing { field .name } ' )
8583 continue
8684
87- # Skip old chat templates if we have new ones
88- if field .name .startswith (Keys .Tokenizer .CHAT_TEMPLATE ) and Keys .Tokenizer .CHAT_TEMPLATE in new_metadata :
89- logger .debug (f'Skipping { field .name } ' )
85+ # Copy existed fields except 'embedded_files'
86+ if not field .name == Keys .EMBEDDED_FILES :
87+ cur_val = decode_field (field )
88+ writer .add_key (field .name )
89+ writer .add_val (cur_val , field .types [0 ])
90+ logger .debug (f'Copying { field .name } ' )
9091 continue
9192
92- old_val = decode_field (field )
93- val = new_metadata .get (field .name , old_val )
94-
95- if field .name in new_metadata :
96- logger .debug (f'Modifying { field .name } : "{ old_val } " -> "{ val } "' )
97- del new_metadata [field .name ]
98- elif val is not None :
99- logger .debug (f'Copying { field .name } ' )
93+ # Update embedded_files
94+ val = decode_field (field )
95+ for path in filename :
96+ logger .debug (f'Adding { field .name } : { path } ' )
97+ val .append (path )
10098
101- if val is not None :
102- writer .add_key (field .name )
103- writer .add_val (val , field .types [0 ])
104-
105- if Keys .Tokenizer .CHAT_TEMPLATE in new_metadata :
106- logger .debug ('Adding chat template(s)' )
107- writer .add_chat_template (new_metadata [Keys .Tokenizer .CHAT_TEMPLATE ])
108- del new_metadata [Keys .Tokenizer .CHAT_TEMPLATE ]
109-
110- for key , name in new_metadata .items ():
111- logger .debug (f'Adding { key } : { name } ' )
112- with open (name , "rb" ) as f :
113- val = f .read ()
114- writer .add_object (key , val )
99+ # Add filenames to kv
100+ logger .info (f'* Modifying { Keys .EMBEDDED_FILES } to { val } ' )
101+ writer .add_array (Keys .EMBEDDED_FILES , val )
115102
116103 for tensor in reader .tensors :
117104 # Dimensions are written in reverse order, so flip them first
118105 shape = np .flipud (tensor .shape )
119106 writer .add_tensor_info (tensor .name , shape , tensor .data .dtype , tensor .data .nbytes , tensor .tensor_type )
120107
108+ # Add file info as tensor_info
109+ for path in filename :
110+ logger .debug (f'Adding tensor_info { path } ' )
111+ with open (path , "rb" ) as f :
112+ data = f .read ()
113+ data_len = len (data )
114+ dims = [data_len ]
115+ raw_dtype = GGMLQuantizationType .I8
116+ writer .add_tensor_info (path , dims , np .float16 , data_len , raw_dtype )
117+
121118 writer .write_header_to_file ()
122119 writer .write_kv_data_to_file ()
123120 writer .write_ti_data_to_file ()
124121
125122 for tensor in reader .tensors :
126123 writer .write_tensor_data (tensor .data )
127124
125+ # Write file body as tensor data
126+ for path in filename :
127+ logger .debug (f'Adding tensor data { path } ' )
128+ with open (path , "rb" ) as f :
129+ data = f .read ()
130+ data_len = len (data )
131+ # write data with padding
132+ writer .write_data (data )
133+
128134 writer .close ()
129135
130136
@@ -133,6 +139,7 @@ def main() -> None:
133139 parser .add_argument ("input" , type = str , help = "GGUF format model input filename" )
134140 parser .add_argument ("output" , type = str , help = "GGUF format model output filename" )
135141 parser .add_argument ("addfiles" , type = str , nargs = '+' , help = "add filenames ..." )
142+ parser .add_argument ("--force" , action = "store_true" , help = "Bypass warnings without confirmation" )
136143 parser .add_argument ("--verbose" , action = "store_true" , help = "Increase output verbosity" )
137144 args = parser .parse_args (None if len (sys .argv ) > 1 else ["--help" ])
138145 logging .basicConfig (level = logging .DEBUG if args .verbose else logging .INFO )
@@ -142,6 +149,15 @@ def main() -> None:
142149 arch = get_field_data (reader , Keys .General .ARCHITECTURE )
143150 endianess = get_byteorder (reader )
144151
152+ if os .path .isfile (args .output ) and not args .force :
153+ logger .warning ('*** Warning *** Warning *** Warning **' )
154+ logger .warning (f'* The "{ args .output } " GGUF file already exists, it will be overwritten!' )
155+ logger .warning ('* Enter exactly YES if you are positive you want to proceed:' )
156+ response = input ('YES, I am sure> ' )
157+ if response != 'YES' :
158+ logger .info ("You didn't enter YES. Okay then, see ya!" )
159+ sys .exit (0 )
160+
145161 logger .info (f'* Writing: { args .output } ' )
146162 writer = GGUFWriter (args .output , arch = arch , endianess = endianess )
147163
@@ -150,15 +166,13 @@ def main() -> None:
150166 logger .debug (f'Setting custom alignment: { alignment } ' )
151167 writer .data_alignment = alignment
152168
153- logger .info (f'* Adding: { args .addfiles } ' )
154- new_metadata = {}
155- for path in args .addfiles :
156- # add FILE_MARK to key
157- key = Keys .General .FILE_MARK + path
158- new_metadata [key ] = path
159- logger .info (f'* Adding: { key } = { path } ' )
160- copy_with_new_metadata (reader , writer , new_metadata )
161-
169+ if args .addfiles is not None :
170+ filename = []
171+ for path in args .addfiles :
172+ filename .append (path )
173+ logger .info (f'* Adding: { path } ' )
174+ copy_with_filename (reader , writer , filename )
175+
162176
163177if __name__ == '__main__' :
164178 main ()
0 commit comments