Skip to content

Commit dc64529

Browse files
committed
update tests
Signed-off-by: Ming Yang <[email protected]>
1 parent 54be252 commit dc64529

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

hopper/test_flash_attn.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,9 @@
122122
],
123123
)
124124
@pytest.mark.parametrize(
125-
"cp_world_size", [4, 2],
125+
"cp_world_size", [4, 2, 1], # 1 means disabling cp
126126
)
127-
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
127+
#@pytest.mark.parametrize('seqlen_q,seqlen_k', [(1, 1)])
128128
def test_flash_attn_output(
129129
seqlen_q, seqlen_k, d, causal, local, softcap, V_colmajor, deterministic, has_qv_, mha_type, dtype, test_sink,
130130
cp_world_size,
@@ -135,6 +135,8 @@ def test_flash_attn_output(
135135
pytest.skip("Has Qv requires hdim 64 and dtype to be float16 or bfloat16 (not float8_e4m3fn)")
136136
if test_sink and has_qv_:
137137
pytest.skip("Sink disabled for Qv")
138+
if cp_world_size > 1 and local:
139+
pytest.skip("context parallelism is not supported with local attention yet")
138140
device = "cuda"
139141
# set seed
140142
torch.random.manual_seed(0)

hopper/test_util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,12 +250,12 @@ def construct_cp_mask(
250250

251251
# Calculate effective sequence lengths
252252
sk = (
253-
seqlen_k * cp_world_size # Global seqlen_k for DCP
253+
torch.tensor(seqlen_k * cp_world_size, device=device, dtype=torch.long) # Global seqlen_k for DCP
254254
if key_padding_mask is None
255255
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") * cp_world_size
256256
)
257257
sq = (
258-
seqlen_q
258+
torch.tensor(seqlen_q, device=device, dtype=torch.long) # Global seqlen_k for DCP
259259
if query_padding_mask is None
260260
else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
261261
)

0 commit comments

Comments
 (0)