diff --git a/firewall.go b/firewall.go index 6bf470d..8e9cd71 100644 --- a/firewall.go +++ b/firewall.go @@ -425,9 +425,9 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table") // Drop returns an error if the packet should be dropped, explaining why. It // returns nil if the packet should not be dropped. -func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error { +func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache, now time.Time) error { // Check if we spoke to this tuple, if we did then allow this packet - if f.inConns(fp, h, caPool, localCache) { + if f.inConns(fp, h, caPool, localCache, now) { return nil } @@ -476,7 +476,7 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * } // We always want to conntrack since it is a faster operation - f.addConn(fp, incoming) + f.addConn(fp, incoming, now) return nil } @@ -505,7 +505,7 @@ func (f *Firewall) EmitStats() { metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV())) } -func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool { +func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache, now time.Time) bool { if localCache != nil { if _, ok := localCache[fp]; ok { return true @@ -517,7 +517,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, // Purge every time we test ep, has := conntrack.TimerWheel.Purge() if has { - f.evict(ep) + f.evict(ep, now) } c, ok := conntrack.Conns[fp] @@ -564,11 +564,11 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, switch fp.Protocol { case firewall.ProtoTCP: - c.Expires = time.Now().Add(f.TCPTimeout) + c.Expires = now.Add(f.TCPTimeout) case firewall.ProtoUDP: - c.Expires = time.Now().Add(f.UDPTimeout) + c.Expires = now.Add(f.UDPTimeout) default: - c.Expires = time.Now().Add(f.DefaultTimeout) + c.Expires = now.Add(f.DefaultTimeout) } conntrack.Unlock() @@ -580,7 +580,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, return true } -func (f *Firewall) addConn(fp firewall.Packet, incoming bool) { +func (f *Firewall) addConn(fp firewall.Packet, incoming bool, now time.Time) { var timeout time.Duration c := &conn{} @@ -596,7 +596,7 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) { conntrack := f.Conntrack conntrack.Lock() if _, ok := conntrack.Conns[fp]; !ok { - conntrack.TimerWheel.Advance(time.Now()) + conntrack.TimerWheel.Advance(now) conntrack.TimerWheel.Add(fp, timeout) } @@ -604,14 +604,14 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) { // firewall reload c.incoming = incoming c.rulesVersion = f.rulesVersion - c.Expires = time.Now().Add(timeout) + c.Expires = now.Add(timeout) conntrack.Conns[fp] = c conntrack.Unlock() } // Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel // Caller must own the connMutex lock! -func (f *Firewall) evict(p firewall.Packet) { +func (f *Firewall) evict(p firewall.Packet, now time.Time) { // Are we still tracking this conn? conntrack := f.Conntrack t, ok := conntrack.Conns[p] @@ -619,11 +619,11 @@ func (f *Firewall) evict(p firewall.Packet) { return } - newT := t.Expires.Sub(time.Now()) + newT := t.Expires.Sub(now) // Timeout is in the future, re-add the timer if newT > 0 { - conntrack.TimerWheel.Advance(time.Now()) + conntrack.TimerWheel.Advance(now) conntrack.TimerWheel.Add(p, newT) return } diff --git a/inside.go b/inside.go index 0f6e18e..1896df8 100644 --- a/inside.go +++ b/inside.go @@ -2,6 +2,7 @@ package nebula import ( "net/netip" + "time" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/firewall" @@ -12,7 +13,7 @@ import ( "github.com/slackhq/nebula/routing" ) -func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb []byte, out *packet.Packet, q int, localCache firewall.ConntrackCache) { +func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb []byte, out *packet.Packet, q int, localCache firewall.ConntrackCache, now time.Time) { err := newPacket(packet, false, fwPacket) if err != nil { if f.l.Level >= logrus.DebugLevel { @@ -67,7 +68,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet return } - dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) + dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache, now) if dropReason == nil { f.sendNoMetricsDelayed(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q) } else { @@ -218,7 +219,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp } // check if packet is in outbound fw rules - dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil) + dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil, time.Now()) if dropReason != nil { if f.l.Level >= logrus.DebugLevel { f.l.WithField("fwPacket", fp). diff --git a/interface.go b/interface.go index c68f3f9..6d6f819 100644 --- a/interface.go +++ b/interface.go @@ -287,7 +287,7 @@ func (f *Interface) listenOut(q int) { outPackets[i].SegCounter = 0 } - f.readOutsidePacketsMany(pkts, outPackets, h, fwPacket, lhh, nb, q, ctCache.Get(f.l)) + f.readOutsidePacketsMany(pkts, outPackets, h, fwPacket, lhh, nb, q, ctCache.Get(f.l), time.Now()) for i := range pkts { if pkts[i].OutLen != -1 { for j := 0; j < outPackets[i].SegCounter; j++ { @@ -344,9 +344,10 @@ func (f *Interface) listenIn(reader overlay.TunDev, queueNum int) { os.Exit(2) } + now := time.Now() for i, pkt := range packets[:n] { outPackets[i].OutLen = -1 - f.consumeInsidePacket(pkt.Payload, fwPacket, nb, outPackets[i], queueNum, conntrackCache.Get(f.l)) + f.consumeInsidePacket(pkt.Payload, fwPacket, nb, outPackets[i], queueNum, conntrackCache.Get(f.l), now) } _, err = f.writers[queueNum].WriteBatch(outPackets[:n]) if err != nil { diff --git a/outside.go b/outside.go index b9ddac7..8cc9a5e 100644 --- a/outside.go +++ b/outside.go @@ -20,7 +20,7 @@ const ( minFwPacketLen = 4 ) -func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { +func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) { err := h.Parse(packet) if err != nil { // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors @@ -62,7 +62,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] switch h.Subtype { case header.MessageNone: - if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache) { + if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache, now) { return } case header.MessageRelay: @@ -97,7 +97,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] case TerminalType: // If I am the target of this relay, process the unwrapped packet // From this recursive point, all these variables are 'burned'. We shouldn't rely on them again. - f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) + f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache, now) return case ForwardingType: // Find the target HostInfo relay object @@ -217,7 +217,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] f.connectionManager.In(hostinfo) } -func (f *Interface) readOutsidePacketsMany(packets []*packet.Packet, out []*packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { +func (f *Interface) readOutsidePacketsMany(packets []*packet.Packet, out []*packet.OutPacket, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) { for i, pkt := range packets { out[i].Scratch = out[i].Scratch[:0] ip := pkt.AddrPort() @@ -266,7 +266,7 @@ func (f *Interface) readOutsidePacketsMany(packets []*packet.Packet, out []*pack switch h.Subtype { case header.MessageNone: - if !f.decryptToTunDelayWrite(hostinfo, h.MessageCounter, out[i], pkt, segment, fwPacket, nb, q, localCache) { + if !f.decryptToTunDelayWrite(hostinfo, h.MessageCounter, out[i], pkt, segment, fwPacket, nb, q, localCache, now) { return } case header.MessageRelay: @@ -301,7 +301,7 @@ func (f *Interface) readOutsidePacketsMany(packets []*packet.Packet, out []*pack case TerminalType: // If I am the target of this relay, process the unwrapped packet // From this recursive point, all these variables are 'burned'. We shouldn't rely on them again. - f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[i].Scratch[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) + f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[i].Scratch[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache, now) return case ForwardingType: // Find the target HostInfo relay object @@ -672,7 +672,7 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet [] return out, nil } -func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter uint64, out *packet.OutPacket, pkt *packet.Packet, inSegment []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool { +func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter uint64, out *packet.OutPacket, pkt *packet.Packet, inSegment []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) bool { var err error out.Segments[out.SegCounter] = out.Segments[out.SegCounter][:0] @@ -695,7 +695,7 @@ func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter ui return false } - dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache) + dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache, now) if dropReason != nil { // NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore // This gives us a buffer to build the reject packet in @@ -714,7 +714,7 @@ func (f *Interface) decryptToTunDelayWrite(hostinfo *HostInfo, messageCounter ui return true } -func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool { +func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache, now time.Time) bool { var err error out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb) @@ -736,7 +736,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out return false } - dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache) + dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache, now) if dropReason != nil { // NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore // This gives us a buffer to build the reject packet in diff --git a/overlay/vhostnet/device.go b/overlay/vhostnet/device.go index cb0dc6e..8a9489a 100644 --- a/overlay/vhostnet/device.go +++ b/overlay/vhostnet/device.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "runtime" + "slices" "github.com/slackhq/nebula/overlay/vhost" "github.com/slackhq/nebula/overlay/virtqueue" @@ -249,10 +250,6 @@ func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) erro return fmt.Errorf("offer descriptor chain: %w", err) } //todo surely there's something better to do here - doneYet := map[uint16]bool{} - for _, chain := range chainIndexes { - doneYet[chain] = false - } for { txedChains, err := dev.TransmitQueue.BlockAndGetHeads(context.TODO()) @@ -261,31 +258,27 @@ func (dev *Device) TransmitPackets(vnethdr virtio.NetHdr, packets [][]byte) erro } else if len(txedChains) == 0 { continue //todo will this ever exit? } - for c := range txedChains { - doneYet[txedChains[c].GetHead()] = true + for _, c := range txedChains { + idx := slices.Index(chainIndexes, c.GetHead()) + if idx < 0 { + continue + } else { + _ = dev.TransmitQueue.FreeDescriptorChain(chainIndexes[idx]) + chainIndexes[idx] = 0 //todo I hope this works + } } done := true //optimism! - for _, x := range doneYet { - if !x { + for _, x := range chainIndexes { + if x != 0 { done = false break } } if done { - break + return nil } } - - // Wait for the packet to have been transmitted. - for i := range chainIndexes { - - if err = dev.TransmitQueue.FreeDescriptorChain(chainIndexes[i]); err != nil { - return fmt.Errorf("free descriptor chain: %w", err) - } - } - - return nil } // TODO: Make above methods cancelable by taking a context.Context argument? diff --git a/packet/packet.go b/packet/packet.go index 2b7a43b..0d096ec 100644 --- a/packet/packet.go +++ b/packet/packet.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "iter" "net/netip" + "slices" "syscall" "unsafe" @@ -23,7 +24,6 @@ type Packet struct { wasSegmented bool isV4 bool - //Addr netip.AddrPort } func New(isV4 bool) *Packet { @@ -81,17 +81,13 @@ func (p *Packet) SetSegSizeForTX() { hdr.Level = unix.SOL_UDP hdr.Type = unix.UDP_SEGMENT hdr.SetLen(syscall.CmsgLen(2)) - //setCmsgLen(hdr, unix.CmsgLen(2)) binary.NativeEndian.PutUint16(p.Control[unix.CmsgLen(0):unix.CmsgLen(0)+2], uint16(p.SegSize)) - //data := p.Control[syscall.CmsgSpace(0)-syscall.CmsgSpace(2)+syscall.SizeofCmsghdr:] - //binary.NativeEndian.PutUint16(data, uint16(p.SegSize)) } func (p *Packet) CompatibleForSegmentationWith(otherP *Packet) bool { //same dest - - if p.AddrPort() != otherP.AddrPort() { - return false //todo more efficient? + if !slices.Equal(p.Name, otherP.Name) { + return false } //same body len @@ -113,33 +109,5 @@ func (p *Packet) Segments() iter.Seq[[]byte] { return } } - - //if p.SegSize > 0 && p.SegSize < len(p.Payload) { - // - //} else { - // f.readOutsidePackets(p.Addr, nil, result2[:0], p.Payload, h, fwPacket2, lhh, nb2, i, conntrackCache.Get(f.l)) - //} - } } - -//type Pool struct { -// pool sync.Pool -//} -// -//var bigPool = &Pool{ -// pool: sync.Pool{New: func() any { return New() }}, -//} -// -//func GetPool() *Pool { -// return bigPool -//} -// -//func (p *Pool) Get() *Packet { -// return p.pool.Get().(*Packet) -//} -// -//func (p *Pool) Put(x *Packet) { -// x.Payload = x.Payload[:Size] -// p.pool.Put(x) -//}