diff --git a/inside.go b/inside.go index 0d53f95..0f6e18e 100644 --- a/inside.go +++ b/inside.go @@ -8,10 +8,11 @@ import ( "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/noiseutil" + "github.com/slackhq/nebula/packet" "github.com/slackhq/nebula/routing" ) -func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) { +func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb []byte, out *packet.Packet, q int, localCache firewall.ConntrackCache) { err := newPacket(packet, false, fwPacket) if err != nil { if f.l.Level >= logrus.DebugLevel { @@ -53,7 +54,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet }) if hostinfo == nil { - f.rejectInside(packet, out, q) + f.rejectInside(packet, out.Payload, q) //todo vector? if f.l.Level >= logrus.DebugLevel { f.l.WithField("vpnAddr", fwPacket.RemoteAddr). WithField("fwPacket", fwPacket). @@ -68,10 +69,9 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) if dropReason == nil { - f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q) - + f.sendNoMetricsDelayed(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q) } else { - f.rejectInside(packet, out, q) + f.rejectInside(packet, out.Payload, q) //todo vector? if f.l.Level >= logrus.DebugLevel { hostinfo.logger(f.l). WithField("fwPacket", fwPacket). @@ -410,3 +410,81 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType } } } + +func (f *Interface) sendNoMetricsDelayed(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb []byte, out *packet.Packet, q int) { + if ci.eKey == nil { + return + } + useRelay := !remote.IsValid() && !hostinfo.remote.IsValid() + fullOut := out.Payload + + if useRelay { + if len(out.Payload) < header.Len { + // out always has a capacity of mtu, but not always a length greater than the header.Len. + // Grow it to make sure the next operation works. + out.Payload = out.Payload[:header.Len] + } + // Save a header's worth of data at the front of the 'out' buffer. + out.Payload = out.Payload[header.Len:] + } + + if noiseutil.EncryptLockNeeded { + // NOTE: for goboring AESGCMTLS we need to lock because of the nonce check + ci.writeLock.Lock() + } + c := ci.messageCounter.Add(1) + + //l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p) + out.Payload = header.Encode(out.Payload, header.Version, t, st, hostinfo.remoteIndexId, c) + f.connectionManager.Out(hostinfo) + + // Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against + // all our addrs and enable a faster roaming. + if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount { + //NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is + // finally used again. This tunnel would eventually be torn down and recreated if this action didn't help. + f.lightHouse.QueryServer(hostinfo.vpnAddrs[0]) + hostinfo.lastRebindCount = f.rebindCount + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter") + } + } + + var err error + out.Payload, err = ci.eKey.EncryptDanger(out.Payload, out.Payload, p, c, nb) + if noiseutil.EncryptLockNeeded { + ci.writeLock.Unlock() + } + if err != nil { + hostinfo.logger(f.l).WithError(err). + WithField("udpAddr", remote).WithField("counter", c). + WithField("attemptedCounter", c). + Error("Failed to encrypt outgoing packet") + return + } + + if remote.IsValid() { + err = f.writers[q].Prep(out, remote) + if err != nil { + hostinfo.logger(f.l).WithError(err).WithField("udpAddr", remote).Error("Failed to write outgoing packet") + } + } else if hostinfo.remote.IsValid() { + err = f.writers[q].Prep(out, hostinfo.remote) + if err != nil { + hostinfo.logger(f.l).WithError(err).WithField("udpAddr", remote).Error("Failed to write outgoing packet") + } + } else { + // Try to send via a relay + for _, relayIP := range hostinfo.relayState.CopyRelayIps() { + relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP) + if err != nil { + hostinfo.relayState.DeleteRelay(relayIP) + hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo") + continue + } + //todo vector!! + f.SendVia(relayHostInfo, relay, out.Payload, nb, fullOut[:header.Len+len(out.Payload)], true) + break + } + } +} diff --git a/interface.go b/interface.go index e6da5b8..cefd330 100644 --- a/interface.go +++ b/interface.go @@ -318,15 +318,16 @@ func (f *Interface) listenIn(reader overlay.TunDev, queueNum int) { for i := 0; i < batch; i++ { originalPackets[i] = make([]byte, 0xffff) } - out := make([]byte, mtu) fwPacket := &firewall.Packet{} nb := make([]byte, 12, 12) conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) packets := make([]*packet.VirtIOPacket, batch) + outPackets := make([]*packet.Packet, batch) for i := 0; i < batch; i++ { packets[i] = packet.NewVIO() + outPackets[i] = packet.New(false) //todo? } for { @@ -343,9 +344,13 @@ func (f *Interface) listenIn(reader overlay.TunDev, queueNum int) { os.Exit(2) } - //todo vectorize - for _, pkt := range packets[:n] { - f.consumeInsidePacket(pkt.Payload, fwPacket, nb, out, queueNum, conntrackCache.Get(f.l)) + for i, pkt := range packets[:n] { + outPackets[i].OutLen = -1 + f.consumeInsidePacket(pkt.Payload, fwPacket, nb, outPackets[i], queueNum, conntrackCache.Get(f.l)) + } + _, err = f.writers[queueNum].WriteBatch(outPackets[:n]) + if err != nil { + f.l.WithError(err).Error("Error while writing outbound packets") } } diff --git a/packet/packet.go b/packet/packet.go index 8eeb448..2b7a43b 100644 --- a/packet/packet.go +++ b/packet/packet.go @@ -4,6 +4,8 @@ import ( "encoding/binary" "iter" "net/netip" + "syscall" + "unsafe" "golang.org/x/sys/unix" ) @@ -73,6 +75,32 @@ func (p *Packet) Update(ctrlLen int) { p.updateCtrl(ctrlLen) } +func (p *Packet) SetSegSizeForTX() { + p.SegSize = len(p.Payload) + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&p.Control[0])) + hdr.Level = unix.SOL_UDP + hdr.Type = unix.UDP_SEGMENT + hdr.SetLen(syscall.CmsgLen(2)) + //setCmsgLen(hdr, unix.CmsgLen(2)) + binary.NativeEndian.PutUint16(p.Control[unix.CmsgLen(0):unix.CmsgLen(0)+2], uint16(p.SegSize)) + //data := p.Control[syscall.CmsgSpace(0)-syscall.CmsgSpace(2)+syscall.SizeofCmsghdr:] + //binary.NativeEndian.PutUint16(data, uint16(p.SegSize)) +} + +func (p *Packet) CompatibleForSegmentationWith(otherP *Packet) bool { + //same dest + + if p.AddrPort() != otherP.AddrPort() { + return false //todo more efficient? + } + + //same body len + if len(p.Payload) != len(otherP.Payload) { + return false //todo technically you can cram one extra in + } + return true +} + func (p *Packet) Segments() iter.Seq[[]byte] { return func(yield func([]byte) bool) { //cursor := 0 diff --git a/udp/conn.go b/udp/conn.go index d7583f2..ba26c00 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -19,6 +19,8 @@ type Conn interface { ListenOut(r EncReader) WriteTo(b []byte, addr netip.AddrPort) error ReloadConfig(c *config.C) + Prep(pkt *packet.Packet, addr netip.AddrPort) error + WriteBatch(pkt []*packet.Packet) (int, error) Close() error } diff --git a/udp/udp_linux.go b/udp/udp_linux.go index dac89d6..187e411 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -14,6 +14,7 @@ import ( "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/packet" "golang.org/x/sys/unix" ) @@ -191,6 +192,124 @@ func (u *StdConn) WriteToBatch(b []byte, ip netip.AddrPort) error { return u.writeTo6(b, ip) } +func (u *StdConn) Prep(pkt *packet.Packet, addr netip.AddrPort) error { + nl, err := u.encodeSockaddr(pkt.Name, addr) + if err != nil { + return err + } + pkt.Name = pkt.Name[:nl] + pkt.OutLen = len(pkt.Payload) + return nil +} + +func (u *StdConn) WriteBatch(pkts []*packet.Packet) (int, error) { + if len(pkts) == 0 { + return 0, nil + } + + msgs := make([]rawMessage, 0, len(pkts)) //todo recycle + iovs := make([][]iovec, 0, len(pkts)) + + sent := 0 + + var mostRecentPkt *packet.Packet + //segmenting := false + idx := 0 + for _, pkt := range pkts { + if len(pkt.Payload) == 0 || pkt.OutLen == -1 { + sent++ + continue + } + lastIdx := idx - 1 + if mostRecentPkt != nil && pkt.CompatibleForSegmentationWith(mostRecentPkt) && msgs[lastIdx].Hdr.Iovlen < 4 { + + msgs[lastIdx].Hdr.Controllen = uint64(len(mostRecentPkt.Control)) + msgs[lastIdx].Hdr.Control = &mostRecentPkt.Control[0] + msgs[lastIdx].Hdr.Iovlen++ + iovs[lastIdx] = append(iovs[lastIdx], iovec{ + Base: &pkt.Payload[0], + Len: uint64(len(pkt.Payload)), + }) + mostRecentPkt.SetSegSizeForTX() + } else { + msgs = append(msgs, rawMessage{}) + iovs = append(iovs, make([]iovec, 1, 8)) //todo + iovs[idx][0] = iovec{ + Base: &pkt.Payload[0], + Len: uint64(len(pkt.Payload)), + } + + msg := &msgs[idx] + iov := &iovs[idx][0] + idx++ + + msg.Hdr.Iov = iov + msg.Hdr.Iovlen = 1 + setRawMessageControl(msg, nil) + msg.Hdr.Flags = 0 + + msg.Hdr.Name = &pkt.Name[0] + msg.Hdr.Namelen = uint32(len(pkt.Name)) + mostRecentPkt = pkt + } + + } + + if len(msgs) == 0 { + return sent, nil + } + + offset := 0 + for offset < len(msgs) { + n, _, errno := unix.Syscall6( + unix.SYS_SENDMMSG, + uintptr(u.sysFd), + uintptr(unsafe.Pointer(&msgs[offset])), + uintptr(len(msgs)-offset), + 0, + 0, + 0, + ) + + if errno != 0 { + if errno == unix.EINTR { + continue + } + return sent + offset, &net.OpError{Op: "sendmmsg", Err: errno} + } + + if n == 0 { + break + } + offset += int(n) + } + + return sent + len(msgs), nil +} + +func (u *StdConn) encodeSockaddr(dst []byte, addr netip.AddrPort) (uint32, error) { + if u.isV4 { + if !addr.Addr().Is4() { + return 0, fmt.Errorf("Listener is IPv4, but writing to IPv6 remote") + } + var sa unix.RawSockaddrInet4 + sa.Family = unix.AF_INET + sa.Addr = addr.Addr().As4() + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port()) + size := unix.SizeofSockaddrInet4 + copy(dst[:size], (*(*[unix.SizeofSockaddrInet4]byte)(unsafe.Pointer(&sa)))[:]) + return uint32(size), nil + } + + var sa unix.RawSockaddrInet6 + sa.Family = unix.AF_INET6 + sa.Addr = addr.Addr().As16() + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port()) + size := unix.SizeofSockaddrInet6 + copy(dst[:size], (*(*[unix.SizeofSockaddrInet6]byte)(unsafe.Pointer(&sa)))[:]) + return uint32(size), nil +} + func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error { var rsa unix.RawSockaddrInet6 rsa.Family = unix.AF_INET6 diff --git a/udp/udp_linux_64.go b/udp/udp_linux_64.go index 8e778f0..e9e3ccb 100644 --- a/udp/udp_linux_64.go +++ b/udp/udp_linux_64.go @@ -80,3 +80,13 @@ func (u *StdConn) PrepareRawMessages(n int, isV4 bool) ([]rawMessage, []*packet. return msgs, packets } + +func setIovecSlice(iov *iovec, b []byte) { + if len(b) == 0 { + iov.Base = nil + iov.Len = 0 + return + } + iov.Base = &b[0] + iov.Len = uint64(len(b)) +}