Make sure standard read/write path doesn't heap allocate

This commit is contained in:
Nate Brown
2026-05-04 15:14:09 -05:00
parent 3954d9af34
commit 32669e9568
2 changed files with 27 additions and 24 deletions

View File

@@ -24,7 +24,6 @@ import (
type tun struct { type tun struct {
f *os.File f *os.File
rc syscall.RawConn
Device string Device string
vpnNetworks []netip.Prefix vpnNetworks []netip.Prefix
DefaultMTU int DefaultMTU int
@@ -121,15 +120,8 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*t
return nil, fmt.Errorf("SetNonblock: %v", err) return nil, fmt.Errorf("SetNonblock: %v", err)
} }
f := os.NewFile(uintptr(fd), "")
rc, err := f.SyscallConn()
if err != nil {
return nil, fmt.Errorf("failed to get syscall conn for tun: %w", err)
}
t := &tun{ t := &tun{
f: f, f: os.NewFile(uintptr(fd), ""),
rc: rc,
Device: name, Device: name,
vpnNetworks: vpnNetworks, vpnNetworks: vpnNetworks,
DefaultMTU: c.GetInt("tun.mtu", DefaultMTU), DefaultMTU: c.GetInt("tun.mtu", DefaultMTU),
@@ -509,9 +501,15 @@ func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
// Read pulls one IP packet off the utun device. // Read pulls one IP packet off the utun device.
func (t *tun) Read(to []byte) (int, error) { func (t *tun) Read(to []byte) (int, error) {
// Grab rc as a local so the compiler can devirtualize the call and keep the closure on the stack.
rc, err := t.f.SyscallConn()
if err != nil {
return 0, err
}
var errno syscall.Errno var errno syscall.Errno
var n uintptr var n uintptr
err := t.rc.Read(func(fd uintptr) bool { err = rc.Read(func(fd uintptr) bool {
var head [4]byte var head [4]byte
iovecs := [2]syscall.Iovec{ iovecs := [2]syscall.Iovec{
{Base: &head[0], Len: 4}, {Base: &head[0], Len: 4},
@@ -529,10 +527,10 @@ func (t *tun) Read(to []byte) (int, error) {
if err == syscall.EBADF || err.Error() == "use of closed file" { if err == syscall.EBADF || err.Error() == "use of closed file" {
return 0, os.ErrClosed return 0, os.ErrClosed
} }
return 0, fmt.Errorf("failed to make read call for tun: %w", err) return 0, err
} }
if errno != 0 { if errno != 0 {
return 0, fmt.Errorf("failed to make inner read call for tun: %w", errno) return 0, errno
} }
bytesRead := int(n) bytesRead := int(n)
@@ -559,9 +557,14 @@ func (t *tun) Write(from []byte) (int, error) {
return 0, fmt.Errorf("unable to determine IP version from packet") return 0, fmt.Errorf("unable to determine IP version from packet")
} }
rc, err := t.f.SyscallConn()
if err != nil {
return 0, err
}
var errno syscall.Errno var errno syscall.Errno
var n uintptr var n uintptr
err := t.rc.Write(func(fd uintptr) bool { err = rc.Write(func(fd uintptr) bool {
iovecs := [2]syscall.Iovec{ iovecs := [2]syscall.Iovec{
{Base: &head[0], Len: 4}, {Base: &head[0], Len: 4},
{Base: &from[0], Len: uint64(len(from))}, {Base: &from[0], Len: uint64(len(from))},

View File

@@ -57,7 +57,6 @@ type tun struct {
l *slog.Logger l *slog.Logger
f *os.File f *os.File
fd int fd int
rc syscall.RawConn
// readBuf is the per-tun read scratch reused across calls so we don't allocate per Read. // readBuf is the per-tun read scratch reused across calls so we don't allocate per Read.
// OpenBSD's pinsyscall protection forbids raw syscall.Syscall(SYS_READV, ...) and stdlib doesn't keep syscall.readv // OpenBSD's pinsyscall protection forbids raw syscall.Syscall(SYS_READV, ...) and stdlib doesn't keep syscall.readv
@@ -95,16 +94,9 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*t
} }
mtu := c.GetInt("tun.mtu", DefaultMTU) mtu := c.GetInt("tun.mtu", DefaultMTU)
f := os.NewFile(uintptr(fd), "")
rc, err := f.SyscallConn()
if err != nil {
return nil, fmt.Errorf("failed to get syscall conn for tun: %w", err)
}
t := &tun{ t := &tun{
f: f, f: os.NewFile(uintptr(fd), ""),
fd: fd, fd: fd,
rc: rc,
Device: deviceName, Device: deviceName,
vpnNetworks: vpnNetworks, vpnNetworks: vpnNetworks,
MTU: mtu, MTU: mtu,
@@ -186,15 +178,23 @@ func (t *tun) Write(from []byte) (int, error) {
return 0, fmt.Errorf("unable to determine IP version from packet") return 0, fmt.Errorf("unable to determine IP version from packet")
} }
// Grab rc as a local so the compiler can devirtualize the call and keep the closure on the stack.
rc, err := t.f.SyscallConn()
if err != nil {
return 0, err
}
var n uintptr var n uintptr
var callErr error var callErr error
err := t.rc.Write(func(fd uintptr) bool { err = rc.Write(func(fd uintptr) bool {
iovecs := []syscall.Iovec{ iovecs := []syscall.Iovec{
{Base: &head[0], Len: 4}, {Base: &head[0], Len: 4},
{Base: &from[0], Len: uint64(len(from))}, {Base: &from[0], Len: uint64(len(from))},
} }
n, callErr = tunWritev(int(fd), iovecs) n, callErr = tunWritev(int(fd), iovecs)
if errors.Is(callErr, syscall.EAGAIN) || errors.Is(callErr, syscall.EWOULDBLOCK) || errors.Is(callErr, syscall.EINTR) { // Type-assert to syscall.Errno so the EAGAIN/EWOULDBLOCK/EINTR check doesn't box the errno
// constants into error interfaces on every call.
if errno, ok := callErr.(syscall.Errno); ok && errno.Temporary() {
return false return false
} }
return true return true