diff --git a/control.go b/control.go index 20dd7fe..016a79b 100644 --- a/control.go +++ b/control.go @@ -131,8 +131,7 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo { // GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate { - _, found := c.f.myVpnAddrsTable.Lookup(vpnIp) - if found { + if c.f.myVpnAddrsTable.Contains(vpnIp) { // Only returning the default certificate since its impossible // for any other host but ourselves to have more than 1 return c.f.pki.getCertState().GetDefaultCertificate().Copy() diff --git a/dns_server.go b/dns_server.go index 710f6ed..7357654 100644 --- a/dns_server.go +++ b/dns_server.go @@ -26,7 +26,7 @@ type dnsRecords struct { dnsMap4 map[string]netip.Addr dnsMap6 map[string]netip.Addr hostMap *HostMap - myVpnAddrsTable *bart.Table[struct{}] + myVpnAddrsTable *bart.Lite } func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords { @@ -112,8 +112,8 @@ func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool { return true } - _, found := d.myVpnAddrsTable.Lookup(b) - return found //if we found it in this table, it's good + //if we found it in this table, it's good + return d.myVpnAddrsTable.Contains(b) } func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) { diff --git a/firewall.go b/firewall.go index e730114..971c156 100644 --- a/firewall.go +++ b/firewall.go @@ -53,7 +53,7 @@ type Firewall struct { // routableNetworks describes the vpn addresses as well as any unsafe networks issued to us in the certificate. // The vpn addresses are a full bit match while the unsafe networks only match the prefix - routableNetworks *bart.Table[struct{}] + routableNetworks *bart.Lite // assignedNetworks is a list of vpn networks assigned to us in the certificate. assignedNetworks []netip.Prefix @@ -125,7 +125,7 @@ type firewallPort map[int32]*FirewallCA type firewallLocalCIDR struct { Any bool - LocalCIDR *bart.Table[struct{}] + LocalCIDR *bart.Lite } // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts. @@ -148,17 +148,17 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D tmax = defaultTimeout } - routableNetworks := new(bart.Table[struct{}]) + routableNetworks := new(bart.Lite) var assignedNetworks []netip.Prefix for _, network := range c.Networks() { nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen()) - routableNetworks.Insert(nprefix, struct{}{}) + routableNetworks.Insert(nprefix) assignedNetworks = append(assignedNetworks, network) } hasUnsafeNetworks := false for _, n := range c.UnsafeNetworks() { - routableNetworks.Insert(n, struct{}{}) + routableNetworks.Insert(n) hasUnsafeNetworks = true } @@ -431,8 +431,7 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * // Make sure remote address matches nebula certificate if h.networks != nil { - _, ok := h.networks.Lookup(fp.RemoteAddr) - if !ok { + if !h.networks.Contains(fp.RemoteAddr) { f.metrics(incoming).droppedRemoteAddr.Inc(1) return ErrInvalidRemoteIP } @@ -445,8 +444,7 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * } // Make sure we are supposed to be handling this local ip address - _, ok := f.routableNetworks.Lookup(fp.LocalAddr) - if !ok { + if !f.routableNetworks.Contains(fp.LocalAddr) { f.metrics(incoming).droppedLocalAddr.Inc(1) return ErrInvalidLocalIP } @@ -752,7 +750,7 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.CachedCertificate, caPool func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error { flc := func() *firewallLocalCIDR { return &firewallLocalCIDR{ - LocalCIDR: new(bart.Table[struct{}]), + LocalCIDR: new(bart.Lite), } } @@ -879,7 +877,7 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error { } for _, network := range f.assignedNetworks { - flc.LocalCIDR.Insert(network, struct{}{}) + flc.LocalCIDR.Insert(network) } return nil @@ -888,7 +886,7 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error { return nil } - flc.LocalCIDR.Insert(localIp, struct{}{}) + flc.LocalCIDR.Insert(localIp) return nil } @@ -901,8 +899,7 @@ func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.CachedCertificate return true } - _, ok := flc.LocalCIDR.Lookup(p.LocalAddr) - return ok + return flc.LocalCIDR.Contains(p.LocalAddr) } type rule struct { diff --git a/handshake_ix.go b/handshake_ix.go index 571a19a..cf422b9 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -192,8 +192,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet for _, network := range remoteCert.Certificate.Networks() { vpnAddr := network.Addr() - _, found := f.myVpnAddrsTable.Lookup(vpnAddr) - if found { + if f.myVpnAddrsTable.Contains(vpnAddr) { f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr). WithField("certName", certName). WithField("certVersion", certVersion). @@ -204,7 +203,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet } // vpnAddrs outside our vpn networks are of no use to us, filter them out - if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok { + if !f.myVpnNetworksTable.Contains(vpnAddr) { continue } @@ -579,7 +578,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha for _, network := range vpnNetworks { // vpnAddrs outside our vpn networks are of no use to us, filter them out vpnAddr := network.Addr() - if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok { + if !f.myVpnNetworksTable.Contains(vpnAddr) { continue } diff --git a/handshake_manager.go b/handshake_manager.go index 6f95402..486541b 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -274,8 +274,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered } // Don't relay through the host I'm trying to connect to - _, found := hm.f.myVpnAddrsTable.Lookup(relay) - if found { + if hm.f.myVpnAddrsTable.Contains(relay) { continue } diff --git a/hostmap.go b/hostmap.go index f9e3c4e..359749b 100644 --- a/hostmap.go +++ b/hostmap.go @@ -223,7 +223,7 @@ type HostInfo struct { recvError atomic.Uint32 // networks are both all vpn and unsafe networks assigned to this host - networks *bart.Table[struct{}] + networks *bart.Lite relayState RelayState // HandshakePacket records the packets used to create this hostinfo @@ -732,13 +732,13 @@ func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) { return } - i.networks = new(bart.Table[struct{}]) + i.networks = new(bart.Lite) for _, network := range networks { - i.networks.Insert(network, struct{}{}) + i.networks.Insert(network) } for _, network := range unsafeNetworks { - i.networks.Insert(network, struct{}{}) + i.networks.Insert(network) } } diff --git a/inside.go b/inside.go index 0af350d..239ea6a 100644 --- a/inside.go +++ b/inside.go @@ -22,14 +22,12 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet // Ignore local broadcast packets if f.dropLocalBroadcast { - _, found := f.myBroadcastAddrsTable.Lookup(fwPacket.RemoteAddr) - if found { + if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) { return } } - _, found := f.myVpnAddrsTable.Lookup(fwPacket.RemoteAddr) - if found { + if f.myVpnAddrsTable.Contains(fwPacket.RemoteAddr) { // Immediately forward packets from self to self. // This should only happen on Darwin-based and FreeBSD hosts, which // routes packets from the Nebula addr to the Nebula addr through the Nebula @@ -130,8 +128,7 @@ func (f *Interface) Handshake(vpnAddr netip.Addr) { // getOrHandshakeNoRouting returns nil if the vpnAddr is not routable. // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel func (f *Interface) getOrHandshakeNoRouting(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { - _, found := f.myVpnNetworksTable.Lookup(vpnAddr) - if found { + if f.myVpnNetworksTable.Contains(vpnAddr) { return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback) } diff --git a/interface.go b/interface.go index a15e2c2..ddd0681 100644 --- a/interface.go +++ b/interface.go @@ -61,11 +61,11 @@ type Interface struct { serveDns bool createTime time.Time lightHouse *LightHouse - myBroadcastAddrsTable *bart.Table[struct{}] - myVpnAddrs []netip.Addr // A list of addresses assigned to us via our certificate - myVpnAddrsTable *bart.Table[struct{}] // A table of addresses assigned to us via our certificate - myVpnNetworks []netip.Prefix // A list of networks assigned to us via our certificate - myVpnNetworksTable *bart.Table[struct{}] // A table of networks assigned to us via our certificate + myBroadcastAddrsTable *bart.Lite + myVpnAddrs []netip.Addr // A list of addresses assigned to us via our certificate + myVpnAddrsTable *bart.Lite + myVpnNetworks []netip.Prefix // A list of networks assigned to us via our certificate + myVpnNetworksTable *bart.Lite dropLocalBroadcast bool dropMulticast bool routines int diff --git a/lighthouse.go b/lighthouse.go index eb09a39..7a679c7 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -32,7 +32,7 @@ type LightHouse struct { amLighthouse bool myVpnNetworks []netip.Prefix - myVpnNetworksTable *bart.Table[struct{}] + myVpnNetworksTable *bart.Lite punchConn udp.Conn punchy *Punchy @@ -201,8 +201,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { //TODO: we could technically insert all returned addrs instead of just the first one if a dns lookup was used addr := addrs[0].Unmap() - _, found := lh.myVpnNetworksTable.Lookup(addr) - if found { + if lh.myVpnNetworksTable.Contains(addr) { lh.l.WithField("addr", rawAddr).WithField("entry", i+1). Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range") continue @@ -359,8 +358,7 @@ func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{ return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err) } - _, found := lh.myVpnNetworksTable.Lookup(addr) - if !found { + if !lh.myVpnNetworksTable.Contains(addr) { return util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil) } lhMap[addr] = struct{}{} @@ -431,8 +429,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, err) } - _, found := lh.myVpnNetworksTable.Lookup(vpnAddr) - if !found { + if !lh.myVpnNetworksTable.Contains(vpnAddr) { return util.NewContextualError("static_host_map key is not in our network, invalid", m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}, nil) } @@ -653,8 +650,7 @@ func (lh *LightHouse) shouldAdd(vpnAddr netip.Addr, to netip.Addr) bool { return false } - _, found := lh.myVpnNetworksTable.Lookup(to) - if found { + if lh.myVpnNetworksTable.Contains(to) { return false } @@ -674,8 +670,7 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bo return false } - _, found := lh.myVpnNetworksTable.Lookup(udpAddr.Addr()) - if found { + if lh.myVpnNetworksTable.Contains(udpAddr.Addr()) { return false } @@ -695,8 +690,7 @@ func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bo return false } - _, found := lh.myVpnNetworksTable.Lookup(udpAddr.Addr()) - if found { + if lh.myVpnNetworksTable.Contains(udpAddr.Addr()) { return false } @@ -856,8 +850,7 @@ func (lh *LightHouse) SendUpdate() { lal := lh.GetLocalAllowList() for _, e := range localAddrs(lh.l, lal) { - _, found := lh.myVpnNetworksTable.Lookup(e) - if found { + if lh.myVpnNetworksTable.Contains(e) { continue } diff --git a/lighthouse_test.go b/lighthouse_test.go index c49615c..eb2d26e 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -31,8 +31,8 @@ func TestOldIPv4Only(t *testing.T) { func Test_lhStaticMapping(t *testing.T) { l := test.NewLogger() myVpnNet := netip.MustParsePrefix("10.128.0.1/16") - nt := new(bart.Table[struct{}]) - nt.Insert(myVpnNet, struct{}{}) + nt := new(bart.Lite) + nt.Insert(myVpnNet) cs := &CertState{ myVpnNetworks: []netip.Prefix{myVpnNet}, myVpnNetworksTable: nt, @@ -56,8 +56,8 @@ func Test_lhStaticMapping(t *testing.T) { func TestReloadLighthouseInterval(t *testing.T) { l := test.NewLogger() myVpnNet := netip.MustParsePrefix("10.128.0.1/16") - nt := new(bart.Table[struct{}]) - nt.Insert(myVpnNet, struct{}{}) + nt := new(bart.Lite) + nt.Insert(myVpnNet) cs := &CertState{ myVpnNetworks: []netip.Prefix{myVpnNet}, myVpnNetworksTable: nt, @@ -91,8 +91,8 @@ func TestReloadLighthouseInterval(t *testing.T) { func BenchmarkLighthouseHandleRequest(b *testing.B) { l := test.NewLogger() myVpnNet := netip.MustParsePrefix("10.128.0.1/0") - nt := new(bart.Table[struct{}]) - nt.Insert(myVpnNet, struct{}{}) + nt := new(bart.Lite) + nt.Insert(myVpnNet) cs := &CertState{ myVpnNetworks: []netip.Prefix{myVpnNet}, myVpnNetworksTable: nt, @@ -196,8 +196,8 @@ func TestLighthouse_Memory(t *testing.T) { c.Settings["listen"] = map[string]any{"port": 4242} myVpnNet := netip.MustParsePrefix("10.128.0.1/24") - nt := new(bart.Table[struct{}]) - nt.Insert(myVpnNet, struct{}{}) + nt := new(bart.Lite) + nt.Insert(myVpnNet) cs := &CertState{ myVpnNetworks: []netip.Prefix{myVpnNet}, myVpnNetworksTable: nt, @@ -281,8 +281,8 @@ func TestLighthouse_reload(t *testing.T) { c.Settings["listen"] = map[string]any{"port": 4242} myVpnNet := netip.MustParsePrefix("10.128.0.1/24") - nt := new(bart.Table[struct{}]) - nt.Insert(myVpnNet, struct{}{}) + nt := new(bart.Lite) + nt.Insert(myVpnNet) cs := &CertState{ myVpnNetworks: []netip.Prefix{myVpnNet}, myVpnNetworksTable: nt, diff --git a/outside.go b/outside.go index 1e9cde1..3a7b3a7 100644 --- a/outside.go +++ b/outside.go @@ -31,8 +31,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] //l.Error("in packet ", header, packet[HeaderLen:]) if ip.IsValid() { - _, found := f.myVpnNetworksTable.Lookup(ip.Addr()) - if found { + if f.myVpnNetworksTable.Contains(ip.Addr()) { if f.l.Level >= logrus.DebugLevel { f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet") } diff --git a/pki.go b/pki.go index c9f8d89..9cab491 100644 --- a/pki.go +++ b/pki.go @@ -39,10 +39,10 @@ type CertState struct { cipher string myVpnNetworks []netip.Prefix - myVpnNetworksTable *bart.Table[struct{}] + myVpnNetworksTable *bart.Lite myVpnAddrs []netip.Addr - myVpnAddrsTable *bart.Table[struct{}] - myVpnBroadcastAddrsTable *bart.Table[struct{}] + myVpnAddrsTable *bart.Lite + myVpnBroadcastAddrsTable *bart.Lite } func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) { @@ -345,9 +345,9 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p cs := CertState{ privateKey: privateKey, pkcs11Backed: pkcs11backed, - myVpnNetworksTable: new(bart.Table[struct{}]), - myVpnAddrsTable: new(bart.Table[struct{}]), - myVpnBroadcastAddrsTable: new(bart.Table[struct{}]), + myVpnNetworksTable: new(bart.Lite), + myVpnAddrsTable: new(bart.Lite), + myVpnBroadcastAddrsTable: new(bart.Lite), } if v1 != nil && v2 != nil { @@ -415,16 +415,16 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p for _, network := range crt.Networks() { cs.myVpnNetworks = append(cs.myVpnNetworks, network) - cs.myVpnNetworksTable.Insert(network, struct{}{}) + cs.myVpnNetworksTable.Insert(network) cs.myVpnAddrs = append(cs.myVpnAddrs, network.Addr()) - cs.myVpnAddrsTable.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()), struct{}{}) + cs.myVpnAddrsTable.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen())) if network.Addr().Is4() { addr := network.Masked().Addr().As4() mask := net.CIDRMask(network.Bits(), network.Addr().BitLen()) binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(mask)) - cs.myVpnBroadcastAddrsTable.Insert(netip.PrefixFrom(netip.AddrFrom4(addr), network.Addr().BitLen()), struct{}{}) + cs.myVpnBroadcastAddrsTable.Insert(netip.PrefixFrom(netip.AddrFrom4(addr), network.Addr().BitLen())) } } diff --git a/relay_manager.go b/relay_manager.go index 7565350..5dd355c 100644 --- a/relay_manager.go +++ b/relay_manager.go @@ -241,15 +241,13 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f logMsg.Info("handleCreateRelayRequest") // Is the source of the relay me? This should never happen, but did happen due to // an issue migrating relays over to newly re-handshaked host info objects. - _, found := f.myVpnAddrsTable.Lookup(from) - if found { + if f.myVpnAddrsTable.Contains(from) { logMsg.WithField("myIP", from).Error("Discarding relay request from myself") return } // Is the target of the relay me? - _, found = f.myVpnAddrsTable.Lookup(target) - if found { + if f.myVpnAddrsTable.Contains(target) { existingRelay, ok := h.relayState.QueryRelayForByIp(from) if ok { switch existingRelay.State {