From f86953ca56e623ec7629ef7753024d8ced944a72 Mon Sep 17 00:00:00 2001 From: dioss-Machiel Date: Mon, 24 Mar 2025 23:15:59 +0100 Subject: [PATCH] Implement ECMP for unsafe_routes (#1332) --- examples/config.yml | 23 ++++++- inside.go | 95 +++++++++++++++++++++++--- overlay/device.go | 4 +- overlay/route.go | 78 ++++++++++++++++++---- overlay/route_test.go | 112 ++++++++++++++++++++++++++++++- overlay/tun_android.go | 5 +- overlay/tun_darwin.go | 9 +-- overlay/tun_disabled.go | 5 +- overlay/tun_freebsd.go | 7 +- overlay/tun_ios.go | 5 +- overlay/tun_linux.go | 91 +++++++++++++++++++------ overlay/tun_netbsd.go | 7 +- overlay/tun_openbsd.go | 7 +- overlay/tun_tester.go | 5 +- overlay/tun_windows.go | 15 +++-- overlay/user.go | 11 ++- routing/balance.go | 39 +++++++++++ routing/balance_test.go | 144 ++++++++++++++++++++++++++++++++++++++++ routing/gateway.go | 70 +++++++++++++++++++ routing/gateway_test.go | 34 ++++++++++ test/tun.go | 6 +- 21 files changed, 690 insertions(+), 82 deletions(-) create mode 100644 routing/balance.go create mode 100644 routing/balance_test.go create mode 100644 routing/gateway.go create mode 100644 routing/gateway_test.go diff --git a/examples/config.yml b/examples/config.yml index 4e7a4ae..3b7c38b 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -239,7 +239,28 @@ tun: # Unsafe routes allows you to route traffic over nebula to non-nebula nodes # Unsafe routes should be avoided unless you have hosts/services that cannot run nebula - # NOTE: The nebula certificate of the "via" node *MUST* have the "route" defined as a subnet in its certificate + # Supports weighted ECMP if you define a list of gateways, this can be used for load balancing or redundancy to hosts outside of nebula + # NOTES: + # * You will only see a single gateway in the routing table if you are not on linux + # * If a gateway is not reachable through the overlay another gateway will be selected to send the traffic through, ignoring weights + # + # unsafe_routes: + # # Multiple gateways without defining a weight defaults to a weight of 1, this will balance traffic equally between the three gateways + # - route: 192.168.87.0/24 + # via: + # - gateway: 10.0.0.1 + # - gateway: 10.0.0.2 + # - gateway: 10.0.0.3 + # # Multiple gateways with a weight, this will balance traffic accordingly + # - route: 192.168.87.0/24 + # via: + # - gateway: 10.0.0.1 + # weight: 10 + # - gateway: 10.0.0.2 + # weight: 5 + # + # NOTE: The nebula certificate of the "via" node(s) *MUST* have the "route" defined as a subnet in its certificate + # `via`: single node or list of gateways to use for this route # `mtu`: will default to tun mtu if this option is not specified # `metric`: will default to 0 if this option is not specified # `install`: will default to true, controls whether this route is installed in the systems routing table. diff --git a/inside.go b/inside.go index 9629947..0af350d 100644 --- a/inside.go +++ b/inside.go @@ -8,6 +8,7 @@ import ( "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/noiseutil" + "github.com/slackhq/nebula/routing" ) func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) { @@ -49,7 +50,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet return } - hostinfo, ready := f.getOrHandshake(fwPacket.RemoteAddr, func(hh *HandshakeHostInfo) { + hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) { hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) }) @@ -121,22 +122,94 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo * f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q) } +// Handshake will attempt to initiate a tunnel with the provided vpn address if it is within our vpn networks. This is a no-op if the tunnel is already established or being established func (f *Interface) Handshake(vpnAddr netip.Addr) { - f.getOrHandshake(vpnAddr, nil) + f.getOrHandshakeNoRouting(vpnAddr, nil) } -// getOrHandshake returns nil if the vpnAddr is not routable. +// getOrHandshakeNoRouting returns nil if the vpnAddr is not routable. // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel -func (f *Interface) getOrHandshake(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { +func (f *Interface) getOrHandshakeNoRouting(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { _, found := f.myVpnNetworksTable.Lookup(vpnAddr) - if !found { - vpnAddr = f.inside.RouteFor(vpnAddr) - if !vpnAddr.IsValid() { - return nil, false - } + if found { + return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback) + } + + return nil, false +} + +// getOrHandshakeConsiderRouting will try to find the HostInfo to handle this packet, starting a handshake if necessary. +// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel. +func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { + + destinationAddr := fwPacket.RemoteAddr + + hostinfo, ready := f.getOrHandshakeNoRouting(destinationAddr, cacheCallback) + + // Host is inside the mesh, no routing required + if hostinfo != nil { + return hostinfo, ready + } + + gateways := f.inside.RoutesFor(destinationAddr) + + switch len(gateways) { + case 0: + return nil, false + case 1: + // Single gateway route + return f.handshakeManager.GetOrHandshake(gateways[0].Addr(), cacheCallback) + default: + // Multi gateway route, perform ECMP categorization + gatewayAddr, balancingOk := routing.BalancePacket(fwPacket, gateways) + + if !balancingOk { + // This happens if the gateway buckets were not calculated, this _should_ never happen + f.l.Error("Gateway buckets not calculated, fallback from ECMP to random routing. Please report this bug.") + } + + var handshakeInfoForChosenGateway *HandshakeHostInfo + var hhReceiver = func(hh *HandshakeHostInfo) { + handshakeInfoForChosenGateway = hh + } + + // Store the handshakeHostInfo for later. + // If this node is not reachable we will attempt other nodes, if none are reachable we will + // cache the packet for this gateway. + if hostinfo, ready = f.handshakeManager.GetOrHandshake(gatewayAddr, hhReceiver); ready { + return hostinfo, true + } + + // It appears the selected gateway cannot be reached, find another gateway to fallback on. + // The current implementation breaks ECMP but that seems better than no connectivity. + // If ECMP is also required when a gateway is down then connectivity status + // for each gateway needs to be kept and the weights recalculated when they go up or down. + // This would also need to interact with unsafe_route updates through reloading the config or + // use of the use_system_route_table option + + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("destination", destinationAddr). + WithField("originalGateway", gatewayAddr). + Debugln("Calculated gateway for ECMP not available, attempting other gateways") + } + + for i := range gateways { + // Skip the gateway that failed previously + if gateways[i].Addr() == gatewayAddr { + continue + } + + // We do not need the HandshakeHostInfo since we cache the packet in the originally chosen gateway + if hostinfo, ready = f.handshakeManager.GetOrHandshake(gateways[i].Addr(), nil); ready { + return hostinfo, true + } + } + + // No gateways reachable, cache the packet in the originally chosen gateway + cacheCallback(handshakeInfoForChosenGateway) + return hostinfo, false } - return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback) } func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) { @@ -163,7 +236,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp // SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) { - hostInfo, ready := f.getOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) { + hostInfo, ready := f.getOrHandshakeNoRouting(vpnAddr, func(hh *HandshakeHostInfo) { hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics) }) diff --git a/overlay/device.go b/overlay/device.go index da8cbe9..07146ab 100644 --- a/overlay/device.go +++ b/overlay/device.go @@ -3,6 +3,8 @@ package overlay import ( "io" "net/netip" + + "github.com/slackhq/nebula/routing" ) type Device interface { @@ -10,6 +12,6 @@ type Device interface { Activate() error Networks() []netip.Prefix Name() string - RouteFor(netip.Addr) netip.Addr + RoutesFor(netip.Addr) routing.Gateways NewMultiQueueReader() (io.ReadWriteCloser, error) } diff --git a/overlay/route.go b/overlay/route.go index 687cc11..12364ec 100644 --- a/overlay/route.go +++ b/overlay/route.go @@ -11,13 +11,14 @@ import ( "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" ) type Route struct { MTU int Metric int Cidr netip.Prefix - Via netip.Addr + Via routing.Gateways Install bool } @@ -47,15 +48,17 @@ func (r Route) String() string { return s } -func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[netip.Addr], error) { - routeTree := new(bart.Table[netip.Addr]) +func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[routing.Gateways], error) { + routeTree := new(bart.Table[routing.Gateways]) for _, r := range routes { if !allowMTU && r.MTU > 0 { l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS) } - if r.Via.IsValid() { - routeTree.Insert(r.Cidr, r.Via) + gateways := r.Via + if len(gateways) > 0 { + routing.CalculateBucketsForGateways(gateways) + routeTree.Insert(r.Cidr, gateways) } } return routeTree, nil @@ -201,14 +204,63 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) { return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not present", i+1) } - via, ok := rVia.(string) - if !ok { - return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string: found %T", i+1, rVia) - } + var gateways routing.Gateways - viaVpnIp, err := netip.ParseAddr(via) - if err != nil { - return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, err) + switch via := rVia.(type) { + case string: + viaIp, err := netip.ParseAddr(via) + if err != nil { + return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, err) + } + + gateways = routing.Gateways{routing.NewGateway(viaIp, 1)} + + case []interface{}: + gateways = make(routing.Gateways, len(via)) + for ig, v := range via { + gatewayMap, ok := v.(map[interface{}]interface{}) + if !ok { + return nil, fmt.Errorf("entry %v in tun.unsafe_routes[%v].via is invalid", i+1, ig+1) + } + + rGateway, ok := gatewayMap["gateway"] + if !ok { + return nil, fmt.Errorf("entry .gateway in tun.unsafe_routes[%v].via[%v] is not present", i+1, ig+1) + } + + parsedGateway, ok := rGateway.(string) + if !ok { + return nil, fmt.Errorf("entry .gateway in tun.unsafe_routes[%v].via[%v] is not a string", i+1, ig+1) + } + + gatewayIp, err := netip.ParseAddr(parsedGateway) + if err != nil { + return nil, fmt.Errorf("entry .gateway in tun.unsafe_routes[%v].via[%v] failed to parse address: %v", i+1, ig+1, err) + } + + rGatewayWeight, ok := gatewayMap["weight"] + if !ok { + rGatewayWeight = 1 + } + + gatewayWeight, ok := rGatewayWeight.(int) + if !ok { + _, err = strconv.ParseInt(rGatewayWeight.(string), 10, 32) + if err != nil { + return nil, fmt.Errorf("entry .weight in tun.unsafe_routes[%v].via[%v] is not an integer", i+1, ig+1) + } + } + + if gatewayWeight < 1 || gatewayWeight > math.MaxInt32 { + return nil, fmt.Errorf("entry .weight in tun.unsafe_routes[%v].via[%v] is not in range (1-%d) : %v", i+1, ig+1, math.MaxInt32, gatewayWeight) + } + + gateways[ig] = routing.NewGateway(gatewayIp, gatewayWeight) + + } + + default: + return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string or list of gateways: found %T", i+1, rVia) } rRoute, ok := m["route"] @@ -226,7 +278,7 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) { } r := Route{ - Via: viaVpnIp, + Via: gateways, MTU: mtu, Metric: metric, Install: install, diff --git a/overlay/route_test.go b/overlay/route_test.go index 8f2c094..eb5e914 100644 --- a/overlay/route_test.go +++ b/overlay/route_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -158,15 +159,39 @@ func Test_parseUnsafeRoutes(t *testing.T) { c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - require.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue)) + require.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string or list of gateways: found %T", invalidValue)) } + // Unparsable list of via + c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": []string{"1", "2"}}}} + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) + assert.Nil(t, routes) + require.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not a string or list of gateways: found []string") + // unparsable via c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP") + // unparsable gateway + c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": []interface{}{map[interface{}]interface{}{"gateway": "1"}}}}} + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) + assert.Nil(t, routes) + require.EqualError(t, err, "entry .gateway in tun.unsafe_routes[1].via[1] failed to parse address: ParseAddr(\"1\"): unable to parse IP") + + // missing gateway element + c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": []interface{}{map[interface{}]interface{}{"weight": "1"}}}}} + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) + assert.Nil(t, routes) + require.EqualError(t, err, "entry .gateway in tun.unsafe_routes[1].via[1] is not present") + + // unparsable weight element + c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": []interface{}{map[interface{}]interface{}{"gateway": "10.0.0.1", "weight": "a"}}}}} + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) + assert.Nil(t, routes) + require.EqualError(t, err, "entry .weight in tun.unsafe_routes[1].via[1] is not an integer") + // missing route c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) @@ -280,7 +305,7 @@ func Test_makeRouteTree(t *testing.T) { nip, err := netip.ParseAddr("192.168.0.1") require.NoError(t, err) - assert.Equal(t, nip, r) + assert.Equal(t, nip, r[0].Addr()) ip, err = netip.ParseAddr("1.0.0.1") require.NoError(t, err) @@ -289,10 +314,91 @@ func Test_makeRouteTree(t *testing.T) { nip, err = netip.ParseAddr("192.168.0.2") require.NoError(t, err) - assert.Equal(t, nip, r) + assert.Equal(t, nip, r[0].Addr()) ip, err = netip.ParseAddr("1.1.0.1") require.NoError(t, err) r, ok = routeTree.Lookup(ip) assert.False(t, ok) } + +func Test_makeMultipathUnsafeRouteTree(t *testing.T) { + l := test.NewLogger() + c := config.NewC(l) + n, err := netip.ParsePrefix("10.0.0.0/24") + require.NoError(t, err) + + c.Settings["tun"] = map[interface{}]interface{}{ + "unsafe_routes": []interface{}{ + map[interface{}]interface{}{ + "route": "192.168.86.0/24", + "via": "192.168.100.10", + }, + map[interface{}]interface{}{ + "route": "192.168.87.0/24", + "via": []interface{}{ + map[interface{}]interface{}{ + "gateway": "10.0.0.1", + }, + map[interface{}]interface{}{ + "gateway": "10.0.0.2", + }, + map[interface{}]interface{}{ + "gateway": "10.0.0.3", + }, + }, + }, + map[interface{}]interface{}{ + "route": "192.168.89.0/24", + "via": []interface{}{ + map[interface{}]interface{}{ + "gateway": "10.0.0.1", + "weight": 10, + }, + map[interface{}]interface{}{ + "gateway": "10.0.0.2", + "weight": 5, + }, + }, + }, + }, + } + + routes, err := parseUnsafeRoutes(c, []netip.Prefix{n}) + require.NoError(t, err) + assert.Len(t, routes, 3) + routeTree, err := makeRouteTree(l, routes, true) + require.NoError(t, err) + + ip, err := netip.ParseAddr("192.168.86.1") + require.NoError(t, err) + r, ok := routeTree.Lookup(ip) + assert.True(t, ok) + + nip, err := netip.ParseAddr("192.168.100.10") + require.NoError(t, err) + assert.Equal(t, nip, r[0].Addr()) + + ip, err = netip.ParseAddr("192.168.87.1") + require.NoError(t, err) + r, ok = routeTree.Lookup(ip) + assert.True(t, ok) + + expectedGateways := routing.Gateways{routing.NewGateway(netip.MustParseAddr("10.0.0.1"), 1), + routing.NewGateway(netip.MustParseAddr("10.0.0.2"), 1), + routing.NewGateway(netip.MustParseAddr("10.0.0.3"), 1)} + + routing.CalculateBucketsForGateways(expectedGateways) + assert.ElementsMatch(t, expectedGateways, r) + + ip, err = netip.ParseAddr("192.168.89.1") + require.NoError(t, err) + r, ok = routeTree.Lookup(ip) + assert.True(t, ok) + + expectedGateways = routing.Gateways{routing.NewGateway(netip.MustParseAddr("10.0.0.1"), 10), + routing.NewGateway(netip.MustParseAddr("10.0.0.2"), 5)} + + routing.CalculateBucketsForGateways(expectedGateways) + assert.ElementsMatch(t, expectedGateways, r) +} diff --git a/overlay/tun_android.go b/overlay/tun_android.go index 72a6565..df1ed8d 100644 --- a/overlay/tun_android.go +++ b/overlay/tun_android.go @@ -13,6 +13,7 @@ import ( "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" ) @@ -21,7 +22,7 @@ type tun struct { fd int vpnNetworks []netip.Prefix Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] l *logrus.Logger } @@ -56,7 +57,7 @@ func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, erro return nil, fmt.Errorf("newTun not supported in Android") } -func (t *tun) RouteFor(ip netip.Addr) netip.Addr { +func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { r, _ := t.routeTree.Load().Lookup(ip) return r } diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index 1a02b49..d2b2896 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -17,6 +17,7 @@ import ( "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" netroute "golang.org/x/net/route" "golang.org/x/sys/unix" @@ -28,7 +29,7 @@ type tun struct { vpnNetworks []netip.Prefix DefaultMTU int Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] linkAddr *netroute.LinkAddr l *logrus.Logger @@ -342,12 +343,12 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) RouteFor(ip netip.Addr) netip.Addr { +func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { r, ok := t.routeTree.Load().Lookup(ip) if ok { return r } - return netip.Addr{} + return routing.Gateways{} } // Get the LinkAddr for the interface of the given name @@ -382,7 +383,7 @@ func (t *tun) addRoutes(logErrors bool) error { routes := *t.Routes.Load() for _, r := range routes { - if !r.Via.IsValid() || !r.Install { + if len(r.Via) == 0 || !r.Install { // We don't allow route MTUs so only install routes with a via continue } diff --git a/overlay/tun_disabled.go b/overlay/tun_disabled.go index cfbf17d..131879d 100644 --- a/overlay/tun_disabled.go +++ b/overlay/tun_disabled.go @@ -9,6 +9,7 @@ import ( "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/routing" ) type disabledTun struct { @@ -43,8 +44,8 @@ func (*disabledTun) Activate() error { return nil } -func (*disabledTun) RouteFor(addr netip.Addr) netip.Addr { - return netip.Addr{} +func (*disabledTun) RoutesFor(addr netip.Addr) routing.Gateways { + return routing.Gateways{} } func (t *disabledTun) Networks() []netip.Prefix { diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index 69690e9..bcb82b3 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -20,6 +20,7 @@ import ( "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" ) @@ -50,7 +51,7 @@ type tun struct { vpnNetworks []netip.Prefix MTU int Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] l *logrus.Logger io.ReadWriteCloser @@ -242,7 +243,7 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) RouteFor(ip netip.Addr) netip.Addr { +func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { r, _ := t.routeTree.Load().Lookup(ip) return r } @@ -262,7 +263,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) addRoutes(logErrors bool) error { routes := *t.Routes.Load() for _, r := range routes { - if !r.Via.IsValid() || !r.Install { + if len(r.Via) == 0 || !r.Install { // We don't allow route MTUs so only install routes with a via continue } diff --git a/overlay/tun_ios.go b/overlay/tun_ios.go index e99d447..e51e112 100644 --- a/overlay/tun_ios.go +++ b/overlay/tun_ios.go @@ -16,6 +16,7 @@ import ( "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" ) @@ -23,7 +24,7 @@ type tun struct { io.ReadWriteCloser vpnNetworks []netip.Prefix Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] l *logrus.Logger } @@ -79,7 +80,7 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) RouteFor(ip netip.Addr) netip.Addr { +func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { r, _ := t.routeTree.Load().Lookup(ip) return r } diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 993bd4a..809536f 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -17,6 +17,7 @@ import ( "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" "github.com/vishvananda/netlink" "golang.org/x/sys/unix" @@ -34,7 +35,7 @@ type tun struct { ioctlFd uintptr Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] routeChan chan struct{} useSystemRoutes bool @@ -231,7 +232,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return file, nil } -func (t *tun) RouteFor(ip netip.Addr) netip.Addr { +func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { r, _ := t.routeTree.Load().Lookup(ip) return r } @@ -550,20 +551,7 @@ func (t *tun) watchRoutes() { }() } -func (t *tun) updateRoutes(r netlink.RouteUpdate) { - if r.Gw == nil { - // Not a gateway route, ignore - t.l.WithField("route", r).Debug("Ignoring route update, not a gateway route") - return - } - - gwAddr, ok := netip.AddrFromSlice(r.Gw) - if !ok { - t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address") - return - } - - gwAddr = gwAddr.Unmap() +func (t *tun) isGatewayInVpnNetworks(gwAddr netip.Addr) bool { withinNetworks := false for i := range t.vpnNetworks { if t.vpnNetworks[i].Contains(gwAddr) { @@ -571,9 +559,68 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { break } } - if !withinNetworks { - // Gateway isn't in our overlay network, ignore - t.l.WithField("route", r).Debug("Ignoring route update, not in our networks") + + return withinNetworks +} + +func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways { + + var gateways routing.Gateways + + link, err := netlink.LinkByName(t.Device) + if err != nil { + t.l.WithField("Devicename", t.Device).Error("Ignoring route update: failed to get link by name") + return gateways + } + + // If this route is relevant to our interface and there is a gateway then add it + if r.LinkIndex == link.Attrs().Index && len(r.Gw) > 0 { + gwAddr, ok := netip.AddrFromSlice(r.Gw) + if !ok { + t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address") + } else { + gwAddr = gwAddr.Unmap() + + if !t.isGatewayInVpnNetworks(gwAddr) { + // Gateway isn't in our overlay network, ignore + t.l.WithField("route", r).Debug("Ignoring route update, not in our network") + } else { + gateways = append(gateways, routing.NewGateway(gwAddr, 1)) + } + } + } + + for _, p := range r.MultiPath { + // If this route is relevant to our interface and there is a gateway then add it + if p.LinkIndex == link.Attrs().Index && len(p.Gw) > 0 { + gwAddr, ok := netip.AddrFromSlice(p.Gw) + if !ok { + t.l.WithField("route", r).Debug("Ignoring multipath route update, invalid gateway address") + } else { + gwAddr = gwAddr.Unmap() + + if !t.isGatewayInVpnNetworks(gwAddr) { + // Gateway isn't in our overlay network, ignore + t.l.WithField("route", r).Debug("Ignoring route update, not in our network") + } else { + // p.Hops+1 = weight of the route + gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1)) + } + } + } + } + + routing.CalculateBucketsForGateways(gateways) + return gateways +} + +func (t *tun) updateRoutes(r netlink.RouteUpdate) { + + gateways := t.getGatewaysFromRoute(&r.Route) + + if len(gateways) == 0 { + // No gateways relevant to our network, no routing changes required. + t.l.WithField("route", r).Debug("Ignoring route update, no gateways") return } @@ -589,12 +636,12 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { newTree := t.routeTree.Load().Clone() if r.Type == unix.RTM_NEWROUTE { - t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Adding route") - newTree.Insert(dst, gwAddr) + t.l.WithField("destination", dst).WithField("via", gateways).Info("Adding route") + newTree.Insert(dst, gateways) } else { + t.l.WithField("destination", dst).WithField("via", gateways).Info("Removing route") newTree.Delete(dst) - t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route") } t.routeTree.Store(newTree) } diff --git a/overlay/tun_netbsd.go b/overlay/tun_netbsd.go index f7586cb..847f1b5 100644 --- a/overlay/tun_netbsd.go +++ b/overlay/tun_netbsd.go @@ -18,6 +18,7 @@ import ( "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" ) @@ -31,7 +32,7 @@ type tun struct { vpnNetworks []netip.Prefix MTU int Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] l *logrus.Logger io.ReadWriteCloser @@ -177,7 +178,7 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) RouteFor(ip netip.Addr) netip.Addr { +func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { r, _ := t.routeTree.Load().Lookup(ip) return r } @@ -197,7 +198,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) addRoutes(logErrors bool) error { routes := *t.Routes.Load() for _, r := range routes { - if !r.Via.IsValid() || !r.Install { + if len(r.Via) == 0 || !r.Install { // We don't allow route MTUs so only install routes with a via continue } diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go index a2fd184..03fb3a0 100644 --- a/overlay/tun_openbsd.go +++ b/overlay/tun_openbsd.go @@ -17,6 +17,7 @@ import ( "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" ) @@ -25,7 +26,7 @@ type tun struct { vpnNetworks []netip.Prefix MTU int Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] l *logrus.Logger io.ReadWriteCloser @@ -158,7 +159,7 @@ func (t *tun) Activate() error { return nil } -func (t *tun) RouteFor(ip netip.Addr) netip.Addr { +func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { r, _ := t.routeTree.Load().Lookup(ip) return r } @@ -166,7 +167,7 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr { func (t *tun) addRoutes(logErrors bool) error { routes := *t.Routes.Load() for _, r := range routes { - if !r.Via.IsValid() || !r.Install { + if len(r.Via) == 0 || !r.Install { // We don't allow route MTUs so only install routes with a via continue } diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index cc3942f..b6712fb 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -13,13 +13,14 @@ import ( "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" ) type TestTun struct { Device string vpnNetworks []netip.Prefix Routes []Route - routeTree *bart.Table[netip.Addr] + routeTree *bart.Table[routing.Gateways] l *logrus.Logger closed atomic.Bool @@ -86,7 +87,7 @@ func (t *TestTun) Get(block bool) []byte { // Below this is boilerplate implementation to make nebula actually work //********************************************************************************************************************// -func (t *TestTun) RouteFor(ip netip.Addr) netip.Addr { +func (t *TestTun) RoutesFor(ip netip.Addr) routing.Gateways { r, _ := t.routeTree.Lookup(ip) return r } diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index 289999d..1d66eac 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -18,6 +18,7 @@ import ( "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" "github.com/slackhq/nebula/wintun" "golang.org/x/sys/windows" @@ -31,7 +32,7 @@ type winTun struct { vpnNetworks []netip.Prefix MTU int Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] l *logrus.Logger tun *wintun.NativeTun @@ -147,13 +148,16 @@ func (t *winTun) addRoutes(logErrors bool) error { foundDefault4 := false for _, r := range routes { - if !r.Via.IsValid() || !r.Install { + if len(r.Via) == 0 || !r.Install { // We don't allow route MTUs so only install routes with a via continue } // Add our unsafe route - err := luid.AddRoute(r.Cidr, r.Via, uint32(r.Metric)) + // Windows does not support multipath routes natively, so we install only a single route. + // This is not a problem as traffic will always be sent to Nebula which handles the multipath routing internally. + // In effect this provides multipath routing support to windows supporting loadbalancing and redundancy. + err := luid.AddRoute(r.Cidr, r.Via[0].Addr(), uint32(r.Metric)) if err != nil { retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err) if logErrors { @@ -198,7 +202,8 @@ func (t *winTun) removeRoutes(routes []Route) error { continue } - err := luid.DeleteRoute(r.Cidr, r.Via) + // See comment on luid.AddRoute + err := luid.DeleteRoute(r.Cidr, r.Via[0].Addr()) if err != nil { t.l.WithError(err).WithField("route", r).Error("Failed to remove route") } else { @@ -208,7 +213,7 @@ func (t *winTun) removeRoutes(routes []Route) error { return nil } -func (t *winTun) RouteFor(ip netip.Addr) netip.Addr { +func (t *winTun) RoutesFor(ip netip.Addr) routing.Gateways { r, _ := t.routeTree.Load().Lookup(ip) return r } diff --git a/overlay/user.go b/overlay/user.go index ae665f3..8a56d66 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -6,6 +6,7 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" ) func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { @@ -38,9 +39,13 @@ type UserDevice struct { func (d *UserDevice) Activate() error { return nil } -func (d *UserDevice) Networks() []netip.Prefix { return d.vpnNetworks } -func (d *UserDevice) Name() string { return "faketun0" } -func (d *UserDevice) RouteFor(ip netip.Addr) netip.Addr { return ip } + +func (d *UserDevice) Networks() []netip.Prefix { return d.vpnNetworks } +func (d *UserDevice) Name() string { return "faketun0" } +func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways { + return routing.Gateways{routing.NewGateway(ip, 1)} +} + func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { return d, nil } diff --git a/routing/balance.go b/routing/balance.go new file mode 100644 index 0000000..6f52497 --- /dev/null +++ b/routing/balance.go @@ -0,0 +1,39 @@ +package routing + +import ( + "net/netip" + + "github.com/slackhq/nebula/firewall" +) + +// Hashes the packet source and destination port and always returns a positive integer +// Based on 'Prospecting for Hash Functions' +// - https://nullprogram.com/blog/2018/07/31/ +// - https://github.com/skeeto/hash-prospector +// [16 21f0aaad 15 d35a2d97 15] = 0.10760229515479501 +func hashPacket(p *firewall.Packet) int { + x := (uint32(p.LocalPort) << 16) | uint32(p.RemotePort) + x ^= x >> 16 + x *= 0x21f0aaad + x ^= x >> 15 + x *= 0xd35a2d97 + x ^= x >> 15 + + return int(x) & 0x7FFFFFFF +} + +// For this function to work correctly it requires that the buckets for the gateways have been calculated +// If the contract is violated balancing will not work properly and the second return value will return false +func BalancePacket(fwPacket *firewall.Packet, gateways []Gateway) (netip.Addr, bool) { + hash := hashPacket(fwPacket) + + for i := range gateways { + if hash <= gateways[i].BucketUpperBound() { + return gateways[i].Addr(), true + } + } + + // If you land here then the buckets for the gateways are not properly calculated + // Fallback to random routing and let the caller know + return gateways[hash%len(gateways)].Addr(), false +} diff --git a/routing/balance_test.go b/routing/balance_test.go new file mode 100644 index 0000000..bbfcb22 --- /dev/null +++ b/routing/balance_test.go @@ -0,0 +1,144 @@ +package routing + +import ( + "net/netip" + "testing" + + "github.com/slackhq/nebula/firewall" + "github.com/stretchr/testify/assert" +) + +func TestPacketsAreBalancedEqually(t *testing.T) { + + gateways := []Gateway{} + + gw1Addr := netip.MustParseAddr("1.0.0.1") + gw2Addr := netip.MustParseAddr("1.0.0.2") + gw3Addr := netip.MustParseAddr("1.0.0.3") + + gateways = append(gateways, NewGateway(gw1Addr, 1)) + gateways = append(gateways, NewGateway(gw2Addr, 1)) + gateways = append(gateways, NewGateway(gw3Addr, 1)) + + CalculateBucketsForGateways(gateways) + + gw1count := 0 + gw2count := 0 + gw3count := 0 + + iterationCount := uint16(65535) + for i := uint16(0); i < iterationCount; i++ { + packet := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: i, + RemotePort: 65535 - i, + Protocol: 6, // TCP + Fragment: false, + } + + selectedGw, ok := BalancePacket(&packet, gateways) + assert.True(t, ok) + + switch selectedGw { + case gw1Addr: + gw1count += 1 + case gw2Addr: + gw2count += 1 + case gw3Addr: + gw3count += 1 + } + + } + + // Assert packets are balanced, allow variation of up to 100 packets per gateway + assert.InDeltaf(t, iterationCount/3, gw1count, 100, "Expected %d +/- 100, but got %d", iterationCount/3, gw1count) + assert.InDeltaf(t, iterationCount/3, gw2count, 100, "Expected %d +/- 100, but got %d", iterationCount/3, gw1count) + assert.InDeltaf(t, iterationCount/3, gw3count, 100, "Expected %d +/- 100, but got %d", iterationCount/3, gw1count) + +} + +func TestPacketsAreBalancedByPriority(t *testing.T) { + + gateways := []Gateway{} + + gw1Addr := netip.MustParseAddr("1.0.0.1") + gw2Addr := netip.MustParseAddr("1.0.0.2") + + gateways = append(gateways, NewGateway(gw1Addr, 10)) + gateways = append(gateways, NewGateway(gw2Addr, 5)) + + CalculateBucketsForGateways(gateways) + + gw1count := 0 + gw2count := 0 + + iterationCount := uint16(65535) + for i := uint16(0); i < iterationCount; i++ { + packet := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: i, + RemotePort: 65535 - i, + Protocol: 6, // TCP + Fragment: false, + } + + selectedGw, ok := BalancePacket(&packet, gateways) + assert.True(t, ok) + + switch selectedGw { + case gw1Addr: + gw1count += 1 + case gw2Addr: + gw2count += 1 + } + + } + + iterationCountAsFloat := float32(iterationCount) + + assert.InDeltaf(t, iterationCountAsFloat*(2.0/3.0), gw1count, 100, "Expected %d +/- 100, but got %d", iterationCountAsFloat*(2.0/3.0), gw1count) + assert.InDeltaf(t, iterationCountAsFloat*(1.0/3.0), gw2count, 100, "Expected %d +/- 100, but got %d", iterationCountAsFloat*(1.0/3.0), gw2count) +} + +func TestBalancePacketDistributsRandomlyAndReturnsFalseIfBucketsNotCalculated(t *testing.T) { + gateways := []Gateway{} + + gw1Addr := netip.MustParseAddr("1.0.0.1") + gw2Addr := netip.MustParseAddr("1.0.0.2") + + gateways = append(gateways, NewGateway(gw1Addr, 10)) + gateways = append(gateways, NewGateway(gw2Addr, 5)) + + iterationCount := uint16(65535) + gw1count := 0 + gw2count := 0 + + for i := uint16(0); i < iterationCount; i++ { + packet := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: i, + RemotePort: 65535 - i, + Protocol: 6, // TCP + Fragment: false, + } + + selectedGw, ok := BalancePacket(&packet, gateways) + assert.False(t, ok) + + switch selectedGw { + case gw1Addr: + gw1count += 1 + case gw2Addr: + gw2count += 1 + } + + } + + assert.Equal(t, int(iterationCount), (gw1count + gw2count)) + assert.NotEqual(t, 0, gw1count) + assert.NotEqual(t, 0, gw2count) + +} diff --git a/routing/gateway.go b/routing/gateway.go new file mode 100644 index 0000000..59d38a9 --- /dev/null +++ b/routing/gateway.go @@ -0,0 +1,70 @@ +package routing + +import ( + "fmt" + "net/netip" +) + +const ( + // Sentinal value + BucketNotCalculated = -1 +) + +type Gateways []Gateway + +func (g Gateways) String() string { + str := "" + for i, gw := range g { + str += gw.String() + if i < len(g)-1 { + str += ", " + } + } + return str +} + +type Gateway struct { + addr netip.Addr + weight int + bucketUpperBound int +} + +func NewGateway(addr netip.Addr, weight int) Gateway { + return Gateway{addr: addr, weight: weight, bucketUpperBound: BucketNotCalculated} +} + +func (g *Gateway) BucketUpperBound() int { + return g.bucketUpperBound +} + +func (g *Gateway) Addr() netip.Addr { + return g.addr +} + +func (g *Gateway) String() string { + return fmt.Sprintf("{addr: %s, weight: %d}", g.addr, g.weight) +} + +// Divide and round to nearest integer +func divideAndRound(v uint64, d uint64) uint64 { + var tmp uint64 = v + d/2 + return tmp / d +} + +// Implements Hash-Threshold mapping, equivalent to the implementation in the linux kernel. +// After this function returns each gateway will have a +// positive bucketUpperBound with a maximum value of 2147483647 (INT_MAX) +func CalculateBucketsForGateways(gateways []Gateway) { + + var totalWeight int = 0 + for i := range gateways { + totalWeight += gateways[i].weight + } + + var loopWeight int = 0 + for i := range gateways { + loopWeight += gateways[i].weight + gateways[i].bucketUpperBound = int(divideAndRound(uint64(loopWeight)<<31, uint64(totalWeight))) - 1 + } + +} diff --git a/routing/gateway_test.go b/routing/gateway_test.go new file mode 100644 index 0000000..8ae78f3 --- /dev/null +++ b/routing/gateway_test.go @@ -0,0 +1,34 @@ +package routing + +import ( + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRebalance3_2Split(t *testing.T) { + gateways := []Gateway{} + + gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 10}) + gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 5}) + + CalculateBucketsForGateways(gateways) + + assert.Equal(t, 1431655764, gateways[0].bucketUpperBound) // INT_MAX/3*2 + assert.Equal(t, 2147483647, gateways[1].bucketUpperBound) // INT_MAX +} + +func TestRebalanceEqualSplit(t *testing.T) { + gateways := []Gateway{} + + gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 1}) + gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 1}) + gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 1}) + + CalculateBucketsForGateways(gateways) + + assert.Equal(t, 715827882, gateways[0].bucketUpperBound) // INT_MAX/3 + assert.Equal(t, 1431655764, gateways[1].bucketUpperBound) // INT_MAX/3*2 + assert.Equal(t, 2147483647, gateways[2].bucketUpperBound) // INT_MAX +} diff --git a/test/tun.go b/test/tun.go index b29d61c..ca65805 100644 --- a/test/tun.go +++ b/test/tun.go @@ -4,12 +4,14 @@ import ( "errors" "io" "net/netip" + + "github.com/slackhq/nebula/routing" ) type NoopTun struct{} -func (NoopTun) RouteFor(addr netip.Addr) netip.Addr { - return netip.Addr{} +func (NoopTun) RoutesFor(addr netip.Addr) routing.Gateways { + return routing.Gateways{} } func (NoopTun) Activate() error {