You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
97 lines
3.7 KiB
97 lines
3.7 KiB
|
|
use wgpu::{Instance, Backends, Adapter, DeviceType, Features, Limits, DeviceDescriptor, Device, Queue, ShaderModuleDescriptor, ShaderSource, ShaderModule, BufferDescriptor, util::{BufferInitDescriptor, DeviceExt}, BufferUsages, Buffer};
|
|
use image::{ImageBuffer, Rgb};
|
|
use std::task::Poll;
|
|
use std::future::Future;
|
|
|
|
#[repr(C)]
|
|
struct ImageFormat {
|
|
channel_stride: u32,
|
|
width: u32,
|
|
width_stride: u32,
|
|
height: u32,
|
|
height_stride: u32,
|
|
}
|
|
|
|
pub struct GpuContext {
|
|
device: Device,
|
|
queue: Queue,
|
|
shader_module: ShaderModule,
|
|
goal_image_buffer: Buffer,
|
|
image_format_buffer: Buffer,
|
|
scratch_image_buffers: (Buffer, Buffer),
|
|
}
|
|
|
|
pub fn init(goal_image: &ImageBuffer<Rgb<f32>, Vec<f32>>) -> Option<GpuContext> {
|
|
let inst = Instance::new(Backends::all());
|
|
|
|
// Look for devices with the features we want
|
|
let required_features = Features::SHADER_FLOAT64 | Features::CLEAR_COMMANDS;
|
|
let adapters: Vec<_> = inst.enumerate_adapters(Backends::all()).filter(|a| a.features().contains(required_features)).collect();
|
|
if adapters.is_empty() {
|
|
return None;
|
|
}
|
|
|
|
// Pick the best type of device available
|
|
let mut target_device_type = DeviceType::Other;
|
|
if adapters.iter().any(|a| a.get_info().device_type == DeviceType::Cpu) {
|
|
target_device_type = DeviceType::Cpu;
|
|
}
|
|
if adapters.iter().any(|a| a.get_info().device_type == DeviceType::VirtualGpu) {
|
|
target_device_type = DeviceType::VirtualGpu;
|
|
}
|
|
if adapters.iter().any(|a| a.get_info().device_type == DeviceType::IntegratedGpu) {
|
|
target_device_type = DeviceType::IntegratedGpu;
|
|
}
|
|
if adapters.iter().any(|a| a.get_info().device_type == DeviceType::DiscreteGpu) {
|
|
target_device_type = DeviceType::DiscreteGpu;
|
|
}
|
|
println!("available adapters:");
|
|
for (i, adapter) in adapters.iter().enumerate() {
|
|
let info = adapter.get_info();
|
|
println!("{}: {}", i, info.name);
|
|
}
|
|
|
|
let (i, adapter) = adapters.into_iter().enumerate().filter(|(_, a)| a.get_info().device_type == target_device_type).next().unwrap();
|
|
println!("picking {}: {}", i, adapter.get_info().name);
|
|
|
|
let limits = adapter.limits();
|
|
|
|
let device_descriptor = DeviceDescriptor {
|
|
label: None,
|
|
features: required_features,
|
|
limits: limits.clone(),
|
|
};
|
|
let (device, queue) = pollster::block_on(adapter.request_device(&device_descriptor, None)).ok()?;
|
|
|
|
let sample_layout = goal_image.sample_layout();
|
|
let image_format = ImageFormat {
|
|
channel_stride: sample_layout.channel_stride as u32,
|
|
width: sample_layout.width,
|
|
width_stride: sample_layout.width_stride as u32,
|
|
height: sample_layout.height,
|
|
height_stride: sample_layout.height_stride as u32,
|
|
};
|
|
|
|
let image_format_buffer = device.create_buffer_init(&BufferInitDescriptor { label: None, contents: unsafe { std::slice::from_raw_parts(&image_format as *const ImageFormat as *const u8, std::mem::size_of::<ImageFormat>()) }, usage: BufferUsages::UNIFORM });
|
|
|
|
let goal_image_storage = &*goal_image.as_raw();
|
|
let goal_image_buffer = device.create_buffer_init(&BufferInitDescriptor { label: None, contents: bytemuck::cast_slice(goal_image_storage), usage: BufferUsages::UNIFORM | BufferUsages::STORAGE });
|
|
|
|
let scratch_image_buffer_gen = || device.create_buffer(&BufferDescriptor { label: None, size: (std::mem::size_of::<f32>() as u64 * image_format.width as u64 * image_format.height as u64), usage: BufferUsages::COPY_DST | BufferUsages::STORAGE | BufferUsages::MAP_READ, mapped_at_creation: false });
|
|
let scratch_image_buffers = (scratch_image_buffer_gen(), scratch_image_buffer_gen());
|
|
|
|
let shader_module = device.create_shader_module(&wgpu::include_spirv!("shader.spv"));
|
|
|
|
let context = GpuContext {
|
|
device,
|
|
queue,
|
|
shader_module,
|
|
goal_image_buffer,
|
|
image_format_buffer,
|
|
scratch_image_buffers,
|
|
};
|
|
|
|
Some(context)
|
|
}
|
|
|