mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 16:34:25 +01:00
Compare commits
6 Commits
jay.wren-w
...
cross-stac
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f597aa71e3 | ||
|
|
20b7219fbe | ||
|
|
3b53c27170 | ||
|
|
526236c5fa | ||
|
|
0ab2882b78 | ||
|
|
889d49ff82 |
@@ -114,33 +114,6 @@ func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []by
|
||||
return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem
|
||||
}
|
||||
|
||||
func NewTestCertDifferentVersion(c cert.Certificate, v cert.Version, ca cert.Certificate, key []byte) (cert.Certificate, []byte) {
|
||||
nc := &cert.TBSCertificate{
|
||||
Version: v,
|
||||
Curve: c.Curve(),
|
||||
Name: c.Name(),
|
||||
Networks: c.Networks(),
|
||||
UnsafeNetworks: c.UnsafeNetworks(),
|
||||
Groups: c.Groups(),
|
||||
NotBefore: time.Unix(c.NotBefore().Unix(), 0),
|
||||
NotAfter: time.Unix(c.NotAfter().Unix(), 0),
|
||||
PublicKey: c.PublicKey(),
|
||||
IsCA: false,
|
||||
}
|
||||
|
||||
c, err := nc.Sign(ca, ca.Curve(), key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
pem, err := c.MarshalPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return c, pem
|
||||
}
|
||||
|
||||
func X25519Keypair() ([]byte, []byte) {
|
||||
privkey := make([]byte, 32)
|
||||
if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
|
||||
|
||||
@@ -354,6 +354,7 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
|
||||
|
||||
if mainHostInfo {
|
||||
decision = tryRehandshake
|
||||
|
||||
} else {
|
||||
if cm.shouldSwapPrimary(hostinfo) {
|
||||
decision = swapPrimary
|
||||
@@ -460,10 +461,6 @@ func (cm *connectionManager) shouldSwapPrimary(current *HostInfo) bool {
|
||||
}
|
||||
|
||||
crt := cm.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
|
||||
if crt == nil {
|
||||
//my cert was reloaded away. We should definitely swap from this tunnel
|
||||
return true
|
||||
}
|
||||
// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
|
||||
// settle down.
|
||||
return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
|
||||
@@ -478,34 +475,31 @@ func (cm *connectionManager) swapPrimary(current, primary *HostInfo) {
|
||||
cm.hostMap.Unlock()
|
||||
}
|
||||
|
||||
// isInvalidCertificate decides if we should destroy a tunnel.
|
||||
// returns true if pki.disconnect_invalid is true and the certificate is no longer valid.
|
||||
// Blocklisted certificates will skip the pki.disconnect_invalid check and return true.
|
||||
// isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and
|
||||
// the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid
|
||||
// check and return true.
|
||||
func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
|
||||
remoteCert := hostinfo.GetCert()
|
||||
if remoteCert == nil {
|
||||
return false //don't tear down tunnels for handshakes in progress
|
||||
return false
|
||||
}
|
||||
|
||||
caPool := cm.intf.pki.GetCAPool()
|
||||
err := caPool.VerifyCachedCertificate(now, remoteCert)
|
||||
if err == nil {
|
||||
return false //cert is still valid! yay!
|
||||
} else if err == cert.ErrBlockListed { //avoiding errors.Is for speed
|
||||
// Block listed certificates should always be disconnected
|
||||
hostinfo.logger(cm.l).WithError(err).
|
||||
WithField("fingerprint", remoteCert.Fingerprint).
|
||||
Info("Remote certificate is blocked, tearing down the tunnel")
|
||||
return true
|
||||
} else if cm.intf.disconnectInvalid.Load() {
|
||||
hostinfo.logger(cm.l).WithError(err).
|
||||
WithField("fingerprint", remoteCert.Fingerprint).
|
||||
Info("Remote certificate is no longer valid, tearing down the tunnel")
|
||||
return true
|
||||
} else {
|
||||
//if we reach here, the cert is no longer valid, but we're configured to keep tunnels from now-invalid certs open
|
||||
return false
|
||||
}
|
||||
|
||||
if !cm.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed {
|
||||
// Block listed certificates should always be disconnected
|
||||
return false
|
||||
}
|
||||
|
||||
hostinfo.logger(cm.l).WithError(err).
|
||||
WithField("fingerprint", remoteCert.Fingerprint).
|
||||
Info("Remote certificate is no longer valid, tearing down the tunnel")
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
|
||||
@@ -536,45 +530,15 @@ func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
|
||||
func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
||||
cs := cm.intf.pki.getCertState()
|
||||
curCrt := hostinfo.ConnectionState.myCert
|
||||
curCrtVersion := curCrt.Version()
|
||||
myCrt := cs.getCertificate(curCrtVersion)
|
||||
if myCrt == nil {
|
||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||
WithField("version", curCrtVersion).
|
||||
WithField("reason", "local certificate removed").
|
||||
Info("Re-handshaking with remote")
|
||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||
myCrt := cs.getCertificate(curCrt.Version())
|
||||
if curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
|
||||
// The current tunnel is using the latest certificate and version, no need to rehandshake.
|
||||
return
|
||||
}
|
||||
peerCrt := hostinfo.ConnectionState.peerCert
|
||||
if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() {
|
||||
// if our certificate version is less than theirs, and we have a matching version available, rehandshake?
|
||||
if cs.getCertificate(peerCrt.Certificate.Version()) != nil {
|
||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||
WithField("version", curCrtVersion).
|
||||
WithField("peerVersion", peerCrt.Certificate.Version()).
|
||||
WithField("reason", "local certificate version lower than peer, attempting to correct").
|
||||
Info("Re-handshaking with remote")
|
||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) {
|
||||
hh.initiatingVersionOverride = peerCrt.Certificate.Version()
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
if !bytes.Equal(curCrt.Signature(), myCrt.Signature()) {
|
||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||
WithField("reason", "local certificate is not current").
|
||||
Info("Re-handshaking with remote")
|
||||
|
||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||
return
|
||||
}
|
||||
if curCrtVersion < cs.initiatingVersion {
|
||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||
WithField("reason", "current cert version < pki.initiatingVersion").
|
||||
Info("Re-handshaking with remote")
|
||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||
WithField("reason", "local certificate is not current").
|
||||
Info("Re-handshaking with remote")
|
||||
|
||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||
return
|
||||
}
|
||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||
}
|
||||
|
||||
@@ -29,8 +29,6 @@ type m = map[string]any
|
||||
|
||||
// newSimpleServer creates a nebula instance with many assumptions
|
||||
func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
||||
l := NewTestLogger()
|
||||
|
||||
var vpnNetworks []netip.Prefix
|
||||
for _, sn := range strings.Split(sVpnNetworks, ",") {
|
||||
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
||||
@@ -56,6 +54,25 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
|
||||
budpIp[3] = 239
|
||||
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
|
||||
}
|
||||
return newSimpleServerWithUdp(v, caCrt, caKey, name, sVpnNetworks, udpAddr, overrides)
|
||||
}
|
||||
|
||||
func newSimpleServerWithUdp(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
||||
l := NewTestLogger()
|
||||
|
||||
var vpnNetworks []netip.Prefix
|
||||
for _, sn := range strings.Split(sVpnNetworks, ",") {
|
||||
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
vpnNetworks = append(vpnNetworks, vpnIpNet)
|
||||
}
|
||||
|
||||
if len(vpnNetworks) == 0 {
|
||||
panic("no vpn networks")
|
||||
}
|
||||
|
||||
_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{})
|
||||
|
||||
caB, err := caCrt.MarshalPEM()
|
||||
@@ -129,109 +146,6 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
|
||||
return control, vpnNetworks, udpAddr, c
|
||||
}
|
||||
|
||||
// newServer creates a nebula instance with fewer assumptions
|
||||
func newServer(caCrt []cert.Certificate, certs []cert.Certificate, key []byte, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
||||
l := NewTestLogger()
|
||||
|
||||
vpnNetworks := certs[len(certs)-1].Networks()
|
||||
|
||||
var udpAddr netip.AddrPort
|
||||
if vpnNetworks[0].Addr().Is4() {
|
||||
budpIp := vpnNetworks[0].Addr().As4()
|
||||
budpIp[1] -= 128
|
||||
udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242)
|
||||
} else {
|
||||
budpIp := vpnNetworks[0].Addr().As16()
|
||||
// beef for funsies
|
||||
budpIp[2] = 190
|
||||
budpIp[3] = 239
|
||||
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
|
||||
}
|
||||
|
||||
caStr := ""
|
||||
for _, ca := range caCrt {
|
||||
x, err := ca.MarshalPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
caStr += string(x)
|
||||
}
|
||||
certStr := ""
|
||||
for _, c := range certs {
|
||||
x, err := c.MarshalPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
certStr += string(x)
|
||||
}
|
||||
|
||||
mc := m{
|
||||
"pki": m{
|
||||
"ca": caStr,
|
||||
"cert": certStr,
|
||||
"key": string(key),
|
||||
},
|
||||
//"tun": m{"disabled": true},
|
||||
"firewall": m{
|
||||
"outbound": []m{{
|
||||
"proto": "any",
|
||||
"port": "any",
|
||||
"host": "any",
|
||||
}},
|
||||
"inbound": []m{{
|
||||
"proto": "any",
|
||||
"port": "any",
|
||||
"host": "any",
|
||||
}},
|
||||
},
|
||||
//"handshakes": m{
|
||||
// "try_interval": "1s",
|
||||
//},
|
||||
"listen": m{
|
||||
"host": udpAddr.Addr().String(),
|
||||
"port": udpAddr.Port(),
|
||||
},
|
||||
"logging": m{
|
||||
"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", certs[0].Name()),
|
||||
"level": l.Level.String(),
|
||||
},
|
||||
"timers": m{
|
||||
"pending_deletion_interval": 2,
|
||||
"connection_alive_interval": 2,
|
||||
},
|
||||
}
|
||||
|
||||
if overrides != nil {
|
||||
final := m{}
|
||||
err := mergo.Merge(&final, overrides, mergo.WithAppendSlice)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = mergo.Merge(&final, mc, mergo.WithAppendSlice)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
mc = final
|
||||
}
|
||||
|
||||
cb, err := yaml.Marshal(mc)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
c := config.NewC(l)
|
||||
cStr := string(cb)
|
||||
c.LoadString(cStr)
|
||||
|
||||
control, err := nebula.Main(c, false, "e2e-test", l, nil)
|
||||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return control, vpnNetworks, udpAddr, c
|
||||
}
|
||||
|
||||
type doneCb func()
|
||||
|
||||
func deadline(t *testing.T, seconds time.Duration) doneCb {
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
package e2e
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -12,8 +11,6 @@ import (
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/cert_test"
|
||||
"github.com/slackhq/nebula/e2e/router"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func TestDropInactiveTunnels(t *testing.T) {
|
||||
@@ -60,261 +57,49 @@ func TestDropInactiveTunnels(t *testing.T) {
|
||||
theirControl.Stop()
|
||||
}
|
||||
|
||||
func TestCertUpgrade(t *testing.T) {
|
||||
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
||||
// under ideal conditions
|
||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
caB, err := ca.MarshalPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
func TestCrossStackRelaysWork(t *testing.T) {
|
||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24,fc00::1/64", m{"relay": m{"use_relays": true}})
|
||||
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "10.128.0.128/24,fc00::128/64", m{"relay": m{"am_relay": true}})
|
||||
theirUdp := netip.MustParseAddrPort("10.0.0.2:4242")
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdp(cert.Version2, ca, caKey, "them ", "fc00::2/64", theirUdp, m{"relay": m{"use_relays": true}})
|
||||
|
||||
ca2B, err := ca2.MarshalPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
caStr := fmt.Sprintf("%s\n%s", caB, ca2B)
|
||||
//myVpnV4 := myVpnIpNet[0]
|
||||
myVpnV6 := myVpnIpNet[1]
|
||||
relayVpnV4 := relayVpnIpNet[0]
|
||||
relayVpnV6 := relayVpnIpNet[1]
|
||||
theirVpnV6 := theirVpnIpNet[0]
|
||||
|
||||
myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
|
||||
_, myCert2Pem := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
|
||||
// Teach my how to get to the relay and that their can be reached via the relay
|
||||
myControl.InjectLightHouseAddr(relayVpnV4.Addr(), relayUdpAddr)
|
||||
myControl.InjectLightHouseAddr(relayVpnV6.Addr(), relayUdpAddr)
|
||||
myControl.InjectRelays(theirVpnV6.Addr(), []netip.Addr{relayVpnV6.Addr()})
|
||||
relayControl.InjectLightHouseAddr(theirVpnV6.Addr(), theirUdpAddr)
|
||||
|
||||
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
|
||||
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
|
||||
|
||||
myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert}, myPrivKey, m{})
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
|
||||
|
||||
// Share our underlay information
|
||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||
// Build a router so we don't have to reason who gets which packet
|
||||
r := router.NewR(t, myControl, relayControl, theirControl)
|
||||
defer r.RenderFlow()
|
||||
|
||||
// Start the servers
|
||||
myControl.Start()
|
||||
relayControl.Start()
|
||||
theirControl.Start()
|
||||
|
||||
r := router.NewR(t, myControl, theirControl)
|
||||
defer r.RenderFlow()
|
||||
t.Log("Trigger a handshake from me to them via the relay")
|
||||
myControl.InjectTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me"))
|
||||
|
||||
r.Log("Assert the tunnel between me and them works")
|
||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||
r.Log("yay")
|
||||
//todo ???
|
||||
time.Sleep(1 * time.Second)
|
||||
r.FlushAll()
|
||||
p := r.RouteForAllUntilTxTun(theirControl)
|
||||
r.Log("Assert the tunnel works")
|
||||
assertUdpPacket(t, []byte("Hi from me"), p, myVpnV6.Addr(), theirVpnV6.Addr(), 80, 80)
|
||||
|
||||
mc := m{
|
||||
"pki": m{
|
||||
"ca": caStr,
|
||||
"cert": string(myCert2Pem),
|
||||
"key": string(myPrivKey),
|
||||
},
|
||||
//"tun": m{"disabled": true},
|
||||
"firewall": myC.Settings["firewall"],
|
||||
//"handshakes": m{
|
||||
// "try_interval": "1s",
|
||||
//},
|
||||
"listen": myC.Settings["listen"],
|
||||
"logging": myC.Settings["logging"],
|
||||
"timers": myC.Settings["timers"],
|
||||
}
|
||||
t.Log("reply?")
|
||||
theirControl.InjectTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them"))
|
||||
p = r.RouteForAllUntilTxTun(myControl)
|
||||
assertUdpPacket(t, []byte("Hi from them"), p, theirVpnV6.Addr(), myVpnV6.Addr(), 80, 80)
|
||||
|
||||
cb, err := yaml.Marshal(mc)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
r.Logf("reload new v2-only config")
|
||||
err = myC.ReloadConfigString(string(cb))
|
||||
assert.NoError(t, err)
|
||||
r.Log("yay, spin until their sees it")
|
||||
waitStart := time.Now()
|
||||
for {
|
||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
||||
if c == nil {
|
||||
r.Log("nil")
|
||||
} else {
|
||||
version := c.Cert.Version()
|
||||
r.Logf("version %d", version)
|
||||
if version == cert.Version2 {
|
||||
break
|
||||
}
|
||||
}
|
||||
since := time.Since(waitStart)
|
||||
if since > time.Second*10 {
|
||||
t.Fatal("Cert should be new by now")
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
||||
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
||||
|
||||
myControl.Stop()
|
||||
theirControl.Stop()
|
||||
}
|
||||
|
||||
func TestCertDowngrade(t *testing.T) {
|
||||
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
||||
// under ideal conditions
|
||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
caB, err := ca.MarshalPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
|
||||
ca2B, err := ca2.MarshalPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
caStr := fmt.Sprintf("%s\n%s", caB, ca2B)
|
||||
|
||||
myCert, _, myPrivKey, myCertPem := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
|
||||
myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
|
||||
|
||||
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
|
||||
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
|
||||
|
||||
myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
|
||||
|
||||
// Share our underlay information
|
||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||
|
||||
// Start the servers
|
||||
myControl.Start()
|
||||
theirControl.Start()
|
||||
|
||||
r := router.NewR(t, myControl, theirControl)
|
||||
defer r.RenderFlow()
|
||||
|
||||
r.Log("Assert the tunnel between me and them works")
|
||||
//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
||||
//r.Log("yay")
|
||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||
r.Log("yay")
|
||||
//todo ???
|
||||
time.Sleep(1 * time.Second)
|
||||
r.FlushAll()
|
||||
|
||||
mc := m{
|
||||
"pki": m{
|
||||
"ca": caStr,
|
||||
"cert": string(myCertPem),
|
||||
"key": string(myPrivKey),
|
||||
},
|
||||
"firewall": myC.Settings["firewall"],
|
||||
"listen": myC.Settings["listen"],
|
||||
"logging": myC.Settings["logging"],
|
||||
"timers": myC.Settings["timers"],
|
||||
}
|
||||
|
||||
cb, err := yaml.Marshal(mc)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
r.Logf("reload new v1-only config")
|
||||
err = myC.ReloadConfigString(string(cb))
|
||||
assert.NoError(t, err)
|
||||
r.Log("yay, spin until their sees it")
|
||||
waitStart := time.Now()
|
||||
for {
|
||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
||||
c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
|
||||
if c == nil || c2 == nil {
|
||||
r.Log("nil")
|
||||
} else {
|
||||
version := c.Cert.Version()
|
||||
theirVersion := c2.Cert.Version()
|
||||
r.Logf("version %d,%d", version, theirVersion)
|
||||
if version == cert.Version1 {
|
||||
break
|
||||
}
|
||||
}
|
||||
since := time.Since(waitStart)
|
||||
if since > time.Second*5 {
|
||||
r.Log("it is unusual that the cert is not new yet, but not a failure yet")
|
||||
}
|
||||
if since > time.Second*10 {
|
||||
r.Log("wtf")
|
||||
t.Fatal("Cert should be new by now")
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
||||
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
||||
|
||||
myControl.Stop()
|
||||
theirControl.Stop()
|
||||
}
|
||||
|
||||
func TestCertMismatchCorrection(t *testing.T) {
|
||||
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
||||
// under ideal conditions
|
||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
|
||||
myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
|
||||
myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
|
||||
|
||||
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
|
||||
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
|
||||
|
||||
myControl, myVpnIpNet, myUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
|
||||
|
||||
// Share our underlay information
|
||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||
|
||||
// Start the servers
|
||||
myControl.Start()
|
||||
theirControl.Start()
|
||||
|
||||
r := router.NewR(t, myControl, theirControl)
|
||||
defer r.RenderFlow()
|
||||
|
||||
r.Log("Assert the tunnel between me and them works")
|
||||
//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
||||
//r.Log("yay")
|
||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||
r.Log("yay")
|
||||
//todo ???
|
||||
time.Sleep(1 * time.Second)
|
||||
r.FlushAll()
|
||||
|
||||
waitStart := time.Now()
|
||||
for {
|
||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
||||
c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
|
||||
if c == nil || c2 == nil {
|
||||
r.Log("nil")
|
||||
} else {
|
||||
version := c.Cert.Version()
|
||||
theirVersion := c2.Cert.Version()
|
||||
r.Logf("version %d,%d", version, theirVersion)
|
||||
if version == theirVersion {
|
||||
break
|
||||
}
|
||||
}
|
||||
since := time.Since(waitStart)
|
||||
if since > time.Second*5 {
|
||||
r.Log("wtf")
|
||||
}
|
||||
if since > time.Second*10 {
|
||||
r.Log("wtf")
|
||||
t.Fatal("Cert should be new by now")
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
||||
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
||||
|
||||
myControl.Stop()
|
||||
theirControl.Stop()
|
||||
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
|
||||
//t.Log("finish up")
|
||||
//myControl.Stop()
|
||||
//theirControl.Stop()
|
||||
//relayControl.Stop()
|
||||
}
|
||||
|
||||
29
firewall.go
29
firewall.go
@@ -417,6 +417,8 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
|
||||
return nil
|
||||
}
|
||||
|
||||
var ErrUnknownNetworkType = errors.New("unknown network type")
|
||||
var ErrPeerRejected = errors.New("remote IP is not within a subnet that we handle")
|
||||
var ErrInvalidRemoteIP = errors.New("remote IP is not in remote certificate subnets")
|
||||
var ErrInvalidLocalIP = errors.New("local IP is not in list of handled local IPs")
|
||||
var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
|
||||
@@ -429,18 +431,31 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
|
||||
return nil
|
||||
}
|
||||
|
||||
// Make sure remote address matches nebula certificate
|
||||
if h.networks != nil {
|
||||
if !h.networks.Contains(fp.RemoteAddr) {
|
||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||
return ErrInvalidRemoteIP
|
||||
}
|
||||
} else {
|
||||
// Make sure remote address matches nebula certificate, and determine how to treat it
|
||||
if h.networks == nil {
|
||||
// Simple case: Certificate has one address and no unsafe networks
|
||||
if h.vpnAddrs[0] != fp.RemoteAddr {
|
||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||
return ErrInvalidRemoteIP
|
||||
}
|
||||
} else {
|
||||
nwType, ok := h.networks.Lookup(fp.RemoteAddr)
|
||||
if !ok {
|
||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||
return ErrInvalidRemoteIP
|
||||
}
|
||||
switch nwType {
|
||||
case NetworkTypeVPN:
|
||||
break // nothing special
|
||||
case NetworkTypeVPNPeer:
|
||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||
return ErrPeerRejected // reject for now, one day this may have different FW rules
|
||||
case NetworkTypeUnsafe:
|
||||
break // nothing special, one day this may have different FW rules
|
||||
default:
|
||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||
return ErrUnknownNetworkType //should never happen
|
||||
}
|
||||
}
|
||||
|
||||
// Make sure we are supposed to be handling this local ip address
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/firewall"
|
||||
@@ -149,7 +150,8 @@ func TestFirewall_Drop(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
|
||||
myVpnNetworksTable := new(bart.Lite)
|
||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||
p := firewall.Packet{
|
||||
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
RemoteAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
@@ -174,7 +176,7 @@ func TestFirewall_Drop(t *testing.T) {
|
||||
},
|
||||
vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
|
||||
}
|
||||
h.buildNetworks(c.networks, c.unsafeNetworks)
|
||||
h.buildNetworks(myVpnNetworksTable, c.networks, c.unsafeNetworks)
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
@@ -226,6 +228,9 @@ func TestFirewall_DropV6(t *testing.T) {
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
|
||||
myVpnNetworksTable := new(bart.Lite)
|
||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
|
||||
|
||||
p := firewall.Packet{
|
||||
LocalAddr: netip.MustParseAddr("fd12::34"),
|
||||
RemoteAddr: netip.MustParseAddr("fd12::34"),
|
||||
@@ -250,7 +255,7 @@ func TestFirewall_DropV6(t *testing.T) {
|
||||
},
|
||||
vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")},
|
||||
}
|
||||
h.buildNetworks(c.networks, c.unsafeNetworks)
|
||||
h.buildNetworks(myVpnNetworksTable, c.networks, c.unsafeNetworks)
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
@@ -453,6 +458,8 @@ func TestFirewall_Drop2(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
myVpnNetworksTable := new(bart.Lite)
|
||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||
|
||||
p := firewall.Packet{
|
||||
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
@@ -478,7 +485,7 @@ func TestFirewall_Drop2(t *testing.T) {
|
||||
},
|
||||
vpnAddrs: []netip.Addr{network.Addr()},
|
||||
}
|
||||
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
||||
h.buildNetworks(myVpnNetworksTable, c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
||||
|
||||
c1 := cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
@@ -493,7 +500,7 @@ func TestFirewall_Drop2(t *testing.T) {
|
||||
peerCert: &c1,
|
||||
},
|
||||
}
|
||||
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
||||
h1.buildNetworks(myVpnNetworksTable, c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
@@ -510,6 +517,8 @@ func TestFirewall_Drop3(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
myVpnNetworksTable := new(bart.Lite)
|
||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||
|
||||
p := firewall.Packet{
|
||||
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
@@ -541,7 +550,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
||||
},
|
||||
vpnAddrs: []netip.Addr{network.Addr()},
|
||||
}
|
||||
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
||||
h1.buildNetworks(myVpnNetworksTable, c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
||||
|
||||
c2 := cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
@@ -556,7 +565,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
||||
},
|
||||
vpnAddrs: []netip.Addr{network.Addr()},
|
||||
}
|
||||
h2.buildNetworks(c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks())
|
||||
h2.buildNetworks(myVpnNetworksTable, c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks())
|
||||
|
||||
c3 := cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
@@ -571,7 +580,7 @@ func TestFirewall_Drop3(t *testing.T) {
|
||||
},
|
||||
vpnAddrs: []netip.Addr{network.Addr()},
|
||||
}
|
||||
h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
|
||||
h3.buildNetworks(myVpnNetworksTable, c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
@@ -597,6 +606,8 @@ func TestFirewall_Drop3V6(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
myVpnNetworksTable := new(bart.Lite)
|
||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
|
||||
|
||||
p := firewall.Packet{
|
||||
LocalAddr: netip.MustParseAddr("fd12::34"),
|
||||
@@ -620,7 +631,7 @@ func TestFirewall_Drop3V6(t *testing.T) {
|
||||
},
|
||||
vpnAddrs: []netip.Addr{network.Addr()},
|
||||
}
|
||||
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
||||
h.buildNetworks(myVpnNetworksTable, c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
||||
|
||||
// Test a remote address match
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
@@ -633,6 +644,8 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
myVpnNetworksTable := new(bart.Lite)
|
||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||
|
||||
p := firewall.Packet{
|
||||
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
@@ -659,7 +672,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||
},
|
||||
vpnAddrs: []netip.Addr{network.Addr()},
|
||||
}
|
||||
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
||||
h.buildNetworks(myVpnNetworksTable, c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||
@@ -696,6 +709,8 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
myVpnNetworksTable := new(bart.Lite)
|
||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24"))
|
||||
|
||||
c := cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
@@ -717,7 +732,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
|
||||
},
|
||||
vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()},
|
||||
}
|
||||
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
||||
h1.buildNetworks(myVpnNetworksTable, c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
|
||||
|
||||
149
handshake_ix.go
149
handshake_ix.go
@@ -23,17 +23,13 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// If we're connecting to a v6 address we must use a v2 cert
|
||||
cs := f.pki.getCertState()
|
||||
v := cs.initiatingVersion
|
||||
if hh.initiatingVersionOverride != cert.VersionPre1 {
|
||||
v = hh.initiatingVersionOverride
|
||||
} else if v < cert.Version2 {
|
||||
// If we're connecting to a v6 address we should encourage use of a V2 cert
|
||||
for _, a := range hh.hostinfo.vpnAddrs {
|
||||
if a.Is6() {
|
||||
v = cert.Version2
|
||||
break
|
||||
}
|
||||
for _, a := range hh.hostinfo.vpnAddrs {
|
||||
if a.Is6() {
|
||||
v = cert.Version2
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,7 +48,6 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
||||
WithField("certVersion", v).
|
||||
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
||||
return false
|
||||
}
|
||||
|
||||
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
|
||||
@@ -108,7 +103,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
||||
WithField("certVersion", cs.initiatingVersion).
|
||||
Error("Unable to handshake with host because no certificate is available")
|
||||
return
|
||||
}
|
||||
|
||||
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
|
||||
@@ -149,8 +143,8 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
|
||||
remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
|
||||
if err != nil {
|
||||
fp, fperr := rc.Fingerprint()
|
||||
if fperr != nil {
|
||||
fp, err := rc.Fingerprint()
|
||||
if err != nil {
|
||||
fp = "<error generating certificate fingerprint>"
|
||||
}
|
||||
|
||||
@@ -169,19 +163,16 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
|
||||
if remoteCert.Certificate.Version() != ci.myCert.Version() {
|
||||
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
|
||||
myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version())
|
||||
if myCertOtherVersion == nil {
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithError(err).WithFields(m{
|
||||
"udpAddr": addr,
|
||||
"handshake": m{"stage": 1, "style": "ix_psk0"},
|
||||
"cert": remoteCert,
|
||||
}).Debug("Might be unable to handshake with host due to missing certificate version")
|
||||
}
|
||||
} else {
|
||||
// Record the certificate we are actually using
|
||||
ci.myCert = myCertOtherVersion
|
||||
rc := cs.getCertificate(remoteCert.Certificate.Version())
|
||||
if rc == nil {
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
|
||||
Info("Unable to handshake with host due to missing certificate version")
|
||||
return
|
||||
}
|
||||
|
||||
// Record the certificate we are actually using
|
||||
ci.myCert = rc
|
||||
}
|
||||
|
||||
if len(remoteCert.Certificate.Networks()) == 0 {
|
||||
@@ -192,17 +183,18 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
return
|
||||
}
|
||||
|
||||
var vpnAddrs []netip.Addr
|
||||
var filteredNetworks []netip.Prefix
|
||||
certName := remoteCert.Certificate.Name()
|
||||
certVersion := remoteCert.Certificate.Version()
|
||||
fingerprint := remoteCert.Fingerprint
|
||||
issuer := remoteCert.Certificate.Issuer()
|
||||
vpnNetworks := remoteCert.Certificate.Networks()
|
||||
|
||||
for _, network := range remoteCert.Certificate.Networks() {
|
||||
anyVpnAddrsInCommon := false
|
||||
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
|
||||
for i, network := range vpnNetworks {
|
||||
vpnAddr := network.Addr()
|
||||
if f.myVpnAddrsTable.Contains(vpnAddr) {
|
||||
f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).
|
||||
f.l.WithField("vpnNetworks", vpnNetworks).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
@@ -210,24 +202,10 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
|
||||
return
|
||||
}
|
||||
|
||||
// vpnAddrs outside our vpn networks are of no use to us, filter them out
|
||||
if !f.myVpnNetworksTable.Contains(vpnAddr) {
|
||||
continue
|
||||
vpnAddrs[i] = network.Addr()
|
||||
if f.myVpnNetworksTable.Contains(vpnAddr) {
|
||||
anyVpnAddrsInCommon = true
|
||||
}
|
||||
|
||||
filteredNetworks = append(filteredNetworks, network)
|
||||
vpnAddrs = append(vpnAddrs, vpnAddr)
|
||||
}
|
||||
|
||||
if len(vpnAddrs) == 0 {
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
|
||||
return
|
||||
}
|
||||
|
||||
if addr.IsValid() {
|
||||
@@ -264,26 +242,30 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
},
|
||||
}
|
||||
|
||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
Info("Handshake message received")
|
||||
msgRxL := f.l.WithFields(m{
|
||||
"vpnAddrs": vpnAddrs,
|
||||
"udpAddr": addr,
|
||||
"certName": certName,
|
||||
"certVersion": certVersion,
|
||||
"fingerprint": fingerprint,
|
||||
"issuer": issuer,
|
||||
"initiatorIndex": hs.Details.InitiatorIndex,
|
||||
"responderIndex": hs.Details.ResponderIndex,
|
||||
"remoteIndex": h.RemoteIndex,
|
||||
"handshake": m{"stage": 1, "style": "ix_psk0"},
|
||||
})
|
||||
|
||||
if anyVpnAddrsInCommon {
|
||||
msgRxL.Info("Handshake message received")
|
||||
} else {
|
||||
//todo warn if not lighthouse or relay?
|
||||
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
|
||||
}
|
||||
|
||||
hs.Details.ResponderIndex = myIndex
|
||||
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
|
||||
if hs.Details.Cert == nil {
|
||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
WithField("certVersion", ci.myCert.Version()).
|
||||
msgRxL.WithField("myCertVersion", ci.myCert.Version()).
|
||||
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
||||
return
|
||||
}
|
||||
@@ -341,7 +323,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
|
||||
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
|
||||
hostinfo.SetRemote(addr)
|
||||
hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
|
||||
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate.Networks(), remoteCert.Certificate.UnsafeNetworks())
|
||||
|
||||
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
|
||||
if err != nil {
|
||||
@@ -582,30 +564,17 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
||||
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
||||
}
|
||||
|
||||
var vpnAddrs []netip.Addr
|
||||
var filteredNetworks []netip.Prefix
|
||||
for _, network := range vpnNetworks {
|
||||
// vpnAddrs outside our vpn networks are of no use to us, filter them out
|
||||
vpnAddr := network.Addr()
|
||||
if !f.myVpnNetworksTable.Contains(vpnAddr) {
|
||||
continue
|
||||
anyVpnAddrsInCommon := false
|
||||
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
|
||||
for i, network := range vpnNetworks {
|
||||
vpnAddrs[i] = network.Addr()
|
||||
if f.myVpnNetworksTable.Contains(network.Addr()) {
|
||||
anyVpnAddrsInCommon = true
|
||||
}
|
||||
|
||||
filteredNetworks = append(filteredNetworks, network)
|
||||
vpnAddrs = append(vpnAddrs, vpnAddr)
|
||||
}
|
||||
|
||||
if len(vpnAddrs) == 0 {
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
WithField("issuer", issuer).
|
||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
|
||||
return true
|
||||
}
|
||||
|
||||
// Ensure the right host responded
|
||||
// todo is it more correct to see if any of hostinfo.vpnAddrs are in the cert? it should have len==1, but one day it might not?
|
||||
if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) {
|
||||
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
|
||||
WithField("udpAddr", addr).
|
||||
@@ -618,6 +587,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
||||
f.handshakeManager.DeleteHostInfo(hostinfo)
|
||||
|
||||
// Create a new hostinfo/handshake for the intended vpn ip
|
||||
//TODO is hostinfo.vpnAddrs[0] always the address to use?
|
||||
f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
|
||||
// Block the current used address
|
||||
newHH.hostinfo.remotes = hostinfo.remotes
|
||||
@@ -644,7 +614,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
||||
ci.window.Update(f.l, 2)
|
||||
|
||||
duration := time.Since(hh.startTime).Nanoseconds()
|
||||
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||
msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||
WithField("certName", certName).
|
||||
WithField("certVersion", certVersion).
|
||||
WithField("fingerprint", fingerprint).
|
||||
@@ -652,12 +622,17 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||
WithField("durationNs", duration).
|
||||
WithField("sentCachedPackets", len(hh.packetStore)).
|
||||
Info("Handshake message received")
|
||||
WithField("sentCachedPackets", len(hh.packetStore))
|
||||
if anyVpnAddrsInCommon {
|
||||
msgRxL.Info("Handshake message received")
|
||||
} else {
|
||||
//todo warn if not lighthouse or relay?
|
||||
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
|
||||
}
|
||||
|
||||
// Build up the radix for the firewall if we have subnets in the cert
|
||||
hostinfo.vpnAddrs = vpnAddrs
|
||||
hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
|
||||
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate.Networks(), remoteCert.Certificate.UnsafeNetworks())
|
||||
|
||||
// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
|
||||
f.handshakeManager.Complete(hostinfo, f)
|
||||
|
||||
@@ -68,12 +68,11 @@ type HandshakeManager struct {
|
||||
type HandshakeHostInfo struct {
|
||||
sync.Mutex
|
||||
|
||||
startTime time.Time // Time that we first started trying with this handshake
|
||||
ready bool // Is the handshake ready
|
||||
initiatingVersionOverride cert.Version // Should we use a non-default cert version for this handshake?
|
||||
counter int64 // How many attempts have we made so far
|
||||
lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
|
||||
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
|
||||
startTime time.Time // Time that we first started trying with this handshake
|
||||
ready bool // Is the handshake ready
|
||||
counter int64 // How many attempts have we made so far
|
||||
lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
|
||||
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
|
||||
|
||||
hostinfo *HostInfo
|
||||
}
|
||||
|
||||
39
hostmap.go
39
hostmap.go
@@ -212,6 +212,18 @@ func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) {
|
||||
rs.relayForByIdx[idx] = r
|
||||
}
|
||||
|
||||
type NetworkType uint8
|
||||
|
||||
const (
|
||||
NetworkTypeUnknown NetworkType = iota
|
||||
// NetworkTypeVPN is a network that overlaps one or more of the vpnNetworks in our certificate
|
||||
NetworkTypeVPN
|
||||
// NetworkTypeVPNPeer is a network that does not overlap one of our networks
|
||||
NetworkTypeVPNPeer
|
||||
// NetworkTypeUnsafe is a network from Certificate.UnsafeNetworks()
|
||||
NetworkTypeUnsafe
|
||||
)
|
||||
|
||||
type HostInfo struct {
|
||||
remote netip.AddrPort
|
||||
remotes *RemoteList
|
||||
@@ -220,13 +232,11 @@ type HostInfo struct {
|
||||
remoteIndexId uint32
|
||||
localIndexId uint32
|
||||
|
||||
// vpnAddrs is a list of vpn addresses assigned to this host that are within our own vpn networks
|
||||
// The host may have other vpn addresses that are outside our
|
||||
// vpn networks but were removed because they are not usable
|
||||
// vpnAddrs is a list of vpn addresses assigned to this host
|
||||
vpnAddrs []netip.Addr
|
||||
|
||||
// networks are both all vpn and unsafe networks assigned to this host
|
||||
networks *bart.Lite
|
||||
// networks is a combination of specific vpn addresses (not prefixes!) and full unsafe networks assigned to this host.
|
||||
networks *bart.Table[NetworkType]
|
||||
relayState RelayState
|
||||
|
||||
// HandshakePacket records the packets used to create this hostinfo
|
||||
@@ -730,20 +740,27 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
|
||||
func (i *HostInfo) buildNetworks(myVpnNetworksTable *bart.Lite, networks, unsafeNetworks []netip.Prefix) {
|
||||
if len(networks) == 1 && len(unsafeNetworks) == 0 {
|
||||
// Simple case, no CIDRTree needed
|
||||
return
|
||||
if myVpnNetworksTable.Contains(networks[0].Addr()) {
|
||||
return // Simple case, no CIDRTree needed
|
||||
}
|
||||
}
|
||||
|
||||
i.networks = new(bart.Lite)
|
||||
i.networks = new(bart.Table[NetworkType])
|
||||
for _, network := range networks {
|
||||
var nwType NetworkType
|
||||
if myVpnNetworksTable.Contains(network.Addr()) {
|
||||
nwType = NetworkTypeVPN
|
||||
} else {
|
||||
nwType = NetworkTypeVPNPeer
|
||||
}
|
||||
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
|
||||
i.networks.Insert(nprefix)
|
||||
i.networks.Insert(nprefix, nwType)
|
||||
}
|
||||
|
||||
for _, network := range unsafeNetworks {
|
||||
i.networks.Insert(network)
|
||||
i.networks.Insert(network, NetworkTypeUnsafe)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
143
inside.go
143
inside.go
@@ -11,149 +11,6 @@ import (
|
||||
"github.com/slackhq/nebula/routing"
|
||||
)
|
||||
|
||||
// consumeInsidePackets processes multiple packets in a batch for improved performance
|
||||
// packets: slice of packet buffers to process
|
||||
// sizes: slice of packet sizes
|
||||
// count: number of packets to process
|
||||
// outs: slice of output buffers (one per packet) with virtio headroom
|
||||
// q: queue index
|
||||
// localCache: firewall conntrack cache
|
||||
// batchPackets: pre-allocated slice for accumulating encrypted packets
|
||||
// batchAddrs: pre-allocated slice for accumulating destination addresses
|
||||
func (f *Interface) consumeInsidePackets(packets [][]byte, sizes []int, count int, outs [][]byte, nb []byte, q int, localCache firewall.ConntrackCache, batchPackets *[][]byte, batchAddrs *[]netip.AddrPort) {
|
||||
// Reusable per-packet state
|
||||
fwPacket := &firewall.Packet{}
|
||||
|
||||
// Reset batch accumulation slices (reuse capacity)
|
||||
*batchPackets = (*batchPackets)[:0]
|
||||
*batchAddrs = (*batchAddrs)[:0]
|
||||
|
||||
// Process each packet in the batch
|
||||
for i := 0; i < count; i++ {
|
||||
packet := packets[i][:sizes[i]]
|
||||
out := outs[i]
|
||||
|
||||
// Inline the consumeInsidePacket logic for better performance
|
||||
err := newPacket(packet, false, fwPacket)
|
||||
if err != nil {
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Ignore local broadcast packets
|
||||
if f.dropLocalBroadcast {
|
||||
if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if f.myVpnAddrsTable.Contains(fwPacket.RemoteAddr) {
|
||||
// Immediately forward packets from self to self.
|
||||
if immediatelyForwardToSelf {
|
||||
_, err := f.readers[q].Write(packet)
|
||||
if err != nil {
|
||||
f.l.WithError(err).Error("Failed to forward to tun")
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Ignore multicast packets
|
||||
if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() {
|
||||
continue
|
||||
}
|
||||
|
||||
hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
|
||||
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
|
||||
})
|
||||
|
||||
if hostinfo == nil {
|
||||
f.rejectInside(packet, out, q)
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("vpnAddr", fwPacket.RemoteAddr).
|
||||
WithField("fwPacket", fwPacket).
|
||||
Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if !ready {
|
||||
continue
|
||||
}
|
||||
|
||||
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
|
||||
if dropReason != nil {
|
||||
f.rejectInside(packet, out, q)
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
hostinfo.logger(f.l).
|
||||
WithField("fwPacket", fwPacket).
|
||||
WithField("reason", dropReason).
|
||||
Debugln("dropping outbound packet")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Encrypt and prepare packet for batch sending
|
||||
ci := hostinfo.ConnectionState
|
||||
if ci.eKey == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this needs relay - if so, send immediately and skip batching
|
||||
useRelay := !hostinfo.remote.IsValid()
|
||||
if useRelay {
|
||||
// Handle relay sends individually (less common path)
|
||||
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, packet, nb, out, q)
|
||||
continue
|
||||
}
|
||||
|
||||
// Encrypt the packet for batch sending
|
||||
if noiseutil.EncryptLockNeeded {
|
||||
ci.writeLock.Lock()
|
||||
}
|
||||
c := ci.messageCounter.Add(1)
|
||||
out = header.Encode(out, header.Version, header.Message, 0, hostinfo.remoteIndexId, c)
|
||||
f.connectionManager.Out(hostinfo)
|
||||
|
||||
// Query lighthouse if needed
|
||||
if hostinfo.lastRebindCount != f.rebindCount {
|
||||
f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
|
||||
hostinfo.lastRebindCount = f.rebindCount
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter")
|
||||
}
|
||||
}
|
||||
|
||||
out, err = ci.eKey.EncryptDanger(out, out, packet, c, nb)
|
||||
if noiseutil.EncryptLockNeeded {
|
||||
ci.writeLock.Unlock()
|
||||
}
|
||||
if err != nil {
|
||||
hostinfo.logger(f.l).WithError(err).
|
||||
WithField("counter", c).
|
||||
Error("Failed to encrypt outgoing packet")
|
||||
continue
|
||||
}
|
||||
|
||||
// Add to batch
|
||||
*batchPackets = append(*batchPackets, out)
|
||||
*batchAddrs = append(*batchAddrs, hostinfo.remote)
|
||||
}
|
||||
|
||||
// Send all accumulated packets in one batch
|
||||
if len(*batchPackets) > 0 {
|
||||
batchSize := len(*batchPackets)
|
||||
f.batchMetrics.udpWriteSize.Update(int64(batchSize))
|
||||
|
||||
n, err := f.writers[q].WriteMulti(*batchPackets, *batchAddrs)
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("sent", n).WithField("total", batchSize).Error("Failed to send batch")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
|
||||
err := newPacket(packet, false, fwPacket)
|
||||
if err != nil {
|
||||
|
||||
76
interface.go
76
interface.go
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
@@ -21,7 +22,6 @@ import (
|
||||
)
|
||||
|
||||
const mtu = 9001
|
||||
const virtioNetHdrLen = overlay.VirtioNetHdrLen
|
||||
|
||||
type InterfaceConfig struct {
|
||||
HostMap *HostMap
|
||||
@@ -50,13 +50,6 @@ type InterfaceConfig struct {
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
type batchMetrics struct {
|
||||
udpReadSize metrics.Histogram
|
||||
tunReadSize metrics.Histogram
|
||||
udpWriteSize metrics.Histogram
|
||||
tunWriteSize metrics.Histogram
|
||||
}
|
||||
|
||||
type Interface struct {
|
||||
hostMap *HostMap
|
||||
outside udp.Conn
|
||||
@@ -93,12 +86,11 @@ type Interface struct {
|
||||
conntrackCacheTimeout time.Duration
|
||||
|
||||
writers []udp.Conn
|
||||
readers []overlay.BatchReadWriter
|
||||
readers []io.ReadWriteCloser
|
||||
|
||||
metricHandshakes metrics.Histogram
|
||||
messageMetrics *MessageMetrics
|
||||
cachedPacketMetrics *cachedPacketMetrics
|
||||
batchMetrics *batchMetrics
|
||||
|
||||
l *logrus.Logger
|
||||
}
|
||||
@@ -185,7 +177,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
||||
routines: c.routines,
|
||||
version: c.version,
|
||||
writers: make([]udp.Conn, c.routines),
|
||||
readers: make([]overlay.BatchReadWriter, c.routines),
|
||||
readers: make([]io.ReadWriteCloser, c.routines),
|
||||
myVpnNetworks: cs.myVpnNetworks,
|
||||
myVpnNetworksTable: cs.myVpnNetworksTable,
|
||||
myVpnAddrs: cs.myVpnAddrs,
|
||||
@@ -201,12 +193,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
||||
sent: metrics.GetOrRegisterCounter("hostinfo.cached_packets.sent", nil),
|
||||
dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil),
|
||||
},
|
||||
batchMetrics: &batchMetrics{
|
||||
udpReadSize: metrics.GetOrRegisterHistogram("batch.udp_read_size", nil, metrics.NewUniformSample(1024)),
|
||||
tunReadSize: metrics.GetOrRegisterHistogram("batch.tun_read_size", nil, metrics.NewUniformSample(1024)),
|
||||
udpWriteSize: metrics.GetOrRegisterHistogram("batch.udp_write_size", nil, metrics.NewUniformSample(1024)),
|
||||
tunWriteSize: metrics.GetOrRegisterHistogram("batch.tun_write_size", nil, metrics.NewUniformSample(1024)),
|
||||
},
|
||||
|
||||
l: c.l,
|
||||
}
|
||||
@@ -239,7 +225,7 @@ func (f *Interface) activate() {
|
||||
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
|
||||
|
||||
// Prepare n tun queues
|
||||
var reader overlay.BatchReadWriter = f.inside
|
||||
var reader io.ReadWriteCloser = f.inside
|
||||
for i := 0; i < f.routines; i++ {
|
||||
if i > 0 {
|
||||
reader, err = f.inside.NewMultiQueueReader()
|
||||
@@ -280,69 +266,39 @@ func (f *Interface) listenOut(i int) {
|
||||
|
||||
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||
lhh := f.lightHouse.NewRequestHandler()
|
||||
|
||||
// Pre-allocate output buffers for batch processing
|
||||
batchSize := li.BatchSize()
|
||||
outs := make([][]byte, batchSize)
|
||||
for idx := range outs {
|
||||
// Allocate full buffer with virtio header space
|
||||
outs[idx] = make([]byte, virtioNetHdrLen, virtioNetHdrLen+udp.MTU)
|
||||
}
|
||||
|
||||
plaintext := make([]byte, udp.MTU)
|
||||
h := &header.H{}
|
||||
fwPacket := &firewall.Packet{}
|
||||
nb := make([]byte, 12)
|
||||
nb := make([]byte, 12, 12)
|
||||
|
||||
li.ListenOutBatch(func(addrs []netip.AddrPort, payloads [][]byte, count int) {
|
||||
f.readOutsidePacketsBatch(addrs, payloads, count, outs[:count], nb, i, h, fwPacket, lhh, ctCache.Get(f.l))
|
||||
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
|
||||
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
|
||||
})
|
||||
}
|
||||
|
||||
func (f *Interface) listenIn(reader overlay.BatchReadWriter, i int) {
|
||||
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
||||
runtime.LockOSThread()
|
||||
|
||||
batchSize := reader.BatchSize()
|
||||
|
||||
// Allocate buffers for batch reading
|
||||
bufs := make([][]byte, batchSize)
|
||||
for idx := range bufs {
|
||||
bufs[idx] = make([]byte, mtu)
|
||||
}
|
||||
sizes := make([]int, batchSize)
|
||||
|
||||
// Allocate output buffers for batch processing (one per packet)
|
||||
// Each has virtio header headroom to avoid copies on write
|
||||
outs := make([][]byte, batchSize)
|
||||
for idx := range outs {
|
||||
outBuf := make([]byte, virtioNetHdrLen+mtu)
|
||||
outs[idx] = outBuf[virtioNetHdrLen:] // Slice starting after headroom
|
||||
}
|
||||
|
||||
// Pre-allocate batch accumulation buffers for sending
|
||||
batchPackets := make([][]byte, 0, batchSize)
|
||||
batchAddrs := make([]netip.AddrPort, 0, batchSize)
|
||||
|
||||
// Pre-allocate nonce buffer (reused for all encryptions)
|
||||
nb := make([]byte, 12)
|
||||
packet := make([]byte, mtu)
|
||||
out := make([]byte, mtu)
|
||||
fwPacket := &firewall.Packet{}
|
||||
nb := make([]byte, 12, 12)
|
||||
|
||||
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||
|
||||
for {
|
||||
n, err := reader.BatchRead(bufs, sizes)
|
||||
n, err := reader.Read(packet)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
f.l.WithError(err).Error("Error while batch reading outbound packets")
|
||||
f.l.WithError(err).Error("Error while reading outbound packet")
|
||||
// This only seems to happen when something fatal happens to the fd, so exit.
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
f.batchMetrics.tunReadSize.Update(int64(n))
|
||||
|
||||
// Process all packets in the batch at once
|
||||
f.consumeInsidePackets(bufs, sizes, n, outs, nb, i, conntrackCache.Get(f.l), &batchPackets, &batchAddrs)
|
||||
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1017,17 +1017,17 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta {
|
||||
return lhh.meta
|
||||
}
|
||||
|
||||
func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs []netip.Addr, p []byte, w EncWriter) {
|
||||
func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, hostinfo *HostInfo, p []byte, w EncWriter) {
|
||||
n := lhh.resetMeta()
|
||||
err := n.Unmarshal(p)
|
||||
if err != nil {
|
||||
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr).
|
||||
lhh.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", rAddr).
|
||||
Error("Failed to unmarshal lighthouse packet")
|
||||
return
|
||||
}
|
||||
|
||||
if n.Details == nil {
|
||||
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr).
|
||||
lhh.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", rAddr).
|
||||
Error("Invalid lighthouse update")
|
||||
return
|
||||
}
|
||||
@@ -1036,24 +1036,24 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs [
|
||||
|
||||
switch n.Type {
|
||||
case NebulaMeta_HostQuery:
|
||||
lhh.handleHostQuery(n, fromVpnAddrs, rAddr, w)
|
||||
lhh.handleHostQuery(n, hostinfo, rAddr, w)
|
||||
|
||||
case NebulaMeta_HostQueryReply:
|
||||
lhh.handleHostQueryReply(n, fromVpnAddrs)
|
||||
lhh.handleHostQueryReply(n, hostinfo.vpnAddrs)
|
||||
|
||||
case NebulaMeta_HostUpdateNotification:
|
||||
lhh.handleHostUpdateNotification(n, fromVpnAddrs, w)
|
||||
lhh.handleHostUpdateNotification(n, hostinfo, w)
|
||||
|
||||
case NebulaMeta_HostMovedNotification:
|
||||
case NebulaMeta_HostPunchNotification:
|
||||
lhh.handleHostPunchNotification(n, fromVpnAddrs, w)
|
||||
lhh.handleHostPunchNotification(n, hostinfo.vpnAddrs, w)
|
||||
|
||||
case NebulaMeta_HostUpdateNotificationAck:
|
||||
// noop
|
||||
}
|
||||
}
|
||||
|
||||
func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []netip.Addr, addr netip.AddrPort, w EncWriter) {
|
||||
func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, hostinfo *HostInfo, addr netip.AddrPort, w EncWriter) {
|
||||
// Exit if we don't answer queries
|
||||
if !lhh.lh.amLighthouse {
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
@@ -1065,7 +1065,7 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
|
||||
queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion()
|
||||
if err != nil {
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).
|
||||
lhh.l.WithField("from", hostinfo.vpnAddrs).WithField("details", n.Details).
|
||||
Debugln("Dropping malformed HostQuery")
|
||||
}
|
||||
return
|
||||
@@ -1073,7 +1073,7 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
|
||||
if useVersion == cert.Version1 && queryVpnAddr.Is6() {
|
||||
// this case really shouldn't be possible to represent, but reject it anyway.
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("queryVpnAddr", queryVpnAddr).
|
||||
lhh.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("queryVpnAddr", queryVpnAddr).
|
||||
Debugln("invalid vpn addr for v1 handleHostQuery")
|
||||
}
|
||||
return
|
||||
@@ -1099,14 +1099,14 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host query reply")
|
||||
lhh.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).Error("Failed to marshal lighthouse host query reply")
|
||||
return
|
||||
}
|
||||
|
||||
lhh.lh.metricTx(NebulaMeta_HostQueryReply, 1)
|
||||
w.SendMessageToVpnAddr(header.LightHouse, 0, fromVpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0])
|
||||
w.SendMessageToHostInfo(header.LightHouse, 0, hostinfo, lhh.pb[:ln], lhh.nb, lhh.out[:0])
|
||||
|
||||
lhh.sendHostPunchNotification(n, fromVpnAddrs, queryVpnAddr, w)
|
||||
lhh.sendHostPunchNotification(n, hostinfo.vpnAddrs, queryVpnAddr, w)
|
||||
}
|
||||
|
||||
// sendHostPunchNotification signals the other side to punch some zero byte udp packets
|
||||
@@ -1115,20 +1115,34 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
|
||||
found, ln, err := lhh.lh.queryAndPrepMessage(whereToPunch, func(c *cache) (int, error) {
|
||||
n = lhh.resetMeta()
|
||||
n.Type = NebulaMeta_HostPunchNotification
|
||||
targetHI := lhh.lh.ifce.GetHostInfo(punchNotifDest)
|
||||
punchNotifDestHI := lhh.lh.ifce.GetHostInfo(punchNotifDest)
|
||||
var useVersion cert.Version
|
||||
if targetHI == nil {
|
||||
if punchNotifDestHI == nil {
|
||||
useVersion = lhh.lh.ifce.GetCertState().initiatingVersion
|
||||
} else {
|
||||
crt := targetHI.GetCert().Certificate
|
||||
useVersion = crt.Version()
|
||||
// we can only retarget if we have a hostinfo
|
||||
newDest, ok := findNetworkUnion(crt.Networks(), fromVpnAddrs)
|
||||
punchNotifDestCrt := punchNotifDestHI.GetCert().Certificate
|
||||
useVersion = punchNotifDestCrt.Version()
|
||||
punchNotifDestNetworks := punchNotifDestCrt.Networks()
|
||||
|
||||
//if we (the lighthouse) don't have a network in common with punchNotifDest, try to find one
|
||||
if !lhh.lh.myVpnNetworksTable.Contains(punchNotifDest) {
|
||||
newPunchNotifDest, ok := findNetworkUnion(lhh.lh.myVpnNetworks, punchNotifDestHI.vpnAddrs)
|
||||
if ok {
|
||||
punchNotifDest = newPunchNotifDest
|
||||
} else {
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.WithField("to", punchNotifDestNetworks).Debugln("unable to notify host to host, no addresses in common")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
newWhereToPunch, ok := findNetworkUnion(punchNotifDestNetworks, fromVpnAddrs)
|
||||
if ok {
|
||||
whereToPunch = newDest
|
||||
whereToPunch = newWhereToPunch
|
||||
} else {
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.WithField("to", crt.Networks()).Debugln("unable to punch to host, no addresses in common")
|
||||
lhh.l.WithFields(m{"from": fromVpnAddrs, "to": punchNotifDestNetworks}).Debugln("unable to punch to host, no addresses in common with requestor")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1234,7 +1248,8 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [
|
||||
}
|
||||
}
|
||||
|
||||
func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
|
||||
func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, hostinfo *HostInfo, w EncWriter) {
|
||||
fromVpnAddrs := hostinfo.vpnAddrs
|
||||
if !lhh.lh.amLighthouse {
|
||||
if lhh.l.Level >= logrus.DebugLevel {
|
||||
lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", fromVpnAddrs)
|
||||
@@ -1302,7 +1317,7 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
|
||||
}
|
||||
|
||||
lhh.lh.metricTx(NebulaMeta_HostUpdateNotificationAck, 1)
|
||||
w.SendMessageToVpnAddr(header.LightHouse, 0, fromVpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0])
|
||||
w.SendMessageToHostInfo(header.LightHouse, 0, hostinfo, lhh.pb[:ln], lhh.nb, lhh.out[:0])
|
||||
}
|
||||
|
||||
func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
|
||||
@@ -1337,19 +1352,12 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
|
||||
}
|
||||
}
|
||||
|
||||
remoteAllowList := lhh.lh.GetRemoteAllowList()
|
||||
for _, a := range n.Details.V4AddrPorts {
|
||||
b := protoV4AddrPortToNetAddrPort(a)
|
||||
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
|
||||
punch(b, detailsVpnAddr)
|
||||
}
|
||||
punch(protoV4AddrPortToNetAddrPort(a), detailsVpnAddr)
|
||||
}
|
||||
|
||||
for _, a := range n.Details.V6AddrPorts {
|
||||
b := protoV6AddrPortToNetAddrPort(a)
|
||||
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
|
||||
punch(b, detailsVpnAddr)
|
||||
}
|
||||
punch(protoV6AddrPortToNetAddrPort(a), detailsVpnAddr)
|
||||
}
|
||||
|
||||
// This sends a nebula test packet to the host trying to contact us. In the case
|
||||
|
||||
@@ -132,8 +132,13 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
||||
)
|
||||
|
||||
mw := &mockEncWriter{}
|
||||
|
||||
hi := []netip.Addr{vpnIp2}
|
||||
hostinfo := &HostInfo{
|
||||
ConnectionState: &ConnectionState{
|
||||
eKey: nil,
|
||||
dKey: nil,
|
||||
},
|
||||
vpnAddrs: []netip.Addr{vpnIp2},
|
||||
}
|
||||
b.Run("notfound", func(b *testing.B) {
|
||||
lhh := lh.NewRequestHandler()
|
||||
req := &NebulaMeta{
|
||||
@@ -146,7 +151,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
||||
p, err := req.Marshal()
|
||||
require.NoError(b, err)
|
||||
for n := 0; n < b.N; n++ {
|
||||
lhh.HandleRequest(rAddr, hi, p, mw)
|
||||
lhh.HandleRequest(rAddr, hostinfo, p, mw)
|
||||
}
|
||||
})
|
||||
b.Run("found", func(b *testing.B) {
|
||||
@@ -162,7 +167,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
||||
require.NoError(b, err)
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
lhh.HandleRequest(rAddr, hi, p, mw)
|
||||
lhh.HandleRequest(rAddr, hostinfo, p, mw)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -326,7 +331,14 @@ func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, l
|
||||
w := &testEncWriter{
|
||||
metaFilter: &filter,
|
||||
}
|
||||
lhh.HandleRequest(fromAddr, []netip.Addr{myVpnIp}, b, w)
|
||||
hostinfo := &HostInfo{
|
||||
ConnectionState: &ConnectionState{
|
||||
eKey: nil,
|
||||
dKey: nil,
|
||||
},
|
||||
vpnAddrs: []netip.Addr{myVpnIp},
|
||||
}
|
||||
lhh.HandleRequest(fromAddr, hostinfo, b, w)
|
||||
return w.lastReply
|
||||
}
|
||||
|
||||
@@ -355,9 +367,15 @@ func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.Ad
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
hostinfo := &HostInfo{
|
||||
ConnectionState: &ConnectionState{
|
||||
eKey: nil,
|
||||
dKey: nil,
|
||||
},
|
||||
vpnAddrs: []netip.Addr{vpnIp},
|
||||
}
|
||||
w := &testEncWriter{}
|
||||
lhh.HandleRequest(fromAddr, []netip.Addr{vpnIp}, b, w)
|
||||
lhh.HandleRequest(fromAddr, hostinfo, b, w)
|
||||
}
|
||||
|
||||
type testLhReply struct {
|
||||
|
||||
5
main.go
5
main.go
@@ -75,8 +75,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
if c.GetBool("sshd.enabled", false) {
|
||||
sshStart, err = configSSH(l, ssh, c)
|
||||
if err != nil {
|
||||
l.WithError(err).Warn("Failed to configure sshd, ssh debugging will not be available")
|
||||
sshStart = nil
|
||||
return nil, util.ContextualizeIfNeeded("Error while configuring the sshd", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -165,7 +164,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
|
||||
for i := 0; i < routines; i++ {
|
||||
l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port)))
|
||||
udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 128))
|
||||
udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64))
|
||||
if err != nil {
|
||||
return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
|
||||
}
|
||||
|
||||
122
outside.go
122
outside.go
@@ -95,7 +95,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
||||
switch relay.Type {
|
||||
case TerminalType:
|
||||
// If I am the target of this relay, process the unwrapped packet
|
||||
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:virtioNetHdrLen], signedPayload, h, fwPacket, lhf, nb, q, localCache)
|
||||
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
|
||||
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
|
||||
return
|
||||
case ForwardingType:
|
||||
// Find the target HostInfo relay object
|
||||
@@ -137,7 +138,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
||||
return
|
||||
}
|
||||
|
||||
lhf.HandleRequest(ip, hostinfo.vpnAddrs, d[virtioNetHdrLen:], f)
|
||||
lhf.HandleRequest(ip, hostinfo, d, f)
|
||||
|
||||
// Fallthrough to the bottom to record incoming traffic
|
||||
|
||||
@@ -159,7 +160,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
||||
// This testRequest might be from TryPromoteBest, so we should roam
|
||||
// to the new IP address before responding
|
||||
f.handleHostRoaming(hostinfo, ip)
|
||||
f.send(header.Test, header.TestReply, ci, hostinfo, d[virtioNetHdrLen:], nb, out)
|
||||
f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out)
|
||||
}
|
||||
|
||||
// Fallthrough to the bottom to record incoming traffic
|
||||
@@ -202,7 +203,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
||||
return
|
||||
}
|
||||
|
||||
f.relayManager.HandleControlMsg(hostinfo, d[virtioNetHdrLen:], f)
|
||||
f.relayManager.HandleControlMsg(hostinfo, d, f)
|
||||
|
||||
default:
|
||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
||||
@@ -473,11 +474,9 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
||||
return false
|
||||
}
|
||||
|
||||
packetData := out[virtioNetHdrLen:]
|
||||
|
||||
err = newPacket(packetData, true, fwPacket)
|
||||
err = newPacket(out, true, fwPacket)
|
||||
if err != nil {
|
||||
hostinfo.logger(f.l).WithError(err).WithField("packet", packetData).
|
||||
hostinfo.logger(f.l).WithError(err).WithField("packet", out).
|
||||
Warnf("Error while validating inbound packet")
|
||||
return false
|
||||
}
|
||||
@@ -492,7 +491,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
||||
if dropReason != nil {
|
||||
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
|
||||
// This gives us a buffer to build the reject packet in
|
||||
f.rejectOutside(packetData, hostinfo.ConnectionState, hostinfo, nb, packet, q)
|
||||
f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q)
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
|
||||
WithField("reason", dropReason).
|
||||
@@ -549,108 +548,3 @@ func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
|
||||
// We also delete it from pending hostmap to allow for fast reconnect.
|
||||
f.handshakeManager.DeleteHostInfo(hostinfo)
|
||||
}
|
||||
|
||||
// readOutsidePacketsBatch processes multiple packets received from UDP in a batch
|
||||
// and writes all successfully decrypted packets to TUN in a single operation
|
||||
func (f *Interface) readOutsidePacketsBatch(addrs []netip.AddrPort, payloads [][]byte, count int, outs [][]byte, nb []byte, q int, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, localCache firewall.ConntrackCache) {
|
||||
// Pre-allocate slice for accumulating successful decryptions
|
||||
tunPackets := make([][]byte, 0, count)
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
payload := payloads[i]
|
||||
addr := addrs[i]
|
||||
out := outs[i]
|
||||
|
||||
// Parse header
|
||||
err := h.Parse(payload)
|
||||
if err != nil {
|
||||
if len(payload) > 1 {
|
||||
f.l.WithField("packet", payload).Infof("Error while parsing inbound packet from %s: %s", addr, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if addr.IsValid() {
|
||||
if f.myVpnNetworksTable.Contains(addr.Addr()) {
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("udpAddr", addr).Debug("Refusing to process double encrypted packet")
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
var hostinfo *HostInfo
|
||||
if h.Type == header.Message && h.Subtype == header.MessageRelay {
|
||||
hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex)
|
||||
} else {
|
||||
hostinfo = f.hostMap.QueryIndex(h.RemoteIndex)
|
||||
}
|
||||
|
||||
var ci *ConnectionState
|
||||
if hostinfo != nil {
|
||||
ci = hostinfo.ConnectionState
|
||||
}
|
||||
|
||||
switch h.Type {
|
||||
case header.Message:
|
||||
if !f.handleEncrypted(ci, addr, h) {
|
||||
continue
|
||||
}
|
||||
|
||||
switch h.Subtype {
|
||||
case header.MessageNone:
|
||||
// Decrypt packet
|
||||
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, payload[:header.Len], payload[header.Len:], h.MessageCounter, nb)
|
||||
if err != nil {
|
||||
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
|
||||
continue
|
||||
}
|
||||
|
||||
packetData := out[virtioNetHdrLen:]
|
||||
|
||||
err = newPacket(packetData, true, fwPacket)
|
||||
if err != nil {
|
||||
hostinfo.logger(f.l).WithError(err).WithField("packet", packetData).Warnf("Error while validating inbound packet")
|
||||
continue
|
||||
}
|
||||
|
||||
if !hostinfo.ConnectionState.window.Update(f.l, h.MessageCounter) {
|
||||
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).Debugln("dropping out of window packet")
|
||||
continue
|
||||
}
|
||||
|
||||
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
|
||||
if dropReason != nil {
|
||||
f.rejectOutside(packetData, hostinfo.ConnectionState, hostinfo, nb, payload, q)
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).WithField("reason", dropReason).Debugln("dropping inbound packet")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
f.connectionManager.In(hostinfo)
|
||||
// Add to batch for TUN write
|
||||
tunPackets = append(tunPackets, out)
|
||||
|
||||
case header.MessageRelay:
|
||||
// Skip relay packets in batch mode for now (less common path)
|
||||
f.readOutsidePackets(addr, nil, out, payload, h, fwPacket, lhf, nb, q, localCache)
|
||||
|
||||
default:
|
||||
hostinfo.logger(f.l).Debugf("unexpected message subtype %d", h.Subtype)
|
||||
}
|
||||
|
||||
default:
|
||||
// Handle non-Message types using single-packet path
|
||||
f.readOutsidePackets(addr, nil, out, payload, h, fwPacket, lhf, nb, q, localCache)
|
||||
}
|
||||
}
|
||||
|
||||
if len(tunPackets) > 0 {
|
||||
n, err := f.readers[q].WriteBatch(tunPackets, virtioNetHdrLen)
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("sent", n).WithField("total", len(tunPackets)).Error("Failed to batch write to tun")
|
||||
}
|
||||
f.batchMetrics.tunWriteSize.Update(int64(len(tunPackets)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,25 +7,11 @@ import (
|
||||
"github.com/slackhq/nebula/routing"
|
||||
)
|
||||
|
||||
// BatchReadWriter extends io.ReadWriteCloser with batch I/O operations
|
||||
type BatchReadWriter interface {
|
||||
io.ReadWriteCloser
|
||||
|
||||
// BatchRead reads multiple packets at once
|
||||
BatchRead(bufs [][]byte, sizes []int) (int, error)
|
||||
|
||||
// WriteBatch writes multiple packets at once
|
||||
WriteBatch(bufs [][]byte, offset int) (int, error)
|
||||
|
||||
// BatchSize returns the optimal batch size for this device
|
||||
BatchSize() int
|
||||
}
|
||||
|
||||
type Device interface {
|
||||
BatchReadWriter
|
||||
io.ReadWriteCloser
|
||||
Activate() error
|
||||
Networks() []netip.Prefix
|
||||
Name() string
|
||||
RoutesFor(netip.Addr) routing.Gateways
|
||||
NewMultiQueueReader() (BatchReadWriter, error)
|
||||
NewMultiQueueReader() (io.ReadWriteCloser, error)
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
)
|
||||
|
||||
const DefaultMTU = 1300
|
||||
const VirtioNetHdrLen = 10 // Size of virtio_net_hdr structure
|
||||
|
||||
// TODO: We may be able to remove routines
|
||||
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
|
||||
|
||||
@@ -95,29 +95,6 @@ func (t *tun) Name() string {
|
||||
return "android"
|
||||
}
|
||||
|
||||
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
|
||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for android")
|
||||
}
|
||||
|
||||
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
|
||||
n, err := t.Read(bufs[0])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
sizes[0] = n
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
|
||||
for i, buf := range bufs {
|
||||
_, err := t.Write(buf[offset:])
|
||||
if err != nil {
|
||||
return i, err
|
||||
}
|
||||
}
|
||||
return len(bufs), nil
|
||||
}
|
||||
|
||||
func (t *tun) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
@@ -549,32 +549,6 @@ func (t *tun) Name() string {
|
||||
return t.Device
|
||||
}
|
||||
|
||||
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
|
||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
|
||||
}
|
||||
|
||||
// BatchRead reads a single packet (batch size 1 for non-Linux platforms)
|
||||
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
|
||||
n, err := t.Read(bufs[0])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
sizes[0] = n
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
// WriteBatch writes packets individually (no batching for non-Linux platforms)
|
||||
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
|
||||
for i, buf := range bufs {
|
||||
_, err := t.Write(buf[offset:])
|
||||
if err != nil {
|
||||
return i, err
|
||||
}
|
||||
}
|
||||
return len(bufs), nil
|
||||
}
|
||||
|
||||
// BatchSize returns 1 for non-Linux platforms (no batching)
|
||||
func (t *tun) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
@@ -105,36 +105,10 @@ func (t *disabledTun) Write(b []byte) (int, error) {
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (t *disabledTun) NewMultiQueueReader() (BatchReadWriter, error) {
|
||||
func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// BatchRead reads a single packet (batch size 1 for disabled tun)
|
||||
func (t *disabledTun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
|
||||
n, err := t.Read(bufs[0])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
sizes[0] = n
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
// WriteBatch writes packets individually (no batching for disabled tun)
|
||||
func (t *disabledTun) WriteBatch(bufs [][]byte, offset int) (int, error) {
|
||||
for i, buf := range bufs {
|
||||
_, err := t.Write(buf[offset:])
|
||||
if err != nil {
|
||||
return i, err
|
||||
}
|
||||
}
|
||||
return len(bufs), nil
|
||||
}
|
||||
|
||||
// BatchSize returns 1 for disabled tun (no batching)
|
||||
func (t *disabledTun) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func (t *disabledTun) Close() error {
|
||||
if t.read != nil {
|
||||
close(t.read)
|
||||
|
||||
@@ -450,36 +450,10 @@ func (t *tun) Name() string {
|
||||
return t.Device
|
||||
}
|
||||
|
||||
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
|
||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
|
||||
}
|
||||
|
||||
// BatchRead reads a single packet (batch size 1 for FreeBSD)
|
||||
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
|
||||
n, err := t.Read(bufs[0])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
sizes[0] = n
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
// WriteBatch writes packets individually (no batching for FreeBSD)
|
||||
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
|
||||
for i, buf := range bufs {
|
||||
_, err := t.Write(buf[offset:])
|
||||
if err != nil {
|
||||
return i, err
|
||||
}
|
||||
}
|
||||
return len(bufs), nil
|
||||
}
|
||||
|
||||
// BatchSize returns 1 for FreeBSD (no batching)
|
||||
func (t *tun) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func (t *tun) addRoutes(logErrors bool) error {
|
||||
routes := *t.Routes.Load()
|
||||
for _, r := range routes {
|
||||
|
||||
@@ -151,29 +151,6 @@ func (t *tun) Name() string {
|
||||
return "iOS"
|
||||
}
|
||||
|
||||
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
|
||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for ios")
|
||||
}
|
||||
|
||||
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
|
||||
n, err := t.Read(bufs[0])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
sizes[0] = n
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
|
||||
for i, buf := range bufs {
|
||||
_, err := t.Write(buf[offset:])
|
||||
if err != nil {
|
||||
return i, err
|
||||
}
|
||||
}
|
||||
return len(bufs), nil
|
||||
}
|
||||
|
||||
func (t *tun) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unsafe"
|
||||
@@ -20,12 +21,10 @@ import (
|
||||
"github.com/slackhq/nebula/util"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
wgtun "golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
type tun struct {
|
||||
io.ReadWriteCloser
|
||||
wgDevice wgtun.Device
|
||||
fd int
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
@@ -66,154 +65,59 @@ type ifreqQLEN struct {
|
||||
pad [8]byte
|
||||
}
|
||||
|
||||
// wgDeviceWrapper wraps a wireguard Device to implement io.ReadWriteCloser
|
||||
// This allows multiqueue readers to use the same wireguard Device batching as the main device
|
||||
type wgDeviceWrapper struct {
|
||||
dev wgtun.Device
|
||||
buf []byte // Reusable buffer for single packet reads
|
||||
}
|
||||
|
||||
func (w *wgDeviceWrapper) Read(b []byte) (int, error) {
|
||||
// Use wireguard Device's batch API for single packet
|
||||
bufs := [][]byte{b}
|
||||
sizes := make([]int, 1)
|
||||
n, err := w.dev.Read(bufs, sizes, 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if n == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
return sizes[0], nil
|
||||
}
|
||||
|
||||
func (w *wgDeviceWrapper) Write(b []byte) (int, error) {
|
||||
// Buffer b should have virtio header space (10 bytes) at the beginning
|
||||
// The decrypted packet data starts at offset 10
|
||||
// Pass the full buffer to WireGuard with offset=virtioNetHdrLen
|
||||
bufs := [][]byte{b}
|
||||
n, err := w.dev.Write(bufs, VirtioNetHdrLen)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if n == 0 {
|
||||
return 0, io.ErrShortWrite
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (w *wgDeviceWrapper) WriteBatch(bufs [][]byte, offset int) (int, error) {
|
||||
// Pass all buffers to WireGuard's batch write
|
||||
return w.dev.Write(bufs, offset)
|
||||
}
|
||||
|
||||
func (w *wgDeviceWrapper) Close() error {
|
||||
return w.dev.Close()
|
||||
}
|
||||
|
||||
// BatchRead implements batching for multiqueue readers
|
||||
func (w *wgDeviceWrapper) BatchRead(bufs [][]byte, sizes []int) (int, error) {
|
||||
// The zero here is offset.
|
||||
return w.dev.Read(bufs, sizes, 0)
|
||||
}
|
||||
|
||||
// BatchSize returns the optimal batch size
|
||||
func (w *wgDeviceWrapper) BatchSize() int {
|
||||
return w.dev.BatchSize()
|
||||
}
|
||||
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
wgDev, name, err := wgtun.CreateUnmonitoredTUNFromFD(deviceFd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create TUN from FD: %w", err)
|
||||
}
|
||||
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
||||
|
||||
file := wgDev.File()
|
||||
t, err := newTunGeneric(c, l, file, vpnNetworks)
|
||||
if err != nil {
|
||||
_ = wgDev.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t.wgDevice = wgDev
|
||||
t.Device = name
|
||||
t.Device = "tun0"
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
|
||||
// Check if /dev/net/tun exists, create if needed (for docker containers)
|
||||
if _, err := os.Stat("/dev/net/tun"); os.IsNotExist(err) {
|
||||
if err := os.MkdirAll("/dev/net", 0755); err != nil {
|
||||
return nil, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err)
|
||||
}
|
||||
if err := unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200))); err != nil {
|
||||
return nil, fmt.Errorf("failed to create /dev/net/tun: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
devName := c.GetString("tun.dev", "")
|
||||
mtu := c.GetInt("tun.mtu", DefaultMTU)
|
||||
|
||||
// Create TUN device manually to support multiqueue
|
||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
|
||||
if os.IsNotExist(err) {
|
||||
err = os.MkdirAll("/dev/net", 0755)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err)
|
||||
}
|
||||
err = unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create /dev/net/tun: %w", err)
|
||||
}
|
||||
|
||||
fd, err = unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("created /dev/net/tun, but still failed: %w", err)
|
||||
}
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var req ifReq
|
||||
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR)
|
||||
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI)
|
||||
if multiqueue {
|
||||
req.Flags |= unix.IFF_MULTI_QUEUE
|
||||
}
|
||||
copy(req.Name[:], devName)
|
||||
copy(req.Name[:], c.GetString("tun.dev", ""))
|
||||
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
||||
unix.Close(fd)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set nonblocking
|
||||
if err = unix.SetNonblock(fd, true); err != nil {
|
||||
unix.Close(fd)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Enable TCP and UDP offload (TSO/GRO) for performance
|
||||
// This allows the kernel to handle segmentation/coalescing
|
||||
const (
|
||||
tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
|
||||
tunUDPOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6
|
||||
)
|
||||
offloads := tunTCPOffloads | tunUDPOffloads
|
||||
if err = unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, offloads); err != nil {
|
||||
// Log warning but don't fail - offload is optional
|
||||
l.WithError(err).Warn("Failed to enable TUN offload (TSO/GRO), performance may be reduced")
|
||||
}
|
||||
name := strings.Trim(string(req.Name[:]), "\x00")
|
||||
|
||||
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
||||
|
||||
// Create wireguard device from file descriptor
|
||||
wgDev, err := wgtun.CreateTUNFromFile(file, mtu)
|
||||
if err != nil {
|
||||
file.Close()
|
||||
return nil, fmt.Errorf("failed to create TUN from file: %w", err)
|
||||
}
|
||||
|
||||
name, err := wgDev.Name()
|
||||
if err != nil {
|
||||
_ = wgDev.Close()
|
||||
return nil, fmt.Errorf("failed to get TUN device name: %w", err)
|
||||
}
|
||||
|
||||
// file is now owned by wgDev, get a new reference
|
||||
file = wgDev.File()
|
||||
t, err := newTunGeneric(c, l, file, vpnNetworks)
|
||||
if err != nil {
|
||||
_ = wgDev.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t.wgDevice = wgDev
|
||||
t.Device = name
|
||||
|
||||
return t, nil
|
||||
@@ -312,44 +216,22 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
|
||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var req ifReq
|
||||
// MUST match the flags used in newTun - includes IFF_VNET_HDR
|
||||
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR | unix.IFF_MULTI_QUEUE)
|
||||
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
|
||||
copy(req.Name[:], t.Device)
|
||||
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
||||
unix.Close(fd)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set nonblocking mode - CRITICAL for proper netpoller integration
|
||||
if err = unix.SetNonblock(fd, true); err != nil {
|
||||
unix.Close(fd)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get MTU from main device
|
||||
mtu := t.MaxMTU
|
||||
if mtu == 0 {
|
||||
mtu = DefaultMTU
|
||||
}
|
||||
|
||||
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
||||
|
||||
// Create wireguard Device from the file descriptor (just like the main device)
|
||||
wgDev, err := wgtun.CreateTUNFromFile(file, mtu)
|
||||
if err != nil {
|
||||
file.Close()
|
||||
return nil, fmt.Errorf("failed to create multiqueue TUN device: %w", err)
|
||||
}
|
||||
|
||||
// Return a wrapper that uses the wireguard Device for all I/O
|
||||
return &wgDeviceWrapper{dev: wgDev}, nil
|
||||
return file, nil
|
||||
}
|
||||
|
||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||
@@ -357,68 +239,7 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||
return r
|
||||
}
|
||||
|
||||
func (t *tun) Read(b []byte) (int, error) {
|
||||
if t.wgDevice != nil {
|
||||
// Use wireguard device which handles virtio headers internally
|
||||
bufs := [][]byte{b}
|
||||
sizes := make([]int, 1)
|
||||
n, err := t.wgDevice.Read(bufs, sizes, 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if n == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
return sizes[0], nil
|
||||
}
|
||||
|
||||
// Fallback: direct read from file (shouldn't happen in normal operation)
|
||||
return t.ReadWriteCloser.Read(b)
|
||||
}
|
||||
|
||||
// BatchRead reads multiple packets at once for improved performance
|
||||
// bufs: slice of buffers to read into
|
||||
// sizes: slice that will be filled with packet sizes
|
||||
// Returns number of packets read
|
||||
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
|
||||
if t.wgDevice != nil {
|
||||
return t.wgDevice.Read(bufs, sizes, 0)
|
||||
}
|
||||
|
||||
// Fallback: single packet read
|
||||
n, err := t.ReadWriteCloser.Read(bufs[0])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
sizes[0] = n
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
// BatchSize returns the optimal number of packets to read/write in a batch
|
||||
func (t *tun) BatchSize() int {
|
||||
if t.wgDevice != nil {
|
||||
return t.wgDevice.BatchSize()
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func (t *tun) Write(b []byte) (int, error) {
|
||||
if t.wgDevice != nil {
|
||||
// Buffer b should have virtio header space (10 bytes) at the beginning
|
||||
// The decrypted packet data starts at offset 10
|
||||
// Pass the full buffer to WireGuard with offset=virtioNetHdrLen
|
||||
bufs := [][]byte{b}
|
||||
n, err := t.wgDevice.Write(bufs, VirtioNetHdrLen)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if n == 0 {
|
||||
return 0, io.ErrShortWrite
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
// Fallback: direct write (shouldn't happen in normal operation)
|
||||
var nn int
|
||||
maximum := len(b)
|
||||
|
||||
@@ -441,22 +262,6 @@ func (t *tun) Write(b []byte) (int, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// WriteBatch writes multiple packets to the TUN device in a single syscall
|
||||
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
|
||||
if t.wgDevice != nil {
|
||||
return t.wgDevice.Write(bufs, offset)
|
||||
}
|
||||
|
||||
// Fallback: write individually (shouldn't happen in normal operation)
|
||||
for i, buf := range bufs {
|
||||
_, err := t.Write(buf)
|
||||
if err != nil {
|
||||
return i, err
|
||||
}
|
||||
}
|
||||
return len(bufs), nil
|
||||
}
|
||||
|
||||
func (t *tun) deviceBytes() (o [16]byte) {
|
||||
for i, c := range t.Device {
|
||||
o[i] = byte(c)
|
||||
@@ -869,10 +674,6 @@ func (t *tun) Close() error {
|
||||
close(t.routeChan)
|
||||
}
|
||||
|
||||
if t.wgDevice != nil {
|
||||
_ = t.wgDevice.Close()
|
||||
}
|
||||
|
||||
if t.ReadWriteCloser != nil {
|
||||
_ = t.ReadWriteCloser.Close()
|
||||
}
|
||||
|
||||
@@ -390,33 +390,10 @@ func (t *tun) Name() string {
|
||||
return t.Device
|
||||
}
|
||||
|
||||
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
|
||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd")
|
||||
}
|
||||
|
||||
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
|
||||
n, err := t.Read(bufs[0])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
sizes[0] = n
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
|
||||
for i, buf := range bufs {
|
||||
_, err := t.Write(buf[offset:])
|
||||
if err != nil {
|
||||
return i, err
|
||||
}
|
||||
}
|
||||
return len(bufs), nil
|
||||
}
|
||||
|
||||
func (t *tun) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func (t *tun) addRoutes(logErrors bool) error {
|
||||
routes := *t.Routes.Load()
|
||||
|
||||
|
||||
@@ -310,33 +310,10 @@ func (t *tun) Name() string {
|
||||
return t.Device
|
||||
}
|
||||
|
||||
func (t *tun) NewMultiQueueReader() (BatchReadWriter, error) {
|
||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd")
|
||||
}
|
||||
|
||||
func (t *tun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
|
||||
n, err := t.Read(bufs[0])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
sizes[0] = n
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
func (t *tun) WriteBatch(bufs [][]byte, offset int) (int, error) {
|
||||
for i, buf := range bufs {
|
||||
_, err := t.Write(buf[offset:])
|
||||
if err != nil {
|
||||
return i, err
|
||||
}
|
||||
}
|
||||
return len(bufs), nil
|
||||
}
|
||||
|
||||
func (t *tun) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func (t *tun) addRoutes(logErrors bool) error {
|
||||
routes := *t.Routes.Load()
|
||||
|
||||
|
||||
@@ -132,29 +132,6 @@ func (t *TestTun) Read(b []byte) (int, error) {
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (t *TestTun) NewMultiQueueReader() (BatchReadWriter, error) {
|
||||
func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented")
|
||||
}
|
||||
|
||||
func (t *TestTun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
|
||||
n, err := t.Read(bufs[0])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
sizes[0] = n
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
func (t *TestTun) WriteBatch(bufs [][]byte, offset int) (int, error) {
|
||||
for i, buf := range bufs {
|
||||
_, err := t.Write(buf[offset:])
|
||||
if err != nil {
|
||||
return i, err
|
||||
}
|
||||
}
|
||||
return len(bufs), nil
|
||||
}
|
||||
|
||||
func (t *TestTun) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ package overlay
|
||||
import (
|
||||
"crypto"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -233,36 +234,10 @@ func (t *winTun) Write(b []byte) (int, error) {
|
||||
return t.tun.Write(b, 0)
|
||||
}
|
||||
|
||||
func (t *winTun) NewMultiQueueReader() (BatchReadWriter, error) {
|
||||
func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
|
||||
}
|
||||
|
||||
// BatchRead reads a single packet (batch size 1 for Windows)
|
||||
func (t *winTun) BatchRead(bufs [][]byte, sizes []int) (int, error) {
|
||||
n, err := t.Read(bufs[0])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
sizes[0] = n
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
// WriteBatch writes packets individually (no batching for Windows)
|
||||
func (t *winTun) WriteBatch(bufs [][]byte, offset int) (int, error) {
|
||||
for i, buf := range bufs {
|
||||
_, err := t.Write(buf[offset:])
|
||||
if err != nil {
|
||||
return i, err
|
||||
}
|
||||
}
|
||||
return len(bufs), nil
|
||||
}
|
||||
|
||||
// BatchSize returns 1 for Windows (no batching)
|
||||
func (t *winTun) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func (t *winTun) Close() error {
|
||||
// It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes,
|
||||
// so to be certain, just remove everything before destroying.
|
||||
|
||||
@@ -46,36 +46,10 @@ func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||
return routing.Gateways{routing.NewGateway(ip, 1)}
|
||||
}
|
||||
|
||||
func (d *UserDevice) NewMultiQueueReader() (BatchReadWriter, error) {
|
||||
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return d, nil
|
||||
}
|
||||
|
||||
// BatchRead reads a single packet (batch size 1 for UserDevice)
|
||||
func (d *UserDevice) BatchRead(bufs [][]byte, sizes []int) (int, error) {
|
||||
n, err := d.Read(bufs[0])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
sizes[0] = n
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
// WriteBatch writes packets individually (no batching for UserDevice)
|
||||
func (d *UserDevice) WriteBatch(bufs [][]byte, offset int) (int, error) {
|
||||
for i, buf := range bufs {
|
||||
_, err := d.Write(buf[offset:])
|
||||
if err != nil {
|
||||
return i, err
|
||||
}
|
||||
}
|
||||
return len(bufs), nil
|
||||
}
|
||||
|
||||
// BatchSize returns 1 for UserDevice (no batching)
|
||||
func (d *UserDevice) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func (d *UserDevice) Pipe() (*io.PipeReader, *io.PipeWriter) {
|
||||
return d.inboundReader, d.outboundWriter
|
||||
}
|
||||
|
||||
83
pki.go
83
pki.go
@@ -100,62 +100,55 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
|
||||
currentState := p.cs.Load()
|
||||
if newState.v1Cert != nil {
|
||||
if currentState.v1Cert == nil {
|
||||
//adding certs is fine, actually. Networks-in-common confirmed in newCertState().
|
||||
} else {
|
||||
// did IP in cert change? if so, don't set
|
||||
if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
|
||||
return util.NewContextualError(
|
||||
"Networks in new cert was different from old",
|
||||
m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks(), "cert_version": cert.Version1},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
|
||||
return util.NewContextualError(
|
||||
"Curve in new v1 cert was different from old",
|
||||
m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve(), "cert_version": cert.Version1},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
return util.NewContextualError("v1 certificate was added, restart required", nil, err)
|
||||
}
|
||||
|
||||
// did IP in cert change? if so, don't set
|
||||
if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
|
||||
return util.NewContextualError(
|
||||
"Networks in new cert was different from old",
|
||||
m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks()},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
|
||||
return util.NewContextualError(
|
||||
"Curve in new cert was different from old",
|
||||
m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve()},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
} else if currentState.v1Cert != nil {
|
||||
//TODO: CERT-V2 we should be able to tear this down
|
||||
return util.NewContextualError("v1 certificate was removed, restart required", nil, err)
|
||||
}
|
||||
|
||||
if newState.v2Cert != nil {
|
||||
if currentState.v2Cert == nil {
|
||||
//adding certs is fine, actually
|
||||
} else {
|
||||
// did IP in cert change? if so, don't set
|
||||
if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
|
||||
return util.NewContextualError(
|
||||
"Networks in new cert was different from old",
|
||||
m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks(), "cert_version": cert.Version2},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
|
||||
return util.NewContextualError(
|
||||
"Curve in new cert was different from old",
|
||||
m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve(), "cert_version": cert.Version2},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
return util.NewContextualError("v2 certificate was added, restart required", nil, err)
|
||||
}
|
||||
|
||||
} else if currentState.v2Cert != nil {
|
||||
//newState.v1Cert is non-nil bc empty certstates aren't permitted
|
||||
if newState.v1Cert == nil {
|
||||
return util.NewContextualError("v1 and v2 certs are nil, this should be impossible", nil, err)
|
||||
}
|
||||
//if we're going to v1-only, we need to make sure we didn't orphan any v2-cert vpnaddrs
|
||||
if !slices.Equal(currentState.v2Cert.Networks(), newState.v1Cert.Networks()) {
|
||||
// did IP in cert change? if so, don't set
|
||||
if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
|
||||
return util.NewContextualError(
|
||||
"Removing a V2 cert is not permitted unless it has identical networks to the new V1 cert",
|
||||
m{"new_v1_networks": newState.v1Cert.Networks(), "old_v2_networks": currentState.v2Cert.Networks()},
|
||||
"Networks in new cert was different from old",
|
||||
m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks()},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
|
||||
return util.NewContextualError(
|
||||
"Curve in new cert was different from old",
|
||||
m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve()},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
} else if currentState.v2Cert != nil {
|
||||
return util.NewContextualError("v2 certificate was removed, restart required", nil, err)
|
||||
}
|
||||
|
||||
// Cipher cant be hot swapped so just leave it at what it was before
|
||||
|
||||
1
stats.go
1
stats.go
@@ -6,7 +6,6 @@ import (
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
18
udp/conn.go
18
udp/conn.go
@@ -13,21 +13,12 @@ type EncReader func(
|
||||
payload []byte,
|
||||
)
|
||||
|
||||
type EncBatchReader func(
|
||||
addrs []netip.AddrPort,
|
||||
payloads [][]byte,
|
||||
count int,
|
||||
)
|
||||
|
||||
type Conn interface {
|
||||
Rebind() error
|
||||
LocalAddr() (netip.AddrPort, error)
|
||||
ListenOut(r EncReader)
|
||||
ListenOutBatch(r EncBatchReader)
|
||||
WriteTo(b []byte, addr netip.AddrPort) error
|
||||
WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error)
|
||||
ReloadConfig(c *config.C)
|
||||
BatchSize() int
|
||||
Close() error
|
||||
}
|
||||
|
||||
@@ -42,21 +33,12 @@ func (NoopConn) LocalAddr() (netip.AddrPort, error) {
|
||||
func (NoopConn) ListenOut(_ EncReader) {
|
||||
return
|
||||
}
|
||||
func (NoopConn) ListenOutBatch(_ EncBatchReader) {
|
||||
return
|
||||
}
|
||||
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
|
||||
return nil
|
||||
}
|
||||
func (NoopConn) WriteMulti(_ [][]byte, _ []netip.AddrPort) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (NoopConn) ReloadConfig(_ *config.C) {
|
||||
return
|
||||
}
|
||||
func (NoopConn) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
func (NoopConn) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -140,17 +140,6 @@ func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error {
|
||||
}
|
||||
}
|
||||
|
||||
// WriteMulti sends multiple packets - fallback implementation without sendmmsg
|
||||
func (u *StdConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) {
|
||||
for i := range packets {
|
||||
err := u.WriteTo(packets[i], addrs[i])
|
||||
if err != nil {
|
||||
return i, err
|
||||
}
|
||||
}
|
||||
return len(packets), nil
|
||||
}
|
||||
|
||||
func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
|
||||
a := u.UDPConn.LocalAddr()
|
||||
|
||||
@@ -195,34 +184,6 @@ func (u *StdConn) ListenOut(r EncReader) {
|
||||
}
|
||||
}
|
||||
|
||||
// ListenOutBatch - fallback to single-packet reads for Darwin
|
||||
func (u *StdConn) ListenOutBatch(r EncBatchReader) {
|
||||
buffer := make([]byte, MTU)
|
||||
addrs := make([]netip.AddrPort, 1)
|
||||
payloads := make([][]byte, 1)
|
||||
|
||||
for {
|
||||
// Just read one packet at a time and call batch callback with count=1
|
||||
n, rua, err := u.ReadFromUDPAddrPort(buffer)
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
||||
return
|
||||
}
|
||||
|
||||
u.l.WithError(err).Error("unexpected udp socket receive error")
|
||||
}
|
||||
|
||||
addrs[0] = netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port())
|
||||
payloads[0] = buffer[:n]
|
||||
r(addrs, payloads, 1)
|
||||
}
|
||||
}
|
||||
|
||||
func (u *StdConn) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func (u *StdConn) Rebind() error {
|
||||
var err error
|
||||
if u.isV4 {
|
||||
|
||||
@@ -85,42 +85,3 @@ func (u *GenericConn) ListenOut(r EncReader) {
|
||||
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
|
||||
}
|
||||
}
|
||||
|
||||
// ListenOutBatch - fallback to single-packet reads for generic platforms
|
||||
func (u *GenericConn) ListenOutBatch(r EncBatchReader) {
|
||||
buffer := make([]byte, MTU)
|
||||
addrs := make([]netip.AddrPort, 1)
|
||||
payloads := make([][]byte, 1)
|
||||
|
||||
for {
|
||||
// Just read one packet at a time and call batch callback with count=1
|
||||
n, rua, err := u.ReadFromUDPAddrPort(buffer)
|
||||
if err != nil {
|
||||
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
||||
return
|
||||
}
|
||||
|
||||
addrs[0] = netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port())
|
||||
payloads[0] = buffer[:n]
|
||||
r(addrs, payloads, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// WriteMulti sends multiple packets - fallback implementation
|
||||
func (u *GenericConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) {
|
||||
for i := range packets {
|
||||
err := u.WriteTo(packets[i], addrs[i])
|
||||
if err != nil {
|
||||
return i, err
|
||||
}
|
||||
}
|
||||
return len(packets), nil
|
||||
}
|
||||
|
||||
func (u *GenericConn) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func (u *GenericConn) Rebind() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
204
udp/udp_linux.go
204
udp/udp_linux.go
@@ -22,11 +22,6 @@ type StdConn struct {
|
||||
isV4 bool
|
||||
l *logrus.Logger
|
||||
batch int
|
||||
|
||||
// Pre-allocated buffers for batch writes (sized for IPv6, works for both)
|
||||
writeMsgs []rawMessage
|
||||
writeIovecs []iovec
|
||||
writeNames [][]byte
|
||||
}
|
||||
|
||||
func maybeIPV4(ip net.IP) (net.IP, bool) {
|
||||
@@ -74,26 +69,7 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
|
||||
return nil, fmt.Errorf("unable to bind to socket: %s", err)
|
||||
}
|
||||
|
||||
c := &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}
|
||||
|
||||
// Pre-allocate write message structures for batching (sized for IPv6, works for both)
|
||||
c.writeMsgs = make([]rawMessage, batch)
|
||||
c.writeIovecs = make([]iovec, batch)
|
||||
c.writeNames = make([][]byte, batch)
|
||||
|
||||
for i := range c.writeMsgs {
|
||||
// Allocate for IPv6 size (larger than IPv4, works for both)
|
||||
c.writeNames[i] = make([]byte, unix.SizeofSockaddrInet6)
|
||||
|
||||
// Point to the iovec in the slice
|
||||
c.writeMsgs[i].Hdr.Iov = &c.writeIovecs[i]
|
||||
c.writeMsgs[i].Hdr.Iovlen = 1
|
||||
|
||||
c.writeMsgs[i].Hdr.Name = &c.writeNames[i][0]
|
||||
// Namelen will be set appropriately in writeMulti4/writeMulti6
|
||||
}
|
||||
|
||||
return c, err
|
||||
return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err
|
||||
}
|
||||
|
||||
func (u *StdConn) Rebind() error {
|
||||
@@ -151,8 +127,6 @@ func (u *StdConn) ListenOut(r EncReader) {
|
||||
read = u.ReadSingle
|
||||
}
|
||||
|
||||
udpBatchHist := metrics.GetOrRegisterHistogram("batch.udp_read_size", nil, metrics.NewUniformSample(1024))
|
||||
|
||||
for {
|
||||
n, err := read(msgs)
|
||||
if err != nil {
|
||||
@@ -160,8 +134,6 @@ func (u *StdConn) ListenOut(r EncReader) {
|
||||
return
|
||||
}
|
||||
|
||||
udpBatchHist.Update(int64(n))
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
|
||||
if u.isV4 {
|
||||
@@ -174,46 +146,6 @@ func (u *StdConn) ListenOut(r EncReader) {
|
||||
}
|
||||
}
|
||||
|
||||
func (u *StdConn) ListenOutBatch(r EncBatchReader) {
|
||||
var ip netip.Addr
|
||||
|
||||
msgs, buffers, names := u.PrepareRawMessages(u.batch)
|
||||
read := u.ReadMulti
|
||||
if u.batch == 1 {
|
||||
read = u.ReadSingle
|
||||
}
|
||||
|
||||
udpBatchHist := metrics.GetOrRegisterHistogram("batch.udp_read_size", nil, metrics.NewUniformSample(1024))
|
||||
|
||||
// Pre-allocate slices for batch callback
|
||||
addrs := make([]netip.AddrPort, u.batch)
|
||||
payloads := make([][]byte, u.batch)
|
||||
|
||||
for {
|
||||
n, err := read(msgs)
|
||||
if err != nil {
|
||||
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
||||
return
|
||||
}
|
||||
|
||||
udpBatchHist.Update(int64(n))
|
||||
|
||||
// Prepare batch data
|
||||
for i := 0; i < n; i++ {
|
||||
if u.isV4 {
|
||||
ip, _ = netip.AddrFromSlice(names[i][4:8])
|
||||
} else {
|
||||
ip, _ = netip.AddrFromSlice(names[i][8:24])
|
||||
}
|
||||
addrs[i] = netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4]))
|
||||
payloads[i] = buffers[i][:msgs[i].Len]
|
||||
}
|
||||
|
||||
// Call batch callback with all packets
|
||||
r(addrs, payloads, n)
|
||||
}
|
||||
}
|
||||
|
||||
func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) {
|
||||
for {
|
||||
n, _, err := unix.Syscall6(
|
||||
@@ -262,19 +194,6 @@ func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
|
||||
return u.writeTo6(b, ip)
|
||||
}
|
||||
|
||||
func (u *StdConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) {
|
||||
if len(packets) != len(addrs) {
|
||||
return 0, fmt.Errorf("packets and addrs length mismatch")
|
||||
}
|
||||
if len(packets) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if u.isV4 {
|
||||
return u.writeMulti4(packets, addrs)
|
||||
}
|
||||
return u.writeMulti6(packets, addrs)
|
||||
}
|
||||
|
||||
func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
|
||||
var rsa unix.RawSockaddrInet6
|
||||
rsa.Family = unix.AF_INET6
|
||||
@@ -329,123 +248,6 @@ func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error {
|
||||
}
|
||||
}
|
||||
|
||||
func (u *StdConn) writeMulti4(packets [][]byte, addrs []netip.AddrPort) (int, error) {
|
||||
sent := 0
|
||||
for sent < len(packets) {
|
||||
// Determine batch size based on remaining packets and buffer capacity
|
||||
batchSize := len(packets) - sent
|
||||
if batchSize > len(u.writeMsgs) {
|
||||
batchSize = len(u.writeMsgs)
|
||||
}
|
||||
|
||||
// Use pre-allocated buffers
|
||||
msgs := u.writeMsgs[:batchSize]
|
||||
iovecs := u.writeIovecs[:batchSize]
|
||||
names := u.writeNames[:batchSize]
|
||||
|
||||
// Setup message structures for this batch
|
||||
for i := 0; i < batchSize; i++ {
|
||||
pktIdx := sent + i
|
||||
if !addrs[pktIdx].Addr().Is4() {
|
||||
return sent + i, ErrInvalidIPv6RemoteForSocket
|
||||
}
|
||||
|
||||
// Setup the packet buffer
|
||||
iovecs[i].Base = &packets[pktIdx][0]
|
||||
iovecs[i].Len = uint(len(packets[pktIdx]))
|
||||
|
||||
// Setup the destination address
|
||||
rsa := (*unix.RawSockaddrInet4)(unsafe.Pointer(&names[i][0]))
|
||||
rsa.Family = unix.AF_INET
|
||||
rsa.Addr = addrs[pktIdx].Addr().As4()
|
||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], addrs[pktIdx].Port())
|
||||
|
||||
// Set the appropriate address length for IPv4
|
||||
msgs[i].Hdr.Namelen = unix.SizeofSockaddrInet4
|
||||
}
|
||||
|
||||
// Send this batch
|
||||
nsent, _, err := unix.Syscall6(
|
||||
unix.SYS_SENDMMSG,
|
||||
uintptr(u.sysFd),
|
||||
uintptr(unsafe.Pointer(&msgs[0])),
|
||||
uintptr(batchSize),
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
|
||||
if err != 0 {
|
||||
return sent + int(nsent), &net.OpError{Op: "sendmmsg", Err: err}
|
||||
}
|
||||
|
||||
sent += int(nsent)
|
||||
if int(nsent) < batchSize {
|
||||
// Couldn't send all packets in batch, return what we sent
|
||||
return sent, nil
|
||||
}
|
||||
}
|
||||
|
||||
return sent, nil
|
||||
}
|
||||
|
||||
func (u *StdConn) writeMulti6(packets [][]byte, addrs []netip.AddrPort) (int, error) {
|
||||
sent := 0
|
||||
for sent < len(packets) {
|
||||
// Determine batch size based on remaining packets and buffer capacity
|
||||
batchSize := len(packets) - sent
|
||||
if batchSize > len(u.writeMsgs) {
|
||||
batchSize = len(u.writeMsgs)
|
||||
}
|
||||
|
||||
// Use pre-allocated buffers
|
||||
msgs := u.writeMsgs[:batchSize]
|
||||
iovecs := u.writeIovecs[:batchSize]
|
||||
names := u.writeNames[:batchSize]
|
||||
|
||||
// Setup message structures for this batch
|
||||
for i := 0; i < batchSize; i++ {
|
||||
pktIdx := sent + i
|
||||
|
||||
// Setup the packet buffer
|
||||
iovecs[i].Base = &packets[pktIdx][0]
|
||||
iovecs[i].Len = uint(len(packets[pktIdx]))
|
||||
|
||||
// Setup the destination address
|
||||
rsa := (*unix.RawSockaddrInet6)(unsafe.Pointer(&names[i][0]))
|
||||
rsa.Family = unix.AF_INET6
|
||||
rsa.Addr = addrs[pktIdx].Addr().As16()
|
||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], addrs[pktIdx].Port())
|
||||
|
||||
// Set the appropriate address length for IPv6
|
||||
msgs[i].Hdr.Namelen = unix.SizeofSockaddrInet6
|
||||
}
|
||||
|
||||
// Send this batch
|
||||
nsent, _, err := unix.Syscall6(
|
||||
unix.SYS_SENDMMSG,
|
||||
uintptr(u.sysFd),
|
||||
uintptr(unsafe.Pointer(&msgs[0])),
|
||||
uintptr(batchSize),
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
|
||||
if err != 0 {
|
||||
return sent + int(nsent), &net.OpError{Op: "sendmmsg", Err: err}
|
||||
}
|
||||
|
||||
sent += int(nsent)
|
||||
if int(nsent) < batchSize {
|
||||
// Couldn't send all packets in batch, return what we sent
|
||||
return sent, nil
|
||||
}
|
||||
}
|
||||
|
||||
return sent, nil
|
||||
}
|
||||
|
||||
func (u *StdConn) ReloadConfig(c *config.C) {
|
||||
b := c.GetInt("listen.read_buffer", 0)
|
||||
if b > 0 {
|
||||
@@ -503,10 +305,6 @@ func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *StdConn) BatchSize() int {
|
||||
return u.batch
|
||||
}
|
||||
|
||||
func (u *StdConn) Close() error {
|
||||
return syscall.Close(u.sysFd)
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
|
||||
type iovec struct {
|
||||
Base *byte
|
||||
Len uint
|
||||
Len uint32
|
||||
}
|
||||
|
||||
type msghdr struct {
|
||||
@@ -40,7 +40,7 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
|
||||
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
||||
|
||||
vs := []iovec{
|
||||
{Base: &buffers[i][0], Len: uint(len(buffers[i]))},
|
||||
{Base: &buffers[i][0], Len: uint32(len(buffers[i]))},
|
||||
}
|
||||
|
||||
msgs[i].Hdr.Iov = &vs[0]
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
|
||||
type iovec struct {
|
||||
Base *byte
|
||||
Len uint
|
||||
Len uint64
|
||||
}
|
||||
|
||||
type msghdr struct {
|
||||
@@ -43,7 +43,7 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
|
||||
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
||||
|
||||
vs := []iovec{
|
||||
{Base: &buffers[i][0], Len: uint(len(buffers[i]))},
|
||||
{Base: &buffers[i][0], Len: uint64(len(buffers[i]))},
|
||||
}
|
||||
|
||||
msgs[i].Hdr.Iov = &vs[0]
|
||||
|
||||
@@ -116,31 +116,6 @@ func (u *TesterConn) ListenOut(r EncReader) {
|
||||
}
|
||||
}
|
||||
|
||||
func (u *TesterConn) ListenOutBatch(r EncBatchReader) {
|
||||
addrs := make([]netip.AddrPort, 1)
|
||||
payloads := make([][]byte, 1)
|
||||
|
||||
for {
|
||||
p, ok := <-u.RxPackets
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
addrs[0] = p.From
|
||||
payloads[0] = p.Data
|
||||
r(addrs, payloads, 1)
|
||||
}
|
||||
}
|
||||
|
||||
func (u *TesterConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) {
|
||||
for i := range packets {
|
||||
err := u.WriteTo(packets[i], addrs[i])
|
||||
if err != nil {
|
||||
return i, err
|
||||
}
|
||||
}
|
||||
return len(packets), nil
|
||||
}
|
||||
|
||||
func (u *TesterConn) ReloadConfig(*config.C) {}
|
||||
|
||||
func NewUDPStatsEmitter(_ []Conn) func() {
|
||||
@@ -152,10 +127,6 @@ func (u *TesterConn) LocalAddr() (netip.AddrPort, error) {
|
||||
return u.Addr, nil
|
||||
}
|
||||
|
||||
func (u *TesterConn) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func (u *TesterConn) Rebind() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user