diff --git a/e2e/helpers_test.go b/e2e/helpers_test.go index a63b3d0..406a037 100644 --- a/e2e/helpers_test.go +++ b/e2e/helpers_test.go @@ -29,8 +29,6 @@ type m = map[string]any // newSimpleServer creates a nebula instance with many assumptions func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) { - l := NewTestLogger() - var vpnNetworks []netip.Prefix for _, sn := range strings.Split(sVpnNetworks, ",") { vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn)) @@ -56,6 +54,25 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name budpIp[3] = 239 udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242) } + return newSimpleServerWithUdp(v, caCrt, caKey, name, sVpnNetworks, udpAddr, overrides) +} + +func newSimpleServerWithUdp(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) { + l := NewTestLogger() + + var vpnNetworks []netip.Prefix + for _, sn := range strings.Split(sVpnNetworks, ",") { + vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn)) + if err != nil { + panic(err) + } + vpnNetworks = append(vpnNetworks, vpnIpNet) + } + + if len(vpnNetworks) == 0 { + panic("no vpn networks") + } + _, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{}) caB, err := caCrt.MarshalPEM() diff --git a/e2e/tunnels_test.go b/e2e/tunnels_test.go index 55974f0..5e23bfd 100644 --- a/e2e/tunnels_test.go +++ b/e2e/tunnels_test.go @@ -4,6 +4,7 @@ package e2e import ( + "net/netip" "testing" "time" @@ -55,3 +56,50 @@ func TestDropInactiveTunnels(t *testing.T) { myControl.Stop() theirControl.Stop() } + +func TestCrossStackRelaysWork(t *testing.T) { + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24,fc00::1/64", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "10.128.0.128/24,fc00::128/64", m{"relay": m{"am_relay": true}}) + theirUdp := netip.MustParseAddrPort("10.0.0.2:4242") + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdp(cert.Version2, ca, caKey, "them ", "fc00::2/64", theirUdp, m{"relay": m{"use_relays": true}}) + + //myVpnV4 := myVpnIpNet[0] + myVpnV6 := myVpnIpNet[1] + relayVpnV4 := relayVpnIpNet[0] + relayVpnV6 := relayVpnIpNet[1] + theirVpnV6 := theirVpnIpNet[0] + + // Teach my how to get to the relay and that their can be reached via the relay + myControl.InjectLightHouseAddr(relayVpnV4.Addr(), relayUdpAddr) + myControl.InjectLightHouseAddr(relayVpnV6.Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnV6.Addr(), []netip.Addr{relayVpnV6.Addr()}) + relayControl.InjectLightHouseAddr(theirVpnV6.Addr(), theirUdpAddr) + + // Build a router so we don't have to reason who gets which packet + r := router.NewR(t, myControl, relayControl, theirControl) + defer r.RenderFlow() + + // Start the servers + myControl.Start() + relayControl.Start() + theirControl.Start() + + t.Log("Trigger a handshake from me to them via the relay") + myControl.InjectTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me")) + + p := r.RouteForAllUntilTxTun(theirControl) + r.Log("Assert the tunnel works") + assertUdpPacket(t, []byte("Hi from me"), p, myVpnV6.Addr(), theirVpnV6.Addr(), 80, 80) + + t.Log("reply?") + theirControl.InjectTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them")) + p = r.RouteForAllUntilTxTun(myControl) + assertUdpPacket(t, []byte("Hi from them"), p, theirVpnV6.Addr(), myVpnV6.Addr(), 80, 80) + + r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl) + //t.Log("finish up") + //myControl.Stop() + //theirControl.Stop() + //relayControl.Stop() +} diff --git a/firewall.go b/firewall.go index 971c156..3359082 100644 --- a/firewall.go +++ b/firewall.go @@ -429,6 +429,7 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * return nil } + // TODO if we don't have a network in common with this packet's source IP, (and it's not for an unsafe_network), do we reject it? // Make sure remote address matches nebula certificate if h.networks != nil { if !h.networks.Contains(fp.RemoteAddr) { diff --git a/handshake_ix.go b/handshake_ix.go index 026bfbd..752b0a7 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -183,17 +183,18 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet return } - var vpnAddrs []netip.Addr - var filteredNetworks []netip.Prefix certName := remoteCert.Certificate.Name() certVersion := remoteCert.Certificate.Version() fingerprint := remoteCert.Fingerprint issuer := remoteCert.Certificate.Issuer() + vpnNetworks := remoteCert.Certificate.Networks() - for _, network := range remoteCert.Certificate.Networks() { + anyVpnAddrsInCommon := false + vpnAddrs := make([]netip.Addr, len(vpnNetworks)) + for i, network := range vpnNetworks { vpnAddr := network.Addr() if f.myVpnAddrsTable.Contains(vpnAddr) { - f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr). + f.l.WithField("vpnNetworks", vpnNetworks).WithField("udpAddr", addr). WithField("certName", certName). WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). @@ -201,18 +202,15 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself") return } - - // vpnAddrs outside our vpn networks are of no use to us, filter them out - if !f.myVpnNetworksTable.Contains(vpnAddr) { - continue + vpnAddrs[i] = network.Addr() + if f.myVpnNetworksTable.Contains(vpnAddr) { + anyVpnAddrsInCommon = true } - - filteredNetworks = append(filteredNetworks, network) - vpnAddrs = append(vpnAddrs, vpnAddr) } - if len(vpnAddrs) == 0 { - f.l.WithError(err).WithField("udpAddr", addr). + if !anyVpnAddrsInCommon { + f.l.WithField("vpnNetworks", vpnNetworks). + WithField("udpAddr", addr). WithField("certName", certName). WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). @@ -332,7 +330,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs) hostinfo.SetRemote(addr) - hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks()) + hostinfo.buildNetworks(remoteCert.Certificate.Networks(), remoteCert.Certificate.UnsafeNetworks()) existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f) if err != nil { @@ -573,20 +571,16 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) } - var vpnAddrs []netip.Addr - var filteredNetworks []netip.Prefix - for _, network := range vpnNetworks { - // vpnAddrs outside our vpn networks are of no use to us, filter them out - vpnAddr := network.Addr() - if !f.myVpnNetworksTable.Contains(vpnAddr) { - continue + anyVpnAddrsInCommon := false + vpnAddrs := make([]netip.Addr, len(vpnNetworks)) + for i, network := range vpnNetworks { + vpnAddrs[i] = network.Addr() + if f.myVpnNetworksTable.Contains(network.Addr()) { + anyVpnAddrsInCommon = true } - - filteredNetworks = append(filteredNetworks, network) - vpnAddrs = append(vpnAddrs, vpnAddr) } - if len(vpnAddrs) == 0 { + if !anyVpnAddrsInCommon { f.l.WithError(err).WithField("udpAddr", addr). WithField("certName", certName). WithField("certVersion", certVersion). @@ -609,6 +603,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha f.handshakeManager.DeleteHostInfo(hostinfo) // Create a new hostinfo/handshake for the intended vpn ip + //TODO is hostinfo.vpnAddrs[0] always the address to use? f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) { // Block the current used address newHH.hostinfo.remotes = hostinfo.remotes @@ -648,7 +643,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha // Build up the radix for the firewall if we have subnets in the cert hostinfo.vpnAddrs = vpnAddrs - hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks()) + hostinfo.buildNetworks(remoteCert.Certificate.Networks(), remoteCert.Certificate.UnsafeNetworks()) // Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here f.handshakeManager.Complete(hostinfo, f) diff --git a/hostmap.go b/hostmap.go index 66b4851..ea221a8 100644 --- a/hostmap.go +++ b/hostmap.go @@ -220,12 +220,10 @@ type HostInfo struct { remoteIndexId uint32 localIndexId uint32 - // vpnAddrs is a list of vpn addresses assigned to this host that are within our own vpn networks - // The host may have other vpn addresses that are outside our - // vpn networks but were removed because they are not usable + // vpnAddrs is a list of vpn addresses assigned to this host vpnAddrs []netip.Addr - // networks are both all vpn and unsafe networks assigned to this host + // networks is a combination of specific vpn addresses (not prefixes!) and full unsafe networks assigned to this host. networks *bart.Lite relayState RelayState