From c13ae6ff67c297f339150355ae6bb9ceedf3bb25 Mon Sep 17 00:00:00 2001 From: "Dvoretckii, Mikhail" Date: Fri, 22 Nov 2024 06:01:33 -0800 Subject: [PATCH 1/2] Enable managed quantization parameters --- tools/common_lib/src/gemm.h | 58 +++++++++++++++++++++++++++++++------ 1 file changed, 49 insertions(+), 9 deletions(-) diff --git a/tools/common_lib/src/gemm.h b/tools/common_lib/src/gemm.h index 99d9d17..ebd1515 100644 --- a/tools/common_lib/src/gemm.h +++ b/tools/common_lib/src/gemm.h @@ -499,7 +499,7 @@ class QuantGemm : public DirectMlBaseNode QuantGemm(const DML_TENSOR_DATA_TYPE data_type, const dml::TensorPolicy& tensor_policy_ab, const dml::TensorPolicy& tensor_policy_c, const TensorShape& shape_a, const TensorShape& shape_b, const TensorShape& shape_c, const TensorShape& shape_out, const TensorShape& shape_scale, const TensorShape& shape_zeropoint, bool a_quantized, bool b_quantized, bool c_quantized, const uint32_t block_size, const DML_TENSOR_DATA_TYPE quantized_data_type, bool has_zero_point, - bool a_transposed, bool b_managed, bool b_transposed, bool c_managed, float alpha, float beta, + bool a_transposed, bool b_managed, bool b_transposed, bool bs_managed, bool bz_managed, bool c_managed, float alpha, float beta, bool allow_fp16_computations, const ActivationSettings& activation_settings, IDMLDevice* dml_device, ID3D12Device* d3d12_device, bool allow_descriptors_volatile, bool disable_mc = false) : DirectMlBaseNode(dml_device, d3d12_device) @@ -546,7 +546,7 @@ class QuantGemm : public DirectMlBaseNode dimensions_scale.push_back(shape_scale.c); dimensions_scale.push_back(shape_scale.h); dimensions_scale.push_back(shape_scale.w); - dml::TensorDesc desc_scale = { data_type, DML_TENSOR_FLAG_NONE, dimensions_scale }; + dml::TensorDesc desc_scale = { data_type, bs_managed ? DML_TENSOR_FLAG_OWNED_BY_DML : DML_TENSOR_FLAG_NONE, dimensions_scale }; input_1_scale_ = dml::InputTensor(graph_, index, desc_scale); dml::TensorDesc::Dimensions dimensions_zero_point; @@ -554,7 +554,7 @@ class QuantGemm : public DirectMlBaseNode dimensions_zero_point.push_back(shape_zeropoint.c); dimensions_zero_point.push_back(shape_zeropoint.h); dimensions_zero_point.push_back(shape_zeropoint.w); - dml::TensorDesc desc_zeropoint = { quantized_data_type, DML_TENSOR_FLAG_NONE, dimensions_zero_point }; + dml::TensorDesc desc_zeropoint = { quantized_data_type, bz_managed ? DML_TENSOR_FLAG_OWNED_BY_DML : DML_TENSOR_FLAG_NONE, dimensions_zero_point }; input_1_zeropoint_ = dml::InputTensor(graph_, index + 1, desc_zeropoint); std::vector tensor_b_quantization_params(2); @@ -657,8 +657,27 @@ class QuantGemm : public DirectMlBaseNode } } - input_bindings.push_back({ DML_BINDING_TYPE_BUFFER, &input_scale_buffer_binding }); - input_bindings.push_back({ DML_BINDING_TYPE_BUFFER, &input_zeropoint_buffer_binding }); + if (input_1_scale_.GetOutputDesc().flags == DML_TENSOR_FLAG_OWNED_BY_DML) + { + input_scale_buffer_binding = { nullptr, 0, 0 }; + input_bindings.push_back({ DML_BINDING_TYPE_NONE, &input_scale_buffer_binding }); + } + else + { + input_scale_buffer_binding = { resource_scale, 0, resource_scale->GetDesc().Width }; + input_bindings.push_back({ DML_BINDING_TYPE_BUFFER, &input_scale_buffer_binding }); + } + + if (input_1_zeropoint_.GetOutputDesc().flags == DML_TENSOR_FLAG_OWNED_BY_DML) + { + input_zeropoint_buffer_binding = { nullptr, 0, 0 }; + input_bindings.push_back({ DML_BINDING_TYPE_NONE, &input_zeropoint_buffer_binding }); + } + else + { + input_zeropoint_buffer_binding = { resource_zeropoint, 0, resource_zeropoint->GetDesc().Width }; + input_bindings.push_back({ DML_BINDING_TYPE_BUFFER, &input_zeropoint_buffer_binding }); + } std::vector output_bindings; @@ -713,8 +732,25 @@ class QuantGemm : public DirectMlBaseNode input_binds.push_back({ nullptr, 0, 0 }); } - input_binds.push_back({ nullptr, 0, 0 }); // tensor scale - input_binds.push_back({ nullptr, 0, 0 }); // tensor zero point + // tensor scale + if (input_1_scale_.GetOutputDesc().flags == DML_TENSOR_FLAG_OWNED_BY_DML) + { + input_binds.push_back({ resource_scale, 0, resource_scale->GetDesc().Width }); + } + else + { + input_binds.push_back({ nullptr, 0, 0 }); + } + + // tensor zero point + if (input_1_zeropoint_.GetOutputDesc().flags == DML_TENSOR_FLAG_OWNED_BY_DML) + { + input_binds.push_back({ resource_zeropoint, 0, resource_zeropoint->GetDesc().Width }); + } + else + { + input_binds.push_back({ nullptr, 0, 0 }); + } DML_BUFFER_ARRAY_BINDING input_bind{}; input_bind.BindingCount = static_cast(input_binds.size()); @@ -769,6 +805,8 @@ class QuantGemmBaseDispatcher : public NodeDispatcher float beta = 1.0f; bool b_managed = false; + bool bs_managed = false; + bool bz_managed = false; bool c_managed = false; bool a_transposed = false; bool b_transposed = false; @@ -800,6 +838,8 @@ class QuantGemmBaseDispatcher : public NodeDispatcher opts->add_flag("--a_transposed", params.a_transposed)->default_val(false); opts->add_flag("--b_transposed", params.b_transposed)->default_val(false); opts->add_flag("--b_managed", params.b_managed)->default_val(false); + opts->add_flag("--bs_managed", params.bs_managed)->default_val(false); + opts->add_flag("--bz_managed", params.bz_managed)->default_val(false); opts->add_flag("--c_managed", params.c_managed)->default_val(false); opts->add_flag("--allow_fp16_computations", params.allow_fp16_computations); @@ -1081,7 +1121,7 @@ class QuantGemmBaseDispatcher : public NodeDispatcher gpu_op::QuantGemm gemm_ref(to_dml_data_type(params_.dt), to_dml_tensor_policy(params_.layout), to_dml_tensor_policy(params_.layout_c), params_.shape_a, params_.shape_b, params_.shape_c, get_shape_output(), get_shape_quant_param(), get_shape_quant_param(), params_.a_quantized, params_.b_quantized, params_.c_quantized, params_.block_size, to_dml_data_type(params_.quant_dt), params_.has_zero_point, - params_.a_transposed, false /*params_.b_managed*/, params_.b_transposed, false /*params_.c_managed*/, params_.alpha, params_.beta, + params_.a_transposed, false /*params_.b_managed*/, params_.b_transposed, false /*params_.bs_managed*/, false /*params_.bz_managed*/, false /*params_.c_managed*/, params_.alpha, params_.beta, params_.allow_fp16_computations, params_.activation, dml_device_, d3d12_device_, false, true); // bind descriptor heap @@ -1202,7 +1242,7 @@ class QuantGemmDmlDispatcher : public QuantGemmBaseDispatcher : QuantGemmBaseDispatcher(std::move(params), d3d12_device, dml_device, dml_cmd_recorder, cmd_list) , quantgemm_(to_dml_data_type(params_.dt), to_dml_tensor_policy(params_.layout), to_dml_tensor_policy(params_.layout_c), params_.shape_a, params_.shape_b, params_.shape_c, get_shape_output(), get_shape_quant_param(), get_shape_quant_param(), params_.a_quantized, params_.b_quantized, params_.c_quantized, params_.block_size, to_dml_data_type(params_.quant_dt), params_.has_zero_point, - params_.a_transposed, params_.b_managed, params_.b_transposed, params_.c_managed, + params_.a_transposed, params_.b_managed, params_.b_transposed, params_.bs_managed, params_.bz_managed, params_.c_managed, params_.alpha, params_.beta, params_.allow_fp16_computations, params_.activation, dml_device, d3d12_device, allow_descriptors_volatile, false) { From 472d92f3b44fe05a533bfc125002722e40a437d6 Mon Sep 17 00:00:00 2001 From: "Dvoretckii, Mikhail" Date: Fri, 29 Nov 2024 07:59:12 -0800 Subject: [PATCH 2/2] Implement a UMDD3D12 dispatcher for quantized GEMM --- .../drivers.gpu.compute.ai.dnnlpluginnext | 2 +- tools/common_lib/src/dnnl_utils.h | 1 + tools/common_lib/src/gemm.h | 661 +++++++++++++++++- tools/common_lib/src/iumd_d3d12_impl.h | 5 + tools/common_lib/src/node_dispatcher.h | 1 + tools/cross_runner/src/main.cpp | 14 +- 6 files changed, 673 insertions(+), 11 deletions(-) diff --git a/thirdparty/drivers.gpu.compute.ai.dnnlpluginnext b/thirdparty/drivers.gpu.compute.ai.dnnlpluginnext index e8a6323..1e86221 160000 --- a/thirdparty/drivers.gpu.compute.ai.dnnlpluginnext +++ b/thirdparty/drivers.gpu.compute.ai.dnnlpluginnext @@ -1 +1 @@ -Subproject commit e8a6323a8156aeebffda260e3ef86ad40a1e86e9 +Subproject commit 1e862214ddc383b5d141a433bdc01de5ef312e58 diff --git a/tools/common_lib/src/dnnl_utils.h b/tools/common_lib/src/dnnl_utils.h index 880c799..9d2415f 100644 --- a/tools/common_lib/src/dnnl_utils.h +++ b/tools/common_lib/src/dnnl_utils.h @@ -69,6 +69,7 @@ inline dnnl::memory::data_type to_dnnl_data_type(const DataType l) { case DataType::eFp32: return dnnl::memory::data_type::f32; case DataType::eFp16: return dnnl::memory::data_type::f16; + case DataType::eUint4: return dnnl::memory::data_type::u4; default: return dnnl::memory::data_type::undef; } diff --git a/tools/common_lib/src/gemm.h b/tools/common_lib/src/gemm.h index ebd1515..a18be28 100644 --- a/tools/common_lib/src/gemm.h +++ b/tools/common_lib/src/gemm.h @@ -534,8 +534,8 @@ class QuantGemm : public DirectMlBaseNode dml::TensorDesc::Dimensions dimensions_1; dimensions_1.push_back(shape_b.n); dimensions_1.push_back(shape_b.c); - dimensions_1.push_back(shape_b.h); - dimensions_1.push_back(shape_b.w); + dimensions_1.push_back(b_transposed ? shape_b.w : shape_b.h); + dimensions_1.push_back(b_transposed ? shape_b.h : shape_b.w); dml::TensorDesc desc_input_1 = { quantized_data_type, b_managed ? DML_TENSOR_FLAG_OWNED_BY_DML : DML_TENSOR_FLAG_NONE, dimensions_1 }; input_1_ = dml::InputTensor(graph_, 1, desc_input_1); @@ -965,11 +965,15 @@ class QuantGemmBaseDispatcher : public NodeDispatcher { input_buffer_scale_ = create_buffer(d3d12_device_, tensor_input_scale_bytes_width, D3D12_HEAP_TYPE_DEFAULT, D3D12_RESOURCE_STATE_COPY_DEST, D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS); + reorder_buffer_scale_ = create_buffer(d3d12_device_, tensor_input_scale_bytes_width, + D3D12_HEAP_TYPE_DEFAULT, D3D12_RESOURCE_STATE_COPY_DEST, D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS); } if (tensor_input_zeropoint_bytes_width > 0) { input_buffer_zeropoint_ = create_buffer(d3d12_device_, tensor_input_zeropoint_bytes_width, D3D12_HEAP_TYPE_DEFAULT, D3D12_RESOURCE_STATE_COPY_DEST, D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS); + reorder_buffer_zeropoint_ = create_buffer(d3d12_device_, tensor_input_zeropoint_bytes_width, + D3D12_HEAP_TYPE_DEFAULT, D3D12_RESOURCE_STATE_COPY_DEST, D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS); } output_buffer_ = create_buffer(d3d12_device_, tensor_out_bytes_width, @@ -1181,8 +1185,8 @@ class QuantGemmBaseDispatcher : public NodeDispatcher TensorShape ret{}; ret.n = get_batch(); ret.c = get_channels(); - ret.h = params_.b_transposed ? params_.shape_b.h : params_.shape_b.w; - ret.w = params_.b_transposed ? params_.shape_b.w / params_.block_size : params_.shape_b.h / params_.block_size; + ret.h = params_.b_transposed ? params_.shape_b.w : params_.shape_b.h; + ret.w = params_.b_transposed ? params_.shape_b.h / params_.block_size : params_.shape_b.w / params_.block_size; return ret; } @@ -1198,17 +1202,17 @@ class QuantGemmBaseDispatcher : public NodeDispatcher std::uint32_t get_M() const { - return params_.a_transposed ? params_.shape_a.w : params_.shape_a.h; + return params_.shape_a.h; } std::uint32_t get_K() const { - return params_.a_transposed ? params_.shape_a.h : params_.shape_a.w; + return params_.shape_a.w; } std::uint32_t get_N() const { - return params_.b_transposed ? params_.shape_b.h : params_.shape_b.w; + return params_.shape_b.w; } protected: @@ -1229,7 +1233,9 @@ class QuantGemmBaseDispatcher : public NodeDispatcher ComPtr input_buffer_c_; ComPtr input_buffer_scale_; + ComPtr reorder_buffer_scale_; ComPtr input_buffer_zeropoint_; + ComPtr reorder_buffer_zeropoint_; ComPtr output_buffer_; ComPtr upload_buffer_; @@ -1274,6 +1280,645 @@ class QuantGemmDmlDispatcher : public QuantGemmBaseDispatcher gpu_op::QuantGemm quantgemm_; }; +class QuantGemmUmdD3d12Dispatcher : public QuantGemmBaseDispatcher +{ +public: + struct qgemm_umdd3d12_params_t + { + std::uint32_t verbose_mode = 0; // 0: disabled; 1: execution; 2: creation and execution + bool verbose_dump_to_file = false; + bool cache_blob = false; + + inline static void add_cli_options(CLI::App* opts, qgemm_umdd3d12_params_t& params) + { + opts->add_option("--verbose_mode", params.verbose_mode)->default_val(0); + opts->add_flag("--verbose_file", params.verbose_dump_to_file)->default_val(false); + opts->add_flag("--cache_blob", params.cache_blob, "Use to test persistent cache blob.")->default_val(false); + } + }; +public: + QuantGemmUmdD3d12Dispatcher(create_params_t&& params, const qgemm_umdd3d12_params_t& umdd3d12_param, IntelExtension& intc_ext, ID3D12Device* d3d12_device, IDMLDevice* dml_device, IDMLCommandRecorder* dml_cmd_recorder, ID3D12GraphicsCommandList* cmd_list) + : QuantGemmBaseDispatcher(std::move(params), d3d12_device, dml_device, dml_cmd_recorder, cmd_list) + , device_(d3d12_device, intc_ext.get_info()) + , dnnl_engine_(dnnl::iumd_interop::make_engine(&device_)) + { + using namespace dnnl_utils; + + //dnnl::set_verbose(umdd3d12_param.verbose_mode); + + //input_a_memory_desc_ = to_dnnl_mem_desc(params_.a_transposed ? TensorShape{ params_.shape_a.n, params_.shape_a.c, params_.shape_a.w, params_.shape_a.h } : params_.shape_a, params_.layout, params_.dt); + input_a_memory_desc_ = to_dnnl_mem_desc(params_.shape_a, params_.layout, params_.dt); + if (params_.a_transposed) + { + input_a_memory_desc_ = convert_to_ncwh_format(input_a_memory_desc_); + } + // const auto input_b_memory_desc = to_dnnl_mem_desc(params_.b_transposed ? TensorShape{ params_.shape_b.n, params_.shape_b.c, params_.shape_b.w, params_.shape_b.h } : params_.shape_b, params_.b_managed ? DataLayout::eWeightsLayoutStart : params_.layout, params_.dt); + if (params_.b_managed) + { + input_b_memory_desc_ = to_dnnl_mem_desc(params_.shape_b, DataLayout::eWeightsLayoutStart, params_.quant_dt); + } + else + { + input_b_memory_desc_ = to_dnnl_mem_desc(params_.shape_b, params_.layout, params_.quant_dt); + if (params_.b_transposed) + { + input_b_memory_desc_ = convert_to_ncwh_format(input_b_memory_desc_); + } + } + TensorShape shape_bsz = get_shape_quant_param(); + auto bs_layout = params_.bs_managed ? DataLayout::eWeightsLayoutStart : params_.layout; + input_scales_memory_desc_ = to_dnnl_mem_desc(shape_bsz, bs_layout, params_.dt); + auto bz_layout = params_.bz_managed ? DataLayout::eWeightsLayoutStart : params_.layout; + // FIXME: allocate extra memory and do reorders for these params, since they're expected to be row-major only + if (params_.has_zero_point) + { + input_zeropoints_memory_desc_.emplace(to_dnnl_mem_desc(shape_bsz, bz_layout, params_.quant_dt)); + } + output_memory_desc_ = to_dnnl_mem_desc(get_shape_output(), params_.layout, params_.dt); + + if (has_c_tensor()) + { + input_c_memory_desc_.emplace(to_dnnl_mem_desc(params_.shape_c, params_.layout, params_.dt)); + } + + const dnnl::primitive_attr attr = [this]() + { + // create a post-op with relu + dnnl::post_ops ops; + dnnl::primitive_attr attr; + + // sanity check + assert(attr.get_scratchpad_mode() == dnnl::scratchpad_mode::library); + // set scratchpad mode to user provided + attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); + + const bool force_fp32_accu = params_.dt == DataType::eFp16 && !params_.allow_fp16_computations; + if (force_fp32_accu) + { + attr.set_accumulation_mode(dnnl::accumulation_mode::strict); + } + + dnnl::memory::dims scale_dims(2); + scale_dims[0] = params_.block_size; + scale_dims[1] = 1; + attr.set_scales(DNNL_ARG_WEIGHTS, 0xc, scale_dims, dnnl_utils::to_dnnl_data_type(params_.dt)); + if (params_.has_zero_point) + { + attr.set_zero_points(DNNL_ARG_WEIGHTS, 0xc, scale_dims, dnnl_utils::to_dnnl_data_type(params_.quant_dt)); + } + + attr.set_fpmath_mode(dnnl::fpmath_mode::strict, true); + + // alpha + if (params_.alpha != 1.0f || has_beta_scaling_factors()) + { + ops.append_eltwise(dnnl::algorithm::eltwise_linear, has_beta_scaling_factors() ? params_.alpha / params_.beta : params_.alpha, 0.0f); + } + + if (has_c_tensor()) + { + ops.append_binary(dnnl::algorithm::binary_add, input_c_memory_desc_.value()); + } + + if (has_beta_scaling_factors()) + { + ops.append_eltwise(dnnl::algorithm::eltwise_linear, params_.beta, 0.0f); + } + + if (params_.activation.type != ActivationType::eUnknown) + { + ops.append_eltwise(to_dnnl_activation_type(params_.activation.type), params_.activation.alpha, params_.activation.beta); + attr.set_post_ops(ops); + } + + attr.set_post_ops(ops); + return attr; + }(); + + dnnl::matmul::primitive_desc matmul_desc(dnnl_engine_, + input_a_memory_desc_, + input_b_memory_desc_, + output_memory_desc_, + attr + ); + std::cout << "dnnl-umd kernel impl: " << matmul_desc.impl_info_str() << std::endl; + + input_b_memory_desc_ = matmul_desc.query_md(dnnl::query::weights_md, 0); + const auto persistent_resource_size = [&]() + { + std::size_t ret = 0ull; + + if (params_.b_managed) + { + ret += input_b_memory_desc_.get_size(); + } + if (params_.bs_managed) + { + ret += input_scales_memory_desc_.get_size(); + } + if (params_.bz_managed) + { + ret += input_zeropoints_memory_desc_->get_size(); + } + + if (params_.c_managed) + { + assert(!"params_.c_managed is nt not tested option, most likely bugs hidden somewhere!"); + assert(input_c_memory_desc_.has_value()); + ret += input_c_memory_desc_->get_size(); + } + return ret; + }(); + + if (persistent_resource_size != 0) + { + persistent_buffer_ = create_buffer(d3d12_device, persistent_resource_size, + D3D12_HEAP_TYPE_DEFAULT, D3D12_RESOURCE_STATE_COMMON, D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS); + } + + assert(matmul_desc.query_s64(dnnl::query::memory_consumption_s64) == 0); // we provide scratchpad, so sanity check that primitive does not require any "hidden" memory + scratchpad_memory_desc_.emplace(matmul_desc.query_md(dnnl::query::scratchpad_md)); + const auto temporary_resoruce_size = [&]() + { + return scratchpad_memory_desc_->get_size(); + }(); + if (temporary_resoruce_size != 0) + { + const auto heap_props = CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT); + temporary_buffer_ = create_buffer(d3d12_device, temporary_resoruce_size, + D3D12_HEAP_TYPE_DEFAULT, D3D12_RESOURCE_STATE_COMMON, D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS); + } + + // create qgemm primitive + + + std::ifstream in_key_file("onednn_persistent_cache.key", std::ofstream::in | std::ifstream::binary); + std::ifstream in_value_file("onednn_persistent_cache.value", std::ofstream::in | std::ifstream::binary); + std::vector buffer_key; + std::vector buffer_value; + const auto conv_blob_key = matmul_desc.get_cache_blob_id(); + if (umdd3d12_param.cache_blob && in_key_file.is_open()) + { + buffer_key = std::vector(std::istreambuf_iterator(in_key_file), {}); + } + if (buffer_key == conv_blob_key) + { + std::cout << "Found persistent cache blob files. Using them to create gemm primitive!" << std::endl; + assert(in_value_file.is_open()); // Proper file with key value exists, but file with cache blob (value) does not exist. Delete file with key and rerun application. + buffer_value = std::vector(std::istreambuf_iterator(in_value_file), {}); + } + const auto t0 = std::chrono::high_resolution_clock::now(); + if (buffer_value.empty()) + { + gemm_ = dnnl::matmul(matmul_desc); + } + else + { + gemm_ = dnnl::matmul(matmul_desc, buffer_value); + } + const auto t1 = std::chrono::high_resolution_clock::now(); + const auto diff = std::chrono::duration_cast(t1 - t0); + std::cout << "Primitive create time: " << diff << std::endl; + + if (umdd3d12_param.cache_blob && buffer_value.empty()) + { + std::cout << "Storing persistent cache blob files for." << std::endl; + auto store_binary_data_to_file = [](const auto& file_name, const auto& data) + { + std::ofstream out_file(file_name, std::ofstream::out | std::ofstream::binary); + std::copy(data.begin(), data.end(), std::ostream_iterator(out_file)); + out_file.close(); + }; + const auto cache_blob_id = matmul_desc.get_cache_blob_id(); + store_binary_data_to_file("onednn_persistent_cache.key", cache_blob_id); + + const auto cache_blob = gemm_.get_cache_blob(); + store_binary_data_to_file("onednn_persistent_cache.value", cache_blob); + } + + if (params_.b_managed) + { + auto input_b_memory_desc_physical_ = to_dnnl_mem_desc(params_.shape_b, params_.layout, params_.quant_dt); + if (params_.b_transposed) + input_b_memory_desc_physical_ = convert_to_ncwh_format(input_b_memory_desc_physical_); + dnnl::reorder::primitive_desc reorder_desc(dnnl_engine_, input_b_memory_desc_physical_, dnnl_engine_, input_b_memory_desc_); + reorder_input_b_ = dnnl::reorder(reorder_desc); + } + + if (params_.b_transposed) + { + dnnl::reorder::primitive_desc reorder_desc_bs(dnnl_engine_, to_dnnl_mem_desc(shape_bsz, params_.layout, params_.dt), dnnl_engine_, convert_to_ncwh_format(input_scales_memory_desc_)); + reorder_input_scales_ = dnnl::reorder(reorder_desc_bs); + dnnl::reorder::primitive_desc reorder_desc_bz(dnnl_engine_, to_dnnl_mem_desc(shape_bsz, params_.layout, params_.quant_dt), dnnl_engine_, convert_to_ncwh_format(input_zeropoints_memory_desc_.value())); + reorder_input_zeropoints_ = dnnl::reorder(reorder_desc_bz); + } + + if (params_.c_managed) + { + assert(input_c_memory_desc_.has_value()); + // its just a copy + dnnl::reorder::primitive_desc reorder_desc(dnnl_engine_, input_c_memory_desc_.value(), dnnl_engine_, input_c_memory_desc_.value()); + reorder_input_c_ = dnnl::reorder(reorder_desc); + } + } + + std::uint32_t get_total_descriptor_count()override + { + // allocate enough descriptor upfront + return 50u; + } + + void initialize(ID3D12GraphicsCommandList* cmd_list, D3D12_CPU_DESCRIPTOR_HANDLE cpu_handle, D3D12_GPU_DESCRIPTOR_HANDLE gpu_handle) override + { + ID3D12GraphicsCommandList4* cmd_list4 = nullptr; + throw_if_failed(cmd_list->QueryInterface(&cmd_list4), "cant cast d3d12 device to ID3D12Device5"); + iumd::custom_metacommand::UmdD3d12CommandList cmd(cmd_list4); + dnnl::stream stream = dnnl::iumd_interop::make_stream(dnnl_engine_, &cmd); + + base_cpu_handle_ = CD3DX12_CPU_DESCRIPTOR_HANDLE{ cpu_handle }; + base_gpu_handle_ = CD3DX12_GPU_DESCRIPTOR_HANDLE{ gpu_handle }; + + if (!reorder_input_b_ && !reorder_input_c_ && !reorder_input_scales_ && !reorder_input_zeropoints_ /* && !copy_alpha_shader_*/) + { + // early exit, as no reordering needed or copy shader for alpha value not needed + return; + } + + std::vector> resources_list; + resources_list.reserve(4); + if (persistent_buffer_) + { + resources_list.push_back({ DescType::eUav, persistent_buffer_.Get() }); + } + if (reorder_input_b_) + { + resources_list.push_back({ DescType::eUav, input_buffer_b_.Get() }); + } + if (reorder_input_scales_) + { + resources_list.push_back({ DescType::eUav, input_buffer_scale_.Get() }); + resources_list.push_back({ DescType::eUav, reorder_buffer_scale_.Get() }); + } + if (reorder_input_zeropoints_) + { + resources_list.push_back({ DescType::eUav, input_buffer_zeropoint_.Get() }); + resources_list.push_back({ DescType::eUav, reorder_buffer_zeropoint_.Get() }); + } + if (reorder_input_c_) + { + resources_list.push_back({ DescType::eUav, input_buffer_c_.Get() }); + } + const auto gpu_handles = create_resource_views_and_handles(d3d12_device_, resources_list, base_cpu_handle_, base_gpu_handle_); + + std::size_t rsc_idx = 0; + auto umd_persistent_mem = persistent_buffer_ ? iumd::custom_metacommand::UmdD3d12Memory(gpu_handles[rsc_idx++]) : iumd::custom_metacommand::UmdD3d12Memory{}; + std::size_t persistent_mem_offset = 0; + + // weights reorder + if (reorder_input_b_) + { + auto umd_input_mem = iumd::custom_metacommand::UmdD3d12Memory(gpu_handles[rsc_idx++]); + + auto input_memory_desc_ = dnnl_utils::to_dnnl_mem_desc(params_.shape_b, params_.layout, params_.quant_dt); + if (params_.b_transposed) + { + input_memory_desc_ = dnnl_utils::convert_to_ncwh_format(input_memory_desc_); + } + dnnl::memory input_memory = create_dnnl_memory(input_memory_desc_, umd_input_mem); + dnnl::memory reorder_memory = create_dnnl_memory(input_b_memory_desc_, umd_persistent_mem, persistent_mem_offset); + + std::unordered_map args; + args.insert({ DNNL_ARG_SRC, input_memory }); + args.insert({ DNNL_ARG_DST, reorder_memory }); + + reorder_input_b_.execute(stream, args); + persistent_mem_offset += input_b_memory_desc_.get_size(); + } + + // scales reorder + if (reorder_input_scales_) + { + auto umd_input_mem = iumd::custom_metacommand::UmdD3d12Memory(gpu_handles[rsc_idx++]); + auto umd_reorder_mem = iumd::custom_metacommand::UmdD3d12Memory(gpu_handles[rsc_idx++]); + + TensorShape shape_bsz = get_shape_quant_param(); + auto input_memory_desc_ = dnnl_utils::to_dnnl_mem_desc(shape_bsz, params_.layout, params_.dt); + /*if (params_.b_transposed) + { + input_memory_desc_ = dnnl_utils::convert_to_ncwh_format(input_memory_desc_); + }*/ + dnnl::memory input_memory = create_dnnl_memory(input_memory_desc_, umd_input_mem); + dnnl::memory reorder_memory = create_dnnl_memory(dnnl_utils::convert_to_ncwh_format(input_scales_memory_desc_), umd_reorder_mem); + + std::unordered_map args; + args.insert({ DNNL_ARG_SRC, input_memory }); + args.insert({ DNNL_ARG_DST, reorder_memory }); + + reorder_input_scales_.execute(stream, args); + } + + // zeropoints reorder + if (reorder_input_zeropoints_) + { + auto umd_input_mem = iumd::custom_metacommand::UmdD3d12Memory(gpu_handles[rsc_idx++]); + auto umd_reorder_mem = iumd::custom_metacommand::UmdD3d12Memory(gpu_handles[rsc_idx++]); + + TensorShape shape_bsz = get_shape_quant_param(); + auto input_memory_desc_ = dnnl_utils::to_dnnl_mem_desc(shape_bsz, params_.layout, params_.quant_dt); + /*if (params_.b_transposed) + { + input_memory_desc_ = dnnl_utils::convert_to_ncwh_format(input_memory_desc_); + }*/ + dnnl::memory input_memory = create_dnnl_memory(input_memory_desc_, umd_input_mem); + dnnl::memory reorder_memory = create_dnnl_memory(dnnl_utils::convert_to_ncwh_format(input_zeropoints_memory_desc_.value()), umd_reorder_mem); + + std::unordered_map args; + args.insert({ DNNL_ARG_SRC, input_memory }); + args.insert({ DNNL_ARG_DST, reorder_memory }); + + reorder_input_zeropoints_.execute(stream, args); + } + + // weights reorder + if (reorder_input_c_) + { + assert(input_c_memory_desc_.has_value()); + auto umd_input_mem = iumd::custom_metacommand::UmdD3d12Memory(gpu_handles[rsc_idx++]); + + dnnl::memory input_memory = create_dnnl_memory(input_c_memory_desc_.value(), umd_input_mem); + dnnl::memory reorder_memory = create_dnnl_memory(input_c_memory_desc_.value(), umd_persistent_mem, persistent_mem_offset); + + std::unordered_map args; + args.insert({ DNNL_ARG_SRC, input_memory }); + args.insert({ DNNL_ARG_DST, reorder_memory }); + + reorder_input_c_.execute(stream, args); + } + } + + void execute(ID3D12GraphicsCommandList* cmd_list) override + { + std::vector> resources_list; + resources_list.reserve(get_total_descriptor_count()); + resources_list.push_back({ DescType::eUav, input_buffer_a_.Get() }); + if (persistent_buffer_) + { + resources_list.push_back({ DescType::eUav, persistent_buffer_.Get() }); + } + if (!reorder_input_b_) + { + resources_list.push_back({ DescType::eUav, input_buffer_b_.Get() }); + } + if (!reorder_input_scales_) + { + resources_list.push_back({ DescType::eUav, input_buffer_scale_.Get() }); + } + else + { + resources_list.push_back({ DescType::eUav, reorder_buffer_scale_.Get() }); + } + if (params_.has_zero_point && !reorder_input_zeropoints_) + { + resources_list.push_back({ DescType::eUav, input_buffer_zeropoint_.Get() }); + } + else if (params_.has_zero_point && reorder_input_zeropoints_) + { + resources_list.push_back({ DescType::eUav, reorder_buffer_zeropoint_.Get() }); + } + + resources_list.push_back({ DescType::eUav, output_buffer_.Get() }); + if (input_buffer_c_ && !reorder_input_c_) + { + resources_list.push_back({ DescType::eUav, input_buffer_c_.Get() }); + } + if (temporary_buffer_) + { + resources_list.push_back({ DescType::eUav, temporary_buffer_.Get() }); + } + const auto gpu_handles = create_resource_views_and_handles(d3d12_device_, resources_list, base_cpu_handle_, base_gpu_handle_); + + std::size_t res_idx = 0; + auto umd_input_a_memory_ = iumd::custom_metacommand::UmdD3d12Memory(gpu_handles[res_idx++]); + auto umd_persitent_memory = persistent_buffer_ ? iumd::custom_metacommand::UmdD3d12Memory(gpu_handles[res_idx++]) : iumd::custom_metacommand::UmdD3d12Memory(); + auto umd_input_b_memory_ = reorder_input_b_ ? umd_persitent_memory : iumd::custom_metacommand::UmdD3d12Memory(gpu_handles[res_idx++]); + auto umd_input_scales_memory_ = /*reorder_input_scales_ ? umd_persitent_memory :*/ iumd::custom_metacommand::UmdD3d12Memory(gpu_handles[res_idx++]); + auto umd_input_zeropoints_memory_ = params_.has_zero_point ? (/*reorder_input_zeropoints_ ? umd_persitent_memory :*/ iumd::custom_metacommand::UmdD3d12Memory(gpu_handles[res_idx++])) : iumd::custom_metacommand::UmdD3d12Memory(); + auto umd_output_memory_ = iumd::custom_metacommand::UmdD3d12Memory(gpu_handles[res_idx++]); + auto umd_input_c_memory_ = [&]() + { + if (input_buffer_c_ && !reorder_input_c_) + { + return iumd::custom_metacommand::UmdD3d12Memory(gpu_handles[res_idx++]); + } + else if (reorder_input_c_) + { + return umd_persitent_memory; + } + return iumd::custom_metacommand::UmdD3d12Memory{}; + }(); + + //auto umd_alpha_beta_memory_ = umd_persitent_memory; + auto umd_scratchpad_memory_ = temporary_buffer_ ? iumd::custom_metacommand::UmdD3d12Memory(gpu_handles[res_idx++]) : iumd::custom_metacommand::UmdD3d12Memory(); + + // stream is created in execute(...), because in MetaCommand cmd list object can be different from execute-to-execute + ID3D12GraphicsCommandList4* cmd_list4 = nullptr; + throw_if_failed(cmd_list->QueryInterface(&cmd_list4), "cant cast d3d12 device to ID3D12Device5"); + iumd::custom_metacommand::UmdD3d12CommandList cmd(cmd_list4); + dnnl::stream stream = dnnl::iumd_interop::make_stream(dnnl_engine_, &cmd); + + // memory resources are created in execute(...), because in MetaCommand these objects can be different from execute-to-execute + dnnl::memory input_memory = create_dnnl_memory(input_a_memory_desc_, umd_input_a_memory_); + + std::size_t persistent_mem_offset = 0; + + dnnl::memory input_b_memory = [&](std::size_t& persistent_mem_offset) + { + std::size_t offset = persistent_mem_offset; + if (reorder_input_b_) + { + persistent_mem_offset += input_b_memory_desc_.get_size(); + } + return create_dnnl_memory(input_b_memory_desc_, umd_input_b_memory_, reorder_input_b_ ? offset : 0ull); + }(persistent_mem_offset); + + dnnl::memory input_scales_memory = [&](std::size_t& persistent_mem_offset) + { + std::size_t offset = 0; // persistent_mem_offset; + if (false) //reorder_input_scales_) + { + persistent_mem_offset += input_scales_memory_desc_.get_size(); + } + return create_dnnl_memory(dnnl_utils::convert_to_ncwh_format(input_scales_memory_desc_), umd_input_scales_memory_, reorder_input_scales_ ? offset : 0ull); + }(persistent_mem_offset); + + dnnl::memory input_zeropoints_memory = [&](std::size_t& persistent_mem_offset) + { + if (params_.has_zero_point) + { + std::size_t offset = 0; // persistent_mem_offset; + if (false) // reorder_input_zeropoints_) + { + persistent_mem_offset += input_zeropoints_memory_desc_->get_size(); + } + return create_dnnl_memory(dnnl_utils::convert_to_ncwh_format(input_zeropoints_memory_desc_.value()), umd_input_zeropoints_memory_, reorder_input_zeropoints_ ? offset : 0ull); + } + return dnnl::memory{}; + }(persistent_mem_offset); + + dnnl::memory input_c_memory = [&](std::size_t& persistent_mem_offset) + { + if (has_c_tensor()) + { + std::size_t offset = persistent_mem_offset; + if (reorder_input_c_) + { + persistent_mem_offset += input_c_memory_desc_->get_size(); + } + return create_dnnl_memory(input_c_memory_desc_.value(), umd_input_c_memory_, reorder_input_c_ ? offset : 0ull); + } + return dnnl::memory{}; + }(persistent_mem_offset); + + std::optional scratchpad_memory; + if (has_scratchpad_tensor()) + { + scratchpad_memory.emplace(create_dnnl_memory(scratchpad_memory_desc_.value(), umd_scratchpad_memory_)); + } + + dnnl::memory output_memory = create_dnnl_memory(output_memory_desc_, umd_output_memory_); + + std::unordered_map args; + args.insert({ DNNL_ARG_SRC, input_memory }); + args.insert({ DNNL_ARG_WEIGHTS, input_b_memory }); + args.insert({ DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, input_scales_memory }); + if (input_zeropoints_memory) + { + args.insert({ DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS, input_zeropoints_memory }); + } + std::size_t post_ops_idx = 0ull; + if (params_.alpha != 1 || has_beta_scaling_factors()) + { + post_ops_idx++; + } + + if (input_c_memory) + { + args.insert({ static_cast(DNNL_ARG_ATTR_MULTIPLE_POST_OP(post_ops_idx) | DNNL_ARG_SRC_1), input_c_memory }); + post_ops_idx++; + } + + if (has_beta_scaling_factors()) + { + post_ops_idx++; + } + + if (scratchpad_memory_desc_) + { + args.insert({ DNNL_ARG_SCRATCHPAD, scratchpad_memory.value() }); + } + + args.insert({ DNNL_ARG_DST, output_memory }); + + gemm_.execute(stream, args); + } + + + ConformanceResult validate_conformance(ID3D12CommandQueue* command_queue, + ID3D12CommandAllocator* command_allocator, ID3D12GraphicsCommandList* command_list, bool print_mismatche, std::size_t reference_dispatch_iterations) override + { + auto dump_buffer_to_file = [&](const auto& buffer, const auto& file_name) + { + if (!buffer) + { + return; + } + const auto bytes_width = buffer->GetDesc().Width; + // readback data and validate + auto readback_buffer = create_buffer(d3d12_device_, bytes_width, D3D12_HEAP_TYPE_READBACK, D3D12_RESOURCE_STATE_COPY_DEST); + auto readback_output_barrirer = CD3DX12_RESOURCE_BARRIER::Transition(buffer.Get(), + D3D12_RESOURCE_STATE_UNORDERED_ACCESS, D3D12_RESOURCE_STATE_COPY_SOURCE); + command_list->ResourceBarrier(1, &readback_output_barrirer); + command_list->CopyResource(readback_buffer.Get(), buffer.Get()); + close_execute_reset_wait(d3d12_device_, command_queue, command_allocator, command_list); + + std::vector data_out(bytes_width); + std::byte* readback_mapped_ptr = nullptr; + readback_buffer->Map(0, nullptr, reinterpret_cast(&readback_mapped_ptr)); + std::memcpy(data_out.data(), readback_mapped_ptr, data_out.size()); + readback_buffer->Unmap(0, nullptr); + + // Assuming data_out now contains the char data + char* char_ptr = reinterpret_cast(data_out.data()); + size_t num_chars = data_out.size() / sizeof(char); + std::ofstream file(file_name, std::ios::out); // Open in text mode; use std::ios::binary for binary mode + for (size_t i = 0; i < num_chars; ++i) { + file << (int)char_ptr[i] << std::endl; // Write in text format; for binary, use file.write(reinterpret_cast(&char_ptr[i]), sizeof(char)); + } + file.close(); + }; + + if (params_.dump_resource) + { + dump_buffer_to_file(persistent_buffer_, "umd_qgemm_data.txt"); + dump_buffer_to_file(input_buffer_zeropoint_, "umd_qgemm_zeros.txt"); + dump_buffer_to_file(reorder_buffer_zeropoint_, "umd_qgemm_zeros_reorder.txt"); + } + + const auto ret = QuantGemmBaseDispatcher::validate_conformance(command_queue, command_allocator, command_list, print_mismatche, reference_dispatch_iterations); + return ret; + } +private: + dnnl::memory create_dnnl_memory(const auto& desc, auto& umd_mem, std::size_t offset = 0) + { + return dnnl::iumd_interop::make_memory(desc, dnnl_engine_, &umd_mem, offset); + }; + +private: + bool has_beta_scaling_factors() const + { + // OneDNNL has a bit different GEMM API defintion: alpha*A*B + beta*C + // DirectML: beta*(alpha/beta*(A*B)+C)) + // we will pass alpha as alpha/beta if beta value is effective + return params_.beta != 0.0f && params_.beta != 1.0f; + } + + bool has_c_tensor() const + { + return input_buffer_c_ != nullptr; + } + + bool has_scratchpad_tensor() const + { + return scratchpad_memory_desc_.has_value(); + } + +private: + iumd::custom_metacommand::UmdD3d12Device device_; + dnnl::engine dnnl_engine_; + + dnnl::matmul gemm_; + dnnl::reorder reorder_input_b_; + dnnl::reorder reorder_input_scales_; + dnnl::reorder reorder_input_zeropoints_; + dnnl::reorder reorder_input_c_; + + iumd::IUMDPipelineStateObject::Ptr copy_alpha_shader_ = nullptr; + + dnnl::memory::desc input_a_memory_desc_; + dnnl::memory::desc input_b_memory_desc_; + dnnl::memory::desc input_scales_memory_desc_; + std::optional input_zeropoints_memory_desc_; + dnnl::memory::desc output_memory_desc_; + std::optional input_c_memory_desc_; + std::optional scratchpad_memory_desc_; + + ComPtr temporary_buffer_; + ComPtr persistent_buffer_; // ToDo: input_b can be managed, than it should be used for that + + CD3DX12_CPU_DESCRIPTOR_HANDLE base_cpu_handle_; + CD3DX12_GPU_DESCRIPTOR_HANDLE base_gpu_handle_; +}; + class GemmBaseDispatcher : public NodeDispatcher { public: @@ -1684,7 +2329,7 @@ class GemmBaseDispatcher : public NodeDispatcher { if (params_.type == GemmType::GemmType_AB || params_.type == GemmType::GemmType_SV_S_QKV || params_.type == GemmType::GemmType_SV_S_KV) { - return params_.b_transposed ? params_.shape_b.h : params_.shape_b.w; + return params_.shape_b.w; } else if (params_.type == GemmType::GemmType_QK_QKV) { diff --git a/tools/common_lib/src/iumd_d3d12_impl.h b/tools/common_lib/src/iumd_d3d12_impl.h index c81d073..7896e27 100644 --- a/tools/common_lib/src/iumd_d3d12_impl.h +++ b/tools/common_lib/src/iumd_d3d12_impl.h @@ -138,6 +138,11 @@ namespace custom_metacommand return sku_.threads_per_eu * sku_.eu_per_dss * sku_.hw_simd_size; }; + std::uint32_t get_l3_cache_size() const //override + { + return 16 * 1024 * 1024; // FixMe: 16MB is a plug + }; + UMD_IGFX get_umd_igfx() const override { return sku_.igfx; diff --git a/tools/common_lib/src/node_dispatcher.h b/tools/common_lib/src/node_dispatcher.h index 52246d8..b826c1a 100644 --- a/tools/common_lib/src/node_dispatcher.h +++ b/tools/common_lib/src/node_dispatcher.h @@ -19,6 +19,7 @@ enum class NodeType eMhaCm, eMemoryBandwidth, eQuantGemmDml, + eQuantGemmUmdD3d12, eCount }; diff --git a/tools/cross_runner/src/main.cpp b/tools/cross_runner/src/main.cpp index 77ac9e6..b2f4d2e 100644 --- a/tools/cross_runner/src/main.cpp +++ b/tools/cross_runner/src/main.cpp @@ -73,6 +73,7 @@ struct CliOptions ConvolutionUmdD3d12Dispatcher::conv_umdd3d12_params_t conv_umdd3d12_params{}; MvnCmDispatcher::mvn_cm_params_t mvn_cm_params{}; SoftmaxCmDispatcher::softmax_cm_params_t softmax_cm_params{}; + QuantGemmUmdD3d12Dispatcher::qgemm_umdd3d12_params_t quant_gemm_umdd3d12_opts{}; GemmCmDispatcher::cm_params_t gemm_cm_params{}; GemmUmdD3d12Dispatcher::gemm_umdd3d12_params_t gemm_umdd3d12_params{}; MhaCmDispatcher::cm_params_t mha_cm_opts{}; @@ -93,7 +94,8 @@ int main(int argc, const char*argv[]) NodeType::eSoftmaxDml, NodeType::eSoftmaxCm, NodeType::eMvnDml, NodeType::eMvnCm, NodeType::eMhaDml, NodeType::eMhaCm, - NodeType::eMemoryBandwidth, NodeType::eQuantGemmDml + NodeType::eMemoryBandwidth, NodeType::eQuantGemmDml, + NodeType::eQuantGemmUmdD3d12 }))-> transform(CLI::Transformer(std::map{ { "conv_dml", NodeType::eConvDml }, @@ -110,6 +112,7 @@ int main(int argc, const char*argv[]) { "mha_cm", NodeType::eMhaCm}, { "mem_bw", NodeType::eMemoryBandwidth }, { "quant_gemm_dml", NodeType::eQuantGemmDml }, + { "quant_gemm_umd_d3d12", NodeType::eQuantGemmUmdD3d12 }, }, CLI::ignore_case, CLI::ignore_underscore)); dml_runner_app.add_option("--iters", opts.dispatch_iterations, "How many iterations to run.")->check(CLI::Range(1u, MAX_ITERATIONS)); dml_runner_app.add_flag("--no_conform", opts.no_conformance_check); @@ -144,6 +147,8 @@ int main(int argc, const char*argv[]) GemmCmDispatcher::cm_params_t::add_cli_options(gemm_cm_option_groups, opts.gemm_cm_params); auto gemm_umdd3d12_option_groups = dml_runner_app.add_subcommand("gemm_umdd3d12_opts", "Options for gemm layer with CM implementation."); GemmUmdD3d12Dispatcher::gemm_umdd3d12_params_t::add_cli_options(gemm_umdd3d12_option_groups, opts.gemm_umdd3d12_params); + auto quant_gemm_umdd3d12_option_groups = dml_runner_app.add_subcommand("quant_gemm_umdd3d12_opts", "Options for Quantized Gemm layer with oneDNN implementation."); + QuantGemmUmdD3d12Dispatcher::qgemm_umdd3d12_params_t::add_cli_options(quant_gemm_umdd3d12_option_groups, opts.quant_gemm_umdd3d12_opts); auto mem_bw_option_group = dml_runner_app.add_subcommand("mem_bw_opts", "Options for memory bandwidths measurements"); gpu_op::MemoryBandwidthDispatcher::MemoryBandwidthDispatcher::create_params_t::add_cli_options(mem_bw_option_group, opts.memory_bw_params); auto mha_cm_option_groups = dml_runner_app.add_subcommand("mha_cm_opts", "Options for mha layer with CM implementation."); @@ -173,7 +178,7 @@ int main(int argc, const char*argv[]) std::cout << "Gemm options not set.\n"; return -1; } - else if (opts.node_type == NodeType::eQuantGemmDml && !quant_gemm_option_groups->parsed()) + else if ((opts.node_type == NodeType::eQuantGemmDml || opts.node_type == NodeType::eQuantGemmUmdD3d12) && !quant_gemm_option_groups->parsed()) { std::cout << "Quant Gemm options not set.\n"; return -1; @@ -281,6 +286,11 @@ int main(int argc, const char*argv[]) node = std::make_unique(std::move(opts.quant_gemm_opts), true, d3d12_device.Get(), dml_device.Get(), dml_command_recorder.Get(), command_list.Get()); } + else if (opts.node_type == NodeType::eQuantGemmUmdD3d12) + { + node = std::make_unique(std::move(opts.quant_gemm_opts), std::move(opts.quant_gemm_umdd3d12_opts), + intel_extension_d3d12, d3d12_device.Get(), dml_device.Get(), dml_command_recorder.Get(), command_list.Get()); + } else { assert(false && "Unknown node type!");