618 lines
25 KiB
TableGen
618 lines
25 KiB
TableGen
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||
|
|
||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
you may not use this file except in compliance with the License.
|
||
|
You may obtain a copy of the License at
|
||
|
|
||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||
|
|
||
|
Unless required by applicable law or agreed to in writing, software
|
||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
See the License for the specific language governing permissions and
|
||
|
limitations under the License.
|
||
|
==============================================================================*/
|
||
|
|
||
|
// This is the base operation definition file for TensorFlow.
|
||
|
//
|
||
|
// This file includes the definition for the TensorFlow dialect, base TensorFlow
|
||
|
// op, and various commonly used TensorFlow traits, types, attributes, and
|
||
|
// builders.
|
||
|
|
||
|
#ifndef TF_OP_BASE
|
||
|
#define TF_OP_BASE
|
||
|
|
||
|
include "mlir/IR/OpBase.td"
|
||
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// TensorFlow dialect definitions
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
def TF_Dialect : Dialect {
|
||
|
let name = "tf";
|
||
|
|
||
|
let description = [{
|
||
|
The TensorFlow dialect.
|
||
|
|
||
|
This dialect maps to TensorFlow operations.
|
||
|
|
||
|
Invariants:
|
||
|
|
||
|
* All values are of Tensor type (in particular, scalars are
|
||
|
represented using zero-dimensional tensors);
|
||
|
|
||
|
TODO: Make invariants more structured so that we can reference them in ops.
|
||
|
}];
|
||
|
|
||
|
let cppNamespace = "::mlir::TF";
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// TensorFlow traits
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
// Specify this trait if the op requires all outputs to have the same type and
|
||
|
// the inputs either have the same type as result or a ref type corresponding to
|
||
|
// the result type.
|
||
|
def TF_OperandsSameAsResultsTypeOrRef : NativeOpTrait<
|
||
|
"TF::OperandsSameAsResultsTypeOrRef">;
|
||
|
|
||
|
// Op has the same operand and result element types (or type itself, if scalar)
|
||
|
// after resolving reference types (i.e., after converting reference types to
|
||
|
// their corresponding TensorFlow or standard types).
|
||
|
def TF_SameOperandsAndResultElementTypeResolveRef : NativeOpTrait<
|
||
|
"TF::SameOperandsAndResultElementTypeResolveRef">;
|
||
|
|
||
|
// Op has the same operand and result types after resolving reference types
|
||
|
// (i.e., after converting reference types to their corresponding TensorFlow or
|
||
|
// standard types).
|
||
|
def TF_SameOperandsAndResultTypeResolveRef : NativeOpTrait<
|
||
|
"TF::SameOperandsAndResultTypeResolveRef">;
|
||
|
|
||
|
// Layout agnostic operations do not depend on the operands data layout (data
|
||
|
// format), as an example all element wise operations are layout agnostic.
|
||
|
def TF_LayoutAgnostic : NativeOpTrait<"TF::LayoutAgnostic">;
|
||
|
|
||
|
// Trait to indicate operations that cannot be duplicated as they might carry
|
||
|
// certain state around within their implementations.
|
||
|
def TF_CannotDuplicate : NativeOpTrait<"TF::CannotDuplicate">;
|
||
|
|
||
|
// Trait to indicate an operation cannot be constant folded.
|
||
|
def TF_NoConstantFold : NativeOpTrait<"TF::NoConstantFold">;
|
||
|
|
||
|
// Coefficient wise binary operation with implicit broadcasting support, for
|
||
|
// example tf.Sub operation.
|
||
|
def TF_CwiseBinary : NativeOpTrait<"TF::CwiseBinary">;
|
||
|
|
||
|
// Coefficient wise unary operation, for example tf.Sqrt operation.
|
||
|
def TF_CwiseUnary : NativeOpTrait<"TF::CwiseUnary">;
|
||
|
|
||
|
// Variant of broadcastable trait that considers TF's subtype behavior.
|
||
|
class TF_OpIsBroadcastableToRes<int opId, int resId> : And<[
|
||
|
TCOpResIsShapedTypePred<opId, resId>,
|
||
|
CPred<"mlir::TF::BroadcastCompatible("
|
||
|
"$_op.getOperand(" # opId # ").getType(), "
|
||
|
"$_op.getResult(" # resId # ").getType())">]>;
|
||
|
|
||
|
|
||
|
class TF_AllTypesMatchPred<list<string> values> :
|
||
|
CPred<"TF::AreCastCompatible(llvm::makeArrayRef({" #
|
||
|
!interleave(values, ", ") # "}))">;
|
||
|
|
||
|
class TF_AllTypesMatch<list<string> names> :
|
||
|
PredOpTrait<
|
||
|
"all of {" # !interleave(names, ", ") #
|
||
|
"} have dynamically equal types ",
|
||
|
TF_AllTypesMatchPred<
|
||
|
!foreach(n, names, !subst("$_self", "$" # n, "$_self.getType()"))>>;
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// Rank/Shape helpers.
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
class TF_OperandIsUnrankedPred<int n> :
|
||
|
CPred<"$_op.getOperand(" # n # ").getType().isa<UnrankedTensorType>()">;
|
||
|
|
||
|
class TF_ResultIsUnrankedPred<int n> :
|
||
|
CPred<"$_op.getResult(" # n # ").getType().isa<UnrankedTensorType>()">;
|
||
|
|
||
|
// Returns true if the n-th operand has unknown rank or has rank m.
|
||
|
class TF_OperandHasRank<int n, int m> :
|
||
|
PredOpTrait<"operand " # n # " is " # m # "-D",
|
||
|
Or<[TF_OperandIsUnrankedPred<n>,
|
||
|
CPred<"$_op.getOperand(" # n #
|
||
|
").getType().cast<ShapedType>().getRank() == " # m>]>>;
|
||
|
|
||
|
// Returns true if the n-th result has unknown rank or has rank m.
|
||
|
class TF_ResultHasRank<int n, int m> :
|
||
|
PredOpTrait<"result " # n # " is " # m # "-D",
|
||
|
Or<[TF_ResultIsUnrankedPred<n>,
|
||
|
CPred<"$_op.getResult(" # n #
|
||
|
").getType().cast<ShapedType>().getRank() == " # m>]>>;
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// TensorFlow op side effects
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
class TF_ResourceBase<string resourceKind> :
|
||
|
Resource<!strconcat("::mlir::TF::ResourceEffects::", resourceKind)> {
|
||
|
}
|
||
|
|
||
|
def TF_VariableResource : TF_ResourceBase<"Variable">;
|
||
|
def TF_StackResource : TF_ResourceBase<"Stack">;
|
||
|
def TF_TensorArrayResource : TF_ResourceBase<"TensorArray">;
|
||
|
def TF_SummaryResource : TF_ResourceBase<"Summary">;
|
||
|
def TF_LookupTableResource : TF_ResourceBase<"LookupTable">;
|
||
|
def TF_DatasetSeedGeneratorResource : TF_ResourceBase<"DatasetSeedGenerator">;
|
||
|
def TF_DatasetMemoryCacheResource : TF_ResourceBase<"DatasetMemoryCache">;
|
||
|
def TF_DatasetIteratorResource : TF_ResourceBase<"DatasetIterator">;
|
||
|
def TF_TPUEmbeddingResource : TF_ResourceBase<"TPUEmbedding">;
|
||
|
|
||
|
def TF_VariableRead : MemRead<TF_VariableResource>;
|
||
|
def TF_StackRead : MemRead<TF_StackResource>;
|
||
|
def TF_TensorArrayRead : MemRead<TF_TensorArrayResource>;
|
||
|
def TF_LookupTableRead : MemRead<TF_LookupTableResource>;
|
||
|
def TF_DatasetSeedGeneratorRead : MemRead<TF_DatasetSeedGeneratorResource>;
|
||
|
def TF_DatasetMemoryCacheRead : MemRead<TF_DatasetMemoryCacheResource>;
|
||
|
def TF_DatasetIteratorRead : MemRead<TF_DatasetIteratorResource>;
|
||
|
|
||
|
def TF_VariableWrite : MemWrite<TF_VariableResource>;
|
||
|
def TF_StackWrite : MemWrite<TF_StackResource>;
|
||
|
def TF_TensorArrayWrite : MemWrite<TF_TensorArrayResource>;
|
||
|
def TF_SummaryWrite : MemWrite<TF_SummaryResource>;
|
||
|
def TF_LookupTableWrite : MemWrite<TF_LookupTableResource>;
|
||
|
def TF_DatasetSeedGeneratorWrite : MemWrite<TF_DatasetSeedGeneratorResource>;
|
||
|
def TF_DatasetMemoryCacheWrite : MemWrite<TF_DatasetMemoryCacheResource>;
|
||
|
def TF_DatasetIteratorWrite : MemWrite<TF_DatasetIteratorResource>;
|
||
|
|
||
|
def TF_VariableAlloc : MemAlloc<TF_VariableResource>;
|
||
|
def TF_StackAlloc : MemAlloc<TF_StackResource>;
|
||
|
def TF_TensorArrayAlloc : MemAlloc<TF_TensorArrayResource>;
|
||
|
def TF_SummaryAlloc : MemAlloc<TF_SummaryResource>;
|
||
|
def TF_LookupTableAlloc : MemAlloc<TF_LookupTableResource>;
|
||
|
def TF_DatasetSeedGeneratorAlloc : MemAlloc<TF_DatasetSeedGeneratorResource>;
|
||
|
def TF_DatasetMemoryCacheAlloc : MemAlloc<TF_DatasetMemoryCacheResource>;
|
||
|
def TF_DatasetIteratorAlloc : MemAlloc<TF_DatasetIteratorResource>;
|
||
|
|
||
|
def TF_StackFree : MemFree<TF_StackResource>;
|
||
|
def TF_TensorArrayFree : MemFree<TF_TensorArrayResource>;
|
||
|
def TF_SummaryFree : MemFree<TF_SummaryResource>;
|
||
|
def TF_DatasetSeedGeneratorFree : MemFree<TF_DatasetSeedGeneratorResource>;
|
||
|
def TF_DatasetMemoryCacheFree : MemFree<TF_DatasetMemoryCacheResource>;
|
||
|
def TF_DatasetIteratorFree : MemFree<TF_DatasetIteratorResource>;
|
||
|
|
||
|
def TF_TPUEmbeddingSideEffect : MemoryEffects<[MemWrite<TF_TPUEmbeddingResource>]>;
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// TensorFlow op definitions
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
class TF_Op<string mnemonic, list<OpTrait> traits = []> :
|
||
|
Op<TF_Dialect, mnemonic, traits>;
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// TensorFlow attribute definitions
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
class TF_TensorFlowAttr <string name, string description> :
|
||
|
Attr<CPred<"$_self.isa<mlir::TF::" # name # "Attr>()">,
|
||
|
"TensorFlow " # description # " attribute">;
|
||
|
|
||
|
def TF_ShapeAttr : TF_TensorFlowAttr<"Shape", "shape"> {
|
||
|
let returnType = "llvm::Optional<llvm::ArrayRef<int64_t>>";
|
||
|
let convertFromStorage = "$_self.cast<mlir::TF::ShapeAttr>().getValue()";
|
||
|
|
||
|
// Create a ranked shape attr by default.
|
||
|
let constBuilderCall = "mlir::TF::ShapeAttr::get($_builder.getContext(), $0)";
|
||
|
}
|
||
|
|
||
|
def TF_ShapeAttrArray :
|
||
|
TypedArrayAttrBase<TF_ShapeAttr, "tensorflow shape attribute array">;
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// TensorFlow type definitions
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
// Any tensor element type defined in the TensorFlow dialect
|
||
|
def TF_TFDialectType :
|
||
|
Type<CPred<"$_self.isa<mlir::TF::TensorFlowType>()">, "TensorFlow type">;
|
||
|
|
||
|
// Class for any TensorFlow dialect specific type
|
||
|
class TF_TensorFlowType <string name, string description> :
|
||
|
Type<CPred<"$_self.isa<mlir::TF::" # name # "Type>()">,
|
||
|
"TensorFlow " # description # " type">,
|
||
|
BuildableType<"getType<mlir::TF::" # name # "Type>()">;
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// Reference types
|
||
|
|
||
|
// Float reference types
|
||
|
def TF_Float16Ref : TF_TensorFlowType<"HalfRef", "f16ref">;
|
||
|
def TF_Float32Ref : TF_TensorFlowType<"FloatRef", "f32ref">;
|
||
|
def TF_Float64Ref : TF_TensorFlowType<"DoubleRef", "f64ref">;
|
||
|
def TF_Bfloat16Ref : TF_TensorFlowType<"Bfloat16Ref", "bf16ref">;
|
||
|
|
||
|
// Complex reference types
|
||
|
def TF_Complex64Ref : TF_TensorFlowType<"Complex64Ref", "complex64ref">;
|
||
|
def TF_Complex128Ref : TF_TensorFlowType<"Complex128Ref", "complex128ref">;
|
||
|
|
||
|
// Integer reference types
|
||
|
def TF_Int8Ref : TF_TensorFlowType<"Int8Ref", "i8ref">;
|
||
|
def TF_Int16Ref : TF_TensorFlowType<"Int16Ref", "i16ref">;
|
||
|
def TF_Int32Ref : TF_TensorFlowType<"Int32Ref", "i32ref">;
|
||
|
def TF_Int64Ref : TF_TensorFlowType<"Int64Ref", "i64ref">;
|
||
|
|
||
|
def TF_Uint8Ref : TF_TensorFlowType<"Uint8Ref", "ui8ref">;
|
||
|
def TF_Uint16Ref : TF_TensorFlowType<"Uint16Ref", "ui16ref">;
|
||
|
def TF_Uint32Ref : TF_TensorFlowType<"Uint32Ref", "ui32ref">;
|
||
|
def TF_Uint64Ref : TF_TensorFlowType<"Uint64Ref", "ui64ref">;
|
||
|
|
||
|
// Quantized reference types
|
||
|
def TF_Qint8Ref : TF_TensorFlowType<"Qint8Ref", "qint8ref">;
|
||
|
def TF_Qint16Ref : TF_TensorFlowType<"Qint16Ref", "qint16ref">;
|
||
|
def TF_Qint32Ref : TF_TensorFlowType<"Qint32Ref", "qint32ref">;
|
||
|
def TF_Quint8Ref : TF_TensorFlowType<"Quint8Ref", "quint8ref">;
|
||
|
def TF_Quint16Ref : TF_TensorFlowType<"Quint16Ref", "quint16ref">;
|
||
|
|
||
|
// Other reference types
|
||
|
def TF_BoolRef : TF_TensorFlowType<"BoolRef", "boolref">;
|
||
|
def TF_ResourceRef : TF_TensorFlowType<"ResourceRef", "resourceref">;
|
||
|
def TF_StrRef : TF_TensorFlowType<"StringRef", "stringref">;
|
||
|
def TF_VariantRef : TF_TensorFlowType<"VariantRef", "variantref">;
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// Integer types (including corresponding reference types)
|
||
|
|
||
|
def TF_Bool : AnyTypeOf<[I<1>, TF_BoolRef], "bool">;
|
||
|
|
||
|
def TF_Int8 : AnyTypeOf<[I8, TF_Int8Ref], "8-bit integer">;
|
||
|
def TF_Int16 : AnyTypeOf<[I16, TF_Int16Ref], "16-bit integer">;
|
||
|
def TF_Int32 : AnyTypeOf<[I32, TF_Int32Ref], "32-bit integer">;
|
||
|
def TF_Int64 : AnyTypeOf<[I64, TF_Int64Ref], "64-bit integer">;
|
||
|
def TF_I32OrI64 : AnyTypeOf<[I32, I64, TF_Int32Ref, TF_Int64Ref],
|
||
|
"32/64-bit signed integer">;
|
||
|
|
||
|
def TF_Uint8 : AnyTypeOf<[UI<8>, TF_Uint8Ref], "8-bit unsigned integer">;
|
||
|
def TF_Uint16 : AnyTypeOf<[UI<16>, TF_Uint16Ref], "16-bit unsigned integer">;
|
||
|
def TF_Uint32 : AnyTypeOf<[UI<32>, TF_Uint32Ref], "32-bit unsigned integer">;
|
||
|
def TF_Uint64 : AnyTypeOf<[UI<64>, TF_Uint64Ref], "64-bit unsigned integer">;
|
||
|
|
||
|
// Any unsigned integer type
|
||
|
def TF_UInt : AnyTypeOf<[TF_Uint8, TF_Uint16, TF_Uint32, TF_Uint64],
|
||
|
"unsigned integer">;
|
||
|
|
||
|
// Any signed integer type
|
||
|
def TF_SInt : AnyTypeOf<[TF_Int8, TF_Int16, TF_Int32, TF_Int64],
|
||
|
"signed integer">;
|
||
|
|
||
|
// Any integer type
|
||
|
def TF_Int : AnyTypeOf<[TF_SInt, TF_UInt], "integer">;
|
||
|
|
||
|
// Tensor types
|
||
|
def TF_BoolTensor : TensorOf<[TF_Bool]>;
|
||
|
|
||
|
def TF_IntTensor : TensorOf<[TF_Int]>;
|
||
|
def TF_Int8Tensor : TensorOf<[TF_Int8]>;
|
||
|
def TF_Int16Tensor : TensorOf<[TF_Int16]>;
|
||
|
def TF_Int32Tensor : TensorOf<[TF_Int32]>;
|
||
|
def TF_Int64Tensor : TensorOf<[TF_Int64]>;
|
||
|
def TF_I32OrI64Tensor : TensorOf<[TF_I32OrI64]>;
|
||
|
|
||
|
def TF_Uint8Tensor : TensorOf<[TF_Uint8]>;
|
||
|
def TF_Uint16Tensor : TensorOf<[TF_Uint16]>;
|
||
|
def TF_Uint32Tensor : TensorOf<[TF_Uint32]>;
|
||
|
def TF_Uint64Tensor : TensorOf<[TF_Uint64]>;
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// Quantized types (including corresponding reference types)
|
||
|
|
||
|
def TF_Qint8 : AnyTypeOf<
|
||
|
[TF_TensorFlowType<"Qint8", "qint8">, TF_Qint8Ref],
|
||
|
"8-bit quantized integer">;
|
||
|
def TF_Qint16 : AnyTypeOf<
|
||
|
[TF_TensorFlowType<"Qint16", "qint16">, TF_Qint16Ref],
|
||
|
"16-bit quantized integer">;
|
||
|
def TF_Qint32 : AnyTypeOf<
|
||
|
[TF_TensorFlowType<"Qint32", "qint32">, TF_Qint32Ref],
|
||
|
"32-bit quantized integer">;
|
||
|
def TF_Quint8 : AnyTypeOf<
|
||
|
[TF_TensorFlowType<"Quint8", "quint8">, TF_Quint8Ref],
|
||
|
"8-bit quantized unsigned integer">;
|
||
|
def TF_Quint16 : AnyTypeOf<
|
||
|
[TF_TensorFlowType<"Quint16", "quint16">, TF_Quint16Ref],
|
||
|
"16-bit quantized unsigned integer">;
|
||
|
|
||
|
// Any quantized type
|
||
|
def TF_Quantized : AnyTypeOf<
|
||
|
[TF_Qint8, TF_Qint16, TF_Qint32, TF_Quint8, TF_Quint16], "quantized">;
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// Floating-point types (including corresponding reference types)
|
||
|
|
||
|
def TF_Float16 : AnyTypeOf<[F16, TF_Float16Ref], "16-bit float">;
|
||
|
def TF_Float32 : AnyTypeOf<[F32, TF_Float32Ref], "32-bit float">;
|
||
|
def TF_Float64 : AnyTypeOf<[F64, TF_Float64Ref], "64-bit float">;
|
||
|
def TF_Bfloat16 : AnyTypeOf<[BF16, TF_Bfloat16Ref], "bfloat16">;
|
||
|
|
||
|
def TF_F32OrF64 : AnyTypeOf<[TF_Float32, TF_Float64], "32/64-bit float">;
|
||
|
|
||
|
def TF_Float : AnyTypeOf<
|
||
|
[TF_Float16, TF_Float32, TF_Float64, TF_Bfloat16],
|
||
|
"floating-point">;
|
||
|
|
||
|
// Tensor types
|
||
|
def TF_FloatTensor : TensorOf<[TF_Float]>;
|
||
|
def TF_F32OrF64Tensor : TensorOf<[TF_F32OrF64]>;
|
||
|
def TF_Float16Tensor : TensorOf<[TF_Float16]>;
|
||
|
def TF_Float32Tensor : TensorOf<[TF_Float32]>;
|
||
|
def TF_Float64Tensor : TensorOf<[TF_Float64]>;
|
||
|
def TF_Bfloat16Tensor : TensorOf<[TF_Bfloat16]>;
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// Complex types (including corresponding reference types)
|
||
|
|
||
|
// TODO(suderman): Remove TF_Complex64 and use a standard ops declaration, along
|
||
|
// with the associated cleanup.
|
||
|
def TF_Complex64 : AnyTypeOf<[Complex<F<32>>, TF_Complex64Ref],
|
||
|
"64-bit complex">;
|
||
|
def TF_Complex128 : AnyTypeOf<[Complex<F<64>>, TF_Complex128Ref],
|
||
|
"128-bit complex">;
|
||
|
def TF_Complex : AnyTypeOf<[TF_Complex64, TF_Complex128], "complex">;
|
||
|
|
||
|
// Tensor types
|
||
|
def TF_ComplexTensor : TensorOf<[TF_Complex]>;
|
||
|
def TF_Complex64Tensor : TensorOf<[TF_Complex64]>;
|
||
|
def TF_Complex128Tensor : TensorOf<[TF_Complex128]>;
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// String/variant/resource types (including corresponding reference types)
|
||
|
|
||
|
def TF_Str : AnyTypeOf<
|
||
|
[TF_TensorFlowType<"String", "str">, TF_StrRef], "string">;
|
||
|
def TF_StrTensor : TensorOf<[TF_Str]>;
|
||
|
|
||
|
def TF_Variant : AnyTypeOf<
|
||
|
[TF_TensorFlowType<"Variant", "var">, TF_VariantRef], "variant">;
|
||
|
def TF_VariantTensor : TensorOf<[TF_Variant]>;
|
||
|
|
||
|
def TF_Resource : AnyTypeOf<
|
||
|
[TF_TensorFlowType<"Resource", "res">, TF_ResourceRef], "resource">;
|
||
|
def TF_ResourceTensor : TensorOf<[TF_Resource]>;
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// Multi-category type constraints
|
||
|
|
||
|
def TF_IntOrF32OrF64Tensor: TensorOf<[TF_Int, TF_F32OrF64]>;
|
||
|
def TF_FpOrI32OrI64Tensor : TensorOf<[TF_Float, TF_I32OrI64]>;
|
||
|
def TF_IntOrFpTensor : TensorOf<[TF_Int, TF_Float]>;
|
||
|
def TF_SintOrFpTensor : TensorOf<[TF_SInt, TF_Float]>;
|
||
|
def TF_FpOrComplexTensor : TensorOf<[TF_Float, TF_Complex]>;
|
||
|
|
||
|
def TF_Number : AnyTypeOf<
|
||
|
[TF_Int, TF_Float, TF_Quantized, TF_Complex], "number">;
|
||
|
def TF_NumberTensor : TensorOf<[TF_Number]>;
|
||
|
|
||
|
def TF_NumberNotQuantizedOrStr :
|
||
|
AnyTypeOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Str]>;
|
||
|
def TF_NumberNotQuantizedOrStrTensor : TensorOf<[TF_NumberNotQuantizedOrStr]>;
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// Tensor and tensor element types
|
||
|
|
||
|
// Any tensor element type allowed in TensorFlow ops
|
||
|
// (see https://www.tensorflow.org/api_docs/python/tf/dtypes/DType)
|
||
|
def TF_ElementType : Type<Or<[TF_Float.predicate,
|
||
|
TF_Complex.predicate,
|
||
|
TF_Int.predicate,
|
||
|
TF_Bool.predicate,
|
||
|
TF_TFDialectType.predicate]>,
|
||
|
"tf.dtype">;
|
||
|
|
||
|
// Any TensorFlow tensor type
|
||
|
def TF_Tensor : TensorOf<[TF_ElementType]>;
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// TensorFlow attribute definitions
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// Tensorflow devices metadata
|
||
|
|
||
|
// Tensorflow GPU device metadata.
|
||
|
def TF_GpuDeviceMetadata : StructAttr<"GpuDeviceMetadata", TF_Dialect, [
|
||
|
// GPU device compute capability: major:minor.
|
||
|
StructFieldAttr<"cc_major", I32Attr>,
|
||
|
StructFieldAttr<"cc_minor", I32Attr>
|
||
|
]>;
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// String attribute constraints
|
||
|
|
||
|
// A string attribute whose value are one of the values in `cases`.
|
||
|
class TF_AnyStrAttrOf<list<string> cases> : StringBasedAttr<
|
||
|
CPred<!foldl(
|
||
|
"$_self.cast<StringAttr>().getValue() == \"" # !head(cases) # "\"",
|
||
|
!foreach(case, !tail(cases),
|
||
|
"$_self.cast<StringAttr>().getValue() == \"" # case # "\""),
|
||
|
prev, cur, prev # " || " # cur)>,
|
||
|
"string attribute whose value is " #
|
||
|
!foldl(/*init*/!head(cases), /*list*/!tail(cases),
|
||
|
prev, cur, prev # ", or " # cur)>;
|
||
|
|
||
|
// TODO: Use EnumAttr to define the common attribute cases
|
||
|
|
||
|
def TF_ConvnetDataFormatAttr : StringBasedAttr<
|
||
|
CPred<"$_self.cast<StringAttr>().getValue() == \"NHWC\" || " #
|
||
|
"$_self.cast<StringAttr>().getValue() == \"NCHW\"">,
|
||
|
"'NHWC' or 'NCHW' convnet data format">;
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// Type attributes
|
||
|
|
||
|
// A derived attribute that returns the size of `idx`-th ODS-declared variadic
|
||
|
// operand.
|
||
|
class TF_DerivedOperandSizeAttr<int idx> : DerivedAttr<
|
||
|
"size_t",
|
||
|
"auto range = getODSOperands(" # idx # ");\n"
|
||
|
"return std::distance(range.begin(), range.end());",
|
||
|
[{ $_builder.getI64IntegerAttr($_self) }]>;
|
||
|
|
||
|
// A derived attribute that returns the element type of `idx`-th ODS-declared
|
||
|
// operand. If the `idx`-th operand is a variadic operand, then this attribute
|
||
|
// just returns the element type of its first tensor, which is only meaningful
|
||
|
// when the variadic operand has at least one tensor and the tensors all have
|
||
|
// the same element type.
|
||
|
class TF_DerivedOperandTypeAttr<int idx> : DerivedTypeAttr<
|
||
|
"return mlir::getElementTypeOrSelf(*getODSOperands(" # idx # ").begin());">;
|
||
|
|
||
|
// A derived attribute that returns the element types of the tensors in the
|
||
|
// actual value pack that corresponds to the `idx`-th ODS-declared variadic
|
||
|
// operand. This returns a list of element types so it is used for variadic
|
||
|
// operands that can have different element types.
|
||
|
class TF_DerivedOperandTypeListAttr<int idx> : DerivedAttr<
|
||
|
"mlir::OperandElementTypeRange",
|
||
|
"auto values = getODSOperands(" # idx # ");\n"
|
||
|
"return {mlir::OperandElementTypeIterator(values.begin()), "
|
||
|
"mlir::OperandElementTypeIterator(values.end())};",
|
||
|
[{
|
||
|
ArrayAttr::get($_ctx,
|
||
|
[&]() {
|
||
|
llvm::SmallVector<Attribute, 4> ret;
|
||
|
for (auto t : $_self)
|
||
|
ret.push_back(TypeAttr::get(t));
|
||
|
return ret;
|
||
|
}())
|
||
|
}]
|
||
|
>;
|
||
|
|
||
|
// A derived attribute that returns the shapes of the tensors in the actual
|
||
|
// value pack that corresponds to the `idx`-th ODS-declared variadic operand.
|
||
|
// This returns a list of shapes so it is used for variadic operands that
|
||
|
// can have different shapes.
|
||
|
class TF_DerivedOperandShapeListAttr<int idx> : DerivedAttr<
|
||
|
"::mlir::TF::OperandShapeRange",
|
||
|
"auto values = getODSOperands(" # idx # ");\n"
|
||
|
"return {mlir::TF::OperandShapeIterator(values.begin()), "
|
||
|
"mlir::TF::OperandShapeIterator(values.end())};",
|
||
|
[{
|
||
|
ArrayAttr::get($_ctx,
|
||
|
[&](){
|
||
|
llvm::SmallVector<Attribute, 4> ret;
|
||
|
for (auto shape : $_self)
|
||
|
ret.push_back(mlir::TF::ShapeAttr::get($_ctx, shape));
|
||
|
return ret;
|
||
|
}())
|
||
|
}]
|
||
|
>;
|
||
|
|
||
|
// A derived attribute that returns the size of `idx`-th ODS-declared variadic
|
||
|
// result.
|
||
|
class TF_DerivedResultSizeAttr<int idx> : DerivedAttr<
|
||
|
"size_t",
|
||
|
"auto range = getODSResults(" # idx # ");\n"
|
||
|
"return std::distance(range.begin(), range.end());",
|
||
|
[{ $_builder.getI64IntegerAttr($_self) }]>;
|
||
|
|
||
|
// A derived attribute that returns the element type of `idx`-th ODS-declared
|
||
|
// result. If the `idx`-th result is a variadic result, then this attribute
|
||
|
// just returns the element type of its first tensor, which is only meaningful
|
||
|
// when the variadic result has at least one tensor and the tensors all have
|
||
|
// the same element type.
|
||
|
class TF_DerivedResultTypeAttr<int idx> : DerivedTypeAttr<
|
||
|
"return mlir::getElementTypeOrSelf(*getODSResults(" # idx # ").begin());">;
|
||
|
|
||
|
// A derived attribute that returns the element types of the tensors in the
|
||
|
// actual value pack that corresponds to the `idx`-th ODS-declared variadic
|
||
|
// result. This returns a list of element types so it is used for variadic
|
||
|
// results that can have different element types.
|
||
|
class TF_DerivedResultTypeListAttr<int idx> : DerivedAttr<
|
||
|
"mlir::ResultElementTypeRange",
|
||
|
"auto values = getODSResults(" # idx # ");\n"
|
||
|
"return {mlir::ResultElementTypeIterator(values.begin()), "
|
||
|
"mlir::ResultElementTypeIterator(values.end())};",
|
||
|
[{
|
||
|
ArrayAttr::get($_ctx,
|
||
|
[&]() {
|
||
|
llvm::SmallVector<Attribute, 4> ret;
|
||
|
for (auto t : $_self)
|
||
|
ret.push_back(TypeAttr::get(t));
|
||
|
return ret;
|
||
|
}())
|
||
|
}]
|
||
|
>;
|
||
|
|
||
|
// A derived attribute that returns the shapes of the tensors in the actual
|
||
|
// value pack that corresponds to the `idx`-th ODS-declared variadic result.
|
||
|
// This returns a list of shapes so it is used for variadic results that
|
||
|
// can have different shapes.
|
||
|
class TF_DerivedResultShapeListAttr<int idx> : DerivedAttr<
|
||
|
"mlir::TF::ResultShapeRange",
|
||
|
"auto values = getODSResults(" # idx # ");\n"
|
||
|
"return {mlir::TF::ResultShapeIterator(values.begin()), "
|
||
|
"mlir::TF::ResultShapeIterator(values.end())};",
|
||
|
[{
|
||
|
ArrayAttr::get($_ctx,
|
||
|
[&](){
|
||
|
llvm::SmallVector<Attribute, 4> ret;
|
||
|
for (auto shape : $_self)
|
||
|
ret.push_back(mlir::TF::ShapeAttr::get($_ctx, shape));
|
||
|
return ret;
|
||
|
}())
|
||
|
}]
|
||
|
>;
|
||
|
|
||
|
// A derived attribute that returns the shape of the first result type.
|
||
|
def TF_DerivedResultShapeAttr : DerivedAttr<"ShapedType",
|
||
|
"return (*getOperation()->result_type_begin()).cast<ShapedType>();",
|
||
|
[{ mlir::TF::ShapeAttr::get($_ctx, $_self) }]>;
|
||
|
|
||
|
def TF_IntTypeAttr : TypeAttrBase<"IntegerType", "integer type"> {
|
||
|
let returnType = "Type";
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// TensorFlow common builders
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
// Mixin class defining a builder for binary ops supporting broadcast
|
||
|
// behavior. The result type has the same element type as both operands.
|
||
|
class WithBroadcastableBinOpBuilder {
|
||
|
list<OpBuilder> builders = [
|
||
|
OpBuilder<(ins "Value":$x, "Value":$y),
|
||
|
[{
|
||
|
auto resultType =
|
||
|
OpTrait::util::getBroadcastedType(x.getType(), y.getType());
|
||
|
if (!resultType)
|
||
|
mlir::emitError($_state.location, "non-broadcastable operands");
|
||
|
return build($_builder, $_state, resultType, x, y);
|
||
|
}]>];
|
||
|
}
|
||
|
|
||
|
// Mixin class defining a builder for comparison ops supporting broadcast
|
||
|
// behavior. The result type has bool element type.
|
||
|
class WithBroadcastableCmpOpBuilder {
|
||
|
list<OpBuilder> builders = [
|
||
|
OpBuilder<(ins "Value":$x, "Value":$y),
|
||
|
[{
|
||
|
Type resultType;
|
||
|
if (x.getType().isa<UnrankedTensorType>() ||
|
||
|
y.getType().isa<UnrankedTensorType>()) {
|
||
|
resultType = UnrankedTensorType::get($_builder.getI1Type());
|
||
|
} else {
|
||
|
SmallVector<int64_t, 4> resultShape;
|
||
|
if (!OpTrait::util::getBroadcastedShape(
|
||
|
x.getType().cast<ShapedType>().getShape(),
|
||
|
y.getType().cast<ShapedType>().getShape(), resultShape)) {
|
||
|
mlir::emitError($_state.location,
|
||
|
"operands have no broadcastable shapes");
|
||
|
}
|
||
|
|
||
|
resultType = RankedTensorType::get(resultShape, $_builder.getI1Type());
|
||
|
}
|
||
|
return build($_builder, $_state, resultType, x, y);
|
||
|
}]>];
|
||
|
}
|
||
|
|
||
|
#endif // TF_OP_BASE
|