Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/scaffolding/run_best_of_n_with_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from tensorrt_llm.scaffolding.controller import (BestOfNController,
NativeGenerationController,
QwenRewardController)
PRMController)
from tensorrt_llm.scaffolding.scaffolding_llm import ScaffoldingLlm
from tensorrt_llm.scaffolding.worker import TRTLLMWorker

Expand Down Expand Up @@ -41,13 +41,13 @@ def main():
kv_cache_free_gpu_memory_fraction=0.2,
disable_overlap_scheduler=True)
workers[NativeGenerationController.WorkerTag.GENERATION] = gen_worker
workers[QwenRewardController.WorkerTag.REWARD] = reward_worker
workers[PRMController.WorkerTag.REWARD] = reward_worker

gen_controller = NativeGenerationController(sampling_params={
"max_tokens": 4096,
"temperature": 0.6,
})
reward_controller = QwenRewardController(tokenizer=reward_worker.tokenizer)
reward_controller = PRMController(tokenizer=reward_worker.tokenizer)
controller = BestOfNController(
generation_controller=gen_controller,
reward_controller=reward_controller,
Expand All @@ -61,7 +61,7 @@ def main():

results = llm.generate(prompts)
print(results[0].output.output_str)
llm.shutdown(shutdown_wokers=True)
llm.shutdown(shutdown_workers=True)
print(f'main shut down done')


Expand Down
186 changes: 114 additions & 72 deletions tensorrt_llm/scaffolding/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ def process(self, tasks: List[Task], **kwargs):

class NativeRewardController(Controller):

def __init__(self):
self.scores = None

class WorkerTag(Enum):
REWARD = "reward"

Expand All @@ -91,79 +94,112 @@ def process(self, tasks: List[Task], **kwargs):
yield tasks


class QwenRewardController(NativeRewardController):
class PRMController(NativeRewardController):
"""
Controller that integrate multi Generation output into one prompt and get
reward values from reward model.
Use PRM(Process Reward Model) to get the score of output. Will split
output into multi steps if `split_steps` is True. Otherwise will only
extract last token score.

Output:
The scores of each task will be stored in `self.scores`.

Example:
Suppose the model output is split using a special token like <extra_0>:
Input: "Step1,...<extra_0>Step2,...\\boxed{answer}.<extra_0>."
The function will mask out logits and remain only scores at separate_token.
Each represent the probability score for each step, eg: [0.98, 1.0].
We can assume the output is good when product of all probabilities is high.
"""

def __init__(self, tokenizer, separate_token="<extra_0>"): # nosec B107
def __init__(
self,
tokenizer,
split_steps=True,
step_token="\n\n",
separate_token="<extra_0>", # nosec B107
):
super().__init__()
self.tokenizer = tokenizer
self.split_steps = split_steps
self.step_token = step_token
self.separate_token = separate_token

def _make_step_rewards(self, logits, token_masks):
probabilities = F.softmax(logits, dim=-1)
probabilities = probabilities * token_masks.unsqueeze(
-1) # bs, seq_len, num_labels=2

all_scores_res = []
for i in range(probabilities.size(0)):
sample = probabilities[i] # seq_len, num_labels
positive_probs = sample[sample != 0].view(
-1, 2)[:, 1] # num_separate_tokens, num_labels
non_zero_elements_list = positive_probs.cpu().tolist()
all_scores_res.append(non_zero_elements_list)
return all_scores_res

def process(self, tasks: List[Task], **kwargs):
# Combine messages using chat template
content = "".join(
(task.output_str + self.separate_token) for task in tasks)
messages = [
{
"role":
"system",
"content":
"Please reason step by step, and put your final answer within \\boxed{}."
},
{
"role": "user",
"content": tasks[0].input_str
},
{
"role": "assistant",
"content": content
},
]
combined_prompt = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=False)

# TODO: support input_ids as model input, avoid doing it again in worker
merged_task = GenerationTask.create_from_prompt(combined_prompt)
merged_task.worker_tag = self.WorkerTag.REWARD
def _calc_steps_score(self, logits, token_mask):
probs = F.softmax(logits, dim=-1) # seq_len, num_labels=2
masked_probs = probs * token_mask.unsqueeze(-1)[0]

# TODO: pack this logic
merged_task.max_tokens = 1
merged_task.return_context_logits = True
# only keep the logits at the separate_token
step_probs = masked_probs[masked_probs != 0].view(-1, 2)[:, 1]
score = torch.prod(step_probs).item()
return score

yield [merged_task]
def _calc_last_token_score(self, logits):
# seq_len, num_labels=2
probs = F.softmax(logits, dim=-1)
score = probs[-1, 1].item()
return score

assert merged_task.context_logits is not None
# TODO: consider running on cpu to not interrupt worker or move
# tokenizer to a worker
input_ids = self.tokenizer.encode(
combined_prompt,
return_tensors="pt",
).to(merged_task.context_logits.device)

# TODO: align add_special_tokens with SamplingParams
token_masks = (input_ids == self.tokenizer.encode(
self.separate_token, add_special_tokens=True)[0])
all_scores_res = self._make_step_rewards(merged_task.context_logits,
token_masks)

return all_scores_res
def process(self, tasks: List[Task], **kwargs):
reward_tasks = []
for task in tasks:
if self.split_steps:
steps = task.output_str.split(self.step_token)
content = "".join(
(step + self.separate_token) for step in steps)
else:
content = self.separate_token + task.output_str + self.separate_token
# Combine messages using chat template
messages = [
{
"role":
"system",
"content":
"Please reason step by step, and put your final answer within \\boxed{}."
},
{
"role": "user",
"content": task.input_str
},
{
"role": "assistant",
"content": content
},
]
processed_prompt = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=False)

# TODO: support input_ids as model input, avoid doing it again in worker
reward_task = GenerationTask.create_from_prompt(processed_prompt)
reward_task.worker_tag = self.WorkerTag.REWARD

# TODO: pack this logic
reward_task.max_tokens = 1
reward_task.return_context_logits = True
reward_tasks.append(reward_task)

yield reward_tasks

scores = []
for reward_task in reward_tasks:
assert reward_task.context_logits is not None
# TODO: consider running on cpu to not interrupt worker or move
# tokenizer to a worker
input_ids = self.tokenizer.encode(
reward_task.input_str,
return_tensors="pt",
).to(reward_task.context_logits.device)

if self.split_steps:
# TODO: align add_special_tokens with SamplingParams
token_mask = (input_ids == self.tokenizer.encode(
self.separate_token, add_special_tokens=True)[0])
score = self._calc_steps_score(reward_task.context_logits,
token_mask)
else:
score = self._calc_last_token_score(reward_task.context_logits)
scores.append(score)

self.scores = scores


# Controller runs a single generation task with majority vote.
Expand Down Expand Up @@ -243,21 +279,27 @@ def process(self,
self.generation_controller for _ in range(sample_num)
]
generation_kwargs_list = [generation_kwargs for _ in range(sample_num)]
generation_tasks_list = [copy.deepcopy(task) for _ in range(sample_num)]
generation_tasks = [copy.deepcopy(task) for _ in range(sample_num)]

# yield from self.generation_controller.process(generation_tasks_list,
# **generation_kwargs)
yield ParallelProcess(generation_controllers,
[[t] for t in generation_tasks_list],
[[t] for t in generation_tasks],
generation_kwargs_list)

reward_values = yield from self.reward_controller.process(
generation_tasks_list, **reward_kwargs)
yield from self.reward_controller.process(generation_tasks,
**reward_kwargs)

assert self.reward_controller.scores is not None
reward_values = self.reward_controller.scores

for i, gen_task, reward_value in zip(range(sample_num),
generation_tasks, reward_values):
logger.info(
f"[output {i}, score {reward_value}]:\n{gen_task.output_str}")

best_task = self.select_best(generation_tasks_list, reward_values,
**select_best_kwargs)
best_task, best_idx = self.select_best(generation_tasks, reward_values,
**select_best_kwargs)
task.output_str = best_task.output_str

def select_best(self, tasks: List[Task], reward_values, **kwargs) -> Task:
max_index = torch.argmax(torch.tensor(reward_values)).item()
return tasks[max_index]
return tasks[max_index], max_index
8 changes: 4 additions & 4 deletions tests/unittest/scaffolding/test_scaffolding.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_unbatched_scaffolding_sync(default_prompt, deepseek_distill_7b_path):
result = scaffolding_llm.generate(default_prompt)
assert isinstance(result.output.output_str, str) and len(
result.output.output_str) > 0, "Output should be a non-empty string"
scaffolding_llm.shutdown(shutdown_wokers=True)
scaffolding_llm.shutdown(shutdown_workers=True)


def test_batched_scaffolding_sync(default_prompt, deepseek_distill_7b_path):
Expand All @@ -64,7 +64,7 @@ def test_batched_scaffolding_sync(default_prompt, deepseek_distill_7b_path):
for result in results:
assert isinstance(result.output.output_str, str) and len(
result.output.output_str) > 0, "Output should be a non-empty string"
scaffolding_llm.shutdown(shutdown_wokers=True)
scaffolding_llm.shutdown(shutdown_workers=True)


def test_async_scaffolding_generation(default_prompt, deepseek_distill_7b_path):
Expand All @@ -76,7 +76,7 @@ async def run_async_test():
result = await future.aresult()
assert isinstance(result.output.output_str, str) and len(
result.output.output_str) > 0, "Output should be a non-empty string"
scaffolding_llm.shutdown(shutdown_wokers=True)
scaffolding_llm.shutdown(shutdown_workers=True)

import asyncio
asyncio.run(run_async_test())
Expand All @@ -88,4 +88,4 @@ def test_majority_vote(default_prompt, deepseek_distill_7b_path):
result = scaffolding_llm.generate(default_prompt)
assert isinstance(result.output.output_str, str) and len(
result.output.output_str) > 0, "Output should be a non-empty string"
scaffolding_llm.shutdown(shutdown_wokers=True)
scaffolding_llm.shutdown(shutdown_workers=True)