From d18d1aea67bfc35cbd0a9cca806e7ca04321510e Mon Sep 17 00:00:00 2001 From: Ryan Date: Wed, 5 Nov 2025 20:34:02 -0500 Subject: [PATCH] first --- udp/udp_linux.go | 370 +++++++++++++++++++++++++++++++++++++++++++- udp/udp_linux_32.go | 18 +++ udp/udp_linux_64.go | 18 +++ 3 files changed, 403 insertions(+), 3 deletions(-) diff --git a/udp/udp_linux.go b/udp/udp_linux.go index e3df48f..ab52ec4 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -5,9 +5,11 @@ package udp import ( "encoding/binary" + "errors" "fmt" "net" "net/netip" + "sync" "syscall" "time" "unsafe" @@ -20,11 +22,35 @@ import ( var readTimeout = unix.NsecToTimeval(int64(time.Millisecond * 500)) +const ( + defaultGSOMaxSegments = 8 + defaultGSOFlushTimeout = 150 * time.Microsecond + maxGSOBatchBytes = 0xFFFF +) + +var ( + errGSOFallback = errors.New("udp gso fallback") + errGSODisabled = errors.New("udp gso disabled") +) + type StdConn struct { sysFd int isV4 bool l *logrus.Logger batch int + + enableGRO bool + enableGSO bool + + gsoMu sync.Mutex + gsoBuf []byte + gsoAddr netip.AddrPort + gsoSegSize int + gsoSegments int + gsoMaxSegments int + gsoMaxBytes int + gsoFlushTimeout time.Duration + gsoTimer *time.Timer } func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { @@ -69,7 +95,15 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in return nil, fmt.Errorf("unable to bind to socket: %s", err) } - return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err + return &StdConn{ + sysFd: fd, + isV4: ip.Is4(), + l: l, + batch: batch, + gsoMaxSegments: defaultGSOMaxSegments, + gsoMaxBytes: MTU * defaultGSOMaxSegments, + gsoFlushTimeout: defaultGSOFlushTimeout, + }, err } func (u *StdConn) Rebind() error { @@ -119,7 +153,10 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) { } func (u *StdConn) ListenOut(r EncReader) error { - var ip netip.Addr + var ( + ip netip.Addr + controls [][]byte + ) msgs, buffers, names := u.PrepareRawMessages(u.batch) read := u.ReadMulti @@ -128,6 +165,23 @@ func (u *StdConn) ListenOut(r EncReader) error { } for { + if u.enableGRO { + if controls == nil { + controls = make([][]byte, len(msgs)) + for i := range controls { + controls[i] = make([]byte, unix.CmsgSpace(4)) + } + } + for i := range msgs { + setRawMessageControl(&msgs[i], controls[i]) + } + } else if controls != nil { + for i := range msgs { + setRawMessageControl(&msgs[i], nil) + } + controls = nil + } + n, err := read(msgs) if err != nil { return err @@ -140,7 +194,23 @@ func (u *StdConn) ListenOut(r EncReader) error { } else { ip, _ = netip.AddrFromSlice(names[i][8:24]) } - r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len]) + addr := netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])) + payload := buffers[i][:msgs[i].Len] + + if controls != nil { + if ctrlLen := getRawMessageControlLen(&msgs[i]); ctrlLen > 0 { + if segSize, segCount := parseGROControl(controls[i][:ctrlLen]); segCount > 1 && segSize > 0 { + segSize = normalizeGROSegSize(segSize, segCount, len(payload)) + if segSize > 0 && segSize < len(payload) { + if u.emitGROSegments(r, addr, payload, segSize) { + continue + } + } + } + } + } + + r(addr, payload) } } } @@ -193,6 +263,14 @@ func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) { } func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error { + if u.enableGSO && ip.IsValid() { + if err := u.queueGSOPacket(b, ip); err == nil { + return nil + } else if !errors.Is(err, errGSOFallback) { + return err + } + } + if u.isV4 { return u.writeTo4(b, ip) } @@ -299,6 +377,72 @@ func (u *StdConn) ReloadConfig(c *config.C) { u.l.WithError(err).Error("Failed to set listen.so_mark") } } + + u.configureGRO(c.GetBool("listen.enable_gro", false)) + u.configureGSO(c) +} + +func (u *StdConn) configureGRO(enable bool) { + if enable == u.enableGRO { + return + } + + if enable { + if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 1); err != nil { + u.l.WithError(err).Warn("Failed to enable UDP GRO") + return + } + u.enableGRO = true + u.l.Info("UDP GRO enabled") + return + } + + if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 0); err != nil && err != unix.ENOPROTOOPT { + u.l.WithError(err).Warn("Failed to disable UDP GRO") + } + u.enableGRO = false +} + +func (u *StdConn) configureGSO(c *config.C) { + enable := c.GetBool("listen.enable_gso", false) + if !enable { + u.disableGSO() + } else { + u.enableGSO = true + } + + segments := c.GetInt("listen.gso_max_segments", defaultGSOMaxSegments) + if segments < 1 { + segments = 1 + } + u.gsoMaxSegments = segments + + maxBytes := c.GetInt("listen.gso_max_bytes", 0) + if maxBytes <= 0 { + maxBytes = MTU * segments + } + if maxBytes > maxGSOBatchBytes { + u.l.WithField("requested", maxBytes).Warn("listen.gso_max_bytes larger than UDP limit; clamping") + maxBytes = maxGSOBatchBytes + } + u.gsoMaxBytes = maxBytes + + timeout := c.GetDuration("listen.gso_flush_timeout", defaultGSOFlushTimeout) + if timeout < 0 { + timeout = 0 + } + u.gsoFlushTimeout = timeout +} + +func (u *StdConn) disableGSO() { + u.gsoMu.Lock() + defer u.gsoMu.Unlock() + u.enableGSO = false + _ = u.flushGSOlocked() + u.gsoBuf = nil + u.gsoSegments = 0 + u.gsoSegSize = 0 + u.stopGSOTimerLocked() } func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error { @@ -310,7 +454,227 @@ func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error { return nil } +func (u *StdConn) queueGSOPacket(b []byte, addr netip.AddrPort) error { + if len(b) == 0 { + return nil + } + + u.gsoMu.Lock() + defer u.gsoMu.Unlock() + + if !u.enableGSO || !addr.IsValid() || len(b) > u.gsoMaxBytes { + if err := u.flushGSOlocked(); err != nil { + return err + } + return errGSOFallback + } + + if u.gsoSegments == 0 { + if cap(u.gsoBuf) < u.gsoMaxBytes { + u.gsoBuf = make([]byte, 0, u.gsoMaxBytes) + } + u.gsoAddr = addr + u.gsoSegSize = len(b) + } else if addr != u.gsoAddr || len(b) != u.gsoSegSize { + if err := u.flushGSOlocked(); err != nil { + return err + } + if cap(u.gsoBuf) < u.gsoMaxBytes { + u.gsoBuf = make([]byte, 0, u.gsoMaxBytes) + } + u.gsoAddr = addr + u.gsoSegSize = len(b) + } + + if len(u.gsoBuf)+len(b) > u.gsoMaxBytes { + if err := u.flushGSOlocked(); err != nil { + return err + } + if cap(u.gsoBuf) < u.gsoMaxBytes { + u.gsoBuf = make([]byte, 0, u.gsoMaxBytes) + } + u.gsoAddr = addr + u.gsoSegSize = len(b) + } + + u.gsoBuf = append(u.gsoBuf, b...) + u.gsoSegments++ + + if u.gsoSegments >= u.gsoMaxSegments || u.gsoFlushTimeout <= 0 { + return u.flushGSOlocked() + } + + u.scheduleGSOFlushLocked() + return nil +} + +func (u *StdConn) flushGSOlocked() error { + if u.gsoSegments == 0 { + u.stopGSOTimerLocked() + return nil + } + + payload := append([]byte(nil), u.gsoBuf...) + addr := u.gsoAddr + segSize := u.gsoSegSize + + u.gsoBuf = u.gsoBuf[:0] + u.gsoSegments = 0 + u.gsoSegSize = 0 + u.stopGSOTimerLocked() + + if segSize <= 0 { + return errGSOFallback + } + + err := u.sendSegmented(payload, addr, segSize) + if errors.Is(err, errGSODisabled) { + u.l.WithField("addr", addr).Warn("UDP GSO disabled by kernel, falling back to sendto") + u.enableGSO = false + return u.sendSegmentsIndividually(payload, addr, segSize) + } + + return err +} + +func (u *StdConn) sendSegmented(payload []byte, addr netip.AddrPort, segSize int) error { + if len(payload) == 0 { + return nil + } + + control := make([]byte, unix.CmsgSpace(2)) + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + hdr.Level = unix.SOL_UDP + hdr.Type = unix.UDP_SEGMENT + setCmsgLen(hdr, unix.CmsgLen(2)) + binary.LittleEndian.PutUint16(control[unix.CmsgLen(0):unix.CmsgLen(0)+2], uint16(segSize)) + + var sa unix.Sockaddr + if addr.Addr().Is4() { + var sa4 unix.SockaddrInet4 + sa4.Port = int(addr.Port()) + sa4.Addr = addr.Addr().As4() + sa = &sa4 + } else { + var sa6 unix.SockaddrInet6 + sa6.Port = int(addr.Port()) + sa6.Addr = addr.Addr().As16() + sa = &sa6 + } + + if _, err := unix.SendmsgN(u.sysFd, payload, control, sa, 0); err != nil { + if errno, ok := err.(syscall.Errno); ok && (errno == unix.EINVAL || errno == unix.ENOTSUP || errno == unix.EOPNOTSUPP) { + return errGSODisabled + } + return &net.OpError{Op: "sendmsg", Err: err} + } + return nil +} + +func (u *StdConn) sendSegmentsIndividually(buf []byte, addr netip.AddrPort, segSize int) error { + if segSize <= 0 { + return errGSOFallback + } + + for offset := 0; offset < len(buf); offset += segSize { + end := offset + segSize + if end > len(buf) { + end = len(buf) + } + var err error + if u.isV4 { + err = u.writeTo4(buf[offset:end], addr) + } else { + err = u.writeTo6(buf[offset:end], addr) + } + if err != nil { + return err + } + } + return nil +} + +func (u *StdConn) scheduleGSOFlushLocked() { + if u.gsoTimer == nil { + u.gsoTimer = time.AfterFunc(u.gsoFlushTimeout, u.gsoFlushTimer) + return + } + u.gsoTimer.Reset(u.gsoFlushTimeout) +} + +func (u *StdConn) stopGSOTimerLocked() { + if u.gsoTimer != nil { + u.gsoTimer.Stop() + u.gsoTimer = nil + } +} + +func (u *StdConn) gsoFlushTimer() { + u.gsoMu.Lock() + defer u.gsoMu.Unlock() + _ = u.flushGSOlocked() +} + +func parseGROControl(control []byte) (int, int) { + if len(control) == 0 { + return 0, 0 + } + + cmsgs, err := unix.ParseSocketControlMessage(control) + if err != nil { + return 0, 0 + } + + for _, c := range cmsgs { + if c.Header.Level == unix.SOL_UDP && c.Header.Type == unix.UDP_GRO && len(c.Data) >= 2 { + segSize := int(binary.LittleEndian.Uint16(c.Data[:2])) + segCount := 0 + if len(c.Data) >= 4 { + segCount = int(binary.LittleEndian.Uint16(c.Data[2:4])) + } + return segSize, segCount + } + } + + return 0, 0 +} + +func (u *StdConn) emitGROSegments(r EncReader, addr netip.AddrPort, payload []byte, segSize int) bool { + if segSize <= 0 || segSize >= len(payload) { + return false + } + + for offset := 0; offset < len(payload); offset += segSize { + end := offset + segSize + if end > len(payload) { + end = len(payload) + } + r(addr, payload[offset:end]) + } + return true +} + +func normalizeGROSegSize(segSize, segCount, total int) int { + if segCount > 1 && total > 0 { + avg := total / segCount + if avg > 0 { + if segSize > avg { + if segSize-8 == avg { + segSize = avg + } else if segSize > total { + segSize = avg + } + } + } + } + if segSize > total { + segSize = total + } + return segSize +} + func (u *StdConn) Close() error { + u.disableGSO() return syscall.Close(u.sysFd) } diff --git a/udp/udp_linux_32.go b/udp/udp_linux_32.go index de8f1cd..3204776 100644 --- a/udp/udp_linux_32.go +++ b/udp/udp_linux_32.go @@ -52,3 +52,21 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { return msgs, buffers, names } + +func setRawMessageControl(msg *rawMessage, buf []byte) { + if len(buf) == 0 { + msg.Hdr.Control = nil + msg.Hdr.Controllen = 0 + return + } + msg.Hdr.Control = &buf[0] + msg.Hdr.Controllen = uint32(len(buf)) +} + +func getRawMessageControlLen(msg *rawMessage) int { + return int(msg.Hdr.Controllen) +} + +func setCmsgLen(h *unix.Cmsghdr, l int) { + h.Len = uint32(l) +} diff --git a/udp/udp_linux_64.go b/udp/udp_linux_64.go index 48c5a97..a09173d 100644 --- a/udp/udp_linux_64.go +++ b/udp/udp_linux_64.go @@ -55,3 +55,21 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { return msgs, buffers, names } + +func setRawMessageControl(msg *rawMessage, buf []byte) { + if len(buf) == 0 { + msg.Hdr.Control = nil + msg.Hdr.Controllen = 0 + return + } + msg.Hdr.Control = &buf[0] + msg.Hdr.Controllen = uint64(len(buf)) +} + +func getRawMessageControlLen(msg *rawMessage) int { + return int(msg.Hdr.Controllen) +} + +func setCmsgLen(h *unix.Cmsghdr, l int) { + h.Len = uint64(l) +}