mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-23 00:44:25 +01:00
hmm yes time
This commit is contained in:
28
firewall.go
28
firewall.go
@@ -425,9 +425,9 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
|
||||
|
||||
// Drop returns an error if the packet should be dropped, explaining why. It
|
||||
// returns nil if the packet should not be dropped.
|
||||
func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error {
|
||||
func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache, now time.Time) error {
|
||||
// Check if we spoke to this tuple, if we did then allow this packet
|
||||
if f.inConns(fp, h, caPool, localCache) {
|
||||
if f.inConns(fp, h, caPool, localCache, now) {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -476,7 +476,7 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
|
||||
}
|
||||
|
||||
// We always want to conntrack since it is a faster operation
|
||||
f.addConn(fp, incoming)
|
||||
f.addConn(fp, incoming, now)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -505,7 +505,7 @@ func (f *Firewall) EmitStats() {
|
||||
metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
|
||||
}
|
||||
|
||||
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool {
|
||||
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache, now time.Time) bool {
|
||||
if localCache != nil {
|
||||
if _, ok := localCache[fp]; ok {
|
||||
return true
|
||||
@@ -517,7 +517,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
|
||||
// Purge every time we test
|
||||
ep, has := conntrack.TimerWheel.Purge()
|
||||
if has {
|
||||
f.evict(ep)
|
||||
f.evict(ep, now)
|
||||
}
|
||||
|
||||
c, ok := conntrack.Conns[fp]
|
||||
@@ -564,11 +564,11 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
|
||||
|
||||
switch fp.Protocol {
|
||||
case firewall.ProtoTCP:
|
||||
c.Expires = time.Now().Add(f.TCPTimeout)
|
||||
c.Expires = now.Add(f.TCPTimeout)
|
||||
case firewall.ProtoUDP:
|
||||
c.Expires = time.Now().Add(f.UDPTimeout)
|
||||
c.Expires = now.Add(f.UDPTimeout)
|
||||
default:
|
||||
c.Expires = time.Now().Add(f.DefaultTimeout)
|
||||
c.Expires = now.Add(f.DefaultTimeout)
|
||||
}
|
||||
|
||||
conntrack.Unlock()
|
||||
@@ -580,7 +580,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
|
||||
return true
|
||||
}
|
||||
|
||||
func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
|
||||
func (f *Firewall) addConn(fp firewall.Packet, incoming bool, now time.Time) {
|
||||
var timeout time.Duration
|
||||
c := &conn{}
|
||||
|
||||
@@ -596,7 +596,7 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
|
||||
conntrack := f.Conntrack
|
||||
conntrack.Lock()
|
||||
if _, ok := conntrack.Conns[fp]; !ok {
|
||||
conntrack.TimerWheel.Advance(time.Now())
|
||||
conntrack.TimerWheel.Advance(now)
|
||||
conntrack.TimerWheel.Add(fp, timeout)
|
||||
}
|
||||
|
||||
@@ -604,14 +604,14 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
|
||||
// firewall reload
|
||||
c.incoming = incoming
|
||||
c.rulesVersion = f.rulesVersion
|
||||
c.Expires = time.Now().Add(timeout)
|
||||
c.Expires = now.Add(timeout)
|
||||
conntrack.Conns[fp] = c
|
||||
conntrack.Unlock()
|
||||
}
|
||||
|
||||
// Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
|
||||
// Caller must own the connMutex lock!
|
||||
func (f *Firewall) evict(p firewall.Packet) {
|
||||
func (f *Firewall) evict(p firewall.Packet, now time.Time) {
|
||||
// Are we still tracking this conn?
|
||||
conntrack := f.Conntrack
|
||||
t, ok := conntrack.Conns[p]
|
||||
@@ -619,11 +619,11 @@ func (f *Firewall) evict(p firewall.Packet) {
|
||||
return
|
||||
}
|
||||
|
||||
newT := t.Expires.Sub(time.Now())
|
||||
newT := t.Expires.Sub(now)
|
||||
|
||||
// Timeout is in the future, re-add the timer
|
||||
if newT > 0 {
|
||||
conntrack.TimerWheel.Advance(time.Now())
|
||||
conntrack.TimerWheel.Advance(now)
|
||||
conntrack.TimerWheel.Add(p, newT)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package nebula
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/firewall"
|
||||
@@ -12,7 +13,7 @@ import (
|
||||
"github.com/slackhq/nebula/routing"
|
||||
)
|
||||
|
||||
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb []byte, out *packet.Packet, q int, localCache firewall.ConntrackCache) {
|
||||
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb []byte, out *packet.Packet, q int, localCache firewall.ConntrackCache, now time.Time) {
|
||||
err := newPacket(packet, false, fwPacket)
|
||||
if err != nil {
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
@@ -67,7 +68,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
||||
return
|
||||
}
|
||||
|
||||
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
|
||||
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache, now)
|
||||
if dropReason == nil {
|
||||
f.sendNoMetricsDelayed(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q)
|
||||
} else {
|
||||
@@ -218,7 +219,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
|
||||
}
|
||||
|
||||
// check if packet is in outbound fw rules
|
||||
dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil)
|
||||
dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil, time.Now())
|
||||
if dropReason != nil {
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("fwPacket", fp).
|
||||
|
||||
@@ -287,7 +287,7 @@ func (f *Interface) listenOut(q int) {
|
||||
outPackets[i].SegCounter = 0
|
||||
}
|
||||
|
||||
f.readOutsidePacketsMany(pkts, outPackets, h, fwPacket, lhh, nb, q, ctCache.Get(f.l))
|
||||
f.readOutsidePacketsMany(pkts, outPackets, h, fwPacket, lhh, nb, q, ctCache.Get(f.l), time.Now())
|
||||
for i := range pkts {
|
||||
if pkts[i].OutLen != -1 {
|
||||
for j := 0; j < outPackets[i].SegCounter; j++ {
|
||||
@@ -344,9 +344,10 @@ func (f *Interface) listenIn(reader overlay.TunDev, queueNum int) {
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
for i, pkt := range packets[:n] {
|
||||
outPackets[i].OutLen = -1
|
||||
f.consumeInsidePacket(pkt.Payload, fwPacket, nb, outPackets[i], queueNum, conntrackCache.Get(f.l))
|
||||
f.consumeInsidePacket(pkt.Payload, fwPacket, nb, outPackets[i], queueNum, conntrackCache.Get(f.l), now)
|
||||
}
|
||||
_, err = f.writers[queueNum].WriteBatch(outPackets[:n])
|
||||
if err != nil {
|
||||
|
||||
20
outside.go
20
outside.go
@@ -20,7 +20,7 @@ const (
|
||||
minFwPacketLen = 4
|
||||
)
|
||||
|
||||
func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
|
||||
func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) {
|
||||
err := h.Parse(packet)
|
||||
if err != nil {
|
||||
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
|
||||
@@ -62,7 +62,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
||||
|
||||
switch h.Subtype {
|
||||
case header.MessageNone:
|
||||
if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache) {
|
||||
if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache, now) {
|
||||
return
|
||||
}
|
||||
case header.MessageRelay:
|
||||
@@ -97,7 +97,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
||||
case TerminalType:
|
||||
// If I am the target of this relay, process the unwrapped packet
|
||||
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
|
||||
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
|
||||
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache, now)
|
||||
return
|
||||
case ForwardingType:
|
||||
// Find the target HostInfo relay object
|
||||
@@ -217,7 +217,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
||||
f.connectionManager.In(hostinfo)
|
||||
}
|
||||
|
||||
func (f *Interface) readOutsidePacketsMany(packets []*packet.Packet, out []*packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
|
||||
func (f *Interface) readOutsidePacketsMany(packets []*packet.Packet, out []*packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) {
|
||||
for i, pkt := range packets {
|
||||
out[i].Scratch = out[i].Scratch[:0]
|
||||
ip := pkt.AddrPort()
|
||||
@@ -266,7 +266,7 @@ func (f *Interface) readOutsidePacketsMany(packets []*packet.Packet, out []*pack
|
||||
|
||||
switch h.Subtype {
|
||||
case header.MessageNone:
|
||||
if !f.decryptToTunDelayWrite(hostinfo, h.MessageCounter, out[i], pkt, segment, fwPacket, nb, q, localCache) {
|
||||
if !f.decryptToTunDelayWrite(hostinfo, h.MessageCounter, out[i], pkt, segment, fwPacket, nb, q, localCache, now) {
|
||||
return
|
||||
}
|
||||
case header.MessageRelay:
|
||||
@@ -301,7 +301,7 @@ func (f *Interface) readOutsidePacketsMany(packets []*packet.Packet, out []*pack
|
||||
case TerminalType:
|
||||
// If I am the target of this relay, process the unwrapped packet
|
||||
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
|
||||
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[i].Scratch[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
|
||||
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[i].Scratch[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache, now)
|
||||
return
|
||||
case ForwardingType:
|
||||
// Find the target HostInfo relay object
|
||||
@@ -672,7 +672,7 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter uint64, out *packet.OutPacket, pkt *packet.Packet, inSegment []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool {
|
||||
func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter uint64, out *packet.OutPacket, pkt *packet.Packet, inSegment []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) bool {
|
||||
var err error
|
||||
|
||||
out.Segments[out.SegCounter] = out.Segments[out.SegCounter][:0]
|
||||
@@ -695,7 +695,7 @@ func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter ui
|
||||
return false
|
||||
}
|
||||
|
||||
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
|
||||
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache, now)
|
||||
if dropReason != nil {
|
||||
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
|
||||
// This gives us a buffer to build the reject packet in
|
||||
@@ -714,7 +714,7 @@ func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter ui
|
||||
return true
|
||||
}
|
||||
|
||||
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool {
|
||||
func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) bool {
|
||||
var err error
|
||||
|
||||
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
|
||||
@@ -736,7 +736,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
||||
return false
|
||||
}
|
||||
|
||||
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
|
||||
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache, now)
|
||||
if dropReason != nil {
|
||||
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
|
||||
// This gives us a buffer to build the reject packet in
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"slices"
|
||||
|
||||
"github.com/slackhq/nebula/overlay/vhost"
|
||||
"github.com/slackhq/nebula/overlay/virtqueue"
|
||||
@@ -249,10 +250,6 @@ func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) erro
|
||||
return fmt.Errorf("offer descriptor chain: %w", err)
|
||||
}
|
||||
//todo surely there's something better to do here
|
||||
doneYet := map[uint16]bool{}
|
||||
for _, chain := range chainIndexes {
|
||||
doneYet[chain] = false
|
||||
}
|
||||
|
||||
for {
|
||||
txedChains, err := dev.TransmitQueue.BlockAndGetHeads(context.TODO())
|
||||
@@ -261,32 +258,28 @@ func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) erro
|
||||
} else if len(txedChains) == 0 {
|
||||
continue //todo will this ever exit?
|
||||
}
|
||||
for c := range txedChains {
|
||||
doneYet[txedChains[c].GetHead()] = true
|
||||
for _, c := range txedChains {
|
||||
idx := slices.Index(chainIndexes, c.GetHead())
|
||||
if idx < 0 {
|
||||
continue
|
||||
} else {
|
||||
_ = dev.TransmitQueue.FreeDescriptorChain(chainIndexes[idx])
|
||||
chainIndexes[idx] = 0 //todo I hope this works
|
||||
}
|
||||
}
|
||||
done := true //optimism!
|
||||
for _, x := range doneYet {
|
||||
if !x {
|
||||
for _, x := range chainIndexes {
|
||||
if x != 0 {
|
||||
done = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if done {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for the packet to have been transmitted.
|
||||
for i := range chainIndexes {
|
||||
|
||||
if err = dev.TransmitQueue.FreeDescriptorChain(chainIndexes[i]); err != nil {
|
||||
return fmt.Errorf("free descriptor chain: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Make above methods cancelable by taking a context.Context argument?
|
||||
// TODO: Implement zero-copy variants to transmit and receive packets?
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/binary"
|
||||
"iter"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
@@ -23,7 +24,6 @@ type Packet struct {
|
||||
|
||||
wasSegmented bool
|
||||
isV4 bool
|
||||
//Addr netip.AddrPort
|
||||
}
|
||||
|
||||
func New(isV4 bool) *Packet {
|
||||
@@ -81,17 +81,13 @@ func (p *Packet) SetSegSizeForTX() {
|
||||
hdr.Level = unix.SOL_UDP
|
||||
hdr.Type = unix.UDP_SEGMENT
|
||||
hdr.SetLen(syscall.CmsgLen(2))
|
||||
//setCmsgLen(hdr, unix.CmsgLen(2))
|
||||
binary.NativeEndian.PutUint16(p.Control[unix.CmsgLen(0):unix.CmsgLen(0)+2], uint16(p.SegSize))
|
||||
//data := p.Control[syscall.CmsgSpace(0)-syscall.CmsgSpace(2)+syscall.SizeofCmsghdr:]
|
||||
//binary.NativeEndian.PutUint16(data, uint16(p.SegSize))
|
||||
}
|
||||
|
||||
func (p *Packet) CompatibleForSegmentationWith(otherP *Packet) bool {
|
||||
//same dest
|
||||
|
||||
if p.AddrPort() != otherP.AddrPort() {
|
||||
return false //todo more efficient?
|
||||
if !slices.Equal(p.Name, otherP.Name) {
|
||||
return false
|
||||
}
|
||||
|
||||
//same body len
|
||||
@@ -113,33 +109,5 @@ func (p *Packet) Segments() iter.Seq[[]byte] {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
//if p.SegSize > 0 && p.SegSize < len(p.Payload) {
|
||||
//
|
||||
//} else {
|
||||
// f.readOutsidePackets(p.Addr, nil, result2[:0], p.Payload, h, fwPacket2, lhh, nb2, i, conntrackCache.Get(f.l))
|
||||
//}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
//type Pool struct {
|
||||
// pool sync.Pool
|
||||
//}
|
||||
//
|
||||
//var bigPool = &Pool{
|
||||
// pool: sync.Pool{New: func() any { return New() }},
|
||||
//}
|
||||
//
|
||||
//func GetPool() *Pool {
|
||||
// return bigPool
|
||||
//}
|
||||
//
|
||||
//func (p *Pool) Get() *Packet {
|
||||
// return p.pool.Get().(*Packet)
|
||||
//}
|
||||
//
|
||||
//func (p *Pool) Put(x *Packet) {
|
||||
// x.Payload = x.Payload[:Size]
|
||||
// p.pool.Put(x)
|
||||
//}
|
||||
|
||||
Reference in New Issue
Block a user