Skip to content

Commit 086f20a

Browse files
authored
Fix test_barrier hang by using static global rank in ProcessGroupXCCL (#2036)
Fixes #1978 In ProcessGroupNCCL, `globalRank()` returns a static int globalRank, which is first initialized by the ProcessGroup setup, so the globalRank assigned to each thread matches the id of the CUDA device. However, we were not using this same pattern for XCCL. Instead, we were just using the assigned rank of the thread, which were not necessarily the same as the globalRank. The failing test `test_barrier` created two separate groups of 2 ranks each, and then 4 threads called barrier, but all on the same 2-thread group. Since initially the device id is not specified in this barrier call, the thread attempts to "guess" the device index. In the previous code, this guess would be 0 or 1, since the rank of each thread was not actually the globalRank. In `barrier`, this guessed id was used to initialize XCCLComm objects, and then call allreduce on a single element tensor. However, this resulted in an allreduce call two times on each device, which could result in a hang based on the execution order of the 4 threads. With the update in this PR, PGXCCL now uses the static globalRank in the same places as ProcessGroupNCCL, so the initialized XCCLComm objects are for unique devices and allreduce doesn't call on the same device multiple times.
1 parent f301733 commit 086f20a

File tree

3 files changed

+10
-4
lines changed

3 files changed

+10
-4
lines changed

src/xccl/ProcessGroupXCCL.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,11 @@ const std::string& ProcessGroupXCCL::logPrefix() const {
352352
return logPrefix_;
353353
}
354354

355+
const int& ProcessGroupXCCL::globalRank() const {
356+
static int globalRank = rank_;
357+
return globalRank;
358+
}
359+
355360
ProcessGroupXCCL::ProcessGroupXCCL(
356361
c10::intrusive_ptr<Store> store,
357362
int rank,
@@ -379,7 +384,7 @@ ProcessGroupXCCL::ProcessGroupXCCL(
379384
std::string torch_distributed_debug =
380385
getCvarString({"TORCH_DISTRIBUTED_DEBUG"}, OFF.c_str());
381386
LOG(INFO) << logPrefix() << "ProcessGroupXCCL initialization options: "
382-
<< "size: " << size << ", global rank: " << rank_
387+
<< "size: " << size << ", global rank: " << globalRank()
383388
<< ", USE_HIGH_PRIORITY_STREAM: "
384389
<< options_->is_high_priority_stream
385390
<< ", PG Name: " << options_->group_name;
@@ -410,7 +415,7 @@ bool ProcessGroupXCCL::dumpDebuggingInfo(bool includeStackTrace /*=true*/) {
410415
if (traceBufferSize_ > 0) {
411416
// TODO: dump_xccl_trace
412417
auto xcclTrace = dump_xccl_trace(true, includeStackTrace, false);
413-
DebugInfoWriter& writer = DebugInfoWriter::getWriter(rank_);
418+
DebugInfoWriter& writer = DebugInfoWriter::getWriter(globalRank());
414419
LOG(INFO) << logPrefix() << "ProcessGroupXCCL dumping xccl trace to "
415420
<< writer.getWriterTarget();
416421
writer.write(xcclTrace);
@@ -2021,7 +2026,7 @@ c10::DeviceIndex ProcessGroupXCCL::guessDeviceId() const {
20212026
return *usedDeviceIdxs_.begin();
20222027
}
20232028
int devIdx =
2024-
static_cast<int16_t>(rank_ % at::detail::getXPUHooks().getNumGPUs());
2029+
static_cast<int16_t>(globalRank() % at::detail::getXPUHooks().getNumGPUs());
20252030
LOG(WARNING)
20262031
<< logPrefix()
20272032
<< c10::str(

src/xccl/ProcessGroupXCCL.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,7 @@ class TORCH_API ProcessGroupXCCL : public Backend {
423423
c10::DeviceIndex guessDeviceId() const;
424424

425425
const std::vector<uint64_t>& groupRanks() const;
426+
const int& globalRank() const;
426427
void setEnqueuedPgStatus(c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> work);
427428
bool dumpDebuggingInfo(bool includeStackTrace = true);
428429

src/xccl/ProcessGroupXCCLMonitor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ void HeartbeatMonitorXCCL::runLoop() {
3939
// We only need to dump once per PG, so we use local_id_ == 0 for the first PG
4040
if (pg_->local_id_ == 0) {
4141
// DumpPipe is one per-trainer process
42-
dumpPipe.emplace(pg_->getRank());
42+
dumpPipe.emplace(pg_->globalRank());
4343
while (true) {
4444
std::unique_lock<std::mutex> lock(monitorMutex_);
4545
if (monitorWakeUpCV_.wait_for(

0 commit comments

Comments
 (0)