1bea8e1eac
-Added LocalVector (needed it) -Added stb_rect_pack (It's pretty cool, we could probably use it for other stuff too) -Fixes and changes all around the place -Added library for 128 bits fixed point (required for Delaunay3D)
112 lines
4.7 KiB
C++
112 lines
4.7 KiB
C++
// ======================================================================== //
|
|
// Copyright 2009-2019 Intel Corporation //
|
|
// //
|
|
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
|
// you may not use this file except in compliance with the License. //
|
|
// You may obtain a copy of the License at //
|
|
// //
|
|
// http://www.apache.org/licenses/LICENSE-2.0 //
|
|
// //
|
|
// Unless required by applicable law or agreed to in writing, software //
|
|
// distributed under the License is distributed on an "AS IS" BASIS, //
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
|
// See the License for the specific language governing permissions and //
|
|
// limitations under the License. //
|
|
// ======================================================================== //
|
|
|
|
#include "common/tensor.h"
|
|
#include "image.h"
|
|
#include "node.h"
|
|
#include "input_reorder.h"
|
|
#include "output_reorder.h"
|
|
#include "transfer_function.h"
|
|
|
|
#pragma once
|
|
|
|
namespace oidn {
|
|
|
|
// Progress state
|
|
struct Progress
|
|
{
|
|
ProgressMonitorFunction func;
|
|
void* userPtr;
|
|
int taskCount;
|
|
};
|
|
|
|
class Executable
|
|
{
|
|
public:
|
|
virtual ~Executable() {}
|
|
virtual void execute(const Progress& progress, int taskIndex) = 0;
|
|
};
|
|
|
|
template<int K>
|
|
class Network : public Executable
|
|
{
|
|
public:
|
|
Network(const Ref<Device>& device, const std::map<std::string, Tensor>& weightMap);
|
|
|
|
void execute(const Progress& progress, int taskIndex) override;
|
|
|
|
std::shared_ptr<memory> allocTensor(const memory::dims& dims,
|
|
memory::format_tag format = memory::format_tag::any,
|
|
void* data = nullptr);
|
|
|
|
std::shared_ptr<memory> castTensor(const memory::dims& dims,
|
|
const std::shared_ptr<memory>& src,
|
|
size_t srcOffset = 0,
|
|
memory::format_tag format = memory::format_tag::any);
|
|
|
|
std::shared_ptr<memory> castTensor(const memory::dims& dims,
|
|
const std::shared_ptr<memory>& src,
|
|
const memory::dims& srcOffset);
|
|
|
|
void zeroTensor(const std::shared_ptr<memory>& dst);
|
|
|
|
memory::dims getInputReorderDims(const memory::dims& srcDims, int alignment);
|
|
|
|
std::shared_ptr<Node> addInputReorder(const Image& color,
|
|
const Image& albedo,
|
|
const Image& normal,
|
|
const std::shared_ptr<TransferFunction>& transferFunc,
|
|
int alignment,
|
|
const std::shared_ptr<memory>& userDst = nullptr);
|
|
|
|
std::shared_ptr<Node> addOutputReorder(const std::shared_ptr<memory>& src,
|
|
const std::shared_ptr<TransferFunction>& transferFunc,
|
|
const Image& output);
|
|
|
|
memory::dims getConvDims(const std::string& name, const memory::dims& srcDims);
|
|
std::shared_ptr<Node> addConv(const std::string& name,
|
|
const std::shared_ptr<memory>& src,
|
|
const std::shared_ptr<memory>& userDst = nullptr,
|
|
bool relu = true);
|
|
|
|
memory::dims getPoolDims(const memory::dims& srcDims);
|
|
std::shared_ptr<Node> addPool(const std::shared_ptr<memory>& src,
|
|
const std::shared_ptr<memory>& userDst = nullptr);
|
|
|
|
memory::dims getUpsampleDims(const memory::dims& srcDims);
|
|
std::shared_ptr<Node> addUpsample(const std::shared_ptr<memory>& src,
|
|
const std::shared_ptr<memory>& userDst = nullptr);
|
|
|
|
memory::dims getConcatDims(const memory::dims& src1Dims, const memory::dims& src2Dims);
|
|
|
|
std::shared_ptr<Node> addAutoexposure(const Image& color,
|
|
const std::shared_ptr<HDRTransferFunction>& transferFunc);
|
|
|
|
void finalize();
|
|
|
|
private:
|
|
Ref<Device> device;
|
|
engine eng;
|
|
stream sm;
|
|
std::vector<std::shared_ptr<Node>> nodes;
|
|
std::map<std::string, Tensor> weightMap;
|
|
|
|
// Memory allocation statistics
|
|
size_t activationAllocBytes = 0; // number of allocated activation bytes
|
|
size_t totalAllocBytes = 0; // total number of allocated bytes
|
|
};
|
|
|
|
} // namespace oidn
|