summaryrefslogtreecommitdiffstats
path: root/ml/kmeans/KMeans.h
blob: 4ea3b6a89b3d1c9e6a43f8201eb996a8905b8d55 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
// SPDX-License-Identifier: GPL-3.0-or-later

#ifndef KMEANS_H
#define KMEANS_H

#include <atomic>
#include <vector>
#include <limits>
#include <mutex>

#include "SamplesBuffer.h"

class KMeans {
public:
    KMeans(size_t NumClusters = 2) : NumClusters(NumClusters) {
        MinDist = std::numeric_limits<CalculatedNumber>::max();
        MaxDist = std::numeric_limits<CalculatedNumber>::min();
    };

    void train(SamplesBuffer &SB, size_t MaxIterations);
    CalculatedNumber anomalyScore(SamplesBuffer &SB);

private:
    size_t NumClusters;

    std::vector<DSample> ClusterCenters;

    CalculatedNumber MinDist;
    CalculatedNumber MaxDist;

    std::mutex Mutex;
};

#endif /* KMEANS_H */