1 /*-
2  * Copyright (c) 2016 Mindaugas Rasiukevicius <rmind at noxt eu>
3  * All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
15  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
17  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
20  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
21  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
22  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
23  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
24  * SUCH DAMAGE.
25  */
26 
27 /*
28  * Longest Prefix Match (LPM) library supporting IPv4 and IPv6.
29  *
30  * Algorithm:
31  *
32  * Each prefix gets its own hash map and all added prefixes are saved
33  * in a bitmap.  On a lookup, we perform a linear scan of hash maps,
34  * iterating through the added prefixes only.  Usually, there are only
35  * a few unique prefixes used and such simple algorithm is very efficient.
36  * With many IPv6 prefixes, the linear scan might become a bottleneck.
37  */
38 
39 #if defined(_KERNEL)
40 #include <sys/cdefs.h>
41 __KERNEL_RCSID(0, "$NetBSD: lpm.c,v 1.6 2019/06/12 14:36:32 christos Exp $");
42 
43 #include <sys/param.h>
44 #include <sys/types.h>
45 #include <sys/malloc.h>
46 #include <sys/kmem.h>
47 #else
48 #include <sys/socket.h>
49 #include <arpa/inet.h>
50 
51 #include <stdio.h>
52 #include <stdlib.h>
53 #include <stdbool.h>
54 #include <stddef.h>
55 #include <string.h>
56 #include <strings.h>
57 #include <errno.h>
58 #include <assert.h>
59 #define kmem_alloc(a, b) malloc(a)
60 #define kmem_free(a, b) free(a)
61 #define kmem_zalloc(a, b) calloc(a, 1)
62 #endif
63 
64 #include "lpm.h"
65 
66 #define   LPM_MAX_PREFIX                (128)
67 #define   LPM_MAX_WORDS                 (LPM_MAX_PREFIX >> 5)
68 #define   LPM_TO_WORDS(x)               ((x) >> 2)
69 #define   LPM_HASH_STEP                 (8)
70 #define   LPM_LEN_IDX(len)    ((len) >> 4)
71 
72 #ifdef DEBUG
73 #define   ASSERT                        assert
74 #else
75 #define   ASSERT(x)
76 #endif
77 
78 typedef struct lpm_ent {
79           struct lpm_ent *next;
80           void *              val;
81           unsigned  len;
82           uint8_t             key[];
83 } lpm_ent_t;
84 
85 typedef struct {
86           unsigned  hashsize;
87           unsigned  nitems;
88           lpm_ent_t **        bucket;
89 } lpm_hmap_t;
90 
91 struct lpm {
92           uint32_t  bitmask[LPM_MAX_WORDS];
93           int                 flags;
94           void *              defvals[2];
95           lpm_hmap_t          prefix[LPM_MAX_PREFIX + 1];
96 };
97 
98 static const uint32_t zero_address[LPM_MAX_WORDS];
99 
100 lpm_t *
lpm_create(int flags)101 lpm_create(int flags)
102 {
103           lpm_t *lpm = kmem_zalloc(sizeof(*lpm), KM_SLEEP);
104           lpm->flags = flags;
105           return lpm;
106 }
107 
108 void
lpm_clear(lpm_t * lpm,lpm_dtor_t dtor,void * arg)109 lpm_clear(lpm_t *lpm, lpm_dtor_t dtor, void *arg)
110 {
111           for (unsigned n = 0; n <= LPM_MAX_PREFIX; n++) {
112                     lpm_hmap_t *hmap = &lpm->prefix[n];
113 
114                     if (!hmap->hashsize) {
115                               KASSERT(!hmap->bucket);
116                               continue;
117                     }
118                     for (unsigned i = 0; i < hmap->hashsize; i++) {
119                               lpm_ent_t *entry = hmap->bucket[i];
120 
121                               while (entry) {
122                                         lpm_ent_t *next = entry->next;
123 
124                                         if (dtor) {
125                                                   dtor(arg, entry->key,
126                                                       entry->len, entry->val);
127                                         }
128                                         kmem_free(entry,
129                                             offsetof(lpm_ent_t, key[entry->len]));
130                                         entry = next;
131                               }
132                     }
133                     kmem_free(hmap->bucket, hmap->hashsize * sizeof(lpm_ent_t *));
134                     hmap->bucket = NULL;
135                     hmap->hashsize = 0;
136                     hmap->nitems = 0;
137           }
138           if (dtor) {
139                     dtor(arg, zero_address, 4, lpm->defvals[0]);
140                     dtor(arg, zero_address, 16, lpm->defvals[1]);
141           }
142           memset(lpm->bitmask, 0, sizeof(lpm->bitmask));
143           memset(lpm->defvals, 0, sizeof(lpm->defvals));
144 }
145 
146 void
lpm_destroy(lpm_t * lpm)147 lpm_destroy(lpm_t *lpm)
148 {
149           lpm_clear(lpm, NULL, NULL);
150           kmem_free(lpm, sizeof(*lpm));
151 }
152 
153 /*
154  * fnv1a_hash: Fowler-Noll-Vo hash function (FNV-1a variant).
155  */
156 static uint32_t
fnv1a_hash(const void * buf,size_t len)157 fnv1a_hash(const void *buf, size_t len)
158 {
159           uint32_t hash = 2166136261UL;
160           const uint8_t *p = buf;
161 
162           while (len--) {
163                     hash ^= *p++;
164                     hash *= 16777619U;
165           }
166           return hash;
167 }
168 
169 static bool
hashmap_rehash(lpm_hmap_t * hmap,unsigned size,int flags)170 hashmap_rehash(lpm_hmap_t *hmap, unsigned size, int flags)
171 {
172           lpm_ent_t **bucket;
173           unsigned hashsize;
174 
175           for (hashsize = 1; hashsize < size; hashsize <<= 1) {
176                     continue;
177           }
178           bucket = kmem_zalloc(hashsize * sizeof(lpm_ent_t *), flags);
179           if (bucket == NULL)
180                     return false;
181           for (unsigned n = 0; n < hmap->hashsize; n++) {
182                     lpm_ent_t *list = hmap->bucket[n];
183 
184                     while (list) {
185                               lpm_ent_t *entry = list;
186                               uint32_t hash = fnv1a_hash(entry->key, entry->len);
187                               const unsigned i = hash & (hashsize - 1);
188 
189                               list = entry->next;
190                               entry->next = bucket[i];
191                               bucket[i] = entry;
192                     }
193           }
194           if (hmap->bucket)
195                     kmem_free(hmap->bucket, hmap->hashsize * sizeof(lpm_ent_t *));
196           hmap->bucket = bucket;
197           hmap->hashsize = hashsize;
198           return true;
199 }
200 
201 static lpm_ent_t *
hashmap_insert(lpm_hmap_t * hmap,const void * key,size_t len,int flags)202 hashmap_insert(lpm_hmap_t *hmap, const void *key, size_t len, int flags)
203 {
204           const unsigned target = hmap->nitems + LPM_HASH_STEP;
205           const size_t entlen = offsetof(lpm_ent_t, key[len]);
206           uint32_t hash, i;
207           lpm_ent_t *entry;
208 
209           if (hmap->hashsize < target && !hashmap_rehash(hmap, target, flags)) {
210                     return NULL;
211           }
212 
213           hash = fnv1a_hash(key, len);
214           i = hash & (hmap->hashsize - 1);
215           entry = hmap->bucket[i];
216           while (entry) {
217                     if (entry->len == len && memcmp(entry->key, key, len) == 0) {
218                               return entry;
219                     }
220                     entry = entry->next;
221           }
222 
223           if ((entry = kmem_alloc(entlen, flags)) != NULL) {
224                     memcpy(entry->key, key, len);
225                     entry->next = hmap->bucket[i];
226                     entry->len = len;
227 
228                     hmap->bucket[i] = entry;
229                     hmap->nitems++;
230           }
231           return entry;
232 }
233 
234 static lpm_ent_t *
hashmap_lookup(lpm_hmap_t * hmap,const void * key,size_t len)235 hashmap_lookup(lpm_hmap_t *hmap, const void *key, size_t len)
236 {
237           const uint32_t hash = fnv1a_hash(key, len);
238           const unsigned i = hash & (hmap->hashsize - 1);
239           lpm_ent_t *entry;
240 
241           if (hmap->hashsize == 0) {
242                     return NULL;
243           }
244           entry = hmap->bucket[i];
245 
246           while (entry) {
247                     if (entry->len == len && memcmp(entry->key, key, len) == 0) {
248                               return entry;
249                     }
250                     entry = entry->next;
251           }
252           return NULL;
253 }
254 
255 static int
hashmap_remove(lpm_hmap_t * hmap,const void * key,size_t len)256 hashmap_remove(lpm_hmap_t *hmap, const void *key, size_t len)
257 {
258           const uint32_t hash = fnv1a_hash(key, len);
259           const unsigned i = hash & (hmap->hashsize - 1);
260           lpm_ent_t *prev = NULL, *entry;
261 
262           if (hmap->hashsize == 0) {
263                     return -1;
264           }
265           entry = hmap->bucket[i];
266 
267           while (entry) {
268                     if (entry->len == len && memcmp(entry->key, key, len) == 0) {
269                               if (prev) {
270                                         prev->next = entry->next;
271                               } else {
272                                         hmap->bucket[i] = entry->next;
273                               }
274                               kmem_free(entry, offsetof(lpm_ent_t, key[len]));
275                               return 0;
276                     }
277                     prev = entry;
278                     entry = entry->next;
279           }
280           return -1;
281 }
282 
283 /*
284  * compute_prefix: given the address and prefix length, compute and
285  * return the address prefix.
286  */
287 static inline void
compute_prefix(const unsigned nwords,const uint32_t * addr,unsigned preflen,uint32_t * prefix)288 compute_prefix(const unsigned nwords, const uint32_t *addr,
289     unsigned preflen, uint32_t *prefix)
290 {
291           uint32_t addr2[4];
292 
293           if ((uintptr_t)addr & 3) {
294                     /* Unaligned address: just copy for now. */
295                     memcpy(addr2, addr, nwords * 4);
296                     addr = addr2;
297           }
298           for (unsigned i = 0; i < nwords; i++) {
299                     if (preflen == 0) {
300                               prefix[i] = 0;
301                               continue;
302                     }
303                     if (preflen < 32) {
304                               uint32_t mask = htonl(0xffffffff << (32 - preflen));
305                               prefix[i] = addr[i] & mask;
306                               preflen = 0;
307                     } else {
308                               prefix[i] = addr[i];
309                               preflen -= 32;
310                     }
311           }
312 }
313 
314 /*
315  * lpm_insert: insert the CIDR into the LPM table.
316  *
317  * => Returns zero on success and -1 on failure.
318  */
319 int
lpm_insert(lpm_t * lpm,const void * addr,size_t len,unsigned preflen,void * val)320 lpm_insert(lpm_t *lpm, const void *addr,
321     size_t len, unsigned preflen, void *val)
322 {
323           const unsigned nwords = LPM_TO_WORDS(len);
324           uint32_t prefix[LPM_MAX_WORDS];
325           lpm_ent_t *entry;
326           KASSERT(len == 4 || len == 16);
327 
328           if (preflen == 0) {
329                     /* 0-length prefix is a special case. */
330                     lpm->defvals[LPM_LEN_IDX(len)] = val;
331                     return 0;
332           }
333           compute_prefix(nwords, addr, preflen, prefix);
334           entry = hashmap_insert(&lpm->prefix[preflen], prefix, len, lpm->flags);
335           if (entry) {
336                     const unsigned n = --preflen >> 5;
337                     lpm->bitmask[n] |= 0x80000000U >> (preflen & 31);
338                     entry->val = val;
339                     return 0;
340           }
341           return -1;
342 }
343 
344 /*
345  * lpm_remove: remove the specified prefix.
346  */
347 int
lpm_remove(lpm_t * lpm,const void * addr,size_t len,unsigned preflen)348 lpm_remove(lpm_t *lpm, const void *addr, size_t len, unsigned preflen)
349 {
350           const unsigned nwords = LPM_TO_WORDS(len);
351           uint32_t prefix[LPM_MAX_WORDS];
352           KASSERT(len == 4 || len == 16);
353 
354           if (preflen == 0) {
355                     lpm->defvals[LPM_LEN_IDX(len)] = NULL;
356                     return 0;
357           }
358           compute_prefix(nwords, addr, preflen, prefix);
359           return hashmap_remove(&lpm->prefix[preflen], prefix, len);
360 }
361 
362 /*
363  * lpm_lookup: find the longest matching prefix given the IP address.
364  *
365  * => Returns the associated value on success or NULL on failure.
366  */
367 void *
lpm_lookup(lpm_t * lpm,const void * addr,size_t len)368 lpm_lookup(lpm_t *lpm, const void *addr, size_t len)
369 {
370           const unsigned nwords = LPM_TO_WORDS(len);
371           unsigned i, n = nwords;
372           uint32_t prefix[LPM_MAX_WORDS];
373 
374           while (n--) {
375                     uint32_t bitmask = lpm->bitmask[n];
376 
377                     while ((i = ffs(bitmask)) != 0) {
378                               const unsigned preflen = (32 * n) + (32 - --i);
379                               lpm_hmap_t *hmap = &lpm->prefix[preflen];
380                               lpm_ent_t *entry;
381 
382                               compute_prefix(nwords, addr, preflen, prefix);
383                               entry = hashmap_lookup(hmap, prefix, len);
384                               if (entry) {
385                                         return entry->val;
386                               }
387                               bitmask &= ~(1U << i);
388                     }
389           }
390           return lpm->defvals[LPM_LEN_IDX(len)];
391 }
392 
393 /*
394  * lpm_lookup_prefix: return the value associated with a prefix
395  *
396  * => Returns the associated value on success or NULL on failure.
397  */
398 void *
lpm_lookup_prefix(lpm_t * lpm,const void * addr,size_t len,unsigned preflen)399 lpm_lookup_prefix(lpm_t *lpm, const void *addr, size_t len, unsigned preflen)
400 {
401           const unsigned nwords = LPM_TO_WORDS(len);
402           uint32_t prefix[LPM_MAX_WORDS];
403           lpm_ent_t *entry;
404           KASSERT(len == 4 || len == 16);
405 
406           if (preflen == 0) {
407                     return lpm->defvals[LPM_LEN_IDX(len)];
408           }
409           compute_prefix(nwords, addr, preflen, prefix);
410           entry = hashmap_lookup(&lpm->prefix[preflen], prefix, len);
411           if (entry) {
412                     return entry->val;
413           }
414           return NULL;
415 }
416 
417 #if !defined(_KERNEL)
418 /*
419  * lpm_strtobin: convert CIDR string to the binary IP address and mask.
420  *
421  * => The address will be in the network byte order.
422  * => Returns 0 on success or -1 on failure.
423  */
424 int
lpm_strtobin(const char * cidr,void * addr,size_t * len,unsigned * preflen)425 lpm_strtobin(const char *cidr, void *addr, size_t *len, unsigned *preflen)
426 {
427           char *p, buf[INET6_ADDRSTRLEN];
428 
429           strncpy(buf, cidr, sizeof(buf));
430           buf[sizeof(buf) - 1] = '\0';
431 
432           if ((p = strchr(buf, '/')) != NULL) {
433                     const ptrdiff_t off = p - buf;
434                     *preflen = atoi(&buf[off + 1]);
435                     buf[off] = '\0';
436           } else {
437                     *preflen = LPM_MAX_PREFIX;
438           }
439 
440           if (inet_pton(AF_INET6, buf, addr) == 1) {
441                     *len = 16;
442                     return 0;
443           }
444           if (inet_pton(AF_INET, buf, addr) == 1) {
445                     if (*preflen == LPM_MAX_PREFIX) {
446                               *preflen = 32;
447                     }
448                     *len = 4;
449                     return 0;
450           }
451           return -1;
452 }
453 #endif
454