diff options
Diffstat (limited to 'arch/x86/net')
-rw-r--r-- | arch/x86/net/bpf_jit_comp.c | 80 | ||||
-rw-r--r-- | arch/x86/net/bpf_jit_comp32.c | 367 |
2 files changed, 107 insertions, 340 deletions
diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c index afabf597c855..eaaed5bfc4a4 100644 --- a/arch/x86/net/bpf_jit_comp.c +++ b/arch/x86/net/bpf_jit_comp.c @@ -1,13 +1,9 @@ +// SPDX-License-Identifier: GPL-2.0-only /* * bpf_jit_comp.c: BPF JIT compiler * * Copyright (C) 2011-2013 Eric Dumazet (eric.dumazet@gmail.com) * Internal BPF Copyright (c) 2011-2014 PLUMgrid, http://plumgrid.com - * - * This program is free software; you can redistribute it and/or - * modify it under the terms of the GNU General Public License - * as published by the Free Software Foundation; version 2 - * of the License. */ #include <linux/netdevice.h> #include <linux/filter.h> @@ -190,9 +186,7 @@ struct jit_context { #define BPF_MAX_INSN_SIZE 128 #define BPF_INSN_SAFETY 64 -#define AUX_STACK_SPACE 40 /* Space for RBX, R13, R14, R15, tailcnt */ - -#define PROLOGUE_SIZE 37 +#define PROLOGUE_SIZE 20 /* * Emit x86-64 prologue code for BPF program and check its size. @@ -203,44 +197,19 @@ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf) u8 *prog = *pprog; int cnt = 0; - /* push rbp */ - EMIT1(0x55); - - /* mov rbp,rsp */ - EMIT3(0x48, 0x89, 0xE5); - - /* sub rsp, rounded_stack_depth + AUX_STACK_SPACE */ - EMIT3_off32(0x48, 0x81, 0xEC, - round_up(stack_depth, 8) + AUX_STACK_SPACE); - - /* sub rbp, AUX_STACK_SPACE */ - EMIT4(0x48, 0x83, 0xED, AUX_STACK_SPACE); - - /* mov qword ptr [rbp+0],rbx */ - EMIT4(0x48, 0x89, 0x5D, 0); - /* mov qword ptr [rbp+8],r13 */ - EMIT4(0x4C, 0x89, 0x6D, 8); - /* mov qword ptr [rbp+16],r14 */ - EMIT4(0x4C, 0x89, 0x75, 16); - /* mov qword ptr [rbp+24],r15 */ - EMIT4(0x4C, 0x89, 0x7D, 24); - + EMIT1(0x55); /* push rbp */ + EMIT3(0x48, 0x89, 0xE5); /* mov rbp, rsp */ + /* sub rsp, rounded_stack_depth */ + EMIT3_off32(0x48, 0x81, 0xEC, round_up(stack_depth, 8)); + EMIT1(0x53); /* push rbx */ + EMIT2(0x41, 0x55); /* push r13 */ + EMIT2(0x41, 0x56); /* push r14 */ + EMIT2(0x41, 0x57); /* push r15 */ if (!ebpf_from_cbpf) { - /* - * Clear the tail call counter (tail_call_cnt): for eBPF tail - * calls we need to reset the counter to 0. It's done in two - * instructions, resetting RAX register to 0, and moving it - * to the counter location. - */ - - /* xor eax, eax */ - EMIT2(0x31, 0xc0); - /* mov qword ptr [rbp+32], rax */ - EMIT4(0x48, 0x89, 0x45, 32); - + /* zero init tail_call_cnt */ + EMIT2(0x6a, 0x00); BUILD_BUG_ON(cnt != PROLOGUE_SIZE); } - *pprog = prog; } @@ -285,13 +254,13 @@ static void emit_bpf_tail_call(u8 **pprog) * if (tail_call_cnt > MAX_TAIL_CALL_CNT) * goto out; */ - EMIT2_off32(0x8B, 0x85, 36); /* mov eax, dword ptr [rbp + 36] */ + EMIT2_off32(0x8B, 0x85, -36 - MAX_BPF_STACK); /* mov eax, dword ptr [rbp - 548] */ EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */ #define OFFSET2 (30 + RETPOLINE_RAX_BPF_JIT_SIZE) EMIT2(X86_JA, OFFSET2); /* ja out */ label2 = cnt; EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */ - EMIT2_off32(0x89, 0x85, 36); /* mov dword ptr [rbp + 36], eax */ + EMIT2_off32(0x89, 0x85, -36 - MAX_BPF_STACK); /* mov dword ptr [rbp -548], eax */ /* prog = array->ptrs[index]; */ EMIT4_off32(0x48, 0x8B, 0x84, 0xD6, /* mov rax, [rsi + rdx * 8 + offsetof(...)] */ @@ -1040,19 +1009,14 @@ emit_jmp: seen_exit = true; /* Update cleanup_addr */ ctx->cleanup_addr = proglen; - /* mov rbx, qword ptr [rbp+0] */ - EMIT4(0x48, 0x8B, 0x5D, 0); - /* mov r13, qword ptr [rbp+8] */ - EMIT4(0x4C, 0x8B, 0x6D, 8); - /* mov r14, qword ptr [rbp+16] */ - EMIT4(0x4C, 0x8B, 0x75, 16); - /* mov r15, qword ptr [rbp+24] */ - EMIT4(0x4C, 0x8B, 0x7D, 24); - - /* add rbp, AUX_STACK_SPACE */ - EMIT4(0x48, 0x83, 0xC5, AUX_STACK_SPACE); - EMIT1(0xC9); /* leave */ - EMIT1(0xC3); /* ret */ + if (!bpf_prog_was_classic(bpf_prog)) + EMIT1(0x5B); /* get rid of tail_call_cnt */ + EMIT2(0x41, 0x5F); /* pop r15 */ + EMIT2(0x41, 0x5E); /* pop r14 */ + EMIT2(0x41, 0x5D); /* pop r13 */ + EMIT1(0x5B); /* pop rbx */ + EMIT1(0xC9); /* leave */ + EMIT1(0xC3); /* ret */ break; default: diff --git a/arch/x86/net/bpf_jit_comp32.c b/arch/x86/net/bpf_jit_comp32.c index b29e82f190c7..393d251798c0 100644 --- a/arch/x86/net/bpf_jit_comp32.c +++ b/arch/x86/net/bpf_jit_comp32.c @@ -253,13 +253,14 @@ static inline void emit_ia32_mov_r(const u8 dst, const u8 src, bool dstk, /* dst = src */ static inline void emit_ia32_mov_r64(const bool is64, const u8 dst[], const u8 src[], bool dstk, - bool sstk, u8 **pprog) + bool sstk, u8 **pprog, + const struct bpf_prog_aux *aux) { emit_ia32_mov_r(dst_lo, src_lo, dstk, sstk, pprog); if (is64) /* complete 8 byte move */ emit_ia32_mov_r(dst_hi, src_hi, dstk, sstk, pprog); - else + else if (!aux->verifier_zext) /* zero out high 4 bytes */ emit_ia32_mov_i(dst_hi, 0, dstk, pprog); } @@ -313,7 +314,8 @@ static inline void emit_ia32_mul_r(const u8 dst, const u8 src, bool dstk, } static inline void emit_ia32_to_le_r64(const u8 dst[], s32 val, - bool dstk, u8 **pprog) + bool dstk, u8 **pprog, + const struct bpf_prog_aux *aux) { u8 *prog = *pprog; int cnt = 0; @@ -334,12 +336,14 @@ static inline void emit_ia32_to_le_r64(const u8 dst[], s32 val, */ EMIT2(0x0F, 0xB7); EMIT1(add_2reg(0xC0, dreg_lo, dreg_lo)); - /* xor dreg_hi,dreg_hi */ - EMIT2(0x33, add_2reg(0xC0, dreg_hi, dreg_hi)); + if (!aux->verifier_zext) + /* xor dreg_hi,dreg_hi */ + EMIT2(0x33, add_2reg(0xC0, dreg_hi, dreg_hi)); break; case 32: - /* xor dreg_hi,dreg_hi */ - EMIT2(0x33, add_2reg(0xC0, dreg_hi, dreg_hi)); + if (!aux->verifier_zext) + /* xor dreg_hi,dreg_hi */ + EMIT2(0x33, add_2reg(0xC0, dreg_hi, dreg_hi)); break; case 64: /* nop */ @@ -358,7 +362,8 @@ static inline void emit_ia32_to_le_r64(const u8 dst[], s32 val, } static inline void emit_ia32_to_be_r64(const u8 dst[], s32 val, - bool dstk, u8 **pprog) + bool dstk, u8 **pprog, + const struct bpf_prog_aux *aux) { u8 *prog = *pprog; int cnt = 0; @@ -380,16 +385,18 @@ static inline void emit_ia32_to_be_r64(const u8 dst[], s32 val, EMIT2(0x0F, 0xB7); EMIT1(add_2reg(0xC0, dreg_lo, dreg_lo)); - /* xor dreg_hi,dreg_hi */ - EMIT2(0x33, add_2reg(0xC0, dreg_hi, dreg_hi)); + if (!aux->verifier_zext) + /* xor dreg_hi,dreg_hi */ + EMIT2(0x33, add_2reg(0xC0, dreg_hi, dreg_hi)); break; case 32: /* Emit 'bswap eax' to swap lower 4 bytes */ EMIT1(0x0F); EMIT1(add_1reg(0xC8, dreg_lo)); - /* xor dreg_hi,dreg_hi */ - EMIT2(0x33, add_2reg(0xC0, dreg_hi, dreg_hi)); + if (!aux->verifier_zext) + /* xor dreg_hi,dreg_hi */ + EMIT2(0x33, add_2reg(0xC0, dreg_hi, dreg_hi)); break; case 64: /* Emit 'bswap eax' to swap lower 4 bytes */ @@ -569,7 +576,7 @@ static inline void emit_ia32_alu_r(const bool is64, const bool hi, const u8 op, static inline void emit_ia32_alu_r64(const bool is64, const u8 op, const u8 dst[], const u8 src[], bool dstk, bool sstk, - u8 **pprog) + u8 **pprog, const struct bpf_prog_aux *aux) { u8 *prog = *pprog; @@ -577,7 +584,7 @@ static inline void emit_ia32_alu_r64(const bool is64, const u8 op, if (is64) emit_ia32_alu_r(is64, true, op, dst_hi, src_hi, dstk, sstk, &prog); - else + else if (!aux->verifier_zext) emit_ia32_mov_i(dst_hi, 0, dstk, &prog); *pprog = prog; } @@ -668,7 +675,8 @@ static inline void emit_ia32_alu_i(const bool is64, const bool hi, const u8 op, /* ALU operation (64 bit) */ static inline void emit_ia32_alu_i64(const bool is64, const u8 op, const u8 dst[], const u32 val, - bool dstk, u8 **pprog) + bool dstk, u8 **pprog, + const struct bpf_prog_aux *aux) { u8 *prog = *pprog; u32 hi = 0; @@ -679,7 +687,7 @@ static inline void emit_ia32_alu_i64(const bool is64, const u8 op, emit_ia32_alu_i(is64, false, op, dst_lo, val, dstk, &prog); if (is64) emit_ia32_alu_i(is64, true, op, dst_hi, hi, dstk, &prog); - else + else if (!aux->verifier_zext) emit_ia32_mov_i(dst_hi, 0, dstk, &prog); *pprog = prog; @@ -724,9 +732,6 @@ static inline void emit_ia32_lsh_r64(const u8 dst[], const u8 src[], { u8 *prog = *pprog; int cnt = 0; - static int jmp_label1 = -1; - static int jmp_label2 = -1; - static int jmp_label3 = -1; u8 dreg_lo = dstk ? IA32_EAX : dst_lo; u8 dreg_hi = dstk ? IA32_EDX : dst_hi; @@ -745,78 +750,22 @@ static inline void emit_ia32_lsh_r64(const u8 dst[], const u8 src[], /* mov ecx,src_lo */ EMIT2(0x8B, add_2reg(0xC0, src_lo, IA32_ECX)); - /* cmp ecx,32 */ - EMIT3(0x83, add_1reg(0xF8, IA32_ECX), 32); - /* Jumps when >= 32 */ - if (is_imm8(jmp_label(jmp_label1, 2))) - EMIT2(IA32_JAE, jmp_label(jmp_label1, 2)); - else - EMIT2_off32(0x0F, IA32_JAE + 0x10, jmp_label(jmp_label1, 6)); - - /* < 32 */ - /* shl dreg_hi,cl */ - EMIT2(0xD3, add_1reg(0xE0, dreg_hi)); - /* mov ebx,dreg_lo */ - EMIT2(0x8B, add_2reg(0xC0, dreg_lo, IA32_EBX)); + /* shld dreg_hi,dreg_lo,cl */ + EMIT3(0x0F, 0xA5, add_2reg(0xC0, dreg_hi, dreg_lo)); /* shl dreg_lo,cl */ EMIT2(0xD3, add_1reg(0xE0, dreg_lo)); - /* IA32_ECX = -IA32_ECX + 32 */ - /* neg ecx */ - EMIT2(0xF7, add_1reg(0xD8, IA32_ECX)); - /* add ecx,32 */ - EMIT3(0x83, add_1reg(0xC0, IA32_ECX), 32); + /* if ecx >= 32, mov dreg_lo into dreg_hi and clear dreg_lo */ - /* shr ebx,cl */ - EMIT2(0xD3, add_1reg(0xE8, IA32_EBX)); - /* or dreg_hi,ebx */ - EMIT2(0x09, add_2reg(0xC0, dreg_hi, IA32_EBX)); - - /* goto out; */ - if (is_imm8(jmp_label(jmp_label3, 2))) - EMIT2(0xEB, jmp_label(jmp_label3, 2)); - else - EMIT1_off32(0xE9, jmp_label(jmp_label3, 5)); - - /* >= 32 */ - if (jmp_label1 == -1) - jmp_label1 = cnt; - - /* cmp ecx,64 */ - EMIT3(0x83, add_1reg(0xF8, IA32_ECX), 64); - /* Jumps when >= 64 */ - if (is_imm8(jmp_label(jmp_label2, 2))) - EMIT2(IA32_JAE, jmp_label(jmp_label2, 2)); - else - EMIT2_off32(0x0F, IA32_JAE + 0x10, jmp_label(jmp_label2, 6)); + /* cmp ecx,32 */ + EMIT3(0x83, add_1reg(0xF8, IA32_ECX), 32); + /* skip the next two instructions (4 bytes) when < 32 */ + EMIT2(IA32_JB, 4); - /* >= 32 && < 64 */ - /* sub ecx,32 */ - EMIT3(0x83, add_1reg(0xE8, IA32_ECX), 32); - /* shl dreg_lo,cl */ - EMIT2(0xD3, add_1reg(0xE0, dreg_lo)); /* mov dreg_hi,dreg_lo */ EMIT2(0x89, add_2reg(0xC0, dreg_hi, dreg_lo)); - - /* xor dreg_lo,dreg_lo */ - EMIT2(0x33, add_2reg(0xC0, dreg_lo, dreg_lo)); - - /* goto out; */ - if (is_imm8(jmp_label(jmp_label3, 2))) - EMIT2(0xEB, jmp_label(jmp_label3, 2)); - else - EMIT1_off32(0xE9, jmp_label(jmp_label3, 5)); - - /* >= 64 */ - if (jmp_label2 == -1) - jmp_label2 = cnt; /* xor dreg_lo,dreg_lo */ EMIT2(0x33, add_2reg(0xC0, dreg_lo, dreg_lo)); - /* xor dreg_hi,dreg_hi */ - EMIT2(0x33, add_2reg(0xC0, dreg_hi, dreg_hi)); - - if (jmp_label3 == -1) - jmp_label3 = cnt; if (dstk) { /* mov dword ptr [ebp+off],dreg_lo */ @@ -836,9 +785,6 @@ static inline void emit_ia32_arsh_r64(const u8 dst[], const u8 src[], { u8 *prog = *pprog; int cnt = 0; - static int jmp_label1 = -1; - static int jmp_label2 = -1; - static int jmp_label3 = -1; u8 dreg_lo = dstk ? IA32_EAX : dst_lo; u8 dreg_hi = dstk ? IA32_EDX : dst_hi; @@ -857,79 +803,23 @@ static inline void emit_ia32_arsh_r64(const u8 dst[], const u8 src[], /* mov ecx,src_lo */ EMIT2(0x8B, add_2reg(0xC0, src_lo, IA32_ECX)); - /* cmp ecx,32 */ - EMIT3(0x83, add_1reg(0xF8, IA32_ECX), 32); - /* Jumps when >= 32 */ - if (is_imm8(jmp_label(jmp_label1, 2))) - EMIT2(IA32_JAE, jmp_label(jmp_label1, 2)); - else - EMIT2_off32(0x0F, IA32_JAE + 0x10, jmp_label(jmp_label1, 6)); - - /* < 32 */ - /* lshr dreg_lo,cl */ - EMIT2(0xD3, add_1reg(0xE8, dreg_lo)); - /* mov ebx,dreg_hi */ - EMIT2(0x8B, add_2reg(0xC0, dreg_hi, IA32_EBX)); - /* ashr dreg_hi,cl */ + /* shrd dreg_lo,dreg_hi,cl */ + EMIT3(0x0F, 0xAD, add_2reg(0xC0, dreg_lo, dreg_hi)); + /* sar dreg_hi,cl */ EMIT2(0xD3, add_1reg(0xF8, dreg_hi)); - /* IA32_ECX = -IA32_ECX + 32 */ - /* neg ecx */ - EMIT2(0xF7, add_1reg(0xD8, IA32_ECX)); - /* add ecx,32 */ - EMIT3(0x83, add_1reg(0xC0, IA32_ECX), 32); - - /* shl ebx,cl */ - EMIT2(0xD3, add_1reg(0xE0, IA32_EBX)); - /* or dreg_lo,ebx */ - EMIT2(0x09, add_2reg(0xC0, dreg_lo, IA32_EBX)); - - /* goto out; */ - if (is_imm8(jmp_label(jmp_label3, 2))) - EMIT2(0xEB, jmp_label(jmp_label3, 2)); - else - EMIT1_off32(0xE9, jmp_label(jmp_label3, 5)); - - /* >= 32 */ - if (jmp_label1 == -1) - jmp_label1 = cnt; + /* if ecx >= 32, mov dreg_hi to dreg_lo and set/clear dreg_hi depending on sign */ - /* cmp ecx,64 */ - EMIT3(0x83, add_1reg(0xF8, IA32_ECX), 64); - /* Jumps when >= 64 */ - if (is_imm8(jmp_label(jmp_label2, 2))) - EMIT2(IA32_JAE, jmp_label(jmp_label2, 2)); - else - EMIT2_off32(0x0F, IA32_JAE + 0x10, jmp_label(jmp_label2, 6)); + /* cmp ecx,32 */ + EMIT3(0x83, add_1reg(0xF8, IA32_ECX), 32); + /* skip the next two instructions (5 bytes) when < 32 */ + EMIT2(IA32_JB, 5); - /* >= 32 && < 64 */ - /* sub ecx,32 */ - EMIT3(0x83, add_1reg(0xE8, IA32_ECX), 32); - /* ashr dreg_hi,cl */ - EMIT2(0xD3, add_1reg(0xF8, dreg_hi)); /* mov dreg_lo,dreg_hi */ EMIT2(0x89, add_2reg(0xC0, dreg_lo, dreg_hi)); - - /* ashr dreg_hi,imm8 */ + /* sar dreg_hi,31 */ EMIT3(0xC1, add_1reg(0xF8, dreg_hi), 31); - /* goto out; */ - if (is_imm8(jmp_label(jmp_label3, 2))) - EMIT2(0xEB, jmp_label(jmp_label3, 2)); - else - EMIT1_off32(0xE9, jmp_label(jmp_label3, 5)); - - /* >= 64 */ - if (jmp_label2 == -1) - jmp_label2 = cnt; - /* ashr dreg_hi,imm8 */ - EMIT3(0xC1, add_1reg(0xF8, dreg_hi), 31); - /* mov dreg_lo,dreg_hi */ - EMIT2(0x89, add_2reg(0xC0, dreg_lo, dreg_hi)); - - if (jmp_label3 == -1) - jmp_label3 = cnt; - if (dstk) { /* mov dword ptr [ebp+off],dreg_lo */ EMIT3(0x89, add_2reg(0x40, IA32_EBP, dreg_lo), @@ -948,9 +838,6 @@ static inline void emit_ia32_rsh_r64(const u8 dst[], const u8 src[], bool dstk, { u8 *prog = *pprog; int cnt = 0; - static int jmp_label1 = -1; - static int jmp_label2 = -1; - static int jmp_label3 = -1; u8 dreg_lo = dstk ? IA32_EAX : dst_lo; u8 dreg_hi = dstk ? IA32_EDX : dst_hi; @@ -969,77 +856,23 @@ static inline void emit_ia32_rsh_r64(const u8 dst[], const u8 src[], bool dstk, /* mov ecx,src_lo */ EMIT2(0x8B, add_2reg(0xC0, src_lo, IA32_ECX)); - /* cmp ecx,32 */ - EMIT3(0x83, add_1reg(0xF8, IA32_ECX), 32); - /* Jumps when >= 32 */ - if (is_imm8(jmp_label(jmp_label1, 2))) - EMIT2(IA32_JAE, jmp_label(jmp_label1, 2)); - else - EMIT2_off32(0x0F, IA32_JAE + 0x10, jmp_label(jmp_label1, 6)); - - /* < 32 */ - /* lshr dreg_lo,cl */ - EMIT2(0xD3, add_1reg(0xE8, dreg_lo)); - /* mov ebx,dreg_hi */ - EMIT2(0x8B, add_2reg(0xC0, dreg_hi, IA32_EBX)); + /* shrd dreg_lo,dreg_hi,cl */ + EMIT3(0x0F, 0xAD, add_2reg(0xC0, dreg_lo, dreg_hi)); /* shr dreg_hi,cl */ EMIT2(0xD3, add_1reg(0xE8, dreg_hi)); - /* IA32_ECX = -IA32_ECX + 32 */ - /* neg ecx */ - EMIT2(0xF7, add_1reg(0xD8, IA32_ECX)); - /* add ecx,32 */ - EMIT3(0x83, add_1reg(0xC0, IA32_ECX), 32); - - /* shl ebx,cl */ - EMIT2(0xD3, add_1reg(0xE0, IA32_EBX)); - /* or dreg_lo,ebx */ - EMIT2(0x09, add_2reg(0xC0, dreg_lo, IA32_EBX)); - - /* goto out; */ - if (is_imm8(jmp_label(jmp_label3, 2))) - EMIT2(0xEB, jmp_label(jmp_label3, 2)); - else - EMIT1_off32(0xE9, jmp_label(jmp_label3, 5)); + /* if ecx >= 32, mov dreg_hi to dreg_lo and clear dreg_hi */ - /* >= 32 */ - if (jmp_label1 == -1) - jmp_label1 = cnt; - /* cmp ecx,64 */ - EMIT3(0x83, add_1reg(0xF8, IA32_ECX), 64); - /* Jumps when >= 64 */ - if (is_imm8(jmp_label(jmp_label2, 2))) - EMIT2(IA32_JAE, jmp_label(jmp_label2, 2)); - else - EMIT2_off32(0x0F, IA32_JAE + 0x10, jmp_label(jmp_label2, 6)); + /* cmp ecx,32 */ + EMIT3(0x83, add_1reg(0xF8, IA32_ECX), 32); + /* skip the next two instructions (4 bytes) when < 32 */ + EMIT2(IA32_JB, 4); - /* >= 32 && < 64 */ - /* sub ecx,32 */ - EMIT3(0x83, add_1reg(0xE8, IA32_ECX), 32); - /* shr dreg_hi,cl */ - EMIT2(0xD3, add_1reg(0xE8, dreg_hi)); /* mov dreg_lo,dreg_hi */ EMIT2(0x89, add_2reg(0xC0, dreg_lo, dreg_hi)); /* xor dreg_hi,dreg_hi */ EMIT2(0x33, add_2reg(0xC0, dreg_hi, dreg_hi)); - /* goto out; */ - if (is_imm8(jmp_label(jmp_label3, 2))) - EMIT2(0xEB, jmp_label(jmp_label3, 2)); - else - EMIT1_off32(0xE9, jmp_label(jmp_label3, 5)); - - /* >= 64 */ - if (jmp_label2 == -1) - jmp_label2 = cnt; - /* xor dreg_lo,dreg_lo */ - EMIT2(0x33, add_2reg(0xC0, dreg_lo, dreg_lo)); - /* xor dreg_hi,dreg_hi */ - EMIT2(0x33, add_2reg(0xC0, dreg_hi, dreg_hi)); - - if (jmp_label3 == -1) - jmp_label3 = cnt; - if (dstk) { /* mov dword ptr [ebp+off],dreg_lo */ EMIT3(0x89, add_2reg(0x40, IA32_EBP, dreg_lo), @@ -1069,27 +902,10 @@ static inline void emit_ia32_lsh_i64(const u8 dst[], const u32 val, } /* Do LSH operation */ if (val < 32) { - /* shl dreg_hi,imm8 */ - EMIT3(0xC1, add_1reg(0xE0, dreg_hi), val); - /* mov ebx,dreg_lo */ - EMIT2(0x8B, add_2reg(0xC0, dreg_lo, IA32_EBX)); + /* shld dreg_hi,dreg_lo,imm8 */ + EMIT4(0x0F, 0xA4, add_2reg(0xC0, dreg_hi, dreg_lo), val); /* shl dreg_lo,imm8 */ EMIT3(0xC1, add_1reg(0xE0, dreg_lo), val); - - /* IA32_ECX = 32 - val */ - /* mov ecx,val */ - EMIT2(0xB1, val); - /* movzx ecx,ecx */ - EMIT3(0x0F, 0xB6, add_2reg(0xC0, IA32_ECX, IA32_ECX)); - /* neg ecx */ - EMIT2(0xF7, add_1reg(0xD8, IA32_ECX)); - /* add ecx,32 */ - EMIT3(0x83, add_1reg(0xC0, IA32_ECX), 32); - - /* shr ebx,cl */ - EMIT2(0xD3, add_1reg(0xE8, IA32_EBX)); - /* or dreg_hi,ebx */ - EMIT2(0x09, add_2reg(0xC0, dreg_hi, IA32_EBX)); } else if (val >= 32 && val < 64) { u32 value = val - 32; @@ -1135,27 +951,10 @@ static inline void emit_ia32_rsh_i64(const u8 dst[], const u32 val, /* Do RSH operation */ if (val < 32) { - /* shr dreg_lo,imm8 */ - EMIT3(0xC1, add_1reg(0xE8, dreg_lo), val); - /* mov ebx,dreg_hi */ - EMIT2(0x8B, add_2reg(0xC0, dreg_hi, IA32_EBX)); + /* shrd dreg_lo,dreg_hi,imm8 */ + EMIT4(0x0F, 0xAC, add_2reg(0xC0, dreg_lo, dreg_hi), val); /* shr dreg_hi,imm8 */ EMIT3(0xC1, add_1reg(0xE8, dreg_hi), val); - - /* IA32_ECX = 32 - val */ - /* mov ecx,val */ - EMIT2(0xB1, val); - /* movzx ecx,ecx */ - EMIT3(0x0F, 0xB6, add_2reg(0xC0, IA32_ECX, IA32_ECX)); - /* neg ecx */ - EMIT2(0xF7, add_1reg(0xD8, IA32_ECX)); - /* add ecx,32 */ - EMIT3(0x83, add_1reg(0xC0, IA32_ECX), 32); - - /* shl ebx,cl */ - EMIT2(0xD3, add_1reg(0xE0, IA32_EBX)); - /* or dreg_lo,ebx */ - EMIT2(0x09, add_2reg(0xC0, dreg_lo, IA32_EBX)); } else if (val >= 32 && val < 64) { u32 value = val - 32; @@ -1200,27 +999,10 @@ static inline void emit_ia32_arsh_i64(const u8 dst[], const u32 val, } /* Do RSH operation */ if (val < 32) { - /* shr dreg_lo,imm8 */ - EMIT3(0xC1, add_1reg(0xE8, dreg_lo), val); - /* mov ebx,dreg_hi */ - EMIT2(0x8B, add_2reg(0xC0, dreg_hi, IA32_EBX)); + /* shrd dreg_lo,dreg_hi,imm8 */ + EMIT4(0x0F, 0xAC, add_2reg(0xC0, dreg_lo, dreg_hi), val); /* ashr dreg_hi,imm8 */ EMIT3(0xC1, add_1reg(0xF8, dreg_hi), val); - - /* IA32_ECX = 32 - val */ - /* mov ecx,val */ - EMIT2(0xB1, val); - /* movzx ecx,ecx */ - EMIT3(0x0F, 0xB6, add_2reg(0xC0, IA32_ECX, IA32_ECX)); - /* neg ecx */ - EMIT2(0xF7, add_1reg(0xD8, IA32_ECX)); - /* add ecx,32 */ - EMIT3(0x83, add_1reg(0xC0, IA32_ECX), 32); - - /* shl ebx,cl */ - EMIT2(0xD3, add_1reg(0xE0, IA32_EBX)); - /* or dreg_lo,ebx */ - EMIT2(0x09, add_2reg(0xC0, dreg_lo, IA32_EBX)); } else if (val >= 32 && val < 64) { u32 value = val - 32; @@ -1713,8 +1495,13 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, case BPF_ALU64 | BPF_MOV | BPF_X: switch (BPF_SRC(code)) { case BPF_X: - emit_ia32_mov_r64(is64, dst, src, dstk, - sstk, &prog); + if (imm32 == 1) { + /* Special mov32 for zext. */ + emit_ia32_mov_i(dst_hi, 0, dstk, &prog); + break; + } + emit_ia32_mov_r64(is64, dst, src, dstk, sstk, + &prog, bpf_prog->aux); break; case BPF_K: /* Sign-extend immediate value to dst reg */ @@ -1754,11 +1541,13 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, switch (BPF_SRC(code)) { case BPF_X: emit_ia32_alu_r64(is64, BPF_OP(code), dst, - src, dstk, sstk, &prog); + src, dstk, sstk, &prog, + bpf_prog->aux); break; case BPF_K: emit_ia32_alu_i64(is64, BPF_OP(code), dst, - imm32, dstk, &prog); + imm32, dstk, &prog, + bpf_prog->aux); break; } break; @@ -1777,7 +1566,8 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, false, &prog); break; } - emit_ia32_mov_i(dst_hi, 0, dstk, &prog); + if (!bpf_prog->aux->verifier_zext) + emit_ia32_mov_i(dst_hi, 0, dstk, &prog); break; case BPF_ALU | BPF_LSH | BPF_X: case BPF_ALU | BPF_RSH | BPF_X: @@ -1797,7 +1587,8 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, &prog); break; } - emit_ia32_mov_i(dst_hi, 0, dstk, &prog); + if (!bpf_prog->aux->verifier_zext) + emit_ia32_mov_i(dst_hi, 0, dstk, &prog); break; /* dst = dst / src(imm) */ /* dst = dst % src(imm) */ @@ -1819,7 +1610,8 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, &prog); break; } - emit_ia32_mov_i(dst_hi, 0, dstk, &prog); + if (!bpf_prog->aux->verifier_zext) + emit_ia32_mov_i(dst_hi, 0, dstk, &prog); break; case BPF_ALU64 | BPF_DIV | BPF_K: case BPF_ALU64 | BPF_DIV | BPF_X: @@ -1836,7 +1628,8 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, EMIT2_off32(0xC7, add_1reg(0xC0, IA32_ECX), imm32); emit_ia32_shift_r(BPF_OP(code), dst_lo, IA32_ECX, dstk, false, &prog); - emit_ia32_mov_i(dst_hi, 0, dstk, &prog); + if (!bpf_prog->aux->verifier_zext) + emit_ia32_mov_i(dst_hi, 0, dstk, &prog); break; /* dst = dst << imm */ case BPF_ALU64 | BPF_LSH | BPF_K: @@ -1872,7 +1665,8 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, case BPF_ALU | BPF_NEG: emit_ia32_alu_i(is64, false, BPF_OP(code), dst_lo, 0, dstk, &prog); - emit_ia32_mov_i(dst_hi, 0, dstk, &prog); + if (!bpf_prog->aux->verifier_zext) + emit_ia32_mov_i(dst_hi, 0, dstk, &prog); break; /* dst = ~dst (64 bit) */ case BPF_ALU64 | BPF_NEG: @@ -1892,11 +1686,13 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, break; /* dst = htole(dst) */ case BPF_ALU | BPF_END | BPF_FROM_LE: - emit_ia32_to_le_r64(dst, imm32, dstk, &prog); + emit_ia32_to_le_r64(dst, imm32, dstk, &prog, + bpf_prog->aux); break; /* dst = htobe(dst) */ case BPF_ALU | BPF_END | BPF_FROM_BE: - emit_ia32_to_be_r64(dst, imm32, dstk, &prog); + emit_ia32_to_be_r64(dst, imm32, dstk, &prog, + bpf_prog->aux); break; /* dst = imm64 */ case BPF_LD | BPF_IMM | BPF_DW: { @@ -2051,6 +1847,8 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, case BPF_B: case BPF_H: case BPF_W: + if (!bpf_prog->aux->verifier_zext) + break; if (dstk) { EMIT3(0xC7, add_1reg(0x40, IA32_EBP), STACK_VAR(dst_hi)); @@ -2475,6 +2273,11 @@ notyet: return proglen; } +bool bpf_jit_needs_zext(void) +{ + return true; +} + struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog) { struct bpf_binary_header *header = NULL; |