Metal: Improve startup times by using concurrent shader compilation APIs

This commit is contained in:
Stuart Carnie 2024-08-25 08:04:02 +10:00
parent e3550cb20f
commit 2ef1ef63a5
No known key found for this signature in database
GPG key ID: 848D9C9718D78B4F
5 changed files with 403 additions and 30 deletions

View file

@ -57,10 +57,12 @@
#import "servers/rendering/rendering_device_driver.h"
#import <CommonCrypto/CommonDigest.h>
#import <Foundation/Foundation.h>
#import <Metal/Metal.h>
#import <QuartzCore/CAMetalLayer.h>
#import <simd/simd.h>
#import <zlib.h>
#import <initializer_list>
#import <optional>
#import <spirv.hpp>
@ -497,6 +499,76 @@ struct API_AVAILABLE(macos(11.0), ios(14.0)) UniformSet {
HashMap<RDC::ShaderStage, id<MTLArgumentEncoder>> encoders;
};
struct ShaderCacheEntry;
enum class ShaderLoadStrategy {
DEFAULT,
LAZY,
};
/**
* A Metal shader library.
*/
@interface MDLibrary : NSObject
- (id<MTLLibrary>)library;
- (NSError *)error;
- (void)setLabel:(NSString *)label;
+ (instancetype)newLibraryWithCacheEntry:(ShaderCacheEntry *)entry
device:(id<MTLDevice>)device
source:(NSString *)source
options:(MTLCompileOptions *)options
strategy:(ShaderLoadStrategy)strategy;
@end
struct SHA256Digest {
unsigned char data[CC_SHA256_DIGEST_LENGTH];
uint32_t hash() const {
uint32_t c = crc32(0, data, CC_SHA256_DIGEST_LENGTH);
return c;
}
SHA256Digest() {
bzero(data, CC_SHA256_DIGEST_LENGTH);
}
SHA256Digest(const char *p_data, size_t p_length) {
CC_SHA256(p_data, (CC_LONG)p_length, data);
}
};
template <>
struct HashMapComparatorDefault<SHA256Digest> {
static bool compare(const SHA256Digest &p_lhs, const SHA256Digest &p_rhs) {
return memcmp(p_lhs.data, p_rhs.data, CC_SHA256_DIGEST_LENGTH) == 0;
}
};
/**
* A cache entry for a Metal shader library.
*/
struct ShaderCacheEntry {
RenderingDeviceDriverMetal &owner;
SHA256Digest key;
CharString name;
CharString short_sha;
RD::ShaderStage stage = RD::SHADER_STAGE_VERTEX;
/**
* This reference must be weak, to ensure that when the last strong reference to the library
* is released, the cache entry is freed.
*/
MDLibrary *__weak library = nil;
/** Notify the cache that this entry is no longer needed. */
void notify_free() const;
ShaderCacheEntry(RenderingDeviceDriverMetal &p_owner, SHA256Digest p_key) :
owner(p_owner), key(p_key) {
}
~ShaderCacheEntry() = default;
};
class API_AVAILABLE(macos(11.0), ios(14.0)) MDShader {
public:
CharString name;
@ -517,15 +589,14 @@ public:
} push_constants;
MTLSize local = {};
id<MTLLibrary> kernel;
MDLibrary *kernel;
#if DEV_ENABLED
CharString kernel_source;
#endif
void encode_push_constant_data(VectorView<uint32_t> p_data, MDCommandBuffer *p_cb) final;
MDComputeShader(CharString p_name, Vector<UniformSet> p_sets, id<MTLLibrary> p_kernel);
~MDComputeShader() override = default;
MDComputeShader(CharString p_name, Vector<UniformSet> p_sets, MDLibrary *p_kernel);
};
class API_AVAILABLE(macos(11.0), ios(14.0)) MDRenderShader final : public MDShader {
@ -541,8 +612,8 @@ public:
} frag;
} push_constants;
id<MTLLibrary> vert;
id<MTLLibrary> frag;
MDLibrary *vert;
MDLibrary *frag;
#if DEV_ENABLED
CharString vert_source;
CharString frag_source;
@ -550,8 +621,7 @@ public:
void encode_push_constant_data(VectorView<uint32_t> p_data, MDCommandBuffer *p_cb) final;
MDRenderShader(CharString p_name, Vector<UniformSet> p_sets, id<MTLLibrary> p_vert, id<MTLLibrary> p_frag);
~MDRenderShader() override = default;
MDRenderShader(CharString p_name, Vector<UniformSet> p_sets, MDLibrary *p_vert, MDLibrary *p_frag);
};
enum StageResourceUsage : uint32_t {

View file

@ -50,9 +50,12 @@
#import "metal_objects.h"
#import "metal_utils.h"
#import "pixel_formats.h"
#import "rendering_device_driver_metal.h"
#import <os/signpost.h>
void MDCommandBuffer::begin() {
DEV_ASSERT(commandBuffer == nil);
commandBuffer = queue.commandBuffer;
@ -850,7 +853,7 @@ void MDCommandBuffer::_end_blit() {
type = MDCommandBufferStateType::None;
}
MDComputeShader::MDComputeShader(CharString p_name, Vector<UniformSet> p_sets, id<MTLLibrary> p_kernel) :
MDComputeShader::MDComputeShader(CharString p_name, Vector<UniformSet> p_sets, MDLibrary *p_kernel) :
MDShader(p_name, p_sets), kernel(p_kernel) {
}
@ -868,7 +871,7 @@ void MDComputeShader::encode_push_constant_data(VectorView<uint32_t> p_data, MDC
[enc setBytes:ptr length:length atIndex:push_constants.binding];
}
MDRenderShader::MDRenderShader(CharString p_name, Vector<UniformSet> p_sets, id<MTLLibrary> _Nonnull p_vert, id<MTLLibrary> _Nonnull p_frag) :
MDRenderShader::MDRenderShader(CharString p_name, Vector<UniformSet> p_sets, MDLibrary *_Nonnull p_vert, MDLibrary *_Nonnull p_frag) :
MDShader(p_name, p_sets), vert(p_vert), frag(p_frag) {
}
@ -1378,3 +1381,204 @@ id<MTLDepthStencilState> MDResourceCache::get_depth_stencil_state(bool p_use_dep
}
return *val;
}
static const char *SHADER_STAGE_NAMES[] = {
[RD::SHADER_STAGE_VERTEX] = "vert",
[RD::SHADER_STAGE_FRAGMENT] = "frag",
[RD::SHADER_STAGE_TESSELATION_CONTROL] = "tess_ctrl",
[RD::SHADER_STAGE_TESSELATION_EVALUATION] = "tess_eval",
[RD::SHADER_STAGE_COMPUTE] = "comp",
};
void ShaderCacheEntry::notify_free() const {
owner.shader_cache_free_entry(key);
}
@interface MDLibrary ()
- (instancetype)initWithCacheEntry:(ShaderCacheEntry *)entry;
- (ShaderCacheEntry *)entry;
@end
@interface MDLazyLibrary : MDLibrary {
id<MTLLibrary> _library;
NSError *_error;
std::shared_mutex _mu;
bool _loaded;
id<MTLDevice> _device;
NSString *_source;
MTLCompileOptions *_options;
}
- (instancetype)initWithCacheEntry:(ShaderCacheEntry *)entry
device:(id<MTLDevice>)device
source:(NSString *)source
options:(MTLCompileOptions *)options;
@end
@interface MDImmediateLibrary : MDLibrary {
id<MTLLibrary> _library;
NSError *_error;
std::mutex _cv_mutex;
std::condition_variable _cv;
std::atomic<bool> _complete;
bool _ready;
}
- (instancetype)initWithCacheEntry:(ShaderCacheEntry *)entry
device:(id<MTLDevice>)device
source:(NSString *)source
options:(MTLCompileOptions *)options;
@end
@implementation MDLibrary {
ShaderCacheEntry *_entry;
}
+ (instancetype)newLibraryWithCacheEntry:(ShaderCacheEntry *)entry
device:(id<MTLDevice>)device
source:(NSString *)source
options:(MTLCompileOptions *)options
strategy:(ShaderLoadStrategy)strategy {
switch (strategy) {
case ShaderLoadStrategy::DEFAULT:
[[fallthrough]];
default:
return [[MDImmediateLibrary alloc] initWithCacheEntry:entry device:device source:source options:options];
case ShaderLoadStrategy::LAZY:
return [[MDLazyLibrary alloc] initWithCacheEntry:entry device:device source:source options:options];
}
}
- (ShaderCacheEntry *)entry {
return _entry;
}
- (id<MTLLibrary>)library {
CRASH_NOW_MSG("Not implemented");
return nil;
}
- (NSError *)error {
CRASH_NOW_MSG("Not implemented");
return nil;
}
- (void)setLabel:(NSString *)label {
}
- (instancetype)initWithCacheEntry:(ShaderCacheEntry *)entry {
self = [super init];
_entry = entry;
_entry->library = self;
return self;
}
- (void)dealloc {
_entry->notify_free();
}
@end
@implementation MDImmediateLibrary
- (instancetype)initWithCacheEntry:(ShaderCacheEntry *)entry
device:(id<MTLDevice>)device
source:(NSString *)source
options:(MTLCompileOptions *)options {
self = [super initWithCacheEntry:entry];
_complete = false;
_ready = false;
__block os_signpost_id_t compile_id = (os_signpost_id_t)(uintptr_t)self;
os_signpost_interval_begin(LOG_INTERVALS, compile_id, "shader_compile",
"shader_name=%{public}s stage=%{public}s hash=%{public}s",
entry->name.get_data(), SHADER_STAGE_NAMES[entry->stage], entry->short_sha.get_data());
[device newLibraryWithSource:source
options:options
completionHandler:^(id<MTLLibrary> library, NSError *error) {
os_signpost_interval_end(LOG_INTERVALS, compile_id, "shader_compile");
self->_library = library;
self->_error = error;
if (error) {
ERR_PRINT(String(U"Error compiling shader %s: %s").format(entry->name.get_data(), error.localizedDescription.UTF8String));
}
{
std::lock_guard<std::mutex> lock(self->_cv_mutex);
_ready = true;
}
_cv.notify_all();
_complete = true;
}];
return self;
}
- (id<MTLLibrary>)library {
if (!_complete) {
std::unique_lock<std::mutex> lock(_cv_mutex);
_cv.wait(lock, [&] { return _ready; });
}
return _library;
}
- (NSError *)error {
if (!_complete) {
std::unique_lock<std::mutex> lock(_cv_mutex);
_cv.wait(lock, [&] { return _ready; });
}
return _error;
}
@end
@implementation MDLazyLibrary
- (instancetype)initWithCacheEntry:(ShaderCacheEntry *)entry
device:(id<MTLDevice>)device
source:(NSString *)source
options:(MTLCompileOptions *)options {
self = [super initWithCacheEntry:entry];
_device = device;
_source = source;
_options = options;
return self;
}
- (void)load {
{
std::shared_lock<std::shared_mutex> lock(_mu);
if (_loaded) {
return;
}
}
std::unique_lock<std::shared_mutex> lock(_mu);
if (_loaded) {
return;
}
ShaderCacheEntry *entry = [self entry];
__block os_signpost_id_t compile_id = (os_signpost_id_t)(uintptr_t)self;
os_signpost_interval_begin(LOG_INTERVALS, compile_id, "shader_compile",
"shader_name=%{public}s stage=%{public}s hash=%{public}s",
entry->name.get_data(), SHADER_STAGE_NAMES[entry->stage], entry->short_sha.get_data());
NSError *error;
_library = [_device newLibraryWithSource:_source options:_options error:&error];
os_signpost_interval_end(LOG_INTERVALS, compile_id, "shader_compile");
_device = nil;
_source = nil;
_options = nil;
_loaded = true;
}
- (id<MTLLibrary>)library {
[self load];
return _library;
}
- (NSError *)error {
[self load];
return _error;
}
@end

View file

@ -31,6 +31,8 @@
#ifndef METAL_UTILS_H
#define METAL_UTILS_H
#import <os/log.h>
#pragma mark - Boolean flags
namespace flags {
@ -78,4 +80,22 @@ static constexpr uint64_t round_up_to_alignment(uint64_t p_value, uint64_t p_ali
return aligned_value;
}
class Defer {
public:
Defer(std::function<void()> func) :
func_(func) {}
~Defer() { func_(); }
private:
std::function<void()> func_;
};
#define CONCAT_INTERNAL(x, y) x##y
#define CONCAT(x, y) CONCAT_INTERNAL(x, y)
#define DEFER const Defer &CONCAT(defer__, __LINE__) = Defer
extern os_log_t LOG_DRIVER;
// Used for dynamic tracing.
extern os_log_t LOG_INTERVALS;
#endif // METAL_UTILS_H

View file

@ -48,6 +48,8 @@
class RenderingContextDriverMetal;
class API_AVAILABLE(macos(11.0), ios(14.0)) RenderingDeviceDriverMetal : public RenderingDeviceDriver {
friend struct ShaderCacheEntry;
template <typename T>
using Result = std::variant<T, Error>;
@ -77,6 +79,19 @@ class API_AVAILABLE(macos(11.0), ios(14.0)) RenderingDeviceDriverMetal : public
Error _create_device();
Error _check_capabilities();
#pragma mark - Shader Cache
ShaderLoadStrategy _shader_load_strategy = ShaderLoadStrategy::DEFAULT;
/**
* The shader cache is a map of hashes of the Metal source to shader cache entries.
*
* To prevent unbounded growth of the cache, cache entries are automatically freed when
* there are no more references to the MDLibrary associated with the cache entry.
*/
HashMap<SHA256Digest, ShaderCacheEntry *, HashableHasher<SHA256Digest>> _shader_cache;
void shader_cache_free_entry(const SHA256Digest &key);
public:
Error initialize(uint32_t p_device_index, uint32_t p_frame_count) override final;
@ -270,7 +285,7 @@ public:
#pragma mark Pipeline
private:
Result<id<MTLFunction>> _create_function(id<MTLLibrary> p_library, NSString *p_name, VectorView<PipelineSpecializationConstant> &p_specialization_constants);
Result<id<MTLFunction>> _create_function(MDLibrary *p_library, NSString *p_name, VectorView<PipelineSpecializationConstant> &p_specialization_constants);
public:
virtual void pipeline_free(PipelineID p_pipeline_id) override final;

View file

@ -60,9 +60,22 @@
#import <Metal/MTLTexture.h>
#import <Metal/Metal.h>
#import <os/log.h>
#import <os/signpost.h>
#import <spirv_msl.hpp>
#import <spirv_parser.hpp>
#pragma mark - Logging
os_log_t LOG_DRIVER;
// Used for dynamic tracing.
os_log_t LOG_INTERVALS;
__attribute__((constructor)) static void InitializeLogging(void) {
LOG_DRIVER = os_log_create("org.stuartcarnie.godot.metal", OS_LOG_CATEGORY_POINTS_OF_INTEREST);
LOG_INTERVALS = os_log_create("org.stuartcarnie.godot.metal", "events");
}
/*****************/
/**** GENERIC ****/
/*****************/
@ -2258,6 +2271,15 @@ Vector<uint8_t> RenderingDeviceDriverMetal::shader_compile_binary_from_spirv(Vec
return ret;
}
void RenderingDeviceDriverMetal::shader_cache_free_entry(const SHA256Digest &key) {
if (ShaderCacheEntry **pentry = _shader_cache.getptr(key); pentry != nullptr) {
ShaderCacheEntry *entry = *pentry;
_shader_cache.erase(key);
entry->library = nil;
memdelete(entry);
}
}
RDD::ShaderID RenderingDeviceDriverMetal::shader_create_from_bytecode(const Vector<uint8_t> &p_shader_binary, ShaderDescription &r_shader_desc, String &r_name) {
r_shader_desc = {}; // Driver-agnostic.
@ -2285,18 +2307,32 @@ RDD::ShaderID RenderingDeviceDriverMetal::shader_create_from_bytecode(const Vect
MTLCompileOptions *options = [MTLCompileOptions new];
options.languageVersion = binary_data.get_msl_version();
HashMap<ShaderStage, id<MTLLibrary>> libraries;
HashMap<ShaderStage, MDLibrary *> libraries;
for (ShaderStageData &shader_data : binary_data.stages) {
NSString *source = [[NSString alloc] initWithBytesNoCopy:(void *)shader_data.source.ptr()
length:shader_data.source.length()
encoding:NSUTF8StringEncoding
freeWhenDone:NO];
NSError *error = nil;
id<MTLLibrary> library = [device newLibraryWithSource:source options:options error:&error];
if (error != nil) {
print_error(error.localizedDescription.UTF8String);
ERR_FAIL_V_MSG(ShaderID(), "failed to compile Metal source");
SHA256Digest key = SHA256Digest(shader_data.source.ptr(), shader_data.source.length());
if (ShaderCacheEntry **p = _shader_cache.getptr(key); p != nullptr) {
libraries[shader_data.stage] = (*p)->library;
continue;
}
NSString *source = [[NSString alloc] initWithBytes:(void *)shader_data.source.ptr()
length:shader_data.source.length()
encoding:NSUTF8StringEncoding];
ShaderCacheEntry *cd = memnew(ShaderCacheEntry(*this, key));
cd->name = binary_data.shader_name;
String sha_hex = String::hex_encode_buffer(key.data, CC_SHA256_DIGEST_LENGTH);
cd->short_sha = sha_hex.substr(0, 8).utf8();
cd->stage = shader_data.stage;
MDLibrary *library = [MDLibrary newLibraryWithCacheEntry:cd
device:device
source:source
options:options
strategy:_shader_load_strategy];
_shader_cache[key] = cd;
libraries[shader_data.stage] = library;
}
@ -3062,8 +3098,13 @@ void RenderingDeviceDriverMetal::command_render_set_line_width(CommandBufferID p
// ----- PIPELINE -----
RenderingDeviceDriverMetal::Result<id<MTLFunction>> RenderingDeviceDriverMetal::_create_function(id<MTLLibrary> p_library, NSString *p_name, VectorView<PipelineSpecializationConstant> &p_specialization_constants) {
id<MTLFunction> function = [p_library newFunctionWithName:p_name];
RenderingDeviceDriverMetal::Result<id<MTLFunction>> RenderingDeviceDriverMetal::_create_function(MDLibrary *p_library, NSString *p_name, VectorView<PipelineSpecializationConstant> &p_specialization_constants) {
id<MTLLibrary> library = p_library.library;
if (!library) {
ERR_FAIL_V_MSG(ERR_CANT_CREATE, "Failed to compile Metal library");
}
id<MTLFunction> function = [library newFunctionWithName:p_name];
ERR_FAIL_NULL_V_MSG(function, ERR_CANT_CREATE, "No function named main0");
if (function.functionConstantsDictionary.count == 0) {
@ -3141,9 +3182,9 @@ RenderingDeviceDriverMetal::Result<id<MTLFunction>> RenderingDeviceDriverMetal::
}
NSError *err = nil;
function = [p_library newFunctionWithName:@"main0"
constantValues:constantValues
error:&err];
function = [library newFunctionWithName:@"main0"
constantValues:constantValues
error:&err];
ERR_FAIL_NULL_V_MSG(function, ERR_CANT_CREATE, String("specialized function failed: ") + err.localizedDescription.UTF8String);
return function;
@ -3188,6 +3229,14 @@ RDD::PipelineID RenderingDeviceDriverMetal::render_pipeline_create(
MTLVertexDescriptor *vert_desc = rid::get(p_vertex_format);
MDRenderPass *pass = (MDRenderPass *)(p_render_pass.id);
os_signpost_id_t reflect_id = os_signpost_id_make_with_pointer(LOG_INTERVALS, shader);
os_signpost_interval_begin(LOG_INTERVALS, reflect_id, "render_pipeline_create", "shader_name=%{public}s", shader->name.get_data());
DEFER([=]() {
os_signpost_interval_end(LOG_INTERVALS, reflect_id, "render_pipeline_create");
});
os_signpost_event_emit(LOG_DRIVER, OS_SIGNPOST_ID_EXCLUSIVE, "create_pipeline");
MTLRenderPipelineDescriptor *desc = [MTLRenderPipelineDescriptor new];
{
@ -3482,9 +3531,15 @@ void RenderingDeviceDriverMetal::command_compute_dispatch_indirect(CommandBuffer
RDD::PipelineID RenderingDeviceDriverMetal::compute_pipeline_create(ShaderID p_shader, VectorView<PipelineSpecializationConstant> p_specialization_constants) {
MDComputeShader *shader = (MDComputeShader *)(p_shader.id);
id<MTLLibrary> library = shader->kernel;
os_signpost_id_t reflect_id = os_signpost_id_make_with_pointer(LOG_INTERVALS, shader);
os_signpost_interval_begin(LOG_INTERVALS, reflect_id, "compute_pipeline_create", "shader_name=%{public}s", shader->name.get_data());
DEFER([=]() {
os_signpost_interval_end(LOG_INTERVALS, reflect_id, "compute_pipeline_create");
});
Result<id<MTLFunction>> function_or_err = _create_function(library, @"main0", p_specialization_constants);
os_signpost_event_emit(LOG_DRIVER, OS_SIGNPOST_ID_EXCLUSIVE, "create_pipeline");
Result<id<MTLFunction>> function_or_err = _create_function(shader->kernel, @"main0", p_specialization_constants);
ERR_FAIL_COND_V(std::holds_alternative<Error>(function_or_err), PipelineID());
id<MTLFunction> function = std::get<id<MTLFunction>>(function_or_err);
@ -3585,12 +3640,13 @@ void RenderingDeviceDriverMetal::set_object_name(ObjectType p_type, ID p_driver_
buffer.label = [NSString stringWithUTF8String:p_name.utf8().get_data()];
} break;
case OBJECT_TYPE_SHADER: {
NSString *label = [NSString stringWithUTF8String:p_name.utf8().get_data()];
MDShader *shader = (MDShader *)(p_driver_id.id);
if (MDRenderShader *rs = dynamic_cast<MDRenderShader *>(shader); rs != nullptr) {
rs->vert.label = [NSString stringWithUTF8String:p_name.utf8().get_data()];
rs->frag.label = [NSString stringWithUTF8String:p_name.utf8().get_data()];
[rs->vert setLabel:label];
[rs->frag setLabel:label];
} else if (MDComputeShader *cs = dynamic_cast<MDComputeShader *>(shader); cs != nullptr) {
cs->kernel.label = [NSString stringWithUTF8String:p_name.utf8().get_data()];
[cs->kernel setLabel:label];
} else {
DEV_ASSERT(false);
}
@ -3830,12 +3886,20 @@ size_t RenderingDeviceDriverMetal::get_texel_buffer_alignment_for_format(MTLPixe
RenderingDeviceDriverMetal::RenderingDeviceDriverMetal(RenderingContextDriverMetal *p_context_driver) :
context_driver(p_context_driver) {
DEV_ASSERT(p_context_driver != nullptr);
if (String res = OS::get_singleton()->get_environment("GODOT_MTL_SHADER_LOAD_STRATEGY"); res == U"lazy") {
_shader_load_strategy = ShaderLoadStrategy::LAZY;
}
}
RenderingDeviceDriverMetal::~RenderingDeviceDriverMetal() {
for (MDCommandBuffer *cb : command_buffers) {
delete cb;
}
for (KeyValue<SHA256Digest, ShaderCacheEntry *> &kv : _shader_cache) {
memdelete(kv.value);
}
}
#pragma mark - Initialization