// Tencent is pleased to support the open source community by making ncnn available. // // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except // in compliance with the License. You may obtain a copy of the License at // // https://opensource.org/licenses/BSD-3-Clause // // Unless required by applicable law or agreed to in writing, software distributed // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. #include "ir.h" #include #include #include #include #include #include #include #include "storezip.h" namespace pnnx { static bool type_is_integer(int type) { if (type == 1) return false; if (type == 2) return false; if (type == 3) return false; if (type == 4) return true; if (type == 5) return true; if (type == 6) return true; if (type == 7) return true; if (type == 8) return true; if (type == 9) return true; return false; } static const char* type_to_string(int type) { if (type == 1) return "f32"; if (type == 2) return "f64"; if (type == 3) return "f16"; if (type == 4) return "i32"; if (type == 5) return "i64"; if (type == 6) return "i16"; if (type == 7) return "i8"; if (type == 8) return "u8"; if (type == 9) return "bool"; return "null"; } static const char* type_to_numpy_string(int type) { if (type == 1) return "float32"; if (type == 2) return "float64"; if (type == 3) return "float16"; if (type == 4) return "int32"; if (type == 5) return "int64"; if (type == 6) return "int16"; if (type == 7) return "int8"; if (type == 8) return "uint8"; if (type == 9) return "bool8"; return "null"; } static const char* type_to_dtype_string(int type) { if (type == 1) return "torch.float"; if (type == 2) return "torch.double"; if (type == 3) return "torch.half"; if (type == 4) return "torch.int"; if (type == 5) return "torch.long"; if (type == 6) return "torch.short"; if (type == 7) return "torch.int8"; if (type == 8) return "torch.uint8"; if (type == 9) return "torch.bool"; return "null"; } static size_t type_to_elemsize(int type) { if (type == 1) return 4; if (type == 2) return 8; if (type == 3) return 2; if (type == 4) return 4; if (type == 5) return 8; if (type == 6) return 2; if (type == 7) return 1; if (type == 8) return 1; if (type == 9) return 1; return 0; // null } static int string_to_type(const char* s) { if (strcmp(s, "f32") == 0) return 1; if (strcmp(s, "f64") == 0) return 2; if (strcmp(s, "f16") == 0) return 3; if (strcmp(s, "i32") == 0) return 4; if (strcmp(s, "i64") == 0) return 5; if (strcmp(s, "i16") == 0) return 6; if (strcmp(s, "i8") == 0) return 7; if (strcmp(s, "u8") == 0) return 8; if (strcmp(s, "bool") == 0) return 9; return 0; // null } int get_at_tensor_type(const at::ScalarType& st) { if (st == c10::ScalarType::Float) return 1; if (st == c10::ScalarType::Double) return 2; if (st == c10::ScalarType::Half) return 3; if (st == c10::ScalarType::Int) return 4; if (st == c10::ScalarType::QInt32) return 4; if (st == c10::ScalarType::Long) return 5; if (st == c10::ScalarType::Short) return 6; if (st == c10::ScalarType::Char) return 7; if (st == c10::ScalarType::QInt8) return 7; if (st == c10::ScalarType::Byte) return 8; if (st == c10::ScalarType::QUInt8) return 8; if (st == c10::ScalarType::Bool) return 9; return 0; // unknown type } Parameter::Parameter(const torch::jit::Node* value_node) { type = 0; if (value_node->kind() == c10::prim::Constant) { if (!value_node->hasAttribute(torch::jit::attr::value)) { fprintf(stderr, "no attribute value\n"); return; } switch (value_node->output()->type()->kind()) { case c10::TypeKind::NoneType: { type = 0; break; } case c10::TypeKind::BoolType: { type = 1; b = value_node->i(torch::jit::attr::value); break; } case c10::TypeKind::IntType: { type = 2; i = (int)value_node->i(torch::jit::attr::value); break; } case c10::TypeKind::FloatType: { type = 3; f = (float)value_node->f(torch::jit::attr::value); break; } case c10::TypeKind::StringType: { type = 4; s = value_node->s(torch::jit::attr::value); break; } case c10::TypeKind::TensorType: { at::Tensor t = value_node->t(torch::jit::attr::value); if (t.dim() == 0) { if (t.scalar_type() == c10::ScalarType::Long) { type = 2; i = (int)t.item(); } else if (t.scalar_type() == c10::ScalarType::Int) { type = 2; i = t.item(); } else if (t.scalar_type() == c10::ScalarType::Double) { type = 3; f = (float)t.item(); } else if (t.scalar_type() == c10::ScalarType::Float) { type = 3; f = t.item(); } else { fprintf(stderr, "unknown Parameter value kind %s of TensorType, t.dim = 0\n", value_node->kind().toDisplayString()); } } else { const int ndim = (int)t.dim(); type = 8; fprintf(stderr, "unknown Parameter value kind %s of TensorType, t.dim = %d\n", value_node->kind().toDisplayString(), ndim); } break; } default: { fprintf(stderr, "unknown Parameter value kind %s\n", value_node->kind().toDisplayString()); break; } } } else if (value_node->kind() == c10::prim::ListConstruct) { switch (value_node->output()->type()->cast()->getElementType()->kind()) { case c10::TypeKind::IntType: { type = 5; for (const auto& x : value_node->inputs()) { ai.push_back((int)x->node()->i(torch::jit::attr::value)); } break; } case c10::TypeKind::FloatType: { type = 6; for (const auto& x : value_node->inputs()) { af.push_back((float)x->node()->f(torch::jit::attr::value)); } break; } case c10::TypeKind::StringType: { type = 7; for (const auto& x : value_node->inputs()) { as.push_back(x->node()->s(torch::jit::attr::value)); } break; } default: { fprintf(stderr, "unknown Parameter value kind %s\n", value_node->kind().toDisplayString()); break; } } } else { fprintf(stderr, "unknown Parameter value kind %s\n", value_node->kind().toDisplayString()); } } Parameter::Parameter(const torch::jit::Value* value) : Parameter(value->node()) { } bool operator==(const Parameter& lhs, const Parameter& rhs) { if (lhs.type != rhs.type) return false; if (lhs.type == 0) return true; if (lhs.type == 1 && lhs.b == rhs.b) return true; if (lhs.type == 2 && lhs.i == rhs.i) return true; if (lhs.type == 3 && lhs.f == rhs.f) return true; if (lhs.type == 4 && lhs.s == rhs.s) return true; if (lhs.type == 5 && lhs.ai == rhs.ai) return true; if (lhs.type == 6 && lhs.af == rhs.af) return true; if (lhs.type == 7 && lhs.as == rhs.as) return true; return false; } Attribute::Attribute(const at::Tensor& t) { type = get_at_tensor_type(t.scalar_type()); const int ndim = (int)t.dim(); if (ndim == 0) { shape = {1}; data.resize(type_to_elemsize(type)); if (t.scalar_type() == c10::ScalarType::Long) { int64_t i = t.item(); memcpy((void*)data.data(), (const void*)&i, data.size()); } else if (t.scalar_type() == c10::ScalarType::Int) { int i = t.item(); memcpy((void*)data.data(), (const void*)&i, data.size()); } else if (t.scalar_type() == c10::ScalarType::Double) { double f = t.item(); memcpy((void*)data.data(), (const void*)&f, data.size()); } else if (t.scalar_type() == c10::ScalarType::Float) { float f = t.item(); memcpy((void*)data.data(), (const void*)&f, data.size()); } else { fprintf(stderr, "unknown Attribute tensor scalar type %d\n", type); } return; } shape.resize(ndim); for (int i = 0; i < ndim; i++) shape[i] = t.size(i); if (shape.size() > 0) { int size = shape[0]; for (size_t i = 1; i < shape.size(); i++) { size *= shape[i]; } data.resize(size * type_to_elemsize(type)); memcpy((void*)data.data(), (const void*)t.cpu().contiguous().data_ptr(), data.size()); } } Attribute::Attribute(const std::initializer_list& _shape, const std::vector& t) { type = 1; shape = _shape; if (shape.size() > 0) { int size = shape[0]; for (size_t i = 1; i < shape.size(); i++) { size *= shape[i]; } data.resize(size * type_to_elemsize(type)); memcpy((void*)data.data(), (const void*)t.data(), data.size()); } } bool operator==(const Attribute& lhs, const Attribute& rhs) { if (lhs.type != rhs.type) return false; if (lhs.type == 0) return true; if (lhs.shape != rhs.shape) return false; if (lhs.data != rhs.data) return false; return true; } Attribute operator+(const Attribute& a, const Attribute& b) { Attribute c; if (a.type != b.type) { fprintf(stderr, "concat attribute type mismatch\n"); return c; } if (a.shape.size() != b.shape.size()) { fprintf(stderr, "concat attribute shape rank mismatch\n"); return c; } for (int i = 1; i < (int)a.shape.size(); i++) { if (a.shape[i] != b.shape[i]) { fprintf(stderr, "concat attribute shape mismatch\n"); return c; } } c.type = a.type; c.shape = a.shape; c.shape[0] += b.shape[0]; // concat the first dim c.data.resize(a.data.size() + b.data.size()); memcpy(c.data.data(), a.data.data(), a.data.size()); memcpy(c.data.data() + a.data.size(), b.data.data(), b.data.size()); return c; } Parameter Parameter::parse_from_string(const std::string& value) { Parameter p; p.type = 0; if (value == "None" || value == "()" || value == "[]") { return p; } if (value == "True" || value == "False") { // bool p.type = 1; p.b = value == "True"; return p; } if (value[0] == '(' || value[0] == '[') { // list std::string lc = value.substr(1, value.size() - 2); std::istringstream lcss(lc); while (!lcss.eof()) { std::string elem; std::getline(lcss, elem, ','); if ((elem[0] != '-' && (elem[0] < '0' || elem[0] > '9')) || (elem[0] == '-' && (elem[1] < '0' || elem[1] > '9'))) { // string p.type = 7; p.as.push_back(elem); } else if (elem.find('.') != std::string::npos || elem.find('e') != std::string::npos) { // float p.type = 6; p.af.push_back(std::stof(elem)); } else { // integer p.type = 5; p.ai.push_back(std::stoi(elem)); } } return p; } if ((value[0] != '-' && (value[0] < '0' || value[0] > '9')) || (value[0] == '-' && (value[1] < '0' || value[1] > '9'))) { // string p.type = 4; p.s = value; return p; } if (value.find('.') != std::string::npos || value.find('e') != std::string::npos) { // float p.type = 3; p.f = std::stof(value); return p; } // integer p.type = 2; p.i = std::stoi(value); return p; } Graph::Graph() { } Graph::~Graph() { for (auto x : ops) delete x; for (auto x : operands) delete x; ops.clear(); operands.clear(); } Graph::Graph(const Graph& /*rhs*/) { } Graph& Graph::operator=(const Graph& /*rhs*/) { return *this; } static void load_parameter(Operator* op, const std::string& key, const std::string& value) { op->params[key] = Parameter::parse_from_string(value); } static void load_input_key(Operator* op, const std::string& key, const std::string& value) { op->inputnames.resize(op->inputs.size()); for (size_t i = 0; i < op->inputs.size(); i++) { const Operand* oprand = op->inputs[i]; if (oprand->name == value) { op->inputnames[i] = key; break; } } } static void load_shape(Operator* op, const std::string& key, const std::string& value) { Operand* operand = 0; for (auto r : op->inputs) { if (r->name == key) { operand = r; break; } } if (!operand) { for (auto r : op->outputs) { if (r->name == key) { operand = r; break; } } } if (!operand) { fprintf(stderr, "no such operand %s for operator %s\n", key.c_str(), op->name.c_str()); return; } // type std::string typestr = value.substr(value.find_last_of(')') + 1); operand->type = string_to_type(typestr.c_str()); // shape std::string lc = value.substr(1, value.find_last_of(')') - 1); std::istringstream lcss(lc); operand->shape.clear(); while (!lcss.eof()) { std::string elem; std::getline(lcss, elem, ','); if (elem == "?") { operand->shape.push_back(-1); } else { int i = std::stoi(elem); operand->shape.push_back(i); } } } static void load_attribute(Operator* op, const std::string& key, const std::string& value, StoreZipReader& szr) { Attribute& a = op->attrs[key]; // type std::string typestr = value.substr(value.find_last_of(')') + 1); a.type = string_to_type(typestr.c_str()); if (a.type == 0) return; // shape std::string lc = value.substr(1, value.find_last_of(')') - 1); std::istringstream lcss(lc); a.shape.clear(); while (!lcss.eof()) { std::string elem; std::getline(lcss, elem, ','); int i = std::stoi(elem); a.shape.push_back(i); } if (a.shape.empty()) return; // data size_t size = 1; for (int i : a.shape) { size *= i; } size_t bytesize = size * type_to_elemsize(a.type); std::string filename = op->name + "." + key; size_t filesize = szr.get_file_size(filename); if (filesize == 0) { // no such file return; } if (filesize != bytesize) { fprintf(stderr, "file size not match expect %lu but got %lu\n", bytesize, filesize); } a.data.resize(bytesize); szr.read_file(filename, (char*)a.data.data()); } int Graph::load(const std::string& parampath, const std::string& binpath) { std::ifstream is(parampath, std::ios::in | std::ios::binary); if (!is.good()) { fprintf(stderr, "open failed\n"); return -1; } StoreZipReader szr; if (szr.open(binpath) != 0) { fprintf(stderr, "open failed\n"); return -1; } int magic = 0; { std::string line; std::getline(is, line); std::istringstream iss(line); iss >> magic; } int operator_count = 0; int operand_count = 0; { std::string line; std::getline(is, line); std::istringstream iss(line); iss >> operator_count >> operand_count; } for (int i = 0; i < operator_count; i++) { std::string line; std::getline(is, line); std::istringstream iss(line); std::string type; std::string name; int input_count = 0; int output_count = 0; iss >> type >> name >> input_count >> output_count; Operator* op = new_operator(type, name); for (int j = 0; j < input_count; j++) { std::string operand_name; iss >> operand_name; Operand* r = get_operand(operand_name); r->consumers.push_back(op); op->inputs.push_back(r); } for (int j = 0; j < output_count; j++) { std::string operand_name; iss >> operand_name; Operand* r = new_operand(operand_name); r->producer = op; op->outputs.push_back(r); } // key=value while (!iss.eof()) { std::string param; iss >> param; std::string key; std::string value; std::istringstream pss(param); std::getline(pss, key, '='); std::getline(pss, value); if (key[0] == '@') { // attribute load_attribute(op, key.substr(1), value, szr); } else if (key[0] == '$') { // operand input key load_input_key(op, key.substr(1), value); } else if (key[0] == '#') { // operand shape load_shape(op, key.substr(1), value); } else { // parameter load_parameter(op, key, value); } } } return 0; } int Graph::save(const std::string& parampath, const std::string& binpath) { FILE* paramfp = fopen(parampath.c_str(), "wb"); if (!paramfp) { fprintf(stderr, "fopen %s failed\n", parampath.c_str()); return -1; } StoreZipWriter szw; if (szw.open(binpath) != 0) { fprintf(stderr, "open failed\n"); return -1; } // magic fprintf(paramfp, "7767517\n"); // op count and oprand count fprintf(paramfp, "%d %d\n", (int)ops.size(), (int)operands.size()); for (const Operator* op : ops) { fprintf(paramfp, "%-24s %-24s %d %d", op->type.c_str(), op->name.c_str(), (int)op->inputs.size(), (int)op->outputs.size()); for (const Operand* oprand : op->inputs) { fprintf(paramfp, " %s", oprand->name.c_str()); } for (const Operand* oprand : op->outputs) { fprintf(paramfp, " %s", oprand->name.c_str()); } for (const auto& it : op->params) { fprintf(paramfp, " %s=", it.first.c_str()); const Parameter& param = it.second; if (param.type == 0) { fprintf(paramfp, "None"); } if (param.type == 1) { if (param.b) fprintf(paramfp, "True"); else fprintf(paramfp, "False"); } if (param.type == 2) { fprintf(paramfp, "%d", param.i); } if (param.type == 3) { fprintf(paramfp, "%e", param.f); } if (param.type == 4) { fprintf(paramfp, "%s", param.s.c_str()); } if (param.type == 5) { fprintf(paramfp, "("); for (size_t i = 0; i < param.ai.size(); i++) { fprintf(paramfp, "%d", param.ai[i]); if (i + 1 != param.ai.size()) fprintf(paramfp, ","); } fprintf(paramfp, ")"); } if (param.type == 6) { fprintf(paramfp, "("); for (size_t i = 0; i < param.af.size(); i++) { fprintf(paramfp, "%e", param.af[i]); if (i + 1 != param.af.size()) fprintf(paramfp, ","); } fprintf(paramfp, ")"); } if (param.type == 7) { fprintf(paramfp, "("); for (size_t i = 0; i < param.as.size(); i++) { fprintf(paramfp, "%s", param.as[i].c_str()); if (i + 1 != param.as.size()) fprintf(paramfp, ","); } fprintf(paramfp, ")"); } } for (const auto& it : op->attrs) { fprintf(paramfp, " @%s=", it.first.c_str()); const Attribute& attr = it.second; fprintf(paramfp, "("); for (int i = 0; i < (int)attr.shape.size() - 1; i++) { fprintf(paramfp, "%d,", attr.shape[i]); } if (attr.shape.size() > 0) fprintf(paramfp, "%d", attr.shape[attr.shape.size() - 1]); fprintf(paramfp, ")"); fprintf(paramfp, type_to_string(attr.type)); std::string filename = op->name + "." + it.first; szw.write_file(filename, attr.data.data(), attr.data.size()); } if (op->inputnames.size() == op->inputs.size()) { for (size_t i = 0; i < op->inputs.size(); i++) { if (op->inputnames[i].empty()) continue; const Operand* oprand = op->inputs[i]; fprintf(paramfp, " $%s=%s", op->inputnames[i].c_str(), oprand->name.c_str()); } } for (const Operand* oprand : op->inputs) { if (oprand->shape.empty()) continue; fprintf(paramfp, " #%s=", oprand->name.c_str()); fprintf(paramfp, "("); for (int i = 0; i < (int)oprand->shape.size() - 1; i++) { if (oprand->shape[i] == -1) fprintf(paramfp, "?,"); else fprintf(paramfp, "%d,", oprand->shape[i]); } if (oprand->shape.size() > 0) { if (oprand->shape[oprand->shape.size() - 1] == -1) fprintf(paramfp, "?"); else fprintf(paramfp, "%d", oprand->shape[oprand->shape.size() - 1]); } fprintf(paramfp, ")"); fprintf(paramfp, type_to_string(oprand->type)); } for (const Operand* oprand : op->outputs) { if (oprand->shape.empty()) continue; fprintf(paramfp, " #%s=", oprand->name.c_str()); fprintf(paramfp, "("); for (int i = 0; i < (int)oprand->shape.size() - 1; i++) { if (oprand->shape[i] == -1) fprintf(paramfp, "?,"); else fprintf(paramfp, "%d,", oprand->shape[i]); } if (oprand->shape.size() > 0) { if (oprand->shape[oprand->shape.size() - 1] == -1) fprintf(paramfp, "?"); else fprintf(paramfp, "%d", oprand->shape[oprand->shape.size() - 1]); } fprintf(paramfp, ")"); fprintf(paramfp, type_to_string(oprand->type)); } fprintf(paramfp, "\n"); } fclose(paramfp); return 0; } static std::string sanitize_identifier(const std::string& s) { std::string ss = s; for (size_t i = 0; i < ss.size(); i++) { if (ss[i] == '.' || ss[i] == ':') ss[i] = '_'; } return ss; } static std::string expand_expression(const Operator* op) { std::string expr = op->params.at("expr").s; // split into tokens std::vector tokens; { std::string t; for (size_t i = 0; i < expr.size(); i++) { char ch = expr[i]; if (ch == '[') // list { t += ch; tokens.push_back(t); t.clear(); } else if (ch == '(' || ch == ')' || ch == ',' || ch == ']') { if (!t.empty()) { tokens.push_back(t); t.clear(); } } else { t += ch; } } if (!t.empty()) { tokens.push_back(t); } } // scan and stack std::stack exprstack; for (int i = (int)tokens.size() - 1; i >= 0; i--) { const std::string& t = tokens[i]; if (t == "size") { std::string a = exprstack.top(); exprstack.pop(); std::string b = exprstack.top(); exprstack.pop(); std::string r = a + ".size(" + b + ")"; exprstack.push(r); } else if (t == "int" || t == "sqrt" || t == "rsqrt" || t == "neg") { std::string unaryop; if (t == "int") unaryop = "int"; if (t == "sqrt") unaryop = "torch.sqrt"; if (t == "rsqrt") unaryop = "torch.rsqrt"; if (t == "neg") unaryop = "torch.neg"; std::string a = exprstack.top(); exprstack.pop(); std::string r = unaryop + "(" + a + ")"; exprstack.push(r); } else if (t == "pow") { std::string a = exprstack.top(); exprstack.pop(); std::string b = exprstack.top(); exprstack.pop(); std::string r = a + ".pow(" + b + ")"; exprstack.push(r); } else if (t == "add" || t == "sub" || t == "mul" || t == "div" || t == "floor_divide") { std::string binaryop; if (t == "add") binaryop = "+"; if (t == "sub") binaryop = "-"; if (t == "mul") binaryop = "*"; if (t == "div") binaryop = "/"; if (t == "floor_divide") binaryop = "//"; std::string a = exprstack.top(); exprstack.pop(); std::string b = exprstack.top(); exprstack.pop(); std::string r = std::string("(") + a + " " + binaryop + " " + b + ")"; exprstack.push(r); } else if (t == "[") // list { std::vector elements; while (!exprstack.empty()) { std::string a = exprstack.top(); exprstack.pop(); elements.push_back(a); } std::string r = "["; for (int j = 0; j < (int)elements.size() - 1; j++) { r += elements[j]; if (j + 1 != (int)elements.size()) r += ", "; } if (!elements.empty()) { r += elements[elements.size() - 1]; } r += "]"; exprstack.push(r); } else if (t[0] == '@') { int input_index = std::stoi(t.substr(1)); std::string varid = std::string("v_") + sanitize_identifier(op->inputs[input_index]->name); exprstack.push(varid); } else { // literal exprstack.push(t); } } std::string r = exprstack.top(); exprstack.pop(); return r; } static std::string make_slice_expression(const Operator* op) { for (size_t j = 0; j < op->inputnames.size(); j++) { fprintf(stderr, "make_slice_expression %s %s\n", op->inputnames[j].c_str(), op->inputs[j]->name.c_str()); } std::vector dims = op->params.at("dims").ai; std::string r; int last_dim = -1; const int ndim = (int)dims.size(); for (int i = 0; i < ndim; i++) { int dim = dims[i]; for (int j = last_dim + 1; j < dim; j++) { r += ":,"; } last_dim = dim; if (op->params.find("starts") != op->params.end()) { std::vector starts = op->params.at("starts").ai; int start = starts[i]; if (start != 0) r += std::to_string(start); } else { fprintf(stderr, "find start\n"); // find start for (size_t j = 0; j < op->inputnames.size(); j++) { if (op->inputnames[j] == "start") { r += std::string("v_") + sanitize_identifier(op->inputs[j]->name); fprintf(stderr, "find start %s\n", op->inputs[j]->name.c_str()); break; } } } r += ':'; if (op->params.find("ends") != op->params.end()) { std::vector ends = op->params.at("ends").ai; int end = ends[i]; if (end != -1) r += std::to_string(end); } else { // find end for (size_t j = 0; j < op->inputnames.size(); j++) { if (op->inputnames[j] == "end") { r += std::string("v_") + sanitize_identifier(op->inputs[j]->name); break; } } } if (op->params.find("steps") != op->params.end()) { std::vector steps = op->params.at("steps").ai; int step = steps[i]; if (step != 1) { r += ':'; r += std::to_string(step); } } else { // find step for (size_t j = 0; j < op->inputnames.size(); j++) { if (op->inputnames[j] == "step") { r += ':'; r += std::string("v_") + sanitize_identifier(op->inputs[j]->name); break; } } } if (i + 1 != ndim) r += ','; } return r; } static std::string make_index_expression(const Operator* op) { fprintf(stderr, "make_index_expression %s\n", op->name.c_str()); std::string index_expr = op->params.at("expr").s; // strip out-most [ ] pair index_expr = index_expr.substr(1, index_expr.size() - 2); // None,None, -> ..., bool leading_none = false; while (index_expr.substr(0, 5) == "None,") { leading_none = true; index_expr = index_expr.substr(5); } if (leading_none) { index_expr = "...," + index_expr; } return index_expr; } int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) { FILE* pyfp = fopen(pypath.c_str(), "wb"); if (!pyfp) { fprintf(stderr, "fopen %s failed\n", pypath.c_str()); return -1; } fprintf(pyfp, "import os\n"); fprintf(pyfp, "import numpy as np\n"); fprintf(pyfp, "import tempfile, zipfile\n"); fprintf(pyfp, "import torch\n"); fprintf(pyfp, "import torch.nn as nn\n"); fprintf(pyfp, "import torch.nn.functional as F\n"); fprintf(pyfp, "import torchvision\n"); fprintf(pyfp, "\n"); fprintf(pyfp, "class Model(nn.Module):\n"); fprintf(pyfp, " def __init__(self):\n"); fprintf(pyfp, " super(Model, self).__init__()\n"); fprintf(pyfp, "\n"); // module { for (const Operator* op : ops) { if (op->type.substr(0, 3) != "nn." && op->type.substr(0, 16) != "torchvision.ops.") continue; fprintf(pyfp, " self.%s = %s(", sanitize_identifier(op->name).c_str(), op->type.c_str()); int param_count = op->params.size(); if (op->type == "nn.quantized.Conv2d" || op->type == "nn.quantized.Linear") { param_count -= 2; // ignore scale and zero_point } int param_index = 0; for (const auto& it : op->params) { if (op->type == "nn.quantized.Conv2d" || op->type == "nn.quantized.Linear") { if (it.first == "scale" || it.first == "zero_point") continue; } fprintf(pyfp, "%s=", it.first.c_str()); const Parameter& param = it.second; if (param.type == 0) { fprintf(pyfp, "None"); } if (param.type == 1) { if (param.b) fprintf(pyfp, "True"); else fprintf(pyfp, "False"); } if (param.type == 2) { fprintf(pyfp, "%d", param.i); } if (param.type == 3) { fprintf(pyfp, "%f", param.f); } if (param.type == 4) { if (param.s.substr(0, 6) == "torch.") { fprintf(pyfp, "%s", param.s.c_str()); } else { fprintf(pyfp, "\'%s\'", param.s.c_str()); } } if (param.type == 5) { fprintf(pyfp, "("); for (size_t i = 0; i < param.ai.size(); i++) { fprintf(pyfp, "%d", param.ai[i]); if (i + 1 != param.ai.size() || param.ai.size() == 1) fprintf(pyfp, ","); } fprintf(pyfp, ")"); } if (param.type == 6) { fprintf(pyfp, "("); for (size_t i = 0; i < param.af.size(); i++) { fprintf(pyfp, "%f", param.af[i]); if (i + 1 != param.af.size() || param.af.size() == 1) fprintf(pyfp, ","); } fprintf(pyfp, ")"); } if (param.type == 7) { fprintf(pyfp, "("); for (size_t i = 0; i < param.as.size(); i++) { if (param.as[i].substr(0, 6) == "torch.") { fprintf(pyfp, "%s", param.as[i].c_str()); } else { fprintf(pyfp, "\'%s\'", param.as[i].c_str()); } if (i + 1 != param.as.size() || param.as.size() == 1) fprintf(pyfp, ","); } fprintf(pyfp, ")"); } param_index++; if (param_index != param_count) fprintf(pyfp, ", "); } fprintf(pyfp, ")\n"); } } fprintf(pyfp, "\n"); // load weights { fprintf(pyfp, " archive = zipfile.ZipFile('%s', 'r')\n", pnnxbinpath.c_str()); for (const Operator* op : ops) { if (op->type.substr(0, 3) != "nn." && op->type.substr(0, 16) != "torchvision.ops.") continue; if (op->type == "nn.quantized.Conv2d" || op->type == "nn.quantized.Linear") { for (const auto& it : op->attrs) { if (it.first == "weight" || it.first == "bias") { fprintf(pyfp, " self_%s_%s = self.load_pnnx_bin_as_parameter(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), it.first.c_str(), op->name.c_str(), it.first.c_str()); } else { // unknown attr continue; } const Attribute& attr = it.second; for (size_t i = 0; i < attr.shape.size(); i++) { fprintf(pyfp, "%d", attr.shape[i]); if (i + 1 != attr.shape.size()) fprintf(pyfp, ","); } fprintf(pyfp, "), '%s', requires_grad=False)\n", type_to_numpy_string(attr.type)); } fprintf(pyfp, " self.%s.set_weight_bias(self_%s_weight, self_%s_bias)\n", sanitize_identifier(op->name).c_str(), sanitize_identifier(op->name).c_str(), sanitize_identifier(op->name).c_str()); fprintf(pyfp, " self.%s.scale = %f\n", sanitize_identifier(op->name).c_str(), op->params.at("scale").f); fprintf(pyfp, " self.%s.zero_point = %d\n", sanitize_identifier(op->name).c_str(), op->params.at("zero_point").i); continue; } for (const auto& it : op->attrs) { if (it.first == "running_mean" || it.first == "running_var") { fprintf(pyfp, " self.%s.%s = self.load_pnnx_bin_as_tensor(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), it.first.c_str(), op->name.c_str(), it.first.c_str()); } else { fprintf(pyfp, " self.%s.%s = self.load_pnnx_bin_as_parameter(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), it.first.c_str(), op->name.c_str(), it.first.c_str()); } const Attribute& attr = it.second; for (size_t i = 0; i < attr.shape.size(); i++) { fprintf(pyfp, "%d", attr.shape[i]); if (i + 1 != attr.shape.size()) fprintf(pyfp, ","); } if (attr.type == 1 || attr.type == 2 || attr.type == 3) { fprintf(pyfp, "), '%s')\n", type_to_numpy_string(attr.type)); } else { fprintf(pyfp, "), '%s', requires_grad=False)\n", type_to_numpy_string(attr.type)); } } } for (const Operator* op : ops) { if (op->type != "pnnx.Attribute") continue; const std::string& key = op->attrs.begin()->first; const Attribute& attr = op->attrs.begin()->second; bool is_running_mean_var = false; { const Operand* r = op->outputs[0]; if (r->consumers.size() == 1) { const Operator* op2 = r->consumers[0]; if (op2->type == "F.batch_norm" || op2->type == "F.instance_norm") { if (r == op2->inputs[1] || r == op2->inputs[2]) { is_running_mean_var = true; } } } } if (is_running_mean_var) { fprintf(pyfp, " self.%s_%s = self.load_pnnx_bin_as_tensor(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), sanitize_identifier(key).c_str(), op->name.c_str(), key.c_str()); } else { fprintf(pyfp, " self.%s_%s = self.load_pnnx_bin_as_parameter(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), sanitize_identifier(key).c_str(), op->name.c_str(), key.c_str()); } for (size_t i = 0; i < attr.shape.size(); i++) { fprintf(pyfp, "%d", attr.shape[i]); if (i + 1 != attr.shape.size()) fprintf(pyfp, ","); } if (attr.type == 1 || attr.type == 2 || attr.type == 3) { fprintf(pyfp, "), '%s')\n", type_to_numpy_string(attr.type)); } else { fprintf(pyfp, "), '%s', requires_grad=False)\n", type_to_numpy_string(attr.type)); } } fprintf(pyfp, " archive.close()\n"); } fprintf(pyfp, "\n"); // utility function { fprintf(pyfp, " def load_pnnx_bin_as_parameter(self, archive, key, shape, dtype, requires_grad=True):\n"); fprintf(pyfp, " return nn.Parameter(self.load_pnnx_bin_as_tensor(archive, key, shape, dtype), requires_grad)\n"); fprintf(pyfp, "\n"); fprintf(pyfp, " def load_pnnx_bin_as_tensor(self, archive, key, shape, dtype):\n"); fprintf(pyfp, " _, tmppath = tempfile.mkstemp()\n"); fprintf(pyfp, " tmpf = open(tmppath, 'wb')\n"); fprintf(pyfp, " with archive.open(key) as keyfile:\n"); fprintf(pyfp, " tmpf.write(keyfile.read())\n"); fprintf(pyfp, " tmpf.close()\n"); fprintf(pyfp, " m = np.memmap(tmppath, dtype=dtype, mode='r', shape=shape).copy()\n"); fprintf(pyfp, " os.remove(tmppath)\n"); fprintf(pyfp, " return torch.from_numpy(m)\n"); } fprintf(pyfp, "\n"); // def forward { fprintf(pyfp, " def forward(self"); for (const Operator* op : ops) { if (op->type != "pnnx.Input") continue; fprintf(pyfp, ", v_%s", sanitize_identifier(op->outputs[0]->name).c_str()); } fprintf(pyfp, "):\n"); } // forward body { for (const Operator* op : ops) { if (op->type == "pnnx.Input" || op->type == "pnnx.Output") continue; fprintf(pyfp, " "); if (op->type == "pnnx.Expression") { // expr for (size_t i = 0; i < op->outputs.size(); i++) { fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str()); if (i + 1 != op->outputs.size()) fprintf(pyfp, ", "); } std::string expanded_expr = expand_expression(op); fprintf(pyfp, " = %s\n", expanded_expr.c_str()); } else if (op->type == "pnnx.Attribute") { const std::string& key = op->attrs.begin()->first; fprintf(pyfp, "v_%s = self.%s_%s\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->name).c_str(), sanitize_identifier(key).c_str()); } else if (op->type == "Tensor.slice") { // slice expr std::string slice_expr = make_slice_expression(op); fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), slice_expr.c_str()); } else if (op->type == "Tensor.index") { // index expr if (op->inputs.size() == 2) { std::string expanded_expr = expand_expression(op->inputs[1]->producer); fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), expanded_expr.c_str()); } else { std::string index_expr = make_index_expression(op); fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), index_expr.c_str()); } } else if (op->type == "Tensor.view" || op->type == "Tensor.reshape") { // view reshape fprintf(pyfp, "v_%s = v_%s.%s(", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), op->type.substr(7).c_str()); if (op->inputs.size() == 2) { fprintf(pyfp, "*v_%s", sanitize_identifier(op->inputs[1]->name).c_str()); } else { const std::vector& shape = op->params.at("shape").ai; for (size_t i = 0; i < shape.size(); i++) { fprintf(pyfp, "%d", shape[i]); if (i + 1 != shape.size()) fprintf(pyfp, ", "); } } fprintf(pyfp, ")\n"); } else if (op->type == "Tensor.repeat") { // view reshape fprintf(pyfp, "v_%s = v_%s.%s(", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), op->type.substr(7).c_str()); if (op->inputs.size() == 2) { fprintf(pyfp, "*v_%s", sanitize_identifier(op->inputs[1]->name).c_str()); } else { const std::vector& sizes = op->params.at("sizes").ai; for (size_t i = 0; i < sizes.size(); i++) { fprintf(pyfp, "%d", sizes[i]); if (i + 1 != sizes.size()) fprintf(pyfp, ", "); } } fprintf(pyfp, ")\n"); } else if (op->type == "torch.cat" || op->type == "torch.stack") { // cat fprintf(pyfp, "v_%s = %s(", sanitize_identifier(op->outputs[0]->name).c_str(), op->type.c_str()); if (op->inputs.size() == 1) { fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[0]->name).c_str()); } else { fprintf(pyfp, "("); for (size_t i = 0; i < op->inputs.size(); i++) { fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); if (i + 1 != op->inputs.size()) fprintf(pyfp, ", "); } fprintf(pyfp, ")"); } fprintf(pyfp, ", dim=%d", op->params.at("dim").i); fprintf(pyfp, ")\n"); } else if (op->type == "prim::TupleUnpack") { for (size_t i = 0; i < op->outputs.size(); i++) { fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str()); if (i + 1 != op->outputs.size()) fprintf(pyfp, ", "); } fprintf(pyfp, " = v_%s\n", sanitize_identifier(op->inputs[0]->name).c_str()); } else if (op->type == "prim::TupleConstruct") { fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[0]->name).c_str()); fprintf(pyfp, " = ("); for (size_t i = 0; i < op->inputs.size(); i++) { fprintf(pyfp, "v_%s, ", sanitize_identifier(op->inputs[i]->name).c_str()); } fprintf(pyfp, ")\n"); } else if (op->type == "prim::ListUnpack") { for (size_t i = 0; i < op->outputs.size(); i++) { fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str()); if (i + 1 != op->outputs.size()) fprintf(pyfp, ", "); } fprintf(pyfp, " = v_%s\n", sanitize_identifier(op->inputs[0]->name).c_str()); } else if (op->type == "prim::ListConstruct") { fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[0]->name).c_str()); fprintf(pyfp, " = ["); for (size_t i = 0; i < op->inputs.size(); i++) { fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); if (i + 1 != op->inputs.size()) fprintf(pyfp, ", "); } fprintf(pyfp, "]\n"); } else if (op->type == "nn.LSTM") { if (op->outputs.size() == 1) { fprintf(pyfp, "v_%s, _", sanitize_identifier(op->outputs[0]->name).c_str()); } else { fprintf(pyfp, "v_%s, (v_%s, v_%s)", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->outputs[1]->name).c_str(), sanitize_identifier(op->outputs[2]->name).c_str()); } fprintf(pyfp, " = self.%s(", sanitize_identifier(op->name).c_str()); fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[0]->name).c_str()); if (op->inputs.size() == 3) { fprintf(pyfp, ", (v_%s, v_%s)", sanitize_identifier(op->inputs[1]->name).c_str(), sanitize_identifier(op->inputs[2]->name).c_str()); } fprintf(pyfp, ")\n"); } else if (op->type.substr(0, 3) == "nn." || op->type.substr(0, 16) == "torchvision.ops.") { // self.xxx() for (size_t i = 0; i < op->outputs.size(); i++) { fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str()); if (i + 1 != op->outputs.size()) fprintf(pyfp, ", "); } fprintf(pyfp, " = self.%s(", sanitize_identifier(op->name).c_str()); for (size_t i = 0; i < op->inputs.size(); i++) { fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); if (i + 1 != op->inputs.size()) fprintf(pyfp, ", "); } fprintf(pyfp, ")\n"); } else if (op->type.find("::") != std::string::npos || op->type.find(".") != std::string::npos) { // direct for (size_t i = 0; i < op->outputs.size(); i++) { fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str()); if (i + 1 != op->outputs.size()) fprintf(pyfp, ", "); } if (op->type.substr(0, 7) == "Tensor.") { fprintf(pyfp, " = v_%s.%s(", sanitize_identifier(op->inputs[0]->name).c_str(), op->type.substr(7).c_str()); } else { fprintf(pyfp, " = %s(", op->type.c_str()); if (op->inputnames.size() == op->inputs.size()) { for (size_t i = 0; i < op->inputs.size(); i++) { if (!op->inputnames[i].empty()) continue; fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); if (i + 1 != op->inputs.size()) fprintf(pyfp, ", "); } for (size_t i = 0; i < op->inputs.size(); i++) { if (op->inputnames[i].empty()) continue; fprintf(pyfp, "%s=v_%s", op->inputnames[i].c_str(), sanitize_identifier(op->inputs[i]->name).c_str()); if (i + 1 != op->inputs.size()) fprintf(pyfp, ", "); } } else { for (size_t i = 0; i < op->inputs.size(); i++) { fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); if (i + 1 != op->inputs.size()) fprintf(pyfp, ", "); } } } int i = 0; for (const auto& it : op->params) { if (op->type.substr(0, 7) == "Tensor." && i == 0) { fprintf(pyfp, "%s=", it.first.c_str()); } else if (op->inputs.empty() && i == 0) { fprintf(pyfp, "%s=", it.first.c_str()); } else { fprintf(pyfp, ", %s=", it.first.c_str()); } i++; const Parameter& param = it.second; if (param.type == 0) { fprintf(pyfp, "None"); } if (param.type == 1) { if (param.b) fprintf(pyfp, "True"); else fprintf(pyfp, "False"); } if (param.type == 2) { fprintf(pyfp, "%d", param.i); } if (param.type == 3) { fprintf(pyfp, "%f", param.f); } if (param.type == 4) { if (param.s.substr(0, 6) == "torch.") { fprintf(pyfp, "%s", param.s.c_str()); } else { fprintf(pyfp, "\'%s\'", param.s.c_str()); } } if (param.type == 5) { fprintf(pyfp, "("); for (size_t i = 0; i < param.ai.size(); i++) { fprintf(pyfp, "%d", param.ai[i]); if (i + 1 != param.ai.size() || param.ai.size() == 1) fprintf(pyfp, ","); } fprintf(pyfp, ")"); } if (param.type == 6) { fprintf(pyfp, "("); for (size_t i = 0; i < param.af.size(); i++) { fprintf(pyfp, "%f", param.af[i]); if (i + 1 != param.af.size() || param.af.size() == 1) fprintf(pyfp, ","); } fprintf(pyfp, ")"); } if (param.type == 7) { fprintf(pyfp, "("); for (size_t i = 0; i < param.as.size(); i++) { if (param.as[i].substr(0, 6) == "torch.") { fprintf(pyfp, "%s", param.as[i].c_str()); } else { fprintf(pyfp, "\'%s\'", param.as[i].c_str()); } if (i + 1 != param.as.size() || param.as.size() == 1) fprintf(pyfp, ","); } fprintf(pyfp, ")"); } } fprintf(pyfp, ")\n"); } else { fprintf(stderr, "todo %s\n", op->type.c_str()); } } } // return { fprintf(pyfp, " return "); int output_count = 0; { for (const Operator* op : ops) { if (op->type == "pnnx.Output") output_count++; } } int output_index = 0; for (const Operator* op : ops) { if (op->type != "pnnx.Output") continue; fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[0]->name).c_str()); if (output_index + 1 != output_count) fprintf(pyfp, ", "); output_index++; } fprintf(pyfp, "\n"); } fprintf(pyfp, "\n"); // export torchscript { fprintf(pyfp, "def export_torchscript():\n"); fprintf(pyfp, " net = Model()\n"); fprintf(pyfp, " net.eval()\n"); fprintf(pyfp, "\n"); fprintf(pyfp, " torch.manual_seed(0)\n"); std::vector input_names; for (const Operator* op : ops) { if (op->type != "pnnx.Input") continue; const Operand* r = op->outputs[0]; std::string input_name = std::string("v_") + sanitize_identifier(r->name); if (type_is_integer(r->type)) { fprintf(pyfp, " %s = torch.randint(10, (", input_name.c_str()); for (size_t i = 0; i < r->shape.size(); i++) { fprintf(pyfp, "%d", r->shape[i]); if (i + 1 != r->shape.size() || r->shape.size() == 1) fprintf(pyfp, ", "); } fprintf(pyfp, "), dtype=%s)\n", type_to_dtype_string(r->type)); } else { fprintf(pyfp, " %s = torch.rand(", input_name.c_str()); for (size_t i = 0; i < r->shape.size(); i++) { fprintf(pyfp, "%d, ", r->shape[i]); } fprintf(pyfp, "dtype=%s)\n", type_to_dtype_string(r->type)); } input_names.push_back(input_name); } fprintf(pyfp, "\n"); if (input_names.size() == 1) { fprintf(pyfp, " mod = torch.jit.trace(net, %s)\n", input_names[0].c_str()); } else { fprintf(pyfp, " mod = torch.jit.trace(net, ("); for (size_t i = 0; i < input_names.size(); i++) { fprintf(pyfp, "%s", input_names[i].c_str()); if (i + 1 != input_names.size()) fprintf(pyfp, ", "); } fprintf(pyfp, "))\n"); } fprintf(pyfp, " mod.save(\"%s.pt\")\n", pypath.c_str()); } fprintf(pyfp, "\n"); // test inference { fprintf(pyfp, "def test_inference():\n"); fprintf(pyfp, " net = Model()\n"); fprintf(pyfp, " net.eval()\n"); fprintf(pyfp, "\n"); fprintf(pyfp, " torch.manual_seed(0)\n"); std::vector input_names; for (const Operator* op : ops) { if (op->type != "pnnx.Input") continue; const Operand* r = op->outputs[0]; std::string input_name = std::string("v_") + sanitize_identifier(r->name); if (type_is_integer(r->type)) { fprintf(pyfp, " %s = torch.randint(10, (", input_name.c_str()); for (size_t i = 0; i < r->shape.size(); i++) { fprintf(pyfp, "%d", r->shape[i]); if (i + 1 != r->shape.size() || r->shape.size() == 1) fprintf(pyfp, ", "); } fprintf(pyfp, "), dtype=%s)\n", type_to_dtype_string(r->type)); } else { fprintf(pyfp, " %s = torch.rand(", input_name.c_str()); for (size_t i = 0; i < r->shape.size(); i++) { fprintf(pyfp, "%d, ", r->shape[i]); } fprintf(pyfp, "dtype=%s)\n", type_to_dtype_string(r->type)); } input_names.push_back(input_name); } fprintf(pyfp, "\n"); if (input_names.size() == 1) { fprintf(pyfp, " return net(%s)\n", input_names[0].c_str()); } else { fprintf(pyfp, " return net("); for (size_t i = 0; i < input_names.size(); i++) { fprintf(pyfp, "%s", input_names[i].c_str()); if (i + 1 != input_names.size()) fprintf(pyfp, ", "); } fprintf(pyfp, ")\n"); } } fclose(pyfp); return 0; } static bool string_is_positive_integer(const std::string& t) { for (size_t i = 0; i < t.size(); i++) { if (t[i] < '0' || t[i] > '9') return false; } return true; } static unsigned short float32_to_float16(float value) { // 1 : 8 : 23 union { unsigned int u; float f; } tmp; tmp.f = value; // 1 : 8 : 23 unsigned short sign = (tmp.u & 0x80000000) >> 31; unsigned short exponent = (tmp.u & 0x7F800000) >> 23; unsigned int significand = tmp.u & 0x7FFFFF; // NCNN_LOGE("%d %d %d", sign, exponent, significand); // 1 : 5 : 10 unsigned short fp16; if (exponent == 0) { // zero or denormal, always underflow fp16 = (sign << 15) | (0x00 << 10) | 0x00; } else if (exponent == 0xFF) { // infinity or NaN fp16 = (sign << 15) | (0x1F << 10) | (significand ? 0x200 : 0x00); } else { // normalized short newexp = exponent + (-127 + 15); if (newexp >= 31) { // overflow, return infinity fp16 = (sign << 15) | (0x1F << 10) | 0x00; } else if (newexp <= 0) { // Some normal fp32 cannot be expressed as normal fp16 fp16 = (sign << 15) | (0x00 << 10) | 0x00; } else { // normal fp16 fp16 = (sign << 15) | (newexp << 10) | (significand >> 13); } } return fp16; } int Graph::ncnn(const std::string& parampath, const std::string& binpath, const std::string& pypath) { FILE* paramfp = fopen(parampath.c_str(), "wb"); if (!paramfp) { fprintf(stderr, "fopen %s failed\n", parampath.c_str()); return -1; } FILE* binfp = fopen(binpath.c_str(), "wb"); if (!binfp) { fprintf(stderr, "fopen %s failed\n", binpath.c_str()); fclose(paramfp); return -1; } // magic fprintf(paramfp, "7767517\n"); // op count and oprand count fprintf(paramfp, "%d %d\n", (int)ops.size(), (int)operands.size()); for (const Operator* op : ops) { fprintf(paramfp, "%-24s %-24s %d %d", op->type.c_str(), op->name.c_str(), (int)op->inputs.size(), (int)op->outputs.size()); for (const Operand* oprand : op->inputs) { fprintf(paramfp, " %s", oprand->name.c_str()); } for (const Operand* oprand : op->outputs) { fprintf(paramfp, " %s", oprand->name.c_str()); } for (const auto& it : op->params) { const Parameter& param = it.second; if (!string_is_positive_integer(it.first)) { fprintf(stderr, "ignore %s %s param %s=", op->type.c_str(), op->name.c_str(), it.first.c_str()); if (param.type == 0) { fprintf(stderr, "None"); } if (param.type == 1) { if (param.b) fprintf(stderr, "True"); else fprintf(stderr, "False"); } if (param.type == 2) { fprintf(stderr, "%d", param.i); } if (param.type == 3) { fprintf(stderr, "%e", param.f); } if (param.type == 4) { fprintf(stderr, "%s", param.s.c_str()); } if (param.type == 5) { fprintf(stderr, "("); for (size_t i = 0; i < param.ai.size(); i++) { fprintf(stderr, "%d", param.ai[i]); if (i + 1 != param.ai.size()) fprintf(stderr, ","); } fprintf(stderr, ")"); } if (param.type == 6) { fprintf(stderr, "("); for (size_t i = 0; i < param.af.size(); i++) { fprintf(stderr, "%e", param.af[i]); if (i + 1 != param.af.size()) fprintf(stderr, ","); } fprintf(stderr, ")"); } if (param.type == 7) { fprintf(stderr, "("); for (size_t i = 0; i < param.as.size(); i++) { fprintf(stderr, "%s", param.as[i].c_str()); if (i + 1 != param.as.size()) fprintf(stderr, ","); } fprintf(stderr, ")"); } fprintf(stderr, "\n"); continue; } const int idkey = std::stoi(it.first); if (param.type == 2) { fprintf(paramfp, " %d=%d", idkey, param.i); } if (param.type == 3) { fprintf(paramfp, " %d=%e", idkey, param.f); } if (param.type == 5) { const int array_size = (int)param.ai.size(); fprintf(paramfp, " %d=%d", -23300 - idkey, array_size); for (size_t i = 0; i < param.ai.size(); i++) { fprintf(paramfp, ",%d", param.ai[i]); } } if (param.type == 6) { const int array_size = (int)param.af.size(); fprintf(paramfp, " %d=%d", -23300 - idkey, array_size); for (size_t i = 0; i < param.af.size(); i++) { fprintf(paramfp, ",%e", param.af[i]); } } } bool is_type_flag_fp32 = false; for (const auto& it : op->attrs) { // fprintf(paramfp, " @%s=", it.first.c_str()); const Attribute& attr = it.second; if (is_type_flag_fp32) { // fp32 -> fp16 const float* p = (const float*)attr.data.data(); int len = attr.data.size() / 4; for (int i = 0; i < len; i++) { unsigned short v_fp16 = float32_to_float16(p[i]); fwrite(&v_fp16, sizeof(v_fp16), 1, binfp); } is_type_flag_fp32 = false; continue; } if (attr.type == 0 && attr.data == std::vector {0, 0, 0, 0}) { // write fp16 flag unsigned int fp16_flag = 0x01306B47; fwrite(&fp16_flag, sizeof(fp16_flag), 1, binfp); is_type_flag_fp32 = true; continue; } fwrite(attr.data.data(), attr.data.size(), 1, binfp); } // if (op->inputnames.size() == op->inputs.size()) // { // for (size_t i = 0; i < op->inputs.size(); i++) // { // const Operand* oprand = op->inputs[i]; // fprintf(paramfp, " $%s=%s", op->inputnames[i].c_str(), oprand->name.c_str()); // } // } // for (const Operand* oprand : op->outputs) // { // if (oprand->params.find("__batch_index") == oprand->params.end()) // continue; // // const int batch_index = oprand->params.at("__batch_index").i; // // fprintf(paramfp, " #%s=%d", oprand->name.c_str(), batch_index); // } // for (const Operand* oprand : op->outputs) // { // if (oprand->shape.empty()) // continue; // // fprintf(paramfp, " #%s=", oprand->name.c_str()); // // fprintf(paramfp, "("); // for (int64_t i = 0; i < oprand->shape.size() - 1; i++) // { // fprintf(paramfp, "%d,", oprand->shape[i]); // } // if (oprand->shape.size() > 0) // fprintf(paramfp, "%d", oprand->shape[oprand->shape.size() - 1]); // fprintf(paramfp, ")"); // // fprintf(paramfp, type_to_string(oprand->type)); // } fprintf(paramfp, "\n"); } fclose(paramfp); fclose(binfp); FILE* pyfp = fopen(pypath.c_str(), "wb"); if (!pyfp) { fprintf(stderr, "fopen %s failed\n", pypath.c_str()); return -1; } fprintf(pyfp, "import numpy as np\n"); fprintf(pyfp, "import ncnn\n"); fprintf(pyfp, "import torch\n"); fprintf(pyfp, "\n"); // test inference { fprintf(pyfp, "def test_inference():\n"); fprintf(pyfp, " torch.manual_seed(0)\n"); for (int input_index = 0;; input_index++) { std::string input_name = std::string("in") + std::to_string(input_index); const Operand* r = get_operand(input_name); if (!r) break; if (type_is_integer(r->type)) { fprintf(pyfp, " %s = torch.randint(10, (", input_name.c_str()); for (size_t i = 0; i < r->shape.size(); i++) { fprintf(pyfp, "%d", r->shape[i]); if (i + 1 != r->shape.size() || r->shape.size() == 1) fprintf(pyfp, ", "); } fprintf(pyfp, "), dtype=%s)\n", type_to_dtype_string(r->type)); } else { fprintf(pyfp, " %s = torch.rand(", input_name.c_str()); for (size_t i = 0; i < r->shape.size(); i++) { fprintf(pyfp, "%d, ", r->shape[i]); } fprintf(pyfp, "dtype=%s)\n", type_to_dtype_string(r->type)); } } fprintf(pyfp, " out = []\n"); fprintf(pyfp, "\n"); fprintf(pyfp, " with ncnn.Net() as net:\n"); fprintf(pyfp, " net.load_param(\"%s\")\n", parampath.c_str()); fprintf(pyfp, " net.load_model(\"%s\")\n", binpath.c_str()); fprintf(pyfp, "\n"); fprintf(pyfp, " with net.create_extractor() as ex:\n"); for (int input_index = 0;; input_index++) { std::string input_name = std::string("in") + std::to_string(input_index); const Operand* r = get_operand(input_name); if (!r) break; const int batch_index = r->params.at("__batch_index").i; if (batch_index != 233) { fprintf(pyfp, " ex.input(\"%s\", ncnn.Mat(%s.squeeze(%d).numpy()).clone())\n", input_name.c_str(), input_name.c_str(), batch_index); } else { fprintf(pyfp, " ex.input(\"%s\", ncnn.Mat(%s.numpy()).clone())\n", input_name.c_str(), input_name.c_str()); } } fprintf(pyfp, "\n"); for (int output_index = 0;; output_index++) { std::string output_name = std::string("out") + std::to_string(output_index); const Operand* r = get_operand(output_name); if (!r) break; fprintf(pyfp, " _, %s = ex.extract(\"%s\")\n", output_name.c_str(), output_name.c_str()); const int batch_index = r->params.at("__batch_index").i; if (batch_index != 233) { fprintf(pyfp, " out.append(torch.from_numpy(np.array(%s)).unsqueeze(%d))\n", output_name.c_str(), batch_index); } else { fprintf(pyfp, " out.append(torch.from_numpy(np.array(%s)))\n", output_name.c_str()); } } fprintf(pyfp, "\n"); fprintf(pyfp, " if len(out) == 1:\n"); fprintf(pyfp, " return out[0]\n"); fprintf(pyfp, " else:\n"); fprintf(pyfp, " return tuple(out)\n"); } fclose(pyfp); return 0; } int Graph::parse(const std::string& param) { std::istringstream is(param); if (!is.good()) { fprintf(stderr, "open failed\n"); return -1; } int magic = 0; { std::string line; std::getline(is, line); std::istringstream iss(line); iss >> magic; } int operator_count = 0; int operand_count = 0; { std::string line; std::getline(is, line); std::istringstream iss(line); iss >> operator_count >> operand_count; } for (int i = 0; i < operator_count; i++) { std::string line; std::getline(is, line); std::istringstream iss(line); std::string type; std::string name; int input_count = 0; int output_count = 0; iss >> type >> name >> input_count >> output_count; Operator* op = new_operator(type, name); for (int j = 0; j < input_count; j++) { std::string operand_name; iss >> operand_name; Operand* r = get_operand(operand_name); r->consumers.push_back(op); op->inputs.push_back(r); } for (int j = 0; j < output_count; j++) { std::string operand_name; iss >> operand_name; Operand* r = new_operand(operand_name); r->producer = op; op->outputs.push_back(r); } // key=value while (!iss.eof()) { std::string param; iss >> param; std::string key; std::string value; std::istringstream pss(param); std::getline(pss, key, '='); std::getline(pss, value); if (key[0] == '@') { // attribute // load_attribute(op, key.substr(1), value, szr); } else if (key[0] == '$') { // operand input key // load_input_key(op, key.substr(1), value); } else if (key[0] == '#') { // operand shape load_shape(op, key.substr(1), value); } else { // parameter load_parameter(op, key, value); } } } return 0; } void Operand::remove_consumer(const Operator* c) { auto it = std::find(consumers.begin(), consumers.end(), c); consumers.erase(it); } Operator* Graph::new_operator(const std::string& type, const std::string& name) { Operator* op = new Operator; op->type = type; op->name = name; ops.push_back(op); return op; } Operator* Graph::new_operator_before(const std::string& type, const std::string& name, const Operator* cur) { Operator* op = new Operator; op->type = type; op->name = name; ops.insert(std::find(ops.begin(), ops.end(), cur), op); return op; } Operator* Graph::new_operator_after(const std::string& type, const std::string& name, const Operator* cur) { Operator* op = new Operator; op->type = type; op->name = name; ops.insert(std::find(ops.begin(), ops.end(), cur) + 1, op); return op; } Operand* Graph::new_operand(const torch::jit::Value* v) { Operand* r = new Operand; r->name = v->debugName(); auto pt = v->type()->cast(); if (pt) { if (pt->scalarType().has_value() && pt->dim().has_value()) { r->type = get_at_tensor_type(pt->scalarType().value()); const int ndim = (int)pt->dim().value(); r->shape.resize(ndim); for (int i = 0; i < ndim; i++) { if (pt->sizes()[i].has_value()) r->shape[i] = (int)pt->sizes()[i].value(); else r->shape[i] = -1; } } } operands.push_back(r); return r; } Operand* Graph::new_operand(const std::string& name) { Operand* r = new Operand; r->name = name; operands.push_back(r); return r; } Operand* Graph::get_operand(const std::string& name) { for (Operand* r : operands) { if (r->name == name) return r; } return 0; } } // namespace pnnx