hmm yes time

This commit is contained in:
JackDoan
2025-11-11 21:14:24 -06:00
parent 685ac3e112
commit c6bee8e981
6 changed files with 46 additions and 83 deletions

View File

@@ -425,9 +425,9 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
// Drop returns an error if the packet should be dropped, explaining why. It
// returns nil if the packet should not be dropped.
func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error {
func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache, now time.Time) error {
// Check if we spoke to this tuple, if we did then allow this packet
if f.inConns(fp, h, caPool, localCache) {
if f.inConns(fp, h, caPool, localCache, now) {
return nil
}
@@ -476,7 +476,7 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
}
// We always want to conntrack since it is a faster operation
f.addConn(fp, incoming)
f.addConn(fp, incoming, now)
return nil
}
@@ -505,7 +505,7 @@ func (f *Firewall) EmitStats() {
metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
}
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool {
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache, now time.Time) bool {
if localCache != nil {
if _, ok := localCache[fp]; ok {
return true
@@ -517,7 +517,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
// Purge every time we test
ep, has := conntrack.TimerWheel.Purge()
if has {
f.evict(ep)
f.evict(ep, now)
}
c, ok := conntrack.Conns[fp]
@@ -564,11 +564,11 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
switch fp.Protocol {
case firewall.ProtoTCP:
c.Expires = time.Now().Add(f.TCPTimeout)
c.Expires = now.Add(f.TCPTimeout)
case firewall.ProtoUDP:
c.Expires = time.Now().Add(f.UDPTimeout)
c.Expires = now.Add(f.UDPTimeout)
default:
c.Expires = time.Now().Add(f.DefaultTimeout)
c.Expires = now.Add(f.DefaultTimeout)
}
conntrack.Unlock()
@@ -580,7 +580,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
return true
}
func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
func (f *Firewall) addConn(fp firewall.Packet, incoming bool, now time.Time) {
var timeout time.Duration
c := &conn{}
@@ -596,7 +596,7 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
conntrack := f.Conntrack
conntrack.Lock()
if _, ok := conntrack.Conns[fp]; !ok {
conntrack.TimerWheel.Advance(time.Now())
conntrack.TimerWheel.Advance(now)
conntrack.TimerWheel.Add(fp, timeout)
}
@@ -604,14 +604,14 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
// firewall reload
c.incoming = incoming
c.rulesVersion = f.rulesVersion
c.Expires = time.Now().Add(timeout)
c.Expires = now.Add(timeout)
conntrack.Conns[fp] = c
conntrack.Unlock()
}
// Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
// Caller must own the connMutex lock!
func (f *Firewall) evict(p firewall.Packet) {
func (f *Firewall) evict(p firewall.Packet, now time.Time) {
// Are we still tracking this conn?
conntrack := f.Conntrack
t, ok := conntrack.Conns[p]
@@ -619,11 +619,11 @@ func (f *Firewall) evict(p firewall.Packet) {
return
}
newT := t.Expires.Sub(time.Now())
newT := t.Expires.Sub(now)
// Timeout is in the future, re-add the timer
if newT > 0 {
conntrack.TimerWheel.Advance(time.Now())
conntrack.TimerWheel.Advance(now)
conntrack.TimerWheel.Add(p, newT)
return
}

View File

@@ -2,6 +2,7 @@ package nebula
import (
"net/netip"
"time"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/firewall"
@@ -12,7 +13,7 @@ import (
"github.com/slackhq/nebula/routing"
)
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb []byte, out *packet.Packet, q int, localCache firewall.ConntrackCache) {
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb []byte, out *packet.Packet, q int, localCache firewall.ConntrackCache, now time.Time) {
err := newPacket(packet, false, fwPacket)
if err != nil {
if f.l.Level >= logrus.DebugLevel {
@@ -67,7 +68,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
return
}
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache, now)
if dropReason == nil {
f.sendNoMetricsDelayed(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q)
} else {
@@ -218,7 +219,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
}
// check if packet is in outbound fw rules
dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil)
dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil, time.Now())
if dropReason != nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("fwPacket", fp).

View File

@@ -287,7 +287,7 @@ func (f *Interface) listenOut(q int) {
outPackets[i].SegCounter = 0
}
f.readOutsidePacketsMany(pkts, outPackets, h, fwPacket, lhh, nb, q, ctCache.Get(f.l))
f.readOutsidePacketsMany(pkts, outPackets, h, fwPacket, lhh, nb, q, ctCache.Get(f.l), time.Now())
for i := range pkts {
if pkts[i].OutLen != -1 {
for j := 0; j < outPackets[i].SegCounter; j++ {
@@ -344,9 +344,10 @@ func (f *Interface) listenIn(reader overlay.TunDev, queueNum int) {
os.Exit(2)
}
now := time.Now()
for i, pkt := range packets[:n] {
outPackets[i].OutLen = -1
f.consumeInsidePacket(pkt.Payload, fwPacket, nb, outPackets[i], queueNum, conntrackCache.Get(f.l))
f.consumeInsidePacket(pkt.Payload, fwPacket, nb, outPackets[i], queueNum, conntrackCache.Get(f.l), now)
}
_, err = f.writers[queueNum].WriteBatch(outPackets[:n])
if err != nil {

View File

@@ -20,7 +20,7 @@ const (
minFwPacketLen = 4
)
func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) {
err := h.Parse(packet)
if err != nil {
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
@@ -62,7 +62,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
switch h.Subtype {
case header.MessageNone:
if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache) {
if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache, now) {
return
}
case header.MessageRelay:
@@ -97,7 +97,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
case TerminalType:
// If I am the target of this relay, process the unwrapped packet
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache, now)
return
case ForwardingType:
// Find the target HostInfo relay object
@@ -217,7 +217,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
f.connectionManager.In(hostinfo)
}
func (f *Interface) readOutsidePacketsMany(packets []*packet.Packet, out []*packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
func (f *Interface) readOutsidePacketsMany(packets []*packet.Packet, out []*packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) {
for i, pkt := range packets {
out[i].Scratch = out[i].Scratch[:0]
ip := pkt.AddrPort()
@@ -266,7 +266,7 @@ func (f *Interface) readOutsidePacketsMany(packets []*packet.Packet, out []*pack
switch h.Subtype {
case header.MessageNone:
if !f.decryptToTunDelayWrite(hostinfo, h.MessageCounter, out[i], pkt, segment, fwPacket, nb, q, localCache) {
if !f.decryptToTunDelayWrite(hostinfo, h.MessageCounter, out[i], pkt, segment, fwPacket, nb, q, localCache, now) {
return
}
case header.MessageRelay:
@@ -301,7 +301,7 @@ func (f *Interface) readOutsidePacketsMany(packets []*packet.Packet, out []*pack
case TerminalType:
// If I am the target of this relay, process the unwrapped packet
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[i].Scratch[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[i].Scratch[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache, now)
return
case ForwardingType:
// Find the target HostInfo relay object
@@ -672,7 +672,7 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
return out, nil
}
func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter uint64, out *packet.OutPacket, pkt *packet.Packet, inSegment []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool {
func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter uint64, out *packet.OutPacket, pkt *packet.Packet, inSegment []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) bool {
var err error
out.Segments[out.SegCounter] = out.Segments[out.SegCounter][:0]
@@ -695,7 +695,7 @@ func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter ui
return false
}
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache, now)
if dropReason != nil {
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
// This gives us a buffer to build the reject packet in
@@ -714,7 +714,7 @@ func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter ui
return true
}
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool {
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) bool {
var err error
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
@@ -736,7 +736,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
return false
}
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache, now)
if dropReason != nil {
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
// This gives us a buffer to build the reject packet in

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"os"
"runtime"
"slices"
"github.com/slackhq/nebula/overlay/vhost"
"github.com/slackhq/nebula/overlay/virtqueue"
@@ -249,10 +250,6 @@ func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) erro
return fmt.Errorf("offer descriptor chain: %w", err)
}
//todo surely there's something better to do here
doneYet := map[uint16]bool{}
for _, chain := range chainIndexes {
doneYet[chain] = false
}
for {
txedChains, err := dev.TransmitQueue.BlockAndGetHeads(context.TODO())
@@ -261,31 +258,27 @@ func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) erro
} else if len(txedChains) == 0 {
continue //todo will this ever exit?
}
for c := range txedChains {
doneYet[txedChains[c].GetHead()] = true
for _, c := range txedChains {
idx := slices.Index(chainIndexes, c.GetHead())
if idx < 0 {
continue
} else {
_ = dev.TransmitQueue.FreeDescriptorChain(chainIndexes[idx])
chainIndexes[idx] = 0 //todo I hope this works
}
}
done := true //optimism!
for _, x := range doneYet {
if !x {
for _, x := range chainIndexes {
if x != 0 {
done = false
break
}
}
if done {
break
}
}
// Wait for the packet to have been transmitted.
for i := range chainIndexes {
if err = dev.TransmitQueue.FreeDescriptorChain(chainIndexes[i]); err != nil {
return fmt.Errorf("free descriptor chain: %w", err)
}
}
return nil
}
}
}
// TODO: Make above methods cancelable by taking a context.Context argument?

View File

@@ -4,6 +4,7 @@ import (
"encoding/binary"
"iter"
"net/netip"
"slices"
"syscall"
"unsafe"
@@ -23,7 +24,6 @@ type Packet struct {
wasSegmented bool
isV4 bool
//Addr netip.AddrPort
}
func New(isV4 bool) *Packet {
@@ -81,17 +81,13 @@ func (p *Packet) SetSegSizeForTX() {
hdr.Level = unix.SOL_UDP
hdr.Type = unix.UDP_SEGMENT
hdr.SetLen(syscall.CmsgLen(2))
//setCmsgLen(hdr, unix.CmsgLen(2))
binary.NativeEndian.PutUint16(p.Control[unix.CmsgLen(0):unix.CmsgLen(0)+2], uint16(p.SegSize))
//data := p.Control[syscall.CmsgSpace(0)-syscall.CmsgSpace(2)+syscall.SizeofCmsghdr:]
//binary.NativeEndian.PutUint16(data, uint16(p.SegSize))
}
func (p *Packet) CompatibleForSegmentationWith(otherP *Packet) bool {
//same dest
if p.AddrPort() != otherP.AddrPort() {
return false //todo more efficient?
if !slices.Equal(p.Name, otherP.Name) {
return false
}
//same body len
@@ -113,33 +109,5 @@ func (p *Packet) Segments() iter.Seq[[]byte] {
return
}
}
//if p.SegSize > 0 && p.SegSize < len(p.Payload) {
//
//} else {
// f.readOutsidePackets(p.Addr, nil, result2[:0], p.Payload, h, fwPacket2, lhh, nb2, i, conntrackCache.Get(f.l))
//}
}
}
//type Pool struct {
// pool sync.Pool
//}
//
//var bigPool = &Pool{
// pool: sync.Pool{New: func() any { return New() }},
//}
//
//func GetPool() *Pool {
// return bigPool
//}
//
//func (p *Pool) Get() *Packet {
// return p.pool.Get().(*Packet)
//}
//
//func (p *Pool) Put(x *Packet) {
// x.Payload = x.Payload[:Size]
// p.pool.Put(x)
//}