aboutsummaryrefslogtreecommitdiffstats
path: root/arch/x86/net/bpf_jit_comp.c
diff options
context:
space:
mode:
Diffstat (limited to 'arch/x86/net/bpf_jit_comp.c')
-rw-r--r--arch/x86/net/bpf_jit_comp.c234
1 files changed, 134 insertions, 100 deletions
diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c
index ce5b2ebd5701..b725154182cc 100644
--- a/arch/x86/net/bpf_jit_comp.c
+++ b/arch/x86/net/bpf_jit_comp.c
@@ -11,10 +11,10 @@
11#include <linux/netdevice.h> 11#include <linux/netdevice.h>
12#include <linux/filter.h> 12#include <linux/filter.h>
13#include <linux/if_vlan.h> 13#include <linux/if_vlan.h>
14#include <asm/cacheflush.h> 14#include <linux/bpf.h>
15
15#include <asm/set_memory.h> 16#include <asm/set_memory.h>
16#include <asm/nospec-branch.h> 17#include <asm/nospec-branch.h>
17#include <linux/bpf.h>
18 18
19/* 19/*
20 * assembly code in arch/x86/net/bpf_jit.S 20 * assembly code in arch/x86/net/bpf_jit.S
@@ -61,7 +61,12 @@ static bool is_imm8(int value)
61 61
62static bool is_simm32(s64 value) 62static bool is_simm32(s64 value)
63{ 63{
64 return value == (s64) (s32) value; 64 return value == (s64)(s32)value;
65}
66
67static bool is_uimm32(u64 value)
68{
69 return value == (u64)(u32)value;
65} 70}
66 71
67/* mov dst, src */ 72/* mov dst, src */
@@ -98,16 +103,6 @@ static int bpf_size_to_x86_bytes(int bpf_size)
98#define X86_JLE 0x7E 103#define X86_JLE 0x7E
99#define X86_JG 0x7F 104#define X86_JG 0x7F
100 105
101static void bpf_flush_icache(void *start, void *end)
102{
103 mm_segment_t old_fs = get_fs();
104
105 set_fs(KERNEL_DS);
106 smp_wmb();
107 flush_icache_range((unsigned long)start, (unsigned long)end);
108 set_fs(old_fs);
109}
110
111#define CHOOSE_LOAD_FUNC(K, func) \ 106#define CHOOSE_LOAD_FUNC(K, func) \
112 ((int)K < 0 ? ((int)K >= SKF_LL_OFF ? func##_negative_offset : func) : func##_positive_offset) 107 ((int)K < 0 ? ((int)K >= SKF_LL_OFF ? func##_negative_offset : func) : func##_positive_offset)
113 108
@@ -212,7 +207,7 @@ struct jit_context {
212/* emit x64 prologue code for BPF program and check it's size. 207/* emit x64 prologue code for BPF program and check it's size.
213 * bpf_tail_call helper will skip it while jumping into another program 208 * bpf_tail_call helper will skip it while jumping into another program
214 */ 209 */
215static void emit_prologue(u8 **pprog, u32 stack_depth) 210static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf)
216{ 211{
217 u8 *prog = *pprog; 212 u8 *prog = *pprog;
218 int cnt = 0; 213 int cnt = 0;
@@ -247,18 +242,21 @@ static void emit_prologue(u8 **pprog, u32 stack_depth)
247 /* mov qword ptr [rbp+24],r15 */ 242 /* mov qword ptr [rbp+24],r15 */
248 EMIT4(0x4C, 0x89, 0x7D, 24); 243 EMIT4(0x4C, 0x89, 0x7D, 24);
249 244
250 /* Clear the tail call counter (tail_call_cnt): for eBPF tail calls 245 if (!ebpf_from_cbpf) {
251 * we need to reset the counter to 0. It's done in two instructions, 246 /* Clear the tail call counter (tail_call_cnt): for eBPF tail
252 * resetting rax register to 0 (xor on eax gets 0 extended), and 247 * calls we need to reset the counter to 0. It's done in two
253 * moving it to the counter location. 248 * instructions, resetting rax register to 0, and moving it
254 */ 249 * to the counter location.
250 */
255 251
256 /* xor eax, eax */ 252 /* xor eax, eax */
257 EMIT2(0x31, 0xc0); 253 EMIT2(0x31, 0xc0);
258 /* mov qword ptr [rbp+32], rax */ 254 /* mov qword ptr [rbp+32], rax */
259 EMIT4(0x48, 0x89, 0x45, 32); 255 EMIT4(0x48, 0x89, 0x45, 32);
256
257 BUILD_BUG_ON(cnt != PROLOGUE_SIZE);
258 }
260 259
261 BUILD_BUG_ON(cnt != PROLOGUE_SIZE);
262 *pprog = prog; 260 *pprog = prog;
263} 261}
264 262
@@ -356,6 +354,86 @@ static void emit_load_skb_data_hlen(u8 **pprog)
356 *pprog = prog; 354 *pprog = prog;
357} 355}
358 356
357static void emit_mov_imm32(u8 **pprog, bool sign_propagate,
358 u32 dst_reg, const u32 imm32)
359{
360 u8 *prog = *pprog;
361 u8 b1, b2, b3;
362 int cnt = 0;
363
364 /* optimization: if imm32 is positive, use 'mov %eax, imm32'
365 * (which zero-extends imm32) to save 2 bytes.
366 */
367 if (sign_propagate && (s32)imm32 < 0) {
368 /* 'mov %rax, imm32' sign extends imm32 */
369 b1 = add_1mod(0x48, dst_reg);
370 b2 = 0xC7;
371 b3 = 0xC0;
372 EMIT3_off32(b1, b2, add_1reg(b3, dst_reg), imm32);
373 goto done;
374 }
375
376 /* optimization: if imm32 is zero, use 'xor %eax, %eax'
377 * to save 3 bytes.
378 */
379 if (imm32 == 0) {
380 if (is_ereg(dst_reg))
381 EMIT1(add_2mod(0x40, dst_reg, dst_reg));
382 b2 = 0x31; /* xor */
383 b3 = 0xC0;
384 EMIT2(b2, add_2reg(b3, dst_reg, dst_reg));
385 goto done;
386 }
387
388 /* mov %eax, imm32 */
389 if (is_ereg(dst_reg))
390 EMIT1(add_1mod(0x40, dst_reg));
391 EMIT1_off32(add_1reg(0xB8, dst_reg), imm32);
392done:
393 *pprog = prog;
394}
395
396static void emit_mov_imm64(u8 **pprog, u32 dst_reg,
397 const u32 imm32_hi, const u32 imm32_lo)
398{
399 u8 *prog = *pprog;
400 int cnt = 0;
401
402 if (is_uimm32(((u64)imm32_hi << 32) | (u32)imm32_lo)) {
403 /* For emitting plain u32, where sign bit must not be
404 * propagated LLVM tends to load imm64 over mov32
405 * directly, so save couple of bytes by just doing
406 * 'mov %eax, imm32' instead.
407 */
408 emit_mov_imm32(&prog, false, dst_reg, imm32_lo);
409 } else {
410 /* movabsq %rax, imm64 */
411 EMIT2(add_1mod(0x48, dst_reg), add_1reg(0xB8, dst_reg));
412 EMIT(imm32_lo, 4);
413 EMIT(imm32_hi, 4);
414 }
415
416 *pprog = prog;
417}
418
419static void emit_mov_reg(u8 **pprog, bool is64, u32 dst_reg, u32 src_reg)
420{
421 u8 *prog = *pprog;
422 int cnt = 0;
423
424 if (is64) {
425 /* mov dst, src */
426 EMIT_mov(dst_reg, src_reg);
427 } else {
428 /* mov32 dst, src */
429 if (is_ereg(dst_reg) || is_ereg(src_reg))
430 EMIT1(add_2mod(0x40, dst_reg, src_reg));
431 EMIT2(0x89, add_2reg(0xC0, dst_reg, src_reg));
432 }
433
434 *pprog = prog;
435}
436
359static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, 437static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
360 int oldproglen, struct jit_context *ctx) 438 int oldproglen, struct jit_context *ctx)
361{ 439{
@@ -369,7 +447,8 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
369 int proglen = 0; 447 int proglen = 0;
370 u8 *prog = temp; 448 u8 *prog = temp;
371 449
372 emit_prologue(&prog, bpf_prog->aux->stack_depth); 450 emit_prologue(&prog, bpf_prog->aux->stack_depth,
451 bpf_prog_was_classic(bpf_prog));
373 452
374 if (seen_ld_abs) 453 if (seen_ld_abs)
375 emit_load_skb_data_hlen(&prog); 454 emit_load_skb_data_hlen(&prog);
@@ -378,7 +457,7 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
378 const s32 imm32 = insn->imm; 457 const s32 imm32 = insn->imm;
379 u32 dst_reg = insn->dst_reg; 458 u32 dst_reg = insn->dst_reg;
380 u32 src_reg = insn->src_reg; 459 u32 src_reg = insn->src_reg;
381 u8 b1 = 0, b2 = 0, b3 = 0; 460 u8 b2 = 0, b3 = 0;
382 s64 jmp_offset; 461 s64 jmp_offset;
383 u8 jmp_cond; 462 u8 jmp_cond;
384 bool reload_skb_data; 463 bool reload_skb_data;
@@ -414,16 +493,11 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
414 EMIT2(b2, add_2reg(0xC0, dst_reg, src_reg)); 493 EMIT2(b2, add_2reg(0xC0, dst_reg, src_reg));
415 break; 494 break;
416 495
417 /* mov dst, src */
418 case BPF_ALU64 | BPF_MOV | BPF_X: 496 case BPF_ALU64 | BPF_MOV | BPF_X:
419 EMIT_mov(dst_reg, src_reg);
420 break;
421
422 /* mov32 dst, src */
423 case BPF_ALU | BPF_MOV | BPF_X: 497 case BPF_ALU | BPF_MOV | BPF_X:
424 if (is_ereg(dst_reg) || is_ereg(src_reg)) 498 emit_mov_reg(&prog,
425 EMIT1(add_2mod(0x40, dst_reg, src_reg)); 499 BPF_CLASS(insn->code) == BPF_ALU64,
426 EMIT2(0x89, add_2reg(0xC0, dst_reg, src_reg)); 500 dst_reg, src_reg);
427 break; 501 break;
428 502
429 /* neg dst */ 503 /* neg dst */
@@ -486,58 +560,13 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
486 break; 560 break;
487 561
488 case BPF_ALU64 | BPF_MOV | BPF_K: 562 case BPF_ALU64 | BPF_MOV | BPF_K:
489 /* optimization: if imm32 is positive,
490 * use 'mov eax, imm32' (which zero-extends imm32)
491 * to save 2 bytes
492 */
493 if (imm32 < 0) {
494 /* 'mov rax, imm32' sign extends imm32 */
495 b1 = add_1mod(0x48, dst_reg);
496 b2 = 0xC7;
497 b3 = 0xC0;
498 EMIT3_off32(b1, b2, add_1reg(b3, dst_reg), imm32);
499 break;
500 }
501
502 case BPF_ALU | BPF_MOV | BPF_K: 563 case BPF_ALU | BPF_MOV | BPF_K:
503 /* optimization: if imm32 is zero, use 'xor <dst>,<dst>' 564 emit_mov_imm32(&prog, BPF_CLASS(insn->code) == BPF_ALU64,
504 * to save 3 bytes. 565 dst_reg, imm32);
505 */
506 if (imm32 == 0) {
507 if (is_ereg(dst_reg))
508 EMIT1(add_2mod(0x40, dst_reg, dst_reg));
509 b2 = 0x31; /* xor */
510 b3 = 0xC0;
511 EMIT2(b2, add_2reg(b3, dst_reg, dst_reg));
512 break;
513 }
514
515 /* mov %eax, imm32 */
516 if (is_ereg(dst_reg))
517 EMIT1(add_1mod(0x40, dst_reg));
518 EMIT1_off32(add_1reg(0xB8, dst_reg), imm32);
519 break; 566 break;
520 567
521 case BPF_LD | BPF_IMM | BPF_DW: 568 case BPF_LD | BPF_IMM | BPF_DW:
522 /* optimization: if imm64 is zero, use 'xor <dst>,<dst>' 569 emit_mov_imm64(&prog, dst_reg, insn[1].imm, insn[0].imm);
523 * to save 7 bytes.
524 */
525 if (insn[0].imm == 0 && insn[1].imm == 0) {
526 b1 = add_2mod(0x48, dst_reg, dst_reg);
527 b2 = 0x31; /* xor */
528 b3 = 0xC0;
529 EMIT3(b1, b2, add_2reg(b3, dst_reg, dst_reg));
530
531 insn++;
532 i++;
533 break;
534 }
535
536 /* movabsq %rax, imm64 */
537 EMIT2(add_1mod(0x48, dst_reg), add_1reg(0xB8, dst_reg));
538 EMIT(insn[0].imm, 4);
539 EMIT(insn[1].imm, 4);
540
541 insn++; 570 insn++;
542 i++; 571 i++;
543 break; 572 break;
@@ -594,36 +623,38 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
594 case BPF_ALU | BPF_MUL | BPF_X: 623 case BPF_ALU | BPF_MUL | BPF_X:
595 case BPF_ALU64 | BPF_MUL | BPF_K: 624 case BPF_ALU64 | BPF_MUL | BPF_K:
596 case BPF_ALU64 | BPF_MUL | BPF_X: 625 case BPF_ALU64 | BPF_MUL | BPF_X:
597 EMIT1(0x50); /* push rax */ 626 {
598 EMIT1(0x52); /* push rdx */ 627 bool is64 = BPF_CLASS(insn->code) == BPF_ALU64;
628
629 if (dst_reg != BPF_REG_0)
630 EMIT1(0x50); /* push rax */
631 if (dst_reg != BPF_REG_3)
632 EMIT1(0x52); /* push rdx */
599 633
600 /* mov r11, dst_reg */ 634 /* mov r11, dst_reg */
601 EMIT_mov(AUX_REG, dst_reg); 635 EMIT_mov(AUX_REG, dst_reg);
602 636
603 if (BPF_SRC(insn->code) == BPF_X) 637 if (BPF_SRC(insn->code) == BPF_X)
604 /* mov rax, src_reg */ 638 emit_mov_reg(&prog, is64, BPF_REG_0, src_reg);
605 EMIT_mov(BPF_REG_0, src_reg);
606 else 639 else
607 /* mov rax, imm32 */ 640 emit_mov_imm32(&prog, is64, BPF_REG_0, imm32);
608 EMIT3_off32(0x48, 0xC7, 0xC0, imm32);
609 641
610 if (BPF_CLASS(insn->code) == BPF_ALU64) 642 if (is64)
611 EMIT1(add_1mod(0x48, AUX_REG)); 643 EMIT1(add_1mod(0x48, AUX_REG));
612 else if (is_ereg(AUX_REG)) 644 else if (is_ereg(AUX_REG))
613 EMIT1(add_1mod(0x40, AUX_REG)); 645 EMIT1(add_1mod(0x40, AUX_REG));
614 /* mul(q) r11 */ 646 /* mul(q) r11 */
615 EMIT2(0xF7, add_1reg(0xE0, AUX_REG)); 647 EMIT2(0xF7, add_1reg(0xE0, AUX_REG));
616 648
617 /* mov r11, rax */ 649 if (dst_reg != BPF_REG_3)
618 EMIT_mov(AUX_REG, BPF_REG_0); 650 EMIT1(0x5A); /* pop rdx */
619 651 if (dst_reg != BPF_REG_0) {
620 EMIT1(0x5A); /* pop rdx */ 652 /* mov dst_reg, rax */
621 EMIT1(0x58); /* pop rax */ 653 EMIT_mov(dst_reg, BPF_REG_0);
622 654 EMIT1(0x58); /* pop rax */
623 /* mov dst_reg, r11 */ 655 }
624 EMIT_mov(dst_reg, AUX_REG);
625 break; 656 break;
626 657 }
627 /* shifts */ 658 /* shifts */
628 case BPF_ALU | BPF_LSH | BPF_K: 659 case BPF_ALU | BPF_LSH | BPF_K:
629 case BPF_ALU | BPF_RSH | BPF_K: 660 case BPF_ALU | BPF_RSH | BPF_K:
@@ -641,7 +672,11 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
641 case BPF_RSH: b3 = 0xE8; break; 672 case BPF_RSH: b3 = 0xE8; break;
642 case BPF_ARSH: b3 = 0xF8; break; 673 case BPF_ARSH: b3 = 0xF8; break;
643 } 674 }
644 EMIT3(0xC1, add_1reg(b3, dst_reg), imm32); 675
676 if (imm32 == 1)
677 EMIT2(0xD1, add_1reg(b3, dst_reg));
678 else
679 EMIT3(0xC1, add_1reg(b3, dst_reg), imm32);
645 break; 680 break;
646 681
647 case BPF_ALU | BPF_LSH | BPF_X: 682 case BPF_ALU | BPF_LSH | BPF_X:
@@ -1222,7 +1257,6 @@ skip_init_addrs:
1222 bpf_jit_dump(prog->len, proglen, pass + 1, image); 1257 bpf_jit_dump(prog->len, proglen, pass + 1, image);
1223 1258
1224 if (image) { 1259 if (image) {
1225 bpf_flush_icache(header, image + proglen);
1226 if (!prog->is_func || extra_pass) { 1260 if (!prog->is_func || extra_pass) {
1227 bpf_jit_binary_lock_ro(header); 1261 bpf_jit_binary_lock_ro(header);
1228 } else { 1262 } else {