@@ -40,12 +40,12 @@ class PooledAllocator final : public Allocator {
4040 public:
4141 static constexpr size_t kDefaultPageSize = 4096 ;
4242
43- explicit PooledAllocator (Device dev, size_t page_size = kDefaultPageSize )
44- : Allocator(kPooled ), page_size_(page_size), used_memory_(0 ), device_(dev) {}
43+ explicit PooledAllocator (size_t page_size = kDefaultPageSize )
44+ : Allocator(kPooled ), page_size_(page_size), used_memory_(0 ) {}
4545
4646 ~PooledAllocator () { ReleaseAll (); }
4747
48- Buffer Alloc (size_t nbytes, size_t alignment, DLDataType type_hint) override {
48+ virtual Buffer Alloc (Device dev, size_t nbytes, size_t alignment, DLDataType type_hint) override {
4949 std::lock_guard<std::recursive_mutex> lock (mu_);
5050 size_t size = ((nbytes + page_size_ - 1 ) / page_size_) * page_size_;
5151 auto && it = memory_pool_.find (size);
@@ -56,26 +56,27 @@ class PooledAllocator final : public Allocator {
5656 return ret;
5757 }
5858 Buffer buf;
59- buf.device = device_ ;
59+ buf.device = dev ;
6060 buf.size = size;
6161 buf.alloc_type = kPooled ;
6262 try {
63- buf.data = DeviceAPI::Get (device_ )->AllocDataSpace (device_ , size, alignment, type_hint);
63+ buf.data = DeviceAPI::Get (dev )->AllocDataSpace (dev , size, alignment, type_hint);
6464 } catch (InternalError& err) {
6565 LOG (WARNING) << " PooledAllocator got InternalError during allocation: " << err.message ();
6666 LOG (WARNING) << " Trying to release all unused memory and reallocate..." ;
6767 ReleaseAll ();
68- buf.data = DeviceAPI::Get (device_ )->AllocDataSpace (device_ , size, alignment, type_hint);
68+ buf.data = DeviceAPI::Get (dev )->AllocDataSpace (dev , size, alignment, type_hint);
6969 }
7070
7171 used_memory_.fetch_add (size, std::memory_order_relaxed);
7272 VLOG (1 ) << " allocate " << size << " B, used memory " << used_memory_ << " B" ;
7373 return buf;
7474 }
7575
76- Buffer Alloc (ShapeTuple shape, DLDataType type_hint, const std::string& mem_scope) override {
76+ virtual Buffer Alloc (Device dev, ShapeTuple shape, DLDataType type_hint,
77+ const std::string& mem_scope) override {
7778 if (mem_scope.empty () || mem_scope == " global" ) {
78- return Allocator::Alloc (device_ , shape, type_hint, mem_scope);
79+ return Allocator::Alloc (dev , shape, type_hint, mem_scope);
7980 }
8081 LOG (FATAL) << " This alloc should be implemented" ;
8182 return {};
@@ -113,7 +114,6 @@ class PooledAllocator final : public Allocator {
113114 std::atomic<size_t > used_memory_;
114115 std::unordered_map<size_t , std::vector<Buffer>> memory_pool_;
115116 std::recursive_mutex mu_;
116- Device device_;
117117};
118118
119119} // namespace memory
0 commit comments