From befba57366d0c5e6548dd9ed97275fa516c3637b Mon Sep 17 00:00:00 2001 From: JackDoan Date: Wed, 5 Nov 2025 15:38:47 -0600 Subject: [PATCH] hmmm --- wgstack/conn/bind_std.go | 616 ++++++++++++++++--------------- wgstack/conn/controlfns_linux.go | 3 + wgstack/conn/sticky_linux.go | 115 +++--- 3 files changed, 380 insertions(+), 354 deletions(-) diff --git a/wgstack/conn/bind_std.go b/wgstack/conn/bind_std.go index 1e103de..d63466d 100644 --- a/wgstack/conn/bind_std.go +++ b/wgstack/conn/bind_std.go @@ -1,12 +1,14 @@ -// SPDX-License-Identifier: MIT -// -// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ package conn import ( "context" "errors" + "fmt" "net" "net/netip" "runtime" @@ -16,7 +18,6 @@ import ( "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" - "golang.org/x/sys/unix" ) var ( @@ -29,28 +30,53 @@ var ( // methods for sending and receiving multiple datagrams per-syscall. See the // proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564. type StdNetBind struct { - mu sync.Mutex // protects all fields except as specified - ipv4 *net.UDPConn - ipv6 *net.UDPConn - ipv4PC *ipv4.PacketConn // will be nil on non-Linux - ipv6PC *ipv6.PacketConn // will be nil on non-Linux + mu sync.Mutex // protects all fields except as specified + ipv4 *net.UDPConn + ipv6 *net.UDPConn + ipv4PC *ipv4.PacketConn // will be nil on non-Linux + ipv6PC *ipv6.PacketConn // will be nil on non-Linux + ipv4TxOffload bool + ipv4RxOffload bool + ipv6TxOffload bool + ipv6RxOffload bool - // these three fields are not guarded by mu - udpAddrPool sync.Pool - ipv4MsgsPool sync.Pool - ipv6MsgsPool sync.Pool + // these two fields are not guarded by mu + udpAddrPool sync.Pool + msgsPool sync.Pool blackhole4 bool blackhole6 bool - - listenAddr4 string - listenAddr6 string - bindV4 bool - bindV6 bool - reusePort bool } -func newStdNetBind() *StdNetBind { +// NewStdNetBind creates a bind that listens on all interfaces. +func NewStdNetBind() *StdNetBind { + return newStdNetBind().(*StdNetBind) +} + +// NewStdNetBindForAddr creates a bind that listens on a specific address. +// If addr is IPv4, only the IPv4 socket will be created. For IPv6, only the +// IPv6 socket will be created. +func NewStdNetBindForAddr(addr netip.Addr, reusePort bool) *StdNetBind { + b := NewStdNetBind() + //if addr.IsValid() { + // if addr.IsUnspecified() { + // // keep dual-stack defaults with empty listen addresses + // } else if addr.Is4() { + // b.listenAddr4 = addr.Unmap().String() + // b.bindV4 = true + // b.bindV6 = false + // } else { + // b.listenAddr6 = addr.Unmap().String() + // b.bindV6 = true + // b.bindV4 = false + // } + //} + //b.reusePort = reusePort + + return b +} + +func newStdNetBind() Bind { return &StdNetBind{ udpAddrPool: sync.Pool{ New: func() any { @@ -60,68 +86,28 @@ func newStdNetBind() *StdNetBind { }, }, - ipv4MsgsPool: sync.Pool{ - New: func() any { - msgs := make([]ipv4.Message, IdealBatchSize) - for i := range msgs { - msgs[i].Buffers = make(net.Buffers, 1) - msgs[i].OOB = make([]byte, srcControlSize) - } - return &msgs - }, - }, - - ipv6MsgsPool: sync.Pool{ + msgsPool: sync.Pool{ New: func() any { + // ipv6.Message and ipv4.Message are interchangeable as they are + // both aliases for x/net/internal/socket.Message. msgs := make([]ipv6.Message, IdealBatchSize) for i := range msgs { msgs[i].Buffers = make(net.Buffers, 1) - msgs[i].OOB = make([]byte, srcControlSize) + msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize) } return &msgs }, }, - bindV4: true, - bindV6: true, - reusePort: false, } } -// NewStdNetBind creates a bind that listens on all interfaces. -func NewStdNetBind() *StdNetBind { - return newStdNetBind() -} - -// NewStdNetBindForAddr creates a bind that listens on a specific address. -// If addr is IPv4, only the IPv4 socket will be created. For IPv6, only the -// IPv6 socket will be created. -func NewStdNetBindForAddr(addr netip.Addr, reusePort bool) *StdNetBind { - b := newStdNetBind() - if addr.IsValid() { - if addr.IsUnspecified() { - // keep dual-stack defaults with empty listen addresses - } else if addr.Is4() { - b.listenAddr4 = addr.Unmap().String() - b.bindV4 = true - b.bindV6 = false - } else { - b.listenAddr6 = addr.Unmap().String() - b.bindV6 = true - b.bindV4 = false - } - } - b.reusePort = reusePort - return b -} - type StdNetEndpoint struct { // AddrPort is the endpoint destination. netip.AddrPort - // src is the current sticky source address and interface index, if supported. - src struct { - netip.Addr - ifidx int32 - } + // src is the current sticky source address and interface index, if + // supported. Typically this is a PKTINFO structure from/for control + // messages, see unix.PKTINFO for an example. + src []byte } var ( @@ -140,21 +126,17 @@ func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) { } func (e *StdNetEndpoint) ClearSrc() { - e.src.ifidx = 0 - e.src.Addr = netip.Addr{} + if e.src != nil { + // Truncate src, no need to reallocate. + e.src = e.src[:0] + } } func (e *StdNetEndpoint) DstIP() netip.Addr { return e.AddrPort.Addr() } -func (e *StdNetEndpoint) SrcIP() netip.Addr { - return e.src.Addr -} - -func (e *StdNetEndpoint) SrcIfidx() int32 { - return e.src.ifidx -} +// See control_default,linux, etc for implementations of SrcIP and SrcIfidx. func (e *StdNetEndpoint) DstToBytes() []byte { b, _ := e.AddrPort.MarshalBinary() @@ -165,32 +147,8 @@ func (e *StdNetEndpoint) DstToString() string { return e.AddrPort.String() } -func (e *StdNetEndpoint) SrcToString() string { - return e.src.Addr.String() -} - -func (s *StdNetBind) listenNet(network string, host string, port int) (*net.UDPConn, int, error) { - lc := listenConfig() - if s.reusePort { - base := lc.Control - lc.Control = func(network, address string, c syscall.RawConn) error { - if base != nil { - if err := base(network, address, c); err != nil { - return err - } - } - return c.Control(func(fd uintptr) { - _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1) - }) - } - } - - addr := ":" + strconv.Itoa(port) - if host != "" { - addr = net.JoinHostPort(host, strconv.Itoa(port)) - } - - conn, err := lc.ListenPacket(context.Background(), network, addr) +func listenNet(network string, port int) (*net.UDPConn, int, error) { + conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port)) if err != nil { return nil, 0, err } @@ -204,47 +162,10 @@ func (s *StdNetBind) listenNet(network string, host string, port int) (*net.UDPC if err != nil { return nil, 0, err } + return conn.(*net.UDPConn), uaddr.Port, nil } -func (s *StdNetBind) openIPv4(port int) (*net.UDPConn, *ipv4.PacketConn, int, error) { - if !s.bindV4 { - return nil, nil, port, nil - } - host := s.listenAddr4 - conn, actualPort, err := s.listenNet("udp4", host, port) - if err != nil { - if errors.Is(err, syscall.EAFNOSUPPORT) { - return nil, nil, port, nil - } - return nil, nil, port, err - } - if runtime.GOOS != "linux" { - return conn, nil, actualPort, nil - } - pc := ipv4.NewPacketConn(conn) - return conn, pc, actualPort, nil -} - -func (s *StdNetBind) openIPv6(port int) (*net.UDPConn, *ipv6.PacketConn, int, error) { - if !s.bindV6 { - return nil, nil, port, nil - } - host := s.listenAddr6 - conn, actualPort, err := s.listenNet("udp6", host, port) - if err != nil { - if errors.Is(err, syscall.EAFNOSUPPORT) { - return nil, nil, port, nil - } - return nil, nil, port, err - } - if runtime.GOOS != "linux" { - return conn, nil, actualPort, nil - } - pc := ipv6.NewPacketConn(conn) - return conn, pc, actualPort, nil -} - func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { s.mu.Lock() defer s.mu.Unlock() @@ -260,46 +181,44 @@ func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { // If uport is 0, we can retry on failure. again: port := int(uport) - var v4conn *net.UDPConn - var v6conn *net.UDPConn + var v4conn, v6conn *net.UDPConn var v4pc *ipv4.PacketConn var v6pc *ipv6.PacketConn - v4conn, v4pc, port, err = s.openIPv4(port) - if err != nil { + v4conn, port, err = listenNet("udp4", port) + if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { return nil, 0, err } // Listen on the same port as we're using for ipv4. - v6conn, v6pc, port, err = s.openIPv6(port) + v6conn, port, err = listenNet("udp6", port) if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { - if v4conn != nil { - v4conn.Close() - } + v4conn.Close() tries++ goto again } - if err != nil { - if v4conn != nil { - v4conn.Close() - } + if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { + v4conn.Close() return nil, 0, err } - var fns []ReceiveFunc if v4conn != nil { - s.ipv4 = v4conn - if v4pc != nil { + s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn) + if runtime.GOOS == "linux" || runtime.GOOS == "android" { + v4pc = ipv4.NewPacketConn(v4conn) s.ipv4PC = v4pc } - fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn)) + fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload)) + s.ipv4 = v4conn } if v6conn != nil { - s.ipv6 = v6conn - if v6pc != nil { + s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn) + if runtime.GOOS == "linux" || runtime.GOOS == "android" { + v6pc = ipv6.NewPacketConn(v6conn) s.ipv6PC = v6pc } - fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn)) + fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload)) + s.ipv6 = v6conn } if len(fns) == 0 { return nil, 0, syscall.EAFNOSUPPORT @@ -308,76 +227,101 @@ again: return fns, uint16(port), nil } -func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn) ReceiveFunc { - return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { - msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message) - defer s.ipv4MsgsPool.Put(msgs) - for i := range bufs { - (*msgs)[i].Buffers[0] = bufs[i] - } - var numMsgs int - if runtime.GOOS == "linux" && pc != nil { - numMsgs, err = pc.ReadBatch(*msgs, 0) +func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) { + for i := range *msgs { + (*msgs)[i].OOB = (*msgs)[i].OOB[:0] + (*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB} + } + s.msgsPool.Put(msgs) +} + +func (s *StdNetBind) getMessages() *[]ipv6.Message { + return s.msgsPool.Get().(*[]ipv6.Message) +} + +var ( + // If compilation fails here these are no longer the same underlying type. + _ ipv6.Message = ipv4.Message{} +) + +type batchReader interface { + ReadBatch([]ipv6.Message, int) (int, error) +} + +type batchWriter interface { + WriteBatch([]ipv6.Message, int) (int, error) +} + +func (s *StdNetBind) receiveIP( + br batchReader, + conn *net.UDPConn, + rxOffload bool, + bufs [][]byte, + sizes []int, + eps []Endpoint, +) (n int, err error) { + msgs := s.getMessages() + for i := range bufs { + (*msgs)[i].Buffers[0] = bufs[i] + (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)] + } + defer s.putMessages(msgs) + var numMsgs int + if runtime.GOOS == "linux" || runtime.GOOS == "android" { + if rxOffload { + readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams) + numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0) + if err != nil { + return 0, err + } + numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize) if err != nil { return 0, err } } else { - msg := &(*msgs)[0] - msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) + numMsgs, err = br.ReadBatch(*msgs, 0) if err != nil { return 0, err } - numMsgs = 1 } - for i := 0; i < numMsgs; i++ { - msg := &(*msgs)[i] - sizes[i] = msg.N - addrPort := msg.Addr.(*net.UDPAddr).AddrPort() - ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation - getSrcFromControl(msg.OOB[:msg.NN], ep) - eps[i] = ep + } else { + msg := &(*msgs)[0] + msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) + if err != nil { + return 0, err } - return numMsgs, nil + numMsgs = 1 + } + for i := 0; i < numMsgs; i++ { + msg := &(*msgs)[i] + sizes[i] = msg.N + if sizes[i] == 0 { + continue + } + addrPort := msg.Addr.(*net.UDPAddr).AddrPort() + ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation + getSrcFromControl(msg.OOB[:msg.NN], ep) + eps[i] = ep + } + return numMsgs, nil +} + +func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { + return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { + return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) } } -func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn) ReceiveFunc { +func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { - msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message) - defer s.ipv6MsgsPool.Put(msgs) - for i := range bufs { - (*msgs)[i].Buffers[0] = bufs[i] - } - var numMsgs int - if runtime.GOOS == "linux" && pc != nil { - numMsgs, err = pc.ReadBatch(*msgs, 0) - if err != nil { - return 0, err - } - } else { - msg := &(*msgs)[0] - msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) - if err != nil { - return 0, err - } - numMsgs = 1 - } - for i := 0; i < numMsgs; i++ { - msg := &(*msgs)[i] - sizes[i] = msg.N - addrPort := msg.Addr.(*net.UDPAddr).AddrPort() - ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation - getSrcFromControl(msg.OOB[:msg.NN], ep) - eps[i] = ep - } - return numMsgs, nil + return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) } } // TODO: When all Binds handle IdealBatchSize, remove this dynamic function and // rename the IdealBatchSize constant to BatchSize. func (s *StdNetBind) BatchSize() int { - if runtime.GOOS == "linux" { + if runtime.GOOS == "linux" || runtime.GOOS == "android" { return IdealBatchSize } return 1 @@ -400,28 +344,42 @@ func (s *StdNetBind) Close() error { } s.blackhole4 = false s.blackhole6 = false + s.ipv4TxOffload = false + s.ipv4RxOffload = false + s.ipv6TxOffload = false + s.ipv6RxOffload = false if err1 != nil { return err1 } return err2 } +type ErrUDPGSODisabled struct { + onLaddr string + RetryErr error +} + +func (e ErrUDPGSODisabled) Error() string { + return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.onLaddr) +} + +func (e ErrUDPGSODisabled) Unwrap() error { + return e.RetryErr +} + func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error { s.mu.Lock() blackhole := s.blackhole4 conn := s.ipv4 - var ( - pc4 *ipv4.PacketConn - pc6 *ipv6.PacketConn - ) + offload := s.ipv4TxOffload + br := batchWriter(s.ipv4PC) is6 := false if endpoint.DstIP().Is6() { blackhole = s.blackhole6 conn = s.ipv6 - pc6 = s.ipv6PC + br = s.ipv6PC is6 = true - } else { - pc4 = s.ipv4PC + offload = s.ipv6TxOffload } s.mu.Unlock() @@ -431,109 +389,185 @@ func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error { if conn == nil { return syscall.EAFNOSUPPORT } + + msgs := s.getMessages() + defer s.putMessages(msgs) + ua := s.udpAddrPool.Get().(*net.UDPAddr) + defer s.udpAddrPool.Put(ua) if is6 { - return s.send6(conn, pc6, endpoint, bufs) + as16 := endpoint.DstIP().As16() + copy(ua.IP, as16[:]) + ua.IP = ua.IP[:16] } else { - return s.send4(conn, pc4, endpoint, bufs) + as4 := endpoint.DstIP().As4() + copy(ua.IP, as4[:]) + ua.IP = ua.IP[:4] } + ua.Port = int(endpoint.(*StdNetEndpoint).Port()) + var ( + retried bool + err error + ) +retry: + if offload { + n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize) + err = s.send(conn, br, (*msgs)[:n]) + if err != nil && offload && errShouldDisableUDPGSO(err) { + offload = false + s.mu.Lock() + if is6 { + s.ipv6TxOffload = false + } else { + s.ipv4TxOffload = false + } + s.mu.Unlock() + retried = true + goto retry + } + } else { + for i := range bufs { + (*msgs)[i].Addr = ua + (*msgs)[i].Buffers[0] = bufs[i] + setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint)) + } + err = s.send(conn, br, (*msgs)[:len(bufs)]) + } + if retried { + return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err} + } + return err } -func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint, bufs [][]byte) error { - ua := s.udpAddrPool.Get().(*net.UDPAddr) - as4 := ep.DstIP().As4() - copy(ua.IP, as4[:]) - ua.IP = ua.IP[:4] - ua.Port = int(ep.(*StdNetEndpoint).Port()) - msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message) - for i, buf := range bufs { - (*msgs)[i].Buffers[0] = buf - (*msgs)[i].Addr = ua - setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint)) - } +func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error { var ( n int err error start int ) - if runtime.GOOS == "linux" && pc != nil { + if runtime.GOOS == "linux" || runtime.GOOS == "android" { for { - n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0) - if err != nil { - if errors.Is(err, syscall.EAFNOSUPPORT) { - for j := start; j < len(bufs); j++ { - _, _, werr := conn.WriteMsgUDP(bufs[j], (*msgs)[j].OOB, ua) - if werr != nil { - err = werr - break - } - } - } - break - } - if n == len((*msgs)[start:len(bufs)]) { + n, err = pc.WriteBatch(msgs[start:], 0) + if err != nil || n == len(msgs[start:]) { break } start += n } } else { - for i, buf := range bufs { - _, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua) + for _, msg := range msgs { + _, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr)) if err != nil { break } } } - s.udpAddrPool.Put(ua) - s.ipv4MsgsPool.Put(msgs) return err } -func (s *StdNetBind) send6(conn *net.UDPConn, pc *ipv6.PacketConn, ep Endpoint, bufs [][]byte) error { - ua := s.udpAddrPool.Get().(*net.UDPAddr) - as16 := ep.DstIP().As16() - copy(ua.IP, as16[:]) - ua.IP = ua.IP[:16] - ua.Port = int(ep.(*StdNetEndpoint).Port()) - msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message) - for i, buf := range bufs { - (*msgs)[i].Buffers[0] = buf - (*msgs)[i].Addr = ua - setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint)) - } +const ( + // Exceeding these values results in EMSGSIZE. They account for layer3 and + // layer4 headers. IPv6 does not need to account for itself as the payload + // length field is self excluding. + maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8 + maxIPv6PayloadLen = 1<<16 - 1 - 8 + + // This is a hard limit imposed by the kernel. + udpSegmentMaxDatagrams = 64 +) + +type setGSOFunc func(control *[]byte, gsoSize uint16) + +func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int { var ( - n int - err error - start int + base = -1 // index of msg we are currently coalescing into + gsoSize int // segmentation size of msgs[base] + dgramCnt int // number of dgrams coalesced into msgs[base] + endBatch bool // tracking flag to start a new batch on next iteration of bufs ) - if runtime.GOOS == "linux" && pc != nil { - for { - n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0) - if err != nil { - if errors.Is(err, syscall.EAFNOSUPPORT) { - for j := start; j < len(bufs); j++ { - _, _, werr := conn.WriteMsgUDP(bufs[j], (*msgs)[j].OOB, ua) - if werr != nil { - err = werr - break - } - } + maxPayloadLen := maxIPv4PayloadLen + if ep.DstIP().Is6() { + maxPayloadLen = maxIPv6PayloadLen + } + for i, buf := range bufs { + if i > 0 { + msgLen := len(buf) + baseLenBefore := len(msgs[base].Buffers[0]) + freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore + if msgLen+baseLenBefore <= maxPayloadLen && + msgLen <= gsoSize && + msgLen <= freeBaseCap && + dgramCnt < udpSegmentMaxDatagrams && + !endBatch { + msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...) + if i == len(bufs)-1 { + setGSO(&msgs[base].OOB, uint16(gsoSize)) } - break + dgramCnt++ + if msgLen < gsoSize { + // A smaller than gsoSize packet on the tail is legal, but + // it must end the batch. + endBatch = true + } + continue } - if n == len((*msgs)[start:len(bufs)]) { - break - } - start += n } - } else { - for i, buf := range bufs { - _, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua) - if err != nil { - break + if dgramCnt > 1 { + setGSO(&msgs[base].OOB, uint16(gsoSize)) + } + // Reset prior to incrementing base since we are preparing to start a + // new potential batch. + endBatch = false + base++ + gsoSize = len(buf) + setSrcControl(&msgs[base].OOB, ep) + msgs[base].Buffers[0] = buf + msgs[base].Addr = addr + dgramCnt = 1 + } + return base + 1 +} + +type getGSOFunc func(control []byte) (int, error) + +func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) { + for i := firstMsgAt; i < len(msgs); i++ { + msg := &msgs[i] + if msg.N == 0 { + return n, err + } + var ( + gsoSize int + start int + end = msg.N + numToSplit = 1 + ) + gsoSize, err = getGSO(msg.OOB[:msg.NN]) + if err != nil { + return n, err + } + if gsoSize > 0 { + numToSplit = (msg.N + gsoSize - 1) / gsoSize + end = gsoSize + } + for j := 0; j < numToSplit; j++ { + if n > i { + return n, errors.New("splitting coalesced packet resulted in overflow") } + copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end]) + msgs[n].N = copied + msgs[n].Addr = msg.Addr + start = end + end += gsoSize + if end > msg.N { + end = msg.N + } + n++ + } + if i != n-1 { + // It is legal for bytes to move within msg.Buffers[0] as a result + // of splitting, so we only zero the source msg len when it is not + // the destination of the last split operation above. + msg.N = 0 } } - s.udpAddrPool.Put(ua) - s.ipv6MsgsPool.Put(msgs) - return err + return n, nil } diff --git a/wgstack/conn/controlfns_linux.go b/wgstack/conn/controlfns_linux.go index cc25d25..e765d7a 100644 --- a/wgstack/conn/controlfns_linux.go +++ b/wgstack/conn/controlfns_linux.go @@ -29,6 +29,9 @@ func init() { // Set beyond *mem_max if CAP_NET_ADMIN _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, socketBufferSize) _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, socketBufferSize) + _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1) //todo!!! + _ = unix.SetsockoptInt(int(fd), unix.SOL_UDP, unix.UDP_SEGMENT, 0xffff) //todo!!! + //print(err.Error()) }) }, diff --git a/wgstack/conn/sticky_linux.go b/wgstack/conn/sticky_linux.go index b00a73e..e3a4f04 100644 --- a/wgstack/conn/sticky_linux.go +++ b/wgstack/conn/sticky_linux.go @@ -1,9 +1,3 @@ -//go:build linux && !android - -// SPDX-License-Identifier: MIT -// -// Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. - package conn import ( @@ -13,6 +7,37 @@ import ( "golang.org/x/sys/unix" ) +func (e *StdNetEndpoint) SrcIP() netip.Addr { + switch len(e.src) { + case unix.CmsgSpace(unix.SizeofInet4Pktinfo): + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) + return netip.AddrFrom4(info.Spec_dst) + case unix.CmsgSpace(unix.SizeofInet6Pktinfo): + info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) + // TODO: set zone. in order to do so we need to check if the address is + // link local, and if it is perform a syscall to turn the ifindex into a + // zone string because netip uses string zones. + return netip.AddrFrom16(info.Addr) + } + return netip.Addr{} +} + +func (e *StdNetEndpoint) SrcIfidx() int32 { + switch len(e.src) { + case unix.CmsgSpace(unix.SizeofInet4Pktinfo): + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) + return info.Ifindex + case unix.CmsgSpace(unix.SizeofInet6Pktinfo): + info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) + return int32(info.Ifindex) + } + return 0 +} + +func (e *StdNetEndpoint) SrcToString() string { + return e.SrcIP().String() +} + // getSrcFromControl parses the control for PKTINFO and if found updates ep with // the source information found. func getSrcFromControl(control []byte, ep *StdNetEndpoint) { @@ -34,83 +59,47 @@ func getSrcFromControl(control []byte, ep *StdNetEndpoint) { if hdr.Level == unix.IPPROTO_IP && hdr.Type == unix.IP_PKTINFO { - info := pktInfoFromBuf[unix.Inet4Pktinfo](data) - ep.src.Addr = netip.AddrFrom4(info.Spec_dst) - ep.src.ifidx = info.Ifindex + if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet4Pktinfo) { + ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) + } + ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)] + hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr) + copy(ep.src, hdrBuf) + copy(ep.src[unix.CmsgLen(0):], data) return } if hdr.Level == unix.IPPROTO_IPV6 && hdr.Type == unix.IPV6_PKTINFO { - info := pktInfoFromBuf[unix.Inet6Pktinfo](data) - ep.src.Addr = netip.AddrFrom16(info.Addr) - ep.src.ifidx = int32(info.Ifindex) + if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet6Pktinfo) { + ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet6Pktinfo)) + } + ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)] + + hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr) + copy(ep.src, hdrBuf) + copy(ep.src[unix.CmsgLen(0):], data) return } } } -// pktInfoFromBuf returns type T populated from the provided buf via copy(). It -// panics if buf is of insufficient size. -func pktInfoFromBuf[T unix.Inet4Pktinfo | unix.Inet6Pktinfo](buf []byte) (t T) { - size := int(unsafe.Sizeof(t)) - if len(buf) < size { - panic("pktInfoFromBuf: buffer too small") - } - copy(unsafe.Slice((*byte)(unsafe.Pointer(&t)), size), buf) - return t -} - // setSrcControl sets an IP{V6}_PKTINFO in control based on the source address // and source ifindex found in ep. control's len will be set to 0 in the event // that ep is a default value. func setSrcControl(control *[]byte, ep *StdNetEndpoint) { - *control = (*control)[:cap(*control)] - if len(*control) < int(unsafe.Sizeof(unix.Cmsghdr{})) { - *control = (*control)[:0] + if cap(*control) < len(ep.src) { return } - - if ep.src.ifidx == 0 && !ep.SrcIP().IsValid() { - *control = (*control)[:0] - return - } - - if len(*control) < srcControlSize { - *control = (*control)[:0] - return - } - - hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(*control)[0])) - if ep.SrcIP().Is4() { - hdr.Level = unix.IPPROTO_IP - hdr.Type = unix.IP_PKTINFO - hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo)) - - info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr])) - info.Ifindex = ep.src.ifidx - if ep.SrcIP().IsValid() { - info.Spec_dst = ep.SrcIP().As4() - } - *control = (*control)[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)] - } else { - hdr.Level = unix.IPPROTO_IPV6 - hdr.Type = unix.IPV6_PKTINFO - hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo)) - - info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr])) - info.Ifindex = uint32(ep.src.ifidx) - if ep.SrcIP().IsValid() { - info.Addr = ep.SrcIP().As16() - } - *control = (*control)[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)] - } - + *control = (*control)[:0] + *control = append(*control, ep.src...) } -var srcControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo) +// stickyControlSize returns the recommended buffer size for pooling sticky +// offloading control data. +var stickyControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo) const StdNetSupportsStickySockets = true