Files
TencentOS-tiny/components/ai/onnx/operator_int/model.c
2021-09-06 21:50:44 +08:00

85 lines
3.0 KiB
C

#include <inttypes.h>
#include "onnx.h"
int* onnx_model_run(Onnx__ModelProto* model, int* input, int64_t* shapeInput)
{
int64_t* shapeOutput = (int64_t*) malloc(sizeof(int64_t)*3);
shapeOutput[0] = -1; shapeOutput[1] = -1; shapeOutput[2] = -1;
Onnx__NodeProto* node = onnx_graph_get_node_by_input(model->graph, model->graph->input[0]->name);
int i = 0;
int* output;
while(node != NULL)
{
printf("[%2d] %-10s %-20s ", i++, node->op_type, node->name);
if(strcmp(node->op_type, "Conv") == 0)
{
output = conv2D_layer(model->graph, input, shapeInput, shapeOutput, node->name);
}
else if(strcmp(node->op_type, "Relu") == 0)
{
output = relu_layer(model->graph, input, shapeInput, shapeOutput, node->name);
}
else if(strcmp(node->op_type, "MaxPool") == 0)
{
output = maxpool_layer(model->graph, input, shapeInput, shapeOutput, node->name);
}
else if(strcmp(node->op_type, "Softmax") == 0)
{
output = softmax_layer(model->graph, input, shapeInput, shapeOutput, node->name);
}
else if(strcmp(node->op_type, "MatMul") == 0)
{
output = matmul_layer(model->graph, input, shapeInput, shapeOutput, node->name);
}
else if(strcmp(node->op_type, "Add") == 0)
{
output = add_layer(model->graph, input, shapeInput, shapeOutput, node->name);
}
else if(strcmp(node->op_type, "Identity") == 0)
{
node = onnx_graph_get_node_by_input(model->graph, node->output[0]);
printf("\n");
continue;
}
else if(strcmp(node->op_type, "Transpose") == 0)
{
node = onnx_graph_get_node_by_input(model->graph, node->output[0]);
printf("\n");
continue;
}
else if(strcmp(node->op_type, "Reshape") == 0)
{
shapeOutput[1] = shapeOutput[0] * shapeOutput[1] * shapeOutput[2];
shapeOutput[2] = 1;
shapeOutput[0] = 1;
printf("[%2" PRId64 ", %2" PRId64 ", %2" PRId64 "] --> [%2" PRId64 ", %2" PRId64 ", %2" PRId64 "]\n", shapeInput[0], shapeInput[1], shapeInput[2], shapeOutput[0], shapeOutput[1], shapeOutput[2]);
// free(input);
// input = output;
memcpy(shapeInput, shapeOutput, sizeof(int64_t)*3);
node = onnx_graph_get_node_by_input(model->graph, node->output[0]);
continue;
}
else
{
printf("Unsupported operand: %s\n", node->op_type);
}
printf("[%2" PRId64 ", %2" PRId64 ", %2" PRId64 "] --> [%2" PRId64 ", %2" PRId64 ", %2" PRId64 "]\n", shapeInput[0], shapeInput[1], shapeInput[2], shapeOutput[0], shapeOutput[1], shapeOutput[2]);
free(input);
input = output;
memcpy(shapeInput, shapeOutput, sizeof(int64_t)*3);
node = onnx_graph_get_node_by_input(model->graph, node->output[0]);
}
output = input;
free(shapeOutput);
return output;
}