mirror of
https://github.com/slackhq/nebula.git
synced 2026-05-15 20:37:36 +02:00
Handshake state machine (#1656)
This commit is contained in:
361
handshake/payload_test.go
Normal file
361
handshake/payload_test.go
Normal file
@@ -0,0 +1,361 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/encoding/protowire"
|
||||
)
|
||||
|
||||
func TestPayloadRoundTrip(t *testing.T) {
|
||||
t.Run("all fields set", func(t *testing.T) {
|
||||
data := MarshalPayload(nil, Payload{
|
||||
Cert: []byte("test-cert-bytes"),
|
||||
CertVersion: 2,
|
||||
InitiatorIndex: 12345,
|
||||
ResponderIndex: 67890,
|
||||
Time: 1234567890,
|
||||
})
|
||||
|
||||
got, err := UnmarshalPayload(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, []byte("test-cert-bytes"), got.Cert)
|
||||
assert.Equal(t, uint32(12345), got.InitiatorIndex)
|
||||
assert.Equal(t, uint32(67890), got.ResponderIndex)
|
||||
assert.Equal(t, uint64(1234567890), got.Time)
|
||||
assert.Equal(t, uint32(2), got.CertVersion)
|
||||
})
|
||||
|
||||
t.Run("minimal fields", func(t *testing.T) {
|
||||
data := MarshalPayload(nil, Payload{InitiatorIndex: 1})
|
||||
|
||||
got, err := UnmarshalPayload(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, uint32(1), got.InitiatorIndex)
|
||||
assert.Equal(t, uint32(0), got.ResponderIndex)
|
||||
assert.Equal(t, uint64(0), got.Time)
|
||||
assert.Nil(t, got.Cert)
|
||||
})
|
||||
|
||||
t.Run("empty payload", func(t *testing.T) {
|
||||
data := MarshalPayload(nil, Payload{})
|
||||
|
||||
got, err := UnmarshalPayload(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, uint32(0), got.InitiatorIndex)
|
||||
})
|
||||
|
||||
t.Run("large cert bytes", func(t *testing.T) {
|
||||
bigCert := make([]byte, 4096)
|
||||
for i := range bigCert {
|
||||
bigCert[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
data := MarshalPayload(nil, Payload{
|
||||
Cert: bigCert,
|
||||
CertVersion: 2,
|
||||
InitiatorIndex: 999,
|
||||
})
|
||||
|
||||
got, err := UnmarshalPayload(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, bigCert, got.Cert)
|
||||
assert.Equal(t, uint32(999), got.InitiatorIndex)
|
||||
})
|
||||
|
||||
t.Run("append to existing buffer", func(t *testing.T) {
|
||||
prefix := []byte("prefix")
|
||||
data := MarshalPayload(prefix, Payload{InitiatorIndex: 42})
|
||||
|
||||
assert.Equal(t, []byte("prefix"), data[:6])
|
||||
|
||||
got, err := UnmarshalPayload(data[6:])
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint32(42), got.InitiatorIndex)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPayloadUnknownFields(t *testing.T) {
|
||||
t.Run("unknown field in outer message is skipped", func(t *testing.T) {
|
||||
// Marshal a normal payload then append an unknown field (field 99, varint)
|
||||
data := MarshalPayload(nil, Payload{InitiatorIndex: 42})
|
||||
data = protowire.AppendTag(data, 99, protowire.VarintType)
|
||||
data = protowire.AppendVarint(data, 12345)
|
||||
|
||||
got, err := UnmarshalPayload(data)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint32(42), got.InitiatorIndex)
|
||||
})
|
||||
|
||||
t.Run("unknown field in details is skipped", func(t *testing.T) {
|
||||
// Build details with a known field + unknown field
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, 77)
|
||||
// Unknown field 50, varint
|
||||
details = protowire.AppendTag(details, 50, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, 9999)
|
||||
// Another known field after the unknown one
|
||||
details = protowire.AppendTag(details, fieldResponderIndex, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, 88)
|
||||
|
||||
// Wrap in outer message
|
||||
var data []byte
|
||||
data = protowire.AppendTag(data, 1, protowire.BytesType)
|
||||
data = protowire.AppendBytes(data, details)
|
||||
|
||||
got, err := UnmarshalPayload(data)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint32(77), got.InitiatorIndex)
|
||||
assert.Equal(t, uint32(88), got.ResponderIndex)
|
||||
})
|
||||
|
||||
t.Run("reserved fields 6 and 7 are skipped", func(t *testing.T) {
|
||||
// Fields 6 and 7 are reserved in the proto definition
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, 100)
|
||||
details = protowire.AppendTag(details, 6, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, 1)
|
||||
details = protowire.AppendTag(details, 7, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, 2)
|
||||
|
||||
var data []byte
|
||||
data = protowire.AppendTag(data, 1, protowire.BytesType)
|
||||
data = protowire.AppendBytes(data, details)
|
||||
|
||||
got, err := UnmarshalPayload(data)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint32(100), got.InitiatorIndex)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPayloadBytesConsumed(t *testing.T) {
|
||||
t.Run("all bytes consumed on valid input", func(t *testing.T) {
|
||||
original := Payload{
|
||||
Cert: []byte("cert"),
|
||||
CertVersion: 2,
|
||||
InitiatorIndex: 100,
|
||||
ResponderIndex: 200,
|
||||
Time: 999,
|
||||
}
|
||||
data := MarshalPayload(nil, original)
|
||||
|
||||
got, err := UnmarshalPayload(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Re-marshal and compare — proves we consumed and reproduced all fields
|
||||
remarshaled := MarshalPayload(nil, got)
|
||||
assert.Equal(t, data, remarshaled)
|
||||
})
|
||||
}
|
||||
|
||||
// wrapDetails wraps raw detail bytes in the outer NebulaHandshake envelope
|
||||
// so UnmarshalPayload can reach unmarshalPayloadDetails.
|
||||
func wrapDetails(details []byte) []byte {
|
||||
var out []byte
|
||||
out = protowire.AppendTag(out, 1, protowire.BytesType)
|
||||
out = protowire.AppendBytes(out, details)
|
||||
return out
|
||||
}
|
||||
|
||||
func TestPayloadUnmarshalErrors(t *testing.T) {
|
||||
t.Run("nil input", func(t *testing.T) {
|
||||
got, err := UnmarshalPayload(nil)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint32(0), got.InitiatorIndex)
|
||||
})
|
||||
|
||||
t.Run("truncated outer tag", func(t *testing.T) {
|
||||
_, err := UnmarshalPayload([]byte{0x80})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("truncated outer details field", func(t *testing.T) {
|
||||
_, err := UnmarshalPayload([]byte{0x0a, 0x64, 0x01, 0x02, 0x03, 0x04, 0x05})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("truncated outer unknown field", func(t *testing.T) {
|
||||
// Valid tag for unknown field 99 varint, but no value follows
|
||||
var data []byte
|
||||
data = protowire.AppendTag(data, 99, protowire.VarintType)
|
||||
_, err := UnmarshalPayload(data)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("truncated details tag", func(t *testing.T) {
|
||||
_, err := UnmarshalPayload(wrapDetails([]byte{0x80}))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("truncated cert bytes", func(t *testing.T) {
|
||||
// Field 1 (cert), bytes type, length 10 but only 2 bytes
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldCert, protowire.BytesType)
|
||||
details = append(details, 0x0a, 0x01, 0x02) // length 10, only 2 bytes
|
||||
_, err := UnmarshalPayload(wrapDetails(details))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("truncated initiator index varint", func(t *testing.T) {
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
|
||||
details = append(details, 0x80) // incomplete varint
|
||||
_, err := UnmarshalPayload(wrapDetails(details))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("truncated responder index varint", func(t *testing.T) {
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldResponderIndex, protowire.VarintType)
|
||||
details = append(details, 0x80)
|
||||
_, err := UnmarshalPayload(wrapDetails(details))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("truncated time varint", func(t *testing.T) {
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldTime, protowire.VarintType)
|
||||
details = append(details, 0x80)
|
||||
_, err := UnmarshalPayload(wrapDetails(details))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("truncated cert version varint", func(t *testing.T) {
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldCertVersion, protowire.VarintType)
|
||||
details = append(details, 0x80)
|
||||
_, err := UnmarshalPayload(wrapDetails(details))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("truncated unknown field in details", func(t *testing.T) {
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, 50, protowire.VarintType)
|
||||
details = append(details, 0x80) // incomplete varint
|
||||
_, err := UnmarshalPayload(wrapDetails(details))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("cert with wrong wire type rejected", func(t *testing.T) {
|
||||
// fieldCert as Varint instead of Bytes.
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldCert, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, 42)
|
||||
_, err := UnmarshalPayload(wrapDetails(details))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("initiator index with wrong wire type rejected", func(t *testing.T) {
|
||||
// fieldInitiatorIndex as Bytes instead of Varint.
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.BytesType)
|
||||
details = protowire.AppendBytes(details, []byte{1, 2, 3})
|
||||
_, err := UnmarshalPayload(wrapDetails(details))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("time with wrong wire type rejected", func(t *testing.T) {
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldTime, protowire.BytesType)
|
||||
details = protowire.AppendBytes(details, []byte{1, 2, 3})
|
||||
_, err := UnmarshalPayload(wrapDetails(details))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("cert version with wrong wire type rejected", func(t *testing.T) {
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldCertVersion, protowire.BytesType)
|
||||
details = protowire.AppendBytes(details, []byte{1, 2, 3})
|
||||
_, err := UnmarshalPayload(wrapDetails(details))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("repeated singular field follows proto3 last-wins", func(t *testing.T) {
|
||||
// Per proto3, multiple instances of a singular field are accepted and
|
||||
// the last value wins. We keep this behavior so that peers using
|
||||
// alternative encoders aren't rejected.
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, 1)
|
||||
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, 42)
|
||||
got, err := UnmarshalPayload(wrapDetails(details))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint32(42), got.InitiatorIndex)
|
||||
})
|
||||
|
||||
t.Run("initiator index varint overflow rejected", func(t *testing.T) {
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, math.MaxUint32+1)
|
||||
_, err := UnmarshalPayload(wrapDetails(details))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("cert version varint overflow rejected", func(t *testing.T) {
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldCertVersion, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, math.MaxUint32+1)
|
||||
_, err := UnmarshalPayload(wrapDetails(details))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
// FuzzPayload feeds arbitrary bytes through UnmarshalPayload to confirm it
|
||||
// never panics, and for any input that parses cleanly, that re-marshal +
|
||||
// re-parse is a fix-point. Inputs come from an authenticated peer (post-
|
||||
// noise-decrypt), so the threat model is "valid peer behaving arbitrarily,"
|
||||
// not "unauthenticated injection."
|
||||
func FuzzPayload(f *testing.F) {
|
||||
// Seed corpus with a handful of known-good shapes.
|
||||
f.Add(MarshalPayload(nil, Payload{}))
|
||||
f.Add(MarshalPayload(nil, Payload{Cert: []byte{1, 2, 3}, CertVersion: 2}))
|
||||
f.Add(MarshalPayload(nil, Payload{InitiatorIndex: 42, Time: 1}))
|
||||
f.Add(MarshalPayload(nil, Payload{
|
||||
Cert: []byte("seed-cert"),
|
||||
InitiatorIndex: 1,
|
||||
ResponderIndex: 2,
|
||||
Time: 3,
|
||||
CertVersion: 2,
|
||||
}))
|
||||
f.Add([]byte{})
|
||||
f.Add([]byte{0xff})
|
||||
|
||||
f.Fuzz(func(t *testing.T, data []byte) {
|
||||
p1, err := UnmarshalPayload(data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// For any input that parses, re-marshaling and re-parsing must
|
||||
// yield an equivalent Payload. This catches dispatch bugs (e.g.
|
||||
// emitting a field on marshal that we don't accept on parse) and
|
||||
// any non-idempotent parsing behavior.
|
||||
b2 := MarshalPayload(nil, p1)
|
||||
p2, err := UnmarshalPayload(b2)
|
||||
if err != nil {
|
||||
t.Fatalf("re-parse of self-marshaled payload failed: %v\nintermediate: %x\n", err, b2)
|
||||
}
|
||||
if !payloadsEqual(p1, p2) {
|
||||
t.Fatalf("re-marshal not idempotent\nfirst: %+v\nsecond: %+v", p1, p2)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func payloadsEqual(a, b Payload) bool {
|
||||
return bytes.Equal(a.Cert, b.Cert) &&
|
||||
a.InitiatorIndex == b.InitiatorIndex &&
|
||||
a.ResponderIndex == b.ResponderIndex &&
|
||||
a.Time == b.Time &&
|
||||
a.CertVersion == b.CertVersion
|
||||
}
|
||||
Reference in New Issue
Block a user