From bd9cc01d624a6f1feb9c4096a8eb79794b12472c Mon Sep 17 00:00:00 2001 From: brad-defined <77982333+brad-defined@users.noreply.github.com> Date: Tue, 9 May 2023 11:22:08 -0400 Subject: [PATCH 1/6] Dns static lookerupper (#796) * Support lighthouse DNS names, and regularly resolve the name in a background goroutine to discover DNS updates. --- control_test.go | 2 +- handshake_manager_test.go | 2 +- lighthouse.go | 175 +++++++++++++++++++++++++++++++------- lighthouse_test.go | 15 ++-- main.go | 2 +- remote_list.go | 170 +++++++++++++++++++++++++++++++++++- remote_list_test.go | 6 +- 7 files changed, 324 insertions(+), 48 deletions(-) diff --git a/control_test.go b/control_test.go index ec469b4..de46991 100644 --- a/control_test.go +++ b/control_test.go @@ -47,7 +47,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { Signature: []byte{1, 2, 1, 2, 1, 3}, } - remotes := NewRemoteList() + remotes := NewRemoteList(nil) remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port))) remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port))) hm.Add(iputil.Ip2VpnIp(ipNet.IP), &HostInfo{ diff --git a/handshake_manager_test.go b/handshake_manager_test.go index 5635c40..3e39e48 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -41,7 +41,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { assert.False(t, initCalled) assert.Same(t, i, i2) - i.remotes = NewRemoteList() + i.remotes = NewRemoteList(nil) i.HandshakeReady = true // Adding something to pending should not affect the main hostmap diff --git a/lighthouse.go b/lighthouse.go index 2532fc4..460a1cb 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "net" + "net/netip" "sync" "sync/atomic" "time" @@ -33,6 +34,7 @@ type netIpAndPort struct { type LightHouse struct { //TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time sync.RWMutex //Because we concurrently read and write to our maps + ctx context.Context amLighthouse bool myVpnIp iputil.VpnIp myVpnZeros iputil.VpnIp @@ -82,7 +84,7 @@ type LightHouse struct { // NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object // addrMap should be nil unless this is during a config reload -func NewLightHouseFromConfig(l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, pc *udp.Conn, p *Punchy) (*LightHouse, error) { +func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, pc *udp.Conn, p *Punchy) (*LightHouse, error) { amLighthouse := c.GetBool("lighthouse.am_lighthouse", false) nebulaPort := uint32(c.GetInt("listen.port", 0)) if amLighthouse && nebulaPort == 0 { @@ -100,6 +102,7 @@ func NewLightHouseFromConfig(l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, ones, _ := myVpnNet.Mask.Size() h := LightHouse{ + ctx: ctx, amLighthouse: amLighthouse, myVpnIp: iputil.Ip2VpnIp(myVpnNet.IP), myVpnZeros: iputil.VpnIp(32 - ones), @@ -258,7 +261,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { } //NOTE: many things will get much simpler when we combine static_host_map and lighthouse.hosts in config - if initial || c.HasChanged("static_host_map") { + if initial || c.HasChanged("static_host_map") || c.HasChanged("static_map.cadence") || c.HasChanged("static_map.network") || c.HasChanged("static_map.lookup_timeout") { staticList := make(map[iputil.VpnIp]struct{}) err := lh.loadStaticMap(c, lh.myVpnNet, staticList) if err != nil { @@ -268,9 +271,19 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { lh.staticList.Store(&staticList) if !initial { //TODO: we should remove any remote list entries for static hosts that were removed/modified? - lh.l.Info("static_host_map has changed") + if c.HasChanged("static_host_map") { + lh.l.Info("static_host_map has changed") + } + if c.HasChanged("static_map.cadence") { + lh.l.Info("static_map.cadence has changed") + } + if c.HasChanged("static_map.network") { + lh.l.Info("static_map.network has changed") + } + if c.HasChanged("static_map.lookup_timeout") { + lh.l.Info("static_map.lookup_timeout has changed") + } } - } if initial || c.HasChanged("lighthouse.hosts") { @@ -344,7 +357,48 @@ func (lh *LightHouse) parseLighthouses(c *config.C, tunCidr *net.IPNet, lhMap ma return nil } +func getStaticMapCadence(c *config.C) (time.Duration, error) { + cadence := c.GetString("static_map.cadence", "30s") + d, err := time.ParseDuration(cadence) + if err != nil { + return 0, err + } + return d, nil +} + +func getStaticMapLookupTimeout(c *config.C) (time.Duration, error) { + lookupTimeout := c.GetString("static_map.lookup_timeout", "250ms") + d, err := time.ParseDuration(lookupTimeout) + if err != nil { + return 0, err + } + return d, nil +} + +func getStaticMapNetwork(c *config.C) (string, error) { + network := c.GetString("static_map.network", "ip4") + if network != "ip" && network != "ip4" && network != "ip6" { + return "", fmt.Errorf("static_map.network must be one of ip, ip4, or ip6") + } + return network, nil +} + func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList map[iputil.VpnIp]struct{}) error { + d, err := getStaticMapCadence(c) + if err != nil { + return err + } + + network, err := getStaticMapNetwork(c) + if err != nil { + return err + } + + lookup_timeout, err := getStaticMapLookupTimeout(c) + if err != nil { + return err + } + shm := c.GetMap("static_host_map", map[interface{}]interface{}{}) i := 0 @@ -360,21 +414,17 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList vpnIp := iputil.Ip2VpnIp(rip) vals, ok := v.([]interface{}) - if ok { - for _, v := range vals { - ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v)) - if err != nil { - return util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp, "entry": i + 1}, err) - } - lh.addStaticRemote(vpnIp, udp.NewAddr(ip, port), staticList) - } + if !ok { + vals = []interface{}{v} + } + remoteAddrs := []string{} + for _, v := range vals { + remoteAddrs = append(remoteAddrs, fmt.Sprintf("%v", v)) + } - } else { - ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v)) - if err != nil { - return util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp, "entry": i + 1}, err) - } - lh.addStaticRemote(vpnIp, udp.NewAddr(ip, port), staticList) + err := lh.addStaticRemotes(i, d, network, lookup_timeout, vpnIp, remoteAddrs, staticList) + if err != nil { + return err } i++ } @@ -482,30 +532,47 @@ func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) { // We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with // And we don't want a lighthouse query reply to interfere with our learned cache if we are a client // NOTE: this function should not interact with any hot path objects, like lh.staticList, the caller should handle it -func (lh *LightHouse) addStaticRemote(vpnIp iputil.VpnIp, toAddr *udp.Addr, staticList map[iputil.VpnIp]struct{}) { +func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnIp iputil.VpnIp, toAddrs []string, staticList map[iputil.VpnIp]struct{}) error { lh.Lock() am := lh.unlockedGetRemoteList(vpnIp) am.Lock() defer am.Unlock() + ctx := lh.ctx lh.Unlock() - if ipv4 := toAddr.IP.To4(); ipv4 != nil { - to := NewIp4AndPort(ipv4, uint32(toAddr.Port)) - if !lh.unlockedShouldAddV4(vpnIp, to) { - return - } - am.unlockedPrependV4(lh.myVpnIp, to) + hr, err := NewHostnameResults(ctx, lh.l, d, network, timeout, toAddrs, func() { + // This callback runs whenever the DNS hostname resolver finds a different set of IP's + // in its resolution for hostnames. + am.Lock() + defer am.Unlock() + am.shouldRebuild = true + }) + if err != nil { + return util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp, "entry": i + 1}, err) + } + am.unlockedSetHostnamesResults(hr) - } else { - to := NewIp6AndPort(toAddr.IP, uint32(toAddr.Port)) - if !lh.unlockedShouldAddV6(vpnIp, to) { - return + for _, addrPort := range hr.GetIPs() { + + switch { + case addrPort.Addr().Is4(): + to := NewIp4AndPortFromNetIP(addrPort.Addr(), addrPort.Port()) + if !lh.unlockedShouldAddV4(vpnIp, to) { + continue + } + am.unlockedPrependV4(lh.myVpnIp, to) + case addrPort.Addr().Is6(): + to := NewIp6AndPortFromNetIP(addrPort.Addr(), addrPort.Port()) + if !lh.unlockedShouldAddV6(vpnIp, to) { + continue + } + am.unlockedPrependV6(lh.myVpnIp, to) } - am.unlockedPrependV6(lh.myVpnIp, to) } // Mark it as static in the caller provided map staticList[vpnIp] = struct{}{} + return nil } // addCalculatedRemotes adds any calculated remotes based on the @@ -545,12 +612,42 @@ func (lh *LightHouse) addCalculatedRemotes(vpnIp iputil.VpnIp) bool { func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList { am, ok := lh.addrMap[vpnIp] if !ok { - am = NewRemoteList() + am = NewRemoteList(func(a netip.Addr) bool { return lh.shouldAdd(vpnIp, a) }) lh.addrMap[vpnIp] = am } return am } +func (lh *LightHouse) shouldAdd(vpnIp iputil.VpnIp, to netip.Addr) bool { + switch { + case to.Is4(): + ipBytes := to.As4() + ip := iputil.Ip2VpnIp(ipBytes[:]) + allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, ip) + if lh.l.Level >= logrus.TraceLevel { + lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow") + } + if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, ip) { + return false + } + case to.Is6(): + ipBytes := to.As16() + + hi := binary.BigEndian.Uint64(ipBytes[:8]) + lo := binary.BigEndian.Uint64(ipBytes[8:]) + allow := lh.GetRemoteAllowList().AllowIpV6(vpnIp, hi, lo) + if lh.l.Level >= logrus.TraceLevel { + lh.l.WithField("remoteIp", to).WithField("allow", allow).Trace("remoteAllowList.Allow") + } + + // We don't check our vpn network here because nebula does not support ipv6 on the inside + if !allow { + return false + } + } + return true +} + // unlockedShouldAddV4 checks if to is allowed by our allow list func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bool { allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, iputil.VpnIp(to.Ip)) @@ -609,6 +706,14 @@ func NewIp4AndPort(ip net.IP, port uint32) *Ip4AndPort { return &ipp } +func NewIp4AndPortFromNetIP(ip netip.Addr, port uint16) *Ip4AndPort { + v4Addr := ip.As4() + return &Ip4AndPort{ + Ip: binary.BigEndian.Uint32(v4Addr[:]), + Port: uint32(port), + } +} + func NewIp6AndPort(ip net.IP, port uint32) *Ip6AndPort { return &Ip6AndPort{ Hi: binary.BigEndian.Uint64(ip[:8]), @@ -617,6 +722,14 @@ func NewIp6AndPort(ip net.IP, port uint32) *Ip6AndPort { } } +func NewIp6AndPortFromNetIP(ip netip.Addr, port uint16) *Ip6AndPort { + ip6Addr := ip.As16() + return &Ip6AndPort{ + Hi: binary.BigEndian.Uint64(ip6Addr[:8]), + Lo: binary.BigEndian.Uint64(ip6Addr[8:]), + Port: uint32(port), + } +} func NewUDPAddrFromLH4(ipp *Ip4AndPort) *udp.Addr { ip := ipp.Ip return udp.NewAddr( diff --git a/lighthouse_test.go b/lighthouse_test.go index 658c087..aa4da4c 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -1,6 +1,7 @@ package nebula import ( + "context" "fmt" "net" "testing" @@ -53,14 +54,14 @@ func Test_lhStaticMapping(t *testing.T) { c := config.NewC(l) c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}} c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}} - _, err := NewLightHouseFromConfig(l, c, myVpnNet, nil, nil) + _, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) assert.Nil(t, err) lh2 := "10.128.0.3" c = config.NewC(l) c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}} c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}} - _, err = NewLightHouseFromConfig(l, c, myVpnNet, nil, nil) + _, err = NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry") } @@ -69,14 +70,14 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { _, myVpnNet, _ := net.ParseCIDR("10.128.0.1/0") c := config.NewC(l) - lh, err := NewLightHouseFromConfig(l, c, myVpnNet, nil, nil) + lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) if !assert.NoError(b, err) { b.Fatal() } hAddr := udp.NewAddrFromString("4.5.6.7:12345") hAddr2 := udp.NewAddrFromString("4.5.6.7:12346") - lh.addrMap[3] = NewRemoteList() + lh.addrMap[3] = NewRemoteList(nil) lh.addrMap[3].unlockedSetV4( 3, 3, @@ -89,7 +90,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { rAddr := udp.NewAddrFromString("1.2.2.3:12345") rAddr2 := udp.NewAddrFromString("1.2.2.3:12346") - lh.addrMap[2] = NewRemoteList() + lh.addrMap[2] = NewRemoteList(nil) lh.addrMap[2].unlockedSetV4( 3, 3, @@ -162,7 +163,7 @@ func TestLighthouse_Memory(t *testing.T) { c := config.NewC(l) c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true} c.Settings["listen"] = map[interface{}]interface{}{"port": 4242} - lh, err := NewLightHouseFromConfig(l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil) + lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil) assert.NoError(t, err) lhh := lh.NewRequestHandler() @@ -238,7 +239,7 @@ func TestLighthouse_reload(t *testing.T) { c := config.NewC(l) c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true} c.Settings["listen"] = map[interface{}]interface{}{"port": 4242} - lh, err := NewLightHouseFromConfig(l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil) + lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil) assert.NoError(t, err) c.Settings["static_host_map"] = map[interface{}]interface{}{"10.128.0.2": []interface{}{"1.1.1.1:4242"}} diff --git a/main.go b/main.go index bbf831a..4d604f5 100644 --- a/main.go +++ b/main.go @@ -226,7 +226,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg */ punchy := NewPunchyFromConfig(l, c) - lightHouse, err := NewLightHouseFromConfig(l, c, tunCidr, udpConns[0], punchy) + lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy) switch { case errors.As(err, &util.ContextualError{}): return nil, err diff --git a/remote_list.go b/remote_list.go index 4b544f6..4540714 100644 --- a/remote_list.go +++ b/remote_list.go @@ -2,10 +2,16 @@ package nebula import ( "bytes" + "context" "net" + "net/netip" "sort" + "strconv" "sync" + "sync/atomic" + "time" + "github.com/sirupsen/logrus" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/udp" ) @@ -55,6 +61,132 @@ type cacheV6 struct { reported []*Ip6AndPort } +type hostnamePort struct { + name string + port uint16 +} + +type hostnamesResults struct { + hostnames []hostnamePort + network string + lookupTimeout time.Duration + stop chan struct{} + l *logrus.Logger + ips atomic.Pointer[map[netip.AddrPort]struct{}] +} + +func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, network string, timeout time.Duration, hostPorts []string, onUpdate func()) (*hostnamesResults, error) { + r := &hostnamesResults{ + hostnames: make([]hostnamePort, len(hostPorts)), + network: network, + lookupTimeout: timeout, + stop: make(chan (struct{})), + l: l, + } + + // Fastrack IP addresses to ensure they're immediately available for use. + // DNS lookups for hostnames that aren't hardcoded IP's will happen in a background goroutine. + performBackgroundLookup := false + ips := map[netip.AddrPort]struct{}{} + for idx, hostPort := range hostPorts { + + rIp, sPort, err := net.SplitHostPort(hostPort) + if err != nil { + return nil, err + } + + iPort, err := strconv.Atoi(sPort) + if err != nil { + return nil, err + } + + r.hostnames[idx] = hostnamePort{name: rIp, port: uint16(iPort)} + addr, err := netip.ParseAddr(rIp) + if err != nil { + // This address is a hostname, not an IP address + performBackgroundLookup = true + continue + } + + // Save the IP address immediately + ips[netip.AddrPortFrom(addr, uint16(iPort))] = struct{}{} + } + r.ips.Store(&ips) + + // Time for the DNS lookup goroutine + if performBackgroundLookup { + ticker := time.NewTicker(d) + go func() { + defer ticker.Stop() + for { + netipAddrs := map[netip.AddrPort]struct{}{} + for _, hostPort := range r.hostnames { + timeoutCtx, timeoutCancel := context.WithTimeout(ctx, r.lookupTimeout) + addrs, err := net.DefaultResolver.LookupNetIP(timeoutCtx, r.network, hostPort.name) + timeoutCancel() + if err != nil { + l.WithFields(logrus.Fields{"hostname": hostPort.name, "network": r.network}).WithError(err).Error("DNS resolution failed for static_map host") + continue + } + for _, a := range addrs { + netipAddrs[netip.AddrPortFrom(a, hostPort.port)] = struct{}{} + } + } + origSet := r.ips.Load() + different := false + for a := range *origSet { + if _, ok := netipAddrs[a]; !ok { + different = true + break + } + } + if !different { + for a := range netipAddrs { + if _, ok := (*origSet)[a]; !ok { + different = true + break + } + } + } + if different { + l.WithFields(logrus.Fields{"origSet": origSet, "newSet": netipAddrs}).Info("DNS results changed for host list") + r.ips.Store(&netipAddrs) + onUpdate() + } + select { + case <-ctx.Done(): + return + case <-r.stop: + return + case <-ticker.C: + continue + } + } + }() + } + + return r, nil +} + +func (hr *hostnamesResults) Cancel() { + if hr != nil { + hr.stop <- struct{}{} + } +} + +func (hr *hostnamesResults) GetIPs() []netip.AddrPort { + var retSlice []netip.AddrPort + if hr != nil { + p := hr.ips.Load() + if p != nil { + for k := range *p { + retSlice = append(retSlice, k) + } + } + } + return retSlice +} + // RemoteList is a unifying concept for lighthouse servers and clients as well as hostinfos. // It serves as a local cache of query replies, host update notifications, and locally learned addresses type RemoteList struct { @@ -72,6 +204,9 @@ type RemoteList struct { // For learned addresses, this is the vpnIp that sent the packet cache map[iputil.VpnIp]*cache + hr *hostnamesResults + shouldAdd func(netip.Addr) bool + // This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip. // They should not be tried again during a handshake badRemotes []*udp.Addr @@ -81,14 +216,21 @@ type RemoteList struct { } // NewRemoteList creates a new empty RemoteList -func NewRemoteList() *RemoteList { +func NewRemoteList(shouldAdd func(netip.Addr) bool) *RemoteList { return &RemoteList{ - addrs: make([]*udp.Addr, 0), - relays: make([]*iputil.VpnIp, 0), - cache: make(map[iputil.VpnIp]*cache), + addrs: make([]*udp.Addr, 0), + relays: make([]*iputil.VpnIp, 0), + cache: make(map[iputil.VpnIp]*cache), + shouldAdd: shouldAdd, } } +func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) { + // Cancel any existing hostnamesResults DNS goroutine to release resources + r.hr.Cancel() + r.hr = hr +} + // Len locks and reports the size of the deduplicated address list // The deduplication work may need to occur here, so you must pass preferredRanges func (r *RemoteList) Len(preferredRanges []*net.IPNet) int { @@ -437,6 +579,26 @@ func (r *RemoteList) unlockedCollect() { } } + dnsAddrs := r.hr.GetIPs() + for _, addr := range dnsAddrs { + if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) { + switch { + case addr.Addr().Is4(): + v4 := addr.Addr().As4() + addrs = append(addrs, &udp.Addr{ + IP: v4[:], + Port: addr.Port(), + }) + case addr.Addr().Is6(): + v6 := addr.Addr().As16() + addrs = append(addrs, &udp.Addr{ + IP: v6[:], + Port: addr.Port(), + }) + } + } + } + r.addrs = addrs r.relays = relays diff --git a/remote_list_test.go b/remote_list_test.go index 2170930..49aa171 100644 --- a/remote_list_test.go +++ b/remote_list_test.go @@ -9,7 +9,7 @@ import ( ) func TestRemoteList_Rebuild(t *testing.T) { - rl := NewRemoteList() + rl := NewRemoteList(nil) rl.unlockedSetV4( 0, 0, @@ -102,7 +102,7 @@ func TestRemoteList_Rebuild(t *testing.T) { } func BenchmarkFullRebuild(b *testing.B) { - rl := NewRemoteList() + rl := NewRemoteList(nil) rl.unlockedSetV4( 0, 0, @@ -167,7 +167,7 @@ func BenchmarkFullRebuild(b *testing.B) { } func BenchmarkSortRebuild(b *testing.B) { - rl := NewRemoteList() + rl := NewRemoteList(nil) rl.unlockedSetV4( 0, 0, From 0707caedb413c068dcc5e8029b08faa38c990a15 Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Tue, 9 May 2023 11:24:52 -0400 Subject: [PATCH 2/6] document P256 and BoringCrypto (#865) * document P256 and BoringCrypto Some basic descriptions of P256 and BoringCrypto added to the bottom of README.md so that their prupose is not a mystery. * typo --- README.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/README.md b/README.md index ba4e997..925aa61 100644 --- a/README.md +++ b/README.md @@ -118,6 +118,17 @@ To build nebula for a specific platform (ex, Windows): See the [Makefile](Makefile) for more details on build targets +## Curve P256 and BoringCrypto + +The default curve used for cryptographic handshakes and signatures is Curve25519. This is the recommended setting for most users. If your deployment has certain compliance requirements, you have the option of creating your CA using `nebula-cert ca -curve P256` to use NIST Curve P256. The CA will then sign certificates using ECDSA P256, and any hosts using these certificates will use P256 for ECDH handshakes. + +In addition, Nebula can be built using the [BoringCrypto GOEXPERIMENT](https://github.com/golang/go/blob/go1.20/src/crypto/internal/boring/README.md) by running either of the following make targets: + + make bin-boringcrypto + make release-boringcrypto + +This is not the recommended default deployment, but may be useful based on your compliance requirements. + ## Credits Nebula was created at Slack Technologies, Inc by Nate Brown and Ryan Huber, with contributions from Oliver Fross, Alan Lam, Wade Simmons, and Lining Wang. From 115b4b70b1c3bf07606b9939b0c92b4264ab2cf4 Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Tue, 9 May 2023 11:25:21 -0400 Subject: [PATCH 3/6] add SECURITY.md (#864) * add SECURITY.md Fixes: #699 * add Security mention to New issue template * cleanup --- .github/ISSUE_TEMPLATE/config.yml | 4 ++++ SECURITY.md | 12 ++++++++++++ 2 files changed, 16 insertions(+) create mode 100644 SECURITY.md diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index 94e2c6b..9c675ca 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -11,3 +11,7 @@ contact_links: - name: 📱 Mobile Nebula url: https://github.com/definednet/mobile_nebula about: 'This issue tracker is not for mobile support. Try the Mobile Nebula repo instead!' + + - name: 🔒 Report Security Vulnerability + url: https://github.com/slackhq/nebula/blob/master/SECURITY.md + about: 'Please view SECURITY.md to learn how to report security vulnerabilities.' diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000..bfff621 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,12 @@ +Security Policy +=============== + +Reporting a Vulnerability +------------------------- + +If you believe you have found a security vulnerability with Nebula, please let +us know right away. We will investigate all reports and do our best to quickly +fix valid issues. + +You can submit your report on [HackerOne](https://hackerone.com/slack) and our +security team will respond as soon as possible. From a9cb2e06f40dd5212f1f83efc239948f97797e98 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Tue, 9 May 2023 10:36:55 -0500 Subject: [PATCH 4/6] Add ability to respect the system route table for unsafe route on linux (#839) --- cidr/tree4.go | 28 ++++++++- cidr/tree4_test.go | 14 +++++ examples/config.yml | 4 ++ overlay/tun.go | 2 + overlay/tun_android.go | 4 +- overlay/tun_darwin.go | 4 +- overlay/tun_freebsd.go | 4 +- overlay/tun_ios.go | 4 +- overlay/tun_linux.go | 124 +++++++++++++++++++++++++++++++++----- overlay/tun_linux_test.go | 14 ++--- overlay/tun_tester.go | 4 +- overlay/tun_windows.go | 4 +- 12 files changed, 173 insertions(+), 37 deletions(-) diff --git a/cidr/tree4.go b/cidr/tree4.go index 28d0e78..0839c90 100644 --- a/cidr/tree4.go +++ b/cidr/tree4.go @@ -13,8 +13,14 @@ type Node struct { value interface{} } +type entry struct { + CIDR *net.IPNet + Value *interface{} +} + type Tree4 struct { root *Node + list []entry } const ( @@ -24,6 +30,7 @@ const ( func NewTree4() *Tree4 { tree := new(Tree4) tree.root = &Node{} + tree.list = []entry{} return tree } @@ -53,6 +60,15 @@ func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) { // We already have this range so update the value if next != nil { + addCIDR := cidr.String() + for i, v := range tree.list { + if addCIDR == v.CIDR.String() { + tree.list = append(tree.list[:i], tree.list[i+1:]...) + break + } + } + + tree.list = append(tree.list, entry{CIDR: cidr, Value: &val}) node.value = val return } @@ -74,9 +90,10 @@ func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) { // Final node marks our cidr, set the value node.value = val + tree.list = append(tree.list, entry{CIDR: cidr, Value: &val}) } -// Finds the first match, which may be the least specific +// Contains finds the first match, which may be the least specific func (tree *Tree4) Contains(ip iputil.VpnIp) (value interface{}) { bit := startbit node := tree.root @@ -99,7 +116,7 @@ func (tree *Tree4) Contains(ip iputil.VpnIp) (value interface{}) { return value } -// Finds the most specific match +// MostSpecificContains finds the most specific match func (tree *Tree4) MostSpecificContains(ip iputil.VpnIp) (value interface{}) { bit := startbit node := tree.root @@ -121,7 +138,7 @@ func (tree *Tree4) MostSpecificContains(ip iputil.VpnIp) (value interface{}) { return value } -// Finds the most specific match +// Match finds the most specific match func (tree *Tree4) Match(ip iputil.VpnIp) (value interface{}) { bit := startbit node := tree.root @@ -143,3 +160,8 @@ func (tree *Tree4) Match(ip iputil.VpnIp) (value interface{}) { } return value } + +// List will return all CIDRs and their current values. Do not modify the contents! +func (tree *Tree4) List() []entry { + return tree.list +} diff --git a/cidr/tree4_test.go b/cidr/tree4_test.go index 07f2b0a..dce8d54 100644 --- a/cidr/tree4_test.go +++ b/cidr/tree4_test.go @@ -8,6 +8,20 @@ import ( "github.com/stretchr/testify/assert" ) +func TestCIDRTree_List(t *testing.T) { + tree := NewTree4() + tree.AddCIDR(Parse("1.0.0.0/16"), "1") + tree.AddCIDR(Parse("1.0.0.0/8"), "2") + tree.AddCIDR(Parse("1.0.0.0/16"), "3") + tree.AddCIDR(Parse("1.0.0.0/16"), "4") + list := tree.List() + assert.Len(t, list, 2) + assert.Equal(t, "1.0.0.0/8", list[0].CIDR.String()) + assert.Equal(t, "2", *list[0].Value) + assert.Equal(t, "1.0.0.0/16", list[1].CIDR.String()) + assert.Equal(t, "4", *list[1].Value) +} + func TestCIDRTree_Contains(t *testing.T) { tree := NewTree4() tree.AddCIDR(Parse("1.0.0.0/8"), "1") diff --git a/examples/config.yml b/examples/config.yml index db5d0e3..9356b3a 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -223,6 +223,10 @@ tun: # metric: 100 # install: true + # On linux only, set to true to manage unsafe routes directly on the system route table with gateway routes instead of + # in nebula configuration files. Default false, not reloadable. + #use_system_route_table: false + # TODO # Configure logging level logging: diff --git a/overlay/tun.go b/overlay/tun.go index 3da50b8..5eccec9 100644 --- a/overlay/tun.go +++ b/overlay/tun.go @@ -35,6 +35,7 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, fd * c.GetInt("tun.mtu", DefaultMTU), routes, c.GetInt("tun.tx_queue", 500), + c.GetBool("tun.use_system_route_table", false), ) default: @@ -46,6 +47,7 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, fd * routes, c.GetInt("tun.tx_queue", 500), routines > 1, + c.GetBool("tun.use_system_route_table", false), ) } } diff --git a/overlay/tun_android.go b/overlay/tun_android.go index 321aec8..c731d78 100644 --- a/overlay/tun_android.go +++ b/overlay/tun_android.go @@ -22,7 +22,7 @@ type tun struct { l *logrus.Logger } -func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int) (*tun, error) { +func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool) (*tun, error) { routeTree, err := makeRouteTree(l, routes, false) if err != nil { return nil, err @@ -41,7 +41,7 @@ func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes }, nil } -func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) { +func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in Android") } diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index 6320570..fd3429d 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -77,7 +77,7 @@ type ifreqMTU struct { pad [8]byte } -func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (*tun, error) { +func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) { routeTree, err := makeRouteTree(l, routes, false) if err != nil { return nil, err @@ -170,7 +170,7 @@ func (t *tun) deviceBytes() (o [16]byte) { return } -func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (*tun, error) { +func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in Darwin") } diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index 1054228..99cbdb0 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -38,11 +38,11 @@ func (t *tun) Close() error { return nil } -func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (*tun, error) { +func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD") } -func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (*tun, error) { +func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) { routeTree, err := makeRouteTree(l, routes, false) if err != nil { return nil, err diff --git a/overlay/tun_ios.go b/overlay/tun_ios.go index 59c190e..26f34ec 100644 --- a/overlay/tun_ios.go +++ b/overlay/tun_ios.go @@ -23,11 +23,11 @@ type tun struct { routeTree *cidr.Tree4 } -func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) { +func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in iOS") } -func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int) (*tun, error) { +func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool) (*tun, error) { routeTree, err := makeRouteTree(l, routes, false) if err != nil { return nil, err diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 932b585..7833186 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -4,11 +4,13 @@ package overlay import ( + "bytes" "fmt" "io" "net" "os" "strings" + "sync/atomic" "unsafe" "github.com/sirupsen/logrus" @@ -26,9 +28,13 @@ type tun struct { MaxMTU int DefaultMTU int TXQueueLen int - Routes []Route - routeTree *cidr.Tree4 - l *logrus.Logger + + Routes []Route + routeTree atomic.Pointer[cidr.Tree4] + routeChan chan struct{} + useSystemRoutes bool + + l *logrus.Logger } type ifReq struct { @@ -63,7 +69,7 @@ type ifreqQLEN struct { pad [8]byte } -func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int) (*tun, error) { +func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int, useSystemRoutes bool) (*tun, error) { routeTree, err := makeRouteTree(l, routes, true) if err != nil { return nil, err @@ -71,7 +77,7 @@ func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU in file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") - return &tun{ + t := &tun{ ReadWriteCloser: file, fd: int(file.Fd()), Device: "tun0", @@ -79,12 +85,14 @@ func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU in DefaultMTU: defaultMTU, TXQueueLen: txQueueLen, Routes: routes, - routeTree: routeTree, + useSystemRoutes: useSystemRoutes, l: l, - }, nil + } + t.routeTree.Store(routeTree) + return t, nil } -func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int, multiqueue bool) (*tun, error) { +func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int, multiqueue bool, useSystemRoutes bool) (*tun, error) { fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { return nil, err @@ -119,7 +127,7 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int return nil, err } - return &tun{ + t := &tun{ ReadWriteCloser: file, fd: int(file.Fd()), Device: name, @@ -128,9 +136,11 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int DefaultMTU: defaultMTU, TXQueueLen: txQueueLen, Routes: routes, - routeTree: routeTree, + useSystemRoutes: useSystemRoutes, l: l, - }, nil + } + t.routeTree.Store(routeTree) + return t, nil } func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { @@ -152,7 +162,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { } func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - r := t.routeTree.MostSpecificContains(ip) + r := t.routeTree.Load().MostSpecificContains(ip) if r != nil { return r.(iputil.VpnIp) } @@ -183,16 +193,20 @@ func (t *tun) Write(b []byte) (int, error) { } } -func (t tun) deviceBytes() (o [16]byte) { +func (t *tun) deviceBytes() (o [16]byte) { for i, c := range t.Device { o[i] = byte(c) } return } -func (t tun) Activate() error { +func (t *tun) Activate() error { devName := t.deviceBytes() + if t.useSystemRoutes { + t.watchRoutes() + } + var addr, mask [4]byte copy(addr[:], t.cidr.IP.To4()) @@ -318,7 +332,7 @@ func (t *tun) Name() string { return t.Device } -func (t tun) advMSS(r Route) int { +func (t *tun) advMSS(r Route) int { mtu := r.MTU if r.MTU == 0 { mtu = t.DefaultMTU @@ -330,3 +344,83 @@ func (t tun) advMSS(r Route) int { } return 0 } + +func (t *tun) watchRoutes() { + rch := make(chan netlink.RouteUpdate) + doneChan := make(chan struct{}) + + if err := netlink.RouteSubscribe(rch, doneChan); err != nil { + t.l.WithError(err).Errorf("failed to subscribe to system route changes") + return + } + + t.routeChan = doneChan + + go func() { + for { + select { + case r := <-rch: + t.updateRoutes(r) + case <-doneChan: + // netlink.RouteSubscriber will close the rch for us + return + } + } + }() +} + +func (t *tun) updateRoutes(r netlink.RouteUpdate) { + if r.Gw == nil { + // Not a gateway route, ignore + t.l.WithField("route", r).Debug("Ignoring route update, not a gateway route") + return + } + + if !t.cidr.Contains(r.Gw) { + // Gateway isn't in our overlay network, ignore + t.l.WithField("route", r).Debug("Ignoring route update, not in our network") + return + } + + if x := r.Dst.IP.To4(); x == nil { + // Nebula only handles ipv4 on the overlay currently + t.l.WithField("route", r).Debug("Ignoring route update, destination is not ipv4") + return + } + + newTree := cidr.NewTree4() + if r.Type == unix.RTM_NEWROUTE { + for _, oldR := range t.routeTree.Load().List() { + newTree.AddCIDR(oldR.CIDR, oldR.Value) + } + + t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Adding route") + newTree.AddCIDR(r.Dst, iputil.Ip2VpnIp(r.Gw)) + + } else { + gw := iputil.Ip2VpnIp(r.Gw) + for _, oldR := range t.routeTree.Load().List() { + if bytes.Equal(oldR.CIDR.IP, r.Dst.IP) && bytes.Equal(oldR.CIDR.Mask, r.Dst.Mask) && *oldR.Value != nil && (*oldR.Value).(iputil.VpnIp) == gw { + // This is the record to delete + t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route") + continue + } + + newTree.AddCIDR(oldR.CIDR, oldR.Value) + } + } + + t.routeTree.Store(newTree) +} + +func (t *tun) Close() error { + if t.routeChan != nil { + close(t.routeChan) + } + + if t.ReadWriteCloser != nil { + t.ReadWriteCloser.Close() + } + + return nil +} diff --git a/overlay/tun_linux_test.go b/overlay/tun_linux_test.go index 6c2043d..1c1842d 100644 --- a/overlay/tun_linux_test.go +++ b/overlay/tun_linux_test.go @@ -7,19 +7,19 @@ import "testing" var runAdvMSSTests = []struct { name string - tun tun + tun *tun r Route expected int }{ // Standard case, default MTU is the device max MTU - {"default", tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{}, 0}, - {"default-min", tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1440}, 0}, - {"default-low", tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1200}, 1160}, + {"default", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{}, 0}, + {"default-min", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1440}, 0}, + {"default-low", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1200}, 1160}, // Case where we have a route MTU set higher than the default - {"route", tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{}, 1400}, - {"route-min", tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 1440}, 1400}, - {"route-high", tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 8941}, 0}, + {"route", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{}, 1400}, + {"route-min", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 1440}, 1400}, + {"route-high", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 8941}, 0}, } func TestTunAdvMSS(t *testing.T) { diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index 38c11a6..3a49dcb 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -25,7 +25,7 @@ type TestTun struct { TxPackets chan []byte // Packets transmitted outside by nebula } -func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool) (*TestTun, error) { +func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool, _ bool) (*TestTun, error) { routeTree, err := makeRouteTree(l, routes, false) if err != nil { return nil, err @@ -42,7 +42,7 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, _ int, routes }, nil } -func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (*TestTun, error) { +func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*TestTun, error) { return nil, fmt.Errorf("newTunFromFd not supported") } diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index e35e98b..57d90cb 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -14,11 +14,11 @@ import ( "github.com/sirupsen/logrus" ) -func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (Device, error) { +func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (Device, error) { return nil, fmt.Errorf("newTunFromFd not supported in Windows") } -func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (Device, error) { +func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (Device, error) { useWintun := true if err := checkWinTunExists(); err != nil { l.WithError(err).Warn("Check Wintun driver failed, fallback to wintap driver") From 1701087035b142676d3fde5a1b40607c503c68da Mon Sep 17 00:00:00 2001 From: Ilya Lukyanov <73885545+jilyaluk@users.noreply.github.com> Date: Tue, 9 May 2023 16:37:23 +0100 Subject: [PATCH 5/6] Add destination CIDR checking (#507) --- examples/config.yml | 3 +- firewall.go | 101 +++++++++++++++++++----------- firewall_test.go | 147 ++++++++++++++++++++++++++++++-------------- 3 files changed, 169 insertions(+), 82 deletions(-) diff --git a/examples/config.yml b/examples/config.yml index 9356b3a..ad49a3c 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -305,7 +305,8 @@ firewall: # host: `any` or a literal hostname, ie `test-host` # group: `any` or a literal group name, ie `default-group` # groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass - # cidr: a CIDR, `0.0.0.0/0` is any. + # cidr: a remote CIDR, `0.0.0.0/0` is any. + # local_cidr: a local CIDR, `0.0.0.0/0` is any. This could be used to filter destinations when using unsafe_routes. # ca_name: An issuing CA name # ca_sha: An issuing CA shasum diff --git a/firewall.go b/firewall.go index 061d9e6..93d940d 100644 --- a/firewall.go +++ b/firewall.go @@ -25,7 +25,7 @@ const tcpACK = 0x10 const tcpFIN = 0x01 type FirewallInterface interface { - AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error + AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error } type conn struct { @@ -106,11 +106,12 @@ type FirewallCA struct { } type FirewallRule struct { - // Any makes Hosts, Groups, and CIDR irrelevant - Any bool - Hosts map[string]struct{} - Groups [][]string - CIDR *cidr.Tree4 + // Any makes Hosts, Groups, CIDR and LocalCIDR irrelevant + Any bool + Hosts map[string]struct{} + Groups [][]string + CIDR *cidr.Tree4 + LocalCIDR *cidr.Tree4 } // Even though ports are uint16, int32 maps are faster for lookup @@ -218,18 +219,22 @@ func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *conf } // AddRule properly creates the in memory rule structure for a firewall table. -func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error { +func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error { // Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS // https://github.com/golang/go/issues/14131 sIp := "" if ip != nil { sIp = ip.String() } + lIp := "" + if localIp != nil { + lIp = localIp.String() + } // We need this rule string because we generate a hash. Removing this will break firewall reload. ruleString := fmt.Sprintf( - "incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, caName: %v, caSha: %s", - incoming, proto, startPort, endPort, groups, host, sIp, caName, caSha, + "incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s", + incoming, proto, startPort, endPort, groups, host, sIp, lIp, caName, caSha, ) f.rules += ruleString + "\n" @@ -237,7 +242,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort if !incoming { direction = "outgoing" } - f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "caName": caName, "caSha": caSha}). + f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "localIp": lIp, "caName": caName, "caSha": caSha}). Info("Firewall rule added") var ( @@ -264,7 +269,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort return fmt.Errorf("unknown protocol %v", proto) } - return fp.addRule(startPort, endPort, groups, host, ip, caName, caSha) + return fp.addRule(startPort, endPort, groups, host, ip, localIp, caName, caSha) } // GetRuleHash returns a hash representation of all inbound and outbound rules @@ -302,8 +307,8 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw return fmt.Errorf("%s rule #%v; only one of port or code should be provided", table, i) } - if r.Host == "" && len(r.Groups) == 0 && r.Group == "" && r.Cidr == "" && r.CAName == "" && r.CASha == "" { - return fmt.Errorf("%s rule #%v; at least one of host, group, cidr, ca_name, or ca_sha must be provided", table, i) + if r.Host == "" && len(r.Groups) == 0 && r.Group == "" && r.Cidr == "" && r.LocalCidr == "" && r.CAName == "" && r.CASha == "" { + return fmt.Errorf("%s rule #%v; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided", table, i) } if len(r.Groups) > 0 { @@ -355,7 +360,15 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw } } - err = fw.AddRule(inbound, proto, startPort, endPort, groups, r.Host, cidr, r.CAName, r.CASha) + var localCidr *net.IPNet + if r.LocalCidr != "" { + _, localCidr, err = net.ParseCIDR(r.LocalCidr) + if err != nil { + return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err) + } + } + + err = fw.AddRule(inbound, proto, startPort, endPort, groups, r.Host, cidr, localCidr, r.CAName, r.CASha) if err != nil { return fmt.Errorf("%s rule #%v; `%s`", table, i, err) } @@ -595,7 +608,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaC return false } -func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error { +func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error { if startPort > endPort { return fmt.Errorf("start port was lower than end port") } @@ -608,7 +621,7 @@ func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string, } } - if err := fp[i].addRule(groups, host, ip, caName, caSha); err != nil { + if err := fp[i].addRule(groups, host, ip, localIp, caName, caSha); err != nil { return err } } @@ -639,12 +652,13 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCer return fp[firewall.PortAny].match(p, c, caPool) } -func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caName, caSha string) error { +func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPNet, caName, caSha string) error { fr := func() *FirewallRule { return &FirewallRule{ - Hosts: make(map[string]struct{}), - Groups: make([][]string, 0), - CIDR: cidr.NewTree4(), + Hosts: make(map[string]struct{}), + Groups: make([][]string, 0), + CIDR: cidr.NewTree4(), + LocalCIDR: cidr.NewTree4(), } } @@ -653,14 +667,14 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caNam fc.Any = fr() } - return fc.Any.addRule(groups, host, ip) + return fc.Any.addRule(groups, host, ip, localIp) } if caSha != "" { if _, ok := fc.CAShas[caSha]; !ok { fc.CAShas[caSha] = fr() } - err := fc.CAShas[caSha].addRule(groups, host, ip) + err := fc.CAShas[caSha].addRule(groups, host, ip, localIp) if err != nil { return err } @@ -670,7 +684,7 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caNam if _, ok := fc.CANames[caName]; !ok { fc.CANames[caName] = fr() } - err := fc.CANames[caName].addRule(groups, host, ip) + err := fc.CANames[caName].addRule(groups, host, ip, localIp) if err != nil { return err } @@ -702,17 +716,18 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool return fc.CANames[s.Details.Name].match(p, c) } -func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) error { +func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet, localIp *net.IPNet) error { if fr.Any { return nil } - if fr.isAny(groups, host, ip) { + if fr.isAny(groups, host, ip, localIp) { fr.Any = true // If it's any we need to wipe out any pre-existing rules to save on memory fr.Groups = make([][]string, 0) fr.Hosts = make(map[string]struct{}) fr.CIDR = cidr.NewTree4() + fr.LocalCIDR = cidr.NewTree4() } else { if len(groups) > 0 { fr.Groups = append(fr.Groups, groups) @@ -725,13 +740,17 @@ func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) err if ip != nil { fr.CIDR.AddCIDR(ip, struct{}{}) } + + if localIp != nil { + fr.LocalCIDR.AddCIDR(localIp, struct{}{}) + } } return nil } -func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool { - if len(groups) == 0 && host == "" && ip == nil { +func (fr *FirewallRule) isAny(groups []string, host string, ip, localIp *net.IPNet) bool { + if len(groups) == 0 && host == "" && ip == nil && localIp == nil { return true } @@ -749,6 +768,10 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool return true } + if localIp != nil && localIp.Contains(net.IPv4(0, 0, 0, 0)) { + return true + } + return false } @@ -790,20 +813,25 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool return true } + if fr.LocalCIDR != nil && fr.LocalCIDR.Contains(p.LocalIP) != nil { + return true + } + // No host, group, or cidr matched, bye bye return false } type rule struct { - Port string - Code string - Proto string - Host string - Group string - Groups []string - Cidr string - CAName string - CASha string + Port string + Code string + Proto string + Host string + Group string + Groups []string + Cidr string + LocalCidr string + CAName string + CASha string } func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, error) { @@ -827,6 +855,7 @@ func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, er r.Proto = toString("proto", m) r.Host = toString("host", m) r.Cidr = toString("cidr", m) + r.LocalCidr = toString("local_cidr", m) r.CAName = toString("ca_name", m) r.CASha = toString("ca_sha", m) diff --git a/firewall_test.go b/firewall_test.go index d824192..fe3d2e0 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -69,67 +69,75 @@ func TestFirewall_AddRule(t *testing.T) { _, ti, _ := net.ParseCIDR("1.2.3.4/32") - assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, nil, "", "")) // An empty rule is any assert.True(t, fw.InRules.TCP[1].Any.Any) assert.Empty(t, fw.InRules.TCP[1].Any.Groups) assert.Empty(t, fw.InRules.TCP[1].Any.Hosts) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", "")) assert.False(t, fw.InRules.UDP[1].Any.Any) assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1") assert.Empty(t, fw.InRules.UDP[1].Any.Hosts) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, nil, "", "")) assert.False(t, fw.InRules.ICMP[1].Any.Any) assert.Empty(t, fw.InRules.ICMP[1].Any.Groups) assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, "", "")) + assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, nil, "", "")) assert.False(t, fw.OutRules.AnyProto[1].Any.Any) assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups) assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP))) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", "")) + assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", nil, ti, "", "")) + assert.False(t, fw.OutRules.AnyProto[1].Any.Any) + assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups) + assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts) + assert.NotNil(t, fw.OutRules.AnyProto[1].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP))) + + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "ca-name", "")) assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "ca-sha")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", "ca-sha")) assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha") // Set any and clear fields fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", "")) + assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, ti, "", "")) assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0]) assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1") assert.NotNil(t, fw.OutRules.AnyProto[0].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP))) + assert.NotNil(t, fw.OutRules.AnyProto[0].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP))) // run twice just to make sure //TODO: these ANY rules should clear the CA firewall portion - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", "")) - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, "", "")) + assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", "")) + assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any) assert.Empty(t, fw.OutRules.AnyProto[0].Any.Groups) assert.Empty(t, fw.OutRules.AnyProto[0].Any.Hosts) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, "", "")) + assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) _, anyIp, _ := net.ParseCIDR("0.0.0.0/0") - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, "", "")) + assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, nil, "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any) // Test error conditions fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, "", "")) - assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", nil, "", "")) + assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, nil, "", "")) + assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", nil, nil, "", "")) } func TestFirewall_Drop(t *testing.T) { @@ -169,7 +177,7 @@ func TestFirewall_Drop(t *testing.T) { h.CreateRemoteCIDR(&c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", "")) cp := cert.NewCAPool() // Drop outbound @@ -188,28 +196,28 @@ func TestFirewall_Drop(t *testing.T) { // ensure signer doesn't get in the way of group checks fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum-bad")) assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule) // test caSha doesn't drop on match fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum-bad")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum")) assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil)) // ensure ca name doesn't get in the way of group checks cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", "")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good-bad", "")) assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule) // test caName doesn't drop on match cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", "")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good-bad", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good", "")) assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil)) } @@ -219,11 +227,11 @@ func BenchmarkFirewallTable_match(b *testing.B) { } _, n, _ := net.ParseCIDR("172.1.1.1/32") - _ = ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, "", "") - _ = ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, "", "") - _ = ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, "", "") - _ = ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, "", "") - _ = ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, "", "") + _ = ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, n, "", "") + _ = ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, n, "", "") + _ = ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, n, "", "") + _ = ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, n, "", "") + _ = ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, n, "", "") cp := cert.NewCAPool() b.Run("fail on proto", func(b *testing.B) { @@ -291,7 +299,20 @@ func BenchmarkFirewallTable_match(b *testing.B) { } }) - _ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, "", "") + b.Run("pass on local ip", func(b *testing.B) { + ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1)) + c := &cert.NebulaCertificate{ + Details: cert.NebulaCertificateDetails{ + InvertedGroups: map[string]struct{}{"nope": {}}, + Name: "good-host", + }, + } + for n := 0; n < b.N; n++ { + ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, LocalIP: ip}, true, c, cp) + } + }) + + _ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, n, "", "") b.Run("pass on ip with any port", func(b *testing.B) { ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1)) @@ -305,6 +326,19 @@ func BenchmarkFirewallTable_match(b *testing.B) { ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp) } }) + + b.Run("pass on local ip with any port", func(b *testing.B) { + ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1)) + c := &cert.NebulaCertificate{ + Details: cert.NebulaCertificateDetails{ + InvertedGroups: map[string]struct{}{"nope": {}}, + Name: "good-host", + }, + } + for n := 0; n < b.N; n++ { + ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip}, true, c, cp) + } + }) } func TestFirewall_Drop2(t *testing.T) { @@ -356,7 +390,7 @@ func TestFirewall_Drop2(t *testing.T) { h1.CreateRemoteCIDR(&c1) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, nil, "", "")) cp := cert.NewCAPool() // h1/c1 lacks the proper groups @@ -438,8 +472,8 @@ func TestFirewall_Drop3(t *testing.T) { h3.CreateRemoteCIDR(&c3) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", nil, "", "")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", nil, "", "signer-sha")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", nil, nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", nil, nil, "", "signer-sha")) cp := cert.NewCAPool() // c1 should pass because host match @@ -489,7 +523,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { h.CreateRemoteCIDR(&c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", "")) cp := cert.NewCAPool() // Drop outbound @@ -502,7 +536,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { oldFw := fw fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", nil, nil, "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 @@ -511,7 +545,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { oldFw = fw fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", nil, nil, "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 @@ -653,7 +687,7 @@ func TestNewFirewallFromConfig(t *testing.T) { conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}} _, err = NewFirewallFromConfig(l, c, conf) - assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, ca_name, or ca_sha must be provided") + assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided") // Test code/port error conf = config.NewC(l) @@ -677,6 +711,12 @@ func TestNewFirewallFromConfig(t *testing.T) { _, err = NewFirewallFromConfig(l, c, conf) assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh") + // Test local_cidr parse error + conf = config.NewC(l) + conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}} + _, err = NewFirewallFromConfig(l, c, conf) + assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; invalid CIDR address: testh") + // Test both group and groups conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}} @@ -691,63 +731,78 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { mf := &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) - assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall) // Test adding udp rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) - assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall) // Test adding icmp rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) - assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall) // Test adding any rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall) + + // Test adding rule with cidr + cidr := &net.IPNet{net.ParseIP("10.0.0.0").To4(), net.IPv4Mask(255, 0, 0, 0)} + conf = config.NewC(l) + mf = &mockFirewall{} + conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}} + assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: nil}, mf.lastCall) + + // Test adding rule with local_cidr + conf = config.NewC(l) + mf = &mockFirewall{} + conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}} + assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: cidr}, mf.lastCall) // Test adding rule with ca_sha conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caSha: "12312313123"}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caSha: "12312313123"}, mf.lastCall) // Test adding rule with ca_name conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caName: "root01"}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caName: "root01"}, mf.lastCall) // Test single group conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall) // Test single groups conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall) // Test multiple AND groups conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil, localIp: nil}, mf.lastCall) // Test Add error conf = config.NewC(l) @@ -892,6 +947,7 @@ type addRuleCall struct { groups []string host string ip *net.IPNet + localIp *net.IPNet caName string caSha string } @@ -901,7 +957,7 @@ type mockFirewall struct { nextCallReturn error } -func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error { +func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error { mf.lastCall = addRuleCall{ incoming: incoming, proto: proto, @@ -910,6 +966,7 @@ func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, end groups: groups, host: host, ip: ip, + localIp: localIp, caName: caName, caSha: caSha, } From 419aaf2e362d5feb9e1208b9c50fc801d85e269f Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Tue, 9 May 2023 11:37:48 -0400 Subject: [PATCH 6/6] issue templates: remove Report Security Vulnerability (#867) This is redundant as Github automatically adds a section for this near the top. --- .github/ISSUE_TEMPLATE/config.yml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index 9c675ca..94e2c6b 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -11,7 +11,3 @@ contact_links: - name: 📱 Mobile Nebula url: https://github.com/definednet/mobile_nebula about: 'This issue tracker is not for mobile support. Try the Mobile Nebula repo instead!' - - - name: 🔒 Report Security Vulnerability - url: https://github.com/slackhq/nebula/blob/master/SECURITY.md - about: 'Please view SECURITY.md to learn how to report security vulnerabilities.'