Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tools/common_lib/src/dnnl_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ inline dnnl::memory::data_type to_dnnl_data_type(const DataType l)
{
case DataType::eFp32: return dnnl::memory::data_type::f32;
case DataType::eFp16: return dnnl::memory::data_type::f16;
case DataType::eUint4: return dnnl::memory::data_type::u4;
default:
return dnnl::memory::data_type::undef;
}
Expand Down
719 changes: 702 additions & 17 deletions tools/common_lib/src/gemm.h

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions tools/common_lib/src/iumd_d3d12_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,11 @@ namespace custom_metacommand
return sku_.threads_per_eu * sku_.eu_per_dss * sku_.hw_simd_size;
};

std::uint32_t get_l3_cache_size() const //override
{
return 16 * 1024 * 1024; // FixMe: 16MB is a plug
};

UMD_IGFX get_umd_igfx() const override
{
return sku_.igfx;
Expand Down
1 change: 1 addition & 0 deletions tools/common_lib/src/node_dispatcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ enum class NodeType
eMhaCm,
eMemoryBandwidth,
eQuantGemmDml,
eQuantGemmUmdD3d12,
eCount
};

Expand Down
14 changes: 12 additions & 2 deletions tools/cross_runner/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ struct CliOptions
ConvolutionUmdD3d12Dispatcher::conv_umdd3d12_params_t conv_umdd3d12_params{};
MvnCmDispatcher::mvn_cm_params_t mvn_cm_params{};
SoftmaxCmDispatcher::softmax_cm_params_t softmax_cm_params{};
QuantGemmUmdD3d12Dispatcher::qgemm_umdd3d12_params_t quant_gemm_umdd3d12_opts{};
GemmCmDispatcher::cm_params_t gemm_cm_params{};
GemmUmdD3d12Dispatcher::gemm_umdd3d12_params_t gemm_umdd3d12_params{};
MhaCmDispatcher::cm_params_t mha_cm_opts{};
Expand All @@ -93,7 +94,8 @@ int main(int argc, const char*argv[])
NodeType::eSoftmaxDml, NodeType::eSoftmaxCm,
NodeType::eMvnDml, NodeType::eMvnCm,
NodeType::eMhaDml, NodeType::eMhaCm,
NodeType::eMemoryBandwidth, NodeType::eQuantGemmDml
NodeType::eMemoryBandwidth, NodeType::eQuantGemmDml,
NodeType::eQuantGemmUmdD3d12
}))->
transform(CLI::Transformer(std::map<std::string, NodeType>{
{ "conv_dml", NodeType::eConvDml },
Expand All @@ -110,6 +112,7 @@ int main(int argc, const char*argv[])
{ "mha_cm", NodeType::eMhaCm},
{ "mem_bw", NodeType::eMemoryBandwidth },
{ "quant_gemm_dml", NodeType::eQuantGemmDml },
{ "quant_gemm_umd_d3d12", NodeType::eQuantGemmUmdD3d12 },
}, CLI::ignore_case, CLI::ignore_underscore));
dml_runner_app.add_option("--iters", opts.dispatch_iterations, "How many iterations to run.")->check(CLI::Range(1u, MAX_ITERATIONS));
dml_runner_app.add_flag("--no_conform", opts.no_conformance_check);
Expand Down Expand Up @@ -144,6 +147,8 @@ int main(int argc, const char*argv[])
GemmCmDispatcher::cm_params_t::add_cli_options(gemm_cm_option_groups, opts.gemm_cm_params);
auto gemm_umdd3d12_option_groups = dml_runner_app.add_subcommand("gemm_umdd3d12_opts", "Options for gemm layer with CM implementation.");
GemmUmdD3d12Dispatcher::gemm_umdd3d12_params_t::add_cli_options(gemm_umdd3d12_option_groups, opts.gemm_umdd3d12_params);
auto quant_gemm_umdd3d12_option_groups = dml_runner_app.add_subcommand("quant_gemm_umdd3d12_opts", "Options for Quantized Gemm layer with oneDNN implementation.");
QuantGemmUmdD3d12Dispatcher::qgemm_umdd3d12_params_t::add_cli_options(quant_gemm_umdd3d12_option_groups, opts.quant_gemm_umdd3d12_opts);
auto mem_bw_option_group = dml_runner_app.add_subcommand("mem_bw_opts", "Options for memory bandwidths measurements");
gpu_op::MemoryBandwidthDispatcher::MemoryBandwidthDispatcher::create_params_t::add_cli_options(mem_bw_option_group, opts.memory_bw_params);
auto mha_cm_option_groups = dml_runner_app.add_subcommand("mha_cm_opts", "Options for mha layer with CM implementation.");
Expand Down Expand Up @@ -173,7 +178,7 @@ int main(int argc, const char*argv[])
std::cout << "Gemm options not set.\n";
return -1;
}
else if (opts.node_type == NodeType::eQuantGemmDml && !quant_gemm_option_groups->parsed())
else if ((opts.node_type == NodeType::eQuantGemmDml || opts.node_type == NodeType::eQuantGemmUmdD3d12) && !quant_gemm_option_groups->parsed())
{
std::cout << "Quant Gemm options not set.\n";
return -1;
Expand Down Expand Up @@ -281,6 +286,11 @@ int main(int argc, const char*argv[])
node = std::make_unique<QuantGemmDmlDispatcher>(std::move(opts.quant_gemm_opts), true,
d3d12_device.Get(), dml_device.Get(), dml_command_recorder.Get(), command_list.Get());
}
else if (opts.node_type == NodeType::eQuantGemmUmdD3d12)
{
node = std::make_unique<QuantGemmUmdD3d12Dispatcher>(std::move(opts.quant_gemm_opts), std::move(opts.quant_gemm_umdd3d12_opts),
intel_extension_d3d12, d3d12_device.Get(), dml_device.Get(), dml_command_recorder.Get(), command_list.Get());
}
else
{
assert(false && "Unknown node type!");
Expand Down