mirror of
https://github.com/slackhq/nebula.git
synced 2026-02-14 00:34:22 +01:00
claude does TUN virtio header support
This commit is contained in:
@@ -24,6 +24,11 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const (
|
||||
// virtioNetHdrLen is the length of virtio_net_hdr (without mergeable buffers)
|
||||
virtioNetHdrLen = 10
|
||||
)
|
||||
|
||||
type tun struct {
|
||||
io.ReadWriteCloser
|
||||
fd int
|
||||
@@ -35,6 +40,12 @@ type tun struct {
|
||||
deviceIndex int
|
||||
ioctlFd uintptr
|
||||
nonBlocking bool // true if fd is in non-blocking mode
|
||||
vnetHdr bool // true if IFF_VNET_HDR is enabled on the TUN device
|
||||
|
||||
// readBuf is used when vnetHdr is enabled to read the full packet+header
|
||||
// before stripping the header. This is needed because caller-provided
|
||||
// buffers are sized for MTU but kernel writes MTU+10 with virtio header.
|
||||
readBuf []byte
|
||||
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
@@ -54,6 +65,23 @@ func (t *tun) Networks() []netip.Prefix {
|
||||
return t.vpnNetworks
|
||||
}
|
||||
|
||||
// tunVnetHdrSupported checks if the kernel supports IFF_VNET_HDR on TUN devices
|
||||
func tunVnetHdrSupported() bool {
|
||||
fd, err := unix.Open("/dev/net/tun", unix.O_RDONLY, 0)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer unix.Close(fd)
|
||||
|
||||
var features uint32
|
||||
err = ioctl(uintptr(fd), uintptr(unix.TUNGETFEATURES), uintptr(unsafe.Pointer(&features)))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return features&unix.IFF_VNET_HDR != 0
|
||||
}
|
||||
|
||||
type ifReq struct {
|
||||
Name [16]byte
|
||||
Flags uint16
|
||||
@@ -108,11 +136,18 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
||||
}
|
||||
}
|
||||
|
||||
// Check if VNET_HDR is supported before trying to use it
|
||||
useVnetHdr := tunVnetHdrSupported()
|
||||
|
||||
var req ifReq
|
||||
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI)
|
||||
if multiqueue {
|
||||
req.Flags |= unix.IFF_MULTI_QUEUE
|
||||
}
|
||||
if useVnetHdr {
|
||||
req.Flags |= unix.IFF_VNET_HDR
|
||||
}
|
||||
|
||||
nameStr := c.GetString("tun.dev", "")
|
||||
copy(req.Name[:], nameStr)
|
||||
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
||||
@@ -123,6 +158,13 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
||||
}
|
||||
name := strings.Trim(string(req.Name[:]), "\x00")
|
||||
|
||||
// Track if VNET_HDR is in use
|
||||
// Note: We don't call TUNSETOFFLOAD - just handle the headers manually
|
||||
vnetHdrEnabled := useVnetHdr
|
||||
if vnetHdrEnabled {
|
||||
l.Info("TUN VNET_HDR enabled")
|
||||
}
|
||||
|
||||
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
||||
t, err := newTunGeneric(c, l, file, vpnNetworks)
|
||||
if err != nil {
|
||||
@@ -130,6 +172,13 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
||||
}
|
||||
|
||||
t.Device = name
|
||||
t.vnetHdr = vnetHdrEnabled
|
||||
|
||||
// Allocate read buffer for virtio header handling
|
||||
// Buffer needs to be large enough for virtio header + max packet
|
||||
if t.vnetHdr {
|
||||
t.readBuf = make([]byte, t.MaxMTU+virtioNetHdrLen)
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
@@ -247,25 +296,48 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
|
||||
var req ifReq
|
||||
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
|
||||
if t.vnetHdr {
|
||||
req.Flags |= unix.IFF_VNET_HDR
|
||||
}
|
||||
copy(req.Name[:], t.Device)
|
||||
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &tunBatchReader{fd: fd, device: t.Device}, nil
|
||||
reader := &tunBatchReader{fd: fd, device: t.Device, vnetHdr: t.vnetHdr}
|
||||
if t.vnetHdr {
|
||||
reader.readBuf = make([]byte, t.MaxMTU+virtioNetHdrLen)
|
||||
}
|
||||
return reader, nil
|
||||
}
|
||||
|
||||
// tunBatchReader implements BatchReader for efficient batch packet reading
|
||||
type tunBatchReader struct {
|
||||
fd int
|
||||
device string
|
||||
fd int
|
||||
device string
|
||||
vnetHdr bool
|
||||
readBuf []byte // internal buffer for virtio header handling
|
||||
}
|
||||
|
||||
func (r *tunBatchReader) Read(b []byte) (int, error) {
|
||||
// Choose buffer: use internal buffer for vnetHdr, caller's buffer otherwise
|
||||
readBuf := b
|
||||
if r.vnetHdr {
|
||||
readBuf = r.readBuf
|
||||
}
|
||||
|
||||
// Use poll to wait for data, then read
|
||||
for {
|
||||
n, err := unix.Read(r.fd, b)
|
||||
n, err := unix.Read(r.fd, readBuf)
|
||||
if err == nil {
|
||||
if r.vnetHdr && n > virtioNetHdrLen {
|
||||
packetLen := n - virtioNetHdrLen
|
||||
copy(b, readBuf[virtioNetHdrLen:n])
|
||||
return packetLen, nil
|
||||
}
|
||||
if r.vnetHdr {
|
||||
return 0, nil // No packet data
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
if err == unix.EAGAIN || err == unix.EWOULDBLOCK {
|
||||
@@ -285,7 +357,24 @@ func (r *tunBatchReader) Read(b []byte) (int, error) {
|
||||
}
|
||||
|
||||
func (r *tunBatchReader) Write(b []byte) (int, error) {
|
||||
return unix.Write(r.fd, b)
|
||||
if !r.vnetHdr {
|
||||
return unix.Write(r.fd, b)
|
||||
}
|
||||
|
||||
// Use writev to prepend virtio header without copying the packet data
|
||||
// Header is all zeros = no GSO, no checksum offload
|
||||
var hdr [virtioNetHdrLen]byte
|
||||
bufs := [][]byte{hdr[:], b}
|
||||
|
||||
n, err := unix.Writev(r.fd, bufs)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
// Return only the packet bytes written (exclude header)
|
||||
if n > virtioNetHdrLen {
|
||||
return n - virtioNetHdrLen, nil
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (r *tunBatchReader) Close() error {
|
||||
@@ -302,10 +391,31 @@ func (r *tunBatchReader) ReadBatch(packets [][]byte, sizes []int) (int, error) {
|
||||
maxPackets = len(sizes)
|
||||
}
|
||||
|
||||
// Choose read buffer based on vnetHdr
|
||||
readBuf := packets[0] // Will be updated in loop for non-vnetHdr
|
||||
if r.vnetHdr {
|
||||
readBuf = r.readBuf
|
||||
}
|
||||
|
||||
for count < maxPackets {
|
||||
n, err := unix.Read(r.fd, packets[count])
|
||||
if !r.vnetHdr {
|
||||
readBuf = packets[count]
|
||||
}
|
||||
|
||||
n, err := unix.Read(r.fd, readBuf)
|
||||
if err == nil && n > 0 {
|
||||
sizes[count] = n
|
||||
if r.vnetHdr {
|
||||
if n > virtioNetHdrLen {
|
||||
packetLen := n - virtioNetHdrLen
|
||||
copy(packets[count], readBuf[virtioNetHdrLen:n])
|
||||
sizes[count] = packetLen
|
||||
} else {
|
||||
// Malformed packet (no data after header), skip
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
sizes[count] = n
|
||||
}
|
||||
count++
|
||||
continue
|
||||
}
|
||||
@@ -353,6 +463,27 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
|
||||
}
|
||||
|
||||
func (t *tun) Write(b []byte) (int, error) {
|
||||
if !t.vnetHdr {
|
||||
return t.writeSimple(b)
|
||||
}
|
||||
|
||||
// Use writev to prepend virtio header without copying the packet data
|
||||
// Header is all zeros = no GSO, no checksum offload
|
||||
var hdr [virtioNetHdrLen]byte
|
||||
bufs := [][]byte{hdr[:], b}
|
||||
|
||||
n, err := unix.Writev(t.fd, bufs)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
// Return only the packet bytes written (exclude header)
|
||||
if n > virtioNetHdrLen {
|
||||
return n - virtioNetHdrLen, nil
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (t *tun) writeSimple(b []byte) (int, error) {
|
||||
var nn int
|
||||
maximum := len(b)
|
||||
|
||||
@@ -389,21 +520,64 @@ func (t *tun) EnableBatchReading() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read overrides the default Read to handle non-blocking mode
|
||||
// Read overrides the default Read to handle non-blocking mode and virtio headers
|
||||
func (t *tun) Read(b []byte) (int, error) {
|
||||
if !t.vnetHdr {
|
||||
return t.readSimple(b)
|
||||
}
|
||||
|
||||
// With VNET_HDR, read into internal buffer (which has space for header)
|
||||
// then copy packet data to caller's buffer
|
||||
if !t.nonBlocking {
|
||||
// Use the embedded ReadWriteCloser for blocking reads
|
||||
return t.ReadWriteCloser.Read(b)
|
||||
n, err := t.ReadWriteCloser.Read(t.readBuf)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if n <= virtioNetHdrLen {
|
||||
return 0, nil // No packet data
|
||||
}
|
||||
packetLen := n - virtioNetHdrLen
|
||||
copy(b, t.readBuf[virtioNetHdrLen:n])
|
||||
return packetLen, nil
|
||||
}
|
||||
|
||||
// Non-blocking read with poll
|
||||
for {
|
||||
n, err := unix.Read(t.fd, t.readBuf)
|
||||
if err == nil {
|
||||
if n <= virtioNetHdrLen {
|
||||
return 0, nil // No packet data
|
||||
}
|
||||
packetLen := n - virtioNetHdrLen
|
||||
copy(b, t.readBuf[virtioNetHdrLen:n])
|
||||
return packetLen, nil
|
||||
}
|
||||
if err == unix.EAGAIN || err == unix.EWOULDBLOCK {
|
||||
pfds := []unix.PollFd{{Fd: int32(t.fd), Events: unix.POLLIN}}
|
||||
_, err = unix.Poll(pfds, -1)
|
||||
if err != nil {
|
||||
if err == unix.EINTR {
|
||||
continue
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
|
||||
func (t *tun) readSimple(b []byte) (int, error) {
|
||||
if !t.nonBlocking {
|
||||
return t.ReadWriteCloser.Read(b)
|
||||
}
|
||||
|
||||
for {
|
||||
n, err := unix.Read(t.fd, b)
|
||||
if err == nil {
|
||||
return n, nil
|
||||
}
|
||||
if err == unix.EAGAIN || err == unix.EWOULDBLOCK {
|
||||
// Wait for data
|
||||
pfds := []unix.PollFd{{Fd: int32(t.fd), Events: unix.POLLIN}}
|
||||
_, err = unix.Poll(pfds, -1)
|
||||
if err != nil {
|
||||
@@ -423,7 +597,7 @@ func (t *tun) Read(b []byte) (int, error) {
|
||||
func (t *tun) ReadBatch(packets [][]byte, sizes []int) (int, error) {
|
||||
if !t.nonBlocking {
|
||||
// Fallback to single read if non-blocking not enabled
|
||||
n, err := t.ReadWriteCloser.Read(packets[0])
|
||||
n, err := t.Read(packets[0])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -437,10 +611,33 @@ func (t *tun) ReadBatch(packets [][]byte, sizes []int) (int, error) {
|
||||
maxPackets = len(sizes)
|
||||
}
|
||||
|
||||
// Choose read buffer based on vnetHdr
|
||||
// With vnetHdr, we need to read into internal buffer (has space for header)
|
||||
// then copy packet data to caller's buffer
|
||||
readBuf := packets[0] // Will be updated in the loop
|
||||
if t.vnetHdr {
|
||||
readBuf = t.readBuf
|
||||
}
|
||||
|
||||
for count < maxPackets {
|
||||
n, err := unix.Read(t.fd, packets[count])
|
||||
if !t.vnetHdr {
|
||||
readBuf = packets[count]
|
||||
}
|
||||
|
||||
n, err := unix.Read(t.fd, readBuf)
|
||||
if err == nil && n > 0 {
|
||||
sizes[count] = n
|
||||
if t.vnetHdr {
|
||||
if n > virtioNetHdrLen {
|
||||
packetLen := n - virtioNetHdrLen
|
||||
copy(packets[count], readBuf[virtioNetHdrLen:n])
|
||||
sizes[count] = packetLen
|
||||
} else {
|
||||
// Malformed packet (no data after header), skip
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
sizes[count] = n
|
||||
}
|
||||
count++
|
||||
continue
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user