mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 16:34:25 +01:00
Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
064831cf21 | ||
|
|
9ded90c6e8 | ||
|
|
19600f257f | ||
|
|
3e2a6e0a5d | ||
|
|
bc62f5ec82 | ||
|
|
012fcf40fe | ||
|
|
ad319b964d | ||
|
|
f42878c5fc | ||
|
|
f2b3ef4b3e | ||
|
|
c3ec96d9c2 | ||
|
|
01909f4715 |
18
CHANGELOG.md
18
CHANGELOG.md
@@ -7,30 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
|
||||
- Experimental Linux UDP offload support: enable `listen.enable_gso` and
|
||||
`listen.enable_gro` to activate UDP_SEGMENT batching and GRO receive
|
||||
splitting. Includes automatic capability probing, per-packet fallbacks, and
|
||||
runtime metrics/logs for visibility.
|
||||
- Optional Linux TUN `virtio_net_hdr` support: set `tun.enable_vnet_hdr` to
|
||||
have Nebula negotiate VNET headers and offload flags so future batches can
|
||||
be delivered to the kernel with metadata instead of per-packet writes.
|
||||
- Linux UDP send sharding can now be tuned with `listen.send_shards`; defaults
|
||||
to `GOMAXPROCS` but can be increased to stripe heavy peers across more
|
||||
goroutines.
|
||||
|
||||
### Changed
|
||||
|
||||
- `default_local_cidr_any` now defaults to false, meaning that any firewall rule
|
||||
intended to target an `unsafe_routes` entry must explicitly declare it via the
|
||||
`local_cidr` field. This is almost always the intended behavior. This flag is
|
||||
deprecated and will be removed in a future release.
|
||||
- UDP receive path now enqueues into per-worker lock-free rings, restoring the
|
||||
`listen.decrypt_workers`/`listen.decrypt_queue_depth` tuning knobs while
|
||||
eliminating the mutex contention from the old shared channel.
|
||||
- Increased replay protection window to 32k packets so high-throughput links
|
||||
tolerate larger bursts of reordering without tripping the anti-replay logic.
|
||||
|
||||
## [1.9.4] - 2024-09-09
|
||||
|
||||
|
||||
@@ -114,6 +114,33 @@ func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []by
|
||||
return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem
|
||||
}
|
||||
|
||||
func NewTestCertDifferentVersion(c cert.Certificate, v cert.Version, ca cert.Certificate, key []byte) (cert.Certificate, []byte) {
|
||||
nc := &cert.TBSCertificate{
|
||||
Version: v,
|
||||
Curve: c.Curve(),
|
||||
Name: c.Name(),
|
||||
Networks: c.Networks(),
|
||||
UnsafeNetworks: c.UnsafeNetworks(),
|
||||
Groups: c.Groups(),
|
||||
NotBefore: time.Unix(c.NotBefore().Unix(), 0),
|
||||
NotAfter: time.Unix(c.NotAfter().Unix(), 0),
|
||||
PublicKey: c.PublicKey(),
|
||||
IsCA: false,
|
||||
}
|
||||
|
||||
c, err := nc.Sign(ca, ca.Curve(), key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
pem, err := c.MarshalPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return c, pem
|
||||
}
|
||||
|
||||
func X25519Keypair() ([]byte, []byte) {
|
||||
privkey := make([]byte, 32)
|
||||
if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
|
||||
|
||||
@@ -65,8 +65,16 @@ func main() {
|
||||
}
|
||||
|
||||
if !*configTest {
|
||||
ctrl.Start()
|
||||
ctrl.ShutdownBlock()
|
||||
wait, err := ctrl.Start()
|
||||
if err != nil {
|
||||
util.LogWithContextIfNeeded("Error while running", err, l)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
go ctrl.ShutdownBlock()
|
||||
wait()
|
||||
|
||||
l.Info("Goodbye")
|
||||
}
|
||||
|
||||
os.Exit(0)
|
||||
|
||||
@@ -3,6 +3,9 @@ package main
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
@@ -58,10 +61,22 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
go func() {
|
||||
log.Println(http.ListenAndServe("0.0.0.0:6060", nil))
|
||||
}()
|
||||
|
||||
if !*configTest {
|
||||
ctrl.Start()
|
||||
wait, err := ctrl.Start()
|
||||
if err != nil {
|
||||
util.LogWithContextIfNeeded("Error while running", err, l)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
go ctrl.ShutdownBlock()
|
||||
notifyReady(l)
|
||||
ctrl.ShutdownBlock()
|
||||
wait()
|
||||
|
||||
l.Info("Goodbye")
|
||||
}
|
||||
|
||||
os.Exit(0)
|
||||
|
||||
@@ -354,7 +354,6 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
|
||||
|
||||
if mainHostInfo {
|
||||
decision = tryRehandshake
|
||||
|
||||
} else {
|
||||
if cm.shouldSwapPrimary(hostinfo) {
|
||||
decision = swapPrimary
|
||||
@@ -461,6 +460,10 @@ func (cm *connectionManager) shouldSwapPrimary(current *HostInfo) bool {
|
||||
}
|
||||
|
||||
crt := cm.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
|
||||
if crt == nil {
|
||||
//my cert was reloaded away. We should definitely swap from this tunnel
|
||||
return true
|
||||
}
|
||||
// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
|
||||
// settle down.
|
||||
return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
|
||||
@@ -475,31 +478,34 @@ func (cm *connectionManager) swapPrimary(current, primary *HostInfo) {
|
||||
cm.hostMap.Unlock()
|
||||
}
|
||||
|
||||
// isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and
|
||||
// the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid
|
||||
// check and return true.
|
||||
// isInvalidCertificate decides if we should destroy a tunnel.
|
||||
// returns true if pki.disconnect_invalid is true and the certificate is no longer valid.
|
||||
// Blocklisted certificates will skip the pki.disconnect_invalid check and return true.
|
||||
func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
|
||||
remoteCert := hostinfo.GetCert()
|
||||
if remoteCert == nil {
|
||||
return false
|
||||
return false //don't tear down tunnels for handshakes in progress
|
||||
}
|
||||
|
||||
caPool := cm.intf.pki.GetCAPool()
|
||||
err := caPool.VerifyCachedCertificate(now, remoteCert)
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !cm.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed {
|
||||
return false //cert is still valid! yay!
|
||||
} else if err == cert.ErrBlockListed { //avoiding errors.Is for speed
|
||||
// Block listed certificates should always be disconnected
|
||||
hostinfo.logger(cm.l).WithError(err).
|
||||
WithField("fingerprint", remoteCert.Fingerprint).
|
||||
Info("Remote certificate is blocked, tearing down the tunnel")
|
||||
return true
|
||||
} else if cm.intf.disconnectInvalid.Load() {
|
||||
hostinfo.logger(cm.l).WithError(err).
|
||||
WithField("fingerprint", remoteCert.Fingerprint).
|
||||
Info("Remote certificate is no longer valid, tearing down the tunnel")
|
||||
return true
|
||||
} else {
|
||||
//if we reach here, the cert is no longer valid, but we're configured to keep tunnels from now-invalid certs open
|
||||
return false
|
||||
}
|
||||
|
||||
hostinfo.logger(cm.l).WithError(err).
|
||||
WithField("fingerprint", remoteCert.Fingerprint).
|
||||
Info("Remote certificate is no longer valid, tearing down the tunnel")
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
|
||||
@@ -530,15 +536,45 @@ func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
|
||||
func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
||||
cs := cm.intf.pki.getCertState()
|
||||
curCrt := hostinfo.ConnectionState.myCert
|
||||
myCrt := cs.getCertificate(curCrt.Version())
|
||||
if curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
|
||||
// The current tunnel is using the latest certificate and version, no need to rehandshake.
|
||||
curCrtVersion := curCrt.Version()
|
||||
myCrt := cs.getCertificate(curCrtVersion)
|
||||
if myCrt == nil {
|
||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||
WithField("version", curCrtVersion).
|
||||
WithField("reason", "local certificate removed").
|
||||
Info("Re-handshaking with remote")
|
||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||
return
|
||||
}
|
||||
peerCrt := hostinfo.ConnectionState.peerCert
|
||||
if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() {
|
||||
// if our certificate version is less than theirs, and we have a matching version available, rehandshake?
|
||||
if cs.getCertificate(peerCrt.Certificate.Version()) != nil {
|
||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||
WithField("version", curCrtVersion).
|
||||
WithField("peerVersion", peerCrt.Certificate.Version()).
|
||||
WithField("reason", "local certificate version lower than peer, attempting to correct").
|
||||
Info("Re-handshaking with remote")
|
||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) {
|
||||
hh.initiatingVersionOverride = peerCrt.Certificate.Version()
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
if !bytes.Equal(curCrt.Signature(), myCrt.Signature()) {
|
||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||
WithField("reason", "local certificate is not current").
|
||||
Info("Re-handshaking with remote")
|
||||
|
||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||
WithField("reason", "local certificate is not current").
|
||||
Info("Re-handshaking with remote")
|
||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||
return
|
||||
}
|
||||
if curCrtVersion < cs.initiatingVersion {
|
||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||
WithField("reason", "current cert version < pki.initiatingVersion").
|
||||
Info("Re-handshaking with remote")
|
||||
|
||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,10 +13,7 @@ import (
|
||||
"github.com/slackhq/nebula/noiseutil"
|
||||
)
|
||||
|
||||
// ReplayWindow controls the size of the sliding window used to detect replays.
|
||||
// High-bandwidth links with GRO/GSO can reorder more than a thousand packets in
|
||||
// flight, so keep this comfortably above the largest expected burst.
|
||||
const ReplayWindow = 32768
|
||||
const ReplayWindow = 1024
|
||||
|
||||
type ConnectionState struct {
|
||||
eKey *NebulaCipherState
|
||||
|
||||
56
control.go
56
control.go
@@ -2,9 +2,11 @@ package nebula
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
@@ -13,6 +15,16 @@ import (
|
||||
"github.com/slackhq/nebula/overlay"
|
||||
)
|
||||
|
||||
type RunState int
|
||||
|
||||
const (
|
||||
Stopped RunState = 0 // The control has yet to be started
|
||||
Started RunState = 1 // The control has been started
|
||||
Stopping RunState = 2 // The control is stopping
|
||||
)
|
||||
|
||||
var ErrAlreadyStarted = errors.New("nebula is already started")
|
||||
|
||||
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
|
||||
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
|
||||
|
||||
@@ -26,6 +38,9 @@ type controlHostLister interface {
|
||||
}
|
||||
|
||||
type Control struct {
|
||||
stateLock sync.Mutex
|
||||
state RunState
|
||||
|
||||
f *Interface
|
||||
l *logrus.Logger
|
||||
ctx context.Context
|
||||
@@ -49,10 +64,21 @@ type ControlHostInfo struct {
|
||||
CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"`
|
||||
}
|
||||
|
||||
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
|
||||
func (c *Control) Start() {
|
||||
// Start actually runs nebula, this is a nonblocking call.
|
||||
// The returned function can be used to wait for nebula to fully stop.
|
||||
func (c *Control) Start() (func(), error) {
|
||||
c.stateLock.Lock()
|
||||
if c.state != Stopped {
|
||||
c.stateLock.Unlock()
|
||||
return nil, ErrAlreadyStarted
|
||||
}
|
||||
|
||||
// Activate the interface
|
||||
c.f.activate()
|
||||
err := c.f.activate()
|
||||
if err != nil {
|
||||
c.stateLock.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Call all the delayed funcs that waited patiently for the interface to be created.
|
||||
if c.sshStart != nil {
|
||||
@@ -72,15 +98,33 @@ func (c *Control) Start() {
|
||||
}
|
||||
|
||||
// Start reading packets.
|
||||
c.f.run()
|
||||
c.state = Started
|
||||
c.stateLock.Unlock()
|
||||
return c.f.run()
|
||||
}
|
||||
|
||||
func (c *Control) State() RunState {
|
||||
c.stateLock.Lock()
|
||||
defer c.stateLock.Unlock()
|
||||
return c.state
|
||||
}
|
||||
|
||||
func (c *Control) Context() context.Context {
|
||||
return c.ctx
|
||||
}
|
||||
|
||||
// Stop signals nebula to shutdown and close all tunnels, returns after the shutdown is complete
|
||||
// Stop is a non-blocking call that signals nebula to close all tunnels and shut down
|
||||
func (c *Control) Stop() {
|
||||
c.stateLock.Lock()
|
||||
if c.state != Started {
|
||||
c.stateLock.Unlock()
|
||||
// We are stopping or stopped already
|
||||
return
|
||||
}
|
||||
|
||||
c.state = Stopping
|
||||
c.stateLock.Unlock()
|
||||
|
||||
// Stop the handshakeManager (and other services), to prevent new tunnels from
|
||||
// being created while we're shutting them all down.
|
||||
c.cancel()
|
||||
@@ -89,7 +133,7 @@ func (c *Control) Stop() {
|
||||
if err := c.f.Close(); err != nil {
|
||||
c.l.WithError(err).Error("Close interface failed")
|
||||
}
|
||||
c.l.Info("Goodbye")
|
||||
c.state = Stopped
|
||||
}
|
||||
|
||||
// ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled
|
||||
|
||||
@@ -129,6 +129,109 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
|
||||
return control, vpnNetworks, udpAddr, c
|
||||
}
|
||||
|
||||
// newServer creates a nebula instance with fewer assumptions
|
||||
func newServer(caCrt []cert.Certificate, certs []cert.Certificate, key []byte, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
||||
l := NewTestLogger()
|
||||
|
||||
vpnNetworks := certs[len(certs)-1].Networks()
|
||||
|
||||
var udpAddr netip.AddrPort
|
||||
if vpnNetworks[0].Addr().Is4() {
|
||||
budpIp := vpnNetworks[0].Addr().As4()
|
||||
budpIp[1] -= 128
|
||||
udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242)
|
||||
} else {
|
||||
budpIp := vpnNetworks[0].Addr().As16()
|
||||
// beef for funsies
|
||||
budpIp[2] = 190
|
||||
budpIp[3] = 239
|
||||
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
|
||||
}
|
||||
|
||||
caStr := ""
|
||||
for _, ca := range caCrt {
|
||||
x, err := ca.MarshalPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
caStr += string(x)
|
||||
}
|
||||
certStr := ""
|
||||
for _, c := range certs {
|
||||
x, err := c.MarshalPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
certStr += string(x)
|
||||
}
|
||||
|
||||
mc := m{
|
||||
"pki": m{
|
||||
"ca": caStr,
|
||||
"cert": certStr,
|
||||
"key": string(key),
|
||||
},
|
||||
//"tun": m{"disabled": true},
|
||||
"firewall": m{
|
||||
"outbound": []m{{
|
||||
"proto": "any",
|
||||
"port": "any",
|
||||
"host": "any",
|
||||
}},
|
||||
"inbound": []m{{
|
||||
"proto": "any",
|
||||
"port": "any",
|
||||
"host": "any",
|
||||
}},
|
||||
},
|
||||
//"handshakes": m{
|
||||
// "try_interval": "1s",
|
||||
//},
|
||||
"listen": m{
|
||||
"host": udpAddr.Addr().String(),
|
||||
"port": udpAddr.Port(),
|
||||
},
|
||||
"logging": m{
|
||||
"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", certs[0].Name()),
|
||||
"level": l.Level.String(),
|
||||
},
|
||||
"timers": m{
|
||||
"pending_deletion_interval": 2,
|
||||
"connection_alive_interval": 2,
|
||||
},
|
||||
}
|
||||
|
||||
if overrides != nil {
|
||||
final := m{}
|
||||
err := mergo.Merge(&final, overrides, mergo.WithAppendSlice)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = mergo.Merge(&final, mc, mergo.WithAppendSlice)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
mc = final
|
||||
}
|
||||
|
||||
cb, err := yaml.Marshal(mc)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
c := config.NewC(l)
|
||||
cStr := string(cb)
|
||||
c.LoadString(cStr)
|
||||
|
||||
control, err := nebula.Main(c, false, "e2e-test", l, nil)
|
||||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return control, vpnNetworks, udpAddr, c
|
||||
}
|
||||
|
||||
type doneCb func()
|
||||
|
||||
func deadline(t *testing.T, seconds time.Duration) doneCb {
|
||||
|
||||
@@ -4,12 +4,16 @@
|
||||
package e2e
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/cert_test"
|
||||
"github.com/slackhq/nebula/e2e/router"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func TestDropInactiveTunnels(t *testing.T) {
|
||||
@@ -55,3 +59,262 @@ func TestDropInactiveTunnels(t *testing.T) {
|
||||
myControl.Stop()
|
||||
theirControl.Stop()
|
||||
}
|
||||
|
||||
func TestCertUpgrade(t *testing.T) {
|
||||
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
||||
// under ideal conditions
|
||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
caB, err := ca.MarshalPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
|
||||
ca2B, err := ca2.MarshalPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
caStr := fmt.Sprintf("%s\n%s", caB, ca2B)
|
||||
|
||||
myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
|
||||
_, myCert2Pem := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
|
||||
|
||||
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
|
||||
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
|
||||
|
||||
myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert}, myPrivKey, m{})
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
|
||||
|
||||
// Share our underlay information
|
||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||
|
||||
// Start the servers
|
||||
myControl.Start()
|
||||
theirControl.Start()
|
||||
|
||||
r := router.NewR(t, myControl, theirControl)
|
||||
defer r.RenderFlow()
|
||||
|
||||
r.Log("Assert the tunnel between me and them works")
|
||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||
r.Log("yay")
|
||||
//todo ???
|
||||
time.Sleep(1 * time.Second)
|
||||
r.FlushAll()
|
||||
|
||||
mc := m{
|
||||
"pki": m{
|
||||
"ca": caStr,
|
||||
"cert": string(myCert2Pem),
|
||||
"key": string(myPrivKey),
|
||||
},
|
||||
//"tun": m{"disabled": true},
|
||||
"firewall": myC.Settings["firewall"],
|
||||
//"handshakes": m{
|
||||
// "try_interval": "1s",
|
||||
//},
|
||||
"listen": myC.Settings["listen"],
|
||||
"logging": myC.Settings["logging"],
|
||||
"timers": myC.Settings["timers"],
|
||||
}
|
||||
|
||||
cb, err := yaml.Marshal(mc)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
r.Logf("reload new v2-only config")
|
||||
err = myC.ReloadConfigString(string(cb))
|
||||
assert.NoError(t, err)
|
||||
r.Log("yay, spin until their sees it")
|
||||
waitStart := time.Now()
|
||||
for {
|
||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
||||
if c == nil {
|
||||
r.Log("nil")
|
||||
} else {
|
||||
version := c.Cert.Version()
|
||||
r.Logf("version %d", version)
|
||||
if version == cert.Version2 {
|
||||
break
|
||||
}
|
||||
}
|
||||
since := time.Since(waitStart)
|
||||
if since > time.Second*10 {
|
||||
t.Fatal("Cert should be new by now")
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
||||
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
||||
|
||||
myControl.Stop()
|
||||
theirControl.Stop()
|
||||
}
|
||||
|
||||
func TestCertDowngrade(t *testing.T) {
|
||||
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
||||
// under ideal conditions
|
||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
caB, err := ca.MarshalPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
|
||||
ca2B, err := ca2.MarshalPEM()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
caStr := fmt.Sprintf("%s\n%s", caB, ca2B)
|
||||
|
||||
myCert, _, myPrivKey, myCertPem := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
|
||||
myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
|
||||
|
||||
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
|
||||
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
|
||||
|
||||
myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
|
||||
|
||||
// Share our underlay information
|
||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||
|
||||
// Start the servers
|
||||
myControl.Start()
|
||||
theirControl.Start()
|
||||
|
||||
r := router.NewR(t, myControl, theirControl)
|
||||
defer r.RenderFlow()
|
||||
|
||||
r.Log("Assert the tunnel between me and them works")
|
||||
//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
||||
//r.Log("yay")
|
||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||
r.Log("yay")
|
||||
//todo ???
|
||||
time.Sleep(1 * time.Second)
|
||||
r.FlushAll()
|
||||
|
||||
mc := m{
|
||||
"pki": m{
|
||||
"ca": caStr,
|
||||
"cert": string(myCertPem),
|
||||
"key": string(myPrivKey),
|
||||
},
|
||||
"firewall": myC.Settings["firewall"],
|
||||
"listen": myC.Settings["listen"],
|
||||
"logging": myC.Settings["logging"],
|
||||
"timers": myC.Settings["timers"],
|
||||
}
|
||||
|
||||
cb, err := yaml.Marshal(mc)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
r.Logf("reload new v1-only config")
|
||||
err = myC.ReloadConfigString(string(cb))
|
||||
assert.NoError(t, err)
|
||||
r.Log("yay, spin until their sees it")
|
||||
waitStart := time.Now()
|
||||
for {
|
||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
||||
c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
|
||||
if c == nil || c2 == nil {
|
||||
r.Log("nil")
|
||||
} else {
|
||||
version := c.Cert.Version()
|
||||
theirVersion := c2.Cert.Version()
|
||||
r.Logf("version %d,%d", version, theirVersion)
|
||||
if version == cert.Version1 {
|
||||
break
|
||||
}
|
||||
}
|
||||
since := time.Since(waitStart)
|
||||
if since > time.Second*5 {
|
||||
r.Log("it is unusual that the cert is not new yet, but not a failure yet")
|
||||
}
|
||||
if since > time.Second*10 {
|
||||
r.Log("wtf")
|
||||
t.Fatal("Cert should be new by now")
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
||||
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
||||
|
||||
myControl.Stop()
|
||||
theirControl.Stop()
|
||||
}
|
||||
|
||||
func TestCertMismatchCorrection(t *testing.T) {
|
||||
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
||||
// under ideal conditions
|
||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
|
||||
myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
|
||||
myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
|
||||
|
||||
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
|
||||
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
|
||||
|
||||
myControl, myVpnIpNet, myUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
|
||||
|
||||
// Share our underlay information
|
||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||
|
||||
// Start the servers
|
||||
myControl.Start()
|
||||
theirControl.Start()
|
||||
|
||||
r := router.NewR(t, myControl, theirControl)
|
||||
defer r.RenderFlow()
|
||||
|
||||
r.Log("Assert the tunnel between me and them works")
|
||||
//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
||||
//r.Log("yay")
|
||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||
r.Log("yay")
|
||||
//todo ???
|
||||
time.Sleep(1 * time.Second)
|
||||
r.FlushAll()
|
||||
|
||||
waitStart := time.Now()
|
||||
for {
|
||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
||||
c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
|
||||
if c == nil || c2 == nil {
|
||||
r.Log("nil")
|
||||
} else {
|
||||
version := c.Cert.Version()
|
||||
theirVersion := c2.Cert.Version()
|
||||
r.Logf("version %d,%d", version, theirVersion)
|
||||
if version == theirVersion {
|
||||
break
|
||||
}
|
||||
}
|
||||
since := time.Since(waitStart)
|
||||
if since > time.Second*5 {
|
||||
r.Log("wtf")
|
||||
}
|
||||
if since > time.Second*10 {
|
||||
r.Log("wtf")
|
||||
t.Fatal("Cert should be new by now")
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
||||
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
||||
|
||||
myControl.Stop()
|
||||
theirControl.Stop()
|
||||
}
|
||||
|
||||
@@ -23,13 +23,17 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// If we're connecting to a v6 address we must use a v2 cert
|
||||
cs := f.pki.getCertState()
|
||||
v := cs.initiatingVersion
|
||||
for _, a := range hh.hostinfo.vpnAddrs {
|
||||
if a.Is6() {
|
||||
v = cert.Version2
|
||||
break
|
||||
if hh.initiatingVersionOverride != cert.VersionPre1 {
|
||||
v = hh.initiatingVersionOverride
|
||||
} else if v < cert.Version2 {
|
||||
// If we're connecting to a v6 address we should encourage use of a V2 cert
|
||||
for _, a := range hh.hostinfo.vpnAddrs {
|
||||
if a.Is6() {
|
||||
v = cert.Version2
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,6 +52,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
||||
WithField("certVersion", v).
|
||||
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
||||
return false
|
||||
}
|
||||
|
||||
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
|
||||
@@ -103,6 +108,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
||||
WithField("certVersion", cs.initiatingVersion).
|
||||
Error("Unable to handshake with host because no certificate is available")
|
||||
return
|
||||
}
|
||||
|
||||
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
|
||||
@@ -143,8 +149,8 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
|
||||
remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
|
||||
if err != nil {
|
||||
fp, err := rc.Fingerprint()
|
||||
if err != nil {
|
||||
fp, fperr := rc.Fingerprint()
|
||||
if fperr != nil {
|
||||
fp = "<error generating certificate fingerprint>"
|
||||
}
|
||||
|
||||
@@ -163,16 +169,19 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
||||
|
||||
if remoteCert.Certificate.Version() != ci.myCert.Version() {
|
||||
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
|
||||
rc := cs.getCertificate(remoteCert.Certificate.Version())
|
||||
if rc == nil {
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
|
||||
Info("Unable to handshake with host due to missing certificate version")
|
||||
return
|
||||
myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version())
|
||||
if myCertOtherVersion == nil {
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithError(err).WithFields(m{
|
||||
"udpAddr": addr,
|
||||
"handshake": m{"stage": 1, "style": "ix_psk0"},
|
||||
"cert": remoteCert,
|
||||
}).Debug("Might be unable to handshake with host due to missing certificate version")
|
||||
}
|
||||
} else {
|
||||
// Record the certificate we are actually using
|
||||
ci.myCert = myCertOtherVersion
|
||||
}
|
||||
|
||||
// Record the certificate we are actually using
|
||||
ci.myCert = rc
|
||||
}
|
||||
|
||||
if len(remoteCert.Certificate.Networks()) == 0 {
|
||||
|
||||
@@ -68,11 +68,12 @@ type HandshakeManager struct {
|
||||
type HandshakeHostInfo struct {
|
||||
sync.Mutex
|
||||
|
||||
startTime time.Time // Time that we first started trying with this handshake
|
||||
ready bool // Is the handshake ready
|
||||
counter int64 // How many attempts have we made so far
|
||||
lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
|
||||
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
|
||||
startTime time.Time // Time that we first started trying with this handshake
|
||||
ready bool // Is the handshake ready
|
||||
initiatingVersionOverride cert.Version // Should we use a non-default cert version for this handshake?
|
||||
counter int64 // How many attempts have we made so far
|
||||
lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
|
||||
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
|
||||
|
||||
hostinfo *HostInfo
|
||||
}
|
||||
|
||||
413
interface.go
413
interface.go
@@ -5,9 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/bits"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -23,12 +21,7 @@ import (
|
||||
"github.com/slackhq/nebula/udp"
|
||||
)
|
||||
|
||||
const (
|
||||
mtu = 9001
|
||||
tunReadBufferSize = mtu * 8
|
||||
defaultDecryptWorkerFactor = 2
|
||||
defaultInboundQueueDepth = 1024
|
||||
)
|
||||
const mtu = 9001
|
||||
|
||||
type InterfaceConfig struct {
|
||||
HostMap *HostMap
|
||||
@@ -55,8 +48,6 @@ type InterfaceConfig struct {
|
||||
|
||||
ConntrackCacheTimeout time.Duration
|
||||
l *logrus.Logger
|
||||
DecryptWorkers int
|
||||
DecryptQueueDepth int
|
||||
}
|
||||
|
||||
type Interface struct {
|
||||
@@ -96,172 +87,13 @@ type Interface struct {
|
||||
|
||||
writers []udp.Conn
|
||||
readers []io.ReadWriteCloser
|
||||
wg sync.WaitGroup
|
||||
|
||||
metricHandshakes metrics.Histogram
|
||||
messageMetrics *MessageMetrics
|
||||
cachedPacketMetrics *cachedPacketMetrics
|
||||
|
||||
l *logrus.Logger
|
||||
ctx context.Context
|
||||
udpListenWG sync.WaitGroup
|
||||
inboundPool sync.Pool
|
||||
decryptWG sync.WaitGroup
|
||||
decryptQueues []*inboundRing
|
||||
decryptWorkers int
|
||||
decryptStates []decryptWorkerState
|
||||
decryptCounter atomic.Uint32
|
||||
}
|
||||
|
||||
type inboundPacket struct {
|
||||
addr netip.AddrPort
|
||||
payload []byte
|
||||
release func()
|
||||
queue int
|
||||
}
|
||||
|
||||
type decryptWorkerState struct {
|
||||
queue *inboundRing
|
||||
notify chan struct{}
|
||||
}
|
||||
|
||||
type decryptContext struct {
|
||||
ctTicker *firewall.ConntrackCacheTicker
|
||||
plain []byte
|
||||
head header.H
|
||||
fwPacket firewall.Packet
|
||||
light *LightHouseHandler
|
||||
nebula []byte
|
||||
}
|
||||
|
||||
type inboundCell struct {
|
||||
seq atomic.Uint64
|
||||
pkt *inboundPacket
|
||||
}
|
||||
|
||||
type inboundRing struct {
|
||||
mask uint64
|
||||
cells []inboundCell
|
||||
enqueuePos atomic.Uint64
|
||||
dequeuePos atomic.Uint64
|
||||
}
|
||||
|
||||
func newInboundRing(capacity int) *inboundRing {
|
||||
if capacity < 2 {
|
||||
capacity = 2
|
||||
}
|
||||
size := nextPowerOfTwo(uint32(capacity))
|
||||
if size < 2 {
|
||||
size = 2
|
||||
}
|
||||
ring := &inboundRing{
|
||||
mask: uint64(size - 1),
|
||||
cells: make([]inboundCell, size),
|
||||
}
|
||||
for i := range ring.cells {
|
||||
ring.cells[i].seq.Store(uint64(i))
|
||||
}
|
||||
return ring
|
||||
}
|
||||
|
||||
func nextPowerOfTwo(v uint32) uint32 {
|
||||
if v == 0 {
|
||||
return 1
|
||||
}
|
||||
return 1 << (32 - bits.LeadingZeros32(v-1))
|
||||
}
|
||||
|
||||
func (r *inboundRing) Enqueue(pkt *inboundPacket) bool {
|
||||
var cell *inboundCell
|
||||
pos := r.enqueuePos.Load()
|
||||
for {
|
||||
cell = &r.cells[pos&r.mask]
|
||||
seq := cell.seq.Load()
|
||||
diff := int64(seq) - int64(pos)
|
||||
if diff == 0 {
|
||||
if r.enqueuePos.CompareAndSwap(pos, pos+1) {
|
||||
break
|
||||
}
|
||||
} else if diff < 0 {
|
||||
return false
|
||||
} else {
|
||||
pos = r.enqueuePos.Load()
|
||||
}
|
||||
}
|
||||
cell.pkt = pkt
|
||||
cell.seq.Store(pos + 1)
|
||||
return true
|
||||
}
|
||||
|
||||
func (r *inboundRing) Dequeue() (*inboundPacket, bool) {
|
||||
var cell *inboundCell
|
||||
pos := r.dequeuePos.Load()
|
||||
for {
|
||||
cell = &r.cells[pos&r.mask]
|
||||
seq := cell.seq.Load()
|
||||
diff := int64(seq) - int64(pos+1)
|
||||
if diff == 0 {
|
||||
if r.dequeuePos.CompareAndSwap(pos, pos+1) {
|
||||
break
|
||||
}
|
||||
} else if diff < 0 {
|
||||
return nil, false
|
||||
} else {
|
||||
pos = r.dequeuePos.Load()
|
||||
}
|
||||
}
|
||||
pkt := cell.pkt
|
||||
cell.pkt = nil
|
||||
cell.seq.Store(pos + r.mask + 1)
|
||||
return pkt, true
|
||||
}
|
||||
|
||||
func (f *Interface) getInboundPacket() *inboundPacket {
|
||||
if pkt, ok := f.inboundPool.Get().(*inboundPacket); ok && pkt != nil {
|
||||
return pkt
|
||||
}
|
||||
return &inboundPacket{}
|
||||
}
|
||||
|
||||
func (f *Interface) putInboundPacket(pkt *inboundPacket) {
|
||||
if pkt == nil {
|
||||
return
|
||||
}
|
||||
pkt.addr = netip.AddrPort{}
|
||||
pkt.payload = nil
|
||||
pkt.release = nil
|
||||
pkt.queue = 0
|
||||
f.inboundPool.Put(pkt)
|
||||
}
|
||||
|
||||
func newDecryptContext(f *Interface) *decryptContext {
|
||||
return &decryptContext{
|
||||
ctTicker: firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout),
|
||||
plain: make([]byte, udp.MTU),
|
||||
head: header.H{},
|
||||
fwPacket: firewall.Packet{},
|
||||
light: f.lightHouse.NewRequestHandler(),
|
||||
nebula: make([]byte, 12, 12),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Interface) processInboundPacket(pkt *inboundPacket, ctx *decryptContext) {
|
||||
if pkt == nil {
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if pkt.release != nil {
|
||||
pkt.release()
|
||||
}
|
||||
f.putInboundPacket(pkt)
|
||||
}()
|
||||
|
||||
ctx.head = header.H{}
|
||||
ctx.fwPacket = firewall.Packet{}
|
||||
var cache firewall.ConntrackCache
|
||||
if ctx.ctTicker != nil {
|
||||
cache = ctx.ctTicker.Get(f.l)
|
||||
}
|
||||
f.readOutsidePackets(pkt.addr, nil, ctx.plain[:0], pkt.payload, &ctx.head, &ctx.fwPacket, ctx.light, ctx.nebula, pkt.queue, cache)
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
type EncWriter interface {
|
||||
@@ -331,35 +163,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
||||
}
|
||||
|
||||
cs := c.pki.getCertState()
|
||||
decryptWorkers := c.DecryptWorkers
|
||||
if decryptWorkers < 0 {
|
||||
decryptWorkers = 0
|
||||
}
|
||||
if decryptWorkers == 0 {
|
||||
decryptWorkers = c.routines * defaultDecryptWorkerFactor
|
||||
if decryptWorkers < c.routines {
|
||||
decryptWorkers = c.routines
|
||||
}
|
||||
}
|
||||
if decryptWorkers < 0 {
|
||||
decryptWorkers = 0
|
||||
}
|
||||
if runtime.GOOS != "linux" {
|
||||
decryptWorkers = 0
|
||||
}
|
||||
|
||||
queueDepth := c.DecryptQueueDepth
|
||||
if queueDepth <= 0 {
|
||||
queueDepth = defaultInboundQueueDepth
|
||||
}
|
||||
minDepth := c.routines * 64
|
||||
if minDepth <= 0 {
|
||||
minDepth = 64
|
||||
}
|
||||
if queueDepth < minDepth {
|
||||
queueDepth = minDepth
|
||||
}
|
||||
|
||||
ifce := &Interface{
|
||||
pki: c.pki,
|
||||
hostMap: c.HostMap,
|
||||
@@ -392,10 +195,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
||||
dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil),
|
||||
},
|
||||
|
||||
l: c.l,
|
||||
ctx: ctx,
|
||||
inboundPool: sync.Pool{New: func() any { return &inboundPacket{} }},
|
||||
decryptWorkers: decryptWorkers,
|
||||
l: c.l,
|
||||
}
|
||||
|
||||
ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
|
||||
@@ -404,26 +204,13 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
||||
|
||||
ifce.connectionManager.intf = ifce
|
||||
|
||||
if decryptWorkers > 0 {
|
||||
ifce.decryptQueues = make([]*inboundRing, decryptWorkers)
|
||||
ifce.decryptStates = make([]decryptWorkerState, decryptWorkers)
|
||||
for i := 0; i < decryptWorkers; i++ {
|
||||
queue := newInboundRing(queueDepth)
|
||||
ifce.decryptQueues[i] = queue
|
||||
ifce.decryptStates[i] = decryptWorkerState{
|
||||
queue: queue,
|
||||
notify: make(chan struct{}, 1),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ifce, nil
|
||||
}
|
||||
|
||||
// activate creates the interface on the host. After the interface is created, any
|
||||
// other services that want to bind listeners to its IP may do so successfully. However,
|
||||
// the interface isn't going to process anything until run() is called.
|
||||
func (f *Interface) activate() {
|
||||
func (f *Interface) activate() error {
|
||||
// actually turn on tun dev
|
||||
|
||||
addr, err := f.outside.LocalAddr()
|
||||
@@ -444,94 +231,38 @@ func (f *Interface) activate() {
|
||||
if i > 0 {
|
||||
reader, err = f.inside.NewMultiQueueReader()
|
||||
if err != nil {
|
||||
f.l.Fatal(err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
f.readers[i] = reader
|
||||
}
|
||||
|
||||
if err := f.inside.Activate(); err != nil {
|
||||
if err = f.inside.Activate(); err != nil {
|
||||
f.inside.Close()
|
||||
f.l.Fatal(err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *Interface) startDecryptWorkers() {
|
||||
if f.decryptWorkers <= 0 || len(f.decryptQueues) == 0 {
|
||||
return
|
||||
}
|
||||
f.decryptWG.Add(f.decryptWorkers)
|
||||
for i := 0; i < f.decryptWorkers; i++ {
|
||||
go f.decryptWorker(i)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Interface) decryptWorker(id int) {
|
||||
defer f.decryptWG.Done()
|
||||
if id < 0 || id >= len(f.decryptStates) {
|
||||
return
|
||||
}
|
||||
state := f.decryptStates[id]
|
||||
if state.queue == nil {
|
||||
return
|
||||
}
|
||||
ctx := newDecryptContext(f)
|
||||
for {
|
||||
for {
|
||||
pkt, ok := state.queue.Dequeue()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
f.processInboundPacket(pkt, ctx)
|
||||
}
|
||||
if f.closed.Load() || f.ctx.Err() != nil {
|
||||
for {
|
||||
pkt, ok := state.queue.Dequeue()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
f.processInboundPacket(pkt, ctx)
|
||||
}
|
||||
}
|
||||
select {
|
||||
case <-f.ctx.Done():
|
||||
case <-state.notify:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Interface) notifyDecryptWorker(idx int) {
|
||||
if idx < 0 || idx >= len(f.decryptStates) {
|
||||
return
|
||||
}
|
||||
state := f.decryptStates[idx]
|
||||
if state.notify == nil {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case state.notify <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Interface) run() {
|
||||
f.startDecryptWorkers()
|
||||
func (f *Interface) run() (func(), error) {
|
||||
// Launch n queues to read packets from udp
|
||||
f.udpListenWG.Add(f.routines)
|
||||
for i := 0; i < f.routines; i++ {
|
||||
f.wg.Add(1)
|
||||
go f.listenOut(i)
|
||||
}
|
||||
|
||||
// Launch n queues to read packets from tun dev
|
||||
for i := 0; i < f.routines; i++ {
|
||||
f.wg.Add(1)
|
||||
go f.listenIn(f.readers[i], i)
|
||||
}
|
||||
|
||||
return f.wg.Wait, nil
|
||||
}
|
||||
|
||||
func (f *Interface) listenOut(i int) {
|
||||
runtime.LockOSThread()
|
||||
defer f.udpListenWG.Done()
|
||||
|
||||
var li udp.Conn
|
||||
if i > 0 {
|
||||
li = f.writers[i]
|
||||
@@ -539,78 +270,30 @@ func (f *Interface) listenOut(i int) {
|
||||
li = f.outside
|
||||
}
|
||||
|
||||
useWorkers := f.decryptWorkers > 0 && len(f.decryptQueues) > 0
|
||||
var (
|
||||
inlineTicker *firewall.ConntrackCacheTicker
|
||||
inlineHandler *LightHouseHandler
|
||||
inlinePlain []byte
|
||||
inlineHeader header.H
|
||||
inlinePacket firewall.Packet
|
||||
inlineNB []byte
|
||||
inlineCtx *decryptContext
|
||||
)
|
||||
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||
lhh := f.lightHouse.NewRequestHandler()
|
||||
plaintext := make([]byte, udp.MTU)
|
||||
h := &header.H{}
|
||||
fwPacket := &firewall.Packet{}
|
||||
nb := make([]byte, 12, 12)
|
||||
|
||||
if useWorkers {
|
||||
inlineCtx = newDecryptContext(f)
|
||||
} else {
|
||||
inlineTicker = firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||
inlineHandler = f.lightHouse.NewRequestHandler()
|
||||
inlinePlain = make([]byte, udp.MTU)
|
||||
inlineNB = make([]byte, 12, 12)
|
||||
err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
|
||||
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
|
||||
})
|
||||
|
||||
if err != nil && !f.closed.Load() {
|
||||
f.l.WithError(err).Error("Error while reading packet inbound packet, closing")
|
||||
//TODO: Trigger Control to close
|
||||
}
|
||||
|
||||
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte, release func()) {
|
||||
if !useWorkers {
|
||||
if release != nil {
|
||||
defer release()
|
||||
}
|
||||
select {
|
||||
case <-f.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
inlineHeader = header.H{}
|
||||
inlinePacket = firewall.Packet{}
|
||||
var cache firewall.ConntrackCache
|
||||
if inlineTicker != nil {
|
||||
cache = inlineTicker.Get(f.l)
|
||||
}
|
||||
f.readOutsidePackets(fromUdpAddr, nil, inlinePlain[:0], payload, &inlineHeader, &inlinePacket, inlineHandler, inlineNB, i, cache)
|
||||
return
|
||||
}
|
||||
|
||||
if f.ctx.Err() != nil {
|
||||
if release != nil {
|
||||
release()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
pkt := f.getInboundPacket()
|
||||
pkt.addr = fromUdpAddr
|
||||
pkt.payload = payload
|
||||
pkt.release = release
|
||||
pkt.queue = i
|
||||
|
||||
queueCount := len(f.decryptQueues)
|
||||
if queueCount == 0 {
|
||||
f.processInboundPacket(pkt, inlineCtx)
|
||||
return
|
||||
}
|
||||
w := int(f.decryptCounter.Add(1)-1) % queueCount
|
||||
if w < 0 || w >= queueCount || !f.decryptQueues[w].Enqueue(pkt) {
|
||||
f.processInboundPacket(pkt, inlineCtx)
|
||||
return
|
||||
}
|
||||
f.notifyDecryptWorker(w)
|
||||
})
|
||||
f.l.Debugf("underlay reader %v is done", i)
|
||||
f.wg.Done()
|
||||
}
|
||||
|
||||
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
||||
runtime.LockOSThread()
|
||||
|
||||
packet := make([]byte, tunReadBufferSize)
|
||||
out := make([]byte, tunReadBufferSize)
|
||||
packet := make([]byte, mtu)
|
||||
out := make([]byte, mtu)
|
||||
fwPacket := &firewall.Packet{}
|
||||
nb := make([]byte, 12, 12)
|
||||
|
||||
@@ -619,17 +302,18 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
||||
for {
|
||||
n, err := reader.Read(packet)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
|
||||
return
|
||||
if !f.closed.Load() {
|
||||
f.l.WithError(err).Error("Error while reading outbound packet, closing")
|
||||
//TODO: Trigger Control to close
|
||||
}
|
||||
|
||||
f.l.WithError(err).Error("Error while reading outbound packet")
|
||||
// This only seems to happen when something fatal happens to the fd, so exit.
|
||||
os.Exit(2)
|
||||
break
|
||||
}
|
||||
|
||||
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
|
||||
}
|
||||
|
||||
f.l.Debugf("overlay reader %v is done", i)
|
||||
f.wg.Done()
|
||||
}
|
||||
|
||||
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
|
||||
@@ -781,6 +465,7 @@ func (f *Interface) GetCertState() *CertState {
|
||||
func (f *Interface) Close() error {
|
||||
f.closed.Store(true)
|
||||
|
||||
// Release the udp readers
|
||||
for _, u := range f.writers {
|
||||
err := u.Close()
|
||||
if err != nil {
|
||||
@@ -788,19 +473,13 @@ func (f *Interface) Close() error {
|
||||
}
|
||||
}
|
||||
|
||||
f.udpListenWG.Wait()
|
||||
if f.decryptWorkers > 0 {
|
||||
for _, state := range f.decryptStates {
|
||||
if state.notify != nil {
|
||||
select {
|
||||
case state.notify <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
// Release the tun readers
|
||||
for _, u := range f.readers {
|
||||
err := u.Close()
|
||||
if err != nil {
|
||||
f.l.WithError(err).Error("Error while closing tun device")
|
||||
}
|
||||
f.decryptWG.Wait()
|
||||
}
|
||||
|
||||
// Release the tun device
|
||||
return f.inside.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
25
main.go
25
main.go
@@ -120,8 +120,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
l.WithField("duration", conntrackCacheTimeout).Info("Using routine-local conntrack cache")
|
||||
}
|
||||
|
||||
udp.SetDisableUDPCsum(c.GetBool("listen.disable_udp_checksum", false))
|
||||
|
||||
var tun overlay.Device
|
||||
if !configTest {
|
||||
c.CatchHUP(ctx)
|
||||
@@ -223,9 +221,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
}
|
||||
}
|
||||
|
||||
decryptWorkers := c.GetInt("listen.decrypt_workers", 0)
|
||||
decryptQueueDepth := c.GetInt("listen.decrypt_queue_depth", 0)
|
||||
|
||||
ifConfig := &InterfaceConfig{
|
||||
HostMap: hostMap,
|
||||
Inside: tun,
|
||||
@@ -248,8 +243,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
punchy: punchy,
|
||||
ConntrackCacheTimeout: conntrackCacheTimeout,
|
||||
l: l,
|
||||
DecryptWorkers: decryptWorkers,
|
||||
DecryptQueueDepth: decryptQueueDepth,
|
||||
}
|
||||
|
||||
var ifce *Interface
|
||||
@@ -291,14 +284,14 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
}
|
||||
|
||||
return &Control{
|
||||
ifce,
|
||||
l,
|
||||
ctx,
|
||||
cancel,
|
||||
sshStart,
|
||||
statsStart,
|
||||
dnsStart,
|
||||
lightHouse.StartUpdateWorker,
|
||||
connManager.Start,
|
||||
f: ifce,
|
||||
l: l,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
sshStart: sshStart,
|
||||
statsStart: statsStart,
|
||||
dnsStart: dnsStart,
|
||||
lighthouseStart: lightHouse.StartUpdateWorker,
|
||||
connectionManagerStart: connManager.Start,
|
||||
}, nil
|
||||
}
|
||||
|
||||
10
outside.go
10
outside.go
@@ -29,7 +29,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
||||
return
|
||||
}
|
||||
|
||||
//l.Error("in packet ", header, packet[HeaderLen:])
|
||||
//f.l.Error("in packet ", h)
|
||||
if ip.IsValid() {
|
||||
if f.myVpnNetworksTable.Contains(ip.Addr()) {
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
@@ -470,13 +470,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
||||
|
||||
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
|
||||
if err != nil {
|
||||
hostinfo.logger(f.l).
|
||||
WithError(err).
|
||||
WithField("tag", "decrypt-debug").
|
||||
WithField("remoteIndexLocal", hostinfo.localIndexId).
|
||||
WithField("messageCounter", messageCounter).
|
||||
WithField("packet_len", len(packet)).
|
||||
Error("Failed to decrypt packet")
|
||||
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
@@ -25,17 +25,14 @@ import (
|
||||
|
||||
type tun struct {
|
||||
io.ReadWriteCloser
|
||||
fd int
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
MaxMTU int
|
||||
DefaultMTU int
|
||||
TXQueueLen int
|
||||
deviceIndex int
|
||||
ioctlFd uintptr
|
||||
enableVnetHdr bool
|
||||
vnetHdrLen int
|
||||
queues []*tunQueue
|
||||
fd int
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
MaxMTU int
|
||||
DefaultMTU int
|
||||
TXQueueLen int
|
||||
deviceIndex int
|
||||
ioctlFd uintptr
|
||||
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
@@ -68,90 +65,10 @@ type ifreqQLEN struct {
|
||||
pad [8]byte
|
||||
}
|
||||
|
||||
const (
|
||||
virtioNetHdrLen = 12
|
||||
tunDefaultMaxPacket = 65536
|
||||
)
|
||||
|
||||
type tunQueue struct {
|
||||
file *os.File
|
||||
fd int
|
||||
enableVnetHdr bool
|
||||
vnetHdrLen int
|
||||
maxPacket int
|
||||
writeScratch []byte
|
||||
readScratch []byte
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
func newTunQueue(file *os.File, enableVnetHdr bool, vnetHdrLen, maxPacket int, l *logrus.Logger) *tunQueue {
|
||||
if maxPacket <= 0 {
|
||||
maxPacket = tunDefaultMaxPacket
|
||||
}
|
||||
q := &tunQueue{
|
||||
file: file,
|
||||
fd: int(file.Fd()),
|
||||
enableVnetHdr: enableVnetHdr,
|
||||
vnetHdrLen: vnetHdrLen,
|
||||
maxPacket: maxPacket,
|
||||
l: l,
|
||||
}
|
||||
if enableVnetHdr {
|
||||
q.growReadScratch(maxPacket)
|
||||
}
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *tunQueue) growReadScratch(packetSize int) {
|
||||
needed := q.vnetHdrLen + packetSize
|
||||
if needed < q.vnetHdrLen+DefaultMTU {
|
||||
needed = q.vnetHdrLen + DefaultMTU
|
||||
}
|
||||
if q.readScratch == nil || cap(q.readScratch) < needed {
|
||||
q.readScratch = make([]byte, needed)
|
||||
} else {
|
||||
q.readScratch = q.readScratch[:needed]
|
||||
}
|
||||
}
|
||||
|
||||
func (q *tunQueue) setMaxPacket(packet int) {
|
||||
if packet <= 0 {
|
||||
packet = DefaultMTU
|
||||
}
|
||||
q.maxPacket = packet
|
||||
if q.enableVnetHdr {
|
||||
q.growReadScratch(packet)
|
||||
}
|
||||
}
|
||||
|
||||
func configureVnetHdr(fd int, hdrLen int, l *logrus.Logger) error {
|
||||
features, err := unix.IoctlGetInt(fd, unix.TUNGETFEATURES)
|
||||
if err == nil && features&unix.IFF_VNET_HDR == 0 {
|
||||
return fmt.Errorf("kernel does not support IFF_VNET_HDR")
|
||||
}
|
||||
if err := unix.IoctlSetInt(fd, unix.TUNSETVNETHDRSZ, hdrLen); err != nil {
|
||||
return err
|
||||
}
|
||||
offload := unix.TUN_F_CSUM | unix.TUN_F_UFO
|
||||
if err := unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, offload); err != nil {
|
||||
if l != nil {
|
||||
l.WithError(err).Warn("Failed to enable TUN offload features")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
||||
enableVnetHdr := c.GetBool("tun.enable_vnet_hdr", false)
|
||||
if enableVnetHdr {
|
||||
if err := configureVnetHdr(deviceFd, virtioNetHdrLen, l); err != nil {
|
||||
l.WithError(err).Warn("Failed to configure VNET header support on provided tun fd; disabling")
|
||||
enableVnetHdr = false
|
||||
}
|
||||
}
|
||||
|
||||
t, err := newTunGeneric(c, l, file, vpnNetworks, enableVnetHdr)
|
||||
t, err := newTunGeneric(c, l, file, vpnNetworks)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -189,25 +106,14 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
||||
if multiqueue {
|
||||
req.Flags |= unix.IFF_MULTI_QUEUE
|
||||
}
|
||||
enableVnetHdr := c.GetBool("tun.enable_vnet_hdr", false)
|
||||
if enableVnetHdr {
|
||||
req.Flags |= unix.IFF_VNET_HDR
|
||||
}
|
||||
copy(req.Name[:], c.GetString("tun.dev", ""))
|
||||
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
name := strings.Trim(string(req.Name[:]), "\x00")
|
||||
|
||||
if enableVnetHdr {
|
||||
if err := configureVnetHdr(fd, virtioNetHdrLen, l); err != nil {
|
||||
l.WithError(err).Warn("Failed to configure VNET header support on tun device; disabling")
|
||||
enableVnetHdr = false
|
||||
}
|
||||
}
|
||||
|
||||
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
||||
t, err := newTunGeneric(c, l, file, vpnNetworks, enableVnetHdr)
|
||||
t, err := newTunGeneric(c, l, file, vpnNetworks)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -217,30 +123,21 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix, enableVnetHdr bool) (*tun, error) {
|
||||
queue := newTunQueue(file, enableVnetHdr, virtioNetHdrLen, tunDefaultMaxPacket, l)
|
||||
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
t := &tun{
|
||||
ReadWriteCloser: queue,
|
||||
ReadWriteCloser: file,
|
||||
fd: int(file.Fd()),
|
||||
vpnNetworks: vpnNetworks,
|
||||
TXQueueLen: c.GetInt("tun.tx_queue", 500),
|
||||
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
|
||||
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
|
||||
l: l,
|
||||
enableVnetHdr: enableVnetHdr,
|
||||
vnetHdrLen: virtioNetHdrLen,
|
||||
queues: []*tunQueue{queue},
|
||||
}
|
||||
|
||||
err := t.reload(c, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if enableVnetHdr {
|
||||
for _, q := range t.queues {
|
||||
q.setMaxPacket(t.MaxMTU)
|
||||
}
|
||||
}
|
||||
|
||||
c.RegisterReloadCallback(func(c *config.C) {
|
||||
err := t.reload(c, false)
|
||||
@@ -283,11 +180,6 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
||||
|
||||
t.MaxMTU = newMaxMTU
|
||||
t.DefaultMTU = newDefaultMTU
|
||||
if t.enableVnetHdr {
|
||||
for _, q := range t.queues {
|
||||
q.setMaxPacket(t.MaxMTU)
|
||||
}
|
||||
}
|
||||
|
||||
// Teach nebula how to handle the routes before establishing them in the system table
|
||||
oldRoutes := t.Routes.Swap(&routes)
|
||||
@@ -332,87 +224,14 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
|
||||
var req ifReq
|
||||
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
|
||||
if t.enableVnetHdr {
|
||||
req.Flags |= unix.IFF_VNET_HDR
|
||||
}
|
||||
copy(req.Name[:], t.Device)
|
||||
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
||||
queue := newTunQueue(file, t.enableVnetHdr, t.vnetHdrLen, t.MaxMTU, t.l)
|
||||
if t.enableVnetHdr {
|
||||
if err := configureVnetHdr(fd, t.vnetHdrLen, t.l); err != nil {
|
||||
queue.enableVnetHdr = false
|
||||
}
|
||||
}
|
||||
t.queues = append(t.queues, queue)
|
||||
|
||||
return queue, nil
|
||||
}
|
||||
|
||||
func (q *tunQueue) Read(p []byte) (int, error) {
|
||||
if !q.enableVnetHdr {
|
||||
return q.file.Read(p)
|
||||
}
|
||||
|
||||
if len(p)+q.vnetHdrLen > cap(q.readScratch) {
|
||||
q.growReadScratch(len(p))
|
||||
}
|
||||
|
||||
buf := q.readScratch[:cap(q.readScratch)]
|
||||
n, err := q.file.Read(buf)
|
||||
if n <= 0 {
|
||||
return n, err
|
||||
}
|
||||
if n < q.vnetHdrLen {
|
||||
if err == nil {
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
|
||||
payload := buf[q.vnetHdrLen:n]
|
||||
if len(payload) > len(p) {
|
||||
copy(p, payload[:len(p)])
|
||||
if err == nil {
|
||||
err = io.ErrShortBuffer
|
||||
}
|
||||
return len(p), err
|
||||
}
|
||||
copy(p, payload)
|
||||
return len(payload), err
|
||||
}
|
||||
|
||||
func (q *tunQueue) Write(b []byte) (int, error) {
|
||||
if !q.enableVnetHdr {
|
||||
return unix.Write(q.fd, b)
|
||||
}
|
||||
|
||||
total := q.vnetHdrLen + len(b)
|
||||
if cap(q.writeScratch) < total {
|
||||
q.writeScratch = make([]byte, total)
|
||||
} else {
|
||||
q.writeScratch = q.writeScratch[:total]
|
||||
}
|
||||
|
||||
for i := 0; i < q.vnetHdrLen; i++ {
|
||||
q.writeScratch[i] = 0
|
||||
}
|
||||
copy(q.writeScratch[q.vnetHdrLen:], b)
|
||||
|
||||
n, err := unix.Write(q.fd, q.writeScratch)
|
||||
if n >= q.vnetHdrLen {
|
||||
n -= q.vnetHdrLen
|
||||
} else {
|
||||
n = 0
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (q *tunQueue) Close() error {
|
||||
return q.file.Close()
|
||||
return file, nil
|
||||
}
|
||||
|
||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||
|
||||
85
pki.go
85
pki.go
@@ -100,55 +100,62 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
|
||||
currentState := p.cs.Load()
|
||||
if newState.v1Cert != nil {
|
||||
if currentState.v1Cert == nil {
|
||||
return util.NewContextualError("v1 certificate was added, restart required", nil, err)
|
||||
}
|
||||
//adding certs is fine, actually. Networks-in-common confirmed in newCertState().
|
||||
} else {
|
||||
// did IP in cert change? if so, don't set
|
||||
if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
|
||||
return util.NewContextualError(
|
||||
"Networks in new cert was different from old",
|
||||
m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks(), "cert_version": cert.Version1},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
// did IP in cert change? if so, don't set
|
||||
if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
|
||||
return util.NewContextualError(
|
||||
"Networks in new cert was different from old",
|
||||
m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks()},
|
||||
nil,
|
||||
)
|
||||
if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
|
||||
return util.NewContextualError(
|
||||
"Curve in new v1 cert was different from old",
|
||||
m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve(), "cert_version": cert.Version1},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
|
||||
return util.NewContextualError(
|
||||
"Curve in new cert was different from old",
|
||||
m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve()},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
} else if currentState.v1Cert != nil {
|
||||
//TODO: CERT-V2 we should be able to tear this down
|
||||
return util.NewContextualError("v1 certificate was removed, restart required", nil, err)
|
||||
}
|
||||
|
||||
if newState.v2Cert != nil {
|
||||
if currentState.v2Cert == nil {
|
||||
return util.NewContextualError("v2 certificate was added, restart required", nil, err)
|
||||
}
|
||||
//adding certs is fine, actually
|
||||
} else {
|
||||
// did IP in cert change? if so, don't set
|
||||
if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
|
||||
return util.NewContextualError(
|
||||
"Networks in new cert was different from old",
|
||||
m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks(), "cert_version": cert.Version2},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
// did IP in cert change? if so, don't set
|
||||
if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
|
||||
return util.NewContextualError(
|
||||
"Networks in new cert was different from old",
|
||||
m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks()},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
|
||||
return util.NewContextualError(
|
||||
"Curve in new cert was different from old",
|
||||
m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve()},
|
||||
nil,
|
||||
)
|
||||
if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
|
||||
return util.NewContextualError(
|
||||
"Curve in new cert was different from old",
|
||||
m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve(), "cert_version": cert.Version2},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
} else if currentState.v2Cert != nil {
|
||||
return util.NewContextualError("v2 certificate was removed, restart required", nil, err)
|
||||
//newState.v1Cert is non-nil bc empty certstates aren't permitted
|
||||
if newState.v1Cert == nil {
|
||||
return util.NewContextualError("v1 and v2 certs are nil, this should be impossible", nil, err)
|
||||
}
|
||||
//if we're going to v1-only, we need to make sure we didn't orphan any v2-cert vpnaddrs
|
||||
if !slices.Equal(currentState.v2Cert.Networks(), newState.v1Cert.Networks()) {
|
||||
return util.NewContextualError(
|
||||
"Removing a V2 cert is not permitted unless it has identical networks to the new V1 cert",
|
||||
m{"new_v1_networks": newState.v1Cert.Networks(), "old_v2_networks": currentState.v2Cert.Networks()},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Cipher cant be hot swapped so just leave it at what it was before
|
||||
|
||||
@@ -44,7 +44,10 @@ type Service struct {
|
||||
}
|
||||
|
||||
func New(control *nebula.Control) (*Service, error) {
|
||||
control.Start()
|
||||
wait, err := control.Start()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx := control.Context()
|
||||
eg, ctx := errgroup.WithContext(ctx)
|
||||
@@ -141,6 +144,12 @@ func New(control *nebula.Control) (*Service, error) {
|
||||
}
|
||||
})
|
||||
|
||||
// Add the nebula wait function to the group
|
||||
eg.Go(func() error {
|
||||
wait()
|
||||
return nil
|
||||
})
|
||||
|
||||
return &s, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
package udp
|
||||
|
||||
import "sync/atomic"
|
||||
|
||||
var disableUDPCsum atomic.Bool
|
||||
|
||||
// SetDisableUDPCsum controls whether IPv4 UDP sockets opt out of kernel
|
||||
// checksum calculation via SO_NO_CHECK. Only applicable on platforms that
|
||||
// support the option (Linux). IPv6 always keeps the checksum enabled.
|
||||
func SetDisableUDPCsum(disable bool) {
|
||||
disableUDPCsum.Store(disable)
|
||||
}
|
||||
|
||||
func udpChecksumDisabled() bool {
|
||||
return disableUDPCsum.Load()
|
||||
}
|
||||
@@ -11,13 +11,12 @@ const MTU = 9001
|
||||
type EncReader func(
|
||||
addr netip.AddrPort,
|
||||
payload []byte,
|
||||
release func(),
|
||||
)
|
||||
|
||||
type Conn interface {
|
||||
Rebind() error
|
||||
LocalAddr() (netip.AddrPort, error)
|
||||
ListenOut(r EncReader)
|
||||
ListenOut(r EncReader) error
|
||||
WriteTo(b []byte, addr netip.AddrPort) error
|
||||
ReloadConfig(c *config.C)
|
||||
Close() error
|
||||
@@ -31,8 +30,8 @@ func (NoopConn) Rebind() error {
|
||||
func (NoopConn) LocalAddr() (netip.AddrPort, error) {
|
||||
return netip.AddrPort{}, nil
|
||||
}
|
||||
func (NoopConn) ListenOut(_ EncReader) {
|
||||
return
|
||||
func (NoopConn) ListenOut(_ EncReader) error {
|
||||
return nil
|
||||
}
|
||||
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
|
||||
return nil
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
//go:build linux && (386 || amd64p32 || arm || mips || mipsle) && !android && !e2e_testing
|
||||
// +build linux
|
||||
// +build 386 amd64p32 arm mips mipsle
|
||||
// +build !android
|
||||
// +build !e2e_testing
|
||||
|
||||
package udp
|
||||
|
||||
import "golang.org/x/sys/unix"
|
||||
|
||||
func controllen(n int) uint32 {
|
||||
return uint32(n)
|
||||
}
|
||||
|
||||
func setCmsgLen(h *unix.Cmsghdr, n int) {
|
||||
h.Len = uint32(unix.CmsgLen(n))
|
||||
}
|
||||
|
||||
func setIovecLen(v *unix.Iovec, n int) {
|
||||
v.Len = uint32(n)
|
||||
}
|
||||
|
||||
func setMsghdrIovlen(m *unix.Msghdr, n int) {
|
||||
m.Iovlen = uint32(n)
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
//go:build linux && (amd64 || arm64 || ppc64 || ppc64le || mips64 || mips64le || s390x || riscv64 || loong64) && !android && !e2e_testing
|
||||
// +build linux
|
||||
// +build amd64 arm64 ppc64 ppc64le mips64 mips64le s390x riscv64 loong64
|
||||
// +build !android
|
||||
// +build !e2e_testing
|
||||
|
||||
package udp
|
||||
|
||||
import "golang.org/x/sys/unix"
|
||||
|
||||
func controllen(n int) uint64 {
|
||||
return uint64(n)
|
||||
}
|
||||
|
||||
func setCmsgLen(h *unix.Cmsghdr, n int) {
|
||||
h.Len = uint64(unix.CmsgLen(n))
|
||||
}
|
||||
|
||||
func setIovecLen(v *unix.Iovec, n int) {
|
||||
v.Len = uint64(n)
|
||||
}
|
||||
|
||||
func setMsghdrIovlen(m *unix.Msghdr, n int) {
|
||||
m.Iovlen = uint64(n)
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
//go:build linux && (386 || amd64p32 || arm || mips || mipsle) && !android && !e2e_testing
|
||||
|
||||
package udp
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
type linuxMmsgHdr struct {
|
||||
Hdr unix.Msghdr
|
||||
Len uint32
|
||||
}
|
||||
|
||||
func sendmmsg(fd int, hdrs []linuxMmsgHdr, flags int) (int, error) {
|
||||
if len(hdrs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
n, _, errno := unix.Syscall6(unix.SYS_SENDMMSG, uintptr(fd), uintptr(unsafe.Pointer(&hdrs[0])), uintptr(len(hdrs)), uintptr(flags), 0, 0)
|
||||
if errno != 0 {
|
||||
return int(n), errno
|
||||
}
|
||||
return int(n), nil
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
//go:build linux && (amd64 || arm64 || ppc64 || ppc64le || mips64 || mips64le || s390x || riscv64 || loong64) && !android && !e2e_testing
|
||||
|
||||
package udp
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
type linuxMmsgHdr struct {
|
||||
Hdr unix.Msghdr
|
||||
Len uint32
|
||||
_ uint32
|
||||
}
|
||||
|
||||
func sendmmsg(fd int, hdrs []linuxMmsgHdr, flags int) (int, error) {
|
||||
if len(hdrs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
n, _, errno := unix.Syscall6(unix.SYS_SENDMMSG, uintptr(fd), uintptr(unsafe.Pointer(&hdrs[0])), uintptr(len(hdrs)), uintptr(flags), 0, 0)
|
||||
if errno != 0 {
|
||||
return int(n), errno
|
||||
}
|
||||
return int(n), nil
|
||||
}
|
||||
@@ -165,7 +165,7 @@ func NewUDPStatsEmitter(udpConns []Conn) func() {
|
||||
return func() {}
|
||||
}
|
||||
|
||||
func (u *StdConn) ListenOut(r EncReader) {
|
||||
func (u *StdConn) ListenOut(r EncReader) error {
|
||||
buffer := make([]byte, MTU)
|
||||
|
||||
for {
|
||||
@@ -173,14 +173,13 @@ func (u *StdConn) ListenOut(r EncReader) {
|
||||
n, rua, err := u.ReadFromUDPAddrPort(buffer)
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
u.l.WithError(err).Error("unexpected udp socket receive error")
|
||||
}
|
||||
|
||||
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n], nil)
|
||||
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -71,17 +71,16 @@ type rawMessage struct {
|
||||
Len uint32
|
||||
}
|
||||
|
||||
func (u *GenericConn) ListenOut(r EncReader) {
|
||||
func (u *GenericConn) ListenOut(r EncReader) error {
|
||||
buffer := make([]byte, MTU)
|
||||
|
||||
for {
|
||||
// Just read one packet at a time
|
||||
n, rua, err := u.ReadFromUDPAddrPort(buffer)
|
||||
if err != nil {
|
||||
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n], nil)
|
||||
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
|
||||
}
|
||||
}
|
||||
|
||||
1123
udp/udp_linux.go
1123
udp/udp_linux.go
File diff suppressed because it is too large
Load Diff
@@ -30,29 +30,17 @@ type rawMessage struct {
|
||||
Len uint32
|
||||
}
|
||||
|
||||
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte, [][]byte) {
|
||||
controlLen := int(u.controlLen.Load())
|
||||
|
||||
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
|
||||
msgs := make([]rawMessage, n)
|
||||
buffers := make([][]byte, n)
|
||||
names := make([][]byte, n)
|
||||
|
||||
var controls [][]byte
|
||||
if controlLen > 0 {
|
||||
controls = make([][]byte, n)
|
||||
}
|
||||
|
||||
for i := range msgs {
|
||||
size := int(u.groBufSize.Load())
|
||||
if size < MTU {
|
||||
size = MTU
|
||||
}
|
||||
buf := u.borrowRxBuffer(size)
|
||||
buffers[i] = buf
|
||||
buffers[i] = make([]byte, MTU)
|
||||
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
||||
|
||||
vs := []iovec{
|
||||
{Base: &buf[0], Len: uint32(len(buf))},
|
||||
{Base: &buffers[i][0], Len: uint32(len(buffers[i]))},
|
||||
}
|
||||
|
||||
msgs[i].Hdr.Iov = &vs[0]
|
||||
@@ -60,22 +48,7 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte, [
|
||||
|
||||
msgs[i].Hdr.Name = &names[i][0]
|
||||
msgs[i].Hdr.Namelen = uint32(len(names[i]))
|
||||
|
||||
if controlLen > 0 {
|
||||
controls[i] = make([]byte, controlLen)
|
||||
msgs[i].Hdr.Control = &controls[i][0]
|
||||
msgs[i].Hdr.Controllen = controllen(len(controls[i]))
|
||||
} else {
|
||||
msgs[i].Hdr.Control = nil
|
||||
msgs[i].Hdr.Controllen = controllen(0)
|
||||
}
|
||||
}
|
||||
|
||||
return msgs, buffers, names, controls
|
||||
}
|
||||
|
||||
func setIovecBase(msg *rawMessage, buf []byte) {
|
||||
iov := (*iovec)(msg.Hdr.Iov)
|
||||
iov.Base = &buf[0]
|
||||
iov.Len = uint32(len(buf))
|
||||
return msgs, buffers, names
|
||||
}
|
||||
|
||||
@@ -33,50 +33,25 @@ type rawMessage struct {
|
||||
Pad0 [4]byte
|
||||
}
|
||||
|
||||
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte, [][]byte) {
|
||||
controlLen := int(u.controlLen.Load())
|
||||
|
||||
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
|
||||
msgs := make([]rawMessage, n)
|
||||
buffers := make([][]byte, n)
|
||||
names := make([][]byte, n)
|
||||
|
||||
var controls [][]byte
|
||||
if controlLen > 0 {
|
||||
controls = make([][]byte, n)
|
||||
}
|
||||
|
||||
for i := range msgs {
|
||||
size := int(u.groBufSize.Load())
|
||||
if size < MTU {
|
||||
size = MTU
|
||||
}
|
||||
buf := u.borrowRxBuffer(size)
|
||||
buffers[i] = buf
|
||||
buffers[i] = make([]byte, MTU)
|
||||
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
||||
|
||||
vs := []iovec{{Base: &buf[0], Len: uint64(len(buf))}}
|
||||
vs := []iovec{
|
||||
{Base: &buffers[i][0], Len: uint64(len(buffers[i]))},
|
||||
}
|
||||
|
||||
msgs[i].Hdr.Iov = &vs[0]
|
||||
msgs[i].Hdr.Iovlen = uint64(len(vs))
|
||||
|
||||
msgs[i].Hdr.Name = &names[i][0]
|
||||
msgs[i].Hdr.Namelen = uint32(len(names[i]))
|
||||
|
||||
if controlLen > 0 {
|
||||
controls[i] = make([]byte, controlLen)
|
||||
msgs[i].Hdr.Control = &controls[i][0]
|
||||
msgs[i].Hdr.Controllen = controllen(len(controls[i]))
|
||||
} else {
|
||||
msgs[i].Hdr.Control = nil
|
||||
msgs[i].Hdr.Controllen = controllen(0)
|
||||
}
|
||||
}
|
||||
|
||||
return msgs, buffers, names, controls
|
||||
}
|
||||
|
||||
func setIovecBase(msg *rawMessage, buf []byte) {
|
||||
iov := (*iovec)(msg.Hdr.Iov)
|
||||
iov.Base = &buf[0]
|
||||
iov.Len = uint64(len(buf))
|
||||
return msgs, buffers, names
|
||||
}
|
||||
|
||||
@@ -134,7 +134,7 @@ func (u *RIOConn) bind(sa windows.Sockaddr) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *RIOConn) ListenOut(r EncReader) {
|
||||
func (u *RIOConn) ListenOut(r EncReader) error {
|
||||
buffer := make([]byte, MTU)
|
||||
|
||||
for {
|
||||
@@ -142,14 +142,13 @@ func (u *RIOConn) ListenOut(r EncReader) {
|
||||
n, rua, err := u.receive(buffer)
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
||||
return
|
||||
return err
|
||||
}
|
||||
u.l.WithError(err).Error("unexpected udp socket receive error")
|
||||
continue
|
||||
}
|
||||
|
||||
r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n], nil)
|
||||
r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n])
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ package udp
|
||||
import (
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
@@ -106,13 +107,13 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *TesterConn) ListenOut(r EncReader) {
|
||||
func (u *TesterConn) ListenOut(r EncReader) error {
|
||||
for {
|
||||
p, ok := <-u.RxPackets
|
||||
if !ok {
|
||||
return
|
||||
return os.ErrClosed
|
||||
}
|
||||
r(p.From, p.Data, func() {})
|
||||
r(p.From, p.Data)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user