summaryrefslogtreecommitdiffstats
path: root/ml/KMeans.h
blob: 0398eeb86378b6ec70f81ae622cfe61922b94ec5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
// SPDX-License-Identifier: GPL-3.0-or-later

#ifndef KMEANS_H
#define KMEANS_H

#include <atomic>
#include <vector>
#include <limits>
#include <mutex>

#include "SamplesBuffer.h"
#include "json/single_include/nlohmann/json.hpp"

class KMeans {
public:
    KMeans(size_t NumClusters = 2) : NumClusters(NumClusters) {
        MinDist = std::numeric_limits<CalculatedNumber>::max();
        MaxDist = std::numeric_limits<CalculatedNumber>::min();
    };

    void train(const std::vector<DSample> &Samples, size_t MaxIterations);
    CalculatedNumber anomalyScore(const DSample &Sample) const;

    void toJson(nlohmann::json &J) const {
        J = nlohmann::json{
            {"CCs", ClusterCenters},
            {"MinDist", MinDist},
            {"MaxDist", MaxDist}
        };
    }

private:
    size_t NumClusters;

    std::vector<DSample> ClusterCenters;

    CalculatedNumber MinDist;
    CalculatedNumber MaxDist;
};

#endif /* KMEANS_H */