hmm yes time

This commit is contained in:
JackDoan
2025-11-11 21:14:24 -06:00
parent 685ac3e112
commit c6bee8e981
6 changed files with 46 additions and 83 deletions

View File

@@ -425,9 +425,9 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
// Drop returns an error if the packet should be dropped, explaining why. It // Drop returns an error if the packet should be dropped, explaining why. It
// returns nil if the packet should not be dropped. // returns nil if the packet should not be dropped.
func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error { func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache, now time.Time) error {
// Check if we spoke to this tuple, if we did then allow this packet // Check if we spoke to this tuple, if we did then allow this packet
if f.inConns(fp, h, caPool, localCache) { if f.inConns(fp, h, caPool, localCache, now) {
return nil return nil
} }
@@ -476,7 +476,7 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
} }
// We always want to conntrack since it is a faster operation // We always want to conntrack since it is a faster operation
f.addConn(fp, incoming) f.addConn(fp, incoming, now)
return nil return nil
} }
@@ -505,7 +505,7 @@ func (f *Firewall) EmitStats() {
metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV())) metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
} }
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool { func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache, now time.Time) bool {
if localCache != nil { if localCache != nil {
if _, ok := localCache[fp]; ok { if _, ok := localCache[fp]; ok {
return true return true
@@ -517,7 +517,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
// Purge every time we test // Purge every time we test
ep, has := conntrack.TimerWheel.Purge() ep, has := conntrack.TimerWheel.Purge()
if has { if has {
f.evict(ep) f.evict(ep, now)
} }
c, ok := conntrack.Conns[fp] c, ok := conntrack.Conns[fp]
@@ -564,11 +564,11 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
switch fp.Protocol { switch fp.Protocol {
case firewall.ProtoTCP: case firewall.ProtoTCP:
c.Expires = time.Now().Add(f.TCPTimeout) c.Expires = now.Add(f.TCPTimeout)
case firewall.ProtoUDP: case firewall.ProtoUDP:
c.Expires = time.Now().Add(f.UDPTimeout) c.Expires = now.Add(f.UDPTimeout)
default: default:
c.Expires = time.Now().Add(f.DefaultTimeout) c.Expires = now.Add(f.DefaultTimeout)
} }
conntrack.Unlock() conntrack.Unlock()
@@ -580,7 +580,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
return true return true
} }
func (f *Firewall) addConn(fp firewall.Packet, incoming bool) { func (f *Firewall) addConn(fp firewall.Packet, incoming bool, now time.Time) {
var timeout time.Duration var timeout time.Duration
c := &conn{} c := &conn{}
@@ -596,7 +596,7 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
conntrack := f.Conntrack conntrack := f.Conntrack
conntrack.Lock() conntrack.Lock()
if _, ok := conntrack.Conns[fp]; !ok { if _, ok := conntrack.Conns[fp]; !ok {
conntrack.TimerWheel.Advance(time.Now()) conntrack.TimerWheel.Advance(now)
conntrack.TimerWheel.Add(fp, timeout) conntrack.TimerWheel.Add(fp, timeout)
} }
@@ -604,14 +604,14 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
// firewall reload // firewall reload
c.incoming = incoming c.incoming = incoming
c.rulesVersion = f.rulesVersion c.rulesVersion = f.rulesVersion
c.Expires = time.Now().Add(timeout) c.Expires = now.Add(timeout)
conntrack.Conns[fp] = c conntrack.Conns[fp] = c
conntrack.Unlock() conntrack.Unlock()
} }
// Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel // Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
// Caller must own the connMutex lock! // Caller must own the connMutex lock!
func (f *Firewall) evict(p firewall.Packet) { func (f *Firewall) evict(p firewall.Packet, now time.Time) {
// Are we still tracking this conn? // Are we still tracking this conn?
conntrack := f.Conntrack conntrack := f.Conntrack
t, ok := conntrack.Conns[p] t, ok := conntrack.Conns[p]
@@ -619,11 +619,11 @@ func (f *Firewall) evict(p firewall.Packet) {
return return
} }
newT := t.Expires.Sub(time.Now()) newT := t.Expires.Sub(now)
// Timeout is in the future, re-add the timer // Timeout is in the future, re-add the timer
if newT > 0 { if newT > 0 {
conntrack.TimerWheel.Advance(time.Now()) conntrack.TimerWheel.Advance(now)
conntrack.TimerWheel.Add(p, newT) conntrack.TimerWheel.Add(p, newT)
return return
} }

View File

@@ -2,6 +2,7 @@ package nebula
import ( import (
"net/netip" "net/netip"
"time"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
@@ -12,7 +13,7 @@ import (
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
) )
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb []byte, out *packet.Packet, q int, localCache firewall.ConntrackCache) { func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb []byte, out *packet.Packet, q int, localCache firewall.ConntrackCache, now time.Time) {
err := newPacket(packet, false, fwPacket) err := newPacket(packet, false, fwPacket)
if err != nil { if err != nil {
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
@@ -67,7 +68,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
return return
} }
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache, now)
if dropReason == nil { if dropReason == nil {
f.sendNoMetricsDelayed(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q) f.sendNoMetricsDelayed(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q)
} else { } else {
@@ -218,7 +219,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
} }
// check if packet is in outbound fw rules // check if packet is in outbound fw rules
dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil) dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil, time.Now())
if dropReason != nil { if dropReason != nil {
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
f.l.WithField("fwPacket", fp). f.l.WithField("fwPacket", fp).

View File

@@ -287,7 +287,7 @@ func (f *Interface) listenOut(q int) {
outPackets[i].SegCounter = 0 outPackets[i].SegCounter = 0
} }
f.readOutsidePacketsMany(pkts, outPackets, h, fwPacket, lhh, nb, q, ctCache.Get(f.l)) f.readOutsidePacketsMany(pkts, outPackets, h, fwPacket, lhh, nb, q, ctCache.Get(f.l), time.Now())
for i := range pkts { for i := range pkts {
if pkts[i].OutLen != -1 { if pkts[i].OutLen != -1 {
for j := 0; j < outPackets[i].SegCounter; j++ { for j := 0; j < outPackets[i].SegCounter; j++ {
@@ -344,9 +344,10 @@ func (f *Interface) listenIn(reader overlay.TunDev, queueNum int) {
os.Exit(2) os.Exit(2)
} }
now := time.Now()
for i, pkt := range packets[:n] { for i, pkt := range packets[:n] {
outPackets[i].OutLen = -1 outPackets[i].OutLen = -1
f.consumeInsidePacket(pkt.Payload, fwPacket, nb, outPackets[i], queueNum, conntrackCache.Get(f.l)) f.consumeInsidePacket(pkt.Payload, fwPacket, nb, outPackets[i], queueNum, conntrackCache.Get(f.l), now)
} }
_, err = f.writers[queueNum].WriteBatch(outPackets[:n]) _, err = f.writers[queueNum].WriteBatch(outPackets[:n])
if err != nil { if err != nil {

View File

@@ -20,7 +20,7 @@ const (
minFwPacketLen = 4 minFwPacketLen = 4
) )
func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) {
err := h.Parse(packet) err := h.Parse(packet)
if err != nil { if err != nil {
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
@@ -62,7 +62,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
switch h.Subtype { switch h.Subtype {
case header.MessageNone: case header.MessageNone:
if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache) { if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache, now) {
return return
} }
case header.MessageRelay: case header.MessageRelay:
@@ -97,7 +97,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
case TerminalType: case TerminalType:
// If I am the target of this relay, process the unwrapped packet // If I am the target of this relay, process the unwrapped packet
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again. // From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache, now)
return return
case ForwardingType: case ForwardingType:
// Find the target HostInfo relay object // Find the target HostInfo relay object
@@ -217,7 +217,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
f.connectionManager.In(hostinfo) f.connectionManager.In(hostinfo)
} }
func (f *Interface) readOutsidePacketsMany(packets []*packet.Packet, out []*packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { func (f *Interface) readOutsidePacketsMany(packets []*packet.Packet, out []*packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) {
for i, pkt := range packets { for i, pkt := range packets {
out[i].Scratch = out[i].Scratch[:0] out[i].Scratch = out[i].Scratch[:0]
ip := pkt.AddrPort() ip := pkt.AddrPort()
@@ -266,7 +266,7 @@ func (f *Interface) readOutsidePacketsMany(packets []*packet.Packet, out []*pack
switch h.Subtype { switch h.Subtype {
case header.MessageNone: case header.MessageNone:
if !f.decryptToTunDelayWrite(hostinfo, h.MessageCounter, out[i], pkt, segment, fwPacket, nb, q, localCache) { if !f.decryptToTunDelayWrite(hostinfo, h.MessageCounter, out[i], pkt, segment, fwPacket, nb, q, localCache, now) {
return return
} }
case header.MessageRelay: case header.MessageRelay:
@@ -301,7 +301,7 @@ func (f *Interface) readOutsidePacketsMany(packets []*packet.Packet, out []*pack
case TerminalType: case TerminalType:
// If I am the target of this relay, process the unwrapped packet // If I am the target of this relay, process the unwrapped packet
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again. // From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[i].Scratch[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[i].Scratch[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache, now)
return return
case ForwardingType: case ForwardingType:
// Find the target HostInfo relay object // Find the target HostInfo relay object
@@ -672,7 +672,7 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
return out, nil return out, nil
} }
func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter uint64, out *packet.OutPacket, pkt *packet.Packet, inSegment []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool { func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter uint64, out *packet.OutPacket, pkt *packet.Packet, inSegment []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) bool {
var err error var err error
out.Segments[out.SegCounter] = out.Segments[out.SegCounter][:0] out.Segments[out.SegCounter] = out.Segments[out.SegCounter][:0]
@@ -695,7 +695,7 @@ func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter ui
return false return false
} }
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache) dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache, now)
if dropReason != nil { if dropReason != nil {
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore // NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
// This gives us a buffer to build the reject packet in // This gives us a buffer to build the reject packet in
@@ -714,7 +714,7 @@ func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter ui
return true return true
} }
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool { func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) bool {
var err error var err error
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb) out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
@@ -736,7 +736,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
return false return false
} }
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache) dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache, now)
if dropReason != nil { if dropReason != nil {
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore // NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
// This gives us a buffer to build the reject packet in // This gives us a buffer to build the reject packet in

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"os" "os"
"runtime" "runtime"
"slices"
"github.com/slackhq/nebula/overlay/vhost" "github.com/slackhq/nebula/overlay/vhost"
"github.com/slackhq/nebula/overlay/virtqueue" "github.com/slackhq/nebula/overlay/virtqueue"
@@ -249,10 +250,6 @@ func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) erro
return fmt.Errorf("offer descriptor chain: %w", err) return fmt.Errorf("offer descriptor chain: %w", err)
} }
//todo surely there's something better to do here //todo surely there's something better to do here
doneYet := map[uint16]bool{}
for _, chain := range chainIndexes {
doneYet[chain] = false
}
for { for {
txedChains, err := dev.TransmitQueue.BlockAndGetHeads(context.TODO()) txedChains, err := dev.TransmitQueue.BlockAndGetHeads(context.TODO())
@@ -261,31 +258,27 @@ func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) erro
} else if len(txedChains) == 0 { } else if len(txedChains) == 0 {
continue //todo will this ever exit? continue //todo will this ever exit?
} }
for c := range txedChains { for _, c := range txedChains {
doneYet[txedChains[c].GetHead()] = true idx := slices.Index(chainIndexes, c.GetHead())
if idx < 0 {
continue
} else {
_ = dev.TransmitQueue.FreeDescriptorChain(chainIndexes[idx])
chainIndexes[idx] = 0 //todo I hope this works
}
} }
done := true //optimism! done := true //optimism!
for _, x := range doneYet { for _, x := range chainIndexes {
if !x { if x != 0 {
done = false done = false
break break
} }
} }
if done { if done {
break
}
}
// Wait for the packet to have been transmitted.
for i := range chainIndexes {
if err = dev.TransmitQueue.FreeDescriptorChain(chainIndexes[i]); err != nil {
return fmt.Errorf("free descriptor chain: %w", err)
}
}
return nil return nil
}
}
} }
// TODO: Make above methods cancelable by taking a context.Context argument? // TODO: Make above methods cancelable by taking a context.Context argument?

View File

@@ -4,6 +4,7 @@ import (
"encoding/binary" "encoding/binary"
"iter" "iter"
"net/netip" "net/netip"
"slices"
"syscall" "syscall"
"unsafe" "unsafe"
@@ -23,7 +24,6 @@ type Packet struct {
wasSegmented bool wasSegmented bool
isV4 bool isV4 bool
//Addr netip.AddrPort
} }
func New(isV4 bool) *Packet { func New(isV4 bool) *Packet {
@@ -81,17 +81,13 @@ func (p *Packet) SetSegSizeForTX() {
hdr.Level = unix.SOL_UDP hdr.Level = unix.SOL_UDP
hdr.Type = unix.UDP_SEGMENT hdr.Type = unix.UDP_SEGMENT
hdr.SetLen(syscall.CmsgLen(2)) hdr.SetLen(syscall.CmsgLen(2))
//setCmsgLen(hdr, unix.CmsgLen(2))
binary.NativeEndian.PutUint16(p.Control[unix.CmsgLen(0):unix.CmsgLen(0)+2], uint16(p.SegSize)) binary.NativeEndian.PutUint16(p.Control[unix.CmsgLen(0):unix.CmsgLen(0)+2], uint16(p.SegSize))
//data := p.Control[syscall.CmsgSpace(0)-syscall.CmsgSpace(2)+syscall.SizeofCmsghdr:]
//binary.NativeEndian.PutUint16(data, uint16(p.SegSize))
} }
func (p *Packet) CompatibleForSegmentationWith(otherP *Packet) bool { func (p *Packet) CompatibleForSegmentationWith(otherP *Packet) bool {
//same dest //same dest
if !slices.Equal(p.Name, otherP.Name) {
if p.AddrPort() != otherP.AddrPort() { return false
return false //todo more efficient?
} }
//same body len //same body len
@@ -113,33 +109,5 @@ func (p *Packet) Segments() iter.Seq[[]byte] {
return return
} }
} }
//if p.SegSize > 0 && p.SegSize < len(p.Payload) {
//
//} else {
// f.readOutsidePackets(p.Addr, nil, result2[:0], p.Payload, h, fwPacket2, lhh, nb2, i, conntrackCache.Get(f.l))
//}
} }
} }
//type Pool struct {
// pool sync.Pool
//}
//
//var bigPool = &Pool{
// pool: sync.Pool{New: func() any { return New() }},
//}
//
//func GetPool() *Pool {
// return bigPool
//}
//
//func (p *Pool) Get() *Packet {
// return p.pool.Get().(*Packet)
//}
//
//func (p *Pool) Put(x *Packet) {
// x.Payload = x.Payload[:Size]
// p.pool.Put(x)
//}