diff --git a/udp/udp_linux.go b/udp/udp_linux.go index 1c628836..03fd99d8 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -54,6 +54,13 @@ type StdConn struct { // probed once at socket creation. When true, WriteSegmented takes a // single-syscall GSO path; otherwise it falls back to a WriteTo loop. gsoSupported bool + + // UDP GRO (recvmsg with UDP_GRO cmsg) support. groSupported is probed + // once at socket creation. When true, listenOutBatch allocates larger + // RX buffers and a per-entry cmsg slot so the kernel can coalesce + // consecutive same-flow datagrams into a single recvmmsg entry; the + // delivered cmsg carries the gso_size used to split them back apart. + groSupported bool } func setReusePort(network, address string, c syscall.RawConn) error { @@ -104,6 +111,12 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in out.writeFunc = out.sendmmsgRawWrite out.prepareGSO() + // GRO delivers coalesced superpackets that need a cmsg to split back + // into segments. The single-packet RX path uses ReadFromUDPAddrPort + // and cannot see that cmsg, so only enable GRO for the batch path. + if batch > 1 { + out.prepareGRO() + } return out, nil } @@ -162,6 +175,34 @@ func (u *StdConn) prepareGSO() { u.gsoSupported = true } +// udpGROBufferSize sizes the per-entry recvmmsg buffer when UDP_GRO is on. +// The kernel stitches a run of same-flow datagrams into a single skb whose +// length is bounded by sk_gso_max_size (typically 65535); anything larger +// would be MSG_TRUNCed. We use the maximum representable UDP length so a +// full superpacket always lands intact. +const udpGROBufferSize = 65535 + +// udpGROCmsgPayload is the size of the UDP_GRO cmsg data delivered by the +// kernel: a single int (gso_size in bytes). See udp_cmsg_recv() in +// net/ipv4/udp.c. +const udpGROCmsgPayload = 4 + +// prepareGRO turns on UDP_GRO so the kernel coalesces consecutive same-flow +// datagrams into one recvmmsg entry, with a cmsg carrying the gso_size used +// to split them back apart on the application side. +func (u *StdConn) prepareGRO() { + var probeErr error + if err := u.rawConn.Control(func(fd uintptr) { + probeErr = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1) + }); err != nil { + return + } + if probeErr != nil { + return + } + u.groSupported = true +} + func (u *StdConn) SupportsMultipleReaders() bool { return true } @@ -282,7 +323,13 @@ func (u *StdConn) listenOutBatch(r EncReader, flush func()) error { var n int var operr error - msgs, buffers, names := u.PrepareRawMessages(u.batch) + bufSize := MTU + cmsgSpace := 0 + if u.groSupported { + bufSize = udpGROBufferSize + cmsgSpace = unix.CmsgSpace(udpGROCmsgPayload) + } + msgs, buffers, names, _ := u.PrepareRawMessages(u.batch, bufSize, cmsgSpace) //reader needs to capture variables from this function, since it's used as a lambda with rawConn.Read //defining it outside the loop so it gets re-used @@ -292,6 +339,11 @@ func (u *StdConn) listenOutBatch(r EncReader, flush func()) error { } for { + if cmsgSpace > 0 { + for i := range msgs { + setMsgControllen(&msgs[i].Hdr, cmsgSpace) + } + } err := u.rawConn.Read(reader) if err != nil { return err @@ -307,7 +359,28 @@ func (u *StdConn) listenOutBatch(r EncReader, flush func()) error { } else { ip, _ = netip.AddrFromSlice(names[i][8:24]) } - r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len]) + from := netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])) + payload := buffers[i][:msgs[i].Len] + + segSize := 0 + if u.groSupported { + segSize = parseUDPGRO(&msgs[i].Hdr) + } + if segSize <= 0 || segSize >= len(payload) { + // No coalescing happened (or a lone datagram). + r(from, payload) + continue + } + // GRO superpacket: the kernel guarantees every segment is + // exactly segSize bytes except for the final one, which may be + // short. + for off := 0; off < len(payload); off += segSize { + end := off + segSize + if end > len(payload) { + end = len(payload) + } + r(from, payload[off:end]) + } } // End-of-batch: let callers (e.g. TUN write coalescer) flush any // state they accumulated across this batch. @@ -315,6 +388,38 @@ func (u *StdConn) listenOutBatch(r EncReader, flush func()) error { } } +// parseUDPGRO walks the control buffer on hdr looking for a SOL_UDP/UDP_GRO +// cmsg and returns the gso_size (bytes per coalesced segment) it carries. +// Returns 0 when no UDP_GRO cmsg is present, which is the normal case for +// lone datagrams that the kernel did not coalesce. +func parseUDPGRO(hdr *msghdr) int { + controllen := int(hdr.Controllen) + if controllen < unix.SizeofCmsghdr || hdr.Control == nil { + return 0 + } + ctrl := unsafe.Slice(hdr.Control, controllen) + off := 0 + for off+unix.SizeofCmsghdr <= len(ctrl) { + ch := (*unix.Cmsghdr)(unsafe.Pointer(&ctrl[off])) + clen := int(ch.Len) + if clen < unix.SizeofCmsghdr || off+clen > len(ctrl) { + return 0 + } + if ch.Level == unix.SOL_UDP && ch.Type == unix.UDP_GRO { + dataOff := off + unix.CmsgLen(0) + if dataOff+udpGROCmsgPayload <= len(ctrl) { + return int(int32(binary.NativeEndian.Uint32(ctrl[dataOff : dataOff+udpGROCmsgPayload]))) + } + return 0 + } + // Advance by the aligned cmsg space. CmsgSpace(n) is the stride + // from one header to the next (len aligned up to the platform's + // cmsg alignment). + off += unix.CmsgSpace(clen - unix.CmsgLen(0)) + } + return 0 +} + func (u *StdConn) ListenOut(r EncReader, flush func()) error { if u.batch == 1 { return u.listenOutSingle(r, flush) diff --git a/udp/udp_linux_32.go b/udp/udp_linux_32.go index 1d4c7fbb..0f153a49 100644 --- a/udp/udp_linux_32.go +++ b/udp/udp_linux_32.go @@ -30,13 +30,18 @@ type rawMessage struct { Len uint32 } -func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { +func (u *StdConn) PrepareRawMessages(n, bufSize, cmsgSpace int) ([]rawMessage, [][]byte, [][]byte, []byte) { msgs := make([]rawMessage, n) buffers := make([][]byte, n) names := make([][]byte, n) + var cmsgs []byte + if cmsgSpace > 0 { + cmsgs = make([]byte, n*cmsgSpace) + } + for i := range msgs { - buffers[i] = make([]byte, MTU) + buffers[i] = make([]byte, bufSize) names[i] = make([]byte, unix.SizeofSockaddrInet6) vs := []iovec{ @@ -48,9 +53,14 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { msgs[i].Hdr.Name = &names[i][0] msgs[i].Hdr.Namelen = uint32(len(names[i])) + + if cmsgSpace > 0 { + msgs[i].Hdr.Control = &cmsgs[i*cmsgSpace] + msgs[i].Hdr.Controllen = uint32(cmsgSpace) + } } - return msgs, buffers, names + return msgs, buffers, names, cmsgs } func setIovLen(v *iovec, n int) { diff --git a/udp/udp_linux_64.go b/udp/udp_linux_64.go index 0293b1a4..dc373538 100644 --- a/udp/udp_linux_64.go +++ b/udp/udp_linux_64.go @@ -33,13 +33,18 @@ type rawMessage struct { Pad0 [4]byte } -func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { +func (u *StdConn) PrepareRawMessages(n, bufSize, cmsgSpace int) ([]rawMessage, [][]byte, [][]byte, []byte) { msgs := make([]rawMessage, n) buffers := make([][]byte, n) names := make([][]byte, n) + var cmsgs []byte + if cmsgSpace > 0 { + cmsgs = make([]byte, n*cmsgSpace) + } + for i := range msgs { - buffers[i] = make([]byte, MTU) + buffers[i] = make([]byte, bufSize) names[i] = make([]byte, unix.SizeofSockaddrInet6) vs := []iovec{ @@ -51,9 +56,14 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { msgs[i].Hdr.Name = &names[i][0] msgs[i].Hdr.Namelen = uint32(len(names[i])) + + if cmsgSpace > 0 { + msgs[i].Hdr.Control = &cmsgs[i*cmsgSpace] + msgs[i].Hdr.Controllen = uint64(cmsgSpace) + } } - return msgs, buffers, names + return msgs, buffers, names, cmsgs } func setIovLen(v *iovec, n int) {