Skip to content

Commit 566a58c

Browse files
Revert "Data Parallelism support (#865)"
This reverts commit a27922a.
1 parent f0cdad5 commit 566a58c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+297
-2352
lines changed

.buildkite/pipeline_jax.yml

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,6 @@ steps:
101101
exit 0
102102
fi
103103
104-
105104
- label: "JAX unit tests"
106105
key: test_7
107106
soft_fail: true
@@ -221,17 +220,6 @@ steps:
221220
exit 0
222221
fi
223222
224-
- label: "E2E data parallelism test"
225-
key: test_14
226-
soft_fail: true
227-
env:
228-
NEW_MODEL_DESIGN: "True"
229-
agents:
230-
queue: tpu_v6e_8_queue
231-
commands:
232-
- |
233-
.buildkite/scripts/run_in_docker.sh \
234-
bash -c 'python3 -m pytest -s -v -x /workspace/tpu_inference/tests/e2e/test_data_parallel.py'
235223
236224
# -----------------------------------------------------------------
237225
# NOTIFICATION STEP
@@ -252,7 +240,6 @@ steps:
252240
- test_11
253241
- test_12
254242
- test_13
255-
- test_14
256243
agents:
257244
queue: cpu
258245
commands:

tests/e2e/test_async_scheduler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,9 @@ def test_performance(
121121
'''
122122
Test that async scheduler decoding provides significant performance improvement.
123123
Compares timing between reference LLM and async LLM using Qwen2.5-1.5B.
124-
Expects async_llm to be at least 1.25x faster than ref_llm.
124+
Expects async_llm to be at least 1.3x faster than ref_llm.
125125
'''
126-
min_speed_up = 1.25
126+
min_speed_up = 1.3
127127
_test_performance_helper(monkeypatch, sampling_config, model_name,
128128
min_speed_up)
129129

tests/e2e/test_data_parallel.py

Lines changed: 0 additions & 223 deletions
This file was deleted.

tests/layers/jax/layers/attention/test_common_attention.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,8 @@ class TestAttention(unittest.TestCase):
2121
def setUp(self):
2222
"""Sets up the testing environment before each test."""
2323
self.mesh = Mesh(
24-
np.array(jax.devices()[:1]).reshape(1, 1, 1, -1),
24+
np.array(jax.devices()[:1]).reshape(1, -1),
2525
axis_names=(
26-
"data",
27-
"attn_dp",
2826
"expert",
2927
"model",
3028
),

tests/layers/jax/layers/test_sharding.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,9 @@ def test_default_sharding_config(self):
7070
sharding_cfg = sharding.get_sharding_cfg()
7171
generate_rules = sharding_cfg.generate_rules
7272

73-
self.assertEqual(generate_rules.ffw_weight_df, (None, "model"))
74-
self.assertEqual(generate_rules.moe_router_de, (None, "model"))
73+
self.assertEqual(generate_rules.ffw_weight_df,
74+
(None, ("model", "expert")))
75+
self.assertEqual(generate_rules.moe_router_de, (None, "expert"))
7576
self.assertEqual(generate_rules.attn_q_weight_dnh,
7677
(None, "model", None))
7778

@@ -86,7 +87,7 @@ def test_sharding_init_with_overrides(self):
8687

8788
sharding_cfg = sharding.get_sharding_cfg()
8889
self.assertNotEqual(sharding_cfg.generate_rules.logits_tv,
89-
(None, "model"))
90+
(None, ("model", "expert")))
9091
self.assertEqual(sharding_cfg.generate_rules.logits_tv,
9192
("data", "model"))
9293

tests/layers/vllm/test_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,11 @@ def mesh():
8181
"""Provides a mock 1D JAX mesh for testing."""
8282
# Create a mesh with available devices, useful for running on CPU/GPU/TPU
8383
# For this test, it will likely be a single CPU device.
84-
devices = np.array(jax.local_devices())[0:1]
84+
devices = np.array(jax.local_devices())
8585
if not devices.any():
8686
# Add a mock device if no devices are present (e.g., in a CI environment)
8787
devices = np.array([jax.devices("cpu")[0]])
88-
return Mesh(devices.reshape((-1, 1, 1)), ("data", "attn_dp", "model"))
88+
return Mesh(devices.reshape((-1, 1)), ("data", "model"))
8989

9090

9191
class TestPallasAttentionBackend:

tests/models/common/test_model_loader.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import jax
66
import jax.numpy as jnp
7-
import numpy as np
87
import pytest
98
import torch
109
from jax.sharding import Mesh
@@ -40,10 +39,9 @@ def __call__(self, kv_caches, input_ids, attention_metadata):
4039
@pytest.fixture(scope="module")
4140
def mesh() -> Mesh:
4241
"""Provides a JAX device mesh for sharding."""
43-
devices = np.array(jax.devices()[:1])
44-
devices = devices.reshape((1, 1, 1, -1))
42+
devices = jax.devices()
4543
# Pass the 1D list of devices directly. Its ndim will match len(axis_names).
46-
return Mesh(devices, axis_names=("data", "attn_dp", "expert", "model"))
44+
return Mesh(devices, axis_names=("model", ))
4745

4846

4947
@pytest.fixture

tests/models/jax/test_attention_interface.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@ def mesh():
3939
"""Provides a mock 1D JAX mesh for testing."""
4040
# Create a mesh with available devices, useful for running on CPU/GPU/TPU
4141
# For this test, it will likely be a single CPU device.
42-
devices = np.array(jax.local_devices()[:1])
42+
devices = np.array(jax.local_devices())
4343
if not devices.any():
4444
# Add a mock device if no devices are present (e.g., in a CI environment)
4545
devices = np.array([jax.devices("cpu")[0]])
46-
return Mesh(devices.reshape((-1, 1, 1)), ("data", "attn_dp", "model"))
46+
return Mesh(devices.reshape((-1, 1)), ("data", "model"))
4747

4848

4949
# ---- Test for `attention` ----

tests/models/jax/test_llama3.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,9 @@ def mesh():
3636
devices = np.array(jax.local_devices()[:1])
3737
num_devices = len(devices)
3838
assert num_devices == 1
39-
device_mesh = devices.reshape((num_devices, 1, 1, 1))
39+
device_mesh = devices.reshape((num_devices, 1))
4040

41-
with Mesh(device_mesh,
42-
axis_names=('data', 'attn_dp', 'expert', 'model')) as m:
41+
with Mesh(device_mesh, axis_names=('data', 'model')) as m:
4342
yield m
4443

4544

@@ -90,12 +89,7 @@ def test_llama32_1b(self, mock_vllm_config, rng, mesh, mock_model_inputs):
9089
model_config = mock_vllm_config.model_config
9190
hf_config = model_config.hf_config
9291

93-
assert model.mesh.shape == {
94-
"data": 1,
95-
"attn_dp": 1,
96-
"expert": 1,
97-
"model": 1
98-
}
92+
assert model.mesh.shape == {"data": 1, "model": 1}
9993

10094
layers = model.model.layers
10195
assert len(layers) == hf_config.num_hidden_layers

tests/models/jax/test_llama4.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,9 @@ def mesh():
8181
# Reshape devices into a 3D array to name 3 axes: data, model, and expert.
8282
# The 'model' and 'expert' axes will have a size of 1.
8383
num_devices = len(devices)
84-
device_mesh = devices.reshape((num_devices, 1, 1, 1))
84+
device_mesh = devices.reshape((num_devices, 1, 1))
8585

86-
with Mesh(device_mesh,
87-
axis_names=('data', 'attn_dp', 'model', 'expert')) as m:
86+
with Mesh(device_mesh, axis_names=('data', 'model', 'expert')) as m:
8887
yield m
8988

9089

0 commit comments

Comments
 (0)