@@ -78,28 +78,6 @@ CudaDevice deserialize_device(std::string device_info);
7878
7979CudaDevice get_device_info (int64_t gpu_id, nvinfer1::DeviceType device_type);
8080
81- class DeviceList {
82- using DeviceMap = std::unordered_map<int , CudaDevice>;
83- DeviceMap device_list;
84- DeviceList () {}
85-
86- public:
87- static DeviceList& instance () {
88- static DeviceList obj;
89- return obj;
90- }
91-
92- void insert (int device_id, CudaDevice cuda_device) {
93- device_list[device_id] = cuda_device;
94- }
95- CudaDevice find (int device_id) {
96- return device_list[device_id];
97- }
98- DeviceMap get_devices () {
99- return device_list;
100- }
101- };
102-
10381struct TRTEngine : torch::CustomClassHolder {
10482 // Each engine needs it's own runtime object
10583 nvinfer1::IRuntime* rt;
@@ -125,6 +103,49 @@ struct TRTEngine : torch::CustomClassHolder {
125103
126104std::vector<at::Tensor> execute_engine (std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine);
127105
106+ class DeviceList {
107+ using DeviceMap = std::unordered_map<int , CudaDevice>;
108+ DeviceMap device_list;
109+
110+ public:
111+ // Scans and updates the list of available CUDA devices
112+ DeviceList (void ) {
113+ int num_devices = 0 ;
114+ auto status = cudaGetDeviceCount (&num_devices);
115+ TRTORCH_ASSERT ((status == cudaSuccess), " Unable to read CUDA capable devices. Return status: " << status);
116+ cudaDeviceProp device_prop;
117+ for (int i = 0 ; i < num_devices; i++) {
118+ TRTORCH_CHECK (
119+ (cudaGetDeviceProperties (&device_prop, i) == cudaSuccess),
120+ " Unable to read CUDA Device Properies for device id: " << i);
121+ std::string device_name (device_prop.name );
122+ CudaDevice device = {
123+ i, device_prop.major , device_prop.minor , nvinfer1::DeviceType::kGPU , device_name.size (), device_name};
124+ device_list[i] = device;
125+ }
126+ }
127+
128+ public:
129+ static DeviceList& instance () {
130+ static DeviceList obj;
131+ return obj;
132+ }
133+
134+ void insert (int device_id, CudaDevice cuda_device) {
135+ device_list[device_id] = cuda_device;
136+ }
137+ CudaDevice find (int device_id) {
138+ return device_list[device_id];
139+ }
140+ DeviceMap get_devices () {
141+ return device_list;
142+ }
143+ };
144+
145+ namespace {
146+ static DeviceList cuda_device_list;
147+ }
148+
128149} // namespace runtime
129150} // namespace core
130151} // namespace trtorch
0 commit comments