add more tests around bits counters (#1441)

Co-authored-by: Nate Brown <nbrown.us@gmail.com>
This commit is contained in:
Wade Simmons
2025-11-18 17:42:21 -05:00
committed by GitHub
parent 4df8bcb1f5
commit 27ea667aee
2 changed files with 118 additions and 100 deletions

109
bits.go
View File

@@ -9,14 +9,13 @@ type Bits struct {
length uint64
current uint64
bits []bool
firstSeen bool
lostCounter metrics.Counter
dupeCounter metrics.Counter
outOfWindowCounter metrics.Counter
}
func NewBits(bits uint64) *Bits {
return &Bits{
b := &Bits{
length: bits,
bits: make([]bool, bits, bits),
current: 0,
@@ -24,34 +23,37 @@ func NewBits(bits uint64) *Bits {
dupeCounter: metrics.GetOrRegisterCounter("network.packets.duplicate", nil),
outOfWindowCounter: metrics.GetOrRegisterCounter("network.packets.out_of_window", nil),
}
// There is no counter value 0, mark it to avoid counting a lost packet later.
b.bits[0] = true
b.current = 0
return b
}
func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool {
func (b *Bits) Check(l *logrus.Logger, i uint64) bool {
// If i is the next number, return true.
if i > b.current || (i == 0 && b.firstSeen == false && b.current < b.length) {
if i > b.current {
return true
}
// If i is within the window, check if it's been set already. The first window will fail this check
if i > b.current-b.length {
return !b.bits[i%b.length]
}
// If i is within the first window
if i < b.length {
// If i is within the window, check if it's been set already.
if i > b.current-b.length || i < b.length && b.current < b.length {
return !b.bits[i%b.length]
}
// Not within the window
l.Debugf("rejected a packet (top) %d %d\n", b.current, i)
if l.Level >= logrus.DebugLevel {
l.Debugf("rejected a packet (top) %d %d\n", b.current, i)
}
return false
}
func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
// If i is the next number, return true and update current.
if i == b.current+1 {
// Report missed packets, we can only understand what was missed after the first window has been gone through
if i > b.length && b.bits[i%b.length] == false {
// Check if the oldest bit was lost since we are shifting the window by 1 and occupying it with this counter
// The very first window can only be tracked as lost once we are on the 2nd window or greater
if b.bits[i%b.length] == false && i > b.length {
b.lostCounter.Inc(1)
}
b.bits[i%b.length] = true
@@ -59,61 +61,32 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
return true
}
// If i packet is greater than current but less than the maximum length of our bitmap,
// flip everything in between to false and move ahead.
if i > b.current && i < b.current+b.length {
// In between current and i need to be zero'd to allow those packets to come in later
for n := b.current + 1; n < i; n++ {
// If i is a jump, adjust the window, record lost, update current, and return true
if i > b.current {
lost := int64(0)
// Zero out the bits between the current and the new counter value, limited by the window size,
// since the window is shifting
for n := b.current + 1; n <= min(i, b.current+b.length); n++ {
if b.bits[n%b.length] == false && n > b.length {
lost++
}
b.bits[n%b.length] = false
}
b.bits[i%b.length] = true
b.current = i
//l.Debugf("missed %d packets between %d and %d\n", i-b.current, i, b.current)
return true
}
// If i is greater than the delta between current and the total length of our bitmap,
// just flip everything in the map and move ahead.
if i >= b.current+b.length {
// The current window loss will be accounted for later, only record the jump as loss up until then
lost := maxInt64(0, int64(i-b.current-b.length))
//TODO: explain this
if b.current == 0 {
lost++
}
for n := range b.bits {
// Don't want to count the first window as a loss
//TODO: this is likely wrong, we are wanting to track only the bit slots that we aren't going to track anymore and this is marking everything as missed
//if b.bits[n] == false {
// lost++
//}
b.bits[n] = false
}
// Only record any skipped packets as a result of the window moving further than the window length
// Any loss within the new window will be accounted for in future calls
lost += max(0, int64(i-b.current-b.length))
b.lostCounter.Inc(lost)
if l.Level >= logrus.DebugLevel {
l.WithField("receiveWindow", m{"accepted": true, "currentCounter": b.current, "incomingCounter": i, "reason": "window shifting"}).
Debug("Receive window")
}
b.bits[i%b.length] = true
b.current = i
return true
}
// Allow for the 0 packet to come in within the first window
if i == 0 && b.firstSeen == false && b.current < b.length {
b.firstSeen = true
b.bits[i%b.length] = true
return true
}
// If i is within the window of current minus length (the total pat window size),
// allow it and flip to true but to NOT change current. We also have to account for the first window
if ((b.current >= b.length && i > b.current-b.length) || (b.current < b.length && i < b.length)) && i <= b.current {
if b.current == i {
// If i is within the current window but below the current counter,
// Check to see if it's a duplicate
if i > b.current-b.length || i < b.length && b.current < b.length {
if b.current == i || b.bits[i%b.length] == true {
if l.Level >= logrus.DebugLevel {
l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "duplicate"}).
Debug("Receive window")
@@ -122,18 +95,8 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
return false
}
if b.bits[i%b.length] == true {
if l.Level >= logrus.DebugLevel {
l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "old duplicate"}).
Debug("Receive window")
}
b.dupeCounter.Inc(1)
return false
}
b.bits[i%b.length] = true
return true
}
// In all other cases, fail and don't change current.
@@ -147,11 +110,3 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
}
return false
}
func maxInt64(a, b int64) int64 {
if a > b {
return a
}
return b
}

View File

@@ -15,48 +15,41 @@ func TestBits(t *testing.T) {
assert.Len(t, b.bits, 10)
// This is initialized to zero - receive one. This should work.
assert.True(t, b.Check(l, 1))
u := b.Update(l, 1)
assert.True(t, u)
assert.True(t, b.Update(l, 1))
assert.EqualValues(t, 1, b.current)
g := []bool{false, true, false, false, false, false, false, false, false, false}
g := []bool{true, true, false, false, false, false, false, false, false, false}
assert.Equal(t, g, b.bits)
// Receive two
assert.True(t, b.Check(l, 2))
u = b.Update(l, 2)
assert.True(t, u)
assert.True(t, b.Update(l, 2))
assert.EqualValues(t, 2, b.current)
g = []bool{false, true, true, false, false, false, false, false, false, false}
g = []bool{true, true, true, false, false, false, false, false, false, false}
assert.Equal(t, g, b.bits)
// Receive two again - it will fail
assert.False(t, b.Check(l, 2))
u = b.Update(l, 2)
assert.False(t, u)
assert.False(t, b.Update(l, 2))
assert.EqualValues(t, 2, b.current)
// Jump ahead to 15, which should clear everything and set the 6th element
assert.True(t, b.Check(l, 15))
u = b.Update(l, 15)
assert.True(t, u)
assert.True(t, b.Update(l, 15))
assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, false, true, false, false, false, false}
assert.Equal(t, g, b.bits)
// Mark 14, which is allowed because it is in the window
assert.True(t, b.Check(l, 14))
u = b.Update(l, 14)
assert.True(t, u)
assert.True(t, b.Update(l, 14))
assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, true, true, false, false, false, false}
assert.Equal(t, g, b.bits)
// Mark 5, which is not allowed because it is not in the window
assert.False(t, b.Check(l, 5))
u = b.Update(l, 5)
assert.False(t, u)
assert.False(t, b.Update(l, 5))
assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, true, true, false, false, false, false}
assert.Equal(t, g, b.bits)
@@ -69,10 +62,29 @@ func TestBits(t *testing.T) {
// Walk through a few windows in order
b = NewBits(10)
for i := uint64(0); i <= 100; i++ {
for i := uint64(1); i <= 100; i++ {
assert.True(t, b.Check(l, i), "Error while checking %v", i)
assert.True(t, b.Update(l, i), "Error while updating %v", i)
}
assert.False(t, b.Check(l, 1), "Out of window check")
}
func TestBitsLargeJumps(t *testing.T) {
l := test.NewLogger()
b := NewBits(10)
b.lostCounter.Clear()
b = NewBits(10)
b.lostCounter.Clear()
assert.True(t, b.Update(l, 55)) // We saw packet 55 and can still track 45,46,47,48,49,50,51,52,53,54
assert.Equal(t, int64(45), b.lostCounter.Count())
assert.True(t, b.Update(l, 100)) // We saw packet 55 and 100 and can still track 90,91,92,93,94,95,96,97,98,99
assert.Equal(t, int64(89), b.lostCounter.Count())
assert.True(t, b.Update(l, 200)) // We saw packet 55, 100, and 200 and can still track 190,191,192,193,194,195,196,197,198,199
assert.Equal(t, int64(188), b.lostCounter.Count())
}
func TestBitsDupeCounter(t *testing.T) {
@@ -124,8 +136,7 @@ func TestBitsOutOfWindowCounter(t *testing.T) {
assert.False(t, b.Update(l, 0))
assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
//tODO: make sure lostcounter doesn't increase in orderly increment
assert.Equal(t, int64(20), b.lostCounter.Count())
assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost
assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
}
@@ -137,8 +148,6 @@ func TestBitsLostCounter(t *testing.T) {
b.dupeCounter.Clear()
b.outOfWindowCounter.Clear()
//assert.True(t, b.Update(0))
assert.True(t, b.Update(l, 0))
assert.True(t, b.Update(l, 20))
assert.True(t, b.Update(l, 21))
assert.True(t, b.Update(l, 22))
@@ -149,7 +158,7 @@ func TestBitsLostCounter(t *testing.T) {
assert.True(t, b.Update(l, 27))
assert.True(t, b.Update(l, 28))
assert.True(t, b.Update(l, 29))
assert.Equal(t, int64(20), b.lostCounter.Count())
assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost
assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
@@ -158,8 +167,6 @@ func TestBitsLostCounter(t *testing.T) {
b.dupeCounter.Clear()
b.outOfWindowCounter.Clear()
assert.True(t, b.Update(l, 0))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 9))
assert.Equal(t, int64(0), b.lostCounter.Count())
// 10 will set 0 index, 0 was already set, no lost packets
@@ -214,6 +221,62 @@ func TestBitsLostCounter(t *testing.T) {
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
}
func TestBitsLostCounterIssue1(t *testing.T) {
l := test.NewLogger()
b := NewBits(10)
b.lostCounter.Clear()
b.dupeCounter.Clear()
b.outOfWindowCounter.Clear()
assert.True(t, b.Update(l, 4))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 1))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 9))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 2))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 3))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 5))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 6))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 7))
assert.Equal(t, int64(0), b.lostCounter.Count())
// assert.True(t, b.Update(l, 8))
assert.True(t, b.Update(l, 10))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 11))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 14))
assert.Equal(t, int64(0), b.lostCounter.Count())
// Issue seems to be here, we reset missing packet 8 to false here and don't increment the lost counter
assert.True(t, b.Update(l, 19))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 12))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 13))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 15))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 16))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 17))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 18))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 20))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 21))
// We missed packet 8 above
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
}
func BenchmarkBits(b *testing.B) {
z := NewBits(10)
for n := 0; n < b.N; n++ {