// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "tf_dialect.h"

#include <mlir/Dialect/Traits.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/Dialect.h>
#include <mlir/IR/DialectImplementation.h>
#include <mlir/IR/Location.h>
#include <mlir/IR/Matchers.h>
#include <mlir/IR/MLIRContext.h>
#include <mlir/IR/OpDefinition.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/Operation.h>
#include <mlir/IR/OperationSupport.h>
#include <mlir/IR/PatternMatch.h>
#include <mlir/IR/TypeUtilities.h>
#include <mlir/IR/Types.h>
#include <mlir/IR/Value.h>
#include <mlir/IR/Verifier.h>
#include <mlir/Interfaces/CallInterfaces.h>
#include <mlir/Interfaces/DerivedAttributeOpInterface.h>
#include <mlir/Interfaces/InferTypeOpInterface.h>
#include <mlir/Interfaces/LoopLikeInterface.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
#include <mlir/Parser.h>
#include <mlir/Support/LogicalResult.h>
#include <mlir/Transforms/InliningUtils.h>

#include "tf_attributes.h"
#include "tf_side_effects.h"
#include "tf_traits.h"

namespace mlir {

static LogicalResult Verify(...)
{
    return success();
}
static LogicalResult VerifyPartitionedCall(...)
{
    return success();
}
static LogicalResult VerifyStridedSliceBase(...)
{
    return success();
}
static LogicalResult VerifyUnsortedSegmentReduction(...)
{
    return success();
}

namespace TF {

TensorFlowDialect::TensorFlowDialect(MLIRContext* context)
    : Dialect(/*name=*/"tf", context, TypeID::get<TensorFlowDialect>())
{
    addOperations<
#define GET_OP_LIST
#include "tf_all_ops.cc.inc"
    >();
    addTypes<
#define HANDLE_TF_TYPE(tftype, enumerant, name)      tftype##Type,
#define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type
#include "tf_types.def"
    >();
    //   addInterfaces<TFInlinerInterface, TFDecodeAttributesInterface,
    //                 TFConstantFoldInterface>();
    addAttributes<ShapeAttr, FuncAttr>();

    // Support unknown operations because not all TensorFlow operations are
    // registered.
    allowUnknownOperations();

    //   for (const auto &hook : *TensorFlowDialect::additional_operation_hooks_) {
    //     hook(*this);
    //   }
}

namespace {

ShapeAttr ParseShapeAttr(MLIRContext* context, StringRef spec, Location loc)
{
    auto emit_error = [&, spec]() {
        emitError(loc, "invalid TensorFlow shape attribute: ") << spec;
        return nullptr;
    };

    if (!spec.consume_front("shape<")) return emit_error();

    if (spec.consume_front("*>"))
        return mlir::TF::ShapeAttr::get(context, llvm::None);

    SmallVector<int64_t, 4> shape;
    while (!spec.consume_front(">"))
    {
        int64_t dim;

        if (spec.consume_front("?"))
            dim = -1;
        else if (spec.consumeInteger(10, dim) || dim < 0)
            return emit_error();

        spec.consume_front("x");

        shape.push_back(dim);
    }

    return mlir::TF::ShapeAttr::get(context, llvm::makeArrayRef(shape));
}

// Parses a #tf.func attribute of the following format:
//
//   #tf.func<@symbol, {attr = "value"}>
//
// where the first element is a SymbolRefAttr and the second element is a
// DictionaryAttr.
FuncAttr ParseFuncAttr(MLIRContext* context, StringRef spec, Location loc)
{
    auto emit_error = [&, spec]() {
        emitError(loc, "invalid TensorFlow func attribute: ") << spec;
        return nullptr;
    };

    if (!spec.consume_front("func<")) return emit_error();

    size_t func_name_num_read = 0;
    Attribute func_name_attr = mlir::parseAttribute(spec, context, func_name_num_read);
    if (!func_name_attr || !func_name_attr.isa<SymbolRefAttr>())
        return emit_error();
    spec = spec.drop_front(func_name_num_read);

    if (!spec.consume_front(", ")) return emit_error();

    size_t func_attrs_num_read = 0;
    Attribute func_attrs_attr = mlir::parseAttribute(spec, context, func_attrs_num_read);
    if (!func_attrs_attr || !func_attrs_attr.isa<DictionaryAttr>())
        return emit_error();
    spec = spec.drop_front(func_attrs_num_read);

    if (!spec.consume_front(">")) return emit_error();

    return mlir::TF::FuncAttr::get(context, func_name_attr.cast<SymbolRefAttr>(),
                                   func_attrs_attr.cast<DictionaryAttr>());
}

} // namespace

Attribute TensorFlowDialect::parseAttribute(DialectAsmParser& parser,
        Type type) const
{
    auto spec = parser.getFullSymbolSpec();
    Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());

    if (spec.startswith("shape")) return ParseShapeAttr(getContext(), spec, loc);

    if (spec.startswith("func")) return ParseFuncAttr(getContext(), spec, loc);

    return (emitError(loc, "unknown TensorFlow attribute: " + spec), nullptr);
}

// Parses a type registered to this dialect.
Type TensorFlowDialect::parseType(DialectAsmParser& parser) const
{
    StringRef data;
    if (parser.parseKeyword(&data)) return Type();

#define HANDLE_TF_TYPE(tftype, enumerant, name) \
    if (data == name) return tftype##Type::get(getContext());
// Custom TensorFlow types are handled separately at the end as they do partial
// match.
#define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name)
// NOLINTNEXTLINE
#include "tf_types.def"

    llvm::SMLoc loc = parser.getNameLoc();
    if (data.startswith("resource"))
    {
        Type ret = ParseResourceType(parser);
        if (!ret) parser.emitError(loc, "invalid resource type");
        return ret;
    }
    if (data.startswith("variant"))
    {
        Type ret = ParseVariantType(parser);
        if (!ret) parser.emitError(loc, "invalid variant type");
        return ret;
    }
    return (parser.emitError(loc, "unknown TensorFlow type: " + data), nullptr);
}

namespace {
template<typename TypeWithSubtype>
Type ParseTypeWithSubtype(MLIRContext* context, DialectAsmParser& parser)
{
    // Default type without inferred subtypes.
    if (failed(parser.parseOptionalLess())) return TypeWithSubtype::get(context);

    // Most types with subtypes have only one subtype.
    SmallVector<TensorType, 1> subtypes;
    do
    {
        TensorType tensor_ty;
        if (parser.parseType(tensor_ty)) return Type();

        // Each of the subtypes should be a valid TensorFlow type.
        // TODO(jpienaar): Remove duplication.
        if (!IsValidTFTensorType(tensor_ty))
        {
            parser.emitError(parser.getNameLoc()) << "invalid subtype: " << tensor_ty;
            return Type();
        }
        subtypes.push_back(tensor_ty);
    } while (succeeded(parser.parseOptionalComma()));

    if (parser.parseGreater()) return Type();

    return TypeWithSubtype::get(subtypes, context);
}
} // anonymous namespace

Type TensorFlowDialect::ParseResourceType(DialectAsmParser& parser) const
{
    return ParseTypeWithSubtype<ResourceType>(getContext(), parser);
}

Type TensorFlowDialect::ParseVariantType(DialectAsmParser& parser) const
{
    return ParseTypeWithSubtype<VariantType>(getContext(), parser);
}

Operation* TensorFlowDialect::materializeConstant(OpBuilder& builder,
        Attribute value, Type type,
        Location loc)
{
    return builder.create<ConstOp>(loc, type, value);
}

// Builds a constant op with the specified attribute `value`. The result
// op's type is deduced from `value`; if `value` is of scalar type,
// wraps it up with a tensor type of empty shape.
// TODO(jpienaar): This one differs from the autogenerated one as it takes an
// attribute but always creates an ElementsAttr internally.
void ConstOp::build(OpBuilder& builder, OperationState& result,
                    Attribute value)
{
    ShapedType type;
    if (auto elem_attr = value.dyn_cast<ElementsAttr>())
    {
        return ConstOp::build(builder, result, elem_attr);
    }
    else if (value.isa<BoolAttr, FloatAttr, IntegerAttr>())
    {
        // All TensorFlow types must be tensor types. In the build() method,
        // we want to provide more flexibility by allowing attributes of scalar
        // types. But we need to wrap it up with ElementsAttr to construct
        // valid TensorFlow constants.
        type = RankedTensorType::get(/*shape=*/ {}, value.getType());
        return ConstOp::build(builder, result, DenseElementsAttr::get(type, value));
    }
    // TODO(jpienaar): support other TensorFlow specific types.
    llvm_unreachable("unsupported attribute type for building tf.Const");
}

void ConstOp::build(OpBuilder& builder, OperationState& result, Type type,
                    Attribute value)
{
    // Handle the case where the type and value are already tensors.
    if (type.isa<TensorType>() && value.isa<ElementsAttr>())
    {
        result.addTypes(type);
        result.addAttribute("value", value);
        return;
    }

    // Otherwise, default to the attribute builder.
    ConstOp::build(builder, result, value);
    assert(type == result.types[0] && "type mismatch in construction");
}

Region& WhileRegionOp::getLoopBody()
{
    return body();
}

bool WhileRegionOp::isDefinedOutsideOfLoop(Value value)
{
    // If the Op defining the value exists and the defining op is outside the
    // scope of this WhileRegion, then we can infer that its defined outside.
    // The defining Op is outside the scope of this WhileRegion if this
    // WhileRegionOp is not an ancestor of the defining op in the parent chain.
    Operation* def_op = value.getDefiningOp();
    return def_op && !getOperation()->isAncestor(def_op);
}

LogicalResult WhileRegionOp::moveOutOfLoop(
    llvm::ArrayRef<mlir::Operation*> ops)
{
    // Move the hoisted value to just before the while.
    Operation* while_op = this->getOperation();
    for (auto op : ops) op->moveBefore(while_op);
    return success();
}

} // namespace TF

} // namespace mlir

#define GET_OP_CLASSES
#include "tf_all_ops.cc.inc"