Skip to content

Commit 8a4645f

Browse files
authored
Fix tokenization of <|constrain|> content type in rendering (#47)
1 parent 2387e4a commit 8a4645f

File tree

2 files changed

+49
-2
lines changed

2 files changed

+49
-2
lines changed

src/encoding.rs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -835,7 +835,22 @@ impl Render<Message> for HarmonyEncoding {
835835

836836
// finally content type
837837
if let Some(content_type) = &message.content_type {
838-
self.render_text_into(format!(" {content_type}"), into)?;
838+
// <|constrain|> is a unique case which needs to be tokenized as a special token
839+
if let Some(constrain_marker) = self.mapped_format_token(FormattingToken::ConstrainedFormat) {
840+
if content_type.starts_with(constrain_marker) {
841+
// Render the space, then the constrain marker as a special token, then the rest as text (if any)
842+
self.render_text_into(" ", into)?;
843+
self.render_formatting_token_into(FormattingToken::ConstrainedFormat, into)?;
844+
let rest = &content_type[constrain_marker.len()..];
845+
if !rest.is_empty() {
846+
self.render_text_into(rest, into)?;
847+
}
848+
} else {
849+
self.render_text_into(format!(" {content_type}"), into)?;
850+
}
851+
} else {
852+
self.render_text_into(format!(" {content_type}"), into)?;
853+
}
839854
}
840855

841856
self.render_formatting_token_into(FormattingToken::Message, into)?;

tests/test_harmony.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,36 @@ def test_simple_tool_call(encoding_name):
233233
assert parsed == expected
234234

235235

236+
@pytest.mark.parametrize(
237+
"encoding_name",
238+
[
239+
HarmonyEncodingName.HARMONY_GPT_OSS,
240+
],
241+
)
242+
def test_tool_call_with_constrain_tokenized_correctly(encoding_name):
243+
"""
244+
Despite passing <|constrain|> as a string in "content_type" it has to be kept as a special token.
245+
"""
246+
encoding = load_harmony_encoding(encoding_name)
247+
text = (
248+
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
249+
' <|constrain|>json<|message|>{"location": "Tokyo"}<|call|>'
250+
)
251+
tokens = encoding.encode(text, allowed_special="all")
252+
parsed = encoding.parse_messages_from_completion_tokens(tokens, role=None)
253+
expected = [
254+
Message.from_role_and_content(Role.ASSISTANT, '{"location": "Tokyo"}')
255+
.with_channel("commentary")
256+
.with_recipient("functions.get_weather")
257+
.with_content_type("<|constrain|>json"),
258+
]
259+
assert parsed == expected
260+
261+
rendered = encoding.render_conversation(Conversation.from_messages(expected))
262+
assert text == encoding.decode_utf8(tokens)
263+
assert rendered == tokens
264+
265+
236266
@pytest.mark.parametrize(
237267
"encoding_name",
238268
[
@@ -248,7 +278,7 @@ def test_tool_call_with_constrain_marker_adjacent(encoding_name):
248278
encoding = load_harmony_encoding(encoding_name)
249279
text = (
250280
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
251-
'<|constrain|>json<|message|>{"location": "Tokyo"}<|end|>'
281+
'<|constrain|>json<|message|>{"location": "Tokyo"}<|call|>'
252282
)
253283
tokens = encoding.encode(text, allowed_special="all")
254284
parsed = encoding.parse_messages_from_completion_tokens(tokens, role=None)
@@ -702,6 +732,8 @@ def test_does_not_drop_if_ongoing_analysis():
702732
)
703733

704734
assert encoding.decode_utf8(tokens) == expected_output
735+
# ensure that <|constrain|>json part is tokenized correctly as special tokens
736+
assert encoding.encode(expected_output, allowed_special="all") == tokens
705737

706738

707739
def test_preserve_cot():

0 commit comments

Comments
 (0)