diff --git a/cmd/nebula/main.go b/cmd/nebula/main.go index 93f3967..4b864d3 100644 --- a/cmd/nebula/main.go +++ b/cmd/nebula/main.go @@ -3,6 +3,9 @@ package main import ( "flag" "fmt" + "log" + "net/http" + _ "net/http/pprof" "os" "github.com/sirupsen/logrus" @@ -58,6 +61,10 @@ func main() { os.Exit(1) } + go func() { + log.Println(http.ListenAndServe("0.0.0.0:6060", nil)) + }() + if !*configTest { wait, err := ctrl.Start() if err != nil { diff --git a/interface.go b/interface.go index 5327548..d03e313 100644 --- a/interface.go +++ b/interface.go @@ -279,6 +279,7 @@ func (f *Interface) listenOut(i int) { f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) }) + f.l.Errorf("udp reader %v is done", i) f.wg.Done() } @@ -296,6 +297,7 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { n, err := reader.Read(packet) if err != nil { if !f.closed.Load() { + //TODO: should we close? yes f.l.WithError(err).Error("Error while reading outbound packet") } break @@ -304,6 +306,7 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l)) } + f.l.Errorf("tun reader %v is done", i) f.wg.Done() } @@ -456,6 +459,7 @@ func (f *Interface) GetCertState() *CertState { func (f *Interface) Close() error { f.closed.Store(true) + // Release the udp readers for _, u := range f.writers { err := u.Close() if err != nil { @@ -463,6 +467,13 @@ func (f *Interface) Close() error { } } - // Release the tun device - return f.inside.Close() + // Release the tun readers + for _, u := range f.readers { + err := u.Close() + if err != nil { + f.l.WithError(err).Error("Error while closing tun device") + } + } + + return nil } diff --git a/outside.go b/outside.go index 1e9cde1..62fe146 100644 --- a/outside.go +++ b/outside.go @@ -29,7 +29,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] return } - //l.Error("in packet ", header, packet[HeaderLen:]) + //f.l.Error("in packet ", h) if ip.IsValid() { _, found := f.myVpnNetworksTable.Lookup(ip.Addr()) if found { diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 7d19c85..84632e6 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -65,6 +65,11 @@ type ifreqQLEN struct { } func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { + err := unix.SetNonblock(deviceFd, true) + if err != nil { + return nil, err + } + file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") t, err := newTunGeneric(c, l, file, vpnNetworks) @@ -111,6 +116,11 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu } name := strings.Trim(string(req.Name[:]), "\x00") + err = unix.SetNonblock(fd, true) + if err != nil { + return nil, err + } + file := os.NewFile(uintptr(fd), "/dev/net/tun") t, err := newTunGeneric(c, l, file, vpnNetworks) if err != nil { @@ -132,7 +142,12 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []n l: l, } - err := t.reload(c, true) + err := unix.SetNonblock(t.fd, true) + if err != nil { + return nil, err + } + + err = t.reload(c, true) if err != nil { return nil, err } @@ -227,6 +242,11 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, err } + err = unix.SetNonblock(fd, true) + if err != nil { + return nil, err + } + file := os.NewFile(uintptr(fd), "/dev/net/tun") return file, nil @@ -237,29 +257,6 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { return r } -func (t *tun) Write(b []byte) (int, error) { - var nn int - maximum := len(b) - - for { - n, err := unix.Write(t.fd, b[nn:maximum]) - if n > 0 { - nn += n - } - if nn == len(b) { - return nn, err - } - - if err != nil { - return nn, err - } - - if n == 0 { - return nn, io.ErrUnexpectedEOF - } - } -} - func (t *tun) deviceBytes() (o [16]byte) { for i, c := range t.Device { o[i] = byte(c) diff --git a/udp/udp_linux.go b/udp/udp_linux.go index f1936b4..661fe54 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -4,10 +4,12 @@ package udp import ( + "context" "encoding/binary" "fmt" "net" "net/netip" + "strconv" "syscall" "unsafe" @@ -18,58 +20,46 @@ import ( ) type StdConn struct { - sysFd int + c *net.UDPConn + rc syscall.RawConn isV4 bool l *logrus.Logger batch int } -func maybeIPV4(ip net.IP) (net.IP, bool) { - ip4 := ip.To4() - if ip4 != nil { - return ip4, true - } - return ip, false -} - func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { - af := unix.AF_INET6 - if ip.Is4() { - af = unix.AF_INET - } - syscall.ForkLock.RLock() - fd, err := unix.Socket(af, unix.SOCK_DGRAM, unix.IPPROTO_UDP) - if err == nil { - unix.CloseOnExec(fd) - } - syscall.ForkLock.RUnlock() + lc := net.ListenConfig{ + Control: func(network, address string, c syscall.RawConn) error { + if multi { + var err error + oErr := c.Control(func(fd uintptr) { + err = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1) + }) + if oErr != nil { + return fmt.Errorf("error while setting SO_REUSEPORT: %w", oErr) + } + if err != nil { + return fmt.Errorf("unable to set SO_REUSEPORT: %w", err) + } + } + return nil + }, + } + + c, err := lc.ListenPacket(context.Background(), "udp", net.JoinHostPort(ip.String(), strconv.Itoa(port))) if err != nil { - unix.Close(fd) - return nil, fmt.Errorf("unable to open socket: %s", err) + return nil, fmt.Errorf("unable to open socket: %w", err) } - if multi { - if err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil { - return nil, fmt.Errorf("unable to set SO_REUSEPORT: %s", err) - } + uc := c.(*net.UDPConn) + rc, err := uc.SyscallConn() + if err != nil { + _ = c.Close() + return nil, fmt.Errorf("unable to open sysfd: %w", err) } - var sa unix.Sockaddr - if ip.Is4() { - sa4 := &unix.SockaddrInet4{Port: port} - sa4.Addr = ip.As4() - sa = sa4 - } else { - sa6 := &unix.SockaddrInet6{Port: port} - sa6.Addr = ip.As16() - sa = sa6 - } - if err = unix.Bind(fd, sa); err != nil { - return nil, fmt.Errorf("unable to bind to socket: %s", err) - } - - return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err + return &StdConn{c: uc, rc: rc, isV4: ip.Is4(), l: l, batch: batch}, err } func (u *StdConn) Rebind() error { @@ -77,50 +67,83 @@ func (u *StdConn) Rebind() error { } func (u *StdConn) SetRecvBuffer(n int) error { - return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n) + var err error + oErr := u.rc.Control(func(fd uintptr) { + err = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n) + }) + if oErr != nil { + return oErr + } + return err } func (u *StdConn) SetSendBuffer(n int) error { - return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, n) + var err error + oErr := u.rc.Control(func(fd uintptr) { + err = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, n) + }) + if oErr != nil { + return oErr + } + return err } func (u *StdConn) SetSoMark(mark int) error { - return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_MARK, mark) + var err error + oErr := u.rc.Control(func(fd uintptr) { + err = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK, mark) + }) + if oErr != nil { + return oErr + } + return err } func (u *StdConn) GetRecvBuffer() (int, error) { - return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_RCVBUF) + var err error + var n int + oErr := u.rc.Control(func(fd uintptr) { + n, err = unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF) + }) + if oErr != nil { + return n, oErr + } + return n, err } func (u *StdConn) GetSendBuffer() (int, error) { - return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF) + var err error + var n int + oErr := u.rc.Control(func(fd uintptr) { + n, err = unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF) + }) + if oErr != nil { + return n, oErr + } + return n, err } func (u *StdConn) GetSoMark() (int, error) { - return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_MARK) + var err error + var n int + oErr := u.rc.Control(func(fd uintptr) { + n, err = unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK) + }) + if oErr != nil { + return n, oErr + } + return n, err } func (u *StdConn) LocalAddr() (netip.AddrPort, error) { - sa, err := unix.Getsockname(u.sysFd) - if err != nil { - return netip.AddrPort{}, err - } - - switch sa := sa.(type) { - case *unix.SockaddrInet4: - return netip.AddrPortFrom(netip.AddrFrom4(sa.Addr), uint16(sa.Port)), nil - - case *unix.SockaddrInet6: - return netip.AddrPortFrom(netip.AddrFrom16(sa.Addr), uint16(sa.Port)), nil - - default: - return netip.AddrPort{}, fmt.Errorf("unsupported sock type: %T", sa) - } + sa := u.c.LocalAddr() + return netip.ParseAddrPort(sa.String()) } func (u *StdConn) ListenOut(r EncReader) { var ip netip.Addr - + var n uintptr + var err error msgs, buffers, names := u.PrepareRawMessages(u.batch) read := u.ReadMulti if u.batch == 1 { @@ -128,124 +151,93 @@ func (u *StdConn) ListenOut(r EncReader) { } for { - n, err := read(msgs) + read(msgs, &n, &err) if err != nil { - u.l.WithError(err).Debug("udp socket is closed, exiting read loop") + u.l.WithError(err).Error("udp socket is closed, exiting read loop") return } - for i := 0; i < n; i++ { + for i := 0; i < int(n); i++ { // Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic if u.isV4 { ip, _ = netip.AddrFromSlice(names[i][4:8]) } else { ip, _ = netip.AddrFromSlice(names[i][8:24]) } + //u.l.Error("GOT A PACKET", msgs[i].Len) r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len]) } } } -func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) { - for { - n, _, err := unix.Syscall6( +func (u *StdConn) ReadSingle(msgs []rawMessage, n *uintptr, err *error) { + oErr := u.rc.Read(func(fd uintptr) bool { + in, _, nErr := unix.Syscall6( unix.SYS_RECVMSG, - uintptr(u.sysFd), + fd, uintptr(unsafe.Pointer(&(msgs[0].Hdr))), - 0, - 0, - 0, - 0, + 0, 0, 0, 0, ) - if err != 0 { - return 0, &net.OpError{Op: "recvmsg", Err: err} + if nErr == syscall.EAGAIN || nErr == syscall.EINTR { + // Retry read + return false + + } else if nErr != 0 { + u.l.Errorf("READING FROM UDP SINGLE had an errno %d", nErr) + *err = &net.OpError{Op: "recvmsg", Err: nErr} + *n = 0 + return true } - msgs[0].Len = uint32(n) - return 1, nil + msgs[0].Len = uint32(in) + *n = 1 + return true + }) + + if *err == nil && oErr != nil { + *err = oErr + *n = 0 + return } } -func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) { - for { - n, _, err := unix.Syscall6( +func (u *StdConn) ReadMulti(msgs []rawMessage, n *uintptr, err *error) { + oErr := u.rc.Read(func(fd uintptr) bool { + var nErr syscall.Errno + *n, _, nErr = unix.Syscall6( unix.SYS_RECVMMSG, - uintptr(u.sysFd), - uintptr(unsafe.Pointer(&msgs[0])), + fd, + uintptr(unsafe.Pointer(&(msgs[0].Hdr))), uintptr(len(msgs)), unix.MSG_WAITFORONE, - 0, - 0, + 0, 0, ) - if err != 0 { - return 0, &net.OpError{Op: "recvmmsg", Err: err} + if nErr == syscall.EAGAIN || nErr == syscall.EINTR { + // Retry read + return false + + } else if nErr != 0 { + u.l.Errorf("READING FROM UDP MULTI had an errno %d", nErr) + *err = &net.OpError{Op: "recvmmsg", Err: nErr} + *n = 0 + return true } - return int(n), nil + return true + }) + + if *err == nil && oErr != nil { + *err = oErr + *n = 0 + return } } func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error { - if u.isV4 { - return u.writeTo4(b, ip) - } - return u.writeTo6(b, ip) -} - -func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error { - var rsa unix.RawSockaddrInet6 - rsa.Family = unix.AF_INET6 - rsa.Addr = ip.Addr().As16() - binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ip.Port()) - - for { - _, _, err := unix.Syscall6( - unix.SYS_SENDTO, - uintptr(u.sysFd), - uintptr(unsafe.Pointer(&b[0])), - uintptr(len(b)), - uintptr(0), - uintptr(unsafe.Pointer(&rsa)), - uintptr(unix.SizeofSockaddrInet6), - ) - - if err != 0 { - return &net.OpError{Op: "sendto", Err: err} - } - - return nil - } -} - -func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error { - if !ip.Addr().Is4() { - return fmt.Errorf("Listener is IPv4, but writing to IPv6 remote") - } - - var rsa unix.RawSockaddrInet4 - rsa.Family = unix.AF_INET - rsa.Addr = ip.Addr().As4() - binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ip.Port()) - - for { - _, _, err := unix.Syscall6( - unix.SYS_SENDTO, - uintptr(u.sysFd), - uintptr(unsafe.Pointer(&b[0])), - uintptr(len(b)), - uintptr(0), - uintptr(unsafe.Pointer(&rsa)), - uintptr(unix.SizeofSockaddrInet4), - ) - - if err != 0 { - return &net.OpError{Op: "sendto", Err: err} - } - - return nil - } + _, err := u.c.WriteToUDPAddrPort(b, ip) + return err } func (u *StdConn) ReloadConfig(c *config.C) { @@ -298,15 +290,27 @@ func (u *StdConn) ReloadConfig(c *config.C) { func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error { var vallen uint32 = 4 * unix.SK_MEMINFO_VARS - _, _, err := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(u.sysFd), uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0) - if err != 0 { - return err + var err error + oErr := u.rc.Control(func(fd uintptr) { + _, _, err = unix.Syscall6( + unix.SYS_GETSOCKOPT, + fd, + uintptr(unix.SOL_SOCKET), + uintptr(unix.SO_MEMINFO), + uintptr(unsafe.Pointer(meminfo)), + uintptr(unsafe.Pointer(&vallen)), + 0, + ) + }) + if oErr != nil { + return oErr } - return nil + return err } func (u *StdConn) Close() error { - return syscall.Close(u.sysFd) + err := u.c.Close() + return err } func NewUDPStatsEmitter(udpConns []Conn) func() {