summaryrefslogtreecommitdiffstats
path: root/utils/fann_train.pl
diff options
context:
space:
mode:
Diffstat (limited to 'utils/fann_train.pl')
-rwxr-xr-xutils/fann_train.pl247
1 files changed, 247 insertions, 0 deletions
diff --git a/utils/fann_train.pl b/utils/fann_train.pl
new file mode 100755
index 0000000..2ce422e
--- /dev/null
+++ b/utils/fann_train.pl
@@ -0,0 +1,247 @@
+#!/usr/bin/env perl
+
+# This script is a very simple prototype to learn fann from rspamd logs
+# For now, it is intended for internal use only
+
+use strict;
+use warnings FATAL => 'all';
+use AI::FANN qw(:all);
+use Getopt::Std;
+
+my %sym_idx; # Symbols by index
+my %sym_names; # Symbols by name
+my $num = 1; # Number of symbols
+my @spam;
+my @ham;
+my $max_samples = -1;
+my $split = 1;
+my $preprocessed = 0; # output is in format <score>:<0|1>:<SYM1,...SYMN>
+my $score_spam = 12;
+my $score_ham = -6;
+
+sub process {
+ my ( $input, $spam, $ham ) = @_;
+ my $samples = 0;
+
+ while (<$input>) {
+ if ( !$preprocessed ) {
+ if (/^.*rspamd_task_write_log.*: \[(-?\d+\.?\d*)\/(\d+\.?\d*)\]\s*\[(.+)\].*$/) {
+ if ( $1 > $score_spam ) {
+ $_ = "$1:1: $3";
+ }
+ elsif ( $1 < $score_ham ) {
+ $_ = "$1:0: $3\n";
+ }
+ else {
+ # Out of boundary
+ next;
+ }
+ }
+ else {
+ # Not our log message
+ next;
+ }
+ }
+
+ $_ =~ /^(-?\d+\.?\d*):([01]):\s*(\S.*)$/;
+
+ my $is_spam = 0;
+
+ if ( $2 == 1 ) {
+ $is_spam = 1;
+ }
+
+ my @ar = split /,/, $3;
+ my %sample;
+
+ foreach my $sym (@ar) {
+ chomp $sym;
+ if ( !$sym_idx{$sym} ) {
+ $sym_idx{$sym} = $num;
+ $sym_names{$num} = $sym;
+ $num++;
+ }
+
+ $sample{ $sym_idx{$sym} } = 1;
+ }
+
+ if ($is_spam) {
+ push @{$spam}, \%sample;
+ }
+ else {
+ push @{$ham}, \%sample;
+ }
+
+ $samples++;
+ if ( $max_samples > 0 && $samples > $max_samples ) {
+ return;
+ }
+ }
+}
+
+# Shuffle array
+sub fisher_yates_shuffle {
+ my $array = shift;
+ my $i = @$array;
+
+ while ( --$i ) {
+ my $j = int rand( $i + 1 );
+ @$array[ $i, $j ] = @$array[ $j, $i ];
+ }
+}
+
+# Train network
+sub train {
+ my ( $ann, $sample, $result ) = @_;
+
+ my @row;
+
+ for ( my $i = 1 ; $i < $num ; $i++ ) {
+ if ( $sample->{$i} ) {
+ push @row, 1;
+ }
+ else {
+ push @row, 0;
+ }
+ }
+
+ #print "@row -> @{$result}\n";
+
+ $ann->train( \@row, \@{$result} );
+}
+
+sub test {
+ my ( $ann, $sample ) = @_;
+
+ my @row;
+
+ for ( my $i = 1 ; $i < $num ; $i++ ) {
+ if ( $sample->{$i} ) {
+ push @row, 1;
+ }
+ else {
+ push @row, 0;
+ }
+ }
+
+ my $ret = $ann->run( \@row );
+
+ return $ret;
+}
+
+my %opts;
+getopts( 'o:i:s:n:t:hpS:H:', \%opts );
+
+if ( $opts{'h'} ) {
+ print "$0 [-i input] [-o output] [-s scores] [-n max_samples] [-S spam_score] [-H ham_score] [-ph]\n";
+ exit;
+}
+
+my $input = *STDIN;
+
+if ( $opts{'i'} ) {
+ open( $input, '<', $opts{'i'} ) or die "cannot open $opts{i}";
+}
+
+if ( $opts{'n'} ) {
+ $max_samples = $opts{'n'};
+}
+
+if ( $opts{'t'} ) {
+
+ # Test split
+ $split = $opts{'t'};
+}
+if ( $opts{'p'} ) {
+ $preprocessed = 1;
+}
+
+if ( $opts{'H'} ) {
+ $score_ham = $opts{'H'};
+}
+
+if ( $opts{'S'} ) {
+ $score_spam = $opts{'S'};
+}
+
+# ham_prob, spam_prob
+my @spam_out = (1);
+my @ham_out = (0);
+
+process( $input, \@spam, \@ham );
+fisher_yates_shuffle( \@spam );
+fisher_yates_shuffle( \@ham );
+
+my $nspam = int( scalar(@spam) / $split );
+my $nham = int( scalar(@ham) / $split );
+
+my $ann = AI::FANN->new_standard( $num - 1, ( $num + 2 ) / 2, 1 );
+
+my @train_data;
+
+# Train ANN
+for ( my $i = 0 ; $i < $nham ; $i++ ) {
+ push @train_data, [ $ham[$i], \@ham_out ];
+}
+
+for ( my $i = 0 ; $i < $nspam ; $i++ ) {
+ push @train_data, [ $spam[$i], \@spam_out ];
+}
+
+fisher_yates_shuffle( \@train_data );
+
+foreach my $train_row (@train_data) {
+ train( $ann, @{$train_row}[0], @{$train_row}[1] );
+}
+
+print "Trained $nspam SPAM and $nham HAM samples\n";
+
+# Now run fann
+if ( $split > 1 ) {
+ my $sample = 0.0;
+ my $correct = 0.0;
+ for ( my $i = $nham ; $i < $nham * $split ; $i++ ) {
+ my $ret = test( $ann, $ham[$i] );
+
+ #print "@{$ret}\n";
+ if ( @{$ret}[0] < 0.5 ) {
+ $correct++;
+ }
+ $sample++;
+ }
+
+ print "Tested $sample HAM samples, correct matched: $correct, rate: " . ( $correct / $sample ) . "\n";
+
+ $sample = 0.0;
+ $correct = 0.0;
+
+ for ( my $i = $nspam ; $i < $nspam * $split ; $i++ ) {
+ my $ret = test( $ann, $spam[$i] );
+
+ #print "@{$ret}\n";
+ if ( @{$ret}[0] > 0.5 ) {
+ $correct++;
+ }
+ $sample++;
+ }
+
+ print "Tested $sample SPAM samples, correct matched: $correct, rate: " . ( $correct / $sample ) . "\n";
+}
+
+if ( $opts{'o'} ) {
+ $ann->save( $opts{'o'} ) or die "cannot save ann into $opts{o}";
+}
+
+if ( $opts{'s'} ) {
+ open( my $scores, '>', $opts{'s'} ) or die "cannot open score file $opts{'s'}";
+ print $scores "{";
+ for ( my $i = 1 ; $i < $num ; $i++ ) {
+ my $n = $i - 1;
+ if ( $i != $num - 1 ) {
+ print $scores "\"$sym_names{$i}\":$n,";
+ }
+ else {
+ print $scores "\"$sym_names{$i}\":$n}\n";
+ }
+ }
+}