mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 08:24:25 +01:00
Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
064831cf21 | ||
|
|
9ded90c6e8 | ||
|
|
19600f257f | ||
|
|
3e2a6e0a5d | ||
|
|
bc62f5ec82 | ||
|
|
012fcf40fe | ||
|
|
ad319b964d | ||
|
|
f42878c5fc | ||
|
|
f2b3ef4b3e | ||
|
|
c3ec96d9c2 | ||
|
|
01909f4715 |
@@ -7,13 +7,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||||||
|
|
||||||
## [Unreleased]
|
## [Unreleased]
|
||||||
|
|
||||||
### Added
|
|
||||||
|
|
||||||
- Experimental Linux UDP offload support: enable `listen.enable_gso` and
|
|
||||||
`listen.enable_gro` to activate UDP_SEGMENT batching and GRO receive
|
|
||||||
splitting. Includes automatic capability probing, per-packet fallbacks, and
|
|
||||||
runtime metrics/logs for visibility.
|
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
|
|
||||||
- `default_local_cidr_any` now defaults to false, meaning that any firewall rule
|
- `default_local_cidr_any` now defaults to false, meaning that any firewall rule
|
||||||
|
|||||||
@@ -114,6 +114,33 @@ func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []by
|
|||||||
return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem
|
return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewTestCertDifferentVersion(c cert.Certificate, v cert.Version, ca cert.Certificate, key []byte) (cert.Certificate, []byte) {
|
||||||
|
nc := &cert.TBSCertificate{
|
||||||
|
Version: v,
|
||||||
|
Curve: c.Curve(),
|
||||||
|
Name: c.Name(),
|
||||||
|
Networks: c.Networks(),
|
||||||
|
UnsafeNetworks: c.UnsafeNetworks(),
|
||||||
|
Groups: c.Groups(),
|
||||||
|
NotBefore: time.Unix(c.NotBefore().Unix(), 0),
|
||||||
|
NotAfter: time.Unix(c.NotAfter().Unix(), 0),
|
||||||
|
PublicKey: c.PublicKey(),
|
||||||
|
IsCA: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
c, err := nc.Sign(ca, ca.Curve(), key)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pem, err := c.MarshalPEM()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return c, pem
|
||||||
|
}
|
||||||
|
|
||||||
func X25519Keypair() ([]byte, []byte) {
|
func X25519Keypair() ([]byte, []byte) {
|
||||||
privkey := make([]byte, 32)
|
privkey := make([]byte, 32)
|
||||||
if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
|
if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
|
||||||
|
|||||||
@@ -65,8 +65,16 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !*configTest {
|
if !*configTest {
|
||||||
ctrl.Start()
|
wait, err := ctrl.Start()
|
||||||
ctrl.ShutdownBlock()
|
if err != nil {
|
||||||
|
util.LogWithContextIfNeeded("Error while running", err, l)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
go ctrl.ShutdownBlock()
|
||||||
|
wait()
|
||||||
|
|
||||||
|
l.Info("Goodbye")
|
||||||
}
|
}
|
||||||
|
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
|
|||||||
@@ -3,6 +3,9 @@ package main
|
|||||||
import (
|
import (
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
_ "net/http/pprof"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
@@ -58,10 +61,22 @@ func main() {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
log.Println(http.ListenAndServe("0.0.0.0:6060", nil))
|
||||||
|
}()
|
||||||
|
|
||||||
if !*configTest {
|
if !*configTest {
|
||||||
ctrl.Start()
|
wait, err := ctrl.Start()
|
||||||
|
if err != nil {
|
||||||
|
util.LogWithContextIfNeeded("Error while running", err, l)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
go ctrl.ShutdownBlock()
|
||||||
notifyReady(l)
|
notifyReady(l)
|
||||||
ctrl.ShutdownBlock()
|
wait()
|
||||||
|
|
||||||
|
l.Info("Goodbye")
|
||||||
}
|
}
|
||||||
|
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
|
|||||||
@@ -354,7 +354,6 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
|
|||||||
|
|
||||||
if mainHostInfo {
|
if mainHostInfo {
|
||||||
decision = tryRehandshake
|
decision = tryRehandshake
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
if cm.shouldSwapPrimary(hostinfo) {
|
if cm.shouldSwapPrimary(hostinfo) {
|
||||||
decision = swapPrimary
|
decision = swapPrimary
|
||||||
@@ -461,6 +460,10 @@ func (cm *connectionManager) shouldSwapPrimary(current *HostInfo) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
crt := cm.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
|
crt := cm.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
|
||||||
|
if crt == nil {
|
||||||
|
//my cert was reloaded away. We should definitely swap from this tunnel
|
||||||
|
return true
|
||||||
|
}
|
||||||
// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
|
// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
|
||||||
// settle down.
|
// settle down.
|
||||||
return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
|
return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
|
||||||
@@ -475,31 +478,34 @@ func (cm *connectionManager) swapPrimary(current, primary *HostInfo) {
|
|||||||
cm.hostMap.Unlock()
|
cm.hostMap.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and
|
// isInvalidCertificate decides if we should destroy a tunnel.
|
||||||
// the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid
|
// returns true if pki.disconnect_invalid is true and the certificate is no longer valid.
|
||||||
// check and return true.
|
// Blocklisted certificates will skip the pki.disconnect_invalid check and return true.
|
||||||
func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
|
func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
|
||||||
remoteCert := hostinfo.GetCert()
|
remoteCert := hostinfo.GetCert()
|
||||||
if remoteCert == nil {
|
if remoteCert == nil {
|
||||||
return false
|
return false //don't tear down tunnels for handshakes in progress
|
||||||
}
|
}
|
||||||
|
|
||||||
caPool := cm.intf.pki.GetCAPool()
|
caPool := cm.intf.pki.GetCAPool()
|
||||||
err := caPool.VerifyCachedCertificate(now, remoteCert)
|
err := caPool.VerifyCachedCertificate(now, remoteCert)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return false
|
return false //cert is still valid! yay!
|
||||||
}
|
} else if err == cert.ErrBlockListed { //avoiding errors.Is for speed
|
||||||
|
|
||||||
if !cm.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed {
|
|
||||||
// Block listed certificates should always be disconnected
|
// Block listed certificates should always be disconnected
|
||||||
return false
|
hostinfo.logger(cm.l).WithError(err).
|
||||||
}
|
WithField("fingerprint", remoteCert.Fingerprint).
|
||||||
|
Info("Remote certificate is blocked, tearing down the tunnel")
|
||||||
|
return true
|
||||||
|
} else if cm.intf.disconnectInvalid.Load() {
|
||||||
hostinfo.logger(cm.l).WithError(err).
|
hostinfo.logger(cm.l).WithError(err).
|
||||||
WithField("fingerprint", remoteCert.Fingerprint).
|
WithField("fingerprint", remoteCert.Fingerprint).
|
||||||
Info("Remote certificate is no longer valid, tearing down the tunnel")
|
Info("Remote certificate is no longer valid, tearing down the tunnel")
|
||||||
|
|
||||||
return true
|
return true
|
||||||
|
} else {
|
||||||
|
//if we reach here, the cert is no longer valid, but we're configured to keep tunnels from now-invalid certs open
|
||||||
|
return false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
|
func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
|
||||||
@@ -530,15 +536,45 @@ func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
|
|||||||
func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
||||||
cs := cm.intf.pki.getCertState()
|
cs := cm.intf.pki.getCertState()
|
||||||
curCrt := hostinfo.ConnectionState.myCert
|
curCrt := hostinfo.ConnectionState.myCert
|
||||||
myCrt := cs.getCertificate(curCrt.Version())
|
curCrtVersion := curCrt.Version()
|
||||||
if curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
|
myCrt := cs.getCertificate(curCrtVersion)
|
||||||
// The current tunnel is using the latest certificate and version, no need to rehandshake.
|
if myCrt == nil {
|
||||||
|
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||||
|
WithField("version", curCrtVersion).
|
||||||
|
WithField("reason", "local certificate removed").
|
||||||
|
Info("Re-handshaking with remote")
|
||||||
|
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
peerCrt := hostinfo.ConnectionState.peerCert
|
||||||
|
if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() {
|
||||||
|
// if our certificate version is less than theirs, and we have a matching version available, rehandshake?
|
||||||
|
if cs.getCertificate(peerCrt.Certificate.Version()) != nil {
|
||||||
|
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||||
|
WithField("version", curCrtVersion).
|
||||||
|
WithField("peerVersion", peerCrt.Certificate.Version()).
|
||||||
|
WithField("reason", "local certificate version lower than peer, attempting to correct").
|
||||||
|
Info("Re-handshaking with remote")
|
||||||
|
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) {
|
||||||
|
hh.initiatingVersionOverride = peerCrt.Certificate.Version()
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !bytes.Equal(curCrt.Signature(), myCrt.Signature()) {
|
||||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||||
WithField("reason", "local certificate is not current").
|
WithField("reason", "local certificate is not current").
|
||||||
Info("Re-handshaking with remote")
|
Info("Re-handshaking with remote")
|
||||||
|
|
||||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if curCrtVersion < cs.initiatingVersion {
|
||||||
|
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||||
|
WithField("reason", "current cert version < pki.initiatingVersion").
|
||||||
|
Info("Re-handshaking with remote")
|
||||||
|
|
||||||
|
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
56
control.go
56
control.go
@@ -2,9 +2,11 @@ package nebula
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
@@ -13,6 +15,16 @@ import (
|
|||||||
"github.com/slackhq/nebula/overlay"
|
"github.com/slackhq/nebula/overlay"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type RunState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
Stopped RunState = 0 // The control has yet to be started
|
||||||
|
Started RunState = 1 // The control has been started
|
||||||
|
Stopping RunState = 2 // The control is stopping
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrAlreadyStarted = errors.New("nebula is already started")
|
||||||
|
|
||||||
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
|
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
|
||||||
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
|
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
|
||||||
|
|
||||||
@@ -26,6 +38,9 @@ type controlHostLister interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Control struct {
|
type Control struct {
|
||||||
|
stateLock sync.Mutex
|
||||||
|
state RunState
|
||||||
|
|
||||||
f *Interface
|
f *Interface
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
@@ -49,10 +64,21 @@ type ControlHostInfo struct {
|
|||||||
CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"`
|
CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
|
// Start actually runs nebula, this is a nonblocking call.
|
||||||
func (c *Control) Start() {
|
// The returned function can be used to wait for nebula to fully stop.
|
||||||
|
func (c *Control) Start() (func(), error) {
|
||||||
|
c.stateLock.Lock()
|
||||||
|
if c.state != Stopped {
|
||||||
|
c.stateLock.Unlock()
|
||||||
|
return nil, ErrAlreadyStarted
|
||||||
|
}
|
||||||
|
|
||||||
// Activate the interface
|
// Activate the interface
|
||||||
c.f.activate()
|
err := c.f.activate()
|
||||||
|
if err != nil {
|
||||||
|
c.stateLock.Unlock()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
// Call all the delayed funcs that waited patiently for the interface to be created.
|
// Call all the delayed funcs that waited patiently for the interface to be created.
|
||||||
if c.sshStart != nil {
|
if c.sshStart != nil {
|
||||||
@@ -72,15 +98,33 @@ func (c *Control) Start() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Start reading packets.
|
// Start reading packets.
|
||||||
c.f.run()
|
c.state = Started
|
||||||
|
c.stateLock.Unlock()
|
||||||
|
return c.f.run()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Control) State() RunState {
|
||||||
|
c.stateLock.Lock()
|
||||||
|
defer c.stateLock.Unlock()
|
||||||
|
return c.state
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Control) Context() context.Context {
|
func (c *Control) Context() context.Context {
|
||||||
return c.ctx
|
return c.ctx
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop signals nebula to shutdown and close all tunnels, returns after the shutdown is complete
|
// Stop is a non-blocking call that signals nebula to close all tunnels and shut down
|
||||||
func (c *Control) Stop() {
|
func (c *Control) Stop() {
|
||||||
|
c.stateLock.Lock()
|
||||||
|
if c.state != Started {
|
||||||
|
c.stateLock.Unlock()
|
||||||
|
// We are stopping or stopped already
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.state = Stopping
|
||||||
|
c.stateLock.Unlock()
|
||||||
|
|
||||||
// Stop the handshakeManager (and other services), to prevent new tunnels from
|
// Stop the handshakeManager (and other services), to prevent new tunnels from
|
||||||
// being created while we're shutting them all down.
|
// being created while we're shutting them all down.
|
||||||
c.cancel()
|
c.cancel()
|
||||||
@@ -89,7 +133,7 @@ func (c *Control) Stop() {
|
|||||||
if err := c.f.Close(); err != nil {
|
if err := c.f.Close(); err != nil {
|
||||||
c.l.WithError(err).Error("Close interface failed")
|
c.l.WithError(err).Error("Close interface failed")
|
||||||
}
|
}
|
||||||
c.l.Info("Goodbye")
|
c.state = Stopped
|
||||||
}
|
}
|
||||||
|
|
||||||
// ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled
|
// ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled
|
||||||
|
|||||||
@@ -129,6 +129,109 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
|
|||||||
return control, vpnNetworks, udpAddr, c
|
return control, vpnNetworks, udpAddr, c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// newServer creates a nebula instance with fewer assumptions
|
||||||
|
func newServer(caCrt []cert.Certificate, certs []cert.Certificate, key []byte, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
||||||
|
l := NewTestLogger()
|
||||||
|
|
||||||
|
vpnNetworks := certs[len(certs)-1].Networks()
|
||||||
|
|
||||||
|
var udpAddr netip.AddrPort
|
||||||
|
if vpnNetworks[0].Addr().Is4() {
|
||||||
|
budpIp := vpnNetworks[0].Addr().As4()
|
||||||
|
budpIp[1] -= 128
|
||||||
|
udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242)
|
||||||
|
} else {
|
||||||
|
budpIp := vpnNetworks[0].Addr().As16()
|
||||||
|
// beef for funsies
|
||||||
|
budpIp[2] = 190
|
||||||
|
budpIp[3] = 239
|
||||||
|
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
|
||||||
|
}
|
||||||
|
|
||||||
|
caStr := ""
|
||||||
|
for _, ca := range caCrt {
|
||||||
|
x, err := ca.MarshalPEM()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
caStr += string(x)
|
||||||
|
}
|
||||||
|
certStr := ""
|
||||||
|
for _, c := range certs {
|
||||||
|
x, err := c.MarshalPEM()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
certStr += string(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
mc := m{
|
||||||
|
"pki": m{
|
||||||
|
"ca": caStr,
|
||||||
|
"cert": certStr,
|
||||||
|
"key": string(key),
|
||||||
|
},
|
||||||
|
//"tun": m{"disabled": true},
|
||||||
|
"firewall": m{
|
||||||
|
"outbound": []m{{
|
||||||
|
"proto": "any",
|
||||||
|
"port": "any",
|
||||||
|
"host": "any",
|
||||||
|
}},
|
||||||
|
"inbound": []m{{
|
||||||
|
"proto": "any",
|
||||||
|
"port": "any",
|
||||||
|
"host": "any",
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
//"handshakes": m{
|
||||||
|
// "try_interval": "1s",
|
||||||
|
//},
|
||||||
|
"listen": m{
|
||||||
|
"host": udpAddr.Addr().String(),
|
||||||
|
"port": udpAddr.Port(),
|
||||||
|
},
|
||||||
|
"logging": m{
|
||||||
|
"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", certs[0].Name()),
|
||||||
|
"level": l.Level.String(),
|
||||||
|
},
|
||||||
|
"timers": m{
|
||||||
|
"pending_deletion_interval": 2,
|
||||||
|
"connection_alive_interval": 2,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if overrides != nil {
|
||||||
|
final := m{}
|
||||||
|
err := mergo.Merge(&final, overrides, mergo.WithAppendSlice)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
err = mergo.Merge(&final, mc, mergo.WithAppendSlice)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
mc = final
|
||||||
|
}
|
||||||
|
|
||||||
|
cb, err := yaml.Marshal(mc)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c := config.NewC(l)
|
||||||
|
cStr := string(cb)
|
||||||
|
c.LoadString(cStr)
|
||||||
|
|
||||||
|
control, err := nebula.Main(c, false, "e2e-test", l, nil)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return control, vpnNetworks, udpAddr, c
|
||||||
|
}
|
||||||
|
|
||||||
type doneCb func()
|
type doneCb func()
|
||||||
|
|
||||||
func deadline(t *testing.T, seconds time.Duration) doneCb {
|
func deadline(t *testing.T, seconds time.Duration) doneCb {
|
||||||
|
|||||||
@@ -4,12 +4,16 @@
|
|||||||
package e2e
|
package e2e
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/cert_test"
|
"github.com/slackhq/nebula/cert_test"
|
||||||
"github.com/slackhq/nebula/e2e/router"
|
"github.com/slackhq/nebula/e2e/router"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDropInactiveTunnels(t *testing.T) {
|
func TestDropInactiveTunnels(t *testing.T) {
|
||||||
@@ -55,3 +59,262 @@ func TestDropInactiveTunnels(t *testing.T) {
|
|||||||
myControl.Stop()
|
myControl.Stop()
|
||||||
theirControl.Stop()
|
theirControl.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCertUpgrade(t *testing.T) {
|
||||||
|
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
||||||
|
// under ideal conditions
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
caB, err := ca.MarshalPEM()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
|
||||||
|
ca2B, err := ca2.MarshalPEM()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
caStr := fmt.Sprintf("%s\n%s", caB, ca2B)
|
||||||
|
|
||||||
|
myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
|
||||||
|
_, myCert2Pem := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
|
||||||
|
|
||||||
|
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
|
||||||
|
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
|
||||||
|
|
||||||
|
myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert}, myPrivKey, m{})
|
||||||
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
|
||||||
|
|
||||||
|
// Share our underlay information
|
||||||
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||||
|
|
||||||
|
// Start the servers
|
||||||
|
myControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
r := router.NewR(t, myControl, theirControl)
|
||||||
|
defer r.RenderFlow()
|
||||||
|
|
||||||
|
r.Log("Assert the tunnel between me and them works")
|
||||||
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
r.Log("yay")
|
||||||
|
//todo ???
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
r.FlushAll()
|
||||||
|
|
||||||
|
mc := m{
|
||||||
|
"pki": m{
|
||||||
|
"ca": caStr,
|
||||||
|
"cert": string(myCert2Pem),
|
||||||
|
"key": string(myPrivKey),
|
||||||
|
},
|
||||||
|
//"tun": m{"disabled": true},
|
||||||
|
"firewall": myC.Settings["firewall"],
|
||||||
|
//"handshakes": m{
|
||||||
|
// "try_interval": "1s",
|
||||||
|
//},
|
||||||
|
"listen": myC.Settings["listen"],
|
||||||
|
"logging": myC.Settings["logging"],
|
||||||
|
"timers": myC.Settings["timers"],
|
||||||
|
}
|
||||||
|
|
||||||
|
cb, err := yaml.Marshal(mc)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.Logf("reload new v2-only config")
|
||||||
|
err = myC.ReloadConfigString(string(cb))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
r.Log("yay, spin until their sees it")
|
||||||
|
waitStart := time.Now()
|
||||||
|
for {
|
||||||
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
||||||
|
if c == nil {
|
||||||
|
r.Log("nil")
|
||||||
|
} else {
|
||||||
|
version := c.Cert.Version()
|
||||||
|
r.Logf("version %d", version)
|
||||||
|
if version == cert.Version2 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
since := time.Since(waitStart)
|
||||||
|
if since > time.Second*10 {
|
||||||
|
t.Fatal("Cert should be new by now")
|
||||||
|
}
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
||||||
|
|
||||||
|
myControl.Stop()
|
||||||
|
theirControl.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCertDowngrade(t *testing.T) {
|
||||||
|
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
||||||
|
// under ideal conditions
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
caB, err := ca.MarshalPEM()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
|
||||||
|
ca2B, err := ca2.MarshalPEM()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
caStr := fmt.Sprintf("%s\n%s", caB, ca2B)
|
||||||
|
|
||||||
|
myCert, _, myPrivKey, myCertPem := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
|
||||||
|
myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
|
||||||
|
|
||||||
|
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
|
||||||
|
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
|
||||||
|
|
||||||
|
myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
|
||||||
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
|
||||||
|
|
||||||
|
// Share our underlay information
|
||||||
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||||
|
|
||||||
|
// Start the servers
|
||||||
|
myControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
r := router.NewR(t, myControl, theirControl)
|
||||||
|
defer r.RenderFlow()
|
||||||
|
|
||||||
|
r.Log("Assert the tunnel between me and them works")
|
||||||
|
//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
||||||
|
//r.Log("yay")
|
||||||
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
r.Log("yay")
|
||||||
|
//todo ???
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
r.FlushAll()
|
||||||
|
|
||||||
|
mc := m{
|
||||||
|
"pki": m{
|
||||||
|
"ca": caStr,
|
||||||
|
"cert": string(myCertPem),
|
||||||
|
"key": string(myPrivKey),
|
||||||
|
},
|
||||||
|
"firewall": myC.Settings["firewall"],
|
||||||
|
"listen": myC.Settings["listen"],
|
||||||
|
"logging": myC.Settings["logging"],
|
||||||
|
"timers": myC.Settings["timers"],
|
||||||
|
}
|
||||||
|
|
||||||
|
cb, err := yaml.Marshal(mc)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.Logf("reload new v1-only config")
|
||||||
|
err = myC.ReloadConfigString(string(cb))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
r.Log("yay, spin until their sees it")
|
||||||
|
waitStart := time.Now()
|
||||||
|
for {
|
||||||
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
||||||
|
c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
|
||||||
|
if c == nil || c2 == nil {
|
||||||
|
r.Log("nil")
|
||||||
|
} else {
|
||||||
|
version := c.Cert.Version()
|
||||||
|
theirVersion := c2.Cert.Version()
|
||||||
|
r.Logf("version %d,%d", version, theirVersion)
|
||||||
|
if version == cert.Version1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
since := time.Since(waitStart)
|
||||||
|
if since > time.Second*5 {
|
||||||
|
r.Log("it is unusual that the cert is not new yet, but not a failure yet")
|
||||||
|
}
|
||||||
|
if since > time.Second*10 {
|
||||||
|
r.Log("wtf")
|
||||||
|
t.Fatal("Cert should be new by now")
|
||||||
|
}
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
||||||
|
|
||||||
|
myControl.Stop()
|
||||||
|
theirControl.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCertMismatchCorrection(t *testing.T) {
|
||||||
|
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
||||||
|
// under ideal conditions
|
||||||
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
|
||||||
|
myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
|
||||||
|
myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
|
||||||
|
|
||||||
|
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
|
||||||
|
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
|
||||||
|
|
||||||
|
myControl, myVpnIpNet, myUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
|
||||||
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
|
||||||
|
|
||||||
|
// Share our underlay information
|
||||||
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||||
|
|
||||||
|
// Start the servers
|
||||||
|
myControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
r := router.NewR(t, myControl, theirControl)
|
||||||
|
defer r.RenderFlow()
|
||||||
|
|
||||||
|
r.Log("Assert the tunnel between me and them works")
|
||||||
|
//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
||||||
|
//r.Log("yay")
|
||||||
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
r.Log("yay")
|
||||||
|
//todo ???
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
r.FlushAll()
|
||||||
|
|
||||||
|
waitStart := time.Now()
|
||||||
|
for {
|
||||||
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
||||||
|
c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
|
||||||
|
if c == nil || c2 == nil {
|
||||||
|
r.Log("nil")
|
||||||
|
} else {
|
||||||
|
version := c.Cert.Version()
|
||||||
|
theirVersion := c2.Cert.Version()
|
||||||
|
r.Logf("version %d,%d", version, theirVersion)
|
||||||
|
if version == theirVersion {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
since := time.Since(waitStart)
|
||||||
|
if since > time.Second*5 {
|
||||||
|
r.Log("wtf")
|
||||||
|
}
|
||||||
|
if since > time.Second*10 {
|
||||||
|
r.Log("wtf")
|
||||||
|
t.Fatal("Cert should be new by now")
|
||||||
|
}
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
||||||
|
|
||||||
|
myControl.Stop()
|
||||||
|
theirControl.Stop()
|
||||||
|
}
|
||||||
|
|||||||
@@ -23,15 +23,19 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we're connecting to a v6 address we must use a v2 cert
|
|
||||||
cs := f.pki.getCertState()
|
cs := f.pki.getCertState()
|
||||||
v := cs.initiatingVersion
|
v := cs.initiatingVersion
|
||||||
|
if hh.initiatingVersionOverride != cert.VersionPre1 {
|
||||||
|
v = hh.initiatingVersionOverride
|
||||||
|
} else if v < cert.Version2 {
|
||||||
|
// If we're connecting to a v6 address we should encourage use of a V2 cert
|
||||||
for _, a := range hh.hostinfo.vpnAddrs {
|
for _, a := range hh.hostinfo.vpnAddrs {
|
||||||
if a.Is6() {
|
if a.Is6() {
|
||||||
v = cert.Version2
|
v = cert.Version2
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
crt := cs.getCertificate(v)
|
crt := cs.getCertificate(v)
|
||||||
if crt == nil {
|
if crt == nil {
|
||||||
@@ -48,6 +52,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
|||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
||||||
WithField("certVersion", v).
|
WithField("certVersion", v).
|
||||||
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
|
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
|
||||||
@@ -103,6 +108,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
||||||
WithField("certVersion", cs.initiatingVersion).
|
WithField("certVersion", cs.initiatingVersion).
|
||||||
Error("Unable to handshake with host because no certificate is available")
|
Error("Unable to handshake with host because no certificate is available")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
|
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
|
||||||
@@ -143,8 +149,8 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
|
|
||||||
remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
|
remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fp, err := rc.Fingerprint()
|
fp, fperr := rc.Fingerprint()
|
||||||
if err != nil {
|
if fperr != nil {
|
||||||
fp = "<error generating certificate fingerprint>"
|
fp = "<error generating certificate fingerprint>"
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -163,16 +169,19 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
|
|
||||||
if remoteCert.Certificate.Version() != ci.myCert.Version() {
|
if remoteCert.Certificate.Version() != ci.myCert.Version() {
|
||||||
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
|
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
|
||||||
rc := cs.getCertificate(remoteCert.Certificate.Version())
|
myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version())
|
||||||
if rc == nil {
|
if myCertOtherVersion == nil {
|
||||||
f.l.WithError(err).WithField("udpAddr", addr).
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
|
f.l.WithError(err).WithFields(m{
|
||||||
Info("Unable to handshake with host due to missing certificate version")
|
"udpAddr": addr,
|
||||||
return
|
"handshake": m{"stage": 1, "style": "ix_psk0"},
|
||||||
|
"cert": remoteCert,
|
||||||
|
}).Debug("Might be unable to handshake with host due to missing certificate version")
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
// Record the certificate we are actually using
|
// Record the certificate we are actually using
|
||||||
ci.myCert = rc
|
ci.myCert = myCertOtherVersion
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(remoteCert.Certificate.Networks()) == 0 {
|
if len(remoteCert.Certificate.Networks()) == 0 {
|
||||||
|
|||||||
@@ -70,6 +70,7 @@ type HandshakeHostInfo struct {
|
|||||||
|
|
||||||
startTime time.Time // Time that we first started trying with this handshake
|
startTime time.Time // Time that we first started trying with this handshake
|
||||||
ready bool // Is the handshake ready
|
ready bool // Is the handshake ready
|
||||||
|
initiatingVersionOverride cert.Version // Should we use a non-default cert version for this handshake?
|
||||||
counter int64 // How many attempts have we made so far
|
counter int64 // How many attempts have we made so far
|
||||||
lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
|
lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
|
||||||
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
|
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
|
||||||
|
|||||||
56
interface.go
56
interface.go
@@ -6,8 +6,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -87,6 +87,7 @@ type Interface struct {
|
|||||||
|
|
||||||
writers []udp.Conn
|
writers []udp.Conn
|
||||||
readers []io.ReadWriteCloser
|
readers []io.ReadWriteCloser
|
||||||
|
wg sync.WaitGroup
|
||||||
|
|
||||||
metricHandshakes metrics.Histogram
|
metricHandshakes metrics.Histogram
|
||||||
messageMetrics *MessageMetrics
|
messageMetrics *MessageMetrics
|
||||||
@@ -209,7 +210,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
// activate creates the interface on the host. After the interface is created, any
|
// activate creates the interface on the host. After the interface is created, any
|
||||||
// other services that want to bind listeners to its IP may do so successfully. However,
|
// other services that want to bind listeners to its IP may do so successfully. However,
|
||||||
// the interface isn't going to process anything until run() is called.
|
// the interface isn't going to process anything until run() is called.
|
||||||
func (f *Interface) activate() {
|
func (f *Interface) activate() error {
|
||||||
// actually turn on tun dev
|
// actually turn on tun dev
|
||||||
|
|
||||||
addr, err := f.outside.LocalAddr()
|
addr, err := f.outside.LocalAddr()
|
||||||
@@ -230,33 +231,38 @@ func (f *Interface) activate() {
|
|||||||
if i > 0 {
|
if i > 0 {
|
||||||
reader, err = f.inside.NewMultiQueueReader()
|
reader, err = f.inside.NewMultiQueueReader()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.Fatal(err)
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
f.readers[i] = reader
|
f.readers[i] = reader
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := f.inside.Activate(); err != nil {
|
if err = f.inside.Activate(); err != nil {
|
||||||
f.inside.Close()
|
f.inside.Close()
|
||||||
f.l.Fatal(err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) run() {
|
func (f *Interface) run() (func(), error) {
|
||||||
// Launch n queues to read packets from udp
|
// Launch n queues to read packets from udp
|
||||||
for i := 0; i < f.routines; i++ {
|
for i := 0; i < f.routines; i++ {
|
||||||
|
f.wg.Add(1)
|
||||||
go f.listenOut(i)
|
go f.listenOut(i)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Launch n queues to read packets from tun dev
|
// Launch n queues to read packets from tun dev
|
||||||
for i := 0; i < f.routines; i++ {
|
for i := 0; i < f.routines; i++ {
|
||||||
|
f.wg.Add(1)
|
||||||
go f.listenIn(f.readers[i], i)
|
go f.listenIn(f.readers[i], i)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return f.wg.Wait, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) listenOut(i int) {
|
func (f *Interface) listenOut(i int) {
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
|
|
||||||
var li udp.Conn
|
var li udp.Conn
|
||||||
if i > 0 {
|
if i > 0 {
|
||||||
li = f.writers[i]
|
li = f.writers[i]
|
||||||
@@ -271,14 +277,21 @@ func (f *Interface) listenOut(i int) {
|
|||||||
fwPacket := &firewall.Packet{}
|
fwPacket := &firewall.Packet{}
|
||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
|
|
||||||
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
|
err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
|
||||||
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
|
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if err != nil && !f.closed.Load() {
|
||||||
|
f.l.WithError(err).Error("Error while reading packet inbound packet, closing")
|
||||||
|
//TODO: Trigger Control to close
|
||||||
|
}
|
||||||
|
|
||||||
|
f.l.Debugf("underlay reader %v is done", i)
|
||||||
|
f.wg.Done()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
|
|
||||||
packet := make([]byte, mtu)
|
packet := make([]byte, mtu)
|
||||||
out := make([]byte, mtu)
|
out := make([]byte, mtu)
|
||||||
fwPacket := &firewall.Packet{}
|
fwPacket := &firewall.Packet{}
|
||||||
@@ -289,17 +302,18 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
|||||||
for {
|
for {
|
||||||
n, err := reader.Read(packet)
|
n, err := reader.Read(packet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
|
if !f.closed.Load() {
|
||||||
return
|
f.l.WithError(err).Error("Error while reading outbound packet, closing")
|
||||||
|
//TODO: Trigger Control to close
|
||||||
}
|
}
|
||||||
|
break
|
||||||
f.l.WithError(err).Error("Error while reading outbound packet")
|
|
||||||
// This only seems to happen when something fatal happens to the fd, so exit.
|
|
||||||
os.Exit(2)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
|
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
f.l.Debugf("overlay reader %v is done", i)
|
||||||
|
f.wg.Done()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
|
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
|
||||||
@@ -451,6 +465,7 @@ func (f *Interface) GetCertState() *CertState {
|
|||||||
func (f *Interface) Close() error {
|
func (f *Interface) Close() error {
|
||||||
f.closed.Store(true)
|
f.closed.Store(true)
|
||||||
|
|
||||||
|
// Release the udp readers
|
||||||
for _, u := range f.writers {
|
for _, u := range f.writers {
|
||||||
err := u.Close()
|
err := u.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -458,6 +473,13 @@ func (f *Interface) Close() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Release the tun device
|
// Release the tun readers
|
||||||
return f.inside.Close()
|
for _, u := range f.readers {
|
||||||
|
err := u.Close()
|
||||||
|
if err != nil {
|
||||||
|
f.l.WithError(err).Error("Error while closing tun device")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
18
main.go
18
main.go
@@ -284,14 +284,14 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &Control{
|
return &Control{
|
||||||
ifce,
|
f: ifce,
|
||||||
l,
|
l: l,
|
||||||
ctx,
|
ctx: ctx,
|
||||||
cancel,
|
cancel: cancel,
|
||||||
sshStart,
|
sshStart: sshStart,
|
||||||
statsStart,
|
statsStart: statsStart,
|
||||||
dnsStart,
|
dnsStart: dnsStart,
|
||||||
lightHouse.StartUpdateWorker,
|
lighthouseStart: lightHouse.StartUpdateWorker,
|
||||||
connManager.Start,
|
connectionManagerStart: connManager.Start,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
//l.Error("in packet ", header, packet[HeaderLen:])
|
//f.l.Error("in packet ", h)
|
||||||
if ip.IsValid() {
|
if ip.IsValid() {
|
||||||
if f.myVpnNetworksTable.Contains(ip.Addr()) {
|
if f.myVpnNetworksTable.Contains(ip.Addr()) {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
|
|||||||
39
pki.go
39
pki.go
@@ -100,41 +100,36 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
|
|||||||
currentState := p.cs.Load()
|
currentState := p.cs.Load()
|
||||||
if newState.v1Cert != nil {
|
if newState.v1Cert != nil {
|
||||||
if currentState.v1Cert == nil {
|
if currentState.v1Cert == nil {
|
||||||
return util.NewContextualError("v1 certificate was added, restart required", nil, err)
|
//adding certs is fine, actually. Networks-in-common confirmed in newCertState().
|
||||||
}
|
} else {
|
||||||
|
|
||||||
// did IP in cert change? if so, don't set
|
// did IP in cert change? if so, don't set
|
||||||
if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
|
if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
|
||||||
return util.NewContextualError(
|
return util.NewContextualError(
|
||||||
"Networks in new cert was different from old",
|
"Networks in new cert was different from old",
|
||||||
m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks()},
|
m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks(), "cert_version": cert.Version1},
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
|
if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
|
||||||
return util.NewContextualError(
|
return util.NewContextualError(
|
||||||
"Curve in new cert was different from old",
|
"Curve in new v1 cert was different from old",
|
||||||
m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve()},
|
m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve(), "cert_version": cert.Version1},
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
} else if currentState.v1Cert != nil {
|
|
||||||
//TODO: CERT-V2 we should be able to tear this down
|
|
||||||
return util.NewContextualError("v1 certificate was removed, restart required", nil, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if newState.v2Cert != nil {
|
if newState.v2Cert != nil {
|
||||||
if currentState.v2Cert == nil {
|
if currentState.v2Cert == nil {
|
||||||
return util.NewContextualError("v2 certificate was added, restart required", nil, err)
|
//adding certs is fine, actually
|
||||||
}
|
} else {
|
||||||
|
|
||||||
// did IP in cert change? if so, don't set
|
// did IP in cert change? if so, don't set
|
||||||
if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
|
if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
|
||||||
return util.NewContextualError(
|
return util.NewContextualError(
|
||||||
"Networks in new cert was different from old",
|
"Networks in new cert was different from old",
|
||||||
m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks()},
|
m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks(), "cert_version": cert.Version2},
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -142,13 +137,25 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
|
|||||||
if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
|
if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
|
||||||
return util.NewContextualError(
|
return util.NewContextualError(
|
||||||
"Curve in new cert was different from old",
|
"Curve in new cert was different from old",
|
||||||
m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve()},
|
m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve(), "cert_version": cert.Version2},
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} else if currentState.v2Cert != nil {
|
} else if currentState.v2Cert != nil {
|
||||||
return util.NewContextualError("v2 certificate was removed, restart required", nil, err)
|
//newState.v1Cert is non-nil bc empty certstates aren't permitted
|
||||||
|
if newState.v1Cert == nil {
|
||||||
|
return util.NewContextualError("v1 and v2 certs are nil, this should be impossible", nil, err)
|
||||||
|
}
|
||||||
|
//if we're going to v1-only, we need to make sure we didn't orphan any v2-cert vpnaddrs
|
||||||
|
if !slices.Equal(currentState.v2Cert.Networks(), newState.v1Cert.Networks()) {
|
||||||
|
return util.NewContextualError(
|
||||||
|
"Removing a V2 cert is not permitted unless it has identical networks to the new V1 cert",
|
||||||
|
m{"new_v1_networks": newState.v1Cert.Networks(), "old_v2_networks": currentState.v2Cert.Networks()},
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cipher cant be hot swapped so just leave it at what it was before
|
// Cipher cant be hot swapped so just leave it at what it was before
|
||||||
|
|||||||
@@ -44,7 +44,10 @@ type Service struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func New(control *nebula.Control) (*Service, error) {
|
func New(control *nebula.Control) (*Service, error) {
|
||||||
control.Start()
|
wait, err := control.Start()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
ctx := control.Context()
|
ctx := control.Context()
|
||||||
eg, ctx := errgroup.WithContext(ctx)
|
eg, ctx := errgroup.WithContext(ctx)
|
||||||
@@ -141,6 +144,12 @@ func New(control *nebula.Control) (*Service, error) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Add the nebula wait function to the group
|
||||||
|
eg.Go(func() error {
|
||||||
|
wait()
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
return &s, nil
|
return &s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ type EncReader func(
|
|||||||
type Conn interface {
|
type Conn interface {
|
||||||
Rebind() error
|
Rebind() error
|
||||||
LocalAddr() (netip.AddrPort, error)
|
LocalAddr() (netip.AddrPort, error)
|
||||||
ListenOut(r EncReader)
|
ListenOut(r EncReader) error
|
||||||
WriteTo(b []byte, addr netip.AddrPort) error
|
WriteTo(b []byte, addr netip.AddrPort) error
|
||||||
ReloadConfig(c *config.C)
|
ReloadConfig(c *config.C)
|
||||||
Close() error
|
Close() error
|
||||||
@@ -30,8 +30,8 @@ func (NoopConn) Rebind() error {
|
|||||||
func (NoopConn) LocalAddr() (netip.AddrPort, error) {
|
func (NoopConn) LocalAddr() (netip.AddrPort, error) {
|
||||||
return netip.AddrPort{}, nil
|
return netip.AddrPort{}, nil
|
||||||
}
|
}
|
||||||
func (NoopConn) ListenOut(_ EncReader) {
|
func (NoopConn) ListenOut(_ EncReader) error {
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
|
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -1,17 +0,0 @@
|
|||||||
//go:build linux && (386 || amd64p32 || arm || mips || mipsle) && !android && !e2e_testing
|
|
||||||
// +build linux
|
|
||||||
// +build 386 amd64p32 arm mips mipsle
|
|
||||||
// +build !android
|
|
||||||
// +build !e2e_testing
|
|
||||||
|
|
||||||
package udp
|
|
||||||
|
|
||||||
import "golang.org/x/sys/unix"
|
|
||||||
|
|
||||||
func controllen(n int) uint32 {
|
|
||||||
return uint32(n)
|
|
||||||
}
|
|
||||||
|
|
||||||
func setCmsgLen(h *unix.Cmsghdr, n int) {
|
|
||||||
h.Len = uint32(unix.CmsgLen(n))
|
|
||||||
}
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
//go:build linux && (amd64 || arm64 || ppc64 || ppc64le || mips64 || mips64le || s390x || riscv64 || loong64) && !android && !e2e_testing
|
|
||||||
// +build linux
|
|
||||||
// +build amd64 arm64 ppc64 ppc64le mips64 mips64le s390x riscv64 loong64
|
|
||||||
// +build !android
|
|
||||||
// +build !e2e_testing
|
|
||||||
|
|
||||||
package udp
|
|
||||||
|
|
||||||
import "golang.org/x/sys/unix"
|
|
||||||
|
|
||||||
func controllen(n int) uint64 {
|
|
||||||
return uint64(n)
|
|
||||||
}
|
|
||||||
|
|
||||||
func setCmsgLen(h *unix.Cmsghdr, n int) {
|
|
||||||
h.Len = uint64(unix.CmsgLen(n))
|
|
||||||
}
|
|
||||||
@@ -165,7 +165,7 @@ func NewUDPStatsEmitter(udpConns []Conn) func() {
|
|||||||
return func() {}
|
return func() {}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) ListenOut(r EncReader) {
|
func (u *StdConn) ListenOut(r EncReader) error {
|
||||||
buffer := make([]byte, MTU)
|
buffer := make([]byte, MTU)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@@ -173,8 +173,7 @@ func (u *StdConn) ListenOut(r EncReader) {
|
|||||||
n, rua, err := u.ReadFromUDPAddrPort(buffer)
|
n, rua, err := u.ReadFromUDPAddrPort(buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, net.ErrClosed) {
|
if errors.Is(err, net.ErrClosed) {
|
||||||
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
return err
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
u.l.WithError(err).Error("unexpected udp socket receive error")
|
u.l.WithError(err).Error("unexpected udp socket receive error")
|
||||||
|
|||||||
@@ -71,15 +71,14 @@ type rawMessage struct {
|
|||||||
Len uint32
|
Len uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *GenericConn) ListenOut(r EncReader) {
|
func (u *GenericConn) ListenOut(r EncReader) error {
|
||||||
buffer := make([]byte, MTU)
|
buffer := make([]byte, MTU)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
// Just read one packet at a time
|
// Just read one packet at a time
|
||||||
n, rua, err := u.ReadFromUDPAddrPort(buffer)
|
n, rua, err := u.ReadFromUDPAddrPort(buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
return err
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
|
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
|
||||||
|
|||||||
618
udp/udp_linux.go
618
udp/udp_linux.go
@@ -5,12 +5,9 @@ package udp
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
@@ -21,46 +18,13 @@ import (
|
|||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
var readTimeout = unix.NsecToTimeval(int64(time.Millisecond * 500))
|
||||||
defaultGSOMaxSegments = 64
|
|
||||||
defaultGSOMaxBytes = 64000
|
|
||||||
defaultGROReadBufferSize = 2 * defaultGSOMaxBytes
|
|
||||||
defaultGSOFlushTimeout = 100 * time.Microsecond
|
|
||||||
)
|
|
||||||
|
|
||||||
type StdConn struct {
|
type StdConn struct {
|
||||||
sysFd int
|
sysFd int
|
||||||
isV4 bool
|
isV4 bool
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
batch int
|
batch int
|
||||||
|
|
||||||
enableGRO bool
|
|
||||||
enableGSO bool
|
|
||||||
|
|
||||||
controlLen atomic.Int32
|
|
||||||
|
|
||||||
gsoMu sync.Mutex
|
|
||||||
gsoPendingBuf []byte
|
|
||||||
gsoPendingSegments int
|
|
||||||
gsoPendingAddr netip.AddrPort
|
|
||||||
gsoPendingSegSize int
|
|
||||||
gsoMaxSegments int
|
|
||||||
gsoMaxBytes int
|
|
||||||
gsoFlushTimeout time.Duration
|
|
||||||
gsoFlushTimer *time.Timer
|
|
||||||
gsoControlBuf []byte
|
|
||||||
|
|
||||||
gsoBatches metrics.Counter
|
|
||||||
gsoSegments metrics.Counter
|
|
||||||
groSegments metrics.Counter
|
|
||||||
}
|
|
||||||
|
|
||||||
func maybeIPV4(ip net.IP) (net.IP, bool) {
|
|
||||||
ip4 := ip.To4()
|
|
||||||
if ip4 != nil {
|
|
||||||
return ip4, true
|
|
||||||
}
|
|
||||||
return ip, false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
|
||||||
@@ -86,6 +50,11 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set a read timeout
|
||||||
|
if err = unix.SetsockoptTimeval(fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &readTimeout); err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to set SO_RCVTIMEO: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
var sa unix.Sockaddr
|
var sa unix.Sockaddr
|
||||||
if ip.Is4() {
|
if ip.Is4() {
|
||||||
sa4 := &unix.SockaddrInet4{Port: port}
|
sa4 := &unix.SockaddrInet4{Port: port}
|
||||||
@@ -100,18 +69,7 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
|
|||||||
return nil, fmt.Errorf("unable to bind to socket: %s", err)
|
return nil, fmt.Errorf("unable to bind to socket: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &StdConn{
|
return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err
|
||||||
sysFd: fd,
|
|
||||||
isV4: ip.Is4(),
|
|
||||||
l: l,
|
|
||||||
batch: batch,
|
|
||||||
gsoMaxSegments: defaultGSOMaxSegments,
|
|
||||||
gsoMaxBytes: defaultGSOMaxBytes,
|
|
||||||
gsoFlushTimeout: defaultGSOFlushTimeout,
|
|
||||||
gsoBatches: metrics.GetOrRegisterCounter("udp.gso.batches", nil),
|
|
||||||
gsoSegments: metrics.GetOrRegisterCounter("udp.gso.segments", nil),
|
|
||||||
groSegments: metrics.GetOrRegisterCounter("udp.gro.segments", nil),
|
|
||||||
}, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) Rebind() error {
|
func (u *StdConn) Rebind() error {
|
||||||
@@ -160,69 +118,29 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) ListenOut(r EncReader) {
|
func (u *StdConn) ListenOut(r EncReader) error {
|
||||||
var ip netip.Addr
|
var ip netip.Addr
|
||||||
|
|
||||||
msgs, buffers, names, controls := u.PrepareRawMessages(u.batch)
|
msgs, buffers, names := u.PrepareRawMessages(u.batch)
|
||||||
read := u.ReadMulti
|
read := u.ReadMulti
|
||||||
if u.batch == 1 {
|
if u.batch == 1 {
|
||||||
read = u.ReadSingle
|
read = u.ReadSingle
|
||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
//desiredControl := int(u.controlLen.Load())
|
|
||||||
//hasControl := len(controls) > 0
|
|
||||||
//if (desiredControl > 0) != hasControl || (desiredControl > 0 && hasControl && len(controls[0]) != desiredControl) {
|
|
||||||
// msgs, buffers, names, controls = u.PrepareRawMessages(u.batch)
|
|
||||||
// hasControl = len(controls) > 0
|
|
||||||
//}
|
|
||||||
//
|
|
||||||
for i := range msgs {
|
|
||||||
if len(controls) <= i || len(controls[i]) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
msgs[i].Hdr.Controllen = controllen(len(controls[i]))
|
|
||||||
}
|
|
||||||
|
|
||||||
n, err := read(msgs)
|
n, err := read(msgs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
return err
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < n; i++ {
|
for i := 0; i < n; i++ {
|
||||||
payloadLen := int(msgs[i].Len)
|
|
||||||
if payloadLen == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
|
// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
|
||||||
if u.isV4 {
|
if u.isV4 {
|
||||||
ip, _ = netip.AddrFromSlice(names[i][4:8])
|
ip, _ = netip.AddrFromSlice(names[i][4:8])
|
||||||
} else {
|
} else {
|
||||||
ip, _ = netip.AddrFromSlice(names[i][8:24])
|
ip, _ = netip.AddrFromSlice(names[i][8:24])
|
||||||
}
|
}
|
||||||
addr := netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4]))
|
r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len])
|
||||||
|
|
||||||
if len(controls) > i && len(controls[i]) > 0 {
|
|
||||||
if segSize, segCount := u.parseGROSegment(&msgs[i], controls[i]); segSize > 0 && segSize < payloadLen {
|
|
||||||
if u.emitSegments(r, addr, buffers[i][:payloadLen], segSize, segCount) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if segCount > 1 {
|
|
||||||
u.l.WithFields(logrus.Fields{
|
|
||||||
"tag": "gro-debug",
|
|
||||||
"stage": "listen_out",
|
|
||||||
"reason": "emit_failed",
|
|
||||||
"payload_len": payloadLen,
|
|
||||||
"seg_size": segSize,
|
|
||||||
"seg_count": segCount,
|
|
||||||
}).Debug("gro-debug fallback to single packet")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
r(addr, buffers[i][:payloadLen])
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -240,6 +158,9 @@ func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
if err != 0 {
|
if err != 0 {
|
||||||
|
if err == unix.EAGAIN || err == unix.EINTR {
|
||||||
|
continue
|
||||||
|
}
|
||||||
return 0, &net.OpError{Op: "recvmsg", Err: err}
|
return 0, &net.OpError{Op: "recvmsg", Err: err}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -261,6 +182,9 @@ func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
if err != 0 {
|
if err != 0 {
|
||||||
|
if err == unix.EAGAIN || err == unix.EINTR {
|
||||||
|
continue
|
||||||
|
}
|
||||||
return 0, &net.OpError{Op: "recvmmsg", Err: err}
|
return 0, &net.OpError{Op: "recvmmsg", Err: err}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -269,13 +193,6 @@ func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
|
func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
|
||||||
if u.enableGSO {
|
|
||||||
if err := u.writeToGSO(b, ip); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if u.isV4 {
|
if u.isV4 {
|
||||||
return u.writeTo4(b, ip)
|
return u.writeTo4(b, ip)
|
||||||
}
|
}
|
||||||
@@ -336,494 +253,6 @@ func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) writeToGSO(b []byte, addr netip.AddrPort) error {
|
|
||||||
if len(b) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if !addr.IsValid() {
|
|
||||||
return u.directWrite(b, addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
u.gsoMu.Lock()
|
|
||||||
defer u.gsoMu.Unlock()
|
|
||||||
|
|
||||||
if cap(u.gsoPendingBuf) < u.gsoMaxBytes { //I feel like this is bad?
|
|
||||||
u.gsoPendingBuf = make([]byte, 0, u.gsoMaxBytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
if u.gsoPendingSegments > 0 && u.gsoPendingAddr != addr {
|
|
||||||
if err := u.flushPendingLocked(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(b) > u.gsoMaxBytes || u.gsoMaxSegments <= 1 {
|
|
||||||
if err := u.flushPendingLocked(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return u.directWrite(b, addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
if u.gsoPendingSegments == 0 {
|
|
||||||
u.gsoPendingAddr = addr
|
|
||||||
u.gsoPendingSegSize = len(b)
|
|
||||||
} else {
|
|
||||||
if len(b) > u.gsoPendingSegSize {
|
|
||||||
if err := u.flushPendingLocked(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
u.gsoPendingAddr = addr
|
|
||||||
u.gsoPendingSegSize = len(b)
|
|
||||||
} else if len(b) < u.gsoPendingSegSize {
|
|
||||||
if err := u.flushPendingLocked(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
u.gsoPendingAddr = addr
|
|
||||||
u.gsoPendingSegSize = len(b)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
inBuf := len(u.gsoPendingBuf) + len(b)
|
|
||||||
if len(u.gsoPendingBuf)+len(b) > u.gsoMaxBytes {
|
|
||||||
if err := u.flushPendingLocked(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
u.gsoPendingAddr = addr
|
|
||||||
u.gsoPendingSegSize = len(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
u.gsoPendingBuf = append(u.gsoPendingBuf, b...)
|
|
||||||
u.gsoPendingSegments++
|
|
||||||
|
|
||||||
if u.gsoPendingSegments >= u.gsoMaxSegments {
|
|
||||||
return u.flushPendingLocked()
|
|
||||||
}
|
|
||||||
|
|
||||||
if u.gsoFlushTimeout <= 0 {
|
|
||||||
return u.flushPendingLocked()
|
|
||||||
}
|
|
||||||
|
|
||||||
u.scheduleFlushLocked(inBuf)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) flushPendingLocked() error {
|
|
||||||
if u.gsoPendingSegments == 0 {
|
|
||||||
u.stopFlushTimerLocked()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
buf := u.gsoPendingBuf[:len(u.gsoPendingBuf)]
|
|
||||||
addr := u.gsoPendingAddr
|
|
||||||
segSize := u.gsoPendingSegSize
|
|
||||||
segments := u.gsoPendingSegments
|
|
||||||
|
|
||||||
u.stopFlushTimerLocked()
|
|
||||||
|
|
||||||
var err error
|
|
||||||
if segments <= 1 || !u.enableGSO {
|
|
||||||
err = u.directWrite(buf, addr)
|
|
||||||
} else {
|
|
||||||
err = u.sendSegmentedLocked(buf, addr, segSize)
|
|
||||||
if err != nil && (errors.Is(err, unix.EOPNOTSUPP) || errors.Is(err, unix.ENOTSUP)) {
|
|
||||||
u.enableGSO = false
|
|
||||||
u.l.WithError(err).Warn("UDP GSO not supported, disabling")
|
|
||||||
err = u.sendSequentialLocked(buf, addr, segSize)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err == nil && segments > 1 && u.enableGSO {
|
|
||||||
if u.gsoBatches != nil {
|
|
||||||
u.gsoBatches.Inc(1)
|
|
||||||
}
|
|
||||||
if u.gsoSegments != nil {
|
|
||||||
u.gsoSegments.Inc(int64(segments))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
u.gsoPendingBuf = u.gsoPendingBuf[:0]
|
|
||||||
u.gsoPendingSegments = 0
|
|
||||||
u.gsoPendingSegSize = 0
|
|
||||||
u.gsoPendingAddr = netip.AddrPort{}
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) sendSegmentedLocked(buf []byte, addr netip.AddrPort, segSize int) error {
|
|
||||||
if len(buf) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if segSize <= 0 {
|
|
||||||
segSize = len(buf)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(u.gsoControlBuf) < unix.CmsgSpace(2) {
|
|
||||||
u.gsoControlBuf = make([]byte, unix.CmsgSpace(2))
|
|
||||||
}
|
|
||||||
control := u.gsoControlBuf[:unix.CmsgSpace(2)]
|
|
||||||
for i := range control {
|
|
||||||
control[i] = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
|
||||||
setCmsgLen(hdr, 2)
|
|
||||||
hdr.Level = unix.SOL_UDP
|
|
||||||
hdr.Type = unix.UDP_SEGMENT
|
|
||||||
|
|
||||||
dataOff := unix.CmsgLen(0)
|
|
||||||
binary.NativeEndian.PutUint16(control[dataOff:dataOff+2], uint16(segSize))
|
|
||||||
|
|
||||||
var sa unix.Sockaddr
|
|
||||||
if u.isV4 {
|
|
||||||
sa4 := &unix.SockaddrInet4{Port: int(addr.Port())}
|
|
||||||
sa4.Addr = addr.Addr().As4()
|
|
||||||
sa = sa4
|
|
||||||
} else {
|
|
||||||
sa6 := &unix.SockaddrInet6{Port: int(addr.Port())}
|
|
||||||
sa6.Addr = addr.Addr().As16()
|
|
||||||
sa = sa6
|
|
||||||
}
|
|
||||||
|
|
||||||
for {
|
|
||||||
n, err := unix.SendmsgN(u.sysFd, buf, control[:unix.CmsgSpace(2)], sa, 0)
|
|
||||||
if err != nil {
|
|
||||||
if err == unix.EINTR {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return &net.OpError{Op: "sendmsg", Err: err}
|
|
||||||
}
|
|
||||||
if n != len(buf) {
|
|
||||||
return &net.OpError{Op: "sendmsg", Err: unix.EIO}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) sendSequentialLocked(buf []byte, addr netip.AddrPort, segSize int) error {
|
|
||||||
if len(buf) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if segSize <= 0 {
|
|
||||||
segSize = len(buf)
|
|
||||||
}
|
|
||||||
|
|
||||||
for offset := 0; offset < len(buf); offset += segSize {
|
|
||||||
end := offset + segSize
|
|
||||||
if end > len(buf) {
|
|
||||||
end = len(buf)
|
|
||||||
}
|
|
||||||
var err error
|
|
||||||
if u.isV4 {
|
|
||||||
err = u.writeTo4(buf[offset:end], addr)
|
|
||||||
} else {
|
|
||||||
err = u.writeTo6(buf[offset:end], addr)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if end == len(buf) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) scheduleFlushLocked(inBuf int) {
|
|
||||||
if u.gsoFlushTimeout <= 0 {
|
|
||||||
_ = u.flushPendingLocked()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
t := u.gsoFlushTimeout
|
|
||||||
if inBuf > u.gsoMaxBytes/2 {
|
|
||||||
t = t / 2
|
|
||||||
}
|
|
||||||
|
|
||||||
if u.gsoFlushTimer == nil {
|
|
||||||
u.gsoFlushTimer = time.AfterFunc(t, u.flushTimerHandler)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !u.gsoFlushTimer.Stop() {
|
|
||||||
// timer already fired or running; allow handler to exit if no data
|
|
||||||
}
|
|
||||||
u.gsoFlushTimer.Reset(t)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) stopFlushTimerLocked() {
|
|
||||||
if u.gsoFlushTimer != nil {
|
|
||||||
u.gsoFlushTimer.Stop()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) flushTimerHandler() {
|
|
||||||
//u.l.Warn("timer hit")
|
|
||||||
u.gsoMu.Lock()
|
|
||||||
defer u.gsoMu.Unlock()
|
|
||||||
|
|
||||||
if u.gsoPendingSegments == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := u.flushPendingLocked(); err != nil {
|
|
||||||
u.l.WithError(err).Warn("Failed to flush GSO batch")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) directWrite(b []byte, addr netip.AddrPort) error {
|
|
||||||
if u.isV4 {
|
|
||||||
return u.writeTo4(b, addr)
|
|
||||||
}
|
|
||||||
return u.writeTo6(b, addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) emitSegments(r EncReader, addr netip.AddrPort, payload []byte, segSize, segCount int) bool {
|
|
||||||
if segSize <= 0 || segSize >= len(payload) {
|
|
||||||
if u.l.Level >= logrus.DebugLevel {
|
|
||||||
u.l.WithFields(logrus.Fields{
|
|
||||||
"tag": "gro-debug",
|
|
||||||
"stage": "emit",
|
|
||||||
"reason": "invalid_seg_size",
|
|
||||||
"payload_len": len(payload),
|
|
||||||
"seg_size": segSize,
|
|
||||||
"seg_count": segCount,
|
|
||||||
}).Debug("gro-debug skip emit")
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
totalLen := len(payload)
|
|
||||||
if segCount <= 0 {
|
|
||||||
segCount = (totalLen + segSize - 1) / segSize
|
|
||||||
}
|
|
||||||
if segCount <= 1 {
|
|
||||||
if u.l.Level >= logrus.DebugLevel {
|
|
||||||
u.l.WithFields(logrus.Fields{
|
|
||||||
"tag": "gro-debug",
|
|
||||||
"stage": "emit",
|
|
||||||
"reason": "single_segment",
|
|
||||||
"payload_len": totalLen,
|
|
||||||
"seg_size": segSize,
|
|
||||||
"seg_count": segCount,
|
|
||||||
}).Debug("gro-debug skip emit")
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
//segments := make([][]byte, 0, segCount)
|
|
||||||
start := 0
|
|
||||||
//var firstHeader header.H
|
|
||||||
//firstParsed := false
|
|
||||||
//var firstCounter uint64
|
|
||||||
//var firstRemote uint32
|
|
||||||
numSegments := 0
|
|
||||||
//for start < totalLen && len(segments) < segCount {
|
|
||||||
for start < totalLen && numSegments < segCount {
|
|
||||||
end := start + segSize
|
|
||||||
if end > totalLen {
|
|
||||||
end = totalLen
|
|
||||||
}
|
|
||||||
|
|
||||||
//segment := append([]byte(nil), payload[start:end]...)
|
|
||||||
//q := numSegments % 4 //TODO
|
|
||||||
r(addr, payload[start:end])
|
|
||||||
numSegments++
|
|
||||||
//segments = append(segments, segment)
|
|
||||||
start = end
|
|
||||||
|
|
||||||
//if !firstParsed {
|
|
||||||
// if err := firstHeader.Parse(segment); err == nil {
|
|
||||||
// firstParsed = true
|
|
||||||
// firstCounter = firstHeader.MessageCounter
|
|
||||||
// firstRemote = firstHeader.RemoteIndex
|
|
||||||
// } else if u.l.IsLevelEnabled(logrus.DebugLevel) {
|
|
||||||
// u.l.WithFields(logrus.Fields{
|
|
||||||
// "tag": "gro-debug",
|
|
||||||
// "stage": "emit",
|
|
||||||
// "event": "parse_fail",
|
|
||||||
// "seg_index": len(segments) - 1,
|
|
||||||
// "seg_size": segSize,
|
|
||||||
// "seg_count": segCount,
|
|
||||||
// "payload_len": totalLen,
|
|
||||||
// "err": err,
|
|
||||||
// }).Debug("gro-debug segment parse failed")
|
|
||||||
// }
|
|
||||||
//}
|
|
||||||
}
|
|
||||||
|
|
||||||
//for idx, segment := range segments {
|
|
||||||
// r(addr, segment)
|
|
||||||
//if idx == len(segments)-1 && len(segment) < segSize && u.l.IsLevelEnabled(logrus.DebugLevel) {
|
|
||||||
// var tail header.H
|
|
||||||
// if err := tail.Parse(segment); err == nil {
|
|
||||||
// u.l.WithFields(logrus.Fields{
|
|
||||||
// "tag": "gro-debug",
|
|
||||||
// "stage": "emit",
|
|
||||||
// "event": "tail_segment",
|
|
||||||
// "segment_len": len(segment),
|
|
||||||
// "remote_index": tail.RemoteIndex,
|
|
||||||
// "message_counter": tail.MessageCounter,
|
|
||||||
// }).Debug("gro-debug tail segment metadata")
|
|
||||||
// } else {
|
|
||||||
// u.l.WithError(err).Warn("Failed to parse tail segment")
|
|
||||||
// }
|
|
||||||
//}
|
|
||||||
//}
|
|
||||||
|
|
||||||
if u.groSegments != nil {
|
|
||||||
//u.groSegments.Inc(int64(len(segments)))
|
|
||||||
u.groSegments.Inc(int64(numSegments))
|
|
||||||
}
|
|
||||||
|
|
||||||
//if len(segments) > 0 {
|
|
||||||
// lastLen := len(segments[len(segments)-1])
|
|
||||||
// if u.l.IsLevelEnabled(logrus.DebugLevel) {
|
|
||||||
// u.l.WithFields(logrus.Fields{
|
|
||||||
// "tag": "gro-debug",
|
|
||||||
// "stage": "emit",
|
|
||||||
// "event": "success",
|
|
||||||
// "payload_len": totalLen,
|
|
||||||
// "seg_size": segSize,
|
|
||||||
// "seg_count": segCount,
|
|
||||||
// "actual_segs": len(segments),
|
|
||||||
// "last_seg_len": lastLen,
|
|
||||||
// "addr": addr.String(),
|
|
||||||
// "first_remote": firstRemote,
|
|
||||||
// "first_counter": firstCounter,
|
|
||||||
// }).Debug("gro-debug emit")
|
|
||||||
// }
|
|
||||||
//}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) parseGROSegment(msg *rawMessage, control []byte) (int, int) {
|
|
||||||
ctrlLen := int(msg.Hdr.Controllen)
|
|
||||||
if ctrlLen <= 0 {
|
|
||||||
return 0, 0
|
|
||||||
}
|
|
||||||
if ctrlLen > len(control) {
|
|
||||||
ctrlLen = len(control)
|
|
||||||
}
|
|
||||||
|
|
||||||
cmsgs, err := unix.ParseSocketControlMessage(control[:ctrlLen])
|
|
||||||
if err != nil {
|
|
||||||
u.l.WithError(err).Debug("failed to parse UDP GRO control message")
|
|
||||||
return 0, 0
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, c := range cmsgs {
|
|
||||||
if c.Header.Level == unix.SOL_UDP && c.Header.Type == unix.UDP_GRO && len(c.Data) >= 2 {
|
|
||||||
segSize := int(binary.NativeEndian.Uint16(c.Data[:2]))
|
|
||||||
segCount := 0
|
|
||||||
if len(c.Data) >= 4 {
|
|
||||||
segCount = int(binary.NativeEndian.Uint16(c.Data[2:4]))
|
|
||||||
}
|
|
||||||
if u.l.Level >= logrus.DebugLevel {
|
|
||||||
u.l.WithFields(logrus.Fields{
|
|
||||||
"tag": "gro-debug",
|
|
||||||
"stage": "parse",
|
|
||||||
"seg_size": segSize,
|
|
||||||
"seg_count": segCount,
|
|
||||||
}).Debug("gro-debug control parsed")
|
|
||||||
}
|
|
||||||
return segSize, segCount
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0, 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) configureGRO(enable bool) {
|
|
||||||
if enable == u.enableGRO {
|
|
||||||
if enable {
|
|
||||||
u.controlLen.Store(int32(unix.CmsgSpace(2)))
|
|
||||||
} else {
|
|
||||||
u.controlLen.Store(0)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if enable {
|
|
||||||
if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 1); err != nil {
|
|
||||||
u.l.WithError(err).Warn("Failed to enable UDP GRO")
|
|
||||||
u.enableGRO = false
|
|
||||||
u.controlLen.Store(0)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
u.enableGRO = true
|
|
||||||
u.controlLen.Store(int32(unix.CmsgSpace(2)))
|
|
||||||
u.l.Info("UDP GRO enabled")
|
|
||||||
} else {
|
|
||||||
if u.enableGRO {
|
|
||||||
if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 0); err != nil {
|
|
||||||
u.l.WithError(err).Warn("Failed to disable UDP GRO")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
u.enableGRO = false
|
|
||||||
u.controlLen.Store(0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) configureGSO(enable bool, c *config.C) {
|
|
||||||
u.gsoMu.Lock()
|
|
||||||
defer u.gsoMu.Unlock()
|
|
||||||
|
|
||||||
if !enable {
|
|
||||||
if u.enableGSO {
|
|
||||||
if err := u.flushPendingLocked(); err != nil {
|
|
||||||
u.l.WithError(err).Warn("Failed to flush GSO buffers while disabling")
|
|
||||||
}
|
|
||||||
u.enableGSO = false
|
|
||||||
if u.gsoFlushTimer != nil {
|
|
||||||
u.gsoFlushTimer.Stop()
|
|
||||||
}
|
|
||||||
u.l.Info("UDP GSO disabled")
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
maxSegments := c.GetInt("listen.gso_max_segments", defaultGSOMaxSegments)
|
|
||||||
if maxSegments < 2 {
|
|
||||||
maxSegments = 2
|
|
||||||
}
|
|
||||||
|
|
||||||
maxBytes := c.GetInt("listen.gso_max_bytes", 0)
|
|
||||||
if maxBytes <= 0 {
|
|
||||||
maxBytes = defaultGSOMaxBytes
|
|
||||||
}
|
|
||||||
if maxBytes < MTU {
|
|
||||||
maxBytes = MTU
|
|
||||||
}
|
|
||||||
|
|
||||||
flushTimeout := c.GetDuration("listen.gso_flush_timeout", defaultGSOFlushTimeout)
|
|
||||||
if flushTimeout < 0 {
|
|
||||||
flushTimeout = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
u.enableGSO = true
|
|
||||||
u.gsoMaxSegments = maxSegments
|
|
||||||
u.gsoMaxBytes = maxBytes
|
|
||||||
u.gsoFlushTimeout = flushTimeout
|
|
||||||
|
|
||||||
if cap(u.gsoPendingBuf) < u.gsoMaxBytes {
|
|
||||||
u.gsoPendingBuf = make([]byte, 0, u.gsoMaxBytes)
|
|
||||||
} else {
|
|
||||||
u.gsoPendingBuf = u.gsoPendingBuf[:0]
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(u.gsoControlBuf) < unix.CmsgSpace(2) {
|
|
||||||
u.gsoControlBuf = make([]byte, unix.CmsgSpace(2))
|
|
||||||
}
|
|
||||||
|
|
||||||
u.l.WithFields(logrus.Fields{
|
|
||||||
"segments": u.gsoMaxSegments,
|
|
||||||
"bytes": u.gsoMaxBytes,
|
|
||||||
"flush_timeout": u.gsoFlushTimeout,
|
|
||||||
}).Info("UDP GSO configured")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *StdConn) ReloadConfig(c *config.C) {
|
func (u *StdConn) ReloadConfig(c *config.C) {
|
||||||
b := c.GetInt("listen.read_buffer", 0)
|
b := c.GetInt("listen.read_buffer", 0)
|
||||||
if b > 0 {
|
if b > 0 {
|
||||||
@@ -870,9 +299,6 @@ func (u *StdConn) ReloadConfig(c *config.C) {
|
|||||||
u.l.WithError(err).Error("Failed to set listen.so_mark")
|
u.l.WithError(err).Error("Failed to set listen.so_mark")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
u.configureGRO(c.GetBool("listen.enable_gro", false))
|
|
||||||
u.configureGSO(c.GetBool("listen.enable_gso", false), c)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
|
func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
|
||||||
@@ -885,15 +311,7 @@ func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) Close() error {
|
func (u *StdConn) Close() error {
|
||||||
u.gsoMu.Lock()
|
return syscall.Close(u.sysFd)
|
||||||
flushErr := u.flushPendingLocked()
|
|
||||||
u.gsoMu.Unlock()
|
|
||||||
|
|
||||||
closeErr := syscall.Close(u.sysFd)
|
|
||||||
if flushErr != nil {
|
|
||||||
return flushErr
|
|
||||||
}
|
|
||||||
return closeErr
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUDPStatsEmitter(udpConns []Conn) func() {
|
func NewUDPStatsEmitter(udpConns []Conn) func() {
|
||||||
|
|||||||
@@ -30,24 +30,13 @@ type rawMessage struct {
|
|||||||
Len uint32
|
Len uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte, [][]byte) {
|
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
|
||||||
controlLen := int(u.controlLen.Load())
|
|
||||||
|
|
||||||
msgs := make([]rawMessage, n)
|
msgs := make([]rawMessage, n)
|
||||||
buffers := make([][]byte, n)
|
buffers := make([][]byte, n)
|
||||||
names := make([][]byte, n)
|
names := make([][]byte, n)
|
||||||
|
|
||||||
var controls [][]byte
|
|
||||||
if controlLen > 0 {
|
|
||||||
controls = make([][]byte, n)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range msgs {
|
for i := range msgs {
|
||||||
size := MTU
|
buffers[i] = make([]byte, MTU)
|
||||||
if defaultGROReadBufferSize > size {
|
|
||||||
size = defaultGROReadBufferSize
|
|
||||||
}
|
|
||||||
buffers[i] = make([]byte, size)
|
|
||||||
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
||||||
|
|
||||||
vs := []iovec{
|
vs := []iovec{
|
||||||
@@ -59,16 +48,7 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte, [
|
|||||||
|
|
||||||
msgs[i].Hdr.Name = &names[i][0]
|
msgs[i].Hdr.Name = &names[i][0]
|
||||||
msgs[i].Hdr.Namelen = uint32(len(names[i]))
|
msgs[i].Hdr.Namelen = uint32(len(names[i]))
|
||||||
|
|
||||||
if controlLen > 0 {
|
|
||||||
controls[i] = make([]byte, controlLen)
|
|
||||||
msgs[i].Hdr.Control = &controls[i][0]
|
|
||||||
msgs[i].Hdr.Controllen = controllen(len(controls[i]))
|
|
||||||
} else {
|
|
||||||
msgs[i].Hdr.Control = nil
|
|
||||||
msgs[i].Hdr.Controllen = controllen(0)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return msgs, buffers, names, controls
|
return msgs, buffers, names
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -33,43 +33,25 @@ type rawMessage struct {
|
|||||||
Pad0 [4]byte
|
Pad0 [4]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte, [][]byte) {
|
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
|
||||||
controlLen := int(u.controlLen.Load())
|
|
||||||
|
|
||||||
msgs := make([]rawMessage, n)
|
msgs := make([]rawMessage, n)
|
||||||
buffers := make([][]byte, n)
|
buffers := make([][]byte, n)
|
||||||
names := make([][]byte, n)
|
names := make([][]byte, n)
|
||||||
|
|
||||||
var controls [][]byte
|
|
||||||
if controlLen > 0 {
|
|
||||||
controls = make([][]byte, n)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range msgs {
|
for i := range msgs {
|
||||||
size := MTU
|
buffers[i] = make([]byte, MTU)
|
||||||
if defaultGROReadBufferSize > size {
|
|
||||||
size = defaultGROReadBufferSize
|
|
||||||
}
|
|
||||||
buffers[i] = make([]byte, size)
|
|
||||||
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
names[i] = make([]byte, unix.SizeofSockaddrInet6)
|
||||||
|
|
||||||
vs := []iovec{{Base: &buffers[i][0], Len: uint64(len(buffers[i]))}}
|
vs := []iovec{
|
||||||
|
{Base: &buffers[i][0], Len: uint64(len(buffers[i]))},
|
||||||
|
}
|
||||||
|
|
||||||
msgs[i].Hdr.Iov = &vs[0]
|
msgs[i].Hdr.Iov = &vs[0]
|
||||||
msgs[i].Hdr.Iovlen = uint64(len(vs))
|
msgs[i].Hdr.Iovlen = uint64(len(vs))
|
||||||
|
|
||||||
msgs[i].Hdr.Name = &names[i][0]
|
msgs[i].Hdr.Name = &names[i][0]
|
||||||
msgs[i].Hdr.Namelen = uint32(len(names[i]))
|
msgs[i].Hdr.Namelen = uint32(len(names[i]))
|
||||||
|
|
||||||
if controlLen > 0 {
|
|
||||||
controls[i] = make([]byte, controlLen)
|
|
||||||
msgs[i].Hdr.Control = &controls[i][0]
|
|
||||||
msgs[i].Hdr.Controllen = controllen(len(controls[i]))
|
|
||||||
} else {
|
|
||||||
msgs[i].Hdr.Control = nil
|
|
||||||
msgs[i].Hdr.Controllen = controllen(0)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return msgs, buffers, names, controls
|
return msgs, buffers, names
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -134,7 +134,7 @@ func (u *RIOConn) bind(sa windows.Sockaddr) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *RIOConn) ListenOut(r EncReader) {
|
func (u *RIOConn) ListenOut(r EncReader) error {
|
||||||
buffer := make([]byte, MTU)
|
buffer := make([]byte, MTU)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@@ -142,8 +142,7 @@ func (u *RIOConn) ListenOut(r EncReader) {
|
|||||||
n, rua, err := u.receive(buffer)
|
n, rua, err := u.receive(buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, net.ErrClosed) {
|
if errors.Is(err, net.ErrClosed) {
|
||||||
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
return err
|
||||||
return
|
|
||||||
}
|
}
|
||||||
u.l.WithError(err).Error("unexpected udp socket receive error")
|
u.l.WithError(err).Error("unexpected udp socket receive error")
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ package udp
|
|||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"os"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
@@ -106,11 +107,11 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *TesterConn) ListenOut(r EncReader) {
|
func (u *TesterConn) ListenOut(r EncReader) error {
|
||||||
for {
|
for {
|
||||||
p, ok := <-u.RxPackets
|
p, ok := <-u.RxPackets
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return os.ErrClosed
|
||||||
}
|
}
|
||||||
r(p.From, p.Data)
|
r(p.From, p.Data)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user