Skip to content

Conversation

kliuae
Copy link
Contributor

@kliuae kliuae commented Sep 2, 2025

Purpose

This PR targets ROCm AITER, introducing a flag‑gated path that fuses DeepSeek models’ shared_experts into the AITER's FusedMoE kernel, reducing separate MLP and addition overhead while preserving numeric behavior.

When shared experts fusion is enabled, the shared experts are viewed as synthetic routed experts after the original routed experts and receive allocated top‑k slots through grouped_topk, enabling a single fused MoE dispatch for both shared and routed experts.

This feature can be controlled by the environment flag VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS which is only effective when VLLM_ROCM_USE_AITER_MOE is set.

Test Plan

The following tests validate DeepSeek models by collecting benchmark metrics and performning correctness tests through lm_eval.

vLLM server launch command:

# Toggle VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS

VLLM_USE_V1=1 \
VLLM_ROCM_USE_AITER=1 \
VLLM_ROCM_USE_AITER_MOE=1 \
VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS=${FUSE_SHARED_EXPERTS} \
VLLM_DISABLE_COMPILE_CACHE=1 \
vllm serve ${model_name} --tensor-parallel-size ${tp_size} --block-size 1 --compilation-config '{"cuadgraph_mode": "FULL_AND_PIECEWISE"}'

Benchmark commands:

# sharegpt dataset
vllm bench serve --model ${model_name} --dataset-name sharegpt --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json --percentile-metrics ttft,tpot,itl,e2el --request-rate ${request_rate}

# random dataset
vllm bench serve --model ${model_name} --percentile-metrics ttft,tpot,itl,e2el --request-rate 10 --num-prompts 100 --dataset-name random --random-input-len 1024 --random-output-len 1024

lm_eval command:

lm_eval --model local-completions --tasks gsm8k --model_args model=${model_name},base_url=http://localhost:8000/v1/completions,num_concurrent=128,max_retries=3,tokenized_requests=False

Test Result

Benchmark results

deepseek-ai/DeepSeek-R1 on sharegpt dataset

Aiter Fused
Shared Experts
Request
Rate
QPS Throughput TTFT TOPT ITL
P50 P99 P50 P99 P50 P99
No 4 3.75 1479.63 120.81 380.14 45.37 83.91 35.65 147.64
Yes 4 3.76 1479.35 119.48 363.16 44.02 74.72 34.87 138.54
+ 0.27% -0.02% 1.11% 4.68% 3.07% 12.30% 2.24% 6.57%
No 8 6.82 2696.90 175.49 411.43 72.03 127.79 48.63 164.63
Yes 8 6.93 2729.55 161.74 457.21 65.76 122.38 44.60 157.29
+ 1.61% 1.21% 8.50% -10.01% 9.53% 4.42% 9.04% 4.67%
No inf 14.97 5805.71 9457.86 17323.00 127.69 683.91 75.30 574.26
Yes inf 15.05 5959.04 9240.45 16705.93 126.12 659.70 73.93 541.18
+ 0.53% 2.64% 2.35% 3.69% 1.24% 3.67% 1.85% 6.11%

deepseek-ai/DeepSeek-R1 on random dataset, input-len/output-len: 1k/1k

Aiter Fused
Shared Experts
Request
Rate
QPS Throughput TTFT TOPT ITL
P50 P99 P50 P99 P50 P99
No 10 1.81 3450.08 641.56 1254.48 47.84 625.85 42.68 218.86
Yes 10 1.92 3719.75 576.87 1277.88 44.18 513.07 40.21 287.79
+ 6.08% 7.82% 11.21% -1.83% 8.28% 21.98% 6.14% -23.95%

Accuracy test

deepseek-ai/DeepSeek-R1

Aiter Fused
Shared Experts
Tasks Version Filter n-shot Metric Value Stderr
No gsm8k 3 flexible-extract 5 exact_match _ 0.9545 _ 0.0057
strict-match 5 exact_match _ 0.9538 _ 0.0058
Yes gsm8k 3 flexible-extract 5 exact_match _ 0.9568 _ 0.0056
strict-match 5 exact_match _ 0.9560 _ 0.0056

deepseek-ai/DeepSeek-V3

Aiter Fused
Shared Experts
Tasks Version Filter n-shot Metric Value Stderr
No gsm8k 3 flexible-extract 5 exact_match _ 0.9492 _ 0.006
strict-match 5 exact_match _ 0.9492 _ 0.006
Yes gsm8k 3 flexible-extract 5 exact_match _ 0.9507 _ 0.0060
strict-match 5 exact_match _ 0.9484 _ 0.0061

deepseek-ai/DeepSeek-V2-Lite-Chat

Aiter Fused
Shared Experts
Tasks Version Filter n-shot Metric Value Stderr
No gsm8k 3 flexible-extract 5 exact_match _ 0.4845 _ 0.0138
strict-match 5 exact_match _ 0.4776 _ 0.0138
Yes gsm8k 3 flexible-extract 5 exact_match _ 0.4936 _ 0.0138
strict-match 5 exact_match _ 0.4890 _ 0.0138

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

valarLip and others added 9 commits August 13, 2025 09:13
Deepseek 085 sharedexperts aiter jun new

Signed-off-by: chenjun <[email protected]>
Signed-off-by: kliuae <[email protected]>
Signed-off-by: kliuae <[email protected]>
Signed-off-by: kliuae <[email protected]>
Signed-off-by: kliuae <[email protected]>
Signed-off-by: kliuae <[email protected]>
Signed-off-by: kliuae <[email protected]>
@mergify mergify bot added deepseek Related to DeepSeek models rocm Related to AMD ROCm labels Sep 2, 2025
@mergify
Copy link

mergify bot commented Sep 2, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @kliuae.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 2, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an optimization for DeepSeek models on ROCm by fusing shared experts into the AITER FusedMoE kernel. This is controlled by a new environment flag. The changes span across environment variable setup, the core FusedMoE layer, quantization layers, and the DeepSeek model implementation to correctly handle the fused logic and weight loading.

The implementation looks solid and the changes are consistent with the goal of the PR. I've found one area for improvement in the initialization logic for the shared expert metadata, which could be made more memory and performance efficient. My detailed feedback is in the comment below.

Comment on lines 75 to 85
if is_EP:
s_topk_ids_list = [[fake_expertid] *
(n_shared_experts + is_EP)] * max_num_tokens
for i in range(tp_rank, max_num_tokens, tp_size):
s_topk_ids_list[i] = shared_expert_ids
else:
s_topk_ids_list = [range(n_routed_experts, fake_expertid)
] * max_num_tokens
s_topk_ids[:] = torch.tensor(s_topk_ids_list,
dtype=torch.int32,
device='cuda')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation for initializing s_topk_ids can be inefficient. It constructs a large Python list of lists (s_topk_ids_list) on the host, which is then converted to a PyTorch tensor on the CPU before being moved to the GPU. For a large max_num_tokens, this can lead to significant host memory consumption and slow down the initialization process.

A more efficient approach would be to perform these operations directly on the GPU tensor, avoiding the large intermediate host-side data structures. This can be achieved using tensor broadcasting and slicing.

Suggested change
if is_EP:
s_topk_ids_list = [[fake_expertid] *
(n_shared_experts + is_EP)] * max_num_tokens
for i in range(tp_rank, max_num_tokens, tp_size):
s_topk_ids_list[i] = shared_expert_ids
else:
s_topk_ids_list = [range(n_routed_experts, fake_expertid)
] * max_num_tokens
s_topk_ids[:] = torch.tensor(s_topk_ids_list,
dtype=torch.int32,
device='cuda')
if is_EP:
s_topk_ids.fill_(fake_expertid)
shared_expert_ids_tensor = torch.tensor(shared_expert_ids,
dtype=torch.int32,
device='cuda')
s_topk_ids[tp_rank::tp_size] = shared_expert_ids_tensor
else:
s_topk_ids_row = torch.arange(n_routed_experts,
fake_expertid,
dtype=torch.int32,
device='cuda')
s_topk_ids.copy_(s_topk_ids_row.expand(max_num_tokens, -1))

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant performance optimization for DeepSeek models on ROCm by fusing shared experts into the AITER MoE kernel. The implementation is gated behind environment variables and includes comprehensive benchmark and accuracy tests, which is great. However, I've identified two critical issues that need to be addressed. The first is related to the use of a global variable for model-specific metadata, which can lead to race conditions and incorrect behavior when serving multiple models. The second is a bug in the weight loading logic for the fused shared experts, which fails to correctly track loaded parameters and will likely cause errors. Addressing these issues will ensure the stability and correctness of this new feature.

Comment on lines 50 to 94
aiter_topK_meta_data = None


@lru_cache(maxsize=1)
def init_aiter_topK_meta_data(n_routed_experts: int,
n_shared_experts: int,
top_k: int,
tp_rank: int,
tp_size: int,
shared_experts_score: float = 1.0,
max_num_tokens: int = 32768,
is_EP: bool = False):
global aiter_topK_meta_data
fake_expertid = n_routed_experts + n_shared_experts

# all layers reuse same buffer
total_topk_ids = torch.empty(
(max_num_tokens, top_k + n_shared_experts + is_EP),
dtype=torch.int32,
device='cuda')
ns_topk_ids, s_topk_ids = total_topk_ids.split(
[top_k, n_shared_experts + is_EP], dim=1)
shared_expert_ids = [
n_routed_experts + i for i in range(n_shared_experts + is_EP)
]
if is_EP:
s_topk_ids_list = [[fake_expertid] *
(n_shared_experts + is_EP)] * max_num_tokens
for i in range(tp_rank, max_num_tokens, tp_size):
s_topk_ids_list[i] = shared_expert_ids
else:
s_topk_ids_list = [range(n_routed_experts, fake_expertid)
] * max_num_tokens
s_topk_ids[:] = torch.tensor(s_topk_ids_list,
dtype=torch.int32,
device='cuda')

total_topk_weights = torch.empty(
(max_num_tokens, top_k + n_shared_experts + is_EP),
dtype=torch.float32,
device='cuda')
ns_topk_weights, s_topk_weights = total_topk_weights.split(
[top_k, n_shared_experts + is_EP], dim=1)
s_topk_weights.fill_(shared_experts_score)
aiter_topK_meta_data = (total_topk_weights, total_topk_ids)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The use of a global variable aiter_topK_meta_data to store model-specific metadata is problematic. If vLLM serves multiple models with different MoE configurations in the same process, this global variable will be overwritten, leading to incorrect behavior for one of the models. This can cause race conditions and hard-to-debug errors.

The metadata should be managed without using a global variable. A better approach would be:

  1. Modify init_aiter_topK_meta_data to return the metadata tuple instead of modifying a global variable. The @lru_cache decorator should then be used on a function that is pure (has no side effects).
  2. In FusedMoE.__init__, store the returned metadata in an instance attribute, e.g., self.aiter_topK_meta_data.
  3. Pass this instance attribute down through the call chain (forward_cuda -> select_experts -> rocm_aiter_grouped_topk).
  4. rocm_aiter_grouped_topk should then use the passed metadata instead of the global variable.

This change will ensure that each model's metadata is properly encapsulated and avoids race conditions.

Comment on lines 929 to 1033
else:
is_expert_weight = False
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue

# Anyway, this is an expert weight and should not be
# attempted to load as other weights later
is_expert_weight = True

# Do not modify `name` since the loop may continue here
# Instead, create a new variable
name_mapped = name.replace(weight_name, param_name)

if is_pp_missing_parameter(name_mapped, self):
continue

param = params_dict[name_mapped]
# We should ask the weight loader to return success or not
# here since otherwise we may skip experts with other
# available replicas.
weight_loader = typing.cast(Callable[..., bool],
param.weight_loader)
success = weight_loader(param,
loaded_weight,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True)
if success:
name = name_mapped
break
else:
if is_expert_weight:
# We've checked that this is an expert weight
# However it's not mapped locally to this rank
# So we simply skip it
continue

# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue

# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue

if is_pp_missing_parameter(name, self):
continue

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

# Special handling: when AITER fusion_shared_experts is enabled,
# checkpoints may provide a single widened shared_experts tensor
# without explicit expert indices
# (e.g. ...mlp.shared_experts.gate_proj.weight).
# For models with multiple shared experts, split that tensor
# evenly into per-shared-expert slices and load them into
# appended expert slots mlp.experts.{n_routed_experts + j}.*
# accordingly.
num_chunks = 1
if is_fuse_shared_experts_layer:
num_chunks = getattr(self.config, "n_shared_experts",
1) or 1
# Determine split axis based on op type
# gate/up: ColumnParallel → split along dim 0
# down: RowParallel → split along dim 1
split_dim = 1 if "down_proj.weight" in name else 0
total = loaded_weight.shape[split_dim]
assert total % num_chunks == 0, (
f"Shared expert weight dim {total} "
f"not divisible by num_chunks {num_chunks}")
chunk_size = total // num_chunks

for j in range(num_chunks):
chunk_name = name
weight_to_load = loaded_weight

if is_fuse_shared_experts_layer:
if split_dim == 0:
weight_to_load = loaded_weight[j *
chunk_size:(j + 1) *
chunk_size, :]
else:
weight_to_load = loaded_weight[:, j *
chunk_size:(j + 1) *
chunk_size]
# Synthesize an expert-style name so expert mapping
# can route it
chunk_name = name.replace(
"mlp.shared_experts",
f"mlp.experts.{self.config.n_routed_experts + j}")

# Use expert_params_mapping to locate the destination
# param and delegate to its expert-aware weight_loader
# with expert_id.
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in chunk_name:
continue

# Anyway, this is an expert weight and should not be
# attempted to load as other weights later
is_expert_weight = True

# Do not modify `name` since the loop may continue here
# Instead, create a new variable
name_mapped = chunk_name.replace(
weight_name, param_name)

if is_pp_missing_parameter(name_mapped, self):
continue

param = params_dict[name_mapped]
# We should ask the weight loader to return success or
# not here since otherwise we may skip experts with
# other available replicas.
weight_loader = typing.cast(Callable[..., bool],
param.weight_loader)
success = weight_loader(param,
weight_to_load,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True)
if success:
if not is_fuse_shared_experts_layer:
name = name_mapped
break
else:
if is_expert_weight:
# We've checked that this is an expert weight
# However it's not mapped locally to this rank
# So we simply skip it
continue

# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue

# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue

if is_pp_missing_parameter(name, self):
continue

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The logic for loading shared expert weights when is_fuse_shared_experts_layer is true does not correctly update the loaded_params set. It adds the original shared expert tensor name (e.g., ...mlp.shared_experts.gate_proj.weight) to loaded_params, but it actually loads the weights into multiple, chunked expert parameters (e.g., mlp.experts.64.*, mlp.experts.65.*, etc.).

As a result, vLLM will not be aware that these chunked expert parameters have been loaded, which will likely lead to "missing keys" errors at the end of the weight loading process or incorrect model behavior if those checks are bypassed.

The fix is to add each name_mapped to loaded_params as it is successfully loaded within the for j in range(num_chunks): loop, and then prevent the original shared expert name from being added to loaded_params at the end of the outer loop over weights.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kliuae , is this comment relevant? I'm not familiar enough with the weight loading to know.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently it's not checked in vLLM, and yes down the line it'd be good to align the loaded_params with the names actually loaded. We'll make updates accordingly to reflect this.

@gshtras
Copy link
Collaborator

gshtras commented Sep 2, 2025

cc @qli88

@mergify mergify bot removed the needs-rebase label Sep 5, 2025
Signed-off-by: kliuae <[email protected]>
@mergify
Copy link

mergify bot commented Sep 10, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @kliuae.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 10, 2025
Signed-off-by: kliuae <[email protected]>
@mergify
Copy link

mergify bot commented Oct 8, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @kliuae.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@kliuae
Copy link
Contributor Author

kliuae commented Oct 9, 2025

@bnellnm I have made changes addressing your comments, and the branch has been synced with the upstream. Can you help with the review? Thanks

@mergify
Copy link

mergify bot commented Oct 10, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @kliuae.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 10, 2025
Comment on lines -207 to +214
if config.n_shared_experts is None:
if (
config.n_shared_experts is None
or is_rocm_aiter_fusion_shared_expert_enabled()
):
Copy link
Contributor

@bnellnm bnellnm Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that I refactored this recently so that there's only an instance of SharedFusedMoE. It can handle when self.shared_experts is None so it should be simple to keep it mostly the same except for passing n_shared_experts when config.n_shared_experts() is true.

Comment on lines 1521 to 1582
# Use expert_params_mapping to locate the destination
# param and delegate to its expert-aware weight_loader
# with expert_id.
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in chunk_name:
continue

# Anyway, this is an expert weight and should not be
# attempted to load as other weights later
is_expert_weight = True

# Do not modify `name` since the loop may continue here
# Instead, create a new variable
name_mapped = chunk_name.replace(weight_name, param_name)

if is_pp_missing_parameter(name_mapped, self):
continue

param = params_dict[name_mapped]
# We should ask the weight loader to return success or
# not here since otherwise we may skip experts with
# other available replicas.
weight_loader = typing.cast(
Callable[..., bool], param.weight_loader
)
success = weight_loader(
param,
weight_to_load,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True,
)
if success:
if not is_fuse_shared_experts_layer:
name = name_mapped
break
else:
if is_expert_weight:
# We've checked that this is an expert weight
# However it's not mapped locally to this rank
# So we simply skip it
continue

# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue

# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue

if is_pp_missing_parameter(name, self):
continue

param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this bit basically the same as before? It's a little hard to tell the way the diff shows up.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this is largely the same as before. The changes are that, since in deepseek-v2-lite the multiple shared experts' weights are provided as single weights tensors, load_weights chunks the tensors by the number of shared experts and wraps their weight loading in a loop. Other than that, for the other layers and when shared experts fusion is not enabled, this number of chunks is set to one, and their loading logic should remain the same.

[top_k, n_shared_experts + is_EP], dim=1
)
s_topk_weights.fill_(shared_experts_score)
aiter_topK_meta_data = (total_topk_weights, total_topk_ids)
Copy link
Contributor

@bnellnm bnellnm Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you assert aiter_topK_meta_data is None here so that if we run into the situation where the parameters to the init function change, we don't silently overwrite the global with a different value? I assume since the init function is cached it should only ever do this assignment once as long as the input parameters remain unchanged.

Copy link
Contributor

@bnellnm bnellnm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, nice work! I had a few final questions/comments though.

Signed-off-by: kliuae <[email protected]>
Signed-off-by: kliuae <[email protected]>
@mergify mergify bot removed the needs-rebase label Oct 14, 2025
@DarkLight1337 DarkLight1337 added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 15, 2025
@kliuae kliuae requested a review from WoosukKwon as a code owner October 15, 2025 12:30
@DarkLight1337 DarkLight1337 merged commit 1317034 into vllm-project:main Oct 16, 2025
64 checks passed
mandy-li pushed a commit to mandy-li/vllm that referenced this pull request Oct 16, 2025
albertoperdomo2 pushed a commit to albertoperdomo2/vllm that referenced this pull request Oct 16, 2025
…llm-project#24097)

Signed-off-by: chenjun <[email protected]>
Signed-off-by: kliuae <[email protected]>
Co-authored-by: valarLip <[email protected]>
Co-authored-by: TJian <[email protected]>
Signed-off-by: Alberto Perdomo <[email protected]>
albertoperdomo2 pushed a commit to albertoperdomo2/vllm that referenced this pull request Oct 16, 2025
…llm-project#24097)

Signed-off-by: chenjun <[email protected]>
Signed-off-by: kliuae <[email protected]>
Co-authored-by: valarLip <[email protected]>
Co-authored-by: TJian <[email protected]>
Signed-off-by: Alberto Perdomo <[email protected]>
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants