Merge remote-tracking branch 'origin/master' into multiport

This commit is contained in:
Wade Simmons
2023-10-27 08:48:13 -04:00
74 changed files with 2540 additions and 1402 deletions

View File

@@ -1,6 +1,7 @@
package udp
import (
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
)
@@ -18,3 +19,33 @@ type EncReader func(
q int,
localCache firewall.ConntrackCache,
)
type Conn interface {
Rebind() error
LocalAddr() (*Addr, error)
ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int)
WriteTo(b []byte, addr *Addr) error
ReloadConfig(c *config.C)
Close() error
}
type NoopConn struct{}
func (NoopConn) Rebind() error {
return nil
}
func (NoopConn) LocalAddr() (*Addr, error) {
return nil, nil
}
func (NoopConn) ListenOut(_ EncReader, _ LightHouseHandlerFunc, _ *firewall.ConntrackCacheTicker, _ int) {
return
}
func (NoopConn) WriteTo(_ []byte, _ *Addr) error {
return nil
}
func (NoopConn) ReloadConfig(_ *config.C) {
return
}
func (NoopConn) Close() error {
return nil
}

View File

@@ -8,9 +8,14 @@ import (
"net"
"syscall"
"github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
return NewGenericListener(l, ip, port, multi, batch)
}
func NewListenConfig(multi bool) net.ListenConfig {
return net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
@@ -34,6 +39,6 @@ func NewListenConfig(multi bool) net.ListenConfig {
}
}
func (u *Conn) Rebind() error {
func (u *GenericConn) Rebind() error {
return nil
}

47
udp/udp_bsd.go Normal file
View File

@@ -0,0 +1,47 @@
//go:build (openbsd || freebsd) && !e2e_testing
// +build openbsd freebsd
// +build !e2e_testing
package udp
// FreeBSD support is primarily implemented in udp_generic, besides NewListenConfig
import (
"fmt"
"net"
"syscall"
"github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
return NewGenericListener(l, ip, port, multi, batch)
}
func NewListenConfig(multi bool) net.ListenConfig {
return net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
if multi {
var controlErr error
err := c.Control(func(fd uintptr) {
if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
controlErr = fmt.Errorf("SO_REUSEPORT failed: %v", err)
return
}
})
if err != nil {
return err
}
if controlErr != nil {
return controlErr
}
}
return nil
},
}
}
func (u *GenericConn) Rebind() error {
return nil
}

View File

@@ -10,9 +10,14 @@ import (
"net"
"syscall"
"github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
return NewGenericListener(l, ip, port, multi, batch)
}
func NewListenConfig(multi bool) net.ListenConfig {
return net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
@@ -37,11 +42,16 @@ func NewListenConfig(multi bool) net.ListenConfig {
}
}
func (u *Conn) Rebind() error {
file, err := u.File()
func (u *GenericConn) Rebind() error {
rc, err := u.UDPConn.SyscallConn()
if err != nil {
return err
}
return syscall.SetsockoptInt(int(file.Fd()), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, 0)
return rc.Control(func(fd uintptr) {
err := syscall.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, 0)
if err != nil {
u.l.WithError(err).Error("Failed to rebind udp socket")
}
})
}

View File

@@ -18,30 +18,32 @@ import (
"github.com/slackhq/nebula/header"
)
type Conn struct {
type GenericConn struct {
*net.UDPConn
l *logrus.Logger
}
func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (*Conn, error) {
var _ Conn = &GenericConn{}
func NewGenericListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
lc := NewListenConfig(multi)
pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)))
if err != nil {
return nil, err
}
if uc, ok := pc.(*net.UDPConn); ok {
return &Conn{UDPConn: uc, l: l}, nil
return &GenericConn{UDPConn: uc, l: l}, nil
}
return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc)
}
func (uc *Conn) WriteTo(b []byte, addr *Addr) error {
_, err := uc.UDPConn.WriteToUDP(b, &net.UDPAddr{IP: addr.IP, Port: int(addr.Port)})
func (u *GenericConn) WriteTo(b []byte, addr *Addr) error {
_, err := u.UDPConn.WriteToUDP(b, &net.UDPAddr{IP: addr.IP, Port: int(addr.Port)})
return err
}
func (uc *Conn) LocalAddr() (*Addr, error) {
a := uc.UDPConn.LocalAddr()
func (u *GenericConn) LocalAddr() (*Addr, error) {
a := u.UDPConn.LocalAddr()
switch v := a.(type) {
case *net.UDPAddr:
@@ -55,11 +57,11 @@ func (uc *Conn) LocalAddr() (*Addr, error) {
}
}
func (u *Conn) ReloadConfig(c *config.C) {
func (u *GenericConn) ReloadConfig(c *config.C) {
// TODO
}
func NewUDPStatsEmitter(udpConns []*Conn) func() {
func NewUDPStatsEmitter(udpConns []Conn) func() {
// No UDP stats for non-linux
return func() {}
}
@@ -68,7 +70,7 @@ type rawMessage struct {
Len uint32
}
func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
func (u *GenericConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
plaintext := make([]byte, MTU)
buffer := make([]byte, MTU)
h := &header.H{}
@@ -80,8 +82,8 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall
// Just read one packet at a time
n, rua, err := u.ReadFromUDP(buffer)
if err != nil {
u.l.WithError(err).Error("Failed to read packets")
continue
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
}
udpAddr.IP = rua.IP

View File

@@ -20,7 +20,7 @@ import (
//TODO: make it support reload as best you can!
type Conn struct {
type StdConn struct {
sysFd int
l *logrus.Logger
batch int
@@ -45,7 +45,7 @@ const (
type _SK_MEMINFO [_SK_MEMINFO_VARS]uint32
func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (*Conn, error) {
func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
syscall.ForkLock.RLock()
fd, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
if err == nil {
@@ -77,30 +77,30 @@ func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (
//v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU)
//l.Println(v, err)
return &Conn{sysFd: fd, l: l, batch: batch}, err
return &StdConn{sysFd: fd, l: l, batch: batch}, err
}
func (u *Conn) Rebind() error {
func (u *StdConn) Rebind() error {
return nil
}
func (u *Conn) SetRecvBuffer(n int) error {
func (u *StdConn) SetRecvBuffer(n int) error {
return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n)
}
func (u *Conn) SetSendBuffer(n int) error {
func (u *StdConn) SetSendBuffer(n int) error {
return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, n)
}
func (u *Conn) GetRecvBuffer() (int, error) {
func (u *StdConn) GetRecvBuffer() (int, error) {
return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_RCVBUF)
}
func (u *Conn) GetSendBuffer() (int, error) {
func (u *StdConn) GetSendBuffer() (int, error) {
return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF)
}
func (u *Conn) LocalAddr() (*Addr, error) {
func (u *StdConn) LocalAddr() (*Addr, error) {
sa, err := unix.Getsockname(u.sysFd)
if err != nil {
return nil, err
@@ -119,7 +119,7 @@ func (u *Conn) LocalAddr() (*Addr, error) {
return addr, nil
}
func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
plaintext := make([]byte, MTU)
h := &header.H{}
fwPacket := &firewall.Packet{}
@@ -137,8 +137,8 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall
for {
n, err := read(msgs)
if err != nil {
u.l.WithError(err).Error("Failed to read packets")
continue
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
}
//metric.Update(int64(n))
@@ -150,7 +150,7 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall
}
}
func (u *Conn) ReadSingle(msgs []rawMessage) (int, error) {
func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) {
for {
n, _, err := unix.Syscall6(
unix.SYS_RECVMSG,
@@ -171,7 +171,7 @@ func (u *Conn) ReadSingle(msgs []rawMessage) (int, error) {
}
}
func (u *Conn) ReadMulti(msgs []rawMessage) (int, error) {
func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) {
for {
n, _, err := unix.Syscall6(
unix.SYS_RECVMMSG,
@@ -191,7 +191,7 @@ func (u *Conn) ReadMulti(msgs []rawMessage) (int, error) {
}
}
func (u *Conn) WriteTo(b []byte, addr *Addr) error {
func (u *StdConn) WriteTo(b []byte, addr *Addr) error {
var rsa unix.RawSockaddrInet6
rsa.Family = unix.AF_INET6
@@ -221,7 +221,7 @@ func (u *Conn) WriteTo(b []byte, addr *Addr) error {
}
}
func (u *Conn) ReloadConfig(c *config.C) {
func (u *StdConn) ReloadConfig(c *config.C) {
b := c.GetInt("listen.read_buffer", 0)
if b > 0 {
err := u.SetRecvBuffer(b)
@@ -253,7 +253,7 @@ func (u *Conn) ReloadConfig(c *config.C) {
}
}
func (u *Conn) getMemInfo(meminfo *_SK_MEMINFO) error {
func (u *StdConn) getMemInfo(meminfo *_SK_MEMINFO) error {
var vallen uint32 = 4 * _SK_MEMINFO_VARS
_, _, err := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(u.sysFd), uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0)
if err != 0 {
@@ -262,11 +262,16 @@ func (u *Conn) getMemInfo(meminfo *_SK_MEMINFO) error {
return nil
}
func NewUDPStatsEmitter(udpConns []*Conn) func() {
func (u *StdConn) Close() error {
//TODO: this will not interrupt the read loop
return syscall.Close(u.sysFd)
}
func NewUDPStatsEmitter(udpConns []Conn) func() {
// Check if our kernel supports SO_MEMINFO before registering the gauges
var udpGauges [][_SK_MEMINFO_VARS]metrics.Gauge
var meminfo _SK_MEMINFO
if err := udpConns[0].getMemInfo(&meminfo); err == nil {
if err := udpConns[0].(*StdConn).getMemInfo(&meminfo); err == nil {
udpGauges = make([][_SK_MEMINFO_VARS]metrics.Gauge, len(udpConns))
for i := range udpConns {
udpGauges[i] = [_SK_MEMINFO_VARS]metrics.Gauge{
@@ -285,7 +290,7 @@ func NewUDPStatsEmitter(udpConns []*Conn) func() {
return func() {
for i, gauges := range udpGauges {
if err := udpConns[i].getMemInfo(&meminfo); err == nil {
if err := udpConns[i].(*StdConn).getMemInfo(&meminfo); err == nil {
for j := 0; j < _SK_MEMINFO_VARS; j++ {
gauges[j].Update(int64(meminfo[j]))
}

View File

@@ -30,7 +30,7 @@ type rawMessage struct {
Len uint32
}
func (u *Conn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
msgs := make([]rawMessage, n)
buffers := make([][]byte, n)
names := make([][]byte, n)

View File

@@ -33,7 +33,7 @@ type rawMessage struct {
Pad0 [4]byte
}
func (u *Conn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) {
msgs := make([]rawMessage, n)
buffers := make([][]byte, n)
names := make([][]byte, n)

View File

@@ -10,9 +10,14 @@ import (
"net"
"syscall"
"github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
return NewGenericListener(l, ip, port, multi, batch)
}
func NewListenConfig(multi bool) net.ListenConfig {
return net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
@@ -36,6 +41,6 @@ func NewListenConfig(multi bool) net.ListenConfig {
}
}
func (u *Conn) Rebind() error {
func (u *GenericConn) Rebind() error {
return nil
}

403
udp/udp_rio_windows.go Normal file
View File

@@ -0,0 +1,403 @@
//go:build !e2e_testing
// +build !e2e_testing
// Inspired by https://git.zx2c4.com/wireguard-go/tree/conn/bind_windows.go
package udp
import (
"errors"
"fmt"
"io"
"net"
"sync"
"sync/atomic"
"syscall"
"unsafe"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/conn/winrio"
)
// Assert we meet the standard conn interface
var _ Conn = &RIOConn{}
//go:linkname procyield runtime.procyield
func procyield(cycles uint32)
const (
packetsPerRing = 1024
bytesPerPacket = 2048 - 32
receiveSpins = 15
)
type ringPacket struct {
addr windows.RawSockaddrInet6
data [bytesPerPacket]byte
}
type ringBuffer struct {
packets uintptr
head, tail uint32
id winrio.BufferId
iocp windows.Handle
isFull bool
cq winrio.Cq
mu sync.Mutex
overlapped windows.Overlapped
}
type RIOConn struct {
isOpen atomic.Bool
l *logrus.Logger
sock windows.Handle
rx, tx ringBuffer
rq winrio.Rq
results [packetsPerRing]winrio.Result
}
func NewRIOListener(l *logrus.Logger, ip net.IP, port int) (*RIOConn, error) {
if !winrio.Initialize() {
return nil, errors.New("could not initialize winrio")
}
u := &RIOConn{l: l}
addr := [16]byte{}
copy(addr[:], ip.To16())
err := u.bind(&windows.SockaddrInet6{Addr: addr, Port: port})
if err != nil {
return nil, fmt.Errorf("bind: %w", err)
}
for i := 0; i < packetsPerRing; i++ {
err = u.insertReceiveRequest()
if err != nil {
return nil, fmt.Errorf("init rx ring: %w", err)
}
}
u.isOpen.Store(true)
return u, nil
}
func (u *RIOConn) bind(sa windows.Sockaddr) error {
var err error
u.sock, err = winrio.Socket(windows.AF_INET6, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
if err != nil {
return err
}
// Enable v4 for this socket
syscall.SetsockoptInt(syscall.Handle(u.sock), syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0)
err = u.rx.Open()
if err != nil {
return err
}
err = u.tx.Open()
if err != nil {
return err
}
u.rq, err = winrio.CreateRequestQueue(u.sock, packetsPerRing, 1, packetsPerRing, 1, u.rx.cq, u.tx.cq, 0)
if err != nil {
return err
}
err = windows.Bind(u.sock, sa)
if err != nil {
return err
}
return nil
}
func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
plaintext := make([]byte, MTU)
buffer := make([]byte, MTU)
h := &header.H{}
fwPacket := &firewall.Packet{}
udpAddr := &Addr{IP: make([]byte, 16)}
nb := make([]byte, 12, 12)
for {
// Just read one packet at a time
n, rua, err := u.receive(buffer)
if err != nil {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
}
udpAddr.IP = rua.Addr[:]
p := (*[2]byte)(unsafe.Pointer(&udpAddr.Port))
p[0] = byte(rua.Port >> 8)
p[1] = byte(rua.Port)
r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l))
}
}
func (u *RIOConn) insertReceiveRequest() error {
packet := u.rx.Push()
dataBuffer := &winrio.Buffer{
Id: u.rx.id,
Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - u.rx.packets),
Length: uint32(len(packet.data)),
}
addressBuffer := &winrio.Buffer{
Id: u.rx.id,
Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - u.rx.packets),
Length: uint32(unsafe.Sizeof(packet.addr)),
}
return winrio.ReceiveEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, uintptr(unsafe.Pointer(packet)))
}
func (u *RIOConn) receive(buf []byte) (int, windows.RawSockaddrInet6, error) {
if !u.isOpen.Load() {
return 0, windows.RawSockaddrInet6{}, net.ErrClosed
}
u.rx.mu.Lock()
defer u.rx.mu.Unlock()
var err error
var count uint32
var results [1]winrio.Result
retry:
count = 0
for tries := 0; count == 0 && tries < receiveSpins; tries++ {
if tries > 0 {
if !u.isOpen.Load() {
return 0, windows.RawSockaddrInet6{}, net.ErrClosed
}
procyield(1)
}
count = winrio.DequeueCompletion(u.rx.cq, results[:])
}
if count == 0 {
err = winrio.Notify(u.rx.cq)
if err != nil {
return 0, windows.RawSockaddrInet6{}, err
}
var bytes uint32
var key uintptr
var overlapped *windows.Overlapped
err = windows.GetQueuedCompletionStatus(u.rx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
if err != nil {
return 0, windows.RawSockaddrInet6{}, err
}
if !u.isOpen.Load() {
return 0, windows.RawSockaddrInet6{}, net.ErrClosed
}
count = winrio.DequeueCompletion(u.rx.cq, results[:])
if count == 0 {
return 0, windows.RawSockaddrInet6{}, io.ErrNoProgress
}
}
u.rx.Return(1)
err = u.insertReceiveRequest()
if err != nil {
return 0, windows.RawSockaddrInet6{}, err
}
// We limit the MTU well below the 65k max for practicality, but this means a remote host can still send us
// huge packets. Just try again when this happens. The infinite loop this could cause is still limited to
// attacker bandwidth, just like the rest of the receive path.
if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE {
goto retry
}
if results[0].Status != 0 {
return 0, windows.RawSockaddrInet6{}, windows.Errno(results[0].Status)
}
packet := (*ringPacket)(unsafe.Pointer(uintptr(results[0].RequestContext)))
ep := packet.addr
n := copy(buf, packet.data[:results[0].BytesTransferred])
return n, ep, nil
}
func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error {
if !u.isOpen.Load() {
return net.ErrClosed
}
if len(buf) > bytesPerPacket {
return io.ErrShortBuffer
}
u.tx.mu.Lock()
defer u.tx.mu.Unlock()
count := winrio.DequeueCompletion(u.tx.cq, u.results[:])
if count == 0 && u.tx.isFull {
err := winrio.Notify(u.tx.cq)
if err != nil {
return err
}
var bytes uint32
var key uintptr
var overlapped *windows.Overlapped
err = windows.GetQueuedCompletionStatus(u.tx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
if err != nil {
return err
}
if !u.isOpen.Load() {
return net.ErrClosed
}
count = winrio.DequeueCompletion(u.tx.cq, u.results[:])
if count == 0 {
return io.ErrNoProgress
}
}
if count > 0 {
u.tx.Return(count)
}
packet := u.tx.Push()
packet.addr.Family = windows.AF_INET6
p := (*[2]byte)(unsafe.Pointer(&packet.addr.Port))
p[0] = byte(addr.Port >> 8)
p[1] = byte(addr.Port)
copy(packet.addr.Addr[:], addr.IP.To16())
copy(packet.data[:], buf)
dataBuffer := &winrio.Buffer{
Id: u.tx.id,
Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - u.tx.packets),
Length: uint32(len(buf)),
}
addressBuffer := &winrio.Buffer{
Id: u.tx.id,
Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - u.tx.packets),
Length: uint32(unsafe.Sizeof(packet.addr)),
}
return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
}
func (u *RIOConn) LocalAddr() (*Addr, error) {
sa, err := windows.Getsockname(u.sock)
if err != nil {
return nil, err
}
v6 := sa.(*windows.SockaddrInet6)
return &Addr{
IP: v6.Addr[:],
Port: uint16(v6.Port),
}, nil
}
func (u *RIOConn) Rebind() error {
return nil
}
func (u *RIOConn) ReloadConfig(*config.C) {}
func (u *RIOConn) Close() error {
if !u.isOpen.CompareAndSwap(true, false) {
return nil
}
windows.PostQueuedCompletionStatus(u.rx.iocp, 0, 0, nil)
windows.PostQueuedCompletionStatus(u.tx.iocp, 0, 0, nil)
u.rx.CloseAndZero()
u.tx.CloseAndZero()
if u.sock != 0 {
windows.CloseHandle(u.sock)
}
return nil
}
func (ring *ringBuffer) Push() *ringPacket {
for ring.isFull {
panic("ring is full")
}
ret := (*ringPacket)(unsafe.Pointer(ring.packets + (uintptr(ring.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{}))))
ring.tail += 1
if ring.tail%packetsPerRing == ring.head%packetsPerRing {
ring.isFull = true
}
return ret
}
func (ring *ringBuffer) Return(count uint32) {
if ring.head%packetsPerRing == ring.tail%packetsPerRing && !ring.isFull {
return
}
ring.head += count
ring.isFull = false
}
func (ring *ringBuffer) CloseAndZero() {
if ring.cq != 0 {
winrio.CloseCompletionQueue(ring.cq)
ring.cq = 0
}
if ring.iocp != 0 {
windows.CloseHandle(ring.iocp)
ring.iocp = 0
}
if ring.id != 0 {
winrio.DeregisterBuffer(ring.id)
ring.id = 0
}
if ring.packets != 0 {
windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE)
ring.packets = 0
}
ring.head = 0
ring.tail = 0
ring.isFull = false
}
func (ring *ringBuffer) Open() error {
var err error
packetsLen := unsafe.Sizeof(ringPacket{}) * packetsPerRing
ring.packets, err = windows.VirtualAlloc(0, packetsLen, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE)
if err != nil {
return err
}
ring.id, err = winrio.RegisterPointer(unsafe.Pointer(ring.packets), uint32(packetsLen))
if err != nil {
return err
}
ring.iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
if err != nil {
return err
}
ring.cq, err = winrio.CreateIOCPCompletionQueue(packetsPerRing, ring.iocp, 0, &ring.overlapped)
if err != nil {
return err
}
return nil
}

View File

@@ -5,7 +5,9 @@ package udp
import (
"fmt"
"io"
"net"
"sync/atomic"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
@@ -36,17 +38,18 @@ func (u *Packet) Copy() *Packet {
return n
}
type Conn struct {
type TesterConn struct {
Addr *Addr
RxPackets chan *Packet // Packets to receive into nebula
TxPackets chan *Packet // Packets transmitted outside by nebula
l *logrus.Logger
closed atomic.Bool
l *logrus.Logger
}
func NewListener(l *logrus.Logger, ip net.IP, port int, _ bool, _ int) (*Conn, error) {
return &Conn{
func NewListener(l *logrus.Logger, ip net.IP, port int, _ bool, _ int) (Conn, error) {
return &TesterConn{
Addr: &Addr{ip, uint16(port)},
RxPackets: make(chan *Packet, 10),
TxPackets: make(chan *Packet, 10),
@@ -57,7 +60,11 @@ func NewListener(l *logrus.Logger, ip net.IP, port int, _ bool, _ int) (*Conn, e
// Send will place a UdpPacket onto the receive queue for nebula to consume
// this is an encrypted packet or a handshake message in most cases
// packets were transmitted from another nebula node, you can send them with Tun.Send
func (u *Conn) Send(packet *Packet) {
func (u *TesterConn) Send(packet *Packet) {
if u.closed.Load() {
return
}
h := &header.H{}
if err := h.Parse(packet.Data); err != nil {
panic(err)
@@ -74,7 +81,7 @@ func (u *Conn) Send(packet *Packet) {
// Get will pull a UdpPacket from the transmit queue
// nebula meant to send this message on the network, it will be encrypted
// packets were ingested from the tun side (in most cases), you can send them with Tun.Send
func (u *Conn) Get(block bool) *Packet {
func (u *TesterConn) Get(block bool) *Packet {
if block {
return <-u.TxPackets
}
@@ -91,7 +98,11 @@ func (u *Conn) Get(block bool) *Packet {
// Below this is boilerplate implementation to make nebula actually work
//********************************************************************************************************************//
func (u *Conn) WriteTo(b []byte, addr *Addr) error {
func (u *TesterConn) WriteTo(b []byte, addr *Addr) error {
if u.closed.Load() {
return io.ErrClosedPipe
}
p := &Packet{
Data: make([]byte, len(b), len(b)),
FromIp: make([]byte, 16),
@@ -108,7 +119,7 @@ func (u *Conn) WriteTo(b []byte, addr *Addr) error {
return nil
}
func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
func (u *TesterConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
plaintext := make([]byte, MTU)
h := &header.H{}
fwPacket := &firewall.Packet{}
@@ -126,17 +137,25 @@ func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall
}
}
func (u *Conn) ReloadConfig(*config.C) {}
func (u *TesterConn) ReloadConfig(*config.C) {}
func NewUDPStatsEmitter(_ []*Conn) func() {
func NewUDPStatsEmitter(_ []Conn) func() {
// No UDP stats for non-linux
return func() {}
}
func (u *Conn) LocalAddr() (*Addr, error) {
func (u *TesterConn) LocalAddr() (*Addr, error) {
return u.Addr, nil
}
func (u *Conn) Rebind() error {
func (u *TesterConn) Rebind() error {
return nil
}
func (u *TesterConn) Close() error {
if u.closed.CompareAndSwap(false, true) {
close(u.RxPackets)
close(u.TxPackets)
}
return nil
}

View File

@@ -3,14 +3,31 @@
package udp
// Windows support is primarily implemented in udp_generic, besides NewListenConfig
import (
"fmt"
"net"
"syscall"
"github.com/sirupsen/logrus"
)
func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
if multi {
//NOTE: Technically we can support it with RIO but it wouldn't be at the socket level
// The udp stack would need to be reworked to hide away the implementation differences between
// Windows and Linux
return nil, fmt.Errorf("multiple udp listeners not supported on windows")
}
rc, err := NewRIOListener(l, ip, port)
if err == nil {
return rc, nil
}
l.WithError(err).Error("Falling back to standard udp sockets")
return NewGenericListener(l, ip, port, multi, batch)
}
func NewListenConfig(multi bool) net.ListenConfig {
return net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
@@ -24,6 +41,6 @@ func NewListenConfig(multi bool) net.ListenConfig {
}
}
func (u *Conn) Rebind() error {
func (u *GenericConn) Rebind() error {
return nil
}