diff --git a/udp/errors.go b/udp/errors.go new file mode 100644 index 0000000..12a8487 --- /dev/null +++ b/udp/errors.go @@ -0,0 +1,5 @@ +package udp + +import "errors" + +var ErrInvalidIPv6RemoteForSocket = errors.New("listener is IPv4, but writing to IPv6 remote") diff --git a/udp/udp_darwin.go b/udp/udp_darwin.go index 183ac7a..74041ca 100644 --- a/udp/udp_darwin.go +++ b/udp/udp_darwin.go @@ -6,17 +6,63 @@ package udp // Darwin support is primarily implemented in udp_generic, besides NewListenConfig import ( + "context" + "encoding/binary" + "errors" "fmt" "net" "net/netip" "syscall" + "unsafe" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/firewall" + "github.com/slackhq/nebula/header" "golang.org/x/sys/unix" ) +type StdConn struct { + *net.UDPConn + isV4 bool + sysFd uintptr + l *logrus.Logger +} + +var _ Conn = &StdConn{} + func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { - return NewGenericListener(l, ip, port, multi, batch) + lc := NewListenConfig(multi) + pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port))) + if err != nil { + return nil, err + } + + if uc, ok := pc.(*net.UDPConn); ok { + c := &StdConn{UDPConn: uc, l: l} + + rc, err := uc.SyscallConn() + if err != nil { + return nil, fmt.Errorf("failed to open udp socket: %w", err) + } + + err = rc.Control(func(fd uintptr) { + c.sysFd = fd + }) + if err != nil { + return nil, fmt.Errorf("failed to get udp fd: %w", err) + } + + la, err := c.LocalAddr() + if err != nil { + return nil, err + } + c.isV4 = la.Addr().Is4() + + return c, nil + } + + return nil, fmt.Errorf("unexpected PacketConn: %T %#v", pc, pc) } func NewListenConfig(multi bool) net.ListenConfig { @@ -43,16 +89,130 @@ func NewListenConfig(multi bool) net.ListenConfig { } } -func (u *GenericConn) Rebind() error { - rc, err := u.UDPConn.SyscallConn() - if err != nil { - return err +//go:linkname sendto golang.org/x/sys/unix.sendto +//go:noescape +func sendto(s int, buf []byte, flags int, to unsafe.Pointer, addrlen int32) (err error) + +func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error { + var sa unsafe.Pointer + var addrLen int32 + + if u.isV4 { + if ap.Addr().Is6() { + return ErrInvalidIPv6RemoteForSocket + } + + var rsa unix.RawSockaddrInet6 + rsa.Family = unix.AF_INET6 + rsa.Addr = ap.Addr().As16() + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port()) + sa = unsafe.Pointer(&rsa) + addrLen = syscall.SizeofSockaddrInet4 + } else { + var rsa unix.RawSockaddrInet6 + rsa.Family = unix.AF_INET6 + rsa.Addr = ap.Addr().As16() + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port()) + sa = unsafe.Pointer(&rsa) + addrLen = syscall.SizeofSockaddrInet6 } - return rc.Control(func(fd uintptr) { - err := syscall.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, 0) - if err != nil { - u.l.WithError(err).Error("Failed to rebind udp socket") + // Golang stdlib doesn't handle EAGAIN correctly in some situations so we do writes ourselves + // See https://github.com/golang/go/issues/73919 + for { + //_, _, err := unix.Syscall6(unix.SYS_SENDTO, u.sysFd, uintptr(unsafe.Pointer(&b[0])), uintptr(len(b)), 0, sa, addrLen) + err := sendto(int(u.sysFd), b, 0, sa, addrLen) + if err == nil { + // Written, get out before the error handling + return nil } - }) + + if errors.Is(err, syscall.EINTR) { + // Write was interrupted, retry + continue + } + + if errors.Is(err, syscall.EAGAIN) { + return &net.OpError{Op: "sendto", Err: unix.EWOULDBLOCK} + } + + if errors.Is(err, syscall.EBADF) { + return net.ErrClosed + } + + return &net.OpError{Op: "sendto", Err: err} + } +} + +func (u *StdConn) LocalAddr() (netip.AddrPort, error) { + a := u.UDPConn.LocalAddr() + + switch v := a.(type) { + case *net.UDPAddr: + addr, ok := netip.AddrFromSlice(v.IP) + if !ok { + return netip.AddrPort{}, fmt.Errorf("LocalAddr returned invalid IP address: %s", v.IP) + } + return netip.AddrPortFrom(addr, uint16(v.Port)), nil + + default: + return netip.AddrPort{}, fmt.Errorf("LocalAddr returned: %#v", a) + } +} + +func (u *StdConn) ReloadConfig(c *config.C) { + // TODO +} + +func NewUDPStatsEmitter(udpConns []Conn) func() { + // No UDP stats for non-linux + return func() {} +} + +func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { + plaintext := make([]byte, MTU) + buffer := make([]byte, MTU) + h := &header.H{} + fwPacket := &firewall.Packet{} + nb := make([]byte, 12, 12) + + for { + // Just read one packet at a time + n, rua, err := u.ReadFromUDPAddrPort(buffer) + if err != nil { + if errors.Is(err, net.ErrClosed) { + u.l.WithError(err).Debug("udp socket is closed, exiting read loop") + return + } + + u.l.WithError(err).Error("unexpected udp socket receive error") + } + + r( + netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), + plaintext[:0], + buffer[:n], + h, + fwPacket, + lhf, + nb, + q, + cache.Get(u.l), + ) + } +} + +func (u *StdConn) Rebind() error { + var err error + if u.isV4 { + err = syscall.SetsockoptInt(int(u.sysFd), syscall.IPPROTO_IP, syscall.IP_BOUND_IF, 0) + } else { + err = syscall.SetsockoptInt(int(u.sysFd), syscall.IPPROTO_IPV6, syscall.IPV6_BOUND_IF, 0) + } + + if err != nil { + u.l.WithError(err).Error("Failed to rebind udp socket") + } + + return nil } diff --git a/udp/udp_generic.go b/udp/udp_generic.go index 2d84536..74b7d29 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -1,6 +1,7 @@ -//go:build (!linux || android) && !e2e_testing +//go:build (!linux || android) && !e2e_testing && !darwin // +build !linux android // +build !e2e_testing +// +build !darwin // udp_generic implements the nebula UDP interface in pure Go stdlib. This // means it can be used on platforms like Darwin and Windows. diff --git a/udp/udp_linux.go b/udp/udp_linux.go index 2eee76e..eee83cf 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -243,7 +243,7 @@ func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error { func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error { if !ip.Addr().Is4() { - return fmt.Errorf("Listener is IPv4, but writing to IPv6 remote") + return ErrInvalidIPv6RemoteForSocket } var rsa unix.RawSockaddrInet4