233 lines
6.2 KiB
C++
233 lines
6.2 KiB
C++
|
// ======================================================================== //
|
||
|
// Copyright 2009-2019 Intel Corporation //
|
||
|
// //
|
||
|
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
||
|
// you may not use this file except in compliance with the License. //
|
||
|
// You may obtain a copy of the License at //
|
||
|
// //
|
||
|
// http://www.apache.org/licenses/LICENSE-2.0 //
|
||
|
// //
|
||
|
// Unless required by applicable law or agreed to in writing, software //
|
||
|
// distributed under the License is distributed on an "AS IS" BASIS, //
|
||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
||
|
// See the License for the specific language governing permissions and //
|
||
|
// limitations under the License. //
|
||
|
// ======================================================================== //
|
||
|
|
||
|
#pragma once
|
||
|
|
||
|
#include "node.h"
|
||
|
#include "image.h"
|
||
|
|
||
|
namespace oidn {
|
||
|
|
||
|
// Input reorder node
|
||
|
template<int K, class TransferFunction>
|
||
|
class InputReorderNode : public Node
|
||
|
{
|
||
|
private:
|
||
|
// Source
|
||
|
Image color;
|
||
|
Image albedo;
|
||
|
Image normal;
|
||
|
|
||
|
// Destination
|
||
|
std::shared_ptr<memory> dst;
|
||
|
float* dstPtr;
|
||
|
int C2;
|
||
|
int H2;
|
||
|
int W2;
|
||
|
|
||
|
// Tile
|
||
|
int h1Begin;
|
||
|
int w1Begin;
|
||
|
int h2Begin;
|
||
|
int w2Begin;
|
||
|
int H;
|
||
|
int W;
|
||
|
|
||
|
std::shared_ptr<TransferFunction> transferFunc;
|
||
|
|
||
|
public:
|
||
|
InputReorderNode(const Image& color,
|
||
|
const Image& albedo,
|
||
|
const Image& normal,
|
||
|
const std::shared_ptr<memory>& dst,
|
||
|
const std::shared_ptr<TransferFunction>& transferFunc)
|
||
|
: color(color), albedo(albedo), normal(normal),
|
||
|
dst(dst),
|
||
|
h1Begin(0), w1Begin(0),
|
||
|
H(color.height), W(color.width),
|
||
|
transferFunc(transferFunc)
|
||
|
{
|
||
|
const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data;
|
||
|
assert(memory_desc_matches_tag(dstDesc, mkldnn_format_tag_t(BlockedFormat<K>::nChwKc)));
|
||
|
assert(dstDesc.ndims == 4);
|
||
|
assert(dstDesc.data_type == memory::data_type::f32);
|
||
|
assert(dstDesc.dims[0] == 1);
|
||
|
//assert(dstDesc.dims[1] >= getPadded<K>(C1));
|
||
|
|
||
|
dstPtr = (float*)dst->get_data_handle();
|
||
|
C2 = dstDesc.dims[1];
|
||
|
H2 = dstDesc.dims[2];
|
||
|
W2 = dstDesc.dims[3];
|
||
|
}
|
||
|
|
||
|
void setTile(int h1, int w1, int h2, int w2, int H, int W) override
|
||
|
{
|
||
|
h1Begin = h1;
|
||
|
w1Begin = w1;
|
||
|
h2Begin = h2;
|
||
|
w2Begin = w2;
|
||
|
this->H = H;
|
||
|
this->W = W;
|
||
|
}
|
||
|
|
||
|
void execute(stream& sm) override
|
||
|
{
|
||
|
assert(H + h1Begin <= color.height);
|
||
|
assert(W + w1Begin <= color.width);
|
||
|
assert(H + h2Begin <= H2);
|
||
|
assert(W + w2Begin <= W2);
|
||
|
|
||
|
parallel_nd(H2, [&](int h2)
|
||
|
{
|
||
|
const int h = h2 - h2Begin;
|
||
|
|
||
|
if (h >= 0 && h < H)
|
||
|
{
|
||
|
const int h1 = h + h1Begin;
|
||
|
|
||
|
// Zero pad
|
||
|
for (int w2 = 0; w2 < w2Begin; ++w2)
|
||
|
{
|
||
|
int c = 0;
|
||
|
while (c < C2)
|
||
|
store(h2, w2, c, 0.f);
|
||
|
}
|
||
|
|
||
|
// Reorder
|
||
|
for (int w = 0; w < W; ++w)
|
||
|
{
|
||
|
const int w1 = w + w1Begin;
|
||
|
const int w2 = w + w2Begin;
|
||
|
|
||
|
int c = 0;
|
||
|
storeColor(h2, w2, c, (float*)color.get(h1, w1));
|
||
|
if (albedo)
|
||
|
storeAlbedo(h2, w2, c, (float*)albedo.get(h1, w1));
|
||
|
if (normal)
|
||
|
storeNormal(h2, w2, c, (float*)normal.get(h1, w1));
|
||
|
while (c < C2)
|
||
|
store(h2, w2, c, 0.f);
|
||
|
}
|
||
|
|
||
|
// Zero pad
|
||
|
for (int w2 = W + w2Begin; w2 < W2; ++w2)
|
||
|
{
|
||
|
int c = 0;
|
||
|
while (c < C2)
|
||
|
store(h2, w2, c, 0.f);
|
||
|
}
|
||
|
}
|
||
|
else
|
||
|
{
|
||
|
// Zero pad
|
||
|
for (int w2 = 0; w2 < W2; ++w2)
|
||
|
{
|
||
|
int c = 0;
|
||
|
while (c < C2)
|
||
|
store(h2, w2, c, 0.f);
|
||
|
}
|
||
|
}
|
||
|
});
|
||
|
}
|
||
|
|
||
|
std::shared_ptr<memory> getDst() const override { return dst; }
|
||
|
|
||
|
private:
|
||
|
// Stores a single value
|
||
|
__forceinline void store(int h, int w, int& c, float value)
|
||
|
{
|
||
|
// Destination is in nChwKc format
|
||
|
float* dst_c = dstPtr + (H2*W2*K*(c/K)) + h*W2*K + w*K + (c%K);
|
||
|
*dst_c = value;
|
||
|
c++;
|
||
|
}
|
||
|
|
||
|
// Stores a color
|
||
|
__forceinline void storeColor(int h, int w, int& c, const float* values)
|
||
|
{
|
||
|
#pragma unroll
|
||
|
for (int i = 0; i < 3; ++i)
|
||
|
{
|
||
|
// Load the value
|
||
|
float x = values[i];
|
||
|
|
||
|
// Sanitize the value
|
||
|
x = maxSafe(x, 0.f);
|
||
|
|
||
|
// Apply the transfer function
|
||
|
x = transferFunc->forward(x);
|
||
|
|
||
|
// Store the value
|
||
|
store(h, w, c, x);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Stores an albedo
|
||
|
__forceinline void storeAlbedo(int h, int w, int& c, const float* values)
|
||
|
{
|
||
|
#pragma unroll
|
||
|
for (int i = 0; i < 3; ++i)
|
||
|
{
|
||
|
// Load the value
|
||
|
float x = values[i];
|
||
|
|
||
|
// Sanitize the value
|
||
|
x = clampSafe(x, 0.f, 1.f);
|
||
|
|
||
|
// Store the value
|
||
|
store(h, w, c, x);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Stores a normal
|
||
|
__forceinline void storeNormal(int h, int w, int& c, const float* values)
|
||
|
{
|
||
|
// Load the normal
|
||
|
float x = values[0];
|
||
|
float y = values[1];
|
||
|
float z = values[2];
|
||
|
|
||
|
// Compute the length of the normal
|
||
|
const float lengthSqr = sqr(x) + sqr(y) + sqr(z);
|
||
|
|
||
|
// Normalize the normal and transform it to [0..1]
|
||
|
if (isfinite(lengthSqr))
|
||
|
{
|
||
|
const float invLength = (lengthSqr > minVectorLengthSqr) ? rsqrt(lengthSqr) : 1.f;
|
||
|
|
||
|
const float scale = invLength * 0.5f;
|
||
|
const float offset = 0.5f;
|
||
|
|
||
|
x = x * scale + offset;
|
||
|
y = y * scale + offset;
|
||
|
z = z * scale + offset;
|
||
|
}
|
||
|
else
|
||
|
{
|
||
|
x = 0.f;
|
||
|
y = 0.f;
|
||
|
z = 0.f;
|
||
|
}
|
||
|
|
||
|
// Store the normal
|
||
|
store(h, w, c, x);
|
||
|
store(h, w, c, y);
|
||
|
store(h, w, c, z);
|
||
|
}
|
||
|
};
|
||
|
|
||
|
} // namespace oidn
|