Compare commits

..

2 Commits

Author SHA1 Message Date
Ryan
9253f36a3c tweak defaults and turn on gsogro on linux by default 2025-11-06 13:34:58 -05:00
Ryan
c9a695c2bf try with sendmmsg merged back 2025-11-06 10:56:53 -05:00
12 changed files with 388 additions and 60 deletions

View File

@@ -15,7 +15,7 @@ import (
// TODO: In a 5Gbps test, 1024 is not sufficient. With a 1400 MTU this is about 1.4Gbps of window, assuming full packets. // TODO: In a 5Gbps test, 1024 is not sufficient. With a 1400 MTU this is about 1.4Gbps of window, assuming full packets.
// 4092 should be sufficient for 5Gbps // 4092 should be sufficient for 5Gbps
const ReplayWindow = 1024 const ReplayWindow = 4096
type ConnectionState struct { type ConnectionState struct {
eKey *NebulaCipherState eKey *NebulaCipherState

106
inside.go
View File

@@ -11,19 +11,19 @@ import (
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
) )
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) { func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, queue func(netip.AddrPort, int), q int, localCache firewall.ConntrackCache) bool {
err := newPacket(packet, false, fwPacket) err := newPacket(packet, false, fwPacket)
if err != nil { if err != nil {
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err) f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
} }
return return false
} }
// Ignore local broadcast packets // Ignore local broadcast packets
if f.dropLocalBroadcast { if f.dropLocalBroadcast {
if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) { if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) {
return return false
} }
} }
@@ -40,12 +40,12 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
} }
// Otherwise, drop. On linux, we should never see these packets - Linux // Otherwise, drop. On linux, we should never see these packets - Linux
// routes packets from the nebula addr to the nebula addr through the loopback device. // routes packets from the nebula addr to the nebula addr through the loopback device.
return return false
} }
// Ignore multicast packets // Ignore multicast packets
if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() { if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() {
return return false
} }
hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) { hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
@@ -59,26 +59,26 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
WithField("fwPacket", fwPacket). WithField("fwPacket", fwPacket).
Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks") Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks")
} }
return return false
} }
if !ready { if !ready {
return return false
} }
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason == nil { if dropReason == nil {
f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q) return f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, queue, q)
} else {
f.rejectInside(packet, out, q)
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).
WithField("fwPacket", fwPacket).
WithField("reason", dropReason).
Debugln("dropping outbound packet")
}
} }
f.rejectInside(packet, out, q)
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).
WithField("fwPacket", fwPacket).
WithField("reason", dropReason).
Debugln("dropping outbound packet")
}
return false
} }
func (f *Interface) rejectInside(packet []byte, out []byte, q int) { func (f *Interface) rejectInside(packet []byte, out []byte, q int) {
@@ -117,7 +117,7 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
return return
} }
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q) _ = f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, nil, q)
} }
// Handshake will attempt to initiate a tunnel with the provided vpn address if it is within our vpn networks. This is a no-op if the tunnel is already established or being established // Handshake will attempt to initiate a tunnel with the provided vpn address if it is within our vpn networks. This is a no-op if the tunnel is already established or being established
@@ -228,7 +228,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
return return
} }
f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0) _ = f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, nil, 0)
} }
// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr // SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr
@@ -258,12 +258,12 @@ func (f *Interface) SendMessageToHostInfo(t header.MessageType, st header.Messag
func (f *Interface) send(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, p, nb, out []byte) { func (f *Interface) send(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, p, nb, out []byte) {
f.messageMetrics.Tx(t, st, 1) f.messageMetrics.Tx(t, st, 1)
f.sendNoMetrics(t, st, ci, hostinfo, netip.AddrPort{}, p, nb, out, 0) _ = f.sendNoMetrics(t, st, ci, hostinfo, netip.AddrPort{}, p, nb, out, nil, 0)
} }
func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte) { func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte) {
f.messageMetrics.Tx(t, st, 1) f.messageMetrics.Tx(t, st, 1)
f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0) _ = f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, nil, 0)
} }
// SendVia sends a payload through a Relay tunnel. No authentication or encryption is done // SendVia sends a payload through a Relay tunnel. No authentication or encryption is done
@@ -331,9 +331,12 @@ func (f *Interface) SendVia(via *HostInfo,
f.connectionManager.RelayUsed(relay.LocalIndex) f.connectionManager.RelayUsed(relay.LocalIndex)
} }
func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int) { // sendNoMetrics encrypts and writes/queues an outbound packet. It returns true
// when the payload has been handed to a caller-provided queue (meaning the
// caller is responsible for flushing it later).
func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, queue func(netip.AddrPort, int), q int) bool {
if ci.eKey == nil { if ci.eKey == nil {
return return false
} }
useRelay := !remote.IsValid() && !hostinfo.remote.IsValid() useRelay := !remote.IsValid() && !hostinfo.remote.IsValid()
fullOut := out fullOut := out
@@ -380,32 +383,39 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
WithField("udpAddr", remote).WithField("counter", c). WithField("udpAddr", remote).WithField("counter", c).
WithField("attemptedCounter", c). WithField("attemptedCounter", c).
Error("Failed to encrypt outgoing packet") Error("Failed to encrypt outgoing packet")
return return false
} }
if remote.IsValid() { dest := remote
err = f.writers[q].WriteTo(out, remote) if !dest.IsValid() {
if err != nil { dest = hostinfo.remote
hostinfo.logger(f.l).WithError(err).
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
}
} else if hostinfo.remote.IsValid() {
err = f.writers[q].WriteTo(out, hostinfo.remote)
if err != nil {
hostinfo.logger(f.l).WithError(err).
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
}
} else {
// Try to send via a relay
for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
if err != nil {
hostinfo.relayState.DeleteRelay(relayIP)
hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")
continue
}
f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true)
break
}
} }
if dest.IsValid() {
if queue != nil {
queue(dest, len(out))
return true
}
err = f.writers[q].WriteTo(out, dest)
if err != nil {
hostinfo.logger(f.l).WithError(err).
WithField("udpAddr", dest).Error("Failed to write outgoing packet")
}
return false
}
// Try to send via a relay
for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
if err != nil {
hostinfo.relayState.DeleteRelay(relayIP)
hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")
continue
}
f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true)
break
}
return false
} }

View File

@@ -25,10 +25,14 @@ import (
const ( const (
mtu = 9001 mtu = 9001
inboundBatchSizeDefault = 32 inboundBatchSizeDefault = 128
outboundBatchSizeDefault = 32 outboundBatchSizeDefault = 64
batchFlushIntervalDefault = 50 * time.Microsecond batchFlushIntervalDefault = 12 * time.Microsecond
maxOutstandingBatchesDefault = 1028 maxOutstandingBatchesDefault = 8
sendBatchSizeDefault = 64
maxPendingPacketsDefault = 32
maxPendingBytesDefault = 64 * 1024
maxSendBufPerRoutineDefault = 16
) )
type InterfaceConfig struct { type InterfaceConfig struct {
@@ -64,6 +68,9 @@ type BatchConfig struct {
OutboundBatchSize int OutboundBatchSize int
FlushInterval time.Duration FlushInterval time.Duration
MaxOutstandingPerChan int MaxOutstandingPerChan int
MaxPendingPackets int
MaxPendingBytes int
MaxSendBuffersPerChan int
} }
type Interface struct { type Interface struct {
@@ -120,10 +127,23 @@ type Interface struct {
packetBatchPool sync.Pool packetBatchPool sync.Pool
outboundBatchPool sync.Pool outboundBatchPool sync.Pool
sendPool sync.Pool
sendBufCache [][]*[]byte
sendBatchSize int
inboundBatchSize int inboundBatchSize int
outboundBatchSize int outboundBatchSize int
batchFlushInterval time.Duration batchFlushInterval time.Duration
maxOutstandingPerChan int maxOutstandingPerChan int
maxPendingPackets int
maxPendingBytes int
maxSendBufPerRoutine int
}
type outboundSend struct {
buf *[]byte
length int
addr netip.AddrPort
} }
type packetBatch struct { type packetBatch struct {
@@ -194,6 +214,63 @@ func (f *Interface) releaseOutboundBatch(b *outboundBatch) {
f.outboundBatchPool.Put(b) f.outboundBatchPool.Put(b)
} }
func (f *Interface) getSendBuffer(q int) *[]byte {
cache := f.sendBufCache[q]
if n := len(cache); n > 0 {
buf := cache[n-1]
f.sendBufCache[q] = cache[:n-1]
*buf = (*buf)[:0]
return buf
}
if v := f.sendPool.Get(); v != nil {
buf := v.(*[]byte)
*buf = (*buf)[:0]
return buf
}
b := make([]byte, mtu)
return &b
}
func (f *Interface) releaseSendBuffer(q int, buf *[]byte) {
if buf == nil {
return
}
*buf = (*buf)[:0]
cache := f.sendBufCache[q]
if len(cache) < f.maxSendBufPerRoutine {
f.sendBufCache[q] = append(cache, buf)
return
}
f.sendPool.Put(buf)
}
func (f *Interface) flushSendQueue(q int, pending *[]outboundSend, pendingBytes *int) {
if len(*pending) == 0 {
return
}
batch := make([]udp.BatchPacket, len(*pending))
for i, entry := range *pending {
batch[i] = udp.BatchPacket{
Payload: (*entry.buf)[:entry.length],
Addr: entry.addr,
}
}
sent, err := f.writers[q].WriteBatch(batch)
if err != nil {
f.l.WithError(err).WithField("sent", sent).Error("Failed to batch send packets")
}
for _, entry := range *pending {
f.releaseSendBuffer(q, entry.buf)
}
*pending = (*pending)[:0]
if pendingBytes != nil {
*pendingBytes = 0
}
}
type EncWriter interface { type EncWriter interface {
SendVia(via *HostInfo, SendVia(via *HostInfo,
relay *Relay, relay *Relay,
@@ -275,6 +352,15 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
if bc.MaxOutstandingPerChan <= 0 { if bc.MaxOutstandingPerChan <= 0 {
bc.MaxOutstandingPerChan = maxOutstandingBatchesDefault bc.MaxOutstandingPerChan = maxOutstandingBatchesDefault
} }
if bc.MaxPendingPackets <= 0 {
bc.MaxPendingPackets = maxPendingPacketsDefault
}
if bc.MaxPendingBytes <= 0 {
bc.MaxPendingBytes = maxPendingBytesDefault
}
if bc.MaxSendBuffersPerChan <= 0 {
bc.MaxSendBuffersPerChan = maxSendBufPerRoutineDefault
}
ifce := &Interface{ ifce := &Interface{
pki: c.pki, pki: c.pki,
hostMap: c.HostMap, hostMap: c.HostMap,
@@ -316,6 +402,10 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
outboundBatchSize: bc.OutboundBatchSize, outboundBatchSize: bc.OutboundBatchSize,
batchFlushInterval: bc.FlushInterval, batchFlushInterval: bc.FlushInterval,
maxOutstandingPerChan: bc.MaxOutstandingPerChan, maxOutstandingPerChan: bc.MaxOutstandingPerChan,
maxPendingPackets: bc.MaxPendingPackets,
maxPendingBytes: bc.MaxPendingBytes,
maxSendBufPerRoutine: bc.MaxSendBuffersPerChan,
sendBatchSize: bc.OutboundBatchSize,
} }
for i := 0; i < c.routines; i++ { for i := 0; i < c.routines; i++ {
@@ -340,6 +430,12 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
return newOutboundBatch(ifce.outboundBatchSize) return newOutboundBatch(ifce.outboundBatchSize)
}} }}
ifce.sendPool = sync.Pool{New: func() any {
buf := make([]byte, mtu)
return &buf
}}
ifce.sendBufCache = make([][]*[]byte, c.routines)
ifce.tryPromoteEvery.Store(c.tryPromoteEvery) ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
ifce.reQueryEvery.Store(c.reQueryEvery) ifce.reQueryEvery.Store(c.reQueryEvery)
ifce.reQueryWait.Store(int64(c.reQueryWait)) ifce.reQueryWait.Store(int64(c.reQueryWait))
@@ -539,18 +635,52 @@ func (f *Interface) workerOut(i int, ctx context.Context) {
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
fwPacket1 := &firewall.Packet{} fwPacket1 := &firewall.Packet{}
nb1 := make([]byte, 12, 12) nb1 := make([]byte, 12, 12)
result1 := make([]byte, mtu) pending := make([]outboundSend, 0, f.sendBatchSize)
pendingBytes := 0
maxPendingPackets := f.maxPendingPackets
if maxPendingPackets <= 0 {
maxPendingPackets = f.sendBatchSize
}
maxPendingBytes := f.maxPendingBytes
if maxPendingBytes <= 0 {
maxPendingBytes = f.sendBatchSize * mtu
}
for { for {
select { select {
case batch := <-f.outbound[i]: case batch := <-f.outbound[i]:
for _, data := range batch.payloads { for _, data := range batch.payloads {
f.consumeInsidePacket(*data, fwPacket1, nb1, result1, i, conntrackCache.Get(f.l)) sendBuf := f.getSendBuffer(i)
buf := (*sendBuf)[:0]
queue := func(addr netip.AddrPort, length int) {
if len(pending) >= maxPendingPackets || pendingBytes+length > maxPendingBytes {
f.flushSendQueue(i, &pending, &pendingBytes)
}
pending = append(pending, outboundSend{
buf: sendBuf,
length: length,
addr: addr,
})
pendingBytes += length
if len(pending) >= f.sendBatchSize || pendingBytes >= maxPendingBytes {
f.flushSendQueue(i, &pending, &pendingBytes)
}
}
sent := f.consumeInsidePacket(*data, fwPacket1, nb1, buf, queue, i, conntrackCache.Get(f.l))
if !sent {
f.releaseSendBuffer(i, sendBuf)
}
*data = (*data)[:mtu] *data = (*data)[:mtu]
f.outPool.Put(data) f.outPool.Put(data)
} }
f.releaseOutboundBatch(batch) f.releaseOutboundBatch(batch)
if len(pending) > 0 {
f.flushSendQueue(i, &pending, &pendingBytes)
}
case <-ctx.Done(): case <-ctx.Done():
if len(pending) > 0 {
f.flushSendQueue(i, &pending, &pendingBytes)
}
f.wg.Done() f.wg.Done()
return return
} }

View File

@@ -164,7 +164,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
for i := 0; i < routines; i++ { for i := 0; i < routines; i++ {
l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port))) l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port)))
udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64)) udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 128))
if err != nil { if err != nil {
return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err) return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
} }
@@ -226,6 +226,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
OutboundBatchSize: c.GetInt("batch.outbound_size", outboundBatchSizeDefault), OutboundBatchSize: c.GetInt("batch.outbound_size", outboundBatchSizeDefault),
FlushInterval: c.GetDuration("batch.flush_interval", batchFlushIntervalDefault), FlushInterval: c.GetDuration("batch.flush_interval", batchFlushIntervalDefault),
MaxOutstandingPerChan: c.GetInt("batch.max_outstanding", maxOutstandingBatchesDefault), MaxOutstandingPerChan: c.GetInt("batch.max_outstanding", maxOutstandingBatchesDefault),
MaxPendingPackets: c.GetInt("batch.max_pending_packets", 0),
MaxPendingBytes: c.GetInt("batch.max_pending_bytes", 0),
MaxSendBuffersPerChan: c.GetInt("batch.max_send_buffers_per_routine", 0),
} }
ifConfig := &InterfaceConfig{ ifConfig := &InterfaceConfig{

View File

@@ -18,10 +18,16 @@ type Conn interface {
LocalAddr() (netip.AddrPort, error) LocalAddr() (netip.AddrPort, error)
ListenOut(r EncReader) error ListenOut(r EncReader) error
WriteTo(b []byte, addr netip.AddrPort) error WriteTo(b []byte, addr netip.AddrPort) error
WriteBatch(pkts []BatchPacket) (int, error)
ReloadConfig(c *config.C) ReloadConfig(c *config.C)
Close() error Close() error
} }
type BatchPacket struct {
Payload []byte
Addr netip.AddrPort
}
type NoopConn struct{} type NoopConn struct{}
func (NoopConn) Rebind() error { func (NoopConn) Rebind() error {
@@ -36,6 +42,9 @@ func (NoopConn) ListenOut(_ EncReader) error {
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error { func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
return nil return nil
} }
func (NoopConn) WriteBatch(_ []BatchPacket) (int, error) {
return 0, nil
}
func (NoopConn) ReloadConfig(_ *config.C) { func (NoopConn) ReloadConfig(_ *config.C) {
return return
} }

View File

@@ -140,6 +140,17 @@ func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error {
} }
} }
func (u *StdConn) WriteBatch(pkts []BatchPacket) (int, error) {
sent := 0
for _, pkt := range pkts {
if err := u.WriteTo(pkt.Payload, pkt.Addr); err != nil {
return sent, err
}
sent++
}
return sent, nil
}
func (u *StdConn) LocalAddr() (netip.AddrPort, error) { func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
a := u.UDPConn.LocalAddr() a := u.UDPConn.LocalAddr()

View File

@@ -42,6 +42,17 @@ func (u *GenericConn) WriteTo(b []byte, addr netip.AddrPort) error {
return err return err
} }
func (u *GenericConn) WriteBatch(pkts []BatchPacket) (int, error) {
sent := 0
for _, pkt := range pkts {
if err := u.WriteTo(pkt.Payload, pkt.Addr); err != nil {
return sent, err
}
sent++
}
return sent, nil
}
func (u *GenericConn) LocalAddr() (netip.AddrPort, error) { func (u *GenericConn) LocalAddr() (netip.AddrPort, error) {
a := u.UDPConn.LocalAddr() a := u.UDPConn.LocalAddr()

View File

@@ -23,8 +23,8 @@ import (
var readTimeout = unix.NsecToTimeval(int64(time.Millisecond * 500)) var readTimeout = unix.NsecToTimeval(int64(time.Millisecond * 500))
const ( const (
defaultGSOMaxSegments = 8 defaultGSOMaxSegments = 128
defaultGSOFlushTimeout = 150 * time.Microsecond defaultGSOFlushTimeout = 80 * time.Microsecond
defaultGROReadBufferSize = MTU * defaultGSOMaxSegments defaultGROReadBufferSize = MTU * defaultGSOMaxSegments
maxGSOBatchBytes = 0xFFFF maxGSOBatchBytes = 0xFFFF
) )
@@ -343,6 +343,118 @@ func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
return u.writeTo6(b, ip) return u.writeTo6(b, ip)
} }
func (u *StdConn) WriteBatch(pkts []BatchPacket) (int, error) {
if len(pkts) == 0 {
return 0, nil
}
msgs := make([]rawMessage, 0, len(pkts))
iovs := make([]iovec, 0, len(pkts))
names := make([][unix.SizeofSockaddrInet6]byte, 0, len(pkts))
sent := 0
for _, pkt := range pkts {
if len(pkt.Payload) == 0 {
sent++
continue
}
if u.enableGSO && pkt.Addr.IsValid() {
if err := u.queueGSOPacket(pkt.Payload, pkt.Addr); err == nil {
sent++
continue
} else if !errors.Is(err, errGSOFallback) {
return sent, err
}
}
if !pkt.Addr.IsValid() {
if err := u.WriteTo(pkt.Payload, pkt.Addr); err != nil {
return sent, err
}
sent++
continue
}
msgs = append(msgs, rawMessage{})
iovs = append(iovs, iovec{})
names = append(names, [unix.SizeofSockaddrInet6]byte{})
idx := len(msgs) - 1
msg := &msgs[idx]
iov := &iovs[idx]
name := &names[idx]
setIovecSlice(iov, pkt.Payload)
msg.Hdr.Iov = iov
msg.Hdr.Iovlen = 1
setRawMessageControl(msg, nil)
msg.Hdr.Flags = 0
nameLen, err := u.encodeSockaddr(name[:], pkt.Addr)
if err != nil {
return sent, err
}
msg.Hdr.Name = &name[0]
msg.Hdr.Namelen = nameLen
}
if len(msgs) == 0 {
return sent, nil
}
offset := 0
for offset < len(msgs) {
n, _, errno := unix.Syscall6(
unix.SYS_SENDMMSG,
uintptr(u.sysFd),
uintptr(unsafe.Pointer(&msgs[offset])),
uintptr(len(msgs)-offset),
0,
0,
0,
)
if errno != 0 {
if errno == unix.EINTR {
continue
}
return sent + offset, &net.OpError{Op: "sendmmsg", Err: errno}
}
if n == 0 {
break
}
offset += int(n)
}
return sent + len(msgs), nil
}
func (u *StdConn) encodeSockaddr(dst []byte, addr netip.AddrPort) (uint32, error) {
if u.isV4 {
if !addr.Addr().Is4() {
return 0, fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
}
var sa unix.RawSockaddrInet4
sa.Family = unix.AF_INET
sa.Addr = addr.Addr().As4()
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
size := unix.SizeofSockaddrInet4
copy(dst[:size], (*(*[unix.SizeofSockaddrInet4]byte)(unsafe.Pointer(&sa)))[:])
return uint32(size), nil
}
var sa unix.RawSockaddrInet6
sa.Family = unix.AF_INET6
sa.Addr = addr.Addr().As16()
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
size := unix.SizeofSockaddrInet6
copy(dst[:size], (*(*[unix.SizeofSockaddrInet6]byte)(unsafe.Pointer(&sa)))[:])
return uint32(size), nil
}
func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error { func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
var rsa unix.RawSockaddrInet6 var rsa unix.RawSockaddrInet6
rsa.Family = unix.AF_INET6 rsa.Family = unix.AF_INET6
@@ -453,7 +565,7 @@ func (u *StdConn) configureGRO(c *config.C) {
return return
} }
enable := c.GetBool("listen.enable_gro", false) enable := c.GetBool("listen.enable_gro", true)
if enable == u.enableGRO { if enable == u.enableGRO {
if enable { if enable {
if size := c.GetInt("listen.gro_read_buffer", 0); size > 0 { if size := c.GetInt("listen.gro_read_buffer", 0); size > 0 {
@@ -482,7 +594,7 @@ func (u *StdConn) configureGRO(c *config.C) {
} }
func (u *StdConn) configureGSO(c *config.C) { func (u *StdConn) configureGSO(c *config.C) {
enable := c.GetBool("listen.enable_gso", false) enable := c.GetBool("listen.enable_gso", true)
if !enable { if !enable {
u.disableGSO() u.disableGSO()
} else { } else {

View File

@@ -77,3 +77,13 @@ func getRawMessageFlags(msg *rawMessage) int {
func setCmsgLen(h *unix.Cmsghdr, l int) { func setCmsgLen(h *unix.Cmsghdr, l int) {
h.Len = uint32(l) h.Len = uint32(l)
} }
func setIovecSlice(iov *iovec, b []byte) {
if len(b) == 0 {
iov.Base = nil
iov.Len = 0
return
}
iov.Base = &b[0]
iov.Len = uint32(len(b))
}

View File

@@ -80,3 +80,13 @@ func getRawMessageFlags(msg *rawMessage) int {
func setCmsgLen(h *unix.Cmsghdr, l int) { func setCmsgLen(h *unix.Cmsghdr, l int) {
h.Len = uint64(l) h.Len = uint64(l)
} }
func setIovecSlice(iov *iovec, b []byte) {
if len(b) == 0 {
iov.Base = nil
iov.Len = 0
return
}
iov.Base = &b[0]
iov.Len = uint64(len(b))
}

View File

@@ -304,6 +304,17 @@ func (u *RIOConn) WriteTo(buf []byte, ip netip.AddrPort) error {
return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
} }
func (u *RIOConn) WriteBatch(pkts []BatchPacket) (int, error) {
sent := 0
for _, pkt := range pkts {
if err := u.WriteTo(pkt.Payload, pkt.Addr); err != nil {
return sent, err
}
sent++
}
return sent, nil
}
func (u *RIOConn) LocalAddr() (netip.AddrPort, error) { func (u *RIOConn) LocalAddr() (netip.AddrPort, error) {
sa, err := windows.Getsockname(u.sock) sa, err := windows.Getsockname(u.sock)
if err != nil { if err != nil {

View File

@@ -106,6 +106,17 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error {
return nil return nil
} }
func (u *TesterConn) WriteBatch(pkts []BatchPacket) (int, error) {
sent := 0
for _, pkt := range pkts {
if err := u.WriteTo(pkt.Payload, pkt.Addr); err != nil {
return sent, err
}
sent++
}
return sent, nil
}
func (u *TesterConn) ListenOut(r EncReader) { func (u *TesterConn) ListenOut(r EncReader) {
for { for {
p, ok := <-u.RxPackets p, ok := <-u.RxPackets