Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 0 additions & 13 deletions .buildkite/pipeline_jax.yml
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ steps:
exit 0
fi


- label: "JAX unit tests"
key: test_7
soft_fail: true
Expand Down Expand Up @@ -221,17 +220,6 @@ steps:
exit 0
fi

- label: "E2E data parallelism test"
key: test_14
soft_fail: true
env:
NEW_MODEL_DESIGN: "True"
agents:
queue: tpu_v6e_8_queue
commands:
- |
.buildkite/scripts/run_in_docker.sh \
bash -c 'python3 -m pytest -s -v -x /workspace/tpu_inference/tests/e2e/test_data_parallel.py'

# -----------------------------------------------------------------
# NOTIFICATION STEP
Expand All @@ -252,7 +240,6 @@ steps:
- test_11
- test_12
- test_13
- test_14
agents:
queue: cpu
commands:
Expand Down
4 changes: 2 additions & 2 deletions tests/e2e/test_async_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ def test_performance(
'''
Test that async scheduler decoding provides significant performance improvement.
Compares timing between reference LLM and async LLM using Qwen2.5-1.5B.
Expects async_llm to be at least 1.25x faster than ref_llm.
Expects async_llm to be at least 1.3x faster than ref_llm.
'''
min_speed_up = 1.25
min_speed_up = 1.3
_test_performance_helper(monkeypatch, sampling_config, model_name,
min_speed_up)

Expand Down
223 changes: 0 additions & 223 deletions tests/e2e/test_data_parallel.py

This file was deleted.

4 changes: 1 addition & 3 deletions tests/layers/jax/layers/attention/test_common_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,8 @@ class TestAttention(unittest.TestCase):
def setUp(self):
"""Sets up the testing environment before each test."""
self.mesh = Mesh(
np.array(jax.devices()[:1]).reshape(1, 1, 1, -1),
np.array(jax.devices()[:1]).reshape(1, -1),
axis_names=(
"data",
"attn_dp",
"expert",
"model",
),
Expand Down
7 changes: 4 additions & 3 deletions tests/layers/jax/layers/test_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@ def test_default_sharding_config(self):
sharding_cfg = sharding.get_sharding_cfg()
generate_rules = sharding_cfg.generate_rules

self.assertEqual(generate_rules.ffw_weight_df, (None, "model"))
self.assertEqual(generate_rules.moe_router_de, (None, "model"))
self.assertEqual(generate_rules.ffw_weight_df,
(None, ("model", "expert")))
self.assertEqual(generate_rules.moe_router_de, (None, "expert"))
self.assertEqual(generate_rules.attn_q_weight_dnh,
(None, "model", None))

Expand All @@ -86,7 +87,7 @@ def test_sharding_init_with_overrides(self):

sharding_cfg = sharding.get_sharding_cfg()
self.assertNotEqual(sharding_cfg.generate_rules.logits_tv,
(None, "model"))
(None, ("model", "expert")))
self.assertEqual(sharding_cfg.generate_rules.logits_tv,
("data", "model"))

Expand Down
4 changes: 2 additions & 2 deletions tests/layers/vllm/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,11 @@ def mesh():
"""Provides a mock 1D JAX mesh for testing."""
# Create a mesh with available devices, useful for running on CPU/GPU/TPU
# For this test, it will likely be a single CPU device.
devices = np.array(jax.local_devices())[0:1]
devices = np.array(jax.local_devices())
if not devices.any():
# Add a mock device if no devices are present (e.g., in a CI environment)
devices = np.array([jax.devices("cpu")[0]])
return Mesh(devices.reshape((-1, 1, 1)), ("data", "attn_dp", "model"))
return Mesh(devices.reshape((-1, 1)), ("data", "model"))


class TestPallasAttentionBackend:
Expand Down
6 changes: 2 additions & 4 deletions tests/models/common/test_model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import jax
import jax.numpy as jnp
import numpy as np
import pytest
import torch
from jax.sharding import Mesh
Expand Down Expand Up @@ -40,10 +39,9 @@ def __call__(self, kv_caches, input_ids, attention_metadata):
@pytest.fixture(scope="module")
def mesh() -> Mesh:
"""Provides a JAX device mesh for sharding."""
devices = np.array(jax.devices()[:1])
devices = devices.reshape((1, 1, 1, -1))
devices = jax.devices()
# Pass the 1D list of devices directly. Its ndim will match len(axis_names).
return Mesh(devices, axis_names=("data", "attn_dp", "expert", "model"))
return Mesh(devices, axis_names=("model", ))


@pytest.fixture
Expand Down
4 changes: 2 additions & 2 deletions tests/models/jax/test_attention_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ def mesh():
"""Provides a mock 1D JAX mesh for testing."""
# Create a mesh with available devices, useful for running on CPU/GPU/TPU
# For this test, it will likely be a single CPU device.
devices = np.array(jax.local_devices()[:1])
devices = np.array(jax.local_devices())
if not devices.any():
# Add a mock device if no devices are present (e.g., in a CI environment)
devices = np.array([jax.devices("cpu")[0]])
return Mesh(devices.reshape((-1, 1, 1)), ("data", "attn_dp", "model"))
return Mesh(devices.reshape((-1, 1)), ("data", "model"))


# ---- Test for `attention` ----
Expand Down
12 changes: 3 additions & 9 deletions tests/models/jax/test_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,9 @@ def mesh():
devices = np.array(jax.local_devices()[:1])
num_devices = len(devices)
assert num_devices == 1
device_mesh = devices.reshape((num_devices, 1, 1, 1))
device_mesh = devices.reshape((num_devices, 1))

with Mesh(device_mesh,
axis_names=('data', 'attn_dp', 'expert', 'model')) as m:
with Mesh(device_mesh, axis_names=('data', 'model')) as m:
yield m


Expand Down Expand Up @@ -90,12 +89,7 @@ def test_llama32_1b(self, mock_vllm_config, rng, mesh, mock_model_inputs):
model_config = mock_vllm_config.model_config
hf_config = model_config.hf_config

assert model.mesh.shape == {
"data": 1,
"attn_dp": 1,
"expert": 1,
"model": 1
}
assert model.mesh.shape == {"data": 1, "model": 1}

layers = model.model.layers
assert len(layers) == hf_config.num_hidden_layers
Expand Down
5 changes: 2 additions & 3 deletions tests/models/jax/test_llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,9 @@ def mesh():
# Reshape devices into a 3D array to name 3 axes: data, model, and expert.
# The 'model' and 'expert' axes will have a size of 1.
num_devices = len(devices)
device_mesh = devices.reshape((num_devices, 1, 1, 1))
device_mesh = devices.reshape((num_devices, 1, 1))

with Mesh(device_mesh,
axis_names=('data', 'attn_dp', 'model', 'expert')) as m:
with Mesh(device_mesh, axis_names=('data', 'model', 'expert')) as m:
yield m


Expand Down
Loading