mirror of
https://github.com/slackhq/nebula.git
synced 2026-02-15 09:14:23 +01:00
use wg tun library; batching & locking improvements
This commit is contained in:
178
udp/udp_linux.go
178
udp/udp_linux.go
@@ -22,6 +22,11 @@ type StdConn struct {
|
||||
isV4 bool
|
||||
l *logrus.Logger
|
||||
batch int
|
||||
|
||||
// Pre-allocated buffers for batch writes (sized for IPv6, works for both)
|
||||
writeMsgs []rawMessage
|
||||
writeIovecs []iovec
|
||||
writeNames [][]byte
|
||||
}
|
||||
|
||||
func maybeIPV4(ip net.IP) (net.IP, bool) {
|
||||
@@ -69,7 +74,26 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
|
||||
return nil, fmt.Errorf("unable to bind to socket: %s", err)
|
||||
}
|
||||
|
||||
return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err
|
||||
c := &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}
|
||||
|
||||
// Pre-allocate write message structures for batching (sized for IPv6, works for both)
|
||||
c.writeMsgs = make([]rawMessage, batch)
|
||||
c.writeIovecs = make([]iovec, batch)
|
||||
c.writeNames = make([][]byte, batch)
|
||||
|
||||
for i := range c.writeMsgs {
|
||||
// Allocate for IPv6 size (larger than IPv4, works for both)
|
||||
c.writeNames[i] = make([]byte, unix.SizeofSockaddrInet6)
|
||||
|
||||
// Point to the iovec in the slice
|
||||
c.writeMsgs[i].Hdr.Iov = &c.writeIovecs[i]
|
||||
c.writeMsgs[i].Hdr.Iovlen = 1
|
||||
|
||||
c.writeMsgs[i].Hdr.Name = &c.writeNames[i][0]
|
||||
// Namelen will be set appropriately in writeMulti4/writeMulti6
|
||||
}
|
||||
|
||||
return c, err
|
||||
}
|
||||
|
||||
func (u *StdConn) SupportsMultipleReaders() bool {
|
||||
@@ -122,7 +146,7 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (u *StdConn) ListenOut(r EncReader) {
|
||||
func (u *StdConn) ListenOutBatch(r EncBatchReader) {
|
||||
var ip netip.Addr
|
||||
|
||||
msgs, buffers, names := u.PrepareRawMessages(u.batch)
|
||||
@@ -131,6 +155,12 @@ func (u *StdConn) ListenOut(r EncReader) {
|
||||
read = u.ReadSingle
|
||||
}
|
||||
|
||||
udpBatchHist := metrics.GetOrRegisterHistogram("batch.udp_read_size", nil, metrics.NewUniformSample(1024))
|
||||
|
||||
// Pre-allocate slices for batch callback
|
||||
addrs := make([]netip.AddrPort, u.batch)
|
||||
payloads := make([][]byte, u.batch)
|
||||
|
||||
for {
|
||||
n, err := read(msgs)
|
||||
if err != nil {
|
||||
@@ -138,15 +168,21 @@ func (u *StdConn) ListenOut(r EncReader) {
|
||||
return
|
||||
}
|
||||
|
||||
udpBatchHist.Update(int64(n))
|
||||
|
||||
// Prepare batch data
|
||||
for i := 0; i < n; i++ {
|
||||
// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
|
||||
if u.isV4 {
|
||||
ip, _ = netip.AddrFromSlice(names[i][4:8])
|
||||
} else {
|
||||
ip, _ = netip.AddrFromSlice(names[i][8:24])
|
||||
}
|
||||
r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len])
|
||||
addrs[i] = netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4]))
|
||||
payloads[i] = buffers[i][:msgs[i].Len]
|
||||
}
|
||||
|
||||
// Call batch callback with all packets
|
||||
r(addrs, payloads, n)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -198,6 +234,19 @@ func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
|
||||
return u.writeTo6(b, ip)
|
||||
}
|
||||
|
||||
func (u *StdConn) WriteMulti(packets [][]byte, addrs []netip.AddrPort) (int, error) {
|
||||
if len(packets) != len(addrs) {
|
||||
return 0, fmt.Errorf("packets and addrs length mismatch")
|
||||
}
|
||||
if len(packets) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if u.isV4 {
|
||||
return u.writeMulti4(packets, addrs)
|
||||
}
|
||||
return u.writeMulti6(packets, addrs)
|
||||
}
|
||||
|
||||
func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
|
||||
var rsa unix.RawSockaddrInet6
|
||||
rsa.Family = unix.AF_INET6
|
||||
@@ -252,6 +301,123 @@ func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error {
|
||||
}
|
||||
}
|
||||
|
||||
func (u *StdConn) writeMulti4(packets [][]byte, addrs []netip.AddrPort) (int, error) {
|
||||
sent := 0
|
||||
for sent < len(packets) {
|
||||
// Determine batch size based on remaining packets and buffer capacity
|
||||
batchSize := len(packets) - sent
|
||||
if batchSize > len(u.writeMsgs) {
|
||||
batchSize = len(u.writeMsgs)
|
||||
}
|
||||
|
||||
// Use pre-allocated buffers
|
||||
msgs := u.writeMsgs[:batchSize]
|
||||
iovecs := u.writeIovecs[:batchSize]
|
||||
names := u.writeNames[:batchSize]
|
||||
|
||||
// Setup message structures for this batch
|
||||
for i := 0; i < batchSize; i++ {
|
||||
pktIdx := sent + i
|
||||
if !addrs[pktIdx].Addr().Is4() {
|
||||
return sent + i, ErrInvalidIPv6RemoteForSocket
|
||||
}
|
||||
|
||||
// Setup the packet buffer
|
||||
iovecs[i].Base = &packets[pktIdx][0]
|
||||
iovecs[i].Len = uint(len(packets[pktIdx]))
|
||||
|
||||
// Setup the destination address
|
||||
rsa := (*unix.RawSockaddrInet4)(unsafe.Pointer(&names[i][0]))
|
||||
rsa.Family = unix.AF_INET
|
||||
rsa.Addr = addrs[pktIdx].Addr().As4()
|
||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], addrs[pktIdx].Port())
|
||||
|
||||
// Set the appropriate address length for IPv4
|
||||
msgs[i].Hdr.Namelen = unix.SizeofSockaddrInet4
|
||||
}
|
||||
|
||||
// Send this batch
|
||||
nsent, _, err := unix.Syscall6(
|
||||
unix.SYS_SENDMMSG,
|
||||
uintptr(u.sysFd),
|
||||
uintptr(unsafe.Pointer(&msgs[0])),
|
||||
uintptr(batchSize),
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
|
||||
if err != 0 {
|
||||
return sent + int(nsent), &net.OpError{Op: "sendmmsg", Err: err}
|
||||
}
|
||||
|
||||
sent += int(nsent)
|
||||
if int(nsent) < batchSize {
|
||||
// Couldn't send all packets in batch, return what we sent
|
||||
return sent, nil
|
||||
}
|
||||
}
|
||||
|
||||
return sent, nil
|
||||
}
|
||||
|
||||
func (u *StdConn) writeMulti6(packets [][]byte, addrs []netip.AddrPort) (int, error) {
|
||||
sent := 0
|
||||
for sent < len(packets) {
|
||||
// Determine batch size based on remaining packets and buffer capacity
|
||||
batchSize := len(packets) - sent
|
||||
if batchSize > len(u.writeMsgs) {
|
||||
batchSize = len(u.writeMsgs)
|
||||
}
|
||||
|
||||
// Use pre-allocated buffers
|
||||
msgs := u.writeMsgs[:batchSize]
|
||||
iovecs := u.writeIovecs[:batchSize]
|
||||
names := u.writeNames[:batchSize]
|
||||
|
||||
// Setup message structures for this batch
|
||||
for i := 0; i < batchSize; i++ {
|
||||
pktIdx := sent + i
|
||||
|
||||
// Setup the packet buffer
|
||||
iovecs[i].Base = &packets[pktIdx][0]
|
||||
iovecs[i].Len = uint(len(packets[pktIdx]))
|
||||
|
||||
// Setup the destination address
|
||||
rsa := (*unix.RawSockaddrInet6)(unsafe.Pointer(&names[i][0]))
|
||||
rsa.Family = unix.AF_INET6
|
||||
rsa.Addr = addrs[pktIdx].Addr().As16()
|
||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], addrs[pktIdx].Port())
|
||||
|
||||
// Set the appropriate address length for IPv6
|
||||
msgs[i].Hdr.Namelen = unix.SizeofSockaddrInet6
|
||||
}
|
||||
|
||||
// Send this batch
|
||||
nsent, _, err := unix.Syscall6(
|
||||
unix.SYS_SENDMMSG,
|
||||
uintptr(u.sysFd),
|
||||
uintptr(unsafe.Pointer(&msgs[0])),
|
||||
uintptr(batchSize),
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
|
||||
if err != 0 {
|
||||
return sent + int(nsent), &net.OpError{Op: "sendmmsg", Err: err}
|
||||
}
|
||||
|
||||
sent += int(nsent)
|
||||
if int(nsent) < batchSize {
|
||||
// Couldn't send all packets in batch, return what we sent
|
||||
return sent, nil
|
||||
}
|
||||
}
|
||||
|
||||
return sent, nil
|
||||
}
|
||||
|
||||
func (u *StdConn) ReloadConfig(c *config.C) {
|
||||
b := c.GetInt("listen.read_buffer", 0)
|
||||
if b > 0 {
|
||||
@@ -309,6 +475,10 @@ func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *StdConn) BatchSize() int {
|
||||
return u.batch
|
||||
}
|
||||
|
||||
func (u *StdConn) Close() error {
|
||||
return syscall.Close(u.sysFd)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user