Merge remote-tracking branch 'origin/master' into mutex-debug

This commit is contained in:
Wade Simmons 2023-05-09 11:42:05 -04:00
commit a83f0ca470
23 changed files with 689 additions and 167 deletions

View File

@ -118,6 +118,17 @@ To build nebula for a specific platform (ex, Windows):
See the [Makefile](Makefile) for more details on build targets
## Curve P256 and BoringCrypto
The default curve used for cryptographic handshakes and signatures is Curve25519. This is the recommended setting for most users. If your deployment has certain compliance requirements, you have the option of creating your CA using `nebula-cert ca -curve P256` to use NIST Curve P256. The CA will then sign certificates using ECDSA P256, and any hosts using these certificates will use P256 for ECDH handshakes.
In addition, Nebula can be built using the [BoringCrypto GOEXPERIMENT](https://github.com/golang/go/blob/go1.20/src/crypto/internal/boring/README.md) by running either of the following make targets:
make bin-boringcrypto
make release-boringcrypto
This is not the recommended default deployment, but may be useful based on your compliance requirements.
## Credits
Nebula was created at Slack Technologies, Inc by Nate Brown and Ryan Huber, with contributions from Oliver Fross, Alan Lam, Wade Simmons, and Lining Wang.

12
SECURITY.md Normal file
View File

@ -0,0 +1,12 @@
Security Policy
===============
Reporting a Vulnerability
-------------------------
If you believe you have found a security vulnerability with Nebula, please let
us know right away. We will investigate all reports and do our best to quickly
fix valid issues.
You can submit your report on [HackerOne](https://hackerone.com/slack) and our
security team will respond as soon as possible.

View File

@ -13,8 +13,14 @@ type Node struct {
value interface{}
}
type entry struct {
CIDR *net.IPNet
Value *interface{}
}
type Tree4 struct {
root *Node
list []entry
}
const (
@ -24,6 +30,7 @@ const (
func NewTree4() *Tree4 {
tree := new(Tree4)
tree.root = &Node{}
tree.list = []entry{}
return tree
}
@ -53,6 +60,15 @@ func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) {
// We already have this range so update the value
if next != nil {
addCIDR := cidr.String()
for i, v := range tree.list {
if addCIDR == v.CIDR.String() {
tree.list = append(tree.list[:i], tree.list[i+1:]...)
break
}
}
tree.list = append(tree.list, entry{CIDR: cidr, Value: &val})
node.value = val
return
}
@ -74,9 +90,10 @@ func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) {
// Final node marks our cidr, set the value
node.value = val
tree.list = append(tree.list, entry{CIDR: cidr, Value: &val})
}
// Finds the first match, which may be the least specific
// Contains finds the first match, which may be the least specific
func (tree *Tree4) Contains(ip iputil.VpnIp) (value interface{}) {
bit := startbit
node := tree.root
@ -99,7 +116,7 @@ func (tree *Tree4) Contains(ip iputil.VpnIp) (value interface{}) {
return value
}
// Finds the most specific match
// MostSpecificContains finds the most specific match
func (tree *Tree4) MostSpecificContains(ip iputil.VpnIp) (value interface{}) {
bit := startbit
node := tree.root
@ -121,7 +138,7 @@ func (tree *Tree4) MostSpecificContains(ip iputil.VpnIp) (value interface{}) {
return value
}
// Finds the most specific match
// Match finds the most specific match
func (tree *Tree4) Match(ip iputil.VpnIp) (value interface{}) {
bit := startbit
node := tree.root
@ -143,3 +160,8 @@ func (tree *Tree4) Match(ip iputil.VpnIp) (value interface{}) {
}
return value
}
// List will return all CIDRs and their current values. Do not modify the contents!
func (tree *Tree4) List() []entry {
return tree.list
}

View File

@ -8,6 +8,20 @@ import (
"github.com/stretchr/testify/assert"
)
func TestCIDRTree_List(t *testing.T) {
tree := NewTree4()
tree.AddCIDR(Parse("1.0.0.0/16"), "1")
tree.AddCIDR(Parse("1.0.0.0/8"), "2")
tree.AddCIDR(Parse("1.0.0.0/16"), "3")
tree.AddCIDR(Parse("1.0.0.0/16"), "4")
list := tree.List()
assert.Len(t, list, 2)
assert.Equal(t, "1.0.0.0/8", list[0].CIDR.String())
assert.Equal(t, "2", *list[0].Value)
assert.Equal(t, "1.0.0.0/16", list[1].CIDR.String())
assert.Equal(t, "4", *list[1].Value)
}
func TestCIDRTree_Contains(t *testing.T) {
tree := NewTree4()
tree.AddCIDR(Parse("1.0.0.0/8"), "1")

View File

@ -47,7 +47,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
Signature: []byte{1, 2, 1, 2, 1, 3},
}
remotes := NewRemoteList()
remotes := NewRemoteList(nil)
remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port)))
remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port)))
hm.Add(iputil.Ip2VpnIp(ipNet.IP), &HostInfo{

View File

@ -223,6 +223,10 @@ tun:
# metric: 100
# install: true
# On linux only, set to true to manage unsafe routes directly on the system route table with gateway routes instead of
# in nebula configuration files. Default false, not reloadable.
#use_system_route_table: false
# TODO
# Configure logging level
logging:
@ -301,7 +305,8 @@ firewall:
# host: `any` or a literal hostname, ie `test-host`
# group: `any` or a literal group name, ie `default-group`
# groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
# cidr: a CIDR, `0.0.0.0/0` is any.
# cidr: a remote CIDR, `0.0.0.0/0` is any.
# local_cidr: a local CIDR, `0.0.0.0/0` is any. This could be used to filter destinations when using unsafe_routes.
# ca_name: An issuing CA name
# ca_sha: An issuing CA shasum

View File

@ -25,7 +25,7 @@ const tcpACK = 0x10
const tcpFIN = 0x01
type FirewallInterface interface {
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error
}
type conn struct {
@ -106,11 +106,12 @@ type FirewallCA struct {
}
type FirewallRule struct {
// Any makes Hosts, Groups, and CIDR irrelevant
// Any makes Hosts, Groups, CIDR and LocalCIDR irrelevant
Any bool
Hosts map[string]struct{}
Groups [][]string
CIDR *cidr.Tree4
LocalCIDR *cidr.Tree4
}
// Even though ports are uint16, int32 maps are faster for lookup
@ -218,18 +219,22 @@ func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *conf
}
// AddRule properly creates the in memory rule structure for a firewall table.
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
// Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS
// https://github.com/golang/go/issues/14131
sIp := ""
if ip != nil {
sIp = ip.String()
}
lIp := ""
if localIp != nil {
lIp = localIp.String()
}
// We need this rule string because we generate a hash. Removing this will break firewall reload.
ruleString := fmt.Sprintf(
"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, caName: %v, caSha: %s",
incoming, proto, startPort, endPort, groups, host, sIp, caName, caSha,
"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s",
incoming, proto, startPort, endPort, groups, host, sIp, lIp, caName, caSha,
)
f.rules += ruleString + "\n"
@ -237,7 +242,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
if !incoming {
direction = "outgoing"
}
f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "caName": caName, "caSha": caSha}).
f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "localIp": lIp, "caName": caName, "caSha": caSha}).
Info("Firewall rule added")
var (
@ -264,7 +269,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
return fmt.Errorf("unknown protocol %v", proto)
}
return fp.addRule(startPort, endPort, groups, host, ip, caName, caSha)
return fp.addRule(startPort, endPort, groups, host, ip, localIp, caName, caSha)
}
// GetRuleHash returns a hash representation of all inbound and outbound rules
@ -302,8 +307,8 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
return fmt.Errorf("%s rule #%v; only one of port or code should be provided", table, i)
}
if r.Host == "" && len(r.Groups) == 0 && r.Group == "" && r.Cidr == "" && r.CAName == "" && r.CASha == "" {
return fmt.Errorf("%s rule #%v; at least one of host, group, cidr, ca_name, or ca_sha must be provided", table, i)
if r.Host == "" && len(r.Groups) == 0 && r.Group == "" && r.Cidr == "" && r.LocalCidr == "" && r.CAName == "" && r.CASha == "" {
return fmt.Errorf("%s rule #%v; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided", table, i)
}
if len(r.Groups) > 0 {
@ -355,7 +360,15 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
}
}
err = fw.AddRule(inbound, proto, startPort, endPort, groups, r.Host, cidr, r.CAName, r.CASha)
var localCidr *net.IPNet
if r.LocalCidr != "" {
_, localCidr, err = net.ParseCIDR(r.LocalCidr)
if err != nil {
return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err)
}
}
err = fw.AddRule(inbound, proto, startPort, endPort, groups, r.Host, cidr, localCidr, r.CAName, r.CASha)
if err != nil {
return fmt.Errorf("%s rule #%v; `%s`", table, i, err)
}
@ -595,7 +608,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaC
return false
}
func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
if startPort > endPort {
return fmt.Errorf("start port was lower than end port")
}
@ -608,7 +621,7 @@ func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string,
}
}
if err := fp[i].addRule(groups, host, ip, caName, caSha); err != nil {
if err := fp[i].addRule(groups, host, ip, localIp, caName, caSha); err != nil {
return err
}
}
@ -639,12 +652,13 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCer
return fp[firewall.PortAny].match(p, c, caPool)
}
func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caName, caSha string) error {
func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPNet, caName, caSha string) error {
fr := func() *FirewallRule {
return &FirewallRule{
Hosts: make(map[string]struct{}),
Groups: make([][]string, 0),
CIDR: cidr.NewTree4(),
LocalCIDR: cidr.NewTree4(),
}
}
@ -653,14 +667,14 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caNam
fc.Any = fr()
}
return fc.Any.addRule(groups, host, ip)
return fc.Any.addRule(groups, host, ip, localIp)
}
if caSha != "" {
if _, ok := fc.CAShas[caSha]; !ok {
fc.CAShas[caSha] = fr()
}
err := fc.CAShas[caSha].addRule(groups, host, ip)
err := fc.CAShas[caSha].addRule(groups, host, ip, localIp)
if err != nil {
return err
}
@ -670,7 +684,7 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caNam
if _, ok := fc.CANames[caName]; !ok {
fc.CANames[caName] = fr()
}
err := fc.CANames[caName].addRule(groups, host, ip)
err := fc.CANames[caName].addRule(groups, host, ip, localIp)
if err != nil {
return err
}
@ -702,17 +716,18 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool
return fc.CANames[s.Details.Name].match(p, c)
}
func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) error {
func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet, localIp *net.IPNet) error {
if fr.Any {
return nil
}
if fr.isAny(groups, host, ip) {
if fr.isAny(groups, host, ip, localIp) {
fr.Any = true
// If it's any we need to wipe out any pre-existing rules to save on memory
fr.Groups = make([][]string, 0)
fr.Hosts = make(map[string]struct{})
fr.CIDR = cidr.NewTree4()
fr.LocalCIDR = cidr.NewTree4()
} else {
if len(groups) > 0 {
fr.Groups = append(fr.Groups, groups)
@ -725,13 +740,17 @@ func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) err
if ip != nil {
fr.CIDR.AddCIDR(ip, struct{}{})
}
if localIp != nil {
fr.LocalCIDR.AddCIDR(localIp, struct{}{})
}
}
return nil
}
func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool {
if len(groups) == 0 && host == "" && ip == nil {
func (fr *FirewallRule) isAny(groups []string, host string, ip, localIp *net.IPNet) bool {
if len(groups) == 0 && host == "" && ip == nil && localIp == nil {
return true
}
@ -749,6 +768,10 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool
return true
}
if localIp != nil && localIp.Contains(net.IPv4(0, 0, 0, 0)) {
return true
}
return false
}
@ -790,6 +813,10 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
return true
}
if fr.LocalCIDR != nil && fr.LocalCIDR.Contains(p.LocalIP) != nil {
return true
}
// No host, group, or cidr matched, bye bye
return false
}
@ -802,6 +829,7 @@ type rule struct {
Group string
Groups []string
Cidr string
LocalCidr string
CAName string
CASha string
}
@ -827,6 +855,7 @@ func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, er
r.Proto = toString("proto", m)
r.Host = toString("host", m)
r.Cidr = toString("cidr", m)
r.LocalCidr = toString("local_cidr", m)
r.CAName = toString("ca_name", m)
r.CASha = toString("ca_sha", m)

View File

@ -69,67 +69,75 @@ func TestFirewall_AddRule(t *testing.T) {
_, ti, _ := net.ParseCIDR("1.2.3.4/32")
assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, "", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, nil, "", ""))
// An empty rule is any
assert.True(t, fw.InRules.TCP[1].Any.Any)
assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", ""))
assert.False(t, fw.InRules.UDP[1].Any.Any)
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1")
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, "", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, nil, "", ""))
assert.False(t, fw.InRules.ICMP[1].Any.Any)
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, "", ""))
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, nil, "", ""))
assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP)))
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", ""))
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", nil, ti, "", ""))
assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP)))
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "ca-name", ""))
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "ca-sha"))
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", "ca-sha"))
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
// Set any and clear fields
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", ""))
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, ti, "", ""))
assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0])
assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")
assert.NotNil(t, fw.OutRules.AnyProto[0].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP)))
assert.NotNil(t, fw.OutRules.AnyProto[0].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP)))
// run twice just to make sure
//TODO: these ANY rules should clear the CA firewall portion
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, "", ""))
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
assert.Empty(t, fw.OutRules.AnyProto[0].Any.Groups)
assert.Empty(t, fw.OutRules.AnyProto[0].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, "", ""))
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
_, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, "", ""))
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, nil, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
// Test error conditions
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, "", ""))
assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", nil, "", ""))
assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, nil, "", ""))
assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", nil, nil, "", ""))
}
func TestFirewall_Drop(t *testing.T) {
@ -169,7 +177,7 @@ func TestFirewall_Drop(t *testing.T) {
h.CreateRemoteCIDR(&c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))
cp := cert.NewCAPool()
// Drop outbound
@ -188,28 +196,28 @@ func TestFirewall_Drop(t *testing.T) {
// ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum"))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad"))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum"))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad"))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum"))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum-bad"))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum"))
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good-bad", ""))
assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good-bad", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good", ""))
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
}
@ -219,11 +227,11 @@ func BenchmarkFirewallTable_match(b *testing.B) {
}
_, n, _ := net.ParseCIDR("172.1.1.1/32")
_ = ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, "", "")
_ = ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, "", "")
_ = ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, "", "")
_ = ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, "", "")
_ = ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, "", "")
_ = ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, n, "", "")
_ = ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, n, "", "")
_ = ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, n, "", "")
_ = ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, n, "", "")
_ = ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, n, "", "")
cp := cert.NewCAPool()
b.Run("fail on proto", func(b *testing.B) {
@ -291,7 +299,20 @@ func BenchmarkFirewallTable_match(b *testing.B) {
}
})
_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, "", "")
b.Run("pass on local ip", func(b *testing.B) {
ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
InvertedGroups: map[string]struct{}{"nope": {}},
Name: "good-host",
},
}
for n := 0; n < b.N; n++ {
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, LocalIP: ip}, true, c, cp)
}
})
_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, n, "", "")
b.Run("pass on ip with any port", func(b *testing.B) {
ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
@ -305,6 +326,19 @@ func BenchmarkFirewallTable_match(b *testing.B) {
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp)
}
})
b.Run("pass on local ip with any port", func(b *testing.B) {
ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
InvertedGroups: map[string]struct{}{"nope": {}},
Name: "good-host",
},
}
for n := 0; n < b.N; n++ {
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip}, true, c, cp)
}
})
}
func TestFirewall_Drop2(t *testing.T) {
@ -356,7 +390,7 @@ func TestFirewall_Drop2(t *testing.T) {
h1.CreateRemoteCIDR(&c1)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, nil, "", ""))
cp := cert.NewCAPool()
// h1/c1 lacks the proper groups
@ -438,8 +472,8 @@ func TestFirewall_Drop3(t *testing.T) {
h3.CreateRemoteCIDR(&c3)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", nil, "", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", nil, "", "signer-sha"))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", nil, nil, "", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", nil, nil, "", "signer-sha"))
cp := cert.NewCAPool()
// c1 should pass because host match
@ -489,7 +523,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
h.CreateRemoteCIDR(&c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))
cp := cert.NewCAPool()
// Drop outbound
@ -502,7 +536,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
oldFw := fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", nil, "", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", nil, nil, "", ""))
fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1
@ -511,7 +545,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
oldFw = fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", nil, "", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", nil, nil, "", ""))
fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1
@ -653,7 +687,7 @@ func TestNewFirewallFromConfig(t *testing.T) {
conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
_, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, ca_name, or ca_sha must be provided")
assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
// Test code/port error
conf = config.NewC(l)
@ -677,6 +711,12 @@ func TestNewFirewallFromConfig(t *testing.T) {
_, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh")
// Test local_cidr parse error
conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; invalid CIDR address: testh")
// Test both group and groups
conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
@ -691,63 +731,78 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
mf := &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
// Test adding udp rule
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
// Test adding icmp rule
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
// Test adding any rule
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
// Test adding rule with cidr
cidr := &net.IPNet{net.ParseIP("10.0.0.0").To4(), net.IPv4Mask(255, 0, 0, 0)}
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: nil}, mf.lastCall)
// Test adding rule with local_cidr
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: cidr}, mf.lastCall)
// Test adding rule with ca_sha
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caSha: "12312313123"}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caSha: "12312313123"}, mf.lastCall)
// Test adding rule with ca_name
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caName: "root01"}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caName: "root01"}, mf.lastCall)
// Test single group
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall)
// Test single groups
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall)
// Test multiple AND groups
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil, localIp: nil}, mf.lastCall)
// Test Add error
conf = config.NewC(l)
@ -892,6 +947,7 @@ type addRuleCall struct {
groups []string
host string
ip *net.IPNet
localIp *net.IPNet
caName string
caSha string
}
@ -901,7 +957,7 @@ type mockFirewall struct {
nextCallReturn error
}
func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
mf.lastCall = addRuleCall{
incoming: incoming,
proto: proto,
@ -910,6 +966,7 @@ func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, end
groups: groups,
host: host,
ip: ip,
localIp: localIp,
caName: caName,
caSha: caSha,
}

View File

@ -41,7 +41,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
assert.False(t, initCalled)
assert.Same(t, i, i2)
i.remotes = NewRemoteList()
i.remotes = NewRemoteList(nil)
i.HandshakeReady = true
// Adding something to pending should not affect the main hostmap

View File

@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"net"
"net/netip"
"sync"
"sync/atomic"
"time"
@ -33,6 +34,7 @@ type netIpAndPort struct {
type LightHouse struct {
//TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time
sync.RWMutex //Because we concurrently read and write to our maps
ctx context.Context
amLighthouse bool
myVpnIp iputil.VpnIp
myVpnZeros iputil.VpnIp
@ -82,7 +84,7 @@ type LightHouse struct {
// NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object
// addrMap should be nil unless this is during a config reload
func NewLightHouseFromConfig(l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, pc *udp.Conn, p *Punchy) (*LightHouse, error) {
func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, pc *udp.Conn, p *Punchy) (*LightHouse, error) {
amLighthouse := c.GetBool("lighthouse.am_lighthouse", false)
nebulaPort := uint32(c.GetInt("listen.port", 0))
if amLighthouse && nebulaPort == 0 {
@ -100,6 +102,7 @@ func NewLightHouseFromConfig(l *logrus.Logger, c *config.C, myVpnNet *net.IPNet,
ones, _ := myVpnNet.Mask.Size()
h := LightHouse{
ctx: ctx,
amLighthouse: amLighthouse,
myVpnIp: iputil.Ip2VpnIp(myVpnNet.IP),
myVpnZeros: iputil.VpnIp(32 - ones),
@ -258,7 +261,7 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
}
//NOTE: many things will get much simpler when we combine static_host_map and lighthouse.hosts in config
if initial || c.HasChanged("static_host_map") {
if initial || c.HasChanged("static_host_map") || c.HasChanged("static_map.cadence") || c.HasChanged("static_map.network") || c.HasChanged("static_map.lookup_timeout") {
staticList := make(map[iputil.VpnIp]struct{})
err := lh.loadStaticMap(c, lh.myVpnNet, staticList)
if err != nil {
@ -268,9 +271,19 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
lh.staticList.Store(&staticList)
if !initial {
//TODO: we should remove any remote list entries for static hosts that were removed/modified?
if c.HasChanged("static_host_map") {
lh.l.Info("static_host_map has changed")
}
if c.HasChanged("static_map.cadence") {
lh.l.Info("static_map.cadence has changed")
}
if c.HasChanged("static_map.network") {
lh.l.Info("static_map.network has changed")
}
if c.HasChanged("static_map.lookup_timeout") {
lh.l.Info("static_map.lookup_timeout has changed")
}
}
}
if initial || c.HasChanged("lighthouse.hosts") {
@ -344,7 +357,48 @@ func (lh *LightHouse) parseLighthouses(c *config.C, tunCidr *net.IPNet, lhMap ma
return nil
}
func getStaticMapCadence(c *config.C) (time.Duration, error) {
cadence := c.GetString("static_map.cadence", "30s")
d, err := time.ParseDuration(cadence)
if err != nil {
return 0, err
}
return d, nil
}
func getStaticMapLookupTimeout(c *config.C) (time.Duration, error) {
lookupTimeout := c.GetString("static_map.lookup_timeout", "250ms")
d, err := time.ParseDuration(lookupTimeout)
if err != nil {
return 0, err
}
return d, nil
}
func getStaticMapNetwork(c *config.C) (string, error) {
network := c.GetString("static_map.network", "ip4")
if network != "ip" && network != "ip4" && network != "ip6" {
return "", fmt.Errorf("static_map.network must be one of ip, ip4, or ip6")
}
return network, nil
}
func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList map[iputil.VpnIp]struct{}) error {
d, err := getStaticMapCadence(c)
if err != nil {
return err
}
network, err := getStaticMapNetwork(c)
if err != nil {
return err
}
lookup_timeout, err := getStaticMapLookupTimeout(c)
if err != nil {
return err
}
shm := c.GetMap("static_host_map", map[interface{}]interface{}{})
i := 0
@ -360,21 +414,17 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList
vpnIp := iputil.Ip2VpnIp(rip)
vals, ok := v.([]interface{})
if ok {
for _, v := range vals {
ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v))
if err != nil {
return util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp, "entry": i + 1}, err)
if !ok {
vals = []interface{}{v}
}
lh.addStaticRemote(vpnIp, udp.NewAddr(ip, port), staticList)
remoteAddrs := []string{}
for _, v := range vals {
remoteAddrs = append(remoteAddrs, fmt.Sprintf("%v", v))
}
} else {
ip, port, err := udp.ParseIPAndPort(fmt.Sprintf("%v", v))
err := lh.addStaticRemotes(i, d, network, lookup_timeout, vpnIp, remoteAddrs, staticList)
if err != nil {
return util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp, "entry": i + 1}, err)
}
lh.addStaticRemote(vpnIp, udp.NewAddr(ip, port), staticList)
return err
}
i++
}
@ -482,30 +532,47 @@ func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) {
// We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with
// And we don't want a lighthouse query reply to interfere with our learned cache if we are a client
// NOTE: this function should not interact with any hot path objects, like lh.staticList, the caller should handle it
func (lh *LightHouse) addStaticRemote(vpnIp iputil.VpnIp, toAddr *udp.Addr, staticList map[iputil.VpnIp]struct{}) {
func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnIp iputil.VpnIp, toAddrs []string, staticList map[iputil.VpnIp]struct{}) error {
lh.Lock()
am := lh.unlockedGetRemoteList(vpnIp)
am.Lock()
defer am.Unlock()
ctx := lh.ctx
lh.Unlock()
if ipv4 := toAddr.IP.To4(); ipv4 != nil {
to := NewIp4AndPort(ipv4, uint32(toAddr.Port))
hr, err := NewHostnameResults(ctx, lh.l, d, network, timeout, toAddrs, func() {
// This callback runs whenever the DNS hostname resolver finds a different set of IP's
// in its resolution for hostnames.
am.Lock()
defer am.Unlock()
am.shouldRebuild = true
})
if err != nil {
return util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp, "entry": i + 1}, err)
}
am.unlockedSetHostnamesResults(hr)
for _, addrPort := range hr.GetIPs() {
switch {
case addrPort.Addr().Is4():
to := NewIp4AndPortFromNetIP(addrPort.Addr(), addrPort.Port())
if !lh.unlockedShouldAddV4(vpnIp, to) {
return
continue
}
am.unlockedPrependV4(lh.myVpnIp, to)
} else {
to := NewIp6AndPort(toAddr.IP, uint32(toAddr.Port))
case addrPort.Addr().Is6():
to := NewIp6AndPortFromNetIP(addrPort.Addr(), addrPort.Port())
if !lh.unlockedShouldAddV6(vpnIp, to) {
return
continue
}
am.unlockedPrependV6(lh.myVpnIp, to)
}
}
// Mark it as static in the caller provided map
staticList[vpnIp] = struct{}{}
return nil
}
// addCalculatedRemotes adds any calculated remotes based on the
@ -545,12 +612,42 @@ func (lh *LightHouse) addCalculatedRemotes(vpnIp iputil.VpnIp) bool {
func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList {
am, ok := lh.addrMap[vpnIp]
if !ok {
am = NewRemoteList()
am = NewRemoteList(func(a netip.Addr) bool { return lh.shouldAdd(vpnIp, a) })
lh.addrMap[vpnIp] = am
}
return am
}
func (lh *LightHouse) shouldAdd(vpnIp iputil.VpnIp, to netip.Addr) bool {
switch {
case to.Is4():
ipBytes := to.As4()
ip := iputil.Ip2VpnIp(ipBytes[:])
allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, ip)
if lh.l.Level >= logrus.TraceLevel {
lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow")
}
if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, ip) {
return false
}
case to.Is6():
ipBytes := to.As16()
hi := binary.BigEndian.Uint64(ipBytes[:8])
lo := binary.BigEndian.Uint64(ipBytes[8:])
allow := lh.GetRemoteAllowList().AllowIpV6(vpnIp, hi, lo)
if lh.l.Level >= logrus.TraceLevel {
lh.l.WithField("remoteIp", to).WithField("allow", allow).Trace("remoteAllowList.Allow")
}
// We don't check our vpn network here because nebula does not support ipv6 on the inside
if !allow {
return false
}
}
return true
}
// unlockedShouldAddV4 checks if to is allowed by our allow list
func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bool {
allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, iputil.VpnIp(to.Ip))
@ -609,6 +706,14 @@ func NewIp4AndPort(ip net.IP, port uint32) *Ip4AndPort {
return &ipp
}
func NewIp4AndPortFromNetIP(ip netip.Addr, port uint16) *Ip4AndPort {
v4Addr := ip.As4()
return &Ip4AndPort{
Ip: binary.BigEndian.Uint32(v4Addr[:]),
Port: uint32(port),
}
}
func NewIp6AndPort(ip net.IP, port uint32) *Ip6AndPort {
return &Ip6AndPort{
Hi: binary.BigEndian.Uint64(ip[:8]),
@ -617,6 +722,14 @@ func NewIp6AndPort(ip net.IP, port uint32) *Ip6AndPort {
}
}
func NewIp6AndPortFromNetIP(ip netip.Addr, port uint16) *Ip6AndPort {
ip6Addr := ip.As16()
return &Ip6AndPort{
Hi: binary.BigEndian.Uint64(ip6Addr[:8]),
Lo: binary.BigEndian.Uint64(ip6Addr[8:]),
Port: uint32(port),
}
}
func NewUDPAddrFromLH4(ipp *Ip4AndPort) *udp.Addr {
ip := ipp.Ip
return udp.NewAddr(

View File

@ -1,6 +1,7 @@
package nebula
import (
"context"
"fmt"
"net"
"testing"
@ -53,14 +54,14 @@ func Test_lhStaticMapping(t *testing.T) {
c := config.NewC(l)
c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}}
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
_, err := NewLightHouseFromConfig(l, c, myVpnNet, nil, nil)
_, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
assert.Nil(t, err)
lh2 := "10.128.0.3"
c = config.NewC(l)
c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}}
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}}
_, err = NewLightHouseFromConfig(l, c, myVpnNet, nil, nil)
_, err = NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry")
}
@ -69,14 +70,14 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
_, myVpnNet, _ := net.ParseCIDR("10.128.0.1/0")
c := config.NewC(l)
lh, err := NewLightHouseFromConfig(l, c, myVpnNet, nil, nil)
lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
if !assert.NoError(b, err) {
b.Fatal()
}
hAddr := udp.NewAddrFromString("4.5.6.7:12345")
hAddr2 := udp.NewAddrFromString("4.5.6.7:12346")
lh.addrMap[3] = NewRemoteList()
lh.addrMap[3] = NewRemoteList(nil)
lh.addrMap[3].unlockedSetV4(
3,
3,
@ -89,7 +90,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
rAddr := udp.NewAddrFromString("1.2.2.3:12345")
rAddr2 := udp.NewAddrFromString("1.2.2.3:12346")
lh.addrMap[2] = NewRemoteList()
lh.addrMap[2] = NewRemoteList(nil)
lh.addrMap[2].unlockedSetV4(
3,
3,
@ -162,7 +163,7 @@ func TestLighthouse_Memory(t *testing.T) {
c := config.NewC(l)
c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
lh, err := NewLightHouseFromConfig(l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil)
lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil)
assert.NoError(t, err)
lhh := lh.NewRequestHandler()
@ -238,7 +239,7 @@ func TestLighthouse_reload(t *testing.T) {
c := config.NewC(l)
c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
lh, err := NewLightHouseFromConfig(l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil)
lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil)
assert.NoError(t, err)
c.Settings["static_host_map"] = map[interface{}]interface{}{"10.128.0.2": []interface{}{"1.1.1.1:4242"}}

View File

@ -226,7 +226,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
*/
punchy := NewPunchyFromConfig(l, c)
lightHouse, err := NewLightHouseFromConfig(l, c, tunCidr, udpConns[0], punchy)
lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy)
switch {
case errors.As(err, &util.ContextualError{}):
return nil, err

View File

@ -35,6 +35,7 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, fd *
c.GetInt("tun.mtu", DefaultMTU),
routes,
c.GetInt("tun.tx_queue", 500),
c.GetBool("tun.use_system_route_table", false),
)
default:
@ -46,6 +47,7 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, fd *
routes,
c.GetInt("tun.tx_queue", 500),
routines > 1,
c.GetBool("tun.use_system_route_table", false),
)
}
}

View File

@ -22,7 +22,7 @@ type tun struct {
l *logrus.Logger
}
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int) (*tun, error) {
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool) (*tun, error) {
routeTree, err := makeRouteTree(l, routes, false)
if err != nil {
return nil, err
@ -41,7 +41,7 @@ func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes
}, nil
}
func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) {
func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTun not supported in Android")
}

View File

@ -77,7 +77,7 @@ type ifreqMTU struct {
pad [8]byte
}
func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (*tun, error) {
func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) {
routeTree, err := makeRouteTree(l, routes, false)
if err != nil {
return nil, err
@ -170,7 +170,7 @@ func (t *tun) deviceBytes() (o [16]byte) {
return
}
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (*tun, error) {
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
}

View File

@ -38,11 +38,11 @@ func (t *tun) Close() error {
return nil
}
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (*tun, error) {
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
}
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (*tun, error) {
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) {
routeTree, err := makeRouteTree(l, routes, false)
if err != nil {
return nil, err

View File

@ -23,11 +23,11 @@ type tun struct {
routeTree *cidr.Tree4
}
func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) {
func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTun not supported in iOS")
}
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int) (*tun, error) {
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool) (*tun, error) {
routeTree, err := makeRouteTree(l, routes, false)
if err != nil {
return nil, err

View File

@ -4,11 +4,13 @@
package overlay
import (
"bytes"
"fmt"
"io"
"net"
"os"
"strings"
"sync/atomic"
"unsafe"
"github.com/sirupsen/logrus"
@ -26,8 +28,12 @@ type tun struct {
MaxMTU int
DefaultMTU int
TXQueueLen int
Routes []Route
routeTree *cidr.Tree4
routeTree atomic.Pointer[cidr.Tree4]
routeChan chan struct{}
useSystemRoutes bool
l *logrus.Logger
}
@ -63,7 +69,7 @@ type ifreqQLEN struct {
pad [8]byte
}
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int) (*tun, error) {
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int, useSystemRoutes bool) (*tun, error) {
routeTree, err := makeRouteTree(l, routes, true)
if err != nil {
return nil, err
@ -71,7 +77,7 @@ func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU in
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
return &tun{
t := &tun{
ReadWriteCloser: file,
fd: int(file.Fd()),
Device: "tun0",
@ -79,12 +85,14 @@ func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU in
DefaultMTU: defaultMTU,
TXQueueLen: txQueueLen,
Routes: routes,
routeTree: routeTree,
useSystemRoutes: useSystemRoutes,
l: l,
}, nil
}
t.routeTree.Store(routeTree)
return t, nil
}
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int, multiqueue bool) (*tun, error) {
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int, multiqueue bool, useSystemRoutes bool) (*tun, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
return nil, err
@ -119,7 +127,7 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int
return nil, err
}
return &tun{
t := &tun{
ReadWriteCloser: file,
fd: int(file.Fd()),
Device: name,
@ -128,9 +136,11 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int
DefaultMTU: defaultMTU,
TXQueueLen: txQueueLen,
Routes: routes,
routeTree: routeTree,
useSystemRoutes: useSystemRoutes,
l: l,
}, nil
}
t.routeTree.Store(routeTree)
return t, nil
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
@ -152,7 +162,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
}
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
r := t.routeTree.MostSpecificContains(ip)
r := t.routeTree.Load().MostSpecificContains(ip)
if r != nil {
return r.(iputil.VpnIp)
}
@ -183,16 +193,20 @@ func (t *tun) Write(b []byte) (int, error) {
}
}
func (t tun) deviceBytes() (o [16]byte) {
func (t *tun) deviceBytes() (o [16]byte) {
for i, c := range t.Device {
o[i] = byte(c)
}
return
}
func (t tun) Activate() error {
func (t *tun) Activate() error {
devName := t.deviceBytes()
if t.useSystemRoutes {
t.watchRoutes()
}
var addr, mask [4]byte
copy(addr[:], t.cidr.IP.To4())
@ -318,7 +332,7 @@ func (t *tun) Name() string {
return t.Device
}
func (t tun) advMSS(r Route) int {
func (t *tun) advMSS(r Route) int {
mtu := r.MTU
if r.MTU == 0 {
mtu = t.DefaultMTU
@ -330,3 +344,83 @@ func (t tun) advMSS(r Route) int {
}
return 0
}
func (t *tun) watchRoutes() {
rch := make(chan netlink.RouteUpdate)
doneChan := make(chan struct{})
if err := netlink.RouteSubscribe(rch, doneChan); err != nil {
t.l.WithError(err).Errorf("failed to subscribe to system route changes")
return
}
t.routeChan = doneChan
go func() {
for {
select {
case r := <-rch:
t.updateRoutes(r)
case <-doneChan:
// netlink.RouteSubscriber will close the rch for us
return
}
}
}()
}
func (t *tun) updateRoutes(r netlink.RouteUpdate) {
if r.Gw == nil {
// Not a gateway route, ignore
t.l.WithField("route", r).Debug("Ignoring route update, not a gateway route")
return
}
if !t.cidr.Contains(r.Gw) {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
return
}
if x := r.Dst.IP.To4(); x == nil {
// Nebula only handles ipv4 on the overlay currently
t.l.WithField("route", r).Debug("Ignoring route update, destination is not ipv4")
return
}
newTree := cidr.NewTree4()
if r.Type == unix.RTM_NEWROUTE {
for _, oldR := range t.routeTree.Load().List() {
newTree.AddCIDR(oldR.CIDR, oldR.Value)
}
t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Adding route")
newTree.AddCIDR(r.Dst, iputil.Ip2VpnIp(r.Gw))
} else {
gw := iputil.Ip2VpnIp(r.Gw)
for _, oldR := range t.routeTree.Load().List() {
if bytes.Equal(oldR.CIDR.IP, r.Dst.IP) && bytes.Equal(oldR.CIDR.Mask, r.Dst.Mask) && *oldR.Value != nil && (*oldR.Value).(iputil.VpnIp) == gw {
// This is the record to delete
t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route")
continue
}
newTree.AddCIDR(oldR.CIDR, oldR.Value)
}
}
t.routeTree.Store(newTree)
}
func (t *tun) Close() error {
if t.routeChan != nil {
close(t.routeChan)
}
if t.ReadWriteCloser != nil {
t.ReadWriteCloser.Close()
}
return nil
}

View File

@ -7,19 +7,19 @@ import "testing"
var runAdvMSSTests = []struct {
name string
tun tun
tun *tun
r Route
expected int
}{
// Standard case, default MTU is the device max MTU
{"default", tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{}, 0},
{"default-min", tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1440}, 0},
{"default-low", tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1200}, 1160},
{"default", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{}, 0},
{"default-min", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1440}, 0},
{"default-low", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1200}, 1160},
// Case where we have a route MTU set higher than the default
{"route", tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{}, 1400},
{"route-min", tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 1440}, 1400},
{"route-high", tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 8941}, 0},
{"route", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{}, 1400},
{"route-min", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 1440}, 1400},
{"route-high", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 8941}, 0},
}
func TestTunAdvMSS(t *testing.T) {

View File

@ -25,7 +25,7 @@ type TestTun struct {
TxPackets chan []byte // Packets transmitted outside by nebula
}
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool) (*TestTun, error) {
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool, _ bool) (*TestTun, error) {
routeTree, err := makeRouteTree(l, routes, false)
if err != nil {
return nil, err
@ -42,7 +42,7 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, _ int, routes
}, nil
}
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (*TestTun, error) {
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*TestTun, error) {
return nil, fmt.Errorf("newTunFromFd not supported")
}

View File

@ -14,11 +14,11 @@ import (
"github.com/sirupsen/logrus"
)
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int) (Device, error) {
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (Device, error) {
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
}
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool) (Device, error) {
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (Device, error) {
useWintun := true
if err := checkWinTunExists(); err != nil {
l.WithError(err).Warn("Check Wintun driver failed, fallback to wintap driver")

View File

@ -2,10 +2,16 @@ package nebula
import (
"bytes"
"context"
"net"
"net/netip"
"sort"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
)
@ -55,6 +61,132 @@ type cacheV6 struct {
reported []*Ip6AndPort
}
type hostnamePort struct {
name string
port uint16
}
type hostnamesResults struct {
hostnames []hostnamePort
network string
lookupTimeout time.Duration
stop chan struct{}
l *logrus.Logger
ips atomic.Pointer[map[netip.AddrPort]struct{}]
}
func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, network string, timeout time.Duration, hostPorts []string, onUpdate func()) (*hostnamesResults, error) {
r := &hostnamesResults{
hostnames: make([]hostnamePort, len(hostPorts)),
network: network,
lookupTimeout: timeout,
stop: make(chan (struct{})),
l: l,
}
// Fastrack IP addresses to ensure they're immediately available for use.
// DNS lookups for hostnames that aren't hardcoded IP's will happen in a background goroutine.
performBackgroundLookup := false
ips := map[netip.AddrPort]struct{}{}
for idx, hostPort := range hostPorts {
rIp, sPort, err := net.SplitHostPort(hostPort)
if err != nil {
return nil, err
}
iPort, err := strconv.Atoi(sPort)
if err != nil {
return nil, err
}
r.hostnames[idx] = hostnamePort{name: rIp, port: uint16(iPort)}
addr, err := netip.ParseAddr(rIp)
if err != nil {
// This address is a hostname, not an IP address
performBackgroundLookup = true
continue
}
// Save the IP address immediately
ips[netip.AddrPortFrom(addr, uint16(iPort))] = struct{}{}
}
r.ips.Store(&ips)
// Time for the DNS lookup goroutine
if performBackgroundLookup {
ticker := time.NewTicker(d)
go func() {
defer ticker.Stop()
for {
netipAddrs := map[netip.AddrPort]struct{}{}
for _, hostPort := range r.hostnames {
timeoutCtx, timeoutCancel := context.WithTimeout(ctx, r.lookupTimeout)
addrs, err := net.DefaultResolver.LookupNetIP(timeoutCtx, r.network, hostPort.name)
timeoutCancel()
if err != nil {
l.WithFields(logrus.Fields{"hostname": hostPort.name, "network": r.network}).WithError(err).Error("DNS resolution failed for static_map host")
continue
}
for _, a := range addrs {
netipAddrs[netip.AddrPortFrom(a, hostPort.port)] = struct{}{}
}
}
origSet := r.ips.Load()
different := false
for a := range *origSet {
if _, ok := netipAddrs[a]; !ok {
different = true
break
}
}
if !different {
for a := range netipAddrs {
if _, ok := (*origSet)[a]; !ok {
different = true
break
}
}
}
if different {
l.WithFields(logrus.Fields{"origSet": origSet, "newSet": netipAddrs}).Info("DNS results changed for host list")
r.ips.Store(&netipAddrs)
onUpdate()
}
select {
case <-ctx.Done():
return
case <-r.stop:
return
case <-ticker.C:
continue
}
}
}()
}
return r, nil
}
func (hr *hostnamesResults) Cancel() {
if hr != nil {
hr.stop <- struct{}{}
}
}
func (hr *hostnamesResults) GetIPs() []netip.AddrPort {
var retSlice []netip.AddrPort
if hr != nil {
p := hr.ips.Load()
if p != nil {
for k := range *p {
retSlice = append(retSlice, k)
}
}
}
return retSlice
}
// RemoteList is a unifying concept for lighthouse servers and clients as well as hostinfos.
// It serves as a local cache of query replies, host update notifications, and locally learned addresses
type RemoteList struct {
@ -72,6 +204,9 @@ type RemoteList struct {
// For learned addresses, this is the vpnIp that sent the packet
cache map[iputil.VpnIp]*cache
hr *hostnamesResults
shouldAdd func(netip.Addr) bool
// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
// They should not be tried again during a handshake
badRemotes []*udp.Addr
@ -81,14 +216,21 @@ type RemoteList struct {
}
// NewRemoteList creates a new empty RemoteList
func NewRemoteList() *RemoteList {
func NewRemoteList(shouldAdd func(netip.Addr) bool) *RemoteList {
return &RemoteList{
addrs: make([]*udp.Addr, 0),
relays: make([]*iputil.VpnIp, 0),
cache: make(map[iputil.VpnIp]*cache),
shouldAdd: shouldAdd,
}
}
func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) {
// Cancel any existing hostnamesResults DNS goroutine to release resources
r.hr.Cancel()
r.hr = hr
}
// Len locks and reports the size of the deduplicated address list
// The deduplication work may need to occur here, so you must pass preferredRanges
func (r *RemoteList) Len(preferredRanges []*net.IPNet) int {
@ -437,6 +579,26 @@ func (r *RemoteList) unlockedCollect() {
}
}
dnsAddrs := r.hr.GetIPs()
for _, addr := range dnsAddrs {
if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) {
switch {
case addr.Addr().Is4():
v4 := addr.Addr().As4()
addrs = append(addrs, &udp.Addr{
IP: v4[:],
Port: addr.Port(),
})
case addr.Addr().Is6():
v6 := addr.Addr().As16()
addrs = append(addrs, &udp.Addr{
IP: v6[:],
Port: addr.Port(),
})
}
}
}
r.addrs = addrs
r.relays = relays

View File

@ -9,7 +9,7 @@ import (
)
func TestRemoteList_Rebuild(t *testing.T) {
rl := NewRemoteList()
rl := NewRemoteList(nil)
rl.unlockedSetV4(
0,
0,
@ -102,7 +102,7 @@ func TestRemoteList_Rebuild(t *testing.T) {
}
func BenchmarkFullRebuild(b *testing.B) {
rl := NewRemoteList()
rl := NewRemoteList(nil)
rl.unlockedSetV4(
0,
0,
@ -167,7 +167,7 @@ func BenchmarkFullRebuild(b *testing.B) {
}
func BenchmarkSortRebuild(b *testing.B) {
rl := NewRemoteList()
rl := NewRemoteList(nil)
rl.unlockedSetV4(
0,
0,