mirror of
https://github.com/slackhq/nebula.git
synced 2026-02-15 17:24:23 +01:00
batch tun reads
This commit is contained in:
@@ -34,6 +34,7 @@ type tun struct {
|
||||
TXQueueLen int
|
||||
deviceIndex int
|
||||
ioctlFd uintptr
|
||||
nonBlocking bool // true if fd is in non-blocking mode
|
||||
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
@@ -239,7 +240,7 @@ func (t *tun) SupportsMultiqueue() bool {
|
||||
}
|
||||
|
||||
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR|unix.O_NONBLOCK, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -251,9 +252,99 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
||||
return &tunBatchReader{fd: fd, device: t.Device}, nil
|
||||
}
|
||||
|
||||
return file, nil
|
||||
// tunBatchReader implements BatchReader for efficient batch packet reading
|
||||
type tunBatchReader struct {
|
||||
fd int
|
||||
device string
|
||||
}
|
||||
|
||||
func (r *tunBatchReader) Read(b []byte) (int, error) {
|
||||
// Use poll to wait for data, then read
|
||||
for {
|
||||
n, err := unix.Read(r.fd, b)
|
||||
if err == nil {
|
||||
return n, nil
|
||||
}
|
||||
if err == unix.EAGAIN || err == unix.EWOULDBLOCK {
|
||||
// Wait for data
|
||||
pfds := []unix.PollFd{{Fd: int32(r.fd), Events: unix.POLLIN}}
|
||||
_, err = unix.Poll(pfds, -1)
|
||||
if err != nil {
|
||||
if err == unix.EINTR {
|
||||
continue
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
|
||||
func (r *tunBatchReader) Write(b []byte) (int, error) {
|
||||
return unix.Write(r.fd, b)
|
||||
}
|
||||
|
||||
func (r *tunBatchReader) Close() error {
|
||||
return unix.Close(r.fd)
|
||||
}
|
||||
|
||||
// ReadBatch reads up to len(packets) packets from the TUN device.
|
||||
// It drains all available packets without blocking, using poll() only
|
||||
// when no packets have been read yet.
|
||||
func (r *tunBatchReader) ReadBatch(packets [][]byte, sizes []int) (int, error) {
|
||||
count := 0
|
||||
maxPackets := len(packets)
|
||||
if len(sizes) < maxPackets {
|
||||
maxPackets = len(sizes)
|
||||
}
|
||||
|
||||
for count < maxPackets {
|
||||
n, err := unix.Read(r.fd, packets[count])
|
||||
if err == nil && n > 0 {
|
||||
sizes[count] = n
|
||||
count++
|
||||
continue
|
||||
}
|
||||
|
||||
if err == unix.EAGAIN || err == unix.EWOULDBLOCK {
|
||||
// No more packets available
|
||||
if count > 0 {
|
||||
// We have some packets, return them
|
||||
return count, nil
|
||||
}
|
||||
// No packets yet, wait for at least one
|
||||
pfds := []unix.PollFd{{Fd: int32(r.fd), Events: unix.POLLIN}}
|
||||
_, err = unix.Poll(pfds, -1)
|
||||
if err != nil {
|
||||
if err == unix.EINTR {
|
||||
continue
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if count > 0 {
|
||||
// Return what we have
|
||||
return count, nil
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if n == 0 {
|
||||
if count > 0 {
|
||||
return count, nil
|
||||
}
|
||||
return 0, io.EOF
|
||||
}
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||
@@ -284,6 +375,111 @@ func (t *tun) Write(b []byte) (int, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// EnableBatchReading sets the TUN fd to non-blocking mode to enable batch reading.
|
||||
// This should be called before using ReadBatch.
|
||||
func (t *tun) EnableBatchReading() error {
|
||||
if t.nonBlocking {
|
||||
return nil
|
||||
}
|
||||
err := unix.SetNonblock(t.fd, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.nonBlocking = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read overrides the default Read to handle non-blocking mode
|
||||
func (t *tun) Read(b []byte) (int, error) {
|
||||
if !t.nonBlocking {
|
||||
// Use the embedded ReadWriteCloser for blocking reads
|
||||
return t.ReadWriteCloser.Read(b)
|
||||
}
|
||||
|
||||
// Non-blocking read with poll
|
||||
for {
|
||||
n, err := unix.Read(t.fd, b)
|
||||
if err == nil {
|
||||
return n, nil
|
||||
}
|
||||
if err == unix.EAGAIN || err == unix.EWOULDBLOCK {
|
||||
// Wait for data
|
||||
pfds := []unix.PollFd{{Fd: int32(t.fd), Events: unix.POLLIN}}
|
||||
_, err = unix.Poll(pfds, -1)
|
||||
if err != nil {
|
||||
if err == unix.EINTR {
|
||||
continue
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
|
||||
// ReadBatch reads up to len(packets) packets from the TUN device.
|
||||
// EnableBatchReading must be called first.
|
||||
func (t *tun) ReadBatch(packets [][]byte, sizes []int) (int, error) {
|
||||
if !t.nonBlocking {
|
||||
// Fallback to single read if non-blocking not enabled
|
||||
n, err := t.ReadWriteCloser.Read(packets[0])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
sizes[0] = n
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
count := 0
|
||||
maxPackets := len(packets)
|
||||
if len(sizes) < maxPackets {
|
||||
maxPackets = len(sizes)
|
||||
}
|
||||
|
||||
for count < maxPackets {
|
||||
n, err := unix.Read(t.fd, packets[count])
|
||||
if err == nil && n > 0 {
|
||||
sizes[count] = n
|
||||
count++
|
||||
continue
|
||||
}
|
||||
|
||||
if err == unix.EAGAIN || err == unix.EWOULDBLOCK {
|
||||
// No more packets available
|
||||
if count > 0 {
|
||||
return count, nil
|
||||
}
|
||||
// No packets yet, wait for at least one
|
||||
pfds := []unix.PollFd{{Fd: int32(t.fd), Events: unix.POLLIN}}
|
||||
_, err = unix.Poll(pfds, -1)
|
||||
if err != nil {
|
||||
if err == unix.EINTR {
|
||||
continue
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if count > 0 {
|
||||
return count, nil
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if n == 0 {
|
||||
if count > 0 {
|
||||
return count, nil
|
||||
}
|
||||
return 0, io.EOF
|
||||
}
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (t *tun) deviceBytes() (o [16]byte) {
|
||||
for i, c := range t.Device {
|
||||
o[i] = byte(c)
|
||||
|
||||
Reference in New Issue
Block a user