Merge remote-tracking branch 'origin/master' into mutex-debug

This commit is contained in:
Wade Simmons 2024-04-11 12:15:52 -04:00
commit 0ccfad1a1e
33 changed files with 1550 additions and 640 deletions

View File

@ -8,13 +8,13 @@ body:
attributes: attributes:
value: | value: |
### Thank you for taking the time to file a bug report! ### Thank you for taking the time to file a bug report!
Please fill out this form as completely as possible. Please fill out this form as completely as possible.
- type: input - type: input
id: version id: version
attributes: attributes:
label: What version of `nebula` are you using? label: What version of `nebula` are you using? (`nebula -version`)
placeholder: 0.0.0 placeholder: 0.0.0
validations: validations:
required: true required: true
@ -41,10 +41,17 @@ body:
attributes: attributes:
label: Logs from affected hosts label: Logs from affected hosts
description: | description: |
Provide logs from all affected hosts during the time of the issue. Please provide logs from ALL affected hosts during the time of the issue. If you do not provide logs we will be unable to assist you!
[Learn how to find Nebula logs here.](https://nebula.defined.net/docs/guides/viewing-nebula-logs/)
Improve formatting by using <code>```</code> at the beginning and end of each log block. Improve formatting by using <code>```</code> at the beginning and end of each log block.
value: |
```
```
validations: validations:
required: false required: true
- type: textarea - type: textarea
id: configs id: configs
@ -52,6 +59,11 @@ body:
label: Config files from affected hosts label: Config files from affected hosts
description: | description: |
Provide config files for all affected hosts. Provide config files for all affected hosts.
Improve formatting by using <code>```</code> at the beginning and end of each config file. Improve formatting by using <code>```</code> at the beginning and end of each config file.
value: |
```
```
validations: validations:
required: false required: true

View File

@ -142,15 +142,22 @@ func (tree *Tree4[T]) MostSpecificContains(ip iputil.VpnIp) (ok bool, value T) {
return ok, value return ok, value
} }
// Match finds the most specific match type eachFunc[T any] func(T) bool
// TODO this is exact match
func (tree *Tree4[T]) Match(ip iputil.VpnIp) (ok bool, value T) { // EachContains will call a function, passing the value, for each entry until the function returns true or the search is complete
// The final return value will be true if the provided function returned true
func (tree *Tree4[T]) EachContains(ip iputil.VpnIp, each eachFunc[T]) bool {
bit := startbit bit := startbit
node := tree.root node := tree.root
lastNode := node
for node != nil { for node != nil {
lastNode = node if node.hasValue {
// If the each func returns true then we can exit the loop
if each(node.value) {
return true
}
}
if ip&bit != 0 { if ip&bit != 0 {
node = node.right node = node.right
} else { } else {
@ -160,10 +167,33 @@ func (tree *Tree4[T]) Match(ip iputil.VpnIp) (ok bool, value T) {
bit >>= 1 bit >>= 1
} }
if bit == 0 && lastNode != nil { return false
value = lastNode.value }
ok = true
// GetCIDR returns the entry added by the most recent matching AddCIDR call
func (tree *Tree4[T]) GetCIDR(cidr *net.IPNet) (ok bool, value T) {
bit := startbit
node := tree.root
ip := iputil.Ip2VpnIp(cidr.IP)
mask := iputil.Ip2VpnIp(cidr.Mask)
// Find our last ancestor in the tree
for node != nil && bit&mask != 0 {
if ip&bit != 0 {
node = node.right
} else {
node = node.left
}
bit = bit >> 1
} }
if bit&mask == 0 && node != nil {
value = node.value
ok = node.hasValue
}
return ok, value return ok, value
} }

View File

@ -115,35 +115,36 @@ func TestCIDRTree_MostSpecificContains(t *testing.T) {
assert.Equal(t, "cool", r) assert.Equal(t, "cool", r)
} }
func TestCIDRTree_Match(t *testing.T) { func TestTree4_GetCIDR(t *testing.T) {
tree := NewTree4[string]() tree := NewTree4[string]()
tree.AddCIDR(Parse("4.1.1.0/32"), "1a") tree.AddCIDR(Parse("1.0.0.0/8"), "1")
tree.AddCIDR(Parse("4.1.1.1/32"), "1b") tree.AddCIDR(Parse("2.1.0.0/16"), "2")
tree.AddCIDR(Parse("3.1.1.0/24"), "3")
tree.AddCIDR(Parse("4.1.1.0/24"), "4a")
tree.AddCIDR(Parse("4.1.1.1/32"), "4b")
tree.AddCIDR(Parse("4.1.2.1/32"), "4c")
tree.AddCIDR(Parse("254.0.0.0/4"), "5")
tests := []struct { tests := []struct {
Found bool Found bool
Result interface{} Result interface{}
IP string IPNet *net.IPNet
}{ }{
{true, "1a", "4.1.1.0"}, {true, "1", Parse("1.0.0.0/8")},
{true, "1b", "4.1.1.1"}, {true, "2", Parse("2.1.0.0/16")},
{true, "3", Parse("3.1.1.0/24")},
{true, "4a", Parse("4.1.1.0/24")},
{true, "4b", Parse("4.1.1.1/32")},
{true, "4c", Parse("4.1.2.1/32")},
{true, "5", Parse("254.0.0.0/4")},
{false, "", Parse("2.0.0.0/8")},
} }
for _, tt := range tests { for _, tt := range tests {
ok, r := tree.Match(iputil.Ip2VpnIp(net.ParseIP(tt.IP))) ok, r := tree.GetCIDR(tt.IPNet)
assert.Equal(t, tt.Found, ok) assert.Equal(t, tt.Found, ok)
assert.Equal(t, tt.Result, r) assert.Equal(t, tt.Result, r)
} }
tree = NewTree4[string]()
tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))
assert.True(t, ok)
assert.Equal(t, "cool", r)
ok, r = tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))
assert.True(t, ok)
assert.Equal(t, "cool", r)
} }
func BenchmarkCIDRTree_Contains(b *testing.B) { func BenchmarkCIDRTree_Contains(b *testing.B) {
@ -167,25 +168,3 @@ func BenchmarkCIDRTree_Contains(b *testing.B) {
} }
}) })
} }
func BenchmarkCIDRTree_Match(b *testing.B) {
tree := NewTree4[string]()
tree.AddCIDR(Parse("1.1.0.0/16"), "1")
tree.AddCIDR(Parse("1.2.1.1/32"), "1")
tree.AddCIDR(Parse("192.2.1.1/32"), "1")
tree.AddCIDR(Parse("172.2.1.1/32"), "1")
ip := iputil.Ip2VpnIp(net.ParseIP("1.2.1.1"))
b.Run("found", func(b *testing.B) {
for i := 0; i < b.N; i++ {
tree.Match(ip)
}
})
ip = iputil.Ip2VpnIp(net.ParseIP("1.2.1.255"))
b.Run("not found", func(b *testing.B) {
for i := 0; i < b.N; i++ {
tree.Match(ip)
}
})
}

View File

@ -456,7 +456,7 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
} }
if n.punchy.GetTargetEverything() { if n.punchy.GetTargetEverything() {
hostinfo.remotes.ForEach(n.hostMap.preferredRanges, func(addr *udp.Addr, preferred bool) { hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr *udp.Addr, preferred bool) {
n.metricsTxPunchy.Inc(1) n.metricsTxPunchy.Inc(1)
n.intf.outside.WriteTo([]byte{1}, addr) n.intf.outside.WriteTo([]byte{1}, addr)
}) })

View File

@ -43,7 +43,9 @@ func Test_NewConnectionManagerTest(t *testing.T) {
preferredRanges := []*net.IPNet{localrange} preferredRanges := []*net.IPNet{localrange}
// Very incomplete mock objects // Very incomplete mock objects
hostMap := NewHostMap(l, vpncidr, preferredRanges) hostMap := newHostMap(l, vpncidr)
hostMap.preferredRanges.Store(&preferredRanges)
cs := &CertState{ cs := &CertState{
RawCertificate: []byte{}, RawCertificate: []byte{},
PrivateKey: []byte{}, PrivateKey: []byte{},
@ -123,7 +125,9 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
preferredRanges := []*net.IPNet{localrange} preferredRanges := []*net.IPNet{localrange}
// Very incomplete mock objects // Very incomplete mock objects
hostMap := NewHostMap(l, vpncidr, preferredRanges) hostMap := newHostMap(l, vpncidr)
hostMap.preferredRanges.Store(&preferredRanges)
cs := &CertState{ cs := &CertState{
RawCertificate: []byte{}, RawCertificate: []byte{},
PrivateKey: []byte{}, PrivateKey: []byte{},
@ -210,7 +214,8 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24")
preferredRanges := []*net.IPNet{localrange} preferredRanges := []*net.IPNet{localrange}
hostMap := NewHostMap(l, vpncidr, preferredRanges) hostMap := newHostMap(l, vpncidr)
hostMap.preferredRanges.Store(&preferredRanges)
// Generate keys for CA and peer's cert. // Generate keys for CA and peer's cert.
pubCA, privCA, _ := ed25519.GenerateKey(rand.Reader) pubCA, privCA, _ := ed25519.GenerateKey(rand.Reader)

View File

@ -145,7 +145,7 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlH
return nil return nil
} }
ch := copyHostInfo(h, c.f.hostMap.preferredRanges) ch := copyHostInfo(h, c.f.hostMap.GetPreferredRanges())
return &ch return &ch
} }
@ -157,7 +157,7 @@ func (c *Control) SetRemoteForTunnel(vpnIp iputil.VpnIp, addr udp.Addr) *Control
} }
hostInfo.SetRemote(addr.Copy()) hostInfo.SetRemote(addr.Copy())
ch := copyHostInfo(hostInfo, c.f.hostMap.preferredRanges) ch := copyHostInfo(hostInfo, c.f.hostMap.GetPreferredRanges())
return &ch return &ch
} }

View File

@ -18,7 +18,9 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object // Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
// To properly ensure we are not exposing core memory to the caller // To properly ensure we are not exposing core memory to the caller
hm := NewHostMap(l, &net.IPNet{}, make([]*net.IPNet, 0)) hm := newHostMap(l, &net.IPNet{})
hm.preferredRanges.Store(&[]*net.IPNet{})
remote1 := udp.NewAddr(net.ParseIP("0.0.0.100"), 4444) remote1 := udp.NewAddr(net.ParseIP("0.0.0.100"), 4444)
remote2 := udp.NewAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444) remote2 := udp.NewAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444)
ipNet := net.IPNet{ ipNet := net.IPNet{

View File

@ -309,6 +309,13 @@ firewall:
outbound_action: drop outbound_action: drop
inbound_action: drop inbound_action: drop
# Controls the default value for local_cidr. Default is true, will be deprecated after v1.9 and defaulted to false.
# This setting only affects nebula hosts with subnets encoded in their certificate. A nebula host acting as an
# unsafe router with `default_local_cidr_any: true` will expose their unsafe routes to every inbound rule regardless
# of the actual destination for the packet. Setting this to false requires each inbound rule to contain a `local_cidr`
# if the intention is to allow traffic to flow to an unsafe route.
#default_local_cidr_any: false
conntrack: conntrack:
tcp_timeout: 12m tcp_timeout: 12m
udp_timeout: 3m udp_timeout: 3m
@ -316,7 +323,7 @@ firewall:
# The firewall is default deny. There is no way to write a deny rule. # The firewall is default deny. There is no way to write a deny rule.
# Rules are comprised of a protocol, port, and one or more of host, group, or CIDR # Rules are comprised of a protocol, port, and one or more of host, group, or CIDR
# Logical evaluation is roughly: port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr) # Logical evaluation is roughly: port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr) AND (local cidr)
# - port: Takes `0` or `any` as any, a single number `80`, a range `200-901`, or `fragment` to match second and further fragments of fragmented packets (since there is no port available). # - port: Takes `0` or `any` as any, a single number `80`, a range `200-901`, or `fragment` to match second and further fragments of fragmented packets (since there is no port available).
# code: same as port but makes more sense when talking about ICMP, TODO: this is not currently implemented in a way that works, use `any` # code: same as port but makes more sense when talking about ICMP, TODO: this is not currently implemented in a way that works, use `any`
# proto: `any`, `tcp`, `udp`, or `icmp` # proto: `any`, `tcp`, `udp`, or `icmp`
@ -325,6 +332,8 @@ firewall:
# groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass # groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
# cidr: a remote CIDR, `0.0.0.0/0` is any. # cidr: a remote CIDR, `0.0.0.0/0` is any.
# local_cidr: a local CIDR, `0.0.0.0/0` is any. This could be used to filter destinations when using unsafe_routes. # local_cidr: a local CIDR, `0.0.0.0/0` is any. This could be used to filter destinations when using unsafe_routes.
# Default is `any` unless the certificate contains subnets and then the default is the ip issued in the certificate
# if `default_local_cidr_any` is false, otherwise its `any`.
# ca_name: An issuing CA name # ca_name: An issuing CA name
# ca_sha: An issuing CA shasum # ca_sha: An issuing CA shasum
@ -346,3 +355,10 @@ firewall:
groups: groups:
- laptop - laptop
- home - home
# Expose a subnet (unsafe route) to hosts with the group remote_client
# This example assume you have a subnet of 192.168.100.1/24 or larger encoded in the certificate
- port: 8080
proto: tcp
group: remote_client
local_cidr: 192.168.100.1/24

View File

@ -57,15 +57,18 @@ type Firewall struct {
DefaultTimeout time.Duration //linux: 600s DefaultTimeout time.Duration //linux: 600s
// Used to ensure we don't emit local packets for ips we don't own // Used to ensure we don't emit local packets for ips we don't own
localIps *cidr.Tree4[struct{}] localIps *cidr.Tree4[struct{}]
assignedCIDR *net.IPNet
hasSubnets bool
rules string rules string
rulesVersion uint16 rulesVersion uint16
trackTCPRTT bool defaultLocalCIDRAny bool
metricTCPRTT metrics.Histogram trackTCPRTT bool
incomingMetrics firewallMetrics metricTCPRTT metrics.Histogram
outgoingMetrics firewallMetrics incomingMetrics firewallMetrics
outgoingMetrics firewallMetrics
l *logrus.Logger l *logrus.Logger
} }
@ -83,6 +86,8 @@ type FirewallConntrack struct {
TimerWheel *TimerWheel[firewall.Packet] TimerWheel *TimerWheel[firewall.Packet]
} }
// FirewallTable is the entry point for a rule, the evaluation order is:
// Proto AND port AND (CA SHA or CA name) AND local CIDR AND (group OR groups OR name OR remote CIDR)
type FirewallTable struct { type FirewallTable struct {
TCP firewallPort TCP firewallPort
UDP firewallPort UDP firewallPort
@ -106,18 +111,27 @@ type FirewallCA struct {
} }
type FirewallRule struct { type FirewallRule struct {
// Any makes Hosts, Groups, CIDR and LocalCIDR irrelevant // Any makes Hosts, Groups, and CIDR irrelevant
Any bool Any *firewallLocalCIDR
Hosts map[string]struct{} Hosts map[string]*firewallLocalCIDR
Groups [][]string Groups []*firewallGroups
CIDR *cidr.Tree4[struct{}] CIDR *cidr.Tree4[*firewallLocalCIDR]
LocalCIDR *cidr.Tree4[struct{}] }
type firewallGroups struct {
Groups []string
LocalCIDR *firewallLocalCIDR
} }
// Even though ports are uint16, int32 maps are faster for lookup // Even though ports are uint16, int32 maps are faster for lookup
// Plus we can use `-1` for fragment rules // Plus we can use `-1` for fragment rules
type firewallPort map[int32]*FirewallCA type firewallPort map[int32]*FirewallCA
type firewallLocalCIDR struct {
Any bool
LocalCIDR *cidr.Tree4[struct{}]
}
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts. // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall { func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall {
//TODO: error on 0 duration //TODO: error on 0 duration
@ -138,8 +152,15 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
} }
localIps := cidr.NewTree4[struct{}]() localIps := cidr.NewTree4[struct{}]()
var assignedCIDR *net.IPNet
for _, ip := range c.Details.Ips { for _, ip := range c.Details.Ips {
localIps.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{}) ipNet := &net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}
localIps.AddCIDR(ipNet, struct{}{})
if assignedCIDR == nil {
// Only grabbing the first one in the cert since any more than that currently has undefined behavior
assignedCIDR = ipNet
}
} }
for _, n := range c.Details.Subnets { for _, n := range c.Details.Subnets {
@ -158,6 +179,8 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
UDPTimeout: UDPTimeout, UDPTimeout: UDPTimeout,
DefaultTimeout: defaultTimeout, DefaultTimeout: defaultTimeout,
localIps: localIps, localIps: localIps,
assignedCIDR: assignedCIDR,
hasSubnets: len(c.Details.Subnets) > 0,
l: l, l: l,
metricTCPRTT: metrics.GetOrRegisterHistogram("network.tcp.rtt", nil, metrics.NewExpDecaySample(1028, 0.015)), metricTCPRTT: metrics.GetOrRegisterHistogram("network.tcp.rtt", nil, metrics.NewExpDecaySample(1028, 0.015)),
@ -184,6 +207,9 @@ func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *conf
//TODO: max_connections //TODO: max_connections
) )
//TODO: Flip to false after v1.9 release
fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", true)
inboundAction := c.GetString("firewall.inbound_action", "drop") inboundAction := c.GetString("firewall.inbound_action", "drop")
switch inboundAction { switch inboundAction {
case "reject": case "reject":
@ -270,7 +296,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
return fmt.Errorf("unknown protocol %v", proto) return fmt.Errorf("unknown protocol %v", proto)
} }
return fp.addRule(startPort, endPort, groups, host, ip, localIp, caName, caSha) return fp.addRule(f, startPort, endPort, groups, host, ip, localIp, caName, caSha)
} }
// GetRuleHash returns a hash representation of all inbound and outbound rules // GetRuleHash returns a hash representation of all inbound and outbound rules
@ -624,7 +650,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaC
return false return false
} }
func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error { func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
if startPort > endPort { if startPort > endPort {
return fmt.Errorf("start port was lower than end port") return fmt.Errorf("start port was lower than end port")
} }
@ -637,7 +663,7 @@ func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string,
} }
} }
if err := fp[i].addRule(groups, host, ip, localIp, caName, caSha); err != nil { if err := fp[i].addRule(f, groups, host, ip, localIp, caName, caSha); err != nil {
return err return err
} }
} }
@ -668,13 +694,12 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCer
return fp[firewall.PortAny].match(p, c, caPool) return fp[firewall.PortAny].match(p, c, caPool)
} }
func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPNet, caName, caSha string) error { func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp *net.IPNet, caName, caSha string) error {
fr := func() *FirewallRule { fr := func() *FirewallRule {
return &FirewallRule{ return &FirewallRule{
Hosts: make(map[string]struct{}), Hosts: make(map[string]*firewallLocalCIDR),
Groups: make([][]string, 0), Groups: make([]*firewallGroups, 0),
CIDR: cidr.NewTree4[struct{}](), CIDR: cidr.NewTree4[*firewallLocalCIDR](),
LocalCIDR: cidr.NewTree4[struct{}](),
} }
} }
@ -683,14 +708,14 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPN
fc.Any = fr() fc.Any = fr()
} }
return fc.Any.addRule(groups, host, ip, localIp) return fc.Any.addRule(f, groups, host, ip, localIp)
} }
if caSha != "" { if caSha != "" {
if _, ok := fc.CAShas[caSha]; !ok { if _, ok := fc.CAShas[caSha]; !ok {
fc.CAShas[caSha] = fr() fc.CAShas[caSha] = fr()
} }
err := fc.CAShas[caSha].addRule(groups, host, ip, localIp) err := fc.CAShas[caSha].addRule(f, groups, host, ip, localIp)
if err != nil { if err != nil {
return err return err
} }
@ -700,7 +725,7 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPN
if _, ok := fc.CANames[caName]; !ok { if _, ok := fc.CANames[caName]; !ok {
fc.CANames[caName] = fr() fc.CANames[caName] = fr()
} }
err := fc.CANames[caName].addRule(groups, host, ip, localIp) err := fc.CANames[caName].addRule(f, groups, host, ip, localIp)
if err != nil { if err != nil {
return err return err
} }
@ -732,41 +757,63 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool
return fc.CANames[s.Details.Name].match(p, c) return fc.CANames[s.Details.Name].match(p, c)
} }
func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet, localIp *net.IPNet) error { func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *net.IPNet, localCIDR *net.IPNet) error {
if fr.Any { flc := func() *firewallLocalCIDR {
return nil return &firewallLocalCIDR{
LocalCIDR: cidr.NewTree4[struct{}](),
}
} }
if fr.isAny(groups, host, ip, localIp) { if fr.isAny(groups, host, ip) {
fr.Any = true if fr.Any == nil {
// If it's any we need to wipe out any pre-existing rules to save on memory fr.Any = flc()
fr.Groups = make([][]string, 0)
fr.Hosts = make(map[string]struct{})
fr.CIDR = cidr.NewTree4[struct{}]()
fr.LocalCIDR = cidr.NewTree4[struct{}]()
} else {
if len(groups) > 0 {
fr.Groups = append(fr.Groups, groups)
} }
if host != "" { return fr.Any.addRule(f, localCIDR)
fr.Hosts[host] = struct{}{} }
if len(groups) > 0 {
nlc := flc()
err := nlc.addRule(f, localCIDR)
if err != nil {
return err
} }
if ip != nil { fr.Groups = append(fr.Groups, &firewallGroups{
fr.CIDR.AddCIDR(ip, struct{}{}) Groups: groups,
} LocalCIDR: nlc,
})
}
if localIp != nil { if host != "" {
fr.LocalCIDR.AddCIDR(localIp, struct{}{}) nlc := fr.Hosts[host]
if nlc == nil {
nlc = flc()
} }
err := nlc.addRule(f, localCIDR)
if err != nil {
return err
}
fr.Hosts[host] = nlc
}
if ip != nil {
_, nlc := fr.CIDR.GetCIDR(ip)
if nlc == nil {
nlc = flc()
}
err := nlc.addRule(f, localCIDR)
if err != nil {
return err
}
fr.CIDR.AddCIDR(ip, nlc)
} }
return nil return nil
} }
func (fr *FirewallRule) isAny(groups []string, host string, ip, localIp *net.IPNet) bool { func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool {
if len(groups) == 0 && host == "" && ip == nil && localIp == nil { if len(groups) == 0 && host == "" && ip == nil {
return true return true
} }
@ -784,10 +831,6 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip, localIp *net.IPN
return true return true
} }
if localIp != nil && localIp.Contains(net.IPv4(0, 0, 0, 0)) {
return true
}
return false return false
} }
@ -797,7 +840,7 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
} }
// Shortcut path for if groups, hosts, or cidr contained an `any` // Shortcut path for if groups, hosts, or cidr contained an `any`
if fr.Any { if fr.Any.match(p, c) {
return true return true
} }
@ -805,7 +848,7 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
for _, sg := range fr.Groups { for _, sg := range fr.Groups {
found := false found := false
for _, g := range sg { for _, g := range sg.Groups {
if _, ok := c.Details.InvertedGroups[g]; !ok { if _, ok := c.Details.InvertedGroups[g]; !ok {
found = false found = false
break break
@ -814,33 +857,51 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
found = true found = true
} }
if found { if found && sg.LocalCIDR.match(p, c) {
return true return true
} }
} }
if fr.Hosts != nil { if fr.Hosts != nil {
if _, ok := fr.Hosts[c.Details.Name]; ok { if flc, ok := fr.Hosts[c.Details.Name]; ok {
return true if flc.match(p, c) {
return true
}
} }
} }
if fr.CIDR != nil { return fr.CIDR.EachContains(p.RemoteIP, func(flc *firewallLocalCIDR) bool {
ok, _ := fr.CIDR.Contains(p.RemoteIP) return flc.match(p, c)
if ok { })
return true }
func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp *net.IPNet) error {
if localIp == nil {
if !f.hasSubnets || f.defaultLocalCIDRAny {
flc.Any = true
return nil
} }
localIp = f.assignedCIDR
} else if localIp.Contains(net.IPv4(0, 0, 0, 0)) {
flc.Any = true
} }
if fr.LocalCIDR != nil { flc.LocalCIDR.AddCIDR(localIp, struct{}{})
ok, _ := fr.LocalCIDR.Contains(p.LocalIP) return nil
if ok { }
return true
} func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate) bool {
if flc == nil {
return false
} }
// No host, group, or cidr matched, bye bye if flc.Any {
return false return true
}
ok, _ := flc.LocalCIDR.Contains(p.LocalIP)
return ok
} }
type rule struct { type rule struct {

View File

@ -71,36 +71,32 @@ func TestFirewall_AddRule(t *testing.T) {
assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, nil, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, nil, "", ""))
// An empty rule is any // An empty rule is any
assert.True(t, fw.InRules.TCP[1].Any.Any) assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
assert.Empty(t, fw.InRules.TCP[1].Any.Groups) assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts) assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", ""))
assert.False(t, fw.InRules.UDP[1].Any.Any) assert.Nil(t, fw.InRules.UDP[1].Any.Any)
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1") assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts) assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, nil, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, nil, "", ""))
assert.False(t, fw.InRules.ICMP[1].Any.Any) assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups) assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1") assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, nil, "", "")) assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, nil, "", ""))
assert.False(t, fw.OutRules.AnyProto[1].Any.Any) assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups) ok, _ := fw.OutRules.AnyProto[1].Any.CIDR.GetCIDR(ti)
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
ok, _ := fw.OutRules.AnyProto[1].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP))
assert.True(t, ok) assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", nil, ti, "", "")) assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", nil, ti, "", ""))
assert.False(t, fw.OutRules.AnyProto[1].Any.Any) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups) ok, _ = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.GetCIDR(ti)
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
ok, _ = fw.OutRules.AnyProto[1].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP))
assert.True(t, ok) assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
@ -111,32 +107,14 @@ func TestFirewall_AddRule(t *testing.T) {
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", "ca-sha")) assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", "ca-sha"))
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha") assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
// Set any and clear fields
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, ti, "", ""))
assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0])
assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")
ok, _ = fw.OutRules.AnyProto[0].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP))
assert.True(t, ok)
ok, _ = fw.OutRules.AnyProto[0].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP))
assert.True(t, ok)
// run twice just to make sure
//TODO: these ANY rules should clear the CA firewall portion
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
assert.Empty(t, fw.OutRules.AnyProto[0].Any.Groups)
assert.Empty(t, fw.OutRules.AnyProto[0].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", "")) assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
_, anyIp, _ := net.ParseCIDR("0.0.0.0/0") _, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, nil, "", "")) assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, nil, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
// Test error conditions // Test error conditions
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
@ -226,33 +204,43 @@ func TestFirewall_Drop(t *testing.T) {
} }
func BenchmarkFirewallTable_match(b *testing.B) { func BenchmarkFirewallTable_match(b *testing.B) {
f := &Firewall{}
ft := FirewallTable{ ft := FirewallTable{
TCP: firewallPort{}, TCP: firewallPort{},
} }
_, n, _ := net.ParseCIDR("172.1.1.1/32") _, n, _ := net.ParseCIDR("172.1.1.1/32")
_ = ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, n, "", "") goodLocalCIDRIP := iputil.Ip2VpnIp(n.IP)
_ = ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, n, "", "") _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", n, nil, "", "")
_ = ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, n, "", "") _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", nil, n, "", "")
_ = ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, n, "", "")
_ = ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, n, "", "")
cp := cert.NewCAPool() cp := cert.NewCAPool()
b.Run("fail on proto", func(b *testing.B) { b.Run("fail on proto", func(b *testing.B) {
// This benchmark is showing us the cost of failing to match the protocol
c := &cert.NebulaCertificate{} c := &cert.NebulaCertificate{}
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp) assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp))
} }
}) })
b.Run("fail on port", func(b *testing.B) { b.Run("pass proto, fail on port", func(b *testing.B) {
// This benchmark is showing us the cost of matching a specific protocol but failing to match the port
c := &cert.NebulaCertificate{} c := &cert.NebulaCertificate{}
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp) assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp))
} }
}) })
b.Run("fail all group, name, and cidr", func(b *testing.B) { b.Run("pass proto, port, fail on local CIDR", func(b *testing.B) {
c := &cert.NebulaCertificate{}
ip, _, _ := net.ParseCIDR("9.254.254.254/32")
lip := iputil.Ip2VpnIp(ip)
for n := 0; n < b.N; n++ {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: lip}, true, c, cp))
}
})
b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) {
_, ip, _ := net.ParseCIDR("9.254.254.254/32") _, ip, _ := net.ParseCIDR("9.254.254.254/32")
c := &cert.NebulaCertificate{ c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{ Details: cert.NebulaCertificateDetails{
@ -262,11 +250,25 @@ func BenchmarkFirewallTable_match(b *testing.B) {
}, },
} }
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp) assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
} }
}) })
b.Run("pass on group", func(b *testing.B) { b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) {
_, ip, _ := net.ParseCIDR("9.254.254.254/32")
c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
InvertedGroups: map[string]struct{}{"nope": {}},
Name: "nope",
Ips: []*net.IPNet{ip},
},
}
for n := 0; n < b.N; n++ {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp))
}
})
b.Run("pass on group on any local cidr", func(b *testing.B) {
c := &cert.NebulaCertificate{ c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{ Details: cert.NebulaCertificateDetails{
InvertedGroups: map[string]struct{}{"good-group": {}}, InvertedGroups: map[string]struct{}{"good-group": {}},
@ -274,7 +276,19 @@ func BenchmarkFirewallTable_match(b *testing.B) {
}, },
} }
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp) assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
}
})
b.Run("pass on group on specific local cidr", func(b *testing.B) {
c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
InvertedGroups: map[string]struct{}{"good-group": {}},
Name: "nope",
},
}
for n := 0; n < b.N; n++ {
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp))
} }
}) })
@ -289,60 +303,60 @@ func BenchmarkFirewallTable_match(b *testing.B) {
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp) ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
} }
}) })
//
b.Run("pass on ip", func(b *testing.B) { //b.Run("pass on ip", func(b *testing.B) {
ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1)) // ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
c := &cert.NebulaCertificate{ // c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{ // Details: cert.NebulaCertificateDetails{
InvertedGroups: map[string]struct{}{"nope": {}}, // InvertedGroups: map[string]struct{}{"nope": {}},
Name: "good-host", // Name: "good-host",
}, // },
} // }
for n := 0; n < b.N; n++ { // for n := 0; n < b.N; n++ {
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp) // ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp)
} // }
}) //})
//
b.Run("pass on local ip", func(b *testing.B) { //b.Run("pass on local ip", func(b *testing.B) {
ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1)) // ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
c := &cert.NebulaCertificate{ // c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{ // Details: cert.NebulaCertificateDetails{
InvertedGroups: map[string]struct{}{"nope": {}}, // InvertedGroups: map[string]struct{}{"nope": {}},
Name: "good-host", // Name: "good-host",
}, // },
} // }
for n := 0; n < b.N; n++ { // for n := 0; n < b.N; n++ {
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, LocalIP: ip}, true, c, cp) // ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, LocalIP: ip}, true, c, cp)
} // }
}) //})
//
_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, n, "", "") //_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, n, "", "")
//
b.Run("pass on ip with any port", func(b *testing.B) { //b.Run("pass on ip with any port", func(b *testing.B) {
ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1)) // ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
c := &cert.NebulaCertificate{ // c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{ // Details: cert.NebulaCertificateDetails{
InvertedGroups: map[string]struct{}{"nope": {}}, // InvertedGroups: map[string]struct{}{"nope": {}},
Name: "good-host", // Name: "good-host",
}, // },
} // }
for n := 0; n < b.N; n++ { // for n := 0; n < b.N; n++ {
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp) // ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp)
} // }
}) //})
//
b.Run("pass on local ip with any port", func(b *testing.B) { //b.Run("pass on local ip with any port", func(b *testing.B) {
ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1)) // ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
c := &cert.NebulaCertificate{ // c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{ // Details: cert.NebulaCertificateDetails{
InvertedGroups: map[string]struct{}{"nope": {}}, // InvertedGroups: map[string]struct{}{"nope": {}},
Name: "good-host", // Name: "good-host",
}, // },
} // }
for n := 0; n < b.N; n++ { // for n := 0; n < b.N; n++ {
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip}, true, c, cp) // ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip}, true, c, cp)
} // }
}) //})
} }
func TestFirewall_Drop2(t *testing.T) { func TestFirewall_Drop2(t *testing.T) {

14
go.mod
View File

@ -8,7 +8,7 @@ require (
github.com/armon/go-radix v1.0.0 github.com/armon/go-radix v1.0.0
github.com/clarkmcc/go-dag v0.0.0-20220908000337-9c3ba5b365fc github.com/clarkmcc/go-dag v0.0.0-20220908000337-9c3ba5b365fc
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
github.com/flynn/noise v1.0.1 github.com/flynn/noise v1.1.0
github.com/gogo/protobuf v1.3.2 github.com/gogo/protobuf v1.3.2
github.com/google/gopacket v1.1.19 github.com/google/gopacket v1.1.19
github.com/kardianos/service v1.2.2 github.com/kardianos/service v1.2.2
@ -19,19 +19,19 @@ require (
github.com/sirupsen/logrus v1.9.3 github.com/sirupsen/logrus v1.9.3
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
github.com/stretchr/testify v1.8.4 github.com/stretchr/testify v1.9.0
github.com/timandy/routine v1.1.1 github.com/timandy/routine v1.1.1
github.com/vishvananda/netlink v1.2.1-beta.2 github.com/vishvananda/netlink v1.2.1-beta.2
golang.org/x/crypto v0.18.0 golang.org/x/crypto v0.21.0
golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53 golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53
golang.org/x/net v0.20.0 golang.org/x/net v0.22.0
golang.org/x/sync v0.6.0 golang.org/x/sync v0.6.0
golang.org/x/sys v0.16.0 golang.org/x/sys v0.18.0
golang.org/x/term v0.16.0 golang.org/x/term v0.18.0
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
golang.zx2c4.com/wireguard/windows v0.5.3 golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/protobuf v1.32.0 google.golang.org/protobuf v1.33.0
gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v2 v2.4.0
gvisor.dev/gvisor v0.0.0-20230504175454-7b0a1988a28f gvisor.dev/gvisor v0.0.0-20230504175454-7b0a1988a28f
) )

28
go.sum
View File

@ -24,8 +24,8 @@ github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432/go.
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/flynn/noise v1.0.1 h1:vPp/jdQLXC6ppsXSj/pM3W1BIJ5FEHE2TulSJBpb43Y= github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
github.com/flynn/noise v1.0.1/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
@ -134,10 +134,10 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/timandy/routine v1.1.1 h1:6/Z7qLFZj3GrzuRksBFzIG8YGUh8CLhjnnMePBQTrEI= github.com/timandy/routine v1.1.1 h1:6/Z7qLFZj3GrzuRksBFzIG8YGUh8CLhjnnMePBQTrEI=
github.com/timandy/routine v1.1.1/go.mod h1:OZHPOKSvqL/ZvqXFkNZyit0xIVelERptYXdAHH00adQ= github.com/timandy/routine v1.1.1/go.mod h1:OZHPOKSvqL/ZvqXFkNZyit0xIVelERptYXdAHH00adQ=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/vishvananda/netlink v1.2.1-beta.2 h1:Llsql0lnQEbHj0I1OuKyp8otXp0r3q0mPkuhwHfStVs= github.com/vishvananda/netlink v1.2.1-beta.2 h1:Llsql0lnQEbHj0I1OuKyp8otXp0r3q0mPkuhwHfStVs=
github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho= github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho=
github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
@ -150,8 +150,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc= golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53 h1:5llv2sWeaMSnA3w2kS57ouQQ4pudlXrR0dCgw51QK9o= golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53 h1:5llv2sWeaMSnA3w2kS57ouQQ4pudlXrR0dCgw51QK9o=
golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
@ -170,8 +170,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo= golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc=
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@ -198,11 +198,11 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.16.0 h1:m+B6fahuftsE9qjo0VWp2FW0mB3MTJvR0BaMQrq0pmE= golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8=
golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@ -233,8 +233,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.32.0 h1:pPC6BG5ex8PDFnkbrGU3EixyhKcQ2aDuBS36lqK/C7I= google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
google.golang.org/protobuf v1.32.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@ -408,7 +408,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha
hostinfo.remotes = f.lightHouse.QueryCache(vpnIp) hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp). f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp).
WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.preferredRanges)). WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())).
Info("Blocked addresses for handshakes") Info("Blocked addresses for handshakes")
// Swap the packet store to benefit the original intended recipient // Swap the packet store to benefit the original intended recipient

View File

@ -181,7 +181,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
hostinfo := hh.hostinfo hostinfo := hh.hostinfo
// If we are out of time, clean up // If we are out of time, clean up
if hh.counter >= hm.config.retries { if hh.counter >= hm.config.retries {
hh.hostinfo.logger(hm.l).WithField("udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.preferredRanges)). hh.hostinfo.logger(hm.l).WithField("udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())).
WithField("initiatorIndex", hh.hostinfo.localIndexId). WithField("initiatorIndex", hh.hostinfo.localIndexId).
WithField("remoteIndex", hh.hostinfo.remoteIndexId). WithField("remoteIndex", hh.hostinfo.remoteIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
@ -211,7 +211,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
hostinfo.remotes = hm.lightHouse.QueryCache(vpnIp) hostinfo.remotes = hm.lightHouse.QueryCache(vpnIp)
} }
remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.preferredRanges) remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())
remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hh.lastRemotes) remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hh.lastRemotes)
// We only care about a lighthouse trigger if we have new remotes to send to. // We only care about a lighthouse trigger if we have new remotes to send to.
@ -235,7 +235,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
// Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply // Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
var sentTo []*udp.Addr var sentTo []*udp.Addr
hostinfo.remotes.ForEach(hm.mainHostMap.preferredRanges, func(addr *udp.Addr, _ bool) { hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr *udp.Addr, _ bool) {
hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1) hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr) err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
if err != nil { if err != nil {
@ -362,7 +362,7 @@ func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han
hm.mainHostMap.RUnlock() hm.mainHostMap.RUnlock()
// Do not attempt promotion if you are a lighthouse // Do not attempt promotion if you are a lighthouse
if !hm.lightHouse.amLighthouse { if !hm.lightHouse.amLighthouse {
h.TryPromoteBest(hm.mainHostMap.preferredRanges, hm.f) h.TryPromoteBest(hm.mainHostMap.GetPreferredRanges(), hm.f)
} }
return h, true return h, true
} }
@ -602,7 +602,7 @@ func (hm *HandshakeManager) queryIndex(index uint32) *HandshakeHostInfo {
} }
func (c *HandshakeManager) GetPreferredRanges() []*net.IPNet { func (c *HandshakeManager) GetPreferredRanges() []*net.IPNet {
return c.mainHostMap.preferredRanges return c.mainHostMap.GetPreferredRanges()
} }
func (c *HandshakeManager) ForEachVpnIp(f controlEach) { func (c *HandshakeManager) ForEachVpnIp(f controlEach) {

View File

@ -19,7 +19,9 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
_, localrange, _ := net.ParseCIDR("10.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24")
ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2")) ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2"))
preferredRanges := []*net.IPNet{localrange} preferredRanges := []*net.IPNet{localrange}
mainHM := NewHostMap(l, vpncidr, preferredRanges) mainHM := newHostMap(l, vpncidr)
mainHM.preferredRanges.Store(&preferredRanges)
lh := newTestLighthouse() lh := newTestLighthouse()
cs := &CertState{ cs := &CertState{

View File

@ -10,6 +10,7 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
@ -56,9 +57,8 @@ type HostMap struct {
Relays map[uint32]*HostInfo // Maps a Relay IDX to a Relay HostInfo object Relays map[uint32]*HostInfo // Maps a Relay IDX to a Relay HostInfo object
RemoteIndexes map[uint32]*HostInfo RemoteIndexes map[uint32]*HostInfo
Hosts map[iputil.VpnIp]*HostInfo Hosts map[iputil.VpnIp]*HostInfo
preferredRanges []*net.IPNet preferredRanges atomic.Pointer[[]*net.IPNet]
vpnCIDR *net.IPNet vpnCIDR *net.IPNet
metricsEnabled bool
l *logrus.Logger l *logrus.Logger
} }
@ -254,22 +254,54 @@ type cachedPacketMetrics struct {
dropped metrics.Counter dropped metrics.Counter
} }
func NewHostMap(l *logrus.Logger, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap { func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR *net.IPNet, c *config.C) *HostMap {
h := map[iputil.VpnIp]*HostInfo{} hm := newHostMap(l, vpnCIDR)
i := map[uint32]*HostInfo{}
r := map[uint32]*HostInfo{} hm.reload(c, true)
relays := map[uint32]*HostInfo{} c.RegisterReloadCallback(func(c *config.C) {
m := HostMap{ hm.reload(c, false)
syncRWMutex: newSyncRWMutex("hostmap"), })
Indexes: i,
Relays: relays, l.WithField("network", hm.vpnCIDR.String()).
RemoteIndexes: r, WithField("preferredRanges", hm.GetPreferredRanges()).
Hosts: h, Info("Main HostMap created")
preferredRanges: preferredRanges,
vpnCIDR: vpnCIDR, return hm
l: l, }
func newHostMap(l *logrus.Logger, vpnCIDR *net.IPNet) *HostMap {
return &HostMap{
syncRWMutex: newSyncRWMutex("hostmap"),
Indexes: map[uint32]*HostInfo{},
Relays: map[uint32]*HostInfo{},
RemoteIndexes: map[uint32]*HostInfo{},
Hosts: map[iputil.VpnIp]*HostInfo{},
vpnCIDR: vpnCIDR,
l: l,
}
}
func (hm *HostMap) reload(c *config.C, initial bool) {
if initial || c.HasChanged("preferred_ranges") {
var preferredRanges []*net.IPNet
rawPreferredRanges := c.GetStringSlice("preferred_ranges", []string{})
for _, rawPreferredRange := range rawPreferredRanges {
_, preferredRange, err := net.ParseCIDR(rawPreferredRange)
if err != nil {
hm.l.WithError(err).WithField("range", rawPreferredRanges).Warn("Failed to parse preferred ranges, ignoring")
continue
}
preferredRanges = append(preferredRanges, preferredRange)
}
oldRanges := hm.preferredRanges.Swap(&preferredRanges)
if !initial {
hm.l.WithField("oldPreferredRanges", *oldRanges).WithField("newPreferredRanges", preferredRanges).Info("preferred_ranges changed")
}
} }
return &m
} }
// EmitStats reports host, index, and relay counts to the stats collection system // EmitStats reports host, index, and relay counts to the stats collection system
@ -458,7 +490,7 @@ func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) *HostI
hm.RUnlock() hm.RUnlock()
// Do not attempt promotion if you are a lighthouse // Do not attempt promotion if you are a lighthouse
if promoteIfce != nil && !promoteIfce.lightHouse.amLighthouse { if promoteIfce != nil && !promoteIfce.lightHouse.amLighthouse {
h.TryPromoteBest(hm.preferredRanges, promoteIfce) h.TryPromoteBest(hm.GetPreferredRanges(), promoteIfce)
} }
return h return h
@ -505,7 +537,8 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
} }
func (hm *HostMap) GetPreferredRanges() []*net.IPNet { func (hm *HostMap) GetPreferredRanges() []*net.IPNet {
return hm.preferredRanges //NOTE: if preferredRanges is ever not stored before a load this will fail to dereference a nil pointer
return *hm.preferredRanges.Load()
} }
func (hm *HostMap) ForEachVpnIp(f controlEach) { func (hm *HostMap) ForEachVpnIp(f controlEach) {
@ -597,7 +630,7 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool {
// NOTE: We do this loop here instead of calling `isPreferred` in // NOTE: We do this loop here instead of calling `isPreferred` in
// remote_list.go so that we only have to loop over preferredRanges once. // remote_list.go so that we only have to loop over preferredRanges once.
newIsPreferred := false newIsPreferred := false
for _, l := range hm.preferredRanges { for _, l := range hm.GetPreferredRanges() {
// return early if we are already on a preferred remote // return early if we are already on a preferred remote
if l.Contains(currentRemote.IP) { if l.Contains(currentRemote.IP) {
return false return false

View File

@ -4,19 +4,19 @@ import (
"net" "net"
"testing" "testing"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestHostMap_MakePrimary(t *testing.T) { func TestHostMap_MakePrimary(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
hm := NewHostMap( hm := newHostMap(
l, l,
&net.IPNet{ &net.IPNet{
IP: net.IP{10, 0, 0, 1}, IP: net.IP{10, 0, 0, 1},
Mask: net.IPMask{255, 255, 255, 0}, Mask: net.IPMask{255, 255, 255, 0},
}, },
[]*net.IPNet{},
) )
f := &Interface{} f := &Interface{}
@ -91,13 +91,12 @@ func TestHostMap_MakePrimary(t *testing.T) {
func TestHostMap_DeleteHostInfo(t *testing.T) { func TestHostMap_DeleteHostInfo(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
hm := NewHostMap( hm := newHostMap(
l, l,
&net.IPNet{ &net.IPNet{
IP: net.IP{10, 0, 0, 1}, IP: net.IP{10, 0, 0, 1},
Mask: net.IPMask{255, 255, 255, 0}, Mask: net.IPMask{255, 255, 255, 0},
}, },
[]*net.IPNet{},
) )
f := &Interface{} f := &Interface{}
@ -205,3 +204,33 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
prim = hm.QueryVpnIp(1) prim = hm.QueryVpnIp(1)
assert.Nil(t, prim) assert.Nil(t, prim)
} }
func TestHostMap_reload(t *testing.T) {
l := test.NewLogger()
c := config.NewC(l)
hm := NewHostMapFromConfig(
l,
&net.IPNet{
IP: net.IP{10, 0, 0, 1},
Mask: net.IPMask{255, 255, 255, 0},
},
c,
)
toS := func(ipn []*net.IPNet) []string {
var s []string
for _, n := range ipn {
s = append(s, n.String())
}
return s
}
assert.Empty(t, hm.GetPreferredRanges())
c.ReloadConfigString("preferred_ranges: [1.1.1.0/24, 10.1.1.0/24]")
assert.EqualValues(t, []string{"1.1.1.0/24", "10.1.1.0/24"}, toS(hm.GetPreferredRanges()))
c.ReloadConfigString("preferred_ranges: [1.1.1.1/32]")
assert.EqualValues(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges()))
}

47
main.go
View File

@ -183,52 +183,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
} }
} }
// Set up my internal host map hostMap := NewHostMapFromConfig(l, tunCidr, c)
var preferredRanges []*net.IPNet
rawPreferredRanges := c.GetStringSlice("preferred_ranges", []string{})
// First, check if 'preferred_ranges' is set and fallback to 'local_range'
if len(rawPreferredRanges) > 0 {
for _, rawPreferredRange := range rawPreferredRanges {
_, preferredRange, err := net.ParseCIDR(rawPreferredRange)
if err != nil {
return nil, util.ContextualizeIfNeeded("Failed to parse preferred ranges", err)
}
preferredRanges = append(preferredRanges, preferredRange)
}
}
// local_range was superseded by preferred_ranges. If it is still present,
// merge the local_range setting into preferred_ranges. We will probably
// deprecate local_range and remove in the future.
rawLocalRange := c.GetString("local_range", "")
if rawLocalRange != "" {
_, localRange, err := net.ParseCIDR(rawLocalRange)
if err != nil {
return nil, util.ContextualizeIfNeeded("Failed to parse local_range", err)
}
// Check if the entry for local_range was already specified in
// preferred_ranges. Don't put it into the slice twice if so.
var found bool
for _, r := range preferredRanges {
if r.String() == localRange.String() {
found = true
break
}
}
if !found {
preferredRanges = append(preferredRanges, localRange)
}
}
hostMap := NewHostMap(l, tunCidr, preferredRanges)
hostMap.metricsEnabled = c.GetBool("stats.message_metrics", false)
l.
WithField("network", hostMap.vpnCIDR.String()).
WithField("preferredRanges", hostMap.preferredRanges).
Info("Main HostMap created")
punchy := NewPunchyFromConfig(l, c) punchy := NewPunchyFromConfig(l, c)
lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy) lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy)
if err != nil { if err != nil {

View File

@ -1,6 +1,7 @@
package overlay package overlay
import ( import (
"bytes"
"fmt" "fmt"
"math" "math"
"net" "net"
@ -21,6 +22,35 @@ type Route struct {
Install bool Install bool
} }
// Equal determines if a route that could be installed in the system route table is equal to another
// Via is ignored since that is only consumed within nebula itself
func (r Route) Equal(t Route) bool {
if !r.Cidr.IP.Equal(t.Cidr.IP) {
return false
}
if !bytes.Equal(r.Cidr.Mask, t.Cidr.Mask) {
return false
}
if r.Metric != t.Metric {
return false
}
if r.MTU != t.MTU {
return false
}
if r.Install != t.Install {
return false
}
return true
}
func (r Route) String() string {
s := r.Cidr.String()
if r.Metric != 0 {
s += fmt.Sprintf(" metric: %v", r.Metric)
}
return s
}
func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*cidr.Tree4[iputil.VpnIp], error) { func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*cidr.Tree4[iputil.VpnIp], error) {
routeTree := cidr.NewTree4[iputil.VpnIp]() routeTree := cidr.NewTree4[iputil.VpnIp]()
for _, r := range routes { for _, r := range routes {

View File

@ -10,60 +10,63 @@ import (
const DefaultMTU = 1300 const DefaultMTU = 1300
// TODO: We may be able to remove routines
type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error)
func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) { func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) {
routes, err := parseRoutes(c, tunCidr)
if err != nil {
return nil, util.NewContextualError("Could not parse tun.routes", nil, err)
}
unsafeRoutes, err := parseUnsafeRoutes(c, tunCidr)
if err != nil {
return nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err)
}
routes = append(routes, unsafeRoutes...)
switch { switch {
case c.GetBool("tun.disabled", false): case c.GetBool("tun.disabled", false):
tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l) tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
return tun, nil return tun, nil
default: default:
return newTun( return newTun(c, l, tunCidr, routines > 1)
l,
c.GetString("tun.dev", ""),
tunCidr,
c.GetInt("tun.mtu", DefaultMTU),
routes,
c.GetInt("tun.tx_queue", 500),
routines > 1,
c.GetBool("tun.use_system_route_table", false),
)
} }
} }
func NewFdDeviceFromConfig(fd *int) DeviceFactory { func NewFdDeviceFromConfig(fd *int) DeviceFactory {
return func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) { return func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) {
routes, err := parseRoutes(c, tunCidr) return newTunFromFd(c, l, *fd, tunCidr)
if err != nil {
return nil, util.NewContextualError("Could not parse tun.routes", nil, err)
}
unsafeRoutes, err := parseUnsafeRoutes(c, tunCidr)
if err != nil {
return nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err)
}
routes = append(routes, unsafeRoutes...)
return newTunFromFd(
l,
*fd,
tunCidr,
c.GetInt("tun.mtu", DefaultMTU),
routes,
c.GetInt("tun.tx_queue", 500),
c.GetBool("tun.use_system_route_table", false),
)
} }
} }
func getAllRoutesFromConfig(c *config.C, cidr *net.IPNet, initial bool) (bool, []Route, error) {
if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") {
return false, nil, nil
}
routes, err := parseRoutes(c, cidr)
if err != nil {
return true, nil, util.NewContextualError("Could not parse tun.routes", nil, err)
}
unsafeRoutes, err := parseUnsafeRoutes(c, cidr)
if err != nil {
return true, nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err)
}
routes = append(routes, unsafeRoutes...)
return true, routes, nil
}
// findRemovedRoutes will return all routes that are not present in the newRoutes list and would affect the system route table.
// Via is not used to evaluate since it does not affect the system route table.
func findRemovedRoutes(newRoutes, oldRoutes []Route) []Route {
var removed []Route
has := func(entry Route) bool {
for _, check := range newRoutes {
if check.Equal(entry) {
return true
}
}
return false
}
for _, oldEntry := range oldRoutes {
if !has(oldEntry) {
removed = append(removed, oldEntry)
}
}
return removed
}

View File

@ -8,45 +8,57 @@ import (
"io" "io"
"net" "net"
"os" "os"
"sync/atomic"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/util"
) )
type tun struct { type tun struct {
io.ReadWriteCloser io.ReadWriteCloser
fd int fd int
cidr *net.IPNet cidr *net.IPNet
routeTree *cidr.Tree4[iputil.VpnIp] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
l *logrus.Logger l *logrus.Logger
} }
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool) (*tun, error) { func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) (*tun, error) {
routeTree, err := makeRouteTree(l, routes, false)
if err != nil {
return nil, err
}
// XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly. // XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly.
// Be sure not to call file.Fd() as it will set the fd to blocking mode. // Be sure not to call file.Fd() as it will set the fd to blocking mode.
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
return &tun{ t := &tun{
ReadWriteCloser: file, ReadWriteCloser: file,
fd: deviceFd, fd: deviceFd,
cidr: cidr, cidr: cidr,
l: l, l: l,
routeTree: routeTree, }
}, nil
err := t.reload(c, true)
if err != nil {
return nil, err
}
c.RegisterReloadCallback(func(c *config.C) {
err := t.reload(c, false)
if err != nil {
util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
}
})
return t, nil
} }
func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool, _ bool) (*tun, error) { func newTun(_ *config.C, _ *logrus.Logger, _ *net.IPNet, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTun not supported in Android") return nil, fmt.Errorf("newTun not supported in Android")
} }
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
_, r := t.routeTree.MostSpecificContains(ip) _, r := t.routeTree.Load().MostSpecificContains(ip)
return r return r
} }
@ -54,6 +66,27 @@ func (t tun) Activate() error {
return nil return nil
} }
func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
if err != nil {
return err
}
if !initial && !change {
return nil
}
routeTree, err := makeRouteTree(t.l, routes, false)
if err != nil {
return err
}
// Teach nebula how to handle the routes
t.Routes.Store(&routes)
t.routeTree.Store(routeTree)
return nil
}
func (t *tun) Cidr() *net.IPNet { func (t *tun) Cidr() *net.IPNet {
return t.cidr return t.cidr
} }

View File

@ -9,12 +9,15 @@ import (
"io" "io"
"net" "net"
"os" "os"
"sync/atomic"
"syscall" "syscall"
"unsafe" "unsafe"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/util"
netroute "golang.org/x/net/route" netroute "golang.org/x/net/route"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
@ -24,8 +27,9 @@ type tun struct {
Device string Device string
cidr *net.IPNet cidr *net.IPNet
DefaultMTU int DefaultMTU int
Routes []Route Routes atomic.Pointer[[]Route]
routeTree *cidr.Tree4[iputil.VpnIp] routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
linkAddr *netroute.LinkAddr
l *logrus.Logger l *logrus.Logger
// cache out buffer since we need to prepend 4 bytes for tun metadata // cache out buffer since we need to prepend 4 bytes for tun metadata
@ -69,12 +73,8 @@ type ifreqMTU struct {
pad [8]byte pad [8]byte
} }
func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) { func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) {
routeTree, err := makeRouteTree(l, routes, false) name := c.GetString("tun.dev", "")
if err != nil {
return nil, err
}
ifIndex := -1 ifIndex := -1
if name != "" && name != "utun" { if name != "" && name != "utun" {
_, err := fmt.Sscanf(name, "utun%d", &ifIndex) _, err := fmt.Sscanf(name, "utun%d", &ifIndex)
@ -142,17 +142,27 @@ func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, rout
file := os.NewFile(uintptr(fd), "") file := os.NewFile(uintptr(fd), "")
tun := &tun{ t := &tun{
ReadWriteCloser: file, ReadWriteCloser: file,
Device: name, Device: name,
cidr: cidr, cidr: cidr,
DefaultMTU: defaultMTU, DefaultMTU: c.GetInt("tun.mtu", DefaultMTU),
Routes: routes,
routeTree: routeTree,
l: l, l: l,
} }
return tun, nil err = t.reload(c, true)
if err != nil {
return nil, err
}
c.RegisterReloadCallback(func(c *config.C) {
err := t.reload(c, false)
if err != nil {
util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
}
})
return t, nil
} }
func (t *tun) deviceBytes() (o [16]byte) { func (t *tun) deviceBytes() (o [16]byte) {
@ -162,7 +172,7 @@ func (t *tun) deviceBytes() (o [16]byte) {
return return
} }
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) { func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in Darwin") return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
} }
@ -260,6 +270,7 @@ func (t *tun) Activate() error {
if linkAddr == nil { if linkAddr == nil {
return fmt.Errorf("unable to discover link_addr for tun interface") return fmt.Errorf("unable to discover link_addr for tun interface")
} }
t.linkAddr = linkAddr
copy(routeAddr.IP[:], addr[:]) copy(routeAddr.IP[:], addr[:])
copy(maskAddr.IP[:], mask[:]) copy(maskAddr.IP[:], mask[:])
@ -278,33 +289,48 @@ func (t *tun) Activate() error {
} }
// Unsafe path routes // Unsafe path routes
for _, r := range t.Routes { return t.addRoutes(false)
if r.Via == nil || !r.Install { }
// We don't allow route MTUs so only install routes with a via
continue
}
copy(routeAddr.IP[:], r.Cidr.IP.To4()) func (t *tun) reload(c *config.C, initial bool) error {
copy(maskAddr.IP[:], net.IP(r.Cidr.Mask).To4()) change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
if err != nil {
return err
}
err = addRoute(routeSock, routeAddr, maskAddr, linkAddr) if !initial && !change {
return nil
}
routeTree, err := makeRouteTree(t.l, routes, false)
if err != nil {
return err
}
// Teach nebula how to handle the routes before establishing them in the system table
oldRoutes := t.Routes.Swap(&routes)
t.routeTree.Store(routeTree)
if !initial {
// Remove first, if the system removes a wanted route hopefully it will be re-added next
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
if err != nil { if err != nil {
if errors.Is(err, unix.EEXIST) { util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
t.l.WithField("route", r.Cidr).
Warnf("unable to add unsafe_route, identical route already exists")
} else {
return err
}
} }
// TODO how to set metric // Ensure any routes we actually want are installed
err = t.addRoutes(true)
if err != nil {
// Catch any stray logs
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
}
} }
return nil return nil
} }
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
ok, r := t.routeTree.MostSpecificContains(ip) ok, r := t.routeTree.Load().MostSpecificContains(ip)
if ok { if ok {
return r return r
} }
@ -340,6 +366,88 @@ func getLinkAddr(name string) (*netroute.LinkAddr, error) {
return nil, nil return nil, nil
} }
func (t *tun) addRoutes(logErrors bool) error {
routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer func() {
unix.Shutdown(routeSock, unix.SHUT_RDWR)
err := unix.Close(routeSock)
if err != nil {
t.l.WithError(err).Error("failed to close AF_ROUTE socket")
}
}()
routeAddr := &netroute.Inet4Addr{}
maskAddr := &netroute.Inet4Addr{}
routes := *t.Routes.Load()
for _, r := range routes {
if r.Via == nil || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
copy(routeAddr.IP[:], r.Cidr.IP.To4())
copy(maskAddr.IP[:], net.IP(r.Cidr.Mask).To4())
err := addRoute(routeSock, routeAddr, maskAddr, t.linkAddr)
if err != nil {
if errors.Is(err, unix.EEXIST) {
t.l.WithField("route", r.Cidr).
Warnf("unable to add unsafe_route, identical route already exists")
} else {
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
if logErrors {
retErr.Log(t.l)
} else {
return retErr
}
}
} else {
t.l.WithField("route", r).Info("Added route")
}
}
return nil
}
func (t *tun) removeRoutes(routes []Route) error {
routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer func() {
unix.Shutdown(routeSock, unix.SHUT_RDWR)
err := unix.Close(routeSock)
if err != nil {
t.l.WithError(err).Error("failed to close AF_ROUTE socket")
}
}()
routeAddr := &netroute.Inet4Addr{}
maskAddr := &netroute.Inet4Addr{}
for _, r := range routes {
if !r.Install {
continue
}
copy(routeAddr.IP[:], r.Cidr.IP.To4())
copy(maskAddr.IP[:], net.IP(r.Cidr.Mask).To4())
err := delRoute(routeSock, routeAddr, maskAddr, t.linkAddr)
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
}
}
return nil
}
func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error { func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error {
r := netroute.RouteMessage{ r := netroute.RouteMessage{
Version: unix.RTM_VERSION, Version: unix.RTM_VERSION,
@ -365,6 +473,30 @@ func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr)
return nil return nil
} }
func delRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error {
r := netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_DELETE,
Seq: 1,
Addrs: []netroute.Addr{
unix.RTAX_DST: addr,
unix.RTAX_GATEWAY: link,
unix.RTAX_NETMASK: mask,
},
}
data, err := r.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
}
return nil
}
func (t *tun) Read(to []byte) (int, error) { func (t *tun) Read(to []byte) (int, error) {
buf := make([]byte, len(to)+4) buf := make([]byte, len(to)+4)

View File

@ -13,12 +13,15 @@ import (
"os" "os"
"os/exec" "os/exec"
"strconv" "strconv"
"sync/atomic"
"syscall" "syscall"
"unsafe" "unsafe"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/util"
) )
const ( const (
@ -47,8 +50,8 @@ type tun struct {
Device string Device string
cidr *net.IPNet cidr *net.IPNet
MTU int MTU int
Routes []Route Routes atomic.Pointer[[]Route]
routeTree *cidr.Tree4[iputil.VpnIp] routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
l *logrus.Logger l *logrus.Logger
io.ReadWriteCloser io.ReadWriteCloser
@ -76,14 +79,15 @@ func (t *tun) Close() error {
return nil return nil
} }
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) { func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD") return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
} }
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) { func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) {
// Try to open existing tun device // Try to open existing tun device
var file *os.File var file *os.File
var err error var err error
deviceName := c.GetString("tun.dev", "")
if deviceName != "" { if deviceName != "" {
file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0) file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
} }
@ -144,47 +148,85 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int
ioctl(fd, syscall.SIOCSIFNAME, uintptr(unsafe.Pointer(&ifrr))) ioctl(fd, syscall.SIOCSIFNAME, uintptr(unsafe.Pointer(&ifrr)))
} }
routeTree, err := makeRouteTree(l, routes, false) t := &tun{
ReadWriteCloser: file,
Device: deviceName,
cidr: cidr,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
}
err = t.reload(c, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &tun{ c.RegisterReloadCallback(func(c *config.C) {
ReadWriteCloser: file, err := t.reload(c, false)
Device: deviceName, if err != nil {
cidr: cidr, util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
MTU: defaultMTU, }
Routes: routes, })
routeTree: routeTree,
l: l, return t, nil
}, nil
} }
func (t *tun) Activate() error { func (t *tun) Activate() error {
var err error var err error
// TODO use syscalls instead of exec.Command // TODO use syscalls instead of exec.Command
t.l.Debug("command: ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String()) cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String())
if err = exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String()).Run(); err != nil { t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err) return fmt.Errorf("failed to run 'ifconfig': %s", err)
} }
t.l.Debug("command: route", "-n", "add", "-net", t.cidr.String(), "-interface", t.Device)
if err = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), "-interface", t.Device).Run(); err != nil { cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), "-interface", t.Device)
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err) return fmt.Errorf("failed to run 'route add': %s", err)
} }
t.l.Debug("command: ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
if err = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU)).Run(); err != nil { cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err) return fmt.Errorf("failed to run 'ifconfig': %s", err)
} }
// Unsafe path routes // Unsafe path routes
for _, r := range t.Routes { return t.addRoutes(false)
if r.Via == nil || !r.Install { }
// We don't allow route MTUs so only install routes with a via
continue func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
if err != nil {
return err
}
if !initial && !change {
return nil
}
routeTree, err := makeRouteTree(t.l, routes, false)
if err != nil {
return err
}
// Teach nebula how to handle the routes before establishing them in the system table
oldRoutes := t.Routes.Swap(&routes)
t.routeTree.Store(routeTree)
if !initial {
// Remove first, if the system removes a wanted route hopefully it will be re-added next
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
if err != nil {
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
} }
t.l.Debug("command: route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device) // Ensure any routes we actually want are installed
if err = exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device).Run(); err != nil { err = t.addRoutes(true)
return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.Cidr.String(), err) if err != nil {
// Catch any stray logs
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
} }
} }
@ -192,7 +234,7 @@ func (t *tun) Activate() error {
} }
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
_, r := t.routeTree.MostSpecificContains(ip) _, r := t.routeTree.Load().MostSpecificContains(ip)
return r return r
} }
@ -208,6 +250,46 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd") return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
} }
func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load()
for _, r := range routes {
if r.Via == nil || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device)
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
if logErrors {
retErr.Log(t.l)
} else {
return retErr
}
}
}
return nil
}
func (t *tun) removeRoutes(routes []Route) error {
for _, r := range routes {
if !r.Install {
continue
}
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), "-interface", t.Device)
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
}
}
return nil
}
func (t *tun) deviceBytes() (o [16]byte) { func (t *tun) deviceBytes() (o [16]byte) {
for i, c := range t.Device { for i, c := range t.Device {
o[i] = byte(c) o[i] = byte(c)

View File

@ -10,43 +10,78 @@ import (
"net" "net"
"os" "os"
"sync" "sync"
"sync/atomic"
"syscall" "syscall"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/util"
) )
type tun struct { type tun struct {
io.ReadWriteCloser io.ReadWriteCloser
cidr *net.IPNet cidr *net.IPNet
routeTree *cidr.Tree4[iputil.VpnIp] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
l *logrus.Logger
} }
func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool, _ bool) (*tun, error) { func newTun(_ *config.C, _ *logrus.Logger, _ *net.IPNet, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTun not supported in iOS") return nil, fmt.Errorf("newTun not supported in iOS")
} }
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool) (*tun, error) { func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) (*tun, error) {
routeTree, err := makeRouteTree(l, routes, false) file := os.NewFile(uintptr(deviceFd), "/dev/tun")
t := &tun{
cidr: cidr,
ReadWriteCloser: &tunReadCloser{f: file},
l: l,
}
err := t.reload(c, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
file := os.NewFile(uintptr(deviceFd), "/dev/tun") c.RegisterReloadCallback(func(c *config.C) {
return &tun{ err := t.reload(c, false)
cidr: cidr, if err != nil {
ReadWriteCloser: &tunReadCloser{f: file}, util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
routeTree: routeTree, }
}, nil })
return t, nil
} }
func (t *tun) Activate() error { func (t *tun) Activate() error {
return nil return nil
} }
func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
if err != nil {
return err
}
if !initial && !change {
return nil
}
routeTree, err := makeRouteTree(t.l, routes, false)
if err != nil {
return err
}
// Teach nebula how to handle the routes
t.Routes.Store(&routes)
t.routeTree.Store(routeTree)
return nil
}
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
_, r := t.routeTree.MostSpecificContains(ip) _, r := t.routeTree.Load().MostSpecificContains(ip)
return r return r
} }

View File

@ -15,21 +15,25 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/util"
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
type tun struct { type tun struct {
io.ReadWriteCloser io.ReadWriteCloser
fd int fd int
Device string Device string
cidr *net.IPNet cidr *net.IPNet
MaxMTU int MaxMTU int
DefaultMTU int DefaultMTU int
TXQueueLen int TXQueueLen int
deviceIndex int
ioctlFd uintptr
Routes []Route Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
routeChan chan struct{} routeChan chan struct{}
useSystemRoutes bool useSystemRoutes bool
@ -61,30 +65,20 @@ type ifreqQLEN struct {
pad [8]byte pad [8]byte
} }
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int, useSystemRoutes bool) (*tun, error) { func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) (*tun, error) {
routeTree, err := makeRouteTree(l, routes, true) file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
t, err := newTunGeneric(c, l, file, cidr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") t.Device = "tun0"
t := &tun{
ReadWriteCloser: file,
fd: int(file.Fd()),
Device: "tun0",
cidr: cidr,
DefaultMTU: defaultMTU,
TXQueueLen: txQueueLen,
Routes: routes,
useSystemRoutes: useSystemRoutes,
l: l,
}
t.routeTree.Store(routeTree)
return t, nil return t, nil
} }
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int, multiqueue bool, useSystemRoutes bool) (*tun, error) { func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, multiqueue bool) (*tun, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil { if err != nil {
return nil, err return nil, err
@ -95,46 +89,113 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int
if multiqueue { if multiqueue {
req.Flags |= unix.IFF_MULTI_QUEUE req.Flags |= unix.IFF_MULTI_QUEUE
} }
copy(req.Name[:], deviceName) copy(req.Name[:], c.GetString("tun.dev", ""))
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
return nil, err return nil, err
} }
name := strings.Trim(string(req.Name[:]), "\x00") name := strings.Trim(string(req.Name[:]), "\x00")
file := os.NewFile(uintptr(fd), "/dev/net/tun") file := os.NewFile(uintptr(fd), "/dev/net/tun")
t, err := newTunGeneric(c, l, file, cidr)
maxMTU := defaultMTU
for _, r := range routes {
if r.MTU == 0 {
r.MTU = defaultMTU
}
if r.MTU > maxMTU {
maxMTU = r.MTU
}
}
routeTree, err := makeRouteTree(l, routes, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
t.Device = name
return t, nil
}
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr *net.IPNet) (*tun, error) {
t := &tun{ t := &tun{
ReadWriteCloser: file, ReadWriteCloser: file,
fd: int(file.Fd()), fd: int(file.Fd()),
Device: name,
cidr: cidr, cidr: cidr,
MaxMTU: maxMTU, TXQueueLen: c.GetInt("tun.tx_queue", 500),
DefaultMTU: defaultMTU, useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
TXQueueLen: txQueueLen,
Routes: routes,
useSystemRoutes: useSystemRoutes,
l: l, l: l,
} }
t.routeTree.Store(routeTree)
err := t.reload(c, true)
if err != nil {
return nil, err
}
c.RegisterReloadCallback(func(c *config.C) {
err := t.reload(c, false)
if err != nil {
util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
}
})
return t, nil return t, nil
} }
func (t *tun) reload(c *config.C, initial bool) error {
routeChange, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
if err != nil {
return err
}
if !initial && !routeChange && !c.HasChanged("tun.mtu") {
return nil
}
routeTree, err := makeRouteTree(t.l, routes, true)
if err != nil {
return err
}
oldDefaultMTU := t.DefaultMTU
oldMaxMTU := t.MaxMTU
newDefaultMTU := c.GetInt("tun.mtu", DefaultMTU)
newMaxMTU := newDefaultMTU
for i, r := range routes {
if r.MTU == 0 {
routes[i].MTU = newDefaultMTU
}
if r.MTU > t.MaxMTU {
newMaxMTU = r.MTU
}
}
t.MaxMTU = newMaxMTU
t.DefaultMTU = newDefaultMTU
// Teach nebula how to handle the routes before establishing them in the system table
oldRoutes := t.Routes.Swap(&routes)
t.routeTree.Store(routeTree)
if !initial {
if oldMaxMTU != newMaxMTU {
t.setMTU()
t.l.Infof("Set max MTU to %v was %v", t.MaxMTU, oldMaxMTU)
}
if oldDefaultMTU != newDefaultMTU {
err := t.setDefaultRoute()
if err != nil {
t.l.Warn(err)
} else {
t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU)
}
}
// Remove first, if the system removes a wanted route hopefully it will be re-added next
t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
// Ensure any routes we actually want are installed
err = t.addRoutes(true)
if err != nil {
// This should never be called since addRoutes should log its own errors in a reload condition
util.LogWithContextIfNeeded("Failed to refresh routes", err, t.l)
}
}
return nil
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil { if err != nil {
@ -208,7 +269,7 @@ func (t *tun) Activate() error {
if err != nil { if err != nil {
return err return err
} }
fd := uintptr(s) t.ioctlFd = uintptr(s)
ifra := ifreqAddr{ ifra := ifreqAddr{
Name: devName, Name: devName,
@ -219,52 +280,76 @@ func (t *tun) Activate() error {
} }
// Set the device ip address // Set the device ip address
if err = ioctl(fd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil { if err = ioctl(t.ioctlFd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil {
return fmt.Errorf("failed to set tun address: %s", err) return fmt.Errorf("failed to set tun address: %s", err)
} }
// Set the device network // Set the device network
ifra.Addr.Addr = mask ifra.Addr.Addr = mask
if err = ioctl(fd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil { if err = ioctl(t.ioctlFd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil {
return fmt.Errorf("failed to set tun netmask: %s", err) return fmt.Errorf("failed to set tun netmask: %s", err)
} }
// Set the device name // Set the device name
ifrf := ifReq{Name: devName} ifrf := ifReq{Name: devName}
if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { if err = ioctl(t.ioctlFd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to set tun device name: %s", err) return fmt.Errorf("failed to set tun device name: %s", err)
} }
// Set the MTU on the device // Setup our default MTU
ifm := ifreqMTU{Name: devName, MTU: int32(t.MaxMTU)} t.setMTU()
if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
// This is currently a non fatal condition because the route table must have the MTU set appropriately as well
t.l.WithError(err).Error("Failed to set tun mtu")
}
// Set the transmit queue length // Set the transmit queue length
ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)} ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)}
if err = ioctl(fd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil { if err = ioctl(t.ioctlFd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
// If we can't set the queue length nebula will still work but it may lead to packet loss // If we can't set the queue length nebula will still work but it may lead to packet loss
t.l.WithError(err).Error("Failed to set tun tx queue length") t.l.WithError(err).Error("Failed to set tun tx queue length")
} }
// Bring up the interface // Bring up the interface
ifrf.Flags = ifrf.Flags | unix.IFF_UP ifrf.Flags = ifrf.Flags | unix.IFF_UP
if err = ioctl(fd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to bring the tun device up: %s", err) return fmt.Errorf("failed to bring the tun device up: %s", err)
} }
// Set the routes
link, err := netlink.LinkByName(t.Device) link, err := netlink.LinkByName(t.Device)
if err != nil { if err != nil {
return fmt.Errorf("failed to get tun device link: %s", err) return fmt.Errorf("failed to get tun device link: %s", err)
} }
t.deviceIndex = link.Attrs().Index
if err = t.setDefaultRoute(); err != nil {
return err
}
// Set the routes
if err = t.addRoutes(false); err != nil {
return err
}
// Run the interface
ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING
if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to run tun device: %s", err)
}
return nil
}
func (t *tun) setMTU() {
// Set the MTU on the device
ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MaxMTU)}
if err := ioctl(t.ioctlFd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
// This is currently a non fatal condition because the route table must have the MTU set appropriately as well
t.l.WithError(err).Error("Failed to set tun mtu")
}
}
func (t *tun) setDefaultRoute() error {
// Default route // Default route
dr := &net.IPNet{IP: t.cidr.IP.Mask(t.cidr.Mask), Mask: t.cidr.Mask} dr := &net.IPNet{IP: t.cidr.IP.Mask(t.cidr.Mask), Mask: t.cidr.Mask}
nr := netlink.Route{ nr := netlink.Route{
LinkIndex: link.Attrs().Index, LinkIndex: t.deviceIndex,
Dst: dr, Dst: dr,
MTU: t.DefaultMTU, MTU: t.DefaultMTU,
AdvMSS: t.advMSS(Route{}), AdvMSS: t.advMSS(Route{}),
@ -274,19 +359,24 @@ func (t *tun) Activate() error {
Table: unix.RT_TABLE_MAIN, Table: unix.RT_TABLE_MAIN,
Type: unix.RTN_UNICAST, Type: unix.RTN_UNICAST,
} }
err = netlink.RouteReplace(&nr) err := netlink.RouteReplace(&nr)
if err != nil { if err != nil {
return fmt.Errorf("failed to set mtu %v on the default route %v; %v", t.DefaultMTU, dr, err) return fmt.Errorf("failed to set mtu %v on the default route %v; %v", t.DefaultMTU, dr, err)
} }
return nil
}
func (t *tun) addRoutes(logErrors bool) error {
// Path routes // Path routes
for _, r := range t.Routes { routes := *t.Routes.Load()
for _, r := range routes {
if !r.Install { if !r.Install {
continue continue
} }
nr := netlink.Route{ nr := netlink.Route{
LinkIndex: link.Attrs().Index, LinkIndex: t.deviceIndex,
Dst: r.Cidr, Dst: r.Cidr,
MTU: r.MTU, MTU: r.MTU,
AdvMSS: t.advMSS(r), AdvMSS: t.advMSS(r),
@ -297,21 +387,49 @@ func (t *tun) Activate() error {
nr.Priority = r.Metric nr.Priority = r.Metric
} }
err = netlink.RouteAdd(&nr) err := netlink.RouteReplace(&nr)
if err != nil { if err != nil {
return fmt.Errorf("failed to set mtu %v on route %v; %v", r.MTU, r.Cidr, err) retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
if logErrors {
retErr.Log(t.l)
} else {
return retErr
}
} else {
t.l.WithField("route", r).Info("Added route")
} }
} }
// Run the interface
ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING
if err = ioctl(fd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to run tun device: %s", err)
}
return nil return nil
} }
func (t *tun) removeRoutes(routes []Route) {
for _, r := range routes {
if !r.Install {
continue
}
nr := netlink.Route{
LinkIndex: t.deviceIndex,
Dst: r.Cidr,
MTU: r.MTU,
AdvMSS: t.advMSS(r),
Scope: unix.RT_SCOPE_LINK,
}
if r.Metric > 0 {
nr.Priority = r.Metric
}
err := netlink.RouteDel(&nr)
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
}
}
}
func (t *tun) Cidr() *net.IPNet { func (t *tun) Cidr() *net.IPNet {
return t.cidr return t.cidr
} }
@ -410,5 +528,9 @@ func (t *tun) Close() error {
t.ReadWriteCloser.Close() t.ReadWriteCloser.Close()
} }
if t.ioctlFd > 0 {
os.NewFile(t.ioctlFd, "ioctlFd").Close()
}
return nil return nil
} }

View File

@ -11,12 +11,15 @@ import (
"os/exec" "os/exec"
"regexp" "regexp"
"strconv" "strconv"
"sync/atomic"
"syscall" "syscall"
"unsafe" "unsafe"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/util"
) )
type ifreqDestroy struct { type ifreqDestroy struct {
@ -28,8 +31,8 @@ type tun struct {
Device string Device string
cidr *net.IPNet cidr *net.IPNet
MTU int MTU int
Routes []Route Routes atomic.Pointer[[]Route]
routeTree *cidr.Tree4[iputil.VpnIp] routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
l *logrus.Logger l *logrus.Logger
io.ReadWriteCloser io.ReadWriteCloser
@ -56,43 +59,50 @@ func (t *tun) Close() error {
return nil return nil
} }
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) { func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD") return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
} }
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) { func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) {
// Try to open tun device // Try to open tun device
var file *os.File var file *os.File
var err error var err error
deviceName := c.GetString("tun.dev", "")
if deviceName == "" { if deviceName == "" {
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified") return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
} }
if !deviceNameRE.MatchString(deviceName) { if !deviceNameRE.MatchString(deviceName) {
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified") return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
} }
file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0) file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }
routeTree, err := makeRouteTree(l, routes, false) t := &tun{
if err != nil {
return nil, err
}
return &tun{
ReadWriteCloser: file, ReadWriteCloser: file,
Device: deviceName, Device: deviceName,
cidr: cidr, cidr: cidr,
MTU: defaultMTU, MTU: c.GetInt("tun.mtu", DefaultMTU),
Routes: routes,
routeTree: routeTree,
l: l, l: l,
}, nil }
err = t.reload(c, true)
if err != nil {
return nil, err
}
c.RegisterReloadCallback(func(c *config.C) {
err := t.reload(c, false)
if err != nil {
util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
}
})
return t, nil
} }
func (t *tun) Activate() error { func (t *tun) Activate() error {
@ -116,17 +126,42 @@ func (t *tun) Activate() error {
if err = cmd.Run(); err != nil { if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err) return fmt.Errorf("failed to run 'ifconfig': %s", err)
} }
// Unsafe path routes // Unsafe path routes
for _, r := range t.Routes { return t.addRoutes(false)
if r.Via == nil || !r.Install { }
// We don't allow route MTUs so only install routes with a via
continue func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
if err != nil {
return err
}
if !initial && !change {
return nil
}
routeTree, err := makeRouteTree(t.l, routes, false)
if err != nil {
return err
}
// Teach nebula how to handle the routes before establishing them in the system table
oldRoutes := t.Routes.Swap(&routes)
t.routeTree.Store(routeTree)
if !initial {
// Remove first, if the system removes a wanted route hopefully it will be re-added next
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
if err != nil {
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
} }
cmd = exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.IP.String()) // Ensure any routes we actually want are installed
t.l.Debug("command: ", cmd.String()) err = t.addRoutes(true)
if err = cmd.Run(); err != nil { if err != nil {
return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.Cidr.String(), err) // Catch any stray logs
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
} }
} }
@ -134,7 +169,7 @@ func (t *tun) Activate() error {
} }
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
_, r := t.routeTree.MostSpecificContains(ip) _, r := t.routeTree.Load().MostSpecificContains(ip)
return r return r
} }
@ -150,6 +185,46 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd") return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd")
} }
func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load()
for _, r := range routes {
if r.Via == nil || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.IP.String())
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
if logErrors {
retErr.Log(t.l)
} else {
return retErr
}
}
}
return nil
}
func (t *tun) removeRoutes(routes []Route) error {
for _, r := range routes {
if !r.Install {
continue
}
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.cidr.IP.String())
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
}
}
return nil
}
func (t *tun) deviceBytes() (o [16]byte) { func (t *tun) deviceBytes() (o [16]byte) {
for i, c := range t.Device { for i, c := range t.Device {
o[i] = byte(c) o[i] = byte(c)

View File

@ -11,19 +11,22 @@ import (
"os/exec" "os/exec"
"regexp" "regexp"
"strconv" "strconv"
"sync/atomic"
"syscall" "syscall"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/util"
) )
type tun struct { type tun struct {
Device string Device string
cidr *net.IPNet cidr *net.IPNet
MTU int MTU int
Routes []Route Routes atomic.Pointer[[]Route]
routeTree *cidr.Tree4[iputil.VpnIp] routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
l *logrus.Logger l *logrus.Logger
io.ReadWriteCloser io.ReadWriteCloser
@ -40,13 +43,14 @@ func (t *tun) Close() error {
return nil return nil
} }
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) { func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD") return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD")
} }
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) { func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) {
deviceName := c.GetString("tun.dev", "")
if deviceName == "" { if deviceName == "" {
return nil, fmt.Errorf("a device name in the format of tunN must be specified") return nil, fmt.Errorf("a device name in the format of tunN must be specified")
} }
@ -60,20 +64,64 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int
return nil, err return nil, err
} }
routeTree, err := makeRouteTree(l, routes, false) t := &tun{
ReadWriteCloser: file,
Device: deviceName,
cidr: cidr,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
}
err = t.reload(c, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &tun{ c.RegisterReloadCallback(func(c *config.C) {
ReadWriteCloser: file, err := t.reload(c, false)
Device: deviceName, if err != nil {
cidr: cidr, util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
MTU: defaultMTU, }
Routes: routes, })
routeTree: routeTree,
l: l, return t, nil
}, nil }
func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
if err != nil {
return err
}
if !initial && !change {
return nil
}
routeTree, err := makeRouteTree(t.l, routes, false)
if err != nil {
return err
}
// Teach nebula how to handle the routes before establishing them in the system table
oldRoutes := t.Routes.Swap(&routes)
t.routeTree.Store(routeTree)
if !initial {
// Remove first, if the system removes a wanted route hopefully it will be re-added next
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
if err != nil {
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
}
// Ensure any routes we actually want are installed
err = t.addRoutes(true)
if err != nil {
// Catch any stray logs
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
}
}
return nil
} }
func (t *tun) Activate() error { func (t *tun) Activate() error {
@ -98,25 +146,52 @@ func (t *tun) Activate() error {
} }
// Unsafe path routes // Unsafe path routes
for _, r := range t.Routes { return t.addRoutes(false)
}
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
_, r := t.routeTree.Load().MostSpecificContains(ip)
return r
}
func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load()
for _, r := range routes {
if r.Via == nil || !r.Install { if r.Via == nil || !r.Install {
// We don't allow route MTUs so only install routes with a via // We don't allow route MTUs so only install routes with a via
continue continue
} }
cmd = exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.IP.String()) cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.IP.String())
t.l.Debug("command: ", cmd.String()) t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil { if err := cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.Cidr.String(), err) retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
if logErrors {
retErr.Log(t.l)
} else {
return retErr
}
} }
} }
return nil return nil
} }
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { func (t *tun) removeRoutes(routes []Route) error {
_, r := t.routeTree.MostSpecificContains(ip) for _, r := range routes {
return r if !r.Install {
continue
}
cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.cidr.IP.String())
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
}
}
return nil
} }
func (t *tun) Cidr() *net.IPNet { func (t *tun) Cidr() *net.IPNet {

View File

@ -12,6 +12,7 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/iputil"
) )
@ -27,14 +28,18 @@ type TestTun struct {
TxPackets chan []byte // Packets transmitted outside by nebula TxPackets chan []byte // Packets transmitted outside by nebula
} }
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool, _ bool) (*TestTun, error) { func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*TestTun, error) {
_, routes, err := getAllRoutesFromConfig(c, cidr, true)
if err != nil {
return nil, err
}
routeTree, err := makeRouteTree(l, routes, false) routeTree, err := makeRouteTree(l, routes, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &TestTun{ return &TestTun{
Device: deviceName, Device: c.GetString("tun.dev", ""),
cidr: cidr, cidr: cidr,
Routes: routes, Routes: routes,
routeTree: routeTree, routeTree: routeTree,
@ -44,7 +49,7 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, _ int, routes
}, nil }, nil
} }
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*TestTun, error) { func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*TestTun, error) {
return nil, fmt.Errorf("newTunFromFd not supported") return nil, fmt.Errorf("newTunFromFd not supported")
} }

View File

@ -6,10 +6,13 @@ import (
"net" "net"
"os/exec" "os/exec"
"strconv" "strconv"
"sync/atomic"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/util"
"github.com/songgao/water" "github.com/songgao/water"
) )
@ -17,25 +20,34 @@ type waterTun struct {
Device string Device string
cidr *net.IPNet cidr *net.IPNet
MTU int MTU int
Routes []Route Routes atomic.Pointer[[]Route]
routeTree *cidr.Tree4[iputil.VpnIp] routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
l *logrus.Logger
f *net.Interface
*water.Interface *water.Interface
} }
func newWaterTun(l *logrus.Logger, cidr *net.IPNet, defaultMTU int, routes []Route) (*waterTun, error) { func newWaterTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*waterTun, error) {
routeTree, err := makeRouteTree(l, routes, false) // NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate()
t := &waterTun{
cidr: cidr,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
}
err := t.reload(c, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate() c.RegisterReloadCallback(func(c *config.C) {
return &waterTun{ err := t.reload(c, false)
cidr: cidr, if err != nil {
MTU: defaultMTU, util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
Routes: routes, }
routeTree: routeTree, })
}, nil
return t, nil
} }
func (t *waterTun) Activate() error { func (t *waterTun) Activate() error {
@ -74,30 +86,104 @@ func (t *waterTun) Activate() error {
return fmt.Errorf("failed to run 'netsh' to set MTU: %s", err) return fmt.Errorf("failed to run 'netsh' to set MTU: %s", err)
} }
iface, err := net.InterfaceByName(t.Device) t.f, err = net.InterfaceByName(t.Device)
if err != nil { if err != nil {
return fmt.Errorf("failed to find interface named %s: %v", t.Device, err) return fmt.Errorf("failed to find interface named %s: %v", t.Device, err)
} }
for _, r := range t.Routes { err = t.addRoutes(false)
if r.Via == nil || !r.Install { if err != nil {
// We don't allow route MTUs so only install routes with a via return err
continue }
}
err = exec.Command( return nil
"C:\\Windows\\System32\\route.exe", "add", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(iface.Index), "METRIC", strconv.Itoa(r.Metric), }
).Run()
func (t *waterTun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
if err != nil {
return err
}
if !initial && !change {
return nil
}
routeTree, err := makeRouteTree(t.l, routes, false)
if err != nil {
return err
}
// Teach nebula how to handle the routes before establishing them in the system table
oldRoutes := t.Routes.Swap(&routes)
t.routeTree.Store(routeTree)
if !initial {
// Remove first, if the system removes a wanted route hopefully it will be re-added next
t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
// Ensure any routes we actually want are installed
err = t.addRoutes(true)
if err != nil { if err != nil {
return fmt.Errorf("failed to add the unsafe_route %s: %v", r.Cidr.String(), err) // Catch any stray logs
util.LogWithContextIfNeeded("Failed to set routes", err, t.l)
} else {
for _, r := range findRemovedRoutes(routes, *oldRoutes) {
t.l.WithField("route", r).Info("Removed route")
}
} }
} }
return nil return nil
} }
func (t *waterTun) addRoutes(logErrors bool) error {
// Path routes
routes := *t.Routes.Load()
for _, r := range routes {
if r.Via == nil || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
err := exec.Command(
"C:\\Windows\\System32\\route.exe", "add", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(t.f.Index), "METRIC", strconv.Itoa(r.Metric),
).Run()
if err != nil {
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
if logErrors {
retErr.Log(t.l)
} else {
return retErr
}
} else {
t.l.WithField("route", r).Info("Added route")
}
}
return nil
}
func (t *waterTun) removeRoutes(routes []Route) {
for _, r := range routes {
if !r.Install {
continue
}
err := exec.Command(
"C:\\Windows\\System32\\route.exe", "delete", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(t.f.Index), "METRIC", strconv.Itoa(r.Metric),
).Run()
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
}
}
}
func (t *waterTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { func (t *waterTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
_, r := t.routeTree.MostSpecificContains(ip) _, r := t.routeTree.Load().MostSpecificContains(ip)
return r return r
} }

View File

@ -12,13 +12,14 @@ import (
"syscall" "syscall"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
) )
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (Device, error) { func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (Device, error) {
return nil, fmt.Errorf("newTunFromFd not supported in Windows") return nil, fmt.Errorf("newTunFromFd not supported in Windows")
} }
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (Device, error) { func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, multiqueue bool) (Device, error) {
useWintun := true useWintun := true
if err := checkWinTunExists(); err != nil { if err := checkWinTunExists(); err != nil {
l.WithError(err).Warn("Check Wintun driver failed, fallback to wintap driver") l.WithError(err).Warn("Check Wintun driver failed, fallback to wintap driver")
@ -26,14 +27,14 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int
} }
if useWintun { if useWintun {
device, err := newWinTun(l, deviceName, cidr, defaultMTU, routes) device, err := newWinTun(c, l, cidr, multiqueue)
if err != nil { if err != nil {
return nil, fmt.Errorf("create Wintun interface failed, %w", err) return nil, fmt.Errorf("create Wintun interface failed, %w", err)
} }
return device, nil return device, nil
} }
device, err := newWaterTun(l, cidr, defaultMTU, routes) device, err := newWaterTun(c, l, cidr, multiqueue)
if err != nil { if err != nil {
return nil, fmt.Errorf("create wintap driver failed, %w", err) return nil, fmt.Errorf("create wintap driver failed, %w", err)
} }

View File

@ -6,11 +6,14 @@ import (
"io" "io"
"net" "net"
"net/netip" "net/netip"
"sync/atomic"
"unsafe" "unsafe"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/util"
"github.com/slackhq/nebula/wintun" "github.com/slackhq/nebula/wintun"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
@ -23,8 +26,9 @@ type winTun struct {
cidr *net.IPNet cidr *net.IPNet
prefix netip.Prefix prefix netip.Prefix
MTU int MTU int
Routes []Route Routes atomic.Pointer[[]Route]
routeTree *cidr.Tree4[iputil.VpnIp] routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
l *logrus.Logger
tun *wintun.NativeTun tun *wintun.NativeTun
} }
@ -48,83 +52,148 @@ func generateGUIDByDeviceName(name string) (*windows.GUID, error) {
return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil
} }
func newWinTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route) (*winTun, error) { func newWinTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*winTun, error) {
deviceName := c.GetString("tun.dev", "")
guid, err := generateGUIDByDeviceName(deviceName) guid, err := generateGUIDByDeviceName(deviceName)
if err != nil { if err != nil {
return nil, fmt.Errorf("generate GUID failed: %w", err) return nil, fmt.Errorf("generate GUID failed: %w", err)
} }
var tunDevice wintun.Device
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, defaultMTU)
if err != nil {
// Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device.
// Trying a second time resolves the issue.
l.WithError(err).Debug("Failed to create wintun device, retrying")
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, defaultMTU)
if err != nil {
return nil, fmt.Errorf("create TUN device failed: %w", err)
}
}
routeTree, err := makeRouteTree(l, routes, false)
if err != nil {
return nil, err
}
prefix, err := iputil.ToNetIpPrefix(*cidr) prefix, err := iputil.ToNetIpPrefix(*cidr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &winTun{ t := &winTun{
Device: deviceName, Device: deviceName,
cidr: cidr, cidr: cidr,
prefix: prefix, prefix: prefix,
MTU: defaultMTU, MTU: c.GetInt("tun.mtu", DefaultMTU),
Routes: routes, l: l,
routeTree: routeTree, }
tun: tunDevice.(*wintun.NativeTun), err = t.reload(c, true)
}, nil if err != nil {
return nil, err
}
var tunDevice wintun.Device
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
if err != nil {
// Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device.
// Trying a second time resolves the issue.
l.WithError(err).Debug("Failed to create wintun device, retrying")
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
if err != nil {
return nil, fmt.Errorf("create TUN device failed: %w", err)
}
}
t.tun = tunDevice.(*wintun.NativeTun)
c.RegisterReloadCallback(func(c *config.C) {
err := t.reload(c, false)
if err != nil {
util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
}
})
return t, nil
}
func (t *winTun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
if err != nil {
return err
}
if !initial && !change {
return nil
}
routeTree, err := makeRouteTree(t.l, routes, false)
if err != nil {
return err
}
// Teach nebula how to handle the routes before establishing them in the system table
oldRoutes := t.Routes.Swap(&routes)
t.routeTree.Store(routeTree)
if !initial {
// Remove first, if the system removes a wanted route hopefully it will be re-added next
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
if err != nil {
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
}
// Ensure any routes we actually want are installed
err = t.addRoutes(true)
if err != nil {
// Catch any stray logs
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
}
}
return nil
} }
func (t *winTun) Activate() error { func (t *winTun) Activate() error {
luid := winipcfg.LUID(t.tun.LUID()) luid := winipcfg.LUID(t.tun.LUID())
if err := luid.SetIPAddresses([]netip.Prefix{t.prefix}); err != nil { err := luid.SetIPAddresses([]netip.Prefix{t.prefix})
if err != nil {
return fmt.Errorf("failed to set address: %w", err) return fmt.Errorf("failed to set address: %w", err)
} }
foundDefault4 := false err = t.addRoutes(false)
routes := make([]*winipcfg.RouteData, 0, len(t.Routes)+1) if err != nil {
return err
}
for _, r := range t.Routes { return nil
}
func (t *winTun) addRoutes(logErrors bool) error {
luid := winipcfg.LUID(t.tun.LUID())
routes := *t.Routes.Load()
foundDefault4 := false
for _, r := range routes {
if r.Via == nil || !r.Install { if r.Via == nil || !r.Install {
// We don't allow route MTUs so only install routes with a via // We don't allow route MTUs so only install routes with a via
continue continue
} }
prefix, err := iputil.ToNetIpPrefix(*r.Cidr)
if err != nil {
retErr := util.NewContextualError("Failed to parse cidr to netip prefix, ignoring route", map[string]interface{}{"route": r}, err)
if logErrors {
retErr.Log(t.l)
continue
} else {
return retErr
}
}
// Add our unsafe route
err = luid.AddRoute(prefix, r.Via.ToNetIpAddr(), uint32(r.Metric))
if err != nil {
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
if logErrors {
retErr.Log(t.l)
continue
} else {
return retErr
}
} else {
t.l.WithField("route", r).Info("Added route")
}
if !foundDefault4 { if !foundDefault4 {
if ones, bits := r.Cidr.Mask.Size(); ones == 0 && bits != 0 { if ones, bits := r.Cidr.Mask.Size(); ones == 0 && bits != 0 {
foundDefault4 = true foundDefault4 = true
} }
} }
prefix, err := iputil.ToNetIpPrefix(*r.Cidr)
if err != nil {
return err
}
// Add our unsafe route
routes = append(routes, &winipcfg.RouteData{
Destination: prefix,
NextHop: r.Via.ToNetIpAddr(),
Metric: uint32(r.Metric),
})
}
if err := luid.AddRoutes(routes); err != nil {
return fmt.Errorf("failed to add routes: %w", err)
} }
ipif, err := luid.IPInterface(windows.AF_INET) ipif, err := luid.IPInterface(windows.AF_INET)
@ -141,12 +210,35 @@ func (t *winTun) Activate() error {
if err := ipif.Set(); err != nil { if err := ipif.Set(); err != nil {
return fmt.Errorf("failed to set ip interface: %w", err) return fmt.Errorf("failed to set ip interface: %w", err)
} }
return nil
}
func (t *winTun) removeRoutes(routes []Route) error {
luid := winipcfg.LUID(t.tun.LUID())
for _, r := range routes {
if !r.Install {
continue
}
prefix, err := iputil.ToNetIpPrefix(*r.Cidr)
if err != nil {
t.l.WithError(err).WithField("route", r).Info("Failed to convert cidr to netip prefix")
continue
}
err = luid.DeleteRoute(prefix, r.Via.ToNetIpAddr())
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
}
}
return nil return nil
} }
func (t *winTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { func (t *winTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
_, r := t.routeTree.MostSpecificContains(ip) _, r := t.routeTree.Load().MostSpecificContains(ip)
return r return r
} }

2
ssh.go
View File

@ -939,7 +939,7 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
enc.SetIndent("", " ") enc.SetIndent("", " ")
} }
return enc.Encode(copyHostInfo(hostInfo, ifce.hostMap.preferredRanges)) return enc.Encode(copyHostInfo(hostInfo, ifce.hostMap.GetPreferredRanges()))
} }
func sshReload(c *config.C, w sshd.StringWriter) error { func sshReload(c *config.C, w sshd.StringWriter) error {

View File

@ -2,6 +2,7 @@ package util
import ( import (
"errors" "errors"
"fmt"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -40,7 +41,7 @@ func (ce *ContextualError) Error() string {
if ce.RealError == nil { if ce.RealError == nil {
return ce.Context return ce.Context
} }
return ce.RealError.Error() return fmt.Errorf("%s (%v): %w", ce.Context, ce.Fields, ce.RealError).Error()
} }
func (ce *ContextualError) Unwrap() error { func (ce *ContextualError) Unwrap() error {