add more tests around bits counters (#1441)

Co-authored-by: Nate Brown <nbrown.us@gmail.com>
This commit is contained in:
Wade Simmons
2025-11-18 17:42:21 -05:00
committed by GitHub
parent 4df8bcb1f5
commit 27ea667aee
2 changed files with 118 additions and 100 deletions

109
bits.go
View File

@@ -9,14 +9,13 @@ type Bits struct {
length uint64 length uint64
current uint64 current uint64
bits []bool bits []bool
firstSeen bool
lostCounter metrics.Counter lostCounter metrics.Counter
dupeCounter metrics.Counter dupeCounter metrics.Counter
outOfWindowCounter metrics.Counter outOfWindowCounter metrics.Counter
} }
func NewBits(bits uint64) *Bits { func NewBits(bits uint64) *Bits {
return &Bits{ b := &Bits{
length: bits, length: bits,
bits: make([]bool, bits, bits), bits: make([]bool, bits, bits),
current: 0, current: 0,
@@ -24,34 +23,37 @@ func NewBits(bits uint64) *Bits {
dupeCounter: metrics.GetOrRegisterCounter("network.packets.duplicate", nil), dupeCounter: metrics.GetOrRegisterCounter("network.packets.duplicate", nil),
outOfWindowCounter: metrics.GetOrRegisterCounter("network.packets.out_of_window", nil), outOfWindowCounter: metrics.GetOrRegisterCounter("network.packets.out_of_window", nil),
} }
// There is no counter value 0, mark it to avoid counting a lost packet later.
b.bits[0] = true
b.current = 0
return b
} }
func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool { func (b *Bits) Check(l *logrus.Logger, i uint64) bool {
// If i is the next number, return true. // If i is the next number, return true.
if i > b.current || (i == 0 && b.firstSeen == false && b.current < b.length) { if i > b.current {
return true return true
} }
// If i is within the window, check if it's been set already. The first window will fail this check // If i is within the window, check if it's been set already.
if i > b.current-b.length { if i > b.current-b.length || i < b.length && b.current < b.length {
return !b.bits[i%b.length]
}
// If i is within the first window
if i < b.length {
return !b.bits[i%b.length] return !b.bits[i%b.length]
} }
// Not within the window // Not within the window
l.Debugf("rejected a packet (top) %d %d\n", b.current, i) if l.Level >= logrus.DebugLevel {
l.Debugf("rejected a packet (top) %d %d\n", b.current, i)
}
return false return false
} }
func (b *Bits) Update(l *logrus.Logger, i uint64) bool { func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
// If i is the next number, return true and update current. // If i is the next number, return true and update current.
if i == b.current+1 { if i == b.current+1 {
// Report missed packets, we can only understand what was missed after the first window has been gone through // Check if the oldest bit was lost since we are shifting the window by 1 and occupying it with this counter
if i > b.length && b.bits[i%b.length] == false { // The very first window can only be tracked as lost once we are on the 2nd window or greater
if b.bits[i%b.length] == false && i > b.length {
b.lostCounter.Inc(1) b.lostCounter.Inc(1)
} }
b.bits[i%b.length] = true b.bits[i%b.length] = true
@@ -59,61 +61,32 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
return true return true
} }
// If i packet is greater than current but less than the maximum length of our bitmap, // If i is a jump, adjust the window, record lost, update current, and return true
// flip everything in between to false and move ahead. if i > b.current {
if i > b.current && i < b.current+b.length { lost := int64(0)
// In between current and i need to be zero'd to allow those packets to come in later // Zero out the bits between the current and the new counter value, limited by the window size,
for n := b.current + 1; n < i; n++ { // since the window is shifting
for n := b.current + 1; n <= min(i, b.current+b.length); n++ {
if b.bits[n%b.length] == false && n > b.length {
lost++
}
b.bits[n%b.length] = false b.bits[n%b.length] = false
} }
b.bits[i%b.length] = true // Only record any skipped packets as a result of the window moving further than the window length
b.current = i // Any loss within the new window will be accounted for in future calls
//l.Debugf("missed %d packets between %d and %d\n", i-b.current, i, b.current) lost += max(0, int64(i-b.current-b.length))
return true
}
// If i is greater than the delta between current and the total length of our bitmap,
// just flip everything in the map and move ahead.
if i >= b.current+b.length {
// The current window loss will be accounted for later, only record the jump as loss up until then
lost := maxInt64(0, int64(i-b.current-b.length))
//TODO: explain this
if b.current == 0 {
lost++
}
for n := range b.bits {
// Don't want to count the first window as a loss
//TODO: this is likely wrong, we are wanting to track only the bit slots that we aren't going to track anymore and this is marking everything as missed
//if b.bits[n] == false {
// lost++
//}
b.bits[n] = false
}
b.lostCounter.Inc(lost) b.lostCounter.Inc(lost)
if l.Level >= logrus.DebugLevel {
l.WithField("receiveWindow", m{"accepted": true, "currentCounter": b.current, "incomingCounter": i, "reason": "window shifting"}).
Debug("Receive window")
}
b.bits[i%b.length] = true b.bits[i%b.length] = true
b.current = i b.current = i
return true return true
} }
// Allow for the 0 packet to come in within the first window // If i is within the current window but below the current counter,
if i == 0 && b.firstSeen == false && b.current < b.length { // Check to see if it's a duplicate
b.firstSeen = true if i > b.current-b.length || i < b.length && b.current < b.length {
b.bits[i%b.length] = true if b.current == i || b.bits[i%b.length] == true {
return true
}
// If i is within the window of current minus length (the total pat window size),
// allow it and flip to true but to NOT change current. We also have to account for the first window
if ((b.current >= b.length && i > b.current-b.length) || (b.current < b.length && i < b.length)) && i <= b.current {
if b.current == i {
if l.Level >= logrus.DebugLevel { if l.Level >= logrus.DebugLevel {
l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "duplicate"}). l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "duplicate"}).
Debug("Receive window") Debug("Receive window")
@@ -122,18 +95,8 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
return false return false
} }
if b.bits[i%b.length] == true {
if l.Level >= logrus.DebugLevel {
l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "old duplicate"}).
Debug("Receive window")
}
b.dupeCounter.Inc(1)
return false
}
b.bits[i%b.length] = true b.bits[i%b.length] = true
return true return true
} }
// In all other cases, fail and don't change current. // In all other cases, fail and don't change current.
@@ -147,11 +110,3 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
} }
return false return false
} }
func maxInt64(a, b int64) int64 {
if a > b {
return a
}
return b
}

View File

@@ -15,48 +15,41 @@ func TestBits(t *testing.T) {
assert.Len(t, b.bits, 10) assert.Len(t, b.bits, 10)
// This is initialized to zero - receive one. This should work. // This is initialized to zero - receive one. This should work.
assert.True(t, b.Check(l, 1)) assert.True(t, b.Check(l, 1))
u := b.Update(l, 1) assert.True(t, b.Update(l, 1))
assert.True(t, u)
assert.EqualValues(t, 1, b.current) assert.EqualValues(t, 1, b.current)
g := []bool{false, true, false, false, false, false, false, false, false, false} g := []bool{true, true, false, false, false, false, false, false, false, false}
assert.Equal(t, g, b.bits) assert.Equal(t, g, b.bits)
// Receive two // Receive two
assert.True(t, b.Check(l, 2)) assert.True(t, b.Check(l, 2))
u = b.Update(l, 2) assert.True(t, b.Update(l, 2))
assert.True(t, u)
assert.EqualValues(t, 2, b.current) assert.EqualValues(t, 2, b.current)
g = []bool{false, true, true, false, false, false, false, false, false, false} g = []bool{true, true, true, false, false, false, false, false, false, false}
assert.Equal(t, g, b.bits) assert.Equal(t, g, b.bits)
// Receive two again - it will fail // Receive two again - it will fail
assert.False(t, b.Check(l, 2)) assert.False(t, b.Check(l, 2))
u = b.Update(l, 2) assert.False(t, b.Update(l, 2))
assert.False(t, u)
assert.EqualValues(t, 2, b.current) assert.EqualValues(t, 2, b.current)
// Jump ahead to 15, which should clear everything and set the 6th element // Jump ahead to 15, which should clear everything and set the 6th element
assert.True(t, b.Check(l, 15)) assert.True(t, b.Check(l, 15))
u = b.Update(l, 15) assert.True(t, b.Update(l, 15))
assert.True(t, u)
assert.EqualValues(t, 15, b.current) assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, false, true, false, false, false, false} g = []bool{false, false, false, false, false, true, false, false, false, false}
assert.Equal(t, g, b.bits) assert.Equal(t, g, b.bits)
// Mark 14, which is allowed because it is in the window // Mark 14, which is allowed because it is in the window
assert.True(t, b.Check(l, 14)) assert.True(t, b.Check(l, 14))
u = b.Update(l, 14) assert.True(t, b.Update(l, 14))
assert.True(t, u)
assert.EqualValues(t, 15, b.current) assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, true, true, false, false, false, false} g = []bool{false, false, false, false, true, true, false, false, false, false}
assert.Equal(t, g, b.bits) assert.Equal(t, g, b.bits)
// Mark 5, which is not allowed because it is not in the window // Mark 5, which is not allowed because it is not in the window
assert.False(t, b.Check(l, 5)) assert.False(t, b.Check(l, 5))
u = b.Update(l, 5) assert.False(t, b.Update(l, 5))
assert.False(t, u)
assert.EqualValues(t, 15, b.current) assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, true, true, false, false, false, false} g = []bool{false, false, false, false, true, true, false, false, false, false}
assert.Equal(t, g, b.bits) assert.Equal(t, g, b.bits)
@@ -69,10 +62,29 @@ func TestBits(t *testing.T) {
// Walk through a few windows in order // Walk through a few windows in order
b = NewBits(10) b = NewBits(10)
for i := uint64(0); i <= 100; i++ { for i := uint64(1); i <= 100; i++ {
assert.True(t, b.Check(l, i), "Error while checking %v", i) assert.True(t, b.Check(l, i), "Error while checking %v", i)
assert.True(t, b.Update(l, i), "Error while updating %v", i) assert.True(t, b.Update(l, i), "Error while updating %v", i)
} }
assert.False(t, b.Check(l, 1), "Out of window check")
}
func TestBitsLargeJumps(t *testing.T) {
l := test.NewLogger()
b := NewBits(10)
b.lostCounter.Clear()
b = NewBits(10)
b.lostCounter.Clear()
assert.True(t, b.Update(l, 55)) // We saw packet 55 and can still track 45,46,47,48,49,50,51,52,53,54
assert.Equal(t, int64(45), b.lostCounter.Count())
assert.True(t, b.Update(l, 100)) // We saw packet 55 and 100 and can still track 90,91,92,93,94,95,96,97,98,99
assert.Equal(t, int64(89), b.lostCounter.Count())
assert.True(t, b.Update(l, 200)) // We saw packet 55, 100, and 200 and can still track 190,191,192,193,194,195,196,197,198,199
assert.Equal(t, int64(188), b.lostCounter.Count())
} }
func TestBitsDupeCounter(t *testing.T) { func TestBitsDupeCounter(t *testing.T) {
@@ -124,8 +136,7 @@ func TestBitsOutOfWindowCounter(t *testing.T) {
assert.False(t, b.Update(l, 0)) assert.False(t, b.Update(l, 0))
assert.Equal(t, int64(1), b.outOfWindowCounter.Count()) assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
//tODO: make sure lostcounter doesn't increase in orderly increment assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost
assert.Equal(t, int64(20), b.lostCounter.Count())
assert.Equal(t, int64(0), b.dupeCounter.Count()) assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(1), b.outOfWindowCounter.Count()) assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
} }
@@ -137,8 +148,6 @@ func TestBitsLostCounter(t *testing.T) {
b.dupeCounter.Clear() b.dupeCounter.Clear()
b.outOfWindowCounter.Clear() b.outOfWindowCounter.Clear()
//assert.True(t, b.Update(0))
assert.True(t, b.Update(l, 0))
assert.True(t, b.Update(l, 20)) assert.True(t, b.Update(l, 20))
assert.True(t, b.Update(l, 21)) assert.True(t, b.Update(l, 21))
assert.True(t, b.Update(l, 22)) assert.True(t, b.Update(l, 22))
@@ -149,7 +158,7 @@ func TestBitsLostCounter(t *testing.T) {
assert.True(t, b.Update(l, 27)) assert.True(t, b.Update(l, 27))
assert.True(t, b.Update(l, 28)) assert.True(t, b.Update(l, 28))
assert.True(t, b.Update(l, 29)) assert.True(t, b.Update(l, 29))
assert.Equal(t, int64(20), b.lostCounter.Count()) assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost
assert.Equal(t, int64(0), b.dupeCounter.Count()) assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
@@ -158,8 +167,6 @@ func TestBitsLostCounter(t *testing.T) {
b.dupeCounter.Clear() b.dupeCounter.Clear()
b.outOfWindowCounter.Clear() b.outOfWindowCounter.Clear()
assert.True(t, b.Update(l, 0))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 9)) assert.True(t, b.Update(l, 9))
assert.Equal(t, int64(0), b.lostCounter.Count()) assert.Equal(t, int64(0), b.lostCounter.Count())
// 10 will set 0 index, 0 was already set, no lost packets // 10 will set 0 index, 0 was already set, no lost packets
@@ -214,6 +221,62 @@ func TestBitsLostCounter(t *testing.T) {
assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
} }
func TestBitsLostCounterIssue1(t *testing.T) {
l := test.NewLogger()
b := NewBits(10)
b.lostCounter.Clear()
b.dupeCounter.Clear()
b.outOfWindowCounter.Clear()
assert.True(t, b.Update(l, 4))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 1))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 9))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 2))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 3))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 5))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 6))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 7))
assert.Equal(t, int64(0), b.lostCounter.Count())
// assert.True(t, b.Update(l, 8))
assert.True(t, b.Update(l, 10))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 11))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(l, 14))
assert.Equal(t, int64(0), b.lostCounter.Count())
// Issue seems to be here, we reset missing packet 8 to false here and don't increment the lost counter
assert.True(t, b.Update(l, 19))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 12))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 13))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 15))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 16))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 17))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 18))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 20))
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.True(t, b.Update(l, 21))
// We missed packet 8 above
assert.Equal(t, int64(1), b.lostCounter.Count())
assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
}
func BenchmarkBits(b *testing.B) { func BenchmarkBits(b *testing.B) {
z := NewBits(10) z := NewBits(10)
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {