diff --git a/handshake/errors.go b/handshake/errors.go index bb8a5893..3bdcc947 100644 --- a/handshake/errors.go +++ b/handshake/errors.go @@ -13,6 +13,7 @@ var ( ErrUnknownSubtype = errors.New("unknown handshake subtype") ErrMissingContent = errors.New("expected handshake content but message was empty") ErrUnexpectedContent = errors.New("received unexpected handshake content") + ErrInvalidRemoteIndex = errors.New("peer sent an invalid index in handshake payload") ErrIndexAllocation = errors.New("failed to allocate local index") ErrNoCredential = errors.New("no handshake credential available for cert version") ErrAsymmetricCipherKeys = errors.New("noise produced only one cipher key") diff --git a/handshake/machine.go b/handshake/machine.go index 737358dc..baf61589 100644 --- a/handshake/machine.go +++ b/handshake/machine.go @@ -312,11 +312,19 @@ func (m *Machine) processPayload(msg []byte, flags msgFlags) error { // Process payload if flags.expectsPayload { + var remoteIndex uint32 if m.result.Initiator { - m.result.RemoteIndex = payload.ResponderIndex + remoteIndex = payload.ResponderIndex } else { - m.result.RemoteIndex = payload.InitiatorIndex + remoteIndex = payload.InitiatorIndex } + // The payload presence check above can be satisfied by Time alone, so a payload + // could still carry a zero index here. We need to reject it. + if remoteIndex == 0 { + m.failed = true + return ErrInvalidRemoteIndex + } + m.result.RemoteIndex = remoteIndex m.result.HandshakeTime = payload.Time m.payloadSet = true } diff --git a/handshake/machine_test.go b/handshake/machine_test.go index 722a39e1..01c968ed 100644 --- a/handshake/machine_test.go +++ b/handshake/machine_test.go @@ -229,6 +229,24 @@ func TestMachineProcessPayload(t *testing.T) { require.ErrorIs(t, err, ErrUnexpectedContent) assert.True(t, m.Failed()) }) + + t.Run("zero initiator index on responder is fatal", func(t *testing.T) { + m := newTestMachine(t, cs, v, false, 100) + bytes := MarshalPayload(nil, Payload{InitiatorIndex: 0, Time: 1}) + err := m.processPayload(bytes, msgFlags{expectsPayload: true}) + require.ErrorIs(t, err, ErrInvalidRemoteIndex) + assert.True(t, m.Failed()) + assert.Zero(t, m.result.RemoteIndex) + }) + + t.Run("zero responder index on initiator is fatal", func(t *testing.T) { + m := newTestMachine(t, cs, v, true, 100) + bytes := MarshalPayload(nil, Payload{InitiatorIndex: 100, ResponderIndex: 0, Time: 1}) + err := m.processPayload(bytes, msgFlags{expectsPayload: true}) + require.ErrorIs(t, err, ErrInvalidRemoteIndex) + assert.True(t, m.Failed()) + assert.Zero(t, m.result.RemoteIndex) + }) } // TestMachineRequireComplete checks the fail-on-incomplete-handshake path