diff --git a/firewall_test.go b/firewall_test.go index ce6ba18..4f24ac0 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -34,27 +34,27 @@ func TestNewFirewall(t *testing.T) { assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) - assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen) + assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) - assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen) + assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) - assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen) + assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) - assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen) + assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) - assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen) + assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) - assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen) + assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) } func TestFirewall_AddRule(t *testing.T) { diff --git a/timeout.go b/timeout.go index fe63f3e..6d8f68b 100644 --- a/timeout.go +++ b/timeout.go @@ -36,19 +36,19 @@ type TimerWheel struct { itemsCached int } -// Represents a tick in the wheel +// TimeoutList Represents a tick in the wheel type TimeoutList struct { Head *TimeoutItem Tail *TimeoutItem } -// Represents an item within a tick +// TimeoutItem Represents an item within a tick type TimeoutItem struct { Packet firewall.Packet Next *TimeoutItem } -// Builds a timer wheel and identifies the tick duration and wheel duration from the provided values +// NewTimerWheel Builds a timer wheel and identifies the tick duration and wheel duration from the provided values // Purge must be called once per entry to actually remove anything func NewTimerWheel(min, max time.Duration) *TimerWheel { //TODO provide an error @@ -56,9 +56,10 @@ func NewTimerWheel(min, max time.Duration) *TimerWheel { // return nil //} - // Round down and add 1 so we can have the smallest # of ticks in the wheel and still account for a full - // max duration - wLen := int((max / min) + 1) + // Round down and add 2 so we can have the smallest # of ticks in the wheel and still account for a full + // max duration, even if our current tick is at the maximum position and the next item to be added is at maximum + // timeout + wLen := int((max / min) + 2) tw := TimerWheel{ wheelLen: wLen, diff --git a/timeout_system.go b/timeout_system.go index 72f6af9..c39d9cd 100644 --- a/timeout_system.go +++ b/timeout_system.go @@ -37,19 +37,19 @@ type SystemTimerWheel struct { lock sync.Mutex } -// Represents a tick in the wheel +// SystemTimeoutList Represents a tick in the wheel type SystemTimeoutList struct { Head *SystemTimeoutItem Tail *SystemTimeoutItem } -// Represents an item within a tick +// SystemTimeoutItem Represents an item within a tick type SystemTimeoutItem struct { Item iputil.VpnIp Next *SystemTimeoutItem } -// Builds a timer wheel and identifies the tick duration and wheel duration from the provided values +// NewSystemTimerWheel Builds a timer wheel and identifies the tick duration and wheel duration from the provided values // Purge must be called once per entry to actually remove anything func NewSystemTimerWheel(min, max time.Duration) *SystemTimerWheel { //TODO provide an error @@ -57,9 +57,10 @@ func NewSystemTimerWheel(min, max time.Duration) *SystemTimerWheel { // return nil //} - // Round down and add 1 so we can have the smallest # of ticks in the wheel and still account for a full - // max duration - wLen := int((max / min) + 1) + // Round down and add 2 so we can have the smallest # of ticks in the wheel and still account for a full + // max duration, even if our current tick is at the maximum position and the next item to be added is at maximum + // timeout + wLen := int((max / min) + 2) tw := SystemTimerWheel{ wheelLen: wLen, diff --git a/timeout_system_test.go b/timeout_system_test.go index 41c64a0..ba3c22b 100644 --- a/timeout_system_test.go +++ b/timeout_system_test.go @@ -12,24 +12,24 @@ import ( func TestNewSystemTimerWheel(t *testing.T) { // Make sure we get an object we expect tw := NewSystemTimerWheel(time.Second, time.Second*10) - assert.Equal(t, 11, tw.wheelLen) + assert.Equal(t, 12, tw.wheelLen) assert.Equal(t, 0, tw.current) assert.Nil(t, tw.lastTick) assert.Equal(t, time.Second*1, tw.tickDuration) assert.Equal(t, time.Second*10, tw.wheelDuration) - assert.Len(t, tw.wheel, 11) + assert.Len(t, tw.wheel, 12) // Assert the math is correct tw = NewSystemTimerWheel(time.Second*3, time.Second*10) - assert.Equal(t, 4, tw.wheelLen) + assert.Equal(t, 5, tw.wheelLen) tw = NewSystemTimerWheel(time.Second*120, time.Minute*10) - assert.Equal(t, 6, tw.wheelLen) + assert.Equal(t, 7, tw.wheelLen) } func TestSystemTimerWheel_findWheel(t *testing.T) { tw := NewSystemTimerWheel(time.Second, time.Second*10) - assert.Len(t, tw.wheel, 11) + assert.Len(t, tw.wheel, 12) // Current + tick + 1 since we don't know how far into current we are assert.Equal(t, 2, tw.findWheel(time.Second*1)) @@ -38,15 +38,32 @@ func TestSystemTimerWheel_findWheel(t *testing.T) { assert.Equal(t, 2, tw.findWheel(time.Millisecond*1)) // Make sure we hit that last index - assert.Equal(t, 0, tw.findWheel(time.Second*10)) + assert.Equal(t, 11, tw.findWheel(time.Second*10)) // Scale down to max duration - assert.Equal(t, 0, tw.findWheel(time.Second*11)) + assert.Equal(t, 11, tw.findWheel(time.Second*11)) tw.current = 1 // Make sure we account for the current position properly assert.Equal(t, 3, tw.findWheel(time.Second*1)) - assert.Equal(t, 1, tw.findWheel(time.Second*10)) + assert.Equal(t, 0, tw.findWheel(time.Second*10)) + + // Ensure that all configurations of a wheel does not result in calculating an overflow of the wheel + for min := time.Duration(1); min < 100; min++ { + for max := min; max < 100; max++ { + tw = NewSystemTimerWheel(min, max) + + for current := 0; current < tw.wheelLen; current++ { + tw.current = current + for timeout := time.Duration(0); timeout <= tw.wheelDuration; timeout++ { + tick := tw.findWheel(timeout) + if tick >= tw.wheelLen { + t.Errorf("Min: %v; Max: %v; Wheel len: %v; Current Tick: %v; Insert timeout: %v; Calc tick: %v", min, max, tw.wheelLen, current, timeout, tick) + } + } + } + } + } } func TestSystemTimerWheel_Add(t *testing.T) { @@ -129,6 +146,10 @@ func TestSystemTimerWheel_Purge(t *testing.T) { tw.advance(ta) assert.Equal(t, 10, tw.current) + ta = ta.Add(time.Second * 1) + tw.advance(ta) + assert.Equal(t, 11, tw.current) + ta = ta.Add(time.Second * 1) tw.advance(ta) assert.Equal(t, 0, tw.current) diff --git a/timeout_test.go b/timeout_test.go index 9678b35..70b107c 100644 --- a/timeout_test.go +++ b/timeout_test.go @@ -11,24 +11,24 @@ import ( func TestNewTimerWheel(t *testing.T) { // Make sure we get an object we expect tw := NewTimerWheel(time.Second, time.Second*10) - assert.Equal(t, 11, tw.wheelLen) + assert.Equal(t, 12, tw.wheelLen) assert.Equal(t, 0, tw.current) assert.Nil(t, tw.lastTick) assert.Equal(t, time.Second*1, tw.tickDuration) assert.Equal(t, time.Second*10, tw.wheelDuration) - assert.Len(t, tw.wheel, 11) + assert.Len(t, tw.wheel, 12) // Assert the math is correct tw = NewTimerWheel(time.Second*3, time.Second*10) - assert.Equal(t, 4, tw.wheelLen) + assert.Equal(t, 5, tw.wheelLen) tw = NewTimerWheel(time.Second*120, time.Minute*10) - assert.Equal(t, 6, tw.wheelLen) + assert.Equal(t, 7, tw.wheelLen) } func TestTimerWheel_findWheel(t *testing.T) { tw := NewTimerWheel(time.Second, time.Second*10) - assert.Len(t, tw.wheel, 11) + assert.Len(t, tw.wheel, 12) // Current + tick + 1 since we don't know how far into current we are assert.Equal(t, 2, tw.findWheel(time.Second*1)) @@ -37,15 +37,15 @@ func TestTimerWheel_findWheel(t *testing.T) { assert.Equal(t, 2, tw.findWheel(time.Millisecond*1)) // Make sure we hit that last index - assert.Equal(t, 0, tw.findWheel(time.Second*10)) + assert.Equal(t, 11, tw.findWheel(time.Second*10)) // Scale down to max duration - assert.Equal(t, 0, tw.findWheel(time.Second*11)) + assert.Equal(t, 11, tw.findWheel(time.Second*11)) tw.current = 1 // Make sure we account for the current position properly assert.Equal(t, 3, tw.findWheel(time.Second*1)) - assert.Equal(t, 1, tw.findWheel(time.Second*10)) + assert.Equal(t, 0, tw.findWheel(time.Second*10)) } func TestTimerWheel_Add(t *testing.T) { @@ -75,6 +75,23 @@ func TestTimerWheel_Add(t *testing.T) { tw.Add(fp2, time.Second*1) assert.Nil(t, tw.itemCache) assert.Equal(t, 0, tw.itemsCached) + + // Ensure that all configurations of a wheel does not result in calculating an overflow of the wheel + for min := time.Duration(1); min < 100; min++ { + for max := min; max < 100; max++ { + tw = NewTimerWheel(min, max) + + for current := 0; current < tw.wheelLen; current++ { + tw.current = current + for timeout := time.Duration(0); timeout <= tw.wheelDuration; timeout++ { + tick := tw.findWheel(timeout) + if tick >= tw.wheelLen { + t.Errorf("Min: %v; Max: %v; Wheel len: %v; Current Tick: %v; Insert timeout: %v; Calc tick: %v", min, max, tw.wheelLen, current, timeout, tick) + } + } + } + } + } } func TestTimerWheel_Purge(t *testing.T) { @@ -134,6 +151,10 @@ func TestTimerWheel_Purge(t *testing.T) { tw.advance(ta) assert.Equal(t, 10, tw.current) + ta = ta.Add(time.Second * 1) + tw.advance(ta) + assert.Equal(t, 11, tw.current) + ta = ta.Add(time.Second * 1) tw.advance(ta) assert.Equal(t, 0, tw.current)