From fb7f0c36572ce4eb5bb915a8d70344fc24e07993 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Fri, 3 Oct 2025 11:59:53 -0400 Subject: [PATCH] Use x/net/route to manage routes directly (#1488) --- overlay/tun.go | 11 +++ overlay/tun_darwin.go | 11 --- overlay/tun_freebsd.go | 174 ++++++++++++++++++++++++++++++++++------- 3 files changed, 158 insertions(+), 38 deletions(-) diff --git a/overlay/tun.go b/overlay/tun.go index 4a6377d..ddf44a3 100644 --- a/overlay/tun.go +++ b/overlay/tun.go @@ -1,6 +1,7 @@ package overlay import ( + "net" "net/netip" "github.com/sirupsen/logrus" @@ -70,3 +71,13 @@ func findRemovedRoutes(newRoutes, oldRoutes []Route) []Route { return removed } + +func prefixToMask(prefix netip.Prefix) netip.Addr { + pLen := 128 + if prefix.Addr().Is4() { + pLen = 32 + } + + addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen)) + return addr +} diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index 7f6ba4f..5ecbeb8 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -7,7 +7,6 @@ import ( "errors" "fmt" "io" - "net" "net/netip" "os" "sync/atomic" @@ -554,13 +553,3 @@ func (t *tun) Name() string { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin") } - -func prefixToMask(prefix netip.Prefix) netip.Addr { - pLen := 128 - if prefix.Addr().Is4() { - pLen = 32 - } - - addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen)) - return addr -} diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index 8e0e4f5..f597881 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -9,9 +9,7 @@ import ( "fmt" "io" "io/fs" - "net" "net/netip" - "os/exec" "sync/atomic" "syscall" "time" @@ -22,6 +20,7 @@ import ( "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" + netroute "golang.org/x/net/route" "golang.org/x/sys/unix" ) @@ -92,6 +91,7 @@ type tun struct { MTU int Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] + linkAddr *netroute.LinkAddr l *logrus.Logger devFd int } @@ -162,6 +162,7 @@ func (t *tun) Write(from []byte) (int, error) { } else { err = nil } + return int(n) - 4, err } @@ -308,7 +309,7 @@ func (t *tun) addIp(cidr netip.Prefix) error { MaskAddr: unix.RawSockaddrInet4{ Len: unix.SizeofSockaddrInet4, Family: unix.AF_INET, - Addr: getNetmask(cidr).As4(), + Addr: prefixToMask(cidr).As4(), }, VHid: 0, } @@ -321,7 +322,10 @@ func (t *tun) addIp(cidr netip.Prefix) error { if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil { return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err) } - } else if cidr.Addr().Is6() { + return nil + } + + if cidr.Addr().Is6() { ifr := ifreqAlias6{ Name: t.deviceBytes(), Addr: unix.RawSockaddrInet6{ @@ -332,7 +336,7 @@ func (t *tun) addIp(cidr netip.Prefix) error { PrefixMask: unix.RawSockaddrInet6{ Len: unix.SizeofSockaddrInet6, Family: unix.AF_INET6, - Addr: getNetmask(cidr).As16(), + Addr: prefixToMask(cidr).As16(), }, Lifetime: addrLifetime{ Expire: 0, @@ -351,11 +355,10 @@ func (t *tun) addIp(cidr netip.Prefix) error { if err := ioctl(uintptr(s), OSIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil { return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err) } - } else { - return fmt.Errorf("Unknown address type") + return nil } - return t.addRoutes(false) + return fmt.Errorf("unknown address type %v", cidr) } func (t *tun) Activate() error { @@ -365,13 +368,23 @@ func (t *tun) Activate() error { return err } + linkAddr, err := getLinkAddr(t.Device) + if err != nil { + return err + } + if linkAddr == nil { + return fmt.Errorf("unable to discover link_addr for tun interface") + } + t.linkAddr = linkAddr + for i := range t.vpnNetworks { err := t.addIp(t.vpnNetworks[i]) if err != nil { return err } } - return nil + + return t.addRoutes(false) } func (t *tun) setMTU() error { @@ -449,15 +462,16 @@ func (t *tun) addRoutes(logErrors bool) error { continue } - cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device) - t.l.Debug("command: ", cmd.String()) - if err := cmd.Run(); err != nil { - retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err) + err := addRoute(r.Cidr, t.linkAddr) + if err != nil { + retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err) if logErrors { retErr.Log(t.l) } else { return retErr } + } else { + t.l.WithField("route", r).Info("Added route") } } @@ -470,9 +484,8 @@ func (t *tun) removeRoutes(routes []Route) error { continue } - cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), "-interface", t.Device) - t.l.Debug("command: ", cmd.String()) - if err := cmd.Run(); err != nil { + err := delRoute(r.Cidr, t.linkAddr) + if err != nil { t.l.WithError(err).WithField("route", r).Error("Failed to remove route") } else { t.l.WithField("route", r).Info("Removed route") @@ -502,22 +515,129 @@ func orBytes(a []byte, b []byte) []byte { return ret } -func getNetmask(cidr netip.Prefix) netip.Addr { - pLen := 128 - if cidr.Addr().Is4() { - pLen = 32 - } - - addr, _ := netip.AddrFromSlice(net.CIDRMask(cidr.Bits(), pLen)) - return addr -} - func getBroadcast(cidr netip.Prefix) netip.Addr { broadcast, _ := netip.AddrFromSlice( orBytes( cidr.Addr().AsSlice(), - flipBytes(getNetmask(cidr).AsSlice()), + flipBytes(prefixToMask(cidr).AsSlice()), ), ) return broadcast } + +func addRoute(prefix netip.Prefix, gateway netroute.Addr) error { + sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) + if err != nil { + return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) + } + defer unix.Close(sock) + + route := &netroute.RouteMessage{ + Version: unix.RTM_VERSION, + Type: unix.RTM_ADD, + Flags: unix.RTF_UP, + Seq: 1, + } + + if prefix.Addr().Is4() { + route.Addrs = []netroute.Addr{ + unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()}, + unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()}, + unix.RTAX_GATEWAY: gateway, + } + } else { + route.Addrs = []netroute.Addr{ + unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()}, + unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()}, + unix.RTAX_GATEWAY: gateway, + } + } + + data, err := route.Marshal() + if err != nil { + return fmt.Errorf("failed to create route.RouteMessage: %w", err) + } + + _, err = unix.Write(sock, data[:]) + if err != nil { + if errors.Is(err, unix.EEXIST) { + // Try to do a change + route.Type = unix.RTM_CHANGE + data, err = route.Marshal() + if err != nil { + return fmt.Errorf("failed to create route.RouteMessage for change: %w", err) + } + _, err = unix.Write(sock, data[:]) + fmt.Println("DOING CHANGE") + return err + } + return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err) + } + + return nil +} + +func delRoute(prefix netip.Prefix, gateway netroute.Addr) error { + sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) + if err != nil { + return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) + } + defer unix.Close(sock) + + route := netroute.RouteMessage{ + Version: unix.RTM_VERSION, + Type: unix.RTM_DELETE, + Seq: 1, + } + + if prefix.Addr().Is4() { + route.Addrs = []netroute.Addr{ + unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()}, + unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()}, + unix.RTAX_GATEWAY: gateway, + } + } else { + route.Addrs = []netroute.Addr{ + unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()}, + unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()}, + unix.RTAX_GATEWAY: gateway, + } + } + + data, err := route.Marshal() + if err != nil { + return fmt.Errorf("failed to create route.RouteMessage: %w", err) + } + _, err = unix.Write(sock, data[:]) + if err != nil { + return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err) + } + + return nil +} + +// getLinkAddr Gets the link address for the interface of the given name +func getLinkAddr(name string) (*netroute.LinkAddr, error) { + rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0) + if err != nil { + return nil, err + } + msgs, err := netroute.ParseRIB(unix.NET_RT_IFLIST, rib) + if err != nil { + return nil, err + } + + for _, m := range msgs { + switch m := m.(type) { + case *netroute.InterfaceMessage: + if m.Name == name { + sa, ok := m.Addrs[unix.RTAX_IFP].(*netroute.LinkAddr) + if ok { + return sa, nil + } + } + } + } + + return nil, nil +}