add nnom pack and example
This commit is contained in:
86
components/ai/nnom/inc/layers/nnom_simple_cell.h
Normal file
86
components/ai/nnom/inc/layers/nnom_simple_cell.h
Normal file
@@ -0,0 +1,86 @@
|
||||
/*
|
||||
* Copyright (c) 2018-2020
|
||||
* Jianjia Ma
|
||||
* majianjia@live.com
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Change Logs:
|
||||
* Date Author Notes
|
||||
* 2020-08-20 Jianjia Ma The first version
|
||||
*/
|
||||
|
||||
#ifndef __NNOM_SIMPLE_CELL_H__
|
||||
#define __NNOM_SIMPLE_CELL_H__
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#include "nnom_rnn.h"
|
||||
#include "nnom_activation.h"
|
||||
|
||||
|
||||
// This Simple Cell replicate the Keras's SimpleCell as blow
|
||||
/*
|
||||
def call(self, inputs, states, training=None):
|
||||
prev_output = states[0] if nest.is_sequence(states) else states
|
||||
|
||||
h = K.dot(inputs, self.kernel)
|
||||
h = K.bias_add(h, self.bias)
|
||||
|
||||
output = h + K.dot(prev_output, self.recurrent_kernel)
|
||||
output = self.activation(output)
|
||||
|
||||
new_state = [output] if nest.is_sequence(states) else output
|
||||
return output, new_state
|
||||
*/
|
||||
|
||||
// a machine interface for configuration
|
||||
typedef struct _nnom_simple_cell_config_t
|
||||
{
|
||||
nnom_layer_config_t super;
|
||||
nnom_tensor_t *weights;
|
||||
nnom_tensor_t* recurrent_weights;
|
||||
nnom_tensor_t *bias;
|
||||
nnom_qformat_param_t q_dec_iw, q_dec_hw, q_dec_h;
|
||||
nnom_activation_type_t act_type; // type of the activation
|
||||
uint16_t units;
|
||||
} nnom_simple_cell_config_t;
|
||||
|
||||
|
||||
typedef struct _nnom_simple_cell_t
|
||||
{
|
||||
nnom_rnn_cell_t super;
|
||||
nnom_activation_type_t act_type;
|
||||
|
||||
nnom_tensor_t* weights;
|
||||
nnom_tensor_t* recurrent_weights;
|
||||
nnom_tensor_t* bias;
|
||||
|
||||
// experimental,
|
||||
// iw: input x weight
|
||||
// hw: hidden state x recurrent weight
|
||||
// h: hidden state
|
||||
nnom_qformat_param_t q_dec_iw, q_dec_hw, q_dec_h;
|
||||
nnom_qformat_param_t oshift_iw, oshift_hw, bias_shift;
|
||||
|
||||
} nnom_simple_cell_t;
|
||||
|
||||
|
||||
// RNN cells
|
||||
// The shape for RNN input is (batch, timestamp, feature), where batch is always 1.
|
||||
//
|
||||
// SimpleCell
|
||||
nnom_rnn_cell_t *simple_cell_s(const nnom_simple_cell_config_t* config);
|
||||
|
||||
nnom_status_t simple_cell_free(nnom_rnn_cell_t* cell);
|
||||
nnom_status_t simple_cell_build(nnom_rnn_cell_t* cell);
|
||||
nnom_status_t simple_cell_run(nnom_rnn_cell_t* cell);
|
||||
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif /* __NNOM_SIMPLE_CELL_H__ */
|
Reference in New Issue
Block a user