// ======================================================================== // // Copyright 2009-2019 Intel Corporation // // // // Licensed under the Apache License, Version 2.0 (the "License"); // // you may not use this file except in compliance with the License. // // You may obtain a copy of the License at // // // // http://www.apache.org/licenses/LICENSE-2.0 // // // // Unless required by applicable law or agreed to in writing, software // // distributed under the License is distributed on an "AS IS" BASIS, // // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // // See the License for the specific language governing permissions and // // limitations under the License. // // ======================================================================== // #pragma once #include "node.h" #include "image.h" namespace oidn { // Output reorder node template<int K, class TransferFunction> class OutputReorderNode : public Node { private: // Source std::shared_ptr<memory> src; const float* srcPtr; int H1; int W1; // Destination Image output; // Tile int h1Begin; int w1Begin; int h2Begin; int w2Begin; int H; int W; std::shared_ptr<TransferFunction> transferFunc; public: OutputReorderNode(const std::shared_ptr<memory>& src, const Image& output, const std::shared_ptr<TransferFunction>& transferFunc) : src(src), output(output), h1Begin(0), w1Begin(0), h2Begin(0), w2Begin(0), H(output.height), W(output.width), transferFunc(transferFunc) { const mkldnn_memory_desc_t& srcDesc = src->get_desc().data; MAYBE_UNUSED(srcDesc); assert(memory_desc_matches_tag(srcDesc, mkldnn_format_tag_t(BlockedFormat<K>::nChwKc))); assert(srcDesc.ndims == 4); assert(srcDesc.data_type == memory::data_type::f32); assert(srcDesc.dims[0] == 1); // We assume output data is <= K OC assert(srcDesc.dims[1] == K); srcPtr = (float*)src->get_data_handle(); H1 = srcDesc.dims[2]; W1 = srcDesc.dims[3]; } void setTile(int h1, int w1, int h2, int w2, int H, int W) override { h1Begin = h1; w1Begin = w1; h2Begin = h2; w2Begin = w2; this->H = H; this->W = W; } void execute(stream& sm) override { assert(h1Begin + H <= H1); assert(w1Begin + W <= W1); assert(h2Begin + H <= output.height); assert(w2Begin + W <= output.width); const int C1 = K; parallel_nd(H, [&](int h) { const int h1 = h + h1Begin; const int h2 = h + h2Begin; for (int w = 0; w < W; ++w) { const int w1 = w + w1Begin; const int w2 = w + w2Begin; float* dstPtr_C = (float*)output.get(h2, w2); // Source is in nChwKc format. In this case C is 1 so this is really nhwc const float* srcPtr_C = srcPtr + h1*W1*C1 + w1*C1; #pragma unroll for (int i = 0; i < 3; ++i) { // Load the value float x = srcPtr_C[i]; // The CNN output may contain negative values or even NaNs, so it must be sanitized x = maxSafe(x, 0.f); // Apply the inverse transfer function x = transferFunc->inverse(x); // Sanitize and store the final value dstPtr_C[i] = max(x, 0.f); } } }); } }; } // namespace oidn