-
Notifications
You must be signed in to change notification settings - Fork 4
dump data for gemm umd path #20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -500,6 +500,8 @@ class GemmBaseDispatcher : public NodeDispatcher | |
| bool allow_fp16_computations = false; | ||
| bool use_dnnl_for_reference_calculations = false; | ||
|
|
||
| bool dump_resource = false; | ||
|
|
||
| inline static void add_cli_options(CLI::App* opts, create_params_t& params) | ||
| { | ||
| add_data_type_cli_option(opts, "--data_type", params.dt)->required(); | ||
|
|
@@ -531,6 +533,8 @@ class GemmBaseDispatcher : public NodeDispatcher | |
| { "sv_s_kv", GemmType::GemmType_SV_S_KV }, | ||
| }, CLI::ignore_case))->required(); | ||
|
|
||
| opts->add_flag("--dump_resource", params.dump_resource); | ||
|
|
||
| } | ||
| }; | ||
| public: | ||
|
|
@@ -1292,12 +1296,8 @@ class GemmUmdD3d12Dispatcher : public GemmBaseDispatcher | |
| std::optional<dnnl::memory> scratchpad_memory; | ||
| if(has_scratchpad_tensor()) | ||
| { | ||
| if (scratchpad_memory_desc_) | ||
| { | ||
| scratchpad_memory.emplace(create_dnnl_memory(scratchpad_memory_desc_.value(), umd_scratchpad_memory_)); | ||
| } | ||
| scratchpad_memory.emplace(create_dnnl_memory(scratchpad_memory_desc_.value(), umd_scratchpad_memory_)); | ||
| } | ||
|
|
||
|
|
||
| dnnl::memory output_memory = create_dnnl_memory(output_memory_desc_, umd_output_memory_); | ||
|
|
||
|
|
@@ -1327,6 +1327,49 @@ class GemmUmdD3d12Dispatcher : public GemmBaseDispatcher | |
| gemm_.execute(stream, args); | ||
| } | ||
|
|
||
|
|
||
| ConformanceResult validate_conformance(ID3D12CommandQueue* command_queue, | ||
| ID3D12CommandAllocator* command_allocator, ID3D12GraphicsCommandList* command_list, bool print_mismatche) override | ||
| { | ||
| auto dump_buffer_to_file = [&](const auto& buffer, const auto& file_name) | ||
| { | ||
| if (!buffer) | ||
| { | ||
| return; | ||
| } | ||
| const auto bytes_width = buffer->GetDesc().Width; | ||
| // readback data and validate | ||
| auto readback_buffer = create_buffer(d3d12_device_, bytes_width, D3D12_HEAP_TYPE_READBACK, D3D12_RESOURCE_STATE_COPY_DEST); | ||
| auto readback_output_barrirer = CD3DX12_RESOURCE_BARRIER::Transition(buffer.Get(), | ||
| D3D12_RESOURCE_STATE_UNORDERED_ACCESS, D3D12_RESOURCE_STATE_COPY_SOURCE); | ||
| command_list->ResourceBarrier(1, &readback_output_barrirer); | ||
| command_list->CopyResource(readback_buffer.Get(), persistent_buffer_.Get()); | ||
| close_execute_reset_wait(d3d12_device_, command_queue, command_allocator, command_list); | ||
|
|
||
| std::vector<std::byte> data_out(bytes_width); | ||
| std::byte* readback_mapped_ptr = nullptr; | ||
| readback_buffer->Map(0, nullptr, reinterpret_cast<void**>(&readback_mapped_ptr)); | ||
| std::memcpy(data_out.data(), readback_mapped_ptr, data_out.size()); | ||
| readback_buffer->Unmap(0, nullptr); | ||
|
|
||
| // Assuming data_out now contains the float data | ||
| float* float_ptr = reinterpret_cast<float*>(data_out.data()); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. const There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Float assumption is fine, but probably not always valid as scratchpad can have arbitrary data type <- you may want to add such extended comment so it will be clear for code reader in future. |
||
| size_t num_floats = data_out.size() / sizeof(float); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. const |
||
| std::ofstream file(file_name, std::ios::out); // Open in text mode; use std::ios::binary for binary mode | ||
| for (size_t i = 0; i < num_floats; ++i) { | ||
| file << float_ptr[i] << std::endl; // Write in text format; for binary, use file.write(reinterpret_cast<const char*>(&float_ptr[i]), sizeof(float)); | ||
| } | ||
| file.close(); | ||
| }; | ||
|
|
||
| if(params_.dump_resource) | ||
| { | ||
| dump_buffer_to_file(persistent_buffer_, "umd_gemm_data.txt"); | ||
| } | ||
|
|
||
| const auto ret = GemmBaseDispatcher::validate_conformance(command_queue, command_allocator, command_list, print_mismatche); | ||
| return ret; | ||
| } | ||
| private: | ||
| dnnl::memory create_dnnl_memory(const auto& desc, auto& umd_mem, std::size_t offset = 0) | ||
| { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar lambda is in convolution umd dispatcher - could it be moved to be separate function?