Merge remote-tracking branch 'origin/master' into mutex-debug

This commit is contained in:
Wade Simmons 2023-05-09 11:42:05 -04:00
commit a83f0ca470
23 changed files with 689 additions and 167 deletions

View File

@ -118,6 +118,17 @@ To build nebula for a specific platform (ex, Windows):
See the [Makefile](Makefile) for more details on build targets See the [Makefile](Makefile) for more details on build targets
## Curve P256 and BoringCrypto
The default curve used for cryptographic handshakes and signatures is Curve25519. This is the recommended setting for most users. If your deployment has certain compliance requirements, you have the option of creating your CA using `nebula-cert ca -curve P256` to use NIST Curve P256. The CA will then sign certificates using ECDSA P256, and any hosts using these certificates will use P256 for ECDH handshakes.
In addition, Nebula can be built using the [BoringCrypto GOEXPERIMENT](https://github.com/golang/go/blob/go1.20/src/crypto/internal/boring/README.md) by running either of the following make targets:
make bin-boringcrypto
make release-boringcrypto
This is not the recommended default deployment, but may be useful based on your compliance requirements.
## Credits ## Credits
Nebula was created at Slack Technologies, Inc by Nate Brown and Ryan Huber, with contributions from Oliver Fross, Alan Lam, Wade Simmons, and Lining Wang. Nebula was created at Slack Technologies, Inc by Nate Brown and Ryan Huber, with contributions from Oliver Fross, Alan Lam, Wade Simmons, and Lining Wang.

12
SECURITY.md Normal file
View File

@ -0,0 +1,12 @@
Security Policy
===============
Reporting a Vulnerability
-------------------------
If you believe you have found a security vulnerability with Nebula, please let
us know right away. We will investigate all reports and do our best to quickly
fix valid issues.
You can submit your report on [HackerOne](https://hackerone.com/slack) and our
security team will respond as soon as possible.

View File

@ -13,8 +13,14 @@ type Node struct {
value interface{} value interface{}
} }
type entry struct {
CIDR *net.IPNet
Value *interface{}
}
type Tree4 struct { type Tree4 struct {
root *Node root *Node
list []entry
} }
const ( const (
@ -24,6 +30,7 @@ const (
func NewTree4() *Tree4 { func NewTree4() *Tree4 {
tree := new(Tree4) tree := new(Tree4)
tree.root = &Node{} tree.root = &Node{}
tree.list = []entry{}
return tree return tree
} }
@ -53,6 +60,15 @@ func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) {
// We already have this range so update the value // We already have this range so update the value
if next != nil { if next != nil {
addCIDR := cidr.String()
for i, v := range tree.list {
if addCIDR == v.CIDR.String() {
tree.list = append(tree.list[:i], tree.list[i+1:]...)
break
}
}
tree.list = append(tree.list, entry{CIDR: cidr, Value: &val})
node.value = val node.value = val
return return
} }
@ -74,9 +90,10 @@ func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) {
// Final node marks our cidr, set the value // Final node marks our cidr, set the value
node.value = val node.value = val
tree.list = append(tree.list, entry{CIDR: cidr, Value: &val})
} }
// Finds the first match, which may be the least specific // Contains finds the first match, which may be the least specific
func (tree *Tree4) Contains(ip iputil.VpnIp) (value interface{}) { func (tree *Tree4) Contains(ip iputil.VpnIp) (value interface{}) {
bit := startbit bit := startbit
node := tree.root node := tree.root
@ -99,7 +116,7 @@ func (tree *Tree4) Contains(ip iputil.VpnIp) (value interface{}) {
return value return value
} }
// Finds the most specific match // MostSpecificContains finds the most specific match
func (tree *Tree4) MostSpecificContains(ip iputil.VpnIp) (value interface{}) { func (tree *Tree4) MostSpecificContains(ip iputil.VpnIp) (value interface{}) {
bit := startbit bit := startbit
node := tree.root node := tree.root
@ -121,7 +138,7 @@ func (tree *Tree4) MostSpecificContains(ip iputil.VpnIp) (value interface{}) {
return value return value
} }
// Finds the most specific match // Match finds the most specific match
func (tree *Tree4) Match(ip iputil.VpnIp) (value interface{}) { func (tree *Tree4) Match(ip iputil.VpnIp) (value interface{}) {
bit := startbit bit := startbit
node := tree.root node := tree.root
@ -143,3 +160,8 @@ func (tree *Tree4) Match(ip iputil.VpnIp) (value interface{}) {
} }
return value return value
} }
// List will return all CIDRs and their current values. Do not modify the contents!
func (tree *Tree4) List() []entry {
return tree.list
}

View File

@ -8,6 +8,20 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestCIDRTree_List(t *testing.T) {
tree := NewTree4()
tree.AddCIDR(Parse("1.0.0.0/16"), "1")
tree.AddCIDR(Parse("1.0.0.0/8"), "2")
tree.AddCIDR(Parse("1.0.0.0/16"), "3")
tree.AddCIDR(Parse("1.0.0.0/16"), "4")
list := tree.List()
assert.Len(t, list, 2)
assert.Equal(t, "1.0.0.0/8", list[0].CIDR.String())
assert.Equal(t, "2", *list[0].Value)
assert.Equal(t, "1.0.0.0/16", list[1].CIDR.String())
assert.Equal(t, "4", *list[1].Value)
}
func TestCIDRTree_Contains(t *testing.T) { func TestCIDRTree_Contains(t *testing.T) {
tree := NewTree4() tree := NewTree4()
tree.AddCIDR(Parse("1.0.0.0/8"), "1") tree.AddCIDR(Parse("1.0.0.0/8"), "1")

View File

@ -47,7 +47,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
Signature: []byte{1, 2, 1, 2, 1, 3}, Signature: []byte{1, 2, 1, 2, 1, 3},
} }
remotes := NewRemoteList() remotes := NewRemoteList(nil)
remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port))) remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port)))
remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port))) remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port)))
hm.Add(iputil.Ip2VpnIp(ipNet.IP), &HostInfo{ hm.Add(iputil.Ip2VpnIp(ipNet.IP), &HostInfo{

View File

@ -223,6 +223,10 @@ tun:
# metric: 100 # metric: 100
# install: true # install: true
# On linux only, set to true to manage unsafe routes directly on the system route table with gateway routes instead of
# in nebula configuration files. Default false, not reloadable.
#use_system_route_table: false
# TODO # TODO
# Configure logging level # Configure logging level
logging: logging:
@ -301,7 +305,8 @@ firewall:
# host: `any` or a literal hostname, ie `test-host` # host: `any` or a literal hostname, ie `test-host`
# group: `any` or a literal group name, ie `default-group` # group: `any` or a literal group name, ie `default-group`
# groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass # groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
# cidr: a CIDR, `0.0.0.0/0` is any. # cidr: a remote CIDR, `0.0.0.0/0` is any.
# local_cidr: a local CIDR, `0.0.0.0/0` is any. This could be used to filter destinations when using unsafe_routes.
# ca_name: An issuing CA name # ca_name: An issuing CA name
# ca_sha: An issuing CA shasum # ca_sha: An issuing CA shasum

View File

@ -25,7 +25,7 @@ const tcpACK = 0x10
const tcpFIN = 0x01 const tcpFIN = 0x01
type FirewallInterface interface { type FirewallInterface interface {
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error
} }
type conn struct { type conn struct {
@ -106,11 +106,12 @@ type FirewallCA struct {
} }
type FirewallRule struct { type FirewallRule struct {
// Any makes Hosts, Groups, and CIDR irrelevant // Any makes Hosts, Groups, CIDR and LocalCIDR irrelevant
Any bool Any bool
Hosts map[string]struct{} Hosts map[string]struct{}
Groups [][]string Groups [][]string
CIDR *cidr.Tree4 CIDR *cidr.Tree4
LocalCIDR *cidr.Tree4
} }
// Even though ports are uint16, int32 maps are faster for lookup // Even though ports are uint16, int32 maps are faster for lookup
@ -218,18 +219,22 @@ func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *conf
} }
// AddRule properly creates the in memory rule structure for a firewall table. // AddRule properly creates the in memory rule structure for a firewall table.
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error { func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
// Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS // Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS
// https://github.com/golang/go/issues/14131 // https://github.com/golang/go/issues/14131
sIp := "" sIp := ""
if ip != nil { if ip != nil {
sIp = ip.String() sIp = ip.String()
} }
lIp := ""
if localIp != nil {
lIp = localIp.String()
}
// We need this rule string because we generate a hash. Removing this will break firewall reload. // We need this rule string because we generate a hash. Removing this will break firewall reload.
ruleString := fmt.Sprintf( ruleString := fmt.Sprintf(
"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, caName: %v, caSha: %s", "incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s",
incoming, proto, startPort, endPort, groups, host, sIp, caName, caSha, incoming, proto, startPort, endPort, groups, host, sIp, lIp, caName, caSha,
) )
f.rules += ruleString + "\n" f.rules += ruleString + "\n"
@ -237,7 +242,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
if !incoming { if !incoming {
direction = "outgoing" direction = "outgoing"
} }
f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "caName": caName, "caSha": caSha}). f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "localIp": lIp, "caName": caName, "caSha": caSha}).
Info("Firewall rule added") Info("Firewall rule added")
var ( var (
@ -264,7 +269,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
return fmt.Errorf("unknown protocol %v", proto) return fmt.Errorf("unknown protocol %v", proto)
} }
return fp.addRule(startPort, endPort, groups, host, ip, caName, caSha) return fp.addRule(startPort, endPort, groups, host, ip, localIp, caName, caSha)
} }
// GetRuleHash returns a hash representation of all inbound and outbound rules // GetRuleHash returns a hash representation of all inbound and outbound rules
@ -302,8 +307,8 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
return fmt.Errorf("%s rule #%v; only one of port or code should be provided", table, i) return fmt.Errorf("%s rule #%v; only one of port or code should be provided", table, i)
} }
if r.Host == "" && len(r.Groups) == 0 && r.Group == "" && r.Cidr == "" && r.CAName == "" && r.CASha == "" { if r.Host == "" && len(r.Groups) == 0 && r.Group == "" && r.Cidr == "" && r.LocalCidr == "" && r.CAName == "" && r.CASha == "" {
return fmt.Errorf("%s rule #%v; at least one of host, group, cidr, ca_name, or ca_sha must be provided", table, i) return fmt.Errorf("%s rule #%v; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided", table, i)
} }
if len(r.Groups) > 0 { if len(r.Groups) > 0 {
@ -355,7 +360,15 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
} }
} }
err = fw.AddRule(inbound, proto, startPort, endPort, groups, r.Host, cidr, r.CAName, r.CASha) var localCidr *net.IPNet
if r.LocalCidr != "" {
_, localCidr, err = net.ParseCIDR(r.LocalCidr)
if err != nil {
return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err)
}
}
err = fw.AddRule(inbound, proto, startPort, endPort, groups, r.Host, cidr, localCidr, r.CAName, r.CASha)
if err != nil { if err != nil {
return fmt.Errorf("%s rule #%v; `%s`", table, i, err) return fmt.Errorf("%s rule #%v; `%s`", table, i, err)
} }
@ -595,7 +608,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaC
return false return false
} }
func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error { func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
if startPort > endPort { if startPort > endPort {
return fmt.Errorf("start port was lower than end port") return fmt.Errorf("start port was lower than end port")
} }
@ -608,7 +621,7 @@ func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string,
} }
} }
if err := fp[i].addRule(groups, host, ip, caName, caSha); err != nil { if err := fp[i].addRule(groups, host, ip, localIp, caName, caSha); err != nil {
return err return err
} }
} }
@ -639,12 +652,13 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCer
return fp[firewall.PortAny].match(p, c, caPool) return fp[firewall.PortAny].match(p, c, caPool)
} }
func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caName, caSha string) error { func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPNet, caName, caSha string) error {
fr := func() *FirewallRule { fr := func() *FirewallRule {
return &FirewallRule{ return &FirewallRule{
Hosts: make(map[string]struct{}), Hosts: make(map[string]struct{}),
Groups: make([][]string, 0), Groups: make([][]string, 0),
CIDR: cidr.NewTree4(), CIDR: cidr.NewTree4(),
LocalCIDR: cidr.NewTree4(),
} }
} }
@ -653,14 +667,14 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caNam
fc.Any = fr() fc.Any = fr()
} }
return fc.Any.addRule(groups, host, ip) return fc.Any.addRule(groups, host, ip, localIp)
} }
if caSha != "" { if caSha != "" {
if _, ok := fc.CAShas[caSha]; !ok { if _, ok := fc.CAShas[caSha]; !ok {
fc.CAShas[caSha] = fr() fc.CAShas[caSha] = fr()
} }
err := fc.CAShas[caSha].addRule(groups, host, ip) err := fc.CAShas[caSha].addRule(groups, host, ip, localIp)
if err != nil { if err != nil {
return err return err
} }
@ -670,7 +684,7 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caNam
if _, ok := fc.CANames[caName]; !ok { if _, ok := fc.CANames[caName]; !ok {
fc.CANames[caName] = fr() fc.CANames[caName] = fr()
} }
err := fc.CANames[caName].addRule(groups, host, ip) err := fc.CANames[caName].addRule(groups, host, ip, localIp)
if err != nil { if err != nil {
return err return err
} }
@ -702,17 +716,18 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool
return fc.CANames[s.Details.Name].match(p, c) return fc.CANames[s.Details.Name].match(p, c)
} }
func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) error { func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet, localIp *net.IPNet) error {
if fr.Any { if fr.Any {
return nil return nil
} }
if fr.isAny(groups, host, ip) { if fr.isAny(groups, host, ip, localIp) {
fr.Any = true fr.Any = true
// If it's any we need to wipe out any pre-existing rules to save on memory // If it's any we need to wipe out any pre-existing rules to save on memory
fr.Groups = make([][]string, 0) fr.Groups = make([][]string, 0)
fr.Hosts = make(map[string]struct{}) fr.Hosts = make(map[string]struct{})
fr.CIDR = cidr.NewTree4() fr.CIDR = cidr.NewTree4()
fr.LocalCIDR = cidr.NewTree4()
} else { } else {
if len(groups) > 0 { if len(groups) > 0 {
fr.Groups = append(fr.Groups, groups) fr.Groups = append(fr.Groups, groups)
@ -725,13 +740,17 @@ func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) err
if ip != nil { if ip != nil {
fr.CIDR.AddCIDR(ip, struct{}{}) fr.CIDR.AddCIDR(ip, struct{}{})
} }
if localIp != nil {
fr.LocalCIDR.AddCIDR(localIp, struct{}{})
}
} }
return nil return nil
} }
func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool { func (fr *FirewallRule) isAny(groups []string, host string, ip, localIp *net.IPNet) bool {
if len(groups) == 0 && host == "" && ip == nil { if len(groups) == 0 && host == "" && ip == nil && localIp == nil {
return true return true
} }
@ -749,6 +768,10 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool
return true return true
} }
if localIp != nil && localIp.Contains(net.IPv4(0, 0, 0, 0)) {
return true
}
return false return false
} }
@ -790,6 +813,10 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
return true return true
} }
if fr.LocalCIDR != nil && fr.LocalCIDR.Contains(p.LocalIP) != nil {
return true
}
// No host, group, or cidr matched, bye bye // No host, group, or cidr matched, bye bye
return false return false
} }
@ -802,6 +829,7 @@ type rule struct {
Group string Group string
Groups []string Groups []string
Cidr string Cidr string
LocalCidr string
CAName string CAName string
CASha string CASha string
} }
@ -827,6 +855,7 @@ func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, er
r.Proto = toString("proto", m) r.Proto = toString("proto", m)
r.Host = toString("host", m) r.Host = toString("host", m)
r.Cidr = toString("cidr", m) r.Cidr = toString("cidr", m)
r.LocalCidr = toString("local_cidr", m)
r.CAName = toString("ca_name", m) r.CAName = toString("ca_name", m)
r.CASha = toString("ca_sha", m) r.CASha = toString("ca_sha", m)

View File

@ -69,67 +69,75 @@ func TestFirewall_AddRule(t *testing.T) {
_, ti, _ := net.ParseCIDR("1.2.3.4/32") _, ti, _ := net.ParseCIDR("1.2.3.4/32")
assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, nil, "", ""))
// An empty rule is any // An empty rule is any
assert.True(t, fw.InRules.TCP[1].Any.Any) assert.True(t, fw.InRules.TCP[1].Any.Any)
assert.Empty(t, fw.InRules.TCP[1].Any.Groups) assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts) assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", ""))
assert.False(t, fw.InRules.UDP[1].Any.Any) assert.False(t, fw.InRules.UDP[1].Any.Any)
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1") assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1")
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts) assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, nil, "", ""))
assert.False(t, fw.InRules.ICMP[1].Any.Any) assert.False(t, fw.InRules.ICMP[1].Any.Any)
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups) assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1") assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, "", "")) assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, nil, "", ""))
assert.False(t, fw.OutRules.AnyProto[1].Any.Any) assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups) assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts) assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP))) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP)))
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", "")) assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", nil, ti, "", ""))
assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP)))
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "ca-name", ""))
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name") assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "ca-sha")) assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", "ca-sha"))
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha") assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
// Set any and clear fields // Set any and clear fields
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", "")) assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, ti, "", ""))
assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0]) assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0])
assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1") assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")
assert.NotNil(t, fw.OutRules.AnyProto[0].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP))) assert.NotNil(t, fw.OutRules.AnyProto[0].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP)))
assert.NotNil(t, fw.OutRules.AnyProto[0].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP)))
// run twice just to make sure // run twice just to make sure
//TODO: these ANY rules should clear the CA firewall portion //TODO: these ANY rules should clear the CA firewall portion
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", "")) assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, "", "")) assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
assert.Empty(t, fw.OutRules.AnyProto[0].Any.Groups) assert.Empty(t, fw.OutRules.AnyProto[0].Any.Groups)
assert.Empty(t, fw.OutRules.AnyProto[0].Any.Hosts) assert.Empty(t, fw.OutRules.AnyProto[0].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, "", "")) assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
_, anyIp, _ := net.ParseCIDR("0.0.0.0/0") _, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, "", "")) assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, nil, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
// Test error conditions // Test error conditions
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, "", "")) assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, nil, "", ""))
assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", nil, "", "")) assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", nil, nil, "", ""))
} }
func TestFirewall_Drop(t *testing.T) { func TestFirewall_Drop(t *testing.T) {
@ -169,7 +177,7 @@ func TestFirewall_Drop(t *testing.T) {
h.CreateRemoteCIDR(&c) h.CreateRemoteCIDR(&c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// Drop outbound // Drop outbound
@ -188,28 +196,28 @@ func TestFirewall_Drop(t *testing.T) {
// ensure signer doesn't get in the way of group checks // ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum"))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match // test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum-bad"))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum"))
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil)) assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks // ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good-bad", ""))
assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match // test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good-bad", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good", ""))
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil)) assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
} }
@ -219,11 +227,11 @@ func BenchmarkFirewallTable_match(b *testing.B) {
} }
_, n, _ := net.ParseCIDR("172.1.1.1/32") _, n, _ := net.ParseCIDR("172.1.1.1/32")
_ = ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, "", "") _ = ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, n, "", "")
_ = ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, "", "") _ = ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, n, "", "")
_ = ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, "", "") _ = ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, n, "", "")
_ = ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, "", "") _ = ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, n, "", "")
_ = ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, "", "") _ = ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, n, "", "")
cp := cert.NewCAPool() cp := cert.NewCAPool()
b.Run("fail on proto", func(b *testing.B) { b.Run("fail on proto", func(b *testing.B) {
@ -291,7 +299,20 @@ func BenchmarkFirewallTable_match(b *testing.B) {
} }
}) })
_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, "", "") b.Run("pass on local ip", func(b *testing.B) {
ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
InvertedGroups: map[string]struct{}{"nope": {}},
Name: "good-host",
},
}
for n := 0; n < b.N; n++ {
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, LocalIP: ip}, true, c, cp)
}
})
_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, n, "", "")
b.Run("pass on ip with any port", func(b *testing.B) { b.Run("pass on ip with any port", func(b *testing.B) {
ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1)) ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
@ -305,6 +326,19 @@ func BenchmarkFirewallTable_match(b *testing.B) {
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp) ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp)
} }
}) })
b.Run("pass on local ip with any port", func(b *testing.B) {
ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
InvertedGroups: map[string]struct{}{"nope": {}},
Name: "good-host",
},
}
for n := 0; n < b.N; n++ {
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip}, true, c, cp)
}
})
} }
func TestFirewall_Drop2(t *testing.T) { func TestFirewall_Drop2(t *testing.T) {
@ -356,7 +390,7 @@ func TestFirewall_Drop2(t *testing.T) {
h1.CreateRemoteCIDR(&c1) h1.CreateRemoteCIDR(&c1)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, nil, "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// h1/c1 lacks the proper groups // h1/c1 lacks the proper groups
@ -438,8 +472,8 @@ func TestFirewall_Drop3(t *testing.T) {
h3.CreateRemoteCIDR(&c3) h3.CreateRemoteCIDR(&c3)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", nil, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", nil, nil, "", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", nil, "", "signer-sha")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", nil, nil, "", "signer-sha"))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// c1 should pass because host match // c1 should pass because host match
@ -489,7 +523,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
h.CreateRemoteCIDR(&c) h.CreateRemoteCIDR(&c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// Drop outbound // Drop outbound
@ -502,7 +536,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
oldFw := fw oldFw := fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", nil, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", nil, nil, "", ""))
fw.Conntrack = oldFw.Conntrack fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1 fw.rulesVersion = oldFw.rulesVersion + 1
@ -511,7 +545,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
oldFw = fw oldFw = fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", nil, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", nil, nil, "", ""))
fw.Conntrack = oldFw.Conntrack fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1 fw.rulesVersion = oldFw.rulesVersion + 1
@ -653,7 +687,7 @@ func TestNewFirewallFromConfig(t *testing.T) {
conf = config.NewC(l) conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
_, err = NewFirewallFromConfig(l, c, conf) _, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, ca_name, or ca_sha must be provided") assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
// Test code/port error // Test code/port error
conf = config.NewC(l) conf = config.NewC(l)
@ -677,6 +711,12 @@ func TestNewFirewallFromConfig(t *testing.T) {
_, err = NewFirewallFromConfig(l, c, conf) _, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh") assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh")
// Test local_cidr parse error
conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; invalid CIDR address: testh")
// Test both group and groups // Test both group and groups
conf = config.NewC(l) conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
@ -691,63 +731,78 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
mf := &mockFirewall{} mf := &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
// Test adding udp rule // Test adding udp rule
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
// Test adding icmp rule // Test adding icmp rule
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
// Test adding any rule // Test adding any rule
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
// Test adding rule with cidr
cidr := &net.IPNet{net.ParseIP("10.0.0.0").To4(), net.IPv4Mask(255, 0, 0, 0)}
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: nil}, mf.lastCall)
// Test adding rule with local_cidr
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: cidr}, mf.lastCall)
// Test adding rule with ca_sha // Test adding rule with ca_sha
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caSha: "12312313123"}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caSha: "12312313123"}, mf.lastCall)
// Test adding rule with ca_name // Test adding rule with ca_name
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caName: "root01"}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caName: "root01"}, mf.lastCall)
// Test single group // Test single group
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall)
// Test single groups // Test single groups
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall)
// Test multiple AND groups // Test multiple AND groups
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil, localIp: nil}, mf.lastCall)
// Test Add error // Test Add error
conf = config.NewC(l) conf = config.NewC(l)
@ -892,6 +947,7 @@ type addRuleCall struct {
groups []string groups []string
host string host string
ip *net.IPNet ip *net.IPNet
localIp *net.IPNet
caName string caName string
caSha string caSha string
} }
@ -901,7 +957,7 @@ type mockFirewall struct {
nextCallReturn error nextCallReturn error
} }
func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error { func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
mf.lastCall = addRuleCall{ mf.lastCall = addRuleCall{
incoming: incoming, incoming: incoming,
proto: proto, proto: proto,
@ -910,6 +966,7 @@ func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, end
groups: groups, groups: groups,
host: host, host: host,
ip: ip, ip: ip,
localIp: localIp,
caName: caName, caName: caName,
caSha: caSha, caSha: caSha,
} }

View File

@ -41,7 +41,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
assert.False(t, initCalled) assert.False(t, initCalled)
assert.Same(t, i, i2) assert.Same(t, i, i2)
i.remotes = NewRemoteList() i.remotes = NewRemoteList(nil)
i.HandshakeReady = true i.HandshakeReady = true
// Adding something to pending should not affect the main hostmap // Adding something to pending should not affect the main hostmap

View File

@ -6,6 +6,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"net/netip"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -33,6 +34,7 @@ type netIpAndPort struct {
type LightHouse struct { type LightHouse struct {
//TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time //TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time
sync.RWMutex //Because we concurrently read and write to our maps sync.RWMutex //Because we concurrently read and write to our maps
ctx context.Context
amLighthouse bool amLighthouse bool
myVpnIp iputil.VpnIp myVpnIp iputil.VpnIp
myVpnZeros iputil.VpnIp myVpnZeros iputil.VpnIp
@ -82,7 +84,7 @@ type LightHouse struct {
// NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object // NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object
// addrMap should be nil unless this is during a config reload // addrMap should be nil unless this is during a config reload
func NewLightHouseFromConfig(l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, pc *udp.Conn, p *Punchy) (*LightHouse, error) { func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, pc *udp.Conn, p *Punchy) (*LightHouse, error) {
amLighthouse := c.GetBool("lighthouse.am_lighthouse", false) amLighthouse := c.GetBool("lighthouse.am_lighthouse", false)
nebulaPort := uint32(c.GetInt("listen.port", 0)) nebulaPort := uint32(c.GetInt("listen.port", 0))
if amLighthouse && nebulaPort == 0 { if amLighthouse && nebulaPort == 0 {
@ -100,6 +102,7 @@ func NewLightHouseFromConfig(l *logrus.Logger, c *config.C, myVpnNet *net.IPNet,
ones, _ := myVpnNet.Mask.Size() ones, _ := myVpnNet.Mask.Size()
h := LightHouse{ h := LightHouse{
ctx: ctx,
amLighthouse: amLighthouse, amLighthouse: amLighthouse,
myVpnIp: iputil.Ip2VpnIp(myVpnNet.IP), myVpnIp: iputil.Ip2VpnIp(myVpnNet.IP),
myVpnZeros: iputil.VpnIp(32 - ones), myVpnZeros: iputil.VpnIp(32 - ones),
@ -258,7 +261,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
} }
//NOTE: many things will get much simpler when we combine static_host_map and lighthouse.hosts in config //NOTE: many things will get much simpler when we combine static_host_map and lighthouse.hosts in config
if initial || c.HasChanged("static_host_map") { if initial || c.HasChanged("static_host_map") || c.HasChanged("static_map.cadence") || c.HasChanged("static_map.network") || c.HasChanged("static_map.lookup_timeout") {
staticList := make(map[iputil.VpnIp]struct{}) staticList := make(map[iputil.VpnIp]struct{})
err := lh.loadStaticMap(c, lh.myVpnNet, staticList) err := lh.loadStaticMap(c, lh.myVpnNet, staticList)
if err != nil { if err != nil {
@ -268,9 +271,19 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
lh.staticList.Store(&staticList) lh.staticList.Store(&staticList)
if !initial { if !initial {
//TODO: we should remove any remote list entries for static hosts that were removed/modified? //TODO: we should remove any remote list entries for static hosts that were removed/modified?
if c.HasChanged("static_host_map") {
lh.l.Info("static_host_map has changed") lh.l.Info("static_host_map has changed")
} }
if c.HasChanged("static_map.cadence") {
lh.l.Info("static_map.cadence has changed")
}
if c.HasChanged("static_map.network") {
lh.l.Info("static_map.network has changed")
}
if c.HasChanged("static_map.lookup_timeout") {
lh.l.Info("static_map.lookup_timeout has changed")
}
}
} }
if initial || c.HasChanged("lighthouse.hosts") { if initial || c.HasChanged("lighthouse.hosts") {
@ -344,7 +357,48 @@ func (lh *LightHouse) parseLighthouses(c *config.C, tunCidr *net.IPNet, lhMap ma
return nil return nil
} }
func getStaticMapCadence(c *config.C) (time.Duration, error) {
cadence := c.GetString("static_map.cadence", "30s")
d, err := time.ParseDuration(cadence)
if err != nil {
return 0, err
}
return d, nil
}
func getStaticMapLookupTimeout(c *config.C) (time.Duration, error) {
lookupTimeout := c.GetString("static_map.lookup_timeout", "250ms")
d, err := time.ParseDuration(lookupTimeout)
if err != nil {
return 0, err
}
return d, nil
}
func getStaticMapNetwork(c *config.C) (string, error) {
network := c.GetString("static_map.network", "ip4")
if network != "ip" && network != "ip4" && network != "ip6" {
return "", fmt.Errorf("static_map.network must be one of ip, ip4, or ip6")
}
return network, nil
}
func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList map[iputil.VpnIp]struct{}) error { func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList map[iputil.VpnIp]struct{}) error {
d, err := getStaticMapCadence(c)
if err != nil {
return err
}
network, err := getStaticMapNetwork(c)
if err != nil {
return err
}
lookup_timeout, err := getStaticMapLookupTimeout(c)
if err != nil {
return err
}
shm := c.GetMap("static_host_map", map[interface{}]interface{}{}) shm := c.GetMap("static_host_map", map[interface{}]interface{}{})
i := 0 i := 0
@ -360,21 +414,17 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList
vpnIp := iputil.Ip2VpnIp(rip) vpnIp := iputil.Ip2VpnIp(rip)
vals, ok := v.([]interface{}) vals, ok := v.([]interface{})
if ok { if !ok {
for _, v := range vals { vals = []interface{}{v}
ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v))
if err != nil {
return util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp, "entry": i + 1}, err)
} }
lh.addStaticRemote(vpnIp, udp.NewAddr(ip, port), staticList) remoteAddrs := []string{}
for _, v := range vals {
remoteAddrs = append(remoteAddrs, fmt.Sprintf("%v", v))
} }
} else { err := lh.addStaticRemotes(i, d, network, lookup_timeout, vpnIp, remoteAddrs, staticList)
ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v))
if err != nil { if err != nil {
return util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp, "entry": i + 1}, err) return err
}
lh.addStaticRemote(vpnIp, udp.NewAddr(ip, port), staticList)
} }
i++ i++
} }
@ -482,30 +532,47 @@ func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) {
// We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with // We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with
// And we don't want a lighthouse query reply to interfere with our learned cache if we are a client // And we don't want a lighthouse query reply to interfere with our learned cache if we are a client
// NOTE: this function should not interact with any hot path objects, like lh.staticList, the caller should handle it // NOTE: this function should not interact with any hot path objects, like lh.staticList, the caller should handle it
func (lh *LightHouse) addStaticRemote(vpnIp iputil.VpnIp, toAddr *udp.Addr, staticList map[iputil.VpnIp]struct{}) { func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnIp iputil.VpnIp, toAddrs []string, staticList map[iputil.VpnIp]struct{}) error {
lh.Lock() lh.Lock()
am := lh.unlockedGetRemoteList(vpnIp) am := lh.unlockedGetRemoteList(vpnIp)
am.Lock() am.Lock()
defer am.Unlock() defer am.Unlock()
ctx := lh.ctx
lh.Unlock() lh.Unlock()
if ipv4 := toAddr.IP.To4(); ipv4 != nil { hr, err := NewHostnameResults(ctx, lh.l, d, network, timeout, toAddrs, func() {
to := NewIp4AndPort(ipv4, uint32(toAddr.Port)) // This callback runs whenever the DNS hostname resolver finds a different set of IP's
// in its resolution for hostnames.
am.Lock()
defer am.Unlock()
am.shouldRebuild = true
})
if err != nil {
return util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp, "entry": i + 1}, err)
}
am.unlockedSetHostnamesResults(hr)
for _, addrPort := range hr.GetIPs() {
switch {
case addrPort.Addr().Is4():
to := NewIp4AndPortFromNetIP(addrPort.Addr(), addrPort.Port())
if !lh.unlockedShouldAddV4(vpnIp, to) { if !lh.unlockedShouldAddV4(vpnIp, to) {
return continue
} }
am.unlockedPrependV4(lh.myVpnIp, to) am.unlockedPrependV4(lh.myVpnIp, to)
case addrPort.Addr().Is6():
} else { to := NewIp6AndPortFromNetIP(addrPort.Addr(), addrPort.Port())
to := NewIp6AndPort(toAddr.IP, uint32(toAddr.Port))
if !lh.unlockedShouldAddV6(vpnIp, to) { if !lh.unlockedShouldAddV6(vpnIp, to) {
return continue
} }
am.unlockedPrependV6(lh.myVpnIp, to) am.unlockedPrependV6(lh.myVpnIp, to)
} }
}
// Mark it as static in the caller provided map // Mark it as static in the caller provided map
staticList[vpnIp] = struct{}{} staticList[vpnIp] = struct{}{}
return nil
} }
// addCalculatedRemotes adds any calculated remotes based on the // addCalculatedRemotes adds any calculated remotes based on the
@ -545,12 +612,42 @@ func (lh *LightHouse) addCalculatedRemotes(vpnIp iputil.VpnIp) bool {
func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList { func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList {
am, ok := lh.addrMap[vpnIp] am, ok := lh.addrMap[vpnIp]
if !ok { if !ok {
am = NewRemoteList() am = NewRemoteList(func(a netip.Addr) bool { return lh.shouldAdd(vpnIp, a) })
lh.addrMap[vpnIp] = am lh.addrMap[vpnIp] = am
} }
return am return am
} }
func (lh *LightHouse) shouldAdd(vpnIp iputil.VpnIp, to netip.Addr) bool {
switch {
case to.Is4():
ipBytes := to.As4()
ip := iputil.Ip2VpnIp(ipBytes[:])
allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, ip)
if lh.l.Level >= logrus.TraceLevel {
lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow")
}
if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, ip) {
return false
}
case to.Is6():
ipBytes := to.As16()
hi := binary.BigEndian.Uint64(ipBytes[:8])
lo := binary.BigEndian.Uint64(ipBytes[8:])
allow := lh.GetRemoteAllowList().AllowIpV6(vpnIp, hi, lo)
if lh.l.Level >= logrus.TraceLevel {
lh.l.WithField("remoteIp", to).WithField("allow", allow).Trace("remoteAllowList.Allow")
}
// We don't check our vpn network here because nebula does not support ipv6 on the inside
if !allow {
return false
}
}
return true
}
// unlockedShouldAddV4 checks if to is allowed by our allow list // unlockedShouldAddV4 checks if to is allowed by our allow list
func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bool { func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bool {
allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, iputil.VpnIp(to.Ip)) allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, iputil.VpnIp(to.Ip))
@ -609,6 +706,14 @@ func NewIp4AndPort(ip net.IP, port uint32) *Ip4AndPort {
return &ipp return &ipp
} }
func NewIp4AndPortFromNetIP(ip netip.Addr, port uint16) *Ip4AndPort {
v4Addr := ip.As4()
return &Ip4AndPort{
Ip: binary.BigEndian.Uint32(v4Addr[:]),
Port: uint32(port),
}
}
func NewIp6AndPort(ip net.IP, port uint32) *Ip6AndPort { func NewIp6AndPort(ip net.IP, port uint32) *Ip6AndPort {
return &Ip6AndPort{ return &Ip6AndPort{
Hi: binary.BigEndian.Uint64(ip[:8]), Hi: binary.BigEndian.Uint64(ip[:8]),
@ -617,6 +722,14 @@ func NewIp6AndPort(ip net.IP, port uint32) *Ip6AndPort {
} }
} }
func NewIp6AndPortFromNetIP(ip netip.Addr, port uint16) *Ip6AndPort {
ip6Addr := ip.As16()
return &Ip6AndPort{
Hi: binary.BigEndian.Uint64(ip6Addr[:8]),
Lo: binary.BigEndian.Uint64(ip6Addr[8:]),
Port: uint32(port),
}
}
func NewUDPAddrFromLH4(ipp *Ip4AndPort) *udp.Addr { func NewUDPAddrFromLH4(ipp *Ip4AndPort) *udp.Addr {
ip := ipp.Ip ip := ipp.Ip
return udp.NewAddr( return udp.NewAddr(

View File

@ -1,6 +1,7 @@
package nebula package nebula
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"testing" "testing"
@ -53,14 +54,14 @@ func Test_lhStaticMapping(t *testing.T) {
c := config.NewC(l) c := config.NewC(l)
c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}} c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}}
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}} c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
_, err := NewLightHouseFromConfig(l, c, myVpnNet, nil, nil) _, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
assert.Nil(t, err) assert.Nil(t, err)
lh2 := "10.128.0.3" lh2 := "10.128.0.3"
c = config.NewC(l) c = config.NewC(l)
c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}} c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}}
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}} c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}}
_, err = NewLightHouseFromConfig(l, c, myVpnNet, nil, nil) _, err = NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry") assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry")
} }
@ -69,14 +70,14 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
_, myVpnNet, _ := net.ParseCIDR("10.128.0.1/0") _, myVpnNet, _ := net.ParseCIDR("10.128.0.1/0")
c := config.NewC(l) c := config.NewC(l)
lh, err := NewLightHouseFromConfig(l, c, myVpnNet, nil, nil) lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
if !assert.NoError(b, err) { if !assert.NoError(b, err) {
b.Fatal() b.Fatal()
} }
hAddr := udp.NewAddrFromString("4.5.6.7:12345") hAddr := udp.NewAddrFromString("4.5.6.7:12345")
hAddr2 := udp.NewAddrFromString("4.5.6.7:12346") hAddr2 := udp.NewAddrFromString("4.5.6.7:12346")
lh.addrMap[3] = NewRemoteList() lh.addrMap[3] = NewRemoteList(nil)
lh.addrMap[3].unlockedSetV4( lh.addrMap[3].unlockedSetV4(
3, 3,
3, 3,
@ -89,7 +90,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
rAddr := udp.NewAddrFromString("1.2.2.3:12345") rAddr := udp.NewAddrFromString("1.2.2.3:12345")
rAddr2 := udp.NewAddrFromString("1.2.2.3:12346") rAddr2 := udp.NewAddrFromString("1.2.2.3:12346")
lh.addrMap[2] = NewRemoteList() lh.addrMap[2] = NewRemoteList(nil)
lh.addrMap[2].unlockedSetV4( lh.addrMap[2].unlockedSetV4(
3, 3,
3, 3,
@ -162,7 +163,7 @@ func TestLighthouse_Memory(t *testing.T) {
c := config.NewC(l) c := config.NewC(l)
c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true} c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
c.Settings["listen"] = map[interface{}]interface{}{"port": 4242} c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
lh, err := NewLightHouseFromConfig(l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil) lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil)
assert.NoError(t, err) assert.NoError(t, err)
lhh := lh.NewRequestHandler() lhh := lh.NewRequestHandler()
@ -238,7 +239,7 @@ func TestLighthouse_reload(t *testing.T) {
c := config.NewC(l) c := config.NewC(l)
c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true} c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
c.Settings["listen"] = map[interface{}]interface{}{"port": 4242} c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
lh, err := NewLightHouseFromConfig(l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil) lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil)
assert.NoError(t, err) assert.NoError(t, err)
c.Settings["static_host_map"] = map[interface{}]interface{}{"10.128.0.2": []interface{}{"1.1.1.1:4242"}} c.Settings["static_host_map"] = map[interface{}]interface{}{"10.128.0.2": []interface{}{"1.1.1.1:4242"}}

View File

@ -226,7 +226,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
*/ */
punchy := NewPunchyFromConfig(l, c) punchy := NewPunchyFromConfig(l, c)
lightHouse, err := NewLightHouseFromConfig(l, c, tunCidr, udpConns[0], punchy) lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy)
switch { switch {
case errors.As(err, &util.ContextualError{}): case errors.As(err, &util.ContextualError{}):
return nil, err return nil, err

View File

@ -35,6 +35,7 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, fd *
c.GetInt("tun.mtu", DefaultMTU), c.GetInt("tun.mtu", DefaultMTU),
routes, routes,
c.GetInt("tun.tx_queue", 500), c.GetInt("tun.tx_queue", 500),
c.GetBool("tun.use_system_route_table", false),
) )
default: default:
@ -46,6 +47,7 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, fd *
routes, routes,
c.GetInt("tun.tx_queue", 500), c.GetInt("tun.tx_queue", 500),
routines > 1, routines > 1,
c.GetBool("tun.use_system_route_table", false),
) )
} }
} }

View File

@ -22,7 +22,7 @@ type tun struct {
l *logrus.Logger l *logrus.Logger
} }
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int) (*tun, error) { func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool) (*tun, error) {
routeTree, err := makeRouteTree(l, routes, false) routeTree, err := makeRouteTree(l, routes, false)
if err != nil { if err != nil {
return nil, err return nil, err
@ -41,7 +41,7 @@ func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes
}, nil }, nil
} }
func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) { func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTun not supported in Android") return nil, fmt.Errorf("newTun not supported in Android")
} }

View File

@ -77,7 +77,7 @@ type ifreqMTU struct {
pad [8]byte pad [8]byte
} }
func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (*tun, error) { func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) {
routeTree, err := makeRouteTree(l, routes, false) routeTree, err := makeRouteTree(l, routes, false)
if err != nil { if err != nil {
return nil, err return nil, err
@ -170,7 +170,7 @@ func (t *tun) deviceBytes() (o [16]byte) {
return return
} }
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (*tun, error) { func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in Darwin") return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
} }

View File

@ -38,11 +38,11 @@ func (t *tun) Close() error {
return nil return nil
} }
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (*tun, error) { func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD") return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
} }
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (*tun, error) { func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) {
routeTree, err := makeRouteTree(l, routes, false) routeTree, err := makeRouteTree(l, routes, false)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -23,11 +23,11 @@ type tun struct {
routeTree *cidr.Tree4 routeTree *cidr.Tree4
} }
func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) { func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTun not supported in iOS") return nil, fmt.Errorf("newTun not supported in iOS")
} }
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int) (*tun, error) { func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool) (*tun, error) {
routeTree, err := makeRouteTree(l, routes, false) routeTree, err := makeRouteTree(l, routes, false)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -4,11 +4,13 @@
package overlay package overlay
import ( import (
"bytes"
"fmt" "fmt"
"io" "io"
"net" "net"
"os" "os"
"strings" "strings"
"sync/atomic"
"unsafe" "unsafe"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -26,8 +28,12 @@ type tun struct {
MaxMTU int MaxMTU int
DefaultMTU int DefaultMTU int
TXQueueLen int TXQueueLen int
Routes []Route Routes []Route
routeTree *cidr.Tree4 routeTree atomic.Pointer[cidr.Tree4]
routeChan chan struct{}
useSystemRoutes bool
l *logrus.Logger l *logrus.Logger
} }
@ -63,7 +69,7 @@ type ifreqQLEN struct {
pad [8]byte pad [8]byte
} }
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int) (*tun, error) { func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int, useSystemRoutes bool) (*tun, error) {
routeTree, err := makeRouteTree(l, routes, true) routeTree, err := makeRouteTree(l, routes, true)
if err != nil { if err != nil {
return nil, err return nil, err
@ -71,7 +77,7 @@ func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU in
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
return &tun{ t := &tun{
ReadWriteCloser: file, ReadWriteCloser: file,
fd: int(file.Fd()), fd: int(file.Fd()),
Device: "tun0", Device: "tun0",
@ -79,12 +85,14 @@ func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU in
DefaultMTU: defaultMTU, DefaultMTU: defaultMTU,
TXQueueLen: txQueueLen, TXQueueLen: txQueueLen,
Routes: routes, Routes: routes,
routeTree: routeTree, useSystemRoutes: useSystemRoutes,
l: l, l: l,
}, nil }
t.routeTree.Store(routeTree)
return t, nil
} }
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int, multiqueue bool) (*tun, error) { func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int, multiqueue bool, useSystemRoutes bool) (*tun, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil { if err != nil {
return nil, err return nil, err
@ -119,7 +127,7 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int
return nil, err return nil, err
} }
return &tun{ t := &tun{
ReadWriteCloser: file, ReadWriteCloser: file,
fd: int(file.Fd()), fd: int(file.Fd()),
Device: name, Device: name,
@ -128,9 +136,11 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int
DefaultMTU: defaultMTU, DefaultMTU: defaultMTU,
TXQueueLen: txQueueLen, TXQueueLen: txQueueLen,
Routes: routes, Routes: routes,
routeTree: routeTree, useSystemRoutes: useSystemRoutes,
l: l, l: l,
}, nil }
t.routeTree.Store(routeTree)
return t, nil
} }
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
@ -152,7 +162,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
} }
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
r := t.routeTree.MostSpecificContains(ip) r := t.routeTree.Load().MostSpecificContains(ip)
if r != nil { if r != nil {
return r.(iputil.VpnIp) return r.(iputil.VpnIp)
} }
@ -183,16 +193,20 @@ func (t *tun) Write(b []byte) (int, error) {
} }
} }
func (t tun) deviceBytes() (o [16]byte) { func (t *tun) deviceBytes() (o [16]byte) {
for i, c := range t.Device { for i, c := range t.Device {
o[i] = byte(c) o[i] = byte(c)
} }
return return
} }
func (t tun) Activate() error { func (t *tun) Activate() error {
devName := t.deviceBytes() devName := t.deviceBytes()
if t.useSystemRoutes {
t.watchRoutes()
}
var addr, mask [4]byte var addr, mask [4]byte
copy(addr[:], t.cidr.IP.To4()) copy(addr[:], t.cidr.IP.To4())
@ -318,7 +332,7 @@ func (t *tun) Name() string {
return t.Device return t.Device
} }
func (t tun) advMSS(r Route) int { func (t *tun) advMSS(r Route) int {
mtu := r.MTU mtu := r.MTU
if r.MTU == 0 { if r.MTU == 0 {
mtu = t.DefaultMTU mtu = t.DefaultMTU
@ -330,3 +344,83 @@ func (t tun) advMSS(r Route) int {
} }
return 0 return 0
} }
func (t *tun) watchRoutes() {
rch := make(chan netlink.RouteUpdate)
doneChan := make(chan struct{})
if err := netlink.RouteSubscribe(rch, doneChan); err != nil {
t.l.WithError(err).Errorf("failed to subscribe to system route changes")
return
}
t.routeChan = doneChan
go func() {
for {
select {
case r := <-rch:
t.updateRoutes(r)
case <-doneChan:
// netlink.RouteSubscriber will close the rch for us
return
}
}
}()
}
func (t *tun) updateRoutes(r netlink.RouteUpdate) {
if r.Gw == nil {
// Not a gateway route, ignore
t.l.WithField("route", r).Debug("Ignoring route update, not a gateway route")
return
}
if !t.cidr.Contains(r.Gw) {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
return
}
if x := r.Dst.IP.To4(); x == nil {
// Nebula only handles ipv4 on the overlay currently
t.l.WithField("route", r).Debug("Ignoring route update, destination is not ipv4")
return
}
newTree := cidr.NewTree4()
if r.Type == unix.RTM_NEWROUTE {
for _, oldR := range t.routeTree.Load().List() {
newTree.AddCIDR(oldR.CIDR, oldR.Value)
}
t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Adding route")
newTree.AddCIDR(r.Dst, iputil.Ip2VpnIp(r.Gw))
} else {
gw := iputil.Ip2VpnIp(r.Gw)
for _, oldR := range t.routeTree.Load().List() {
if bytes.Equal(oldR.CIDR.IP, r.Dst.IP) && bytes.Equal(oldR.CIDR.Mask, r.Dst.Mask) && *oldR.Value != nil && (*oldR.Value).(iputil.VpnIp) == gw {
// This is the record to delete
t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route")
continue
}
newTree.AddCIDR(oldR.CIDR, oldR.Value)
}
}
t.routeTree.Store(newTree)
}
func (t *tun) Close() error {
if t.routeChan != nil {
close(t.routeChan)
}
if t.ReadWriteCloser != nil {
t.ReadWriteCloser.Close()
}
return nil
}

View File

@ -7,19 +7,19 @@ import "testing"
var runAdvMSSTests = []struct { var runAdvMSSTests = []struct {
name string name string
tun tun tun *tun
r Route r Route
expected int expected int
}{ }{
// Standard case, default MTU is the device max MTU // Standard case, default MTU is the device max MTU
{"default", tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{}, 0}, {"default", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{}, 0},
{"default-min", tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1440}, 0}, {"default-min", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1440}, 0},
{"default-low", tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1200}, 1160}, {"default-low", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1200}, 1160},
// Case where we have a route MTU set higher than the default // Case where we have a route MTU set higher than the default
{"route", tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{}, 1400}, {"route", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{}, 1400},
{"route-min", tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 1440}, 1400}, {"route-min", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 1440}, 1400},
{"route-high", tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 8941}, 0}, {"route-high", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 8941}, 0},
} }
func TestTunAdvMSS(t *testing.T) { func TestTunAdvMSS(t *testing.T) {

View File

@ -25,7 +25,7 @@ type TestTun struct {
TxPackets chan []byte // Packets transmitted outside by nebula TxPackets chan []byte // Packets transmitted outside by nebula
} }
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool) (*TestTun, error) { func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool, _ bool) (*TestTun, error) {
routeTree, err := makeRouteTree(l, routes, false) routeTree, err := makeRouteTree(l, routes, false)
if err != nil { if err != nil {
return nil, err return nil, err
@ -42,7 +42,7 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, _ int, routes
}, nil }, nil
} }
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (*TestTun, error) { func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*TestTun, error) {
return nil, fmt.Errorf("newTunFromFd not supported") return nil, fmt.Errorf("newTunFromFd not supported")
} }

View File

@ -14,11 +14,11 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (Device, error) { func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (Device, error) {
return nil, fmt.Errorf("newTunFromFd not supported in Windows") return nil, fmt.Errorf("newTunFromFd not supported in Windows")
} }
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (Device, error) { func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (Device, error) {
useWintun := true useWintun := true
if err := checkWinTunExists(); err != nil { if err := checkWinTunExists(); err != nil {
l.WithError(err).Warn("Check Wintun driver failed, fallback to wintap driver") l.WithError(err).Warn("Check Wintun driver failed, fallback to wintap driver")

View File

@ -2,10 +2,16 @@ package nebula
import ( import (
"bytes" "bytes"
"context"
"net" "net"
"net/netip"
"sort" "sort"
"strconv"
"sync" "sync"
"sync/atomic"
"time"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
) )
@ -55,6 +61,132 @@ type cacheV6 struct {
reported []*Ip6AndPort reported []*Ip6AndPort
} }
type hostnamePort struct {
name string
port uint16
}
type hostnamesResults struct {
hostnames []hostnamePort
network string
lookupTimeout time.Duration
stop chan struct{}
l *logrus.Logger
ips atomic.Pointer[map[netip.AddrPort]struct{}]
}
func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, network string, timeout time.Duration, hostPorts []string, onUpdate func()) (*hostnamesResults, error) {
r := &hostnamesResults{
hostnames: make([]hostnamePort, len(hostPorts)),
network: network,
lookupTimeout: timeout,
stop: make(chan (struct{})),
l: l,
}
// Fastrack IP addresses to ensure they're immediately available for use.
// DNS lookups for hostnames that aren't hardcoded IP's will happen in a background goroutine.
performBackgroundLookup := false
ips := map[netip.AddrPort]struct{}{}
for idx, hostPort := range hostPorts {
rIp, sPort, err := net.SplitHostPort(hostPort)
if err != nil {
return nil, err
}
iPort, err := strconv.Atoi(sPort)
if err != nil {
return nil, err
}
r.hostnames[idx] = hostnamePort{name: rIp, port: uint16(iPort)}
addr, err := netip.ParseAddr(rIp)
if err != nil {
// This address is a hostname, not an IP address
performBackgroundLookup = true
continue
}
// Save the IP address immediately
ips[netip.AddrPortFrom(addr, uint16(iPort))] = struct{}{}
}
r.ips.Store(&ips)
// Time for the DNS lookup goroutine
if performBackgroundLookup {
ticker := time.NewTicker(d)
go func() {
defer ticker.Stop()
for {
netipAddrs := map[netip.AddrPort]struct{}{}
for _, hostPort := range r.hostnames {
timeoutCtx, timeoutCancel := context.WithTimeout(ctx, r.lookupTimeout)
addrs, err := net.DefaultResolver.LookupNetIP(timeoutCtx, r.network, hostPort.name)
timeoutCancel()
if err != nil {
l.WithFields(logrus.Fields{"hostname": hostPort.name, "network": r.network}).WithError(err).Error("DNS resolution failed for static_map host")
continue
}
for _, a := range addrs {
netipAddrs[netip.AddrPortFrom(a, hostPort.port)] = struct{}{}
}
}
origSet := r.ips.Load()
different := false
for a := range *origSet {
if _, ok := netipAddrs[a]; !ok {
different = true
break
}
}
if !different {
for a := range netipAddrs {
if _, ok := (*origSet)[a]; !ok {
different = true
break
}
}
}
if different {
l.WithFields(logrus.Fields{"origSet": origSet, "newSet": netipAddrs}).Info("DNS results changed for host list")
r.ips.Store(&netipAddrs)
onUpdate()
}
select {
case <-ctx.Done():
return
case <-r.stop:
return
case <-ticker.C:
continue
}
}
}()
}
return r, nil
}
func (hr *hostnamesResults) Cancel() {
if hr != nil {
hr.stop <- struct{}{}
}
}
func (hr *hostnamesResults) GetIPs() []netip.AddrPort {
var retSlice []netip.AddrPort
if hr != nil {
p := hr.ips.Load()
if p != nil {
for k := range *p {
retSlice = append(retSlice, k)
}
}
}
return retSlice
}
// RemoteList is a unifying concept for lighthouse servers and clients as well as hostinfos. // RemoteList is a unifying concept for lighthouse servers and clients as well as hostinfos.
// It serves as a local cache of query replies, host update notifications, and locally learned addresses // It serves as a local cache of query replies, host update notifications, and locally learned addresses
type RemoteList struct { type RemoteList struct {
@ -72,6 +204,9 @@ type RemoteList struct {
// For learned addresses, this is the vpnIp that sent the packet // For learned addresses, this is the vpnIp that sent the packet
cache map[iputil.VpnIp]*cache cache map[iputil.VpnIp]*cache
hr *hostnamesResults
shouldAdd func(netip.Addr) bool
// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip. // This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
// They should not be tried again during a handshake // They should not be tried again during a handshake
badRemotes []*udp.Addr badRemotes []*udp.Addr
@ -81,14 +216,21 @@ type RemoteList struct {
} }
// NewRemoteList creates a new empty RemoteList // NewRemoteList creates a new empty RemoteList
func NewRemoteList() *RemoteList { func NewRemoteList(shouldAdd func(netip.Addr) bool) *RemoteList {
return &RemoteList{ return &RemoteList{
addrs: make([]*udp.Addr, 0), addrs: make([]*udp.Addr, 0),
relays: make([]*iputil.VpnIp, 0), relays: make([]*iputil.VpnIp, 0),
cache: make(map[iputil.VpnIp]*cache), cache: make(map[iputil.VpnIp]*cache),
shouldAdd: shouldAdd,
} }
} }
func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) {
// Cancel any existing hostnamesResults DNS goroutine to release resources
r.hr.Cancel()
r.hr = hr
}
// Len locks and reports the size of the deduplicated address list // Len locks and reports the size of the deduplicated address list
// The deduplication work may need to occur here, so you must pass preferredRanges // The deduplication work may need to occur here, so you must pass preferredRanges
func (r *RemoteList) Len(preferredRanges []*net.IPNet) int { func (r *RemoteList) Len(preferredRanges []*net.IPNet) int {
@ -437,6 +579,26 @@ func (r *RemoteList) unlockedCollect() {
} }
} }
dnsAddrs := r.hr.GetIPs()
for _, addr := range dnsAddrs {
if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) {
switch {
case addr.Addr().Is4():
v4 := addr.Addr().As4()
addrs = append(addrs, &udp.Addr{
IP: v4[:],
Port: addr.Port(),
})
case addr.Addr().Is6():
v6 := addr.Addr().As16()
addrs = append(addrs, &udp.Addr{
IP: v6[:],
Port: addr.Port(),
})
}
}
}
r.addrs = addrs r.addrs = addrs
r.relays = relays r.relays = relays

View File

@ -9,7 +9,7 @@ import (
) )
func TestRemoteList_Rebuild(t *testing.T) { func TestRemoteList_Rebuild(t *testing.T) {
rl := NewRemoteList() rl := NewRemoteList(nil)
rl.unlockedSetV4( rl.unlockedSetV4(
0, 0,
0, 0,
@ -102,7 +102,7 @@ func TestRemoteList_Rebuild(t *testing.T) {
} }
func BenchmarkFullRebuild(b *testing.B) { func BenchmarkFullRebuild(b *testing.B) {
rl := NewRemoteList() rl := NewRemoteList(nil)
rl.unlockedSetV4( rl.unlockedSetV4(
0, 0,
0, 0,
@ -167,7 +167,7 @@ func BenchmarkFullRebuild(b *testing.B) {
} }
func BenchmarkSortRebuild(b *testing.B) { func BenchmarkSortRebuild(b *testing.B) {
rl := NewRemoteList() rl := NewRemoteList(nil)
rl.unlockedSetV4( rl.unlockedSetV4(
0, 0,
0, 0,