diff --git a/cmd/gso/gso.go b/cmd/gso/gso.go new file mode 100644 index 0000000..ed68941 --- /dev/null +++ b/cmd/gso/gso.go @@ -0,0 +1,191 @@ +package main + +import ( + "encoding/binary" + "errors" + "flag" + "fmt" + "log" + "net" + "net/netip" + "time" + "unsafe" + + "golang.org/x/sys/unix" +) + +const ( + // UDP_SEGMENT enables GSO segmentation + UDP_SEGMENT = 103 + // Maximum GSO segment size (typical MTU - headers) + maxGSOSize = 1400 +) + +func main() { + destAddr := flag.String("dest", "10.4.0.16:4202", "Destination address") + gsoSize := flag.Int("gso", 1400, "GSO segment size") + totalSize := flag.Int("size", 14000, "Total payload size to send") + count := flag.Int("count", 1, "Number of packets to send") + flag.Parse() + + if *gsoSize > maxGSOSize { + log.Fatalf("GSO size %d exceeds maximum %d", *gsoSize, maxGSOSize) + } + + // Resolve destination address + _, err := net.ResolveUDPAddr("udp", *destAddr) + if err != nil { + log.Fatalf("Failed to resolve address: %v", err) + } + + // Create a raw UDP socket with GSO support + fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_UDP) + if err != nil { + log.Fatalf("Failed to create socket: %v", err) + } + defer unix.Close(fd) + + // Bind to a local address + localAddr := &unix.SockaddrInet4{ + Port: 0, // Let the system choose a port + } + if err := unix.Bind(fd, localAddr); err != nil { + log.Fatalf("Failed to bind socket: %v", err) + } + + fmt.Printf("Sending UDP packets with GSO enabled\n") + fmt.Printf("Destination: %s\n", *destAddr) + fmt.Printf("GSO segment size: %d bytes\n", *gsoSize) + fmt.Printf("Total payload size: %d bytes\n", *totalSize) + fmt.Printf("Number of packets: %d\n\n", *count) + + // Create payload + payload := make([]byte, *totalSize) + for i := range payload { + payload[i] = byte(i % 256) + } + + dest := netip.MustParseAddrPort(*destAddr) + + //if err := unix.SetsockoptInt(fd, unix.SOL_UDP, unix.UDP_SEGMENT, 1400); err != nil { + // panic(err) + //} + + for i := 0; i < *count; i++ { + err := WriteBatch(fd, payload, dest, uint16(*gsoSize), true) + if err != nil { + log.Printf("Send error on packet %d: %v", i, err) + continue + } + + if (i+1)%100 == 0 || i == *count-1 { + fmt.Printf("Sent %d packets\n", i+1) + } + } + fmt.Printf("now, let's send without the correct ctrl header\n") + time.Sleep(time.Second) + for i := 0; i < *count; i++ { + err := WriteBatch(fd, payload, dest, uint16(*gsoSize), false) + if err != nil { + log.Printf("Send error on packet %d: %v", i, err) + continue + } + + if (i+1)%100 == 0 || i == *count-1 { + fmt.Printf("Sent %d packets\n", i+1) + } + } + +} + +func WriteBatch(fd int, payload []byte, addr netip.AddrPort, segSize uint16, withHeader bool) error { + msgs := make([]rawMessage, 0, 1) + iovs := make([]iovec, 0, 1) + names := make([][unix.SizeofSockaddrInet6]byte, 0, 1) + + sent := 0 + + pkts := []BatchPacket{ + { + Payload: payload, + Addr: addr, + }, + } + + for _, pkt := range pkts { + if len(pkt.Payload) == 0 { + sent++ + continue + } + + msgs = append(msgs, rawMessage{}) + iovs = append(iovs, iovec{}) + names = append(names, [unix.SizeofSockaddrInet6]byte{}) + + idx := len(msgs) - 1 + msg := &msgs[idx] + iov := &iovs[idx] + name := &names[idx] + + setIovecSlice(iov, pkt.Payload) + msg.Hdr.Iov = iov + msg.Hdr.Iovlen = 1 + + if withHeader { + setRawMessageControl(msg, buildGSOControlMessage(segSize)) // + } else { + setRawMessageControl(msg, nil) // + } + + msg.Hdr.Flags = 0 + + nameLen, err := encodeSockaddr(name[:], pkt.Addr) + if err != nil { + return err + } + msg.Hdr.Name = &name[0] + msg.Hdr.Namelen = nameLen + } + + if len(msgs) == 0 { + return errors.New("nothing to write") + } + + offset := 0 + for offset < len(msgs) { + n, _, errno := unix.Syscall6( + unix.SYS_SENDMMSG, + uintptr(fd), + uintptr(unsafe.Pointer(&msgs[offset])), + uintptr(len(msgs)-offset), + 0, + 0, + 0, + ) + + if errno != 0 { + if errno == unix.EINTR { + continue + } + return &net.OpError{Op: "sendmmsg", Err: errno} + } + + if n == 0 { + break + } + offset += int(n) + } + + return nil +} + +func buildGSOControlMessage(segSize uint16) []byte { + control := make([]byte, unix.CmsgSpace(2)) + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + hdr.Level = unix.SOL_UDP + hdr.Type = unix.UDP_SEGMENT + setCmsgLen(hdr, unix.CmsgLen(2)) + binary.NativeEndian.PutUint16(control[unix.CmsgLen(0):unix.CmsgLen(0)+2], uint16(segSize)) + + return control +} diff --git a/cmd/gso/helper.go b/cmd/gso/helper.go new file mode 100644 index 0000000..e4f4f05 --- /dev/null +++ b/cmd/gso/helper.go @@ -0,0 +1,85 @@ +package main + +import ( + "encoding/binary" + "fmt" + "net/netip" + "unsafe" + + "golang.org/x/sys/unix" +) + +type iovec struct { + Base *byte + Len uint64 +} + +type msghdr struct { + Name *byte + Namelen uint32 + Pad0 [4]byte + Iov *iovec + Iovlen uint64 + Control *byte + Controllen uint64 + Flags int32 + Pad1 [4]byte +} + +type rawMessage struct { + Hdr msghdr + Len uint32 + Pad0 [4]byte +} + +type BatchPacket struct { + Payload []byte + Addr netip.AddrPort +} + +func encodeSockaddr(dst []byte, addr netip.AddrPort) (uint32, error) { + if addr.Addr().Is4() { + if !addr.Addr().Is4() { + return 0, fmt.Errorf("Listener is IPv4, but writing to IPv6 remote") + } + var sa unix.RawSockaddrInet4 + sa.Family = unix.AF_INET + sa.Addr = addr.Addr().As4() + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port()) + size := unix.SizeofSockaddrInet4 + copy(dst[:size], (*(*[unix.SizeofSockaddrInet4]byte)(unsafe.Pointer(&sa)))[:]) + return uint32(size), nil + } + + var sa unix.RawSockaddrInet6 + sa.Family = unix.AF_INET6 + sa.Addr = addr.Addr().As16() + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port()) + size := unix.SizeofSockaddrInet6 + copy(dst[:size], (*(*[unix.SizeofSockaddrInet6]byte)(unsafe.Pointer(&sa)))[:]) + return uint32(size), nil +} + +func setRawMessageControl(msg *rawMessage, buf []byte) { + if len(buf) == 0 { + msg.Hdr.Control = nil + msg.Hdr.Controllen = 0 + return + } + msg.Hdr.Control = &buf[0] + msg.Hdr.Controllen = uint64(len(buf)) +} + +func setCmsgLen(h *unix.Cmsghdr, l int) { + h.Len = uint64(l) +} + +func setIovecSlice(iov *iovec, b []byte) { + if len(b) == 0 { + iov.Base = nil + iov.Len = 0 + return + } + iov.Base = &b[0] + iov.Len = uint64(len(b)) +} diff --git a/connection_manager.go b/connection_manager.go index 7242c72..caf4d33 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -518,12 +518,12 @@ func (cm *connectionManager) sendPunch(hostinfo *HostInfo) { if cm.punchy.GetTargetEverything() { hostinfo.remotes.ForEach(cm.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) { cm.metricsTxPunchy.Inc(1) - cm.intf.outside.WriteTo([]byte{1}, addr) + cm.intf.outside.WriteDirect([]byte{1}, addr) }) } else if hostinfo.remote.IsValid() { cm.metricsTxPunchy.Inc(1) - cm.intf.outside.WriteTo([]byte{1}, hostinfo.remote) + cm.intf.outside.WriteDirect([]byte{1}, hostinfo.remote) } } diff --git a/connection_state.go b/connection_state.go index 485f6fd..c00c9af 100644 --- a/connection_state.go +++ b/connection_state.go @@ -15,7 +15,8 @@ import ( // TODO: In a 5Gbps test, 1024 is not sufficient. With a 1400 MTU this is about 1.4Gbps of window, assuming full packets. // 4092 should be sufficient for 5Gbps -const ReplayWindow = 8192 +// TODO this is a horrible amount of RAM to waste per-tunnel +const ReplayWindow = 0xffff / 2 type ConnectionState struct { eKey *NebulaCipherState diff --git a/handshake_ix.go b/handshake_ix.go index 026bfbd..2b7e6dd 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -348,7 +348,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet msg = existing.HandshakePacket[2] f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) if addr.IsValid() { - err := f.outside.WriteTo(msg, addr) + err := f.outside.WriteDirect(msg, addr) if err != nil { f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). @@ -417,7 +417,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet // Do the send f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) if addr.IsValid() { - err = f.outside.WriteTo(msg, addr) + err = f.outside.WriteDirect(msg, addr) if err != nil { f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). diff --git a/handshake_manager.go b/handshake_manager.go index f92e72d..7264ae8 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -238,7 +238,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered var sentTo []netip.AddrPort hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr netip.AddrPort, _ bool) { hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1) - err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr) + err := hm.outside.WriteDirect(hostinfo.HandshakePacket[0], addr) if err != nil { hostinfo.logger(hm.l).WithField("udpAddr", addr). WithField("initiatorIndex", hostinfo.localIndexId). diff --git a/inside.go b/inside.go index d24ed31..81fc161 100644 --- a/inside.go +++ b/inside.go @@ -8,6 +8,7 @@ import ( "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/noiseutil" + "github.com/slackhq/nebula/packet" "github.com/slackhq/nebula/routing" ) @@ -324,7 +325,7 @@ func (f *Interface) SendVia(via *HostInfo, via.logger(f.l).WithError(err).Info("Failed to EncryptDanger in sendVia") return } - err = f.writers[0].WriteTo(out, via.remote) + err = f.writers[0].WriteDirect(out, via.remote) if err != nil { via.logger(f.l).WithError(err).Info("Failed to WriteTo in sendVia") } @@ -384,19 +385,29 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType } if remote.IsValid() { - err = f.writers[q].WriteTo(out, remote) + pkt := packet.GetPool().Get() + copy(pkt.Payload[:], out) + pkt.Payload = pkt.Payload[:len(out)] + pkt.Addr = remote + err = f.writers[q].WriteTo(pkt) if err != nil { hostinfo.logger(f.l).WithError(err). WithField("udpAddr", remote).Error("Failed to write outgoing packet") } } else if hostinfo.remote.IsValid() { - err = f.writers[q].WriteTo(out, hostinfo.remote) + pkt := packet.GetPool().Get() + copy(pkt.Payload, out) + pkt.Payload = pkt.Payload[:len(out)] + pkt.Addr = hostinfo.remote + err = f.writers[q].WriteTo(pkt) if err != nil { hostinfo.logger(f.l).WithError(err). WithField("udpAddr", remote).Error("Failed to write outgoing packet") } } else { // Try to send via a relay + + //todo relay is slow sorryyy for _, relayIP := range hostinfo.relayState.CopyRelayIps() { relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP) if err != nil { diff --git a/interface.go b/interface.go index 424a66f..6dde026 100644 --- a/interface.go +++ b/interface.go @@ -207,7 +207,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { l: c.l, } - ifce.pktPool = packet.NewPool() + ifce.pktPool = packet.GetPool() ifce.tryPromoteEvery.Store(c.tryPromoteEvery) ifce.reQueryEvery.Store(c.reQueryEvery) @@ -327,6 +327,26 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { f.wg.Done() } +//// todo why? understand! +//func normalizeGROSegSize(segSize, total int) int { +// if segCount > 1 && total > 0 { +// avg := total / segCount +// if avg > 0 { +// if segSize > avg { +// if segSize-8 == avg { +// segSize = avg +// } else if segSize > total { +// segSize = avg +// } +// } +// } +// } +// if segSize > total { +// segSize = total +// } +// return segSize +//} + func (f *Interface) workerIn(i int, ctx context.Context) { lhh := f.lightHouse.NewRequestHandler() conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) @@ -338,7 +358,18 @@ func (f *Interface) workerIn(i int, ctx context.Context) { for { select { case p := <-f.inbound: - f.readOutsidePackets(p.Addr, nil, result2[:0], p.Payload, h, fwPacket2, lhh, nb2, i, conntrackCache.Get(f.l)) + if p.SegSize > 0 && p.SegSize < len(p.Payload) { + for offset := 0; offset < len(p.Payload); offset += p.SegSize { + end := offset + p.SegSize + if end > len(p.Payload) { + end = len(p.Payload) + } + f.readOutsidePackets(p.Addr, nil, result2[:0], p.Payload[offset:end], h, fwPacket2, lhh, nb2, i, conntrackCache.Get(f.l)) + } + } else { + f.readOutsidePackets(p.Addr, nil, result2[:0], p.Payload, h, fwPacket2, lhh, nb2, i, conntrackCache.Get(f.l)) + } + f.pktPool.Put(p) case <-ctx.Done(): f.wg.Done() @@ -357,7 +388,7 @@ func (f *Interface) workerOut(i int, ctx context.Context) { select { case data := <-f.outbound: f.consumeInsidePacket(data.Payload, fwPacket1, nb1, result1, i, conntrackCache.Get(f.l)) - f.pktPool.Put(data) + //f.pktPool.Put(data) //todo if err pls put packet back case <-ctx.Done(): f.wg.Done() return diff --git a/lighthouse.go b/lighthouse.go index 9f00c39..cadb6e0 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -1329,7 +1329,7 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn go func() { time.Sleep(lhh.lh.punchy.GetDelay()) lhh.lh.metricHolepunchTx.Inc(1) - lhh.lh.punchConn.WriteTo(empty, vpnPeer) + lhh.lh.punchConn.WriteDirect(empty, vpnPeer) }() if lhh.l.Level >= logrus.DebugLevel { diff --git a/outside.go b/outside.go index a6dcb5c..513c0a8 100644 --- a/outside.go +++ b/outside.go @@ -519,7 +519,7 @@ func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) { f.messageMetrics.Tx(header.RecvError, 0, 1) b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0) - _ = f.outside.WriteTo(b, endpoint) + _ = f.outside.WriteDirect(b, endpoint) if f.l.Level >= logrus.DebugLevel { f.l.WithField("index", index). WithField("udpAddr", endpoint). diff --git a/packet/packet.go b/packet/packet.go index d6dd01c..d2f12ed 100644 --- a/packet/packet.go +++ b/packet/packet.go @@ -3,27 +3,36 @@ package packet import ( "net/netip" "sync" + + "golang.org/x/sys/unix" ) -const Size = 9001 +const Size = 0xffff type Packet struct { Payload []byte + Control []byte + SegSize int Addr netip.AddrPort } func New() *Packet { - return &Packet{Payload: make([]byte, Size)} + return &Packet{ + Payload: make([]byte, Size), + Control: make([]byte, unix.CmsgSpace(2)), + } } type Pool struct { pool sync.Pool } -func NewPool() *Pool { - return &Pool{ - pool: sync.Pool{New: func() any { return New() }}, - } +var bigPool = &Pool{ + pool: sync.Pool{New: func() any { return New() }}, +} + +func GetPool() *Pool { + return bigPool } func (p *Pool) Get() *Packet { diff --git a/udp/conn.go b/udp/conn.go index 6d0b79e..dcc3421 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -17,7 +17,8 @@ type Conn interface { Rebind() error LocalAddr() (netip.AddrPort, error) ListenOut(pg PacketBufferGetter, pc chan *packet.Packet) error - WriteTo(b []byte, addr netip.AddrPort) error + WriteTo(p *packet.Packet) error + WriteDirect(b []byte, port netip.AddrPort) error ReloadConfig(c *config.C) Close() error } diff --git a/udp/udp_linux.go b/udp/udp_linux.go index 34aaa3f..cc0d735 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -5,9 +5,11 @@ package udp import ( "encoding/binary" + "errors" "fmt" "net" "net/netip" + "sync" "syscall" "time" "unsafe" @@ -19,13 +21,136 @@ import ( "golang.org/x/sys/unix" ) +const ( + defaultGSOMaxSegments = 16 + defaultGSOFlushTimeout = 150 * time.Microsecond + maxGSOBatchBytes = 0xFFFF +) + +var ( + errGSOFallback = errors.New("udp gso fallback") + errGSODisabled = errors.New("udp gso disabled") +) + var readTimeout = unix.NsecToTimeval(int64(time.Millisecond * 500)) +type gsoState struct { + m sync.Mutex + Buf []byte + Addr netip.AddrPort + SegSize int + MaxSegments int + MaxBytes int + FlushTimeout time.Duration + Timer *time.Timer + + packets []*packet.Packet + msg rawMessage + name [unix.SizeofSockaddrInet6]byte + iov []iovec + ctrl []byte +} + +func (g *gsoState) Init() { + g.iov = make([]iovec, g.MaxSegments) + for i := 0; i < g.MaxSegments; i++ { + g.iov[i] = iovec{} + } + g.msg.Hdr.Iov = &g.iov[0] + g.msg.Hdr.Iovlen = 1 + + g.packets = make([]*packet.Packet, 0, g.MaxSegments) + g.ctrl = make([]byte, unix.CmsgSpace(2)) + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&g.ctrl[0])) + hdr.Level = unix.SOL_UDP + hdr.Type = unix.UDP_SEGMENT + setCmsgLen(hdr, unix.CmsgLen(2)) + g.msg.Hdr.Control = &g.ctrl[0] + g.msg.Hdr.Controllen = uint64(len(g.ctrl)) + + g.name = [unix.SizeofSockaddrInet6]byte{} + g.msg.Hdr.Name = &g.name[0] + +} + +func (g *gsoState) setSegSizeLocked(segSize int) { + g.SegSize = segSize + x := unix.CmsgLen(0) + binary.LittleEndian.PutUint16(g.ctrl[x:x+2], uint16(segSize)) +} + +func (g *gsoState) setNameLocked(x netip.AddrPort, isV4 bool) { + g.Addr = x + nameLen := encodeSockaddr(g.name[:], g.Addr, isV4) + g.msg.Hdr.Name = &g.name[0] + g.msg.Hdr.Namelen = nameLen +} + +func encodeSockaddr(dst []byte, addr netip.AddrPort, isV4 bool) uint32 { + if isV4 { + //todo? + //if !addr.Addr().Is4() { + // return 0, fmt.Errorf("Listener is IPv4, but writing to IPv6 remote") + //} + var sa unix.RawSockaddrInet4 + sa.Family = unix.AF_INET + sa.Addr = addr.Addr().As4() + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port()) + size := unix.SizeofSockaddrInet4 + copy(dst[:size], (*(*[unix.SizeofSockaddrInet4]byte)(unsafe.Pointer(&sa)))[:]) + return uint32(size) + } + + var sa unix.RawSockaddrInet6 + sa.Family = unix.AF_INET6 + sa.Addr = addr.Addr().As16() + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port()) + size := unix.SizeofSockaddrInet6 + copy(dst[:size], (*(*[unix.SizeofSockaddrInet6]byte)(unsafe.Pointer(&sa)))[:]) + return uint32(size) +} + +func (g *gsoState) sendmsgLocked(fd int) error { + //name already set + //ctrl already set + //g.iov = g.iov[:0] + g.msg.Hdr.Iovlen = uint64(len(g.packets)) + for i := range g.packets { + g.iov[i].Base = &g.packets[i].Payload[0] + g.iov[i].Len = uint64(len(g.packets[i].Payload)) + } + + const flags = 0 + for { + _, _, err := unix.Syscall( + unix.SYS_SENDMSG, + uintptr(fd), + uintptr(unsafe.Pointer(&g.msg)), + uintptr(flags), + ) + //todo no matter what, reset things + for i := range g.packets { + pool := packet.GetPool() + pool.Put(g.packets[i]) + } + g.packets = g.packets[:0] + + if err != 0 { + return &net.OpError{Op: "sendmsg", Err: err} + } + + return nil + } +} + type StdConn struct { - sysFd int - isV4 bool - l *logrus.Logger - batch int + sysFd int + isV4 bool + l *logrus.Logger + batch int + enableGRO bool + enableGSO bool + gso gsoState } func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { @@ -145,15 +270,47 @@ func (u *StdConn) ListenOut(pg PacketBufferGetter, pc chan *packet.Packet) error ip, _ = netip.AddrFromSlice(names[i][8:24]) } out.Addr = netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])) + ctrlLen := getRawMessageControlLen(&msgs[i]) + if ctrlLen > 0 { + packets[i].SegSize = parseGROControl(packets[i].Control[:ctrlLen]) + } else { + packets[i].SegSize = 0 + } + pc <- out //rotate this packet out so we don't overwrite it packets[i] = pg() msgs[i].Hdr.Iov.Base = &packets[i].Payload[0] + if u.enableGRO { + msgs[i].Hdr.Control = &packets[i].Control[0] + msgs[i].Hdr.Controllen = uint64(cap(packets[i].Control)) + } + } } } +func parseGROControl(control []byte) int { + if len(control) == 0 { + return 0 + } + + cmsgs, err := unix.ParseSocketControlMessage(control) + if err != nil { + return 0 + } + + for _, c := range cmsgs { + if c.Header.Level == unix.SOL_UDP && c.Header.Type == unix.UDP_GRO && len(c.Data) >= 2 { + segSize := int(binary.LittleEndian.Uint16(c.Data[:2])) + return segSize + } + } + + return 0 +} + func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) { for { n, _, err := unix.Syscall6( @@ -201,11 +358,123 @@ func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) { } } -func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error { - if u.isV4 { - return u.writeTo4(b, ip) +func (u *StdConn) WriteTo(p *packet.Packet) error { + if u.enableGSO && p.Addr.IsValid() { + if err := u.queueGSOPacket(p); err == nil { + return nil + } else if !errors.Is(err, errGSOFallback) { + return err + } } - return u.writeTo6(b, ip) + + var err error + if u.isV4 { + err = u.writeTo4(p.Payload, p.Addr) + } else { + err = u.writeTo4(p.Payload, p.Addr) + } + packet.GetPool().Put(p) + return err +} + +func (u *StdConn) WriteDirect(b []byte, addr netip.AddrPort) error { + if u.isV4 { + return u.writeTo4(b, addr) + } + return u.writeTo6(b, addr) +} + +func (u *StdConn) scheduleGSOFlushLocked() { + if u.gso.Timer == nil { + u.gso.Timer = time.AfterFunc(u.gso.FlushTimeout, u.gsoFlushTimer) + return + } + u.gso.Timer.Reset(u.gso.FlushTimeout) +} + +func (u *StdConn) stopGSOTimerLocked() { + if u.gso.Timer != nil { + u.gso.Timer.Stop() + u.gso.Timer = nil //todo I also don't like this + } +} + +func (u *StdConn) queueGSOPacket(p *packet.Packet) error { + if len(p.Payload) == 0 { + return nil + } + + u.gso.m.Lock() + defer u.gso.m.Unlock() + + if !u.enableGSO || !p.Addr.IsValid() || len(p.Payload) > u.gso.MaxBytes { + if err := u.flushGSOlocked(); err != nil { + return err + } + return errGSOFallback + } + + if len(u.gso.packets) == 0 { + u.gso.setNameLocked(p.Addr, u.isV4) + u.gso.SegSize = len(p.Payload) + u.gso.packets = append(u.gso.packets, p) + } else if p.Addr != u.gso.Addr || len(p.Payload) != u.gso.SegSize { + if err := u.flushGSOlocked(); err != nil { + return err + } //todo deal with "one small packet" case + u.gso.setNameLocked(p.Addr, u.isV4) + u.gso.SegSize = len(p.Payload) + u.gso.packets = append(u.gso.packets, p) + } else { + u.gso.packets = append(u.gso.packets, p) + } + + //big todo + //if len(u.gso.Buf)+len(p.Payload) > u.gso.MaxBytes { + // if err := u.flushGSOlocked(); err != nil { + // return err + // } + // u.gso.setNameLocked(p.Addr, u.isV4) + // u.gso.SegSize = len(p.Payload) + // u.gso.packets = append(u.gso.packets, p) + //} + + if len(u.gso.packets) >= u.gso.MaxSegments || u.gso.FlushTimeout <= 0 { + return u.flushGSOlocked() + } + + u.scheduleGSOFlushLocked() + return nil +} + +func (u *StdConn) flushGSOlocked() error { + if len(u.gso.packets) == 0 { + u.stopGSOTimerLocked() + return nil + } + + u.stopGSOTimerLocked() + + if u.gso.SegSize <= 0 { + return errGSOFallback + } + + err := u.gso.sendmsgLocked(u.sysFd) + if errors.Is(err, errGSODisabled) { + u.l.WithField("addr", u.gso.Addr).Warn("UDP GSO disabled by kernel, falling back to sendto") + u.enableGSO = false + //todo! + //return u.sendSegmentsIndividually(payload, addr, segSize) + } + u.gso.SegSize = 0 + + return err +} + +func (u *StdConn) gsoFlushTimer() { + u.gso.m.Lock() + _ = u.flushGSOlocked() + u.gso.m.Unlock() } func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error { @@ -308,6 +577,72 @@ func (u *StdConn) ReloadConfig(c *config.C) { u.l.WithError(err).Error("Failed to set listen.so_mark") } } + u.configureGRO(true) + u.configureGSO(c) +} + +func (u *StdConn) configureGRO(enable bool) { + if enable == u.enableGRO { + return + } + + if enable { + if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 1); err != nil { + u.l.WithError(err).Warn("Failed to enable UDP GRO") + return + } + u.enableGRO = true + u.l.Info("UDP GRO enabled") + return + } + + if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 0); err != nil && err != unix.ENOPROTOOPT { + u.l.WithError(err).Warn("Failed to disable UDP GRO") + } + u.enableGRO = false +} + +func (u *StdConn) configureGSO(c *config.C) { + enable := c.GetBool("listen.enable_gso", true) + if !enable { + u.disableGSO() + } else { + u.enableGSO = true + } + + segments := c.GetInt("listen.gso_max_segments", defaultGSOMaxSegments) + if segments < 1 { + segments = 1 + } + u.gso.MaxSegments = segments + + maxBytes := c.GetInt("listen.gso_max_bytes", 0) + if maxBytes <= 0 { + maxBytes = MTU * segments + } + if maxBytes > maxGSOBatchBytes { + u.l.WithField("requested", maxBytes).Warn("listen.gso_max_bytes larger than UDP limit; clamping") + maxBytes = maxGSOBatchBytes + } + u.gso.MaxBytes = maxBytes + + timeout := c.GetDuration("listen.gso_flush_timeout", defaultGSOFlushTimeout) + if timeout < 0 { + timeout = 0 + } + u.gso.FlushTimeout = timeout + u.gso.Init() +} + +func (u *StdConn) disableGSO() { + u.gso.m.Lock() + defer u.gso.m.Unlock() + u.enableGSO = false + _ = u.flushGSOlocked() + u.gso.Buf = nil + u.gso.packets = u.gso.packets[:0] + u.gso.SegSize = 0 + u.stopGSOTimerLocked() } func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error { diff --git a/udp/udp_linux_64.go b/udp/udp_linux_64.go index 0550db7..4b0ec0c 100644 --- a/udp/udp_linux_64.go +++ b/udp/udp_linux_64.go @@ -34,6 +34,24 @@ type rawMessage struct { Pad0 [4]byte } +func setRawMessageControl(msg *rawMessage, buf []byte) { + if len(buf) == 0 { + msg.Hdr.Control = nil + msg.Hdr.Controllen = 0 + return + } + msg.Hdr.Control = &buf[0] + msg.Hdr.Controllen = uint64(len(buf)) +} + +func getRawMessageControlLen(msg *rawMessage) int { + return int(msg.Hdr.Controllen) +} + +func setCmsgLen(h *unix.Cmsghdr, l int) { + h.Len = uint64(l) +} + func (u *StdConn) PrepareRawMessages(n int, pg PacketBufferGetter) ([]rawMessage, []*packet.Packet, [][]byte) { msgs := make([]rawMessage, n) names := make([][]byte, n) @@ -42,6 +60,7 @@ func (u *StdConn) PrepareRawMessages(n int, pg PacketBufferGetter) ([]rawMessage for i := range packets { packets[i] = pg() } + //todo? for i := range msgs { names[i] = make([]byte, unix.SizeofSockaddrInet6) @@ -55,6 +74,13 @@ func (u *StdConn) PrepareRawMessages(n int, pg PacketBufferGetter) ([]rawMessage msgs[i].Hdr.Name = &names[i][0] msgs[i].Hdr.Namelen = uint32(len(names[i])) + if u.enableGRO { + msgs[i].Hdr.Control = &packets[i].Control[0] + msgs[i].Hdr.Controllen = uint64(len(packets[i].Control)) + } else { + msgs[i].Hdr.Control = nil + msgs[i].Hdr.Controllen = 0 + } } return msgs, packets, names