diff --git a/lighthouse.go b/lighthouse.go index 9f00c39..9ca5837 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -1017,17 +1017,17 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta { return lhh.meta } -func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs []netip.Addr, p []byte, w EncWriter) { +func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, hostinfo *HostInfo, p []byte, w EncWriter) { n := lhh.resetMeta() err := n.Unmarshal(p) if err != nil { - lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr). + lhh.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", rAddr). Error("Failed to unmarshal lighthouse packet") return } if n.Details == nil { - lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr). + lhh.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", rAddr). Error("Invalid lighthouse update") return } @@ -1036,24 +1036,24 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs [ switch n.Type { case NebulaMeta_HostQuery: - lhh.handleHostQuery(n, fromVpnAddrs, rAddr, w) + lhh.handleHostQuery(n, hostinfo, rAddr, w) case NebulaMeta_HostQueryReply: - lhh.handleHostQueryReply(n, fromVpnAddrs) + lhh.handleHostQueryReply(n, hostinfo.vpnAddrs) case NebulaMeta_HostUpdateNotification: - lhh.handleHostUpdateNotification(n, fromVpnAddrs, w) + lhh.handleHostUpdateNotification(n, hostinfo, w) case NebulaMeta_HostMovedNotification: case NebulaMeta_HostPunchNotification: - lhh.handleHostPunchNotification(n, fromVpnAddrs, w) + lhh.handleHostPunchNotification(n, hostinfo.vpnAddrs, w) case NebulaMeta_HostUpdateNotificationAck: // noop } } -func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []netip.Addr, addr netip.AddrPort, w EncWriter) { +func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, hostinfo *HostInfo, addr netip.AddrPort, w EncWriter) { // Exit if we don't answer queries if !lhh.lh.amLighthouse { if lhh.l.Level >= logrus.DebugLevel { @@ -1065,7 +1065,7 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion() if err != nil { if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details). + lhh.l.WithField("from", hostinfo.vpnAddrs).WithField("details", n.Details). Debugln("Dropping malformed HostQuery") } return @@ -1073,7 +1073,7 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti if useVersion == cert.Version1 && queryVpnAddr.Is6() { // this case really shouldn't be possible to represent, but reject it anyway. if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("queryVpnAddr", queryVpnAddr). + lhh.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("queryVpnAddr", queryVpnAddr). Debugln("invalid vpn addr for v1 handleHostQuery") } return @@ -1099,14 +1099,14 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti } if err != nil { - lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host query reply") + lhh.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).Error("Failed to marshal lighthouse host query reply") return } lhh.lh.metricTx(NebulaMeta_HostQueryReply, 1) - w.SendMessageToVpnAddr(header.LightHouse, 0, fromVpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0]) + w.SendMessageToHostInfo(header.LightHouse, 0, hostinfo, lhh.pb[:ln], lhh.nb, lhh.out[:0]) - lhh.sendHostPunchNotification(n, fromVpnAddrs, queryVpnAddr, w) + lhh.sendHostPunchNotification(n, hostinfo.vpnAddrs, queryVpnAddr, w) } // sendHostPunchNotification signals the other side to punch some zero byte udp packets @@ -1234,7 +1234,8 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [ } } -func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) { +func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, hostinfo *HostInfo, w EncWriter) { + fromVpnAddrs := hostinfo.vpnAddrs if !lhh.lh.amLighthouse { if lhh.l.Level >= logrus.DebugLevel { lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", fromVpnAddrs) @@ -1302,7 +1303,7 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp } lhh.lh.metricTx(NebulaMeta_HostUpdateNotificationAck, 1) - w.SendMessageToVpnAddr(header.LightHouse, 0, fromVpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0]) + w.SendMessageToHostInfo(header.LightHouse, 0, hostinfo, lhh.pb[:ln], lhh.nb, lhh.out[:0]) } func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) { diff --git a/lighthouse_test.go b/lighthouse_test.go index d8a1188..8535269 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -132,8 +132,13 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { ) mw := &mockEncWriter{} - - hi := []netip.Addr{vpnIp2} + hostinfo := &HostInfo{ + ConnectionState: &ConnectionState{ + eKey: nil, + dKey: nil, + }, + vpnAddrs: []netip.Addr{vpnIp2}, + } b.Run("notfound", func(b *testing.B) { lhh := lh.NewRequestHandler() req := &NebulaMeta{ @@ -146,7 +151,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { p, err := req.Marshal() require.NoError(b, err) for n := 0; n < b.N; n++ { - lhh.HandleRequest(rAddr, hi, p, mw) + lhh.HandleRequest(rAddr, hostinfo, p, mw) } }) b.Run("found", func(b *testing.B) { @@ -162,7 +167,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { require.NoError(b, err) for n := 0; n < b.N; n++ { - lhh.HandleRequest(rAddr, hi, p, mw) + lhh.HandleRequest(rAddr, hostinfo, p, mw) } }) } @@ -326,7 +331,14 @@ func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, l w := &testEncWriter{ metaFilter: &filter, } - lhh.HandleRequest(fromAddr, []netip.Addr{myVpnIp}, b, w) + hostinfo := &HostInfo{ + ConnectionState: &ConnectionState{ + eKey: nil, + dKey: nil, + }, + vpnAddrs: []netip.Addr{myVpnIp}, + } + lhh.HandleRequest(fromAddr, hostinfo, b, w) return w.lastReply } @@ -355,9 +367,15 @@ func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.Ad if err != nil { panic(err) } - + hostinfo := &HostInfo{ + ConnectionState: &ConnectionState{ + eKey: nil, + dKey: nil, + }, + vpnAddrs: []netip.Addr{vpnIp}, + } w := &testEncWriter{} - lhh.HandleRequest(fromAddr, []netip.Addr{vpnIp}, b, w) + lhh.HandleRequest(fromAddr, hostinfo, b, w) } type testLhReply struct { diff --git a/outside.go b/outside.go index 5ff87bd..dbbbb70 100644 --- a/outside.go +++ b/outside.go @@ -138,7 +138,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] return } - lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f) + lhf.HandleRequest(ip, hostinfo, d, f) // Fallthrough to the bottom to record incoming traffic