diff --git a/examples/scaffolding/run_best_of_n_with_reward.py b/examples/scaffolding/run_best_of_n_with_reward.py index 767147cfc46..e451cf6b2c0 100644 --- a/examples/scaffolding/run_best_of_n_with_reward.py +++ b/examples/scaffolding/run_best_of_n_with_reward.py @@ -2,7 +2,7 @@ from tensorrt_llm.scaffolding.controller import (BestOfNController, NativeGenerationController, - QwenRewardController) + PRMController) from tensorrt_llm.scaffolding.scaffolding_llm import ScaffoldingLlm from tensorrt_llm.scaffolding.worker import TRTLLMWorker @@ -41,13 +41,13 @@ def main(): kv_cache_free_gpu_memory_fraction=0.2, disable_overlap_scheduler=True) workers[NativeGenerationController.WorkerTag.GENERATION] = gen_worker - workers[QwenRewardController.WorkerTag.REWARD] = reward_worker + workers[PRMController.WorkerTag.REWARD] = reward_worker gen_controller = NativeGenerationController(sampling_params={ "max_tokens": 4096, "temperature": 0.6, }) - reward_controller = QwenRewardController(tokenizer=reward_worker.tokenizer) + reward_controller = PRMController(tokenizer=reward_worker.tokenizer) controller = BestOfNController( generation_controller=gen_controller, reward_controller=reward_controller, @@ -61,7 +61,7 @@ def main(): results = llm.generate(prompts) print(results[0].output.output_str) - llm.shutdown(shutdown_wokers=True) + llm.shutdown(shutdown_workers=True) print(f'main shut down done') diff --git a/tensorrt_llm/scaffolding/controller.py b/tensorrt_llm/scaffolding/controller.py index 263f97541af..713c06385f3 100644 --- a/tensorrt_llm/scaffolding/controller.py +++ b/tensorrt_llm/scaffolding/controller.py @@ -80,6 +80,9 @@ def process(self, tasks: List[Task], **kwargs): class NativeRewardController(Controller): + def __init__(self): + self.scores = None + class WorkerTag(Enum): REWARD = "reward" @@ -91,79 +94,112 @@ def process(self, tasks: List[Task], **kwargs): yield tasks -class QwenRewardController(NativeRewardController): +class PRMController(NativeRewardController): """ - Controller that integrate multi Generation output into one prompt and get - reward values from reward model. + Use PRM(Process Reward Model) to get the score of output. Will split + output into multi steps if `split_steps` is True. Otherwise will only + extract last token score. + + Output: + The scores of each task will be stored in `self.scores`. + + Example: + Suppose the model output is split using a special token like : + Input: "Step1,...Step2,...\\boxed{answer}.." + The function will mask out logits and remain only scores at separate_token. + Each represent the probability score for each step, eg: [0.98, 1.0]. + We can assume the output is good when product of all probabilities is high. """ - def __init__(self, tokenizer, separate_token=""): # nosec B107 + def __init__( + self, + tokenizer, + split_steps=True, + step_token="\n\n", + separate_token="", # nosec B107 + ): super().__init__() self.tokenizer = tokenizer + self.split_steps = split_steps + self.step_token = step_token self.separate_token = separate_token - def _make_step_rewards(self, logits, token_masks): - probabilities = F.softmax(logits, dim=-1) - probabilities = probabilities * token_masks.unsqueeze( - -1) # bs, seq_len, num_labels=2 - - all_scores_res = [] - for i in range(probabilities.size(0)): - sample = probabilities[i] # seq_len, num_labels - positive_probs = sample[sample != 0].view( - -1, 2)[:, 1] # num_separate_tokens, num_labels - non_zero_elements_list = positive_probs.cpu().tolist() - all_scores_res.append(non_zero_elements_list) - return all_scores_res - - def process(self, tasks: List[Task], **kwargs): - # Combine messages using chat template - content = "".join( - (task.output_str + self.separate_token) for task in tasks) - messages = [ - { - "role": - "system", - "content": - "Please reason step by step, and put your final answer within \\boxed{}." - }, - { - "role": "user", - "content": tasks[0].input_str - }, - { - "role": "assistant", - "content": content - }, - ] - combined_prompt = self.tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=False) - - # TODO: support input_ids as model input, avoid doing it again in worker - merged_task = GenerationTask.create_from_prompt(combined_prompt) - merged_task.worker_tag = self.WorkerTag.REWARD + def _calc_steps_score(self, logits, token_mask): + probs = F.softmax(logits, dim=-1) # seq_len, num_labels=2 + masked_probs = probs * token_mask.unsqueeze(-1)[0] - # TODO: pack this logic - merged_task.max_tokens = 1 - merged_task.return_context_logits = True + # only keep the logits at the separate_token + step_probs = masked_probs[masked_probs != 0].view(-1, 2)[:, 1] + score = torch.prod(step_probs).item() + return score - yield [merged_task] + def _calc_last_token_score(self, logits): + # seq_len, num_labels=2 + probs = F.softmax(logits, dim=-1) + score = probs[-1, 1].item() + return score - assert merged_task.context_logits is not None - # TODO: consider running on cpu to not interrupt worker or move - # tokenizer to a worker - input_ids = self.tokenizer.encode( - combined_prompt, - return_tensors="pt", - ).to(merged_task.context_logits.device) - - # TODO: align add_special_tokens with SamplingParams - token_masks = (input_ids == self.tokenizer.encode( - self.separate_token, add_special_tokens=True)[0]) - all_scores_res = self._make_step_rewards(merged_task.context_logits, - token_masks) - - return all_scores_res + def process(self, tasks: List[Task], **kwargs): + reward_tasks = [] + for task in tasks: + if self.split_steps: + steps = task.output_str.split(self.step_token) + content = "".join( + (step + self.separate_token) for step in steps) + else: + content = self.separate_token + task.output_str + self.separate_token + # Combine messages using chat template + messages = [ + { + "role": + "system", + "content": + "Please reason step by step, and put your final answer within \\boxed{}." + }, + { + "role": "user", + "content": task.input_str + }, + { + "role": "assistant", + "content": content + }, + ] + processed_prompt = self.tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=False) + + # TODO: support input_ids as model input, avoid doing it again in worker + reward_task = GenerationTask.create_from_prompt(processed_prompt) + reward_task.worker_tag = self.WorkerTag.REWARD + + # TODO: pack this logic + reward_task.max_tokens = 1 + reward_task.return_context_logits = True + reward_tasks.append(reward_task) + + yield reward_tasks + + scores = [] + for reward_task in reward_tasks: + assert reward_task.context_logits is not None + # TODO: consider running on cpu to not interrupt worker or move + # tokenizer to a worker + input_ids = self.tokenizer.encode( + reward_task.input_str, + return_tensors="pt", + ).to(reward_task.context_logits.device) + + if self.split_steps: + # TODO: align add_special_tokens with SamplingParams + token_mask = (input_ids == self.tokenizer.encode( + self.separate_token, add_special_tokens=True)[0]) + score = self._calc_steps_score(reward_task.context_logits, + token_mask) + else: + score = self._calc_last_token_score(reward_task.context_logits) + scores.append(score) + + self.scores = scores # Controller runs a single generation task with majority vote. @@ -243,21 +279,27 @@ def process(self, self.generation_controller for _ in range(sample_num) ] generation_kwargs_list = [generation_kwargs for _ in range(sample_num)] - generation_tasks_list = [copy.deepcopy(task) for _ in range(sample_num)] + generation_tasks = [copy.deepcopy(task) for _ in range(sample_num)] - # yield from self.generation_controller.process(generation_tasks_list, - # **generation_kwargs) yield ParallelProcess(generation_controllers, - [[t] for t in generation_tasks_list], + [[t] for t in generation_tasks], generation_kwargs_list) - reward_values = yield from self.reward_controller.process( - generation_tasks_list, **reward_kwargs) + yield from self.reward_controller.process(generation_tasks, + **reward_kwargs) + + assert self.reward_controller.scores is not None + reward_values = self.reward_controller.scores + + for i, gen_task, reward_value in zip(range(sample_num), + generation_tasks, reward_values): + logger.info( + f"[output {i}, score {reward_value}]:\n{gen_task.output_str}") - best_task = self.select_best(generation_tasks_list, reward_values, - **select_best_kwargs) + best_task, best_idx = self.select_best(generation_tasks, reward_values, + **select_best_kwargs) task.output_str = best_task.output_str def select_best(self, tasks: List[Task], reward_values, **kwargs) -> Task: max_index = torch.argmax(torch.tensor(reward_values)).item() - return tasks[max_index] + return tasks[max_index], max_index diff --git a/tests/unittest/scaffolding/test_scaffolding.py b/tests/unittest/scaffolding/test_scaffolding.py index 80d96abcd49..b736ea64255 100644 --- a/tests/unittest/scaffolding/test_scaffolding.py +++ b/tests/unittest/scaffolding/test_scaffolding.py @@ -51,7 +51,7 @@ def test_unbatched_scaffolding_sync(default_prompt, deepseek_distill_7b_path): result = scaffolding_llm.generate(default_prompt) assert isinstance(result.output.output_str, str) and len( result.output.output_str) > 0, "Output should be a non-empty string" - scaffolding_llm.shutdown(shutdown_wokers=True) + scaffolding_llm.shutdown(shutdown_workers=True) def test_batched_scaffolding_sync(default_prompt, deepseek_distill_7b_path): @@ -64,7 +64,7 @@ def test_batched_scaffolding_sync(default_prompt, deepseek_distill_7b_path): for result in results: assert isinstance(result.output.output_str, str) and len( result.output.output_str) > 0, "Output should be a non-empty string" - scaffolding_llm.shutdown(shutdown_wokers=True) + scaffolding_llm.shutdown(shutdown_workers=True) def test_async_scaffolding_generation(default_prompt, deepseek_distill_7b_path): @@ -76,7 +76,7 @@ async def run_async_test(): result = await future.aresult() assert isinstance(result.output.output_str, str) and len( result.output.output_str) > 0, "Output should be a non-empty string" - scaffolding_llm.shutdown(shutdown_wokers=True) + scaffolding_llm.shutdown(shutdown_workers=True) import asyncio asyncio.run(run_async_test()) @@ -88,4 +88,4 @@ def test_majority_vote(default_prompt, deepseek_distill_7b_path): result = scaffolding_llm.generate(default_prompt) assert isinstance(result.output.output_str, str) and len( result.output.output_str) > 0, "Output should be a non-empty string" - scaffolding_llm.shutdown(shutdown_wokers=True) + scaffolding_llm.shutdown(shutdown_workers=True)