cursed gso

This commit is contained in:
JackDoan
2025-11-06 17:56:46 -06:00
parent 2ab75709ad
commit 7999b62147
14 changed files with 719 additions and 29 deletions

191
cmd/gso/gso.go Normal file
View File

@@ -0,0 +1,191 @@
package main
import (
"encoding/binary"
"errors"
"flag"
"fmt"
"log"
"net"
"net/netip"
"time"
"unsafe"
"golang.org/x/sys/unix"
)
const (
// UDP_SEGMENT enables GSO segmentation
UDP_SEGMENT = 103
// Maximum GSO segment size (typical MTU - headers)
maxGSOSize = 1400
)
func main() {
destAddr := flag.String("dest", "10.4.0.16:4202", "Destination address")
gsoSize := flag.Int("gso", 1400, "GSO segment size")
totalSize := flag.Int("size", 14000, "Total payload size to send")
count := flag.Int("count", 1, "Number of packets to send")
flag.Parse()
if *gsoSize > maxGSOSize {
log.Fatalf("GSO size %d exceeds maximum %d", *gsoSize, maxGSOSize)
}
// Resolve destination address
_, err := net.ResolveUDPAddr("udp", *destAddr)
if err != nil {
log.Fatalf("Failed to resolve address: %v", err)
}
// Create a raw UDP socket with GSO support
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
if err != nil {
log.Fatalf("Failed to create socket: %v", err)
}
defer unix.Close(fd)
// Bind to a local address
localAddr := &unix.SockaddrInet4{
Port: 0, // Let the system choose a port
}
if err := unix.Bind(fd, localAddr); err != nil {
log.Fatalf("Failed to bind socket: %v", err)
}
fmt.Printf("Sending UDP packets with GSO enabled\n")
fmt.Printf("Destination: %s\n", *destAddr)
fmt.Printf("GSO segment size: %d bytes\n", *gsoSize)
fmt.Printf("Total payload size: %d bytes\n", *totalSize)
fmt.Printf("Number of packets: %d\n\n", *count)
// Create payload
payload := make([]byte, *totalSize)
for i := range payload {
payload[i] = byte(i % 256)
}
dest := netip.MustParseAddrPort(*destAddr)
//if err := unix.SetsockoptInt(fd, unix.SOL_UDP, unix.UDP_SEGMENT, 1400); err != nil {
// panic(err)
//}
for i := 0; i < *count; i++ {
err := WriteBatch(fd, payload, dest, uint16(*gsoSize), true)
if err != nil {
log.Printf("Send error on packet %d: %v", i, err)
continue
}
if (i+1)%100 == 0 || i == *count-1 {
fmt.Printf("Sent %d packets\n", i+1)
}
}
fmt.Printf("now, let's send without the correct ctrl header\n")
time.Sleep(time.Second)
for i := 0; i < *count; i++ {
err := WriteBatch(fd, payload, dest, uint16(*gsoSize), false)
if err != nil {
log.Printf("Send error on packet %d: %v", i, err)
continue
}
if (i+1)%100 == 0 || i == *count-1 {
fmt.Printf("Sent %d packets\n", i+1)
}
}
}
func WriteBatch(fd int, payload []byte, addr netip.AddrPort, segSize uint16, withHeader bool) error {
msgs := make([]rawMessage, 0, 1)
iovs := make([]iovec, 0, 1)
names := make([][unix.SizeofSockaddrInet6]byte, 0, 1)
sent := 0
pkts := []BatchPacket{
{
Payload: payload,
Addr: addr,
},
}
for _, pkt := range pkts {
if len(pkt.Payload) == 0 {
sent++
continue
}
msgs = append(msgs, rawMessage{})
iovs = append(iovs, iovec{})
names = append(names, [unix.SizeofSockaddrInet6]byte{})
idx := len(msgs) - 1
msg := &msgs[idx]
iov := &iovs[idx]
name := &names[idx]
setIovecSlice(iov, pkt.Payload)
msg.Hdr.Iov = iov
msg.Hdr.Iovlen = 1
if withHeader {
setRawMessageControl(msg, buildGSOControlMessage(segSize)) //
} else {
setRawMessageControl(msg, nil) //
}
msg.Hdr.Flags = 0
nameLen, err := encodeSockaddr(name[:], pkt.Addr)
if err != nil {
return err
}
msg.Hdr.Name = &name[0]
msg.Hdr.Namelen = nameLen
}
if len(msgs) == 0 {
return errors.New("nothing to write")
}
offset := 0
for offset < len(msgs) {
n, _, errno := unix.Syscall6(
unix.SYS_SENDMMSG,
uintptr(fd),
uintptr(unsafe.Pointer(&msgs[offset])),
uintptr(len(msgs)-offset),
0,
0,
0,
)
if errno != 0 {
if errno == unix.EINTR {
continue
}
return &net.OpError{Op: "sendmmsg", Err: errno}
}
if n == 0 {
break
}
offset += int(n)
}
return nil
}
func buildGSOControlMessage(segSize uint16) []byte {
control := make([]byte, unix.CmsgSpace(2))
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = unix.SOL_UDP
hdr.Type = unix.UDP_SEGMENT
setCmsgLen(hdr, unix.CmsgLen(2))
binary.NativeEndian.PutUint16(control[unix.CmsgLen(0):unix.CmsgLen(0)+2], uint16(segSize))
return control
}

85
cmd/gso/helper.go Normal file
View File

@@ -0,0 +1,85 @@
package main
import (
"encoding/binary"
"fmt"
"net/netip"
"unsafe"
"golang.org/x/sys/unix"
)
type iovec struct {
Base *byte
Len uint64
}
type msghdr struct {
Name *byte
Namelen uint32
Pad0 [4]byte
Iov *iovec
Iovlen uint64
Control *byte
Controllen uint64
Flags int32
Pad1 [4]byte
}
type rawMessage struct {
Hdr msghdr
Len uint32
Pad0 [4]byte
}
type BatchPacket struct {
Payload []byte
Addr netip.AddrPort
}
func encodeSockaddr(dst []byte, addr netip.AddrPort) (uint32, error) {
if addr.Addr().Is4() {
if !addr.Addr().Is4() {
return 0, fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
}
var sa unix.RawSockaddrInet4
sa.Family = unix.AF_INET
sa.Addr = addr.Addr().As4()
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
size := unix.SizeofSockaddrInet4
copy(dst[:size], (*(*[unix.SizeofSockaddrInet4]byte)(unsafe.Pointer(&sa)))[:])
return uint32(size), nil
}
var sa unix.RawSockaddrInet6
sa.Family = unix.AF_INET6
sa.Addr = addr.Addr().As16()
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
size := unix.SizeofSockaddrInet6
copy(dst[:size], (*(*[unix.SizeofSockaddrInet6]byte)(unsafe.Pointer(&sa)))[:])
return uint32(size), nil
}
func setRawMessageControl(msg *rawMessage, buf []byte) {
if len(buf) == 0 {
msg.Hdr.Control = nil
msg.Hdr.Controllen = 0
return
}
msg.Hdr.Control = &buf[0]
msg.Hdr.Controllen = uint64(len(buf))
}
func setCmsgLen(h *unix.Cmsghdr, l int) {
h.Len = uint64(l)
}
func setIovecSlice(iov *iovec, b []byte) {
if len(b) == 0 {
iov.Base = nil
iov.Len = 0
return
}
iov.Base = &b[0]
iov.Len = uint64(len(b))
}

View File

@@ -518,12 +518,12 @@ func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
if cm.punchy.GetTargetEverything() {
hostinfo.remotes.ForEach(cm.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) {
cm.metricsTxPunchy.Inc(1)
cm.intf.outside.WriteTo([]byte{1}, addr)
cm.intf.outside.WriteDirect([]byte{1}, addr)
})
} else if hostinfo.remote.IsValid() {
cm.metricsTxPunchy.Inc(1)
cm.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
cm.intf.outside.WriteDirect([]byte{1}, hostinfo.remote)
}
}

View File

@@ -15,7 +15,8 @@ import (
// TODO: In a 5Gbps test, 1024 is not sufficient. With a 1400 MTU this is about 1.4Gbps of window, assuming full packets.
// 4092 should be sufficient for 5Gbps
const ReplayWindow = 8192
// TODO this is a horrible amount of RAM to waste per-tunnel
const ReplayWindow = 0xffff / 2
type ConnectionState struct {
eKey *NebulaCipherState

View File

@@ -348,7 +348,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
msg = existing.HandshakePacket[2]
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
if addr.IsValid() {
err := f.outside.WriteTo(msg, addr)
err := f.outside.WriteDirect(msg, addr)
if err != nil {
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
@@ -417,7 +417,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
// Do the send
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
if addr.IsValid() {
err = f.outside.WriteTo(msg, addr)
err = f.outside.WriteDirect(msg, addr)
if err != nil {
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).

View File

@@ -238,7 +238,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
var sentTo []netip.AddrPort
hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr netip.AddrPort, _ bool) {
hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
err := hm.outside.WriteDirect(hostinfo.HandshakePacket[0], addr)
if err != nil {
hostinfo.logger(hm.l).WithField("udpAddr", addr).
WithField("initiatorIndex", hostinfo.localIndexId).

View File

@@ -8,6 +8,7 @@ import (
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/noiseutil"
"github.com/slackhq/nebula/packet"
"github.com/slackhq/nebula/routing"
)
@@ -324,7 +325,7 @@ func (f *Interface) SendVia(via *HostInfo,
via.logger(f.l).WithError(err).Info("Failed to EncryptDanger in sendVia")
return
}
err = f.writers[0].WriteTo(out, via.remote)
err = f.writers[0].WriteDirect(out, via.remote)
if err != nil {
via.logger(f.l).WithError(err).Info("Failed to WriteTo in sendVia")
}
@@ -384,19 +385,29 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
}
if remote.IsValid() {
err = f.writers[q].WriteTo(out, remote)
pkt := packet.GetPool().Get()
copy(pkt.Payload[:], out)
pkt.Payload = pkt.Payload[:len(out)]
pkt.Addr = remote
err = f.writers[q].WriteTo(pkt)
if err != nil {
hostinfo.logger(f.l).WithError(err).
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
}
} else if hostinfo.remote.IsValid() {
err = f.writers[q].WriteTo(out, hostinfo.remote)
pkt := packet.GetPool().Get()
copy(pkt.Payload, out)
pkt.Payload = pkt.Payload[:len(out)]
pkt.Addr = hostinfo.remote
err = f.writers[q].WriteTo(pkt)
if err != nil {
hostinfo.logger(f.l).WithError(err).
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
}
} else {
// Try to send via a relay
//todo relay is slow sorryyy
for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
if err != nil {

View File

@@ -207,7 +207,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
l: c.l,
}
ifce.pktPool = packet.NewPool()
ifce.pktPool = packet.GetPool()
ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
ifce.reQueryEvery.Store(c.reQueryEvery)
@@ -327,6 +327,26 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
f.wg.Done()
}
//// todo why? understand!
//func normalizeGROSegSize(segSize, total int) int {
// if segCount > 1 && total > 0 {
// avg := total / segCount
// if avg > 0 {
// if segSize > avg {
// if segSize-8 == avg {
// segSize = avg
// } else if segSize > total {
// segSize = avg
// }
// }
// }
// }
// if segSize > total {
// segSize = total
// }
// return segSize
//}
func (f *Interface) workerIn(i int, ctx context.Context) {
lhh := f.lightHouse.NewRequestHandler()
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
@@ -338,7 +358,18 @@ func (f *Interface) workerIn(i int, ctx context.Context) {
for {
select {
case p := <-f.inbound:
f.readOutsidePackets(p.Addr, nil, result2[:0], p.Payload, h, fwPacket2, lhh, nb2, i, conntrackCache.Get(f.l))
if p.SegSize > 0 && p.SegSize < len(p.Payload) {
for offset := 0; offset < len(p.Payload); offset += p.SegSize {
end := offset + p.SegSize
if end > len(p.Payload) {
end = len(p.Payload)
}
f.readOutsidePackets(p.Addr, nil, result2[:0], p.Payload[offset:end], h, fwPacket2, lhh, nb2, i, conntrackCache.Get(f.l))
}
} else {
f.readOutsidePackets(p.Addr, nil, result2[:0], p.Payload, h, fwPacket2, lhh, nb2, i, conntrackCache.Get(f.l))
}
f.pktPool.Put(p)
case <-ctx.Done():
f.wg.Done()
@@ -357,7 +388,7 @@ func (f *Interface) workerOut(i int, ctx context.Context) {
select {
case data := <-f.outbound:
f.consumeInsidePacket(data.Payload, fwPacket1, nb1, result1, i, conntrackCache.Get(f.l))
f.pktPool.Put(data)
//f.pktPool.Put(data) //todo if err pls put packet back
case <-ctx.Done():
f.wg.Done()
return

View File

@@ -1329,7 +1329,7 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
go func() {
time.Sleep(lhh.lh.punchy.GetDelay())
lhh.lh.metricHolepunchTx.Inc(1)
lhh.lh.punchConn.WriteTo(empty, vpnPeer)
lhh.lh.punchConn.WriteDirect(empty, vpnPeer)
}()
if lhh.l.Level >= logrus.DebugLevel {

View File

@@ -519,7 +519,7 @@ func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) {
f.messageMetrics.Tx(header.RecvError, 0, 1)
b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0)
_ = f.outside.WriteTo(b, endpoint)
_ = f.outside.WriteDirect(b, endpoint)
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("index", index).
WithField("udpAddr", endpoint).

View File

@@ -3,27 +3,36 @@ package packet
import (
"net/netip"
"sync"
"golang.org/x/sys/unix"
)
const Size = 9001
const Size = 0xffff
type Packet struct {
Payload []byte
Control []byte
SegSize int
Addr netip.AddrPort
}
func New() *Packet {
return &Packet{Payload: make([]byte, Size)}
return &Packet{
Payload: make([]byte, Size),
Control: make([]byte, unix.CmsgSpace(2)),
}
}
type Pool struct {
pool sync.Pool
}
func NewPool() *Pool {
return &Pool{
pool: sync.Pool{New: func() any { return New() }},
}
var bigPool = &Pool{
pool: sync.Pool{New: func() any { return New() }},
}
func GetPool() *Pool {
return bigPool
}
func (p *Pool) Get() *Packet {

View File

@@ -17,7 +17,8 @@ type Conn interface {
Rebind() error
LocalAddr() (netip.AddrPort, error)
ListenOut(pg PacketBufferGetter, pc chan *packet.Packet) error
WriteTo(b []byte, addr netip.AddrPort) error
WriteTo(p *packet.Packet) error
WriteDirect(b []byte, port netip.AddrPort) error
ReloadConfig(c *config.C)
Close() error
}

View File

@@ -5,9 +5,11 @@ package udp
import (
"encoding/binary"
"errors"
"fmt"
"net"
"net/netip"
"sync"
"syscall"
"time"
"unsafe"
@@ -19,13 +21,136 @@ import (
"golang.org/x/sys/unix"
)
const (
defaultGSOMaxSegments = 16
defaultGSOFlushTimeout = 150 * time.Microsecond
maxGSOBatchBytes = 0xFFFF
)
var (
errGSOFallback = errors.New("udp gso fallback")
errGSODisabled = errors.New("udp gso disabled")
)
var readTimeout = unix.NsecToTimeval(int64(time.Millisecond * 500))
type gsoState struct {
m sync.Mutex
Buf []byte
Addr netip.AddrPort
SegSize int
MaxSegments int
MaxBytes int
FlushTimeout time.Duration
Timer *time.Timer
packets []*packet.Packet
msg rawMessage
name [unix.SizeofSockaddrInet6]byte
iov []iovec
ctrl []byte
}
func (g *gsoState) Init() {
g.iov = make([]iovec, g.MaxSegments)
for i := 0; i < g.MaxSegments; i++ {
g.iov[i] = iovec{}
}
g.msg.Hdr.Iov = &g.iov[0]
g.msg.Hdr.Iovlen = 1
g.packets = make([]*packet.Packet, 0, g.MaxSegments)
g.ctrl = make([]byte, unix.CmsgSpace(2))
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&g.ctrl[0]))
hdr.Level = unix.SOL_UDP
hdr.Type = unix.UDP_SEGMENT
setCmsgLen(hdr, unix.CmsgLen(2))
g.msg.Hdr.Control = &g.ctrl[0]
g.msg.Hdr.Controllen = uint64(len(g.ctrl))
g.name = [unix.SizeofSockaddrInet6]byte{}
g.msg.Hdr.Name = &g.name[0]
}
func (g *gsoState) setSegSizeLocked(segSize int) {
g.SegSize = segSize
x := unix.CmsgLen(0)
binary.LittleEndian.PutUint16(g.ctrl[x:x+2], uint16(segSize))
}
func (g *gsoState) setNameLocked(x netip.AddrPort, isV4 bool) {
g.Addr = x
nameLen := encodeSockaddr(g.name[:], g.Addr, isV4)
g.msg.Hdr.Name = &g.name[0]
g.msg.Hdr.Namelen = nameLen
}
func encodeSockaddr(dst []byte, addr netip.AddrPort, isV4 bool) uint32 {
if isV4 {
//todo?
//if !addr.Addr().Is4() {
// return 0, fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
//}
var sa unix.RawSockaddrInet4
sa.Family = unix.AF_INET
sa.Addr = addr.Addr().As4()
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
size := unix.SizeofSockaddrInet4
copy(dst[:size], (*(*[unix.SizeofSockaddrInet4]byte)(unsafe.Pointer(&sa)))[:])
return uint32(size)
}
var sa unix.RawSockaddrInet6
sa.Family = unix.AF_INET6
sa.Addr = addr.Addr().As16()
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
size := unix.SizeofSockaddrInet6
copy(dst[:size], (*(*[unix.SizeofSockaddrInet6]byte)(unsafe.Pointer(&sa)))[:])
return uint32(size)
}
func (g *gsoState) sendmsgLocked(fd int) error {
//name already set
//ctrl already set
//g.iov = g.iov[:0]
g.msg.Hdr.Iovlen = uint64(len(g.packets))
for i := range g.packets {
g.iov[i].Base = &g.packets[i].Payload[0]
g.iov[i].Len = uint64(len(g.packets[i].Payload))
}
const flags = 0
for {
_, _, err := unix.Syscall(
unix.SYS_SENDMSG,
uintptr(fd),
uintptr(unsafe.Pointer(&g.msg)),
uintptr(flags),
)
//todo no matter what, reset things
for i := range g.packets {
pool := packet.GetPool()
pool.Put(g.packets[i])
}
g.packets = g.packets[:0]
if err != 0 {
return &net.OpError{Op: "sendmsg", Err: err}
}
return nil
}
}
type StdConn struct {
sysFd int
isV4 bool
l *logrus.Logger
batch int
sysFd int
isV4 bool
l *logrus.Logger
batch int
enableGRO bool
enableGSO bool
gso gsoState
}
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
@@ -145,15 +270,47 @@ func (u *StdConn) ListenOut(pg PacketBufferGetter, pc chan *packet.Packet) error
ip, _ = netip.AddrFromSlice(names[i][8:24])
}
out.Addr = netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4]))
ctrlLen := getRawMessageControlLen(&msgs[i])
if ctrlLen > 0 {
packets[i].SegSize = parseGROControl(packets[i].Control[:ctrlLen])
} else {
packets[i].SegSize = 0
}
pc <- out
//rotate this packet out so we don't overwrite it
packets[i] = pg()
msgs[i].Hdr.Iov.Base = &packets[i].Payload[0]
if u.enableGRO {
msgs[i].Hdr.Control = &packets[i].Control[0]
msgs[i].Hdr.Controllen = uint64(cap(packets[i].Control))
}
}
}
}
func parseGROControl(control []byte) int {
if len(control) == 0 {
return 0
}
cmsgs, err := unix.ParseSocketControlMessage(control)
if err != nil {
return 0
}
for _, c := range cmsgs {
if c.Header.Level == unix.SOL_UDP && c.Header.Type == unix.UDP_GRO && len(c.Data) >= 2 {
segSize := int(binary.LittleEndian.Uint16(c.Data[:2]))
return segSize
}
}
return 0
}
func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) {
for {
n, _, err := unix.Syscall6(
@@ -201,11 +358,123 @@ func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) {
}
}
func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
if u.isV4 {
return u.writeTo4(b, ip)
func (u *StdConn) WriteTo(p *packet.Packet) error {
if u.enableGSO && p.Addr.IsValid() {
if err := u.queueGSOPacket(p); err == nil {
return nil
} else if !errors.Is(err, errGSOFallback) {
return err
}
}
return u.writeTo6(b, ip)
var err error
if u.isV4 {
err = u.writeTo4(p.Payload, p.Addr)
} else {
err = u.writeTo4(p.Payload, p.Addr)
}
packet.GetPool().Put(p)
return err
}
func (u *StdConn) WriteDirect(b []byte, addr netip.AddrPort) error {
if u.isV4 {
return u.writeTo4(b, addr)
}
return u.writeTo6(b, addr)
}
func (u *StdConn) scheduleGSOFlushLocked() {
if u.gso.Timer == nil {
u.gso.Timer = time.AfterFunc(u.gso.FlushTimeout, u.gsoFlushTimer)
return
}
u.gso.Timer.Reset(u.gso.FlushTimeout)
}
func (u *StdConn) stopGSOTimerLocked() {
if u.gso.Timer != nil {
u.gso.Timer.Stop()
u.gso.Timer = nil //todo I also don't like this
}
}
func (u *StdConn) queueGSOPacket(p *packet.Packet) error {
if len(p.Payload) == 0 {
return nil
}
u.gso.m.Lock()
defer u.gso.m.Unlock()
if !u.enableGSO || !p.Addr.IsValid() || len(p.Payload) > u.gso.MaxBytes {
if err := u.flushGSOlocked(); err != nil {
return err
}
return errGSOFallback
}
if len(u.gso.packets) == 0 {
u.gso.setNameLocked(p.Addr, u.isV4)
u.gso.SegSize = len(p.Payload)
u.gso.packets = append(u.gso.packets, p)
} else if p.Addr != u.gso.Addr || len(p.Payload) != u.gso.SegSize {
if err := u.flushGSOlocked(); err != nil {
return err
} //todo deal with "one small packet" case
u.gso.setNameLocked(p.Addr, u.isV4)
u.gso.SegSize = len(p.Payload)
u.gso.packets = append(u.gso.packets, p)
} else {
u.gso.packets = append(u.gso.packets, p)
}
//big todo
//if len(u.gso.Buf)+len(p.Payload) > u.gso.MaxBytes {
// if err := u.flushGSOlocked(); err != nil {
// return err
// }
// u.gso.setNameLocked(p.Addr, u.isV4)
// u.gso.SegSize = len(p.Payload)
// u.gso.packets = append(u.gso.packets, p)
//}
if len(u.gso.packets) >= u.gso.MaxSegments || u.gso.FlushTimeout <= 0 {
return u.flushGSOlocked()
}
u.scheduleGSOFlushLocked()
return nil
}
func (u *StdConn) flushGSOlocked() error {
if len(u.gso.packets) == 0 {
u.stopGSOTimerLocked()
return nil
}
u.stopGSOTimerLocked()
if u.gso.SegSize <= 0 {
return errGSOFallback
}
err := u.gso.sendmsgLocked(u.sysFd)
if errors.Is(err, errGSODisabled) {
u.l.WithField("addr", u.gso.Addr).Warn("UDP GSO disabled by kernel, falling back to sendto")
u.enableGSO = false
//todo!
//return u.sendSegmentsIndividually(payload, addr, segSize)
}
u.gso.SegSize = 0
return err
}
func (u *StdConn) gsoFlushTimer() {
u.gso.m.Lock()
_ = u.flushGSOlocked()
u.gso.m.Unlock()
}
func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
@@ -308,6 +577,72 @@ func (u *StdConn) ReloadConfig(c *config.C) {
u.l.WithError(err).Error("Failed to set listen.so_mark")
}
}
u.configureGRO(true)
u.configureGSO(c)
}
func (u *StdConn) configureGRO(enable bool) {
if enable == u.enableGRO {
return
}
if enable {
if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 1); err != nil {
u.l.WithError(err).Warn("Failed to enable UDP GRO")
return
}
u.enableGRO = true
u.l.Info("UDP GRO enabled")
return
}
if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 0); err != nil && err != unix.ENOPROTOOPT {
u.l.WithError(err).Warn("Failed to disable UDP GRO")
}
u.enableGRO = false
}
func (u *StdConn) configureGSO(c *config.C) {
enable := c.GetBool("listen.enable_gso", true)
if !enable {
u.disableGSO()
} else {
u.enableGSO = true
}
segments := c.GetInt("listen.gso_max_segments", defaultGSOMaxSegments)
if segments < 1 {
segments = 1
}
u.gso.MaxSegments = segments
maxBytes := c.GetInt("listen.gso_max_bytes", 0)
if maxBytes <= 0 {
maxBytes = MTU * segments
}
if maxBytes > maxGSOBatchBytes {
u.l.WithField("requested", maxBytes).Warn("listen.gso_max_bytes larger than UDP limit; clamping")
maxBytes = maxGSOBatchBytes
}
u.gso.MaxBytes = maxBytes
timeout := c.GetDuration("listen.gso_flush_timeout", defaultGSOFlushTimeout)
if timeout < 0 {
timeout = 0
}
u.gso.FlushTimeout = timeout
u.gso.Init()
}
func (u *StdConn) disableGSO() {
u.gso.m.Lock()
defer u.gso.m.Unlock()
u.enableGSO = false
_ = u.flushGSOlocked()
u.gso.Buf = nil
u.gso.packets = u.gso.packets[:0]
u.gso.SegSize = 0
u.stopGSOTimerLocked()
}
func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {

View File

@@ -34,6 +34,24 @@ type rawMessage struct {
Pad0 [4]byte
}
func setRawMessageControl(msg *rawMessage, buf []byte) {
if len(buf) == 0 {
msg.Hdr.Control = nil
msg.Hdr.Controllen = 0
return
}
msg.Hdr.Control = &buf[0]
msg.Hdr.Controllen = uint64(len(buf))
}
func getRawMessageControlLen(msg *rawMessage) int {
return int(msg.Hdr.Controllen)
}
func setCmsgLen(h *unix.Cmsghdr, l int) {
h.Len = uint64(l)
}
func (u *StdConn) PrepareRawMessages(n int, pg PacketBufferGetter) ([]rawMessage, []*packet.Packet, [][]byte) {
msgs := make([]rawMessage, n)
names := make([][]byte, n)
@@ -42,6 +60,7 @@ func (u *StdConn) PrepareRawMessages(n int, pg PacketBufferGetter) ([]rawMessage
for i := range packets {
packets[i] = pg()
}
//todo?
for i := range msgs {
names[i] = make([]byte, unix.SizeofSockaddrInet6)
@@ -55,6 +74,13 @@ func (u *StdConn) PrepareRawMessages(n int, pg PacketBufferGetter) ([]rawMessage
msgs[i].Hdr.Name = &names[i][0]
msgs[i].Hdr.Namelen = uint32(len(names[i]))
if u.enableGRO {
msgs[i].Hdr.Control = &packets[i].Control[0]
msgs[i].Hdr.Controllen = uint64(len(packets[i].Control))
} else {
msgs[i].Hdr.Control = nil
msgs[i].Hdr.Controllen = 0
}
}
return msgs, packets, names