Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 35 additions & 22 deletions tests/unittest/_torch/modeling/test_modeling_nemotron_h.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pytest
import torch
from utils.llm_data import llm_models_root
from utils.util import skip_gpu_memory_less_than
Expand Down Expand Up @@ -238,15 +237,15 @@ def test_nemotron_h_correctness():
nemotron_h.shutdown()


@pytest.mark.skip(reason="https://nvbugs/5404046")
def test_nemotron_h_cuda_graph_overlap_scheduler():
prompts = [
"Tell me something I don't know about the future of AI",
"The president of the United States is",
"The capital of France is",
"Hello, this is a beautiful day and I'm eager to start my day and",
"The sky is blue because",
"The sum of two and two is",
"The largest mammal is the",
"The chemical symbol for water is",
]
sampling_config = SamplingParams(max_tokens=12,

sampling_config = SamplingParams(max_tokens=10,
temperature=0.0,
return_generation_logits=True)

Expand All @@ -273,32 +272,46 @@ def test_nemotron_h_cuda_graph_overlap_scheduler():
prompts, sampling_params=sampling_config, use_tqdm=True)

# Verify outputs are consistent
for (no_cg_no_overlap, with_cg_no_overlap,
with_cg_with_overlap) in zip(outputs_no_cg_no_overlap,
outputs_with_cg_no_overlap,
outputs_with_cg_with_overlap):

assert (no_cg_no_overlap.outputs[0].text ==
with_cg_no_overlap.outputs[0].text)
assert (with_cg_no_overlap.outputs[0].text ==
with_cg_with_overlap.outputs[0].text)
for i, (no_cg_no_overlap, with_cg_no_overlap,
with_cg_with_overlap) in enumerate(
zip(outputs_no_cg_no_overlap, outputs_with_cg_no_overlap,
outputs_with_cg_with_overlap)):

assert (
no_cg_no_overlap.outputs[0].text ==
with_cg_no_overlap.outputs[0].text
), f"Prompt {i}: no CG no overlap generated text != with CG no overlap generated text"
assert (
with_cg_no_overlap.outputs[0].text ==
with_cg_with_overlap.outputs[0].text
), f"Prompt {i}: with CG no overlap generated text != with CG with overlap generated text"

# similar to other unittests comparing with / without CG, compare logits of first generation step (2nd generated token)
torch.testing.assert_close(
no_cg_no_overlap.outputs[0].generation_logits[1, :],
with_cg_no_overlap.outputs[0].generation_logits[1, :],
atol=0.2,
rtol=0.2)
rtol=0.2,
msg=lambda x:
f"Prompt {i}: with/without CG (no overlap) logits for first generated step {x}"
)

# compare logprobs of all generated tokens
torch.testing.assert_close(extract_decode_logprobs(no_cg_no_overlap),
extract_decode_logprobs(with_cg_no_overlap),
atol=0.2,
rtol=0.2)
torch.testing.assert_close(
extract_decode_logprobs(no_cg_no_overlap),
extract_decode_logprobs(with_cg_no_overlap),
atol=0.2,
rtol=0.2,
msg=lambda x:
f"Prompt {i}: with/without CG (no overlap) logprobs for all selected tokens {x}"
)

# overlap scheduler should have no effect on all logits - low tolerance
torch.testing.assert_close(
with_cg_no_overlap.outputs[0].generation_logits,
with_cg_with_overlap.outputs[0].generation_logits,
atol=0.05,
rtol=0.05)
rtol=0.05,
msg=lambda x:
f"Prompt {i}: with/without overlap (no CG) all generation logits {x}"
)