diff -Nwbaur paq-8f.dist/Makefile paq-8f/Makefile --- paq-8f.dist/Makefile 2006-11-23 05:27:45.000000000 -0600 +++ paq-8f/Makefile 2006-12-15 12:00:35.000000000 -0600 @@ -84,9 +84,10 @@ OBJS = $(SRCS:.c=.o) EXE = $(PACKAGE) -ASM = nasm -ASM_FLAGS = -f elf -ASM_SRCS = $(PACKAGE)7asm.asm +ASM = yasm +#ASM_FLAGS = -g stabs -f elf -m amd64 +ASM_FLAGS = -f elf -m amd64 +ASM_SRCS = $(PACKAGE)7asm-x86_64.asm ASM_OBJS = $(ASM_SRCS:.asm=.o) diff -Nwbaur paq-8f.dist/paq7asm-x86_64.asm paq-8f/paq7asm-x86_64.asm --- paq-8f.dist/paq7asm-x86_64.asm 1969-12-31 18:00:00.000000000 -0600 +++ paq-8f/paq7asm-x86_64.asm 2006-12-15 11:54:23.000000000 -0600 @@ -0,0 +1,98 @@ +; YASM assembly language code for PAQ7. +; (C) 2005, Matt Mahoney. +; This is free software under GPL, http://www.gnu.org/licenses/gpl.txt +; +; MINGW g++: nasm paq7asm.asm -f win32 --prefix _ +; DJGPP g++: nasm paq7asm.asm -f coff --prefix _ +; Borland, Mars: nasm paq7asm.asm -f obj --prefix _ +; Linux: yasm paq7asm.asm -f elf -m amd64 +; +; For other Windows compilers try -f win32 or -f obj. Some old versions +; of Linux should use -f aout instead of -f elf. +; +; This code will only work on a Pentium-MMX or higher. It doesn't +; use extended (Katmai/SSE) instructions. + +section .text + +BITS 64 + +; Vector product a*b of n signed words, returning signed dword scaled +; down by 8 bits. n is rounded up to a multiple of 8. + + global dot_product_x86_64 ; (short* a, short* b, int n) + align 16 +dot_product_x86_64: + mov rcx, rdx ; n + mov rax, rdi ; a + mov rdx, rsi ; b + add rcx, 7 ; n rounding up + and rcx, -8 + jz .done + sub rax, 16 + sub rdx, 16 + pxor xmm0, xmm0 ; sum = 0 +.loop: ; each loop sums 4 products + movdqa xmm1, [rax+rcx*2] ; put parital sums of vector product in xmm1 + pmaddwd xmm1, [rdx+rcx*2] + psrad xmm1, 8 + paddd xmm0, xmm1 + sub rcx, 8 + ja .loop + movdqa xmm1, xmm0 ; add 4 parts of xmm0 and return in eax + psrldq xmm1, 8 + paddd xmm0, xmm1 + movdqa xmm1, xmm0 + psrldq xmm1, 4 + paddd xmm0, xmm1 + movd rax, xmm0 +.done + ret + +; Train n neural network weights w[n] on inputs t[n] and err. +; w[i] += t[i]*err*2+1 >> 17 bounded to +- 32K. +; n is rounded up to a multiple of 8. + +;1st arg rdi -> *t +;2nd arg rsi -> *w +;3rd arg rdx -> n +;4th arg rcx -> err + + global train_x86_64 ; (short* t, short* w, int n, int err) + BITS 64 + align 16 +train_x86_64: + mov rax, rcx ; err + and rax, 0xffff ; put 4 copies of err in mm0 + movd xmm0, rax + movd xmm1, rax + psllq xmm1, 16 + por xmm0, xmm1 + movq xmm1, xmm0 + psllq xmm1, 32 + por xmm0, xmm1 + pcmpeqb xmm1, xmm1 ; 4 copies of 1 in mm1 + psrlw xmm1, 15 + mov rcx, rdx ; n + mov rax, rdi ; t + mov rdx, rsi ; w + add rcx, 7 ; n/8 rounding up + and rcx, -8 + sub rax, 8 + sub rdx, 8 + jz .done + ; err => xmm0 +.loop: ; each iteration adjusts 8 weights + movq xmm2, [rdx+rcx*2] ; w[i] + movq xmm3, [rax+rcx*2] ; t[i] + paddsw xmm3, xmm3 ; t[i]*2 + pmulhw xmm3, xmm0 ; t[i]*err + paddsw xmm3, xmm1 + psraw xmm3, 1 + paddsw xmm2, xmm3 + movq [rdx+rcx*2], xmm2 + sub rcx, 8 + ja .loop +.done: + ret + diff -Nwbaur paq-8f.dist/paq8f.cpp paq-8f/paq8f.cpp --- paq-8f.dist/paq8f.cpp 2006-11-22 00:16:14.000000000 -0600 +++ paq-8f/paq8f.cpp 2006-12-15 11:56:29.000000000 -0600 @@ -542,6 +542,9 @@ #define DEFAULT_OPTION 5 #endif + + + // 8, 16, 32 bit unsigned types (adjust as appropriate) typedef unsigned char U8; typedef unsigned short U16; @@ -1095,6 +1098,8 @@ sum+=(t[i]*w[i]+t[i+1]*w[i+1]) >> 8; return sum; } +#elif __x86_64 +extern "C" int dot_product_x86_64(short *t, short *w, int n); // in NASM #else // The NASM version uses MMX and is about 8 times faster. extern "C" int dot_product(short *t, short *w, int n); // in NASM #endif @@ -1113,6 +1118,8 @@ w[i]=wt; } } +#elif __x86_64 +extern "C" void train_x86_64(short *t, short *w, int n, int err); // in NASM #else extern "C" void train(short *t, short *w, int n, int err); // in NASM #endif @@ -1135,7 +1142,13 @@ for (int i=0; i=-32768 && err<32768); +#ifdef NOASM train(&tx[0], &wx[cxt[i]*N], nx, err); +#elif __x86_64 + train_x86_64(&tx[0], &wx[cxt[i]*N], nx, err); +#else + train(&tx[0], &wx[cxt[i]*N], nx, err); +#endif } nx=base=ncxt=0; } @@ -1162,14 +1175,27 @@ if (mp) { // combine outputs mp->update(); for (int i=0; i>5); +#elif __x86_64 + pr[i]=squash(dot_product_x86_64(&tx[0], &wx[cxt[i]*N], nx)>>5); +#else + pr[i]=squash(dot_product(&tx[0], &wx[cxt[i]*N], nx)>>5); +#endif mp->add(stretch(pr[i])); } mp->set(0, 1); return mp->p(); } else { // S=1 context +#ifdef NOASM // no assembly language return pr[0]=squash(dot_product(&tx[0], &wx[0], nx)>>8); +#elif __x86_64 + return pr[0]=squash(dot_product_x86_64(&tx[0], &wx[0], nx)>>8); +#else + return pr[0]=squash(dot_product(&tx[0], &wx[0], nx)>>8); +#endif } } ~Mixer();