// mnist_nn_test.c
// 2020/09/08 by marsee
//
#include <stdio.h>
#include <stdint.h>
#include "af1_weight_float.h"
#include "af1_bias_float.h"
#include "af2_weight_float.h"
#include "af2_bias_float.h"
#include "mnist_data_10.h"
#include "xmnist_nn.h"
int mnist_nn_float(float in[784], float out[10]);
int max_float(float out[10]);
int max_int32_t(int32_t out[10]);
#define NUM_ITERATIONS 10 // C Simulation
// #define NUM_ITERATIONS 2 // C/RTL CoSimulation
int main(){
float t_tran_float[NUM_ITERATIONS][784];
uint8_t t_tran_uint8_t[NUM_ITERATIONS][784];
int32_t result_hard[NUM_ITERATIONS][10];
float result_soft[NUM_ITERATIONS][10];
int max_id_hw, max_id_sw, max_id_ref;
XMnist_nn mnits_nn_ap;
int mnist_nn_isdone = 0;
int32_t res;
printf("Hello World\n");
for(int i=0; i<NUM_ITERATIONS; i++){
for(int j=0; j<784; j++){
t_tran_float[i][j] = (float)(t_train_256[i][j])/256.0;
t_tran_uint8_t[i][j] = (uint32_t)(t_train_256[i][j]);
}
}
// Initialize tht Device
//printf("a"); fflush(stdout);
int XMinst_status = XMnist_nn_Initialize(&mnits_nn_ap, 0);
if (XMinst_status != XST_SUCCESS){
fprintf(stderr, "Could not Initialize XMnist_nn\n");
return(-1);
}
for(int i=0; i<NUM_ITERATIONS; i++){
//printf("a"); fflush(stdout);
u32 char_num = (u32)(&t_tran_uint8_t[i][0]);
XMnist_nn_Set_in_V(&mnits_nn_ap, char_num);
//printf("a"); fflush(stdout);
XMnist_nn_Start(&mnits_nn_ap);
while(mnist_nn_isdone == 0)
mnist_nn_isdone = XMnist_nn_IsDone(&mnits_nn_ap);
// minst nn result check
for(int j=0; j<5; j++){
XMnist_nn_Read_out_V_Words(&mnits_nn_ap, i, &res, 1);
result_hard[i][j*2] = res & 0x1fff; // 13 bit
if(result_hard[i][j*2] & 0x1000) // minus
result_hard[i][j*2] = 0xffffe000 | result_hard[i][j*2]; // Sign extension
result_hard[i][j*2+1] = (res & 0x1fff0000) >> 16;
if(result_hard[i][j*2+1] & 0x1000) // minus
result_hard[i][j*2+1] = 0xffffe000 | result_hard[i][j*2+1]; // Sign extension
}
mnist_nn_float(&t_tran_float[i][0], &result_soft[i][0]);
}
int errflag=0;
for(int i=0; i<NUM_ITERATIONS; i++){
max_id_hw = max_int32_t(&result_hard[i][0]);
max_id_sw = max_float(&result_soft[i][0]);
max_id_ref = max_float(&t_test[i][0]);
if(max_id_ref != max_id_hw){
printf("id = %d, max_id_ref = %d, max_id_hw = %d\n", i, max_id_ref, max_id_hw);
errflag = 1;
}
if(max_id_ref != max_id_sw){
printf("id = %d, max_id_ref = %d, max_id_sw = %d\n", i, max_id_ref, max_id_sw);
errflag = 1;
}
}
if(errflag == 0)
printf("No Error\n");
return(0);
}
int mnist_nn_float(float in[784], float out[10]){
float dot1[50];
float dot2[10];
for(int col=0; col<50; col++){
dot1[col] = 0;
for(int row=0; row<784; row++){
dot1[col] += in[row]*af1_fweight[row][col];
}
dot1[col] += af1_fbias[col];
if(dot1[col] < 0) // ReLU
dot1[col] = 0;
}
for(int col=0; col<10; col++){
dot2[col] = 0;
for(int row=0; row<50; row++){
dot2[col] += dot1[row]*af2_fweight[row][col];
}
dot2[col] += af2_fbias[col];
if(dot2[col] < 0) // ReLU
dot2[col] = 0;
out[col] = dot2[col];
}
return(0);
}
int max_float(float out[10]){
int max_id;
float max;
for(int i=0; i<10; i++){
if(i == 0){
max = out[0];
max_id = 0;
}else if(out[i]>max){
max = out[i];
max_id = i;
}
}
return(max_id);
}
int max_int32_t(int32_t out[10]){
int max_id;
int32_t max;
for(int i=0; i<10; i++){
if(i == 0){
max = out[0];
max_id = 0;
}else if(out[i]>max){
max = out[i];
max_id = i;
}
}
return(max_id);
}
日 | 月 | 火 | 水 | 木 | 金 | 土 |
---|---|---|---|---|---|---|
- | - | 1 | 2 | 3 | 4 | 5 |
6 | 7 | 8 | 9 | 10 | 11 | 12 |
13 | 14 | 15 | 16 | 17 | 18 | 19 |
20 | 21 | 22 | 23 | 24 | 25 | 26 |
27 | 28 | 29 | 30 | - | - | - |