Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tools/common_lib/src/dml_base_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ inline dml::TensorPolicy to_dml_tensor_policy(DataLayout layout)
switch (layout)
{
case DataLayout::eCHW:
case DataLayout::eNCDHW:
case DataLayout::eNCHW: return dml::TensorPolicy::Default();
case DataLayout::eNHWC: return dml::TensorPolicy::InterleavedChannel();
case DataLayout::eW: return dml::TensorPolicy(compute_w_tensor_policy);
Expand Down
35 changes: 31 additions & 4 deletions tools/common_lib/src/layers_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,9 @@ enum class DataLayout
eNHWC_AlignH48, // example layout for unpacked tensor, ToDo: refactor cross runner to work with strides instead of hardcoded data layouts
eCHW, // 3d dims for GEMMS
eW,


eNCDHW,
eNHW,
eNCH,
// ..
// ..

Expand Down Expand Up @@ -194,6 +195,7 @@ inline std::string data_layout_name(DataLayout l)
case DataLayout::eIO_i8_o8_i2: return "IO_i8_o8_i2";
case DataLayout::eOYXI_o8: return "OYXI_o8";
case DataLayout::eOYXI_o16: return "OYXI_o16";
case DataLayout::eNCDHW: return "Stacked_qkv";
default:
assert(false && "Unknown data layout name.");
return "";
Expand All @@ -215,6 +217,8 @@ inline std::uint8_t data_layout_dimensions_count(DataLayout l)
return 3;
case DataLayout::eW:
return 1;
case DataLayout::eNCDHW:
return 5;
default:
return 0;
}
Expand Down Expand Up @@ -252,6 +256,7 @@ inline std::size_t data_layout_h_alignment(const DataLayout l)
inline TensorShape data_layout_to_strides(TensorShape shape, DataLayout l)
{
const auto c = shape.c;
const auto d = shape.d;
const auto h = static_cast<std::uint32_t>(align(shape.h, data_layout_h_alignment(l)));
const auto w = static_cast<std::uint32_t>(align(shape.w, data_layout_w_alignment(l)));

Expand Down Expand Up @@ -279,6 +284,28 @@ inline TensorShape data_layout_to_strides(TensorShape shape, DataLayout l)
c);
break;
}
case DataLayout::eNCDHW:
{
ret = TensorShape(
c * h * d * w,
h * d* w,
d * w,
w,
1);
break;
case DataLayout::eNHW:
ret = TensorShape(
h * w,
w,
1);
break;
case DataLayout::eNCH:
ret = TensorShape(
c * h,
h,
1);
break;
}
default:
assert("!unsupported right now");
}
Expand Down Expand Up @@ -361,8 +388,8 @@ inline auto add_data_type_cli_option(CLI::App* opts, std::string_view opt_name,

inline auto add_data_layout_cli_option(CLI::App* opts, std::string_view opt_name, DataLayout& layout)
{
return opts->add_option(opt_name.data(), layout)->check(CLI::IsMember({DataLayout::eNCHW, DataLayout::eNHWC, DataLayout::eW, DataLayout::eCHW, DataLayout::eNCHW_AlignW320, DataLayout::eNHWC_AlignH48 }))
return opts->add_option(opt_name.data(), layout)->check(CLI::IsMember({DataLayout::eNCHW, DataLayout::eNHWC, DataLayout::eW, DataLayout::eCHW, DataLayout::eNCHW_AlignW320, DataLayout::eNHWC_AlignH48, DataLayout::eNCDHW }))
->transform(CLI::Transformer(std::map<std::string, DataLayout>{
{"nchw", DataLayout::eNCHW}, { "nhwc", DataLayout::eNHWC }, { "w", DataLayout::eW }, { "chw", DataLayout::eCHW }, { "nchw_alignw320", DataLayout::eNCHW_AlignW320 }, { "nhwc_alignh48", DataLayout::eNHWC_AlignH48 },
{"nchw", DataLayout::eNCHW}, { "nhwc", DataLayout::eNHWC }, { "w", DataLayout::eW }, { "chw", DataLayout::eCHW }, { "nchw_alignw320", DataLayout::eNCHW_AlignW320 }, { "nhwc_alignh48", DataLayout::eNHWC_AlignH48 }, { "ncdhw", DataLayout::eNCDHW },
}, CLI::ignore_case, CLI::ignore_underscore));
}
Loading