718c41634f
1.项目后端整体迁移至PaddleOCR-NCNN算法,已通过基本的兼容性测试 2.工程改为使用CMake组织,后续为了更好地兼容第三方库,不再提供QMake工程 3.重整权利声明文件,重整代码工程,确保最小化侵权风险 Log: 切换后端至PaddleOCR-NCNN,切换工程为CMake Change-Id: I4d5d2c5d37505a4a24b389b1a4c5d12f17bfa38c
164 lines
4.1 KiB
C++
164 lines
4.1 KiB
C++
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tf_attributes.h"
|
|
|
|
namespace mlir {
|
|
namespace TF {
|
|
|
|
namespace detail {
|
|
|
|
// The storage class for ShapeAttr.
|
|
struct ShapeAttrStorage : public AttributeStorage
|
|
{
|
|
using KeyTy = std::pair<ArrayRef<int64_t>, bool>;
|
|
|
|
explicit ShapeAttrStorage(ArrayRef<int64_t> shape, bool unranked = false)
|
|
: shape(shape), unranked(unranked)
|
|
{
|
|
}
|
|
|
|
bool operator==(const KeyTy& key) const
|
|
{
|
|
return key == KeyTy(shape, unranked);
|
|
}
|
|
static unsigned hashKey(const KeyTy& key)
|
|
{
|
|
return llvm::hash_combine(key.first, static_cast<char>(key.second));
|
|
}
|
|
|
|
// NOLINTNEXTLINE
|
|
static ShapeAttrStorage* construct(mlir::AttributeStorageAllocator& allocator,
|
|
const KeyTy& key)
|
|
{
|
|
return new (allocator.allocate<ShapeAttrStorage>())
|
|
ShapeAttrStorage(allocator.copyInto(key.first), key.second);
|
|
}
|
|
|
|
ArrayRef<int64_t> shape;
|
|
bool unranked = false;
|
|
};
|
|
|
|
// The storage class for FuncAttr.
|
|
struct FuncAttrStorage : public AttributeStorage
|
|
{
|
|
using KeyTy = std::pair<Attribute, Attribute>;
|
|
|
|
explicit FuncAttrStorage(Attribute name, Attribute attrs)
|
|
: name(name), attrs(attrs)
|
|
{
|
|
}
|
|
|
|
bool operator==(const KeyTy& key) const
|
|
{
|
|
return key == KeyTy(name, attrs);
|
|
}
|
|
static unsigned hashKey(const KeyTy& key)
|
|
{
|
|
return llvm::hash_combine(key.first, key.second);
|
|
}
|
|
|
|
static FuncAttrStorage* construct(mlir::AttributeStorageAllocator& allocator,
|
|
const KeyTy& key)
|
|
{
|
|
return new (allocator.allocate<FuncAttrStorage>())
|
|
FuncAttrStorage(key.first, key.second);
|
|
}
|
|
|
|
Attribute name;
|
|
Attribute attrs;
|
|
};
|
|
|
|
} // namespace detail
|
|
|
|
// Get or create a shape attribute.
|
|
ShapeAttr ShapeAttr::get(mlir::MLIRContext* context,
|
|
llvm::Optional<ArrayRef<int64_t> > shape)
|
|
{
|
|
if (shape) return Base::get(context, *shape, /*unranked=*/false);
|
|
|
|
return Base::get(context, ArrayRef<int64_t>(), /*unranked=*/true);
|
|
}
|
|
|
|
// Get or create a shape attribute.
|
|
ShapeAttr ShapeAttr::get(mlir::MLIRContext* context, ShapedType shaped_type)
|
|
{
|
|
if (shaped_type.hasRank())
|
|
return Base::get(context, shaped_type.getShape(), /*unranked=*/false);
|
|
|
|
return Base::get(context, ArrayRef<int64_t>(), /*unranked=*/true);
|
|
}
|
|
|
|
llvm::Optional<ArrayRef<int64_t> > ShapeAttr::getValue() const
|
|
{
|
|
if (hasRank()) return getShape();
|
|
return llvm::None;
|
|
}
|
|
|
|
bool ShapeAttr::hasRank() const
|
|
{
|
|
return !getImpl()->unranked;
|
|
}
|
|
|
|
int64_t ShapeAttr::getRank() const
|
|
{
|
|
assert(hasRank());
|
|
return getImpl()->shape.size();
|
|
}
|
|
|
|
ArrayRef<int64_t> ShapeAttr::getShape() const
|
|
{
|
|
assert(hasRank());
|
|
return getImpl()->shape;
|
|
}
|
|
|
|
bool ShapeAttr::hasStaticShape() const
|
|
{
|
|
if (!hasRank()) return false;
|
|
|
|
for (auto dim : getShape())
|
|
{
|
|
if (dim < 0) return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
FuncAttr FuncAttr::get(mlir::MLIRContext* context, llvm::StringRef name,
|
|
DictionaryAttr attr)
|
|
{
|
|
auto symbol = SymbolRefAttr::get(context, name);
|
|
return Base::get(context, symbol, attr);
|
|
}
|
|
|
|
FuncAttr FuncAttr::get(mlir::MLIRContext* context, SymbolRefAttr symbol,
|
|
DictionaryAttr attr)
|
|
{
|
|
return Base::get(context, symbol, attr);
|
|
}
|
|
|
|
SymbolRefAttr FuncAttr::GetName() const
|
|
{
|
|
return getImpl()->name.cast<SymbolRefAttr>();
|
|
}
|
|
|
|
DictionaryAttr FuncAttr::GetAttrs() const
|
|
{
|
|
return getImpl()->attrs.cast<DictionaryAttr>();
|
|
}
|
|
|
|
} // namespace TF
|
|
} // namespace mlir
|