Use x/net/route to manage routes directly (#1488)

This commit is contained in:
Nate Brown
2025-10-03 11:59:53 -04:00
committed by GitHub
parent b1f53d8d25
commit fb7f0c3657
3 changed files with 158 additions and 38 deletions

View File

@@ -1,6 +1,7 @@
package overlay package overlay
import ( import (
"net"
"net/netip" "net/netip"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@@ -70,3 +71,13 @@ func findRemovedRoutes(newRoutes, oldRoutes []Route) []Route {
return removed return removed
} }
func prefixToMask(prefix netip.Prefix) netip.Addr {
pLen := 128
if prefix.Addr().Is4() {
pLen = 32
}
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
return addr
}

View File

@@ -7,7 +7,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net"
"net/netip" "net/netip"
"os" "os"
"sync/atomic" "sync/atomic"
@@ -554,13 +553,3 @@ func (t *tun) Name() string {
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin") return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
} }
func prefixToMask(prefix netip.Prefix) netip.Addr {
pLen := 128
if prefix.Addr().Is4() {
pLen = 32
}
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
return addr
}

View File

@@ -9,9 +9,7 @@ import (
"fmt" "fmt"
"io" "io"
"io/fs" "io/fs"
"net"
"net/netip" "net/netip"
"os/exec"
"sync/atomic" "sync/atomic"
"syscall" "syscall"
"time" "time"
@@ -22,6 +20,7 @@ import (
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
netroute "golang.org/x/net/route"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
@@ -92,6 +91,7 @@ type tun struct {
MTU int MTU int
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]] routeTree atomic.Pointer[bart.Table[routing.Gateways]]
linkAddr *netroute.LinkAddr
l *logrus.Logger l *logrus.Logger
devFd int devFd int
} }
@@ -162,6 +162,7 @@ func (t *tun) Write(from []byte) (int, error) {
} else { } else {
err = nil err = nil
} }
return int(n) - 4, err return int(n) - 4, err
} }
@@ -308,7 +309,7 @@ func (t *tun) addIp(cidr netip.Prefix) error {
MaskAddr: unix.RawSockaddrInet4{ MaskAddr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4, Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET, Family: unix.AF_INET,
Addr: getNetmask(cidr).As4(), Addr: prefixToMask(cidr).As4(),
}, },
VHid: 0, VHid: 0,
} }
@@ -321,7 +322,10 @@ func (t *tun) addIp(cidr netip.Prefix) error {
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil { if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err) return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
} }
} else if cidr.Addr().Is6() { return nil
}
if cidr.Addr().Is6() {
ifr := ifreqAlias6{ ifr := ifreqAlias6{
Name: t.deviceBytes(), Name: t.deviceBytes(),
Addr: unix.RawSockaddrInet6{ Addr: unix.RawSockaddrInet6{
@@ -332,7 +336,7 @@ func (t *tun) addIp(cidr netip.Prefix) error {
PrefixMask: unix.RawSockaddrInet6{ PrefixMask: unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6, Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6, Family: unix.AF_INET6,
Addr: getNetmask(cidr).As16(), Addr: prefixToMask(cidr).As16(),
}, },
Lifetime: addrLifetime{ Lifetime: addrLifetime{
Expire: 0, Expire: 0,
@@ -351,11 +355,10 @@ func (t *tun) addIp(cidr netip.Prefix) error {
if err := ioctl(uintptr(s), OSIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil { if err := ioctl(uintptr(s), OSIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil {
return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err) return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
} }
} else { return nil
return fmt.Errorf("Unknown address type")
} }
return t.addRoutes(false) return fmt.Errorf("unknown address type %v", cidr)
} }
func (t *tun) Activate() error { func (t *tun) Activate() error {
@@ -365,13 +368,23 @@ func (t *tun) Activate() error {
return err return err
} }
linkAddr, err := getLinkAddr(t.Device)
if err != nil {
return err
}
if linkAddr == nil {
return fmt.Errorf("unable to discover link_addr for tun interface")
}
t.linkAddr = linkAddr
for i := range t.vpnNetworks { for i := range t.vpnNetworks {
err := t.addIp(t.vpnNetworks[i]) err := t.addIp(t.vpnNetworks[i])
if err != nil { if err != nil {
return err return err
} }
} }
return nil
return t.addRoutes(false)
} }
func (t *tun) setMTU() error { func (t *tun) setMTU() error {
@@ -449,15 +462,16 @@ func (t *tun) addRoutes(logErrors bool) error {
continue continue
} }
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device) err := addRoute(r.Cidr, t.linkAddr)
t.l.Debug("command: ", cmd.String()) if err != nil {
if err := cmd.Run(); err != nil { retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
if logErrors { if logErrors {
retErr.Log(t.l) retErr.Log(t.l)
} else { } else {
return retErr return retErr
} }
} else {
t.l.WithField("route", r).Info("Added route")
} }
} }
@@ -470,9 +484,8 @@ func (t *tun) removeRoutes(routes []Route) error {
continue continue
} }
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), "-interface", t.Device) err := delRoute(r.Cidr, t.linkAddr)
t.l.Debug("command: ", cmd.String()) if err != nil {
if err := cmd.Run(); err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route") t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else { } else {
t.l.WithField("route", r).Info("Removed route") t.l.WithField("route", r).Info("Removed route")
@@ -502,22 +515,129 @@ func orBytes(a []byte, b []byte) []byte {
return ret return ret
} }
func getNetmask(cidr netip.Prefix) netip.Addr {
pLen := 128
if cidr.Addr().Is4() {
pLen = 32
}
addr, _ := netip.AddrFromSlice(net.CIDRMask(cidr.Bits(), pLen))
return addr
}
func getBroadcast(cidr netip.Prefix) netip.Addr { func getBroadcast(cidr netip.Prefix) netip.Addr {
broadcast, _ := netip.AddrFromSlice( broadcast, _ := netip.AddrFromSlice(
orBytes( orBytes(
cidr.Addr().AsSlice(), cidr.Addr().AsSlice(),
flipBytes(getNetmask(cidr).AsSlice()), flipBytes(prefixToMask(cidr).AsSlice()),
), ),
) )
return broadcast return broadcast
} }
func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := &netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_ADD,
Flags: unix.RTF_UP,
Seq: 1,
}
if prefix.Addr().Is4() {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: gateway,
}
} else {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: gateway,
}
}
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
if errors.Is(err, unix.EEXIST) {
// Try to do a change
route.Type = unix.RTM_CHANGE
data, err = route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
}
_, err = unix.Write(sock, data[:])
fmt.Println("DOING CHANGE")
return err
}
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
}
return nil
}
func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_DELETE,
Seq: 1,
}
if prefix.Addr().Is4() {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: gateway,
}
} else {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: gateway,
}
}
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
}
return nil
}
// getLinkAddr Gets the link address for the interface of the given name
func getLinkAddr(name string) (*netroute.LinkAddr, error) {
rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0)
if err != nil {
return nil, err
}
msgs, err := netroute.ParseRIB(unix.NET_RT_IFLIST, rib)
if err != nil {
return nil, err
}
for _, m := range msgs {
switch m := m.(type) {
case *netroute.InterfaceMessage:
if m.Name == name {
sa, ok := m.Addrs[unix.RTAX_IFP].(*netroute.LinkAddr)
if ok {
return sa, nil
}
}
}
}
return nil, nil
}