mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 16:34:25 +01:00
cursed gso
This commit is contained in:
191
cmd/gso/gso.go
Normal file
191
cmd/gso/gso.go
Normal file
@@ -0,0 +1,191 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const (
|
||||
// UDP_SEGMENT enables GSO segmentation
|
||||
UDP_SEGMENT = 103
|
||||
// Maximum GSO segment size (typical MTU - headers)
|
||||
maxGSOSize = 1400
|
||||
)
|
||||
|
||||
func main() {
|
||||
destAddr := flag.String("dest", "10.4.0.16:4202", "Destination address")
|
||||
gsoSize := flag.Int("gso", 1400, "GSO segment size")
|
||||
totalSize := flag.Int("size", 14000, "Total payload size to send")
|
||||
count := flag.Int("count", 1, "Number of packets to send")
|
||||
flag.Parse()
|
||||
|
||||
if *gsoSize > maxGSOSize {
|
||||
log.Fatalf("GSO size %d exceeds maximum %d", *gsoSize, maxGSOSize)
|
||||
}
|
||||
|
||||
// Resolve destination address
|
||||
_, err := net.ResolveUDPAddr("udp", *destAddr)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to resolve address: %v", err)
|
||||
}
|
||||
|
||||
// Create a raw UDP socket with GSO support
|
||||
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to create socket: %v", err)
|
||||
}
|
||||
defer unix.Close(fd)
|
||||
|
||||
// Bind to a local address
|
||||
localAddr := &unix.SockaddrInet4{
|
||||
Port: 0, // Let the system choose a port
|
||||
}
|
||||
if err := unix.Bind(fd, localAddr); err != nil {
|
||||
log.Fatalf("Failed to bind socket: %v", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Sending UDP packets with GSO enabled\n")
|
||||
fmt.Printf("Destination: %s\n", *destAddr)
|
||||
fmt.Printf("GSO segment size: %d bytes\n", *gsoSize)
|
||||
fmt.Printf("Total payload size: %d bytes\n", *totalSize)
|
||||
fmt.Printf("Number of packets: %d\n\n", *count)
|
||||
|
||||
// Create payload
|
||||
payload := make([]byte, *totalSize)
|
||||
for i := range payload {
|
||||
payload[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
dest := netip.MustParseAddrPort(*destAddr)
|
||||
|
||||
//if err := unix.SetsockoptInt(fd, unix.SOL_UDP, unix.UDP_SEGMENT, 1400); err != nil {
|
||||
// panic(err)
|
||||
//}
|
||||
|
||||
for i := 0; i < *count; i++ {
|
||||
err := WriteBatch(fd, payload, dest, uint16(*gsoSize), true)
|
||||
if err != nil {
|
||||
log.Printf("Send error on packet %d: %v", i, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if (i+1)%100 == 0 || i == *count-1 {
|
||||
fmt.Printf("Sent %d packets\n", i+1)
|
||||
}
|
||||
}
|
||||
fmt.Printf("now, let's send without the correct ctrl header\n")
|
||||
time.Sleep(time.Second)
|
||||
for i := 0; i < *count; i++ {
|
||||
err := WriteBatch(fd, payload, dest, uint16(*gsoSize), false)
|
||||
if err != nil {
|
||||
log.Printf("Send error on packet %d: %v", i, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if (i+1)%100 == 0 || i == *count-1 {
|
||||
fmt.Printf("Sent %d packets\n", i+1)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func WriteBatch(fd int, payload []byte, addr netip.AddrPort, segSize uint16, withHeader bool) error {
|
||||
msgs := make([]rawMessage, 0, 1)
|
||||
iovs := make([]iovec, 0, 1)
|
||||
names := make([][unix.SizeofSockaddrInet6]byte, 0, 1)
|
||||
|
||||
sent := 0
|
||||
|
||||
pkts := []BatchPacket{
|
||||
{
|
||||
Payload: payload,
|
||||
Addr: addr,
|
||||
},
|
||||
}
|
||||
|
||||
for _, pkt := range pkts {
|
||||
if len(pkt.Payload) == 0 {
|
||||
sent++
|
||||
continue
|
||||
}
|
||||
|
||||
msgs = append(msgs, rawMessage{})
|
||||
iovs = append(iovs, iovec{})
|
||||
names = append(names, [unix.SizeofSockaddrInet6]byte{})
|
||||
|
||||
idx := len(msgs) - 1
|
||||
msg := &msgs[idx]
|
||||
iov := &iovs[idx]
|
||||
name := &names[idx]
|
||||
|
||||
setIovecSlice(iov, pkt.Payload)
|
||||
msg.Hdr.Iov = iov
|
||||
msg.Hdr.Iovlen = 1
|
||||
|
||||
if withHeader {
|
||||
setRawMessageControl(msg, buildGSOControlMessage(segSize)) //
|
||||
} else {
|
||||
setRawMessageControl(msg, nil) //
|
||||
}
|
||||
|
||||
msg.Hdr.Flags = 0
|
||||
|
||||
nameLen, err := encodeSockaddr(name[:], pkt.Addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
msg.Hdr.Name = &name[0]
|
||||
msg.Hdr.Namelen = nameLen
|
||||
}
|
||||
|
||||
if len(msgs) == 0 {
|
||||
return errors.New("nothing to write")
|
||||
}
|
||||
|
||||
offset := 0
|
||||
for offset < len(msgs) {
|
||||
n, _, errno := unix.Syscall6(
|
||||
unix.SYS_SENDMMSG,
|
||||
uintptr(fd),
|
||||
uintptr(unsafe.Pointer(&msgs[offset])),
|
||||
uintptr(len(msgs)-offset),
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
|
||||
if errno != 0 {
|
||||
if errno == unix.EINTR {
|
||||
continue
|
||||
}
|
||||
return &net.OpError{Op: "sendmmsg", Err: errno}
|
||||
}
|
||||
|
||||
if n == 0 {
|
||||
break
|
||||
}
|
||||
offset += int(n)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildGSOControlMessage(segSize uint16) []byte {
|
||||
control := make([]byte, unix.CmsgSpace(2))
|
||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
||||
hdr.Level = unix.SOL_UDP
|
||||
hdr.Type = unix.UDP_SEGMENT
|
||||
setCmsgLen(hdr, unix.CmsgLen(2))
|
||||
binary.NativeEndian.PutUint16(control[unix.CmsgLen(0):unix.CmsgLen(0)+2], uint16(segSize))
|
||||
|
||||
return control
|
||||
}
|
||||
85
cmd/gso/helper.go
Normal file
85
cmd/gso/helper.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
type iovec struct {
|
||||
Base *byte
|
||||
Len uint64
|
||||
}
|
||||
|
||||
type msghdr struct {
|
||||
Name *byte
|
||||
Namelen uint32
|
||||
Pad0 [4]byte
|
||||
Iov *iovec
|
||||
Iovlen uint64
|
||||
Control *byte
|
||||
Controllen uint64
|
||||
Flags int32
|
||||
Pad1 [4]byte
|
||||
}
|
||||
|
||||
type rawMessage struct {
|
||||
Hdr msghdr
|
||||
Len uint32
|
||||
Pad0 [4]byte
|
||||
}
|
||||
|
||||
type BatchPacket struct {
|
||||
Payload []byte
|
||||
Addr netip.AddrPort
|
||||
}
|
||||
|
||||
func encodeSockaddr(dst []byte, addr netip.AddrPort) (uint32, error) {
|
||||
if addr.Addr().Is4() {
|
||||
if !addr.Addr().Is4() {
|
||||
return 0, fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
|
||||
}
|
||||
var sa unix.RawSockaddrInet4
|
||||
sa.Family = unix.AF_INET
|
||||
sa.Addr = addr.Addr().As4()
|
||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
|
||||
size := unix.SizeofSockaddrInet4
|
||||
copy(dst[:size], (*(*[unix.SizeofSockaddrInet4]byte)(unsafe.Pointer(&sa)))[:])
|
||||
return uint32(size), nil
|
||||
}
|
||||
|
||||
var sa unix.RawSockaddrInet6
|
||||
sa.Family = unix.AF_INET6
|
||||
sa.Addr = addr.Addr().As16()
|
||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
|
||||
size := unix.SizeofSockaddrInet6
|
||||
copy(dst[:size], (*(*[unix.SizeofSockaddrInet6]byte)(unsafe.Pointer(&sa)))[:])
|
||||
return uint32(size), nil
|
||||
}
|
||||
|
||||
func setRawMessageControl(msg *rawMessage, buf []byte) {
|
||||
if len(buf) == 0 {
|
||||
msg.Hdr.Control = nil
|
||||
msg.Hdr.Controllen = 0
|
||||
return
|
||||
}
|
||||
msg.Hdr.Control = &buf[0]
|
||||
msg.Hdr.Controllen = uint64(len(buf))
|
||||
}
|
||||
|
||||
func setCmsgLen(h *unix.Cmsghdr, l int) {
|
||||
h.Len = uint64(l)
|
||||
}
|
||||
|
||||
func setIovecSlice(iov *iovec, b []byte) {
|
||||
if len(b) == 0 {
|
||||
iov.Base = nil
|
||||
iov.Len = 0
|
||||
return
|
||||
}
|
||||
iov.Base = &b[0]
|
||||
iov.Len = uint64(len(b))
|
||||
}
|
||||
@@ -518,12 +518,12 @@ func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
|
||||
if cm.punchy.GetTargetEverything() {
|
||||
hostinfo.remotes.ForEach(cm.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) {
|
||||
cm.metricsTxPunchy.Inc(1)
|
||||
cm.intf.outside.WriteTo([]byte{1}, addr)
|
||||
cm.intf.outside.WriteDirect([]byte{1}, addr)
|
||||
})
|
||||
|
||||
} else if hostinfo.remote.IsValid() {
|
||||
cm.metricsTxPunchy.Inc(1)
|
||||
cm.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
|
||||
cm.intf.outside.WriteDirect([]byte{1}, hostinfo.remote)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -15,7 +15,8 @@ import (
|
||||
|
||||
// TODO: In a 5Gbps test, 1024 is not sufficient. With a 1400 MTU this is about 1.4Gbps of window, assuming full packets.
|
||||
// 4092 should be sufficient for 5Gbps
|
||||
const ReplayWindow = 8192
|
||||
// TODO this is a horrible amount of RAM to waste per-tunnel
|
||||
const ReplayWindow = 0xffff / 2
|
||||
|
||||
type ConnectionState struct {
|
||||
eKey *NebulaCipherState
|
||||
|
||||
@@ -348,7 +348,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
msg = existing.HandshakePacket[2]
|
||||
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
|
||||
if addr.IsValid() {
|
||||
err := f.outside.WriteTo(msg, addr)
|
||||
err := f.outside.WriteDirect(msg, addr)
|
||||
if err != nil {
|
||||
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
||||
@@ -417,7 +417,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
// Do the send
|
||||
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
|
||||
if addr.IsValid() {
|
||||
err = f.outside.WriteTo(msg, addr)
|
||||
err = f.outside.WriteDirect(msg, addr)
|
||||
if err != nil {
|
||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
|
||||
@@ -238,7 +238,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
||||
var sentTo []netip.AddrPort
|
||||
hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr netip.AddrPort, _ bool) {
|
||||
hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
|
||||
err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
|
||||
err := hm.outside.WriteDirect(hostinfo.HandshakePacket[0], addr)
|
||||
if err != nil {
|
||||
hostinfo.logger(hm.l).WithField("udpAddr", addr).
|
||||
WithField("initiatorIndex", hostinfo.localIndexId).
|
||||
|
||||
17
inside.go
17
inside.go
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/noiseutil"
|
||||
"github.com/slackhq/nebula/packet"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
)
|
||||
|
||||
@@ -324,7 +325,7 @@ func (f *Interface) SendVia(via *HostInfo,
|
||||
via.logger(f.l).WithError(err).Info("Failed to EncryptDanger in sendVia")
|
||||
return
|
||||
}
|
||||
err = f.writers[0].WriteTo(out, via.remote)
|
||||
err = f.writers[0].WriteDirect(out, via.remote)
|
||||
if err != nil {
|
||||
via.logger(f.l).WithError(err).Info("Failed to WriteTo in sendVia")
|
||||
}
|
||||
@@ -384,19 +385,29 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
||||
}
|
||||
|
||||
if remote.IsValid() {
|
||||
err = f.writers[q].WriteTo(out, remote)
|
||||
pkt := packet.GetPool().Get()
|
||||
copy(pkt.Payload[:], out)
|
||||
pkt.Payload = pkt.Payload[:len(out)]
|
||||
pkt.Addr = remote
|
||||
err = f.writers[q].WriteTo(pkt)
|
||||
if err != nil {
|
||||
hostinfo.logger(f.l).WithError(err).
|
||||
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
|
||||
}
|
||||
} else if hostinfo.remote.IsValid() {
|
||||
err = f.writers[q].WriteTo(out, hostinfo.remote)
|
||||
pkt := packet.GetPool().Get()
|
||||
copy(pkt.Payload, out)
|
||||
pkt.Payload = pkt.Payload[:len(out)]
|
||||
pkt.Addr = hostinfo.remote
|
||||
err = f.writers[q].WriteTo(pkt)
|
||||
if err != nil {
|
||||
hostinfo.logger(f.l).WithError(err).
|
||||
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
|
||||
}
|
||||
} else {
|
||||
// Try to send via a relay
|
||||
|
||||
//todo relay is slow sorryyy
|
||||
for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
|
||||
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
|
||||
if err != nil {
|
||||
|
||||
37
interface.go
37
interface.go
@@ -207,7 +207,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
||||
l: c.l,
|
||||
}
|
||||
|
||||
ifce.pktPool = packet.NewPool()
|
||||
ifce.pktPool = packet.GetPool()
|
||||
|
||||
ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
|
||||
ifce.reQueryEvery.Store(c.reQueryEvery)
|
||||
@@ -327,6 +327,26 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
||||
f.wg.Done()
|
||||
}
|
||||
|
||||
//// todo why? understand!
|
||||
//func normalizeGROSegSize(segSize, total int) int {
|
||||
// if segCount > 1 && total > 0 {
|
||||
// avg := total / segCount
|
||||
// if avg > 0 {
|
||||
// if segSize > avg {
|
||||
// if segSize-8 == avg {
|
||||
// segSize = avg
|
||||
// } else if segSize > total {
|
||||
// segSize = avg
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// if segSize > total {
|
||||
// segSize = total
|
||||
// }
|
||||
// return segSize
|
||||
//}
|
||||
|
||||
func (f *Interface) workerIn(i int, ctx context.Context) {
|
||||
lhh := f.lightHouse.NewRequestHandler()
|
||||
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||
@@ -338,7 +358,18 @@ func (f *Interface) workerIn(i int, ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case p := <-f.inbound:
|
||||
f.readOutsidePackets(p.Addr, nil, result2[:0], p.Payload, h, fwPacket2, lhh, nb2, i, conntrackCache.Get(f.l))
|
||||
if p.SegSize > 0 && p.SegSize < len(p.Payload) {
|
||||
for offset := 0; offset < len(p.Payload); offset += p.SegSize {
|
||||
end := offset + p.SegSize
|
||||
if end > len(p.Payload) {
|
||||
end = len(p.Payload)
|
||||
}
|
||||
f.readOutsidePackets(p.Addr, nil, result2[:0], p.Payload[offset:end], h, fwPacket2, lhh, nb2, i, conntrackCache.Get(f.l))
|
||||
}
|
||||
} else {
|
||||
f.readOutsidePackets(p.Addr, nil, result2[:0], p.Payload, h, fwPacket2, lhh, nb2, i, conntrackCache.Get(f.l))
|
||||
}
|
||||
|
||||
f.pktPool.Put(p)
|
||||
case <-ctx.Done():
|
||||
f.wg.Done()
|
||||
@@ -357,7 +388,7 @@ func (f *Interface) workerOut(i int, ctx context.Context) {
|
||||
select {
|
||||
case data := <-f.outbound:
|
||||
f.consumeInsidePacket(data.Payload, fwPacket1, nb1, result1, i, conntrackCache.Get(f.l))
|
||||
f.pktPool.Put(data)
|
||||
//f.pktPool.Put(data) //todo if err pls put packet back
|
||||
case <-ctx.Done():
|
||||
f.wg.Done()
|
||||
return
|
||||
|
||||
@@ -1329,7 +1329,7 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
|
||||
go func() {
|
||||
time.Sleep(lhh.lh.punchy.GetDelay())
|
||||
lhh.lh.metricHolepunchTx.Inc(1)
|
||||
lhh.lh.punchConn.WriteTo(empty, vpnPeer)
|
||||
lhh.lh.punchConn.WriteDirect(empty, vpnPeer)
|
||||
}()
|
||||
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
|
||||
@@ -519,7 +519,7 @@ func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) {
|
||||
f.messageMetrics.Tx(header.RecvError, 0, 1)
|
||||
|
||||
b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0)
|
||||
_ = f.outside.WriteTo(b, endpoint)
|
||||
_ = f.outside.WriteDirect(b, endpoint)
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("index", index).
|
||||
WithField("udpAddr", endpoint).
|
||||
|
||||
@@ -3,27 +3,36 @@ package packet
|
||||
import (
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const Size = 9001
|
||||
const Size = 0xffff
|
||||
|
||||
type Packet struct {
|
||||
Payload []byte
|
||||
Control []byte
|
||||
SegSize int
|
||||
Addr netip.AddrPort
|
||||
}
|
||||
|
||||
func New() *Packet {
|
||||
return &Packet{Payload: make([]byte, Size)}
|
||||
return &Packet{
|
||||
Payload: make([]byte, Size),
|
||||
Control: make([]byte, unix.CmsgSpace(2)),
|
||||
}
|
||||
}
|
||||
|
||||
type Pool struct {
|
||||
pool sync.Pool
|
||||
}
|
||||
|
||||
func NewPool() *Pool {
|
||||
return &Pool{
|
||||
pool: sync.Pool{New: func() any { return New() }},
|
||||
}
|
||||
var bigPool = &Pool{
|
||||
pool: sync.Pool{New: func() any { return New() }},
|
||||
}
|
||||
|
||||
func GetPool() *Pool {
|
||||
return bigPool
|
||||
}
|
||||
|
||||
func (p *Pool) Get() *Packet {
|
||||
|
||||
@@ -17,7 +17,8 @@ type Conn interface {
|
||||
Rebind() error
|
||||
LocalAddr() (netip.AddrPort, error)
|
||||
ListenOut(pg PacketBufferGetter, pc chan *packet.Packet) error
|
||||
WriteTo(b []byte, addr netip.AddrPort) error
|
||||
WriteTo(p *packet.Packet) error
|
||||
WriteDirect(b []byte, port netip.AddrPort) error
|
||||
ReloadConfig(c *config.C)
|
||||
Close() error
|
||||
}
|
||||
|
||||
351
udp/udp_linux.go
351
udp/udp_linux.go
@@ -5,9 +5,11 @@ package udp
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
@@ -19,13 +21,136 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultGSOMaxSegments = 16
|
||||
defaultGSOFlushTimeout = 150 * time.Microsecond
|
||||
maxGSOBatchBytes = 0xFFFF
|
||||
)
|
||||
|
||||
var (
|
||||
errGSOFallback = errors.New("udp gso fallback")
|
||||
errGSODisabled = errors.New("udp gso disabled")
|
||||
)
|
||||
|
||||
var readTimeout = unix.NsecToTimeval(int64(time.Millisecond * 500))
|
||||
|
||||
type gsoState struct {
|
||||
m sync.Mutex
|
||||
Buf []byte
|
||||
Addr netip.AddrPort
|
||||
SegSize int
|
||||
MaxSegments int
|
||||
MaxBytes int
|
||||
FlushTimeout time.Duration
|
||||
Timer *time.Timer
|
||||
|
||||
packets []*packet.Packet
|
||||
msg rawMessage
|
||||
name [unix.SizeofSockaddrInet6]byte
|
||||
iov []iovec
|
||||
ctrl []byte
|
||||
}
|
||||
|
||||
func (g *gsoState) Init() {
|
||||
g.iov = make([]iovec, g.MaxSegments)
|
||||
for i := 0; i < g.MaxSegments; i++ {
|
||||
g.iov[i] = iovec{}
|
||||
}
|
||||
g.msg.Hdr.Iov = &g.iov[0]
|
||||
g.msg.Hdr.Iovlen = 1
|
||||
|
||||
g.packets = make([]*packet.Packet, 0, g.MaxSegments)
|
||||
g.ctrl = make([]byte, unix.CmsgSpace(2))
|
||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&g.ctrl[0]))
|
||||
hdr.Level = unix.SOL_UDP
|
||||
hdr.Type = unix.UDP_SEGMENT
|
||||
setCmsgLen(hdr, unix.CmsgLen(2))
|
||||
g.msg.Hdr.Control = &g.ctrl[0]
|
||||
g.msg.Hdr.Controllen = uint64(len(g.ctrl))
|
||||
|
||||
g.name = [unix.SizeofSockaddrInet6]byte{}
|
||||
g.msg.Hdr.Name = &g.name[0]
|
||||
|
||||
}
|
||||
|
||||
func (g *gsoState) setSegSizeLocked(segSize int) {
|
||||
g.SegSize = segSize
|
||||
x := unix.CmsgLen(0)
|
||||
binary.LittleEndian.PutUint16(g.ctrl[x:x+2], uint16(segSize))
|
||||
}
|
||||
|
||||
func (g *gsoState) setNameLocked(x netip.AddrPort, isV4 bool) {
|
||||
g.Addr = x
|
||||
nameLen := encodeSockaddr(g.name[:], g.Addr, isV4)
|
||||
g.msg.Hdr.Name = &g.name[0]
|
||||
g.msg.Hdr.Namelen = nameLen
|
||||
}
|
||||
|
||||
func encodeSockaddr(dst []byte, addr netip.AddrPort, isV4 bool) uint32 {
|
||||
if isV4 {
|
||||
//todo?
|
||||
//if !addr.Addr().Is4() {
|
||||
// return 0, fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
|
||||
//}
|
||||
var sa unix.RawSockaddrInet4
|
||||
sa.Family = unix.AF_INET
|
||||
sa.Addr = addr.Addr().As4()
|
||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
|
||||
size := unix.SizeofSockaddrInet4
|
||||
copy(dst[:size], (*(*[unix.SizeofSockaddrInet4]byte)(unsafe.Pointer(&sa)))[:])
|
||||
return uint32(size)
|
||||
}
|
||||
|
||||
var sa unix.RawSockaddrInet6
|
||||
sa.Family = unix.AF_INET6
|
||||
sa.Addr = addr.Addr().As16()
|
||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
|
||||
size := unix.SizeofSockaddrInet6
|
||||
copy(dst[:size], (*(*[unix.SizeofSockaddrInet6]byte)(unsafe.Pointer(&sa)))[:])
|
||||
return uint32(size)
|
||||
}
|
||||
|
||||
func (g *gsoState) sendmsgLocked(fd int) error {
|
||||
//name already set
|
||||
//ctrl already set
|
||||
//g.iov = g.iov[:0]
|
||||
g.msg.Hdr.Iovlen = uint64(len(g.packets))
|
||||
for i := range g.packets {
|
||||
g.iov[i].Base = &g.packets[i].Payload[0]
|
||||
g.iov[i].Len = uint64(len(g.packets[i].Payload))
|
||||
}
|
||||
|
||||
const flags = 0
|
||||
for {
|
||||
_, _, err := unix.Syscall(
|
||||
unix.SYS_SENDMSG,
|
||||
uintptr(fd),
|
||||
uintptr(unsafe.Pointer(&g.msg)),
|
||||
uintptr(flags),
|
||||
)
|
||||
//todo no matter what, reset things
|
||||
for i := range g.packets {
|
||||
pool := packet.GetPool()
|
||||
pool.Put(g.packets[i])
|
||||
}
|
||||
g.packets = g.packets[:0]
|
||||
|
||||
if err != 0 {
|
||||
return &net.OpError{Op: "sendmsg", Err: err}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
type StdConn struct {
|
||||
sysFd int
|
||||
isV4 bool
|
||||
l *logrus.Logger
|
||||
batch int
|
||||
sysFd int
|
||||
isV4 bool
|
||||
l *logrus.Logger
|
||||
batch int
|
||||
enableGRO bool
|
||||
enableGSO bool
|
||||
gso gsoState
|
||||
}
|
||||
|
||||
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||
@@ -145,15 +270,47 @@ func (u *StdConn) ListenOut(pg PacketBufferGetter, pc chan *packet.Packet) error
|
||||
ip, _ = netip.AddrFromSlice(names[i][8:24])
|
||||
}
|
||||
out.Addr = netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4]))
|
||||
ctrlLen := getRawMessageControlLen(&msgs[i])
|
||||
if ctrlLen > 0 {
|
||||
packets[i].SegSize = parseGROControl(packets[i].Control[:ctrlLen])
|
||||
} else {
|
||||
packets[i].SegSize = 0
|
||||
}
|
||||
|
||||
pc <- out
|
||||
|
||||
//rotate this packet out so we don't overwrite it
|
||||
packets[i] = pg()
|
||||
msgs[i].Hdr.Iov.Base = &packets[i].Payload[0]
|
||||
if u.enableGRO {
|
||||
msgs[i].Hdr.Control = &packets[i].Control[0]
|
||||
msgs[i].Hdr.Controllen = uint64(cap(packets[i].Control))
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseGROControl(control []byte) int {
|
||||
if len(control) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
cmsgs, err := unix.ParseSocketControlMessage(control)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
for _, c := range cmsgs {
|
||||
if c.Header.Level == unix.SOL_UDP && c.Header.Type == unix.UDP_GRO && len(c.Data) >= 2 {
|
||||
segSize := int(binary.LittleEndian.Uint16(c.Data[:2]))
|
||||
return segSize
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) {
|
||||
for {
|
||||
n, _, err := unix.Syscall6(
|
||||
@@ -201,11 +358,123 @@ func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
|
||||
if u.isV4 {
|
||||
return u.writeTo4(b, ip)
|
||||
func (u *StdConn) WriteTo(p *packet.Packet) error {
|
||||
if u.enableGSO && p.Addr.IsValid() {
|
||||
if err := u.queueGSOPacket(p); err == nil {
|
||||
return nil
|
||||
} else if !errors.Is(err, errGSOFallback) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return u.writeTo6(b, ip)
|
||||
|
||||
var err error
|
||||
if u.isV4 {
|
||||
err = u.writeTo4(p.Payload, p.Addr)
|
||||
} else {
|
||||
err = u.writeTo4(p.Payload, p.Addr)
|
||||
}
|
||||
packet.GetPool().Put(p)
|
||||
return err
|
||||
}
|
||||
|
||||
func (u *StdConn) WriteDirect(b []byte, addr netip.AddrPort) error {
|
||||
if u.isV4 {
|
||||
return u.writeTo4(b, addr)
|
||||
}
|
||||
return u.writeTo6(b, addr)
|
||||
}
|
||||
|
||||
func (u *StdConn) scheduleGSOFlushLocked() {
|
||||
if u.gso.Timer == nil {
|
||||
u.gso.Timer = time.AfterFunc(u.gso.FlushTimeout, u.gsoFlushTimer)
|
||||
return
|
||||
}
|
||||
u.gso.Timer.Reset(u.gso.FlushTimeout)
|
||||
}
|
||||
|
||||
func (u *StdConn) stopGSOTimerLocked() {
|
||||
if u.gso.Timer != nil {
|
||||
u.gso.Timer.Stop()
|
||||
u.gso.Timer = nil //todo I also don't like this
|
||||
}
|
||||
}
|
||||
|
||||
func (u *StdConn) queueGSOPacket(p *packet.Packet) error {
|
||||
if len(p.Payload) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
u.gso.m.Lock()
|
||||
defer u.gso.m.Unlock()
|
||||
|
||||
if !u.enableGSO || !p.Addr.IsValid() || len(p.Payload) > u.gso.MaxBytes {
|
||||
if err := u.flushGSOlocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
return errGSOFallback
|
||||
}
|
||||
|
||||
if len(u.gso.packets) == 0 {
|
||||
u.gso.setNameLocked(p.Addr, u.isV4)
|
||||
u.gso.SegSize = len(p.Payload)
|
||||
u.gso.packets = append(u.gso.packets, p)
|
||||
} else if p.Addr != u.gso.Addr || len(p.Payload) != u.gso.SegSize {
|
||||
if err := u.flushGSOlocked(); err != nil {
|
||||
return err
|
||||
} //todo deal with "one small packet" case
|
||||
u.gso.setNameLocked(p.Addr, u.isV4)
|
||||
u.gso.SegSize = len(p.Payload)
|
||||
u.gso.packets = append(u.gso.packets, p)
|
||||
} else {
|
||||
u.gso.packets = append(u.gso.packets, p)
|
||||
}
|
||||
|
||||
//big todo
|
||||
//if len(u.gso.Buf)+len(p.Payload) > u.gso.MaxBytes {
|
||||
// if err := u.flushGSOlocked(); err != nil {
|
||||
// return err
|
||||
// }
|
||||
// u.gso.setNameLocked(p.Addr, u.isV4)
|
||||
// u.gso.SegSize = len(p.Payload)
|
||||
// u.gso.packets = append(u.gso.packets, p)
|
||||
//}
|
||||
|
||||
if len(u.gso.packets) >= u.gso.MaxSegments || u.gso.FlushTimeout <= 0 {
|
||||
return u.flushGSOlocked()
|
||||
}
|
||||
|
||||
u.scheduleGSOFlushLocked()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *StdConn) flushGSOlocked() error {
|
||||
if len(u.gso.packets) == 0 {
|
||||
u.stopGSOTimerLocked()
|
||||
return nil
|
||||
}
|
||||
|
||||
u.stopGSOTimerLocked()
|
||||
|
||||
if u.gso.SegSize <= 0 {
|
||||
return errGSOFallback
|
||||
}
|
||||
|
||||
err := u.gso.sendmsgLocked(u.sysFd)
|
||||
if errors.Is(err, errGSODisabled) {
|
||||
u.l.WithField("addr", u.gso.Addr).Warn("UDP GSO disabled by kernel, falling back to sendto")
|
||||
u.enableGSO = false
|
||||
//todo!
|
||||
//return u.sendSegmentsIndividually(payload, addr, segSize)
|
||||
}
|
||||
u.gso.SegSize = 0
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (u *StdConn) gsoFlushTimer() {
|
||||
u.gso.m.Lock()
|
||||
_ = u.flushGSOlocked()
|
||||
u.gso.m.Unlock()
|
||||
}
|
||||
|
||||
func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
|
||||
@@ -308,6 +577,72 @@ func (u *StdConn) ReloadConfig(c *config.C) {
|
||||
u.l.WithError(err).Error("Failed to set listen.so_mark")
|
||||
}
|
||||
}
|
||||
u.configureGRO(true)
|
||||
u.configureGSO(c)
|
||||
}
|
||||
|
||||
func (u *StdConn) configureGRO(enable bool) {
|
||||
if enable == u.enableGRO {
|
||||
return
|
||||
}
|
||||
|
||||
if enable {
|
||||
if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 1); err != nil {
|
||||
u.l.WithError(err).Warn("Failed to enable UDP GRO")
|
||||
return
|
||||
}
|
||||
u.enableGRO = true
|
||||
u.l.Info("UDP GRO enabled")
|
||||
return
|
||||
}
|
||||
|
||||
if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 0); err != nil && err != unix.ENOPROTOOPT {
|
||||
u.l.WithError(err).Warn("Failed to disable UDP GRO")
|
||||
}
|
||||
u.enableGRO = false
|
||||
}
|
||||
|
||||
func (u *StdConn) configureGSO(c *config.C) {
|
||||
enable := c.GetBool("listen.enable_gso", true)
|
||||
if !enable {
|
||||
u.disableGSO()
|
||||
} else {
|
||||
u.enableGSO = true
|
||||
}
|
||||
|
||||
segments := c.GetInt("listen.gso_max_segments", defaultGSOMaxSegments)
|
||||
if segments < 1 {
|
||||
segments = 1
|
||||
}
|
||||
u.gso.MaxSegments = segments
|
||||
|
||||
maxBytes := c.GetInt("listen.gso_max_bytes", 0)
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = MTU * segments
|
||||
}
|
||||
if maxBytes > maxGSOBatchBytes {
|
||||
u.l.WithField("requested", maxBytes).Warn("listen.gso_max_bytes larger than UDP limit; clamping")
|
||||
maxBytes = maxGSOBatchBytes
|
||||
}
|
||||
u.gso.MaxBytes = maxBytes
|
||||
|
||||
timeout := c.GetDuration("listen.gso_flush_timeout", defaultGSOFlushTimeout)
|
||||
if timeout < 0 {
|
||||
timeout = 0
|
||||
}
|
||||
u.gso.FlushTimeout = timeout
|
||||
u.gso.Init()
|
||||
}
|
||||
|
||||
func (u *StdConn) disableGSO() {
|
||||
u.gso.m.Lock()
|
||||
defer u.gso.m.Unlock()
|
||||
u.enableGSO = false
|
||||
_ = u.flushGSOlocked()
|
||||
u.gso.Buf = nil
|
||||
u.gso.packets = u.gso.packets[:0]
|
||||
u.gso.SegSize = 0
|
||||
u.stopGSOTimerLocked()
|
||||
}
|
||||
|
||||
func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
|
||||
|
||||
@@ -34,6 +34,24 @@ type rawMessage struct {
|
||||
Pad0 [4]byte
|
||||
}
|
||||
|
||||
func setRawMessageControl(msg *rawMessage, buf []byte) {
|
||||
if len(buf) == 0 {
|
||||
msg.Hdr.Control = nil
|
||||
msg.Hdr.Controllen = 0
|
||||
return
|
||||
}
|
||||
msg.Hdr.Control = &buf[0]
|
||||
msg.Hdr.Controllen = uint64(len(buf))
|
||||
}
|
||||
|
||||
func getRawMessageControlLen(msg *rawMessage) int {
|
||||
return int(msg.Hdr.Controllen)
|
||||
}
|
||||
|
||||
func setCmsgLen(h *unix.Cmsghdr, l int) {
|
||||
h.Len = uint64(l)
|
||||
}
|
||||
|
||||
func (u *StdConn) PrepareRawMessages(n int, pg PacketBufferGetter) ([]rawMessage, []*packet.Packet, [][]byte) {
|
||||
msgs := make([]rawMessage, n)
|
||||
names := make([][]byte, n)
|
||||
@@ -42,6 +60,7 @@ func (u *StdConn) PrepareRawMessages(n int, pg PacketBufferGetter) ([]rawMessage
|
||||
for i := range packets {
|
||||
packets[i] = pg()
|
||||
}
|
||||
//todo?
|
||||
|
||||
for i := range msgs {
|
||||
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
||||
@@ -55,6 +74,13 @@ func (u *StdConn) PrepareRawMessages(n int, pg PacketBufferGetter) ([]rawMessage
|
||||
|
||||
msgs[i].Hdr.Name = &names[i][0]
|
||||
msgs[i].Hdr.Namelen = uint32(len(names[i]))
|
||||
if u.enableGRO {
|
||||
msgs[i].Hdr.Control = &packets[i].Control[0]
|
||||
msgs[i].Hdr.Controllen = uint64(len(packets[i].Control))
|
||||
} else {
|
||||
msgs[i].Hdr.Control = nil
|
||||
msgs[i].Hdr.Controllen = 0
|
||||
}
|
||||
}
|
||||
|
||||
return msgs, packets, names
|
||||
|
||||
Reference in New Issue
Block a user