mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-23 08:54:25 +01:00
write in batches
This commit is contained in:
139
inside.go
139
inside.go
@@ -11,6 +11,145 @@ import (
|
|||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// consumeInsidePackets processes multiple packets in a batch for improved performance
|
||||||
|
// packets: slice of packet buffers to process
|
||||||
|
// sizes: slice of packet sizes
|
||||||
|
// count: number of packets to process
|
||||||
|
// outs: slice of output buffers (one per packet) with virtio headroom
|
||||||
|
// q: queue index
|
||||||
|
// localCache: firewall conntrack cache
|
||||||
|
func (f *Interface) consumeInsidePackets(packets [][]byte, sizes []int, count int, outs [][]byte, q int, localCache firewall.ConntrackCache) {
|
||||||
|
// Reusable per-packet state
|
||||||
|
fwPacket := &firewall.Packet{}
|
||||||
|
nb := make([]byte, 12, 12)
|
||||||
|
|
||||||
|
// Accumulate encrypted packets for batch sending
|
||||||
|
batchPackets := make([][]byte, 0, count)
|
||||||
|
batchAddrs := make([]netip.AddrPort, 0, count)
|
||||||
|
|
||||||
|
// Process each packet in the batch
|
||||||
|
for i := 0; i < count; i++ {
|
||||||
|
packet := packets[i][:sizes[i]]
|
||||||
|
out := outs[i]
|
||||||
|
|
||||||
|
// Inline the consumeInsidePacket logic for better performance
|
||||||
|
err := newPacket(packet, false, fwPacket)
|
||||||
|
if err != nil {
|
||||||
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
|
f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ignore local broadcast packets
|
||||||
|
if f.dropLocalBroadcast {
|
||||||
|
if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if f.myVpnAddrsTable.Contains(fwPacket.RemoteAddr) {
|
||||||
|
// Immediately forward packets from self to self.
|
||||||
|
if immediatelyForwardToSelf {
|
||||||
|
_, err := f.readers[q].Write(packet)
|
||||||
|
if err != nil {
|
||||||
|
f.l.WithError(err).Error("Failed to forward to tun")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ignore multicast packets
|
||||||
|
if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
|
||||||
|
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
|
||||||
|
})
|
||||||
|
|
||||||
|
if hostinfo == nil {
|
||||||
|
f.rejectInside(packet, out, q)
|
||||||
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
|
f.l.WithField("vpnAddr", fwPacket.RemoteAddr).
|
||||||
|
WithField("fwPacket", fwPacket).
|
||||||
|
Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks")
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !ready {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
|
||||||
|
if dropReason != nil {
|
||||||
|
f.rejectInside(packet, out, q)
|
||||||
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
|
hostinfo.logger(f.l).
|
||||||
|
WithField("fwPacket", fwPacket).
|
||||||
|
WithField("reason", dropReason).
|
||||||
|
Debugln("dropping outbound packet")
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encrypt and prepare packet for batch sending
|
||||||
|
ci := hostinfo.ConnectionState
|
||||||
|
if ci.eKey == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if this needs relay - if so, send immediately and skip batching
|
||||||
|
useRelay := !hostinfo.remote.IsValid()
|
||||||
|
if useRelay {
|
||||||
|
// Handle relay sends individually (less common path)
|
||||||
|
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, packet, nb, out, q)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encrypt the packet for batch sending
|
||||||
|
if noiseutil.EncryptLockNeeded {
|
||||||
|
ci.writeLock.Lock()
|
||||||
|
}
|
||||||
|
c := ci.messageCounter.Add(1)
|
||||||
|
out = header.Encode(out, header.Version, header.Message, 0, hostinfo.remoteIndexId, c)
|
||||||
|
f.connectionManager.Out(hostinfo)
|
||||||
|
|
||||||
|
// Query lighthouse if needed
|
||||||
|
if hostinfo.lastRebindCount != f.rebindCount {
|
||||||
|
f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
|
||||||
|
hostinfo.lastRebindCount = f.rebindCount
|
||||||
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
|
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err = ci.eKey.EncryptDanger(out, out, packet, c, nb)
|
||||||
|
if noiseutil.EncryptLockNeeded {
|
||||||
|
ci.writeLock.Unlock()
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
hostinfo.logger(f.l).WithError(err).
|
||||||
|
WithField("counter", c).
|
||||||
|
Error("Failed to encrypt outgoing packet")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add to batch
|
||||||
|
batchPackets = append(batchPackets, out)
|
||||||
|
batchAddrs = append(batchAddrs, hostinfo.remote)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send all accumulated packets in one batch
|
||||||
|
if len(batchPackets) > 0 {
|
||||||
|
n, err := f.writers[q].WriteMulti(batchPackets, batchAddrs)
|
||||||
|
if err != nil {
|
||||||
|
f.l.WithError(err).WithField("sent", n).WithField("total", len(batchPackets)).Error("Failed to send batch")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
|
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
|
||||||
err := newPacket(packet, false, fwPacket)
|
err := newPacket(packet, false, fwPacket)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
17
interface.go
17
interface.go
@@ -333,12 +333,13 @@ func (f *Interface) listenInBatch(reader io.ReadWriteCloser, batchReader BatchRe
|
|||||||
}
|
}
|
||||||
sizes := make([]int, batchSize)
|
sizes := make([]int, batchSize)
|
||||||
|
|
||||||
// Per-packet state (reused across batches)
|
// Allocate output buffers for batch processing (one per packet)
|
||||||
// Allocate out buffer with virtio header headroom to avoid copies on write
|
// Each has virtio header headroom to avoid copies on write
|
||||||
|
outs := make([][]byte, batchSize)
|
||||||
|
for idx := range outs {
|
||||||
outBuf := make([]byte, virtioNetHdrLen+mtu)
|
outBuf := make([]byte, virtioNetHdrLen+mtu)
|
||||||
out := outBuf[virtioNetHdrLen:]
|
outs[idx] = outBuf[virtioNetHdrLen:] // Slice starting after headroom
|
||||||
fwPacket := &firewall.Packet{}
|
}
|
||||||
nb := make([]byte, 12, 12)
|
|
||||||
|
|
||||||
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||||
|
|
||||||
@@ -354,10 +355,8 @@ func (f *Interface) listenInBatch(reader io.ReadWriteCloser, batchReader BatchRe
|
|||||||
os.Exit(2)
|
os.Exit(2)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process each packet in the batch
|
// Process all packets in the batch at once
|
||||||
for j := 0; j < n; j++ {
|
f.consumeInsidePackets(bufs, sizes, n, outs, i, conntrackCache.Get(f.l))
|
||||||
f.consumeInsidePacket(bufs[j][:sizes[j]], fwPacket, nb, out, i, conntrackCache.Get(f.l))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ type Conn interface {
|
|||||||
LocalAddr() (netip.AddrPort, error)
|
LocalAddr() (netip.AddrPort, error)
|
||||||
ListenOut(r EncReader)
|
ListenOut(r EncReader)
|
||||||
WriteTo(b []byte, addr netip.AddrPort) error
|
WriteTo(b []byte, addr netip.AddrPort) error
|
||||||
|
WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error)
|
||||||
ReloadConfig(c *config.C)
|
ReloadConfig(c *config.C)
|
||||||
Close() error
|
Close() error
|
||||||
}
|
}
|
||||||
@@ -36,6 +37,9 @@ func (NoopConn) ListenOut(_ EncReader) {
|
|||||||
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
|
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
func (NoopConn) WriteMulti(_ [][]byte, _ []netip.AddrPort) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
func (NoopConn) ReloadConfig(_ *config.C) {
|
func (NoopConn) ReloadConfig(_ *config.C) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -140,6 +140,17 @@ func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WriteMulti sends multiple packets - fallback implementation without sendmmsg
|
||||||
|
func (u *StdConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) {
|
||||||
|
for i := range packets {
|
||||||
|
err := u.WriteTo(packets[i], addrs[i])
|
||||||
|
if err != nil {
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len(packets), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
|
func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
|
||||||
a := u.UDPConn.LocalAddr()
|
a := u.UDPConn.LocalAddr()
|
||||||
|
|
||||||
|
|||||||
@@ -194,6 +194,19 @@ func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
|
|||||||
return u.writeTo6(b, ip)
|
return u.writeTo6(b, ip)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *StdConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) {
|
||||||
|
if len(packets) != len(addrs) {
|
||||||
|
return 0, fmt.Errorf("packets and addrs length mismatch")
|
||||||
|
}
|
||||||
|
if len(packets) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
if u.isV4 {
|
||||||
|
return u.writeMulti4(packets, addrs)
|
||||||
|
}
|
||||||
|
return u.writeMulti6(packets, addrs)
|
||||||
|
}
|
||||||
|
|
||||||
func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
|
func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
|
||||||
var rsa unix.RawSockaddrInet6
|
var rsa unix.RawSockaddrInet6
|
||||||
rsa.Family = unix.AF_INET6
|
rsa.Family = unix.AF_INET6
|
||||||
@@ -248,6 +261,78 @@ func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *StdConn) writeMulti4(packets [][]byte, addrs []netip.AddrPort) (int, error) {
|
||||||
|
msgs, iovecs, names := u.PrepareWriteMessages4(len(packets))
|
||||||
|
|
||||||
|
for i := range packets {
|
||||||
|
if !addrs[i].Addr().Is4() {
|
||||||
|
return i, ErrInvalidIPv6RemoteForSocket
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup the packet buffer
|
||||||
|
iovecs[i].Base = &packets[i][0]
|
||||||
|
iovecs[i].Len = uint64(len(packets[i]))
|
||||||
|
|
||||||
|
// Setup the destination address
|
||||||
|
rsa := (*unix.RawSockaddrInet4)(unsafe.Pointer(&names[i][0]))
|
||||||
|
rsa.Family = unix.AF_INET
|
||||||
|
rsa.Addr = addrs[i].Addr().As4()
|
||||||
|
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], addrs[i].Port())
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
n, _, err := unix.Syscall6(
|
||||||
|
unix.SYS_SENDMMSG,
|
||||||
|
uintptr(u.sysFd),
|
||||||
|
uintptr(unsafe.Pointer(&msgs[0])),
|
||||||
|
uintptr(len(msgs)),
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != 0 {
|
||||||
|
return int(n), &net.OpError{Op: "sendmmsg", Err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
return int(n), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *StdConn) writeMulti6(packets [][]byte, addrs []netip.AddrPort) (int, error) {
|
||||||
|
msgs, iovecs, names := u.PrepareWriteMessages6(len(packets))
|
||||||
|
|
||||||
|
for i := range packets {
|
||||||
|
// Setup the packet buffer
|
||||||
|
iovecs[i].Base = &packets[i][0]
|
||||||
|
iovecs[i].Len = uint64(len(packets[i]))
|
||||||
|
|
||||||
|
// Setup the destination address
|
||||||
|
rsa := (*unix.RawSockaddrInet6)(unsafe.Pointer(&names[i][0]))
|
||||||
|
rsa.Family = unix.AF_INET6
|
||||||
|
rsa.Addr = addrs[i].Addr().As16()
|
||||||
|
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], addrs[i].Port())
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
n, _, err := unix.Syscall6(
|
||||||
|
unix.SYS_SENDMMSG,
|
||||||
|
uintptr(u.sysFd),
|
||||||
|
uintptr(unsafe.Pointer(&msgs[0])),
|
||||||
|
uintptr(len(msgs)),
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != 0 {
|
||||||
|
return int(n), &net.OpError{Op: "sendmmsg", Err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
return int(n), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (u *StdConn) ReloadConfig(c *config.C) {
|
func (u *StdConn) ReloadConfig(c *config.C) {
|
||||||
b := c.GetInt("listen.read_buffer", 0)
|
b := c.GetInt("listen.read_buffer", 0)
|
||||||
if b > 0 {
|
if b > 0 {
|
||||||
|
|||||||
@@ -55,3 +55,41 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
|
|||||||
|
|
||||||
return msgs, buffers, names
|
return msgs, buffers, names
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *StdConn) PrepareWriteMessages4(n int) ([]rawMessage, []iovec, [][]byte) {
|
||||||
|
msgs := make([]rawMessage, n)
|
||||||
|
iovecs := make([]iovec, n)
|
||||||
|
names := make([][]byte, n)
|
||||||
|
|
||||||
|
for i := range msgs {
|
||||||
|
names[i] = make([]byte, unix.SizeofSockaddrInet4)
|
||||||
|
|
||||||
|
// Point to the iovec in the slice
|
||||||
|
msgs[i].Hdr.Iov = &iovecs[i]
|
||||||
|
msgs[i].Hdr.Iovlen = 1
|
||||||
|
|
||||||
|
msgs[i].Hdr.Name = &names[i][0]
|
||||||
|
msgs[i].Hdr.Namelen = unix.SizeofSockaddrInet4
|
||||||
|
}
|
||||||
|
|
||||||
|
return msgs, iovecs, names
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *StdConn) PrepareWriteMessages6(n int) ([]rawMessage, []iovec, [][]byte) {
|
||||||
|
msgs := make([]rawMessage, n)
|
||||||
|
iovecs := make([]iovec, n)
|
||||||
|
names := make([][]byte, n)
|
||||||
|
|
||||||
|
for i := range msgs {
|
||||||
|
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
||||||
|
|
||||||
|
// Point to the iovec in the slice
|
||||||
|
msgs[i].Hdr.Iov = &iovecs[i]
|
||||||
|
msgs[i].Hdr.Iovlen = 1
|
||||||
|
|
||||||
|
msgs[i].Hdr.Name = &names[i][0]
|
||||||
|
msgs[i].Hdr.Namelen = unix.SizeofSockaddrInet6
|
||||||
|
}
|
||||||
|
|
||||||
|
return msgs, iovecs, names
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user