feat: 切换后端至PaddleOCR-NCNN,切换工程为CMake
1.项目后端整体迁移至PaddleOCR-NCNN算法,已通过基本的兼容性测试 2.工程改为使用CMake组织,后续为了更好地兼容第三方库,不再提供QMake工程 3.重整权利声明文件,重整代码工程,确保最小化侵权风险 Log: 切换后端至PaddleOCR-NCNN,切换工程为CMake Change-Id: I4d5d2c5d37505a4a24b389b1a4c5d12f17bfa38c
This commit is contained in:
61
3rdparty/ncnn/tools/mlir/CMakeLists.txt
vendored
Normal file
61
3rdparty/ncnn/tools/mlir/CMakeLists.txt
vendored
Normal file
@ -0,0 +1,61 @@
|
||||
project(mlir2ncnn)
|
||||
|
||||
cmake_minimum_required(VERSION 3.10)
|
||||
set(CMAKE_CXX_STANDARD 14)
|
||||
|
||||
set(LLVM_PROJECT_INSTALL_DIR "/home/nihui/osd/llvm-project/build/install" CACHE STRING "")
|
||||
|
||||
set(LLVM_DIR "${LLVM_PROJECT_INSTALL_DIR}/lib/cmake/llvm")
|
||||
find_package(LLVM REQUIRED)
|
||||
|
||||
set(MLIR_DIR "${LLVM_PROJECT_INSTALL_DIR}/lib/cmake/mlir")
|
||||
find_package(MLIR REQUIRED)
|
||||
|
||||
add_definitions(-fno-rtti -fno-exceptions)
|
||||
|
||||
include_directories("${LLVM_PROJECT_INSTALL_DIR}/include")
|
||||
|
||||
include_directories(${CMAKE_CURRENT_BINARY_DIR})
|
||||
|
||||
include(${LLVM_DIR}/TableGen.cmake)
|
||||
include(${MLIR_DIR}/AddMLIR.cmake)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS tf_ops.td)
|
||||
mlir_tablegen(tf_all_ops.h.inc -gen-op-decls)
|
||||
mlir_tablegen(tf_all_ops.cc.inc -gen-op-defs)
|
||||
add_public_tablegen_target(tf_opsIncGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS ncnn_ops.td)
|
||||
mlir_tablegen(ncnn_ops.h.inc -gen-op-decls)
|
||||
mlir_tablegen(ncnn_ops.cc.inc -gen-op-defs)
|
||||
add_public_tablegen_target(ncnn_opsIncGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS ncnn_rewriter.td)
|
||||
mlir_tablegen(ncnn_rewriter.inc -gen-rewriters)
|
||||
add_public_tablegen_target(ncnn_rewriterIncGen)
|
||||
|
||||
add_executable(mlir2ncnn
|
||||
mlir2ncnn.cpp
|
||||
ncnn_dialect.cpp
|
||||
ncnn_rewriter.cpp
|
||||
tf_dialect.cpp
|
||||
tf_attributes.cc
|
||||
tf_types.cc
|
||||
)
|
||||
|
||||
add_dependencies(mlir2ncnn
|
||||
tf_opsIncGen
|
||||
ncnn_opsIncGen
|
||||
ncnn_rewriterIncGen
|
||||
)
|
||||
|
||||
target_link_libraries(mlir2ncnn
|
||||
MLIRIR
|
||||
MLIRDialect
|
||||
MLIRInferTypeOpInterface
|
||||
MLIRParser
|
||||
MLIRPass
|
||||
MLIRStandard
|
||||
MLIRTransforms
|
||||
)
|
||||
ncnn_install_tool(mlir2ncnn)
|
13
3rdparty/ncnn/tools/mlir/fix_td.sh
vendored
Normal file
13
3rdparty/ncnn/tools/mlir/fix_td.sh
vendored
Normal file
@ -0,0 +1,13 @@
|
||||
#!/bin/sh
|
||||
|
||||
# This dirty script eat td files :P
|
||||
# https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/mlir/tensorflow/ir
|
||||
|
||||
sed -i 's!tensorflow/compiler/mlir/tensorflow/ir/!!g' *.td tf_traits.h tf_types.h tf_types.cc tf_attributes.cc
|
||||
|
||||
sed -i '/let hasCanonicalizer = 1;/d' *.td
|
||||
sed -i '/let hasFolder = 1;/d' *.td
|
||||
sed -i '/StringRef GetOptimalLayout(const RuntimeDevices& devices);/d' *.td
|
||||
sed -i '/LogicalResult UpdateDataFormat(StringRef data_format);/d' *.td
|
||||
sed -i '/LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);/d' *.td
|
||||
sed -i '/Optional<ContractionFusion> GetContractionFusion();/d' *.td
|
1819
3rdparty/ncnn/tools/mlir/mlir2ncnn.cpp
vendored
Normal file
1819
3rdparty/ncnn/tools/mlir/mlir2ncnn.cpp
vendored
Normal file
File diff suppressed because it is too large
Load Diff
41
3rdparty/ncnn/tools/mlir/ncnn_dialect.cpp
vendored
Normal file
41
3rdparty/ncnn/tools/mlir/ncnn_dialect.cpp
vendored
Normal file
@ -0,0 +1,41 @@
|
||||
// Tencent is pleased to support the open source community by making ncnn available.
|
||||
//
|
||||
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
|
||||
//
|
||||
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
|
||||
// in compliance with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// https://opensource.org/licenses/BSD-3-Clause
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations under the License.
|
||||
|
||||
#include "ncnn_dialect.h"
|
||||
|
||||
#include <mlir/IR/Builders.h>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
namespace ncnn {
|
||||
|
||||
NCNNDialect::NCNNDialect(mlir::MLIRContext* context)
|
||||
: mlir::Dialect("ncnn", context, TypeID::get<NCNNDialect>())
|
||||
{
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "ncnn_ops.cc.inc"
|
||||
>();
|
||||
|
||||
// Support unknown operations because not all NCNN operations are
|
||||
// registered.
|
||||
allowUnknownOperations();
|
||||
}
|
||||
|
||||
} // namespace ncnn
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "ncnn_ops.cc.inc"
|
||||
|
||||
} // namespace mlir
|
47
3rdparty/ncnn/tools/mlir/ncnn_dialect.h
vendored
Normal file
47
3rdparty/ncnn/tools/mlir/ncnn_dialect.h
vendored
Normal file
@ -0,0 +1,47 @@
|
||||
// Tencent is pleased to support the open source community by making ncnn available.
|
||||
//
|
||||
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
|
||||
//
|
||||
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
|
||||
// in compliance with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// https://opensource.org/licenses/BSD-3-Clause
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations under the License.
|
||||
|
||||
#ifndef NCNN_DIALECT_H
|
||||
#define NCNN_DIALECT_H
|
||||
|
||||
#include <mlir/IR/BuiltinTypes.h>
|
||||
#include <mlir/IR/Dialect.h>
|
||||
#include <mlir/Interfaces/SideEffectInterfaces.h>
|
||||
#include <mlir/Pass/Pass.h>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
namespace ncnn {
|
||||
|
||||
class NCNNDialect : public mlir::Dialect
|
||||
{
|
||||
public:
|
||||
NCNNDialect(mlir::MLIRContext* context);
|
||||
|
||||
static StringRef getDialectNamespace()
|
||||
{
|
||||
return "ncnn";
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp> > createNCNNOptimizePass();
|
||||
|
||||
} // namespace ncnn
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "ncnn_ops.h.inc"
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // NCNN_DIALECT_H
|
133
3rdparty/ncnn/tools/mlir/ncnn_ops.td
vendored
Normal file
133
3rdparty/ncnn/tools/mlir/ncnn_ops.td
vendored
Normal file
@ -0,0 +1,133 @@
|
||||
// Tencent is pleased to support the open source community by making ncnn available.
|
||||
//
|
||||
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
|
||||
//
|
||||
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
|
||||
// in compliance with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// https://opensource.org/licenses/BSD-3-Clause
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations under the License.
|
||||
|
||||
#ifndef NCNN_OPS_TD
|
||||
#define NCNN_OPS_TD
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "tf_op_base.td"
|
||||
|
||||
def NCNN_Dialect : Dialect {
|
||||
let name = "ncnn";
|
||||
let cppNamespace = "ncnn";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NCNN op definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class NCNN_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<NCNN_Dialect, mnemonic, traits>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NCNN operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def NCNN_KerasConv2DOp : NCNN_Op<"KerasConv2D", [NoSideEffect]> {
|
||||
|
||||
let arguments = (ins
|
||||
F32Tensor:$x,
|
||||
F32Tensor:$weight,
|
||||
F32Tensor:$bias,
|
||||
|
||||
I64ArrayAttr:$strides,
|
||||
TF_AnyStrAttrOf<["SAME", "VALID", "EXPLICIT"]>:$padding,
|
||||
DefaultValuedAttr<I64ArrayAttr, "{}">:$explicit_paddings,
|
||||
DefaultValuedAttr<I64ArrayAttr, "{1, 1, 1, 1}">:$dilations
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
F32Tensor:$y
|
||||
);
|
||||
}
|
||||
|
||||
def NCNN_KerasDenseOp : NCNN_Op<"KerasDense", [NoSideEffect]> {
|
||||
|
||||
let arguments = (ins
|
||||
F32Tensor:$x,
|
||||
F32Tensor:$weight,
|
||||
F32Tensor:$bias
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
F32Tensor:$y
|
||||
);
|
||||
}
|
||||
|
||||
def NCNN_KerasBatchNormOp : NCNN_Op<"KerasBatchNorm", [NoSideEffect]> {
|
||||
|
||||
let arguments = (ins
|
||||
F32Tensor:$x,
|
||||
F32Tensor:$gamma,
|
||||
F32Tensor:$bias
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
F32Tensor:$y
|
||||
);
|
||||
}
|
||||
|
||||
def NCNN_BinaryOpOp : NCNN_Op<"BinaryOp", [NoSideEffect]> {
|
||||
|
||||
let arguments = (ins
|
||||
F32Tensor:$x,
|
||||
I32Attr:$op_type,
|
||||
I32Attr:$with_scalar,
|
||||
F32Attr:$b
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
F32Tensor:$y
|
||||
);
|
||||
}
|
||||
|
||||
def NCNN_InstanceNormOp : NCNN_Op<"InstanceNorm", [NoSideEffect]> {
|
||||
|
||||
let arguments = (ins
|
||||
F32Tensor:$x,
|
||||
F32Attr:$epsilon
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
F32Tensor:$y
|
||||
);
|
||||
}
|
||||
|
||||
def NCNN_InstanceNormAffineOp : NCNN_Op<"InstanceNormAffine", [NoSideEffect]> {
|
||||
|
||||
let arguments = (ins
|
||||
F32Tensor:$x,
|
||||
F32Tensor:$gamma,
|
||||
F32Tensor:$beta,
|
||||
F32Attr:$epsilon
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
F32Tensor:$y
|
||||
);
|
||||
}
|
||||
|
||||
def NCNN_SwishOp : NCNN_Op<"Swish", [NoSideEffect]> {
|
||||
|
||||
let arguments = (ins
|
||||
F32Tensor:$x
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
F32Tensor:$y
|
||||
);
|
||||
}
|
||||
|
||||
#endif // NCNN_OPS_TD
|
54
3rdparty/ncnn/tools/mlir/ncnn_rewriter.cpp
vendored
Normal file
54
3rdparty/ncnn/tools/mlir/ncnn_rewriter.cpp
vendored
Normal file
@ -0,0 +1,54 @@
|
||||
// Tencent is pleased to support the open source community by making ncnn available.
|
||||
//
|
||||
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
|
||||
//
|
||||
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
|
||||
// in compliance with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// https://opensource.org/licenses/BSD-3-Clause
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations under the License.
|
||||
|
||||
#include <mlir/IR/MLIRContext.h>
|
||||
#include <mlir/IR/PatternMatch.h>
|
||||
#include <mlir/Pass/Pass.h>
|
||||
#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
|
||||
|
||||
#include "tf_dialect.h"
|
||||
#include "ncnn_dialect.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace mlir {
|
||||
|
||||
namespace ncnn {
|
||||
|
||||
#include "ncnn_rewriter.inc"
|
||||
|
||||
class NCNNOptimizePass : public PassWrapper<NCNNOptimizePass, FunctionPass>
|
||||
{
|
||||
public:
|
||||
void runOnFunction();
|
||||
};
|
||||
|
||||
void NCNNOptimizePass::runOnFunction()
|
||||
{
|
||||
mlir::OwningRewritePatternList patterns;
|
||||
mlir::ncnn::populateWithGenerated(&getContext(), patterns);
|
||||
|
||||
(void)mlir::applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp> > createNCNNOptimizePass()
|
||||
{
|
||||
return std::make_unique<NCNNOptimizePass>();
|
||||
}
|
||||
|
||||
static PassRegistration<NCNNOptimizePass> pass("ncnn-optimize", "ncnn optimization");
|
||||
|
||||
} // namespace ncnn
|
||||
|
||||
} // namespace mlir
|
210
3rdparty/ncnn/tools/mlir/ncnn_rewriter.td
vendored
Normal file
210
3rdparty/ncnn/tools/mlir/ncnn_rewriter.td
vendored
Normal file
@ -0,0 +1,210 @@
|
||||
// Tencent is pleased to support the open source community by making ncnn available.
|
||||
//
|
||||
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
|
||||
//
|
||||
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
|
||||
// in compliance with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// https://opensource.org/licenses/BSD-3-Clause
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations under the License.
|
||||
|
||||
#ifndef NCNN_REWRITER_TD
|
||||
#define NCNN_REWRITER_TD
|
||||
|
||||
include "tf_ops.td"
|
||||
include "ncnn_ops.td"
|
||||
|
||||
def get_attr_f : NativeCodeCall<"$0.getValue<FloatAttr>(0)">;
|
||||
|
||||
def OneElementAttrPred : CPred<"$_self.cast<ElementsAttr>().getType().getNumElements() == 1">;
|
||||
|
||||
def OneElementAttr : ElementsAttrBase<And<[ElementsAttr.predicate, OneElementAttrPred]>, "Scalar ElementsAttr">;
|
||||
|
||||
def EqualOperands : Constraint<CPred<"$0 == $1">>;
|
||||
|
||||
def FuseBinaryOpPattern0 : Pat<
|
||||
(TF_MulOp
|
||||
$x,
|
||||
(TF_ConstOp OneElementAttr:$b)
|
||||
),
|
||||
|
||||
(NCNN_BinaryOpOp $x, ConstantAttr<I32Attr, "2">, ConstantAttr<I32Attr, "1">, (get_attr_f $b))
|
||||
>;
|
||||
|
||||
def FuseBinaryOpPattern1 : Pat<
|
||||
(TF_AddV2Op
|
||||
$x,
|
||||
(TF_ConstOp OneElementAttr:$b)
|
||||
),
|
||||
|
||||
(NCNN_BinaryOpOp $x, ConstantAttr<I32Attr, "0">, ConstantAttr<I32Attr, "1">, (get_attr_f $b))
|
||||
>;
|
||||
|
||||
def FuseKerasConv2DOpPattern : Pat<
|
||||
(TF_BiasAddOp
|
||||
(TF_Conv2DOp $x, $weight, $strides, $use_cudnn_on_gpu, $padding, $explicit_paddings, $data_format, $dilations),
|
||||
$bias,
|
||||
$data_format_
|
||||
),
|
||||
|
||||
(NCNN_KerasConv2DOp $x, $weight, $bias, $strides, $padding, $explicit_paddings, $dilations)
|
||||
>;
|
||||
|
||||
def FuseKerasConv2DOpPattern1 : Pat<
|
||||
(TF_AddV2Op
|
||||
(TF_Conv2DOp $x, $weight, $strides, $use_cudnn_on_gpu, $padding, $explicit_paddings, $data_format, $dilations),
|
||||
$bias
|
||||
),
|
||||
|
||||
(NCNN_KerasConv2DOp $x, $weight, $bias, $strides, $padding, $explicit_paddings, $dilations)
|
||||
>;
|
||||
|
||||
def FuseKerasDenseOpPattern : Pat<
|
||||
(TF_BiasAddOp
|
||||
(TF_MatMulOp $x, $weight, $transpose_a, $transpose_b),
|
||||
$bias,
|
||||
$data_format_
|
||||
),
|
||||
|
||||
(NCNN_KerasDenseOp $x, $weight, $bias)
|
||||
>;
|
||||
|
||||
def NonOneElementAttrPred : CPred<"$_self.cast<ElementsAttr>().getType().getNumElements() != 1">;
|
||||
|
||||
def NonOneElementAttr : ElementsAttrBase<And<[ElementsAttr.predicate, NonOneElementAttrPred]>, "Non Scalar ElementsAttr">;
|
||||
|
||||
def FuseKerasBatchNormOpPattern : Pat<
|
||||
(TF_AddV2Op
|
||||
(TF_MulOp
|
||||
$x,
|
||||
(TF_ConstOp:$gamma NonOneElementAttr)
|
||||
),
|
||||
(TF_ConstOp:$bias NonOneElementAttr)
|
||||
),
|
||||
|
||||
(NCNN_KerasBatchNormOp $x, $gamma, $bias)
|
||||
>;
|
||||
|
||||
def FuseInstanceNormPattern0 : Pat<
|
||||
(TF_MulOp
|
||||
(TF_RsqrtOp
|
||||
(TF_AddV2Op
|
||||
(TF_MeanOp
|
||||
(TF_SquaredDifferenceOp
|
||||
(TF_MeanOp:$mean
|
||||
$x,
|
||||
(TF_ConstOp:$reduce_axis ElementsAttr),
|
||||
ConstBoolAttrTrue // keep_dims
|
||||
),
|
||||
$x_
|
||||
),
|
||||
$reduce_axis_,
|
||||
ConstBoolAttrTrue // keep_dims
|
||||
),
|
||||
(TF_ConstOp ElementsAttr:$epsilon)
|
||||
)
|
||||
),
|
||||
(TF_SubOp $x__, $mean_)
|
||||
),
|
||||
|
||||
(NCNN_InstanceNormOp $x, (get_attr_f $epsilon)),
|
||||
|
||||
[
|
||||
(EqualOperands $x, $x_),
|
||||
(EqualOperands $x, $x__),
|
||||
(EqualOperands $reduce_axis, $reduce_axis_),
|
||||
(EqualOperands $mean, $mean_)
|
||||
]
|
||||
>;
|
||||
|
||||
def FuseInstanceNormPattern1 : Pat<
|
||||
(TF_MulOp
|
||||
(TF_RsqrtOp
|
||||
(TF_AddV2Op
|
||||
(TF_MeanOp
|
||||
(TF_SquaredDifferenceOp
|
||||
$x_,
|
||||
(TF_MeanOp:$mean
|
||||
$x,
|
||||
(TF_ConstOp:$reduce_axis ElementsAttr),
|
||||
ConstBoolAttrTrue // keep_dims
|
||||
)
|
||||
),
|
||||
$reduce_axis_,
|
||||
ConstBoolAttrTrue // keep_dims
|
||||
),
|
||||
(TF_ConstOp ElementsAttr:$epsilon)
|
||||
)
|
||||
),
|
||||
(TF_SubOp $x__, $mean_)
|
||||
),
|
||||
|
||||
(NCNN_InstanceNormOp $x, (get_attr_f $epsilon)),
|
||||
|
||||
[
|
||||
(EqualOperands $x, $x_),
|
||||
(EqualOperands $x, $x__),
|
||||
(EqualOperands $reduce_axis, $reduce_axis_),
|
||||
(EqualOperands $mean, $mean_)
|
||||
]
|
||||
>;
|
||||
|
||||
def FuseInstanceNormAffinePattern : Pat<
|
||||
(TF_ReshapeOp
|
||||
(TF_AddV2Op
|
||||
(TF_MulOp
|
||||
$reshaped__,
|
||||
(TF_MulOp:$rsqrt_var_eps_gamma
|
||||
(TF_RsqrtOp
|
||||
(TF_AddV2Op
|
||||
(TF_MeanOp
|
||||
(TF_SquaredDifferenceOp
|
||||
$reshaped_,
|
||||
(TF_MeanOp:$mean
|
||||
(TF_ReshapeOp:$reshaped $x, (TF_ConstOp ElementsAttr)),
|
||||
(TF_ConstOp:$reduce_axis ElementsAttr),
|
||||
ConstBoolAttrTrue // keep_dims
|
||||
)
|
||||
),
|
||||
$reduce_axis_,
|
||||
ConstBoolAttrTrue // keep_dims
|
||||
),
|
||||
(TF_ConstOp ElementsAttr:$epsilon)
|
||||
)
|
||||
),
|
||||
$gamma
|
||||
)
|
||||
),
|
||||
(TF_SubOp
|
||||
$beta,
|
||||
(TF_MulOp $rsqrt_var_eps_gamma_, $mean_)
|
||||
)
|
||||
),
|
||||
(TF_ConstOp ElementsAttr)
|
||||
),
|
||||
|
||||
(NCNN_InstanceNormAffineOp $x, $gamma, $beta, (get_attr_f $epsilon)),
|
||||
|
||||
[
|
||||
(EqualOperands $reshaped, $reshaped_),
|
||||
(EqualOperands $reshaped, $reshaped__),
|
||||
(EqualOperands $reduce_axis, $reduce_axis_),
|
||||
(EqualOperands $rsqrt_var_eps_gamma, $rsqrt_var_eps_gamma_),
|
||||
(EqualOperands $mean, $mean_)
|
||||
]
|
||||
>;
|
||||
|
||||
def FuseSwishPattern : Pat<
|
||||
(TF_MulOp
|
||||
(TF_SigmoidOp $x),
|
||||
$x
|
||||
),
|
||||
|
||||
(NCNN_SwishOp $x)
|
||||
>;
|
||||
|
||||
#endif // NCNN_REWRITER_TD
|
163
3rdparty/ncnn/tools/mlir/tf_attributes.cc
vendored
Normal file
163
3rdparty/ncnn/tools/mlir/tf_attributes.cc
vendored
Normal file
@ -0,0 +1,163 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tf_attributes.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
|
||||
namespace detail {
|
||||
|
||||
// The storage class for ShapeAttr.
|
||||
struct ShapeAttrStorage : public AttributeStorage
|
||||
{
|
||||
using KeyTy = std::pair<ArrayRef<int64_t>, bool>;
|
||||
|
||||
explicit ShapeAttrStorage(ArrayRef<int64_t> shape, bool unranked = false)
|
||||
: shape(shape), unranked(unranked)
|
||||
{
|
||||
}
|
||||
|
||||
bool operator==(const KeyTy& key) const
|
||||
{
|
||||
return key == KeyTy(shape, unranked);
|
||||
}
|
||||
static unsigned hashKey(const KeyTy& key)
|
||||
{
|
||||
return llvm::hash_combine(key.first, static_cast<char>(key.second));
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static ShapeAttrStorage* construct(mlir::AttributeStorageAllocator& allocator,
|
||||
const KeyTy& key)
|
||||
{
|
||||
return new (allocator.allocate<ShapeAttrStorage>())
|
||||
ShapeAttrStorage(allocator.copyInto(key.first), key.second);
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> shape;
|
||||
bool unranked = false;
|
||||
};
|
||||
|
||||
// The storage class for FuncAttr.
|
||||
struct FuncAttrStorage : public AttributeStorage
|
||||
{
|
||||
using KeyTy = std::pair<Attribute, Attribute>;
|
||||
|
||||
explicit FuncAttrStorage(Attribute name, Attribute attrs)
|
||||
: name(name), attrs(attrs)
|
||||
{
|
||||
}
|
||||
|
||||
bool operator==(const KeyTy& key) const
|
||||
{
|
||||
return key == KeyTy(name, attrs);
|
||||
}
|
||||
static unsigned hashKey(const KeyTy& key)
|
||||
{
|
||||
return llvm::hash_combine(key.first, key.second);
|
||||
}
|
||||
|
||||
static FuncAttrStorage* construct(mlir::AttributeStorageAllocator& allocator,
|
||||
const KeyTy& key)
|
||||
{
|
||||
return new (allocator.allocate<FuncAttrStorage>())
|
||||
FuncAttrStorage(key.first, key.second);
|
||||
}
|
||||
|
||||
Attribute name;
|
||||
Attribute attrs;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Get or create a shape attribute.
|
||||
ShapeAttr ShapeAttr::get(mlir::MLIRContext* context,
|
||||
llvm::Optional<ArrayRef<int64_t> > shape)
|
||||
{
|
||||
if (shape) return Base::get(context, *shape, /*unranked=*/false);
|
||||
|
||||
return Base::get(context, ArrayRef<int64_t>(), /*unranked=*/true);
|
||||
}
|
||||
|
||||
// Get or create a shape attribute.
|
||||
ShapeAttr ShapeAttr::get(mlir::MLIRContext* context, ShapedType shaped_type)
|
||||
{
|
||||
if (shaped_type.hasRank())
|
||||
return Base::get(context, shaped_type.getShape(), /*unranked=*/false);
|
||||
|
||||
return Base::get(context, ArrayRef<int64_t>(), /*unranked=*/true);
|
||||
}
|
||||
|
||||
llvm::Optional<ArrayRef<int64_t> > ShapeAttr::getValue() const
|
||||
{
|
||||
if (hasRank()) return getShape();
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
bool ShapeAttr::hasRank() const
|
||||
{
|
||||
return !getImpl()->unranked;
|
||||
}
|
||||
|
||||
int64_t ShapeAttr::getRank() const
|
||||
{
|
||||
assert(hasRank());
|
||||
return getImpl()->shape.size();
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> ShapeAttr::getShape() const
|
||||
{
|
||||
assert(hasRank());
|
||||
return getImpl()->shape;
|
||||
}
|
||||
|
||||
bool ShapeAttr::hasStaticShape() const
|
||||
{
|
||||
if (!hasRank()) return false;
|
||||
|
||||
for (auto dim : getShape())
|
||||
{
|
||||
if (dim < 0) return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
FuncAttr FuncAttr::get(mlir::MLIRContext* context, llvm::StringRef name,
|
||||
DictionaryAttr attr)
|
||||
{
|
||||
auto symbol = SymbolRefAttr::get(context, name);
|
||||
return Base::get(context, symbol, attr);
|
||||
}
|
||||
|
||||
FuncAttr FuncAttr::get(mlir::MLIRContext* context, SymbolRefAttr symbol,
|
||||
DictionaryAttr attr)
|
||||
{
|
||||
return Base::get(context, symbol, attr);
|
||||
}
|
||||
|
||||
SymbolRefAttr FuncAttr::GetName() const
|
||||
{
|
||||
return getImpl()->name.cast<SymbolRefAttr>();
|
||||
}
|
||||
|
||||
DictionaryAttr FuncAttr::GetAttrs() const
|
||||
{
|
||||
return getImpl()->attrs.cast<DictionaryAttr>();
|
||||
}
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
97
3rdparty/ncnn/tools/mlir/tf_attributes.h
vendored
Normal file
97
3rdparty/ncnn/tools/mlir/tf_attributes.h
vendored
Normal file
@ -0,0 +1,97 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file defines the attributes used in the TensorFlow dialect.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_ATTRIBUTES_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_ATTRIBUTES_H_
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
|
||||
namespace detail {
|
||||
|
||||
struct ShapeAttrStorage;
|
||||
struct FuncAttrStorage;
|
||||
|
||||
} // namespace detail
|
||||
|
||||
class ShapeAttr : public Attribute::AttrBase<ShapeAttr, Attribute,
|
||||
detail::ShapeAttrStorage>
|
||||
{
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
// Get or create a shape attribute. If shape is llvm::None, then it is
|
||||
// unranked. Otherwise it is ranked. And for ranked shapes, the value of the
|
||||
// dimension size must be >= -1. The value of -1 means the dimension is
|
||||
// dynamic. Otherwise, the dimension is static.
|
||||
static ShapeAttr get(mlir::MLIRContext* context,
|
||||
llvm::Optional<ArrayRef<int64_t> > shape);
|
||||
|
||||
// Get or create a shape attribute from a ShapedType type.
|
||||
static ShapeAttr get(mlir::MLIRContext* context, ShapedType shaped_type);
|
||||
|
||||
llvm::Optional<ArrayRef<int64_t> > getValue() const;
|
||||
|
||||
bool hasRank() const;
|
||||
|
||||
// If this is ranked, return the rank. Otherwise, abort.
|
||||
int64_t getRank() const;
|
||||
|
||||
// If this is ranked, return the shape. Otherwise, abort.
|
||||
ArrayRef<int64_t> getShape() const;
|
||||
|
||||
// If this is unranked type or any dimension has unknown size (<0), it doesn't
|
||||
// have static shape. If all dimensions have known size (>= 0), it has static
|
||||
// shape.
|
||||
bool hasStaticShape() const;
|
||||
};
|
||||
|
||||
// Custom attribute to model AttrValue.value.func (NameAttrList type attribute).
|
||||
// This attribute holds a SymbolRefAttr, for the NameAttrList.name string and a
|
||||
// DictionaryAttr for the NameAttrList.attr map<string, AttrValue>. It is
|
||||
// currently printed and parsed for the following format:
|
||||
//
|
||||
// #tf.func<@symbol, {attr = "value"}>
|
||||
//
|
||||
// where the first element is the SymbolRefAttr and the second element is the
|
||||
// DictionaryAttr.
|
||||
class FuncAttr
|
||||
: public Attribute::AttrBase<FuncAttr, Attribute, detail::FuncAttrStorage>
|
||||
{
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
static FuncAttr get(mlir::MLIRContext* context, llvm::StringRef name,
|
||||
DictionaryAttr attr);
|
||||
|
||||
static FuncAttr get(mlir::MLIRContext* context, SymbolRefAttr symbol,
|
||||
DictionaryAttr attr);
|
||||
|
||||
SymbolRefAttr GetName() const;
|
||||
|
||||
DictionaryAttr GetAttrs() const;
|
||||
};
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_ATTRIBUTES_H_
|
323
3rdparty/ncnn/tools/mlir/tf_dialect.cpp
vendored
Normal file
323
3rdparty/ncnn/tools/mlir/tf_dialect.cpp
vendored
Normal file
@ -0,0 +1,323 @@
|
||||
// Tencent is pleased to support the open source community by making ncnn available.
|
||||
//
|
||||
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
|
||||
//
|
||||
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
|
||||
// in compliance with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// https://opensource.org/licenses/BSD-3-Clause
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations under the License.
|
||||
|
||||
#include "tf_dialect.h"
|
||||
|
||||
#include <mlir/Dialect/Traits.h>
|
||||
#include <mlir/IR/Attributes.h>
|
||||
#include <mlir/IR/Builders.h>
|
||||
#include <mlir/IR/Dialect.h>
|
||||
#include <mlir/IR/DialectImplementation.h>
|
||||
#include <mlir/IR/Location.h>
|
||||
#include <mlir/IR/Matchers.h>
|
||||
#include <mlir/IR/MLIRContext.h>
|
||||
#include <mlir/IR/OpDefinition.h>
|
||||
#include <mlir/IR/OpImplementation.h>
|
||||
#include <mlir/IR/Operation.h>
|
||||
#include <mlir/IR/OperationSupport.h>
|
||||
#include <mlir/IR/PatternMatch.h>
|
||||
#include <mlir/IR/TypeUtilities.h>
|
||||
#include <mlir/IR/Types.h>
|
||||
#include <mlir/IR/Value.h>
|
||||
#include <mlir/IR/Verifier.h>
|
||||
#include <mlir/Interfaces/CallInterfaces.h>
|
||||
#include <mlir/Interfaces/DerivedAttributeOpInterface.h>
|
||||
#include <mlir/Interfaces/InferTypeOpInterface.h>
|
||||
#include <mlir/Interfaces/LoopLikeInterface.h>
|
||||
#include <mlir/Interfaces/SideEffectInterfaces.h>
|
||||
#include <mlir/Parser.h>
|
||||
#include <mlir/Support/LogicalResult.h>
|
||||
#include <mlir/Transforms/InliningUtils.h>
|
||||
|
||||
#include "tf_attributes.h"
|
||||
#include "tf_side_effects.h"
|
||||
#include "tf_traits.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
static LogicalResult Verify(...)
|
||||
{
|
||||
return success();
|
||||
}
|
||||
static LogicalResult VerifyPartitionedCall(...)
|
||||
{
|
||||
return success();
|
||||
}
|
||||
static LogicalResult VerifyStridedSliceBase(...)
|
||||
{
|
||||
return success();
|
||||
}
|
||||
static LogicalResult VerifyUnsortedSegmentReduction(...)
|
||||
{
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace TF {
|
||||
|
||||
TensorFlowDialect::TensorFlowDialect(MLIRContext* context)
|
||||
: Dialect(/*name=*/"tf", context, TypeID::get<TensorFlowDialect>())
|
||||
{
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "tf_all_ops.cc.inc"
|
||||
>();
|
||||
addTypes<
|
||||
#define HANDLE_TF_TYPE(tftype, enumerant, name) tftype##Type,
|
||||
#define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type
|
||||
#include "tf_types.def"
|
||||
>();
|
||||
// addInterfaces<TFInlinerInterface, TFDecodeAttributesInterface,
|
||||
// TFConstantFoldInterface>();
|
||||
addAttributes<ShapeAttr, FuncAttr>();
|
||||
|
||||
// Support unknown operations because not all TensorFlow operations are
|
||||
// registered.
|
||||
allowUnknownOperations();
|
||||
|
||||
// for (const auto &hook : *TensorFlowDialect::additional_operation_hooks_) {
|
||||
// hook(*this);
|
||||
// }
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
ShapeAttr ParseShapeAttr(MLIRContext* context, StringRef spec, Location loc)
|
||||
{
|
||||
auto emit_error = [&, spec]() {
|
||||
emitError(loc, "invalid TensorFlow shape attribute: ") << spec;
|
||||
return nullptr;
|
||||
};
|
||||
|
||||
if (!spec.consume_front("shape<")) return emit_error();
|
||||
|
||||
if (spec.consume_front("*>"))
|
||||
return mlir::TF::ShapeAttr::get(context, llvm::None);
|
||||
|
||||
SmallVector<int64_t, 4> shape;
|
||||
while (!spec.consume_front(">"))
|
||||
{
|
||||
int64_t dim;
|
||||
|
||||
if (spec.consume_front("?"))
|
||||
dim = -1;
|
||||
else if (spec.consumeInteger(10, dim) || dim < 0)
|
||||
return emit_error();
|
||||
|
||||
spec.consume_front("x");
|
||||
|
||||
shape.push_back(dim);
|
||||
}
|
||||
|
||||
return mlir::TF::ShapeAttr::get(context, llvm::makeArrayRef(shape));
|
||||
}
|
||||
|
||||
// Parses a #tf.func attribute of the following format:
|
||||
//
|
||||
// #tf.func<@symbol, {attr = "value"}>
|
||||
//
|
||||
// where the first element is a SymbolRefAttr and the second element is a
|
||||
// DictionaryAttr.
|
||||
FuncAttr ParseFuncAttr(MLIRContext* context, StringRef spec, Location loc)
|
||||
{
|
||||
auto emit_error = [&, spec]() {
|
||||
emitError(loc, "invalid TensorFlow func attribute: ") << spec;
|
||||
return nullptr;
|
||||
};
|
||||
|
||||
if (!spec.consume_front("func<")) return emit_error();
|
||||
|
||||
size_t func_name_num_read = 0;
|
||||
Attribute func_name_attr = mlir::parseAttribute(spec, context, func_name_num_read);
|
||||
if (!func_name_attr || !func_name_attr.isa<SymbolRefAttr>())
|
||||
return emit_error();
|
||||
spec = spec.drop_front(func_name_num_read);
|
||||
|
||||
if (!spec.consume_front(", ")) return emit_error();
|
||||
|
||||
size_t func_attrs_num_read = 0;
|
||||
Attribute func_attrs_attr = mlir::parseAttribute(spec, context, func_attrs_num_read);
|
||||
if (!func_attrs_attr || !func_attrs_attr.isa<DictionaryAttr>())
|
||||
return emit_error();
|
||||
spec = spec.drop_front(func_attrs_num_read);
|
||||
|
||||
if (!spec.consume_front(">")) return emit_error();
|
||||
|
||||
return mlir::TF::FuncAttr::get(context, func_name_attr.cast<SymbolRefAttr>(),
|
||||
func_attrs_attr.cast<DictionaryAttr>());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Attribute TensorFlowDialect::parseAttribute(DialectAsmParser& parser,
|
||||
Type type) const
|
||||
{
|
||||
auto spec = parser.getFullSymbolSpec();
|
||||
Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
|
||||
|
||||
if (spec.startswith("shape")) return ParseShapeAttr(getContext(), spec, loc);
|
||||
|
||||
if (spec.startswith("func")) return ParseFuncAttr(getContext(), spec, loc);
|
||||
|
||||
return (emitError(loc, "unknown TensorFlow attribute: " + spec), nullptr);
|
||||
}
|
||||
|
||||
// Parses a type registered to this dialect.
|
||||
Type TensorFlowDialect::parseType(DialectAsmParser& parser) const
|
||||
{
|
||||
StringRef data;
|
||||
if (parser.parseKeyword(&data)) return Type();
|
||||
|
||||
#define HANDLE_TF_TYPE(tftype, enumerant, name) \
|
||||
if (data == name) return tftype##Type::get(getContext());
|
||||
// Custom TensorFlow types are handled separately at the end as they do partial
|
||||
// match.
|
||||
#define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name)
|
||||
// NOLINTNEXTLINE
|
||||
#include "tf_types.def"
|
||||
|
||||
llvm::SMLoc loc = parser.getNameLoc();
|
||||
if (data.startswith("resource"))
|
||||
{
|
||||
Type ret = ParseResourceType(parser);
|
||||
if (!ret) parser.emitError(loc, "invalid resource type");
|
||||
return ret;
|
||||
}
|
||||
if (data.startswith("variant"))
|
||||
{
|
||||
Type ret = ParseVariantType(parser);
|
||||
if (!ret) parser.emitError(loc, "invalid variant type");
|
||||
return ret;
|
||||
}
|
||||
return (parser.emitError(loc, "unknown TensorFlow type: " + data), nullptr);
|
||||
}
|
||||
|
||||
namespace {
|
||||
template<typename TypeWithSubtype>
|
||||
Type ParseTypeWithSubtype(MLIRContext* context, DialectAsmParser& parser)
|
||||
{
|
||||
// Default type without inferred subtypes.
|
||||
if (failed(parser.parseOptionalLess())) return TypeWithSubtype::get(context);
|
||||
|
||||
// Most types with subtypes have only one subtype.
|
||||
SmallVector<TensorType, 1> subtypes;
|
||||
do
|
||||
{
|
||||
TensorType tensor_ty;
|
||||
if (parser.parseType(tensor_ty)) return Type();
|
||||
|
||||
// Each of the subtypes should be a valid TensorFlow type.
|
||||
// TODO(jpienaar): Remove duplication.
|
||||
if (!IsValidTFTensorType(tensor_ty))
|
||||
{
|
||||
parser.emitError(parser.getNameLoc()) << "invalid subtype: " << tensor_ty;
|
||||
return Type();
|
||||
}
|
||||
subtypes.push_back(tensor_ty);
|
||||
} while (succeeded(parser.parseOptionalComma()));
|
||||
|
||||
if (parser.parseGreater()) return Type();
|
||||
|
||||
return TypeWithSubtype::get(subtypes, context);
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
Type TensorFlowDialect::ParseResourceType(DialectAsmParser& parser) const
|
||||
{
|
||||
return ParseTypeWithSubtype<ResourceType>(getContext(), parser);
|
||||
}
|
||||
|
||||
Type TensorFlowDialect::ParseVariantType(DialectAsmParser& parser) const
|
||||
{
|
||||
return ParseTypeWithSubtype<VariantType>(getContext(), parser);
|
||||
}
|
||||
|
||||
Operation* TensorFlowDialect::materializeConstant(OpBuilder& builder,
|
||||
Attribute value, Type type,
|
||||
Location loc)
|
||||
{
|
||||
return builder.create<ConstOp>(loc, type, value);
|
||||
}
|
||||
|
||||
// Builds a constant op with the specified attribute `value`. The result
|
||||
// op's type is deduced from `value`; if `value` is of scalar type,
|
||||
// wraps it up with a tensor type of empty shape.
|
||||
// TODO(jpienaar): This one differs from the autogenerated one as it takes an
|
||||
// attribute but always creates an ElementsAttr internally.
|
||||
void ConstOp::build(OpBuilder& builder, OperationState& result,
|
||||
Attribute value)
|
||||
{
|
||||
ShapedType type;
|
||||
if (auto elem_attr = value.dyn_cast<ElementsAttr>())
|
||||
{
|
||||
return ConstOp::build(builder, result, elem_attr);
|
||||
}
|
||||
else if (value.isa<BoolAttr, FloatAttr, IntegerAttr>())
|
||||
{
|
||||
// All TensorFlow types must be tensor types. In the build() method,
|
||||
// we want to provide more flexibility by allowing attributes of scalar
|
||||
// types. But we need to wrap it up with ElementsAttr to construct
|
||||
// valid TensorFlow constants.
|
||||
type = RankedTensorType::get(/*shape=*/ {}, value.getType());
|
||||
return ConstOp::build(builder, result, DenseElementsAttr::get(type, value));
|
||||
}
|
||||
// TODO(jpienaar): support other TensorFlow specific types.
|
||||
llvm_unreachable("unsupported attribute type for building tf.Const");
|
||||
}
|
||||
|
||||
void ConstOp::build(OpBuilder& builder, OperationState& result, Type type,
|
||||
Attribute value)
|
||||
{
|
||||
// Handle the case where the type and value are already tensors.
|
||||
if (type.isa<TensorType>() && value.isa<ElementsAttr>())
|
||||
{
|
||||
result.addTypes(type);
|
||||
result.addAttribute("value", value);
|
||||
return;
|
||||
}
|
||||
|
||||
// Otherwise, default to the attribute builder.
|
||||
ConstOp::build(builder, result, value);
|
||||
assert(type == result.types[0] && "type mismatch in construction");
|
||||
}
|
||||
|
||||
Region& WhileRegionOp::getLoopBody()
|
||||
{
|
||||
return body();
|
||||
}
|
||||
|
||||
bool WhileRegionOp::isDefinedOutsideOfLoop(Value value)
|
||||
{
|
||||
// If the Op defining the value exists and the defining op is outside the
|
||||
// scope of this WhileRegion, then we can infer that its defined outside.
|
||||
// The defining Op is outside the scope of this WhileRegion if this
|
||||
// WhileRegionOp is not an ancestor of the defining op in the parent chain.
|
||||
Operation* def_op = value.getDefiningOp();
|
||||
return def_op && !getOperation()->isAncestor(def_op);
|
||||
}
|
||||
|
||||
LogicalResult WhileRegionOp::moveOutOfLoop(
|
||||
llvm::ArrayRef<mlir::Operation*> ops)
|
||||
{
|
||||
// Move the hoisted value to just before the while.
|
||||
Operation* while_op = this->getOperation();
|
||||
for (auto op : ops) op->moveBefore(while_op);
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace TF
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "tf_all_ops.cc.inc"
|
68
3rdparty/ncnn/tools/mlir/tf_dialect.h
vendored
Normal file
68
3rdparty/ncnn/tools/mlir/tf_dialect.h
vendored
Normal file
@ -0,0 +1,68 @@
|
||||
// Tencent is pleased to support the open source community by making ncnn available.
|
||||
//
|
||||
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
|
||||
//
|
||||
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
|
||||
// in compliance with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// https://opensource.org/licenses/BSD-3-Clause
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations under the License.
|
||||
|
||||
#ifndef TF_DIALECT_H
|
||||
#define TF_DIALECT_H
|
||||
|
||||
#include <mlir/Dialect/Traits.h>
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
#include <mlir/IR/Dialect.h>
|
||||
#include <mlir/IR/OpImplementation.h>
|
||||
#include <mlir/Interfaces/ControlFlowInterfaces.h>
|
||||
#include <mlir/Interfaces/DerivedAttributeOpInterface.h>
|
||||
#include <mlir/Interfaces/InferTypeOpInterface.h>
|
||||
#include <mlir/Interfaces/LoopLikeInterface.h>
|
||||
#include <mlir/Interfaces/SideEffectInterfaces.h>
|
||||
|
||||
#include "tf_traits.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
namespace TF {
|
||||
|
||||
class TensorFlowDialect : public mlir::Dialect
|
||||
{
|
||||
public:
|
||||
TensorFlowDialect(mlir::MLIRContext* context);
|
||||
|
||||
static StringRef getDialectNamespace()
|
||||
{
|
||||
return "tf";
|
||||
}
|
||||
|
||||
Attribute parseAttribute(DialectAsmParser& parser, Type type) const override;
|
||||
|
||||
// Parse a type registered to this dialect.
|
||||
Type parseType(DialectAsmParser& parser) const override;
|
||||
|
||||
// Parses resource type with potential subtypes.
|
||||
Type ParseResourceType(DialectAsmParser& parser) const;
|
||||
|
||||
// Parse and print variant type. It may have subtypes inferred using shape
|
||||
// inference.
|
||||
Type ParseVariantType(DialectAsmParser& parser) const;
|
||||
|
||||
// Registered hook to materialize a constant operation from a given attribute
|
||||
// value with the desired resultant type.
|
||||
Operation* materializeConstant(OpBuilder& builder, Attribute value, Type type, Location loc) override;
|
||||
};
|
||||
|
||||
} // namespace TF
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "tf_all_ops.h.inc"
|
||||
|
||||
#endif // TF_DIALECT_H
|
18565
3rdparty/ncnn/tools/mlir/tf_generated_ops.td
vendored
Normal file
18565
3rdparty/ncnn/tools/mlir/tf_generated_ops.td
vendored
Normal file
File diff suppressed because it is too large
Load Diff
617
3rdparty/ncnn/tools/mlir/tf_op_base.td
vendored
Normal file
617
3rdparty/ncnn/tools/mlir/tf_op_base.td
vendored
Normal file
@ -0,0 +1,617 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This is the base operation definition file for TensorFlow.
|
||||
//
|
||||
// This file includes the definition for the TensorFlow dialect, base TensorFlow
|
||||
// op, and various commonly used TensorFlow traits, types, attributes, and
|
||||
// builders.
|
||||
|
||||
#ifndef TF_OP_BASE
|
||||
#define TF_OP_BASE
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorFlow dialect definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def TF_Dialect : Dialect {
|
||||
let name = "tf";
|
||||
|
||||
let description = [{
|
||||
The TensorFlow dialect.
|
||||
|
||||
This dialect maps to TensorFlow operations.
|
||||
|
||||
Invariants:
|
||||
|
||||
* All values are of Tensor type (in particular, scalars are
|
||||
represented using zero-dimensional tensors);
|
||||
|
||||
TODO: Make invariants more structured so that we can reference them in ops.
|
||||
}];
|
||||
|
||||
let cppNamespace = "::mlir::TF";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorFlow traits
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Specify this trait if the op requires all outputs to have the same type and
|
||||
// the inputs either have the same type as result or a ref type corresponding to
|
||||
// the result type.
|
||||
def TF_OperandsSameAsResultsTypeOrRef : NativeOpTrait<
|
||||
"TF::OperandsSameAsResultsTypeOrRef">;
|
||||
|
||||
// Op has the same operand and result element types (or type itself, if scalar)
|
||||
// after resolving reference types (i.e., after converting reference types to
|
||||
// their corresponding TensorFlow or standard types).
|
||||
def TF_SameOperandsAndResultElementTypeResolveRef : NativeOpTrait<
|
||||
"TF::SameOperandsAndResultElementTypeResolveRef">;
|
||||
|
||||
// Op has the same operand and result types after resolving reference types
|
||||
// (i.e., after converting reference types to their corresponding TensorFlow or
|
||||
// standard types).
|
||||
def TF_SameOperandsAndResultTypeResolveRef : NativeOpTrait<
|
||||
"TF::SameOperandsAndResultTypeResolveRef">;
|
||||
|
||||
// Layout agnostic operations do not depend on the operands data layout (data
|
||||
// format), as an example all element wise operations are layout agnostic.
|
||||
def TF_LayoutAgnostic : NativeOpTrait<"TF::LayoutAgnostic">;
|
||||
|
||||
// Trait to indicate operations that cannot be duplicated as they might carry
|
||||
// certain state around within their implementations.
|
||||
def TF_CannotDuplicate : NativeOpTrait<"TF::CannotDuplicate">;
|
||||
|
||||
// Trait to indicate an operation cannot be constant folded.
|
||||
def TF_NoConstantFold : NativeOpTrait<"TF::NoConstantFold">;
|
||||
|
||||
// Coefficient wise binary operation with implicit broadcasting support, for
|
||||
// example tf.Sub operation.
|
||||
def TF_CwiseBinary : NativeOpTrait<"TF::CwiseBinary">;
|
||||
|
||||
// Coefficient wise unary operation, for example tf.Sqrt operation.
|
||||
def TF_CwiseUnary : NativeOpTrait<"TF::CwiseUnary">;
|
||||
|
||||
// Variant of broadcastable trait that considers TF's subtype behavior.
|
||||
class TF_OpIsBroadcastableToRes<int opId, int resId> : And<[
|
||||
TCOpResIsShapedTypePred<opId, resId>,
|
||||
CPred<"mlir::TF::BroadcastCompatible("
|
||||
"$_op.getOperand(" # opId # ").getType(), "
|
||||
"$_op.getResult(" # resId # ").getType())">]>;
|
||||
|
||||
|
||||
class TF_AllTypesMatchPred<list<string> values> :
|
||||
CPred<"TF::AreCastCompatible(llvm::makeArrayRef({" #
|
||||
!interleave(values, ", ") # "}))">;
|
||||
|
||||
class TF_AllTypesMatch<list<string> names> :
|
||||
PredOpTrait<
|
||||
"all of {" # !interleave(names, ", ") #
|
||||
"} have dynamically equal types ",
|
||||
TF_AllTypesMatchPred<
|
||||
!foreach(n, names, !subst("$_self", "$" # n, "$_self.getType()"))>>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Rank/Shape helpers.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class TF_OperandIsUnrankedPred<int n> :
|
||||
CPred<"$_op.getOperand(" # n # ").getType().isa<UnrankedTensorType>()">;
|
||||
|
||||
class TF_ResultIsUnrankedPred<int n> :
|
||||
CPred<"$_op.getResult(" # n # ").getType().isa<UnrankedTensorType>()">;
|
||||
|
||||
// Returns true if the n-th operand has unknown rank or has rank m.
|
||||
class TF_OperandHasRank<int n, int m> :
|
||||
PredOpTrait<"operand " # n # " is " # m # "-D",
|
||||
Or<[TF_OperandIsUnrankedPred<n>,
|
||||
CPred<"$_op.getOperand(" # n #
|
||||
").getType().cast<ShapedType>().getRank() == " # m>]>>;
|
||||
|
||||
// Returns true if the n-th result has unknown rank or has rank m.
|
||||
class TF_ResultHasRank<int n, int m> :
|
||||
PredOpTrait<"result " # n # " is " # m # "-D",
|
||||
Or<[TF_ResultIsUnrankedPred<n>,
|
||||
CPred<"$_op.getResult(" # n #
|
||||
").getType().cast<ShapedType>().getRank() == " # m>]>>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorFlow op side effects
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class TF_ResourceBase<string resourceKind> :
|
||||
Resource<!strconcat("::mlir::TF::ResourceEffects::", resourceKind)> {
|
||||
}
|
||||
|
||||
def TF_VariableResource : TF_ResourceBase<"Variable">;
|
||||
def TF_StackResource : TF_ResourceBase<"Stack">;
|
||||
def TF_TensorArrayResource : TF_ResourceBase<"TensorArray">;
|
||||
def TF_SummaryResource : TF_ResourceBase<"Summary">;
|
||||
def TF_LookupTableResource : TF_ResourceBase<"LookupTable">;
|
||||
def TF_DatasetSeedGeneratorResource : TF_ResourceBase<"DatasetSeedGenerator">;
|
||||
def TF_DatasetMemoryCacheResource : TF_ResourceBase<"DatasetMemoryCache">;
|
||||
def TF_DatasetIteratorResource : TF_ResourceBase<"DatasetIterator">;
|
||||
def TF_TPUEmbeddingResource : TF_ResourceBase<"TPUEmbedding">;
|
||||
|
||||
def TF_VariableRead : MemRead<TF_VariableResource>;
|
||||
def TF_StackRead : MemRead<TF_StackResource>;
|
||||
def TF_TensorArrayRead : MemRead<TF_TensorArrayResource>;
|
||||
def TF_LookupTableRead : MemRead<TF_LookupTableResource>;
|
||||
def TF_DatasetSeedGeneratorRead : MemRead<TF_DatasetSeedGeneratorResource>;
|
||||
def TF_DatasetMemoryCacheRead : MemRead<TF_DatasetMemoryCacheResource>;
|
||||
def TF_DatasetIteratorRead : MemRead<TF_DatasetIteratorResource>;
|
||||
|
||||
def TF_VariableWrite : MemWrite<TF_VariableResource>;
|
||||
def TF_StackWrite : MemWrite<TF_StackResource>;
|
||||
def TF_TensorArrayWrite : MemWrite<TF_TensorArrayResource>;
|
||||
def TF_SummaryWrite : MemWrite<TF_SummaryResource>;
|
||||
def TF_LookupTableWrite : MemWrite<TF_LookupTableResource>;
|
||||
def TF_DatasetSeedGeneratorWrite : MemWrite<TF_DatasetSeedGeneratorResource>;
|
||||
def TF_DatasetMemoryCacheWrite : MemWrite<TF_DatasetMemoryCacheResource>;
|
||||
def TF_DatasetIteratorWrite : MemWrite<TF_DatasetIteratorResource>;
|
||||
|
||||
def TF_VariableAlloc : MemAlloc<TF_VariableResource>;
|
||||
def TF_StackAlloc : MemAlloc<TF_StackResource>;
|
||||
def TF_TensorArrayAlloc : MemAlloc<TF_TensorArrayResource>;
|
||||
def TF_SummaryAlloc : MemAlloc<TF_SummaryResource>;
|
||||
def TF_LookupTableAlloc : MemAlloc<TF_LookupTableResource>;
|
||||
def TF_DatasetSeedGeneratorAlloc : MemAlloc<TF_DatasetSeedGeneratorResource>;
|
||||
def TF_DatasetMemoryCacheAlloc : MemAlloc<TF_DatasetMemoryCacheResource>;
|
||||
def TF_DatasetIteratorAlloc : MemAlloc<TF_DatasetIteratorResource>;
|
||||
|
||||
def TF_StackFree : MemFree<TF_StackResource>;
|
||||
def TF_TensorArrayFree : MemFree<TF_TensorArrayResource>;
|
||||
def TF_SummaryFree : MemFree<TF_SummaryResource>;
|
||||
def TF_DatasetSeedGeneratorFree : MemFree<TF_DatasetSeedGeneratorResource>;
|
||||
def TF_DatasetMemoryCacheFree : MemFree<TF_DatasetMemoryCacheResource>;
|
||||
def TF_DatasetIteratorFree : MemFree<TF_DatasetIteratorResource>;
|
||||
|
||||
def TF_TPUEmbeddingSideEffect : MemoryEffects<[MemWrite<TF_TPUEmbeddingResource>]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorFlow op definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class TF_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<TF_Dialect, mnemonic, traits>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorFlow attribute definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class TF_TensorFlowAttr <string name, string description> :
|
||||
Attr<CPred<"$_self.isa<mlir::TF::" # name # "Attr>()">,
|
||||
"TensorFlow " # description # " attribute">;
|
||||
|
||||
def TF_ShapeAttr : TF_TensorFlowAttr<"Shape", "shape"> {
|
||||
let returnType = "llvm::Optional<llvm::ArrayRef<int64_t>>";
|
||||
let convertFromStorage = "$_self.cast<mlir::TF::ShapeAttr>().getValue()";
|
||||
|
||||
// Create a ranked shape attr by default.
|
||||
let constBuilderCall = "mlir::TF::ShapeAttr::get($_builder.getContext(), $0)";
|
||||
}
|
||||
|
||||
def TF_ShapeAttrArray :
|
||||
TypedArrayAttrBase<TF_ShapeAttr, "tensorflow shape attribute array">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorFlow type definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Any tensor element type defined in the TensorFlow dialect
|
||||
def TF_TFDialectType :
|
||||
Type<CPred<"$_self.isa<mlir::TF::TensorFlowType>()">, "TensorFlow type">;
|
||||
|
||||
// Class for any TensorFlow dialect specific type
|
||||
class TF_TensorFlowType <string name, string description> :
|
||||
Type<CPred<"$_self.isa<mlir::TF::" # name # "Type>()">,
|
||||
"TensorFlow " # description # " type">,
|
||||
BuildableType<"getType<mlir::TF::" # name # "Type>()">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Reference types
|
||||
|
||||
// Float reference types
|
||||
def TF_Float16Ref : TF_TensorFlowType<"HalfRef", "f16ref">;
|
||||
def TF_Float32Ref : TF_TensorFlowType<"FloatRef", "f32ref">;
|
||||
def TF_Float64Ref : TF_TensorFlowType<"DoubleRef", "f64ref">;
|
||||
def TF_Bfloat16Ref : TF_TensorFlowType<"Bfloat16Ref", "bf16ref">;
|
||||
|
||||
// Complex reference types
|
||||
def TF_Complex64Ref : TF_TensorFlowType<"Complex64Ref", "complex64ref">;
|
||||
def TF_Complex128Ref : TF_TensorFlowType<"Complex128Ref", "complex128ref">;
|
||||
|
||||
// Integer reference types
|
||||
def TF_Int8Ref : TF_TensorFlowType<"Int8Ref", "i8ref">;
|
||||
def TF_Int16Ref : TF_TensorFlowType<"Int16Ref", "i16ref">;
|
||||
def TF_Int32Ref : TF_TensorFlowType<"Int32Ref", "i32ref">;
|
||||
def TF_Int64Ref : TF_TensorFlowType<"Int64Ref", "i64ref">;
|
||||
|
||||
def TF_Uint8Ref : TF_TensorFlowType<"Uint8Ref", "ui8ref">;
|
||||
def TF_Uint16Ref : TF_TensorFlowType<"Uint16Ref", "ui16ref">;
|
||||
def TF_Uint32Ref : TF_TensorFlowType<"Uint32Ref", "ui32ref">;
|
||||
def TF_Uint64Ref : TF_TensorFlowType<"Uint64Ref", "ui64ref">;
|
||||
|
||||
// Quantized reference types
|
||||
def TF_Qint8Ref : TF_TensorFlowType<"Qint8Ref", "qint8ref">;
|
||||
def TF_Qint16Ref : TF_TensorFlowType<"Qint16Ref", "qint16ref">;
|
||||
def TF_Qint32Ref : TF_TensorFlowType<"Qint32Ref", "qint32ref">;
|
||||
def TF_Quint8Ref : TF_TensorFlowType<"Quint8Ref", "quint8ref">;
|
||||
def TF_Quint16Ref : TF_TensorFlowType<"Quint16Ref", "quint16ref">;
|
||||
|
||||
// Other reference types
|
||||
def TF_BoolRef : TF_TensorFlowType<"BoolRef", "boolref">;
|
||||
def TF_ResourceRef : TF_TensorFlowType<"ResourceRef", "resourceref">;
|
||||
def TF_StrRef : TF_TensorFlowType<"StringRef", "stringref">;
|
||||
def TF_VariantRef : TF_TensorFlowType<"VariantRef", "variantref">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Integer types (including corresponding reference types)
|
||||
|
||||
def TF_Bool : AnyTypeOf<[I<1>, TF_BoolRef], "bool">;
|
||||
|
||||
def TF_Int8 : AnyTypeOf<[I8, TF_Int8Ref], "8-bit integer">;
|
||||
def TF_Int16 : AnyTypeOf<[I16, TF_Int16Ref], "16-bit integer">;
|
||||
def TF_Int32 : AnyTypeOf<[I32, TF_Int32Ref], "32-bit integer">;
|
||||
def TF_Int64 : AnyTypeOf<[I64, TF_Int64Ref], "64-bit integer">;
|
||||
def TF_I32OrI64 : AnyTypeOf<[I32, I64, TF_Int32Ref, TF_Int64Ref],
|
||||
"32/64-bit signed integer">;
|
||||
|
||||
def TF_Uint8 : AnyTypeOf<[UI<8>, TF_Uint8Ref], "8-bit unsigned integer">;
|
||||
def TF_Uint16 : AnyTypeOf<[UI<16>, TF_Uint16Ref], "16-bit unsigned integer">;
|
||||
def TF_Uint32 : AnyTypeOf<[UI<32>, TF_Uint32Ref], "32-bit unsigned integer">;
|
||||
def TF_Uint64 : AnyTypeOf<[UI<64>, TF_Uint64Ref], "64-bit unsigned integer">;
|
||||
|
||||
// Any unsigned integer type
|
||||
def TF_UInt : AnyTypeOf<[TF_Uint8, TF_Uint16, TF_Uint32, TF_Uint64],
|
||||
"unsigned integer">;
|
||||
|
||||
// Any signed integer type
|
||||
def TF_SInt : AnyTypeOf<[TF_Int8, TF_Int16, TF_Int32, TF_Int64],
|
||||
"signed integer">;
|
||||
|
||||
// Any integer type
|
||||
def TF_Int : AnyTypeOf<[TF_SInt, TF_UInt], "integer">;
|
||||
|
||||
// Tensor types
|
||||
def TF_BoolTensor : TensorOf<[TF_Bool]>;
|
||||
|
||||
def TF_IntTensor : TensorOf<[TF_Int]>;
|
||||
def TF_Int8Tensor : TensorOf<[TF_Int8]>;
|
||||
def TF_Int16Tensor : TensorOf<[TF_Int16]>;
|
||||
def TF_Int32Tensor : TensorOf<[TF_Int32]>;
|
||||
def TF_Int64Tensor : TensorOf<[TF_Int64]>;
|
||||
def TF_I32OrI64Tensor : TensorOf<[TF_I32OrI64]>;
|
||||
|
||||
def TF_Uint8Tensor : TensorOf<[TF_Uint8]>;
|
||||
def TF_Uint16Tensor : TensorOf<[TF_Uint16]>;
|
||||
def TF_Uint32Tensor : TensorOf<[TF_Uint32]>;
|
||||
def TF_Uint64Tensor : TensorOf<[TF_Uint64]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Quantized types (including corresponding reference types)
|
||||
|
||||
def TF_Qint8 : AnyTypeOf<
|
||||
[TF_TensorFlowType<"Qint8", "qint8">, TF_Qint8Ref],
|
||||
"8-bit quantized integer">;
|
||||
def TF_Qint16 : AnyTypeOf<
|
||||
[TF_TensorFlowType<"Qint16", "qint16">, TF_Qint16Ref],
|
||||
"16-bit quantized integer">;
|
||||
def TF_Qint32 : AnyTypeOf<
|
||||
[TF_TensorFlowType<"Qint32", "qint32">, TF_Qint32Ref],
|
||||
"32-bit quantized integer">;
|
||||
def TF_Quint8 : AnyTypeOf<
|
||||
[TF_TensorFlowType<"Quint8", "quint8">, TF_Quint8Ref],
|
||||
"8-bit quantized unsigned integer">;
|
||||
def TF_Quint16 : AnyTypeOf<
|
||||
[TF_TensorFlowType<"Quint16", "quint16">, TF_Quint16Ref],
|
||||
"16-bit quantized unsigned integer">;
|
||||
|
||||
// Any quantized type
|
||||
def TF_Quantized : AnyTypeOf<
|
||||
[TF_Qint8, TF_Qint16, TF_Qint32, TF_Quint8, TF_Quint16], "quantized">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Floating-point types (including corresponding reference types)
|
||||
|
||||
def TF_Float16 : AnyTypeOf<[F16, TF_Float16Ref], "16-bit float">;
|
||||
def TF_Float32 : AnyTypeOf<[F32, TF_Float32Ref], "32-bit float">;
|
||||
def TF_Float64 : AnyTypeOf<[F64, TF_Float64Ref], "64-bit float">;
|
||||
def TF_Bfloat16 : AnyTypeOf<[BF16, TF_Bfloat16Ref], "bfloat16">;
|
||||
|
||||
def TF_F32OrF64 : AnyTypeOf<[TF_Float32, TF_Float64], "32/64-bit float">;
|
||||
|
||||
def TF_Float : AnyTypeOf<
|
||||
[TF_Float16, TF_Float32, TF_Float64, TF_Bfloat16],
|
||||
"floating-point">;
|
||||
|
||||
// Tensor types
|
||||
def TF_FloatTensor : TensorOf<[TF_Float]>;
|
||||
def TF_F32OrF64Tensor : TensorOf<[TF_F32OrF64]>;
|
||||
def TF_Float16Tensor : TensorOf<[TF_Float16]>;
|
||||
def TF_Float32Tensor : TensorOf<[TF_Float32]>;
|
||||
def TF_Float64Tensor : TensorOf<[TF_Float64]>;
|
||||
def TF_Bfloat16Tensor : TensorOf<[TF_Bfloat16]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Complex types (including corresponding reference types)
|
||||
|
||||
// TODO(suderman): Remove TF_Complex64 and use a standard ops declaration, along
|
||||
// with the associated cleanup.
|
||||
def TF_Complex64 : AnyTypeOf<[Complex<F<32>>, TF_Complex64Ref],
|
||||
"64-bit complex">;
|
||||
def TF_Complex128 : AnyTypeOf<[Complex<F<64>>, TF_Complex128Ref],
|
||||
"128-bit complex">;
|
||||
def TF_Complex : AnyTypeOf<[TF_Complex64, TF_Complex128], "complex">;
|
||||
|
||||
// Tensor types
|
||||
def TF_ComplexTensor : TensorOf<[TF_Complex]>;
|
||||
def TF_Complex64Tensor : TensorOf<[TF_Complex64]>;
|
||||
def TF_Complex128Tensor : TensorOf<[TF_Complex128]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// String/variant/resource types (including corresponding reference types)
|
||||
|
||||
def TF_Str : AnyTypeOf<
|
||||
[TF_TensorFlowType<"String", "str">, TF_StrRef], "string">;
|
||||
def TF_StrTensor : TensorOf<[TF_Str]>;
|
||||
|
||||
def TF_Variant : AnyTypeOf<
|
||||
[TF_TensorFlowType<"Variant", "var">, TF_VariantRef], "variant">;
|
||||
def TF_VariantTensor : TensorOf<[TF_Variant]>;
|
||||
|
||||
def TF_Resource : AnyTypeOf<
|
||||
[TF_TensorFlowType<"Resource", "res">, TF_ResourceRef], "resource">;
|
||||
def TF_ResourceTensor : TensorOf<[TF_Resource]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Multi-category type constraints
|
||||
|
||||
def TF_IntOrF32OrF64Tensor: TensorOf<[TF_Int, TF_F32OrF64]>;
|
||||
def TF_FpOrI32OrI64Tensor : TensorOf<[TF_Float, TF_I32OrI64]>;
|
||||
def TF_IntOrFpTensor : TensorOf<[TF_Int, TF_Float]>;
|
||||
def TF_SintOrFpTensor : TensorOf<[TF_SInt, TF_Float]>;
|
||||
def TF_FpOrComplexTensor : TensorOf<[TF_Float, TF_Complex]>;
|
||||
|
||||
def TF_Number : AnyTypeOf<
|
||||
[TF_Int, TF_Float, TF_Quantized, TF_Complex], "number">;
|
||||
def TF_NumberTensor : TensorOf<[TF_Number]>;
|
||||
|
||||
def TF_NumberNotQuantizedOrStr :
|
||||
AnyTypeOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Str]>;
|
||||
def TF_NumberNotQuantizedOrStrTensor : TensorOf<[TF_NumberNotQuantizedOrStr]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Tensor and tensor element types
|
||||
|
||||
// Any tensor element type allowed in TensorFlow ops
|
||||
// (see https://www.tensorflow.org/api_docs/python/tf/dtypes/DType)
|
||||
def TF_ElementType : Type<Or<[TF_Float.predicate,
|
||||
TF_Complex.predicate,
|
||||
TF_Int.predicate,
|
||||
TF_Bool.predicate,
|
||||
TF_TFDialectType.predicate]>,
|
||||
"tf.dtype">;
|
||||
|
||||
// Any TensorFlow tensor type
|
||||
def TF_Tensor : TensorOf<[TF_ElementType]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorFlow attribute definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Tensorflow devices metadata
|
||||
|
||||
// Tensorflow GPU device metadata.
|
||||
def TF_GpuDeviceMetadata : StructAttr<"GpuDeviceMetadata", TF_Dialect, [
|
||||
// GPU device compute capability: major:minor.
|
||||
StructFieldAttr<"cc_major", I32Attr>,
|
||||
StructFieldAttr<"cc_minor", I32Attr>
|
||||
]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// String attribute constraints
|
||||
|
||||
// A string attribute whose value are one of the values in `cases`.
|
||||
class TF_AnyStrAttrOf<list<string> cases> : StringBasedAttr<
|
||||
CPred<!foldl(
|
||||
"$_self.cast<StringAttr>().getValue() == \"" # !head(cases) # "\"",
|
||||
!foreach(case, !tail(cases),
|
||||
"$_self.cast<StringAttr>().getValue() == \"" # case # "\""),
|
||||
prev, cur, prev # " || " # cur)>,
|
||||
"string attribute whose value is " #
|
||||
!foldl(/*init*/!head(cases), /*list*/!tail(cases),
|
||||
prev, cur, prev # ", or " # cur)>;
|
||||
|
||||
// TODO: Use EnumAttr to define the common attribute cases
|
||||
|
||||
def TF_ConvnetDataFormatAttr : StringBasedAttr<
|
||||
CPred<"$_self.cast<StringAttr>().getValue() == \"NHWC\" || " #
|
||||
"$_self.cast<StringAttr>().getValue() == \"NCHW\"">,
|
||||
"'NHWC' or 'NCHW' convnet data format">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type attributes
|
||||
|
||||
// A derived attribute that returns the size of `idx`-th ODS-declared variadic
|
||||
// operand.
|
||||
class TF_DerivedOperandSizeAttr<int idx> : DerivedAttr<
|
||||
"size_t",
|
||||
"auto range = getODSOperands(" # idx # ");\n"
|
||||
"return std::distance(range.begin(), range.end());",
|
||||
[{ $_builder.getI64IntegerAttr($_self) }]>;
|
||||
|
||||
// A derived attribute that returns the element type of `idx`-th ODS-declared
|
||||
// operand. If the `idx`-th operand is a variadic operand, then this attribute
|
||||
// just returns the element type of its first tensor, which is only meaningful
|
||||
// when the variadic operand has at least one tensor and the tensors all have
|
||||
// the same element type.
|
||||
class TF_DerivedOperandTypeAttr<int idx> : DerivedTypeAttr<
|
||||
"return mlir::getElementTypeOrSelf(*getODSOperands(" # idx # ").begin());">;
|
||||
|
||||
// A derived attribute that returns the element types of the tensors in the
|
||||
// actual value pack that corresponds to the `idx`-th ODS-declared variadic
|
||||
// operand. This returns a list of element types so it is used for variadic
|
||||
// operands that can have different element types.
|
||||
class TF_DerivedOperandTypeListAttr<int idx> : DerivedAttr<
|
||||
"mlir::OperandElementTypeRange",
|
||||
"auto values = getODSOperands(" # idx # ");\n"
|
||||
"return {mlir::OperandElementTypeIterator(values.begin()), "
|
||||
"mlir::OperandElementTypeIterator(values.end())};",
|
||||
[{
|
||||
ArrayAttr::get($_ctx,
|
||||
[&]() {
|
||||
llvm::SmallVector<Attribute, 4> ret;
|
||||
for (auto t : $_self)
|
||||
ret.push_back(TypeAttr::get(t));
|
||||
return ret;
|
||||
}())
|
||||
}]
|
||||
>;
|
||||
|
||||
// A derived attribute that returns the shapes of the tensors in the actual
|
||||
// value pack that corresponds to the `idx`-th ODS-declared variadic operand.
|
||||
// This returns a list of shapes so it is used for variadic operands that
|
||||
// can have different shapes.
|
||||
class TF_DerivedOperandShapeListAttr<int idx> : DerivedAttr<
|
||||
"::mlir::TF::OperandShapeRange",
|
||||
"auto values = getODSOperands(" # idx # ");\n"
|
||||
"return {mlir::TF::OperandShapeIterator(values.begin()), "
|
||||
"mlir::TF::OperandShapeIterator(values.end())};",
|
||||
[{
|
||||
ArrayAttr::get($_ctx,
|
||||
[&](){
|
||||
llvm::SmallVector<Attribute, 4> ret;
|
||||
for (auto shape : $_self)
|
||||
ret.push_back(mlir::TF::ShapeAttr::get($_ctx, shape));
|
||||
return ret;
|
||||
}())
|
||||
}]
|
||||
>;
|
||||
|
||||
// A derived attribute that returns the size of `idx`-th ODS-declared variadic
|
||||
// result.
|
||||
class TF_DerivedResultSizeAttr<int idx> : DerivedAttr<
|
||||
"size_t",
|
||||
"auto range = getODSResults(" # idx # ");\n"
|
||||
"return std::distance(range.begin(), range.end());",
|
||||
[{ $_builder.getI64IntegerAttr($_self) }]>;
|
||||
|
||||
// A derived attribute that returns the element type of `idx`-th ODS-declared
|
||||
// result. If the `idx`-th result is a variadic result, then this attribute
|
||||
// just returns the element type of its first tensor, which is only meaningful
|
||||
// when the variadic result has at least one tensor and the tensors all have
|
||||
// the same element type.
|
||||
class TF_DerivedResultTypeAttr<int idx> : DerivedTypeAttr<
|
||||
"return mlir::getElementTypeOrSelf(*getODSResults(" # idx # ").begin());">;
|
||||
|
||||
// A derived attribute that returns the element types of the tensors in the
|
||||
// actual value pack that corresponds to the `idx`-th ODS-declared variadic
|
||||
// result. This returns a list of element types so it is used for variadic
|
||||
// results that can have different element types.
|
||||
class TF_DerivedResultTypeListAttr<int idx> : DerivedAttr<
|
||||
"mlir::ResultElementTypeRange",
|
||||
"auto values = getODSResults(" # idx # ");\n"
|
||||
"return {mlir::ResultElementTypeIterator(values.begin()), "
|
||||
"mlir::ResultElementTypeIterator(values.end())};",
|
||||
[{
|
||||
ArrayAttr::get($_ctx,
|
||||
[&]() {
|
||||
llvm::SmallVector<Attribute, 4> ret;
|
||||
for (auto t : $_self)
|
||||
ret.push_back(TypeAttr::get(t));
|
||||
return ret;
|
||||
}())
|
||||
}]
|
||||
>;
|
||||
|
||||
// A derived attribute that returns the shapes of the tensors in the actual
|
||||
// value pack that corresponds to the `idx`-th ODS-declared variadic result.
|
||||
// This returns a list of shapes so it is used for variadic results that
|
||||
// can have different shapes.
|
||||
class TF_DerivedResultShapeListAttr<int idx> : DerivedAttr<
|
||||
"mlir::TF::ResultShapeRange",
|
||||
"auto values = getODSResults(" # idx # ");\n"
|
||||
"return {mlir::TF::ResultShapeIterator(values.begin()), "
|
||||
"mlir::TF::ResultShapeIterator(values.end())};",
|
||||
[{
|
||||
ArrayAttr::get($_ctx,
|
||||
[&](){
|
||||
llvm::SmallVector<Attribute, 4> ret;
|
||||
for (auto shape : $_self)
|
||||
ret.push_back(mlir::TF::ShapeAttr::get($_ctx, shape));
|
||||
return ret;
|
||||
}())
|
||||
}]
|
||||
>;
|
||||
|
||||
// A derived attribute that returns the shape of the first result type.
|
||||
def TF_DerivedResultShapeAttr : DerivedAttr<"ShapedType",
|
||||
"return (*getOperation()->result_type_begin()).cast<ShapedType>();",
|
||||
[{ mlir::TF::ShapeAttr::get($_ctx, $_self) }]>;
|
||||
|
||||
def TF_IntTypeAttr : TypeAttrBase<"IntegerType", "integer type"> {
|
||||
let returnType = "Type";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorFlow common builders
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Mixin class defining a builder for binary ops supporting broadcast
|
||||
// behavior. The result type has the same element type as both operands.
|
||||
class WithBroadcastableBinOpBuilder {
|
||||
list<OpBuilder> builders = [
|
||||
OpBuilder<(ins "Value":$x, "Value":$y),
|
||||
[{
|
||||
auto resultType =
|
||||
OpTrait::util::getBroadcastedType(x.getType(), y.getType());
|
||||
if (!resultType)
|
||||
mlir::emitError($_state.location, "non-broadcastable operands");
|
||||
return build($_builder, $_state, resultType, x, y);
|
||||
}]>];
|
||||
}
|
||||
|
||||
// Mixin class defining a builder for comparison ops supporting broadcast
|
||||
// behavior. The result type has bool element type.
|
||||
class WithBroadcastableCmpOpBuilder {
|
||||
list<OpBuilder> builders = [
|
||||
OpBuilder<(ins "Value":$x, "Value":$y),
|
||||
[{
|
||||
Type resultType;
|
||||
if (x.getType().isa<UnrankedTensorType>() ||
|
||||
y.getType().isa<UnrankedTensorType>()) {
|
||||
resultType = UnrankedTensorType::get($_builder.getI1Type());
|
||||
} else {
|
||||
SmallVector<int64_t, 4> resultShape;
|
||||
if (!OpTrait::util::getBroadcastedShape(
|
||||
x.getType().cast<ShapedType>().getShape(),
|
||||
y.getType().cast<ShapedType>().getShape(), resultShape)) {
|
||||
mlir::emitError($_state.location,
|
||||
"operands have no broadcastable shapes");
|
||||
}
|
||||
|
||||
resultType = RankedTensorType::get(resultShape, $_builder.getI1Type());
|
||||
}
|
||||
return build($_builder, $_state, resultType, x, y);
|
||||
}]>];
|
||||
}
|
||||
|
||||
#endif // TF_OP_BASE
|
2037
3rdparty/ncnn/tools/mlir/tf_ops.td
vendored
Normal file
2037
3rdparty/ncnn/tools/mlir/tf_ops.td
vendored
Normal file
File diff suppressed because it is too large
Load Diff
106
3rdparty/ncnn/tools/mlir/tf_side_effects.h
vendored
Normal file
106
3rdparty/ncnn/tools/mlir/tf_side_effects.h
vendored
Normal file
@ -0,0 +1,106 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This is the side effect definition file for TensorFlow.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_SIDE_EFFECTS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_SIDE_EFFECTS_H_
|
||||
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
namespace ResourceEffects {
|
||||
|
||||
struct Variable : ::mlir::SideEffects::Resource::Base<Variable>
|
||||
{
|
||||
StringRef getName() final
|
||||
{
|
||||
return "Variable";
|
||||
}
|
||||
};
|
||||
|
||||
struct Stack : ::mlir::SideEffects::Resource::Base<Stack>
|
||||
{
|
||||
StringRef getName() final
|
||||
{
|
||||
return "Stack";
|
||||
}
|
||||
};
|
||||
|
||||
struct TensorArray : ::mlir::SideEffects::Resource::Base<TensorArray>
|
||||
{
|
||||
StringRef getName() final
|
||||
{
|
||||
return "TensorArray";
|
||||
}
|
||||
};
|
||||
|
||||
struct Summary : ::mlir::SideEffects::Resource::Base<Summary>
|
||||
{
|
||||
StringRef getName() final
|
||||
{
|
||||
return "Summary";
|
||||
}
|
||||
};
|
||||
|
||||
struct LookupTable : ::mlir::SideEffects::Resource::Base<LookupTable>
|
||||
{
|
||||
StringRef getName() final
|
||||
{
|
||||
return "LookupTable";
|
||||
}
|
||||
};
|
||||
|
||||
struct DatasetSeedGenerator
|
||||
: ::mlir::SideEffects::Resource::Base<DatasetSeedGenerator>
|
||||
{
|
||||
StringRef getName() final
|
||||
{
|
||||
return "DatasetSeedGenerator";
|
||||
}
|
||||
};
|
||||
|
||||
struct DatasetMemoryCache
|
||||
: ::mlir::SideEffects::Resource::Base<DatasetMemoryCache>
|
||||
{
|
||||
StringRef getName() final
|
||||
{
|
||||
return "DatasetMemoryCache";
|
||||
}
|
||||
};
|
||||
|
||||
struct DatasetIterator : ::mlir::SideEffects::Resource::Base<DatasetIterator>
|
||||
{
|
||||
StringRef getName() final
|
||||
{
|
||||
return "DatasetIterator";
|
||||
}
|
||||
};
|
||||
|
||||
// Special resource type to track TPU Embedding specific ops, which must execute
|
||||
// but do not have side effects with one another or with resource variable ops.
|
||||
struct TPUEmbedding : ::mlir::SideEffects::Resource::Base<TPUEmbedding>
|
||||
{
|
||||
StringRef getName() final
|
||||
{
|
||||
return "TPUEmbedding";
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ResourceEffects
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_SIDE_EFFECTS_H_
|
189
3rdparty/ncnn/tools/mlir/tf_traits.h
vendored
Normal file
189
3rdparty/ncnn/tools/mlir/tf_traits.h
vendored
Normal file
@ -0,0 +1,189 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file defines the op traits used in the MLIR TensorFlow dialect.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TRAITS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TRAITS_H_
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/OpDefinition.h" // from @llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tf_types.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace OpTrait {
|
||||
namespace TF {
|
||||
|
||||
// Verifies if 'ref_type' is a REF type corresponding to 'type'.
|
||||
static inline LogicalResult VerifyRefTypeMatch(mlir::Type type,
|
||||
mlir::Type maybe_ref_type)
|
||||
{
|
||||
if (auto ref_type = maybe_ref_type.dyn_cast<mlir::TF::TensorFlowRefType>())
|
||||
return success(ref_type.RemoveRef().getTypeID() == type.getTypeID());
|
||||
return failure();
|
||||
}
|
||||
|
||||
// This class provides verification for ops that are known to have the same
|
||||
// result types and all operands are either of the same type as result or a REF
|
||||
// type corresponding to the result type.
|
||||
// TODO(jpienaar): Update the name and the description.
|
||||
template<typename ConcreteType>
|
||||
class OperandsSameAsResultsTypeOrRef
|
||||
: public TraitBase<ConcreteType, OperandsSameAsResultsTypeOrRef>
|
||||
{
|
||||
public:
|
||||
static LogicalResult verifyTrait(Operation* op)
|
||||
{
|
||||
LogicalResult shapeMatch = impl::verifySameOperandsAndResultShape(op);
|
||||
if (failed(shapeMatch)) return shapeMatch;
|
||||
Type type = op->getResult(0).getType();
|
||||
// Verify that the first result type is same as the rest of the results.
|
||||
// We skip the comparison against itself.
|
||||
for (auto result_type : llvm::drop_begin(op->getResultTypes(), 1))
|
||||
{
|
||||
if (!mlir::TF::HasCompatibleElementTypes(type, result_type))
|
||||
return op->emitOpError()
|
||||
<< "requires all return types to have compatible element types";
|
||||
}
|
||||
for (auto operand_type : op->getOperandTypes())
|
||||
{
|
||||
if (!mlir::TF::HasCompatibleElementTypes(
|
||||
operand_type, type, /*may_ignore_ref_type_lhs=*/true))
|
||||
return op->emitError() << "requires all operands and results to have "
|
||||
"compatible element types";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
inline LogicalResult verifySameOperandsAndResultElementTypeResolveRef(
|
||||
Operation* op)
|
||||
{
|
||||
Type element_type;
|
||||
if (op->getNumResults() > 0)
|
||||
{
|
||||
element_type = mlir::TF::GetElementTypeOrSelfResolveRef(op->getResult(0).getType());
|
||||
}
|
||||
else if (op->getNumOperands() > 0)
|
||||
{
|
||||
element_type = mlir::TF::GetElementTypeOrSelfResolveRef(op->getOperand(0).getType());
|
||||
}
|
||||
else
|
||||
{
|
||||
// Nothing to check.
|
||||
return success();
|
||||
}
|
||||
// Verify that all result element types are compatible to `element_type`.
|
||||
for (const auto& result_type : op->getResultTypes())
|
||||
{
|
||||
if (mlir::TF::GetElementTypeOrSelfResolveRef(result_type) != element_type)
|
||||
{
|
||||
return op->emitOpError(
|
||||
"requires compatible element types for all operands and results");
|
||||
}
|
||||
}
|
||||
// Verify that all operand element types are compatible to `element_type`.
|
||||
for (const auto& operand_type : op->getOperandTypes())
|
||||
{
|
||||
if (mlir::TF::GetElementTypeOrSelfResolveRef(operand_type) != element_type)
|
||||
{
|
||||
return op->emitOpError(
|
||||
"requires compatible element types for all operands and results");
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
// Verifies that op has the same operand and result element types (or type
|
||||
// itself, if scalar) after resolving reference types (i.e., after converting
|
||||
// reference types to their corresponding TensorFlow or standard types).
|
||||
template<typename ConcreteType>
|
||||
class SameOperandsAndResultElementTypeResolveRef
|
||||
: public TraitBase<ConcreteType,
|
||||
SameOperandsAndResultElementTypeResolveRef>
|
||||
{
|
||||
public:
|
||||
static LogicalResult verifyTrait(Operation* op)
|
||||
{
|
||||
return detail::verifySameOperandsAndResultElementTypeResolveRef(op);
|
||||
}
|
||||
};
|
||||
|
||||
// Verifies that op has the same operand and result types after resolving
|
||||
// reference types (i.e., after converting reference types to their
|
||||
// corresponding TensorFlow or standard types).
|
||||
template<typename ConcreteType>
|
||||
class SameOperandsAndResultTypeResolveRef
|
||||
: public TraitBase<ConcreteType, SameOperandsAndResultTypeResolveRef>
|
||||
{
|
||||
public:
|
||||
static LogicalResult verifyTrait(Operation* op)
|
||||
{
|
||||
if (failed(impl::verifySameOperandsAndResultShape(op))) return failure();
|
||||
return detail::verifySameOperandsAndResultElementTypeResolveRef(op);
|
||||
}
|
||||
};
|
||||
|
||||
// Layout agnostic operations do not depend on the operands data layout (data
|
||||
// format), as and example all element wise operations are layout agnostic.
|
||||
template<typename ConcreteType>
|
||||
class LayoutAgnostic : public TraitBase<ConcreteType, LayoutAgnostic>
|
||||
{
|
||||
};
|
||||
|
||||
// Trait to indicate operations that cannot be duplicated as they might carry
|
||||
// certain state around within their implementations.
|
||||
template<typename ConcreteType>
|
||||
class CannotDuplicate : public TraitBase<ConcreteType, CannotDuplicate>
|
||||
{
|
||||
public:
|
||||
static LogicalResult verifyTrait(Operation* op)
|
||||
{
|
||||
if (MemoryEffectOpInterface::hasNoEffect(op))
|
||||
return op->emitError(
|
||||
"operations with no side effects cannot have CannotDuplicate trait");
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Trait to indicate an operation cannot be constant folded.
|
||||
template<typename ConcreteType>
|
||||
class NoConstantFold : public TraitBase<ConcreteType, NoConstantFold>
|
||||
{
|
||||
};
|
||||
|
||||
// Coefficient-wise binary operation with implicit broadcasting support, for
|
||||
// example tf.Sub operation.
|
||||
template<typename ConcreteType>
|
||||
class CwiseBinary : public TraitBase<ConcreteType, CwiseBinary>
|
||||
{
|
||||
};
|
||||
|
||||
// Coefficient-wise unary operation, for example tf.Sqrt operation.
|
||||
template<typename ConcreteType>
|
||||
class CwiseUnary : public TraitBase<ConcreteType, CwiseUnary>
|
||||
{
|
||||
};
|
||||
|
||||
} // namespace TF
|
||||
} // namespace OpTrait
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TRAITS_H_
|
462
3rdparty/ncnn/tools/mlir/tf_types.cc
vendored
Normal file
462
3rdparty/ncnn/tools/mlir/tf_types.cc
vendored
Normal file
@ -0,0 +1,462 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tf_types.h"
|
||||
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "mlir/Dialect/Traits.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Dialect.h" // from @llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
||||
|
||||
namespace {
|
||||
// Returns the shape of the given value if it's ranked; returns llvm::None
|
||||
// otherwise.
|
||||
llvm::Optional<llvm::ArrayRef<int64_t> > GetShape(mlir::Value value)
|
||||
{
|
||||
auto shaped_type = value.getType().cast<mlir::ShapedType>();
|
||||
if (shaped_type.hasRank()) return shaped_type.getShape();
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
// Merges cast compatible shapes and returns a more refined shape. The two
|
||||
// shapes are cast compatible if they have the same rank and at each dimension,
|
||||
// either both have same size or one of them is dynamic. Returns false if the
|
||||
// given shapes are not cast compatible. The refined shape is same or more
|
||||
// precise than the two input shapes.
|
||||
bool GetCastCompatibleShape(llvm::ArrayRef<int64_t> a_shape,
|
||||
llvm::ArrayRef<int64_t> b_shape,
|
||||
llvm::SmallVectorImpl<int64_t>* refined_shape)
|
||||
{
|
||||
if (a_shape.size() != b_shape.size()) return false;
|
||||
int64_t rank = a_shape.size();
|
||||
refined_shape->reserve(rank);
|
||||
for (auto dims : llvm::zip(a_shape, b_shape))
|
||||
{
|
||||
int64_t dim1 = std::get<0>(dims);
|
||||
int64_t dim2 = std::get<1>(dims);
|
||||
|
||||
if (mlir::ShapedType::isDynamic(dim1))
|
||||
{
|
||||
refined_shape->push_back(dim2);
|
||||
continue;
|
||||
}
|
||||
if (mlir::ShapedType::isDynamic(dim2))
|
||||
{
|
||||
refined_shape->push_back(dim1);
|
||||
continue;
|
||||
}
|
||||
if (dim1 == dim2)
|
||||
{
|
||||
refined_shape->push_back(dim1);
|
||||
continue;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Utility iterators
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OperandShapeIterator::OperandShapeIterator(Operation::operand_iterator it)
|
||||
: llvm::mapped_iterator<Operation::operand_iterator,
|
||||
llvm::Optional<ArrayRef<int64_t> > (*)(Value)>(
|
||||
it, &GetShape)
|
||||
{
|
||||
}
|
||||
|
||||
ResultShapeIterator::ResultShapeIterator(Operation::result_iterator it)
|
||||
: llvm::mapped_iterator<Operation::result_iterator,
|
||||
llvm::Optional<ArrayRef<int64_t> > (*)(Value)>(
|
||||
it, &GetShape)
|
||||
{
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TF types helper functions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool TensorFlowType::classof(Type type)
|
||||
{
|
||||
return type.getDialect().getNamespace() == "tf";
|
||||
}
|
||||
bool TensorFlowRefType::classof(Type type)
|
||||
{
|
||||
return type.isa<
|
||||
#define HANDLE_TF_TYPE(tftype, enumerant, name)
|
||||
#define HANDLE_TF_REF_TYPE(tftype, enumerant, name) tftype##Type,
|
||||
#define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type
|
||||
// NOLINTNEXTLINE
|
||||
#include "tf_types.def"
|
||||
>();
|
||||
}
|
||||
bool TensorFlowTypeWithSubtype::classof(Type type)
|
||||
{
|
||||
return type.isa<ResourceType, VariantType>();
|
||||
}
|
||||
|
||||
TensorFlowType TensorFlowRefType::get(Type type)
|
||||
{
|
||||
MLIRContext* ctx = type.getContext();
|
||||
type = getElementTypeOrSelf(type);
|
||||
if (type.isF16())
|
||||
{
|
||||
return HalfRefType::get(ctx);
|
||||
}
|
||||
else if (type.isF32())
|
||||
{
|
||||
return FloatRefType::get(ctx);
|
||||
}
|
||||
else if (type.isF64())
|
||||
{
|
||||
return DoubleRefType::get(ctx);
|
||||
}
|
||||
else if (type.isBF16())
|
||||
{
|
||||
return Bfloat16RefType::get(ctx);
|
||||
}
|
||||
else if (auto complex_type = type.dyn_cast<ComplexType>())
|
||||
{
|
||||
Type etype = complex_type.getElementType();
|
||||
if (etype.isF32())
|
||||
{
|
||||
return Complex64RefType::get(ctx);
|
||||
}
|
||||
else if (etype.isF64())
|
||||
{
|
||||
return Complex128RefType::get(ctx);
|
||||
}
|
||||
llvm_unreachable("unexpected complex type");
|
||||
}
|
||||
else if (auto itype = type.dyn_cast<IntegerType>())
|
||||
{
|
||||
switch (itype.getWidth())
|
||||
{
|
||||
case 1:
|
||||
return BoolRefType::get(ctx);
|
||||
case 8:
|
||||
return itype.isUnsigned() ? TensorFlowType(Uint8RefType::get(ctx))
|
||||
: Int8RefType::get(ctx);
|
||||
case 16:
|
||||
return itype.isUnsigned() ? TensorFlowType(Uint16RefType::get(ctx))
|
||||
: Int16RefType::get(ctx);
|
||||
case 32:
|
||||
return itype.isUnsigned() ? TensorFlowType(Uint32RefType::get(ctx))
|
||||
: Int32RefType::get(ctx);
|
||||
case 64:
|
||||
return itype.isUnsigned() ? TensorFlowType(Uint64RefType::get(ctx))
|
||||
: Int64RefType::get(ctx);
|
||||
default:
|
||||
llvm_unreachable("unexpected integer type");
|
||||
}
|
||||
}
|
||||
#define HANDLE_TF_TYPE(tftype, enumerant, name) \
|
||||
if (auto derived_ty = type.dyn_cast<tftype##Type>()) \
|
||||
return tftype##RefType::get(ctx);
|
||||
|
||||
#define HANDLE_TF_REF_TYPE(tftype, enumerant, name)
|
||||
// NOLINTNEXTLINE
|
||||
#include "tf_types.def"
|
||||
llvm_unreachable("unexpected type kind");
|
||||
}
|
||||
|
||||
Type TensorFlowRefType::RemoveRef()
|
||||
{
|
||||
MLIRContext* ctx = getContext();
|
||||
if (isa<HalfRefType>()) return mlir::FloatType::getF16(ctx);
|
||||
if (isa<FloatRefType>()) return mlir::FloatType::getF32(ctx);
|
||||
if (isa<DoubleRefType>()) return mlir::FloatType::getF64(ctx);
|
||||
if (isa<Bfloat16RefType>()) return mlir::FloatType::getBF16(ctx);
|
||||
if (isa<BoolRefType>()) return mlir::IntegerType::get(ctx, 1);
|
||||
if (isa<Int8RefType>()) return mlir::IntegerType::get(ctx, 8);
|
||||
if (isa<Int16RefType>()) return mlir::IntegerType::get(ctx, 16);
|
||||
if (isa<Int32RefType>()) return mlir::IntegerType::get(ctx, 32);
|
||||
if (isa<Int64RefType>()) return mlir::IntegerType::get(ctx, 64);
|
||||
if (isa<Uint8RefType>())
|
||||
return mlir::IntegerType::get(ctx, 8, IntegerType::Unsigned);
|
||||
if (isa<Uint16RefType>())
|
||||
return mlir::IntegerType::get(ctx, 16, IntegerType::Unsigned);
|
||||
if (isa<Uint32RefType>())
|
||||
return mlir::IntegerType::get(ctx, 32, IntegerType::Unsigned);
|
||||
if (isa<Uint64RefType>())
|
||||
return mlir::IntegerType::get(ctx, 64, IntegerType::Unsigned);
|
||||
if (isa<Complex64RefType>())
|
||||
return mlir::ComplexType::get(mlir::FloatType::getF32(ctx));
|
||||
if (isa<Complex128RefType>())
|
||||
return mlir::ComplexType::get(mlir::FloatType::getF64(ctx));
|
||||
#define HANDLE_TF_TYPE(tftype, enumerant, name) \
|
||||
if (isa<tftype##RefType>()) return tftype##Type::get(ctx);
|
||||
|
||||
#define HANDLE_TF_REF_TYPE(tftype, enumerant, name)
|
||||
// NOLINTNEXTLINE
|
||||
#include "tf_types.def"
|
||||
llvm_unreachable("unexpected tensorflow ref type kind");
|
||||
}
|
||||
|
||||
Type TensorFlowTypeWithSubtype::RemoveSubtypes()
|
||||
{
|
||||
MLIRContext* ctx = getContext();
|
||||
if (isa<VariantType>()) return VariantType::get(ctx);
|
||||
if (isa<ResourceType>()) return ResourceType::get(ctx);
|
||||
llvm_unreachable("unexpected tensorflow type with subtypes kind");
|
||||
}
|
||||
|
||||
ArrayRef<TensorType> TensorFlowTypeWithSubtype::GetSubtypes()
|
||||
{
|
||||
if (auto variant_type = dyn_cast<VariantType>())
|
||||
return variant_type.getSubtypes();
|
||||
if (auto resource_type = dyn_cast<ResourceType>())
|
||||
return resource_type.getSubtypes();
|
||||
llvm_unreachable("unexpected tensorflow type with subtypes kind");
|
||||
}
|
||||
|
||||
// TODO(jpienaar): BroadcastCompatible and HasCompatibleElementTypes have
|
||||
// similar structure that could be extracted into helper method.
|
||||
bool BroadcastCompatible(TypeRange lhs, TypeRange rhs)
|
||||
{
|
||||
if (lhs.size() != rhs.size()) return false;
|
||||
for (auto types : llvm::zip(lhs, rhs))
|
||||
{
|
||||
// Drop ref types because they don't affect broadcast compatibility. E.g.,
|
||||
// `tensor<!tf.f32ref>` and `tensor<f32>` should be considered broadcast
|
||||
// compatible.
|
||||
auto lhs_type = DropRefType(std::get<0>(types));
|
||||
auto rhs_type = DropRefType(std::get<1>(types));
|
||||
|
||||
// This should be true for all TF ops:
|
||||
auto lhs_tt = lhs_type.dyn_cast<TensorType>();
|
||||
auto rhs_tt = rhs_type.dyn_cast<TensorType>();
|
||||
if (!lhs_tt || !rhs_tt)
|
||||
{
|
||||
if (lhs_type != rhs_type) return false;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Verify matching element types. These should be identical, except for
|
||||
// variant type where unknown subtype is considered compatible with all
|
||||
// subtypes.
|
||||
auto lhs_et = lhs_tt.getElementType();
|
||||
auto rhs_et = rhs_tt.getElementType();
|
||||
if (lhs_et != rhs_et)
|
||||
{
|
||||
// If either does not have subtypes, then the element types don't match.
|
||||
auto lhs_wst = lhs_et.dyn_cast<TF::TensorFlowTypeWithSubtype>();
|
||||
auto rhs_wst = rhs_et.dyn_cast<TF::TensorFlowTypeWithSubtype>();
|
||||
if (!lhs_wst || !rhs_wst) return false;
|
||||
|
||||
// Consider the subtype of variant types.
|
||||
auto lhs_wst_st = lhs_wst.GetSubtypes();
|
||||
auto rhs_wst_st = rhs_wst.GetSubtypes();
|
||||
if (!lhs_wst_st.empty() && !rhs_wst_st.empty())
|
||||
{
|
||||
for (auto subtypes : llvm::zip(lhs_wst_st, rhs_wst_st))
|
||||
{
|
||||
if (!BroadcastCompatible(std::get<0>(subtypes),
|
||||
std::get<1>(subtypes)))
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto lhs_rt = lhs_type.dyn_cast<RankedTensorType>();
|
||||
auto rhs_rt = rhs_type.dyn_cast<RankedTensorType>();
|
||||
if (!lhs_rt || !rhs_rt) return true;
|
||||
SmallVector<int64_t, 4> shape;
|
||||
return OpTrait::util::getBroadcastedShape(lhs_rt.getShape(),
|
||||
rhs_rt.getShape(), shape);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Given two types `a` and `b`, returns a refined type which is cast compatible
|
||||
// with both `a` and `b` and is equal to or more precise than both of them. It
|
||||
// returns empty Type if the input types are not cast compatible.
|
||||
//
|
||||
// The two types are considered cast compatible if they have dynamically equal
|
||||
// shapes and element type. For element types that do not have subtypes, they
|
||||
// must be equal. However for TensorFlow types such as Resource and Variant,
|
||||
// that also have subtypes, we recursively check for subtype compatibilty for
|
||||
// Resource types and assume all variant types are cast compatible. If either
|
||||
// one of `a` or `b` have empty subtypes, they are considered cast compatible.
|
||||
//
|
||||
// The returned type is same or more precise than the input types. For example,
|
||||
// if `a` and `b` are cast compatible types tensor<2x?x?xf32> and
|
||||
// tensor<?x4x?xf32> respectively, the returned type is tensor<2x4x?xf32>.
|
||||
//
|
||||
// Provides option to ignore ref types on 'a'. This is useful for TF ops that
|
||||
// might allow operands to either be same as result type or be a ref type
|
||||
// corresponding to it.
|
||||
mlir::Type GetCastCompatibleType(mlir::Type a, mlir::Type b,
|
||||
bool may_ignore_ref_type_a)
|
||||
{
|
||||
// Fast path if everything is equal.
|
||||
if (a == b) return b;
|
||||
|
||||
auto a_tt = a.dyn_cast<mlir::TensorType>();
|
||||
auto b_tt = b.dyn_cast<mlir::TensorType>();
|
||||
|
||||
// If only one of a or b is a tensor type, they are incompatible.
|
||||
if (static_cast<bool>(a_tt) ^ static_cast<bool>(b_tt)) return nullptr;
|
||||
|
||||
// For non-tensor types, we do not need to worry about shape and can return
|
||||
// early.
|
||||
if (!a_tt && !b_tt)
|
||||
{
|
||||
// Remove ref types.
|
||||
if (may_ignore_ref_type_a)
|
||||
{
|
||||
if (auto ref_type = a.dyn_cast<mlir::TF::TensorFlowRefType>())
|
||||
{
|
||||
a = ref_type.RemoveRef();
|
||||
if (a == b) return a;
|
||||
}
|
||||
}
|
||||
if (a.getTypeID() != b.getTypeID()) return nullptr;
|
||||
|
||||
// If either is not a type that contain subtypes then the types are not cast
|
||||
// compatible.
|
||||
auto a_wst = a.dyn_cast<mlir::TF::TensorFlowTypeWithSubtype>();
|
||||
auto b_wst = b.dyn_cast<mlir::TF::TensorFlowTypeWithSubtype>();
|
||||
if (!a_wst || !b_wst) return nullptr;
|
||||
|
||||
// For Variant types we are more permissive right now and accept all pairs
|
||||
// of Variant types. If we are more constrainted and check compatibility of
|
||||
// subtypes, we might reject valid graphs.
|
||||
// TODO(prakalps): Variant doesn't have a subtype, we assign it
|
||||
// one, so we should only assign it one when we know the subtype. Then we
|
||||
// can be more constrained and check subtypes for cast compatibility as
|
||||
// well.
|
||||
if (a.isa<mlir::TF::VariantType>()) return a;
|
||||
|
||||
// For Resource types, we recursively check the subtypes for cast
|
||||
// compatibility, if possible. Otherwise treat them as compatible.
|
||||
auto a_wst_st = a_wst.GetSubtypes();
|
||||
auto b_wst_st = b_wst.GetSubtypes();
|
||||
if (a_wst_st.empty() || b_wst_st.empty()) return a;
|
||||
if (a_wst_st.size() != b_wst_st.size()) return nullptr;
|
||||
llvm::SmallVector<mlir::TensorType, 4> refined_subtypes;
|
||||
for (auto subtypes : llvm::zip(a_wst_st, b_wst_st))
|
||||
{
|
||||
mlir::Type refined_st = GetCastCompatibleType(std::get<0>(subtypes), std::get<1>(subtypes),
|
||||
/*may_ignore_ref_type_a=*/false);
|
||||
if (!refined_st) return nullptr;
|
||||
refined_subtypes.push_back(refined_st.cast<mlir::TensorType>());
|
||||
}
|
||||
|
||||
return mlir::TF::ResourceType::get(refined_subtypes, a.getContext());
|
||||
}
|
||||
|
||||
// For tensor types, check compatibility of both element type and shape.
|
||||
mlir::Type refined_element_ty = GetCastCompatibleType(
|
||||
a_tt.getElementType(), b_tt.getElementType(), may_ignore_ref_type_a);
|
||||
if (!refined_element_ty) return nullptr;
|
||||
|
||||
if (!a_tt.hasRank() && !b_tt.hasRank())
|
||||
{
|
||||
return mlir::UnrankedTensorType::get(refined_element_ty);
|
||||
}
|
||||
if (!a_tt.hasRank())
|
||||
{
|
||||
return mlir::RankedTensorType::get(b_tt.getShape(), refined_element_ty);
|
||||
}
|
||||
if (!b_tt.hasRank())
|
||||
{
|
||||
return mlir::RankedTensorType::get(a_tt.getShape(), refined_element_ty);
|
||||
}
|
||||
|
||||
llvm::SmallVector<int64_t, 8> refined_shape;
|
||||
if (!GetCastCompatibleShape(a_tt.getShape(), b_tt.getShape(), &refined_shape))
|
||||
return nullptr;
|
||||
|
||||
return mlir::RankedTensorType::get(refined_shape, refined_element_ty);
|
||||
}
|
||||
|
||||
bool HasCompatibleElementTypes(Type lhs, Type rhs,
|
||||
bool may_ignore_ref_type_lhs)
|
||||
{
|
||||
return GetCastCompatibleType(lhs, rhs, may_ignore_ref_type_lhs) != nullptr;
|
||||
}
|
||||
|
||||
bool AreCastCompatible(TypeRange types)
|
||||
{
|
||||
Type common = types.front();
|
||||
for (auto type : types.drop_front())
|
||||
{
|
||||
Type refined_type = GetCastCompatibleType(common, type, /*may_ignore_ref_type_a=*/false);
|
||||
if (!refined_type) return false;
|
||||
common = refined_type;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ArraysAreCastCompatible(TypeRange lhs, TypeRange rhs)
|
||||
{
|
||||
if (lhs.size() != rhs.size()) return false;
|
||||
for (auto pair : llvm::zip(lhs, rhs))
|
||||
{
|
||||
auto lhs_i = std::get<0>(pair);
|
||||
auto rhs_i = std::get<1>(pair);
|
||||
if (!AreCastCompatible({lhs_i, rhs_i})) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Assumes a function `GetDefaultTypeOf(ComposedType)` that returns the default
|
||||
// type for a composed type (such as a ref type or a type with subtypes).
|
||||
template<typename ComposedType>
|
||||
Type DropTypeHelper(Type ty)
|
||||
{
|
||||
Type element_ty = getElementTypeOrSelf(ty);
|
||||
auto composed_type = element_ty.dyn_cast<ComposedType>();
|
||||
if (!composed_type) return ty;
|
||||
|
||||
Type default_ty = GetDefaultTypeOf(composed_type);
|
||||
if (auto ranked_ty = ty.dyn_cast<RankedTensorType>())
|
||||
{
|
||||
return RankedTensorType::get(ranked_ty.getShape(), default_ty);
|
||||
}
|
||||
else if (ty.dyn_cast<UnrankedTensorType>())
|
||||
{
|
||||
return UnrankedTensorType::get(default_ty);
|
||||
}
|
||||
else
|
||||
{
|
||||
return default_ty;
|
||||
}
|
||||
}
|
||||
|
||||
Type DropSubTypes(Type ty)
|
||||
{
|
||||
return DropTypeHelper<TF::TensorFlowTypeWithSubtype>(ty);
|
||||
}
|
||||
|
||||
Type DropRefType(Type ty)
|
||||
{
|
||||
return DropTypeHelper<TF::TensorFlowRefType>(ty);
|
||||
}
|
||||
|
||||
Type DropRefAndSubTypes(Type ty)
|
||||
{
|
||||
return DropRefType(DropSubTypes(ty));
|
||||
}
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
77
3rdparty/ncnn/tools/mlir/tf_types.def
vendored
Normal file
77
3rdparty/ncnn/tools/mlir/tf_types.def
vendored
Normal file
@ -0,0 +1,77 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file contains descriptions of the various TensorFlow types. This is
|
||||
// used as a central place for enumerating the different types.
|
||||
|
||||
#ifdef HANDLE_TF_TYPE
|
||||
|
||||
// class, enumerant, name
|
||||
HANDLE_TF_TYPE(Qint8, QINT8, "qint8")
|
||||
HANDLE_TF_TYPE(Qint16, QINT16, "qint16")
|
||||
HANDLE_TF_TYPE(Qint32, QINT32, "qint32")
|
||||
HANDLE_TF_TYPE(Quint8, QUINT8, "quint8")
|
||||
HANDLE_TF_TYPE(Quint16, QUINT16, "quint16")
|
||||
HANDLE_TF_TYPE(String, STRING, "string")
|
||||
|
||||
#ifndef HANDLE_CUSTOM_TF_TYPE
|
||||
#define HANDLE_CUSTOM_TF_TYPE(class, enumerant, name) \
|
||||
HANDLE_TF_TYPE(class, enumerant, name)
|
||||
#endif
|
||||
HANDLE_CUSTOM_TF_TYPE(Resource, RESOURCE, "resource")
|
||||
HANDLE_CUSTOM_TF_TYPE(Variant, VARIANT, "variant")
|
||||
#undef HANDLE_CUSTOM_TF_TYPE
|
||||
|
||||
// All ref types are listed below this line and FloatRef is the first ref type.
|
||||
// This helps in easily differentiating ref and non-ref types, and converting
|
||||
// a type to/from ref types.
|
||||
|
||||
#ifndef HANDLE_TF_REF_TYPE
|
||||
#define HANDLE_TF_REF_TYPE(class, enumerant, name) \
|
||||
HANDLE_TF_TYPE(class, enumerant, name)
|
||||
#endif
|
||||
HANDLE_TF_REF_TYPE(FloatRef, FLOAT_REF, "f32ref")
|
||||
HANDLE_TF_REF_TYPE(DoubleRef, DOUBLE_REF, "f64ref")
|
||||
HANDLE_TF_REF_TYPE(Uint8Ref, UINT8_REF, "uint8ref")
|
||||
HANDLE_TF_REF_TYPE(Int8Ref, INT8_REF, "int8ref")
|
||||
HANDLE_TF_REF_TYPE(Uint16Ref, UINT16_REF, "uint16ref")
|
||||
HANDLE_TF_REF_TYPE(Int16Ref, INT16_REF, "int16ref")
|
||||
HANDLE_TF_REF_TYPE(Uint32Ref, UINT32_REF, "uint32ref")
|
||||
HANDLE_TF_REF_TYPE(Int32Ref, INT32_REF, "int32ref")
|
||||
HANDLE_TF_REF_TYPE(Uint64Ref, UINT64_REF, "uint64ref")
|
||||
HANDLE_TF_REF_TYPE(Int64Ref, INT64_REF, "int64ref")
|
||||
HANDLE_TF_REF_TYPE(StringRef, STRING_REF, "stringref")
|
||||
HANDLE_TF_REF_TYPE(BoolRef, BOOL_REF, "boolref")
|
||||
HANDLE_TF_REF_TYPE(Quint8Ref, QUINT8_REF, "quint8ref")
|
||||
HANDLE_TF_REF_TYPE(Qint8Ref, QINT8_REF, "qint8ref")
|
||||
HANDLE_TF_REF_TYPE(Quint16Ref, QUINT16_REF, "quint16ref")
|
||||
HANDLE_TF_REF_TYPE(Qint16Ref, QINT16_REF, "qint16ref")
|
||||
HANDLE_TF_REF_TYPE(Qint32Ref, QINT32_REF, "qint32ref")
|
||||
HANDLE_TF_REF_TYPE(Bfloat16Ref, BFLOAT16_REF, "bfloat16ref")
|
||||
HANDLE_TF_REF_TYPE(Complex64Ref, COMPLEX64_REF, "complex64ref")
|
||||
HANDLE_TF_REF_TYPE(Complex128Ref, COMPLEX128_REF, "complex128ref")
|
||||
HANDLE_TF_REF_TYPE(HalfRef, HALF_REF, "halfref")
|
||||
HANDLE_TF_REF_TYPE(ResourceRef, RESOURCE_REF, "resourceref")
|
||||
|
||||
#ifndef HANDLE_LAST_TF_TYPE
|
||||
#define HANDLE_LAST_TF_TYPE(class, enumerant, name) \
|
||||
HANDLE_TF_REF_TYPE(class, enumerant, name)
|
||||
#endif
|
||||
HANDLE_LAST_TF_TYPE(VariantRef, VARIANT_REF, "variantref")
|
||||
#undef HANDLE_LAST_TF_TYPE
|
||||
|
||||
#undef HANDLE_TF_REF_TYPE
|
||||
#undef HANDLE_TF_TYPE
|
||||
#endif
|
380
3rdparty/ncnn/tools/mlir/tf_types.h
vendored
Normal file
380
3rdparty/ncnn/tools/mlir/tf_types.h
vendored
Normal file
@ -0,0 +1,380 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file defines the types used in the standard MLIR TensorFlow dialect.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TYPES_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TYPES_H_
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Diagnostics.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
||||
#include "mlir/IR/Types.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Utility iterators
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// An iterator for the tensor shapes of an op's operands of shaped types.
|
||||
// Returns llvm::None if a operand is unranked; returns ArrayRef<int64_t> as the
|
||||
// shape otherwise.
|
||||
class OperandShapeIterator final
|
||||
: public llvm::mapped_iterator<Operation::operand_iterator,
|
||||
llvm::Optional<ArrayRef<int64_t> > (*)(
|
||||
Value)>
|
||||
{
|
||||
public:
|
||||
using reference = llvm::Optional<ArrayRef<int64_t> >;
|
||||
|
||||
/// Initializes the operand shape iterator to the specified operand iterator.
|
||||
explicit OperandShapeIterator(Operation::operand_iterator it);
|
||||
};
|
||||
|
||||
using OperandShapeRange = iterator_range<OperandShapeIterator>;
|
||||
|
||||
// An iterator for the tensor shapes of an op's results of shaped types.
|
||||
// Returns llvm::None if a result is unranked; returns ArrayRef<int64_t> as the
|
||||
// shape otherwise.
|
||||
class ResultShapeIterator final
|
||||
: public llvm::mapped_iterator<Operation::result_iterator,
|
||||
llvm::Optional<ArrayRef<int64_t> > (*)(
|
||||
Value)>
|
||||
{
|
||||
public:
|
||||
using reference = llvm::Optional<ArrayRef<int64_t> >;
|
||||
|
||||
/// Initializes the result shape iterator to the specified result iterator.
|
||||
explicit ResultShapeIterator(Operation::result_iterator it);
|
||||
};
|
||||
|
||||
using ResultShapeRange = iterator_range<ResultShapeIterator>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorFlow types
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// The base class in the TensorFlow type hierarchy.
|
||||
class TensorFlowType : public Type
|
||||
{
|
||||
public:
|
||||
using Type::Type;
|
||||
|
||||
// Support method to enable LLVM-style type casting.
|
||||
static bool classof(Type type);
|
||||
};
|
||||
|
||||
// Returns true if the specified type is a valid TensorFlow element type.
|
||||
static inline bool IsValidTFElementType(Type type)
|
||||
{
|
||||
return type.isa<ComplexType, FloatType, IntegerType, TensorFlowType>();
|
||||
}
|
||||
|
||||
// Returns true if this is a valid TensorFlow tensor type.
|
||||
static inline bool IsValidTFTensorType(Type type)
|
||||
{
|
||||
// TensorFlow types should be tensors of one of the valid TensorFlow element
|
||||
// types.
|
||||
if (auto tensor_ty = type.dyn_cast<TensorType>())
|
||||
return IsValidTFElementType(tensor_ty.getElementType());
|
||||
return false;
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
// Common implementation of TensorFlow types. The template argument indicates
|
||||
// the concrete derived class per CRTP.
|
||||
template<typename Derived>
|
||||
class TensorFlowTypeImpl
|
||||
: public Type::TypeBase<Derived, TensorFlowType, TypeStorage>
|
||||
{
|
||||
public:
|
||||
using Base = typename Type::TypeBase<Derived, TensorFlowType, TypeStorage>;
|
||||
using TFBase = TensorFlowTypeImpl<Derived>;
|
||||
using Base::Base;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
// TensorFlowRefType class supports all the ref types in TensorFlow dialect.
|
||||
class TensorFlowRefType : public TensorFlowType
|
||||
{
|
||||
public:
|
||||
using TensorFlowType::TensorFlowType;
|
||||
|
||||
// Checks if a type is TensorFlow Ref type.
|
||||
static bool classof(Type type);
|
||||
|
||||
// Converts a type to the corresponding TensorFlowRef type.
|
||||
static TensorFlowType get(Type type);
|
||||
static TensorFlowType getChecked(Type type, MLIRContext* context,
|
||||
Location loc)
|
||||
{
|
||||
if (failed(verify(loc, type)))
|
||||
{
|
||||
return TensorFlowRefType();
|
||||
}
|
||||
return get(type);
|
||||
}
|
||||
|
||||
static LogicalResult verify(Location loc, Type type)
|
||||
{
|
||||
// type should be a valid TensorFlow type.
|
||||
if (!IsValidTFTensorType(type))
|
||||
{
|
||||
return emitError(loc) << "invalid TensorFlow type: " << type;
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
// Converts a TensorFlowRef type to the corresponding TensorFlow or standard
|
||||
// type.
|
||||
Type RemoveRef();
|
||||
};
|
||||
|
||||
// Returns the corresponding TensorFlow or standard type from TensorFlowRef
|
||||
// type.
|
||||
static inline Type GetDefaultTypeOf(TensorFlowRefType type)
|
||||
{
|
||||
return type.RemoveRef();
|
||||
}
|
||||
|
||||
// Returns the element type if `type` is a `ShapedType` and the type itself
|
||||
// otherwise, converting `TensorFlowRef` type to corresponding `TensorFlow` or
|
||||
// standard type if necessary.
|
||||
static inline Type GetElementTypeOrSelfResolveRef(Type type)
|
||||
{
|
||||
Type element_type = mlir::getElementTypeOrSelf(type);
|
||||
if (auto ref_type = element_type.dyn_cast<mlir::TF::TensorFlowRefType>())
|
||||
{
|
||||
element_type = ref_type.RemoveRef();
|
||||
}
|
||||
return element_type;
|
||||
}
|
||||
|
||||
#define HANDLE_TF_TYPE(tftype, enumerant, name) \
|
||||
class tftype##Type : public detail::TensorFlowTypeImpl<tftype##Type> \
|
||||
{ \
|
||||
public: \
|
||||
using TFBase::TFBase; \
|
||||
};
|
||||
|
||||
// Custom TensorFlow types are defined separately.
|
||||
#define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name)
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
#include "tf_types.def"
|
||||
|
||||
namespace detail {
|
||||
// Storage type contains inferred subtypes for TypeWithSubtype.
|
||||
class TypeWithSubtypeStorage : public TypeStorage
|
||||
{
|
||||
public:
|
||||
using KeyTy = ArrayRef<TensorType>;
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static TypeWithSubtypeStorage* construct(TypeStorageAllocator& allocator,
|
||||
const KeyTy& key)
|
||||
{
|
||||
ArrayRef<TensorType> subtypes = allocator.copyInto(key);
|
||||
return new (allocator.allocate<TypeWithSubtypeStorage>())
|
||||
TypeWithSubtypeStorage(subtypes);
|
||||
}
|
||||
|
||||
explicit TypeWithSubtypeStorage(const KeyTy& key)
|
||||
: subtypes_(key)
|
||||
{
|
||||
}
|
||||
|
||||
bool operator==(const KeyTy& key) const
|
||||
{
|
||||
return key == subtypes_;
|
||||
}
|
||||
|
||||
static llvm::hash_code hashKey(const KeyTy& key)
|
||||
{
|
||||
return llvm::hash_combine_range(key.begin(), key.end());
|
||||
}
|
||||
|
||||
KeyTy subtypes_;
|
||||
};
|
||||
|
||||
// Common implementation of TensorFlow types with subtypes. These subtypes are
|
||||
// opaque and their interpretation depends on the actual underlying type.
|
||||
// The template argument indicates the concrete derived class per CRTP. Concrete
|
||||
// classes must implement the following:
|
||||
// - `static std::string getTypeName()` that returns the name of the type for
|
||||
// verification logging.
|
||||
template<typename Derived>
|
||||
class TypeWithSubtypeImpl
|
||||
: public Type::TypeBase<Derived, TensorFlowType, TypeWithSubtypeStorage>
|
||||
{
|
||||
public:
|
||||
using Base = Type::TypeBase<Derived, TensorFlowType, TypeWithSubtypeStorage>;
|
||||
using TFBase = TypeWithSubtypeImpl<Derived>;
|
||||
using Base::Base;
|
||||
|
||||
static Derived get(ArrayRef<TensorType> subtypes, MLIRContext* context)
|
||||
{
|
||||
return Base::get(context, subtypes);
|
||||
}
|
||||
|
||||
static Derived getChecked(ArrayRef<TensorType> subtypes, MLIRContext* context,
|
||||
Location loc)
|
||||
{
|
||||
return Base::getChecked(loc, subtypes);
|
||||
}
|
||||
static Derived getChecked(function_ref<InFlightDiagnostic()> emitError,
|
||||
MLIRContext* context,
|
||||
ArrayRef<TensorType> subtypes)
|
||||
{
|
||||
return Base::getChecked(emitError, context, subtypes);
|
||||
}
|
||||
|
||||
static Derived get(MLIRContext* context)
|
||||
{
|
||||
return get({}, context);
|
||||
}
|
||||
|
||||
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
|
||||
ArrayRef<TensorType> subtypes)
|
||||
{
|
||||
// Each of the subtypes should be a valid TensorFlow type.
|
||||
for (TensorType subtype : subtypes)
|
||||
{
|
||||
if (!IsValidTFTensorType(subtype))
|
||||
{
|
||||
return emitError() << "invalid " << Derived::getTypeName()
|
||||
<< " subtype: " << subtype;
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
ArrayRef<TensorType> getSubtypes()
|
||||
{
|
||||
return Base::getImpl()->subtypes_;
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
// TensorFlowTypeWithSubtype class supports all the types with subtypes in
|
||||
// TensorFlow dialect.
|
||||
class TensorFlowTypeWithSubtype : public TensorFlowType
|
||||
{
|
||||
public:
|
||||
using TensorFlowType::TensorFlowType;
|
||||
|
||||
// Checks if a type is TensorFlow type with subtypes.
|
||||
static bool classof(Type type);
|
||||
|
||||
// Converts a TypeWithSubtype type to the same type but without its subtypes.
|
||||
Type RemoveSubtypes();
|
||||
|
||||
// Returns the subtypes.
|
||||
ArrayRef<TensorType> GetSubtypes();
|
||||
};
|
||||
|
||||
// Returns the corresponding TensorFlow type with subtypes but without its
|
||||
// subtypes.
|
||||
static inline Type GetDefaultTypeOf(TensorFlowTypeWithSubtype type)
|
||||
{
|
||||
return type.RemoveSubtypes();
|
||||
}
|
||||
|
||||
// TensorFlow resource type is used to support TensorFlow resource variables,
|
||||
// which represent shared, persistent state manipulated by a TensorFlow program.
|
||||
// ResourceType stores shape and datatype for subtypes unlike most other data
|
||||
// types that don't have any associated information.
|
||||
class ResourceType : public detail::TypeWithSubtypeImpl<ResourceType>
|
||||
{
|
||||
public:
|
||||
using TFBase::TFBase;
|
||||
static std::string getTypeName()
|
||||
{
|
||||
return "ResourceType";
|
||||
}
|
||||
};
|
||||
|
||||
// TensorFlow variant type is used to support arbitrary custom C++ data types.
|
||||
// VariantType stores inferred shape and datatype for subtypes unlike most other
|
||||
// data types that don't have any associated information. For example, variants
|
||||
// encoding TensorList type stores the common shape and dtype of the list
|
||||
// elements as the only subtype.
|
||||
class VariantType : public detail::TypeWithSubtypeImpl<VariantType>
|
||||
{
|
||||
public:
|
||||
using TFBase::TFBase;
|
||||
static std::string getTypeName()
|
||||
{
|
||||
return "VariantType";
|
||||
}
|
||||
};
|
||||
|
||||
// Given two types `a` and `b`, returns a refined type which is cast compatible
|
||||
// with both `a` and `b` and is equal to or more precise than both of them. It
|
||||
// returns empty Type if the input types are not cast compatible.
|
||||
// Provides option to ignore ref types on 'a'. This is useful for TF ops that
|
||||
// might allow operands to either be same as result type or be a ref type
|
||||
// corresponding to it.
|
||||
mlir::Type GetCastCompatibleType(mlir::Type a, mlir::Type b,
|
||||
bool may_ignore_ref_type_a);
|
||||
|
||||
// Returns whether two arrays of Type are broadcast compatible.
|
||||
bool BroadcastCompatible(TypeRange lhs, TypeRange rhs);
|
||||
|
||||
// Returns whether the two elemental types are compatible. Shapes are compatible
|
||||
// if:
|
||||
// - the types are statically equal
|
||||
// - could be dynamically equal
|
||||
// - considering dynamic shapes equal unless contradictory info known;
|
||||
// - element types are equivalent, modulo subtypes possible be less exact
|
||||
// (e.g., a resource type without subtype is considered compatible with
|
||||
// resource type with known subtype).
|
||||
// Provide option to ignore ref types on 'lhs'.
|
||||
bool HasCompatibleElementTypes(Type lhs, Type rhs,
|
||||
bool may_ignore_ref_type_lhs = false);
|
||||
|
||||
// Returns true if all TensorFlow types can be cast to one
|
||||
// another. In other words, a single run-time value is legal for both the types.
|
||||
// For example, tensor<*xf32>, tensor<?xf32> and tensor<3xf32> are cast
|
||||
// compatible.
|
||||
bool AreCastCompatible(TypeRange types);
|
||||
|
||||
// Returns true if corresponding elements of lhs and rhs AreCastCompatible and
|
||||
// lhs and rhs are the same length.
|
||||
bool ArraysAreCastCompatible(TypeRange lhs, TypeRange rhs);
|
||||
|
||||
// If `ty` is a tensor type and its element type has subtypes, then returns a
|
||||
// new type of same shape but dropped subtypes for the element type.
|
||||
// Otherwise, if `ty` has subtypes, then returns corresponding type with dropped
|
||||
// subtypes.
|
||||
// Otherwise, returns the original type `ty`.
|
||||
Type DropSubTypes(Type ty);
|
||||
|
||||
// If `ty` is a tensor type and has elements of a ref type, then returns a new
|
||||
// type of same shape but corresponding non-ref type as element type.
|
||||
// Otherwise, if `ty` is a ref type, then returns corresponding non-ref type.
|
||||
// Otherwise, returns the original type `ty`.
|
||||
Type DropRefType(Type ty);
|
||||
|
||||
// Convenience call for executing both `DropRefType` and `DropSubTypes`.
|
||||
Type DropRefAndSubTypes(Type ty);
|
||||
|
||||
} // end namespace TF
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TYPES_H_
|
Reference in New Issue
Block a user