教程3:tensorflow lite c++ 接口解读
教程2:tensorflow lite 编译和安装二 使用bazel编译
教程3:tensorflow lite c++ 接口解读
教程4:TensorFlow2.0 label_image 的编译和使用
大概整理了TensorFlow lite的初始化流程。主要从BuildFromBuffer 开始到构建完毕的初始化流程。
FlatBufferModel需要包含的头文件#include <model.h>
解压模型的时候可以从文件夹中读取一个文件,也可以从数组buffer内存中直接读取16进制数据。本说明中是读取的数组数据。
任何依赖的Interpreter实例正在使用FlatBufferModel实例,所构建的实例自始至终要吃激活状态不可被毁掉。
官网中给出的代码是将 解析器,错误报告,解释器,放在一起写的。
using namespace tflite; StderrReporter error_reporter; auto model = FlatBufferModel::BuildFromFile("interesting_model.tflite", &error_reporter); MyOpResolver resolver; // You need to subclass OpResolver to provide // implementations. InterpreterBuilder builder(*model, resolver); std::unique_ptr interpreter; if(builder(&interpreter) == kTfLiteOk) { .. run model inference with interpreter }
还有就是从buffer读取
std::unique_ptr< FlatBufferModel > BuildFromBuffer( const char *caller_owned_buffer, size_t buffer_size, ErrorReporter *error_reporter )
这是基于预加载的flatbuffer缓冲区构建模型。
当调用时要保留缓冲区的所有权,并应保持其激活状态,直到销毁返回的对象为止。 调用时还保留error_reporter的所有权,并且必须确保其生存期比FlatBufferModel实例更长。 如果失败,则返回nullptr。 注意:这不会验证缓冲区,因此不应在无效/不受信任的输入上调用。
ErrorReporter在#include <error_reporter.h> 头文件中
官方解释类似于printf.
OpResolver在头文件#include <op_resolver.h>中
返回给定操作码或自定义操作名的TfLiteRegistration的抽象接口。这是在Flatbuffer模型中引用的操作被映射到可执行函数指针(TfLiteRegistrations)的机制。
InterpreterBuilder在#include <model.h>头文件中。
构建能够解释模型的解释器。一个模型的生命周期必须与构建器创建的任何解释器一样长。 原则上,可以从一个模型中创建多个解释器。 op_resolver:实现OpResolver接口的实例,该接口将自定义操作名称和内置操作代码映射到操作注册。 提供的op_resolver对象的生命周期必须至少与InterpreterBuilder一样长。 与model和error_reporter不同,op_resolver在创建的Interpreter对象的持续时间内不需要存在。 error_reporter:一个函数子,该函子被调用以报告处理printf var arg语义的错误。 error_reporter对象的生存期必须大于或等于operator()创建的解释器。
成功时返回kTfLiteOk并将解释器设置为有效的解释器。 注意:用户必须确保模型生命周期(和错误报告程序,如果提供)至少与解释程序的生命一样长。
Interpreter在#include <interpreter.h>的头文件中,这个解释器个人觉得翻译成是从张量输入输出的graph节点的解释器。该解释器对应张量流和graph节点之间的关系。具体的说每一个graph节点处理输入张量并生成对应的输出张量。所有的输入、输出张量可以由索引提供。
void init_interpreter(Settings *s)
{
if (!s->model_name.c_str())
{
LOG(ERROR) << "no model file name\n";
exit(-1);
}
std::unique_ptr<tflite::FlatBufferModel> model;
model = tflite::FlatBufferModel::BuildFromBuffer(mfn_tflite, mfn_tflite_len);
//model = tflite::FlatBufferModel::BuildFromFile("mfn.tflite");
if (!model)
{
LOG(FATAL) << "\nFailed to load model " << s->model_name << "\n";
exit(-1);
}
if (s->verbose)
LOG(INFO) << "Loaded model " << s->model_name << "\n\r";
model->error_reporter();
if (s->verbose)
LOG(INFO) << "resolved reporter\n\r";
tflite::ops::builtin::BuiltinOpResolver resolver;
tflite::InterpreterBuilder(*model, resolver)(&interpreter);
if (!interpreter)
{
LOG(FATAL) << "Failed to construct interpreter\n";
exit(-1);
}
interpreter->UseNNAPI(s->accel);
if (s->verbose)
{
LOG(INFO) << "tensors size: " << interpreter->tensors_size() << "\n\r";
LOG(INFO) << "nodes size: " << interpreter->nodes_size() << "\n\r";
LOG(INFO) << "inputs: " << interpreter->inputs().size() << "\n\r";
LOG(INFO) << "input(0) name: " << interpreter->GetInputName(0) << "\n\r";
}
if (s->number_of_threads != -1)
{
interpreter->SetNumThreads(s->number_of_threads);
}
input = interpreter->inputs()[0];
if (s->verbose)
LOG(INFO) << "input: " << input << "\n\r";
const std::vector<int> inputs = interpreter->inputs();
const std::vector<int> outputs = interpreter->outputs();
if (s->verbose)
{
LOG(INFO) << "number of inputs: " << inputs.size() << "\n\r";
LOG(INFO) << "number of outputs: " << outputs.size() << "\n\r";
}
if (interpreter->AllocateTensors() != kTfLiteOk)
{
LOG(FATAL) << "Failed to allocate tensors!";
}
if (s->verbose)
PrintInterpreterState(interpreter.get());
// get input dimension from the input tensor metadata
// assuming one input only
TfLiteIntArray *dims = interpreter->tensor(input)->dims;
wanted_height = dims->data[1];
wanted_width = dims->data[2];
wanted_channels = dims->data[3];
}
最新评论