/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ // This is the base operation definition file for TensorFlow. // // This file includes the definition for the TensorFlow dialect, base TensorFlow // op, and various commonly used TensorFlow traits, types, attributes, and // builders. #ifndef TF_OP_BASE #define TF_OP_BASE include "mlir/IR/OpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" //===----------------------------------------------------------------------===// // TensorFlow dialect definitions //===----------------------------------------------------------------------===// def TF_Dialect : Dialect { let name = "tf"; let description = [{ The TensorFlow dialect. This dialect maps to TensorFlow operations. Invariants: * All values are of Tensor type (in particular, scalars are represented using zero-dimensional tensors); TODO: Make invariants more structured so that we can reference them in ops. }]; let cppNamespace = "::mlir::TF"; } //===----------------------------------------------------------------------===// // TensorFlow traits //===----------------------------------------------------------------------===// // Specify this trait if the op requires all outputs to have the same type and // the inputs either have the same type as result or a ref type corresponding to // the result type. def TF_OperandsSameAsResultsTypeOrRef : NativeOpTrait< "TF::OperandsSameAsResultsTypeOrRef">; // Op has the same operand and result element types (or type itself, if scalar) // after resolving reference types (i.e., after converting reference types to // their corresponding TensorFlow or standard types). def TF_SameOperandsAndResultElementTypeResolveRef : NativeOpTrait< "TF::SameOperandsAndResultElementTypeResolveRef">; // Op has the same operand and result types after resolving reference types // (i.e., after converting reference types to their corresponding TensorFlow or // standard types). def TF_SameOperandsAndResultTypeResolveRef : NativeOpTrait< "TF::SameOperandsAndResultTypeResolveRef">; // Layout agnostic operations do not depend on the operands data layout (data // format), as an example all element wise operations are layout agnostic. def TF_LayoutAgnostic : NativeOpTrait<"TF::LayoutAgnostic">; // Trait to indicate operations that cannot be duplicated as they might carry // certain state around within their implementations. def TF_CannotDuplicate : NativeOpTrait<"TF::CannotDuplicate">; // Trait to indicate an operation cannot be constant folded. def TF_NoConstantFold : NativeOpTrait<"TF::NoConstantFold">; // Coefficient wise binary operation with implicit broadcasting support, for // example tf.Sub operation. def TF_CwiseBinary : NativeOpTrait<"TF::CwiseBinary">; // Coefficient wise unary operation, for example tf.Sqrt operation. def TF_CwiseUnary : NativeOpTrait<"TF::CwiseUnary">; // Variant of broadcastable trait that considers TF's subtype behavior. class TF_OpIsBroadcastableToRes : And<[ TCOpResIsShapedTypePred, CPred<"mlir::TF::BroadcastCompatible(" "$_op.getOperand(" # opId # ").getType(), " "$_op.getResult(" # resId # ").getType())">]>; class TF_AllTypesMatchPred values> : CPred<"TF::AreCastCompatible(llvm::makeArrayRef({" # !interleave(values, ", ") # "}))">; class TF_AllTypesMatch names> : PredOpTrait< "all of {" # !interleave(names, ", ") # "} have dynamically equal types ", TF_AllTypesMatchPred< !foreach(n, names, !subst("$_self", "$" # n, "$_self.getType()"))>>; //===----------------------------------------------------------------------===// // Rank/Shape helpers. //===----------------------------------------------------------------------===// class TF_OperandIsUnrankedPred : CPred<"$_op.getOperand(" # n # ").getType().isa()">; class TF_ResultIsUnrankedPred : CPred<"$_op.getResult(" # n # ").getType().isa()">; // Returns true if the n-th operand has unknown rank or has rank m. class TF_OperandHasRank : PredOpTrait<"operand " # n # " is " # m # "-D", Or<[TF_OperandIsUnrankedPred, CPred<"$_op.getOperand(" # n # ").getType().cast().getRank() == " # m>]>>; // Returns true if the n-th result has unknown rank or has rank m. class TF_ResultHasRank : PredOpTrait<"result " # n # " is " # m # "-D", Or<[TF_ResultIsUnrankedPred, CPred<"$_op.getResult(" # n # ").getType().cast().getRank() == " # m>]>>; //===----------------------------------------------------------------------===// // TensorFlow op side effects //===----------------------------------------------------------------------===// class TF_ResourceBase : Resource { } def TF_VariableResource : TF_ResourceBase<"Variable">; def TF_StackResource : TF_ResourceBase<"Stack">; def TF_TensorArrayResource : TF_ResourceBase<"TensorArray">; def TF_SummaryResource : TF_ResourceBase<"Summary">; def TF_LookupTableResource : TF_ResourceBase<"LookupTable">; def TF_DatasetSeedGeneratorResource : TF_ResourceBase<"DatasetSeedGenerator">; def TF_DatasetMemoryCacheResource : TF_ResourceBase<"DatasetMemoryCache">; def TF_DatasetIteratorResource : TF_ResourceBase<"DatasetIterator">; def TF_TPUEmbeddingResource : TF_ResourceBase<"TPUEmbedding">; def TF_VariableRead : MemRead; def TF_StackRead : MemRead; def TF_TensorArrayRead : MemRead; def TF_LookupTableRead : MemRead; def TF_DatasetSeedGeneratorRead : MemRead; def TF_DatasetMemoryCacheRead : MemRead; def TF_DatasetIteratorRead : MemRead; def TF_VariableWrite : MemWrite; def TF_StackWrite : MemWrite; def TF_TensorArrayWrite : MemWrite; def TF_SummaryWrite : MemWrite; def TF_LookupTableWrite : MemWrite; def TF_DatasetSeedGeneratorWrite : MemWrite; def TF_DatasetMemoryCacheWrite : MemWrite; def TF_DatasetIteratorWrite : MemWrite; def TF_VariableAlloc : MemAlloc; def TF_StackAlloc : MemAlloc; def TF_TensorArrayAlloc : MemAlloc; def TF_SummaryAlloc : MemAlloc; def TF_LookupTableAlloc : MemAlloc; def TF_DatasetSeedGeneratorAlloc : MemAlloc; def TF_DatasetMemoryCacheAlloc : MemAlloc; def TF_DatasetIteratorAlloc : MemAlloc; def TF_StackFree : MemFree; def TF_TensorArrayFree : MemFree; def TF_SummaryFree : MemFree; def TF_DatasetSeedGeneratorFree : MemFree; def TF_DatasetMemoryCacheFree : MemFree; def TF_DatasetIteratorFree : MemFree; def TF_TPUEmbeddingSideEffect : MemoryEffects<[MemWrite]>; //===----------------------------------------------------------------------===// // TensorFlow op definitions //===----------------------------------------------------------------------===// class TF_Op traits = []> : Op; //===----------------------------------------------------------------------===// // TensorFlow attribute definitions //===----------------------------------------------------------------------===// class TF_TensorFlowAttr : Attr()">, "TensorFlow " # description # " attribute">; def TF_ShapeAttr : TF_TensorFlowAttr<"Shape", "shape"> { let returnType = "llvm::Optional>"; let convertFromStorage = "$_self.cast().getValue()"; // Create a ranked shape attr by default. let constBuilderCall = "mlir::TF::ShapeAttr::get($_builder.getContext(), $0)"; } def TF_ShapeAttrArray : TypedArrayAttrBase; //===----------------------------------------------------------------------===// // TensorFlow type definitions //===----------------------------------------------------------------------===// // Any tensor element type defined in the TensorFlow dialect def TF_TFDialectType : Type()">, "TensorFlow type">; // Class for any TensorFlow dialect specific type class TF_TensorFlowType : Type()">, "TensorFlow " # description # " type">, BuildableType<"getType()">; //===----------------------------------------------------------------------===// // Reference types // Float reference types def TF_Float16Ref : TF_TensorFlowType<"HalfRef", "f16ref">; def TF_Float32Ref : TF_TensorFlowType<"FloatRef", "f32ref">; def TF_Float64Ref : TF_TensorFlowType<"DoubleRef", "f64ref">; def TF_Bfloat16Ref : TF_TensorFlowType<"Bfloat16Ref", "bf16ref">; // Complex reference types def TF_Complex64Ref : TF_TensorFlowType<"Complex64Ref", "complex64ref">; def TF_Complex128Ref : TF_TensorFlowType<"Complex128Ref", "complex128ref">; // Integer reference types def TF_Int8Ref : TF_TensorFlowType<"Int8Ref", "i8ref">; def TF_Int16Ref : TF_TensorFlowType<"Int16Ref", "i16ref">; def TF_Int32Ref : TF_TensorFlowType<"Int32Ref", "i32ref">; def TF_Int64Ref : TF_TensorFlowType<"Int64Ref", "i64ref">; def TF_Uint8Ref : TF_TensorFlowType<"Uint8Ref", "ui8ref">; def TF_Uint16Ref : TF_TensorFlowType<"Uint16Ref", "ui16ref">; def TF_Uint32Ref : TF_TensorFlowType<"Uint32Ref", "ui32ref">; def TF_Uint64Ref : TF_TensorFlowType<"Uint64Ref", "ui64ref">; // Quantized reference types def TF_Qint8Ref : TF_TensorFlowType<"Qint8Ref", "qint8ref">; def TF_Qint16Ref : TF_TensorFlowType<"Qint16Ref", "qint16ref">; def TF_Qint32Ref : TF_TensorFlowType<"Qint32Ref", "qint32ref">; def TF_Quint8Ref : TF_TensorFlowType<"Quint8Ref", "quint8ref">; def TF_Quint16Ref : TF_TensorFlowType<"Quint16Ref", "quint16ref">; // Other reference types def TF_BoolRef : TF_TensorFlowType<"BoolRef", "boolref">; def TF_ResourceRef : TF_TensorFlowType<"ResourceRef", "resourceref">; def TF_StrRef : TF_TensorFlowType<"StringRef", "stringref">; def TF_VariantRef : TF_TensorFlowType<"VariantRef", "variantref">; //===----------------------------------------------------------------------===// // Integer types (including corresponding reference types) def TF_Bool : AnyTypeOf<[I<1>, TF_BoolRef], "bool">; def TF_Int8 : AnyTypeOf<[I8, TF_Int8Ref], "8-bit integer">; def TF_Int16 : AnyTypeOf<[I16, TF_Int16Ref], "16-bit integer">; def TF_Int32 : AnyTypeOf<[I32, TF_Int32Ref], "32-bit integer">; def TF_Int64 : AnyTypeOf<[I64, TF_Int64Ref], "64-bit integer">; def TF_I32OrI64 : AnyTypeOf<[I32, I64, TF_Int32Ref, TF_Int64Ref], "32/64-bit signed integer">; def TF_Uint8 : AnyTypeOf<[UI<8>, TF_Uint8Ref], "8-bit unsigned integer">; def TF_Uint16 : AnyTypeOf<[UI<16>, TF_Uint16Ref], "16-bit unsigned integer">; def TF_Uint32 : AnyTypeOf<[UI<32>, TF_Uint32Ref], "32-bit unsigned integer">; def TF_Uint64 : AnyTypeOf<[UI<64>, TF_Uint64Ref], "64-bit unsigned integer">; // Any unsigned integer type def TF_UInt : AnyTypeOf<[TF_Uint8, TF_Uint16, TF_Uint32, TF_Uint64], "unsigned integer">; // Any signed integer type def TF_SInt : AnyTypeOf<[TF_Int8, TF_Int16, TF_Int32, TF_Int64], "signed integer">; // Any integer type def TF_Int : AnyTypeOf<[TF_SInt, TF_UInt], "integer">; // Tensor types def TF_BoolTensor : TensorOf<[TF_Bool]>; def TF_IntTensor : TensorOf<[TF_Int]>; def TF_Int8Tensor : TensorOf<[TF_Int8]>; def TF_Int16Tensor : TensorOf<[TF_Int16]>; def TF_Int32Tensor : TensorOf<[TF_Int32]>; def TF_Int64Tensor : TensorOf<[TF_Int64]>; def TF_I32OrI64Tensor : TensorOf<[TF_I32OrI64]>; def TF_Uint8Tensor : TensorOf<[TF_Uint8]>; def TF_Uint16Tensor : TensorOf<[TF_Uint16]>; def TF_Uint32Tensor : TensorOf<[TF_Uint32]>; def TF_Uint64Tensor : TensorOf<[TF_Uint64]>; //===----------------------------------------------------------------------===// // Quantized types (including corresponding reference types) def TF_Qint8 : AnyTypeOf< [TF_TensorFlowType<"Qint8", "qint8">, TF_Qint8Ref], "8-bit quantized integer">; def TF_Qint16 : AnyTypeOf< [TF_TensorFlowType<"Qint16", "qint16">, TF_Qint16Ref], "16-bit quantized integer">; def TF_Qint32 : AnyTypeOf< [TF_TensorFlowType<"Qint32", "qint32">, TF_Qint32Ref], "32-bit quantized integer">; def TF_Quint8 : AnyTypeOf< [TF_TensorFlowType<"Quint8", "quint8">, TF_Quint8Ref], "8-bit quantized unsigned integer">; def TF_Quint16 : AnyTypeOf< [TF_TensorFlowType<"Quint16", "quint16">, TF_Quint16Ref], "16-bit quantized unsigned integer">; // Any quantized type def TF_Quantized : AnyTypeOf< [TF_Qint8, TF_Qint16, TF_Qint32, TF_Quint8, TF_Quint16], "quantized">; //===----------------------------------------------------------------------===// // Floating-point types (including corresponding reference types) def TF_Float16 : AnyTypeOf<[F16, TF_Float16Ref], "16-bit float">; def TF_Float32 : AnyTypeOf<[F32, TF_Float32Ref], "32-bit float">; def TF_Float64 : AnyTypeOf<[F64, TF_Float64Ref], "64-bit float">; def TF_Bfloat16 : AnyTypeOf<[BF16, TF_Bfloat16Ref], "bfloat16">; def TF_F32OrF64 : AnyTypeOf<[TF_Float32, TF_Float64], "32/64-bit float">; def TF_Float : AnyTypeOf< [TF_Float16, TF_Float32, TF_Float64, TF_Bfloat16], "floating-point">; // Tensor types def TF_FloatTensor : TensorOf<[TF_Float]>; def TF_F32OrF64Tensor : TensorOf<[TF_F32OrF64]>; def TF_Float16Tensor : TensorOf<[TF_Float16]>; def TF_Float32Tensor : TensorOf<[TF_Float32]>; def TF_Float64Tensor : TensorOf<[TF_Float64]>; def TF_Bfloat16Tensor : TensorOf<[TF_Bfloat16]>; //===----------------------------------------------------------------------===// // Complex types (including corresponding reference types) // TODO(suderman): Remove TF_Complex64 and use a standard ops declaration, along // with the associated cleanup. def TF_Complex64 : AnyTypeOf<[Complex>, TF_Complex64Ref], "64-bit complex">; def TF_Complex128 : AnyTypeOf<[Complex>, TF_Complex128Ref], "128-bit complex">; def TF_Complex : AnyTypeOf<[TF_Complex64, TF_Complex128], "complex">; // Tensor types def TF_ComplexTensor : TensorOf<[TF_Complex]>; def TF_Complex64Tensor : TensorOf<[TF_Complex64]>; def TF_Complex128Tensor : TensorOf<[TF_Complex128]>; //===----------------------------------------------------------------------===// // String/variant/resource types (including corresponding reference types) def TF_Str : AnyTypeOf< [TF_TensorFlowType<"String", "str">, TF_StrRef], "string">; def TF_StrTensor : TensorOf<[TF_Str]>; def TF_Variant : AnyTypeOf< [TF_TensorFlowType<"Variant", "var">, TF_VariantRef], "variant">; def TF_VariantTensor : TensorOf<[TF_Variant]>; def TF_Resource : AnyTypeOf< [TF_TensorFlowType<"Resource", "res">, TF_ResourceRef], "resource">; def TF_ResourceTensor : TensorOf<[TF_Resource]>; //===----------------------------------------------------------------------===// // Multi-category type constraints def TF_IntOrF32OrF64Tensor: TensorOf<[TF_Int, TF_F32OrF64]>; def TF_FpOrI32OrI64Tensor : TensorOf<[TF_Float, TF_I32OrI64]>; def TF_IntOrFpTensor : TensorOf<[TF_Int, TF_Float]>; def TF_SintOrFpTensor : TensorOf<[TF_SInt, TF_Float]>; def TF_FpOrComplexTensor : TensorOf<[TF_Float, TF_Complex]>; def TF_Number : AnyTypeOf< [TF_Int, TF_Float, TF_Quantized, TF_Complex], "number">; def TF_NumberTensor : TensorOf<[TF_Number]>; def TF_NumberNotQuantizedOrStr : AnyTypeOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Str]>; def TF_NumberNotQuantizedOrStrTensor : TensorOf<[TF_NumberNotQuantizedOrStr]>; //===----------------------------------------------------------------------===// // Tensor and tensor element types // Any tensor element type allowed in TensorFlow ops // (see https://www.tensorflow.org/api_docs/python/tf/dtypes/DType) def TF_ElementType : Type, "tf.dtype">; // Any TensorFlow tensor type def TF_Tensor : TensorOf<[TF_ElementType]>; //===----------------------------------------------------------------------===// // TensorFlow attribute definitions //===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===// // Tensorflow devices metadata // Tensorflow GPU device metadata. def TF_GpuDeviceMetadata : StructAttr<"GpuDeviceMetadata", TF_Dialect, [ // GPU device compute capability: major:minor. StructFieldAttr<"cc_major", I32Attr>, StructFieldAttr<"cc_minor", I32Attr> ]>; //===----------------------------------------------------------------------===// // String attribute constraints // A string attribute whose value are one of the values in `cases`. class TF_AnyStrAttrOf cases> : StringBasedAttr< CPred().getValue() == \"" # !head(cases) # "\"", !foreach(case, !tail(cases), "$_self.cast().getValue() == \"" # case # "\""), prev, cur, prev # " || " # cur)>, "string attribute whose value is " # !foldl(/*init*/!head(cases), /*list*/!tail(cases), prev, cur, prev # ", or " # cur)>; // TODO: Use EnumAttr to define the common attribute cases def TF_ConvnetDataFormatAttr : StringBasedAttr< CPred<"$_self.cast().getValue() == \"NHWC\" || " # "$_self.cast().getValue() == \"NCHW\"">, "'NHWC' or 'NCHW' convnet data format">; //===----------------------------------------------------------------------===// // Type attributes // A derived attribute that returns the size of `idx`-th ODS-declared variadic // operand. class TF_DerivedOperandSizeAttr : DerivedAttr< "size_t", "auto range = getODSOperands(" # idx # ");\n" "return std::distance(range.begin(), range.end());", [{ $_builder.getI64IntegerAttr($_self) }]>; // A derived attribute that returns the element type of `idx`-th ODS-declared // operand. If the `idx`-th operand is a variadic operand, then this attribute // just returns the element type of its first tensor, which is only meaningful // when the variadic operand has at least one tensor and the tensors all have // the same element type. class TF_DerivedOperandTypeAttr : DerivedTypeAttr< "return mlir::getElementTypeOrSelf(*getODSOperands(" # idx # ").begin());">; // A derived attribute that returns the element types of the tensors in the // actual value pack that corresponds to the `idx`-th ODS-declared variadic // operand. This returns a list of element types so it is used for variadic // operands that can have different element types. class TF_DerivedOperandTypeListAttr : DerivedAttr< "mlir::OperandElementTypeRange", "auto values = getODSOperands(" # idx # ");\n" "return {mlir::OperandElementTypeIterator(values.begin()), " "mlir::OperandElementTypeIterator(values.end())};", [{ ArrayAttr::get($_ctx, [&]() { llvm::SmallVector ret; for (auto t : $_self) ret.push_back(TypeAttr::get(t)); return ret; }()) }] >; // A derived attribute that returns the shapes of the tensors in the actual // value pack that corresponds to the `idx`-th ODS-declared variadic operand. // This returns a list of shapes so it is used for variadic operands that // can have different shapes. class TF_DerivedOperandShapeListAttr : DerivedAttr< "::mlir::TF::OperandShapeRange", "auto values = getODSOperands(" # idx # ");\n" "return {mlir::TF::OperandShapeIterator(values.begin()), " "mlir::TF::OperandShapeIterator(values.end())};", [{ ArrayAttr::get($_ctx, [&](){ llvm::SmallVector ret; for (auto shape : $_self) ret.push_back(mlir::TF::ShapeAttr::get($_ctx, shape)); return ret; }()) }] >; // A derived attribute that returns the size of `idx`-th ODS-declared variadic // result. class TF_DerivedResultSizeAttr : DerivedAttr< "size_t", "auto range = getODSResults(" # idx # ");\n" "return std::distance(range.begin(), range.end());", [{ $_builder.getI64IntegerAttr($_self) }]>; // A derived attribute that returns the element type of `idx`-th ODS-declared // result. If the `idx`-th result is a variadic result, then this attribute // just returns the element type of its first tensor, which is only meaningful // when the variadic result has at least one tensor and the tensors all have // the same element type. class TF_DerivedResultTypeAttr : DerivedTypeAttr< "return mlir::getElementTypeOrSelf(*getODSResults(" # idx # ").begin());">; // A derived attribute that returns the element types of the tensors in the // actual value pack that corresponds to the `idx`-th ODS-declared variadic // result. This returns a list of element types so it is used for variadic // results that can have different element types. class TF_DerivedResultTypeListAttr : DerivedAttr< "mlir::ResultElementTypeRange", "auto values = getODSResults(" # idx # ");\n" "return {mlir::ResultElementTypeIterator(values.begin()), " "mlir::ResultElementTypeIterator(values.end())};", [{ ArrayAttr::get($_ctx, [&]() { llvm::SmallVector ret; for (auto t : $_self) ret.push_back(TypeAttr::get(t)); return ret; }()) }] >; // A derived attribute that returns the shapes of the tensors in the actual // value pack that corresponds to the `idx`-th ODS-declared variadic result. // This returns a list of shapes so it is used for variadic results that // can have different shapes. class TF_DerivedResultShapeListAttr : DerivedAttr< "mlir::TF::ResultShapeRange", "auto values = getODSResults(" # idx # ");\n" "return {mlir::TF::ResultShapeIterator(values.begin()), " "mlir::TF::ResultShapeIterator(values.end())};", [{ ArrayAttr::get($_ctx, [&](){ llvm::SmallVector ret; for (auto shape : $_self) ret.push_back(mlir::TF::ShapeAttr::get($_ctx, shape)); return ret; }()) }] >; // A derived attribute that returns the shape of the first result type. def TF_DerivedResultShapeAttr : DerivedAttr<"ShapedType", "return (*getOperation()->result_type_begin()).cast();", [{ mlir::TF::ShapeAttr::get($_ctx, $_self) }]>; def TF_IntTypeAttr : TypeAttrBase<"IntegerType", "integer type"> { let returnType = "Type"; } //===----------------------------------------------------------------------===// // TensorFlow common builders //===----------------------------------------------------------------------===// // Mixin class defining a builder for binary ops supporting broadcast // behavior. The result type has the same element type as both operands. class WithBroadcastableBinOpBuilder { list builders = [ OpBuilder<(ins "Value":$x, "Value":$y), [{ auto resultType = OpTrait::util::getBroadcastedType(x.getType(), y.getType()); if (!resultType) mlir::emitError($_state.location, "non-broadcastable operands"); return build($_builder, $_state, resultType, x, y); }]>]; } // Mixin class defining a builder for comparison ops supporting broadcast // behavior. The result type has bool element type. class WithBroadcastableCmpOpBuilder { list builders = [ OpBuilder<(ins "Value":$x, "Value":$y), [{ Type resultType; if (x.getType().isa() || y.getType().isa()) { resultType = UnrankedTensorType::get($_builder.getI1Type()); } else { SmallVector resultShape; if (!OpTrait::util::getBroadcastedShape( x.getType().cast().getShape(), y.getType().cast().getShape(), resultShape)) { mlir::emitError($_state.location, "operands have no broadcastable shapes"); } resultType = RankedTensorType::get(resultShape, $_builder.getI1Type()); } return build($_builder, $_state, resultType, x, y); }]>]; } #endif // TF_OP_BASE