mirror of
https://github.com/slackhq/nebula.git
synced 2026-05-16 04:47:38 +02:00
Remove more os.Exit calls and give a more reliable wait for stop function (attempt 3) (#1661)
This commit is contained in:
@@ -78,8 +78,20 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !*configTest {
|
if !*configTest {
|
||||||
ctrl.Start()
|
wait, err := ctrl.Start()
|
||||||
ctrl.ShutdownBlock()
|
if err != nil {
|
||||||
|
util.LogWithContextIfNeeded("Error while running", err, l)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
go ctrl.ShutdownBlock()
|
||||||
|
|
||||||
|
if err := wait(); err != nil {
|
||||||
|
l.WithError(err).Error("Nebula stopped due to fatal error")
|
||||||
|
os.Exit(2)
|
||||||
|
}
|
||||||
|
|
||||||
|
l.Info("Goodbye")
|
||||||
}
|
}
|
||||||
|
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
|
|||||||
@@ -72,9 +72,21 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !*configTest {
|
if !*configTest {
|
||||||
ctrl.Start()
|
wait, err := ctrl.Start()
|
||||||
|
if err != nil {
|
||||||
|
util.LogWithContextIfNeeded("Error while running", err, l)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
go ctrl.ShutdownBlock()
|
||||||
notifyReady(l)
|
notifyReady(l)
|
||||||
ctrl.ShutdownBlock()
|
|
||||||
|
if err := wait(); err != nil {
|
||||||
|
l.WithError(err).Error("Nebula stopped due to fatal error")
|
||||||
|
os.Exit(2)
|
||||||
|
}
|
||||||
|
|
||||||
|
l.Info("Goodbye")
|
||||||
}
|
}
|
||||||
|
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
|
|||||||
78
control.go
78
control.go
@@ -2,9 +2,11 @@ package nebula
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
@@ -13,6 +15,20 @@ import (
|
|||||||
"github.com/slackhq/nebula/overlay"
|
"github.com/slackhq/nebula/overlay"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type RunState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
StateUnknown RunState = iota
|
||||||
|
StateReady
|
||||||
|
StateStarted
|
||||||
|
StateStopping
|
||||||
|
StateStopped
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrAlreadyStarted = errors.New("nebula is already started")
|
||||||
|
var ErrAlreadyStopped = errors.New("nebula cannot be restarted")
|
||||||
|
var ErrUnknownState = errors.New("nebula state is invalid")
|
||||||
|
|
||||||
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
|
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
|
||||||
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
|
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
|
||||||
|
|
||||||
@@ -26,6 +42,9 @@ type controlHostLister interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Control struct {
|
type Control struct {
|
||||||
|
stateLock sync.Mutex
|
||||||
|
state RunState
|
||||||
|
|
||||||
f *Interface
|
f *Interface
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
@@ -49,10 +68,31 @@ type ControlHostInfo struct {
|
|||||||
CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"`
|
CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
|
// Start actually runs nebula, this is a nonblocking call.
|
||||||
func (c *Control) Start() {
|
// The returned function blocks until nebula has fully stopped and returns the
|
||||||
|
// first fatal reader error (if any). A nil error means nebula shut down
|
||||||
|
// gracefully; a non-nil error means a reader hit an unexpected failure that
|
||||||
|
// triggered the shutdown.
|
||||||
|
func (c *Control) Start() (func() error, error) {
|
||||||
|
c.stateLock.Lock()
|
||||||
|
defer c.stateLock.Unlock()
|
||||||
|
switch c.state {
|
||||||
|
case StateReady:
|
||||||
|
//yay!
|
||||||
|
case StateStopped, StateStopping:
|
||||||
|
return nil, ErrAlreadyStopped
|
||||||
|
case StateStarted:
|
||||||
|
return nil, ErrAlreadyStarted
|
||||||
|
default:
|
||||||
|
return nil, ErrUnknownState
|
||||||
|
}
|
||||||
|
|
||||||
// Activate the interface
|
// Activate the interface
|
||||||
c.f.activate()
|
err := c.f.activate()
|
||||||
|
if err != nil {
|
||||||
|
c.state = StateStopped
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
// Call all the delayed funcs that waited patiently for the interface to be created.
|
// Call all the delayed funcs that waited patiently for the interface to be created.
|
||||||
if c.sshStart != nil {
|
if c.sshStart != nil {
|
||||||
@@ -71,16 +111,40 @@ func (c *Control) Start() {
|
|||||||
c.lighthouseStart()
|
c.lighthouseStart()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.f.triggerShutdown = c.Stop
|
||||||
|
|
||||||
// Start reading packets.
|
// Start reading packets.
|
||||||
c.f.run()
|
out, err := c.f.run()
|
||||||
|
if err != nil {
|
||||||
|
c.state = StateStopped
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
c.state = StateStarted
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Control) State() RunState {
|
||||||
|
c.stateLock.Lock()
|
||||||
|
defer c.stateLock.Unlock()
|
||||||
|
return c.state
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Control) Context() context.Context {
|
func (c *Control) Context() context.Context {
|
||||||
return c.ctx
|
return c.ctx
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop signals nebula to shutdown and close all tunnels, returns after the shutdown is complete
|
// Stop is a non-blocking call that signals nebula to close all tunnels and shut down
|
||||||
func (c *Control) Stop() {
|
func (c *Control) Stop() {
|
||||||
|
c.stateLock.Lock()
|
||||||
|
if c.state != StateStarted {
|
||||||
|
c.stateLock.Unlock()
|
||||||
|
// We are stopping or stopped already
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.state = StateStopping
|
||||||
|
c.stateLock.Unlock()
|
||||||
|
|
||||||
// Stop the handshakeManager (and other services), to prevent new tunnels from
|
// Stop the handshakeManager (and other services), to prevent new tunnels from
|
||||||
// being created while we're shutting them all down.
|
// being created while we're shutting them all down.
|
||||||
c.cancel()
|
c.cancel()
|
||||||
@@ -89,7 +153,9 @@ func (c *Control) Stop() {
|
|||||||
if err := c.f.Close(); err != nil {
|
if err := c.f.Close(); err != nil {
|
||||||
c.l.WithError(err).Error("Close interface failed")
|
c.l.WithError(err).Error("Close interface failed")
|
||||||
}
|
}
|
||||||
c.l.Info("Goodbye")
|
c.stateLock.Lock()
|
||||||
|
c.state = StateStopped
|
||||||
|
c.stateLock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled
|
// ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled
|
||||||
|
|||||||
@@ -79,6 +79,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
|||||||
}, &Interface{})
|
}, &Interface{})
|
||||||
|
|
||||||
c := Control{
|
c := Control{
|
||||||
|
state: StateReady,
|
||||||
f: &Interface{
|
f: &Interface{
|
||||||
hostMap: hm,
|
hostMap: hm,
|
||||||
},
|
},
|
||||||
|
|||||||
95
interface.go
95
interface.go
@@ -6,7 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -87,6 +87,13 @@ type Interface struct {
|
|||||||
|
|
||||||
writers []udp.Conn
|
writers []udp.Conn
|
||||||
readers []io.ReadWriteCloser
|
readers []io.ReadWriteCloser
|
||||||
|
wg sync.WaitGroup
|
||||||
|
|
||||||
|
// fatalErr holds the first unexpected reader error that caused shutdown.
|
||||||
|
// nil means "no fatal error" (yet)
|
||||||
|
fatalErr atomic.Pointer[error]
|
||||||
|
// triggerShutdown is a function that will be run exactly once, when onFatal swaps something non-nil into fatalErr
|
||||||
|
triggerShutdown func()
|
||||||
|
|
||||||
metricHandshakes metrics.Histogram
|
metricHandshakes metrics.Histogram
|
||||||
messageMetrics *MessageMetrics
|
messageMetrics *MessageMetrics
|
||||||
@@ -209,7 +216,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
// activate creates the interface on the host. After the interface is created, any
|
// activate creates the interface on the host. After the interface is created, any
|
||||||
// other services that want to bind listeners to its IP may do so successfully. However,
|
// other services that want to bind listeners to its IP may do so successfully. However,
|
||||||
// the interface isn't going to process anything until run() is called.
|
// the interface isn't going to process anything until run() is called.
|
||||||
func (f *Interface) activate() {
|
func (f *Interface) activate() error {
|
||||||
// actually turn on tun dev
|
// actually turn on tun dev
|
||||||
|
|
||||||
addr, err := f.outside.LocalAddr()
|
addr, err := f.outside.LocalAddr()
|
||||||
@@ -237,27 +244,54 @@ func (f *Interface) activate() {
|
|||||||
if i > 0 {
|
if i > 0 {
|
||||||
reader, err = f.inside.NewMultiQueueReader()
|
reader, err = f.inside.NewMultiQueueReader()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.Fatal(err)
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
f.readers[i] = reader
|
f.readers[i] = reader
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := f.inside.Activate(); err != nil {
|
f.wg.Add(1) // for us to wait on Close() to return
|
||||||
|
if err = f.inside.Activate(); err != nil {
|
||||||
|
f.wg.Done()
|
||||||
f.inside.Close()
|
f.inside.Close()
|
||||||
f.l.Fatal(err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) run() {
|
func (f *Interface) run() (func() error, error) {
|
||||||
// Launch n queues to read packets from udp
|
// Launch n queues to read packets from udp
|
||||||
for i := 0; i < f.routines; i++ {
|
for i := 0; i < f.routines; i++ {
|
||||||
go f.listenOut(i)
|
f.wg.Go(func() {
|
||||||
|
f.listenOut(i)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Launch n queues to read packets from tun dev
|
// Launch n queues to read packets from tun dev
|
||||||
for i := 0; i < f.routines; i++ {
|
for i := 0; i < f.routines; i++ {
|
||||||
go f.listenIn(f.readers[i], i)
|
f.wg.Go(func() {
|
||||||
|
f.listenIn(f.readers[i], i)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return func() error {
|
||||||
|
f.wg.Wait()
|
||||||
|
if e := f.fatalErr.Load(); e != nil {
|
||||||
|
return *e
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// onFatal stores the first fatal reader error, and calls triggerShutdown if it was the first one
|
||||||
|
func (f *Interface) onFatal(err error) {
|
||||||
|
swapped := f.fatalErr.CompareAndSwap(nil, &err)
|
||||||
|
if !swapped {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if f.triggerShutdown != nil {
|
||||||
|
f.triggerShutdown()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -276,9 +310,16 @@ func (f *Interface) listenOut(i int) {
|
|||||||
fwPacket := &firewall.Packet{}
|
fwPacket := &firewall.Packet{}
|
||||||
nb := make([]byte, 12, 12)
|
nb := make([]byte, 12, 12)
|
||||||
|
|
||||||
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
|
err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
|
||||||
f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
|
f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if err != nil && !f.closed.Load() {
|
||||||
|
f.l.WithError(err).Error("Error while reading inbound packet, closing")
|
||||||
|
f.onFatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f.l.Debugf("underlay reader %v is done", i)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
||||||
@@ -292,17 +333,17 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
|||||||
for {
|
for {
|
||||||
n, err := reader.Read(packet)
|
n, err := reader.Read(packet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, os.ErrClosed) && f.closed.Load() {
|
if !f.closed.Load() {
|
||||||
return
|
f.l.WithError(err).WithField("reader", i).Error("Error while reading outbound packet, closing")
|
||||||
|
f.onFatal(err)
|
||||||
}
|
}
|
||||||
|
break
|
||||||
f.l.WithError(err).Error("Error while reading outbound packet")
|
|
||||||
// This only seems to happen when something fatal happens to the fd, so exit.
|
|
||||||
os.Exit(2)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
|
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
f.l.Debugf("overlay reader %v is done", i)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
|
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
|
||||||
@@ -477,23 +518,23 @@ func (f *Interface) GetCertState() *CertState {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) Close() error {
|
func (f *Interface) Close() error {
|
||||||
|
var errs []error
|
||||||
f.closed.Store(true)
|
f.closed.Store(true)
|
||||||
|
|
||||||
for _, u := range f.writers {
|
// Release the udp readers
|
||||||
|
for i, u := range f.writers {
|
||||||
err := u.Close()
|
err := u.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).Error("Error while closing udp socket")
|
f.l.WithError(err).WithField("writer", i).Error("Error while closing udp socket")
|
||||||
}
|
errs = append(errs, err)
|
||||||
}
|
|
||||||
for i, r := range f.readers {
|
|
||||||
if i == 0 {
|
|
||||||
continue // f.readers[0] is f.inside, which we want to save for last
|
|
||||||
}
|
|
||||||
if err := r.Close(); err != nil {
|
|
||||||
f.l.WithError(err).Error("Error while closing tun reader")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Release the tun device
|
// Release the tun device (closing the tun also closes all readers)
|
||||||
return f.inside.Close()
|
closeErr := f.inside.Close()
|
||||||
|
if closeErr != nil {
|
||||||
|
errs = append(errs, closeErr)
|
||||||
|
}
|
||||||
|
f.wg.Done()
|
||||||
|
return errors.Join(errs...)
|
||||||
}
|
}
|
||||||
|
|||||||
19
main.go
19
main.go
@@ -288,15 +288,16 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &Control{
|
return &Control{
|
||||||
ifce,
|
state: StateReady,
|
||||||
l,
|
f: ifce,
|
||||||
ctx,
|
l: l,
|
||||||
cancel,
|
ctx: ctx,
|
||||||
sshStart,
|
cancel: cancel,
|
||||||
statsStart,
|
sshStart: sshStart,
|
||||||
dnsStart,
|
statsStart: statsStart,
|
||||||
lightHouse.StartUpdateWorker,
|
dnsStart: dnsStart,
|
||||||
connManager.Start,
|
lighthouseStart: lightHouse.StartUpdateWorker,
|
||||||
|
connectionManagerStart: connManager.Start,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
120
overlay/tun_file_linux_test.go
Normal file
120
overlay/tun_file_linux_test.go
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
//go:build linux && !android && !e2e_testing
|
||||||
|
// +build linux,!android,!e2e_testing
|
||||||
|
|
||||||
|
package overlay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
// newReadPipe returns a read fd. The matching write fd is registered for cleanup.
|
||||||
|
// The caller takes ownership of the read fd (pass it to newTunFd / newFriend).
|
||||||
|
func newReadPipe(t *testing.T) int {
|
||||||
|
t.Helper()
|
||||||
|
var fds [2]int
|
||||||
|
if err := unix.Pipe2(fds[:], unix.O_CLOEXEC); err != nil {
|
||||||
|
t.Fatalf("pipe2: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { _ = unix.Close(fds[1]) })
|
||||||
|
return fds[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTunFile_WakeForShutdown_UnblocksRead(t *testing.T) {
|
||||||
|
tf, err := newTunFd(newReadPipe(t))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("newTunFd: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { _ = tf.Close() })
|
||||||
|
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
_, err := tf.Read(make([]byte, 64))
|
||||||
|
done <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Verify Read is actually blocked in poll.
|
||||||
|
select {
|
||||||
|
case err := <-done:
|
||||||
|
t.Fatalf("Read returned before shutdown signal: %v", err)
|
||||||
|
case <-time.After(50 * time.Millisecond):
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tf.wakeForShutdown(); err != nil {
|
||||||
|
t.Fatalf("wakeForShutdown: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-done:
|
||||||
|
if !errors.Is(err, os.ErrClosed) {
|
||||||
|
t.Fatalf("expected os.ErrClosed, got %v", err)
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("Read did not wake on shutdown")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTunFile_WakeForShutdown_WakesFriends(t *testing.T) {
|
||||||
|
parent, err := newTunFd(newReadPipe(t))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("newTunFd: %v", err)
|
||||||
|
}
|
||||||
|
friend, err := parent.newFriend(newReadPipe(t))
|
||||||
|
if err != nil {
|
||||||
|
_ = parent.Close()
|
||||||
|
t.Fatalf("newFriend: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = friend.Close()
|
||||||
|
_ = parent.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
readers := []*tunFile{parent, friend}
|
||||||
|
errs := make([]error, len(readers))
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i, r := range readers {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(i int, r *tunFile) {
|
||||||
|
defer wg.Done()
|
||||||
|
_, errs[i] = r.Read(make([]byte, 64))
|
||||||
|
}(i, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
if err := parent.wakeForShutdown(); err != nil {
|
||||||
|
t.Fatalf("wakeForShutdown: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() { wg.Wait(); close(done) }()
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("readers did not wake")
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, err := range errs {
|
||||||
|
if !errors.Is(err, os.ErrClosed) {
|
||||||
|
t.Errorf("reader %d: expected os.ErrClosed, got %v", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTunFile_Close_Idempotent(t *testing.T) {
|
||||||
|
tf, err := newTunFd(newReadPipe(t))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("newTunFd: %v", err)
|
||||||
|
}
|
||||||
|
if err := tf.Close(); err != nil {
|
||||||
|
t.Fatalf("first Close: %v", err)
|
||||||
|
}
|
||||||
|
if err := tf.Close(); err != nil {
|
||||||
|
t.Fatalf("second Close should be a no-op, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,6 +4,7 @@
|
|||||||
package overlay
|
package overlay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
@@ -24,9 +25,175 @@ import (
|
|||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// tunFile wraps a TUN file descriptor with poll-based reads. The FD provided will be changed to non-blocking.
|
||||||
|
// A shared eventfd allows Close to wake all readers blocked in poll.
|
||||||
|
type tunFile struct {
|
||||||
|
fd int
|
||||||
|
shutdownFd int
|
||||||
|
lastOne bool
|
||||||
|
readPoll [2]unix.PollFd
|
||||||
|
writePoll [2]unix.PollFd
|
||||||
|
closed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// newFriend makes a tunFile for a MultiQueueReader that copies the shutdown eventfd from the parent tun
|
||||||
|
func (r *tunFile) newFriend(fd int) (*tunFile, error) {
|
||||||
|
if err := unix.SetNonblock(fd, true); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err)
|
||||||
|
}
|
||||||
|
return &tunFile{
|
||||||
|
fd: fd,
|
||||||
|
shutdownFd: r.shutdownFd,
|
||||||
|
readPoll: [2]unix.PollFd{
|
||||||
|
{Fd: int32(fd), Events: unix.POLLIN},
|
||||||
|
{Fd: int32(r.shutdownFd), Events: unix.POLLIN},
|
||||||
|
},
|
||||||
|
writePoll: [2]unix.PollFd{
|
||||||
|
{Fd: int32(fd), Events: unix.POLLOUT},
|
||||||
|
{Fd: int32(r.shutdownFd), Events: unix.POLLIN},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTunFd(fd int) (*tunFile, error) {
|
||||||
|
if err := unix.SetNonblock(fd, true); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create eventfd: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
out := &tunFile{
|
||||||
|
fd: fd,
|
||||||
|
shutdownFd: shutdownFd,
|
||||||
|
lastOne: true,
|
||||||
|
readPoll: [2]unix.PollFd{
|
||||||
|
{Fd: int32(fd), Events: unix.POLLIN},
|
||||||
|
{Fd: int32(shutdownFd), Events: unix.POLLIN},
|
||||||
|
},
|
||||||
|
writePoll: [2]unix.PollFd{
|
||||||
|
{Fd: int32(fd), Events: unix.POLLOUT},
|
||||||
|
{Fd: int32(shutdownFd), Events: unix.POLLIN},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *tunFile) blockOnRead() error {
|
||||||
|
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
|
||||||
|
var err error
|
||||||
|
for {
|
||||||
|
_, err = unix.Poll(r.readPoll[:], -1)
|
||||||
|
if err != unix.EINTR {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
//always reset these!
|
||||||
|
tunEvents := r.readPoll[0].Revents
|
||||||
|
shutdownEvents := r.readPoll[1].Revents
|
||||||
|
r.readPoll[0].Revents = 0
|
||||||
|
r.readPoll[1].Revents = 0
|
||||||
|
//do the err check before trusting the potentially bogus bits we just got
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
|
||||||
|
return os.ErrClosed
|
||||||
|
} else if tunEvents&problemFlags != 0 {
|
||||||
|
return os.ErrClosed
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *tunFile) blockOnWrite() error {
|
||||||
|
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
|
||||||
|
var err error
|
||||||
|
for {
|
||||||
|
_, err = unix.Poll(r.writePoll[:], -1)
|
||||||
|
if err != unix.EINTR {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
//always reset these!
|
||||||
|
tunEvents := r.writePoll[0].Revents
|
||||||
|
shutdownEvents := r.writePoll[1].Revents
|
||||||
|
r.writePoll[0].Revents = 0
|
||||||
|
r.writePoll[1].Revents = 0
|
||||||
|
//do the err check before trusting the potentially bogus bits we just got
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
|
||||||
|
return os.ErrClosed
|
||||||
|
} else if tunEvents&problemFlags != 0 {
|
||||||
|
return os.ErrClosed
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *tunFile) Read(buf []byte) (int, error) {
|
||||||
|
for {
|
||||||
|
if n, err := unix.Read(r.fd, buf); err == nil {
|
||||||
|
return n, nil
|
||||||
|
} else if err == unix.EAGAIN {
|
||||||
|
if err = r.blockOnRead(); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
} else if err == unix.EINTR {
|
||||||
|
continue
|
||||||
|
} else if err == unix.EBADF {
|
||||||
|
return 0, os.ErrClosed
|
||||||
|
} else {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *tunFile) Write(buf []byte) (int, error) {
|
||||||
|
for {
|
||||||
|
if n, err := unix.Write(r.fd, buf); err == nil {
|
||||||
|
return n, nil
|
||||||
|
} else if err == unix.EAGAIN {
|
||||||
|
if err = r.blockOnWrite(); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
} else if err == unix.EINTR {
|
||||||
|
continue
|
||||||
|
} else if err == unix.EBADF {
|
||||||
|
return 0, os.ErrClosed
|
||||||
|
} else {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *tunFile) wakeForShutdown() error {
|
||||||
|
var buf [8]byte
|
||||||
|
binary.NativeEndian.PutUint64(buf[:], 1)
|
||||||
|
_, err := unix.Write(int(r.readPoll[1].Fd), buf[:])
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *tunFile) Close() error {
|
||||||
|
if r.closed { // avoid closing more than once. Technically a fd could get re-used, which would be a problem
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
r.closed = true
|
||||||
|
if r.lastOne {
|
||||||
|
_ = unix.Close(r.shutdownFd)
|
||||||
|
}
|
||||||
|
return unix.Close(r.fd)
|
||||||
|
}
|
||||||
|
|
||||||
type tun struct {
|
type tun struct {
|
||||||
io.ReadWriteCloser
|
*tunFile
|
||||||
fd int
|
readers []*tunFile
|
||||||
|
closeLock sync.Mutex
|
||||||
Device string
|
Device string
|
||||||
vpnNetworks []netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
MaxMTU int
|
MaxMTU int
|
||||||
@@ -72,9 +239,7 @@ type ifreqQLEN struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||||
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
t, err := newTunGeneric(c, l, deviceFd, vpnNetworks)
|
||||||
|
|
||||||
t, err := newTunGeneric(c, l, file, vpnNetworks)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -115,6 +280,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
|||||||
nameStr := c.GetString("tun.dev", "")
|
nameStr := c.GetString("tun.dev", "")
|
||||||
copy(req.Name[:], nameStr)
|
copy(req.Name[:], nameStr)
|
||||||
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
||||||
|
_ = unix.Close(fd)
|
||||||
return nil, &NameError{
|
return nil, &NameError{
|
||||||
Name: nameStr,
|
Name: nameStr,
|
||||||
Underlying: err,
|
Underlying: err,
|
||||||
@@ -122,8 +288,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
|||||||
}
|
}
|
||||||
name := strings.Trim(string(req.Name[:]), "\x00")
|
name := strings.Trim(string(req.Name[:]), "\x00")
|
||||||
|
|
||||||
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
t, err := newTunGeneric(c, l, fd, vpnNetworks)
|
||||||
t, err := newTunGeneric(c, l, file, vpnNetworks)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -133,10 +298,17 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
|||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
|
// newTunGeneric does all the stuff common to different tun initialization paths. It will close your files on error.
|
||||||
|
func newTunGeneric(c *config.C, l *logrus.Logger, fd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||||
|
tfd, err := newTunFd(fd)
|
||||||
|
if err != nil {
|
||||||
|
_ = unix.Close(fd)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
t := &tun{
|
t := &tun{
|
||||||
ReadWriteCloser: file,
|
tunFile: tfd,
|
||||||
fd: int(file.Fd()),
|
readers: []*tunFile{tfd},
|
||||||
|
closeLock: sync.Mutex{},
|
||||||
vpnNetworks: vpnNetworks,
|
vpnNetworks: vpnNetworks,
|
||||||
TXQueueLen: c.GetInt("tun.tx_queue", 500),
|
TXQueueLen: c.GetInt("tun.tx_queue", 500),
|
||||||
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
|
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
|
||||||
@@ -145,8 +317,8 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []n
|
|||||||
l: l,
|
l: l,
|
||||||
}
|
}
|
||||||
|
|
||||||
err := t.reload(c, true)
|
if err = t.reload(c, true); err != nil {
|
||||||
if err != nil {
|
_ = t.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -239,6 +411,9 @@ func (t *tun) SupportsMultiqueue() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
|
t.closeLock.Lock()
|
||||||
|
defer t.closeLock.Unlock()
|
||||||
|
|
||||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -248,12 +423,19 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
|||||||
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
|
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
|
||||||
copy(req.Name[:], t.Device)
|
copy(req.Name[:], t.Device)
|
||||||
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
||||||
|
_ = unix.Close(fd)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
out, err := t.tunFile.newFriend(fd)
|
||||||
|
if err != nil {
|
||||||
|
_ = unix.Close(fd)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return file, nil
|
t.readers = append(t.readers, out)
|
||||||
|
|
||||||
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||||
@@ -684,18 +866,40 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Close() error {
|
func (t *tun) Close() error {
|
||||||
|
t.closeLock.Lock()
|
||||||
|
defer t.closeLock.Unlock()
|
||||||
|
|
||||||
if t.routeChan != nil {
|
if t.routeChan != nil {
|
||||||
close(t.routeChan)
|
close(t.routeChan)
|
||||||
|
t.routeChan = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if t.ReadWriteCloser != nil {
|
// Signal all readers blocked in poll to wake up and exit
|
||||||
_ = t.ReadWriteCloser.Close()
|
_ = t.tunFile.wakeForShutdown()
|
||||||
}
|
|
||||||
|
|
||||||
if t.ioctlFd > 0 {
|
if t.ioctlFd > 0 {
|
||||||
_ = os.NewFile(t.ioctlFd, "ioctlFd").Close()
|
_ = unix.Close(int(t.ioctlFd))
|
||||||
t.ioctlFd = 0
|
t.ioctlFd = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
for i := range t.readers {
|
||||||
|
if i == 0 {
|
||||||
|
continue //we want to close the zeroth reader last
|
||||||
|
}
|
||||||
|
err := t.readers[i].Close()
|
||||||
|
if err != nil {
|
||||||
|
t.l.WithField("reader", i).WithError(err).Error("error closing tun reader")
|
||||||
|
} else {
|
||||||
|
t.l.WithField("reader", i).Info("closed tun reader")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//this is t.readers[0] too
|
||||||
|
err := t.tunFile.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.l.WithField("reader", 0).WithError(err).Error("error closing tun reader")
|
||||||
|
} else {
|
||||||
|
t.l.WithField("reader", 0).Info("closed tun reader")
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -44,7 +44,10 @@ type Service struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func New(control *nebula.Control) (*Service, error) {
|
func New(control *nebula.Control) (*Service, error) {
|
||||||
control.Start()
|
wait, err := control.Start()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
ctx := control.Context()
|
ctx := control.Context()
|
||||||
eg, ctx := errgroup.WithContext(ctx)
|
eg, ctx := errgroup.WithContext(ctx)
|
||||||
@@ -141,6 +144,12 @@ func New(control *nebula.Control) (*Service, error) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Add the nebula wait function to the group so a fatal reader error
|
||||||
|
// propagates out through errgroup.Wait().
|
||||||
|
eg.Go(func() error {
|
||||||
|
return wait()
|
||||||
|
})
|
||||||
|
|
||||||
return &s, nil
|
return &s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ type EncReader func(
|
|||||||
type Conn interface {
|
type Conn interface {
|
||||||
Rebind() error
|
Rebind() error
|
||||||
LocalAddr() (netip.AddrPort, error)
|
LocalAddr() (netip.AddrPort, error)
|
||||||
ListenOut(r EncReader)
|
ListenOut(r EncReader) error
|
||||||
WriteTo(b []byte, addr netip.AddrPort) error
|
WriteTo(b []byte, addr netip.AddrPort) error
|
||||||
ReloadConfig(c *config.C)
|
ReloadConfig(c *config.C)
|
||||||
SupportsMultipleReaders() bool
|
SupportsMultipleReaders() bool
|
||||||
@@ -31,8 +31,8 @@ func (NoopConn) Rebind() error {
|
|||||||
func (NoopConn) LocalAddr() (netip.AddrPort, error) {
|
func (NoopConn) LocalAddr() (netip.AddrPort, error) {
|
||||||
return netip.AddrPort{}, nil
|
return netip.AddrPort{}, nil
|
||||||
}
|
}
|
||||||
func (NoopConn) ListenOut(_ EncReader) {
|
func (NoopConn) ListenOut(_ EncReader) error {
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
func (NoopConn) SupportsMultipleReaders() bool {
|
func (NoopConn) SupportsMultipleReaders() bool {
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -165,7 +165,7 @@ func NewUDPStatsEmitter(udpConns []Conn) func() {
|
|||||||
return func() {}
|
return func() {}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) ListenOut(r EncReader) {
|
func (u *StdConn) ListenOut(r EncReader) error {
|
||||||
buffer := make([]byte, MTU)
|
buffer := make([]byte, MTU)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@@ -173,8 +173,7 @@ func (u *StdConn) ListenOut(r EncReader) {
|
|||||||
n, rua, err := u.ReadFromUDPAddrPort(buffer)
|
n, rua, err := u.ReadFromUDPAddrPort(buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, net.ErrClosed) {
|
if errors.Is(err, net.ErrClosed) {
|
||||||
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
return err
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
u.l.WithError(err).Error("unexpected udp socket receive error")
|
u.l.WithError(err).Error("unexpected udp socket receive error")
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ type rawMessage struct {
|
|||||||
Len uint32
|
Len uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *GenericConn) ListenOut(r EncReader) {
|
func (u *GenericConn) ListenOut(r EncReader) error {
|
||||||
buffer := make([]byte, MTU)
|
buffer := make([]byte, MTU)
|
||||||
|
|
||||||
var lastRecvErr time.Time
|
var lastRecvErr time.Time
|
||||||
@@ -83,8 +83,7 @@ func (u *GenericConn) ListenOut(r EncReader) {
|
|||||||
n, rua, err := u.ReadFromUDPAddrPort(buffer)
|
n, rua, err := u.ReadFromUDPAddrPort(buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, net.ErrClosed) {
|
if errors.Is(err, net.ErrClosed) {
|
||||||
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
return err
|
||||||
return
|
|
||||||
}
|
}
|
||||||
// Dampen unexpected message warns to once per minute
|
// Dampen unexpected message warns to once per minute
|
||||||
if lastRecvErr.IsZero() || time.Since(lastRecvErr) > time.Minute {
|
if lastRecvErr.IsZero() || time.Since(lastRecvErr) > time.Minute {
|
||||||
|
|||||||
@@ -171,7 +171,7 @@ func recvmmsg(fd uintptr, msgs []rawMessage) (int, bool, error) {
|
|||||||
return int(n), true, nil
|
return int(n), true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) listenOutSingle(r EncReader) {
|
func (u *StdConn) listenOutSingle(r EncReader) error {
|
||||||
var err error
|
var err error
|
||||||
var n int
|
var n int
|
||||||
var from netip.AddrPort
|
var from netip.AddrPort
|
||||||
@@ -180,15 +180,14 @@ func (u *StdConn) listenOutSingle(r EncReader) {
|
|||||||
for {
|
for {
|
||||||
n, from, err = u.udpConn.ReadFromUDPAddrPort(buffer)
|
n, from, err = u.udpConn.ReadFromUDPAddrPort(buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
return err
|
||||||
return
|
|
||||||
}
|
}
|
||||||
from = netip.AddrPortFrom(from.Addr().Unmap(), from.Port())
|
from = netip.AddrPortFrom(from.Addr().Unmap(), from.Port())
|
||||||
r(from, buffer[:n])
|
r(from, buffer[:n])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) listenOutBatch(r EncReader) {
|
func (u *StdConn) listenOutBatch(r EncReader) error {
|
||||||
var ip netip.Addr
|
var ip netip.Addr
|
||||||
var n int
|
var n int
|
||||||
var operr error
|
var operr error
|
||||||
@@ -205,12 +204,10 @@ func (u *StdConn) listenOutBatch(r EncReader) {
|
|||||||
for {
|
for {
|
||||||
err := u.rawConn.Read(reader)
|
err := u.rawConn.Read(reader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
return err
|
||||||
return
|
|
||||||
}
|
}
|
||||||
if operr != nil {
|
if operr != nil {
|
||||||
u.l.WithError(operr).Debug("operr: udp socket is closed, exiting read loop")
|
return operr
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < n; i++ {
|
for i := 0; i < n; i++ {
|
||||||
@@ -225,14 +222,11 @@ func (u *StdConn) listenOutBatch(r EncReader) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) ListenOut(r EncReader) {
|
func (u *StdConn) ListenOut(r EncReader) error {
|
||||||
if u.batch == 1 {
|
if u.batch == 1 {
|
||||||
//save some ram by not calling PrepareRawMessages for fields we won't use
|
return u.listenOutSingle(r)
|
||||||
//we could also make this path more common by calling recvmmsg with msgs[:1],
|
|
||||||
//but that's still the recvmmsg syscall, which would be a change
|
|
||||||
u.listenOutSingle(r)
|
|
||||||
} else {
|
} else {
|
||||||
u.listenOutBatch(r)
|
return u.listenOutBatch(r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -140,7 +140,7 @@ func (u *RIOConn) bind(l *logrus.Logger, sa windows.Sockaddr) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *RIOConn) ListenOut(r EncReader) {
|
func (u *RIOConn) ListenOut(r EncReader) error {
|
||||||
buffer := make([]byte, MTU)
|
buffer := make([]byte, MTU)
|
||||||
|
|
||||||
var lastRecvErr time.Time
|
var lastRecvErr time.Time
|
||||||
@@ -151,8 +151,7 @@ func (u *RIOConn) ListenOut(r EncReader) {
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, net.ErrClosed) {
|
if errors.Is(err, net.ErrClosed) {
|
||||||
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
|
return err
|
||||||
return
|
|
||||||
}
|
}
|
||||||
// Dampen unexpected message warns to once per minute
|
// Dampen unexpected message warns to once per minute
|
||||||
if lastRecvErr.IsZero() || time.Since(lastRecvErr) > time.Minute {
|
if lastRecvErr.IsZero() || time.Since(lastRecvErr) > time.Minute {
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ package udp
|
|||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"os"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
@@ -106,11 +107,11 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *TesterConn) ListenOut(r EncReader) {
|
func (u *TesterConn) ListenOut(r EncReader) error {
|
||||||
for {
|
for {
|
||||||
p, ok := <-u.RxPackets
|
p, ok := <-u.RxPackets
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return os.ErrClosed
|
||||||
}
|
}
|
||||||
r(p.From, p.Data)
|
r(p.From, p.Data)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user