why does it work

This commit is contained in:
JackDoan
2025-11-11 19:02:16 -06:00
parent 400fdace9d
commit 17a6917428
6 changed files with 251 additions and 9 deletions

View File

@@ -8,10 +8,11 @@ import (
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/noiseutil"
"github.com/slackhq/nebula/packet"
"github.com/slackhq/nebula/routing"
)
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb []byte, out *packet.Packet, q int, localCache firewall.ConntrackCache) {
err := newPacket(packet, false, fwPacket)
if err != nil {
if f.l.Level >= logrus.DebugLevel {
@@ -53,7 +54,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
})
if hostinfo == nil {
f.rejectInside(packet, out, q)
f.rejectInside(packet, out.Payload, q) //todo vector?
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnAddr", fwPacket.RemoteAddr).
WithField("fwPacket", fwPacket).
@@ -68,10 +69,9 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason == nil {
f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q)
f.sendNoMetricsDelayed(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q)
} else {
f.rejectInside(packet, out, q)
f.rejectInside(packet, out.Payload, q) //todo vector?
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).
WithField("fwPacket", fwPacket).
@@ -410,3 +410,81 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
}
}
}
func (f *Interface) sendNoMetricsDelayed(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb []byte, out *packet.Packet, q int) {
if ci.eKey == nil {
return
}
useRelay := !remote.IsValid() && !hostinfo.remote.IsValid()
fullOut := out.Payload
if useRelay {
if len(out.Payload) < header.Len {
// out always has a capacity of mtu, but not always a length greater than the header.Len.
// Grow it to make sure the next operation works.
out.Payload = out.Payload[:header.Len]
}
// Save a header's worth of data at the front of the 'out' buffer.
out.Payload = out.Payload[header.Len:]
}
if noiseutil.EncryptLockNeeded {
// NOTE: for goboring AESGCMTLS we need to lock because of the nonce check
ci.writeLock.Lock()
}
c := ci.messageCounter.Add(1)
//l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p)
out.Payload = header.Encode(out.Payload, header.Version, t, st, hostinfo.remoteIndexId, c)
f.connectionManager.Out(hostinfo)
// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
// all our addrs and enable a faster roaming.
if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount {
//NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is
// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
hostinfo.lastRebindCount = f.rebindCount
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter")
}
}
var err error
out.Payload, err = ci.eKey.EncryptDanger(out.Payload, out.Payload, p, c, nb)
if noiseutil.EncryptLockNeeded {
ci.writeLock.Unlock()
}
if err != nil {
hostinfo.logger(f.l).WithError(err).
WithField("udpAddr", remote).WithField("counter", c).
WithField("attemptedCounter", c).
Error("Failed to encrypt outgoing packet")
return
}
if remote.IsValid() {
err = f.writers[q].Prep(out, remote)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", remote).Error("Failed to write outgoing packet")
}
} else if hostinfo.remote.IsValid() {
err = f.writers[q].Prep(out, hostinfo.remote)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", remote).Error("Failed to write outgoing packet")
}
} else {
// Try to send via a relay
for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
if err != nil {
hostinfo.relayState.DeleteRelay(relayIP)
hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")
continue
}
//todo vector!!
f.SendVia(relayHostInfo, relay, out.Payload, nb, fullOut[:header.Len+len(out.Payload)], true)
break
}
}
}

View File

@@ -318,15 +318,16 @@ func (f *Interface) listenIn(reader overlay.TunDev, queueNum int) {
for i := 0; i < batch; i++ {
originalPackets[i] = make([]byte, 0xffff)
}
out := make([]byte, mtu)
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
packets := make([]*packet.VirtIOPacket, batch)
outPackets := make([]*packet.Packet, batch)
for i := 0; i < batch; i++ {
packets[i] = packet.NewVIO()
outPackets[i] = packet.New(false) //todo?
}
for {
@@ -343,9 +344,13 @@ func (f *Interface) listenIn(reader overlay.TunDev, queueNum int) {
os.Exit(2)
}
//todo vectorize
for _, pkt := range packets[:n] {
f.consumeInsidePacket(pkt.Payload, fwPacket, nb, out, queueNum, conntrackCache.Get(f.l))
for i, pkt := range packets[:n] {
outPackets[i].OutLen = -1
f.consumeInsidePacket(pkt.Payload, fwPacket, nb, outPackets[i], queueNum, conntrackCache.Get(f.l))
}
_, err = f.writers[queueNum].WriteBatch(outPackets[:n])
if err != nil {
f.l.WithError(err).Error("Error while writing outbound packets")
}
}

View File

@@ -4,6 +4,8 @@ import (
"encoding/binary"
"iter"
"net/netip"
"syscall"
"unsafe"
"golang.org/x/sys/unix"
)
@@ -73,6 +75,32 @@ func (p *Packet) Update(ctrlLen int) {
p.updateCtrl(ctrlLen)
}
func (p *Packet) SetSegSizeForTX() {
p.SegSize = len(p.Payload)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&p.Control[0]))
hdr.Level = unix.SOL_UDP
hdr.Type = unix.UDP_SEGMENT
hdr.SetLen(syscall.CmsgLen(2))
//setCmsgLen(hdr, unix.CmsgLen(2))
binary.NativeEndian.PutUint16(p.Control[unix.CmsgLen(0):unix.CmsgLen(0)+2], uint16(p.SegSize))
//data := p.Control[syscall.CmsgSpace(0)-syscall.CmsgSpace(2)+syscall.SizeofCmsghdr:]
//binary.NativeEndian.PutUint16(data, uint16(p.SegSize))
}
func (p *Packet) CompatibleForSegmentationWith(otherP *Packet) bool {
//same dest
if p.AddrPort() != otherP.AddrPort() {
return false //todo more efficient?
}
//same body len
if len(p.Payload) != len(otherP.Payload) {
return false //todo technically you can cram one extra in
}
return true
}
func (p *Packet) Segments() iter.Seq[[]byte] {
return func(yield func([]byte) bool) {
//cursor := 0

View File

@@ -19,6 +19,8 @@ type Conn interface {
ListenOut(r EncReader)
WriteTo(b []byte, addr netip.AddrPort) error
ReloadConfig(c *config.C)
Prep(pkt *packet.Packet, addr netip.AddrPort) error
WriteBatch(pkt []*packet.Packet) (int, error)
Close() error
}

View File

@@ -14,6 +14,7 @@ import (
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/packet"
"golang.org/x/sys/unix"
)
@@ -191,6 +192,124 @@ func (u *StdConn) WriteToBatch(b []byte, ip netip.AddrPort) error {
return u.writeTo6(b, ip)
}
func (u *StdConn) Prep(pkt *packet.Packet, addr netip.AddrPort) error {
nl, err := u.encodeSockaddr(pkt.Name, addr)
if err != nil {
return err
}
pkt.Name = pkt.Name[:nl]
pkt.OutLen = len(pkt.Payload)
return nil
}
func (u *StdConn) WriteBatch(pkts []*packet.Packet) (int, error) {
if len(pkts) == 0 {
return 0, nil
}
msgs := make([]rawMessage, 0, len(pkts)) //todo recycle
iovs := make([][]iovec, 0, len(pkts))
sent := 0
var mostRecentPkt *packet.Packet
//segmenting := false
idx := 0
for _, pkt := range pkts {
if len(pkt.Payload) == 0 || pkt.OutLen == -1 {
sent++
continue
}
lastIdx := idx - 1
if mostRecentPkt != nil && pkt.CompatibleForSegmentationWith(mostRecentPkt) && msgs[lastIdx].Hdr.Iovlen < 4 {
msgs[lastIdx].Hdr.Controllen = uint64(len(mostRecentPkt.Control))
msgs[lastIdx].Hdr.Control = &mostRecentPkt.Control[0]
msgs[lastIdx].Hdr.Iovlen++
iovs[lastIdx] = append(iovs[lastIdx], iovec{
Base: &pkt.Payload[0],
Len: uint64(len(pkt.Payload)),
})
mostRecentPkt.SetSegSizeForTX()
} else {
msgs = append(msgs, rawMessage{})
iovs = append(iovs, make([]iovec, 1, 8)) //todo
iovs[idx][0] = iovec{
Base: &pkt.Payload[0],
Len: uint64(len(pkt.Payload)),
}
msg := &msgs[idx]
iov := &iovs[idx][0]
idx++
msg.Hdr.Iov = iov
msg.Hdr.Iovlen = 1
setRawMessageControl(msg, nil)
msg.Hdr.Flags = 0
msg.Hdr.Name = &pkt.Name[0]
msg.Hdr.Namelen = uint32(len(pkt.Name))
mostRecentPkt = pkt
}
}
if len(msgs) == 0 {
return sent, nil
}
offset := 0
for offset < len(msgs) {
n, _, errno := unix.Syscall6(
unix.SYS_SENDMMSG,
uintptr(u.sysFd),
uintptr(unsafe.Pointer(&msgs[offset])),
uintptr(len(msgs)-offset),
0,
0,
0,
)
if errno != 0 {
if errno == unix.EINTR {
continue
}
return sent + offset, &net.OpError{Op: "sendmmsg", Err: errno}
}
if n == 0 {
break
}
offset += int(n)
}
return sent + len(msgs), nil
}
func (u *StdConn) encodeSockaddr(dst []byte, addr netip.AddrPort) (uint32, error) {
if u.isV4 {
if !addr.Addr().Is4() {
return 0, fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
}
var sa unix.RawSockaddrInet4
sa.Family = unix.AF_INET
sa.Addr = addr.Addr().As4()
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
size := unix.SizeofSockaddrInet4
copy(dst[:size], (*(*[unix.SizeofSockaddrInet4]byte)(unsafe.Pointer(&sa)))[:])
return uint32(size), nil
}
var sa unix.RawSockaddrInet6
sa.Family = unix.AF_INET6
sa.Addr = addr.Addr().As16()
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
size := unix.SizeofSockaddrInet6
copy(dst[:size], (*(*[unix.SizeofSockaddrInet6]byte)(unsafe.Pointer(&sa)))[:])
return uint32(size), nil
}
func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
var rsa unix.RawSockaddrInet6
rsa.Family = unix.AF_INET6

View File

@@ -80,3 +80,13 @@ func (u *StdConn) PrepareRawMessages(n int, isV4 bool) ([]rawMessage, []*packet.
return msgs, packets
}
func setIovecSlice(iov *iovec, b []byte) {
if len(b) == 0 {
iov.Base = nil
iov.Len = 0
return
}
iov.Base = &b[0]
iov.Len = uint64(len(b))
}