@@ -38,10 +38,10 @@ class McastGPUBuffer
3838 // ! \param device The CUDA device for buffer allocation.
3939 // ! \param mnNvlink Flag indicating if multi-node NVLink is used.
4040 McastGPUBuffer (
41- size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t splitColor, at::Device device , bool mnNvlink)
42- : mMcastDeviceMemory (bufSize, groupSize, groupRank, splitColor, device.index() , mnNvlink)
41+ size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t splitColor, uint32_t deviceIdx , bool mnNvlink)
42+ : mMcastDeviceMemory (bufSize, groupSize, groupRank, splitColor, deviceIdx , mnNvlink)
4343 , mBufSize (bufSize)
44- , mLocalDevice (device )
44+ , mLocalDevice (at::Device(at::DeviceType::CUDA, deviceIdx) )
4545 {
4646 }
4747
@@ -51,7 +51,7 @@ class McastGPUBuffer
5151 // ! \param dtype The data type of the tensor elements.
5252 // ! \param storageOffset The offset in elements from the start of the buffer.
5353 // ! \return An ATen tensor wrapping the unicast buffer section.
54- at::Tensor getUCBuffer (uint32_t rank, c10::IntArrayRef sizes, c10 ::ScalarType dtype, int64_t storageOffset)
54+ at::Tensor getUCBuffer (uint32_t rank, std::vector< long int > sizes, torch ::ScalarType dtype, int64_t storageOffset)
5555 {
5656 size_t const numel = std::accumulate (sizes.begin (), sizes.end (), 1UL , std::multiplies<size_t >());
5757 size_t const elementSize = c10::elementSize (dtype);
@@ -61,15 +61,18 @@ class McastGPUBuffer
6161 auto * dataPtr = static_cast <uint8_t *>(mMcastDeviceMemory .getUnicastPtr (rank)) + storageOffset * elementSize;
6262
6363 auto options = at::TensorOptions ().dtype (dtype).device (mLocalDevice );
64- return at::for_blob (dataPtr, sizes).options (options).target_device (mLocalDevice ).make_tensor ();
64+ return at::for_blob (dataPtr, c10::IntArrayRef (sizes))
65+ .options (options)
66+ .target_device (mLocalDevice )
67+ .make_tensor ();
6568 }
6669
6770 // ! \brief Returns a PyTorch tensor view of the multicast buffer portion.
6871 // ! \param sizes The desired shape (dimensions) of the tensor.
6972 // ! \param dtype The data type of the tensor elements.
7073 // ! \param storageOffset The offset in elements from the start of the buffer.
7174 // ! \return An ATen tensor wrapping the multicast buffer section.
72- at::Tensor getMCBuffer (c10::IntArrayRef sizes, c10 ::ScalarType dtype, int64_t storageOffset)
75+ at::Tensor getMCBuffer (std::vector< long int > sizes, torch ::ScalarType dtype, int64_t storageOffset)
7376 {
7477 size_t const numel = std::accumulate (sizes.begin (), sizes.end (), 1UL , std::multiplies<size_t >());
7578 size_t const elementSize = c10::elementSize (dtype);
@@ -79,7 +82,10 @@ class McastGPUBuffer
7982 auto * dataPtr = static_cast <uint8_t *>(mMcastDeviceMemory .getMulticastPtr ()) + storageOffset * elementSize;
8083
8184 auto options = at::TensorOptions ().dtype (dtype).device (mLocalDevice );
82- return at::for_blob (dataPtr, sizes).options (options).target_device (mLocalDevice ).make_tensor ();
85+ return at::for_blob (dataPtr, c10::IntArrayRef (sizes))
86+ .options (options)
87+ .target_device (mLocalDevice )
88+ .make_tensor ();
8389 }
8490
8591private:
0 commit comments