1 /*        $NetBSD: bpfjit.c,v 1.48 2020/02/01 02:54:02 riastradh Exp $          */
2 
3 /*-
4  * Copyright (c) 2011-2015 Alexander Nasonov.
5  * All rights reserved.
6  *
7  * Redistribution and use in source and binary forms, with or without
8  * modification, are permitted provided that the following conditions
9  * are met:
10  *
11  * 1. Redistributions of source code must retain the above copyright
12  *    notice, this list of conditions and the following disclaimer.
13  * 2. Redistributions in binary form must reproduce the above copyright
14  *    notice, this list of conditions and the following disclaimer in
15  *    the documentation and/or other materials provided with the
16  *    distribution.
17  *
18  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
19  * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
20  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
21  * FOR A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE
22  * COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
23  * INCIDENTAL, SPECIAL, EXEMPLARY OR CONSEQUENTIAL DAMAGES (INCLUDING,
24  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
25  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED
26  * AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT
28  * OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
29  * SUCH DAMAGE.
30  */
31 
32 #include <sys/cdefs.h>
33 #ifdef _KERNEL
34 __KERNEL_RCSID(0, "$NetBSD: bpfjit.c,v 1.48 2020/02/01 02:54:02 riastradh Exp $");
35 #else
36 __RCSID("$NetBSD: bpfjit.c,v 1.48 2020/02/01 02:54:02 riastradh Exp $");
37 #endif
38 
39 #include <sys/types.h>
40 #include <sys/queue.h>
41 
42 #ifndef _KERNEL
43 #include <assert.h>
44 #define BJ_ASSERT(c) assert(c)
45 #else
46 #define BJ_ASSERT(c) KASSERT(c)
47 #endif
48 
49 #ifndef _KERNEL
50 #include <stdlib.h>
51 #define BJ_ALLOC(sz) malloc(sz)
52 #define BJ_FREE(p, sz) free(p)
53 #else
54 #include <sys/kmem.h>
55 #define BJ_ALLOC(sz) kmem_alloc(sz, KM_SLEEP)
56 #define BJ_FREE(p, sz) kmem_free(p, sz)
57 #endif
58 
59 #ifndef _KERNEL
60 #include <limits.h>
61 #include <stdbool.h>
62 #include <stddef.h>
63 #include <stdint.h>
64 #include <string.h>
65 #else
66 #include <sys/atomic.h>
67 #include <sys/module.h>
68 #endif
69 
70 #define   __BPF_PRIVATE
71 #include <net/bpf.h>
72 #include <net/bpfjit.h>
73 #include <sljitLir.h>
74 
75 #if !defined(_KERNEL) && defined(SLJIT_VERBOSE) && SLJIT_VERBOSE
76 #include <stdio.h> /* for stderr */
77 #endif
78 
79 /*
80  * Number of saved registers to pass to sljit_emit_enter() function.
81  */
82 #define NSAVEDS               3
83 
84 /*
85  * Arguments of generated bpfjit_func_t.
86  * The first argument is reassigned upon entry
87  * to a more frequently used buf argument.
88  */
89 #define BJ_CTX_ARG  SLJIT_S0
90 #define BJ_ARGS               SLJIT_S1
91 
92 /*
93  * Permanent register assignments.
94  */
95 #define BJ_BUF                SLJIT_S0
96 //#define BJ_ARGS   SLJIT_S1
97 #define BJ_BUFLEN   SLJIT_S2
98 #define BJ_AREG               SLJIT_R0
99 #define BJ_TMP1REG  SLJIT_R1
100 #define BJ_TMP2REG  SLJIT_R2
101 #define BJ_XREG               SLJIT_R3
102 #define BJ_TMP3REG  SLJIT_R4
103 
104 #ifdef _KERNEL
105 #define MAX_MEMWORDS BPF_MAX_MEMWORDS
106 #else
107 #define MAX_MEMWORDS BPF_MEMWORDS
108 #endif
109 
110 #define BJ_INIT_NOBITS  ((bpf_memword_init_t)0)
111 #define BJ_INIT_MBIT(k) BPF_MEMWORD_INIT(k)
112 #define BJ_INIT_ABIT    BJ_INIT_MBIT(MAX_MEMWORDS)
113 #define BJ_INIT_XBIT    BJ_INIT_MBIT(MAX_MEMWORDS + 1)
114 
115 /*
116  * Get a number of memwords and external memwords from a bpf_ctx object.
117  */
118 #define GET_EXTWORDS(bc) ((bc) ? (bc)->extwords : 0)
119 #define GET_MEMWORDS(bc) (GET_EXTWORDS(bc) ? GET_EXTWORDS(bc) : BPF_MEMWORDS)
120 
121 /*
122  * Optimization hints.
123  */
124 typedef unsigned int bpfjit_hint_t;
125 #define BJ_HINT_ABS  0x01 /* packet read at absolute offset   */
126 #define BJ_HINT_IND  0x02 /* packet read at variable offset   */
127 #define BJ_HINT_MSH  0x04 /* BPF_MSH instruction              */
128 #define BJ_HINT_COP  0x08 /* BPF_COP or BPF_COPX instruction  */
129 #define BJ_HINT_COPX 0x10 /* BPF_COPX instruction             */
130 #define BJ_HINT_XREG 0x20 /* BJ_XREG is needed                */
131 #define BJ_HINT_LDX  0x40 /* BPF_LDX instruction              */
132 #define BJ_HINT_PKT  (BJ_HINT_ABS|BJ_HINT_IND|BJ_HINT_MSH)
133 
134 /*
135  * Datatype for Array Bounds Check Elimination (ABC) pass.
136  */
137 typedef uint64_t bpfjit_abc_length_t;
138 #define MAX_ABC_LENGTH (UINT32_MAX + UINT64_C(4)) /* max. width is 4 */
139 
140 struct bpfjit_stack
141 {
142           bpf_ctx_t *ctx;
143           uint32_t *extmem; /* pointer to external memory store */
144           uint32_t reg; /* saved A or X register */
145 #ifdef _KERNEL
146           int err; /* 3rd argument for m_xword/m_xhalf/m_xbyte function call */
147 #endif
148           uint32_t mem[BPF_MEMWORDS]; /* internal memory store */
149 };
150 
151 /*
152  * Data for BPF_JMP instruction.
153  * Forward declaration for struct bpfjit_jump.
154  */
155 struct bpfjit_jump_data;
156 
157 /*
158  * Node of bjumps list.
159  */
160 struct bpfjit_jump {
161           struct sljit_jump *sjump;
162           SLIST_ENTRY(bpfjit_jump) entries;
163           struct bpfjit_jump_data *jdata;
164 };
165 
166 /*
167  * Data for BPF_JMP instruction.
168  */
169 struct bpfjit_jump_data {
170           /*
171            * These entries make up bjumps list:
172            * jtf[0] - when coming from jt path,
173            * jtf[1] - when coming from jf path.
174            */
175           struct bpfjit_jump jtf[2];
176           /*
177            * Length calculated by Array Bounds Check Elimination (ABC) pass.
178            */
179           bpfjit_abc_length_t abc_length;
180           /*
181            * Length checked by the last out-of-bounds check.
182            */
183           bpfjit_abc_length_t checked_length;
184 };
185 
186 /*
187  * Data for "read from packet" instructions.
188  * See also read_pkt_insn() function below.
189  */
190 struct bpfjit_read_pkt_data {
191           /*
192            * Length calculated by Array Bounds Check Elimination (ABC) pass.
193            */
194           bpfjit_abc_length_t abc_length;
195           /*
196            * If positive, emit "if (buflen < check_length) return 0"
197            * out-of-bounds check.
198            * Values greater than UINT32_MAX generate unconditional "return 0".
199            */
200           bpfjit_abc_length_t check_length;
201 };
202 
203 /*
204  * Additional (optimization-related) data for bpf_insn.
205  */
206 struct bpfjit_insn_data {
207           /* List of jumps to this insn. */
208           SLIST_HEAD(, bpfjit_jump) bjumps;
209 
210           union {
211                     struct bpfjit_jump_data     jdata;
212                     struct bpfjit_read_pkt_data rdata;
213           } u;
214 
215           bpf_memword_init_t invalid;
216           bool unreachable;
217 };
218 
219 #ifdef _KERNEL
220 
221 uint32_t m_xword(const struct mbuf *, uint32_t, int *);
222 uint32_t m_xhalf(const struct mbuf *, uint32_t, int *);
223 uint32_t m_xbyte(const struct mbuf *, uint32_t, int *);
224 
225 MODULE(MODULE_CLASS_MISC, bpfjit, "sljit")
226 
227 static int
bpfjit_modcmd(modcmd_t cmd,void * arg)228 bpfjit_modcmd(modcmd_t cmd, void *arg)
229 {
230 
231           switch (cmd) {
232           case MODULE_CMD_INIT:
233                     bpfjit_module_ops.bj_free_code = &bpfjit_free_code;
234                     atomic_store_release(&bpfjit_module_ops.bj_generate_code,
235                         &bpfjit_generate_code);
236                     return 0;
237 
238           case MODULE_CMD_FINI:
239                     return EOPNOTSUPP;
240 
241           default:
242                     return ENOTTY;
243           }
244 }
245 #endif
246 
247 /*
248  * Return a number of scratch registers to pass
249  * to sljit_emit_enter() function.
250  */
251 static sljit_s32
nscratches(bpfjit_hint_t hints)252 nscratches(bpfjit_hint_t hints)
253 {
254           sljit_s32 rv = 2;
255 
256 #ifdef _KERNEL
257           if (hints & BJ_HINT_PKT)
258                     rv = 3; /* xcall with three arguments */
259 #endif
260 
261           if (hints & BJ_HINT_IND)
262                     rv = 3; /* uses BJ_TMP2REG */
263 
264           if (hints & BJ_HINT_COP)
265                     rv = 3; /* calls copfunc with three arguments */
266 
267           if (hints & BJ_HINT_XREG)
268                     rv = 4; /* uses BJ_XREG */
269 
270 #ifdef _KERNEL
271           if (hints & BJ_HINT_LDX)
272                     rv = 5; /* uses BJ_TMP3REG */
273 #endif
274 
275           if (hints & BJ_HINT_COPX)
276                     rv = 5; /* uses BJ_TMP3REG */
277 
278           return rv;
279 }
280 
281 static uint32_t
read_width(const struct bpf_insn * pc)282 read_width(const struct bpf_insn *pc)
283 {
284 
285           switch (BPF_SIZE(pc->code)) {
286           case BPF_W: return 4;
287           case BPF_H: return 2;
288           case BPF_B: return 1;
289           default:    return 0;
290           }
291 }
292 
293 /*
294  * Copy buf and buflen members of bpf_args from BJ_ARGS
295  * pointer to BJ_BUF and BJ_BUFLEN registers.
296  */
297 static int
load_buf_buflen(struct sljit_compiler * compiler)298 load_buf_buflen(struct sljit_compiler *compiler)
299 {
300           int status;
301 
302           status = sljit_emit_op1(compiler,
303               SLJIT_MOV_P,
304               BJ_BUF, 0,
305               SLJIT_MEM1(BJ_ARGS),
306               offsetof(struct bpf_args, pkt));
307           if (status != SLJIT_SUCCESS)
308                     return status;
309 
310           status = sljit_emit_op1(compiler,
311               SLJIT_MOV, /* size_t source */
312               BJ_BUFLEN, 0,
313               SLJIT_MEM1(BJ_ARGS),
314               offsetof(struct bpf_args, buflen));
315 
316           return status;
317 }
318 
319 static bool
grow_jumps(struct sljit_jump *** jumps,size_t * size)320 grow_jumps(struct sljit_jump ***jumps, size_t *size)
321 {
322           struct sljit_jump **newptr;
323           const size_t elemsz = sizeof(struct sljit_jump *);
324           size_t old_size = *size;
325           size_t new_size = 2 * old_size;
326 
327           if (new_size < old_size || new_size > SIZE_MAX / elemsz)
328                     return false;
329 
330           newptr = BJ_ALLOC(new_size * elemsz);
331           if (newptr == NULL)
332                     return false;
333 
334           memcpy(newptr, *jumps, old_size * elemsz);
335           BJ_FREE(*jumps, old_size * elemsz);
336 
337           *jumps = newptr;
338           *size = new_size;
339           return true;
340 }
341 
342 static bool
append_jump(struct sljit_jump * jump,struct sljit_jump *** jumps,size_t * size,size_t * max_size)343 append_jump(struct sljit_jump *jump, struct sljit_jump ***jumps,
344     size_t *size, size_t *max_size)
345 {
346           if (*size == *max_size && !grow_jumps(jumps, max_size))
347                     return false;
348 
349           (*jumps)[(*size)++] = jump;
350           return true;
351 }
352 
353 /*
354  * Emit code for BPF_LD+BPF_B+BPF_ABS    A <- P[k:1].
355  */
356 static int
emit_read8(struct sljit_compiler * compiler,sljit_s32 src,uint32_t k)357 emit_read8(struct sljit_compiler *compiler, sljit_s32 src, uint32_t k)
358 {
359 
360           return sljit_emit_op1(compiler,
361               SLJIT_MOV_U8,
362               BJ_AREG, 0,
363               SLJIT_MEM1(src), k);
364 }
365 
366 /*
367  * Emit code for BPF_LD+BPF_H+BPF_ABS    A <- P[k:2].
368  */
369 static int
emit_read16(struct sljit_compiler * compiler,sljit_s32 src,uint32_t k)370 emit_read16(struct sljit_compiler *compiler, sljit_s32 src, uint32_t k)
371 {
372           int status;
373 
374           BJ_ASSERT(k <= UINT32_MAX - 1);
375 
376           /* A = buf[k]; */
377           status = sljit_emit_op1(compiler,
378               SLJIT_MOV_U8,
379               BJ_AREG, 0,
380               SLJIT_MEM1(src), k);
381           if (status != SLJIT_SUCCESS)
382                     return status;
383 
384           /* tmp1 = buf[k+1]; */
385           status = sljit_emit_op1(compiler,
386               SLJIT_MOV_U8,
387               BJ_TMP1REG, 0,
388               SLJIT_MEM1(src), k+1);
389           if (status != SLJIT_SUCCESS)
390                     return status;
391 
392           /* A = A << 8; */
393           status = sljit_emit_op2(compiler,
394               SLJIT_SHL,
395               BJ_AREG, 0,
396               BJ_AREG, 0,
397               SLJIT_IMM, 8);
398           if (status != SLJIT_SUCCESS)
399                     return status;
400 
401           /* A = A + tmp1; */
402           status = sljit_emit_op2(compiler,
403               SLJIT_ADD,
404               BJ_AREG, 0,
405               BJ_AREG, 0,
406               BJ_TMP1REG, 0);
407           return status;
408 }
409 
410 /*
411  * Emit code for BPF_LD+BPF_W+BPF_ABS    A <- P[k:4].
412  */
413 static int
emit_read32(struct sljit_compiler * compiler,sljit_s32 src,uint32_t k)414 emit_read32(struct sljit_compiler *compiler, sljit_s32 src, uint32_t k)
415 {
416           int status;
417 
418           BJ_ASSERT(k <= UINT32_MAX - 3);
419 
420           /* A = buf[k]; */
421           status = sljit_emit_op1(compiler,
422               SLJIT_MOV_U8,
423               BJ_AREG, 0,
424               SLJIT_MEM1(src), k);
425           if (status != SLJIT_SUCCESS)
426                     return status;
427 
428           /* tmp1 = buf[k+1]; */
429           status = sljit_emit_op1(compiler,
430               SLJIT_MOV_U8,
431               BJ_TMP1REG, 0,
432               SLJIT_MEM1(src), k+1);
433           if (status != SLJIT_SUCCESS)
434                     return status;
435 
436           /* A = A << 8; */
437           status = sljit_emit_op2(compiler,
438               SLJIT_SHL,
439               BJ_AREG, 0,
440               BJ_AREG, 0,
441               SLJIT_IMM, 8);
442           if (status != SLJIT_SUCCESS)
443                     return status;
444 
445           /* A = A + tmp1; */
446           status = sljit_emit_op2(compiler,
447               SLJIT_ADD,
448               BJ_AREG, 0,
449               BJ_AREG, 0,
450               BJ_TMP1REG, 0);
451           if (status != SLJIT_SUCCESS)
452                     return status;
453 
454           /* tmp1 = buf[k+2]; */
455           status = sljit_emit_op1(compiler,
456               SLJIT_MOV_U8,
457               BJ_TMP1REG, 0,
458               SLJIT_MEM1(src), k+2);
459           if (status != SLJIT_SUCCESS)
460                     return status;
461 
462           /* A = A << 8; */
463           status = sljit_emit_op2(compiler,
464               SLJIT_SHL,
465               BJ_AREG, 0,
466               BJ_AREG, 0,
467               SLJIT_IMM, 8);
468           if (status != SLJIT_SUCCESS)
469                     return status;
470 
471           /* A = A + tmp1; */
472           status = sljit_emit_op2(compiler,
473               SLJIT_ADD,
474               BJ_AREG, 0,
475               BJ_AREG, 0,
476               BJ_TMP1REG, 0);
477           if (status != SLJIT_SUCCESS)
478                     return status;
479 
480           /* tmp1 = buf[k+3]; */
481           status = sljit_emit_op1(compiler,
482               SLJIT_MOV_U8,
483               BJ_TMP1REG, 0,
484               SLJIT_MEM1(src), k+3);
485           if (status != SLJIT_SUCCESS)
486                     return status;
487 
488           /* A = A << 8; */
489           status = sljit_emit_op2(compiler,
490               SLJIT_SHL,
491               BJ_AREG, 0,
492               BJ_AREG, 0,
493               SLJIT_IMM, 8);
494           if (status != SLJIT_SUCCESS)
495                     return status;
496 
497           /* A = A + tmp1; */
498           status = sljit_emit_op2(compiler,
499               SLJIT_ADD,
500               BJ_AREG, 0,
501               BJ_AREG, 0,
502               BJ_TMP1REG, 0);
503           return status;
504 }
505 
506 #ifdef _KERNEL
507 /*
508  * Emit code for m_xword/m_xhalf/m_xbyte call.
509  *
510  * @pc BPF_LD+BPF_W+BPF_ABS    A <- P[k:4]
511  *     BPF_LD+BPF_H+BPF_ABS    A <- P[k:2]
512  *     BPF_LD+BPF_B+BPF_ABS    A <- P[k:1]
513  *     BPF_LD+BPF_W+BPF_IND    A <- P[X+k:4]
514  *     BPF_LD+BPF_H+BPF_IND    A <- P[X+k:2]
515  *     BPF_LD+BPF_B+BPF_IND    A <- P[X+k:1]
516  *     BPF_LDX+BPF_B+BPF_MSH   X <- 4*(P[k:1]&0xf)
517  */
518 static int
emit_xcall(struct sljit_compiler * compiler,bpfjit_hint_t hints,const struct bpf_insn * pc,int dst,struct sljit_jump *** ret0,size_t * ret0_size,size_t * ret0_maxsize,uint32_t (* fn)(const struct mbuf *,uint32_t,int *))519 emit_xcall(struct sljit_compiler *compiler, bpfjit_hint_t hints,
520     const struct bpf_insn *pc, int dst, struct sljit_jump ***ret0,
521     size_t *ret0_size, size_t *ret0_maxsize,
522     uint32_t (*fn)(const struct mbuf *, uint32_t, int *))
523 {
524 #if BJ_XREG == SLJIT_RETURN_REG   || \
525     BJ_XREG == SLJIT_R0 || \
526     BJ_XREG == SLJIT_R1 || \
527     BJ_XREG == SLJIT_R2
528 #error "Not supported assignment of registers."
529 #endif
530           struct sljit_jump *jump;
531           sljit_s32 save_reg;
532           int status;
533 
534           save_reg = (BPF_CLASS(pc->code) == BPF_LDX) ? BJ_AREG : BJ_XREG;
535 
536           if (save_reg == BJ_AREG || (hints & BJ_HINT_XREG)) {
537                     /* save A or X */
538                     status = sljit_emit_op1(compiler,
539                         SLJIT_MOV_U32,
540                         SLJIT_MEM1(SLJIT_SP),
541                         offsetof(struct bpfjit_stack, reg),
542                         save_reg, 0);
543                     if (status != SLJIT_SUCCESS)
544                               return status;
545           }
546 
547           /*
548            * Prepare registers for fn(mbuf, k, &err) call.
549            */
550           status = sljit_emit_op1(compiler,
551               SLJIT_MOV,
552               SLJIT_R0, 0,
553               BJ_BUF, 0);
554           if (status != SLJIT_SUCCESS)
555                     return status;
556 
557           if (BPF_CLASS(pc->code) == BPF_LD && BPF_MODE(pc->code) == BPF_IND) {
558                     if (pc->k == 0) {
559                               /* k = X; */
560                               status = sljit_emit_op1(compiler,
561                                   SLJIT_MOV,
562                                   SLJIT_R1, 0,
563                                   BJ_XREG, 0);
564                               if (status != SLJIT_SUCCESS)
565                                         return status;
566                     } else {
567                               /* if (X > UINT32_MAX - pc->k) return 0; */
568                               jump = sljit_emit_cmp(compiler,
569                                   SLJIT_GREATER,
570                                   BJ_XREG, 0,
571                                   SLJIT_IMM, UINT32_MAX - pc->k);
572                               if (jump == NULL)
573                                         return SLJIT_ERR_ALLOC_FAILED;
574                               if (!append_jump(jump, ret0, ret0_size, ret0_maxsize))
575                                         return SLJIT_ERR_ALLOC_FAILED;
576 
577                               /* k = X + pc->k; */
578                               status = sljit_emit_op2(compiler,
579                                   SLJIT_ADD,
580                                   SLJIT_R1, 0,
581                                   BJ_XREG, 0,
582                                   SLJIT_IMM, (uint32_t)pc->k);
583                               if (status != SLJIT_SUCCESS)
584                                         return status;
585                     }
586           } else {
587                     /* k = pc->k */
588                     status = sljit_emit_op1(compiler,
589                         SLJIT_MOV,
590                         SLJIT_R1, 0,
591                         SLJIT_IMM, (uint32_t)pc->k);
592                     if (status != SLJIT_SUCCESS)
593                               return status;
594           }
595 
596           /*
597            * The third argument of fn is an address on stack.
598            */
599           status = sljit_get_local_base(compiler,
600               SLJIT_R2, 0,
601               offsetof(struct bpfjit_stack, err));
602           if (status != SLJIT_SUCCESS)
603                     return status;
604 
605           /* fn(buf, k, &err); */
606           status = sljit_emit_ijump(compiler,
607               SLJIT_CALL3,
608               SLJIT_IMM, SLJIT_FUNC_OFFSET(fn));
609           if (status != SLJIT_SUCCESS)
610                     return status;
611 
612           if (dst != SLJIT_RETURN_REG) {
613                     /* move return value to dst */
614                     status = sljit_emit_op1(compiler,
615                         SLJIT_MOV,
616                         dst, 0,
617                         SLJIT_RETURN_REG, 0);
618                     if (status != SLJIT_SUCCESS)
619                               return status;
620           }
621 
622           /* if (*err != 0) return 0; */
623           jump = sljit_emit_cmp(compiler,
624               SLJIT_NOT_EQUAL|SLJIT_I32_OP,
625               SLJIT_MEM1(SLJIT_SP),
626               offsetof(struct bpfjit_stack, err),
627               SLJIT_IMM, 0);
628           if (jump == NULL)
629                     return SLJIT_ERR_ALLOC_FAILED;
630 
631           if (!append_jump(jump, ret0, ret0_size, ret0_maxsize))
632                     return SLJIT_ERR_ALLOC_FAILED;
633 
634           if (save_reg == BJ_AREG || (hints & BJ_HINT_XREG)) {
635                     /* restore A or X */
636                     status = sljit_emit_op1(compiler,
637                         SLJIT_MOV_U32,
638                         save_reg, 0,
639                         SLJIT_MEM1(SLJIT_SP),
640                         offsetof(struct bpfjit_stack, reg));
641                     if (status != SLJIT_SUCCESS)
642                               return status;
643           }
644 
645           return SLJIT_SUCCESS;
646 }
647 #endif
648 
649 /*
650  * Emit code for BPF_COP and BPF_COPX instructions.
651  */
652 static int
emit_cop(struct sljit_compiler * compiler,bpfjit_hint_t hints,const bpf_ctx_t * bc,const struct bpf_insn * pc,struct sljit_jump *** ret0,size_t * ret0_size,size_t * ret0_maxsize)653 emit_cop(struct sljit_compiler *compiler, bpfjit_hint_t hints,
654     const bpf_ctx_t *bc, const struct bpf_insn *pc,
655     struct sljit_jump ***ret0, size_t *ret0_size, size_t *ret0_maxsize)
656 {
657 #if BJ_XREG    == SLJIT_RETURN_REG   || \
658     BJ_XREG    == SLJIT_R0 || \
659     BJ_XREG    == SLJIT_R1 || \
660     BJ_XREG    == SLJIT_R2 || \
661     BJ_TMP3REG == SLJIT_R0 || \
662     BJ_TMP3REG == SLJIT_R1 || \
663     BJ_TMP3REG == SLJIT_R2
664 #error "Not supported assignment of registers."
665 #endif
666 
667           struct sljit_jump *jump;
668           sljit_s32 call_reg;
669           sljit_sw call_off;
670           int status;
671 
672           BJ_ASSERT(bc != NULL && bc->copfuncs != NULL);
673 
674           if (hints & BJ_HINT_LDX) {
675                     /* save X */
676                     status = sljit_emit_op1(compiler,
677                         SLJIT_MOV_U32,
678                         SLJIT_MEM1(SLJIT_SP),
679                         offsetof(struct bpfjit_stack, reg),
680                         BJ_XREG, 0);
681                     if (status != SLJIT_SUCCESS)
682                               return status;
683           }
684 
685           if (BPF_MISCOP(pc->code) == BPF_COP) {
686                     call_reg = SLJIT_IMM;
687                     call_off = SLJIT_FUNC_OFFSET(bc->copfuncs[pc->k]);
688           } else {
689                     /* if (X >= bc->nfuncs) return 0; */
690                     jump = sljit_emit_cmp(compiler,
691                         SLJIT_GREATER_EQUAL,
692                         BJ_XREG, 0,
693                         SLJIT_IMM, bc->nfuncs);
694                     if (jump == NULL)
695                               return SLJIT_ERR_ALLOC_FAILED;
696                     if (!append_jump(jump, ret0, ret0_size, ret0_maxsize))
697                               return SLJIT_ERR_ALLOC_FAILED;
698 
699                     /* tmp1 = ctx; */
700                     status = sljit_emit_op1(compiler,
701                         SLJIT_MOV_P,
702                         BJ_TMP1REG, 0,
703                         SLJIT_MEM1(SLJIT_SP),
704                         offsetof(struct bpfjit_stack, ctx));
705                     if (status != SLJIT_SUCCESS)
706                               return status;
707 
708                     /* tmp1 = ctx->copfuncs; */
709                     status = sljit_emit_op1(compiler,
710                         SLJIT_MOV_P,
711                         BJ_TMP1REG, 0,
712                         SLJIT_MEM1(BJ_TMP1REG),
713                         offsetof(struct bpf_ctx, copfuncs));
714                     if (status != SLJIT_SUCCESS)
715                               return status;
716 
717                     /* tmp2 = X; */
718                     status = sljit_emit_op1(compiler,
719                         SLJIT_MOV,
720                         BJ_TMP2REG, 0,
721                         BJ_XREG, 0);
722                     if (status != SLJIT_SUCCESS)
723                               return status;
724 
725                     /* tmp3 = ctx->copfuncs[tmp2]; */
726                     call_reg = BJ_TMP3REG;
727                     call_off = 0;
728                     status = sljit_emit_op1(compiler,
729                         SLJIT_MOV_P,
730                         call_reg, call_off,
731                         SLJIT_MEM2(BJ_TMP1REG, BJ_TMP2REG),
732                         SLJIT_WORD_SHIFT);
733                     if (status != SLJIT_SUCCESS)
734                               return status;
735           }
736 
737           /*
738            * Copy bpf_copfunc_t arguments to registers.
739            */
740 #if BJ_AREG != SLJIT_R2
741           status = sljit_emit_op1(compiler,
742               SLJIT_MOV_U32,
743               SLJIT_R2, 0,
744               BJ_AREG, 0);
745           if (status != SLJIT_SUCCESS)
746                     return status;
747 #endif
748 
749           status = sljit_emit_op1(compiler,
750               SLJIT_MOV_P,
751               SLJIT_R0, 0,
752               SLJIT_MEM1(SLJIT_SP),
753               offsetof(struct bpfjit_stack, ctx));
754           if (status != SLJIT_SUCCESS)
755                     return status;
756 
757           status = sljit_emit_op1(compiler,
758               SLJIT_MOV_P,
759               SLJIT_R1, 0,
760               BJ_ARGS, 0);
761           if (status != SLJIT_SUCCESS)
762                     return status;
763 
764           status = sljit_emit_ijump(compiler,
765               SLJIT_CALL3, call_reg, call_off);
766           if (status != SLJIT_SUCCESS)
767                     return status;
768 
769 #if BJ_AREG != SLJIT_RETURN_REG
770           status = sljit_emit_op1(compiler,
771               SLJIT_MOV,
772               BJ_AREG, 0,
773               SLJIT_RETURN_REG, 0);
774           if (status != SLJIT_SUCCESS)
775                     return status;
776 #endif
777 
778           if (hints & BJ_HINT_LDX) {
779                     /* restore X */
780                     status = sljit_emit_op1(compiler,
781                         SLJIT_MOV_U32,
782                         BJ_XREG, 0,
783                         SLJIT_MEM1(SLJIT_SP),
784                         offsetof(struct bpfjit_stack, reg));
785                     if (status != SLJIT_SUCCESS)
786                               return status;
787           }
788 
789           return SLJIT_SUCCESS;
790 }
791 
792 /*
793  * Generate code for
794  * BPF_LD+BPF_W+BPF_ABS    A <- P[k:4]
795  * BPF_LD+BPF_H+BPF_ABS    A <- P[k:2]
796  * BPF_LD+BPF_B+BPF_ABS    A <- P[k:1]
797  * BPF_LD+BPF_W+BPF_IND    A <- P[X+k:4]
798  * BPF_LD+BPF_H+BPF_IND    A <- P[X+k:2]
799  * BPF_LD+BPF_B+BPF_IND    A <- P[X+k:1]
800  */
801 static int
emit_pkt_read(struct sljit_compiler * compiler,bpfjit_hint_t hints,const struct bpf_insn * pc,struct sljit_jump * to_mchain_jump,struct sljit_jump *** ret0,size_t * ret0_size,size_t * ret0_maxsize)802 emit_pkt_read(struct sljit_compiler *compiler, bpfjit_hint_t hints,
803     const struct bpf_insn *pc, struct sljit_jump *to_mchain_jump,
804     struct sljit_jump ***ret0, size_t *ret0_size, size_t *ret0_maxsize)
805 {
806           int status = SLJIT_ERR_ALLOC_FAILED;
807           uint32_t width;
808           sljit_s32 ld_reg;
809           struct sljit_jump *jump;
810 #ifdef _KERNEL
811           struct sljit_label *label;
812           struct sljit_jump *over_mchain_jump;
813           const bool check_zero_buflen = (to_mchain_jump != NULL);
814 #endif
815           const uint32_t k = pc->k;
816 
817 #ifdef _KERNEL
818           if (to_mchain_jump == NULL) {
819                     to_mchain_jump = sljit_emit_cmp(compiler,
820                         SLJIT_EQUAL,
821                         BJ_BUFLEN, 0,
822                         SLJIT_IMM, 0);
823                     if (to_mchain_jump == NULL)
824                               return SLJIT_ERR_ALLOC_FAILED;
825           }
826 #endif
827 
828           ld_reg = BJ_BUF;
829           width = read_width(pc);
830           if (width == 0)
831                     return SLJIT_ERR_ALLOC_FAILED;
832 
833           if (BPF_MODE(pc->code) == BPF_IND) {
834                     /* tmp1 = buflen - (pc->k + width); */
835                     status = sljit_emit_op2(compiler,
836                         SLJIT_SUB,
837                         BJ_TMP1REG, 0,
838                         BJ_BUFLEN, 0,
839                         SLJIT_IMM, k + width);
840                     if (status != SLJIT_SUCCESS)
841                               return status;
842 
843                     /* ld_reg = buf + X; */
844                     ld_reg = BJ_TMP2REG;
845                     status = sljit_emit_op2(compiler,
846                         SLJIT_ADD,
847                         ld_reg, 0,
848                         BJ_BUF, 0,
849                         BJ_XREG, 0);
850                     if (status != SLJIT_SUCCESS)
851                               return status;
852 
853                     /* if (tmp1 < X) return 0; */
854                     jump = sljit_emit_cmp(compiler,
855                         SLJIT_LESS,
856                         BJ_TMP1REG, 0,
857                         BJ_XREG, 0);
858                     if (jump == NULL)
859                               return SLJIT_ERR_ALLOC_FAILED;
860                     if (!append_jump(jump, ret0, ret0_size, ret0_maxsize))
861                               return SLJIT_ERR_ALLOC_FAILED;
862           }
863 
864           /*
865            * Don't emit wrapped-around reads. They're dead code but
866            * dead code elimination logic isn't smart enough to figure
867            * it out.
868            */
869           if (k <= UINT32_MAX - width + 1) {
870                     switch (width) {
871                     case 4:
872                               status = emit_read32(compiler, ld_reg, k);
873                               break;
874                     case 2:
875                               status = emit_read16(compiler, ld_reg, k);
876                               break;
877                     case 1:
878                               status = emit_read8(compiler, ld_reg, k);
879                               break;
880                     }
881 
882                     if (status != SLJIT_SUCCESS)
883                               return status;
884           }
885 
886 #ifdef _KERNEL
887           over_mchain_jump = sljit_emit_jump(compiler, SLJIT_JUMP);
888           if (over_mchain_jump == NULL)
889                     return SLJIT_ERR_ALLOC_FAILED;
890 
891           /* entry point to mchain handler */
892           label = sljit_emit_label(compiler);
893           if (label == NULL)
894                     return SLJIT_ERR_ALLOC_FAILED;
895           sljit_set_label(to_mchain_jump, label);
896 
897           if (check_zero_buflen) {
898                     /* if (buflen != 0) return 0; */
899                     jump = sljit_emit_cmp(compiler,
900                         SLJIT_NOT_EQUAL,
901                         BJ_BUFLEN, 0,
902                         SLJIT_IMM, 0);
903                     if (jump == NULL)
904                               return SLJIT_ERR_ALLOC_FAILED;
905                     if (!append_jump(jump, ret0, ret0_size, ret0_maxsize))
906                               return SLJIT_ERR_ALLOC_FAILED;
907           }
908 
909           switch (width) {
910           case 4:
911                     status = emit_xcall(compiler, hints, pc, BJ_AREG,
912                         ret0, ret0_size, ret0_maxsize, &m_xword);
913                     break;
914           case 2:
915                     status = emit_xcall(compiler, hints, pc, BJ_AREG,
916                         ret0, ret0_size, ret0_maxsize, &m_xhalf);
917                     break;
918           case 1:
919                     status = emit_xcall(compiler, hints, pc, BJ_AREG,
920                         ret0, ret0_size, ret0_maxsize, &m_xbyte);
921                     break;
922           }
923 
924           if (status != SLJIT_SUCCESS)
925                     return status;
926 
927           label = sljit_emit_label(compiler);
928           if (label == NULL)
929                     return SLJIT_ERR_ALLOC_FAILED;
930           sljit_set_label(over_mchain_jump, label);
931 #endif
932 
933           return SLJIT_SUCCESS;
934 }
935 
936 static int
emit_memload(struct sljit_compiler * compiler,sljit_s32 dst,uint32_t k,size_t extwords)937 emit_memload(struct sljit_compiler *compiler,
938     sljit_s32 dst, uint32_t k, size_t extwords)
939 {
940           int status;
941           sljit_s32 src;
942           sljit_sw srcw;
943 
944           srcw = k * sizeof(uint32_t);
945 
946           if (extwords == 0) {
947                     src = SLJIT_MEM1(SLJIT_SP);
948                     srcw += offsetof(struct bpfjit_stack, mem);
949           } else {
950                     /* copy extmem pointer to the tmp1 register */
951                     status = sljit_emit_op1(compiler,
952                         SLJIT_MOV_P,
953                         BJ_TMP1REG, 0,
954                         SLJIT_MEM1(SLJIT_SP),
955                         offsetof(struct bpfjit_stack, extmem));
956                     if (status != SLJIT_SUCCESS)
957                               return status;
958                     src = SLJIT_MEM1(BJ_TMP1REG);
959           }
960 
961           return sljit_emit_op1(compiler, SLJIT_MOV_U32, dst, 0, src, srcw);
962 }
963 
964 static int
emit_memstore(struct sljit_compiler * compiler,sljit_s32 src,uint32_t k,size_t extwords)965 emit_memstore(struct sljit_compiler *compiler,
966     sljit_s32 src, uint32_t k, size_t extwords)
967 {
968           int status;
969           sljit_s32 dst;
970           sljit_sw dstw;
971 
972           dstw = k * sizeof(uint32_t);
973 
974           if (extwords == 0) {
975                     dst = SLJIT_MEM1(SLJIT_SP);
976                     dstw += offsetof(struct bpfjit_stack, mem);
977           } else {
978                     /* copy extmem pointer to the tmp1 register */
979                     status = sljit_emit_op1(compiler,
980                         SLJIT_MOV_P,
981                         BJ_TMP1REG, 0,
982                         SLJIT_MEM1(SLJIT_SP),
983                         offsetof(struct bpfjit_stack, extmem));
984                     if (status != SLJIT_SUCCESS)
985                               return status;
986                     dst = SLJIT_MEM1(BJ_TMP1REG);
987           }
988 
989           return sljit_emit_op1(compiler, SLJIT_MOV_U32, dst, dstw, src, 0);
990 }
991 
992 /*
993  * Emit code for BPF_LDX+BPF_B+BPF_MSH    X <- 4*(P[k:1]&0xf).
994  */
995 static int
emit_msh(struct sljit_compiler * compiler,bpfjit_hint_t hints,const struct bpf_insn * pc,struct sljit_jump * to_mchain_jump,struct sljit_jump *** ret0,size_t * ret0_size,size_t * ret0_maxsize)996 emit_msh(struct sljit_compiler *compiler, bpfjit_hint_t hints,
997     const struct bpf_insn *pc, struct sljit_jump *to_mchain_jump,
998     struct sljit_jump ***ret0, size_t *ret0_size, size_t *ret0_maxsize)
999 {
1000           int status;
1001 #ifdef _KERNEL
1002           struct sljit_label *label;
1003           struct sljit_jump *jump, *over_mchain_jump;
1004           const bool check_zero_buflen = (to_mchain_jump != NULL);
1005 #endif
1006           const uint32_t k = pc->k;
1007 
1008 #ifdef _KERNEL
1009           if (to_mchain_jump == NULL) {
1010                     to_mchain_jump = sljit_emit_cmp(compiler,
1011                         SLJIT_EQUAL,
1012                         BJ_BUFLEN, 0,
1013                         SLJIT_IMM, 0);
1014                     if (to_mchain_jump == NULL)
1015                               return SLJIT_ERR_ALLOC_FAILED;
1016           }
1017 #endif
1018 
1019           /* tmp1 = buf[k] */
1020           status = sljit_emit_op1(compiler,
1021               SLJIT_MOV_U8,
1022               BJ_TMP1REG, 0,
1023               SLJIT_MEM1(BJ_BUF), k);
1024           if (status != SLJIT_SUCCESS)
1025                     return status;
1026 
1027 #ifdef _KERNEL
1028           over_mchain_jump = sljit_emit_jump(compiler, SLJIT_JUMP);
1029           if (over_mchain_jump == NULL)
1030                     return SLJIT_ERR_ALLOC_FAILED;
1031 
1032           /* entry point to mchain handler */
1033           label = sljit_emit_label(compiler);
1034           if (label == NULL)
1035                     return SLJIT_ERR_ALLOC_FAILED;
1036           sljit_set_label(to_mchain_jump, label);
1037 
1038           if (check_zero_buflen) {
1039                     /* if (buflen != 0) return 0; */
1040                     jump = sljit_emit_cmp(compiler,
1041                         SLJIT_NOT_EQUAL,
1042                         BJ_BUFLEN, 0,
1043                         SLJIT_IMM, 0);
1044                     if (jump == NULL)
1045                               return SLJIT_ERR_ALLOC_FAILED;
1046                     if (!append_jump(jump, ret0, ret0_size, ret0_maxsize))
1047                               return SLJIT_ERR_ALLOC_FAILED;
1048           }
1049 
1050           status = emit_xcall(compiler, hints, pc, BJ_TMP1REG,
1051               ret0, ret0_size, ret0_maxsize, &m_xbyte);
1052           if (status != SLJIT_SUCCESS)
1053                     return status;
1054 
1055           label = sljit_emit_label(compiler);
1056           if (label == NULL)
1057                     return SLJIT_ERR_ALLOC_FAILED;
1058           sljit_set_label(over_mchain_jump, label);
1059 #endif
1060 
1061           /* tmp1 &= 0xf */
1062           status = sljit_emit_op2(compiler,
1063               SLJIT_AND,
1064               BJ_TMP1REG, 0,
1065               BJ_TMP1REG, 0,
1066               SLJIT_IMM, 0xf);
1067           if (status != SLJIT_SUCCESS)
1068                     return status;
1069 
1070           /* X = tmp1 << 2 */
1071           status = sljit_emit_op2(compiler,
1072               SLJIT_SHL,
1073               BJ_XREG, 0,
1074               BJ_TMP1REG, 0,
1075               SLJIT_IMM, 2);
1076           if (status != SLJIT_SUCCESS)
1077                     return status;
1078 
1079           return SLJIT_SUCCESS;
1080 }
1081 
1082 /*
1083  * Emit code for A = A / k or A = A % k when k is a power of 2.
1084  * @pc BPF_DIV or BPF_MOD instruction.
1085  */
1086 static int
emit_pow2_moddiv(struct sljit_compiler * compiler,const struct bpf_insn * pc)1087 emit_pow2_moddiv(struct sljit_compiler *compiler, const struct bpf_insn *pc)
1088 {
1089           uint32_t k = pc->k;
1090           int status = SLJIT_SUCCESS;
1091 
1092           BJ_ASSERT(k != 0 && (k & (k - 1)) == 0);
1093 
1094           if (BPF_OP(pc->code) == BPF_MOD) {
1095                     status = sljit_emit_op2(compiler,
1096                         SLJIT_AND,
1097                         BJ_AREG, 0,
1098                         BJ_AREG, 0,
1099                         SLJIT_IMM, k - 1);
1100           } else {
1101                     int shift = 0;
1102 
1103                     /*
1104                      * Do shift = __builtin_ctz(k).
1105                      * The loop is slower, but that's ok.
1106                      */
1107                     while (k > 1) {
1108                               k >>= 1;
1109                               shift++;
1110                     }
1111 
1112                     if (shift != 0) {
1113                               status = sljit_emit_op2(compiler,
1114                                   SLJIT_LSHR|SLJIT_I32_OP,
1115                                   BJ_AREG, 0,
1116                                   BJ_AREG, 0,
1117                                   SLJIT_IMM, shift);
1118                     }
1119           }
1120 
1121           return status;
1122 }
1123 
1124 #if !defined(BPFJIT_USE_UDIV)
1125 static sljit_uw
divide(sljit_uw x,sljit_uw y)1126 divide(sljit_uw x, sljit_uw y)
1127 {
1128 
1129           return (uint32_t)x / (uint32_t)y;
1130 }
1131 
1132 static sljit_uw
modulus(sljit_uw x,sljit_uw y)1133 modulus(sljit_uw x, sljit_uw y)
1134 {
1135 
1136           return (uint32_t)x % (uint32_t)y;
1137 }
1138 #endif
1139 
1140 /*
1141  * Emit code for A = A / div or A = A % div.
1142  * @pc BPF_DIV or BPF_MOD instruction.
1143  */
1144 static int
emit_moddiv(struct sljit_compiler * compiler,const struct bpf_insn * pc)1145 emit_moddiv(struct sljit_compiler *compiler, const struct bpf_insn *pc)
1146 {
1147           int status;
1148           const bool xdiv = BPF_OP(pc->code) == BPF_DIV;
1149           const bool xreg = BPF_SRC(pc->code) == BPF_X;
1150 
1151 #if BJ_XREG == SLJIT_RETURN_REG   || \
1152     BJ_XREG == SLJIT_R0 || \
1153     BJ_XREG == SLJIT_R1 || \
1154     BJ_AREG == SLJIT_R1
1155 #error "Not supported assignment of registers."
1156 #endif
1157 
1158 #if BJ_AREG != SLJIT_R0
1159           status = sljit_emit_op1(compiler,
1160               SLJIT_MOV,
1161               SLJIT_R0, 0,
1162               BJ_AREG, 0);
1163           if (status != SLJIT_SUCCESS)
1164                     return status;
1165 #endif
1166 
1167           status = sljit_emit_op1(compiler,
1168               SLJIT_MOV,
1169               SLJIT_R1, 0,
1170               xreg ? BJ_XREG : SLJIT_IMM,
1171               xreg ? 0 : (uint32_t)pc->k);
1172           if (status != SLJIT_SUCCESS)
1173                     return status;
1174 
1175 #if defined(BPFJIT_USE_UDIV)
1176           status = sljit_emit_op0(compiler, SLJIT_UDIV|SLJIT_I32_OP);
1177 
1178           if (BPF_OP(pc->code) == BPF_DIV) {
1179 #if BJ_AREG != SLJIT_R0
1180                     status = sljit_emit_op1(compiler,
1181                         SLJIT_MOV,
1182                         BJ_AREG, 0,
1183                         SLJIT_R0, 0);
1184 #endif
1185           } else {
1186 #if BJ_AREG != SLJIT_R1
1187                     /* Remainder is in SLJIT_R1. */
1188                     status = sljit_emit_op1(compiler,
1189                         SLJIT_MOV,
1190                         BJ_AREG, 0,
1191                         SLJIT_R1, 0);
1192 #endif
1193           }
1194 
1195           if (status != SLJIT_SUCCESS)
1196                     return status;
1197 #else
1198           status = sljit_emit_ijump(compiler,
1199               SLJIT_CALL2,
1200               SLJIT_IMM, xdiv ? SLJIT_FUNC_OFFSET(divide) :
1201                     SLJIT_FUNC_OFFSET(modulus));
1202 
1203 #if BJ_AREG != SLJIT_RETURN_REG
1204           status = sljit_emit_op1(compiler,
1205               SLJIT_MOV,
1206               BJ_AREG, 0,
1207               SLJIT_RETURN_REG, 0);
1208           if (status != SLJIT_SUCCESS)
1209                     return status;
1210 #endif
1211 #endif
1212 
1213           return status;
1214 }
1215 
1216 /*
1217  * Return true if pc is a "read from packet" instruction.
1218  * If length is not NULL and return value is true, *length will
1219  * be set to a safe length required to read a packet.
1220  */
1221 static bool
read_pkt_insn(const struct bpf_insn * pc,bpfjit_abc_length_t * length)1222 read_pkt_insn(const struct bpf_insn *pc, bpfjit_abc_length_t *length)
1223 {
1224           bool rv;
1225           bpfjit_abc_length_t width = 0; /* XXXuninit */
1226 
1227           switch (BPF_CLASS(pc->code)) {
1228           default:
1229                     rv = false;
1230                     break;
1231 
1232           case BPF_LD:
1233                     rv = BPF_MODE(pc->code) == BPF_ABS ||
1234                          BPF_MODE(pc->code) == BPF_IND;
1235                     if (rv) {
1236                               width = read_width(pc);
1237                               rv = (width != 0);
1238                     }
1239                     break;
1240 
1241           case BPF_LDX:
1242                     rv = BPF_MODE(pc->code) == BPF_MSH &&
1243                          BPF_SIZE(pc->code) == BPF_B;
1244                     width = 1;
1245                     break;
1246           }
1247 
1248           if (rv && length != NULL) {
1249                     /*
1250                      * Values greater than UINT32_MAX will generate
1251                      * unconditional "return 0".
1252                      */
1253                     *length = (uint32_t)pc->k + width;
1254           }
1255 
1256           return rv;
1257 }
1258 
1259 static void
optimize_init(struct bpfjit_insn_data * insn_dat,size_t insn_count)1260 optimize_init(struct bpfjit_insn_data *insn_dat, size_t insn_count)
1261 {
1262           size_t i;
1263 
1264           for (i = 0; i < insn_count; i++) {
1265                     SLIST_INIT(&insn_dat[i].bjumps);
1266                     insn_dat[i].invalid = BJ_INIT_NOBITS;
1267           }
1268 }
1269 
1270 /*
1271  * The function divides instructions into blocks. Destination of a jump
1272  * instruction starts a new block. BPF_RET and BPF_JMP instructions
1273  * terminate a block. Blocks are linear, that is, there are no jumps out
1274  * from the middle of a block and there are no jumps in to the middle of
1275  * a block.
1276  *
1277  * The function also sets bits in *initmask for memwords that
1278  * need to be initialized to zero. Note that this set should be empty
1279  * for any valid kernel filter program.
1280  */
1281 static bool
optimize_pass1(const bpf_ctx_t * bc,const struct bpf_insn * insns,struct bpfjit_insn_data * insn_dat,size_t insn_count,bpf_memword_init_t * initmask,bpfjit_hint_t * hints)1282 optimize_pass1(const bpf_ctx_t *bc, const struct bpf_insn *insns,
1283     struct bpfjit_insn_data *insn_dat, size_t insn_count,
1284     bpf_memword_init_t *initmask, bpfjit_hint_t *hints)
1285 {
1286           struct bpfjit_jump *jtf;
1287           size_t i;
1288           uint32_t jt, jf;
1289           bpfjit_abc_length_t length;
1290           bpf_memword_init_t invalid; /* borrowed from bpf_filter() */
1291           bool unreachable;
1292 
1293           const size_t memwords = GET_MEMWORDS(bc);
1294 
1295           *hints = 0;
1296           *initmask = BJ_INIT_NOBITS;
1297 
1298           unreachable = false;
1299           invalid = ~BJ_INIT_NOBITS;
1300 
1301           for (i = 0; i < insn_count; i++) {
1302                     if (!SLIST_EMPTY(&insn_dat[i].bjumps))
1303                               unreachable = false;
1304                     insn_dat[i].unreachable = unreachable;
1305 
1306                     if (unreachable)
1307                               continue;
1308 
1309                     invalid |= insn_dat[i].invalid;
1310 
1311                     if (read_pkt_insn(&insns[i], &length) && length > UINT32_MAX)
1312                               unreachable = true;
1313 
1314                     switch (BPF_CLASS(insns[i].code)) {
1315                     case BPF_RET:
1316                               if (BPF_RVAL(insns[i].code) == BPF_A)
1317                                         *initmask |= invalid & BJ_INIT_ABIT;
1318 
1319                               unreachable = true;
1320                               continue;
1321 
1322                     case BPF_LD:
1323                               if (BPF_MODE(insns[i].code) == BPF_ABS)
1324                                         *hints |= BJ_HINT_ABS;
1325 
1326                               if (BPF_MODE(insns[i].code) == BPF_IND) {
1327                                         *hints |= BJ_HINT_IND | BJ_HINT_XREG;
1328                                         *initmask |= invalid & BJ_INIT_XBIT;
1329                               }
1330 
1331                               if (BPF_MODE(insns[i].code) == BPF_MEM &&
1332                                   (uint32_t)insns[i].k < memwords) {
1333                                         *initmask |= invalid & BJ_INIT_MBIT(insns[i].k);
1334                               }
1335 
1336                               invalid &= ~BJ_INIT_ABIT;
1337                               continue;
1338 
1339                     case BPF_LDX:
1340                               *hints |= BJ_HINT_XREG | BJ_HINT_LDX;
1341 
1342                               if (BPF_MODE(insns[i].code) == BPF_MEM &&
1343                                   (uint32_t)insns[i].k < memwords) {
1344                                         *initmask |= invalid & BJ_INIT_MBIT(insns[i].k);
1345                               }
1346 
1347                               if (BPF_MODE(insns[i].code) == BPF_MSH &&
1348                                   BPF_SIZE(insns[i].code) == BPF_B) {
1349                                         *hints |= BJ_HINT_MSH;
1350                               }
1351 
1352                               invalid &= ~BJ_INIT_XBIT;
1353                               continue;
1354 
1355                     case BPF_ST:
1356                               *initmask |= invalid & BJ_INIT_ABIT;
1357 
1358                               if ((uint32_t)insns[i].k < memwords)
1359                                         invalid &= ~BJ_INIT_MBIT(insns[i].k);
1360 
1361                               continue;
1362 
1363                     case BPF_STX:
1364                               *hints |= BJ_HINT_XREG;
1365                               *initmask |= invalid & BJ_INIT_XBIT;
1366 
1367                               if ((uint32_t)insns[i].k < memwords)
1368                                         invalid &= ~BJ_INIT_MBIT(insns[i].k);
1369 
1370                               continue;
1371 
1372                     case BPF_ALU:
1373                               *initmask |= invalid & BJ_INIT_ABIT;
1374 
1375                               if (insns[i].code != (BPF_ALU|BPF_NEG) &&
1376                                   BPF_SRC(insns[i].code) == BPF_X) {
1377                                         *hints |= BJ_HINT_XREG;
1378                                         *initmask |= invalid & BJ_INIT_XBIT;
1379                               }
1380 
1381                               invalid &= ~BJ_INIT_ABIT;
1382                               continue;
1383 
1384                     case BPF_MISC:
1385                               switch (BPF_MISCOP(insns[i].code)) {
1386                               case BPF_TAX: // X <- A
1387                                         *hints |= BJ_HINT_XREG;
1388                                         *initmask |= invalid & BJ_INIT_ABIT;
1389                                         invalid &= ~BJ_INIT_XBIT;
1390                                         continue;
1391 
1392                               case BPF_TXA: // A <- X
1393                                         *hints |= BJ_HINT_XREG;
1394                                         *initmask |= invalid & BJ_INIT_XBIT;
1395                                         invalid &= ~BJ_INIT_ABIT;
1396                                         continue;
1397 
1398                               case BPF_COPX:
1399                                         *hints |= BJ_HINT_XREG | BJ_HINT_COPX;
1400                                         /* FALLTHROUGH */
1401 
1402                               case BPF_COP:
1403                                         *hints |= BJ_HINT_COP;
1404                                         *initmask |= invalid & BJ_INIT_ABIT;
1405                                         invalid &= ~BJ_INIT_ABIT;
1406                                         continue;
1407                               }
1408 
1409                               continue;
1410 
1411                     case BPF_JMP:
1412                               /* Initialize abc_length for ABC pass. */
1413                               insn_dat[i].u.jdata.abc_length = MAX_ABC_LENGTH;
1414 
1415                               *initmask |= invalid & BJ_INIT_ABIT;
1416 
1417                               if (BPF_SRC(insns[i].code) == BPF_X) {
1418                                         *hints |= BJ_HINT_XREG;
1419                                         *initmask |= invalid & BJ_INIT_XBIT;
1420                               }
1421 
1422                               if (BPF_OP(insns[i].code) == BPF_JA) {
1423                                         jt = jf = insns[i].k;
1424                               } else {
1425                                         jt = insns[i].jt;
1426                                         jf = insns[i].jf;
1427                               }
1428 
1429                               if (jt >= insn_count - (i + 1) ||
1430                                   jf >= insn_count - (i + 1)) {
1431                                         return false;
1432                               }
1433 
1434                               if (jt > 0 && jf > 0)
1435                                         unreachable = true;
1436 
1437                               jt += i + 1;
1438                               jf += i + 1;
1439 
1440                               jtf = insn_dat[i].u.jdata.jtf;
1441 
1442                               jtf[0].jdata = &insn_dat[i].u.jdata;
1443                               SLIST_INSERT_HEAD(&insn_dat[jt].bjumps,
1444                                   &jtf[0], entries);
1445 
1446                               if (jf != jt) {
1447                                         jtf[1].jdata = &insn_dat[i].u.jdata;
1448                                         SLIST_INSERT_HEAD(&insn_dat[jf].bjumps,
1449                                             &jtf[1], entries);
1450                               }
1451 
1452                               insn_dat[jf].invalid |= invalid;
1453                               insn_dat[jt].invalid |= invalid;
1454                               invalid = 0;
1455 
1456                               continue;
1457                     }
1458           }
1459 
1460           return true;
1461 }
1462 
1463 /*
1464  * Array Bounds Check Elimination (ABC) pass.
1465  */
1466 static void
optimize_pass2(const bpf_ctx_t * bc,const struct bpf_insn * insns,struct bpfjit_insn_data * insn_dat,size_t insn_count)1467 optimize_pass2(const bpf_ctx_t *bc, const struct bpf_insn *insns,
1468     struct bpfjit_insn_data *insn_dat, size_t insn_count)
1469 {
1470           struct bpfjit_jump *jmp;
1471           const struct bpf_insn *pc;
1472           struct bpfjit_insn_data *pd;
1473           size_t i;
1474           bpfjit_abc_length_t length, abc_length = 0;
1475 
1476           const size_t extwords = GET_EXTWORDS(bc);
1477 
1478           for (i = insn_count; i != 0; i--) {
1479                     pc = &insns[i-1];
1480                     pd = &insn_dat[i-1];
1481 
1482                     if (pd->unreachable)
1483                               continue;
1484 
1485                     switch (BPF_CLASS(pc->code)) {
1486                     case BPF_RET:
1487                               /*
1488                                * It's quite common for bpf programs to
1489                                * check packet bytes in increasing order
1490                                * and return zero if bytes don't match
1491                                * specified critetion. Such programs disable
1492                                * ABC optimization completely because for
1493                                * every jump there is a branch with no read
1494                                * instruction.
1495                                * With no side effects, BPF_STMT(BPF_RET+BPF_K, 0)
1496                                * is indistinguishable from out-of-bound load.
1497                                * Therefore, abc_length can be set to
1498                                * MAX_ABC_LENGTH and enable ABC for many
1499                                * bpf programs.
1500                                * If this optimization encounters any
1501                                * instruction with a side effect, it will
1502                                * reset abc_length.
1503                                */
1504                               if (BPF_RVAL(pc->code) == BPF_K && pc->k == 0)
1505                                         abc_length = MAX_ABC_LENGTH;
1506                               else
1507                                         abc_length = 0;
1508                               break;
1509 
1510                     case BPF_MISC:
1511                               if (BPF_MISCOP(pc->code) == BPF_COP ||
1512                                   BPF_MISCOP(pc->code) == BPF_COPX) {
1513                                         /* COP instructions can have side effects. */
1514                                         abc_length = 0;
1515                               }
1516                               break;
1517 
1518                     case BPF_ST:
1519                     case BPF_STX:
1520                               if (extwords != 0) {
1521                                         /* Write to memory is visible after a call. */
1522                                         abc_length = 0;
1523                               }
1524                               break;
1525 
1526                     case BPF_JMP:
1527                               abc_length = pd->u.jdata.abc_length;
1528                               break;
1529 
1530                     default:
1531                               if (read_pkt_insn(pc, &length)) {
1532                                         if (abc_length < length)
1533                                                   abc_length = length;
1534                                         pd->u.rdata.abc_length = abc_length;
1535                               }
1536                               break;
1537                     }
1538 
1539                     SLIST_FOREACH(jmp, &pd->bjumps, entries) {
1540                               if (jmp->jdata->abc_length > abc_length)
1541                                         jmp->jdata->abc_length = abc_length;
1542                     }
1543           }
1544 }
1545 
1546 static void
optimize_pass3(const struct bpf_insn * insns,struct bpfjit_insn_data * insn_dat,size_t insn_count)1547 optimize_pass3(const struct bpf_insn *insns,
1548     struct bpfjit_insn_data *insn_dat, size_t insn_count)
1549 {
1550           struct bpfjit_jump *jmp;
1551           size_t i;
1552           bpfjit_abc_length_t checked_length = 0;
1553 
1554           for (i = 0; i < insn_count; i++) {
1555                     if (insn_dat[i].unreachable)
1556                               continue;
1557 
1558                     SLIST_FOREACH(jmp, &insn_dat[i].bjumps, entries) {
1559                               if (jmp->jdata->checked_length < checked_length)
1560                                         checked_length = jmp->jdata->checked_length;
1561                     }
1562 
1563                     if (BPF_CLASS(insns[i].code) == BPF_JMP) {
1564                               insn_dat[i].u.jdata.checked_length = checked_length;
1565                     } else if (read_pkt_insn(&insns[i], NULL)) {
1566                               struct bpfjit_read_pkt_data *rdata =
1567                                   &insn_dat[i].u.rdata;
1568                               rdata->check_length = 0;
1569                               if (checked_length < rdata->abc_length) {
1570                                         checked_length = rdata->abc_length;
1571                                         rdata->check_length = checked_length;
1572                               }
1573                     }
1574           }
1575 }
1576 
1577 static bool
optimize(const bpf_ctx_t * bc,const struct bpf_insn * insns,struct bpfjit_insn_data * insn_dat,size_t insn_count,bpf_memword_init_t * initmask,bpfjit_hint_t * hints)1578 optimize(const bpf_ctx_t *bc, const struct bpf_insn *insns,
1579     struct bpfjit_insn_data *insn_dat, size_t insn_count,
1580     bpf_memword_init_t *initmask, bpfjit_hint_t *hints)
1581 {
1582 
1583           optimize_init(insn_dat, insn_count);
1584 
1585           if (!optimize_pass1(bc, insns, insn_dat, insn_count, initmask, hints))
1586                     return false;
1587 
1588           optimize_pass2(bc, insns, insn_dat, insn_count);
1589           optimize_pass3(insns, insn_dat, insn_count);
1590 
1591           return true;
1592 }
1593 
1594 /*
1595  * Convert BPF_ALU operations except BPF_NEG and BPF_DIV to sljit operation.
1596  */
1597 static bool
alu_to_op(const struct bpf_insn * pc,int * res)1598 alu_to_op(const struct bpf_insn *pc, int *res)
1599 {
1600           const uint32_t k = pc->k;
1601 
1602           /*
1603            * Note: all supported 64bit arches have 32bit multiply
1604            * instruction so SLJIT_I32_OP doesn't have any overhead.
1605            */
1606           switch (BPF_OP(pc->code)) {
1607           case BPF_ADD:
1608                     *res = SLJIT_ADD;
1609                     return true;
1610           case BPF_SUB:
1611                     *res = SLJIT_SUB;
1612                     return true;
1613           case BPF_MUL:
1614                     *res = SLJIT_MUL|SLJIT_I32_OP;
1615                     return true;
1616           case BPF_OR:
1617                     *res = SLJIT_OR;
1618                     return true;
1619           case BPF_XOR:
1620                     *res = SLJIT_XOR;
1621                     return true;
1622           case BPF_AND:
1623                     *res = SLJIT_AND;
1624                     return true;
1625           case BPF_LSH:
1626                     *res = SLJIT_SHL;
1627                     return k < 32;
1628           case BPF_RSH:
1629                     *res = SLJIT_LSHR|SLJIT_I32_OP;
1630                     return k < 32;
1631           default:
1632                     return false;
1633           }
1634 }
1635 
1636 /*
1637  * Convert BPF_JMP operations except BPF_JA to sljit condition.
1638  */
1639 static bool
jmp_to_cond(const struct bpf_insn * pc,bool negate,int * res)1640 jmp_to_cond(const struct bpf_insn *pc, bool negate, int *res)
1641 {
1642 
1643           /*
1644            * Note: all supported 64bit arches have 32bit comparison
1645            * instructions so SLJIT_I32_OP doesn't have any overhead.
1646            */
1647           *res = SLJIT_I32_OP;
1648 
1649           switch (BPF_OP(pc->code)) {
1650           case BPF_JGT:
1651                     *res |= negate ? SLJIT_LESS_EQUAL : SLJIT_GREATER;
1652                     return true;
1653           case BPF_JGE:
1654                     *res |= negate ? SLJIT_LESS : SLJIT_GREATER_EQUAL;
1655                     return true;
1656           case BPF_JEQ:
1657                     *res |= negate ? SLJIT_NOT_EQUAL : SLJIT_EQUAL;
1658                     return true;
1659           case BPF_JSET:
1660                     *res |= negate ? SLJIT_EQUAL : SLJIT_NOT_EQUAL;
1661                     return true;
1662           default:
1663                     return false;
1664           }
1665 }
1666 
1667 /*
1668  * Convert BPF_K and BPF_X to sljit register.
1669  */
1670 static int
kx_to_reg(const struct bpf_insn * pc)1671 kx_to_reg(const struct bpf_insn *pc)
1672 {
1673 
1674           switch (BPF_SRC(pc->code)) {
1675           case BPF_K: return SLJIT_IMM;
1676           case BPF_X: return BJ_XREG;
1677           default:
1678                     BJ_ASSERT(false);
1679                     return 0;
1680           }
1681 }
1682 
1683 static sljit_sw
kx_to_reg_arg(const struct bpf_insn * pc)1684 kx_to_reg_arg(const struct bpf_insn *pc)
1685 {
1686 
1687           switch (BPF_SRC(pc->code)) {
1688           case BPF_K: return (uint32_t)pc->k; /* SLJIT_IMM, pc->k, */
1689           case BPF_X: return 0;               /* BJ_XREG, 0,      */
1690           default:
1691                     BJ_ASSERT(false);
1692                     return 0;
1693           }
1694 }
1695 
1696 static bool
generate_insn_code(struct sljit_compiler * compiler,bpfjit_hint_t hints,const bpf_ctx_t * bc,const struct bpf_insn * insns,struct bpfjit_insn_data * insn_dat,size_t insn_count)1697 generate_insn_code(struct sljit_compiler *compiler, bpfjit_hint_t hints,
1698     const bpf_ctx_t *bc, const struct bpf_insn *insns,
1699     struct bpfjit_insn_data *insn_dat, size_t insn_count)
1700 {
1701           /* a list of jumps to out-of-bound return from a generated function */
1702           struct sljit_jump **ret0;
1703           size_t ret0_size, ret0_maxsize;
1704 
1705           struct sljit_jump *jump;
1706           struct sljit_label *label;
1707           const struct bpf_insn *pc;
1708           struct bpfjit_jump *bjump, *jtf;
1709           struct sljit_jump *to_mchain_jump;
1710 
1711           size_t i;
1712           unsigned int rval, mode, src, op;
1713           int branching, negate;
1714           int status, cond, op2;
1715           uint32_t jt, jf;
1716 
1717           bool unconditional_ret;
1718           bool rv;
1719 
1720           const size_t extwords = GET_EXTWORDS(bc);
1721           const size_t memwords = GET_MEMWORDS(bc);
1722 
1723           ret0 = NULL;
1724           rv = false;
1725 
1726           ret0_size = 0;
1727           ret0_maxsize = 64;
1728           ret0 = BJ_ALLOC(ret0_maxsize * sizeof(ret0[0]));
1729           if (ret0 == NULL)
1730                     goto fail;
1731 
1732           /* reset sjump members of jdata */
1733           for (i = 0; i < insn_count; i++) {
1734                     if (insn_dat[i].unreachable ||
1735                         BPF_CLASS(insns[i].code) != BPF_JMP) {
1736                               continue;
1737                     }
1738 
1739                     jtf = insn_dat[i].u.jdata.jtf;
1740                     jtf[0].sjump = jtf[1].sjump = NULL;
1741           }
1742 
1743           /* main loop */
1744           for (i = 0; i < insn_count; i++) {
1745                     if (insn_dat[i].unreachable)
1746                               continue;
1747 
1748                     /*
1749                      * Resolve jumps to the current insn.
1750                      */
1751                     label = NULL;
1752                     SLIST_FOREACH(bjump, &insn_dat[i].bjumps, entries) {
1753                               if (bjump->sjump != NULL) {
1754                                         if (label == NULL)
1755                                                   label = sljit_emit_label(compiler);
1756                                         if (label == NULL)
1757                                                   goto fail;
1758                                         sljit_set_label(bjump->sjump, label);
1759                               }
1760                     }
1761 
1762                     to_mchain_jump = NULL;
1763                     unconditional_ret = false;
1764 
1765                     if (read_pkt_insn(&insns[i], NULL)) {
1766                               if (insn_dat[i].u.rdata.check_length > UINT32_MAX) {
1767                                         /* Jump to "return 0" unconditionally. */
1768                                         unconditional_ret = true;
1769                                         jump = sljit_emit_jump(compiler, SLJIT_JUMP);
1770                                         if (jump == NULL)
1771                                                   goto fail;
1772                                         if (!append_jump(jump, &ret0,
1773                                             &ret0_size, &ret0_maxsize))
1774                                                   goto fail;
1775                               } else if (insn_dat[i].u.rdata.check_length > 0) {
1776                                         /* if (buflen < check_length) return 0; */
1777                                         jump = sljit_emit_cmp(compiler,
1778                                             SLJIT_LESS,
1779                                             BJ_BUFLEN, 0,
1780                                             SLJIT_IMM,
1781                                             insn_dat[i].u.rdata.check_length);
1782                                         if (jump == NULL)
1783                                                   goto fail;
1784 #ifdef _KERNEL
1785                                         to_mchain_jump = jump;
1786 #else
1787                                         if (!append_jump(jump, &ret0,
1788                                             &ret0_size, &ret0_maxsize))
1789                                                   goto fail;
1790 #endif
1791                               }
1792                     }
1793 
1794                     pc = &insns[i];
1795                     switch (BPF_CLASS(pc->code)) {
1796 
1797                     default:
1798                               goto fail;
1799 
1800                     case BPF_LD:
1801                               /* BPF_LD+BPF_IMM          A <- k */
1802                               if (pc->code == (BPF_LD|BPF_IMM)) {
1803                                         status = sljit_emit_op1(compiler,
1804                                             SLJIT_MOV,
1805                                             BJ_AREG, 0,
1806                                             SLJIT_IMM, (uint32_t)pc->k);
1807                                         if (status != SLJIT_SUCCESS)
1808                                                   goto fail;
1809 
1810                                         continue;
1811                               }
1812 
1813                               /* BPF_LD+BPF_MEM          A <- M[k] */
1814                               if (pc->code == (BPF_LD|BPF_MEM)) {
1815                                         if ((uint32_t)pc->k >= memwords)
1816                                                   goto fail;
1817                                         status = emit_memload(compiler,
1818                                             BJ_AREG, pc->k, extwords);
1819                                         if (status != SLJIT_SUCCESS)
1820                                                   goto fail;
1821 
1822                                         continue;
1823                               }
1824 
1825                               /* BPF_LD+BPF_W+BPF_LEN    A <- len */
1826                               if (pc->code == (BPF_LD|BPF_W|BPF_LEN)) {
1827                                         status = sljit_emit_op1(compiler,
1828                                             SLJIT_MOV, /* size_t source */
1829                                             BJ_AREG, 0,
1830                                             SLJIT_MEM1(BJ_ARGS),
1831                                             offsetof(struct bpf_args, wirelen));
1832                                         if (status != SLJIT_SUCCESS)
1833                                                   goto fail;
1834 
1835                                         continue;
1836                               }
1837 
1838                               mode = BPF_MODE(pc->code);
1839                               if (mode != BPF_ABS && mode != BPF_IND)
1840                                         goto fail;
1841 
1842                               if (unconditional_ret)
1843                                         continue;
1844 
1845                               status = emit_pkt_read(compiler, hints, pc,
1846                                   to_mchain_jump, &ret0, &ret0_size, &ret0_maxsize);
1847                               if (status != SLJIT_SUCCESS)
1848                                         goto fail;
1849 
1850                               continue;
1851 
1852                     case BPF_LDX:
1853                               mode = BPF_MODE(pc->code);
1854 
1855                               /* BPF_LDX+BPF_W+BPF_IMM    X <- k */
1856                               if (mode == BPF_IMM) {
1857                                         if (BPF_SIZE(pc->code) != BPF_W)
1858                                                   goto fail;
1859                                         status = sljit_emit_op1(compiler,
1860                                             SLJIT_MOV,
1861                                             BJ_XREG, 0,
1862                                             SLJIT_IMM, (uint32_t)pc->k);
1863                                         if (status != SLJIT_SUCCESS)
1864                                                   goto fail;
1865 
1866                                         continue;
1867                               }
1868 
1869                               /* BPF_LDX+BPF_W+BPF_LEN    X <- len */
1870                               if (mode == BPF_LEN) {
1871                                         if (BPF_SIZE(pc->code) != BPF_W)
1872                                                   goto fail;
1873                                         status = sljit_emit_op1(compiler,
1874                                             SLJIT_MOV, /* size_t source */
1875                                             BJ_XREG, 0,
1876                                             SLJIT_MEM1(BJ_ARGS),
1877                                             offsetof(struct bpf_args, wirelen));
1878                                         if (status != SLJIT_SUCCESS)
1879                                                   goto fail;
1880 
1881                                         continue;
1882                               }
1883 
1884                               /* BPF_LDX+BPF_W+BPF_MEM    X <- M[k] */
1885                               if (mode == BPF_MEM) {
1886                                         if (BPF_SIZE(pc->code) != BPF_W)
1887                                                   goto fail;
1888                                         if ((uint32_t)pc->k >= memwords)
1889                                                   goto fail;
1890                                         status = emit_memload(compiler,
1891                                             BJ_XREG, pc->k, extwords);
1892                                         if (status != SLJIT_SUCCESS)
1893                                                   goto fail;
1894 
1895                                         continue;
1896                               }
1897 
1898                               /* BPF_LDX+BPF_B+BPF_MSH    X <- 4*(P[k:1]&0xf) */
1899                               if (mode != BPF_MSH || BPF_SIZE(pc->code) != BPF_B)
1900                                         goto fail;
1901 
1902                               if (unconditional_ret)
1903                                         continue;
1904 
1905                               status = emit_msh(compiler, hints, pc,
1906                                   to_mchain_jump, &ret0, &ret0_size, &ret0_maxsize);
1907                               if (status != SLJIT_SUCCESS)
1908                                         goto fail;
1909 
1910                               continue;
1911 
1912                     case BPF_ST:
1913                               if (pc->code != BPF_ST ||
1914                                   (uint32_t)pc->k >= memwords) {
1915                                         goto fail;
1916                               }
1917 
1918                               status = emit_memstore(compiler,
1919                                   BJ_AREG, pc->k, extwords);
1920                               if (status != SLJIT_SUCCESS)
1921                                         goto fail;
1922 
1923                               continue;
1924 
1925                     case BPF_STX:
1926                               if (pc->code != BPF_STX ||
1927                                   (uint32_t)pc->k >= memwords) {
1928                                         goto fail;
1929                               }
1930 
1931                               status = emit_memstore(compiler,
1932                                   BJ_XREG, pc->k, extwords);
1933                               if (status != SLJIT_SUCCESS)
1934                                         goto fail;
1935 
1936                               continue;
1937 
1938                     case BPF_ALU:
1939                               if (pc->code == (BPF_ALU|BPF_NEG)) {
1940                                         status = sljit_emit_op1(compiler,
1941                                             SLJIT_NEG,
1942                                             BJ_AREG, 0,
1943                                             BJ_AREG, 0);
1944                                         if (status != SLJIT_SUCCESS)
1945                                                   goto fail;
1946 
1947                                         continue;
1948                               }
1949 
1950                               op = BPF_OP(pc->code);
1951                               if (op != BPF_DIV && op != BPF_MOD) {
1952                                         if (!alu_to_op(pc, &op2))
1953                                                   goto fail;
1954 
1955                                         status = sljit_emit_op2(compiler,
1956                                             op2, BJ_AREG, 0, BJ_AREG, 0,
1957                                             kx_to_reg(pc), kx_to_reg_arg(pc));
1958                                         if (status != SLJIT_SUCCESS)
1959                                                   goto fail;
1960 
1961                                         continue;
1962                               }
1963 
1964                               /* BPF_DIV/BPF_MOD */
1965 
1966                               src = BPF_SRC(pc->code);
1967                               if (src != BPF_X && src != BPF_K)
1968                                         goto fail;
1969 
1970                               /* division by zero? */
1971                               if (src == BPF_X) {
1972                                         jump = sljit_emit_cmp(compiler,
1973                                             SLJIT_EQUAL|SLJIT_I32_OP,
1974                                             BJ_XREG, 0,
1975                                             SLJIT_IMM, 0);
1976                                         if (jump == NULL)
1977                                                   goto fail;
1978                                         if (!append_jump(jump, &ret0,
1979                                             &ret0_size, &ret0_maxsize))
1980                                                   goto fail;
1981                               } else if (pc->k == 0) {
1982                                         jump = sljit_emit_jump(compiler, SLJIT_JUMP);
1983                                         if (jump == NULL)
1984                                                   goto fail;
1985                                         if (!append_jump(jump, &ret0,
1986                                             &ret0_size, &ret0_maxsize))
1987                                                   goto fail;
1988                               }
1989 
1990                               if (src == BPF_X) {
1991                                         status = emit_moddiv(compiler, pc);
1992                                         if (status != SLJIT_SUCCESS)
1993                                                   goto fail;
1994                               } else if (pc->k != 0) {
1995                                         if (pc->k & (pc->k - 1)) {
1996                                                   status = emit_moddiv(compiler, pc);
1997                                         } else {
1998                                                   status = emit_pow2_moddiv(compiler, pc);
1999                                         }
2000                                         if (status != SLJIT_SUCCESS)
2001                                                   goto fail;
2002                               }
2003 
2004                               continue;
2005 
2006                     case BPF_JMP:
2007                               op = BPF_OP(pc->code);
2008                               if (op == BPF_JA) {
2009                                         jt = jf = pc->k;
2010                               } else {
2011                                         jt = pc->jt;
2012                                         jf = pc->jf;
2013                               }
2014 
2015                               negate = (jt == 0) ? 1 : 0;
2016                               branching = (jt == jf) ? 0 : 1;
2017                               jtf = insn_dat[i].u.jdata.jtf;
2018 
2019                               if (branching) {
2020                                         if (op != BPF_JSET) {
2021                                                   if (!jmp_to_cond(pc, negate, &cond))
2022                                                             goto fail;
2023                                                   jump = sljit_emit_cmp(compiler,
2024                                                       cond, BJ_AREG, 0,
2025                                                       kx_to_reg(pc), kx_to_reg_arg(pc));
2026                                         } else {
2027                                                   status = sljit_emit_op2(compiler,
2028                                                       SLJIT_AND,
2029                                                       BJ_TMP1REG, 0,
2030                                                       BJ_AREG, 0,
2031                                                       kx_to_reg(pc), kx_to_reg_arg(pc));
2032                                                   if (status != SLJIT_SUCCESS)
2033                                                             goto fail;
2034 
2035                                                   if (!jmp_to_cond(pc, negate, &cond))
2036                                                             goto fail;
2037                                                   jump = sljit_emit_cmp(compiler,
2038                                                       cond, BJ_TMP1REG, 0, SLJIT_IMM, 0);
2039                                         }
2040 
2041                                         if (jump == NULL)
2042                                                   goto fail;
2043 
2044                                         BJ_ASSERT(jtf[negate].sjump == NULL);
2045                                         jtf[negate].sjump = jump;
2046                               }
2047 
2048                               if (!branching || (jt != 0 && jf != 0)) {
2049                                         jump = sljit_emit_jump(compiler, SLJIT_JUMP);
2050                                         if (jump == NULL)
2051                                                   goto fail;
2052 
2053                                         BJ_ASSERT(jtf[branching].sjump == NULL);
2054                                         jtf[branching].sjump = jump;
2055                               }
2056 
2057                               continue;
2058 
2059                     case BPF_RET:
2060                               rval = BPF_RVAL(pc->code);
2061                               if (rval == BPF_X)
2062                                         goto fail;
2063 
2064                               /* BPF_RET+BPF_K    accept k bytes */
2065                               if (rval == BPF_K) {
2066                                         status = sljit_emit_return(compiler,
2067                                             SLJIT_MOV_U32,
2068                                             SLJIT_IMM, (uint32_t)pc->k);
2069                                         if (status != SLJIT_SUCCESS)
2070                                                   goto fail;
2071                               }
2072 
2073                               /* BPF_RET+BPF_A    accept A bytes */
2074                               if (rval == BPF_A) {
2075                                         status = sljit_emit_return(compiler,
2076                                             SLJIT_MOV_U32,
2077                                             BJ_AREG, 0);
2078                                         if (status != SLJIT_SUCCESS)
2079                                                   goto fail;
2080                               }
2081 
2082                               continue;
2083 
2084                     case BPF_MISC:
2085                               switch (BPF_MISCOP(pc->code)) {
2086                               case BPF_TAX:
2087                                         status = sljit_emit_op1(compiler,
2088                                             SLJIT_MOV_U32,
2089                                             BJ_XREG, 0,
2090                                             BJ_AREG, 0);
2091                                         if (status != SLJIT_SUCCESS)
2092                                                   goto fail;
2093 
2094                                         continue;
2095 
2096                               case BPF_TXA:
2097                                         status = sljit_emit_op1(compiler,
2098                                             SLJIT_MOV,
2099                                             BJ_AREG, 0,
2100                                             BJ_XREG, 0);
2101                                         if (status != SLJIT_SUCCESS)
2102                                                   goto fail;
2103 
2104                                         continue;
2105 
2106                               case BPF_COP:
2107                               case BPF_COPX:
2108                                         if (bc == NULL || bc->copfuncs == NULL)
2109                                                   goto fail;
2110                                         if (BPF_MISCOP(pc->code) == BPF_COP &&
2111                                             (uint32_t)pc->k >= bc->nfuncs) {
2112                                                   goto fail;
2113                                         }
2114 
2115                                         status = emit_cop(compiler, hints, bc, pc,
2116                                             &ret0, &ret0_size, &ret0_maxsize);
2117                                         if (status != SLJIT_SUCCESS)
2118                                                   goto fail;
2119 
2120                                         continue;
2121                               }
2122 
2123                               goto fail;
2124                     } /* switch */
2125           } /* main loop */
2126 
2127           BJ_ASSERT(ret0_size <= ret0_maxsize);
2128 
2129           if (ret0_size > 0) {
2130                     label = sljit_emit_label(compiler);
2131                     if (label == NULL)
2132                               goto fail;
2133                     for (i = 0; i < ret0_size; i++)
2134                               sljit_set_label(ret0[i], label);
2135           }
2136 
2137           status = sljit_emit_return(compiler,
2138               SLJIT_MOV_U32,
2139               SLJIT_IMM, 0);
2140           if (status != SLJIT_SUCCESS)
2141                     goto fail;
2142 
2143           rv = true;
2144 
2145 fail:
2146           if (ret0 != NULL)
2147                     BJ_FREE(ret0, ret0_maxsize * sizeof(ret0[0]));
2148 
2149           return rv;
2150 }
2151 
2152 bpfjit_func_t
bpfjit_generate_code(const bpf_ctx_t * bc,const struct bpf_insn * insns,size_t insn_count)2153 bpfjit_generate_code(const bpf_ctx_t *bc,
2154     const struct bpf_insn *insns, size_t insn_count)
2155 {
2156           void *rv;
2157           struct sljit_compiler *compiler;
2158 
2159           size_t i;
2160           int status;
2161 
2162           /* optimization related */
2163           bpf_memword_init_t initmask;
2164           bpfjit_hint_t hints;
2165 
2166           /* memory store location for initial zero initialization */
2167           sljit_s32 mem_reg;
2168           sljit_sw mem_off;
2169 
2170           struct bpfjit_insn_data *insn_dat;
2171 
2172           const size_t extwords = GET_EXTWORDS(bc);
2173           const size_t memwords = GET_MEMWORDS(bc);
2174           const bpf_memword_init_t preinited = extwords ? bc->preinited : 0;
2175 
2176           rv = NULL;
2177           compiler = NULL;
2178           insn_dat = NULL;
2179 
2180           if (memwords > MAX_MEMWORDS)
2181                     goto fail;
2182 
2183           if (insn_count == 0 || insn_count > SIZE_MAX / sizeof(insn_dat[0]))
2184                     goto fail;
2185 
2186           insn_dat = BJ_ALLOC(insn_count * sizeof(insn_dat[0]));
2187           if (insn_dat == NULL)
2188                     goto fail;
2189 
2190           if (!optimize(bc, insns, insn_dat, insn_count, &initmask, &hints))
2191                     goto fail;
2192 
2193           compiler = sljit_create_compiler(NULL);
2194           if (compiler == NULL)
2195                     goto fail;
2196 
2197 #if !defined(_KERNEL) && defined(SLJIT_VERBOSE) && SLJIT_VERBOSE
2198           sljit_compiler_verbose(compiler, stderr);
2199 #endif
2200 
2201           status = sljit_emit_enter(compiler, 0, 2, nscratches(hints),
2202               NSAVEDS, 0, 0, sizeof(struct bpfjit_stack));
2203           if (status != SLJIT_SUCCESS)
2204                     goto fail;
2205 
2206           if (hints & BJ_HINT_COP) {
2207                     /* save ctx argument */
2208                     status = sljit_emit_op1(compiler,
2209                         SLJIT_MOV_P,
2210                         SLJIT_MEM1(SLJIT_SP),
2211                         offsetof(struct bpfjit_stack, ctx),
2212                         BJ_CTX_ARG, 0);
2213                     if (status != SLJIT_SUCCESS)
2214                               goto fail;
2215           }
2216 
2217           if (extwords == 0) {
2218                     mem_reg = SLJIT_MEM1(SLJIT_SP);
2219                     mem_off = offsetof(struct bpfjit_stack, mem);
2220           } else {
2221                     /* copy "mem" argument from bpf_args to bpfjit_stack */
2222                     status = sljit_emit_op1(compiler,
2223                         SLJIT_MOV_P,
2224                         BJ_TMP1REG, 0,
2225                         SLJIT_MEM1(BJ_ARGS), offsetof(struct bpf_args, mem));
2226                     if (status != SLJIT_SUCCESS)
2227                               goto fail;
2228 
2229                     status = sljit_emit_op1(compiler,
2230                         SLJIT_MOV_P,
2231                         SLJIT_MEM1(SLJIT_SP),
2232                         offsetof(struct bpfjit_stack, extmem),
2233                         BJ_TMP1REG, 0);
2234                     if (status != SLJIT_SUCCESS)
2235                               goto fail;
2236 
2237                     mem_reg = SLJIT_MEM1(BJ_TMP1REG);
2238                     mem_off = 0;
2239           }
2240 
2241           /*
2242            * Exclude pre-initialised external memory words but keep
2243            * initialization statuses of A and X registers in case
2244            * bc->preinited wrongly sets those two bits.
2245            */
2246           initmask &= ~preinited | BJ_INIT_ABIT | BJ_INIT_XBIT;
2247 
2248 #if defined(_KERNEL)
2249           /* bpf_filter() checks initialization of memwords. */
2250           BJ_ASSERT((initmask & (BJ_INIT_MBIT(memwords) - 1)) == 0);
2251 #endif
2252           for (i = 0; i < memwords; i++) {
2253                     if (initmask & BJ_INIT_MBIT(i)) {
2254                               /* M[i] = 0; */
2255                               status = sljit_emit_op1(compiler,
2256                                   SLJIT_MOV_U32,
2257                                   mem_reg, mem_off + i * sizeof(uint32_t),
2258                                   SLJIT_IMM, 0);
2259                               if (status != SLJIT_SUCCESS)
2260                                         goto fail;
2261                     }
2262           }
2263 
2264           if (initmask & BJ_INIT_ABIT) {
2265                     /* A = 0; */
2266                     status = sljit_emit_op1(compiler,
2267                         SLJIT_MOV,
2268                         BJ_AREG, 0,
2269                         SLJIT_IMM, 0);
2270                     if (status != SLJIT_SUCCESS)
2271                               goto fail;
2272           }
2273 
2274           if (initmask & BJ_INIT_XBIT) {
2275                     /* X = 0; */
2276                     status = sljit_emit_op1(compiler,
2277                         SLJIT_MOV,
2278                         BJ_XREG, 0,
2279                         SLJIT_IMM, 0);
2280                     if (status != SLJIT_SUCCESS)
2281                               goto fail;
2282           }
2283 
2284           status = load_buf_buflen(compiler);
2285           if (status != SLJIT_SUCCESS)
2286                     goto fail;
2287 
2288           if (!generate_insn_code(compiler, hints,
2289               bc, insns, insn_dat, insn_count)) {
2290                     goto fail;
2291           }
2292 
2293           rv = sljit_generate_code(compiler);
2294 
2295 fail:
2296           if (compiler != NULL)
2297                     sljit_free_compiler(compiler);
2298 
2299           if (insn_dat != NULL)
2300                     BJ_FREE(insn_dat, insn_count * sizeof(insn_dat[0]));
2301 
2302           return (bpfjit_func_t)rv;
2303 }
2304 
2305 void
bpfjit_free_code(bpfjit_func_t code)2306 bpfjit_free_code(bpfjit_func_t code)
2307 {
2308 
2309           sljit_free_code((void *)code);
2310 }
2311