@@ -22,7 +22,7 @@ def clear_cache():
2222
2323# Define MLA and non-MLA backends separately
2424DEVICE_MLA_BACKENDS = {
25- "cuda" : ["TRITON_MLA" , "FLASHMLA" ],
25+ "cuda" : ["TRITON_MLA" , "FLASHMLA" , "FLASH_ATTN_MLA" , "CUTLASS_MLA" ],
2626 "hip" : ["TRITON_MLA" , "ROCM_AITER_MLA" ],
2727 "cpu" : [],
2828}
@@ -98,21 +98,14 @@ def test_env(
9898 with patch ("vllm.attention.selector.current_platform" ,
9999 RocmPlatform ()):
100100 if use_mla :
101- # Validate HIP MLA backend-block_size combinations
102- valid_combination = (
103- (name == "TRITON_MLA" and block_size != 1 )
104- or (name == "ROCM_AITER_MLA" and block_size == 1 ))
105-
106- if valid_combination :
107- backend = get_attn_backend (16 ,
108- torch .float16 ,
109- torch .float16 ,
110- block_size ,
111- False ,
112- use_mla = use_mla )
113- expected = f"{ name } _VLLM_V1" if use_v1 else name
114- assert backend .get_name () == expected
115- else :
101+ # ROCm MLA backend logic:
102+ # - TRITON_MLA: supported when block_size != 1
103+ # - ROCM_AITER_MLA: supported when block_size == 1
104+ # If backend is forced but doesn't match block_size,
105+ # should raise ValueError
106+
107+ if name == "TRITON_MLA" and block_size == 1 :
108+ # TRITON_MLA doesn't support block_size == 1
116109 with pytest .raises (ValueError ) as exc_info :
117110 get_attn_backend (16 ,
118111 torch .float16 ,
@@ -122,6 +115,27 @@ def test_env(
122115 use_mla = use_mla )
123116 assert f"The selected backend, { name } " in str (
124117 exc_info .value )
118+ elif name == "ROCM_AITER_MLA" and block_size != 1 :
119+ # ROCM_AITER_MLA only supports block_size == 1
120+ with pytest .raises (ValueError ) as exc_info :
121+ get_attn_backend (16 ,
122+ torch .float16 ,
123+ torch .float16 ,
124+ block_size ,
125+ False ,
126+ use_mla = use_mla )
127+ assert f"The selected backend, { name } " in str (
128+ exc_info .value )
129+ else :
130+ # Valid backend-block_size combination
131+ backend = get_attn_backend (16 ,
132+ torch .float16 ,
133+ torch .float16 ,
134+ block_size ,
135+ False ,
136+ use_mla = use_mla )
137+ expected = f"{ name } _VLLM_V1" if use_v1 else name
138+ assert backend .get_name () == expected
125139 else :
126140 backend = get_attn_backend (16 ,
127141 torch .float16 ,
@@ -136,26 +150,68 @@ def test_env(
136150 with patch ("vllm.attention.selector.current_platform" ,
137151 CudaPlatform ()):
138152 if use_mla :
139- if name == "FLASHMLA" and block_size == 64 :
140- from vllm .attention .backends .flashmla import (
141- is_flashmla_supported )
142-
143- # only on cuda platforms with specific capability.
144- is_supported , _ = is_flashmla_supported ()
145-
146- if not is_supported :
147- # if platform is not supported then skip this case.
148- pytest .skip ()
153+ # CUDA MLA backend logic:
154+ # - CUTLASS_MLA: only supported with block_size == 128
155+ # and Blackwell GPUs (SM 10.0), V1 only
156+ # - FLASHMLA: only supported with block_size == 64
157+ # - FLASH_ATTN_MLA: V1 only
158+ # - TRITON_MLA: fallback for other cases
159+
160+ if name == "CUTLASS_MLA" :
161+ if not use_v1 :
162+ # CUTLASS_MLA only supported on V1 engine
163+ pytest .skip (
164+ "CUTLASS_MLA only supported on V1 engine" )
165+ elif block_size != 128 :
166+ # CUTLASS_MLA only supports block_size == 128
167+ pytest .skip (
168+ "CUTLASS_MLA only supports block_size 128" )
169+ else :
170+ backend = get_attn_backend (16 ,
171+ torch .float16 ,
172+ torch .float16 ,
173+ block_size ,
174+ False ,
175+ use_mla = use_mla )
176+ expected = "CUTLASS_MLA_VLLM_V1"
177+ assert backend .get_name () == expected
178+ elif name == "FLASHMLA" :
179+ if block_size != 64 :
180+ # FlashMLA only supports block_size == 64
181+ pytest .skip ("FlashMLA only supports block_size 64" )
182+ else :
183+ from vllm .attention .backends .flashmla import (
184+ is_flashmla_supported )
185+ is_supported , _ = is_flashmla_supported ()
186+ if not is_supported :
187+ pytest .skip (
188+ "FlashMLA not supported on this platform" )
189+ else :
190+ backend = get_attn_backend (16 ,
191+ torch .float16 ,
192+ torch .float16 ,
193+ block_size ,
194+ False ,
195+ use_mla = use_mla )
196+ expected = f"{ name } _VLLM_V1" if use_v1 else name
197+ assert backend .get_name () == expected
198+ elif name == "FLASH_ATTN_MLA" :
199+ if not use_v1 :
200+ # FlashAttention MLA only supported on V1 engine
201+ pytest .skip (
202+ "FlashAttention MLA only supported on V1 engine"
203+ )
149204 else :
150205 backend = get_attn_backend (16 ,
151206 torch .float16 ,
152207 torch .float16 ,
153208 block_size ,
154209 False ,
155210 use_mla = use_mla )
156- expected = f" { name } _VLLM_V1" if use_v1 else name
211+ expected = "FLASH_ATTN_MLA"
157212 assert backend .get_name () == expected
158213 else :
214+ # TRITON_MLA or other fallback
159215 backend = get_attn_backend (16 ,
160216 torch .float16 ,
161217 torch .float16 ,
0 commit comments