diff --git a/connection_manager.go b/connection_manager.go index e7fc04cd..4b0b8896 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -153,8 +153,8 @@ func (cm *connectionManager) Start(ctx context.Context) { defer clockSource.Stop() p := []byte("") - nb := make([]byte, 12, 12) - out := make([]byte, mtu) + // Long-lived buf for the traffic-check goroutine; never released. + buf := cm.intf.bufAlloc.Acquire() for { select { @@ -169,13 +169,13 @@ func (cm *connectionManager) Start(ctx context.Context) { break } - cm.doTrafficCheck(localIndex, p, nb, out, now) + cm.doTrafficCheck(localIndex, p, buf, now) } } } } -func (cm *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) { +func (cm *connectionManager) doTrafficCheck(localIndex uint32, p []byte, buf *WireBuffer, now time.Time) { decision, hostinfo, primary := cm.makeTrafficDecision(localIndex, now) switch decision { @@ -199,7 +199,7 @@ func (cm *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte cm.tryRehandshake(hostinfo) case sendTestPacket: - cm.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out) + cm.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, buf) } cm.resetRelayTrafficCheck(hostinfo) @@ -308,7 +308,9 @@ func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo if err != nil { cm.l.Error("failed to marshal Control message to migrate relay", "error", err) } else { - cm.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu)) + migBuf := cm.intf.bufAlloc.Acquire() + cm.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, migBuf) + cm.intf.bufAlloc.Release(migBuf) cm.l.Info("send CreateRelayRequest", "relayFrom", req.RelayFromAddr, "relayTo", req.RelayToAddr, diff --git a/connection_manager_test.go b/connection_manager_test.go index 7dc08a45..3fe22dd6 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -67,9 +67,9 @@ func Test_NewConnectionManagerTest(t *testing.T) { punchy := NewPunchyFromConfig(test.NewLogger(), conf) nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy) nc.intf = ifce + p := []byte("") - nb := make([]byte, 12, 12) - out := make([]byte, mtu) + buf := NewWireBuffer(mtu, 0) // Add an ip we have established a connection w/ to hostmap hostinfo := &HostInfo{ @@ -92,7 +92,7 @@ func Test_NewConnectionManagerTest(t *testing.T) { assert.True(t, hostinfo.in.Load()) // Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded - nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) + nc.doTrafficCheck(hostinfo.localIndexId, p, buf, time.Now()) assert.False(t, hostinfo.pendingDeletion.Load()) assert.False(t, hostinfo.out.Load()) assert.False(t, hostinfo.in.Load()) @@ -100,7 +100,7 @@ func Test_NewConnectionManagerTest(t *testing.T) { // Do another traffic check tick, this host should be pending deletion now nc.Out(hostinfo) assert.True(t, hostinfo.out.Load()) - nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) + nc.doTrafficCheck(hostinfo.localIndexId, p, buf, time.Now()) assert.True(t, hostinfo.pendingDeletion.Load()) assert.False(t, hostinfo.out.Load()) assert.False(t, hostinfo.in.Load()) @@ -108,7 +108,7 @@ func Test_NewConnectionManagerTest(t *testing.T) { assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) // Do a final traffic check tick, the host should now be removed - nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) + nc.doTrafficCheck(hostinfo.localIndexId, p, buf, time.Now()) assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs) assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId) } @@ -149,9 +149,9 @@ func Test_NewConnectionManagerTest2(t *testing.T) { punchy := NewPunchyFromConfig(test.NewLogger(), conf) nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy) nc.intf = ifce + p := []byte("") - nb := make([]byte, 12, 12) - out := make([]byte, mtu) + buf := NewWireBuffer(mtu, 0) // Add an ip we have established a connection w/ to hostmap hostinfo := &HostInfo{ @@ -174,14 +174,14 @@ func Test_NewConnectionManagerTest2(t *testing.T) { assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) // Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded - nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) + nc.doTrafficCheck(hostinfo.localIndexId, p, buf, time.Now()) assert.False(t, hostinfo.pendingDeletion.Load()) assert.False(t, hostinfo.out.Load()) assert.False(t, hostinfo.in.Load()) // Do another traffic check tick, this host should be pending deletion now nc.Out(hostinfo) - nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) + nc.doTrafficCheck(hostinfo.localIndexId, p, buf, time.Now()) assert.True(t, hostinfo.pendingDeletion.Load()) assert.False(t, hostinfo.out.Load()) assert.False(t, hostinfo.in.Load()) @@ -190,7 +190,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) { // We saw traffic, should no longer be pending deletion nc.In(hostinfo) - nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) + nc.doTrafficCheck(hostinfo.localIndexId, p, buf, time.Now()) assert.False(t, hostinfo.pendingDeletion.Load()) assert.False(t, hostinfo.out.Load()) assert.False(t, hostinfo.in.Load()) diff --git a/control.go b/control.go index ef58988b..7eac47e2 100644 --- a/control.go +++ b/control.go @@ -278,15 +278,9 @@ func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool { } if !localOnly { - c.f.send( - header.CloseTunnel, - 0, - hostInfo.ConnectionState, - hostInfo, - []byte{}, - make([]byte, 12, 12), - make([]byte, mtu), - ) + buf := c.f.bufAlloc.Acquire() + c.f.send(header.CloseTunnel, 0, hostInfo.ConnectionState, hostInfo, []byte{}, buf) + c.f.bufAlloc.Release(buf) } c.f.closeTunnel(hostInfo) @@ -296,11 +290,14 @@ func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool { // CloseAllTunnels is just like CloseTunnel except it goes through and shuts them all down, optionally you can avoid shutting down lighthouse tunnels // the int returned is a count of tunnels closed func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) { + // One WireBuffer for the whole shutdown loop. + buf := c.f.bufAlloc.Acquire() + defer c.f.bufAlloc.Release(buf) shutdown := func(h *HostInfo) { if excludeLighthouses && c.f.lightHouse.IsAnyLighthouseAddr(h.vpnAddrs) { return } - c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) + c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, buf) c.f.closeTunnel(h) c.l.Debug("Sending close tunnel message", diff --git a/e2e/bench_test.go b/e2e/bench_test.go new file mode 100644 index 00000000..421e5e29 --- /dev/null +++ b/e2e/bench_test.go @@ -0,0 +1,68 @@ +//go:build e2e_testing +// +build e2e_testing + +package e2e + +import ( + "testing" + "time" + + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/cert_test" + "github.com/slackhq/nebula/e2e/router" +) + +// BenchmarkHandshake measures end-to-end tunnel establishment time. The two +// nodes and the router are constructed once before the loop so the timed window +// is just the handshake itself: trigger packet -> handshake1 -> handshake2 -> +// cached packet replay -> arrival on the remote TUN. Between iterations we +// tear down both sides locally (no CloseTunnel notification on the wire) and +// re-inject the lighthouse address that closeTunnel cleared, so the next +// iteration runs through a fresh handshake against the same harness. +func BenchmarkHandshake(b *testing.B) { + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + + // Default try_interval is 100ms. The handshake manager schedules handshake1 + // on its OutboundHandshakeTimer rather than firing immediately on trigger + // (the trigger channel only fast-paths static hosts), so a 100ms default + // drowns the actual handshake cost. Drop it to 1ms so the bench reflects + // the computation, not the wheel cadence. + bovr := m{"handshakes": m{"try_interval": "1ms"}} + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", bovr) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "10.128.0.2/24", bovr) + + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) + + myControl.Start() + theirControl.Start() + defer myControl.Stop() + defer theirControl.Stop() + + r := router.NewR(b, myControl, theirControl) + r.CancelFlowLogs() + r.EnableFanIn() + + trigger := BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + + b.ReportAllocs() + b.ResetTimer() + + for n := 0; n < b.N; n++ { + myControl.InjectTunPacket(trigger) + // RouteForAllUntilTxTun returns the moment the cached packet arrives at + // the remote TUN, which is also when both sides are fully established. + _ = r.RouteForAllUntilTxTun(theirControl) + + b.StopTimer() + // Local-only close removes hostmap state on both sides without putting a + // CloseTunnel packet on the wire that we'd then have to drain. The + // closeTunnel path also clears learned lighthouse state for the peer + // when the last hostinfo for that addr goes away, so we re-inject. + myControl.CloseTunnel(theirVpnIpNet[0].Addr(), true) + theirControl.CloseTunnel(myVpnIpNet[0].Addr(), true) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) + b.StartTimer() + } +} diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index d0b9543c..afad12c8 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -165,7 +165,7 @@ func TestGoodHandshakeNoOverlap(t *testing.T) { empty := []byte{} t.Log("do something to cause a handshake") - myControl.GetF().SendMessageToVpnAddr(header.Test, header.MessageNone, theirVpnIpNet[0].Addr(), empty, empty, empty) + myControl.GetF().SendMessageToVpnAddr(header.Test, header.MessageNone, theirVpnIpNet[0].Addr(), empty, nebula.NewWireBuffer(9001, 0)) t.Log("Have them consume my stage 0 packet. They have a tunnel now") theirControl.InjectUDPPacket(myControl.GetFromUDP(true)) diff --git a/handshake_manager.go b/handshake_manager.go index 87257028..f8f32043 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -971,11 +971,11 @@ func (hm *HandshakeManager) continueHandshake(via ViaSender, hh *HandshakeHostIn if f.l.Enabled(context.Background(), slog.LevelDebug) { hostinfo.logger(f.l).Debug("Sending stored packets", "count", len(hh.packetStore)) } - nb := make([]byte, 12, 12) - out := make([]byte, mtu) + buf := f.bufAlloc.Acquire() for _, cp := range hh.packetStore { - cp.callback(cp.messageType, cp.messageSubType, hostinfo, cp.packet, nb, out) + cp.callback(cp.messageType, cp.messageSubType, hostinfo, cp.packet, buf) } + f.bufAlloc.Release(buf) f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore))) } @@ -1085,7 +1085,9 @@ func (hm *HandshakeManager) sendHandshakeResponse(via ViaSender, msg []byte, hos // We received a valid handshake on this relay, so make sure the relay // state reflects that, in case it had been marked Disestablished. via.relayHI.relayState.UpdateRelayForByIdxState(via.remoteIdx, Established) - f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) + buf := f.bufAlloc.Acquire() + f.SendVia(via.relayHI, via.relay, msg, buf) + f.bufAlloc.Release(buf) f.l.Info("Handshake message sent", append(logFields, "relay", via.relayHI.vpnAddrs[0])...) } } @@ -1102,7 +1104,9 @@ func (hm *HandshakeManager) handleCheckAndCompleteError(err error, existing, hos switch err { case ErrAlreadySeen: if existing.SetRemoteIfPreferred(f.hostMap, via) { - f.SendMessageToVpnAddr(header.Test, header.TestRequest, hostinfo.vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) + buf := f.bufAlloc.Acquire() + f.SendMessageToVpnAddr(header.Test, header.TestRequest, hostinfo.vpnAddrs[0], []byte(""), buf) + f.bufAlloc.Release(buf) } // Resend the original response. The peer is committed to that response's // ephemeral keys; a freshly-built one would have different keys and break @@ -1125,7 +1129,9 @@ func (hm *HandshakeManager) handleCheckAndCompleteError(err error, existing, hos "responderIndex", hostinfo.localIndexId, "handshake", hsFields, ) - f.SendMessageToVpnAddr(header.Test, header.TestRequest, hostinfo.vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) + buf := f.bufAlloc.Acquire() + f.SendMessageToVpnAddr(header.Test, header.TestRequest, hostinfo.vpnAddrs[0], []byte(""), buf) + f.bufAlloc.Release(buf) case ErrLocalIndexCollision: f.l.Error("Failed to add HostInfo due to localIndex collision", diff --git a/handshake_manager_test.go b/handshake_manager_test.go index 5f8383e4..f6283342 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -80,15 +80,15 @@ func testCountTimerWheelEntries(tw *LockingTimerWheel[netip.Addr]) (c int) { type mockEncWriter struct { } -func (mw *mockEncWriter) SendMessageToVpnAddr(_ header.MessageType, _ header.MessageSubType, _ netip.Addr, _, _, _ []byte) { +func (mw *mockEncWriter) SendMessageToVpnAddr(_ header.MessageType, _ header.MessageSubType, _ netip.Addr, _ []byte, _ *WireBuffer) { return } -func (mw *mockEncWriter) SendVia(_ *HostInfo, _ *Relay, _, _, _ []byte, _ bool) { +func (mw *mockEncWriter) SendVia(_ *HostInfo, _ *Relay, _ []byte, _ *WireBuffer) { return } -func (mw *mockEncWriter) SendMessageToHostInfo(_ header.MessageType, _ header.MessageSubType, _ *HostInfo, _, _, _ []byte) { +func (mw *mockEncWriter) SendMessageToHostInfo(_ header.MessageType, _ header.MessageSubType, _ *HostInfo, _ []byte, _ *WireBuffer) { return } diff --git a/hostmap.go b/hostmap.go index 08acd1be..4e301181 100644 --- a/hostmap.go +++ b/hostmap.go @@ -308,7 +308,7 @@ type cachedPacket struct { packet []byte } -type packetCallback func(t header.MessageType, st header.MessageSubType, h *HostInfo, p, nb, out []byte) +type packetCallback func(t header.MessageType, st header.MessageSubType, h *HostInfo, p []byte, buf *WireBuffer) type cachedPacketMetrics struct { sent metrics.Counter @@ -691,6 +691,7 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []netip.Prefix, ifce *Interfac } } + buf := ifce.bufAlloc.Acquire() i.remotes.ForEach(preferredRanges, func(addr netip.AddrPort, preferred bool) { if remote.IsValid() && (!addr.IsValid() || !preferred) { return @@ -698,8 +699,9 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []netip.Prefix, ifce *Interfac // Try to send a test packet to that host, this should // cause it to detect a roaming event and switch remotes - ifce.sendTo(header.Test, header.TestRequest, i.ConnectionState, i, addr, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) + ifce.sendTo(header.Test, header.TestRequest, i.ConnectionState, i, addr, []byte(""), buf) }) + ifce.bufAlloc.Release(buf) } // Re query our lighthouses for new remotes occasionally diff --git a/inside.go b/inside.go index 68cb38ec..abdd6f4d 100644 --- a/inside.go +++ b/inside.go @@ -8,12 +8,13 @@ import ( "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" - "github.com/slackhq/nebula/noiseutil" "github.com/slackhq/nebula/routing" ) -func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) { - err := newPacket(packet, false, fwPacket) +func (f *Interface) consumeInsidePacket(buf *WireBuffer, q int, localCache firewall.ConntrackCache) { + packet := buf.IPPacket() + + err := newPacket(packet, false, buf.FwPacket) if err != nil { if f.l.Enabled(context.Background(), slog.LevelDebug) { f.l.Debug("Error while validating outbound packet", @@ -26,12 +27,12 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet // Ignore local broadcast packets if f.dropLocalBroadcast { - if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) { + if f.myBroadcastAddrsTable.Contains(buf.FwPacket.RemoteAddr) { return } } - if f.myVpnAddrsTable.Contains(fwPacket.RemoteAddr) { + if f.myVpnAddrsTable.Contains(buf.FwPacket.RemoteAddr) { // Immediately forward packets from self to self. // This should only happen on Darwin-based and FreeBSD hosts, which // routes packets from the Nebula addr to the Nebula addr through the Nebula @@ -48,20 +49,20 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet } // Ignore multicast packets - if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() { + if f.dropMulticast && buf.FwPacket.RemoteAddr.IsMulticast() { return } - hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) { + hostinfo, ready := f.getOrHandshakeConsiderRouting(buf.FwPacket, func(hh *HandshakeHostInfo) { hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) }) if hostinfo == nil { - f.rejectInside(packet, out, q) + f.rejectInside(packet, buf.Out, q) if f.l.Enabled(context.Background(), slog.LevelDebug) { f.l.Debug("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks", - "vpnAddr", fwPacket.RemoteAddr, - "fwPacket", fwPacket, + "vpnAddr", buf.FwPacket.RemoteAddr, + "fwPacket", buf.FwPacket, ) } return @@ -71,15 +72,15 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet return } - dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) + dropReason := f.firewall.Drop(*buf.FwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) if dropReason == nil { - f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q) + f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, buf, q) } else { - f.rejectInside(packet, out, q) + f.rejectInside(packet, buf.Out, q) if f.l.Enabled(context.Background(), slog.LevelDebug) { hostinfo.logger(f.l).Debug("dropping outbound packet", - "fwPacket", fwPacket, + "fwPacket", buf.FwPacket, "reason", dropReason, ) } @@ -102,27 +103,27 @@ func (f *Interface) rejectInside(packet []byte, out []byte, q int) { } } -func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *HostInfo, nb, out []byte, q int) { +func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *HostInfo, scratch []byte, buf *WireBuffer, q int) { if !f.firewall.OutSendReject { return } - out = iputil.CreateRejectPacket(packet, out) - if len(out) == 0 { + rejectIP := iputil.CreateRejectPacket(packet, scratch) + if len(rejectIP) == 0 { return } - if len(out) > iputil.MaxRejectPacketSize { + if len(rejectIP) > iputil.MaxRejectPacketSize { if f.l.Enabled(context.Background(), slog.LevelInfo) { f.l.Info("rejectOutside: packet too big, not sending", "packet", packet, - "outPacket", out, + "outPacket", rejectIP, ) } return } - f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q) + f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, rejectIP, buf, q) } // Handshake will attempt to initiate a tunnel with the provided vpn address. This is a no-op if the tunnel is already established or being established @@ -215,7 +216,7 @@ func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cac } -func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) { +func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p []byte, buf *WireBuffer) { fp := &firewall.Packet{} err := newPacket(p, false, fp) if err != nil { @@ -235,12 +236,12 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp return } - f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0) + f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, buf, 0) } // SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr. // This function ignores myVpnNetworksTable, and will always attempt to treat the address as a vpnAddr -func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) { +func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p []byte, buf *WireBuffer) { hostInfo, ready := f.handshakeManager.GetOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) { hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics) }) @@ -258,113 +259,73 @@ func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.Message return } - f.SendMessageToHostInfo(t, st, hostInfo, p, nb, out) + f.SendMessageToHostInfo(t, st, hostInfo, p, buf) } -func (f *Interface) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hi *HostInfo, p, nb, out []byte) { - f.send(t, st, hi.ConnectionState, hi, p, nb, out) +func (f *Interface) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hi *HostInfo, p []byte, buf *WireBuffer) { + f.send(t, st, hi.ConnectionState, hi, p, buf) } -func (f *Interface) send(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, p, nb, out []byte) { +func (f *Interface) send(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, p []byte, buf *WireBuffer) { f.messageMetrics.Tx(t, st, 1) - f.sendNoMetrics(t, st, ci, hostinfo, netip.AddrPort{}, p, nb, out, 0) + f.sendNoMetrics(t, st, ci, hostinfo, netip.AddrPort{}, p, buf, 0) } -func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte) { +func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p []byte, buf *WireBuffer) { f.messageMetrics.Tx(t, st, 1) - f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0) + f.sendNoMetrics(t, st, ci, hostinfo, remote, p, buf, 0) } // SendVia sends a payload through a Relay tunnel. No authentication or encryption is done // to the payload for the ultimate target host, making this a useful method for sending // handshake messages to peers through relay tunnels. -// via is the HostInfo through which the message is relayed. -// ad is the plaintext data to authenticate, but not encrypt -// nb is a buffer used to store the nonce value, re-used for performance reasons. -// out is a buffer used to store the result of the Encrypt operation -// q indicates which writer to use to send the packet. -func (f *Interface) SendVia(via *HostInfo, - relay *Relay, - ad, - nb, - out []byte, - nocopy bool, -) { - if noiseutil.EncryptLockNeeded { - // NOTE: for goboring AESGCMTLS we need to lock because of the nonce check - via.ConnectionState.writeLock.Lock() - } - c := via.ConnectionState.messageCounter.Add(1) - - out = header.Encode(out, header.Version, header.Message, header.MessageRelay, relay.RemoteIndex, c) - f.connectionManager.Out(via) - - // Authenticate the header and payload, but do not encrypt for this message type. - // The payload consists of the inner, unencrypted Nebula header, as well as the end-to-end encrypted payload. - if len(out)+len(ad)+via.ConnectionState.eKey.Overhead() > cap(out) { - if noiseutil.EncryptLockNeeded { - via.ConnectionState.writeLock.Unlock() - } +// +// via is the HostInfo through which the message is relayed. ad is staged into +// the inner-payload slot of buf and then AAD-only sealed under via's key by +// SealRelayInPlace. The sendNoMetrics relay-forward path skips this entry +// point and calls sendViaInPlace directly because its inner ciphertext is +// already in place from the encrypt step. +func (f *Interface) SendVia(via *HostInfo, relay *Relay, ad []byte, buf *WireBuffer) { + if header.Len+len(ad)+via.ConnectionState.eKey.Overhead() > cap(buf.Out) { via.logger(f.l).Error("SendVia out buffer not large enough for relay", - "outCap", cap(out), + "outCap", cap(buf.Out), "payloadLen", len(ad), - "headerLen", len(out), + "headerLen", header.Len, "cipherOverhead", via.ConnectionState.eKey.Overhead(), ) return } + buf.StageRelayInner(ad) + f.sendViaInPlace(via, relay, len(ad), buf) +} - // The header bytes are written to the 'out' slice; Grow the slice to hold the header and associated data payload. - offset := len(out) - out = out[:offset+len(ad)] - - // In one call path, the associated data _is_ already stored in out. In other call paths, the associated data must - // be copied into 'out'. - if !nocopy { - copy(out[offset:], ad) - } - - var err error - out, err = via.ConnectionState.eKey.EncryptDanger(out, out, nil, c, nb) - if noiseutil.EncryptLockNeeded { - via.ConnectionState.writeLock.Unlock() - } +// sendViaInPlace stamps the outer relay header, AAD-seals over the [outer +// header | inner-already-staged] region, and writes the result to via.remote. +// Called from SendVia (after staging ad) and from sendNoMetrics' relay-forward +// path (where the inner ciphertext is already in place from SealForRelay). +func (f *Interface) sendViaInPlace(via *HostInfo, relay *Relay, innerLen int, buf *WireBuffer) { + f.connectionManager.Out(via) + out, err := buf.SealRelayInPlace(via.ConnectionState, relay.RemoteIndex, innerLen) if err != nil { via.logger(f.l).Info("Failed to EncryptDanger in sendVia", "error", err) return } - err = f.writers[0].WriteTo(out, via.remote) - if err != nil { + if err := f.writers[0].WriteTo(out, via.remote); err != nil { via.logger(f.l).Info("Failed to WriteTo in sendVia", "error", err) } f.connectionManager.RelayUsed(relay.LocalIndex) } -func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int) { +// sendNoMetrics encrypts and writes one outbound nebula packet (data, control, +// lighthouse, etc) using buf as the per-call wire scratch. When the hostinfo +// has no direct remote we encrypt into the relay-reserved slot via +// SealForRelay so sendViaInPlace can wrap it without an extra copy. +func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p []byte, buf *WireBuffer, q int) { if ci.eKey == nil { return } useRelay := !remote.IsValid() && !hostinfo.remote.IsValid() - fullOut := out - if useRelay { - if len(out) < header.Len { - // out always has a capacity of mtu, but not always a length greater than the header.Len. - // Grow it to make sure the next operation works. - out = out[:header.Len] - } - // Save a header's worth of data at the front of the 'out' buffer. - out = out[header.Len:] - } - - if noiseutil.EncryptLockNeeded { - // NOTE: for goboring AESGCMTLS we need to lock because of the nonce check - ci.writeLock.Lock() - } - c := ci.messageCounter.Add(1) - - //l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p) - out = header.Encode(out, header.Version, t, st, hostinfo.remoteIndexId, c) f.connectionManager.Out(hostinfo) // Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against @@ -381,50 +342,42 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType } } + var out []byte var err error - out, err = ci.eKey.EncryptDanger(out, out, p, c, nb) - if noiseutil.EncryptLockNeeded { - ci.writeLock.Unlock() + if useRelay { + out, err = buf.SealForRelay(ci, t, st, hostinfo.remoteIndexId, p) + } else { + out, err = buf.Seal(ci, t, st, hostinfo.remoteIndexId, p) } if err != nil { hostinfo.logger(f.l).Error("Failed to encrypt outgoing packet", "error", err, "udpAddr", remote, - "counter", c, - "attemptedCounter", c, ) return } - if remote.IsValid() { - err = f.writers[q].WriteTo(out, remote) - if err != nil { - hostinfo.logger(f.l).Error("Failed to write outgoing packet", - "error", err, - "udpAddr", remote, - ) + switch { + case remote.IsValid(): + if err := f.writers[q].WriteTo(out, remote); err != nil { + hostinfo.logger(f.l).Error("Failed to write outgoing packet", "error", err, "udpAddr", remote) } - } else if hostinfo.remote.IsValid() { - err = f.writers[q].WriteTo(out, hostinfo.remote) - if err != nil { - hostinfo.logger(f.l).Error("Failed to write outgoing packet", - "error", err, - "udpAddr", remote, - ) + case hostinfo.remote.IsValid(): + if err := f.writers[q].WriteTo(out, hostinfo.remote); err != nil { + hostinfo.logger(f.l).Error("Failed to write outgoing packet", "error", err, "udpAddr", hostinfo.remote) } - } else { - // Try to send via a relay + default: + // SealForRelay placed the inner ciphertext at buf.Out[header.Len:], + // so sendViaInPlace can wrap it with the outer relay header without + // an extra copy. for _, relayIP := range hostinfo.relayState.CopyRelayIps() { relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP) if err != nil { hostinfo.relayState.DeleteRelay(relayIP) - hostinfo.logger(f.l).Info("sendNoMetrics failed to find HostInfo", - "relay", relayIP, - "error", err, - ) + hostinfo.logger(f.l).Info("sendNoMetrics failed to find HostInfo", "relay", relayIP, "error", err) continue } - f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true) + f.sendViaInPlace(relayHostInfo, relay, len(out), buf) break } } diff --git a/interface.go b/interface.go index 5fedcdd3..1e18436b 100644 --- a/interface.go +++ b/interface.go @@ -101,19 +101,19 @@ type Interface struct { messageMetrics *MessageMetrics cachedPacketMetrics *cachedPacketMetrics + // bufAlloc hands out reusable WireBuffers sized for this interface's + // inside Device. All buf consumers (hot-path data-plane goroutines, + // long-lived workers, and cold callers) acquire from here so sizing + // is centralized and consistent. Long-lived owners just don't release. + bufAlloc WireBufferAllocator + l *slog.Logger } type EncWriter interface { - SendVia(via *HostInfo, - relay *Relay, - ad, - nb, - out []byte, - nocopy bool, - ) - SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) - SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) + SendVia(via *HostInfo, relay *Relay, ad []byte, buf *WireBuffer) + SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p []byte, buf *WireBuffer) + SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p []byte, buf *WireBuffer) Handshake(vpnAddr netip.Addr) GetHostInfo(vpnAddr netip.Addr) *HostInfo GetCertState() *CertState @@ -204,6 +204,8 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil), }, + bufAlloc: NewWireBufferPool(mtu, c.Inside.TunPrefixLen()), + l: c.l, } @@ -311,13 +313,11 @@ func (f *Interface) listenOut(i int) { ctCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout) lhh := f.lightHouse.NewRequestHandler() - plaintext := make([]byte, udp.MTU) - h := &header.H{} - fwPacket := &firewall.Packet{} - nb := make([]byte, 12, 12) + // Long-lived per-receive-goroutine buf; never released back to the pool. + buf := f.bufAlloc.Acquire() err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { - f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get()) + f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, buf, payload, lhh, i, ctCache.Get()) }) if err != nil && !f.closed.Load() { @@ -329,15 +329,12 @@ func (f *Interface) listenOut(i int) { } func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { - packet := make([]byte, mtu) - out := make([]byte, mtu) - fwPacket := &firewall.Packet{} - nb := make([]byte, 12, 12) - + // Long-lived per-tun-reader buf; never released back to the pool. + buf := f.bufAlloc.Acquire() conntrackCache := firewall.NewConntrackCacheTicker(f.ctx, f.l, f.conntrackCacheTimeout) for { - n, err := reader.Read(packet) + _, err := buf.ReadIPFromTUN(reader) if err != nil { if !f.closed.Load() { f.l.Error("Error while reading outbound packet, closing", "error", err, "reader", i) @@ -346,7 +343,7 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { break } - f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get()) + f.consumeInsidePacket(buf, i, conntrackCache.Get()) } f.l.Debug("overlay reader is done", "reader", i) diff --git a/lighthouse.go b/lighthouse.go index 6034e68c..08484237 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -63,7 +63,11 @@ type LightHouse struct { interval atomic.Int64 updateCancel context.CancelFunc ifce EncWriter - nebulaPort uint32 // 32 bits because protobuf does not have a uint16 + // bufAlloc lets the lighthouse query/update workers, request handlers + // and punchback goroutines acquire correctly sized WireBuffers from + // the same pool as the data plane. Set by main.go alongside ifce. + bufAlloc WireBufferAllocator + nebulaPort uint32 // 32 bits because protobuf does not have a uint16 advertiseAddrs atomic.Pointer[[]netip.AddrPort] @@ -109,7 +113,11 @@ func NewLightHouseFromConfig(ctx context.Context, l *slog.Logger, c *config.C, c punchy: p, updateTrigger: make(chan struct{}, 1), queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)), - l: l, + // Default to a no-prefix pool so the query/update workers and + // request handlers have a working WireBufferAllocator before + // main.go wires up the real one from the Interface. + bufAlloc: NewWireBufferPool(mtu, 0), + l: l, } lighthouses := make([]netip.Addr, 0) h.lighthouses.Store(&lighthouses) @@ -758,21 +766,22 @@ func (lh *LightHouse) startQueryWorker() { } go func() { - nb := make([]byte, 12, 12) - out := make([]byte, mtu) + // Long-lived per-worker WireBuffer; reused for every lighthouse query + // this worker issues for the life of the goroutine. + buf := lh.bufAlloc.Acquire() for { select { case <-lh.ctx.Done(): return case addr := <-lh.queryChan: - lh.innerQueryServer(addr, nb, out) + lh.innerQueryServer(addr, buf) } } }() } -func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) { +func (lh *LightHouse) innerQueryServer(addr netip.Addr, buf *WireBuffer) { if lh.IsLighthouseAddr(addr) { return } @@ -821,7 +830,7 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) { } } - lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, v1Query, nb, out) + lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, v1Query, buf) queried++ } else if v == cert.Version2 { @@ -840,7 +849,7 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) { } } - lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, v2Query, nb, out) + lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, v2Query, buf) queried++ } else { @@ -869,8 +878,12 @@ func (lh *LightHouse) StartUpdateWorker() { go func() { defer clockSource.Stop() + // Long-lived per-worker WireBuffer; reused across every periodic + // update for the life of this goroutine. + buf := lh.bufAlloc.Acquire() + for { - lh.SendUpdate() + lh.sendUpdate(buf) select { case <-updateCtx.Done(): @@ -884,6 +897,15 @@ func (lh *LightHouse) StartUpdateWorker() { }() } +// SendUpdate is the public entry point that triggers a one-shot lighthouse +// update outside the worker loop (e.g. tests or reload paths). It allocates +// its own WireBuffer since callers don't already own one. +func (lh *LightHouse) SendUpdate() { + buf := lh.bufAlloc.Acquire() + defer lh.bufAlloc.Release(buf) + lh.sendUpdate(buf) +} + // TriggerUpdate requests an immediate lighthouse update. This is a non-blocking // operation intended to be called after a handshake completes with a lighthouse, // so the lighthouse has our current addresses without waiting for the next @@ -895,7 +917,7 @@ func (lh *LightHouse) TriggerUpdate() { } } -func (lh *LightHouse) SendUpdate() { +func (lh *LightHouse) sendUpdate(buf *WireBuffer) { var v4 []*V4AddrPort var v6 []*V6AddrPort @@ -921,9 +943,6 @@ func (lh *LightHouse) SendUpdate() { } } - nb := make([]byte, 12, 12) - out := make([]byte, mtu) - var v1Update, v2Update []byte var err error updated := 0 @@ -974,7 +993,7 @@ func (lh *LightHouse) SendUpdate() { } } - lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, v1Update, nb, out) + lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, v1Update, buf) updated++ } else if v == cert.Version2 { @@ -1003,7 +1022,7 @@ func (lh *LightHouse) SendUpdate() { } } - lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, v2Update, nb, out) + lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, v2Update, buf) updated++ } else { @@ -1019,9 +1038,11 @@ func (lh *LightHouse) SendUpdate() { } type LightHouseHandler struct { - lh *LightHouse - nb []byte - out []byte + lh *LightHouse + // buf is the long-lived per-handler wire scratch. NewRequestHandler is + // called once per data-plane receive goroutine, so buf is owned by that + // goroutine and reused for every lighthouse send the handler issues. + buf *WireBuffer pb []byte meta *NebulaMeta l *slog.Logger @@ -1030,8 +1051,7 @@ type LightHouseHandler struct { func (lh *LightHouse) NewRequestHandler() *LightHouseHandler { lhh := &LightHouseHandler{ lh: lh, - nb: make([]byte, 12, 12), - out: make([]byte, mtu), + buf: lh.bufAlloc.Acquire(), l: lh.l, pb: make([]byte, mtu), @@ -1168,7 +1188,7 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti } lhh.lh.metricTx(NebulaMeta_HostQueryReply, 1) - w.SendMessageToVpnAddr(header.LightHouse, 0, fromVpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0]) + w.SendMessageToVpnAddr(header.LightHouse, 0, fromVpnAddrs[0], lhh.pb[:ln], lhh.buf) lhh.sendHostPunchNotification(n, fromVpnAddrs, queryVpnAddr, w) } @@ -1228,7 +1248,7 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd } lhh.lh.metricTx(NebulaMeta_HostPunchNotification, 1) - w.SendMessageToVpnAddr(header.LightHouse, 0, punchNotifDest, lhh.pb[:ln], lhh.nb, lhh.out[:0]) + w.SendMessageToVpnAddr(header.LightHouse, 0, punchNotifDest, lhh.pb[:ln], lhh.buf) } func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *NebulaMeta) { @@ -1385,7 +1405,7 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp } lhh.lh.metricTx(NebulaMeta_HostUpdateNotificationAck, 1) - w.SendMessageToVpnAddr(header.LightHouse, 0, fromVpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0]) + w.SendMessageToVpnAddr(header.LightHouse, 0, fromVpnAddrs[0], lhh.pb[:ln], lhh.buf) } func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) { @@ -1452,10 +1472,13 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn "vpnAddr", detailsVpnAddr, ) } - //NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine - // for each punchBack packet. We should move this into a timerwheel or a single goroutine + // We acquire and release a fresh buf within this goroutine so it + // returns to the pool once the punchback send completes. We + // should move this into a timerwheel or a single goroutine // managed by a channel. - w.SendMessageToVpnAddr(header.Test, header.TestRequest, detailsVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) + pbuf := lhh.lh.bufAlloc.Acquire() + defer lhh.lh.bufAlloc.Release(pbuf) + w.SendMessageToVpnAddr(header.Test, header.TestRequest, detailsVpnAddr, []byte(""), pbuf) }() } } diff --git a/lighthouse_test.go b/lighthouse_test.go index c57c44ec..9467d338 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -372,12 +372,12 @@ type testEncWriter struct { protocolVersion cert.Version } -func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) { +func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad []byte, buf *WireBuffer) { } func (tw *testEncWriter) Handshake(vpnIp netip.Addr) { } -func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, _, _ []byte) { +func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p []byte, _ *WireBuffer) { msg := &NebulaMeta{} err := msg.Unmarshal(p) if tw.metaFilter == nil || msg.Type == *tw.metaFilter { @@ -394,7 +394,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M } } -func (tw *testEncWriter) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, _, _ []byte) { +func (tw *testEncWriter) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p []byte, _ *WireBuffer) { msg := &NebulaMeta{} err := msg.Unmarshal(p) if tw.metaFilter == nil || msg.Type == *tw.metaFilter { diff --git a/main.go b/main.go index d5e5dcc8..099eab25 100644 --- a/main.go +++ b/main.go @@ -232,6 +232,7 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev ifce.writers = udpConns lightHouse.ifce = ifce + lightHouse.bufAlloc = ifce.bufAlloc ifce.RegisterConfigChangeCallbacks(c) ifce.reloadDisconnectInvalid(c) diff --git a/noise.go b/noise.go index 0491da17..0c6ce22a 100644 --- a/noise.go +++ b/noise.go @@ -14,6 +14,19 @@ type endianness interface { var noiseEndianness endianness = binary.BigEndian +// NonceSize is the AEAD nonce length used by all ciphers nebula supports +// today (AES-GCM and ChaCha20-Poly1305 both use 96-bit nonces). Encrypt- +// and DecryptDanger lay out the nonce as 4 zero bytes followed by an 8-byte +// big-endian counter; if a future cipher with a different nonce size is +// added, this constant and those layouts must change together. +const NonceSize = 12 + +// AEADOverhead is the AEAD authentication tag length the ciphers nebula +// supports append to ciphertext. Both AES-GCM and ChaCha20-Poly1305 use +// 128-bit tags. NebulaCipherState.Overhead() returns this dynamically from +// the cipher; the constant is for sizing buffers at construction time. +const AEADOverhead = 16 + type NebulaCipherState struct { c cipher.AEAD } diff --git a/outside.go b/outside.go index 1e00a0a9..6860dd3d 100644 --- a/outside.go +++ b/outside.go @@ -20,7 +20,8 @@ const ( minFwPacketLen = 4 ) -func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { +func (f *Interface) readOutsidePackets(via ViaSender, buf *WireBuffer, packet []byte, lhf *LightHouseHandler, q int, localCache firewall.ConntrackCache) { + h := buf.H err := h.Parse(packet) if err != nil { // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors @@ -65,7 +66,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, switch h.Subtype { case header.MessageNone: - if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache) { + if !f.decryptToTun(hostinfo, h.MessageCounter, buf, packet, q, localCache) { return } case header.MessageRelay: @@ -76,8 +77,9 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, // which will gracefully fail in the DecryptDanger call. signedPayload := packet[:len(packet)-hostinfo.ConnectionState.dKey.Overhead()] signatureValue := packet[len(packet)-hostinfo.ConnectionState.dKey.Overhead():] - out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, signedPayload, signatureValue, h.MessageCounter, nb) - if err != nil { + // AAD-only validation: passing dst=nil since there's no plaintext + // to recover (ciphertext is just the trailing AEAD tag). + if _, err = hostinfo.ConnectionState.dKey.DecryptDanger(nil, signedPayload, signatureValue, h.MessageCounter, buf.NB); err != nil { return } // Successfully validated the thing. Get rid of the Relay header. @@ -110,7 +112,8 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, relay: relay, IsRelayed: true, } - f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) + buf.Reset() + f.readOutsidePackets(via, buf, signedPayload, lhf, q, localCache) return case ForwardingType: // Find the target HostInfo relay object @@ -130,7 +133,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, case ForwardingType: // Forward this packet through the relay tunnel // Find the target HostInfo - f.SendVia(targetHI, targetRelay, signedPayload, nb, out, false) + f.SendVia(targetHI, targetRelay, signedPayload, buf) return case TerminalType: hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal") @@ -152,7 +155,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, return } - d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) + d, err := f.decrypt(hostinfo, h.MessageCounter, buf, packet, h) if err != nil { hostinfo.logger(f.l).Error("Failed to decrypt lighthouse packet", "error", err, @@ -173,7 +176,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, return } - d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) + d, err := f.decrypt(hostinfo, h.MessageCounter, buf, packet, h) if err != nil { hostinfo.logger(f.l).Error("Failed to decrypt test packet", "error", err, @@ -185,9 +188,9 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, if h.Subtype == header.TestRequest { // This testRequest might be from TryPromoteBest, so we should roam - // to the new IP address before responding + // to the new IP address before responding. f.handleHostRoaming(hostinfo, via) - f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out) + f.send(header.Test, header.TestReply, ci, hostinfo, d, buf) } // Fallthrough to the bottom to record incoming traffic @@ -210,7 +213,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, if !f.handleEncrypted(ci, via, h) { return } - _, err = f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) + _, err = f.decrypt(hostinfo, h.MessageCounter, buf, packet, h) if err != nil { hostinfo.logger(f.l).Error("Failed to decrypt CloseTunnel packet", "error", err, @@ -230,7 +233,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, return } - d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) + d, err := f.decrypt(hostinfo, h.MessageCounter, buf, packet, h) if err != nil { hostinfo.logger(f.l).Error("Failed to decrypt Control packet", "error", err, @@ -266,7 +269,9 @@ func (f *Interface) closeTunnel(hostInfo *HostInfo) { // sendCloseTunnel is a helper function to send a proper close tunnel packet to a remote func (f *Interface) sendCloseTunnel(h *HostInfo) { - f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) + buf := f.bufAlloc.Acquire() + defer f.bufAlloc.Release(buf) + f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, buf) } func (f *Interface) handleHostRoaming(hostinfo *HostInfo, via ViaSender) { @@ -515,9 +520,8 @@ func parseV4(data []byte, incoming bool, fp *firewall.Packet) error { return nil } -func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []byte, h *header.H, nb []byte) ([]byte, error) { - var err error - out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], mc, nb) +func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, buf *WireBuffer, packet []byte, h *header.H) ([]byte, error) { + plaintext, err := buf.DecryptForHandler(hostinfo.ConnectionState, packet, mc) if err != nil { return nil, err } @@ -529,42 +533,41 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet [] return nil, errors.New("out of window packet") } - return out, nil + return plaintext, nil } -func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool { - var err error - - out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb) - if err != nil { +func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, buf *WireBuffer, packet []byte, q int, localCache firewall.ConntrackCache) bool { + if err := buf.DecryptDatagram(hostinfo.ConnectionState, packet, messageCounter); err != nil { hostinfo.logger(f.l).Error("Failed to decrypt packet", "error", err) return false } - err = newPacket(out, true, fwPacket) - if err != nil { + ipPacket := buf.IPPacket() + if err := newPacket(ipPacket, true, buf.FwPacket); err != nil { hostinfo.logger(f.l).Warn("Error while validating inbound packet", "error", err, - "packet", out, + "packet", ipPacket, ) return false } if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) { if f.l.Enabled(context.Background(), slog.LevelDebug) { - hostinfo.logger(f.l).Debug("dropping out of window packet", "fwPacket", fwPacket) + hostinfo.logger(f.l).Debug("dropping out of window packet", "fwPacket", buf.FwPacket) } return false } - dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache) + dropReason := f.firewall.Drop(*buf.FwPacket, true, hostinfo, f.pki.GetCAPool(), localCache) if dropReason != nil { - // NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore - // This gives us a buffer to build the reject packet in - f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q) + // NOTE: We hand `packet` (the original UDP ciphertext we already + // decrypted from) as the reject-IP scratch since we no longer + // need its ciphertext, and it's disjoint from buf.Out where + // sendNoMetrics will encrypt the wire packet. + f.rejectOutside(ipPacket, hostinfo.ConnectionState, hostinfo, packet, buf, q) if f.l.Enabled(context.Background(), slog.LevelDebug) { hostinfo.logger(f.l).Debug("dropping inbound packet", - "fwPacket", fwPacket, + "fwPacket", buf.FwPacket, "reason", dropReason, ) } @@ -572,8 +575,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out } f.connectionManager.In(hostinfo) - _, err = f.readers[q].Write(out) - if err != nil { + if _, err := buf.WriteIPToTUN(f.readers[q]); err != nil { f.l.Error("Failed to write to tun", "error", err) } return true diff --git a/overlay/device.go b/overlay/device.go index b6077aba..cfba8f7e 100644 --- a/overlay/device.go +++ b/overlay/device.go @@ -15,4 +15,7 @@ type Device interface { RoutesFor(netip.Addr) routing.Gateways SupportsMultiqueue() bool NewMultiQueueReader() (io.ReadWriteCloser, error) + // TunPrefixLen reports the number of bytes the device prepends to every IP packet on the wire. + // Currently only non zero for the BSD tun devices. + TunPrefixLen() int } diff --git a/overlay/overlaytest/noop.go b/overlay/overlaytest/noop.go index 956da7dd..8f1059c6 100644 --- a/overlay/overlaytest/noop.go +++ b/overlay/overlaytest/noop.go @@ -50,3 +50,5 @@ func (NoopTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (NoopTun) Close() error { return nil } + +func (NoopTun) TunPrefixLen() int { return 0 } diff --git a/overlay/tun_android.go b/overlay/tun_android.go index 9cbb64be..0cc697dc 100644 --- a/overlay/tun_android.go +++ b/overlay/tun_android.go @@ -102,3 +102,5 @@ func (t *tun) SupportsMultiqueue() bool { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for android") } + +func (t *tun) TunPrefixLen() int { return 0 } diff --git a/overlay/tun_bsd.go b/overlay/tun_bsd.go new file mode 100644 index 00000000..6c8e218d --- /dev/null +++ b/overlay/tun_bsd.go @@ -0,0 +1,29 @@ +//go:build (darwin || ios || freebsd || openbsd || netbsd) && !e2e_testing + +package overlay + +import ( + "fmt" + "syscall" +) + +// StampTunPrefix writes the 4-byte AF_INET / AF_INET6 protocol-family marker into buf[0:4] in place, +// picking the family from the first byte of the IP packet at buf[4]. +func StampTunPrefix(buf []byte) error { + if len(buf) < 5 { + return fmt.Errorf("tun write buffer too small for prefix") + } + ipVer := buf[4] >> 4 + buf[0] = 0 + buf[1] = 0 + buf[2] = 0 + switch ipVer { + case 4: + buf[3] = syscall.AF_INET + case 6: + buf[3] = syscall.AF_INET6 + default: + return fmt.Errorf("unable to determine IP version from packet") + } + return nil +} diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index 524ef0cd..86520a8f 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -11,7 +11,6 @@ import ( "net/netip" "os" "sync/atomic" - "syscall" "unsafe" "github.com/gaissmai/bart" @@ -31,9 +30,6 @@ type tun struct { routeTree atomic.Pointer[bart.Table[routing.Gateways]] linkAddr *netroute.LinkAddr l *slog.Logger - - // cache out buffer since we need to prepend 4 bytes for tun metadata - out []byte } type ifReq struct { @@ -502,44 +498,6 @@ func delRoute(prefix netip.Prefix, gateway netroute.Addr) error { return nil } -func (t *tun) Read(to []byte) (int, error) { - buf := make([]byte, len(to)+4) - - n, err := t.ReadWriteCloser.Read(buf) - - copy(to, buf[4:]) - return n - 4, err -} - -// Write is only valid for single threaded use -func (t *tun) Write(from []byte) (int, error) { - buf := t.out - if cap(buf) < len(from)+4 { - buf = make([]byte, len(from)+4) - t.out = buf - } - buf = buf[:len(from)+4] - - if len(from) == 0 { - return 0, syscall.EIO - } - - // Determine the IP Family for the NULL L2 Header - ipVer := from[0] >> 4 - if ipVer == 4 { - buf[3] = syscall.AF_INET - } else if ipVer == 6 { - buf[3] = syscall.AF_INET6 - } else { - return 0, fmt.Errorf("unable to determine IP version from packet") - } - - copy(buf[4:], from) - - n, err := t.ReadWriteCloser.Write(buf) - return n - 4, err -} - func (t *tun) Networks() []netip.Prefix { return t.vpnNetworks } @@ -555,3 +513,7 @@ func (t *tun) SupportsMultiqueue() bool { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin") } + +// TunPrefixLen reports the 4-byte BSD AF_INET / AF_INET6 protocol-family +// marker the kernel prepends on read and expects on write. +func (t *tun) TunPrefixLen() int { return 4 } diff --git a/overlay/tun_disabled.go b/overlay/tun_disabled.go index f47880dd..620a9f18 100644 --- a/overlay/tun_disabled.go +++ b/overlay/tun_disabled.go @@ -136,3 +136,5 @@ func (p prettyPacket) String() string { return s.String() } + +func (t *disabledTun) TunPrefixLen() int { return 0 } diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index 3d995553..2dff2014 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -158,74 +158,43 @@ func (t *tun) blockOnWrite() error { } func (t *tun) Read(to []byte) (int, error) { - // first 4 bytes is protocol family, in network byte order - var head [4]byte - iovecs := [2]syscall.Iovec{ - {&head[0], 4}, - {&to[0], uint64(len(to))}, - } for { - n, _, errno := syscall.Syscall(syscall.SYS_READV, uintptr(t.fd), uintptr(unsafe.Pointer(&iovecs[0])), 2) - if errno == 0 { - bytesRead := int(n) - if bytesRead < 4 { - return 0, nil - } - return bytesRead - 4, nil + n, err := unix.Read(t.fd, to) + if err == nil { + return n, nil } - switch errno { + switch err { case unix.EAGAIN: - if err := t.blockOnRead(); err != nil { - return 0, err + if berr := t.blockOnRead(); berr != nil { + return 0, berr } case unix.EINTR: // retry case unix.EBADF: return 0, os.ErrClosed default: - return 0, errno + return 0, err } } } -// Write is only valid for single threaded use func (t *tun) Write(from []byte) (int, error) { - if len(from) <= 1 { - return 0, syscall.EIO - } - - ipVer := from[0] >> 4 - var head [4]byte - // first 4 bytes is protocol family, in network byte order - switch ipVer { - case 4: - head[3] = syscall.AF_INET - case 6: - head[3] = syscall.AF_INET6 - default: - return 0, fmt.Errorf("unable to determine IP version from packet") - } - - iovecs := [2]syscall.Iovec{ - {&head[0], 4}, - {&from[0], uint64(len(from))}, - } for { - n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, uintptr(t.fd), uintptr(unsafe.Pointer(&iovecs[0])), 2) - if errno == 0 { - return int(n) - 4, nil + n, err := unix.Write(t.fd, from) + if err == nil { + return n, nil } - switch errno { + switch err { case unix.EAGAIN: - if err := t.blockOnWrite(); err != nil { - return 0, err + if berr := t.blockOnWrite(); berr != nil { + return 0, berr } case unix.EINTR: // retry case unix.EBADF: return 0, os.ErrClosed default: - return 0, errno + return 0, err } } } @@ -732,3 +701,7 @@ func getLinkAddr(name string) (*netroute.LinkAddr, error) { return nil, nil } + +// TunPrefixLen reports the 4-byte BSD AF_INET / AF_INET6 protocol-family +// marker the kernel prepends on read and expects on write. +func (t *tun) TunPrefixLen() int { return 4 } diff --git a/overlay/tun_ios.go b/overlay/tun_ios.go index 6bfcbdfb..678231ab 100644 --- a/overlay/tun_ios.go +++ b/overlay/tun_ios.go @@ -4,15 +4,12 @@ package overlay import ( - "errors" "fmt" "io" "log/slog" "net/netip" "os" - "sync" "sync/atomic" - "syscall" "github.com/gaissmai/bart" "github.com/slackhq/nebula/config" @@ -36,7 +33,7 @@ func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip file := os.NewFile(uintptr(deviceFd), "/dev/tun") t := &tun{ vpnNetworks: vpnNetworks, - ReadWriteCloser: &tunReadCloser{f: file}, + ReadWriteCloser: file, l: l, } @@ -85,64 +82,6 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { return r } -// The following is hoisted up from water, we do this so we can inject our own fd on iOS -type tunReadCloser struct { - f io.ReadWriteCloser - - rMu sync.Mutex - rBuf []byte - - wMu sync.Mutex - wBuf []byte -} - -func (tr *tunReadCloser) Read(to []byte) (int, error) { - tr.rMu.Lock() - defer tr.rMu.Unlock() - - if cap(tr.rBuf) < len(to)+4 { - tr.rBuf = make([]byte, len(to)+4) - } - tr.rBuf = tr.rBuf[:len(to)+4] - - n, err := tr.f.Read(tr.rBuf) - copy(to, tr.rBuf[4:]) - return n - 4, err -} - -func (tr *tunReadCloser) Write(from []byte) (int, error) { - if len(from) == 0 { - return 0, syscall.EIO - } - - tr.wMu.Lock() - defer tr.wMu.Unlock() - - if cap(tr.wBuf) < len(from)+4 { - tr.wBuf = make([]byte, len(from)+4) - } - tr.wBuf = tr.wBuf[:len(from)+4] - - // Determine the IP Family for the NULL L2 Header - ipVer := from[0] >> 4 - if ipVer == 4 { - tr.wBuf[3] = syscall.AF_INET - } else if ipVer == 6 { - tr.wBuf[3] = syscall.AF_INET6 - } else { - return 0, errors.New("unable to determine IP version from packet") - } - - copy(tr.wBuf[4:], from) - - n, err := tr.f.Write(tr.wBuf) - return n - 4, err -} - -func (tr *tunReadCloser) Close() error { - return tr.f.Close() -} - func (t *tun) Networks() []netip.Prefix { return t.vpnNetworks } @@ -158,3 +97,7 @@ func (t *tun) SupportsMultiqueue() bool { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for ios") } + +// TunPrefixLen reports the 4-byte BSD AF_INET / AF_INET6 protocol-family +// marker the kernel prepends on read and expects on write. +func (t *tun) TunPrefixLen() int { return 4 } diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index c6cfb686..96cefa4c 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -907,3 +907,5 @@ func (t *tun) Close() error { } return err } + +func (t *tun) TunPrefixLen() int { return 0 } diff --git a/overlay/tun_netbsd.go b/overlay/tun_netbsd.go index c971bb6e..51a18eec 100644 --- a/overlay/tun_netbsd.go +++ b/overlay/tun_netbsd.go @@ -58,13 +58,13 @@ type addrLifetime struct { } type tun struct { + io.ReadWriteCloser Device string vpnNetworks []netip.Prefix MTU int Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] l *slog.Logger - f *os.File fd int } @@ -96,12 +96,12 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*t } t := &tun{ - f: os.NewFile(uintptr(fd), ""), - fd: fd, - Device: deviceName, - vpnNetworks: vpnNetworks, - MTU: c.GetInt("tun.mtu", DefaultMTU), - l: l, + ReadWriteCloser: os.NewFile(uintptr(fd), ""), + fd: fd, + Device: deviceName, + vpnNetworks: vpnNetworks, + MTU: c.GetInt("tun.mtu", DefaultMTU), + l: l, } err = t.reload(c, true) @@ -120,12 +120,12 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*t } func (t *tun) Close() error { - if t.f != nil { - if err := t.f.Close(); err != nil { + if t.ReadWriteCloser != nil { + if err := t.ReadWriteCloser.Close(); err != nil { return fmt.Errorf("error closing tun file: %w", err) } - // t.f.Close should have handled it for us but let's be extra sure + // Close on the os.File should have handled the fd for us but let's be extra sure _ = unix.Close(t.fd) s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP) @@ -141,99 +141,6 @@ func (t *tun) Close() error { return nil } -func (t *tun) Read(to []byte) (int, error) { - rc, err := t.f.SyscallConn() - if err != nil { - return 0, fmt.Errorf("failed to get syscall conn for tun: %w", err) - } - - var errno syscall.Errno - var n uintptr - err = rc.Read(func(fd uintptr) bool { - // first 4 bytes is protocol family, in network byte order - head := [4]byte{} - iovecs := []syscall.Iovec{ - {&head[0], 4}, - {&to[0], uint64(len(to))}, - } - - n, _, errno = syscall.Syscall(syscall.SYS_READV, fd, uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2)) - if errno.Temporary() { - // We got an EAGAIN, EINTR, or EWOULDBLOCK, go again - return false - } - return true - }) - if err != nil { - if err == syscall.EBADF || err.Error() == "use of closed file" { - // Go doesn't export poll.ErrFileClosing but happily reports it to us so here we are - // https://github.com/golang/go/blob/master/src/internal/poll/fd_poll_runtime.go#L121 - return 0, os.ErrClosed - } - return 0, fmt.Errorf("failed to make read call for tun: %w", err) - } - - if errno != 0 { - return 0, fmt.Errorf("failed to make inner read call for tun: %w", errno) - } - - // fix bytes read number to exclude header - bytesRead := int(n) - if bytesRead < 0 { - return bytesRead, nil - } else if bytesRead < 4 { - return 0, nil - } else { - return bytesRead - 4, nil - } -} - -// Write is only valid for single threaded use -func (t *tun) Write(from []byte) (int, error) { - if len(from) <= 1 { - return 0, syscall.EIO - } - - ipVer := from[0] >> 4 - var head [4]byte - // first 4 bytes is protocol family, in network byte order - if ipVer == 4 { - head[3] = syscall.AF_INET - } else if ipVer == 6 { - head[3] = syscall.AF_INET6 - } else { - return 0, fmt.Errorf("unable to determine IP version from packet") - } - - rc, err := t.f.SyscallConn() - if err != nil { - return 0, err - } - - var errno syscall.Errno - var n uintptr - err = rc.Write(func(fd uintptr) bool { - iovecs := []syscall.Iovec{ - {&head[0], 4}, - {&from[0], uint64(len(from))}, - } - - n, _, errno = syscall.Syscall(syscall.SYS_WRITEV, fd, uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2)) - // According to NetBSD documentation for TUN, writes will only return errors in which - // this packet will never be delivered so just go on living life. - return true - }) - if err != nil { - return 0, err - } - - if errno != 0 { - return 0, errno - } - - return int(n) - 4, err -} - func (t *tun) addIp(cidr netip.Prefix) error { if cidr.Addr().Is4() { var req ifreqAlias4 @@ -551,3 +458,7 @@ func delRoute(prefix netip.Prefix, gateways []netip.Prefix) error { return nil } + +// TunPrefixLen reports the 4-byte BSD AF_INET / AF_INET6 protocol-family +// marker the kernel prepends on read and expects on write. +func (t *tun) TunPrefixLen() int { return 4 } diff --git a/overlay/tun_no_prefix.go b/overlay/tun_no_prefix.go new file mode 100644 index 00000000..865aac4f --- /dev/null +++ b/overlay/tun_no_prefix.go @@ -0,0 +1,10 @@ +//go:build (!darwin && !ios && !freebsd && !openbsd && !netbsd) || e2e_testing + +package overlay + +// StampTunPrefix is a no-op on platforms whose tun devices have no +// protocol-family marker. WireBuffer only invokes it when its prefixLen +// is non-zero, so this should never be reached on these platforms. +func StampTunPrefix(buf []byte) error { + return nil +} diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go index 81362184..5e38245d 100644 --- a/overlay/tun_openbsd.go +++ b/overlay/tun_openbsd.go @@ -49,16 +49,14 @@ type ifreq struct { } type tun struct { + io.ReadWriteCloser Device string vpnNetworks []netip.Prefix MTU int Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] l *slog.Logger - f *os.File fd int - // cache out buffer since we need to prepend 4 bytes for tun metadata - out []byte } var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) @@ -89,12 +87,12 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*t } t := &tun{ - f: os.NewFile(uintptr(fd), ""), - fd: fd, - Device: deviceName, - vpnNetworks: vpnNetworks, - MTU: c.GetInt("tun.mtu", DefaultMTU), - l: l, + ReadWriteCloser: os.NewFile(uintptr(fd), ""), + fd: fd, + Device: deviceName, + vpnNetworks: vpnNetworks, + MTU: c.GetInt("tun.mtu", DefaultMTU), + l: l, } err = t.reload(c, true) @@ -113,55 +111,17 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*t } func (t *tun) Close() error { - if t.f != nil { - if err := t.f.Close(); err != nil { + if t.ReadWriteCloser != nil { + if err := t.ReadWriteCloser.Close(); err != nil { return fmt.Errorf("error closing tun file: %w", err) } - // t.f.Close should have handled it for us but let's be extra sure + // Close on the os.File should have handled the fd for us but let's be extra sure _ = unix.Close(t.fd) } return nil } -func (t *tun) Read(to []byte) (int, error) { - buf := make([]byte, len(to)+4) - - n, err := t.f.Read(buf) - - copy(to, buf[4:]) - return n - 4, err -} - -// Write is only valid for single threaded use -func (t *tun) Write(from []byte) (int, error) { - buf := t.out - if cap(buf) < len(from)+4 { - buf = make([]byte, len(from)+4) - t.out = buf - } - buf = buf[:len(from)+4] - - if len(from) == 0 { - return 0, syscall.EIO - } - - // Determine the IP Family for the NULL L2 Header - ipVer := from[0] >> 4 - if ipVer == 4 { - buf[3] = syscall.AF_INET - } else if ipVer == 6 { - buf[3] = syscall.AF_INET6 - } else { - return 0, fmt.Errorf("unable to determine IP version from packet") - } - - copy(buf[4:], from) - - n, err := t.f.Write(buf) - return n - 4, err -} - func (t *tun) addIp(cidr netip.Prefix) error { if cidr.Addr().Is4() { var req ifreqAlias4 @@ -471,3 +431,7 @@ func delRoute(prefix netip.Prefix, gateways []netip.Prefix) error { return nil } + +// TunPrefixLen reports the 4-byte BSD AF_INET / AF_INET6 protocol-family +// marker the kernel prepends on read and expects on write. +func (t *tun) TunPrefixLen() int { return 4 } diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index 8acd83f0..270c139f 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -184,3 +184,5 @@ func (t *TestTun) SupportsMultiqueue() bool { func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented") } + +func (t *TestTun) TunPrefixLen() int { return 0 } diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index 680dddb3..649e79c5 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -296,3 +296,5 @@ func checkWinTunExists() error { _, err = syscall.LoadDLL(filepath.Join(filepath.Dir(myPath), "dist", "windows", "wintun", "bin", arch, "wintun.dll")) return err } + +func (t *winTun) TunPrefixLen() int { return 0 } diff --git a/overlay/user.go b/overlay/user.go index e5f27f37..a8039b3f 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -69,3 +69,5 @@ func (d *UserDevice) Close() error { d.outboundWriter.Close() return nil } + +func (d *UserDevice) TunPrefixLen() int { return 0 } diff --git a/relay_manager.go b/relay_manager.go index 25e65871..7a5e8a22 100644 --- a/relay_manager.go +++ b/relay_manager.go @@ -63,6 +63,9 @@ func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *Ho } hostinfo.logger(rm.l).Info("Attempt to relay through hosts", "relays", hostinfo.remotes.relays) + // One WireBuffer for the whole relay-fanout loop. + buf := f.bufAlloc.Acquire() + defer f.bufAlloc.Release(buf) // Send a RelayRequest to all known Relay IP's for _, relay := range hostinfo.remotes.relays { // Don't relay through the host I'm trying to connect to @@ -124,7 +127,7 @@ func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *Ho if err != nil { hostinfo.logger(rm.l).Error("Failed to marshal Control message to create relay", "error", err) } else { - f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) + f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, buf) rm.l.Info("send CreateRelayRequest", "relayFrom", f.myVpnAddrs[0], "relayTo", vpnIp, @@ -139,7 +142,7 @@ func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *Ho switch existingRelay.State { case Established: hostinfo.logger(rm.l).Info("Send handshake via relay", "relay", relay.String()) - f.SendVia(relayHostInfo, existingRelay, stage0, make([]byte, 12), make([]byte, mtu), false) + f.SendVia(relayHostInfo, existingRelay, stage0, buf) case Disestablished: // Mark this relay as 'requested' relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested) @@ -180,7 +183,7 @@ func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *Ho hostinfo.logger(rm.l).Error("Failed to marshal Control message to create relay", "error", err) } else { // This must send over the hostinfo, not over hm.Hosts[ip] - f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) + f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, buf) rm.l.Info("send CreateRelayRequest", "relayFrom", f.myVpnAddrs[0], "relayTo", vpnIp, @@ -368,7 +371,9 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f if err != nil { rm.l.Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay", "error", err) } else { - f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu)) + buf := f.bufAlloc.Acquire() + f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, buf) + f.bufAlloc.Release(buf) rm.l.Info("send CreateRelayResponse", "relayFrom", resp.RelayFromAddr, "relayTo", resp.RelayToAddr, @@ -468,7 +473,9 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f if err != nil { logMsg.Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay", "error", err) } else { - f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu)) + buf := f.bufAlloc.Acquire() + f.SendMessageToHostInfo(header.Control, 0, h, msg, buf) + f.bufAlloc.Release(buf) rm.l.Info("send CreateRelayResponse", "relayFrom", from, "relayTo", target, @@ -538,7 +545,9 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f if err != nil { logMsg.Error("relayManager Failed to marshal Control message to create relay", "error", err) } else { - f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu)) + buf := f.bufAlloc.Acquire() + f.SendMessageToHostInfo(header.Control, 0, peer, msg, buf) + f.bufAlloc.Release(buf) rm.l.Info("send CreateRelayRequest", "relayFrom", h.vpnAddrs[0], "relayTo", target, diff --git a/ssh.go b/ssh.go index 3863b5ec..75c3c6de 100644 --- a/ssh.go +++ b/ssh.go @@ -632,15 +632,9 @@ func sshCloseTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) er } if !flags.LocalOnly { - ifce.send( - header.CloseTunnel, - 0, - hostInfo.ConnectionState, - hostInfo, - []byte{}, - make([]byte, 12, 12), - make([]byte, mtu), - ) + buf := ifce.bufAlloc.Acquire() + ifce.send(header.CloseTunnel, 0, hostInfo.ConnectionState, hostInfo, []byte{}, buf) + ifce.bufAlloc.Release(buf) } ifce.closeTunnel(hostInfo) diff --git a/wire_buffer.go b/wire_buffer.go new file mode 100644 index 00000000..6fbd74d9 --- /dev/null +++ b/wire_buffer.go @@ -0,0 +1,255 @@ +package nebula + +import ( + "io" + "sync" + + "github.com/slackhq/nebula/firewall" + "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/noiseutil" + "github.com/slackhq/nebula/overlay" +) + +// WireBuffer is the per-goroutine working set for processing one IP packet +// through the data plane. It owns: +// +// - The IP-payload byte buffer used to hold the current inbound or +// outbound packet, with prefixLen bytes of slack at the front for +// the BSD AF_INET protocol-family marker. +// - The fwPacket scratch parsed by newPacket(). +// - The 12-byte AEAD nonce scratch. +// - The header.H parse target used by the receive path. +// - An mtu-sized wire-output scratch for sendNoMetrics and for building +// reject packets. +// +// One WireBuffer is allocated per data-plane goroutine (listenIn for the +// TUN-side, listenOut for the UDP-side) and reused for every packet. No +// per-packet allocation. Future GRO/GSO/TSO and reliable-transport work +// will likely extend this to carry batch state and fragment metadata. +// +// The TUN protocol-family prefix is handled here, not in the overlay +// package. On BSDs the kernel writes the 4-byte marker into the slack on +// read, and we stamp it into the slack before write. On linux/windows +// /userspace devices prefixLen is 0 and the slack is empty. +type WireBuffer struct { + // FwPacket is the parsed IP packet metadata (5-tuple, fragment flags, + // etc.) populated by newPacket(). + FwPacket *firewall.Packet + // NB is a 12-byte scratch the AEAD uses for the nonce; reused so we + // don't allocate one per encrypt/decrypt. + NB []byte + // H is the parse target for inbound nebula headers. Receive path only. + H *header.H + // Out is an mtu-sized wire-output scratch passed to sendNoMetrics and + // rejectInside / rejectOutside. Sized to fit any single wire packet. + Out []byte + + // ip is the IP-payload region: a slice of len 0, cap linkMTU sliced + // from raw at offset prefixLen. The current packet (if any) is + // ip[:bodyN]. The TUN prefix slack lives at raw[0:prefixLen] just + // before ip. + ip []byte + // raw is the backing slab. Layout: + // [prefixLen bytes prefix slack | linkMTU bytes IP region | outSize bytes Out scratch] + // Holding it lets ReadIPFromTUN / WriteIPToTUN address the slack + // region directly. + raw []byte + prefixLen int + bodyN int +} + +// NewWireBuffer returns a buffer sized to hold any single IP packet up to +// linkMTU, plus a disjoint wire-output scratch sliced from the same backing +// slab (the AEAD's Seal contract requires plaintext and dst not to partially +// overlap, and keeping them in one slab gives a single allocation per +// goroutine). Out is sized for the relay worst case +// (linkMTU + 2*header.Len + 2*AEADOverhead). +// +// prefixLen is the number of bytes the destination tun device prepends/ +// expects on each IP packet (overlay.Device.TunPrefixLen). On BSDs this +// is 4 (AF_INET marker); on linux/windows/userspace devices it is 0. +func NewWireBuffer(linkMTU, prefixLen int) *WireBuffer { + outSize := linkMTU + 2*header.Len + 2*AEADOverhead + raw := make([]byte, prefixLen+linkMTU+outSize) + outStart := prefixLen + linkMTU + return &WireBuffer{ + FwPacket: &firewall.Packet{}, + NB: make([]byte, NonceSize), + H: &header.H{}, + Out: raw[outStart : outStart : outStart+outSize], + ip: raw[prefixLen:prefixLen:outStart], + raw: raw, + prefixLen: prefixLen, + } +} + +// Reset clears the body-length record so the buffer is ready for another +// recv (e.g. relay-receive recursion before a nested decrypt). +func (b *WireBuffer) Reset() { b.bodyN = 0 } + +// IPPacket returns the IP packet currently held in the payload region (after +// a successful ReadIPFromTUN or DecryptDatagram). The slice aliases the +// buffer; do not retain past the next operation. +func (b *WireBuffer) IPPacket() []byte { + return b.ip[:b.bodyN] +} + +// Seal stamps a nebula header at the front of buf.Out and AEAD-seals p as the +// payload, treating the header as additional authenticated data. The lock +// scope around counter increment + encrypt matches what goboring AESGCMTLS +// requires; non-boring builds skip the lock. +// +// Returns the wire bytes (header || ciphertext || tag), aliased to buf.Out. +// The slice is invalidated by the next Seal* call on this buffer. +func (b *WireBuffer) Seal(ci *ConnectionState, t header.MessageType, st header.MessageSubType, remoteIndex uint32, p []byte) ([]byte, error) { + return b.sealInto(b.Out[:cap(b.Out)], ci, t, st, remoteIndex, p) +} + +// SealForRelay is like Seal but reserves header.Len bytes of slack at the front +// of buf.Out for an outer relay header. The inner header + ciphertext lands at +// offset header.Len so a follow-up SealRelayInPlace can stamp the outer header +// without copying. Use this when the caller may need to wrap the result in a +// relay envelope after the fact. +func (b *WireBuffer) SealForRelay(ci *ConnectionState, t header.MessageType, st header.MessageSubType, remoteIndex uint32, p []byte) ([]byte, error) { + return b.sealInto(b.Out[header.Len:cap(b.Out)], ci, t, st, remoteIndex, p) +} + +func (b *WireBuffer) sealInto(out []byte, ci *ConnectionState, t header.MessageType, st header.MessageSubType, remoteIndex uint32, p []byte) ([]byte, error) { + if noiseutil.EncryptLockNeeded { + ci.writeLock.Lock() + } + c := ci.messageCounter.Add(1) + out = header.Encode(out, header.Version, t, st, remoteIndex, c) + out, err := ci.eKey.EncryptDanger(out, out, p, c, b.NB) + if noiseutil.EncryptLockNeeded { + ci.writeLock.Unlock() + } + return out, err +} + +// SealRelayInPlace wraps an inner message that is already staged at +// buf.Out[header.Len:header.Len+innerLen] (either from a SealForRelay encrypt +// or from a copy via the SendVia entry point). It stamps the outer relay +// header into buf.Out[:header.Len] and AAD-only seals over the entire region, +// producing the wire bytes for the relay tunnel. +// +// Returns the wire bytes aliased to buf.Out; invalidated by the next Seal* +// call on this buffer. +func (b *WireBuffer) SealRelayInPlace(ci *ConnectionState, remoteIndex uint32, innerLen int) ([]byte, error) { + if noiseutil.EncryptLockNeeded { + ci.writeLock.Lock() + } + c := ci.messageCounter.Add(1) + out := b.Out[:cap(b.Out)] + out = header.Encode(out, header.Version, header.Message, header.MessageRelay, remoteIndex, c) + out = out[:header.Len+innerLen] + out, err := ci.eKey.EncryptDanger(out, out, nil, c, b.NB) + if noiseutil.EncryptLockNeeded { + ci.writeLock.Unlock() + } + return out, err +} + +// StageRelayInner copies ad into the inner-payload slot at buf.Out[header.Len:] +// so SealRelayInPlace can wrap it on the next call. Used by SendVia when ad +// did not come from a prior SealForRelay (e.g. a handshake message being +// forwarded through a relay tunnel without our own encryption). +func (b *WireBuffer) StageRelayInner(ad []byte) int { + return copy(b.Out[header.Len:cap(b.Out)], ad) +} + +// ReadIPFromTUN reads one IP packet from r into the payload region and +// updates bodyN. On BSDs the kernel writes its 4-byte protocol-family +// marker into the slack at raw[0:prefixLen] and the IP packet at +// raw[prefixLen:prefixLen+n]; we hand it the slack-prefixed slice so +// the kernel can do this in one syscall with no copy. On linux/windows/ +// userspace devices prefixLen is 0 and the slack is empty. +func (b *WireBuffer) ReadIPFromTUN(r io.Reader) (int, error) { + n, err := r.Read(b.raw[:b.prefixLen+cap(b.ip)]) + if err != nil { + b.bodyN = 0 + return 0, err + } + if n < b.prefixLen { + b.bodyN = 0 + return 0, nil + } + b.bodyN = n - b.prefixLen + return b.bodyN, nil +} + +// WriteIPToTUN writes the IP packet currently in the payload region to w. +// On BSDs we stamp the protocol-family marker into the slack at +// raw[0:prefixLen] in place and write the entire slack+IP region in a +// single syscall, so the kernel sees [marker][ip] back to back without a +// userspace copy. On linux/windows/userspace devices the slack is empty +// and we just write the IP region. +func (b *WireBuffer) WriteIPToTUN(w io.Writer) (int, error) { + out := b.raw[:b.prefixLen+b.bodyN] + if b.prefixLen > 0 { + if err := overlay.StampTunPrefix(out); err != nil { + return 0, err + } + } + return w.Write(out) +} + +// DecryptDatagram decrypts an inbound UDP packet into the payload region. +func (b *WireBuffer) DecryptDatagram(ci *ConnectionState, packet []byte, mc uint64) error { + dst, err := ci.dKey.DecryptDanger(b.ip[:0], packet[:header.Len], packet[header.Len:], mc, b.NB) + if err != nil { + b.bodyN = 0 + return err + } + b.bodyN = len(dst) + return nil +} + +// DecryptForHandler decrypts an inbound UDP packet (lighthouse, test, +// control, close-tunnel) into the payload region and returns the plaintext +// slice for the in-process handler. Returned slice aliases the buffer. +func (b *WireBuffer) DecryptForHandler(ci *ConnectionState, packet []byte, mc uint64) ([]byte, error) { + dst, err := ci.dKey.DecryptDanger(b.ip[:0], packet[:header.Len], packet[header.Len:], mc, b.NB) + if err != nil { + b.bodyN = 0 + return nil, err + } + b.bodyN = len(dst) + return dst, nil +} + +// WireBufferAllocator hands out reusable WireBuffers for cold callers that +// don't own a long-lived per-goroutine buffer (control plane, relay manager, +// connection manager teardown, etc.). Hot-path goroutines hold their own +// buffer for the life of the goroutine and don't need to acquire one. +type WireBufferAllocator interface { + Acquire() *WireBuffer + Release(*WireBuffer) +} + +// wireBufferPool is a sync.Pool-backed WireBufferAllocator. The pool is +// keyed off a single linkMTU and prefixLen; cold callers send across the +// data-plane mtu and target the same Device, so we size the pool's +// buffers the same way. +type wireBufferPool struct { + pool sync.Pool +} + +func NewWireBufferPool(linkMTU, prefixLen int) *wireBufferPool { + return &wireBufferPool{ + pool: sync.Pool{ + New: func() any { + return NewWireBuffer(linkMTU, prefixLen) + }, + }, + } +} + +func (p *wireBufferPool) Acquire() *WireBuffer { + return p.pool.Get().(*WireBuffer) +} + +func (p *wireBufferPool) Release(b *WireBuffer) { + b.Reset() + p.pool.Put(b) +}