Files
deepin-ocr/3rdparty/ncnn/tools/pnnx/src/ir.cpp

2598 lines
78 KiB
C++
Raw Normal View History

// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
#include "ir.h"
#include <stdint.h>
#include <algorithm>
#include <fstream>
#include <sstream>
#include <string>
#include <stack>
#include <torch/script.h>
#include "storezip.h"
namespace pnnx {
static bool type_is_integer(int type)
{
if (type == 1) return false;
if (type == 2) return false;
if (type == 3) return false;
if (type == 4) return true;
if (type == 5) return true;
if (type == 6) return true;
if (type == 7) return true;
if (type == 8) return true;
if (type == 9) return true;
return false;
}
static const char* type_to_string(int type)
{
if (type == 1) return "f32";
if (type == 2) return "f64";
if (type == 3) return "f16";
if (type == 4) return "i32";
if (type == 5) return "i64";
if (type == 6) return "i16";
if (type == 7) return "i8";
if (type == 8) return "u8";
if (type == 9) return "bool";
return "null";
}
static const char* type_to_numpy_string(int type)
{
if (type == 1) return "float32";
if (type == 2) return "float64";
if (type == 3) return "float16";
if (type == 4) return "int32";
if (type == 5) return "int64";
if (type == 6) return "int16";
if (type == 7) return "int8";
if (type == 8) return "uint8";
if (type == 9) return "bool8";
return "null";
}
static const char* type_to_dtype_string(int type)
{
if (type == 1) return "torch.float";
if (type == 2) return "torch.double";
if (type == 3) return "torch.half";
if (type == 4) return "torch.int";
if (type == 5) return "torch.long";
if (type == 6) return "torch.short";
if (type == 7) return "torch.int8";
if (type == 8) return "torch.uint8";
if (type == 9) return "torch.bool";
return "null";
}
static size_t type_to_elemsize(int type)
{
if (type == 1) return 4;
if (type == 2) return 8;
if (type == 3) return 2;
if (type == 4) return 4;
if (type == 5) return 8;
if (type == 6) return 2;
if (type == 7) return 1;
if (type == 8) return 1;
if (type == 9) return 1;
return 0; // null
}
static int string_to_type(const char* s)
{
if (strcmp(s, "f32") == 0) return 1;
if (strcmp(s, "f64") == 0) return 2;
if (strcmp(s, "f16") == 0) return 3;
if (strcmp(s, "i32") == 0) return 4;
if (strcmp(s, "i64") == 0) return 5;
if (strcmp(s, "i16") == 0) return 6;
if (strcmp(s, "i8") == 0) return 7;
if (strcmp(s, "u8") == 0) return 8;
if (strcmp(s, "bool") == 0) return 9;
return 0; // null
}
int get_at_tensor_type(const at::ScalarType& st)
{
if (st == c10::ScalarType::Float) return 1;
if (st == c10::ScalarType::Double) return 2;
if (st == c10::ScalarType::Half) return 3;
if (st == c10::ScalarType::Int) return 4;
if (st == c10::ScalarType::QInt32) return 4;
if (st == c10::ScalarType::Long) return 5;
if (st == c10::ScalarType::Short) return 6;
if (st == c10::ScalarType::Char) return 7;
if (st == c10::ScalarType::QInt8) return 7;
if (st == c10::ScalarType::Byte) return 8;
if (st == c10::ScalarType::QUInt8) return 8;
if (st == c10::ScalarType::Bool) return 9;
return 0; // unknown type
}
Parameter::Parameter(const torch::jit::Node* value_node)
{
type = 0;
if (value_node->kind() == c10::prim::Constant)
{
if (!value_node->hasAttribute(torch::jit::attr::value))
{
fprintf(stderr, "no attribute value\n");
return;
}
switch (value_node->output()->type()->kind())
{
case c10::TypeKind::NoneType:
{
type = 0;
break;
}
case c10::TypeKind::BoolType:
{
type = 1;
b = value_node->i(torch::jit::attr::value);
break;
}
case c10::TypeKind::IntType:
{
type = 2;
i = (int)value_node->i(torch::jit::attr::value);
break;
}
case c10::TypeKind::FloatType:
{
type = 3;
f = (float)value_node->f(torch::jit::attr::value);
break;
}
case c10::TypeKind::StringType:
{
type = 4;
s = value_node->s(torch::jit::attr::value);
break;
}
case c10::TypeKind::TensorType:
{
at::Tensor t = value_node->t(torch::jit::attr::value);
if (t.dim() == 0)
{
if (t.scalar_type() == c10::ScalarType::Long)
{
type = 2;
i = (int)t.item<int64_t>();
}
else if (t.scalar_type() == c10::ScalarType::Int)
{
type = 2;
i = t.item<int>();
}
else if (t.scalar_type() == c10::ScalarType::Double)
{
type = 3;
f = (float)t.item<double>();
}
else if (t.scalar_type() == c10::ScalarType::Float)
{
type = 3;
f = t.item<float>();
}
else
{
fprintf(stderr, "unknown Parameter value kind %s of TensorType, t.dim = 0\n", value_node->kind().toDisplayString());
}
}
else
{
const int ndim = (int)t.dim();
type = 8;
fprintf(stderr, "unknown Parameter value kind %s of TensorType, t.dim = %d\n", value_node->kind().toDisplayString(), ndim);
}
break;
}
default:
{
fprintf(stderr, "unknown Parameter value kind %s\n", value_node->kind().toDisplayString());
break;
}
}
}
else if (value_node->kind() == c10::prim::ListConstruct)
{
switch (value_node->output()->type()->cast<c10::ListType>()->getElementType()->kind())
{
case c10::TypeKind::IntType:
{
type = 5;
for (const auto& x : value_node->inputs())
{
ai.push_back((int)x->node()->i(torch::jit::attr::value));
}
break;
}
case c10::TypeKind::FloatType:
{
type = 6;
for (const auto& x : value_node->inputs())
{
af.push_back((float)x->node()->f(torch::jit::attr::value));
}
break;
}
case c10::TypeKind::StringType:
{
type = 7;
for (const auto& x : value_node->inputs())
{
as.push_back(x->node()->s(torch::jit::attr::value));
}
break;
}
default:
{
fprintf(stderr, "unknown Parameter value kind %s\n", value_node->kind().toDisplayString());
break;
}
}
}
else
{
fprintf(stderr, "unknown Parameter value kind %s\n", value_node->kind().toDisplayString());
}
}
Parameter::Parameter(const torch::jit::Value* value)
: Parameter(value->node())
{
}
bool operator==(const Parameter& lhs, const Parameter& rhs)
{
if (lhs.type != rhs.type)
return false;
if (lhs.type == 0)
return true;
if (lhs.type == 1 && lhs.b == rhs.b)
return true;
if (lhs.type == 2 && lhs.i == rhs.i)
return true;
if (lhs.type == 3 && lhs.f == rhs.f)
return true;
if (lhs.type == 4 && lhs.s == rhs.s)
return true;
if (lhs.type == 5 && lhs.ai == rhs.ai)
return true;
if (lhs.type == 6 && lhs.af == rhs.af)
return true;
if (lhs.type == 7 && lhs.as == rhs.as)
return true;
return false;
}
Attribute::Attribute(const at::Tensor& t)
{
type = get_at_tensor_type(t.scalar_type());
const int ndim = (int)t.dim();
if (ndim == 0)
{
shape = {1};
data.resize(type_to_elemsize(type));
if (t.scalar_type() == c10::ScalarType::Long)
{
int64_t i = t.item<int64_t>();
memcpy((void*)data.data(), (const void*)&i, data.size());
}
else if (t.scalar_type() == c10::ScalarType::Int)
{
int i = t.item<int>();
memcpy((void*)data.data(), (const void*)&i, data.size());
}
else if (t.scalar_type() == c10::ScalarType::Double)
{
double f = t.item<double>();
memcpy((void*)data.data(), (const void*)&f, data.size());
}
else if (t.scalar_type() == c10::ScalarType::Float)
{
float f = t.item<float>();
memcpy((void*)data.data(), (const void*)&f, data.size());
}
else
{
fprintf(stderr, "unknown Attribute tensor scalar type %d\n", type);
}
return;
}
shape.resize(ndim);
for (int i = 0; i < ndim; i++)
shape[i] = t.size(i);
if (shape.size() > 0)
{
int size = shape[0];
for (size_t i = 1; i < shape.size(); i++)
{
size *= shape[i];
}
data.resize(size * type_to_elemsize(type));
memcpy((void*)data.data(), (const void*)t.cpu().contiguous().data_ptr(), data.size());
}
}
Attribute::Attribute(const std::initializer_list<int>& _shape, const std::vector<float>& t)
{
type = 1;
shape = _shape;
if (shape.size() > 0)
{
int size = shape[0];
for (size_t i = 1; i < shape.size(); i++)
{
size *= shape[i];
}
data.resize(size * type_to_elemsize(type));
memcpy((void*)data.data(), (const void*)t.data(), data.size());
}
}
bool operator==(const Attribute& lhs, const Attribute& rhs)
{
if (lhs.type != rhs.type)
return false;
if (lhs.type == 0)
return true;
if (lhs.shape != rhs.shape)
return false;
if (lhs.data != rhs.data)
return false;
return true;
}
Attribute operator+(const Attribute& a, const Attribute& b)
{
Attribute c;
if (a.type != b.type)
{
fprintf(stderr, "concat attribute type mismatch\n");
return c;
}
if (a.shape.size() != b.shape.size())
{
fprintf(stderr, "concat attribute shape rank mismatch\n");
return c;
}
for (int i = 1; i < (int)a.shape.size(); i++)
{
if (a.shape[i] != b.shape[i])
{
fprintf(stderr, "concat attribute shape mismatch\n");
return c;
}
}
c.type = a.type;
c.shape = a.shape;
c.shape[0] += b.shape[0]; // concat the first dim
c.data.resize(a.data.size() + b.data.size());
memcpy(c.data.data(), a.data.data(), a.data.size());
memcpy(c.data.data() + a.data.size(), b.data.data(), b.data.size());
return c;
}
Parameter Parameter::parse_from_string(const std::string& value)
{
Parameter p;
p.type = 0;
if (value == "None" || value == "()" || value == "[]")
{
return p;
}
if (value == "True" || value == "False")
{
// bool
p.type = 1;
p.b = value == "True";
return p;
}
if (value[0] == '(' || value[0] == '[')
{
// list
std::string lc = value.substr(1, value.size() - 2);
std::istringstream lcss(lc);
while (!lcss.eof())
{
std::string elem;
std::getline(lcss, elem, ',');
if ((elem[0] != '-' && (elem[0] < '0' || elem[0] > '9')) || (elem[0] == '-' && (elem[1] < '0' || elem[1] > '9')))
{
// string
p.type = 7;
p.as.push_back(elem);
}
else if (elem.find('.') != std::string::npos || elem.find('e') != std::string::npos)
{
// float
p.type = 6;
p.af.push_back(std::stof(elem));
}
else
{
// integer
p.type = 5;
p.ai.push_back(std::stoi(elem));
}
}
return p;
}
if ((value[0] != '-' && (value[0] < '0' || value[0] > '9')) || (value[0] == '-' && (value[1] < '0' || value[1] > '9')))
{
// string
p.type = 4;
p.s = value;
return p;
}
if (value.find('.') != std::string::npos || value.find('e') != std::string::npos)
{
// float
p.type = 3;
p.f = std::stof(value);
return p;
}
// integer
p.type = 2;
p.i = std::stoi(value);
return p;
}
Graph::Graph()
{
}
Graph::~Graph()
{
for (auto x : ops)
delete x;
for (auto x : operands)
delete x;
ops.clear();
operands.clear();
}
Graph::Graph(const Graph& /*rhs*/)
{
}
Graph& Graph::operator=(const Graph& /*rhs*/)
{
return *this;
}
static void load_parameter(Operator* op, const std::string& key, const std::string& value)
{
op->params[key] = Parameter::parse_from_string(value);
}
static void load_input_key(Operator* op, const std::string& key, const std::string& value)
{
op->inputnames.resize(op->inputs.size());
for (size_t i = 0; i < op->inputs.size(); i++)
{
const Operand* oprand = op->inputs[i];
if (oprand->name == value)
{
op->inputnames[i] = key;
break;
}
}
}
static void load_shape(Operator* op, const std::string& key, const std::string& value)
{
Operand* operand = 0;
for (auto r : op->inputs)
{
if (r->name == key)
{
operand = r;
break;
}
}
if (!operand)
{
for (auto r : op->outputs)
{
if (r->name == key)
{
operand = r;
break;
}
}
}
if (!operand)
{
fprintf(stderr, "no such operand %s for operator %s\n", key.c_str(), op->name.c_str());
return;
}
// type
std::string typestr = value.substr(value.find_last_of(')') + 1);
operand->type = string_to_type(typestr.c_str());
// shape
std::string lc = value.substr(1, value.find_last_of(')') - 1);
std::istringstream lcss(lc);
operand->shape.clear();
while (!lcss.eof())
{
std::string elem;
std::getline(lcss, elem, ',');
if (elem == "?")
{
operand->shape.push_back(-1);
}
else
{
int i = std::stoi(elem);
operand->shape.push_back(i);
}
}
}
static void load_attribute(Operator* op, const std::string& key, const std::string& value, StoreZipReader& szr)
{
Attribute& a = op->attrs[key];
// type
std::string typestr = value.substr(value.find_last_of(')') + 1);
a.type = string_to_type(typestr.c_str());
if (a.type == 0)
return;
// shape
std::string lc = value.substr(1, value.find_last_of(')') - 1);
std::istringstream lcss(lc);
a.shape.clear();
while (!lcss.eof())
{
std::string elem;
std::getline(lcss, elem, ',');
int i = std::stoi(elem);
a.shape.push_back(i);
}
if (a.shape.empty())
return;
// data
size_t size = 1;
for (int i : a.shape)
{
size *= i;
}
size_t bytesize = size * type_to_elemsize(a.type);
std::string filename = op->name + "." + key;
size_t filesize = szr.get_file_size(filename);
if (filesize == 0)
{
// no such file
return;
}
if (filesize != bytesize)
{
fprintf(stderr, "file size not match expect %lu but got %lu\n", bytesize, filesize);
}
a.data.resize(bytesize);
szr.read_file(filename, (char*)a.data.data());
}
int Graph::load(const std::string& parampath, const std::string& binpath)
{
std::ifstream is(parampath, std::ios::in | std::ios::binary);
if (!is.good())
{
fprintf(stderr, "open failed\n");
return -1;
}
StoreZipReader szr;
if (szr.open(binpath) != 0)
{
fprintf(stderr, "open failed\n");
return -1;
}
int magic = 0;
{
std::string line;
std::getline(is, line);
std::istringstream iss(line);
iss >> magic;
}
int operator_count = 0;
int operand_count = 0;
{
std::string line;
std::getline(is, line);
std::istringstream iss(line);
iss >> operator_count >> operand_count;
}
for (int i = 0; i < operator_count; i++)
{
std::string line;
std::getline(is, line);
std::istringstream iss(line);
std::string type;
std::string name;
int input_count = 0;
int output_count = 0;
iss >> type >> name >> input_count >> output_count;
Operator* op = new_operator(type, name);
for (int j = 0; j < input_count; j++)
{
std::string operand_name;
iss >> operand_name;
Operand* r = get_operand(operand_name);
r->consumers.push_back(op);
op->inputs.push_back(r);
}
for (int j = 0; j < output_count; j++)
{
std::string operand_name;
iss >> operand_name;
Operand* r = new_operand(operand_name);
r->producer = op;
op->outputs.push_back(r);
}
// key=value
while (!iss.eof())
{
std::string param;
iss >> param;
std::string key;
std::string value;
std::istringstream pss(param);
std::getline(pss, key, '=');
std::getline(pss, value);
if (key[0] == '@')
{
// attribute
load_attribute(op, key.substr(1), value, szr);
}
else if (key[0] == '$')
{
// operand input key
load_input_key(op, key.substr(1), value);
}
else if (key[0] == '#')
{
// operand shape
load_shape(op, key.substr(1), value);
}
else
{
// parameter
load_parameter(op, key, value);
}
}
}
return 0;
}
int Graph::save(const std::string& parampath, const std::string& binpath)
{
FILE* paramfp = fopen(parampath.c_str(), "wb");
if (!paramfp)
{
fprintf(stderr, "fopen %s failed\n", parampath.c_str());
return -1;
}
StoreZipWriter szw;
if (szw.open(binpath) != 0)
{
fprintf(stderr, "open failed\n");
return -1;
}
// magic
fprintf(paramfp, "7767517\n");
// op count and oprand count
fprintf(paramfp, "%d %d\n", (int)ops.size(), (int)operands.size());
for (const Operator* op : ops)
{
fprintf(paramfp, "%-24s %-24s %d %d", op->type.c_str(), op->name.c_str(), (int)op->inputs.size(), (int)op->outputs.size());
for (const Operand* oprand : op->inputs)
{
fprintf(paramfp, " %s", oprand->name.c_str());
}
for (const Operand* oprand : op->outputs)
{
fprintf(paramfp, " %s", oprand->name.c_str());
}
for (const auto& it : op->params)
{
fprintf(paramfp, " %s=", it.first.c_str());
const Parameter& param = it.second;
if (param.type == 0)
{
fprintf(paramfp, "None");
}
if (param.type == 1)
{
if (param.b)
fprintf(paramfp, "True");
else
fprintf(paramfp, "False");
}
if (param.type == 2)
{
fprintf(paramfp, "%d", param.i);
}
if (param.type == 3)
{
fprintf(paramfp, "%e", param.f);
}
if (param.type == 4)
{
fprintf(paramfp, "%s", param.s.c_str());
}
if (param.type == 5)
{
fprintf(paramfp, "(");
for (size_t i = 0; i < param.ai.size(); i++)
{
fprintf(paramfp, "%d", param.ai[i]);
if (i + 1 != param.ai.size())
fprintf(paramfp, ",");
}
fprintf(paramfp, ")");
}
if (param.type == 6)
{
fprintf(paramfp, "(");
for (size_t i = 0; i < param.af.size(); i++)
{
fprintf(paramfp, "%e", param.af[i]);
if (i + 1 != param.af.size())
fprintf(paramfp, ",");
}
fprintf(paramfp, ")");
}
if (param.type == 7)
{
fprintf(paramfp, "(");
for (size_t i = 0; i < param.as.size(); i++)
{
fprintf(paramfp, "%s", param.as[i].c_str());
if (i + 1 != param.as.size())
fprintf(paramfp, ",");
}
fprintf(paramfp, ")");
}
}
for (const auto& it : op->attrs)
{
fprintf(paramfp, " @%s=", it.first.c_str());
const Attribute& attr = it.second;
fprintf(paramfp, "(");
for (int i = 0; i < (int)attr.shape.size() - 1; i++)
{
fprintf(paramfp, "%d,", attr.shape[i]);
}
if (attr.shape.size() > 0)
fprintf(paramfp, "%d", attr.shape[attr.shape.size() - 1]);
fprintf(paramfp, ")");
fprintf(paramfp, type_to_string(attr.type));
std::string filename = op->name + "." + it.first;
szw.write_file(filename, attr.data.data(), attr.data.size());
}
if (op->inputnames.size() == op->inputs.size())
{
for (size_t i = 0; i < op->inputs.size(); i++)
{
if (op->inputnames[i].empty())
continue;
const Operand* oprand = op->inputs[i];
fprintf(paramfp, " $%s=%s", op->inputnames[i].c_str(), oprand->name.c_str());
}
}
for (const Operand* oprand : op->inputs)
{
if (oprand->shape.empty())
continue;
fprintf(paramfp, " #%s=", oprand->name.c_str());
fprintf(paramfp, "(");
for (int i = 0; i < (int)oprand->shape.size() - 1; i++)
{
if (oprand->shape[i] == -1)
fprintf(paramfp, "?,");
else
fprintf(paramfp, "%d,", oprand->shape[i]);
}
if (oprand->shape.size() > 0)
{
if (oprand->shape[oprand->shape.size() - 1] == -1)
fprintf(paramfp, "?");
else
fprintf(paramfp, "%d", oprand->shape[oprand->shape.size() - 1]);
}
fprintf(paramfp, ")");
fprintf(paramfp, type_to_string(oprand->type));
}
for (const Operand* oprand : op->outputs)
{
if (oprand->shape.empty())
continue;
fprintf(paramfp, " #%s=", oprand->name.c_str());
fprintf(paramfp, "(");
for (int i = 0; i < (int)oprand->shape.size() - 1; i++)
{
if (oprand->shape[i] == -1)
fprintf(paramfp, "?,");
else
fprintf(paramfp, "%d,", oprand->shape[i]);
}
if (oprand->shape.size() > 0)
{
if (oprand->shape[oprand->shape.size() - 1] == -1)
fprintf(paramfp, "?");
else
fprintf(paramfp, "%d", oprand->shape[oprand->shape.size() - 1]);
}
fprintf(paramfp, ")");
fprintf(paramfp, type_to_string(oprand->type));
}
fprintf(paramfp, "\n");
}
fclose(paramfp);
return 0;
}
static std::string sanitize_identifier(const std::string& s)
{
std::string ss = s;
for (size_t i = 0; i < ss.size(); i++)
{
if (ss[i] == '.' || ss[i] == ':')
ss[i] = '_';
}
return ss;
}
static std::string expand_expression(const Operator* op)
{
std::string expr = op->params.at("expr").s;
// split into tokens
std::vector<std::string> tokens;
{
std::string t;
for (size_t i = 0; i < expr.size(); i++)
{
char ch = expr[i];
if (ch == '[') // list
{
t += ch;
tokens.push_back(t);
t.clear();
}
else if (ch == '(' || ch == ')' || ch == ',' || ch == ']')
{
if (!t.empty())
{
tokens.push_back(t);
t.clear();
}
}
else
{
t += ch;
}
}
if (!t.empty())
{
tokens.push_back(t);
}
}
// scan and stack
std::stack<std::string> exprstack;
for (int i = (int)tokens.size() - 1; i >= 0; i--)
{
const std::string& t = tokens[i];
if (t == "size")
{
std::string a = exprstack.top();
exprstack.pop();
std::string b = exprstack.top();
exprstack.pop();
std::string r = a + ".size(" + b + ")";
exprstack.push(r);
}
else if (t == "int" || t == "sqrt" || t == "rsqrt" || t == "neg")
{
std::string unaryop;
if (t == "int") unaryop = "int";
if (t == "sqrt") unaryop = "torch.sqrt";
if (t == "rsqrt") unaryop = "torch.rsqrt";
if (t == "neg") unaryop = "torch.neg";
std::string a = exprstack.top();
exprstack.pop();
std::string r = unaryop + "(" + a + ")";
exprstack.push(r);
}
else if (t == "pow")
{
std::string a = exprstack.top();
exprstack.pop();
std::string b = exprstack.top();
exprstack.pop();
std::string r = a + ".pow(" + b + ")";
exprstack.push(r);
}
else if (t == "add" || t == "sub" || t == "mul" || t == "div" || t == "floor_divide")
{
std::string binaryop;
if (t == "add") binaryop = "+";
if (t == "sub") binaryop = "-";
if (t == "mul") binaryop = "*";
if (t == "div") binaryop = "/";
if (t == "floor_divide") binaryop = "//";
std::string a = exprstack.top();
exprstack.pop();
std::string b = exprstack.top();
exprstack.pop();
std::string r = std::string("(") + a + " " + binaryop + " " + b + ")";
exprstack.push(r);
}
else if (t == "[") // list
{
std::vector<std::string> elements;
while (!exprstack.empty())
{
std::string a = exprstack.top();
exprstack.pop();
elements.push_back(a);
}
std::string r = "[";
for (int j = 0; j < (int)elements.size() - 1; j++)
{
r += elements[j];
if (j + 1 != (int)elements.size())
r += ", ";
}
if (!elements.empty())
{
r += elements[elements.size() - 1];
}
r += "]";
exprstack.push(r);
}
else if (t[0] == '@')
{
int input_index = std::stoi(t.substr(1));
std::string varid = std::string("v_") + sanitize_identifier(op->inputs[input_index]->name);
exprstack.push(varid);
}
else
{
// literal
exprstack.push(t);
}
}
std::string r = exprstack.top();
exprstack.pop();
return r;
}
static std::string make_slice_expression(const Operator* op)
{
for (size_t j = 0; j < op->inputnames.size(); j++)
{
fprintf(stderr, "make_slice_expression %s %s\n", op->inputnames[j].c_str(), op->inputs[j]->name.c_str());
}
std::vector<int> dims = op->params.at("dims").ai;
std::string r;
int last_dim = -1;
const int ndim = (int)dims.size();
for (int i = 0; i < ndim; i++)
{
int dim = dims[i];
for (int j = last_dim + 1; j < dim; j++)
{
r += ":,";
}
last_dim = dim;
if (op->params.find("starts") != op->params.end())
{
std::vector<int> starts = op->params.at("starts").ai;
int start = starts[i];
if (start != 0)
r += std::to_string(start);
}
else
{
fprintf(stderr, "find start\n");
// find start
for (size_t j = 0; j < op->inputnames.size(); j++)
{
if (op->inputnames[j] == "start")
{
r += std::string("v_") + sanitize_identifier(op->inputs[j]->name);
fprintf(stderr, "find start %s\n", op->inputs[j]->name.c_str());
break;
}
}
}
r += ':';
if (op->params.find("ends") != op->params.end())
{
std::vector<int> ends = op->params.at("ends").ai;
int end = ends[i];
if (end != -1)
r += std::to_string(end);
}
else
{
// find end
for (size_t j = 0; j < op->inputnames.size(); j++)
{
if (op->inputnames[j] == "end")
{
r += std::string("v_") + sanitize_identifier(op->inputs[j]->name);
break;
}
}
}
if (op->params.find("steps") != op->params.end())
{
std::vector<int> steps = op->params.at("steps").ai;
int step = steps[i];
if (step != 1)
{
r += ':';
r += std::to_string(step);
}
}
else
{
// find step
for (size_t j = 0; j < op->inputnames.size(); j++)
{
if (op->inputnames[j] == "step")
{
r += ':';
r += std::string("v_") + sanitize_identifier(op->inputs[j]->name);
break;
}
}
}
if (i + 1 != ndim)
r += ',';
}
return r;
}
static std::string make_index_expression(const Operator* op)
{
fprintf(stderr, "make_index_expression %s\n", op->name.c_str());
std::string index_expr = op->params.at("expr").s;
// strip out-most [ ] pair
index_expr = index_expr.substr(1, index_expr.size() - 2);
// None,None, -> ...,
bool leading_none = false;
while (index_expr.substr(0, 5) == "None,")
{
leading_none = true;
index_expr = index_expr.substr(5);
}
if (leading_none)
{
index_expr = "...," + index_expr;
}
return index_expr;
}
int Graph::python(const std::string& pypath, const std::string& pnnxbinpath)
{
FILE* pyfp = fopen(pypath.c_str(), "wb");
if (!pyfp)
{
fprintf(stderr, "fopen %s failed\n", pypath.c_str());
return -1;
}
fprintf(pyfp, "import os\n");
fprintf(pyfp, "import numpy as np\n");
fprintf(pyfp, "import tempfile, zipfile\n");
fprintf(pyfp, "import torch\n");
fprintf(pyfp, "import torch.nn as nn\n");
fprintf(pyfp, "import torch.nn.functional as F\n");
fprintf(pyfp, "import torchvision\n");
fprintf(pyfp, "\n");
fprintf(pyfp, "class Model(nn.Module):\n");
fprintf(pyfp, " def __init__(self):\n");
fprintf(pyfp, " super(Model, self).__init__()\n");
fprintf(pyfp, "\n");
// module
{
for (const Operator* op : ops)
{
if (op->type.substr(0, 3) != "nn." && op->type.substr(0, 16) != "torchvision.ops.")
continue;
fprintf(pyfp, " self.%s = %s(", sanitize_identifier(op->name).c_str(), op->type.c_str());
int param_count = op->params.size();
if (op->type == "nn.quantized.Conv2d" || op->type == "nn.quantized.Linear")
{
param_count -= 2; // ignore scale and zero_point
}
int param_index = 0;
for (const auto& it : op->params)
{
if (op->type == "nn.quantized.Conv2d" || op->type == "nn.quantized.Linear")
{
if (it.first == "scale" || it.first == "zero_point")
continue;
}
fprintf(pyfp, "%s=", it.first.c_str());
const Parameter& param = it.second;
if (param.type == 0)
{
fprintf(pyfp, "None");
}
if (param.type == 1)
{
if (param.b)
fprintf(pyfp, "True");
else
fprintf(pyfp, "False");
}
if (param.type == 2)
{
fprintf(pyfp, "%d", param.i);
}
if (param.type == 3)
{
fprintf(pyfp, "%f", param.f);
}
if (param.type == 4)
{
if (param.s.substr(0, 6) == "torch.")
{
fprintf(pyfp, "%s", param.s.c_str());
}
else
{
fprintf(pyfp, "\'%s\'", param.s.c_str());
}
}
if (param.type == 5)
{
fprintf(pyfp, "(");
for (size_t i = 0; i < param.ai.size(); i++)
{
fprintf(pyfp, "%d", param.ai[i]);
if (i + 1 != param.ai.size() || param.ai.size() == 1)
fprintf(pyfp, ",");
}
fprintf(pyfp, ")");
}
if (param.type == 6)
{
fprintf(pyfp, "(");
for (size_t i = 0; i < param.af.size(); i++)
{
fprintf(pyfp, "%f", param.af[i]);
if (i + 1 != param.af.size() || param.af.size() == 1)
fprintf(pyfp, ",");
}
fprintf(pyfp, ")");
}
if (param.type == 7)
{
fprintf(pyfp, "(");
for (size_t i = 0; i < param.as.size(); i++)
{
if (param.as[i].substr(0, 6) == "torch.")
{
fprintf(pyfp, "%s", param.as[i].c_str());
}
else
{
fprintf(pyfp, "\'%s\'", param.as[i].c_str());
}
if (i + 1 != param.as.size() || param.as.size() == 1)
fprintf(pyfp, ",");
}
fprintf(pyfp, ")");
}
param_index++;
if (param_index != param_count)
fprintf(pyfp, ", ");
}
fprintf(pyfp, ")\n");
}
}
fprintf(pyfp, "\n");
// load weights
{
fprintf(pyfp, " archive = zipfile.ZipFile('%s', 'r')\n", pnnxbinpath.c_str());
for (const Operator* op : ops)
{
if (op->type.substr(0, 3) != "nn." && op->type.substr(0, 16) != "torchvision.ops.")
continue;
if (op->type == "nn.quantized.Conv2d" || op->type == "nn.quantized.Linear")
{
for (const auto& it : op->attrs)
{
if (it.first == "weight" || it.first == "bias")
{
fprintf(pyfp, " self_%s_%s = self.load_pnnx_bin_as_parameter(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), it.first.c_str(), op->name.c_str(), it.first.c_str());
}
else
{
// unknown attr
continue;
}
const Attribute& attr = it.second;
for (size_t i = 0; i < attr.shape.size(); i++)
{
fprintf(pyfp, "%d", attr.shape[i]);
if (i + 1 != attr.shape.size())
fprintf(pyfp, ",");
}
fprintf(pyfp, "), '%s', requires_grad=False)\n", type_to_numpy_string(attr.type));
}
fprintf(pyfp, " self.%s.set_weight_bias(self_%s_weight, self_%s_bias)\n", sanitize_identifier(op->name).c_str(), sanitize_identifier(op->name).c_str(), sanitize_identifier(op->name).c_str());
fprintf(pyfp, " self.%s.scale = %f\n", sanitize_identifier(op->name).c_str(), op->params.at("scale").f);
fprintf(pyfp, " self.%s.zero_point = %d\n", sanitize_identifier(op->name).c_str(), op->params.at("zero_point").i);
continue;
}
for (const auto& it : op->attrs)
{
if (it.first == "running_mean" || it.first == "running_var")
{
fprintf(pyfp, " self.%s.%s = self.load_pnnx_bin_as_tensor(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), it.first.c_str(), op->name.c_str(), it.first.c_str());
}
else
{
fprintf(pyfp, " self.%s.%s = self.load_pnnx_bin_as_parameter(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), it.first.c_str(), op->name.c_str(), it.first.c_str());
}
const Attribute& attr = it.second;
for (size_t i = 0; i < attr.shape.size(); i++)
{
fprintf(pyfp, "%d", attr.shape[i]);
if (i + 1 != attr.shape.size())
fprintf(pyfp, ",");
}
if (attr.type == 1 || attr.type == 2 || attr.type == 3)
{
fprintf(pyfp, "), '%s')\n", type_to_numpy_string(attr.type));
}
else
{
fprintf(pyfp, "), '%s', requires_grad=False)\n", type_to_numpy_string(attr.type));
}
}
}
for (const Operator* op : ops)
{
if (op->type != "pnnx.Attribute")
continue;
const std::string& key = op->attrs.begin()->first;
const Attribute& attr = op->attrs.begin()->second;
bool is_running_mean_var = false;
{
const Operand* r = op->outputs[0];
if (r->consumers.size() == 1)
{
const Operator* op2 = r->consumers[0];
if (op2->type == "F.batch_norm" || op2->type == "F.instance_norm")
{
if (r == op2->inputs[1] || r == op2->inputs[2])
{
is_running_mean_var = true;
}
}
}
}
if (is_running_mean_var)
{
fprintf(pyfp, " self.%s_%s = self.load_pnnx_bin_as_tensor(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), sanitize_identifier(key).c_str(), op->name.c_str(), key.c_str());
}
else
{
fprintf(pyfp, " self.%s_%s = self.load_pnnx_bin_as_parameter(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), sanitize_identifier(key).c_str(), op->name.c_str(), key.c_str());
}
for (size_t i = 0; i < attr.shape.size(); i++)
{
fprintf(pyfp, "%d", attr.shape[i]);
if (i + 1 != attr.shape.size())
fprintf(pyfp, ",");
}
if (attr.type == 1 || attr.type == 2 || attr.type == 3)
{
fprintf(pyfp, "), '%s')\n", type_to_numpy_string(attr.type));
}
else
{
fprintf(pyfp, "), '%s', requires_grad=False)\n", type_to_numpy_string(attr.type));
}
}
fprintf(pyfp, " archive.close()\n");
}
fprintf(pyfp, "\n");
// utility function
{
fprintf(pyfp, " def load_pnnx_bin_as_parameter(self, archive, key, shape, dtype, requires_grad=True):\n");
fprintf(pyfp, " return nn.Parameter(self.load_pnnx_bin_as_tensor(archive, key, shape, dtype), requires_grad)\n");
fprintf(pyfp, "\n");
fprintf(pyfp, " def load_pnnx_bin_as_tensor(self, archive, key, shape, dtype):\n");
fprintf(pyfp, " _, tmppath = tempfile.mkstemp()\n");
fprintf(pyfp, " tmpf = open(tmppath, 'wb')\n");
fprintf(pyfp, " with archive.open(key) as keyfile:\n");
fprintf(pyfp, " tmpf.write(keyfile.read())\n");
fprintf(pyfp, " tmpf.close()\n");
fprintf(pyfp, " m = np.memmap(tmppath, dtype=dtype, mode='r', shape=shape).copy()\n");
fprintf(pyfp, " os.remove(tmppath)\n");
fprintf(pyfp, " return torch.from_numpy(m)\n");
}
fprintf(pyfp, "\n");
// def forward
{
fprintf(pyfp, " def forward(self");
for (const Operator* op : ops)
{
if (op->type != "pnnx.Input")
continue;
fprintf(pyfp, ", v_%s", sanitize_identifier(op->outputs[0]->name).c_str());
}
fprintf(pyfp, "):\n");
}
// forward body
{
for (const Operator* op : ops)
{
if (op->type == "pnnx.Input" || op->type == "pnnx.Output")
continue;
fprintf(pyfp, " ");
if (op->type == "pnnx.Expression")
{
// expr
for (size_t i = 0; i < op->outputs.size(); i++)
{
fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str());
if (i + 1 != op->outputs.size())
fprintf(pyfp, ", ");
}
std::string expanded_expr = expand_expression(op);
fprintf(pyfp, " = %s\n", expanded_expr.c_str());
}
else if (op->type == "pnnx.Attribute")
{
const std::string& key = op->attrs.begin()->first;
fprintf(pyfp, "v_%s = self.%s_%s\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->name).c_str(), sanitize_identifier(key).c_str());
}
else if (op->type == "Tensor.slice")
{
// slice expr
std::string slice_expr = make_slice_expression(op);
fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), slice_expr.c_str());
}
else if (op->type == "Tensor.index")
{
// index expr
if (op->inputs.size() == 2)
{
std::string expanded_expr = expand_expression(op->inputs[1]->producer);
fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), expanded_expr.c_str());
}
else
{
std::string index_expr = make_index_expression(op);
fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), index_expr.c_str());
}
}
else if (op->type == "Tensor.view" || op->type == "Tensor.reshape")
{
// view reshape
fprintf(pyfp, "v_%s = v_%s.%s(", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), op->type.substr(7).c_str());
if (op->inputs.size() == 2)
{
fprintf(pyfp, "*v_%s", sanitize_identifier(op->inputs[1]->name).c_str());
}
else
{
const std::vector<int>& shape = op->params.at("shape").ai;
for (size_t i = 0; i < shape.size(); i++)
{
fprintf(pyfp, "%d", shape[i]);
if (i + 1 != shape.size())
fprintf(pyfp, ", ");
}
}
fprintf(pyfp, ")\n");
}
else if (op->type == "Tensor.repeat")
{
// view reshape
fprintf(pyfp, "v_%s = v_%s.%s(", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), op->type.substr(7).c_str());
if (op->inputs.size() == 2)
{
fprintf(pyfp, "*v_%s", sanitize_identifier(op->inputs[1]->name).c_str());
}
else
{
const std::vector<int>& sizes = op->params.at("sizes").ai;
for (size_t i = 0; i < sizes.size(); i++)
{
fprintf(pyfp, "%d", sizes[i]);
if (i + 1 != sizes.size())
fprintf(pyfp, ", ");
}
}
fprintf(pyfp, ")\n");
}
else if (op->type == "torch.cat" || op->type == "torch.stack")
{
// cat
fprintf(pyfp, "v_%s = %s(", sanitize_identifier(op->outputs[0]->name).c_str(), op->type.c_str());
if (op->inputs.size() == 1)
{
fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[0]->name).c_str());
}
else
{
fprintf(pyfp, "(");
for (size_t i = 0; i < op->inputs.size(); i++)
{
fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str());
if (i + 1 != op->inputs.size())
fprintf(pyfp, ", ");
}
fprintf(pyfp, ")");
}
fprintf(pyfp, ", dim=%d", op->params.at("dim").i);
fprintf(pyfp, ")\n");
}
else if (op->type == "prim::TupleUnpack")
{
for (size_t i = 0; i < op->outputs.size(); i++)
{
fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str());
if (i + 1 != op->outputs.size())
fprintf(pyfp, ", ");
}
fprintf(pyfp, " = v_%s\n", sanitize_identifier(op->inputs[0]->name).c_str());
}
else if (op->type == "prim::TupleConstruct")
{
fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[0]->name).c_str());
fprintf(pyfp, " = (");
for (size_t i = 0; i < op->inputs.size(); i++)
{
fprintf(pyfp, "v_%s, ", sanitize_identifier(op->inputs[i]->name).c_str());
}
fprintf(pyfp, ")\n");
}
else if (op->type == "prim::ListUnpack")
{
for (size_t i = 0; i < op->outputs.size(); i++)
{
fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str());
if (i + 1 != op->outputs.size())
fprintf(pyfp, ", ");
}
fprintf(pyfp, " = v_%s\n", sanitize_identifier(op->inputs[0]->name).c_str());
}
else if (op->type == "prim::ListConstruct")
{
fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[0]->name).c_str());
fprintf(pyfp, " = [");
for (size_t i = 0; i < op->inputs.size(); i++)
{
fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str());
if (i + 1 != op->inputs.size())
fprintf(pyfp, ", ");
}
fprintf(pyfp, "]\n");
}
else if (op->type == "nn.LSTM")
{
if (op->outputs.size() == 1)
{
fprintf(pyfp, "v_%s, _", sanitize_identifier(op->outputs[0]->name).c_str());
}
else
{
fprintf(pyfp, "v_%s, (v_%s, v_%s)", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->outputs[1]->name).c_str(), sanitize_identifier(op->outputs[2]->name).c_str());
}
fprintf(pyfp, " = self.%s(", sanitize_identifier(op->name).c_str());
fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[0]->name).c_str());
if (op->inputs.size() == 3)
{
fprintf(pyfp, ", (v_%s, v_%s)", sanitize_identifier(op->inputs[1]->name).c_str(), sanitize_identifier(op->inputs[2]->name).c_str());
}
fprintf(pyfp, ")\n");
}
else if (op->type.substr(0, 3) == "nn." || op->type.substr(0, 16) == "torchvision.ops.")
{
// self.xxx()
for (size_t i = 0; i < op->outputs.size(); i++)
{
fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str());
if (i + 1 != op->outputs.size())
fprintf(pyfp, ", ");
}
fprintf(pyfp, " = self.%s(", sanitize_identifier(op->name).c_str());
for (size_t i = 0; i < op->inputs.size(); i++)
{
fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str());
if (i + 1 != op->inputs.size())
fprintf(pyfp, ", ");
}
fprintf(pyfp, ")\n");
}
else if (op->type.find("::") != std::string::npos || op->type.find(".") != std::string::npos)
{
// direct
for (size_t i = 0; i < op->outputs.size(); i++)
{
fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str());
if (i + 1 != op->outputs.size())
fprintf(pyfp, ", ");
}
if (op->type.substr(0, 7) == "Tensor.")
{
fprintf(pyfp, " = v_%s.%s(", sanitize_identifier(op->inputs[0]->name).c_str(), op->type.substr(7).c_str());
}
else
{
fprintf(pyfp, " = %s(", op->type.c_str());
if (op->inputnames.size() == op->inputs.size())
{
for (size_t i = 0; i < op->inputs.size(); i++)
{
if (!op->inputnames[i].empty())
continue;
fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str());
if (i + 1 != op->inputs.size())
fprintf(pyfp, ", ");
}
for (size_t i = 0; i < op->inputs.size(); i++)
{
if (op->inputnames[i].empty())
continue;
fprintf(pyfp, "%s=v_%s", op->inputnames[i].c_str(), sanitize_identifier(op->inputs[i]->name).c_str());
if (i + 1 != op->inputs.size())
fprintf(pyfp, ", ");
}
}
else
{
for (size_t i = 0; i < op->inputs.size(); i++)
{
fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str());
if (i + 1 != op->inputs.size())
fprintf(pyfp, ", ");
}
}
}
int i = 0;
for (const auto& it : op->params)
{
if (op->type.substr(0, 7) == "Tensor." && i == 0)
{
fprintf(pyfp, "%s=", it.first.c_str());
}
else if (op->inputs.empty() && i == 0)
{
fprintf(pyfp, "%s=", it.first.c_str());
}
else
{
fprintf(pyfp, ", %s=", it.first.c_str());
}
i++;
const Parameter& param = it.second;
if (param.type == 0)
{
fprintf(pyfp, "None");
}
if (param.type == 1)
{
if (param.b)
fprintf(pyfp, "True");
else
fprintf(pyfp, "False");
}
if (param.type == 2)
{
fprintf(pyfp, "%d", param.i);
}
if (param.type == 3)
{
fprintf(pyfp, "%f", param.f);
}
if (param.type == 4)
{
if (param.s.substr(0, 6) == "torch.")
{
fprintf(pyfp, "%s", param.s.c_str());
}
else
{
fprintf(pyfp, "\'%s\'", param.s.c_str());
}
}
if (param.type == 5)
{
fprintf(pyfp, "(");
for (size_t i = 0; i < param.ai.size(); i++)
{
fprintf(pyfp, "%d", param.ai[i]);
if (i + 1 != param.ai.size() || param.ai.size() == 1)
fprintf(pyfp, ",");
}
fprintf(pyfp, ")");
}
if (param.type == 6)
{
fprintf(pyfp, "(");
for (size_t i = 0; i < param.af.size(); i++)
{
fprintf(pyfp, "%f", param.af[i]);
if (i + 1 != param.af.size() || param.af.size() == 1)
fprintf(pyfp, ",");
}
fprintf(pyfp, ")");
}
if (param.type == 7)
{
fprintf(pyfp, "(");
for (size_t i = 0; i < param.as.size(); i++)
{
if (param.as[i].substr(0, 6) == "torch.")
{
fprintf(pyfp, "%s", param.as[i].c_str());
}
else
{
fprintf(pyfp, "\'%s\'", param.as[i].c_str());
}
if (i + 1 != param.as.size() || param.as.size() == 1)
fprintf(pyfp, ",");
}
fprintf(pyfp, ")");
}
}
fprintf(pyfp, ")\n");
}
else
{
fprintf(stderr, "todo %s\n", op->type.c_str());
}
}
}
// return
{
fprintf(pyfp, " return ");
int output_count = 0;
{
for (const Operator* op : ops)
{
if (op->type == "pnnx.Output")
output_count++;
}
}
int output_index = 0;
for (const Operator* op : ops)
{
if (op->type != "pnnx.Output")
continue;
fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[0]->name).c_str());
if (output_index + 1 != output_count)
fprintf(pyfp, ", ");
output_index++;
}
fprintf(pyfp, "\n");
}
fprintf(pyfp, "\n");
// export torchscript
{
fprintf(pyfp, "def export_torchscript():\n");
fprintf(pyfp, " net = Model()\n");
fprintf(pyfp, " net.eval()\n");
fprintf(pyfp, "\n");
fprintf(pyfp, " torch.manual_seed(0)\n");
std::vector<std::string> input_names;
for (const Operator* op : ops)
{
if (op->type != "pnnx.Input")
continue;
const Operand* r = op->outputs[0];
std::string input_name = std::string("v_") + sanitize_identifier(r->name);
if (type_is_integer(r->type))
{
fprintf(pyfp, " %s = torch.randint(10, (", input_name.c_str());
for (size_t i = 0; i < r->shape.size(); i++)
{
fprintf(pyfp, "%d", r->shape[i]);
if (i + 1 != r->shape.size() || r->shape.size() == 1)
fprintf(pyfp, ", ");
}
fprintf(pyfp, "), dtype=%s)\n", type_to_dtype_string(r->type));
}
else
{
fprintf(pyfp, " %s = torch.rand(", input_name.c_str());
for (size_t i = 0; i < r->shape.size(); i++)
{
fprintf(pyfp, "%d, ", r->shape[i]);
}
fprintf(pyfp, "dtype=%s)\n", type_to_dtype_string(r->type));
}
input_names.push_back(input_name);
}
fprintf(pyfp, "\n");
if (input_names.size() == 1)
{
fprintf(pyfp, " mod = torch.jit.trace(net, %s)\n", input_names[0].c_str());
}
else
{
fprintf(pyfp, " mod = torch.jit.trace(net, (");
for (size_t i = 0; i < input_names.size(); i++)
{
fprintf(pyfp, "%s", input_names[i].c_str());
if (i + 1 != input_names.size())
fprintf(pyfp, ", ");
}
fprintf(pyfp, "))\n");
}
fprintf(pyfp, " mod.save(\"%s.pt\")\n", pypath.c_str());
}
fprintf(pyfp, "\n");
// test inference
{
fprintf(pyfp, "def test_inference():\n");
fprintf(pyfp, " net = Model()\n");
fprintf(pyfp, " net.eval()\n");
fprintf(pyfp, "\n");
fprintf(pyfp, " torch.manual_seed(0)\n");
std::vector<std::string> input_names;
for (const Operator* op : ops)
{
if (op->type != "pnnx.Input")
continue;
const Operand* r = op->outputs[0];
std::string input_name = std::string("v_") + sanitize_identifier(r->name);
if (type_is_integer(r->type))
{
fprintf(pyfp, " %s = torch.randint(10, (", input_name.c_str());
for (size_t i = 0; i < r->shape.size(); i++)
{
fprintf(pyfp, "%d", r->shape[i]);
if (i + 1 != r->shape.size() || r->shape.size() == 1)
fprintf(pyfp, ", ");
}
fprintf(pyfp, "), dtype=%s)\n", type_to_dtype_string(r->type));
}
else
{
fprintf(pyfp, " %s = torch.rand(", input_name.c_str());
for (size_t i = 0; i < r->shape.size(); i++)
{
fprintf(pyfp, "%d, ", r->shape[i]);
}
fprintf(pyfp, "dtype=%s)\n", type_to_dtype_string(r->type));
}
input_names.push_back(input_name);
}
fprintf(pyfp, "\n");
if (input_names.size() == 1)
{
fprintf(pyfp, " return net(%s)\n", input_names[0].c_str());
}
else
{
fprintf(pyfp, " return net(");
for (size_t i = 0; i < input_names.size(); i++)
{
fprintf(pyfp, "%s", input_names[i].c_str());
if (i + 1 != input_names.size())
fprintf(pyfp, ", ");
}
fprintf(pyfp, ")\n");
}
}
fclose(pyfp);
return 0;
}
static bool string_is_positive_integer(const std::string& t)
{
for (size_t i = 0; i < t.size(); i++)
{
if (t[i] < '0' || t[i] > '9')
return false;
}
return true;
}
static unsigned short float32_to_float16(float value)
{
// 1 : 8 : 23
union
{
unsigned int u;
float f;
} tmp;
tmp.f = value;
// 1 : 8 : 23
unsigned short sign = (tmp.u & 0x80000000) >> 31;
unsigned short exponent = (tmp.u & 0x7F800000) >> 23;
unsigned int significand = tmp.u & 0x7FFFFF;
// NCNN_LOGE("%d %d %d", sign, exponent, significand);
// 1 : 5 : 10
unsigned short fp16;
if (exponent == 0)
{
// zero or denormal, always underflow
fp16 = (sign << 15) | (0x00 << 10) | 0x00;
}
else if (exponent == 0xFF)
{
// infinity or NaN
fp16 = (sign << 15) | (0x1F << 10) | (significand ? 0x200 : 0x00);
}
else
{
// normalized
short newexp = exponent + (-127 + 15);
if (newexp >= 31)
{
// overflow, return infinity
fp16 = (sign << 15) | (0x1F << 10) | 0x00;
}
else if (newexp <= 0)
{
// Some normal fp32 cannot be expressed as normal fp16
fp16 = (sign << 15) | (0x00 << 10) | 0x00;
}
else
{
// normal fp16
fp16 = (sign << 15) | (newexp << 10) | (significand >> 13);
}
}
return fp16;
}
int Graph::ncnn(const std::string& parampath, const std::string& binpath, const std::string& pypath)
{
FILE* paramfp = fopen(parampath.c_str(), "wb");
if (!paramfp)
{
fprintf(stderr, "fopen %s failed\n", parampath.c_str());
return -1;
}
FILE* binfp = fopen(binpath.c_str(), "wb");
if (!binfp)
{
fprintf(stderr, "fopen %s failed\n", binpath.c_str());
fclose(paramfp);
return -1;
}
// magic
fprintf(paramfp, "7767517\n");
// op count and oprand count
fprintf(paramfp, "%d %d\n", (int)ops.size(), (int)operands.size());
for (const Operator* op : ops)
{
fprintf(paramfp, "%-24s %-24s %d %d", op->type.c_str(), op->name.c_str(), (int)op->inputs.size(), (int)op->outputs.size());
for (const Operand* oprand : op->inputs)
{
fprintf(paramfp, " %s", oprand->name.c_str());
}
for (const Operand* oprand : op->outputs)
{
fprintf(paramfp, " %s", oprand->name.c_str());
}
for (const auto& it : op->params)
{
const Parameter& param = it.second;
if (!string_is_positive_integer(it.first))
{
fprintf(stderr, "ignore %s %s param %s=", op->type.c_str(), op->name.c_str(), it.first.c_str());
if (param.type == 0)
{
fprintf(stderr, "None");
}
if (param.type == 1)
{
if (param.b)
fprintf(stderr, "True");
else
fprintf(stderr, "False");
}
if (param.type == 2)
{
fprintf(stderr, "%d", param.i);
}
if (param.type == 3)
{
fprintf(stderr, "%e", param.f);
}
if (param.type == 4)
{
fprintf(stderr, "%s", param.s.c_str());
}
if (param.type == 5)
{
fprintf(stderr, "(");
for (size_t i = 0; i < param.ai.size(); i++)
{
fprintf(stderr, "%d", param.ai[i]);
if (i + 1 != param.ai.size())
fprintf(stderr, ",");
}
fprintf(stderr, ")");
}
if (param.type == 6)
{
fprintf(stderr, "(");
for (size_t i = 0; i < param.af.size(); i++)
{
fprintf(stderr, "%e", param.af[i]);
if (i + 1 != param.af.size())
fprintf(stderr, ",");
}
fprintf(stderr, ")");
}
if (param.type == 7)
{
fprintf(stderr, "(");
for (size_t i = 0; i < param.as.size(); i++)
{
fprintf(stderr, "%s", param.as[i].c_str());
if (i + 1 != param.as.size())
fprintf(stderr, ",");
}
fprintf(stderr, ")");
}
fprintf(stderr, "\n");
continue;
}
const int idkey = std::stoi(it.first);
if (param.type == 2)
{
fprintf(paramfp, " %d=%d", idkey, param.i);
}
if (param.type == 3)
{
fprintf(paramfp, " %d=%e", idkey, param.f);
}
if (param.type == 5)
{
const int array_size = (int)param.ai.size();
fprintf(paramfp, " %d=%d", -23300 - idkey, array_size);
for (size_t i = 0; i < param.ai.size(); i++)
{
fprintf(paramfp, ",%d", param.ai[i]);
}
}
if (param.type == 6)
{
const int array_size = (int)param.af.size();
fprintf(paramfp, " %d=%d", -23300 - idkey, array_size);
for (size_t i = 0; i < param.af.size(); i++)
{
fprintf(paramfp, ",%e", param.af[i]);
}
}
}
bool is_type_flag_fp32 = false;
for (const auto& it : op->attrs)
{
// fprintf(paramfp, " @%s=", it.first.c_str());
const Attribute& attr = it.second;
if (is_type_flag_fp32)
{
// fp32 -> fp16
const float* p = (const float*)attr.data.data();
int len = attr.data.size() / 4;
for (int i = 0; i < len; i++)
{
unsigned short v_fp16 = float32_to_float16(p[i]);
fwrite(&v_fp16, sizeof(v_fp16), 1, binfp);
}
is_type_flag_fp32 = false;
continue;
}
if (attr.type == 0 && attr.data == std::vector<char> {0, 0, 0, 0})
{
// write fp16 flag
unsigned int fp16_flag = 0x01306B47;
fwrite(&fp16_flag, sizeof(fp16_flag), 1, binfp);
is_type_flag_fp32 = true;
continue;
}
fwrite(attr.data.data(), attr.data.size(), 1, binfp);
}
// if (op->inputnames.size() == op->inputs.size())
// {
// for (size_t i = 0; i < op->inputs.size(); i++)
// {
// const Operand* oprand = op->inputs[i];
// fprintf(paramfp, " $%s=%s", op->inputnames[i].c_str(), oprand->name.c_str());
// }
// }
// for (const Operand* oprand : op->outputs)
// {
// if (oprand->params.find("__batch_index") == oprand->params.end())
// continue;
//
// const int batch_index = oprand->params.at("__batch_index").i;
//
// fprintf(paramfp, " #%s=%d", oprand->name.c_str(), batch_index);
// }
// for (const Operand* oprand : op->outputs)
// {
// if (oprand->shape.empty())
// continue;
//
// fprintf(paramfp, " #%s=", oprand->name.c_str());
//
// fprintf(paramfp, "(");
// for (int64_t i = 0; i < oprand->shape.size() - 1; i++)
// {
// fprintf(paramfp, "%d,", oprand->shape[i]);
// }
// if (oprand->shape.size() > 0)
// fprintf(paramfp, "%d", oprand->shape[oprand->shape.size() - 1]);
// fprintf(paramfp, ")");
//
// fprintf(paramfp, type_to_string(oprand->type));
// }
fprintf(paramfp, "\n");
}
fclose(paramfp);
fclose(binfp);
FILE* pyfp = fopen(pypath.c_str(), "wb");
if (!pyfp)
{
fprintf(stderr, "fopen %s failed\n", pypath.c_str());
return -1;
}
fprintf(pyfp, "import numpy as np\n");
fprintf(pyfp, "import ncnn\n");
fprintf(pyfp, "import torch\n");
fprintf(pyfp, "\n");
// test inference
{
fprintf(pyfp, "def test_inference():\n");
fprintf(pyfp, " torch.manual_seed(0)\n");
for (int input_index = 0;; input_index++)
{
std::string input_name = std::string("in") + std::to_string(input_index);
const Operand* r = get_operand(input_name);
if (!r)
break;
if (type_is_integer(r->type))
{
fprintf(pyfp, " %s = torch.randint(10, (", input_name.c_str());
for (size_t i = 0; i < r->shape.size(); i++)
{
fprintf(pyfp, "%d", r->shape[i]);
if (i + 1 != r->shape.size() || r->shape.size() == 1)
fprintf(pyfp, ", ");
}
fprintf(pyfp, "), dtype=%s)\n", type_to_dtype_string(r->type));
}
else
{
fprintf(pyfp, " %s = torch.rand(", input_name.c_str());
for (size_t i = 0; i < r->shape.size(); i++)
{
fprintf(pyfp, "%d, ", r->shape[i]);
}
fprintf(pyfp, "dtype=%s)\n", type_to_dtype_string(r->type));
}
}
fprintf(pyfp, " out = []\n");
fprintf(pyfp, "\n");
fprintf(pyfp, " with ncnn.Net() as net:\n");
fprintf(pyfp, " net.load_param(\"%s\")\n", parampath.c_str());
fprintf(pyfp, " net.load_model(\"%s\")\n", binpath.c_str());
fprintf(pyfp, "\n");
fprintf(pyfp, " with net.create_extractor() as ex:\n");
for (int input_index = 0;; input_index++)
{
std::string input_name = std::string("in") + std::to_string(input_index);
const Operand* r = get_operand(input_name);
if (!r)
break;
const int batch_index = r->params.at("__batch_index").i;
if (batch_index != 233)
{
fprintf(pyfp, " ex.input(\"%s\", ncnn.Mat(%s.squeeze(%d).numpy()).clone())\n", input_name.c_str(), input_name.c_str(), batch_index);
}
else
{
fprintf(pyfp, " ex.input(\"%s\", ncnn.Mat(%s.numpy()).clone())\n", input_name.c_str(), input_name.c_str());
}
}
fprintf(pyfp, "\n");
for (int output_index = 0;; output_index++)
{
std::string output_name = std::string("out") + std::to_string(output_index);
const Operand* r = get_operand(output_name);
if (!r)
break;
fprintf(pyfp, " _, %s = ex.extract(\"%s\")\n", output_name.c_str(), output_name.c_str());
const int batch_index = r->params.at("__batch_index").i;
if (batch_index != 233)
{
fprintf(pyfp, " out.append(torch.from_numpy(np.array(%s)).unsqueeze(%d))\n", output_name.c_str(), batch_index);
}
else
{
fprintf(pyfp, " out.append(torch.from_numpy(np.array(%s)))\n", output_name.c_str());
}
}
fprintf(pyfp, "\n");
fprintf(pyfp, " if len(out) == 1:\n");
fprintf(pyfp, " return out[0]\n");
fprintf(pyfp, " else:\n");
fprintf(pyfp, " return tuple(out)\n");
}
fclose(pyfp);
return 0;
}
int Graph::parse(const std::string& param)
{
std::istringstream is(param);
if (!is.good())
{
fprintf(stderr, "open failed\n");
return -1;
}
int magic = 0;
{
std::string line;
std::getline(is, line);
std::istringstream iss(line);
iss >> magic;
}
int operator_count = 0;
int operand_count = 0;
{
std::string line;
std::getline(is, line);
std::istringstream iss(line);
iss >> operator_count >> operand_count;
}
for (int i = 0; i < operator_count; i++)
{
std::string line;
std::getline(is, line);
std::istringstream iss(line);
std::string type;
std::string name;
int input_count = 0;
int output_count = 0;
iss >> type >> name >> input_count >> output_count;
Operator* op = new_operator(type, name);
for (int j = 0; j < input_count; j++)
{
std::string operand_name;
iss >> operand_name;
Operand* r = get_operand(operand_name);
r->consumers.push_back(op);
op->inputs.push_back(r);
}
for (int j = 0; j < output_count; j++)
{
std::string operand_name;
iss >> operand_name;
Operand* r = new_operand(operand_name);
r->producer = op;
op->outputs.push_back(r);
}
// key=value
while (!iss.eof())
{
std::string param;
iss >> param;
std::string key;
std::string value;
std::istringstream pss(param);
std::getline(pss, key, '=');
std::getline(pss, value);
if (key[0] == '@')
{
// attribute
// load_attribute(op, key.substr(1), value, szr);
}
else if (key[0] == '$')
{
// operand input key
// load_input_key(op, key.substr(1), value);
}
else if (key[0] == '#')
{
// operand shape
load_shape(op, key.substr(1), value);
}
else
{
// parameter
load_parameter(op, key, value);
}
}
}
return 0;
}
void Operand::remove_consumer(const Operator* c)
{
auto it = std::find(consumers.begin(), consumers.end(), c);
consumers.erase(it);
}
Operator* Graph::new_operator(const std::string& type, const std::string& name)
{
Operator* op = new Operator;
op->type = type;
op->name = name;
ops.push_back(op);
return op;
}
Operator* Graph::new_operator_before(const std::string& type, const std::string& name, const Operator* cur)
{
Operator* op = new Operator;
op->type = type;
op->name = name;
ops.insert(std::find(ops.begin(), ops.end(), cur), op);
return op;
}
Operator* Graph::new_operator_after(const std::string& type, const std::string& name, const Operator* cur)
{
Operator* op = new Operator;
op->type = type;
op->name = name;
ops.insert(std::find(ops.begin(), ops.end(), cur) + 1, op);
return op;
}
Operand* Graph::new_operand(const torch::jit::Value* v)
{
Operand* r = new Operand;
r->name = v->debugName();
auto pt = v->type()->cast<c10::TensorType>();
if (pt)
{
if (pt->scalarType().has_value() && pt->dim().has_value())
{
r->type = get_at_tensor_type(pt->scalarType().value());
const int ndim = (int)pt->dim().value();
r->shape.resize(ndim);
for (int i = 0; i < ndim; i++)
{
if (pt->sizes()[i].has_value())
r->shape[i] = (int)pt->sizes()[i].value();
else
r->shape[i] = -1;
}
}
}
operands.push_back(r);
return r;
}
Operand* Graph::new_operand(const std::string& name)
{
Operand* r = new Operand;
r->name = name;
operands.push_back(r);
return r;
}
Operand* Graph::get_operand(const std::string& name)
{
for (Operand* r : operands)
{
if (r->name == name)
return r;
}
return 0;
}
} // namespace pnnx