Use generics for CIDRTrees to avoid casting issues (#1004)

This commit is contained in:
Nate Brown
2023-11-02 17:05:08 -05:00
committed by GitHub
parent a44e1b8b05
commit 5181cb0474
21 changed files with 264 additions and 247 deletions

View File

@@ -21,8 +21,8 @@ type Route struct {
Install bool
}
func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*cidr.Tree4, error) {
routeTree := cidr.NewTree4()
func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*cidr.Tree4[iputil.VpnIp], error) {
routeTree := cidr.NewTree4[iputil.VpnIp]()
for _, r := range routes {
if !allowMTU && r.MTU > 0 {
l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS)

View File

@@ -265,18 +265,16 @@ func Test_makeRouteTree(t *testing.T) {
assert.NoError(t, err)
ip := iputil.Ip2VpnIp(net.ParseIP("1.0.0.2"))
r := routeTree.MostSpecificContains(ip)
assert.NotNil(t, r)
assert.IsType(t, iputil.VpnIp(0), r)
assert.EqualValues(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.1")), r)
ok, r := routeTree.MostSpecificContains(ip)
assert.True(t, ok)
assert.Equal(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.1")), r)
ip = iputil.Ip2VpnIp(net.ParseIP("1.0.0.1"))
r = routeTree.MostSpecificContains(ip)
assert.NotNil(t, r)
assert.IsType(t, iputil.VpnIp(0), r)
assert.EqualValues(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.2")), r)
ok, r = routeTree.MostSpecificContains(ip)
assert.True(t, ok)
assert.Equal(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.2")), r)
ip = iputil.Ip2VpnIp(net.ParseIP("1.1.0.1"))
r = routeTree.MostSpecificContains(ip)
assert.Nil(t, r)
ok, r = routeTree.MostSpecificContains(ip)
assert.False(t, ok)
}

View File

@@ -25,7 +25,7 @@ type tun struct {
cidr *net.IPNet
DefaultMTU int
Routes []Route
routeTree *cidr.Tree4
routeTree *cidr.Tree4[iputil.VpnIp]
l *logrus.Logger
// cache out buffer since we need to prepend 4 bytes for tun metadata
@@ -304,9 +304,9 @@ func (t *tun) Activate() error {
}
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
r := t.routeTree.MostSpecificContains(ip)
if r != nil {
return r.(iputil.VpnIp)
ok, r := t.routeTree.MostSpecificContains(ip)
if ok {
return r
}
return 0

View File

@@ -48,7 +48,7 @@ type tun struct {
cidr *net.IPNet
MTU int
Routes []Route
routeTree *cidr.Tree4
routeTree *cidr.Tree4[iputil.VpnIp]
l *logrus.Logger
io.ReadWriteCloser
@@ -192,12 +192,8 @@ func (t *tun) Activate() error {
}
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
r := t.routeTree.MostSpecificContains(ip)
if r != nil {
return r.(iputil.VpnIp)
}
return 0
_, r := t.routeTree.MostSpecificContains(ip)
return r
}
func (t *tun) Cidr() *net.IPNet {

View File

@@ -30,7 +30,7 @@ type tun struct {
TXQueueLen int
Routes []Route
routeTree atomic.Pointer[cidr.Tree4]
routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
routeChan chan struct{}
useSystemRoutes bool
@@ -154,12 +154,8 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
}
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
r := t.routeTree.Load().MostSpecificContains(ip)
if r != nil {
return r.(iputil.VpnIp)
}
return 0
_, r := t.routeTree.Load().MostSpecificContains(ip)
return r
}
func (t *tun) Write(b []byte) (int, error) {
@@ -380,7 +376,7 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
return
}
newTree := cidr.NewTree4()
newTree := cidr.NewTree4[iputil.VpnIp]()
if r.Type == unix.RTM_NEWROUTE {
for _, oldR := range t.routeTree.Load().List() {
newTree.AddCIDR(oldR.CIDR, oldR.Value)
@@ -392,7 +388,7 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
} else {
gw := iputil.Ip2VpnIp(r.Gw)
for _, oldR := range t.routeTree.Load().List() {
if bytes.Equal(oldR.CIDR.IP, r.Dst.IP) && bytes.Equal(oldR.CIDR.Mask, r.Dst.Mask) && *oldR.Value != nil && (*oldR.Value).(iputil.VpnIp) == gw {
if bytes.Equal(oldR.CIDR.IP, r.Dst.IP) && bytes.Equal(oldR.CIDR.Mask, r.Dst.Mask) && oldR.Value == gw {
// This is the record to delete
t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route")
continue

View File

@@ -29,7 +29,7 @@ type tun struct {
cidr *net.IPNet
MTU int
Routes []Route
routeTree *cidr.Tree4
routeTree *cidr.Tree4[iputil.VpnIp]
l *logrus.Logger
io.ReadWriteCloser
@@ -134,12 +134,8 @@ func (t *tun) Activate() error {
}
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
r := t.routeTree.MostSpecificContains(ip)
if r != nil {
return r.(iputil.VpnIp)
}
return 0
_, r := t.routeTree.MostSpecificContains(ip)
return r
}
func (t *tun) Cidr() *net.IPNet {

View File

@@ -23,7 +23,7 @@ type tun struct {
cidr *net.IPNet
MTU int
Routes []Route
routeTree *cidr.Tree4
routeTree *cidr.Tree4[iputil.VpnIp]
l *logrus.Logger
io.ReadWriteCloser
@@ -115,12 +115,8 @@ func (t *tun) Activate() error {
}
func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
r := t.routeTree.MostSpecificContains(ip)
if r != nil {
return r.(iputil.VpnIp)
}
return 0
_, r := t.routeTree.MostSpecificContains(ip)
return r
}
func (t *tun) Cidr() *net.IPNet {

View File

@@ -19,7 +19,7 @@ type TestTun struct {
Device string
cidr *net.IPNet
Routes []Route
routeTree *cidr.Tree4
routeTree *cidr.Tree4[iputil.VpnIp]
l *logrus.Logger
closed atomic.Bool
@@ -83,12 +83,8 @@ func (t *TestTun) Get(block bool) []byte {
//********************************************************************************************************************//
func (t *TestTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
r := t.routeTree.MostSpecificContains(ip)
if r != nil {
return r.(iputil.VpnIp)
}
return 0
_, r := t.routeTree.MostSpecificContains(ip)
return r
}
func (t *TestTun) Activate() error {

View File

@@ -18,7 +18,7 @@ type waterTun struct {
cidr *net.IPNet
MTU int
Routes []Route
routeTree *cidr.Tree4
routeTree *cidr.Tree4[iputil.VpnIp]
*water.Interface
}
@@ -97,12 +97,8 @@ func (t *waterTun) Activate() error {
}
func (t *waterTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
r := t.routeTree.MostSpecificContains(ip)
if r != nil {
return r.(iputil.VpnIp)
}
return 0
_, r := t.routeTree.MostSpecificContains(ip)
return r
}
func (t *waterTun) Cidr() *net.IPNet {

View File

@@ -24,7 +24,7 @@ type winTun struct {
prefix netip.Prefix
MTU int
Routes []Route
routeTree *cidr.Tree4
routeTree *cidr.Tree4[iputil.VpnIp]
tun *wintun.NativeTun
}
@@ -146,12 +146,8 @@ func (t *winTun) Activate() error {
}
func (t *winTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
r := t.routeTree.MostSpecificContains(ip)
if r != nil {
return r.(iputil.VpnIp)
}
return 0
_, r := t.routeTree.MostSpecificContains(ip)
return r
}
func (t *winTun) Cidr() *net.IPNet {