Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 48 additions & 5 deletions tools/common_lib/src/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,8 @@ class GemmBaseDispatcher : public NodeDispatcher
bool allow_fp16_computations = false;
bool use_dnnl_for_reference_calculations = false;

bool dump_resource = false;

inline static void add_cli_options(CLI::App* opts, create_params_t& params)
{
add_data_type_cli_option(opts, "--data_type", params.dt)->required();
Expand Down Expand Up @@ -531,6 +533,8 @@ class GemmBaseDispatcher : public NodeDispatcher
{ "sv_s_kv", GemmType::GemmType_SV_S_KV },
}, CLI::ignore_case))->required();

opts->add_flag("--dump_resource", params.dump_resource);

}
};
public:
Expand Down Expand Up @@ -1292,12 +1296,8 @@ class GemmUmdD3d12Dispatcher : public GemmBaseDispatcher
std::optional<dnnl::memory> scratchpad_memory;
if(has_scratchpad_tensor())
{
if (scratchpad_memory_desc_)
{
scratchpad_memory.emplace(create_dnnl_memory(scratchpad_memory_desc_.value(), umd_scratchpad_memory_));
}
scratchpad_memory.emplace(create_dnnl_memory(scratchpad_memory_desc_.value(), umd_scratchpad_memory_));
}


dnnl::memory output_memory = create_dnnl_memory(output_memory_desc_, umd_output_memory_);

Expand Down Expand Up @@ -1327,6 +1327,49 @@ class GemmUmdD3d12Dispatcher : public GemmBaseDispatcher
gemm_.execute(stream, args);
}


ConformanceResult validate_conformance(ID3D12CommandQueue* command_queue,
ID3D12CommandAllocator* command_allocator, ID3D12GraphicsCommandList* command_list, bool print_mismatche) override
{
auto dump_buffer_to_file = [&](const auto& buffer, const auto& file_name)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar lambda is in convolution umd dispatcher - could it be moved to be separate function?

{
if (!buffer)
{
return;
}
const auto bytes_width = buffer->GetDesc().Width;
// readback data and validate
auto readback_buffer = create_buffer(d3d12_device_, bytes_width, D3D12_HEAP_TYPE_READBACK, D3D12_RESOURCE_STATE_COPY_DEST);
auto readback_output_barrirer = CD3DX12_RESOURCE_BARRIER::Transition(buffer.Get(),
D3D12_RESOURCE_STATE_UNORDERED_ACCESS, D3D12_RESOURCE_STATE_COPY_SOURCE);
command_list->ResourceBarrier(1, &readback_output_barrirer);
command_list->CopyResource(readback_buffer.Get(), persistent_buffer_.Get());
close_execute_reset_wait(d3d12_device_, command_queue, command_allocator, command_list);

std::vector<std::byte> data_out(bytes_width);
std::byte* readback_mapped_ptr = nullptr;
readback_buffer->Map(0, nullptr, reinterpret_cast<void**>(&readback_mapped_ptr));
std::memcpy(data_out.data(), readback_mapped_ptr, data_out.size());
readback_buffer->Unmap(0, nullptr);

// Assuming data_out now contains the float data
float* float_ptr = reinterpret_cast<float*>(data_out.data());
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Float assumption is fine, but probably not always valid as scratchpad can have arbitrary data type <- you may want to add such extended comment so it will be clear for code reader in future.

size_t num_floats = data_out.size() / sizeof(float);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const

std::ofstream file(file_name, std::ios::out); // Open in text mode; use std::ios::binary for binary mode
for (size_t i = 0; i < num_floats; ++i) {
file << float_ptr[i] << std::endl; // Write in text format; for binary, use file.write(reinterpret_cast<const char*>(&float_ptr[i]), sizeof(float));
}
file.close();
};

if(params_.dump_resource)
{
dump_buffer_to_file(persistent_buffer_, "umd_gemm_data.txt");
}

const auto ret = GemmBaseDispatcher::validate_conformance(command_queue, command_allocator, command_list, print_mismatche);
return ret;
}
private:
dnnl::memory create_dnnl_memory(const auto& desc, auto& umd_mem, std::size_t offset = 0)
{
Expand Down