Compare commits

...

10 Commits

Author SHA1 Message Date
JackDoan
2ab75709ad hmm 2025-11-04 15:40:33 -06:00
Nate Brown
2ea8a72d5c dunno 2025-10-05 23:23:30 -05:00
Nate Brown
663232e1fc Testing the concept 2025-10-05 23:23:10 -05:00
Nate Brown
2f48529e8b Cleanup and note more work 2025-10-05 23:23:08 -05:00
Nate Brown
f3e1ad64cd Try the timeout 2025-10-05 23:22:29 -05:00
Nate Brown
1d8112a329 Revert "More playing" way too much garbage emitted
This reverts commit fa098c551a.
2025-10-05 23:22:29 -05:00
Nate Brown
31eea0cc94 More playing 2025-10-05 23:22:29 -05:00
Nate Brown
dbba4a4c77 Playing 2025-10-05 23:22:29 -05:00
Nate Brown
194fde45da non-blocking io for linux 2025-10-05 23:22:27 -05:00
Nate Brown
f46b83f2c4 Remove more os.Exit calls and give a more reliable wait for stop function 2025-10-05 23:20:43 -05:00
15 changed files with 306 additions and 91 deletions

View File

@@ -5,6 +5,7 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
// TODO: Pretty sure this is just all sorts of racy now, we need it to be atomic
type Bits struct { type Bits struct {
length uint64 length uint64
current uint64 current uint64
@@ -43,7 +44,7 @@ func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool {
} }
// Not within the window // Not within the window
l.Debugf("rejected a packet (top) %d %d\n", b.current, i) l.Error("rejected a packet (top) %d %d\n", b.current, i)
return false return false
} }

View File

@@ -65,8 +65,16 @@ func main() {
} }
if !*configTest { if !*configTest {
ctrl.Start() wait, err := ctrl.Start()
ctrl.ShutdownBlock() if err != nil {
util.LogWithContextIfNeeded("Error while running", err, l)
os.Exit(1)
}
go ctrl.ShutdownBlock()
wait()
l.Info("Goodbye")
} }
os.Exit(0) os.Exit(0)

View File

@@ -3,6 +3,9 @@ package main
import ( import (
"flag" "flag"
"fmt" "fmt"
"log"
"net/http"
_ "net/http/pprof"
"os" "os"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@@ -58,10 +61,22 @@ func main() {
os.Exit(1) os.Exit(1)
} }
go func() {
log.Println(http.ListenAndServe("0.0.0.0:6060", nil))
}()
if !*configTest { if !*configTest {
ctrl.Start() wait, err := ctrl.Start()
if err != nil {
util.LogWithContextIfNeeded("Error while running", err, l)
os.Exit(1)
}
go ctrl.ShutdownBlock()
notifyReady(l) notifyReady(l)
ctrl.ShutdownBlock() wait()
l.Info("Goodbye")
} }
os.Exit(0) os.Exit(0)

View File

@@ -13,7 +13,9 @@ import (
"github.com/slackhq/nebula/noiseutil" "github.com/slackhq/nebula/noiseutil"
) )
const ReplayWindow = 1024 // TODO: In a 5Gbps test, 1024 is not sufficient. With a 1400 MTU this is about 1.4Gbps of window, assuming full packets.
// 4092 should be sufficient for 5Gbps
const ReplayWindow = 8192
type ConnectionState struct { type ConnectionState struct {
eKey *NebulaCipherState eKey *NebulaCipherState

View File

@@ -2,9 +2,11 @@ package nebula
import ( import (
"context" "context"
"errors"
"net/netip" "net/netip"
"os" "os"
"os/signal" "os/signal"
"sync"
"syscall" "syscall"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@@ -13,6 +15,16 @@ import (
"github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/overlay"
) )
type RunState int
const (
Stopped RunState = 0 // The control has yet to be started
Started RunState = 1 // The control has been started
Stopping RunState = 2 // The control is stopping
)
var ErrAlreadyStarted = errors.New("nebula is already started")
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching // Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc // core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
@@ -26,6 +38,9 @@ type controlHostLister interface {
} }
type Control struct { type Control struct {
stateLock sync.Mutex
state RunState
f *Interface f *Interface
l *logrus.Logger l *logrus.Logger
ctx context.Context ctx context.Context
@@ -49,10 +64,21 @@ type ControlHostInfo struct {
CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"` CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"`
} }
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock() // Start actually runs nebula, this is a nonblocking call.
func (c *Control) Start() { // The returned function can be used to wait for nebula to fully stop.
func (c *Control) Start() (func(), error) {
c.stateLock.Lock()
if c.state != Stopped {
c.stateLock.Unlock()
return nil, ErrAlreadyStarted
}
// Activate the interface // Activate the interface
c.f.activate() err := c.f.activate()
if err != nil {
c.stateLock.Unlock()
return nil, err
}
// Call all the delayed funcs that waited patiently for the interface to be created. // Call all the delayed funcs that waited patiently for the interface to be created.
if c.sshStart != nil { if c.sshStart != nil {
@@ -72,15 +98,33 @@ func (c *Control) Start() {
} }
// Start reading packets. // Start reading packets.
c.f.run() c.state = Started
c.stateLock.Unlock()
return c.f.run(c.ctx)
}
func (c *Control) State() RunState {
c.stateLock.Lock()
defer c.stateLock.Unlock()
return c.state
} }
func (c *Control) Context() context.Context { func (c *Control) Context() context.Context {
return c.ctx return c.ctx
} }
// Stop signals nebula to shutdown and close all tunnels, returns after the shutdown is complete // Stop is a non-blocking call that signals nebula to close all tunnels and shut down
func (c *Control) Stop() { func (c *Control) Stop() {
c.stateLock.Lock()
if c.state != Started {
c.stateLock.Unlock()
// We are stopping or stopped already
return
}
c.state = Stopping
c.stateLock.Unlock()
// Stop the handshakeManager (and other services), to prevent new tunnels from // Stop the handshakeManager (and other services), to prevent new tunnels from
// being created while we're shutting them all down. // being created while we're shutting them all down.
c.cancel() c.cancel()
@@ -89,7 +133,7 @@ func (c *Control) Stop() {
if err := c.f.Close(); err != nil { if err := c.f.Close(); err != nil {
c.l.WithError(err).Error("Close interface failed") c.l.WithError(err).Error("Close interface failed")
} }
c.l.Info("Goodbye") c.state = Stopped
} }
// ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled // ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled

View File

@@ -6,8 +6,8 @@ import (
"fmt" "fmt"
"io" "io"
"net/netip" "net/netip"
"os"
"runtime" "runtime"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -18,6 +18,7 @@ import (
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/packet"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
) )
@@ -87,12 +88,17 @@ type Interface struct {
writers []udp.Conn writers []udp.Conn
readers []io.ReadWriteCloser readers []io.ReadWriteCloser
wg sync.WaitGroup
metricHandshakes metrics.Histogram metricHandshakes metrics.Histogram
messageMetrics *MessageMetrics messageMetrics *MessageMetrics
cachedPacketMetrics *cachedPacketMetrics cachedPacketMetrics *cachedPacketMetrics
l *logrus.Logger l *logrus.Logger
pktPool *packet.Pool
inbound chan *packet.Packet
outbound chan *packet.Packet
} }
type EncWriter interface { type EncWriter interface {
@@ -194,9 +200,15 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil), dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil),
}, },
//TODO: configurable size
inbound: make(chan *packet.Packet, 2048),
outbound: make(chan *packet.Packet, 2048),
l: c.l, l: c.l,
} }
ifce.pktPool = packet.NewPool()
ifce.tryPromoteEvery.Store(c.tryPromoteEvery) ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
ifce.reQueryEvery.Store(c.reQueryEvery) ifce.reQueryEvery.Store(c.reQueryEvery)
ifce.reQueryWait.Store(int64(c.reQueryWait)) ifce.reQueryWait.Store(int64(c.reQueryWait))
@@ -209,7 +221,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
// activate creates the interface on the host. After the interface is created, any // activate creates the interface on the host. After the interface is created, any
// other services that want to bind listeners to its IP may do so successfully. However, // other services that want to bind listeners to its IP may do so successfully. However,
// the interface isn't going to process anything until run() is called. // the interface isn't going to process anything until run() is called.
func (f *Interface) activate() { func (f *Interface) activate() error {
// actually turn on tun dev // actually turn on tun dev
addr, err := f.outside.LocalAddr() addr, err := f.outside.LocalAddr()
@@ -230,33 +242,46 @@ func (f *Interface) activate() {
if i > 0 { if i > 0 {
reader, err = f.inside.NewMultiQueueReader() reader, err = f.inside.NewMultiQueueReader()
if err != nil { if err != nil {
f.l.Fatal(err) return err
} }
} }
f.readers[i] = reader f.readers[i] = reader
} }
if err := f.inside.Activate(); err != nil { if err = f.inside.Activate(); err != nil {
f.inside.Close() f.inside.Close()
f.l.Fatal(err) return err
}
} }
func (f *Interface) run() { return nil
// Launch n queues to read packets from udp }
func (f *Interface) run(c context.Context) (func(), error) {
for i := 0; i < f.routines; i++ { for i := 0; i < f.routines; i++ {
// read packets from udp and queue to f.inbound
f.wg.Add(1)
go f.listenOut(i) go f.listenOut(i)
// Launch n queues to read packets from inside tun dev and queue to f.outbound
//todo this never stops f.wg.Add(1)
go f.listenIn(f.readers[i], i)
// Launch n workers to process traffic from f.inbound and smash it onto the inside of the tun
f.wg.Add(1)
go f.workerIn(i, c)
f.wg.Add(1)
go f.workerIn(i, c)
// read from f.outbound and write to UDP (outside the tun)
f.wg.Add(1)
go f.workerOut(i, c)
} }
// Launch n queues to read packets from tun dev return f.wg.Wait, nil
for i := 0; i < f.routines; i++ {
go f.listenIn(f.readers[i], i)
}
} }
func (f *Interface) listenOut(i int) { func (f *Interface) listenOut(i int) {
runtime.LockOSThread() runtime.LockOSThread()
var li udp.Conn var li udp.Conn
if i > 0 { if i > 0 {
li = f.writers[i] li = f.writers[i]
@@ -264,41 +289,79 @@ func (f *Interface) listenOut(i int) {
li = f.outside li = f.outside
} }
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) err := li.ListenOut(f.pktPool.Get, f.inbound)
lhh := f.lightHouse.NewRequestHandler() if err != nil && !f.closed.Load() {
plaintext := make([]byte, udp.MTU) f.l.WithError(err).Error("Error while reading packet inbound packet, closing")
h := &header.H{} //TODO: Trigger Control to close
fwPacket := &firewall.Packet{} }
nb := make([]byte, 12, 12)
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { f.l.Debugf("underlay reader %v is done", i)
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) f.wg.Done()
})
} }
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
runtime.LockOSThread() runtime.LockOSThread()
packet := make([]byte, mtu) for {
out := make([]byte, mtu) p := f.pktPool.Get()
fwPacket := &firewall.Packet{} n, err := reader.Read(p.Payload)
nb := make([]byte, 12, 12) if err != nil {
if !f.closed.Load() {
f.l.WithError(err).Error("Error while reading outbound packet, closing")
//TODO: Trigger Control to close
}
break
}
p.Payload = (p.Payload)[:n]
//TODO: nonblocking channel write
f.outbound <- p
//select {
//case f.outbound <- p:
//default:
// f.l.Error("Dropped packet from outbound channel")
//}
}
f.l.Debugf("overlay reader %v is done", i)
f.wg.Done()
}
func (f *Interface) workerIn(i int, ctx context.Context) {
lhh := f.lightHouse.NewRequestHandler()
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
fwPacket2 := &firewall.Packet{}
nb2 := make([]byte, 12, 12)
result2 := make([]byte, mtu)
h := &header.H{}
for { for {
n, err := reader.Read(packet) select {
if err != nil { case p := <-f.inbound:
if errors.Is(err, os.ErrClosed) && f.closed.Load() { f.readOutsidePackets(p.Addr, nil, result2[:0], p.Payload, h, fwPacket2, lhh, nb2, i, conntrackCache.Get(f.l))
f.pktPool.Put(p)
case <-ctx.Done():
f.wg.Done()
return return
} }
}
f.l.WithError(err).Error("Error while reading outbound packet")
// This only seems to happen when something fatal happens to the fd, so exit.
os.Exit(2)
} }
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l)) func (f *Interface) workerOut(i int, ctx context.Context) {
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
fwPacket1 := &firewall.Packet{}
nb1 := make([]byte, 12, 12)
result1 := make([]byte, mtu)
for {
select {
case data := <-f.outbound:
f.consumeInsidePacket(data.Payload, fwPacket1, nb1, result1, i, conntrackCache.Get(f.l))
f.pktPool.Put(data)
case <-ctx.Done():
f.wg.Done()
return
}
} }
} }
@@ -451,6 +514,7 @@ func (f *Interface) GetCertState() *CertState {
func (f *Interface) Close() error { func (f *Interface) Close() error {
f.closed.Store(true) f.closed.Store(true)
// Release the udp readers
for _, u := range f.writers { for _, u := range f.writers {
err := u.Close() err := u.Close()
if err != nil { if err != nil {
@@ -458,6 +522,13 @@ func (f *Interface) Close() error {
} }
} }
// Release the tun device // Release the tun readers
return f.inside.Close() for _, u := range f.readers {
err := u.Close()
if err != nil {
f.l.WithError(err).Error("Error while closing tun device")
}
}
return nil
} }

18
main.go
View File

@@ -284,14 +284,14 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
} }
return &Control{ return &Control{
ifce, f: ifce,
l, l: l,
ctx, ctx: ctx,
cancel, cancel: cancel,
sshStart, sshStart: sshStart,
statsStart, statsStart: statsStart,
dnsStart, dnsStart: dnsStart,
lightHouse.StartUpdateWorker, lighthouseStart: lightHouse.StartUpdateWorker,
connManager.Start, connectionManagerStart: connManager.Start,
}, nil }, nil
} }

View File

@@ -29,7 +29,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
return return
} }
//l.Error("in packet ", header, packet[HeaderLen:]) //f.l.Error("in packet ", h)
if ip.IsValid() { if ip.IsValid() {
if f.myVpnNetworksTable.Contains(ip.Addr()) { if f.myVpnNetworksTable.Contains(ip.Addr()) {
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
@@ -245,6 +245,7 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort
return return
} }
//TODO: Seems we have a bunch of stuff racing here, since we don't have a lock on hostinfo anymore we announce roaming in bursts
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr). hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
Info("Host roamed to new udp ip/port.") Info("Host roamed to new udp ip/port.")
hostinfo.lastRoam = time.Now() hostinfo.lastRoam = time.Now()
@@ -470,7 +471,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb) out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
if err != nil { if err != nil {
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet") hostinfo.logger(f.l).WithError(err).WithField("fwPacket", fwPacket).Error("Failed to decrypt packet")
return false return false
} }

36
packet/packet.go Normal file
View File

@@ -0,0 +1,36 @@
package packet
import (
"net/netip"
"sync"
)
const Size = 9001
type Packet struct {
Payload []byte
Addr netip.AddrPort
}
func New() *Packet {
return &Packet{Payload: make([]byte, Size)}
}
type Pool struct {
pool sync.Pool
}
func NewPool() *Pool {
return &Pool{
pool: sync.Pool{New: func() any { return New() }},
}
}
func (p *Pool) Get() *Packet {
return p.pool.Get().(*Packet)
}
func (p *Pool) Put(x *Packet) {
x.Payload = x.Payload[:Size]
p.pool.Put(x)
}

View File

@@ -9,10 +9,13 @@ import (
"math" "math"
"net" "net"
"net/netip" "net/netip"
"os"
"strings" "strings"
"sync" "sync"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula" "github.com/slackhq/nebula"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/overlay"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/buffer"
@@ -43,8 +46,19 @@ type Service struct {
} }
} }
func New(control *nebula.Control) (*Service, error) { func New(config *config.C) (*Service, error) {
control.Start() logger := logrus.New()
logger.Out = os.Stdout
control, err := nebula.Main(config, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
if err != nil {
return nil, err
}
wait, err := control.Start()
if err != nil {
return nil, err
}
ctx := control.Context() ctx := control.Context()
eg, ctx := errgroup.WithContext(ctx) eg, ctx := errgroup.WithContext(ctx)
@@ -141,6 +155,12 @@ func New(control *nebula.Control) (*Service, error) {
} }
}) })
// Add the nebula wait function to the group
eg.Go(func() error {
wait()
return nil
})
return &s, nil return &s, nil
} }

View File

@@ -4,19 +4,19 @@ import (
"net/netip" "net/netip"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/packet"
) )
const MTU = 9001 const MTU = 9001
type EncReader func( type EncReader func(*packet.Packet)
addr netip.AddrPort,
payload []byte, type PacketBufferGetter func() *packet.Packet
)
type Conn interface { type Conn interface {
Rebind() error Rebind() error
LocalAddr() (netip.AddrPort, error) LocalAddr() (netip.AddrPort, error)
ListenOut(r EncReader) ListenOut(pg PacketBufferGetter, pc chan *packet.Packet) error
WriteTo(b []byte, addr netip.AddrPort) error WriteTo(b []byte, addr netip.AddrPort) error
ReloadConfig(c *config.C) ReloadConfig(c *config.C)
Close() error Close() error

View File

@@ -71,15 +71,14 @@ type rawMessage struct {
Len uint32 Len uint32
} }
func (u *GenericConn) ListenOut(r EncReader) { func (u *GenericConn) ListenOut(r EncReader) error {
buffer := make([]byte, MTU) buffer := make([]byte, MTU)
for { for {
// Just read one packet at a time // Just read one packet at a time
n, rua, err := u.ReadFromUDPAddrPort(buffer) n, rua, err := u.ReadFromUDPAddrPort(buffer)
if err != nil { if err != nil {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop") return err
return
} }
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])

View File

@@ -9,14 +9,18 @@ import (
"net" "net"
"net/netip" "net/netip"
"syscall" "syscall"
"time"
"unsafe" "unsafe"
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/packet"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
var readTimeout = unix.NsecToTimeval(int64(time.Millisecond * 500))
type StdConn struct { type StdConn struct {
sysFd int sysFd int
isV4 bool isV4 bool
@@ -24,14 +28,6 @@ type StdConn struct {
batch int batch int
} }
func maybeIPV4(ip net.IP) (net.IP, bool) {
ip4 := ip.To4()
if ip4 != nil {
return ip4, true
}
return ip, false
}
func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
af := unix.AF_INET6 af := unix.AF_INET6
if ip.Is4() { if ip.Is4() {
@@ -55,6 +51,11 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
} }
} }
// Set a read timeout
if err = unix.SetsockoptTimeval(fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &readTimeout); err != nil {
return nil, fmt.Errorf("unable to set SO_RCVTIMEO: %s", err)
}
var sa unix.Sockaddr var sa unix.Sockaddr
if ip.Is4() { if ip.Is4() {
sa4 := &unix.SockaddrInet4{Port: port} sa4 := &unix.SockaddrInet4{Port: port}
@@ -118,10 +119,10 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
} }
} }
func (u *StdConn) ListenOut(r EncReader) { func (u *StdConn) ListenOut(pg PacketBufferGetter, pc chan *packet.Packet) error {
var ip netip.Addr var ip netip.Addr
msgs, buffers, names := u.PrepareRawMessages(u.batch) msgs, packets, names := u.PrepareRawMessages(u.batch, pg)
read := u.ReadMulti read := u.ReadMulti
if u.batch == 1 { if u.batch == 1 {
read = u.ReadSingle read = u.ReadSingle
@@ -130,18 +131,25 @@ func (u *StdConn) ListenOut(r EncReader) {
for { for {
n, err := read(msgs) n, err := read(msgs)
if err != nil { if err != nil {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop") return err
return
} }
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
out := packets[i]
out.Payload = out.Payload[:msgs[i].Len]
// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic // Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
if u.isV4 { if u.isV4 {
ip, _ = netip.AddrFromSlice(names[i][4:8]) ip, _ = netip.AddrFromSlice(names[i][4:8])
} else { } else {
ip, _ = netip.AddrFromSlice(names[i][8:24]) ip, _ = netip.AddrFromSlice(names[i][8:24])
} }
r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len]) out.Addr = netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4]))
pc <- out
//rotate this packet out so we don't overwrite it
packets[i] = pg()
msgs[i].Hdr.Iov.Base = &packets[i].Payload[0]
} }
} }
} }
@@ -159,6 +167,9 @@ func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) {
) )
if err != 0 { if err != 0 {
if err == unix.EAGAIN || err == unix.EINTR {
continue
}
return 0, &net.OpError{Op: "recvmsg", Err: err} return 0, &net.OpError{Op: "recvmsg", Err: err}
} }
@@ -180,6 +191,9 @@ func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) {
) )
if err != 0 { if err != 0 {
if err == unix.EAGAIN || err == unix.EINTR {
continue
}
return 0, &net.OpError{Op: "recvmmsg", Err: err} return 0, &net.OpError{Op: "recvmmsg", Err: err}
} }
@@ -221,7 +235,7 @@ func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error { func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error {
if !ip.Addr().Is4() { if !ip.Addr().Is4() {
return ErrInvalidIPv6RemoteForSocket return fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
} }
var rsa unix.RawSockaddrInet4 var rsa unix.RawSockaddrInet4

View File

@@ -7,6 +7,7 @@
package udp package udp
import ( import (
"github.com/slackhq/nebula/packet"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
@@ -33,17 +34,20 @@ type rawMessage struct {
Pad0 [4]byte Pad0 [4]byte
} }
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { func (u *StdConn) PrepareRawMessages(n int, pg PacketBufferGetter) ([]rawMessage, []*packet.Packet, [][]byte) {
msgs := make([]rawMessage, n) msgs := make([]rawMessage, n)
buffers := make([][]byte, n)
names := make([][]byte, n) names := make([][]byte, n)
packets := make([]*packet.Packet, n)
for i := range packets {
packets[i] = pg()
}
for i := range msgs { for i := range msgs {
buffers[i] = make([]byte, MTU)
names[i] = make([]byte, unix.SizeofSockaddrInet6) names[i] = make([]byte, unix.SizeofSockaddrInet6)
vs := []iovec{ vs := []iovec{
{Base: &buffers[i][0], Len: uint64(len(buffers[i]))}, {Base: &packets[i].Payload[0], Len: uint64(packet.Size)},
} }
msgs[i].Hdr.Iov = &vs[0] msgs[i].Hdr.Iov = &vs[0]
@@ -53,5 +57,5 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
msgs[i].Hdr.Namelen = uint32(len(names[i])) msgs[i].Hdr.Namelen = uint32(len(names[i]))
} }
return msgs, buffers, names return msgs, packets, names
} }

View File

@@ -134,7 +134,7 @@ func (u *RIOConn) bind(sa windows.Sockaddr) error {
return nil return nil
} }
func (u *RIOConn) ListenOut(r EncReader) { func (u *RIOConn) ListenOut(r EncReader) error {
buffer := make([]byte, MTU) buffer := make([]byte, MTU)
for { for {