Files
TencentOS-tiny/components/ai/nnom/src/layers/nnom_gru_cell.c
2021-09-08 23:47:15 +08:00

339 lines
11 KiB
C

/*
* Copyright (c) 2018-2020
* Jianjia Ma
* majianjia@live.com
*
* SPDX-License-Identifier: Apache-2.0
*
* Change Logs:
* Date Author Notes
* 2020-08-24 Jianjia Ma The first version
*/
#include <stdint.h>
#include <string.h>
#include <stdbool.h>
#include "nnom.h"
#include "nnom_local.h"
#include "nnom_layers.h"
#include "layers/nnom_gru_cell.h"
#ifdef NNOM_USING_CMSIS_NN
#include "arm_math.h"
#include "arm_nnfunctions.h"
#endif
nnom_rnn_cell_t *gru_cell_s(const nnom_gru_cell_config_t* config)
{
nnom_gru_cell_t *cell;
cell = nnom_mem(sizeof(nnom_gru_cell_t));
if (cell == NULL)
return NULL;
// set methods
cell->super.run = gru_cell_run;
cell->super.build = gru_cell_build;
cell->super.free = gru_cell_free;
cell->super.config = (void*) config;
cell->super.units = config->units;
cell->super.type = NNOM_GRU_CELL;
// set parameters
cell->bias = config->bias;
cell->weights = config->weights;
cell->recurrent_weights = config->recurrent_weights;
// q format for intermediate calculation
cell->q_dec_h = config->q_dec_h;
cell->q_dec_z = config->q_dec_z;
return (nnom_rnn_cell_t *)cell;
}
nnom_status_t gru_cell_free(nnom_rnn_cell_t* cell)
{
return NN_SUCCESS;
}
// the state buffer and computational buffer shape of the cell
nnom_status_t gru_cell_build(nnom_rnn_cell_t* cell)
{
nnom_layer_t *layer = cell->layer;
nnom_gru_cell_t *c = (nnom_gru_cell_t *)cell;
// calculate output shift for the 2 calculation.
// hw = the product of hidden x weight, iw = the product of input x weight
// due to the addition of them, they must have same q format.
// that is -> c->q_dec_z;
// for the dots in cell: output shift = input_dec + weight_dec - output_dec
c->oshift_hw = c->q_dec_h + c->recurrent_weights->q_dec[0] - c->q_dec_z;
c->oshift_iw = layer->in->tensor->q_dec[0] + c->weights->q_dec[0] - c->q_dec_z;
// bias shift = bias_dec - out_dec
c->bias_shift = layer->in->tensor->q_dec[0] + c->weights->q_dec[0] - c->bias->q_dec[0];
// state size = one timestamp output size.
cell->state_size = cell->units * 2; // Q15
// comp buffer size: not required
cell->comp_buf_size = cell->units * (3*3) * 2 + cell->feature_size * 2; //q15 + input q7->q15 buffer.
// finally, calculate the MAC for info for each timestamp
cell->macc = cell->feature_size * cell->units *3 // input: feature * state * 3 gates
+ cell->units * cell->units *8 // recurrent, state * output_unit * (5 gate + 3 mult)
+ cell->units * (3 + 3 + 5); // 3 gates, 3 mult, 5 addition
return NN_SUCCESS;
}
// keras implementation as below.
/*
def step(cell_inputs, cell_states):
"""Step function that will be used by Keras RNN backend."""
h_tm1 = cell_states[0]
# inputs projected by all gate matrices at once
matrix_x = K.dot(cell_inputs, kernel)
matrix_x = K.bias_add(matrix_x, input_bias)
x_z, x_r, x_h = array_ops.split(matrix_x, 3, axis=1)
# hidden state projected by all gate matrices at once
matrix_inner = K.dot(h_tm1, recurrent_kernel)
matrix_inner = K.bias_add(matrix_inner, recurrent_bias)
recurrent_z, recurrent_r, recurrent_h = array_ops.split(matrix_inner, 3,
axis=1)
z = nn.sigmoid(x_z + recurrent_z)
r = nn.sigmoid(x_r + recurrent_r)
hh = nn.tanh(x_h + r * recurrent_h)
# previous and candidate state mixed by update gate
h = z * h_tm1 + (1 - z) * hh
return h, [h]
*/
//
nnom_status_t gru_cell_run(nnom_rnn_cell_t* cell)
{
nnom_layer_t *layer = cell->layer;
nnom_gru_cell_t* c = (nnom_gru_cell_t*) cell;
int act_int_bit = 7 - c->q_dec_z;
// gate data
q15_t* x_z, *x_r, *x_h;
q15_t* recurrent_z, *recurrent_r, *recurrent_h;
q15_t* temp[3];
// bias
q7_t* bias = (q7_t*)c->bias->p_data;
q7_t* recurrent_bias = (q7_t*)c->bias->p_data + cell->units*3;
// state buffer
q15_t* h_tm1 = (q15_t*)cell->in_state;
q15_t* h_t = (q15_t*)cell->out_state;
// computing buffer
// low |-- buf0 --|-- buf1 --|-- buf2 --|-- input_q15 --|
q15_t *buf[3];
buf[0] = (q15_t*)layer->comp->mem->blk;
buf[1] = (q15_t*)layer->comp->mem->blk + cell->units*3;
buf[2] = (q15_t*)layer->comp->mem->blk + cell->units*6;
q15_t *in_q15_buf = (q15_t*)layer->comp->mem->blk + cell->units*9;
// input q7 cast to q15
local_q7_to_q15(cell->in_data, in_q15_buf, cell->feature_size);
// matrix_x = K.dot(cell_inputs, kernel) + bias --> buf0
#ifdef NNOM_USING_CMSIS_NN
arm_fully_connected_mat_q7_vec_q15_opt
#else
local_fully_connected_mat_q7_vec_q15_opt
#endif
(in_q15_buf, c->weights->p_data, cell->feature_size,
cell->units*3, c->bias_shift + 8, c->oshift_iw, bias, buf[0], NULL);
// matrix_intter = K.dot(h_tm1, recurrent_kernel) + bias -> buf1
#ifdef NNOM_USING_CMSIS_NN
arm_fully_connected_mat_q7_vec_q15_opt
#else
local_fully_connected_mat_q7_vec_q15_opt
#endif
(h_tm1, c->recurrent_weights->p_data, cell->units,
cell->units*3, c->bias_shift + 8, c->oshift_hw, recurrent_bias, buf[1], NULL);
// split to each gate
x_z = buf[0];
x_r = buf[0] + cell->units;
x_h = buf[0] + cell->units*2;
recurrent_z = buf[1];
recurrent_r = buf[1] + cell->units;
recurrent_h = buf[1] + cell->units*2;
// buffers
temp[0] = buf[2];
temp[1] = buf[2] + cell->units;
temp[2] = buf[2] + cell->units*2;
/* z = nn.sigmoid(x_z + recurrent_z) */
// 1. z1 = x_z + recurrent_z ---> temp[0]
local_add_q15(x_z, recurrent_z, temp[0], 0, cell->units);
// 2. z = sigmoid(z1)
local_sigmoid_q15(temp[0], cell->units, act_int_bit);
/* r = nn.sigmoid(x_r + recurrent_r) */
// 1. r1 = x_r + recurrent_r ---> temp[1]
local_add_q15(x_r, recurrent_r, temp[1], 0, cell->units);
// 2. r = sigmoid(r1)
local_sigmoid_q15(temp[1], cell->units, act_int_bit);
/* hh = nn.tanh(x_h + r * recurrent_h) */
// 1. hh1 = r * recurrent_h ---> temp[2]
local_mult_q15(temp[1], recurrent_h, temp[2], 15, cell->units);
// 2. hh2 = x_h + hh1 ---> temp[1]
local_add_q15(x_h, temp[2], temp[1], 0, cell->units);
// 3. hh = tanh(h2) ---> temp[1]
local_tanh_q15(temp[1], cell->units, act_int_bit);
/* h = z * h_tm1 + (1 - z) * hh */
// 1. h1 = z*h_tm1 ---> temp[2]
local_mult_q15(temp[0], h_tm1, temp[2], 15, cell->units);
// 2. h2 = 1 - z ---> h_t state buff
local_1_minor_z_q15(temp[0], h_t, 15, cell->units);
// 3. h3 = h2 * hh ---> temp[0]
local_mult_q15(h_t, temp[1], temp[0], 15, cell->units);
// h = h1 + h3
local_add_q15(temp[2], temp[0], h_t, 0, cell->units);
// finally, copy and convert state to output
local_q15_to_q7(h_t, cell->out_data, 8, cell->units);
return NN_SUCCESS;
}
// Researve for debugging, printing the intermediate variables/data.
#if 0
// delete after testing completed
static void print_variable_q15(q15_t *data,char*name, int dec_bit, int size)
{
printf("\n\n");
printf("%s", name);
for(int i = 0; i < size; i++)
{
if(i%8==0)
printf("\n");
printf("%f\t", (float) data[i] / (1 << dec_bit));
}
printf("\n");
}
//
nnom_status_t gru_cell_run(nnom_rnn_cell_t* cell)
{
nnom_layer_t *layer = cell->layer;
nnom_gru_cell_t* c = (nnom_gru_cell_t*) cell;
int act_int_bit = 7 - c->q_dec_z;
// gate data
q15_t* x_z, *x_r, *x_h;
q15_t* recurrent_z, *recurrent_r, *recurrent_h;
q15_t* temp[3];
// test
//nnom_memset(cell->in_data, 5 * (1<<layer->in->tensor->q_dec[0]), cell->feature_size);
// bias
q7_t* bias = (q7_t*)c->bias->p_data;
q7_t* recurrent_bias = (q7_t*)c->bias->p_data + cell->units*3;
// state buffer
q15_t* h_tm1 = (q15_t*)cell->in_state;
q15_t* h_t = (q15_t*)cell->out_state;
// computing buffer
// low |-- buf0 --|-- buf1 --|-- buf2 --|-- input_q15 --|
q15_t *buf[3];
buf[0] = (q15_t*)layer->comp->mem->blk;
buf[1] = (q15_t*)layer->comp->mem->blk + cell->units*3;
buf[2] = (q15_t*)layer->comp->mem->blk + cell->units*6;
q15_t *in_q15_buf = (q15_t*)layer->comp->mem->blk + cell->units*9;
// input q7 cast to q15
local_q7_to_q15(cell->in_data, in_q15_buf, cell->feature_size);
// matrix_x = K.dot(cell_inputs, kernel) + bias --> buf0
#ifdef NNOM_USING_CMSIS_NN
arm_fully_connected_mat_q7_vec_q15_opt
#else
local_fully_connected_mat_q7_vec_q15_opt
#endif
(in_q15_buf, c->weights->p_data, cell->feature_size,
cell->units*3, c->bias_shift + 8, c->oshift_iw, bias, buf[0], NULL);
// matrix_intter = K.dot(h_tm1, recurrent_kernel) + bias -> buf1
#ifdef NNOM_USING_CMSIS_NN
arm_fully_connected_mat_q7_vec_q15_opt
#else
local_fully_connected_mat_q7_vec_q15_opt
#endif
(h_tm1, c->recurrent_weights->p_data, cell->units,
cell->units*3, c->bias_shift + 8, c->oshift_hw, recurrent_bias, buf[1], NULL);
print_variable_q15(in_q15_buf, "input", layer->in->tensor->q_dec[0]+8, cell->feature_size);
print_variable_q15(buf[0], "matrix_x", c->q_dec_z+8, cell->units*3);
print_variable_q15(buf[1], "matrix_recurrent", c->q_dec_z+8, cell->units*3);
// split to each gate
x_z = buf[0];
x_r = buf[0] + cell->units;
x_h = buf[0] + cell->units*2;
recurrent_z = buf[1];
recurrent_r = buf[1] + cell->units;
recurrent_h = buf[1] + cell->units*2;
// buffers
temp[0] = buf[2];
temp[1] = buf[2] + cell->units;
temp[2] = buf[2] + cell->units*2;
// z = nn.sigmoid(x_z + recurrent_z)
// 1. z1 = x_z + recurrent_z ---> temp[0]
local_add_q15(x_z, recurrent_z, temp[0], 0, cell->units);
// 2. z = sigmoid(z1)
local_sigmoid_q15(temp[0], cell->units, act_int_bit);
print_variable_q15(temp[0], "z", 15, cell->units);
// r = nn.sigmoid(x_r + recurrent_r)
// 1. r1 = x_r + recurrent_r ---> temp[1]
local_add_q15(x_r, recurrent_r, temp[1], 0, cell->units);
// 2. r = sigmoid(r1)
local_sigmoid_q15(temp[1], cell->units, act_int_bit);
print_variable_q15(temp[1], "r", 15, cell->units);
// hh = nn.tanh(x_h + r * recurrent_h)
// 1. hh1 = r * recurrent_h ---> temp[2]
local_mult_q15(temp[1], recurrent_h, temp[2], 15, cell->units);
// 2. hh2 = x_h + h1 ---> temp[1]
local_add_q15(x_h, temp[2], temp[1], 0, cell->units);
// 3. hh = tanh(h2) ---> temp[1]
local_tanh_q15(temp[1], cell->units, act_int_bit);
print_variable_q15(temp[1], "hh", 15, cell->units);
// h = z * h_tm1 + (1 - z) * hh
// 1. h1 = z*h_tm1 ---> temp[2]
local_mult_q15(temp[0], h_tm1, temp[2], 15, cell->units);
print_variable_q15( temp[2], "h1", 15, cell->units);
// 2. h2 = 1 - z ---> h_t state buff
local_1_minor_z_q15(temp[0], h_t, 15, cell->units);
print_variable_q15( h_t, "h2", 15, cell->units);
// 3. h3 = h2 * hh ---> temp[0]
local_mult_q15(h_t, temp[1], temp[0], 15, cell->units);
print_variable_q15( temp[0], "h3", 15, cell->units);
// h = h1 + h3
local_add_q15(temp[2], temp[0], h_t, 0, cell->units);
print_variable_q15(h_t, "h", 15, cell->units);
// finally, copy and convert state to output
local_q15_to_q7(h_t, cell->out_data, 8, cell->units);
return NN_SUCCESS;
}
#endif