mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 16:34:25 +01:00
Use generics for CIDRTrees to avoid casting issues (#1004)
This commit is contained in:
@@ -21,8 +21,8 @@ type Route struct {
|
||||
Install bool
|
||||
}
|
||||
|
||||
func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*cidr.Tree4, error) {
|
||||
routeTree := cidr.NewTree4()
|
||||
func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*cidr.Tree4[iputil.VpnIp], error) {
|
||||
routeTree := cidr.NewTree4[iputil.VpnIp]()
|
||||
for _, r := range routes {
|
||||
if !allowMTU && r.MTU > 0 {
|
||||
l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS)
|
||||
|
||||
@@ -265,18 +265,16 @@ func Test_makeRouteTree(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
ip := iputil.Ip2VpnIp(net.ParseIP("1.0.0.2"))
|
||||
r := routeTree.MostSpecificContains(ip)
|
||||
assert.NotNil(t, r)
|
||||
assert.IsType(t, iputil.VpnIp(0), r)
|
||||
assert.EqualValues(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.1")), r)
|
||||
ok, r := routeTree.MostSpecificContains(ip)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.1")), r)
|
||||
|
||||
ip = iputil.Ip2VpnIp(net.ParseIP("1.0.0.1"))
|
||||
r = routeTree.MostSpecificContains(ip)
|
||||
assert.NotNil(t, r)
|
||||
assert.IsType(t, iputil.VpnIp(0), r)
|
||||
assert.EqualValues(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.2")), r)
|
||||
ok, r = routeTree.MostSpecificContains(ip)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.2")), r)
|
||||
|
||||
ip = iputil.Ip2VpnIp(net.ParseIP("1.1.0.1"))
|
||||
r = routeTree.MostSpecificContains(ip)
|
||||
assert.Nil(t, r)
|
||||
ok, r = routeTree.MostSpecificContains(ip)
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ type tun struct {
|
||||
cidr *net.IPNet
|
||||
DefaultMTU int
|
||||
Routes []Route
|
||||
routeTree *cidr.Tree4
|
||||
routeTree *cidr.Tree4[iputil.VpnIp]
|
||||
l *logrus.Logger
|
||||
|
||||
// cache out buffer since we need to prepend 4 bytes for tun metadata
|
||||
@@ -304,9 +304,9 @@ func (t *tun) Activate() error {
|
||||
}
|
||||
|
||||
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
||||
r := t.routeTree.MostSpecificContains(ip)
|
||||
if r != nil {
|
||||
return r.(iputil.VpnIp)
|
||||
ok, r := t.routeTree.MostSpecificContains(ip)
|
||||
if ok {
|
||||
return r
|
||||
}
|
||||
|
||||
return 0
|
||||
|
||||
@@ -48,7 +48,7 @@ type tun struct {
|
||||
cidr *net.IPNet
|
||||
MTU int
|
||||
Routes []Route
|
||||
routeTree *cidr.Tree4
|
||||
routeTree *cidr.Tree4[iputil.VpnIp]
|
||||
l *logrus.Logger
|
||||
|
||||
io.ReadWriteCloser
|
||||
@@ -192,12 +192,8 @@ func (t *tun) Activate() error {
|
||||
}
|
||||
|
||||
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
||||
r := t.routeTree.MostSpecificContains(ip)
|
||||
if r != nil {
|
||||
return r.(iputil.VpnIp)
|
||||
}
|
||||
|
||||
return 0
|
||||
_, r := t.routeTree.MostSpecificContains(ip)
|
||||
return r
|
||||
}
|
||||
|
||||
func (t *tun) Cidr() *net.IPNet {
|
||||
|
||||
@@ -30,7 +30,7 @@ type tun struct {
|
||||
TXQueueLen int
|
||||
|
||||
Routes []Route
|
||||
routeTree atomic.Pointer[cidr.Tree4]
|
||||
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
|
||||
routeChan chan struct{}
|
||||
useSystemRoutes bool
|
||||
|
||||
@@ -154,12 +154,8 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
}
|
||||
|
||||
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
||||
r := t.routeTree.Load().MostSpecificContains(ip)
|
||||
if r != nil {
|
||||
return r.(iputil.VpnIp)
|
||||
}
|
||||
|
||||
return 0
|
||||
_, r := t.routeTree.Load().MostSpecificContains(ip)
|
||||
return r
|
||||
}
|
||||
|
||||
func (t *tun) Write(b []byte) (int, error) {
|
||||
@@ -380,7 +376,7 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
||||
return
|
||||
}
|
||||
|
||||
newTree := cidr.NewTree4()
|
||||
newTree := cidr.NewTree4[iputil.VpnIp]()
|
||||
if r.Type == unix.RTM_NEWROUTE {
|
||||
for _, oldR := range t.routeTree.Load().List() {
|
||||
newTree.AddCIDR(oldR.CIDR, oldR.Value)
|
||||
@@ -392,7 +388,7 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
||||
} else {
|
||||
gw := iputil.Ip2VpnIp(r.Gw)
|
||||
for _, oldR := range t.routeTree.Load().List() {
|
||||
if bytes.Equal(oldR.CIDR.IP, r.Dst.IP) && bytes.Equal(oldR.CIDR.Mask, r.Dst.Mask) && *oldR.Value != nil && (*oldR.Value).(iputil.VpnIp) == gw {
|
||||
if bytes.Equal(oldR.CIDR.IP, r.Dst.IP) && bytes.Equal(oldR.CIDR.Mask, r.Dst.Mask) && oldR.Value == gw {
|
||||
// This is the record to delete
|
||||
t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route")
|
||||
continue
|
||||
|
||||
@@ -29,7 +29,7 @@ type tun struct {
|
||||
cidr *net.IPNet
|
||||
MTU int
|
||||
Routes []Route
|
||||
routeTree *cidr.Tree4
|
||||
routeTree *cidr.Tree4[iputil.VpnIp]
|
||||
l *logrus.Logger
|
||||
|
||||
io.ReadWriteCloser
|
||||
@@ -134,12 +134,8 @@ func (t *tun) Activate() error {
|
||||
}
|
||||
|
||||
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
||||
r := t.routeTree.MostSpecificContains(ip)
|
||||
if r != nil {
|
||||
return r.(iputil.VpnIp)
|
||||
}
|
||||
|
||||
return 0
|
||||
_, r := t.routeTree.MostSpecificContains(ip)
|
||||
return r
|
||||
}
|
||||
|
||||
func (t *tun) Cidr() *net.IPNet {
|
||||
|
||||
@@ -23,7 +23,7 @@ type tun struct {
|
||||
cidr *net.IPNet
|
||||
MTU int
|
||||
Routes []Route
|
||||
routeTree *cidr.Tree4
|
||||
routeTree *cidr.Tree4[iputil.VpnIp]
|
||||
l *logrus.Logger
|
||||
|
||||
io.ReadWriteCloser
|
||||
@@ -115,12 +115,8 @@ func (t *tun) Activate() error {
|
||||
}
|
||||
|
||||
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
||||
r := t.routeTree.MostSpecificContains(ip)
|
||||
if r != nil {
|
||||
return r.(iputil.VpnIp)
|
||||
}
|
||||
|
||||
return 0
|
||||
_, r := t.routeTree.MostSpecificContains(ip)
|
||||
return r
|
||||
}
|
||||
|
||||
func (t *tun) Cidr() *net.IPNet {
|
||||
|
||||
@@ -19,7 +19,7 @@ type TestTun struct {
|
||||
Device string
|
||||
cidr *net.IPNet
|
||||
Routes []Route
|
||||
routeTree *cidr.Tree4
|
||||
routeTree *cidr.Tree4[iputil.VpnIp]
|
||||
l *logrus.Logger
|
||||
|
||||
closed atomic.Bool
|
||||
@@ -83,12 +83,8 @@ func (t *TestTun) Get(block bool) []byte {
|
||||
//********************************************************************************************************************//
|
||||
|
||||
func (t *TestTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
||||
r := t.routeTree.MostSpecificContains(ip)
|
||||
if r != nil {
|
||||
return r.(iputil.VpnIp)
|
||||
}
|
||||
|
||||
return 0
|
||||
_, r := t.routeTree.MostSpecificContains(ip)
|
||||
return r
|
||||
}
|
||||
|
||||
func (t *TestTun) Activate() error {
|
||||
|
||||
@@ -18,7 +18,7 @@ type waterTun struct {
|
||||
cidr *net.IPNet
|
||||
MTU int
|
||||
Routes []Route
|
||||
routeTree *cidr.Tree4
|
||||
routeTree *cidr.Tree4[iputil.VpnIp]
|
||||
|
||||
*water.Interface
|
||||
}
|
||||
@@ -97,12 +97,8 @@ func (t *waterTun) Activate() error {
|
||||
}
|
||||
|
||||
func (t *waterTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
||||
r := t.routeTree.MostSpecificContains(ip)
|
||||
if r != nil {
|
||||
return r.(iputil.VpnIp)
|
||||
}
|
||||
|
||||
return 0
|
||||
_, r := t.routeTree.MostSpecificContains(ip)
|
||||
return r
|
||||
}
|
||||
|
||||
func (t *waterTun) Cidr() *net.IPNet {
|
||||
|
||||
@@ -24,7 +24,7 @@ type winTun struct {
|
||||
prefix netip.Prefix
|
||||
MTU int
|
||||
Routes []Route
|
||||
routeTree *cidr.Tree4
|
||||
routeTree *cidr.Tree4[iputil.VpnIp]
|
||||
|
||||
tun *wintun.NativeTun
|
||||
}
|
||||
@@ -146,12 +146,8 @@ func (t *winTun) Activate() error {
|
||||
}
|
||||
|
||||
func (t *winTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
|
||||
r := t.routeTree.MostSpecificContains(ip)
|
||||
if r != nil {
|
||||
return r.(iputil.VpnIp)
|
||||
}
|
||||
|
||||
return 0
|
||||
_, r := t.routeTree.MostSpecificContains(ip)
|
||||
return r
|
||||
}
|
||||
|
||||
func (t *winTun) Cidr() *net.IPNet {
|
||||
|
||||
Reference in New Issue
Block a user