xref: /dragonfly/tools/tools/toeplitz/toeplitz.c (revision a705726d80daa5154a419cdfee23366c9871d064)
1 #include <sys/types.h>
2 #include <sys/socket.h>
3 
4 #include <arpa/inet.h>
5 #include <netinet/in.h>
6 
7 #include <stdint.h>
8 #include <stdio.h>
9 #include <stdlib.h>
10 #include <string.h>
11 #include <unistd.h>
12 
13 #ifndef NBBY
14 #define NBBY        __NBBY
15 #endif
16 
17 #define HASHMASK    0x7f
18 
19 #define KEYLEN                40
20 #define HASHLEN               12
21 
22 static uint8_t      toeplitz_key[KEYLEN];
23 static uint32_t     hash_table[HASHLEN][256];
24 
25 static void         toeplitz_init(uint32_t[][256], int, const uint8_t[], int);
26 static void         getaddrport(char *, uint32_t *, uint16_t *);
27 
28 static void
usage(const char * cmd)29 usage(const char *cmd)
30 {
31           fprintf(stderr, "%s [-s s1_hex [-s s2_hex]] [-p] [-m mask] [-d div] "
32               "addr1.port1 addr2.port2\n", cmd);
33           exit(1);
34 }
35 
36 int
main(int argc,char * argv[])37 main(int argc, char *argv[])
38 {
39           uint32_t saddr, daddr;
40           uint16_t sport, dport;
41           uint32_t res, mask, divisor;
42 
43           const char *cmd = argv[0];
44           uint8_t seeds[2] = { 0x6d, 0x5a };
45           int i, opt, use_port;
46 
47           i = 0;
48           use_port = 0;
49           mask = 0xffffffff;
50           divisor = 0;
51 
52           while ((opt = getopt(argc, argv, "d:m:ps:")) != -1) {
53                     switch (opt) {
54                     case 'd':
55                               divisor = strtoul(optarg, NULL, 10);
56                               break;
57 
58                     case 'm':
59                               mask = strtoul(optarg, NULL, 16);
60                               break;
61 
62                     case 'p':
63                               use_port = 1;
64                               break;
65 
66                     case 's':
67                               if (i >= 2)
68                                         usage(cmd);
69                               seeds[i++] = strtoul(optarg, NULL, 16);
70                               break;
71 
72                     default:
73                               usage(cmd);
74                     }
75           }
76           argc -= optind;
77           argv += optind;
78 
79           if (argc != 2)
80                     usage(cmd);
81 
82           for (i = 0; i < KEYLEN; ++i) {
83                     if (i & 1)
84                               toeplitz_key[i] = seeds[1];
85                     else
86                               toeplitz_key[i] = seeds[0];
87           }
88 
89           getaddrport(argv[0], &saddr, &sport);
90           getaddrport(argv[1], &daddr, &dport);
91 
92           toeplitz_init(hash_table, HASHLEN, toeplitz_key, KEYLEN);
93 
94           res =  hash_table[0][(saddr >> 0) & 0xff];
95           res ^= hash_table[1][(saddr >> 8) & 0xff];
96           res ^= hash_table[2][(saddr >> 16)  & 0xff];
97           res ^= hash_table[3][(saddr >> 24)  & 0xff];
98           res ^= hash_table[4][(daddr >> 0) & 0xff];
99           res ^= hash_table[5][(daddr >> 8) & 0xff];
100           res ^= hash_table[6][(daddr >> 16)  & 0xff];
101           res ^= hash_table[7][(daddr >> 24)  & 0xff];
102           if (use_port) {
103                     res ^= hash_table[8][(sport >> 0)  & 0xff];
104                     res ^= hash_table[9][(sport >> 8)  & 0xff];
105                     res ^= hash_table[10][(dport >> 0)  & 0xff];
106                     res ^= hash_table[11][(dport >> 8)  & 0xff];
107           }
108 
109           printf("0x%08x, masked 0x%08x", res, res & mask);
110           if (divisor == 0)
111                     printf("\n");
112           else
113                     printf(", modulo %u\n", (res & HASHMASK) % divisor);
114           exit(0);
115 }
116 
117 static void
toeplitz_init(uint32_t cache[][256],int cache_len,const uint8_t key_str[],int key_strlen)118 toeplitz_init(uint32_t cache[][256], int cache_len,
119     const uint8_t key_str[], int key_strlen)
120 {
121           int i;
122 
123           if (key_strlen < cache_len + (int)sizeof(uint32_t))
124                     exit(1);
125 
126           for (i = 0; i < cache_len; ++i) {
127                     uint32_t key[NBBY];
128                     int j, b, shift, val;
129 
130                     bzero(key, sizeof(key));
131 
132                     /*
133                      * Calculate 32bit keys for one byte; one key for each bit.
134                      */
135                     for (b = 0; b < NBBY; ++b) {
136                               for (j = 0; j < 32; ++j) {
137                                         uint8_t k;
138                                         int bit;
139 
140                                         bit = (i * NBBY) + b + j;
141 
142                                         k = key_str[bit / NBBY];
143                                         shift = NBBY - (bit % NBBY) - 1;
144                                         if (k & (1 << shift))
145                                                   key[b] |= 1 << (31 - j);
146                               }
147                     }
148 
149                     /*
150                      * Cache the results of all possible bit combination of
151                      * one byte.
152                      */
153                     for (val = 0; val < 256; ++val) {
154                               uint32_t res = 0;
155 
156                               for (b = 0; b < NBBY; ++b) {
157                                         shift = NBBY - b - 1;
158                                         if (val & (1 << shift))
159                                                   res ^= key[b];
160                               }
161                               cache[i][val] = res;
162                     }
163           }
164 }
165 
166 static void
getaddrport(char * ap_str,uint32_t * addr,uint16_t * port0)167 getaddrport(char *ap_str, uint32_t *addr, uint16_t *port0)
168 {
169           uint16_t port;
170           char *p;
171 
172           p = strrchr(ap_str, '.');
173           if (p == NULL) {
174                     fprintf(stderr, "invalid addr.port %s\n", ap_str);
175                     exit(1);
176           }
177 
178           *p = '\0';
179           ++p;
180 
181           port = strtoul(p, NULL, 10);
182           *port0 = htons(port);
183 
184           if (inet_pton(AF_INET, ap_str, addr) <= 0) {
185                     fprintf(stderr, "invalid addr %s\n", ap_str);
186                     exit(1);
187           }
188 }
189