| 
18 | 18 | from vllm import envs  | 
19 | 19 | from vllm.attention import AttentionMetadata  | 
20 | 20 | from vllm.config import VllmConfig  | 
21 |  | -from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn  | 
 | 21 | +from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn  | 
22 | 22 | from vllm.model_executor.layers.layernorm import RMSNorm  | 
23 | 23 | from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler  | 
24 | 24 | from vllm.model_executor.model_loader.loader import DefaultModelLoader  | 
@@ -252,33 +252,50 @@ def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:  | 
252 | 252 |         return audio_embeds  | 
253 | 253 | 
 
  | 
254 | 254 | 
 
  | 
 | 255 | +class FlippedSiluAndMul(SiluAndMul):  | 
 | 256 | +    """Ultravox is trained with SwiGLU with flipped halves."""  | 
 | 257 | + | 
 | 258 | +    def forward(self, x: torch.Tensor):  | 
 | 259 | +        a, b = x.chunk(2, dim=-1)  | 
 | 260 | +        flipped = torch.cat((b, a), dim=-1)  | 
 | 261 | +        return super().forward(flipped)  | 
 | 262 | + | 
 | 263 | + | 
255 | 264 | class UltravoxProjector(nn.Module):  | 
256 | 265 | 
 
  | 
257 | 266 |     def __init__(self, config: UltravoxConfig):  | 
258 | 267 |         super().__init__()  | 
259 | 268 |         self.hidden_dim = config.hidden_size  | 
260 | 269 |         self._pad_and_stack = StackAudioFrames(config.stack_factor)  | 
261 |  | -        dim = config.audio_config.hidden_size * config.stack_factor  | 
262 |  | -        self.ln_pre = RMSNorm(dim)  | 
263 |  | -        self.linear_1 = nn.Linear(dim, self.hidden_dim, bias=False)  | 
264 |  | -        dim = self.hidden_dim  | 
 | 270 | +        dim_in = config.audio_config.hidden_size * config.stack_factor  | 
 | 271 | +        self.ln_pre = RMSNorm(dim_in)  | 
 | 272 | +        self.linear_1 = nn.Linear(dim_in, self.hidden_dim, bias=False)  | 
 | 273 | +        dim_mid = self.hidden_dim  | 
265 | 274 | 
 
  | 
266 | 275 |         if config.projector_act == "swiglu":  | 
267 |  | -            self.act = MulAndSilu()  | 
268 |  | -            dim = dim // 2  | 
 | 276 | +            self.act = FlippedSiluAndMul()  | 
 | 277 | +            dim_mid = dim_mid // 2  | 
269 | 278 |         else:  | 
270 | 279 |             self.act = get_act_fn(config.projector_act)  | 
271 | 280 | 
 
  | 
272 |  | -        self.linear_2 = nn.Linear(dim,  | 
273 |  | -                                  config.text_config.hidden_size,  | 
274 |  | -                                  bias=False)  | 
275 |  | -        self.ln_post = RMSNorm(config.text_config.hidden_size)  | 
 | 281 | +        dim_out = config.text_config.hidden_size  | 
 | 282 | +        self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False)  | 
 | 283 | + | 
 | 284 | +        # Ultravox v0.4.1 and below uses layer_norm after the second linear layer,  | 
 | 285 | +        # while v0.5.0 and above uses layer_norm after the first linear layer.  | 
 | 286 | +        if config.projector_ln_mid:  | 
 | 287 | +            self.ln_mid: nn.Module = RMSNorm(dim_mid)  | 
 | 288 | +            self.ln_post = nn.Identity()  | 
 | 289 | +        else:  | 
 | 290 | +            self.ln_mid = nn.Identity()  | 
 | 291 | +            self.ln_post = RMSNorm(dim_out)  | 
276 | 292 | 
 
  | 
277 | 293 |     def forward(self, audio_features: torch.Tensor) -> torch.Tensor:  | 
278 | 294 |         audio_features = self._pad_and_stack(audio_features)  | 
279 | 295 |         audio_features = self.ln_pre(audio_features)  | 
280 | 296 |         hidden_states = self.linear_1(audio_features)  | 
281 | 297 |         hidden_states = self.act(hidden_states)  | 
 | 298 | +        hidden_states = self.ln_mid(hidden_states)  | 
282 | 299 |         hidden_states = self.linear_2(hidden_states)  | 
283 | 300 |         hidden_states = self.ln_post(hidden_states)  | 
284 | 301 |         return hidden_states  | 
 | 
0 commit comments