mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 16:34:25 +01:00
Compare commits
6 Commits
jay.wren-w
...
cross-stac
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f597aa71e3 | ||
|
|
20b7219fbe | ||
|
|
3b53c27170 | ||
|
|
526236c5fa | ||
|
|
0ab2882b78 | ||
|
|
889d49ff82 |
@@ -29,8 +29,6 @@ type m = map[string]any
|
|||||||
|
|
||||||
// newSimpleServer creates a nebula instance with many assumptions
|
// newSimpleServer creates a nebula instance with many assumptions
|
||||||
func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
||||||
l := NewTestLogger()
|
|
||||||
|
|
||||||
var vpnNetworks []netip.Prefix
|
var vpnNetworks []netip.Prefix
|
||||||
for _, sn := range strings.Split(sVpnNetworks, ",") {
|
for _, sn := range strings.Split(sVpnNetworks, ",") {
|
||||||
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
||||||
@@ -56,6 +54,25 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
|
|||||||
budpIp[3] = 239
|
budpIp[3] = 239
|
||||||
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
|
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
|
||||||
}
|
}
|
||||||
|
return newSimpleServerWithUdp(v, caCrt, caKey, name, sVpnNetworks, udpAddr, overrides)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSimpleServerWithUdp(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
||||||
|
l := NewTestLogger()
|
||||||
|
|
||||||
|
var vpnNetworks []netip.Prefix
|
||||||
|
for _, sn := range strings.Split(sVpnNetworks, ",") {
|
||||||
|
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
vpnNetworks = append(vpnNetworks, vpnIpNet)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(vpnNetworks) == 0 {
|
||||||
|
panic("no vpn networks")
|
||||||
|
}
|
||||||
|
|
||||||
_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{})
|
_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{})
|
||||||
|
|
||||||
caB, err := caCrt.MarshalPEM()
|
caB, err := caCrt.MarshalPEM()
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
package e2e
|
package e2e
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -55,3 +56,50 @@ func TestDropInactiveTunnels(t *testing.T) {
|
|||||||
myControl.Stop()
|
myControl.Stop()
|
||||||
theirControl.Stop()
|
theirControl.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCrossStackRelaysWork(t *testing.T) {
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24,fc00::1/64", m{"relay": m{"use_relays": true}})
|
||||||
|
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "10.128.0.128/24,fc00::128/64", m{"relay": m{"am_relay": true}})
|
||||||
|
theirUdp := netip.MustParseAddrPort("10.0.0.2:4242")
|
||||||
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdp(cert.Version2, ca, caKey, "them ", "fc00::2/64", theirUdp, m{"relay": m{"use_relays": true}})
|
||||||
|
|
||||||
|
//myVpnV4 := myVpnIpNet[0]
|
||||||
|
myVpnV6 := myVpnIpNet[1]
|
||||||
|
relayVpnV4 := relayVpnIpNet[0]
|
||||||
|
relayVpnV6 := relayVpnIpNet[1]
|
||||||
|
theirVpnV6 := theirVpnIpNet[0]
|
||||||
|
|
||||||
|
// Teach my how to get to the relay and that their can be reached via the relay
|
||||||
|
myControl.InjectLightHouseAddr(relayVpnV4.Addr(), relayUdpAddr)
|
||||||
|
myControl.InjectLightHouseAddr(relayVpnV6.Addr(), relayUdpAddr)
|
||||||
|
myControl.InjectRelays(theirVpnV6.Addr(), []netip.Addr{relayVpnV6.Addr()})
|
||||||
|
relayControl.InjectLightHouseAddr(theirVpnV6.Addr(), theirUdpAddr)
|
||||||
|
|
||||||
|
// Build a router so we don't have to reason who gets which packet
|
||||||
|
r := router.NewR(t, myControl, relayControl, theirControl)
|
||||||
|
defer r.RenderFlow()
|
||||||
|
|
||||||
|
// Start the servers
|
||||||
|
myControl.Start()
|
||||||
|
relayControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
t.Log("Trigger a handshake from me to them via the relay")
|
||||||
|
myControl.InjectTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me"))
|
||||||
|
|
||||||
|
p := r.RouteForAllUntilTxTun(theirControl)
|
||||||
|
r.Log("Assert the tunnel works")
|
||||||
|
assertUdpPacket(t, []byte("Hi from me"), p, myVpnV6.Addr(), theirVpnV6.Addr(), 80, 80)
|
||||||
|
|
||||||
|
t.Log("reply?")
|
||||||
|
theirControl.InjectTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them"))
|
||||||
|
p = r.RouteForAllUntilTxTun(myControl)
|
||||||
|
assertUdpPacket(t, []byte("Hi from them"), p, theirVpnV6.Addr(), myVpnV6.Addr(), 80, 80)
|
||||||
|
|
||||||
|
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
|
||||||
|
//t.Log("finish up")
|
||||||
|
//myControl.Stop()
|
||||||
|
//theirControl.Stop()
|
||||||
|
//relayControl.Stop()
|
||||||
|
}
|
||||||
|
|||||||
29
firewall.go
29
firewall.go
@@ -417,6 +417,8 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var ErrUnknownNetworkType = errors.New("unknown network type")
|
||||||
|
var ErrPeerRejected = errors.New("remote IP is not within a subnet that we handle")
|
||||||
var ErrInvalidRemoteIP = errors.New("remote IP is not in remote certificate subnets")
|
var ErrInvalidRemoteIP = errors.New("remote IP is not in remote certificate subnets")
|
||||||
var ErrInvalidLocalIP = errors.New("local IP is not in list of handled local IPs")
|
var ErrInvalidLocalIP = errors.New("local IP is not in list of handled local IPs")
|
||||||
var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
|
var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
|
||||||
@@ -429,18 +431,31 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure remote address matches nebula certificate
|
// Make sure remote address matches nebula certificate, and determine how to treat it
|
||||||
if h.networks != nil {
|
if h.networks == nil {
|
||||||
if !h.networks.Contains(fp.RemoteAddr) {
|
|
||||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
|
||||||
return ErrInvalidRemoteIP
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Simple case: Certificate has one address and no unsafe networks
|
// Simple case: Certificate has one address and no unsafe networks
|
||||||
if h.vpnAddrs[0] != fp.RemoteAddr {
|
if h.vpnAddrs[0] != fp.RemoteAddr {
|
||||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||||
return ErrInvalidRemoteIP
|
return ErrInvalidRemoteIP
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
nwType, ok := h.networks.Lookup(fp.RemoteAddr)
|
||||||
|
if !ok {
|
||||||
|
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||||
|
return ErrInvalidRemoteIP
|
||||||
|
}
|
||||||
|
switch nwType {
|
||||||
|
case NetworkTypeVPN:
|
||||||
|
break // nothing special
|
||||||
|
case NetworkTypeVPNPeer:
|
||||||
|
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||||
|
return ErrPeerRejected // reject for now, one day this may have different FW rules
|
||||||
|
case NetworkTypeUnsafe:
|
||||||
|
break // nothing special, one day this may have different FW rules
|
||||||
|
default:
|
||||||
|
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||||
|
return ErrUnknownNetworkType //should never happen
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure we are supposed to be handling this local ip address
|
// Make sure we are supposed to be handling this local ip address
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gaissmai/bart"
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
@@ -149,7 +150,8 @@ func TestFirewall_Drop(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||||
p := firewall.Packet{
|
p := firewall.Packet{
|
||||||
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
RemoteAddr: netip.MustParseAddr("1.2.3.4"),
|
RemoteAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
@@ -174,7 +176,7 @@ func TestFirewall_Drop(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
|
vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
|
||||||
}
|
}
|
||||||
h.buildNetworks(c.networks, c.unsafeNetworks)
|
h.buildNetworks(myVpnNetworksTable, c.networks, c.unsafeNetworks)
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
@@ -226,6 +228,9 @@ func TestFirewall_DropV6(t *testing.T) {
|
|||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
|
|
||||||
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
|
||||||
|
|
||||||
p := firewall.Packet{
|
p := firewall.Packet{
|
||||||
LocalAddr: netip.MustParseAddr("fd12::34"),
|
LocalAddr: netip.MustParseAddr("fd12::34"),
|
||||||
RemoteAddr: netip.MustParseAddr("fd12::34"),
|
RemoteAddr: netip.MustParseAddr("fd12::34"),
|
||||||
@@ -250,7 +255,7 @@ func TestFirewall_DropV6(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")},
|
vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")},
|
||||||
}
|
}
|
||||||
h.buildNetworks(c.networks, c.unsafeNetworks)
|
h.buildNetworks(myVpnNetworksTable, c.networks, c.unsafeNetworks)
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
@@ -453,6 +458,8 @@ func TestFirewall_Drop2(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||||
|
|
||||||
p := firewall.Packet{
|
p := firewall.Packet{
|
||||||
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
@@ -478,7 +485,7 @@ func TestFirewall_Drop2(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{network.Addr()},
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
}
|
}
|
||||||
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
h.buildNetworks(myVpnNetworksTable, c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
c1 := cert.CachedCertificate{
|
c1 := cert.CachedCertificate{
|
||||||
Certificate: &dummyCert{
|
Certificate: &dummyCert{
|
||||||
@@ -493,7 +500,7 @@ func TestFirewall_Drop2(t *testing.T) {
|
|||||||
peerCert: &c1,
|
peerCert: &c1,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
h1.buildNetworks(myVpnNetworksTable, c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
@@ -510,6 +517,8 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||||
|
|
||||||
p := firewall.Packet{
|
p := firewall.Packet{
|
||||||
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
@@ -541,7 +550,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{network.Addr()},
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
}
|
}
|
||||||
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
h1.buildNetworks(myVpnNetworksTable, c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
c2 := cert.CachedCertificate{
|
c2 := cert.CachedCertificate{
|
||||||
Certificate: &dummyCert{
|
Certificate: &dummyCert{
|
||||||
@@ -556,7 +565,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{network.Addr()},
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
}
|
}
|
||||||
h2.buildNetworks(c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks())
|
h2.buildNetworks(myVpnNetworksTable, c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
c3 := cert.CachedCertificate{
|
c3 := cert.CachedCertificate{
|
||||||
Certificate: &dummyCert{
|
Certificate: &dummyCert{
|
||||||
@@ -571,7 +580,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{network.Addr()},
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
}
|
}
|
||||||
h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
|
h3.buildNetworks(myVpnNetworksTable, c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
@@ -597,6 +606,8 @@ func TestFirewall_Drop3V6(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
|
||||||
|
|
||||||
p := firewall.Packet{
|
p := firewall.Packet{
|
||||||
LocalAddr: netip.MustParseAddr("fd12::34"),
|
LocalAddr: netip.MustParseAddr("fd12::34"),
|
||||||
@@ -620,7 +631,7 @@ func TestFirewall_Drop3V6(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{network.Addr()},
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
}
|
}
|
||||||
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
h.buildNetworks(myVpnNetworksTable, c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
// Test a remote address match
|
// Test a remote address match
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
@@ -633,6 +644,8 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||||
|
|
||||||
p := firewall.Packet{
|
p := firewall.Packet{
|
||||||
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
@@ -659,7 +672,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{network.Addr()},
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
}
|
}
|
||||||
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
h.buildNetworks(myVpnNetworksTable, c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
@@ -696,6 +709,8 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
ob := &bytes.Buffer{}
|
ob := &bytes.Buffer{}
|
||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
|
myVpnNetworksTable := new(bart.Lite)
|
||||||
|
myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24"))
|
||||||
|
|
||||||
c := cert.CachedCertificate{
|
c := cert.CachedCertificate{
|
||||||
Certificate: &dummyCert{
|
Certificate: &dummyCert{
|
||||||
@@ -717,7 +732,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
|
|||||||
},
|
},
|
||||||
vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()},
|
vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()},
|
||||||
}
|
}
|
||||||
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
h1.buildNetworks(myVpnNetworksTable, c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
|
|
||||||
|
|||||||
6
go.mod
6
go.mod
@@ -29,11 +29,11 @@ require (
|
|||||||
golang.org/x/sys v0.37.0
|
golang.org/x/sys v0.37.0
|
||||||
golang.org/x/term v0.36.0
|
golang.org/x/term v0.36.0
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb
|
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||||
google.golang.org/protobuf v1.36.8
|
google.golang.org/protobuf v1.36.8
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c
|
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
@@ -49,6 +49,6 @@ require (
|
|||||||
github.com/vishvananda/netns v0.0.5 // indirect
|
github.com/vishvananda/netns v0.0.5 // indirect
|
||||||
go.yaml.in/yaml/v2 v2.4.2 // indirect
|
go.yaml.in/yaml/v2 v2.4.2 // indirect
|
||||||
golang.org/x/mod v0.24.0 // indirect
|
golang.org/x/mod v0.24.0 // indirect
|
||||||
golang.org/x/time v0.7.0 // indirect
|
golang.org/x/time v0.5.0 // indirect
|
||||||
golang.org/x/tools v0.33.0 // indirect
|
golang.org/x/tools v0.33.0 // indirect
|
||||||
)
|
)
|
||||||
|
|||||||
12
go.sum
12
go.sum
@@ -215,8 +215,8 @@ golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
|
|||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
|
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
|
||||||
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||||
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
|
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
|
||||||
@@ -230,8 +230,8 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T
|
|||||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+ZbWg+4sHnLp52d5yiIPUxMBSt4X9A=
|
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b h1:J1CaxgLerRR5lgx3wnr6L04cJFbWoceSK9JWBdglINo=
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
|
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4=
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
|
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
|
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
|
||||||
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
||||||
@@ -257,5 +257,5 @@ gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
|||||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c h1:m/r7OM+Y2Ty1sgBQ7Qb27VgIMBW8ZZhT4gLnUyDIhzI=
|
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe h1:fre4i6mv4iBuz5lCMOzHD1rH1ljqHWSICFmZRbbgp3g=
|
||||||
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g=
|
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe/go.mod h1:sxc3Uvk/vHcd3tj7/DHVBoR5wvWT/MmRq2pj7HRJnwU=
|
||||||
|
|||||||
108
handshake_ix.go
108
handshake_ix.go
@@ -183,17 +183,18 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var vpnAddrs []netip.Addr
|
|
||||||
var filteredNetworks []netip.Prefix
|
|
||||||
certName := remoteCert.Certificate.Name()
|
certName := remoteCert.Certificate.Name()
|
||||||
certVersion := remoteCert.Certificate.Version()
|
certVersion := remoteCert.Certificate.Version()
|
||||||
fingerprint := remoteCert.Fingerprint
|
fingerprint := remoteCert.Fingerprint
|
||||||
issuer := remoteCert.Certificate.Issuer()
|
issuer := remoteCert.Certificate.Issuer()
|
||||||
|
vpnNetworks := remoteCert.Certificate.Networks()
|
||||||
|
|
||||||
for _, network := range remoteCert.Certificate.Networks() {
|
anyVpnAddrsInCommon := false
|
||||||
|
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
|
||||||
|
for i, network := range vpnNetworks {
|
||||||
vpnAddr := network.Addr()
|
vpnAddr := network.Addr()
|
||||||
if f.myVpnAddrsTable.Contains(vpnAddr) {
|
if f.myVpnAddrsTable.Contains(vpnAddr) {
|
||||||
f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).
|
f.l.WithField("vpnNetworks", vpnNetworks).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("certVersion", certVersion).
|
WithField("certVersion", certVersion).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
@@ -201,24 +202,10 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
vpnAddrs[i] = network.Addr()
|
||||||
// vpnAddrs outside our vpn networks are of no use to us, filter them out
|
if f.myVpnNetworksTable.Contains(vpnAddr) {
|
||||||
if !f.myVpnNetworksTable.Contains(vpnAddr) {
|
anyVpnAddrsInCommon = true
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
filteredNetworks = append(filteredNetworks, network)
|
|
||||||
vpnAddrs = append(vpnAddrs, vpnAddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(vpnAddrs) == 0 {
|
|
||||||
f.l.WithError(err).WithField("udpAddr", addr).
|
|
||||||
WithField("certName", certName).
|
|
||||||
WithField("certVersion", certVersion).
|
|
||||||
WithField("fingerprint", fingerprint).
|
|
||||||
WithField("issuer", issuer).
|
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if addr.IsValid() {
|
if addr.IsValid() {
|
||||||
@@ -255,26 +242,30 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
msgRxL := f.l.WithFields(m{
|
||||||
WithField("certName", certName).
|
"vpnAddrs": vpnAddrs,
|
||||||
WithField("certVersion", certVersion).
|
"udpAddr": addr,
|
||||||
WithField("fingerprint", fingerprint).
|
"certName": certName,
|
||||||
WithField("issuer", issuer).
|
"certVersion": certVersion,
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
"fingerprint": fingerprint,
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
"issuer": issuer,
|
||||||
Info("Handshake message received")
|
"initiatorIndex": hs.Details.InitiatorIndex,
|
||||||
|
"responderIndex": hs.Details.ResponderIndex,
|
||||||
|
"remoteIndex": h.RemoteIndex,
|
||||||
|
"handshake": m{"stage": 1, "style": "ix_psk0"},
|
||||||
|
})
|
||||||
|
|
||||||
|
if anyVpnAddrsInCommon {
|
||||||
|
msgRxL.Info("Handshake message received")
|
||||||
|
} else {
|
||||||
|
//todo warn if not lighthouse or relay?
|
||||||
|
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
|
||||||
|
}
|
||||||
|
|
||||||
hs.Details.ResponderIndex = myIndex
|
hs.Details.ResponderIndex = myIndex
|
||||||
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
|
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
|
||||||
if hs.Details.Cert == nil {
|
if hs.Details.Cert == nil {
|
||||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
msgRxL.WithField("myCertVersion", ci.myCert.Version()).
|
||||||
WithField("certName", certName).
|
|
||||||
WithField("certVersion", certVersion).
|
|
||||||
WithField("fingerprint", fingerprint).
|
|
||||||
WithField("issuer", issuer).
|
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
|
||||||
WithField("certVersion", ci.myCert.Version()).
|
|
||||||
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -332,7 +323,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
|
|
||||||
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
|
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
|
||||||
hostinfo.SetRemote(addr)
|
hostinfo.SetRemote(addr)
|
||||||
hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
|
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate.Networks(), remoteCert.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
|
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -573,30 +564,17 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
var vpnAddrs []netip.Addr
|
anyVpnAddrsInCommon := false
|
||||||
var filteredNetworks []netip.Prefix
|
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
|
||||||
for _, network := range vpnNetworks {
|
for i, network := range vpnNetworks {
|
||||||
// vpnAddrs outside our vpn networks are of no use to us, filter them out
|
vpnAddrs[i] = network.Addr()
|
||||||
vpnAddr := network.Addr()
|
if f.myVpnNetworksTable.Contains(network.Addr()) {
|
||||||
if !f.myVpnNetworksTable.Contains(vpnAddr) {
|
anyVpnAddrsInCommon = true
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
filteredNetworks = append(filteredNetworks, network)
|
|
||||||
vpnAddrs = append(vpnAddrs, vpnAddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(vpnAddrs) == 0 {
|
|
||||||
f.l.WithError(err).WithField("udpAddr", addr).
|
|
||||||
WithField("certName", certName).
|
|
||||||
WithField("certVersion", certVersion).
|
|
||||||
WithField("fingerprint", fingerprint).
|
|
||||||
WithField("issuer", issuer).
|
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure the right host responded
|
// Ensure the right host responded
|
||||||
|
// todo is it more correct to see if any of hostinfo.vpnAddrs are in the cert? it should have len==1, but one day it might not?
|
||||||
if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) {
|
if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) {
|
||||||
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
|
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
|
||||||
WithField("udpAddr", addr).
|
WithField("udpAddr", addr).
|
||||||
@@ -609,6 +587,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
f.handshakeManager.DeleteHostInfo(hostinfo)
|
f.handshakeManager.DeleteHostInfo(hostinfo)
|
||||||
|
|
||||||
// Create a new hostinfo/handshake for the intended vpn ip
|
// Create a new hostinfo/handshake for the intended vpn ip
|
||||||
|
//TODO is hostinfo.vpnAddrs[0] always the address to use?
|
||||||
f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
|
f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
|
||||||
// Block the current used address
|
// Block the current used address
|
||||||
newHH.hostinfo.remotes = hostinfo.remotes
|
newHH.hostinfo.remotes = hostinfo.remotes
|
||||||
@@ -635,7 +614,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
ci.window.Update(f.l, 2)
|
ci.window.Update(f.l, 2)
|
||||||
|
|
||||||
duration := time.Since(hh.startTime).Nanoseconds()
|
duration := time.Since(hh.startTime).Nanoseconds()
|
||||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("certVersion", certVersion).
|
WithField("certVersion", certVersion).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
@@ -643,12 +622,17 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||||
WithField("durationNs", duration).
|
WithField("durationNs", duration).
|
||||||
WithField("sentCachedPackets", len(hh.packetStore)).
|
WithField("sentCachedPackets", len(hh.packetStore))
|
||||||
Info("Handshake message received")
|
if anyVpnAddrsInCommon {
|
||||||
|
msgRxL.Info("Handshake message received")
|
||||||
|
} else {
|
||||||
|
//todo warn if not lighthouse or relay?
|
||||||
|
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
|
||||||
|
}
|
||||||
|
|
||||||
// Build up the radix for the firewall if we have subnets in the cert
|
// Build up the radix for the firewall if we have subnets in the cert
|
||||||
hostinfo.vpnAddrs = vpnAddrs
|
hostinfo.vpnAddrs = vpnAddrs
|
||||||
hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
|
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate.Networks(), remoteCert.Certificate.UnsafeNetworks())
|
||||||
|
|
||||||
// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
|
// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
|
||||||
f.handshakeManager.Complete(hostinfo, f)
|
f.handshakeManager.Complete(hostinfo, f)
|
||||||
|
|||||||
39
hostmap.go
39
hostmap.go
@@ -212,6 +212,18 @@ func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) {
|
|||||||
rs.relayForByIdx[idx] = r
|
rs.relayForByIdx[idx] = r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type NetworkType uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
NetworkTypeUnknown NetworkType = iota
|
||||||
|
// NetworkTypeVPN is a network that overlaps one or more of the vpnNetworks in our certificate
|
||||||
|
NetworkTypeVPN
|
||||||
|
// NetworkTypeVPNPeer is a network that does not overlap one of our networks
|
||||||
|
NetworkTypeVPNPeer
|
||||||
|
// NetworkTypeUnsafe is a network from Certificate.UnsafeNetworks()
|
||||||
|
NetworkTypeUnsafe
|
||||||
|
)
|
||||||
|
|
||||||
type HostInfo struct {
|
type HostInfo struct {
|
||||||
remote netip.AddrPort
|
remote netip.AddrPort
|
||||||
remotes *RemoteList
|
remotes *RemoteList
|
||||||
@@ -220,13 +232,11 @@ type HostInfo struct {
|
|||||||
remoteIndexId uint32
|
remoteIndexId uint32
|
||||||
localIndexId uint32
|
localIndexId uint32
|
||||||
|
|
||||||
// vpnAddrs is a list of vpn addresses assigned to this host that are within our own vpn networks
|
// vpnAddrs is a list of vpn addresses assigned to this host
|
||||||
// The host may have other vpn addresses that are outside our
|
|
||||||
// vpn networks but were removed because they are not usable
|
|
||||||
vpnAddrs []netip.Addr
|
vpnAddrs []netip.Addr
|
||||||
|
|
||||||
// networks are both all vpn and unsafe networks assigned to this host
|
// networks is a combination of specific vpn addresses (not prefixes!) and full unsafe networks assigned to this host.
|
||||||
networks *bart.Lite
|
networks *bart.Table[NetworkType]
|
||||||
relayState RelayState
|
relayState RelayState
|
||||||
|
|
||||||
// HandshakePacket records the packets used to create this hostinfo
|
// HandshakePacket records the packets used to create this hostinfo
|
||||||
@@ -730,20 +740,27 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
|
func (i *HostInfo) buildNetworks(myVpnNetworksTable *bart.Lite, networks, unsafeNetworks []netip.Prefix) {
|
||||||
if len(networks) == 1 && len(unsafeNetworks) == 0 {
|
if len(networks) == 1 && len(unsafeNetworks) == 0 {
|
||||||
// Simple case, no CIDRTree needed
|
if myVpnNetworksTable.Contains(networks[0].Addr()) {
|
||||||
return
|
return // Simple case, no CIDRTree needed
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
i.networks = new(bart.Lite)
|
i.networks = new(bart.Table[NetworkType])
|
||||||
for _, network := range networks {
|
for _, network := range networks {
|
||||||
|
var nwType NetworkType
|
||||||
|
if myVpnNetworksTable.Contains(network.Addr()) {
|
||||||
|
nwType = NetworkTypeVPN
|
||||||
|
} else {
|
||||||
|
nwType = NetworkTypeVPNPeer
|
||||||
|
}
|
||||||
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
|
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
|
||||||
i.networks.Insert(nprefix)
|
i.networks.Insert(nprefix, nwType)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, network := range unsafeNetworks {
|
for _, network := range unsafeNetworks {
|
||||||
i.networks.Insert(network)
|
i.networks.Insert(network, NetworkTypeUnsafe)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -33,7 +33,8 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
|||||||
// routes packets from the Nebula addr to the Nebula addr through the Nebula
|
// routes packets from the Nebula addr to the Nebula addr through the Nebula
|
||||||
// TUN device.
|
// TUN device.
|
||||||
if immediatelyForwardToSelf {
|
if immediatelyForwardToSelf {
|
||||||
if err := f.writeTun(q, packet); err != nil {
|
_, err := f.readers[q].Write(packet)
|
||||||
|
if err != nil {
|
||||||
f.l.WithError(err).Error("Failed to forward to tun")
|
f.l.WithError(err).Error("Failed to forward to tun")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -90,7 +91,8 @@ func (f *Interface) rejectInside(packet []byte, out []byte, q int) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := f.writeTun(q, out); err != nil {
|
_, err := f.readers[q].Write(out)
|
||||||
|
if err != nil {
|
||||||
f.l.WithError(err).Error("Failed to write to tun")
|
f.l.WithError(err).Error("Failed to write to tun")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
101
interface.go
101
interface.go
@@ -47,7 +47,6 @@ type InterfaceConfig struct {
|
|||||||
reQueryWait time.Duration
|
reQueryWait time.Duration
|
||||||
|
|
||||||
ConntrackCacheTimeout time.Duration
|
ConntrackCacheTimeout time.Duration
|
||||||
batchSize int
|
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -85,7 +84,6 @@ type Interface struct {
|
|||||||
version string
|
version string
|
||||||
|
|
||||||
conntrackCacheTimeout time.Duration
|
conntrackCacheTimeout time.Duration
|
||||||
batchSize int
|
|
||||||
|
|
||||||
writers []udp.Conn
|
writers []udp.Conn
|
||||||
readers []io.ReadWriteCloser
|
readers []io.ReadWriteCloser
|
||||||
@@ -112,16 +110,6 @@ type EncWriter interface {
|
|||||||
GetCertState() *CertState
|
GetCertState() *CertState
|
||||||
}
|
}
|
||||||
|
|
||||||
// BatchReader is an interface for readers that support vectorized packet reading
|
|
||||||
type BatchReader interface {
|
|
||||||
BatchRead(buffers [][]byte, sizes []int) (int, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchWriter is an interface for writers that support vectorized packet writing
|
|
||||||
type BatchWriter interface {
|
|
||||||
BatchWrite([][]byte) (int, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type sendRecvErrorConfig uint8
|
type sendRecvErrorConfig uint8
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -198,7 +186,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
relayManager: c.relayManager,
|
relayManager: c.relayManager,
|
||||||
connectionManager: c.connectionManager,
|
connectionManager: c.connectionManager,
|
||||||
conntrackCacheTimeout: c.ConntrackCacheTimeout,
|
conntrackCacheTimeout: c.ConntrackCacheTimeout,
|
||||||
batchSize: c.batchSize,
|
|
||||||
|
|
||||||
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
|
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
|
||||||
messageMetrics: c.MessageMetrics,
|
messageMetrics: c.MessageMetrics,
|
||||||
@@ -282,7 +269,7 @@ func (f *Interface) listenOut(i int) {
|
|||||||
plaintext := make([]byte, udp.MTU)
|
plaintext := make([]byte, udp.MTU)
|
||||||
h := &header.H{}
|
h := &header.H{}
|
||||||
fwPacket := &firewall.Packet{}
|
fwPacket := &firewall.Packet{}
|
||||||
nb := make([]byte, 12)
|
nb := make([]byte, 12, 12)
|
||||||
|
|
||||||
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
|
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
|
||||||
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
|
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
|
||||||
@@ -292,16 +279,6 @@ func (f *Interface) listenOut(i int) {
|
|||||||
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
|
|
||||||
// Check if reader supports batch operations
|
|
||||||
if batchReader, ok := reader.(BatchReader); ok {
|
|
||||||
err := f.listenInBatch(batchReader, i)
|
|
||||||
if err != nil {
|
|
||||||
f.l.WithError(err).Error("Fatal error in batch packet reader, exiting goroutine")
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fall back to single-packet mode
|
|
||||||
packet := make([]byte, mtu)
|
packet := make([]byte, mtu)
|
||||||
out := make([]byte, mtu)
|
out := make([]byte, mtu)
|
||||||
fwPacket := &firewall.Packet{}
|
fwPacket := &firewall.Packet{}
|
||||||
@@ -316,85 +293,15 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
f.l.WithError(err).Error("Fatal error while reading outbound packet, exiting goroutine")
|
f.l.WithError(err).Error("Error while reading outbound packet")
|
||||||
return
|
// This only seems to happen when something fatal happens to the fd, so exit.
|
||||||
|
os.Exit(2)
|
||||||
}
|
}
|
||||||
|
|
||||||
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
|
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// listenInBatch handles vectorized packet reading for improved performance
|
|
||||||
func (f *Interface) listenInBatch(reader BatchReader, i int) error {
|
|
||||||
// Allocate per-packet state and buffers for batch reading
|
|
||||||
batchSize := f.batchSize
|
|
||||||
if batchSize <= 0 {
|
|
||||||
batchSize = 64 // Fallback to default if not configured
|
|
||||||
}
|
|
||||||
fwPackets := make([]*firewall.Packet, batchSize)
|
|
||||||
outBuffers := make([][]byte, batchSize)
|
|
||||||
nbBuffers := make([][]byte, batchSize)
|
|
||||||
packets := make([][]byte, batchSize)
|
|
||||||
sizes := make([]int, batchSize)
|
|
||||||
|
|
||||||
for j := 0; j < batchSize; j++ {
|
|
||||||
fwPackets[j] = &firewall.Packet{}
|
|
||||||
outBuffers[j] = make([]byte, mtu)
|
|
||||||
nbBuffers[j] = make([]byte, 12)
|
|
||||||
packets[j] = make([]byte, mtu)
|
|
||||||
}
|
|
||||||
|
|
||||||
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
|
||||||
|
|
||||||
for {
|
|
||||||
n, err := reader.BatchRead(packets, sizes)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Errorf("error while batch reading outbound packets: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process each packet in the batch
|
|
||||||
cache := conntrackCache.Get(f.l)
|
|
||||||
for idx := 0; idx < n; idx++ {
|
|
||||||
if sizes[idx] > 0 {
|
|
||||||
// Use modulo to reuse fw packet state if batch is larger than our pre-allocated state
|
|
||||||
stateIdx := idx % len(fwPackets)
|
|
||||||
f.consumeInsidePacket(packets[idx][:sizes[idx]], fwPackets[stateIdx], nbBuffers[stateIdx], outBuffers[stateIdx], i, cache)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// writeTunBatch attempts to write multiple packets to the TUN device using batch operations if supported
|
|
||||||
func (f *Interface) writeTunBatch(q int, packets [][]byte) error {
|
|
||||||
if len(packets) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if the reader/writer supports batch operations
|
|
||||||
if batchWriter, ok := f.readers[q].(BatchWriter); ok {
|
|
||||||
_, err := batchWriter.BatchWrite(packets)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fall back to writing packets individually
|
|
||||||
for _, packet := range packets {
|
|
||||||
if _, err := f.readers[q].Write(packet); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// writeTun writes a single packet to the TUN device
|
|
||||||
func (f *Interface) writeTun(q int, packet []byte) error {
|
|
||||||
_, err := f.readers[q].Write(packet)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
|
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
|
||||||
c.RegisterReloadCallback(f.reloadFirewall)
|
c.RegisterReloadCallback(f.reloadFirewall)
|
||||||
c.RegisterReloadCallback(f.reloadSendRecvError)
|
c.RegisterReloadCallback(f.reloadSendRecvError)
|
||||||
|
|||||||
@@ -1017,17 +1017,17 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta {
|
|||||||
return lhh.meta
|
return lhh.meta
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs []netip.Addr, p []byte, w EncWriter) {
|
func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, hostinfo *HostInfo, p []byte, w EncWriter) {
|
||||||
n := lhh.resetMeta()
|
n := lhh.resetMeta()
|
||||||
err := n.Unmarshal(p)
|
err := n.Unmarshal(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr).
|
lhh.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", rAddr).
|
||||||
Error("Failed to unmarshal lighthouse packet")
|
Error("Failed to unmarshal lighthouse packet")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if n.Details == nil {
|
if n.Details == nil {
|
||||||
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr).
|
lhh.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", rAddr).
|
||||||
Error("Invalid lighthouse update")
|
Error("Invalid lighthouse update")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1036,24 +1036,24 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs [
|
|||||||
|
|
||||||
switch n.Type {
|
switch n.Type {
|
||||||
case NebulaMeta_HostQuery:
|
case NebulaMeta_HostQuery:
|
||||||
lhh.handleHostQuery(n, fromVpnAddrs, rAddr, w)
|
lhh.handleHostQuery(n, hostinfo, rAddr, w)
|
||||||
|
|
||||||
case NebulaMeta_HostQueryReply:
|
case NebulaMeta_HostQueryReply:
|
||||||
lhh.handleHostQueryReply(n, fromVpnAddrs)
|
lhh.handleHostQueryReply(n, hostinfo.vpnAddrs)
|
||||||
|
|
||||||
case NebulaMeta_HostUpdateNotification:
|
case NebulaMeta_HostUpdateNotification:
|
||||||
lhh.handleHostUpdateNotification(n, fromVpnAddrs, w)
|
lhh.handleHostUpdateNotification(n, hostinfo, w)
|
||||||
|
|
||||||
case NebulaMeta_HostMovedNotification:
|
case NebulaMeta_HostMovedNotification:
|
||||||
case NebulaMeta_HostPunchNotification:
|
case NebulaMeta_HostPunchNotification:
|
||||||
lhh.handleHostPunchNotification(n, fromVpnAddrs, w)
|
lhh.handleHostPunchNotification(n, hostinfo.vpnAddrs, w)
|
||||||
|
|
||||||
case NebulaMeta_HostUpdateNotificationAck:
|
case NebulaMeta_HostUpdateNotificationAck:
|
||||||
// noop
|
// noop
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []netip.Addr, addr netip.AddrPort, w EncWriter) {
|
func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, hostinfo *HostInfo, addr netip.AddrPort, w EncWriter) {
|
||||||
// Exit if we don't answer queries
|
// Exit if we don't answer queries
|
||||||
if !lhh.lh.amLighthouse {
|
if !lhh.lh.amLighthouse {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
@@ -1065,7 +1065,7 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
|
|||||||
queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion()
|
queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).
|
lhh.l.WithField("from", hostinfo.vpnAddrs).WithField("details", n.Details).
|
||||||
Debugln("Dropping malformed HostQuery")
|
Debugln("Dropping malformed HostQuery")
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@@ -1073,7 +1073,7 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
|
|||||||
if useVersion == cert.Version1 && queryVpnAddr.Is6() {
|
if useVersion == cert.Version1 && queryVpnAddr.Is6() {
|
||||||
// this case really shouldn't be possible to represent, but reject it anyway.
|
// this case really shouldn't be possible to represent, but reject it anyway.
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("queryVpnAddr", queryVpnAddr).
|
lhh.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("queryVpnAddr", queryVpnAddr).
|
||||||
Debugln("invalid vpn addr for v1 handleHostQuery")
|
Debugln("invalid vpn addr for v1 handleHostQuery")
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@@ -1099,14 +1099,14 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host query reply")
|
lhh.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).Error("Failed to marshal lighthouse host query reply")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
lhh.lh.metricTx(NebulaMeta_HostQueryReply, 1)
|
lhh.lh.metricTx(NebulaMeta_HostQueryReply, 1)
|
||||||
w.SendMessageToVpnAddr(header.LightHouse, 0, fromVpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0])
|
w.SendMessageToHostInfo(header.LightHouse, 0, hostinfo, lhh.pb[:ln], lhh.nb, lhh.out[:0])
|
||||||
|
|
||||||
lhh.sendHostPunchNotification(n, fromVpnAddrs, queryVpnAddr, w)
|
lhh.sendHostPunchNotification(n, hostinfo.vpnAddrs, queryVpnAddr, w)
|
||||||
}
|
}
|
||||||
|
|
||||||
// sendHostPunchNotification signals the other side to punch some zero byte udp packets
|
// sendHostPunchNotification signals the other side to punch some zero byte udp packets
|
||||||
@@ -1115,20 +1115,34 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
|
|||||||
found, ln, err := lhh.lh.queryAndPrepMessage(whereToPunch, func(c *cache) (int, error) {
|
found, ln, err := lhh.lh.queryAndPrepMessage(whereToPunch, func(c *cache) (int, error) {
|
||||||
n = lhh.resetMeta()
|
n = lhh.resetMeta()
|
||||||
n.Type = NebulaMeta_HostPunchNotification
|
n.Type = NebulaMeta_HostPunchNotification
|
||||||
targetHI := lhh.lh.ifce.GetHostInfo(punchNotifDest)
|
punchNotifDestHI := lhh.lh.ifce.GetHostInfo(punchNotifDest)
|
||||||
var useVersion cert.Version
|
var useVersion cert.Version
|
||||||
if targetHI == nil {
|
if punchNotifDestHI == nil {
|
||||||
useVersion = lhh.lh.ifce.GetCertState().initiatingVersion
|
useVersion = lhh.lh.ifce.GetCertState().initiatingVersion
|
||||||
} else {
|
} else {
|
||||||
crt := targetHI.GetCert().Certificate
|
|
||||||
useVersion = crt.Version()
|
|
||||||
// we can only retarget if we have a hostinfo
|
// we can only retarget if we have a hostinfo
|
||||||
newDest, ok := findNetworkUnion(crt.Networks(), fromVpnAddrs)
|
punchNotifDestCrt := punchNotifDestHI.GetCert().Certificate
|
||||||
|
useVersion = punchNotifDestCrt.Version()
|
||||||
|
punchNotifDestNetworks := punchNotifDestCrt.Networks()
|
||||||
|
|
||||||
|
//if we (the lighthouse) don't have a network in common with punchNotifDest, try to find one
|
||||||
|
if !lhh.lh.myVpnNetworksTable.Contains(punchNotifDest) {
|
||||||
|
newPunchNotifDest, ok := findNetworkUnion(lhh.lh.myVpnNetworks, punchNotifDestHI.vpnAddrs)
|
||||||
|
if ok {
|
||||||
|
punchNotifDest = newPunchNotifDest
|
||||||
|
} else {
|
||||||
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
|
lhh.l.WithField("to", punchNotifDestNetworks).Debugln("unable to notify host to host, no addresses in common")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
newWhereToPunch, ok := findNetworkUnion(punchNotifDestNetworks, fromVpnAddrs)
|
||||||
if ok {
|
if ok {
|
||||||
whereToPunch = newDest
|
whereToPunch = newWhereToPunch
|
||||||
} else {
|
} else {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
lhh.l.WithField("to", crt.Networks()).Debugln("unable to punch to host, no addresses in common")
|
lhh.l.WithFields(m{"from": fromVpnAddrs, "to": punchNotifDestNetworks}).Debugln("unable to punch to host, no addresses in common with requestor")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1234,7 +1248,8 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
|
func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, hostinfo *HostInfo, w EncWriter) {
|
||||||
|
fromVpnAddrs := hostinfo.vpnAddrs
|
||||||
if !lhh.lh.amLighthouse {
|
if !lhh.lh.amLighthouse {
|
||||||
if lhh.l.Level >= logrus.DebugLevel {
|
if lhh.l.Level >= logrus.DebugLevel {
|
||||||
lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", fromVpnAddrs)
|
lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", fromVpnAddrs)
|
||||||
@@ -1302,7 +1317,7 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
|
|||||||
}
|
}
|
||||||
|
|
||||||
lhh.lh.metricTx(NebulaMeta_HostUpdateNotificationAck, 1)
|
lhh.lh.metricTx(NebulaMeta_HostUpdateNotificationAck, 1)
|
||||||
w.SendMessageToVpnAddr(header.LightHouse, 0, fromVpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0])
|
w.SendMessageToHostInfo(header.LightHouse, 0, hostinfo, lhh.pb[:ln], lhh.nb, lhh.out[:0])
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
|
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
|
||||||
|
|||||||
@@ -132,8 +132,13 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
mw := &mockEncWriter{}
|
mw := &mockEncWriter{}
|
||||||
|
hostinfo := &HostInfo{
|
||||||
hi := []netip.Addr{vpnIp2}
|
ConnectionState: &ConnectionState{
|
||||||
|
eKey: nil,
|
||||||
|
dKey: nil,
|
||||||
|
},
|
||||||
|
vpnAddrs: []netip.Addr{vpnIp2},
|
||||||
|
}
|
||||||
b.Run("notfound", func(b *testing.B) {
|
b.Run("notfound", func(b *testing.B) {
|
||||||
lhh := lh.NewRequestHandler()
|
lhh := lh.NewRequestHandler()
|
||||||
req := &NebulaMeta{
|
req := &NebulaMeta{
|
||||||
@@ -146,7 +151,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
|||||||
p, err := req.Marshal()
|
p, err := req.Marshal()
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
lhh.HandleRequest(rAddr, hi, p, mw)
|
lhh.HandleRequest(rAddr, hostinfo, p, mw)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
b.Run("found", func(b *testing.B) {
|
b.Run("found", func(b *testing.B) {
|
||||||
@@ -162,7 +167,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
|||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
|
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
lhh.HandleRequest(rAddr, hi, p, mw)
|
lhh.HandleRequest(rAddr, hostinfo, p, mw)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -326,7 +331,14 @@ func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, l
|
|||||||
w := &testEncWriter{
|
w := &testEncWriter{
|
||||||
metaFilter: &filter,
|
metaFilter: &filter,
|
||||||
}
|
}
|
||||||
lhh.HandleRequest(fromAddr, []netip.Addr{myVpnIp}, b, w)
|
hostinfo := &HostInfo{
|
||||||
|
ConnectionState: &ConnectionState{
|
||||||
|
eKey: nil,
|
||||||
|
dKey: nil,
|
||||||
|
},
|
||||||
|
vpnAddrs: []netip.Addr{myVpnIp},
|
||||||
|
}
|
||||||
|
lhh.HandleRequest(fromAddr, hostinfo, b, w)
|
||||||
return w.lastReply
|
return w.lastReply
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -355,9 +367,15 @@ func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.Ad
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
hostinfo := &HostInfo{
|
||||||
|
ConnectionState: &ConnectionState{
|
||||||
|
eKey: nil,
|
||||||
|
dKey: nil,
|
||||||
|
},
|
||||||
|
vpnAddrs: []netip.Addr{vpnIp},
|
||||||
|
}
|
||||||
w := &testEncWriter{}
|
w := &testEncWriter{}
|
||||||
lhh.HandleRequest(fromAddr, []netip.Addr{vpnIp}, b, w)
|
lhh.HandleRequest(fromAddr, hostinfo, b, w)
|
||||||
}
|
}
|
||||||
|
|
||||||
type testLhReply struct {
|
type testLhReply struct {
|
||||||
|
|||||||
1
main.go
1
main.go
@@ -242,7 +242,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
relayManager: NewRelayManager(ctx, l, hostMap, c),
|
relayManager: NewRelayManager(ctx, l, hostMap, c),
|
||||||
punchy: punchy,
|
punchy: punchy,
|
||||||
ConntrackCacheTimeout: conntrackCacheTimeout,
|
ConntrackCacheTimeout: conntrackCacheTimeout,
|
||||||
batchSize: c.GetInt("tun.batch_size", 64),
|
|
||||||
l: l,
|
l: l,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
11
outside.go
11
outside.go
@@ -138,7 +138,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f)
|
lhf.HandleRequest(ip, hostinfo, d, f)
|
||||||
|
|
||||||
// Fallthrough to the bottom to record incoming traffic
|
// Fallthrough to the bottom to record incoming traffic
|
||||||
|
|
||||||
@@ -333,13 +333,12 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fp.Protocol = uint8(proto)
|
fp.Protocol = uint8(proto)
|
||||||
ports := data[offset : offset+4]
|
|
||||||
if incoming {
|
if incoming {
|
||||||
fp.RemotePort = binary.BigEndian.Uint16(ports[0:2])
|
fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2])
|
||||||
fp.LocalPort = binary.BigEndian.Uint16(ports[2:4])
|
fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
|
||||||
} else {
|
} else {
|
||||||
fp.LocalPort = binary.BigEndian.Uint16(ports[0:2])
|
fp.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2])
|
||||||
fp.RemotePort = binary.BigEndian.Uint16(ports[2:4])
|
fp.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
|
||||||
}
|
}
|
||||||
|
|
||||||
fp.Fragment = false
|
fp.Fragment = false
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package overlay
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -304,3 +305,29 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
|
|||||||
|
|
||||||
return routes, nil
|
return routes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ipWithin(o *net.IPNet, i *net.IPNet) bool {
|
||||||
|
// Make sure o contains the lowest form of i
|
||||||
|
if !o.Contains(i.IP.Mask(i.Mask)) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the max ip in i
|
||||||
|
ip4 := i.IP.To4()
|
||||||
|
if ip4 == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
last := make(net.IP, len(ip4))
|
||||||
|
copy(last, ip4)
|
||||||
|
for x := range ip4 {
|
||||||
|
last[x] |= ^i.Mask[x]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure o contains the max
|
||||||
|
if !o.Contains(last) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|||||||
@@ -225,7 +225,6 @@ func Test_parseUnsafeRoutes(t *testing.T) {
|
|||||||
// no mtu
|
// no mtu
|
||||||
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
|
c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
|
||||||
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Len(t, routes, 1)
|
assert.Len(t, routes, 1)
|
||||||
assert.Equal(t, 0, routes[0].MTU)
|
assert.Equal(t, 0, routes[0].MTU)
|
||||||
|
|
||||||
@@ -319,7 +318,7 @@ func Test_makeRouteTree(t *testing.T) {
|
|||||||
|
|
||||||
ip, err = netip.ParseAddr("1.1.0.1")
|
ip, err = netip.ParseAddr("1.1.0.1")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, ok = routeTree.Lookup(ip)
|
r, ok = routeTree.Lookup(ip)
|
||||||
assert.False(t, ok)
|
assert.False(t, ok)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package overlay
|
package overlay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
@@ -70,3 +72,51 @@ func findRemovedRoutes(newRoutes, oldRoutes []Route) []Route {
|
|||||||
|
|
||||||
return removed
|
return removed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func prefixToMask(prefix netip.Prefix) netip.Addr {
|
||||||
|
pLen := 128
|
||||||
|
if prefix.Addr().Is4() {
|
||||||
|
pLen = 32
|
||||||
|
}
|
||||||
|
|
||||||
|
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func flipBytes(b []byte) []byte {
|
||||||
|
for i := 0; i < len(b); i++ {
|
||||||
|
b[i] ^= 0xFF
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
func orBytes(a []byte, b []byte) []byte {
|
||||||
|
ret := make([]byte, len(a))
|
||||||
|
for i := 0; i < len(a); i++ {
|
||||||
|
ret[i] = a[i] | b[i]
|
||||||
|
}
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
func getBroadcast(cidr netip.Prefix) netip.Addr {
|
||||||
|
broadcast, _ := netip.AddrFromSlice(
|
||||||
|
orBytes(
|
||||||
|
cidr.Addr().AsSlice(),
|
||||||
|
flipBytes(prefixToMask(cidr).AsSlice()),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return broadcast
|
||||||
|
}
|
||||||
|
|
||||||
|
func selectGateway(dest netip.Prefix, gateways []netip.Prefix) (netip.Prefix, error) {
|
||||||
|
for _, gateway := range gateways {
|
||||||
|
if dest.Addr().Is4() && gateway.Addr().Is4() {
|
||||||
|
return gateway, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if dest.Addr().Is6() && gateway.Addr().Is6() {
|
||||||
|
return gateway, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return netip.Prefix{}, fmt.Errorf("no gateway found for %v in the list of vpn networks", dest)
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
//go:build darwin && !ios && !e2e_testing
|
//go:build !ios && !e2e_testing
|
||||||
// +build darwin,!ios,!e2e_testing
|
// +build !ios,!e2e_testing
|
||||||
|
|
||||||
package overlay
|
package overlay
|
||||||
|
|
||||||
@@ -8,27 +8,48 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"sync/atomic"
|
||||||
|
"syscall"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
netroute "golang.org/x/net/route"
|
netroute "golang.org/x/net/route"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
wgtun "golang.zx2c4.com/wireguard/tun"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type tun struct {
|
type tun struct {
|
||||||
linkAddr *netroute.LinkAddr
|
io.ReadWriteCloser
|
||||||
|
Device string
|
||||||
|
vpnNetworks []netip.Prefix
|
||||||
|
DefaultMTU int
|
||||||
|
Routes atomic.Pointer[[]Route]
|
||||||
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
|
linkAddr *netroute.LinkAddr
|
||||||
|
l *logrus.Logger
|
||||||
|
|
||||||
|
// cache out buffer since we need to prepend 4 bytes for tun metadata
|
||||||
|
out []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// ioctl structures for Darwin network configuration
|
|
||||||
type ifReq struct {
|
type ifReq struct {
|
||||||
Name [unix.IFNAMSIZ]byte
|
Name [unix.IFNAMSIZ]byte
|
||||||
Flags uint16
|
Flags uint16
|
||||||
pad [8]byte
|
pad [8]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
_SIOCAIFADDR_IN6 = 2155899162
|
||||||
|
_UTUN_OPT_IFNAME = 2
|
||||||
|
_IN6_IFF_NODAD = 0x0020
|
||||||
|
_IN6_IFF_SECURED = 0x0400
|
||||||
|
utunControlName = "com.apple.net.utun_control"
|
||||||
|
)
|
||||||
|
|
||||||
type ifreqMTU struct {
|
type ifreqMTU struct {
|
||||||
Name [16]byte
|
Name [16]byte
|
||||||
MTU int32
|
MTU int32
|
||||||
@@ -58,61 +79,60 @@ type ifreqAlias6 struct {
|
|||||||
Lifetime addrLifetime
|
Lifetime addrLifetime
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||||
_SIOCAIFADDR_IN6 = 2155899162
|
|
||||||
_IN6_IFF_NODAD = 0x0020
|
|
||||||
)
|
|
||||||
|
|
||||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*wgTun, error) {
|
|
||||||
return nil, fmt.Errorf("newTunFromFd not supported on Darwin")
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*wgTun, error) {
|
|
||||||
name := c.GetString("tun.dev", "")
|
name := c.GetString("tun.dev", "")
|
||||||
deviceName := "utun"
|
ifIndex := -1
|
||||||
|
|
||||||
// Parse device name to handle utun[0-9]+ format
|
|
||||||
if name != "" && name != "utun" {
|
if name != "" && name != "utun" {
|
||||||
ifIndex := -1
|
|
||||||
_, err := fmt.Sscanf(name, "utun%d", &ifIndex)
|
_, err := fmt.Sscanf(name, "utun%d", &ifIndex)
|
||||||
if err != nil || ifIndex < 0 {
|
if err != nil || ifIndex < 0 {
|
||||||
// NOTE: we don't make this error so we don't break existing
|
// NOTE: we don't make this error so we don't break existing
|
||||||
// configs that set a name before it was used.
|
// configs that set a name before it was used.
|
||||||
l.Warn("interface name must be utun[0-9]+ on Darwin, ignoring")
|
l.Warn("interface name must be utun[0-9]+ on Darwin, ignoring")
|
||||||
} else {
|
ifIndex = -1
|
||||||
deviceName = name
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mtu := c.GetInt("tun.mtu", DefaultMTU)
|
fd, err := unix.Socket(unix.AF_SYSTEM, unix.SOCK_DGRAM, unix.AF_SYS_CONTROL)
|
||||||
|
|
||||||
// Create WireGuard TUN device
|
|
||||||
tunDevice, err := wgtun.CreateTUN(deviceName, mtu)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create TUN device: %w", err)
|
return nil, fmt.Errorf("system socket: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the actual device name
|
var ctlInfo = &unix.CtlInfo{}
|
||||||
actualName, err := tunDevice.Name()
|
copy(ctlInfo.Name[:], utunControlName)
|
||||||
|
|
||||||
|
err = unix.IoctlCtlInfo(fd, ctlInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tunDevice.Close()
|
return nil, fmt.Errorf("CTLIOCGINFO: %v", err)
|
||||||
return nil, fmt.Errorf("failed to get TUN device name: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
t := &wgTun{
|
err = unix.Connect(fd, &unix.SockaddrCtl{
|
||||||
tunDevice: tunDevice,
|
ID: ctlInfo.Id,
|
||||||
vpnNetworks: vpnNetworks,
|
Unit: uint32(ifIndex) + 1,
|
||||||
MaxMTU: mtu,
|
})
|
||||||
DefaultMTU: mtu,
|
if err != nil {
|
||||||
l: l,
|
return nil, fmt.Errorf("SYS_CONNECT: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create Darwin-specific route manager
|
name, err = unix.GetsockoptString(fd, unix.AF_SYS_CONTROL, _UTUN_OPT_IFNAME)
|
||||||
t.routeManager = &tun{}
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to retrieve tun name: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = unix.SetNonblock(fd, true)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("SetNonblock: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t := &tun{
|
||||||
|
ReadWriteCloser: os.NewFile(uintptr(fd), ""),
|
||||||
|
Device: name,
|
||||||
|
vpnNetworks: vpnNetworks,
|
||||||
|
DefaultMTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||||
|
l: l,
|
||||||
|
}
|
||||||
|
|
||||||
err = t.reload(c, true)
|
err = t.reload(c, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tunDevice.Close()
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -123,251 +143,215 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
l.WithField("name", actualName).Info("Created WireGuard TUN device")
|
|
||||||
|
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rm *tun) Activate(t *wgTun) error {
|
func (t *tun) deviceBytes() (o [16]byte) {
|
||||||
name, err := t.tunDevice.Name()
|
for i, c := range t.Device {
|
||||||
|
o[i] = byte(c)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||||
|
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) Close() error {
|
||||||
|
if t.ReadWriteCloser != nil {
|
||||||
|
return t.ReadWriteCloser.Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) Activate() error {
|
||||||
|
devName := t.deviceBytes()
|
||||||
|
|
||||||
|
s, err := unix.Socket(
|
||||||
|
unix.AF_INET,
|
||||||
|
unix.SOCK_DGRAM,
|
||||||
|
unix.IPPROTO_IP,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get device name: %w", err)
|
return err
|
||||||
|
}
|
||||||
|
defer unix.Close(s)
|
||||||
|
|
||||||
|
fd := uintptr(s)
|
||||||
|
|
||||||
|
// Set the MTU on the device
|
||||||
|
ifm := ifreqMTU{Name: devName, MTU: int32(t.DefaultMTU)}
|
||||||
|
if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
|
||||||
|
return fmt.Errorf("failed to set tun mtu: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set the MTU
|
// Get the device flags
|
||||||
rm.SetMTU(t, t.MaxMTU)
|
ifrf := ifReq{Name: devName}
|
||||||
|
if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
||||||
// Add IP addresses
|
return fmt.Errorf("failed to get tun flags: %s", err)
|
||||||
for _, network := range t.vpnNetworks {
|
|
||||||
if err := rm.addIP(t, name, network); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Bring up the interface using ioctl
|
linkAddr, err := getLinkAddr(t.Device)
|
||||||
if err := rm.bringUpInterface(name); err != nil {
|
|
||||||
return fmt.Errorf("failed to bring up interface: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the link address for routing
|
|
||||||
linkAddr, err := getLinkAddr(name)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get link address: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
if linkAddr == nil {
|
if linkAddr == nil {
|
||||||
return fmt.Errorf("unable to discover link_addr for tun interface")
|
return fmt.Errorf("unable to discover link_addr for tun interface")
|
||||||
}
|
}
|
||||||
rm.linkAddr = linkAddr
|
t.linkAddr = linkAddr
|
||||||
|
|
||||||
// Set the routes
|
for _, network := range t.vpnNetworks {
|
||||||
if err := rm.AddRoutes(t, false); err != nil {
|
if network.Addr().Is4() {
|
||||||
|
err = t.activate4(network)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
err = t.activate6(network)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run the interface
|
||||||
|
ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING
|
||||||
|
if err = ioctl(fd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
||||||
|
return fmt.Errorf("failed to run tun device: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unsafe path routes
|
||||||
|
return t.addRoutes(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) activate4(network netip.Prefix) error {
|
||||||
|
s, err := unix.Socket(
|
||||||
|
unix.AF_INET,
|
||||||
|
unix.SOCK_DGRAM,
|
||||||
|
unix.IPPROTO_IP,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer unix.Close(s)
|
||||||
|
|
||||||
|
ifr := ifreqAlias4{
|
||||||
|
Name: t.deviceBytes(),
|
||||||
|
Addr: unix.RawSockaddrInet4{
|
||||||
|
Len: unix.SizeofSockaddrInet4,
|
||||||
|
Family: unix.AF_INET,
|
||||||
|
Addr: network.Addr().As4(),
|
||||||
|
},
|
||||||
|
DstAddr: unix.RawSockaddrInet4{
|
||||||
|
Len: unix.SizeofSockaddrInet4,
|
||||||
|
Family: unix.AF_INET,
|
||||||
|
Addr: network.Addr().As4(),
|
||||||
|
},
|
||||||
|
MaskAddr: unix.RawSockaddrInet4{
|
||||||
|
Len: unix.SizeofSockaddrInet4,
|
||||||
|
Family: unix.AF_INET,
|
||||||
|
Addr: prefixToMask(network).As4(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
|
||||||
|
return fmt.Errorf("failed to set tun v4 address: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = addRoute(network, t.linkAddr)
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rm *tun) bringUpInterface(name string) error {
|
func (t *tun) activate6(network netip.Prefix) error {
|
||||||
// Open a socket for ioctl
|
s, err := unix.Socket(
|
||||||
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, 0)
|
unix.AF_INET6,
|
||||||
|
unix.SOCK_DGRAM,
|
||||||
|
unix.IPPROTO_IP,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create socket: %w", err)
|
return err
|
||||||
}
|
|
||||||
defer unix.Close(fd)
|
|
||||||
|
|
||||||
// Get current flags
|
|
||||||
var ifrf ifReq
|
|
||||||
copy(ifrf.Name[:], name)
|
|
||||||
|
|
||||||
if err := ioctl(uintptr(fd), unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
|
||||||
return fmt.Errorf("failed to get interface flags: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set IFF_UP and IFF_RUNNING flags
|
|
||||||
ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING
|
|
||||||
|
|
||||||
if err := ioctl(uintptr(fd), unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
|
||||||
return fmt.Errorf("failed to set interface flags: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rm *tun) SetMTU(t *wgTun, mtu int) {
|
|
||||||
name, err := t.tunDevice.Name()
|
|
||||||
if err != nil {
|
|
||||||
t.l.WithError(err).Error("Failed to get device name for MTU set")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Open a socket for ioctl
|
|
||||||
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, 0)
|
|
||||||
if err != nil {
|
|
||||||
t.l.WithError(err).Error("Failed to create socket for MTU set")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer unix.Close(fd)
|
|
||||||
|
|
||||||
// Prepare the ioctl request
|
|
||||||
var ifr ifreqMTU
|
|
||||||
copy(ifr.Name[:], name)
|
|
||||||
ifr.MTU = int32(mtu)
|
|
||||||
|
|
||||||
// Set the MTU using ioctl
|
|
||||||
if err := ioctl(uintptr(fd), unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifr))); err != nil {
|
|
||||||
t.l.WithError(err).Error("Failed to set tun mtu via ioctl")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rm *tun) SetDefaultRoute(t *wgTun, cidr netip.Prefix) error {
|
|
||||||
// On Darwin, routes are set via ifconfig and route commands
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rm *tun) AddRoutes(t *wgTun, logErrors bool) error {
|
|
||||||
routes := *t.Routes.Load()
|
|
||||||
for _, r := range routes {
|
|
||||||
if !r.Install {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
err := rm.addRoute(r.Cidr)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, unix.EEXIST) {
|
|
||||||
t.l.WithField("route", r.Cidr).
|
|
||||||
Warnf("unable to add unsafe_route, identical route already exists")
|
|
||||||
} else {
|
|
||||||
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
|
||||||
if logErrors {
|
|
||||||
retErr.Log(t.l)
|
|
||||||
} else {
|
|
||||||
return retErr
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
t.l.WithField("route", r).Info("Added route")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rm *tun) RemoveRoutes(t *wgTun, routes []Route) {
|
|
||||||
for _, r := range routes {
|
|
||||||
if !r.Install {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
err := rm.delRoute(r.Cidr)
|
|
||||||
if err != nil {
|
|
||||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
|
||||||
} else {
|
|
||||||
t.l.WithField("route", r).Info("Removed route")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rm *tun) NewMultiQueueReader(t *wgTun) (io.ReadWriteCloser, error) {
|
|
||||||
// Darwin doesn't support multi-queue TUN devices in the same way as Linux
|
|
||||||
// Return a reader that wraps the same device
|
|
||||||
return &wgTunReader{
|
|
||||||
parent: t,
|
|
||||||
tunDevice: t.tunDevice,
|
|
||||||
offset: 0,
|
|
||||||
l: t.l,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rm *tun) addIP(t *wgTun, name string, network netip.Prefix) error {
|
|
||||||
addr := network.Addr()
|
|
||||||
|
|
||||||
if addr.Is4() {
|
|
||||||
return rm.addIPv4(name, network)
|
|
||||||
} else {
|
|
||||||
return rm.addIPv6(name, network)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rm *tun) addIPv4(name string, network netip.Prefix) error {
|
|
||||||
// Open an IPv4 socket for ioctl
|
|
||||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create IPv4 socket: %w", err)
|
|
||||||
}
|
}
|
||||||
defer unix.Close(s)
|
defer unix.Close(s)
|
||||||
|
|
||||||
var ifr ifreqAlias4
|
ifr := ifreqAlias6{
|
||||||
copy(ifr.Name[:], name)
|
Name: t.deviceBytes(),
|
||||||
|
Addr: unix.RawSockaddrInet6{
|
||||||
// Set the address
|
Len: unix.SizeofSockaddrInet6,
|
||||||
ifr.Addr = unix.RawSockaddrInet4{
|
Family: unix.AF_INET6,
|
||||||
Len: unix.SizeofSockaddrInet4,
|
Addr: network.Addr().As16(),
|
||||||
Family: unix.AF_INET,
|
},
|
||||||
Addr: network.Addr().As4(),
|
PrefixMask: unix.RawSockaddrInet6{
|
||||||
|
Len: unix.SizeofSockaddrInet6,
|
||||||
|
Family: unix.AF_INET6,
|
||||||
|
Addr: prefixToMask(network).As16(),
|
||||||
|
},
|
||||||
|
Lifetime: addrLifetime{
|
||||||
|
// never expires
|
||||||
|
Vltime: 0xffffffff,
|
||||||
|
Pltime: 0xffffffff,
|
||||||
|
},
|
||||||
|
Flags: _IN6_IFF_NODAD,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set the destination address (same as address for point-to-point)
|
|
||||||
ifr.DstAddr = unix.RawSockaddrInet4{
|
|
||||||
Len: unix.SizeofSockaddrInet4,
|
|
||||||
Family: unix.AF_INET,
|
|
||||||
Addr: network.Addr().As4(),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the netmask
|
|
||||||
ifr.MaskAddr = unix.RawSockaddrInet4{
|
|
||||||
Len: unix.SizeofSockaddrInet4,
|
|
||||||
Family: unix.AF_INET,
|
|
||||||
Addr: prefixToMask(network).As4(),
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
|
|
||||||
return fmt.Errorf("failed to set IPv4 address via ioctl: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rm *tun) addIPv6(name string, network netip.Prefix) error {
|
|
||||||
// Open an IPv6 socket for ioctl
|
|
||||||
s, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create IPv6 socket: %w", err)
|
|
||||||
}
|
|
||||||
defer unix.Close(s)
|
|
||||||
|
|
||||||
var ifr ifreqAlias6
|
|
||||||
copy(ifr.Name[:], name)
|
|
||||||
|
|
||||||
// Set the address
|
|
||||||
ifr.Addr = unix.RawSockaddrInet6{
|
|
||||||
Len: unix.SizeofSockaddrInet6,
|
|
||||||
Family: unix.AF_INET6,
|
|
||||||
Addr: network.Addr().As16(),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the prefix mask
|
|
||||||
ifr.PrefixMask = unix.RawSockaddrInet6{
|
|
||||||
Len: unix.SizeofSockaddrInet6,
|
|
||||||
Family: unix.AF_INET6,
|
|
||||||
Addr: prefixToMask(network).As16(),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set lifetime (never expires)
|
|
||||||
ifr.Lifetime = addrLifetime{
|
|
||||||
Vltime: 0xffffffff,
|
|
||||||
Pltime: 0xffffffff,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set flags (no DAD - Duplicate Address Detection)
|
|
||||||
ifr.Flags = _IN6_IFF_NODAD
|
|
||||||
|
|
||||||
if err := ioctl(uintptr(s), _SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil {
|
if err := ioctl(uintptr(s), _SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil {
|
||||||
return fmt.Errorf("failed to set IPv6 address via ioctl: %w", err)
|
return fmt.Errorf("failed to set tun address: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *tun) reload(c *config.C, initial bool) error {
|
||||||
|
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !initial && !change {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
routeTree, err := makeRouteTree(t.l, routes, false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Teach nebula how to handle the routes before establishing them in the system table
|
||||||
|
oldRoutes := t.Routes.Swap(&routes)
|
||||||
|
t.routeTree.Store(routeTree)
|
||||||
|
|
||||||
|
if !initial {
|
||||||
|
// Remove first, if the system removes a wanted route hopefully it will be re-added next
|
||||||
|
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
|
||||||
|
if err != nil {
|
||||||
|
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure any routes we actually want are installed
|
||||||
|
err = t.addRoutes(true)
|
||||||
|
if err != nil {
|
||||||
|
// Catch any stray logs
|
||||||
|
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||||
|
r, ok := t.routeTree.Load().Lookup(ip)
|
||||||
|
if ok {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
return routing.Gateways{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the LinkAddr for the interface of the given name
|
||||||
|
// Is there an easier way to fetch this when we create the interface?
|
||||||
|
// Maybe SIOCGIFINDEX? but this doesn't appear to exist in the darwin headers.
|
||||||
func getLinkAddr(name string) (*netroute.LinkAddr, error) {
|
func getLinkAddr(name string) (*netroute.LinkAddr, error) {
|
||||||
rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0)
|
rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -393,7 +377,53 @@ func getLinkAddr(name string) (*netroute.LinkAddr, error) {
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rm *tun) addRoute(prefix netip.Prefix) error {
|
func (t *tun) addRoutes(logErrors bool) error {
|
||||||
|
routes := *t.Routes.Load()
|
||||||
|
|
||||||
|
for _, r := range routes {
|
||||||
|
if len(r.Via) == 0 || !r.Install {
|
||||||
|
// We don't allow route MTUs so only install routes with a via
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err := addRoute(r.Cidr, t.linkAddr)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, unix.EEXIST) {
|
||||||
|
t.l.WithField("route", r.Cidr).
|
||||||
|
Warnf("unable to add unsafe_route, identical route already exists")
|
||||||
|
} else {
|
||||||
|
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
||||||
|
if logErrors {
|
||||||
|
retErr.Log(t.l)
|
||||||
|
} else {
|
||||||
|
return retErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t.l.WithField("route", r).Info("Added route")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) removeRoutes(routes []Route) error {
|
||||||
|
for _, r := range routes {
|
||||||
|
if !r.Install {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err := delRoute(r.Cidr, t.linkAddr)
|
||||||
|
if err != nil {
|
||||||
|
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||||
|
} else {
|
||||||
|
t.l.WithField("route", r).Info("Removed route")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
|
||||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||||
@@ -411,13 +441,13 @@ func (rm *tun) addRoute(prefix netip.Prefix) error {
|
|||||||
route.Addrs = []netroute.Addr{
|
route.Addrs = []netroute.Addr{
|
||||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||||
unix.RTAX_GATEWAY: rm.linkAddr,
|
unix.RTAX_GATEWAY: gateway,
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
route.Addrs = []netroute.Addr{
|
route.Addrs = []netroute.Addr{
|
||||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||||
unix.RTAX_GATEWAY: rm.linkAddr,
|
unix.RTAX_GATEWAY: gateway,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -434,7 +464,7 @@ func (rm *tun) addRoute(prefix netip.Prefix) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rm *tun) delRoute(prefix netip.Prefix) error {
|
func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
|
||||||
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||||
@@ -451,13 +481,13 @@ func (rm *tun) delRoute(prefix netip.Prefix) error {
|
|||||||
route.Addrs = []netroute.Addr{
|
route.Addrs = []netroute.Addr{
|
||||||
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||||
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||||
unix.RTAX_GATEWAY: rm.linkAddr,
|
unix.RTAX_GATEWAY: gateway,
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
route.Addrs = []netroute.Addr{
|
route.Addrs = []netroute.Addr{
|
||||||
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||||
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||||
unix.RTAX_GATEWAY: rm.linkAddr,
|
unix.RTAX_GATEWAY: gateway,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -465,7 +495,6 @@ func (rm *tun) delRoute(prefix netip.Prefix) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = unix.Write(sock, data[:])
|
_, err = unix.Write(sock, data[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
||||||
@@ -474,34 +503,52 @@ func (rm *tun) delRoute(prefix netip.Prefix) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ioctl(a1, a2, a3 uintptr) error {
|
func (t *tun) Read(to []byte) (int, error) {
|
||||||
_, _, errno := unix.Syscall(unix.SYS_IOCTL, a1, a2, a3)
|
buf := make([]byte, len(to)+4)
|
||||||
if errno != 0 {
|
|
||||||
return errno
|
n, err := t.ReadWriteCloser.Read(buf)
|
||||||
}
|
|
||||||
return nil
|
copy(to, buf[4:])
|
||||||
|
return n - 4, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func prefixToMask(prefix netip.Prefix) netip.Addr {
|
// Write is only valid for single threaded use
|
||||||
bits := prefix.Bits()
|
func (t *tun) Write(from []byte) (int, error) {
|
||||||
if prefix.Addr().Is4() {
|
buf := t.out
|
||||||
// Create IPv4 netmask from prefix length
|
if cap(buf) < len(from)+4 {
|
||||||
mask := ^uint32(0) << (32 - bits)
|
buf = make([]byte, len(from)+4)
|
||||||
return netip.AddrFrom4([4]byte{
|
t.out = buf
|
||||||
byte(mask >> 24),
|
|
||||||
byte(mask >> 16),
|
|
||||||
byte(mask >> 8),
|
|
||||||
byte(mask),
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
// Create IPv6 netmask from prefix length
|
|
||||||
var mask [16]byte
|
|
||||||
for i := 0; i < bits/8; i++ {
|
|
||||||
mask[i] = 0xff
|
|
||||||
}
|
|
||||||
if bits%8 != 0 {
|
|
||||||
mask[bits/8] = ^byte(0) << (8 - bits%8)
|
|
||||||
}
|
|
||||||
return netip.AddrFrom16(mask)
|
|
||||||
}
|
}
|
||||||
|
buf = buf[:len(from)+4]
|
||||||
|
|
||||||
|
if len(from) == 0 {
|
||||||
|
return 0, syscall.EIO
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine the IP Family for the NULL L2 Header
|
||||||
|
ipVer := from[0] >> 4
|
||||||
|
if ipVer == 4 {
|
||||||
|
buf[3] = syscall.AF_INET
|
||||||
|
} else if ipVer == 6 {
|
||||||
|
buf[3] = syscall.AF_INET6
|
||||||
|
} else {
|
||||||
|
return 0, fmt.Errorf("unable to determine IP version from packet")
|
||||||
|
}
|
||||||
|
|
||||||
|
copy(buf[4:], from)
|
||||||
|
|
||||||
|
n, err := t.ReadWriteCloser.Write(buf)
|
||||||
|
return n - 4, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) Networks() []netip.Prefix {
|
||||||
|
return t.vpnNetworks
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) Name() string {
|
||||||
|
return t.Device
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,77 +1,284 @@
|
|||||||
//go:build freebsd && !e2e_testing
|
//go:build !e2e_testing
|
||||||
// +build freebsd,!e2e_testing
|
// +build !e2e_testing
|
||||||
|
|
||||||
package overlay
|
package overlay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"io/fs"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os/exec"
|
"sync/atomic"
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"syscall"
|
"syscall"
|
||||||
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
|
netroute "golang.org/x/net/route"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
wgtun "golang.zx2c4.com/wireguard/tun"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type tun struct{}
|
const (
|
||||||
|
// FIODGNAME is defined in sys/sys/filio.h on FreeBSD
|
||||||
|
// For 32-bit systems, use FIODGNAME_32 (not defined in this file: 0x80086678)
|
||||||
|
FIODGNAME = 0x80106678
|
||||||
|
TUNSIFMODE = 0x8004745e
|
||||||
|
TUNSIFHEAD = 0x80047460
|
||||||
|
OSIOCAIFADDR_IN6 = 0x8088691b
|
||||||
|
IN6_IFF_NODAD = 0x0020
|
||||||
|
)
|
||||||
|
|
||||||
|
type fiodgnameArg struct {
|
||||||
|
length int32
|
||||||
|
pad [4]byte
|
||||||
|
buf unsafe.Pointer
|
||||||
|
}
|
||||||
|
|
||||||
// ifreqRename is used for renaming network interfaces on FreeBSD
|
|
||||||
type ifreqRename struct {
|
type ifreqRename struct {
|
||||||
Name [unix.IFNAMSIZ]byte
|
Name [unix.IFNAMSIZ]byte
|
||||||
Data uintptr
|
Data uintptr
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*wgTun, error) {
|
type ifreqDestroy struct {
|
||||||
return nil, fmt.Errorf("newTunFromFd not supported on FreeBSD")
|
Name [unix.IFNAMSIZ]byte
|
||||||
|
pad [16]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*wgTun, error) {
|
type ifReq struct {
|
||||||
deviceName := c.GetString("tun.dev", "tun")
|
Name [unix.IFNAMSIZ]byte
|
||||||
mtu := c.GetInt("tun.mtu", DefaultMTU)
|
Flags uint16
|
||||||
|
}
|
||||||
|
|
||||||
// Create WireGuard TUN device
|
type ifreqMTU struct {
|
||||||
tunDevice, err := wgtun.CreateTUN(deviceName, mtu)
|
Name [unix.IFNAMSIZ]byte
|
||||||
if err != nil {
|
MTU int32
|
||||||
return nil, fmt.Errorf("failed to create TUN device: %w", err)
|
}
|
||||||
|
|
||||||
|
type addrLifetime struct {
|
||||||
|
Expire uint64
|
||||||
|
Preferred uint64
|
||||||
|
Vltime uint32
|
||||||
|
Pltime uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
type ifreqAlias4 struct {
|
||||||
|
Name [unix.IFNAMSIZ]byte
|
||||||
|
Addr unix.RawSockaddrInet4
|
||||||
|
DstAddr unix.RawSockaddrInet4
|
||||||
|
MaskAddr unix.RawSockaddrInet4
|
||||||
|
VHid uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
type ifreqAlias6 struct {
|
||||||
|
Name [unix.IFNAMSIZ]byte
|
||||||
|
Addr unix.RawSockaddrInet6
|
||||||
|
DstAddr unix.RawSockaddrInet6
|
||||||
|
PrefixMask unix.RawSockaddrInet6
|
||||||
|
Flags uint32
|
||||||
|
Lifetime addrLifetime
|
||||||
|
VHid uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
type tun struct {
|
||||||
|
Device string
|
||||||
|
vpnNetworks []netip.Prefix
|
||||||
|
MTU int
|
||||||
|
Routes atomic.Pointer[[]Route]
|
||||||
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
|
linkAddr *netroute.LinkAddr
|
||||||
|
l *logrus.Logger
|
||||||
|
devFd int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) Read(to []byte) (int, error) {
|
||||||
|
// use readv() to read from the tunnel device, to eliminate the need for copying the buffer
|
||||||
|
if t.devFd < 0 {
|
||||||
|
return -1, syscall.EINVAL
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the actual device name
|
// first 4 bytes is protocol family, in network byte order
|
||||||
actualName, err := tunDevice.Name()
|
head := make([]byte, 4)
|
||||||
|
|
||||||
|
iovecs := []syscall.Iovec{
|
||||||
|
{&head[0], 4},
|
||||||
|
{&to[0], uint64(len(to))},
|
||||||
|
}
|
||||||
|
|
||||||
|
n, _, errno := syscall.Syscall(syscall.SYS_READV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
|
||||||
|
|
||||||
|
var err error
|
||||||
|
if errno != 0 {
|
||||||
|
err = syscall.Errno(errno)
|
||||||
|
} else {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
// fix bytes read number to exclude header
|
||||||
|
bytesRead := int(n)
|
||||||
|
if bytesRead < 0 {
|
||||||
|
return bytesRead, err
|
||||||
|
} else if bytesRead < 4 {
|
||||||
|
return 0, err
|
||||||
|
} else {
|
||||||
|
return bytesRead - 4, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write is only valid for single threaded use
|
||||||
|
func (t *tun) Write(from []byte) (int, error) {
|
||||||
|
// use writev() to write to the tunnel device, to eliminate the need for copying the buffer
|
||||||
|
if t.devFd < 0 {
|
||||||
|
return -1, syscall.EINVAL
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(from) <= 1 {
|
||||||
|
return 0, syscall.EIO
|
||||||
|
}
|
||||||
|
ipVer := from[0] >> 4
|
||||||
|
var head []byte
|
||||||
|
// first 4 bytes is protocol family, in network byte order
|
||||||
|
if ipVer == 4 {
|
||||||
|
head = []byte{0, 0, 0, syscall.AF_INET}
|
||||||
|
} else if ipVer == 6 {
|
||||||
|
head = []byte{0, 0, 0, syscall.AF_INET6}
|
||||||
|
} else {
|
||||||
|
return 0, fmt.Errorf("unable to determine IP version from packet")
|
||||||
|
}
|
||||||
|
iovecs := []syscall.Iovec{
|
||||||
|
{&head[0], 4},
|
||||||
|
{&from[0], uint64(len(from))},
|
||||||
|
}
|
||||||
|
|
||||||
|
n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
|
||||||
|
|
||||||
|
var err error
|
||||||
|
if errno != 0 {
|
||||||
|
err = syscall.Errno(errno)
|
||||||
|
} else {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return int(n) - 4, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) Close() error {
|
||||||
|
if t.devFd >= 0 {
|
||||||
|
err := syscall.Close(t.devFd)
|
||||||
|
if err != nil {
|
||||||
|
t.l.WithError(err).Error("Error closing device")
|
||||||
|
}
|
||||||
|
t.devFd = -1
|
||||||
|
|
||||||
|
c := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
// destroying the interface can block if a read() is still pending. Do this asynchronously.
|
||||||
|
defer close(c)
|
||||||
|
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
|
||||||
|
if err == nil {
|
||||||
|
defer syscall.Close(s)
|
||||||
|
ifreq := ifreqDestroy{Name: t.deviceBytes()}
|
||||||
|
err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.l.WithError(err).Error("Error destroying tunnel")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// wait up to 1 second so we start blocking at the ioctl
|
||||||
|
select {
|
||||||
|
case <-c:
|
||||||
|
case <-time.After(1 * time.Second):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||||
|
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||||
|
// Try to open existing tun device
|
||||||
|
var fd int
|
||||||
|
var err error
|
||||||
|
deviceName := c.GetString("tun.dev", "")
|
||||||
|
if deviceName != "" {
|
||||||
|
fd, err = syscall.Open("/dev/"+deviceName, syscall.O_RDWR, 0)
|
||||||
|
}
|
||||||
|
if errors.Is(err, fs.ErrNotExist) || deviceName == "" {
|
||||||
|
// If the device doesn't already exist, request a new one and rename it
|
||||||
|
fd, err = syscall.Open("/dev/tun", syscall.O_RDWR, 0)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tunDevice.Close()
|
return nil, err
|
||||||
return nil, fmt.Errorf("failed to get TUN device name: %w", err)
|
}
|
||||||
|
|
||||||
|
// Read the name of the interface
|
||||||
|
var name [16]byte
|
||||||
|
arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)}
|
||||||
|
ctrlErr := ioctl(uintptr(fd), FIODGNAME, uintptr(unsafe.Pointer(&arg)))
|
||||||
|
|
||||||
|
if ctrlErr == nil {
|
||||||
|
// set broadcast mode and multicast
|
||||||
|
ifmode := uint32(unix.IFF_BROADCAST | unix.IFF_MULTICAST)
|
||||||
|
ctrlErr = ioctl(uintptr(fd), TUNSIFMODE, uintptr(unsafe.Pointer(&ifmode)))
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctrlErr == nil {
|
||||||
|
// turn on link-layer mode, to support ipv6
|
||||||
|
ifhead := uint32(1)
|
||||||
|
ctrlErr = ioctl(uintptr(fd), TUNSIFHEAD, uintptr(unsafe.Pointer(&ifhead)))
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctrlErr != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ifName := string(bytes.TrimRight(name[:], "\x00"))
|
||||||
|
if deviceName == "" {
|
||||||
|
deviceName = ifName
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the name doesn't match the desired interface name, rename it now
|
// If the name doesn't match the desired interface name, rename it now
|
||||||
if actualName != deviceName && deviceName != "" && deviceName != "tun" {
|
if ifName != deviceName {
|
||||||
if err := renameInterface(actualName, deviceName); err != nil {
|
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||||
tunDevice.Close()
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to rename interface from %s to %s: %w", actualName, deviceName, err)
|
return nil, err
|
||||||
}
|
}
|
||||||
actualName = deviceName
|
defer syscall.Close(s)
|
||||||
|
|
||||||
|
fd := uintptr(s)
|
||||||
|
|
||||||
|
var fromName [16]byte
|
||||||
|
var toName [16]byte
|
||||||
|
copy(fromName[:], ifName)
|
||||||
|
copy(toName[:], deviceName)
|
||||||
|
|
||||||
|
ifrr := ifreqRename{
|
||||||
|
Name: fromName,
|
||||||
|
Data: uintptr(unsafe.Pointer(&toName)),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the device name
|
||||||
|
ioctl(fd, syscall.SIOCSIFNAME, uintptr(unsafe.Pointer(&ifrr)))
|
||||||
}
|
}
|
||||||
|
|
||||||
t := &wgTun{
|
t := &tun{
|
||||||
tunDevice: tunDevice,
|
Device: deviceName,
|
||||||
vpnNetworks: vpnNetworks,
|
vpnNetworks: vpnNetworks,
|
||||||
MaxMTU: mtu,
|
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||||
DefaultMTU: mtu,
|
|
||||||
l: l,
|
l: l,
|
||||||
|
devFd: fd,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create FreeBSD-specific route manager
|
|
||||||
t.routeManager = &tun{}
|
|
||||||
|
|
||||||
err = t.reload(c, true)
|
err = t.reload(c, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tunDevice.Close()
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -82,86 +289,180 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
l.WithField("name", actualName).Info("Created WireGuard TUN device")
|
|
||||||
|
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rm *tun) Activate(t *wgTun) error {
|
func (t *tun) addIp(cidr netip.Prefix) error {
|
||||||
name, err := t.tunDevice.Name()
|
if cidr.Addr().Is4() {
|
||||||
if err != nil {
|
ifr := ifreqAlias4{
|
||||||
return fmt.Errorf("failed to get device name: %w", err)
|
Name: t.deviceBytes(),
|
||||||
|
Addr: unix.RawSockaddrInet4{
|
||||||
|
Len: unix.SizeofSockaddrInet4,
|
||||||
|
Family: unix.AF_INET,
|
||||||
|
Addr: cidr.Addr().As4(),
|
||||||
|
},
|
||||||
|
DstAddr: unix.RawSockaddrInet4{
|
||||||
|
Len: unix.SizeofSockaddrInet4,
|
||||||
|
Family: unix.AF_INET,
|
||||||
|
Addr: getBroadcast(cidr).As4(),
|
||||||
|
},
|
||||||
|
MaskAddr: unix.RawSockaddrInet4{
|
||||||
|
Len: unix.SizeofSockaddrInet4,
|
||||||
|
Family: unix.AF_INET,
|
||||||
|
Addr: prefixToMask(cidr).As4(),
|
||||||
|
},
|
||||||
|
VHid: 0,
|
||||||
|
}
|
||||||
|
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer syscall.Close(s)
|
||||||
|
// Note: unix.SIOCAIFADDR corresponds to FreeBSD's OSIOCAIFADDR
|
||||||
|
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
|
||||||
|
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set the MTU
|
if cidr.Addr().Is6() {
|
||||||
rm.SetMTU(t, t.MaxMTU)
|
ifr := ifreqAlias6{
|
||||||
|
Name: t.deviceBytes(),
|
||||||
|
Addr: unix.RawSockaddrInet6{
|
||||||
|
Len: unix.SizeofSockaddrInet6,
|
||||||
|
Family: unix.AF_INET6,
|
||||||
|
Addr: cidr.Addr().As16(),
|
||||||
|
},
|
||||||
|
PrefixMask: unix.RawSockaddrInet6{
|
||||||
|
Len: unix.SizeofSockaddrInet6,
|
||||||
|
Family: unix.AF_INET6,
|
||||||
|
Addr: prefixToMask(cidr).As16(),
|
||||||
|
},
|
||||||
|
Lifetime: addrLifetime{
|
||||||
|
Expire: 0,
|
||||||
|
Preferred: 0,
|
||||||
|
Vltime: 0xffffffff,
|
||||||
|
Pltime: 0xffffffff,
|
||||||
|
},
|
||||||
|
Flags: IN6_IFF_NODAD,
|
||||||
|
}
|
||||||
|
s, err := syscall.Socket(syscall.AF_INET6, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer syscall.Close(s)
|
||||||
|
|
||||||
// Add IP addresses
|
if err := ioctl(uintptr(s), OSIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil {
|
||||||
for _, network := range t.vpnNetworks {
|
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
|
||||||
if err := rm.addIP(t, name, network); err != nil {
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("unknown address type %v", cidr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) Activate() error {
|
||||||
|
// Setup our default MTU
|
||||||
|
err := t.setMTU()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
linkAddr, err := getLinkAddr(t.Device)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if linkAddr == nil {
|
||||||
|
return fmt.Errorf("unable to discover link_addr for tun interface")
|
||||||
|
}
|
||||||
|
t.linkAddr = linkAddr
|
||||||
|
|
||||||
|
for i := range t.vpnNetworks {
|
||||||
|
err := t.addIp(t.vpnNetworks[i])
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Bring up the interface
|
return t.addRoutes(false)
|
||||||
if err := runCommandBSD("ifconfig", name, "up"); err != nil {
|
}
|
||||||
return fmt.Errorf("failed to bring up interface: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the routes
|
func (t *tun) setMTU() error {
|
||||||
if err := rm.AddRoutes(t, false); err != nil {
|
// Set the MTU on the device
|
||||||
|
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer syscall.Close(s)
|
||||||
|
|
||||||
|
ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MTU)}
|
||||||
|
err = ioctl(uintptr(s), unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm)))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) reload(c *config.C, initial bool) error {
|
||||||
|
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !initial && !change {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
routeTree, err := makeRouteTree(t.l, routes, false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Teach nebula how to handle the routes before establishing them in the system table
|
||||||
|
oldRoutes := t.Routes.Swap(&routes)
|
||||||
|
t.routeTree.Store(routeTree)
|
||||||
|
|
||||||
|
if !initial {
|
||||||
|
// Remove first, if the system removes a wanted route hopefully it will be re-added next
|
||||||
|
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
|
||||||
|
if err != nil {
|
||||||
|
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure any routes we actually want are installed
|
||||||
|
err = t.addRoutes(true)
|
||||||
|
if err != nil {
|
||||||
|
// Catch any stray logs
|
||||||
|
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rm *tun) SetMTU(t *wgTun, mtu int) {
|
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||||
name, err := t.tunDevice.Name()
|
r, _ := t.routeTree.Load().Lookup(ip)
|
||||||
if err != nil {
|
return r
|
||||||
t.l.WithError(err).Error("Failed to get device name for MTU set")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := runCommandBSD("ifconfig", name, "mtu", strconv.Itoa(mtu)); err != nil {
|
|
||||||
t.l.WithError(err).Error("Failed to set tun mtu")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rm *tun) SetDefaultRoute(t *wgTun, cidr netip.Prefix) error {
|
func (t *tun) Networks() []netip.Prefix {
|
||||||
// On FreeBSD, routes are set via ifconfig and route commands
|
return t.vpnNetworks
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rm *tun) AddRoutes(t *wgTun, logErrors bool) error {
|
func (t *tun) Name() string {
|
||||||
name, err := t.tunDevice.Name()
|
return t.Device
|
||||||
if err != nil {
|
}
|
||||||
return fmt.Errorf("failed to get device name: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) addRoutes(logErrors bool) error {
|
||||||
routes := *t.Routes.Load()
|
routes := *t.Routes.Load()
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if !r.Install {
|
if len(r.Via) == 0 || !r.Install {
|
||||||
|
// We don't allow route MTUs so only install routes with a via
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add route using route command
|
err := addRoute(r.Cidr, t.linkAddr)
|
||||||
args := []string{"add"}
|
|
||||||
|
|
||||||
if r.Cidr.Addr().Is6() {
|
|
||||||
args = append(args, "-inet6")
|
|
||||||
} else {
|
|
||||||
args = append(args, "-inet")
|
|
||||||
}
|
|
||||||
|
|
||||||
args = append(args, r.Cidr.String(), "-interface", name)
|
|
||||||
|
|
||||||
if r.Metric > 0 {
|
|
||||||
// FreeBSD doesn't support route metrics directly like Linux
|
|
||||||
t.l.WithField("route", r).Warn("Route metrics are not fully supported on FreeBSD")
|
|
||||||
}
|
|
||||||
|
|
||||||
err := runCommandBSD("route", args...)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
||||||
if logErrors {
|
if logErrors {
|
||||||
@@ -177,99 +478,142 @@ func (rm *tun) AddRoutes(t *wgTun, logErrors bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rm *tun) RemoveRoutes(t *wgTun, routes []Route) {
|
func (t *tun) removeRoutes(routes []Route) error {
|
||||||
name, err := t.tunDevice.Name()
|
|
||||||
if err != nil {
|
|
||||||
t.l.WithError(err).Error("Failed to get device name for route removal")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if !r.Install {
|
if !r.Install {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
args := []string{"delete"}
|
err := delRoute(r.Cidr, t.linkAddr)
|
||||||
|
|
||||||
if r.Cidr.Addr().Is6() {
|
|
||||||
args = append(args, "-inet6")
|
|
||||||
} else {
|
|
||||||
args = append(args, "-inet")
|
|
||||||
}
|
|
||||||
|
|
||||||
args = append(args, r.Cidr.String(), "-interface", name)
|
|
||||||
|
|
||||||
err := runCommandBSD("route", args...)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Removed route")
|
t.l.WithField("route", r).Info("Removed route")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rm *tun) NewMultiQueueReader(t *wgTun) (io.ReadWriteCloser, error) {
|
func (t *tun) deviceBytes() (o [16]byte) {
|
||||||
// FreeBSD doesn't support multi-queue TUN devices in the same way as Linux
|
for i, c := range t.Device {
|
||||||
// Return a reader that wraps the same device
|
o[i] = byte(c)
|
||||||
return &wgTunReader{
|
}
|
||||||
parent: t,
|
return
|
||||||
tunDevice: t.tunDevice,
|
|
||||||
offset: 0,
|
|
||||||
l: t.l,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rm *tun) addIP(t *wgTun, name string, network netip.Prefix) error {
|
func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
|
||||||
addr := network.Addr()
|
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||||
|
}
|
||||||
|
defer unix.Close(sock)
|
||||||
|
|
||||||
if addr.Is4() {
|
route := &netroute.RouteMessage{
|
||||||
// For IPv4: ifconfig tun0 10.0.0.1/24
|
Version: unix.RTM_VERSION,
|
||||||
if err := runCommandBSD("ifconfig", name, network.String()); err != nil {
|
Type: unix.RTM_ADD,
|
||||||
return fmt.Errorf("failed to add IPv4 address: %w", err)
|
Flags: unix.RTF_UP,
|
||||||
|
Seq: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
if prefix.Addr().Is4() {
|
||||||
|
route.Addrs = []netroute.Addr{
|
||||||
|
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||||
|
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||||
|
unix.RTAX_GATEWAY: gateway,
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// For IPv6: ifconfig tun0 inet6 add 2001:db8::1/64
|
route.Addrs = []netroute.Addr{
|
||||||
if err := runCommandBSD("ifconfig", name, "inet6", "add", network.String()); err != nil {
|
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||||
return fmt.Errorf("failed to add IPv6 address: %w", err)
|
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||||
|
unix.RTAX_GATEWAY: gateway,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
data, err := route.Marshal()
|
||||||
}
|
|
||||||
|
|
||||||
func runCommandBSD(name string, args ...string) error {
|
|
||||||
cmd := exec.Command(name, args...)
|
|
||||||
output, err := cmd.CombinedOutput()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%s %s failed: %w\nOutput: %s", name, strings.Join(args, " "), err, string(output))
|
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func renameInterface(fromName, toName string) error {
|
_, err = unix.Write(sock, data[:])
|
||||||
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create socket: %w", err)
|
if errors.Is(err, unix.EEXIST) {
|
||||||
}
|
// Try to do a change
|
||||||
defer syscall.Close(s)
|
route.Type = unix.RTM_CHANGE
|
||||||
|
data, err = route.Marshal()
|
||||||
fd := uintptr(s)
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
|
||||||
var fromNameBytes [unix.IFNAMSIZ]byte
|
}
|
||||||
var toNameBytes [unix.IFNAMSIZ]byte
|
_, err = unix.Write(sock, data[:])
|
||||||
copy(fromNameBytes[:], fromName)
|
fmt.Println("DOING CHANGE")
|
||||||
copy(toNameBytes[:], toName)
|
return err
|
||||||
|
}
|
||||||
ifrr := ifreqRename{
|
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
||||||
Name: fromNameBytes,
|
|
||||||
Data: uintptr(unsafe.Pointer(&toNameBytes)),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the device name using SIOCSIFNAME ioctl
|
|
||||||
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, fd, syscall.SIOCSIFNAME, uintptr(unsafe.Pointer(&ifrr)))
|
|
||||||
if errno != 0 {
|
|
||||||
return fmt.Errorf("SIOCSIFNAME ioctl failed: %w", errno)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
|
||||||
|
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||||
|
}
|
||||||
|
defer unix.Close(sock)
|
||||||
|
|
||||||
|
route := netroute.RouteMessage{
|
||||||
|
Version: unix.RTM_VERSION,
|
||||||
|
Type: unix.RTM_DELETE,
|
||||||
|
Seq: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
if prefix.Addr().Is4() {
|
||||||
|
route.Addrs = []netroute.Addr{
|
||||||
|
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||||
|
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||||
|
unix.RTAX_GATEWAY: gateway,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
route.Addrs = []netroute.Addr{
|
||||||
|
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||||
|
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||||
|
unix.RTAX_GATEWAY: gateway,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := route.Marshal()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
||||||
|
}
|
||||||
|
_, err = unix.Write(sock, data[:])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getLinkAddr Gets the link address for the interface of the given name
|
||||||
|
func getLinkAddr(name string) (*netroute.LinkAddr, error) {
|
||||||
|
rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
msgs, err := netroute.ParseRIB(unix.NET_RT_IFLIST, rib)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, m := range msgs {
|
||||||
|
switch m := m.(type) {
|
||||||
|
case *netroute.InterfaceMessage:
|
||||||
|
if m.Name == name {
|
||||||
|
sa, ok := m.Addrs[unix.RTAX_IFP].(*netroute.LinkAddr)
|
||||||
|
if ok {
|
||||||
|
return sa, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
//go:build linux && !android && !e2e_testing
|
//go:build !android && !e2e_testing
|
||||||
// +build linux,!android,!e2e_testing
|
// +build !android,!e2e_testing
|
||||||
|
|
||||||
package overlay
|
package overlay
|
||||||
|
|
||||||
@@ -9,105 +9,133 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/routing"
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
"github.com/vishvananda/netlink"
|
"github.com/vishvananda/netlink"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
wgtun "golang.zx2c4.com/wireguard/tun"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type tun struct {
|
type tun struct {
|
||||||
deviceIndex int
|
io.ReadWriteCloser
|
||||||
ioctlFd uintptr
|
fd int
|
||||||
txQueueLen int
|
Device string
|
||||||
|
vpnNetworks []netip.Prefix
|
||||||
|
MaxMTU int
|
||||||
|
DefaultMTU int
|
||||||
|
TXQueueLen int
|
||||||
|
deviceIndex int
|
||||||
|
ioctlFd uintptr
|
||||||
|
|
||||||
|
Routes atomic.Pointer[[]Route]
|
||||||
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
|
routeChan chan struct{}
|
||||||
useSystemRoutes bool
|
useSystemRoutes bool
|
||||||
useSystemRoutesBufferSize int
|
useSystemRoutesBufferSize int
|
||||||
|
|
||||||
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*wgTun, error) {
|
func (t *tun) Networks() []netip.Prefix {
|
||||||
deviceName := c.GetString("tun.dev", "")
|
return t.vpnNetworks
|
||||||
mtu := c.GetInt("tun.mtu", DefaultMTU)
|
|
||||||
|
|
||||||
// Create WireGuard TUN device
|
|
||||||
tunDevice, err := wgtun.CreateTUN(deviceName, mtu)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create TUN device: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the actual device name
|
|
||||||
actualName, err := tunDevice.Name()
|
|
||||||
if err != nil {
|
|
||||||
tunDevice.Close()
|
|
||||||
return nil, fmt.Errorf("failed to get TUN device name: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
t := &wgTun{
|
|
||||||
tunDevice: tunDevice,
|
|
||||||
vpnNetworks: vpnNetworks,
|
|
||||||
MaxMTU: mtu,
|
|
||||||
DefaultMTU: mtu,
|
|
||||||
l: l,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create Linux-specific route manager
|
|
||||||
routeManager := &tun{
|
|
||||||
txQueueLen: c.GetInt("tun.tx_queue", 500),
|
|
||||||
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
|
|
||||||
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
|
|
||||||
}
|
|
||||||
t.routeManager = routeManager
|
|
||||||
|
|
||||||
err = t.reload(c, true)
|
|
||||||
if err != nil {
|
|
||||||
tunDevice.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.RegisterReloadCallback(func(c *config.C) {
|
|
||||||
err := t.reload(c, false)
|
|
||||||
if err != nil {
|
|
||||||
util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
l.WithField("name", actualName).Info("Created WireGuard TUN device")
|
|
||||||
|
|
||||||
return t, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*wgTun, error) {
|
type ifReq struct {
|
||||||
// Create TUN device from file descriptor
|
Name [16]byte
|
||||||
|
Flags uint16
|
||||||
|
pad [8]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type ifreqMTU struct {
|
||||||
|
Name [16]byte
|
||||||
|
MTU int32
|
||||||
|
pad [8]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type ifreqQLEN struct {
|
||||||
|
Name [16]byte
|
||||||
|
Value int32
|
||||||
|
pad [8]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||||
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
||||||
mtu := c.GetInt("tun.mtu", DefaultMTU)
|
|
||||||
tunDevice, err := wgtun.CreateTUNFromFile(file, mtu)
|
t, err := newTunGeneric(c, l, file, vpnNetworks)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create TUN device from fd: %w", err)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
t := &wgTun{
|
t.Device = "tun0"
|
||||||
tunDevice: tunDevice,
|
|
||||||
vpnNetworks: vpnNetworks,
|
return t, nil
|
||||||
MaxMTU: mtu,
|
}
|
||||||
DefaultMTU: mtu,
|
|
||||||
l: l,
|
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
|
||||||
|
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||||
|
if err != nil {
|
||||||
|
// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
err = os.MkdirAll("/dev/net", 0755)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err)
|
||||||
|
}
|
||||||
|
err = unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create /dev/net/tun: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fd, err = unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("created /dev/net/tun, but still failed: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create Linux-specific route manager
|
var req ifReq
|
||||||
routeManager := &tun{
|
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI)
|
||||||
txQueueLen: c.GetInt("tun.tx_queue", 500),
|
if multiqueue {
|
||||||
|
req.Flags |= unix.IFF_MULTI_QUEUE
|
||||||
|
}
|
||||||
|
copy(req.Name[:], c.GetString("tun.dev", ""))
|
||||||
|
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
name := strings.Trim(string(req.Name[:]), "\x00")
|
||||||
|
|
||||||
|
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
||||||
|
t, err := newTunGeneric(c, l, file, vpnNetworks)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Device = name
|
||||||
|
|
||||||
|
return t, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||||
|
t := &tun{
|
||||||
|
ReadWriteCloser: file,
|
||||||
|
fd: int(file.Fd()),
|
||||||
|
vpnNetworks: vpnNetworks,
|
||||||
|
TXQueueLen: c.GetInt("tun.tx_queue", 500),
|
||||||
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
|
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
|
||||||
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
|
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
|
||||||
|
l: l,
|
||||||
}
|
}
|
||||||
t.routeManager = routeManager
|
|
||||||
|
|
||||||
err = t.reload(c, true)
|
err := t.reload(c, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tunDevice.Close()
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -121,105 +149,273 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net
|
|||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rm *tun) Activate(t *wgTun) error {
|
func (t *tun) reload(c *config.C, initial bool) error {
|
||||||
name, err := t.tunDevice.Name()
|
routeChange, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get device name: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if t.routeManager.useSystemRoutes {
|
if !initial && !routeChange && !c.HasChanged("tun.mtu") {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
routeTree, err := makeRouteTree(t.l, routes, true)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
oldDefaultMTU := t.DefaultMTU
|
||||||
|
oldMaxMTU := t.MaxMTU
|
||||||
|
newDefaultMTU := c.GetInt("tun.mtu", DefaultMTU)
|
||||||
|
newMaxMTU := newDefaultMTU
|
||||||
|
for i, r := range routes {
|
||||||
|
if r.MTU == 0 {
|
||||||
|
routes[i].MTU = newDefaultMTU
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.MTU > t.MaxMTU {
|
||||||
|
newMaxMTU = r.MTU
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.MaxMTU = newMaxMTU
|
||||||
|
t.DefaultMTU = newDefaultMTU
|
||||||
|
|
||||||
|
// Teach nebula how to handle the routes before establishing them in the system table
|
||||||
|
oldRoutes := t.Routes.Swap(&routes)
|
||||||
|
t.routeTree.Store(routeTree)
|
||||||
|
|
||||||
|
if !initial {
|
||||||
|
if oldMaxMTU != newMaxMTU {
|
||||||
|
t.setMTU()
|
||||||
|
t.l.Infof("Set max MTU to %v was %v", t.MaxMTU, oldMaxMTU)
|
||||||
|
}
|
||||||
|
|
||||||
|
if oldDefaultMTU != newDefaultMTU {
|
||||||
|
for i := range t.vpnNetworks {
|
||||||
|
err := t.setDefaultRoute(t.vpnNetworks[i])
|
||||||
|
if err != nil {
|
||||||
|
t.l.Warn(err)
|
||||||
|
} else {
|
||||||
|
t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove first, if the system removes a wanted route hopefully it will be re-added next
|
||||||
|
t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
|
||||||
|
|
||||||
|
// Ensure any routes we actually want are installed
|
||||||
|
err = t.addRoutes(true)
|
||||||
|
if err != nil {
|
||||||
|
// This should never be called since addRoutes should log its own errors in a reload condition
|
||||||
|
util.LogWithContextIfNeeded("Failed to refresh routes", err, t.l)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
|
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var req ifReq
|
||||||
|
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
|
||||||
|
copy(req.Name[:], t.Device)
|
||||||
|
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
||||||
|
|
||||||
|
return file, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||||
|
r, _ := t.routeTree.Load().Lookup(ip)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) Write(b []byte) (int, error) {
|
||||||
|
var nn int
|
||||||
|
maximum := len(b)
|
||||||
|
|
||||||
|
for {
|
||||||
|
n, err := unix.Write(t.fd, b[nn:maximum])
|
||||||
|
if n > 0 {
|
||||||
|
nn += n
|
||||||
|
}
|
||||||
|
if nn == len(b) {
|
||||||
|
return nn, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nn, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if n == 0 {
|
||||||
|
return nn, io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) deviceBytes() (o [16]byte) {
|
||||||
|
for i, c := range t.Device {
|
||||||
|
o[i] = byte(c)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasNetlinkAddr(al []*netlink.Addr, x netlink.Addr) bool {
|
||||||
|
for i := range al {
|
||||||
|
if al[i].Equal(x) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// addIPs uses netlink to add all addresses that don't exist, then it removes ones that should not be there
|
||||||
|
func (t *tun) addIPs(link netlink.Link) error {
|
||||||
|
newAddrs := make([]*netlink.Addr, len(t.vpnNetworks))
|
||||||
|
for i := range t.vpnNetworks {
|
||||||
|
newAddrs[i] = &netlink.Addr{
|
||||||
|
IPNet: &net.IPNet{
|
||||||
|
IP: t.vpnNetworks[i].Addr().AsSlice(),
|
||||||
|
Mask: net.CIDRMask(t.vpnNetworks[i].Bits(), t.vpnNetworks[i].Addr().BitLen()),
|
||||||
|
},
|
||||||
|
Label: t.vpnNetworks[i].Addr().Zone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//add all new addresses
|
||||||
|
for i := range newAddrs {
|
||||||
|
//AddrReplace still adds new IPs, but if their properties change it will change them as well
|
||||||
|
if err := netlink.AddrReplace(link, newAddrs[i]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//iterate over remainder, remove whoever shouldn't be there
|
||||||
|
al, err := netlink.AddrList(link, netlink.FAMILY_ALL)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get tun address list: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range al {
|
||||||
|
if hasNetlinkAddr(newAddrs, al[i]) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
err = netlink.AddrDel(link, &al[i])
|
||||||
|
if err != nil {
|
||||||
|
t.l.WithError(err).Error("failed to remove address from tun address list")
|
||||||
|
} else {
|
||||||
|
t.l.WithField("removed", al[i].String()).Info("removed address not listed in cert(s)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) Activate() error {
|
||||||
|
devName := t.deviceBytes()
|
||||||
|
|
||||||
|
if t.useSystemRoutes {
|
||||||
t.watchRoutes()
|
t.watchRoutes()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the netlink device
|
|
||||||
link, err := netlink.LinkByName(name)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get tun device link: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
rm.deviceIndex = link.Attrs().Index
|
|
||||||
|
|
||||||
// Open socket for ioctl operations
|
|
||||||
s, err := unix.Socket(
|
s, err := unix.Socket(
|
||||||
unix.AF_INET,
|
unix.AF_INET, //because everything we use t.ioctlFd for is address family independent, this is fine
|
||||||
unix.SOCK_DGRAM,
|
unix.SOCK_DGRAM,
|
||||||
unix.IPPROTO_IP,
|
unix.IPPROTO_IP,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
rm.ioctlFd = uintptr(s)
|
t.ioctlFd = uintptr(s)
|
||||||
|
|
||||||
rm.SetMTU(t, t.MaxMTU)
|
// Set the device name
|
||||||
|
ifrf := ifReq{Name: devName}
|
||||||
|
if err = ioctl(t.ioctlFd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
||||||
|
return fmt.Errorf("failed to set tun device name: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
link, err := netlink.LinkByName(t.Device)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get tun device link: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.deviceIndex = link.Attrs().Index
|
||||||
|
|
||||||
|
// Setup our default MTU
|
||||||
|
t.setMTU()
|
||||||
|
|
||||||
// Set the transmit queue length
|
// Set the transmit queue length
|
||||||
devName := deviceBytes(name)
|
ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)}
|
||||||
ifrq := ifreqQLEN{Name: devName, Value: int32(rm.txQueueLen)}
|
if err = ioctl(t.ioctlFd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
|
||||||
if err = ioctl(t.routeManager.ioctlFd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
|
|
||||||
// If we can't set the queue length nebula will still work but it may lead to packet loss
|
// If we can't set the queue length nebula will still work but it may lead to packet loss
|
||||||
t.l.WithError(err).Error("Failed to set tun tx queue length")
|
t.l.WithError(err).Error("Failed to set tun tx queue length")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Disable IPv6 link-local address generation
|
|
||||||
const modeNone = 1
|
const modeNone = 1
|
||||||
if err = netlink.LinkSetIP6AddrGenMode(link, modeNone); err != nil {
|
if err = netlink.LinkSetIP6AddrGenMode(link, modeNone); err != nil {
|
||||||
t.l.WithError(err).Warn("Failed to disable link local address generation")
|
t.l.WithError(err).Warn("Failed to disable link local address generation")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add IP addresses
|
if err = t.addIPs(link); err != nil {
|
||||||
if err = t.routeManager.addIPs(t, link); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Bring up the interface
|
// Bring up the interface
|
||||||
if err = netlink.LinkSetUp(link); err != nil {
|
ifrf.Flags = ifrf.Flags | unix.IFF_UP
|
||||||
|
if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
||||||
return fmt.Errorf("failed to bring the tun device up: %s", err)
|
return fmt.Errorf("failed to bring the tun device up: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set route MTU
|
//set route MTU
|
||||||
for i := range t.vpnNetworks {
|
for i := range t.vpnNetworks {
|
||||||
if err = t.routeManager.SetDefaultRoute(t, t.vpnNetworks[i]); err != nil {
|
if err = t.setDefaultRoute(t.vpnNetworks[i]); err != nil {
|
||||||
return fmt.Errorf("failed to set default route MTU: %w", err)
|
return fmt.Errorf("failed to set default route MTU: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set the routes
|
// Set the routes
|
||||||
if err = t.routeManager.AddRoutes(t, false); err != nil {
|
if err = t.addRoutes(false); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Run the interface
|
||||||
|
ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING
|
||||||
|
if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
||||||
|
return fmt.Errorf("failed to run tun device: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rm *tun) SetMTU(t *wgTun, mtu int) {
|
func (t *tun) setMTU() {
|
||||||
name, err := t.tunDevice.Name()
|
// Set the MTU on the device
|
||||||
if err != nil {
|
ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MaxMTU)}
|
||||||
t.l.WithError(err).Error("Failed to get device name for MTU set")
|
if err := ioctl(t.ioctlFd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
|
||||||
return
|
// This is currently a non fatal condition because the route table must have the MTU set appropriately as well
|
||||||
}
|
|
||||||
|
|
||||||
link, err := netlink.LinkByName(name)
|
|
||||||
if err != nil {
|
|
||||||
t.l.WithError(err).Error("Failed to get link for MTU set")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := netlink.LinkSetMTU(link, mtu); err != nil {
|
|
||||||
t.l.WithError(err).Error("Failed to set tun mtu")
|
t.l.WithError(err).Error("Failed to set tun mtu")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rm *tun) SetDefaultRoute(t *wgTun, cidr netip.Prefix) error {
|
func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
|
||||||
dr := &net.IPNet{
|
dr := &net.IPNet{
|
||||||
IP: cidr.Masked().Addr().AsSlice(),
|
IP: cidr.Masked().Addr().AsSlice(),
|
||||||
Mask: net.CIDRMask(cidr.Bits(), cidr.Addr().BitLen()),
|
Mask: net.CIDRMask(cidr.Bits(), cidr.Addr().BitLen()),
|
||||||
}
|
}
|
||||||
|
|
||||||
nr := netlink.Route{
|
nr := netlink.Route{
|
||||||
LinkIndex: t.routeManager.deviceIndex,
|
LinkIndex: t.deviceIndex,
|
||||||
Dst: dr,
|
Dst: dr,
|
||||||
MTU: t.DefaultMTU,
|
MTU: t.DefaultMTU,
|
||||||
AdvMSS: advMSS(Route{}, t.DefaultMTU, t.MaxMTU),
|
AdvMSS: t.advMSS(Route{}),
|
||||||
Scope: unix.RT_SCOPE_LINK,
|
Scope: unix.RT_SCOPE_LINK,
|
||||||
Src: net.IP(cidr.Addr().AsSlice()),
|
Src: net.IP(cidr.Addr().AsSlice()),
|
||||||
Protocol: unix.RTPROT_KERNEL,
|
Protocol: unix.RTPROT_KERNEL,
|
||||||
@@ -229,7 +425,7 @@ func (rm *tun) SetDefaultRoute(t *wgTun, cidr netip.Prefix) error {
|
|||||||
err := netlink.RouteReplace(&nr)
|
err := netlink.RouteReplace(&nr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.WithError(err).WithField("cidr", cidr).Warn("Failed to set default route MTU, retrying")
|
t.l.WithError(err).WithField("cidr", cidr).Warn("Failed to set default route MTU, retrying")
|
||||||
// Retry twice more
|
//retry twice more -- on some systems there appears to be a race condition where if we set routes too soon, netlink says `invalid argument`
|
||||||
for i := 0; i < 2; i++ {
|
for i := 0; i < 2; i++ {
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
err = netlink.RouteReplace(&nr)
|
err = netlink.RouteReplace(&nr)
|
||||||
@@ -247,7 +443,8 @@ func (rm *tun) SetDefaultRoute(t *wgTun, cidr netip.Prefix) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rm *tun) AddRoutes(t *wgTun, logErrors bool) error {
|
func (t *tun) addRoutes(logErrors bool) error {
|
||||||
|
// Path routes
|
||||||
routes := *t.Routes.Load()
|
routes := *t.Routes.Load()
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if !r.Install {
|
if !r.Install {
|
||||||
@@ -260,10 +457,10 @@ func (rm *tun) AddRoutes(t *wgTun, logErrors bool) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
nr := netlink.Route{
|
nr := netlink.Route{
|
||||||
LinkIndex: t.routeManager.deviceIndex,
|
LinkIndex: t.deviceIndex,
|
||||||
Dst: dr,
|
Dst: dr,
|
||||||
MTU: r.MTU,
|
MTU: r.MTU,
|
||||||
AdvMSS: advMSS(r, t.DefaultMTU, t.MaxMTU),
|
AdvMSS: t.advMSS(r),
|
||||||
Scope: unix.RT_SCOPE_LINK,
|
Scope: unix.RT_SCOPE_LINK,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -287,7 +484,7 @@ func (rm *tun) AddRoutes(t *wgTun, logErrors bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rm *tun) RemoveRoutes(t *wgTun, routes []Route) {
|
func (t *tun) removeRoutes(routes []Route) {
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if !r.Install {
|
if !r.Install {
|
||||||
continue
|
continue
|
||||||
@@ -299,10 +496,10 @@ func (rm *tun) RemoveRoutes(t *wgTun, routes []Route) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
nr := netlink.Route{
|
nr := netlink.Route{
|
||||||
LinkIndex: t.routeManager.deviceIndex,
|
LinkIndex: t.deviceIndex,
|
||||||
Dst: dr,
|
Dst: dr,
|
||||||
MTU: r.MTU,
|
MTU: r.MTU,
|
||||||
AdvMSS: advMSS(r, t.DefaultMTU, t.MaxMTU),
|
AdvMSS: t.advMSS(r),
|
||||||
Scope: unix.RT_SCOPE_LINK,
|
Scope: unix.RT_SCOPE_LINK,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -319,105 +516,30 @@ func (rm *tun) RemoveRoutes(t *wgTun, routes []Route) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rm *tun) NewMultiQueueReader(t *wgTun) (io.ReadWriteCloser, error) {
|
func (t *tun) Name() string {
|
||||||
// For Linux with WireGuard TUN, we can reuse the same device
|
return t.Device
|
||||||
// The vectorized I/O will handle batching
|
|
||||||
return &wgTunReader{
|
|
||||||
parent: t,
|
|
||||||
tunDevice: t.tunDevice,
|
|
||||||
offset: 0,
|
|
||||||
l: t.l,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func deviceBytes(name string) [16]byte {
|
func (t *tun) advMSS(r Route) int {
|
||||||
var o [16]byte
|
|
||||||
for i, c := range name {
|
|
||||||
if i >= 16 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
o[i] = byte(c)
|
|
||||||
}
|
|
||||||
return o
|
|
||||||
}
|
|
||||||
|
|
||||||
func advMSS(r Route, defaultMTU, maxMTU int) int {
|
|
||||||
mtu := r.MTU
|
mtu := r.MTU
|
||||||
if r.MTU == 0 {
|
if r.MTU == 0 {
|
||||||
mtu = defaultMTU
|
mtu = t.DefaultMTU
|
||||||
}
|
}
|
||||||
|
|
||||||
// We only need to set advmss if the route MTU does not match the device MTU
|
// We only need to set advmss if the route MTU does not match the device MTU
|
||||||
if mtu != maxMTU {
|
if mtu != t.MaxMTU {
|
||||||
return mtu - 40
|
return mtu - 40
|
||||||
}
|
}
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
type ifreqQLEN struct {
|
func (t *tun) watchRoutes() {
|
||||||
Name [16]byte
|
|
||||||
Value int32
|
|
||||||
pad [8]byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func hasNetlinkAddr(al []*netlink.Addr, x netlink.Addr) bool {
|
|
||||||
for i := range al {
|
|
||||||
if al[i].Equal(x) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rm *tun) addIPs(t *wgTun, link netlink.Link) error {
|
|
||||||
newAddrs := make([]*netlink.Addr, len(t.vpnNetworks))
|
|
||||||
for i := range t.vpnNetworks {
|
|
||||||
newAddrs[i] = &netlink.Addr{
|
|
||||||
IPNet: &net.IPNet{
|
|
||||||
IP: t.vpnNetworks[i].Addr().AsSlice(),
|
|
||||||
Mask: net.CIDRMask(t.vpnNetworks[i].Bits(), t.vpnNetworks[i].Addr().BitLen()),
|
|
||||||
},
|
|
||||||
Label: t.vpnNetworks[i].Addr().Zone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add all new addresses
|
|
||||||
for i := range newAddrs {
|
|
||||||
if err := netlink.AddrReplace(link, newAddrs[i]); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Iterate over remainder, remove whoever shouldn't be there
|
|
||||||
al, err := netlink.AddrList(link, netlink.FAMILY_ALL)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get tun address list: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range al {
|
|
||||||
if hasNetlinkAddr(newAddrs, al[i]) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
err = netlink.AddrDel(link, &al[i])
|
|
||||||
if err != nil {
|
|
||||||
t.l.WithError(err).Error("failed to remove address from tun address list")
|
|
||||||
} else {
|
|
||||||
t.l.WithField("removed", al[i].String()).Info("removed address not listed in cert(s)")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// watchRoutes monitors system route changes
|
|
||||||
func (t *wgTun) watchRoutes() {
|
|
||||||
|
|
||||||
rch := make(chan netlink.RouteUpdate)
|
rch := make(chan netlink.RouteUpdate)
|
||||||
doneChan := make(chan struct{})
|
doneChan := make(chan struct{})
|
||||||
|
|
||||||
netlinkOptions := netlink.RouteSubscribeOptions{
|
netlinkOptions := netlink.RouteSubscribeOptions{
|
||||||
ReceiveBufferSize: t.routeManager.useSystemRoutesBufferSize,
|
ReceiveBufferSize: t.useSystemRoutesBufferSize,
|
||||||
ReceiveBufferForceSize: t.routeManager.useSystemRoutesBufferSize != 0,
|
ReceiveBufferForceSize: t.useSystemRoutesBufferSize != 0,
|
||||||
ErrorCallback: func(e error) { t.l.WithError(e).Errorf("netlink error") },
|
ErrorCallback: func(e error) { t.l.WithError(e).Errorf("netlink error") },
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -435,19 +557,87 @@ func (t *wgTun) watchRoutes() {
|
|||||||
if ok {
|
if ok {
|
||||||
t.updateRoutes(r)
|
t.updateRoutes(r)
|
||||||
} else {
|
} else {
|
||||||
|
// may be should do something here as
|
||||||
|
// netlink stops sending updates
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
case <-doneChan:
|
case <-doneChan:
|
||||||
|
// netlink.RouteSubscriber will close the rch for us
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *wgTun) updateRoutes(r netlink.RouteUpdate) {
|
func (t *tun) isGatewayInVpnNetworks(gwAddr netip.Addr) bool {
|
||||||
gateways := t.getGatewaysFromRoute(&r.Route, t.routeManager.deviceIndex)
|
withinNetworks := false
|
||||||
|
for i := range t.vpnNetworks {
|
||||||
|
if t.vpnNetworks[i].Contains(gwAddr) {
|
||||||
|
withinNetworks = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return withinNetworks
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
|
||||||
|
|
||||||
|
var gateways routing.Gateways
|
||||||
|
|
||||||
|
link, err := netlink.LinkByName(t.Device)
|
||||||
|
if err != nil {
|
||||||
|
t.l.WithField("Devicename", t.Device).Error("Ignoring route update: failed to get link by name")
|
||||||
|
return gateways
|
||||||
|
}
|
||||||
|
|
||||||
|
// If this route is relevant to our interface and there is a gateway then add it
|
||||||
|
if r.LinkIndex == link.Attrs().Index && len(r.Gw) > 0 {
|
||||||
|
gwAddr, ok := netip.AddrFromSlice(r.Gw)
|
||||||
|
if !ok {
|
||||||
|
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
|
||||||
|
} else {
|
||||||
|
gwAddr = gwAddr.Unmap()
|
||||||
|
|
||||||
|
if !t.isGatewayInVpnNetworks(gwAddr) {
|
||||||
|
// Gateway isn't in our overlay network, ignore
|
||||||
|
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
|
||||||
|
} else {
|
||||||
|
gateways = append(gateways, routing.NewGateway(gwAddr, 1))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, p := range r.MultiPath {
|
||||||
|
// If this route is relevant to our interface and there is a gateway then add it
|
||||||
|
if p.LinkIndex == link.Attrs().Index && len(p.Gw) > 0 {
|
||||||
|
gwAddr, ok := netip.AddrFromSlice(p.Gw)
|
||||||
|
if !ok {
|
||||||
|
t.l.WithField("route", r).Debug("Ignoring multipath route update, invalid gateway address")
|
||||||
|
} else {
|
||||||
|
gwAddr = gwAddr.Unmap()
|
||||||
|
|
||||||
|
if !t.isGatewayInVpnNetworks(gwAddr) {
|
||||||
|
// Gateway isn't in our overlay network, ignore
|
||||||
|
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
|
||||||
|
} else {
|
||||||
|
// p.Hops+1 = weight of the route
|
||||||
|
gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
routing.CalculateBucketsForGateways(gateways)
|
||||||
|
return gateways
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
||||||
|
|
||||||
|
gateways := t.getGatewaysFromRoute(&r.Route)
|
||||||
|
|
||||||
if len(gateways) == 0 {
|
if len(gateways) == 0 {
|
||||||
|
// No gateways relevant to our network, no routing changes required.
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, no gateways")
|
t.l.WithField("route", r).Debug("Ignoring route update, no gateways")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -471,6 +661,7 @@ func (t *wgTun) updateRoutes(r netlink.RouteUpdate) {
|
|||||||
if r.Type == unix.RTM_NEWROUTE {
|
if r.Type == unix.RTM_NEWROUTE {
|
||||||
t.l.WithField("destination", dst).WithField("via", gateways).Info("Adding route")
|
t.l.WithField("destination", dst).WithField("via", gateways).Info("Adding route")
|
||||||
newTree.Insert(dst, gateways)
|
newTree.Insert(dst, gateways)
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("destination", dst).WithField("via", gateways).Info("Removing route")
|
t.l.WithField("destination", dst).WithField("via", gateways).Info("Removing route")
|
||||||
newTree.Delete(dst)
|
newTree.Delete(dst)
|
||||||
@@ -478,71 +669,18 @@ func (t *wgTun) updateRoutes(r netlink.RouteUpdate) {
|
|||||||
t.routeTree.Store(newTree)
|
t.routeTree.Store(newTree)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *wgTun) getGatewaysFromRoute(r *netlink.Route, deviceIndex int) routing.Gateways {
|
func (t *tun) Close() error {
|
||||||
var gateways routing.Gateways
|
if t.routeChan != nil {
|
||||||
|
close(t.routeChan)
|
||||||
name, err := t.tunDevice.Name()
|
|
||||||
if err != nil {
|
|
||||||
t.l.Error("Ignoring route update: failed to get device name")
|
|
||||||
return gateways
|
|
||||||
}
|
}
|
||||||
|
|
||||||
link, err := netlink.LinkByName(name)
|
if t.ReadWriteCloser != nil {
|
||||||
if err != nil {
|
_ = t.ReadWriteCloser.Close()
|
||||||
t.l.WithField("DeviceName", name).Error("Ignoring route update: failed to get link by name")
|
|
||||||
return gateways
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// If this route is relevant to our interface and there is a gateway then add it
|
if t.ioctlFd > 0 {
|
||||||
if r.LinkIndex == link.Attrs().Index && len(r.Gw) > 0 {
|
_ = os.NewFile(t.ioctlFd, "ioctlFd").Close()
|
||||||
gwAddr, ok := netip.AddrFromSlice(r.Gw)
|
|
||||||
if !ok {
|
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
|
|
||||||
} else {
|
|
||||||
gwAddr = gwAddr.Unmap()
|
|
||||||
|
|
||||||
if !t.isGatewayInVpnNetworks(gwAddr) {
|
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
|
|
||||||
} else {
|
|
||||||
gateways = append(gateways, routing.NewGateway(gwAddr, 1))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, p := range r.MultiPath {
|
|
||||||
if p.LinkIndex == link.Attrs().Index && len(p.Gw) > 0 {
|
|
||||||
gwAddr, ok := netip.AddrFromSlice(p.Gw)
|
|
||||||
if !ok {
|
|
||||||
t.l.WithField("route", r).Debug("Ignoring multipath route update, invalid gateway address")
|
|
||||||
} else {
|
|
||||||
gwAddr = gwAddr.Unmap()
|
|
||||||
|
|
||||||
if !t.isGatewayInVpnNetworks(gwAddr) {
|
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
|
|
||||||
} else {
|
|
||||||
gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
routing.CalculateBucketsForGateways(gateways)
|
|
||||||
return gateways
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *wgTun) isGatewayInVpnNetworks(gwAddr netip.Addr) bool {
|
|
||||||
for i := range t.vpnNetworks {
|
|
||||||
if t.vpnNetworks[i].Contains(gwAddr) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func ioctl(a1, a2, a3 uintptr) error {
|
|
||||||
_, _, errno := unix.Syscall(unix.SYS_IOCTL, a1, a2, a3)
|
|
||||||
if errno != 0 {
|
|
||||||
return errno
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,27 +6,26 @@ package overlay
|
|||||||
import "testing"
|
import "testing"
|
||||||
|
|
||||||
var runAdvMSSTests = []struct {
|
var runAdvMSSTests = []struct {
|
||||||
name string
|
name string
|
||||||
defaultMTU int
|
tun *tun
|
||||||
maxMTU int
|
r Route
|
||||||
r Route
|
expected int
|
||||||
expected int
|
|
||||||
}{
|
}{
|
||||||
// Standard case, default MTU is the device max MTU
|
// Standard case, default MTU is the device max MTU
|
||||||
{"default", 1440, 1440, Route{}, 0},
|
{"default", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{}, 0},
|
||||||
{"default-min", 1440, 1440, Route{MTU: 1440}, 0},
|
{"default-min", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1440}, 0},
|
||||||
{"default-low", 1440, 1440, Route{MTU: 1200}, 1160},
|
{"default-low", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1200}, 1160},
|
||||||
|
|
||||||
// Case where we have a route MTU set higher than the default
|
// Case where we have a route MTU set higher than the default
|
||||||
{"route", 1440, 8941, Route{}, 1400},
|
{"route", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{}, 1400},
|
||||||
{"route-min", 1440, 8941, Route{MTU: 1440}, 1400},
|
{"route-min", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 1440}, 1400},
|
||||||
{"route-high", 1440, 8941, Route{MTU: 8941}, 0},
|
{"route-high", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 8941}, 0},
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTunAdvMSS(t *testing.T) {
|
func TestTunAdvMSS(t *testing.T) {
|
||||||
for _, tt := range runAdvMSSTests {
|
for _, tt := range runAdvMSSTests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
o := advMSS(tt.r, tt.defaultMTU, tt.maxMTU)
|
o := tt.tun.advMSS(tt.r)
|
||||||
if o != tt.expected {
|
if o != tt.expected {
|
||||||
t.Errorf("got %d, want %d", o, tt.expected)
|
t.Errorf("got %d, want %d", o, tt.expected)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -547,41 +547,3 @@ func delRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ioctl(a1, a2, a3 uintptr) error {
|
|
||||||
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, a1, a2, a3)
|
|
||||||
if errno != 0 {
|
|
||||||
return errno
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func prefixToMask(prefix netip.Prefix) netip.Addr {
|
|
||||||
bits := prefix.Bits()
|
|
||||||
if prefix.Addr().Is4() {
|
|
||||||
mask := ^uint32(0) << (32 - bits)
|
|
||||||
return netip.AddrFrom4([4]byte{
|
|
||||||
byte(mask >> 24),
|
|
||||||
byte(mask >> 16),
|
|
||||||
byte(mask >> 8),
|
|
||||||
byte(mask),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
var mask [16]byte
|
|
||||||
for i := 0; i < bits/8; i++ {
|
|
||||||
mask[i] = 0xff
|
|
||||||
}
|
|
||||||
if bits%8 != 0 {
|
|
||||||
mask[bits/8] = ^byte(0) << (8 - bits%8)
|
|
||||||
}
|
|
||||||
return netip.AddrFrom16(mask)
|
|
||||||
}
|
|
||||||
|
|
||||||
func selectGateway(prefix netip.Prefix, gateways []netip.Prefix) (netip.Prefix, error) {
|
|
||||||
for _, gw := range gateways {
|
|
||||||
if prefix.Addr().Is4() == gw.Addr().Is4() {
|
|
||||||
return gw, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return netip.Prefix{}, fmt.Errorf("no suitable gateway found for prefix %v", prefix)
|
|
||||||
}
|
|
||||||
|
|||||||
14
overlay/tun_notwin.go
Normal file
14
overlay/tun_notwin.go
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
//go:build !windows
|
||||||
|
// +build !windows
|
||||||
|
|
||||||
|
package overlay
|
||||||
|
|
||||||
|
import "syscall"
|
||||||
|
|
||||||
|
func ioctl(a1, a2, a3 uintptr) error {
|
||||||
|
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, a1, a2, a3)
|
||||||
|
if errno != 0 {
|
||||||
|
return errno
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -1,59 +1,104 @@
|
|||||||
//go:build openbsd && !e2e_testing
|
//go:build !e2e_testing
|
||||||
// +build openbsd,!e2e_testing
|
// +build !e2e_testing
|
||||||
|
|
||||||
package overlay
|
package overlay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os/exec"
|
"os"
|
||||||
"strconv"
|
"regexp"
|
||||||
"strings"
|
"sync/atomic"
|
||||||
|
"syscall"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
wgtun "golang.zx2c4.com/wireguard/tun"
|
netroute "golang.org/x/net/route"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
type tun struct{}
|
const (
|
||||||
|
SIOCAIFADDR_IN6 = 0x8080691a
|
||||||
|
)
|
||||||
|
|
||||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*wgTun, error) {
|
type ifreqAlias4 struct {
|
||||||
return nil, fmt.Errorf("newTunFromFd not supported on OpenBSD")
|
Name [unix.IFNAMSIZ]byte
|
||||||
|
Addr unix.RawSockaddrInet4
|
||||||
|
DstAddr unix.RawSockaddrInet4
|
||||||
|
MaskAddr unix.RawSockaddrInet4
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*wgTun, error) {
|
type ifreqAlias6 struct {
|
||||||
deviceName := c.GetString("tun.dev", "tun")
|
Name [unix.IFNAMSIZ]byte
|
||||||
mtu := c.GetInt("tun.mtu", DefaultMTU)
|
Addr unix.RawSockaddrInet6
|
||||||
|
DstAddr unix.RawSockaddrInet6
|
||||||
|
PrefixMask unix.RawSockaddrInet6
|
||||||
|
Flags uint32
|
||||||
|
Lifetime [2]uint32
|
||||||
|
}
|
||||||
|
|
||||||
// Create WireGuard TUN device
|
type ifreq struct {
|
||||||
tunDevice, err := wgtun.CreateTUN(deviceName, mtu)
|
Name [unix.IFNAMSIZ]byte
|
||||||
if err != nil {
|
data int
|
||||||
return nil, fmt.Errorf("failed to create TUN device: %w", err)
|
}
|
||||||
|
|
||||||
|
type tun struct {
|
||||||
|
Device string
|
||||||
|
vpnNetworks []netip.Prefix
|
||||||
|
MTU int
|
||||||
|
Routes atomic.Pointer[[]Route]
|
||||||
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
|
l *logrus.Logger
|
||||||
|
f *os.File
|
||||||
|
fd int
|
||||||
|
// cache out buffer since we need to prepend 4 bytes for tun metadata
|
||||||
|
out []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||||
|
|
||||||
|
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||||
|
return nil, fmt.Errorf("newTunFromFd not supported in openbsd")
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||||
|
// Try to open tun device
|
||||||
|
var err error
|
||||||
|
deviceName := c.GetString("tun.dev", "")
|
||||||
|
if deviceName == "" {
|
||||||
|
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
|
||||||
|
}
|
||||||
|
if !deviceNameRE.MatchString(deviceName) {
|
||||||
|
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the actual device name
|
fd, err := unix.Open("/dev/"+deviceName, os.O_RDWR, 0)
|
||||||
actualName, err := tunDevice.Name()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tunDevice.Close()
|
return nil, err
|
||||||
return nil, fmt.Errorf("failed to get TUN device name: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
t := &wgTun{
|
err = unix.SetNonblock(fd, true)
|
||||||
tunDevice: tunDevice,
|
if err != nil {
|
||||||
|
l.WithError(err).Warn("Failed to set the tun device as nonblocking")
|
||||||
|
}
|
||||||
|
|
||||||
|
t := &tun{
|
||||||
|
f: os.NewFile(uintptr(fd), ""),
|
||||||
|
fd: fd,
|
||||||
|
Device: deviceName,
|
||||||
vpnNetworks: vpnNetworks,
|
vpnNetworks: vpnNetworks,
|
||||||
MaxMTU: mtu,
|
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||||
DefaultMTU: mtu,
|
|
||||||
l: l,
|
l: l,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create OpenBSD-specific route manager
|
|
||||||
t.routeManager = &tun{}
|
|
||||||
|
|
||||||
err = t.reload(c, true)
|
err = t.reload(c, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tunDevice.Close()
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -64,86 +109,221 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
l.WithField("name", actualName).Info("Created WireGuard TUN device")
|
|
||||||
|
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rm *tun) Activate(t *wgTun) error {
|
func (t *tun) Close() error {
|
||||||
name, err := t.tunDevice.Name()
|
if t.f != nil {
|
||||||
if err != nil {
|
if err := t.f.Close(); err != nil {
|
||||||
return fmt.Errorf("failed to get device name: %w", err)
|
return fmt.Errorf("error closing tun file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// t.f.Close should have handled it for us but let's be extra sure
|
||||||
|
_ = unix.Close(t.fd)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) Read(to []byte) (int, error) {
|
||||||
|
buf := make([]byte, len(to)+4)
|
||||||
|
|
||||||
|
n, err := t.f.Read(buf)
|
||||||
|
|
||||||
|
copy(to, buf[4:])
|
||||||
|
return n - 4, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write is only valid for single threaded use
|
||||||
|
func (t *tun) Write(from []byte) (int, error) {
|
||||||
|
buf := t.out
|
||||||
|
if cap(buf) < len(from)+4 {
|
||||||
|
buf = make([]byte, len(from)+4)
|
||||||
|
t.out = buf
|
||||||
|
}
|
||||||
|
buf = buf[:len(from)+4]
|
||||||
|
|
||||||
|
if len(from) == 0 {
|
||||||
|
return 0, syscall.EIO
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set the MTU
|
// Determine the IP Family for the NULL L2 Header
|
||||||
rm.SetMTU(t, t.MaxMTU)
|
ipVer := from[0] >> 4
|
||||||
|
if ipVer == 4 {
|
||||||
|
buf[3] = syscall.AF_INET
|
||||||
|
} else if ipVer == 6 {
|
||||||
|
buf[3] = syscall.AF_INET6
|
||||||
|
} else {
|
||||||
|
return 0, fmt.Errorf("unable to determine IP version from packet")
|
||||||
|
}
|
||||||
|
|
||||||
// Add IP addresses
|
copy(buf[4:], from)
|
||||||
for _, network := range t.vpnNetworks {
|
|
||||||
if err := rm.addIP(t, name, network); err != nil {
|
n, err := t.f.Write(buf)
|
||||||
|
return n - 4, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) addIp(cidr netip.Prefix) error {
|
||||||
|
if cidr.Addr().Is4() {
|
||||||
|
var req ifreqAlias4
|
||||||
|
req.Name = t.deviceBytes()
|
||||||
|
req.Addr = unix.RawSockaddrInet4{
|
||||||
|
Len: unix.SizeofSockaddrInet4,
|
||||||
|
Family: unix.AF_INET,
|
||||||
|
Addr: cidr.Addr().As4(),
|
||||||
|
}
|
||||||
|
req.DstAddr = unix.RawSockaddrInet4{
|
||||||
|
Len: unix.SizeofSockaddrInet4,
|
||||||
|
Family: unix.AF_INET,
|
||||||
|
Addr: cidr.Addr().As4(),
|
||||||
|
}
|
||||||
|
req.MaskAddr = unix.RawSockaddrInet4{
|
||||||
|
Len: unix.SizeofSockaddrInet4,
|
||||||
|
Family: unix.AF_INET,
|
||||||
|
Addr: prefixToMask(cidr).As4(),
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer syscall.Close(s)
|
||||||
|
|
||||||
|
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&req))); err != nil {
|
||||||
|
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = addRoute(cidr, t.vpnNetworks)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to set route for vpn network %v: %w", cidr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if cidr.Addr().Is6() {
|
||||||
|
var req ifreqAlias6
|
||||||
|
req.Name = t.deviceBytes()
|
||||||
|
req.Addr = unix.RawSockaddrInet6{
|
||||||
|
Len: unix.SizeofSockaddrInet6,
|
||||||
|
Family: unix.AF_INET6,
|
||||||
|
Addr: cidr.Addr().As16(),
|
||||||
|
}
|
||||||
|
req.PrefixMask = unix.RawSockaddrInet6{
|
||||||
|
Len: unix.SizeofSockaddrInet6,
|
||||||
|
Family: unix.AF_INET6,
|
||||||
|
Addr: prefixToMask(cidr).As16(),
|
||||||
|
}
|
||||||
|
req.Lifetime[0] = 0xffffffff
|
||||||
|
req.Lifetime[1] = 0xffffffff
|
||||||
|
|
||||||
|
s, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer syscall.Close(s)
|
||||||
|
|
||||||
|
if err := ioctl(uintptr(s), SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&req))); err != nil {
|
||||||
|
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("unknown address type %v", cidr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) Activate() error {
|
||||||
|
err := t.doIoctlByName(unix.SIOCSIFMTU, uint32(t.MTU))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to set tun mtu: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range t.vpnNetworks {
|
||||||
|
err = t.addIp(t.vpnNetworks[i])
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Bring up the interface
|
return t.addRoutes(false)
|
||||||
if err := runCommandBSD("ifconfig", name, "up"); err != nil {
|
}
|
||||||
return fmt.Errorf("failed to bring up interface: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the routes
|
func (t *tun) doIoctlByName(ctl uintptr, value uint32) error {
|
||||||
if err := rm.AddRoutes(t, false); err != nil {
|
s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer syscall.Close(s)
|
||||||
|
|
||||||
|
ir := ifreq{Name: t.deviceBytes(), data: int(value)}
|
||||||
|
err = ioctl(uintptr(s), ctl, uintptr(unsafe.Pointer(&ir)))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) reload(c *config.C, initial bool) error {
|
||||||
|
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !initial && !change {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
routeTree, err := makeRouteTree(t.l, routes, false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Teach nebula how to handle the routes before establishing them in the system table
|
||||||
|
oldRoutes := t.Routes.Swap(&routes)
|
||||||
|
t.routeTree.Store(routeTree)
|
||||||
|
|
||||||
|
if !initial {
|
||||||
|
// Remove first, if the system removes a wanted route hopefully it will be re-added next
|
||||||
|
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
|
||||||
|
if err != nil {
|
||||||
|
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure any routes we actually want are installed
|
||||||
|
err = t.addRoutes(true)
|
||||||
|
if err != nil {
|
||||||
|
// Catch any stray logs
|
||||||
|
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rm *tun) SetMTU(t *wgTun, mtu int) {
|
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||||
name, err := t.tunDevice.Name()
|
r, _ := t.routeTree.Load().Lookup(ip)
|
||||||
if err != nil {
|
return r
|
||||||
t.l.WithError(err).Error("Failed to get device name for MTU set")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := runCommandBSD("ifconfig", name, "mtu", strconv.Itoa(mtu)); err != nil {
|
|
||||||
t.l.WithError(err).Error("Failed to set tun mtu")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rm *tun) SetDefaultRoute(t *wgTun, cidr netip.Prefix) error {
|
func (t *tun) Networks() []netip.Prefix {
|
||||||
// On OpenBSD, routes are set via ifconfig and route commands
|
return t.vpnNetworks
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rm *tun) AddRoutes(t *wgTun, logErrors bool) error {
|
func (t *tun) Name() string {
|
||||||
name, err := t.tunDevice.Name()
|
return t.Device
|
||||||
if err != nil {
|
}
|
||||||
return fmt.Errorf("failed to get device name: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) addRoutes(logErrors bool) error {
|
||||||
routes := *t.Routes.Load()
|
routes := *t.Routes.Load()
|
||||||
|
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if !r.Install {
|
if len(r.Via) == 0 || !r.Install {
|
||||||
|
// We don't allow route MTUs so only install routes with a via
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add route using route command
|
err := addRoute(r.Cidr, t.vpnNetworks)
|
||||||
args := []string{"add"}
|
|
||||||
|
|
||||||
if r.Cidr.Addr().Is6() {
|
|
||||||
args = append(args, "-inet6")
|
|
||||||
} else {
|
|
||||||
args = append(args, "-inet")
|
|
||||||
}
|
|
||||||
|
|
||||||
args = append(args, r.Cidr.String(), "-interface", name)
|
|
||||||
|
|
||||||
if r.Metric > 0 {
|
|
||||||
// OpenBSD doesn't support route metrics directly like Linux
|
|
||||||
t.l.WithField("route", r).Warn("Route metrics are not fully supported on OpenBSD")
|
|
||||||
}
|
|
||||||
|
|
||||||
err := runCommandBSD("route", args...)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
||||||
if logErrors {
|
if logErrors {
|
||||||
@@ -159,71 +339,131 @@ func (rm *tun) AddRoutes(t *wgTun, logErrors bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rm *tun) RemoveRoutes(t *wgTun, routes []Route) {
|
func (t *tun) removeRoutes(routes []Route) error {
|
||||||
name, err := t.tunDevice.Name()
|
|
||||||
if err != nil {
|
|
||||||
t.l.WithError(err).Error("Failed to get device name for route removal")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if !r.Install {
|
if !r.Install {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
args := []string{"delete"}
|
err := delRoute(r.Cidr, t.vpnNetworks)
|
||||||
|
|
||||||
if r.Cidr.Addr().Is6() {
|
|
||||||
args = append(args, "-inet6")
|
|
||||||
} else {
|
|
||||||
args = append(args, "-inet")
|
|
||||||
}
|
|
||||||
|
|
||||||
args = append(args, r.Cidr.String(), "-interface", name)
|
|
||||||
|
|
||||||
err := runCommandBSD("route", args...)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Removed route")
|
t.l.WithField("route", r).Info("Removed route")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rm *tun) NewMultiQueueReader(t *wgTun) (io.ReadWriteCloser, error) {
|
func (t *tun) deviceBytes() (o [16]byte) {
|
||||||
// OpenBSD doesn't support multi-queue TUN devices in the same way as Linux
|
for i, c := range t.Device {
|
||||||
// Return a reader that wraps the same device
|
o[i] = byte(c)
|
||||||
return &wgTunReader{
|
}
|
||||||
parent: t,
|
return
|
||||||
tunDevice: t.tunDevice,
|
|
||||||
offset: 0,
|
|
||||||
l: t.l,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rm *tun) addIP(t *wgTun, name string, network netip.Prefix) error {
|
func addRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
|
||||||
addr := network.Addr()
|
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||||
|
}
|
||||||
|
defer unix.Close(sock)
|
||||||
|
|
||||||
if addr.Is4() {
|
route := &netroute.RouteMessage{
|
||||||
// For IPv4: ifconfig tun0 10.0.0.1/24
|
Version: unix.RTM_VERSION,
|
||||||
if err := runCommandBSD("ifconfig", name, network.String()); err != nil {
|
Type: unix.RTM_ADD,
|
||||||
return fmt.Errorf("failed to add IPv4 address: %w", err)
|
Flags: unix.RTF_UP | unix.RTF_GATEWAY,
|
||||||
|
Seq: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
if prefix.Addr().Is4() {
|
||||||
|
gw, err := selectGateway(prefix, gateways)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
route.Addrs = []netroute.Addr{
|
||||||
|
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||||
|
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||||
|
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// For IPv6: ifconfig tun0 inet6 add 2001:db8::1/64
|
gw, err := selectGateway(prefix, gateways)
|
||||||
if err := runCommandBSD("ifconfig", name, "inet6", "add", network.String()); err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to add IPv6 address: %w", err)
|
return err
|
||||||
|
}
|
||||||
|
route.Addrs = []netroute.Addr{
|
||||||
|
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||||
|
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||||
|
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
data, err := route.Marshal()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = unix.Write(sock, data[:])
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, unix.EEXIST) {
|
||||||
|
// Try to do a change
|
||||||
|
route.Type = unix.RTM_CHANGE
|
||||||
|
data, err = route.Marshal()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
|
||||||
|
}
|
||||||
|
_, err = unix.Write(sock, data[:])
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func runCommandBSD(name string, args ...string) error {
|
func delRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
|
||||||
cmd := exec.Command(name, args...)
|
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||||
output, err := cmd.CombinedOutput()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%s %s failed: %w\nOutput: %s", name, strings.Join(args, " "), err, string(output))
|
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||||
}
|
}
|
||||||
|
defer unix.Close(sock)
|
||||||
|
|
||||||
|
route := netroute.RouteMessage{
|
||||||
|
Version: unix.RTM_VERSION,
|
||||||
|
Type: unix.RTM_DELETE,
|
||||||
|
Seq: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
if prefix.Addr().Is4() {
|
||||||
|
gw, err := selectGateway(prefix, gateways)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
route.Addrs = []netroute.Addr{
|
||||||
|
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||||
|
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||||
|
unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
gw, err := selectGateway(prefix, gateways)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
route.Addrs = []netroute.Addr{
|
||||||
|
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||||
|
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||||
|
unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := route.Marshal()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
||||||
|
}
|
||||||
|
_, err = unix.Write(sock, data[:])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,242 +0,0 @@
|
|||||||
//go:build !android && !netbsd && !e2e_testing
|
|
||||||
// +build !android,!netbsd,!e2e_testing
|
|
||||||
|
|
||||||
package overlay
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/netip"
|
|
||||||
"sync/atomic"
|
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
|
||||||
"github.com/slackhq/nebula/routing"
|
|
||||||
"github.com/slackhq/nebula/util"
|
|
||||||
wgtun "golang.zx2c4.com/wireguard/tun"
|
|
||||||
)
|
|
||||||
|
|
||||||
// wgTun wraps a WireGuard TUN device and implements the overlay.Device interface
|
|
||||||
type wgTun struct {
|
|
||||||
tunDevice wgtun.Device
|
|
||||||
vpnNetworks []netip.Prefix
|
|
||||||
MaxMTU int
|
|
||||||
DefaultMTU int
|
|
||||||
|
|
||||||
Routes atomic.Pointer[[]Route]
|
|
||||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
|
||||||
routeChan chan struct{}
|
|
||||||
|
|
||||||
// Platform-specific route management
|
|
||||||
routeManager *tun
|
|
||||||
|
|
||||||
l *logrus.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchReader interface for readers that support vectorized I/O
|
|
||||||
type BatchReader interface {
|
|
||||||
BatchRead(buffers [][]byte, sizes []int) (int, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchWriter interface for writers that support vectorized I/O
|
|
||||||
type BatchWriter interface {
|
|
||||||
BatchWrite(packets [][]byte) (int, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// wgTunReader wraps a single TUN queue for multi-queue support
|
|
||||||
type wgTunReader struct {
|
|
||||||
parent *wgTun
|
|
||||||
tunDevice wgtun.Device
|
|
||||||
offset int
|
|
||||||
l *logrus.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *wgTun) Networks() []netip.Prefix {
|
|
||||||
return t.vpnNetworks
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *wgTun) Name() string {
|
|
||||||
name, err := t.tunDevice.Name()
|
|
||||||
if err != nil {
|
|
||||||
t.l.WithError(err).Error("Failed to get TUN device name")
|
|
||||||
return "unknown"
|
|
||||||
}
|
|
||||||
return name
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *wgTun) RoutesFor(ip netip.Addr) routing.Gateways {
|
|
||||||
r, _ := t.routeTree.Load().Lookup(ip)
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *wgTun) Activate() error {
|
|
||||||
if t.routeManager == nil {
|
|
||||||
return fmt.Errorf("route manager not initialized")
|
|
||||||
}
|
|
||||||
return t.routeManager.Activate(t)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read implements single-packet read for backward compatibility
|
|
||||||
func (t *wgTun) Read(b []byte) (int, error) {
|
|
||||||
bufs := [][]byte{b}
|
|
||||||
sizes := []int{0}
|
|
||||||
n, err := t.tunDevice.Read(bufs, sizes, 0)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if n == 0 {
|
|
||||||
return 0, io.ErrNoProgress
|
|
||||||
}
|
|
||||||
return sizes[0], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write implements single-packet write for backward compatibility
|
|
||||||
func (t *wgTun) Write(b []byte) (int, error) {
|
|
||||||
bufs := [][]byte{b}
|
|
||||||
offset := 0
|
|
||||||
|
|
||||||
// WireGuard TUN expects the packet data to start at offset 0
|
|
||||||
n, err := t.tunDevice.Write(bufs, offset)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if n == 0 {
|
|
||||||
return 0, io.ErrShortWrite
|
|
||||||
}
|
|
||||||
return len(b), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *wgTun) Close() error {
|
|
||||||
if t.routeChan != nil {
|
|
||||||
close(t.routeChan)
|
|
||||||
}
|
|
||||||
|
|
||||||
if t.tunDevice != nil {
|
|
||||||
return t.tunDevice.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *wgTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
|
||||||
// For WireGuard TUN, we need to create separate TUN device instances for multi-queue
|
|
||||||
// The platform-specific implementation will handle this
|
|
||||||
if t.routeManager == nil {
|
|
||||||
return nil, fmt.Errorf("route manager not initialized for multi-queue reader")
|
|
||||||
}
|
|
||||||
|
|
||||||
return t.routeManager.NewMultiQueueReader(t)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *wgTun) reload(c *config.C, initial bool) error {
|
|
||||||
routeChange, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !initial && !routeChange && !c.HasChanged("tun.mtu") {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
routeTree, err := makeRouteTree(t.l, routes, true)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
oldDefaultMTU := t.DefaultMTU
|
|
||||||
oldMaxMTU := t.MaxMTU
|
|
||||||
newDefaultMTU := c.GetInt("tun.mtu", DefaultMTU)
|
|
||||||
newMaxMTU := newDefaultMTU
|
|
||||||
for i, r := range routes {
|
|
||||||
if r.MTU == 0 {
|
|
||||||
routes[i].MTU = newDefaultMTU
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.MTU > t.MaxMTU {
|
|
||||||
newMaxMTU = r.MTU
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
t.MaxMTU = newMaxMTU
|
|
||||||
t.DefaultMTU = newDefaultMTU
|
|
||||||
|
|
||||||
// Teach nebula how to handle the routes before establishing them in the system table
|
|
||||||
oldRoutes := t.Routes.Swap(&routes)
|
|
||||||
t.routeTree.Store(routeTree)
|
|
||||||
|
|
||||||
if !initial && t.routeManager != nil {
|
|
||||||
if oldMaxMTU != newMaxMTU {
|
|
||||||
t.routeManager.SetMTU(t, t.MaxMTU)
|
|
||||||
t.l.Infof("Set max MTU to %v was %v", t.MaxMTU, oldMaxMTU)
|
|
||||||
}
|
|
||||||
|
|
||||||
if oldDefaultMTU != newDefaultMTU {
|
|
||||||
for i := range t.vpnNetworks {
|
|
||||||
err := t.routeManager.SetDefaultRoute(t, t.vpnNetworks[i])
|
|
||||||
if err != nil {
|
|
||||||
t.l.Warn(err)
|
|
||||||
} else {
|
|
||||||
t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove first, if the system removes a wanted route hopefully it will be re-added next
|
|
||||||
t.routeManager.RemoveRoutes(t, findRemovedRoutes(routes, *oldRoutes))
|
|
||||||
|
|
||||||
// Ensure any routes we actually want are installed
|
|
||||||
err = t.routeManager.AddRoutes(t, true)
|
|
||||||
if err != nil {
|
|
||||||
// This should never be called since AddRoutes should log its own errors in a reload condition
|
|
||||||
util.LogWithContextIfNeeded("Failed to refresh routes", err, t.l)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchRead reads multiple packets from the TUN device using vectorized I/O
|
|
||||||
// The caller provides buffers and sizes slices, and this function returns the number of packets read.
|
|
||||||
func (r *wgTunReader) BatchRead(buffers [][]byte, sizes []int) (int, error) {
|
|
||||||
return r.tunDevice.Read(buffers, sizes, r.offset)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read implements io.Reader for wgTunReader (single packet for compatibility)
|
|
||||||
func (r *wgTunReader) Read(b []byte) (int, error) {
|
|
||||||
bufs := [][]byte{b}
|
|
||||||
sizes := []int{0}
|
|
||||||
n, err := r.tunDevice.Read(bufs, sizes, r.offset)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if n == 0 {
|
|
||||||
return 0, io.ErrNoProgress
|
|
||||||
}
|
|
||||||
return sizes[0], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write implements io.Writer for wgTunReader
|
|
||||||
func (r *wgTunReader) Write(b []byte) (int, error) {
|
|
||||||
bufs := [][]byte{b}
|
|
||||||
n, err := r.tunDevice.Write(bufs, r.offset)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if n == 0 {
|
|
||||||
return 0, io.ErrShortWrite
|
|
||||||
}
|
|
||||||
return len(b), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchWrite writes multiple packets to the TUN device using vectorized I/O
|
|
||||||
func (r *wgTunReader) BatchWrite(packets [][]byte) (int, error) {
|
|
||||||
return r.tunDevice.Write(packets, r.offset)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *wgTunReader) Close() error {
|
|
||||||
if r.tunDevice != nil {
|
|
||||||
return r.tunDevice.Close()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,77 +1,84 @@
|
|||||||
//go:build windows && !e2e_testing
|
//go:build !e2e_testing
|
||||||
// +build windows,!e2e_testing
|
// +build !e2e_testing
|
||||||
|
|
||||||
package overlay
|
package overlay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto"
|
"crypto"
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"sync/atomic"
|
||||||
|
"syscall"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/routing"
|
||||||
"github.com/slackhq/nebula/util"
|
"github.com/slackhq/nebula/util"
|
||||||
|
"github.com/slackhq/nebula/wintun"
|
||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
wgtun "golang.zx2c4.com/wireguard/tun"
|
|
||||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||||
)
|
)
|
||||||
|
|
||||||
const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
|
const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
|
||||||
|
|
||||||
type tun struct {
|
type winTun struct {
|
||||||
luid winipcfg.LUID
|
Device string
|
||||||
|
vpnNetworks []netip.Prefix
|
||||||
|
MTU int
|
||||||
|
Routes atomic.Pointer[[]Route]
|
||||||
|
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||||
|
l *logrus.Logger
|
||||||
|
|
||||||
|
tun *wintun.NativeTun
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*wgTun, error) {
|
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) {
|
||||||
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
|
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*wgTun, error) {
|
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) {
|
||||||
deviceName := c.GetString("tun.dev", "Nebula")
|
err := checkWinTunExists()
|
||||||
mtu := c.GetInt("tun.mtu", DefaultMTU)
|
|
||||||
|
|
||||||
// Create WireGuard TUN device
|
|
||||||
tunDevice, err := wgtun.CreateTUN(deviceName, mtu)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create TUN device: %w", err)
|
return nil, fmt.Errorf("can not load the wintun driver: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the actual device name
|
deviceName := c.GetString("tun.dev", "")
|
||||||
actualName, err := tunDevice.Name()
|
guid, err := generateGUIDByDeviceName(deviceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tunDevice.Close()
|
return nil, fmt.Errorf("generate GUID failed: %w", err)
|
||||||
return nil, fmt.Errorf("failed to get TUN device name: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
t := &wgTun{
|
t := &winTun{
|
||||||
tunDevice: tunDevice,
|
Device: deviceName,
|
||||||
vpnNetworks: vpnNetworks,
|
vpnNetworks: vpnNetworks,
|
||||||
MaxMTU: mtu,
|
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||||
DefaultMTU: mtu,
|
|
||||||
l: l,
|
l: l,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create Windows-specific route manager
|
|
||||||
rm := &tun{}
|
|
||||||
|
|
||||||
// Get LUID from the TUN device
|
|
||||||
// The WireGuard TUN device on Windows should provide a LUID() method
|
|
||||||
if nativeTun, ok := tunDevice.(interface{ LUID() uint64 }); ok {
|
|
||||||
rm.luid = winipcfg.LUID(nativeTun.LUID())
|
|
||||||
} else {
|
|
||||||
tunDevice.Close()
|
|
||||||
return nil, fmt.Errorf("failed to get LUID from TUN device")
|
|
||||||
}
|
|
||||||
t.routeManager = rm
|
|
||||||
|
|
||||||
err = t.reload(c, true)
|
err = t.reload(c, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tunDevice.Close()
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var tunDevice wintun.Device
|
||||||
|
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
|
||||||
|
if err != nil {
|
||||||
|
// Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device.
|
||||||
|
// Trying a second time resolves the issue.
|
||||||
|
l.WithError(err).Debug("Failed to create wintun device, retrying")
|
||||||
|
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create TUN device failed: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.tun = tunDevice.(*wintun.NativeTun)
|
||||||
|
|
||||||
c.RegisterReloadCallback(func(c *config.C) {
|
c.RegisterReloadCallback(func(c *config.C) {
|
||||||
err := t.reload(c, false)
|
err := t.reload(c, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -79,140 +86,206 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
l.WithField("name", actualName).Info("Created WireGuard TUN device")
|
|
||||||
|
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rm *tun) Activate(t *wgTun) error {
|
func (t *winTun) reload(c *config.C, initial bool) error {
|
||||||
// Set MTU
|
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
||||||
err := rm.setMTU(t, t.MaxMTU)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to set MTU: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add IP addresses
|
if !initial && !change {
|
||||||
for _, network := range t.vpnNetworks {
|
return nil
|
||||||
if err := rm.addIP(t, network); err != nil {
|
}
|
||||||
return err
|
|
||||||
|
routeTree, err := makeRouteTree(t.l, routes, false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Teach nebula how to handle the routes before establishing them in the system table
|
||||||
|
oldRoutes := t.Routes.Swap(&routes)
|
||||||
|
t.routeTree.Store(routeTree)
|
||||||
|
|
||||||
|
if !initial {
|
||||||
|
// Remove first, if the system removes a wanted route hopefully it will be re-added next
|
||||||
|
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
|
||||||
|
if err != nil {
|
||||||
|
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure any routes we actually want are installed
|
||||||
|
err = t.addRoutes(true)
|
||||||
|
if err != nil {
|
||||||
|
// Catch any stray logs
|
||||||
|
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add routes
|
return nil
|
||||||
if err := rm.AddRoutes(t, false); err != nil {
|
}
|
||||||
|
|
||||||
|
func (t *winTun) Activate() error {
|
||||||
|
luid := winipcfg.LUID(t.tun.LUID())
|
||||||
|
|
||||||
|
err := luid.SetIPAddresses(t.vpnNetworks)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to set address: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = t.addRoutes(false)
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rm *tun) SetMTU(t *wgTun, mtu int) {
|
func (t *winTun) addRoutes(logErrors bool) error {
|
||||||
if err := rm.setMTU(t, mtu); err != nil {
|
luid := winipcfg.LUID(t.tun.LUID())
|
||||||
t.l.WithError(err).Error("Failed to set MTU")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rm *tun) setMTU(t *wgTun, mtu int) error {
|
|
||||||
// Set MTU using winipcfg
|
|
||||||
// Note: MTU setting on Windows TUN devices may be handled by the driver
|
|
||||||
// For now, we'll skip explicit MTU setting as the WireGuard TUN handles it
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rm *tun) SetDefaultRoute(t *wgTun, cidr netip.Prefix) error {
|
|
||||||
// On Windows, routes are managed differently
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rm *tun) AddRoutes(t *wgTun, logErrors bool) error {
|
|
||||||
routes := *t.Routes.Load()
|
routes := *t.Routes.Load()
|
||||||
|
foundDefault4 := false
|
||||||
|
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if !r.Install {
|
if len(r.Via) == 0 || !r.Install {
|
||||||
|
// We don't allow route MTUs so only install routes with a via
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.MTU > 0 {
|
// Add our unsafe route
|
||||||
// Windows route MTU is not directly supported
|
// Windows does not support multipath routes natively, so we install only a single route.
|
||||||
t.l.WithField("route", r).Debug("Route MTU is not supported on Windows")
|
// This is not a problem as traffic will always be sent to Nebula which handles the multipath routing internally.
|
||||||
}
|
// In effect this provides multipath routing support to windows supporting loadbalancing and redundancy.
|
||||||
|
err := luid.AddRoute(r.Cidr, r.Via[0].Addr(), uint32(r.Metric))
|
||||||
// Use winipcfg to add the route
|
|
||||||
// The rm.luid should have the AddRoute method from winipcfg
|
|
||||||
if len(r.Via) == 0 {
|
|
||||||
t.l.WithField("route", r).Warn("Route has no via address, skipping")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
err := rm.luid.AddRoute(r.Cidr, r.Via[0].Addr(), uint32(r.Metric))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
|
||||||
if logErrors {
|
if logErrors {
|
||||||
retErr.Log(t.l)
|
retErr.Log(t.l)
|
||||||
|
continue
|
||||||
} else {
|
} else {
|
||||||
return retErr
|
return retErr
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Added route")
|
t.l.WithField("route", r).Info("Added route")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !foundDefault4 {
|
||||||
|
if r.Cidr.Bits() == 0 && r.Cidr.Addr().BitLen() == 32 {
|
||||||
|
foundDefault4 = true
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ipif, err := luid.IPInterface(windows.AF_INET)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get ip interface: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ipif.NLMTU = uint32(t.MTU)
|
||||||
|
if foundDefault4 {
|
||||||
|
ipif.UseAutomaticMetric = false
|
||||||
|
ipif.Metric = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ipif.Set(); err != nil {
|
||||||
|
return fmt.Errorf("failed to set ip interface: %w", err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rm *tun) RemoveRoutes(t *wgTun, routes []Route) {
|
func (t *winTun) removeRoutes(routes []Route) error {
|
||||||
|
luid := winipcfg.LUID(t.tun.LUID())
|
||||||
|
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if !r.Install {
|
if !r.Install {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(r.Via) == 0 {
|
// See comment on luid.AddRoute
|
||||||
continue
|
err := luid.DeleteRoute(r.Cidr, r.Via[0].Addr())
|
||||||
}
|
|
||||||
|
|
||||||
err := rm.luid.DeleteRoute(r.Cidr, r.Via[0].Addr())
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||||
} else {
|
} else {
|
||||||
t.l.WithField("route", r).Info("Removed route")
|
t.l.WithField("route", r).Info("Removed route")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func (rm *tun) NewMultiQueueReader(t *wgTun) (io.ReadWriteCloser, error) {
|
|
||||||
// Windows doesn't support multi-queue TUN devices
|
|
||||||
// Return a reader that wraps the same device
|
|
||||||
return &wgTunReader{
|
|
||||||
parent: t,
|
|
||||||
tunDevice: t.tunDevice,
|
|
||||||
offset: 0,
|
|
||||||
l: t.l,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rm *tun) addIP(t *wgTun, network netip.Prefix) error {
|
|
||||||
// Add IP address using winipcfg
|
|
||||||
// SetIPAddresses expects a slice of prefixes
|
|
||||||
err := rm.luid.SetIPAddresses([]netip.Prefix{network})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to add IP address %s: %w", network, err)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateGUIDByDeviceName generates a GUID based on the device name
|
func (t *winTun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||||
func generateGUIDByDeviceName(deviceName string) (*windows.GUID, error) {
|
r, _ := t.routeTree.Load().Lookup(ip)
|
||||||
// Hash the device name to create a deterministic GUID
|
return r
|
||||||
h := crypto.SHA256.New()
|
}
|
||||||
h.Write([]byte(tunGUIDLabel))
|
|
||||||
h.Write([]byte(deviceName))
|
func (t *winTun) Networks() []netip.Prefix {
|
||||||
sum := h.Sum(nil)
|
return t.vpnNetworks
|
||||||
|
}
|
||||||
guid := &windows.GUID{
|
|
||||||
Data1: binary.LittleEndian.Uint32(sum[0:4]),
|
func (t *winTun) Name() string {
|
||||||
Data2: binary.LittleEndian.Uint16(sum[4:6]),
|
return t.Device
|
||||||
Data3: binary.LittleEndian.Uint16(sum[6:8]),
|
}
|
||||||
}
|
|
||||||
copy(guid.Data4[:], sum[8:16])
|
func (t *winTun) Read(b []byte) (int, error) {
|
||||||
|
return t.tun.Read(b, 0)
|
||||||
return guid, nil
|
}
|
||||||
|
|
||||||
|
func (t *winTun) Write(b []byte) (int, error) {
|
||||||
|
return t.tun.Write(b, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *winTun) Close() error {
|
||||||
|
// It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes,
|
||||||
|
// so to be certain, just remove everything before destroying.
|
||||||
|
luid := winipcfg.LUID(t.tun.LUID())
|
||||||
|
_ = luid.FlushRoutes(windows.AF_INET)
|
||||||
|
_ = luid.FlushIPAddresses(windows.AF_INET)
|
||||||
|
|
||||||
|
_ = luid.FlushRoutes(windows.AF_INET6)
|
||||||
|
_ = luid.FlushIPAddresses(windows.AF_INET6)
|
||||||
|
|
||||||
|
_ = luid.FlushDNS(windows.AF_INET)
|
||||||
|
_ = luid.FlushDNS(windows.AF_INET6)
|
||||||
|
|
||||||
|
return t.tun.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateGUIDByDeviceName(name string) (*windows.GUID, error) {
|
||||||
|
// GUID is 128 bit
|
||||||
|
hash := crypto.MD5.New()
|
||||||
|
|
||||||
|
_, err := hash.Write([]byte(tunGUIDLabel))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = hash.Write([]byte(name))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sum := hash.Sum(nil)
|
||||||
|
|
||||||
|
return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkWinTunExists() error {
|
||||||
|
myPath, err := os.Executable()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
arch := runtime.GOARCH
|
||||||
|
switch arch {
|
||||||
|
case "386":
|
||||||
|
//NOTE: wintun bundles 386 as x86
|
||||||
|
arch = "x86"
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = syscall.LoadDLL(filepath.Join(filepath.Dir(myPath), "dist", "windows", "wintun", "bin", arch, "wintun.dll"))
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user