diff --git a/inside.go b/inside.go index 0d53f952..69503abf 100644 --- a/inside.go +++ b/inside.go @@ -11,7 +11,7 @@ import ( "github.com/slackhq/nebula/routing" ) -func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) { +func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb []byte, batch *sendBatch, rejectBuf []byte, q int, localCache firewall.ConntrackCache) { err := newPacket(packet, false, fwPacket) if err != nil { if f.l.Level >= logrus.DebugLevel { @@ -53,7 +53,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet }) if hostinfo == nil { - f.rejectInside(packet, out, q) + f.rejectInside(packet, rejectBuf, q) if f.l.Level >= logrus.DebugLevel { f.l.WithField("vpnAddr", fwPacket.RemoteAddr). WithField("fwPacket", fwPacket). @@ -68,10 +68,10 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) if dropReason == nil { - f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q) + f.sendInsideMessage(hostinfo, packet, nb, batch, rejectBuf, q) } else { - f.rejectInside(packet, out, q) + f.rejectInside(packet, rejectBuf, q) if f.l.Level >= logrus.DebugLevel { hostinfo.logger(f.l). WithField("fwPacket", fwPacket). @@ -81,6 +81,63 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet } } +// sendInsideMessage encrypts a firewall-approved inside packet into the +// caller's batch slot for later sendmmsg flush. When hostinfo.remote is not +// valid we fall through to the relay slow path via the unbatched sendNoMetrics +// so relay behavior is unchanged. +func (f *Interface) sendInsideMessage(hostinfo *HostInfo, p, nb []byte, batch *sendBatch, rejectBuf []byte, q int) { + ci := hostinfo.ConnectionState + if ci.eKey == nil { + return + } + + if !hostinfo.remote.IsValid() { + // Slow path: relay fallback. Reuse rejectBuf as the ciphertext + // scratch; sendNoMetrics arranges header space for SendVia. + f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, p, nb, rejectBuf, q) + return + } + + scratch := batch.Next() + if scratch == nil { + // Batch full: bypass batching and send this packet directly so we + // never drop traffic on over-subscribed iterations. + f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, p, nb, rejectBuf, q) + return + } + + if noiseutil.EncryptLockNeeded { + ci.writeLock.Lock() + } + c := ci.messageCounter.Add(1) + + out := header.Encode(scratch, header.Version, header.Message, 0, hostinfo.remoteIndexId, c) + f.connectionManager.Out(hostinfo) + + if hostinfo.lastRebindCount != f.rebindCount { + //NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is + // finally used again. This tunnel would eventually be torn down and recreated if this action didn't help. + f.lightHouse.QueryServer(hostinfo.vpnAddrs[0]) + hostinfo.lastRebindCount = f.rebindCount + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter") + } + } + + out, err := ci.eKey.EncryptDanger(out, out, p, c, nb) + if noiseutil.EncryptLockNeeded { + ci.writeLock.Unlock() + } + if err != nil { + hostinfo.logger(f.l).WithError(err). + WithField("udpAddr", hostinfo.remote).WithField("counter", c). + Error("Failed to encrypt outgoing packet") + return + } + + batch.Commit(len(out), hostinfo.remote) +} + func (f *Interface) rejectInside(packet []byte, out []byte, q int) { if !f.firewall.InSendReject { return diff --git a/interface.go b/interface.go index 456ae00e..943d30e9 100644 --- a/interface.go +++ b/interface.go @@ -321,14 +321,15 @@ func (f *Interface) listenOut(i int) { } func (f *Interface) listenIn(reader overlay.Queue, i int) { - out := make([]byte, mtu) + rejectBuf := make([]byte, mtu) + batch := newSendBatch(sendBatchCap, udp.MTU+32) fwPacket := &firewall.Packet{} nb := make([]byte, 12, 12) conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) for { - batch, err := reader.ReadBatch() + pkts, err := reader.ReadBatch() if err != nil { if !f.closed.Load() { f.l.WithError(err).WithField("reader", i).Error("Error while reading outbound packet, closing") @@ -337,14 +338,28 @@ func (f *Interface) listenIn(reader overlay.Queue, i int) { break } - for _, pkt := range batch { - f.consumeInsidePacket(pkt, fwPacket, nb, out, i, conntrackCache.Get(f.l)) + batch.Reset() + for _, pkt := range pkts { + if batch.Len() >= batch.Cap() { + f.flushBatch(batch, i) + batch.Reset() + } + f.consumeInsidePacket(pkt, fwPacket, nb, batch, rejectBuf, i, conntrackCache.Get(f.l)) + } + if batch.Len() > 0 { + f.flushBatch(batch, i) } } f.l.Infof("overlay reader %v is done", i) } +func (f *Interface) flushBatch(batch *sendBatch, q int) { + if err := f.writers[q].WriteBatch(batch.bufs, batch.dsts); err != nil { + f.l.WithError(err).WithField("writer", q).Error("Failed to write outgoing batch") + } +} + func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) { c.RegisterReloadCallback(f.reloadFirewall) c.RegisterReloadCallback(f.reloadSendRecvError) diff --git a/udp/conn.go b/udp/conn.go index 30d89dec..652ff79d 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -8,6 +8,12 @@ import ( const MTU = 9001 +// MaxWriteBatch is the largest batch any Conn.WriteBatch implementation is +// required to accept. Callers SHOULD NOT pass more than this per call; Linux +// backends preallocate sendmmsg scratch sized to this value, so exceeding it +// only costs a chunked retry. +const MaxWriteBatch = 128 + type EncReader func( addr netip.AddrPort, payload []byte, @@ -18,6 +24,12 @@ type Conn interface { LocalAddr() (netip.AddrPort, error) ListenOut(r EncReader) error WriteTo(b []byte, addr netip.AddrPort) error + // WriteBatch sends a contiguous batch of packets, each with its own + // destination. bufs and addrs must have the same length. Linux uses + // sendmmsg(2) for a single syscall; other backends fall back to a + // WriteTo loop. Returns on the first error; callers may observe a + // partial send if some packets went out before the error. + WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error ReloadConfig(c *config.C) SupportsMultipleReaders() bool Close() error @@ -40,6 +52,9 @@ func (NoopConn) SupportsMultipleReaders() bool { func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error { return nil } +func (NoopConn) WriteBatch(_ [][]byte, _ []netip.AddrPort) error { + return nil +} func (NoopConn) ReloadConfig(_ *config.C) { return } diff --git a/udp/udp_generic.go b/udp/udp_generic.go index 44632fed..9e39436f 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -42,6 +42,15 @@ func (u *GenericConn) WriteTo(b []byte, addr netip.AddrPort) error { return err } +func (u *GenericConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error { + for i, b := range bufs { + if _, err := u.UDPConn.WriteToUDPAddrPort(b, addrs[i]); err != nil { + return err + } + } + return nil +} + func (u *GenericConn) LocalAddr() (netip.AddrPort, error) { a := u.UDPConn.LocalAddr() diff --git a/udp/udp_linux_64.go b/udp/udp_linux_64.go index 48c5a978..96c548e7 100644 --- a/udp/udp_linux_64.go +++ b/udp/udp_linux_64.go @@ -55,3 +55,23 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { return msgs, buffers, names } + +// prepareWriteMessages allocates one Mmsghdr/iovec/sockaddr scratch per slot, +// wired up so each writeMsgs[i] already points at writeIovs[i] and +// writeNames[i]. Callers fill in the iovec Base/Len, the sockaddr bytes, and +// Namelen before each sendmmsg. +func (u *StdConn) prepareWriteMessages(n int) { + u.writeMsgs = make([]rawMessage, n) + u.writeIovs = make([]iovec, n) + u.writeNames = make([][]byte, n) + for i := range u.writeMsgs { + u.writeNames[i] = make([]byte, unix.SizeofSockaddrInet6) + u.writeMsgs[i].Hdr.Iov = &u.writeIovs[i] + u.writeMsgs[i].Hdr.Iovlen = 1 + u.writeMsgs[i].Hdr.Name = &u.writeNames[i][0] + } +} + +func setIovLen(v *iovec, n int) { + v.Len = uint64(n) +}