mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-11 18:33:57 +01:00
Remove unusable networks from remote tunnels at handshake time (#1318)
This commit is contained in:
parent
84dd988f01
commit
1ad0f57c1e
@ -8,7 +8,6 @@ import (
|
||||
"hash/fnv"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@ -438,9 +437,8 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
|
||||
return ErrInvalidRemoteIP
|
||||
}
|
||||
} else {
|
||||
// Simple case: Certificate has one IP and no subnets
|
||||
//TODO: we can make this more performant
|
||||
if !slices.Contains(h.vpnAddrs, fp.RemoteAddr) {
|
||||
// Simple case: Certificate has one address and no unsafe networks
|
||||
if h.vpnAddrs[0] != fp.RemoteAddr {
|
||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||
return ErrInvalidRemoteIP
|
||||
}
|
||||
|
||||
@ -152,7 +152,7 @@ func TestFirewall_Drop(t *testing.T) {
|
||||
},
|
||||
vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
|
||||
}
|
||||
h.buildNetworks(&c)
|
||||
h.buildNetworks(c.networks, c.unsafeNetworks)
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
@ -332,7 +332,7 @@ func TestFirewall_Drop2(t *testing.T) {
|
||||
},
|
||||
vpnAddrs: []netip.Addr{network.Addr()},
|
||||
}
|
||||
h.buildNetworks(c.Certificate)
|
||||
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
||||
|
||||
c1 := cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
@ -342,11 +342,12 @@ func TestFirewall_Drop2(t *testing.T) {
|
||||
InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}},
|
||||
}
|
||||
h1 := HostInfo{
|
||||
vpnAddrs: []netip.Addr{network.Addr()},
|
||||
ConnectionState: &ConnectionState{
|
||||
peerCert: &c1,
|
||||
},
|
||||
}
|
||||
h1.buildNetworks(c1.Certificate)
|
||||
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
@ -394,7 +395,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
||||
},
|
||||
vpnAddrs: []netip.Addr{network.Addr()},
|
||||
}
|
||||
h1.buildNetworks(c1.Certificate)
|
||||
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
||||
|
||||
c2 := cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
@ -409,7 +410,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
||||
},
|
||||
vpnAddrs: []netip.Addr{network.Addr()},
|
||||
}
|
||||
h2.buildNetworks(c2.Certificate)
|
||||
h2.buildNetworks(c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks())
|
||||
|
||||
c3 := cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
@ -424,7 +425,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
||||
},
|
||||
vpnAddrs: []netip.Addr{network.Addr()},
|
||||
}
|
||||
h3.buildNetworks(c3.Certificate)
|
||||
h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
@ -471,7 +472,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||
},
|
||||
vpnAddrs: []netip.Addr{network.Addr()},
|
||||
}
|
||||
h.buildNetworks(c.Certificate)
|
||||
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
|
||||
@ -172,6 +172,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
}
|
||||
|
||||
var vpnAddrs []netip.Addr
|
||||
var filteredNetworks []netip.Prefix
|
||||
certName := remoteCert.Certificate.Name()
|
||||
fingerprint := remoteCert.Fingerprint
|
||||
issuer := remoteCert.Certificate.Issuer()
|
||||
@ -189,15 +190,32 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
}
|
||||
|
||||
if addr.IsValid() {
|
||||
// addr can be invalid when the tunnel is being relayed.
|
||||
// We only want to apply the remote allow list for direct tunnels here
|
||||
if !f.lightHouse.GetRemoteAllowList().Allow(vpnAddr, addr.Addr()) {
|
||||
f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// vpnAddrs outside our vpn networks are of no use to us, filter them out
|
||||
if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
filteredNetworks = append(filteredNetworks, network)
|
||||
vpnAddrs = append(vpnAddrs, vpnAddr)
|
||||
}
|
||||
|
||||
if len(vpnAddrs) == 0 {
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
|
||||
return
|
||||
}
|
||||
|
||||
myIndex, err := generateIndex(f.l)
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||
@ -294,7 +312,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
|
||||
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
|
||||
hostinfo.SetRemote(addr)
|
||||
hostinfo.buildNetworks(remoteCert.Certificate)
|
||||
hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
|
||||
|
||||
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
|
||||
if err != nil {
|
||||
@ -431,7 +449,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
||||
|
||||
hostinfo := hh.hostinfo
|
||||
if addr.IsValid() {
|
||||
//TODO: this is kind of nonsense now
|
||||
// The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list.
|
||||
if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnAddrs[0], addr.Addr()) {
|
||||
f.l.WithField("vpnIp", hostinfo.vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
|
||||
return false
|
||||
@ -492,7 +510,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
||||
e = e.WithField("cert", remoteCert)
|
||||
}
|
||||
|
||||
e.Info("Invalid vpn ip from host")
|
||||
e.Info("Empty networks from host")
|
||||
return true
|
||||
}
|
||||
|
||||
@ -516,9 +534,26 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
||||
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
||||
}
|
||||
|
||||
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
|
||||
for i, n := range vpnNetworks {
|
||||
vpnAddrs[i] = n.Addr()
|
||||
var vpnAddrs []netip.Addr
|
||||
var filteredNetworks []netip.Prefix
|
||||
for _, network := range vpnNetworks {
|
||||
// vpnAddrs outside our vpn networks are of no use to us, filter them out
|
||||
vpnAddr := network.Addr()
|
||||
if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
filteredNetworks = append(filteredNetworks, network)
|
||||
vpnAddrs = append(vpnAddrs, vpnAddr)
|
||||
}
|
||||
|
||||
if len(vpnAddrs) == 0 {
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
|
||||
return true
|
||||
}
|
||||
|
||||
// Ensure the right host responded
|
||||
@ -558,7 +593,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
||||
ci.window.Update(f.l, 2)
|
||||
|
||||
duration := time.Since(hh.startTime).Nanoseconds()
|
||||
f.l.WithField("vpnNetworks", vpnNetworks).WithField("udpAddr", addr).
|
||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
@ -569,9 +604,10 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
||||
Info("Handshake message received")
|
||||
|
||||
// Build up the radix for the firewall if we have subnets in the cert
|
||||
hostinfo.buildNetworks(remoteCert.Certificate)
|
||||
hostinfo.vpnAddrs = vpnAddrs
|
||||
hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
|
||||
|
||||
// Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp
|
||||
// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
|
||||
f.handshakeManager.Complete(hostinfo, f)
|
||||
f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
|
||||
|
||||
|
||||
16
hostmap.go
16
hostmap.go
@ -215,8 +215,12 @@ type HostInfo struct {
|
||||
ConnectionState *ConnectionState
|
||||
remoteIndexId uint32
|
||||
localIndexId uint32
|
||||
vpnAddrs []netip.Addr
|
||||
recvError atomic.Uint32
|
||||
|
||||
// vpnAddrs is a list of vpn addresses assigned to this host that are within our own vpn networks
|
||||
// The host may have other vpn addresses that are outside our
|
||||
// vpn networks but were removed because they are not usable
|
||||
vpnAddrs []netip.Addr
|
||||
recvError atomic.Uint32
|
||||
|
||||
// networks are both all vpn and unsafe networks assigned to this host
|
||||
networks *bart.Table[struct{}]
|
||||
@ -712,18 +716,18 @@ func (i *HostInfo) RecvErrorExceeded() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (i *HostInfo) buildNetworks(c cert.Certificate) {
|
||||
if len(c.Networks()) == 1 && len(c.UnsafeNetworks()) == 0 {
|
||||
func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
|
||||
if len(networks) == 1 && len(unsafeNetworks) == 0 {
|
||||
// Simple case, no CIDRTree needed
|
||||
return
|
||||
}
|
||||
|
||||
i.networks = new(bart.Table[struct{}])
|
||||
for _, network := range c.Networks() {
|
||||
for _, network := range networks {
|
||||
i.networks.Insert(network, struct{}{})
|
||||
}
|
||||
|
||||
for _, network := range c.UnsafeNetworks() {
|
||||
for _, network := range unsafeNetworks {
|
||||
i.networks.Insert(network, struct{}{})
|
||||
}
|
||||
}
|
||||
|
||||
@ -64,7 +64,7 @@ type Interface struct {
|
||||
myBroadcastAddrsTable *bart.Table[struct{}]
|
||||
myVpnAddrs []netip.Addr // A list of addresses assigned to us via our certificate
|
||||
myVpnAddrsTable *bart.Table[struct{}] // A table of addresses assigned to us via our certificate
|
||||
myVpnNetworks []netip.Prefix // A table of networks assigned to us via our certificate
|
||||
myVpnNetworks []netip.Prefix // A list of networks assigned to us via our certificate
|
||||
myVpnNetworksTable *bart.Table[struct{}] // A table of networks assigned to us via our certificate
|
||||
dropLocalBroadcast bool
|
||||
dropMulticast bool
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user