/* netad.cpp, Matt Mahoney, mmahoney@cs.fit.edu

Copyright (C) 2002, Matt Mahoney.  This program is distributed
without warranty under terms of the GNU general public license.
See http://www.gnu.org/licenses/gpl.txt

Anomaly detection system.  To use:

  netad tcpdump_files... |sort /+46 /r >x
  eval4 s=x

*/

#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <vector>
#include <map>
#include <algorithm>
#include <ctime>
#include <cstring>
using namespace std;

// 8-32 bit int types
typedef unsigned char U8;
typedef unsigned long U32;

const int M=48;  // Number of attributes
const char* const label[]={  // Attribute names
  "IPver", "TOS", "IPlen1", "IPlen0", "ID1", "ID0", "Frag1", "Frag0",
  "TTL", "Prot", "IPchk1", "IPchk0", "SA3", "SA2", "SA1", "SA0",
  "DA3", "DA2", "DA1", "DA0", "SP1", "SP0", "DP1", "DP0",
  "Seq3", "Seq2", "Seq1", "Seq0", "Ack3", "Ack2", "Ack1", "Ack0",
  "TCPhdr", "UAPRSF", "Win1", "Win0", "TCPchk1", "TCPchk0", "Urg1", "Urg0",
  "C0", "C1", "C2", "C3", "C4", "C5", "C6", "C7", "C8", "C9",
  "C10", "C11", "C12", "C13", "C14", "C15", "C16", "C17", "C18", "C19"};

// Anomaly detection engine
class IDS {
  struct Rule {
    int trains, tests;  // Number of eval calls with h=0, 1
    int n[M];  // Training tuple count since last anomaly
    int t[M];  // Time of last anomaly
    int r[M];  // Number of values allowed in training
    int val[M][256];  // Training count
    int last[M][256]; // Time of last occurrence
    Rule* p[M][256];  // Child pointers by antecedent
    Rule() {memset(this, 0, sizeof(Rule));}
    void print(FILE* f=stdout, int level=0) const;  // Print in readable form
    void eval(const U8* tuple, int h, double* scores, int now);
  };
  Rule* tree;  // Tree child nodes keyed by anteceedent (attr == value)
  void eval(const U8* tuple, int h, double* scores, int now, Rule* branch);
  void print(FILE* f, Rule* p, int level) const;  // Recursive print
public:
  IDS();
  double eval(const U8* tuple, int h, double* scores, int now);
  ~IDS();
};

IDS::IDS() {
  tree=new Rule;
//  tree->p[6][64]=new Rule;  // DF
  tree->p[9][6]=new Rule;  // TCP
  tree->p[9][6]->p[33][2]=new Rule;  // TCP SYN
  tree->p[9][6]->p[33][16]=new Rule;  // TCP ACK
  tree->p[9][6]->p[33][16]->p[22][0]=new Rule;  // Ports 0-255
  tree->p[9][6]->p[33][16]->p[22][0]->p[23][21]=new Rule;  // FTP
  tree->p[9][6]->p[33][16]->p[22][0]->p[23][23]=new Rule;  // Telnet
  tree->p[9][6]->p[33][16]->p[22][0]->p[23][25]=new Rule;  // SMTP
  tree->p[9][6]->p[33][16]->p[22][0]->p[23][80]=new Rule;  // HTTP
}

void IDS::Rule::print(FILE* f, int level) const {
  for (int j=0; j<level; ++j)
    fprintf(f, "  ");
  fprintf(f, "%d trains, %d tests, tn/r:\n", trains, tests);
  for (int i=0; i<M; ++i) {
    for (int j=0; j<level; ++j)
      fprintf(f, "  ");
    fprintf(f, "%s %d %d/%d", label[i], t[i], n[i], r[i]);
    for (int j=0; j<256; ++j) {
      if (val[i][j]*10>trains || (val[i][j] && r[i]<10))
        fprintf(f, " %02X=%d", j, val[i][j]);
    }
    fprintf(f, "\n");
  }
}

// Recursively print rules to f for tree rooted at p
void IDS::print(FILE* f, Rule* p, int level) const {
  if (!p)
    return;
  for (int i=0; i<M; ++i) {
    for (int j=0; j<256; ++j) {
      if (p->p[i][j]) {
        for (int k=0; k<level; ++k)
          fprintf(f, "  ");
        fprintf(f, "if %s=%02X then\n", label[i], j);
        print(f, p->p[i][j], level+1);
        for (int k=0; k<level; ++k)
          fprintf(f, "  ");
        fprintf(f, "end if %s=%02X\n", label[i], j);
      }
    }
  }
  p->print(f, level);
}

// Print IDS rules to file rules.txt
IDS::~IDS() {

  // Print rules to file rules.txt
  FILE* f=fopen("rules.txt", "w");
  if (!f)
    return;
  print(f, tree, 0);
  fclose(f);
}

// Train (h=0) or test (h=1) one rule on tuple[M], add to anomaly scores[M]
// at time now
void IDS::Rule::eval(const U8* tuple, int h, double* scores, int now) {
  if (h)
    ++tests;
  else
    ++trains;
  for (int i=0; i<M; ++i) {
    if (!h)
      ++n[i];
    if (val[i][tuple[i]]==0) {  // Anomaly
      if (r[i]>0 && trains>0)
        scores[i]+=double(now-t[i])*n[i]/r[i]*(256-r[i])/256.0;
      if (!h) {
        ++r[i];
        n[i]=0;
      }
      t[i]=now;
    }
    if (trains>0)
    scores[i]+=double(now-last[i][tuple[i]])/(val[i][tuple[i]]+r[i]/256.0);
    if (!h)
      ++val[i][tuple[i]];
    last[i][tuple[i]]=now;
  }
}

// Recursively train/test rule tree rooted at p
void IDS::eval(const U8* tuple, int h, double* scores, int now, Rule* p) {
  if (!p)
    return;
  int count=0;
  for (int i=0; i<M; ++i) {
    if (p->p[i][tuple[i]]) {
      eval(tuple, h, scores, now, p->p[i][tuple[i]]);
      ++count;
    }
  }
  if (!count)
    p->eval(tuple, h, scores, now);
}

// Train/test all rules, return sum of anomaly scores
double IDS::eval(const U8* tuple, int h, double* scores, int now) {

  // Clear scores
  for (int i=0; i<M; ++i)
    scores[i]=0;

  // Evaluate all rules recursively
  eval(tuple, h, scores, now, tree);

  // Add up scores
  double score=0;
  for (int i=0; i<M; ++i)
    score+=scores[i];
  return score;
}

// Converts 4 bytes to 32 bit int, either MSB first or LSB first
class I4 {
public:
  bool msb_first;
  I4(): msb_first(true) {}
  U32 operator()(U8* p) const {
    if (msb_first)
      return (U32(p[0])<<24)|(U32(p[1])<<16)|(p[2]<<8)|p[3];
    else
      return (U32(p[3])<<24)|(U32(p[2])<<16)|(p[1]<<8)|p[0];
  }
} i4;  // global functoid

// Return UCT time t in EST or EDT (for Mar/Apr 1999 only)
const char* print_time(U32 t) {
  time_t ts = (time_t) t;  // Whole seconds
  ts-=18000;  // Convert to EST
  if (ts>=923205600)
    ts+=3600;  // Convert to EDT after 0200 4/4/1999
  struct tm *tp = localtime(&ts);
  static char buf[50];
  strftime(buf, 50, "%m/%d/%Y %H:%M:%S", tp);
  return buf;
}

int main(int argc, char** argv) {

  // Check program args
  if (argc<2) {
    fprintf(stderr, "Usage: olerad2 tcpdump_files... |sort /+46 /r\n");
    return 1;
  }

  IDS ids;  // Model
  int anomalies=0;  // Number of anomalies output
  double highscore=0;  // Highest anomaly score
  int now=0;  // Number of tuples

  // Open each file, skip if not found
  for (int i=1; i<argc; ++i) {
    FILE* in=fopen(argv[i], "rb");
    if (!in) {
      perror(argv[i]);
      continue;
    }

    // Read TCPDUMP header, skip if bad
    const int MAX_BUF=1532;
    static U8 buf[MAX_BUF];  // Input buffer
    if (fread(buf, 1, 24, in)!=24) {
      fprintf(stderr, "%s: file too small\n", argv[i]);
      continue;
    }

    // Determine if input is MSB or LSB first
    i4.msb_first=true;
    if (i4(buf)!=0xa1b2c3d4)
      i4.msb_first=false;
    if (i4(buf)!=0xa1b2c3d4) {
      fprintf(stderr, "%s: not in tcpdump format\n", argv[i]);
      continue;
    }

    // Read packets
    while (true) {
      if (fread(buf, 1, 16, in)!=16)
        break;  // EOF
      U32 seconds=i4(buf);  // Time in seconds
      U32 len1=i4(buf+8);  // Captured packet length
      U32 len2=i4(buf+12); // Original packet length
      if (len1>len2 || len2>MAX_BUF-16) {
        fprintf(stderr, "%s: corrupted: len1=%lu len2=%lu at %ld\n",
          argv[i], len1, len2, ftell(in));
        break;
      }
      if (fread(buf+16, 1, len1, in)!=len1)
        break;  // EOF
      int len=min(U32(buf[32]*256+buf[33]+14), len1); // min IP, Ether length
      if (MAX_BUF-len>16)
        memset(buf+len+16, 0, MAX_BUF-len-16);

      // Remove artifacts
      buf[38]=0;  // TTL

      // Evaluate packet and print anomaly
      const U8* tuple=buf+30;
      static double scores[M];
      double score=ids.eval(tuple, (seconds>=0x36fc0000), scores, ++now);
      if (seconds>=0x36fc0000 && score>1 && score*1000000>highscore) {
        if (score>highscore)
          highscore=score;
        const static double DECABEL=0.1/log(10);
        printf("       0 %s %03d.%03d.%03d.%03d %8.6f #",
            print_time(seconds), buf[46], buf[47], buf[48], buf[49],
            log(score)*DECABEL);
        for (int i=0; i<M; ++i) {
          if (scores[i]*10>score)
            printf(" %s=%02X,%1.0f%%", label[i], tuple[i],
                scores[i]*100/score);
        }
        printf("\n");
        ++anomalies;
      }
    }
    fclose(in);
  }
  fprintf(stderr, "%d anomalies out of %d\n", anomalies, now);
  return 0;
}

