diff --git a/connection_manager.go b/connection_manager.go index caf4d33..7242c72 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -518,12 +518,12 @@ func (cm *connectionManager) sendPunch(hostinfo *HostInfo) { if cm.punchy.GetTargetEverything() { hostinfo.remotes.ForEach(cm.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) { cm.metricsTxPunchy.Inc(1) - cm.intf.outside.WriteDirect([]byte{1}, addr) + cm.intf.outside.WriteTo([]byte{1}, addr) }) } else if hostinfo.remote.IsValid() { cm.metricsTxPunchy.Inc(1) - cm.intf.outside.WriteDirect([]byte{1}, hostinfo.remote) + cm.intf.outside.WriteTo([]byte{1}, hostinfo.remote) } } diff --git a/connection_state.go b/connection_state.go index c00c9af..485f6fd 100644 --- a/connection_state.go +++ b/connection_state.go @@ -15,8 +15,7 @@ import ( // TODO: In a 5Gbps test, 1024 is not sufficient. With a 1400 MTU this is about 1.4Gbps of window, assuming full packets. // 4092 should be sufficient for 5Gbps -// TODO this is a horrible amount of RAM to waste per-tunnel -const ReplayWindow = 0xffff / 2 +const ReplayWindow = 8192 type ConnectionState struct { eKey *NebulaCipherState diff --git a/handshake_ix.go b/handshake_ix.go index 2b7e6dd..026bfbd 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -348,7 +348,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet msg = existing.HandshakePacket[2] f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) if addr.IsValid() { - err := f.outside.WriteDirect(msg, addr) + err := f.outside.WriteTo(msg, addr) if err != nil { f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). @@ -417,7 +417,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet // Do the send f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) if addr.IsValid() { - err = f.outside.WriteDirect(msg, addr) + err = f.outside.WriteTo(msg, addr) if err != nil { f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). diff --git a/handshake_manager.go b/handshake_manager.go index 7264ae8..f92e72d 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -238,7 +238,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered var sentTo []netip.AddrPort hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr netip.AddrPort, _ bool) { hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1) - err := hm.outside.WriteDirect(hostinfo.HandshakePacket[0], addr) + err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr) if err != nil { hostinfo.logger(hm.l).WithField("udpAddr", addr). WithField("initiatorIndex", hostinfo.localIndexId). diff --git a/inside.go b/inside.go index 81fc161..d24ed31 100644 --- a/inside.go +++ b/inside.go @@ -8,7 +8,6 @@ import ( "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/noiseutil" - "github.com/slackhq/nebula/packet" "github.com/slackhq/nebula/routing" ) @@ -325,7 +324,7 @@ func (f *Interface) SendVia(via *HostInfo, via.logger(f.l).WithError(err).Info("Failed to EncryptDanger in sendVia") return } - err = f.writers[0].WriteDirect(out, via.remote) + err = f.writers[0].WriteTo(out, via.remote) if err != nil { via.logger(f.l).WithError(err).Info("Failed to WriteTo in sendVia") } @@ -385,29 +384,19 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType } if remote.IsValid() { - pkt := packet.GetPool().Get() - copy(pkt.Payload[:], out) - pkt.Payload = pkt.Payload[:len(out)] - pkt.Addr = remote - err = f.writers[q].WriteTo(pkt) + err = f.writers[q].WriteTo(out, remote) if err != nil { hostinfo.logger(f.l).WithError(err). WithField("udpAddr", remote).Error("Failed to write outgoing packet") } } else if hostinfo.remote.IsValid() { - pkt := packet.GetPool().Get() - copy(pkt.Payload, out) - pkt.Payload = pkt.Payload[:len(out)] - pkt.Addr = hostinfo.remote - err = f.writers[q].WriteTo(pkt) + err = f.writers[q].WriteTo(out, hostinfo.remote) if err != nil { hostinfo.logger(f.l).WithError(err). WithField("udpAddr", remote).Error("Failed to write outgoing packet") } } else { // Try to send via a relay - - //todo relay is slow sorryyy for _, relayIP := range hostinfo.relayState.CopyRelayIps() { relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP) if err != nil { diff --git a/interface.go b/interface.go index 6dde026..0579707 100644 --- a/interface.go +++ b/interface.go @@ -327,26 +327,6 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { f.wg.Done() } -//// todo why? understand! -//func normalizeGROSegSize(segSize, total int) int { -// if segCount > 1 && total > 0 { -// avg := total / segCount -// if avg > 0 { -// if segSize > avg { -// if segSize-8 == avg { -// segSize = avg -// } else if segSize > total { -// segSize = avg -// } -// } -// } -// } -// if segSize > total { -// segSize = total -// } -// return segSize -//} - func (f *Interface) workerIn(i int, ctx context.Context) { lhh := f.lightHouse.NewRequestHandler() conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) @@ -388,7 +368,7 @@ func (f *Interface) workerOut(i int, ctx context.Context) { select { case data := <-f.outbound: f.consumeInsidePacket(data.Payload, fwPacket1, nb1, result1, i, conntrackCache.Get(f.l)) - //f.pktPool.Put(data) //todo if err pls put packet back + f.pktPool.Put(data) case <-ctx.Done(): f.wg.Done() return diff --git a/lighthouse.go b/lighthouse.go index cadb6e0..9f00c39 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -1329,7 +1329,7 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn go func() { time.Sleep(lhh.lh.punchy.GetDelay()) lhh.lh.metricHolepunchTx.Inc(1) - lhh.lh.punchConn.WriteDirect(empty, vpnPeer) + lhh.lh.punchConn.WriteTo(empty, vpnPeer) }() if lhh.l.Level >= logrus.DebugLevel { diff --git a/outside.go b/outside.go index 513c0a8..a6dcb5c 100644 --- a/outside.go +++ b/outside.go @@ -519,7 +519,7 @@ func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) { f.messageMetrics.Tx(header.RecvError, 0, 1) b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0) - _ = f.outside.WriteDirect(b, endpoint) + _ = f.outside.WriteTo(b, endpoint) if f.l.Level >= logrus.DebugLevel { f.l.WithField("index", index). WithField("udpAddr", endpoint). diff --git a/udp/conn.go b/udp/conn.go index dcc3421..6d0b79e 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -17,8 +17,7 @@ type Conn interface { Rebind() error LocalAddr() (netip.AddrPort, error) ListenOut(pg PacketBufferGetter, pc chan *packet.Packet) error - WriteTo(p *packet.Packet) error - WriteDirect(b []byte, port netip.AddrPort) error + WriteTo(b []byte, addr netip.AddrPort) error ReloadConfig(c *config.C) Close() error } diff --git a/udp/udp_linux.go b/udp/udp_linux.go index cc0d735..f5c09d4 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -5,11 +5,9 @@ package udp import ( "encoding/binary" - "errors" "fmt" "net" "net/netip" - "sync" "syscall" "time" "unsafe" @@ -21,128 +19,8 @@ import ( "golang.org/x/sys/unix" ) -const ( - defaultGSOMaxSegments = 16 - defaultGSOFlushTimeout = 150 * time.Microsecond - maxGSOBatchBytes = 0xFFFF -) - -var ( - errGSOFallback = errors.New("udp gso fallback") - errGSODisabled = errors.New("udp gso disabled") -) - var readTimeout = unix.NsecToTimeval(int64(time.Millisecond * 500)) -type gsoState struct { - m sync.Mutex - Buf []byte - Addr netip.AddrPort - SegSize int - MaxSegments int - MaxBytes int - FlushTimeout time.Duration - Timer *time.Timer - - packets []*packet.Packet - msg rawMessage - name [unix.SizeofSockaddrInet6]byte - iov []iovec - ctrl []byte -} - -func (g *gsoState) Init() { - g.iov = make([]iovec, g.MaxSegments) - for i := 0; i < g.MaxSegments; i++ { - g.iov[i] = iovec{} - } - g.msg.Hdr.Iov = &g.iov[0] - g.msg.Hdr.Iovlen = 1 - - g.packets = make([]*packet.Packet, 0, g.MaxSegments) - g.ctrl = make([]byte, unix.CmsgSpace(2)) - hdr := (*unix.Cmsghdr)(unsafe.Pointer(&g.ctrl[0])) - hdr.Level = unix.SOL_UDP - hdr.Type = unix.UDP_SEGMENT - setCmsgLen(hdr, unix.CmsgLen(2)) - g.msg.Hdr.Control = &g.ctrl[0] - g.msg.Hdr.Controllen = uint64(len(g.ctrl)) - - g.name = [unix.SizeofSockaddrInet6]byte{} - g.msg.Hdr.Name = &g.name[0] - -} - -func (g *gsoState) setSegSizeLocked(segSize int) { - g.SegSize = segSize - x := unix.CmsgLen(0) - binary.LittleEndian.PutUint16(g.ctrl[x:x+2], uint16(segSize)) -} - -func (g *gsoState) setNameLocked(x netip.AddrPort, isV4 bool) { - g.Addr = x - nameLen := encodeSockaddr(g.name[:], g.Addr, isV4) - g.msg.Hdr.Name = &g.name[0] - g.msg.Hdr.Namelen = nameLen -} - -func encodeSockaddr(dst []byte, addr netip.AddrPort, isV4 bool) uint32 { - if isV4 { - //todo? - //if !addr.Addr().Is4() { - // return 0, fmt.Errorf("Listener is IPv4, but writing to IPv6 remote") - //} - var sa unix.RawSockaddrInet4 - sa.Family = unix.AF_INET - sa.Addr = addr.Addr().As4() - binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port()) - size := unix.SizeofSockaddrInet4 - copy(dst[:size], (*(*[unix.SizeofSockaddrInet4]byte)(unsafe.Pointer(&sa)))[:]) - return uint32(size) - } - - var sa unix.RawSockaddrInet6 - sa.Family = unix.AF_INET6 - sa.Addr = addr.Addr().As16() - binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port()) - size := unix.SizeofSockaddrInet6 - copy(dst[:size], (*(*[unix.SizeofSockaddrInet6]byte)(unsafe.Pointer(&sa)))[:]) - return uint32(size) -} - -func (g *gsoState) sendmsgLocked(fd int) error { - //name already set - //ctrl already set - //g.iov = g.iov[:0] - g.msg.Hdr.Iovlen = uint64(len(g.packets)) - for i := range g.packets { - g.iov[i].Base = &g.packets[i].Payload[0] - g.iov[i].Len = uint64(len(g.packets[i].Payload)) - } - - const flags = 0 - for { - _, _, err := unix.Syscall( - unix.SYS_SENDMSG, - uintptr(fd), - uintptr(unsafe.Pointer(&g.msg)), - uintptr(flags), - ) - //todo no matter what, reset things - for i := range g.packets { - pool := packet.GetPool() - pool.Put(g.packets[i]) - } - g.packets = g.packets[:0] - - if err != 0 { - return &net.OpError{Op: "sendmsg", Err: err} - } - - return nil - } -} - type StdConn struct { sysFd int isV4 bool @@ -150,7 +28,7 @@ type StdConn struct { batch int enableGRO bool enableGSO bool - gso gsoState + //gso gsoState } func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { @@ -358,123 +236,11 @@ func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) { } } -func (u *StdConn) WriteTo(p *packet.Packet) error { - if u.enableGSO && p.Addr.IsValid() { - if err := u.queueGSOPacket(p); err == nil { - return nil - } else if !errors.Is(err, errGSOFallback) { - return err - } - } - - var err error +func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error { if u.isV4 { - err = u.writeTo4(p.Payload, p.Addr) - } else { - err = u.writeTo4(p.Payload, p.Addr) + return u.writeTo4(b, ip) } - packet.GetPool().Put(p) - return err -} - -func (u *StdConn) WriteDirect(b []byte, addr netip.AddrPort) error { - if u.isV4 { - return u.writeTo4(b, addr) - } - return u.writeTo6(b, addr) -} - -func (u *StdConn) scheduleGSOFlushLocked() { - if u.gso.Timer == nil { - u.gso.Timer = time.AfterFunc(u.gso.FlushTimeout, u.gsoFlushTimer) - return - } - u.gso.Timer.Reset(u.gso.FlushTimeout) -} - -func (u *StdConn) stopGSOTimerLocked() { - if u.gso.Timer != nil { - u.gso.Timer.Stop() - u.gso.Timer = nil //todo I also don't like this - } -} - -func (u *StdConn) queueGSOPacket(p *packet.Packet) error { - if len(p.Payload) == 0 { - return nil - } - - u.gso.m.Lock() - defer u.gso.m.Unlock() - - if !u.enableGSO || !p.Addr.IsValid() || len(p.Payload) > u.gso.MaxBytes { - if err := u.flushGSOlocked(); err != nil { - return err - } - return errGSOFallback - } - - if len(u.gso.packets) == 0 { - u.gso.setNameLocked(p.Addr, u.isV4) - u.gso.SegSize = len(p.Payload) - u.gso.packets = append(u.gso.packets, p) - } else if p.Addr != u.gso.Addr || len(p.Payload) != u.gso.SegSize { - if err := u.flushGSOlocked(); err != nil { - return err - } //todo deal with "one small packet" case - u.gso.setNameLocked(p.Addr, u.isV4) - u.gso.SegSize = len(p.Payload) - u.gso.packets = append(u.gso.packets, p) - } else { - u.gso.packets = append(u.gso.packets, p) - } - - //big todo - //if len(u.gso.Buf)+len(p.Payload) > u.gso.MaxBytes { - // if err := u.flushGSOlocked(); err != nil { - // return err - // } - // u.gso.setNameLocked(p.Addr, u.isV4) - // u.gso.SegSize = len(p.Payload) - // u.gso.packets = append(u.gso.packets, p) - //} - - if len(u.gso.packets) >= u.gso.MaxSegments || u.gso.FlushTimeout <= 0 { - return u.flushGSOlocked() - } - - u.scheduleGSOFlushLocked() - return nil -} - -func (u *StdConn) flushGSOlocked() error { - if len(u.gso.packets) == 0 { - u.stopGSOTimerLocked() - return nil - } - - u.stopGSOTimerLocked() - - if u.gso.SegSize <= 0 { - return errGSOFallback - } - - err := u.gso.sendmsgLocked(u.sysFd) - if errors.Is(err, errGSODisabled) { - u.l.WithField("addr", u.gso.Addr).Warn("UDP GSO disabled by kernel, falling back to sendto") - u.enableGSO = false - //todo! - //return u.sendSegmentsIndividually(payload, addr, segSize) - } - u.gso.SegSize = 0 - - return err -} - -func (u *StdConn) gsoFlushTimer() { - u.gso.m.Lock() - _ = u.flushGSOlocked() - u.gso.m.Unlock() + return u.writeTo6(b, ip) } func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error { @@ -578,7 +344,6 @@ func (u *StdConn) ReloadConfig(c *config.C) { } } u.configureGRO(true) - u.configureGSO(c) } func (u *StdConn) configureGRO(enable bool) { @@ -602,49 +367,6 @@ func (u *StdConn) configureGRO(enable bool) { u.enableGRO = false } -func (u *StdConn) configureGSO(c *config.C) { - enable := c.GetBool("listen.enable_gso", true) - if !enable { - u.disableGSO() - } else { - u.enableGSO = true - } - - segments := c.GetInt("listen.gso_max_segments", defaultGSOMaxSegments) - if segments < 1 { - segments = 1 - } - u.gso.MaxSegments = segments - - maxBytes := c.GetInt("listen.gso_max_bytes", 0) - if maxBytes <= 0 { - maxBytes = MTU * segments - } - if maxBytes > maxGSOBatchBytes { - u.l.WithField("requested", maxBytes).Warn("listen.gso_max_bytes larger than UDP limit; clamping") - maxBytes = maxGSOBatchBytes - } - u.gso.MaxBytes = maxBytes - - timeout := c.GetDuration("listen.gso_flush_timeout", defaultGSOFlushTimeout) - if timeout < 0 { - timeout = 0 - } - u.gso.FlushTimeout = timeout - u.gso.Init() -} - -func (u *StdConn) disableGSO() { - u.gso.m.Lock() - defer u.gso.m.Unlock() - u.enableGSO = false - _ = u.flushGSOlocked() - u.gso.Buf = nil - u.gso.packets = u.gso.packets[:0] - u.gso.SegSize = 0 - u.stopGSOTimerLocked() -} - func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error { var vallen uint32 = 4 * unix.SK_MEMINFO_VARS _, _, err := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(u.sysFd), uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0)