diff --git a/e2e/helpers_test.go b/e2e/helpers_test.go index 3014096..fe3e6f1 100644 --- a/e2e/helpers_test.go +++ b/e2e/helpers_test.go @@ -29,8 +29,6 @@ type m = map[string]any // newSimpleServer creates a nebula instance with many assumptions func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) { - l := NewTestLogger() - var vpnNetworks []netip.Prefix for _, sn := range strings.Split(sVpnNetworks, ",") { vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn)) @@ -56,6 +54,25 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name budpIp[3] = 239 udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242) } + return newSimpleServerWithUdp(v, caCrt, caKey, name, sVpnNetworks, udpAddr, overrides) +} + +func newSimpleServerWithUdp(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) { + l := NewTestLogger() + + var vpnNetworks []netip.Prefix + for _, sn := range strings.Split(sVpnNetworks, ",") { + vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn)) + if err != nil { + panic(err) + } + vpnNetworks = append(vpnNetworks, vpnIpNet) + } + + if len(vpnNetworks) == 0 { + panic("no vpn networks") + } + _, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{}) caB, err := caCrt.MarshalPEM() diff --git a/e2e/tunnels_test.go b/e2e/tunnels_test.go index 27ea3d1..62027b7 100644 --- a/e2e/tunnels_test.go +++ b/e2e/tunnels_test.go @@ -318,3 +318,50 @@ func TestCertMismatchCorrection(t *testing.T) { myControl.Stop() theirControl.Stop() } + +func TestCrossStackRelaysWork(t *testing.T) { + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24,fc00::1/64", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "10.128.0.128/24,fc00::128/64", m{"relay": m{"am_relay": true}}) + theirUdp := netip.MustParseAddrPort("10.0.0.2:4242") + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdp(cert.Version2, ca, caKey, "them ", "fc00::2/64", theirUdp, m{"relay": m{"use_relays": true}}) + + //myVpnV4 := myVpnIpNet[0] + myVpnV6 := myVpnIpNet[1] + relayVpnV4 := relayVpnIpNet[0] + relayVpnV6 := relayVpnIpNet[1] + theirVpnV6 := theirVpnIpNet[0] + + // Teach my how to get to the relay and that their can be reached via the relay + myControl.InjectLightHouseAddr(relayVpnV4.Addr(), relayUdpAddr) + myControl.InjectLightHouseAddr(relayVpnV6.Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnV6.Addr(), []netip.Addr{relayVpnV6.Addr()}) + relayControl.InjectLightHouseAddr(theirVpnV6.Addr(), theirUdpAddr) + + // Build a router so we don't have to reason who gets which packet + r := router.NewR(t, myControl, relayControl, theirControl) + defer r.RenderFlow() + + // Start the servers + myControl.Start() + relayControl.Start() + theirControl.Start() + + t.Log("Trigger a handshake from me to them via the relay") + myControl.InjectTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me")) + + p := r.RouteForAllUntilTxTun(theirControl) + r.Log("Assert the tunnel works") + assertUdpPacket(t, []byte("Hi from me"), p, myVpnV6.Addr(), theirVpnV6.Addr(), 80, 80) + + t.Log("reply?") + theirControl.InjectTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them")) + p = r.RouteForAllUntilTxTun(myControl) + assertUdpPacket(t, []byte("Hi from them"), p, theirVpnV6.Addr(), myVpnV6.Addr(), 80, 80) + + r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl) + //t.Log("finish up") + //myControl.Stop() + //theirControl.Stop() + //relayControl.Stop() +} diff --git a/handshake_manager.go b/handshake_manager.go index ee72d71..3165a21 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -300,6 +300,8 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered InitiatorRelayIndex: idx, } + relayFrom := hm.f.myVpnAddrs[0] + switch relayHostInfo.GetCert().Certificate.Version() { case cert.Version1: if !hm.f.myVpnAddrs[0].Is4() { @@ -317,7 +319,13 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered b = vpnIp.As4() m.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) case cert.Version2: - m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0]) + if vpnIp.Is4() { + relayFrom = hm.f.myVpnAddrs[0] + } else { + //todo do this smarter + relayFrom = hm.f.myVpnAddrs[len(hm.f.myVpnAddrs)-1] + } + m.RelayFromAddr = netAddrToProtoAddr(relayFrom) m.RelayToAddr = netAddrToProtoAddr(vpnIp) default: hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay") @@ -332,7 +340,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered } else { hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) hm.l.WithFields(logrus.Fields{ - "relayFrom": hm.f.myVpnAddrs[0], + "relayFrom": relayFrom, "relayTo": vpnIp, "initiatorRelayIndex": idx, "relay": relay}). @@ -358,6 +366,8 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered InitiatorRelayIndex: existingRelay.LocalIndex, } + relayFrom := hm.f.myVpnAddrs[0] + switch relayHostInfo.GetCert().Certificate.Version() { case cert.Version1: if !hm.f.myVpnAddrs[0].Is4() { @@ -375,7 +385,14 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered b = vpnIp.As4() m.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) case cert.Version2: - m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0]) + if vpnIp.Is4() { + relayFrom = hm.f.myVpnAddrs[0] + } else { + //todo do this smarter + relayFrom = hm.f.myVpnAddrs[len(hm.f.myVpnAddrs)-1] + } + + m.RelayFromAddr = netAddrToProtoAddr(relayFrom) m.RelayToAddr = netAddrToProtoAddr(vpnIp) default: hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay") @@ -390,7 +407,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered // This must send over the hostinfo, not over hm.Hosts[ip] hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) hm.l.WithFields(logrus.Fields{ - "relayFrom": hm.f.myVpnAddrs[0], + "relayFrom": relayFrom, "relayTo": vpnIp, "initiatorRelayIndex": existingRelay.LocalIndex, "relay": relay}). diff --git a/hostmap.go b/hostmap.go index cd2e696..bd905cc 100644 --- a/hostmap.go +++ b/hostmap.go @@ -2,6 +2,7 @@ package nebula import ( "errors" + "fmt" "net" "net/netip" "slices" @@ -521,6 +522,7 @@ func (hm *HostMap) QueryVpnAddrsRelayFor(targetIps []netip.Addr, relayHostIp net return nil, nil, errors.New("unable to find host") } + lastH := h for h != nil { for _, targetIp := range targetIps { r, ok := h.relayState.QueryRelayForByIp(targetIp) @@ -528,10 +530,12 @@ func (hm *HostMap) QueryVpnAddrsRelayFor(targetIps []netip.Addr, relayHostIp net return h, r, nil } } + lastH = h h = h.next } - return nil, nil, errors.New("unable to find host with relay") + //todo no merge + return nil, nil, fmt.Errorf("unable to find host with relay: %v", lastH) } func (hm *HostMap) unlockedDisestablishVpnAddrRelayFor(hi *HostInfo) { diff --git a/relay_manager.go b/relay_manager.go index 5dd355c..2e9a7c7 100644 --- a/relay_manager.go +++ b/relay_manager.go @@ -190,6 +190,7 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f InitiatorRelayIndex: peerRelay.RemoteIndex, } + relayFrom := h.vpnAddrs[0] if v == cert.Version1 { peer := peerHostInfo.vpnAddrs[0] if !peer.Is4() { @@ -207,7 +208,13 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f b = targetAddr.As4() resp.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) } else { - resp.RelayFromAddr = netAddrToProtoAddr(peerHostInfo.vpnAddrs[0]) + if targetAddr.Is4() { + relayFrom = h.vpnAddrs[0] + } else { + //todo do this smarter + relayFrom = h.vpnAddrs[len(h.vpnAddrs)-1] + } + resp.RelayFromAddr = netAddrToProtoAddr(relayFrom) resp.RelayToAddr = target } @@ -360,7 +367,7 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f Type: NebulaControl_CreateRelayRequest, InitiatorRelayIndex: index, } - + relayFrom := h.vpnAddrs[0] if v == cert.Version1 { if !h.vpnAddrs[0].Is4() { rm.l.WithField("relayFrom", h.vpnAddrs[0]). @@ -377,7 +384,13 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f b = target.As4() req.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) } else { - req.RelayFromAddr = netAddrToProtoAddr(h.vpnAddrs[0]) + if target.Is4() { + relayFrom = h.vpnAddrs[0] + } else { + //todo do this smarter + relayFrom = h.vpnAddrs[len(h.vpnAddrs)-1] + } + req.RelayFromAddr = netAddrToProtoAddr(relayFrom) req.RelayToAddr = netAddrToProtoAddr(target) } @@ -388,7 +401,7 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f } else { f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu)) rm.l.WithFields(logrus.Fields{ - "relayFrom": h.vpnAddrs[0], + "relayFrom": relayFrom, "relayTo": target, "initiatorRelayIndex": req.InitiatorRelayIndex, "responderRelayIndex": req.ResponderRelayIndex,