From a46222a2fe6ce143eb69b0ba22846778b1304974 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Wed, 16 Apr 2025 17:11:11 -0500 Subject: [PATCH] Try the timeout --- interface.go | 2 + overlay/tun_linux.go | 45 +++--- udp/udp_linux.go | 347 +++++++++++++++++++++---------------------- 3 files changed, 196 insertions(+), 198 deletions(-) diff --git a/interface.go b/interface.go index 5844986..79e642e 100644 --- a/interface.go +++ b/interface.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net/netip" + "runtime" "sync" "sync/atomic" "time" @@ -258,6 +259,7 @@ func (f *Interface) run() (func(), error) { } func (f *Interface) listenOut(i int) { + runtime.LockOSThread() var li udp.Conn if i > 0 { li = f.writers[i] diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 84632e6..7d19c85 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -65,11 +65,6 @@ type ifreqQLEN struct { } func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { - err := unix.SetNonblock(deviceFd, true) - if err != nil { - return nil, err - } - file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") t, err := newTunGeneric(c, l, file, vpnNetworks) @@ -116,11 +111,6 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu } name := strings.Trim(string(req.Name[:]), "\x00") - err = unix.SetNonblock(fd, true) - if err != nil { - return nil, err - } - file := os.NewFile(uintptr(fd), "/dev/net/tun") t, err := newTunGeneric(c, l, file, vpnNetworks) if err != nil { @@ -142,12 +132,7 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []n l: l, } - err := unix.SetNonblock(t.fd, true) - if err != nil { - return nil, err - } - - err = t.reload(c, true) + err := t.reload(c, true) if err != nil { return nil, err } @@ -242,11 +227,6 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, err } - err = unix.SetNonblock(fd, true) - if err != nil { - return nil, err - } - file := os.NewFile(uintptr(fd), "/dev/net/tun") return file, nil @@ -257,6 +237,29 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { return r } +func (t *tun) Write(b []byte) (int, error) { + var nn int + maximum := len(b) + + for { + n, err := unix.Write(t.fd, b[nn:maximum]) + if n > 0 { + nn += n + } + if nn == len(b) { + return nn, err + } + + if err != nil { + return nn, err + } + + if n == 0 { + return nn, io.ErrUnexpectedEOF + } + } +} + func (t *tun) deviceBytes() (o [16]byte) { for i, c := range t.Device { o[i] = byte(c) diff --git a/udp/udp_linux.go b/udp/udp_linux.go index 2a747e2..4bfcc3b 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -4,13 +4,12 @@ package udp import ( - "context" "encoding/binary" "fmt" "net" "net/netip" - "strconv" "syscall" + "time" "unsafe" "github.com/rcrowley/go-metrics" @@ -19,54 +18,58 @@ import ( "golang.org/x/sys/unix" ) +var readTimeout = unix.NsecToTimeval(int64(time.Millisecond * 500)) + type StdConn struct { - c *net.UDPConn - rc syscall.RawConn + sysFd int isV4 bool l *logrus.Logger batch int - - // cached fields for reading packets - msgs []rawMessage - buffers [][]byte - names [][]byte - n uintptr - err error } func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { - lc := net.ListenConfig{ - Control: func(network, address string, c syscall.RawConn) error { - if multi { - var err error - oErr := c.Control(func(fd uintptr) { - err = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1) - }) - if oErr != nil { - return fmt.Errorf("error while setting SO_REUSEPORT: %w", oErr) - } - if err != nil { - return fmt.Errorf("unable to set SO_REUSEPORT: %w", err) - } - } - - return nil - }, + af := unix.AF_INET6 + if ip.Is4() { + af = unix.AF_INET } + syscall.ForkLock.RLock() + fd, err := unix.Socket(af, unix.SOCK_DGRAM, unix.IPPROTO_UDP) + if err == nil { + unix.CloseOnExec(fd) + } + syscall.ForkLock.RUnlock() - c, err := lc.ListenPacket(context.Background(), "udp", net.JoinHostPort(ip.String(), strconv.Itoa(port))) if err != nil { - return nil, fmt.Errorf("unable to open socket: %w", err) + unix.Close(fd) + return nil, fmt.Errorf("unable to open socket: %s", err) } - uc := c.(*net.UDPConn) - rc, err := uc.SyscallConn() - if err != nil { - _ = c.Close() - return nil, fmt.Errorf("unable to open sysfd: %w", err) + if multi { + if err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil { + return nil, fmt.Errorf("unable to set SO_REUSEPORT: %s", err) + } } - return &StdConn{c: uc, rc: rc, isV4: ip.Is4(), l: l, batch: batch}, err + // Set a read timeout + if err = unix.SetsockoptTimeval(fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &readTimeout); err != nil { + return nil, fmt.Errorf("unable to set SO_RCVTIMEO: %s", err) + } + + var sa unix.Sockaddr + if ip.Is4() { + sa4 := &unix.SockaddrInet4{Port: port} + sa4.Addr = ip.As4() + sa = sa4 + } else { + sa6 := &unix.SockaddrInet6{Port: port} + sa6.Addr = ip.As16() + sa = sa6 + } + if err = unix.Bind(fd, sa); err != nil { + return nil, fmt.Errorf("unable to bind to socket: %s", err) + } + + return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err } func (u *StdConn) Rebind() error { @@ -74,179 +77,181 @@ func (u *StdConn) Rebind() error { } func (u *StdConn) SetRecvBuffer(n int) error { - var err error - oErr := u.rc.Control(func(fd uintptr) { - err = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n) - }) - if oErr != nil { - return oErr - } - return err + return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n) } func (u *StdConn) SetSendBuffer(n int) error { - var err error - oErr := u.rc.Control(func(fd uintptr) { - err = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, n) - }) - if oErr != nil { - return oErr - } - return err + return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, n) } func (u *StdConn) SetSoMark(mark int) error { - var err error - oErr := u.rc.Control(func(fd uintptr) { - err = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK, mark) - }) - if oErr != nil { - return oErr - } - return err + return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_MARK, mark) } func (u *StdConn) GetRecvBuffer() (int, error) { - var err error - var n int - oErr := u.rc.Control(func(fd uintptr) { - n, err = unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF) - }) - if oErr != nil { - return n, oErr - } - return n, err + return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_RCVBUF) } func (u *StdConn) GetSendBuffer() (int, error) { - var err error - var n int - oErr := u.rc.Control(func(fd uintptr) { - n, err = unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF) - }) - if oErr != nil { - return n, oErr - } - return n, err + return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF) } func (u *StdConn) GetSoMark() (int, error) { - var err error - var n int - oErr := u.rc.Control(func(fd uintptr) { - n, err = unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK) - }) - if oErr != nil { - return n, oErr - } - return n, err + return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_MARK) } func (u *StdConn) LocalAddr() (netip.AddrPort, error) { - sa := u.c.LocalAddr() - return netip.ParseAddrPort(sa.String()) + sa, err := unix.Getsockname(u.sysFd) + if err != nil { + return netip.AddrPort{}, err + } + + switch sa := sa.(type) { + case *unix.SockaddrInet4: + return netip.AddrPortFrom(netip.AddrFrom4(sa.Addr), uint16(sa.Port)), nil + + case *unix.SockaddrInet6: + return netip.AddrPortFrom(netip.AddrFrom16(sa.Addr), uint16(sa.Port)), nil + + default: + return netip.AddrPort{}, fmt.Errorf("unsupported sock type: %T", sa) + } } func (u *StdConn) ListenOut(r EncReader) { var ip netip.Addr - u.msgs, u.buffers, u.names = u.PrepareRawMessages(u.batch) + msgs, buffers, names := u.PrepareRawMessages(u.batch) read := u.ReadMulti if u.batch == 1 { read = u.ReadSingle } for { - read() - if u.err != nil { - //TODO: remove logging, return error - u.l.WithError(u.err).Error("udp socket is closed, exiting read loop") + n, err := read(msgs) + if err != nil { + u.l.WithError(err).Debug("udp socket is closed, exiting read loop") return } - for i := 0; i < int(u.n); i++ { + for i := 0; i < n; i++ { // Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic if u.isV4 { - ip, _ = netip.AddrFromSlice(u.names[i][4:8]) + ip, _ = netip.AddrFromSlice(names[i][4:8]) } else { - ip, _ = netip.AddrFromSlice(u.names[i][8:24]) + ip, _ = netip.AddrFromSlice(names[i][8:24]) } - //u.l.Error("GOT A PACKET", msgs[i].Len) - r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(u.names[i][2:4])), u.buffers[i][:u.msgs[i].Len]) + r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len]) } } } -func (u *StdConn) ReadSingle() { - err := u.rc.Read(u.innerReadSingle) - if u.err == nil && err != nil { - u.err = err - u.n = 0 - return +func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) { + for { + n, _, err := unix.Syscall6( + unix.SYS_RECVMSG, + uintptr(u.sysFd), + uintptr(unsafe.Pointer(&(msgs[0].Hdr))), + 0, + 0, + 0, + 0, + ) + + if err != 0 { + if err == unix.EAGAIN || err == unix.EINTR { + continue + } + return 0, &net.OpError{Op: "recvmsg", Err: err} + } + + msgs[0].Len = uint32(n) + return 1, nil } } -func (u *StdConn) innerReadSingle(fd uintptr) bool { - in, _, err := unix.Syscall6( - unix.SYS_RECVMSG, - fd, - uintptr(unsafe.Pointer(&(u.msgs[0].Hdr))), - 0, 0, 0, 0, - ) +func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) { + for { + n, _, err := unix.Syscall6( + unix.SYS_RECVMMSG, + uintptr(u.sysFd), + uintptr(unsafe.Pointer(&msgs[0])), + uintptr(len(msgs)), + unix.MSG_WAITFORONE, + 0, + 0, + ) - if err == syscall.EAGAIN || err == syscall.EINTR { - // Retry read - return false + if err != 0 { + if err == unix.EAGAIN || err == unix.EINTR { + continue + } + return 0, &net.OpError{Op: "recvmmsg", Err: err} + } - } else if err != 0 { - u.l.Errorf("READING FROM UDP SINGLE had an errno %d", err) - u.err = &net.OpError{Op: "recvmsg", Err: err} - u.n = 0 - return true + return int(n), nil } - - u.msgs[0].Len = uint32(in) - u.n = 1 - return true -} - -func (u *StdConn) ReadMulti() { - err := u.rc.Read(u.innerReadMulti) - if u.err == nil && err != nil { - u.err = err - u.n = 0 - return - } -} - -func (u *StdConn) innerReadMulti(fd uintptr) bool { - var err syscall.Errno - u.n, _, err = unix.Syscall6( - unix.SYS_RECVMMSG, - fd, - uintptr(unsafe.Pointer(&u.msgs[0])), - uintptr(len(u.msgs)), - unix.MSG_WAITFORONE, - 0, 0, - ) - - if err == syscall.EAGAIN || err == syscall.EINTR { - // Retry read - return false - - } else if err != 0 { - u.l.Errorf("READING FROM UDP MULTI had an errno %d", err) - u.err = &net.OpError{Op: "recvmmsg", Err: err} - u.n = 0 - return true - } - - return true } func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error { - _, err := u.c.WriteToUDPAddrPort(b, ip) - return err + if u.isV4 { + return u.writeTo4(b, ip) + } + return u.writeTo6(b, ip) +} + +func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error { + var rsa unix.RawSockaddrInet6 + rsa.Family = unix.AF_INET6 + rsa.Addr = ip.Addr().As16() + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ip.Port()) + + for { + _, _, err := unix.Syscall6( + unix.SYS_SENDTO, + uintptr(u.sysFd), + uintptr(unsafe.Pointer(&b[0])), + uintptr(len(b)), + uintptr(0), + uintptr(unsafe.Pointer(&rsa)), + uintptr(unix.SizeofSockaddrInet6), + ) + + if err != 0 { + return &net.OpError{Op: "sendto", Err: err} + } + + return nil + } +} + +func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error { + if !ip.Addr().Is4() { + return fmt.Errorf("Listener is IPv4, but writing to IPv6 remote") + } + + var rsa unix.RawSockaddrInet4 + rsa.Family = unix.AF_INET + rsa.Addr = ip.Addr().As4() + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ip.Port()) + + for { + _, _, err := unix.Syscall6( + unix.SYS_SENDTO, + uintptr(u.sysFd), + uintptr(unsafe.Pointer(&b[0])), + uintptr(len(b)), + uintptr(0), + uintptr(unsafe.Pointer(&rsa)), + uintptr(unix.SizeofSockaddrInet4), + ) + + if err != 0 { + return &net.OpError{Op: "sendto", Err: err} + } + + return nil + } } func (u *StdConn) ReloadConfig(c *config.C) { @@ -299,27 +304,15 @@ func (u *StdConn) ReloadConfig(c *config.C) { func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error { var vallen uint32 = 4 * unix.SK_MEMINFO_VARS - var err error - oErr := u.rc.Control(func(fd uintptr) { - _, _, err = unix.Syscall6( - unix.SYS_GETSOCKOPT, - fd, - uintptr(unix.SOL_SOCKET), - uintptr(unix.SO_MEMINFO), - uintptr(unsafe.Pointer(meminfo)), - uintptr(unsafe.Pointer(&vallen)), - 0, - ) - }) - if oErr != nil { - return oErr + _, _, err := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(u.sysFd), uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0) + if err != 0 { + return err } - return err + return nil } func (u *StdConn) Close() error { - err := u.c.Close() - return err + return syscall.Close(u.sysFd) } func NewUDPStatsEmitter(udpConns []Conn) func() {