|
12 | 12 |
|
13 | 13 | # Define models, templates, and their corresponding expected outputs |
14 | 14 | MODEL_TEMPLATE_GENERATON_OUTPUT = [ |
15 | | - ("facebook/opt-125m", chatml_jinja_path, True, """<|im_start|>user |
| 15 | + ("facebook/opt-125m", chatml_jinja_path, True, False, """<|im_start|>user |
16 | 16 | Hello<|im_end|> |
17 | 17 | <|im_start|>assistant |
18 | 18 | Hi there!<|im_end|> |
19 | 19 | <|im_start|>user |
20 | 20 | What is the capital of<|im_end|> |
21 | 21 | <|im_start|>assistant |
22 | 22 | """), |
23 | | - ("facebook/opt-125m", chatml_jinja_path, False, """<|im_start|>user |
| 23 | + ("facebook/opt-125m", chatml_jinja_path, False, False, """<|im_start|>user |
24 | 24 | Hello<|im_end|> |
25 | 25 | <|im_start|>assistant |
26 | 26 | Hi there!<|im_end|> |
27 | 27 | <|im_start|>user |
28 | | -What is the capital of""") |
| 28 | +What is the capital of"""), |
| 29 | + ("facebook/opt-125m", chatml_jinja_path, False, True, """<|im_start|>user |
| 30 | +Hello<|im_end|> |
| 31 | +<|im_start|>assistant |
| 32 | +Hi there!<|im_end|> |
| 33 | +<|im_start|>user |
| 34 | +What is the capital of<|im_end|> |
| 35 | +<|im_start|>assistant |
| 36 | +The capital of"""), |
29 | 37 | ] |
30 | 38 |
|
31 | 39 | TEST_MESSAGES = [ |
|
42 | 50 | 'content': 'What is the capital of' |
43 | 51 | }, |
44 | 52 | ] |
| 53 | +ASSISTANT_MESSAGE_TO_CONTINUE = { |
| 54 | + 'role': 'assistant', |
| 55 | + 'content': 'The capital of' |
| 56 | +} |
45 | 57 |
|
46 | 58 |
|
47 | 59 | def test_load_chat_template(): |
@@ -73,26 +85,30 @@ def test_no_load_chat_template_literallike(): |
73 | 85 |
|
74 | 86 |
|
75 | 87 | @pytest.mark.parametrize( |
76 | | - "model,template,add_generation_prompt,expected_output", |
| 88 | + "model,template,add_generation_prompt,continue_final_message,expected_output", |
77 | 89 | MODEL_TEMPLATE_GENERATON_OUTPUT) |
78 | 90 | def test_get_gen_prompt(model, template, add_generation_prompt, |
79 | | - expected_output): |
| 91 | + continue_final_message, expected_output): |
80 | 92 | # Initialize the tokenizer |
81 | 93 | tokenizer = get_tokenizer(tokenizer_name=model) |
82 | 94 | template_content = load_chat_template(chat_template=template) |
83 | 95 |
|
84 | 96 | # Create a mock request object using keyword arguments |
85 | 97 | mock_request = ChatCompletionRequest( |
86 | 98 | model=model, |
87 | | - messages=TEST_MESSAGES, |
88 | | - add_generation_prompt=add_generation_prompt) |
| 99 | + messages=TEST_MESSAGES + [ASSISTANT_MESSAGE_TO_CONTINUE] |
| 100 | + if continue_final_message else TEST_MESSAGES, |
| 101 | + add_generation_prompt=add_generation_prompt, |
| 102 | + continue_final_message=continue_final_message, |
| 103 | + ) |
89 | 104 |
|
90 | 105 | # Call the function and get the result |
91 | 106 | result = apply_hf_chat_template( |
92 | 107 | tokenizer, |
93 | 108 | conversation=mock_request.messages, |
94 | 109 | chat_template=mock_request.chat_template or template_content, |
95 | 110 | add_generation_prompt=mock_request.add_generation_prompt, |
| 111 | + continue_final_message=mock_request.continue_final_message, |
96 | 112 | ) |
97 | 113 |
|
98 | 114 | # Test assertion |
|
0 commit comments