Merge pull request #81659 from DarioSamo/nlm-denoiser

Replace OIDN denoiser in Lightmapper with a JNLM denoiser compute shader.
This commit is contained in:
Yuri Sizov 2023-09-27 19:08:01 +02:00
commit aa82cccc41
332 changed files with 304 additions and 114499 deletions

View file

@ -63,6 +63,13 @@ Copyright: 2011, Ole Kniemeyer, MAXON, www.maxon.net
2007-2014, Juan Linietsky, Ariel Manzur
License: Expat and Zlib
Files: ./modules/lightmapper_rd/lm_compute.glsl
Comment: Joint Non-Local Means (JNLM) denoiser
Copyright: 2020, Manuel Prandini
2014-present, Godot Engine contributors
2007-2014, Juan Linietsky, Ariel Manzur
License: Expat
Files: ./platform/android/java/lib/aidl/com/android/*
./platform/android/java/lib/res/layout/status_bar_ongoing_event_progress_bar.xml
./platform/android/java/lib/src/com/google/android/*
@ -428,11 +435,6 @@ Comment: Stripped down version of "nvapi.h" from the NVIDIA NVAPI SDK
Copyright: 2019-2022, NVIDIA Corporation
License: Expat
Files: ./thirdparty/oidn/
Comment: Intel Open Image Denoise
Copyright: 2009-2019, Intel Corporation
License: Apache-2.0
Files: ./thirdparty/openxr/
Comment: OpenXR Loader
Copyright: 2020-2023, The Khronos Group Inc.

View file

@ -24,6 +24,9 @@
<member name="camera_attributes" type="CameraAttributes" setter="set_camera_attributes" getter="get_camera_attributes">
The [CameraAttributes] resource that specifies exposure levels to bake at. Auto-exposure and non exposure properties will be ignored. Exposure settings should be used to reduce the dynamic range present when baking. If exposure is too high, the [LightmapGI] will have banding artifacts or may have over-exposure artifacts.
</member>
<member name="denoiser_strength" type="float" setter="set_denoiser_strength" getter="get_denoiser_strength" default="0.1">
The strength of denoising step applied to the generated lightmaps. Only effective if [member use_denoiser] is [code]true[/code].
</member>
<member name="directional" type="bool" setter="set_directional" getter="is_directional" default="false">
If [code]true[/code], bakes lightmaps to contain directional information as spherical harmonics. This results in more realistic lighting appearance, especially with normal mapped materials and for lights that have their direct light baked ([member Light3D.light_bake_mode] set to [constant Light3D.BAKE_STATIC]). The directional information is also used to provide rough reflections for static and dynamic objects. This has a small run-time performance cost as the shader has to perform more work to interpret the direction information from the lightmap. Directional lightmaps also take longer to bake and result in larger file sizes.
[b]Note:[/b] The property's name has no relationship with [DirectionalLight3D]. [member directional] works with all light types.
@ -59,8 +62,7 @@
To further speed up bake times, decrease [member bounces], disable [member use_denoiser] and increase the lightmap texel size on 3D scenes in the Import doc.
</member>
<member name="use_denoiser" type="bool" setter="set_use_denoiser" getter="is_using_denoiser" default="true">
If [code]true[/code], uses a CPU-based denoising algorithm on the generated lightmap. This eliminates most noise within the generated lightmap at the cost of longer bake times. File sizes are generally not impacted significantly by the use of a denoiser, although lossless compression may do a better job at compressing a denoised image.
[b]Note:[/b] The built-in denoiser (OpenImageDenoise) may crash when denoising lightmaps in large scenes. If you encounter a crash at the end of lightmap baking, try disabling [member use_denoiser].
If [code]true[/code], uses a GPU-based denoising algorithm on the generated lightmap. This eliminates most noise within the generated lightmap at the cost of longer bake times. File sizes are generally not impacted significantly by the use of a denoiser, although lossless compression may do a better job at compressing a denoised image.
</member>
</members>
<constants>

View file

@ -1,138 +0,0 @@
#!/usr/bin/env python
import resource_to_cpp
Import("env")
Import("env_modules")
env_oidn = env_modules.Clone()
# Thirdparty source files
thirdparty_obj = []
thirdparty_dir = "#thirdparty/oidn/"
thirdparty_sources = [
"core/api.cpp",
"core/device.cpp",
"core/filter.cpp",
"core/network.cpp",
"core/autoencoder.cpp",
"core/transfer_function.cpp",
"weights/rtlightmap_hdr.gen.cpp",
"mkl-dnn/src/common/batch_normalization.cpp",
"mkl-dnn/src/common/concat.cpp",
"mkl-dnn/src/common/convolution.cpp",
"mkl-dnn/src/common/convolution_pd.cpp",
"mkl-dnn/src/common/deconvolution.cpp",
"mkl-dnn/src/common/eltwise.cpp",
"mkl-dnn/src/common/engine.cpp",
"mkl-dnn/src/common/inner_product.cpp",
"mkl-dnn/src/common/inner_product_pd.cpp",
"mkl-dnn/src/common/lrn.cpp",
"mkl-dnn/src/common/memory.cpp",
"mkl-dnn/src/common/memory_desc_wrapper.cpp",
"mkl-dnn/src/common/mkldnn_debug.cpp",
"mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp",
"mkl-dnn/src/common/pooling.cpp",
"mkl-dnn/src/common/primitive.cpp",
"mkl-dnn/src/common/primitive_attr.cpp",
"mkl-dnn/src/common/primitive_desc.cpp",
"mkl-dnn/src/common/primitive_exec_types.cpp",
"mkl-dnn/src/common/primitive_iterator.cpp",
"mkl-dnn/src/common/query.cpp",
"mkl-dnn/src/common/reorder.cpp",
"mkl-dnn/src/common/rnn.cpp",
"mkl-dnn/src/common/scratchpad.cpp",
"mkl-dnn/src/common/shuffle.cpp",
"mkl-dnn/src/common/softmax.cpp",
"mkl-dnn/src/common/stream.cpp",
"mkl-dnn/src/common/sum.cpp",
"mkl-dnn/src/common/utils.cpp",
"mkl-dnn/src/common/verbose.cpp",
"mkl-dnn/src/cpu/cpu_barrier.cpp",
"mkl-dnn/src/cpu/cpu_concat.cpp",
"mkl-dnn/src/cpu/cpu_engine.cpp",
"mkl-dnn/src/cpu/cpu_memory.cpp",
"mkl-dnn/src/cpu/cpu_reducer.cpp",
"mkl-dnn/src/cpu/cpu_reorder.cpp",
"mkl-dnn/src/cpu/cpu_sum.cpp",
"mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.cpp",
"mkl-dnn/src/cpu/jit_avx2_convolution.cpp",
"mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp",
"mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.cpp",
"mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp",
"mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.cpp",
"mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.cpp",
"mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.cpp",
"mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.cpp",
"mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.cpp",
"mkl-dnn/src/cpu/jit_sse42_convolution.cpp",
"mkl-dnn/src/cpu/jit_transpose_src_utils.cpp",
"mkl-dnn/src/cpu/jit_uni_eltwise.cpp",
"mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.cpp",
"mkl-dnn/src/cpu/jit_uni_pooling.cpp",
"mkl-dnn/src/cpu/jit_uni_reorder.cpp",
"mkl-dnn/src/cpu/jit_uni_reorder_utils.cpp",
"mkl-dnn/src/cpu/jit_utils/jit_utils.cpp",
"mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.c",
"common/platform.cpp",
"common/thread.cpp",
"common/tensor.cpp",
]
thirdparty_sources = [thirdparty_dir + file for file in thirdparty_sources]
thirdparty_include_dirs = [
"",
"include",
"mkl-dnn/include",
"mkl-dnn/src",
"mkl-dnn/src/common",
"mkl-dnn/src/cpu/xbyak",
"mkl-dnn/src/cpu",
]
thirdparty_include_dirs = [thirdparty_dir + file for file in thirdparty_include_dirs]
env_oidn.Prepend(CPPPATH=thirdparty_include_dirs)
env_oidn.Append(
CPPDEFINES=[
"MKLDNN_THR=MKLDNN_THR_SEQ",
"OIDN_STATIC_LIB",
"__STDC_CONSTANT_MACROS",
"__STDC_LIMIT_MACROS",
"DISABLE_VERBOSE",
"MKLDNN_ENABLE_CONCURRENT_EXEC",
]
)
env_oidn.AppendUnique(CPPDEFINES=["NDEBUG"]) # No assert() even in debug builds.
env_thirdparty = env_oidn.Clone()
env_thirdparty.disable_warnings()
if env["disable_exceptions"]:
# OIDN hard-requires exceptions, so we re-enable them here.
if env.msvc and ("_HAS_EXCEPTIONS", 0) in env_thirdparty["CPPDEFINES"]:
env_thirdparty["CPPDEFINES"].remove(("_HAS_EXCEPTIONS", 0))
env_thirdparty.AppendUnique(CCFLAGS=["/EHsc"])
elif not env.msvc and "-fno-exceptions" in env_thirdparty["CCFLAGS"]:
env_thirdparty["CCFLAGS"].remove("-fno-exceptions")
env_thirdparty.add_source_files(thirdparty_obj, thirdparty_sources)
env.modules_sources += thirdparty_obj
weights_in_path = thirdparty_dir + "weights/rtlightmap_hdr.tza"
weights_out_path = thirdparty_dir + "weights/rtlightmap_hdr.gen.cpp"
env_thirdparty.Depends(weights_out_path, weights_in_path)
env_thirdparty.CommandNoCache(weights_out_path, weights_in_path, resource_to_cpp.tza_to_cpp)
# Godot source files
module_obj = []
env_oidn.add_source_files(module_obj, "*.cpp")
env.modules_sources += module_obj
# Needed to force rebuilding the module files when the thirdparty library is updated.
env.Depends(module_obj, thirdparty_obj)

View file

@ -1,12 +0,0 @@
def can_build(env, platform):
# Thirdparty dependency OpenImage Denoise includes oneDNN library
# and the version we use only supports x86_64.
# It's also only relevant for tools build and desktop platforms,
# as doing lightmap generation and denoising on Android or Web
# would be a bit far-fetched.
desktop_platforms = ["linuxbsd", "macos", "windows"]
return env.editor_build and platform in desktop_platforms and env["arch"] == "x86_64"
def configure(env):
pass

View file

@ -1,66 +0,0 @@
/**************************************************************************/
/* denoise_wrapper.cpp */
/**************************************************************************/
/* This file is part of: */
/* GODOT ENGINE */
/* https://godotengine.org */
/**************************************************************************/
/* Copyright (c) 2014-present Godot Engine contributors (see AUTHORS.md). */
/* Copyright (c) 2007-2014 Juan Linietsky, Ariel Manzur. */
/* */
/* Permission is hereby granted, free of charge, to any person obtaining */
/* a copy of this software and associated documentation files (the */
/* "Software"), to deal in the Software without restriction, including */
/* without limitation the rights to use, copy, modify, merge, publish, */
/* distribute, sublicense, and/or sell copies of the Software, and to */
/* permit persons to whom the Software is furnished to do so, subject to */
/* the following conditions: */
/* */
/* The above copyright notice and this permission notice shall be */
/* included in all copies or substantial portions of the Software. */
/* */
/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */
/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */
/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. */
/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */
/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */
/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */
/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */
/**************************************************************************/
#include "denoise_wrapper.h"
#include <OpenImageDenoise/oidn.h>
#include <stdio.h>
void *oidn_denoiser_init() {
OIDNDeviceImpl *device = oidnNewDevice(OIDN_DEVICE_TYPE_CPU);
oidnCommitDevice(device);
return device;
}
bool oidn_denoise(void *deviceptr, float *p_floats, int p_width, int p_height) {
OIDNDeviceImpl *device = (OIDNDeviceImpl *)deviceptr;
OIDNFilter filter = oidnNewFilter(device, "RTLightmap");
oidnSetSharedFilterImage(filter, "color", (void *)p_floats, OIDN_FORMAT_FLOAT3, p_width, p_height, 0, 0, 0);
oidnSetSharedFilterImage(filter, "output", (void *)p_floats, OIDN_FORMAT_FLOAT3, p_width, p_height, 0, 0, 0);
oidnSetFilter1b(filter, "hdr", true);
//oidnSetFilter1f(filter, "hdrScale", 1.0f);
oidnCommitFilter(filter);
oidnExecuteFilter(filter);
const char *msg;
bool success = true;
if (oidnGetDeviceError(device, &msg) != OIDN_ERROR_NONE) {
printf("LightmapDenoiser: %s\n", msg);
success = false;
}
oidnReleaseFilter(filter);
return success;
}
void oidn_denoiser_finish(void *device) {
oidnReleaseDevice((OIDNDeviceImpl *)device);
}

View file

@ -1,38 +0,0 @@
/**************************************************************************/
/* denoise_wrapper.h */
/**************************************************************************/
/* This file is part of: */
/* GODOT ENGINE */
/* https://godotengine.org */
/**************************************************************************/
/* Copyright (c) 2014-present Godot Engine contributors (see AUTHORS.md). */
/* Copyright (c) 2007-2014 Juan Linietsky, Ariel Manzur. */
/* */
/* Permission is hereby granted, free of charge, to any person obtaining */
/* a copy of this software and associated documentation files (the */
/* "Software"), to deal in the Software without restriction, including */
/* without limitation the rights to use, copy, modify, merge, publish, */
/* distribute, sublicense, and/or sell copies of the Software, and to */
/* permit persons to whom the Software is furnished to do so, subject to */
/* the following conditions: */
/* */
/* The above copyright notice and this permission notice shall be */
/* included in all copies or substantial portions of the Software. */
/* */
/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */
/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */
/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. */
/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */
/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */
/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */
/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */
/**************************************************************************/
#ifndef DENOISE_WRAPPER_H
#define DENOISE_WRAPPER_H
void *oidn_denoiser_init();
bool oidn_denoise(void *device, float *p_floats, int p_width, int p_height);
void oidn_denoiser_finish(void *device);
#endif // DENOISE_WRAPPER_H

View file

@ -1,65 +0,0 @@
/**************************************************************************/
/* lightmap_denoiser.cpp */
/**************************************************************************/
/* This file is part of: */
/* GODOT ENGINE */
/* https://godotengine.org */
/**************************************************************************/
/* Copyright (c) 2014-present Godot Engine contributors (see AUTHORS.md). */
/* Copyright (c) 2007-2014 Juan Linietsky, Ariel Manzur. */
/* */
/* Permission is hereby granted, free of charge, to any person obtaining */
/* a copy of this software and associated documentation files (the */
/* "Software"), to deal in the Software without restriction, including */
/* without limitation the rights to use, copy, modify, merge, publish, */
/* distribute, sublicense, and/or sell copies of the Software, and to */
/* permit persons to whom the Software is furnished to do so, subject to */
/* the following conditions: */
/* */
/* The above copyright notice and this permission notice shall be */
/* included in all copies or substantial portions of the Software. */
/* */
/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */
/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */
/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. */
/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */
/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */
/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */
/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */
/**************************************************************************/
#include "lightmap_denoiser.h"
#include "denoise_wrapper.h"
#include "core/io/image.h"
LightmapDenoiser *LightmapDenoiserOIDN::create_oidn_denoiser() {
return memnew(LightmapDenoiserOIDN);
}
void LightmapDenoiserOIDN::make_default_denoiser() {
create_function = create_oidn_denoiser;
}
Ref<Image> LightmapDenoiserOIDN::denoise_image(const Ref<Image> &p_image) {
Ref<Image> img = p_image->duplicate();
img->convert(Image::FORMAT_RGBF);
Vector<uint8_t> data = img->get_data();
if (!oidn_denoise(device, (float *)data.ptrw(), img->get_width(), img->get_height())) {
return p_image;
}
img->set_data(img->get_width(), img->get_height(), false, img->get_format(), data);
return img;
}
LightmapDenoiserOIDN::LightmapDenoiserOIDN() {
device = oidn_denoiser_init();
}
LightmapDenoiserOIDN::~LightmapDenoiserOIDN() {
oidn_denoiser_finish(device);
}

View file

@ -1,56 +0,0 @@
/**************************************************************************/
/* lightmap_denoiser.h */
/**************************************************************************/
/* This file is part of: */
/* GODOT ENGINE */
/* https://godotengine.org */
/**************************************************************************/
/* Copyright (c) 2014-present Godot Engine contributors (see AUTHORS.md). */
/* Copyright (c) 2007-2014 Juan Linietsky, Ariel Manzur. */
/* */
/* Permission is hereby granted, free of charge, to any person obtaining */
/* a copy of this software and associated documentation files (the */
/* "Software"), to deal in the Software without restriction, including */
/* without limitation the rights to use, copy, modify, merge, publish, */
/* distribute, sublicense, and/or sell copies of the Software, and to */
/* permit persons to whom the Software is furnished to do so, subject to */
/* the following conditions: */
/* */
/* The above copyright notice and this permission notice shall be */
/* included in all copies or substantial portions of the Software. */
/* */
/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */
/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */
/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. */
/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */
/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */
/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */
/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */
/**************************************************************************/
#ifndef LIGHTMAP_DENOISER_H
#define LIGHTMAP_DENOISER_H
#include "core/object/class_db.h"
#include "scene/3d/lightmapper.h"
struct OIDNDeviceImpl;
class LightmapDenoiserOIDN : public LightmapDenoiser {
GDCLASS(LightmapDenoiserOIDN, LightmapDenoiser);
protected:
void *device = nullptr;
public:
static LightmapDenoiser *create_oidn_denoiser();
Ref<Image> denoise_image(const Ref<Image> &p_image) override;
static void make_default_denoiser();
LightmapDenoiserOIDN();
~LightmapDenoiserOIDN();
};
#endif // LIGHTMAP_DENOISER_H

View file

@ -1,49 +0,0 @@
/**************************************************************************/
/* register_types.cpp */
/**************************************************************************/
/* This file is part of: */
/* GODOT ENGINE */
/* https://godotengine.org */
/**************************************************************************/
/* Copyright (c) 2014-present Godot Engine contributors (see AUTHORS.md). */
/* Copyright (c) 2007-2014 Juan Linietsky, Ariel Manzur. */
/* */
/* Permission is hereby granted, free of charge, to any person obtaining */
/* a copy of this software and associated documentation files (the */
/* "Software"), to deal in the Software without restriction, including */
/* without limitation the rights to use, copy, modify, merge, publish, */
/* distribute, sublicense, and/or sell copies of the Software, and to */
/* permit persons to whom the Software is furnished to do so, subject to */
/* the following conditions: */
/* */
/* The above copyright notice and this permission notice shall be */
/* included in all copies or substantial portions of the Software. */
/* */
/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */
/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */
/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. */
/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */
/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */
/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */
/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */
/**************************************************************************/
#include "register_types.h"
#include "lightmap_denoiser.h"
#include "core/config/engine.h"
void initialize_denoise_module(ModuleInitializationLevel p_level) {
if (p_level != MODULE_INITIALIZATION_LEVEL_SCENE) {
return;
}
LightmapDenoiserOIDN::make_default_denoiser();
}
void uninitialize_denoise_module(ModuleInitializationLevel p_level) {
if (p_level != MODULE_INITIALIZATION_LEVEL_SCENE) {
return;
}
}

View file

@ -1,39 +0,0 @@
/**************************************************************************/
/* register_types.h */
/**************************************************************************/
/* This file is part of: */
/* GODOT ENGINE */
/* https://godotengine.org */
/**************************************************************************/
/* Copyright (c) 2014-present Godot Engine contributors (see AUTHORS.md). */
/* Copyright (c) 2007-2014 Juan Linietsky, Ariel Manzur. */
/* */
/* Permission is hereby granted, free of charge, to any person obtaining */
/* a copy of this software and associated documentation files (the */
/* "Software"), to deal in the Software without restriction, including */
/* without limitation the rights to use, copy, modify, merge, publish, */
/* distribute, sublicense, and/or sell copies of the Software, and to */
/* permit persons to whom the Software is furnished to do so, subject to */
/* the following conditions: */
/* */
/* The above copyright notice and this permission notice shall be */
/* included in all copies or substantial portions of the Software. */
/* */
/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */
/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */
/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. */
/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */
/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */
/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */
/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */
/**************************************************************************/
#ifndef DENOISE_REGISTER_TYPES_H
#define DENOISE_REGISTER_TYPES_H
#include "modules/register_module_types.h"
void initialize_denoise_module(ModuleInitializationLevel p_level);
void uninitialize_denoise_module(ModuleInitializationLevel p_level);
#endif // DENOISE_REGISTER_TYPES_H

View file

@ -1,68 +0,0 @@
#!/usr/bin/env python
## ======================================================================== ##
## Copyright 2009-2019 Intel Corporation ##
## ##
## Licensed under the Apache License, Version 2.0 (the "License"); ##
## you may not use this file except in compliance with the License. ##
## You may obtain a copy of the License at ##
## ##
## http://www.apache.org/licenses/LICENSE-2.0 ##
## ##
## Unless required by applicable law or agreed to in writing, software ##
## distributed under the License is distributed on an "AS IS" BASIS, ##
## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ##
## See the License for the specific language governing permissions and ##
## limitations under the License. ##
## ======================================================================== ##
import os
from array import array
# Generates a C++ file from the specified binary resource file
def generate(in_path, out_path):
namespace = "oidn::weights"
scopes = namespace.split("::")
file_name = os.path.basename(in_path)
var_name = os.path.splitext(file_name)[0]
with open(in_path, "rb") as in_file, open(out_path, "w") as out_file:
# Header
out_file.write("// Generated from: %s\n" % file_name)
out_file.write("#include <cstddef>\n\n")
# Open the namespaces
for s in scopes:
out_file.write("namespace %s {\n" % s)
if scopes:
out_file.write("\n")
# Read the file
in_data = array("B", in_file.read())
# Write the size
out_file.write("//const size_t %s_size = %d;\n\n" % (var_name, len(in_data)))
# Write the data
out_file.write("unsigned char %s[] = {" % var_name)
for i in range(len(in_data)):
c = in_data[i]
if i > 0:
out_file.write(",")
if (i + 1) % 20 == 1:
out_file.write("\n")
out_file.write("%d" % c)
out_file.write("\n};\n")
# Close the namespaces
if scopes:
out_file.write("\n")
for scope in reversed(scopes):
out_file.write("} // namespace %s\n" % scope)
def tza_to_cpp(target, source, env):
for x in zip(source, target):
generate(str(x[0]), str(x[1]))

View file

@ -614,24 +614,28 @@ void LightmapperRD::_raster_geometry(RenderingDevice *rd, Size2i atlas_size, int
}
}
LightmapperRD::BakeError LightmapperRD::_dilate(RenderingDevice *rd, Ref<RDShaderFile> &compute_shader, RID &compute_base_uniform_set, PushConstant &push_constant, RID &source_light_tex, RID &dest_light_tex, const Size2i &atlas_size, int atlas_slices) {
static Vector<RD::Uniform> dilate_or_denoise_common_uniforms(RID &p_source_light_tex, RID &p_dest_light_tex) {
Vector<RD::Uniform> uniforms;
{
{
RD::Uniform u;
u.uniform_type = RD::UNIFORM_TYPE_IMAGE;
u.binding = 0;
u.append_id(dest_light_tex);
uniforms.push_back(u);
}
{
RD::Uniform u;
u.uniform_type = RD::UNIFORM_TYPE_TEXTURE;
u.binding = 1;
u.append_id(source_light_tex);
uniforms.push_back(u);
}
RD::Uniform u;
u.uniform_type = RD::UNIFORM_TYPE_IMAGE;
u.binding = 0;
u.append_id(p_dest_light_tex);
uniforms.push_back(u);
}
{
RD::Uniform u;
u.uniform_type = RD::UNIFORM_TYPE_TEXTURE;
u.binding = 1;
u.append_id(p_source_light_tex);
uniforms.push_back(u);
}
return uniforms;
}
LightmapperRD::BakeError LightmapperRD::_dilate(RenderingDevice *rd, Ref<RDShaderFile> &compute_shader, RID &compute_base_uniform_set, PushConstant &push_constant, RID &source_light_tex, RID &dest_light_tex, const Size2i &atlas_size, int atlas_slices) {
Vector<RD::Uniform> uniforms = dilate_or_denoise_common_uniforms(source_light_tex, dest_light_tex);
RID compute_shader_dilate = rd->shader_create_from_spirv(compute_shader->get_spirv_stages("dilate"));
ERR_FAIL_COND_V(compute_shader_dilate.is_null(), BAKE_ERROR_LIGHTMAP_CANT_PRE_BAKE_MESHES); //internal check, should not happen
@ -667,7 +671,77 @@ LightmapperRD::BakeError LightmapperRD::_dilate(RenderingDevice *rd, Ref<RDShade
return BAKE_OK;
}
LightmapperRD::BakeError LightmapperRD::bake(BakeQuality p_quality, bool p_use_denoiser, int p_bounces, float p_bias, int p_max_texture_size, bool p_bake_sh, GenerateProbes p_generate_probes, const Ref<Image> &p_environment_panorama, const Basis &p_environment_transform, BakeStepFunc p_step_function, void *p_bake_userdata, float p_exposure_normalization) {
LightmapperRD::BakeError LightmapperRD::_denoise(RenderingDevice *p_rd, Ref<RDShaderFile> &p_compute_shader, const RID &p_compute_base_uniform_set, PushConstant &p_push_constant, RID p_source_light_tex, RID p_source_normal_tex, RID p_dest_light_tex, float p_denoiser_strength, const Size2i &p_atlas_size, int p_atlas_slices, bool p_bake_sh, BakeStepFunc p_step_function) {
RID denoise_params_buffer = p_rd->uniform_buffer_create(sizeof(DenoiseParams));
DenoiseParams denoise_params;
denoise_params.spatial_bandwidth = 5.0f;
denoise_params.light_bandwidth = p_denoiser_strength;
denoise_params.albedo_bandwidth = 1.0f;
denoise_params.normal_bandwidth = 0.1f;
denoise_params.filter_strength = 10.0f;
p_rd->buffer_update(denoise_params_buffer, 0, sizeof(DenoiseParams), &denoise_params);
Vector<RD::Uniform> uniforms = dilate_or_denoise_common_uniforms(p_source_light_tex, p_dest_light_tex);
{
RD::Uniform u;
u.uniform_type = RD::UNIFORM_TYPE_TEXTURE;
u.binding = 2;
u.append_id(p_source_normal_tex);
uniforms.push_back(u);
}
{
RD::Uniform u;
u.uniform_type = RD::UNIFORM_TYPE_UNIFORM_BUFFER;
u.binding = 3;
u.append_id(denoise_params_buffer);
uniforms.push_back(u);
}
RID compute_shader_denoise = p_rd->shader_create_from_spirv(p_compute_shader->get_spirv_stages("denoise"));
ERR_FAIL_COND_V(compute_shader_denoise.is_null(), BAKE_ERROR_LIGHTMAP_CANT_PRE_BAKE_MESHES);
RID compute_shader_denoise_pipeline = p_rd->compute_pipeline_create(compute_shader_denoise);
RID denoise_uniform_set = p_rd->uniform_set_create(uniforms, compute_shader_denoise, 1);
// We denoise in fixed size regions and synchronize execution to avoid GPU timeouts.
// We use a region with 1/4 the amount of pixels if we're denoising SH lightmaps, as
// all four of them are denoised in the shader in one dispatch.
const int max_region_size = p_bake_sh ? 512 : 1024;
int x_regions = (p_atlas_size.width - 1) / max_region_size + 1;
int y_regions = (p_atlas_size.height - 1) / max_region_size + 1;
for (int s = 0; s < p_atlas_slices; s++) {
p_push_constant.atlas_slice = s;
for (int i = 0; i < x_regions; i++) {
for (int j = 0; j < y_regions; j++) {
int x = i * max_region_size;
int y = j * max_region_size;
int w = MIN((i + 1) * max_region_size, p_atlas_size.width) - x;
int h = MIN((j + 1) * max_region_size, p_atlas_size.height) - y;
p_push_constant.region_ofs[0] = x;
p_push_constant.region_ofs[1] = y;
RD::ComputeListID compute_list = p_rd->compute_list_begin();
p_rd->compute_list_bind_compute_pipeline(compute_list, compute_shader_denoise_pipeline);
p_rd->compute_list_bind_uniform_set(compute_list, p_compute_base_uniform_set, 0);
p_rd->compute_list_bind_uniform_set(compute_list, denoise_uniform_set, 1);
p_rd->compute_list_set_push_constant(compute_list, &p_push_constant, sizeof(PushConstant));
p_rd->compute_list_dispatch(compute_list, (w - 1) / 8 + 1, (h - 1) / 8 + 1, 1);
p_rd->compute_list_end();
p_rd->submit();
p_rd->sync();
}
}
}
p_rd->free(compute_shader_denoise);
p_rd->free(denoise_params_buffer);
return BAKE_OK;
}
LightmapperRD::BakeError LightmapperRD::bake(BakeQuality p_quality, bool p_use_denoiser, float p_denoiser_strength, int p_bounces, float p_bias, int p_max_texture_size, bool p_bake_sh, GenerateProbes p_generate_probes, const Ref<Image> &p_environment_panorama, const Basis &p_environment_transform, BakeStepFunc p_step_function, void *p_bake_userdata, float p_exposure_normalization) {
if (p_step_function) {
p_step_function(0.0, RTR("Begin Bake"), p_bake_userdata, true);
}
@ -1434,27 +1508,11 @@ LightmapperRD::BakeError LightmapperRD::bake(BakeQuality p_quality, bool p_use_d
p_step_function(0.8, RTR("Denoising"), p_bake_userdata, true);
}
Ref<LightmapDenoiser> denoiser = LightmapDenoiser::create();
if (denoiser.is_valid()) {
for (int i = 0; i < atlas_slices * (p_bake_sh ? 4 : 1); i++) {
Vector<uint8_t> s = rd->texture_get_data(light_accum_tex, i);
Ref<Image> img = Image::create_from_data(atlas_size.width, atlas_size.height, false, Image::FORMAT_RGBAH, s);
Ref<Image> denoised = denoiser->denoise_image(img);
if (denoised != img) {
denoised->convert(Image::FORMAT_RGBAH);
Vector<uint8_t> ds = denoised->get_data();
denoised.unref(); //avoid copy on write
{ //restore alpha
uint32_t count = s.size() / 2; //uint16s
const uint16_t *src = (const uint16_t *)s.ptr();
uint16_t *dst = (uint16_t *)ds.ptrw();
for (uint32_t j = 0; j < count; j += 4) {
dst[j + 3] = src[j + 3];
}
}
rd->texture_update(light_accum_tex, i, ds);
}
{
SWAP(light_accum_tex, light_accum_tex2);
BakeError error = _denoise(rd, compute_shader, compute_base_uniform_set, push_constant, light_accum_tex2, normal_tex, light_accum_tex, p_denoiser_strength, atlas_size, atlas_slices, p_bake_sh, p_step_function);
if (unlikely(error != BAKE_OK)) {
return error;
}
}

View file

@ -229,11 +229,22 @@ class LightmapperRD : public Lightmapper {
Vector<Ref<Image>> bake_textures;
Vector<Color> probe_values;
struct DenoiseParams {
float spatial_bandwidth;
float light_bandwidth;
float albedo_bandwidth;
float normal_bandwidth;
float filter_strength;
float pad[3];
};
BakeError _blit_meshes_into_atlas(int p_max_texture_size, Vector<Ref<Image>> &albedo_images, Vector<Ref<Image>> &emission_images, AABB &bounds, Size2i &atlas_size, int &atlas_slices, BakeStepFunc p_step_function, void *p_bake_userdata);
void _create_acceleration_structures(RenderingDevice *rd, Size2i atlas_size, int atlas_slices, AABB &bounds, int grid_size, Vector<Probe> &probe_positions, GenerateProbes p_generate_probes, Vector<int> &slice_triangle_count, Vector<int> &slice_seam_count, RID &vertex_buffer, RID &triangle_buffer, RID &lights_buffer, RID &triangle_cell_indices_buffer, RID &probe_positions_buffer, RID &grid_texture, RID &seams_buffer, BakeStepFunc p_step_function, void *p_bake_userdata);
void _raster_geometry(RenderingDevice *rd, Size2i atlas_size, int atlas_slices, int grid_size, AABB bounds, float p_bias, Vector<int> slice_triangle_count, RID position_tex, RID unocclude_tex, RID normal_tex, RID raster_depth_buffer, RID rasterize_shader, RID raster_base_uniform);
BakeError _dilate(RenderingDevice *rd, Ref<RDShaderFile> &compute_shader, RID &compute_base_uniform_set, PushConstant &push_constant, RID &source_light_tex, RID &dest_light_tex, const Size2i &atlas_size, int atlas_slices);
BakeError _denoise(RenderingDevice *p_rd, Ref<RDShaderFile> &p_compute_shader, const RID &p_compute_base_uniform_set, PushConstant &p_push_constant, RID p_source_light_tex, RID p_source_normal_tex, RID p_dest_light_tex, float p_denoiser_strength, const Size2i &p_atlas_size, int p_atlas_slices, bool p_bake_sh, BakeStepFunc p_step_function);
public:
virtual void add_mesh(const MeshData &p_mesh) override;
@ -241,7 +252,7 @@ public:
virtual void add_omni_light(bool p_static, const Vector3 &p_position, const Color &p_color, float p_energy, float p_range, float p_attenuation, float p_size, float p_shadow_blur) override;
virtual void add_spot_light(bool p_static, const Vector3 &p_position, const Vector3 p_direction, const Color &p_color, float p_energy, float p_range, float p_attenuation, float p_spot_angle, float p_spot_attenuation, float p_size, float p_shadow_blur) override;
virtual void add_probe(const Vector3 &p_position) override;
virtual BakeError bake(BakeQuality p_quality, bool p_use_denoiser, int p_bounces, float p_bias, int p_max_texture_size, bool p_bake_sh, GenerateProbes p_generate_probes, const Ref<Image> &p_environment_panorama, const Basis &p_environment_transform, BakeStepFunc p_step_function = nullptr, void *p_bake_userdata = nullptr, float p_exposure_normalization = 1.0) override;
virtual BakeError bake(BakeQuality p_quality, bool p_use_denoiser, float p_denoiser_strength, int p_bounces, float p_bias, int p_max_texture_size, bool p_bake_sh, GenerateProbes p_generate_probes, const Ref<Image> &p_environment_panorama, const Basis &p_environment_transform, BakeStepFunc p_step_function = nullptr, void *p_bake_userdata = nullptr, float p_exposure_normalization = 1.0) override;
int get_bake_texture_count() const override;
Ref<Image> get_bake_texture(int p_index) const override;

View file

@ -5,6 +5,7 @@ secondary = "#define MODE_BOUNCE_LIGHT";
dilate = "#define MODE_DILATE";
unocclude = "#define MODE_UNOCCLUDE";
light_probes = "#define MODE_LIGHT_PROBES";
denoise = "#define MODE_DENOISE";
#[compute]
@ -65,11 +66,24 @@ layout(set = 1, binding = 6) uniform texture2D environment;
layout(rgba32f, set = 1, binding = 5) uniform restrict writeonly image2DArray primary_dynamic;
#endif
#ifdef MODE_DILATE
#if defined(MODE_DILATE) || defined(MODE_DENOISE)
layout(rgba16f, set = 1, binding = 0) uniform restrict writeonly image2DArray dest_light;
layout(set = 1, binding = 1) uniform texture2DArray source_light;
#endif
#ifdef MODE_DENOISE
layout(set = 1, binding = 2) uniform texture2DArray source_normal;
layout(set = 1, binding = 3) uniform DenoiseParams {
float spatial_bandwidth;
float light_bandwidth;
float albedo_bandwidth;
float normal_bandwidth;
float filter_strength;
}
denoise_params;
#endif
layout(push_constant, std430) uniform Params {
ivec2 atlas_size; // x used for light probe mode total probes
uint ray_count;
@ -735,4 +749,153 @@ void main() {
imageStore(dest_light, ivec3(atlas_pos, params.atlas_slice), c);
#endif
#ifdef MODE_DENOISE
// Joint Non-local means (JNLM) denoiser.
//
// Based on YoctoImageDenoiser's JNLM implementation with corrections from "Nonlinearly Weighted First-order Regression for Denoising Monte Carlo Renderings".
//
// <https://github.com/ManuelPrandini/YoctoImageDenoiser/blob/06e19489dd64e47792acffde536393802ba48607/libs/yocto_extension/yocto_extension.cpp#L207>
// <https://benedikt-bitterli.me/nfor/nfor.pdf>
//
// MIT License
//
// Copyright (c) 2020 ManuelPrandini
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
//
// Most of the constants below have been hand-picked to fit the common scenarios lightmaps
// are generated with, but they can be altered freely to experiment and achieve better results.
// Half the size of the patch window around each pixel that is weighted to compute the denoised pixel.
// A value of 1 represents a 3x3 window, a value of 2 a 5x5 window, etc.
const int HALF_PATCH_WINDOW = 4;
// Half the size of the search window around each pixel that is denoised and weighted to compute the denoised pixel.
const int HALF_SEARCH_WINDOW = 10;
// For all of the following sigma values, smaller values will give less weight to pixels that have a bigger distance
// in the feature being evaluated. Therefore, smaller values are likely to cause more noise to appear, but will also
// cause less features to be erased in the process.
// Controls how much the spatial distance of the pixels influences the denoising weight.
const float SIGMA_SPATIAL = denoise_params.spatial_bandwidth;
// Controls how much the light color distance of the pixels influences the denoising weight.
const float SIGMA_LIGHT = denoise_params.light_bandwidth;
// Controls how much the albedo color distance of the pixels influences the denoising weight.
const float SIGMA_ALBEDO = denoise_params.albedo_bandwidth;
// Controls how much the normal vector distance of the pixels influences the denoising weight.
const float SIGMA_NORMAL = denoise_params.normal_bandwidth;
// Strength of the filter. The original paper recommends values around 10 to 15 times the Sigma parameter.
const float FILTER_VALUE = denoise_params.filter_strength * SIGMA_LIGHT;
// Formula constants.
const int PATCH_WINDOW_DIMENSION = (HALF_PATCH_WINDOW * 2 + 1);
const int PATCH_WINDOW_DIMENSION_SQUARE = (PATCH_WINDOW_DIMENSION * PATCH_WINDOW_DIMENSION);
const float TWO_SIGMA_SPATIAL_SQUARE = 2.0f * SIGMA_SPATIAL * SIGMA_SPATIAL;
const float TWO_SIGMA_LIGHT_SQUARE = 2.0f * SIGMA_LIGHT * SIGMA_LIGHT;
const float TWO_SIGMA_ALBEDO_SQUARE = 2.0f * SIGMA_ALBEDO * SIGMA_ALBEDO;
const float TWO_SIGMA_NORMAL_SQUARE = 2.0f * SIGMA_NORMAL * SIGMA_NORMAL;
const float FILTER_SQUARE_TWO_SIGMA_LIGHT_SQUARE = FILTER_VALUE * FILTER_VALUE * TWO_SIGMA_LIGHT_SQUARE;
const float EPSILON = 1e-6f;
#ifdef USE_SH_LIGHTMAPS
const uint slice_count = 4;
const uint slice_base = params.atlas_slice * slice_count;
#else
const uint slice_count = 1;
const uint slice_base = params.atlas_slice;
#endif
for (uint i = 0; i < slice_count; i++) {
uint lightmap_slice = slice_base + i;
vec3 denoised_rgb = vec3(0.0f);
vec4 input_light = texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos, lightmap_slice), 0);
vec3 input_albedo = texelFetch(sampler2DArray(albedo_tex, linear_sampler), ivec3(atlas_pos, params.atlas_slice), 0).rgb;
vec3 input_normal = texelFetch(sampler2DArray(source_normal, linear_sampler), ivec3(atlas_pos, params.atlas_slice), 0).xyz;
if (length(input_normal) > EPSILON) {
// Compute the denoised pixel if the normal is valid.
float sum_weights = 0.0f;
vec3 input_rgb = input_light.rgb;
for (int search_y = -HALF_SEARCH_WINDOW; search_y <= HALF_SEARCH_WINDOW; search_y++) {
for (int search_x = -HALF_SEARCH_WINDOW; search_x <= HALF_SEARCH_WINDOW; search_x++) {
ivec2 search_pos = atlas_pos + ivec2(search_x, search_y);
vec3 search_rgb = texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(search_pos, lightmap_slice), 0).rgb;
vec3 search_albedo = texelFetch(sampler2DArray(albedo_tex, linear_sampler), ivec3(search_pos, params.atlas_slice), 0).rgb;
vec3 search_normal = texelFetch(sampler2DArray(source_normal, linear_sampler), ivec3(search_pos, params.atlas_slice), 0).xyz;
float patch_square_dist = 0.0f;
for (int offset_y = -HALF_PATCH_WINDOW; offset_y <= HALF_PATCH_WINDOW; offset_y++) {
for (int offset_x = -HALF_PATCH_WINDOW; offset_x <= HALF_PATCH_WINDOW; offset_x++) {
ivec2 offset_input_pos = atlas_pos + ivec2(offset_x, offset_y);
ivec2 offset_search_pos = search_pos + ivec2(offset_x, offset_y);
vec3 offset_input_rgb = texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(offset_input_pos, lightmap_slice), 0).rgb;
vec3 offset_search_rgb = texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(offset_search_pos, lightmap_slice), 0).rgb;
vec3 offset_delta_rgb = offset_input_rgb - offset_search_rgb;
patch_square_dist += dot(offset_delta_rgb, offset_delta_rgb) - TWO_SIGMA_LIGHT_SQUARE;
}
}
patch_square_dist = max(0.0f, patch_square_dist / (3.0f * PATCH_WINDOW_DIMENSION_SQUARE));
float weight = 1.0f;
// Ignore weight if search position is out of bounds.
weight *= step(0, search_pos.x) * step(search_pos.x, params.atlas_size.x - 1);
weight *= step(0, search_pos.y) * step(search_pos.y, params.atlas_size.y - 1);
// Ignore weight if normal is zero length.
weight *= step(EPSILON, length(search_normal));
// Weight with pixel distance.
vec2 pixel_delta = vec2(search_x, search_y);
float pixel_square_dist = dot(pixel_delta, pixel_delta);
weight *= exp(-pixel_square_dist / TWO_SIGMA_SPATIAL_SQUARE);
// Weight with patch.
weight *= exp(-patch_square_dist / FILTER_SQUARE_TWO_SIGMA_LIGHT_SQUARE);
// Weight with albedo.
vec3 albedo_delta = input_albedo - search_albedo;
float albedo_square_dist = dot(albedo_delta, albedo_delta);
weight *= exp(-albedo_square_dist / TWO_SIGMA_ALBEDO_SQUARE);
// Weight with normal.
vec3 normal_delta = input_normal - search_normal;
float normal_square_dist = dot(normal_delta, normal_delta);
weight *= exp(-normal_square_dist / TWO_SIGMA_NORMAL_SQUARE);
denoised_rgb += weight * search_rgb;
sum_weights += weight;
}
}
denoised_rgb /= sum_weights;
} else {
// Ignore pixels where the normal is empty, just copy the light color.
denoised_rgb = input_light.rgb;
}
imageStore(dest_light, ivec3(atlas_pos, lightmap_slice), vec4(denoised_rgb, input_light.a));
}
#endif
}

View file

@ -1080,7 +1080,7 @@ LightmapGI::BakeError LightmapGI::bake(Node *p_from_node, String p_image_data_pa
}
}
Lightmapper::BakeError bake_err = lightmapper->bake(Lightmapper::BakeQuality(bake_quality), use_denoiser, bounces, bias, max_texture_size, directional, Lightmapper::GenerateProbes(gen_probes), environment_image, environment_transform, _lightmap_bake_step_function, &bsud, exposure_normalization);
Lightmapper::BakeError bake_err = lightmapper->bake(Lightmapper::BakeQuality(bake_quality), use_denoiser, denoiser_strength, bounces, bias, max_texture_size, directional, Lightmapper::GenerateProbes(gen_probes), environment_image, environment_transform, _lightmap_bake_step_function, &bsud, exposure_normalization);
if (bake_err == Lightmapper::BAKE_ERROR_LIGHTMAP_TOO_SMALL) {
return BAKE_ERROR_TEXTURE_SIZE_TOO_SMALL;
@ -1362,12 +1362,21 @@ AABB LightmapGI::get_aabb() const {
void LightmapGI::set_use_denoiser(bool p_enable) {
use_denoiser = p_enable;
notify_property_list_changed();
}
bool LightmapGI::is_using_denoiser() const {
return use_denoiser;
}
void LightmapGI::set_denoiser_strength(float p_denoiser_strength) {
denoiser_strength = p_denoiser_strength;
}
float LightmapGI::get_denoiser_strength() const {
return denoiser_strength;
}
void LightmapGI::set_directional(bool p_enable) {
directional = p_enable;
}
@ -1482,6 +1491,9 @@ void LightmapGI::_validate_property(PropertyInfo &p_property) const {
if (p_property.name == "environment_custom_energy" && environment_mode != ENVIRONMENT_MODE_CUSTOM_COLOR && environment_mode != ENVIRONMENT_MODE_CUSTOM_SKY) {
p_property.usage = PROPERTY_USAGE_NONE;
}
if (p_property.name == "denoiser_strength" && !use_denoiser) {
p_property.usage = PROPERTY_USAGE_NONE;
}
}
void LightmapGI::_bind_methods() {
@ -1518,6 +1530,9 @@ void LightmapGI::_bind_methods() {
ClassDB::bind_method(D_METHOD("set_use_denoiser", "use_denoiser"), &LightmapGI::set_use_denoiser);
ClassDB::bind_method(D_METHOD("is_using_denoiser"), &LightmapGI::is_using_denoiser);
ClassDB::bind_method(D_METHOD("set_denoiser_strength", "denoiser_strength"), &LightmapGI::set_denoiser_strength);
ClassDB::bind_method(D_METHOD("get_denoiser_strength"), &LightmapGI::get_denoiser_strength);
ClassDB::bind_method(D_METHOD("set_interior", "enable"), &LightmapGI::set_interior);
ClassDB::bind_method(D_METHOD("is_interior"), &LightmapGI::is_interior);
@ -1535,6 +1550,7 @@ void LightmapGI::_bind_methods() {
ADD_PROPERTY(PropertyInfo(Variant::BOOL, "directional"), "set_directional", "is_directional");
ADD_PROPERTY(PropertyInfo(Variant::BOOL, "interior"), "set_interior", "is_interior");
ADD_PROPERTY(PropertyInfo(Variant::BOOL, "use_denoiser"), "set_use_denoiser", "is_using_denoiser");
ADD_PROPERTY(PropertyInfo(Variant::FLOAT, "denoiser_strength", PROPERTY_HINT_RANGE, "0.001,0.2,0.001,or_greater"), "set_denoiser_strength", "get_denoiser_strength");
ADD_PROPERTY(PropertyInfo(Variant::FLOAT, "bias", PROPERTY_HINT_RANGE, "0.00001,0.1,0.00001,or_greater"), "set_bias", "get_bias");
ADD_PROPERTY(PropertyInfo(Variant::INT, "max_texture_size", PROPERTY_HINT_RANGE, "2048,16384,1"), "set_max_texture_size", "get_max_texture_size");
ADD_GROUP("Environment", "environment_");

View file

@ -145,6 +145,7 @@ public:
private:
BakeQuality bake_quality = BAKE_QUALITY_MEDIUM;
bool use_denoiser = true;
float denoiser_strength = 0.1f;
int bounces = 3;
float bias = 0.0005;
int max_texture_size = 16384;
@ -239,6 +240,9 @@ public:
void set_use_denoiser(bool p_enable);
bool is_using_denoiser() const;
void set_denoiser_strength(float p_denoiser_strength);
float get_denoiser_strength() const;
void set_directional(bool p_enable);
bool is_directional() const;

View file

@ -180,7 +180,7 @@ public:
virtual void add_omni_light(bool p_static, const Vector3 &p_position, const Color &p_color, float p_energy, float p_range, float p_attenuation, float p_size, float p_shadow_blur) = 0;
virtual void add_spot_light(bool p_static, const Vector3 &p_position, const Vector3 p_direction, const Color &p_color, float p_energy, float p_range, float p_attenuation, float p_spot_angle, float p_spot_attenuation, float p_size, float p_shadow_blur) = 0;
virtual void add_probe(const Vector3 &p_position) = 0;
virtual BakeError bake(BakeQuality p_quality, bool p_use_denoiser, int p_bounces, float p_bias, int p_max_texture_size, bool p_bake_sh, GenerateProbes p_generate_probes, const Ref<Image> &p_environment_panorama, const Basis &p_environment_transform, BakeStepFunc p_step_function = nullptr, void *p_step_userdata = nullptr, float p_exposure_normalization = 1.0) = 0;
virtual BakeError bake(BakeQuality p_quality, bool p_use_denoiser, float p_denoiser_strength, int p_bounces, float p_bias, int p_max_texture_size, bool p_bake_sh, GenerateProbes p_generate_probes, const Ref<Image> &p_environment_panorama, const Basis &p_environment_transform, BakeStepFunc p_step_function = nullptr, void *p_step_userdata = nullptr, float p_exposure_normalization = 1.0) = 0;
virtual int get_bake_texture_count() const = 0;
virtual Ref<Image> get_bake_texture(int p_index) const = 0;

31
thirdparty/README.md vendored
View file

@ -650,37 +650,6 @@ Files extracted from the upstream source:
- `nvapi_minimal.h` was created by using `nvapi.h` from upstream and removing unnecessary code.
## oidn
- Upstream: https://github.com/OpenImageDenoise/oidn
- Version: 1.1.0 (c58c5216db05ceef4cde5a096862f2eeffd14c06, 2019)
- License: Apache 2.0
Files extracted from upstream source:
- common/* (except tasking.* and CMakeLists.txt)
- core/*
- include/OpenImageDenoise/* (except version.h.in)
- LICENSE.txt
- mkl-dnn/include/*
- mkl-dnn/src/* (except CMakeLists.txt)
- weights/rtlightmap_hdr.tza
- scripts/resource_to_cpp.py
Modified files:
Modifications are marked with `// -- GODOT start --` and `// -- GODOT end --`.
Patch files are provided in `oidn/patches/`.
- core/autoencoder.cpp
- core/autoencoder.h
- core/common.h
- core/device.cpp
- core/device.h
- core/transfer_function.cpp
- scripts/resource_to_cpp.py (used in modules/denoise/resource_to_cpp.py)
## openxr
- Upstream: https://github.com/KhronosGroup/OpenXR-SDK

View file

@ -1,202 +0,0 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View file

@ -1,52 +0,0 @@
// ======================================================================== //
// Copyright 2009-2019 Intel Corporation //
// //
// Licensed under the Apache License, Version 2.0 (the "License"); //
// you may not use this file except in compliance with the License. //
// You may obtain a copy of the License at //
// //
// http://www.apache.org/licenses/LICENSE-2.0 //
// //
// Unless required by applicable law or agreed to in writing, software //
// distributed under the License is distributed on an "AS IS" BASIS, //
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
// See the License for the specific language governing permissions and //
// limitations under the License. //
// ======================================================================== //
#pragma once
#include "platform.h"
#include <mutex>
#include <condition_variable>
namespace oidn {
class Barrier
{
private:
std::mutex m;
std::condition_variable cv;
volatile int count;
public:
Barrier(int count) : count(count) {}
void wait()
{
std::unique_lock<std::mutex> lk(m);
count--;
if (count == 0)
{
lk.unlock();
cv.notify_all();
}
else
{
cv.wait(lk, [&]{ return count == 0; });
}
}
};
} // namespace oidn

View file

@ -1,45 +0,0 @@
// ======================================================================== //
// Copyright 2009-2019 Intel Corporation //
// //
// Licensed under the Apache License, Version 2.0 (the "License"); //
// you may not use this file except in compliance with the License. //
// You may obtain a copy of the License at //
// //
// http://www.apache.org/licenses/LICENSE-2.0 //
// //
// Unless required by applicable law or agreed to in writing, software //
// distributed under the License is distributed on an "AS IS" BASIS, //
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
// See the License for the specific language governing permissions and //
// limitations under the License. //
// ======================================================================== //
#pragma once
#include <exception>
#include "platform.h"
namespace oidn {
class Exception : public std::exception
{
private:
Error error;
const char* message;
public:
Exception(Error error, const char* message)
: error(error), message(message) {}
Error code() const noexcept
{
return error;
}
const char* what() const noexcept override
{
return message;
}
};
} // namespace oidn

View file

@ -1,114 +0,0 @@
// ======================================================================== //
// Copyright 2009-2019 Intel Corporation //
// //
// Licensed under the Apache License, Version 2.0 (the "License"); //
// you may not use this file except in compliance with the License. //
// You may obtain a copy of the License at //
// //
// http://www.apache.org/licenses/LICENSE-2.0 //
// //
// Unless required by applicable law or agreed to in writing, software //
// distributed under the License is distributed on an "AS IS" BASIS, //
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
// See the License for the specific language governing permissions and //
// limitations under the License. //
// ======================================================================== //
#include "platform.h"
namespace oidn {
// ----------------------------------------------------------------------------
// Common functions
// ----------------------------------------------------------------------------
void* alignedMalloc(size_t size, size_t alignment)
{
if (size == 0)
return nullptr;
assert((alignment & (alignment-1)) == 0);
void* ptr = _mm_malloc(size, alignment);
if (ptr == nullptr)
throw std::bad_alloc();
return ptr;
}
void alignedFree(void* ptr)
{
if (ptr)
_mm_free(ptr);
}
// ----------------------------------------------------------------------------
// System information
// ----------------------------------------------------------------------------
std::string getPlatformName()
{
std::string name;
#if defined(__linux__)
name = "Linux";
#elif defined(__FreeBSD__)
name = "FreeBSD";
#elif defined(__CYGWIN__)
name = "Cygwin";
#elif defined(_WIN32)
name = "Windows";
#elif defined(__APPLE__)
name = "macOS";
#elif defined(__unix__)
name = "Unix";
#else
return "Unknown";
#endif
#if defined(__x86_64__) || defined(_M_X64) || defined(__ia64__) || defined(__aarch64__)
name += " (64-bit)";
#else
name += " (32-bit)";
#endif
return name;
}
std::string getCompilerName()
{
#if defined(__INTEL_COMPILER)
int mayor = __INTEL_COMPILER / 100 % 100;
int minor = __INTEL_COMPILER % 100;
std::string version = "Intel Compiler ";
version += toString(mayor);
version += "." + toString(minor);
#if defined(__INTEL_COMPILER_UPDATE)
version += "." + toString(__INTEL_COMPILER_UPDATE);
#endif
return version;
#elif defined(__clang__)
return "Clang " __clang_version__;
#elif defined(__GNUC__)
return "GCC " __VERSION__;
#elif defined(_MSC_VER)
std::string version = toString(_MSC_FULL_VER);
version.insert(4, ".");
version.insert(9, ".");
version.insert(2, ".");
return "Visual C++ Compiler " + version;
#else
return "Unknown";
#endif
}
std::string getBuildName()
{
#if defined(NDEBUG)
return "Release";
#else
return "Debug";
#endif
}
} // namespace oidn

View file

@ -1,131 +0,0 @@
// ======================================================================== //
// Copyright 2009-2019 Intel Corporation //
// //
// Licensed under the Apache License, Version 2.0 (the "License"); //
// you may not use this file except in compliance with the License. //
// You may obtain a copy of the License at //
// //
// http://www.apache.org/licenses/LICENSE-2.0 //
// //
// Unless required by applicable law or agreed to in writing, software //
// distributed under the License is distributed on an "AS IS" BASIS, //
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
// See the License for the specific language governing permissions and //
// limitations under the License. //
// ======================================================================== //
#pragma once
#if defined(_WIN32)
#define WIN32_LEAN_AND_MEAN
#define NOMINMAX
#include <windows.h>
#elif defined(__APPLE__)
#include <sys/sysctl.h>
#endif
#include <xmmintrin.h>
#include <cstdint>
#include <climits>
#include <limits>
#include <atomic>
#include <algorithm>
#include <memory>
#include <cmath>
#include <string>
#include <sstream>
#include <iostream>
#include <cassert>
#include "include/OpenImageDenoise/oidn.hpp"
namespace oidn {
// ----------------------------------------------------------------------------
// Macros
// ----------------------------------------------------------------------------
#if defined(_WIN32)
// Windows
#if !defined(__noinline)
#define __noinline __declspec(noinline)
#endif
#else
// Unix
#if !defined(__forceinline)
#define __forceinline inline __attribute__((always_inline))
#endif
#if !defined(__noinline)
#define __noinline __attribute__((noinline))
#endif
#endif
#ifndef UNUSED
#define UNUSED(x) ((void)x)
#endif
#ifndef MAYBE_UNUSED
#define MAYBE_UNUSED(x) UNUSED(x)
#endif
// ----------------------------------------------------------------------------
// Error handling and debugging
// ----------------------------------------------------------------------------
struct Verbose
{
int verbose;
Verbose(int v = 0) : verbose(v) {}
__forceinline bool isVerbose(int v = 1) const { return v <= verbose; }
};
#define OIDN_WARNING(message) { if (isVerbose()) std::cerr << "Warning: " << message << std::endl; }
#define OIDN_FATAL(message) throw std::runtime_error(message);
// ----------------------------------------------------------------------------
// Common functions
// ----------------------------------------------------------------------------
using std::min;
using std::max;
template<typename T>
__forceinline T clamp(const T& value, const T& minValue, const T& maxValue)
{
return min(max(value, minValue), maxValue);
}
void* alignedMalloc(size_t size, size_t alignment);
void alignedFree(void* ptr);
template<typename T>
inline std::string toString(const T& a)
{
std::stringstream sm;
sm << a;
return sm.str();
}
#if defined(__APPLE__)
template<typename T>
bool getSysctl(const char* name, T& value)
{
int64_t result = 0;
size_t size = sizeof(result);
if (sysctlbyname(name, &result, &size, nullptr, 0) != 0)
return false;
value = T(result);
return true;
}
#endif
// ----------------------------------------------------------------------------
// System information
// ----------------------------------------------------------------------------
std::string getPlatformName();
std::string getCompilerName();
std::string getBuildName();
} // namespace oidn

View file

@ -1,163 +0,0 @@
// ======================================================================== //
// Copyright 2009-2019 Intel Corporation //
// //
// Licensed under the Apache License, Version 2.0 (the "License"); //
// you may not use this file except in compliance with the License. //
// You may obtain a copy of the License at //
// //
// http://www.apache.org/licenses/LICENSE-2.0 //
// //
// Unless required by applicable law or agreed to in writing, software //
// distributed under the License is distributed on an "AS IS" BASIS, //
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
// See the License for the specific language governing permissions and //
// limitations under the License. //
// ======================================================================== //
#pragma once
#include "platform.h"
namespace oidn {
class RefCount
{
private:
std::atomic<size_t> count;
public:
__forceinline RefCount(int count = 0) noexcept : count(count) {}
__forceinline size_t incRef() noexcept
{
return count.fetch_add(1) + 1;
}
__forceinline size_t decRef()
{
const size_t newCount = decRefKeep();
if (newCount == 0)
destroy();
return newCount;
}
__forceinline size_t decRefKeep() noexcept
{
return count.fetch_add(-1) - 1;
}
__forceinline void destroy()
{
delete this;
}
protected:
// Disable copying
RefCount(const RefCount&) = delete;
RefCount& operator =(const RefCount&) = delete;
virtual ~RefCount() noexcept = default;
};
template<typename T>
class Ref
{
private:
T* ptr;
public:
__forceinline Ref() noexcept : ptr(nullptr) {}
__forceinline Ref(std::nullptr_t) noexcept : ptr(nullptr) {}
__forceinline Ref(const Ref& other) noexcept : ptr(other.ptr) { if (ptr) ptr->incRef(); }
__forceinline Ref(Ref&& other) noexcept : ptr(other.ptr) { other.ptr = nullptr; }
__forceinline Ref(T* ptr) noexcept : ptr(ptr) { if (ptr) ptr->incRef(); }
template<typename Y>
__forceinline Ref(const Ref<Y>& other) noexcept : ptr(other.get()) { if (ptr) ptr->incRef(); }
template<typename Y>
__forceinline explicit Ref(Y* ptr) noexcept : ptr(ptr) { if (ptr) ptr->incRef(); }
__forceinline ~Ref() { if (ptr) ptr->decRef(); }
__forceinline Ref& operator =(const Ref& other)
{
if (other.ptr)
other.ptr->incRef();
if (ptr)
ptr->decRef();
ptr = other.ptr;
return *this;
}
__forceinline Ref& operator =(Ref&& other)
{
if (ptr)
ptr->decRef();
ptr = other.ptr;
other.ptr = nullptr;
return *this;
}
__forceinline Ref& operator =(T* other)
{
if (other)
other->incRef();
if (ptr)
ptr->decRef();
ptr = other;
return *this;
}
__forceinline Ref& operator =(std::nullptr_t)
{
if (ptr)
ptr->decRef();
ptr = nullptr;
return *this;
}
__forceinline operator bool() const noexcept { return ptr != nullptr; }
__forceinline T& operator *() const noexcept { return *ptr; }
__forceinline T* operator ->() const noexcept { return ptr; }
__forceinline T* get() const noexcept { return ptr; }
__forceinline T* detach() noexcept
{
T* res = ptr;
ptr = nullptr;
return res;
}
};
template<typename T> __forceinline bool operator < (const Ref<T>& a, const Ref<T>& b) noexcept { return a.ptr < b.ptr; }
template<typename T> __forceinline bool operator ==(const Ref<T>& a, std::nullptr_t) noexcept { return a.ptr == nullptr; }
template<typename T> __forceinline bool operator ==(std::nullptr_t, const Ref<T>& b) noexcept { return nullptr == b.ptr; }
template<typename T> __forceinline bool operator ==(const Ref<T>& a, const Ref<T>& b) noexcept { return a.ptr == b.ptr; }
template<typename T> __forceinline bool operator !=(const Ref<T>& a, std::nullptr_t) noexcept { return a.ptr != nullptr; }
template<typename T> __forceinline bool operator !=(std::nullptr_t, const Ref<T>& b) noexcept { return nullptr != b.ptr; }
template<typename T> __forceinline bool operator !=(const Ref<T>& a, const Ref<T>& b) noexcept { return a.ptr != b.ptr; }
template<typename T, typename... Args>
__forceinline Ref<T> makeRef(Args&&... args)
{
return Ref<T>(new T(std::forward<Args>(args)...));
}
template<typename T, typename Y>
__forceinline Ref<Y> staticRefCast(const Ref<T>& a)
{
return Ref<Y>(static_cast<Y*>(a.get()));
}
template<typename T, typename Y>
__forceinline Ref<Y> dynamicRefCast(const Ref<T>& a)
{
return Ref<Y>(dynamic_cast<Y*>(a.get()));
}
} // namespace oidn

View file

@ -1,83 +0,0 @@
// ======================================================================== //
// Copyright 2009-2019 Intel Corporation //
// //
// Licensed under the Apache License, Version 2.0 (the "License"); //
// you may not use this file except in compliance with the License. //
// You may obtain a copy of the License at //
// //
// http://www.apache.org/licenses/LICENSE-2.0 //
// //
// Unless required by applicable law or agreed to in writing, software //
// distributed under the License is distributed on an "AS IS" BASIS, //
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
// See the License for the specific language governing permissions and //
// limitations under the License. //
// ======================================================================== //
#include "exception.h"
#include "tensor.h"
namespace oidn {
std::map<std::string, Tensor> parseTensors(void* buffer)
{
char* input = (char*)buffer;
// Parse the magic value
const int magic = *(unsigned short*)input;
if (magic != 0x41D7)
throw Exception(Error::InvalidOperation, "invalid tensor archive");
input += sizeof(unsigned short);
// Parse the version
const int majorVersion = *(unsigned char*)input++;
const int minorVersion = *(unsigned char*)input++;
UNUSED(minorVersion);
if (majorVersion > 1)
throw Exception(Error::InvalidOperation, "unsupported tensor archive version");
// Parse the number of tensors
const int numTensors = *(int*)input;
input += sizeof(int);
// Parse the tensors
std::map<std::string, Tensor> tensorMap;
for (int i = 0; i < numTensors; ++i)
{
Tensor tensor;
// Parse the name
const int nameLen = *(unsigned char*)input++;
std::string name(input, nameLen);
input += nameLen;
// Parse the number of dimensions
const int ndims = *(unsigned char*)input++;
// Parse the shape of the tensor
tensor.dims.resize(ndims);
for (int i = 0; i < ndims; ++i)
tensor.dims[i] = ((int*)input)[i];
input += ndims * sizeof(int);
// Parse the format of the tensor
tensor.format = std::string(input, input + ndims);
input += ndims;
// Parse the data type of the tensor
const char type = *(unsigned char*)input++;
if (type != 'f') // only float32 is supported
throw Exception(Error::InvalidOperation, "unsupported tensor data type");
// Skip the data
tensor.data = (float*)input;
input += tensor.size() * sizeof(float);
// Add the tensor to the map
tensorMap.emplace(name, std::move(tensor));
}
return tensorMap;
}
} // namespace oidn

View file

@ -1,66 +0,0 @@
// ======================================================================== //
// Copyright 2009-2019 Intel Corporation //
// //
// Licensed under the Apache License, Version 2.0 (the "License"); //
// you may not use this file except in compliance with the License. //
// You may obtain a copy of the License at //
// //
// http://www.apache.org/licenses/LICENSE-2.0 //
// //
// Unless required by applicable law or agreed to in writing, software //
// distributed under the License is distributed on an "AS IS" BASIS, //
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
// See the License for the specific language governing permissions and //
// limitations under the License. //
// ======================================================================== //
#pragma once
#include "platform.h"
#include <vector>
#include <map>
namespace oidn {
template<typename T>
using shared_vector = std::shared_ptr<std::vector<T>>;
// Generic tensor
struct Tensor
{
float* data;
std::vector<int64_t> dims;
std::string format;
shared_vector<char> buffer; // optional, only for reference counting
__forceinline Tensor() : data(nullptr) {}
__forceinline Tensor(const std::vector<int64_t>& dims, const std::string& format)
: dims(dims),
format(format)
{
buffer = std::make_shared<std::vector<char>>(size() * sizeof(float));
data = (float*)buffer->data();
}
__forceinline operator bool() const { return data != nullptr; }
__forceinline int ndims() const { return (int)dims.size(); }
// Returns the number of values
__forceinline size_t size() const
{
size_t size = 1;
for (int i = 0; i < ndims(); ++i)
size *= dims[i];
return size;
}
__forceinline float& operator [](size_t i) { return data[i]; }
__forceinline const float& operator [](size_t i) const { return data[i]; }
};
// Parses tensors from a buffer
std::map<std::string, Tensor> parseTensors(void* buffer);
} // namespace oidn

View file

@ -1,297 +0,0 @@
// ======================================================================== //
// Copyright 2009-2019 Intel Corporation //
// //
// Licensed under the Apache License, Version 2.0 (the "License"); //
// you may not use this file except in compliance with the License. //
// You may obtain a copy of the License at //
// //
// http://www.apache.org/licenses/LICENSE-2.0 //
// //
// Unless required by applicable law or agreed to in writing, software //
// distributed under the License is distributed on an "AS IS" BASIS, //
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
// See the License for the specific language governing permissions and //
// limitations under the License. //
// ======================================================================== //
#if defined(_MSC_VER)
#pragma warning (disable : 4146) // unary minus operator applied to unsigned type, result still unsigned
#endif
#if defined(__APPLE__)
#include <mach/thread_act.h>
#include <mach/mach_init.h>
#endif
#include "thread.h"
#include <fstream>
namespace oidn {
#if defined(_WIN32)
// --------------------------------------------------------------------------
// ThreadAffinity - Windows
// --------------------------------------------------------------------------
ThreadAffinity::ThreadAffinity(int numThreadsPerCore, int verbose)
: Verbose(verbose)
{
HMODULE hLib = GetModuleHandle(TEXT("kernel32"));
pGetLogicalProcessorInformationEx = (GetLogicalProcessorInformationExFunc)GetProcAddress(hLib, "GetLogicalProcessorInformationEx");
pSetThreadGroupAffinity = (SetThreadGroupAffinityFunc)GetProcAddress(hLib, "SetThreadGroupAffinity");
if (pGetLogicalProcessorInformationEx && pSetThreadGroupAffinity)
{
// Get logical processor information
PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX buffer = nullptr;
DWORD bufferSize = 0;
// First call the function with an empty buffer to get the required buffer size
BOOL result = pGetLogicalProcessorInformationEx(RelationProcessorCore, buffer, &bufferSize);
if (result || GetLastError() != ERROR_INSUFFICIENT_BUFFER)
{
OIDN_WARNING("GetLogicalProcessorInformationEx failed");
return;
}
// Allocate the buffer
buffer = (PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX)malloc(bufferSize);
if (!buffer)
{
OIDN_WARNING("SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX allocation failed");
return;
}
// Call again the function but now with the properly sized buffer
result = pGetLogicalProcessorInformationEx(RelationProcessorCore, buffer, &bufferSize);
if (!result)
{
OIDN_WARNING("GetLogicalProcessorInformationEx failed");
free(buffer);
return;
}
// Iterate over the logical processor information structures
// There should be one structure for each physical core
char* ptr = (char*)buffer;
while (ptr < (char*)buffer + bufferSize)
{
PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX item = (PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX)ptr;
if (item->Relationship == RelationProcessorCore && item->Processor.GroupCount > 0)
{
// Iterate over the groups
int numThreads = 0;
for (int group = 0; (group < item->Processor.GroupCount) && (numThreads < numThreadsPerCore); ++group)
{
GROUP_AFFINITY coreAffinity = item->Processor.GroupMask[group];
while ((coreAffinity.Mask != 0) && (numThreads < numThreadsPerCore))
{
// Extract the next set bit/thread from the mask
GROUP_AFFINITY threadAffinity = coreAffinity;
threadAffinity.Mask = threadAffinity.Mask & -threadAffinity.Mask;
// Push the affinity for this thread
affinities.push_back(threadAffinity);
oldAffinities.push_back(threadAffinity);
numThreads++;
// Remove this bit/thread from the mask
coreAffinity.Mask ^= threadAffinity.Mask;
}
}
}
// Next structure
ptr += item->Size;
}
// Free the buffer
free(buffer);
}
}
void ThreadAffinity::set(int threadIndex)
{
if (threadIndex >= (int)affinities.size())
return;
// Save the current affinity and set the new one
const HANDLE thread = GetCurrentThread();
if (!pSetThreadGroupAffinity(thread, &affinities[threadIndex], &oldAffinities[threadIndex]))
OIDN_WARNING("SetThreadGroupAffinity failed");
}
void ThreadAffinity::restore(int threadIndex)
{
if (threadIndex >= (int)affinities.size())
return;
// Restore the original affinity
const HANDLE thread = GetCurrentThread();
if (!pSetThreadGroupAffinity(thread, &oldAffinities[threadIndex], nullptr))
OIDN_WARNING("SetThreadGroupAffinity failed");
}
#elif defined(__linux__)
// --------------------------------------------------------------------------
// ThreadAffinity - Linux
// --------------------------------------------------------------------------
ThreadAffinity::ThreadAffinity(int numThreadsPerCore, int verbose)
: Verbose(verbose)
{
std::vector<int> threadIds;
// Parse the thread/CPU topology
for (int cpuId = 0; ; cpuId++)
{
std::fstream fs;
std::string cpu = std::string("/sys/devices/system/cpu/cpu") + std::to_string(cpuId) + std::string("/topology/thread_siblings_list");
fs.open(cpu.c_str(), std::fstream::in);
if (fs.fail()) break;
int i;
int j = 0;
while ((j < numThreadsPerCore) && (fs >> i))
{
if (std::none_of(threadIds.begin(), threadIds.end(), [&](int id) { return id == i; }))
threadIds.push_back(i);
if (fs.peek() == ',')
fs.ignore();
j++;
}
fs.close();
}
#if 0
for (size_t i = 0; i < thread_ids.size(); ++i)
std::cout << "thread " << i << " -> " << thread_ids[i] << std::endl;
#endif
// Create the affinity structures
affinities.resize(threadIds.size());
oldAffinities.resize(threadIds.size());
for (size_t i = 0; i < threadIds.size(); ++i)
{
cpu_set_t affinity;
CPU_ZERO(&affinity);
CPU_SET(threadIds[i], &affinity);
affinities[i] = affinity;
oldAffinities[i] = affinity;
}
}
void ThreadAffinity::set(int threadIndex)
{
if (threadIndex >= (int)affinities.size())
return;
const pthread_t thread = pthread_self();
// Save the current affinity
if (pthread_getaffinity_np(thread, sizeof(cpu_set_t), &oldAffinities[threadIndex]) != 0)
{
OIDN_WARNING("pthread_getaffinity_np failed");
oldAffinities[threadIndex] = affinities[threadIndex];
return;
}
// Set the new affinity
if (pthread_setaffinity_np(thread, sizeof(cpu_set_t), &affinities[threadIndex]) != 0)
OIDN_WARNING("pthread_setaffinity_np failed");
}
void ThreadAffinity::restore(int threadIndex)
{
if (threadIndex >= (int)affinities.size())
return;
const pthread_t thread = pthread_self();
// Restore the original affinity
if (pthread_setaffinity_np(thread, sizeof(cpu_set_t), &oldAffinities[threadIndex]) != 0)
OIDN_WARNING("pthread_setaffinity_np failed");
}
#elif defined(__APPLE__)
// --------------------------------------------------------------------------
// ThreadAffinity - macOS
// --------------------------------------------------------------------------
ThreadAffinity::ThreadAffinity(int numThreadsPerCore, int verbose)
: Verbose(verbose)
{
// Query the thread/CPU topology
int numPhysicalCpus;
int numLogicalCpus;
if (!getSysctl("hw.physicalcpu", numPhysicalCpus) || !getSysctl("hw.logicalcpu", numLogicalCpus))
{
OIDN_WARNING("sysctlbyname failed");
return;
}
if ((numLogicalCpus % numPhysicalCpus != 0) && (numThreadsPerCore > 1))
return; // this shouldn't happen
const int maxThreadsPerCore = numLogicalCpus / numPhysicalCpus;
// Create the affinity structures
// macOS doesn't support binding a thread to a specific core, but we can at least group threads which
// should be on the same core together
for (int core = 1; core <= numPhysicalCpus; ++core) // tags start from 1!
{
thread_affinity_policy affinity;
affinity.affinity_tag = core;
for (int thread = 0; thread < min(numThreadsPerCore, maxThreadsPerCore); ++thread)
{
affinities.push_back(affinity);
oldAffinities.push_back(affinity);
}
}
}
void ThreadAffinity::set(int threadIndex)
{
if (threadIndex >= (int)affinities.size())
return;
const auto thread = mach_thread_self();
// Save the current affinity
mach_msg_type_number_t policyCount = THREAD_AFFINITY_POLICY_COUNT;
boolean_t getDefault = FALSE;
if (thread_policy_get(thread, THREAD_AFFINITY_POLICY, (thread_policy_t)&oldAffinities[threadIndex], &policyCount, &getDefault) != KERN_SUCCESS)
{
OIDN_WARNING("thread_policy_get failed");
oldAffinities[threadIndex] = affinities[threadIndex];
return;
}
// Set the new affinity
if (thread_policy_set(thread, THREAD_AFFINITY_POLICY, (thread_policy_t)&affinities[threadIndex], THREAD_AFFINITY_POLICY_COUNT) != KERN_SUCCESS)
OIDN_WARNING("thread_policy_set failed");
}
void ThreadAffinity::restore(int threadIndex)
{
if (threadIndex >= (int)affinities.size())
return;
const auto thread = mach_thread_self();
// Restore the original affinity
if (thread_policy_set(thread, THREAD_AFFINITY_POLICY, (thread_policy_t)&oldAffinities[threadIndex], THREAD_AFFINITY_POLICY_COUNT) != KERN_SUCCESS)
OIDN_WARNING("thread_policy_set failed");
}
#endif
} // namespace oidn

View file

@ -1,202 +0,0 @@
// ======================================================================== //
// Copyright 2009-2019 Intel Corporation //
// //
// Licensed under the Apache License, Version 2.0 (the "License"); //
// you may not use this file except in compliance with the License. //
// You may obtain a copy of the License at //
// //
// http://www.apache.org/licenses/LICENSE-2.0 //
// //
// Unless required by applicable law or agreed to in writing, software //
// distributed under the License is distributed on an "AS IS" BASIS, //
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
// See the License for the specific language governing permissions and //
// limitations under the License. //
// ======================================================================== //
#pragma once
#include "platform.h"
#if !defined(_WIN32)
#include <pthread.h>
#include <sched.h>
#if defined(__APPLE__)
#include <mach/thread_policy.h>
#endif
#endif
#include <vector>
#include <mutex>
namespace oidn {
// --------------------------------------------------------------------------
// ThreadLocal
// --------------------------------------------------------------------------
// Wrapper which makes any variable thread-local
template<typename T>
class ThreadLocal : public Verbose
{
private:
#if defined(_WIN32)
DWORD key;
#else
pthread_key_t key;
#endif
std::vector<T*> instances;
std::mutex mutex;
public:
ThreadLocal(int verbose = 0)
: Verbose(verbose)
{
#if defined(_WIN32)
key = TlsAlloc();
if (key == TLS_OUT_OF_INDEXES)
OIDN_FATAL("TlsAlloc failed");
#else
if (pthread_key_create(&key, nullptr) != 0)
OIDN_FATAL("pthread_key_create failed");
#endif
}
~ThreadLocal()
{
std::lock_guard<std::mutex> lock(mutex);
for (T* ptr : instances)
delete ptr;
#if defined(_WIN32)
if (!TlsFree(key))
OIDN_WARNING("TlsFree failed");
#else
if (pthread_key_delete(key) != 0)
OIDN_WARNING("pthread_key_delete failed");
#endif
}
T& get()
{
#if defined(_WIN32)
T* ptr = (T*)TlsGetValue(key);
#else
T* ptr = (T*)pthread_getspecific(key);
#endif
if (ptr)
return *ptr;
ptr = new T;
std::lock_guard<std::mutex> lock(mutex);
instances.push_back(ptr);
#if defined(_WIN32)
if (!TlsSetValue(key, ptr))
OIDN_FATAL("TlsSetValue failed");
#else
if (pthread_setspecific(key, ptr) != 0)
OIDN_FATAL("pthread_setspecific failed");
#endif
return *ptr;
}
};
#if defined(_WIN32)
// --------------------------------------------------------------------------
// ThreadAffinity - Windows
// --------------------------------------------------------------------------
class ThreadAffinity : public Verbose
{
private:
typedef BOOL (WINAPI *GetLogicalProcessorInformationExFunc)(LOGICAL_PROCESSOR_RELATIONSHIP,
PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX,
PDWORD);
typedef BOOL (WINAPI *SetThreadGroupAffinityFunc)(HANDLE,
CONST GROUP_AFFINITY*,
PGROUP_AFFINITY);
GetLogicalProcessorInformationExFunc pGetLogicalProcessorInformationEx = nullptr;
SetThreadGroupAffinityFunc pSetThreadGroupAffinity = nullptr;
std::vector<GROUP_AFFINITY> affinities; // thread affinities
std::vector<GROUP_AFFINITY> oldAffinities; // original thread affinities
public:
ThreadAffinity(int numThreadsPerCore = INT_MAX, int verbose = 0);
int getNumThreads() const
{
return (int)affinities.size();
}
// Sets the affinity (0..numThreads-1) of the thread after saving the current affinity
void set(int threadIndex);
// Restores the affinity of the thread
void restore(int threadIndex);
};
#elif defined(__linux__)
// --------------------------------------------------------------------------
// ThreadAffinity - Linux
// --------------------------------------------------------------------------
class ThreadAffinity : public Verbose
{
private:
std::vector<cpu_set_t> affinities; // thread affinities
std::vector<cpu_set_t> oldAffinities; // original thread affinities
public:
ThreadAffinity(int numThreadsPerCore = INT_MAX, int verbose = 0);
int getNumThreads() const
{
return (int)affinities.size();
}
// Sets the affinity (0..numThreads-1) of the thread after saving the current affinity
void set(int threadIndex);
// Restores the affinity of the thread
void restore(int threadIndex);
};
#elif defined(__APPLE__)
// --------------------------------------------------------------------------
// ThreadAffinity - macOS
// --------------------------------------------------------------------------
class ThreadAffinity : public Verbose
{
private:
std::vector<thread_affinity_policy> affinities; // thread affinities
std::vector<thread_affinity_policy> oldAffinities; // original thread affinities
public:
ThreadAffinity(int numThreadsPerCore = INT_MAX, int verbose = 0);
int getNumThreads() const
{
return (int)affinities.size();
}
// Sets the affinity (0..numThreads-1) of the thread after saving the current affinity
void set(int threadIndex);
// Restores the affinity of the thread
void restore(int threadIndex);
};
#endif
} // namespace oidn

View file

@ -1,49 +0,0 @@
// ======================================================================== //
// Copyright 2009-2019 Intel Corporation //
// //
// Licensed under the Apache License, Version 2.0 (the "License"); //
// you may not use this file except in compliance with the License. //
// You may obtain a copy of the License at //
// //
// http://www.apache.org/licenses/LICENSE-2.0 //
// //
// Unless required by applicable law or agreed to in writing, software //
// distributed under the License is distributed on an "AS IS" BASIS, //
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
// See the License for the specific language governing permissions and //
// limitations under the License. //
// ======================================================================== //
#pragma once
#include "platform.h"
#include <chrono>
namespace oidn {
class Timer
{
private:
using clock = std::chrono::high_resolution_clock;
std::chrono::time_point<clock> start;
public:
Timer()
{
reset();
}
void reset()
{
start = clock::now();
}
double query() const
{
auto end = clock::now();
return std::chrono::duration_cast<std::chrono::duration<double>>(end - start).count();
}
};
} // namespace oidn

View file

@ -1,408 +0,0 @@
// ======================================================================== //
// Copyright 2009-2019 Intel Corporation //
// //
// Licensed under the Apache License, Version 2.0 (the "License"); //
// you may not use this file except in compliance with the License. //
// You may obtain a copy of the License at //
// //
// http://www.apache.org/licenses/LICENSE-2.0 //
// //
// Unless required by applicable law or agreed to in writing, software //
// distributed under the License is distributed on an "AS IS" BASIS, //
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
// See the License for the specific language governing permissions and //
// limitations under the License. //
// ======================================================================== //
#ifdef _WIN32
# define OIDN_API extern "C" __declspec(dllexport)
#else
# define OIDN_API extern "C" __attribute__ ((visibility ("default")))
#endif
// Locks the device that owns the specified object
// Use *only* inside OIDN_TRY/CATCH!
#define OIDN_LOCK(obj) \
std::lock_guard<std::mutex> lock(obj->getDevice()->getMutex());
// Try/catch for converting exceptions to errors
#define OIDN_TRY \
try {
#define OIDN_CATCH(obj) \
} catch (Exception& e) { \
Device::setError(obj ? obj->getDevice() : nullptr, e.code(), e.what()); \
} catch (std::bad_alloc&) { \
Device::setError(obj ? obj->getDevice() : nullptr, Error::OutOfMemory, "out of memory"); \
} catch (mkldnn::error& e) { \
if (e.status == mkldnn_out_of_memory) \
Device::setError(obj ? obj->getDevice() : nullptr, Error::OutOfMemory, "out of memory"); \
else \
Device::setError(obj ? obj->getDevice() : nullptr, Error::Unknown, e.message); \
} catch (std::exception& e) { \
Device::setError(obj ? obj->getDevice() : nullptr, Error::Unknown, e.what()); \
} catch (...) { \
Device::setError(obj ? obj->getDevice() : nullptr, Error::Unknown, "unknown exception caught"); \
}
#include "device.h"
#include "filter.h"
#include <mutex>
namespace oidn {
namespace
{
__forceinline void checkHandle(void* handle)
{
if (handle == nullptr)
throw Exception(Error::InvalidArgument, "invalid handle");
}
template<typename T>
__forceinline void retainObject(T* obj)
{
if (obj)
{
obj->incRef();
}
else
{
OIDN_TRY
checkHandle(obj);
OIDN_CATCH(obj)
}
}
template<typename T>
__forceinline void releaseObject(T* obj)
{
if (obj == nullptr || obj->decRefKeep() == 0)
{
OIDN_TRY
checkHandle(obj);
OIDN_LOCK(obj);
obj->destroy();
OIDN_CATCH(obj)
}
}
template<>
__forceinline void releaseObject(Device* obj)
{
if (obj == nullptr || obj->decRefKeep() == 0)
{
OIDN_TRY
checkHandle(obj);
// Do NOT lock the device because it owns the mutex
obj->destroy();
OIDN_CATCH(obj)
}
}
}
OIDN_API OIDNDevice oidnNewDevice(OIDNDeviceType type)
{
Ref<Device> device = nullptr;
OIDN_TRY
if (type == OIDN_DEVICE_TYPE_CPU || type == OIDN_DEVICE_TYPE_DEFAULT)
device = makeRef<Device>();
else
throw Exception(Error::InvalidArgument, "invalid device type");
OIDN_CATCH(device)
return (OIDNDevice)device.detach();
}
OIDN_API void oidnRetainDevice(OIDNDevice hDevice)
{
Device* device = (Device*)hDevice;
retainObject(device);
}
OIDN_API void oidnReleaseDevice(OIDNDevice hDevice)
{
Device* device = (Device*)hDevice;
releaseObject(device);
}
OIDN_API void oidnSetDevice1b(OIDNDevice hDevice, const char* name, bool value)
{
Device* device = (Device*)hDevice;
OIDN_TRY
checkHandle(hDevice);
OIDN_LOCK(device);
device->set1i(name, value);
OIDN_CATCH(device)
}
OIDN_API void oidnSetDevice1i(OIDNDevice hDevice, const char* name, int value)
{
Device* device = (Device*)hDevice;
OIDN_TRY
checkHandle(hDevice);
OIDN_LOCK(device);
device->set1i(name, value);
OIDN_CATCH(device)
}
OIDN_API bool oidnGetDevice1b(OIDNDevice hDevice, const char* name)
{
Device* device = (Device*)hDevice;
OIDN_TRY
checkHandle(hDevice);
OIDN_LOCK(device);
return device->get1i(name);
OIDN_CATCH(device)
return false;
}
OIDN_API int oidnGetDevice1i(OIDNDevice hDevice, const char* name)
{
Device* device = (Device*)hDevice;
OIDN_TRY
checkHandle(hDevice);
OIDN_LOCK(device);
return device->get1i(name);
OIDN_CATCH(device)
return 0;
}
OIDN_API void oidnSetDeviceErrorFunction(OIDNDevice hDevice, OIDNErrorFunction func, void* userPtr)
{
Device* device = (Device*)hDevice;
OIDN_TRY
checkHandle(hDevice);
OIDN_LOCK(device);
device->setErrorFunction((ErrorFunction)func, userPtr);
OIDN_CATCH(device)
}
OIDN_API OIDNError oidnGetDeviceError(OIDNDevice hDevice, const char** outMessage)
{
Device* device = (Device*)hDevice;
OIDN_TRY
return (OIDNError)Device::getError(device, outMessage);
OIDN_CATCH(device)
if (outMessage) *outMessage = "";
return OIDN_ERROR_UNKNOWN;
}
OIDN_API void oidnCommitDevice(OIDNDevice hDevice)
{
Device* device = (Device*)hDevice;
OIDN_TRY
checkHandle(hDevice);
OIDN_LOCK(device);
device->commit();
OIDN_CATCH(device)
}
OIDN_API OIDNBuffer oidnNewBuffer(OIDNDevice hDevice, size_t byteSize)
{
Device* device = (Device*)hDevice;
OIDN_TRY
checkHandle(hDevice);
OIDN_LOCK(device);
Ref<Buffer> buffer = device->newBuffer(byteSize);
return (OIDNBuffer)buffer.detach();
OIDN_CATCH(device)
return nullptr;
}
OIDN_API OIDNBuffer oidnNewSharedBuffer(OIDNDevice hDevice, void* ptr, size_t byteSize)
{
Device* device = (Device*)hDevice;
OIDN_TRY
checkHandle(hDevice);
OIDN_LOCK(device);
Ref<Buffer> buffer = device->newBuffer(ptr, byteSize);
return (OIDNBuffer)buffer.detach();
OIDN_CATCH(device)
return nullptr;
}
OIDN_API void oidnRetainBuffer(OIDNBuffer hBuffer)
{
Buffer* buffer = (Buffer*)hBuffer;
retainObject(buffer);
}
OIDN_API void oidnReleaseBuffer(OIDNBuffer hBuffer)
{
Buffer* buffer = (Buffer*)hBuffer;
releaseObject(buffer);
}
OIDN_API void* oidnMapBuffer(OIDNBuffer hBuffer, OIDNAccess access, size_t byteOffset, size_t byteSize)
{
Buffer* buffer = (Buffer*)hBuffer;
OIDN_TRY
checkHandle(hBuffer);
OIDN_LOCK(buffer);
return buffer->map(byteOffset, byteSize);
OIDN_CATCH(buffer)
return nullptr;
}
OIDN_API void oidnUnmapBuffer(OIDNBuffer hBuffer, void* mappedPtr)
{
Buffer* buffer = (Buffer*)hBuffer;
OIDN_TRY
checkHandle(hBuffer);
OIDN_LOCK(buffer);
return buffer->unmap(mappedPtr);
OIDN_CATCH(buffer)
}
OIDN_API OIDNFilter oidnNewFilter(OIDNDevice hDevice, const char* type)
{
Device* device = (Device*)hDevice;
OIDN_TRY
checkHandle(hDevice);
OIDN_LOCK(device);
Ref<Filter> filter = device->newFilter(type);
return (OIDNFilter)filter.detach();
OIDN_CATCH(device)
return nullptr;
}
OIDN_API void oidnRetainFilter(OIDNFilter hFilter)
{
Filter* filter = (Filter*)hFilter;
retainObject(filter);
}
OIDN_API void oidnReleaseFilter(OIDNFilter hFilter)
{
Filter* filter = (Filter*)hFilter;
releaseObject(filter);
}
OIDN_API void oidnSetFilterImage(OIDNFilter hFilter, const char* name,
OIDNBuffer hBuffer, OIDNFormat format,
size_t width, size_t height,
size_t byteOffset,
size_t bytePixelStride, size_t byteRowStride)
{
Filter* filter = (Filter*)hFilter;
OIDN_TRY
checkHandle(hFilter);
checkHandle(hBuffer);
OIDN_LOCK(filter);
Ref<Buffer> buffer = (Buffer*)hBuffer;
if (buffer->getDevice() != filter->getDevice())
throw Exception(Error::InvalidArgument, "the specified objects are bound to different devices");
Image data(buffer, (Format)format, (int)width, (int)height, byteOffset, bytePixelStride, byteRowStride);
filter->setImage(name, data);
OIDN_CATCH(filter)
}
OIDN_API void oidnSetSharedFilterImage(OIDNFilter hFilter, const char* name,
void* ptr, OIDNFormat format,
size_t width, size_t height,
size_t byteOffset,
size_t bytePixelStride, size_t byteRowStride)
{
Filter* filter = (Filter*)hFilter;
OIDN_TRY
checkHandle(hFilter);
OIDN_LOCK(filter);
Image data(ptr, (Format)format, (int)width, (int)height, byteOffset, bytePixelStride, byteRowStride);
filter->setImage(name, data);
OIDN_CATCH(filter)
}
OIDN_API void oidnSetFilter1b(OIDNFilter hFilter, const char* name, bool value)
{
Filter* filter = (Filter*)hFilter;
OIDN_TRY
checkHandle(hFilter);
OIDN_LOCK(filter);
filter->set1i(name, int(value));
OIDN_CATCH(filter)
}
OIDN_API bool oidnGetFilter1b(OIDNFilter hFilter, const char* name)
{
Filter* filter = (Filter*)hFilter;
OIDN_TRY
checkHandle(hFilter);
OIDN_LOCK(filter);
return filter->get1i(name);
OIDN_CATCH(filter)
return false;
}
OIDN_API void oidnSetFilter1i(OIDNFilter hFilter, const char* name, int value)
{
Filter* filter = (Filter*)hFilter;
OIDN_TRY
checkHandle(hFilter);
OIDN_LOCK(filter);
filter->set1i(name, value);
OIDN_CATCH(filter)
}
OIDN_API int oidnGetFilter1i(OIDNFilter hFilter, const char* name)
{
Filter* filter = (Filter*)hFilter;
OIDN_TRY
checkHandle(hFilter);
OIDN_LOCK(filter);
return filter->get1i(name);
OIDN_CATCH(filter)
return 0;
}
OIDN_API void oidnSetFilter1f(OIDNFilter hFilter, const char* name, float value)
{
Filter* filter = (Filter*)hFilter;
OIDN_TRY
checkHandle(hFilter);
OIDN_LOCK(filter);
filter->set1f(name, value);
OIDN_CATCH(filter)
}
OIDN_API float oidnGetFilter1f(OIDNFilter hFilter, const char* name)
{
Filter* filter = (Filter*)hFilter;
OIDN_TRY
checkHandle(hFilter);
OIDN_LOCK(filter);
return filter->get1f(name);
OIDN_CATCH(filter)
return 0;
}
OIDN_API void oidnSetFilterProgressMonitorFunction(OIDNFilter hFilter, OIDNProgressMonitorFunction func, void* userPtr)
{
Filter* filter = (Filter*)hFilter;
OIDN_TRY
checkHandle(hFilter);
OIDN_LOCK(filter);
filter->setProgressMonitorFunction(func, userPtr);
OIDN_CATCH(filter)
}
OIDN_API void oidnCommitFilter(OIDNFilter hFilter)
{
Filter* filter = (Filter*)hFilter;
OIDN_TRY
checkHandle(hFilter);
OIDN_LOCK(filter);
filter->commit();
OIDN_CATCH(filter)
}
OIDN_API void oidnExecuteFilter(OIDNFilter hFilter)
{
Filter* filter = (Filter*)hFilter;
OIDN_TRY
checkHandle(hFilter);
OIDN_LOCK(filter);
filter->execute();
OIDN_CATCH(filter)
}
} // namespace oidn

View file

@ -1,535 +0,0 @@
// ======================================================================== //
// Copyright 2009-2019 Intel Corporation //
// //
// Licensed under the Apache License, Version 2.0 (the "License"); //
// you may not use this file except in compliance with the License. //
// You may obtain a copy of the License at //
// //
// http://www.apache.org/licenses/LICENSE-2.0 //
// //
// Unless required by applicable law or agreed to in writing, software //
// distributed under the License is distributed on an "AS IS" BASIS, //
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
// See the License for the specific language governing permissions and //
// limitations under the License. //
// ======================================================================== //
#include "autoencoder.h"
namespace oidn {
// --------------------------------------------------------------------------
// AutoencoderFilter
// --------------------------------------------------------------------------
AutoencoderFilter::AutoencoderFilter(const Ref<Device>& device)
: Filter(device)
{
}
void AutoencoderFilter::setImage(const std::string& name, const Image& data)
{
if (name == "color")
color = data;
else if (name == "albedo")
albedo = data;
else if (name == "normal")
normal = data;
else if (name == "output")
output = data;
dirty = true;
}
void AutoencoderFilter::set1i(const std::string& name, int value)
{
if (name == "hdr")
hdr = value;
else if (name == "srgb")
srgb = value;
else if (name == "maxMemoryMB")
maxMemoryMB = value;
dirty = true;
}
int AutoencoderFilter::get1i(const std::string& name)
{
if (name == "hdr")
return hdr;
else if (name == "srgb")
return srgb;
else if (name == "maxMemoryMB")
return maxMemoryMB;
else if (name == "alignment")
return alignment;
else if (name == "overlap")
return overlap;
else
throw Exception(Error::InvalidArgument, "invalid parameter");
}
void AutoencoderFilter::set1f(const std::string& name, float value)
{
if (name == "hdrScale")
hdrScale = value;
dirty = true;
}
float AutoencoderFilter::get1f(const std::string& name)
{
if (name == "hdrScale")
return hdrScale;
else
throw Exception(Error::InvalidArgument, "invalid parameter");
}
void AutoencoderFilter::commit()
{
if (!dirty)
return;
// -- GODOT start --
//device->executeTask([&]()
//{
// GODOT end --
if (mayiuse(avx512_common))
net = buildNet<16>();
else
net = buildNet<8>();
// GODOT start --
//});
// GODOT end --
dirty = false;
}
void AutoencoderFilter::execute()
{
if (dirty)
throw Exception(Error::InvalidOperation, "changes to the filter are not committed");
if (!net)
return;
// -- GODOT start --
//device->executeTask([&]()
//{
// -- GODOT end --
Progress progress;
progress.func = progressFunc;
progress.userPtr = progressUserPtr;
progress.taskCount = tileCountH * tileCountW;
// Iterate over the tiles
int tileIndex = 0;
for (int i = 0; i < tileCountH; ++i)
{
const int h = i * (tileH - 2*overlap); // input tile position (including overlap)
const int overlapBeginH = i > 0 ? overlap : 0; // overlap on the top
const int overlapEndH = i < tileCountH-1 ? overlap : 0; // overlap on the bottom
const int tileH1 = min(H - h, tileH); // input tile size (including overlap)
const int tileH2 = tileH1 - overlapBeginH - overlapEndH; // output tile size
const int alignOffsetH = tileH - roundUp(tileH1, alignment); // align to the bottom in the tile buffer
for (int j = 0; j < tileCountW; ++j)
{
const int w = j * (tileW - 2*overlap); // input tile position (including overlap)
const int overlapBeginW = j > 0 ? overlap : 0; // overlap on the left
const int overlapEndW = j < tileCountW-1 ? overlap : 0; // overlap on the right
const int tileW1 = min(W - w, tileW); // input tile size (including overlap)
const int tileW2 = tileW1 - overlapBeginW - overlapEndW; // output tile size
const int alignOffsetW = tileW - roundUp(tileW1, alignment); // align to the right in the tile buffer
// Set the input tile
inputReorder->setTile(h, w,
alignOffsetH, alignOffsetW,
tileH1, tileW1);
// Set the output tile
outputReorder->setTile(alignOffsetH + overlapBeginH, alignOffsetW + overlapBeginW,
h + overlapBeginH, w + overlapBeginW,
tileH2, tileW2);
//printf("Tile: %d %d -> %d %d\n", w+overlapBeginW, h+overlapBeginH, w+overlapBeginW+tileW2, h+overlapBeginH+tileH2);
// Denoise the tile
net->execute(progress, tileIndex);
// Next tile
tileIndex++;
}
}
// -- GODOT start --
//});
// -- GODOT end --
}
void AutoencoderFilter::computeTileSize()
{
const int minTileSize = 3*overlap;
const int estimatedBytesPerPixel = mayiuse(avx512_common) ? estimatedBytesPerPixel16 : estimatedBytesPerPixel8;
const int64_t maxTilePixels = (int64_t(maxMemoryMB)*1024*1024 - estimatedBytesBase) / estimatedBytesPerPixel;
tileCountH = 1;
tileCountW = 1;
tileH = roundUp(H, alignment);
tileW = roundUp(W, alignment);
// Divide the image into tiles until the tile size gets below the threshold
while (int64_t(tileH) * tileW > maxTilePixels)
{
if (tileH > minTileSize && tileH > tileW)
{
tileCountH++;
tileH = max(roundUp(ceilDiv(H - 2*overlap, tileCountH), alignment) + 2*overlap, minTileSize);
}
else if (tileW > minTileSize)
{
tileCountW++;
tileW = max(roundUp(ceilDiv(W - 2*overlap, tileCountW), alignment) + 2*overlap, minTileSize);
}
else
break;
}
// Compute the final number of tiles
tileCountH = (H > tileH) ? ceilDiv(H - 2*overlap, tileH - 2*overlap) : 1;
tileCountW = (W > tileW) ? ceilDiv(W - 2*overlap, tileW - 2*overlap) : 1;
if (device->isVerbose(2))
{
std::cout << "Tile size : " << tileW << "x" << tileH << std::endl;
std::cout << "Tile count: " << tileCountW << "x" << tileCountH << std::endl;
}
}
template<int K>
std::shared_ptr<Executable> AutoencoderFilter::buildNet()
{
H = color.height;
W = color.width;
// Configure the network
int inputC;
void* weightPtr;
if (srgb && hdr)
throw Exception(Error::InvalidOperation, "srgb and hdr modes cannot be enabled at the same time");
if (color && !albedo && !normal && weightData.hdr)
{
inputC = 3;
weightPtr = hdr ? weightData.hdr : weightData.ldr;
}
else if (color && albedo && !normal && weightData.hdr_alb)
{
inputC = 6;
weightPtr = hdr ? weightData.hdr_alb : weightData.ldr_alb;
}
else if (color && albedo && normal && weightData.hdr_alb_nrm)
{
inputC = 9;
weightPtr = hdr ? weightData.hdr_alb_nrm : weightData.ldr_alb_nrm;
}
else
{
throw Exception(Error::InvalidOperation, "unsupported combination of input features");
}
if (!output)
throw Exception(Error::InvalidOperation, "output image not specified");
if ((color.format != Format::Float3)
|| (albedo && albedo.format != Format::Float3)
|| (normal && normal.format != Format::Float3)
|| (output.format != Format::Float3))
throw Exception(Error::InvalidOperation, "unsupported image format");
if ((albedo && (albedo.width != W || albedo.height != H))
|| (normal && (normal.width != W || normal.height != H))
|| (output.width != W || output.height != H))
throw Exception(Error::InvalidOperation, "image size mismatch");
// Compute the tile size
computeTileSize();
// If the image size is zero, there is nothing else to do
if (H <= 0 || W <= 0)
return nullptr;
// Parse the weights
const auto weightMap = parseTensors(weightPtr);
// Create the network
std::shared_ptr<Network<K>> net = std::make_shared<Network<K>>(device, weightMap);
// Compute the tensor sizes
const auto inputDims = memory::dims({1, inputC, tileH, tileW});
const auto inputReorderDims = net->getInputReorderDims(inputDims, alignment); //-> concat0
const auto conv1Dims = net->getConvDims("conv1", inputReorderDims); //-> temp0
const auto conv1bDims = net->getConvDims("conv1b", conv1Dims); //-> temp1
const auto pool1Dims = net->getPoolDims(conv1bDims); //-> concat1
const auto conv2Dims = net->getConvDims("conv2", pool1Dims); //-> temp0
const auto pool2Dims = net->getPoolDims(conv2Dims); //-> concat2
const auto conv3Dims = net->getConvDims("conv3", pool2Dims); //-> temp0
const auto pool3Dims = net->getPoolDims(conv3Dims); //-> concat3
const auto conv4Dims = net->getConvDims("conv4", pool3Dims); //-> temp0
const auto pool4Dims = net->getPoolDims(conv4Dims); //-> concat4
const auto conv5Dims = net->getConvDims("conv5", pool4Dims); //-> temp0
const auto pool5Dims = net->getPoolDims(conv5Dims); //-> temp1
const auto upsample4Dims = net->getUpsampleDims(pool5Dims); //-> concat4
const auto concat4Dims = net->getConcatDims(upsample4Dims, pool4Dims);
const auto conv6Dims = net->getConvDims("conv6", concat4Dims); //-> temp0
const auto conv6bDims = net->getConvDims("conv6b", conv6Dims); //-> temp1
const auto upsample3Dims = net->getUpsampleDims(conv6bDims); //-> concat3
const auto concat3Dims = net->getConcatDims(upsample3Dims, pool3Dims);
const auto conv7Dims = net->getConvDims("conv7", concat3Dims); //-> temp0
const auto conv7bDims = net->getConvDims("conv7b", conv7Dims); //-> temp1
const auto upsample2Dims = net->getUpsampleDims(conv7bDims); //-> concat2
const auto concat2Dims = net->getConcatDims(upsample2Dims, pool2Dims);
const auto conv8Dims = net->getConvDims("conv8", concat2Dims); //-> temp0
const auto conv8bDims = net->getConvDims("conv8b", conv8Dims); //-> temp1
const auto upsample1Dims = net->getUpsampleDims(conv8bDims); //-> concat1
const auto concat1Dims = net->getConcatDims(upsample1Dims, pool1Dims);
const auto conv9Dims = net->getConvDims("conv9", concat1Dims); //-> temp0
const auto conv9bDims = net->getConvDims("conv9b", conv9Dims); //-> temp1
const auto upsample0Dims = net->getUpsampleDims(conv9bDims); //-> concat0
const auto concat0Dims = net->getConcatDims(upsample0Dims, inputReorderDims);
const auto conv10Dims = net->getConvDims("conv10", concat0Dims); //-> temp0
const auto conv10bDims = net->getConvDims("conv10b", conv10Dims); //-> temp1
const auto conv11Dims = net->getConvDims("conv11", conv10bDims); //-> temp0
const auto outputDims = memory::dims({1, 3, tileH, tileW});
// Allocate two temporary ping-pong buffers to decrease memory usage
const auto temp0Dims = getMaxTensorDims({
conv1Dims,
conv2Dims,
conv3Dims,
conv4Dims,
conv5Dims,
conv6Dims,
conv7Dims,
conv8Dims,
conv9Dims,
conv10Dims,
conv11Dims
});
const auto temp1Dims = getMaxTensorDims({
conv1bDims,
pool5Dims,
conv6bDims,
conv7bDims,
conv8bDims,
conv9bDims,
conv10bDims,
});
auto temp0 = net->allocTensor(temp0Dims);
auto temp1 = net->allocTensor(temp1Dims);
// Allocate enough memory to hold the concat outputs. Then use the first
// half to hold the previous conv output and the second half to hold the
// pool/orig image output. This works because everything is C dimension
// outermost, padded to K floats, and all the concats are on the C dimension.
auto concat0Dst = net->allocTensor(concat0Dims);
auto concat1Dst = net->allocTensor(concat1Dims);
auto concat2Dst = net->allocTensor(concat2Dims);
auto concat3Dst = net->allocTensor(concat3Dims);
auto concat4Dst = net->allocTensor(concat4Dims);
// Transfer function
std::shared_ptr<TransferFunction> transferFunc = makeTransferFunc();
// Autoexposure
if (auto tf = std::dynamic_pointer_cast<HDRTransferFunction>(transferFunc))
{
if (isnan(hdrScale))
net->addAutoexposure(color, tf);
else
tf->setExposure(hdrScale);
}
// Input reorder
auto inputReorderDst = net->castTensor(inputReorderDims, concat0Dst, upsample0Dims);
inputReorder = net->addInputReorder(color, albedo, normal,
transferFunc,
alignment, inputReorderDst);
// conv1
auto conv1 = net->addConv("conv1", inputReorder->getDst(), temp0);
// conv1b
auto conv1b = net->addConv("conv1b", conv1->getDst(), temp1);
// pool1
// Adjust pointer for pool1 to eliminate concat1
auto pool1Dst = net->castTensor(pool1Dims, concat1Dst, upsample1Dims);
auto pool1 = net->addPool(conv1b->getDst(), pool1Dst);
// conv2
auto conv2 = net->addConv("conv2", pool1->getDst(), temp0);
// pool2
// Adjust pointer for pool2 to eliminate concat2
auto pool2Dst = net->castTensor(pool2Dims, concat2Dst, upsample2Dims);
auto pool2 = net->addPool(conv2->getDst(), pool2Dst);
// conv3
auto conv3 = net->addConv("conv3", pool2->getDst(), temp0);
// pool3
// Adjust pointer for pool3 to eliminate concat3
auto pool3Dst = net->castTensor(pool3Dims, concat3Dst, upsample3Dims);
auto pool3 = net->addPool(conv3->getDst(), pool3Dst);
// conv4
auto conv4 = net->addConv("conv4", pool3->getDst(), temp0);
// pool4
// Adjust pointer for pool4 to eliminate concat4
auto pool4Dst = net->castTensor(pool4Dims, concat4Dst, upsample4Dims);
auto pool4 = net->addPool(conv4->getDst(), pool4Dst);
// conv5
auto conv5 = net->addConv("conv5", pool4->getDst(), temp0);
// pool5
auto pool5 = net->addPool(conv5->getDst(), temp1);
// upsample4
auto upsample4Dst = net->castTensor(upsample4Dims, concat4Dst);
auto upsample4 = net->addUpsample(pool5->getDst(), upsample4Dst);
// conv6
auto conv6 = net->addConv("conv6", concat4Dst, temp0);
// conv6b
auto conv6b = net->addConv("conv6b", conv6->getDst(), temp1);
// upsample3
auto upsample3Dst = net->castTensor(upsample3Dims, concat3Dst);
auto upsample3 = net->addUpsample(conv6b->getDst(), upsample3Dst);
// conv7
auto conv7 = net->addConv("conv7", concat3Dst, temp0);
// conv7b
auto conv7b = net->addConv("conv7b", conv7->getDst(), temp1);
// upsample2
auto upsample2Dst = net->castTensor(upsample2Dims, concat2Dst);
auto upsample2 = net->addUpsample(conv7b->getDst(), upsample2Dst);
// conv8
auto conv8 = net->addConv("conv8", concat2Dst, temp0);
// conv8b
auto conv8b = net->addConv("conv8b", conv8->getDst(), temp1);
// upsample1
auto upsample1Dst = net->castTensor(upsample1Dims, concat1Dst);
auto upsample1 = net->addUpsample(conv8b->getDst(), upsample1Dst);
// conv9
auto conv9 = net->addConv("conv9", concat1Dst, temp0);
// conv9b
auto conv9b = net->addConv("conv9b", conv9->getDst(), temp1);
// upsample0
auto upsample0Dst = net->castTensor(upsample0Dims, concat0Dst);
auto upsample0 = net->addUpsample(conv9b->getDst(), upsample0Dst);
// conv10
auto conv10 = net->addConv("conv10", concat0Dst, temp0);
// conv10b
auto conv10b = net->addConv("conv10b", conv10->getDst(), temp1);
// conv11
auto conv11 = net->addConv("conv11", conv10b->getDst(), temp0, false /* no relu */);
// Output reorder
outputReorder = net->addOutputReorder(conv11->getDst(), transferFunc, output);
net->finalize();
return net;
}
std::shared_ptr<TransferFunction> AutoencoderFilter::makeTransferFunc()
{
if (hdr)
return std::make_shared<PQXTransferFunction>();
else if (srgb)
return std::make_shared<LinearTransferFunction>();
else
return std::make_shared<GammaTransferFunction>();
}
// -- GODOT start --
// Godot doesn't need Raytracing filters. Removing them saves space in the weights files.
#if 0
// -- GODOT end --
// --------------------------------------------------------------------------
// RTFilter
// --------------------------------------------------------------------------
namespace weights
{
// LDR
extern unsigned char rt_ldr[]; // color
extern unsigned char rt_ldr_alb[]; // color, albedo
extern unsigned char rt_ldr_alb_nrm[]; // color, albedo, normal
// HDR
extern unsigned char rt_hdr[]; // color
extern unsigned char rt_hdr_alb[]; // color, albedo
extern unsigned char rt_hdr_alb_nrm[]; // color, albedo, normal
}
RTFilter::RTFilter(const Ref<Device>& device)
: AutoencoderFilter(device)
{
weightData.ldr = weights::rt_ldr;
weightData.ldr_alb = weights::rt_ldr_alb;
weightData.ldr_alb_nrm = weights::rt_ldr_alb_nrm;
weightData.hdr = weights::rt_hdr;
weightData.hdr_alb = weights::rt_hdr_alb;
weightData.hdr_alb_nrm = weights::rt_hdr_alb_nrm;
}
// -- GODOT start --
#endif
// -- GODOT end --
// --------------------------------------------------------------------------
// RTLightmapFilter
// --------------------------------------------------------------------------
namespace weights
{
// HDR
extern unsigned char rtlightmap_hdr[]; // color
}
RTLightmapFilter::RTLightmapFilter(const Ref<Device>& device)
: AutoencoderFilter(device)
{
weightData.hdr = weights::rtlightmap_hdr;
hdr = true;
}
std::shared_ptr<TransferFunction> RTLightmapFilter::makeTransferFunc()
{
return std::make_shared<LogTransferFunction>();
}
} // namespace oidn

View file

@ -1,120 +0,0 @@
// ======================================================================== //
// Copyright 2009-2019 Intel Corporation //
// //
// Licensed under the Apache License, Version 2.0 (the "License"); //
// you may not use this file except in compliance with the License. //
// You may obtain a copy of the License at //
// //
// http://www.apache.org/licenses/LICENSE-2.0 //
// //
// Unless required by applicable law or agreed to in writing, software //
// distributed under the License is distributed on an "AS IS" BASIS, //
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
// See the License for the specific language governing permissions and //
// limitations under the License. //
// ======================================================================== //
#pragma once
#include "filter.h"
#include "network.h"
#include "transfer_function.h"
namespace oidn {
// --------------------------------------------------------------------------
// AutoencoderFilter - Direct-predicting autoencoder
// --------------------------------------------------------------------------
class AutoencoderFilter : public Filter
{
protected:
static constexpr int alignment = 32; // required spatial alignment in pixels (padding may be necessary)
static constexpr int receptiveField = 222; // receptive field in pixels
static constexpr int overlap = roundUp(receptiveField / 2, alignment); // required spatial overlap between tiles in pixels
static constexpr int estimatedBytesBase = 16*1024*1024; // estimated base memory usage
static constexpr int estimatedBytesPerPixel8 = 889; // estimated memory usage per pixel for K=8
static constexpr int estimatedBytesPerPixel16 = 2185; // estimated memory usage per pixel for K=16
Image color;
Image albedo;
Image normal;
Image output;
bool hdr = false;
float hdrScale = std::numeric_limits<float>::quiet_NaN();
bool srgb = false;
int maxMemoryMB = 6000; // approximate maximum memory usage in MBs
int H = 0; // image height
int W = 0; // image width
int tileH = 0; // tile height
int tileW = 0; // tile width
int tileCountH = 1; // number of tiles in H dimension
int tileCountW = 1; // number of tiles in W dimension
std::shared_ptr<Executable> net;
std::shared_ptr<Node> inputReorder;
std::shared_ptr<Node> outputReorder;
struct
{
void* ldr = nullptr;
void* ldr_alb = nullptr;
void* ldr_alb_nrm = nullptr;
void* hdr = nullptr;
void* hdr_alb = nullptr;
void* hdr_alb_nrm = nullptr;
} weightData;
explicit AutoencoderFilter(const Ref<Device>& device);
virtual std::shared_ptr<TransferFunction> makeTransferFunc();
public:
void setImage(const std::string& name, const Image& data) override;
void set1i(const std::string& name, int value) override;
int get1i(const std::string& name) override;
void set1f(const std::string& name, float value) override;
float get1f(const std::string& name) override;
void commit() override;
void execute() override;
private:
void computeTileSize();
template<int K>
std::shared_ptr<Executable> buildNet();
bool isCommitted() const { return bool(net); }
};
// --------------------------------------------------------------------------
// RTFilter - Generic ray tracing denoiser
// --------------------------------------------------------------------------
// -- GODOT start --
// Godot doesn't need Raytracing filters. Removing them saves space in the weights files.
#if 0
// -- GODOT end --
class RTFilter : public AutoencoderFilter
{
public:
explicit RTFilter(const Ref<Device>& device);
};
// -- GODOT start --
#endif
// -- GODOT end --
// --------------------------------------------------------------------------
// RTLightmapFilter - Ray traced lightmap denoiser
// --------------------------------------------------------------------------
class RTLightmapFilter : public AutoencoderFilter
{
public:
explicit RTLightmapFilter(const Ref<Device>& device);
std::shared_ptr<TransferFunction> makeTransferFunc() override;
};
} // namespace oidn

View file

@ -1,75 +0,0 @@
// ======================================================================== //
// Copyright 2009-2019 Intel Corporation //
// //
// Licensed under the Apache License, Version 2.0 (the "License"); //
// you may not use this file except in compliance with the License. //
// You may obtain a copy of the License at //
// //
// http://www.apache.org/licenses/LICENSE-2.0 //
// //
// Unless required by applicable law or agreed to in writing, software //
// distributed under the License is distributed on an "AS IS" BASIS, //
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
// See the License for the specific language governing permissions and //
// limitations under the License. //
// ======================================================================== //
#pragma once
#include "common.h"
#include "device.h"
namespace oidn {
class Device;
// Buffer which may or may not own its data
class Buffer : public RefCount
{
private:
char* ptr;
size_t byteSize;
bool shared;
Ref<Device> device;
public:
__forceinline Buffer(const Ref<Device>& device, size_t size)
: ptr((char*)alignedMalloc(size, 64)),
byteSize(size),
shared(false),
device(device) {}
__forceinline Buffer(const Ref<Device>& device, void* data, size_t size)
: ptr((char*)data),
byteSize(size),
shared(true),
device(device)
{
if (data == nullptr)
throw Exception(Error::InvalidArgument, "buffer pointer null");
}
__forceinline ~Buffer()
{
if (!shared)
alignedFree(ptr);
}
__forceinline char* data() { return ptr; }
__forceinline const char* data() const { return ptr; }
__forceinline size_t size() const { return byteSize; }
void* map(size_t offset, size_t size)
{
if (offset + size > byteSize)
throw Exception(Error::InvalidArgument, "buffer region out of range");
return ptr + offset;
}
void unmap(void* mappedPtr) {}
Device* getDevice() { return device.get(); }
};
} // namespace oidn

View file

@ -1,136 +0,0 @@
// ======================================================================== //
// Copyright 2009-2019 Intel Corporation //
// //
// Licensed under the Apache License, Version 2.0 (the "License"); //
// you may not use this file except in compliance with the License. //
// You may obtain a copy of the License at //
// //
// http://www.apache.org/licenses/LICENSE-2.0 //
// //
// Unless required by applicable law or agreed to in writing, software //
// distributed under the License is distributed on an "AS IS" BASIS, //
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
// See the License for the specific language governing permissions and //
// limitations under the License. //
// ======================================================================== //
#pragma once
#include "common/platform.h"
#include "mkl-dnn/include/mkldnn.hpp"
#include "mkl-dnn/include/mkldnn_debug.h"
#include "mkl-dnn/src/common/mkldnn_thread.hpp"
#include "mkl-dnn/src/common/type_helpers.hpp"
#include "mkl-dnn/src/cpu/jit_generator.hpp"
#include "common/ref.h"
#include "common/exception.h"
#include "common/thread.h"
// -- GODOT start --
//#include "common/tasking.h"
// -- GODOT end --
#include "math.h"
namespace oidn {
using namespace mkldnn;
using namespace mkldnn::impl::cpu;
using mkldnn::impl::parallel_nd;
using mkldnn::impl::memory_desc_matches_tag;
inline size_t getFormatBytes(Format format)
{
switch (format)
{
case Format::Undefined: return 1;
case Format::Float: return sizeof(float);
case Format::Float2: return sizeof(float)*2;
case Format::Float3: return sizeof(float)*3;
case Format::Float4: return sizeof(float)*4;
}
assert(0);
return 0;
}
inline memory::dims getTensorDims(const std::shared_ptr<memory>& mem)
{
const mkldnn_memory_desc_t& desc = mem->get_desc().data;
return memory::dims(&desc.dims[0], &desc.dims[desc.ndims]);
}
inline memory::data_type getTensorType(const std::shared_ptr<memory>& mem)
{
const mkldnn_memory_desc_t& desc = mem->get_desc().data;
return memory::data_type(desc.data_type);
}
// Returns the number of values in a tensor
inline size_t getTensorSize(const memory::dims& dims)
{
size_t res = 1;
for (int i = 0; i < (int)dims.size(); ++i)
res *= dims[i];
return res;
}
inline memory::dims getMaxTensorDims(const std::vector<memory::dims>& dims)
{
memory::dims result;
size_t maxSize = 0;
for (const auto& d : dims)
{
const size_t size = getTensorSize(d);
if (size > maxSize)
{
result = d;
maxSize = size;
}
}
return result;
}
inline size_t getTensorSize(const std::shared_ptr<memory>& mem)
{
return getTensorSize(getTensorDims(mem));
}
template<int K>
inline int getPadded(int dim)
{
return (dim + (K-1)) & ~(K-1);
}
template<int K>
inline memory::dims getPadded_nchw(const memory::dims& dims)
{
assert(dims.size() == 4);
memory::dims padDims = dims;
padDims[1] = getPadded<K>(dims[1]); // pad C
return padDims;
}
template<int K>
struct BlockedFormat;
template<>
struct BlockedFormat<8>
{
static constexpr memory::format_tag nChwKc = memory::format_tag::nChw8c;
static constexpr memory::format_tag OIhwKiKo = memory::format_tag::OIhw8i8o;
};
template<>
struct BlockedFormat<16>
{
static constexpr memory::format_tag nChwKc = memory::format_tag::nChw16c;
static constexpr memory::format_tag OIhwKiKo = memory::format_tag::OIhw16i16o;
};
} // namespace oidn

View file

@ -1,238 +0,0 @@
// ======================================================================== //
// Copyright 2009-2019 Intel Corporation //
// //
// Licensed under the Apache License, Version 2.0 (the "License"); //
// you may not use this file except in compliance with the License. //
// You may obtain a copy of the License at //
// //
// http://www.apache.org/licenses/LICENSE-2.0 //
// //
// Unless required by applicable law or agreed to in writing, software //
// distributed under the License is distributed on an "AS IS" BASIS, //
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
// See the License for the specific language governing permissions and //
// limitations under the License. //
// ======================================================================== //
#include "device.h"
#include "autoencoder.h"
namespace oidn {
thread_local Device::ErrorState Device::globalError;
Device::Device()
{
if (!mayiuse(sse41))
throw Exception(Error::UnsupportedHardware, "SSE4.1 support is required at minimum");
}
Device::~Device()
{
// -- GODOT start --
//observer.reset();
// -- GODOT end --
}
void Device::setError(Device* device, Error code, const std::string& message)
{
// Update the stored error only if the previous error was queried
if (device)
{
ErrorState& curError = device->error.get();
if (curError.code == Error::None)
{
curError.code = code;
curError.message = message;
}
// Print the error message in verbose mode
if (device->isVerbose())
std::cerr << "Error: " << message << std::endl;
// Call the error callback function
ErrorFunction errorFunc;
void* errorUserPtr;
{
std::lock_guard<std::mutex> lock(device->mutex);
errorFunc = device->errorFunc;
errorUserPtr = device->errorUserPtr;
}
if (errorFunc)
errorFunc(errorUserPtr, code, (code == Error::None) ? nullptr : message.c_str());
}
else
{
if (globalError.code == Error::None)
{
globalError.code = code;
globalError.message = message;
}
}
}
Error Device::getError(Device* device, const char** outMessage)
{
// Return and clear the stored error code, but keep the error message so pointers to it will
// remain valid until the next getError call
if (device)
{
ErrorState& curError = device->error.get();
const Error code = curError.code;
if (outMessage)
*outMessage = (code == Error::None) ? nullptr : curError.message.c_str();
curError.code = Error::None;
return code;
}
else
{
const Error code = globalError.code;
if (outMessage)
*outMessage = (code == Error::None) ? nullptr : globalError.message.c_str();
globalError.code = Error::None;
return code;
}
}
void Device::setErrorFunction(ErrorFunction func, void* userPtr)
{
errorFunc = func;
errorUserPtr = userPtr;
}
int Device::get1i(const std::string& name)
{
if (name == "numThreads")
return numThreads;
else if (name == "setAffinity")
return setAffinity;
else if (name == "verbose")
return verbose;
else if (name == "version")
return OIDN_VERSION;
else if (name == "versionMajor")
return OIDN_VERSION_MAJOR;
else if (name == "versionMinor")
return OIDN_VERSION_MINOR;
else if (name == "versionPatch")
return OIDN_VERSION_PATCH;
else
throw Exception(Error::InvalidArgument, "invalid parameter");
}
void Device::set1i(const std::string& name, int value)
{
if (name == "numThreads")
numThreads = value;
else if (name == "setAffinity")
setAffinity = value;
else if (name == "verbose")
{
verbose = value;
error.verbose = value;
}
dirty = true;
}
void Device::commit()
{
if (isCommitted())
throw Exception(Error::InvalidOperation, "device can be committed only once");
// -- GODOT start --
#if 0
// -- GODOT end --
// Get the optimal thread affinities
if (setAffinity)
{
affinity = std::make_shared<ThreadAffinity>(1, verbose); // one thread per core
if (affinity->getNumThreads() == 0)
affinity.reset();
}
// Create the task arena
const int maxNumThreads = affinity ? affinity->getNumThreads() : tbb::this_task_arena::max_concurrency();
numThreads = (numThreads > 0) ? min(numThreads, maxNumThreads) : maxNumThreads;
arena = std::make_shared<tbb::task_arena>(numThreads);
// Automatically set the thread affinities
if (affinity)
observer = std::make_shared<PinningObserver>(affinity, *arena);
// -- GODOT start --
#endif
numThreads = 1;
// -- GODOT end --
dirty = false;
if (isVerbose())
print();
}
void Device::checkCommitted()
{
if (dirty)
throw Exception(Error::InvalidOperation, "changes to the device are not committed");
}
Ref<Buffer> Device::newBuffer(size_t byteSize)
{
checkCommitted();
return makeRef<Buffer>(Ref<Device>(this), byteSize);
}
Ref<Buffer> Device::newBuffer(void* ptr, size_t byteSize)
{
checkCommitted();
return makeRef<Buffer>(Ref<Device>(this), ptr, byteSize);
}
Ref<Filter> Device::newFilter(const std::string& type)
{
checkCommitted();
if (isVerbose())
std::cout << "Filter: " << type << std::endl;
Ref<Filter> filter;
// -- GODOT start --
// Godot doesn't need Raytracing filters. Removing them saves space in the weights files.
#if 0
// -- GODOT end --
if (type == "RT")
filter = makeRef<RTFilter>(Ref<Device>(this));
// -- GODOT start --
// Godot doesn't need Raytracing filters. Removing them saves space in the weights files.
#endif
if (type == "RTLightmap")
// -- GODOT end --
filter = makeRef<RTLightmapFilter>(Ref<Device>(this));
else
throw Exception(Error::InvalidArgument, "unknown filter type");
return filter;
}
void Device::print()
{
std::cout << std::endl;
std::cout << "Intel(R) Open Image Denoise " << OIDN_VERSION_STRING << std::endl;
std::cout << " Compiler: " << getCompilerName() << std::endl;
std::cout << " Build : " << getBuildName() << std::endl;
std::cout << " Platform: " << getPlatformName() << std::endl;
// -- GODOT start --
// std::cout << " Tasking :";
// std::cout << " TBB" << TBB_VERSION_MAJOR << "." << TBB_VERSION_MINOR;
// std::cout << " TBB_header_interface_" << TBB_INTERFACE_VERSION << " TBB_lib_interface_" << tbb::TBB_runtime_interface_version();
// std::cout << std::endl;
// -- GODOT end --
std::cout << std::endl;
}
} // namespace oidn

View file

@ -1,102 +0,0 @@
// ======================================================================== //
// Copyright 2009-2019 Intel Corporation //
// //
// Licensed under the Apache License, Version 2.0 (the "License"); //
// you may not use this file except in compliance with the License. //
// You may obtain a copy of the License at //
// //
// http://www.apache.org/licenses/LICENSE-2.0 //
// //
// Unless required by applicable law or agreed to in writing, software //
// distributed under the License is distributed on an "AS IS" BASIS, //
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
// See the License for the specific language governing permissions and //
// limitations under the License. //
// ======================================================================== //
#pragma once
#include "common.h"
namespace oidn {
class Buffer;
class Filter;
class Device : public RefCount, public Verbose
{
private:
// Thread-safety
std::mutex mutex;
// Error handling
struct ErrorState
{
Error code = Error::None;
std::string message;
};
static thread_local ErrorState globalError;
ThreadLocal<ErrorState> error;
ErrorFunction errorFunc = nullptr;
void* errorUserPtr = nullptr;
// -- GODOT start --
// // Tasking
// std::shared_ptr<tbb::task_arena> arena;
// std::shared_ptr<PinningObserver> observer;
// std::shared_ptr<ThreadAffinity> affinity;
// -- GODOT end --
// Parameters
int numThreads = 0; // autodetect by default
bool setAffinity = true;
bool dirty = true;
public:
Device();
~Device();
static void setError(Device* device, Error code, const std::string& message);
static Error getError(Device* device, const char** outMessage);
void setErrorFunction(ErrorFunction func, void* userPtr);
int get1i(const std::string& name);
void set1i(const std::string& name, int value);
void commit();
// -- GODOT start --
// template<typename F>
// void executeTask(F& f)
// {
// arena->execute(f);
// }
// template<typename F>
// void executeTask(const F& f)
// {
// arena->execute(f);
// }
// -- GODOT end --
Ref<Buffer> newBuffer(size_t byteSize);
Ref<Buffer> newBuffer(void* ptr, size_t byteSize);
Ref<Filter> newFilter(const std::string& type);
__forceinline Device* getDevice() { return this; }
__forceinline std::mutex& getMutex() { return mutex; }
private:
// -- GODOT start --
//bool isCommitted() const { return bool(arena); }
bool isCommitted() const { return false; }
// -- GODOT end --
void checkCommitted();
void print();
};
} // namespace oidn

View file

@ -1,27 +0,0 @@
// ======================================================================== //
// Copyright 2009-2019 Intel Corporation //
// //
// Licensed under the Apache License, Version 2.0 (the "License"); //
// you may not use this file except in compliance with the License. //
// You may obtain a copy of the License at //
// //
// http://www.apache.org/licenses/LICENSE-2.0 //
// //
// Unless required by applicable law or agreed to in writing, software //
// distributed under the License is distributed on an "AS IS" BASIS, //
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
// See the License for the specific language governing permissions and //
// limitations under the License. //
// ======================================================================== //
#include "filter.h"
namespace oidn {
void Filter::setProgressMonitorFunction(ProgressMonitorFunction func, void* userPtr)
{
progressFunc = func;
progressUserPtr = userPtr;
}
} // namespace oidn

View file

@ -1,52 +0,0 @@
// ======================================================================== //
// Copyright 2009-2019 Intel Corporation //
// //
// Licensed under the Apache License, Version 2.0 (the "License"); //
// you may not use this file except in compliance with the License. //
// You may obtain a copy of the License at //
// //
// http://www.apache.org/licenses/LICENSE-2.0 //
// //
// Unless required by applicable law or agreed to in writing, software //
// distributed under the License is distributed on an "AS IS" BASIS, //
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
// See the License for the specific language governing permissions and //
// limitations under the License. //
// ======================================================================== //
#pragma once
#include "common.h"
#include "device.h"
#include "image.h"
namespace oidn {
class Filter : public RefCount
{
protected:
Ref<Device> device;
ProgressMonitorFunction progressFunc = nullptr;
void* progressUserPtr = nullptr;
bool dirty = true;
public:
explicit Filter(const Ref<Device>& device) : device(device) {}
virtual void setImage(const std::string& name, const Image& data) = 0;
virtual void set1i(const std::string& name, int value) = 0;
virtual int get1i(const std::string& name) = 0;
virtual void set1f(const std::string& name, float value) = 0;
virtual float get1f(const std::string& name) = 0;
void setProgressMonitorFunction(ProgressMonitorFunction func, void* userPtr);
virtual void commit() = 0;
virtual void execute() = 0;
Device* getDevice() { return device.get(); }
};
} // namespace oidn

View file

@ -1,111 +0,0 @@
// ======================================================================== //
// Copyright 2009-2019 Intel Corporation //
// //
// Licensed under the Apache License, Version 2.0 (the "License"); //
// you may not use this file except in compliance with the License. //
// You may obtain a copy of the License at //
// //
// http://www.apache.org/licenses/LICENSE-2.0 //
// //
// Unless required by applicable law or agreed to in writing, software //
// distributed under the License is distributed on an "AS IS" BASIS, //
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
// See the License for the specific language governing permissions and //
// limitations under the License. //
// ======================================================================== //
#pragma once
#include "common.h"
#include "buffer.h"
namespace oidn {
struct Image
{
static constexpr int maxSize = 65536;
char* ptr; // pointer to the first pixel
int width; // width in number of pixels
int height; // height in number of pixels
size_t bytePixelStride; // pixel stride in number of *bytes*
size_t rowStride; // row stride in number of *pixel strides*
Format format; // pixel format
Ref<Buffer> buffer; // buffer containing the image data
Image() : ptr(nullptr), width(0), height(0), bytePixelStride(0), rowStride(0), format(Format::Undefined) {}
Image(void* ptr, Format format, int width, int height, size_t byteOffset, size_t inBytePixelStride, size_t inByteRowStride)
{
if (ptr == nullptr)
throw Exception(Error::InvalidArgument, "buffer pointer null");
init((char*)ptr + byteOffset, format, width, height, inBytePixelStride, inByteRowStride);
}
Image(const Ref<Buffer>& buffer, Format format, int width, int height, size_t byteOffset, size_t inBytePixelStride, size_t inByteRowStride)
{
init(buffer->data() + byteOffset, format, width, height, inBytePixelStride, inByteRowStride);
if (byteOffset + height * rowStride * bytePixelStride > buffer->size())
throw Exception(Error::InvalidArgument, "buffer region out of range");
}
void init(char* ptr, Format format, int width, int height, size_t inBytePixelStride, size_t inByteRowStride)
{
assert(width >= 0);
assert(height >= 0);
if (width > maxSize || height > maxSize)
throw Exception(Error::InvalidArgument, "image size too large");
this->ptr = ptr;
this->width = width;
this->height = height;
const size_t pixelSize = getFormatBytes(format);
if (inBytePixelStride != 0)
{
if (inBytePixelStride < pixelSize)
throw Exception(Error::InvalidArgument, "pixel stride smaller than pixel size");
this->bytePixelStride = inBytePixelStride;
}
else
{
this->bytePixelStride = pixelSize;
}
if (inByteRowStride != 0)
{
if (inByteRowStride < width * this->bytePixelStride)
throw Exception(Error::InvalidArgument, "row stride smaller than width * pixel stride");
if (inByteRowStride % this->bytePixelStride != 0)
throw Exception(Error::InvalidArgument, "row stride not integer multiple of pixel stride");
this->rowStride = inByteRowStride / this->bytePixelStride;
}
else
{
this->rowStride = width;
}
this->format = format;
}
__forceinline char* get(int y, int x)
{
return ptr + ((size_t(y) * rowStride + size_t(x)) * bytePixelStride);
}
__forceinline const char* get(int y, int x) const
{
return ptr + ((size_t(y) * rowStride + size_t(x)) * bytePixelStride);
}
operator bool() const
{
return ptr != nullptr;
}
};
} // namespace oidn

View file

@ -1,232 +0,0 @@
// ======================================================================== //
// Copyright 2009-2019 Intel Corporation //
// //
// Licensed under the Apache License, Version 2.0 (the "License"); //
// you may not use this file except in compliance with the License. //
// You may obtain a copy of the License at //
// //
// http://www.apache.org/licenses/LICENSE-2.0 //
// //
// Unless required by applicable law or agreed to in writing, software //
// distributed under the License is distributed on an "AS IS" BASIS, //
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
// See the License for the specific language governing permissions and //
// limitations under the License. //
// ======================================================================== //
#pragma once
#include "node.h"
#include "image.h"
namespace oidn {
// Input reorder node
template<int K, class TransferFunction>
class InputReorderNode : public Node
{
private:
// Source
Image color;
Image albedo;
Image normal;
// Destination
std::shared_ptr<memory> dst;
float* dstPtr;
int C2;
int H2;
int W2;
// Tile
int h1Begin;
int w1Begin;
int h2Begin;
int w2Begin;
int H;
int W;
std::shared_ptr<TransferFunction> transferFunc;
public:
InputReorderNode(const Image& color,
const Image& albedo,
const Image& normal,
const std::shared_ptr<memory>& dst,
const std::shared_ptr<TransferFunction>& transferFunc)
: color(color), albedo(albedo), normal(normal),
dst(dst),
h1Begin(0), w1Begin(0),
H(color.height), W(color.width),
transferFunc(transferFunc)
{
const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data;
assert(memory_desc_matches_tag(dstDesc, mkldnn_format_tag_t(BlockedFormat<K>::nChwKc)));
assert(dstDesc.ndims == 4);
assert(dstDesc.data_type == memory::data_type::f32);
assert(dstDesc.dims[0] == 1);
//assert(dstDesc.dims[1] >= getPadded<K>(C1));
dstPtr = (float*)dst->get_data_handle();
C2 = dstDesc.dims[1];
H2 = dstDesc.dims[2];
W2 = dstDesc.dims[3];
}
void setTile(int h1, int w1, int h2, int w2, int H, int W) override
{
h1Begin = h1;
w1Begin = w1;
h2Begin = h2;
w2Begin = w2;
this->H = H;
this->W = W;
}
void execute(stream& sm) override
{
assert(H + h1Begin <= color.height);
assert(W + w1Begin <= color.width);
assert(H + h2Begin <= H2);
assert(W + w2Begin <= W2);
parallel_nd(H2, [&](int h2)
{
const int h = h2 - h2Begin;
if (h >= 0 && h < H)
{
const int h1 = h + h1Begin;
// Zero pad
for (int w2 = 0; w2 < w2Begin; ++w2)
{
int c = 0;
while (c < C2)
store(h2, w2, c, 0.f);
}
// Reorder
for (int w = 0; w < W; ++w)
{
const int w1 = w + w1Begin;
const int w2 = w + w2Begin;
int c = 0;
storeColor(h2, w2, c, (float*)color.get(h1, w1));
if (albedo)
storeAlbedo(h2, w2, c, (float*)albedo.get(h1, w1));
if (normal)
storeNormal(h2, w2, c, (float*)normal.get(h1, w1));
while (c < C2)
store(h2, w2, c, 0.f);
}
// Zero pad
for (int w2 = W + w2Begin; w2 < W2; ++w2)
{
int c = 0;
while (c < C2)
store(h2, w2, c, 0.f);
}
}
else
{
// Zero pad
for (int w2 = 0; w2 < W2; ++w2)
{
int c = 0;
while (c < C2)
store(h2, w2, c, 0.f);
}
}
});
}
std::shared_ptr<memory> getDst() const override { return dst; }
private:
// Stores a single value
__forceinline void store(int h, int w, int& c, float value)
{
// Destination is in nChwKc format
float* dst_c = dstPtr + (H2*W2*K*(c/K)) + h*W2*K + w*K + (c%K);
*dst_c = value;
c++;
}
// Stores a color
__forceinline void storeColor(int h, int w, int& c, const float* values)
{
#pragma unroll
for (int i = 0; i < 3; ++i)
{
// Load the value
float x = values[i];
// Sanitize the value
x = maxSafe(x, 0.f);
// Apply the transfer function
x = transferFunc->forward(x);
// Store the value
store(h, w, c, x);
}
}
// Stores an albedo
__forceinline void storeAlbedo(int h, int w, int& c, const float* values)
{
#pragma unroll
for (int i = 0; i < 3; ++i)
{
// Load the value
float x = values[i];
// Sanitize the value
x = clampSafe(x, 0.f, 1.f);
// Store the value
store(h, w, c, x);
}
}
// Stores a normal
__forceinline void storeNormal(int h, int w, int& c, const float* values)
{
// Load the normal
float x = values[0];
float y = values[1];
float z = values[2];
// Compute the length of the normal
const float lengthSqr = sqr(x) + sqr(y) + sqr(z);
// Normalize the normal and transform it to [0..1]
if (isfinite(lengthSqr))
{
const float invLength = (lengthSqr > minVectorLengthSqr) ? rsqrt(lengthSqr) : 1.f;
const float scale = invLength * 0.5f;
const float offset = 0.5f;
x = x * scale + offset;
y = y * scale + offset;
z = z * scale + offset;
}
else
{
x = 0.f;
y = 0.f;
z = 0.f;
}
// Store the normal
store(h, w, c, x);
store(h, w, c, y);
store(h, w, c, z);
}
};
} // namespace oidn

View file

@ -1,78 +0,0 @@
// ======================================================================== //
// Copyright 2009-2019 Intel Corporation //
// //
// Licensed under the Apache License, Version 2.0 (the "License"); //
// you may not use this file except in compliance with the License. //
// You may obtain a copy of the License at //
// //
// http://www.apache.org/licenses/LICENSE-2.0 //
// //
// Unless required by applicable law or agreed to in writing, software //
// distributed under the License is distributed on an "AS IS" BASIS, //
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
// See the License for the specific language governing permissions and //
// limitations under the License. //
// ======================================================================== //
#pragma once
#include "common/platform.h"
namespace oidn {
constexpr float minVectorLength = 1e-10f;
constexpr float minVectorLengthSqr = minVectorLength * minVectorLength;
using std::log;
using std::log2;
using std::exp;
using std::exp2;
using std::pow;
using std::isfinite;
using std::isnan;
__forceinline float sqr(float x)
{
return x * x;
}
__forceinline float rcp(float x)
{
__m128 r = _mm_rcp_ss(_mm_set_ss(x));
return _mm_cvtss_f32(_mm_sub_ss(_mm_add_ss(r, r), _mm_mul_ss(_mm_mul_ss(r, r), _mm_set_ss(x))));
}
__forceinline float rsqrt(float x)
{
__m128 r = _mm_rsqrt_ss(_mm_set_ss(x));
return _mm_cvtss_f32(_mm_add_ss(_mm_mul_ss(_mm_set_ss(1.5f), r),
_mm_mul_ss(_mm_mul_ss(_mm_mul_ss(_mm_set_ss(x), _mm_set_ss(-0.5f)), r), _mm_mul_ss(r, r))));
}
__forceinline float maxSafe(float value, float minValue)
{
return isfinite(value) ? max(value, minValue) : minValue;
}
__forceinline float clampSafe(float value, float minValue, float maxValue)
{
return isfinite(value) ? clamp(value, minValue, maxValue) : minValue;
}
// Returns ceil(a / b) for non-negative integers
template<class Int>
__forceinline constexpr Int ceilDiv(Int a, Int b)
{
//assert(a >= 0);
//assert(b > 0);
return (a + b - 1) / b;
}
// Returns a rounded up to multiple of b
template<class Int>
__forceinline constexpr Int roundUp(Int a, Int b)
{
return ceilDiv(a, b) * b;
}
} // namespace oidn

View file

@ -1,436 +0,0 @@
// ======================================================================== //
// Copyright 2009-2019 Intel Corporation //
// //
// Licensed under the Apache License, Version 2.0 (the "License"); //
// you may not use this file except in compliance with the License. //
// You may obtain a copy of the License at //
// //
// http://www.apache.org/licenses/LICENSE-2.0 //
// //
// Unless required by applicable law or agreed to in writing, software //
// distributed under the License is distributed on an "AS IS" BASIS, //
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
// See the License for the specific language governing permissions and //
// limitations under the License. //
// ======================================================================== //
#include "upsample.h"
#include "weights_reorder.h"
#include "network.h"
// -- GODOT start --
#include <cstring>
// -- GODOT end --
namespace oidn {
template<int K>
Network<K>::Network(const Ref<Device>& device, const std::map<std::string, Tensor>& weightMap)
: device(device),
eng(engine::cpu, 0),
sm(eng),
weightMap(weightMap)
{
}
template<int K>
void Network<K>::execute(const Progress& progress, int taskIndex)
{
if (progress.func)
{
const double value = double(taskIndex) / double(progress.taskCount);
if (!progress.func(progress.userPtr, value))
throw Exception(Error::Cancelled, "execution was cancelled");
}
for (size_t i = 0; i < nodes.size(); ++i)
{
nodes[i]->execute(sm);
if (progress.func)
{
const double value = (double(taskIndex) + double(i+1) / double(nodes.size())) / double(progress.taskCount);
if (!progress.func(progress.userPtr, value))
throw Exception(Error::Cancelled, "execution was cancelled");
}
}
}
template<int K>
std::shared_ptr<memory> Network<K>::allocTensor(const memory::dims& dims,
memory::format_tag format,
void* data)
{
if (format == memory::format_tag::any)
{
if (dims.size() == 4)
format = BlockedFormat<K>::nChwKc;
else if (dims.size() == 1)
format = memory::format_tag::x;
else
assert(0);
}
memory::desc desc(dims, memory::data_type::f32, format);
if (data == nullptr)
{
const size_t bytes = getTensorSize(dims) * sizeof(float);
if (format == BlockedFormat<K>::nChwKc)
activationAllocBytes += bytes;
totalAllocBytes += bytes;
return std::make_shared<memory>(desc, eng);
}
else
{
return std::make_shared<memory>(desc, eng, data);
}
}
template<int K>
std::shared_ptr<memory> Network<K>::castTensor(const memory::dims& dims,
const std::shared_ptr<memory>& src,
size_t srcOffset,
memory::format_tag format)
{
const mkldnn_memory_desc_t& srcDesc = src->get_desc().data;
MAYBE_UNUSED(srcDesc);
assert(srcDesc.data_type == memory::data_type::f32);
assert(getTensorSize(src) >= srcOffset + getTensorSize(dims));
if (format == memory::format_tag::any)
{
if (dims.size() == 4)
format = BlockedFormat<K>::nChwKc;
else if (dims.size() == 1)
format = memory::format_tag::x;
else
assert(0);
}
memory::desc desc(dims, memory::data_type::f32, format);
float* srcPtr = (float*)src->get_data_handle() + srcOffset;
return std::make_shared<memory>(desc, eng, srcPtr);
}
template<int K>
std::shared_ptr<memory> Network<K>::castTensor(const memory::dims& dims,
const std::shared_ptr<memory>& src,
const memory::dims& srcOffset)
{
return castTensor(dims, src, getTensorSize(srcOffset));
}
template<int K>
void Network<K>::zeroTensor(const std::shared_ptr<memory>& dst)
{
assert(getTensorType(dst) == memory::data_type::f32);
memset(dst->get_data_handle(), 0, getTensorSize(dst)*sizeof(float));
}
template<int K>
memory::dims Network<K>::getInputReorderDims(const memory::dims& srcDims, int alignment)
{
memory::dims dstDims = srcDims;
dstDims[1] = getPadded<K>(srcDims[1]); // round up C
dstDims[2] = roundUp(srcDims[2], memory::dim(alignment)); // round up H
dstDims[3] = roundUp(srcDims[3], memory::dim(alignment)); // round up W
return dstDims;
}
template<int K>
std::shared_ptr<Node> Network<K>::addInputReorder(const Image& color,
const Image& albedo,
const Image& normal,
const std::shared_ptr<TransferFunction>& transferFunc,
int alignment,
const std::shared_ptr<memory>& userDst)
{
assert(color);
int inputC = 3;
if (albedo) inputC += 3;
if (normal) inputC += 3;
memory::dims srcDims = {1, inputC, color.height, color.width};
memory::dims dstDims = getInputReorderDims(srcDims, alignment);
// Allocate padded memory
auto dst = userDst;
if (!dst)
dst = allocTensor(dstDims);
// Push node
std::shared_ptr<Node> node;
if (auto tf = std::dynamic_pointer_cast<LinearTransferFunction>(transferFunc))
node = std::make_shared<InputReorderNode<K, LinearTransferFunction>>(color, albedo, normal, dst, tf);
else if (auto tf = std::dynamic_pointer_cast<GammaTransferFunction>(transferFunc))
node = std::make_shared<InputReorderNode<K, GammaTransferFunction>>(color, albedo, normal, dst, tf);
else if (auto tf = std::dynamic_pointer_cast<LogTransferFunction>(transferFunc))
node = std::make_shared<InputReorderNode<K, LogTransferFunction>>(color, albedo, normal, dst, tf);
else if (auto tf = std::dynamic_pointer_cast<PQXTransferFunction>(transferFunc))
node = std::make_shared<InputReorderNode<K, PQXTransferFunction>>(color, albedo, normal, dst, tf);
else
assert(0);
nodes.push_back(node);
return node;
}
template<int K>
std::shared_ptr<Node> Network<K>::addOutputReorder(const std::shared_ptr<memory>& src,
const std::shared_ptr<TransferFunction>& transferFunc,
const Image& output)
{
memory::dims srcDims = getTensorDims(src);
assert(srcDims[1] == K);
// Push node
std::shared_ptr<Node> node;
if (auto tf = std::dynamic_pointer_cast<LinearTransferFunction>(transferFunc))
node = std::make_shared<OutputReorderNode<K, LinearTransferFunction>>(src, output, tf);
else if (auto tf = std::dynamic_pointer_cast<GammaTransferFunction>(transferFunc))
node = std::make_shared<OutputReorderNode<K, GammaTransferFunction>>(src, output, tf);
else if (auto tf = std::dynamic_pointer_cast<LogTransferFunction>(transferFunc))
node = std::make_shared<OutputReorderNode<K, LogTransferFunction>>(src, output, tf);
else if (auto tf = std::dynamic_pointer_cast<PQXTransferFunction>(transferFunc))
node = std::make_shared<OutputReorderNode<K, PQXTransferFunction>>(src, output, tf);
else
assert(0);
nodes.push_back(node);
return node;
}
template<int K>
memory::dims Network<K>::getConvDims(const std::string& name, const memory::dims& srcDims)
{
auto b = weightMap[name + "/b"];
memory::dims dstDims = srcDims;
dstDims[1] = getPadded<K>(b.dims[0]); // dstDims[C] = getPadded(OC)
return dstDims;
}
template<int K>
std::shared_ptr<Node> Network<K>::addConv(const std::string& name,
const std::shared_ptr<memory>& src,
const std::shared_ptr<memory>& userDst,
bool relu)
{
const memory::dims strides = {1, 1};
const memory::dims padding = {1, 1};
memory::dims srcDims = getTensorDims(src);
// Get the weights
const auto& W = weightMap[name + "/W"];
if (W.ndims() != 4 || W.format != "oihw")
throw Exception(Error::InvalidOperation, "invalid convolution weights");
memory::dims weightsDims = W.dims;
auto userWeights = allocTensor(weightsDims, memory::format_tag::oihw, W.data);
// Pad the weights
memory::dims weightsPadDims = weightsDims;
weightsPadDims[1] = getPadded<K>(weightsDims[1]); // IC
weightsPadDims[0] = getPadded<K>(weightsDims[0]); // OC
assert(srcDims[1] == weightsPadDims[1]); // srcDims[C] == weightsPadDims[IC]
auto weightsPad = allocTensor(weightsPadDims, memory::format_tag::oihw);
WeightsReorderNode<K>(userWeights, weightsPad).execute(sm);
// Get the biases
const auto& b = weightMap[name + "/b"];
if (b.ndims() != 1)
throw Exception(Error::InvalidOperation, "invalid convolution biases");
memory::dims biasDims = b.dims;
// Copy/pad the biases
memory::dims biasPadDims = {getPadded<K>(biasDims[0])};
auto bias = allocTensor(biasPadDims);
if (biasDims[0] != biasPadDims[0])
memset(bias->get_data_handle(), 0, biasPadDims[0]*sizeof(float));
memcpy(bias->get_data_handle(), b.data, biasDims[0]*sizeof(float));
// Allocate memory for destination
memory::dims dstDims = srcDims;
dstDims[1] = weightsPadDims[0]; // dstDims[C] = weightsPadDims[OC]
std::shared_ptr<memory> dst;
if (!userDst)
dst = allocTensor(dstDims);
else if (getTensorDims(userDst) == dstDims)
dst = userDst;
else
dst = castTensor(dstDims, userDst);
// Create a convolution
// Let the convolution primitive choose the weights format
auto weightsDesc = memory::desc({ weightsPadDims }, memory::data_type::f32, memory::format_tag::any);
auto convAlgo = (K == 16) ? convolution_winograd : convolution_direct;
auto convDesc = convolution_forward::desc(
prop_kind::forward_inference, convAlgo,
src->get_desc(),
weightsDesc,
bias->get_desc(),
dst->get_desc(),
strides, padding, padding, padding_kind::zero);
// Incorporate relu
mkldnn::primitive_attr convAttr;
if (relu)
{
mkldnn::post_ops ops;
ops.append_eltwise(
1.f, // scale factor, not used
algorithm::eltwise_relu,
0.f, // max with
0.f // unused
);
convAttr.set_post_ops(ops);
}
convAttr.set_scratchpad_mode(scratchpad_mode_user);
auto convPrimDesc = convolution_forward::primitive_desc(convDesc, convAttr, eng);
// Reorder the weights to the final format, if necessary
auto weights = weightsPad;
if (convPrimDesc.weights_desc() != weightsPad->get_desc())
{
weights = std::make_shared<memory>(convPrimDesc.weights_desc(), eng);
ReorderNode(weightsPad, weights).execute(sm);
}
// Create convolution node and add it to the net
auto node = std::make_shared<ConvNode>(convPrimDesc, src, weights, bias, dst);
nodes.push_back(node);
return node;
}
template<int K>
memory::dims Network<K>::getPoolDims(const memory::dims& srcDims)
{
memory::dims dstDims = srcDims;
dstDims[2] /= 2; // H/2
dstDims[3] /= 2; // W/2
return dstDims;
}
template<int K>
std::shared_ptr<Node> Network<K>::addPool(const std::shared_ptr<memory>& src,
const std::shared_ptr<memory>& userDst)
{
const memory::dims kernel = {2, 2};
const memory::dims strides = {2, 2};
const memory::dims padding = {0, 0};
memory::dims srcDims = getTensorDims(src);
memory::dims dstDims = getPoolDims(srcDims);
std::shared_ptr<memory> dst;
if (!userDst)
dst = allocTensor(dstDims);
else if (getTensorDims(userDst) == dstDims)
dst = userDst;
else
dst = castTensor(dstDims, userDst);
auto poolDesc = pooling_forward::desc(
prop_kind::forward_inference, pooling_max,
src->get_desc(),
dst->get_desc(),
strides, kernel, padding, padding, padding_kind::zero);
mkldnn::primitive_attr poolAttr;
poolAttr.set_scratchpad_mode(scratchpad_mode_user);
auto poolPrimDesc = pooling_forward::primitive_desc(poolDesc, poolAttr, eng);
auto node = std::make_shared<PoolNode>(poolPrimDesc, src, dst);
nodes.push_back(node);
return node;
}
template<int K>
memory::dims Network<K>::getUpsampleDims(const memory::dims& srcDims)
{
memory::dims dstDims = srcDims;
dstDims[2] *= 2; // H*2
dstDims[3] *= 2; // W*2
return dstDims;
}
template<int K>
std::shared_ptr<Node> Network<K>::addUpsample(const std::shared_ptr<memory>& src,
const std::shared_ptr<memory>& userDst)
{
memory::dims srcDims = getTensorDims(src);
memory::dims dstDims = getUpsampleDims(srcDims);
std::shared_ptr<memory> dst;
if (!userDst)
dst = allocTensor(dstDims);
else if (getTensorDims(userDst) == dstDims)
dst = userDst;
else
dst = castTensor(dstDims, userDst);
// Create upsampling node and add it to net
auto node = std::make_shared<UpsampleNode<K>>(src, dst);
nodes.push_back(node);
return node;
}
template<int K>
memory::dims Network<K>::getConcatDims(const memory::dims& src1Dims, const memory::dims& src2Dims)
{
assert(src1Dims[0] == src2Dims[0]); // N
assert(src1Dims[2] == src2Dims[2]); // H
assert(src1Dims[3] == src2Dims[3]); // W
memory::dims dstDims = src1Dims;
dstDims[1] += src2Dims[1]; // C
return dstDims;
}
template<int K>
std::shared_ptr<Node> Network<K>::addAutoexposure(const Image& color,
const std::shared_ptr<HDRTransferFunction>& transferFunc)
{
auto node = std::make_shared<AutoexposureNode>(color, transferFunc);
nodes.push_back(node);
return node;
}
template <int K>
void Network<K>::finalize()
{
// Compute the size of the scratchpad
size_t scratchpadSize = 0;
for (const auto& node : nodes)
scratchpadSize = max(scratchpadSize, node->getScratchpadSize());
// Allocate the scratchpad
memory::dims scratchpadDims = { memory::dim(scratchpadSize) };
memory::desc scratchpadDesc(scratchpadDims, memory::data_type::u8, memory::format_tag::x);
auto scratchpad = std::make_shared<memory>(scratchpadDesc, eng);
activationAllocBytes += scratchpadSize;
totalAllocBytes += scratchpadSize;
// Set the scratchpad for the nodes
for (auto& node : nodes)
node->setScratchpad(scratchpad);
// Free the weights
weightMap.clear();
// Print statistics
if (device->isVerbose(2))
{
std::cout << "Activation bytes: " << activationAllocBytes << std::endl;
std::cout << "Scratchpad bytes: " << scratchpadSize << std::endl;
std::cout << "Total bytes : " << totalAllocBytes << std::endl;
}
}
template class Network<8>;
template class Network<16>;
} // namespace oidn

View file

@ -1,112 +0,0 @@
// ======================================================================== //
// Copyright 2009-2019 Intel Corporation //
// //
// Licensed under the Apache License, Version 2.0 (the "License"); //
// you may not use this file except in compliance with the License. //
// You may obtain a copy of the License at //
// //
// http://www.apache.org/licenses/LICENSE-2.0 //
// //
// Unless required by applicable law or agreed to in writing, software //
// distributed under the License is distributed on an "AS IS" BASIS, //
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
// See the License for the specific language governing permissions and //
// limitations under the License. //
// ======================================================================== //
#include "common/tensor.h"
#include "image.h"
#include "node.h"
#include "input_reorder.h"
#include "output_reorder.h"
#include "transfer_function.h"
#pragma once
namespace oidn {
// Progress state
struct Progress
{
ProgressMonitorFunction func;
void* userPtr;
int taskCount;
};
class Executable
{
public:
virtual ~Executable() {}
virtual void execute(const Progress& progress, int taskIndex) = 0;
};
template<int K>
class Network : public Executable
{
public:
Network(const Ref<Device>& device, const std::map<std::string, Tensor>& weightMap);
void execute(const Progress& progress, int taskIndex) override;
std::shared_ptr<memory> allocTensor(const memory::dims& dims,
memory::format_tag format = memory::format_tag::any,
void* data = nullptr);
std::shared_ptr<memory> castTensor(const memory::dims& dims,
const std::shared_ptr<memory>& src,
size_t srcOffset = 0,
memory::format_tag format = memory::format_tag::any);
std::shared_ptr<memory> castTensor(const memory::dims& dims,
const std::shared_ptr<memory>& src,
const memory::dims& srcOffset);
void zeroTensor(const std::shared_ptr<memory>& dst);
memory::dims getInputReorderDims(const memory::dims& srcDims, int alignment);
std::shared_ptr<Node> addInputReorder(const Image& color,
const Image& albedo,
const Image& normal,
const std::shared_ptr<TransferFunction>& transferFunc,
int alignment,
const std::shared_ptr<memory>& userDst = nullptr);
std::shared_ptr<Node> addOutputReorder(const std::shared_ptr<memory>& src,
const std::shared_ptr<TransferFunction>& transferFunc,
const Image& output);
memory::dims getConvDims(const std::string& name, const memory::dims& srcDims);
std::shared_ptr<Node> addConv(const std::string& name,
const std::shared_ptr<memory>& src,
const std::shared_ptr<memory>& userDst = nullptr,
bool relu = true);
memory::dims getPoolDims(const memory::dims& srcDims);
std::shared_ptr<Node> addPool(const std::shared_ptr<memory>& src,
const std::shared_ptr<memory>& userDst = nullptr);
memory::dims getUpsampleDims(const memory::dims& srcDims);
std::shared_ptr<Node> addUpsample(const std::shared_ptr<memory>& src,
const std::shared_ptr<memory>& userDst = nullptr);
memory::dims getConcatDims(const memory::dims& src1Dims, const memory::dims& src2Dims);
std::shared_ptr<Node> addAutoexposure(const Image& color,
const std::shared_ptr<HDRTransferFunction>& transferFunc);
void finalize();
private:
Ref<Device> device;
engine eng;
stream sm;
std::vector<std::shared_ptr<Node>> nodes;
std::map<std::string, Tensor> weightMap;
// Memory allocation statistics
size_t activationAllocBytes = 0; // number of allocated activation bytes
size_t totalAllocBytes = 0; // total number of allocated bytes
};
} // namespace oidn

View file

@ -1,142 +0,0 @@
// ======================================================================== //
// Copyright 2009-2019 Intel Corporation //
// //
// Licensed under the Apache License, Version 2.0 (the "License"); //
// you may not use this file except in compliance with the License. //
// You may obtain a copy of the License at //
// //
// http://www.apache.org/licenses/LICENSE-2.0 //
// //
// Unless required by applicable law or agreed to in writing, software //
// distributed under the License is distributed on an "AS IS" BASIS, //
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
// See the License for the specific language governing permissions and //
// limitations under the License. //
// ======================================================================== //
#pragma once
#include "common.h"
#include <vector>
namespace oidn {
class Node
{
public:
virtual ~Node() = default;
virtual void execute(stream& sm) = 0;
virtual std::shared_ptr<memory> getDst() const { return nullptr; }
virtual size_t getScratchpadSize() const { return 0; }
virtual void setScratchpad(const std::shared_ptr<memory>& mem) {}
virtual void setTile(int h1, int w1, int h2, int w2, int H, int W)
{
assert(0); // not supported
}
};
// Node wrapping an MKL-DNN primitive
class MklNode : public Node
{
private:
primitive prim;
std::unordered_map<int, memory> args;
std::shared_ptr<memory> scratchpad;
public:
MklNode(const primitive& prim, const std::unordered_map<int, memory>& args)
: prim(prim),
args(args)
{}
size_t getScratchpadSize() const override
{
const auto primDesc = prim.get_primitive_desc();
const mkldnn_memory_desc_t* scratchpadDesc = mkldnn_primitive_desc_query_md(primDesc, mkldnn_query_scratchpad_md, 0);
if (scratchpadDesc == nullptr)
return 0;
return mkldnn_memory_desc_get_size(scratchpadDesc);
}
void setScratchpad(const std::shared_ptr<memory>& mem) override
{
scratchpad = mem;
args.insert(std::make_pair(MKLDNN_ARG_SCRATCHPAD, *scratchpad));
}
void execute(stream& sm) override
{
prim.execute(sm, args);
}
};
// Convolution node
class ConvNode : public MklNode
{
private:
std::shared_ptr<memory> src;
std::shared_ptr<memory> weights;
std::shared_ptr<memory> bias;
std::shared_ptr<memory> dst;
public:
ConvNode(const convolution_forward::primitive_desc& desc,
const std::shared_ptr<memory>& src,
const std::shared_ptr<memory>& weights,
const std::shared_ptr<memory>& bias,
const std::shared_ptr<memory>& dst)
: MklNode(convolution_forward(desc),
{ { MKLDNN_ARG_SRC, *src },
{ MKLDNN_ARG_WEIGHTS, *weights },
{ MKLDNN_ARG_BIAS, *bias },
{ MKLDNN_ARG_DST, *dst } }),
src(src), weights(weights), bias(bias), dst(dst)
{}
std::shared_ptr<memory> getDst() const override { return dst; }
};
// Pooling node
class PoolNode : public MklNode
{
private:
std::shared_ptr<memory> src;
std::shared_ptr<memory> dst;
public:
PoolNode(const pooling_forward::primitive_desc& desc,
const std::shared_ptr<memory>& src,
const std::shared_ptr<memory>& dst)
: MklNode(pooling_forward(desc),
{ { MKLDNN_ARG_SRC, *src },
{ MKLDNN_ARG_DST, *dst } }),
src(src), dst(dst)
{}
std::shared_ptr<memory> getDst() const override { return dst; }
};
// Reorder node
class ReorderNode : public MklNode
{
private:
std::shared_ptr<memory> src;
std::shared_ptr<memory> dst;
public:
ReorderNode(const std::shared_ptr<memory>& src,
const std::shared_ptr<memory>& dst)
: MklNode(reorder(reorder::primitive_desc(*src, *dst)),
{ { MKLDNN_ARG_SRC, *src },
{ MKLDNN_ARG_DST, *dst } }),
src(src), dst(dst)
{}
std::shared_ptr<memory> getDst() const override { return dst; }
};
} // namespace oidn

View file

@ -1,126 +0,0 @@
// ======================================================================== //
// Copyright 2009-2019 Intel Corporation //
// //
// Licensed under the Apache License, Version 2.0 (the "License"); //
// you may not use this file except in compliance with the License. //
// You may obtain a copy of the License at //
// //
// http://www.apache.org/licenses/LICENSE-2.0 //
// //
// Unless required by applicable law or agreed to in writing, software //
// distributed under the License is distributed on an "AS IS" BASIS, //
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
// See the License for the specific language governing permissions and //
// limitations under the License. //
// ======================================================================== //
#pragma once
#include "node.h"
#include "image.h"
namespace oidn {
// Output reorder node
template<int K, class TransferFunction>
class OutputReorderNode : public Node
{
private:
// Source
std::shared_ptr<memory> src;
const float* srcPtr;
int H1;
int W1;
// Destination
Image output;
// Tile
int h1Begin;
int w1Begin;
int h2Begin;
int w2Begin;
int H;
int W;
std::shared_ptr<TransferFunction> transferFunc;
public:
OutputReorderNode(const std::shared_ptr<memory>& src,
const Image& output,
const std::shared_ptr<TransferFunction>& transferFunc)
: src(src),
output(output),
h1Begin(0), w1Begin(0),
h2Begin(0), w2Begin(0),
H(output.height), W(output.width),
transferFunc(transferFunc)
{
const mkldnn_memory_desc_t& srcDesc = src->get_desc().data;
MAYBE_UNUSED(srcDesc);
assert(memory_desc_matches_tag(srcDesc, mkldnn_format_tag_t(BlockedFormat<K>::nChwKc)));
assert(srcDesc.ndims == 4);
assert(srcDesc.data_type == memory::data_type::f32);
assert(srcDesc.dims[0] == 1);
// We assume output data is <= K OC
assert(srcDesc.dims[1] == K);
srcPtr = (float*)src->get_data_handle();
H1 = srcDesc.dims[2];
W1 = srcDesc.dims[3];
}
void setTile(int h1, int w1, int h2, int w2, int H, int W) override
{
h1Begin = h1;
w1Begin = w1;
h2Begin = h2;
w2Begin = w2;
this->H = H;
this->W = W;
}
void execute(stream& sm) override
{
assert(h1Begin + H <= H1);
assert(w1Begin + W <= W1);
assert(h2Begin + H <= output.height);
assert(w2Begin + W <= output.width);
const int C1 = K;
parallel_nd(H, [&](int h)
{
const int h1 = h + h1Begin;
const int h2 = h + h2Begin;
for (int w = 0; w < W; ++w)
{
const int w1 = w + w1Begin;
const int w2 = w + w2Begin;
float* dstPtr_C = (float*)output.get(h2, w2);
// Source is in nChwKc format. In this case C is 1 so this is really nhwc
const float* srcPtr_C = srcPtr + h1*W1*C1 + w1*C1;
#pragma unroll
for (int i = 0; i < 3; ++i)
{
// Load the value
float x = srcPtr_C[i];
// The CNN output may contain negative values or even NaNs, so it must be sanitized
x = maxSafe(x, 0.f);
// Apply the inverse transfer function
x = transferFunc->inverse(x);
// Sanitize and store the final value
dstPtr_C[i] = max(x, 0.f);
}
}
});
}
};
} // namespace oidn

View file

@ -1,103 +0,0 @@
// ======================================================================== //
// Copyright 2009-2019 Intel Corporation //
// //
// Licensed under the Apache License, Version 2.0 (the "License"); //
// you may not use this file except in compliance with the License. //
// You may obtain a copy of the License at //
// //
// http://www.apache.org/licenses/LICENSE-2.0 //
// //
// Unless required by applicable law or agreed to in writing, software //
// distributed under the License is distributed on an "AS IS" BASIS, //
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
// See the License for the specific language governing permissions and //
// limitations under the License. //
// ======================================================================== //
#include "transfer_function.h"
namespace oidn {
const float LogTransferFunction::xScale = 1.f / log(LogTransferFunction::yMax + 1.f);
const float PQXTransferFunction::xScale = 1.f / PQXTransferFunction::pqxForward(PQXTransferFunction::yMax * PQXTransferFunction::yScale);
float AutoexposureNode::autoexposure(const Image& color)
{
assert(color.format == Format::Float3);
constexpr float key = 0.18f;
constexpr float eps = 1e-8f;
constexpr int K = 16; // downsampling amount
// Downsample the image to minimize sensitivity to noise
const int H = color.height; // original height
const int W = color.width; // original width
const int HK = (H + K/2) / K; // downsampled height
const int WK = (W + K/2) / K; // downsampled width
// Compute the average log luminance of the downsampled image
using Sum = std::pair<float, int>;
// -- GODOT start --
// Sum sum =
// tbb::parallel_reduce(
// tbb::blocked_range2d<int>(0, HK, 0, WK),
// Sum(0.f, 0),
// [&](const tbb::blocked_range2d<int>& r, Sum sum) -> Sum
// {
// // Iterate over blocks
// for (int i = r.rows().begin(); i != r.rows().end(); ++i)
// {
// for (int j = r.cols().begin(); j != r.cols().end(); ++j)
// {
Sum sum = Sum(0.0f, 0);
for (int i = 0; i != HK; ++i)
{
for (int j = 0; j != WK; ++j)
{
// Compute the average luminance in the current block
const int beginH = int(ptrdiff_t(i) * H / HK);
const int beginW = int(ptrdiff_t(j) * W / WK);
const int endH = int(ptrdiff_t(i+1) * H / HK);
const int endW = int(ptrdiff_t(j+1) * W / WK);
float L = 0.f;
for (int h = beginH; h < endH; ++h)
{
for (int w = beginW; w < endW; ++w)
{
const float* rgb = (const float*)color.get(h, w);
const float r = maxSafe(rgb[0], 0.f);
const float g = maxSafe(rgb[1], 0.f);
const float b = maxSafe(rgb[2], 0.f);
L += luminance(r, g, b);
}
}
L /= (endH - beginH) * (endW - beginW);
// Accumulate the log luminance
if (L > eps)
{
sum.first += log2(L);
sum.second++;
}
}
}
// return sum;
// },
// [](Sum a, Sum b) -> Sum { return Sum(a.first+b.first, a.second+b.second); },
// tbb::static_partitioner()
// );
// -- GODOT end --
return (sum.second > 0) ? (key / exp2(sum.first / float(sum.second))) : 1.f;
}
} // namespace oidn

View file

@ -1,201 +0,0 @@
// ======================================================================== //
// Copyright 2009-2019 Intel Corporation //
// //
// Licensed under the Apache License, Version 2.0 (the "License"); //
// you may not use this file except in compliance with the License. //
// You may obtain a copy of the License at //
// //
// http://www.apache.org/licenses/LICENSE-2.0 //
// //
// Unless required by applicable law or agreed to in writing, software //
// distributed under the License is distributed on an "AS IS" BASIS, //
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
// See the License for the specific language governing permissions and //
// limitations under the License. //
// ======================================================================== //
#pragma once
#include "image.h"
#include "node.h"
namespace oidn {
__forceinline float luminance(float r, float g, float b)
{
return 0.212671f * r + 0.715160f * g + 0.072169f * b;
}
// Color transfer function base class
class TransferFunction
{
public:
virtual ~TransferFunction() = default;
virtual float forward(float y) const = 0;
virtual float inverse(float x) const = 0;
};
// HDR transfer function base class
class HDRTransferFunction : public TransferFunction
{
protected:
static constexpr float yMax = 65504.f;
float exposure;
float rcpExposure;
public:
HDRTransferFunction(float exposure = 1.f)
{
setExposure(exposure);
}
void setExposure(float exposure)
{
this->exposure = exposure;
this->rcpExposure = (exposure != 0.f) ? (1.f / exposure) : 0.f;
}
};
// Linear transfer function (LDR)
class LinearTransferFunction : public TransferFunction
{
public:
__forceinline float forward(float y) const override
{
return min(y, 1.f);
}
__forceinline float inverse(float x) const override
{
return min(x, 1.f);
}
};
// 2.2 gamma transfer function (LDR)
class GammaTransferFunction : public TransferFunction
{
public:
__forceinline float forward(float y) const override
{
return min(pow(y, 1.f/2.2f), 1.f);
}
__forceinline float inverse(float x) const override
{
return min(pow(x, 2.2f), 1.f);
}
};
// Logarithmic transfer function (HDR)
// Compresses [0..65504] to [0..1]
class LogTransferFunction : public HDRTransferFunction
{
private:
static const float xScale;
public:
LogTransferFunction(float exposure = 1.f)
: HDRTransferFunction(exposure)
{
}
__forceinline float forward(float y) const override
{
return log(y * exposure + 1.f) * xScale;
}
__forceinline float inverse(float x) const override
{
return (exp(x * (1.f/xScale)) - 1.f) * rcpExposure;
}
};
// PQX transfer function (HDR)
// Compresses [0..65504] to [0..1]
class PQXTransferFunction : public HDRTransferFunction
{
private:
static constexpr float m1 = 2610.f / 4096.f / 4.f;
static constexpr float m2 = 2523.f / 4096.f * 128.f;
static constexpr float c1 = 3424.f / 4096.f;
static constexpr float c2 = 2413.f / 4096.f * 32.f;
static constexpr float c3 = 2392.f / 4096.f * 32.f;
static constexpr float a = 3711.f / 4096.f / 8.f;
static constexpr float yScale = 100.f / 10000.f;
static const float xScale;
public:
PQXTransferFunction(float exposure = 1.f)
: HDRTransferFunction(exposure)
{
}
__forceinline float forward(float y) const override
{
return pqxForward(y * exposure * yScale) * xScale;
}
__forceinline float inverse(float x) const override
{
return pqxInverse(x * (1.f/xScale)) * (1.f/yScale) * rcpExposure;
}
private:
static __forceinline float pqForward(float y)
{
const float yp = pow(y, m1);
return pow((c1 + c2 * yp) * rcp(1.f + c3 * yp), m2);
}
static __forceinline float pqxForward(float y)
{
if (y <= 1.f)
return pqForward(y);
else
return a * log(y) + 1.f;
}
static __forceinline float pqInverse(float x)
{
const float xp = pow(x, 1.f/m2);
return pow(max((xp - c1) * rcp(c2 - c3 * xp), 0.f), 1.f/m1);
}
static __forceinline float pqxInverse(float x)
{
if (x <= 1.f)
return pqInverse(x);
else
return exp((x - 1.f) * (1.f/a));
}
};
// Autoexposure node
class AutoexposureNode : public Node
{
private:
Image color;
std::shared_ptr<HDRTransferFunction> transferFunc;
public:
AutoexposureNode(const Image& color,
const std::shared_ptr<HDRTransferFunction>& transferFunc)
: color(color),
transferFunc(transferFunc)
{}
void execute(stream& sm) override
{
const float exposure = autoexposure(color);
//printf("exposure = %f\n", exposure);
transferFunc->setExposure(exposure);
}
private:
static float autoexposure(const Image& color);
};
} // namespace oidn

View file

@ -1,92 +0,0 @@
// ======================================================================== //
// Copyright 2009-2019 Intel Corporation //
// //
// Licensed under the Apache License, Version 2.0 (the "License"); //
// you may not use this file except in compliance with the License. //
// You may obtain a copy of the License at //
// //
// http://www.apache.org/licenses/LICENSE-2.0 //
// //
// Unless required by applicable law or agreed to in writing, software //
// distributed under the License is distributed on an "AS IS" BASIS, //
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
// See the License for the specific language governing permissions and //
// limitations under the License. //
// ======================================================================== //
#pragma once
#include "node.h"
namespace oidn {
// 2x2 nearest-neighbor upsampling node
template<int K>
class UpsampleNode : public Node
{
private:
std::shared_ptr<memory> src;
std::shared_ptr<memory> dst;
public:
UpsampleNode(const std::shared_ptr<memory>& src,
const std::shared_ptr<memory>& dst)
: src(src),
dst(dst)
{
const mkldnn_memory_desc_t& srcDesc = src->get_desc().data;
const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data;
MAYBE_UNUSED(srcDesc);
MAYBE_UNUSED(dstDesc);
assert(memory_desc_matches_tag(srcDesc, mkldnn_format_tag_t(BlockedFormat<K>::nChwKc)));
assert(memory_desc_matches_tag(dstDesc, mkldnn_format_tag_t(BlockedFormat<K>::nChwKc)));
assert(srcDesc.ndims == 4);
assert(dstDesc.ndims == 4);
assert(srcDesc.data_type == memory::data_type::f32);
assert(dstDesc.data_type == memory::data_type::f32);
assert(srcDesc.dims[0] == 1);
assert(dstDesc.dims[0] == 1);
// 2x2 upsampling
assert(dstDesc.dims[2] == srcDesc.dims[2] * 2);
assert(dstDesc.dims[3] == srcDesc.dims[3] * 2);
}
void execute(stream& sm) override
{
const mkldnn_memory_desc_t& srcDesc = src->get_desc().data;
const float* srcPtr = (float*)src->get_data_handle();
float* dstPtr = (float*)dst->get_data_handle();
const int C = srcDesc.dims[1];
const int H = srcDesc.dims[2];
const int W = srcDesc.dims[3];
const int CK = C / K;
parallel_nd(CK, H, [&](int ck, int h)
{
const size_t offset = ck*H*W*K + h*W*K;
const float* srcPtr_line = srcPtr + offset;
float* dstPtr_line0 = dstPtr + offset * 4;
float* dstPtr_line1 = dstPtr_line0 + W*2*K; // next line
for (int w = 0; w < W; ++w)
{
#pragma unroll
for (int k = 0; k < K; k += 4)
{
const __m128 m = _mm_load_ps(&srcPtr_line[w*K + k]);
_mm_stream_ps(&dstPtr_line0[w*2*K + k], m);
_mm_stream_ps(&dstPtr_line0[w*2*K+K + k], m);
_mm_stream_ps(&dstPtr_line1[w*2*K + k], m);
_mm_stream_ps(&dstPtr_line1[w*2*K+K + k], m);
}
}
});
}
std::shared_ptr<memory> getDst() const override { return dst; }
};
} // namespace oidn

View file

@ -1,99 +0,0 @@
// ======================================================================== //
// Copyright 2009-2019 Intel Corporation //
// //
// Licensed under the Apache License, Version 2.0 (the "License"); //
// you may not use this file except in compliance with the License. //
// You may obtain a copy of the License at //
// //
// http://www.apache.org/licenses/LICENSE-2.0 //
// //
// Unless required by applicable law or agreed to in writing, software //
// distributed under the License is distributed on an "AS IS" BASIS, //
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
// See the License for the specific language governing permissions and //
// limitations under the License. //
// ======================================================================== //
#pragma once
#include "node.h"
namespace oidn {
// Reorders weights from oihw to padded oihw format
template<int K>
class WeightsReorderNode : public Node
{
private:
std::shared_ptr<memory> src;
std::shared_ptr<memory> dst;
public:
WeightsReorderNode(const std::shared_ptr<memory>& src,
const std::shared_ptr<memory>& dst)
: src(src),
dst(dst)
{
const mkldnn_memory_desc_t& srcDesc = src->get_desc().data;
const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data;
MAYBE_UNUSED(srcDesc);
MAYBE_UNUSED(dstDesc);
assert(memory_desc_matches_tag(srcDesc, mkldnn_format_tag_t(memory::format_tag::oihw)));
assert(memory_desc_matches_tag(dstDesc, mkldnn_format_tag_t(memory::format_tag::oihw)));
assert(srcDesc.ndims == 4);
assert(dstDesc.ndims == 4);
assert(srcDesc.data_type == memory::data_type::f32);
assert(dstDesc.data_type == memory::data_type::f32);
assert(getPadded<K>(srcDesc.dims[0]) == dstDesc.dims[0]); // OC
assert(getPadded<K>(srcDesc.dims[1]) == dstDesc.dims[1]); // IC
assert(srcDesc.dims[2] == dstDesc.dims[2]);
assert(srcDesc.dims[3] == dstDesc.dims[3]);
}
void execute(stream& sm) override
{
const mkldnn_memory_desc_t& srcDesc = src->get_desc().data;
const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data;
const float* srcPtr = (float*)src->get_data_handle();
float* dstPtr = (float*)dst->get_data_handle();
const int OC1 = srcDesc.dims[0];
const int OC2 = dstDesc.dims[0];
const int IC1 = srcDesc.dims[1];
const int IC2 = dstDesc.dims[1];
const int H = dstDesc.dims[2];
const int W = dstDesc.dims[3];
for (int oc = 0; oc < OC2; ++oc)
{
for (int ic = 0; ic < IC2; ++ic)
{
for (int h = 0; h < H; ++h)
{
for (int w = 0; w < W; ++w)
{
// Output is in oihw format
float* dstPtr_c = dstPtr + oc*IC2*H*W + ic*H*W + h*W + w;
if (oc < OC1 && ic < IC1)
{
// Input is in oihw format
const float* srcPtr_c = srcPtr + oc*IC1*H*W + ic*H*W + h*W + w;
*dstPtr_c = *srcPtr_c;
}
else
{
// padding
*dstPtr_c = 0;
}
}
}
}
}
}
std::shared_ptr<memory> getDst() const override { return dst; }
};
} // namespace oidn

View file

@ -1,214 +0,0 @@
// ======================================================================== //
// Copyright 2009-2019 Intel Corporation //
// //
// Licensed under the Apache License, Version 2.0 (the "License"); //
// you may not use this file except in compliance with the License. //
// You may obtain a copy of the License at //
// //
// http://www.apache.org/licenses/LICENSE-2.0 //
// //
// Unless required by applicable law or agreed to in writing, software //
// distributed under the License is distributed on an "AS IS" BASIS, //
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
// See the License for the specific language governing permissions and //
// limitations under the License. //
// ======================================================================== //
#pragma once
#include <stddef.h>
#include <stdbool.h>
#include <stdint.h>
#include "version.h"
#if defined(__cplusplus)
extern "C" {
#endif
#ifndef OIDN_API
#if defined(_WIN32) && !defined(OIDN_STATIC_LIB)
# define OIDN_API __declspec(dllimport)
#else
# define OIDN_API
#endif
#endif
// ----------------------------------------------------------------------------
// Device
// ----------------------------------------------------------------------------
// Device types
typedef enum
{
OIDN_DEVICE_TYPE_DEFAULT = 0, // select device automatically
OIDN_DEVICE_TYPE_CPU = 1, // CPU device
} OIDNDeviceType;
// Error codes
typedef enum
{
OIDN_ERROR_NONE = 0, // no error occurred
OIDN_ERROR_UNKNOWN = 1, // an unknown error occurred
OIDN_ERROR_INVALID_ARGUMENT = 2, // an invalid argument was specified
OIDN_ERROR_INVALID_OPERATION = 3, // the operation is not allowed
OIDN_ERROR_OUT_OF_MEMORY = 4, // not enough memory to execute the operation
OIDN_ERROR_UNSUPPORTED_HARDWARE = 5, // the hardware (e.g. CPU) is not supported
OIDN_ERROR_CANCELLED = 6, // the operation was cancelled by the user
} OIDNError;
// Error callback function
typedef void (*OIDNErrorFunction)(void* userPtr, OIDNError code, const char* message);
// Device handle
typedef struct OIDNDeviceImpl* OIDNDevice;
// Creates a new device.
OIDN_API OIDNDevice oidnNewDevice(OIDNDeviceType type);
// Retains the device (increments the reference count).
OIDN_API void oidnRetainDevice(OIDNDevice device);
// Releases the device (decrements the reference count).
OIDN_API void oidnReleaseDevice(OIDNDevice device);
// Sets a boolean parameter of the device.
OIDN_API void oidnSetDevice1b(OIDNDevice device, const char* name, bool value);
// Sets an integer parameter of the device.
OIDN_API void oidnSetDevice1i(OIDNDevice device, const char* name, int value);
// Gets a boolean parameter of the device.
OIDN_API bool oidnGetDevice1b(OIDNDevice device, const char* name);
// Gets an integer parameter of the device (e.g. "version").
OIDN_API int oidnGetDevice1i(OIDNDevice device, const char* name);
// Sets the error callback function of the device.
OIDN_API void oidnSetDeviceErrorFunction(OIDNDevice device, OIDNErrorFunction func, void* userPtr);
// Returns the first unqueried error code stored in the device for the current
// thread, optionally also returning a string message (if not NULL), and clears
// the stored error. Can be called with a NULL device as well to check why a
// device creation failed.
OIDN_API OIDNError oidnGetDeviceError(OIDNDevice device, const char** outMessage);
// Commits all previous changes to the device.
// Must be called before first using the device (e.g. creating filters).
OIDN_API void oidnCommitDevice(OIDNDevice device);
// ----------------------------------------------------------------------------
// Buffer
// ----------------------------------------------------------------------------
// Formats for images and other data stored in buffers
typedef enum
{
OIDN_FORMAT_UNDEFINED = 0,
// 32-bit single-precision floating point scalar and vector formats
OIDN_FORMAT_FLOAT = 1,
OIDN_FORMAT_FLOAT2 = 2,
OIDN_FORMAT_FLOAT3 = 3,
OIDN_FORMAT_FLOAT4 = 4,
} OIDNFormat;
// Access modes for mapping buffers
typedef enum
{
OIDN_ACCESS_READ = 0, // read-only access
OIDN_ACCESS_WRITE = 1, // write-only access
OIDN_ACCESS_READ_WRITE = 2, // read and write access
OIDN_ACCESS_WRITE_DISCARD = 3, // write-only access, previous contents discarded
} OIDNAccess;
// Buffer handle
typedef struct OIDNBufferImpl* OIDNBuffer;
// Creates a new buffer (data allocated and owned by the device).
OIDN_API OIDNBuffer oidnNewBuffer(OIDNDevice device, size_t byteSize);
// Creates a new shared buffer (data allocated and owned by the user).
OIDN_API OIDNBuffer oidnNewSharedBuffer(OIDNDevice device, void* ptr, size_t byteSize);
// Maps a region of the buffer to host memory.
// If byteSize is 0, the maximum available amount of memory will be mapped.
OIDN_API void* oidnMapBuffer(OIDNBuffer buffer, OIDNAccess access, size_t byteOffset, size_t byteSize);
// Unmaps a region of the buffer.
// mappedPtr must be a pointer returned by a previous call to oidnMapBuffer.
OIDN_API void oidnUnmapBuffer(OIDNBuffer buffer, void* mappedPtr);
// Retains the buffer (increments the reference count).
OIDN_API void oidnRetainBuffer(OIDNBuffer buffer);
// Releases the buffer (decrements the reference count).
OIDN_API void oidnReleaseBuffer(OIDNBuffer buffer);
// ----------------------------------------------------------------------------
// Filter
// ----------------------------------------------------------------------------
// Progress monitor callback function
typedef bool (*OIDNProgressMonitorFunction)(void* userPtr, double n);
// Filter handle
typedef struct OIDNFilterImpl* OIDNFilter;
// Creates a new filter of the specified type (e.g. "RT").
OIDN_API OIDNFilter oidnNewFilter(OIDNDevice device, const char* type);
// Retains the filter (increments the reference count).
OIDN_API void oidnRetainFilter(OIDNFilter filter);
// Releases the filter (decrements the reference count).
OIDN_API void oidnReleaseFilter(OIDNFilter filter);
// Sets an image parameter of the filter (stored in a buffer).
// If bytePixelStride and/or byteRowStride are zero, these will be computed automatically.
OIDN_API void oidnSetFilterImage(OIDNFilter filter, const char* name,
OIDNBuffer buffer, OIDNFormat format,
size_t width, size_t height,
size_t byteOffset,
size_t bytePixelStride, size_t byteRowStride);
// Sets an image parameter of the filter (owned by the user).
// If bytePixelStride and/or byteRowStride are zero, these will be computed automatically.
OIDN_API void oidnSetSharedFilterImage(OIDNFilter filter, const char* name,
void* ptr, OIDNFormat format,
size_t width, size_t height,
size_t byteOffset,
size_t bytePixelStride, size_t byteRowStride);
// Sets a boolean parameter of the filter.
OIDN_API void oidnSetFilter1b(OIDNFilter filter, const char* name, bool value);
// Gets a boolean parameter of the filter.
OIDN_API bool oidnGetFilter1b(OIDNFilter filter, const char* name);
// Sets an integer parameter of the filter.
OIDN_API void oidnSetFilter1i(OIDNFilter filter, const char* name, int value);
// Gets an integer parameter of the filter.
OIDN_API int oidnGetFilter1i(OIDNFilter filter, const char* name);
// Sets a float parameter of the filter.
OIDN_API void oidnSetFilter1f(OIDNFilter filter, const char* name, float value);
// Gets a float parameter of the filter.
OIDN_API float oidnGetFilter1f(OIDNFilter filter, const char* name);
// Sets the progress monitor callback function of the filter.
OIDN_API void oidnSetFilterProgressMonitorFunction(OIDNFilter filter, OIDNProgressMonitorFunction func, void* userPtr);
// Commits all previous changes to the filter.
// Must be called before first executing the filter.
OIDN_API void oidnCommitFilter(OIDNFilter filter);
// Executes the filter.
OIDN_API void oidnExecuteFilter(OIDNFilter filter);
#if defined(__cplusplus)
}
#endif

View file

@ -1,468 +0,0 @@
// ======================================================================== //
// Copyright 2009-2019 Intel Corporation //
// //
// Licensed under the Apache License, Version 2.0 (the "License"); //
// you may not use this file except in compliance with the License. //
// You may obtain a copy of the License at //
// //
// http://www.apache.org/licenses/LICENSE-2.0 //
// //
// Unless required by applicable law or agreed to in writing, software //
// distributed under the License is distributed on an "AS IS" BASIS, //
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
// See the License for the specific language governing permissions and //
// limitations under the License. //
// ======================================================================== //
#pragma once
#include <algorithm>
#include "oidn.h"
namespace oidn {
// --------------------------------------------------------------------------
// Buffer
// --------------------------------------------------------------------------
// Formats for images and other data stored in buffers
enum class Format
{
Undefined = OIDN_FORMAT_UNDEFINED,
// 32-bit single-precision floating point scalar and vector formats
Float = OIDN_FORMAT_FLOAT,
Float2 = OIDN_FORMAT_FLOAT2,
Float3 = OIDN_FORMAT_FLOAT3,
Float4 = OIDN_FORMAT_FLOAT4,
};
// Access modes for mapping buffers
enum class Access
{
Read = OIDN_ACCESS_READ, // read-only access
Write = OIDN_ACCESS_WRITE, // write-only access
ReadWrite = OIDN_ACCESS_READ_WRITE, // read and write access
WriteDiscard = OIDN_ACCESS_WRITE_DISCARD, // write-only access, previous contents discarded
};
// Buffer object with automatic reference counting
class BufferRef
{
private:
OIDNBuffer handle;
public:
BufferRef() : handle(nullptr) {}
BufferRef(OIDNBuffer handle) : handle(handle) {}
BufferRef(const BufferRef& other) : handle(other.handle)
{
if (handle)
oidnRetainBuffer(handle);
}
BufferRef(BufferRef&& other) : handle(other.handle)
{
other.handle = nullptr;
}
BufferRef& operator =(const BufferRef& other)
{
if (&other != this)
{
if (other.handle)
oidnRetainBuffer(other.handle);
if (handle)
oidnReleaseBuffer(handle);
handle = other.handle;
}
return *this;
}
BufferRef& operator =(BufferRef&& other)
{
std::swap(handle, other.handle);
return *this;
}
BufferRef& operator =(OIDNBuffer other)
{
if (other)
oidnRetainBuffer(other);
if (handle)
oidnReleaseBuffer(handle);
handle = other;
return *this;
}
~BufferRef()
{
if (handle)
oidnReleaseBuffer(handle);
}
OIDNBuffer getHandle() const
{
return handle;
}
operator bool() const
{
return handle != nullptr;
}
// Maps a region of the buffer to host memory.
// If byteSize is 0, the maximum available amount of memory will be mapped.
void* map(Access access = Access::ReadWrite, size_t byteOffset = 0, size_t byteSize = 0)
{
return oidnMapBuffer(handle, (OIDNAccess)access, byteOffset, byteSize);
}
// Unmaps a region of the buffer.
// mappedPtr must be a pointer returned by a previous call to map.
void unmap(void* mappedPtr)
{
oidnUnmapBuffer(handle, mappedPtr);
}
};
// --------------------------------------------------------------------------
// Filter
// --------------------------------------------------------------------------
// Progress monitor callback function
typedef bool (*ProgressMonitorFunction)(void* userPtr, double n);
// Filter object with automatic reference counting
class FilterRef
{
private:
OIDNFilter handle;
public:
FilterRef() : handle(nullptr) {}
FilterRef(OIDNFilter handle) : handle(handle) {}
FilterRef(const FilterRef& other) : handle(other.handle)
{
if (handle)
oidnRetainFilter(handle);
}
FilterRef(FilterRef&& other) : handle(other.handle)
{
other.handle = nullptr;
}
FilterRef& operator =(const FilterRef& other)
{
if (&other != this)
{
if (other.handle)
oidnRetainFilter(other.handle);
if (handle)
oidnReleaseFilter(handle);
handle = other.handle;
}
return *this;
}
FilterRef& operator =(FilterRef&& other)
{
std::swap(handle, other.handle);
return *this;
}
FilterRef& operator =(OIDNFilter other)
{
if (other)
oidnRetainFilter(other);
if (handle)
oidnReleaseFilter(handle);
handle = other;
return *this;
}
~FilterRef()
{
if (handle)
oidnReleaseFilter(handle);
}
OIDNFilter getHandle() const
{
return handle;
}
operator bool() const
{
return handle != nullptr;
}
// Sets an image parameter of the filter (stored in a buffer).
void setImage(const char* name,
const BufferRef& buffer, Format format,
size_t width, size_t height,
size_t byteOffset = 0,
size_t bytePixelStride = 0, size_t byteRowStride = 0)
{
oidnSetFilterImage(handle, name,
buffer.getHandle(), (OIDNFormat)format,
width, height,
byteOffset,
bytePixelStride, byteRowStride);
}
// Sets an image parameter of the filter (owned by the user).
void setImage(const char* name,
void* ptr, Format format,
size_t width, size_t height,
size_t byteOffset = 0,
size_t bytePixelStride = 0, size_t byteRowStride = 0)
{
oidnSetSharedFilterImage(handle, name,
ptr, (OIDNFormat)format,
width, height,
byteOffset,
bytePixelStride, byteRowStride);
}
// Sets a boolean parameter of the filter.
void set(const char* name, bool value)
{
oidnSetFilter1b(handle, name, value);
}
// Sets an integer parameter of the filter.
void set(const char* name, int value)
{
oidnSetFilter1i(handle, name, value);
}
// Sets a float parameter of the filter.
void set(const char* name, float value)
{
oidnSetFilter1f(handle, name, value);
}
// Gets a parameter of the filter.
template<typename T>
T get(const char* name);
// Sets the progress monitor callback function of the filter.
void setProgressMonitorFunction(ProgressMonitorFunction func, void* userPtr = nullptr)
{
oidnSetFilterProgressMonitorFunction(handle, (OIDNProgressMonitorFunction)func, userPtr);
}
// Commits all previous changes to the filter.
void commit()
{
oidnCommitFilter(handle);
}
// Executes the filter.
void execute()
{
oidnExecuteFilter(handle);
}
};
// Gets a boolean parameter of the filter.
template<>
inline bool FilterRef::get(const char* name)
{
return oidnGetFilter1b(handle, name);
}
// Gets an integer parameter of the filter.
template<>
inline int FilterRef::get(const char* name)
{
return oidnGetFilter1i(handle, name);
}
// Gets a float parameter of the filter.
template<>
inline float FilterRef::get(const char* name)
{
return oidnGetFilter1f(handle, name);
}
// --------------------------------------------------------------------------
// Device
// --------------------------------------------------------------------------
// Device types
enum class DeviceType
{
Default = OIDN_DEVICE_TYPE_DEFAULT, // select device automatically
CPU = OIDN_DEVICE_TYPE_CPU, // CPU device
};
// Error codes
enum class Error
{
None = OIDN_ERROR_NONE, // no error occurred
Unknown = OIDN_ERROR_UNKNOWN, // an unknown error occurred
InvalidArgument = OIDN_ERROR_INVALID_ARGUMENT, // an invalid argument was specified
InvalidOperation = OIDN_ERROR_INVALID_OPERATION, // the operation is not allowed
OutOfMemory = OIDN_ERROR_OUT_OF_MEMORY, // not enough memory to execute the operation
UnsupportedHardware = OIDN_ERROR_UNSUPPORTED_HARDWARE, // the hardware (e.g. CPU) is not supported
Cancelled = OIDN_ERROR_CANCELLED, // the operation was cancelled by the user
};
// Error callback function
typedef void (*ErrorFunction)(void* userPtr, Error code, const char* message);
// Device object with automatic reference counting
class DeviceRef
{
private:
OIDNDevice handle;
public:
DeviceRef() : handle(nullptr) {}
DeviceRef(OIDNDevice handle) : handle(handle) {}
DeviceRef(const DeviceRef& other) : handle(other.handle)
{
if (handle)
oidnRetainDevice(handle);
}
DeviceRef(DeviceRef&& other) : handle(other.handle)
{
other.handle = nullptr;
}
DeviceRef& operator =(const DeviceRef& other)
{
if (&other != this)
{
if (other.handle)
oidnRetainDevice(other.handle);
if (handle)
oidnReleaseDevice(handle);
handle = other.handle;
}
return *this;
}
DeviceRef& operator =(DeviceRef&& other)
{
std::swap(handle, other.handle);
return *this;
}
DeviceRef& operator =(OIDNDevice other)
{
if (other)
oidnRetainDevice(other);
if (handle)
oidnReleaseDevice(handle);
handle = other;
return *this;
}
~DeviceRef()
{
if (handle)
oidnReleaseDevice(handle);
}
OIDNDevice getHandle() const
{
return handle;
}
operator bool() const
{
return handle != nullptr;
}
// Sets a boolean parameter of the device.
void set(const char* name, bool value)
{
oidnSetDevice1b(handle, name, value);
}
// Sets an integer parameter of the device.
void set(const char* name, int value)
{
oidnSetDevice1i(handle, name, value);
}
// Gets a parameter of the device.
template<typename T>
T get(const char* name);
// Sets the error callback function of the device.
void setErrorFunction(ErrorFunction func, void* userPtr = nullptr)
{
oidnSetDeviceErrorFunction(handle, (OIDNErrorFunction)func, userPtr);
}
// Returns the first unqueried error code and clears the stored error.
// Can be called for a null device as well to check why a device creation failed.
Error getError()
{
return (Error)oidnGetDeviceError(handle, nullptr);
}
// Returns the first unqueried error code and string message, and clears the stored error.
// Can be called for a null device as well to check why a device creation failed.
Error getError(const char*& outMessage)
{
return (Error)oidnGetDeviceError(handle, &outMessage);
}
// Commits all previous changes to the device.
// Must be called before first using the device (e.g. creating filters).
void commit()
{
oidnCommitDevice(handle);
}
// Creates a new buffer (data allocated and owned by the device).
BufferRef newBuffer(size_t byteSize)
{
return oidnNewBuffer(handle, byteSize);
}
// Creates a new shared buffer (data allocated and owned by the user).
BufferRef newBuffer(void* ptr, size_t byteSize)
{
return oidnNewSharedBuffer(handle, ptr, byteSize);
}
// Creates a new filter of the specified type (e.g. "RT").
FilterRef newFilter(const char* type)
{
return oidnNewFilter(handle, type);
}
};
// Gets a boolean parameter of the device.
template<>
inline bool DeviceRef::get(const char* name)
{
return oidnGetDevice1b(handle, name);
}
// Gets an integer parameter of the device (e.g. "version").
template<>
inline int DeviceRef::get(const char* name)
{
return oidnGetDevice1i(handle, name);
}
// Creates a new device.
inline DeviceRef newDevice(DeviceType type = DeviceType::Default)
{
return DeviceRef(oidnNewDevice((OIDNDeviceType)type));
}
} // namespace oidn

View file

@ -1,23 +0,0 @@
// ======================================================================== //
// Copyright 2009-2019 Intel Corporation //
// //
// Licensed under the Apache License, Version 2.0 (the "License"); //
// you may not use this file except in compliance with the License. //
// You may obtain a copy of the License at //
// //
// http://www.apache.org/licenses/LICENSE-2.0 //
// //
// Unless required by applicable law or agreed to in writing, software //
// distributed under the License is distributed on an "AS IS" BASIS, //
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
// See the License for the specific language governing permissions and //
// limitations under the License. //
// ======================================================================== //
#pragma once
#define OIDN_VERSION_MAJOR 1
#define OIDN_VERSION_MINOR 1
#define OIDN_VERSION_PATCH 0
#define OIDN_VERSION 10100
#define OIDN_VERSION_STRING "1.1.0"

View file

@ -1,214 +0,0 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "{}"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright {yyyy} {name of copyright owner}
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
============================================================================
Intel MKL-DNN includes components with separate copyright
notices and license terms.
XByak, 3-clause BSD license
Copyright (c) 2007 MITSUNARI Shigeo
See full copyright notice and license text in src/cpu/xbyak/COPYRIGHT
gtest, 3-clause BSD license
Copyright 2008, Google Inc.
See full copyright notice and license text in tests/gtests/gtest/LICENSE

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -1,98 +0,0 @@
/*******************************************************************************
* Copyright 2018-2019 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
/* DO NOT EDIT, AUTO-GENERATED */
#ifndef MKLDNN_DEBUG_H
#define MKLDNN_DEBUG_H
#ifndef DOXYGEN_SHOULD_SKIP_THIS
/* All symbols shall be internal unless marked as MKLDNN_API */
#if defined _WIN32 || defined __CYGWIN__
# define MKLDNN_HELPER_DLL_IMPORT __declspec(dllimport)
# define MKLDNN_HELPER_DLL_EXPORT __declspec(dllexport)
#else
# if __GNUC__ >= 4
# define MKLDNN_HELPER_DLL_IMPORT __attribute__ ((visibility ("default")))
# define MKLDNN_HELPER_DLL_EXPORT __attribute__ ((visibility ("default")))
# else
# define MKLDNN_HELPER_DLL_IMPORT
# define MKLDNN_HELPER_DLL_EXPORT
# endif
#endif
#ifdef MKLDNN_DLL
# ifdef MKLDNN_DLL_EXPORTS
# define MKLDNN_API MKLDNN_HELPER_DLL_EXPORT
# else
# define MKLDNN_API MKLDNN_HELPER_DLL_IMPORT
# endif
#else
# define MKLDNN_API
#endif
#if defined (__GNUC__)
# define MKLDNN_DEPRECATED __attribute__((deprecated))
#elif defined(_MSC_VER)
# define MKLDNN_DEPRECATED __declspec(deprecated)
#else
# define MKLDNN_DEPRECATED
#endif
#include "mkldnn_types.h"
#endif /* DOXYGEN_SHOULD_SKIP_THIS */
#ifdef __cplusplus
extern "C" {
#endif
const char MKLDNN_API *mkldnn_status2str(mkldnn_status_t v);
const char MKLDNN_API *mkldnn_dt2str(mkldnn_data_type_t v);
const char MKLDNN_API *mkldnn_fmt_kind2str(mkldnn_format_kind_t v);
const char MKLDNN_API *mkldnn_fmt_tag2str(mkldnn_format_tag_t v);
const char MKLDNN_API *mkldnn_prop_kind2str(mkldnn_prop_kind_t v);
const char MKLDNN_API *mkldnn_prim_kind2str(mkldnn_primitive_kind_t v);
const char MKLDNN_API *mkldnn_alg_kind2str(mkldnn_alg_kind_t v);
const char MKLDNN_API *mkldnn_rnn_direction2str(mkldnn_rnn_direction_t v);
/** Forms a format string for a given memory descriptor.
*
* The format is defined as: 'dt:[p|o|0]:fmt_kind:fmt:extra'.
* Here:
* - dt -- data type
* - p -- indicates there is non-trivial padding
* - o -- indicates there is non-trivial padding offset
* - 0 -- indicates there is non-trivial offset0
* - fmt_kind -- format kind (blocked, wino, etc...)
* - fmt -- extended format string (format_kind specific)
* - extra -- shows extra fields (underspecified)
*/
int MKLDNN_API mkldnn_md2fmt_str(char *fmt_str, size_t fmt_str_len,
const mkldnn_memory_desc_t *md);
/** Forms a dimension string for a given memory descriptor.
*
* The format is defined as: 'dim0xdim1x...xdimN
*/
int MKLDNN_API mkldnn_md2dim_str(char *dim_str, size_t dim_str_len,
const mkldnn_memory_desc_t *md);
#ifdef __cplusplus
}
#endif
#endif

File diff suppressed because it is too large Load diff

View file

@ -1,32 +0,0 @@
/*******************************************************************************
* Copyright 2019 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifndef MKLDNN_VERSION_H
#define MKLDNN_VERSION_H
/* Major version of MKL-DNN */
#define MKLDNN_VERSION_MAJOR 0
/* Minor version of MKL-DNN */
#define MKLDNN_VERSION_MINOR 90
/* Patch version of MKL-DNN */
#define MKLDNN_VERSION_PATCH 0
/* Git Commit Hash of MKL-DNN */
#define MKLDNN_VERSION_HASH "096bda1ca23324879f2df5a129e610e4405f775c"
#endif

View file

@ -1,32 +0,0 @@
/*******************************************************************************
* Copyright 2019 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifndef MKLDNN_VERSION_H
#define MKLDNN_VERSION_H
/* Major version of MKL-DNN */
#define MKLDNN_VERSION_MAJOR @MKLDNN_VERSION_MAJOR@
/* Minor version of MKL-DNN */
#define MKLDNN_VERSION_MINOR @MKLDNN_VERSION_MINOR@
/* Patch version of MKL-DNN */
#define MKLDNN_VERSION_PATCH @MKLDNN_VERSION_PATCH@
/* Git Commit Hash of MKL-DNN */
#define MKLDNN_VERSION_HASH "@MKLDNN_VERSION_HASH@"
#endif

View file

@ -1,104 +0,0 @@
/*******************************************************************************
* Copyright 2016-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <assert.h>
#include "mkldnn.h"
#include "c_types_map.hpp"
#include "type_helpers.hpp"
#include "utils.hpp"
using namespace mkldnn::impl;
using namespace mkldnn::impl::utils;
using namespace mkldnn::impl::status;
using namespace mkldnn::impl::prop_kind;
using namespace mkldnn::impl::alg_kind;
using namespace mkldnn::impl::types;
namespace {
status_t bnrm_desc_init(batch_normalization_desc_t *bnrm_desc,
prop_kind_t prop_kind, const memory_desc_t *data_desc,
const memory_desc_t *diff_data_desc, float epsilon, unsigned flags) {
bool args_ok = true
&& !any_null(bnrm_desc, data_desc)
&& one_of(prop_kind, forward_training, forward_inference,
backward_data, backward)
&& IMPLICATION(prop_kind & backward, diff_data_desc != nullptr);
if (!args_ok) return invalid_arguments;
auto bd = batch_normalization_desc_t();
bd.primitive_kind = primitive_kind::batch_normalization;
bd.prop_kind = prop_kind;
bd.data_desc = *data_desc;
bd.diff_data_desc = zero_md();
if ( one_of(bd.prop_kind,backward_data, backward) )
bd.diff_data_desc = *diff_data_desc;
dims_t scaleshift_dims = { 2, data_desc->dims[1] };
mkldnn_memory_desc_init_by_tag(&bd.data_scaleshift_desc, 2,
scaleshift_dims, data_type::f32, mkldnn_nc);
bd.diff_data_scaleshift_desc = zero_md();
if (bd.prop_kind == backward) {
bd.diff_data_scaleshift_desc = bd.data_scaleshift_desc;
}
dims_t stats_dims = { data_desc->dims[1] };
mkldnn_memory_desc_init_by_tag(&bd.mean_desc, 1, stats_dims,
data_type::f32, mkldnn_x);
bd.variance_desc = bd.mean_desc;
bd.batch_norm_epsilon = epsilon;
unsigned bnorm_flags =
mkldnn_use_global_stats | mkldnn_use_scaleshift | mkldnn_fuse_bn_relu;
if ((~bnorm_flags & flags) != 0) return invalid_arguments;
bd.flags = flags;
bool consistency = true
&& utils::one_of(bd.data_desc.ndims, 2, 4, 5);
if (bd.prop_kind == backward_data)
consistency = consistency
&& utils::one_of(bd.diff_data_desc.ndims, 2, 4, 5)
&& array_cmp(bd.diff_data_desc.dims, bd.data_desc.dims,
bd.diff_data_desc.ndims);
if (!consistency) return invalid_arguments;
*bnrm_desc = bd;
return success;
}
}
status_t mkldnn_batch_normalization_forward_desc_init(
batch_normalization_desc_t *bnrm_desc, prop_kind_t prop_kind,
const memory_desc_t *data_desc, float epsilon, unsigned flags) {
if (!one_of(prop_kind, forward_training, forward_inference))
return invalid_arguments;
return bnrm_desc_init(bnrm_desc, prop_kind, data_desc, nullptr,
epsilon, flags);
}
status_t mkldnn_batch_normalization_backward_desc_init(
batch_normalization_desc_t *bnrm_desc, prop_kind_t prop_kind,
const memory_desc_t *diff_data_desc, const memory_desc_t *data_desc,
float epsilon, unsigned flags) {
if (!one_of(prop_kind, backward, backward_data))
return invalid_arguments;
return bnrm_desc_init(bnrm_desc, prop_kind, data_desc, diff_data_desc,
epsilon, flags);
}
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

View file

@ -1,240 +0,0 @@
/*******************************************************************************
* Copyright 2016-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifndef BATCH_NORMALIZATION_PD_HPP
#define BATCH_NORMALIZATION_PD_HPP
#include "mkldnn.h"
#include "c_types_map.hpp"
#include "primitive_desc.hpp"
#include "utils.hpp"
namespace mkldnn {
namespace impl {
struct batch_normalization_fwd_pd_t;
struct batch_normalization_pd_t: public primitive_desc_t {
static constexpr auto base_pkind = primitive_kind::batch_normalization;
batch_normalization_pd_t(engine_t *engine,
const batch_normalization_desc_t *adesc,
const primitive_attr_t *attr,
const batch_normalization_fwd_pd_t *hint_fwd_pd)
: primitive_desc_t(engine, attr, base_pkind)
, desc_(*adesc)
, hint_fwd_pd_(hint_fwd_pd)
, data_md_(desc_.data_desc)
, stat_md_(desc_.mean_desc)
, scaleshift_md_(desc_.data_scaleshift_desc)
, ws_md_()
{}
const batch_normalization_desc_t *desc() const { return &desc_; }
virtual const op_desc_t *op_desc() const override
{ return reinterpret_cast<const op_desc_t *>(this->desc()); }
virtual void init_info() override { impl::init_info(this, this->info_); }
virtual status_t query(query_t what, int idx, void *result) const override {
switch (what) {
case query::batch_normalization_d:
*(const batch_normalization_desc_t**)result = desc(); break;
default: return primitive_desc_t::query(what, idx, result);
}
return status::success;
}
/* common batch_normalization aux functions */
dim_t MB() const { return data_desc().dims[0]; }
dim_t C() const { return data_desc().dims[1]; }
dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; }
dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; }
dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; }
int ndims() const { return desc_.data_desc.ndims; }
bool stats_is_src() const { return desc_.flags & mkldnn_use_global_stats; }
bool use_scaleshift() const { return desc_.flags & mkldnn_use_scaleshift; }
bool use_global_stats() const
{ return desc_.flags & mkldnn_use_global_stats; }
bool fuse_bn_relu() const { return desc_.flags & mkldnn_fuse_bn_relu; }
bool with_relu_post_op() const {
const auto &p = this->attr()->post_ops_;
return p.len_ == 1 && p.entry_[0].is_relu(true, true);
}
bool is_fwd() const {
return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
prop_kind::forward_inference);
}
bool is_bwd() const { return !this->is_fwd(); }
bool is_training() const
{ return desc_.prop_kind == prop_kind::forward_training; }
bool has_zero_dim_memory() const
{ return memory_desc_wrapper(desc_.data_desc).has_zero_dim(); }
protected:
batch_normalization_desc_t desc_;
const batch_normalization_fwd_pd_t *hint_fwd_pd_;
memory_desc_t data_md_;
memory_desc_t stat_md_;
memory_desc_t scaleshift_md_;
memory_desc_t ws_md_;
void init_default_ws(size_t bits_per_element) {
const auto data_mdw = memory_desc_wrapper(data_md_);
const dim_t data_nelems = data_mdw.nelems(true);
const dim_t bits_per_byte = 8;
const dims_t ws_sz = { (dim_t)utils::div_up(
data_nelems * bits_per_element, bits_per_byte) };
mkldnn_memory_desc_init_by_tag(&ws_md_, 1, ws_sz, impl::data_type::u8,
format_tag::x);
}
private:
const memory_desc_t &data_desc() const { return desc_.data_desc; }
};
struct batch_normalization_fwd_pd_t: public batch_normalization_pd_t {
typedef batch_normalization_fwd_pd_t base_class;
typedef batch_normalization_fwd_pd_t hint_class;
batch_normalization_fwd_pd_t(engine_t *engine,
const batch_normalization_desc_t *adesc,
const primitive_attr_t *attr,
const batch_normalization_fwd_pd_t *hint_fwd_pd)
: batch_normalization_pd_t(engine, adesc, attr, hint_fwd_pd)
{}
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
if (arg == MKLDNN_ARG_SRC) return arg_usage_t::input;
if (arg == MKLDNN_ARG_DST) return arg_usage_t::output;
if (utils::one_of(arg, MKLDNN_ARG_MEAN, MKLDNN_ARG_VARIANCE)) {
if (stats_is_src()) return arg_usage_t::input;
if (!stats_is_src() && is_training()) return arg_usage_t::output;
return arg_usage_t::unused;
}
if (arg == MKLDNN_ARG_SCALE_SHIFT && use_scaleshift())
return arg_usage_t::input;
if (arg == MKLDNN_ARG_WORKSPACE && is_training() && fuse_bn_relu())
return arg_usage_t::output;
return primitive_desc_t::arg_usage(arg);
}
virtual const memory_desc_t *src_md(int index = 0) const override {
if (index == 0) return &data_md_;
if (stats_is_src() && (index == 1 || index == 2)) return &stat_md_;
return nullptr;
}
virtual const memory_desc_t *dst_md(int index = 0) const override {
if (index == 0) return &data_md_;
if (!stats_is_src() && is_training() && (index == 1 || index == 2))
return &stat_md_;
return nullptr;
}
virtual const memory_desc_t *weights_md(int index = 0) const override
{ return index == 0 ? &scaleshift_md_ : nullptr; }
virtual const memory_desc_t *workspace_md(int index = 0) const override
{ return index == 0 && is_training() && fuse_bn_relu() ? &ws_md_ : nullptr; }
const memory_desc_t *stat_md() const
{ return stats_is_src() ? src_md(1) : dst_md(1); }
virtual int n_inputs() const override
{ return 1 + 2 * stats_is_src() + use_scaleshift(); }
virtual int n_outputs() const override
{ return 1 + (fuse_bn_relu() + 2 * (!stats_is_src())) * is_training(); }
};
struct batch_normalization_bwd_pd_t: public batch_normalization_pd_t {
typedef batch_normalization_bwd_pd_t base_class;
typedef batch_normalization_fwd_pd_t hint_class;
batch_normalization_bwd_pd_t(engine_t *engine,
const batch_normalization_desc_t *adesc,
const primitive_attr_t *attr,
const batch_normalization_fwd_pd_t *hint_fwd_pd)
: batch_normalization_pd_t(engine, adesc, attr, hint_fwd_pd)
, diff_data_md_(desc_.diff_data_desc)
, diff_scaleshift_md_(desc_.diff_data_scaleshift_desc)
{}
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_MEAN,
MKLDNN_ARG_VARIANCE, MKLDNN_ARG_DIFF_DST))
return arg_usage_t::input;
if (arg == MKLDNN_ARG_SCALE_SHIFT && use_scaleshift())
return arg_usage_t::input;
if (arg == MKLDNN_ARG_WORKSPACE && fuse_bn_relu())
return arg_usage_t::input;
if (arg == MKLDNN_ARG_DIFF_SRC)
return arg_usage_t::output;
if (arg == MKLDNN_ARG_DIFF_SCALE_SHIFT && use_scaleshift())
return arg_usage_t::output;
return primitive_desc_t::arg_usage(arg);
}
virtual const memory_desc_t *src_md(int index = 0) const override
{ return index == 0 ? &data_md_ : index <= 2 ? &stat_md_ : nullptr; }
virtual const memory_desc_t *diff_dst_md(int index = 0) const override
{ return index == 0 ? &diff_data_md_ : nullptr; }
virtual const memory_desc_t *diff_src_md(int index = 0) const override
{ return index == 0 ? &diff_data_md_ : nullptr; }
virtual const memory_desc_t *weights_md(int index = 0) const override
{ return index == 0 ? &scaleshift_md_ : nullptr; }
virtual const memory_desc_t *diff_weights_md(int index = 0) const override
{ return index == 0 ? &diff_scaleshift_md_ : nullptr; }
virtual const memory_desc_t *workspace_md(int index = 0) const override
{ return index == 0 && fuse_bn_relu() ? &ws_md_ : nullptr; }
const memory_desc_t *stat_md() const { return src_md(1); }
virtual int n_inputs() const override
{ return 4 + use_scaleshift() + fuse_bn_relu(); }
virtual int n_outputs() const override
{ return 1 + (desc_.prop_kind == prop_kind::backward); }
protected:
memory_desc_t diff_data_md_;
memory_desc_t diff_scaleshift_md_;
};
}
}
#endif
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

View file

@ -1,550 +0,0 @@
/*******************************************************************************
* Copyright 2016-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifndef TYPE_MAPPING_HPP
#define TYPE_MAPPING_HPP
#include "mkldnn_types.h"
namespace mkldnn {
namespace impl {
// TODO: autogenerate this
using dim_t = mkldnn_dim_t;
using dims_t = mkldnn_dims_t;
using stride_t = mkldnn_dim_t;
using strides_t = mkldnn_strides_t;
using status_t = mkldnn_status_t;
namespace status {
const status_t success = mkldnn_success;
const status_t out_of_memory = mkldnn_out_of_memory;
const status_t try_again = mkldnn_try_again;
const status_t invalid_arguments = mkldnn_invalid_arguments;
const status_t not_ready = mkldnn_not_ready;
const status_t unimplemented = mkldnn_unimplemented;
const status_t iterator_ends = mkldnn_iterator_ends;
const status_t runtime_error = mkldnn_runtime_error;
const status_t not_required = mkldnn_not_required;
}
using prop_kind_t = mkldnn_prop_kind_t;
namespace prop_kind {
const prop_kind_t undef = mkldnn_prop_kind_undef;
const prop_kind_t forward_training = mkldnn_forward_training;
const prop_kind_t forward_inference = mkldnn_forward_inference;
const prop_kind_t forward_scoring = mkldnn_forward_scoring;
const prop_kind_t forward = mkldnn_forward;
const prop_kind_t backward = mkldnn_backward;
const prop_kind_t backward_data = mkldnn_backward_data;
const prop_kind_t backward_weights = mkldnn_backward_weights;
const prop_kind_t backward_bias = mkldnn_backward_bias;
}
using alg_kind_t = mkldnn_alg_kind_t;
namespace alg_kind {
const alg_kind_t undef = mkldnn_alg_kind_undef;
const alg_kind_t convolution_auto = mkldnn_convolution_auto;
const alg_kind_t convolution_direct = mkldnn_convolution_direct;
const alg_kind_t convolution_winograd = mkldnn_convolution_winograd;
const alg_kind_t deconvolution_direct = mkldnn_deconvolution_direct;
const alg_kind_t deconvolution_winograd = mkldnn_deconvolution_winograd;
const alg_kind_t eltwise_relu = mkldnn_eltwise_relu;
const alg_kind_t eltwise_tanh = mkldnn_eltwise_tanh;
const alg_kind_t eltwise_elu = mkldnn_eltwise_elu;
const alg_kind_t eltwise_square = mkldnn_eltwise_square;
const alg_kind_t eltwise_abs = mkldnn_eltwise_abs;
const alg_kind_t eltwise_sqrt = mkldnn_eltwise_sqrt;
const alg_kind_t eltwise_linear = mkldnn_eltwise_linear;
const alg_kind_t eltwise_bounded_relu = mkldnn_eltwise_bounded_relu;
const alg_kind_t eltwise_soft_relu = mkldnn_eltwise_soft_relu;
const alg_kind_t eltwise_logistic = mkldnn_eltwise_logistic;
const alg_kind_t pooling_max = mkldnn_pooling_max;
const alg_kind_t pooling_avg = mkldnn_pooling_avg;
const alg_kind_t pooling_avg_include_padding = mkldnn_pooling_avg_include_padding;
const alg_kind_t pooling_avg_exclude_padding = mkldnn_pooling_avg_exclude_padding;
const alg_kind_t lrn_across_channels = mkldnn_lrn_across_channels;
const alg_kind_t lrn_within_channel = mkldnn_lrn_within_channel;
const alg_kind_t vanilla_rnn = mkldnn_vanilla_rnn;
const alg_kind_t vanilla_lstm = mkldnn_vanilla_lstm;
const alg_kind_t vanilla_gru = mkldnn_vanilla_gru;
const alg_kind_t gru_linear_before_reset = mkldnn_gru_linear_before_reset;
}
using data_type_t = mkldnn_data_type_t;
namespace data_type {
const data_type_t undef = mkldnn_data_type_undef;
const data_type_t f32 = mkldnn_f32;
const data_type_t s32 = mkldnn_s32;
const data_type_t s8 = mkldnn_s8;
const data_type_t u8 = mkldnn_u8;
}
using scratchpad_mode_t = mkldnn_scratchpad_mode_t;
namespace scratchpad_mode {
const scratchpad_mode_t library = mkldnn_scratchpad_mode_library;
const scratchpad_mode_t user = mkldnn_scratchpad_mode_user;
}
using rnn_packed_format_t = mkldnn_rnn_packed_memory_format_t;
namespace rnn_packed_format {
const rnn_packed_format_t undef = mkldnn_packed_format_undef;
const rnn_packed_format_t ldigo_p = mkldnn_ldigo_p;
const rnn_packed_format_t ldgoi_p = mkldnn_ldgoi_p;
}
using format_kind_t = mkldnn_format_kind_t;
namespace format_kind {
const format_kind_t undef = mkldnn_format_kind_undef;
const format_kind_t any = mkldnn_format_kind_any;
const format_kind_t blocked = mkldnn_blocked;
const format_kind_t wino = mkldnn_format_kind_wino;
const format_kind_t rnn_packed = mkldnn_format_kind_rnn_packed;
}
using format_tag_t = mkldnn_format_tag_t;
namespace format_tag {
const format_tag_t undef = mkldnn_format_tag_undef;
const format_tag_t any = mkldnn_format_tag_any;
const format_tag_t a = mkldnn_a;
const format_tag_t ab = mkldnn_ab;
const format_tag_t abc = mkldnn_abc;
const format_tag_t abcd = mkldnn_abcd;
const format_tag_t abcde = mkldnn_abcde;
const format_tag_t abcdef = mkldnn_abcdef;
const format_tag_t abdec = mkldnn_abdec;
const format_tag_t acb = mkldnn_acb;
const format_tag_t acbde = mkldnn_acbde;
const format_tag_t acdb = mkldnn_acdb;
const format_tag_t acdeb = mkldnn_acdeb;
const format_tag_t ba = mkldnn_ba;
const format_tag_t bac = mkldnn_bac;
const format_tag_t bacd = mkldnn_bacd;
const format_tag_t bcda = mkldnn_bcda;
const format_tag_t cba = mkldnn_cba;
const format_tag_t cdba = mkldnn_cdba;
const format_tag_t cdeba = mkldnn_cdeba;
const format_tag_t decab = mkldnn_decab;
const format_tag_t Abc16a = mkldnn_Abc16a;
const format_tag_t ABc16a16b = mkldnn_ABc16a16b;
const format_tag_t aBc16b = mkldnn_aBc16b;
const format_tag_t ABc16b16a = mkldnn_ABc16b16a;
const format_tag_t Abc4a = mkldnn_Abc4a;
const format_tag_t aBc4b = mkldnn_aBc4b;
const format_tag_t ABc4b16a4b = mkldnn_ABc4b16a4b;
const format_tag_t ABc4b4a = mkldnn_ABc4b4a;
const format_tag_t ABc8a16b2a = mkldnn_ABc8a16b2a;
const format_tag_t ABc8a8b = mkldnn_ABc8a8b;
const format_tag_t aBc8b = mkldnn_aBc8b;
const format_tag_t ABc8b16a2b = mkldnn_ABc8b16a2b;
const format_tag_t ABc8b8a = mkldnn_ABc8b8a;
const format_tag_t Abcd16a = mkldnn_Abcd16a;
const format_tag_t ABcd16a16b = mkldnn_ABcd16a16b;
const format_tag_t aBcd16b = mkldnn_aBcd16b;
const format_tag_t ABcd16b16a = mkldnn_ABcd16b16a;
const format_tag_t aBCd16b16c = mkldnn_aBCd16b16c;
const format_tag_t aBCd16c16b = mkldnn_aBCd16c16b;
const format_tag_t Abcd4a = mkldnn_Abcd4a;
const format_tag_t aBcd4b = mkldnn_aBcd4b;
const format_tag_t ABcd4b16a4b = mkldnn_ABcd4b16a4b;
const format_tag_t ABcd4b4a = mkldnn_ABcd4b4a;
const format_tag_t aBCd4c16b4c = mkldnn_aBCd4c16b4c;
const format_tag_t aBCd4c4b = mkldnn_aBCd4c4b;
const format_tag_t ABcd8a16b2a = mkldnn_ABcd8a16b2a;
const format_tag_t ABcd8a8b = mkldnn_ABcd8a8b;
const format_tag_t aBcd8b = mkldnn_aBcd8b;
const format_tag_t ABcd8b16a2b = mkldnn_ABcd8b16a2b;
const format_tag_t aBCd8b16c2b = mkldnn_aBCd8b16c2b;
const format_tag_t ABcd8b8a = mkldnn_ABcd8b8a;
const format_tag_t aBCd8b8c = mkldnn_aBCd8b8c;
const format_tag_t aBCd8c16b2c = mkldnn_aBCd8c16b2c;
const format_tag_t aBCd8c8b = mkldnn_aBCd8c8b;
const format_tag_t Abcde16a = mkldnn_Abcde16a;
const format_tag_t ABcde16a16b = mkldnn_ABcde16a16b;
const format_tag_t aBcde16b = mkldnn_aBcde16b;
const format_tag_t ABcde16b16a = mkldnn_ABcde16b16a;
const format_tag_t aBCde16b16c = mkldnn_aBCde16b16c;
const format_tag_t aBCde16c16b = mkldnn_aBCde16c16b;
const format_tag_t aBCde2c8b4c = mkldnn_aBCde2c8b4c;
const format_tag_t Abcde4a = mkldnn_Abcde4a;
const format_tag_t aBcde4b = mkldnn_aBcde4b;
const format_tag_t ABcde4b4a = mkldnn_ABcde4b4a;
const format_tag_t aBCde4b4c = mkldnn_aBCde4b4c;
const format_tag_t aBCde4c16b4c = mkldnn_aBCde4c16b4c;
const format_tag_t aBCde4c4b = mkldnn_aBCde4c4b;
const format_tag_t Abcde8a = mkldnn_Abcde8a;
const format_tag_t ABcde8a8b = mkldnn_ABcde8a8b;
const format_tag_t aBcde8b = mkldnn_aBcde8b;
const format_tag_t ABcde8b16a2b = mkldnn_ABcde8b16a2b;
const format_tag_t aBCde8b16c2b = mkldnn_aBCde8b16c2b;
const format_tag_t ABcde8b8a = mkldnn_ABcde8b8a;
const format_tag_t aBCde8b8c = mkldnn_aBCde8b8c;
const format_tag_t aBCde8c16b2c = mkldnn_aBCde8c16b2c;
const format_tag_t aBCde8c8b = mkldnn_aBCde8c8b;
const format_tag_t aBcdef16b = mkldnn_aBcdef16b;
const format_tag_t aBCdef16b16c = mkldnn_aBCdef16b16c;
const format_tag_t aBCdef16c16b = mkldnn_aBCdef16c16b;
const format_tag_t aBcdef4b = mkldnn_aBcdef4b;
const format_tag_t aBCdef4c4b = mkldnn_aBCdef4c4b;
const format_tag_t aBCdef8b8c = mkldnn_aBCdef8b8c;
const format_tag_t aBCdef8c16b2c = mkldnn_aBCdef8c16b2c;
const format_tag_t aBCdef8c8b = mkldnn_aBCdef8c8b;
const format_tag_t aBdc16b = mkldnn_aBdc16b;
const format_tag_t aBdc4b = mkldnn_aBdc4b;
const format_tag_t aBdc8b = mkldnn_aBdc8b;
const format_tag_t aBdec16b = mkldnn_aBdec16b;
const format_tag_t aBdec4b = mkldnn_aBdec4b;
const format_tag_t aBdec8b = mkldnn_aBdec8b;
const format_tag_t aBdefc16b = mkldnn_aBdefc16b;
const format_tag_t aBdefc4b = mkldnn_aBdefc4b;
const format_tag_t aBdefc8b = mkldnn_aBdefc8b;
const format_tag_t Acb16a = mkldnn_Acb16a;
const format_tag_t Acb4a = mkldnn_Acb4a;
const format_tag_t Acb8a = mkldnn_Acb8a;
const format_tag_t aCBd16b16c = mkldnn_aCBd16b16c;
const format_tag_t aCBde16b16c = mkldnn_aCBde16b16c;
const format_tag_t Acdb16a = mkldnn_Acdb16a;
const format_tag_t Acdb4a = mkldnn_Acdb4a;
const format_tag_t Acdb8a = mkldnn_Acdb8a;
const format_tag_t Acdeb16a = mkldnn_Acdeb16a;
const format_tag_t Acdeb4a = mkldnn_Acdeb4a;
const format_tag_t Acdeb8a = mkldnn_Acdeb8a;
const format_tag_t BAc16a16b = mkldnn_BAc16a16b;
const format_tag_t BAcd16a16b = mkldnn_BAcd16a16b;
const format_tag_t last = mkldnn_format_tag_last;
const format_tag_t x = mkldnn_x;
const format_tag_t nc = mkldnn_nc;
const format_tag_t cn = mkldnn_cn;
const format_tag_t ncw = mkldnn_ncw;
const format_tag_t nwc = mkldnn_nwc;
const format_tag_t nchw = mkldnn_nchw;
const format_tag_t nhwc = mkldnn_nhwc;
const format_tag_t chwn = mkldnn_chwn;
const format_tag_t ncdhw = mkldnn_ncdhw;
const format_tag_t ndhwc = mkldnn_ndhwc;
const format_tag_t oi = mkldnn_oi;
const format_tag_t io = mkldnn_io;
const format_tag_t oiw = mkldnn_oiw;
const format_tag_t wio = mkldnn_wio;
const format_tag_t oihw = mkldnn_oihw;
const format_tag_t hwio = mkldnn_hwio;
const format_tag_t ihwo = mkldnn_ihwo;
const format_tag_t iohw = mkldnn_iohw;
const format_tag_t oidhw = mkldnn_oidhw;
const format_tag_t dhwio = mkldnn_dhwio;
const format_tag_t goiw = mkldnn_goiw;
const format_tag_t goihw = mkldnn_goihw;
const format_tag_t hwigo = mkldnn_hwigo;
const format_tag_t giohw = mkldnn_giohw;
const format_tag_t goidhw = mkldnn_goidhw;
const format_tag_t tnc = mkldnn_tnc;
const format_tag_t ntc = mkldnn_ntc;
const format_tag_t ldsnc = mkldnn_ldsnc;
const format_tag_t ldigo = mkldnn_ldigo;
const format_tag_t ldgoi = mkldnn_ldgoi;
const format_tag_t ldgo = mkldnn_ldgo;
const format_tag_t nCdhw16c = mkldnn_nCdhw16c;
const format_tag_t nCdhw4c = mkldnn_nCdhw4c;
const format_tag_t nCdhw8c = mkldnn_nCdhw8c;
const format_tag_t nChw16c = mkldnn_nChw16c;
const format_tag_t nChw4c = mkldnn_nChw4c;
const format_tag_t nChw8c = mkldnn_nChw8c;
const format_tag_t nCw16c = mkldnn_nCw16c;
const format_tag_t nCw4c = mkldnn_nCw4c;
const format_tag_t nCw8c = mkldnn_nCw8c;
const format_tag_t IOw16o16i = mkldnn_IOw16o16i;
const format_tag_t OIw16i16o = mkldnn_OIw16i16o;
const format_tag_t OIw16o16i = mkldnn_OIw16o16i;
const format_tag_t Oiw16o = mkldnn_Oiw16o;
const format_tag_t OIw4i16o4i = mkldnn_OIw4i16o4i;
const format_tag_t OIw4i4o = mkldnn_OIw4i4o;
const format_tag_t Oiw4o = mkldnn_Oiw4o;
const format_tag_t OIw8i16o2i = mkldnn_OIw8i16o2i;
const format_tag_t OIw8i8o = mkldnn_OIw8i8o;
const format_tag_t OIw8o16i2o = mkldnn_OIw8o16i2o;
const format_tag_t OIw8o8i = mkldnn_OIw8o8i;
const format_tag_t Owi16o = mkldnn_Owi16o;
const format_tag_t Owi4o = mkldnn_Owi4o;
const format_tag_t Owi8o = mkldnn_Owi8o;
const format_tag_t IOhw16o16i = mkldnn_IOhw16o16i;
const format_tag_t Ohwi16o = mkldnn_Ohwi16o;
const format_tag_t Ohwi4o = mkldnn_Ohwi4o;
const format_tag_t Ohwi8o = mkldnn_Ohwi8o;
const format_tag_t OIhw16i16o = mkldnn_OIhw16i16o;
const format_tag_t OIhw16o16i = mkldnn_OIhw16o16i;
const format_tag_t Oihw16o = mkldnn_Oihw16o;
const format_tag_t OIhw4i16o4i = mkldnn_OIhw4i16o4i;
const format_tag_t OIhw4i4o = mkldnn_OIhw4i4o;
const format_tag_t Oihw4o = mkldnn_Oihw4o;
const format_tag_t OIhw8i16o2i = mkldnn_OIhw8i16o2i;
const format_tag_t OIhw8i8o = mkldnn_OIhw8i8o;
const format_tag_t OIhw8o16i2o = mkldnn_OIhw8o16i2o;
const format_tag_t OIhw8o8i = mkldnn_OIhw8o8i;
const format_tag_t Odhwi16o = mkldnn_Odhwi16o;
const format_tag_t Odhwi4o = mkldnn_Odhwi4o;
const format_tag_t Odhwi8o = mkldnn_Odhwi8o;
const format_tag_t OIdhw16i16o = mkldnn_OIdhw16i16o;
const format_tag_t OIdhw16o16i = mkldnn_OIdhw16o16i;
const format_tag_t Oidhw16o = mkldnn_Oidhw16o;
const format_tag_t OIdhw4i4o = mkldnn_OIdhw4i4o;
const format_tag_t Oidhw4o = mkldnn_Oidhw4o;
const format_tag_t OIdhw8i16o2i = mkldnn_OIdhw8i16o2i;
const format_tag_t OIdhw8i8o = mkldnn_OIdhw8i8o;
const format_tag_t OIdhw8o8i = mkldnn_OIdhw8o8i;
const format_tag_t gIOw16o16i = mkldnn_gIOw16o16i;
const format_tag_t Goiw16g = mkldnn_Goiw16g;
const format_tag_t gOIw16i16o = mkldnn_gOIw16i16o;
const format_tag_t gOIw16o16i = mkldnn_gOIw16o16i;
const format_tag_t gOiw16o = mkldnn_gOiw16o;
const format_tag_t gOIw4i16o4i = mkldnn_gOIw4i16o4i;
const format_tag_t gOIw4i4o = mkldnn_gOIw4i4o;
const format_tag_t gOiw4o = mkldnn_gOiw4o;
const format_tag_t gOIw8i16o2i = mkldnn_gOIw8i16o2i;
const format_tag_t gOIw8i8o = mkldnn_gOIw8i8o;
const format_tag_t gOIw8o16i2o = mkldnn_gOIw8o16i2o;
const format_tag_t gOIw8o8i = mkldnn_gOIw8o8i;
const format_tag_t gOwi16o = mkldnn_gOwi16o;
const format_tag_t gOwi4o = mkldnn_gOwi4o;
const format_tag_t gOwi8o = mkldnn_gOwi8o;
const format_tag_t gIOhw16o16i = mkldnn_gIOhw16o16i;
const format_tag_t gOhwi16o = mkldnn_gOhwi16o;
const format_tag_t gOhwi4o = mkldnn_gOhwi4o;
const format_tag_t gOhwi8o = mkldnn_gOhwi8o;
const format_tag_t Goihw16g = mkldnn_Goihw16g;
const format_tag_t gOIhw16i16o = mkldnn_gOIhw16i16o;
const format_tag_t gOIhw16o16i = mkldnn_gOIhw16o16i;
const format_tag_t gOihw16o = mkldnn_gOihw16o;
const format_tag_t gOIhw2i8o4i = mkldnn_gOIhw2i8o4i;
const format_tag_t gOIhw4i16o4i = mkldnn_gOIhw4i16o4i;
const format_tag_t gOIhw4i4o = mkldnn_gOIhw4i4o;
const format_tag_t gOIhw4o4i = mkldnn_gOIhw4o4i;
const format_tag_t gOihw4o = mkldnn_gOihw4o;
const format_tag_t Goihw8g = mkldnn_Goihw8g;
const format_tag_t gOIhw8i16o2i = mkldnn_gOIhw8i16o2i;
const format_tag_t gOIhw8i8o = mkldnn_gOIhw8i8o;
const format_tag_t gOIhw8o16i2o = mkldnn_gOIhw8o16i2o;
const format_tag_t gOIhw8o8i = mkldnn_gOIhw8o8i;
const format_tag_t gOdhwi16o = mkldnn_gOdhwi16o;
const format_tag_t gOdhwi4o = mkldnn_gOdhwi4o;
const format_tag_t gOdhwi8o = mkldnn_gOdhwi8o;
const format_tag_t gOIdhw16i16o = mkldnn_gOIdhw16i16o;
const format_tag_t gOIdhw16o16i = mkldnn_gOIdhw16o16i;
const format_tag_t gOidhw16o = mkldnn_gOidhw16o;
const format_tag_t gOIdhw4i4o = mkldnn_gOIdhw4i4o;
const format_tag_t gOidhw4o = mkldnn_gOidhw4o;
const format_tag_t gOIdhw8i16o2i = mkldnn_gOIdhw8i16o2i;
const format_tag_t gOIdhw8i8o = mkldnn_gOIdhw8i8o;
const format_tag_t gOIdhw8o8i = mkldnn_gOIdhw8o8i;
}
using memory_extra_flags_t = mkldnn_memory_extra_flags_t;
namespace memory_extra_flags {
const memory_extra_flags_t none = mkldnn_memory_extra_flag_none;
const memory_extra_flags_t compensation_conv_s8s8 = mkldnn_memory_extra_flag_compensation_conv_s8s8;
const memory_extra_flags_t scale_adjust = mkldnn_memory_extra_flag_scale_adjust;
}
using padding_kind_t = mkldnn_padding_kind_t;
namespace padding_kind {
const padding_kind_t padding_zero = mkldnn_padding_zero;
}
using engine_kind_t = mkldnn_engine_kind_t;
namespace engine_kind {
const engine_kind_t any_engine = mkldnn_any_engine;
const engine_kind_t cpu = mkldnn_cpu;
}
using primitive_kind_t = mkldnn_primitive_kind_t;
namespace primitive_kind {
const primitive_kind_t undefined = mkldnn_undefined_primitive;
const primitive_kind_t reorder = mkldnn_reorder;
const primitive_kind_t concat = mkldnn_concat;
const primitive_kind_t sum = mkldnn_sum;
const primitive_kind_t convolution = mkldnn_convolution;
const primitive_kind_t deconvolution = mkldnn_deconvolution;
const primitive_kind_t shuffle = mkldnn_shuffle;
const primitive_kind_t eltwise = mkldnn_eltwise;
const primitive_kind_t softmax = mkldnn_softmax;
const primitive_kind_t pooling = mkldnn_pooling;
const primitive_kind_t lrn = mkldnn_lrn;
const primitive_kind_t batch_normalization = mkldnn_batch_normalization;
const primitive_kind_t inner_product = mkldnn_inner_product;
const primitive_kind_t rnn = mkldnn_rnn;
}
using query_t = mkldnn_query_t;
namespace query {
const query_t undef = mkldnn_query_undef;
const query_t engine = mkldnn_query_engine;
const query_t primitive_kind = mkldnn_query_primitive_kind;
const query_t num_of_inputs_s32 = mkldnn_query_num_of_inputs_s32;
const query_t num_of_outputs_s32 = mkldnn_query_num_of_outputs_s32;
const query_t time_estimate_f64 = mkldnn_query_time_estimate_f64;
const query_t memory_consumption_s64 = mkldnn_query_memory_consumption_s64;
const query_t scratchpad_engine = mkldnn_query_scratchpad_engine;
const query_t impl_info_str = mkldnn_query_impl_info_str;
const query_t some_d = mkldnn_query_some_d;
const query_t op_d = mkldnn_query_op_d;
const query_t convolution_d = mkldnn_query_convolution_d;
const query_t deconvolution_d = mkldnn_query_deconvolution_d;
const query_t shuffle_d = mkldnn_query_shuffle_d;
const query_t eltwise_d = mkldnn_query_eltwise_d;
const query_t softmax_d = mkldnn_query_softmax_d;
const query_t pooling_d = mkldnn_query_pooling_d;
const query_t lrn_d = mkldnn_query_lrn_d;
const query_t batch_normalization_d = mkldnn_query_batch_normalization_d;
const query_t inner_product_d = mkldnn_query_inner_product_d;
const query_t rnn_d = mkldnn_query_rnn_d;
const query_t some_md = mkldnn_query_some_md;
const query_t src_md = mkldnn_query_src_md;
const query_t diff_src_md = mkldnn_query_diff_src_md;
const query_t weights_md = mkldnn_query_weights_md;
const query_t diff_weights_md = mkldnn_query_diff_weights_md;
const query_t dst_md = mkldnn_query_dst_md;
const query_t diff_dst_md = mkldnn_query_diff_dst_md;
const query_t workspace_md = mkldnn_query_workspace_md;
const query_t scratchpad_md = mkldnn_query_scratchpad_md;
}
using blocking_desc_t = mkldnn_blocking_desc_t;
using rnn_packed_desc_t = mkldnn_rnn_packed_desc_t;
using wino_desc_t = mkldnn_wino_desc_t;
using memory_extra_desc_t = mkldnn_memory_extra_desc_t;
using memory_desc_t = mkldnn_memory_desc_t;
using convolution_desc_t = mkldnn_convolution_desc_t;
using deconvolution_desc_t = mkldnn_deconvolution_desc_t;
using shuffle_desc_t = mkldnn_shuffle_desc_t;
using pooling_desc_t = mkldnn_pooling_desc_t;
using eltwise_desc_t = mkldnn_eltwise_desc_t;
using softmax_desc_t = mkldnn_softmax_desc_t;
using lrn_desc_t = mkldnn_lrn_desc_t;
using batch_normalization_desc_t = mkldnn_batch_normalization_desc_t;
using inner_product_desc_t = mkldnn_inner_product_desc_t;
using rnn_direction_t = mkldnn_rnn_direction_t;
using rnn_cell_desc_t = mkldnn_rnn_cell_desc_t;
using rnn_desc_t = mkldnn_rnn_desc_t;
/* C op_desc_t, which eventually are just (void*) */
using c_op_desc_t = mkldnn_op_desc_t;
using const_c_op_desc_t = const_mkldnn_op_desc_t;
struct op_desc_t {
union {
primitive_kind_t kind;
convolution_desc_t convolution;
deconvolution_desc_t deconvolution;
shuffle_desc_t shuffle;
pooling_desc_t pooling;
eltwise_desc_t eltwise;
softmax_desc_t softmax;
lrn_desc_t lrn;
batch_normalization_desc_t batch_normalization;
inner_product_desc_t inner_product;
rnn_desc_t rnn;
};
op_desc_t(const primitive_kind_t &_): kind(_) {}
# define DECL_CTOR_AND_CONVERTERS(c_type, name) \
op_desc_t(const c_type &_): name(_) {} \
static op_desc_t *convert_from_c(c_type *_) \
{ return reinterpret_cast<op_desc_t*>(_); } \
static const op_desc_t *convert_from_c(const c_type *_) \
{ return reinterpret_cast<const op_desc_t*>(_); }
DECL_CTOR_AND_CONVERTERS(convolution_desc_t, convolution);
DECL_CTOR_AND_CONVERTERS(shuffle_desc_t, shuffle);
DECL_CTOR_AND_CONVERTERS(pooling_desc_t, pooling);
DECL_CTOR_AND_CONVERTERS(eltwise_desc_t, eltwise);
DECL_CTOR_AND_CONVERTERS(softmax_desc_t, softmax);
DECL_CTOR_AND_CONVERTERS(lrn_desc_t, lrn);
DECL_CTOR_AND_CONVERTERS(batch_normalization_desc_t, batch_normalization);
DECL_CTOR_AND_CONVERTERS(inner_product_desc_t, inner_product);
DECL_CTOR_AND_CONVERTERS(rnn_desc_t, rnn);
# undef DECL_CTOR_AND_CONVERTERS
};
using engine_t = mkldnn_engine;
using primitive_desc_iterator_t = mkldnn_primitive_desc_iterator;
using primitive_desc_t = mkldnn_primitive_desc;
using primitive_attr_t = mkldnn_primitive_attr;
using post_ops_t = mkldnn_post_ops;
using memory_t = mkldnn_memory;
using primitive_t = mkldnn_primitive;
using primitive_arg_index_t = int;
using stream_flags_t = mkldnn_stream_flags_t;
namespace stream_flags {
const stream_flags_t default_flags = mkldnn_stream_default_flags;
}
using stream_t = mkldnn_stream;
/* forward declaration of the internal primitive_desc types */
struct batch_normalization_bwd_pd_t;
struct batch_normalization_fwd_pd_t;
struct batch_normalization_pd_t;
struct concat_pd_t;
struct convolution_bwd_data_pd_t;
struct convolution_bwd_weights_pd_t;
struct convolution_fwd_pd_t;
struct convolution_pd_t;
struct deconvolution_bwd_data_pd_t;
struct deconvolution_bwd_weights_pd_t;
struct deconvolution_fwd_pd_t;
struct deconvolution_pd_t;
struct eltwise_bwd_pd_t;
struct eltwise_fwd_pd_t;
struct eltwise_pd_t;
struct inner_product_bwd_data_pd_t;
struct inner_product_bwd_weights_pd_t;
struct inner_product_fwd_pd_t;
struct inner_product_pd_t;
struct lrn_bwd_pd_t;
struct lrn_fwd_pd_t;
struct lrn_pd_t;
struct pooling_bwd_pd_t;
struct pooling_fwd_pd_t;
struct pooling_pd_t;
struct reorder_pd_t;
struct rnn_bwd_pd_t;
struct rnn_fwd_pd_t;
struct rnn_pd_t;
struct shuffle_pd_t;
struct softmax_bwd_pd_t;
struct softmax_fwd_pd_t;
struct softmax_pd_t;
struct sum_pd_t;
}
}
#endif
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

View file

@ -1,86 +0,0 @@
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <assert.h>
#include "mkldnn.h"
#include "c_types_map.hpp"
#include "engine.hpp"
#include "type_helpers.hpp"
#include "utils.hpp"
#include "concat_pd.hpp"
using namespace mkldnn::impl;
using namespace mkldnn::impl::utils;
using namespace mkldnn::impl::status;
status_t mkldnn_concat_primitive_desc_create(primitive_desc_t **concat_pd,
const memory_desc_t *dst_md, int n, int concat_dim,
const memory_desc_t *src_mds,
const primitive_attr_t *attr,
engine_t *engine) {
bool args_ok = !any_null(concat_pd, src_mds) && n > 0;
if (!args_ok) return invalid_arguments;
const primitive_attr_t dummy_attr;
if (attr == NULL)
attr = &dummy_attr;
const int ndims = src_mds[0].ndims;
const dims_t &dims = src_mds[0].dims;
const data_type_t dt = src_mds[0].data_type;
int concat_dim_sz = dims[concat_dim];
for (int i = 1; i < n; ++i) {
if (src_mds[i].ndims != ndims) return invalid_arguments;
for (int d = 0; d < ndims; ++d) {
if (d == concat_dim) continue;
if (src_mds[i].dims[d] != dims[d])
return invalid_arguments;
}
if (src_mds[i].data_type != dt) return invalid_arguments;
concat_dim_sz += src_mds[i].dims[concat_dim];
}
memory_desc_t dummy_dst_md;
if (dst_md) {
if (dst_md->ndims != ndims) return invalid_arguments;
for (int d = 0; d < ndims; ++d) {
if (dst_md->dims[d] !=
(d == concat_dim ? concat_dim_sz : dims[d]))
return invalid_arguments;
}
} else {
dummy_dst_md = src_mds[0];
dummy_dst_md.dims[concat_dim] = concat_dim_sz;
dummy_dst_md.format_kind = format_kind::any;
dst_md = &dummy_dst_md;
}
auto c_pd = reinterpret_cast<concat_pd_t **>(concat_pd);
for (auto c = engine->get_concat_implementation_list(); *c; ++c) {
if ((*c)(c_pd, engine, attr, dst_md, n, concat_dim, src_mds)
== success) {
(*c_pd)->init_info();
(*c_pd)->init_scratchpad_md();
return success;
}
}
return unimplemented;
}

View file

@ -1,211 +0,0 @@
/*******************************************************************************
* Copyright 2019 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifndef CONCAT_PD_HPP
#define CONCAT_PD_HPP
#include <assert.h>
#include "c_types_map.hpp"
#include "nstl.hpp"
#include "primitive_desc.hpp"
#include "type_helpers.hpp"
#include "utils.hpp"
namespace mkldnn {
namespace impl {
struct concat_pd_t: public primitive_desc_t {
concat_pd_t(engine_t *engine, const primitive_attr_t *attr,
const memory_desc_t *dst_md, int n, int concat_dim,
const memory_desc_t *src_mds)
: primitive_desc_t(engine, attr, primitive_kind::concat)
, n_(n), concat_dim_(concat_dim), dst_md_(*dst_md)
{
src_mds_.reserve(n_);
for (int i = 0; i < n_; ++i) src_mds_.push_back(src_mds[i]);
}
concat_pd_t(const concat_pd_t &rhs) = default;
virtual void init_info() override { impl::init_info(this, this->info_); }
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
if (arg >= MKLDNN_ARG_MULTIPLE_SRC
&& arg < MKLDNN_ARG_MULTIPLE_SRC + n_inputs())
return arg_usage_t::input;
if (arg == MKLDNN_ARG_DST)
return arg_usage_t::output;
return primitive_desc_t::arg_usage(arg);
}
virtual const memory_desc_t *src_md(int index = 0) const override
{ return index < n_inputs() ? &src_mds_[index] : nullptr; }
virtual const memory_desc_t *dst_md(int index = 0) const override
{ return index == 0 ? &dst_md_ : nullptr; }
virtual int n_inputs() const override { return n_; }
virtual int n_outputs() const override { return 1; }
int concat_dim() const { return concat_dim_; }
const memory_desc_t *src_image_md(int index = 0) const
{ return index < n_inputs() ? &src_image_mds_[index] : nullptr; }
protected:
int n_, concat_dim_;
memory_desc_t dst_md_;
nstl::vector<memory_desc_t> src_mds_;
/* contains images of srcs in the dst memory (if possible)
* Lives here to simplify some implementations. An implementation might
* use this auxiliary array iff init() returned success */
nstl::vector<memory_desc_t> src_image_mds_;
protected:
/* inits src_image_mds_ and dst_md_ in simple cases. The call may fail */
status_t init() {
bool ok = true
&& set_default_params() == status::success
&& attr()->has_default_values();
if (!ok) return status::unimplemented;
for (int i = 0; i < n_; ++i) {
const memory_desc_wrapper i_d(&src_mds_[i]);
if (!i_d.is_blocking_desc() || i_d.is_additional_buffer())
return status::unimplemented;
}
const int ndims = dst_md_.ndims;
int current_concat_dim_offset = 0;
for (int i = 0; i < n_; ++i) {
const int dim = src_mds_[i].dims[concat_dim_];
dims_t dims, offsets = {};
utils::array_copy(dims, dst_md_.dims, ndims);
dims[concat_dim_] = dim;
offsets[concat_dim_] = current_concat_dim_offset;
memory_desc_t src_img_d;
status_t status = mkldnn_memory_desc_init_submemory(&src_img_d,
&dst_md_, dims, offsets);
if (status != status::success) return status;
src_image_mds_.push_back(src_img_d);
current_concat_dim_offset += dim;
}
return status::success;
}
status_t set_default_params() {
if (dst_md_.format_kind != format_kind::any)
return status::success;
const int ndims = dst_md_.ndims;
/* The stupidest ever heuristics (but not the same as we had before):
* - Pick the first non-plain format;
* - If all formats are plain or it is not possible to create a
* blocked format for the output, pick the format of the plain input
* - If this fails as well, use plain layout (abcd...)
*/
status_t status = status::unimplemented;
for (int i = 0; i < n_; ++i) {
const memory_desc_wrapper src_d(src_mds_[i]);
if (src_d.is_blocking_desc() && !src_d.is_plain()) {
status = memory_desc_init_by_blocking_desc(dst_md_,
src_d.blocking_desc());
if (status == status::success) break;
}
}
if (status == status::success) {
/* check if we can create a sub-memory for the dst */
bool desired_format_ok = true;
int current_concat_dim_offset = 0;
for (int i = 0; i < n_; ++i) {
const int dim = src_mds_[i].dims[concat_dim_];
dims_t dims, offsets = {};
utils::array_copy(dims, dst_md_.dims, ndims);
dims[concat_dim_] = dim;
offsets[concat_dim_] = current_concat_dim_offset;
memory_desc_t src_img_d;
status_t status = mkldnn_memory_desc_init_submemory(&src_img_d,
&dst_md_, dims, offsets);
if (status != status::success) {
desired_format_ok = false;
break;
}
current_concat_dim_offset += dim;
}
if (!desired_format_ok)
status = status::unimplemented;
}
/* if no success so far, try using the format of the first plain input */
if (status != status::success) {
for (int i = 0; i < n_; ++i) {
const memory_desc_wrapper src_d(src_mds_[i]);
if (src_d.is_blocking_desc() && src_d.is_plain()) {
status = memory_desc_init_by_blocking_desc(dst_md_,
memory_desc_wrapper(src_mds_[0]).blocking_desc());
if (status == status::success) return status;
}
}
}
/* the last line of defense: use plain abcd... format */
if (status != status::success)
status = memory_desc_init_by_strides(dst_md_, nullptr);
return status;
}
};
#define DECLARE_CONCAT_PD_t(impl_name, ...) \
static status_t create(concat_pd_t **concat_pd, \
engine_t *engine, const primitive_attr_t *attr, \
const memory_desc_t *dst_md, int n, int concat_dim, \
const memory_desc_t *src_mds) { \
using namespace status; \
auto _pd = new pd_t(engine, attr, dst_md, n, concat_dim, src_mds); \
if (_pd == nullptr) return out_of_memory; \
if (_pd->init() != success) { delete _pd; return unimplemented; } \
return safe_ptr_assign<concat_pd_t>(*concat_pd, _pd); \
} \
virtual status_t create_primitive(primitive_t **p) const override { \
double ms = get_msec(); \
auto ret = safe_ptr_assign<primitive_t>(*p, new (__VA_ARGS__)(this)); \
ms = get_msec() - ms; \
if (mkldnn_verbose()->level >= 2) { \
printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \
fflush(0); \
} \
return ret; \
} \
virtual pd_t *clone() const override { return new pd_t(*this); } \
virtual const char *name() const override { return impl_name; } \
#define DECLARE_CONCAT_PD_T(impl_name, ...) \
DECLARE_CONCAT_PD_t(impl_name, __VA_ARGS__)
}
}
#endif

View file

@ -1,200 +0,0 @@
/*******************************************************************************
* Copyright 2016-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <assert.h>
#include "mkldnn.h"
#include "c_types_map.hpp"
#include "type_helpers.hpp"
#include "utils.hpp"
using namespace mkldnn::impl;
using namespace mkldnn::impl::utils;
using namespace mkldnn::impl::status;
using namespace mkldnn::impl::prop_kind;
using namespace mkldnn::impl::alg_kind;
using namespace mkldnn::impl::types;
namespace mkldnn {
namespace impl {
status_t conv_desc_init(convolution_desc_t *conv_desc,
prop_kind_t prop_kind, alg_kind_t alg_kind,
const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
const memory_desc_t *bias_desc, const memory_desc_t *dst_desc,
const dims_t strides, const dims_t dilates,
const dims_t padding_l, const dims_t padding_r,
padding_kind_t padding_kind) {
bool args_ok = true
&& !any_null(conv_desc, src_desc, weights_desc, dst_desc, strides,
padding_l)
&& one_of(alg_kind, convolution_auto, convolution_direct, convolution_winograd)
&& one_of(padding_kind, padding_kind::padding_zero);
if (!args_ok) return invalid_arguments;
if (padding_r == nullptr) padding_r = padding_l;
auto cd = convolution_desc_t();
cd.primitive_kind = primitive_kind::convolution;
cd.prop_kind = prop_kind;
cd.alg_kind = alg_kind;
cd.diff_src_desc = cd.src_desc = zero_md();
cd.diff_dst_desc = cd.dst_desc = zero_md();
cd.diff_weights_desc = cd.weights_desc = zero_md();
cd.diff_bias_desc = cd.bias_desc = zero_md();
const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
const bool with_bias =
bias_desc && bias_desc->format_kind != format_kind::undef;
const bool with_groups = weights_desc->ndims == src_desc->ndims + 1;
(prop_kind == backward_data ? cd.diff_src_desc : cd.src_desc) = *src_desc;
(is_fwd ? cd.dst_desc : cd.diff_dst_desc) = *dst_desc;
(prop_kind == backward_weights ? cd.diff_weights_desc : cd.weights_desc) =
*weights_desc;
if (with_bias)
(prop_kind == backward_weights ? cd.diff_bias_desc : cd.bias_desc) =
*bias_desc;
int sp_dims = src_desc->ndims - 2;
utils::array_copy(cd.strides, strides, sp_dims);
utils::array_copy(cd.padding[0], padding_l, sp_dims);
utils::array_copy(cd.padding[1], padding_r, sp_dims);
if (dilates)
utils::array_copy(cd.dilates, dilates, sp_dims);
else
utils::array_set(cd.dilates, 0, sp_dims);
cd.padding_kind = padding_kind;
cd.accum_data_type = types::default_accum_data_type(src_desc->data_type,
weights_desc->data_type, dst_desc->data_type, prop_kind);
const int g = with_groups ? weights_desc->dims[0] : 1;
const int bias_dim = prop_kind == backward_data
? src_desc->dims[1]
: dst_desc->dims[1];
bool consistency = true
&& memory_desc_wrapper(weights_desc).nelems()
&& src_desc->ndims == dst_desc->ndims
&& utils::one_of(src_desc->ndims, 3, 4, 5)
&& utils::one_of(weights_desc->ndims, src_desc->ndims,
src_desc->ndims + 1)
&& (with_bias ? bias_desc->ndims == 1 : true)
&& (with_bias ? bias_desc->dims[0] == bias_dim : true)
&& src_desc->dims[0] == dst_desc->dims[0]
&& src_desc->dims[1] == g * weights_desc->dims[with_groups + 1]
&& dst_desc->dims[1] == g * weights_desc->dims[with_groups + 0];
for (int i = 2; i < src_desc->ndims; ++i)
{
int src = src_desc->dims[i];
int ker = weights_desc->dims[with_groups + i];
int dil = cd.dilates[i - 2];
int pad_l = padding_l[i - 2];
int pad_r = padding_r[i - 2];
int str = strides[i - 2];
int dst = dst_desc->dims[i];
int ker_range = 1 + (ker - 1) * (dil + 1);
if (str < 1) return invalid_arguments;
consistency = consistency
&& dil >= 0
&& pad_l >= 0
&& pad_r + str > 0
&& (src - ker_range + pad_l + pad_r) / str + 1 == dst;
}
if (!consistency) return invalid_arguments;
*conv_desc = cd;
return success;
}
}
}
status_t mkldnn_convolution_forward_desc_init(convolution_desc_t *conv_desc,
prop_kind_t prop_kind, alg_kind_t alg_kind,
const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
const memory_desc_t *bias_desc, const memory_desc_t *dst_desc,
const dims_t strides, const dims_t padding_l, const dims_t padding_r,
padding_kind_t padding_kind) {
if (!one_of(prop_kind, forward_training, forward_inference))
return invalid_arguments;
return mkldnn::impl::conv_desc_init(conv_desc, prop_kind, alg_kind, src_desc,
weights_desc, bias_desc, dst_desc, strides, nullptr,
padding_l, padding_r, padding_kind);
}
status_t mkldnn_dilated_convolution_forward_desc_init(
convolution_desc_t *conv_desc, prop_kind_t prop_kind,
alg_kind_t alg_kind, const memory_desc_t *src_desc,
const memory_desc_t *weights_desc, const memory_desc_t *bias_desc,
const memory_desc_t *dst_desc, const dims_t strides,
const dims_t dilates, const dims_t padding_l,
const dims_t padding_r, padding_kind_t padding_kind) {
if (!one_of(prop_kind, forward_training, forward_inference))
return invalid_arguments;
return mkldnn::impl::conv_desc_init(conv_desc, prop_kind, alg_kind, src_desc,
weights_desc, bias_desc, dst_desc, strides, dilates,
padding_l, padding_r, padding_kind);
}
status_t mkldnn_convolution_backward_data_desc_init(
convolution_desc_t *conv_desc, alg_kind_t alg_kind,
const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc,
const memory_desc_t *diff_dst_desc, const dims_t strides,
const dims_t padding_l, const dims_t padding_r,
padding_kind_t padding_kind) {
return mkldnn::impl::conv_desc_init(conv_desc, backward_data, alg_kind, diff_src_desc,
weights_desc, nullptr, diff_dst_desc, strides, nullptr,
padding_l, padding_r, padding_kind);
}
status_t mkldnn_dilated_convolution_backward_data_desc_init(
convolution_desc_t *conv_desc, alg_kind_t alg_kind,
const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc,
const memory_desc_t *diff_dst_desc, const dims_t strides,
const dims_t dilates, const dims_t padding_l, const dims_t padding_r,
padding_kind_t padding_kind) {
return mkldnn::impl::conv_desc_init(conv_desc, backward_data, alg_kind, diff_src_desc,
weights_desc, nullptr, diff_dst_desc, strides, dilates,
padding_l, padding_r, padding_kind);
}
status_t mkldnn_convolution_backward_weights_desc_init(
convolution_desc_t *conv_desc, alg_kind_t alg_kind,
const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc,
const memory_desc_t *diff_bias_desc,
const memory_desc_t *diff_dst_desc, const dims_t strides,
const dims_t padding_l, const dims_t padding_r,
padding_kind_t padding_kind) {
return mkldnn::impl::conv_desc_init(conv_desc, backward_weights, alg_kind, src_desc,
diff_weights_desc, diff_bias_desc, diff_dst_desc, strides,
nullptr, padding_l, padding_r, padding_kind);
}
status_t mkldnn_dilated_convolution_backward_weights_desc_init(
convolution_desc_t *conv_desc, alg_kind_t alg_kind,
const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc,
const memory_desc_t *diff_bias_desc,
const memory_desc_t *diff_dst_desc, const dims_t strides,
const dims_t dilates, const dims_t padding_l, const dims_t padding_r,
padding_kind_t padding_kind) {
return mkldnn::impl::conv_desc_init(conv_desc, backward_weights, alg_kind, src_desc,
diff_weights_desc, diff_bias_desc, diff_dst_desc, strides,
dilates, padding_l, padding_r, padding_kind);
}
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

View file

@ -1,56 +0,0 @@
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "utils.hpp"
#include "convolution_pd.hpp"
namespace mkldnn {
namespace impl {
using namespace prop_kind;
memory_desc_t *conv_prop_invariant_src_d(convolution_desc_t *desc) {
return desc->prop_kind == backward_data
? &desc->diff_src_desc : &desc->src_desc;
}
memory_desc_t *conv_prop_invariant_wei_d(convolution_desc_t *desc) {
return desc->prop_kind == backward_weights
? &desc->diff_weights_desc : &desc->weights_desc;
}
memory_desc_t *conv_prop_invariant_bia_d(convolution_desc_t *desc) {
return desc->prop_kind == backward_weights
? &desc->diff_bias_desc : &desc->bias_desc;
}
memory_desc_t *conv_prop_invariant_dst_d(convolution_desc_t *desc) {
return utils::one_of(desc->prop_kind, forward_inference, forward_training)
? &desc->dst_desc : &desc->diff_dst_desc;
}
const memory_desc_t *conv_prop_invariant_src_d(const convolution_desc_t *desc)
{ return conv_prop_invariant_src_d(const_cast<convolution_desc_t *>(desc)); }
const memory_desc_t *conv_prop_invariant_wei_d(const convolution_desc_t *desc)
{ return conv_prop_invariant_wei_d(const_cast<convolution_desc_t *>(desc)); }
const memory_desc_t *conv_prop_invariant_bia_d(const convolution_desc_t *desc)
{ return conv_prop_invariant_bia_d(const_cast<convolution_desc_t *>(desc)); }
const memory_desc_t *conv_prop_invariant_dst_d(const convolution_desc_t *desc)
{ return conv_prop_invariant_dst_d(const_cast<convolution_desc_t *>(desc)); }
}
}

View file

@ -1,348 +0,0 @@
/*******************************************************************************
* Copyright 2016-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifndef CONVOLUTION_PD_HPP
#define CONVOLUTION_PD_HPP
#include "mkldnn.h"
#include "c_types_map.hpp"
#include "primitive_desc.hpp"
#include "utils.hpp"
namespace mkldnn {
namespace impl {
status_t conv_desc_init(convolution_desc_t *conv_desc,
prop_kind_t prop_kind, alg_kind_t alg_kind,
const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
const memory_desc_t *bias_desc, const memory_desc_t *dst_desc,
const dims_t strides, const dims_t dilates,
const dims_t padding_l, const dims_t padding_r,
padding_kind_t padding_kind);
memory_desc_t *conv_prop_invariant_src_d(convolution_desc_t *desc);
memory_desc_t *conv_prop_invariant_wei_d(convolution_desc_t *desc);
memory_desc_t *conv_prop_invariant_bia_d(convolution_desc_t *desc);
memory_desc_t *conv_prop_invariant_dst_d(convolution_desc_t *desc);
const memory_desc_t *conv_prop_invariant_src_d(const convolution_desc_t *desc);
const memory_desc_t *conv_prop_invariant_wei_d(const convolution_desc_t *desc);
const memory_desc_t *conv_prop_invariant_bia_d(const convolution_desc_t *desc);
const memory_desc_t *conv_prop_invariant_dst_d(const convolution_desc_t *desc);
struct convolution_fwd_pd_t;
struct convolution_pd_t: public primitive_desc_t {
static constexpr auto base_pkind = primitive_kind::convolution;
convolution_pd_t(engine_t *engine,
const convolution_desc_t *adesc,
const primitive_attr_t *attr,
const convolution_fwd_pd_t *hint_fwd_pd)
: primitive_desc_t(engine, attr, base_pkind)
, desc_(*adesc)
, hint_fwd_pd_(hint_fwd_pd)
{}
const convolution_desc_t *desc() const { return &desc_; }
virtual const op_desc_t *op_desc() const override
{ return reinterpret_cast<const op_desc_t *>(this->desc()); }
virtual void init_info() override { impl::init_info(this, this->info_); }
virtual status_t query(query_t what, int idx, void *result) const override {
switch (what) {
case pkind_traits<base_pkind>::query_d:
*(const convolution_desc_t**)result = desc(); break;
default: return primitive_desc_t::query(what, idx, result);
}
return status::success;
}
/* common conv aux functions */
dim_t MB() const { return _src_md()->dims[0]; }
dim_t IC() const { return _src_md()->dims[1]; }
dim_t OC() const { return _dst_md()->dims[1]; }
dim_t G() const { return with_groups() ? _wei_md()->dims[0] : 1; }
dim_t ID() const { return ndims() >= 5 ? _src_md()->dims[ndims() - 3] : 1; }
dim_t IH() const { return ndims() >= 4 ? _src_md()->dims[ndims() - 2] : 1; }
dim_t IW() const { return _src_md()->dims[ndims() - 1]; }
dim_t OD() const { return ndims() >= 5 ? _dst_md()->dims[ndims() - 3] : 1; }
dim_t OH() const { return ndims() >= 4 ? _dst_md()->dims[ndims() - 2] : 1; }
dim_t OW() const { return _dst_md()->dims[ndims() - 1]; }
dim_t KD() const { return ndims() >= 5 ? _wei_md()->dims[ndims() + with_groups() - 3] : 1; }
dim_t KH() const { return ndims() >= 4 ? _wei_md()->dims[ndims() + with_groups() - 2] : 1; }
dim_t KW() const { return _wei_md()->dims[ndims() + with_groups() - 1]; }
dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; }
dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; }
dim_t KSW() const { return desc_.strides[ndims() - 3]; }
dim_t KDD() const { return ndims() >= 5 ? desc_.dilates[ndims() - 5] : 0; }
dim_t KDH() const { return ndims() >= 4 ? desc_.dilates[ndims() - 4] : 1; }
dim_t KDW() const { return desc_.dilates[ndims() - 3]; }
dim_t padFront() const { return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; }
dim_t padBack() const { return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; }
dim_t padT() const { return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; }
dim_t padB() const { return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; }
dim_t padL() const { return desc_.padding[0][ndims() - 3]; }
dim_t padR() const { return desc_.padding[1][ndims() - 3]; }
int ndims() const { return _src_md()->ndims; }
bool with_bias() const { return !memory_desc_wrapper(*_bia_md()).is_zero(); }
bool with_groups() const { return _wei_md()->ndims == ndims() + 1; }
bool is_fwd() const {
return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
prop_kind::forward_inference);
}
bool has_zero_dim_memory() const {
const auto s_d = memory_desc_wrapper(*_src_md());
const auto d_d = memory_desc_wrapper(*_dst_md());
return s_d.has_zero_dim() || d_d.has_zero_dim();
}
protected:
convolution_desc_t desc_;
const convolution_fwd_pd_t *hint_fwd_pd_;
bool set_default_formats_common_template(
memory_desc_t &src_md, format_tag_t src_tag,
memory_desc_t &wei_md, format_tag_t wei_tag,
memory_desc_t &dst_md, format_tag_t dst_tag,
memory_desc_t &bia_md) {
using namespace format_tag;
# define IS_OK(f) \
do { if ((f) != status::success) return false; } while(0)
if (src_md.format_kind == format_kind::any
&& !utils::one_of(src_tag, any, undef))
IS_OK(memory_desc_init_by_tag(src_md, src_tag));
if (dst_md.format_kind == format_kind::any
&& !utils::one_of(dst_tag, any, undef))
IS_OK(memory_desc_init_by_tag(dst_md, dst_tag));
if (wei_md.format_kind == format_kind::any
&& !utils::one_of(wei_tag, any, undef))
IS_OK(memory_desc_init_by_tag(wei_md, wei_tag));
if (with_bias() && bia_md.format_kind == format_kind::any)
IS_OK(memory_desc_init_by_tag(bia_md, x));
# undef IS_OK
return true;
}
bool set_default_alg_kind(alg_kind_t alg_kind) {
assert(utils::one_of(alg_kind, alg_kind::convolution_direct,
alg_kind::convolution_winograd));
if (desc_.alg_kind == alg_kind::convolution_auto)
desc_.alg_kind = alg_kind;
return desc_.alg_kind == alg_kind;
}
bool expect_data_types(data_type_t src_dt, data_type_t wei_dt,
data_type_t bia_dt, data_type_t dst_dt, data_type_t acc_dt) const {
bool ok = true
&& (src_dt == data_type::undef || _src_md()->data_type == src_dt)
&& (wei_dt == data_type::undef || _wei_md()->data_type == wei_dt)
&& (dst_dt == data_type::undef || _dst_md()->data_type == dst_dt)
&& (acc_dt == data_type::undef || desc_.accum_data_type == acc_dt);
if (with_bias() && bia_dt != data_type::undef)
ok = ok && _bia_md()->data_type == bia_dt;
return ok;
}
private:
const memory_desc_t *_src_md() const { return conv_prop_invariant_src_d(&desc_); }
const memory_desc_t *_wei_md() const { return conv_prop_invariant_wei_d(&desc_); }
const memory_desc_t *_bia_md() const { return conv_prop_invariant_bia_d(&desc_); }
const memory_desc_t *_dst_md() const { return conv_prop_invariant_dst_d(&desc_); }
};
struct convolution_fwd_pd_t: public convolution_pd_t {
typedef convolution_fwd_pd_t base_class;
typedef convolution_fwd_pd_t hint_class;
convolution_fwd_pd_t(engine_t *engine,
const convolution_desc_t *adesc,
const primitive_attr_t *attr,
const convolution_fwd_pd_t *hint_fwd_pd)
: convolution_pd_t(engine, adesc, attr, hint_fwd_pd)
, src_md_(desc_.src_desc)
, weights_md_(desc_.weights_desc)
, bias_md_(desc_.bias_desc)
, dst_md_(desc_.dst_desc)
{}
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS))
return arg_usage_t::input;
if (arg == MKLDNN_ARG_BIAS && with_bias())
return arg_usage_t::input;
if (arg == MKLDNN_ARG_DST)
return arg_usage_t::output;
return primitive_desc_t::arg_usage(arg);
}
virtual const memory_desc_t *src_md(int index = 0) const override
{ return index == 0 ? &src_md_ : nullptr; }
virtual const memory_desc_t *dst_md(int index = 0) const override
{ return index == 0 ? &dst_md_ : nullptr; }
virtual const memory_desc_t *weights_md(int index = 0) const override {
if (index == 0) return &weights_md_;
if (index == 1 && with_bias()) return &bias_md_;
return nullptr;
}
virtual int n_inputs() const override { return 2 + with_bias(); }
virtual int n_outputs() const override { return 1; }
protected:
memory_desc_t src_md_;
memory_desc_t weights_md_;
memory_desc_t bias_md_;
memory_desc_t dst_md_;
bool set_default_formats_common(format_tag_t src_tag,
format_tag_t wei_tag, format_tag_t dst_tag) {
return set_default_formats_common_template(src_md_, src_tag,
weights_md_, wei_tag, dst_md_, dst_tag, bias_md_);
}
};
struct convolution_bwd_data_pd_t: public convolution_pd_t {
typedef convolution_bwd_data_pd_t base_class;
typedef convolution_fwd_pd_t hint_class;
convolution_bwd_data_pd_t(engine_t *engine,
const convolution_desc_t *adesc,
const primitive_attr_t *attr,
const convolution_fwd_pd_t *hint_fwd_pd)
: convolution_pd_t(engine, adesc, attr, hint_fwd_pd)
, diff_src_md_(desc_.diff_src_desc)
, weights_md_(desc_.weights_desc)
, bias_md_(desc_.bias_desc)
, diff_dst_md_(desc_.diff_dst_desc)
{}
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST))
return arg_usage_t::input;
if (arg == MKLDNN_ARG_DIFF_SRC)
return arg_usage_t::output;
return primitive_desc_t::arg_usage(arg);
}
virtual const memory_desc_t *diff_src_md(int index = 0) const override
{ return index == 0 ? &diff_src_md_ : nullptr; }
virtual const memory_desc_t *diff_dst_md(int index = 0) const override
{ return index == 0 ? &diff_dst_md_ : nullptr; }
virtual const memory_desc_t *weights_md(int index = 0) const override {
if (index == 0) return &weights_md_;
if (index == 1 && with_bias()) return &bias_md_;
return nullptr;
}
virtual int n_inputs() const override { return 2 + with_bias(); }
virtual int n_outputs() const override { return 1; }
virtual bool support_bias() const { return false; }
protected:
memory_desc_t diff_src_md_;
memory_desc_t weights_md_;
memory_desc_t bias_md_;
memory_desc_t diff_dst_md_;
bool set_default_formats_common(format_tag_t diff_src_tag,
format_tag_t wei_tag, format_tag_t diff_dst_tag) {
return set_default_formats_common_template(diff_src_md_, diff_src_tag,
weights_md_, wei_tag, diff_dst_md_, diff_dst_tag, bias_md_);
}
};
struct convolution_bwd_weights_pd_t: public convolution_pd_t {
typedef convolution_bwd_weights_pd_t base_class;
typedef convolution_fwd_pd_t hint_class;
convolution_bwd_weights_pd_t(engine_t *engine,
const convolution_desc_t *adesc,
const primitive_attr_t *attr,
const convolution_fwd_pd_t *hint_fwd_pd)
: convolution_pd_t(engine, adesc, attr, hint_fwd_pd)
, src_md_(desc_.src_desc)
, diff_weights_md_(desc_.diff_weights_desc)
, diff_bias_md_(desc_.diff_bias_desc)
, diff_dst_md_(desc_.diff_dst_desc)
{}
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST))
return arg_usage_t::input;
if (arg == MKLDNN_ARG_DIFF_WEIGHTS)
return arg_usage_t::output;
if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias())
return arg_usage_t::output;
return primitive_desc_t::arg_usage(arg);
}
virtual const memory_desc_t *src_md(int index = 0) const override
{ return index == 0 ? &src_md_ : nullptr; }
virtual const memory_desc_t *diff_dst_md(int index = 0) const override
{ return index == 0 ? &diff_dst_md_ : nullptr; }
virtual const memory_desc_t *diff_weights_md(int index = 0) const override {
if (index == 0) return &diff_weights_md_;
if (index == 1 && with_bias()) return &diff_bias_md_;
return nullptr;
}
virtual int n_inputs() const override { return 2; }
virtual int n_outputs() const override { return 1 + with_bias(); }
protected:
memory_desc_t src_md_;
memory_desc_t diff_weights_md_;
memory_desc_t diff_bias_md_;
memory_desc_t diff_dst_md_;
bool set_default_formats_common(format_tag_t src_tag,
format_tag_t diff_wei_tag, format_tag_t diff_dst_tag) {
return set_default_formats_common_template(src_md_, src_tag,
diff_weights_md_, diff_wei_tag, diff_dst_md_, diff_dst_tag,
diff_bias_md_);
}
};
}
}
#endif
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

View file

@ -1,188 +0,0 @@
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "mkldnn.h"
#include <assert.h>
#include "c_types_map.hpp"
#include "type_helpers.hpp"
#include "utils.hpp"
using namespace mkldnn::impl;
using namespace mkldnn::impl::utils;
using namespace mkldnn::impl::status;
using namespace mkldnn::impl::prop_kind;
using namespace mkldnn::impl::alg_kind;
using namespace mkldnn::impl::types;
namespace {
status_t deconv_desc_init(deconvolution_desc_t *deconv_desc,
prop_kind_t prop_kind, alg_kind_t alg_kind,
const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
const memory_desc_t *bias_desc, const memory_desc_t *dst_desc,
const dims_t strides, const dims_t dilates, const dims_t padding_l,
const dims_t padding_r, padding_kind_t padding_kind) {
bool args_ok = true
&& !any_null(deconv_desc, src_desc, weights_desc, dst_desc, strides,
padding_l)
&& one_of(alg_kind, deconvolution_direct, deconvolution_winograd)
&& one_of(padding_kind, padding_kind::padding_zero);
if (!args_ok)
return invalid_arguments;
if (padding_r == nullptr)
padding_r = padding_l;
auto dd = deconvolution_desc_t();
dd.primitive_kind = primitive_kind::deconvolution;
dd.prop_kind = prop_kind;
dd.alg_kind = alg_kind;
dd.diff_src_desc = dd.src_desc = zero_md();
dd.diff_dst_desc = dd.dst_desc = zero_md();
dd.diff_weights_desc = dd.weights_desc = zero_md();
dd.diff_bias_desc = dd.bias_desc = zero_md();
const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
const bool with_bias
= bias_desc && bias_desc->format_kind != format_kind::undef;
const bool with_groups = weights_desc->ndims == src_desc->ndims + 1;
(prop_kind == backward_data ? dd.diff_src_desc : dd.src_desc) = *src_desc;
(is_fwd ? dd.dst_desc : dd.diff_dst_desc) = *dst_desc;
(prop_kind == backward_weights ? dd.diff_weights_desc : dd.weights_desc)
= *weights_desc;
if (with_bias)
(prop_kind == backward_weights ? dd.diff_bias_desc : dd.bias_desc)
= *bias_desc;
int sp_dims = src_desc->ndims - 2;
utils::array_copy(dd.strides, strides, sp_dims);
utils::array_copy(dd.padding[0], padding_l, sp_dims);
utils::array_copy(dd.padding[1], padding_r, sp_dims);
if (dilates)
utils::array_copy(dd.dilates, dilates, sp_dims);
else
utils::array_set(dd.dilates, 0, sp_dims);
dd.padding_kind = padding_kind;
dd.accum_data_type = types::default_accum_data_type(src_desc->data_type,
weights_desc->data_type, dst_desc->data_type, prop_kind);
const int g = with_groups ? weights_desc->dims[0] : 1;
bool consistency = true
&& src_desc->ndims == dst_desc->ndims
&& utils::one_of(src_desc->ndims, 3, 4, 5)
&& utils::one_of(weights_desc->ndims, src_desc->ndims,
src_desc->ndims + 1)
&& (with_bias ? bias_desc->ndims == 1 : true)
&& (with_bias ? bias_desc->dims[0] == dst_desc->dims[1] : true)
&& src_desc->dims[0] == dst_desc->dims[0]
&& src_desc->dims[1] == g * weights_desc->dims[with_groups + 1]
&& dst_desc->dims[1] == g * weights_desc->dims[with_groups + 0];
for (int i = 2; i < src_desc->ndims; ++i) {
int src = src_desc->dims[i];
int ker = weights_desc->dims[with_groups + i];
int dil = dd.dilates[i - 2];
int pad = padding_l[i - 2] + padding_r[i - 2];
int str = strides[i - 2];
int dst = dst_desc->dims[i];
int ker_range = 1 + (ker - 1) * (dil + 1);
consistency
= consistency && (dst - ker_range + pad) / str + 1 == src;
}
if (!consistency)
return invalid_arguments;
*deconv_desc = dd;
return success;
}
}
status_t mkldnn_deconvolution_forward_desc_init(
deconvolution_desc_t *deconv_desc, prop_kind_t prop_kind,
alg_kind_t alg_kind, const memory_desc_t *src_desc,
const memory_desc_t *weights_desc, const memory_desc_t *bias_desc,
const memory_desc_t *dst_desc, const dims_t strides,
const dims_t padding_l, const dims_t padding_r,
padding_kind_t padding_kind) {
if (!one_of(prop_kind, forward_training, forward_inference))
return invalid_arguments;
return deconv_desc_init(deconv_desc, prop_kind, alg_kind, src_desc,
weights_desc, bias_desc, dst_desc, strides, nullptr, padding_l,
padding_r, padding_kind);
}
status_t mkldnn_dilated_deconvolution_forward_desc_init(
deconvolution_desc_t *deconv_desc, prop_kind_t prop_kind,
alg_kind_t alg_kind, const memory_desc_t *src_desc,
const memory_desc_t *weights_desc, const memory_desc_t *bias_desc,
const memory_desc_t *dst_desc, const dims_t strides,
const dims_t dilates, const dims_t padding_l, const dims_t padding_r,
padding_kind_t padding_kind) {
if (!one_of(prop_kind, forward_training, forward_inference))
return invalid_arguments;
return deconv_desc_init(deconv_desc, prop_kind, alg_kind, src_desc,
weights_desc, bias_desc, dst_desc, strides, dilates, padding_l,
padding_r, padding_kind);
}
status_t mkldnn_deconvolution_backward_data_desc_init(
deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind,
const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc,
const memory_desc_t *diff_dst_desc, const dims_t strides,
const dims_t padding_l, const dims_t padding_r,
padding_kind_t padding_kind) {
return deconv_desc_init(deconv_desc, backward_data, alg_kind, diff_src_desc,
weights_desc, nullptr, diff_dst_desc, strides, nullptr, padding_l,
padding_r, padding_kind);
}
status_t mkldnn_dilated_deconvolution_backward_data_desc_init(
deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind,
const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc,
const memory_desc_t *diff_dst_desc, const dims_t strides,
const dims_t dilates, const dims_t padding_l, const dims_t padding_r,
padding_kind_t padding_kind) {
return deconv_desc_init(deconv_desc, backward_data, alg_kind, diff_src_desc,
weights_desc, nullptr, diff_dst_desc, strides,dilates, padding_l,
padding_r, padding_kind);
}
status_t mkldnn_deconvolution_backward_weights_desc_init(
deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind,
const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc,
const memory_desc_t *diff_bias_desc, const memory_desc_t *diff_dst_desc,
const dims_t strides, const dims_t padding_l, const dims_t padding_r,
padding_kind_t padding_kind) {
return deconv_desc_init(deconv_desc, backward_weights, alg_kind, src_desc,
diff_weights_desc, diff_bias_desc, diff_dst_desc, strides, nullptr,
padding_l, padding_r, padding_kind);
}
status_t mkldnn_dilated_deconvolution_backward_weights_desc_init(
deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind,
const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc,
const memory_desc_t *diff_bias_desc, const memory_desc_t *diff_dst_desc,
const dims_t strides, const dims_t dilates, const dims_t padding_l,
const dims_t padding_r, padding_kind_t padding_kind) {
return deconv_desc_init(deconv_desc, backward_weights, alg_kind, src_desc,
diff_weights_desc, diff_bias_desc, diff_dst_desc, strides, dilates,
padding_l, padding_r, padding_kind);
}
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

View file

@ -1,293 +0,0 @@
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifndef DECONVOLUTION_PD_HPP
#define DECONVOLUTION_PD_HPP
#include "mkldnn.h"
#include "c_types_map.hpp"
#include "convolution_pd.hpp"
#include "primitive_desc.hpp"
#include "utils.hpp"
namespace mkldnn {
namespace impl {
struct deconvolution_fwd_pd_t;
struct deconvolution_pd_t: public primitive_desc_t {
static constexpr auto base_pkind = primitive_kind::deconvolution;
deconvolution_pd_t(engine_t *engine,
const deconvolution_desc_t *adesc,
const primitive_attr_t *attr,
const deconvolution_fwd_pd_t *hint_fwd_pd)
: primitive_desc_t(engine, attr, base_pkind)
, desc_(*adesc)
, hint_fwd_pd_(hint_fwd_pd)
{}
const deconvolution_desc_t *desc() const { return &desc_; }
virtual const op_desc_t *op_desc() const override
{ return reinterpret_cast<const op_desc_t *>(this->desc()); }
virtual void init_info() override { impl::init_info(this, this->info_); }
virtual status_t query(query_t what, int idx, void *result) const override {
switch (what) {
case pkind_traits<base_pkind>::query_d:
*(const deconvolution_desc_t **)result = desc();
break;
default: return primitive_desc_t::query(what, idx, result);
}
return status::success;
}
/* common deconv aux functions (note that conv_desc_t == deconv_desc_t) */
dim_t MB() const { return conv_prop_invariant_src_d(&desc_)->dims[0]; }
dim_t IC() const { return conv_prop_invariant_src_d(&desc_)->dims[1]; }
dim_t OC() const { return conv_prop_invariant_dst_d(&desc_)->dims[1]; }
dim_t G() const
{ return with_groups() ? conv_prop_invariant_wei_d(&desc_)->dims[0] : 1; }
dim_t ID() const {
return ndims() >= 5
? conv_prop_invariant_src_d(&desc_)->dims[ndims() - 3] : 1;
}
dim_t IH() const {
return ndims() >= 4
? conv_prop_invariant_src_d(&desc_)->dims[ndims() - 2] : 1;
}
dim_t IW() const {
return conv_prop_invariant_src_d(&desc_)->dims[ndims() - 1];
}
dim_t OD() const {
return ndims() >= 5
? conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 3] : 1;
}
dim_t OH() const {
return ndims() >= 4
? conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 2] : 1;
}
dim_t OW() const {
return conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 1];
}
dim_t KD() const {
const int w_ndims = ndims() + with_groups();
return ndims() >= 5
? conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 3] : 1;
}
dim_t KH() const {
const int w_ndims = ndims() + with_groups();
return ndims() >= 4
? conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 2] : 1;
}
dim_t KW() const {
const int w_ndims = ndims() + with_groups();
return conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 1];
}
dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; }
dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; }
dim_t KSW() const { return desc_.strides[ndims() - 3]; }
dim_t KDD() const { return ndims() >= 5 ? desc_.dilates[ndims() - 5] : 0; }
dim_t KDH() const { return ndims() >= 4 ? desc_.dilates[ndims() - 4] : 1; }
dim_t KDW() const { return desc_.dilates[ndims() - 3]; }
dim_t padFront() const
{ return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; }
dim_t padBack() const
{ return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; }
dim_t padT() const
{ return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; }
dim_t padB() const
{ return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; }
dim_t padL() const { return desc_.padding[0][ndims() - 3]; }
dim_t padR() const { return desc_.padding[1][ndims() - 3]; }
bool with_bias() const {
return
!memory_desc_wrapper(*conv_prop_invariant_bia_d(&desc_)).is_zero();
}
bool with_groups() const
{ return conv_prop_invariant_wei_d(&desc_)->ndims == ndims() + 1; }
int ndims() const { return conv_prop_invariant_src_d(&desc_)->ndims; }
bool is_fwd() const {
return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
prop_kind::forward_inference);
}
bool has_zero_dim_memory() const {
const auto s_d = memory_desc_wrapper(*conv_prop_invariant_src_d(&desc_));
const auto d_d = memory_desc_wrapper(*conv_prop_invariant_dst_d(&desc_));
return s_d.has_zero_dim() || d_d.has_zero_dim();
}
protected:
deconvolution_desc_t desc_;
const deconvolution_fwd_pd_t *hint_fwd_pd_;
};
struct deconvolution_fwd_pd_t: public deconvolution_pd_t {
typedef deconvolution_fwd_pd_t base_class;
typedef deconvolution_fwd_pd_t hint_class;
deconvolution_fwd_pd_t(engine_t *engine,
const deconvolution_desc_t *adesc,
const primitive_attr_t *attr,
const deconvolution_fwd_pd_t *hint_fwd_pd)
: deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd)
, src_md_(desc_.src_desc)
, weights_md_(desc_.weights_desc)
, bias_md_(desc_.bias_desc)
, dst_md_(desc_.dst_desc)
{}
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS))
return arg_usage_t::input;
if (arg == MKLDNN_ARG_BIAS && with_bias())
return arg_usage_t::input;
if (arg == MKLDNN_ARG_DST)
return arg_usage_t::output;
return primitive_desc_t::arg_usage(arg);
}
virtual const memory_desc_t *src_md(int index = 0) const override
{ return index == 0 ? &src_md_ : nullptr; }
virtual const memory_desc_t *dst_md(int index = 0) const override
{ return index == 0 ? &dst_md_ : nullptr; }
virtual const memory_desc_t *weights_md(int index = 0) const override {
if (index == 0) return &weights_md_;
if (index == 1 && with_bias()) return &bias_md_;
return nullptr;
}
virtual int n_inputs() const override { return 2 + with_bias(); }
virtual int n_outputs() const override { return 1; }
protected:
memory_desc_t src_md_;
memory_desc_t weights_md_;
memory_desc_t bias_md_;
memory_desc_t dst_md_;
};
struct deconvolution_bwd_data_pd_t: public deconvolution_pd_t {
typedef deconvolution_bwd_data_pd_t base_class;
typedef deconvolution_fwd_pd_t hint_class;
deconvolution_bwd_data_pd_t(engine_t *engine,
const deconvolution_desc_t *adesc,
const primitive_attr_t *attr,
const deconvolution_fwd_pd_t *hint_fwd_pd)
: deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd)
, diff_src_md_(desc_.diff_src_desc)
, weights_md_(desc_.weights_desc)
, diff_dst_md_(desc_.diff_dst_desc)
{}
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST))
return arg_usage_t::input;
if (arg == MKLDNN_ARG_DIFF_SRC)
return arg_usage_t::output;
return primitive_desc_t::arg_usage(arg);
}
virtual const memory_desc_t *diff_src_md(int index = 0) const override
{ return index == 0 ? &diff_src_md_ : nullptr; }
virtual const memory_desc_t *diff_dst_md(int index = 0) const override
{ return index == 0 ? &diff_dst_md_ : nullptr; }
virtual const memory_desc_t *weights_md(int index = 0) const override
{ return index == 0 ? &weights_md_ : nullptr; }
virtual int n_inputs() const override { return 2; }
virtual int n_outputs() const override { return 1; }
protected:
memory_desc_t diff_src_md_;
memory_desc_t weights_md_;
memory_desc_t diff_dst_md_;
};
struct deconvolution_bwd_weights_pd_t: public deconvolution_pd_t {
typedef deconvolution_bwd_weights_pd_t base_class;
typedef deconvolution_fwd_pd_t hint_class;
deconvolution_bwd_weights_pd_t(engine_t *engine,
const deconvolution_desc_t *adesc,
const primitive_attr_t *attr,
const deconvolution_fwd_pd_t *hint_fwd_pd)
: deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd)
, src_md_(desc_.src_desc)
, diff_weights_md_(desc_.diff_weights_desc)
, diff_bias_md_(desc_.diff_bias_desc)
, diff_dst_md_(desc_.diff_dst_desc)
{}
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST))
return arg_usage_t::input;
if (arg == MKLDNN_ARG_DIFF_WEIGHTS)
return arg_usage_t::output;
if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias())
return arg_usage_t::output;
return primitive_desc_t::arg_usage(arg);
}
virtual const memory_desc_t *src_md(int index = 0) const override
{ return index == 0 ? &src_md_ : nullptr; }
virtual const memory_desc_t *diff_dst_md(int index = 0) const override
{ return index == 0 ? &diff_dst_md_ : nullptr; }
virtual const memory_desc_t *diff_weights_md(int index = 0) const override {
if (index == 0) return &diff_weights_md_;
if (index == 1 && with_bias()) return &diff_bias_md_;
return nullptr;
}
virtual int n_inputs() const override { return 2; }
virtual int n_outputs() const override { return 1 + with_bias(); }
protected:
memory_desc_t src_md_;
memory_desc_t diff_weights_md_;
memory_desc_t diff_bias_md_;
memory_desc_t diff_dst_md_;
};
}
}
#endif
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

View file

@ -1,84 +0,0 @@
/*******************************************************************************
* Copyright 2016-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <assert.h>
#include "mkldnn.h"
#include "c_types_map.hpp"
#include "type_helpers.hpp"
#include "utils.hpp"
using namespace mkldnn::impl;
using namespace mkldnn::impl::utils;
using namespace mkldnn::impl::status;
using namespace mkldnn::impl::prop_kind;
using namespace mkldnn::impl::alg_kind;
using namespace mkldnn::impl::types;
namespace {
status_t eltwise_desc_init(eltwise_desc_t *eltwise_desc, prop_kind_t prop_kind,
alg_kind_t alg_kind, const memory_desc_t *data_desc,
const memory_desc_t *diff_data_desc, float alpha, float beta) {
bool args_ok = true
&& !any_null(eltwise_desc, data_desc)
&& one_of(prop_kind, forward_training, forward_inference,
backward_data)
&& one_of(alg_kind, eltwise_relu, eltwise_tanh, eltwise_elu,
eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic)
&& IMPLICATION(prop_kind == backward_data, diff_data_desc != nullptr);
if (!args_ok) return invalid_arguments;
auto ed = eltwise_desc_t();
ed.primitive_kind = primitive_kind::eltwise;
ed.prop_kind = prop_kind;
ed.alg_kind = alg_kind;
ed.data_desc = *data_desc;
ed.diff_data_desc =
(ed.prop_kind == backward_data) ? *diff_data_desc : zero_md();
ed.alpha = alpha;
ed.beta = beta;
bool consistency = true
&& IMPLICATION(ed.prop_kind == backward_data,
array_cmp(ed.diff_data_desc.dims, ed.data_desc.dims,
ed.diff_data_desc.ndims));
if (!consistency) return invalid_arguments;
*eltwise_desc = ed;
return success;
}
}
status_t mkldnn_eltwise_forward_desc_init(eltwise_desc_t *eltwise_desc,
prop_kind_t prop_kind, alg_kind_t alg_kind,
const memory_desc_t *data_desc, float alpha, float beta) {
if (!one_of(prop_kind, forward_training, forward_inference))
return invalid_arguments;
return eltwise_desc_init(eltwise_desc, prop_kind, alg_kind, data_desc,
nullptr, alpha, beta);
}
status_t mkldnn_eltwise_backward_desc_init(eltwise_desc_t *eltwise_desc,
alg_kind_t alg_kind, const memory_desc_t *diff_data_desc,
const memory_desc_t *data_desc, float alpha, float beta) {
return eltwise_desc_init(eltwise_desc, backward_data, alg_kind, data_desc,
diff_data_desc, alpha, beta);
}
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

View file

@ -1,161 +0,0 @@
/*******************************************************************************
* Copyright 2016-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifndef ELTWISE_PD_HPP
#define ELTWISE_PD_HPP
#include "mkldnn.h"
#include "c_types_map.hpp"
#include "primitive_desc.hpp"
namespace mkldnn {
namespace impl {
struct eltwise_fwd_pd_t;
struct eltwise_pd_t: public primitive_desc_t {
static constexpr auto base_pkind = primitive_kind::eltwise;
eltwise_pd_t(mkldnn::impl::engine_t *engine,
const eltwise_desc_t *adesc,
const primitive_attr_t *attr,
const eltwise_fwd_pd_t *hint_fwd_pd)
: primitive_desc_t(engine, attr, base_pkind)
, desc_(*adesc)
, hint_fwd_pd_(hint_fwd_pd)
, data_md_(desc_.data_desc)
{}
const eltwise_desc_t *desc() const { return &desc_; }
virtual const op_desc_t *op_desc() const override
{ return reinterpret_cast<const op_desc_t *>(this->desc()); }
virtual void init_info() override { impl::init_info(this, this->info_); }
virtual status_t query(query_t what, int idx, void *result) const override {
switch (what) {
case query::eltwise_d:
*(const eltwise_desc_t**)result = desc(); break;
default: return primitive_desc_t::query(what, idx, result);
}
return status::success;
}
/* common eltwise aux functions */
dim_t MB() const { return data_desc().dims[0]; }
dim_t C() const { return data_desc().dims[1]; }
dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; }
dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; }
dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; }
int ndims() const { return data_desc().ndims; }
bool is_fwd() const {
return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
prop_kind::forward_inference);
}
bool has_zero_dim_memory() const
{ return memory_desc_wrapper(desc_.data_desc).has_zero_dim(); }
protected:
eltwise_desc_t desc_;
const eltwise_fwd_pd_t *hint_fwd_pd_;
memory_desc_t data_md_;
private:
const memory_desc_t &data_desc() const { return desc_.data_desc; }
};
struct eltwise_fwd_pd_t: public eltwise_pd_t {
typedef eltwise_fwd_pd_t base_class;
typedef eltwise_fwd_pd_t hint_class;
eltwise_fwd_pd_t(mkldnn::impl::engine_t *engine,
const eltwise_desc_t *adesc,
const primitive_attr_t *attr,
const eltwise_fwd_pd_t *hint_fwd_pd)
: eltwise_pd_t(engine, adesc, attr, hint_fwd_pd)
{}
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
if (arg == MKLDNN_ARG_SRC)
return arg_usage_t::input;
if (arg == MKLDNN_ARG_DST)
return arg_usage_t::output;
return primitive_desc_t::arg_usage(arg);
}
virtual const memory_desc_t *src_md(int index = 0) const override
{ return index == 0 ? &data_md_ : nullptr; }
virtual const memory_desc_t *dst_md(int index = 0) const override
{ return index == 0 ? &data_md_ : nullptr; }
virtual int n_inputs() const override { return 1; }
virtual int n_outputs() const override { return 1; }
bool is_zero_preserved() const
{ return math::eltwise_fwd_preserves_zero(desc_.alg_kind); }
};
struct eltwise_bwd_pd_t: public eltwise_pd_t {
typedef eltwise_bwd_pd_t base_class;
typedef eltwise_fwd_pd_t hint_class;
eltwise_bwd_pd_t(engine_t *engine,
const eltwise_desc_t *adesc,
const primitive_attr_t *attr,
const eltwise_fwd_pd_t *hint_fwd_pd)
: eltwise_pd_t(engine, adesc, attr, hint_fwd_pd)
, diff_data_md_(desc_.diff_data_desc)
{}
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST))
return arg_usage_t::input;
if (arg == MKLDNN_ARG_DIFF_SRC)
return arg_usage_t::output;
return primitive_desc_t::arg_usage(arg);
}
virtual const memory_desc_t *src_md(int index = 0) const override
{ return index == 0 ? &data_md_ : nullptr; }
virtual const memory_desc_t *diff_dst_md(int index = 0) const override
{ return index == 0 ? &diff_data_md_ : nullptr; }
virtual const memory_desc_t *diff_src_md(int index = 0) const override
{ return index == 0 ? &diff_data_md_ : nullptr; }
virtual int n_inputs() const override { return 2; }
virtual int n_outputs() const override { return 1; }
bool is_zero_preserved() const { return true; }
protected:
memory_desc_t diff_data_md_;
};
}
}
#endif
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

View file

@ -1,75 +0,0 @@
/*******************************************************************************
* Copyright 2016-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "mkldnn.h"
#include "engine.hpp"
#include "nstl.hpp"
#include "c_types_map.hpp"
#include "../cpu/cpu_engine.hpp"
namespace mkldnn {
namespace impl {
engine_factory_t *engine_factories[] = {
&cpu::engine_factory,
nullptr,
};
static inline engine_factory_t *get_engine_factory(engine_kind_t kind) {
for (engine_factory_t **ef = engine_factories; *ef; ef++)
if ((*ef)->kind() == kind)
return *ef;
return nullptr;
}
}
}
using namespace mkldnn::impl;
using namespace mkldnn::impl::status;
size_t mkldnn_engine_get_count(engine_kind_t kind) {
engine_factory_t *ef = get_engine_factory(kind);
return ef != nullptr ? ef->count() : 0;
}
status_t mkldnn_engine_create(engine_t **engine,
engine_kind_t kind, size_t index) {
if (engine == nullptr)
return invalid_arguments;
engine_factory_t *ef = get_engine_factory(kind);
if (ef == nullptr || index >= ef->count())
return invalid_arguments;
return ef->engine_create(engine, index);
}
status_t mkldnn_engine_get_kind(engine_t *engine, engine_kind_t *kind) {
if (engine == nullptr)
return invalid_arguments;
*kind = engine->kind();
return success;
}
status_t mkldnn_engine_destroy(engine_t *engine) {
/* TODO: engine->dec_ref_count(); */
delete engine;
return success;
}
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

View file

@ -1,119 +0,0 @@
/*******************************************************************************
* Copyright 2016-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifndef ENGINE_HPP
#define ENGINE_HPP
#include "mkldnn.h"
#include "c_types_map.hpp"
#include "primitive.hpp"
#include "utils.hpp"
/** \brief An abstraction of an execution unit with shared resources
*
* Responsibilities:
* - Provide engine specific memory allocation
* - Provide engine specific primitive_desc_t creators
*/
struct mkldnn_engine: public mkldnn::impl::c_compatible {
mkldnn_engine(mkldnn::impl::engine_kind_t kind)
: kind_(kind)
{}
virtual ~mkldnn_engine() {}
/** get kind of the current engine */
virtual mkldnn::impl::engine_kind_t kind() const { return kind_; }
/** allocate memory */
virtual mkldnn::impl::status_t memory_create(
mkldnn::impl::memory_t **memory,
const mkldnn::impl::memory_desc_t *md,
void *handle) = 0;
/** implementation section (typedefs) */
// TODO: remove engine?
typedef mkldnn::impl::status_t (*reorder_primitive_desc_create_f)(
mkldnn::impl::reorder_pd_t **reorder_pd,
mkldnn::impl::engine_t *engine,
const mkldnn::impl::primitive_attr_t *attr,
mkldnn::impl::engine_t *src_engine,
const mkldnn::impl::memory_desc_t *src_md,
mkldnn::impl::engine_t *dst_engine,
const mkldnn::impl::memory_desc_t *dst_md);
typedef mkldnn::impl::status_t (*concat_primitive_desc_create_f)(
mkldnn::impl::concat_pd_t **concat_pd,
mkldnn::impl::engine_t *engine,
const mkldnn::impl::primitive_attr_t *attr,
const mkldnn::impl::memory_desc_t *dst_md,
int n, int concat_dim,
const mkldnn::impl::memory_desc_t *src_mds);
typedef mkldnn::impl::status_t (*sum_primitive_desc_create_f)(
mkldnn::impl::sum_pd_t **sum_pd,
mkldnn::impl::engine_t *engine,
const mkldnn::impl::primitive_attr_t *attr,
const mkldnn::impl::memory_desc_t *dst_md,
int n, const float *scales,
const mkldnn::impl::memory_desc_t *src_mds);
typedef mkldnn::impl::status_t (*primitive_desc_create_f)(
mkldnn::impl::primitive_desc_t **, const mkldnn::impl::op_desc_t *,
const mkldnn::impl::primitive_attr_t *attr,
mkldnn::impl::engine_t *, const mkldnn::impl::primitive_desc_t *);
/* implementation section */
/** return the list of reorder implementations. engine guarantees to return
* a NULL-terminated list */
virtual const reorder_primitive_desc_create_f*
get_reorder_implementation_list() const = 0;
/** return the list of concat implementations. engine guarantees to return
* a NULL-terminated list */
virtual const concat_primitive_desc_create_f*
get_concat_implementation_list() const = 0;
/** return the list of sum implementations. engine guarantees to return
* a NULL-terminated list */
virtual const sum_primitive_desc_create_f*
get_sum_implementation_list() const = 0;
/** return the list of implementations. engine guarantees to return a
* NULL-terminated list */
virtual const primitive_desc_create_f* get_implementation_list() const = 0;
protected:
mkldnn::impl::engine_kind_t kind_;
};
namespace mkldnn {
namespace impl {
struct engine_factory_t: public c_compatible {
virtual size_t count() const = 0;
virtual engine_kind_t kind() const = 0;
virtual status_t engine_create(engine_t **engine, size_t index) const = 0;
};
}
}
#endif
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

View file

@ -1,106 +0,0 @@
/*******************************************************************************
* Copyright 2016-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <assert.h>
#include "mkldnn.h"
#include "c_types_map.hpp"
#include "type_helpers.hpp"
#include "utils.hpp"
using namespace mkldnn::impl;
using namespace mkldnn::impl::utils;
using namespace mkldnn::impl::status;
using namespace mkldnn::impl::prop_kind;
using namespace mkldnn::impl::types;
namespace {
status_t ip_desc_init(inner_product_desc_t *ip_desc, prop_kind_t prop_kind,
const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
const memory_desc_t *bias_desc, const memory_desc_t *dst_desc) {
bool args_ok = !any_null(ip_desc, src_desc, weights_desc, dst_desc);
if (!args_ok) return invalid_arguments;
auto id = inner_product_desc_t();
id.primitive_kind = primitive_kind::inner_product;
id.prop_kind = prop_kind;
id.diff_src_desc = id.src_desc = zero_md();
id.diff_dst_desc = id.dst_desc = zero_md();
id.diff_weights_desc = id.weights_desc = zero_md();
id.diff_bias_desc = id.bias_desc = zero_md();
const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
const bool with_bias =
bias_desc && bias_desc->format_kind != format_kind::undef;
(prop_kind == backward_data ? id.diff_src_desc : id.src_desc) = *src_desc;
(is_fwd ? id.dst_desc : id.diff_dst_desc) = *dst_desc;
(prop_kind == backward_weights ? id.diff_weights_desc : id.weights_desc) =
*weights_desc;
if (with_bias)
(prop_kind == backward_weights ? id.diff_bias_desc : id.bias_desc) =
*bias_desc;
id.accum_data_type = types::default_accum_data_type(src_desc->data_type,
weights_desc->data_type, dst_desc->data_type, prop_kind);
bool consistency = true
&& memory_desc_wrapper(weights_desc).nelems()
&& one_of(src_desc->ndims, 2, 3, 4, 5)
&& dst_desc->ndims == 2
&& weights_desc->ndims == src_desc->ndims
&& (with_bias ? bias_desc->ndims == 1 : true)
&& (with_bias ? bias_desc->dims[0] == dst_desc->dims[1] : true)
&& src_desc->dims[0] == dst_desc->dims[0]
&& array_cmp(&src_desc->dims[1], &weights_desc->dims[1],
src_desc->ndims - 1)
&& dst_desc->dims[1] == weights_desc->dims[0];
if (!consistency) return invalid_arguments;
*ip_desc = id;
return success;
}
}
status_t mkldnn_inner_product_forward_desc_init(inner_product_desc_t *ip_desc,
prop_kind_t prop_kind, const memory_desc_t *src_desc,
const memory_desc_t *weights_desc, const memory_desc_t *bias_desc,
const memory_desc_t *dst_desc) {
if (!one_of(prop_kind, forward_training, forward_inference))
return invalid_arguments;
return ip_desc_init(ip_desc, prop_kind, src_desc, weights_desc, bias_desc,
dst_desc);
}
status_t mkldnn_inner_product_backward_data_desc_init(
inner_product_desc_t *ip_desc, const memory_desc_t *diff_src_desc,
const memory_desc_t *weights_desc, const memory_desc_t *diff_dst_desc)
{
return ip_desc_init(ip_desc, backward_data, diff_src_desc, weights_desc,
nullptr, diff_dst_desc);
}
status_t mkldnn_inner_product_backward_weights_desc_init(
inner_product_desc_t *ip_desc, const memory_desc_t *src_desc,
const memory_desc_t *diff_weights_desc,
const memory_desc_t *diff_bias_desc,
const memory_desc_t *diff_dst_desc) {
return ip_desc_init(ip_desc, backward_weights, src_desc, diff_weights_desc,
diff_bias_desc, diff_dst_desc);
}
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

View file

@ -1,56 +0,0 @@
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "utils.hpp"
#include "inner_product_pd.hpp"
namespace mkldnn {
namespace impl {
using namespace prop_kind;
memory_desc_t *ip_prop_invariant_src_d(inner_product_desc_t *desc) {
return desc->prop_kind == backward_data
? &desc->diff_src_desc : &desc->src_desc;
}
memory_desc_t *ip_prop_invariant_wei_d(inner_product_desc_t *desc) {
return desc->prop_kind == backward_weights
? &desc->diff_weights_desc : &desc->weights_desc;
}
memory_desc_t *ip_prop_invariant_bia_d(inner_product_desc_t *desc) {
return desc->prop_kind == backward_weights
? &desc->diff_bias_desc : &desc->bias_desc;
}
memory_desc_t *ip_prop_invariant_dst_d(inner_product_desc_t *desc) {
return utils::one_of(desc->prop_kind, forward_inference, forward_training)
? &desc->dst_desc : &desc->diff_dst_desc;
}
const memory_desc_t *ip_prop_invariant_src_d(const inner_product_desc_t *desc)
{ return ip_prop_invariant_src_d(const_cast<inner_product_desc_t *>(desc)); }
const memory_desc_t *ip_prop_invariant_wei_d(const inner_product_desc_t *desc)
{ return ip_prop_invariant_wei_d(const_cast<inner_product_desc_t *>(desc)); }
const memory_desc_t *ip_prop_invariant_bia_d(const inner_product_desc_t *desc)
{ return ip_prop_invariant_bia_d(const_cast<inner_product_desc_t *>(desc)); }
const memory_desc_t *ip_prop_invariant_dst_d(const inner_product_desc_t *desc)
{ return ip_prop_invariant_dst_d(const_cast<inner_product_desc_t *>(desc)); }
}
}

View file

@ -1,321 +0,0 @@
/*******************************************************************************
* Copyright 2016-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifndef INNER_PRODUCT_PD_HPP
#define INNER_PRODUCT_PD_HPP
#include "mkldnn.h"
#include "c_types_map.hpp"
#include "primitive_desc.hpp"
#include "utils.hpp"
namespace mkldnn {
namespace impl {
memory_desc_t *ip_prop_invariant_src_d(inner_product_desc_t *desc);
memory_desc_t *ip_prop_invariant_wei_d(inner_product_desc_t *desc);
memory_desc_t *ip_prop_invariant_bia_d(inner_product_desc_t *desc);
memory_desc_t *ip_prop_invariant_dst_d(inner_product_desc_t *desc);
const memory_desc_t *ip_prop_invariant_src_d(const inner_product_desc_t *desc);
const memory_desc_t *ip_prop_invariant_wei_d(const inner_product_desc_t *desc);
const memory_desc_t *ip_prop_invariant_bia_d(const inner_product_desc_t *desc);
const memory_desc_t *ip_prop_invariant_dst_d(const inner_product_desc_t *desc);
struct inner_product_fwd_pd_t;
struct inner_product_pd_t: public primitive_desc_t {
static constexpr auto base_pkind = primitive_kind::inner_product;
inner_product_pd_t(engine_t *engine,
const inner_product_desc_t *adesc,
const primitive_attr_t *attr,
const inner_product_fwd_pd_t *hint_fwd_pd)
: primitive_desc_t(engine, attr, base_pkind)
, desc_(*adesc)
, hint_fwd_pd_(hint_fwd_pd)
{}
const inner_product_desc_t *desc() const { return &desc_; }
virtual const op_desc_t *op_desc() const override
{ return reinterpret_cast<const op_desc_t *>(this->desc()); }
virtual void init_info() override { impl::init_info(this, this->info_); }
virtual status_t query(query_t what, int idx, void *result) const override {
switch (what) {
case query::inner_product_d:
*(const inner_product_desc_t**)result = desc(); break;
default: return primitive_desc_t::query(what, idx, result);
}
return status::success;
}
/* common inner_product aux functions */
dim_t MB() const { return ip_prop_invariant_src_d(&desc_)->dims[0]; }
dim_t IC() const { return ip_prop_invariant_src_d(&desc_)->dims[1]; }
dim_t OC() const { return ip_prop_invariant_dst_d(&desc_)->dims[1]; }
dim_t ID() const {
return ndims() >= 5
? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 3] : 1;
}
dim_t IH() const {
return ndims() >= 4
? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 2] : 1;
}
dim_t IW() const {
return ndims() >= 3
? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 1] : 1;
}
dim_t OD() const {
return ndims() >= 5
? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 3] : 1;
}
dim_t OH() const {
return ndims() >= 4
? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 2] : 1;
}
dim_t OW() const {
return ndims() >= 3
? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 1] : 1;
}
dim_t KD() const {
return ndims() >= 5
? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 3] : 1;
}
dim_t KH() const {
return ndims() >= 4
? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 2] : 1;
}
dim_t KW() const {
return ndims() >= 3
? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 1] : 1;
}
dim_t IC_total() const {
return utils::array_product(&ip_prop_invariant_src_d(&desc_)->dims[1],
ndims() - 1);
}
dim_t IC_total_padded() const {
auto src_d = desc()->prop_kind == prop_kind::backward_data
? memory_desc_wrapper(diff_src_md())
: memory_desc_wrapper(src_md());
assert(src_d.is_blocking_desc());
if (!src_d.is_blocking_desc()) return -1;
return utils::array_product(src_d.padded_dims() + 1, ndims() - 1);
}
int ndims() const { return ip_prop_invariant_src_d(&desc_)->ndims; }
bool with_bias() const
{ return !memory_desc_wrapper(*ip_prop_invariant_bia_d(&desc_)).is_zero(); }
bool has_zero_dim_memory() const {
const auto s_d = memory_desc_wrapper(*ip_prop_invariant_src_d(&desc_));
const auto d_d = memory_desc_wrapper(*ip_prop_invariant_dst_d(&desc_));
return s_d.has_zero_dim() || d_d.has_zero_dim();
}
bool is_fwd() const {
return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
prop_kind::forward_inference);
}
protected:
inner_product_desc_t desc_;
const inner_product_fwd_pd_t *hint_fwd_pd_;
status_t template_set_default_params(memory_desc_t &src_md,
memory_desc_t &weights_md, memory_desc_t &dst_md,
memory_desc_t *bias_md) {
using namespace format_tag;
if (src_md.format_kind == format_kind::any) {
CHECK(memory_desc_init_by_tag(src_md,
utils::pick(ndims() - 2, nc, ncw, nchw, ncdhw)));
}
if (dst_md.format_kind == format_kind::any)
CHECK(memory_desc_init_by_tag(dst_md, nc));
if (weights_md.format_kind == format_kind::any) {
CHECK(memory_desc_init_by_tag(weights_md,
utils::pick(ndims() - 2, oi, oiw, oihw, oidhw)));
}
if (bias_md && bias_md->format_kind == format_kind::any)
CHECK(memory_desc_init_by_tag(*bias_md, x));
return status::success;
}
};
struct inner_product_fwd_pd_t: public inner_product_pd_t {
typedef inner_product_fwd_pd_t base_class;
typedef inner_product_fwd_pd_t hint_class;
inner_product_fwd_pd_t(engine_t *engine,
const inner_product_desc_t *adesc,
const primitive_attr_t *attr,
const inner_product_fwd_pd_t *hint_fwd_pd)
: inner_product_pd_t(engine, adesc, attr, hint_fwd_pd)
, src_md_(desc_.src_desc)
, weights_md_(desc_.weights_desc)
, bias_md_(desc_.bias_desc)
, dst_md_(desc_.dst_desc)
{}
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS))
return arg_usage_t::input;
if (arg == MKLDNN_ARG_BIAS && with_bias())
return arg_usage_t::input;
if (arg == MKLDNN_ARG_DST)
return arg_usage_t::output;
return primitive_desc_t::arg_usage(arg);
}
virtual const memory_desc_t *src_md(int index = 0) const override
{ return index == 0 ? &src_md_ : nullptr; }
virtual const memory_desc_t *dst_md(int index = 0) const override
{ return index == 0 ? &dst_md_ : nullptr; }
virtual const memory_desc_t *weights_md(int index = 0) const override {
if (index == 0) return &weights_md_;
if (index == 1 && with_bias()) return &bias_md_;
return nullptr;
}
virtual int n_inputs() const override { return 2 + with_bias(); }
virtual int n_outputs() const override { return 1; }
protected:
memory_desc_t src_md_;
memory_desc_t weights_md_;
memory_desc_t bias_md_;
memory_desc_t dst_md_;
status_t set_default_params() {
return template_set_default_params(src_md_, weights_md_, dst_md_,
&bias_md_);
}
};
struct inner_product_bwd_data_pd_t: public inner_product_pd_t {
typedef inner_product_bwd_data_pd_t base_class;
typedef inner_product_fwd_pd_t hint_class;
inner_product_bwd_data_pd_t(engine_t *engine,
const inner_product_desc_t *adesc,
const primitive_attr_t *attr,
const inner_product_fwd_pd_t *hint_fwd_pd)
: inner_product_pd_t(engine, adesc, attr, hint_fwd_pd)
, diff_src_md_(desc_.diff_src_desc)
, weights_md_(desc_.weights_desc)
, diff_dst_md_(desc_.diff_dst_desc)
{}
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST))
return arg_usage_t::input;
if (arg == MKLDNN_ARG_DIFF_SRC)
return arg_usage_t::output;
return primitive_desc_t::arg_usage(arg);
}
virtual const memory_desc_t *diff_src_md(int index = 0) const override
{ return index == 0 ? &diff_src_md_ : nullptr; }
virtual const memory_desc_t *diff_dst_md(int index = 0) const override
{ return index == 0 ? &diff_dst_md_ : nullptr; }
virtual const memory_desc_t *weights_md(int index = 0) const override
{ return index == 0 ? &weights_md_ : nullptr; }
virtual int n_inputs() const override { return 2; }
virtual int n_outputs() const override { return 1; }
protected:
memory_desc_t diff_src_md_;
memory_desc_t weights_md_;
memory_desc_t diff_dst_md_;
status_t set_default_params() {
return template_set_default_params(diff_src_md_, weights_md_,
diff_dst_md_, nullptr);
}
};
struct inner_product_bwd_weights_pd_t: public inner_product_pd_t {
typedef inner_product_bwd_weights_pd_t base_class;
typedef inner_product_fwd_pd_t hint_class;
inner_product_bwd_weights_pd_t(engine_t *engine,
const inner_product_desc_t *adesc,
const primitive_attr_t *attr,
const inner_product_fwd_pd_t *hint_fwd_pd)
: inner_product_pd_t(engine, adesc, attr, hint_fwd_pd)
, src_md_(desc_.src_desc)
, diff_weights_md_(desc_.diff_weights_desc)
, diff_bias_md_(desc_.diff_bias_desc)
, diff_dst_md_(desc_.diff_dst_desc)
{}
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST))
return arg_usage_t::input;
if (arg == MKLDNN_ARG_DIFF_WEIGHTS)
return arg_usage_t::output;
if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias())
return arg_usage_t::output;
return primitive_desc_t::arg_usage(arg);
}
virtual const memory_desc_t *src_md(int index = 0) const override
{ return index == 0 ? &src_md_ : nullptr; }
virtual const memory_desc_t *diff_dst_md(int index = 0) const override
{ return index == 0 ? &diff_dst_md_ : nullptr; }
virtual const memory_desc_t *diff_weights_md(int index = 0) const override {
if (index == 0) return &diff_weights_md_;
if (index == 1 && with_bias()) return &diff_bias_md_;
return nullptr;
}
virtual int n_inputs() const override { return 2; }
virtual int n_outputs() const override { return 1 + with_bias(); }
protected:
memory_desc_t src_md_;
memory_desc_t diff_weights_md_;
memory_desc_t diff_bias_md_;
memory_desc_t diff_dst_md_;
status_t set_default_params() {
return template_set_default_params(src_md_, diff_weights_md_,
diff_dst_md_, &diff_bias_md_);
}
};
}
}
#endif
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

View file

@ -1,91 +0,0 @@
/*******************************************************************************
* Copyright 2016-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <assert.h>
#include "mkldnn.h"
#include "c_types_map.hpp"
#include "type_helpers.hpp"
#include "utils.hpp"
using namespace mkldnn::impl;
using namespace mkldnn::impl::utils;
using namespace mkldnn::impl::status;
using namespace mkldnn::impl::prop_kind;
using namespace mkldnn::impl::alg_kind;
using namespace mkldnn::impl::types;
namespace {
status_t lrn_desc_init(lrn_desc_t *lrn_desc,
prop_kind_t prop_kind, alg_kind_t alg_kind,
const memory_desc_t *data_desc, const memory_desc_t *diff_data_desc,
dim_t local_size, float alpha, float beta, float k) {
bool args_ok = true
&& !any_null(lrn_desc, data_desc)
&& one_of(alg_kind, lrn_within_channel, lrn_across_channels)
&& one_of(prop_kind, forward_training, forward_inference, backward_data)
&& IMPLICATION(prop_kind == backward_data, diff_data_desc != nullptr);
if (!args_ok) return invalid_arguments;
auto ld = lrn_desc_t();
ld.primitive_kind = primitive_kind::lrn;
ld.prop_kind = prop_kind;
ld.alg_kind = alg_kind;
const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
ld.data_desc = *data_desc;
if (!is_fwd)
ld.diff_data_desc = *diff_data_desc;
else
ld.diff_data_desc = zero_md();
ld.local_size = local_size;
ld.lrn_alpha = alpha;
ld.lrn_beta = beta;
ld.lrn_k = k;
bool consistency = true
&& ld.data_desc.ndims == 4;
if (ld.prop_kind == backward_data)
consistency = consistency
&& ld.diff_data_desc.ndims == 4
&& array_cmp(ld.diff_data_desc.dims, ld.data_desc.dims, 4);
if (!consistency) return invalid_arguments;
*lrn_desc = ld;
return success;
}
}
status_t mkldnn_lrn_forward_desc_init(lrn_desc_t *lrn_desc,
prop_kind_t prop_kind, alg_kind_t alg_kind,
const memory_desc_t *data_desc, dim_t local_size, float alpha,
float beta, float k) {
if (!one_of(prop_kind, forward_training, forward_inference))
return invalid_arguments;
return lrn_desc_init(lrn_desc, prop_kind, alg_kind, data_desc, nullptr,
local_size, alpha, beta, k);
}
status_t mkldnn_lrn_backward_desc_init(lrn_desc_t *lrn_desc,
alg_kind_t alg_kind, const memory_desc_t *data_desc,
const memory_desc_t *diff_data_desc, dim_t local_size, float alpha,
float beta, float k) {
return lrn_desc_init(lrn_desc, backward_data, alg_kind, data_desc,
diff_data_desc, local_size, alpha, beta, k);
}
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

View file

@ -1,170 +0,0 @@
/*******************************************************************************
* Copyright 2016-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifndef LRN_PD_HPP
#define LRN_PD_HPP
#include "mkldnn.h"
#include "c_types_map.hpp"
#include "primitive_desc.hpp"
namespace mkldnn {
namespace impl {
struct lrn_fwd_pd_t;
struct lrn_pd_t: public primitive_desc_t {
static constexpr auto base_pkind = primitive_kind::lrn;
lrn_pd_t(engine_t *engine,
const lrn_desc_t *adesc,
const primitive_attr_t *attr,
const lrn_fwd_pd_t *hint_fwd_pd)
: primitive_desc_t(engine, attr, base_pkind)
, desc_(*adesc)
, hint_fwd_pd_(hint_fwd_pd)
, data_md_(desc_.data_desc)
, ws_md_()
{}
const lrn_desc_t *desc() const { return &desc_; }
virtual const op_desc_t *op_desc() const override
{ return reinterpret_cast<const op_desc_t *>(this->desc()); }
virtual void init_info() override { impl::init_info(this, this->info_); }
virtual status_t query(query_t what, int idx, void *result) const override {
switch (what) {
case query::lrn_d:
*(const lrn_desc_t**)result = desc(); break;
default: return primitive_desc_t::query(what, idx, result);
}
return status::success;
}
/* common lrn aux functions */
dim_t MB() const { return data_desc().dims[0]; }
dim_t C() const { return data_desc().dims[1]; }
dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; }
dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; }
dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; }
int ndims() const { return data_desc().ndims; }
bool has_zero_dim_memory() const
{ return memory_desc_wrapper(desc_.data_desc).has_zero_dim(); }
bool is_fwd() const {
return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
prop_kind::forward_inference);
}
protected:
lrn_desc_t desc_;
const lrn_fwd_pd_t *hint_fwd_pd_;
memory_desc_t data_md_;
memory_desc_t ws_md_;
private:
const memory_desc_t &data_desc() const { return desc_.data_desc; }
};
struct lrn_fwd_pd_t: public lrn_pd_t {
typedef lrn_fwd_pd_t base_class;
typedef lrn_fwd_pd_t hint_class;
lrn_fwd_pd_t(engine_t *engine,
const lrn_desc_t *adesc,
const primitive_attr_t *attr,
const lrn_fwd_pd_t *hint_fwd_pd)
: lrn_pd_t(engine, adesc, attr, hint_fwd_pd)
{}
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
if (arg == MKLDNN_ARG_SRC)
return arg_usage_t::input;
if (arg == MKLDNN_ARG_DST)
return arg_usage_t::output;
if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr))
return arg_usage_t::output;
return primitive_desc_t::arg_usage(arg);
}
virtual const memory_desc_t *src_md(int index = 0) const override
{ return index == 0 ? &data_md_ : nullptr; }
virtual const memory_desc_t *dst_md(int index = 0) const override
{ return index == 0 ? &data_md_ : nullptr; }
virtual const memory_desc_t *workspace_md(int index = 0) const override
{ return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; }
virtual int n_inputs() const override { return 1; }
virtual int n_outputs() const override
{ return 1 + (workspace_md() != nullptr); }
};
struct lrn_bwd_pd_t: public lrn_pd_t {
typedef lrn_bwd_pd_t base_class;
typedef lrn_fwd_pd_t hint_class;
lrn_bwd_pd_t(engine_t *engine,
const lrn_desc_t *adesc,
const primitive_attr_t *attr,
const lrn_fwd_pd_t *hint_fwd_pd)
: lrn_pd_t(engine, adesc, attr, hint_fwd_pd)
, diff_data_md_(desc_.diff_data_desc)
{}
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST))
return arg_usage_t::input;
if (arg == MKLDNN_ARG_DIFF_SRC)
return arg_usage_t::output;
if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr))
return arg_usage_t::input;
return primitive_desc_t::arg_usage(arg);
}
virtual const memory_desc_t *src_md(int index = 0) const override
{ return index == 0 ? &data_md_ : nullptr; }
virtual const memory_desc_t *diff_dst_md(int index = 0) const override
{ return index == 0 ? &diff_data_md_ : nullptr; }
virtual const memory_desc_t *diff_src_md(int index = 0) const override
{ return index == 0 ? &diff_data_md_ : nullptr; }
virtual const memory_desc_t *workspace_md(int index = 0) const override
{ return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; }
virtual int n_inputs() const override
{ return 2 + (workspace_md() != nullptr); }
virtual int n_outputs() const override { return 1; }
protected:
memory_desc_t diff_data_md_;
};
}
}
#endif
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

View file

@ -1,280 +0,0 @@
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifndef MATH_UTILS_HPP
#define MATH_UTILS_HPP
#include <stdint.h>
#include <math.h>
#include "utils.hpp"
#include "nstl.hpp"
#include "mkldnn_traits.hpp"
#if defined(MKLDNN_X86_64)
#include "immintrin.h"
#endif
namespace mkldnn {
namespace impl {
namespace math {
/** rounds @p f to an integer according to the mxcsr register */
inline int mxcsr_round(float f) {
#if defined(MKLDNN_X86_64)
return _mm_cvtss_si32(_mm_load_ss(&f));
#else
return (int)nearbyintf(f); // optimism
#endif
}
template <typename data_t, typename acc_t>
inline typename utils::enable_if<!nstl::is_integral<data_t>::value,
typename utils::remove_reference<data_t>::type>::type
saturate(const acc_t &x) {
return (typename utils::remove_reference<data_t>::type)x;
}
template <typename data_t, typename acc_t>
inline typename utils::enable_if<nstl::is_integral<data_t>::value,
typename utils::remove_reference<data_t>::type>::type
saturate(const acc_t &x) {
acc_t v = x;
if (v < (acc_t)nstl::numeric_limits<data_t>::lowest())
v = (acc_t)nstl::numeric_limits<data_t>::lowest();
if (v > (acc_t)nstl::numeric_limits<data_t>::max())
v = (acc_t)nstl::numeric_limits<data_t>::max();
return (typename utils::remove_reference<data_t>::type)v;
}
template <typename data_t>
double saturate(const double &x) {
double v = x;
if (v < (double)nstl::numeric_limits<data_t>::lowest())
v = (double)nstl::numeric_limits<data_t>::lowest();
if (v > (double)nstl::numeric_limits<data_t>::max())
v = (double)nstl::numeric_limits<data_t>::max();
return v;
}
template <> inline int8_t saturate<int8_t, uint8_t>(const uint8_t &x) {
return x <= 127u ? x : 127;
}
template <> inline uint8_t saturate<uint8_t, int8_t>(const int8_t &x) {
return x >= 0 ? x : 0;
}
template <typename out_t>
typename utils::enable_if<nstl::is_integral<out_t>::value, out_t>::type
out_round(float v) { return (out_t)mxcsr_round(v); }
template <typename out_t>
typename utils::enable_if<nstl::is_integral<out_t>::value, out_t>::type
out_round(double v) { return (out_t)mxcsr_round((float)v); }
template <typename out_t>
typename utils::enable_if<!nstl::is_integral<out_t>::value, out_t>::type
out_round(float v) { return v; }
inline int gcd(int a, int b) {
a = impl::nstl::abs(a);
b = impl::nstl::abs(b);
if (a < b) { int x = a; a = b; b = x; }
if (b == 0) return a;
int r;
while ((r = a % b) != 0) { a = b; b = r; }
return b;
}
template <typename T>
inline bool is_pow2(const T& v) { return (v & (v - 1)) == 0; }
/** returns floor(log2(v)), aka the position of the leftmost non-0 bit */
inline int ilog2q(size_t v) {
if (v == 0)
return -1;
int p = 0;
# define CP(pw) do { if (v >= (1ull << pw)) { v >>= pw; p += pw; } } while(0)
CP(32); CP(16); CP(8); CP(4); CP(2); CP(1);
# undef CP
return p;
}
template <typename T, typename U = typename utils::remove_reference<T>::type>
inline U one_m_square(T x) {
return (U)(1 - x) * (1 + x);
}
template <typename T, typename U = typename utils::remove_reference<T>::type>
inline U x_m_square(T x) {
return (U)(1 - x) * x;
}
/* activation */
template <typename T, typename A,
typename U = typename utils::remove_reference<T>::type>
inline U relu_fwd(T s, A alpha) {
return s > 0 ? s : (U)(s * alpha);
}
template <typename T, typename A,
typename U = typename utils::remove_reference<T>::type>
inline U relu_bwd(T dd, T s, A alpha) {
return s > 0 ? dd : (U)(dd * alpha);
}
template <typename T, typename U = typename utils::remove_reference<T>::type>
inline U tanh_fwd(T s) {
const float e = tanhf((float) s);
return (U)e;
}
template <typename T, typename U = typename utils::remove_reference<T>::type>
inline U tanh_bwd(T dd, T s) {
const float e = tanh_fwd<float>((float) s);
return (U)(dd * (1 - e) * (1 + e));
}
template <typename T, typename A,
typename U = typename utils::remove_reference<T>::type>
inline U elu_fwd(T s, A alpha) {
return s > 0 ? s : (U)(alpha * (::expm1f((float)s)));
}
template <typename T, typename A,
typename U = typename utils::remove_reference<T>::type>
inline U elu_bwd(T dd, T s, A alpha) {
return (U)(dd * (s > 0 ? 1 : alpha * ::expf((float)s)));
}
template <typename T, typename U = typename utils::remove_reference<T>::type>
inline U square_fwd(T s) {
return s * s;
}
template <typename T, typename U = typename utils::remove_reference<T>::type>
inline U square_bwd(T dd, T s) {
return dd * 2 * s;
}
template <typename T, typename U = typename utils::remove_reference<T>::type>
inline U abs_fwd(T s) {
return s > 0 ? s : -s;
}
template <typename T, typename U = typename utils::remove_reference<T>::type>
inline U abs_bwd(T dd, T s) {
return s > 0 ? dd : s < 0 ? -dd : 0;
}
template <typename T, typename U = typename utils::remove_reference<T>::type>
inline U sqrt_fwd(T s) {
return s > 0 ? (U)(::sqrtf((float)(s))) : 0;
}
template <typename T, typename U = typename utils::remove_reference<T>::type>
inline U sqrt_bwd(T dd, T s) {
return s > 0
? (U)(dd / (2 * ::sqrtf((float)(s))))
: 0;
}
template <typename T, typename A,
typename U = typename utils::remove_reference<T>::type>
inline U linear_fwd(T s, A alpha, A beta) {
return (U)(alpha * s + beta);
}
template <typename T, typename A,
typename U = typename utils::remove_reference<T>::type>
inline U linear_bwd(T dd, T s, A alpha, A beta) {
(void) s;
(void) beta;
return (U)(dd * alpha);
}
template <typename T, typename A,
typename U = typename utils::remove_reference<T>::type>
inline U bounded_relu_fwd(T s, A alpha) {
s = s > 0 ? s : 0;
return s > alpha ? (U)(alpha) : s;
}
template <typename T, typename A,
typename U = typename utils::remove_reference<T>::type>
inline U bounded_relu_bwd(T dd, T s, A alpha) {
return dd * (0 < s && s < alpha ? 1 : 0);
}
template <typename T, typename U = typename utils::remove_reference<T>::type>
inline U soft_relu_fwd(T s) {
float max_logf = 8.872284e+01; //::logf(FLT_MAX)
return s < max_logf ? (U)(::log1pf(::expf((float)s))) : s;
}
template <typename T, typename U = typename utils::remove_reference<T>::type>
inline U soft_relu_bwd(T dd, T s) {
return (U)(dd / (1 + ::expf((float)(-s))));
}
template <typename T, typename U = typename utils::remove_reference<T>::type>
inline U logistic_fwd(T s) {
U v = (U)(::expf((float) -s));
return 1 / (1 + v);
}
template <typename T, typename U = typename utils::remove_reference<T>::type>
inline U logistic_bwd(T dd, T s) {
U v = logistic_fwd<T, U>(s);
return dd * v * (1 - v);
}
inline bool eltwise_fwd_preserves_zero(alg_kind_t alg, bool jit_impl = false) {
using namespace alg_kind;
using namespace utils;
const bool preserves_zero = true
&& !one_of(alg, eltwise_linear, eltwise_soft_relu, eltwise_logistic)
&& IMPLICATION(jit_impl, !one_of(alg, eltwise_elu, eltwise_tanh));
return preserves_zero;
}
inline float get_bias(const char *bias, size_t offset, data_type_t data_type)
{
if (!bias)
return 0.0f;
#define CASE(dt) \
case dt: return (float)((const prec_traits<dt>::type *)bias)[offset]
switch (data_type) {
CASE(data_type::s8);
CASE(data_type::u8);
CASE(data_type::s32);
CASE(data_type::f32);
default: assert(!"unimplemented");
}
return 0; // never happens (should probably be a NaN)
#undef CASE
}
}
}
}
#endif

View file

@ -1,238 +0,0 @@
/*******************************************************************************
* Copyright 2016-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <assert.h>
#include <stddef.h>
#include <stdint.h>
#include "mkldnn.h"
#include "c_types_map.hpp"
#include "engine.hpp"
#include "type_helpers.hpp"
#include "utils.hpp"
using namespace mkldnn::impl;
using namespace mkldnn::impl::utils;
using namespace mkldnn::impl::status;
using namespace mkldnn::impl::data_type;
namespace {
bool memory_desc_sanity_check(int ndims,const dims_t dims,
data_type_t data_type, format_kind_t format_kind) {
if (ndims == 0) return true;
bool ok = true
&& dims != nullptr
&& 0 < ndims && ndims <= MKLDNN_MAX_NDIMS
&& one_of(data_type, f32, s32, s8, u8)
&& format_kind != format_kind::undef;
if (!ok) return false;
for (int d = 0; d < ndims; ++d)
if (dims[d] < 0) return false;
return true;
}
bool memory_desc_sanity_check(const memory_desc_t *md) {
if (md == nullptr) return false;
return memory_desc_sanity_check(md->ndims, md->dims, md->data_type,
format_kind::any);
}
}
status_t mkldnn_memory_desc_init_by_tag(memory_desc_t *memory_desc, int ndims,
const dims_t dims, data_type_t data_type, format_tag_t tag) {
if (any_null(memory_desc)) return invalid_arguments;
if (ndims == 0 || tag == format_tag::undef) {
*memory_desc = types::zero_md();
return success;
}
format_kind_t format_kind = types::format_tag_to_kind(tag);
/* memory_desc != 0 */
bool args_ok = !any_null(memory_desc)
&& memory_desc_sanity_check(ndims, dims, data_type, format_kind);
if (!args_ok) return invalid_arguments;
auto md = memory_desc_t();
md.ndims = ndims;
array_copy(md.dims, dims, ndims);
md.data_type = data_type;
array_copy(md.padded_dims, dims, ndims);
md.format_kind = format_kind;
status_t status = success;
if (tag == format_tag::undef) {
status = invalid_arguments;
} else if (tag == format_tag::any) {
// nop
} else if (format_kind == format_kind::blocked) {
status = memory_desc_wrapper::compute_blocking(md, tag);
} else {
assert(!"unreachable");
status = invalid_arguments;
}
if (status == success)
*memory_desc = md;
return status;
}
status_t mkldnn_memory_desc_init_by_strides(memory_desc_t *memory_desc,
int ndims, const dims_t dims, data_type_t data_type,
const dims_t strides) {
if (any_null(memory_desc)) return invalid_arguments;
if (ndims == 0) {
*memory_desc = types::zero_md();
return success;
}
/* memory_desc != 0 */
bool args_ok = !any_null(memory_desc)
&& memory_desc_sanity_check(ndims, dims, data_type, format_kind::any);
if (!args_ok) return invalid_arguments;
auto md = memory_desc_t();
md.ndims = ndims;
array_copy(md.dims, dims, ndims);
md.data_type = data_type;
array_copy(md.padded_dims, dims, ndims);
md.format_kind = format_kind::blocked;
dims_t default_strides = {0};
if (strides == nullptr) {
default_strides[md.ndims - 1] = 1;
for (int d = md.ndims - 2; d >= 0; --d)
default_strides[d] = default_strides[d + 1] * md.padded_dims[d + 1];
strides = default_strides;
} else {
/* TODO: add sanity check for the provided strides */
}
array_copy(md.format_desc.blocking.strides, strides, md.ndims);
*memory_desc = md;
return status::success;
}
status_t mkldnn_memory_desc_init_submemory(memory_desc_t *md,
const memory_desc_t *parent_md, const dims_t dims,
const dims_t offsets) {
if (any_null(md, parent_md) || !memory_desc_sanity_check(parent_md))
return invalid_arguments;
const memory_desc_wrapper src_d(parent_md);
for (int d = 0; d < src_d.ndims(); ++d) {
if (dims[d] < 0 || offsets[d] < 0
|| (offsets[d] + dims[d] > src_d.dims()[d]))
return invalid_arguments;
}
if (src_d.format_kind() != format_kind::blocked)
return unimplemented;
dims_t blocks;
src_d.compute_blocks(blocks);
memory_desc_t dst_d = *parent_md;
auto &dst_d_blk = dst_d.format_desc.blocking;
/* TODO: put this into memory_desc_wrapper */
for (int d = 0; d < src_d.ndims(); ++d) {
/* very limited functionality for now */
const bool ok = true
&& offsets[d] % blocks[d] == 0 /* [r1] */
&& src_d.padded_offsets()[d] == 0
&& (false
|| dims[d] % blocks[d] == 0
|| dims[d] < blocks[d]);
if (!ok)
return unimplemented;
const bool is_right_border = offsets[d] + dims[d] == src_d.dims()[d];
dst_d.dims[d] = dims[d];
dst_d.padded_dims[d] = is_right_border
? src_d.padded_dims()[d] - offsets[d] : dst_d.dims[d];
dst_d.padded_offsets[d] = src_d.padded_offsets()[d];
dst_d.offset0 += /* [r1] */
offsets[d] / blocks[d] * dst_d_blk.strides[d];
}
*md = dst_d;
return success;
}
int mkldnn_memory_desc_equal(const memory_desc_t *lhs,
const memory_desc_t *rhs) {
if (lhs == rhs) return 1;
if (any_null(lhs, rhs)) return 0;
return memory_desc_wrapper(*lhs) == memory_desc_wrapper(*rhs);
}
size_t mkldnn_memory_desc_get_size(const memory_desc_t *md) {
if (md == nullptr) return 0;
return memory_desc_wrapper(*md).size();
}
status_t mkldnn_memory_create(memory_t **memory, const memory_desc_t *md,
engine_t *engine, void *handle) {
if (any_null(memory, engine)) return invalid_arguments;
memory_desc_t z_md = types::zero_md();
return engine->memory_create(memory, md ? md : &z_md, handle);
}
status_t mkldnn_memory_get_memory_desc(const memory_t *memory,
const memory_desc_t **md) {
if (any_null(memory, md)) return invalid_arguments;
*md = memory->md();
return success;
}
status_t mkldnn_memory_get_engine(const memory_t *memory, engine_t **engine) {
if (any_null(memory, engine)) return invalid_arguments;
*engine = memory->engine();
return success;
}
status_t mkldnn_memory_get_data_handle(const memory_t *memory,
void **handle) {
if (any_null(handle))
return invalid_arguments;
if (memory == nullptr) {
*handle = nullptr;
return success;
}
return memory->get_data_handle(handle);
}
status_t mkldnn_memory_set_data_handle(memory_t *memory, void *handle) {
if (any_null(memory)) return invalid_arguments;
return memory->set_data_handle(handle);
}
status_t mkldnn_memory_destroy(memory_t *memory) {
delete memory;
return success;
}
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

View file

@ -1,63 +0,0 @@
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifndef MEMORY_HPP
#define MEMORY_HPP
#include <assert.h>
#include "mkldnn.h"
#include "c_types_map.hpp"
#include "nstl.hpp"
struct mkldnn_memory: public mkldnn::impl::c_compatible {
mkldnn_memory(mkldnn::impl::engine_t *engine,
const mkldnn::impl::memory_desc_t *md)
: engine_(engine), md_(*md) {}
virtual ~mkldnn_memory() {}
/** allocates/initializes memory */
virtual mkldnn::impl::status_t init() = 0;
/** returns memory's engine */
mkldnn::impl::engine_t *engine() const { return engine_; }
/** returns memory's description */
const mkldnn::impl::memory_desc_t *md() const { return &md_; }
/** returns data handle */
virtual mkldnn::impl::status_t get_data_handle(void **handle) const = 0;
/** sets data handle */
virtual mkldnn::impl::status_t set_data_handle(void *handle) = 0;
/** zeros padding */
virtual mkldnn::impl::status_t zero_pad() const
{ return mkldnn::impl::status::success; }
protected:
mkldnn::impl::engine_t *engine_;
const mkldnn::impl::memory_desc_t md_;
private:
mkldnn_memory() = delete;
mkldnn_memory(const mkldnn_memory &) = delete;
mkldnn_memory(mkldnn_memory &&) = delete;
mkldnn_memory &operator=(const mkldnn_memory &) = delete;
mkldnn_memory &operator=(mkldnn_memory &&) = delete;
};
#endif

View file

@ -1,212 +0,0 @@
/*******************************************************************************
* Copyright 2016-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <assert.h>
#include <initializer_list>
#include "c_types_map.hpp"
#include "memory_desc_wrapper.hpp"
#include "type_helpers.hpp"
#include "utils.hpp"
namespace mkldnn {
namespace impl {
status_t fill_blocked(memory_desc_t &md,
std::initializer_list<int> perm,
std::initializer_list<int> inner_blks,
std::initializer_list<int> inner_idxs) {
const bool ok = true
&& perm.size() == (size_t)md.ndims
&& inner_blks.size() == inner_idxs.size();
if (!ok) return status::invalid_arguments;
md.offset0 = 0;
blocking_desc_t &blk = md.format_desc.blocking;
dim_t block_size = 1;
dims_t blocks = {0};
utils::array_set(blocks, 1, md.ndims);
blk.inner_nblks = (int)inner_blks.size();
int iblk = 0;
for (const auto &b: inner_idxs)
blk.inner_idxs[iblk++] = b;
iblk = 0;
for (const auto &b: inner_blks) {
int dim = blk.inner_idxs[iblk];
block_size *= b;
blocks[dim] *= b;
blk.inner_blks[iblk++] = b;
}
utils::array_set(md.padded_offsets, 0, md.ndims);
for (int d = 0; d < md.ndims; ++d)
md.padded_dims[d] = utils::rnd_up(md.dims[d], blocks[d]);
dim_t stride = block_size;
// if only we use C++14, the initializer_list would have rbegin()/rend()...
for (int d = 0; d < md.ndims; ++d)
stride *= md.padded_dims[d] == 0 ? 1 : md.padded_dims[d] / blocks[d];
for (const auto &d: perm) {
if (md.padded_dims[d] == 0) {
blk.strides[d] = 1;
continue;
}
stride /= md.padded_dims[d] / blocks[d];
blk.strides[d] = stride;
}
assert(stride == block_size);
return status::success;
}
status_t memory_desc_wrapper::compute_blocking(memory_desc_t &memory_desc,
format_tag_t tag)
{
using namespace format_tag;
if (memory_desc.ndims == 0) return status::invalid_arguments;
# define C(tag, ... /* perm, inner_blks, inner_idxs */) \
case tag: return fill_blocked(memory_desc, __VA_ARGS__)
switch (tag) {
C(a, {0}, {}, {});
C(ab, {0, 1}, {}, {});
C(abc, {0, 1, 2}, {}, {});
C(abcd, {0, 1, 2, 3}, {}, {});
C(abcde, {0, 1, 2, 3, 4}, {}, {});
C(abcdef, {0, 1, 2, 3, 4, 5}, {}, {});
C(abdec, {0, 1, 3, 4, 2}, {}, {});
C(acb, {0, 2, 1}, {}, {});
C(acbde, {0, 2, 1, 3, 4}, {}, {});
C(acdb, {0, 2, 3, 1}, {}, {});
C(acdeb, {0, 2, 3, 4, 1}, {}, {});
C(ba, {1, 0}, {}, {});
C(bac, {1, 0, 2}, {}, {});
C(bacd, {1, 0, 2, 3}, {}, {});
C(bcda, {1, 2, 3, 0}, {}, {});
C(cba, {2, 1, 0}, {}, {});
C(cdba, {2, 3, 1, 0}, {}, {});
C(cdeba, {2, 3, 4, 1, 0}, {}, {});
C(decab, {3, 4, 2, 0, 1}, {}, {});
C(Abc4a, {0, 1, 2}, {4}, {0});
C(aBc4b, {0, 1, 2}, {4}, {1});
C(ABc4b16a4b, {0, 1, 2}, {4, 16, 4}, {1, 0, 1});
C(ABc4b4a, {0, 1, 2}, {4, 4}, {1, 0});
C(Abcd4a, {0, 1, 2, 3}, {4}, {0});
C(aBcd4b, {0, 1, 2, 3}, {4}, {1});
C(ABcd4b4a, {0, 1, 2, 3}, {4, 4}, {1, 0});
C(aBCd4c16b4c, {0, 1, 2, 3}, {4, 16, 4}, {2, 1, 2});
C(aBCd4c4b, {0, 1, 2, 3, 4}, {4, 4}, {2, 1});
C(Abcde4a, {0, 1, 2, 3, 4}, {4}, {0});
C(aBcde4b, {0, 1, 2, 3, 4}, {4}, {1});
C(ABcde4b4a, {0, 1, 2, 3, 4}, {4, 4}, {1, 0});
C(aBCde4c4b, {0, 1, 2, 3, 4}, {4, 4}, {2, 1});
C(aBcdef4b, {0, 1, 2, 3, 4, 5}, {4}, {1});
C(aBCdef4c4b, {0, 1, 2, 3, 4, 5}, {4, 4}, {2, 1});
C(aBdc4b, {0, 1, 3, 2}, {4}, {1});
C(aBdec4b, {0, 1, 3, 4, 2}, {4}, {1});
C(aBdefc4b, {0, 1, 3, 4, 5, 2}, {4}, {1});
C(Acb4a, {0, 2, 1}, {4}, {0});
C(Acdb4a, {0, 2, 3, 1}, {4}, {0});
C(Acdeb4a, {0, 2, 3, 4, 1}, {4}, {0});
C(Abc16a, {0, 1, 2}, {16}, {0});
C(ABc16a16b, {0, 1, 2}, {16, 16}, {0, 1});
C(aBc16b, {0, 1, 2}, {16}, {1});
C(ABc16b16a, {0, 1, 2}, {16, 16}, {1, 0});
C(ABc8a16b2a, {0, 1, 2}, {8, 16, 2}, {0, 1, 0});
C(ABc8a8b, {0, 1, 2}, {8, 8}, {0, 1});
C(aBc8b, {0, 1, 2}, {8}, {1});
C(ABc8b16a2b, {0, 1, 2}, {8, 16, 2}, {1, 0, 1});
C(ABc8b8a, {0, 1, 2}, {8, 8}, {1, 0});
C(Abcd16a, {0, 1, 2, 3}, {16}, {0});
C(ABcd16a16b, {0, 1, 2, 3}, {16, 16}, {0, 1});
C(aBcd16b, {0, 1, 2, 3}, {16}, {1});
C(ABcd16b16a, {0, 1, 2, 3}, {16, 16}, {1, 0});
C(aBCd16b16c, {0, 1, 2, 3}, {16, 16}, {1, 2});
C(aBCd16c16b, {0, 1, 2, 3}, {16, 16}, {2, 1});
C(ABcd4b16a4b, {0, 1, 2, 3}, {4, 16, 4}, {1, 0, 1});
C(ABcd8a16b2a, {0, 1, 2, 3}, {8, 16, 2}, {0, 1, 0});
C(ABcd8a8b, {0, 1, 2, 3}, {8, 8}, {0, 1});
C(aBcd8b, {0, 1, 2, 3}, {8}, {1});
C(ABcd8b16a2b, {0, 1, 2, 3}, {8, 16, 2}, {1, 0, 1});
C(aBCd8b16c2b, {0, 1, 2, 3}, {8, 16, 2}, {1, 2, 1});
C(ABcd8b8a, {0, 1, 2, 3}, {8, 8}, {1, 0});
C(aBCd8b8c, {0, 1, 2, 3}, {8, 8}, {1, 2});
C(aBCd8c16b2c, {0, 1, 2, 3}, {8, 16, 2}, {2, 1, 2});
C(aBCd8c8b, {0, 1, 2, 3}, {8, 8}, {2, 1});
C(Abcde16a, {0, 1, 2, 3, 4}, {16}, {0});
C(ABcde16a16b, {0, 1, 2, 3, 4}, {16, 16}, {0, 1});
C(aBcde16b, {0, 1, 2, 3, 4}, {16}, {1});
C(ABcde16b16a, {0, 1, 2, 3, 4}, {16, 16}, {1, 0});
C(aBCde16b16c, {0, 1, 2, 3, 4}, {16, 16}, {1, 2});
C(aBCde16c16b, {0, 1, 2, 3, 4}, {16, 16}, {2, 1});
C(aBCde2c8b4c, {0, 1, 2, 3, 4}, {2, 8, 4}, {2, 1, 2});
C(aBCde4b4c, {0, 1, 2, 3, 4}, {4, 4}, {1, 2});
C(aBCde4c16b4c, {0, 1, 2, 3, 4}, {4, 16, 4}, {2, 1, 2});
C(Abcde8a, {0, 1, 2, 3, 4}, {8}, {0});
C(ABcde8a8b, {0, 1, 2, 3, 4}, {8, 8}, {0, 1});
C(aBcde8b, {0, 1, 2, 3, 4}, {8}, {1});
C(ABcde8b16a2b, {0, 1, 2, 3, 4}, {8, 16, 2}, {1, 0, 1});
C(aBCde8b16c2b, {0, 1, 2, 3, 4}, {8, 16, 2}, {1, 2, 1});
C(ABcde8b8a, {0, 1, 2, 3, 4}, {8, 8}, {1, 0});
C(aBCde8b8c, {0, 1, 2, 3, 4}, {8, 8}, {1, 2});
C(aBCde8c16b2c, {0, 1, 2, 3, 4}, {8, 16, 2}, {2, 1, 2});
C(aBCde8c8b, {0, 1, 2, 3, 4}, {8, 8}, {2, 1});
C(aBcdef16b, {0, 1, 2, 3, 4, 5}, {16}, {1});
C(aBCdef16b16c, {0, 1, 2, 3, 4, 5}, {16, 16}, {1, 2});
C(aBCdef16c16b, {0, 1, 2, 3, 4, 5}, {16, 16}, {2, 1});
C(aBCdef8b8c, {0, 1, 2, 3, 4, 5}, {8, 8}, {1, 2});
C(aBCdef8c16b2c, {0, 1, 2, 3, 4, 5}, {8, 16, 2}, {2, 1, 2});
C(aBCdef8c8b, {0, 1, 2, 3, 4, 5}, {8, 8}, {2, 1});
C(aBdc16b, {0, 1, 3, 2}, {16}, {1});
C(aBdc8b, {0, 1, 3, 2}, {8}, {1});
C(aBdec16b, {0, 1, 3, 4, 2}, {16}, {1});
C(aBdec8b, {0, 1, 3, 4, 2}, {8}, {1});
C(aBdefc16b, {0, 1, 3, 4, 5, 2}, {16}, {1});
C(aBdefc8b, {0, 1, 3, 4, 5, 2}, {8}, {1});
C(Acb16a, {0, 2, 1}, {16}, {0});
C(Acb8a, {0, 2, 1}, {8}, {0});
C(aCBd16b16c, {0, 2, 1, 3}, {16, 16}, {1, 2});
C(aCBde16b16c, {0, 2, 1, 3, 4}, {16, 16}, {1, 2});
C(Acdb16a, {0, 2, 3, 1}, {16}, {0});
C(Acdb8a, {0, 2, 3, 1}, {8}, {0});
C(Acdeb16a, {0, 2, 3, 4, 1}, {16}, {0});
C(Acdeb8a, {0, 2, 3, 4, 1}, {8}, {0});
C(BAc16a16b, {1, 0, 2}, {16, 16}, {0, 1});
C(BAcd16a16b, {1, 0, 2, 3}, {16, 16}, {0, 1});
default: break;
}
#undef C
return status::invalid_arguments;
}
}
}
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

View file

@ -1,400 +0,0 @@
/*******************************************************************************
* Copyright 2016-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifndef MEMORY_DESC_WRAPPER_HPP
#define MEMORY_DESC_WRAPPER_HPP
#include <assert.h>
#include "c_types_map.hpp"
#include "nstl.hpp"
#include "utils.hpp"
#include "type_helpers.hpp"
namespace mkldnn {
namespace impl {
/** thin wrapper class over \struct memory_desc_t which allows easy
* manipulations with underlying C structure, which is taken by reference */
struct memory_desc_wrapper: public c_compatible {
const memory_desc_t *md_;
/** constructor which takes a reference to a constant underlying C memory
* descriptor \param md */
memory_desc_wrapper(const memory_desc_t *md): md_(md) {}
memory_desc_wrapper(const memory_desc_t &md): memory_desc_wrapper(&md) {}
/* implementing attributes */
int ndims() const { return md_->ndims; }
const dims_t &dims() const { return md_->dims; }
data_type_t data_type() const { return md_->data_type; }
const dims_t &padded_dims() const { return md_->padded_dims; }
const dims_t &padded_offsets() const { return md_->padded_offsets; }
dim_t offset0() const { return md_->offset0; }
format_kind_t format_kind() const { return md_->format_kind; }
bool is_blocking_desc() const
{ return format_kind() == format_kind::blocked; }
bool is_wino_desc() const
{ return format_kind() == format_kind::wino; }
bool is_rnn_packed_desc() const
{ return format_kind() == format_kind::rnn_packed; }
const blocking_desc_t &blocking_desc() const {
assert(is_blocking_desc());
return md_->format_desc.blocking;
}
const wino_desc_t &wino_desc() const {
assert(is_wino_desc());
return md_->format_desc.wino_desc;
}
const rnn_packed_desc_t &rnn_packed_desc() const {
assert(is_rnn_packed_desc());
return md_->format_desc.rnn_packed_desc;
}
const memory_extra_desc_t &extra() const { return md_->extra; }
/* some useful function */
/** returns the number of elements including padding if \param with_padding
* is true, and the number of data elements otherwise */
dim_t nelems(bool with_padding = false) const {
if (is_zero()) return 0;
return utils::array_product(
with_padding ? padded_dims() : dims(), ndims());
}
/** returns true if memory descriptor is zero */
bool is_zero() const { return ndims() == 0; }
/** returns true if memory descriptor contains zero as one of its dim */
bool has_zero_dim() const { return nelems() == 0; }
/** return the size of data type (a shortcut) */
size_t data_type_size() const
{ return types::data_type_size(data_type()); }
/** return the size of data type of additional buffer */
size_t additional_buffer_data_size() const {
if (extra().flags & memory_extra_flags::compensation_conv_s8s8)
return sizeof(int32_t);
return 0;
}
/** return true if memory format has additional buffer */
bool is_additional_buffer() const {
return (extra().flags & memory_extra_flags::compensation_conv_s8s8);
}
/** returns the size of additional buffer */
size_t additional_buffer_size() const {
if (extra().flags & memory_extra_flags::compensation_conv_s8s8) {
int cmask = extra().compensation_mask;
assert(cmask == 1 || cmask == 3);
dim_t prod = 1;
for (int d = 0; d < ndims(); ++d)
if (cmask & (1<<d)) prod *= padded_dims()[d];
return prod * additional_buffer_data_size();
}
return 0;
}
/** returns the size required to store described memory
* note: if offset0 != 0 returns 0 (need to specify the behavior) */
size_t size() const {
if (is_zero() || has_zero_dim() || format_kind() == format_kind::any)
return 0;
if (format_kind() == format_kind::wino) {
return wino_desc().size;
} else if (format_kind() == format_kind::rnn_packed) {
return rnn_packed_desc().size;
} else {
if (offset0() != 0) return 0;
dims_t blocks = {0};
compute_blocks(blocks);
const auto &bd = blocking_desc();
size_t max_size = 0;
for (int d = 0; d < ndims(); ++d)
max_size = nstl::max<size_t>(max_size,
padded_dims()[d] / blocks[d] * bd.strides[d]);
if (max_size == 1 && bd.inner_nblks != 0) {
max_size = utils::array_product(bd.inner_blks, bd.inner_nblks);
}
return max_size * data_type_size() + additional_buffer_size();
}
}
/** returns true if data is dense in memory */
bool is_dense(bool with_padding = false) const {
if (utils::one_of(format_kind(), format_kind::undef, format_kind::any))
return false;
return nelems(with_padding) * data_type_size() == size();
}
/** returns true if memory desc is fully defined */
bool is_defined() const { return format_kind() != format_kind::any; }
/** returns true if the only (potentially) padded dim is \param dim */
bool only_padded_dim(int dim) const {
for (int d = 0; d < ndims(); ++d)
if (d != dim && dims()[d] != padded_dims()[d])
return false;
return true;
}
/** returns true if memory desc has blocked layout and block dims are 1s */
bool is_plain() const {
if (!is_blocking_desc()) return false;
return blocking_desc().inner_nblks == 0;
}
/** returns overall block sizes */
void compute_blocks(dims_t blocks) const {
if (!is_blocking_desc()) {
utils::array_set(blocks, 0, ndims());
return;
}
utils::array_set(blocks, 1, ndims());
const auto &bd = blocking_desc();
for (int iblk = 0; iblk < bd.inner_nblks; ++iblk)
blocks[bd.inner_idxs[iblk]] *= bd.inner_blks[iblk];
}
/* comparison section */
bool operator==(const memory_desc_wrapper &rhs) const
{ return *this->md_ == *rhs.md_; }
bool operator!=(const memory_desc_wrapper &rhs) const
{ return !operator==(rhs); }
bool operator==(const memory_desc_t &rhs) const
{ return operator==(memory_desc_wrapper(rhs)); }
bool operator!=(const memory_desc_t &rhs) const
{ return !operator==(rhs); }
/** returns true if data (w/o padding if with_padding == false and w/
* padding otherwise) have the same physical structure, i.e. dimensions,
* strides, and blocked structure. Depending on with_data_type flag
* data_type is taken or not taken into account. dim_start allows to check
* similarity for the logical part of data [dim_start .. ndims()].
* CAUTION: format kind any and undef are not similar to whatever, hence the
* following statement might be true: lhs == rhs && !lhs.similar_to(rhs) */
/* TODO: revise */
bool similar_to(const memory_desc_wrapper &rhs,
bool with_padding = true, bool with_data_type = true,
int dim_start = 0) const;
/** returns true if one memory can be reordered to another */
bool consistent_with(const memory_desc_wrapper &rhs) const;
/** returns true if the memory desc corresponds to the given format tag and
* strides.
* @sa memory_desc_matches_tag */
bool matches_tag(format_tag_t tag, const dims_t strides = nullptr) const {
return memory_desc_matches_tag(*md_, tag, strides);
}
/** returns matching tag (or undef if match is not found)
* XXX: This is a workaround that eventually should go away! */
template <typename... Tags>
format_tag_t matches_one_of_tag(Tags ...tags) const {
for (const auto tag: {tags...}) {
if (memory_desc_matches_tag(*md_, tag))
return tag;
}
return format_tag::undef;
}
/* offset section */
/** returns physical offset by logical one. logical offset is represented by
* an array \param pos. if \param is_pos_padded is true \param pos
* represents the position in already padded area */
dim_t off_v(const dims_t pos, bool is_pos_padded = false) const {
assert(is_blocking_desc());
const blocking_desc_t &blk = blocking_desc();
dims_t pos_copy = {0};
for (int d = 0; d < ndims(); ++d)
pos_copy[d] = pos[d] + (is_pos_padded ? 0 : padded_offsets()[d]);
dim_t phys_offset = offset0();
if (blk.inner_nblks > 0) {
dim_t blk_stride = 1;
for (int iblk = blk.inner_nblks - 1; iblk >= 0; --iblk) {
const int d = blk.inner_idxs[iblk];
const dim_t p = pos_copy[d] % blk.inner_blks[iblk];
phys_offset += p * blk_stride;
pos_copy[d] /= blk.inner_blks[iblk];
blk_stride *= blk.inner_blks[iblk];
}
}
for (int d = 0; d < ndims(); ++d) {
const dim_t p = pos_copy[d];
phys_offset += p * blk.strides[d];
}
return phys_offset;
}
/** returns physical offset by logical one. logical offset is represented by
* a scalar \param l_offset. if \param is_pos_padded is true, \param
* l_offset represents logical offset in already padded area */
dim_t off_l(dim_t l_offset, bool is_pos_padded = false) const {
assert(is_blocking_desc());
dims_t pos;
for (int rd = 0; rd < ndims(); ++rd) {
const int d = ndims() - 1 - rd;
const dim_t cur_dim = is_pos_padded ? padded_dims()[d] : dims()[d];
pos[d] = l_offset % cur_dim;
l_offset /= cur_dim;
}
return off_v(pos, is_pos_padded);
}
/** returns physical offset by logical one. logical offset is represented by
* a tuple of indices (\param xn, ..., \param x1, \param x0) */
template<typename... Args>
dim_t off(Args... args) const {
assert(sizeof...(args) == ndims());
dims_t pos = { args... };
return off_v(pos, false);
}
/** returns physical offset by logical one. logical offset is represented by
* a tuple of indices (\param xn, ..., \param x1, \param x0) in already
* padded area */
template<typename... Args>
dim_t off_padding(Args... args) const {
assert(sizeof...(args) == ndims());
dims_t pos = { args... };
return off_v(pos, true);
}
/** returns physical offset by logical one. Logical offset is represented by
* a tuple of block indices (\param bn, ..., \param b1, \param b0). It is a
* user responsibility to adjust the result to get offset within blocks */
template<typename ...Args>
dim_t blk_off(Args... args) const {
return _blk_off<sizeof...(args), Args...>(args...);
}
template<bool skip_first, typename T, typename ...Args>
dim_t blk_off(T xn, Args... args) const {
return skip_first
? blk_off<Args...>(args...)
: blk_off<T, Args...>(xn, args...);
}
/* static functions section */
/* TODO: replace with non-static, once md_ becomes non-const ref */
static status_t compute_blocking(memory_desc_t &memory_desc,
format_tag_t tag);
private:
/* TODO: put logical_offset in utils */
template<typename T>
dim_t logical_offset(T x0) const { return x0; }
template<typename T, typename... Args>
dim_t logical_offset(T xn, Args... args) const {
const size_t n_args = sizeof...(args);
return xn * utils::array_product<n_args>(
&dims()[ndims() - n_args]) + logical_offset(args...);
}
template<int ORIG_LEN, typename ...Void>
dim_t _blk_off() const { return offset0(); }
template<int ORIG_LEN, typename T, typename ...Args>
dim_t _blk_off(T xc, Args ...args) const {
assert(is_blocking_desc());
constexpr int dc = ORIG_LEN - sizeof...(args) - 1;
return xc * blocking_desc().strides[dc]
+ _blk_off<ORIG_LEN, Args...>(args...);
}
};
inline bool memory_desc_wrapper::similar_to(const memory_desc_wrapper &rhs,
bool with_padding, bool with_data_type, int dim_start) const {
using namespace utils;
if (one_of(format_kind(), format_kind::undef, format_kind::any))
return false;
if (is_wino_desc() || is_rnn_packed_desc())
return false;
const int ds = dim_start;
const auto &blk = blocking_desc();
const auto &r_blk = rhs.blocking_desc();
return ndims() == rhs.ndims()
&& dim_start <= ndims() /* guard */
&& format_kind() == rhs.format_kind()
&& IMPLICATION(with_data_type, data_type() == rhs.data_type())
&& array_cmp(dims() + ds, rhs.dims() + ds, ndims() - ds)
&& array_cmp(blk.strides + ds, r_blk.strides + ds, ndims() - ds)
&& blk.inner_nblks == r_blk.inner_nblks
&& array_cmp(blk.inner_blks, r_blk.inner_blks, blk.inner_nblks)
&& array_cmp(blk.inner_idxs, r_blk.inner_idxs, blk.inner_nblks)
&& IMPLICATION(with_padding, true
&& array_cmp(padded_dims() + ds, rhs.padded_dims() + ds,
ndims() - ds)
&& array_cmp(padded_offsets() + ds, rhs.padded_offsets() + ds,
ndims() - ds));
}
inline bool memory_desc_wrapper::consistent_with(
const memory_desc_wrapper &rhs) const {
if (ndims() == rhs.ndims()) {
for (int d = 0; d < ndims(); ++d) {
if (dims()[d] != rhs.dims()[d]) return false;
}
return true;
} else {
/* TODO: revise.
* is the following possible?
* [1, a, b] <--reorder--> [a, b]
* [a, 1, b] <--reorder--> [a, b]
* not, at least for now */
return false;
}
}
}
}
#endif
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

View file

@ -1,295 +0,0 @@
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifndef MEMORY_TRACKING_HPP
#define MEMORY_TRACKING_HPP
#include <assert.h>
#include <unordered_map>
#include "nstl.hpp"
#include "utils.hpp"
namespace mkldnn {
namespace impl {
namespace memory_tracking {
/* Memory tracking capabilities
*
* The main purpose of this header file is to provide uniform way to register
* required memory for a scratchpad at a primitive descriptor creation time
* and then easily access it having only the base address of the scratchpad.
*
* Primitives might contain multiple disjoint parts that require temporary
* buffers (known as scratchpad) during their execution. A primitive descriptor
* should summarize all the needs into one single number -- the buffer size
* that would be requested from a user. At execution time, the corresponding
* primitive will receive a base pointer to a scratchpad. It then needs to
* provide each part of algorithm the corresponding piece of memory. Three main
* challenges here are:
* 1. Track correct offset (from the base scratchpad address) for each piece
* 2. Algorithm might require that different memory pieces to be aligned, so
* the scratchpad size is no more just a sum of size of the corresponding
* subparts.
* 3. While a primitive is responsible for its scratchpad, the implementation
* might use some other basic blocks (e.g. cpu_reducer) that also require
* scratchpad memory. So there should be a simple way of passing the
* information back and force between the main algorithm (a primitive) and
* auxiliary stuff that lives completely separately from it (e.g. reducer).
*
* To address these challenges this header file provides 3 structures:
* 1. registry_t -- the class the stores the information about requested
* memory. The information includes required size and desired
* alignment for each piece. This class is also responsible
* for computing the right offset to a given piece using the
* base pointer.
* This class is basically a ledger with all entries.
* Lives in primitive descriptors.
*
* 2. registrar_t -- the interface to a registry_t to book memory. Used at
* primitive descriptor creation time only. Contains a
* reference to the corresponding *mutable* registry.
* Always modifiable.
* Allows chaining (using prefixes).
*
* 3. grantor_t -- the interface to a registry_t to access memory. Used at
* primitive execution time only. Contains a reference to
* the corresponding *constant* registry and base pointer.
* Always constant.
* Allows chaining (using prefixes).
*
* Both registrar_t and grantor_t allow chaining with extra prefix provided.
* The feature is useful when a primitive offload a part of computations to
* some other primitives which require their own scratchpad space
* (e.g. reducer). Prefixes are used to avoid key collision in cases when
* multiple sub-primitive (e.g. multiple reducers) are used.
*
* A short example below demonstrates how to use aforementioned classes. In it
* the main primitive is convolution that uses scratchpad for keeping padded
* bias. It also needs a reducer, that needs its own space as well.
*
* ``` c++
* struct reducer_t {
* static void init(registrar_t &scratchpad) {
* // preserve space for the reduction (one page aligned)
* scratchpad.book(key_space, sizeof(float) * 980 * 1024, 4096);
* }
*
* void exec(const grantor_t &scratchpad) {
* // get the pointer to preserved space. scratchpad came from
* // upper primitive (convolution in this example)
* auto space = scratchpad.get<float>(key_reducer_space);
*
* space[:] += ...;
* }
* };
*
* struct conv_t {
* struct pd_t {
* void init() {
* registrar_t scratchpad(scratchpad_registry_);
*
* // preserve a space for padded bias (using default alignment)
* scratchpad.book(key_conv_padded_bias, 128);
*
* // create a proxy registrar for the reducer All entries made
* // by reducer would live in convolution's registry, but would
* // have their own `prefix`, so no interference with conv's
* // buffers.
* registrar_t reducer_scratchpad(scratchpad, prefix_reducer);
*
* reducer_t::init(reducer_scratchpad);
* }
*
* registry_t scratchpad_registry_;
* }
*
* void exec() {
* // get the base pointer to a scratchpad memory from a user
* void *scratchpad_ptr = this->input(MKLDNN_MEM_SCRATCHPAD);
*
* // create a grantor to the scratchpad (and provide the base
* // pointer).
* grantor_t scratchpad(pd()->scratchpad_registry_, scratchpad_ptr);
*
* // access the padded_bias (need only key name and the grantor)
* auto padded_bias = scratchpad.get<float>(key_conv_padded_bias);
*
* // to give the `right` grantor to reducer we need to add the
* // corresponding prefix, so that reducer would be able to access
* // its keys. The call is very similar to the one in pd_t::init
* // with only difference in types: grantor_t vs registrar_t.
* grantor_t reducer_scratchpad(scratchpad, prefix_reducer);
* reducer->exec(reducer_scratchpad);
* }
* };
* ```
*/
/* namespace with common keys and prefixes */
namespace names {
enum {
key_none = 0,
key_bnorm_tmp_mean,
key_bnorm_tmp_var,
key_bnorm_tmp_diff_ss,
key_bnorm_tmp_stats,
key_bnorm_reduction,
key_concat_iptrs,
key_concat_istrides,
key_concat_nelems,
key_concat_optrs,
key_conv_adjusted_scales,
key_conv_bia_reduction,
key_conv_gemm_col,
key_conv_gemm_imtr,
key_conv_int_dat_in_acc_dt,
key_conv_padded_bias,
key_conv_rtus_space,
key_conv_tr_diff_dst,
key_conv_tr_diff_dst_bctx,
key_conv_tr_src,
key_conv_tr_src_bctx,
key_conv_wei_reduction,
key_conv_wei_bia_reduction,
key_conv_wei_bia_reduction_bctx,
key_iprod_int_dat_in_acc_dt,
key_reducer_space,
key_reducer_space_bctx,
key_reorder_wino_plain,
key_reorder_wino_transform_space,
key_reorder_rnn_weights_quantization,
key_reorder_rnn_weights_reduction,
key_rnn_space,
key_rnn_ptrs_bia,
key_rnn_ptrs_wei_layer,
key_rnn_ptrs_wei_iter,
key_softmax_reduction,
key_wino_U,
key_wino_V,
key_wino_M,
key_barrier,
};
enum {
prefix_none = 0,
prefix_reducer_bia,
prefix_reducer_wei,
};
}
// level 0: 00 00 00 xxx
// level 1: 00 00 aa xxx
// level 2: 00 aa bb xxx
// level 3: aa bb cc xxx
// max # of levels: 3 + 1 (base_level)
// here:
// xxx : [1 .. MAX_KEY) : key
// aa, bb, cc : [1 .. MAX_PREFIX) : prefixes for levels 1, 2, and 3
using key_t = uint32_t;
enum { MAX_KEY = (1u << 10), MAX_PREFIX = (1u << 7), };
/// generates global key based on a prefix and a local key
inline key_t make_key(key_t prefix, key_t key) { return prefix + key; }
/// generates global prefix based on the global parent and the local ones
inline key_t make_prefix(key_t parent_prefix, key_t prefix)
{ return MAX_PREFIX * parent_prefix + MAX_KEY * prefix; }
struct registrar_t;
struct grantor_t;
struct registry_t {
void book(const key_t &key, size_t size, size_t alignment) {
if (size == 0) return;
assert(offset_map_.count(key) == 0);
size = utils::rnd_up(size, minimal_alignment);
alignment = nstl::max<size_t>(alignment, minimal_alignment);
offset_map_[key] = entry_t{size_, size, alignment};
size_ += size + alignment - minimal_alignment;
}
void *get(const key_t &key, void *base_ptr) const {
if (base_ptr == nullptr) { assert(size() == 0); return nullptr; }
if (offset_map_.count(key) != 1) return nullptr;
const auto &e = offset_map_.at(key);
base_ptr = utils::align_ptr<void>(base_ptr, minimal_alignment);
char *ptr = (char *)base_ptr + e.offset;
return utils::align_ptr<void>(ptr, e.alignment);
}
size_t size() const
{ return size_ > 0 ? size_ + minimal_alignment - 1 : 0; }
registrar_t registrar();
grantor_t grantor(void *base_ptr) const;
protected:
enum { minimal_alignment = 64 };
struct entry_t { size_t offset, size, alignment; };
std::unordered_map<key_t, entry_t> offset_map_;
size_t size_ = 0;
};
struct registrar_t {
enum { default_alignment = 64 };
registrar_t(registry_t &registry): registry_(registry), prefix_(0) {}
registrar_t(registrar_t &parent, const key_t &prefix)
: registry_(parent.registry_)
, prefix_(make_prefix(parent.prefix_, prefix)) {}
void book(const key_t &key, size_t size,
size_t alignment = default_alignment)
{ registry_.book(make_key(prefix_, key), size, alignment); }
protected:
registry_t &registry_;
const key_t prefix_;
};
struct grantor_t {
grantor_t(const registry_t &registry, void *base_ptr)
: registry_(registry), prefix_(0), base_ptr_(base_ptr) {}
grantor_t(const grantor_t &parent, const key_t &prefix)
: registry_(parent.registry_)
, prefix_(make_prefix(parent.prefix_, prefix))
, base_ptr_(parent.base_ptr_) {}
template <typename T = void> T *get(const key_t &key) const
{ return (T *)registry_.get(make_key(prefix_, key), base_ptr_); }
protected:
const registry_t &registry_;
const key_t prefix_;
void *base_ptr_;
};
inline registrar_t registry_t::registrar() { return registrar_t(*this); }
inline grantor_t registry_t::grantor(void *base_ptr) const
{ return grantor_t(*this, base_ptr); }
}
}
}
#endif

View file

@ -1,131 +0,0 @@
/*******************************************************************************
* Copyright 2019 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <assert.h>
#include <stdio.h>
#include <cinttypes>
#include "mkldnn_debug.h"
#include "mkldnn_types.h"
#include "c_types_map.hpp"
#include "type_helpers.hpp"
#include "utils.hpp"
#define DPRINT(...) do { \
int l = snprintf(str + written_len, str_len, __VA_ARGS__); \
if (l < 0) return l; \
if ((size_t)l >= str_len) return -1; \
written_len += l; str_len -= l; \
} while(0)
int mkldnn_md2fmt_str(char *str, size_t str_len,
const mkldnn_memory_desc_t *mdesc) {
using namespace mkldnn::impl;
if (str == nullptr || str_len <= 1u)
return -1;
int written_len = 0;
if (mdesc == nullptr) {
DPRINT("%s::%s::",
mkldnn_dt2str(data_type::undef),
mkldnn_fmt_kind2str(format_kind::undef));
return written_len;
}
memory_desc_wrapper md(mdesc);
DPRINT("%s:", mkldnn_dt2str(md.data_type()));
bool padded_dims = false, padded_offsets = false;
for (int d = 0; d < md.ndims(); ++d) {
if (md.dims()[d] != md.padded_dims()[d]) padded_dims = true;
if (md.padded_offsets()[d] != 0) padded_offsets = true;
}
bool offset0 = md.offset0();
DPRINT("%s%s%s:",
padded_dims ? "p" : "",
padded_offsets ? "o" : "",
offset0 ? "0" : "");
DPRINT("%s:", mkldnn_fmt_kind2str(md.format_kind()));
if (!md.is_blocking_desc()) {
/* TODO: extend */
DPRINT("%s:", "");
} else {
const auto &blk = md.blocking_desc();
dims_t blocks;
md.compute_blocks(blocks);
char dim_chars[MKLDNN_MAX_NDIMS + 1];
bool plain = true;
for (int d = 0; d < md.ndims(); ++d) {
dim_chars[d] = (blocks[d] == 1 ? 'a' : 'A') + (char)d;
if (blocks[d] != 1) plain = false;
}
dims_t strides;
utils::array_copy(strides, blk.strides, md.ndims());
utils::simultaneous_sort(strides, dim_chars, md.ndims(),
[](dim_t a, dim_t b) { return b - a; });
dim_chars[md.ndims()] = '\0';
DPRINT("%s", dim_chars);
if (!plain) {
for (int iblk = 0; iblk < blk.inner_nblks; ++iblk) {
DPRINT("%d%c", (int)blk.inner_blks[iblk],
'a' + (char)blk.inner_idxs[iblk]);
}
}
DPRINT("%s", ":");
}
DPRINT("f%lx", (long)md.extra().flags);
return written_len;
}
int mkldnn_md2dim_str(char *str, size_t str_len,
const mkldnn_memory_desc_t *mdesc) {
using namespace mkldnn::impl;
if (str == nullptr || str_len <= 1)
return -1;
int written_len = 0;
if (mdesc == nullptr || mdesc->ndims == 0) {
DPRINT("%s", "");
return written_len;
}
memory_desc_wrapper md(mdesc);
for (int d = 0; d < md.ndims() - 1; ++d)
DPRINT("%" PRId64 "x", md.dims()[d]);
DPRINT("%" PRId64, md.dims()[md.ndims() - 1]);
return written_len;
}
#undef DPRINT

View file

@ -1,365 +0,0 @@
/*******************************************************************************
* Copyright 2018-2019 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
/* DO NOT EDIT, AUTO-GENERATED */
#include <assert.h>
#include "mkldnn_debug.h"
#include "mkldnn_types.h"
const char *mkldnn_status2str(mkldnn_status_t v) {
if (v == mkldnn_success) return "success";
if (v == mkldnn_out_of_memory) return "out_of_memory";
if (v == mkldnn_try_again) return "try_again";
if (v == mkldnn_invalid_arguments) return "invalid_arguments";
if (v == mkldnn_not_ready) return "not_ready";
if (v == mkldnn_unimplemented) return "unimplemented";
if (v == mkldnn_iterator_ends) return "iterator_ends";
if (v == mkldnn_runtime_error) return "runtime_error";
if (v == mkldnn_not_required) return "not_required";
assert(!"unknown status");
return "unknown status";
}
const char *mkldnn_dt2str(mkldnn_data_type_t v) {
if (v == mkldnn_data_type_undef) return "undef";
if (v == mkldnn_f32) return "f32";
if (v == mkldnn_s32) return "s32";
if (v == mkldnn_s8) return "s8";
if (v == mkldnn_u8) return "u8";
assert(!"unknown dt");
return "unknown dt";
}
const char *mkldnn_fmt_kind2str(mkldnn_format_kind_t v) {
if (v == mkldnn_format_kind_undef) return "undef";
if (v == mkldnn_format_kind_any) return "any";
if (v == mkldnn_blocked) return "blocked";
if (v == mkldnn_format_kind_wino) return "wino";
if (v == mkldnn_format_kind_rnn_packed) return "rnn_packed";
assert(!"unknown fmt_kind");
return "unknown fmt_kind";
}
const char *mkldnn_fmt_tag2str(mkldnn_format_tag_t v) {
if (v == mkldnn_format_tag_undef) return "undef";
if (v == mkldnn_format_tag_any) return "format_tag_any";
if (v == mkldnn_a) return "a";
if (v == mkldnn_ab) return "ab";
if (v == mkldnn_abc) return "abc";
if (v == mkldnn_abcd) return "abcd";
if (v == mkldnn_abcde) return "abcde";
if (v == mkldnn_abcdef) return "abcdef";
if (v == mkldnn_abdec) return "abdec";
if (v == mkldnn_acb) return "acb";
if (v == mkldnn_acbde) return "acbde";
if (v == mkldnn_acdb) return "acdb";
if (v == mkldnn_acdeb) return "acdeb";
if (v == mkldnn_ba) return "ba";
if (v == mkldnn_bac) return "bac";
if (v == mkldnn_bacd) return "bacd";
if (v == mkldnn_bcda) return "bcda";
if (v == mkldnn_cba) return "cba";
if (v == mkldnn_cdba) return "cdba";
if (v == mkldnn_cdeba) return "cdeba";
if (v == mkldnn_decab) return "decab";
if (v == mkldnn_Abc16a) return "Abc16a";
if (v == mkldnn_ABc16a16b) return "ABc16a16b";
if (v == mkldnn_aBc16b) return "aBc16b";
if (v == mkldnn_ABc16b16a) return "ABc16b16a";
if (v == mkldnn_Abc4a) return "Abc4a";
if (v == mkldnn_aBc4b) return "aBc4b";
if (v == mkldnn_ABc4b16a4b) return "ABc4b16a4b";
if (v == mkldnn_ABc4b4a) return "ABc4b4a";
if (v == mkldnn_ABc8a16b2a) return "ABc8a16b2a";
if (v == mkldnn_ABc8a8b) return "ABc8a8b";
if (v == mkldnn_aBc8b) return "aBc8b";
if (v == mkldnn_ABc8b16a2b) return "ABc8b16a2b";
if (v == mkldnn_ABc8b8a) return "ABc8b8a";
if (v == mkldnn_Abcd16a) return "Abcd16a";
if (v == mkldnn_ABcd16a16b) return "ABcd16a16b";
if (v == mkldnn_aBcd16b) return "aBcd16b";
if (v == mkldnn_ABcd16b16a) return "ABcd16b16a";
if (v == mkldnn_aBCd16b16c) return "aBCd16b16c";
if (v == mkldnn_aBCd16c16b) return "aBCd16c16b";
if (v == mkldnn_Abcd4a) return "Abcd4a";
if (v == mkldnn_aBcd4b) return "aBcd4b";
if (v == mkldnn_ABcd4b16a4b) return "ABcd4b16a4b";
if (v == mkldnn_ABcd4b4a) return "ABcd4b4a";
if (v == mkldnn_aBCd4c16b4c) return "aBCd4c16b4c";
if (v == mkldnn_aBCd4c4b) return "aBCd4c4b";
if (v == mkldnn_ABcd8a16b2a) return "ABcd8a16b2a";
if (v == mkldnn_ABcd8a8b) return "ABcd8a8b";
if (v == mkldnn_aBcd8b) return "aBcd8b";
if (v == mkldnn_ABcd8b16a2b) return "ABcd8b16a2b";
if (v == mkldnn_aBCd8b16c2b) return "aBCd8b16c2b";
if (v == mkldnn_ABcd8b8a) return "ABcd8b8a";
if (v == mkldnn_aBCd8b8c) return "aBCd8b8c";
if (v == mkldnn_aBCd8c16b2c) return "aBCd8c16b2c";
if (v == mkldnn_aBCd8c8b) return "aBCd8c8b";
if (v == mkldnn_Abcde16a) return "Abcde16a";
if (v == mkldnn_ABcde16a16b) return "ABcde16a16b";
if (v == mkldnn_aBcde16b) return "aBcde16b";
if (v == mkldnn_ABcde16b16a) return "ABcde16b16a";
if (v == mkldnn_aBCde16b16c) return "aBCde16b16c";
if (v == mkldnn_aBCde16c16b) return "aBCde16c16b";
if (v == mkldnn_aBCde2c8b4c) return "aBCde2c8b4c";
if (v == mkldnn_Abcde4a) return "Abcde4a";
if (v == mkldnn_aBcde4b) return "aBcde4b";
if (v == mkldnn_ABcde4b4a) return "ABcde4b4a";
if (v == mkldnn_aBCde4b4c) return "aBCde4b4c";
if (v == mkldnn_aBCde4c16b4c) return "aBCde4c16b4c";
if (v == mkldnn_aBCde4c4b) return "aBCde4c4b";
if (v == mkldnn_Abcde8a) return "Abcde8a";
if (v == mkldnn_ABcde8a8b) return "ABcde8a8b";
if (v == mkldnn_ABcde8b16a2b) return "ABcde8b16a2b";
if (v == mkldnn_aBCde8b16c2b) return "aBCde8b16c2b";
if (v == mkldnn_ABcde8b8a) return "ABcde8b8a";
if (v == mkldnn_aBCde8b8c) return "aBCde8b8c";
if (v == mkldnn_aBCde8c16b2c) return "aBCde8c16b2c";
if (v == mkldnn_aBCde8c8b) return "aBCde8c8b";
if (v == mkldnn_aBcdef16b) return "aBcdef16b";
if (v == mkldnn_aBCdef16b16c) return "aBCdef16b16c";
if (v == mkldnn_aBCdef16c16b) return "aBCdef16c16b";
if (v == mkldnn_aBcdef4b) return "aBcdef4b";
if (v == mkldnn_aBCdef4c4b) return "aBCdef4c4b";
if (v == mkldnn_aBCdef8b8c) return "aBCdef8b8c";
if (v == mkldnn_aBCdef8c16b2c) return "aBCdef8c16b2c";
if (v == mkldnn_aBCdef8c8b) return "aBCdef8c8b";
if (v == mkldnn_aBdc16b) return "aBdc16b";
if (v == mkldnn_aBdc4b) return "aBdc4b";
if (v == mkldnn_aBdc8b) return "aBdc8b";
if (v == mkldnn_aBdec16b) return "aBdec16b";
if (v == mkldnn_aBdec4b) return "aBdec4b";
if (v == mkldnn_aBdec8b) return "aBdec8b";
if (v == mkldnn_aBdefc16b) return "aBdefc16b";
if (v == mkldnn_aBdefc4b) return "aBdefc4b";
if (v == mkldnn_aBdefc8b) return "aBdefc8b";
if (v == mkldnn_Acb16a) return "Acb16a";
if (v == mkldnn_Acb4a) return "Acb4a";
if (v == mkldnn_Acb8a) return "Acb8a";
if (v == mkldnn_aCBd16b16c) return "aCBd16b16c";
if (v == mkldnn_aCBde16b16c) return "aCBde16b16c";
if (v == mkldnn_Acdb16a) return "Acdb16a";
if (v == mkldnn_Acdb4a) return "Acdb4a";
if (v == mkldnn_Acdb8a) return "Acdb8a";
if (v == mkldnn_Acdeb16a) return "Acdeb16a";
if (v == mkldnn_Acdeb4a) return "Acdeb4a";
if (v == mkldnn_Acdeb8a) return "Acdeb8a";
if (v == mkldnn_BAc16a16b) return "BAc16a16b";
if (v == mkldnn_BAcd16a16b) return "BAcd16a16b";
if (v == mkldnn_format_tag_last) return "format_tag_last";
if (v == mkldnn_x) return "x";
if (v == mkldnn_nc) return "nc";
if (v == mkldnn_cn) return "cn";
if (v == mkldnn_ncw) return "ncw";
if (v == mkldnn_nwc) return "nwc";
if (v == mkldnn_nchw) return "nchw";
if (v == mkldnn_nhwc) return "nhwc";
if (v == mkldnn_chwn) return "chwn";
if (v == mkldnn_ncdhw) return "ncdhw";
if (v == mkldnn_ndhwc) return "ndhwc";
if (v == mkldnn_oi) return "oi";
if (v == mkldnn_io) return "io";
if (v == mkldnn_oiw) return "oiw";
if (v == mkldnn_wio) return "wio";
if (v == mkldnn_oihw) return "oihw";
if (v == mkldnn_hwio) return "hwio";
if (v == mkldnn_ihwo) return "ihwo";
if (v == mkldnn_iohw) return "iohw";
if (v == mkldnn_oidhw) return "oidhw";
if (v == mkldnn_dhwio) return "dhwio";
if (v == mkldnn_goiw) return "goiw";
if (v == mkldnn_goihw) return "goihw";
if (v == mkldnn_hwigo) return "hwigo";
if (v == mkldnn_giohw) return "giohw";
if (v == mkldnn_goidhw) return "goidhw";
if (v == mkldnn_tnc) return "tnc";
if (v == mkldnn_ntc) return "ntc";
if (v == mkldnn_ldsnc) return "ldsnc";
if (v == mkldnn_ldigo) return "ldigo";
if (v == mkldnn_ldgoi) return "ldgoi";
if (v == mkldnn_ldgo) return "ldgo";
if (v == mkldnn_nCdhw16c) return "nCdhw16c";
if (v == mkldnn_nCdhw4c) return "nCdhw4c";
if (v == mkldnn_nCdhw8c) return "nCdhw8c";
if (v == mkldnn_nChw16c) return "nChw16c";
if (v == mkldnn_nChw4c) return "nChw4c";
if (v == mkldnn_nChw8c) return "nChw8c";
if (v == mkldnn_nCw16c) return "nCw16c";
if (v == mkldnn_nCw4c) return "nCw4c";
if (v == mkldnn_nCw8c) return "nCw8c";
if (v == mkldnn_IOw16o16i) return "IOw16o16i";
if (v == mkldnn_OIw16i16o) return "OIw16i16o";
if (v == mkldnn_OIw16o16i) return "OIw16o16i";
if (v == mkldnn_Oiw16o) return "Oiw16o";
if (v == mkldnn_OIw4i16o4i) return "OIw4i16o4i";
if (v == mkldnn_OIw4i4o) return "OIw4i4o";
if (v == mkldnn_Oiw4o) return "Oiw4o";
if (v == mkldnn_OIw8i16o2i) return "OIw8i16o2i";
if (v == mkldnn_OIw8i8o) return "OIw8i8o";
if (v == mkldnn_OIw8o16i2o) return "OIw8o16i2o";
if (v == mkldnn_OIw8o8i) return "OIw8o8i";
if (v == mkldnn_Owi16o) return "Owi16o";
if (v == mkldnn_Owi4o) return "Owi4o";
if (v == mkldnn_Owi8o) return "Owi8o";
if (v == mkldnn_IOhw16o16i) return "IOhw16o16i";
if (v == mkldnn_Ohwi16o) return "Ohwi16o";
if (v == mkldnn_Ohwi4o) return "Ohwi4o";
if (v == mkldnn_Ohwi8o) return "Ohwi8o";
if (v == mkldnn_OIhw16i16o) return "OIhw16i16o";
if (v == mkldnn_OIhw16o16i) return "OIhw16o16i";
if (v == mkldnn_Oihw16o) return "Oihw16o";
if (v == mkldnn_OIhw4i16o4i) return "OIhw4i16o4i";
if (v == mkldnn_OIhw4i4o) return "OIhw4i4o";
if (v == mkldnn_Oihw4o) return "Oihw4o";
if (v == mkldnn_OIhw8i16o2i) return "OIhw8i16o2i";
if (v == mkldnn_OIhw8i8o) return "OIhw8i8o";
if (v == mkldnn_OIhw8o16i2o) return "OIhw8o16i2o";
if (v == mkldnn_OIhw8o8i) return "OIhw8o8i";
if (v == mkldnn_Odhwi16o) return "Odhwi16o";
if (v == mkldnn_Odhwi4o) return "Odhwi4o";
if (v == mkldnn_Odhwi8o) return "Odhwi8o";
if (v == mkldnn_OIdhw16i16o) return "OIdhw16i16o";
if (v == mkldnn_OIdhw16o16i) return "OIdhw16o16i";
if (v == mkldnn_Oidhw16o) return "Oidhw16o";
if (v == mkldnn_OIdhw4i4o) return "OIdhw4i4o";
if (v == mkldnn_Oidhw4o) return "Oidhw4o";
if (v == mkldnn_OIdhw8i16o2i) return "OIdhw8i16o2i";
if (v == mkldnn_OIdhw8i8o) return "OIdhw8i8o";
if (v == mkldnn_OIdhw8o8i) return "OIdhw8o8i";
if (v == mkldnn_Goiw16g) return "Goiw16g";
if (v == mkldnn_gIOw16o16i) return "gIOw16o16i";
if (v == mkldnn_gOIw16i16o) return "gOIw16i16o";
if (v == mkldnn_gOIw16o16i) return "gOIw16o16i";
if (v == mkldnn_gOiw16o) return "gOiw16o";
if (v == mkldnn_gOIw4i16o4i) return "gOIw4i16o4i";
if (v == mkldnn_gOIw4i4o) return "gOIw4i4o";
if (v == mkldnn_gOiw4o) return "gOiw4o";
if (v == mkldnn_gOIw8i16o2i) return "gOIw8i16o2i";
if (v == mkldnn_gOIw8i8o) return "gOIw8i8o";
if (v == mkldnn_gOIw8o16i2o) return "gOIw8o16i2o";
if (v == mkldnn_gOIw8o8i) return "gOIw8o8i";
if (v == mkldnn_gOwi16o) return "gOwi16o";
if (v == mkldnn_gOwi4o) return "gOwi4o";
if (v == mkldnn_gOwi8o) return "gOwi8o";
if (v == mkldnn_gIOhw16o16i) return "gIOhw16o16i";
if (v == mkldnn_gOhwi16o) return "gOhwi16o";
if (v == mkldnn_gOhwi4o) return "gOhwi4o";
if (v == mkldnn_gOhwi8o) return "gOhwi8o";
if (v == mkldnn_Goihw16g) return "Goihw16g";
if (v == mkldnn_gOIhw16i16o) return "gOIhw16i16o";
if (v == mkldnn_gOIhw16o16i) return "gOIhw16o16i";
if (v == mkldnn_gOihw16o) return "gOihw16o";
if (v == mkldnn_gOIhw2i8o4i) return "gOIhw2i8o4i";
if (v == mkldnn_gOIhw4i16o4i) return "gOIhw4i16o4i";
if (v == mkldnn_gOIhw4i4o) return "gOIhw4i4o";
if (v == mkldnn_gOIhw4o4i) return "gOIhw4o4i";
if (v == mkldnn_gOihw4o) return "gOihw4o";
if (v == mkldnn_Goihw8g) return "Goihw8g";
if (v == mkldnn_gOIhw8i16o2i) return "gOIhw8i16o2i";
if (v == mkldnn_gOIhw8i8o) return "gOIhw8i8o";
if (v == mkldnn_gOIhw8o16i2o) return "gOIhw8o16i2o";
if (v == mkldnn_gOIhw8o8i) return "gOIhw8o8i";
if (v == mkldnn_gOdhwi16o) return "gOdhwi16o";
if (v == mkldnn_gOdhwi4o) return "gOdhwi4o";
if (v == mkldnn_gOdhwi8o) return "gOdhwi8o";
if (v == mkldnn_gOIdhw16i16o) return "gOIdhw16i16o";
if (v == mkldnn_gOIdhw16o16i) return "gOIdhw16o16i";
if (v == mkldnn_gOidhw16o) return "gOidhw16o";
if (v == mkldnn_gOIdhw4i4o) return "gOIdhw4i4o";
if (v == mkldnn_gOidhw4o) return "gOidhw4o";
if (v == mkldnn_gOIdhw8i16o2i) return "gOIdhw8i16o2i";
if (v == mkldnn_gOIdhw8i8o) return "gOIdhw8i8o";
if (v == mkldnn_gOIdhw8o8i) return "gOIdhw8o8i";
assert(!"unknown fmt_tag");
return "unknown fmt_tag";
}
const char *mkldnn_prop_kind2str(mkldnn_prop_kind_t v) {
if (v == mkldnn_prop_kind_undef) return "undef";
if (v == mkldnn_forward_training) return "forward_training";
if (v == mkldnn_forward_inference) return "forward_inference";
if (v == mkldnn_forward_scoring) return "forward_scoring";
if (v == mkldnn_forward) return "forward";
if (v == mkldnn_backward) return "backward";
if (v == mkldnn_backward_data) return "backward_data";
if (v == mkldnn_backward_weights) return "backward_weights";
if (v == mkldnn_backward_bias) return "backward_bias";
assert(!"unknown prop_kind");
return "unknown prop_kind";
}
const char *mkldnn_prim_kind2str(mkldnn_primitive_kind_t v) {
if (v == mkldnn_undefined_primitive) return "undef";
if (v == mkldnn_reorder) return "reorder";
if (v == mkldnn_shuffle) return "shuffle";
if (v == mkldnn_concat) return "concat";
if (v == mkldnn_sum) return "sum";
if (v == mkldnn_convolution) return "convolution";
if (v == mkldnn_deconvolution) return "deconvolution";
if (v == mkldnn_eltwise) return "eltwise";
if (v == mkldnn_softmax) return "softmax";
if (v == mkldnn_pooling) return "pooling";
if (v == mkldnn_lrn) return "lrn";
if (v == mkldnn_batch_normalization) return "batch_normalization";
if (v == mkldnn_inner_product) return "inner_product";
if (v == mkldnn_rnn) return "rnn";
assert(!"unknown prim_kind");
return "unknown prim_kind";
}
const char *mkldnn_alg_kind2str(mkldnn_alg_kind_t v) {
if (v == mkldnn_alg_kind_undef) return "undef";
if (v == mkldnn_convolution_direct) return "convolution_direct";
if (v == mkldnn_convolution_winograd) return "convolution_winograd";
if (v == mkldnn_convolution_auto) return "convolution_auto";
if (v == mkldnn_deconvolution_direct) return "deconvolution_direct";
if (v == mkldnn_deconvolution_winograd) return "deconvolution_winograd";
if (v == mkldnn_eltwise_relu) return "eltwise_relu";
if (v == mkldnn_eltwise_tanh) return "eltwise_tanh";
if (v == mkldnn_eltwise_elu) return "eltwise_elu";
if (v == mkldnn_eltwise_square) return "eltwise_square";
if (v == mkldnn_eltwise_abs) return "eltwise_abs";
if (v == mkldnn_eltwise_sqrt) return "eltwise_sqrt";
if (v == mkldnn_eltwise_linear) return "eltwise_linear";
if (v == mkldnn_eltwise_bounded_relu) return "eltwise_bounded_relu";
if (v == mkldnn_eltwise_soft_relu) return "eltwise_soft_relu";
if (v == mkldnn_eltwise_logistic) return "eltwise_logistic";
if (v == mkldnn_pooling_max) return "pooling_max";
if (v == mkldnn_pooling_avg_include_padding) return "pooling_avg_include_padding";
if (v == mkldnn_pooling_avg_exclude_padding) return "pooling_avg_exclude_padding";
if (v == mkldnn_pooling_avg) return "pooling_avg";
if (v == mkldnn_lrn_across_channels) return "lrn_across_channels";
if (v == mkldnn_lrn_within_channel) return "lrn_within_channel";
if (v == mkldnn_vanilla_rnn) return "vanilla_rnn";
if (v == mkldnn_vanilla_lstm) return "vanilla_lstm";
if (v == mkldnn_vanilla_gru) return "vanilla_gru";
if (v == mkldnn_gru_linear_before_reset) return "gru_linear_before_reset";
assert(!"unknown alg_kind");
return "unknown alg_kind";
}
const char *mkldnn_rnn_direction2str(mkldnn_rnn_direction_t v) {
if (v == mkldnn_unidirectional_left2right) return "unidirectional_left2right";
if (v == mkldnn_unidirectional_right2left) return "unidirectional_right2left";
if (v == mkldnn_bidirectional_concat) return "bidirectional_concat";
if (v == mkldnn_bidirectional_sum) return "bidirectional_sum";
if (v == mkldnn_unidirectional) return "unidirectional";
assert(!"unknown rnn_direction");
return "unknown rnn_direction";
}

View file

@ -1,115 +0,0 @@
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifndef MKLDNN_THREAD_HPP
#define MKLDNN_THREAD_HPP
#include "utils.hpp"
#include "z_magic.hpp"
#define MKLDNN_THR_SEQ 0
#define MKLDNN_THR_OMP 1
#define MKLDNN_THR_TBB 2
/* Ideally this condition below should never happen (if the library is built
* using regular cmake). For the 3rd-party projects that build the library
* from the sources on their own try to guess the right threading... */
#if !defined(MKLDNN_THR)
# define MKLDNN_THR MKLDNN_THR_TBB
#endif
#if MKLDNN_THR == MKLDNN_THR_SEQ
#define MKLDNN_THR_SYNC 1
inline int mkldnn_get_max_threads() { return 1; }
inline int mkldnn_get_num_threads() { return 1; }
inline int mkldnn_get_thread_num() { return 0; }
inline int mkldnn_in_parallel() { return 0; }
inline void mkldnn_thr_barrier() {}
#define PRAGMA_OMP(...)
#elif MKLDNN_THR == MKLDNN_THR_OMP
#include <omp.h>
#define MKLDNN_THR_SYNC 1
inline int mkldnn_get_max_threads() { return omp_get_max_threads(); }
inline int mkldnn_get_num_threads() { return omp_get_num_threads(); }
inline int mkldnn_get_thread_num() { return omp_get_thread_num(); }
inline int mkldnn_in_parallel() { return omp_in_parallel(); }
inline void mkldnn_thr_barrier() {
# pragma omp barrier
}
#define PRAGMA_OMP(...) PRAGMA_MACRO(CHAIN2(omp, __VA_ARGS__))
#elif MKLDNN_THR == MKLDNN_THR_TBB
#include "tbb/task_arena.h"
#include "tbb/parallel_for.h"
#define MKLDNN_THR_SYNC 0
inline int mkldnn_get_max_threads()
{ return tbb::this_task_arena::max_concurrency(); }
inline int mkldnn_get_num_threads() { return mkldnn_get_max_threads(); }
inline int mkldnn_get_thread_num()
{ return tbb::this_task_arena::current_thread_index(); }
inline int mkldnn_in_parallel() { return 0; }
inline void mkldnn_thr_barrier() { assert(!"no barrier in TBB"); }
#define PRAGMA_OMP(...)
#endif
/* MSVC still supports omp 2.0 only */
#if defined(_MSC_VER) && !defined(__clang__) && !defined(__INTEL_COMPILER)
# define collapse(x)
# define PRAGMA_OMP_SIMD(...)
#else
# define PRAGMA_OMP_SIMD(...) PRAGMA_MACRO(CHAIN2(omp, simd __VA_ARGS__))
#endif // defined(_MSC_VER) && !defined(__INTEL_COMPILER)
namespace mkldnn {
namespace impl {
inline bool mkldnn_thr_syncable() { return MKLDNN_THR_SYNC == 1; }
template <typename T, typename U>
inline void balance211(T n, U team, U tid, T &n_start, T &n_end) {
T n_min = 1;
T &n_my = n_end;
if (team <= 1 || n == 0) {
n_start = 0;
n_my = n;
} else if (n_min == 1) {
// team = T1 + T2
// n = T1*n1 + T2*n2 (n1 - n2 = 1)
T n1 = utils::div_up(n, (T)team);
T n2 = n1 - 1;
T T1 = n - n2 * (T)team;
n_my = (T)tid < T1 ? n1 : n2;
n_start = (T)tid <= T1 ? tid * n1 : T1 * n1 + ((T)tid - T1) * n2;
}
n_end += n_start;
}
} // namespace impl
} // namespace mkldnn
#include "mkldnn_thread_parallel_nd.hpp"
#endif
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

View file

@ -1,277 +0,0 @@
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifndef MKLDNN_THREAD_PARALLEL_ND_HPP
#define MKLDNN_THREAD_PARALLEL_ND_HPP
/* This header must be included by mkldnn_thread.hpp only */
/* Functions:
* - parallel(nthr, f) - executes f in parallel using at most
* nthr threads. If nthr equals 0
* mkldnn_get_max_threads() threads is
* used
* - for_nd(ithr, nthr, dims..., f) - multidimensional for loop for already
* created threads
* - parallel_nd(dims..., f) - creates a parallel section and then
* calls for_nd
* - parallel_nd_in_omp(dims..., f) - queries current nthr and ithr and then
* calls for_nd (mostly for convenience)
*/
namespace mkldnn {
namespace impl {
/* general parallelization */
template <typename F>
void parallel(int nthr, F f) {
if (nthr == 0) nthr = mkldnn_get_max_threads();
#if MKLDNN_THR == MKLDNN_THR_SEQ
assert(nthr == 1);
f(0, 1);
#elif MKLDNN_THR == MKLDNN_THR_OMP
if (nthr == 1) { f(0, 1); return; }
# pragma omp parallel num_threads(nthr)
f(mkldnn_get_thread_num(), mkldnn_get_num_threads());
#elif MKLDNN_THR == MKLDNN_THR_TBB
if (nthr == 1) { f(0, 1); return; }
tbb::parallel_for(0, nthr, [&](int ithr) { f(ithr, nthr); }, tbb::static_partitioner());
#endif
}
/* for_nd section */
template <typename T0, typename F>
void for_nd(const int ithr, const int nthr, const T0 &D0, F f) {
T0 start{0}, end{0};
balance211(D0, nthr, ithr, start, end);
for (T0 d0 = start; d0 < end; ++d0) f(d0);
}
template <typename T0, typename T1, typename F>
void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1, F f) {
const size_t work_amount = (size_t)D0 * D1;
if (work_amount == 0) return;
size_t start{0}, end{0};
balance211(work_amount, nthr, ithr, start, end);
T0 d0{0}; T1 d1{0};
utils::nd_iterator_init(start, d0, D0, d1, D1);
for (size_t iwork = start; iwork < end; ++iwork) {
f(d0, d1);
utils::nd_iterator_step(d0, D0, d1, D1);
}
}
template <typename T0, typename T1, typename T2, typename F>
void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1,
const T2 &D2, F f) {
const size_t work_amount = (size_t)D0 * D1 * D2;
if (work_amount == 0) return;
size_t start{0}, end{0};
balance211(work_amount, nthr, ithr, start, end);
T0 d0{0}; T1 d1{0}; T2 d2{0};
utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2);
for (size_t iwork = start; iwork < end; ++iwork) {
f(d0, d1, d2);
utils::nd_iterator_step(d0, D0, d1, D1, d2, D2);
}
}
template <typename T0, typename T1, typename T2, typename T3, typename F>
void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1,
const T2 &D2, const T3 &D3, F f) {
const size_t work_amount = (size_t)D0 * D1 * D2 * D3;
if (work_amount == 0) return;
size_t start{0}, end{0};
balance211(work_amount, nthr, ithr, start, end);
T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0};
utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3);
for (size_t iwork = start; iwork < end; ++iwork) {
f(d0, d1, d2, d3);
utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3);
}
}
template <typename T0, typename T1, typename T2, typename T3, typename T4,
typename F>
void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1,
const T2 &D2, const T3 &D3, const T4 &D4, F f) {
const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4;
if (work_amount == 0) return;
size_t start{0}, end{0};
balance211(work_amount, nthr, ithr, start, end);
T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0};
utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4);
for (size_t iwork = start; iwork < end; ++iwork) {
f(d0, d1, d2, d3, d4);
utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4);
}
}
template <typename T0, typename T1, typename T2, typename T3, typename T4,
typename T5, typename F>
void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1,
const T2 &D2, const T3 &D3, const T4 &D4, const T5 &D5, F f) {
const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4 * D5;
if (work_amount == 0) return;
size_t start{0}, end{0};
balance211(work_amount, nthr, ithr, start, end);
T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0}; T5 d5{0};
utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4,
d5, D5);
for (size_t iwork = start; iwork < end; ++iwork) {
f(d0, d1, d2, d3, d4, d5);
utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5);
}
}
// Skip a lambda function in the parameter pack.
template <typename T>
constexpr size_t get_work_amount(const T &v) { return 1; }
template <typename T, typename ...Args>
constexpr size_t get_work_amount(const T &v, Args &&...args)
{ return (size_t)v * get_work_amount(utils::forward<Args>(args)...); }
/* parallel_nd and parallel_nd_in_omp section */
#if MKLDNN_THR != MKLDNN_THR_TBB
template <typename ...Args>
void parallel_nd(Args &&...args) {
#if MKLDNN_THR == MKLDNN_THR_SEQ
for_nd(0, 1, utils::forward<Args>(args)...);
#elif MKLDNN_THR == MKLDNN_THR_OMP
const bool do_parallel = get_work_amount(utils::forward<Args>(args)...) > 1;
# pragma omp parallel if (do_parallel)
{
const int nthr = !do_parallel ? 1 : mkldnn_get_num_threads();
const int ithr = !do_parallel ? 0 : mkldnn_get_thread_num();
for_nd(ithr, nthr, utils::forward<Args>(args)...);
}
#endif
}
#else // MKLDNN_THR != MKLDNN_THR_TBB
// gcc 4.8 has a bug with passing parameter pack to lambdas.
// So have to explicitly instantiate all the cases.
template <typename T0, typename F>
void parallel_nd(const T0 &D0, F f) {
const size_t work_amount = (size_t)D0;
if (work_amount == 0) return;
tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
f(T0(iwork));
}
}, tbb::static_partitioner());
}
template <typename T0, typename T1, typename F>
void parallel_nd(const T0 &D0, const T1 &D1, F f) {
const size_t work_amount = (size_t)D0 * D1;
if (work_amount == 0) return;
tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
T0 d0{0}; T1 d1{0};
utils::nd_iterator_init(r.begin(), d0, D0, d1, D1);
for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
f(d0, d1);
utils::nd_iterator_step(d0, D0, d1, D1);
}
}, tbb::static_partitioner());
}
template <typename T0, typename T1, typename T2, typename F>
void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, F f) {
const size_t work_amount = (size_t)D0 * D1 * D2;
if (work_amount == 0) return;
tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
T0 d0{0}; T1 d1{0}; T2 d2{0};
utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2);
for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
f(d0, d1, d2);
utils::nd_iterator_step(d0, D0, d1, D1, d2, D2);
}
}, tbb::static_partitioner());
}
template <typename T0, typename T1, typename T2, typename T3, typename F>
void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3, F f) {
const size_t work_amount = (size_t)D0 * D1 * D2 * D3;
if (work_amount == 0) return;
tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0};
utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3);
for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
f(d0, d1, d2, d3);
utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3);
}
}, tbb::static_partitioner());
}
template <typename T0, typename T1, typename T2, typename T3, typename T4,
typename F>
void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3,
const T4 &D4, F f) {
const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4;
if (work_amount == 0) return;
tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0};
utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3, d4, D4);
for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
f(d0, d1, d2, d3, d4);
utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4);
}
}, tbb::static_partitioner());
}
template <typename T0, typename T1, typename T2, typename T3, typename T4,
typename T5, typename F>
void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3,
const T4 &D4, const T5 &D5, F f) {
const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4 * D5;
if (work_amount == 0) return;
tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0}; T5 d5{0};
utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3, d4, D4,
d5, D5);
for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
f(d0, d1, d2, d3, d4, d5);
utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5);
}
}, tbb::static_partitioner());
}
#endif
template <typename ...Args>
void parallel_nd_in_omp(Args &&...args) {
#if MKLDNN_THR == MKLDNN_THR_SEQ
for_nd(0, 1, utils::forward<Args>(args)...);
#elif MKLDNN_THR == MKLDNN_THR_OMP
for_nd(mkldnn_get_thread_num(), mkldnn_get_num_threads(),
utils::forward<Args>(args)...);
#elif MKLDNN_THR == MKLDNN_THR_TBB
assert(!"unsupported parallel_nd_in_omp()");
#endif
}
} // namespace impl
} // namespace mkldnn
#endif

View file

@ -1,77 +0,0 @@
/*******************************************************************************
* Copyright 2016-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifndef MKLDNN_TRAITS_HPP
#define MKLDNN_TRAITS_HPP
#include <assert.h>
#include <stdint.h>
#include "mkldnn.h"
#include "c_types_map.hpp"
#include "nstl.hpp"
#include "utils.hpp"
#include "z_magic.hpp"
namespace mkldnn {
namespace impl {
template <data_type_t> struct prec_traits {}; /* ::type -> float */
template <typename> struct data_traits {}; /* ::data_type -> f32 */
template <int> struct typesize_traits {}; /* ::data_type_size -> f32 */
template <primitive_kind_t> struct pkind_traits {}; /* ::desc_type, ::query_d */
template <> struct prec_traits<data_type::f32> { typedef float type; };
template <> struct prec_traits<data_type::s32> { typedef int32_t type; };
template <> struct prec_traits<data_type::s8> { typedef int8_t type; };
template <> struct prec_traits<data_type::u8> { typedef uint8_t type; };
template <> struct data_traits<float>
{ static constexpr data_type_t data_type = data_type::f32; };
template <> struct data_traits<int32_t>
{ static constexpr data_type_t data_type = data_type::s32; };
template <> struct data_traits<int8_t>
{ static constexpr data_type_t data_type = data_type::s8; };
template <> struct data_traits<uint8_t>
{ static constexpr data_type_t data_type = data_type::u8; };
template <> struct typesize_traits<4> { typedef float type; };
template <> struct typesize_traits<2> { typedef int16_t type; };
template <> struct typesize_traits<1> { typedef uint8_t type; };
#define PKIND_TRAITS_INST(op) \
template <> struct pkind_traits<primitive_kind::op> { \
typedef CONCAT2(op, _desc_t) desc_type; \
static constexpr query_t query_d = query::CONCAT2(op, _d); \
}
PKIND_TRAITS_INST(convolution);
PKIND_TRAITS_INST(deconvolution);
PKIND_TRAITS_INST(shuffle);
PKIND_TRAITS_INST(eltwise);
PKIND_TRAITS_INST(softmax);
PKIND_TRAITS_INST(pooling);
PKIND_TRAITS_INST(lrn);
PKIND_TRAITS_INST(batch_normalization);
PKIND_TRAITS_INST(inner_product);
PKIND_TRAITS_INST(rnn);
#undef PKIND_TRAITS_INST
}
}
#endif
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

View file

@ -1,193 +0,0 @@
/*******************************************************************************
* Copyright 2016-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifndef NSTL_HPP
#define NSTL_HPP
#include <stdint.h>
#include <limits.h>
#include <float.h>
#include <vector>
#include <map>
#include "z_magic.hpp"
namespace mkldnn {
namespace impl {
void *malloc(size_t size, int alignment);
void free(void *p);
struct c_compatible {
enum { default_alignment = 64 };
static void *operator new(size_t sz) {
return malloc(sz, default_alignment);
}
static void *operator new(size_t sz, void *p) { UNUSED(sz); return p; }
static void *operator new[](size_t sz) {
return malloc(sz, default_alignment);
}
static void operator delete(void *p) { free(p); }
static void operator delete[](void *p) { free(p); }
};
namespace nstl {
template<typename T>
inline const T abs(const T& a) {
return a >= 0 ? a : -a;
}
template<typename T>
inline const T& max(const T& a, const T& b) {
return a > b ? a : b;
}
template<typename T>
inline const T& min(const T& a, const T& b) {
return a < b ? a : b;
}
template<typename T> void swap(T& t1, T& t2) {
T tmp(t1);
t1 = t2;
t2 = tmp;
}
// Rationale: MKL-DNN needs numeric limits implementation that does not
// generate dependencies on C++ run-time libraries.
template<typename T> struct numeric_limits;
template<> struct numeric_limits<float> {
static constexpr float lowest() { return -FLT_MAX; }
static constexpr float max() { return FLT_MAX; }
};
template<> struct numeric_limits<int32_t> {
static constexpr int lowest() { return INT32_MIN; }
static constexpr int max() { return INT32_MAX; }
};
template<> struct numeric_limits<int16_t> {
static constexpr int16_t lowest() { return INT16_MIN; }
static constexpr int16_t max() { return INT16_MAX; }
};
template<> struct numeric_limits<int8_t> {
static constexpr int8_t lowest() { return INT8_MIN; }
static constexpr int8_t max() { return INT8_MAX; }
};
template<> struct numeric_limits<uint8_t> {
static constexpr uint8_t lowest() { return 0; }
static constexpr uint8_t max() { return UINT8_MAX; }
};
template<typename T> struct is_integral
{ static constexpr bool value = false; };
template<> struct is_integral<int32_t> { static constexpr bool value = true; };
template<> struct is_integral<int16_t> { static constexpr bool value = true; };
template<> struct is_integral<int8_t> { static constexpr bool value = true; };
template<> struct is_integral<uint8_t> { static constexpr bool value = true; };
template <typename T, typename U> struct is_same
{ static constexpr bool value = false; };
template <typename T> struct is_same<T, T>
{ static constexpr bool value = true; };
// Rationale: MKL-DNN needs container implementations that do not generate
// dependencies on C++ run-time libraries.
//
// Implementation philosophy: caller is responsible to check if the operation
// is valid. The only functions that have to return status are those that
// depend on memory allocation or similar operations.
//
// This means that e.g. an operator [] does not have to check for boundaries.
// The caller should have checked the boundaries. If it did not we crash and
// burn: this is a bug in MKL-DNN and throwing an exception would not have been
// recoverable.
//
// On the other hand, insert() or resize() or a similar operation needs to
// return a status because the outcome depends on factors external to the
// caller. The situation is probably also not recoverable also, but MKL-DNN
// needs to be nice and report "out of memory" to the users.
enum nstl_status_t {
success = 0,
out_of_memory
};
template <typename T> class vector: public c_compatible {
private:
std::vector<T> _impl;
public:
typedef typename std::vector<T>::iterator iterator;
typedef typename std::vector<T>::const_iterator const_iterator;
typedef typename std::vector<T>::size_type size_type;
vector() {}
vector(size_type n): _impl(n) {}
vector(size_type n, const T &value): _impl(n, value) {}
template <typename input_iterator>
vector(input_iterator first, input_iterator last): _impl(first, last) {}
~vector() {}
size_type size() const { return _impl.size(); }
T& operator[] (size_type i) { return _impl[i]; }
const T& operator[] (size_type i) const { return _impl[i]; }
iterator begin() { return _impl.begin(); }
const_iterator begin() const { return _impl.begin(); }
iterator end() { return _impl.end(); }
const_iterator end() const { return _impl.end(); }
template <typename input_iterator>
nstl_status_t insert(iterator pos, input_iterator begin, input_iterator end)
{
_impl.insert(pos, begin, end);
return success;
}
void clear() { _impl.clear(); }
void push_back(const T& t) { _impl.push_back(t); }
void resize(size_type count) { _impl.resize(count); }
void reserve(size_type count) { _impl.reserve(count); }
};
template <typename Key, typename T> class map: public c_compatible {
private:
std::map<Key, T> _impl;
public:
typedef typename std::map<Key, T>::iterator iterator;
typedef typename std::map<Key, T>::const_iterator const_iterator;
typedef typename std::map<Key, T>::size_type size_type;
map() {}
~map() {}
size_type size() const { return _impl.size(); }
T& operator[](const Key &k) { return _impl[k]; }
const T& operator[](const Key &k) const { return _impl[k]; }
iterator begin() { return _impl.begin(); }
const_iterator begin() const { return _impl.begin(); }
iterator end() { return _impl.end(); }
const_iterator end() const { return _impl.end(); }
template <typename input_iterator>
void clear() { _impl.clear(); }
};
}
}
}
#endif
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

View file

@ -1,114 +0,0 @@
/*******************************************************************************
* Copyright 2016-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <assert.h>
#include "mkldnn.h"
#include "c_types_map.hpp"
#include "type_helpers.hpp"
#include "utils.hpp"
using namespace mkldnn::impl;
using namespace mkldnn::impl::utils;
using namespace mkldnn::impl::status;
using namespace mkldnn::impl::prop_kind;
using namespace mkldnn::impl::alg_kind;
using namespace mkldnn::impl::types;
namespace {
status_t pooling_desc_init(pooling_desc_t *pool_desc,
prop_kind_t prop_kind, alg_kind_t alg_kind,
const memory_desc_t *src_desc, const memory_desc_t *dst_desc,
const dims_t strides, const dims_t kernel, const dims_t padding_l,
const dims_t padding_r, padding_kind_t padding_kind) {
bool args_ok = true
&& !any_null(pool_desc, src_desc, dst_desc, strides, kernel, padding_l)
&& one_of(alg_kind, pooling_max,
pooling_avg_include_padding,
pooling_avg_exclude_padding)
&& one_of(padding_kind, padding_kind::padding_zero);
if (!args_ok) return invalid_arguments;
if (padding_r == nullptr) padding_r = padding_l;
auto pd = pooling_desc_t();
pd.primitive_kind = primitive_kind::pooling;
pd.prop_kind = prop_kind;
pd.alg_kind = alg_kind;
pd.src_desc.ndims = src_desc->ndims;
const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
pd.diff_src_desc = pd.src_desc = zero_md();
pd.diff_dst_desc = pd.dst_desc = zero_md();
(is_fwd ? pd.src_desc : pd.diff_src_desc) = *src_desc;
(is_fwd ? pd.dst_desc : pd.diff_dst_desc) = *dst_desc;
int sp_dims = src_desc->ndims - 2;
utils::array_copy(pd.strides, strides, sp_dims);
utils::array_copy(pd.kernel, kernel, sp_dims);
utils::array_copy(pd.padding[0], padding_l, sp_dims);
utils::array_copy(pd.padding[1], padding_r, sp_dims);
pd.padding_kind = padding_kind;
if (one_of(alg_kind, pooling_max, pooling_avg_include_padding,
pooling_avg_exclude_padding)) {
pd.accum_data_type = types::default_accum_data_type(
src_desc->data_type, dst_desc->data_type);
} else {
pd.accum_data_type = dst_desc->data_type;
}
bool consistency = true
&& utils::one_of(src_desc->ndims, 4, 5)
&& utils::one_of(dst_desc->ndims, 4, 5)
&& src_desc->dims[0] == dst_desc->dims[0]
&& src_desc->dims[1] == dst_desc->dims[1];
for (int i = 2; i < src_desc->ndims; ++i)
consistency = consistency && (
(src_desc->dims[i] - kernel[i - 2] + padding_l[i - 2]
+ padding_r[i - 2]) / strides[i - 2] + 1
== dst_desc->dims[i]);
if (!consistency) return invalid_arguments;
*pool_desc = pd;
return success;
}
}
status_t mkldnn_pooling_forward_desc_init(pooling_desc_t *pool_desc,
prop_kind_t prop_kind, alg_kind_t alg_kind,
const memory_desc_t *src_desc, const memory_desc_t *dst_desc,
const dims_t strides, const dims_t kernel, const dims_t padding_l,
const dims_t padding_r, padding_kind_t padding_kind) {
if (!one_of(prop_kind, forward_training, forward_inference))
return invalid_arguments;
return pooling_desc_init(pool_desc, prop_kind, alg_kind, src_desc,
dst_desc, strides, kernel, padding_l, padding_r, padding_kind);
}
status_t mkldnn_pooling_backward_desc_init(pooling_desc_t *pool_desc,
alg_kind_t alg_kind, const memory_desc_t *diff_src_desc,
const memory_desc_t *diff_dst_desc, const dims_t strides,
const dims_t kernel, const dims_t padding_l, const dims_t padding_r,
padding_kind_t padding_kind) {
return pooling_desc_init(pool_desc, prop_kind::backward_data, alg_kind,
diff_src_desc, diff_dst_desc, strides, kernel, padding_l,
padding_r, padding_kind);
}
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

View file

@ -1,238 +0,0 @@
/*******************************************************************************
* Copyright 2016-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifndef POOLING_PD_HPP
#define POOLING_PD_HPP
#include "mkldnn.h"
#include "c_types_map.hpp"
#include "primitive_desc.hpp"
#include "type_helpers.hpp"
namespace mkldnn {
namespace impl {
struct pooling_fwd_pd_t;
struct pooling_pd_t: public primitive_desc_t {
static constexpr auto base_pkind = primitive_kind::pooling;
pooling_pd_t(engine_t *engine,
const pooling_desc_t *adesc,
const primitive_attr_t *attr,
const pooling_fwd_pd_t *hint_fwd_pd)
: primitive_desc_t(engine, attr, base_pkind)
, desc_(*adesc)
, hint_fwd_pd_(hint_fwd_pd)
, ws_md_()
{}
const pooling_desc_t *desc() const { return &desc_; }
virtual const op_desc_t *op_desc() const override
{ return reinterpret_cast<const op_desc_t *>(this->desc()); }
virtual void init_info() override { impl::init_info(this, this->info_); }
virtual status_t query(query_t what, int idx, void *result) const override {
switch (what) {
case query::pooling_d:
*(const pooling_desc_t**)result = desc(); break;
default: return primitive_desc_t::query(what, idx, result);
}
return status::success;
}
/* common pooling aux functions */
dim_t MB() const { return src_desc().dims[0]; }
dim_t C() const { return src_desc().dims[1]; }
dim_t ID() const { return ndims() >= 5 ? src_desc().dims[ndims() - 3] : 1; }
dim_t IH() const { return ndims() >= 4 ? src_desc().dims[ndims() - 2] : 1; }
dim_t IW() const { return src_desc().dims[ndims() - 1]; }
dim_t OD() const { return ndims() >= 5 ? dst_desc().dims[ndims() - 3] : 1; }
dim_t OH() const { return ndims() >= 4 ? dst_desc().dims[ndims() - 2] : 1; }
dim_t OW() const { return dst_desc().dims[ndims() - 1]; }
dim_t KD() const { return ndims() >= 5 ? desc_.kernel[ndims() - 5] : 1; }
dim_t KH() const { return ndims() >= 4 ? desc_.kernel[ndims() - 4] : 1; }
dim_t KW() const { return desc_.kernel[ndims() - 3]; }
dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; }
dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; }
dim_t KSW() const { return desc_.strides[ndims() - 3]; }
dim_t padFront() const
{ return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; }
dim_t padBack() const
{ return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; }
dim_t padT() const
{ return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; }
dim_t padB() const
{ return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; }
dim_t padL() const { return desc_.padding[0][ndims() - 3]; }
dim_t padR() const { return desc_.padding[1][ndims() - 3]; }
int ndims() const { return src_desc().ndims; }
bool is_3d() const { return ndims() == 5; }
bool has_zero_dim_memory() const
{ return memory_desc_wrapper(src_desc()).has_zero_dim(); }
bool is_fwd() const {
return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
prop_kind::forward_inference);
}
protected:
pooling_desc_t desc_;
const pooling_fwd_pd_t *hint_fwd_pd_;
memory_desc_t ws_md_;
void init_default_ws() {
ws_md_ = is_fwd() ? *dst_md() : *diff_dst_md();
ws_md_.data_type = indices_data_type();
}
data_type_t indices_data_type() const {
/* the simplest way to express 256... */
const int u8_max = nstl::numeric_limits<
typename prec_traits<data_type::u8>::type>::max();
return utils::array_product(desc()->kernel, ndims()) <= u8_max
? data_type::u8 : data_type::s32;
}
private:
const memory_desc_t &src_desc() const
{ return is_fwd() ? desc_.src_desc : desc_.diff_src_desc; }
const memory_desc_t &dst_desc() const
{ return is_fwd() ? desc_.dst_desc : desc_.diff_dst_desc; }
};
struct pooling_fwd_pd_t: public pooling_pd_t {
typedef pooling_fwd_pd_t base_class;
typedef pooling_fwd_pd_t hint_class;
pooling_fwd_pd_t(engine_t *engine,
const pooling_desc_t *adesc,
const primitive_attr_t *attr,
const pooling_fwd_pd_t *hint_fwd_pd)
: pooling_pd_t(engine, adesc, attr, hint_fwd_pd)
, src_md_(desc_.src_desc)
, dst_md_(desc_.dst_desc)
{}
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
if (arg == MKLDNN_ARG_SRC)
return arg_usage_t::input;
if (arg == MKLDNN_ARG_DST)
return arg_usage_t::output;
if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr))
return arg_usage_t::output;
return primitive_desc_t::arg_usage(arg);
}
virtual const memory_desc_t *src_md(int index = 0) const override
{ return index == 0 ? &src_md_ : nullptr; }
virtual const memory_desc_t *dst_md(int index = 0) const override
{ return index == 0 ? &dst_md_ : nullptr; }
virtual const memory_desc_t *workspace_md(int index = 0) const override
{ return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; }
virtual int n_inputs() const override { return 1; }
virtual int n_outputs() const override
{ return 1 + (workspace_md() != nullptr); }
protected:
memory_desc_t src_md_;
memory_desc_t dst_md_;
virtual status_t set_default_params() {
if (dst_md()->format_kind != format_kind::any)
return status::success;
if (src_md()->format_kind != format_kind::blocked)
return status::unimplemented;
return memory_desc_init_by_blocking_desc(dst_md_,
src_md_.format_desc.blocking);
}
};
struct pooling_bwd_pd_t: public pooling_pd_t {
typedef pooling_bwd_pd_t base_class;
typedef pooling_fwd_pd_t hint_class;
pooling_bwd_pd_t(engine_t *engine,
const pooling_desc_t *adesc,
const primitive_attr_t *attr,
const pooling_fwd_pd_t *hint_fwd_pd)
: pooling_pd_t(engine, adesc, attr, hint_fwd_pd)
, diff_src_md_(desc_.diff_src_desc)
, diff_dst_md_(desc_.diff_dst_desc)
{}
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
if (arg == MKLDNN_ARG_DIFF_DST)
return arg_usage_t::input;
if (arg == MKLDNN_ARG_DIFF_SRC)
return arg_usage_t::output;
if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr))
return arg_usage_t::input;
return primitive_desc_t::arg_usage(arg);
}
virtual const memory_desc_t *diff_src_md(int index = 0) const override
{ return index == 0 ? &diff_src_md_ : nullptr; }
virtual const memory_desc_t *diff_dst_md(int index = 0) const override
{ return index == 0 ? &diff_dst_md_ : nullptr; }
virtual const memory_desc_t *workspace_md(int index = 0) const override
{ return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; }
virtual int n_inputs() const override
{ return 1 + (workspace_md() != nullptr); }
virtual int n_outputs() const override { return 1; }
protected:
memory_desc_t diff_src_md_;
memory_desc_t diff_dst_md_;
virtual status_t set_default_params() {
if (diff_src_md()->format_kind != format_kind::any)
return status::success;
if (diff_dst_md()->format_kind != format_kind::blocked)
return status::unimplemented;
return memory_desc_init_by_blocking_desc(diff_src_md_,
diff_dst_md_.format_desc.blocking);
}
};
}
}
#endif
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

View file

@ -1,103 +0,0 @@
/*******************************************************************************
* Copyright 2016-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <assert.h>
#include "c_types_map.hpp"
#include "engine.hpp"
#include "primitive_desc.hpp"
#include "primitive.hpp"
#include "type_helpers.hpp"
#include "stream.hpp"
#include "utils.hpp"
using namespace mkldnn::impl;
using namespace mkldnn::impl::status;
using namespace mkldnn::impl::primitive_kind;
namespace {
// XXX: this is a huge hammer. This disables all and any msan checks on
// primitives outputs.
//
// A proper approach would be an implementation-specific unpoisoning.
void unpoison_outputs(const exec_args_t &args) {
for(const auto &arg: args) {
if (arg.second.is_const) continue;
auto *mem = arg.second.mem;
void *p;
mem->get_data_handle(&p);
size_t s = memory_desc_wrapper(*mem->md()).size();
msan_unpoison(p, s);
}
}
}
status_t mkldnn_primitive_desc_destroy(primitive_desc_t *primitive_desc) {
if (primitive_desc) delete primitive_desc;
return success;
}
status_t mkldnn_primitive_create(primitive_t **primitive,
const primitive_desc_t *primitive_desc) {
if (utils::any_null(primitive, primitive_desc))
return invalid_arguments;
return primitive_desc->create_primitive(primitive);
}
status_t mkldnn_primitive_execute(const primitive_t *primitive,
stream_t *stream, int nargs, const mkldnn_exec_arg_t *c_args) {
bool ok = true
&& !utils::any_null(primitive, stream)
&& primitive->engine() == stream->engine()
&& IMPLICATION(nargs > 0, c_args != nullptr);
if (!ok) return invalid_arguments;
exec_args_t args;
status_t status = cvt_primtive_args(primitive->pd(), nargs, c_args, args);
if (status != status::success) return status;
exec_ctx_t ctx(stream, std::move(args));
if (mkldnn_verbose()->level) {
double ms = get_msec();
status = primitive->execute(ctx);
ms = get_msec() - ms;
printf("mkldnn_verbose,exec,%s,%g\n", primitive->pd()->info(), ms);
fflush(0);
} else {
status = primitive->execute(ctx);
}
if (msan_enabled) unpoison_outputs(ctx.args());
return status;
}
status_t mkldnn_primitive_get_primitive_desc(const primitive_t *primitive,
const primitive_desc_t **primitive_desc) {
if (utils::any_null(primitive, primitive_desc))
return invalid_arguments;
return safe_ptr_assign<const primitive_desc_t>(*primitive_desc,
primitive->pd());
}
status_t mkldnn_primitive_destroy(primitive_t *primitive) {
if (primitive != nullptr)
delete primitive;
return success;
}
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

View file

@ -1,76 +0,0 @@
/*******************************************************************************
* Copyright 2016-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifndef PRIMITIVE_HPP
#define PRIMITIVE_HPP
#include <assert.h>
#include "mkldnn.h"
#include "c_types_map.hpp"
#include "nstl.hpp"
#include "primitive_desc.hpp"
#include "primitive_exec_types.hpp"
/** \brief A pure virtual primitive class
*
* Primitive contains links to its inputs & outputs, though it does not track
* their readiness on execution step.
*
* @remark @b Rational.
* Dependencies are essential through-out the whole MKL-DNN library, so it
* makes sense to include them on the very low level. On the other hand,
* tracking them should be a task for corresponding essence, like scheduler,
* stream or whatever. Primitive itself should know nothing about the
* environment it is running in.
*
* @note
* To make user experience better we should provide API which allows
* achieving the best (or good enough) performance when creating primitives
* in natural order: i.e. from bottom to top for forward pass and from top to
* bottom for backward pass. Please consider restriction [1] in Level 0.
*/
struct mkldnn_primitive: public mkldnn::impl::c_compatible {
mkldnn_primitive(const mkldnn::impl::primitive_desc_t *pd)
: pd_(pd->clone()) {}
virtual ~mkldnn_primitive() { delete pd_; }
/** returns primitive's engine */
mkldnn::impl::engine_t *engine() const { return pd_->engine(); }
/** returns primitive's inputs */
const mkldnn::impl::primitive_desc_t *pd() const { return pd_; }
/** returns primitive's kind */
mkldnn::impl::primitive_kind_t kind() const { return pd_->kind(); }
/** executes primitive with execution context @p ctx */
virtual mkldnn::impl::status_t execute(const mkldnn::impl::exec_ctx_t &ctx)
const = 0;
protected:
const mkldnn::impl::primitive_desc_t *pd_;
private:
mkldnn_primitive() = delete;
mkldnn_primitive(const mkldnn_primitive &) = delete;
mkldnn_primitive(mkldnn_primitive &&) = delete;
mkldnn_primitive &operator=(const mkldnn_primitive &) = delete;
mkldnn_primitive &operator=(mkldnn_primitive &&) = delete;
};
#endif
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

View file

@ -1,290 +0,0 @@
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "mkldnn.h"
#include "c_types_map.hpp"
#include "primitive_attr.hpp"
#include "type_helpers.hpp"
#include "utils.hpp"
using namespace mkldnn::impl;
using namespace mkldnn::impl::status;
using namespace mkldnn::impl::utils;
namespace mkldnn {
namespace impl {
status_t scales_t::set(dim_t count, int mask, const float *scales) {
cleanup();
count_ = count;
mask_ = mask;
if (count_ == 1) {
scales_ = scales_buf_;
utils::array_set(scales_, scales[0], scales_buf_size);
} else {
scales_ = (float *)impl::malloc(count_ * sizeof(*scales_), 64);
if (scales_ == nullptr)
return status::out_of_memory;
for (dim_t c = 0; c < count_; ++c)
scales_[c] = scales[c];
}
return status::success;
}
}
}
status_t post_ops_t::append_sum(float scale) {
if (len_ == capacity)
return out_of_memory;
entry_[len_].kind = primitive_kind::sum;
entry_[len_].sum.scale = scale;
len_++;
return success;
}
status_t post_ops_t::append_eltwise(float scale, alg_kind_t alg, float alpha,
float beta) {
using namespace mkldnn::impl::alg_kind;
bool known_alg = one_of(alg, eltwise_relu, eltwise_tanh, eltwise_elu,
eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic);
if (!known_alg)
return invalid_arguments;
if (len_ == capacity)
return out_of_memory;
entry_[len_].kind = primitive_kind::eltwise;
entry_[len_].eltwise.scale = scale;
entry_[len_].eltwise.alg = alg;
entry_[len_].eltwise.alpha = alpha;
entry_[len_].eltwise.beta = beta;
len_++;
return success;
}
status_t primitive_attr_t::set_scratchpad_mode(
scratchpad_mode_t scratchpad_mode) {
using namespace mkldnn::impl::scratchpad_mode;
const bool ok = one_of(scratchpad_mode, library, user);
if (!ok)
return invalid_arguments;
scratchpad_mode_ = scratchpad_mode;
return success;
}
status_t primitive_attr_t::set_post_ops(const post_ops_t &post_ops) {
this->post_ops_ = post_ops;
return success;
}
/* Public C API */
status_t mkldnn_primitive_attr_create(primitive_attr_t **attr) {
if (attr == nullptr)
return invalid_arguments;
return safe_ptr_assign<mkldnn_primitive_attr>(*attr,
new mkldnn_primitive_attr);
}
status_t mkldnn_primitive_attr_clone(primitive_attr_t **attr,
const primitive_attr_t *existing_attr) {
if (any_null(attr, existing_attr))
return invalid_arguments;
return safe_ptr_assign<mkldnn_primitive_attr>(*attr,
existing_attr->clone());
}
status_t mkldnn_primitive_attr_destroy(primitive_attr_t *attr) {
if (attr)
delete attr;
return success;
}
status_t mkldnn_primitive_attr_get_scratchpad_mode(
const primitive_attr_t *attr, scratchpad_mode_t *scratchpad_mode) {
if (any_null(attr, scratchpad_mode))
return invalid_arguments;
*scratchpad_mode = attr->scratchpad_mode_;
return success;
}
status_t mkldnn_primitive_attr_set_scratchpad_mode(
primitive_attr_t *attr, scratchpad_mode_t scratchpad_mode) {
if (any_null(attr))
return invalid_arguments;
return attr->set_scratchpad_mode(scratchpad_mode);
}
status_t mkldnn_primitive_attr_get_output_scales(const primitive_attr_t *attr,
dim_t *count, int *mask, const float **scales) {
if (any_null(attr, count, mask, scales))
return invalid_arguments;
*count = attr->output_scales_.count_;
*mask = attr->output_scales_.mask_;
*scales = attr->output_scales_.scales_;
return success;
}
status_t mkldnn_primitive_attr_set_output_scales(primitive_attr_t *attr,
dim_t count, int mask, const float *scales) {
bool ok = !any_null(attr, scales) && count > 0 && mask >= 0;
if (!ok)
return invalid_arguments;
return attr->output_scales_.set(count, mask, scales);
}
status_t mkldnn_primitive_attr_get_post_ops(const primitive_attr_t *attr,
const post_ops_t **post_ops) {
if (any_null(attr, post_ops))
return invalid_arguments;
*post_ops = &attr->post_ops_;
return success;
}
status_t mkldnn_primitive_attr_set_post_ops(primitive_attr_t *attr,
const post_ops_t *post_ops) {
if (any_null(attr, post_ops))
return invalid_arguments;
return attr->set_post_ops(*post_ops);
}
status_t mkldnn_post_ops_create(post_ops_t **post_ops) {
if (post_ops == nullptr)
return invalid_arguments;
return safe_ptr_assign<mkldnn_post_ops>(*post_ops, new mkldnn_post_ops);
}
status_t mkldnn_post_ops_destroy(post_ops_t *post_ops) {
if (post_ops)
delete post_ops;
return success;
}
int mkldnn_post_ops_len(const post_ops_t *post_ops) {
if (post_ops)
return post_ops->len_;
return 0;
}
primitive_kind_t mkldnn_post_ops_get_kind(const post_ops_t *post_ops,
int index) {
bool ok = post_ops && 0 <= index && index < post_ops->len_;
if (!ok)
return primitive_kind::undefined;
return post_ops->entry_[index].kind;
}
status_t mkldnn_post_ops_append_sum(post_ops_t *post_ops, float scale) {
if (post_ops == nullptr)
return invalid_arguments;
return post_ops->append_sum(scale);
}
namespace {
bool simple_get_params_check(const post_ops_t *post_ops, int index,
primitive_kind_t kind) {
bool ok = true
&& post_ops != nullptr
&& 0 <= index
&& index < post_ops->len_
&& post_ops->entry_[index].kind == kind;
return ok;
}
}
status_t mkldnn_post_ops_get_params_sum(const post_ops_t *post_ops, int index,
float *scale) {
bool ok = true
&& simple_get_params_check(post_ops, index, primitive_kind::sum)
&& !any_null(scale);
if (!ok)
return invalid_arguments;
*scale = post_ops->entry_[index].sum.scale;
return success;
}
status_t mkldnn_post_ops_append_eltwise(post_ops_t *post_ops, float scale,
alg_kind_t kind, float alpha, float beta) {
if (post_ops == nullptr)
return invalid_arguments;
return post_ops->append_eltwise(scale, kind, alpha, beta);
}
status_t mkldnn_post_ops_get_params_eltwise(const post_ops_t *post_ops,
int index, float *scale, alg_kind_t *alg, float *alpha, float *beta) {
bool ok = true
&& simple_get_params_check(post_ops, index, primitive_kind::eltwise)
&& !any_null(scale, alpha, beta);
if (!ok)
return invalid_arguments;
const auto &e = post_ops->entry_[index].eltwise;
*scale = e.scale;
*alg = e.alg;
*alpha = e.alpha;
*beta = e.beta;
return success;
}
status_t mkldnn_primitive_attr_set_rnn_data_qparams(
primitive_attr_t *attr, const float scale, const float shift) {
if (attr == nullptr)
return invalid_arguments;
return attr->rnn_data_qparams_.set(scale, shift);
}
status_t mkldnn_primitive_attr_set_rnn_weights_qparams(
primitive_attr_t *attr, dim_t count, int mask, const float *scales) {
bool ok = !any_null(attr, scales) && count > 0 && mask >= 0;
if (!ok)
return invalid_arguments;
return attr->rnn_weights_qparams_.set(count, mask, scales);
}

View file

@ -1,183 +0,0 @@
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifndef PRIMITIVE_ATTR_HPP
#define PRIMITIVE_ATTR_HPP
#include "mkldnn.h"
#include "c_types_map.hpp"
#include "nstl.hpp"
#include "utils.hpp"
namespace mkldnn {
namespace impl {
struct rnn_data_qparams_t : public c_compatible {
rnn_data_qparams_t() : scale_(1.), shift_(0.) {}
bool has_default_values() const { return (scale_ == 1. && shift_ == 0.); }
status_t set(float scale, float shift) {
scale_ = scale;
shift_ = shift;
return status::success;
}
float scale_;
float shift_;
};
struct scales_t: public c_compatible {
scales_t(): count_(1), mask_(0), scales_(scales_buf_)
{ set(1.); }
scales_t(const scales_t &rhs): scales_t()
{ set(rhs.count_, rhs.mask_, rhs.scales_); }
~scales_t() { cleanup(); }
scales_t &operator=(const scales_t &rhs) {
if (&rhs == this)
return *this;
status_t status = set(rhs.count_, rhs.mask_, rhs.scales_);
assert(status == status::success);
(void)status;
return *this;
}
bool has_default_values() const {
for (dim_t c = 0; c < count_; ++c) {
if(scales_[c] != 1.) return false;
}
return true;
}
status_t set(dim_t count, int mask, const float *scales);
status_t set(float single_scale) { return this->set(1, 0, &single_scale); }
dim_t count_;
int mask_;
float *scales_;
private:
enum { scales_buf_size = 16 };
float scales_buf_[scales_buf_size];
void cleanup() {
if (scales_ != scales_buf_ && scales_ != nullptr)
impl::free(scales_);
count_ = 1;
mask_ = 0;
scales_ = scales_buf_;
}
};
}
}
struct mkldnn_post_ops: public mkldnn::impl::c_compatible {
struct entry_t {
struct eltwise_t {
mkldnn::impl::alg_kind_t alg;
float scale, alpha, beta;
};
mkldnn::impl::primitive_kind_t kind;
union {
struct { float scale; } sum;
eltwise_t eltwise;
};
bool is_eltwise(bool require_scale_one = true) const {
using namespace mkldnn::impl;
return kind == primitive_kind::eltwise
&& IMPLICATION(require_scale_one, eltwise.scale == 1.f);
}
bool is_relu(bool require_scale_one = true,
bool require_nslope_zero = true) const {
using namespace mkldnn::impl;
return is_eltwise(require_scale_one)
&& eltwise.alg == alg_kind::eltwise_relu
&& IMPLICATION(require_nslope_zero, eltwise.alpha == 0.f);
}
bool is_sum(bool require_scale_one = true) const {
using namespace mkldnn::impl;
return kind == primitive_kind::sum
&& IMPLICATION(require_scale_one, sum.scale == 1.f);
}
};
mkldnn_post_ops(): len_(0) {}
mkldnn::impl::status_t append_sum(float scale);
mkldnn::impl::status_t append_eltwise(float scale,
mkldnn::impl::alg_kind_t alg, float alpha, float beta);
int find(mkldnn::impl::primitive_kind_t kind, int start = 0,
int stop = -1) const {
if (stop == -1) stop = len_;
stop = mkldnn::impl::nstl::min(stop, len_);
for (int idx = start; idx < stop; ++idx)
if (entry_[idx].kind == kind) return idx;
return -1;
}
bool has_default_values() const { return len_ == 0; }
bool contain(mkldnn::impl::primitive_kind_t kind, int index) const
{ return find(kind, index, index + 1) == index; }
enum { capacity = 4 };
int len_;
entry_t entry_[capacity];
};
struct mkldnn_primitive_attr: public mkldnn::impl::c_compatible {
mkldnn_primitive_attr()
: scratchpad_mode_(mkldnn::impl::scratchpad_mode::library)
{}
mkldnn_primitive_attr *clone() const
{ return new mkldnn_primitive_attr(*this); }
/** Returns true if the attributes have default values.
*
* @note The scratchpad_mode_ is not take into account */
bool has_default_values() const {
return true
&& output_scales_.has_default_values()
&& post_ops_.has_default_values()
&& rnn_data_qparams_.has_default_values()
&& rnn_weights_qparams_.has_default_values();
}
mkldnn::impl::status_t set_scratchpad_mode(
mkldnn::impl::scratchpad_mode_t scratchpad_mode);
mkldnn::impl::status_t set_post_ops(
const mkldnn::impl::post_ops_t &post_ops);
mkldnn::impl::scratchpad_mode_t scratchpad_mode_;
mkldnn::impl::scales_t output_scales_;
mkldnn::impl::post_ops_t post_ops_;
mkldnn::impl::rnn_data_qparams_t rnn_data_qparams_;
mkldnn::impl::scales_t rnn_weights_qparams_;
};
#endif

View file

@ -1,78 +0,0 @@
/*******************************************************************************
* Copyright 2016-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "mkldnn.h"
#include "c_types_map.hpp"
#include "nstl.hpp"
#include "primitive_desc.hpp"
using namespace mkldnn::impl;
using namespace mkldnn::impl::status;
status_t primitive_desc_t::query(query_t what, int idx, void *result) const {
auto safe_ret_md = [&](const memory_desc_t *_) {
if (_ == nullptr) return not_required;
*(const memory_desc_t **)result = _;
return success;
};
switch (what) {
case query::engine: *(engine_t**)result = engine(); break;
case query::primitive_kind: *(primitive_kind_t*)result = kind(); break;
case query::scratchpad_engine:
*(engine_t**)result = scratchpad_engine(); break;
case query::memory_consumption_s64:
*(dim_t *)result = scratchpad_size(scratchpad_mode::library); break;
case query::op_d:
if (idx != 0 || op_desc() == nullptr) return invalid_arguments;
*(const_c_op_desc_t *)result
= static_cast<const_c_op_desc_t>(op_desc()); break;
case query::src_md: return safe_ret_md(src_md(idx));
case query::diff_src_md: return safe_ret_md(diff_src_md(idx));
case query::dst_md: return safe_ret_md(dst_md(idx));
case query::diff_dst_md: return safe_ret_md(diff_dst_md(idx));
case query::weights_md: return safe_ret_md(weights_md(idx));
case query::diff_weights_md: return safe_ret_md(diff_weights_md(idx));
case query::workspace_md:
if (idx != 0) return status::invalid_arguments;
return safe_ret_md(workspace_md(idx));
case query::scratchpad_md:
if (idx != 0) return status::invalid_arguments;
return safe_ret_md(scratchpad_md(idx));
case query::num_of_inputs_s32: *(int*)result = n_inputs(); break;
case query::num_of_outputs_s32: *(int*)result = n_outputs(); break;
case query::impl_info_str: *(const char **)result = name(); break;
default: return unimplemented;
}
return success;
}
status_t mkldnn_primitive_desc_get_attr(const primitive_desc_t *primitive_desc,
const primitive_attr_t **attr) {
if (utils::any_null(primitive_desc, attr))
return invalid_arguments;
*attr = primitive_desc->attr();
return success;
}

View file

@ -1,174 +0,0 @@
/*******************************************************************************
* Copyright 2016-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifndef PRIMITIVE_DESC_HPP
#define PRIMITIVE_DESC_HPP
#include "mkldnn.h"
#include "c_types_map.hpp"
#include "memory_tracking.hpp"
#include "nstl.hpp"
#include "type_helpers.hpp"
#include "primitive_attr.hpp"
#include "verbose.hpp"
struct mkldnn_primitive_desc: public mkldnn::impl::c_compatible {
using md_t = mkldnn::impl::memory_desc_t;
mkldnn_primitive_desc(mkldnn::impl::engine_t *engine,
const mkldnn::impl::primitive_attr_t *attr,
mkldnn::impl::primitive_kind_t kind)
: engine_(engine), attr_(*attr), kind_(kind) { info_[0] = '\0'; }
mkldnn_primitive_desc(mkldnn::impl::engine_t *engine,
mkldnn::impl::primitive_kind_t kind)
: engine_(engine), kind_(kind) { info_[0] = '\0'; }
virtual mkldnn_primitive_desc *clone() const = 0;
virtual ~mkldnn_primitive_desc() {}
const mkldnn::impl::primitive_attr_t *attr() const { return &attr_; }
mkldnn::impl::engine_t *engine() const { return engine_; }
mkldnn::impl::primitive_kind_t kind() const { return kind_; }
virtual void init_info() {}
const char *info() const { return info_; }
mkldnn::impl::memory_tracking::registry_t &scratchpad_registry()
{ return scratchpad_registry_; }
const mkldnn::impl::memory_tracking::registry_t &scratchpad_registry() const
{ return scratchpad_registry_; }
virtual mkldnn::impl::engine_t *scratchpad_engine() const
{ return engine_; }
virtual const mkldnn::impl::op_desc_t *op_desc() const { return nullptr; }
enum class arg_usage_t { unused, input, output };
virtual arg_usage_t arg_usage(
mkldnn::impl::primitive_arg_index_t arg) const {
using mkldnn::impl::types::is_zero_md;
if (arg == MKLDNN_ARG_SCRATCHPAD && !is_zero_md(scratchpad_md()))
return arg_usage_t::output;
return arg_usage_t::unused;
}
# define DECLARE_MD_STUB(stub) \
virtual const mkldnn::impl::memory_desc_t *stub(int idx = 0) const \
{ return nullptr; }
DECLARE_MD_STUB(input_md); DECLARE_MD_STUB(output_md);
DECLARE_MD_STUB(src_md); DECLARE_MD_STUB(diff_src_md);
DECLARE_MD_STUB(dst_md); DECLARE_MD_STUB(diff_dst_md);
DECLARE_MD_STUB(weights_md); DECLARE_MD_STUB(diff_weights_md);
DECLARE_MD_STUB(workspace_md);
# undef DECLARE_MD_STUB
const mkldnn::impl::memory_desc_t *scratchpad_md(int idx = 0) const {
return idx == 0 ? &scratchpad_md_ : nullptr;
}
virtual void init_scratchpad_md() {
auto size = scratchpad_size(mkldnn::impl::scratchpad_mode::user);
mkldnn::impl::dims_t dims = { size };
mkldnn_memory_desc_init_by_tag(&scratchpad_md_, size ? 1 : 0, dims,
mkldnn::impl::data_type::u8, mkldnn_x);
}
/** returns the scratchpad size for the given scratchpad mode. */
mkldnn::impl::dim_t scratchpad_size(
mkldnn::impl::scratchpad_mode_t mode) const {
if (mode != attr_.scratchpad_mode_) return 0;
return scratchpad_registry().size();
}
virtual int n_inputs() const { return 0; }
virtual int n_outputs() const { return 0; }
virtual mkldnn::impl::status_t query(mkldnn::impl::query_t what, int idx,
void *result) const;
virtual mkldnn::impl::status_t create_primitive(
mkldnn::impl::primitive_t **primitive) const = 0;
virtual const char *name() const { return "mkldnn_primitive_desc"; }
/* static magic */
template<typename pd_t>
static mkldnn::impl::status_t create(mkldnn::impl::primitive_desc_t **pd,
const mkldnn::impl::op_desc_t *adesc,
const mkldnn::impl::primitive_attr_t *attr,
mkldnn::impl::engine_t *engine,
const mkldnn::impl::primitive_desc_t *hint_fwd) {
using namespace mkldnn::impl;
using namespace mkldnn::impl::status;
using pd_op_desc_t = typename pkind_traits<pd_t::base_pkind>::desc_type;
if (adesc->kind != pd_t::base_pkind) return invalid_arguments;
assert(hint_fwd ? hint_fwd->kind() == pd_t::base_pkind : true);
auto hint =
reinterpret_cast<const typename pd_t::hint_class *>(hint_fwd);
auto _pd = new pd_t(engine, (const pd_op_desc_t *)adesc, attr, hint);
if (_pd == nullptr) return out_of_memory;
if (_pd->init() != success) { delete _pd; return unimplemented; }
_pd->init_info();
_pd->init_scratchpad_md();
*pd = _pd;
return success;
}
protected:
mkldnn::impl::engine_t *engine_;
mkldnn::impl::primitive_attr_t attr_;
mkldnn::impl::primitive_kind_t kind_;
mkldnn::impl::memory_desc_t scratchpad_md_;
char info_[MKLDNN_VERBOSE_BUF_LEN];
mkldnn::impl::memory_tracking::registry_t scratchpad_registry_;
protected:
/** compares ws between fwd_pd and this (make sense to use for bwd_pd)
* Expectation: this already set workspace, and this workspace should
* exactly match the one from fwd_pd */
bool compare_ws(const mkldnn_primitive_desc *fwd_pd) const {
using namespace mkldnn::impl;
if (!workspace_md()) return true; // the impl lives fine w/o workspace
return fwd_pd && fwd_pd->workspace_md()
&& *fwd_pd->workspace_md() == *workspace_md();
}
};
#define DECLARE_COMMON_PD_t(impl_name, ...) \
virtual pd_t *clone() const override { return new pd_t(*this); } \
virtual status_t create_primitive(primitive_t **p) const override { \
double ms = get_msec(); \
auto ret = safe_ptr_assign<primitive_t>(*p, new (__VA_ARGS__)(this)); \
ms = get_msec() - ms; \
if (mkldnn_verbose()->level >= 2) { \
printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \
fflush(0); \
} \
return ret; \
} \
virtual const char *name() const override { return impl_name; }
#define DECLARE_COMMON_PD_T(impl_name, ...) \
DECLARE_COMMON_PD_t(impl_name, __VA_ARGS__)
#endif
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

View file

@ -1,90 +0,0 @@
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "memory.hpp"
#include "primitive.hpp"
#include "primitive_exec_types.hpp"
namespace mkldnn {
namespace impl {
status_t cvt_primtive_args(const primitive_desc_t *pd, int nargs,
const mkldnn_exec_arg_t *c_args, exec_args_t &args) {
using namespace status;
if (!IMPLICATION(nargs > 0, c_args != nullptr)) return invalid_arguments;
int n_inputs = 0;
int n_outputs = 0;
for (int i = 0; i < nargs; ++i) {
primitive_arg_index_t arg = c_args[i].arg;
auto *mem = c_args[i].memory;
switch (pd->arg_usage(arg)) {
case primitive_desc_t::arg_usage_t::input:
if (args.count(arg) != 0) return invalid_arguments;
args[arg] = {mem, true};
n_inputs++;
break;
case primitive_desc_t::arg_usage_t::output:
if (args.count(arg) != 0) return invalid_arguments;
args[arg] = {mem, false};
n_outputs++;
break;
case primitive_desc_t::arg_usage_t::unused:
break;
}
}
bool scratchpad_required = !types::is_zero_md(pd->scratchpad_md());
if (n_inputs != pd->n_inputs()) return invalid_arguments;
if (n_outputs != pd->n_outputs() + (scratchpad_required ? 1 : 0))
return invalid_arguments;
return success;
}
const void *exec_ctx_t::input(primitive_arg_index_t arg) const {
if (args_.count(arg) != 1) return nullptr;
const auto ma = args_.at(arg);
assert(ma.is_const);
void *ptr;
status_t status = ma.mem->get_data_handle(&ptr);
assert(status == status::success); MAYBE_UNUSED(status);
return ptr;
}
void *exec_ctx_t::output(primitive_arg_index_t arg) const {
if (args_.count(arg) != 1) return nullptr;
const auto ma = args_.at(arg);
assert(!ma.is_const);
void *ptr;
status_t status = ma.mem->get_data_handle(&ptr);
assert(status == status::success); MAYBE_UNUSED(status);
return ptr;
}
const memory_t *exec_ctx_t::memory(primitive_arg_index_t arg) const {
assert(args_.count(arg) == 1);
const auto ma = args_.at(arg);
assert(!ma.is_const);
return ma.mem;
}
}
}

View file

@ -1,68 +0,0 @@
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifndef PRIMITIVE_EXEC_TYPES_HPP
#define PRIMITIVE_EXEC_TYPES_HPP
#include <unordered_map>
#include "mkldnn_types.h"
#include "c_types_map.hpp"
#include "memory.hpp"
#include "primitive_desc.hpp"
namespace mkldnn {
namespace impl {
struct memory_arg_t {
memory_t *mem;
bool is_const;
};
using exec_args_t = std::unordered_map<primitive_arg_index_t, memory_arg_t>;
status_t cvt_primtive_args(const primitive_desc_t *pd, int nargs,
const mkldnn_exec_arg_t *c_args, exec_args_t &args);
/** Primitive execution context (helps passing stream, memories, and events. */
struct exec_ctx_t {
exec_ctx_t(const exec_ctx_t &) = default;
exec_ctx_t(exec_ctx_t &&) = default;
exec_ctx_t(stream_t *stream): stream_(stream) {}
exec_ctx_t(stream_t *stream, exec_args_t &&args)
: stream_(stream)
, args_(std::move(args)) {}
stream_t *stream() const { return stream_; }
const exec_args_t &args() const { return args_; }
/* tentative solution... TODO: replace with functions return memory_t */
const void *input(primitive_arg_index_t arg) const;
void *output(primitive_arg_index_t arg) const;
const memory_t *memory(primitive_arg_index_t arg) const;
private:
stream_t *stream_;
exec_args_t args_;
};
}
}
#endif

Some files were not shown because too many files have changed in this diff Show more