summaryrefslogtreecommitdiffstats
path: root/fluent-bit/lib/wasm-micro-runtime-WAMR-1.2.2/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c
blob: 2fa5165384ceaba21d22b1c196cfb83f32c7e750 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
/*
 * Copyright (C) 2019 Intel Corporation.  All rights reserved.
 * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 */

#include <stdio.h>
#include <stdlib.h>
#include <assert.h>
#include <string.h>
#include <math.h>

#include "utils.h"
#include "logger.h"

void
test_sum(execution_target target)
{
    int dims[] = { 1, 5, 5, 1 };
    input_info input = create_input(dims);

    uint32_t output_size = 0;
    float *output = run_inference(target, input.input_tensor, input.dim,
                                  &output_size, "/assets/models/sum.tflite", 1);

    assert(output_size == 1);
    assert(fabs(output[0] - 300.0) < EPSILON);

    free(input.dim);
    free(input.input_tensor);
    free(output);
}

void
test_max(execution_target target)
{
    int dims[] = { 1, 5, 5, 1 };
    input_info input = create_input(dims);

    uint32_t output_size = 0;
    float *output = run_inference(target, input.input_tensor, input.dim,
                                  &output_size, "/assets/models/max.tflite", 1);

    assert(output_size == 1);
    assert(fabs(output[0] - 24.0) < EPSILON);
    NN_INFO_PRINTF("Result: max is %f", output[0]);

    free(input.dim);
    free(input.input_tensor);
    free(output);
}

void
test_average(execution_target target)
{
    int dims[] = { 1, 5, 5, 1 };
    input_info input = create_input(dims);

    uint32_t output_size = 0;
    float *output =
        run_inference(target, input.input_tensor, input.dim, &output_size,
                      "/assets/models/average.tflite", 1);

    assert(output_size == 1);
    assert(fabs(output[0] - 12.0) < EPSILON);
    NN_INFO_PRINTF("Result: average is %f", output[0]);

    free(input.dim);
    free(input.input_tensor);
    free(output);
}

void
test_mult_dimensions(execution_target target)
{
    int dims[] = { 1, 3, 3, 1 };
    input_info input = create_input(dims);

    uint32_t output_size = 0;
    float *output =
        run_inference(target, input.input_tensor, input.dim, &output_size,
                      "/assets/models/mult_dim.tflite", 1);

    assert(output_size == 9);
    for (int i = 0; i < 9; i++)
        assert(fabs(output[i] - i) < EPSILON);

    free(input.dim);
    free(input.input_tensor);
    free(output);
}

void
test_mult_outputs(execution_target target)
{
    int dims[] = { 1, 4, 4, 1 };
    input_info input = create_input(dims);

    uint32_t output_size = 0;
    float *output =
        run_inference(target, input.input_tensor, input.dim, &output_size,
                      "/assets/models/mult_out.tflite", 2);

    assert(output_size == 8);
    // first tensor check
    for (int i = 0; i < 4; i++)
        assert(fabs(output[i] - (i * 4 + 24)) < EPSILON);
    // second tensor check
    for (int i = 0; i < 4; i++)
        assert(fabs(output[i + 4] - (i + 6)) < EPSILON);

    free(input.dim);
    free(input.input_tensor);
    free(output);
}

int
main()
{
    char *env = getenv("TARGET");
    if (env == NULL) {
        NN_INFO_PRINTF("Usage:\n--env=\"TARGET=[cpu|gpu]\"");
        return 1;
    }
    execution_target target;
    if (strcmp(env, "cpu") == 0)
        target = cpu;
    else if (strcmp(env, "gpu") == 0)
        target = gpu;
    else {
        NN_ERR_PRINTF("Wrong target!");
        return 1;
    }
    NN_INFO_PRINTF("################### Testing sum...");
    test_sum(target);
    NN_INFO_PRINTF("################### Testing max...");
    test_max(target);
    NN_INFO_PRINTF("################### Testing average...");
    test_average(target);
    NN_INFO_PRINTF("################### Testing multiple dimensions...");
    test_mult_dimensions(target);
    NN_INFO_PRINTF("################### Testing multiple outputs...");
    test_mult_outputs(target);

    NN_INFO_PRINTF("Tests: passed!");
    return 0;
}