diff --git a/tools/common_lib/src/gemm.h b/tools/common_lib/src/gemm.h index 2d66018..cc0781a 100644 --- a/tools/common_lib/src/gemm.h +++ b/tools/common_lib/src/gemm.h @@ -500,6 +500,8 @@ class GemmBaseDispatcher : public NodeDispatcher bool allow_fp16_computations = false; bool use_dnnl_for_reference_calculations = false; + bool dump_resource = false; + inline static void add_cli_options(CLI::App* opts, create_params_t& params) { add_data_type_cli_option(opts, "--data_type", params.dt)->required(); @@ -531,6 +533,8 @@ class GemmBaseDispatcher : public NodeDispatcher { "sv_s_kv", GemmType::GemmType_SV_S_KV }, }, CLI::ignore_case))->required(); + opts->add_flag("--dump_resource", params.dump_resource); + } }; public: @@ -1292,12 +1296,8 @@ class GemmUmdD3d12Dispatcher : public GemmBaseDispatcher std::optional scratchpad_memory; if(has_scratchpad_tensor()) { - if (scratchpad_memory_desc_) - { - scratchpad_memory.emplace(create_dnnl_memory(scratchpad_memory_desc_.value(), umd_scratchpad_memory_)); - } + scratchpad_memory.emplace(create_dnnl_memory(scratchpad_memory_desc_.value(), umd_scratchpad_memory_)); } - dnnl::memory output_memory = create_dnnl_memory(output_memory_desc_, umd_output_memory_); @@ -1327,6 +1327,49 @@ class GemmUmdD3d12Dispatcher : public GemmBaseDispatcher gemm_.execute(stream, args); } + + ConformanceResult validate_conformance(ID3D12CommandQueue* command_queue, + ID3D12CommandAllocator* command_allocator, ID3D12GraphicsCommandList* command_list, bool print_mismatche) override + { + auto dump_buffer_to_file = [&](const auto& buffer, const auto& file_name) + { + if (!buffer) + { + return; + } + const auto bytes_width = buffer->GetDesc().Width; + // readback data and validate + auto readback_buffer = create_buffer(d3d12_device_, bytes_width, D3D12_HEAP_TYPE_READBACK, D3D12_RESOURCE_STATE_COPY_DEST); + auto readback_output_barrirer = CD3DX12_RESOURCE_BARRIER::Transition(buffer.Get(), + D3D12_RESOURCE_STATE_UNORDERED_ACCESS, D3D12_RESOURCE_STATE_COPY_SOURCE); + command_list->ResourceBarrier(1, &readback_output_barrirer); + command_list->CopyResource(readback_buffer.Get(), persistent_buffer_.Get()); + close_execute_reset_wait(d3d12_device_, command_queue, command_allocator, command_list); + + std::vector data_out(bytes_width); + std::byte* readback_mapped_ptr = nullptr; + readback_buffer->Map(0, nullptr, reinterpret_cast(&readback_mapped_ptr)); + std::memcpy(data_out.data(), readback_mapped_ptr, data_out.size()); + readback_buffer->Unmap(0, nullptr); + + // Assuming data_out now contains the float data + float* float_ptr = reinterpret_cast(data_out.data()); + size_t num_floats = data_out.size() / sizeof(float); + std::ofstream file(file_name, std::ios::out); // Open in text mode; use std::ios::binary for binary mode + for (size_t i = 0; i < num_floats; ++i) { + file << float_ptr[i] << std::endl; // Write in text format; for binary, use file.write(reinterpret_cast(&float_ptr[i]), sizeof(float)); + } + file.close(); + }; + + if(params_.dump_resource) + { + dump_buffer_to_file(persistent_buffer_, "umd_gemm_data.txt"); + } + + const auto ret = GemmBaseDispatcher::validate_conformance(command_queue, command_allocator, command_list, print_mismatche); + return ret; + } private: dnnl::memory create_dnnl_memory(const auto& desc, auto& umd_mem, std::size_t offset = 0) {