// Tencent is pleased to support the open source community by making ncnn available. // // Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except // in compliance with the License. You may obtain a copy of the License at // // https://opensource.org/licenses/BSD-3-Clause // // Unless required by applicable law or agreed to in writing, software distributed // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. #include "tf_dialect.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "tf_attributes.h" #include "tf_side_effects.h" #include "tf_traits.h" namespace mlir { static LogicalResult Verify(...) { return success(); } static LogicalResult VerifyPartitionedCall(...) { return success(); } static LogicalResult VerifyStridedSliceBase(...) { return success(); } static LogicalResult VerifyUnsortedSegmentReduction(...) { return success(); } namespace TF { TensorFlowDialect::TensorFlowDialect(MLIRContext* context) : Dialect(/*name=*/"tf", context, TypeID::get()) { addOperations< #define GET_OP_LIST #include "tf_all_ops.cc.inc" >(); addTypes< #define HANDLE_TF_TYPE(tftype, enumerant, name) tftype##Type, #define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type #include "tf_types.def" >(); // addInterfaces(); addAttributes(); // Support unknown operations because not all TensorFlow operations are // registered. allowUnknownOperations(); // for (const auto &hook : *TensorFlowDialect::additional_operation_hooks_) { // hook(*this); // } } namespace { ShapeAttr ParseShapeAttr(MLIRContext* context, StringRef spec, Location loc) { auto emit_error = [&, spec]() { emitError(loc, "invalid TensorFlow shape attribute: ") << spec; return nullptr; }; if (!spec.consume_front("shape<")) return emit_error(); if (spec.consume_front("*>")) return mlir::TF::ShapeAttr::get(context, llvm::None); SmallVector shape; while (!spec.consume_front(">")) { int64_t dim; if (spec.consume_front("?")) dim = -1; else if (spec.consumeInteger(10, dim) || dim < 0) return emit_error(); spec.consume_front("x"); shape.push_back(dim); } return mlir::TF::ShapeAttr::get(context, llvm::makeArrayRef(shape)); } // Parses a #tf.func attribute of the following format: // // #tf.func<@symbol, {attr = "value"}> // // where the first element is a SymbolRefAttr and the second element is a // DictionaryAttr. FuncAttr ParseFuncAttr(MLIRContext* context, StringRef spec, Location loc) { auto emit_error = [&, spec]() { emitError(loc, "invalid TensorFlow func attribute: ") << spec; return nullptr; }; if (!spec.consume_front("func<")) return emit_error(); size_t func_name_num_read = 0; Attribute func_name_attr = mlir::parseAttribute(spec, context, func_name_num_read); if (!func_name_attr || !func_name_attr.isa()) return emit_error(); spec = spec.drop_front(func_name_num_read); if (!spec.consume_front(", ")) return emit_error(); size_t func_attrs_num_read = 0; Attribute func_attrs_attr = mlir::parseAttribute(spec, context, func_attrs_num_read); if (!func_attrs_attr || !func_attrs_attr.isa()) return emit_error(); spec = spec.drop_front(func_attrs_num_read); if (!spec.consume_front(">")) return emit_error(); return mlir::TF::FuncAttr::get(context, func_name_attr.cast(), func_attrs_attr.cast()); } } // namespace Attribute TensorFlowDialect::parseAttribute(DialectAsmParser& parser, Type type) const { auto spec = parser.getFullSymbolSpec(); Location loc = parser.getEncodedSourceLoc(parser.getNameLoc()); if (spec.startswith("shape")) return ParseShapeAttr(getContext(), spec, loc); if (spec.startswith("func")) return ParseFuncAttr(getContext(), spec, loc); return (emitError(loc, "unknown TensorFlow attribute: " + spec), nullptr); } // Parses a type registered to this dialect. Type TensorFlowDialect::parseType(DialectAsmParser& parser) const { StringRef data; if (parser.parseKeyword(&data)) return Type(); #define HANDLE_TF_TYPE(tftype, enumerant, name) \ if (data == name) return tftype##Type::get(getContext()); // Custom TensorFlow types are handled separately at the end as they do partial // match. #define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name) // NOLINTNEXTLINE #include "tf_types.def" llvm::SMLoc loc = parser.getNameLoc(); if (data.startswith("resource")) { Type ret = ParseResourceType(parser); if (!ret) parser.emitError(loc, "invalid resource type"); return ret; } if (data.startswith("variant")) { Type ret = ParseVariantType(parser); if (!ret) parser.emitError(loc, "invalid variant type"); return ret; } return (parser.emitError(loc, "unknown TensorFlow type: " + data), nullptr); } namespace { template Type ParseTypeWithSubtype(MLIRContext* context, DialectAsmParser& parser) { // Default type without inferred subtypes. if (failed(parser.parseOptionalLess())) return TypeWithSubtype::get(context); // Most types with subtypes have only one subtype. SmallVector subtypes; do { TensorType tensor_ty; if (parser.parseType(tensor_ty)) return Type(); // Each of the subtypes should be a valid TensorFlow type. // TODO(jpienaar): Remove duplication. if (!IsValidTFTensorType(tensor_ty)) { parser.emitError(parser.getNameLoc()) << "invalid subtype: " << tensor_ty; return Type(); } subtypes.push_back(tensor_ty); } while (succeeded(parser.parseOptionalComma())); if (parser.parseGreater()) return Type(); return TypeWithSubtype::get(subtypes, context); } } // anonymous namespace Type TensorFlowDialect::ParseResourceType(DialectAsmParser& parser) const { return ParseTypeWithSubtype(getContext(), parser); } Type TensorFlowDialect::ParseVariantType(DialectAsmParser& parser) const { return ParseTypeWithSubtype(getContext(), parser); } Operation* TensorFlowDialect::materializeConstant(OpBuilder& builder, Attribute value, Type type, Location loc) { return builder.create(loc, type, value); } // Builds a constant op with the specified attribute `value`. The result // op's type is deduced from `value`; if `value` is of scalar type, // wraps it up with a tensor type of empty shape. // TODO(jpienaar): This one differs from the autogenerated one as it takes an // attribute but always creates an ElementsAttr internally. void ConstOp::build(OpBuilder& builder, OperationState& result, Attribute value) { ShapedType type; if (auto elem_attr = value.dyn_cast()) { return ConstOp::build(builder, result, elem_attr); } else if (value.isa()) { // All TensorFlow types must be tensor types. In the build() method, // we want to provide more flexibility by allowing attributes of scalar // types. But we need to wrap it up with ElementsAttr to construct // valid TensorFlow constants. type = RankedTensorType::get(/*shape=*/ {}, value.getType()); return ConstOp::build(builder, result, DenseElementsAttr::get(type, value)); } // TODO(jpienaar): support other TensorFlow specific types. llvm_unreachable("unsupported attribute type for building tf.Const"); } void ConstOp::build(OpBuilder& builder, OperationState& result, Type type, Attribute value) { // Handle the case where the type and value are already tensors. if (type.isa() && value.isa()) { result.addTypes(type); result.addAttribute("value", value); return; } // Otherwise, default to the attribute builder. ConstOp::build(builder, result, value); assert(type == result.types[0] && "type mismatch in construction"); } Region& WhileRegionOp::getLoopBody() { return body(); } bool WhileRegionOp::isDefinedOutsideOfLoop(Value value) { // If the Op defining the value exists and the defining op is outside the // scope of this WhileRegion, then we can infer that its defined outside. // The defining Op is outside the scope of this WhileRegion if this // WhileRegionOp is not an ancestor of the defining op in the parent chain. Operation* def_op = value.getDefiningOp(); return def_op && !getOperation()->isAncestor(def_op); } LogicalResult WhileRegionOp::moveOutOfLoop( llvm::ArrayRef ops) { // Move the hoisted value to just before the while. Operation* while_op = this->getOperation(); for (auto op : ops) op->moveBefore(while_op); return success(); } } // namespace TF } // namespace mlir #define GET_OP_CLASSES #include "tf_all_ops.cc.inc"