[cert-v2] punchy-respond on an address in common with the querying host (#1261)

This commit is contained in:
Jack Doan 2024-11-12 22:14:44 -05:00 committed by GitHub
parent 602dca8508
commit 5380fef7b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 99 additions and 15 deletions

View File

@ -1108,32 +1108,44 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
lhh.lh.metricTx(NebulaMeta_HostQueryReply, 1) lhh.lh.metricTx(NebulaMeta_HostQueryReply, 1)
w.SendMessageToVpnAddr(header.LightHouse, 0, fromVpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0]) w.SendMessageToVpnAddr(header.LightHouse, 0, fromVpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0])
// This signals the other side to punch some zero byte udp packets lhh.sendHostPunchNotification(n, fromVpnAddrs, queryVpnAddr, w)
found, ln, err = lhh.lh.queryAndPrepMessage(fromVpnAddrs[0], func(c *cache) (int, error) { }
// sendHostPunchNotification signals the other side to punch some zero byte udp packets
func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, punchNotifDest netip.Addr, w EncWriter) {
whereToPunch := fromVpnAddrs[0]
found, ln, err := lhh.lh.queryAndPrepMessage(whereToPunch, func(c *cache) (int, error) {
n = lhh.resetMeta() n = lhh.resetMeta()
n.Type = NebulaMeta_HostPunchNotification n.Type = NebulaMeta_HostPunchNotification
targetHI := lhh.lh.ifce.GetHostInfo(queryVpnAddr) targetHI := lhh.lh.ifce.GetHostInfo(punchNotifDest)
var useVersion cert.Version
if targetHI == nil { if targetHI == nil {
useVersion = lhh.lh.ifce.GetCertState().defaultVersion useVersion = lhh.lh.ifce.GetCertState().defaultVersion
} else { } else {
useVersion = targetHI.GetCert().Certificate.Version() crt := targetHI.GetCert().Certificate
useVersion = crt.Version()
// we can only retarget if we have a hostinfo
newDest, ok := findNetworkUnion(crt.Networks(), fromVpnAddrs)
if ok {
whereToPunch = newDest
} else {
//TODO this means the destination will have no addresses in common with the punch-ee
//choosing to do nothing for now, but maybe we return an error?
}
} }
if useVersion == cert.Version1 { if useVersion == cert.Version1 {
if !fromVpnAddrs[0].Is4() { if !whereToPunch.Is4() {
return 0, fmt.Errorf("invalid vpn addr for v1 handleHostQuery") return 0, fmt.Errorf("invalid vpn addr for v1 handleHostQuery")
} }
b := fromVpnAddrs[0].As4() b := whereToPunch.As4()
n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:]) n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:])
lhh.coalesceAnswers(useVersion, c, n)
} else if useVersion == cert.Version2 { } else if useVersion == cert.Version2 {
n.Details.VpnAddr = netAddrToProtoAddr(fromVpnAddrs[0]) n.Details.VpnAddr = netAddrToProtoAddr(whereToPunch)
lhh.coalesceAnswers(useVersion, c, n)
} else { } else {
panic("unsupported version") return 0, errors.New("unsupported version")
} }
lhh.coalesceAnswers(useVersion, c, n)
return n.MarshalTo(lhh.pb) return n.MarshalTo(lhh.pb)
}) })
@ -1148,7 +1160,7 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
} }
lhh.lh.metricTx(NebulaMeta_HostPunchNotification, 1) lhh.lh.metricTx(NebulaMeta_HostPunchNotification, 1)
w.SendMessageToVpnAddr(header.LightHouse, 0, queryVpnAddr, lhh.pb[:ln], lhh.nb, lhh.out[:0]) w.SendMessageToVpnAddr(header.LightHouse, 0, punchNotifDest, lhh.pb[:ln], lhh.nb, lhh.out[:0])
} }
func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *NebulaMeta) { func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *NebulaMeta) {
@ -1429,3 +1441,15 @@ func (d *NebulaMetaDetails) GetRelays() []netip.Addr {
} }
return relays return relays
} }
// FindNetworkUnion returns the first netip.Addr contained in the list of provided netip.Prefix, if able
func findNetworkUnion(prefixes []netip.Prefix, addrs []netip.Addr) (netip.Addr, bool) {
for i := range prefixes {
for j := range addrs {
if prefixes[i].Contains(addrs[j]) {
return addrs[j], true
}
}
}
return netip.Addr{}, false
}

View File

@ -494,3 +494,63 @@ func assertIp4InArray(t *testing.T, have []*V4AddrPort, want ...netip.AddrPort)
} }
} }
} }
func Test_findNetworkUnion(t *testing.T) {
var out netip.Addr
var ok bool
tenDot := netip.MustParsePrefix("10.0.0.0/8")
oneSevenTwo := netip.MustParsePrefix("172.16.0.0/16")
fe80 := netip.MustParsePrefix("fe80::/8")
fc00 := netip.MustParsePrefix("fc00::/7")
a1 := netip.MustParseAddr("10.0.0.1")
afe81 := netip.MustParseAddr("fe80::1")
//simple
out, ok = findNetworkUnion([]netip.Prefix{tenDot}, []netip.Addr{a1})
assert.True(t, ok)
assert.Equal(t, out, a1)
//mixed lengths
out, ok = findNetworkUnion([]netip.Prefix{tenDot}, []netip.Addr{a1, afe81})
assert.True(t, ok)
assert.Equal(t, out, a1)
out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo}, []netip.Addr{a1})
assert.True(t, ok)
assert.Equal(t, out, a1)
//mixed family
out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{a1})
assert.True(t, ok)
assert.Equal(t, out, a1)
out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{a1, afe81})
assert.True(t, ok)
assert.Equal(t, out, a1)
//ordering
out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{afe81, a1})
assert.True(t, ok)
assert.Equal(t, out, a1)
out, ok = findNetworkUnion([]netip.Prefix{fe80, tenDot, oneSevenTwo}, []netip.Addr{afe81, a1})
assert.True(t, ok)
assert.Equal(t, out, afe81)
//some mismatches
out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{afe81})
assert.True(t, ok)
assert.Equal(t, out, afe81)
out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1, afe81})
assert.True(t, ok)
assert.Equal(t, out, afe81)
//falsey cases
out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1})
assert.False(t, ok)
out, ok = findNetworkUnion([]netip.Prefix{fc00, fe80}, []netip.Addr{a1})
assert.False(t, ok)
out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fc00}, []netip.Addr{a1, afe81})
assert.False(t, ok)
out, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81})
assert.False(t, ok)
}

View File

@ -137,8 +137,8 @@ func (rm *relayManager) HandleControlMsg(h *HostInfo, d []byte, f *Interface) {
func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) { func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) {
rm.l.WithFields(logrus.Fields{ rm.l.WithFields(logrus.Fields{
"relayFrom": m.RelayFromAddr, "relayFrom": protoAddrToNetAddr(m.RelayFromAddr),
"relayTo": m.RelayToAddr, "relayTo": protoAddrToNetAddr(m.RelayToAddr),
"initiatorRelayIndex": m.InitiatorRelayIndex, "initiatorRelayIndex": m.InitiatorRelayIndex,
"responderRelayIndex": m.ResponderRelayIndex, "responderRelayIndex": m.ResponderRelayIndex,
"vpnAddrs": h.vpnAddrs}). "vpnAddrs": h.vpnAddrs}).