diff --git a/interface.go b/interface.go index d03e313..5844986 100644 --- a/interface.go +++ b/interface.go @@ -6,7 +6,6 @@ import ( "fmt" "io" "net/netip" - "runtime" "sync" "sync/atomic" "time" @@ -259,8 +258,6 @@ func (f *Interface) run() (func(), error) { } func (f *Interface) listenOut(i int) { - runtime.LockOSThread() - var li udp.Conn if i > 0 { li = f.writers[i] @@ -284,8 +281,6 @@ func (f *Interface) listenOut(i int) { } func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { - runtime.LockOSThread() - packet := make([]byte, mtu) out := make([]byte, mtu) fwPacket := &firewall.Packet{} diff --git a/udp/udp_linux.go b/udp/udp_linux.go index 661fe54..2a747e2 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -25,6 +25,13 @@ type StdConn struct { isV4 bool l *logrus.Logger batch int + + // cached fields for reading packets + msgs []rawMessage + buffers [][]byte + names [][]byte + n uintptr + err error } func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { @@ -142,99 +149,101 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) { func (u *StdConn) ListenOut(r EncReader) { var ip netip.Addr - var n uintptr - var err error - msgs, buffers, names := u.PrepareRawMessages(u.batch) + + u.msgs, u.buffers, u.names = u.PrepareRawMessages(u.batch) read := u.ReadMulti if u.batch == 1 { read = u.ReadSingle } for { - read(msgs, &n, &err) - if err != nil { - u.l.WithError(err).Error("udp socket is closed, exiting read loop") + read() + if u.err != nil { + //TODO: remove logging, return error + u.l.WithError(u.err).Error("udp socket is closed, exiting read loop") return } - for i := 0; i < int(n); i++ { + for i := 0; i < int(u.n); i++ { // Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic if u.isV4 { - ip, _ = netip.AddrFromSlice(names[i][4:8]) + ip, _ = netip.AddrFromSlice(u.names[i][4:8]) } else { - ip, _ = netip.AddrFromSlice(names[i][8:24]) + ip, _ = netip.AddrFromSlice(u.names[i][8:24]) } //u.l.Error("GOT A PACKET", msgs[i].Len) - r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len]) + r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(u.names[i][2:4])), u.buffers[i][:u.msgs[i].Len]) } } } -func (u *StdConn) ReadSingle(msgs []rawMessage, n *uintptr, err *error) { - oErr := u.rc.Read(func(fd uintptr) bool { - in, _, nErr := unix.Syscall6( - unix.SYS_RECVMSG, - fd, - uintptr(unsafe.Pointer(&(msgs[0].Hdr))), - 0, 0, 0, 0, - ) - - if nErr == syscall.EAGAIN || nErr == syscall.EINTR { - // Retry read - return false - - } else if nErr != 0 { - u.l.Errorf("READING FROM UDP SINGLE had an errno %d", nErr) - *err = &net.OpError{Op: "recvmsg", Err: nErr} - *n = 0 - return true - } - - msgs[0].Len = uint32(in) - *n = 1 - return true - }) - - if *err == nil && oErr != nil { - *err = oErr - *n = 0 +func (u *StdConn) ReadSingle() { + err := u.rc.Read(u.innerReadSingle) + if u.err == nil && err != nil { + u.err = err + u.n = 0 return } } -func (u *StdConn) ReadMulti(msgs []rawMessage, n *uintptr, err *error) { - oErr := u.rc.Read(func(fd uintptr) bool { - var nErr syscall.Errno - *n, _, nErr = unix.Syscall6( - unix.SYS_RECVMMSG, - fd, - uintptr(unsafe.Pointer(&(msgs[0].Hdr))), - uintptr(len(msgs)), - unix.MSG_WAITFORONE, - 0, 0, - ) +func (u *StdConn) innerReadSingle(fd uintptr) bool { + in, _, err := unix.Syscall6( + unix.SYS_RECVMSG, + fd, + uintptr(unsafe.Pointer(&(u.msgs[0].Hdr))), + 0, 0, 0, 0, + ) - if nErr == syscall.EAGAIN || nErr == syscall.EINTR { - // Retry read - return false - - } else if nErr != 0 { - u.l.Errorf("READING FROM UDP MULTI had an errno %d", nErr) - *err = &net.OpError{Op: "recvmmsg", Err: nErr} - *n = 0 - return true - } + if err == syscall.EAGAIN || err == syscall.EINTR { + // Retry read + return false + } else if err != 0 { + u.l.Errorf("READING FROM UDP SINGLE had an errno %d", err) + u.err = &net.OpError{Op: "recvmsg", Err: err} + u.n = 0 return true - }) + } - if *err == nil && oErr != nil { - *err = oErr - *n = 0 + u.msgs[0].Len = uint32(in) + u.n = 1 + return true +} + +func (u *StdConn) ReadMulti() { + err := u.rc.Read(u.innerReadMulti) + if u.err == nil && err != nil { + u.err = err + u.n = 0 return } } +func (u *StdConn) innerReadMulti(fd uintptr) bool { + var err syscall.Errno + u.n, _, err = unix.Syscall6( + unix.SYS_RECVMMSG, + fd, + uintptr(unsafe.Pointer(&u.msgs[0])), + uintptr(len(u.msgs)), + unix.MSG_WAITFORONE, + 0, 0, + ) + + if err == syscall.EAGAIN || err == syscall.EINTR { + // Retry read + return false + + } else if err != 0 { + u.l.Errorf("READING FROM UDP MULTI had an errno %d", err) + u.err = &net.OpError{Op: "recvmmsg", Err: err} + u.n = 0 + return true + } + + return true +} + func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error { _, err := u.c.WriteToUDPAddrPort(b, ip) return err