You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

97 lines
3.7 KiB

use wgpu::{Instance, Backends, Adapter, DeviceType, Features, Limits, DeviceDescriptor, Device, Queue, ShaderModuleDescriptor, ShaderSource, ShaderModule, BufferDescriptor, util::{BufferInitDescriptor, DeviceExt}, BufferUsages, Buffer};
use image::{ImageBuffer, Rgb};
use std::task::Poll;
use std::future::Future;
#[repr(C)]
struct ImageFormat {
channel_stride: u32,
width: u32,
width_stride: u32,
height: u32,
height_stride: u32,
}
pub struct GpuContext {
device: Device,
queue: Queue,
shader_module: ShaderModule,
goal_image_buffer: Buffer,
image_format_buffer: Buffer,
scratch_image_buffers: (Buffer, Buffer),
}
pub fn init(goal_image: &ImageBuffer<Rgb<f32>, Vec<f32>>) -> Option<GpuContext> {
let inst = Instance::new(Backends::all());
// Look for devices with the features we want
let required_features = Features::SHADER_FLOAT64 | Features::CLEAR_COMMANDS;
let adapters: Vec<_> = inst.enumerate_adapters(Backends::all()).filter(|a| a.features().contains(required_features)).collect();
if adapters.is_empty() {
return None;
}
// Pick the best type of device available
let mut target_device_type = DeviceType::Other;
if adapters.iter().any(|a| a.get_info().device_type == DeviceType::Cpu) {
target_device_type = DeviceType::Cpu;
}
if adapters.iter().any(|a| a.get_info().device_type == DeviceType::VirtualGpu) {
target_device_type = DeviceType::VirtualGpu;
}
if adapters.iter().any(|a| a.get_info().device_type == DeviceType::IntegratedGpu) {
target_device_type = DeviceType::IntegratedGpu;
}
if adapters.iter().any(|a| a.get_info().device_type == DeviceType::DiscreteGpu) {
target_device_type = DeviceType::DiscreteGpu;
}
println!("available adapters:");
for (i, adapter) in adapters.iter().enumerate() {
let info = adapter.get_info();
println!("{}: {}", i, info.name);
}
let (i, adapter) = adapters.into_iter().enumerate().filter(|(_, a)| a.get_info().device_type == target_device_type).next().unwrap();
println!("picking {}: {}", i, adapter.get_info().name);
let limits = adapter.limits();
let device_descriptor = DeviceDescriptor {
label: None,
features: required_features,
limits: limits.clone(),
};
let (device, queue) = pollster::block_on(adapter.request_device(&device_descriptor, None)).ok()?;
let sample_layout = goal_image.sample_layout();
let image_format = ImageFormat {
channel_stride: sample_layout.channel_stride as u32,
width: sample_layout.width,
width_stride: sample_layout.width_stride as u32,
height: sample_layout.height,
height_stride: sample_layout.height_stride as u32,
};
let image_format_buffer = device.create_buffer_init(&BufferInitDescriptor { label: None, contents: unsafe { std::slice::from_raw_parts(&image_format as *const ImageFormat as *const u8, std::mem::size_of::<ImageFormat>()) }, usage: BufferUsages::UNIFORM });
let goal_image_storage = &*goal_image.as_raw();
let goal_image_buffer = device.create_buffer_init(&BufferInitDescriptor { label: None, contents: bytemuck::cast_slice(goal_image_storage), usage: BufferUsages::UNIFORM | BufferUsages::STORAGE });
let scratch_image_buffer_gen = || device.create_buffer(&BufferDescriptor { label: None, size: (std::mem::size_of::<f32>() as u64 * image_format.width as u64 * image_format.height as u64), usage: BufferUsages::COPY_DST | BufferUsages::STORAGE | BufferUsages::MAP_READ, mapped_at_creation: false });
let scratch_image_buffers = (scratch_image_buffer_gen(), scratch_image_buffer_gen());
let shader_module = device.create_shader_module(&wgpu::include_spirv!("shader.spv"));
let context = GpuContext {
device,
queue,
shader_module,
goal_image_buffer,
image_format_buffer,
scratch_image_buffers,
};
Some(context)
}