From e7423d39f936f801697f70516c4f5c778d912f9a Mon Sep 17 00:00:00 2001 From: JackDoan Date: Thu, 6 Nov 2025 09:18:33 -0600 Subject: [PATCH] cursed --- go.mod | 1 + go.sum | 7 +- main.go | 5 +- udp/udp_linux.go | 2 +- udp/wireguard_conn_linux.go | 9 +- wgstack/conn/bind_std.go | 24 +++- wgstack/conn/controlfns.go | 182 ++++++++++++++++++++++++++++++- wgstack/conn/controlfns_linux.go | 1 + wgstack/conn/features_linux.go | 8 +- 9 files changed, 223 insertions(+), 16 deletions(-) diff --git a/go.mod b/go.mod index b060abc..2a33359 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( dario.cat/mergo v1.0.2 github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be github.com/armon/go-radix v1.0.0 + github.com/cilium/ebpf v0.12.3 github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 github.com/flynn/noise v1.1.0 github.com/gaissmai/bart v0.25.0 diff --git a/go.sum b/go.sum index 3aee30a..577ca16 100644 --- a/go.sum +++ b/go.sum @@ -17,6 +17,8 @@ github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6r github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cilium/ebpf v0.12.3 h1:8ht6F9MquybnY97at+VDZb3eQQr8ev79RueWeVaEcG4= +github.com/cilium/ebpf v0.12.3/go.mod h1:TctK1ivibvI3znr66ljgi4hqOT8EYQjz1KWBfb1UVgM= github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 h1:M5QgkYacWj0Xs8MhpIK/5uwU02icXpEoSo9sM2aRCps= github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432/go.mod h1:xwIwAxMvYnVrGJPe2FKx5prTrnAjGOD8zvDOnxnrrkM= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -24,6 +26,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg= github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= +github.com/frankban/quicktest v1.14.5 h1:dfYrrRyLtiqT9GyKXgdh+k4inNeTvmGbuSgZ3lx3GhA= +github.com/frankban/quicktest v1.14.5/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/gaissmai/bart v0.25.0 h1:eqiokVPqM3F94vJ0bTHXHtH91S8zkKL+bKh+BsGOsJM= github.com/gaissmai/bart v0.25.0/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= @@ -78,8 +82,9 @@ github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfn github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= diff --git a/main.go b/main.go index 9ef8ab8..1dc5851 100644 --- a/main.go +++ b/main.go @@ -179,7 +179,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg useWGDefault := runtime.GOOS == "linux" useWG := c.GetBool("listen.use_wireguard_stack", useWGDefault) - var mkListener func(*logrus.Logger, netip.Addr, int, bool, int) (udp.Conn, error) + var mkListener func(*logrus.Logger, netip.Addr, int, bool, int, int) (udp.Conn, error) if useWG { mkListener = udp.NewWireguardListener } else { @@ -188,10 +188,11 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg for i := 0; i < routines; i++ { l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port))) - udpServer, err := mkListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64)) + udpServer, err := mkListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64), i) if err != nil { return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err) } + //todo set bpf on zeroth socket udpServer.ReloadConfig(c) if cfg, ok := udpServer.(interface { ConfigureOffload(bool, bool, int) diff --git a/udp/udp_linux.go b/udp/udp_linux.go index 4dbdf3a..2571515 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -32,7 +32,7 @@ func maybeIPV4(ip net.IP) (net.IP, bool) { return ip, false } -func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int, q int) (Conn, error) { af := unix.AF_INET6 if ip.Is4() { af = unix.AF_INET diff --git a/udp/wireguard_conn_linux.go b/udp/wireguard_conn_linux.go index c3f9e9a..934b658 100644 --- a/udp/wireguard_conn_linux.go +++ b/udp/wireguard_conn_linux.go @@ -27,13 +27,13 @@ type WGConn struct { enableGRO bool gsoMaxSeg int closed atomic.Bool - + q int closeOnce sync.Once } // NewWireguardListener creates a UDP listener backed by WireGuard's StdNetBind. -func NewWireguardListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { - bind := wgconn.NewStdNetBindForAddr(ip, multi) +func NewWireguardListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int, q int) (Conn, error) { + bind := wgconn.NewStdNetBindForAddr(ip, multi, q) recvers, actualPort, err := bind.Open(uint16(port)) if err != nil { return nil, err @@ -51,6 +51,7 @@ func NewWireguardListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, reqBatch: batch, localIP: ip, localPort: actualPort, + q: q, }, nil } @@ -71,7 +72,7 @@ func (c *WGConn) listen(fn wgconn.ReceiveFunc, r EncReader) { batchSize := c.batch packets := make([][]byte, batchSize) for i := range packets { - packets[i] = make([]byte, MTU) + packets[i] = make([]byte, 0xffff) } sizes := make([]int, batchSize) endpoints := make([]wgconn.Endpoint, batchSize) diff --git a/wgstack/conn/bind_std.go b/wgstack/conn/bind_std.go index d63466d..fc6e9ed 100644 --- a/wgstack/conn/bind_std.go +++ b/wgstack/conn/bind_std.go @@ -46,6 +46,7 @@ type StdNetBind struct { blackhole4 bool blackhole6 bool + q int } // NewStdNetBind creates a bind that listens on all interfaces. @@ -56,8 +57,9 @@ func NewStdNetBind() *StdNetBind { // NewStdNetBindForAddr creates a bind that listens on a specific address. // If addr is IPv4, only the IPv4 socket will be created. For IPv6, only the // IPv6 socket will be created. -func NewStdNetBindForAddr(addr netip.Addr, reusePort bool) *StdNetBind { +func NewStdNetBindForAddr(addr netip.Addr, reusePort bool, q int) *StdNetBind { b := NewStdNetBind() + b.q = q //if addr.IsValid() { // if addr.IsUnspecified() { // // keep dual-stack defaults with empty listen addresses @@ -147,12 +149,24 @@ func (e *StdNetEndpoint) DstToString() string { return e.AddrPort.String() } -func listenNet(network string, port int) (*net.UDPConn, int, error) { - conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port)) +func listenNet(network string, port int, q int) (*net.UDPConn, int, error) { + lc := listenConfig(q) + + conn, err := lc.ListenPacket(context.Background(), network, ":"+strconv.Itoa(port)) if err != nil { return nil, 0, err } + if q == 0 { + if EvilFdZero == 0 { + panic("fuck") + } + err = reusePortHax(EvilFdZero) + if err != nil { + return nil, 0, fmt.Errorf("reuse port hax: %v", err) + } + } + // Retrieve port. laddr := conn.LocalAddr() uaddr, err := net.ResolveUDPAddr( @@ -185,13 +199,13 @@ again: var v4pc *ipv4.PacketConn var v6pc *ipv6.PacketConn - v4conn, port, err = listenNet("udp4", port) + v4conn, port, err = listenNet("udp4", port, s.q) if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { return nil, 0, err } // Listen on the same port as we're using for ipv4. - v6conn, port, err = listenNet("udp6", port) + v6conn, port, err = listenNet("udp6", port, s.q) if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { v4conn.Close() tries++ diff --git a/wgstack/conn/controlfns.go b/wgstack/conn/controlfns.go index eb2c7a7..4a03f74 100644 --- a/wgstack/conn/controlfns.go +++ b/wgstack/conn/controlfns.go @@ -5,8 +5,12 @@ package conn import ( + "fmt" "net" "syscall" + + "github.com/cilium/ebpf" + "github.com/cilium/ebpf/asm" ) // UDP socket read/write buffer size (7MB). The value of 7MB is chosen as it is @@ -25,10 +29,169 @@ type controlFn func(network, address string, c syscall.RawConn) error // that can apply socket options. var controlFns = []controlFn{} +const SO_ATTACH_REUSEPORT_EBPF = 52 + +//Create eBPF program that returns a hash to distribute packets + +func createReuseportProgram() (*ebpf.Program, error) { + // This program uses the packet's hash and returns it modulo number of sockets + // Simple version: just return a counter-based distribution + //instructions := asm.Instructions{ + // // Load the skb->hash value (already computed by kernel) + // asm.LoadMem(asm.R0, asm.R1, int16(unsafe.Offsetof(unix.XDPMd{}.RxQueueIndex)), asm.Word), + // asm.Return(), + //} + // + //// Alternative: simpler round-robin approach + //// This returns the CPU number, effectively round-robin + //instructions := asm.Instructions{ + // asm.Mov.Reg(asm.R0, asm.R1), // Move ctx to R0 + // asm.LoadMem(asm.R0, asm.R1, 0, asm.Word), // Load some field + // asm.Return(), + //} + + // Better: Use BPF helper to get random/hash value + //instructions := asm.Instructions{ + // // Call get_prandom_u32() to get random value for distribution + // asm.Mov.Imm(asm.R0, 0), + // asm.Call.Label("get_prandom_u32"), + // asm.Return(), + //} + // + //prog, err := ebpf.NewProgram(&ebpf.ProgramSpec{ + // Type: ebpf.SocketFilter, + // Instructions: instructions, + // License: "GPL", + //}) + + //instructions := asm.Instructions{ + // // R1 contains pointer to skb + // // Load skb->hash at offset 0x20 (may vary by kernel, but 0x20 is common) + // asm.LoadMem(asm.R0, asm.R1, 0x20, asm.Word), + // + // // If hash is 0, use rxhash instead (fallback) + // asm.JEq.Imm(asm.R0, 0, "use_rxhash"), + // asm.Return().Sym("return"), + // + // // Fallback: load rxhash + // asm.LoadMem(asm.R0, asm.R1, 0x24, asm.Word).Sym("use_rxhash"), + // asm.Return(), + //} + // + //prog, err := ebpf.NewProgram(&ebpf.ProgramSpec{ + // Type: ebpf.SkReuseport, + // Instructions: instructions, + // License: "GPL", + //}) + + //instructions := asm.Instructions{ + // // R1 = ctx (sk_reuseport_md) + // // R2 = sk_reuseport map (we'll use NULL/0 for default behavior) + // // R3 = key (select socket index) + // // R4 = flags + // + // // Simple approach: use the hash field from sk_reuseport_md + // // struct sk_reuseport_md { ... __u32 hash; ... } at offset 24 + // asm.Mov.Reg(asm.R6, asm.R1), // Save ctx + // + // // Load the hash value at offset 24 + // asm.LoadMem(asm.R2, asm.R6, 24, asm.Word), + // + // // Call bpf_sk_select_reuseport(ctx, map, key, flags) + // asm.Mov.Reg(asm.R1, asm.R6), // ctx + // asm.Mov.Imm(asm.R2, 0), // map (NULL = use default) + // asm.Mov.Reg(asm.R3, asm.R2), // key = hash we loaded (in R2) + // asm.Mov.Imm(asm.R4, 0), // flags + // asm.Call.Label("sk_select_reuseport"), + // + // // Return 0 + // asm.Mov.Imm(asm.R0, 0), + // asm.Return(), + //} + // + //prog, err := ebpf.NewProgram(&ebpf.ProgramSpec{ + // Type: ebpf.SkReuseport, + // Instructions: instructions, + // License: "GPL", + //}) + + instructions := asm.Instructions{ + // R1 = ctx (sk_reuseport_md pointer) + // Load hash from sk_reuseport_md at offset 24 + //asm.LoadMem(asm.R0, asm.R1, 20, asm.Word), + + // R1 = ctx (save it) + asm.Mov.Reg(asm.R6, asm.R1), + + // Prepare string on stack: "BPF called!\n" + // We need to build the format string on the stack + asm.Mov.Reg(asm.R1, asm.R10), // R1 = frame pointer + asm.Add.Imm(asm.R1, -16), // R1 = stack location for string + + // Write "BPF called!\n" to stack (we'll use a simpler version) + // Store immediate 64-bit values + asm.StoreImm(asm.R1, 0, 0x2066706220, asm.DWord), // "bpf " + asm.StoreImm(asm.R1, 8, 0x0a21, asm.DWord), // "!\n" + + // Call bpf_trace_printk(fmt, fmt_size) + // R1 already points to format string + asm.Mov.Imm(asm.R2, 16), // R2 = format size + asm.Call.Label("bpf_printk"), + + // Return 0 (send to socket 0 for testing) + asm.Mov.Imm(asm.R0, 0), + asm.Return(), + + //asm.Mov.Imm(asm.R0, 0), + //// Just return the hash directly + //// The kernel will automatically modulo by number of sockets + //asm.Return(), + } + + prog, err := ebpf.NewProgram(&ebpf.ProgramSpec{ + Type: ebpf.SkReuseport, + Instructions: instructions, + License: "GPL", + }) + + return prog, err +} + +//func createReuseportProgram() (*ebpf.Program, error) { +// // Try offset 20 (common in newer kernels) +// instructions := asm.Instructions{ +// asm.LoadMem(asm.R0, asm.R1, 20, asm.Word), +// asm.Return(), +// } +// +// prog, err := ebpf.NewProgram(&ebpf.ProgramSpec{ +// Type: ebpf.SkReuseport, +// Instructions: instructions, +// License: "GPL", +// }) +// +// return prog, err +//} + +func reusePortHax(fd uintptr) error { + prog, err := createReuseportProgram() + if err != nil { + return fmt.Errorf("failed to create eBPF program: %w", err) + } + //defer prog.Close() + sockErr := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, SO_ATTACH_REUSEPORT_EBPF, prog.FD()) + if sockErr != nil { + return sockErr + } + return nil +} + +var EvilFdZero uintptr + // listenConfig returns a net.ListenConfig that applies the controlFns to the // socket prior to bind. This is used to apply socket buffer sizing and packet // information OOB configuration for sticky sockets. -func listenConfig() *net.ListenConfig { +func listenConfig(q int) *net.ListenConfig { return &net.ListenConfig{ Control: func(network, address string, c syscall.RawConn) error { for _, fn := range controlFns { @@ -36,6 +199,23 @@ func listenConfig() *net.ListenConfig { return err } } + + if q == 0 { + c.Control(func(fd uintptr) { + EvilFdZero = fd + }) + // var e error + // err := c.Control(func(fd uintptr) { + // e = reusePortHax(fd) + // }) + // if err != nil { + // return err + // } + // if e != nil { + // return e + // } + } + return nil }, } diff --git a/wgstack/conn/controlfns_linux.go b/wgstack/conn/controlfns_linux.go index e765d7a..3b1142b 100644 --- a/wgstack/conn/controlfns_linux.go +++ b/wgstack/conn/controlfns_linux.go @@ -30,6 +30,7 @@ func init() { _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, socketBufferSize) _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, socketBufferSize) _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1) //todo!!! + _ = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1) //todo!!! _ = unix.SetsockoptInt(int(fd), unix.SOL_UDP, unix.UDP_SEGMENT, 0xffff) //todo!!! //print(err.Error()) }) diff --git a/wgstack/conn/features_linux.go b/wgstack/conn/features_linux.go index 8959d93..ce652bd 100644 --- a/wgstack/conn/features_linux.go +++ b/wgstack/conn/features_linux.go @@ -6,6 +6,7 @@ package conn import ( + "fmt" "net" "golang.org/x/sys/unix" @@ -16,12 +17,15 @@ func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) { if err != nil { return } + a := 0 err = rc.Control(func(fd uintptr) { - _, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT) - txOffload = errSyscall == nil + a, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT) + + txOffload = err == nil opt, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO) rxOffload = errSyscall == nil && opt == 1 }) + fmt.Printf("%d", a) if err != nil { return false, false }