/* leradp.cpp - Anomaly rule learning algorithm and detector

Copyright (C) 2003, Matt Mahoney.  This program is distributed
without warranty under terms of the GNU general public license.
See http://www.gnu.org/licenses/gpl.txt

Usage: leradp seed train.tcpdump test.tcpdump... | perl afil.pl | eval

LERADP reads tcpdump files, trains on the first one, and reports
anomalies in subsequent files in format suitable for EVAL, e.g.

iiiiiiii MM/DD/YYYY HH:MM:SS 172.016.DA1.DA0 S.SSSSSS #Comments

The format is an ID number (0), date, time (converted
to local time if you compile with Borland, not DJGPP), destination
IP address and anomaly score between 0 and 9.999999.  Comments
include the source IP and the rule violated.  LERADP uses a randomized
learning algorithm initialized with a seed (an integer).  Results
can be repeated if the same seed is used.

LERADP learns rules from the first tcpdump file.  It uses the LERAD
algorithm but extracts attributes directly from the input packets
rather than a database file.  For instance, given attributes
A, B, C, D and the tuples:

A B C D
1 2 3 4
1 2 3 5

It might learn the following rules (output to rules.txt)

1 n=2 if then B=2
2 n=2 if A=1 then B=2
3 n=2 if A=1 C=3 then B=2

where n is the number of tuples satisfying the anteceedent.
Then given the tuple,

A B C D
1 8 4 5

then the anteceedent of the first 2 rules are satisfied, and they
are updated:

1 n=3 if then B=2 3
2 n=3 if A=1 then B=2 3
3 n=2 if A=1 C=3 then B=2 (unchanged)

If a tuple in subsequent files satisfies the anteceedent but the consequent
is not in the list, then an alarm is generated with an anomaly score
proportional to tn/r, where t is the time (in number of tuples) since
the last anomaly and r is the number of allowed values in the consequent
(r=2 for rules 1 and 2, r=1 for rule 3).  For instance, if the tuple
below occurs 100 tuples later,

A B C D
1 8 3 0

then since it violates rule 3, it generates an anomlay score of
100 * 2 / 1 = 200.

*/

#include <cstdio>
#include <cstdlib>
#include <cctype>
#include <ctime>
#include <cmath>
#include <string>
#include <vector>
#include <map>
#include <algorithm>
#include <iostream>
#include <fstream>
#include <iomanip>
using namespace std;

// Convert 2 or 4 bytes to int, MSB first
int i2(const unsigned char* p) {
  return (p[0]<<8)|p[1];
}

int i3(const unsigned char* p) {
  return (((p[0]<<8)|p[1])<<8)|p[2];
}

unsigned long i4(const unsigned char* p) {
  return (((((p[0]<<8)|p[1])<<8)|p[2])<<8)|p[3];
} 

// Return the time (seconds since 1970 local) in a readable format
const char* print_time(unsigned long seconds) {
  static char s[30];
  time_t t=time_t(seconds);
  tm* ts=localtime(&t);
  if (ts)
    strftime(s, 30, "%m/%d/%Y %H:%M:%S", ts);
  else
    s[0]=0;
  return s;
}

/* PacketReader - a class for reading tcpdump packets from a file.

PacketReader pr(int argc, char** argv);

  Prepares pr to read packets from a list of files named in argv[0..argc-1].
  Files must be tcpdump files.

const unsigned char* pr.read()

  Reads one packet and returns it, or 0 at end of last file.
  The first call reads the first packet from argv[0].  At the end of
  each file, the file is closed and read() returns the first packet from
  the next file.  The length is i4(pr.read()+12) bytes.
*/

class PacketReader {
private:
  enum {MAX_PACKET=1600};  // Max packet size including tcpdump header
  unsigned char* buf;  // Current input packet, MAX_PACKET bytes
  int argc;  // Number of files remaining to be read
  const char* const* argv;  // Names of files remaining to be read
  FILE* f;  // Currently open file, or 0 if all are closed
  void close(const char* msg=0);  // Close file, print msg if any
  bool little_endian;  // If true, reverse 4x4 header bytes
public:
  PacketReader(int ac, const char* const* av):
    buf(new unsigned char[MAX_PACKET]), argc(ac), argv(av), f(0),
      little_endian(false) {}
  ~PacketReader() {delete[] buf;}
  const unsigned char* read();
};

// Close f and go to next file.  If msg is not 0, print error message
void PacketReader::close(const char* msg) {
  if (f) {
    fclose(f);
    f=0;
  }
  if (msg)
    fprintf(stderr, "%s: %s\n", argv[0], msg);
  --argc;
  ++argv;
}

// Read a packet and return its timestamp, or 0 at EOF
const unsigned char* PacketReader::read() {
  while (true) {
    if (f) {
      if (fread(buf, 1, 16, f)!=16)
        close("end of file");  // EOF
      else {
        if (little_endian) {  // Convert Windows SNORT headers to big endian
          for (int i=0; i<16; i+=4) {
            swap(buf[i], buf[i+3]);
            swap(buf[i+1], buf[i+2]);
          }
        }
        unsigned long len1=i4(buf+8);  // Recorded length <= len2
        unsigned long len2=i4(buf+12); // Original length <= MAX_PACKET-16
        if (len1>len2 || len2>MAX_PACKET-16)
          close("bad tcpdump packet header");
        else if (fread(buf+16, 1, len1, f)!=len1)
          close("truncated packet");
        else
          return buf;
      }
    }
    else {  // Open file
      if (argc<1)
        return 0;  // No file to open
      fprintf(stderr, "%s\n", argv[0]);
      f=fopen(argv[0], "rb");
      if (!f)
        close("file not found");
      else if (fread(buf, 1, 24, f)!=24)
        close("file is too small");
      else if (i4(buf)==0xd4c3b2a1)
        little_endian=true;
      else if (i4(buf)==0xa1b2c3d4)
        little_endian=false;
      else
        close("not in tcpdump format");
    }
  }
}

// LERAD follows

const double LOG10=log(10);

const int m=32;  // Global number of attributes

typedef int Nominal;  // Attribute values
typedef map<Nominal, int> Val;  // Set of allowed values
typedef vector<Nominal> Anteceedent;  // Conditions: 0=don't care, 1=cons.

// Convert packet p to tuple t[m], return false to ignore packet
// Attribute values should be 2 or higher (0 means *, 1 means ?)
bool p2t(const unsigned char* p, vector<Nominal>& t) {
  int len=i4(p+8)+16;  // End of captured packet
  for (int i=0, j=30; i<m; ++i, j+=2)
    t[i]=j<len-1 ? i2(p+j)+3 : 2;
  return true;
}

struct Consequent {  // Consequent and counts of a rule
  int n;  // Support
  int t;  // Time of last anomaly
  Val val;  // Allowed values
  Consequent(): n(0), t(0) {}
  Consequent(Nominal v): n(0), t(0) {val[v]=1;}
  int add(Nominal v) {++n; return ++val[v];}
};

// A ruleset maps anteceedednts to consequents.
// The anteceedent is m elements where 0 = any, 1 = consequent, 2 = absent,
// 3 or more is the constrained value for the attribute.
typedef map<Anteceedent, Consequent> Ruleset;

struct Rule: public Consequent {  // Complete rule
  Anteceedent rule;  // 0=* 1=? >1 = constraint
  int cons;  // index of ? in rule
  vector<int> ante;  // Indexes of anteceedents >1
  Rule(): cons(0) {}
  Rule(const Anteceedent& v, const Consequent& c): Consequent(c), rule(v) {
    cons=rule.size();
    for (int i=0; i<int(rule.size()); ++i) {
      if (rule[i]==1)
        cons=i;
      else if (rule[i]>1)
        ante.push_back(i);
    }
  }
  bool match(const Anteceedent& a) const {  // Rule applies to a?
    for (int i=0; i<int(ante.size()); ++i)
      if (rule[ante[i]]!=a[ante[i]])
        return false;
    return true;
  }
};

// For sorting rules by max n/r, and by min anteceedents if equal
bool operator < (const Rule& a, const Rule& b) {
  const int nr1=a.n*b.val.size();
  const int nr2=b.n*a.val.size();
  return nr1>nr2 || (nr1==nr2 && a.ante.size()<b.ante.size());
}

// 30-31 bit random number
inline int rnd() {return (rand()+(rand()<<15))&0x7fffffff;}

int main(int argc, const char* const* argv) {

  // Test args
  if (argc<4) {
    cerr << "Usage: " << argv[0] << " seed train test...\n";
    return 1;
  }

  // Init random numbers
  srand(argc>1 ? atoi(argv[1]) : 0);

  // Read relation from training file
  PacketReader pr(1, argv+2);
  const unsigned char* packet;
  vector<Anteceedent> relation;  // [tuple][attribute]
  Anteceedent t(m);
  while ((packet=pr.read())!=0) {
    if (p2t(packet, t))
      relation.push_back(t);
  }
  if (relation.size()<2) {
    cerr << "Only " << relation.size() << " tuples\n";
    return 1;
  }

  // Select random samples from the relation for preliminary training
  vector<int> samples;  // Saved ascending random indexes of relation
  int r=min(int(relation.size()), 200);  // Number of samples still needed
  for (int i=0; r>0 && i<int(relation.size()); ++i) {
    int rn=rnd()%(int(relation.size())-i);
    if (rn<r) {
      --r;
      samples.push_back(i);
    }
  }

  // Construct a ruleset.
  // The key is a rule, coded 0="*", 1="?", 2 or more = constrained
  // attribute.  Rules are constructed by sampling 2 tuples
  // in the relation and constructing up to 4 rules for 4 randomly
  // selected matching attributes in random order where the first is "?"
  // and the other 3 are the anteceedents.

  cerr << relation.size() << " tuples.  Constructing ruleset\n";
  Ruleset ruleset;
  Anteceedent rule(m);
  for (int i=0; i<2000; ++i) {  // Number of tuple pairs
/*
    int r1=rnd()%relation.size();  // Pick 2 random tuples
    int r2=rnd()%(relation.size()-1);
    if (r1==r2)
      r2=relation.size()-1;  // so r1 != r2
*/
    // Pick random pair of tuples from the sample set
    int r1=rnd()%samples.size();
    int r2=rnd()%(samples.size()-1);
    if (r1==r2)
      r2=samples.size()-1;
    r1=samples[r1];
    r2=samples[r2];

    // Generate rules by matching attribute values
    for (int j=0; j<m; ++j)
      rule[j]=0;
    int count=0;
    int result=0;
    for (int j=0; j<m*4 && count<4; ++j) {
      int r3=rnd()%m;  // Pick random attribute
      if (relation[r1][r3]>=3 && relation[r1][r3]==relation[r2][r3]) {  
        if (count==0) {  // First match is ?
          result=relation[r1][r3];
          rule[r3]=1;
          count=1;
          ruleset[rule];
        }
        else if (rule[r3]==0) {  // Other matches are anteceedents
          rule[r3]=relation[r1][r3];
          ++count;
          ruleset[rule]=result;
        }
      }
    }
  }

  // Estimate ruleset support by sampling
  cerr << ruleset.size() << " rules.  Estimating support\n";
  vector<Rule> rules;  // Sorted ruleset
  for (Ruleset::iterator p=ruleset.begin(); p!=ruleset.end(); ++p) {
    int k;
    for (k=0; k<m; ++k)  // Find "?"
      if (p->first[k]==1)
        break;
    if (k==m) {
      cerr << "oops, rule with missing ?\n";
      return 1;
    }
    int i;
    for (i=0; i<int(samples.size()); ++i) {
      int r=samples[i];
      int j;
      for (j=0; j<m; ++j)
        if (p->first[j]>=2 && p->first[j]!=relation[r][j])
          break;
      if (j==m) // Rule applies?
        p->second.add(relation[r][k]);
    }
    rules.push_back(Rule(p->first, p->second));
  }

  // Estimate n/r again so that each tuple attribute is predicted by
  // only the rule with the highest n/r from before
  cerr << "Removing duplicate rules from " << rules.size() << " rules\n";
  sort(rules.begin(), rules.end());
  vector<vector<bool> > cover(samples.size());  // cover[i][j] is true
    // if there is a rule predicting attribute relation[sample[i]][j]
  for (int i=0; i<int(cover.size()); ++i)
    cover[i].resize(m);
  for (int i=0; i<int(rules.size()); ++i) {
    Rule& r=rules[i];
    int cons=r.cons;
    r.val.clear();
    r.n=0;
    for (int j=0; j<int(samples.size()); ++j) {
      if (!cover[j][cons] && r.match(relation[samples[j]])) {
        r.add(relation[samples[j]][cons]);
        cover[j][cons]=true;
      }
    }
  }

  // Discard unsupported rules
  for (int i=0; i<int(rules.size()); ++i) {
    if (rules[i].n==0) {
      rules[i]=rules.back();
      rules.pop_back();
      --i;
    }
  }

  // Calculate exact support for top rules on entire training set
  cerr << "Calculating exact coverage for " << rules.size() << " rules\n";
  sort(rules.begin(), rules.end());
  cover.clear();
  cover.resize(relation.size());
  for (int i=0; i<int(cover.size()); ++i)
    cover[i].resize(m);
  int now=0;  // Time
  for (int i=0; i<int(rules.size()); ++i) {
    int cons=rules[i].cons;
    rules[i].n=0;
    rules[i].val.clear();
    now=0;
    for (int j=0; j<int(relation.size()); ++j) {
      ++now;
      if (rules[i].match(relation[j])) {
        cover[j][cons]=true;
        if (rules[i].add(relation[j][cons])==1) {
          rules[i].t=now;

          // Late anomaly or r too high, remove rule
          if (j*10 > int(relation.size()*9) || rules[i].val.size() > 32) {
            rules[i]=rules.back();
            rules.pop_back();
            --i;
            break;
          }
        }
      }
    }
  }

  // Print coverage
  cerr << "Coverage:";
  for (int i=0; i<m; ++i) {
    int count=0;
    for (int j=0; j<int(cover.size()); ++j)
      if (cover[j][i])
        ++count;
    cerr << i << "=" << count << " ";
  }
  cerr << "\n";

  // Output rules.txt
  sort(rules.begin(), rules.end());
  ofstream out("rules.txt");
  if (out) {
    for (int i=0; i<int(rules.size()); ++i) {
      out << i+1 << " " << rules[i].n << "/" << rules[i].val.size() << " if";
      for (int j=0; j<m; ++j) {
        if (rules[i].rule[j]>1)
          out << " " << j << "=" << rules[i].rule[j]-3;
      }
      out << " then " << rules[i].cons << " =";
      for (Val::iterator q=rules[i].val.begin(); q!=rules[i].val.end(); ++q)
        out << " " << (q->first)-3;
      out << "\n";
    }
    out.close();
  }

  // Statistics
  int nsum=0, nrsum=0;
  for (int i=0; i<int(rules.size()); ++i) {
    nsum+=rules[i].n;
    nrsum+=rules[i].n*rules[i].val.size();
    if (i==0 || i==9 || i==99 || i==999 || i==9999 || i==99999
        || i==int(rules.size())-1) {
      cerr << "Rules: " << i+1 << " Coverage: " << nsum << " "
           << (double(nsum)/relation.size()/m)
           << " Diversity: " << double(nrsum)/nsum << "\n";
    }
  }
  cerr << "\n";

  // Score anomalies in test.txt
  cerr << "Testing with " << rules.size() << " rules\n";
  PacketReader pr2(argc-3, argv+3);
  vector<int> ta(rules.size());  // Time of last anomaly per rule
  int alarms=0, tests=0;
  while ((packet=pr2.read())!=0) {

    // Read tuple t
    ++now;
    if (!p2t(packet, t))
      continue;
    if (++tests%10000==0)
      cerr << tests/10000 << " ";

    // Evaluate tuple t
    double score=0;  // Total score of tuple
    double bs=0;  // Highest scoring rule
    int brule=0;  // Index of highest scoring rule

    for (int i=0; i<int(rules.size()); ++i) {
      double score1=0;  // score for this rule
      const Rule& r=rules[i];
      if (r.match(t) && r.val.find(t[r.cons])==r.val.end() && r.val.size()>0){
        score1=double(now-rules[i].t)*rules[i].n/rules[i].val.size();
        if (score1>bs) {
          brule=i;
          bs=score1;
        }
        score+=score1;
        rules[i].t=now;
      }
    }

    // Print anomaly
    double pct=0;
    if (score>0) {
      pct=100*bs/score;
      score=log(score)/LOG10-5;
    }
    if (score>0) {
      ++alarms;
      cout << "       0 " << print_time(i4(packet)) << " "
        << setw(3) << setfill('0') << int(packet[46]) << "."
        << setw(3) << setfill('0') << int(packet[47]) << "."
        << setw(3) << setfill('0') << int(packet[48]) << "."
        << setw(3) << setfill('0') << int(packet[49]);
      cout << " " << int(score)
        << "." << setw(6) << setfill('0')
        << int(score*1000000)-int(score)*1000000
        << " # "
        << setw(3) << setfill('0') << int(packet[42]) << "."
        << setw(3) << setfill('0') << int(packet[43]) << "."
        << setw(3) << setfill('0') << int(packet[44]) << "."
        << setw(3) << setfill('0') << int(packet[45]) << " "
        << setw(3) << brule+1 << " ("
        << setprecision(4) << pct << ")";
      for (int i=0; i<m; ++i) {
        if (rules[brule].rule[i]>=1) {
          cout << " " << i;
          if (rules[brule].rule[i]==1)
            cout << "?";
          cout << "=" << t[i]-3;
        }
      }
      cout << "\n";
    }
  }
  cerr << "\n" << tests << " tests, " << alarms << " alarms\n";
  return 0;
}

