Compare commits

...

3 Commits

Author SHA1 Message Date
Nate Brown
4eb808c829 Pass pointer to ViaSender 2025-11-22 00:42:25 -06:00
Nate Brown
04f47eabf7 Better hot path benchmarks, minor fixups 2025-11-21 23:38:26 -06:00
Nate Brown
6df963dd00 TODO 2025-11-21 22:43:31 -06:00
8 changed files with 180 additions and 122 deletions

View File

@@ -25,11 +25,12 @@ import (
func BenchmarkHotPath(b *testing.B) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
// Put their info in our lighthouse
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Start the servers
myControl.Start()
@@ -38,6 +39,9 @@ func BenchmarkHotPath(b *testing.B) {
r := router.NewR(b, myControl, theirControl)
r.CancelFlowLogs()
assertTunnel(b, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
b.ResetTimer()
for n := 0; n < b.N; n++ {
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
_ = r.RouteForAllUntilTxTun(theirControl)
@@ -47,6 +51,39 @@ func BenchmarkHotPath(b *testing.B) {
theirControl.Stop()
}
func BenchmarkHotPathRelay(b *testing.B) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
// Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(b, myControl, relayControl, theirControl)
r.CancelFlowLogs()
// Start the servers
myControl.Start()
relayControl.Start()
theirControl.Start()
assertTunnel(b, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
b.ResetTimer()
for n := 0; n < b.N; n++ {
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
_ = r.RouteForAllUntilTxTun(theirControl)
}
myControl.Stop()
theirControl.Stop()
relayControl.Stop()
}
func TestGoodHandshake(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)

View File

@@ -292,7 +292,7 @@ func deadline(t *testing.T, seconds time.Duration) doneCb {
}
}
func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
func assertTunnel(t testing.TB, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
// Send a packet from them to me
controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B"))
bPacket := r.RouteForAllUntilTxTun(controlA)
@@ -325,7 +325,7 @@ func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpn
assert.Equal(t, hBinA.RemoteIndex, hAinB.LocalIndex, "Host B remote index does not match host A local index")
}
func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
func assertUdpPacket(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
if toIp.Is6() {
assertUdpPacket6(t, expected, b, fromIp, toIp, fromPort, toPort)
} else {
@@ -333,7 +333,7 @@ func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr,
}
}
func assertUdpPacket6(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
func assertUdpPacket6(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
packet := gopacket.NewPacket(b, layers.LayerTypeIPv6, gopacket.Lazy)
v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6)
assert.NotNil(t, v6, "No ipv6 data found")
@@ -352,7 +352,7 @@ func assertUdpPacket6(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr,
assert.Equal(t, expected, data.Payload(), "Data was incorrect")
}
func assertUdpPacket4(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
func assertUdpPacket4(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy)
v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
assert.NotNil(t, v4, "No ipv4 data found")

View File

@@ -99,11 +99,11 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
return true
}
func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
func ixHandshakeStage1(f *Interface, via *ViaSender, packet []byte, h *header.H) {
cs := f.pki.getCertState()
crt := cs.GetDefaultCertificate()
if crt == nil {
f.l.WithField("udpAddr", addr).
f.l.WithField("from", via).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", cs.initiatingVersion).
Error("Unable to handshake with host because no certificate is available")
@@ -112,7 +112,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
if err != nil {
f.l.WithError(err).WithField("udpAddr", addr).
f.l.WithError(err).WithField("from", via).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to create connection state")
return
@@ -123,7 +123,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil {
f.l.WithError(err).WithField("udpAddr", addr).
f.l.WithError(err).WithField("from", via).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to call noise.ReadMessage")
return
@@ -132,7 +132,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
hs := &NebulaHandshake{}
err = hs.Unmarshal(msg)
if err != nil || hs.Details == nil {
f.l.WithError(err).WithField("udpAddr", addr).
f.l.WithError(err).WithField("from", via).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed unmarshal handshake message")
return
@@ -140,7 +140,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
if err != nil {
f.l.WithError(err).WithField("udpAddr", addr).
f.l.WithError(err).WithField("from", via).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Handshake did not contain a certificate")
return
@@ -153,7 +153,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
fp = "<error generating certificate fingerprint>"
}
e := f.l.WithError(err).WithField("udpAddr", addr).
e := f.l.WithError(err).WithField("from", via).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("certVpnNetworks", rc.Networks()).
WithField("certFingerprint", fp)
@@ -172,7 +172,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
if myCertOtherVersion == nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithError(err).WithFields(m{
"udpAddr": addr,
"from": via,
"handshake": m{"stage": 1, "style": "ix_psk0"},
"cert": remoteCert,
}).Debug("Might be unable to handshake with host due to missing certificate version")
@@ -184,7 +184,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
}
if len(remoteCert.Certificate.Networks()) == 0 {
f.l.WithError(err).WithField("udpAddr", addr).
f.l.WithError(err).WithField("from", via).
WithField("cert", remoteCert).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("No networks in certificate")
@@ -201,7 +201,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
for i, network := range vpnNetworks {
if f.myVpnAddrsTable.Contains(network.Addr()) {
f.l.WithField("vpnNetworks", vpnNetworks).WithField("udpAddr", addr).
f.l.WithField("vpnNetworks", vpnNetworks).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
@@ -215,18 +215,18 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
}
}
if addr.IsValid() {
// addr can be invalid when the tunnel is being relayed.
if !via.IsRelayed {
// We only want to apply the remote allow list for direct tunnels here
if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, addr.Addr()) {
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, via.UdpAddr.Addr()) {
f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
Debug("lighthouse.remote_allow_list denied incoming handshake")
return
}
}
myIndex, err := generateIndex(f.l)
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
@@ -251,7 +251,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
msgRxL := f.l.WithFields(m{
"vpnAddrs": vpnAddrs,
"udpAddr": addr,
"from": via,
"certName": certName,
"certVersion": certVersion,
"fingerprint": fingerprint,
@@ -283,7 +283,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
hsBytes, err := hs.Marshal()
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
@@ -295,7 +295,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2)
msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes)
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
@@ -303,7 +303,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
return
} else if dKey == nil || eKey == nil {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
@@ -329,7 +329,9 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
ci.eKey = NewNebulaCipherState(eKey)
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
hostinfo.SetRemote(addr)
if !via.IsRelayed {
hostinfo.SetRemote(via.UdpAddr)
}
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate)
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
@@ -337,7 +339,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
switch err {
case ErrAlreadySeen:
// Update remote if preferred
if existing.SetRemoteIfPreferred(f.hostMap, addr) {
if existing.SetRemoteIfPreferred(f.hostMap, via) {
// Send a test packet to ensure the other side has also switched to
// the preferred remote
f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
@@ -345,21 +347,21 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
msg = existing.HandshakePacket[2]
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
if addr.IsValid() {
err := f.outside.WriteTo(msg, addr)
if !via.IsRelayed {
err := f.outside.WriteTo(msg, via.UdpAddr)
if err != nil {
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr).
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
WithError(err).Error("Failed to send handshake message")
} else {
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr).
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
Info("Handshake message sent")
}
return
} else {
if via == nil {
f.l.Error("Handshake send failed: both addr and via are nil.")
if via.relay == nil {
f.l.Error("Handshake send failed: both addr and via.relay are nil.")
return
}
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
@@ -371,7 +373,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
}
case ErrExistingHostInfo:
// This means there was an existing tunnel and this handshake was older than the one we are currently based on
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("oldHandshakeTime", existing.lastHandshakeTime).
@@ -387,7 +389,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
return
case ErrLocalIndexCollision:
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
@@ -400,7 +402,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
default:
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
// And we forget to update it here
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
@@ -414,30 +416,23 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
// Do the send
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
if addr.IsValid() {
err = f.outside.WriteTo(msg, addr)
if !via.IsRelayed {
err = f.outside.WriteTo(msg, via.UdpAddr)
log := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
if err != nil {
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithError(err).Error("Failed to send handshake")
log.WithError(err).Error("Failed to send handshake")
} else {
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("Handshake message sent")
log.Info("Handshake message sent")
}
} else {
if via == nil {
f.l.Error("Handshake send failed: both addr and via are nil.")
if via.relay == nil {
f.l.Error("Handshake send failed: both addr and via.relay are nil.")
return
}
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
@@ -462,7 +457,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
return
}
func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool {
func ixHandshakeStage2(f *Interface, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool {
if hh == nil {
// Nothing here to tear down, got a bogus stage 2 packet
return true
@@ -472,10 +467,10 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
defer hh.Unlock()
hostinfo := hh.hostinfo
if addr.IsValid() {
if !via.IsRelayed {
// The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list.
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, addr.Addr()) {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake")
return false
}
}
@@ -483,7 +478,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
ci := hostinfo.ConnectionState
msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
Error("Failed to call noise.ReadMessage")
@@ -492,7 +487,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
// near future
return false
} else if dKey == nil || eKey == nil {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Error("Noise did not arrive at a key")
@@ -504,7 +499,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
hs := &NebulaHandshake{}
err = hs.Unmarshal(msg)
if err != nil || hs.Details == nil {
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
@@ -513,7 +508,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
if err != nil {
f.l.WithError(err).WithField("udpAddr", addr).
f.l.WithError(err).WithField("from", via).
WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("Handshake did not contain a certificate")
@@ -527,7 +522,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
fp = "<error generating certificate fingerprint>"
}
e := f.l.WithError(err).WithField("udpAddr", addr).
e := f.l.WithError(err).WithField("from", via).
WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithField("certFingerprint", fp).
@@ -542,7 +537,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
}
if len(remoteCert.Certificate.Networks()) == 0 {
f.l.WithError(err).WithField("udpAddr", addr).
f.l.WithError(err).WithField("from", via).
WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("cert", remoteCert).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
@@ -565,8 +560,8 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
ci.eKey = NewNebulaCipherState(eKey)
// Make sure the current udpAddr being used is set for responding
if addr.IsValid() {
hostinfo.SetRemote(addr)
if !via.IsRelayed {
hostinfo.SetRemote(via.UdpAddr)
} else {
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
}
@@ -588,7 +583,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
// Ensure the right host responded
if !correctHostResponded {
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
WithField("udpAddr", addr).
WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
@@ -602,7 +597,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
// Block the current used address
newHH.hostinfo.remotes = hostinfo.remotes
newHH.hostinfo.remotes.BlockRemote(addr)
newHH.hostinfo.remotes.BlockRemote(via)
f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).
WithField("vpnNetworks", vpnNetworks).
@@ -625,7 +620,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
ci.window.Update(f.l, 2)
duration := time.Since(hh.startTime).Nanoseconds()
msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).

View File

@@ -136,11 +136,11 @@ func (hm *HandshakeManager) Run(ctx context.Context) {
}
}
func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
func (hm *HandshakeManager) HandleIncoming(via *ViaSender, packet []byte, h *header.H) {
// First remote allow list check before we know the vpnIp
if addr.IsValid() {
if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(addr.Addr()) {
hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
if !via.IsRelayed {
if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(via.UdpAddr.Addr()) {
hm.l.WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake")
return
}
}
@@ -149,11 +149,11 @@ func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender,
case header.HandshakeIXPSK0:
switch h.MessageCounter {
case 1:
ixHandshakeStage1(hm.f, addr, via, packet, h)
ixHandshakeStage1(hm.f, via, packet, h)
case 2:
newHostinfo := hm.queryIndex(h.RemoteIndex)
tearDown := ixHandshakeStage2(hm.f, addr, via, newHostinfo, packet, h)
tearDown := ixHandshakeStage2(hm.f, via, newHostinfo, packet, h)
if tearDown && newHostinfo != nil {
hm.DeleteHostInfo(newHostinfo.hostinfo)
}

View File

@@ -1,7 +1,9 @@
package nebula
import (
"encoding/json"
"errors"
"fmt"
"net"
"net/netip"
"slices"
@@ -276,9 +278,25 @@ type HostInfo struct {
}
type ViaSender struct {
UdpAddr netip.AddrPort
relayHI *HostInfo // relayHI is the host info object of the relay
remoteIdx uint32 // remoteIdx is the index included in the header of the received packet
relay *Relay // relay contains the rest of the relay information, including the PeerIP of the host trying to communicate with us.
IsRelayed bool // IsRelayed is true if the packet was sent through a relay
}
func (v *ViaSender) String() string {
if v.IsRelayed {
return fmt.Sprintf("%s (relayed)", v.UdpAddr)
}
return v.UdpAddr.String()
}
func (v *ViaSender) MarshalJSON() ([]byte, error) {
if v.IsRelayed {
return json.Marshal(m{"direct": v.UdpAddr})
}
return json.Marshal(m{"relay": v.UdpAddr})
}
type cachedPacket struct {
@@ -694,6 +712,7 @@ func (i *HostInfo) GetCert() *cert.CachedCertificate {
return nil
}
// TODO: Maybe use ViaSender here?
func (i *HostInfo) SetRemote(remote netip.AddrPort) {
// We copy here because we likely got this remote from a source that reuses the object
if i.remote != remote {
@@ -704,14 +723,14 @@ func (i *HostInfo) SetRemote(remote netip.AddrPort) {
// SetRemoteIfPreferred returns true if the remote was changed. The lastRoam
// time on the HostInfo will also be updated.
func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) bool {
if !newRemote.IsValid() {
// relays have nil udp Addrs
func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, via *ViaSender) bool {
if via.IsRelayed {
return false
}
currentRemote := i.remote
if !currentRemote.IsValid() {
i.SetRemote(newRemote)
i.SetRemote(via.UdpAddr)
return true
}
@@ -724,7 +743,7 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b
return false
}
if l.Contains(newRemote.Addr()) {
if l.Contains(via.UdpAddr.Addr()) {
newIsPreferred = true
}
}
@@ -734,7 +753,7 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b
i.lastRoam = time.Now()
i.lastRoamRemote = currentRemote
i.SetRemote(newRemote)
i.SetRemote(via.UdpAddr)
return true
}

View File

@@ -279,7 +279,7 @@ func (f *Interface) listenOut(i int) {
nb := make([]byte, 12, 12)
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
f.readOutsidePackets(&ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
})
}

View File

@@ -19,21 +19,21 @@ const (
minFwPacketLen = 4
)
func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
func (f *Interface) readOutsidePackets(via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
err := h.Parse(packet)
if err != nil {
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
if len(packet) > 1 {
f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", ip, err)
f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", via, err)
}
return
}
//l.Error("in packet ", header, packet[HeaderLen:])
if ip.IsValid() {
if f.myVpnNetworksTable.Contains(ip.Addr()) {
if !via.IsRelayed {
if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
f.l.WithField("from", via).Debug("Refusing to process double encrypted packet")
}
return
}
@@ -54,8 +54,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
switch h.Type {
case header.Message:
// TODO handleEncrypted sends directly to addr on error. Handle this in the tunneling case.
if !f.handleEncrypted(ci, ip, h) {
if !f.handleEncrypted(ci, via, h) {
return
}
@@ -79,7 +78,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
// Successfully validated the thing. Get rid of the Relay header.
signedPayload = signedPayload[header.Len:]
// Pull the Roaming parts up here, and return in all call paths.
f.handleHostRoaming(hostinfo, ip)
f.handleHostRoaming(hostinfo, via)
// Track usage of both the HostInfo and the Relay for the received & authenticated packet
f.connectionManager.In(hostinfo)
f.connectionManager.RelayUsed(h.RemoteIndex)
@@ -96,7 +95,14 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
case TerminalType:
// If I am the target of this relay, process the unwrapped packet
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
rVia := ViaSender{
UdpAddr: via.UdpAddr,
relayHI: hostinfo,
remoteIdx: relay.RemoteIndex,
relay: relay,
IsRelayed: true,
}
f.readOutsidePackets(&rVia, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
return
case ForwardingType:
// Find the target HostInfo relay object
@@ -126,31 +132,32 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
case header.LightHouse:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, ip, h) {
if !f.handleEncrypted(ci, via, h) {
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
hostinfo.logger(f.l).WithError(err).WithField("from", via).
WithField("packet", packet).
Error("Failed to decrypt lighthouse packet")
return
}
lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f)
//TODO: assert via is not relayed
lhf.HandleRequest(via.UdpAddr, hostinfo.vpnAddrs, d, f)
// Fallthrough to the bottom to record incoming traffic
case header.Test:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, ip, h) {
if !f.handleEncrypted(ci, via, h) {
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
hostinfo.logger(f.l).WithError(err).WithField("from", via).
WithField("packet", packet).
Error("Failed to decrypt test packet")
return
@@ -159,7 +166,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
if h.Subtype == header.TestRequest {
// This testRequest might be from TryPromoteBest, so we should roam
// to the new IP address before responding
f.handleHostRoaming(hostinfo, ip)
f.handleHostRoaming(hostinfo, via)
f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out)
}
@@ -170,34 +177,34 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
case header.Handshake:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
f.handshakeManager.HandleIncoming(ip, via, packet, h)
f.handshakeManager.HandleIncoming(via, packet, h)
return
case header.RecvError:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
f.handleRecvError(ip, h)
f.handleRecvError(via.UdpAddr, h)
return
case header.CloseTunnel:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, ip, h) {
if !f.handleEncrypted(ci, via, h) {
return
}
hostinfo.logger(f.l).WithField("udpAddr", ip).
hostinfo.logger(f.l).WithField("from", via).
Info("Close tunnel received, tearing down.")
f.closeTunnel(hostinfo)
return
case header.Control:
if !f.handleEncrypted(ci, ip, h) {
if !f.handleEncrypted(ci, via, h) {
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
hostinfo.logger(f.l).WithError(err).WithField("from", via).
WithField("packet", packet).
Error("Failed to decrypt Control packet")
return
@@ -207,11 +214,11 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
default:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", ip)
hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", via)
return
}
f.handleHostRoaming(hostinfo, ip)
f.handleHostRoaming(hostinfo, via)
f.connectionManager.In(hostinfo)
}
@@ -230,36 +237,36 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) {
f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
}
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort) {
if udpAddr.IsValid() && hostinfo.remote != udpAddr {
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, udpAddr.Addr()) {
hostinfo.logger(f.l).WithField("newAddr", udpAddr).Debug("lighthouse.remote_allow_list denied roaming")
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, via *ViaSender) {
if !via.IsRelayed && hostinfo.remote != via.UdpAddr {
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) {
hostinfo.logger(f.l).WithField("newAddr", via.UdpAddr).Debug("lighthouse.remote_allow_list denied roaming")
return
}
if !hostinfo.lastRoam.IsZero() && udpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
if !hostinfo.lastRoam.IsZero() && via.UdpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr).
Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
}
return
}
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr).
Info("Host roamed to new udp ip/port.")
hostinfo.lastRoam = time.Now()
hostinfo.lastRoamRemote = hostinfo.remote
hostinfo.SetRemote(udpAddr)
hostinfo.SetRemote(via.UdpAddr)
}
}
// handleEncrypted returns true if a packet should be processed, false otherwise
func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h *header.H) bool {
func (f *Interface) handleEncrypted(ci *ConnectionState, via *ViaSender, h *header.H) bool {
// If connectionstate does not exist, send a recv error, if possible, to encourage a fast reconnect
if ci == nil {
if addr.IsValid() {
f.maybeSendRecvError(addr, h.RemoteIndex)
if !via.IsRelayed {
f.maybeSendRecvError(via.UdpAddr, h.RemoteIndex)
}
return false
}

View File

@@ -338,21 +338,21 @@ func (r *RemoteList) CopyCache() *CacheMap {
}
// BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list
func (r *RemoteList) BlockRemote(bad netip.AddrPort) {
if !bad.IsValid() {
// relays can have nil udp Addrs
func (r *RemoteList) BlockRemote(bad *ViaSender) {
if bad.IsRelayed {
return
}
r.Lock()
defer r.Unlock()
// Check if we already blocked this addr
if r.unlockedIsBad(bad) {
if r.unlockedIsBad(bad.UdpAddr) {
return
}
// We copy here because we are taking something else's memory and we can't trust everything
r.badRemotes = append(r.badRemotes, bad)
r.badRemotes = append(r.badRemotes, bad.UdpAddr)
// Mark the next interaction must recollect/dedupe
r.shouldRebuild = true