Experimenting

This commit is contained in:
Nate Brown
2026-05-11 11:51:46 -05:00
parent b7e9939e92
commit 86cef88744
33 changed files with 691 additions and 560 deletions

View File

@@ -153,8 +153,8 @@ func (cm *connectionManager) Start(ctx context.Context) {
defer clockSource.Stop() defer clockSource.Stop()
p := []byte("") p := []byte("")
nb := make([]byte, 12, 12) // Long-lived buf for the traffic-check goroutine; never released.
out := make([]byte, mtu) buf := cm.intf.bufAlloc.Acquire()
for { for {
select { select {
@@ -169,13 +169,13 @@ func (cm *connectionManager) Start(ctx context.Context) {
break break
} }
cm.doTrafficCheck(localIndex, p, nb, out, now) cm.doTrafficCheck(localIndex, p, buf, now)
} }
} }
} }
} }
func (cm *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) { func (cm *connectionManager) doTrafficCheck(localIndex uint32, p []byte, buf *WireBuffer, now time.Time) {
decision, hostinfo, primary := cm.makeTrafficDecision(localIndex, now) decision, hostinfo, primary := cm.makeTrafficDecision(localIndex, now)
switch decision { switch decision {
@@ -199,7 +199,7 @@ func (cm *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte
cm.tryRehandshake(hostinfo) cm.tryRehandshake(hostinfo)
case sendTestPacket: case sendTestPacket:
cm.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out) cm.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, buf)
} }
cm.resetRelayTrafficCheck(hostinfo) cm.resetRelayTrafficCheck(hostinfo)
@@ -308,7 +308,9 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo
if err != nil { if err != nil {
cm.l.Error("failed to marshal Control message to migrate relay", "error", err) cm.l.Error("failed to marshal Control message to migrate relay", "error", err)
} else { } else {
cm.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu)) migBuf := cm.intf.bufAlloc.Acquire()
cm.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, migBuf)
cm.intf.bufAlloc.Release(migBuf)
cm.l.Info("send CreateRelayRequest", cm.l.Info("send CreateRelayRequest",
"relayFrom", req.RelayFromAddr, "relayFrom", req.RelayFromAddr,
"relayTo", req.RelayToAddr, "relayTo", req.RelayToAddr,

View File

@@ -67,9 +67,9 @@ func Test_NewConnectionManagerTest(t *testing.T) {
punchy := NewPunchyFromConfig(test.NewLogger(), conf) punchy := NewPunchyFromConfig(test.NewLogger(), conf)
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy) nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
nc.intf = ifce nc.intf = ifce
p := []byte("") p := []byte("")
nb := make([]byte, 12, 12) buf := NewWireBuffer(mtu, 0)
out := make([]byte, mtu)
// Add an ip we have established a connection w/ to hostmap // Add an ip we have established a connection w/ to hostmap
hostinfo := &HostInfo{ hostinfo := &HostInfo{
@@ -92,7 +92,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
assert.True(t, hostinfo.in.Load()) assert.True(t, hostinfo.in.Load())
// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded // Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) nc.doTrafficCheck(hostinfo.localIndexId, p, buf, time.Now())
assert.False(t, hostinfo.pendingDeletion.Load()) assert.False(t, hostinfo.pendingDeletion.Load())
assert.False(t, hostinfo.out.Load()) assert.False(t, hostinfo.out.Load())
assert.False(t, hostinfo.in.Load()) assert.False(t, hostinfo.in.Load())
@@ -100,7 +100,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
// Do another traffic check tick, this host should be pending deletion now // Do another traffic check tick, this host should be pending deletion now
nc.Out(hostinfo) nc.Out(hostinfo)
assert.True(t, hostinfo.out.Load()) assert.True(t, hostinfo.out.Load())
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) nc.doTrafficCheck(hostinfo.localIndexId, p, buf, time.Now())
assert.True(t, hostinfo.pendingDeletion.Load()) assert.True(t, hostinfo.pendingDeletion.Load())
assert.False(t, hostinfo.out.Load()) assert.False(t, hostinfo.out.Load())
assert.False(t, hostinfo.in.Load()) assert.False(t, hostinfo.in.Load())
@@ -108,7 +108,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
// Do a final traffic check tick, the host should now be removed // Do a final traffic check tick, the host should now be removed
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) nc.doTrafficCheck(hostinfo.localIndexId, p, buf, time.Now())
assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs) assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs)
assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
} }
@@ -149,9 +149,9 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
punchy := NewPunchyFromConfig(test.NewLogger(), conf) punchy := NewPunchyFromConfig(test.NewLogger(), conf)
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy) nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
nc.intf = ifce nc.intf = ifce
p := []byte("") p := []byte("")
nb := make([]byte, 12, 12) buf := NewWireBuffer(mtu, 0)
out := make([]byte, mtu)
// Add an ip we have established a connection w/ to hostmap // Add an ip we have established a connection w/ to hostmap
hostinfo := &HostInfo{ hostinfo := &HostInfo{
@@ -174,14 +174,14 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded // Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) nc.doTrafficCheck(hostinfo.localIndexId, p, buf, time.Now())
assert.False(t, hostinfo.pendingDeletion.Load()) assert.False(t, hostinfo.pendingDeletion.Load())
assert.False(t, hostinfo.out.Load()) assert.False(t, hostinfo.out.Load())
assert.False(t, hostinfo.in.Load()) assert.False(t, hostinfo.in.Load())
// Do another traffic check tick, this host should be pending deletion now // Do another traffic check tick, this host should be pending deletion now
nc.Out(hostinfo) nc.Out(hostinfo)
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) nc.doTrafficCheck(hostinfo.localIndexId, p, buf, time.Now())
assert.True(t, hostinfo.pendingDeletion.Load()) assert.True(t, hostinfo.pendingDeletion.Load())
assert.False(t, hostinfo.out.Load()) assert.False(t, hostinfo.out.Load())
assert.False(t, hostinfo.in.Load()) assert.False(t, hostinfo.in.Load())
@@ -190,7 +190,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
// We saw traffic, should no longer be pending deletion // We saw traffic, should no longer be pending deletion
nc.In(hostinfo) nc.In(hostinfo)
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) nc.doTrafficCheck(hostinfo.localIndexId, p, buf, time.Now())
assert.False(t, hostinfo.pendingDeletion.Load()) assert.False(t, hostinfo.pendingDeletion.Load())
assert.False(t, hostinfo.out.Load()) assert.False(t, hostinfo.out.Load())
assert.False(t, hostinfo.in.Load()) assert.False(t, hostinfo.in.Load())

View File

@@ -278,15 +278,9 @@ func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool {
} }
if !localOnly { if !localOnly {
c.f.send( buf := c.f.bufAlloc.Acquire()
header.CloseTunnel, c.f.send(header.CloseTunnel, 0, hostInfo.ConnectionState, hostInfo, []byte{}, buf)
0, c.f.bufAlloc.Release(buf)
hostInfo.ConnectionState,
hostInfo,
[]byte{},
make([]byte, 12, 12),
make([]byte, mtu),
)
} }
c.f.closeTunnel(hostInfo) c.f.closeTunnel(hostInfo)
@@ -296,11 +290,14 @@ func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool {
// CloseAllTunnels is just like CloseTunnel except it goes through and shuts them all down, optionally you can avoid shutting down lighthouse tunnels // CloseAllTunnels is just like CloseTunnel except it goes through and shuts them all down, optionally you can avoid shutting down lighthouse tunnels
// the int returned is a count of tunnels closed // the int returned is a count of tunnels closed
func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) { func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
// One WireBuffer for the whole shutdown loop.
buf := c.f.bufAlloc.Acquire()
defer c.f.bufAlloc.Release(buf)
shutdown := func(h *HostInfo) { shutdown := func(h *HostInfo) {
if excludeLighthouses && c.f.lightHouse.IsAnyLighthouseAddr(h.vpnAddrs) { if excludeLighthouses && c.f.lightHouse.IsAnyLighthouseAddr(h.vpnAddrs) {
return return
} }
c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, buf)
c.f.closeTunnel(h) c.f.closeTunnel(h)
c.l.Debug("Sending close tunnel message", c.l.Debug("Sending close tunnel message",

68
e2e/bench_test.go Normal file
View File

@@ -0,0 +1,68 @@
//go:build e2e_testing
// +build e2e_testing
package e2e
import (
"testing"
"time"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/e2e/router"
)
// BenchmarkHandshake measures end-to-end tunnel establishment time. The two
// nodes and the router are constructed once before the loop so the timed window
// is just the handshake itself: trigger packet -> handshake1 -> handshake2 ->
// cached packet replay -> arrival on the remote TUN. Between iterations we
// tear down both sides locally (no CloseTunnel notification on the wire) and
// re-inject the lighthouse address that closeTunnel cleared, so the next
// iteration runs through a fresh handshake against the same harness.
func BenchmarkHandshake(b *testing.B) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
// Default try_interval is 100ms. The handshake manager schedules handshake1
// on its OutboundHandshakeTimer rather than firing immediately on trigger
// (the trigger channel only fast-paths static hosts), so a 100ms default
// drowns the actual handshake cost. Drop it to 1ms so the bench reflects
// the computation, not the wheel cadence.
bovr := m{"handshakes": m{"try_interval": "1ms"}}
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", bovr)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "10.128.0.2/24", bovr)
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
myControl.Start()
theirControl.Start()
defer myControl.Stop()
defer theirControl.Stop()
r := router.NewR(b, myControl, theirControl)
r.CancelFlowLogs()
r.EnableFanIn()
trigger := BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
myControl.InjectTunPacket(trigger)
// RouteForAllUntilTxTun returns the moment the cached packet arrives at
// the remote TUN, which is also when both sides are fully established.
_ = r.RouteForAllUntilTxTun(theirControl)
b.StopTimer()
// Local-only close removes hostmap state on both sides without putting a
// CloseTunnel packet on the wire that we'd then have to drain. The
// closeTunnel path also clears learned lighthouse state for the peer
// when the last hostinfo for that addr goes away, so we re-inject.
myControl.CloseTunnel(theirVpnIpNet[0].Addr(), true)
theirControl.CloseTunnel(myVpnIpNet[0].Addr(), true)
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
b.StartTimer()
}
}

View File

@@ -165,7 +165,7 @@ func TestGoodHandshakeNoOverlap(t *testing.T) {
empty := []byte{} empty := []byte{}
t.Log("do something to cause a handshake") t.Log("do something to cause a handshake")
myControl.GetF().SendMessageToVpnAddr(header.Test, header.MessageNone, theirVpnIpNet[0].Addr(), empty, empty, empty) myControl.GetF().SendMessageToVpnAddr(header.Test, header.MessageNone, theirVpnIpNet[0].Addr(), empty, nebula.NewWireBuffer(9001, 0))
t.Log("Have them consume my stage 0 packet. They have a tunnel now") t.Log("Have them consume my stage 0 packet. They have a tunnel now")
theirControl.InjectUDPPacket(myControl.GetFromUDP(true)) theirControl.InjectUDPPacket(myControl.GetFromUDP(true))

View File

@@ -971,11 +971,11 @@ func (hm *HandshakeManager) continueHandshake(via ViaSender, hh *HandshakeHostIn
if f.l.Enabled(context.Background(), slog.LevelDebug) { if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("Sending stored packets", "count", len(hh.packetStore)) hostinfo.logger(f.l).Debug("Sending stored packets", "count", len(hh.packetStore))
} }
nb := make([]byte, 12, 12) buf := f.bufAlloc.Acquire()
out := make([]byte, mtu)
for _, cp := range hh.packetStore { for _, cp := range hh.packetStore {
cp.callback(cp.messageType, cp.messageSubType, hostinfo, cp.packet, nb, out) cp.callback(cp.messageType, cp.messageSubType, hostinfo, cp.packet, buf)
} }
f.bufAlloc.Release(buf)
f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore))) f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore)))
} }
@@ -1085,7 +1085,9 @@ func (hm *HandshakeManager) sendHandshakeResponse(via ViaSender, msg []byte, hos
// We received a valid handshake on this relay, so make sure the relay // We received a valid handshake on this relay, so make sure the relay
// state reflects that, in case it had been marked Disestablished. // state reflects that, in case it had been marked Disestablished.
via.relayHI.relayState.UpdateRelayForByIdxState(via.remoteIdx, Established) via.relayHI.relayState.UpdateRelayForByIdxState(via.remoteIdx, Established)
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) buf := f.bufAlloc.Acquire()
f.SendVia(via.relayHI, via.relay, msg, buf)
f.bufAlloc.Release(buf)
f.l.Info("Handshake message sent", append(logFields, "relay", via.relayHI.vpnAddrs[0])...) f.l.Info("Handshake message sent", append(logFields, "relay", via.relayHI.vpnAddrs[0])...)
} }
} }
@@ -1102,7 +1104,9 @@ func (hm *HandshakeManager) handleCheckAndCompleteError(err error, existing, hos
switch err { switch err {
case ErrAlreadySeen: case ErrAlreadySeen:
if existing.SetRemoteIfPreferred(f.hostMap, via) { if existing.SetRemoteIfPreferred(f.hostMap, via) {
f.SendMessageToVpnAddr(header.Test, header.TestRequest, hostinfo.vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) buf := f.bufAlloc.Acquire()
f.SendMessageToVpnAddr(header.Test, header.TestRequest, hostinfo.vpnAddrs[0], []byte(""), buf)
f.bufAlloc.Release(buf)
} }
// Resend the original response. The peer is committed to that response's // Resend the original response. The peer is committed to that response's
// ephemeral keys; a freshly-built one would have different keys and break // ephemeral keys; a freshly-built one would have different keys and break
@@ -1125,7 +1129,9 @@ func (hm *HandshakeManager) handleCheckAndCompleteError(err error, existing, hos
"responderIndex", hostinfo.localIndexId, "responderIndex", hostinfo.localIndexId,
"handshake", hsFields, "handshake", hsFields,
) )
f.SendMessageToVpnAddr(header.Test, header.TestRequest, hostinfo.vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) buf := f.bufAlloc.Acquire()
f.SendMessageToVpnAddr(header.Test, header.TestRequest, hostinfo.vpnAddrs[0], []byte(""), buf)
f.bufAlloc.Release(buf)
case ErrLocalIndexCollision: case ErrLocalIndexCollision:
f.l.Error("Failed to add HostInfo due to localIndex collision", f.l.Error("Failed to add HostInfo due to localIndex collision",

View File

@@ -80,15 +80,15 @@ func testCountTimerWheelEntries(tw *LockingTimerWheel[netip.Addr]) (c int) {
type mockEncWriter struct { type mockEncWriter struct {
} }
func (mw *mockEncWriter) SendMessageToVpnAddr(_ header.MessageType, _ header.MessageSubType, _ netip.Addr, _, _, _ []byte) { func (mw *mockEncWriter) SendMessageToVpnAddr(_ header.MessageType, _ header.MessageSubType, _ netip.Addr, _ []byte, _ *WireBuffer) {
return return
} }
func (mw *mockEncWriter) SendVia(_ *HostInfo, _ *Relay, _, _, _ []byte, _ bool) { func (mw *mockEncWriter) SendVia(_ *HostInfo, _ *Relay, _ []byte, _ *WireBuffer) {
return return
} }
func (mw *mockEncWriter) SendMessageToHostInfo(_ header.MessageType, _ header.MessageSubType, _ *HostInfo, _, _, _ []byte) { func (mw *mockEncWriter) SendMessageToHostInfo(_ header.MessageType, _ header.MessageSubType, _ *HostInfo, _ []byte, _ *WireBuffer) {
return return
} }

View File

@@ -308,7 +308,7 @@ type cachedPacket struct {
packet []byte packet []byte
} }
type packetCallback func(t header.MessageType, st header.MessageSubType, h *HostInfo, p, nb, out []byte) type packetCallback func(t header.MessageType, st header.MessageSubType, h *HostInfo, p []byte, buf *WireBuffer)
type cachedPacketMetrics struct { type cachedPacketMetrics struct {
sent metrics.Counter sent metrics.Counter
@@ -691,6 +691,7 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []netip.Prefix, ifce *Interfac
} }
} }
buf := ifce.bufAlloc.Acquire()
i.remotes.ForEach(preferredRanges, func(addr netip.AddrPort, preferred bool) { i.remotes.ForEach(preferredRanges, func(addr netip.AddrPort, preferred bool) {
if remote.IsValid() && (!addr.IsValid() || !preferred) { if remote.IsValid() && (!addr.IsValid() || !preferred) {
return return
@@ -698,8 +699,9 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []netip.Prefix, ifce *Interfac
// Try to send a test packet to that host, this should // Try to send a test packet to that host, this should
// cause it to detect a roaming event and switch remotes // cause it to detect a roaming event and switch remotes
ifce.sendTo(header.Test, header.TestRequest, i.ConnectionState, i, addr, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) ifce.sendTo(header.Test, header.TestRequest, i.ConnectionState, i, addr, []byte(""), buf)
}) })
ifce.bufAlloc.Release(buf)
} }
// Re query our lighthouses for new remotes occasionally // Re query our lighthouses for new remotes occasionally

197
inside.go
View File

@@ -8,12 +8,13 @@ import (
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/noiseutil"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
) )
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) { func (f *Interface) consumeInsidePacket(buf *WireBuffer, q int, localCache firewall.ConntrackCache) {
err := newPacket(packet, false, fwPacket) packet := buf.IPPacket()
err := newPacket(packet, false, buf.FwPacket)
if err != nil { if err != nil {
if f.l.Enabled(context.Background(), slog.LevelDebug) { if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("Error while validating outbound packet", f.l.Debug("Error while validating outbound packet",
@@ -26,12 +27,12 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
// Ignore local broadcast packets // Ignore local broadcast packets
if f.dropLocalBroadcast { if f.dropLocalBroadcast {
if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) { if f.myBroadcastAddrsTable.Contains(buf.FwPacket.RemoteAddr) {
return return
} }
} }
if f.myVpnAddrsTable.Contains(fwPacket.RemoteAddr) { if f.myVpnAddrsTable.Contains(buf.FwPacket.RemoteAddr) {
// Immediately forward packets from self to self. // Immediately forward packets from self to self.
// This should only happen on Darwin-based and FreeBSD hosts, which // This should only happen on Darwin-based and FreeBSD hosts, which
// routes packets from the Nebula addr to the Nebula addr through the Nebula // routes packets from the Nebula addr to the Nebula addr through the Nebula
@@ -48,20 +49,20 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
} }
// Ignore multicast packets // Ignore multicast packets
if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() { if f.dropMulticast && buf.FwPacket.RemoteAddr.IsMulticast() {
return return
} }
hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) { hostinfo, ready := f.getOrHandshakeConsiderRouting(buf.FwPacket, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
}) })
if hostinfo == nil { if hostinfo == nil {
f.rejectInside(packet, out, q) f.rejectInside(packet, buf.Out, q)
if f.l.Enabled(context.Background(), slog.LevelDebug) { if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks", f.l.Debug("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks",
"vpnAddr", fwPacket.RemoteAddr, "vpnAddr", buf.FwPacket.RemoteAddr,
"fwPacket", fwPacket, "fwPacket", buf.FwPacket,
) )
} }
return return
@@ -71,15 +72,15 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
return return
} }
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) dropReason := f.firewall.Drop(*buf.FwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason == nil { if dropReason == nil {
f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q) f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, buf, q)
} else { } else {
f.rejectInside(packet, out, q) f.rejectInside(packet, buf.Out, q)
if f.l.Enabled(context.Background(), slog.LevelDebug) { if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("dropping outbound packet", hostinfo.logger(f.l).Debug("dropping outbound packet",
"fwPacket", fwPacket, "fwPacket", buf.FwPacket,
"reason", dropReason, "reason", dropReason,
) )
} }
@@ -102,27 +103,27 @@ func (f *Interface) rejectInside(packet []byte, out []byte, q int) {
} }
} }
func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *HostInfo, nb, out []byte, q int) { func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *HostInfo, scratch []byte, buf *WireBuffer, q int) {
if !f.firewall.OutSendReject { if !f.firewall.OutSendReject {
return return
} }
out = iputil.CreateRejectPacket(packet, out) rejectIP := iputil.CreateRejectPacket(packet, scratch)
if len(out) == 0 { if len(rejectIP) == 0 {
return return
} }
if len(out) > iputil.MaxRejectPacketSize { if len(rejectIP) > iputil.MaxRejectPacketSize {
if f.l.Enabled(context.Background(), slog.LevelInfo) { if f.l.Enabled(context.Background(), slog.LevelInfo) {
f.l.Info("rejectOutside: packet too big, not sending", f.l.Info("rejectOutside: packet too big, not sending",
"packet", packet, "packet", packet,
"outPacket", out, "outPacket", rejectIP,
) )
} }
return return
} }
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q) f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, rejectIP, buf, q)
} }
// Handshake will attempt to initiate a tunnel with the provided vpn address. This is a no-op if the tunnel is already established or being established // Handshake will attempt to initiate a tunnel with the provided vpn address. This is a no-op if the tunnel is already established or being established
@@ -215,7 +216,7 @@ func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cac
} }
func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) { func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p []byte, buf *WireBuffer) {
fp := &firewall.Packet{} fp := &firewall.Packet{}
err := newPacket(p, false, fp) err := newPacket(p, false, fp)
if err != nil { if err != nil {
@@ -235,12 +236,12 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
return return
} }
f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0) f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, buf, 0)
} }
// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr. // SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr.
// This function ignores myVpnNetworksTable, and will always attempt to treat the address as a vpnAddr // This function ignores myVpnNetworksTable, and will always attempt to treat the address as a vpnAddr
func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) { func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p []byte, buf *WireBuffer) {
hostInfo, ready := f.handshakeManager.GetOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) { hostInfo, ready := f.handshakeManager.GetOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics) hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
}) })
@@ -258,113 +259,73 @@ func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.Message
return return
} }
f.SendMessageToHostInfo(t, st, hostInfo, p, nb, out) f.SendMessageToHostInfo(t, st, hostInfo, p, buf)
} }
func (f *Interface) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hi *HostInfo, p, nb, out []byte) { func (f *Interface) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hi *HostInfo, p []byte, buf *WireBuffer) {
f.send(t, st, hi.ConnectionState, hi, p, nb, out) f.send(t, st, hi.ConnectionState, hi, p, buf)
} }
func (f *Interface) send(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, p, nb, out []byte) { func (f *Interface) send(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, p []byte, buf *WireBuffer) {
f.messageMetrics.Tx(t, st, 1) f.messageMetrics.Tx(t, st, 1)
f.sendNoMetrics(t, st, ci, hostinfo, netip.AddrPort{}, p, nb, out, 0) f.sendNoMetrics(t, st, ci, hostinfo, netip.AddrPort{}, p, buf, 0)
} }
func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte) { func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p []byte, buf *WireBuffer) {
f.messageMetrics.Tx(t, st, 1) f.messageMetrics.Tx(t, st, 1)
f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0) f.sendNoMetrics(t, st, ci, hostinfo, remote, p, buf, 0)
} }
// SendVia sends a payload through a Relay tunnel. No authentication or encryption is done // SendVia sends a payload through a Relay tunnel. No authentication or encryption is done
// to the payload for the ultimate target host, making this a useful method for sending // to the payload for the ultimate target host, making this a useful method for sending
// handshake messages to peers through relay tunnels. // handshake messages to peers through relay tunnels.
// via is the HostInfo through which the message is relayed. //
// ad is the plaintext data to authenticate, but not encrypt // via is the HostInfo through which the message is relayed. ad is staged into
// nb is a buffer used to store the nonce value, re-used for performance reasons. // the inner-payload slot of buf and then AAD-only sealed under via's key by
// out is a buffer used to store the result of the Encrypt operation // SealRelayInPlace. The sendNoMetrics relay-forward path skips this entry
// q indicates which writer to use to send the packet. // point and calls sendViaInPlace directly because its inner ciphertext is
func (f *Interface) SendVia(via *HostInfo, // already in place from the encrypt step.
relay *Relay, func (f *Interface) SendVia(via *HostInfo, relay *Relay, ad []byte, buf *WireBuffer) {
ad, if header.Len+len(ad)+via.ConnectionState.eKey.Overhead() > cap(buf.Out) {
nb,
out []byte,
nocopy bool,
) {
if noiseutil.EncryptLockNeeded {
// NOTE: for goboring AESGCMTLS we need to lock because of the nonce check
via.ConnectionState.writeLock.Lock()
}
c := via.ConnectionState.messageCounter.Add(1)
out = header.Encode(out, header.Version, header.Message, header.MessageRelay, relay.RemoteIndex, c)
f.connectionManager.Out(via)
// Authenticate the header and payload, but do not encrypt for this message type.
// The payload consists of the inner, unencrypted Nebula header, as well as the end-to-end encrypted payload.
if len(out)+len(ad)+via.ConnectionState.eKey.Overhead() > cap(out) {
if noiseutil.EncryptLockNeeded {
via.ConnectionState.writeLock.Unlock()
}
via.logger(f.l).Error("SendVia out buffer not large enough for relay", via.logger(f.l).Error("SendVia out buffer not large enough for relay",
"outCap", cap(out), "outCap", cap(buf.Out),
"payloadLen", len(ad), "payloadLen", len(ad),
"headerLen", len(out), "headerLen", header.Len,
"cipherOverhead", via.ConnectionState.eKey.Overhead(), "cipherOverhead", via.ConnectionState.eKey.Overhead(),
) )
return return
} }
buf.StageRelayInner(ad)
f.sendViaInPlace(via, relay, len(ad), buf)
}
// The header bytes are written to the 'out' slice; Grow the slice to hold the header and associated data payload. // sendViaInPlace stamps the outer relay header, AAD-seals over the [outer
offset := len(out) // header | inner-already-staged] region, and writes the result to via.remote.
out = out[:offset+len(ad)] // Called from SendVia (after staging ad) and from sendNoMetrics' relay-forward
// path (where the inner ciphertext is already in place from SealForRelay).
// In one call path, the associated data _is_ already stored in out. In other call paths, the associated data must func (f *Interface) sendViaInPlace(via *HostInfo, relay *Relay, innerLen int, buf *WireBuffer) {
// be copied into 'out'. f.connectionManager.Out(via)
if !nocopy { out, err := buf.SealRelayInPlace(via.ConnectionState, relay.RemoteIndex, innerLen)
copy(out[offset:], ad)
}
var err error
out, err = via.ConnectionState.eKey.EncryptDanger(out, out, nil, c, nb)
if noiseutil.EncryptLockNeeded {
via.ConnectionState.writeLock.Unlock()
}
if err != nil { if err != nil {
via.logger(f.l).Info("Failed to EncryptDanger in sendVia", "error", err) via.logger(f.l).Info("Failed to EncryptDanger in sendVia", "error", err)
return return
} }
err = f.writers[0].WriteTo(out, via.remote) if err := f.writers[0].WriteTo(out, via.remote); err != nil {
if err != nil {
via.logger(f.l).Info("Failed to WriteTo in sendVia", "error", err) via.logger(f.l).Info("Failed to WriteTo in sendVia", "error", err)
} }
f.connectionManager.RelayUsed(relay.LocalIndex) f.connectionManager.RelayUsed(relay.LocalIndex)
} }
func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int) { // sendNoMetrics encrypts and writes one outbound nebula packet (data, control,
// lighthouse, etc) using buf as the per-call wire scratch. When the hostinfo
// has no direct remote we encrypt into the relay-reserved slot via
// SealForRelay so sendViaInPlace can wrap it without an extra copy.
func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p []byte, buf *WireBuffer, q int) {
if ci.eKey == nil { if ci.eKey == nil {
return return
} }
useRelay := !remote.IsValid() && !hostinfo.remote.IsValid() useRelay := !remote.IsValid() && !hostinfo.remote.IsValid()
fullOut := out
if useRelay {
if len(out) < header.Len {
// out always has a capacity of mtu, but not always a length greater than the header.Len.
// Grow it to make sure the next operation works.
out = out[:header.Len]
}
// Save a header's worth of data at the front of the 'out' buffer.
out = out[header.Len:]
}
if noiseutil.EncryptLockNeeded {
// NOTE: for goboring AESGCMTLS we need to lock because of the nonce check
ci.writeLock.Lock()
}
c := ci.messageCounter.Add(1)
//l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p)
out = header.Encode(out, header.Version, t, st, hostinfo.remoteIndexId, c)
f.connectionManager.Out(hostinfo) f.connectionManager.Out(hostinfo)
// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against // Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
@@ -381,50 +342,42 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
} }
} }
var out []byte
var err error var err error
out, err = ci.eKey.EncryptDanger(out, out, p, c, nb) if useRelay {
if noiseutil.EncryptLockNeeded { out, err = buf.SealForRelay(ci, t, st, hostinfo.remoteIndexId, p)
ci.writeLock.Unlock() } else {
out, err = buf.Seal(ci, t, st, hostinfo.remoteIndexId, p)
} }
if err != nil { if err != nil {
hostinfo.logger(f.l).Error("Failed to encrypt outgoing packet", hostinfo.logger(f.l).Error("Failed to encrypt outgoing packet",
"error", err, "error", err,
"udpAddr", remote, "udpAddr", remote,
"counter", c,
"attemptedCounter", c,
) )
return return
} }
if remote.IsValid() { switch {
err = f.writers[q].WriteTo(out, remote) case remote.IsValid():
if err != nil { if err := f.writers[q].WriteTo(out, remote); err != nil {
hostinfo.logger(f.l).Error("Failed to write outgoing packet", hostinfo.logger(f.l).Error("Failed to write outgoing packet", "error", err, "udpAddr", remote)
"error", err,
"udpAddr", remote,
)
} }
} else if hostinfo.remote.IsValid() { case hostinfo.remote.IsValid():
err = f.writers[q].WriteTo(out, hostinfo.remote) if err := f.writers[q].WriteTo(out, hostinfo.remote); err != nil {
if err != nil { hostinfo.logger(f.l).Error("Failed to write outgoing packet", "error", err, "udpAddr", hostinfo.remote)
hostinfo.logger(f.l).Error("Failed to write outgoing packet",
"error", err,
"udpAddr", remote,
)
} }
} else { default:
// Try to send via a relay // SealForRelay placed the inner ciphertext at buf.Out[header.Len:],
// so sendViaInPlace can wrap it with the outer relay header without
// an extra copy.
for _, relayIP := range hostinfo.relayState.CopyRelayIps() { for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP) relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
if err != nil { if err != nil {
hostinfo.relayState.DeleteRelay(relayIP) hostinfo.relayState.DeleteRelay(relayIP)
hostinfo.logger(f.l).Info("sendNoMetrics failed to find HostInfo", hostinfo.logger(f.l).Info("sendNoMetrics failed to find HostInfo", "relay", relayIP, "error", err)
"relay", relayIP,
"error", err,
)
continue continue
} }
f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true) f.sendViaInPlace(relayHostInfo, relay, len(out), buf)
break break
} }
} }

View File

@@ -101,19 +101,19 @@ type Interface struct {
messageMetrics *MessageMetrics messageMetrics *MessageMetrics
cachedPacketMetrics *cachedPacketMetrics cachedPacketMetrics *cachedPacketMetrics
// bufAlloc hands out reusable WireBuffers sized for this interface's
// inside Device. All buf consumers (hot-path data-plane goroutines,
// long-lived workers, and cold callers) acquire from here so sizing
// is centralized and consistent. Long-lived owners just don't release.
bufAlloc WireBufferAllocator
l *slog.Logger l *slog.Logger
} }
type EncWriter interface { type EncWriter interface {
SendVia(via *HostInfo, SendVia(via *HostInfo, relay *Relay, ad []byte, buf *WireBuffer)
relay *Relay, SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p []byte, buf *WireBuffer)
ad, SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p []byte, buf *WireBuffer)
nb,
out []byte,
nocopy bool,
)
SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte)
SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte)
Handshake(vpnAddr netip.Addr) Handshake(vpnAddr netip.Addr)
GetHostInfo(vpnAddr netip.Addr) *HostInfo GetHostInfo(vpnAddr netip.Addr) *HostInfo
GetCertState() *CertState GetCertState() *CertState
@@ -204,6 +204,8 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil), dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil),
}, },
bufAlloc: NewWireBufferPool(mtu, c.Inside.TunPrefixLen()),
l: c.l, l: c.l,
} }
@@ -311,13 +313,11 @@ func (f *Interface) listenOut(i int) {
ctCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout) ctCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout)
lhh := f.lightHouse.NewRequestHandler() lhh := f.lightHouse.NewRequestHandler()
plaintext := make([]byte, udp.MTU) // Long-lived per-receive-goroutine buf; never released back to the pool.
h := &header.H{} buf := f.bufAlloc.Acquire()
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get()) f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, buf, payload, lhh, i, ctCache.Get())
}) })
if err != nil && !f.closed.Load() { if err != nil && !f.closed.Load() {
@@ -329,15 +329,12 @@ func (f *Interface) listenOut(i int) {
} }
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
packet := make([]byte, mtu) // Long-lived per-tun-reader buf; never released back to the pool.
out := make([]byte, mtu) buf := f.bufAlloc.Acquire()
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
conntrackCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout) conntrackCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout)
for { for {
n, err := reader.Read(packet) _, err := buf.ReadIPFromTUN(reader)
if err != nil { if err != nil {
if !f.closed.Load() { if !f.closed.Load() {
f.l.Error("Error while reading outbound packet, closing", "error", err, "reader", i) f.l.Error("Error while reading outbound packet, closing", "error", err, "reader", i)
@@ -346,7 +343,7 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
break break
} }
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get()) f.consumeInsidePacket(buf, i, conntrackCache.Get())
} }
f.l.Debug("overlay reader is done", "reader", i) f.l.Debug("overlay reader is done", "reader", i)

View File

@@ -63,6 +63,10 @@ type LightHouse struct {
interval atomic.Int64 interval atomic.Int64
updateCancel context.CancelFunc updateCancel context.CancelFunc
ifce EncWriter ifce EncWriter
// bufAlloc lets the lighthouse query/update workers, request handlers
// and punchback goroutines acquire correctly sized WireBuffers from
// the same pool as the data plane. Set by main.go alongside ifce.
bufAlloc WireBufferAllocator
nebulaPort uint32 // 32 bits because protobuf does not have a uint16 nebulaPort uint32 // 32 bits because protobuf does not have a uint16
advertiseAddrs atomic.Pointer[[]netip.AddrPort] advertiseAddrs atomic.Pointer[[]netip.AddrPort]
@@ -109,6 +113,10 @@ func NewLightHouseFromConfig(ctx context.Context, l *slog.Logger, c *config.C, c
punchy: p, punchy: p,
updateTrigger: make(chan struct{}, 1), updateTrigger: make(chan struct{}, 1),
queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)), queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)),
// Default to a no-prefix pool so the query/update workers and
// request handlers have a working WireBufferAllocator before
// main.go wires up the real one from the Interface.
bufAlloc: NewWireBufferPool(mtu, 0),
l: l, l: l,
} }
lighthouses := make([]netip.Addr, 0) lighthouses := make([]netip.Addr, 0)
@@ -758,21 +766,22 @@ func (lh *LightHouse) startQueryWorker() {
} }
go func() { go func() {
nb := make([]byte, 12, 12) // Long-lived per-worker WireBuffer; reused for every lighthouse query
out := make([]byte, mtu) // this worker issues for the life of the goroutine.
buf := lh.bufAlloc.Acquire()
for { for {
select { select {
case <-lh.ctx.Done(): case <-lh.ctx.Done():
return return
case addr := <-lh.queryChan: case addr := <-lh.queryChan:
lh.innerQueryServer(addr, nb, out) lh.innerQueryServer(addr, buf)
} }
} }
}() }()
} }
func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) { func (lh *LightHouse) innerQueryServer(addr netip.Addr, buf *WireBuffer) {
if lh.IsLighthouseAddr(addr) { if lh.IsLighthouseAddr(addr) {
return return
} }
@@ -821,7 +830,7 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
} }
} }
lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, v1Query, nb, out) lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, v1Query, buf)
queried++ queried++
} else if v == cert.Version2 { } else if v == cert.Version2 {
@@ -840,7 +849,7 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
} }
} }
lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, v2Query, nb, out) lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, v2Query, buf)
queried++ queried++
} else { } else {
@@ -869,8 +878,12 @@ func (lh *LightHouse) StartUpdateWorker() {
go func() { go func() {
defer clockSource.Stop() defer clockSource.Stop()
// Long-lived per-worker WireBuffer; reused across every periodic
// update for the life of this goroutine.
buf := lh.bufAlloc.Acquire()
for { for {
lh.SendUpdate() lh.sendUpdate(buf)
select { select {
case <-updateCtx.Done(): case <-updateCtx.Done():
@@ -884,6 +897,15 @@ func (lh *LightHouse) StartUpdateWorker() {
}() }()
} }
// SendUpdate is the public entry point that triggers a one-shot lighthouse
// update outside the worker loop (e.g. tests or reload paths). It allocates
// its own WireBuffer since callers don't already own one.
func (lh *LightHouse) SendUpdate() {
buf := lh.bufAlloc.Acquire()
defer lh.bufAlloc.Release(buf)
lh.sendUpdate(buf)
}
// TriggerUpdate requests an immediate lighthouse update. This is a non-blocking // TriggerUpdate requests an immediate lighthouse update. This is a non-blocking
// operation intended to be called after a handshake completes with a lighthouse, // operation intended to be called after a handshake completes with a lighthouse,
// so the lighthouse has our current addresses without waiting for the next // so the lighthouse has our current addresses without waiting for the next
@@ -895,7 +917,7 @@ func (lh *LightHouse) TriggerUpdate() {
} }
} }
func (lh *LightHouse) SendUpdate() { func (lh *LightHouse) sendUpdate(buf *WireBuffer) {
var v4 []*V4AddrPort var v4 []*V4AddrPort
var v6 []*V6AddrPort var v6 []*V6AddrPort
@@ -921,9 +943,6 @@ func (lh *LightHouse) SendUpdate() {
} }
} }
nb := make([]byte, 12, 12)
out := make([]byte, mtu)
var v1Update, v2Update []byte var v1Update, v2Update []byte
var err error var err error
updated := 0 updated := 0
@@ -974,7 +993,7 @@ func (lh *LightHouse) SendUpdate() {
} }
} }
lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, v1Update, nb, out) lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, v1Update, buf)
updated++ updated++
} else if v == cert.Version2 { } else if v == cert.Version2 {
@@ -1003,7 +1022,7 @@ func (lh *LightHouse) SendUpdate() {
} }
} }
lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, v2Update, nb, out) lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, v2Update, buf)
updated++ updated++
} else { } else {
@@ -1020,8 +1039,10 @@ func (lh *LightHouse) SendUpdate() {
type LightHouseHandler struct { type LightHouseHandler struct {
lh *LightHouse lh *LightHouse
nb []byte // buf is the long-lived per-handler wire scratch. NewRequestHandler is
out []byte // called once per data-plane receive goroutine, so buf is owned by that
// goroutine and reused for every lighthouse send the handler issues.
buf *WireBuffer
pb []byte pb []byte
meta *NebulaMeta meta *NebulaMeta
l *slog.Logger l *slog.Logger
@@ -1030,8 +1051,7 @@ type LightHouseHandler struct {
func (lh *LightHouse) NewRequestHandler() *LightHouseHandler { func (lh *LightHouse) NewRequestHandler() *LightHouseHandler {
lhh := &LightHouseHandler{ lhh := &LightHouseHandler{
lh: lh, lh: lh,
nb: make([]byte, 12, 12), buf: lh.bufAlloc.Acquire(),
out: make([]byte, mtu),
l: lh.l, l: lh.l,
pb: make([]byte, mtu), pb: make([]byte, mtu),
@@ -1168,7 +1188,7 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
} }
lhh.lh.metricTx(NebulaMeta_HostQueryReply, 1) lhh.lh.metricTx(NebulaMeta_HostQueryReply, 1)
w.SendMessageToVpnAddr(header.LightHouse, 0, fromVpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0]) w.SendMessageToVpnAddr(header.LightHouse, 0, fromVpnAddrs[0], lhh.pb[:ln], lhh.buf)
lhh.sendHostPunchNotification(n, fromVpnAddrs, queryVpnAddr, w) lhh.sendHostPunchNotification(n, fromVpnAddrs, queryVpnAddr, w)
} }
@@ -1228,7 +1248,7 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
} }
lhh.lh.metricTx(NebulaMeta_HostPunchNotification, 1) lhh.lh.metricTx(NebulaMeta_HostPunchNotification, 1)
w.SendMessageToVpnAddr(header.LightHouse, 0, punchNotifDest, lhh.pb[:ln], lhh.nb, lhh.out[:0]) w.SendMessageToVpnAddr(header.LightHouse, 0, punchNotifDest, lhh.pb[:ln], lhh.buf)
} }
func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *NebulaMeta) { func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *NebulaMeta) {
@@ -1385,7 +1405,7 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
} }
lhh.lh.metricTx(NebulaMeta_HostUpdateNotificationAck, 1) lhh.lh.metricTx(NebulaMeta_HostUpdateNotificationAck, 1)
w.SendMessageToVpnAddr(header.LightHouse, 0, fromVpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0]) w.SendMessageToVpnAddr(header.LightHouse, 0, fromVpnAddrs[0], lhh.pb[:ln], lhh.buf)
} }
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) { func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
@@ -1452,10 +1472,13 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
"vpnAddr", detailsVpnAddr, "vpnAddr", detailsVpnAddr,
) )
} }
//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine // We acquire and release a fresh buf within this goroutine so it
// for each punchBack packet. We should move this into a timerwheel or a single goroutine // returns to the pool once the punchback send completes. We
// should move this into a timerwheel or a single goroutine
// managed by a channel. // managed by a channel.
w.SendMessageToVpnAddr(header.Test, header.TestRequest, detailsVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) pbuf := lhh.lh.bufAlloc.Acquire()
defer lhh.lh.bufAlloc.Release(pbuf)
w.SendMessageToVpnAddr(header.Test, header.TestRequest, detailsVpnAddr, []byte(""), pbuf)
}() }()
} }
} }

View File

@@ -372,12 +372,12 @@ type testEncWriter struct {
protocolVersion cert.Version protocolVersion cert.Version
} }
func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) { func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad []byte, buf *WireBuffer) {
} }
func (tw *testEncWriter) Handshake(vpnIp netip.Addr) { func (tw *testEncWriter) Handshake(vpnIp netip.Addr) {
} }
func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, _, _ []byte) { func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p []byte, _ *WireBuffer) {
msg := &NebulaMeta{} msg := &NebulaMeta{}
err := msg.Unmarshal(p) err := msg.Unmarshal(p)
if tw.metaFilter == nil || msg.Type == *tw.metaFilter { if tw.metaFilter == nil || msg.Type == *tw.metaFilter {
@@ -394,7 +394,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M
} }
} }
func (tw *testEncWriter) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, _, _ []byte) { func (tw *testEncWriter) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p []byte, _ *WireBuffer) {
msg := &NebulaMeta{} msg := &NebulaMeta{}
err := msg.Unmarshal(p) err := msg.Unmarshal(p)
if tw.metaFilter == nil || msg.Type == *tw.metaFilter { if tw.metaFilter == nil || msg.Type == *tw.metaFilter {

View File

@@ -232,6 +232,7 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev
ifce.writers = udpConns ifce.writers = udpConns
lightHouse.ifce = ifce lightHouse.ifce = ifce
lightHouse.bufAlloc = ifce.bufAlloc
ifce.RegisterConfigChangeCallbacks(c) ifce.RegisterConfigChangeCallbacks(c)
ifce.reloadDisconnectInvalid(c) ifce.reloadDisconnectInvalid(c)

View File

@@ -14,6 +14,19 @@ type endianness interface {
var noiseEndianness endianness = binary.BigEndian var noiseEndianness endianness = binary.BigEndian
// NonceSize is the AEAD nonce length used by all ciphers nebula supports
// today (AES-GCM and ChaCha20-Poly1305 both use 96-bit nonces). Encrypt-
// and DecryptDanger lay out the nonce as 4 zero bytes followed by an 8-byte
// big-endian counter; if a future cipher with a different nonce size is
// added, this constant and those layouts must change together.
const NonceSize = 12
// AEADOverhead is the AEAD authentication tag length the ciphers nebula
// supports append to ciphertext. Both AES-GCM and ChaCha20-Poly1305 use
// 128-bit tags. NebulaCipherState.Overhead() returns this dynamically from
// the cipher; the constant is for sizing buffers at construction time.
const AEADOverhead = 16
type NebulaCipherState struct { type NebulaCipherState struct {
c cipher.AEAD c cipher.AEAD
} }

View File

@@ -20,7 +20,8 @@ const (
minFwPacketLen = 4 minFwPacketLen = 4
) )
func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { func (f *Interface) readOutsidePackets(via ViaSender, buf *WireBuffer, packet []byte, lhf *LightHouseHandler, q int, localCache firewall.ConntrackCache) {
h := buf.H
err := h.Parse(packet) err := h.Parse(packet)
if err != nil { if err != nil {
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
@@ -65,7 +66,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
switch h.Subtype { switch h.Subtype {
case header.MessageNone: case header.MessageNone:
if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache) { if !f.decryptToTun(hostinfo, h.MessageCounter, buf, packet, q, localCache) {
return return
} }
case header.MessageRelay: case header.MessageRelay:
@@ -76,8 +77,9 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
// which will gracefully fail in the DecryptDanger call. // which will gracefully fail in the DecryptDanger call.
signedPayload := packet[:len(packet)-hostinfo.ConnectionState.dKey.Overhead()] signedPayload := packet[:len(packet)-hostinfo.ConnectionState.dKey.Overhead()]
signatureValue := packet[len(packet)-hostinfo.ConnectionState.dKey.Overhead():] signatureValue := packet[len(packet)-hostinfo.ConnectionState.dKey.Overhead():]
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, signedPayload, signatureValue, h.MessageCounter, nb) // AAD-only validation: passing dst=nil since there's no plaintext
if err != nil { // to recover (ciphertext is just the trailing AEAD tag).
if _, err = hostinfo.ConnectionState.dKey.DecryptDanger(nil, signedPayload, signatureValue, h.MessageCounter, buf.NB); err != nil {
return return
} }
// Successfully validated the thing. Get rid of the Relay header. // Successfully validated the thing. Get rid of the Relay header.
@@ -110,7 +112,8 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
relay: relay, relay: relay,
IsRelayed: true, IsRelayed: true,
} }
f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) buf.Reset()
f.readOutsidePackets(via, buf, signedPayload, lhf, q, localCache)
return return
case ForwardingType: case ForwardingType:
// Find the target HostInfo relay object // Find the target HostInfo relay object
@@ -130,7 +133,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
case ForwardingType: case ForwardingType:
// Forward this packet through the relay tunnel // Forward this packet through the relay tunnel
// Find the target HostInfo // Find the target HostInfo
f.SendVia(targetHI, targetRelay, signedPayload, nb, out, false) f.SendVia(targetHI, targetRelay, signedPayload, buf)
return return
case TerminalType: case TerminalType:
hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal") hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
@@ -152,7 +155,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
return return
} }
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) d, err := f.decrypt(hostinfo, h.MessageCounter, buf, packet, h)
if err != nil { if err != nil {
hostinfo.logger(f.l).Error("Failed to decrypt lighthouse packet", hostinfo.logger(f.l).Error("Failed to decrypt lighthouse packet",
"error", err, "error", err,
@@ -173,7 +176,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
return return
} }
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) d, err := f.decrypt(hostinfo, h.MessageCounter, buf, packet, h)
if err != nil { if err != nil {
hostinfo.logger(f.l).Error("Failed to decrypt test packet", hostinfo.logger(f.l).Error("Failed to decrypt test packet",
"error", err, "error", err,
@@ -185,9 +188,9 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
if h.Subtype == header.TestRequest { if h.Subtype == header.TestRequest {
// This testRequest might be from TryPromoteBest, so we should roam // This testRequest might be from TryPromoteBest, so we should roam
// to the new IP address before responding // to the new IP address before responding.
f.handleHostRoaming(hostinfo, via) f.handleHostRoaming(hostinfo, via)
f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out) f.send(header.Test, header.TestReply, ci, hostinfo, d, buf)
} }
// Fallthrough to the bottom to record incoming traffic // Fallthrough to the bottom to record incoming traffic
@@ -210,7 +213,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
if !f.handleEncrypted(ci, via, h) { if !f.handleEncrypted(ci, via, h) {
return return
} }
_, err = f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) _, err = f.decrypt(hostinfo, h.MessageCounter, buf, packet, h)
if err != nil { if err != nil {
hostinfo.logger(f.l).Error("Failed to decrypt CloseTunnel packet", hostinfo.logger(f.l).Error("Failed to decrypt CloseTunnel packet",
"error", err, "error", err,
@@ -230,7 +233,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
return return
} }
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) d, err := f.decrypt(hostinfo, h.MessageCounter, buf, packet, h)
if err != nil { if err != nil {
hostinfo.logger(f.l).Error("Failed to decrypt Control packet", hostinfo.logger(f.l).Error("Failed to decrypt Control packet",
"error", err, "error", err,
@@ -266,7 +269,9 @@ func (f *Interface) closeTunnel(hostInfo *HostInfo) {
// sendCloseTunnel is a helper function to send a proper close tunnel packet to a remote // sendCloseTunnel is a helper function to send a proper close tunnel packet to a remote
func (f *Interface) sendCloseTunnel(h *HostInfo) { func (f *Interface) sendCloseTunnel(h *HostInfo) {
f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) buf := f.bufAlloc.Acquire()
defer f.bufAlloc.Release(buf)
f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, buf)
} }
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, via ViaSender) { func (f *Interface) handleHostRoaming(hostinfo *HostInfo, via ViaSender) {
@@ -515,9 +520,8 @@ func parseV4(data []byte, incoming bool, fp *firewall.Packet) error {
return nil return nil
} }
func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []byte, h *header.H, nb []byte) ([]byte, error) { func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, buf *WireBuffer, packet []byte, h *header.H) ([]byte, error) {
var err error plaintext, err := buf.DecryptForHandler(hostinfo.ConnectionState, packet, mc)
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], mc, nb)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -529,42 +533,41 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
return nil, errors.New("out of window packet") return nil, errors.New("out of window packet")
} }
return out, nil return plaintext, nil
} }
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool { func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, buf *WireBuffer, packet []byte, q int, localCache firewall.ConntrackCache) bool {
var err error if err := buf.DecryptDatagram(hostinfo.ConnectionState, packet, messageCounter); err != nil {
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
if err != nil {
hostinfo.logger(f.l).Error("Failed to decrypt packet", "error", err) hostinfo.logger(f.l).Error("Failed to decrypt packet", "error", err)
return false return false
} }
err = newPacket(out, true, fwPacket) ipPacket := buf.IPPacket()
if err != nil { if err := newPacket(ipPacket, true, buf.FwPacket); err != nil {
hostinfo.logger(f.l).Warn("Error while validating inbound packet", hostinfo.logger(f.l).Warn("Error while validating inbound packet",
"error", err, "error", err,
"packet", out, "packet", ipPacket,
) )
return false return false
} }
if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) { if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) {
if f.l.Enabled(context.Background(), slog.LevelDebug) { if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("dropping out of window packet", "fwPacket", fwPacket) hostinfo.logger(f.l).Debug("dropping out of window packet", "fwPacket", buf.FwPacket)
} }
return false return false
} }
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache) dropReason := f.firewall.Drop(*buf.FwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason != nil { if dropReason != nil {
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore // NOTE: We hand `packet` (the original UDP ciphertext we already
// This gives us a buffer to build the reject packet in // decrypted from) as the reject-IP scratch since we no longer
f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q) // need its ciphertext, and it's disjoint from buf.Out where
// sendNoMetrics will encrypt the wire packet.
f.rejectOutside(ipPacket, hostinfo.ConnectionState, hostinfo, packet, buf, q)
if f.l.Enabled(context.Background(), slog.LevelDebug) { if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("dropping inbound packet", hostinfo.logger(f.l).Debug("dropping inbound packet",
"fwPacket", fwPacket, "fwPacket", buf.FwPacket,
"reason", dropReason, "reason", dropReason,
) )
} }
@@ -572,8 +575,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
} }
f.connectionManager.In(hostinfo) f.connectionManager.In(hostinfo)
_, err = f.readers[q].Write(out) if _, err := buf.WriteIPToTUN(f.readers[q]); err != nil {
if err != nil {
f.l.Error("Failed to write to tun", "error", err) f.l.Error("Failed to write to tun", "error", err)
} }
return true return true

View File

@@ -15,4 +15,7 @@ type Device interface {
RoutesFor(netip.Addr) routing.Gateways RoutesFor(netip.Addr) routing.Gateways
SupportsMultiqueue() bool SupportsMultiqueue() bool
NewMultiQueueReader() (io.ReadWriteCloser, error) NewMultiQueueReader() (io.ReadWriteCloser, error)
// TunPrefixLen reports the number of bytes the device prepends to every IP packet on the wire.
// Currently only non zero for the BSD tun devices.
TunPrefixLen() int
} }

View File

@@ -50,3 +50,5 @@ func (NoopTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (NoopTun) Close() error { func (NoopTun) Close() error {
return nil return nil
} }
func (NoopTun) TunPrefixLen() int { return 0 }

View File

@@ -102,3 +102,5 @@ func (t *tun) SupportsMultiqueue() bool {
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for android") return nil, fmt.Errorf("TODO: multiqueue not implemented for android")
} }
func (t *tun) TunPrefixLen() int { return 0 }

29
overlay/tun_bsd.go Normal file
View File

@@ -0,0 +1,29 @@
//go:build (darwin || ios || freebsd || openbsd || netbsd) && !e2e_testing
package overlay
import (
"fmt"
"syscall"
)
// StampTunPrefix writes the 4-byte AF_INET / AF_INET6 protocol-family marker into buf[0:4] in place,
// picking the family from the first byte of the IP packet at buf[4].
func StampTunPrefix(buf []byte) error {
if len(buf) < 5 {
return fmt.Errorf("tun write buffer too small for prefix")
}
ipVer := buf[4] >> 4
buf[0] = 0
buf[1] = 0
buf[2] = 0
switch ipVer {
case 4:
buf[3] = syscall.AF_INET
case 6:
buf[3] = syscall.AF_INET6
default:
return fmt.Errorf("unable to determine IP version from packet")
}
return nil
}

View File

@@ -11,7 +11,6 @@ import (
"net/netip" "net/netip"
"os" "os"
"sync/atomic" "sync/atomic"
"syscall"
"unsafe" "unsafe"
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
@@ -31,9 +30,6 @@ type tun struct {
routeTree atomic.Pointer[bart.Table[routing.Gateways]] routeTree atomic.Pointer[bart.Table[routing.Gateways]]
linkAddr *netroute.LinkAddr linkAddr *netroute.LinkAddr
l *slog.Logger l *slog.Logger
// cache out buffer since we need to prepend 4 bytes for tun metadata
out []byte
} }
type ifReq struct { type ifReq struct {
@@ -502,44 +498,6 @@ func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
return nil return nil
} }
func (t *tun) Read(to []byte) (int, error) {
buf := make([]byte, len(to)+4)
n, err := t.ReadWriteCloser.Read(buf)
copy(to, buf[4:])
return n - 4, err
}
// Write is only valid for single threaded use
func (t *tun) Write(from []byte) (int, error) {
buf := t.out
if cap(buf) < len(from)+4 {
buf = make([]byte, len(from)+4)
t.out = buf
}
buf = buf[:len(from)+4]
if len(from) == 0 {
return 0, syscall.EIO
}
// Determine the IP Family for the NULL L2 Header
ipVer := from[0] >> 4
if ipVer == 4 {
buf[3] = syscall.AF_INET
} else if ipVer == 6 {
buf[3] = syscall.AF_INET6
} else {
return 0, fmt.Errorf("unable to determine IP version from packet")
}
copy(buf[4:], from)
n, err := t.ReadWriteCloser.Write(buf)
return n - 4, err
}
func (t *tun) Networks() []netip.Prefix { func (t *tun) Networks() []netip.Prefix {
return t.vpnNetworks return t.vpnNetworks
} }
@@ -555,3 +513,7 @@ func (t *tun) SupportsMultiqueue() bool {
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin") return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
} }
// TunPrefixLen reports the 4-byte BSD AF_INET / AF_INET6 protocol-family
// marker the kernel prepends on read and expects on write.
func (t *tun) TunPrefixLen() int { return 4 }

View File

@@ -136,3 +136,5 @@ func (p prettyPacket) String() string {
return s.String() return s.String()
} }
func (t *disabledTun) TunPrefixLen() int { return 0 }

View File

@@ -158,74 +158,43 @@ func (t *tun) blockOnWrite() error {
} }
func (t *tun) Read(to []byte) (int, error) { func (t *tun) Read(to []byte) (int, error) {
// first 4 bytes is protocol family, in network byte order
var head [4]byte
iovecs := [2]syscall.Iovec{
{&head[0], 4},
{&to[0], uint64(len(to))},
}
for { for {
n, _, errno := syscall.Syscall(syscall.SYS_READV, uintptr(t.fd), uintptr(unsafe.Pointer(&iovecs[0])), 2) n, err := unix.Read(t.fd, to)
if errno == 0 { if err == nil {
bytesRead := int(n) return n, nil
if bytesRead < 4 {
return 0, nil
} }
return bytesRead - 4, nil switch err {
}
switch errno {
case unix.EAGAIN: case unix.EAGAIN:
if err := t.blockOnRead(); err != nil { if berr := t.blockOnRead(); berr != nil {
return 0, err return 0, berr
} }
case unix.EINTR: case unix.EINTR:
// retry // retry
case unix.EBADF: case unix.EBADF:
return 0, os.ErrClosed return 0, os.ErrClosed
default: default:
return 0, errno return 0, err
} }
} }
} }
// Write is only valid for single threaded use
func (t *tun) Write(from []byte) (int, error) { func (t *tun) Write(from []byte) (int, error) {
if len(from) <= 1 {
return 0, syscall.EIO
}
ipVer := from[0] >> 4
var head [4]byte
// first 4 bytes is protocol family, in network byte order
switch ipVer {
case 4:
head[3] = syscall.AF_INET
case 6:
head[3] = syscall.AF_INET6
default:
return 0, fmt.Errorf("unable to determine IP version from packet")
}
iovecs := [2]syscall.Iovec{
{&head[0], 4},
{&from[0], uint64(len(from))},
}
for { for {
n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, uintptr(t.fd), uintptr(unsafe.Pointer(&iovecs[0])), 2) n, err := unix.Write(t.fd, from)
if errno == 0 { if err == nil {
return int(n) - 4, nil return n, nil
} }
switch errno { switch err {
case unix.EAGAIN: case unix.EAGAIN:
if err := t.blockOnWrite(); err != nil { if berr := t.blockOnWrite(); berr != nil {
return 0, err return 0, berr
} }
case unix.EINTR: case unix.EINTR:
// retry // retry
case unix.EBADF: case unix.EBADF:
return 0, os.ErrClosed return 0, os.ErrClosed
default: default:
return 0, errno return 0, err
} }
} }
} }
@@ -732,3 +701,7 @@ func getLinkAddr(name string) (*netroute.LinkAddr, error) {
return nil, nil return nil, nil
} }
// TunPrefixLen reports the 4-byte BSD AF_INET / AF_INET6 protocol-family
// marker the kernel prepends on read and expects on write.
func (t *tun) TunPrefixLen() int { return 4 }

View File

@@ -4,15 +4,12 @@
package overlay package overlay
import ( import (
"errors"
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
"net/netip" "net/netip"
"os" "os"
"sync"
"sync/atomic" "sync/atomic"
"syscall"
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
@@ -36,7 +33,7 @@ func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip
file := os.NewFile(uintptr(deviceFd), "/dev/tun") file := os.NewFile(uintptr(deviceFd), "/dev/tun")
t := &tun{ t := &tun{
vpnNetworks: vpnNetworks, vpnNetworks: vpnNetworks,
ReadWriteCloser: &tunReadCloser{f: file}, ReadWriteCloser: file,
l: l, l: l,
} }
@@ -85,64 +82,6 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
return r return r
} }
// The following is hoisted up from water, we do this so we can inject our own fd on iOS
type tunReadCloser struct {
f io.ReadWriteCloser
rMu sync.Mutex
rBuf []byte
wMu sync.Mutex
wBuf []byte
}
func (tr *tunReadCloser) Read(to []byte) (int, error) {
tr.rMu.Lock()
defer tr.rMu.Unlock()
if cap(tr.rBuf) < len(to)+4 {
tr.rBuf = make([]byte, len(to)+4)
}
tr.rBuf = tr.rBuf[:len(to)+4]
n, err := tr.f.Read(tr.rBuf)
copy(to, tr.rBuf[4:])
return n - 4, err
}
func (tr *tunReadCloser) Write(from []byte) (int, error) {
if len(from) == 0 {
return 0, syscall.EIO
}
tr.wMu.Lock()
defer tr.wMu.Unlock()
if cap(tr.wBuf) < len(from)+4 {
tr.wBuf = make([]byte, len(from)+4)
}
tr.wBuf = tr.wBuf[:len(from)+4]
// Determine the IP Family for the NULL L2 Header
ipVer := from[0] >> 4
if ipVer == 4 {
tr.wBuf[3] = syscall.AF_INET
} else if ipVer == 6 {
tr.wBuf[3] = syscall.AF_INET6
} else {
return 0, errors.New("unable to determine IP version from packet")
}
copy(tr.wBuf[4:], from)
n, err := tr.f.Write(tr.wBuf)
return n - 4, err
}
func (tr *tunReadCloser) Close() error {
return tr.f.Close()
}
func (t *tun) Networks() []netip.Prefix { func (t *tun) Networks() []netip.Prefix {
return t.vpnNetworks return t.vpnNetworks
} }
@@ -158,3 +97,7 @@ func (t *tun) SupportsMultiqueue() bool {
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for ios") return nil, fmt.Errorf("TODO: multiqueue not implemented for ios")
} }
// TunPrefixLen reports the 4-byte BSD AF_INET / AF_INET6 protocol-family
// marker the kernel prepends on read and expects on write.
func (t *tun) TunPrefixLen() int { return 4 }

View File

@@ -907,3 +907,5 @@ func (t *tun) Close() error {
} }
return err return err
} }
func (t *tun) TunPrefixLen() int { return 0 }

View File

@@ -58,13 +58,13 @@ type addrLifetime struct {
} }
type tun struct { type tun struct {
io.ReadWriteCloser
Device string Device string
vpnNetworks []netip.Prefix vpnNetworks []netip.Prefix
MTU int MTU int
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]] routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *slog.Logger l *slog.Logger
f *os.File
fd int fd int
} }
@@ -96,7 +96,7 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*t
} }
t := &tun{ t := &tun{
f: os.NewFile(uintptr(fd), ""), ReadWriteCloser: os.NewFile(uintptr(fd), ""),
fd: fd, fd: fd,
Device: deviceName, Device: deviceName,
vpnNetworks: vpnNetworks, vpnNetworks: vpnNetworks,
@@ -120,12 +120,12 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*t
} }
func (t *tun) Close() error { func (t *tun) Close() error {
if t.f != nil { if t.ReadWriteCloser != nil {
if err := t.f.Close(); err != nil { if err := t.ReadWriteCloser.Close(); err != nil {
return fmt.Errorf("error closing tun file: %w", err) return fmt.Errorf("error closing tun file: %w", err)
} }
// t.f.Close should have handled it for us but let's be extra sure // Close on the os.File should have handled the fd for us but let's be extra sure
_ = unix.Close(t.fd) _ = unix.Close(t.fd)
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP) s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
@@ -141,99 +141,6 @@ func (t *tun) Close() error {
return nil return nil
} }
func (t *tun) Read(to []byte) (int, error) {
rc, err := t.f.SyscallConn()
if err != nil {
return 0, fmt.Errorf("failed to get syscall conn for tun: %w", err)
}
var errno syscall.Errno
var n uintptr
err = rc.Read(func(fd uintptr) bool {
// first 4 bytes is protocol family, in network byte order
head := [4]byte{}
iovecs := []syscall.Iovec{
{&head[0], 4},
{&to[0], uint64(len(to))},
}
n, _, errno = syscall.Syscall(syscall.SYS_READV, fd, uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
if errno.Temporary() {
// We got an EAGAIN, EINTR, or EWOULDBLOCK, go again
return false
}
return true
})
if err != nil {
if err == syscall.EBADF || err.Error() == "use of closed file" {
// Go doesn't export poll.ErrFileClosing but happily reports it to us so here we are
// https://github.com/golang/go/blob/master/src/internal/poll/fd_poll_runtime.go#L121
return 0, os.ErrClosed
}
return 0, fmt.Errorf("failed to make read call for tun: %w", err)
}
if errno != 0 {
return 0, fmt.Errorf("failed to make inner read call for tun: %w", errno)
}
// fix bytes read number to exclude header
bytesRead := int(n)
if bytesRead < 0 {
return bytesRead, nil
} else if bytesRead < 4 {
return 0, nil
} else {
return bytesRead - 4, nil
}
}
// Write is only valid for single threaded use
func (t *tun) Write(from []byte) (int, error) {
if len(from) <= 1 {
return 0, syscall.EIO
}
ipVer := from[0] >> 4
var head [4]byte
// first 4 bytes is protocol family, in network byte order
if ipVer == 4 {
head[3] = syscall.AF_INET
} else if ipVer == 6 {
head[3] = syscall.AF_INET6
} else {
return 0, fmt.Errorf("unable to determine IP version from packet")
}
rc, err := t.f.SyscallConn()
if err != nil {
return 0, err
}
var errno syscall.Errno
var n uintptr
err = rc.Write(func(fd uintptr) bool {
iovecs := []syscall.Iovec{
{&head[0], 4},
{&from[0], uint64(len(from))},
}
n, _, errno = syscall.Syscall(syscall.SYS_WRITEV, fd, uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
// According to NetBSD documentation for TUN, writes will only return errors in which
// this packet will never be delivered so just go on living life.
return true
})
if err != nil {
return 0, err
}
if errno != 0 {
return 0, errno
}
return int(n) - 4, err
}
func (t *tun) addIp(cidr netip.Prefix) error { func (t *tun) addIp(cidr netip.Prefix) error {
if cidr.Addr().Is4() { if cidr.Addr().Is4() {
var req ifreqAlias4 var req ifreqAlias4
@@ -551,3 +458,7 @@ func delRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
return nil return nil
} }
// TunPrefixLen reports the 4-byte BSD AF_INET / AF_INET6 protocol-family
// marker the kernel prepends on read and expects on write.
func (t *tun) TunPrefixLen() int { return 4 }

10
overlay/tun_no_prefix.go Normal file
View File

@@ -0,0 +1,10 @@
//go:build (!darwin && !ios && !freebsd && !openbsd && !netbsd) || e2e_testing
package overlay
// StampTunPrefix is a no-op on platforms whose tun devices have no
// protocol-family marker. WireBuffer only invokes it when its prefixLen
// is non-zero, so this should never be reached on these platforms.
func StampTunPrefix(buf []byte) error {
return nil
}

View File

@@ -49,16 +49,14 @@ type ifreq struct {
} }
type tun struct { type tun struct {
io.ReadWriteCloser
Device string Device string
vpnNetworks []netip.Prefix vpnNetworks []netip.Prefix
MTU int MTU int
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]] routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *slog.Logger l *slog.Logger
f *os.File
fd int fd int
// cache out buffer since we need to prepend 4 bytes for tun metadata
out []byte
} }
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
@@ -89,7 +87,7 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*t
} }
t := &tun{ t := &tun{
f: os.NewFile(uintptr(fd), ""), ReadWriteCloser: os.NewFile(uintptr(fd), ""),
fd: fd, fd: fd,
Device: deviceName, Device: deviceName,
vpnNetworks: vpnNetworks, vpnNetworks: vpnNetworks,
@@ -113,55 +111,17 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*t
} }
func (t *tun) Close() error { func (t *tun) Close() error {
if t.f != nil { if t.ReadWriteCloser != nil {
if err := t.f.Close(); err != nil { if err := t.ReadWriteCloser.Close(); err != nil {
return fmt.Errorf("error closing tun file: %w", err) return fmt.Errorf("error closing tun file: %w", err)
} }
// t.f.Close should have handled it for us but let's be extra sure // Close on the os.File should have handled the fd for us but let's be extra sure
_ = unix.Close(t.fd) _ = unix.Close(t.fd)
} }
return nil return nil
} }
func (t *tun) Read(to []byte) (int, error) {
buf := make([]byte, len(to)+4)
n, err := t.f.Read(buf)
copy(to, buf[4:])
return n - 4, err
}
// Write is only valid for single threaded use
func (t *tun) Write(from []byte) (int, error) {
buf := t.out
if cap(buf) < len(from)+4 {
buf = make([]byte, len(from)+4)
t.out = buf
}
buf = buf[:len(from)+4]
if len(from) == 0 {
return 0, syscall.EIO
}
// Determine the IP Family for the NULL L2 Header
ipVer := from[0] >> 4
if ipVer == 4 {
buf[3] = syscall.AF_INET
} else if ipVer == 6 {
buf[3] = syscall.AF_INET6
} else {
return 0, fmt.Errorf("unable to determine IP version from packet")
}
copy(buf[4:], from)
n, err := t.f.Write(buf)
return n - 4, err
}
func (t *tun) addIp(cidr netip.Prefix) error { func (t *tun) addIp(cidr netip.Prefix) error {
if cidr.Addr().Is4() { if cidr.Addr().Is4() {
var req ifreqAlias4 var req ifreqAlias4
@@ -471,3 +431,7 @@ func delRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
return nil return nil
} }
// TunPrefixLen reports the 4-byte BSD AF_INET / AF_INET6 protocol-family
// marker the kernel prepends on read and expects on write.
func (t *tun) TunPrefixLen() int { return 4 }

View File

@@ -184,3 +184,5 @@ func (t *TestTun) SupportsMultiqueue() bool {
func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented") return nil, fmt.Errorf("TODO: multiqueue not implemented")
} }
func (t *TestTun) TunPrefixLen() int { return 0 }

View File

@@ -296,3 +296,5 @@ func checkWinTunExists() error {
_, err = syscall.LoadDLL(filepath.Join(filepath.Dir(myPath), "dist", "windows", "wintun", "bin", arch, "wintun.dll")) _, err = syscall.LoadDLL(filepath.Join(filepath.Dir(myPath), "dist", "windows", "wintun", "bin", arch, "wintun.dll"))
return err return err
} }
func (t *winTun) TunPrefixLen() int { return 0 }

View File

@@ -69,3 +69,5 @@ func (d *UserDevice) Close() error {
d.outboundWriter.Close() d.outboundWriter.Close()
return nil return nil
} }
func (d *UserDevice) TunPrefixLen() int { return 0 }

View File

@@ -63,6 +63,9 @@ func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *Ho
} }
hostinfo.logger(rm.l).Info("Attempt to relay through hosts", "relays", hostinfo.remotes.relays) hostinfo.logger(rm.l).Info("Attempt to relay through hosts", "relays", hostinfo.remotes.relays)
// One WireBuffer for the whole relay-fanout loop.
buf := f.bufAlloc.Acquire()
defer f.bufAlloc.Release(buf)
// Send a RelayRequest to all known Relay IP's // Send a RelayRequest to all known Relay IP's
for _, relay := range hostinfo.remotes.relays { for _, relay := range hostinfo.remotes.relays {
// Don't relay through the host I'm trying to connect to // Don't relay through the host I'm trying to connect to
@@ -124,7 +127,7 @@ func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *Ho
if err != nil { if err != nil {
hostinfo.logger(rm.l).Error("Failed to marshal Control message to create relay", "error", err) hostinfo.logger(rm.l).Error("Failed to marshal Control message to create relay", "error", err)
} else { } else {
f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, buf)
rm.l.Info("send CreateRelayRequest", rm.l.Info("send CreateRelayRequest",
"relayFrom", f.myVpnAddrs[0], "relayFrom", f.myVpnAddrs[0],
"relayTo", vpnIp, "relayTo", vpnIp,
@@ -139,7 +142,7 @@ func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *Ho
switch existingRelay.State { switch existingRelay.State {
case Established: case Established:
hostinfo.logger(rm.l).Info("Send handshake via relay", "relay", relay.String()) hostinfo.logger(rm.l).Info("Send handshake via relay", "relay", relay.String())
f.SendVia(relayHostInfo, existingRelay, stage0, make([]byte, 12), make([]byte, mtu), false) f.SendVia(relayHostInfo, existingRelay, stage0, buf)
case Disestablished: case Disestablished:
// Mark this relay as 'requested' // Mark this relay as 'requested'
relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested) relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested)
@@ -180,7 +183,7 @@ func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *Ho
hostinfo.logger(rm.l).Error("Failed to marshal Control message to create relay", "error", err) hostinfo.logger(rm.l).Error("Failed to marshal Control message to create relay", "error", err)
} else { } else {
// This must send over the hostinfo, not over hm.Hosts[ip] // This must send over the hostinfo, not over hm.Hosts[ip]
f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, buf)
rm.l.Info("send CreateRelayRequest", rm.l.Info("send CreateRelayRequest",
"relayFrom", f.myVpnAddrs[0], "relayFrom", f.myVpnAddrs[0],
"relayTo", vpnIp, "relayTo", vpnIp,
@@ -368,7 +371,9 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f
if err != nil { if err != nil {
rm.l.Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay", "error", err) rm.l.Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay", "error", err)
} else { } else {
f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu)) buf := f.bufAlloc.Acquire()
f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, buf)
f.bufAlloc.Release(buf)
rm.l.Info("send CreateRelayResponse", rm.l.Info("send CreateRelayResponse",
"relayFrom", resp.RelayFromAddr, "relayFrom", resp.RelayFromAddr,
"relayTo", resp.RelayToAddr, "relayTo", resp.RelayToAddr,
@@ -468,7 +473,9 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
if err != nil { if err != nil {
logMsg.Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay", "error", err) logMsg.Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay", "error", err)
} else { } else {
f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu)) buf := f.bufAlloc.Acquire()
f.SendMessageToHostInfo(header.Control, 0, h, msg, buf)
f.bufAlloc.Release(buf)
rm.l.Info("send CreateRelayResponse", rm.l.Info("send CreateRelayResponse",
"relayFrom", from, "relayFrom", from,
"relayTo", target, "relayTo", target,
@@ -538,7 +545,9 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
if err != nil { if err != nil {
logMsg.Error("relayManager Failed to marshal Control message to create relay", "error", err) logMsg.Error("relayManager Failed to marshal Control message to create relay", "error", err)
} else { } else {
f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu)) buf := f.bufAlloc.Acquire()
f.SendMessageToHostInfo(header.Control, 0, peer, msg, buf)
f.bufAlloc.Release(buf)
rm.l.Info("send CreateRelayRequest", rm.l.Info("send CreateRelayRequest",
"relayFrom", h.vpnAddrs[0], "relayFrom", h.vpnAddrs[0],
"relayTo", target, "relayTo", target,

12
ssh.go
View File

@@ -632,15 +632,9 @@ func sshCloseTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) er
} }
if !flags.LocalOnly { if !flags.LocalOnly {
ifce.send( buf := ifce.bufAlloc.Acquire()
header.CloseTunnel, ifce.send(header.CloseTunnel, 0, hostInfo.ConnectionState, hostInfo, []byte{}, buf)
0, ifce.bufAlloc.Release(buf)
hostInfo.ConnectionState,
hostInfo,
[]byte{},
make([]byte, 12, 12),
make([]byte, mtu),
)
} }
ifce.closeTunnel(hostInfo) ifce.closeTunnel(hostInfo)

255
wire_buffer.go Normal file
View File

@@ -0,0 +1,255 @@
package nebula
import (
"io"
"sync"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/noiseutil"
"github.com/slackhq/nebula/overlay"
)
// WireBuffer is the per-goroutine working set for processing one IP packet
// through the data plane. It owns:
//
// - The IP-payload byte buffer used to hold the current inbound or
// outbound packet, with prefixLen bytes of slack at the front for
// the BSD AF_INET protocol-family marker.
// - The fwPacket scratch parsed by newPacket().
// - The 12-byte AEAD nonce scratch.
// - The header.H parse target used by the receive path.
// - An mtu-sized wire-output scratch for sendNoMetrics and for building
// reject packets.
//
// One WireBuffer is allocated per data-plane goroutine (listenIn for the
// TUN-side, listenOut for the UDP-side) and reused for every packet. No
// per-packet allocation. Future GRO/GSO/TSO and reliable-transport work
// will likely extend this to carry batch state and fragment metadata.
//
// The TUN protocol-family prefix is handled here, not in the overlay
// package. On BSDs the kernel writes the 4-byte marker into the slack on
// read, and we stamp it into the slack before write. On linux/windows
// /userspace devices prefixLen is 0 and the slack is empty.
type WireBuffer struct {
// FwPacket is the parsed IP packet metadata (5-tuple, fragment flags,
// etc.) populated by newPacket().
FwPacket *firewall.Packet
// NB is a 12-byte scratch the AEAD uses for the nonce; reused so we
// don't allocate one per encrypt/decrypt.
NB []byte
// H is the parse target for inbound nebula headers. Receive path only.
H *header.H
// Out is an mtu-sized wire-output scratch passed to sendNoMetrics and
// rejectInside / rejectOutside. Sized to fit any single wire packet.
Out []byte
// ip is the IP-payload region: a slice of len 0, cap linkMTU sliced
// from raw at offset prefixLen. The current packet (if any) is
// ip[:bodyN]. The TUN prefix slack lives at raw[0:prefixLen] just
// before ip.
ip []byte
// raw is the backing slab. Layout:
// [prefixLen bytes prefix slack | linkMTU bytes IP region | outSize bytes Out scratch]
// Holding it lets ReadIPFromTUN / WriteIPToTUN address the slack
// region directly.
raw []byte
prefixLen int
bodyN int
}
// NewWireBuffer returns a buffer sized to hold any single IP packet up to
// linkMTU, plus a disjoint wire-output scratch sliced from the same backing
// slab (the AEAD's Seal contract requires plaintext and dst not to partially
// overlap, and keeping them in one slab gives a single allocation per
// goroutine). Out is sized for the relay worst case
// (linkMTU + 2*header.Len + 2*AEADOverhead).
//
// prefixLen is the number of bytes the destination tun device prepends/
// expects on each IP packet (overlay.Device.TunPrefixLen). On BSDs this
// is 4 (AF_INET marker); on linux/windows/userspace devices it is 0.
func NewWireBuffer(linkMTU, prefixLen int) *WireBuffer {
outSize := linkMTU + 2*header.Len + 2*AEADOverhead
raw := make([]byte, prefixLen+linkMTU+outSize)
outStart := prefixLen + linkMTU
return &WireBuffer{
FwPacket: &firewall.Packet{},
NB: make([]byte, NonceSize),
H: &header.H{},
Out: raw[outStart : outStart : outStart+outSize],
ip: raw[prefixLen:prefixLen:outStart],
raw: raw,
prefixLen: prefixLen,
}
}
// Reset clears the body-length record so the buffer is ready for another
// recv (e.g. relay-receive recursion before a nested decrypt).
func (b *WireBuffer) Reset() { b.bodyN = 0 }
// IPPacket returns the IP packet currently held in the payload region (after
// a successful ReadIPFromTUN or DecryptDatagram). The slice aliases the
// buffer; do not retain past the next operation.
func (b *WireBuffer) IPPacket() []byte {
return b.ip[:b.bodyN]
}
// Seal stamps a nebula header at the front of buf.Out and AEAD-seals p as the
// payload, treating the header as additional authenticated data. The lock
// scope around counter increment + encrypt matches what goboring AESGCMTLS
// requires; non-boring builds skip the lock.
//
// Returns the wire bytes (header || ciphertext || tag), aliased to buf.Out.
// The slice is invalidated by the next Seal* call on this buffer.
func (b *WireBuffer) Seal(ci *ConnectionState, t header.MessageType, st header.MessageSubType, remoteIndex uint32, p []byte) ([]byte, error) {
return b.sealInto(b.Out[:cap(b.Out)], ci, t, st, remoteIndex, p)
}
// SealForRelay is like Seal but reserves header.Len bytes of slack at the front
// of buf.Out for an outer relay header. The inner header + ciphertext lands at
// offset header.Len so a follow-up SealRelayInPlace can stamp the outer header
// without copying. Use this when the caller may need to wrap the result in a
// relay envelope after the fact.
func (b *WireBuffer) SealForRelay(ci *ConnectionState, t header.MessageType, st header.MessageSubType, remoteIndex uint32, p []byte) ([]byte, error) {
return b.sealInto(b.Out[header.Len:cap(b.Out)], ci, t, st, remoteIndex, p)
}
func (b *WireBuffer) sealInto(out []byte, ci *ConnectionState, t header.MessageType, st header.MessageSubType, remoteIndex uint32, p []byte) ([]byte, error) {
if noiseutil.EncryptLockNeeded {
ci.writeLock.Lock()
}
c := ci.messageCounter.Add(1)
out = header.Encode(out, header.Version, t, st, remoteIndex, c)
out, err := ci.eKey.EncryptDanger(out, out, p, c, b.NB)
if noiseutil.EncryptLockNeeded {
ci.writeLock.Unlock()
}
return out, err
}
// SealRelayInPlace wraps an inner message that is already staged at
// buf.Out[header.Len:header.Len+innerLen] (either from a SealForRelay encrypt
// or from a copy via the SendVia entry point). It stamps the outer relay
// header into buf.Out[:header.Len] and AAD-only seals over the entire region,
// producing the wire bytes for the relay tunnel.
//
// Returns the wire bytes aliased to buf.Out; invalidated by the next Seal*
// call on this buffer.
func (b *WireBuffer) SealRelayInPlace(ci *ConnectionState, remoteIndex uint32, innerLen int) ([]byte, error) {
if noiseutil.EncryptLockNeeded {
ci.writeLock.Lock()
}
c := ci.messageCounter.Add(1)
out := b.Out[:cap(b.Out)]
out = header.Encode(out, header.Version, header.Message, header.MessageRelay, remoteIndex, c)
out = out[:header.Len+innerLen]
out, err := ci.eKey.EncryptDanger(out, out, nil, c, b.NB)
if noiseutil.EncryptLockNeeded {
ci.writeLock.Unlock()
}
return out, err
}
// StageRelayInner copies ad into the inner-payload slot at buf.Out[header.Len:]
// so SealRelayInPlace can wrap it on the next call. Used by SendVia when ad
// did not come from a prior SealForRelay (e.g. a handshake message being
// forwarded through a relay tunnel without our own encryption).
func (b *WireBuffer) StageRelayInner(ad []byte) int {
return copy(b.Out[header.Len:cap(b.Out)], ad)
}
// ReadIPFromTUN reads one IP packet from r into the payload region and
// updates bodyN. On BSDs the kernel writes its 4-byte protocol-family
// marker into the slack at raw[0:prefixLen] and the IP packet at
// raw[prefixLen:prefixLen+n]; we hand it the slack-prefixed slice so
// the kernel can do this in one syscall with no copy. On linux/windows/
// userspace devices prefixLen is 0 and the slack is empty.
func (b *WireBuffer) ReadIPFromTUN(r io.Reader) (int, error) {
n, err := r.Read(b.raw[:b.prefixLen+cap(b.ip)])
if err != nil {
b.bodyN = 0
return 0, err
}
if n < b.prefixLen {
b.bodyN = 0
return 0, nil
}
b.bodyN = n - b.prefixLen
return b.bodyN, nil
}
// WriteIPToTUN writes the IP packet currently in the payload region to w.
// On BSDs we stamp the protocol-family marker into the slack at
// raw[0:prefixLen] in place and write the entire slack+IP region in a
// single syscall, so the kernel sees [marker][ip] back to back without a
// userspace copy. On linux/windows/userspace devices the slack is empty
// and we just write the IP region.
func (b *WireBuffer) WriteIPToTUN(w io.Writer) (int, error) {
out := b.raw[:b.prefixLen+b.bodyN]
if b.prefixLen > 0 {
if err := overlay.StampTunPrefix(out); err != nil {
return 0, err
}
}
return w.Write(out)
}
// DecryptDatagram decrypts an inbound UDP packet into the payload region.
func (b *WireBuffer) DecryptDatagram(ci *ConnectionState, packet []byte, mc uint64) error {
dst, err := ci.dKey.DecryptDanger(b.ip[:0], packet[:header.Len], packet[header.Len:], mc, b.NB)
if err != nil {
b.bodyN = 0
return err
}
b.bodyN = len(dst)
return nil
}
// DecryptForHandler decrypts an inbound UDP packet (lighthouse, test,
// control, close-tunnel) into the payload region and returns the plaintext
// slice for the in-process handler. Returned slice aliases the buffer.
func (b *WireBuffer) DecryptForHandler(ci *ConnectionState, packet []byte, mc uint64) ([]byte, error) {
dst, err := ci.dKey.DecryptDanger(b.ip[:0], packet[:header.Len], packet[header.Len:], mc, b.NB)
if err != nil {
b.bodyN = 0
return nil, err
}
b.bodyN = len(dst)
return dst, nil
}
// WireBufferAllocator hands out reusable WireBuffers for cold callers that
// don't own a long-lived per-goroutine buffer (control plane, relay manager,
// connection manager teardown, etc.). Hot-path goroutines hold their own
// buffer for the life of the goroutine and don't need to acquire one.
type WireBufferAllocator interface {
Acquire() *WireBuffer
Release(*WireBuffer)
}
// wireBufferPool is a sync.Pool-backed WireBufferAllocator. The pool is
// keyed off a single linkMTU and prefixLen; cold callers send across the
// data-plane mtu and target the same Device, so we size the pool's
// buffers the same way.
type wireBufferPool struct {
pool sync.Pool
}
func NewWireBufferPool(linkMTU, prefixLen int) *wireBufferPool {
return &wireBufferPool{
pool: sync.Pool{
New: func() any {
return NewWireBuffer(linkMTU, prefixLen)
},
},
}
}
func (p *wireBufferPool) Acquire() *WireBuffer {
return p.pool.Get().(*WireBuffer)
}
func (p *wireBufferPool) Release(b *WireBuffer) {
b.Reset()
p.pool.Put(b)
}