feat: 切换后端至PaddleOCR-NCNN,切换工程为CMake
1.项目后端整体迁移至PaddleOCR-NCNN算法,已通过基本的兼容性测试 2.工程改为使用CMake组织,后续为了更好地兼容第三方库,不再提供QMake工程 3.重整权利声明文件,重整代码工程,确保最小化侵权风险 Log: 切换后端至PaddleOCR-NCNN,切换工程为CMake Change-Id: I4d5d2c5d37505a4a24b389b1a4c5d12f17bfa38c
This commit is contained in:
		
							
								
								
									
										61
									
								
								3rdparty/ncnn/tools/mlir/CMakeLists.txt
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								3rdparty/ncnn/tools/mlir/CMakeLists.txt
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,61 @@
 | 
			
		||||
project(mlir2ncnn)
 | 
			
		||||
 | 
			
		||||
cmake_minimum_required(VERSION 3.10)
 | 
			
		||||
set(CMAKE_CXX_STANDARD 14)
 | 
			
		||||
 | 
			
		||||
set(LLVM_PROJECT_INSTALL_DIR "/home/nihui/osd/llvm-project/build/install" CACHE STRING "")
 | 
			
		||||
 | 
			
		||||
set(LLVM_DIR "${LLVM_PROJECT_INSTALL_DIR}/lib/cmake/llvm")
 | 
			
		||||
find_package(LLVM REQUIRED)
 | 
			
		||||
 | 
			
		||||
set(MLIR_DIR "${LLVM_PROJECT_INSTALL_DIR}/lib/cmake/mlir")
 | 
			
		||||
find_package(MLIR REQUIRED)
 | 
			
		||||
 | 
			
		||||
add_definitions(-fno-rtti -fno-exceptions)
 | 
			
		||||
 | 
			
		||||
include_directories("${LLVM_PROJECT_INSTALL_DIR}/include")
 | 
			
		||||
 | 
			
		||||
include_directories(${CMAKE_CURRENT_BINARY_DIR})
 | 
			
		||||
 | 
			
		||||
include(${LLVM_DIR}/TableGen.cmake)
 | 
			
		||||
include(${MLIR_DIR}/AddMLIR.cmake)
 | 
			
		||||
 | 
			
		||||
set(LLVM_TARGET_DEFINITIONS tf_ops.td)
 | 
			
		||||
mlir_tablegen(tf_all_ops.h.inc -gen-op-decls)
 | 
			
		||||
mlir_tablegen(tf_all_ops.cc.inc -gen-op-defs)
 | 
			
		||||
add_public_tablegen_target(tf_opsIncGen)
 | 
			
		||||
 | 
			
		||||
set(LLVM_TARGET_DEFINITIONS ncnn_ops.td)
 | 
			
		||||
mlir_tablegen(ncnn_ops.h.inc -gen-op-decls)
 | 
			
		||||
mlir_tablegen(ncnn_ops.cc.inc -gen-op-defs)
 | 
			
		||||
add_public_tablegen_target(ncnn_opsIncGen)
 | 
			
		||||
 | 
			
		||||
set(LLVM_TARGET_DEFINITIONS ncnn_rewriter.td)
 | 
			
		||||
mlir_tablegen(ncnn_rewriter.inc -gen-rewriters)
 | 
			
		||||
add_public_tablegen_target(ncnn_rewriterIncGen)
 | 
			
		||||
 | 
			
		||||
add_executable(mlir2ncnn
 | 
			
		||||
    mlir2ncnn.cpp
 | 
			
		||||
    ncnn_dialect.cpp
 | 
			
		||||
    ncnn_rewriter.cpp
 | 
			
		||||
    tf_dialect.cpp
 | 
			
		||||
    tf_attributes.cc
 | 
			
		||||
    tf_types.cc
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
add_dependencies(mlir2ncnn
 | 
			
		||||
    tf_opsIncGen
 | 
			
		||||
    ncnn_opsIncGen
 | 
			
		||||
    ncnn_rewriterIncGen
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
target_link_libraries(mlir2ncnn
 | 
			
		||||
    MLIRIR
 | 
			
		||||
    MLIRDialect
 | 
			
		||||
    MLIRInferTypeOpInterface
 | 
			
		||||
    MLIRParser
 | 
			
		||||
    MLIRPass
 | 
			
		||||
    MLIRStandard
 | 
			
		||||
    MLIRTransforms
 | 
			
		||||
)
 | 
			
		||||
ncnn_install_tool(mlir2ncnn)
 | 
			
		||||
							
								
								
									
										13
									
								
								3rdparty/ncnn/tools/mlir/fix_td.sh
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								3rdparty/ncnn/tools/mlir/fix_td.sh
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,13 @@
 | 
			
		||||
#!/bin/sh
 | 
			
		||||
 | 
			
		||||
# This dirty script eat td files :P
 | 
			
		||||
# https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/mlir/tensorflow/ir
 | 
			
		||||
 | 
			
		||||
sed -i 's!tensorflow/compiler/mlir/tensorflow/ir/!!g' *.td tf_traits.h tf_types.h tf_types.cc tf_attributes.cc
 | 
			
		||||
 | 
			
		||||
sed -i '/let hasCanonicalizer = 1;/d' *.td
 | 
			
		||||
sed -i '/let hasFolder = 1;/d' *.td
 | 
			
		||||
sed -i '/StringRef GetOptimalLayout(const RuntimeDevices& devices);/d' *.td
 | 
			
		||||
sed -i '/LogicalResult UpdateDataFormat(StringRef data_format);/d' *.td
 | 
			
		||||
sed -i '/LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);/d' *.td
 | 
			
		||||
sed -i '/Optional<ContractionFusion> GetContractionFusion();/d' *.td
 | 
			
		||||
							
								
								
									
										1819
									
								
								3rdparty/ncnn/tools/mlir/mlir2ncnn.cpp
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										1819
									
								
								3rdparty/ncnn/tools/mlir/mlir2ncnn.cpp
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										41
									
								
								3rdparty/ncnn/tools/mlir/ncnn_dialect.cpp
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								3rdparty/ncnn/tools/mlir/ncnn_dialect.cpp
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,41 @@
 | 
			
		||||
// Tencent is pleased to support the open source community by making ncnn available.
 | 
			
		||||
//
 | 
			
		||||
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
 | 
			
		||||
//
 | 
			
		||||
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
 | 
			
		||||
// in compliance with the License. You may obtain a copy of the License at
 | 
			
		||||
//
 | 
			
		||||
// https://opensource.org/licenses/BSD-3-Clause
 | 
			
		||||
//
 | 
			
		||||
// Unless required by applicable law or agreed to in writing, software distributed
 | 
			
		||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
 | 
			
		||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
 | 
			
		||||
// specific language governing permissions and limitations under the License.
 | 
			
		||||
 | 
			
		||||
#include "ncnn_dialect.h"
 | 
			
		||||
 | 
			
		||||
#include <mlir/IR/Builders.h>
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
 | 
			
		||||
namespace ncnn {
 | 
			
		||||
 | 
			
		||||
NCNNDialect::NCNNDialect(mlir::MLIRContext* context)
 | 
			
		||||
    : mlir::Dialect("ncnn", context, TypeID::get<NCNNDialect>())
 | 
			
		||||
{
 | 
			
		||||
    addOperations<
 | 
			
		||||
#define GET_OP_LIST
 | 
			
		||||
#include "ncnn_ops.cc.inc"
 | 
			
		||||
    >();
 | 
			
		||||
 | 
			
		||||
    // Support unknown operations because not all NCNN operations are
 | 
			
		||||
    // registered.
 | 
			
		||||
    allowUnknownOperations();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace ncnn
 | 
			
		||||
 | 
			
		||||
#define GET_OP_CLASSES
 | 
			
		||||
#include "ncnn_ops.cc.inc"
 | 
			
		||||
 | 
			
		||||
} // namespace mlir
 | 
			
		||||
							
								
								
									
										47
									
								
								3rdparty/ncnn/tools/mlir/ncnn_dialect.h
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								3rdparty/ncnn/tools/mlir/ncnn_dialect.h
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,47 @@
 | 
			
		||||
// Tencent is pleased to support the open source community by making ncnn available.
 | 
			
		||||
//
 | 
			
		||||
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
 | 
			
		||||
//
 | 
			
		||||
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
 | 
			
		||||
// in compliance with the License. You may obtain a copy of the License at
 | 
			
		||||
//
 | 
			
		||||
// https://opensource.org/licenses/BSD-3-Clause
 | 
			
		||||
//
 | 
			
		||||
// Unless required by applicable law or agreed to in writing, software distributed
 | 
			
		||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
 | 
			
		||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
 | 
			
		||||
// specific language governing permissions and limitations under the License.
 | 
			
		||||
 | 
			
		||||
#ifndef NCNN_DIALECT_H
 | 
			
		||||
#define NCNN_DIALECT_H
 | 
			
		||||
 | 
			
		||||
#include <mlir/IR/BuiltinTypes.h>
 | 
			
		||||
#include <mlir/IR/Dialect.h>
 | 
			
		||||
#include <mlir/Interfaces/SideEffectInterfaces.h>
 | 
			
		||||
#include <mlir/Pass/Pass.h>
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
 | 
			
		||||
namespace ncnn {
 | 
			
		||||
 | 
			
		||||
class NCNNDialect : public mlir::Dialect
 | 
			
		||||
{
 | 
			
		||||
public:
 | 
			
		||||
    NCNNDialect(mlir::MLIRContext* context);
 | 
			
		||||
 | 
			
		||||
    static StringRef getDialectNamespace()
 | 
			
		||||
    {
 | 
			
		||||
        return "ncnn";
 | 
			
		||||
    }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
std::unique_ptr<OperationPass<FuncOp> > createNCNNOptimizePass();
 | 
			
		||||
 | 
			
		||||
} // namespace ncnn
 | 
			
		||||
 | 
			
		||||
#define GET_OP_CLASSES
 | 
			
		||||
#include "ncnn_ops.h.inc"
 | 
			
		||||
 | 
			
		||||
} // namespace mlir
 | 
			
		||||
 | 
			
		||||
#endif // NCNN_DIALECT_H
 | 
			
		||||
							
								
								
									
										133
									
								
								3rdparty/ncnn/tools/mlir/ncnn_ops.td
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										133
									
								
								3rdparty/ncnn/tools/mlir/ncnn_ops.td
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,133 @@
 | 
			
		||||
// Tencent is pleased to support the open source community by making ncnn available.
 | 
			
		||||
//
 | 
			
		||||
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
 | 
			
		||||
//
 | 
			
		||||
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
 | 
			
		||||
// in compliance with the License. You may obtain a copy of the License at
 | 
			
		||||
//
 | 
			
		||||
// https://opensource.org/licenses/BSD-3-Clause
 | 
			
		||||
//
 | 
			
		||||
// Unless required by applicable law or agreed to in writing, software distributed
 | 
			
		||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
 | 
			
		||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
 | 
			
		||||
// specific language governing permissions and limitations under the License.
 | 
			
		||||
 | 
			
		||||
#ifndef NCNN_OPS_TD
 | 
			
		||||
#define NCNN_OPS_TD
 | 
			
		||||
 | 
			
		||||
include "mlir/IR/OpBase.td"
 | 
			
		||||
include "mlir/Interfaces/SideEffectInterfaces.td"
 | 
			
		||||
include "tf_op_base.td"
 | 
			
		||||
 | 
			
		||||
def NCNN_Dialect : Dialect {
 | 
			
		||||
  let name = "ncnn";
 | 
			
		||||
  let cppNamespace = "ncnn";
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// NCNN op definitions
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
class NCNN_Op<string mnemonic, list<OpTrait> traits = []> :
 | 
			
		||||
    Op<NCNN_Dialect, mnemonic, traits>;
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// NCNN operations
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
def NCNN_KerasConv2DOp : NCNN_Op<"KerasConv2D", [NoSideEffect]> {
 | 
			
		||||
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    F32Tensor:$x,
 | 
			
		||||
    F32Tensor:$weight,
 | 
			
		||||
    F32Tensor:$bias,
 | 
			
		||||
 | 
			
		||||
    I64ArrayAttr:$strides,
 | 
			
		||||
    TF_AnyStrAttrOf<["SAME", "VALID", "EXPLICIT"]>:$padding,
 | 
			
		||||
    DefaultValuedAttr<I64ArrayAttr, "{}">:$explicit_paddings,
 | 
			
		||||
    DefaultValuedAttr<I64ArrayAttr, "{1, 1, 1, 1}">:$dilations
 | 
			
		||||
  );
 | 
			
		||||
 | 
			
		||||
  let results = (outs
 | 
			
		||||
    F32Tensor:$y
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def NCNN_KerasDenseOp : NCNN_Op<"KerasDense", [NoSideEffect]> {
 | 
			
		||||
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    F32Tensor:$x,
 | 
			
		||||
    F32Tensor:$weight,
 | 
			
		||||
    F32Tensor:$bias
 | 
			
		||||
  );
 | 
			
		||||
 | 
			
		||||
  let results = (outs
 | 
			
		||||
    F32Tensor:$y
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def NCNN_KerasBatchNormOp : NCNN_Op<"KerasBatchNorm", [NoSideEffect]> {
 | 
			
		||||
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    F32Tensor:$x,
 | 
			
		||||
    F32Tensor:$gamma,
 | 
			
		||||
    F32Tensor:$bias
 | 
			
		||||
  );
 | 
			
		||||
 | 
			
		||||
  let results = (outs
 | 
			
		||||
    F32Tensor:$y
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def NCNN_BinaryOpOp : NCNN_Op<"BinaryOp", [NoSideEffect]> {
 | 
			
		||||
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    F32Tensor:$x,
 | 
			
		||||
    I32Attr:$op_type,
 | 
			
		||||
    I32Attr:$with_scalar,
 | 
			
		||||
    F32Attr:$b
 | 
			
		||||
  );
 | 
			
		||||
 | 
			
		||||
  let results = (outs
 | 
			
		||||
    F32Tensor:$y
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def NCNN_InstanceNormOp : NCNN_Op<"InstanceNorm", [NoSideEffect]> {
 | 
			
		||||
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    F32Tensor:$x,
 | 
			
		||||
    F32Attr:$epsilon
 | 
			
		||||
  );
 | 
			
		||||
 | 
			
		||||
  let results = (outs
 | 
			
		||||
    F32Tensor:$y
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def NCNN_InstanceNormAffineOp : NCNN_Op<"InstanceNormAffine", [NoSideEffect]> {
 | 
			
		||||
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    F32Tensor:$x,
 | 
			
		||||
    F32Tensor:$gamma,
 | 
			
		||||
    F32Tensor:$beta,
 | 
			
		||||
    F32Attr:$epsilon
 | 
			
		||||
  );
 | 
			
		||||
 | 
			
		||||
  let results = (outs
 | 
			
		||||
    F32Tensor:$y
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def NCNN_SwishOp : NCNN_Op<"Swish", [NoSideEffect]> {
 | 
			
		||||
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    F32Tensor:$x
 | 
			
		||||
  );
 | 
			
		||||
 | 
			
		||||
  let results = (outs
 | 
			
		||||
    F32Tensor:$y
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#endif // NCNN_OPS_TD
 | 
			
		||||
							
								
								
									
										54
									
								
								3rdparty/ncnn/tools/mlir/ncnn_rewriter.cpp
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										54
									
								
								3rdparty/ncnn/tools/mlir/ncnn_rewriter.cpp
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,54 @@
 | 
			
		||||
// Tencent is pleased to support the open source community by making ncnn available.
 | 
			
		||||
//
 | 
			
		||||
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
 | 
			
		||||
//
 | 
			
		||||
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
 | 
			
		||||
// in compliance with the License. You may obtain a copy of the License at
 | 
			
		||||
//
 | 
			
		||||
// https://opensource.org/licenses/BSD-3-Clause
 | 
			
		||||
//
 | 
			
		||||
// Unless required by applicable law or agreed to in writing, software distributed
 | 
			
		||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
 | 
			
		||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
 | 
			
		||||
// specific language governing permissions and limitations under the License.
 | 
			
		||||
 | 
			
		||||
#include <mlir/IR/MLIRContext.h>
 | 
			
		||||
#include <mlir/IR/PatternMatch.h>
 | 
			
		||||
#include <mlir/Pass/Pass.h>
 | 
			
		||||
#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
 | 
			
		||||
 | 
			
		||||
#include "tf_dialect.h"
 | 
			
		||||
#include "ncnn_dialect.h"
 | 
			
		||||
 | 
			
		||||
using namespace mlir;
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
 | 
			
		||||
namespace ncnn {
 | 
			
		||||
 | 
			
		||||
#include "ncnn_rewriter.inc"
 | 
			
		||||
 | 
			
		||||
class NCNNOptimizePass : public PassWrapper<NCNNOptimizePass, FunctionPass>
 | 
			
		||||
{
 | 
			
		||||
public:
 | 
			
		||||
    void runOnFunction();
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
void NCNNOptimizePass::runOnFunction()
 | 
			
		||||
{
 | 
			
		||||
    mlir::OwningRewritePatternList patterns;
 | 
			
		||||
    mlir::ncnn::populateWithGenerated(&getContext(), patterns);
 | 
			
		||||
 | 
			
		||||
    (void)mlir::applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::unique_ptr<OperationPass<FuncOp> > createNCNNOptimizePass()
 | 
			
		||||
{
 | 
			
		||||
    return std::make_unique<NCNNOptimizePass>();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static PassRegistration<NCNNOptimizePass> pass("ncnn-optimize", "ncnn optimization");
 | 
			
		||||
 | 
			
		||||
} // namespace ncnn
 | 
			
		||||
 | 
			
		||||
} // namespace mlir
 | 
			
		||||
							
								
								
									
										210
									
								
								3rdparty/ncnn/tools/mlir/ncnn_rewriter.td
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										210
									
								
								3rdparty/ncnn/tools/mlir/ncnn_rewriter.td
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,210 @@
 | 
			
		||||
// Tencent is pleased to support the open source community by making ncnn available.
 | 
			
		||||
//
 | 
			
		||||
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
 | 
			
		||||
//
 | 
			
		||||
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
 | 
			
		||||
// in compliance with the License. You may obtain a copy of the License at
 | 
			
		||||
//
 | 
			
		||||
// https://opensource.org/licenses/BSD-3-Clause
 | 
			
		||||
//
 | 
			
		||||
// Unless required by applicable law or agreed to in writing, software distributed
 | 
			
		||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
 | 
			
		||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
 | 
			
		||||
// specific language governing permissions and limitations under the License.
 | 
			
		||||
 | 
			
		||||
#ifndef NCNN_REWRITER_TD
 | 
			
		||||
#define NCNN_REWRITER_TD
 | 
			
		||||
 | 
			
		||||
include "tf_ops.td"
 | 
			
		||||
include "ncnn_ops.td"
 | 
			
		||||
 | 
			
		||||
def get_attr_f : NativeCodeCall<"$0.getValue<FloatAttr>(0)">;
 | 
			
		||||
 | 
			
		||||
def OneElementAttrPred : CPred<"$_self.cast<ElementsAttr>().getType().getNumElements() == 1">;
 | 
			
		||||
 | 
			
		||||
def OneElementAttr : ElementsAttrBase<And<[ElementsAttr.predicate, OneElementAttrPred]>, "Scalar ElementsAttr">;
 | 
			
		||||
 | 
			
		||||
def EqualOperands : Constraint<CPred<"$0 == $1">>;
 | 
			
		||||
 | 
			
		||||
def FuseBinaryOpPattern0 : Pat<
 | 
			
		||||
    (TF_MulOp
 | 
			
		||||
        $x,
 | 
			
		||||
        (TF_ConstOp OneElementAttr:$b)
 | 
			
		||||
    ),
 | 
			
		||||
 | 
			
		||||
    (NCNN_BinaryOpOp $x, ConstantAttr<I32Attr, "2">, ConstantAttr<I32Attr, "1">, (get_attr_f $b))
 | 
			
		||||
>;
 | 
			
		||||
 | 
			
		||||
def FuseBinaryOpPattern1 : Pat<
 | 
			
		||||
    (TF_AddV2Op
 | 
			
		||||
        $x,
 | 
			
		||||
        (TF_ConstOp OneElementAttr:$b)
 | 
			
		||||
    ),
 | 
			
		||||
 | 
			
		||||
    (NCNN_BinaryOpOp $x, ConstantAttr<I32Attr, "0">, ConstantAttr<I32Attr, "1">, (get_attr_f $b))
 | 
			
		||||
>;
 | 
			
		||||
 | 
			
		||||
def FuseKerasConv2DOpPattern : Pat<
 | 
			
		||||
    (TF_BiasAddOp
 | 
			
		||||
        (TF_Conv2DOp $x, $weight, $strides, $use_cudnn_on_gpu, $padding, $explicit_paddings, $data_format, $dilations),
 | 
			
		||||
        $bias,
 | 
			
		||||
        $data_format_
 | 
			
		||||
    ),
 | 
			
		||||
 | 
			
		||||
    (NCNN_KerasConv2DOp $x, $weight, $bias, $strides, $padding, $explicit_paddings, $dilations)
 | 
			
		||||
>;
 | 
			
		||||
 | 
			
		||||
def FuseKerasConv2DOpPattern1 : Pat<
 | 
			
		||||
    (TF_AddV2Op
 | 
			
		||||
        (TF_Conv2DOp $x, $weight, $strides, $use_cudnn_on_gpu, $padding, $explicit_paddings, $data_format, $dilations),
 | 
			
		||||
        $bias
 | 
			
		||||
    ),
 | 
			
		||||
 | 
			
		||||
    (NCNN_KerasConv2DOp $x, $weight, $bias, $strides, $padding, $explicit_paddings, $dilations)
 | 
			
		||||
>;
 | 
			
		||||
 | 
			
		||||
def FuseKerasDenseOpPattern : Pat<
 | 
			
		||||
    (TF_BiasAddOp
 | 
			
		||||
        (TF_MatMulOp $x, $weight, $transpose_a, $transpose_b),
 | 
			
		||||
        $bias,
 | 
			
		||||
        $data_format_
 | 
			
		||||
    ),
 | 
			
		||||
 | 
			
		||||
    (NCNN_KerasDenseOp $x, $weight, $bias)
 | 
			
		||||
>;
 | 
			
		||||
 | 
			
		||||
def NonOneElementAttrPred : CPred<"$_self.cast<ElementsAttr>().getType().getNumElements() != 1">;
 | 
			
		||||
 | 
			
		||||
def NonOneElementAttr : ElementsAttrBase<And<[ElementsAttr.predicate, NonOneElementAttrPred]>, "Non Scalar ElementsAttr">;
 | 
			
		||||
 | 
			
		||||
def FuseKerasBatchNormOpPattern : Pat<
 | 
			
		||||
    (TF_AddV2Op
 | 
			
		||||
        (TF_MulOp
 | 
			
		||||
            $x,
 | 
			
		||||
            (TF_ConstOp:$gamma NonOneElementAttr)
 | 
			
		||||
        ),
 | 
			
		||||
        (TF_ConstOp:$bias NonOneElementAttr)
 | 
			
		||||
    ),
 | 
			
		||||
 | 
			
		||||
    (NCNN_KerasBatchNormOp $x, $gamma, $bias)
 | 
			
		||||
>;
 | 
			
		||||
 | 
			
		||||
def FuseInstanceNormPattern0 : Pat<
 | 
			
		||||
    (TF_MulOp
 | 
			
		||||
        (TF_RsqrtOp
 | 
			
		||||
            (TF_AddV2Op
 | 
			
		||||
                (TF_MeanOp
 | 
			
		||||
                    (TF_SquaredDifferenceOp
 | 
			
		||||
                        (TF_MeanOp:$mean
 | 
			
		||||
                            $x,
 | 
			
		||||
                            (TF_ConstOp:$reduce_axis ElementsAttr),
 | 
			
		||||
                            ConstBoolAttrTrue // keep_dims
 | 
			
		||||
                        ),
 | 
			
		||||
                        $x_
 | 
			
		||||
                    ),
 | 
			
		||||
                    $reduce_axis_,
 | 
			
		||||
                    ConstBoolAttrTrue // keep_dims
 | 
			
		||||
                ),
 | 
			
		||||
                (TF_ConstOp ElementsAttr:$epsilon)
 | 
			
		||||
            )
 | 
			
		||||
        ),
 | 
			
		||||
        (TF_SubOp $x__, $mean_)
 | 
			
		||||
    ),
 | 
			
		||||
 | 
			
		||||
    (NCNN_InstanceNormOp $x, (get_attr_f $epsilon)),
 | 
			
		||||
 | 
			
		||||
    [
 | 
			
		||||
        (EqualOperands $x, $x_),
 | 
			
		||||
        (EqualOperands $x, $x__),
 | 
			
		||||
        (EqualOperands $reduce_axis, $reduce_axis_),
 | 
			
		||||
        (EqualOperands $mean, $mean_)
 | 
			
		||||
    ]
 | 
			
		||||
>;
 | 
			
		||||
 | 
			
		||||
def FuseInstanceNormPattern1 : Pat<
 | 
			
		||||
    (TF_MulOp
 | 
			
		||||
        (TF_RsqrtOp
 | 
			
		||||
            (TF_AddV2Op
 | 
			
		||||
                (TF_MeanOp
 | 
			
		||||
                    (TF_SquaredDifferenceOp
 | 
			
		||||
                        $x_,
 | 
			
		||||
                        (TF_MeanOp:$mean
 | 
			
		||||
                            $x,
 | 
			
		||||
                            (TF_ConstOp:$reduce_axis ElementsAttr),
 | 
			
		||||
                            ConstBoolAttrTrue // keep_dims
 | 
			
		||||
                        )
 | 
			
		||||
                    ),
 | 
			
		||||
                    $reduce_axis_,
 | 
			
		||||
                    ConstBoolAttrTrue // keep_dims
 | 
			
		||||
                ),
 | 
			
		||||
                (TF_ConstOp ElementsAttr:$epsilon)
 | 
			
		||||
            )
 | 
			
		||||
        ),
 | 
			
		||||
        (TF_SubOp $x__, $mean_)
 | 
			
		||||
    ),
 | 
			
		||||
 | 
			
		||||
    (NCNN_InstanceNormOp $x, (get_attr_f $epsilon)),
 | 
			
		||||
 | 
			
		||||
    [
 | 
			
		||||
        (EqualOperands $x, $x_),
 | 
			
		||||
        (EqualOperands $x, $x__),
 | 
			
		||||
        (EqualOperands $reduce_axis, $reduce_axis_),
 | 
			
		||||
        (EqualOperands $mean, $mean_)
 | 
			
		||||
    ]
 | 
			
		||||
>;
 | 
			
		||||
 | 
			
		||||
def FuseInstanceNormAffinePattern : Pat<
 | 
			
		||||
    (TF_ReshapeOp
 | 
			
		||||
        (TF_AddV2Op
 | 
			
		||||
            (TF_MulOp
 | 
			
		||||
                $reshaped__,
 | 
			
		||||
                (TF_MulOp:$rsqrt_var_eps_gamma
 | 
			
		||||
                    (TF_RsqrtOp
 | 
			
		||||
                        (TF_AddV2Op
 | 
			
		||||
                            (TF_MeanOp
 | 
			
		||||
                                (TF_SquaredDifferenceOp
 | 
			
		||||
                                    $reshaped_,
 | 
			
		||||
                                    (TF_MeanOp:$mean
 | 
			
		||||
                                        (TF_ReshapeOp:$reshaped $x, (TF_ConstOp ElementsAttr)),
 | 
			
		||||
                                        (TF_ConstOp:$reduce_axis ElementsAttr),
 | 
			
		||||
                                        ConstBoolAttrTrue // keep_dims
 | 
			
		||||
                                    )
 | 
			
		||||
                                ),
 | 
			
		||||
                                $reduce_axis_,
 | 
			
		||||
                                ConstBoolAttrTrue // keep_dims
 | 
			
		||||
                            ),
 | 
			
		||||
                            (TF_ConstOp ElementsAttr:$epsilon)
 | 
			
		||||
                        )
 | 
			
		||||
                    ),
 | 
			
		||||
                    $gamma
 | 
			
		||||
                )
 | 
			
		||||
            ),
 | 
			
		||||
            (TF_SubOp
 | 
			
		||||
                $beta,
 | 
			
		||||
                (TF_MulOp $rsqrt_var_eps_gamma_, $mean_)
 | 
			
		||||
            )
 | 
			
		||||
        ),
 | 
			
		||||
        (TF_ConstOp ElementsAttr)
 | 
			
		||||
    ),
 | 
			
		||||
 | 
			
		||||
    (NCNN_InstanceNormAffineOp $x, $gamma, $beta, (get_attr_f $epsilon)),
 | 
			
		||||
 | 
			
		||||
    [
 | 
			
		||||
        (EqualOperands $reshaped, $reshaped_),
 | 
			
		||||
        (EqualOperands $reshaped, $reshaped__),
 | 
			
		||||
        (EqualOperands $reduce_axis, $reduce_axis_),
 | 
			
		||||
        (EqualOperands $rsqrt_var_eps_gamma, $rsqrt_var_eps_gamma_),
 | 
			
		||||
        (EqualOperands $mean, $mean_)
 | 
			
		||||
    ]
 | 
			
		||||
>;
 | 
			
		||||
 | 
			
		||||
def FuseSwishPattern : Pat<
 | 
			
		||||
    (TF_MulOp
 | 
			
		||||
        (TF_SigmoidOp $x),
 | 
			
		||||
        $x
 | 
			
		||||
    ),
 | 
			
		||||
 | 
			
		||||
    (NCNN_SwishOp $x)
 | 
			
		||||
>;
 | 
			
		||||
 | 
			
		||||
#endif // NCNN_REWRITER_TD
 | 
			
		||||
							
								
								
									
										163
									
								
								3rdparty/ncnn/tools/mlir/tf_attributes.cc
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										163
									
								
								3rdparty/ncnn/tools/mlir/tf_attributes.cc
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,163 @@
 | 
			
		||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "tf_attributes.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace TF {
 | 
			
		||||
 | 
			
		||||
namespace detail {
 | 
			
		||||
 | 
			
		||||
// The storage class for ShapeAttr.
 | 
			
		||||
struct ShapeAttrStorage : public AttributeStorage
 | 
			
		||||
{
 | 
			
		||||
    using KeyTy = std::pair<ArrayRef<int64_t>, bool>;
 | 
			
		||||
 | 
			
		||||
    explicit ShapeAttrStorage(ArrayRef<int64_t> shape, bool unranked = false)
 | 
			
		||||
        : shape(shape), unranked(unranked)
 | 
			
		||||
    {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    bool operator==(const KeyTy& key) const
 | 
			
		||||
    {
 | 
			
		||||
        return key == KeyTy(shape, unranked);
 | 
			
		||||
    }
 | 
			
		||||
    static unsigned hashKey(const KeyTy& key)
 | 
			
		||||
    {
 | 
			
		||||
        return llvm::hash_combine(key.first, static_cast<char>(key.second));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // NOLINTNEXTLINE
 | 
			
		||||
    static ShapeAttrStorage* construct(mlir::AttributeStorageAllocator& allocator,
 | 
			
		||||
                                       const KeyTy& key)
 | 
			
		||||
    {
 | 
			
		||||
        return new (allocator.allocate<ShapeAttrStorage>())
 | 
			
		||||
               ShapeAttrStorage(allocator.copyInto(key.first), key.second);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ArrayRef<int64_t> shape;
 | 
			
		||||
    bool unranked = false;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// The storage class for FuncAttr.
 | 
			
		||||
struct FuncAttrStorage : public AttributeStorage
 | 
			
		||||
{
 | 
			
		||||
    using KeyTy = std::pair<Attribute, Attribute>;
 | 
			
		||||
 | 
			
		||||
    explicit FuncAttrStorage(Attribute name, Attribute attrs)
 | 
			
		||||
        : name(name), attrs(attrs)
 | 
			
		||||
    {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    bool operator==(const KeyTy& key) const
 | 
			
		||||
    {
 | 
			
		||||
        return key == KeyTy(name, attrs);
 | 
			
		||||
    }
 | 
			
		||||
    static unsigned hashKey(const KeyTy& key)
 | 
			
		||||
    {
 | 
			
		||||
        return llvm::hash_combine(key.first, key.second);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    static FuncAttrStorage* construct(mlir::AttributeStorageAllocator& allocator,
 | 
			
		||||
                                      const KeyTy& key)
 | 
			
		||||
    {
 | 
			
		||||
        return new (allocator.allocate<FuncAttrStorage>())
 | 
			
		||||
               FuncAttrStorage(key.first, key.second);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Attribute name;
 | 
			
		||||
    Attribute attrs;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
} // namespace detail
 | 
			
		||||
 | 
			
		||||
// Get or create a shape attribute.
 | 
			
		||||
ShapeAttr ShapeAttr::get(mlir::MLIRContext* context,
 | 
			
		||||
                         llvm::Optional<ArrayRef<int64_t> > shape)
 | 
			
		||||
{
 | 
			
		||||
    if (shape) return Base::get(context, *shape, /*unranked=*/false);
 | 
			
		||||
 | 
			
		||||
    return Base::get(context, ArrayRef<int64_t>(), /*unranked=*/true);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Get or create a shape attribute.
 | 
			
		||||
ShapeAttr ShapeAttr::get(mlir::MLIRContext* context, ShapedType shaped_type)
 | 
			
		||||
{
 | 
			
		||||
    if (shaped_type.hasRank())
 | 
			
		||||
        return Base::get(context, shaped_type.getShape(), /*unranked=*/false);
 | 
			
		||||
 | 
			
		||||
    return Base::get(context, ArrayRef<int64_t>(), /*unranked=*/true);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
llvm::Optional<ArrayRef<int64_t> > ShapeAttr::getValue() const
 | 
			
		||||
{
 | 
			
		||||
    if (hasRank()) return getShape();
 | 
			
		||||
    return llvm::None;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool ShapeAttr::hasRank() const
 | 
			
		||||
{
 | 
			
		||||
    return !getImpl()->unranked;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int64_t ShapeAttr::getRank() const
 | 
			
		||||
{
 | 
			
		||||
    assert(hasRank());
 | 
			
		||||
    return getImpl()->shape.size();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
ArrayRef<int64_t> ShapeAttr::getShape() const
 | 
			
		||||
{
 | 
			
		||||
    assert(hasRank());
 | 
			
		||||
    return getImpl()->shape;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool ShapeAttr::hasStaticShape() const
 | 
			
		||||
{
 | 
			
		||||
    if (!hasRank()) return false;
 | 
			
		||||
 | 
			
		||||
    for (auto dim : getShape())
 | 
			
		||||
    {
 | 
			
		||||
        if (dim < 0) return false;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
FuncAttr FuncAttr::get(mlir::MLIRContext* context, llvm::StringRef name,
 | 
			
		||||
                       DictionaryAttr attr)
 | 
			
		||||
{
 | 
			
		||||
    auto symbol = SymbolRefAttr::get(context, name);
 | 
			
		||||
    return Base::get(context, symbol, attr);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
FuncAttr FuncAttr::get(mlir::MLIRContext* context, SymbolRefAttr symbol,
 | 
			
		||||
                       DictionaryAttr attr)
 | 
			
		||||
{
 | 
			
		||||
    return Base::get(context, symbol, attr);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
SymbolRefAttr FuncAttr::GetName() const
 | 
			
		||||
{
 | 
			
		||||
    return getImpl()->name.cast<SymbolRefAttr>();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
DictionaryAttr FuncAttr::GetAttrs() const
 | 
			
		||||
{
 | 
			
		||||
    return getImpl()->attrs.cast<DictionaryAttr>();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace TF
 | 
			
		||||
} // namespace mlir
 | 
			
		||||
							
								
								
									
										97
									
								
								3rdparty/ncnn/tools/mlir/tf_attributes.h
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										97
									
								
								3rdparty/ncnn/tools/mlir/tf_attributes.h
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,97 @@
 | 
			
		||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
// This file defines the attributes used in the TensorFlow dialect.
 | 
			
		||||
 | 
			
		||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_ATTRIBUTES_H_
 | 
			
		||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_ATTRIBUTES_H_
 | 
			
		||||
 | 
			
		||||
#include "llvm/ADT/StringRef.h"
 | 
			
		||||
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
 | 
			
		||||
#include "mlir/IR/BuiltinTypes.h"      // from @llvm-project
 | 
			
		||||
#include "mlir/IR/MLIRContext.h"       // from @llvm-project
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace TF {
 | 
			
		||||
 | 
			
		||||
namespace detail {
 | 
			
		||||
 | 
			
		||||
struct ShapeAttrStorage;
 | 
			
		||||
struct FuncAttrStorage;
 | 
			
		||||
 | 
			
		||||
} // namespace detail
 | 
			
		||||
 | 
			
		||||
class ShapeAttr : public Attribute::AttrBase<ShapeAttr, Attribute,
 | 
			
		||||
    detail::ShapeAttrStorage>
 | 
			
		||||
{
 | 
			
		||||
public:
 | 
			
		||||
    using Base::Base;
 | 
			
		||||
 | 
			
		||||
    // Get or create a shape attribute. If shape is llvm::None, then it is
 | 
			
		||||
    // unranked. Otherwise it is ranked. And for ranked shapes, the value of the
 | 
			
		||||
    // dimension size must be >= -1. The value of -1 means the dimension is
 | 
			
		||||
    // dynamic. Otherwise, the dimension is static.
 | 
			
		||||
    static ShapeAttr get(mlir::MLIRContext* context,
 | 
			
		||||
                         llvm::Optional<ArrayRef<int64_t> > shape);
 | 
			
		||||
 | 
			
		||||
    // Get or create a shape attribute from a ShapedType type.
 | 
			
		||||
    static ShapeAttr get(mlir::MLIRContext* context, ShapedType shaped_type);
 | 
			
		||||
 | 
			
		||||
    llvm::Optional<ArrayRef<int64_t> > getValue() const;
 | 
			
		||||
 | 
			
		||||
    bool hasRank() const;
 | 
			
		||||
 | 
			
		||||
    // If this is ranked, return the rank. Otherwise, abort.
 | 
			
		||||
    int64_t getRank() const;
 | 
			
		||||
 | 
			
		||||
    // If this is ranked, return the shape. Otherwise, abort.
 | 
			
		||||
    ArrayRef<int64_t> getShape() const;
 | 
			
		||||
 | 
			
		||||
    // If this is unranked type or any dimension has unknown size (<0), it doesn't
 | 
			
		||||
    // have static shape. If all dimensions have known size (>= 0), it has static
 | 
			
		||||
    // shape.
 | 
			
		||||
    bool hasStaticShape() const;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Custom attribute to model AttrValue.value.func (NameAttrList type attribute).
 | 
			
		||||
// This attribute holds a SymbolRefAttr, for the NameAttrList.name string and a
 | 
			
		||||
// DictionaryAttr for the NameAttrList.attr map<string, AttrValue>. It is
 | 
			
		||||
// currently printed and parsed for the following format:
 | 
			
		||||
//
 | 
			
		||||
//   #tf.func<@symbol, {attr = "value"}>
 | 
			
		||||
//
 | 
			
		||||
// where the first element is the SymbolRefAttr and the second element is the
 | 
			
		||||
// DictionaryAttr.
 | 
			
		||||
class FuncAttr
 | 
			
		||||
    : public Attribute::AttrBase<FuncAttr, Attribute, detail::FuncAttrStorage>
 | 
			
		||||
{
 | 
			
		||||
public:
 | 
			
		||||
    using Base::Base;
 | 
			
		||||
 | 
			
		||||
    static FuncAttr get(mlir::MLIRContext* context, llvm::StringRef name,
 | 
			
		||||
                        DictionaryAttr attr);
 | 
			
		||||
 | 
			
		||||
    static FuncAttr get(mlir::MLIRContext* context, SymbolRefAttr symbol,
 | 
			
		||||
                        DictionaryAttr attr);
 | 
			
		||||
 | 
			
		||||
    SymbolRefAttr GetName() const;
 | 
			
		||||
 | 
			
		||||
    DictionaryAttr GetAttrs() const;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
} // namespace TF
 | 
			
		||||
} // namespace mlir
 | 
			
		||||
 | 
			
		||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_ATTRIBUTES_H_
 | 
			
		||||
							
								
								
									
										323
									
								
								3rdparty/ncnn/tools/mlir/tf_dialect.cpp
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										323
									
								
								3rdparty/ncnn/tools/mlir/tf_dialect.cpp
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,323 @@
 | 
			
		||||
// Tencent is pleased to support the open source community by making ncnn available.
 | 
			
		||||
//
 | 
			
		||||
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
 | 
			
		||||
//
 | 
			
		||||
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
 | 
			
		||||
// in compliance with the License. You may obtain a copy of the License at
 | 
			
		||||
//
 | 
			
		||||
// https://opensource.org/licenses/BSD-3-Clause
 | 
			
		||||
//
 | 
			
		||||
// Unless required by applicable law or agreed to in writing, software distributed
 | 
			
		||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
 | 
			
		||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
 | 
			
		||||
// specific language governing permissions and limitations under the License.
 | 
			
		||||
 | 
			
		||||
#include "tf_dialect.h"
 | 
			
		||||
 | 
			
		||||
#include <mlir/Dialect/Traits.h>
 | 
			
		||||
#include <mlir/IR/Attributes.h>
 | 
			
		||||
#include <mlir/IR/Builders.h>
 | 
			
		||||
#include <mlir/IR/Dialect.h>
 | 
			
		||||
#include <mlir/IR/DialectImplementation.h>
 | 
			
		||||
#include <mlir/IR/Location.h>
 | 
			
		||||
#include <mlir/IR/Matchers.h>
 | 
			
		||||
#include <mlir/IR/MLIRContext.h>
 | 
			
		||||
#include <mlir/IR/OpDefinition.h>
 | 
			
		||||
#include <mlir/IR/OpImplementation.h>
 | 
			
		||||
#include <mlir/IR/Operation.h>
 | 
			
		||||
#include <mlir/IR/OperationSupport.h>
 | 
			
		||||
#include <mlir/IR/PatternMatch.h>
 | 
			
		||||
#include <mlir/IR/TypeUtilities.h>
 | 
			
		||||
#include <mlir/IR/Types.h>
 | 
			
		||||
#include <mlir/IR/Value.h>
 | 
			
		||||
#include <mlir/IR/Verifier.h>
 | 
			
		||||
#include <mlir/Interfaces/CallInterfaces.h>
 | 
			
		||||
#include <mlir/Interfaces/DerivedAttributeOpInterface.h>
 | 
			
		||||
#include <mlir/Interfaces/InferTypeOpInterface.h>
 | 
			
		||||
#include <mlir/Interfaces/LoopLikeInterface.h>
 | 
			
		||||
#include <mlir/Interfaces/SideEffectInterfaces.h>
 | 
			
		||||
#include <mlir/Parser.h>
 | 
			
		||||
#include <mlir/Support/LogicalResult.h>
 | 
			
		||||
#include <mlir/Transforms/InliningUtils.h>
 | 
			
		||||
 | 
			
		||||
#include "tf_attributes.h"
 | 
			
		||||
#include "tf_side_effects.h"
 | 
			
		||||
#include "tf_traits.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
 | 
			
		||||
static LogicalResult Verify(...)
 | 
			
		||||
{
 | 
			
		||||
    return success();
 | 
			
		||||
}
 | 
			
		||||
static LogicalResult VerifyPartitionedCall(...)
 | 
			
		||||
{
 | 
			
		||||
    return success();
 | 
			
		||||
}
 | 
			
		||||
static LogicalResult VerifyStridedSliceBase(...)
 | 
			
		||||
{
 | 
			
		||||
    return success();
 | 
			
		||||
}
 | 
			
		||||
static LogicalResult VerifyUnsortedSegmentReduction(...)
 | 
			
		||||
{
 | 
			
		||||
    return success();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
namespace TF {
 | 
			
		||||
 | 
			
		||||
TensorFlowDialect::TensorFlowDialect(MLIRContext* context)
 | 
			
		||||
    : Dialect(/*name=*/"tf", context, TypeID::get<TensorFlowDialect>())
 | 
			
		||||
{
 | 
			
		||||
    addOperations<
 | 
			
		||||
#define GET_OP_LIST
 | 
			
		||||
#include "tf_all_ops.cc.inc"
 | 
			
		||||
    >();
 | 
			
		||||
    addTypes<
 | 
			
		||||
#define HANDLE_TF_TYPE(tftype, enumerant, name)      tftype##Type,
 | 
			
		||||
#define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type
 | 
			
		||||
#include "tf_types.def"
 | 
			
		||||
    >();
 | 
			
		||||
    //   addInterfaces<TFInlinerInterface, TFDecodeAttributesInterface,
 | 
			
		||||
    //                 TFConstantFoldInterface>();
 | 
			
		||||
    addAttributes<ShapeAttr, FuncAttr>();
 | 
			
		||||
 | 
			
		||||
    // Support unknown operations because not all TensorFlow operations are
 | 
			
		||||
    // registered.
 | 
			
		||||
    allowUnknownOperations();
 | 
			
		||||
 | 
			
		||||
    //   for (const auto &hook : *TensorFlowDialect::additional_operation_hooks_) {
 | 
			
		||||
    //     hook(*this);
 | 
			
		||||
    //   }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
ShapeAttr ParseShapeAttr(MLIRContext* context, StringRef spec, Location loc)
 | 
			
		||||
{
 | 
			
		||||
    auto emit_error = [&, spec]() {
 | 
			
		||||
        emitError(loc, "invalid TensorFlow shape attribute: ") << spec;
 | 
			
		||||
        return nullptr;
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    if (!spec.consume_front("shape<")) return emit_error();
 | 
			
		||||
 | 
			
		||||
    if (spec.consume_front("*>"))
 | 
			
		||||
        return mlir::TF::ShapeAttr::get(context, llvm::None);
 | 
			
		||||
 | 
			
		||||
    SmallVector<int64_t, 4> shape;
 | 
			
		||||
    while (!spec.consume_front(">"))
 | 
			
		||||
    {
 | 
			
		||||
        int64_t dim;
 | 
			
		||||
 | 
			
		||||
        if (spec.consume_front("?"))
 | 
			
		||||
            dim = -1;
 | 
			
		||||
        else if (spec.consumeInteger(10, dim) || dim < 0)
 | 
			
		||||
            return emit_error();
 | 
			
		||||
 | 
			
		||||
        spec.consume_front("x");
 | 
			
		||||
 | 
			
		||||
        shape.push_back(dim);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return mlir::TF::ShapeAttr::get(context, llvm::makeArrayRef(shape));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Parses a #tf.func attribute of the following format:
 | 
			
		||||
//
 | 
			
		||||
//   #tf.func<@symbol, {attr = "value"}>
 | 
			
		||||
//
 | 
			
		||||
// where the first element is a SymbolRefAttr and the second element is a
 | 
			
		||||
// DictionaryAttr.
 | 
			
		||||
FuncAttr ParseFuncAttr(MLIRContext* context, StringRef spec, Location loc)
 | 
			
		||||
{
 | 
			
		||||
    auto emit_error = [&, spec]() {
 | 
			
		||||
        emitError(loc, "invalid TensorFlow func attribute: ") << spec;
 | 
			
		||||
        return nullptr;
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    if (!spec.consume_front("func<")) return emit_error();
 | 
			
		||||
 | 
			
		||||
    size_t func_name_num_read = 0;
 | 
			
		||||
    Attribute func_name_attr = mlir::parseAttribute(spec, context, func_name_num_read);
 | 
			
		||||
    if (!func_name_attr || !func_name_attr.isa<SymbolRefAttr>())
 | 
			
		||||
        return emit_error();
 | 
			
		||||
    spec = spec.drop_front(func_name_num_read);
 | 
			
		||||
 | 
			
		||||
    if (!spec.consume_front(", ")) return emit_error();
 | 
			
		||||
 | 
			
		||||
    size_t func_attrs_num_read = 0;
 | 
			
		||||
    Attribute func_attrs_attr = mlir::parseAttribute(spec, context, func_attrs_num_read);
 | 
			
		||||
    if (!func_attrs_attr || !func_attrs_attr.isa<DictionaryAttr>())
 | 
			
		||||
        return emit_error();
 | 
			
		||||
    spec = spec.drop_front(func_attrs_num_read);
 | 
			
		||||
 | 
			
		||||
    if (!spec.consume_front(">")) return emit_error();
 | 
			
		||||
 | 
			
		||||
    return mlir::TF::FuncAttr::get(context, func_name_attr.cast<SymbolRefAttr>(),
 | 
			
		||||
                                   func_attrs_attr.cast<DictionaryAttr>());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace
 | 
			
		||||
 | 
			
		||||
Attribute TensorFlowDialect::parseAttribute(DialectAsmParser& parser,
 | 
			
		||||
        Type type) const
 | 
			
		||||
{
 | 
			
		||||
    auto spec = parser.getFullSymbolSpec();
 | 
			
		||||
    Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
 | 
			
		||||
 | 
			
		||||
    if (spec.startswith("shape")) return ParseShapeAttr(getContext(), spec, loc);
 | 
			
		||||
 | 
			
		||||
    if (spec.startswith("func")) return ParseFuncAttr(getContext(), spec, loc);
 | 
			
		||||
 | 
			
		||||
    return (emitError(loc, "unknown TensorFlow attribute: " + spec), nullptr);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Parses a type registered to this dialect.
 | 
			
		||||
Type TensorFlowDialect::parseType(DialectAsmParser& parser) const
 | 
			
		||||
{
 | 
			
		||||
    StringRef data;
 | 
			
		||||
    if (parser.parseKeyword(&data)) return Type();
 | 
			
		||||
 | 
			
		||||
#define HANDLE_TF_TYPE(tftype, enumerant, name) \
 | 
			
		||||
    if (data == name) return tftype##Type::get(getContext());
 | 
			
		||||
// Custom TensorFlow types are handled separately at the end as they do partial
 | 
			
		||||
// match.
 | 
			
		||||
#define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name)
 | 
			
		||||
// NOLINTNEXTLINE
 | 
			
		||||
#include "tf_types.def"
 | 
			
		||||
 | 
			
		||||
    llvm::SMLoc loc = parser.getNameLoc();
 | 
			
		||||
    if (data.startswith("resource"))
 | 
			
		||||
    {
 | 
			
		||||
        Type ret = ParseResourceType(parser);
 | 
			
		||||
        if (!ret) parser.emitError(loc, "invalid resource type");
 | 
			
		||||
        return ret;
 | 
			
		||||
    }
 | 
			
		||||
    if (data.startswith("variant"))
 | 
			
		||||
    {
 | 
			
		||||
        Type ret = ParseVariantType(parser);
 | 
			
		||||
        if (!ret) parser.emitError(loc, "invalid variant type");
 | 
			
		||||
        return ret;
 | 
			
		||||
    }
 | 
			
		||||
    return (parser.emitError(loc, "unknown TensorFlow type: " + data), nullptr);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
template<typename TypeWithSubtype>
 | 
			
		||||
Type ParseTypeWithSubtype(MLIRContext* context, DialectAsmParser& parser)
 | 
			
		||||
{
 | 
			
		||||
    // Default type without inferred subtypes.
 | 
			
		||||
    if (failed(parser.parseOptionalLess())) return TypeWithSubtype::get(context);
 | 
			
		||||
 | 
			
		||||
    // Most types with subtypes have only one subtype.
 | 
			
		||||
    SmallVector<TensorType, 1> subtypes;
 | 
			
		||||
    do
 | 
			
		||||
    {
 | 
			
		||||
        TensorType tensor_ty;
 | 
			
		||||
        if (parser.parseType(tensor_ty)) return Type();
 | 
			
		||||
 | 
			
		||||
        // Each of the subtypes should be a valid TensorFlow type.
 | 
			
		||||
        // TODO(jpienaar): Remove duplication.
 | 
			
		||||
        if (!IsValidTFTensorType(tensor_ty))
 | 
			
		||||
        {
 | 
			
		||||
            parser.emitError(parser.getNameLoc()) << "invalid subtype: " << tensor_ty;
 | 
			
		||||
            return Type();
 | 
			
		||||
        }
 | 
			
		||||
        subtypes.push_back(tensor_ty);
 | 
			
		||||
    } while (succeeded(parser.parseOptionalComma()));
 | 
			
		||||
 | 
			
		||||
    if (parser.parseGreater()) return Type();
 | 
			
		||||
 | 
			
		||||
    return TypeWithSubtype::get(subtypes, context);
 | 
			
		||||
}
 | 
			
		||||
} // anonymous namespace
 | 
			
		||||
 | 
			
		||||
Type TensorFlowDialect::ParseResourceType(DialectAsmParser& parser) const
 | 
			
		||||
{
 | 
			
		||||
    return ParseTypeWithSubtype<ResourceType>(getContext(), parser);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Type TensorFlowDialect::ParseVariantType(DialectAsmParser& parser) const
 | 
			
		||||
{
 | 
			
		||||
    return ParseTypeWithSubtype<VariantType>(getContext(), parser);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Operation* TensorFlowDialect::materializeConstant(OpBuilder& builder,
 | 
			
		||||
        Attribute value, Type type,
 | 
			
		||||
        Location loc)
 | 
			
		||||
{
 | 
			
		||||
    return builder.create<ConstOp>(loc, type, value);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Builds a constant op with the specified attribute `value`. The result
 | 
			
		||||
// op's type is deduced from `value`; if `value` is of scalar type,
 | 
			
		||||
// wraps it up with a tensor type of empty shape.
 | 
			
		||||
// TODO(jpienaar): This one differs from the autogenerated one as it takes an
 | 
			
		||||
// attribute but always creates an ElementsAttr internally.
 | 
			
		||||
void ConstOp::build(OpBuilder& builder, OperationState& result,
 | 
			
		||||
                    Attribute value)
 | 
			
		||||
{
 | 
			
		||||
    ShapedType type;
 | 
			
		||||
    if (auto elem_attr = value.dyn_cast<ElementsAttr>())
 | 
			
		||||
    {
 | 
			
		||||
        return ConstOp::build(builder, result, elem_attr);
 | 
			
		||||
    }
 | 
			
		||||
    else if (value.isa<BoolAttr, FloatAttr, IntegerAttr>())
 | 
			
		||||
    {
 | 
			
		||||
        // All TensorFlow types must be tensor types. In the build() method,
 | 
			
		||||
        // we want to provide more flexibility by allowing attributes of scalar
 | 
			
		||||
        // types. But we need to wrap it up with ElementsAttr to construct
 | 
			
		||||
        // valid TensorFlow constants.
 | 
			
		||||
        type = RankedTensorType::get(/*shape=*/ {}, value.getType());
 | 
			
		||||
        return ConstOp::build(builder, result, DenseElementsAttr::get(type, value));
 | 
			
		||||
    }
 | 
			
		||||
    // TODO(jpienaar): support other TensorFlow specific types.
 | 
			
		||||
    llvm_unreachable("unsupported attribute type for building tf.Const");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void ConstOp::build(OpBuilder& builder, OperationState& result, Type type,
 | 
			
		||||
                    Attribute value)
 | 
			
		||||
{
 | 
			
		||||
    // Handle the case where the type and value are already tensors.
 | 
			
		||||
    if (type.isa<TensorType>() && value.isa<ElementsAttr>())
 | 
			
		||||
    {
 | 
			
		||||
        result.addTypes(type);
 | 
			
		||||
        result.addAttribute("value", value);
 | 
			
		||||
        return;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Otherwise, default to the attribute builder.
 | 
			
		||||
    ConstOp::build(builder, result, value);
 | 
			
		||||
    assert(type == result.types[0] && "type mismatch in construction");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Region& WhileRegionOp::getLoopBody()
 | 
			
		||||
{
 | 
			
		||||
    return body();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool WhileRegionOp::isDefinedOutsideOfLoop(Value value)
 | 
			
		||||
{
 | 
			
		||||
    // If the Op defining the value exists and the defining op is outside the
 | 
			
		||||
    // scope of this WhileRegion, then we can infer that its defined outside.
 | 
			
		||||
    // The defining Op is outside the scope of this WhileRegion if this
 | 
			
		||||
    // WhileRegionOp is not an ancestor of the defining op in the parent chain.
 | 
			
		||||
    Operation* def_op = value.getDefiningOp();
 | 
			
		||||
    return def_op && !getOperation()->isAncestor(def_op);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
LogicalResult WhileRegionOp::moveOutOfLoop(
 | 
			
		||||
    llvm::ArrayRef<mlir::Operation*> ops)
 | 
			
		||||
{
 | 
			
		||||
    // Move the hoisted value to just before the while.
 | 
			
		||||
    Operation* while_op = this->getOperation();
 | 
			
		||||
    for (auto op : ops) op->moveBefore(while_op);
 | 
			
		||||
    return success();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace TF
 | 
			
		||||
 | 
			
		||||
} // namespace mlir
 | 
			
		||||
 | 
			
		||||
#define GET_OP_CLASSES
 | 
			
		||||
#include "tf_all_ops.cc.inc"
 | 
			
		||||
							
								
								
									
										68
									
								
								3rdparty/ncnn/tools/mlir/tf_dialect.h
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										68
									
								
								3rdparty/ncnn/tools/mlir/tf_dialect.h
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,68 @@
 | 
			
		||||
// Tencent is pleased to support the open source community by making ncnn available.
 | 
			
		||||
//
 | 
			
		||||
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
 | 
			
		||||
//
 | 
			
		||||
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
 | 
			
		||||
// in compliance with the License. You may obtain a copy of the License at
 | 
			
		||||
//
 | 
			
		||||
// https://opensource.org/licenses/BSD-3-Clause
 | 
			
		||||
//
 | 
			
		||||
// Unless required by applicable law or agreed to in writing, software distributed
 | 
			
		||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
 | 
			
		||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
 | 
			
		||||
// specific language governing permissions and limitations under the License.
 | 
			
		||||
 | 
			
		||||
#ifndef TF_DIALECT_H
 | 
			
		||||
#define TF_DIALECT_H
 | 
			
		||||
 | 
			
		||||
#include <mlir/Dialect/Traits.h>
 | 
			
		||||
#include <mlir/IR/BuiltinOps.h>
 | 
			
		||||
#include <mlir/IR/Dialect.h>
 | 
			
		||||
#include <mlir/IR/OpImplementation.h>
 | 
			
		||||
#include <mlir/Interfaces/ControlFlowInterfaces.h>
 | 
			
		||||
#include <mlir/Interfaces/DerivedAttributeOpInterface.h>
 | 
			
		||||
#include <mlir/Interfaces/InferTypeOpInterface.h>
 | 
			
		||||
#include <mlir/Interfaces/LoopLikeInterface.h>
 | 
			
		||||
#include <mlir/Interfaces/SideEffectInterfaces.h>
 | 
			
		||||
 | 
			
		||||
#include "tf_traits.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
 | 
			
		||||
namespace TF {
 | 
			
		||||
 | 
			
		||||
class TensorFlowDialect : public mlir::Dialect
 | 
			
		||||
{
 | 
			
		||||
public:
 | 
			
		||||
    TensorFlowDialect(mlir::MLIRContext* context);
 | 
			
		||||
 | 
			
		||||
    static StringRef getDialectNamespace()
 | 
			
		||||
    {
 | 
			
		||||
        return "tf";
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Attribute parseAttribute(DialectAsmParser& parser, Type type) const override;
 | 
			
		||||
 | 
			
		||||
    // Parse a type registered to this dialect.
 | 
			
		||||
    Type parseType(DialectAsmParser& parser) const override;
 | 
			
		||||
 | 
			
		||||
    // Parses resource type with potential subtypes.
 | 
			
		||||
    Type ParseResourceType(DialectAsmParser& parser) const;
 | 
			
		||||
 | 
			
		||||
    // Parse and print variant type. It may have subtypes inferred using shape
 | 
			
		||||
    // inference.
 | 
			
		||||
    Type ParseVariantType(DialectAsmParser& parser) const;
 | 
			
		||||
 | 
			
		||||
    // Registered hook to materialize a constant operation from a given attribute
 | 
			
		||||
    // value with the desired resultant type.
 | 
			
		||||
    Operation* materializeConstant(OpBuilder& builder, Attribute value, Type type, Location loc) override;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
} // namespace TF
 | 
			
		||||
 | 
			
		||||
} // namespace mlir
 | 
			
		||||
 | 
			
		||||
#define GET_OP_CLASSES
 | 
			
		||||
#include "tf_all_ops.h.inc"
 | 
			
		||||
 | 
			
		||||
#endif // TF_DIALECT_H
 | 
			
		||||
							
								
								
									
										18565
									
								
								3rdparty/ncnn/tools/mlir/tf_generated_ops.td
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										18565
									
								
								3rdparty/ncnn/tools/mlir/tf_generated_ops.td
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										617
									
								
								3rdparty/ncnn/tools/mlir/tf_op_base.td
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										617
									
								
								3rdparty/ncnn/tools/mlir/tf_op_base.td
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,617 @@
 | 
			
		||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
// This is the base operation definition file for TensorFlow.
 | 
			
		||||
//
 | 
			
		||||
// This file includes the definition for the TensorFlow dialect, base TensorFlow
 | 
			
		||||
// op, and various commonly used TensorFlow traits, types, attributes, and
 | 
			
		||||
// builders.
 | 
			
		||||
 | 
			
		||||
#ifndef TF_OP_BASE
 | 
			
		||||
#define TF_OP_BASE
 | 
			
		||||
 | 
			
		||||
include "mlir/IR/OpBase.td"
 | 
			
		||||
include "mlir/Interfaces/SideEffectInterfaces.td"
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// TensorFlow dialect definitions
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
def TF_Dialect : Dialect {
 | 
			
		||||
  let name = "tf";
 | 
			
		||||
 | 
			
		||||
  let description = [{
 | 
			
		||||
The TensorFlow dialect.
 | 
			
		||||
 | 
			
		||||
This dialect maps to TensorFlow operations.
 | 
			
		||||
 | 
			
		||||
Invariants:
 | 
			
		||||
 | 
			
		||||
* All values are of Tensor type (in particular, scalars are
 | 
			
		||||
  represented using zero-dimensional tensors);
 | 
			
		||||
 | 
			
		||||
TODO: Make invariants more structured so that we can reference them in ops.
 | 
			
		||||
  }];
 | 
			
		||||
 | 
			
		||||
  let cppNamespace = "::mlir::TF";
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// TensorFlow traits
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
// Specify this trait if the op requires all outputs to have the same type and
 | 
			
		||||
// the inputs either have the same type as result or a ref type corresponding to
 | 
			
		||||
// the result type.
 | 
			
		||||
def TF_OperandsSameAsResultsTypeOrRef : NativeOpTrait<
 | 
			
		||||
  "TF::OperandsSameAsResultsTypeOrRef">;
 | 
			
		||||
 | 
			
		||||
// Op has the same operand and result element types (or type itself, if scalar)
 | 
			
		||||
// after resolving reference types (i.e., after converting reference types to
 | 
			
		||||
// their corresponding TensorFlow or standard types).
 | 
			
		||||
def TF_SameOperandsAndResultElementTypeResolveRef : NativeOpTrait<
 | 
			
		||||
  "TF::SameOperandsAndResultElementTypeResolveRef">;
 | 
			
		||||
 | 
			
		||||
// Op has the same operand and result types after resolving reference types
 | 
			
		||||
// (i.e., after converting reference types to their corresponding TensorFlow or
 | 
			
		||||
// standard types).
 | 
			
		||||
def TF_SameOperandsAndResultTypeResolveRef : NativeOpTrait<
 | 
			
		||||
  "TF::SameOperandsAndResultTypeResolveRef">;
 | 
			
		||||
 | 
			
		||||
// Layout agnostic operations do not depend on the operands data layout (data
 | 
			
		||||
// format), as an example all element wise operations are layout agnostic.
 | 
			
		||||
def TF_LayoutAgnostic : NativeOpTrait<"TF::LayoutAgnostic">;
 | 
			
		||||
 | 
			
		||||
// Trait to indicate operations that cannot be duplicated as they might carry
 | 
			
		||||
// certain state around within their implementations.
 | 
			
		||||
def TF_CannotDuplicate : NativeOpTrait<"TF::CannotDuplicate">;
 | 
			
		||||
 | 
			
		||||
// Trait to indicate an operation cannot be constant folded.
 | 
			
		||||
def TF_NoConstantFold : NativeOpTrait<"TF::NoConstantFold">;
 | 
			
		||||
 | 
			
		||||
// Coefficient wise binary operation with implicit broadcasting support, for
 | 
			
		||||
// example tf.Sub operation.
 | 
			
		||||
def TF_CwiseBinary : NativeOpTrait<"TF::CwiseBinary">;
 | 
			
		||||
 | 
			
		||||
// Coefficient wise unary operation, for example tf.Sqrt operation.
 | 
			
		||||
def TF_CwiseUnary : NativeOpTrait<"TF::CwiseUnary">;
 | 
			
		||||
 | 
			
		||||
// Variant of broadcastable trait that considers TF's subtype behavior.
 | 
			
		||||
class TF_OpIsBroadcastableToRes<int opId, int resId> : And<[
 | 
			
		||||
    TCOpResIsShapedTypePred<opId, resId>,
 | 
			
		||||
    CPred<"mlir::TF::BroadcastCompatible("
 | 
			
		||||
              "$_op.getOperand(" # opId # ").getType(), "
 | 
			
		||||
              "$_op.getResult(" # resId # ").getType())">]>;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TF_AllTypesMatchPred<list<string> values> :
 | 
			
		||||
    CPred<"TF::AreCastCompatible(llvm::makeArrayRef({" #
 | 
			
		||||
      !interleave(values, ", ") # "}))">;
 | 
			
		||||
 | 
			
		||||
class TF_AllTypesMatch<list<string> names> :
 | 
			
		||||
    PredOpTrait<
 | 
			
		||||
        "all of {" # !interleave(names, ", ") #
 | 
			
		||||
          "} have dynamically equal types ",
 | 
			
		||||
        TF_AllTypesMatchPred<
 | 
			
		||||
            !foreach(n, names, !subst("$_self", "$" # n, "$_self.getType()"))>>;
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Rank/Shape helpers.
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
class TF_OperandIsUnrankedPred<int n> :
 | 
			
		||||
  CPred<"$_op.getOperand(" # n # ").getType().isa<UnrankedTensorType>()">;
 | 
			
		||||
 | 
			
		||||
class TF_ResultIsUnrankedPred<int n> :
 | 
			
		||||
  CPred<"$_op.getResult(" # n # ").getType().isa<UnrankedTensorType>()">;
 | 
			
		||||
 | 
			
		||||
// Returns true if the n-th operand has unknown rank or has rank m.
 | 
			
		||||
class TF_OperandHasRank<int n, int m> :
 | 
			
		||||
  PredOpTrait<"operand " # n # " is " # m # "-D",
 | 
			
		||||
    Or<[TF_OperandIsUnrankedPred<n>,
 | 
			
		||||
      CPred<"$_op.getOperand(" # n #
 | 
			
		||||
      ").getType().cast<ShapedType>().getRank() == " # m>]>>;
 | 
			
		||||
 | 
			
		||||
// Returns true if the n-th result has unknown rank or has rank m.
 | 
			
		||||
class TF_ResultHasRank<int n, int m> :
 | 
			
		||||
  PredOpTrait<"result " # n # " is " # m # "-D",
 | 
			
		||||
    Or<[TF_ResultIsUnrankedPred<n>,
 | 
			
		||||
      CPred<"$_op.getResult(" # n #
 | 
			
		||||
      ").getType().cast<ShapedType>().getRank() == " # m>]>>;
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// TensorFlow op side effects
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
class TF_ResourceBase<string resourceKind> :
 | 
			
		||||
  Resource<!strconcat("::mlir::TF::ResourceEffects::", resourceKind)> {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def TF_VariableResource : TF_ResourceBase<"Variable">;
 | 
			
		||||
def TF_StackResource : TF_ResourceBase<"Stack">;
 | 
			
		||||
def TF_TensorArrayResource : TF_ResourceBase<"TensorArray">;
 | 
			
		||||
def TF_SummaryResource : TF_ResourceBase<"Summary">;
 | 
			
		||||
def TF_LookupTableResource : TF_ResourceBase<"LookupTable">;
 | 
			
		||||
def TF_DatasetSeedGeneratorResource : TF_ResourceBase<"DatasetSeedGenerator">;
 | 
			
		||||
def TF_DatasetMemoryCacheResource : TF_ResourceBase<"DatasetMemoryCache">;
 | 
			
		||||
def TF_DatasetIteratorResource : TF_ResourceBase<"DatasetIterator">;
 | 
			
		||||
def TF_TPUEmbeddingResource : TF_ResourceBase<"TPUEmbedding">;
 | 
			
		||||
 | 
			
		||||
def TF_VariableRead : MemRead<TF_VariableResource>;
 | 
			
		||||
def TF_StackRead : MemRead<TF_StackResource>;
 | 
			
		||||
def TF_TensorArrayRead : MemRead<TF_TensorArrayResource>;
 | 
			
		||||
def TF_LookupTableRead : MemRead<TF_LookupTableResource>;
 | 
			
		||||
def TF_DatasetSeedGeneratorRead : MemRead<TF_DatasetSeedGeneratorResource>;
 | 
			
		||||
def TF_DatasetMemoryCacheRead : MemRead<TF_DatasetMemoryCacheResource>;
 | 
			
		||||
def TF_DatasetIteratorRead : MemRead<TF_DatasetIteratorResource>;
 | 
			
		||||
 | 
			
		||||
def TF_VariableWrite : MemWrite<TF_VariableResource>;
 | 
			
		||||
def TF_StackWrite : MemWrite<TF_StackResource>;
 | 
			
		||||
def TF_TensorArrayWrite : MemWrite<TF_TensorArrayResource>;
 | 
			
		||||
def TF_SummaryWrite : MemWrite<TF_SummaryResource>;
 | 
			
		||||
def TF_LookupTableWrite : MemWrite<TF_LookupTableResource>;
 | 
			
		||||
def TF_DatasetSeedGeneratorWrite : MemWrite<TF_DatasetSeedGeneratorResource>;
 | 
			
		||||
def TF_DatasetMemoryCacheWrite : MemWrite<TF_DatasetMemoryCacheResource>;
 | 
			
		||||
def TF_DatasetIteratorWrite : MemWrite<TF_DatasetIteratorResource>;
 | 
			
		||||
 | 
			
		||||
def TF_VariableAlloc : MemAlloc<TF_VariableResource>;
 | 
			
		||||
def TF_StackAlloc : MemAlloc<TF_StackResource>;
 | 
			
		||||
def TF_TensorArrayAlloc : MemAlloc<TF_TensorArrayResource>;
 | 
			
		||||
def TF_SummaryAlloc : MemAlloc<TF_SummaryResource>;
 | 
			
		||||
def TF_LookupTableAlloc : MemAlloc<TF_LookupTableResource>;
 | 
			
		||||
def TF_DatasetSeedGeneratorAlloc : MemAlloc<TF_DatasetSeedGeneratorResource>;
 | 
			
		||||
def TF_DatasetMemoryCacheAlloc : MemAlloc<TF_DatasetMemoryCacheResource>;
 | 
			
		||||
def TF_DatasetIteratorAlloc : MemAlloc<TF_DatasetIteratorResource>;
 | 
			
		||||
 | 
			
		||||
def TF_StackFree : MemFree<TF_StackResource>;
 | 
			
		||||
def TF_TensorArrayFree : MemFree<TF_TensorArrayResource>;
 | 
			
		||||
def TF_SummaryFree : MemFree<TF_SummaryResource>;
 | 
			
		||||
def TF_DatasetSeedGeneratorFree : MemFree<TF_DatasetSeedGeneratorResource>;
 | 
			
		||||
def TF_DatasetMemoryCacheFree : MemFree<TF_DatasetMemoryCacheResource>;
 | 
			
		||||
def TF_DatasetIteratorFree : MemFree<TF_DatasetIteratorResource>;
 | 
			
		||||
 | 
			
		||||
def TF_TPUEmbeddingSideEffect : MemoryEffects<[MemWrite<TF_TPUEmbeddingResource>]>;
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// TensorFlow op definitions
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
class TF_Op<string mnemonic, list<OpTrait> traits = []> :
 | 
			
		||||
    Op<TF_Dialect, mnemonic, traits>;
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// TensorFlow attribute definitions
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
class TF_TensorFlowAttr <string name, string description> :
 | 
			
		||||
    Attr<CPred<"$_self.isa<mlir::TF::" # name # "Attr>()">,
 | 
			
		||||
         "TensorFlow " # description # " attribute">;
 | 
			
		||||
 | 
			
		||||
def TF_ShapeAttr : TF_TensorFlowAttr<"Shape", "shape"> {
 | 
			
		||||
  let returnType = "llvm::Optional<llvm::ArrayRef<int64_t>>";
 | 
			
		||||
  let convertFromStorage = "$_self.cast<mlir::TF::ShapeAttr>().getValue()";
 | 
			
		||||
 | 
			
		||||
  // Create a ranked shape attr by default.
 | 
			
		||||
  let constBuilderCall = "mlir::TF::ShapeAttr::get($_builder.getContext(), $0)";
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def TF_ShapeAttrArray :
 | 
			
		||||
    TypedArrayAttrBase<TF_ShapeAttr, "tensorflow shape attribute array">;
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// TensorFlow type definitions
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
// Any tensor element type defined in the TensorFlow dialect
 | 
			
		||||
def TF_TFDialectType :
 | 
			
		||||
    Type<CPred<"$_self.isa<mlir::TF::TensorFlowType>()">, "TensorFlow type">;
 | 
			
		||||
 | 
			
		||||
// Class for any TensorFlow dialect specific type
 | 
			
		||||
class TF_TensorFlowType <string name, string description> :
 | 
			
		||||
    Type<CPred<"$_self.isa<mlir::TF::" # name # "Type>()">,
 | 
			
		||||
         "TensorFlow " # description # " type">,
 | 
			
		||||
    BuildableType<"getType<mlir::TF::" # name # "Type>()">;
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Reference types
 | 
			
		||||
 | 
			
		||||
// Float reference types
 | 
			
		||||
def TF_Float16Ref : TF_TensorFlowType<"HalfRef", "f16ref">;
 | 
			
		||||
def TF_Float32Ref : TF_TensorFlowType<"FloatRef", "f32ref">;
 | 
			
		||||
def TF_Float64Ref : TF_TensorFlowType<"DoubleRef", "f64ref">;
 | 
			
		||||
def TF_Bfloat16Ref : TF_TensorFlowType<"Bfloat16Ref", "bf16ref">;
 | 
			
		||||
 | 
			
		||||
// Complex reference types
 | 
			
		||||
def TF_Complex64Ref : TF_TensorFlowType<"Complex64Ref", "complex64ref">;
 | 
			
		||||
def TF_Complex128Ref : TF_TensorFlowType<"Complex128Ref", "complex128ref">;
 | 
			
		||||
 | 
			
		||||
// Integer reference types
 | 
			
		||||
def TF_Int8Ref : TF_TensorFlowType<"Int8Ref", "i8ref">;
 | 
			
		||||
def TF_Int16Ref : TF_TensorFlowType<"Int16Ref", "i16ref">;
 | 
			
		||||
def TF_Int32Ref : TF_TensorFlowType<"Int32Ref", "i32ref">;
 | 
			
		||||
def TF_Int64Ref : TF_TensorFlowType<"Int64Ref", "i64ref">;
 | 
			
		||||
 | 
			
		||||
def TF_Uint8Ref : TF_TensorFlowType<"Uint8Ref", "ui8ref">;
 | 
			
		||||
def TF_Uint16Ref : TF_TensorFlowType<"Uint16Ref", "ui16ref">;
 | 
			
		||||
def TF_Uint32Ref : TF_TensorFlowType<"Uint32Ref", "ui32ref">;
 | 
			
		||||
def TF_Uint64Ref : TF_TensorFlowType<"Uint64Ref", "ui64ref">;
 | 
			
		||||
 | 
			
		||||
// Quantized reference types
 | 
			
		||||
def TF_Qint8Ref : TF_TensorFlowType<"Qint8Ref", "qint8ref">;
 | 
			
		||||
def TF_Qint16Ref : TF_TensorFlowType<"Qint16Ref", "qint16ref">;
 | 
			
		||||
def TF_Qint32Ref : TF_TensorFlowType<"Qint32Ref", "qint32ref">;
 | 
			
		||||
def TF_Quint8Ref : TF_TensorFlowType<"Quint8Ref", "quint8ref">;
 | 
			
		||||
def TF_Quint16Ref : TF_TensorFlowType<"Quint16Ref", "quint16ref">;
 | 
			
		||||
 | 
			
		||||
// Other reference types
 | 
			
		||||
def TF_BoolRef : TF_TensorFlowType<"BoolRef", "boolref">;
 | 
			
		||||
def TF_ResourceRef : TF_TensorFlowType<"ResourceRef", "resourceref">;
 | 
			
		||||
def TF_StrRef : TF_TensorFlowType<"StringRef", "stringref">;
 | 
			
		||||
def TF_VariantRef : TF_TensorFlowType<"VariantRef", "variantref">;
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Integer types (including corresponding reference types)
 | 
			
		||||
 | 
			
		||||
def TF_Bool : AnyTypeOf<[I<1>, TF_BoolRef], "bool">;
 | 
			
		||||
 | 
			
		||||
def TF_Int8 : AnyTypeOf<[I8, TF_Int8Ref], "8-bit integer">;
 | 
			
		||||
def TF_Int16 : AnyTypeOf<[I16, TF_Int16Ref], "16-bit integer">;
 | 
			
		||||
def TF_Int32 : AnyTypeOf<[I32, TF_Int32Ref], "32-bit integer">;
 | 
			
		||||
def TF_Int64 : AnyTypeOf<[I64, TF_Int64Ref], "64-bit integer">;
 | 
			
		||||
def TF_I32OrI64 : AnyTypeOf<[I32, I64, TF_Int32Ref, TF_Int64Ref],
 | 
			
		||||
                           "32/64-bit signed integer">;
 | 
			
		||||
 | 
			
		||||
def TF_Uint8 : AnyTypeOf<[UI<8>, TF_Uint8Ref], "8-bit unsigned integer">;
 | 
			
		||||
def TF_Uint16 : AnyTypeOf<[UI<16>, TF_Uint16Ref], "16-bit unsigned integer">;
 | 
			
		||||
def TF_Uint32 : AnyTypeOf<[UI<32>, TF_Uint32Ref], "32-bit unsigned integer">;
 | 
			
		||||
def TF_Uint64 : AnyTypeOf<[UI<64>, TF_Uint64Ref], "64-bit unsigned integer">;
 | 
			
		||||
 | 
			
		||||
// Any unsigned integer type
 | 
			
		||||
def TF_UInt : AnyTypeOf<[TF_Uint8, TF_Uint16, TF_Uint32, TF_Uint64],
 | 
			
		||||
                        "unsigned integer">;
 | 
			
		||||
 | 
			
		||||
// Any signed integer type
 | 
			
		||||
def TF_SInt : AnyTypeOf<[TF_Int8, TF_Int16, TF_Int32, TF_Int64],
 | 
			
		||||
                        "signed integer">;
 | 
			
		||||
 | 
			
		||||
// Any integer type
 | 
			
		||||
def TF_Int : AnyTypeOf<[TF_SInt, TF_UInt], "integer">;
 | 
			
		||||
 | 
			
		||||
// Tensor types
 | 
			
		||||
def TF_BoolTensor : TensorOf<[TF_Bool]>;
 | 
			
		||||
 | 
			
		||||
def TF_IntTensor : TensorOf<[TF_Int]>;
 | 
			
		||||
def TF_Int8Tensor : TensorOf<[TF_Int8]>;
 | 
			
		||||
def TF_Int16Tensor : TensorOf<[TF_Int16]>;
 | 
			
		||||
def TF_Int32Tensor : TensorOf<[TF_Int32]>;
 | 
			
		||||
def TF_Int64Tensor : TensorOf<[TF_Int64]>;
 | 
			
		||||
def TF_I32OrI64Tensor : TensorOf<[TF_I32OrI64]>;
 | 
			
		||||
 | 
			
		||||
def TF_Uint8Tensor : TensorOf<[TF_Uint8]>;
 | 
			
		||||
def TF_Uint16Tensor : TensorOf<[TF_Uint16]>;
 | 
			
		||||
def TF_Uint32Tensor : TensorOf<[TF_Uint32]>;
 | 
			
		||||
def TF_Uint64Tensor : TensorOf<[TF_Uint64]>;
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Quantized types (including corresponding reference types)
 | 
			
		||||
 | 
			
		||||
def TF_Qint8   : AnyTypeOf<
 | 
			
		||||
  [TF_TensorFlowType<"Qint8", "qint8">, TF_Qint8Ref],
 | 
			
		||||
  "8-bit quantized integer">;
 | 
			
		||||
def TF_Qint16  : AnyTypeOf<
 | 
			
		||||
  [TF_TensorFlowType<"Qint16", "qint16">, TF_Qint16Ref],
 | 
			
		||||
  "16-bit quantized integer">;
 | 
			
		||||
def TF_Qint32  : AnyTypeOf<
 | 
			
		||||
  [TF_TensorFlowType<"Qint32", "qint32">, TF_Qint32Ref],
 | 
			
		||||
  "32-bit quantized integer">;
 | 
			
		||||
def TF_Quint8  : AnyTypeOf<
 | 
			
		||||
  [TF_TensorFlowType<"Quint8", "quint8">, TF_Quint8Ref],
 | 
			
		||||
  "8-bit quantized unsigned integer">;
 | 
			
		||||
def TF_Quint16 : AnyTypeOf<
 | 
			
		||||
  [TF_TensorFlowType<"Quint16", "quint16">, TF_Quint16Ref],
 | 
			
		||||
  "16-bit quantized unsigned integer">;
 | 
			
		||||
 | 
			
		||||
// Any quantized type
 | 
			
		||||
def TF_Quantized : AnyTypeOf<
 | 
			
		||||
  [TF_Qint8, TF_Qint16, TF_Qint32, TF_Quint8, TF_Quint16], "quantized">;
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Floating-point types (including corresponding reference types)
 | 
			
		||||
 | 
			
		||||
def TF_Float16 : AnyTypeOf<[F16, TF_Float16Ref], "16-bit float">;
 | 
			
		||||
def TF_Float32 : AnyTypeOf<[F32, TF_Float32Ref], "32-bit float">;
 | 
			
		||||
def TF_Float64 : AnyTypeOf<[F64, TF_Float64Ref], "64-bit float">;
 | 
			
		||||
def TF_Bfloat16 : AnyTypeOf<[BF16, TF_Bfloat16Ref], "bfloat16">;
 | 
			
		||||
 | 
			
		||||
def TF_F32OrF64 : AnyTypeOf<[TF_Float32, TF_Float64], "32/64-bit float">;
 | 
			
		||||
 | 
			
		||||
def TF_Float : AnyTypeOf<
 | 
			
		||||
  [TF_Float16, TF_Float32, TF_Float64, TF_Bfloat16],
 | 
			
		||||
  "floating-point">;
 | 
			
		||||
 | 
			
		||||
// Tensor types
 | 
			
		||||
def TF_FloatTensor : TensorOf<[TF_Float]>;
 | 
			
		||||
def TF_F32OrF64Tensor : TensorOf<[TF_F32OrF64]>;
 | 
			
		||||
def TF_Float16Tensor : TensorOf<[TF_Float16]>;
 | 
			
		||||
def TF_Float32Tensor : TensorOf<[TF_Float32]>;
 | 
			
		||||
def TF_Float64Tensor : TensorOf<[TF_Float64]>;
 | 
			
		||||
def TF_Bfloat16Tensor : TensorOf<[TF_Bfloat16]>;
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Complex types (including corresponding reference types)
 | 
			
		||||
 | 
			
		||||
// TODO(suderman): Remove TF_Complex64 and use a standard ops declaration, along
 | 
			
		||||
// with the associated cleanup.
 | 
			
		||||
def TF_Complex64 : AnyTypeOf<[Complex<F<32>>, TF_Complex64Ref],
 | 
			
		||||
  "64-bit complex">;
 | 
			
		||||
def TF_Complex128 : AnyTypeOf<[Complex<F<64>>, TF_Complex128Ref],
 | 
			
		||||
  "128-bit complex">;
 | 
			
		||||
def TF_Complex : AnyTypeOf<[TF_Complex64, TF_Complex128], "complex">;
 | 
			
		||||
 | 
			
		||||
// Tensor types
 | 
			
		||||
def TF_ComplexTensor : TensorOf<[TF_Complex]>;
 | 
			
		||||
def TF_Complex64Tensor : TensorOf<[TF_Complex64]>;
 | 
			
		||||
def TF_Complex128Tensor : TensorOf<[TF_Complex128]>;
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// String/variant/resource types (including corresponding reference types)
 | 
			
		||||
 | 
			
		||||
def TF_Str : AnyTypeOf<
 | 
			
		||||
  [TF_TensorFlowType<"String", "str">, TF_StrRef], "string">;
 | 
			
		||||
def TF_StrTensor : TensorOf<[TF_Str]>;
 | 
			
		||||
 | 
			
		||||
def TF_Variant : AnyTypeOf<
 | 
			
		||||
  [TF_TensorFlowType<"Variant", "var">, TF_VariantRef], "variant">;
 | 
			
		||||
def TF_VariantTensor : TensorOf<[TF_Variant]>;
 | 
			
		||||
 | 
			
		||||
def TF_Resource : AnyTypeOf<
 | 
			
		||||
  [TF_TensorFlowType<"Resource", "res">, TF_ResourceRef], "resource">;
 | 
			
		||||
def TF_ResourceTensor : TensorOf<[TF_Resource]>;
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Multi-category type constraints
 | 
			
		||||
 | 
			
		||||
def TF_IntOrF32OrF64Tensor: TensorOf<[TF_Int, TF_F32OrF64]>;
 | 
			
		||||
def TF_FpOrI32OrI64Tensor : TensorOf<[TF_Float, TF_I32OrI64]>;
 | 
			
		||||
def TF_IntOrFpTensor : TensorOf<[TF_Int, TF_Float]>;
 | 
			
		||||
def TF_SintOrFpTensor : TensorOf<[TF_SInt, TF_Float]>;
 | 
			
		||||
def TF_FpOrComplexTensor : TensorOf<[TF_Float, TF_Complex]>;
 | 
			
		||||
 | 
			
		||||
def TF_Number : AnyTypeOf<
 | 
			
		||||
  [TF_Int, TF_Float, TF_Quantized, TF_Complex], "number">;
 | 
			
		||||
def TF_NumberTensor : TensorOf<[TF_Number]>;
 | 
			
		||||
 | 
			
		||||
def TF_NumberNotQuantizedOrStr :
 | 
			
		||||
  AnyTypeOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Str]>;
 | 
			
		||||
def TF_NumberNotQuantizedOrStrTensor : TensorOf<[TF_NumberNotQuantizedOrStr]>;
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Tensor and tensor element types
 | 
			
		||||
 | 
			
		||||
// Any tensor element type allowed in TensorFlow ops
 | 
			
		||||
// (see https://www.tensorflow.org/api_docs/python/tf/dtypes/DType)
 | 
			
		||||
def TF_ElementType : Type<Or<[TF_Float.predicate,
 | 
			
		||||
                              TF_Complex.predicate,
 | 
			
		||||
                              TF_Int.predicate,
 | 
			
		||||
                              TF_Bool.predicate,
 | 
			
		||||
                              TF_TFDialectType.predicate]>,
 | 
			
		||||
                          "tf.dtype">;
 | 
			
		||||
 | 
			
		||||
// Any TensorFlow tensor type
 | 
			
		||||
def TF_Tensor : TensorOf<[TF_ElementType]>;
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// TensorFlow attribute definitions
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Tensorflow devices metadata
 | 
			
		||||
 | 
			
		||||
// Tensorflow GPU device metadata.
 | 
			
		||||
def TF_GpuDeviceMetadata : StructAttr<"GpuDeviceMetadata", TF_Dialect, [
 | 
			
		||||
    // GPU device compute capability: major:minor.
 | 
			
		||||
    StructFieldAttr<"cc_major", I32Attr>,
 | 
			
		||||
    StructFieldAttr<"cc_minor", I32Attr>
 | 
			
		||||
]>;
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// String attribute constraints
 | 
			
		||||
 | 
			
		||||
// A string attribute whose value are one of the values in `cases`.
 | 
			
		||||
class TF_AnyStrAttrOf<list<string> cases> : StringBasedAttr<
 | 
			
		||||
  CPred<!foldl(
 | 
			
		||||
      "$_self.cast<StringAttr>().getValue() == \"" # !head(cases) # "\"",
 | 
			
		||||
      !foreach(case, !tail(cases),
 | 
			
		||||
               "$_self.cast<StringAttr>().getValue() == \"" # case # "\""),
 | 
			
		||||
      prev, cur, prev # " || " # cur)>,
 | 
			
		||||
  "string attribute whose value is " #
 | 
			
		||||
    !foldl(/*init*/!head(cases), /*list*/!tail(cases),
 | 
			
		||||
           prev, cur, prev # ", or " # cur)>;
 | 
			
		||||
 | 
			
		||||
// TODO: Use EnumAttr to define the common attribute cases
 | 
			
		||||
 | 
			
		||||
def TF_ConvnetDataFormatAttr : StringBasedAttr<
 | 
			
		||||
    CPred<"$_self.cast<StringAttr>().getValue() == \"NHWC\" || " #
 | 
			
		||||
          "$_self.cast<StringAttr>().getValue() == \"NCHW\"">,
 | 
			
		||||
    "'NHWC' or 'NCHW' convnet data format">;
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Type attributes
 | 
			
		||||
 | 
			
		||||
// A derived attribute that returns the size of `idx`-th ODS-declared variadic
 | 
			
		||||
// operand.
 | 
			
		||||
class TF_DerivedOperandSizeAttr<int idx> : DerivedAttr<
 | 
			
		||||
  "size_t",
 | 
			
		||||
  "auto range = getODSOperands(" # idx # ");\n"
 | 
			
		||||
  "return std::distance(range.begin(), range.end());",
 | 
			
		||||
  [{ $_builder.getI64IntegerAttr($_self) }]>;
 | 
			
		||||
 | 
			
		||||
// A derived attribute that returns the element type of `idx`-th ODS-declared
 | 
			
		||||
// operand. If the `idx`-th operand is a variadic operand, then this attribute
 | 
			
		||||
// just returns the element type of its first tensor, which is only meaningful
 | 
			
		||||
// when the variadic operand has at least one tensor and the tensors all have
 | 
			
		||||
// the same element type.
 | 
			
		||||
class TF_DerivedOperandTypeAttr<int idx> : DerivedTypeAttr<
 | 
			
		||||
  "return mlir::getElementTypeOrSelf(*getODSOperands(" # idx # ").begin());">;
 | 
			
		||||
 | 
			
		||||
// A derived attribute that returns the element types of the tensors in the
 | 
			
		||||
// actual value pack that corresponds to the `idx`-th ODS-declared variadic
 | 
			
		||||
// operand. This returns a list of element types so it is used for variadic
 | 
			
		||||
// operands that can have different element types.
 | 
			
		||||
class TF_DerivedOperandTypeListAttr<int idx> : DerivedAttr<
 | 
			
		||||
  "mlir::OperandElementTypeRange",
 | 
			
		||||
  "auto values = getODSOperands(" # idx # ");\n"
 | 
			
		||||
  "return {mlir::OperandElementTypeIterator(values.begin()), "
 | 
			
		||||
          "mlir::OperandElementTypeIterator(values.end())};",
 | 
			
		||||
  [{
 | 
			
		||||
    ArrayAttr::get($_ctx, 
 | 
			
		||||
    [&]() {
 | 
			
		||||
      llvm::SmallVector<Attribute, 4> ret;
 | 
			
		||||
      for (auto t : $_self)
 | 
			
		||||
        ret.push_back(TypeAttr::get(t));
 | 
			
		||||
      return ret;
 | 
			
		||||
    }())
 | 
			
		||||
  }]
 | 
			
		||||
>;
 | 
			
		||||
 | 
			
		||||
// A derived attribute that returns the shapes of the tensors in the actual
 | 
			
		||||
// value pack that corresponds to the `idx`-th ODS-declared variadic operand.
 | 
			
		||||
// This returns a list of shapes so it is used for variadic operands that
 | 
			
		||||
// can have different shapes.
 | 
			
		||||
class TF_DerivedOperandShapeListAttr<int idx> : DerivedAttr<
 | 
			
		||||
  "::mlir::TF::OperandShapeRange",
 | 
			
		||||
  "auto values = getODSOperands(" # idx # ");\n"
 | 
			
		||||
  "return {mlir::TF::OperandShapeIterator(values.begin()), "
 | 
			
		||||
          "mlir::TF::OperandShapeIterator(values.end())};",
 | 
			
		||||
  [{
 | 
			
		||||
    ArrayAttr::get($_ctx,
 | 
			
		||||
      [&](){
 | 
			
		||||
        llvm::SmallVector<Attribute, 4> ret;
 | 
			
		||||
        for (auto shape : $_self)
 | 
			
		||||
          ret.push_back(mlir::TF::ShapeAttr::get($_ctx, shape));
 | 
			
		||||
        return ret;
 | 
			
		||||
      }())
 | 
			
		||||
  }]
 | 
			
		||||
>;
 | 
			
		||||
 | 
			
		||||
// A derived attribute that returns the size of `idx`-th ODS-declared variadic
 | 
			
		||||
// result.
 | 
			
		||||
class TF_DerivedResultSizeAttr<int idx> : DerivedAttr<
 | 
			
		||||
  "size_t",
 | 
			
		||||
  "auto range = getODSResults(" # idx # ");\n"
 | 
			
		||||
  "return std::distance(range.begin(), range.end());",
 | 
			
		||||
  [{ $_builder.getI64IntegerAttr($_self) }]>;
 | 
			
		||||
 | 
			
		||||
// A derived attribute that returns the element type of `idx`-th ODS-declared
 | 
			
		||||
// result. If the `idx`-th result is a variadic result, then this attribute
 | 
			
		||||
// just returns the element type of its first tensor, which is only meaningful
 | 
			
		||||
// when the variadic result has at least one tensor and the tensors all have
 | 
			
		||||
// the same element type.
 | 
			
		||||
class TF_DerivedResultTypeAttr<int idx> : DerivedTypeAttr<
 | 
			
		||||
  "return mlir::getElementTypeOrSelf(*getODSResults(" # idx # ").begin());">;
 | 
			
		||||
 | 
			
		||||
// A derived attribute that returns the element types of the tensors in the
 | 
			
		||||
// actual value pack that corresponds to the `idx`-th ODS-declared variadic
 | 
			
		||||
// result. This returns a list of element types so it is used for variadic
 | 
			
		||||
// results that can have different element types.
 | 
			
		||||
class TF_DerivedResultTypeListAttr<int idx> : DerivedAttr<
 | 
			
		||||
  "mlir::ResultElementTypeRange",
 | 
			
		||||
  "auto values = getODSResults(" # idx # ");\n"
 | 
			
		||||
  "return {mlir::ResultElementTypeIterator(values.begin()), "
 | 
			
		||||
          "mlir::ResultElementTypeIterator(values.end())};",
 | 
			
		||||
  [{
 | 
			
		||||
    ArrayAttr::get($_ctx, 
 | 
			
		||||
    [&]() {
 | 
			
		||||
      llvm::SmallVector<Attribute, 4> ret;
 | 
			
		||||
      for (auto t : $_self)
 | 
			
		||||
        ret.push_back(TypeAttr::get(t));
 | 
			
		||||
      return ret;
 | 
			
		||||
    }())
 | 
			
		||||
  }]
 | 
			
		||||
>;
 | 
			
		||||
 | 
			
		||||
// A derived attribute that returns the shapes of the tensors in the actual
 | 
			
		||||
// value pack that corresponds to the `idx`-th ODS-declared variadic result.
 | 
			
		||||
// This returns a list of shapes so it is used for variadic results that
 | 
			
		||||
// can have different shapes.
 | 
			
		||||
class TF_DerivedResultShapeListAttr<int idx> : DerivedAttr<
 | 
			
		||||
  "mlir::TF::ResultShapeRange",
 | 
			
		||||
  "auto values = getODSResults(" # idx # ");\n"
 | 
			
		||||
  "return {mlir::TF::ResultShapeIterator(values.begin()), "
 | 
			
		||||
          "mlir::TF::ResultShapeIterator(values.end())};",
 | 
			
		||||
  [{
 | 
			
		||||
    ArrayAttr::get($_ctx,
 | 
			
		||||
      [&](){
 | 
			
		||||
        llvm::SmallVector<Attribute, 4> ret;
 | 
			
		||||
        for (auto shape : $_self)
 | 
			
		||||
          ret.push_back(mlir::TF::ShapeAttr::get($_ctx, shape));
 | 
			
		||||
        return ret;
 | 
			
		||||
      }())
 | 
			
		||||
  }]
 | 
			
		||||
>;
 | 
			
		||||
 | 
			
		||||
// A derived attribute that returns the shape of the first result type.
 | 
			
		||||
def TF_DerivedResultShapeAttr : DerivedAttr<"ShapedType",
 | 
			
		||||
  "return (*getOperation()->result_type_begin()).cast<ShapedType>();",
 | 
			
		||||
  [{ mlir::TF::ShapeAttr::get($_ctx, $_self) }]>;
 | 
			
		||||
 | 
			
		||||
def TF_IntTypeAttr : TypeAttrBase<"IntegerType", "integer type"> {
 | 
			
		||||
  let returnType = "Type";
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// TensorFlow common builders
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
// Mixin class defining a builder for binary ops supporting broadcast
 | 
			
		||||
// behavior. The result type has the same element type as both operands.
 | 
			
		||||
class WithBroadcastableBinOpBuilder {
 | 
			
		||||
  list<OpBuilder> builders = [
 | 
			
		||||
    OpBuilder<(ins "Value":$x, "Value":$y),
 | 
			
		||||
    [{
 | 
			
		||||
  auto resultType =
 | 
			
		||||
      OpTrait::util::getBroadcastedType(x.getType(), y.getType());
 | 
			
		||||
  if (!resultType)
 | 
			
		||||
    mlir::emitError($_state.location, "non-broadcastable operands");
 | 
			
		||||
  return build($_builder, $_state, resultType, x, y);
 | 
			
		||||
}]>];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Mixin class defining a builder for comparison ops supporting broadcast
 | 
			
		||||
// behavior. The result type has bool element type.
 | 
			
		||||
class WithBroadcastableCmpOpBuilder {
 | 
			
		||||
  list<OpBuilder> builders = [
 | 
			
		||||
    OpBuilder<(ins "Value":$x, "Value":$y),
 | 
			
		||||
    [{
 | 
			
		||||
  Type resultType;
 | 
			
		||||
  if (x.getType().isa<UnrankedTensorType>() ||
 | 
			
		||||
      y.getType().isa<UnrankedTensorType>()) {
 | 
			
		||||
    resultType = UnrankedTensorType::get($_builder.getI1Type());
 | 
			
		||||
  } else {
 | 
			
		||||
    SmallVector<int64_t, 4> resultShape;
 | 
			
		||||
    if (!OpTrait::util::getBroadcastedShape(
 | 
			
		||||
            x.getType().cast<ShapedType>().getShape(),
 | 
			
		||||
            y.getType().cast<ShapedType>().getShape(), resultShape)) {
 | 
			
		||||
      mlir::emitError($_state.location,
 | 
			
		||||
                      "operands have no broadcastable shapes");
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    resultType = RankedTensorType::get(resultShape, $_builder.getI1Type());
 | 
			
		||||
  }
 | 
			
		||||
  return build($_builder, $_state, resultType, x, y);
 | 
			
		||||
}]>];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#endif // TF_OP_BASE
 | 
			
		||||
							
								
								
									
										2037
									
								
								3rdparty/ncnn/tools/mlir/tf_ops.td
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										2037
									
								
								3rdparty/ncnn/tools/mlir/tf_ops.td
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										106
									
								
								3rdparty/ncnn/tools/mlir/tf_side_effects.h
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										106
									
								
								3rdparty/ncnn/tools/mlir/tf_side_effects.h
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,106 @@
 | 
			
		||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
// This is the side effect definition file for TensorFlow.
 | 
			
		||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_SIDE_EFFECTS_H_
 | 
			
		||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_SIDE_EFFECTS_H_
 | 
			
		||||
 | 
			
		||||
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace TF {
 | 
			
		||||
namespace ResourceEffects {
 | 
			
		||||
 | 
			
		||||
struct Variable : ::mlir::SideEffects::Resource::Base<Variable>
 | 
			
		||||
{
 | 
			
		||||
    StringRef getName() final
 | 
			
		||||
    {
 | 
			
		||||
        return "Variable";
 | 
			
		||||
    }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct Stack : ::mlir::SideEffects::Resource::Base<Stack>
 | 
			
		||||
{
 | 
			
		||||
    StringRef getName() final
 | 
			
		||||
    {
 | 
			
		||||
        return "Stack";
 | 
			
		||||
    }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct TensorArray : ::mlir::SideEffects::Resource::Base<TensorArray>
 | 
			
		||||
{
 | 
			
		||||
    StringRef getName() final
 | 
			
		||||
    {
 | 
			
		||||
        return "TensorArray";
 | 
			
		||||
    }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct Summary : ::mlir::SideEffects::Resource::Base<Summary>
 | 
			
		||||
{
 | 
			
		||||
    StringRef getName() final
 | 
			
		||||
    {
 | 
			
		||||
        return "Summary";
 | 
			
		||||
    }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct LookupTable : ::mlir::SideEffects::Resource::Base<LookupTable>
 | 
			
		||||
{
 | 
			
		||||
    StringRef getName() final
 | 
			
		||||
    {
 | 
			
		||||
        return "LookupTable";
 | 
			
		||||
    }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct DatasetSeedGenerator
 | 
			
		||||
    : ::mlir::SideEffects::Resource::Base<DatasetSeedGenerator>
 | 
			
		||||
{
 | 
			
		||||
    StringRef getName() final
 | 
			
		||||
    {
 | 
			
		||||
        return "DatasetSeedGenerator";
 | 
			
		||||
    }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct DatasetMemoryCache
 | 
			
		||||
    : ::mlir::SideEffects::Resource::Base<DatasetMemoryCache>
 | 
			
		||||
{
 | 
			
		||||
    StringRef getName() final
 | 
			
		||||
    {
 | 
			
		||||
        return "DatasetMemoryCache";
 | 
			
		||||
    }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct DatasetIterator : ::mlir::SideEffects::Resource::Base<DatasetIterator>
 | 
			
		||||
{
 | 
			
		||||
    StringRef getName() final
 | 
			
		||||
    {
 | 
			
		||||
        return "DatasetIterator";
 | 
			
		||||
    }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Special resource type to track TPU Embedding specific ops, which must execute
 | 
			
		||||
// but do not have side effects with one another or with resource variable ops.
 | 
			
		||||
struct TPUEmbedding : ::mlir::SideEffects::Resource::Base<TPUEmbedding>
 | 
			
		||||
{
 | 
			
		||||
    StringRef getName() final
 | 
			
		||||
    {
 | 
			
		||||
        return "TPUEmbedding";
 | 
			
		||||
    }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
} // namespace ResourceEffects
 | 
			
		||||
} // namespace TF
 | 
			
		||||
} // namespace mlir
 | 
			
		||||
 | 
			
		||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_SIDE_EFFECTS_H_
 | 
			
		||||
							
								
								
									
										189
									
								
								3rdparty/ncnn/tools/mlir/tf_traits.h
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										189
									
								
								3rdparty/ncnn/tools/mlir/tf_traits.h
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,189 @@
 | 
			
		||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
// This file defines the op traits used in the MLIR TensorFlow dialect.
 | 
			
		||||
 | 
			
		||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TRAITS_H_
 | 
			
		||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TRAITS_H_
 | 
			
		||||
 | 
			
		||||
#include "mlir/IR/BuiltinTypes.h"                 // from @llvm-project
 | 
			
		||||
#include "mlir/IR/OpDefinition.h"                 // from @llvm-project
 | 
			
		||||
#include "mlir/IR/TypeUtilities.h"                // from @llvm-project
 | 
			
		||||
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
 | 
			
		||||
#include "mlir/Support/LogicalResult.h"           // from @llvm-project
 | 
			
		||||
#include "tf_types.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace OpTrait {
 | 
			
		||||
namespace TF {
 | 
			
		||||
 | 
			
		||||
// Verifies if 'ref_type' is a REF type corresponding to 'type'.
 | 
			
		||||
static inline LogicalResult VerifyRefTypeMatch(mlir::Type type,
 | 
			
		||||
        mlir::Type maybe_ref_type)
 | 
			
		||||
{
 | 
			
		||||
    if (auto ref_type = maybe_ref_type.dyn_cast<mlir::TF::TensorFlowRefType>())
 | 
			
		||||
        return success(ref_type.RemoveRef().getTypeID() == type.getTypeID());
 | 
			
		||||
    return failure();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// This class provides verification for ops that are known to have the same
 | 
			
		||||
// result types and all operands are either of the same type as result or a REF
 | 
			
		||||
// type corresponding to the result type.
 | 
			
		||||
// TODO(jpienaar): Update the name and the description.
 | 
			
		||||
template<typename ConcreteType>
 | 
			
		||||
class OperandsSameAsResultsTypeOrRef
 | 
			
		||||
    : public TraitBase<ConcreteType, OperandsSameAsResultsTypeOrRef>
 | 
			
		||||
{
 | 
			
		||||
public:
 | 
			
		||||
    static LogicalResult verifyTrait(Operation* op)
 | 
			
		||||
    {
 | 
			
		||||
        LogicalResult shapeMatch = impl::verifySameOperandsAndResultShape(op);
 | 
			
		||||
        if (failed(shapeMatch)) return shapeMatch;
 | 
			
		||||
        Type type = op->getResult(0).getType();
 | 
			
		||||
        // Verify that the first result type is same as the rest of the results.
 | 
			
		||||
        // We skip the comparison against itself.
 | 
			
		||||
        for (auto result_type : llvm::drop_begin(op->getResultTypes(), 1))
 | 
			
		||||
        {
 | 
			
		||||
            if (!mlir::TF::HasCompatibleElementTypes(type, result_type))
 | 
			
		||||
                return op->emitOpError()
 | 
			
		||||
                       << "requires all return types to have compatible element types";
 | 
			
		||||
        }
 | 
			
		||||
        for (auto operand_type : op->getOperandTypes())
 | 
			
		||||
        {
 | 
			
		||||
            if (!mlir::TF::HasCompatibleElementTypes(
 | 
			
		||||
                        operand_type, type, /*may_ignore_ref_type_lhs=*/true))
 | 
			
		||||
                return op->emitError() << "requires all operands and results to have "
 | 
			
		||||
                       "compatible element types";
 | 
			
		||||
        }
 | 
			
		||||
        return success();
 | 
			
		||||
    }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
namespace detail {
 | 
			
		||||
inline LogicalResult verifySameOperandsAndResultElementTypeResolveRef(
 | 
			
		||||
    Operation* op)
 | 
			
		||||
{
 | 
			
		||||
    Type element_type;
 | 
			
		||||
    if (op->getNumResults() > 0)
 | 
			
		||||
    {
 | 
			
		||||
        element_type = mlir::TF::GetElementTypeOrSelfResolveRef(op->getResult(0).getType());
 | 
			
		||||
    }
 | 
			
		||||
    else if (op->getNumOperands() > 0)
 | 
			
		||||
    {
 | 
			
		||||
        element_type = mlir::TF::GetElementTypeOrSelfResolveRef(op->getOperand(0).getType());
 | 
			
		||||
    }
 | 
			
		||||
    else
 | 
			
		||||
    {
 | 
			
		||||
        // Nothing to check.
 | 
			
		||||
        return success();
 | 
			
		||||
    }
 | 
			
		||||
    // Verify that all result element types are compatible to `element_type`.
 | 
			
		||||
    for (const auto& result_type : op->getResultTypes())
 | 
			
		||||
    {
 | 
			
		||||
        if (mlir::TF::GetElementTypeOrSelfResolveRef(result_type) != element_type)
 | 
			
		||||
        {
 | 
			
		||||
            return op->emitOpError(
 | 
			
		||||
                       "requires compatible element types for all operands and results");
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    // Verify that all operand element types are compatible to `element_type`.
 | 
			
		||||
    for (const auto& operand_type : op->getOperandTypes())
 | 
			
		||||
    {
 | 
			
		||||
        if (mlir::TF::GetElementTypeOrSelfResolveRef(operand_type) != element_type)
 | 
			
		||||
        {
 | 
			
		||||
            return op->emitOpError(
 | 
			
		||||
                       "requires compatible element types for all operands and results");
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    return success();
 | 
			
		||||
}
 | 
			
		||||
} // namespace detail
 | 
			
		||||
 | 
			
		||||
// Verifies that op has the same operand and result element types (or type
 | 
			
		||||
// itself, if scalar) after resolving reference types (i.e., after converting
 | 
			
		||||
// reference types to their corresponding TensorFlow or standard types).
 | 
			
		||||
template<typename ConcreteType>
 | 
			
		||||
class SameOperandsAndResultElementTypeResolveRef
 | 
			
		||||
    : public TraitBase<ConcreteType,
 | 
			
		||||
      SameOperandsAndResultElementTypeResolveRef>
 | 
			
		||||
{
 | 
			
		||||
public:
 | 
			
		||||
    static LogicalResult verifyTrait(Operation* op)
 | 
			
		||||
    {
 | 
			
		||||
        return detail::verifySameOperandsAndResultElementTypeResolveRef(op);
 | 
			
		||||
    }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Verifies that op has the same operand and result types after resolving
 | 
			
		||||
// reference types (i.e., after converting reference types to their
 | 
			
		||||
// corresponding TensorFlow or standard types).
 | 
			
		||||
template<typename ConcreteType>
 | 
			
		||||
class SameOperandsAndResultTypeResolveRef
 | 
			
		||||
    : public TraitBase<ConcreteType, SameOperandsAndResultTypeResolveRef>
 | 
			
		||||
{
 | 
			
		||||
public:
 | 
			
		||||
    static LogicalResult verifyTrait(Operation* op)
 | 
			
		||||
    {
 | 
			
		||||
        if (failed(impl::verifySameOperandsAndResultShape(op))) return failure();
 | 
			
		||||
        return detail::verifySameOperandsAndResultElementTypeResolveRef(op);
 | 
			
		||||
    }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Layout agnostic operations do not depend on the operands data layout (data
 | 
			
		||||
// format), as and example all element wise operations are layout agnostic.
 | 
			
		||||
template<typename ConcreteType>
 | 
			
		||||
class LayoutAgnostic : public TraitBase<ConcreteType, LayoutAgnostic>
 | 
			
		||||
{
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Trait to indicate operations that cannot be duplicated as they might carry
 | 
			
		||||
// certain state around within their implementations.
 | 
			
		||||
template<typename ConcreteType>
 | 
			
		||||
class CannotDuplicate : public TraitBase<ConcreteType, CannotDuplicate>
 | 
			
		||||
{
 | 
			
		||||
public:
 | 
			
		||||
    static LogicalResult verifyTrait(Operation* op)
 | 
			
		||||
    {
 | 
			
		||||
        if (MemoryEffectOpInterface::hasNoEffect(op))
 | 
			
		||||
            return op->emitError(
 | 
			
		||||
                       "operations with no side effects cannot have CannotDuplicate trait");
 | 
			
		||||
        return success();
 | 
			
		||||
    }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Trait to indicate an operation cannot be constant folded.
 | 
			
		||||
template<typename ConcreteType>
 | 
			
		||||
class NoConstantFold : public TraitBase<ConcreteType, NoConstantFold>
 | 
			
		||||
{
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Coefficient-wise binary operation with implicit broadcasting support, for
 | 
			
		||||
// example tf.Sub operation.
 | 
			
		||||
template<typename ConcreteType>
 | 
			
		||||
class CwiseBinary : public TraitBase<ConcreteType, CwiseBinary>
 | 
			
		||||
{
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Coefficient-wise unary operation, for example tf.Sqrt operation.
 | 
			
		||||
template<typename ConcreteType>
 | 
			
		||||
class CwiseUnary : public TraitBase<ConcreteType, CwiseUnary>
 | 
			
		||||
{
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
} // namespace TF
 | 
			
		||||
} // namespace OpTrait
 | 
			
		||||
} // namespace mlir
 | 
			
		||||
 | 
			
		||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TRAITS_H_
 | 
			
		||||
							
								
								
									
										462
									
								
								3rdparty/ncnn/tools/mlir/tf_types.cc
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										462
									
								
								3rdparty/ncnn/tools/mlir/tf_types.cc
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,462 @@
 | 
			
		||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "tf_types.h"
 | 
			
		||||
 | 
			
		||||
#include "llvm/Support/ErrorHandling.h"
 | 
			
		||||
#include "mlir/Dialect/Traits.h"   // from @llvm-project
 | 
			
		||||
#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 | 
			
		||||
#include "mlir/IR/Dialect.h"       // from @llvm-project
 | 
			
		||||
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
// Returns the shape of the given value if it's ranked; returns llvm::None
 | 
			
		||||
// otherwise.
 | 
			
		||||
llvm::Optional<llvm::ArrayRef<int64_t> > GetShape(mlir::Value value)
 | 
			
		||||
{
 | 
			
		||||
    auto shaped_type = value.getType().cast<mlir::ShapedType>();
 | 
			
		||||
    if (shaped_type.hasRank()) return shaped_type.getShape();
 | 
			
		||||
    return llvm::None;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Merges cast compatible shapes and returns a more refined shape. The two
 | 
			
		||||
// shapes are cast compatible if they have the same rank and at each dimension,
 | 
			
		||||
// either both have same size or one of them is dynamic. Returns false if the
 | 
			
		||||
// given shapes are not cast compatible. The refined shape is same or more
 | 
			
		||||
// precise than the two input shapes.
 | 
			
		||||
bool GetCastCompatibleShape(llvm::ArrayRef<int64_t> a_shape,
 | 
			
		||||
                            llvm::ArrayRef<int64_t> b_shape,
 | 
			
		||||
                            llvm::SmallVectorImpl<int64_t>* refined_shape)
 | 
			
		||||
{
 | 
			
		||||
    if (a_shape.size() != b_shape.size()) return false;
 | 
			
		||||
    int64_t rank = a_shape.size();
 | 
			
		||||
    refined_shape->reserve(rank);
 | 
			
		||||
    for (auto dims : llvm::zip(a_shape, b_shape))
 | 
			
		||||
    {
 | 
			
		||||
        int64_t dim1 = std::get<0>(dims);
 | 
			
		||||
        int64_t dim2 = std::get<1>(dims);
 | 
			
		||||
 | 
			
		||||
        if (mlir::ShapedType::isDynamic(dim1))
 | 
			
		||||
        {
 | 
			
		||||
            refined_shape->push_back(dim2);
 | 
			
		||||
            continue;
 | 
			
		||||
        }
 | 
			
		||||
        if (mlir::ShapedType::isDynamic(dim2))
 | 
			
		||||
        {
 | 
			
		||||
            refined_shape->push_back(dim1);
 | 
			
		||||
            continue;
 | 
			
		||||
        }
 | 
			
		||||
        if (dim1 == dim2)
 | 
			
		||||
        {
 | 
			
		||||
            refined_shape->push_back(dim1);
 | 
			
		||||
            continue;
 | 
			
		||||
        }
 | 
			
		||||
        return false;
 | 
			
		||||
    }
 | 
			
		||||
    return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace TF {
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Utility iterators
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
OperandShapeIterator::OperandShapeIterator(Operation::operand_iterator it)
 | 
			
		||||
    : llvm::mapped_iterator<Operation::operand_iterator,
 | 
			
		||||
      llvm::Optional<ArrayRef<int64_t> > (*)(Value)>(
 | 
			
		||||
          it, &GetShape)
 | 
			
		||||
{
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
ResultShapeIterator::ResultShapeIterator(Operation::result_iterator it)
 | 
			
		||||
    : llvm::mapped_iterator<Operation::result_iterator,
 | 
			
		||||
      llvm::Optional<ArrayRef<int64_t> > (*)(Value)>(
 | 
			
		||||
          it, &GetShape)
 | 
			
		||||
{
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// TF types helper functions
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
bool TensorFlowType::classof(Type type)
 | 
			
		||||
{
 | 
			
		||||
    return type.getDialect().getNamespace() == "tf";
 | 
			
		||||
}
 | 
			
		||||
bool TensorFlowRefType::classof(Type type)
 | 
			
		||||
{
 | 
			
		||||
    return type.isa<
 | 
			
		||||
#define HANDLE_TF_TYPE(tftype, enumerant, name)
 | 
			
		||||
#define HANDLE_TF_REF_TYPE(tftype, enumerant, name)  tftype##Type,
 | 
			
		||||
#define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type
 | 
			
		||||
// NOLINTNEXTLINE
 | 
			
		||||
#include "tf_types.def"
 | 
			
		||||
           >();
 | 
			
		||||
}
 | 
			
		||||
bool TensorFlowTypeWithSubtype::classof(Type type)
 | 
			
		||||
{
 | 
			
		||||
    return type.isa<ResourceType, VariantType>();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TensorFlowType TensorFlowRefType::get(Type type)
 | 
			
		||||
{
 | 
			
		||||
    MLIRContext* ctx = type.getContext();
 | 
			
		||||
    type = getElementTypeOrSelf(type);
 | 
			
		||||
    if (type.isF16())
 | 
			
		||||
    {
 | 
			
		||||
        return HalfRefType::get(ctx);
 | 
			
		||||
    }
 | 
			
		||||
    else if (type.isF32())
 | 
			
		||||
    {
 | 
			
		||||
        return FloatRefType::get(ctx);
 | 
			
		||||
    }
 | 
			
		||||
    else if (type.isF64())
 | 
			
		||||
    {
 | 
			
		||||
        return DoubleRefType::get(ctx);
 | 
			
		||||
    }
 | 
			
		||||
    else if (type.isBF16())
 | 
			
		||||
    {
 | 
			
		||||
        return Bfloat16RefType::get(ctx);
 | 
			
		||||
    }
 | 
			
		||||
    else if (auto complex_type = type.dyn_cast<ComplexType>())
 | 
			
		||||
    {
 | 
			
		||||
        Type etype = complex_type.getElementType();
 | 
			
		||||
        if (etype.isF32())
 | 
			
		||||
        {
 | 
			
		||||
            return Complex64RefType::get(ctx);
 | 
			
		||||
        }
 | 
			
		||||
        else if (etype.isF64())
 | 
			
		||||
        {
 | 
			
		||||
            return Complex128RefType::get(ctx);
 | 
			
		||||
        }
 | 
			
		||||
        llvm_unreachable("unexpected complex type");
 | 
			
		||||
    }
 | 
			
		||||
    else if (auto itype = type.dyn_cast<IntegerType>())
 | 
			
		||||
    {
 | 
			
		||||
        switch (itype.getWidth())
 | 
			
		||||
        {
 | 
			
		||||
        case 1:
 | 
			
		||||
            return BoolRefType::get(ctx);
 | 
			
		||||
        case 8:
 | 
			
		||||
            return itype.isUnsigned() ? TensorFlowType(Uint8RefType::get(ctx))
 | 
			
		||||
                   : Int8RefType::get(ctx);
 | 
			
		||||
        case 16:
 | 
			
		||||
            return itype.isUnsigned() ? TensorFlowType(Uint16RefType::get(ctx))
 | 
			
		||||
                   : Int16RefType::get(ctx);
 | 
			
		||||
        case 32:
 | 
			
		||||
            return itype.isUnsigned() ? TensorFlowType(Uint32RefType::get(ctx))
 | 
			
		||||
                   : Int32RefType::get(ctx);
 | 
			
		||||
        case 64:
 | 
			
		||||
            return itype.isUnsigned() ? TensorFlowType(Uint64RefType::get(ctx))
 | 
			
		||||
                   : Int64RefType::get(ctx);
 | 
			
		||||
        default:
 | 
			
		||||
            llvm_unreachable("unexpected integer type");
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
#define HANDLE_TF_TYPE(tftype, enumerant, name)          \
 | 
			
		||||
    if (auto derived_ty = type.dyn_cast<tftype##Type>()) \
 | 
			
		||||
        return tftype##RefType::get(ctx);
 | 
			
		||||
 | 
			
		||||
#define HANDLE_TF_REF_TYPE(tftype, enumerant, name)
 | 
			
		||||
// NOLINTNEXTLINE
 | 
			
		||||
#include "tf_types.def"
 | 
			
		||||
    llvm_unreachable("unexpected type kind");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Type TensorFlowRefType::RemoveRef()
 | 
			
		||||
{
 | 
			
		||||
    MLIRContext* ctx = getContext();
 | 
			
		||||
    if (isa<HalfRefType>()) return mlir::FloatType::getF16(ctx);
 | 
			
		||||
    if (isa<FloatRefType>()) return mlir::FloatType::getF32(ctx);
 | 
			
		||||
    if (isa<DoubleRefType>()) return mlir::FloatType::getF64(ctx);
 | 
			
		||||
    if (isa<Bfloat16RefType>()) return mlir::FloatType::getBF16(ctx);
 | 
			
		||||
    if (isa<BoolRefType>()) return mlir::IntegerType::get(ctx, 1);
 | 
			
		||||
    if (isa<Int8RefType>()) return mlir::IntegerType::get(ctx, 8);
 | 
			
		||||
    if (isa<Int16RefType>()) return mlir::IntegerType::get(ctx, 16);
 | 
			
		||||
    if (isa<Int32RefType>()) return mlir::IntegerType::get(ctx, 32);
 | 
			
		||||
    if (isa<Int64RefType>()) return mlir::IntegerType::get(ctx, 64);
 | 
			
		||||
    if (isa<Uint8RefType>())
 | 
			
		||||
        return mlir::IntegerType::get(ctx, 8, IntegerType::Unsigned);
 | 
			
		||||
    if (isa<Uint16RefType>())
 | 
			
		||||
        return mlir::IntegerType::get(ctx, 16, IntegerType::Unsigned);
 | 
			
		||||
    if (isa<Uint32RefType>())
 | 
			
		||||
        return mlir::IntegerType::get(ctx, 32, IntegerType::Unsigned);
 | 
			
		||||
    if (isa<Uint64RefType>())
 | 
			
		||||
        return mlir::IntegerType::get(ctx, 64, IntegerType::Unsigned);
 | 
			
		||||
    if (isa<Complex64RefType>())
 | 
			
		||||
        return mlir::ComplexType::get(mlir::FloatType::getF32(ctx));
 | 
			
		||||
    if (isa<Complex128RefType>())
 | 
			
		||||
        return mlir::ComplexType::get(mlir::FloatType::getF64(ctx));
 | 
			
		||||
#define HANDLE_TF_TYPE(tftype, enumerant, name) \
 | 
			
		||||
    if (isa<tftype##RefType>()) return tftype##Type::get(ctx);
 | 
			
		||||
 | 
			
		||||
#define HANDLE_TF_REF_TYPE(tftype, enumerant, name)
 | 
			
		||||
// NOLINTNEXTLINE
 | 
			
		||||
#include "tf_types.def"
 | 
			
		||||
    llvm_unreachable("unexpected tensorflow ref type kind");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Type TensorFlowTypeWithSubtype::RemoveSubtypes()
 | 
			
		||||
{
 | 
			
		||||
    MLIRContext* ctx = getContext();
 | 
			
		||||
    if (isa<VariantType>()) return VariantType::get(ctx);
 | 
			
		||||
    if (isa<ResourceType>()) return ResourceType::get(ctx);
 | 
			
		||||
    llvm_unreachable("unexpected tensorflow type with subtypes kind");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
ArrayRef<TensorType> TensorFlowTypeWithSubtype::GetSubtypes()
 | 
			
		||||
{
 | 
			
		||||
    if (auto variant_type = dyn_cast<VariantType>())
 | 
			
		||||
        return variant_type.getSubtypes();
 | 
			
		||||
    if (auto resource_type = dyn_cast<ResourceType>())
 | 
			
		||||
        return resource_type.getSubtypes();
 | 
			
		||||
    llvm_unreachable("unexpected tensorflow type with subtypes kind");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TODO(jpienaar): BroadcastCompatible and HasCompatibleElementTypes have
 | 
			
		||||
// similar structure that could be extracted into helper method.
 | 
			
		||||
bool BroadcastCompatible(TypeRange lhs, TypeRange rhs)
 | 
			
		||||
{
 | 
			
		||||
    if (lhs.size() != rhs.size()) return false;
 | 
			
		||||
    for (auto types : llvm::zip(lhs, rhs))
 | 
			
		||||
    {
 | 
			
		||||
        // Drop ref types because they don't affect broadcast compatibility. E.g.,
 | 
			
		||||
        // `tensor<!tf.f32ref>` and `tensor<f32>` should be considered broadcast
 | 
			
		||||
        // compatible.
 | 
			
		||||
        auto lhs_type = DropRefType(std::get<0>(types));
 | 
			
		||||
        auto rhs_type = DropRefType(std::get<1>(types));
 | 
			
		||||
 | 
			
		||||
        // This should be true for all TF ops:
 | 
			
		||||
        auto lhs_tt = lhs_type.dyn_cast<TensorType>();
 | 
			
		||||
        auto rhs_tt = rhs_type.dyn_cast<TensorType>();
 | 
			
		||||
        if (!lhs_tt || !rhs_tt)
 | 
			
		||||
        {
 | 
			
		||||
            if (lhs_type != rhs_type) return false;
 | 
			
		||||
            continue;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // Verify matching element types. These should be identical, except for
 | 
			
		||||
        // variant type where unknown subtype is considered compatible with all
 | 
			
		||||
        // subtypes.
 | 
			
		||||
        auto lhs_et = lhs_tt.getElementType();
 | 
			
		||||
        auto rhs_et = rhs_tt.getElementType();
 | 
			
		||||
        if (lhs_et != rhs_et)
 | 
			
		||||
        {
 | 
			
		||||
            // If either does not have subtypes, then the element types don't match.
 | 
			
		||||
            auto lhs_wst = lhs_et.dyn_cast<TF::TensorFlowTypeWithSubtype>();
 | 
			
		||||
            auto rhs_wst = rhs_et.dyn_cast<TF::TensorFlowTypeWithSubtype>();
 | 
			
		||||
            if (!lhs_wst || !rhs_wst) return false;
 | 
			
		||||
 | 
			
		||||
            // Consider the subtype of variant types.
 | 
			
		||||
            auto lhs_wst_st = lhs_wst.GetSubtypes();
 | 
			
		||||
            auto rhs_wst_st = rhs_wst.GetSubtypes();
 | 
			
		||||
            if (!lhs_wst_st.empty() && !rhs_wst_st.empty())
 | 
			
		||||
            {
 | 
			
		||||
                for (auto subtypes : llvm::zip(lhs_wst_st, rhs_wst_st))
 | 
			
		||||
                {
 | 
			
		||||
                    if (!BroadcastCompatible(std::get<0>(subtypes),
 | 
			
		||||
                                             std::get<1>(subtypes)))
 | 
			
		||||
                        return false;
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        auto lhs_rt = lhs_type.dyn_cast<RankedTensorType>();
 | 
			
		||||
        auto rhs_rt = rhs_type.dyn_cast<RankedTensorType>();
 | 
			
		||||
        if (!lhs_rt || !rhs_rt) return true;
 | 
			
		||||
        SmallVector<int64_t, 4> shape;
 | 
			
		||||
        return OpTrait::util::getBroadcastedShape(lhs_rt.getShape(),
 | 
			
		||||
                rhs_rt.getShape(), shape);
 | 
			
		||||
    }
 | 
			
		||||
    return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Given two types `a` and `b`, returns a refined type which is cast compatible
 | 
			
		||||
// with both `a` and `b` and is equal to or more precise than both of them. It
 | 
			
		||||
// returns empty Type if the input types are not cast compatible.
 | 
			
		||||
//
 | 
			
		||||
// The two types are considered cast compatible if they have dynamically equal
 | 
			
		||||
// shapes and element type. For element types that do not have subtypes, they
 | 
			
		||||
// must be equal. However for TensorFlow types such as Resource and Variant,
 | 
			
		||||
// that also have subtypes, we recursively check for subtype compatibilty for
 | 
			
		||||
// Resource types and assume all variant types are cast compatible. If either
 | 
			
		||||
// one of `a` or `b` have empty subtypes, they are considered cast compatible.
 | 
			
		||||
//
 | 
			
		||||
// The returned type is same or more precise than the input types. For example,
 | 
			
		||||
// if `a` and `b` are cast compatible types tensor<2x?x?xf32> and
 | 
			
		||||
// tensor<?x4x?xf32> respectively, the returned type is tensor<2x4x?xf32>.
 | 
			
		||||
//
 | 
			
		||||
// Provides option to ignore ref types on 'a'. This is useful for TF ops that
 | 
			
		||||
// might allow operands to either be same as result type or be a ref type
 | 
			
		||||
// corresponding to it.
 | 
			
		||||
mlir::Type GetCastCompatibleType(mlir::Type a, mlir::Type b,
 | 
			
		||||
                                 bool may_ignore_ref_type_a)
 | 
			
		||||
{
 | 
			
		||||
    // Fast path if everything is equal.
 | 
			
		||||
    if (a == b) return b;
 | 
			
		||||
 | 
			
		||||
    auto a_tt = a.dyn_cast<mlir::TensorType>();
 | 
			
		||||
    auto b_tt = b.dyn_cast<mlir::TensorType>();
 | 
			
		||||
 | 
			
		||||
    // If only one of a or b is a tensor type, they are incompatible.
 | 
			
		||||
    if (static_cast<bool>(a_tt) ^ static_cast<bool>(b_tt)) return nullptr;
 | 
			
		||||
 | 
			
		||||
    // For non-tensor types, we do not need to worry about shape and can return
 | 
			
		||||
    // early.
 | 
			
		||||
    if (!a_tt && !b_tt)
 | 
			
		||||
    {
 | 
			
		||||
        // Remove ref types.
 | 
			
		||||
        if (may_ignore_ref_type_a)
 | 
			
		||||
        {
 | 
			
		||||
            if (auto ref_type = a.dyn_cast<mlir::TF::TensorFlowRefType>())
 | 
			
		||||
            {
 | 
			
		||||
                a = ref_type.RemoveRef();
 | 
			
		||||
                if (a == b) return a;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        if (a.getTypeID() != b.getTypeID()) return nullptr;
 | 
			
		||||
 | 
			
		||||
        // If either is not a type that contain subtypes then the types are not cast
 | 
			
		||||
        // compatible.
 | 
			
		||||
        auto a_wst = a.dyn_cast<mlir::TF::TensorFlowTypeWithSubtype>();
 | 
			
		||||
        auto b_wst = b.dyn_cast<mlir::TF::TensorFlowTypeWithSubtype>();
 | 
			
		||||
        if (!a_wst || !b_wst) return nullptr;
 | 
			
		||||
 | 
			
		||||
        // For Variant types we are more permissive right now and accept all pairs
 | 
			
		||||
        // of Variant types. If we are more constrainted and check compatibility of
 | 
			
		||||
        // subtypes, we might reject valid graphs.
 | 
			
		||||
        // TODO(prakalps): Variant doesn't have a subtype, we assign it
 | 
			
		||||
        // one, so we should only assign it one when we know the subtype. Then we
 | 
			
		||||
        // can be more constrained and check subtypes for cast compatibility as
 | 
			
		||||
        // well.
 | 
			
		||||
        if (a.isa<mlir::TF::VariantType>()) return a;
 | 
			
		||||
 | 
			
		||||
        // For Resource types, we recursively check the subtypes for cast
 | 
			
		||||
        // compatibility, if possible. Otherwise treat them as compatible.
 | 
			
		||||
        auto a_wst_st = a_wst.GetSubtypes();
 | 
			
		||||
        auto b_wst_st = b_wst.GetSubtypes();
 | 
			
		||||
        if (a_wst_st.empty() || b_wst_st.empty()) return a;
 | 
			
		||||
        if (a_wst_st.size() != b_wst_st.size()) return nullptr;
 | 
			
		||||
        llvm::SmallVector<mlir::TensorType, 4> refined_subtypes;
 | 
			
		||||
        for (auto subtypes : llvm::zip(a_wst_st, b_wst_st))
 | 
			
		||||
        {
 | 
			
		||||
            mlir::Type refined_st = GetCastCompatibleType(std::get<0>(subtypes), std::get<1>(subtypes),
 | 
			
		||||
                                    /*may_ignore_ref_type_a=*/false);
 | 
			
		||||
            if (!refined_st) return nullptr;
 | 
			
		||||
            refined_subtypes.push_back(refined_st.cast<mlir::TensorType>());
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        return mlir::TF::ResourceType::get(refined_subtypes, a.getContext());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // For tensor types, check compatibility of both element type and shape.
 | 
			
		||||
    mlir::Type refined_element_ty = GetCastCompatibleType(
 | 
			
		||||
                                        a_tt.getElementType(), b_tt.getElementType(), may_ignore_ref_type_a);
 | 
			
		||||
    if (!refined_element_ty) return nullptr;
 | 
			
		||||
 | 
			
		||||
    if (!a_tt.hasRank() && !b_tt.hasRank())
 | 
			
		||||
    {
 | 
			
		||||
        return mlir::UnrankedTensorType::get(refined_element_ty);
 | 
			
		||||
    }
 | 
			
		||||
    if (!a_tt.hasRank())
 | 
			
		||||
    {
 | 
			
		||||
        return mlir::RankedTensorType::get(b_tt.getShape(), refined_element_ty);
 | 
			
		||||
    }
 | 
			
		||||
    if (!b_tt.hasRank())
 | 
			
		||||
    {
 | 
			
		||||
        return mlir::RankedTensorType::get(a_tt.getShape(), refined_element_ty);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    llvm::SmallVector<int64_t, 8> refined_shape;
 | 
			
		||||
    if (!GetCastCompatibleShape(a_tt.getShape(), b_tt.getShape(), &refined_shape))
 | 
			
		||||
        return nullptr;
 | 
			
		||||
 | 
			
		||||
    return mlir::RankedTensorType::get(refined_shape, refined_element_ty);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool HasCompatibleElementTypes(Type lhs, Type rhs,
 | 
			
		||||
                               bool may_ignore_ref_type_lhs)
 | 
			
		||||
{
 | 
			
		||||
    return GetCastCompatibleType(lhs, rhs, may_ignore_ref_type_lhs) != nullptr;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool AreCastCompatible(TypeRange types)
 | 
			
		||||
{
 | 
			
		||||
    Type common = types.front();
 | 
			
		||||
    for (auto type : types.drop_front())
 | 
			
		||||
    {
 | 
			
		||||
        Type refined_type = GetCastCompatibleType(common, type, /*may_ignore_ref_type_a=*/false);
 | 
			
		||||
        if (!refined_type) return false;
 | 
			
		||||
        common = refined_type;
 | 
			
		||||
    }
 | 
			
		||||
    return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool ArraysAreCastCompatible(TypeRange lhs, TypeRange rhs)
 | 
			
		||||
{
 | 
			
		||||
    if (lhs.size() != rhs.size()) return false;
 | 
			
		||||
    for (auto pair : llvm::zip(lhs, rhs))
 | 
			
		||||
    {
 | 
			
		||||
        auto lhs_i = std::get<0>(pair);
 | 
			
		||||
        auto rhs_i = std::get<1>(pair);
 | 
			
		||||
        if (!AreCastCompatible({lhs_i, rhs_i})) return false;
 | 
			
		||||
    }
 | 
			
		||||
    return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Assumes a function `GetDefaultTypeOf(ComposedType)` that returns the default
 | 
			
		||||
// type for a composed type (such as a ref type or a type with subtypes).
 | 
			
		||||
template<typename ComposedType>
 | 
			
		||||
Type DropTypeHelper(Type ty)
 | 
			
		||||
{
 | 
			
		||||
    Type element_ty = getElementTypeOrSelf(ty);
 | 
			
		||||
    auto composed_type = element_ty.dyn_cast<ComposedType>();
 | 
			
		||||
    if (!composed_type) return ty;
 | 
			
		||||
 | 
			
		||||
    Type default_ty = GetDefaultTypeOf(composed_type);
 | 
			
		||||
    if (auto ranked_ty = ty.dyn_cast<RankedTensorType>())
 | 
			
		||||
    {
 | 
			
		||||
        return RankedTensorType::get(ranked_ty.getShape(), default_ty);
 | 
			
		||||
    }
 | 
			
		||||
    else if (ty.dyn_cast<UnrankedTensorType>())
 | 
			
		||||
    {
 | 
			
		||||
        return UnrankedTensorType::get(default_ty);
 | 
			
		||||
    }
 | 
			
		||||
    else
 | 
			
		||||
    {
 | 
			
		||||
        return default_ty;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Type DropSubTypes(Type ty)
 | 
			
		||||
{
 | 
			
		||||
    return DropTypeHelper<TF::TensorFlowTypeWithSubtype>(ty);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Type DropRefType(Type ty)
 | 
			
		||||
{
 | 
			
		||||
    return DropTypeHelper<TF::TensorFlowRefType>(ty);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Type DropRefAndSubTypes(Type ty)
 | 
			
		||||
{
 | 
			
		||||
    return DropRefType(DropSubTypes(ty));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace TF
 | 
			
		||||
} // namespace mlir
 | 
			
		||||
							
								
								
									
										77
									
								
								3rdparty/ncnn/tools/mlir/tf_types.def
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										77
									
								
								3rdparty/ncnn/tools/mlir/tf_types.def
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,77 @@
 | 
			
		||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
// This file contains descriptions of the various TensorFlow types. This is
 | 
			
		||||
// used as a central place for enumerating the different types.
 | 
			
		||||
 | 
			
		||||
#ifdef HANDLE_TF_TYPE
 | 
			
		||||
 | 
			
		||||
//             class, enumerant, name
 | 
			
		||||
HANDLE_TF_TYPE(Qint8, QINT8, "qint8")
 | 
			
		||||
HANDLE_TF_TYPE(Qint16, QINT16, "qint16")
 | 
			
		||||
HANDLE_TF_TYPE(Qint32, QINT32, "qint32")
 | 
			
		||||
HANDLE_TF_TYPE(Quint8, QUINT8, "quint8")
 | 
			
		||||
HANDLE_TF_TYPE(Quint16, QUINT16, "quint16")
 | 
			
		||||
HANDLE_TF_TYPE(String, STRING, "string")
 | 
			
		||||
 | 
			
		||||
#ifndef HANDLE_CUSTOM_TF_TYPE
 | 
			
		||||
#define HANDLE_CUSTOM_TF_TYPE(class, enumerant, name) \
 | 
			
		||||
  HANDLE_TF_TYPE(class, enumerant, name)
 | 
			
		||||
#endif
 | 
			
		||||
HANDLE_CUSTOM_TF_TYPE(Resource, RESOURCE, "resource")
 | 
			
		||||
HANDLE_CUSTOM_TF_TYPE(Variant, VARIANT, "variant")
 | 
			
		||||
#undef HANDLE_CUSTOM_TF_TYPE
 | 
			
		||||
 | 
			
		||||
// All ref types are listed below this line and FloatRef is the first ref type.
 | 
			
		||||
// This helps in easily differentiating ref and non-ref types, and converting
 | 
			
		||||
// a type to/from ref types.
 | 
			
		||||
 | 
			
		||||
#ifndef HANDLE_TF_REF_TYPE
 | 
			
		||||
#define HANDLE_TF_REF_TYPE(class, enumerant, name) \
 | 
			
		||||
  HANDLE_TF_TYPE(class, enumerant, name)
 | 
			
		||||
#endif
 | 
			
		||||
HANDLE_TF_REF_TYPE(FloatRef, FLOAT_REF, "f32ref")
 | 
			
		||||
HANDLE_TF_REF_TYPE(DoubleRef, DOUBLE_REF, "f64ref")
 | 
			
		||||
HANDLE_TF_REF_TYPE(Uint8Ref, UINT8_REF, "uint8ref")
 | 
			
		||||
HANDLE_TF_REF_TYPE(Int8Ref, INT8_REF, "int8ref")
 | 
			
		||||
HANDLE_TF_REF_TYPE(Uint16Ref, UINT16_REF, "uint16ref")
 | 
			
		||||
HANDLE_TF_REF_TYPE(Int16Ref, INT16_REF, "int16ref")
 | 
			
		||||
HANDLE_TF_REF_TYPE(Uint32Ref, UINT32_REF, "uint32ref")
 | 
			
		||||
HANDLE_TF_REF_TYPE(Int32Ref, INT32_REF, "int32ref")
 | 
			
		||||
HANDLE_TF_REF_TYPE(Uint64Ref, UINT64_REF, "uint64ref")
 | 
			
		||||
HANDLE_TF_REF_TYPE(Int64Ref, INT64_REF, "int64ref")
 | 
			
		||||
HANDLE_TF_REF_TYPE(StringRef, STRING_REF, "stringref")
 | 
			
		||||
HANDLE_TF_REF_TYPE(BoolRef, BOOL_REF, "boolref")
 | 
			
		||||
HANDLE_TF_REF_TYPE(Quint8Ref, QUINT8_REF, "quint8ref")
 | 
			
		||||
HANDLE_TF_REF_TYPE(Qint8Ref, QINT8_REF, "qint8ref")
 | 
			
		||||
HANDLE_TF_REF_TYPE(Quint16Ref, QUINT16_REF, "quint16ref")
 | 
			
		||||
HANDLE_TF_REF_TYPE(Qint16Ref, QINT16_REF, "qint16ref")
 | 
			
		||||
HANDLE_TF_REF_TYPE(Qint32Ref, QINT32_REF, "qint32ref")
 | 
			
		||||
HANDLE_TF_REF_TYPE(Bfloat16Ref, BFLOAT16_REF, "bfloat16ref")
 | 
			
		||||
HANDLE_TF_REF_TYPE(Complex64Ref, COMPLEX64_REF, "complex64ref")
 | 
			
		||||
HANDLE_TF_REF_TYPE(Complex128Ref, COMPLEX128_REF, "complex128ref")
 | 
			
		||||
HANDLE_TF_REF_TYPE(HalfRef, HALF_REF, "halfref")
 | 
			
		||||
HANDLE_TF_REF_TYPE(ResourceRef, RESOURCE_REF, "resourceref")
 | 
			
		||||
 | 
			
		||||
#ifndef HANDLE_LAST_TF_TYPE
 | 
			
		||||
#define HANDLE_LAST_TF_TYPE(class, enumerant, name) \
 | 
			
		||||
  HANDLE_TF_REF_TYPE(class, enumerant, name)
 | 
			
		||||
#endif
 | 
			
		||||
HANDLE_LAST_TF_TYPE(VariantRef, VARIANT_REF, "variantref")
 | 
			
		||||
#undef HANDLE_LAST_TF_TYPE
 | 
			
		||||
 | 
			
		||||
#undef HANDLE_TF_REF_TYPE
 | 
			
		||||
#undef HANDLE_TF_TYPE
 | 
			
		||||
#endif
 | 
			
		||||
							
								
								
									
										380
									
								
								3rdparty/ncnn/tools/mlir/tf_types.h
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										380
									
								
								3rdparty/ncnn/tools/mlir/tf_types.h
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,380 @@
 | 
			
		||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
// This file defines the types used in the standard MLIR TensorFlow dialect.
 | 
			
		||||
 | 
			
		||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TYPES_H_
 | 
			
		||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TYPES_H_
 | 
			
		||||
 | 
			
		||||
#include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
 | 
			
		||||
#include "mlir/IR/Diagnostics.h"   // from @llvm-project
 | 
			
		||||
#include "mlir/IR/Location.h"      // from @llvm-project
 | 
			
		||||
#include "mlir/IR/Operation.h"     // from @llvm-project
 | 
			
		||||
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
 | 
			
		||||
#include "mlir/IR/Types.h"         // from @llvm-project
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace TF {
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Utility iterators
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
// An iterator for the tensor shapes of an op's operands of shaped types.
 | 
			
		||||
// Returns llvm::None if a operand is unranked; returns ArrayRef<int64_t> as the
 | 
			
		||||
// shape otherwise.
 | 
			
		||||
class OperandShapeIterator final
 | 
			
		||||
    : public llvm::mapped_iterator<Operation::operand_iterator,
 | 
			
		||||
      llvm::Optional<ArrayRef<int64_t> > (*)(
 | 
			
		||||
          Value)>
 | 
			
		||||
{
 | 
			
		||||
public:
 | 
			
		||||
    using reference = llvm::Optional<ArrayRef<int64_t> >;
 | 
			
		||||
 | 
			
		||||
    /// Initializes the operand shape iterator to the specified operand iterator.
 | 
			
		||||
    explicit OperandShapeIterator(Operation::operand_iterator it);
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
using OperandShapeRange = iterator_range<OperandShapeIterator>;
 | 
			
		||||
 | 
			
		||||
// An iterator for the tensor shapes of an op's results of shaped types.
 | 
			
		||||
// Returns llvm::None if a result is unranked; returns ArrayRef<int64_t> as the
 | 
			
		||||
// shape otherwise.
 | 
			
		||||
class ResultShapeIterator final
 | 
			
		||||
    : public llvm::mapped_iterator<Operation::result_iterator,
 | 
			
		||||
      llvm::Optional<ArrayRef<int64_t> > (*)(
 | 
			
		||||
          Value)>
 | 
			
		||||
{
 | 
			
		||||
public:
 | 
			
		||||
    using reference = llvm::Optional<ArrayRef<int64_t> >;
 | 
			
		||||
 | 
			
		||||
    /// Initializes the result shape iterator to the specified result iterator.
 | 
			
		||||
    explicit ResultShapeIterator(Operation::result_iterator it);
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
using ResultShapeRange = iterator_range<ResultShapeIterator>;
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// TensorFlow types
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
// The base class in the TensorFlow type hierarchy.
 | 
			
		||||
class TensorFlowType : public Type
 | 
			
		||||
{
 | 
			
		||||
public:
 | 
			
		||||
    using Type::Type;
 | 
			
		||||
 | 
			
		||||
    // Support method to enable LLVM-style type casting.
 | 
			
		||||
    static bool classof(Type type);
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Returns true if the specified type is a valid TensorFlow element type.
 | 
			
		||||
static inline bool IsValidTFElementType(Type type)
 | 
			
		||||
{
 | 
			
		||||
    return type.isa<ComplexType, FloatType, IntegerType, TensorFlowType>();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Returns true if this is a valid TensorFlow tensor type.
 | 
			
		||||
static inline bool IsValidTFTensorType(Type type)
 | 
			
		||||
{
 | 
			
		||||
    // TensorFlow types should be tensors of one of the valid TensorFlow element
 | 
			
		||||
    // types.
 | 
			
		||||
    if (auto tensor_ty = type.dyn_cast<TensorType>())
 | 
			
		||||
        return IsValidTFElementType(tensor_ty.getElementType());
 | 
			
		||||
    return false;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
namespace detail {
 | 
			
		||||
// Common implementation of TensorFlow types. The template argument indicates
 | 
			
		||||
// the concrete derived class per CRTP.
 | 
			
		||||
template<typename Derived>
 | 
			
		||||
class TensorFlowTypeImpl
 | 
			
		||||
    : public Type::TypeBase<Derived, TensorFlowType, TypeStorage>
 | 
			
		||||
{
 | 
			
		||||
public:
 | 
			
		||||
    using Base = typename Type::TypeBase<Derived, TensorFlowType, TypeStorage>;
 | 
			
		||||
    using TFBase = TensorFlowTypeImpl<Derived>;
 | 
			
		||||
    using Base::Base;
 | 
			
		||||
};
 | 
			
		||||
} // namespace detail
 | 
			
		||||
 | 
			
		||||
// TensorFlowRefType class supports all the ref types in TensorFlow dialect.
 | 
			
		||||
class TensorFlowRefType : public TensorFlowType
 | 
			
		||||
{
 | 
			
		||||
public:
 | 
			
		||||
    using TensorFlowType::TensorFlowType;
 | 
			
		||||
 | 
			
		||||
    // Checks if a type is TensorFlow Ref type.
 | 
			
		||||
    static bool classof(Type type);
 | 
			
		||||
 | 
			
		||||
    // Converts a type to the corresponding TensorFlowRef type.
 | 
			
		||||
    static TensorFlowType get(Type type);
 | 
			
		||||
    static TensorFlowType getChecked(Type type, MLIRContext* context,
 | 
			
		||||
                                     Location loc)
 | 
			
		||||
    {
 | 
			
		||||
        if (failed(verify(loc, type)))
 | 
			
		||||
        {
 | 
			
		||||
            return TensorFlowRefType();
 | 
			
		||||
        }
 | 
			
		||||
        return get(type);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    static LogicalResult verify(Location loc, Type type)
 | 
			
		||||
    {
 | 
			
		||||
        // type should be a valid TensorFlow type.
 | 
			
		||||
        if (!IsValidTFTensorType(type))
 | 
			
		||||
        {
 | 
			
		||||
            return emitError(loc) << "invalid TensorFlow type: " << type;
 | 
			
		||||
        }
 | 
			
		||||
        return success();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Converts a TensorFlowRef type to the corresponding TensorFlow or standard
 | 
			
		||||
    // type.
 | 
			
		||||
    Type RemoveRef();
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Returns the corresponding TensorFlow or standard type from TensorFlowRef
 | 
			
		||||
// type.
 | 
			
		||||
static inline Type GetDefaultTypeOf(TensorFlowRefType type)
 | 
			
		||||
{
 | 
			
		||||
    return type.RemoveRef();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Returns the element type if `type` is a `ShapedType` and the type itself
 | 
			
		||||
// otherwise, converting `TensorFlowRef` type to corresponding `TensorFlow` or
 | 
			
		||||
// standard type if necessary.
 | 
			
		||||
static inline Type GetElementTypeOrSelfResolveRef(Type type)
 | 
			
		||||
{
 | 
			
		||||
    Type element_type = mlir::getElementTypeOrSelf(type);
 | 
			
		||||
    if (auto ref_type = element_type.dyn_cast<mlir::TF::TensorFlowRefType>())
 | 
			
		||||
    {
 | 
			
		||||
        element_type = ref_type.RemoveRef();
 | 
			
		||||
    }
 | 
			
		||||
    return element_type;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define HANDLE_TF_TYPE(tftype, enumerant, name)                          \
 | 
			
		||||
    class tftype##Type : public detail::TensorFlowTypeImpl<tftype##Type> \
 | 
			
		||||
    {                                                                    \
 | 
			
		||||
    public:                                                              \
 | 
			
		||||
        using TFBase::TFBase;                                            \
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
// Custom TensorFlow types are defined separately.
 | 
			
		||||
#define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name)
 | 
			
		||||
 | 
			
		||||
// NOLINTNEXTLINE
 | 
			
		||||
#include "tf_types.def"
 | 
			
		||||
 | 
			
		||||
namespace detail {
 | 
			
		||||
// Storage type contains inferred subtypes for TypeWithSubtype.
 | 
			
		||||
class TypeWithSubtypeStorage : public TypeStorage
 | 
			
		||||
{
 | 
			
		||||
public:
 | 
			
		||||
    using KeyTy = ArrayRef<TensorType>;
 | 
			
		||||
 | 
			
		||||
    // NOLINTNEXTLINE
 | 
			
		||||
    static TypeWithSubtypeStorage* construct(TypeStorageAllocator& allocator,
 | 
			
		||||
            const KeyTy& key)
 | 
			
		||||
    {
 | 
			
		||||
        ArrayRef<TensorType> subtypes = allocator.copyInto(key);
 | 
			
		||||
        return new (allocator.allocate<TypeWithSubtypeStorage>())
 | 
			
		||||
               TypeWithSubtypeStorage(subtypes);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    explicit TypeWithSubtypeStorage(const KeyTy& key)
 | 
			
		||||
        : subtypes_(key)
 | 
			
		||||
    {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    bool operator==(const KeyTy& key) const
 | 
			
		||||
    {
 | 
			
		||||
        return key == subtypes_;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    static llvm::hash_code hashKey(const KeyTy& key)
 | 
			
		||||
    {
 | 
			
		||||
        return llvm::hash_combine_range(key.begin(), key.end());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    KeyTy subtypes_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Common implementation of TensorFlow types with subtypes. These subtypes are
 | 
			
		||||
// opaque and their interpretation depends on the actual underlying type.
 | 
			
		||||
// The template argument indicates the concrete derived class per CRTP. Concrete
 | 
			
		||||
// classes must implement the following:
 | 
			
		||||
//   - `static std::string getTypeName()` that returns the name of the type for
 | 
			
		||||
//     verification logging.
 | 
			
		||||
template<typename Derived>
 | 
			
		||||
class TypeWithSubtypeImpl
 | 
			
		||||
    : public Type::TypeBase<Derived, TensorFlowType, TypeWithSubtypeStorage>
 | 
			
		||||
{
 | 
			
		||||
public:
 | 
			
		||||
    using Base = Type::TypeBase<Derived, TensorFlowType, TypeWithSubtypeStorage>;
 | 
			
		||||
    using TFBase = TypeWithSubtypeImpl<Derived>;
 | 
			
		||||
    using Base::Base;
 | 
			
		||||
 | 
			
		||||
    static Derived get(ArrayRef<TensorType> subtypes, MLIRContext* context)
 | 
			
		||||
    {
 | 
			
		||||
        return Base::get(context, subtypes);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    static Derived getChecked(ArrayRef<TensorType> subtypes, MLIRContext* context,
 | 
			
		||||
                              Location loc)
 | 
			
		||||
    {
 | 
			
		||||
        return Base::getChecked(loc, subtypes);
 | 
			
		||||
    }
 | 
			
		||||
    static Derived getChecked(function_ref<InFlightDiagnostic()> emitError,
 | 
			
		||||
                              MLIRContext* context,
 | 
			
		||||
                              ArrayRef<TensorType> subtypes)
 | 
			
		||||
    {
 | 
			
		||||
        return Base::getChecked(emitError, context, subtypes);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    static Derived get(MLIRContext* context)
 | 
			
		||||
    {
 | 
			
		||||
        return get({}, context);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
 | 
			
		||||
                                ArrayRef<TensorType> subtypes)
 | 
			
		||||
    {
 | 
			
		||||
        // Each of the subtypes should be a valid TensorFlow type.
 | 
			
		||||
        for (TensorType subtype : subtypes)
 | 
			
		||||
        {
 | 
			
		||||
            if (!IsValidTFTensorType(subtype))
 | 
			
		||||
            {
 | 
			
		||||
                return emitError() << "invalid " << Derived::getTypeName()
 | 
			
		||||
                       << " subtype: " << subtype;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        return success();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ArrayRef<TensorType> getSubtypes()
 | 
			
		||||
    {
 | 
			
		||||
        return Base::getImpl()->subtypes_;
 | 
			
		||||
    }
 | 
			
		||||
};
 | 
			
		||||
} // namespace detail
 | 
			
		||||
 | 
			
		||||
// TensorFlowTypeWithSubtype class supports all the types with subtypes in
 | 
			
		||||
// TensorFlow dialect.
 | 
			
		||||
class TensorFlowTypeWithSubtype : public TensorFlowType
 | 
			
		||||
{
 | 
			
		||||
public:
 | 
			
		||||
    using TensorFlowType::TensorFlowType;
 | 
			
		||||
 | 
			
		||||
    // Checks if a type is TensorFlow type with subtypes.
 | 
			
		||||
    static bool classof(Type type);
 | 
			
		||||
 | 
			
		||||
    // Converts a TypeWithSubtype type to the same type but without its subtypes.
 | 
			
		||||
    Type RemoveSubtypes();
 | 
			
		||||
 | 
			
		||||
    // Returns the subtypes.
 | 
			
		||||
    ArrayRef<TensorType> GetSubtypes();
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Returns the corresponding TensorFlow type with subtypes but without its
 | 
			
		||||
// subtypes.
 | 
			
		||||
static inline Type GetDefaultTypeOf(TensorFlowTypeWithSubtype type)
 | 
			
		||||
{
 | 
			
		||||
    return type.RemoveSubtypes();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TensorFlow resource type is used to support TensorFlow resource variables,
 | 
			
		||||
// which represent shared, persistent state manipulated by a TensorFlow program.
 | 
			
		||||
// ResourceType stores shape and datatype for subtypes unlike most other data
 | 
			
		||||
// types that don't have any associated information.
 | 
			
		||||
class ResourceType : public detail::TypeWithSubtypeImpl<ResourceType>
 | 
			
		||||
{
 | 
			
		||||
public:
 | 
			
		||||
    using TFBase::TFBase;
 | 
			
		||||
    static std::string getTypeName()
 | 
			
		||||
    {
 | 
			
		||||
        return "ResourceType";
 | 
			
		||||
    }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// TensorFlow variant type is used to support arbitrary custom C++ data types.
 | 
			
		||||
// VariantType stores inferred shape and datatype for subtypes unlike most other
 | 
			
		||||
// data types that don't have any associated information. For example, variants
 | 
			
		||||
// encoding TensorList type stores the common shape and dtype of the list
 | 
			
		||||
// elements as the only subtype.
 | 
			
		||||
class VariantType : public detail::TypeWithSubtypeImpl<VariantType>
 | 
			
		||||
{
 | 
			
		||||
public:
 | 
			
		||||
    using TFBase::TFBase;
 | 
			
		||||
    static std::string getTypeName()
 | 
			
		||||
    {
 | 
			
		||||
        return "VariantType";
 | 
			
		||||
    }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Given two types `a` and `b`, returns a refined type which is cast compatible
 | 
			
		||||
// with both `a` and `b` and is equal to or more precise than both of them. It
 | 
			
		||||
// returns empty Type if the input types are not cast compatible.
 | 
			
		||||
// Provides option to ignore ref types on 'a'. This is useful for TF ops that
 | 
			
		||||
// might allow operands to either be same as result type or be a ref type
 | 
			
		||||
// corresponding to it.
 | 
			
		||||
mlir::Type GetCastCompatibleType(mlir::Type a, mlir::Type b,
 | 
			
		||||
                                 bool may_ignore_ref_type_a);
 | 
			
		||||
 | 
			
		||||
// Returns whether two arrays of Type are broadcast compatible.
 | 
			
		||||
bool BroadcastCompatible(TypeRange lhs, TypeRange rhs);
 | 
			
		||||
 | 
			
		||||
// Returns whether the two elemental types are compatible. Shapes are compatible
 | 
			
		||||
// if:
 | 
			
		||||
// - the types are statically equal
 | 
			
		||||
// - could be dynamically equal
 | 
			
		||||
//   - considering dynamic shapes equal unless contradictory info known;
 | 
			
		||||
//   - element types are equivalent, modulo subtypes possible be less exact
 | 
			
		||||
//     (e.g., a resource type without subtype is considered compatible with
 | 
			
		||||
//      resource type with known subtype).
 | 
			
		||||
// Provide option to ignore ref types on 'lhs'.
 | 
			
		||||
bool HasCompatibleElementTypes(Type lhs, Type rhs,
 | 
			
		||||
                               bool may_ignore_ref_type_lhs = false);
 | 
			
		||||
 | 
			
		||||
// Returns true if all TensorFlow types can be cast to one
 | 
			
		||||
// another. In other words, a single run-time value is legal for both the types.
 | 
			
		||||
// For example, tensor<*xf32>, tensor<?xf32> and tensor<3xf32> are cast
 | 
			
		||||
// compatible.
 | 
			
		||||
bool AreCastCompatible(TypeRange types);
 | 
			
		||||
 | 
			
		||||
// Returns true if corresponding elements of lhs and rhs AreCastCompatible and
 | 
			
		||||
// lhs and rhs are the same length.
 | 
			
		||||
bool ArraysAreCastCompatible(TypeRange lhs, TypeRange rhs);
 | 
			
		||||
 | 
			
		||||
// If `ty` is a tensor type and its element type has subtypes, then returns a
 | 
			
		||||
// new type of same shape but dropped subtypes for the element type.
 | 
			
		||||
// Otherwise, if `ty` has subtypes, then returns corresponding type with dropped
 | 
			
		||||
// subtypes.
 | 
			
		||||
// Otherwise, returns the original type `ty`.
 | 
			
		||||
Type DropSubTypes(Type ty);
 | 
			
		||||
 | 
			
		||||
// If `ty` is a tensor type and has elements of a ref type, then returns a new
 | 
			
		||||
// type of same shape but corresponding non-ref type as element type.
 | 
			
		||||
// Otherwise, if `ty` is a ref type, then returns corresponding non-ref type.
 | 
			
		||||
// Otherwise, returns the original type `ty`.
 | 
			
		||||
Type DropRefType(Type ty);
 | 
			
		||||
 | 
			
		||||
// Convenience call for executing both `DropRefType` and `DropSubTypes`.
 | 
			
		||||
Type DropRefAndSubTypes(Type ty);
 | 
			
		||||
 | 
			
		||||
} // end namespace TF
 | 
			
		||||
} // end namespace mlir
 | 
			
		||||
 | 
			
		||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TYPES_H_
 | 
			
		||||
		Reference in New Issue
	
	Block a user