mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-11 19:03:57 +01:00
Merge remote-tracking branch 'origin/master' into mutex-debug
This commit is contained in:
commit
0ccfad1a1e
20
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
20
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@ -14,7 +14,7 @@ body:
|
||||
- type: input
|
||||
id: version
|
||||
attributes:
|
||||
label: What version of `nebula` are you using?
|
||||
label: What version of `nebula` are you using? (`nebula -version`)
|
||||
placeholder: 0.0.0
|
||||
validations:
|
||||
required: true
|
||||
@ -41,10 +41,17 @@ body:
|
||||
attributes:
|
||||
label: Logs from affected hosts
|
||||
description: |
|
||||
Provide logs from all affected hosts during the time of the issue.
|
||||
Please provide logs from ALL affected hosts during the time of the issue. If you do not provide logs we will be unable to assist you!
|
||||
|
||||
[Learn how to find Nebula logs here.](https://nebula.defined.net/docs/guides/viewing-nebula-logs/)
|
||||
|
||||
Improve formatting by using <code>```</code> at the beginning and end of each log block.
|
||||
value: |
|
||||
```
|
||||
|
||||
```
|
||||
validations:
|
||||
required: false
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: configs
|
||||
@ -52,6 +59,11 @@ body:
|
||||
label: Config files from affected hosts
|
||||
description: |
|
||||
Provide config files for all affected hosts.
|
||||
|
||||
Improve formatting by using <code>```</code> at the beginning and end of each config file.
|
||||
value: |
|
||||
```
|
||||
|
||||
```
|
||||
validations:
|
||||
required: false
|
||||
required: true
|
||||
|
||||
@ -142,15 +142,22 @@ func (tree *Tree4[T]) MostSpecificContains(ip iputil.VpnIp) (ok bool, value T) {
|
||||
return ok, value
|
||||
}
|
||||
|
||||
// Match finds the most specific match
|
||||
// TODO this is exact match
|
||||
func (tree *Tree4[T]) Match(ip iputil.VpnIp) (ok bool, value T) {
|
||||
type eachFunc[T any] func(T) bool
|
||||
|
||||
// EachContains will call a function, passing the value, for each entry until the function returns true or the search is complete
|
||||
// The final return value will be true if the provided function returned true
|
||||
func (tree *Tree4[T]) EachContains(ip iputil.VpnIp, each eachFunc[T]) bool {
|
||||
bit := startbit
|
||||
node := tree.root
|
||||
lastNode := node
|
||||
|
||||
for node != nil {
|
||||
lastNode = node
|
||||
if node.hasValue {
|
||||
// If the each func returns true then we can exit the loop
|
||||
if each(node.value) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if ip&bit != 0 {
|
||||
node = node.right
|
||||
} else {
|
||||
@ -160,10 +167,33 @@ func (tree *Tree4[T]) Match(ip iputil.VpnIp) (ok bool, value T) {
|
||||
bit >>= 1
|
||||
}
|
||||
|
||||
if bit == 0 && lastNode != nil {
|
||||
value = lastNode.value
|
||||
ok = true
|
||||
return false
|
||||
}
|
||||
|
||||
// GetCIDR returns the entry added by the most recent matching AddCIDR call
|
||||
func (tree *Tree4[T]) GetCIDR(cidr *net.IPNet) (ok bool, value T) {
|
||||
bit := startbit
|
||||
node := tree.root
|
||||
|
||||
ip := iputil.Ip2VpnIp(cidr.IP)
|
||||
mask := iputil.Ip2VpnIp(cidr.Mask)
|
||||
|
||||
// Find our last ancestor in the tree
|
||||
for node != nil && bit&mask != 0 {
|
||||
if ip&bit != 0 {
|
||||
node = node.right
|
||||
} else {
|
||||
node = node.left
|
||||
}
|
||||
|
||||
bit = bit >> 1
|
||||
}
|
||||
|
||||
if bit&mask == 0 && node != nil {
|
||||
value = node.value
|
||||
ok = node.hasValue
|
||||
}
|
||||
|
||||
return ok, value
|
||||
}
|
||||
|
||||
|
||||
@ -115,35 +115,36 @@ func TestCIDRTree_MostSpecificContains(t *testing.T) {
|
||||
assert.Equal(t, "cool", r)
|
||||
}
|
||||
|
||||
func TestCIDRTree_Match(t *testing.T) {
|
||||
func TestTree4_GetCIDR(t *testing.T) {
|
||||
tree := NewTree4[string]()
|
||||
tree.AddCIDR(Parse("4.1.1.0/32"), "1a")
|
||||
tree.AddCIDR(Parse("4.1.1.1/32"), "1b")
|
||||
tree.AddCIDR(Parse("1.0.0.0/8"), "1")
|
||||
tree.AddCIDR(Parse("2.1.0.0/16"), "2")
|
||||
tree.AddCIDR(Parse("3.1.1.0/24"), "3")
|
||||
tree.AddCIDR(Parse("4.1.1.0/24"), "4a")
|
||||
tree.AddCIDR(Parse("4.1.1.1/32"), "4b")
|
||||
tree.AddCIDR(Parse("4.1.2.1/32"), "4c")
|
||||
tree.AddCIDR(Parse("254.0.0.0/4"), "5")
|
||||
|
||||
tests := []struct {
|
||||
Found bool
|
||||
Result interface{}
|
||||
IP string
|
||||
IPNet *net.IPNet
|
||||
}{
|
||||
{true, "1a", "4.1.1.0"},
|
||||
{true, "1b", "4.1.1.1"},
|
||||
{true, "1", Parse("1.0.0.0/8")},
|
||||
{true, "2", Parse("2.1.0.0/16")},
|
||||
{true, "3", Parse("3.1.1.0/24")},
|
||||
{true, "4a", Parse("4.1.1.0/24")},
|
||||
{true, "4b", Parse("4.1.1.1/32")},
|
||||
{true, "4c", Parse("4.1.2.1/32")},
|
||||
{true, "5", Parse("254.0.0.0/4")},
|
||||
{false, "", Parse("2.0.0.0/8")},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
ok, r := tree.Match(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))
|
||||
ok, r := tree.GetCIDR(tt.IPNet)
|
||||
assert.Equal(t, tt.Found, ok)
|
||||
assert.Equal(t, tt.Result, r)
|
||||
}
|
||||
|
||||
tree = NewTree4[string]()
|
||||
tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
|
||||
ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "cool", r)
|
||||
|
||||
ok, r = tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "cool", r)
|
||||
}
|
||||
|
||||
func BenchmarkCIDRTree_Contains(b *testing.B) {
|
||||
@ -167,25 +168,3 @@ func BenchmarkCIDRTree_Contains(b *testing.B) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkCIDRTree_Match(b *testing.B) {
|
||||
tree := NewTree4[string]()
|
||||
tree.AddCIDR(Parse("1.1.0.0/16"), "1")
|
||||
tree.AddCIDR(Parse("1.2.1.1/32"), "1")
|
||||
tree.AddCIDR(Parse("192.2.1.1/32"), "1")
|
||||
tree.AddCIDR(Parse("172.2.1.1/32"), "1")
|
||||
|
||||
ip := iputil.Ip2VpnIp(net.ParseIP("1.2.1.1"))
|
||||
b.Run("found", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
tree.Match(ip)
|
||||
}
|
||||
})
|
||||
|
||||
ip = iputil.Ip2VpnIp(net.ParseIP("1.2.1.255"))
|
||||
b.Run("not found", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
tree.Match(ip)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@ -456,7 +456,7 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
|
||||
}
|
||||
|
||||
if n.punchy.GetTargetEverything() {
|
||||
hostinfo.remotes.ForEach(n.hostMap.preferredRanges, func(addr *udp.Addr, preferred bool) {
|
||||
hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr *udp.Addr, preferred bool) {
|
||||
n.metricsTxPunchy.Inc(1)
|
||||
n.intf.outside.WriteTo([]byte{1}, addr)
|
||||
})
|
||||
|
||||
@ -43,7 +43,9 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
||||
preferredRanges := []*net.IPNet{localrange}
|
||||
|
||||
// Very incomplete mock objects
|
||||
hostMap := NewHostMap(l, vpncidr, preferredRanges)
|
||||
hostMap := newHostMap(l, vpncidr)
|
||||
hostMap.preferredRanges.Store(&preferredRanges)
|
||||
|
||||
cs := &CertState{
|
||||
RawCertificate: []byte{},
|
||||
PrivateKey: []byte{},
|
||||
@ -123,7 +125,9 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
||||
preferredRanges := []*net.IPNet{localrange}
|
||||
|
||||
// Very incomplete mock objects
|
||||
hostMap := NewHostMap(l, vpncidr, preferredRanges)
|
||||
hostMap := newHostMap(l, vpncidr)
|
||||
hostMap.preferredRanges.Store(&preferredRanges)
|
||||
|
||||
cs := &CertState{
|
||||
RawCertificate: []byte{},
|
||||
PrivateKey: []byte{},
|
||||
@ -210,7 +214,8 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
||||
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
|
||||
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||
preferredRanges := []*net.IPNet{localrange}
|
||||
hostMap := NewHostMap(l, vpncidr, preferredRanges)
|
||||
hostMap := newHostMap(l, vpncidr)
|
||||
hostMap.preferredRanges.Store(&preferredRanges)
|
||||
|
||||
// Generate keys for CA and peer's cert.
|
||||
pubCA, privCA, _ := ed25519.GenerateKey(rand.Reader)
|
||||
|
||||
@ -145,7 +145,7 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlH
|
||||
return nil
|
||||
}
|
||||
|
||||
ch := copyHostInfo(h, c.f.hostMap.preferredRanges)
|
||||
ch := copyHostInfo(h, c.f.hostMap.GetPreferredRanges())
|
||||
return &ch
|
||||
}
|
||||
|
||||
@ -157,7 +157,7 @@ func (c *Control) SetRemoteForTunnel(vpnIp iputil.VpnIp, addr udp.Addr) *Control
|
||||
}
|
||||
|
||||
hostInfo.SetRemote(addr.Copy())
|
||||
ch := copyHostInfo(hostInfo, c.f.hostMap.preferredRanges)
|
||||
ch := copyHostInfo(hostInfo, c.f.hostMap.GetPreferredRanges())
|
||||
return &ch
|
||||
}
|
||||
|
||||
|
||||
@ -18,7 +18,9 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
|
||||
// To properly ensure we are not exposing core memory to the caller
|
||||
hm := NewHostMap(l, &net.IPNet{}, make([]*net.IPNet, 0))
|
||||
hm := newHostMap(l, &net.IPNet{})
|
||||
hm.preferredRanges.Store(&[]*net.IPNet{})
|
||||
|
||||
remote1 := udp.NewAddr(net.ParseIP("0.0.0.100"), 4444)
|
||||
remote2 := udp.NewAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444)
|
||||
ipNet := net.IPNet{
|
||||
|
||||
@ -309,6 +309,13 @@ firewall:
|
||||
outbound_action: drop
|
||||
inbound_action: drop
|
||||
|
||||
# Controls the default value for local_cidr. Default is true, will be deprecated after v1.9 and defaulted to false.
|
||||
# This setting only affects nebula hosts with subnets encoded in their certificate. A nebula host acting as an
|
||||
# unsafe router with `default_local_cidr_any: true` will expose their unsafe routes to every inbound rule regardless
|
||||
# of the actual destination for the packet. Setting this to false requires each inbound rule to contain a `local_cidr`
|
||||
# if the intention is to allow traffic to flow to an unsafe route.
|
||||
#default_local_cidr_any: false
|
||||
|
||||
conntrack:
|
||||
tcp_timeout: 12m
|
||||
udp_timeout: 3m
|
||||
@ -316,7 +323,7 @@ firewall:
|
||||
|
||||
# The firewall is default deny. There is no way to write a deny rule.
|
||||
# Rules are comprised of a protocol, port, and one or more of host, group, or CIDR
|
||||
# Logical evaluation is roughly: port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr)
|
||||
# Logical evaluation is roughly: port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr) AND (local cidr)
|
||||
# - port: Takes `0` or `any` as any, a single number `80`, a range `200-901`, or `fragment` to match second and further fragments of fragmented packets (since there is no port available).
|
||||
# code: same as port but makes more sense when talking about ICMP, TODO: this is not currently implemented in a way that works, use `any`
|
||||
# proto: `any`, `tcp`, `udp`, or `icmp`
|
||||
@ -325,6 +332,8 @@ firewall:
|
||||
# groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
|
||||
# cidr: a remote CIDR, `0.0.0.0/0` is any.
|
||||
# local_cidr: a local CIDR, `0.0.0.0/0` is any. This could be used to filter destinations when using unsafe_routes.
|
||||
# Default is `any` unless the certificate contains subnets and then the default is the ip issued in the certificate
|
||||
# if `default_local_cidr_any` is false, otherwise its `any`.
|
||||
# ca_name: An issuing CA name
|
||||
# ca_sha: An issuing CA shasum
|
||||
|
||||
@ -346,3 +355,10 @@ firewall:
|
||||
groups:
|
||||
- laptop
|
||||
- home
|
||||
|
||||
# Expose a subnet (unsafe route) to hosts with the group remote_client
|
||||
# This example assume you have a subnet of 192.168.100.1/24 or larger encoded in the certificate
|
||||
- port: 8080
|
||||
proto: tcp
|
||||
group: remote_client
|
||||
local_cidr: 192.168.100.1/24
|
||||
|
||||
191
firewall.go
191
firewall.go
@ -57,15 +57,18 @@ type Firewall struct {
|
||||
DefaultTimeout time.Duration //linux: 600s
|
||||
|
||||
// Used to ensure we don't emit local packets for ips we don't own
|
||||
localIps *cidr.Tree4[struct{}]
|
||||
localIps *cidr.Tree4[struct{}]
|
||||
assignedCIDR *net.IPNet
|
||||
hasSubnets bool
|
||||
|
||||
rules string
|
||||
rulesVersion uint16
|
||||
|
||||
trackTCPRTT bool
|
||||
metricTCPRTT metrics.Histogram
|
||||
incomingMetrics firewallMetrics
|
||||
outgoingMetrics firewallMetrics
|
||||
defaultLocalCIDRAny bool
|
||||
trackTCPRTT bool
|
||||
metricTCPRTT metrics.Histogram
|
||||
incomingMetrics firewallMetrics
|
||||
outgoingMetrics firewallMetrics
|
||||
|
||||
l *logrus.Logger
|
||||
}
|
||||
@ -83,6 +86,8 @@ type FirewallConntrack struct {
|
||||
TimerWheel *TimerWheel[firewall.Packet]
|
||||
}
|
||||
|
||||
// FirewallTable is the entry point for a rule, the evaluation order is:
|
||||
// Proto AND port AND (CA SHA or CA name) AND local CIDR AND (group OR groups OR name OR remote CIDR)
|
||||
type FirewallTable struct {
|
||||
TCP firewallPort
|
||||
UDP firewallPort
|
||||
@ -106,18 +111,27 @@ type FirewallCA struct {
|
||||
}
|
||||
|
||||
type FirewallRule struct {
|
||||
// Any makes Hosts, Groups, CIDR and LocalCIDR irrelevant
|
||||
Any bool
|
||||
Hosts map[string]struct{}
|
||||
Groups [][]string
|
||||
CIDR *cidr.Tree4[struct{}]
|
||||
LocalCIDR *cidr.Tree4[struct{}]
|
||||
// Any makes Hosts, Groups, and CIDR irrelevant
|
||||
Any *firewallLocalCIDR
|
||||
Hosts map[string]*firewallLocalCIDR
|
||||
Groups []*firewallGroups
|
||||
CIDR *cidr.Tree4[*firewallLocalCIDR]
|
||||
}
|
||||
|
||||
type firewallGroups struct {
|
||||
Groups []string
|
||||
LocalCIDR *firewallLocalCIDR
|
||||
}
|
||||
|
||||
// Even though ports are uint16, int32 maps are faster for lookup
|
||||
// Plus we can use `-1` for fragment rules
|
||||
type firewallPort map[int32]*FirewallCA
|
||||
|
||||
type firewallLocalCIDR struct {
|
||||
Any bool
|
||||
LocalCIDR *cidr.Tree4[struct{}]
|
||||
}
|
||||
|
||||
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
|
||||
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall {
|
||||
//TODO: error on 0 duration
|
||||
@ -138,8 +152,15 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
|
||||
}
|
||||
|
||||
localIps := cidr.NewTree4[struct{}]()
|
||||
var assignedCIDR *net.IPNet
|
||||
for _, ip := range c.Details.Ips {
|
||||
localIps.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
|
||||
ipNet := &net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}
|
||||
localIps.AddCIDR(ipNet, struct{}{})
|
||||
|
||||
if assignedCIDR == nil {
|
||||
// Only grabbing the first one in the cert since any more than that currently has undefined behavior
|
||||
assignedCIDR = ipNet
|
||||
}
|
||||
}
|
||||
|
||||
for _, n := range c.Details.Subnets {
|
||||
@ -158,6 +179,8 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
|
||||
UDPTimeout: UDPTimeout,
|
||||
DefaultTimeout: defaultTimeout,
|
||||
localIps: localIps,
|
||||
assignedCIDR: assignedCIDR,
|
||||
hasSubnets: len(c.Details.Subnets) > 0,
|
||||
l: l,
|
||||
|
||||
metricTCPRTT: metrics.GetOrRegisterHistogram("network.tcp.rtt", nil, metrics.NewExpDecaySample(1028, 0.015)),
|
||||
@ -184,6 +207,9 @@ func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *conf
|
||||
//TODO: max_connections
|
||||
)
|
||||
|
||||
//TODO: Flip to false after v1.9 release
|
||||
fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", true)
|
||||
|
||||
inboundAction := c.GetString("firewall.inbound_action", "drop")
|
||||
switch inboundAction {
|
||||
case "reject":
|
||||
@ -270,7 +296,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
|
||||
return fmt.Errorf("unknown protocol %v", proto)
|
||||
}
|
||||
|
||||
return fp.addRule(startPort, endPort, groups, host, ip, localIp, caName, caSha)
|
||||
return fp.addRule(f, startPort, endPort, groups, host, ip, localIp, caName, caSha)
|
||||
}
|
||||
|
||||
// GetRuleHash returns a hash representation of all inbound and outbound rules
|
||||
@ -624,7 +650,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaC
|
||||
return false
|
||||
}
|
||||
|
||||
func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
|
||||
func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
|
||||
if startPort > endPort {
|
||||
return fmt.Errorf("start port was lower than end port")
|
||||
}
|
||||
@ -637,7 +663,7 @@ func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string,
|
||||
}
|
||||
}
|
||||
|
||||
if err := fp[i].addRule(groups, host, ip, localIp, caName, caSha); err != nil {
|
||||
if err := fp[i].addRule(f, groups, host, ip, localIp, caName, caSha); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -668,13 +694,12 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCer
|
||||
return fp[firewall.PortAny].match(p, c, caPool)
|
||||
}
|
||||
|
||||
func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPNet, caName, caSha string) error {
|
||||
func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp *net.IPNet, caName, caSha string) error {
|
||||
fr := func() *FirewallRule {
|
||||
return &FirewallRule{
|
||||
Hosts: make(map[string]struct{}),
|
||||
Groups: make([][]string, 0),
|
||||
CIDR: cidr.NewTree4[struct{}](),
|
||||
LocalCIDR: cidr.NewTree4[struct{}](),
|
||||
Hosts: make(map[string]*firewallLocalCIDR),
|
||||
Groups: make([]*firewallGroups, 0),
|
||||
CIDR: cidr.NewTree4[*firewallLocalCIDR](),
|
||||
}
|
||||
}
|
||||
|
||||
@ -683,14 +708,14 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPN
|
||||
fc.Any = fr()
|
||||
}
|
||||
|
||||
return fc.Any.addRule(groups, host, ip, localIp)
|
||||
return fc.Any.addRule(f, groups, host, ip, localIp)
|
||||
}
|
||||
|
||||
if caSha != "" {
|
||||
if _, ok := fc.CAShas[caSha]; !ok {
|
||||
fc.CAShas[caSha] = fr()
|
||||
}
|
||||
err := fc.CAShas[caSha].addRule(groups, host, ip, localIp)
|
||||
err := fc.CAShas[caSha].addRule(f, groups, host, ip, localIp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -700,7 +725,7 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPN
|
||||
if _, ok := fc.CANames[caName]; !ok {
|
||||
fc.CANames[caName] = fr()
|
||||
}
|
||||
err := fc.CANames[caName].addRule(groups, host, ip, localIp)
|
||||
err := fc.CANames[caName].addRule(f, groups, host, ip, localIp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -732,41 +757,63 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool
|
||||
return fc.CANames[s.Details.Name].match(p, c)
|
||||
}
|
||||
|
||||
func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet, localIp *net.IPNet) error {
|
||||
if fr.Any {
|
||||
return nil
|
||||
func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *net.IPNet, localCIDR *net.IPNet) error {
|
||||
flc := func() *firewallLocalCIDR {
|
||||
return &firewallLocalCIDR{
|
||||
LocalCIDR: cidr.NewTree4[struct{}](),
|
||||
}
|
||||
}
|
||||
|
||||
if fr.isAny(groups, host, ip, localIp) {
|
||||
fr.Any = true
|
||||
// If it's any we need to wipe out any pre-existing rules to save on memory
|
||||
fr.Groups = make([][]string, 0)
|
||||
fr.Hosts = make(map[string]struct{})
|
||||
fr.CIDR = cidr.NewTree4[struct{}]()
|
||||
fr.LocalCIDR = cidr.NewTree4[struct{}]()
|
||||
} else {
|
||||
if len(groups) > 0 {
|
||||
fr.Groups = append(fr.Groups, groups)
|
||||
if fr.isAny(groups, host, ip) {
|
||||
if fr.Any == nil {
|
||||
fr.Any = flc()
|
||||
}
|
||||
|
||||
if host != "" {
|
||||
fr.Hosts[host] = struct{}{}
|
||||
return fr.Any.addRule(f, localCIDR)
|
||||
}
|
||||
|
||||
if len(groups) > 0 {
|
||||
nlc := flc()
|
||||
err := nlc.addRule(f, localCIDR)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if ip != nil {
|
||||
fr.CIDR.AddCIDR(ip, struct{}{})
|
||||
}
|
||||
fr.Groups = append(fr.Groups, &firewallGroups{
|
||||
Groups: groups,
|
||||
LocalCIDR: nlc,
|
||||
})
|
||||
}
|
||||
|
||||
if localIp != nil {
|
||||
fr.LocalCIDR.AddCIDR(localIp, struct{}{})
|
||||
if host != "" {
|
||||
nlc := fr.Hosts[host]
|
||||
if nlc == nil {
|
||||
nlc = flc()
|
||||
}
|
||||
err := nlc.addRule(f, localCIDR)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fr.Hosts[host] = nlc
|
||||
}
|
||||
|
||||
if ip != nil {
|
||||
_, nlc := fr.CIDR.GetCIDR(ip)
|
||||
if nlc == nil {
|
||||
nlc = flc()
|
||||
}
|
||||
err := nlc.addRule(f, localCIDR)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fr.CIDR.AddCIDR(ip, nlc)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fr *FirewallRule) isAny(groups []string, host string, ip, localIp *net.IPNet) bool {
|
||||
if len(groups) == 0 && host == "" && ip == nil && localIp == nil {
|
||||
func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool {
|
||||
if len(groups) == 0 && host == "" && ip == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
@ -784,10 +831,6 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip, localIp *net.IPN
|
||||
return true
|
||||
}
|
||||
|
||||
if localIp != nil && localIp.Contains(net.IPv4(0, 0, 0, 0)) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@ -797,7 +840,7 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
|
||||
}
|
||||
|
||||
// Shortcut path for if groups, hosts, or cidr contained an `any`
|
||||
if fr.Any {
|
||||
if fr.Any.match(p, c) {
|
||||
return true
|
||||
}
|
||||
|
||||
@ -805,7 +848,7 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
|
||||
for _, sg := range fr.Groups {
|
||||
found := false
|
||||
|
||||
for _, g := range sg {
|
||||
for _, g := range sg.Groups {
|
||||
if _, ok := c.Details.InvertedGroups[g]; !ok {
|
||||
found = false
|
||||
break
|
||||
@ -814,33 +857,51 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
|
||||
found = true
|
||||
}
|
||||
|
||||
if found {
|
||||
if found && sg.LocalCIDR.match(p, c) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if fr.Hosts != nil {
|
||||
if _, ok := fr.Hosts[c.Details.Name]; ok {
|
||||
return true
|
||||
if flc, ok := fr.Hosts[c.Details.Name]; ok {
|
||||
if flc.match(p, c) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if fr.CIDR != nil {
|
||||
ok, _ := fr.CIDR.Contains(p.RemoteIP)
|
||||
if ok {
|
||||
return true
|
||||
return fr.CIDR.EachContains(p.RemoteIP, func(flc *firewallLocalCIDR) bool {
|
||||
return flc.match(p, c)
|
||||
})
|
||||
}
|
||||
|
||||
func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp *net.IPNet) error {
|
||||
if localIp == nil {
|
||||
if !f.hasSubnets || f.defaultLocalCIDRAny {
|
||||
flc.Any = true
|
||||
return nil
|
||||
}
|
||||
|
||||
localIp = f.assignedCIDR
|
||||
} else if localIp.Contains(net.IPv4(0, 0, 0, 0)) {
|
||||
flc.Any = true
|
||||
}
|
||||
|
||||
if fr.LocalCIDR != nil {
|
||||
ok, _ := fr.LocalCIDR.Contains(p.LocalIP)
|
||||
if ok {
|
||||
return true
|
||||
}
|
||||
flc.LocalCIDR.AddCIDR(localIp, struct{}{})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate) bool {
|
||||
if flc == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// No host, group, or cidr matched, bye bye
|
||||
return false
|
||||
if flc.Any {
|
||||
return true
|
||||
}
|
||||
|
||||
ok, _ := flc.LocalCIDR.Contains(p.LocalIP)
|
||||
return ok
|
||||
}
|
||||
|
||||
type rule struct {
|
||||
|
||||
210
firewall_test.go
210
firewall_test.go
@ -71,36 +71,32 @@ func TestFirewall_AddRule(t *testing.T) {
|
||||
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, nil, "", ""))
|
||||
// An empty rule is any
|
||||
assert.True(t, fw.InRules.TCP[1].Any.Any)
|
||||
assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
|
||||
assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
|
||||
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", ""))
|
||||
assert.False(t, fw.InRules.UDP[1].Any.Any)
|
||||
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1")
|
||||
assert.Nil(t, fw.InRules.UDP[1].Any.Any)
|
||||
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
|
||||
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, nil, "", ""))
|
||||
assert.False(t, fw.InRules.ICMP[1].Any.Any)
|
||||
assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
|
||||
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
|
||||
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, nil, "", ""))
|
||||
assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
|
||||
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
|
||||
ok, _ := fw.OutRules.AnyProto[1].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP))
|
||||
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||
ok, _ := fw.OutRules.AnyProto[1].Any.CIDR.GetCIDR(ti)
|
||||
assert.True(t, ok)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", nil, ti, "", ""))
|
||||
assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
|
||||
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
|
||||
ok, _ = fw.OutRules.AnyProto[1].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP))
|
||||
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||
ok, _ = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.GetCIDR(ti)
|
||||
assert.True(t, ok)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
@ -111,32 +107,14 @@ func TestFirewall_AddRule(t *testing.T) {
|
||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", "ca-sha"))
|
||||
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
|
||||
|
||||
// Set any and clear fields
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, ti, "", ""))
|
||||
assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0])
|
||||
assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")
|
||||
ok, _ = fw.OutRules.AnyProto[0].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP))
|
||||
assert.True(t, ok)
|
||||
ok, _ = fw.OutRules.AnyProto[0].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP))
|
||||
assert.True(t, ok)
|
||||
|
||||
// run twice just to make sure
|
||||
//TODO: these ANY rules should clear the CA firewall portion
|
||||
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))
|
||||
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", ""))
|
||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
|
||||
assert.Empty(t, fw.OutRules.AnyProto[0].Any.Groups)
|
||||
assert.Empty(t, fw.OutRules.AnyProto[0].Any.Hosts)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", ""))
|
||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
|
||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
_, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
|
||||
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, nil, "", ""))
|
||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
|
||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
||||
|
||||
// Test error conditions
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
@ -226,33 +204,43 @@ func TestFirewall_Drop(t *testing.T) {
|
||||
}
|
||||
|
||||
func BenchmarkFirewallTable_match(b *testing.B) {
|
||||
f := &Firewall{}
|
||||
ft := FirewallTable{
|
||||
TCP: firewallPort{},
|
||||
}
|
||||
|
||||
_, n, _ := net.ParseCIDR("172.1.1.1/32")
|
||||
_ = ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, n, "", "")
|
||||
_ = ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, n, "", "")
|
||||
_ = ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, n, "", "")
|
||||
_ = ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, n, "", "")
|
||||
_ = ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, n, "", "")
|
||||
goodLocalCIDRIP := iputil.Ip2VpnIp(n.IP)
|
||||
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", n, nil, "", "")
|
||||
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", nil, n, "", "")
|
||||
cp := cert.NewCAPool()
|
||||
|
||||
b.Run("fail on proto", func(b *testing.B) {
|
||||
// This benchmark is showing us the cost of failing to match the protocol
|
||||
c := &cert.NebulaCertificate{}
|
||||
for n := 0; n < b.N; n++ {
|
||||
ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp)
|
||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp))
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("fail on port", func(b *testing.B) {
|
||||
b.Run("pass proto, fail on port", func(b *testing.B) {
|
||||
// This benchmark is showing us the cost of matching a specific protocol but failing to match the port
|
||||
c := &cert.NebulaCertificate{}
|
||||
for n := 0; n < b.N; n++ {
|
||||
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp)
|
||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp))
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("fail all group, name, and cidr", func(b *testing.B) {
|
||||
b.Run("pass proto, port, fail on local CIDR", func(b *testing.B) {
|
||||
c := &cert.NebulaCertificate{}
|
||||
ip, _, _ := net.ParseCIDR("9.254.254.254/32")
|
||||
lip := iputil.Ip2VpnIp(ip)
|
||||
for n := 0; n < b.N; n++ {
|
||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: lip}, true, c, cp))
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) {
|
||||
_, ip, _ := net.ParseCIDR("9.254.254.254/32")
|
||||
c := &cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
@ -262,11 +250,25 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
||||
},
|
||||
}
|
||||
for n := 0; n < b.N; n++ {
|
||||
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
|
||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("pass on group", func(b *testing.B) {
|
||||
b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) {
|
||||
_, ip, _ := net.ParseCIDR("9.254.254.254/32")
|
||||
c := &cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
InvertedGroups: map[string]struct{}{"nope": {}},
|
||||
Name: "nope",
|
||||
Ips: []*net.IPNet{ip},
|
||||
},
|
||||
}
|
||||
for n := 0; n < b.N; n++ {
|
||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp))
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("pass on group on any local cidr", func(b *testing.B) {
|
||||
c := &cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
InvertedGroups: map[string]struct{}{"good-group": {}},
|
||||
@ -274,7 +276,19 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
||||
},
|
||||
}
|
||||
for n := 0; n < b.N; n++ {
|
||||
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
|
||||
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("pass on group on specific local cidr", func(b *testing.B) {
|
||||
c := &cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
InvertedGroups: map[string]struct{}{"good-group": {}},
|
||||
Name: "nope",
|
||||
},
|
||||
}
|
||||
for n := 0; n < b.N; n++ {
|
||||
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp))
|
||||
}
|
||||
})
|
||||
|
||||
@ -289,60 +303,60 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
||||
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("pass on ip", func(b *testing.B) {
|
||||
ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
|
||||
c := &cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
InvertedGroups: map[string]struct{}{"nope": {}},
|
||||
Name: "good-host",
|
||||
},
|
||||
}
|
||||
for n := 0; n < b.N; n++ {
|
||||
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("pass on local ip", func(b *testing.B) {
|
||||
ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
|
||||
c := &cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
InvertedGroups: map[string]struct{}{"nope": {}},
|
||||
Name: "good-host",
|
||||
},
|
||||
}
|
||||
for n := 0; n < b.N; n++ {
|
||||
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, LocalIP: ip}, true, c, cp)
|
||||
}
|
||||
})
|
||||
|
||||
_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, n, "", "")
|
||||
|
||||
b.Run("pass on ip with any port", func(b *testing.B) {
|
||||
ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
|
||||
c := &cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
InvertedGroups: map[string]struct{}{"nope": {}},
|
||||
Name: "good-host",
|
||||
},
|
||||
}
|
||||
for n := 0; n < b.N; n++ {
|
||||
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("pass on local ip with any port", func(b *testing.B) {
|
||||
ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
|
||||
c := &cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
InvertedGroups: map[string]struct{}{"nope": {}},
|
||||
Name: "good-host",
|
||||
},
|
||||
}
|
||||
for n := 0; n < b.N; n++ {
|
||||
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip}, true, c, cp)
|
||||
}
|
||||
})
|
||||
//
|
||||
//b.Run("pass on ip", func(b *testing.B) {
|
||||
// ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
|
||||
// c := &cert.NebulaCertificate{
|
||||
// Details: cert.NebulaCertificateDetails{
|
||||
// InvertedGroups: map[string]struct{}{"nope": {}},
|
||||
// Name: "good-host",
|
||||
// },
|
||||
// }
|
||||
// for n := 0; n < b.N; n++ {
|
||||
// ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp)
|
||||
// }
|
||||
//})
|
||||
//
|
||||
//b.Run("pass on local ip", func(b *testing.B) {
|
||||
// ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
|
||||
// c := &cert.NebulaCertificate{
|
||||
// Details: cert.NebulaCertificateDetails{
|
||||
// InvertedGroups: map[string]struct{}{"nope": {}},
|
||||
// Name: "good-host",
|
||||
// },
|
||||
// }
|
||||
// for n := 0; n < b.N; n++ {
|
||||
// ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, LocalIP: ip}, true, c, cp)
|
||||
// }
|
||||
//})
|
||||
//
|
||||
//_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, n, "", "")
|
||||
//
|
||||
//b.Run("pass on ip with any port", func(b *testing.B) {
|
||||
// ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
|
||||
// c := &cert.NebulaCertificate{
|
||||
// Details: cert.NebulaCertificateDetails{
|
||||
// InvertedGroups: map[string]struct{}{"nope": {}},
|
||||
// Name: "good-host",
|
||||
// },
|
||||
// }
|
||||
// for n := 0; n < b.N; n++ {
|
||||
// ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp)
|
||||
// }
|
||||
//})
|
||||
//
|
||||
//b.Run("pass on local ip with any port", func(b *testing.B) {
|
||||
// ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
|
||||
// c := &cert.NebulaCertificate{
|
||||
// Details: cert.NebulaCertificateDetails{
|
||||
// InvertedGroups: map[string]struct{}{"nope": {}},
|
||||
// Name: "good-host",
|
||||
// },
|
||||
// }
|
||||
// for n := 0; n < b.N; n++ {
|
||||
// ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip}, true, c, cp)
|
||||
// }
|
||||
//})
|
||||
}
|
||||
|
||||
func TestFirewall_Drop2(t *testing.T) {
|
||||
|
||||
14
go.mod
14
go.mod
@ -8,7 +8,7 @@ require (
|
||||
github.com/armon/go-radix v1.0.0
|
||||
github.com/clarkmcc/go-dag v0.0.0-20220908000337-9c3ba5b365fc
|
||||
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
|
||||
github.com/flynn/noise v1.0.1
|
||||
github.com/flynn/noise v1.1.0
|
||||
github.com/gogo/protobuf v1.3.2
|
||||
github.com/google/gopacket v1.1.19
|
||||
github.com/kardianos/service v1.2.2
|
||||
@ -19,19 +19,19 @@ require (
|
||||
github.com/sirupsen/logrus v1.9.3
|
||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
||||
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
|
||||
github.com/stretchr/testify v1.8.4
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/timandy/routine v1.1.1
|
||||
github.com/vishvananda/netlink v1.2.1-beta.2
|
||||
golang.org/x/crypto v0.18.0
|
||||
golang.org/x/crypto v0.21.0
|
||||
golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53
|
||||
golang.org/x/net v0.20.0
|
||||
golang.org/x/net v0.22.0
|
||||
golang.org/x/sync v0.6.0
|
||||
golang.org/x/sys v0.16.0
|
||||
golang.org/x/term v0.16.0
|
||||
golang.org/x/sys v0.18.0
|
||||
golang.org/x/term v0.18.0
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||
google.golang.org/protobuf v1.32.0
|
||||
google.golang.org/protobuf v1.33.0
|
||||
gopkg.in/yaml.v2 v2.4.0
|
||||
gvisor.dev/gvisor v0.0.0-20230504175454-7b0a1988a28f
|
||||
)
|
||||
|
||||
28
go.sum
28
go.sum
@ -24,8 +24,8 @@ github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432/go.
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/flynn/noise v1.0.1 h1:vPp/jdQLXC6ppsXSj/pM3W1BIJ5FEHE2TulSJBpb43Y=
|
||||
github.com/flynn/noise v1.0.1/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
|
||||
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
|
||||
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
|
||||
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
||||
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
||||
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
|
||||
@ -134,10 +134,10 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/timandy/routine v1.1.1 h1:6/Z7qLFZj3GrzuRksBFzIG8YGUh8CLhjnnMePBQTrEI=
|
||||
github.com/timandy/routine v1.1.1/go.mod h1:OZHPOKSvqL/ZvqXFkNZyit0xIVelERptYXdAHH00adQ=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/vishvananda/netlink v1.2.1-beta.2 h1:Llsql0lnQEbHj0I1OuKyp8otXp0r3q0mPkuhwHfStVs=
|
||||
github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho=
|
||||
github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
|
||||
@ -150,8 +150,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
||||
golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc=
|
||||
golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
|
||||
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
|
||||
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
|
||||
golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53 h1:5llv2sWeaMSnA3w2kS57ouQQ4pudlXrR0dCgw51QK9o=
|
||||
golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w=
|
||||
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
||||
@ -170,8 +170,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
|
||||
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo=
|
||||
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
|
||||
golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc=
|
||||
golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
|
||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@ -198,11 +198,11 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU=
|
||||
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
|
||||
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.16.0 h1:m+B6fahuftsE9qjo0VWp2FW0mB3MTJvR0BaMQrq0pmE=
|
||||
golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
|
||||
golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8=
|
||||
golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
@ -233,8 +233,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE
|
||||
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
|
||||
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.32.0 h1:pPC6BG5ex8PDFnkbrGU3EixyhKcQ2aDuBS36lqK/C7I=
|
||||
google.golang.org/protobuf v1.32.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
||||
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
|
||||
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
||||
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
|
||||
@ -408,7 +408,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha
|
||||
hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
|
||||
|
||||
f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp).
|
||||
WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.preferredRanges)).
|
||||
WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())).
|
||||
Info("Blocked addresses for handshakes")
|
||||
|
||||
// Swap the packet store to benefit the original intended recipient
|
||||
|
||||
@ -181,7 +181,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
|
||||
hostinfo := hh.hostinfo
|
||||
// If we are out of time, clean up
|
||||
if hh.counter >= hm.config.retries {
|
||||
hh.hostinfo.logger(hm.l).WithField("udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.preferredRanges)).
|
||||
hh.hostinfo.logger(hm.l).WithField("udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())).
|
||||
WithField("initiatorIndex", hh.hostinfo.localIndexId).
|
||||
WithField("remoteIndex", hh.hostinfo.remoteIndexId).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
@ -211,7 +211,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
|
||||
hostinfo.remotes = hm.lightHouse.QueryCache(vpnIp)
|
||||
}
|
||||
|
||||
remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.preferredRanges)
|
||||
remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())
|
||||
remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hh.lastRemotes)
|
||||
|
||||
// We only care about a lighthouse trigger if we have new remotes to send to.
|
||||
@ -235,7 +235,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
|
||||
|
||||
// Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
|
||||
var sentTo []*udp.Addr
|
||||
hostinfo.remotes.ForEach(hm.mainHostMap.preferredRanges, func(addr *udp.Addr, _ bool) {
|
||||
hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr *udp.Addr, _ bool) {
|
||||
hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
|
||||
err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
|
||||
if err != nil {
|
||||
@ -362,7 +362,7 @@ func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han
|
||||
hm.mainHostMap.RUnlock()
|
||||
// Do not attempt promotion if you are a lighthouse
|
||||
if !hm.lightHouse.amLighthouse {
|
||||
h.TryPromoteBest(hm.mainHostMap.preferredRanges, hm.f)
|
||||
h.TryPromoteBest(hm.mainHostMap.GetPreferredRanges(), hm.f)
|
||||
}
|
||||
return h, true
|
||||
}
|
||||
@ -602,7 +602,7 @@ func (hm *HandshakeManager) queryIndex(index uint32) *HandshakeHostInfo {
|
||||
}
|
||||
|
||||
func (c *HandshakeManager) GetPreferredRanges() []*net.IPNet {
|
||||
return c.mainHostMap.preferredRanges
|
||||
return c.mainHostMap.GetPreferredRanges()
|
||||
}
|
||||
|
||||
func (c *HandshakeManager) ForEachVpnIp(f controlEach) {
|
||||
|
||||
@ -19,7 +19,9 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
|
||||
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
|
||||
ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2"))
|
||||
preferredRanges := []*net.IPNet{localrange}
|
||||
mainHM := NewHostMap(l, vpncidr, preferredRanges)
|
||||
mainHM := newHostMap(l, vpncidr)
|
||||
mainHM.preferredRanges.Store(&preferredRanges)
|
||||
|
||||
lh := newTestLighthouse()
|
||||
|
||||
cs := &CertState{
|
||||
|
||||
73
hostmap.go
73
hostmap.go
@ -10,6 +10,7 @@ import (
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/cidr"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
@ -56,9 +57,8 @@ type HostMap struct {
|
||||
Relays map[uint32]*HostInfo // Maps a Relay IDX to a Relay HostInfo object
|
||||
RemoteIndexes map[uint32]*HostInfo
|
||||
Hosts map[iputil.VpnIp]*HostInfo
|
||||
preferredRanges []*net.IPNet
|
||||
preferredRanges atomic.Pointer[[]*net.IPNet]
|
||||
vpnCIDR *net.IPNet
|
||||
metricsEnabled bool
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
@ -254,22 +254,54 @@ type cachedPacketMetrics struct {
|
||||
dropped metrics.Counter
|
||||
}
|
||||
|
||||
func NewHostMap(l *logrus.Logger, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap {
|
||||
h := map[iputil.VpnIp]*HostInfo{}
|
||||
i := map[uint32]*HostInfo{}
|
||||
r := map[uint32]*HostInfo{}
|
||||
relays := map[uint32]*HostInfo{}
|
||||
m := HostMap{
|
||||
syncRWMutex: newSyncRWMutex("hostmap"),
|
||||
Indexes: i,
|
||||
Relays: relays,
|
||||
RemoteIndexes: r,
|
||||
Hosts: h,
|
||||
preferredRanges: preferredRanges,
|
||||
vpnCIDR: vpnCIDR,
|
||||
l: l,
|
||||
func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR *net.IPNet, c *config.C) *HostMap {
|
||||
hm := newHostMap(l, vpnCIDR)
|
||||
|
||||
hm.reload(c, true)
|
||||
c.RegisterReloadCallback(func(c *config.C) {
|
||||
hm.reload(c, false)
|
||||
})
|
||||
|
||||
l.WithField("network", hm.vpnCIDR.String()).
|
||||
WithField("preferredRanges", hm.GetPreferredRanges()).
|
||||
Info("Main HostMap created")
|
||||
|
||||
return hm
|
||||
}
|
||||
|
||||
func newHostMap(l *logrus.Logger, vpnCIDR *net.IPNet) *HostMap {
|
||||
return &HostMap{
|
||||
syncRWMutex: newSyncRWMutex("hostmap"),
|
||||
Indexes: map[uint32]*HostInfo{},
|
||||
Relays: map[uint32]*HostInfo{},
|
||||
RemoteIndexes: map[uint32]*HostInfo{},
|
||||
Hosts: map[iputil.VpnIp]*HostInfo{},
|
||||
vpnCIDR: vpnCIDR,
|
||||
l: l,
|
||||
}
|
||||
}
|
||||
|
||||
func (hm *HostMap) reload(c *config.C, initial bool) {
|
||||
if initial || c.HasChanged("preferred_ranges") {
|
||||
var preferredRanges []*net.IPNet
|
||||
rawPreferredRanges := c.GetStringSlice("preferred_ranges", []string{})
|
||||
|
||||
for _, rawPreferredRange := range rawPreferredRanges {
|
||||
_, preferredRange, err := net.ParseCIDR(rawPreferredRange)
|
||||
|
||||
if err != nil {
|
||||
hm.l.WithError(err).WithField("range", rawPreferredRanges).Warn("Failed to parse preferred ranges, ignoring")
|
||||
continue
|
||||
}
|
||||
|
||||
preferredRanges = append(preferredRanges, preferredRange)
|
||||
}
|
||||
|
||||
oldRanges := hm.preferredRanges.Swap(&preferredRanges)
|
||||
if !initial {
|
||||
hm.l.WithField("oldPreferredRanges", *oldRanges).WithField("newPreferredRanges", preferredRanges).Info("preferred_ranges changed")
|
||||
}
|
||||
}
|
||||
return &m
|
||||
}
|
||||
|
||||
// EmitStats reports host, index, and relay counts to the stats collection system
|
||||
@ -458,7 +490,7 @@ func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) *HostI
|
||||
hm.RUnlock()
|
||||
// Do not attempt promotion if you are a lighthouse
|
||||
if promoteIfce != nil && !promoteIfce.lightHouse.amLighthouse {
|
||||
h.TryPromoteBest(hm.preferredRanges, promoteIfce)
|
||||
h.TryPromoteBest(hm.GetPreferredRanges(), promoteIfce)
|
||||
}
|
||||
return h
|
||||
|
||||
@ -505,7 +537,8 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
|
||||
}
|
||||
|
||||
func (hm *HostMap) GetPreferredRanges() []*net.IPNet {
|
||||
return hm.preferredRanges
|
||||
//NOTE: if preferredRanges is ever not stored before a load this will fail to dereference a nil pointer
|
||||
return *hm.preferredRanges.Load()
|
||||
}
|
||||
|
||||
func (hm *HostMap) ForEachVpnIp(f controlEach) {
|
||||
@ -597,7 +630,7 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool {
|
||||
// NOTE: We do this loop here instead of calling `isPreferred` in
|
||||
// remote_list.go so that we only have to loop over preferredRanges once.
|
||||
newIsPreferred := false
|
||||
for _, l := range hm.preferredRanges {
|
||||
for _, l := range hm.GetPreferredRanges() {
|
||||
// return early if we are already on a preferred remote
|
||||
if l.Contains(currentRemote.IP) {
|
||||
return false
|
||||
|
||||
@ -4,19 +4,19 @@ import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestHostMap_MakePrimary(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
hm := NewHostMap(
|
||||
hm := newHostMap(
|
||||
l,
|
||||
&net.IPNet{
|
||||
IP: net.IP{10, 0, 0, 1},
|
||||
Mask: net.IPMask{255, 255, 255, 0},
|
||||
},
|
||||
[]*net.IPNet{},
|
||||
)
|
||||
|
||||
f := &Interface{}
|
||||
@ -91,13 +91,12 @@ func TestHostMap_MakePrimary(t *testing.T) {
|
||||
|
||||
func TestHostMap_DeleteHostInfo(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
hm := NewHostMap(
|
||||
hm := newHostMap(
|
||||
l,
|
||||
&net.IPNet{
|
||||
IP: net.IP{10, 0, 0, 1},
|
||||
Mask: net.IPMask{255, 255, 255, 0},
|
||||
},
|
||||
[]*net.IPNet{},
|
||||
)
|
||||
|
||||
f := &Interface{}
|
||||
@ -205,3 +204,33 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
|
||||
prim = hm.QueryVpnIp(1)
|
||||
assert.Nil(t, prim)
|
||||
}
|
||||
|
||||
func TestHostMap_reload(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
c := config.NewC(l)
|
||||
|
||||
hm := NewHostMapFromConfig(
|
||||
l,
|
||||
&net.IPNet{
|
||||
IP: net.IP{10, 0, 0, 1},
|
||||
Mask: net.IPMask{255, 255, 255, 0},
|
||||
},
|
||||
c,
|
||||
)
|
||||
|
||||
toS := func(ipn []*net.IPNet) []string {
|
||||
var s []string
|
||||
for _, n := range ipn {
|
||||
s = append(s, n.String())
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
assert.Empty(t, hm.GetPreferredRanges())
|
||||
|
||||
c.ReloadConfigString("preferred_ranges: [1.1.1.0/24, 10.1.1.0/24]")
|
||||
assert.EqualValues(t, []string{"1.1.1.0/24", "10.1.1.0/24"}, toS(hm.GetPreferredRanges()))
|
||||
|
||||
c.ReloadConfigString("preferred_ranges: [1.1.1.1/32]")
|
||||
assert.EqualValues(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges()))
|
||||
}
|
||||
|
||||
47
main.go
47
main.go
@ -183,52 +183,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
}
|
||||
}
|
||||
|
||||
// Set up my internal host map
|
||||
var preferredRanges []*net.IPNet
|
||||
rawPreferredRanges := c.GetStringSlice("preferred_ranges", []string{})
|
||||
// First, check if 'preferred_ranges' is set and fallback to 'local_range'
|
||||
if len(rawPreferredRanges) > 0 {
|
||||
for _, rawPreferredRange := range rawPreferredRanges {
|
||||
_, preferredRange, err := net.ParseCIDR(rawPreferredRange)
|
||||
if err != nil {
|
||||
return nil, util.ContextualizeIfNeeded("Failed to parse preferred ranges", err)
|
||||
}
|
||||
preferredRanges = append(preferredRanges, preferredRange)
|
||||
}
|
||||
}
|
||||
|
||||
// local_range was superseded by preferred_ranges. If it is still present,
|
||||
// merge the local_range setting into preferred_ranges. We will probably
|
||||
// deprecate local_range and remove in the future.
|
||||
rawLocalRange := c.GetString("local_range", "")
|
||||
if rawLocalRange != "" {
|
||||
_, localRange, err := net.ParseCIDR(rawLocalRange)
|
||||
if err != nil {
|
||||
return nil, util.ContextualizeIfNeeded("Failed to parse local_range", err)
|
||||
}
|
||||
|
||||
// Check if the entry for local_range was already specified in
|
||||
// preferred_ranges. Don't put it into the slice twice if so.
|
||||
var found bool
|
||||
for _, r := range preferredRanges {
|
||||
if r.String() == localRange.String() {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
preferredRanges = append(preferredRanges, localRange)
|
||||
}
|
||||
}
|
||||
|
||||
hostMap := NewHostMap(l, tunCidr, preferredRanges)
|
||||
hostMap.metricsEnabled = c.GetBool("stats.message_metrics", false)
|
||||
|
||||
l.
|
||||
WithField("network", hostMap.vpnCIDR.String()).
|
||||
WithField("preferredRanges", hostMap.preferredRanges).
|
||||
Info("Main HostMap created")
|
||||
|
||||
hostMap := NewHostMapFromConfig(l, tunCidr, c)
|
||||
punchy := NewPunchyFromConfig(l, c)
|
||||
lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy)
|
||||
if err != nil {
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package overlay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"math"
|
||||
"net"
|
||||
@ -21,6 +22,35 @@ type Route struct {
|
||||
Install bool
|
||||
}
|
||||
|
||||
// Equal determines if a route that could be installed in the system route table is equal to another
|
||||
// Via is ignored since that is only consumed within nebula itself
|
||||
func (r Route) Equal(t Route) bool {
|
||||
if !r.Cidr.IP.Equal(t.Cidr.IP) {
|
||||
return false
|
||||
}
|
||||
if !bytes.Equal(r.Cidr.Mask, t.Cidr.Mask) {
|
||||
return false
|
||||
}
|
||||
if r.Metric != t.Metric {
|
||||
return false
|
||||
}
|
||||
if r.MTU != t.MTU {
|
||||
return false
|
||||
}
|
||||
if r.Install != t.Install {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (r Route) String() string {
|
||||
s := r.Cidr.String()
|
||||
if r.Metric != 0 {
|
||||
s += fmt.Sprintf(" metric: %v", r.Metric)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*cidr.Tree4[iputil.VpnIp], error) {
|
||||
routeTree := cidr.NewTree4[iputil.VpnIp]()
|
||||
for _, r := range routes {
|
||||
|
||||
@ -10,60 +10,63 @@ import (
|
||||
|
||||
const DefaultMTU = 1300
|
||||
|
||||
// TODO: We may be able to remove routines
|
||||
type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error)
|
||||
|
||||
func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) {
|
||||
routes, err := parseRoutes(c, tunCidr)
|
||||
if err != nil {
|
||||
return nil, util.NewContextualError("Could not parse tun.routes", nil, err)
|
||||
}
|
||||
|
||||
unsafeRoutes, err := parseUnsafeRoutes(c, tunCidr)
|
||||
if err != nil {
|
||||
return nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err)
|
||||
}
|
||||
routes = append(routes, unsafeRoutes...)
|
||||
|
||||
switch {
|
||||
case c.GetBool("tun.disabled", false):
|
||||
tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
|
||||
return tun, nil
|
||||
|
||||
default:
|
||||
return newTun(
|
||||
l,
|
||||
c.GetString("tun.dev", ""),
|
||||
tunCidr,
|
||||
c.GetInt("tun.mtu", DefaultMTU),
|
||||
routes,
|
||||
c.GetInt("tun.tx_queue", 500),
|
||||
routines > 1,
|
||||
c.GetBool("tun.use_system_route_table", false),
|
||||
)
|
||||
return newTun(c, l, tunCidr, routines > 1)
|
||||
}
|
||||
}
|
||||
|
||||
func NewFdDeviceFromConfig(fd *int) DeviceFactory {
|
||||
return func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) {
|
||||
routes, err := parseRoutes(c, tunCidr)
|
||||
if err != nil {
|
||||
return nil, util.NewContextualError("Could not parse tun.routes", nil, err)
|
||||
}
|
||||
|
||||
unsafeRoutes, err := parseUnsafeRoutes(c, tunCidr)
|
||||
if err != nil {
|
||||
return nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err)
|
||||
}
|
||||
routes = append(routes, unsafeRoutes...)
|
||||
return newTunFromFd(
|
||||
l,
|
||||
*fd,
|
||||
tunCidr,
|
||||
c.GetInt("tun.mtu", DefaultMTU),
|
||||
routes,
|
||||
c.GetInt("tun.tx_queue", 500),
|
||||
c.GetBool("tun.use_system_route_table", false),
|
||||
)
|
||||
|
||||
return newTunFromFd(c, l, *fd, tunCidr)
|
||||
}
|
||||
}
|
||||
|
||||
func getAllRoutesFromConfig(c *config.C, cidr *net.IPNet, initial bool) (bool, []Route, error) {
|
||||
if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
routes, err := parseRoutes(c, cidr)
|
||||
if err != nil {
|
||||
return true, nil, util.NewContextualError("Could not parse tun.routes", nil, err)
|
||||
}
|
||||
|
||||
unsafeRoutes, err := parseUnsafeRoutes(c, cidr)
|
||||
if err != nil {
|
||||
return true, nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err)
|
||||
}
|
||||
|
||||
routes = append(routes, unsafeRoutes...)
|
||||
return true, routes, nil
|
||||
}
|
||||
|
||||
// findRemovedRoutes will return all routes that are not present in the newRoutes list and would affect the system route table.
|
||||
// Via is not used to evaluate since it does not affect the system route table.
|
||||
func findRemovedRoutes(newRoutes, oldRoutes []Route) []Route {
|
||||
var removed []Route
|
||||
has := func(entry Route) bool {
|
||||
for _, check := range newRoutes {
|
||||
if check.Equal(entry) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
for _, oldEntry := range oldRoutes {
|
||||
if !has(oldEntry) {
|
||||
removed = append(removed, oldEntry)
|
||||
}
|
||||
}
|
||||
|
||||
return removed
|
||||
}
|
||||
|
||||
@ -8,45 +8,57 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cidr"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/util"
|
||||
)
|
||||
|
||||
type tun struct {
|
||||
io.ReadWriteCloser
|
||||
fd int
|
||||
cidr *net.IPNet
|
||||
routeTree *cidr.Tree4[iputil.VpnIp]
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool) (*tun, error) {
|
||||
routeTree, err := makeRouteTree(l, routes, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) (*tun, error) {
|
||||
// XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly.
|
||||
// Be sure not to call file.Fd() as it will set the fd to blocking mode.
|
||||
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
||||
|
||||
return &tun{
|
||||
t := &tun{
|
||||
ReadWriteCloser: file,
|
||||
fd: deviceFd,
|
||||
cidr: cidr,
|
||||
l: l,
|
||||
routeTree: routeTree,
|
||||
}, nil
|
||||
}
|
||||
|
||||
err := t.reload(c, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.RegisterReloadCallback(func(c *config.C) {
|
||||
err := t.reload(c, false)
|
||||
if err != nil {
|
||||
util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
|
||||
}
|
||||
})
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool, _ bool) (*tun, error) {
|
||||
func newTun(_ *config.C, _ *logrus.Logger, _ *net.IPNet, _ bool) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTun not supported in Android")
|
||||
}
|
||||
|
||||
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
||||
_, r := t.routeTree.MostSpecificContains(ip)
|
||||
_, r := t.routeTree.Load().MostSpecificContains(ip)
|
||||
return r
|
||||
}
|
||||
|
||||
@ -54,6 +66,27 @@ func (t tun) Activate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) reload(c *config.C, initial bool) error {
|
||||
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !initial && !change {
|
||||
return nil
|
||||
}
|
||||
|
||||
routeTree, err := makeRouteTree(t.l, routes, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Teach nebula how to handle the routes
|
||||
t.Routes.Store(&routes)
|
||||
t.routeTree.Store(routeTree)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) Cidr() *net.IPNet {
|
||||
return t.cidr
|
||||
}
|
||||
|
||||
@ -9,12 +9,15 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cidr"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/util"
|
||||
netroute "golang.org/x/net/route"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
@ -24,8 +27,9 @@ type tun struct {
|
||||
Device string
|
||||
cidr *net.IPNet
|
||||
DefaultMTU int
|
||||
Routes []Route
|
||||
routeTree *cidr.Tree4[iputil.VpnIp]
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
|
||||
linkAddr *netroute.LinkAddr
|
||||
l *logrus.Logger
|
||||
|
||||
// cache out buffer since we need to prepend 4 bytes for tun metadata
|
||||
@ -69,12 +73,8 @@ type ifreqMTU struct {
|
||||
pad [8]byte
|
||||
}
|
||||
|
||||
func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) {
|
||||
routeTree, err := makeRouteTree(l, routes, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) {
|
||||
name := c.GetString("tun.dev", "")
|
||||
ifIndex := -1
|
||||
if name != "" && name != "utun" {
|
||||
_, err := fmt.Sscanf(name, "utun%d", &ifIndex)
|
||||
@ -142,17 +142,27 @@ func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, rout
|
||||
|
||||
file := os.NewFile(uintptr(fd), "")
|
||||
|
||||
tun := &tun{
|
||||
t := &tun{
|
||||
ReadWriteCloser: file,
|
||||
Device: name,
|
||||
cidr: cidr,
|
||||
DefaultMTU: defaultMTU,
|
||||
Routes: routes,
|
||||
routeTree: routeTree,
|
||||
DefaultMTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
}
|
||||
|
||||
return tun, nil
|
||||
err = t.reload(c, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.RegisterReloadCallback(func(c *config.C) {
|
||||
err := t.reload(c, false)
|
||||
if err != nil {
|
||||
util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
|
||||
}
|
||||
})
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (t *tun) deviceBytes() (o [16]byte) {
|
||||
@ -162,7 +172,7 @@ func (t *tun) deviceBytes() (o [16]byte) {
|
||||
return
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) {
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
|
||||
}
|
||||
|
||||
@ -260,6 +270,7 @@ func (t *tun) Activate() error {
|
||||
if linkAddr == nil {
|
||||
return fmt.Errorf("unable to discover link_addr for tun interface")
|
||||
}
|
||||
t.linkAddr = linkAddr
|
||||
|
||||
copy(routeAddr.IP[:], addr[:])
|
||||
copy(maskAddr.IP[:], mask[:])
|
||||
@ -278,33 +289,48 @@ func (t *tun) Activate() error {
|
||||
}
|
||||
|
||||
// Unsafe path routes
|
||||
for _, r := range t.Routes {
|
||||
if r.Via == nil || !r.Install {
|
||||
// We don't allow route MTUs so only install routes with a via
|
||||
continue
|
||||
}
|
||||
return t.addRoutes(false)
|
||||
}
|
||||
|
||||
copy(routeAddr.IP[:], r.Cidr.IP.To4())
|
||||
copy(maskAddr.IP[:], net.IP(r.Cidr.Mask).To4())
|
||||
func (t *tun) reload(c *config.C, initial bool) error {
|
||||
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = addRoute(routeSock, routeAddr, maskAddr, linkAddr)
|
||||
if !initial && !change {
|
||||
return nil
|
||||
}
|
||||
|
||||
routeTree, err := makeRouteTree(t.l, routes, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Teach nebula how to handle the routes before establishing them in the system table
|
||||
oldRoutes := t.Routes.Swap(&routes)
|
||||
t.routeTree.Store(routeTree)
|
||||
|
||||
if !initial {
|
||||
// Remove first, if the system removes a wanted route hopefully it will be re-added next
|
||||
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
|
||||
if err != nil {
|
||||
if errors.Is(err, unix.EEXIST) {
|
||||
t.l.WithField("route", r.Cidr).
|
||||
Warnf("unable to add unsafe_route, identical route already exists")
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
|
||||
}
|
||||
|
||||
// TODO how to set metric
|
||||
// Ensure any routes we actually want are installed
|
||||
err = t.addRoutes(true)
|
||||
if err != nil {
|
||||
// Catch any stray logs
|
||||
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
||||
ok, r := t.routeTree.MostSpecificContains(ip)
|
||||
ok, r := t.routeTree.Load().MostSpecificContains(ip)
|
||||
if ok {
|
||||
return r
|
||||
}
|
||||
@ -340,6 +366,88 @@ func getLinkAddr(name string) (*netroute.LinkAddr, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (t *tun) addRoutes(logErrors bool) error {
|
||||
routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
unix.Shutdown(routeSock, unix.SHUT_RDWR)
|
||||
err := unix.Close(routeSock)
|
||||
if err != nil {
|
||||
t.l.WithError(err).Error("failed to close AF_ROUTE socket")
|
||||
}
|
||||
}()
|
||||
|
||||
routeAddr := &netroute.Inet4Addr{}
|
||||
maskAddr := &netroute.Inet4Addr{}
|
||||
routes := *t.Routes.Load()
|
||||
for _, r := range routes {
|
||||
if r.Via == nil || !r.Install {
|
||||
// We don't allow route MTUs so only install routes with a via
|
||||
continue
|
||||
}
|
||||
|
||||
copy(routeAddr.IP[:], r.Cidr.IP.To4())
|
||||
copy(maskAddr.IP[:], net.IP(r.Cidr.Mask).To4())
|
||||
|
||||
err := addRoute(routeSock, routeAddr, maskAddr, t.linkAddr)
|
||||
if err != nil {
|
||||
if errors.Is(err, unix.EEXIST) {
|
||||
t.l.WithField("route", r.Cidr).
|
||||
Warnf("unable to add unsafe_route, identical route already exists")
|
||||
} else {
|
||||
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
|
||||
if logErrors {
|
||||
retErr.Log(t.l)
|
||||
} else {
|
||||
return retErr
|
||||
}
|
||||
}
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Added route")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) removeRoutes(routes []Route) error {
|
||||
routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
unix.Shutdown(routeSock, unix.SHUT_RDWR)
|
||||
err := unix.Close(routeSock)
|
||||
if err != nil {
|
||||
t.l.WithError(err).Error("failed to close AF_ROUTE socket")
|
||||
}
|
||||
}()
|
||||
|
||||
routeAddr := &netroute.Inet4Addr{}
|
||||
maskAddr := &netroute.Inet4Addr{}
|
||||
|
||||
for _, r := range routes {
|
||||
if !r.Install {
|
||||
continue
|
||||
}
|
||||
|
||||
copy(routeAddr.IP[:], r.Cidr.IP.To4())
|
||||
copy(maskAddr.IP[:], net.IP(r.Cidr.Mask).To4())
|
||||
|
||||
err := delRoute(routeSock, routeAddr, maskAddr, t.linkAddr)
|
||||
if err != nil {
|
||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Removed route")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error {
|
||||
r := netroute.RouteMessage{
|
||||
Version: unix.RTM_VERSION,
|
||||
@ -365,6 +473,30 @@ func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr)
|
||||
return nil
|
||||
}
|
||||
|
||||
func delRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error {
|
||||
r := netroute.RouteMessage{
|
||||
Version: unix.RTM_VERSION,
|
||||
Type: unix.RTM_DELETE,
|
||||
Seq: 1,
|
||||
Addrs: []netroute.Addr{
|
||||
unix.RTAX_DST: addr,
|
||||
unix.RTAX_GATEWAY: link,
|
||||
unix.RTAX_NETMASK: mask,
|
||||
},
|
||||
}
|
||||
|
||||
data, err := r.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
||||
}
|
||||
_, err = unix.Write(sock, data[:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) Read(to []byte) (int, error) {
|
||||
|
||||
buf := make([]byte, len(to)+4)
|
||||
|
||||
@ -13,12 +13,15 @@ import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cidr"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/util"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -47,8 +50,8 @@ type tun struct {
|
||||
Device string
|
||||
cidr *net.IPNet
|
||||
MTU int
|
||||
Routes []Route
|
||||
routeTree *cidr.Tree4[iputil.VpnIp]
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
|
||||
l *logrus.Logger
|
||||
|
||||
io.ReadWriteCloser
|
||||
@ -76,14 +79,15 @@ func (t *tun) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) {
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
|
||||
}
|
||||
|
||||
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) {
|
||||
func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) {
|
||||
// Try to open existing tun device
|
||||
var file *os.File
|
||||
var err error
|
||||
deviceName := c.GetString("tun.dev", "")
|
||||
if deviceName != "" {
|
||||
file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
|
||||
}
|
||||
@ -144,47 +148,85 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int
|
||||
ioctl(fd, syscall.SIOCSIFNAME, uintptr(unsafe.Pointer(&ifrr)))
|
||||
}
|
||||
|
||||
routeTree, err := makeRouteTree(l, routes, false)
|
||||
t := &tun{
|
||||
ReadWriteCloser: file,
|
||||
Device: deviceName,
|
||||
cidr: cidr,
|
||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
}
|
||||
|
||||
err = t.reload(c, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &tun{
|
||||
ReadWriteCloser: file,
|
||||
Device: deviceName,
|
||||
cidr: cidr,
|
||||
MTU: defaultMTU,
|
||||
Routes: routes,
|
||||
routeTree: routeTree,
|
||||
l: l,
|
||||
}, nil
|
||||
c.RegisterReloadCallback(func(c *config.C) {
|
||||
err := t.reload(c, false)
|
||||
if err != nil {
|
||||
util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
|
||||
}
|
||||
})
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (t *tun) Activate() error {
|
||||
var err error
|
||||
// TODO use syscalls instead of exec.Command
|
||||
t.l.Debug("command: ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String())
|
||||
if err = exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String()).Run(); err != nil {
|
||||
cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||
}
|
||||
t.l.Debug("command: route", "-n", "add", "-net", t.cidr.String(), "-interface", t.Device)
|
||||
if err = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), "-interface", t.Device).Run(); err != nil {
|
||||
|
||||
cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), "-interface", t.Device)
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'route add': %s", err)
|
||||
}
|
||||
t.l.Debug("command: ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
|
||||
if err = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU)).Run(); err != nil {
|
||||
|
||||
cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||
}
|
||||
|
||||
// Unsafe path routes
|
||||
for _, r := range t.Routes {
|
||||
if r.Via == nil || !r.Install {
|
||||
// We don't allow route MTUs so only install routes with a via
|
||||
continue
|
||||
return t.addRoutes(false)
|
||||
}
|
||||
|
||||
func (t *tun) reload(c *config.C, initial bool) error {
|
||||
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !initial && !change {
|
||||
return nil
|
||||
}
|
||||
|
||||
routeTree, err := makeRouteTree(t.l, routes, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Teach nebula how to handle the routes before establishing them in the system table
|
||||
oldRoutes := t.Routes.Swap(&routes)
|
||||
t.routeTree.Store(routeTree)
|
||||
|
||||
if !initial {
|
||||
// Remove first, if the system removes a wanted route hopefully it will be re-added next
|
||||
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
|
||||
if err != nil {
|
||||
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
|
||||
}
|
||||
|
||||
t.l.Debug("command: route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device)
|
||||
if err = exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device).Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.Cidr.String(), err)
|
||||
// Ensure any routes we actually want are installed
|
||||
err = t.addRoutes(true)
|
||||
if err != nil {
|
||||
// Catch any stray logs
|
||||
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
|
||||
}
|
||||
}
|
||||
|
||||
@ -192,7 +234,7 @@ func (t *tun) Activate() error {
|
||||
}
|
||||
|
||||
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
||||
_, r := t.routeTree.MostSpecificContains(ip)
|
||||
_, r := t.routeTree.Load().MostSpecificContains(ip)
|
||||
return r
|
||||
}
|
||||
|
||||
@ -208,6 +250,46 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
|
||||
}
|
||||
|
||||
func (t *tun) addRoutes(logErrors bool) error {
|
||||
routes := *t.Routes.Load()
|
||||
for _, r := range routes {
|
||||
if r.Via == nil || !r.Install {
|
||||
// We don't allow route MTUs so only install routes with a via
|
||||
continue
|
||||
}
|
||||
|
||||
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device)
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err := cmd.Run(); err != nil {
|
||||
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
|
||||
if logErrors {
|
||||
retErr.Log(t.l)
|
||||
} else {
|
||||
return retErr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) removeRoutes(routes []Route) error {
|
||||
for _, r := range routes {
|
||||
if !r.Install {
|
||||
continue
|
||||
}
|
||||
|
||||
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), "-interface", t.Device)
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err := cmd.Run(); err != nil {
|
||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Removed route")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) deviceBytes() (o [16]byte) {
|
||||
for i, c := range t.Device {
|
||||
o[i] = byte(c)
|
||||
|
||||
@ -10,43 +10,78 @@ import (
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cidr"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/util"
|
||||
)
|
||||
|
||||
type tun struct {
|
||||
io.ReadWriteCloser
|
||||
cidr *net.IPNet
|
||||
routeTree *cidr.Tree4[iputil.VpnIp]
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool, _ bool) (*tun, error) {
|
||||
func newTun(_ *config.C, _ *logrus.Logger, _ *net.IPNet, _ bool) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTun not supported in iOS")
|
||||
}
|
||||
|
||||
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool) (*tun, error) {
|
||||
routeTree, err := makeRouteTree(l, routes, false)
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) (*tun, error) {
|
||||
file := os.NewFile(uintptr(deviceFd), "/dev/tun")
|
||||
t := &tun{
|
||||
cidr: cidr,
|
||||
ReadWriteCloser: &tunReadCloser{f: file},
|
||||
l: l,
|
||||
}
|
||||
|
||||
err := t.reload(c, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
file := os.NewFile(uintptr(deviceFd), "/dev/tun")
|
||||
return &tun{
|
||||
cidr: cidr,
|
||||
ReadWriteCloser: &tunReadCloser{f: file},
|
||||
routeTree: routeTree,
|
||||
}, nil
|
||||
c.RegisterReloadCallback(func(c *config.C) {
|
||||
err := t.reload(c, false)
|
||||
if err != nil {
|
||||
util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
|
||||
}
|
||||
})
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (t *tun) Activate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) reload(c *config.C, initial bool) error {
|
||||
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !initial && !change {
|
||||
return nil
|
||||
}
|
||||
|
||||
routeTree, err := makeRouteTree(t.l, routes, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Teach nebula how to handle the routes
|
||||
t.Routes.Store(&routes)
|
||||
t.routeTree.Store(routeTree)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
||||
_, r := t.routeTree.MostSpecificContains(ip)
|
||||
_, r := t.routeTree.Load().MostSpecificContains(ip)
|
||||
return r
|
||||
}
|
||||
|
||||
|
||||
@ -15,21 +15,25 @@ import (
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cidr"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/util"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
type tun struct {
|
||||
io.ReadWriteCloser
|
||||
fd int
|
||||
Device string
|
||||
cidr *net.IPNet
|
||||
MaxMTU int
|
||||
DefaultMTU int
|
||||
TXQueueLen int
|
||||
fd int
|
||||
Device string
|
||||
cidr *net.IPNet
|
||||
MaxMTU int
|
||||
DefaultMTU int
|
||||
TXQueueLen int
|
||||
deviceIndex int
|
||||
ioctlFd uintptr
|
||||
|
||||
Routes []Route
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
|
||||
routeChan chan struct{}
|
||||
useSystemRoutes bool
|
||||
@ -61,30 +65,20 @@ type ifreqQLEN struct {
|
||||
pad [8]byte
|
||||
}
|
||||
|
||||
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int, useSystemRoutes bool) (*tun, error) {
|
||||
routeTree, err := makeRouteTree(l, routes, true)
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) (*tun, error) {
|
||||
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
||||
|
||||
t, err := newTunGeneric(c, l, file, cidr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
||||
t.Device = "tun0"
|
||||
|
||||
t := &tun{
|
||||
ReadWriteCloser: file,
|
||||
fd: int(file.Fd()),
|
||||
Device: "tun0",
|
||||
cidr: cidr,
|
||||
DefaultMTU: defaultMTU,
|
||||
TXQueueLen: txQueueLen,
|
||||
Routes: routes,
|
||||
useSystemRoutes: useSystemRoutes,
|
||||
l: l,
|
||||
}
|
||||
t.routeTree.Store(routeTree)
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int, multiqueue bool, useSystemRoutes bool) (*tun, error) {
|
||||
func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, multiqueue bool) (*tun, error) {
|
||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -95,46 +89,113 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int
|
||||
if multiqueue {
|
||||
req.Flags |= unix.IFF_MULTI_QUEUE
|
||||
}
|
||||
copy(req.Name[:], deviceName)
|
||||
copy(req.Name[:], c.GetString("tun.dev", ""))
|
||||
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
name := strings.Trim(string(req.Name[:]), "\x00")
|
||||
|
||||
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
||||
|
||||
maxMTU := defaultMTU
|
||||
for _, r := range routes {
|
||||
if r.MTU == 0 {
|
||||
r.MTU = defaultMTU
|
||||
}
|
||||
|
||||
if r.MTU > maxMTU {
|
||||
maxMTU = r.MTU
|
||||
}
|
||||
}
|
||||
|
||||
routeTree, err := makeRouteTree(l, routes, true)
|
||||
t, err := newTunGeneric(c, l, file, cidr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t.Device = name
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr *net.IPNet) (*tun, error) {
|
||||
t := &tun{
|
||||
ReadWriteCloser: file,
|
||||
fd: int(file.Fd()),
|
||||
Device: name,
|
||||
cidr: cidr,
|
||||
MaxMTU: maxMTU,
|
||||
DefaultMTU: defaultMTU,
|
||||
TXQueueLen: txQueueLen,
|
||||
Routes: routes,
|
||||
useSystemRoutes: useSystemRoutes,
|
||||
TXQueueLen: c.GetInt("tun.tx_queue", 500),
|
||||
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
|
||||
l: l,
|
||||
}
|
||||
t.routeTree.Store(routeTree)
|
||||
|
||||
err := t.reload(c, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.RegisterReloadCallback(func(c *config.C) {
|
||||
err := t.reload(c, false)
|
||||
if err != nil {
|
||||
util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
|
||||
}
|
||||
})
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (t *tun) reload(c *config.C, initial bool) error {
|
||||
routeChange, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !initial && !routeChange && !c.HasChanged("tun.mtu") {
|
||||
return nil
|
||||
}
|
||||
|
||||
routeTree, err := makeRouteTree(t.l, routes, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oldDefaultMTU := t.DefaultMTU
|
||||
oldMaxMTU := t.MaxMTU
|
||||
newDefaultMTU := c.GetInt("tun.mtu", DefaultMTU)
|
||||
newMaxMTU := newDefaultMTU
|
||||
for i, r := range routes {
|
||||
if r.MTU == 0 {
|
||||
routes[i].MTU = newDefaultMTU
|
||||
}
|
||||
|
||||
if r.MTU > t.MaxMTU {
|
||||
newMaxMTU = r.MTU
|
||||
}
|
||||
}
|
||||
|
||||
t.MaxMTU = newMaxMTU
|
||||
t.DefaultMTU = newDefaultMTU
|
||||
|
||||
// Teach nebula how to handle the routes before establishing them in the system table
|
||||
oldRoutes := t.Routes.Swap(&routes)
|
||||
t.routeTree.Store(routeTree)
|
||||
|
||||
if !initial {
|
||||
if oldMaxMTU != newMaxMTU {
|
||||
t.setMTU()
|
||||
t.l.Infof("Set max MTU to %v was %v", t.MaxMTU, oldMaxMTU)
|
||||
}
|
||||
|
||||
if oldDefaultMTU != newDefaultMTU {
|
||||
err := t.setDefaultRoute()
|
||||
if err != nil {
|
||||
t.l.Warn(err)
|
||||
} else {
|
||||
t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove first, if the system removes a wanted route hopefully it will be re-added next
|
||||
t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
|
||||
|
||||
// Ensure any routes we actually want are installed
|
||||
err = t.addRoutes(true)
|
||||
if err != nil {
|
||||
// This should never be called since addRoutes should log its own errors in a reload condition
|
||||
util.LogWithContextIfNeeded("Failed to refresh routes", err, t.l)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
@ -208,7 +269,7 @@ func (t *tun) Activate() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fd := uintptr(s)
|
||||
t.ioctlFd = uintptr(s)
|
||||
|
||||
ifra := ifreqAddr{
|
||||
Name: devName,
|
||||
@ -219,52 +280,76 @@ func (t *tun) Activate() error {
|
||||
}
|
||||
|
||||
// Set the device ip address
|
||||
if err = ioctl(fd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil {
|
||||
if err = ioctl(t.ioctlFd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil {
|
||||
return fmt.Errorf("failed to set tun address: %s", err)
|
||||
}
|
||||
|
||||
// Set the device network
|
||||
ifra.Addr.Addr = mask
|
||||
if err = ioctl(fd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil {
|
||||
if err = ioctl(t.ioctlFd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil {
|
||||
return fmt.Errorf("failed to set tun netmask: %s", err)
|
||||
}
|
||||
|
||||
// Set the device name
|
||||
ifrf := ifReq{Name: devName}
|
||||
if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
||||
if err = ioctl(t.ioctlFd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
||||
return fmt.Errorf("failed to set tun device name: %s", err)
|
||||
}
|
||||
|
||||
// Set the MTU on the device
|
||||
ifm := ifreqMTU{Name: devName, MTU: int32(t.MaxMTU)}
|
||||
if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
|
||||
// This is currently a non fatal condition because the route table must have the MTU set appropriately as well
|
||||
t.l.WithError(err).Error("Failed to set tun mtu")
|
||||
}
|
||||
// Setup our default MTU
|
||||
t.setMTU()
|
||||
|
||||
// Set the transmit queue length
|
||||
ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)}
|
||||
if err = ioctl(fd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
|
||||
if err = ioctl(t.ioctlFd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
|
||||
// If we can't set the queue length nebula will still work but it may lead to packet loss
|
||||
t.l.WithError(err).Error("Failed to set tun tx queue length")
|
||||
}
|
||||
|
||||
// Bring up the interface
|
||||
ifrf.Flags = ifrf.Flags | unix.IFF_UP
|
||||
if err = ioctl(fd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
||||
if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
||||
return fmt.Errorf("failed to bring the tun device up: %s", err)
|
||||
}
|
||||
|
||||
// Set the routes
|
||||
link, err := netlink.LinkByName(t.Device)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get tun device link: %s", err)
|
||||
}
|
||||
t.deviceIndex = link.Attrs().Index
|
||||
|
||||
if err = t.setDefaultRoute(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set the routes
|
||||
if err = t.addRoutes(false); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Run the interface
|
||||
ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING
|
||||
if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
||||
return fmt.Errorf("failed to run tun device: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) setMTU() {
|
||||
// Set the MTU on the device
|
||||
ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MaxMTU)}
|
||||
if err := ioctl(t.ioctlFd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
|
||||
// This is currently a non fatal condition because the route table must have the MTU set appropriately as well
|
||||
t.l.WithError(err).Error("Failed to set tun mtu")
|
||||
}
|
||||
}
|
||||
|
||||
func (t *tun) setDefaultRoute() error {
|
||||
// Default route
|
||||
dr := &net.IPNet{IP: t.cidr.IP.Mask(t.cidr.Mask), Mask: t.cidr.Mask}
|
||||
nr := netlink.Route{
|
||||
LinkIndex: link.Attrs().Index,
|
||||
LinkIndex: t.deviceIndex,
|
||||
Dst: dr,
|
||||
MTU: t.DefaultMTU,
|
||||
AdvMSS: t.advMSS(Route{}),
|
||||
@ -274,19 +359,24 @@ func (t *tun) Activate() error {
|
||||
Table: unix.RT_TABLE_MAIN,
|
||||
Type: unix.RTN_UNICAST,
|
||||
}
|
||||
err = netlink.RouteReplace(&nr)
|
||||
err := netlink.RouteReplace(&nr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set mtu %v on the default route %v; %v", t.DefaultMTU, dr, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) addRoutes(logErrors bool) error {
|
||||
// Path routes
|
||||
for _, r := range t.Routes {
|
||||
routes := *t.Routes.Load()
|
||||
for _, r := range routes {
|
||||
if !r.Install {
|
||||
continue
|
||||
}
|
||||
|
||||
nr := netlink.Route{
|
||||
LinkIndex: link.Attrs().Index,
|
||||
LinkIndex: t.deviceIndex,
|
||||
Dst: r.Cidr,
|
||||
MTU: r.MTU,
|
||||
AdvMSS: t.advMSS(r),
|
||||
@ -297,21 +387,49 @@ func (t *tun) Activate() error {
|
||||
nr.Priority = r.Metric
|
||||
}
|
||||
|
||||
err = netlink.RouteAdd(&nr)
|
||||
err := netlink.RouteReplace(&nr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set mtu %v on route %v; %v", r.MTU, r.Cidr, err)
|
||||
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
|
||||
if logErrors {
|
||||
retErr.Log(t.l)
|
||||
} else {
|
||||
return retErr
|
||||
}
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Added route")
|
||||
}
|
||||
}
|
||||
|
||||
// Run the interface
|
||||
ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING
|
||||
if err = ioctl(fd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
||||
return fmt.Errorf("failed to run tun device: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) removeRoutes(routes []Route) {
|
||||
for _, r := range routes {
|
||||
if !r.Install {
|
||||
continue
|
||||
}
|
||||
|
||||
nr := netlink.Route{
|
||||
LinkIndex: t.deviceIndex,
|
||||
Dst: r.Cidr,
|
||||
MTU: r.MTU,
|
||||
AdvMSS: t.advMSS(r),
|
||||
Scope: unix.RT_SCOPE_LINK,
|
||||
}
|
||||
|
||||
if r.Metric > 0 {
|
||||
nr.Priority = r.Metric
|
||||
}
|
||||
|
||||
err := netlink.RouteDel(&nr)
|
||||
if err != nil {
|
||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Removed route")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *tun) Cidr() *net.IPNet {
|
||||
return t.cidr
|
||||
}
|
||||
@ -410,5 +528,9 @@ func (t *tun) Close() error {
|
||||
t.ReadWriteCloser.Close()
|
||||
}
|
||||
|
||||
if t.ioctlFd > 0 {
|
||||
os.NewFile(t.ioctlFd, "ioctlFd").Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -11,12 +11,15 @@ import (
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cidr"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/util"
|
||||
)
|
||||
|
||||
type ifreqDestroy struct {
|
||||
@ -28,8 +31,8 @@ type tun struct {
|
||||
Device string
|
||||
cidr *net.IPNet
|
||||
MTU int
|
||||
Routes []Route
|
||||
routeTree *cidr.Tree4[iputil.VpnIp]
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
|
||||
l *logrus.Logger
|
||||
|
||||
io.ReadWriteCloser
|
||||
@ -56,43 +59,50 @@ func (t *tun) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) {
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
|
||||
}
|
||||
|
||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||
|
||||
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) {
|
||||
func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) {
|
||||
// Try to open tun device
|
||||
var file *os.File
|
||||
var err error
|
||||
deviceName := c.GetString("tun.dev", "")
|
||||
if deviceName == "" {
|
||||
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
|
||||
}
|
||||
if !deviceNameRE.MatchString(deviceName) {
|
||||
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
|
||||
}
|
||||
|
||||
file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
routeTree, err := makeRouteTree(l, routes, false)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &tun{
|
||||
t := &tun{
|
||||
ReadWriteCloser: file,
|
||||
Device: deviceName,
|
||||
cidr: cidr,
|
||||
MTU: defaultMTU,
|
||||
Routes: routes,
|
||||
routeTree: routeTree,
|
||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
}, nil
|
||||
}
|
||||
|
||||
err = t.reload(c, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.RegisterReloadCallback(func(c *config.C) {
|
||||
err := t.reload(c, false)
|
||||
if err != nil {
|
||||
util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
|
||||
}
|
||||
})
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (t *tun) Activate() error {
|
||||
@ -116,17 +126,42 @@ func (t *tun) Activate() error {
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||
}
|
||||
|
||||
// Unsafe path routes
|
||||
for _, r := range t.Routes {
|
||||
if r.Via == nil || !r.Install {
|
||||
// We don't allow route MTUs so only install routes with a via
|
||||
continue
|
||||
return t.addRoutes(false)
|
||||
}
|
||||
|
||||
func (t *tun) reload(c *config.C, initial bool) error {
|
||||
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !initial && !change {
|
||||
return nil
|
||||
}
|
||||
|
||||
routeTree, err := makeRouteTree(t.l, routes, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Teach nebula how to handle the routes before establishing them in the system table
|
||||
oldRoutes := t.Routes.Swap(&routes)
|
||||
t.routeTree.Store(routeTree)
|
||||
|
||||
if !initial {
|
||||
// Remove first, if the system removes a wanted route hopefully it will be re-added next
|
||||
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
|
||||
if err != nil {
|
||||
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
|
||||
}
|
||||
|
||||
cmd = exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.IP.String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.Cidr.String(), err)
|
||||
// Ensure any routes we actually want are installed
|
||||
err = t.addRoutes(true)
|
||||
if err != nil {
|
||||
// Catch any stray logs
|
||||
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
|
||||
}
|
||||
}
|
||||
|
||||
@ -134,7 +169,7 @@ func (t *tun) Activate() error {
|
||||
}
|
||||
|
||||
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
||||
_, r := t.routeTree.MostSpecificContains(ip)
|
||||
_, r := t.routeTree.Load().MostSpecificContains(ip)
|
||||
return r
|
||||
}
|
||||
|
||||
@ -150,6 +185,46 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd")
|
||||
}
|
||||
|
||||
func (t *tun) addRoutes(logErrors bool) error {
|
||||
routes := *t.Routes.Load()
|
||||
for _, r := range routes {
|
||||
if r.Via == nil || !r.Install {
|
||||
// We don't allow route MTUs so only install routes with a via
|
||||
continue
|
||||
}
|
||||
|
||||
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.IP.String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err := cmd.Run(); err != nil {
|
||||
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
|
||||
if logErrors {
|
||||
retErr.Log(t.l)
|
||||
} else {
|
||||
return retErr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) removeRoutes(routes []Route) error {
|
||||
for _, r := range routes {
|
||||
if !r.Install {
|
||||
continue
|
||||
}
|
||||
|
||||
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.cidr.IP.String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err := cmd.Run(); err != nil {
|
||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Removed route")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) deviceBytes() (o [16]byte) {
|
||||
for i, c := range t.Device {
|
||||
o[i] = byte(c)
|
||||
|
||||
@ -11,19 +11,22 @@ import (
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cidr"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/util"
|
||||
)
|
||||
|
||||
type tun struct {
|
||||
Device string
|
||||
cidr *net.IPNet
|
||||
MTU int
|
||||
Routes []Route
|
||||
routeTree *cidr.Tree4[iputil.VpnIp]
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
|
||||
l *logrus.Logger
|
||||
|
||||
io.ReadWriteCloser
|
||||
@ -40,13 +43,14 @@ func (t *tun) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) {
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD")
|
||||
}
|
||||
|
||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||
|
||||
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) {
|
||||
func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) {
|
||||
deviceName := c.GetString("tun.dev", "")
|
||||
if deviceName == "" {
|
||||
return nil, fmt.Errorf("a device name in the format of tunN must be specified")
|
||||
}
|
||||
@ -60,20 +64,64 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int
|
||||
return nil, err
|
||||
}
|
||||
|
||||
routeTree, err := makeRouteTree(l, routes, false)
|
||||
t := &tun{
|
||||
ReadWriteCloser: file,
|
||||
Device: deviceName,
|
||||
cidr: cidr,
|
||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
}
|
||||
|
||||
err = t.reload(c, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &tun{
|
||||
ReadWriteCloser: file,
|
||||
Device: deviceName,
|
||||
cidr: cidr,
|
||||
MTU: defaultMTU,
|
||||
Routes: routes,
|
||||
routeTree: routeTree,
|
||||
l: l,
|
||||
}, nil
|
||||
c.RegisterReloadCallback(func(c *config.C) {
|
||||
err := t.reload(c, false)
|
||||
if err != nil {
|
||||
util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
|
||||
}
|
||||
})
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (t *tun) reload(c *config.C, initial bool) error {
|
||||
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !initial && !change {
|
||||
return nil
|
||||
}
|
||||
|
||||
routeTree, err := makeRouteTree(t.l, routes, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Teach nebula how to handle the routes before establishing them in the system table
|
||||
oldRoutes := t.Routes.Swap(&routes)
|
||||
t.routeTree.Store(routeTree)
|
||||
|
||||
if !initial {
|
||||
// Remove first, if the system removes a wanted route hopefully it will be re-added next
|
||||
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
|
||||
if err != nil {
|
||||
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
|
||||
}
|
||||
|
||||
// Ensure any routes we actually want are installed
|
||||
err = t.addRoutes(true)
|
||||
if err != nil {
|
||||
// Catch any stray logs
|
||||
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) Activate() error {
|
||||
@ -98,25 +146,52 @@ func (t *tun) Activate() error {
|
||||
}
|
||||
|
||||
// Unsafe path routes
|
||||
for _, r := range t.Routes {
|
||||
return t.addRoutes(false)
|
||||
}
|
||||
|
||||
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
||||
_, r := t.routeTree.Load().MostSpecificContains(ip)
|
||||
return r
|
||||
}
|
||||
|
||||
func (t *tun) addRoutes(logErrors bool) error {
|
||||
routes := *t.Routes.Load()
|
||||
for _, r := range routes {
|
||||
if r.Via == nil || !r.Install {
|
||||
// We don't allow route MTUs so only install routes with a via
|
||||
continue
|
||||
}
|
||||
|
||||
cmd = exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.IP.String())
|
||||
cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.IP.String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err = cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.Cidr.String(), err)
|
||||
if err := cmd.Run(); err != nil {
|
||||
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
|
||||
if logErrors {
|
||||
retErr.Log(t.l)
|
||||
} else {
|
||||
return retErr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
||||
_, r := t.routeTree.MostSpecificContains(ip)
|
||||
return r
|
||||
func (t *tun) removeRoutes(routes []Route) error {
|
||||
for _, r := range routes {
|
||||
if !r.Install {
|
||||
continue
|
||||
}
|
||||
|
||||
cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.cidr.IP.String())
|
||||
t.l.Debug("command: ", cmd.String())
|
||||
if err := cmd.Run(); err != nil {
|
||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Removed route")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) Cidr() *net.IPNet {
|
||||
|
||||
@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cidr"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
)
|
||||
|
||||
@ -27,14 +28,18 @@ type TestTun struct {
|
||||
TxPackets chan []byte // Packets transmitted outside by nebula
|
||||
}
|
||||
|
||||
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool, _ bool) (*TestTun, error) {
|
||||
func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*TestTun, error) {
|
||||
_, routes, err := getAllRoutesFromConfig(c, cidr, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
routeTree, err := makeRouteTree(l, routes, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &TestTun{
|
||||
Device: deviceName,
|
||||
Device: c.GetString("tun.dev", ""),
|
||||
cidr: cidr,
|
||||
Routes: routes,
|
||||
routeTree: routeTree,
|
||||
@ -44,7 +49,7 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, _ int, routes
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*TestTun, error) {
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*TestTun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported")
|
||||
}
|
||||
|
||||
|
||||
@ -6,10 +6,13 @@ import (
|
||||
"net"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cidr"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/util"
|
||||
"github.com/songgao/water"
|
||||
)
|
||||
|
||||
@ -17,25 +20,34 @@ type waterTun struct {
|
||||
Device string
|
||||
cidr *net.IPNet
|
||||
MTU int
|
||||
Routes []Route
|
||||
routeTree *cidr.Tree4[iputil.VpnIp]
|
||||
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
|
||||
l *logrus.Logger
|
||||
f *net.Interface
|
||||
*water.Interface
|
||||
}
|
||||
|
||||
func newWaterTun(l *logrus.Logger, cidr *net.IPNet, defaultMTU int, routes []Route) (*waterTun, error) {
|
||||
routeTree, err := makeRouteTree(l, routes, false)
|
||||
func newWaterTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*waterTun, error) {
|
||||
// NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate()
|
||||
t := &waterTun{
|
||||
cidr: cidr,
|
||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
}
|
||||
|
||||
err := t.reload(c, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate()
|
||||
return &waterTun{
|
||||
cidr: cidr,
|
||||
MTU: defaultMTU,
|
||||
Routes: routes,
|
||||
routeTree: routeTree,
|
||||
}, nil
|
||||
c.RegisterReloadCallback(func(c *config.C) {
|
||||
err := t.reload(c, false)
|
||||
if err != nil {
|
||||
util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
|
||||
}
|
||||
})
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (t *waterTun) Activate() error {
|
||||
@ -74,30 +86,104 @@ func (t *waterTun) Activate() error {
|
||||
return fmt.Errorf("failed to run 'netsh' to set MTU: %s", err)
|
||||
}
|
||||
|
||||
iface, err := net.InterfaceByName(t.Device)
|
||||
t.f, err = net.InterfaceByName(t.Device)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find interface named %s: %v", t.Device, err)
|
||||
}
|
||||
|
||||
for _, r := range t.Routes {
|
||||
if r.Via == nil || !r.Install {
|
||||
// We don't allow route MTUs so only install routes with a via
|
||||
continue
|
||||
}
|
||||
err = t.addRoutes(false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = exec.Command(
|
||||
"C:\\Windows\\System32\\route.exe", "add", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(iface.Index), "METRIC", strconv.Itoa(r.Metric),
|
||||
).Run()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *waterTun) reload(c *config.C, initial bool) error {
|
||||
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !initial && !change {
|
||||
return nil
|
||||
}
|
||||
|
||||
routeTree, err := makeRouteTree(t.l, routes, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Teach nebula how to handle the routes before establishing them in the system table
|
||||
oldRoutes := t.Routes.Swap(&routes)
|
||||
t.routeTree.Store(routeTree)
|
||||
|
||||
if !initial {
|
||||
// Remove first, if the system removes a wanted route hopefully it will be re-added next
|
||||
t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
|
||||
|
||||
// Ensure any routes we actually want are installed
|
||||
err = t.addRoutes(true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add the unsafe_route %s: %v", r.Cidr.String(), err)
|
||||
// Catch any stray logs
|
||||
util.LogWithContextIfNeeded("Failed to set routes", err, t.l)
|
||||
} else {
|
||||
for _, r := range findRemovedRoutes(routes, *oldRoutes) {
|
||||
t.l.WithField("route", r).Info("Removed route")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *waterTun) addRoutes(logErrors bool) error {
|
||||
// Path routes
|
||||
routes := *t.Routes.Load()
|
||||
for _, r := range routes {
|
||||
if r.Via == nil || !r.Install {
|
||||
// We don't allow route MTUs so only install routes with a via
|
||||
continue
|
||||
}
|
||||
|
||||
err := exec.Command(
|
||||
"C:\\Windows\\System32\\route.exe", "add", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(t.f.Index), "METRIC", strconv.Itoa(r.Metric),
|
||||
).Run()
|
||||
|
||||
if err != nil {
|
||||
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
|
||||
if logErrors {
|
||||
retErr.Log(t.l)
|
||||
} else {
|
||||
return retErr
|
||||
}
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Added route")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *waterTun) removeRoutes(routes []Route) {
|
||||
for _, r := range routes {
|
||||
if !r.Install {
|
||||
continue
|
||||
}
|
||||
|
||||
err := exec.Command(
|
||||
"C:\\Windows\\System32\\route.exe", "delete", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(t.f.Index), "METRIC", strconv.Itoa(r.Metric),
|
||||
).Run()
|
||||
if err != nil {
|
||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Removed route")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *waterTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
||||
_, r := t.routeTree.MostSpecificContains(ip)
|
||||
_, r := t.routeTree.Load().MostSpecificContains(ip)
|
||||
return r
|
||||
}
|
||||
|
||||
|
||||
@ -12,13 +12,14 @@ import (
|
||||
"syscall"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
)
|
||||
|
||||
func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (Device, error) {
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (Device, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
|
||||
}
|
||||
|
||||
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (Device, error) {
|
||||
func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, multiqueue bool) (Device, error) {
|
||||
useWintun := true
|
||||
if err := checkWinTunExists(); err != nil {
|
||||
l.WithError(err).Warn("Check Wintun driver failed, fallback to wintap driver")
|
||||
@ -26,14 +27,14 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int
|
||||
}
|
||||
|
||||
if useWintun {
|
||||
device, err := newWinTun(l, deviceName, cidr, defaultMTU, routes)
|
||||
device, err := newWinTun(c, l, cidr, multiqueue)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create Wintun interface failed, %w", err)
|
||||
}
|
||||
return device, nil
|
||||
}
|
||||
|
||||
device, err := newWaterTun(l, cidr, defaultMTU, routes)
|
||||
device, err := newWaterTun(c, l, cidr, multiqueue)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create wintap driver failed, %w", err)
|
||||
}
|
||||
|
||||
@ -6,11 +6,14 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync/atomic"
|
||||
"unsafe"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cidr"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/util"
|
||||
"github.com/slackhq/nebula/wintun"
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
@ -23,8 +26,9 @@ type winTun struct {
|
||||
cidr *net.IPNet
|
||||
prefix netip.Prefix
|
||||
MTU int
|
||||
Routes []Route
|
||||
routeTree *cidr.Tree4[iputil.VpnIp]
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
|
||||
l *logrus.Logger
|
||||
|
||||
tun *wintun.NativeTun
|
||||
}
|
||||
@ -48,83 +52,148 @@ func generateGUIDByDeviceName(name string) (*windows.GUID, error) {
|
||||
return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil
|
||||
}
|
||||
|
||||
func newWinTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route) (*winTun, error) {
|
||||
func newWinTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*winTun, error) {
|
||||
deviceName := c.GetString("tun.dev", "")
|
||||
guid, err := generateGUIDByDeviceName(deviceName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate GUID failed: %w", err)
|
||||
}
|
||||
|
||||
var tunDevice wintun.Device
|
||||
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, defaultMTU)
|
||||
if err != nil {
|
||||
// Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device.
|
||||
// Trying a second time resolves the issue.
|
||||
l.WithError(err).Debug("Failed to create wintun device, retrying")
|
||||
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, defaultMTU)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create TUN device failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
routeTree, err := makeRouteTree(l, routes, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
prefix, err := iputil.ToNetIpPrefix(*cidr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &winTun{
|
||||
Device: deviceName,
|
||||
cidr: cidr,
|
||||
prefix: prefix,
|
||||
MTU: defaultMTU,
|
||||
Routes: routes,
|
||||
routeTree: routeTree,
|
||||
t := &winTun{
|
||||
Device: deviceName,
|
||||
cidr: cidr,
|
||||
prefix: prefix,
|
||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
}
|
||||
|
||||
tun: tunDevice.(*wintun.NativeTun),
|
||||
}, nil
|
||||
err = t.reload(c, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var tunDevice wintun.Device
|
||||
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
|
||||
if err != nil {
|
||||
// Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device.
|
||||
// Trying a second time resolves the issue.
|
||||
l.WithError(err).Debug("Failed to create wintun device, retrying")
|
||||
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create TUN device failed: %w", err)
|
||||
}
|
||||
}
|
||||
t.tun = tunDevice.(*wintun.NativeTun)
|
||||
|
||||
c.RegisterReloadCallback(func(c *config.C) {
|
||||
err := t.reload(c, false)
|
||||
if err != nil {
|
||||
util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
|
||||
}
|
||||
})
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (t *winTun) reload(c *config.C, initial bool) error {
|
||||
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !initial && !change {
|
||||
return nil
|
||||
}
|
||||
|
||||
routeTree, err := makeRouteTree(t.l, routes, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Teach nebula how to handle the routes before establishing them in the system table
|
||||
oldRoutes := t.Routes.Swap(&routes)
|
||||
t.routeTree.Store(routeTree)
|
||||
|
||||
if !initial {
|
||||
// Remove first, if the system removes a wanted route hopefully it will be re-added next
|
||||
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
|
||||
if err != nil {
|
||||
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
|
||||
}
|
||||
|
||||
// Ensure any routes we actually want are installed
|
||||
err = t.addRoutes(true)
|
||||
if err != nil {
|
||||
// Catch any stray logs
|
||||
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *winTun) Activate() error {
|
||||
luid := winipcfg.LUID(t.tun.LUID())
|
||||
|
||||
if err := luid.SetIPAddresses([]netip.Prefix{t.prefix}); err != nil {
|
||||
err := luid.SetIPAddresses([]netip.Prefix{t.prefix})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set address: %w", err)
|
||||
}
|
||||
|
||||
foundDefault4 := false
|
||||
routes := make([]*winipcfg.RouteData, 0, len(t.Routes)+1)
|
||||
err = t.addRoutes(false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, r := range t.Routes {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *winTun) addRoutes(logErrors bool) error {
|
||||
luid := winipcfg.LUID(t.tun.LUID())
|
||||
routes := *t.Routes.Load()
|
||||
foundDefault4 := false
|
||||
|
||||
for _, r := range routes {
|
||||
if r.Via == nil || !r.Install {
|
||||
// We don't allow route MTUs so only install routes with a via
|
||||
continue
|
||||
}
|
||||
|
||||
prefix, err := iputil.ToNetIpPrefix(*r.Cidr)
|
||||
if err != nil {
|
||||
retErr := util.NewContextualError("Failed to parse cidr to netip prefix, ignoring route", map[string]interface{}{"route": r}, err)
|
||||
if logErrors {
|
||||
retErr.Log(t.l)
|
||||
continue
|
||||
} else {
|
||||
return retErr
|
||||
}
|
||||
}
|
||||
|
||||
// Add our unsafe route
|
||||
err = luid.AddRoute(prefix, r.Via.ToNetIpAddr(), uint32(r.Metric))
|
||||
if err != nil {
|
||||
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
|
||||
if logErrors {
|
||||
retErr.Log(t.l)
|
||||
continue
|
||||
} else {
|
||||
return retErr
|
||||
}
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Added route")
|
||||
}
|
||||
|
||||
if !foundDefault4 {
|
||||
if ones, bits := r.Cidr.Mask.Size(); ones == 0 && bits != 0 {
|
||||
foundDefault4 = true
|
||||
}
|
||||
}
|
||||
|
||||
prefix, err := iputil.ToNetIpPrefix(*r.Cidr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Add our unsafe route
|
||||
routes = append(routes, &winipcfg.RouteData{
|
||||
Destination: prefix,
|
||||
NextHop: r.Via.ToNetIpAddr(),
|
||||
Metric: uint32(r.Metric),
|
||||
})
|
||||
}
|
||||
|
||||
if err := luid.AddRoutes(routes); err != nil {
|
||||
return fmt.Errorf("failed to add routes: %w", err)
|
||||
}
|
||||
|
||||
ipif, err := luid.IPInterface(windows.AF_INET)
|
||||
@ -141,12 +210,35 @@ func (t *winTun) Activate() error {
|
||||
if err := ipif.Set(); err != nil {
|
||||
return fmt.Errorf("failed to set ip interface: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *winTun) removeRoutes(routes []Route) error {
|
||||
luid := winipcfg.LUID(t.tun.LUID())
|
||||
|
||||
for _, r := range routes {
|
||||
if !r.Install {
|
||||
continue
|
||||
}
|
||||
|
||||
prefix, err := iputil.ToNetIpPrefix(*r.Cidr)
|
||||
if err != nil {
|
||||
t.l.WithError(err).WithField("route", r).Info("Failed to convert cidr to netip prefix")
|
||||
continue
|
||||
}
|
||||
|
||||
err = luid.DeleteRoute(prefix, r.Via.ToNetIpAddr())
|
||||
if err != nil {
|
||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||
} else {
|
||||
t.l.WithField("route", r).Info("Removed route")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *winTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
||||
_, r := t.routeTree.MostSpecificContains(ip)
|
||||
_, r := t.routeTree.Load().MostSpecificContains(ip)
|
||||
return r
|
||||
}
|
||||
|
||||
|
||||
2
ssh.go
2
ssh.go
@ -939,7 +939,7 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
|
||||
enc.SetIndent("", " ")
|
||||
}
|
||||
|
||||
return enc.Encode(copyHostInfo(hostInfo, ifce.hostMap.preferredRanges))
|
||||
return enc.Encode(copyHostInfo(hostInfo, ifce.hostMap.GetPreferredRanges()))
|
||||
}
|
||||
|
||||
func sshReload(c *config.C, w sshd.StringWriter) error {
|
||||
|
||||
@ -2,6 +2,7 @@ package util
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
@ -40,7 +41,7 @@ func (ce *ContextualError) Error() string {
|
||||
if ce.RealError == nil {
|
||||
return ce.Context
|
||||
}
|
||||
return ce.RealError.Error()
|
||||
return fmt.Errorf("%s (%v): %w", ce.Context, ce.Fields, ce.RealError).Error()
|
||||
}
|
||||
|
||||
func (ce *ContextualError) Unwrap() error {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user