virtualx-engine/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.hpp
Juan Linietsky 1bea8e1eac New lightmapper
-Added LocalVector (needed it)
-Added stb_rect_pack (It's pretty cool, we could probably use it for other stuff too)
-Fixes and changes all around the place
-Added library for 128 bits fixed point (required for Delaunay3D)
2020-05-10 15:59:09 -03:00

502 lines
18 KiB
C++

/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifndef CPU_REF_DECONVOLUTION_HPP
#define CPU_REF_DECONVOLUTION_HPP
#include <assert.h>
#include <string.h>
#include "c_types_map.hpp"
#include "type_helpers.hpp"
#include "utils.hpp"
#include "primitive_iterator.hpp"
#include "cpu_convolution_pd.hpp"
#include "cpu_deconvolution_pd.hpp"
#include "cpu_primitive.hpp"
namespace mkldnn {
namespace impl {
namespace cpu {
static status_t compute_blocked_format(bool with_groups,
const memory_desc_t *oi_md, memory_desc_t *io_md)
{
/* Computes blocking for *i*o* format from *o*i* format */
bool sanity_check_ok = true
&& oi_md->ndims == io_md->ndims
&& oi_md->format_kind == format_kind::blocked;
if (!sanity_check_ok) return status::invalid_arguments;
const blocking_desc_t &oi_blk = oi_md->format_desc.blocking;
blocking_desc_t io_blk = io_md->format_desc.blocking;
io_md->format_kind = format_kind::blocked;
io_blk = oi_blk;
const int ID_OC = 0 + with_groups;
const int ID_IC = 1 + with_groups;
nstl::swap(io_blk.strides[ID_OC], io_blk.strides[ID_IC]);
for (int i_blk = 0; i_blk < io_blk.inner_nblks; ++i_blk) {
if (utils::one_of(io_blk.inner_idxs[i_blk], ID_OC, ID_IC)) {
io_blk.inner_idxs[i_blk] =
(io_blk.inner_idxs[i_blk] == ID_OC ? ID_IC : ID_OC);
}
}
return memory_desc_init_by_blocking_desc(*io_md, io_blk);
}
static status_t conv_descr_create(const deconvolution_desc_t *dd,
convolution_desc_t *cd)
{
using namespace prop_kind;
alg_kind_t alg_kind = dd->alg_kind == alg_kind::deconvolution_direct
? alg_kind::convolution_direct : alg_kind::convolution_winograd;
const memory_desc_t *src_md, *dst_md, *d_weights_d;
prop_kind_t prop_kind;
memory_desc_t c_weights_d;
if (utils::one_of(dd->prop_kind, forward_training, forward_inference)) {
prop_kind = backward_data;
src_md = &dd->dst_desc;
dst_md = &dd->src_desc;
d_weights_d = &dd->weights_desc;
} else if (dd->prop_kind == backward_data) {
prop_kind = forward_training;
src_md = &dd->diff_dst_desc;
dst_md = &dd->diff_src_desc;
d_weights_d = &dd->weights_desc;
} else {
prop_kind = dd->prop_kind;
src_md = &dd->diff_dst_desc;
dst_md = &dd->src_desc;
d_weights_d = &dd->diff_weights_desc;
}
const bool with_groups = d_weights_d->ndims == src_md->ndims + 1;
/* create weights desc for convolution */
c_weights_d = *d_weights_d;
const int ID_OC = 0 + with_groups;
const int ID_IC = 1 + with_groups;
nstl::swap(c_weights_d.dims[ID_OC], c_weights_d.dims[ID_IC]);
nstl::swap(c_weights_d.padded_dims[ID_OC], c_weights_d.padded_dims[ID_IC]);
nstl::swap(c_weights_d.padded_offsets[ID_OC], c_weights_d.padded_offsets[ID_IC]);
if (c_weights_d.format_kind != format_kind::any)
CHECK(compute_blocked_format(with_groups, d_weights_d, &c_weights_d));
return conv_desc_init(cd, prop_kind, alg_kind, src_md, &c_weights_d,
prop_kind != backward_weights ? &dd->bias_desc : nullptr,
dst_md, dd->strides, dd->dilates,
dd->padding[0], dd->padding[1], dd->padding_kind);
}
struct ref_deconvolution_fwd_t: public cpu_primitive_t {
struct pd_t: public cpu_deconvolution_fwd_pd_t {
pd_t(engine_t *engine,
const deconvolution_desc_t *adesc,
const primitive_attr_t *attr,
const deconvolution_fwd_pd_t *hint_fwd_pd)
: cpu_deconvolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
, conv_pd_(nullptr)
{}
pd_t(const pd_t &other)
: cpu_deconvolution_fwd_pd_t(other)
, conv_pd_(other.conv_pd_->clone())
, conv_supports_bias_(other.conv_supports_bias_)
, dst_tag_(other.dst_tag_)
{}
~pd_t() { delete conv_pd_; }
DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_fwd_t);
status_t init_convolution() {
using namespace types;
convolution_desc_t cd;
CHECK(conv_descr_create(desc(), &cd));
mkldnn_primitive_desc_iterator it(engine_, (op_desc_t *)&cd,
&attr_, nullptr);
while (++it != it.end()) {
conv_pd_ = *it;
conv_supports_bias_ =
static_cast<cpu_convolution_bwd_data_pd_t *>(conv_pd_)
->support_bias();
bool output_f32 = utils::everyone_is(data_type::f32,
desc()->accum_data_type, desc()->dst_desc.data_type);
bool ok = true
&& conv_pd_->weights_md()->extra.flags == 0
/* deconv reference code can process only f32 bias */
&& IMPLICATION(with_bias(),
conv_supports_bias_ || output_f32);
if (ok) return status::success;
delete conv_pd_;
}
conv_pd_ = nullptr;
return status::unimplemented;
}
status_t init() {
using namespace format_tag;
bool ok = true
&& is_fwd()
&& utils::one_of(desc()->alg_kind,
alg_kind::deconvolution_direct,
alg_kind::deconvolution_winograd)
&& attr()->post_ops_.has_default_values();
if (ok) {
CHECK(init_convolution());
if (weights_md_.format_kind == format_kind::any) {
CHECK(compute_blocked_format(with_groups(),
conv_pd_->weights_md(), &desc_.weights_desc));
weights_md_ = desc_.weights_desc;
}
if (src_md_.format_kind == format_kind::any)
src_md_ = *conv_pd_->diff_dst_md();
if (dst_md_.format_kind == format_kind::any)
dst_md_ = *conv_pd_->diff_src_md();
if (bias_md_.format_kind == format_kind::any)
CHECK(memory_desc_init_by_tag(bias_md_, x));
dst_tag_ = memory_desc_matches_one_of_tag(dst_md_,
utils::pick(ndims() - 3, ncw, nchw, ncdhw),
utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c),
utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c));
return status::success;
}
return status::unimplemented;
}
virtual void init_scratchpad_md() override {
scratchpad_md_ = *conv_pd_->scratchpad_md();
}
primitive_desc_t *conv_pd_;
bool conv_supports_bias_;
format_tag_t dst_tag_;
};
typedef typename prec_traits<data_type::f32>::type data_t;
ref_deconvolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd)
{ pd()->conv_pd_->create_primitive((primitive_t **)&conv_p_); }
~ref_deconvolution_fwd_t() { delete conv_p_; }
virtual status_t execute(const exec_ctx_t &ctx) const override {
const auto &args = ctx.args();
exec_args_t conv_args;
conv_args[MKLDNN_ARG_DIFF_DST] = args.at(MKLDNN_ARG_SRC);
conv_args[MKLDNN_ARG_WEIGHTS] = args.at(MKLDNN_ARG_WEIGHTS);
if (pd()->with_bias() && pd()->conv_supports_bias_)
conv_args[MKLDNN_ARG_BIAS] = args.at(MKLDNN_ARG_BIAS);
conv_args[MKLDNN_ARG_DIFF_SRC] = args.at(MKLDNN_ARG_DST);
if (!types::is_zero_md(pd()->scratchpad_md()))
conv_args[MKLDNN_ARG_SCRATCHPAD] = args.at(MKLDNN_ARG_SCRATCHPAD);
const exec_ctx_t conv_ctx(ctx.stream(), std::move(conv_args));
conv_p_->execute(conv_ctx);
if (pd()->with_bias() && !pd()->conv_supports_bias_) {
using namespace format_tag;
auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS);
auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
switch (pd()->dst_tag_) {
case ncdhw: case nchw: case ncw:
compute_fwd_bias_ncdhw(bias, dst);
break;
case nCdhw8c: case nChw8c: case nCw8c:
compute_fwd_bias_nCdhwXc<8>(bias, dst);
break;
case nCdhw16c: case nChw16c: case nCw16c:
compute_fwd_bias_nCdhwXc<16>(bias, dst);
break;
default:
compute_fwd_bias(bias, dst);
break;
}
}
return status::success;
}
private:
void compute_fwd_bias(const data_t *bias, data_t *dst) const;
void compute_fwd_bias_ncdhw(const data_t *bias, data_t *dst) const;
template <int blksize> void compute_fwd_bias_nCdhwXc(const data_t *bias,
data_t *dst) const;
const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
primitive_t *conv_p_;
};
struct ref_deconvolution_bwd_data_t: public cpu_primitive_t {
struct pd_t: public cpu_deconvolution_bwd_data_pd_t {
pd_t(engine_t *engine, const deconvolution_desc_t *adesc,
const primitive_attr_t *attr,
const deconvolution_fwd_pd_t *hint_fwd_pd)
: cpu_deconvolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
, conv_pd_(nullptr)
{}
pd_t(const pd_t &other)
: cpu_deconvolution_bwd_data_pd_t(other)
, conv_pd_(other.conv_pd_->clone()) {}
~pd_t() { delete conv_pd_; }
DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_bwd_data_t);
status_t init_convolution() {
using namespace types;
convolution_desc_t cd;
status_t status = conv_descr_create(desc(), &cd);
if (status != status::success) return status;
mkldnn_primitive_desc_iterator it(engine_, (op_desc_t *)&cd,
&attr_, nullptr);
while (++it != it.end()) {
conv_pd_ = *it;
if (conv_pd_->weights_md()->extra.flags == 0)
return status::success;
delete conv_pd_;
}
return status::unimplemented;
}
status_t init() {
using namespace data_type;
bool ok = true
&& desc()->prop_kind == prop_kind::backward_data
&& utils::everyone_is(data_type::f32,
desc()->diff_src_desc.data_type,
desc()->weights_desc.data_type,
desc()->diff_dst_desc.data_type)
&& utils::one_of(desc()->alg_kind,
alg_kind::deconvolution_direct,
alg_kind::deconvolution_winograd);
if (ok) {
CHECK(init_convolution());
if (weights_md_.format_kind == format_kind::any) {
CHECK(compute_blocked_format(with_groups(),
conv_pd_->weights_md(), &desc_.weights_desc));
weights_md_ = desc_.weights_desc;
}
if (diff_src_md_.format_kind == format_kind::any)
diff_src_md_ = *conv_pd_->dst_md();
if (diff_dst_md_.format_kind == format_kind::any)
diff_dst_md_ = *conv_pd_->src_md();
return status::success;
}
return status::unimplemented;
}
virtual void init_scratchpad_md() override {
scratchpad_md_ = *conv_pd_->scratchpad_md();
}
primitive_desc_t *conv_pd_;
};
typedef typename prec_traits<data_type::f32>::type data_t;
ref_deconvolution_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd)
{ pd()->conv_pd_->create_primitive((primitive_t **)&conv_p_); }
~ref_deconvolution_bwd_data_t() { delete conv_p_; }
virtual status_t execute(const exec_ctx_t &ctx) const override {
const auto &args = ctx.args();
exec_args_t conv_args;
conv_args[MKLDNN_ARG_SRC] = args.at(MKLDNN_ARG_DIFF_DST);
conv_args[MKLDNN_ARG_WEIGHTS] = args.at(MKLDNN_ARG_WEIGHTS);
conv_args[MKLDNN_ARG_DST] = args.at(MKLDNN_ARG_DIFF_SRC);
if (!types::is_zero_md(pd()->scratchpad_md()))
conv_args[MKLDNN_ARG_SCRATCHPAD] = args.at(MKLDNN_ARG_SCRATCHPAD);
const exec_ctx_t conv_ctx(ctx.stream(), std::move(conv_args));
conv_p_->execute(conv_ctx);
return status::success;
}
private:
const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
primitive_t *conv_p_;
};
struct ref_deconvolution_bwd_weights_t: public cpu_primitive_t {
struct pd_t: public cpu_deconvolution_bwd_weights_pd_t {
pd_t(engine_t *engine,
const deconvolution_desc_t *adesc,
const primitive_attr_t *attr,
const deconvolution_fwd_pd_t *hint_fwd_pd)
: cpu_deconvolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd)
, conv_pd_(nullptr)
{}
pd_t(const pd_t &other)
: cpu_deconvolution_bwd_weights_pd_t(other)
, conv_pd_(other.conv_pd_->clone())
, dst_tag_(other.dst_tag_)
{}
~pd_t() { delete conv_pd_; }
DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_bwd_weights_t);
status_t init_convolution() {
using namespace types;
convolution_desc_t cd;
status_t status = conv_descr_create(desc(), &cd);
if (status != status::success) return status;
mkldnn_primitive_desc_iterator it(engine_, (op_desc_t *)&cd,
&attr_, nullptr);
while (++it != it.end()) {
conv_pd_ = *it;
if (conv_pd_->diff_weights_md()->extra.flags == 0)
return status::success;
delete conv_pd_;
}
return status::unimplemented;
}
status_t init() {
using namespace format_tag;
bool ok = true
&& desc()->prop_kind == prop_kind::backward_weights
&& utils::everyone_is(data_type::f32,
desc()->src_desc.data_type,
desc()->diff_weights_desc.data_type,
desc()->diff_dst_desc.data_type)
&& utils::one_of(desc()->alg_kind,
alg_kind::deconvolution_direct,
alg_kind::deconvolution_winograd)
&& attr()->has_default_values();
if (ok) {
CHECK(init_convolution());
if (diff_weights_md_.format_kind == format_kind::any) {
CHECK(compute_blocked_format(with_groups(),
conv_pd_->diff_weights_md(),
&desc_.diff_weights_desc));
diff_weights_md_ = desc_.diff_weights_desc;
}
if (src_md_.format_kind == format_kind::any)
src_md_ = *conv_pd_->diff_dst_md();
if (diff_dst_md_.format_kind == format_kind::any)
diff_dst_md_ = *conv_pd_->src_md();
if (diff_bias_md_.format_kind == format_kind::any)
CHECK(memory_desc_init_by_tag(diff_bias_md_, x));
dst_tag_ = memory_desc_matches_one_of_tag(diff_dst_md_,
utils::pick(ndims() - 3, ncw, nchw, ncdhw),
utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c),
utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c));
return status::success;
}
return status::unimplemented;
}
virtual void init_scratchpad_md() override {
scratchpad_md_ = *conv_pd_->scratchpad_md();
}
primitive_desc_t *conv_pd_;
format_tag_t dst_tag_;
};
typedef typename prec_traits<data_type::f32>::type data_t;
ref_deconvolution_bwd_weights_t(const pd_t *apd): cpu_primitive_t(apd)
{ pd()->conv_pd_->create_primitive((primitive_t **)&conv_p_); }
~ref_deconvolution_bwd_weights_t() { delete conv_p_; }
virtual status_t execute(const exec_ctx_t &ctx) const override {
const auto &args = ctx.args();
exec_args_t conv_args;
conv_args[MKLDNN_ARG_DIFF_DST] = args.at(MKLDNN_ARG_SRC);
conv_args[MKLDNN_ARG_SRC] = args.at(MKLDNN_ARG_DIFF_DST);
conv_args[MKLDNN_ARG_DIFF_WEIGHTS] = args.at(MKLDNN_ARG_DIFF_WEIGHTS);
if (!types::is_zero_md(pd()->scratchpad_md()))
conv_args[MKLDNN_ARG_SCRATCHPAD] = args.at(MKLDNN_ARG_SCRATCHPAD);
const exec_ctx_t conv_ctx(ctx.stream(), std::move(conv_args));
status_t status = conv_p_->execute(conv_ctx);
if (status != status::success) return status;
if (pd()->with_bias()) {
using namespace format_tag;
auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
auto diff_bias = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS);
switch (pd()->dst_tag_) {
case ncdhw: case nchw: case ncw:
compute_bwd_bias_ncdhw(diff_dst, diff_bias);
break;
case nCdhw8c: case nChw8c: case nCw8c:
compute_bwd_bias_nCdhwXc<8>(diff_dst, diff_bias);
break;
case nCdhw16c: case nChw16c: case nCw16c:
compute_bwd_bias_nCdhwXc<16>(diff_dst, diff_bias);
break;
default:
compute_bwd_bias(diff_dst, diff_bias);
break;
}
}
return status::success;
}
private:
const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
void compute_bwd_bias(const data_t *diff_dst, data_t *diff_bias) const;
void compute_bwd_bias_ncdhw(const data_t *diff_dst,
data_t *diff_bias) const;
template <int blksize> void compute_bwd_bias_nCdhwXc(
const data_t *diff_dst, data_t *diff_bias) const;
primitive_t *conv_p_;
};
}
}
}
#endif
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s