Skip to content

Commit da2e89a

Browse files
authored
Fix GPU detection in PerStoreFeatureNode (#17593)
1 parent d392d25 commit da2e89a

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/meta_schedule/feature_extractor/per_store_feature.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1392,7 +1392,8 @@ class PerStoreFeatureNode : public FeatureExtractorNode {
13921392

13931393
Array<runtime::NDArray> ExtractFrom(const TuneContext& tune_context,
13941394
const Array<MeasureCandidate>& candidates) {
1395-
bool is_gpu = tune_context->target.value()->kind->name == "cuda";
1395+
auto& target_keys = tune_context->target.value()->keys;
1396+
bool is_gpu = std::find(target_keys.begin(), target_keys.end(), "gpu") != target_keys.end();
13961397
std::vector<runtime::NDArray> results;
13971398
results.resize(candidates.size());
13981399
std::unique_ptr<tir::group6::Feature> feature_group6 = nullptr;

0 commit comments

Comments
 (0)