mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 08:24:25 +01:00
why does it work
This commit is contained in:
88
inside.go
88
inside.go
@@ -8,10 +8,11 @@ import (
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/slackhq/nebula/noiseutil"
|
||||
"github.com/slackhq/nebula/packet"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
)
|
||||
|
||||
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
|
||||
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb []byte, out *packet.Packet, q int, localCache firewall.ConntrackCache) {
|
||||
err := newPacket(packet, false, fwPacket)
|
||||
if err != nil {
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
@@ -53,7 +54,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
||||
})
|
||||
|
||||
if hostinfo == nil {
|
||||
f.rejectInside(packet, out, q)
|
||||
f.rejectInside(packet, out.Payload, q) //todo vector?
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("vpnAddr", fwPacket.RemoteAddr).
|
||||
WithField("fwPacket", fwPacket).
|
||||
@@ -68,10 +69,9 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
||||
|
||||
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
|
||||
if dropReason == nil {
|
||||
f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q)
|
||||
|
||||
f.sendNoMetricsDelayed(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q)
|
||||
} else {
|
||||
f.rejectInside(packet, out, q)
|
||||
f.rejectInside(packet, out.Payload, q) //todo vector?
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
hostinfo.logger(f.l).
|
||||
WithField("fwPacket", fwPacket).
|
||||
@@ -410,3 +410,81 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Interface) sendNoMetricsDelayed(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb []byte, out *packet.Packet, q int) {
|
||||
if ci.eKey == nil {
|
||||
return
|
||||
}
|
||||
useRelay := !remote.IsValid() && !hostinfo.remote.IsValid()
|
||||
fullOut := out.Payload
|
||||
|
||||
if useRelay {
|
||||
if len(out.Payload) < header.Len {
|
||||
// out always has a capacity of mtu, but not always a length greater than the header.Len.
|
||||
// Grow it to make sure the next operation works.
|
||||
out.Payload = out.Payload[:header.Len]
|
||||
}
|
||||
// Save a header's worth of data at the front of the 'out' buffer.
|
||||
out.Payload = out.Payload[header.Len:]
|
||||
}
|
||||
|
||||
if noiseutil.EncryptLockNeeded {
|
||||
// NOTE: for goboring AESGCMTLS we need to lock because of the nonce check
|
||||
ci.writeLock.Lock()
|
||||
}
|
||||
c := ci.messageCounter.Add(1)
|
||||
|
||||
//l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p)
|
||||
out.Payload = header.Encode(out.Payload, header.Version, t, st, hostinfo.remoteIndexId, c)
|
||||
f.connectionManager.Out(hostinfo)
|
||||
|
||||
// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
|
||||
// all our addrs and enable a faster roaming.
|
||||
if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount {
|
||||
//NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is
|
||||
// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
|
||||
f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
|
||||
hostinfo.lastRebindCount = f.rebindCount
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter")
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
out.Payload, err = ci.eKey.EncryptDanger(out.Payload, out.Payload, p, c, nb)
|
||||
if noiseutil.EncryptLockNeeded {
|
||||
ci.writeLock.Unlock()
|
||||
}
|
||||
if err != nil {
|
||||
hostinfo.logger(f.l).WithError(err).
|
||||
WithField("udpAddr", remote).WithField("counter", c).
|
||||
WithField("attemptedCounter", c).
|
||||
Error("Failed to encrypt outgoing packet")
|
||||
return
|
||||
}
|
||||
|
||||
if remote.IsValid() {
|
||||
err = f.writers[q].Prep(out, remote)
|
||||
if err != nil {
|
||||
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", remote).Error("Failed to write outgoing packet")
|
||||
}
|
||||
} else if hostinfo.remote.IsValid() {
|
||||
err = f.writers[q].Prep(out, hostinfo.remote)
|
||||
if err != nil {
|
||||
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", remote).Error("Failed to write outgoing packet")
|
||||
}
|
||||
} else {
|
||||
// Try to send via a relay
|
||||
for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
|
||||
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
|
||||
if err != nil {
|
||||
hostinfo.relayState.DeleteRelay(relayIP)
|
||||
hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")
|
||||
continue
|
||||
}
|
||||
//todo vector!!
|
||||
f.SendVia(relayHostInfo, relay, out.Payload, nb, fullOut[:header.Len+len(out.Payload)], true)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
13
interface.go
13
interface.go
@@ -318,15 +318,16 @@ func (f *Interface) listenIn(reader overlay.TunDev, queueNum int) {
|
||||
for i := 0; i < batch; i++ {
|
||||
originalPackets[i] = make([]byte, 0xffff)
|
||||
}
|
||||
out := make([]byte, mtu)
|
||||
fwPacket := &firewall.Packet{}
|
||||
nb := make([]byte, 12, 12)
|
||||
|
||||
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||
|
||||
packets := make([]*packet.VirtIOPacket, batch)
|
||||
outPackets := make([]*packet.Packet, batch)
|
||||
for i := 0; i < batch; i++ {
|
||||
packets[i] = packet.NewVIO()
|
||||
outPackets[i] = packet.New(false) //todo?
|
||||
}
|
||||
|
||||
for {
|
||||
@@ -343,9 +344,13 @@ func (f *Interface) listenIn(reader overlay.TunDev, queueNum int) {
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
//todo vectorize
|
||||
for _, pkt := range packets[:n] {
|
||||
f.consumeInsidePacket(pkt.Payload, fwPacket, nb, out, queueNum, conntrackCache.Get(f.l))
|
||||
for i, pkt := range packets[:n] {
|
||||
outPackets[i].OutLen = -1
|
||||
f.consumeInsidePacket(pkt.Payload, fwPacket, nb, outPackets[i], queueNum, conntrackCache.Get(f.l))
|
||||
}
|
||||
_, err = f.writers[queueNum].WriteBatch(outPackets[:n])
|
||||
if err != nil {
|
||||
f.l.WithError(err).Error("Error while writing outbound packets")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"encoding/binary"
|
||||
"iter"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
@@ -73,6 +75,32 @@ func (p *Packet) Update(ctrlLen int) {
|
||||
p.updateCtrl(ctrlLen)
|
||||
}
|
||||
|
||||
func (p *Packet) SetSegSizeForTX() {
|
||||
p.SegSize = len(p.Payload)
|
||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&p.Control[0]))
|
||||
hdr.Level = unix.SOL_UDP
|
||||
hdr.Type = unix.UDP_SEGMENT
|
||||
hdr.SetLen(syscall.CmsgLen(2))
|
||||
//setCmsgLen(hdr, unix.CmsgLen(2))
|
||||
binary.NativeEndian.PutUint16(p.Control[unix.CmsgLen(0):unix.CmsgLen(0)+2], uint16(p.SegSize))
|
||||
//data := p.Control[syscall.CmsgSpace(0)-syscall.CmsgSpace(2)+syscall.SizeofCmsghdr:]
|
||||
//binary.NativeEndian.PutUint16(data, uint16(p.SegSize))
|
||||
}
|
||||
|
||||
func (p *Packet) CompatibleForSegmentationWith(otherP *Packet) bool {
|
||||
//same dest
|
||||
|
||||
if p.AddrPort() != otherP.AddrPort() {
|
||||
return false //todo more efficient?
|
||||
}
|
||||
|
||||
//same body len
|
||||
if len(p.Payload) != len(otherP.Payload) {
|
||||
return false //todo technically you can cram one extra in
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *Packet) Segments() iter.Seq[[]byte] {
|
||||
return func(yield func([]byte) bool) {
|
||||
//cursor := 0
|
||||
|
||||
@@ -19,6 +19,8 @@ type Conn interface {
|
||||
ListenOut(r EncReader)
|
||||
WriteTo(b []byte, addr netip.AddrPort) error
|
||||
ReloadConfig(c *config.C)
|
||||
Prep(pkt *packet.Packet, addr netip.AddrPort) error
|
||||
WriteBatch(pkt []*packet.Packet) (int, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
|
||||
119
udp/udp_linux.go
119
udp/udp_linux.go
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/packet"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
@@ -191,6 +192,124 @@ func (u *StdConn) WriteToBatch(b []byte, ip netip.AddrPort) error {
|
||||
return u.writeTo6(b, ip)
|
||||
}
|
||||
|
||||
func (u *StdConn) Prep(pkt *packet.Packet, addr netip.AddrPort) error {
|
||||
nl, err := u.encodeSockaddr(pkt.Name, addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pkt.Name = pkt.Name[:nl]
|
||||
pkt.OutLen = len(pkt.Payload)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *StdConn) WriteBatch(pkts []*packet.Packet) (int, error) {
|
||||
if len(pkts) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
msgs := make([]rawMessage, 0, len(pkts)) //todo recycle
|
||||
iovs := make([][]iovec, 0, len(pkts))
|
||||
|
||||
sent := 0
|
||||
|
||||
var mostRecentPkt *packet.Packet
|
||||
//segmenting := false
|
||||
idx := 0
|
||||
for _, pkt := range pkts {
|
||||
if len(pkt.Payload) == 0 || pkt.OutLen == -1 {
|
||||
sent++
|
||||
continue
|
||||
}
|
||||
lastIdx := idx - 1
|
||||
if mostRecentPkt != nil && pkt.CompatibleForSegmentationWith(mostRecentPkt) && msgs[lastIdx].Hdr.Iovlen < 4 {
|
||||
|
||||
msgs[lastIdx].Hdr.Controllen = uint64(len(mostRecentPkt.Control))
|
||||
msgs[lastIdx].Hdr.Control = &mostRecentPkt.Control[0]
|
||||
msgs[lastIdx].Hdr.Iovlen++
|
||||
iovs[lastIdx] = append(iovs[lastIdx], iovec{
|
||||
Base: &pkt.Payload[0],
|
||||
Len: uint64(len(pkt.Payload)),
|
||||
})
|
||||
mostRecentPkt.SetSegSizeForTX()
|
||||
} else {
|
||||
msgs = append(msgs, rawMessage{})
|
||||
iovs = append(iovs, make([]iovec, 1, 8)) //todo
|
||||
iovs[idx][0] = iovec{
|
||||
Base: &pkt.Payload[0],
|
||||
Len: uint64(len(pkt.Payload)),
|
||||
}
|
||||
|
||||
msg := &msgs[idx]
|
||||
iov := &iovs[idx][0]
|
||||
idx++
|
||||
|
||||
msg.Hdr.Iov = iov
|
||||
msg.Hdr.Iovlen = 1
|
||||
setRawMessageControl(msg, nil)
|
||||
msg.Hdr.Flags = 0
|
||||
|
||||
msg.Hdr.Name = &pkt.Name[0]
|
||||
msg.Hdr.Namelen = uint32(len(pkt.Name))
|
||||
mostRecentPkt = pkt
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if len(msgs) == 0 {
|
||||
return sent, nil
|
||||
}
|
||||
|
||||
offset := 0
|
||||
for offset < len(msgs) {
|
||||
n, _, errno := unix.Syscall6(
|
||||
unix.SYS_SENDMMSG,
|
||||
uintptr(u.sysFd),
|
||||
uintptr(unsafe.Pointer(&msgs[offset])),
|
||||
uintptr(len(msgs)-offset),
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
|
||||
if errno != 0 {
|
||||
if errno == unix.EINTR {
|
||||
continue
|
||||
}
|
||||
return sent + offset, &net.OpError{Op: "sendmmsg", Err: errno}
|
||||
}
|
||||
|
||||
if n == 0 {
|
||||
break
|
||||
}
|
||||
offset += int(n)
|
||||
}
|
||||
|
||||
return sent + len(msgs), nil
|
||||
}
|
||||
|
||||
func (u *StdConn) encodeSockaddr(dst []byte, addr netip.AddrPort) (uint32, error) {
|
||||
if u.isV4 {
|
||||
if !addr.Addr().Is4() {
|
||||
return 0, fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
|
||||
}
|
||||
var sa unix.RawSockaddrInet4
|
||||
sa.Family = unix.AF_INET
|
||||
sa.Addr = addr.Addr().As4()
|
||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
|
||||
size := unix.SizeofSockaddrInet4
|
||||
copy(dst[:size], (*(*[unix.SizeofSockaddrInet4]byte)(unsafe.Pointer(&sa)))[:])
|
||||
return uint32(size), nil
|
||||
}
|
||||
|
||||
var sa unix.RawSockaddrInet6
|
||||
sa.Family = unix.AF_INET6
|
||||
sa.Addr = addr.Addr().As16()
|
||||
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port())
|
||||
size := unix.SizeofSockaddrInet6
|
||||
copy(dst[:size], (*(*[unix.SizeofSockaddrInet6]byte)(unsafe.Pointer(&sa)))[:])
|
||||
return uint32(size), nil
|
||||
}
|
||||
|
||||
func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
|
||||
var rsa unix.RawSockaddrInet6
|
||||
rsa.Family = unix.AF_INET6
|
||||
|
||||
@@ -80,3 +80,13 @@ func (u *StdConn) PrepareRawMessages(n int, isV4 bool) ([]rawMessage, []*packet.
|
||||
|
||||
return msgs, packets
|
||||
}
|
||||
|
||||
func setIovecSlice(iov *iovec, b []byte) {
|
||||
if len(b) == 0 {
|
||||
iov.Base = nil
|
||||
iov.Len = 0
|
||||
return
|
||||
}
|
||||
iov.Base = &b[0]
|
||||
iov.Len = uint64(len(b))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user