Merge remote-tracking branch 'origin/master' into mutex-debug

This commit is contained in:
Wade Simmons 2024-04-11 12:15:52 -04:00
commit 0ccfad1a1e
33 changed files with 1550 additions and 640 deletions

View File

@ -14,7 +14,7 @@ body:
- type: input
id: version
attributes:
label: What version of `nebula` are you using?
label: What version of `nebula` are you using? (`nebula -version`)
placeholder: 0.0.0
validations:
required: true
@ -41,10 +41,17 @@ body:
attributes:
label: Logs from affected hosts
description: |
Provide logs from all affected hosts during the time of the issue.
Please provide logs from ALL affected hosts during the time of the issue. If you do not provide logs we will be unable to assist you!
[Learn how to find Nebula logs here.](https://nebula.defined.net/docs/guides/viewing-nebula-logs/)
Improve formatting by using <code>```</code> at the beginning and end of each log block.
value: |
```
```
validations:
required: false
required: true
- type: textarea
id: configs
@ -52,6 +59,11 @@ body:
label: Config files from affected hosts
description: |
Provide config files for all affected hosts.
Improve formatting by using <code>```</code> at the beginning and end of each config file.
value: |
```
```
validations:
required: false
required: true

View File

@ -142,15 +142,22 @@ func (tree *Tree4[T]) MostSpecificContains(ip iputil.VpnIp) (ok bool, value T) {
return ok, value
}
// Match finds the most specific match
// TODO this is exact match
func (tree *Tree4[T]) Match(ip iputil.VpnIp) (ok bool, value T) {
type eachFunc[T any] func(T) bool
// EachContains will call a function, passing the value, for each entry until the function returns true or the search is complete
// The final return value will be true if the provided function returned true
func (tree *Tree4[T]) EachContains(ip iputil.VpnIp, each eachFunc[T]) bool {
bit := startbit
node := tree.root
lastNode := node
for node != nil {
lastNode = node
if node.hasValue {
// If the each func returns true then we can exit the loop
if each(node.value) {
return true
}
}
if ip&bit != 0 {
node = node.right
} else {
@ -160,10 +167,33 @@ func (tree *Tree4[T]) Match(ip iputil.VpnIp) (ok bool, value T) {
bit >>= 1
}
if bit == 0 && lastNode != nil {
value = lastNode.value
ok = true
return false
}
// GetCIDR returns the entry added by the most recent matching AddCIDR call
func (tree *Tree4[T]) GetCIDR(cidr *net.IPNet) (ok bool, value T) {
bit := startbit
node := tree.root
ip := iputil.Ip2VpnIp(cidr.IP)
mask := iputil.Ip2VpnIp(cidr.Mask)
// Find our last ancestor in the tree
for node != nil && bit&mask != 0 {
if ip&bit != 0 {
node = node.right
} else {
node = node.left
}
bit = bit >> 1
}
if bit&mask == 0 && node != nil {
value = node.value
ok = node.hasValue
}
return ok, value
}

View File

@ -115,35 +115,36 @@ func TestCIDRTree_MostSpecificContains(t *testing.T) {
assert.Equal(t, "cool", r)
}
func TestCIDRTree_Match(t *testing.T) {
func TestTree4_GetCIDR(t *testing.T) {
tree := NewTree4[string]()
tree.AddCIDR(Parse("4.1.1.0/32"), "1a")
tree.AddCIDR(Parse("4.1.1.1/32"), "1b")
tree.AddCIDR(Parse("1.0.0.0/8"), "1")
tree.AddCIDR(Parse("2.1.0.0/16"), "2")
tree.AddCIDR(Parse("3.1.1.0/24"), "3")
tree.AddCIDR(Parse("4.1.1.0/24"), "4a")
tree.AddCIDR(Parse("4.1.1.1/32"), "4b")
tree.AddCIDR(Parse("4.1.2.1/32"), "4c")
tree.AddCIDR(Parse("254.0.0.0/4"), "5")
tests := []struct {
Found bool
Result interface{}
IP string
IPNet *net.IPNet
}{
{true, "1a", "4.1.1.0"},
{true, "1b", "4.1.1.1"},
{true, "1", Parse("1.0.0.0/8")},
{true, "2", Parse("2.1.0.0/16")},
{true, "3", Parse("3.1.1.0/24")},
{true, "4a", Parse("4.1.1.0/24")},
{true, "4b", Parse("4.1.1.1/32")},
{true, "4c", Parse("4.1.2.1/32")},
{true, "5", Parse("254.0.0.0/4")},
{false, "", Parse("2.0.0.0/8")},
}
for _, tt := range tests {
ok, r := tree.Match(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))
ok, r := tree.GetCIDR(tt.IPNet)
assert.Equal(t, tt.Found, ok)
assert.Equal(t, tt.Result, r)
}
tree = NewTree4[string]()
tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))
assert.True(t, ok)
assert.Equal(t, "cool", r)
ok, r = tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))
assert.True(t, ok)
assert.Equal(t, "cool", r)
}
func BenchmarkCIDRTree_Contains(b *testing.B) {
@ -167,25 +168,3 @@ func BenchmarkCIDRTree_Contains(b *testing.B) {
}
})
}
func BenchmarkCIDRTree_Match(b *testing.B) {
tree := NewTree4[string]()
tree.AddCIDR(Parse("1.1.0.0/16"), "1")
tree.AddCIDR(Parse("1.2.1.1/32"), "1")
tree.AddCIDR(Parse("192.2.1.1/32"), "1")
tree.AddCIDR(Parse("172.2.1.1/32"), "1")
ip := iputil.Ip2VpnIp(net.ParseIP("1.2.1.1"))
b.Run("found", func(b *testing.B) {
for i := 0; i < b.N; i++ {
tree.Match(ip)
}
})
ip = iputil.Ip2VpnIp(net.ParseIP("1.2.1.255"))
b.Run("not found", func(b *testing.B) {
for i := 0; i < b.N; i++ {
tree.Match(ip)
}
})
}

View File

@ -456,7 +456,7 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
}
if n.punchy.GetTargetEverything() {
hostinfo.remotes.ForEach(n.hostMap.preferredRanges, func(addr *udp.Addr, preferred bool) {
hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr *udp.Addr, preferred bool) {
n.metricsTxPunchy.Inc(1)
n.intf.outside.WriteTo([]byte{1}, addr)
})

View File

@ -43,7 +43,9 @@ func Test_NewConnectionManagerTest(t *testing.T) {
preferredRanges := []*net.IPNet{localrange}
// Very incomplete mock objects
hostMap := NewHostMap(l, vpncidr, preferredRanges)
hostMap := newHostMap(l, vpncidr)
hostMap.preferredRanges.Store(&preferredRanges)
cs := &CertState{
RawCertificate: []byte{},
PrivateKey: []byte{},
@ -123,7 +125,9 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
preferredRanges := []*net.IPNet{localrange}
// Very incomplete mock objects
hostMap := NewHostMap(l, vpncidr, preferredRanges)
hostMap := newHostMap(l, vpncidr)
hostMap.preferredRanges.Store(&preferredRanges)
cs := &CertState{
RawCertificate: []byte{},
PrivateKey: []byte{},
@ -210,7 +214,8 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
preferredRanges := []*net.IPNet{localrange}
hostMap := NewHostMap(l, vpncidr, preferredRanges)
hostMap := newHostMap(l, vpncidr)
hostMap.preferredRanges.Store(&preferredRanges)
// Generate keys for CA and peer's cert.
pubCA, privCA, _ := ed25519.GenerateKey(rand.Reader)

View File

@ -145,7 +145,7 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlH
return nil
}
ch := copyHostInfo(h, c.f.hostMap.preferredRanges)
ch := copyHostInfo(h, c.f.hostMap.GetPreferredRanges())
return &ch
}
@ -157,7 +157,7 @@ func (c *Control) SetRemoteForTunnel(vpnIp iputil.VpnIp, addr udp.Addr) *Control
}
hostInfo.SetRemote(addr.Copy())
ch := copyHostInfo(hostInfo, c.f.hostMap.preferredRanges)
ch := copyHostInfo(hostInfo, c.f.hostMap.GetPreferredRanges())
return &ch
}

View File

@ -18,7 +18,9 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
l := test.NewLogger()
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
// To properly ensure we are not exposing core memory to the caller
hm := NewHostMap(l, &net.IPNet{}, make([]*net.IPNet, 0))
hm := newHostMap(l, &net.IPNet{})
hm.preferredRanges.Store(&[]*net.IPNet{})
remote1 := udp.NewAddr(net.ParseIP("0.0.0.100"), 4444)
remote2 := udp.NewAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444)
ipNet := net.IPNet{

View File

@ -309,6 +309,13 @@ firewall:
outbound_action: drop
inbound_action: drop
# Controls the default value for local_cidr. Default is true, will be deprecated after v1.9 and defaulted to false.
# This setting only affects nebula hosts with subnets encoded in their certificate. A nebula host acting as an
# unsafe router with `default_local_cidr_any: true` will expose their unsafe routes to every inbound rule regardless
# of the actual destination for the packet. Setting this to false requires each inbound rule to contain a `local_cidr`
# if the intention is to allow traffic to flow to an unsafe route.
#default_local_cidr_any: false
conntrack:
tcp_timeout: 12m
udp_timeout: 3m
@ -316,7 +323,7 @@ firewall:
# The firewall is default deny. There is no way to write a deny rule.
# Rules are comprised of a protocol, port, and one or more of host, group, or CIDR
# Logical evaluation is roughly: port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr)
# Logical evaluation is roughly: port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr) AND (local cidr)
# - port: Takes `0` or `any` as any, a single number `80`, a range `200-901`, or `fragment` to match second and further fragments of fragmented packets (since there is no port available).
# code: same as port but makes more sense when talking about ICMP, TODO: this is not currently implemented in a way that works, use `any`
# proto: `any`, `tcp`, `udp`, or `icmp`
@ -325,6 +332,8 @@ firewall:
# groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
# cidr: a remote CIDR, `0.0.0.0/0` is any.
# local_cidr: a local CIDR, `0.0.0.0/0` is any. This could be used to filter destinations when using unsafe_routes.
# Default is `any` unless the certificate contains subnets and then the default is the ip issued in the certificate
# if `default_local_cidr_any` is false, otherwise its `any`.
# ca_name: An issuing CA name
# ca_sha: An issuing CA shasum
@ -346,3 +355,10 @@ firewall:
groups:
- laptop
- home
# Expose a subnet (unsafe route) to hosts with the group remote_client
# This example assume you have a subnet of 192.168.100.1/24 or larger encoded in the certificate
- port: 8080
proto: tcp
group: remote_client
local_cidr: 192.168.100.1/24

View File

@ -57,15 +57,18 @@ type Firewall struct {
DefaultTimeout time.Duration //linux: 600s
// Used to ensure we don't emit local packets for ips we don't own
localIps *cidr.Tree4[struct{}]
localIps *cidr.Tree4[struct{}]
assignedCIDR *net.IPNet
hasSubnets bool
rules string
rulesVersion uint16
trackTCPRTT bool
metricTCPRTT metrics.Histogram
incomingMetrics firewallMetrics
outgoingMetrics firewallMetrics
defaultLocalCIDRAny bool
trackTCPRTT bool
metricTCPRTT metrics.Histogram
incomingMetrics firewallMetrics
outgoingMetrics firewallMetrics
l *logrus.Logger
}
@ -83,6 +86,8 @@ type FirewallConntrack struct {
TimerWheel *TimerWheel[firewall.Packet]
}
// FirewallTable is the entry point for a rule, the evaluation order is:
// Proto AND port AND (CA SHA or CA name) AND local CIDR AND (group OR groups OR name OR remote CIDR)
type FirewallTable struct {
TCP firewallPort
UDP firewallPort
@ -106,18 +111,27 @@ type FirewallCA struct {
}
type FirewallRule struct {
// Any makes Hosts, Groups, CIDR and LocalCIDR irrelevant
Any bool
Hosts map[string]struct{}
Groups [][]string
CIDR *cidr.Tree4[struct{}]
LocalCIDR *cidr.Tree4[struct{}]
// Any makes Hosts, Groups, and CIDR irrelevant
Any *firewallLocalCIDR
Hosts map[string]*firewallLocalCIDR
Groups []*firewallGroups
CIDR *cidr.Tree4[*firewallLocalCIDR]
}
type firewallGroups struct {
Groups []string
LocalCIDR *firewallLocalCIDR
}
// Even though ports are uint16, int32 maps are faster for lookup
// Plus we can use `-1` for fragment rules
type firewallPort map[int32]*FirewallCA
type firewallLocalCIDR struct {
Any bool
LocalCIDR *cidr.Tree4[struct{}]
}
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall {
//TODO: error on 0 duration
@ -138,8 +152,15 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
}
localIps := cidr.NewTree4[struct{}]()
var assignedCIDR *net.IPNet
for _, ip := range c.Details.Ips {
localIps.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
ipNet := &net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}
localIps.AddCIDR(ipNet, struct{}{})
if assignedCIDR == nil {
// Only grabbing the first one in the cert since any more than that currently has undefined behavior
assignedCIDR = ipNet
}
}
for _, n := range c.Details.Subnets {
@ -158,6 +179,8 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
UDPTimeout: UDPTimeout,
DefaultTimeout: defaultTimeout,
localIps: localIps,
assignedCIDR: assignedCIDR,
hasSubnets: len(c.Details.Subnets) > 0,
l: l,
metricTCPRTT: metrics.GetOrRegisterHistogram("network.tcp.rtt", nil, metrics.NewExpDecaySample(1028, 0.015)),
@ -184,6 +207,9 @@ func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *conf
//TODO: max_connections
)
//TODO: Flip to false after v1.9 release
fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", true)
inboundAction := c.GetString("firewall.inbound_action", "drop")
switch inboundAction {
case "reject":
@ -270,7 +296,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
return fmt.Errorf("unknown protocol %v", proto)
}
return fp.addRule(startPort, endPort, groups, host, ip, localIp, caName, caSha)
return fp.addRule(f, startPort, endPort, groups, host, ip, localIp, caName, caSha)
}
// GetRuleHash returns a hash representation of all inbound and outbound rules
@ -624,7 +650,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaC
return false
}
func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
if startPort > endPort {
return fmt.Errorf("start port was lower than end port")
}
@ -637,7 +663,7 @@ func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string,
}
}
if err := fp[i].addRule(groups, host, ip, localIp, caName, caSha); err != nil {
if err := fp[i].addRule(f, groups, host, ip, localIp, caName, caSha); err != nil {
return err
}
}
@ -668,13 +694,12 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCer
return fp[firewall.PortAny].match(p, c, caPool)
}
func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPNet, caName, caSha string) error {
func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp *net.IPNet, caName, caSha string) error {
fr := func() *FirewallRule {
return &FirewallRule{
Hosts: make(map[string]struct{}),
Groups: make([][]string, 0),
CIDR: cidr.NewTree4[struct{}](),
LocalCIDR: cidr.NewTree4[struct{}](),
Hosts: make(map[string]*firewallLocalCIDR),
Groups: make([]*firewallGroups, 0),
CIDR: cidr.NewTree4[*firewallLocalCIDR](),
}
}
@ -683,14 +708,14 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPN
fc.Any = fr()
}
return fc.Any.addRule(groups, host, ip, localIp)
return fc.Any.addRule(f, groups, host, ip, localIp)
}
if caSha != "" {
if _, ok := fc.CAShas[caSha]; !ok {
fc.CAShas[caSha] = fr()
}
err := fc.CAShas[caSha].addRule(groups, host, ip, localIp)
err := fc.CAShas[caSha].addRule(f, groups, host, ip, localIp)
if err != nil {
return err
}
@ -700,7 +725,7 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPN
if _, ok := fc.CANames[caName]; !ok {
fc.CANames[caName] = fr()
}
err := fc.CANames[caName].addRule(groups, host, ip, localIp)
err := fc.CANames[caName].addRule(f, groups, host, ip, localIp)
if err != nil {
return err
}
@ -732,41 +757,63 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool
return fc.CANames[s.Details.Name].match(p, c)
}
func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet, localIp *net.IPNet) error {
if fr.Any {
return nil
func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *net.IPNet, localCIDR *net.IPNet) error {
flc := func() *firewallLocalCIDR {
return &firewallLocalCIDR{
LocalCIDR: cidr.NewTree4[struct{}](),
}
}
if fr.isAny(groups, host, ip, localIp) {
fr.Any = true
// If it's any we need to wipe out any pre-existing rules to save on memory
fr.Groups = make([][]string, 0)
fr.Hosts = make(map[string]struct{})
fr.CIDR = cidr.NewTree4[struct{}]()
fr.LocalCIDR = cidr.NewTree4[struct{}]()
} else {
if len(groups) > 0 {
fr.Groups = append(fr.Groups, groups)
if fr.isAny(groups, host, ip) {
if fr.Any == nil {
fr.Any = flc()
}
if host != "" {
fr.Hosts[host] = struct{}{}
return fr.Any.addRule(f, localCIDR)
}
if len(groups) > 0 {
nlc := flc()
err := nlc.addRule(f, localCIDR)
if err != nil {
return err
}
if ip != nil {
fr.CIDR.AddCIDR(ip, struct{}{})
}
fr.Groups = append(fr.Groups, &firewallGroups{
Groups: groups,
LocalCIDR: nlc,
})
}
if localIp != nil {
fr.LocalCIDR.AddCIDR(localIp, struct{}{})
if host != "" {
nlc := fr.Hosts[host]
if nlc == nil {
nlc = flc()
}
err := nlc.addRule(f, localCIDR)
if err != nil {
return err
}
fr.Hosts[host] = nlc
}
if ip != nil {
_, nlc := fr.CIDR.GetCIDR(ip)
if nlc == nil {
nlc = flc()
}
err := nlc.addRule(f, localCIDR)
if err != nil {
return err
}
fr.CIDR.AddCIDR(ip, nlc)
}
return nil
}
func (fr *FirewallRule) isAny(groups []string, host string, ip, localIp *net.IPNet) bool {
if len(groups) == 0 && host == "" && ip == nil && localIp == nil {
func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool {
if len(groups) == 0 && host == "" && ip == nil {
return true
}
@ -784,10 +831,6 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip, localIp *net.IPN
return true
}
if localIp != nil && localIp.Contains(net.IPv4(0, 0, 0, 0)) {
return true
}
return false
}
@ -797,7 +840,7 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
}
// Shortcut path for if groups, hosts, or cidr contained an `any`
if fr.Any {
if fr.Any.match(p, c) {
return true
}
@ -805,7 +848,7 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
for _, sg := range fr.Groups {
found := false
for _, g := range sg {
for _, g := range sg.Groups {
if _, ok := c.Details.InvertedGroups[g]; !ok {
found = false
break
@ -814,33 +857,51 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
found = true
}
if found {
if found && sg.LocalCIDR.match(p, c) {
return true
}
}
if fr.Hosts != nil {
if _, ok := fr.Hosts[c.Details.Name]; ok {
return true
if flc, ok := fr.Hosts[c.Details.Name]; ok {
if flc.match(p, c) {
return true
}
}
}
if fr.CIDR != nil {
ok, _ := fr.CIDR.Contains(p.RemoteIP)
if ok {
return true
return fr.CIDR.EachContains(p.RemoteIP, func(flc *firewallLocalCIDR) bool {
return flc.match(p, c)
})
}
func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp *net.IPNet) error {
if localIp == nil {
if !f.hasSubnets || f.defaultLocalCIDRAny {
flc.Any = true
return nil
}
localIp = f.assignedCIDR
} else if localIp.Contains(net.IPv4(0, 0, 0, 0)) {
flc.Any = true
}
if fr.LocalCIDR != nil {
ok, _ := fr.LocalCIDR.Contains(p.LocalIP)
if ok {
return true
}
flc.LocalCIDR.AddCIDR(localIp, struct{}{})
return nil
}
func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate) bool {
if flc == nil {
return false
}
// No host, group, or cidr matched, bye bye
return false
if flc.Any {
return true
}
ok, _ := flc.LocalCIDR.Contains(p.LocalIP)
return ok
}
type rule struct {

View File

@ -71,36 +71,32 @@ func TestFirewall_AddRule(t *testing.T) {
assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, nil, "", ""))
// An empty rule is any
assert.True(t, fw.InRules.TCP[1].Any.Any)
assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", ""))
assert.False(t, fw.InRules.UDP[1].Any.Any)
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1")
assert.Nil(t, fw.InRules.UDP[1].Any.Any)
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, nil, "", ""))
assert.False(t, fw.InRules.ICMP[1].Any.Any)
assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, nil, "", ""))
assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
ok, _ := fw.OutRules.AnyProto[1].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP))
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
ok, _ := fw.OutRules.AnyProto[1].Any.CIDR.GetCIDR(ti)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", nil, ti, "", ""))
assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
ok, _ = fw.OutRules.AnyProto[1].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
ok, _ = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.GetCIDR(ti)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
@ -111,32 +107,14 @@ func TestFirewall_AddRule(t *testing.T) {
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", "ca-sha"))
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
// Set any and clear fields
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, ti, "", ""))
assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0])
assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")
ok, _ = fw.OutRules.AnyProto[0].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP))
assert.True(t, ok)
ok, _ = fw.OutRules.AnyProto[0].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP))
assert.True(t, ok)
// run twice just to make sure
//TODO: these ANY rules should clear the CA firewall portion
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
assert.Empty(t, fw.OutRules.AnyProto[0].Any.Groups)
assert.Empty(t, fw.OutRules.AnyProto[0].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
_, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, nil, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
// Test error conditions
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
@ -226,33 +204,43 @@ func TestFirewall_Drop(t *testing.T) {
}
func BenchmarkFirewallTable_match(b *testing.B) {
f := &Firewall{}
ft := FirewallTable{
TCP: firewallPort{},
}
_, n, _ := net.ParseCIDR("172.1.1.1/32")
_ = ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, n, "", "")
_ = ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, n, "", "")
_ = ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, n, "", "")
_ = ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, n, "", "")
_ = ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, n, "", "")
goodLocalCIDRIP := iputil.Ip2VpnIp(n.IP)
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", n, nil, "", "")
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", nil, n, "", "")
cp := cert.NewCAPool()
b.Run("fail on proto", func(b *testing.B) {
// This benchmark is showing us the cost of failing to match the protocol
c := &cert.NebulaCertificate{}
for n := 0; n < b.N; n++ {
ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp)
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp))
}
})
b.Run("fail on port", func(b *testing.B) {
b.Run("pass proto, fail on port", func(b *testing.B) {
// This benchmark is showing us the cost of matching a specific protocol but failing to match the port
c := &cert.NebulaCertificate{}
for n := 0; n < b.N; n++ {
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp)
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp))
}
})
b.Run("fail all group, name, and cidr", func(b *testing.B) {
b.Run("pass proto, port, fail on local CIDR", func(b *testing.B) {
c := &cert.NebulaCertificate{}
ip, _, _ := net.ParseCIDR("9.254.254.254/32")
lip := iputil.Ip2VpnIp(ip)
for n := 0; n < b.N; n++ {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: lip}, true, c, cp))
}
})
b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) {
_, ip, _ := net.ParseCIDR("9.254.254.254/32")
c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
@ -262,11 +250,25 @@ func BenchmarkFirewallTable_match(b *testing.B) {
},
}
for n := 0; n < b.N; n++ {
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
}
})
b.Run("pass on group", func(b *testing.B) {
b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) {
_, ip, _ := net.ParseCIDR("9.254.254.254/32")
c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
InvertedGroups: map[string]struct{}{"nope": {}},
Name: "nope",
Ips: []*net.IPNet{ip},
},
}
for n := 0; n < b.N; n++ {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp))
}
})
b.Run("pass on group on any local cidr", func(b *testing.B) {
c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
InvertedGroups: map[string]struct{}{"good-group": {}},
@ -274,7 +276,19 @@ func BenchmarkFirewallTable_match(b *testing.B) {
},
}
for n := 0; n < b.N; n++ {
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
}
})
b.Run("pass on group on specific local cidr", func(b *testing.B) {
c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
InvertedGroups: map[string]struct{}{"good-group": {}},
Name: "nope",
},
}
for n := 0; n < b.N; n++ {
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp))
}
})
@ -289,60 +303,60 @@ func BenchmarkFirewallTable_match(b *testing.B) {
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
}
})
b.Run("pass on ip", func(b *testing.B) {
ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
InvertedGroups: map[string]struct{}{"nope": {}},
Name: "good-host",
},
}
for n := 0; n < b.N; n++ {
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp)
}
})
b.Run("pass on local ip", func(b *testing.B) {
ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
InvertedGroups: map[string]struct{}{"nope": {}},
Name: "good-host",
},
}
for n := 0; n < b.N; n++ {
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, LocalIP: ip}, true, c, cp)
}
})
_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, n, "", "")
b.Run("pass on ip with any port", func(b *testing.B) {
ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
InvertedGroups: map[string]struct{}{"nope": {}},
Name: "good-host",
},
}
for n := 0; n < b.N; n++ {
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp)
}
})
b.Run("pass on local ip with any port", func(b *testing.B) {
ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
InvertedGroups: map[string]struct{}{"nope": {}},
Name: "good-host",
},
}
for n := 0; n < b.N; n++ {
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip}, true, c, cp)
}
})
//
//b.Run("pass on ip", func(b *testing.B) {
// ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
// c := &cert.NebulaCertificate{
// Details: cert.NebulaCertificateDetails{
// InvertedGroups: map[string]struct{}{"nope": {}},
// Name: "good-host",
// },
// }
// for n := 0; n < b.N; n++ {
// ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp)
// }
//})
//
//b.Run("pass on local ip", func(b *testing.B) {
// ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
// c := &cert.NebulaCertificate{
// Details: cert.NebulaCertificateDetails{
// InvertedGroups: map[string]struct{}{"nope": {}},
// Name: "good-host",
// },
// }
// for n := 0; n < b.N; n++ {
// ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, LocalIP: ip}, true, c, cp)
// }
//})
//
//_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, n, "", "")
//
//b.Run("pass on ip with any port", func(b *testing.B) {
// ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
// c := &cert.NebulaCertificate{
// Details: cert.NebulaCertificateDetails{
// InvertedGroups: map[string]struct{}{"nope": {}},
// Name: "good-host",
// },
// }
// for n := 0; n < b.N; n++ {
// ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp)
// }
//})
//
//b.Run("pass on local ip with any port", func(b *testing.B) {
// ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
// c := &cert.NebulaCertificate{
// Details: cert.NebulaCertificateDetails{
// InvertedGroups: map[string]struct{}{"nope": {}},
// Name: "good-host",
// },
// }
// for n := 0; n < b.N; n++ {
// ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip}, true, c, cp)
// }
//})
}
func TestFirewall_Drop2(t *testing.T) {

14
go.mod
View File

@ -8,7 +8,7 @@ require (
github.com/armon/go-radix v1.0.0
github.com/clarkmcc/go-dag v0.0.0-20220908000337-9c3ba5b365fc
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
github.com/flynn/noise v1.0.1
github.com/flynn/noise v1.1.0
github.com/gogo/protobuf v1.3.2
github.com/google/gopacket v1.1.19
github.com/kardianos/service v1.2.2
@ -19,19 +19,19 @@ require (
github.com/sirupsen/logrus v1.9.3
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
github.com/stretchr/testify v1.8.4
github.com/stretchr/testify v1.9.0
github.com/timandy/routine v1.1.1
github.com/vishvananda/netlink v1.2.1-beta.2
golang.org/x/crypto v0.18.0
golang.org/x/crypto v0.21.0
golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53
golang.org/x/net v0.20.0
golang.org/x/net v0.22.0
golang.org/x/sync v0.6.0
golang.org/x/sys v0.16.0
golang.org/x/term v0.16.0
golang.org/x/sys v0.18.0
golang.org/x/term v0.18.0
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/protobuf v1.32.0
google.golang.org/protobuf v1.33.0
gopkg.in/yaml.v2 v2.4.0
gvisor.dev/gvisor v0.0.0-20230504175454-7b0a1988a28f
)

28
go.sum
View File

@ -24,8 +24,8 @@ github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432/go.
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/flynn/noise v1.0.1 h1:vPp/jdQLXC6ppsXSj/pM3W1BIJ5FEHE2TulSJBpb43Y=
github.com/flynn/noise v1.0.1/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
@ -134,10 +134,10 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/timandy/routine v1.1.1 h1:6/Z7qLFZj3GrzuRksBFzIG8YGUh8CLhjnnMePBQTrEI=
github.com/timandy/routine v1.1.1/go.mod h1:OZHPOKSvqL/ZvqXFkNZyit0xIVelERptYXdAHH00adQ=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/vishvananda/netlink v1.2.1-beta.2 h1:Llsql0lnQEbHj0I1OuKyp8otXp0r3q0mPkuhwHfStVs=
github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho=
github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
@ -150,8 +150,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc=
golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53 h1:5llv2sWeaMSnA3w2kS57ouQQ4pudlXrR0dCgw51QK9o=
golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
@ -170,8 +170,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo=
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc=
golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@ -198,11 +198,11 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU=
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.16.0 h1:m+B6fahuftsE9qjo0VWp2FW0mB3MTJvR0BaMQrq0pmE=
golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8=
golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@ -233,8 +233,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.32.0 h1:pPC6BG5ex8PDFnkbrGU3EixyhKcQ2aDuBS36lqK/C7I=
google.golang.org/protobuf v1.32.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@ -408,7 +408,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha
hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp).
WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.preferredRanges)).
WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())).
Info("Blocked addresses for handshakes")
// Swap the packet store to benefit the original intended recipient

View File

@ -181,7 +181,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
hostinfo := hh.hostinfo
// If we are out of time, clean up
if hh.counter >= hm.config.retries {
hh.hostinfo.logger(hm.l).WithField("udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.preferredRanges)).
hh.hostinfo.logger(hm.l).WithField("udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())).
WithField("initiatorIndex", hh.hostinfo.localIndexId).
WithField("remoteIndex", hh.hostinfo.remoteIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
@ -211,7 +211,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
hostinfo.remotes = hm.lightHouse.QueryCache(vpnIp)
}
remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.preferredRanges)
remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())
remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hh.lastRemotes)
// We only care about a lighthouse trigger if we have new remotes to send to.
@ -235,7 +235,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
// Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
var sentTo []*udp.Addr
hostinfo.remotes.ForEach(hm.mainHostMap.preferredRanges, func(addr *udp.Addr, _ bool) {
hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr *udp.Addr, _ bool) {
hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
if err != nil {
@ -362,7 +362,7 @@ func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han
hm.mainHostMap.RUnlock()
// Do not attempt promotion if you are a lighthouse
if !hm.lightHouse.amLighthouse {
h.TryPromoteBest(hm.mainHostMap.preferredRanges, hm.f)
h.TryPromoteBest(hm.mainHostMap.GetPreferredRanges(), hm.f)
}
return h, true
}
@ -602,7 +602,7 @@ func (hm *HandshakeManager) queryIndex(index uint32) *HandshakeHostInfo {
}
func (c *HandshakeManager) GetPreferredRanges() []*net.IPNet {
return c.mainHostMap.preferredRanges
return c.mainHostMap.GetPreferredRanges()
}
func (c *HandshakeManager) ForEachVpnIp(f controlEach) {

View File

@ -19,7 +19,9 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2"))
preferredRanges := []*net.IPNet{localrange}
mainHM := NewHostMap(l, vpncidr, preferredRanges)
mainHM := newHostMap(l, vpncidr)
mainHM.preferredRanges.Store(&preferredRanges)
lh := newTestLighthouse()
cs := &CertState{

View File

@ -10,6 +10,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
@ -56,9 +57,8 @@ type HostMap struct {
Relays map[uint32]*HostInfo // Maps a Relay IDX to a Relay HostInfo object
RemoteIndexes map[uint32]*HostInfo
Hosts map[iputil.VpnIp]*HostInfo
preferredRanges []*net.IPNet
preferredRanges atomic.Pointer[[]*net.IPNet]
vpnCIDR *net.IPNet
metricsEnabled bool
l *logrus.Logger
}
@ -254,22 +254,54 @@ type cachedPacketMetrics struct {
dropped metrics.Counter
}
func NewHostMap(l *logrus.Logger, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap {
h := map[iputil.VpnIp]*HostInfo{}
i := map[uint32]*HostInfo{}
r := map[uint32]*HostInfo{}
relays := map[uint32]*HostInfo{}
m := HostMap{
syncRWMutex: newSyncRWMutex("hostmap"),
Indexes: i,
Relays: relays,
RemoteIndexes: r,
Hosts: h,
preferredRanges: preferredRanges,
vpnCIDR: vpnCIDR,
l: l,
func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR *net.IPNet, c *config.C) *HostMap {
hm := newHostMap(l, vpnCIDR)
hm.reload(c, true)
c.RegisterReloadCallback(func(c *config.C) {
hm.reload(c, false)
})
l.WithField("network", hm.vpnCIDR.String()).
WithField("preferredRanges", hm.GetPreferredRanges()).
Info("Main HostMap created")
return hm
}
func newHostMap(l *logrus.Logger, vpnCIDR *net.IPNet) *HostMap {
return &HostMap{
syncRWMutex: newSyncRWMutex("hostmap"),
Indexes: map[uint32]*HostInfo{},
Relays: map[uint32]*HostInfo{},
RemoteIndexes: map[uint32]*HostInfo{},
Hosts: map[iputil.VpnIp]*HostInfo{},
vpnCIDR: vpnCIDR,
l: l,
}
}
func (hm *HostMap) reload(c *config.C, initial bool) {
if initial || c.HasChanged("preferred_ranges") {
var preferredRanges []*net.IPNet
rawPreferredRanges := c.GetStringSlice("preferred_ranges", []string{})
for _, rawPreferredRange := range rawPreferredRanges {
_, preferredRange, err := net.ParseCIDR(rawPreferredRange)
if err != nil {
hm.l.WithError(err).WithField("range", rawPreferredRanges).Warn("Failed to parse preferred ranges, ignoring")
continue
}
preferredRanges = append(preferredRanges, preferredRange)
}
oldRanges := hm.preferredRanges.Swap(&preferredRanges)
if !initial {
hm.l.WithField("oldPreferredRanges", *oldRanges).WithField("newPreferredRanges", preferredRanges).Info("preferred_ranges changed")
}
}
return &m
}
// EmitStats reports host, index, and relay counts to the stats collection system
@ -458,7 +490,7 @@ func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) *HostI
hm.RUnlock()
// Do not attempt promotion if you are a lighthouse
if promoteIfce != nil && !promoteIfce.lightHouse.amLighthouse {
h.TryPromoteBest(hm.preferredRanges, promoteIfce)
h.TryPromoteBest(hm.GetPreferredRanges(), promoteIfce)
}
return h
@ -505,7 +537,8 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
}
func (hm *HostMap) GetPreferredRanges() []*net.IPNet {
return hm.preferredRanges
//NOTE: if preferredRanges is ever not stored before a load this will fail to dereference a nil pointer
return *hm.preferredRanges.Load()
}
func (hm *HostMap) ForEachVpnIp(f controlEach) {
@ -597,7 +630,7 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool {
// NOTE: We do this loop here instead of calling `isPreferred` in
// remote_list.go so that we only have to loop over preferredRanges once.
newIsPreferred := false
for _, l := range hm.preferredRanges {
for _, l := range hm.GetPreferredRanges() {
// return early if we are already on a preferred remote
if l.Contains(currentRemote.IP) {
return false

View File

@ -4,19 +4,19 @@ import (
"net"
"testing"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
)
func TestHostMap_MakePrimary(t *testing.T) {
l := test.NewLogger()
hm := NewHostMap(
hm := newHostMap(
l,
&net.IPNet{
IP: net.IP{10, 0, 0, 1},
Mask: net.IPMask{255, 255, 255, 0},
},
[]*net.IPNet{},
)
f := &Interface{}
@ -91,13 +91,12 @@ func TestHostMap_MakePrimary(t *testing.T) {
func TestHostMap_DeleteHostInfo(t *testing.T) {
l := test.NewLogger()
hm := NewHostMap(
hm := newHostMap(
l,
&net.IPNet{
IP: net.IP{10, 0, 0, 1},
Mask: net.IPMask{255, 255, 255, 0},
},
[]*net.IPNet{},
)
f := &Interface{}
@ -205,3 +204,33 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
prim = hm.QueryVpnIp(1)
assert.Nil(t, prim)
}
func TestHostMap_reload(t *testing.T) {
l := test.NewLogger()
c := config.NewC(l)
hm := NewHostMapFromConfig(
l,
&net.IPNet{
IP: net.IP{10, 0, 0, 1},
Mask: net.IPMask{255, 255, 255, 0},
},
c,
)
toS := func(ipn []*net.IPNet) []string {
var s []string
for _, n := range ipn {
s = append(s, n.String())
}
return s
}
assert.Empty(t, hm.GetPreferredRanges())
c.ReloadConfigString("preferred_ranges: [1.1.1.0/24, 10.1.1.0/24]")
assert.EqualValues(t, []string{"1.1.1.0/24", "10.1.1.0/24"}, toS(hm.GetPreferredRanges()))
c.ReloadConfigString("preferred_ranges: [1.1.1.1/32]")
assert.EqualValues(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges()))
}

47
main.go
View File

@ -183,52 +183,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
}
}
// Set up my internal host map
var preferredRanges []*net.IPNet
rawPreferredRanges := c.GetStringSlice("preferred_ranges", []string{})
// First, check if 'preferred_ranges' is set and fallback to 'local_range'
if len(rawPreferredRanges) > 0 {
for _, rawPreferredRange := range rawPreferredRanges {
_, preferredRange, err := net.ParseCIDR(rawPreferredRange)
if err != nil {
return nil, util.ContextualizeIfNeeded("Failed to parse preferred ranges", err)
}
preferredRanges = append(preferredRanges, preferredRange)
}
}
// local_range was superseded by preferred_ranges. If it is still present,
// merge the local_range setting into preferred_ranges. We will probably
// deprecate local_range and remove in the future.
rawLocalRange := c.GetString("local_range", "")
if rawLocalRange != "" {
_, localRange, err := net.ParseCIDR(rawLocalRange)
if err != nil {
return nil, util.ContextualizeIfNeeded("Failed to parse local_range", err)
}
// Check if the entry for local_range was already specified in
// preferred_ranges. Don't put it into the slice twice if so.
var found bool
for _, r := range preferredRanges {
if r.String() == localRange.String() {
found = true
break
}
}
if !found {
preferredRanges = append(preferredRanges, localRange)
}
}
hostMap := NewHostMap(l, tunCidr, preferredRanges)
hostMap.metricsEnabled = c.GetBool("stats.message_metrics", false)
l.
WithField("network", hostMap.vpnCIDR.String()).
WithField("preferredRanges", hostMap.preferredRanges).
Info("Main HostMap created")
hostMap := NewHostMapFromConfig(l, tunCidr, c)
punchy := NewPunchyFromConfig(l, c)
lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy)
if err != nil {

View File

@ -1,6 +1,7 @@
package overlay
import (
"bytes"
"fmt"
"math"
"net"
@ -21,6 +22,35 @@ type Route struct {
Install bool
}
// Equal determines if a route that could be installed in the system route table is equal to another
// Via is ignored since that is only consumed within nebula itself
func (r Route) Equal(t Route) bool {
if !r.Cidr.IP.Equal(t.Cidr.IP) {
return false
}
if !bytes.Equal(r.Cidr.Mask, t.Cidr.Mask) {
return false
}
if r.Metric != t.Metric {
return false
}
if r.MTU != t.MTU {
return false
}
if r.Install != t.Install {
return false
}
return true
}
func (r Route) String() string {
s := r.Cidr.String()
if r.Metric != 0 {
s += fmt.Sprintf(" metric: %v", r.Metric)
}
return s
}
func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*cidr.Tree4[iputil.VpnIp], error) {
routeTree := cidr.NewTree4[iputil.VpnIp]()
for _, r := range routes {

View File

@ -10,60 +10,63 @@ import (
const DefaultMTU = 1300
// TODO: We may be able to remove routines
type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error)
func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) {
routes, err := parseRoutes(c, tunCidr)
if err != nil {
return nil, util.NewContextualError("Could not parse tun.routes", nil, err)
}
unsafeRoutes, err := parseUnsafeRoutes(c, tunCidr)
if err != nil {
return nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err)
}
routes = append(routes, unsafeRoutes...)
switch {
case c.GetBool("tun.disabled", false):
tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
return tun, nil
default:
return newTun(
l,
c.GetString("tun.dev", ""),
tunCidr,
c.GetInt("tun.mtu", DefaultMTU),
routes,
c.GetInt("tun.tx_queue", 500),
routines > 1,
c.GetBool("tun.use_system_route_table", false),
)
return newTun(c, l, tunCidr, routines > 1)
}
}
func NewFdDeviceFromConfig(fd *int) DeviceFactory {
return func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) {
routes, err := parseRoutes(c, tunCidr)
if err != nil {
return nil, util.NewContextualError("Could not parse tun.routes", nil, err)
}
unsafeRoutes, err := parseUnsafeRoutes(c, tunCidr)
if err != nil {
return nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err)
}
routes = append(routes, unsafeRoutes...)
return newTunFromFd(
l,
*fd,
tunCidr,
c.GetInt("tun.mtu", DefaultMTU),
routes,
c.GetInt("tun.tx_queue", 500),
c.GetBool("tun.use_system_route_table", false),
)
return newTunFromFd(c, l, *fd, tunCidr)
}
}
func getAllRoutesFromConfig(c *config.C, cidr *net.IPNet, initial bool) (bool, []Route, error) {
if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") {
return false, nil, nil
}
routes, err := parseRoutes(c, cidr)
if err != nil {
return true, nil, util.NewContextualError("Could not parse tun.routes", nil, err)
}
unsafeRoutes, err := parseUnsafeRoutes(c, cidr)
if err != nil {
return true, nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err)
}
routes = append(routes, unsafeRoutes...)
return true, routes, nil
}
// findRemovedRoutes will return all routes that are not present in the newRoutes list and would affect the system route table.
// Via is not used to evaluate since it does not affect the system route table.
func findRemovedRoutes(newRoutes, oldRoutes []Route) []Route {
var removed []Route
has := func(entry Route) bool {
for _, check := range newRoutes {
if check.Equal(entry) {
return true
}
}
return false
}
for _, oldEntry := range oldRoutes {
if !has(oldEntry) {
removed = append(removed, oldEntry)
}
}
return removed
}

View File

@ -8,45 +8,57 @@ import (
"io"
"net"
"os"
"sync/atomic"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/util"
)
type tun struct {
io.ReadWriteCloser
fd int
cidr *net.IPNet
routeTree *cidr.Tree4[iputil.VpnIp]
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
l *logrus.Logger
}
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool) (*tun, error) {
routeTree, err := makeRouteTree(l, routes, false)
if err != nil {
return nil, err
}
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) (*tun, error) {
// XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly.
// Be sure not to call file.Fd() as it will set the fd to blocking mode.
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
return &tun{
t := &tun{
ReadWriteCloser: file,
fd: deviceFd,
cidr: cidr,
l: l,
routeTree: routeTree,
}, nil
}
err := t.reload(c, true)
if err != nil {
return nil, err
}
c.RegisterReloadCallback(func(c *config.C) {
err := t.reload(c, false)
if err != nil {
util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
}
})
return t, nil
}
func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool, _ bool) (*tun, error) {
func newTun(_ *config.C, _ *logrus.Logger, _ *net.IPNet, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTun not supported in Android")
}
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
_, r := t.routeTree.MostSpecificContains(ip)
_, r := t.routeTree.Load().MostSpecificContains(ip)
return r
}
@ -54,6 +66,27 @@ func (t tun) Activate() error {
return nil
}
func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
if err != nil {
return err
}
if !initial && !change {
return nil
}
routeTree, err := makeRouteTree(t.l, routes, false)
if err != nil {
return err
}
// Teach nebula how to handle the routes
t.Routes.Store(&routes)
t.routeTree.Store(routeTree)
return nil
}
func (t *tun) Cidr() *net.IPNet {
return t.cidr
}

View File

@ -9,12 +9,15 @@ import (
"io"
"net"
"os"
"sync/atomic"
"syscall"
"unsafe"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/util"
netroute "golang.org/x/net/route"
"golang.org/x/sys/unix"
)
@ -24,8 +27,9 @@ type tun struct {
Device string
cidr *net.IPNet
DefaultMTU int
Routes []Route
routeTree *cidr.Tree4[iputil.VpnIp]
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
linkAddr *netroute.LinkAddr
l *logrus.Logger
// cache out buffer since we need to prepend 4 bytes for tun metadata
@ -69,12 +73,8 @@ type ifreqMTU struct {
pad [8]byte
}
func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) {
routeTree, err := makeRouteTree(l, routes, false)
if err != nil {
return nil, err
}
func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) {
name := c.GetString("tun.dev", "")
ifIndex := -1
if name != "" && name != "utun" {
_, err := fmt.Sscanf(name, "utun%d", &ifIndex)
@ -142,17 +142,27 @@ func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, rout
file := os.NewFile(uintptr(fd), "")
tun := &tun{
t := &tun{
ReadWriteCloser: file,
Device: name,
cidr: cidr,
DefaultMTU: defaultMTU,
Routes: routes,
routeTree: routeTree,
DefaultMTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
}
return tun, nil
err = t.reload(c, true)
if err != nil {
return nil, err
}
c.RegisterReloadCallback(func(c *config.C) {
err := t.reload(c, false)
if err != nil {
util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
}
})
return t, nil
}
func (t *tun) deviceBytes() (o [16]byte) {
@ -162,7 +172,7 @@ func (t *tun) deviceBytes() (o [16]byte) {
return
}
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) {
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
}
@ -260,6 +270,7 @@ func (t *tun) Activate() error {
if linkAddr == nil {
return fmt.Errorf("unable to discover link_addr for tun interface")
}
t.linkAddr = linkAddr
copy(routeAddr.IP[:], addr[:])
copy(maskAddr.IP[:], mask[:])
@ -278,33 +289,48 @@ func (t *tun) Activate() error {
}
// Unsafe path routes
for _, r := range t.Routes {
if r.Via == nil || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
return t.addRoutes(false)
}
copy(routeAddr.IP[:], r.Cidr.IP.To4())
copy(maskAddr.IP[:], net.IP(r.Cidr.Mask).To4())
func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
if err != nil {
return err
}
err = addRoute(routeSock, routeAddr, maskAddr, linkAddr)
if !initial && !change {
return nil
}
routeTree, err := makeRouteTree(t.l, routes, false)
if err != nil {
return err
}
// Teach nebula how to handle the routes before establishing them in the system table
oldRoutes := t.Routes.Swap(&routes)
t.routeTree.Store(routeTree)
if !initial {
// Remove first, if the system removes a wanted route hopefully it will be re-added next
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
if err != nil {
if errors.Is(err, unix.EEXIST) {
t.l.WithField("route", r.Cidr).
Warnf("unable to add unsafe_route, identical route already exists")
} else {
return err
}
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
}
// TODO how to set metric
// Ensure any routes we actually want are installed
err = t.addRoutes(true)
if err != nil {
// Catch any stray logs
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
}
}
return nil
}
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
ok, r := t.routeTree.MostSpecificContains(ip)
ok, r := t.routeTree.Load().MostSpecificContains(ip)
if ok {
return r
}
@ -340,6 +366,88 @@ func getLinkAddr(name string) (*netroute.LinkAddr, error) {
return nil, nil
}
func (t *tun) addRoutes(logErrors bool) error {
routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer func() {
unix.Shutdown(routeSock, unix.SHUT_RDWR)
err := unix.Close(routeSock)
if err != nil {
t.l.WithError(err).Error("failed to close AF_ROUTE socket")
}
}()
routeAddr := &netroute.Inet4Addr{}
maskAddr := &netroute.Inet4Addr{}
routes := *t.Routes.Load()
for _, r := range routes {
if r.Via == nil || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
copy(routeAddr.IP[:], r.Cidr.IP.To4())
copy(maskAddr.IP[:], net.IP(r.Cidr.Mask).To4())
err := addRoute(routeSock, routeAddr, maskAddr, t.linkAddr)
if err != nil {
if errors.Is(err, unix.EEXIST) {
t.l.WithField("route", r.Cidr).
Warnf("unable to add unsafe_route, identical route already exists")
} else {
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
if logErrors {
retErr.Log(t.l)
} else {
return retErr
}
}
} else {
t.l.WithField("route", r).Info("Added route")
}
}
return nil
}
func (t *tun) removeRoutes(routes []Route) error {
routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer func() {
unix.Shutdown(routeSock, unix.SHUT_RDWR)
err := unix.Close(routeSock)
if err != nil {
t.l.WithError(err).Error("failed to close AF_ROUTE socket")
}
}()
routeAddr := &netroute.Inet4Addr{}
maskAddr := &netroute.Inet4Addr{}
for _, r := range routes {
if !r.Install {
continue
}
copy(routeAddr.IP[:], r.Cidr.IP.To4())
copy(maskAddr.IP[:], net.IP(r.Cidr.Mask).To4())
err := delRoute(routeSock, routeAddr, maskAddr, t.linkAddr)
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
}
}
return nil
}
func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error {
r := netroute.RouteMessage{
Version: unix.RTM_VERSION,
@ -365,6 +473,30 @@ func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr)
return nil
}
func delRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error {
r := netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_DELETE,
Seq: 1,
Addrs: []netroute.Addr{
unix.RTAX_DST: addr,
unix.RTAX_GATEWAY: link,
unix.RTAX_NETMASK: mask,
},
}
data, err := r.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
}
return nil
}
func (t *tun) Read(to []byte) (int, error) {
buf := make([]byte, len(to)+4)

View File

@ -13,12 +13,15 @@ import (
"os"
"os/exec"
"strconv"
"sync/atomic"
"syscall"
"unsafe"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/util"
)
const (
@ -47,8 +50,8 @@ type tun struct {
Device string
cidr *net.IPNet
MTU int
Routes []Route
routeTree *cidr.Tree4[iputil.VpnIp]
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
l *logrus.Logger
io.ReadWriteCloser
@ -76,14 +79,15 @@ func (t *tun) Close() error {
return nil
}
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) {
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
}
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) {
func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) {
// Try to open existing tun device
var file *os.File
var err error
deviceName := c.GetString("tun.dev", "")
if deviceName != "" {
file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
}
@ -144,47 +148,85 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int
ioctl(fd, syscall.SIOCSIFNAME, uintptr(unsafe.Pointer(&ifrr)))
}
routeTree, err := makeRouteTree(l, routes, false)
t := &tun{
ReadWriteCloser: file,
Device: deviceName,
cidr: cidr,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
}
err = t.reload(c, true)
if err != nil {
return nil, err
}
return &tun{
ReadWriteCloser: file,
Device: deviceName,
cidr: cidr,
MTU: defaultMTU,
Routes: routes,
routeTree: routeTree,
l: l,
}, nil
c.RegisterReloadCallback(func(c *config.C) {
err := t.reload(c, false)
if err != nil {
util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
}
})
return t, nil
}
func (t *tun) Activate() error {
var err error
// TODO use syscalls instead of exec.Command
t.l.Debug("command: ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String())
if err = exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String()).Run(); err != nil {
cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
t.l.Debug("command: route", "-n", "add", "-net", t.cidr.String(), "-interface", t.Device)
if err = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), "-interface", t.Device).Run(); err != nil {
cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), "-interface", t.Device)
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err)
}
t.l.Debug("command: ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
if err = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU)).Run(); err != nil {
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
// Unsafe path routes
for _, r := range t.Routes {
if r.Via == nil || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
return t.addRoutes(false)
}
func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
if err != nil {
return err
}
if !initial && !change {
return nil
}
routeTree, err := makeRouteTree(t.l, routes, false)
if err != nil {
return err
}
// Teach nebula how to handle the routes before establishing them in the system table
oldRoutes := t.Routes.Swap(&routes)
t.routeTree.Store(routeTree)
if !initial {
// Remove first, if the system removes a wanted route hopefully it will be re-added next
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
if err != nil {
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
}
t.l.Debug("command: route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device)
if err = exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device).Run(); err != nil {
return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.Cidr.String(), err)
// Ensure any routes we actually want are installed
err = t.addRoutes(true)
if err != nil {
// Catch any stray logs
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
}
}
@ -192,7 +234,7 @@ func (t *tun) Activate() error {
}
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
_, r := t.routeTree.MostSpecificContains(ip)
_, r := t.routeTree.Load().MostSpecificContains(ip)
return r
}
@ -208,6 +250,46 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
}
func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load()
for _, r := range routes {
if r.Via == nil || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device)
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
if logErrors {
retErr.Log(t.l)
} else {
return retErr
}
}
}
return nil
}
func (t *tun) removeRoutes(routes []Route) error {
for _, r := range routes {
if !r.Install {
continue
}
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), "-interface", t.Device)
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
}
}
return nil
}
func (t *tun) deviceBytes() (o [16]byte) {
for i, c := range t.Device {
o[i] = byte(c)

View File

@ -10,43 +10,78 @@ import (
"net"
"os"
"sync"
"sync/atomic"
"syscall"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/util"
)
type tun struct {
io.ReadWriteCloser
cidr *net.IPNet
routeTree *cidr.Tree4[iputil.VpnIp]
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
l *logrus.Logger
}
func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool, _ bool) (*tun, error) {
func newTun(_ *config.C, _ *logrus.Logger, _ *net.IPNet, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTun not supported in iOS")
}
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool) (*tun, error) {
routeTree, err := makeRouteTree(l, routes, false)
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) (*tun, error) {
file := os.NewFile(uintptr(deviceFd), "/dev/tun")
t := &tun{
cidr: cidr,
ReadWriteCloser: &tunReadCloser{f: file},
l: l,
}
err := t.reload(c, true)
if err != nil {
return nil, err
}
file := os.NewFile(uintptr(deviceFd), "/dev/tun")
return &tun{
cidr: cidr,
ReadWriteCloser: &tunReadCloser{f: file},
routeTree: routeTree,
}, nil
c.RegisterReloadCallback(func(c *config.C) {
err := t.reload(c, false)
if err != nil {
util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
}
})
return t, nil
}
func (t *tun) Activate() error {
return nil
}
func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
if err != nil {
return err
}
if !initial && !change {
return nil
}
routeTree, err := makeRouteTree(t.l, routes, false)
if err != nil {
return err
}
// Teach nebula how to handle the routes
t.Routes.Store(&routes)
t.routeTree.Store(routeTree)
return nil
}
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
_, r := t.routeTree.MostSpecificContains(ip)
_, r := t.routeTree.Load().MostSpecificContains(ip)
return r
}

View File

@ -15,21 +15,25 @@ import (
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/util"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
)
type tun struct {
io.ReadWriteCloser
fd int
Device string
cidr *net.IPNet
MaxMTU int
DefaultMTU int
TXQueueLen int
fd int
Device string
cidr *net.IPNet
MaxMTU int
DefaultMTU int
TXQueueLen int
deviceIndex int
ioctlFd uintptr
Routes []Route
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
routeChan chan struct{}
useSystemRoutes bool
@ -61,30 +65,20 @@ type ifreqQLEN struct {
pad [8]byte
}
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int, useSystemRoutes bool) (*tun, error) {
routeTree, err := makeRouteTree(l, routes, true)
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) (*tun, error) {
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
t, err := newTunGeneric(c, l, file, cidr)
if err != nil {
return nil, err
}
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
t.Device = "tun0"
t := &tun{
ReadWriteCloser: file,
fd: int(file.Fd()),
Device: "tun0",
cidr: cidr,
DefaultMTU: defaultMTU,
TXQueueLen: txQueueLen,
Routes: routes,
useSystemRoutes: useSystemRoutes,
l: l,
}
t.routeTree.Store(routeTree)
return t, nil
}
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int, multiqueue bool, useSystemRoutes bool) (*tun, error) {
func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, multiqueue bool) (*tun, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
return nil, err
@ -95,46 +89,113 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int
if multiqueue {
req.Flags |= unix.IFF_MULTI_QUEUE
}
copy(req.Name[:], deviceName)
copy(req.Name[:], c.GetString("tun.dev", ""))
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
return nil, err
}
name := strings.Trim(string(req.Name[:]), "\x00")
file := os.NewFile(uintptr(fd), "/dev/net/tun")
maxMTU := defaultMTU
for _, r := range routes {
if r.MTU == 0 {
r.MTU = defaultMTU
}
if r.MTU > maxMTU {
maxMTU = r.MTU
}
}
routeTree, err := makeRouteTree(l, routes, true)
t, err := newTunGeneric(c, l, file, cidr)
if err != nil {
return nil, err
}
t.Device = name
return t, nil
}
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr *net.IPNet) (*tun, error) {
t := &tun{
ReadWriteCloser: file,
fd: int(file.Fd()),
Device: name,
cidr: cidr,
MaxMTU: maxMTU,
DefaultMTU: defaultMTU,
TXQueueLen: txQueueLen,
Routes: routes,
useSystemRoutes: useSystemRoutes,
TXQueueLen: c.GetInt("tun.tx_queue", 500),
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
l: l,
}
t.routeTree.Store(routeTree)
err := t.reload(c, true)
if err != nil {
return nil, err
}
c.RegisterReloadCallback(func(c *config.C) {
err := t.reload(c, false)
if err != nil {
util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
}
})
return t, nil
}
func (t *tun) reload(c *config.C, initial bool) error {
routeChange, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
if err != nil {
return err
}
if !initial && !routeChange && !c.HasChanged("tun.mtu") {
return nil
}
routeTree, err := makeRouteTree(t.l, routes, true)
if err != nil {
return err
}
oldDefaultMTU := t.DefaultMTU
oldMaxMTU := t.MaxMTU
newDefaultMTU := c.GetInt("tun.mtu", DefaultMTU)
newMaxMTU := newDefaultMTU
for i, r := range routes {
if r.MTU == 0 {
routes[i].MTU = newDefaultMTU
}
if r.MTU > t.MaxMTU {
newMaxMTU = r.MTU
}
}
t.MaxMTU = newMaxMTU
t.DefaultMTU = newDefaultMTU
// Teach nebula how to handle the routes before establishing them in the system table
oldRoutes := t.Routes.Swap(&routes)
t.routeTree.Store(routeTree)
if !initial {
if oldMaxMTU != newMaxMTU {
t.setMTU()
t.l.Infof("Set max MTU to %v was %v", t.MaxMTU, oldMaxMTU)
}
if oldDefaultMTU != newDefaultMTU {
err := t.setDefaultRoute()
if err != nil {
t.l.Warn(err)
} else {
t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU)
}
}
// Remove first, if the system removes a wanted route hopefully it will be re-added next
t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
// Ensure any routes we actually want are installed
err = t.addRoutes(true)
if err != nil {
// This should never be called since addRoutes should log its own errors in a reload condition
util.LogWithContextIfNeeded("Failed to refresh routes", err, t.l)
}
}
return nil
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
@ -208,7 +269,7 @@ func (t *tun) Activate() error {
if err != nil {
return err
}
fd := uintptr(s)
t.ioctlFd = uintptr(s)
ifra := ifreqAddr{
Name: devName,
@ -219,52 +280,76 @@ func (t *tun) Activate() error {
}
// Set the device ip address
if err = ioctl(fd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil {
if err = ioctl(t.ioctlFd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil {
return fmt.Errorf("failed to set tun address: %s", err)
}
// Set the device network
ifra.Addr.Addr = mask
if err = ioctl(fd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil {
if err = ioctl(t.ioctlFd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil {
return fmt.Errorf("failed to set tun netmask: %s", err)
}
// Set the device name
ifrf := ifReq{Name: devName}
if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
if err = ioctl(t.ioctlFd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to set tun device name: %s", err)
}
// Set the MTU on the device
ifm := ifreqMTU{Name: devName, MTU: int32(t.MaxMTU)}
if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
// This is currently a non fatal condition because the route table must have the MTU set appropriately as well
t.l.WithError(err).Error("Failed to set tun mtu")
}
// Setup our default MTU
t.setMTU()
// Set the transmit queue length
ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)}
if err = ioctl(fd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
if err = ioctl(t.ioctlFd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
// If we can't set the queue length nebula will still work but it may lead to packet loss
t.l.WithError(err).Error("Failed to set tun tx queue length")
}
// Bring up the interface
ifrf.Flags = ifrf.Flags | unix.IFF_UP
if err = ioctl(fd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to bring the tun device up: %s", err)
}
// Set the routes
link, err := netlink.LinkByName(t.Device)
if err != nil {
return fmt.Errorf("failed to get tun device link: %s", err)
}
t.deviceIndex = link.Attrs().Index
if err = t.setDefaultRoute(); err != nil {
return err
}
// Set the routes
if err = t.addRoutes(false); err != nil {
return err
}
// Run the interface
ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING
if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to run tun device: %s", err)
}
return nil
}
func (t *tun) setMTU() {
// Set the MTU on the device
ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MaxMTU)}
if err := ioctl(t.ioctlFd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
// This is currently a non fatal condition because the route table must have the MTU set appropriately as well
t.l.WithError(err).Error("Failed to set tun mtu")
}
}
func (t *tun) setDefaultRoute() error {
// Default route
dr := &net.IPNet{IP: t.cidr.IP.Mask(t.cidr.Mask), Mask: t.cidr.Mask}
nr := netlink.Route{
LinkIndex: link.Attrs().Index,
LinkIndex: t.deviceIndex,
Dst: dr,
MTU: t.DefaultMTU,
AdvMSS: t.advMSS(Route{}),
@ -274,19 +359,24 @@ func (t *tun) Activate() error {
Table: unix.RT_TABLE_MAIN,
Type: unix.RTN_UNICAST,
}
err = netlink.RouteReplace(&nr)
err := netlink.RouteReplace(&nr)
if err != nil {
return fmt.Errorf("failed to set mtu %v on the default route %v; %v", t.DefaultMTU, dr, err)
}
return nil
}
func (t *tun) addRoutes(logErrors bool) error {
// Path routes
for _, r := range t.Routes {
routes := *t.Routes.Load()
for _, r := range routes {
if !r.Install {
continue
}
nr := netlink.Route{
LinkIndex: link.Attrs().Index,
LinkIndex: t.deviceIndex,
Dst: r.Cidr,
MTU: r.MTU,
AdvMSS: t.advMSS(r),
@ -297,21 +387,49 @@ func (t *tun) Activate() error {
nr.Priority = r.Metric
}
err = netlink.RouteAdd(&nr)
err := netlink.RouteReplace(&nr)
if err != nil {
return fmt.Errorf("failed to set mtu %v on route %v; %v", r.MTU, r.Cidr, err)
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
if logErrors {
retErr.Log(t.l)
} else {
return retErr
}
} else {
t.l.WithField("route", r).Info("Added route")
}
}
// Run the interface
ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING
if err = ioctl(fd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to run tun device: %s", err)
}
return nil
}
func (t *tun) removeRoutes(routes []Route) {
for _, r := range routes {
if !r.Install {
continue
}
nr := netlink.Route{
LinkIndex: t.deviceIndex,
Dst: r.Cidr,
MTU: r.MTU,
AdvMSS: t.advMSS(r),
Scope: unix.RT_SCOPE_LINK,
}
if r.Metric > 0 {
nr.Priority = r.Metric
}
err := netlink.RouteDel(&nr)
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
}
}
}
func (t *tun) Cidr() *net.IPNet {
return t.cidr
}
@ -410,5 +528,9 @@ func (t *tun) Close() error {
t.ReadWriteCloser.Close()
}
if t.ioctlFd > 0 {
os.NewFile(t.ioctlFd, "ioctlFd").Close()
}
return nil
}

View File

@ -11,12 +11,15 @@ import (
"os/exec"
"regexp"
"strconv"
"sync/atomic"
"syscall"
"unsafe"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/util"
)
type ifreqDestroy struct {
@ -28,8 +31,8 @@ type tun struct {
Device string
cidr *net.IPNet
MTU int
Routes []Route
routeTree *cidr.Tree4[iputil.VpnIp]
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
l *logrus.Logger
io.ReadWriteCloser
@ -56,43 +59,50 @@ func (t *tun) Close() error {
return nil
}
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) {
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
}
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) {
func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) {
// Try to open tun device
var file *os.File
var err error
deviceName := c.GetString("tun.dev", "")
if deviceName == "" {
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
}
if !deviceNameRE.MatchString(deviceName) {
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
}
file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
if err != nil {
return nil, err
}
routeTree, err := makeRouteTree(l, routes, false)
if err != nil {
return nil, err
}
return &tun{
t := &tun{
ReadWriteCloser: file,
Device: deviceName,
cidr: cidr,
MTU: defaultMTU,
Routes: routes,
routeTree: routeTree,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
}, nil
}
err = t.reload(c, true)
if err != nil {
return nil, err
}
c.RegisterReloadCallback(func(c *config.C) {
err := t.reload(c, false)
if err != nil {
util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
}
})
return t, nil
}
func (t *tun) Activate() error {
@ -116,17 +126,42 @@ func (t *tun) Activate() error {
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
// Unsafe path routes
for _, r := range t.Routes {
if r.Via == nil || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
return t.addRoutes(false)
}
func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
if err != nil {
return err
}
if !initial && !change {
return nil
}
routeTree, err := makeRouteTree(t.l, routes, false)
if err != nil {
return err
}
// Teach nebula how to handle the routes before establishing them in the system table
oldRoutes := t.Routes.Swap(&routes)
t.routeTree.Store(routeTree)
if !initial {
// Remove first, if the system removes a wanted route hopefully it will be re-added next
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
if err != nil {
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
}
cmd = exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.IP.String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.Cidr.String(), err)
// Ensure any routes we actually want are installed
err = t.addRoutes(true)
if err != nil {
// Catch any stray logs
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
}
}
@ -134,7 +169,7 @@ func (t *tun) Activate() error {
}
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
_, r := t.routeTree.MostSpecificContains(ip)
_, r := t.routeTree.Load().MostSpecificContains(ip)
return r
}
@ -150,6 +185,46 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd")
}
func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load()
for _, r := range routes {
if r.Via == nil || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.IP.String())
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
if logErrors {
retErr.Log(t.l)
} else {
return retErr
}
}
}
return nil
}
func (t *tun) removeRoutes(routes []Route) error {
for _, r := range routes {
if !r.Install {
continue
}
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.cidr.IP.String())
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
}
}
return nil
}
func (t *tun) deviceBytes() (o [16]byte) {
for i, c := range t.Device {
o[i] = byte(c)

View File

@ -11,19 +11,22 @@ import (
"os/exec"
"regexp"
"strconv"
"sync/atomic"
"syscall"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/util"
)
type tun struct {
Device string
cidr *net.IPNet
MTU int
Routes []Route
routeTree *cidr.Tree4[iputil.VpnIp]
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
l *logrus.Logger
io.ReadWriteCloser
@ -40,13 +43,14 @@ func (t *tun) Close() error {
return nil
}
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) {
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD")
}
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) {
func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) {
deviceName := c.GetString("tun.dev", "")
if deviceName == "" {
return nil, fmt.Errorf("a device name in the format of tunN must be specified")
}
@ -60,20 +64,64 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int
return nil, err
}
routeTree, err := makeRouteTree(l, routes, false)
t := &tun{
ReadWriteCloser: file,
Device: deviceName,
cidr: cidr,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
}
err = t.reload(c, true)
if err != nil {
return nil, err
}
return &tun{
ReadWriteCloser: file,
Device: deviceName,
cidr: cidr,
MTU: defaultMTU,
Routes: routes,
routeTree: routeTree,
l: l,
}, nil
c.RegisterReloadCallback(func(c *config.C) {
err := t.reload(c, false)
if err != nil {
util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
}
})
return t, nil
}
func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
if err != nil {
return err
}
if !initial && !change {
return nil
}
routeTree, err := makeRouteTree(t.l, routes, false)
if err != nil {
return err
}
// Teach nebula how to handle the routes before establishing them in the system table
oldRoutes := t.Routes.Swap(&routes)
t.routeTree.Store(routeTree)
if !initial {
// Remove first, if the system removes a wanted route hopefully it will be re-added next
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
if err != nil {
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
}
// Ensure any routes we actually want are installed
err = t.addRoutes(true)
if err != nil {
// Catch any stray logs
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
}
}
return nil
}
func (t *tun) Activate() error {
@ -98,25 +146,52 @@ func (t *tun) Activate() error {
}
// Unsafe path routes
for _, r := range t.Routes {
return t.addRoutes(false)
}
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
_, r := t.routeTree.Load().MostSpecificContains(ip)
return r
}
func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load()
for _, r := range routes {
if r.Via == nil || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
cmd = exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.IP.String())
cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.IP.String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.Cidr.String(), err)
if err := cmd.Run(); err != nil {
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
if logErrors {
retErr.Log(t.l)
} else {
return retErr
}
}
}
return nil
}
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
_, r := t.routeTree.MostSpecificContains(ip)
return r
func (t *tun) removeRoutes(routes []Route) error {
for _, r := range routes {
if !r.Install {
continue
}
cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.cidr.IP.String())
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
}
}
return nil
}
func (t *tun) Cidr() *net.IPNet {

View File

@ -12,6 +12,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil"
)
@ -27,14 +28,18 @@ type TestTun struct {
TxPackets chan []byte // Packets transmitted outside by nebula
}
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool, _ bool) (*TestTun, error) {
func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*TestTun, error) {
_, routes, err := getAllRoutesFromConfig(c, cidr, true)
if err != nil {
return nil, err
}
routeTree, err := makeRouteTree(l, routes, false)
if err != nil {
return nil, err
}
return &TestTun{
Device: deviceName,
Device: c.GetString("tun.dev", ""),
cidr: cidr,
Routes: routes,
routeTree: routeTree,
@ -44,7 +49,7 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, _ int, routes
}, nil
}
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*TestTun, error) {
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*TestTun, error) {
return nil, fmt.Errorf("newTunFromFd not supported")
}

View File

@ -6,10 +6,13 @@ import (
"net"
"os/exec"
"strconv"
"sync/atomic"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/util"
"github.com/songgao/water"
)
@ -17,25 +20,34 @@ type waterTun struct {
Device string
cidr *net.IPNet
MTU int
Routes []Route
routeTree *cidr.Tree4[iputil.VpnIp]
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
l *logrus.Logger
f *net.Interface
*water.Interface
}
func newWaterTun(l *logrus.Logger, cidr *net.IPNet, defaultMTU int, routes []Route) (*waterTun, error) {
routeTree, err := makeRouteTree(l, routes, false)
func newWaterTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*waterTun, error) {
// NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate()
t := &waterTun{
cidr: cidr,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
}
err := t.reload(c, true)
if err != nil {
return nil, err
}
// NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate()
return &waterTun{
cidr: cidr,
MTU: defaultMTU,
Routes: routes,
routeTree: routeTree,
}, nil
c.RegisterReloadCallback(func(c *config.C) {
err := t.reload(c, false)
if err != nil {
util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
}
})
return t, nil
}
func (t *waterTun) Activate() error {
@ -74,30 +86,104 @@ func (t *waterTun) Activate() error {
return fmt.Errorf("failed to run 'netsh' to set MTU: %s", err)
}
iface, err := net.InterfaceByName(t.Device)
t.f, err = net.InterfaceByName(t.Device)
if err != nil {
return fmt.Errorf("failed to find interface named %s: %v", t.Device, err)
}
for _, r := range t.Routes {
if r.Via == nil || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
err = t.addRoutes(false)
if err != nil {
return err
}
err = exec.Command(
"C:\\Windows\\System32\\route.exe", "add", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(iface.Index), "METRIC", strconv.Itoa(r.Metric),
).Run()
return nil
}
func (t *waterTun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
if err != nil {
return err
}
if !initial && !change {
return nil
}
routeTree, err := makeRouteTree(t.l, routes, false)
if err != nil {
return err
}
// Teach nebula how to handle the routes before establishing them in the system table
oldRoutes := t.Routes.Swap(&routes)
t.routeTree.Store(routeTree)
if !initial {
// Remove first, if the system removes a wanted route hopefully it will be re-added next
t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
// Ensure any routes we actually want are installed
err = t.addRoutes(true)
if err != nil {
return fmt.Errorf("failed to add the unsafe_route %s: %v", r.Cidr.String(), err)
// Catch any stray logs
util.LogWithContextIfNeeded("Failed to set routes", err, t.l)
} else {
for _, r := range findRemovedRoutes(routes, *oldRoutes) {
t.l.WithField("route", r).Info("Removed route")
}
}
}
return nil
}
func (t *waterTun) addRoutes(logErrors bool) error {
// Path routes
routes := *t.Routes.Load()
for _, r := range routes {
if r.Via == nil || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
err := exec.Command(
"C:\\Windows\\System32\\route.exe", "add", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(t.f.Index), "METRIC", strconv.Itoa(r.Metric),
).Run()
if err != nil {
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
if logErrors {
retErr.Log(t.l)
} else {
return retErr
}
} else {
t.l.WithField("route", r).Info("Added route")
}
}
return nil
}
func (t *waterTun) removeRoutes(routes []Route) {
for _, r := range routes {
if !r.Install {
continue
}
err := exec.Command(
"C:\\Windows\\System32\\route.exe", "delete", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(t.f.Index), "METRIC", strconv.Itoa(r.Metric),
).Run()
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
}
}
}
func (t *waterTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
_, r := t.routeTree.MostSpecificContains(ip)
_, r := t.routeTree.Load().MostSpecificContains(ip)
return r
}

View File

@ -12,13 +12,14 @@ import (
"syscall"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
)
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (Device, error) {
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (Device, error) {
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
}
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (Device, error) {
func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, multiqueue bool) (Device, error) {
useWintun := true
if err := checkWinTunExists(); err != nil {
l.WithError(err).Warn("Check Wintun driver failed, fallback to wintap driver")
@ -26,14 +27,14 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int
}
if useWintun {
device, err := newWinTun(l, deviceName, cidr, defaultMTU, routes)
device, err := newWinTun(c, l, cidr, multiqueue)
if err != nil {
return nil, fmt.Errorf("create Wintun interface failed, %w", err)
}
return device, nil
}
device, err := newWaterTun(l, cidr, defaultMTU, routes)
device, err := newWaterTun(c, l, cidr, multiqueue)
if err != nil {
return nil, fmt.Errorf("create wintap driver failed, %w", err)
}

View File

@ -6,11 +6,14 @@ import (
"io"
"net"
"net/netip"
"sync/atomic"
"unsafe"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/util"
"github.com/slackhq/nebula/wintun"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
@ -23,8 +26,9 @@ type winTun struct {
cidr *net.IPNet
prefix netip.Prefix
MTU int
Routes []Route
routeTree *cidr.Tree4[iputil.VpnIp]
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
l *logrus.Logger
tun *wintun.NativeTun
}
@ -48,83 +52,148 @@ func generateGUIDByDeviceName(name string) (*windows.GUID, error) {
return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil
}
func newWinTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route) (*winTun, error) {
func newWinTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*winTun, error) {
deviceName := c.GetString("tun.dev", "")
guid, err := generateGUIDByDeviceName(deviceName)
if err != nil {
return nil, fmt.Errorf("generate GUID failed: %w", err)
}
var tunDevice wintun.Device
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, defaultMTU)
if err != nil {
// Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device.
// Trying a second time resolves the issue.
l.WithError(err).Debug("Failed to create wintun device, retrying")
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, defaultMTU)
if err != nil {
return nil, fmt.Errorf("create TUN device failed: %w", err)
}
}
routeTree, err := makeRouteTree(l, routes, false)
if err != nil {
return nil, err
}
prefix, err := iputil.ToNetIpPrefix(*cidr)
if err != nil {
return nil, err
}
return &winTun{
Device: deviceName,
cidr: cidr,
prefix: prefix,
MTU: defaultMTU,
Routes: routes,
routeTree: routeTree,
t := &winTun{
Device: deviceName,
cidr: cidr,
prefix: prefix,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
}
tun: tunDevice.(*wintun.NativeTun),
}, nil
err = t.reload(c, true)
if err != nil {
return nil, err
}
var tunDevice wintun.Device
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
if err != nil {
// Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device.
// Trying a second time resolves the issue.
l.WithError(err).Debug("Failed to create wintun device, retrying")
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
if err != nil {
return nil, fmt.Errorf("create TUN device failed: %w", err)
}
}
t.tun = tunDevice.(*wintun.NativeTun)
c.RegisterReloadCallback(func(c *config.C) {
err := t.reload(c, false)
if err != nil {
util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
}
})
return t, nil
}
func (t *winTun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
if err != nil {
return err
}
if !initial && !change {
return nil
}
routeTree, err := makeRouteTree(t.l, routes, false)
if err != nil {
return err
}
// Teach nebula how to handle the routes before establishing them in the system table
oldRoutes := t.Routes.Swap(&routes)
t.routeTree.Store(routeTree)
if !initial {
// Remove first, if the system removes a wanted route hopefully it will be re-added next
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
if err != nil {
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
}
// Ensure any routes we actually want are installed
err = t.addRoutes(true)
if err != nil {
// Catch any stray logs
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
}
}
return nil
}
func (t *winTun) Activate() error {
luid := winipcfg.LUID(t.tun.LUID())
if err := luid.SetIPAddresses([]netip.Prefix{t.prefix}); err != nil {
err := luid.SetIPAddresses([]netip.Prefix{t.prefix})
if err != nil {
return fmt.Errorf("failed to set address: %w", err)
}
foundDefault4 := false
routes := make([]*winipcfg.RouteData, 0, len(t.Routes)+1)
err = t.addRoutes(false)
if err != nil {
return err
}
for _, r := range t.Routes {
return nil
}
func (t *winTun) addRoutes(logErrors bool) error {
luid := winipcfg.LUID(t.tun.LUID())
routes := *t.Routes.Load()
foundDefault4 := false
for _, r := range routes {
if r.Via == nil || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
prefix, err := iputil.ToNetIpPrefix(*r.Cidr)
if err != nil {
retErr := util.NewContextualError("Failed to parse cidr to netip prefix, ignoring route", map[string]interface{}{"route": r}, err)
if logErrors {
retErr.Log(t.l)
continue
} else {
return retErr
}
}
// Add our unsafe route
err = luid.AddRoute(prefix, r.Via.ToNetIpAddr(), uint32(r.Metric))
if err != nil {
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
if logErrors {
retErr.Log(t.l)
continue
} else {
return retErr
}
} else {
t.l.WithField("route", r).Info("Added route")
}
if !foundDefault4 {
if ones, bits := r.Cidr.Mask.Size(); ones == 0 && bits != 0 {
foundDefault4 = true
}
}
prefix, err := iputil.ToNetIpPrefix(*r.Cidr)
if err != nil {
return err
}
// Add our unsafe route
routes = append(routes, &winipcfg.RouteData{
Destination: prefix,
NextHop: r.Via.ToNetIpAddr(),
Metric: uint32(r.Metric),
})
}
if err := luid.AddRoutes(routes); err != nil {
return fmt.Errorf("failed to add routes: %w", err)
}
ipif, err := luid.IPInterface(windows.AF_INET)
@ -141,12 +210,35 @@ func (t *winTun) Activate() error {
if err := ipif.Set(); err != nil {
return fmt.Errorf("failed to set ip interface: %w", err)
}
return nil
}
func (t *winTun) removeRoutes(routes []Route) error {
luid := winipcfg.LUID(t.tun.LUID())
for _, r := range routes {
if !r.Install {
continue
}
prefix, err := iputil.ToNetIpPrefix(*r.Cidr)
if err != nil {
t.l.WithError(err).WithField("route", r).Info("Failed to convert cidr to netip prefix")
continue
}
err = luid.DeleteRoute(prefix, r.Via.ToNetIpAddr())
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
}
}
return nil
}
func (t *winTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
_, r := t.routeTree.MostSpecificContains(ip)
_, r := t.routeTree.Load().MostSpecificContains(ip)
return r
}

2
ssh.go
View File

@ -939,7 +939,7 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
enc.SetIndent("", " ")
}
return enc.Encode(copyHostInfo(hostInfo, ifce.hostMap.preferredRanges))
return enc.Encode(copyHostInfo(hostInfo, ifce.hostMap.GetPreferredRanges()))
}
func sshReload(c *config.C, w sshd.StringWriter) error {

View File

@ -2,6 +2,7 @@ package util
import (
"errors"
"fmt"
"github.com/sirupsen/logrus"
)
@ -40,7 +41,7 @@ func (ce *ContextualError) Error() string {
if ce.RealError == nil {
return ce.Context
}
return ce.RealError.Error()
return fmt.Errorf("%s (%v): %w", ce.Context, ce.Fields, ce.RealError).Error()
}
func (ce *ContextualError) Unwrap() error {