diff --git a/ggml/src/ggml-vulkan.cpp b/ggml/src/ggml-vulkan.cpp index 12ad9d810327f..abe1ba46ac30b 100644 --- a/ggml/src/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan.cpp @@ -148,6 +148,7 @@ struct vk_device_struct { vk::PhysicalDeviceProperties properties; std::string name; uint64_t max_memory_allocation_size; + uint32_t force_heap_index; bool fp16; vk::Device device; uint32_t vendor_id; @@ -1008,9 +1009,12 @@ static void ggml_vk_queue_cleanup(vk_device& device, vk_queue& q) { q.cmd_buffer_idx = 0; } -static uint32_t find_properties(const vk::PhysicalDeviceMemoryProperties* mem_props, vk::MemoryRequirements* mem_req, vk::MemoryPropertyFlags flags) { +static uint32_t find_properties(const vk::PhysicalDeviceMemoryProperties* mem_props, vk::MemoryRequirements* mem_req, vk::MemoryPropertyFlags flags, uint32_t force_heap_index = UINT32_MAX) { for (uint32_t i = 0; i < mem_props->memoryTypeCount; ++i) { vk::MemoryType memory_type = mem_props->memoryTypes[i]; + if (force_heap_index != UINT32_MAX && memory_type.heapIndex != force_heap_index) { + continue; + } if ((mem_req->memoryTypeBits & ((uint64_t)1 << i)) && (flags & memory_type.propertyFlags) == flags && mem_props->memoryHeaps[memory_type.heapIndex].size >= mem_req->size) { @@ -1053,11 +1057,11 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::Memor uint32_t memory_type_index = UINT32_MAX; - memory_type_index = find_properties(&mem_props, &mem_req, req_flags); + memory_type_index = find_properties(&mem_props, &mem_req, req_flags, device->force_heap_index); buf->memory_property_flags = req_flags; if (memory_type_index == UINT32_MAX && fallback_flags) { - memory_type_index = find_properties(&mem_props, &mem_req, fallback_flags); + memory_type_index = find_properties(&mem_props, &mem_req, fallback_flags, device->force_heap_index); buf->memory_property_flags = fallback_flags; } @@ -1851,6 +1855,14 @@ static vk_device ggml_vk_get_device(size_t idx) { device->max_memory_allocation_size = props3.maxMemoryAllocationSize; } + const char* GGML_VK_FORCE_HEAP_INDEX = getenv("GGML_VK_FORCE_HEAP_INDEX"); + + if (GGML_VK_FORCE_HEAP_INDEX != nullptr) { + device->force_heap_index = std::stoi(GGML_VK_FORCE_HEAP_INDEX); + } else { + device->force_heap_index = UINT32_MAX; + } + device->vendor_id = device->properties.vendorID; device->subgroup_size = subgroup_props.subgroupSize; device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;