Skip to content

Commit 4569d2b

Browse files
committed
[Refactor][Runtime] Always specify device in allocator interface
Prior to this PR, each allocator is closely tied with a device. To enable using a same allocator across different devices of the same kind when needed, we lift the device to the allocator `Alloc` interface.
1 parent 95ec38b commit 4569d2b

File tree

6 files changed

+37
-35
lines changed

6 files changed

+37
-35
lines changed

include/tvm/runtime/memory/memory_manager.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,19 +71,22 @@ class Allocator {
7171
/*! \brief Return the allocator type. */
7272
inline AllocatorType type() const { return type_; }
7373
/*! \brief Allocate a buffer given a size, alignment and type.
74+
* \param dev The device where the array is allocated.
7475
* \param nbytes The size of the buffer.
7576
* \param alignment The alignment of the buffer.
7677
* \param type_hint A type hint to the allocator.
7778
* \return A sized allocation in the form of a buffer.
7879
*/
79-
TVM_DLL virtual Buffer Alloc(size_t nbytes, size_t alignment, DLDataType type_hint) = 0;
80+
TVM_DLL virtual Buffer Alloc(Device dev, size_t nbytes, size_t alignment,
81+
DLDataType type_hint) = 0;
8082
/*! \brief Allocate a buffer given a shape and type.
83+
* \param dev The device where the array is allocated.
8184
* \param shape The shape of the tensor.
8285
* \param type_hint A type hint to the allocator.
8386
* \param mem_scope A memory scope of the buffer.
8487
* \return A sized allocation in the form of a buffer.
8588
*/
86-
TVM_DLL virtual Buffer Alloc(ShapeTuple shape, DLDataType type_hint,
89+
TVM_DLL virtual Buffer Alloc(Device dev, ShapeTuple shape, DLDataType type_hint,
8790
const std::string& mem_scope = "") = 0;
8891
/*! \brief Free a buffer allocated by the allocator.
8992
* \param buffer The buffer to free.
@@ -96,10 +99,6 @@ class Allocator {
9699
*/
97100
TVM_DLL virtual size_t UsedMemory() const = 0;
98101

99-
protected:
100-
TVM_DLL virtual Buffer Alloc(Device dev, ShapeTuple shape, DLDataType type_hint,
101-
const std::string& mem_scope);
102-
103102
private:
104103
AllocatorType type_;
105104
};

src/runtime/memory/memory_manager.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,12 +138,12 @@ Allocator* MemoryManager::GetOrCreateAllocator(Device dev, AllocatorType type) {
138138
switch (type) {
139139
case kNaive: {
140140
VLOG(1) << "New naive allocator for " << dev;
141-
alloc.reset(new NaiveAllocator(dev));
141+
alloc.reset(new NaiveAllocator());
142142
break;
143143
}
144144
case kPooled: {
145145
VLOG(1) << "New pooled allocator for " << dev;
146-
alloc.reset(new PooledAllocator(dev));
146+
alloc.reset(new PooledAllocator());
147147
break;
148148
}
149149
default:
@@ -194,9 +194,9 @@ NDArray Allocator::Empty(ShapeTuple shape, DLDataType dtype, DLDevice dev,
194194
size_t alignment = GetDataAlignment(container->dl_tensor);
195195
Buffer* buffer = new Buffer;
196196
if (!mem_scope.defined() || mem_scope.value().empty() || mem_scope.value() == "global") {
197-
*buffer = this->Alloc(size, alignment, dtype);
197+
*buffer = this->Alloc(dev, size, alignment, dtype);
198198
} else {
199-
*buffer = this->Alloc(shape, dtype, mem_scope.value());
199+
*buffer = this->Alloc(dev, shape, dtype, mem_scope.value());
200200
}
201201
container->manager_ctx = reinterpret_cast<void*>(buffer);
202202
container->dl_tensor.data = buffer->data;
@@ -210,7 +210,7 @@ Buffer Allocator::Alloc(Device dev, ShapeTuple shape, DLDataType type_hint,
210210
NDArray::Container container(nullptr, shape, type_hint, dev);
211211
size_t size = DeviceAPI::Get(dev)->GetDataSize(container.dl_tensor);
212212
size_t alignment = GetDataAlignment(container.dl_tensor);
213-
return Alloc(size, alignment, type_hint);
213+
return Alloc(dev, size, alignment, type_hint);
214214
}
215215
LOG(FATAL) << "Allocator cannot allocate data space with "
216216
<< "specified memory scope: " << mem_scope;

src/runtime/memory/naive_allocator.h

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,46 +35,47 @@ namespace memory {
3535

3636
class NaiveAllocator final : public Allocator {
3737
public:
38-
explicit NaiveAllocator(Device dev) : Allocator(kNaive), used_memory_(0), device_(dev) {}
38+
explicit NaiveAllocator() : Allocator(kNaive), used_memory_(0) {}
3939

40-
Buffer Alloc(size_t nbytes, size_t alignment, DLDataType type_hint) override {
40+
Buffer Alloc(Device dev, size_t nbytes, size_t alignment, DLDataType type_hint) final {
4141
Buffer buf;
42-
buf.device = device_;
42+
buf.device = dev;
4343
buf.size = nbytes;
4444
buf.alloc_type = kNaive;
45-
buf.data = DeviceAPI::Get(device_)->AllocDataSpace(device_, nbytes, alignment, type_hint);
45+
buf.data = DeviceAPI::Get(dev)->AllocDataSpace(dev, nbytes, alignment, type_hint);
4646
used_memory_.fetch_add(nbytes, std::memory_order_relaxed);
4747
DLOG(INFO) << "allocate " << nbytes << " B, used memory " << used_memory_ << " B";
4848
return buf;
4949
}
5050

51-
Buffer Alloc(ShapeTuple shape, DLDataType type_hint, const std::string& mem_scope) override {
51+
Buffer Alloc(Device dev, ShapeTuple shape, DLDataType type_hint,
52+
const std::string& mem_scope) final {
5253
Buffer buf;
5354
size_t nbytes = 1;
5455
for (int i = 0; i < static_cast<int>(shape.size()); ++i) {
5556
nbytes *= static_cast<size_t>(shape[i]);
5657
}
5758
nbytes *= (type_hint.bits * type_hint.lanes + 7) / 8;
58-
buf.device = device_;
59+
buf.device = dev;
5960
if (mem_scope.empty() || mem_scope == "global") {
60-
auto tmp_buf = Allocator::Alloc(device_, shape, type_hint, mem_scope);
61+
auto tmp_buf = Allocator::Alloc(dev, shape, type_hint, mem_scope);
6162
buf.size = tmp_buf.size;
6263
buf.data = tmp_buf.data;
6364
buf.alloc_type = kNaive;
6465
return buf;
6566
}
6667

6768
buf.size = nbytes;
68-
buf.data = DeviceAPI::Get(device_)->AllocDataSpace(device_, shape.size(), shape.data(),
69-
type_hint, String(mem_scope));
69+
buf.data = DeviceAPI::Get(dev)->AllocDataSpace(dev, shape.size(), shape.data(), type_hint,
70+
String(mem_scope));
7071
used_memory_.fetch_add(nbytes, std::memory_order_relaxed);
7172
DLOG(INFO) << "allocate " << nbytes << " B, used memory " << used_memory_ << " B";
7273
buf.alloc_type = kNaive;
7374
return buf;
7475
}
7576

7677
void Free(const Buffer& buffer) override {
77-
DeviceAPI::Get(device_)->FreeDataSpace(buffer.device, buffer.data);
78+
DeviceAPI::Get(buffer.device)->FreeDataSpace(buffer.device, buffer.data);
7879
used_memory_.fetch_sub(buffer.size, std::memory_order_relaxed);
7980
DLOG(INFO) << "free " << buffer.size << " B, used memory " << used_memory_ << " B";
8081
}
@@ -83,7 +84,6 @@ class NaiveAllocator final : public Allocator {
8384

8485
private:
8586
std::atomic<size_t> used_memory_;
86-
Device device_;
8787
};
8888

8989
} // namespace memory

src/runtime/memory/pooled_allocator.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,12 @@ class PooledAllocator final : public Allocator {
4040
public:
4141
static constexpr size_t kDefaultPageSize = 4096;
4242

43-
explicit PooledAllocator(Device dev, size_t page_size = kDefaultPageSize)
44-
: Allocator(kPooled), page_size_(page_size), used_memory_(0), device_(dev) {}
43+
explicit PooledAllocator(size_t page_size = kDefaultPageSize)
44+
: Allocator(kPooled), page_size_(page_size), used_memory_(0) {}
4545

4646
~PooledAllocator() { ReleaseAll(); }
4747

48-
Buffer Alloc(size_t nbytes, size_t alignment, DLDataType type_hint) override {
48+
virtual Buffer Alloc(Device dev, size_t nbytes, size_t alignment, DLDataType type_hint) override {
4949
std::lock_guard<std::recursive_mutex> lock(mu_);
5050
size_t size = ((nbytes + page_size_ - 1) / page_size_) * page_size_;
5151
auto&& it = memory_pool_.find(size);
@@ -56,26 +56,27 @@ class PooledAllocator final : public Allocator {
5656
return ret;
5757
}
5858
Buffer buf;
59-
buf.device = device_;
59+
buf.device = dev;
6060
buf.size = size;
6161
buf.alloc_type = kPooled;
6262
try {
63-
buf.data = DeviceAPI::Get(device_)->AllocDataSpace(device_, size, alignment, type_hint);
63+
buf.data = DeviceAPI::Get(dev)->AllocDataSpace(dev, size, alignment, type_hint);
6464
} catch (InternalError& err) {
6565
LOG(WARNING) << "PooledAllocator got InternalError during allocation: " << err.message();
6666
LOG(WARNING) << "Trying to release all unused memory and reallocate...";
6767
ReleaseAll();
68-
buf.data = DeviceAPI::Get(device_)->AllocDataSpace(device_, size, alignment, type_hint);
68+
buf.data = DeviceAPI::Get(dev)->AllocDataSpace(dev, size, alignment, type_hint);
6969
}
7070

7171
used_memory_.fetch_add(size, std::memory_order_relaxed);
7272
VLOG(1) << "allocate " << size << " B, used memory " << used_memory_ << " B";
7373
return buf;
7474
}
7575

76-
Buffer Alloc(ShapeTuple shape, DLDataType type_hint, const std::string& mem_scope) override {
76+
virtual Buffer Alloc(Device dev, ShapeTuple shape, DLDataType type_hint,
77+
const std::string& mem_scope) override {
7778
if (mem_scope.empty() || mem_scope == "global") {
78-
return Allocator::Alloc(device_, shape, type_hint, mem_scope);
79+
return Allocator::Alloc(dev, shape, type_hint, mem_scope);
7980
}
8081
LOG(FATAL) << "This alloc should be implemented";
8182
return {};
@@ -113,7 +114,6 @@ class PooledAllocator final : public Allocator {
113114
std::atomic<size_t> used_memory_;
114115
std::unordered_map<size_t, std::vector<Buffer>> memory_pool_;
115116
std::recursive_mutex mu_;
116-
Device device_;
117117
};
118118

119119
} // namespace memory

src/runtime/relax_vm/builtin.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,8 @@ Storage VMAllocStorage(void* ctx_ptr, ShapeTuple buffer_shape, Index device_inde
347347
auto* alloc = vm->allocators[device_index];
348348
ICHECK(alloc) << "Did you forget to init the VirtualMachine with devices?";
349349

350-
storage_obj->buffer = alloc->Alloc(buffer_shape, dtype_hint, mem_scope);
350+
storage_obj->buffer =
351+
alloc->Alloc(vm->devices[device_index], buffer_shape, dtype_hint, mem_scope);
351352
Storage storage(storage_obj);
352353
return storage;
353354
}

src/runtime/vm/vm.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,7 @@ void VirtualMachine::RunLoop(const std::vector<Index>& output_tensor_reg_indices
823823

824824
auto storage_obj = SimpleObjAllocator().make_object<StorageObj>();
825825
Allocator* allocator = GetAllocator(instr.alloc_storage.device_index);
826+
Device device = devices_[instr.alloc_storage.device_index];
826827
ICHECK(allocator) << "Did you forget to init the VirtualMachine with devices?";
827828

828829
if (instr.alloc_storage.ndim > 0) {
@@ -844,15 +845,16 @@ void VirtualMachine::RunLoop(const std::vector<Index>& output_tensor_reg_indices
844845
shape_.resize(instr.alloc_storage.ndim);
845846
shape_.assign(instr.alloc_storage.shape,
846847
instr.alloc_storage.shape + instr.alloc_storage.ndim);
847-
storage_obj->buffer =
848-
allocator->Alloc(ShapeTuple(shape_), instr.alloc_storage.dtype_hint, mem_scope);
848+
storage_obj->buffer = allocator->Alloc(device, ShapeTuple(shape_),
849+
instr.alloc_storage.dtype_hint, mem_scope);
849850
} else {
850851
auto size = LoadScalarInt(instr.alloc_storage.allocation_size);
851852
auto alignment = instr.alloc_storage.alignment;
852853
VLOG(2) << "allocating with allocation_size=" << size << ", alignment=" << alignment
853854
<< ", dtype_hint=" << DLDataType2String(instr.alloc_storage.dtype_hint)
854855
<< ", device_index=" << instr.alloc_storage.device_index;
855-
storage_obj->buffer = allocator->Alloc(size, alignment, instr.alloc_storage.dtype_hint);
856+
storage_obj->buffer =
857+
allocator->Alloc(device, size, alignment, instr.alloc_storage.dtype_hint);
856858
}
857859
Storage storage(storage_obj);
858860
WriteRegister(instr.dst, storage);

0 commit comments

Comments
 (0)