claude does TUN virtio header support

This commit is contained in:
Jay Wren
2026-02-04 13:13:41 -05:00
parent 030b7e2763
commit ef1739bec4

View File

@@ -24,6 +24,11 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
const (
// virtioNetHdrLen is the length of virtio_net_hdr (without mergeable buffers)
virtioNetHdrLen = 10
)
type tun struct { type tun struct {
io.ReadWriteCloser io.ReadWriteCloser
fd int fd int
@@ -35,6 +40,12 @@ type tun struct {
deviceIndex int deviceIndex int
ioctlFd uintptr ioctlFd uintptr
nonBlocking bool // true if fd is in non-blocking mode nonBlocking bool // true if fd is in non-blocking mode
vnetHdr bool // true if IFF_VNET_HDR is enabled on the TUN device
// readBuf is used when vnetHdr is enabled to read the full packet+header
// before stripping the header. This is needed because caller-provided
// buffers are sized for MTU but kernel writes MTU+10 with virtio header.
readBuf []byte
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]] routeTree atomic.Pointer[bart.Table[routing.Gateways]]
@@ -54,6 +65,23 @@ func (t *tun) Networks() []netip.Prefix {
return t.vpnNetworks return t.vpnNetworks
} }
// tunVnetHdrSupported checks if the kernel supports IFF_VNET_HDR on TUN devices
func tunVnetHdrSupported() bool {
fd, err := unix.Open("/dev/net/tun", unix.O_RDONLY, 0)
if err != nil {
return false
}
defer unix.Close(fd)
var features uint32
err = ioctl(uintptr(fd), uintptr(unix.TUNGETFEATURES), uintptr(unsafe.Pointer(&features)))
if err != nil {
return false
}
return features&unix.IFF_VNET_HDR != 0
}
type ifReq struct { type ifReq struct {
Name [16]byte Name [16]byte
Flags uint16 Flags uint16
@@ -108,11 +136,18 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
} }
} }
// Check if VNET_HDR is supported before trying to use it
useVnetHdr := tunVnetHdrSupported()
var req ifReq var req ifReq
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI) req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI)
if multiqueue { if multiqueue {
req.Flags |= unix.IFF_MULTI_QUEUE req.Flags |= unix.IFF_MULTI_QUEUE
} }
if useVnetHdr {
req.Flags |= unix.IFF_VNET_HDR
}
nameStr := c.GetString("tun.dev", "") nameStr := c.GetString("tun.dev", "")
copy(req.Name[:], nameStr) copy(req.Name[:], nameStr)
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
@@ -123,6 +158,13 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
} }
name := strings.Trim(string(req.Name[:]), "\x00") name := strings.Trim(string(req.Name[:]), "\x00")
// Track if VNET_HDR is in use
// Note: We don't call TUNSETOFFLOAD - just handle the headers manually
vnetHdrEnabled := useVnetHdr
if vnetHdrEnabled {
l.Info("TUN VNET_HDR enabled")
}
file := os.NewFile(uintptr(fd), "/dev/net/tun") file := os.NewFile(uintptr(fd), "/dev/net/tun")
t, err := newTunGeneric(c, l, file, vpnNetworks) t, err := newTunGeneric(c, l, file, vpnNetworks)
if err != nil { if err != nil {
@@ -130,6 +172,13 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
} }
t.Device = name t.Device = name
t.vnetHdr = vnetHdrEnabled
// Allocate read buffer for virtio header handling
// Buffer needs to be large enough for virtio header + max packet
if t.vnetHdr {
t.readBuf = make([]byte, t.MaxMTU+virtioNetHdrLen)
}
return t, nil return t, nil
} }
@@ -247,25 +296,48 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
var req ifReq var req ifReq
req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE) req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
if t.vnetHdr {
req.Flags |= unix.IFF_VNET_HDR
}
copy(req.Name[:], t.Device) copy(req.Name[:], t.Device)
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
return nil, err return nil, err
} }
return &tunBatchReader{fd: fd, device: t.Device}, nil reader := &tunBatchReader{fd: fd, device: t.Device, vnetHdr: t.vnetHdr}
if t.vnetHdr {
reader.readBuf = make([]byte, t.MaxMTU+virtioNetHdrLen)
}
return reader, nil
} }
// tunBatchReader implements BatchReader for efficient batch packet reading // tunBatchReader implements BatchReader for efficient batch packet reading
type tunBatchReader struct { type tunBatchReader struct {
fd int fd int
device string device string
vnetHdr bool
readBuf []byte // internal buffer for virtio header handling
} }
func (r *tunBatchReader) Read(b []byte) (int, error) { func (r *tunBatchReader) Read(b []byte) (int, error) {
// Choose buffer: use internal buffer for vnetHdr, caller's buffer otherwise
readBuf := b
if r.vnetHdr {
readBuf = r.readBuf
}
// Use poll to wait for data, then read // Use poll to wait for data, then read
for { for {
n, err := unix.Read(r.fd, b) n, err := unix.Read(r.fd, readBuf)
if err == nil { if err == nil {
if r.vnetHdr && n > virtioNetHdrLen {
packetLen := n - virtioNetHdrLen
copy(b, readBuf[virtioNetHdrLen:n])
return packetLen, nil
}
if r.vnetHdr {
return 0, nil // No packet data
}
return n, nil return n, nil
} }
if err == unix.EAGAIN || err == unix.EWOULDBLOCK { if err == unix.EAGAIN || err == unix.EWOULDBLOCK {
@@ -285,7 +357,24 @@ func (r *tunBatchReader) Read(b []byte) (int, error) {
} }
func (r *tunBatchReader) Write(b []byte) (int, error) { func (r *tunBatchReader) Write(b []byte) (int, error) {
return unix.Write(r.fd, b) if !r.vnetHdr {
return unix.Write(r.fd, b)
}
// Use writev to prepend virtio header without copying the packet data
// Header is all zeros = no GSO, no checksum offload
var hdr [virtioNetHdrLen]byte
bufs := [][]byte{hdr[:], b}
n, err := unix.Writev(r.fd, bufs)
if err != nil {
return 0, err
}
// Return only the packet bytes written (exclude header)
if n > virtioNetHdrLen {
return n - virtioNetHdrLen, nil
}
return 0, nil
} }
func (r *tunBatchReader) Close() error { func (r *tunBatchReader) Close() error {
@@ -302,10 +391,31 @@ func (r *tunBatchReader) ReadBatch(packets [][]byte, sizes []int) (int, error) {
maxPackets = len(sizes) maxPackets = len(sizes)
} }
// Choose read buffer based on vnetHdr
readBuf := packets[0] // Will be updated in loop for non-vnetHdr
if r.vnetHdr {
readBuf = r.readBuf
}
for count < maxPackets { for count < maxPackets {
n, err := unix.Read(r.fd, packets[count]) if !r.vnetHdr {
readBuf = packets[count]
}
n, err := unix.Read(r.fd, readBuf)
if err == nil && n > 0 { if err == nil && n > 0 {
sizes[count] = n if r.vnetHdr {
if n > virtioNetHdrLen {
packetLen := n - virtioNetHdrLen
copy(packets[count], readBuf[virtioNetHdrLen:n])
sizes[count] = packetLen
} else {
// Malformed packet (no data after header), skip
continue
}
} else {
sizes[count] = n
}
count++ count++
continue continue
} }
@@ -353,6 +463,27 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
} }
func (t *tun) Write(b []byte) (int, error) { func (t *tun) Write(b []byte) (int, error) {
if !t.vnetHdr {
return t.writeSimple(b)
}
// Use writev to prepend virtio header without copying the packet data
// Header is all zeros = no GSO, no checksum offload
var hdr [virtioNetHdrLen]byte
bufs := [][]byte{hdr[:], b}
n, err := unix.Writev(t.fd, bufs)
if err != nil {
return 0, err
}
// Return only the packet bytes written (exclude header)
if n > virtioNetHdrLen {
return n - virtioNetHdrLen, nil
}
return 0, nil
}
func (t *tun) writeSimple(b []byte) (int, error) {
var nn int var nn int
maximum := len(b) maximum := len(b)
@@ -389,21 +520,64 @@ func (t *tun) EnableBatchReading() error {
return nil return nil
} }
// Read overrides the default Read to handle non-blocking mode // Read overrides the default Read to handle non-blocking mode and virtio headers
func (t *tun) Read(b []byte) (int, error) { func (t *tun) Read(b []byte) (int, error) {
if !t.vnetHdr {
return t.readSimple(b)
}
// With VNET_HDR, read into internal buffer (which has space for header)
// then copy packet data to caller's buffer
if !t.nonBlocking { if !t.nonBlocking {
// Use the embedded ReadWriteCloser for blocking reads n, err := t.ReadWriteCloser.Read(t.readBuf)
return t.ReadWriteCloser.Read(b) if err != nil {
return 0, err
}
if n <= virtioNetHdrLen {
return 0, nil // No packet data
}
packetLen := n - virtioNetHdrLen
copy(b, t.readBuf[virtioNetHdrLen:n])
return packetLen, nil
} }
// Non-blocking read with poll // Non-blocking read with poll
for {
n, err := unix.Read(t.fd, t.readBuf)
if err == nil {
if n <= virtioNetHdrLen {
return 0, nil // No packet data
}
packetLen := n - virtioNetHdrLen
copy(b, t.readBuf[virtioNetHdrLen:n])
return packetLen, nil
}
if err == unix.EAGAIN || err == unix.EWOULDBLOCK {
pfds := []unix.PollFd{{Fd: int32(t.fd), Events: unix.POLLIN}}
_, err = unix.Poll(pfds, -1)
if err != nil {
if err == unix.EINTR {
continue
}
return 0, err
}
continue
}
return n, err
}
}
func (t *tun) readSimple(b []byte) (int, error) {
if !t.nonBlocking {
return t.ReadWriteCloser.Read(b)
}
for { for {
n, err := unix.Read(t.fd, b) n, err := unix.Read(t.fd, b)
if err == nil { if err == nil {
return n, nil return n, nil
} }
if err == unix.EAGAIN || err == unix.EWOULDBLOCK { if err == unix.EAGAIN || err == unix.EWOULDBLOCK {
// Wait for data
pfds := []unix.PollFd{{Fd: int32(t.fd), Events: unix.POLLIN}} pfds := []unix.PollFd{{Fd: int32(t.fd), Events: unix.POLLIN}}
_, err = unix.Poll(pfds, -1) _, err = unix.Poll(pfds, -1)
if err != nil { if err != nil {
@@ -423,7 +597,7 @@ func (t *tun) Read(b []byte) (int, error) {
func (t *tun) ReadBatch(packets [][]byte, sizes []int) (int, error) { func (t *tun) ReadBatch(packets [][]byte, sizes []int) (int, error) {
if !t.nonBlocking { if !t.nonBlocking {
// Fallback to single read if non-blocking not enabled // Fallback to single read if non-blocking not enabled
n, err := t.ReadWriteCloser.Read(packets[0]) n, err := t.Read(packets[0])
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -437,10 +611,33 @@ func (t *tun) ReadBatch(packets [][]byte, sizes []int) (int, error) {
maxPackets = len(sizes) maxPackets = len(sizes)
} }
// Choose read buffer based on vnetHdr
// With vnetHdr, we need to read into internal buffer (has space for header)
// then copy packet data to caller's buffer
readBuf := packets[0] // Will be updated in the loop
if t.vnetHdr {
readBuf = t.readBuf
}
for count < maxPackets { for count < maxPackets {
n, err := unix.Read(t.fd, packets[count]) if !t.vnetHdr {
readBuf = packets[count]
}
n, err := unix.Read(t.fd, readBuf)
if err == nil && n > 0 { if err == nil && n > 0 {
sizes[count] = n if t.vnetHdr {
if n > virtioNetHdrLen {
packetLen := n - virtioNetHdrLen
copy(packets[count], readBuf[virtioNetHdrLen:n])
sizes[count] = packetLen
} else {
// Malformed packet (no data after header), skip
continue
}
} else {
sizes[count] = n
}
count++ count++
continue continue
} }