Support for ipv6 in the overlay with v2 certificates

---------

Co-authored-by: Jack Doan <jackdoan@rivian.com>
This commit is contained in:
Nate Brown 2024-10-23 22:02:10 -05:00
parent 3e6c75573f
commit f2c32421c4
86 changed files with 5747 additions and 3335 deletions

View File

@ -196,7 +196,7 @@ bench-cpu-long:
go test -bench=. -benchtime=60s -cpuprofile=cpu.pprof
go tool pprof go-audit.test cpu.pprof
proto: nebula.pb.go cert/cert.pb.go
proto: nebula.pb.go cert/cert_v1.pb.go
nebula.pb.go: nebula.proto .FORCE
go build github.com/gogo/protobuf/protoc-gen-gogofaster

View File

@ -21,7 +21,11 @@ type calculatedRemote struct {
port uint32
}
func newCalculatedRemote(maskCidr netip.Prefix, port int) (*calculatedRemote, error) {
func newCalculatedRemote(cidr, maskCidr netip.Prefix, port int) (*calculatedRemote, error) {
if maskCidr.Addr().BitLen() != cidr.Addr().BitLen() {
return nil, fmt.Errorf("invalid mask: %s for cidr: %s", maskCidr, cidr)
}
masked := maskCidr.Masked()
if port < 0 || port > math.MaxUint16 {
return nil, fmt.Errorf("invalid port: %d", port)
@ -38,32 +42,38 @@ func (c *calculatedRemote) String() string {
return fmt.Sprintf("CalculatedRemote(mask=%v port=%d)", c.ipNet, c.port)
}
func (c *calculatedRemote) Apply(ip netip.Addr) *Ip4AndPort {
// Combine the masked bytes of the "mask" IP with the unmasked bytes
// of the overlay IP
if c.ipNet.Addr().Is4() {
return c.apply4(ip)
}
return c.apply6(ip)
}
func (c *calculatedRemote) apply4(ip netip.Addr) *Ip4AndPort {
//TODO: IPV6-WORK this can be less crappy
func (c *calculatedRemote) ApplyV4(addr netip.Addr) *V4AddrPort {
// Combine the masked bytes of the "mask" IP with the unmasked bytes of the overlay IP
maskb := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen())
mask := binary.BigEndian.Uint32(maskb[:])
b := c.mask.Addr().As4()
maskIp := binary.BigEndian.Uint32(b[:])
maskAddr := binary.BigEndian.Uint32(b[:])
b = ip.As4()
intIp := binary.BigEndian.Uint32(b[:])
b = addr.As4()
intAddr := binary.BigEndian.Uint32(b[:])
return &Ip4AndPort{(maskIp & mask) | (intIp & ^mask), c.port}
return &V4AddrPort{(maskAddr & mask) | (intAddr & ^mask), c.port}
}
func (c *calculatedRemote) apply6(ip netip.Addr) *Ip4AndPort {
//TODO: IPV6-WORK
panic("Can not calculate ipv6 remote addresses")
func (c *calculatedRemote) ApplyV6(addr netip.Addr) *V6AddrPort {
mask := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen())
maskAddr := c.mask.Addr().As16()
calcAddr := addr.As16()
ap := V6AddrPort{Port: c.port}
maskb := binary.BigEndian.Uint64(mask[:8])
maskAddrb := binary.BigEndian.Uint64(maskAddr[:8])
calcAddrb := binary.BigEndian.Uint64(calcAddr[:8])
ap.Hi = (maskAddrb & maskb) | (calcAddrb & ^maskb)
maskb = binary.BigEndian.Uint64(mask[8:])
maskAddrb = binary.BigEndian.Uint64(maskAddr[8:])
calcAddrb = binary.BigEndian.Uint64(calcAddr[8:])
ap.Lo = (maskAddrb & maskb) | (calcAddrb & ^maskb)
return &ap
}
func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calculatedRemote], error) {
@ -89,8 +99,7 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calcu
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
}
//TODO: IPV6-WORK this does not verify that rawValue contains the same bits as cidr here
entry, err := newCalculatedRemotesListFromConfig(rawValue)
entry, err := newCalculatedRemotesListFromConfig(cidr, rawValue)
if err != nil {
return nil, fmt.Errorf("config '%s.%s': %w", k, rawCIDR, err)
}
@ -101,7 +110,7 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calcu
return calculatedRemotes, nil
}
func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) {
func newCalculatedRemotesListFromConfig(cidr netip.Prefix, raw any) ([]*calculatedRemote, error) {
rawList, ok := raw.([]any)
if !ok {
return nil, fmt.Errorf("calculated_remotes entry has invalid type: %T", raw)
@ -109,7 +118,7 @@ func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) {
var l []*calculatedRemote
for _, e := range rawList {
c, err := newCalculatedRemotesEntryFromConfig(e)
c, err := newCalculatedRemotesEntryFromConfig(cidr, e)
if err != nil {
return nil, fmt.Errorf("calculated_remotes entry: %w", err)
}
@ -119,7 +128,7 @@ func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) {
return l, nil
}
func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) {
func newCalculatedRemotesEntryFromConfig(cidr netip.Prefix, raw any) (*calculatedRemote, error) {
rawMap, ok := raw.(map[any]any)
if !ok {
return nil, fmt.Errorf("invalid type: %T", raw)
@ -155,5 +164,5 @@ func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) {
return nil, fmt.Errorf("invalid port (type %T): %v", rawValue, rawValue)
}
return newCalculatedRemote(maskCidr, port)
return newCalculatedRemote(cidr, maskCidr, port)
}

View File

@ -9,10 +9,9 @@ import (
)
func TestCalculatedRemoteApply(t *testing.T) {
ipNet, err := netip.ParsePrefix("192.168.1.0/24")
require.NoError(t, err)
c, err := newCalculatedRemote(ipNet, 4242)
// Test v4 addresses
ipNet := netip.MustParsePrefix("192.168.1.0/24")
c, err := newCalculatedRemote(ipNet, ipNet, 4242)
require.NoError(t, err)
input, err := netip.ParseAddr("10.0.10.182")
@ -21,5 +20,62 @@ func TestCalculatedRemoteApply(t *testing.T) {
expected, err := netip.ParseAddr("192.168.1.182")
assert.NoError(t, err)
assert.Equal(t, NewIp4AndPortFromNetIP(expected, 4242), c.Apply(input))
assert.Equal(t, netAddrToProtoV4AddrPort(expected, 4242), c.ApplyV4(input))
// Test v6 addresses
ipNet = netip.MustParsePrefix("ffff:ffff:ffff:ffff::0/64")
c, err = newCalculatedRemote(ipNet, ipNet, 4242)
require.NoError(t, err)
input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
assert.NoError(t, err)
expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:beef:beef:beef:beef")
assert.NoError(t, err)
assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
// Test v6 addresses part 2
ipNet = netip.MustParsePrefix("ffff:ffff:ffff:ffff:ffff::0/80")
c, err = newCalculatedRemote(ipNet, ipNet, 4242)
require.NoError(t, err)
input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
assert.NoError(t, err)
expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:ffff:beef:beef:beef")
assert.NoError(t, err)
assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
// Test v6 addresses part 2
ipNet = netip.MustParsePrefix("ffff:ffff:ffff::0/48")
c, err = newCalculatedRemote(ipNet, ipNet, 4242)
require.NoError(t, err)
input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
assert.NoError(t, err)
expected, err = netip.ParseAddr("ffff:ffff:ffff:beef:beef:beef:beef:beef")
assert.NoError(t, err)
assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
}
func Test_newCalculatedRemote(t *testing.T) {
c, err := newCalculatedRemote(netip.MustParsePrefix("1::1/128"), netip.MustParsePrefix("1.0.0.0/32"), 4242)
require.EqualError(t, err, "invalid mask: 1.0.0.0/32 for cidr: 1::1/128")
require.Nil(t, c)
c, err = newCalculatedRemote(netip.MustParsePrefix("1.0.0.0/32"), netip.MustParsePrefix("1::1/128"), 4242)
require.EqualError(t, err, "invalid mask: 1::1/128 for cidr: 1.0.0.0/32")
require.Nil(t, c)
c, err = newCalculatedRemote(netip.MustParsePrefix("1.0.0.0/32"), netip.MustParsePrefix("1.0.0.0/32"), 4242)
require.NoError(t, err)
require.NotNil(t, c)
c, err = newCalculatedRemote(netip.MustParsePrefix("1::1/128"), netip.MustParsePrefix("1::1/128"), 4242)
require.NoError(t, err)
require.NotNil(t, c)
}

View File

@ -2,14 +2,25 @@
This is a library for interacting with `nebula` style certificates and authorities.
A `protobuf` definition of the certificate format is also included
There are now 2 versions of `nebula` certificates:
### Compiling the protobuf definition
## v1
Make sure you have `protoc` installed.
This version is deprecated.
A `protobuf` definition of the certificate format is included at `cert_v1.proto`
To compile the definition you will need `protoc` installed.
To compile for `go` with the same version of protobuf specified in go.mod:
```bash
make
make proto
```
## v2
This is the latest version which uses asn.1 DER encoding. It can support ipv4 and ipv6 and tolerate
future certificate changes better than v1.
`cert_v2.asn1` defines the wire format and can be used to compile marshalers.

52
cert/asn1.go Normal file
View File

@ -0,0 +1,52 @@
package cert
import (
"golang.org/x/crypto/cryptobyte"
"golang.org/x/crypto/cryptobyte/asn1"
)
// readOptionalASN1Boolean reads an asn.1 boolean with a specific tag instead of a asn.1 tag wrapping a boolean with a value
// https://github.com/golang/go/issues/64811#issuecomment-1944446920
func readOptionalASN1Boolean(b *cryptobyte.String, out *bool, tag asn1.Tag, defaultValue bool) bool {
var present bool
var child cryptobyte.String
if !b.ReadOptionalASN1(&child, &present, tag) {
return false
}
if !present {
*out = defaultValue
return true
}
// Ensure we have 1 byte
if len(child) == 1 {
*out = child[0] > 0
return true
}
return false
}
// readOptionalASN1Byte reads an asn.1 uint8 with a specific tag instead of a asn.1 tag wrapping a uint8 with a value
// Similar issue as with readOptionalASN1Boolean
func readOptionalASN1Byte(b *cryptobyte.String, out *byte, tag asn1.Tag, defaultValue byte) bool {
var present bool
var child cryptobyte.String
if !b.ReadOptionalASN1(&child, &present, tag) {
return false
}
if !present {
*out = defaultValue
return true
}
// Ensure we have 1 byte
if len(child) == 1 {
*out = child[0]
return true
}
return false
}

View File

@ -63,31 +63,31 @@ IBNWYMep3ysx9zCgknfG5dKtwGTaqF++BWKDYdyl34KX
rootCA := certificateV1{
details: detailsV1{
Name: "nebula root ca",
name: "nebula root ca",
},
}
rootCA01 := certificateV1{
details: detailsV1{
Name: "nebula root ca 01",
name: "nebula root ca 01",
},
}
rootCAP256 := certificateV1{
details: detailsV1{
Name: "nebula P256 test",
name: "nebula P256 test",
},
}
p, err := NewCAPoolFromPEM([]byte(noNewLines))
assert.Nil(t, err)
assert.Equal(t, p.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.Name)
assert.Equal(t, p.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.Name)
assert.Equal(t, p.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.name)
assert.Equal(t, p.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.name)
pp, err := NewCAPoolFromPEM([]byte(withNewLines))
assert.Nil(t, err)
assert.Equal(t, pp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.Name)
assert.Equal(t, pp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.Name)
assert.Equal(t, pp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.name)
assert.Equal(t, pp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.name)
// expired cert, no valid certs
ppp, err := NewCAPoolFromPEM([]byte(expired))
@ -97,13 +97,13 @@ IBNWYMep3ysx9zCgknfG5dKtwGTaqF++BWKDYdyl34KX
// expired cert, with valid certs
pppp, err := NewCAPoolFromPEM(append([]byte(expired), noNewLines...))
assert.Equal(t, ErrExpired, err)
assert.Equal(t, pppp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.Name)
assert.Equal(t, pppp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.Name)
assert.Equal(t, pppp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.name)
assert.Equal(t, pppp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.name)
assert.Equal(t, pppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Certificate.Name(), "expired")
assert.Equal(t, len(pppp.CAs), 3)
ppppp, err := NewCAPoolFromPEM([]byte(p256))
assert.Nil(t, err)
assert.Equal(t, ppppp.CAs[string("a7938893ec8c4ef769b06d7f425e5e46f7a7f5ffa49c3bcf4a86b608caba9159")].Certificate.Name(), rootCAP256.details.Name)
assert.Equal(t, ppppp.CAs[string("a7938893ec8c4ef769b06d7f425e5e46f7a7f5ffa49c3bcf4a86b608caba9159")].Certificate.Name(), rootCAP256.details.name)
assert.Equal(t, len(ppppp.CAs), 1)
}

View File

@ -1,15 +1,17 @@
package cert
import (
"fmt"
"net/netip"
"time"
)
type Version int
type Version uint8
const (
Version1 Version = 1
Version2 Version = 2
VersionPre1 Version = 0
Version1 Version = 1
Version2 Version = 2
)
type Certificate interface {
@ -107,23 +109,56 @@ type CachedCertificate struct {
signerFingerprint string
}
// UnmarshalCertificate will attempt to unmarshal a wire protocol level certificate.
func UnmarshalCertificate(b []byte) (Certificate, error) {
c, err := unmarshalCertificateV1(b, true)
if err != nil {
return nil, err
}
return c, nil
func (cc *CachedCertificate) String() string {
return cc.Certificate.String()
}
// UnmarshalCertificateFromHandshake will attempt to unmarshal a certificate received in a handshake.
// RecombineAndValidate will attempt to unmarshal a certificate received in a handshake.
// Handshakes save space by placing the peers public key in a different part of the packet, we have to
// reassemble the actual certificate structure with that in mind.
func UnmarshalCertificateFromHandshake(b []byte, publicKey []byte) (Certificate, error) {
c, err := unmarshalCertificateV1(b, false)
func RecombineAndValidate(v Version, rawCertBytes, publicKey []byte, curve Curve, caPool *CAPool) (*CachedCertificate, error) {
if publicKey == nil {
return nil, ErrNoPeerStaticKey
}
if rawCertBytes == nil {
return nil, ErrNoPayload
}
c, err := unmarshalCertificateFromHandshake(v, rawCertBytes, publicKey, curve)
if err != nil {
return nil, fmt.Errorf("error unmarshaling cert: %w", err)
}
cc, err := caPool.VerifyCertificate(time.Now(), c)
if err != nil {
return nil, fmt.Errorf("certificate validation failed: %w", err)
}
return cc, nil
}
func unmarshalCertificateFromHandshake(v Version, b []byte, publicKey []byte, curve Curve) (Certificate, error) {
var c Certificate
var err error
switch v {
case VersionPre1, Version1:
c, err = unmarshalCertificateV1(b, publicKey)
case Version2:
c, err = unmarshalCertificateV2(b, publicKey, curve)
default:
//TODO: make a static var
return nil, fmt.Errorf("unknown certificate version %d", v)
}
if err != nil {
return nil, err
}
c.details.PublicKey = publicKey
if c.Curve() != curve {
return nil, fmt.Errorf("certificate curve %s does not match expected %s", c.Curve().String(), curve.String())
}
return c, nil
}

View File

@ -24,21 +24,21 @@ func TestMarshalingNebulaCertificate(t *testing.T) {
nc := certificateV1{
details: detailsV1{
Name: "testing",
Ips: []netip.Prefix{
name: "testing",
networks: []netip.Prefix{
mustParsePrefixUnmapped("10.1.1.1/24"),
mustParsePrefixUnmapped("10.1.1.2/16"),
},
Subnets: []netip.Prefix{
unsafeNetworks: []netip.Prefix{
mustParsePrefixUnmapped("9.1.1.2/24"),
mustParsePrefixUnmapped("9.1.1.3/16"),
},
Groups: []string{"test-group1", "test-group2", "test-group3"},
NotBefore: before,
NotAfter: after,
PublicKey: pubKey,
IsCA: false,
Issuer: "1234567890abcedfghij1234567890ab",
groups: []string{"test-group1", "test-group2", "test-group3"},
notBefore: before,
notAfter: after,
publicKey: pubKey,
isCA: false,
issuer: "1234567890abcedfghij1234567890ab",
},
signature: []byte("1234567890abcedfghij1234567890ab"),
}
@ -47,20 +47,20 @@ func TestMarshalingNebulaCertificate(t *testing.T) {
assert.Nil(t, err)
//t.Log("Cert size:", len(b))
nc2, err := unmarshalCertificateV1(b, true)
nc2, err := unmarshalCertificateV1(b, nil)
assert.Nil(t, err)
assert.Equal(t, nc.signature, nc2.Signature())
assert.Equal(t, nc.details.Name, nc2.Name())
assert.Equal(t, nc.details.NotBefore, nc2.NotBefore())
assert.Equal(t, nc.details.NotAfter, nc2.NotAfter())
assert.Equal(t, nc.details.PublicKey, nc2.PublicKey())
assert.Equal(t, nc.details.IsCA, nc2.IsCA())
assert.Equal(t, nc.details.name, nc2.Name())
assert.Equal(t, nc.details.notBefore, nc2.NotBefore())
assert.Equal(t, nc.details.notAfter, nc2.NotAfter())
assert.Equal(t, nc.details.publicKey, nc2.PublicKey())
assert.Equal(t, nc.details.isCA, nc2.IsCA())
assert.Equal(t, nc.details.Ips, nc2.Networks())
assert.Equal(t, nc.details.Subnets, nc2.UnsafeNetworks())
assert.Equal(t, nc.details.networks, nc2.Networks())
assert.Equal(t, nc.details.unsafeNetworks, nc2.UnsafeNetworks())
assert.Equal(t, nc.details.Groups, nc2.Groups())
assert.Equal(t, nc.details.groups, nc2.Groups())
}
//func TestNebulaCertificate_Sign(t *testing.T) {
@ -150,8 +150,8 @@ func TestMarshalingNebulaCertificate(t *testing.T) {
func TestNebulaCertificate_Expired(t *testing.T) {
nc := certificateV1{
details: detailsV1{
NotBefore: time.Now().Add(time.Second * -60).Round(time.Second),
NotAfter: time.Now().Add(time.Second * 60).Round(time.Second),
notBefore: time.Now().Add(time.Second * -60).Round(time.Second),
notAfter: time.Now().Add(time.Second * 60).Round(time.Second),
},
}
@ -166,21 +166,21 @@ func TestNebulaCertificate_MarshalJSON(t *testing.T) {
nc := certificateV1{
details: detailsV1{
Name: "testing",
Ips: []netip.Prefix{
name: "testing",
networks: []netip.Prefix{
mustParsePrefixUnmapped("10.1.1.1/24"),
mustParsePrefixUnmapped("10.1.1.2/16"),
},
Subnets: []netip.Prefix{
unsafeNetworks: []netip.Prefix{
mustParsePrefixUnmapped("9.1.1.2/24"),
mustParsePrefixUnmapped("9.1.1.3/16"),
},
Groups: []string{"test-group1", "test-group2", "test-group3"},
NotBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC),
NotAfter: time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC),
PublicKey: pubKey,
IsCA: false,
Issuer: "1234567890abcedfghij1234567890ab",
groups: []string{"test-group1", "test-group2", "test-group3"},
notBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC),
notAfter: time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC),
publicKey: pubKey,
isCA: false,
issuer: "1234567890abcedfghij1234567890ab",
},
signature: []byte("1234567890abcedfghij1234567890ab"),
}
@ -189,7 +189,7 @@ func TestNebulaCertificate_MarshalJSON(t *testing.T) {
assert.Nil(t, err)
assert.Equal(
t,
"{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"ips\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"subnets\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\"}",
"{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"version\":1}",
string(b),
)
}
@ -526,7 +526,7 @@ func TestNebulaCertificate_Copy(t *testing.T) {
func TestUnmarshalNebulaCertificate(t *testing.T) {
// Test that we don't panic with an invalid certificate (#332)
data := []byte("\x98\x00\x00")
_, err := unmarshalCertificateV1(data, true)
_, err := unmarshalCertificateV1(data, nil)
assert.EqualError(t, err, "encoded Details was nil")
}

View File

@ -6,19 +6,16 @@ import (
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rand"
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"encoding/json"
"encoding/pem"
"fmt"
"math/big"
"net"
"net/netip"
"time"
"github.com/slackhq/nebula/pkclient"
"golang.org/x/crypto/curve25519"
"google.golang.org/protobuf/proto"
)
@ -31,71 +28,71 @@ type certificateV1 struct {
}
type detailsV1 struct {
Name string
Ips []netip.Prefix
Subnets []netip.Prefix
Groups []string
NotBefore time.Time
NotAfter time.Time
PublicKey []byte
IsCA bool
Issuer string
name string
networks []netip.Prefix
unsafeNetworks []netip.Prefix
groups []string
notBefore time.Time
notAfter time.Time
publicKey []byte
isCA bool
issuer string
Curve Curve
curve Curve
}
type m map[string]interface{}
func (nc *certificateV1) Version() Version {
func (c *certificateV1) Version() Version {
return Version1
}
func (nc *certificateV1) Curve() Curve {
return nc.details.Curve
func (c *certificateV1) Curve() Curve {
return c.details.curve
}
func (nc *certificateV1) Groups() []string {
return nc.details.Groups
func (c *certificateV1) Groups() []string {
return c.details.groups
}
func (nc *certificateV1) IsCA() bool {
return nc.details.IsCA
func (c *certificateV1) IsCA() bool {
return c.details.isCA
}
func (nc *certificateV1) Issuer() string {
return nc.details.Issuer
func (c *certificateV1) Issuer() string {
return c.details.issuer
}
func (nc *certificateV1) Name() string {
return nc.details.Name
func (c *certificateV1) Name() string {
return c.details.name
}
func (nc *certificateV1) Networks() []netip.Prefix {
return nc.details.Ips
func (c *certificateV1) Networks() []netip.Prefix {
return c.details.networks
}
func (nc *certificateV1) NotAfter() time.Time {
return nc.details.NotAfter
func (c *certificateV1) NotAfter() time.Time {
return c.details.notAfter
}
func (nc *certificateV1) NotBefore() time.Time {
return nc.details.NotBefore
func (c *certificateV1) NotBefore() time.Time {
return c.details.notBefore
}
func (nc *certificateV1) PublicKey() []byte {
return nc.details.PublicKey
func (c *certificateV1) PublicKey() []byte {
return c.details.publicKey
}
func (nc *certificateV1) Signature() []byte {
return nc.signature
func (c *certificateV1) Signature() []byte {
return c.signature
}
func (nc *certificateV1) UnsafeNetworks() []netip.Prefix {
return nc.details.Subnets
func (c *certificateV1) UnsafeNetworks() []netip.Prefix {
return c.details.unsafeNetworks
}
func (nc *certificateV1) Fingerprint() (string, error) {
b, err := nc.Marshal()
func (c *certificateV1) Fingerprint() (string, error) {
b, err := c.Marshal()
if err != nil {
return "", err
}
@ -104,33 +101,33 @@ func (nc *certificateV1) Fingerprint() (string, error) {
return hex.EncodeToString(sum[:]), nil
}
func (nc *certificateV1) CheckSignature(key []byte) bool {
b, err := proto.Marshal(nc.getRawDetails())
func (c *certificateV1) CheckSignature(key []byte) bool {
b, err := proto.Marshal(c.getRawDetails())
if err != nil {
return false
}
switch nc.details.Curve {
switch c.details.curve {
case Curve_CURVE25519:
return ed25519.Verify(key, b, nc.signature)
return ed25519.Verify(key, b, c.signature)
case Curve_P256:
x, y := elliptic.Unmarshal(elliptic.P256(), key)
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
hashed := sha256.Sum256(b)
return ecdsa.VerifyASN1(pubKey, hashed[:], nc.signature)
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
default:
return false
}
}
func (nc *certificateV1) Expired(t time.Time) bool {
return nc.details.NotBefore.After(t) || nc.details.NotAfter.Before(t)
func (c *certificateV1) Expired(t time.Time) bool {
return c.details.notBefore.After(t) || c.details.notAfter.Before(t)
}
func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
if curve != nc.details.Curve {
func (c *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
if curve != c.details.curve {
return fmt.Errorf("curve in cert and private key supplied don't match")
}
if nc.details.IsCA {
if c.details.isCA {
switch curve {
case Curve_CURVE25519:
// the call to PublicKey below will panic slice bounds out of range otherwise
@ -138,7 +135,7 @@ func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key")
}
if !ed25519.PublicKey(nc.details.PublicKey).Equal(ed25519.PrivateKey(key).Public()) {
if !ed25519.PublicKey(c.details.publicKey).Equal(ed25519.PrivateKey(key).Public()) {
return fmt.Errorf("public key in cert and private key supplied don't match")
}
case Curve_P256:
@ -147,7 +144,7 @@ func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
return fmt.Errorf("cannot parse private key as P256: %w", err)
}
pub := privkey.PublicKey().Bytes()
if !bytes.Equal(pub, nc.details.PublicKey) {
if !bytes.Equal(pub, c.details.publicKey) {
return fmt.Errorf("public key in cert and private key supplied don't match")
}
default:
@ -173,7 +170,7 @@ func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
default:
return fmt.Errorf("invalid curve: %s", curve)
}
if !bytes.Equal(pub, nc.details.PublicKey) {
if !bytes.Equal(pub, c.details.publicKey) {
return fmt.Errorf("public key in cert and private key supplied don't match")
}
@ -181,173 +178,155 @@ func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
}
// getRawDetails marshals the raw details into protobuf ready struct
func (nc *certificateV1) getRawDetails() *RawNebulaCertificateDetails {
func (c *certificateV1) getRawDetails() *RawNebulaCertificateDetails {
rd := &RawNebulaCertificateDetails{
Name: nc.details.Name,
Groups: nc.details.Groups,
NotBefore: nc.details.NotBefore.Unix(),
NotAfter: nc.details.NotAfter.Unix(),
PublicKey: make([]byte, len(nc.details.PublicKey)),
IsCA: nc.details.IsCA,
Curve: nc.details.Curve,
Name: c.details.name,
Groups: c.details.groups,
NotBefore: c.details.notBefore.Unix(),
NotAfter: c.details.notAfter.Unix(),
PublicKey: make([]byte, len(c.details.publicKey)),
IsCA: c.details.isCA,
Curve: c.details.curve,
}
for _, ipNet := range nc.details.Ips {
for _, ipNet := range c.details.networks {
mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen())
rd.Ips = append(rd.Ips, addr2int(ipNet.Addr()), ip2int(mask))
}
for _, ipNet := range nc.details.Subnets {
for _, ipNet := range c.details.unsafeNetworks {
mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen())
rd.Subnets = append(rd.Subnets, addr2int(ipNet.Addr()), ip2int(mask))
}
copy(rd.PublicKey, nc.details.PublicKey[:])
copy(rd.PublicKey, c.details.publicKey[:])
// I know, this is terrible
rd.Issuer, _ = hex.DecodeString(nc.details.Issuer)
rd.Issuer, _ = hex.DecodeString(c.details.issuer)
return rd
}
func (nc *certificateV1) String() string {
if nc == nil {
return "Certificate {}\n"
func (c *certificateV1) String() string {
b, err := json.MarshalIndent(c.marshalJSON(), "", "\t")
if err != nil {
return "<error marshalling certificate>"
}
s := "NebulaCertificate {\n"
s += "\tDetails {\n"
s += fmt.Sprintf("\t\tName: %v\n", nc.details.Name)
if len(nc.details.Ips) > 0 {
s += "\t\tIps: [\n"
for _, ip := range nc.details.Ips {
s += fmt.Sprintf("\t\t\t%v\n", ip.String())
}
s += "\t\t]\n"
} else {
s += "\t\tIps: []\n"
}
if len(nc.details.Subnets) > 0 {
s += "\t\tSubnets: [\n"
for _, ip := range nc.details.Subnets {
s += fmt.Sprintf("\t\t\t%v\n", ip.String())
}
s += "\t\t]\n"
} else {
s += "\t\tSubnets: []\n"
}
if len(nc.details.Groups) > 0 {
s += "\t\tGroups: [\n"
for _, g := range nc.details.Groups {
s += fmt.Sprintf("\t\t\t\"%v\"\n", g)
}
s += "\t\t]\n"
} else {
s += "\t\tGroups: []\n"
}
s += fmt.Sprintf("\t\tNot before: %v\n", nc.details.NotBefore)
s += fmt.Sprintf("\t\tNot After: %v\n", nc.details.NotAfter)
s += fmt.Sprintf("\t\tIs CA: %v\n", nc.details.IsCA)
s += fmt.Sprintf("\t\tIssuer: %s\n", nc.details.Issuer)
s += fmt.Sprintf("\t\tPublic key: %x\n", nc.details.PublicKey)
s += fmt.Sprintf("\t\tCurve: %s\n", nc.details.Curve)
s += "\t}\n"
fp, err := nc.Fingerprint()
if err == nil {
s += fmt.Sprintf("\tFingerprint: %s\n", fp)
}
s += fmt.Sprintf("\tSignature: %x\n", nc.Signature())
s += "}"
return s
return string(b)
}
func (nc *certificateV1) MarshalForHandshakes() ([]byte, error) {
pubKey := nc.details.PublicKey
nc.details.PublicKey = nil
rawCertNoKey, err := nc.Marshal()
func (c *certificateV1) MarshalForHandshakes() ([]byte, error) {
pubKey := c.details.publicKey
c.details.publicKey = nil
rawCertNoKey, err := c.Marshal()
if err != nil {
return nil, err
}
nc.details.PublicKey = pubKey
c.details.publicKey = pubKey
return rawCertNoKey, nil
}
func (nc *certificateV1) Marshal() ([]byte, error) {
func (c *certificateV1) Marshal() ([]byte, error) {
rc := RawNebulaCertificate{
Details: nc.getRawDetails(),
Signature: nc.signature,
Details: c.getRawDetails(),
Signature: c.signature,
}
return proto.Marshal(&rc)
}
func (nc *certificateV1) MarshalPEM() ([]byte, error) {
b, err := nc.Marshal()
func (c *certificateV1) MarshalPEM() ([]byte, error) {
b, err := c.Marshal()
if err != nil {
return nil, err
}
return pem.EncodeToMemory(&pem.Block{Type: CertificateBanner, Bytes: b}), nil
}
func (nc *certificateV1) MarshalJSON() ([]byte, error) {
fp, _ := nc.Fingerprint()
jc := m{
"details": m{
"name": nc.details.Name,
"ips": nc.details.Ips,
"subnets": nc.details.Subnets,
"groups": nc.details.Groups,
"notBefore": nc.details.NotBefore,
"notAfter": nc.details.NotAfter,
"publicKey": fmt.Sprintf("%x", nc.details.PublicKey),
"isCa": nc.details.IsCA,
"issuer": nc.details.Issuer,
"curve": nc.details.Curve.String(),
},
"fingerprint": fp,
"signature": fmt.Sprintf("%x", nc.Signature()),
}
return json.Marshal(jc)
func (c *certificateV1) MarshalJSON() ([]byte, error) {
return json.Marshal(c.marshalJSON())
}
func (nc *certificateV1) Copy() Certificate {
c := &certificateV1{
details: detailsV1{
Name: nc.details.Name,
Groups: make([]string, len(nc.details.Groups)),
Ips: make([]netip.Prefix, len(nc.details.Ips)),
Subnets: make([]netip.Prefix, len(nc.details.Subnets)),
NotBefore: nc.details.NotBefore,
NotAfter: nc.details.NotAfter,
PublicKey: make([]byte, len(nc.details.PublicKey)),
IsCA: nc.details.IsCA,
Issuer: nc.details.Issuer,
func (c *certificateV1) marshalJSON() m {
fp, _ := c.Fingerprint()
return m{
"version": Version1,
"details": m{
"name": c.details.name,
"networks": c.details.networks,
"unsafeNetworks": c.details.unsafeNetworks,
"groups": c.details.groups,
"notBefore": c.details.notBefore,
"notAfter": c.details.notAfter,
"publicKey": fmt.Sprintf("%x", c.details.publicKey),
"isCa": c.details.isCA,
"issuer": c.details.issuer,
"curve": c.details.curve.String(),
},
signature: make([]byte, len(nc.signature)),
"fingerprint": fp,
"signature": fmt.Sprintf("%x", c.Signature()),
}
}
func (c *certificateV1) Copy() Certificate {
nc := &certificateV1{
details: detailsV1{
name: c.details.name,
groups: make([]string, len(c.details.groups)),
networks: make([]netip.Prefix, len(c.details.networks)),
unsafeNetworks: make([]netip.Prefix, len(c.details.unsafeNetworks)),
notBefore: c.details.notBefore,
notAfter: c.details.notAfter,
publicKey: make([]byte, len(c.details.publicKey)),
isCA: c.details.isCA,
issuer: c.details.issuer,
curve: c.details.curve,
},
signature: make([]byte, len(c.signature)),
}
copy(c.signature, nc.signature)
copy(c.details.Groups, nc.details.Groups)
copy(c.details.PublicKey, nc.details.PublicKey)
copy(nc.signature, c.signature)
copy(nc.details.groups, c.details.groups)
copy(nc.details.publicKey, c.details.publicKey)
copy(nc.details.networks, c.details.networks)
copy(nc.details.unsafeNetworks, c.details.unsafeNetworks)
for i, p := range nc.details.Ips {
c.details.Ips[i] = p
return nc
}
func (c *certificateV1) fromTBSCertificate(t *TBSCertificate) error {
c.details = detailsV1{
name: t.Name,
networks: t.Networks,
unsafeNetworks: t.UnsafeNetworks,
groups: t.Groups,
notBefore: t.NotBefore,
notAfter: t.NotAfter,
publicKey: t.PublicKey,
isCA: t.IsCA,
curve: t.Curve,
issuer: t.issuer,
}
for i, p := range nc.details.Subnets {
c.details.Subnets[i] = p
}
return nil
}
return c
func (c *certificateV1) marshalForSigning() ([]byte, error) {
b, err := proto.Marshal(c.getRawDetails())
if err != nil {
return nil, err
}
return b, nil
}
func (c *certificateV1) setSignature(b []byte) error {
c.signature = b
return nil
}
// unmarshalCertificateV1 will unmarshal a protobuf byte representation of a nebula cert
func unmarshalCertificateV1(b []byte, assertPublicKey bool) (*certificateV1, error) {
// if the publicKey is provided here then it is not required to be present in `b`
func unmarshalCertificateV1(b []byte, publicKey []byte) (*certificateV1, error) {
if len(b) == 0 {
return nil, fmt.Errorf("nil byte array")
}
@ -371,27 +350,28 @@ func unmarshalCertificateV1(b []byte, assertPublicKey bool) (*certificateV1, err
nc := certificateV1{
details: detailsV1{
Name: rc.Details.Name,
Groups: make([]string, len(rc.Details.Groups)),
Ips: make([]netip.Prefix, len(rc.Details.Ips)/2),
Subnets: make([]netip.Prefix, len(rc.Details.Subnets)/2),
NotBefore: time.Unix(rc.Details.NotBefore, 0),
NotAfter: time.Unix(rc.Details.NotAfter, 0),
PublicKey: make([]byte, len(rc.Details.PublicKey)),
IsCA: rc.Details.IsCA,
Curve: rc.Details.Curve,
name: rc.Details.Name,
groups: make([]string, len(rc.Details.Groups)),
networks: make([]netip.Prefix, len(rc.Details.Ips)/2),
unsafeNetworks: make([]netip.Prefix, len(rc.Details.Subnets)/2),
notBefore: time.Unix(rc.Details.NotBefore, 0),
notAfter: time.Unix(rc.Details.NotAfter, 0),
publicKey: make([]byte, len(rc.Details.PublicKey)),
isCA: rc.Details.IsCA,
curve: rc.Details.Curve,
},
signature: make([]byte, len(rc.Signature)),
}
copy(nc.signature, rc.Signature)
copy(nc.details.Groups, rc.Details.Groups)
nc.details.Issuer = hex.EncodeToString(rc.Details.Issuer)
copy(nc.details.groups, rc.Details.Groups)
nc.details.issuer = hex.EncodeToString(rc.Details.Issuer)
if len(rc.Details.PublicKey) < publicKeyLen && assertPublicKey {
return nil, fmt.Errorf("public key was fewer than 32 bytes; %v", len(rc.Details.PublicKey))
if len(publicKey) > 0 {
nc.details.publicKey = publicKey
}
copy(nc.details.PublicKey, rc.Details.PublicKey)
copy(nc.details.publicKey, rc.Details.PublicKey)
var ip netip.Addr
for i, rawIp := range rc.Details.Ips {
@ -399,7 +379,7 @@ func unmarshalCertificateV1(b []byte, assertPublicKey bool) (*certificateV1, err
ip = int2addr(rawIp)
} else {
ones, _ := net.IPMask(int2ip(rawIp)).Size()
nc.details.Ips[i/2] = netip.PrefixFrom(ip, ones)
nc.details.networks[i/2] = netip.PrefixFrom(ip, ones)
}
}
@ -408,69 +388,15 @@ func unmarshalCertificateV1(b []byte, assertPublicKey bool) (*certificateV1, err
ip = int2addr(rawIp)
} else {
ones, _ := net.IPMask(int2ip(rawIp)).Size()
nc.details.Subnets[i/2] = netip.PrefixFrom(ip, ones)
nc.details.unsafeNetworks[i/2] = netip.PrefixFrom(ip, ones)
}
}
//do not sort the subnets field for V1 certs
return &nc, nil
}
func signV1(t *TBSCertificate, curve Curve, key []byte, client *pkclient.PKClient) (*certificateV1, error) {
c := &certificateV1{
details: detailsV1{
Name: t.Name,
Ips: t.Networks,
Subnets: t.UnsafeNetworks,
Groups: t.Groups,
NotBefore: t.NotBefore,
NotAfter: t.NotAfter,
PublicKey: t.PublicKey,
IsCA: t.IsCA,
Curve: t.Curve,
Issuer: t.issuer,
},
}
b, err := proto.Marshal(c.getRawDetails())
if err != nil {
return nil, err
}
var sig []byte
switch curve {
case Curve_CURVE25519:
signer := ed25519.PrivateKey(key)
sig = ed25519.Sign(signer, b)
case Curve_P256:
if client != nil {
sig, err = client.SignASN1(b)
} else {
signer := &ecdsa.PrivateKey{
PublicKey: ecdsa.PublicKey{
Curve: elliptic.P256(),
},
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95
D: new(big.Int).SetBytes(key),
}
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119
signer.X, signer.Y = signer.Curve.ScalarBaseMult(key)
// We need to hash first for ECDSA
// - https://pkg.go.dev/crypto/ecdsa#SignASN1
hashed := sha256.Sum256(b)
sig, err = ecdsa.SignASN1(rand.Reader, signer, hashed[:])
if err != nil {
return nil, err
}
}
default:
return nil, fmt.Errorf("invalid curve: %s", c.details.Curve)
}
c.signature = sig
return c, nil
}
func ip2int(ip []byte) uint32 {
if len(ip) == 16 {
return binary.BigEndian.Uint32(ip[12:16])

37
cert/cert_v2.asn1 Normal file
View File

@ -0,0 +1,37 @@
Nebula DEFINITIONS AUTOMATIC TAGS ::= BEGIN
Name ::= UTF8String (SIZE (1..253))
Time ::= INTEGER (0..18446744073709551615) -- Seconds since unix epoch, uint64 maximum
Network ::= OCTET STRING (SIZE (5,17)) -- IP addresses are 4 or 16 bytes + 1 byte for the prefix length
Curve ::= ENUMERATED {
curve25519 (0),
p256 (1)
}
-- The maximum size of a certificate must not exceed 65536 bytes
Certificate ::= SEQUENCE {
details OCTET STRING,
curve Curve DEFAULT curve25519,
publicKey OCTET STRING,
-- signature(details + curve + publicKey) using the appropriate method for curve
signature OCTET STRING
}
Details ::= SEQUENCE {
name Name,
-- At least 1 ipv4 or ipv6 address must be present if isCA is false
networks SEQUENCE OF Network OPTIONAL,
unsafeNetworks SEQUENCE OF Network OPTIONAL,
groups SEQUENCE OF Name OPTIONAL,
isCA BOOLEAN DEFAULT false,
notBefore Time,
notAfter Time,
-- issuer is only required if isCA is false, if isCA is true then it must not be present
issuer OCTET STRING OPTIONAL,
...
-- New fields can be added below here
}
END

621
cert/cert_v2.go Normal file
View File

@ -0,0 +1,621 @@
package cert
import (
"bytes"
"crypto/ecdh"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"encoding/pem"
"fmt"
"net/netip"
"slices"
"time"
"golang.org/x/crypto/cryptobyte"
"golang.org/x/crypto/cryptobyte/asn1"
"golang.org/x/crypto/curve25519"
)
//TODO: should we avoid hex encoding shit on output? Just let it be base64?
const (
classConstructed = 0x20
classContextSpecific = 0x80
TagCertDetails = 0 | classConstructed | classContextSpecific
TagCertCurve = 1 | classContextSpecific
TagCertPublicKey = 2 | classContextSpecific
TagCertSignature = 3 | classContextSpecific
TagDetailsName = 0 | classContextSpecific
TagDetailsNetworks = 1 | classConstructed | classContextSpecific
TagDetailsUnsafeNetworks = 2 | classConstructed | classContextSpecific
TagDetailsGroups = 3 | classConstructed | classContextSpecific
TagDetailsIsCA = 4 | classContextSpecific
TagDetailsNotBefore = 5 | classContextSpecific
TagDetailsNotAfter = 6 | classContextSpecific
TagDetailsIssuer = 7 | classContextSpecific
)
const (
// MaxCertificateSize is the maximum length a valid certificate can be
MaxCertificateSize = 65536
// MaxNameLength is limited to a maximum realistic DNS domain name to help facilitate DNS systems
MaxNameLength = 253
// MaxNetworkLength is the maximum length a network value can be.
// 16 bytes for an ipv6 address + 1 byte for the prefix length
MaxNetworkLength = 17
)
type certificateV2 struct {
details detailsV2
// RawDetails contains the entire asn.1 DER encoded Details struct
// This is to benefit forwards compatibility in signature checking.
// signature(RawDetails + Curve + PublicKey) == Signature
rawDetails []byte
curve Curve
publicKey []byte
signature []byte
}
type detailsV2 struct {
name string
networks []netip.Prefix
unsafeNetworks []netip.Prefix
groups []string
isCA bool
notBefore time.Time
notAfter time.Time
issuer string
}
func (c *certificateV2) Version() Version {
return Version2
}
func (c *certificateV2) Curve() Curve {
return c.curve
}
func (c *certificateV2) Groups() []string {
return c.details.groups
}
func (c *certificateV2) IsCA() bool {
return c.details.isCA
}
func (c *certificateV2) Issuer() string {
return c.details.issuer
}
func (c *certificateV2) Name() string {
return c.details.name
}
func (c *certificateV2) Networks() []netip.Prefix {
return c.details.networks
}
func (c *certificateV2) NotAfter() time.Time {
return c.details.notAfter
}
func (c *certificateV2) NotBefore() time.Time {
return c.details.notBefore
}
func (c *certificateV2) PublicKey() []byte {
return c.publicKey
}
func (c *certificateV2) Signature() []byte {
return c.signature
}
func (c *certificateV2) UnsafeNetworks() []netip.Prefix {
return c.details.unsafeNetworks
}
func (c *certificateV2) Fingerprint() (string, error) {
b := make([]byte, len(c.rawDetails)+1+len(c.publicKey))
//TODO: double check this, panic on empty raw details
copy(b, c.rawDetails)
b[len(c.rawDetails)] = byte(c.curve)
copy(b[len(c.rawDetails)+1:], c.publicKey)
copy(b[len(c.rawDetails)+1+len(c.publicKey):], c.signature)
sum := sha256.Sum256(b)
return hex.EncodeToString(sum[:]), nil
}
func (c *certificateV2) CheckSignature(key []byte) bool {
b := make([]byte, len(c.rawDetails)+1+len(c.publicKey))
//TODO: double check this, panic on empty raw details
copy(b, c.rawDetails)
b[len(c.rawDetails)] = byte(c.curve)
copy(b[len(c.rawDetails)+1:], c.publicKey)
switch c.curve {
case Curve_CURVE25519:
return ed25519.Verify(key, b, c.signature)
case Curve_P256:
//TODO: NewPublicKey
x, y := elliptic.Unmarshal(elliptic.P256(), key)
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
hashed := sha256.Sum256(b)
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
default:
return false
}
}
func (c *certificateV2) Expired(t time.Time) bool {
return c.details.notBefore.After(t) || c.details.notAfter.Before(t)
}
func (c *certificateV2) VerifyPrivateKey(curve Curve, key []byte) error {
if curve != c.curve {
return fmt.Errorf("curve in cert and private key supplied don't match")
}
if c.details.isCA {
switch curve {
case Curve_CURVE25519:
// the call to PublicKey below will panic slice bounds out of range otherwise
if len(key) != ed25519.PrivateKeySize {
return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key")
}
if !ed25519.PublicKey(c.publicKey).Equal(ed25519.PrivateKey(key).Public()) {
return fmt.Errorf("public key in cert and private key supplied don't match")
}
case Curve_P256:
privkey, err := ecdh.P256().NewPrivateKey(key)
if err != nil {
return fmt.Errorf("cannot parse private key as P256")
}
pub := privkey.PublicKey().Bytes()
if !bytes.Equal(pub, c.publicKey) {
return fmt.Errorf("public key in cert and private key supplied don't match")
}
default:
return fmt.Errorf("invalid curve: %s", curve)
}
return nil
}
var pub []byte
switch curve {
case Curve_CURVE25519:
var err error
pub, err = curve25519.X25519(key, curve25519.Basepoint)
if err != nil {
return err
}
case Curve_P256:
privkey, err := ecdh.P256().NewPrivateKey(key)
if err != nil {
return err
}
pub = privkey.PublicKey().Bytes()
default:
return fmt.Errorf("invalid curve: %s", curve)
}
if !bytes.Equal(pub, c.publicKey) {
return fmt.Errorf("public key in cert and private key supplied don't match")
}
return nil
}
func (c *certificateV2) String() string {
b, err := json.MarshalIndent(c.marshalJSON(), "", "\t")
if err != nil {
return "<error marshalling certificate>"
}
return string(b)
}
func (c *certificateV2) MarshalForHandshakes() ([]byte, error) {
var b cryptobyte.Builder
// Outermost certificate
b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
// Add the cert details which is already marshalled
//TODO: panic on nil rawDetails
b.AddBytes(c.rawDetails)
// Skipping the curve and public key since those come across in a different part of the handshake
// Add the signature
b.AddASN1(TagCertSignature, func(b *cryptobyte.Builder) {
b.AddBytes(c.signature)
})
})
return b.Bytes()
}
func (c *certificateV2) Marshal() ([]byte, error) {
var b cryptobyte.Builder
// Outermost certificate
b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
// Add the cert details which is already marshalled
b.AddBytes(c.rawDetails)
// Add the curve only if its not the default value
if c.curve != Curve_CURVE25519 {
b.AddASN1(TagCertCurve, func(b *cryptobyte.Builder) {
b.AddBytes([]byte{byte(c.curve)})
})
}
// Add the public key if it is not empty
if c.publicKey != nil {
b.AddASN1(TagCertPublicKey, func(b *cryptobyte.Builder) {
b.AddBytes(c.publicKey)
})
}
// Add the signature
b.AddASN1(TagCertSignature, func(b *cryptobyte.Builder) {
b.AddBytes(c.signature)
})
})
return b.Bytes()
}
func (c *certificateV2) MarshalPEM() ([]byte, error) {
b, err := c.Marshal()
if err != nil {
return nil, err
}
return pem.EncodeToMemory(&pem.Block{Type: CertificateV2Banner, Bytes: b}), nil
}
func (c *certificateV2) MarshalJSON() ([]byte, error) {
return json.Marshal(c.marshalJSON())
}
func (c *certificateV2) marshalJSON() m {
fp, _ := c.Fingerprint()
return m{
"details": m{
"name": c.details.name,
"networks": c.details.networks,
"unsafeNetworks": c.details.unsafeNetworks,
"groups": c.details.groups,
"notBefore": c.details.notBefore,
"notAfter": c.details.notAfter,
"isCa": c.details.isCA,
"issuer": c.details.issuer,
},
"version": Version2,
"publicKey": fmt.Sprintf("%x", c.publicKey),
"curve": c.curve.String(),
"fingerprint": fp,
"signature": fmt.Sprintf("%x", c.Signature()),
}
}
func (c *certificateV2) Copy() Certificate {
nc := &certificateV2{
details: detailsV2{
name: c.details.name,
groups: make([]string, len(c.details.groups)),
networks: make([]netip.Prefix, len(c.details.networks)),
unsafeNetworks: make([]netip.Prefix, len(c.details.unsafeNetworks)),
notBefore: c.details.notBefore,
notAfter: c.details.notAfter,
isCA: c.details.isCA,
issuer: c.details.issuer,
},
curve: c.curve,
publicKey: make([]byte, len(c.publicKey)),
signature: make([]byte, len(c.signature)),
}
copy(nc.signature, c.signature)
copy(nc.details.groups, c.details.groups)
copy(nc.publicKey, c.publicKey)
copy(nc.details.networks, c.details.networks)
copy(nc.details.unsafeNetworks, c.details.unsafeNetworks)
return nc
}
func (c *certificateV2) fromTBSCertificate(t *TBSCertificate) error {
c.details = detailsV2{
name: t.Name,
networks: t.Networks,
unsafeNetworks: t.UnsafeNetworks,
groups: t.Groups,
isCA: t.IsCA,
notBefore: t.NotBefore,
notAfter: t.NotAfter,
issuer: t.issuer,
}
c.curve = t.Curve
c.publicKey = t.PublicKey
return nil
}
func (c *certificateV2) marshalForSigning() ([]byte, error) {
d, err := c.details.Marshal()
if err != nil {
//TODO: annotate?
return nil, err
}
c.rawDetails = d
b := make([]byte, len(c.rawDetails)+1+len(c.publicKey))
//TODO: double check this
copy(b, c.rawDetails)
b[len(c.rawDetails)] = byte(c.curve)
copy(b[len(c.rawDetails)+1:], c.publicKey)
return b, nil
}
func (c *certificateV2) setSignature(b []byte) error {
c.signature = b
return nil
}
func (d *detailsV2) Marshal() ([]byte, error) {
var b cryptobyte.Builder
var err error
// Details are a structure
b.AddASN1(TagCertDetails, func(b *cryptobyte.Builder) {
// Add the name
b.AddASN1(TagDetailsName, func(b *cryptobyte.Builder) {
b.AddBytes([]byte(d.name))
})
// Add the networks if any exist
if len(d.networks) > 0 {
b.AddASN1(TagDetailsNetworks, func(b *cryptobyte.Builder) {
for _, n := range d.networks {
sb, innerErr := n.MarshalBinary()
if innerErr != nil {
// MarshalBinary never returns an error
err = fmt.Errorf("unable to marshal network: %w", innerErr)
return
}
b.AddASN1OctetString(sb)
}
})
}
// Add the unsafe networks if any exist
if len(d.unsafeNetworks) > 0 {
b.AddASN1(TagDetailsUnsafeNetworks, func(b *cryptobyte.Builder) {
for _, n := range d.unsafeNetworks {
sb, innerErr := n.MarshalBinary()
if innerErr != nil {
// MarshalBinary never returns an error
err = fmt.Errorf("unable to marshal unsafe network: %w", innerErr)
return
}
b.AddASN1OctetString(sb)
}
})
}
// Add groups if any exist
if len(d.groups) > 0 {
b.AddASN1(TagDetailsGroups, func(b *cryptobyte.Builder) {
for _, group := range d.groups {
b.AddASN1(asn1.UTF8String, func(b *cryptobyte.Builder) {
b.AddBytes([]byte(group))
})
}
})
}
// Add IsCA only if true
if d.isCA {
b.AddASN1(TagDetailsIsCA, func(b *cryptobyte.Builder) {
b.AddUint8(0xff)
})
}
// Add not before
b.AddASN1Int64WithTag(d.notBefore.Unix(), TagDetailsNotBefore)
// Add not after
b.AddASN1Int64WithTag(d.notAfter.Unix(), TagDetailsNotAfter)
// Add the issuer if present
if d.issuer != "" {
issuerBytes, innerErr := hex.DecodeString(d.issuer)
if innerErr != nil {
err = fmt.Errorf("failed to decode issuer: %w", innerErr)
return
}
b.AddASN1(TagDetailsIssuer, func(b *cryptobyte.Builder) {
b.AddBytes(issuerBytes)
})
}
})
if err != nil {
return nil, err
}
return b.Bytes()
}
func unmarshalCertificateV2(b []byte, publicKey []byte, curve Curve) (*certificateV2, error) {
l := len(b)
if l == 0 || l > MaxCertificateSize {
return nil, ErrBadFormat
}
input := cryptobyte.String(b)
// Open the envelope
if !input.ReadASN1(&input, asn1.SEQUENCE) || input.Empty() {
return nil, ErrBadFormat
}
// Grab the cert details, we need to preserve the tag and length
var rawDetails cryptobyte.String
if !input.ReadASN1Element(&rawDetails, TagCertDetails) || rawDetails.Empty() {
return nil, ErrBadFormat
}
//Maybe grab the curve
var rawCurve byte
if !readOptionalASN1Byte(&input, &rawCurve, TagCertCurve, byte(curve)) {
return nil, ErrBadFormat
}
curve = Curve(rawCurve)
// Maybe grab the public key
var rawPublicKey cryptobyte.String
if len(publicKey) > 0 {
rawPublicKey = publicKey
} else if !input.ReadOptionalASN1(&rawPublicKey, nil, TagCertPublicKey) {
return nil, ErrBadFormat
}
//TODO: Assert public key length
// Grab the signature
var rawSignature cryptobyte.String
if !input.ReadASN1(&rawSignature, TagCertSignature) || rawSignature.Empty() {
return nil, ErrBadFormat
}
// Finally unmarshal the details
details, err := unmarshalDetails(rawDetails)
if err != nil {
return nil, err
}
return &certificateV2{
details: details,
rawDetails: rawDetails,
curve: curve,
publicKey: rawPublicKey,
signature: rawSignature,
}, nil
}
func unmarshalDetails(b cryptobyte.String) (detailsV2, error) {
// Open the envelope
if !b.ReadASN1(&b, TagCertDetails) || b.Empty() {
return detailsV2{}, ErrBadFormat
}
// Read the name
var name cryptobyte.String
if !b.ReadASN1(&name, TagDetailsName) || name.Empty() || len(name) > MaxNameLength {
return detailsV2{}, ErrBadFormat
}
// Read the network addresses
var subString cryptobyte.String
var found bool
if !b.ReadOptionalASN1(&subString, &found, TagDetailsNetworks) {
return detailsV2{}, ErrBadFormat
}
var networks []netip.Prefix
var val cryptobyte.String
if found {
for !subString.Empty() {
if !subString.ReadASN1(&val, asn1.OCTET_STRING) || val.Empty() || len(val) > MaxNetworkLength {
return detailsV2{}, ErrBadFormat
}
var n netip.Prefix
if err := n.UnmarshalBinary(val); err != nil {
return detailsV2{}, ErrBadFormat
}
networks = append(networks, n)
}
}
// Read out any unsafe networks
if !b.ReadOptionalASN1(&subString, &found, TagDetailsUnsafeNetworks) {
return detailsV2{}, ErrBadFormat
}
var unsafeNetworks []netip.Prefix
if found {
for !subString.Empty() {
if !subString.ReadASN1(&val, asn1.OCTET_STRING) || val.Empty() || len(val) > MaxNetworkLength {
return detailsV2{}, ErrBadFormat
}
var n netip.Prefix
if err := n.UnmarshalBinary(val); err != nil {
return detailsV2{}, ErrBadFormat
}
unsafeNetworks = append(unsafeNetworks, n)
}
}
// Read out any groups
if !b.ReadOptionalASN1(&subString, &found, TagDetailsGroups) {
return detailsV2{}, ErrBadFormat
}
var groups []string
if found {
for !subString.Empty() {
if !subString.ReadASN1(&val, asn1.UTF8String) || val.Empty() {
return detailsV2{}, ErrBadFormat
}
groups = append(groups, string(val))
}
}
// Read out IsCA
var isCa bool
if !readOptionalASN1Boolean(&b, &isCa, TagDetailsIsCA, false) {
return detailsV2{}, ErrBadFormat
}
// Read not before and not after
var notBefore int64
if !b.ReadASN1Int64WithTag(&notBefore, TagDetailsNotBefore) {
return detailsV2{}, ErrBadFormat
}
var notAfter int64
if !b.ReadASN1Int64WithTag(&notAfter, TagDetailsNotAfter) {
return detailsV2{}, ErrBadFormat
}
// Read issuer
var issuer cryptobyte.String
if !b.ReadOptionalASN1(&issuer, nil, TagDetailsIssuer) {
return detailsV2{}, ErrBadFormat
}
slices.SortFunc(networks, comparePrefix)
slices.SortFunc(unsafeNetworks, comparePrefix)
return detailsV2{
name: string(name),
networks: networks,
unsafeNetworks: unsafeNetworks,
groups: groups,
isCA: isCa,
notBefore: time.Unix(notBefore, 0),
notAfter: time.Unix(notAfter, 0),
issuer: hex.EncodeToString(issuer),
}, nil
}

View File

@ -24,4 +24,7 @@ var (
ErrInvalidPEMX25519PrivateKeyBanner = errors.New("bytes did not contain a proper X25519 private key banner")
ErrInvalidPEMEd25519PublicKeyBanner = errors.New("bytes did not contain a proper Ed25519 public key banner")
ErrInvalidPEMEd25519PrivateKeyBanner = errors.New("bytes did not contain a proper Ed25519 private key banner")
ErrNoPeerStaticKey = errors.New("no peer static key was present")
ErrNoPayload = errors.New("provided payload was empty")
)

View File

@ -30,19 +30,24 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
return nil, r, ErrInvalidPEMBlock
}
var c Certificate
var err error
switch p.Type {
case CertificateBanner:
c, err := unmarshalCertificateV1(p.Bytes, true)
if err != nil {
return nil, nil, err
}
return c, r, nil
c, err = unmarshalCertificateV1(p.Bytes, nil)
case CertificateV2Banner:
//TODO
panic("TODO")
c, err = unmarshalCertificateV2(p.Bytes, nil, Curve_CURVE25519)
default:
return nil, r, ErrInvalidPEMCertificateBanner
}
if err != nil {
return nil, r, err
}
return c, r, nil
}
func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {

View File

@ -1,11 +1,16 @@
package cert
import (
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rand"
"crypto/sha256"
"fmt"
"math/big"
"net/netip"
"slices"
"time"
"github.com/slackhq/nebula/pkclient"
)
// TBSCertificate represents a certificate intended to be signed.
@ -24,27 +29,62 @@ type TBSCertificate struct {
issuer string
}
type beingSignedCertificate interface {
// fromTBSCertificate copies the values from the TBSCertificate to this versions internal representation
fromTBSCertificate(*TBSCertificate) error
// marshalForSigning returns the bytes that should be signed
marshalForSigning() ([]byte, error)
// setSignature sets the signature for the certificate that has just been signed
setSignature([]byte) error
}
type SignerLambda func(certBytes []byte) ([]byte, error)
// Sign will create a sealed certificate using details provided by the TBSCertificate as long as those
// details do not violate constraints of the signing certificate.
// If the TBSCertificate is a CA then signer must be nil.
func (t *TBSCertificate) Sign(signer Certificate, curve Curve, key []byte) (Certificate, error) {
return t.sign(signer, curve, key, nil)
}
func (t *TBSCertificate) SignPkcs11(signer Certificate, curve Curve, client *pkclient.PKClient) (Certificate, error) {
if curve != Curve_P256 {
return nil, fmt.Errorf("only P256 is supported by PKCS#11")
switch t.Curve {
case Curve_CURVE25519:
pk := ed25519.PrivateKey(key)
sp := func(certBytes []byte) ([]byte, error) {
sig := ed25519.Sign(pk, certBytes)
return sig, nil
}
return t.SignWith(signer, curve, sp)
case Curve_P256:
pk := &ecdsa.PrivateKey{
PublicKey: ecdsa.PublicKey{
Curve: elliptic.P256(),
},
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95
D: new(big.Int).SetBytes(key),
}
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119
pk.X, pk.Y = pk.Curve.ScalarBaseMult(key)
sp := func(certBytes []byte) ([]byte, error) {
// We need to hash first for ECDSA
// - https://pkg.go.dev/crypto/ecdsa#SignASN1
hashed := sha256.Sum256(certBytes)
return ecdsa.SignASN1(rand.Reader, pk, hashed[:])
}
return t.SignWith(signer, curve, sp)
default:
return nil, fmt.Errorf("invalid curve: %s", t.Curve)
}
return t.sign(signer, curve, nil, client)
}
func (t *TBSCertificate) sign(signer Certificate, curve Curve, key []byte, client *pkclient.PKClient) (Certificate, error) {
// SignWith does the same thing as sign, but uses the function in `sp` to calculate the signature.
// You should only use SignWith if you do not have direct access to your private key.
func (t *TBSCertificate) SignWith(signer Certificate, curve Curve, sp SignerLambda) (Certificate, error) {
if curve != t.Curve {
return nil, fmt.Errorf("curve in cert and private key supplied don't match")
}
//TODO: make sure we have all minimum properties to sign, like a public key
//TODO: we need to verify networks and unsafe networks (no duplicates, max of 1 of each version for v2 certs
if signer != nil {
if t.IsCA {
@ -67,10 +107,55 @@ func (t *TBSCertificate) sign(signer Certificate, curve Curve, key []byte, clien
}
}
slices.SortFunc(t.Networks, comparePrefix)
slices.SortFunc(t.UnsafeNetworks, comparePrefix)
var c beingSignedCertificate
switch t.Version {
case Version1:
return signV1(t, curve, key, client)
c = &certificateV1{}
err := c.fromTBSCertificate(t)
if err != nil {
return nil, err
}
case Version2:
c = &certificateV2{}
err := c.fromTBSCertificate(t)
if err != nil {
return nil, err
}
default:
return nil, fmt.Errorf("unknown cert version %d", t.Version)
}
certBytes, err := c.marshalForSigning()
if err != nil {
return nil, err
}
sig, err := sp(certBytes)
if err != nil {
return nil, err
}
//TODO: check if we have sig bytes?
err = c.setSignature(sig)
if err != nil {
return nil, err
}
sc, ok := c.(Certificate)
if !ok {
return nil, fmt.Errorf("invalid certificate")
}
return sc, nil
}
func comparePrefix(a, b netip.Prefix) int {
addr := a.Addr().Compare(b.Addr())
if addr == 0 {
return a.Bits() - b.Bits()
}
return addr
}

View File

@ -27,34 +27,43 @@ type caFlags struct {
outCertPath *string
outQRPath *string
groups *string
ips *string
subnets *string
networks *string
unsafeNetworks *string
argonMemory *uint
argonIterations *uint
argonParallelism *uint
encryption *bool
version *uint
curve *string
p11url *string
// Deprecated options
ips *string
subnets *string
}
func newCaFlags() *caFlags {
cf := caFlags{set: flag.NewFlagSet("ca", flag.ContinueOnError)}
cf.set.Usage = func() {}
cf.name = cf.set.String("name", "", "Required: name of the certificate authority")
cf.version = cf.set.Uint("version", uint(cert.Version2), "Optional: version of the certificate format to use")
cf.duration = cf.set.Duration("duration", time.Duration(time.Hour*8760), "Optional: amount of time the certificate should be valid for. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"")
cf.outKeyPath = cf.set.String("out-key", "ca.key", "Optional: path to write the private key to")
cf.outCertPath = cf.set.String("out-crt", "ca.crt", "Optional: path to write the certificate to")
cf.outQRPath = cf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate")
cf.groups = cf.set.String("groups", "", "Optional: comma separated list of groups. This will limit which groups subordinate certs can use")
cf.ips = cf.set.String("ips", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use for ip addresses")
cf.subnets = cf.set.String("subnets", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use in subnets")
cf.networks = cf.set.String("networks", "", "Optional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in networks")
cf.unsafeNetworks = cf.set.String("unsafe-networks", "", "Optional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in unsafe networks")
cf.argonMemory = cf.set.Uint("argon-memory", 2*1024*1024, "Optional: Argon2 memory parameter (in KiB) used for encrypted private key passphrase")
cf.argonParallelism = cf.set.Uint("argon-parallelism", 4, "Optional: Argon2 parallelism parameter used for encrypted private key passphrase")
cf.argonIterations = cf.set.Uint("argon-iterations", 1, "Optional: Argon2 iterations parameter used for encrypted private key passphrase")
cf.encryption = cf.set.Bool("encrypt", false, "Optional: prompt for passphrase and write out-key in an encrypted format")
cf.curve = cf.set.String("curve", "25519", "EdDSA/ECDSA Curve (25519, P256)")
cf.p11url = p11Flag(cf.set)
cf.ips = cf.set.String("ips", "", "Deprecated, see -networks")
cf.subnets = cf.set.String("subnets", "", "Deprecated, see -unsafe-networks")
return &cf
}
@ -113,36 +122,51 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
}
}
var ips []netip.Prefix
if *cf.ips != "" {
for _, rs := range strings.Split(*cf.ips, ",") {
version := cert.Version(*cf.version)
if version != cert.Version1 && version != cert.Version2 {
return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2)
}
var networks []netip.Prefix
if *cf.networks == "" && *cf.ips != "" {
// Pull up deprecated -ips flag if needed
*cf.networks = *cf.ips
}
if *cf.networks != "" {
for _, rs := range strings.Split(*cf.networks, ",") {
rs := strings.Trim(rs, " ")
if rs != "" {
n, err := netip.ParsePrefix(rs)
if err != nil {
return newHelpErrorf("invalid ip definition: %s", err)
return newHelpErrorf("invalid -networks definition: %s", rs)
}
if !n.Addr().Is4() {
return newHelpErrorf("invalid ip definition: can only be ipv4, have %s", rs)
if version == cert.Version1 && !n.Addr().Is4() {
return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4, have %s", rs)
}
ips = append(ips, n)
networks = append(networks, n)
}
}
}
var subnets []netip.Prefix
if *cf.subnets != "" {
for _, rs := range strings.Split(*cf.subnets, ",") {
var unsafeNetworks []netip.Prefix
if *cf.unsafeNetworks == "" && *cf.subnets != "" {
// Pull up deprecated -subnets flag if needed
*cf.unsafeNetworks = *cf.subnets
}
if *cf.unsafeNetworks != "" {
for _, rs := range strings.Split(*cf.unsafeNetworks, ",") {
rs := strings.Trim(rs, " ")
if rs != "" {
n, err := netip.ParsePrefix(rs)
if err != nil {
return newHelpErrorf("invalid subnet definition: %s", err)
return newHelpErrorf("invalid -unsafe-networks definition: %s", rs)
}
if !n.Addr().Is4() {
return newHelpErrorf("invalid subnet definition: can only be ipv4, have %s", rs)
if version == cert.Version1 && !n.Addr().Is4() {
return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4, have %s", rs)
}
subnets = append(subnets, n)
unsafeNetworks = append(unsafeNetworks, n)
}
}
}
@ -222,11 +246,11 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
}
t := &cert.TBSCertificate{
Version: cert.Version1,
Version: version,
Name: *cf.name,
Groups: groups,
Networks: ips,
UnsafeNetworks: subnets,
Networks: networks,
UnsafeNetworks: unsafeNetworks,
NotBefore: time.Now(),
NotAfter: time.Now().Add(*cf.duration),
PublicKey: pub,
@ -248,7 +272,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
var b []byte
if isP11 {
c, err = t.SignPkcs11(nil, curve, p11Client)
c, err = t.SignWith(nil, curve, p11Client.SignASN1)
if err != nil {
return fmt.Errorf("error while signing with PKCS#11: %w", err)
}

View File

@ -43,9 +43,11 @@ func Test_caHelp(t *testing.T) {
" -groups string\n"+
" \tOptional: comma separated list of groups. This will limit which groups subordinate certs can use\n"+
" -ips string\n"+
" \tOptional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use for ip addresses\n"+
" Deprecated, see -networks\n"+
" -name string\n"+
" \tRequired: name of the certificate authority\n"+
" -networks string\n"+
" \tOptional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in networks\n"+
" -out-crt string\n"+
" \tOptional: path to write the certificate to (default \"ca.crt\")\n"+
" -out-key string\n"+
@ -54,7 +56,11 @@ func Test_caHelp(t *testing.T) {
" \tOptional: output a qr code image (png) of the certificate\n"+
optionalPkcs11String(" -pkcs11 string\n \tOptional: PKCS#11 URI to an existing private key\n")+
" -subnets string\n"+
" \tOptional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use in subnets\n",
" \tDeprecated, see -unsafe-networks\n"+
" -unsafe-networks string\n"+
" \tOptional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in unsafe networks\n"+
" -version uint\n"+
" \tOptional: version of the certificate format to use (default 2)\n",
ob.String(),
)
}
@ -83,25 +89,25 @@ func Test_ca(t *testing.T) {
// required args
assertHelpError(t, ca(
[]string{"-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw,
[]string{"-version", "1", "-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw,
), "-name is required")
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
// ipv4 only ips
assertHelpError(t, ca([]string{"-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid ip definition: can only be ipv4, have 100::100/100")
assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid -networks definition: v1 certificates can only be ipv4, have 100::100/100")
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
// ipv4 only subnets
assertHelpError(t, ca([]string{"-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid subnet definition: can only be ipv4, have 100::100/100")
assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4, have 100::100/100")
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
// failed key write
ob.Reset()
eb.Reset()
args := []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"}
args := []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"}
assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
@ -114,7 +120,7 @@ func Test_ca(t *testing.T) {
// failed cert write
ob.Reset()
eb.Reset()
args = []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()}
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()}
assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
@ -128,7 +134,7 @@ func Test_ca(t *testing.T) {
// test proper cert with removed empty groups and subnets
ob.Reset()
eb.Reset()
args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.Nil(t, ca(args, ob, eb, nopw))
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
@ -161,7 +167,7 @@ func Test_ca(t *testing.T) {
os.Remove(crtF.Name())
ob.Reset()
eb.Reset()
args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.Nil(t, ca(args, ob, eb, testpw))
assert.Equal(t, pwPromptOb, ob.String())
assert.Equal(t, "", eb.String())
@ -189,7 +195,7 @@ func Test_ca(t *testing.T) {
os.Remove(crtF.Name())
ob.Reset()
eb.Reset()
args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.Error(t, ca(args, ob, eb, errpw))
assert.Equal(t, pwPromptOb, ob.String())
assert.Equal(t, "", eb.String())
@ -199,7 +205,7 @@ func Test_ca(t *testing.T) {
os.Remove(crtF.Name())
ob.Reset()
eb.Reset()
args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext")
assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up
assert.Equal(t, "", eb.String())
@ -209,13 +215,13 @@ func Test_ca(t *testing.T) {
os.Remove(crtF.Name())
ob.Reset()
eb.Reset()
args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.Nil(t, ca(args, ob, eb, nopw))
// test that we won't overwrite existing certificate file
ob.Reset()
eb.Reset()
args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name())
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
@ -224,7 +230,7 @@ func Test_ca(t *testing.T) {
os.Remove(keyF.Name())
ob.Reset()
eb.Reset()
args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name())
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())

View File

@ -49,6 +49,8 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
var qrBytes []byte
part := 0
var jsonCerts []cert.Certificate
for {
c, rawCert, err = cert.UnmarshalCertificateFromPEM(rawCert)
if err != nil {
@ -56,13 +58,10 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
}
if *pf.json {
b, _ := json.Marshal(c)
out.Write(b)
out.Write([]byte("\n"))
jsonCerts = append(jsonCerts, c)
} else {
out.Write([]byte(c.String()))
out.Write([]byte("\n"))
_, _ = out.Write([]byte(c.String()))
_, _ = out.Write([]byte("\n"))
}
if *pf.outQRPath != "" {
@ -80,6 +79,12 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
part++
}
if *pf.json {
b, _ := json.Marshal(jsonCerts)
_, _ = out.Write(b)
_, _ = out.Write([]byte("\n"))
}
if *pf.outQRPath != "" {
b, err := qrcode.Encode(string(qrBytes), qrcode.Medium, -5)
if err != nil {

View File

@ -87,7 +87,65 @@ func Test_printCert(t *testing.T) {
assert.Nil(t, err)
assert.Equal(
t,
"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\n",
//"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\n",
`{
"details": {
"curve": "CURVE25519",
"groups": [
"hi"
],
"isCa": false,
"issuer": "`+c.Issuer()+`",
"name": "test",
"networks": [],
"notAfter": "0001-01-01T00:00:00Z",
"notBefore": "0001-01-01T00:00:00Z",
"publicKey": "`+pk+`",
"unsafeNetworks": []
},
"fingerprint": "`+fp+`",
"signature": "`+sig+`",
"version": 1
}
{
"details": {
"curve": "CURVE25519",
"groups": [
"hi"
],
"isCa": false,
"issuer": "`+c.Issuer()+`",
"name": "test",
"networks": [],
"notAfter": "0001-01-01T00:00:00Z",
"notBefore": "0001-01-01T00:00:00Z",
"publicKey": "`+pk+`",
"unsafeNetworks": []
},
"fingerprint": "`+fp+`",
"signature": "`+sig+`",
"version": 1
}
{
"details": {
"curve": "CURVE25519",
"groups": [
"hi"
],
"isCa": false,
"issuer": "`+c.Issuer()+`",
"name": "test",
"networks": [],
"notAfter": "0001-01-01T00:00:00Z",
"notBefore": "0001-01-01T00:00:00Z",
"publicKey": "`+pk+`",
"unsafeNetworks": []
},
"fingerprint": "`+fp+`",
"signature": "`+sig+`",
"version": 1
}
`,
ob.String(),
)
assert.Equal(t, "", eb.String())
@ -108,7 +166,8 @@ func Test_printCert(t *testing.T) {
assert.Nil(t, err)
assert.Equal(
t,
"{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\""+c.Issuer()+"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\""+pk+"\",\"subnets\":[]},\"fingerprint\":\""+fp+"\",\"signature\":\""+sig+"\"}\n{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\""+c.Issuer()+"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\""+pk+"\",\"subnets\":[]},\"fingerprint\":\""+fp+"\",\"signature\":\""+sig+"\"}\n{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\""+c.Issuer()+"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\""+pk+"\",\"subnets\":[]},\"fingerprint\":\""+fp+"\",\"signature\":\""+sig+"\"}\n",
`[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":[],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":[],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":[],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}]
`,
ob.String(),
)
assert.Equal(t, "", eb.String())

View File

@ -3,6 +3,7 @@ package main
import (
"crypto/ecdh"
"crypto/rand"
"errors"
"flag"
"fmt"
"io"
@ -18,36 +19,46 @@ import (
)
type signFlags struct {
set *flag.FlagSet
caKeyPath *string
caCertPath *string
name *string
ip *string
duration *time.Duration
inPubPath *string
outKeyPath *string
outCertPath *string
outQRPath *string
groups *string
subnets *string
p11url *string
set *flag.FlagSet
version *uint
caKeyPath *string
caCertPath *string
name *string
networks *string
unsafeNetworks *string
duration *time.Duration
inPubPath *string
outKeyPath *string
outCertPath *string
outQRPath *string
groups *string
p11url *string
// Deprecated options
ip *string
subnets *string
}
func newSignFlags() *signFlags {
sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)}
sf.set.Usage = func() {}
sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use, the default is to create both v1 and v2 certificates.")
sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key")
sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert")
sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname")
sf.ip = sf.set.String("ip", "", "Required: ipv4 address and network in CIDR notation to assign the cert")
sf.networks = sf.set.String("networks", "", "Required: comma separated list of ip address and network in CIDR notation to assign to this cert")
sf.unsafeNetworks = sf.set.String("unsafe-networks", "", "Optional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for")
sf.duration = sf.set.Duration("duration", 0, "Optional: how long the cert should be valid for. The default is 1 second before the signing cert expires. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"")
sf.inPubPath = sf.set.String("in-pub", "", "Optional (if out-key not set): path to read a previously generated public key")
sf.outKeyPath = sf.set.String("out-key", "", "Optional (if in-pub not set): path to write the private key to")
sf.outCertPath = sf.set.String("out-crt", "", "Optional: path to write the certificate to")
sf.outQRPath = sf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate")
sf.groups = sf.set.String("groups", "", "Optional: comma separated list of groups")
sf.subnets = sf.set.String("subnets", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. Subnets this cert can serve for")
sf.p11url = p11Flag(sf.set)
sf.ip = sf.set.String("ip", "", "Deprecated, see -networks")
sf.subnets = sf.set.String("subnets", "", "Deprecated, see -unsafe-networks")
return &sf
}
@ -71,13 +82,26 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
if err := mustFlagString("name", sf.name); err != nil {
return err
}
if err := mustFlagString("ip", sf.ip); err != nil {
return err
}
if !isP11 && *sf.inPubPath != "" && *sf.outKeyPath != "" {
return newHelpErrorf("cannot set both -in-pub and -out-key")
}
var v4Networks []netip.Prefix
var v6Networks []netip.Prefix
if *sf.networks == "" && *sf.ip != "" {
// Pull up deprecated -ip flag if needed
*sf.networks = *sf.ip
}
if len(*sf.networks) == 0 {
return newHelpErrorf("-networks is required")
}
version := cert.Version(*sf.version)
if version != 0 && version != cert.Version1 && version != cert.Version2 {
return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2)
}
var curve cert.Curve
var caKey []byte
@ -91,14 +115,14 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
// naively attempt to decode the private key as though it is not encrypted
caKey, _, curve, err = cert.UnmarshalSigningPrivateKeyFromPEM(rawCAKey)
if err == cert.ErrPrivateKeyEncrypted {
if errors.Is(err, cert.ErrPrivateKeyEncrypted) {
// ask for a passphrase until we get one
var passphrase []byte
for i := 0; i < 5; i++ {
out.Write([]byte("Enter passphrase: "))
passphrase, err = pr.ReadPassword()
if err == ErrNoTerminal {
if errors.Is(err, ErrNoTerminal) {
return fmt.Errorf("ca-key is encrypted and must be decrypted interactively")
} else if err != nil {
return fmt.Errorf("error reading password: %s", err)
@ -146,12 +170,49 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
*sf.duration = time.Until(caCert.NotAfter()) - time.Second*1
}
network, err := netip.ParsePrefix(*sf.ip)
if err != nil {
return newHelpErrorf("invalid ip definition: %s", *sf.ip)
if *sf.networks != "" {
for _, rs := range strings.Split(*sf.networks, ",") {
//TODO: error on duplicates? Mainly only addr matters, having two of the same addr in the same or different prefix space is strange
rs := strings.Trim(rs, " ")
if rs != "" {
n, err := netip.ParsePrefix(rs)
if err != nil {
return newHelpErrorf("invalid -networks definition: %s", rs)
}
if n.Addr().Is4() {
v4Networks = append(v4Networks, n)
} else {
v6Networks = append(v6Networks, n)
}
}
}
}
if !network.Addr().Is4() {
return newHelpErrorf("invalid ip definition: can only be ipv4, have %s", *sf.ip)
var v4UnsafeNetworks []netip.Prefix
var v6UnsafeNetworks []netip.Prefix
if *sf.unsafeNetworks == "" && *sf.subnets != "" {
// Pull up deprecated -subnets flag if needed
*sf.unsafeNetworks = *sf.subnets
}
if *sf.unsafeNetworks != "" {
//TODO: error on duplicates?
for _, rs := range strings.Split(*sf.unsafeNetworks, ",") {
rs := strings.Trim(rs, " ")
if rs != "" {
n, err := netip.ParsePrefix(rs)
if err != nil {
return newHelpErrorf("invalid -unsafe-networks definition: %s", rs)
}
if n.Addr().Is4() {
v4UnsafeNetworks = append(v4UnsafeNetworks, n)
} else {
v6UnsafeNetworks = append(v6UnsafeNetworks, n)
}
}
}
}
var groups []string
@ -164,23 +225,6 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
}
}
var subnets []netip.Prefix
if *sf.subnets != "" {
for _, rs := range strings.Split(*sf.subnets, ",") {
rs := strings.Trim(rs, " ")
if rs != "" {
s, err := netip.ParsePrefix(rs)
if err != nil {
return newHelpErrorf("invalid subnet definition: %s", rs)
}
if !s.Addr().Is4() {
return newHelpErrorf("invalid subnet definition: can only be ipv4, have %s", rs)
}
subnets = append(subnets, s)
}
}
}
var pub, rawPriv []byte
var p11Client *pkclient.PKClient
@ -218,19 +262,6 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
pub, rawPriv = newKeypair(curve)
}
t := &cert.TBSCertificate{
Version: cert.Version1,
Name: *sf.name,
Networks: []netip.Prefix{network},
Groups: groups,
UnsafeNetworks: subnets,
NotBefore: time.Now(),
NotAfter: time.Now().Add(*sf.duration),
PublicKey: pub,
IsCA: false,
Curve: curve,
}
if *sf.outKeyPath == "" {
*sf.outKeyPath = *sf.name + ".key"
}
@ -243,18 +274,85 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath)
}
var c cert.Certificate
var crts []cert.Certificate
if p11Client == nil {
c, err = t.Sign(caCert, curve, caKey)
if err != nil {
return fmt.Errorf("error while signing: %w", err)
notBefore := time.Now()
notAfter := notBefore.Add(*sf.duration)
if version == 0 || version == cert.Version1 {
// Make sure we at least have an ip
if len(v4Networks) != 1 {
return newHelpErrorf("invalid -networks definition: v1 certificates can only have a single ipv4 address")
}
} else {
c, err = t.SignPkcs11(caCert, curve, p11Client)
if err != nil {
return fmt.Errorf("error while signing with PKCS#11: %w", err)
if version == cert.Version1 {
// If we are asked to mint a v1 certificate only then we cant just ignore any v6 addresses
if len(v6Networks) > 0 {
return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4")
}
if len(v6UnsafeNetworks) > 0 {
return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4")
}
}
t := &cert.TBSCertificate{
Version: cert.Version1,
Name: *sf.name,
Networks: []netip.Prefix{v4Networks[0]},
Groups: groups,
UnsafeNetworks: v4UnsafeNetworks,
NotBefore: notBefore,
NotAfter: notAfter,
PublicKey: pub,
IsCA: false,
Curve: curve,
}
var nc cert.Certificate
if p11Client == nil {
nc, err = t.Sign(caCert, curve, caKey)
if err != nil {
return fmt.Errorf("error while signing: %w", err)
}
} else {
nc, err = t.SignWith(caCert, curve, p11Client.SignASN1)
if err != nil {
return fmt.Errorf("error while signing with PKCS#11: %w", err)
}
}
crts = append(crts, nc)
}
if version == 0 || version == cert.Version2 {
t := &cert.TBSCertificate{
Version: cert.Version2,
Name: *sf.name,
Networks: append(v4Networks, v6Networks...),
Groups: groups,
UnsafeNetworks: append(v4UnsafeNetworks, v6UnsafeNetworks...),
NotBefore: notBefore,
NotAfter: notAfter,
PublicKey: pub,
IsCA: false,
Curve: curve,
}
var nc cert.Certificate
if p11Client == nil {
nc, err = t.Sign(caCert, curve, caKey)
if err != nil {
return fmt.Errorf("error while signing: %w", err)
}
} else {
nc, err = t.SignWith(caCert, curve, p11Client.SignASN1)
if err != nil {
return fmt.Errorf("error while signing with PKCS#11: %w", err)
}
}
crts = append(crts, nc)
}
if !isP11 && *sf.inPubPath == "" {
@ -268,9 +366,13 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
}
}
b, err := c.MarshalPEM()
if err != nil {
return fmt.Errorf("error while marshalling certificate: %s", err)
var b []byte
for _, c := range crts {
sb, err := c.MarshalPEM()
if err != nil {
return fmt.Errorf("error while marshalling certificate: %s", err)
}
b = append(b, sb...)
}
err = os.WriteFile(*sf.outCertPath, b, 0600)

View File

@ -39,9 +39,11 @@ func Test_signHelp(t *testing.T) {
" -in-pub string\n"+
" \tOptional (if out-key not set): path to read a previously generated public key\n"+
" -ip string\n"+
" \tRequired: ipv4 address and network in CIDR notation to assign the cert\n"+
" \tDeprecated, see -networks\n"+
" -name string\n"+
" \tRequired: name of the cert, usually a hostname\n"+
" -networks string\n"+
" \tRequired: comma separated list of ip address and network in CIDR notation to assign to this cert\n"+
" -out-crt string\n"+
" \tOptional: path to write the certificate to\n"+
" -out-key string\n"+
@ -50,7 +52,11 @@ func Test_signHelp(t *testing.T) {
" \tOptional: output a qr code image (png) of the certificate\n"+
optionalPkcs11String(" -pkcs11 string\n \tOptional: PKCS#11 URI to an existing private key\n")+
" -subnets string\n"+
" \tOptional: comma separated list of ipv4 address and network in CIDR notation. Subnets this cert can serve for\n",
" \tDeprecated, see -unsafe-networks\n"+
" -unsafe-networks string\n"+
" \tOptional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for\n"+
" -version uint\n"+
" \tOptional: version of the certificate format to use, the default is to create both v1 and v2 certificates.\n",
ob.String(),
)
}
@ -77,20 +83,20 @@ func Test_signCert(t *testing.T) {
// required args
assertHelpError(t, signCert(
[]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
[]string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
), "-name is required")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
assertHelpError(t, signCert(
[]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
), "-ip is required")
[]string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
), "-networks is required")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
// cannot set -in-pub and -out-key
assertHelpError(t, signCert(
[]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb, nopw,
[]string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb, nopw,
), "cannot set both -in-pub and -out-key")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
@ -98,7 +104,7 @@ func Test_signCert(t *testing.T) {
// failed to read key
ob.Reset()
eb.Reset()
args := []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
args := []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError)
// failed to unmarshal key
@ -108,7 +114,7 @@ func Test_signCert(t *testing.T) {
assert.Nil(t, err)
defer os.Remove(caKeyF.Name())
args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
@ -120,7 +126,7 @@ func Test_signCert(t *testing.T) {
caKeyF.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv))
// failed to read cert
args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError)
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
@ -132,7 +138,7 @@ func Test_signCert(t *testing.T) {
assert.Nil(t, err)
defer os.Remove(caCrtF.Name())
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
@ -143,7 +149,7 @@ func Test_signCert(t *testing.T) {
caCrtF.Write(b)
// failed to read pub
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"}
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError)
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
@ -155,7 +161,7 @@ func Test_signCert(t *testing.T) {
assert.Nil(t, err)
defer os.Remove(inPubF.Name())
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"}
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
@ -169,30 +175,37 @@ func Test_signCert(t *testing.T) {
// bad ip cidr
ob.Reset()
eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "a1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: a1.1.1.1/24")
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "a1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: a1.1.1.1/24")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
ob.Reset()
eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "100::100/100", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: can only be ipv4, have 100::100/100")
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "100::100/100", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: v1 certificates can only have a single ipv4 address")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
ob.Reset()
eb.Reset()
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24,1.1.1.2/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: v1 certificates can only have a single ipv4 address")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
// bad subnet cidr
ob.Reset()
eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: a")
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: a")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
ob.Reset()
eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"}
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: can only be ipv4, have 100::100/100")
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"}
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
@ -205,7 +218,7 @@ func Test_signCert(t *testing.T) {
ob.Reset()
eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
@ -213,7 +226,7 @@ func Test_signCert(t *testing.T) {
// failed key write
ob.Reset()
eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"}
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
@ -226,7 +239,7 @@ func Test_signCert(t *testing.T) {
// failed cert write
ob.Reset()
eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"}
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
@ -240,7 +253,7 @@ func Test_signCert(t *testing.T) {
// test proper cert with removed empty groups and subnets
ob.Reset()
eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.Nil(t, signCert(args, ob, eb, nopw))
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
@ -283,7 +296,7 @@ func Test_signCert(t *testing.T) {
os.Remove(crtF.Name())
ob.Reset()
eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"}
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"}
assert.Nil(t, signCert(args, ob, eb, nopw))
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
@ -300,7 +313,7 @@ func Test_signCert(t *testing.T) {
eb.Reset()
os.Remove(keyF.Name())
os.Remove(crtF.Name())
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while signing: certificate expires after signing certificate")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
@ -308,14 +321,14 @@ func Test_signCert(t *testing.T) {
// create valid cert/key for overwrite tests
os.Remove(keyF.Name())
os.Remove(crtF.Name())
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.Nil(t, signCert(args, ob, eb, nopw))
// test that we won't overwrite existing key file
os.Remove(crtF.Name())
ob.Reset()
eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name())
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
@ -323,14 +336,14 @@ func Test_signCert(t *testing.T) {
// create valid cert/key for overwrite tests
os.Remove(keyF.Name())
os.Remove(crtF.Name())
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.Nil(t, signCert(args, ob, eb, nopw))
// test that we won't overwrite existing certificate file
os.Remove(keyF.Name())
ob.Reset()
eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name())
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
@ -362,7 +375,7 @@ func Test_signCert(t *testing.T) {
caCrtF.Write(b)
// test with the proper password
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.Nil(t, signCert(args, ob, eb, testpw))
assert.Equal(t, "Enter passphrase: ", ob.String())
assert.Empty(t, eb.String())
@ -372,7 +385,7 @@ func Test_signCert(t *testing.T) {
eb.Reset()
testpw.password = []byte("invalid password")
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.Error(t, signCert(args, ob, eb, testpw))
assert.Equal(t, "Enter passphrase: ", ob.String())
assert.Empty(t, eb.String())
@ -381,7 +394,7 @@ func Test_signCert(t *testing.T) {
ob.Reset()
eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.Error(t, signCert(args, ob, eb, nopw))
// normally the user hitting enter on the prompt would add newlines between these
assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String())
@ -391,7 +404,7 @@ func Test_signCert(t *testing.T) {
ob.Reset()
eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.Error(t, signCert(args, ob, eb, errpw))
assert.Equal(t, "Enter passphrase: ", ob.String())
assert.Empty(t, eb.String())

View File

@ -183,7 +183,7 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte,
case deleteTunnel:
if n.hostMap.DeleteHostInfo(hostinfo) {
// Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap
n.intf.lightHouse.DeleteVpnIp(hostinfo.vpnIp)
n.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs)
}
case closeTunnel:
@ -221,7 +221,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
relayFor := oldhostinfo.relayState.CopyAllRelayFor()
for _, r := range relayFor {
existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerIp)
existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerAddr)
var index uint32
var relayFrom netip.Addr
@ -235,11 +235,11 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
index = existing.LocalIndex
switch r.Type {
case TerminalType:
relayFrom = n.intf.myVpnNet.Addr()
relayTo = existing.PeerIp
relayFrom = n.intf.myVpnAddrs[0]
relayTo = existing.PeerAddr
case ForwardingType:
relayFrom = existing.PeerIp
relayTo = newhostinfo.vpnIp
relayFrom = existing.PeerAddr
relayTo = newhostinfo.vpnAddrs[0]
default:
// should never happen
}
@ -253,45 +253,64 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
n.relayUsedLock.RUnlock()
// The relay doesn't exist at all; create some relay state and send the request.
var err error
index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerIp, nil, r.Type, Requested)
index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerAddr, nil, r.Type, Requested)
if err != nil {
n.l.WithError(err).Error("failed to migrate relay to new hostinfo")
continue
}
switch r.Type {
case TerminalType:
relayFrom = n.intf.myVpnNet.Addr()
relayTo = r.PeerIp
relayFrom = n.intf.myVpnAddrs[0]
relayTo = r.PeerAddr
case ForwardingType:
relayFrom = r.PeerIp
relayTo = newhostinfo.vpnIp
relayFrom = r.PeerAddr
relayTo = newhostinfo.vpnAddrs[0]
default:
// should never happen
}
}
//TODO: IPV6-WORK
relayFromB := relayFrom.As4()
relayToB := relayTo.As4()
// Send a CreateRelayRequest to the peer.
req := NebulaControl{
Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: index,
RelayFromIp: binary.BigEndian.Uint32(relayFromB[:]),
RelayToIp: binary.BigEndian.Uint32(relayToB[:]),
}
switch newhostinfo.GetCert().Certificate.Version() {
case cert.Version1:
if !relayFrom.Is4() {
n.l.Error("can not migrate v1 relay with a v6 network because the relay is not running a current nebula version")
continue
}
if !relayTo.Is4() {
n.l.Error("can not migrate v1 relay with a v6 remote network because the relay is not running a current nebula version")
continue
}
b := relayFrom.As4()
req.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
b = relayTo.As4()
req.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
case cert.Version2:
req.RelayFromAddr = netAddrToProtoAddr(relayFrom)
req.RelayToAddr = netAddrToProtoAddr(relayTo)
default:
newhostinfo.logger(n.l).Error("Unknown certificate version found while attempting to migrate relay")
continue
}
msg, err := req.Marshal()
if err != nil {
n.l.WithError(err).Error("failed to marshal Control message to migrate relay")
} else {
n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
n.l.WithFields(logrus.Fields{
"relayFrom": req.RelayFromIp,
"relayTo": req.RelayToIp,
"relayFrom": req.RelayFromAddr,
"relayTo": req.RelayToAddr,
"initiatorRelayIndex": req.InitiatorRelayIndex,
"responderRelayIndex": req.ResponderRelayIndex,
"vpnIp": newhostinfo.vpnIp}).
"vpnAddrs": newhostinfo.vpnAddrs}).
Info("send CreateRelayRequest")
}
}
@ -313,7 +332,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time
return closeTunnel, hostinfo, nil
}
primary := n.hostMap.Hosts[hostinfo.vpnIp]
primary := n.hostMap.Hosts[hostinfo.vpnAddrs[0]]
mainHostInfo := true
if primary != nil && primary != hostinfo {
mainHostInfo = false
@ -407,21 +426,24 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
// Let's sort this out.
if current.vpnIp.Compare(n.intf.myVpnNet.Addr()) < 0 {
//TODO: current.vpnIp should become an array of vpnIps
if current.vpnAddrs[0].Compare(n.intf.myVpnAddrs[0]) < 0 {
// Only one side should flip primary because if both flip then we may never resolve to a single tunnel.
// vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping.
// The remotes vpn ip is lower than mine. I will not flip.
return false
}
certState := n.intf.pki.GetCertState()
return bytes.Equal(current.ConnectionState.myCert.Signature(), certState.Certificate.Signature())
//TODO: we should favor v2 over v1 certificates if configured to send them
crt := n.intf.pki.getCertificate(current.ConnectionState.myCert.Version())
return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
}
func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
n.hostMap.Lock()
// Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake.
if n.hostMap.Hosts[current.vpnIp] == primary {
if n.hostMap.Hosts[current.vpnAddrs[0]] == primary {
n.hostMap.unlockedMakePrimary(current)
}
n.hostMap.Unlock()
@ -473,14 +495,16 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
}
func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
certState := n.intf.pki.GetCertState()
if bytes.Equal(hostinfo.ConnectionState.myCert.Signature(), certState.Certificate.Signature()) {
crt := n.intf.pki.getCertificate(hostinfo.ConnectionState.myCert.Version())
if bytes.Equal(hostinfo.ConnectionState.myCert.Signature(), crt.Signature()) {
return
}
n.l.WithField("vpnIp", hostinfo.vpnIp).
//TODO: we should favor v2 over v1 certificates if configured to send them
n.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("reason", "local certificate is not current").
Info("Re-handshaking with remote")
n.intf.handshakeManager.StartHandshake(hostinfo.vpnIp, nil)
n.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
}

View File

@ -34,20 +34,19 @@ func newTestLighthouse() *LightHouse {
func Test_NewConnectionManagerTest(t *testing.T) {
l := test.NewLogger()
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
vpncidr := netip.MustParsePrefix("172.1.1.1/24")
localrange := netip.MustParsePrefix("10.1.1.1/24")
vpnIp := netip.MustParseAddr("172.1.1.2")
preferredRanges := []netip.Prefix{localrange}
// Very incomplete mock objects
hostMap := newHostMap(l, vpncidr)
hostMap := newHostMap(l)
hostMap.preferredRanges.Store(&preferredRanges)
cs := &CertState{
RawCertificate: []byte{},
PrivateKey: []byte{},
Certificate: &dummyCert{},
RawCertificateNoKey: []byte{},
defaultVersion: cert.Version1,
privateKey: []byte{},
v1Cert: &dummyCert{version: cert.Version1},
v1HandshakeBytes: []byte{},
}
lh := newTestLighthouse()
@ -74,12 +73,12 @@ func Test_NewConnectionManagerTest(t *testing.T) {
// Add an ip we have established a connection w/ to hostmap
hostinfo := &HostInfo{
vpnIp: vpnIp,
vpnAddrs: []netip.Addr{vpnIp},
localIndexId: 1099,
remoteIndexId: 9901,
}
hostinfo.ConnectionState = &ConnectionState{
myCert: &dummyCert{},
myCert: &dummyCert{version: cert.Version1},
H: &noise.HandshakeState{},
}
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
@ -88,7 +87,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
nc.Out(hostinfo.localIndexId)
nc.In(hostinfo.localIndexId)
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.out, hostinfo.localIndexId)
@ -105,32 +104,31 @@ func Test_NewConnectionManagerTest(t *testing.T) {
assert.NotContains(t, nc.out, hostinfo.localIndexId)
assert.NotContains(t, nc.in, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
// Do a final traffic check tick, the host should now be removed
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
}
func Test_NewConnectionManagerTest2(t *testing.T) {
l := test.NewLogger()
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
vpncidr := netip.MustParsePrefix("172.1.1.1/24")
localrange := netip.MustParsePrefix("10.1.1.1/24")
vpnIp := netip.MustParseAddr("172.1.1.2")
preferredRanges := []netip.Prefix{localrange}
// Very incomplete mock objects
hostMap := newHostMap(l, vpncidr)
hostMap := newHostMap(l)
hostMap.preferredRanges.Store(&preferredRanges)
cs := &CertState{
RawCertificate: []byte{},
PrivateKey: []byte{},
Certificate: &dummyCert{},
RawCertificateNoKey: []byte{},
defaultVersion: cert.Version1,
privateKey: []byte{},
v1Cert: &dummyCert{version: cert.Version1},
v1HandshakeBytes: []byte{},
}
lh := newTestLighthouse()
@ -157,12 +155,12 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
// Add an ip we have established a connection w/ to hostmap
hostinfo := &HostInfo{
vpnIp: vpnIp,
vpnAddrs: []netip.Addr{vpnIp},
localIndexId: 1099,
remoteIndexId: 9901,
}
hostinfo.ConnectionState = &ConnectionState{
myCert: &dummyCert{},
myCert: &dummyCert{version: cert.Version1},
H: &noise.HandshakeState{},
}
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
@ -170,8 +168,8 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
// We saw traffic out to vpnIp
nc.Out(hostinfo.localIndexId)
nc.In(hostinfo.localIndexId)
assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnIp)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnAddrs[0])
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
@ -187,7 +185,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
assert.NotContains(t, nc.out, hostinfo.localIndexId)
assert.NotContains(t, nc.in, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
// We saw traffic, should no longer be pending deletion
nc.In(hostinfo.localIndexId)
@ -196,7 +194,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
assert.NotContains(t, nc.out, hostinfo.localIndexId)
assert.NotContains(t, nc.in, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
}
// Check if we can disconnect the peer.
@ -210,7 +208,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
localrange := netip.MustParsePrefix("10.1.1.1/24")
vpnIp := netip.MustParseAddr("172.1.1.2")
preferredRanges := []netip.Prefix{localrange}
hostMap := newHostMap(l, vpncidr)
hostMap := newHostMap(l)
hostMap.preferredRanges.Store(&preferredRanges)
// Generate keys for CA and peer's cert.
@ -244,10 +242,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert)
cs := &CertState{
RawCertificate: []byte{},
PrivateKey: []byte{},
Certificate: &dummyCert{},
RawCertificateNoKey: []byte{},
privateKey: []byte{},
v1Cert: &dummyCert{},
v1HandshakeBytes: []byte{},
}
lh := newTestLighthouse()
@ -273,7 +270,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
ifce.connectionManager = nc
hostinfo := &HostInfo{
vpnIp: vpnIp,
vpnAddrs: []netip.Addr{vpnIp},
ConnectionState: &ConnectionState{
myCert: &dummyCert{},
peerCert: cachedPeerCert,

View File

@ -3,6 +3,7 @@ package nebula
import (
"crypto/rand"
"encoding/json"
"fmt"
"sync"
"sync/atomic"
@ -26,46 +27,46 @@ type ConnectionState struct {
writeLock sync.Mutex
}
func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) {
var dhFunc noise.DHFunc
switch certState.Certificate.Curve() {
switch crt.Curve() {
case cert.Curve_CURVE25519:
dhFunc = noise.DH25519
case cert.Curve_P256:
if certState.pkcs11Backed {
if cs.pkcs11Backed {
dhFunc = noiseutil.DHP256PKCS11
} else {
dhFunc = noiseutil.DHP256
}
default:
l.Errorf("invalid curve: %s", certState.Certificate.Curve())
return nil
return nil, fmt.Errorf("invalid curve: %s", crt.Curve())
}
var cs noise.CipherSuite
if cipher == "chachapoly" {
cs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
var ncs noise.CipherSuite
if cs.cipher == "chachapoly" {
ncs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
} else {
cs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
ncs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
}
static := noise.DHKey{Private: certState.PrivateKey, Public: certState.PublicKey}
static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()}
b := NewBits(ReplayWindow)
// Clear out bit 0, we never transmit it and we don't want it showing as packet loss
// Clear out bit 0, we never transmit it, and we don't want it showing as packet loss
b.Update(l, 0)
hs, err := noise.NewHandshakeState(noise.Config{
CipherSuite: cs,
Random: rand.Reader,
Pattern: pattern,
Initiator: initiator,
StaticKeypair: static,
PresharedKey: psk,
PresharedKeyPlacement: pskStage,
CipherSuite: ncs,
Random: rand.Reader,
Pattern: pattern,
Initiator: initiator,
StaticKeypair: static,
//NOTE: These should come from CertState (pki.go) when we finally implement it
PresharedKey: []byte{},
PresharedKeyPlacement: 0,
})
if err != nil {
return nil
return nil, fmt.Errorf("NewConnectionState: %s", err)
}
// The queue and ready params prevent a counter race that would happen when
@ -74,12 +75,12 @@ func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, i
H: hs,
initiator: initiator,
window: b,
myCert: certState.Certificate,
myCert: crt,
}
// always start the counter from 2, as packet 1 and packet 2 are handshake packets.
ci.messageCounter.Add(2)
return ci
return ci, nil
}
func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
@ -89,3 +90,7 @@ func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
"message_counter": cs.messageCounter.Load(),
})
}
func (cs *ConnectionState) Curve() cert.Curve {
return cs.myCert.Curve()
}

View File

@ -19,9 +19,9 @@ import (
type controlEach func(h *HostInfo)
type controlHostLister interface {
QueryVpnIp(vpnIp netip.Addr) *HostInfo
QueryVpnAddr(vpnAddr netip.Addr) *HostInfo
ForEachIndex(each controlEach)
ForEachVpnIp(each controlEach)
ForEachVpnAddr(each controlEach)
GetPreferredRanges() []netip.Prefix
}
@ -37,7 +37,7 @@ type Control struct {
}
type ControlHostInfo struct {
VpnIp netip.Addr `json:"vpnIp"`
VpnAddrs []netip.Addr `json:"vpnAddrs"`
LocalIndex uint32 `json:"localIndex"`
RemoteIndex uint32 `json:"remoteIndex"`
RemoteAddrs []netip.AddrPort `json:"remoteAddrs"`
@ -131,10 +131,13 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
// GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found
func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate {
if c.f.myVpnNet.Addr() == vpnIp {
return c.f.pki.GetCertState().Certificate.Copy()
_, found := c.f.myVpnAddrsTable.Lookup(vpnIp)
if found {
//TODO: we might have 2 certs....
//TODO: this should return our latest version cert
return c.f.pki.getDefaultCertificate().Copy()
}
hi := c.f.hostMap.QueryVpnIp(vpnIp)
hi := c.f.hostMap.QueryVpnAddr(vpnIp)
if hi == nil {
return nil
}
@ -148,7 +151,7 @@ func (c *Control) CreateTunnel(vpnIp netip.Addr) {
// PrintTunnel creates a new tunnel to the given vpn ip.
func (c *Control) PrintTunnel(vpnIp netip.Addr) *ControlHostInfo {
hi := c.f.hostMap.QueryVpnIp(vpnIp)
hi := c.f.hostMap.QueryVpnAddr(vpnIp)
if hi == nil {
return nil
}
@ -165,9 +168,9 @@ func (c *Control) QueryLighthouse(vpnIp netip.Addr) *CacheMap {
return hi.CopyCache()
}
// GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found
// GetHostInfoByVpnAddr returns a single tunnels hostInfo, or nil if not found
// Caller should take care to Unmap() any 4in6 addresses prior to calling.
func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHostInfo {
func (c *Control) GetHostInfoByVpnAddr(vpnAddr netip.Addr, pending bool) *ControlHostInfo {
var hl controlHostLister
if pending {
hl = c.f.handshakeManager
@ -175,7 +178,7 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHos
hl = c.f.hostMap
}
h := hl.QueryVpnIp(vpnIp)
h := hl.QueryVpnAddr(vpnAddr)
if h == nil {
return nil
}
@ -187,7 +190,7 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHos
// SetRemoteForTunnel forces a tunnel to use a specific remote
// Caller should take care to Unmap() any 4in6 addresses prior to calling.
func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *ControlHostInfo {
hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
hostInfo := c.f.hostMap.QueryVpnAddr(vpnIp)
if hostInfo == nil {
return nil
}
@ -200,7 +203,7 @@ func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *Con
// CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
// Caller should take care to Unmap() any 4in6 addresses prior to calling.
func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool {
hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
hostInfo := c.f.hostMap.QueryVpnAddr(vpnIp)
if hostInfo == nil {
return false
}
@ -229,14 +232,14 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
shutdown := func(h *HostInfo) {
if excludeLighthouses {
if _, ok := lighthouses[h.vpnIp]; ok {
if _, ok := lighthouses[h.vpnAddrs[0]]; ok {
return
}
}
c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
c.f.closeTunnel(h)
c.l.WithField("vpnIp", h.vpnIp).WithField("udpAddr", h.remote).
c.l.WithField("vpnIp", h.vpnAddrs[0]).WithField("udpAddr", h.remote).
Debug("Sending close tunnel message")
closed++
}
@ -246,7 +249,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
// Grab the hostMap lock to access the Relays map
c.f.hostMap.Lock()
for _, relayingHost := range c.f.hostMap.Relays {
relayingHosts[relayingHost.vpnIp] = relayingHost
relayingHosts[relayingHost.vpnAddrs[0]] = relayingHost
}
c.f.hostMap.Unlock()
@ -254,7 +257,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
// Grab the hostMap lock to access the Hosts map
c.f.hostMap.Lock()
for _, relayHost := range c.f.hostMap.Indexes {
if _, ok := relayingHosts[relayHost.vpnIp]; !ok {
if _, ok := relayingHosts[relayHost.vpnAddrs[0]]; !ok {
hostInfos = append(hostInfos, relayHost)
}
}
@ -274,9 +277,8 @@ func (c *Control) Device() overlay.Device {
}
func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
chi := ControlHostInfo{
VpnIp: h.vpnIp,
VpnAddrs: make([]netip.Addr, len(h.vpnAddrs)),
LocalIndex: h.localIndexId,
RemoteIndex: h.remoteIndexId,
RemoteAddrs: h.remotes.CopyAddrs(preferredRanges),
@ -285,6 +287,10 @@ func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
CurrentRemote: h.remote,
}
for i, a := range h.vpnAddrs {
chi.VpnAddrs[i] = a
}
if h.ConnectionState != nil {
chi.MessageCounter = h.ConnectionState.messageCounter.Load()
}
@ -299,7 +305,7 @@ func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
func listHostMapHosts(hl controlHostLister) []ControlHostInfo {
hosts := make([]ControlHostInfo, 0)
pr := hl.GetPreferredRanges()
hl.ForEachVpnIp(func(hostinfo *HostInfo) {
hl.ForEachVpnAddr(func(hostinfo *HostInfo) {
hosts = append(hosts, copyHostInfo(hostinfo, pr))
})
return hosts

View File

@ -19,7 +19,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
l := test.NewLogger()
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
// To properly ensure we are not exposing core memory to the caller
hm := newHostMap(l, netip.Prefix{})
hm := newHostMap(l)
hm.preferredRanges.Store(&[]netip.Prefix{})
remote1 := netip.MustParseAddrPort("0.0.0.100:4444")
@ -35,9 +35,9 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
Mask: net.IPMask{255, 255, 255, 0},
}
remotes := NewRemoteList(nil)
remotes.unlockedPrependV4(netip.IPv4Unspecified(), NewIp4AndPortFromNetIP(remote1.Addr(), remote1.Port()))
remotes.unlockedPrependV6(netip.IPv4Unspecified(), NewIp6AndPortFromNetIP(remote2.Addr(), remote2.Port()))
remotes := NewRemoteList([]netip.Addr{netip.IPv4Unspecified()}, nil)
remotes.unlockedPrependV4(netip.IPv4Unspecified(), netAddrToProtoV4AddrPort(remote1.Addr(), remote1.Port()))
remotes.unlockedPrependV6(netip.IPv4Unspecified(), netAddrToProtoV6AddrPort(remote2.Addr(), remote2.Port()))
vpnIp, ok := netip.AddrFromSlice(ipNet.IP)
assert.True(t, ok)
@ -51,11 +51,11 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
},
remoteIndexId: 200,
localIndexId: 201,
vpnIp: vpnIp,
vpnAddrs: []netip.Addr{vpnIp},
relayState: RelayState{
relays: map[netip.Addr]struct{}{},
relayForByIp: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{},
relays: map[netip.Addr]struct{}{},
relayForByAddr: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},
}, &Interface{})
@ -70,11 +70,11 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
},
remoteIndexId: 200,
localIndexId: 201,
vpnIp: vpnIp2,
vpnAddrs: []netip.Addr{vpnIp2},
relayState: RelayState{
relays: map[netip.Addr]struct{}{},
relayForByIp: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{},
relays: map[netip.Addr]struct{}{},
relayForByAddr: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},
}, &Interface{})
@ -85,10 +85,10 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
l: logrus.New(),
}
thi := c.GetHostInfoByVpnIp(vpnIp, false)
thi := c.GetHostInfoByVpnAddr(vpnIp, false)
expectedInfo := ControlHostInfo{
VpnIp: vpnIp,
VpnAddrs: []netip.Addr{vpnIp},
LocalIndex: 201,
RemoteIndex: 200,
RemoteAddrs: []netip.AddrPort{remote2, remote1},
@ -100,13 +100,13 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
}
// Make sure we don't have any unexpected fields
assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
assertFields(t, []string{"VpnAddrs", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
assert.EqualValues(t, &expectedInfo, thi)
test.AssertDeepCopyEqual(t, &expectedInfo, thi)
// Make sure we don't panic if the host info doesn't have a cert yet
assert.NotPanics(t, func() {
thi = c.GetHostInfoByVpnIp(vpnIp2, false)
thi = c.GetHostInfoByVpnAddr(vpnIp2, false)
})
}

View File

@ -6,8 +6,6 @@ package nebula
import (
"net/netip"
"github.com/slackhq/nebula/cert"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/slackhq/nebula/header"
@ -51,15 +49,15 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType,
// This is necessary if you did not configure static hosts or are not running a lighthouse
func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort) {
c.f.lightHouse.Lock()
remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp)
remoteList := c.f.lightHouse.unlockedGetRemoteList([]netip.Addr{vpnIp})
remoteList.Lock()
defer remoteList.Unlock()
c.f.lightHouse.Unlock()
if toAddr.Addr().Is4() {
remoteList.unlockedPrependV4(vpnIp, NewIp4AndPortFromNetIP(toAddr.Addr(), toAddr.Port()))
remoteList.unlockedPrependV4(vpnIp, netAddrToProtoV4AddrPort(toAddr.Addr(), toAddr.Port()))
} else {
remoteList.unlockedPrependV6(vpnIp, NewIp6AndPortFromNetIP(toAddr.Addr(), toAddr.Port()))
remoteList.unlockedPrependV6(vpnIp, netAddrToProtoV6AddrPort(toAddr.Addr(), toAddr.Port()))
}
}
@ -67,7 +65,7 @@ func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort)
// This is necessary to inform an initiator of possible relays for communicating with a responder
func (c *Control) InjectRelays(vpnIp netip.Addr, relayVpnIps []netip.Addr) {
c.f.lightHouse.Lock()
remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp)
remoteList := c.f.lightHouse.unlockedGetRemoteList([]netip.Addr{vpnIp})
remoteList.Lock()
defer remoteList.Unlock()
c.f.lightHouse.Unlock()
@ -99,21 +97,42 @@ func (c *Control) InjectUDPPacket(p *udp.Packet) {
}
// InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol
func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort uint16, data []byte) {
//TODO: IPV6-WORK
ip := layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
SrcIP: c.f.inside.Cidr().Addr().Unmap().AsSlice(),
DstIP: toIp.Unmap().AsSlice(),
func (c *Control) InjectTunUDPPacket(toAddr netip.Addr, toPort uint16, fromAddr netip.Addr, fromPort uint16, data []byte) {
serialize := make([]gopacket.SerializableLayer, 0)
var netLayer gopacket.NetworkLayer
if toAddr.Is6() {
if !fromAddr.Is6() {
panic("Cant send ipv6 to ipv4")
}
ip := &layers.IPv6{
Version: 6,
NextHeader: layers.IPProtocolUDP,
SrcIP: fromAddr.Unmap().AsSlice(),
DstIP: toAddr.Unmap().AsSlice(),
}
serialize = append(serialize, ip)
netLayer = ip
} else {
if !fromAddr.Is4() {
panic("Cant send ipv4 to ipv6")
}
ip := &layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
SrcIP: fromAddr.Unmap().AsSlice(),
DstIP: toAddr.Unmap().AsSlice(),
}
serialize = append(serialize, ip)
netLayer = ip
}
udp := layers.UDP{
SrcPort: layers.UDPPort(fromPort),
DstPort: layers.UDPPort(toPort),
}
err := udp.SetNetworkLayerForChecksum(&ip)
err := udp.SetNetworkLayerForChecksum(netLayer)
if err != nil {
panic(err)
}
@ -123,7 +142,9 @@ func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort ui
ComputeChecksums: true,
FixLengths: true,
}
err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload(data))
serialize = append(serialize, &udp, gopacket.Payload(data))
err = gopacket.SerializeLayers(buffer, opt, serialize...)
if err != nil {
panic(err)
}
@ -131,8 +152,8 @@ func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort ui
c.f.inside.(*overlay.TestTun).Send(buffer.Bytes())
}
func (c *Control) GetVpnIp() netip.Addr {
return c.f.myVpnNet.Addr()
func (c *Control) GetVpnAddrs() []netip.Addr {
return c.f.myVpnAddrs
}
func (c *Control) GetUDPAddr() netip.AddrPort {
@ -140,7 +161,7 @@ func (c *Control) GetUDPAddr() netip.AddrPort {
}
func (c *Control) KillPendingTunnel(vpnIp netip.Addr) bool {
hostinfo := c.f.handshakeManager.QueryVpnIp(vpnIp)
hostinfo := c.f.handshakeManager.QueryVpnAddr(vpnIp)
if hostinfo == nil {
return false
}
@ -153,8 +174,8 @@ func (c *Control) GetHostmap() *HostMap {
return c.f.hostMap
}
func (c *Control) GetCert() cert.Certificate {
return c.f.pki.GetCertState().Certificate
func (c *Control) GetCertState() *CertState {
return c.f.pki.getCertState()
}
func (c *Control) ReHandshake(vpnIp netip.Addr) {

View File

@ -8,6 +8,7 @@ import (
"strings"
"sync"
"github.com/gaissmai/bart"
"github.com/miekg/dns"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
@ -21,24 +22,39 @@ var dnsAddr string
type dnsRecords struct {
sync.RWMutex
dnsMap map[string]string
hostMap *HostMap
l *logrus.Logger
dnsMap4 map[string]netip.Addr
dnsMap6 map[string]netip.Addr
hostMap *HostMap
myVpnAddrsTable *bart.Table[struct{}]
}
func newDnsRecords(hostMap *HostMap) *dnsRecords {
func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords {
return &dnsRecords{
dnsMap: make(map[string]string),
hostMap: hostMap,
l: l,
dnsMap4: make(map[string]netip.Addr),
dnsMap6: make(map[string]netip.Addr),
hostMap: hostMap,
myVpnAddrsTable: cs.myVpnAddrsTable,
}
}
func (d *dnsRecords) Query(data string) string {
func (d *dnsRecords) Query(q uint16, data string) netip.Addr {
data = strings.ToLower(data)
d.RLock()
defer d.RUnlock()
if r, ok := d.dnsMap[strings.ToLower(data)]; ok {
return r
switch q {
case dns.TypeA:
if r, ok := d.dnsMap4[data]; ok {
return r
}
case dns.TypeAAAA:
if r, ok := d.dnsMap6[data]; ok {
return r
}
}
return ""
return netip.Addr{}
}
func (d *dnsRecords) QueryCert(data string) string {
@ -47,7 +63,7 @@ func (d *dnsRecords) QueryCert(data string) string {
return ""
}
hostinfo := d.hostMap.QueryVpnIp(ip)
hostinfo := d.hostMap.QueryVpnAddr(ip)
if hostinfo == nil {
return ""
}
@ -64,38 +80,62 @@ func (d *dnsRecords) QueryCert(data string) string {
return string(b)
}
func (d *dnsRecords) Add(host, data string) {
// Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host`
func (d *dnsRecords) Add(host string, addresses []netip.Addr) {
host = strings.ToLower(host)
d.Lock()
defer d.Unlock()
d.dnsMap[strings.ToLower(host)] = data
haveV4 := false
haveV6 := false
for _, addr := range addresses {
if addr.Is4() && !haveV4 {
d.dnsMap4[host] = addr
haveV4 = true
} else if addr.Is6() && !haveV6 {
d.dnsMap6[host] = addr
haveV6 = true
}
if haveV4 && haveV6 {
break
}
}
}
func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool {
a, _, _ := net.SplitHostPort(addr)
b, err := netip.ParseAddr(a)
if err != nil {
return false
}
if b.IsLoopback() {
return true
}
_, found := d.myVpnAddrsTable.Lookup(b)
return found //if we found it in this table, it's good
}
func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
for _, q := range m.Question {
switch q.Qtype {
case dns.TypeA:
l.Debugf("Query for A %s", q.Name)
ip := dnsR.Query(q.Name)
if ip != "" {
rr, err := dns.NewRR(fmt.Sprintf("%s A %s", q.Name, ip))
case dns.TypeA, dns.TypeAAAA:
qType := dns.TypeToString[q.Qtype]
d.l.Debugf("Query for %s %s", qType, q.Name)
ip := d.Query(q.Qtype, q.Name)
if ip.IsValid() {
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip))
if err == nil {
m.Answer = append(m.Answer, rr)
}
}
case dns.TypeTXT:
a, _, _ := net.SplitHostPort(w.RemoteAddr().String())
b, err := netip.ParseAddr(a)
if err != nil {
// We only answer these queries from nebula nodes or localhost
if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) {
return
}
// We don't answer these queries from non nebula nodes or localhost
//l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR)
if !dnsR.hostMap.vpnCIDR.Contains(b) && a != "127.0.0.1" {
return
}
l.Debugf("Query for TXT %s", q.Name)
ip := dnsR.QueryCert(q.Name)
d.l.Debugf("Query for TXT %s", q.Name)
ip := d.QueryCert(q.Name)
if ip != "" {
rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
if err == nil {
@ -110,26 +150,24 @@ func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
}
}
func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) {
func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Compress = false
switch r.Opcode {
case dns.OpcodeQuery:
parseQuery(l, m, w)
d.parseQuery(m, w)
}
w.WriteMsg(m)
}
func dnsMain(l *logrus.Logger, hostMap *HostMap, c *config.C) func() {
dnsR = newDnsRecords(hostMap)
func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() {
dnsR = newDnsRecords(l, cs, hostMap)
// attach request handler func
dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) {
handleDnsRequest(l, w, r)
})
dns.HandleFunc(".", dnsR.handleDnsRequest)
c.RegisterReloadCallback(func(c *config.C) {
reloadDns(l, c)

View File

@ -1,23 +1,38 @@
package nebula
import (
"net/netip"
"testing"
"github.com/miekg/dns"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/stretchr/testify/assert"
)
func TestParsequery(t *testing.T) {
//TODO: This test is basically pointless
l := logrus.New()
hostMap := &HostMap{}
ds := newDnsRecords(hostMap)
ds.Add("test.com.com", "1.2.3.4")
ds := newDnsRecords(l, &CertState{}, hostMap)
addrs := []netip.Addr{
netip.MustParseAddr("1.2.3.4"),
netip.MustParseAddr("1.2.3.5"),
netip.MustParseAddr("fd01::24"),
netip.MustParseAddr("fd01::25"),
}
ds.Add("test.com.com", addrs)
m := new(dns.Msg)
m := &dns.Msg{}
m.SetQuestion("test.com.com", dns.TypeA)
ds.parseQuery(m, nil)
assert.NotNil(t, m.Answer)
assert.Equal(t, "1.2.3.4", m.Answer[0].(*dns.A).A.String())
//parseQuery(m)
m = &dns.Msg{}
m.SetQuestion("test.com.com", dns.TypeAAAA)
ds.parseQuery(m, nil)
assert.NotNil(t, m.Answer)
assert.Equal(t, "fd01::24", m.Answer[0].(*dns.AAAA).AAAA.String())
}
func Test_getDnsServerAddr(t *testing.T) {

View File

@ -4,7 +4,6 @@
package e2e
import (
"fmt"
"net/netip"
"slices"
"testing"
@ -12,6 +11,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/e2e/router"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/udp"
@ -21,11 +21,11 @@ import (
func BenchmarkHotPath(b *testing.B) {
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, _, _, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
// Put their info in our lighthouse
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
// Start the servers
myControl.Start()
@ -35,7 +35,7 @@ func BenchmarkHotPath(b *testing.B) {
r.CancelFlowLogs()
for n := 0; n < b.N; n++ {
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
_ = r.RouteForAllUntilTxTun(theirControl)
}
@ -45,18 +45,18 @@ func BenchmarkHotPath(b *testing.B) {
func TestGoodHandshake(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
// Put their info in our lighthouse
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
// Start the servers
myControl.Start()
theirControl.Start()
t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
t.Log("Have them consume my stage 0 packet. They have a tunnel now")
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
@ -77,16 +77,16 @@ func TestGoodHandshake(t *testing.T) {
myControl.WaitForType(1, 0, theirControl)
t.Log("Make sure our host infos are correct")
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl)
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
t.Log("Get that cached packet and make sure it looks right")
myCachedPacket := theirControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
t.Log("Do a bidirectional tunnel test")
r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow()
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
myControl.Stop()
@ -97,12 +97,12 @@ func TestGoodHandshake(t *testing.T) {
func TestWrongResponderHandshake(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.100/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.99/24", nil)
evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", "10.128.0.2/24", nil)
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.100/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.99/24", nil)
evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "evil", "10.128.0.2/24", nil)
// Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse.
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), evilUdpAddr)
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), evilUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, theirControl, evilControl)
@ -114,7 +114,7 @@ func TestWrongResponderHandshake(t *testing.T) {
evilControl.Start()
t.Log("Start the handshake process, we will route until we see the evil tunnel closed")
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
h := &header.H{}
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
@ -131,8 +131,8 @@ func TestWrongResponderHandshake(t *testing.T) {
})
t.Log("Evil tunnel is closed, inject the correct udp addr for them")
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
pendingHi := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), true)
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
pendingHi := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), true)
assert.NotContains(t, pendingHi.RemoteAddrs, evilUdpAddr)
t.Log("Route until we see the cached packet")
@ -153,18 +153,18 @@ func TestWrongResponderHandshake(t *testing.T) {
t.Log("My cached packet should be received by them")
myCachedPacket := theirControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
t.Log("Test the tunnel with them")
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl)
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
t.Log("Flush all packets from all controllers")
r.FlushAll()
t.Log("Ensure ensure I don't have any hostinfo artifacts from evil")
assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), true), "My pending hostmap should not contain evil")
assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), false), "My main hostmap should not contain evil")
assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIp[0].Addr(), true), "My pending hostmap should not contain evil")
assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIp[0].Addr(), false), "My main hostmap should not contain evil")
//TODO: assert hostmaps for everyone
r.RenderHostmaps("Final hostmaps", myControl, theirControl, evilControl)
@ -176,17 +176,17 @@ func TestWrongResponderHandshake(t *testing.T) {
func TestWrongResponderHandshakeStaticHostMap(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.99/24", nil)
evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", "10.128.0.2/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.99/24", nil)
evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "evil", "10.128.0.2/24", nil)
o := m{
"static_host_map": m{
theirVpnIpNet.Addr().String(): []string{evilUdpAddr.String()},
theirVpnIpNet[0].Addr().String(): []string{evilUdpAddr.String()},
},
}
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.100/24", o)
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.100/24", o)
// Put the evil udp addr in for their vpn addr, this is a case of a remote at a static entry changing its vpn addr.
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), evilUdpAddr)
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), evilUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, theirControl, evilControl)
@ -198,7 +198,7 @@ func TestWrongResponderHandshakeStaticHostMap(t *testing.T) {
evilControl.Start()
t.Log("Start the handshake process, we will route until we see the evil tunnel closed")
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
h := &header.H{}
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
@ -215,8 +215,8 @@ func TestWrongResponderHandshakeStaticHostMap(t *testing.T) {
})
t.Log("Evil tunnel is closed, inject the correct udp addr for them")
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
pendingHi := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), true)
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
pendingHi := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), true)
assert.NotContains(t, pendingHi.RemoteAddrs, evilUdpAddr)
t.Log("Route until we see the cached packet")
@ -237,18 +237,19 @@ func TestWrongResponderHandshakeStaticHostMap(t *testing.T) {
t.Log("My cached packet should be received by them")
myCachedPacket := theirControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
t.Log("Test the tunnel with them")
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl)
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
t.Log("Flush all packets from all controllers")
r.FlushAll()
t.Log("Ensure ensure I don't have any hostinfo artifacts from evil")
assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), true), "My pending hostmap should not contain evil")
assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), false), "My main hostmap should not contain evil")
assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIp[0].Addr(), true), "My pending hostmap should not contain evil")
assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIp[0].Addr(), false), "My main hostmap should not contain evil")
//NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete
//TODO: assert hostmaps for everyone
r.RenderHostmaps("Final hostmaps", myControl, theirControl, evilControl)
@ -262,12 +263,12 @@ func TestStage1Race(t *testing.T) {
// But will eventually collapse down to a single tunnel
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
// Put their info in our lighthouse and vice versa
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, theirControl)
@ -278,8 +279,8 @@ func TestStage1Race(t *testing.T) {
theirControl.Start()
t.Log("Trigger a handshake to start on both me and them")
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them"))
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
t.Log("Get both stage 1 handshake packets")
myHsForThem := myControl.GetFromUDP(true)
@ -291,14 +292,14 @@ func TestStage1Race(t *testing.T) {
r.Log("Route until they receive a message packet")
myCachedPacket := r.RouteForAllUntilTxTun(theirControl)
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
r.Log("Their cached packet should be received by me")
theirCachedPacket := r.RouteForAllUntilTxTun(myControl)
assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), 80, 80)
assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80)
r.Log("Do a bidirectional tunnel test")
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
myHostmapHosts := myControl.ListHostmapHosts(false)
myHostmapIndexes := myControl.ListHostmapIndexes(false)
@ -316,7 +317,7 @@ func TestStage1Race(t *testing.T) {
r.Log("Spin until connection manager tears down a tunnel")
for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
t.Log("Connection manager hasn't ticked yet")
time.Sleep(time.Second)
}
@ -339,12 +340,12 @@ func TestStage1Race(t *testing.T) {
func TestUncleanShutdownRaceLoser(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
// Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, theirControl)
@ -355,10 +356,10 @@ func TestUncleanShutdownRaceLoser(t *testing.T) {
theirControl.Start()
r.Log("Trigger a handshake from me to them")
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl)
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
r.Log("Nuke my hostmap")
myHostmap := myControl.GetHostmap()
@ -366,17 +367,17 @@ func TestUncleanShutdownRaceLoser(t *testing.T) {
myHostmap.Indexes = map[uint32]*nebula.HostInfo{}
myHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me again"))
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me again"))
p = r.RouteForAllUntilTxTun(theirControl)
assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
r.Log("Assert the tunnel works")
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
r.Log("Wait for the dead index to go away")
start := len(theirControl.GetHostmap().Indexes)
for {
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
if len(theirControl.GetHostmap().Indexes) < start {
break
}
@ -388,12 +389,12 @@ func TestUncleanShutdownRaceLoser(t *testing.T) {
func TestUncleanShutdownRaceWinner(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
// Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, theirControl)
@ -404,10 +405,10 @@ func TestUncleanShutdownRaceWinner(t *testing.T) {
theirControl.Start()
r.Log("Trigger a handshake from me to them")
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl)
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
r.Log("Nuke my hostmap")
@ -416,18 +417,18 @@ func TestUncleanShutdownRaceWinner(t *testing.T) {
theirHostmap.Indexes = map[uint32]*nebula.HostInfo{}
theirHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them again"))
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them again"))
p = r.RouteForAllUntilTxTun(myControl)
assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), 80, 80)
assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80)
r.RenderHostmaps("Derp hostmaps", myControl, theirControl)
r.Log("Assert the tunnel works")
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
r.Log("Wait for the dead index to go away")
start := len(myControl.GetHostmap().Indexes)
for {
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
if len(myControl.GetHostmap().Indexes) < start {
break
}
@ -439,14 +440,14 @@ func TestUncleanShutdownRaceWinner(t *testing.T) {
func TestRelays(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
// Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl)
@ -458,11 +459,11 @@ func TestRelays(t *testing.T) {
theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
//TODO: assert we actually used the relay even though it should be impossible for a tunnel to have occurred without it
}
@ -470,19 +471,19 @@ func TestRelays(t *testing.T) {
func TestStage1RaceRelays(t *testing.T) {
//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
// Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
theirControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
theirControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
theirControl.InjectRelays(myVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
theirControl.InjectRelays(myVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
relayControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
relayControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl)
@ -494,14 +495,14 @@ func TestStage1RaceRelays(t *testing.T) {
theirControl.Start()
r.Log("Get a tunnel between me and relay")
assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r)
assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r)
r.Log("Get a tunnel between them and relay")
assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r)
assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r)
r.Log("Trigger a handshake from both them and me via relay to them and me")
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them"))
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
r.Log("Wait for a packet from them to me")
p := r.RouteForAllUntilTxTun(myControl)
@ -519,20 +520,20 @@ func TestStage1RaceRelays(t *testing.T) {
func TestStage1RaceRelays2(t *testing.T) {
//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
l := NewTestLogger()
// Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
theirControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
theirControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
theirControl.InjectRelays(myVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
theirControl.InjectRelays(myVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
relayControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
relayControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl)
@ -545,16 +546,16 @@ func TestStage1RaceRelays2(t *testing.T) {
r.Log("Get a tunnel between me and relay")
l.Info("Get a tunnel between me and relay")
assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r)
assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r)
r.Log("Get a tunnel between them and relay")
l.Info("Get a tunnel between them and relay")
assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r)
assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r)
r.Log("Trigger a handshake from both them and me via relay to them and me")
l.Info("Trigger a handshake from both them and me via relay to them and me")
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them"))
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
//r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone)
//r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone)
@ -567,7 +568,7 @@ func TestStage1RaceRelays2(t *testing.T) {
r.Log("Assert the tunnel works")
l.Info("Assert the tunnel works")
assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
t.Log("Wait until we remove extra tunnels")
l.Info("Wait until we remove extra tunnels")
@ -587,7 +588,7 @@ func TestStage1RaceRelays2(t *testing.T) {
"theirControl": len(theirControl.GetHostmap().Indexes),
"relayControl": len(relayControl.GetHostmap().Indexes),
}).Info("Waiting for hostinfos to be removed...")
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
t.Log("Connection manager hasn't ticked yet")
time.Sleep(time.Second)
retries--
@ -595,7 +596,7 @@ func TestStage1RaceRelays2(t *testing.T) {
r.Log("Assert the tunnel works")
l.Info("Assert the tunnel works")
assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
myControl.Stop()
theirControl.Stop()
@ -607,14 +608,14 @@ func TestStage1RaceRelays2(t *testing.T) {
func TestRehandshakingRelays(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
// Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl)
@ -626,17 +627,17 @@ func TestRehandshakingRelays(t *testing.T) {
theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
// When I update the certificate for the relay, both me and them will have 2 host infos for the relay,
// and the main host infos will not have any relay state to handle the me<->relay<->them tunnel.
r.Log("Renew relay certificate and spin until me and them sees it")
_, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{relayVpnIpNet}, nil, []string{"new group"})
_, _, myNextPrivKey, myNextPEM := NewTestCert(cert.Version1, ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"})
caB, err := ca.MarshalPEM()
if err != nil {
@ -654,8 +655,8 @@ func TestRehandshakingRelays(t *testing.T) {
for {
r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet")
assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r)
c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false)
assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r)
c := myControl.GetHostInfoByVpnAddr(relayVpnIpNet[0].Addr(), false)
if len(c.Cert.Groups()) != 0 {
// We have a new certificate now
r.Log("Certificate between my and relay is updated!")
@ -667,8 +668,8 @@ func TestRehandshakingRelays(t *testing.T) {
for {
r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet")
assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r)
c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false)
assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r)
c := theirControl.GetHostInfoByVpnAddr(relayVpnIpNet[0].Addr(), false)
if len(c.Cert.Groups()) != 0 {
// We have a new certificate now
r.Log("Certificate between their and relay is updated!")
@ -679,13 +680,13 @@ func TestRehandshakingRelays(t *testing.T) {
}
r.Log("Assert the relay tunnel still works")
assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
// We should have two hostinfos on all sides
for len(myControl.GetHostmap().Indexes) != 2 {
t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes))
r.Log("Assert the relay tunnel still works")
assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
r.Log("yupitdoes")
time.Sleep(time.Second)
}
@ -693,7 +694,7 @@ func TestRehandshakingRelays(t *testing.T) {
for len(theirControl.GetHostmap().Indexes) != 2 {
t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes))
r.Log("Assert the relay tunnel still works")
assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
r.Log("yupitdoes")
time.Sleep(time.Second)
}
@ -701,7 +702,7 @@ func TestRehandshakingRelays(t *testing.T) {
for len(relayControl.GetHostmap().Indexes) != 2 {
t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes))
r.Log("Assert the relay tunnel still works")
assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
r.Log("yupitdoes")
time.Sleep(time.Second)
}
@ -711,14 +712,14 @@ func TestRehandshakingRelays(t *testing.T) {
func TestRehandshakingRelaysPrimary(t *testing.T) {
// This test is the same as TestRehandshakingRelays but one of the terminal types is a primary swap winner
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.128/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", "10.128.0.1/24", m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.128/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.1/24", m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
// Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl)
@ -730,17 +731,17 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
// When I update the certificate for the relay, both me and them will have 2 host infos for the relay,
// and the main host infos will not have any relay state to handle the me<->relay<->them tunnel.
r.Log("Renew relay certificate and spin until me and them sees it")
_, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{relayVpnIpNet}, nil, []string{"new group"})
_, _, myNextPrivKey, myNextPEM := NewTestCert(cert.Version1, ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"})
caB, err := ca.MarshalPEM()
if err != nil {
@ -758,8 +759,8 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
for {
r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet")
assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r)
c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false)
assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r)
c := myControl.GetHostInfoByVpnAddr(relayVpnIpNet[0].Addr(), false)
if len(c.Cert.Groups()) != 0 {
// We have a new certificate now
r.Log("Certificate between my and relay is updated!")
@ -771,8 +772,8 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
for {
r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet")
assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r)
c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false)
assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r)
c := theirControl.GetHostInfoByVpnAddr(relayVpnIpNet[0].Addr(), false)
if len(c.Cert.Groups()) != 0 {
// We have a new certificate now
r.Log("Certificate between their and relay is updated!")
@ -783,13 +784,13 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
}
r.Log("Assert the relay tunnel still works")
assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
// We should have two hostinfos on all sides
for len(myControl.GetHostmap().Indexes) != 2 {
t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes))
r.Log("Assert the relay tunnel still works")
assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
r.Log("yupitdoes")
time.Sleep(time.Second)
}
@ -797,7 +798,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
for len(theirControl.GetHostmap().Indexes) != 2 {
t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes))
r.Log("Assert the relay tunnel still works")
assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
r.Log("yupitdoes")
time.Sleep(time.Second)
}
@ -805,7 +806,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
for len(relayControl.GetHostmap().Indexes) != 2 {
t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes))
r.Log("Assert the relay tunnel still works")
assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
r.Log("yupitdoes")
time.Sleep(time.Second)
}
@ -814,12 +815,12 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
func TestRehandshaking(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", "10.128.0.2/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", "10.128.0.1/24", nil)
myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.2/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.1/24", nil)
// Put their info in our lighthouse and vice versa
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, theirControl)
@ -830,12 +831,12 @@ func TestRehandshaking(t *testing.T) {
theirControl.Start()
t.Log("Stand up a tunnel between me and them")
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
r.Log("Renew my certificate and spin until their sees it")
_, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{myVpnIpNet}, nil, []string{"new group"})
_, _, myNextPrivKey, myNextPEM := NewTestCert(cert.Version1, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), myVpnIpNet, nil, []string{"new group"})
caB, err := ca.MarshalPEM()
if err != nil {
@ -852,8 +853,8 @@ func TestRehandshaking(t *testing.T) {
myConfig.ReloadConfigString(string(rc))
for {
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false)
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
if len(c.Cert.Groups()) != 0 {
// We have a new certificate now
break
@ -880,19 +881,19 @@ func TestRehandshaking(t *testing.T) {
r.Log("Spin until there is only 1 tunnel")
for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
t.Log("Connection manager hasn't ticked yet")
time.Sleep(time.Second)
}
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
myFinalHostmapHosts := myControl.ListHostmapHosts(false)
myFinalHostmapIndexes := myControl.ListHostmapIndexes(false)
theirFinalHostmapHosts := theirControl.ListHostmapHosts(false)
theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false)
// Make sure the correct tunnel won
c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false)
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
assert.Contains(t, c.Cert.Groups(), "new group")
// We should only have a single tunnel now on both sides
@ -911,12 +912,12 @@ func TestRehandshakingLoser(t *testing.T) {
// The purpose of this test is that the race loser renews their certificate and rehandshakes. The final tunnel
// Should be the one with the new certificate
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", "10.128.0.2/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", "10.128.0.1/24", nil)
myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.2/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.1/24", nil)
// Put their info in our lighthouse and vice versa
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, theirControl)
@ -927,16 +928,12 @@ func TestRehandshakingLoser(t *testing.T) {
theirControl.Start()
t.Log("Stand up a tunnel between me and them")
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
tt1 := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false)
tt2 := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false)
fmt.Println(tt1.LocalIndex, tt2.LocalIndex)
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
r.Log("Renew their certificate and spin until mine sees it")
_, _, theirNextPrivKey, theirNextPEM := NewTestCert(ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{theirVpnIpNet}, nil, []string{"their new group"})
_, _, theirNextPrivKey, theirNextPEM := NewTestCert(cert.Version1, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), theirVpnIpNet, nil, []string{"their new group"})
caB, err := ca.MarshalPEM()
if err != nil {
@ -953,8 +950,8 @@ func TestRehandshakingLoser(t *testing.T) {
theirConfig.ReloadConfigString(string(rc))
for {
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false)
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
theirCertInMe := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
if slices.Contains(theirCertInMe.Cert.Groups(), "their new group") {
break
@ -980,19 +977,19 @@ func TestRehandshakingLoser(t *testing.T) {
r.Log("Spin until there is only 1 tunnel")
for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
t.Log("Connection manager hasn't ticked yet")
time.Sleep(time.Second)
}
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
myFinalHostmapHosts := myControl.ListHostmapHosts(false)
myFinalHostmapIndexes := myControl.ListHostmapIndexes(false)
theirFinalHostmapHosts := theirControl.ListHostmapHosts(false)
theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false)
// Make sure the correct tunnel won
theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false)
theirCertInMe := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
assert.Contains(t, theirCertInMe.Cert.Groups(), "their new group")
// We should only have a single tunnel now on both sides
@ -1011,12 +1008,12 @@ func TestRaceRegression(t *testing.T) {
// We had a bug where we were not finding the duplicate handshake and responding to the final stage 1 which
// caused a cross-linked hostinfo
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
// Put their info in our lighthouse
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Start the servers
myControl.Start()
@ -1030,8 +1027,8 @@ func TestRaceRegression(t *testing.T) {
//them rx stage:2 initiatorIndex=120607833 responderIndex=4209862089
t.Log("Start both handshakes")
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them"))
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
t.Log("Get both stage 1")
myStage1ForThem := myControl.GetFromUDP(true)
@ -1061,12 +1058,52 @@ func TestRaceRegression(t *testing.T) {
r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
t.Log("Make sure the tunnel still works")
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
myControl.Stop()
theirControl.Stop()
}
func TestV2NonPrimaryWithLighthouse(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh ", "10.128.0.1/24, ff::1/64", m{"lighthouse": m{"am_lighthouse": true}})
o := m{
"static_host_map": m{
lhVpnIpNet[1].Addr().String(): []string{lhUdpAddr.String()},
},
"lighthouse": m{
"hosts": []string{lhVpnIpNet[1].Addr().String()},
"local_allow_list": m{
// Try and block our lighthouse updates from using the actual addresses assigned to this computer
// If we start discovering addresses the test router doesn't know about then test traffic cant flow
"10.0.0.0/24": true,
"::/0": false,
},
},
}
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.2/24, ff::2/64", o)
theirControl, theirVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "10.128.0.3/24, ff::3/64", o)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, lhControl, myControl, theirControl)
defer r.RenderFlow()
// Start the servers
lhControl.Start()
myControl.Start()
theirControl.Start()
t.Log("Stand up an ipv6 tunnel between me and them")
assert.True(t, myVpnIpNet[1].Addr().Is6())
assert.True(t, theirVpnIpNet[1].Addr().Is6())
assertTunnel(t, myVpnIpNet[1].Addr(), theirVpnIpNet[1].Addr(), myControl, theirControl, r)
lhControl.Stop()
myControl.Stop()
theirControl.Stop()
}
//TODO: test
// Race winner renews and handshakes
// Race loser renews and handshakes

View File

@ -48,7 +48,7 @@ func NewTestCaCert(before, after time.Time, networks, unsafeNetworks []netip.Pre
// NewTestCert will generate a signed certificate with the provided details.
// Expiry times are defaulted if you do not pass them in
func NewTestCert(ca cert.Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) {
func NewTestCert(v cert.Version, ca cert.Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) {
if before.IsZero() {
before = time.Now().Add(time.Second * -60).Round(time.Second)
}
@ -59,7 +59,7 @@ func NewTestCert(ca cert.Certificate, key []byte, name string, before, after tim
pub, rawPriv := x25519Keypair()
nc := &cert.TBSCertificate{
Version: cert.Version1,
Version: v,
Name: name,
Networks: networks,
UnsafeNetworks: unsafeNetworks,

View File

@ -8,6 +8,7 @@ import (
"io"
"net/netip"
"os"
"strings"
"testing"
"time"
@ -26,25 +27,35 @@ import (
type m map[string]interface{}
// newSimpleServer creates a nebula instance with many assumptions
func newSimpleServer(caCrt cert.Certificate, caKey []byte, name string, sVpnIpNet string, overrides m) (*nebula.Control, netip.Prefix, netip.AddrPort, *config.C) {
func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
l := NewTestLogger()
vpnIpNet, err := netip.ParsePrefix(sVpnIpNet)
if err != nil {
panic(err)
var vpnNetworks []netip.Prefix
for _, sn := range strings.Split(sVpnNetworks, ",") {
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
if err != nil {
panic(err)
}
vpnNetworks = append(vpnNetworks, vpnIpNet)
}
if len(vpnNetworks) == 0 {
panic("no vpn networks")
}
var udpAddr netip.AddrPort
if vpnIpNet.Addr().Is4() {
budpIp := vpnIpNet.Addr().As4()
if vpnNetworks[0].Addr().Is4() {
budpIp := vpnNetworks[0].Addr().As4()
budpIp[1] -= 128
udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242)
} else {
budpIp := vpnIpNet.Addr().As16()
budpIp[13] -= 128
budpIp := vpnNetworks[0].Addr().As16()
// beef for funsies
budpIp[2] = 190
budpIp[3] = 239
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
}
_, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{vpnIpNet}, nil, []string{})
_, _, myPrivKey, myPEM := NewTestCert(v, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{})
caB, err := caCrt.MarshalPEM()
if err != nil {
@ -88,11 +99,16 @@ func newSimpleServer(caCrt cert.Certificate, caKey []byte, name string, sVpnIpNe
}
if overrides != nil {
err = mergo.Merge(&overrides, mc, mergo.WithAppendSlice)
final := m{}
err = mergo.Merge(&final, overrides, mergo.WithAppendSlice)
if err != nil {
panic(err)
}
mc = overrides
err = mergo.Merge(&final, mc, mergo.WithAppendSlice)
if err != nil {
panic(err)
}
mc = final
}
cb, err := yaml.Marshal(mc)
@ -109,7 +125,7 @@ func newSimpleServer(caCrt cert.Certificate, caKey []byte, name string, sVpnIpNe
panic(err)
}
return control, vpnIpNet, udpAddr, c
return control, vpnNetworks, udpAddr, c
}
type doneCb func()
@ -132,27 +148,28 @@ func deadline(t *testing.T, seconds time.Duration) doneCb {
func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
// Send a packet from them to me
controlB.InjectTunUDPPacket(vpnIpA, 80, 90, []byte("Hi from B"))
controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B"))
bPacket := r.RouteForAllUntilTxTun(controlA)
assertUdpPacket(t, []byte("Hi from B"), bPacket, vpnIpB, vpnIpA, 90, 80)
// And once more from me to them
controlA.InjectTunUDPPacket(vpnIpB, 80, 90, []byte("Hello from A"))
controlA.InjectTunUDPPacket(vpnIpB, 80, vpnIpA, 90, []byte("Hello from A"))
aPacket := r.RouteForAllUntilTxTun(controlB)
assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80)
}
func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control) {
func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpnNetsB []netip.Prefix, controlA, controlB *nebula.Control) {
// Get both host infos
hBinA := controlA.GetHostInfoByVpnIp(vpnIpB, false)
assert.NotNil(t, hBinA, "Host B was not found by vpnIp in controlA")
//TODO: we may want to loop over each vpnAddr and assert all the things
hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false)
assert.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA")
hAinB := controlB.GetHostInfoByVpnIp(vpnIpA, false)
assert.NotNil(t, hAinB, "Host A was not found by vpnIp in controlB")
hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false)
assert.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB")
// Check that both vpn and real addr are correct
assert.Equal(t, vpnIpB, hBinA.VpnIp, "Host B VpnIp is wrong in control A")
assert.Equal(t, vpnIpA, hAinB.VpnIp, "Host A VpnIp is wrong in control B")
assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A")
assert.EqualValues(t, getAddrs(vpnNetsA), hAinB.VpnAddrs, "Host A VpnIp is wrong in control B")
assert.Equal(t, addrB, hBinA.CurrentRemote, "Host B remote is wrong in control A")
assert.Equal(t, addrA, hAinB.CurrentRemote, "Host A remote is wrong in control B")
@ -179,6 +196,33 @@ func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnIpA, vpnIp
}
func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
if toIp.Is6() {
assertUdpPacket6(t, expected, b, fromIp, toIp, fromPort, toPort)
} else {
assertUdpPacket4(t, expected, b, fromIp, toIp, fromPort, toPort)
}
}
func assertUdpPacket6(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
packet := gopacket.NewPacket(b, layers.LayerTypeIPv6, gopacket.Lazy)
v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6)
assert.NotNil(t, v6, "No ipv6 data found")
assert.Equal(t, fromIp.AsSlice(), []byte(v6.SrcIP), "Source ip was incorrect")
assert.Equal(t, toIp.AsSlice(), []byte(v6.DstIP), "Dest ip was incorrect")
udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
assert.NotNil(t, udp, "No udp data found")
assert.Equal(t, fromPort, uint16(udp.SrcPort), "Source port was incorrect")
assert.Equal(t, toPort, uint16(udp.DstPort), "Dest port was incorrect")
data := packet.ApplicationLayer()
assert.NotNil(t, data)
assert.Equal(t, expected, data.Payload(), "Data was incorrect")
}
func assertUdpPacket4(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy)
v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
assert.NotNil(t, v4, "No ipv4 data found")
@ -197,6 +241,14 @@ func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr,
assert.Equal(t, expected, data.Payload(), "Data was incorrect")
}
func getAddrs(ns []netip.Prefix) []netip.Addr {
var a []netip.Addr
for _, n := range ns {
a = append(a, n.Addr())
}
return a
}
func NewTestLogger() *logrus.Logger {
l := logrus.New()

View File

@ -58,8 +58,9 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {
var lines []string
var globalLines []*edge
clusterName := strings.Trim(c.GetCert().Name(), " ")
clusterVpnIp := c.GetCert().Networks()[0].Addr()
crt := c.GetCertState().GetDefaultCertificate()
clusterName := strings.Trim(crt.Name(), " ")
clusterVpnIp := crt.Networks()[0].Addr()
r := fmt.Sprintf("\tsubgraph %s[\"%s (%s)\"]\n", clusterName, clusterName, clusterVpnIp)
hm := c.GetHostmap()
@ -101,7 +102,7 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {
for _, idx := range indexes {
hi, ok := hm.Indexes[idx]
if ok {
r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnIp())
r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnAddrs())
remoteClusterName := strings.Trim(hi.GetCert().Certificate.Name(), " ")
globalLines = append(globalLines, &edge{from: fmt.Sprintf("%v.%v", clusterName, idx), to: fmt.Sprintf("%v.%v", remoteClusterName, hi.GetRemoteIndex())})
_ = hi

View File

@ -10,8 +10,8 @@ import (
"os"
"path/filepath"
"reflect"
"regexp"
"sort"
"strings"
"sync"
"testing"
"time"
@ -136,7 +136,10 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
panic("Duplicate listen address: " + addr.String())
}
r.vpnControls[c.GetVpnIp()] = c
for _, vpnAddr := range c.GetVpnAddrs() {
r.vpnControls[vpnAddr] = c
}
r.controls[addr] = c
}
@ -213,11 +216,11 @@ func (r *R) renderFlow() {
continue
}
participants[addr] = struct{}{}
sanAddr := strings.Replace(addr.String(), ":", "-", 1)
sanAddr := normalizeName(addr.String())
participantsVals = append(participantsVals, sanAddr)
fmt.Fprintf(
f, " participant %s as Nebula: %s<br/>UDP: %s\n",
sanAddr, e.packet.from.GetVpnIp(), sanAddr,
sanAddr, e.packet.from.GetVpnAddrs(), sanAddr,
)
}
@ -250,9 +253,9 @@ func (r *R) renderFlow() {
fmt.Fprintf(f,
" %s%s%s: %s(%s), index %v, counter: %v\n",
strings.Replace(p.from.GetUDPAddr().String(), ":", "-", 1),
normalizeName(p.from.GetUDPAddr().String()),
line,
strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1),
normalizeName(p.to.GetUDPAddr().String()),
h.TypeName(), h.SubTypeName(), h.RemoteIndex, h.MessageCounter,
)
}
@ -267,6 +270,11 @@ func (r *R) renderFlow() {
}
}
func normalizeName(s string) string {
rx := regexp.MustCompile("[\\[\\]\\:]")
return rx.ReplaceAllLiteralString(s, "_")
}
// IgnoreFlow tells the router to stop recording future flows that matches the provided criteria.
// messageType and subType will target nebula underlay packets while tun will target nebula overlay packets
// NOTE: This is a very broad system, if you set tun to true then no more tun traffic will be rendered
@ -303,7 +311,7 @@ func (r *R) RenderHostmaps(title string, controls ...*nebula.Control) {
func (r *R) renderHostmaps(title string) {
c := maps.Values(r.controls)
sort.SliceStable(c, func(i, j int) bool {
return c[i].GetVpnIp().Compare(c[j].GetVpnIp()) > 0
return c[i].GetVpnAddrs()[0].Compare(c[j].GetVpnAddrs()[0]) > 0
})
s := renderHostmaps(c...)
@ -419,10 +427,11 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) []
// Nope, lets push the sender along
case p := <-udpTx:
r.Lock()
c := r.getControl(sender.GetUDPAddr(), p.To, p)
a := sender.GetUDPAddr()
c := r.getControl(a, p.To, p)
if c == nil {
r.Unlock()
panic("No control for udp tx")
panic("No control for udp tx " + a.String())
}
fp := r.unlockedInjectFlow(sender, c, p, false)
c.InjectUDPPacket(p)
@ -475,10 +484,11 @@ func (r *R) RouteForAllUntilTxTun(receiver *nebula.Control) []byte {
} else {
// we are a udp tx, route and continue
p := rx.Interface().(*udp.Packet)
c := r.getControl(cm[x].GetUDPAddr(), p.To, p)
a := cm[x].GetUDPAddr()
c := r.getControl(a, p.To, p)
if c == nil {
r.Unlock()
panic("No control for udp tx")
panic(fmt.Sprintf("No control for udp tx %s", p.To))
}
fp := r.unlockedInjectFlow(cm[x], c, p, false)
c.InjectUDPPacket(p)
@ -711,30 +721,42 @@ func (r *R) getControl(fromAddr, toAddr netip.AddrPort, p *udp.Packet) *nebula.C
}
func (r *R) formatUdpPacket(p *packet) string {
packet := gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv4, gopacket.Lazy)
v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
if v4 == nil {
panic("not an ipv4 packet")
var packet gopacket.Packet
var srcAddr netip.Addr
packet = gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv6, gopacket.Lazy)
if packet.ErrorLayer() == nil {
v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6)
if v6 == nil {
panic("not an ipv6 packet")
}
srcAddr, _ = netip.AddrFromSlice(v6.SrcIP)
} else {
packet = gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv4, gopacket.Lazy)
v6 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
if v6 == nil {
panic("not an ipv6 packet")
}
srcAddr, _ = netip.AddrFromSlice(v6.SrcIP)
}
from := "unknown"
srcAddr, _ := netip.AddrFromSlice(v4.SrcIP)
if c, ok := r.vpnControls[srcAddr]; ok {
from = c.GetUDPAddr().String()
}
udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
if udp == nil {
udpLayer := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
if udpLayer == nil {
panic("not a udp packet")
}
data := packet.ApplicationLayer()
return fmt.Sprintf(
" %s-->>%s: src port: %v<br/>dest port: %v<br/>data: \"%v\"\n",
strings.Replace(from, ":", "-", 1),
strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1),
udp.SrcPort,
udp.DstPort,
normalizeName(from),
normalizeName(p.to.GetUDPAddr().String()),
udpLayer.SrcPort,
udpLayer.DstPort,
string(data.Payload()),
)
}

View File

@ -13,6 +13,12 @@ pki:
# disconnect_invalid is a toggle to force a client to be disconnected if the certificate is expired or invalid.
#disconnect_invalid: true
# default_version controls which certificate version is used in handshakes.
# This setting only applies if both a v1 and a v2 certificate are configured, in which case it will default to `1`.
# Once all hosts in the mesh are configured with both a v1 and v2 certificate then this should be changed to `2`.
# After all hosts in the mesh are using a v2 certificate then v1 certificates are no longer needed.
# default_version: 1
# The static host map defines a set of hosts with fixed IP addresses on the internet (or any network).
# A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel.
# The syntax is:
@ -336,10 +342,13 @@ firewall:
# host: `any` or a literal hostname, ie `test-host`
# group: `any` or a literal group name, ie `default-group`
# groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
# cidr: a remote CIDR, `0.0.0.0/0` is any.
# local_cidr: a local CIDR, `0.0.0.0/0` is any. This could be used to filter destinations when using unsafe_routes.
# Default is `any` unless the certificate contains subnets and then the default is the ip issued in the certificate
# if `default_local_cidr_any` is false, otherwise its `any`.
# cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. //TODO: we have a problem, firewall needs to understand this and should probably allow `any` for both
# local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. This could be used to filter destinations when using unsafe_routes.
# //TODO: probably should have an `any` that covers both ip versions
# If no unsafe networks are present in the certificate(s) or `default_local_cidr_any` is true then the default is any ipv4 or ipv6 network.
# Otherwise the default is any vpn network assigned to via the certificate.
# `default_local_cidr_any` defaults to false and is deprecated, it will be removed in a future release.
# If there are unsafe routes present its best to set `local_cidr` to whatever best fits the situation.
# ca_name: An issuing CA name
# ca_sha: An issuing CA shasum

View File

@ -8,6 +8,7 @@ import (
"hash/fnv"
"net/netip"
"reflect"
"slices"
"strconv"
"strings"
"sync"
@ -22,7 +23,8 @@ import (
)
type FirewallInterface interface {
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error
//TODO: name these better addr, localAddr. Are they vpnAddrs?
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, addr, localAddr netip.Prefix, caName string, caSha string) error
}
type conn struct {
@ -51,9 +53,12 @@ type Firewall struct {
UDPTimeout time.Duration //linux: 180s max
DefaultTimeout time.Duration //linux: 600s
// Used to ensure we don't emit local packets for ips we don't own
localIps *bart.Table[struct{}]
assignedCIDR netip.Prefix
// routableNetworks describes the vpn addresses as well as any unsafe networks issued to us in the certificate.
// The vpn addresses are a full bit match while the unsafe networks only match the prefix
routableNetworks *bart.Table[struct{}]
// assignedNetworks is a list of vpn networks assigned to us in the certificate.
assignedNetworks []netip.Prefix
hasUnsafeNetworks bool
rules string
@ -67,9 +72,9 @@ type Firewall struct {
}
type firewallMetrics struct {
droppedLocalIP metrics.Counter
droppedRemoteIP metrics.Counter
droppedNoRule metrics.Counter
droppedLocalAddr metrics.Counter
droppedRemoteAddr metrics.Counter
droppedNoRule metrics.Counter
}
type FirewallConntrack struct {
@ -126,84 +131,87 @@ type firewallLocalCIDR struct {
}
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
// The certificate provided should be the highest version loaded in memory.
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall {
//TODO: error on 0 duration
var min, max time.Duration
var tmin, tmax time.Duration
if tcpTimeout < UDPTimeout {
min = tcpTimeout
max = UDPTimeout
tmin = tcpTimeout
tmax = UDPTimeout
} else {
min = UDPTimeout
max = tcpTimeout
tmin = UDPTimeout
tmax = tcpTimeout
}
if defaultTimeout < min {
min = defaultTimeout
} else if defaultTimeout > max {
max = defaultTimeout
if defaultTimeout < tmin {
tmin = defaultTimeout
} else if defaultTimeout > tmax {
tmax = defaultTimeout
}
localIps := new(bart.Table[struct{}])
var assignedCIDR netip.Prefix
var assignedSet bool
routableNetworks := new(bart.Table[struct{}])
var assignedNetworks []netip.Prefix
for _, network := range c.Networks() {
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
localIps.Insert(nprefix, struct{}{})
if !assignedSet {
// Only grabbing the first one in the cert since any more than that currently has undefined behavior
assignedCIDR = nprefix
assignedSet = true
}
routableNetworks.Insert(nprefix, struct{}{})
assignedNetworks = append(assignedNetworks, network)
}
hasUnsafeNetworks := false
for _, n := range c.UnsafeNetworks() {
localIps.Insert(n, struct{}{})
routableNetworks.Insert(n, struct{}{})
hasUnsafeNetworks = true
}
return &Firewall{
Conntrack: &FirewallConntrack{
Conns: make(map[firewall.Packet]*conn),
TimerWheel: NewTimerWheel[firewall.Packet](min, max),
TimerWheel: NewTimerWheel[firewall.Packet](tmin, tmax),
},
InRules: newFirewallTable(),
OutRules: newFirewallTable(),
TCPTimeout: tcpTimeout,
UDPTimeout: UDPTimeout,
DefaultTimeout: defaultTimeout,
localIps: localIps,
assignedCIDR: assignedCIDR,
routableNetworks: routableNetworks,
assignedNetworks: assignedNetworks,
hasUnsafeNetworks: hasUnsafeNetworks,
l: l,
incomingMetrics: firewallMetrics{
droppedLocalIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_ip", nil),
droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_ip", nil),
droppedNoRule: metrics.GetOrRegisterCounter("firewall.incoming.dropped.no_rule", nil),
droppedLocalAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_addr", nil),
droppedRemoteAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_addr", nil),
droppedNoRule: metrics.GetOrRegisterCounter("firewall.incoming.dropped.no_rule", nil),
},
outgoingMetrics: firewallMetrics{
droppedLocalIP: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.local_ip", nil),
droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.remote_ip", nil),
droppedNoRule: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.no_rule", nil),
droppedLocalAddr: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.local_addr", nil),
droppedRemoteAddr: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.remote_addr", nil),
droppedNoRule: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.no_rule", nil),
},
}
}
func NewFirewallFromConfig(l *logrus.Logger, nc cert.Certificate, c *config.C) (*Firewall, error) {
func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) {
certificate := cs.getCertificate(cert.Version2)
if certificate == nil {
certificate = cs.getCertificate(cert.Version1)
}
if certificate == nil {
panic("No certificate available to reconfigure the firewall")
}
fw := NewFirewall(
l,
c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3),
c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10),
nc,
certificate,
//TODO: max_connections
)
//TODO: Flip to false after v1.9 release
fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", true)
fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", false)
inboundAction := c.GetString("firewall.inbound_action", "drop")
switch inboundAction {
@ -283,7 +291,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
fp = ft.TCP
case firewall.ProtoUDP:
fp = ft.UDP
case firewall.ProtoICMP:
case firewall.ProtoICMP, firewall.ProtoICMPv6:
fp = ft.ICMP
case firewall.ProtoAny:
fp = ft.AnyProto
@ -424,26 +432,25 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
}
// Make sure remote address matches nebula certificate
if remoteCidr := h.remoteCidr; remoteCidr != nil {
//TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different
_, ok := remoteCidr.Lookup(fp.RemoteIP)
if h.networks != nil {
_, ok := h.networks.Lookup(fp.RemoteAddr)
if !ok {
f.metrics(incoming).droppedRemoteIP.Inc(1)
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP
}
} else {
// Simple case: Certificate has one IP and no subnets
if fp.RemoteIP != h.vpnIp {
f.metrics(incoming).droppedRemoteIP.Inc(1)
//TODO: we can make this more performant
if !slices.Contains(h.vpnAddrs, fp.RemoteAddr) {
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP
}
}
// Make sure we are supposed to be handling this local ip address
//TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different
_, ok := f.localIps.Lookup(fp.LocalIP)
_, ok := f.routableNetworks.Lookup(fp.LocalAddr)
if !ok {
f.metrics(incoming).droppedLocalIP.Inc(1)
f.metrics(incoming).droppedLocalAddr.Inc(1)
return ErrInvalidLocalIP
}
@ -629,7 +636,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.CachedC
if ft.UDP.match(p, incoming, c, caPool) {
return true
}
case firewall.ProtoICMP:
case firewall.ProtoICMP, firewall.ProtoICMPv6:
if ft.ICMP.match(p, incoming, c, caPool) {
return true
}
@ -859,9 +866,9 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool
}
matched := false
prefix := netip.PrefixFrom(p.RemoteIP, p.RemoteIP.BitLen())
prefix := netip.PrefixFrom(p.RemoteAddr, p.RemoteAddr.BitLen())
fr.CIDR.EachLookupPrefix(prefix, func(prefix netip.Prefix, val *firewallLocalCIDR) bool {
if prefix.Contains(p.RemoteIP) && val.match(p, c) {
if prefix.Contains(p.RemoteAddr) && val.match(p, c) {
matched = true
return false
}
@ -877,9 +884,14 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
return nil
}
localIp = f.assignedCIDR
for _, network := range f.assignedNetworks {
flc.LocalCIDR.Insert(network, struct{}{})
}
return nil
} else if localIp.Bits() == 0 {
flc.Any = true
return nil
}
flc.LocalCIDR.Insert(localIp, struct{}{})
@ -895,7 +907,7 @@ func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.CachedCertificate
return true
}
_, ok := flc.LocalCIDR.Lookup(p.LocalIP)
_, ok := flc.LocalCIDR.Lookup(p.LocalAddr)
return ok
}

View File

@ -9,18 +9,19 @@ import (
type m map[string]interface{}
const (
ProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever
ProtoTCP = 6
ProtoUDP = 17
ProtoICMP = 1
ProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever
ProtoTCP = 6
ProtoUDP = 17
ProtoICMP = 1
ProtoICMPv6 = 58
PortAny = 0 // Special value for matching `port: any`
PortFragment = -1 // Special value for matching `port: fragment`
)
type Packet struct {
LocalIP netip.Addr
RemoteIP netip.Addr
LocalAddr netip.Addr
RemoteAddr netip.Addr
LocalPort uint16
RemotePort uint16
Protocol uint8
@ -29,8 +30,8 @@ type Packet struct {
func (fp *Packet) Copy() *Packet {
return &Packet{
LocalIP: fp.LocalIP,
RemoteIP: fp.RemoteIP,
LocalAddr: fp.LocalAddr,
RemoteAddr: fp.RemoteAddr,
LocalPort: fp.LocalPort,
RemotePort: fp.RemotePort,
Protocol: fp.Protocol,
@ -51,8 +52,8 @@ func (fp Packet) MarshalJSON() ([]byte, error) {
proto = fmt.Sprintf("unknown %v", fp.Protocol)
}
return json.Marshal(m{
"LocalIP": fp.LocalIP.String(),
"RemoteIP": fp.RemoteIP.String(),
"LocalAddr": fp.LocalAddr.String(),
"RemoteAddr": fp.RemoteAddr.String(),
"LocalPort": fp.LocalPort,
"RemotePort": fp.RemotePort,
"Protocol": proto,

View File

@ -13,6 +13,7 @@ import (
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewFirewall(t *testing.T) {
@ -128,8 +129,8 @@ func TestFirewall_Drop(t *testing.T) {
l.SetOutput(ob)
p := firewall.Packet{
LocalIP: netip.MustParseAddr("1.2.3.4"),
RemoteIP: netip.MustParseAddr("1.2.3.4"),
LocalAddr: netip.MustParseAddr("1.2.3.4"),
RemoteAddr: netip.MustParseAddr("1.2.3.4"),
LocalPort: 10,
RemotePort: 90,
Protocol: firewall.ProtoUDP,
@ -149,9 +150,9 @@ func TestFirewall_Drop(t *testing.T) {
InvertedGroups: map[string]struct{}{"default-group": {}},
},
},
vpnIp: netip.MustParseAddr("1.2.3.4"),
vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
}
h.CreateRemoteCIDR(&c)
h.buildNetworks(&c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@ -166,10 +167,10 @@ func TestFirewall_Drop(t *testing.T) {
assert.NoError(t, fw.Drop(p, false, &h, cp, nil))
// test remote mismatch
oldRemote := p.RemoteIP
p.RemoteIP = netip.MustParseAddr("1.2.3.10")
oldRemote := p.RemoteAddr
p.RemoteAddr = netip.MustParseAddr("1.2.3.10")
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
p.RemoteIP = oldRemote
p.RemoteAddr = oldRemote
// ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
@ -235,7 +236,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
}
ip := netip.MustParsePrefix("9.254.254.254/32")
for n := 0; n < b.N; n++ {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip.Addr()}, true, c, cp))
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
}
})
@ -261,7 +262,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
InvertedGroups: map[string]struct{}{"nope": {}},
}
for n := 0; n < b.N; n++ {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp))
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
}
})
@ -285,7 +286,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
InvertedGroups: map[string]struct{}{"good-group": {}},
}
for n := 0; n < b.N; n++ {
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp))
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
}
})
@ -308,8 +309,8 @@ func TestFirewall_Drop2(t *testing.T) {
l.SetOutput(ob)
p := firewall.Packet{
LocalIP: netip.MustParseAddr("1.2.3.4"),
RemoteIP: netip.MustParseAddr("1.2.3.4"),
LocalAddr: netip.MustParseAddr("1.2.3.4"),
RemoteAddr: netip.MustParseAddr("1.2.3.4"),
LocalPort: 10,
RemotePort: 90,
Protocol: firewall.ProtoUDP,
@ -329,9 +330,9 @@ func TestFirewall_Drop2(t *testing.T) {
ConnectionState: &ConnectionState{
peerCert: &c,
},
vpnIp: network.Addr(),
vpnAddrs: []netip.Addr{network.Addr()},
}
h.CreateRemoteCIDR(c.Certificate)
h.buildNetworks(c.Certificate)
c1 := cert.CachedCertificate{
Certificate: &dummyCert{
@ -345,7 +346,7 @@ func TestFirewall_Drop2(t *testing.T) {
peerCert: &c1,
},
}
h1.CreateRemoteCIDR(c1.Certificate)
h1.buildNetworks(c1.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@ -364,8 +365,8 @@ func TestFirewall_Drop3(t *testing.T) {
l.SetOutput(ob)
p := firewall.Packet{
LocalIP: netip.MustParseAddr("1.2.3.4"),
RemoteIP: netip.MustParseAddr("1.2.3.4"),
LocalAddr: netip.MustParseAddr("1.2.3.4"),
RemoteAddr: netip.MustParseAddr("1.2.3.4"),
LocalPort: 1,
RemotePort: 1,
Protocol: firewall.ProtoUDP,
@ -391,9 +392,9 @@ func TestFirewall_Drop3(t *testing.T) {
ConnectionState: &ConnectionState{
peerCert: &c1,
},
vpnIp: network.Addr(),
vpnAddrs: []netip.Addr{network.Addr()},
}
h1.CreateRemoteCIDR(c1.Certificate)
h1.buildNetworks(c1.Certificate)
c2 := cert.CachedCertificate{
Certificate: &dummyCert{
@ -406,9 +407,9 @@ func TestFirewall_Drop3(t *testing.T) {
ConnectionState: &ConnectionState{
peerCert: &c2,
},
vpnIp: network.Addr(),
vpnAddrs: []netip.Addr{network.Addr()},
}
h2.CreateRemoteCIDR(c2.Certificate)
h2.buildNetworks(c2.Certificate)
c3 := cert.CachedCertificate{
Certificate: &dummyCert{
@ -421,9 +422,9 @@ func TestFirewall_Drop3(t *testing.T) {
ConnectionState: &ConnectionState{
peerCert: &c3,
},
vpnIp: network.Addr(),
vpnAddrs: []netip.Addr{network.Addr()},
}
h3.CreateRemoteCIDR(c3.Certificate)
h3.buildNetworks(c3.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
@ -446,8 +447,8 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
l.SetOutput(ob)
p := firewall.Packet{
LocalIP: netip.MustParseAddr("1.2.3.4"),
RemoteIP: netip.MustParseAddr("1.2.3.4"),
LocalAddr: netip.MustParseAddr("1.2.3.4"),
RemoteAddr: netip.MustParseAddr("1.2.3.4"),
LocalPort: 10,
RemotePort: 90,
Protocol: firewall.ProtoUDP,
@ -468,9 +469,9 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
ConnectionState: &ConnectionState{
peerCert: &c,
},
vpnIp: network.Addr(),
vpnAddrs: []netip.Addr{network.Addr()},
}
h.CreateRemoteCIDR(c.Certificate)
h.buildNetworks(c.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@ -622,55 +623,58 @@ func TestNewFirewallFromConfig(t *testing.T) {
l := test.NewLogger()
// Test a bad rule definition
c := &dummyCert{}
cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil)
require.NoError(t, err)
conf := config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
_, err := NewFirewallFromConfig(l, c, conf)
_, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
// Test both port and code
conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
_, err = NewFirewallFromConfig(l, c, conf)
_, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
// Test missing host, group, cidr, ca_name and ca_sha
conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
_, err = NewFirewallFromConfig(l, c, conf)
_, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
// Test code/port error
conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
_, err = NewFirewallFromConfig(l, c, conf)
_, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}}
_, err = NewFirewallFromConfig(l, c, conf)
_, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
// Test proto error
conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
_, err = NewFirewallFromConfig(l, c, conf)
_, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
// Test cidr parse error
conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, c, conf)
_, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
// Test local_cidr parse error
conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, c, conf)
_, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
// Test both group and groups
conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
_, err = NewFirewallFromConfig(l, c, conf)
_, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
}

1
go.mod
View File

@ -21,7 +21,6 @@ require (
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
github.com/sirupsen/logrus v1.9.3
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
github.com/stretchr/testify v1.9.0
github.com/vishvananda/netlink v1.3.0

2
go.sum
View File

@ -137,8 +137,6 @@ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8=
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E=
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 h1:pnnLyeX7o/5aX8qUQ69P/mLojDqwda8hFOCBTmP/6hw=
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6/go.mod h1:39R/xuhNgVhi+K0/zst4TLrJrVmbm6LVgl4A0+ZFS5M=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=

View File

@ -2,10 +2,12 @@ package nebula
import (
"net/netip"
"slices"
"time"
"github.com/flynn/noise"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/header"
)
@ -16,30 +18,60 @@ import (
func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
err := f.handshakeManager.allocateIndex(hh)
if err != nil {
f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp).
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
return false
}
certState := f.pki.GetCertState()
ci := NewConnectionState(f.l, f.cipher, certState, true, noise.HandshakeIX, []byte{}, 0)
// If we're connecting to a v6 address we must use a v2 cert
cs := f.pki.getCertState()
v := cs.defaultVersion
for _, a := range hh.hostinfo.vpnAddrs {
if a.Is6() {
v = cert.Version2
break
}
}
crt := cs.getCertificate(v)
if crt == nil {
f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", v).
Error("Unable to handshake with host because no certificate is available")
return false
}
crtHs := cs.getHandshakeBytes(v)
if crtHs == nil {
f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", v).
Error("Unable to handshake with host because no certificate handshake bytes is available")
}
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", v).
Error("Failed to create connection state")
return false
}
hh.hostinfo.ConnectionState = ci
hsProto := &NebulaHandshakeDetails{
InitiatorIndex: hh.hostinfo.localIndexId,
Time: uint64(time.Now().UnixNano()),
Cert: certState.RawCertificateNoKey,
}
hsBytes := []byte{}
hs := &NebulaHandshake{
Details: hsProto,
Details: &NebulaHandshakeDetails{
InitiatorIndex: hh.hostinfo.localIndexId,
Time: uint64(time.Now().UnixNano()),
Cert: crtHs,
CertVersion: uint32(v),
},
}
hsBytes, err = hs.Marshal()
hsBytes, err := hs.Marshal()
if err != nil {
f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp).
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).WithField("certVersion", v).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
return false
}
@ -48,7 +80,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
if err != nil {
f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp).
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
return false
}
@ -63,30 +95,44 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
}
func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
certState := f.pki.GetCertState()
ci := NewConnectionState(f.l, f.cipher, certState, false, noise.HandshakeIX, []byte{}, 0)
cs := f.pki.getCertState()
crt := cs.GetDefaultCertificate()
if crt == nil {
f.l.WithField("udpAddr", addr).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", cs.defaultVersion).
Error("Unable to handshake with host because no certificate is available")
}
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
if err != nil {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to create connection state")
return
}
// Mark packet 1 as seen so it doesn't show up as missed
ci.window.Update(f.l, 1)
msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage")
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to call noise.ReadMessage")
return
}
hs := &NebulaHandshake{}
err = hs.Unmarshal(msg)
/*
l.Debugln("GOT INDEX: ", hs.Details.InitiatorIndex)
*/
if err != nil || hs.Details == nil {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed unmarshal handshake message")
return
}
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool())
remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool())
if err != nil {
e := f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"})
@ -99,6 +145,20 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
return
}
if remoteCert.Certificate.Version() != ci.myCert.Version() {
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
rc := cs.getCertificate(remoteCert.Certificate.Version())
if rc == nil {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
Info("Unable to handshake with host due to missing certificate version")
return
}
// Record the certificate we are actually using
ci.myCert = rc
}
if len(remoteCert.Certificate.Networks()) == 0 {
e := f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"})
@ -111,30 +171,36 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
return
}
vpnIp := remoteCert.Certificate.Networks()[0].Addr().Unmap()
var vpnAddrs []netip.Addr
certName := remoteCert.Certificate.Name()
fingerprint := remoteCert.Fingerprint
issuer := remoteCert.Certificate.Issuer()
if vpnIp == f.myVpnNet.Addr() {
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
return
}
if addr.IsValid() {
if !f.lightHouse.GetRemoteAllowList().Allow(vpnIp, addr.Addr()) {
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
for _, network := range remoteCert.Certificate.Networks() {
vpnAddr := network.Addr()
_, found := f.myVpnAddrsTable.Lookup(vpnAddr)
if found {
f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
return
}
if addr.IsValid() {
if !f.lightHouse.GetRemoteAllowList().Allow(vpnAddr, addr.Addr()) {
f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
return
}
}
vpnAddrs = append(vpnAddrs, vpnAddr)
}
myIndex, err := generateIndex(f.l)
if err != nil {
f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
@ -146,17 +212,17 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
ConnectionState: ci,
localIndexId: myIndex,
remoteIndexId: hs.Details.InitiatorIndex,
vpnIp: vpnIp,
vpnAddrs: vpnAddrs,
HandshakePacket: make(map[uint8][]byte, 0),
lastHandshakeTime: hs.Details.Time,
relayState: RelayState{
relays: map[netip.Addr]struct{}{},
relayForByIp: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{},
relays: map[netip.Addr]struct{}{},
relayForByAddr: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},
}
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
@ -165,13 +231,26 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
Info("Handshake message received")
hs.Details.ResponderIndex = myIndex
hs.Details.Cert = certState.RawCertificateNoKey
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
if hs.Details.Cert == nil {
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("certVersion", ci.myCert.Version()).
Error("Unable to handshake with host because no certificate handshake bytes is available")
return
}
hs.Details.CertVersion = uint32(ci.myCert.Version())
// Update the time in case their clock is way off from ours
hs.Details.Time = uint64(time.Now().UnixNano())
hsBytes, err := hs.Marshal()
if err != nil {
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
@ -182,14 +261,14 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2)
msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes)
if err != nil {
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
return
} else if dKey == nil || eKey == nil {
f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
@ -213,9 +292,9 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
ci.dKey = NewNebulaCipherState(dKey)
ci.eKey = NewNebulaCipherState(eKey)
hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
hostinfo.SetRemote(addr)
hostinfo.CreateRemoteCIDR(remoteCert.Certificate)
hostinfo.buildNetworks(remoteCert.Certificate)
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
if err != nil {
@ -225,7 +304,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
if existing.SetRemoteIfPreferred(f.hostMap, addr) {
// Send a test packet to ensure the other side has also switched to
// the preferred remote
f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
}
msg = existing.HandshakePacket[2]
@ -233,11 +312,11 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
if addr.IsValid() {
err := f.outside.WriteTo(msg, addr)
if err != nil {
f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr).
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
WithError(err).Error("Failed to send handshake message")
} else {
f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr).
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
Info("Handshake message sent")
}
@ -247,16 +326,16 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
f.l.Error("Handshake send failed: both addr and via are nil.")
return
}
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
f.l.WithField("vpnIp", existing.vpnIp).WithField("relay", via.relayHI.vpnIp).
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
Info("Handshake message sent")
return
}
case ErrExistingHostInfo:
// This means there was an existing tunnel and this handshake was older than the one we are currently based on
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("oldHandshakeTime", existing.lastHandshakeTime).
WithField("newHandshakeTime", hostinfo.lastHandshakeTime).
@ -267,23 +346,23 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
Info("Handshake too old")
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
return
case ErrLocalIndexCollision:
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnIp).
WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnAddrs).
Error("Failed to add HostInfo due to localIndex collision")
return
default:
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
// And we forget to update it here
f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
@ -299,7 +378,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
if addr.IsValid() {
err = f.outside.WriteTo(msg, addr)
if err != nil {
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
@ -307,7 +386,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithError(err).Error("Failed to send handshake")
} else {
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
@ -320,9 +399,9 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
f.l.Error("Handshake send failed: both addr and via are nil.")
return
}
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
f.l.WithField("vpnIp", vpnIp).WithField("relay", via.relayHI.vpnIp).
f.l.WithField("vpnAddrs", vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
@ -349,8 +428,9 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
hostinfo := hh.hostinfo
if addr.IsValid() {
if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.Addr()) {
f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
//TODO: this is kind of nonsense now
if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnAddrs[0], addr.Addr()) {
f.l.WithField("vpnIp", hostinfo.vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
return false
}
}
@ -358,7 +438,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
ci := hostinfo.ConnectionState
msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil {
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
Error("Failed to call noise.ReadMessage")
@ -367,7 +447,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
// near future
return false
} else if dKey == nil || eKey == nil {
f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Error("Noise did not arrive at a key")
@ -379,16 +459,16 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
hs := &NebulaHandshake{}
err = hs.Unmarshal(msg)
if err != nil || hs.Details == nil {
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
return true
}
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool())
remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool())
if err != nil {
e := f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
e := f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
if f.l.Level > logrus.DebugLevel {
@ -413,7 +493,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
return true
}
vpnIp := remoteCert.Certificate.Networks()[0].Addr().Unmap()
vpnNetworks := remoteCert.Certificate.Networks()
certName := remoteCert.Certificate.Name()
fingerprint := remoteCert.Fingerprint
issuer := remoteCert.Certificate.Issuer()
@ -430,12 +510,17 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
if addr.IsValid() {
hostinfo.SetRemote(addr)
} else {
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
}
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
for i, n := range vpnNetworks {
vpnAddrs[i] = n.Addr()
}
// Ensure the right host responded
if vpnIp != hostinfo.vpnIp {
f.l.WithField("intendedVpnIp", hostinfo.vpnIp).WithField("haveVpnIp", vpnIp).
if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) {
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
WithField("udpAddr", addr).WithField("certName", certName).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("Incorrect host responded to handshake")
@ -444,14 +529,14 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
f.handshakeManager.DeleteHostInfo(hostinfo)
// Create a new hostinfo/handshake for the intended vpn ip
f.handshakeManager.StartHandshake(hostinfo.vpnIp, func(newHH *HandshakeHostInfo) {
f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
//TODO: this doesnt know if its being added or is being used for caching a packet
// Block the current used address
newHH.hostinfo.remotes = hostinfo.remotes
newHH.hostinfo.remotes.BlockRemote(addr)
f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).
WithField("vpnIp", newHH.hostinfo.vpnIp).
WithField("vpnNetworks", vpnNetworks).
WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())).
Info("Blocked addresses for handshakes")
@ -459,11 +544,8 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
newHH.packetStore = hh.packetStore
hh.packetStore = []*cachedPacket{}
// Get the correct remote list for the host we did handshake with
hostinfo.SetRemote(addr)
hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
// Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down
hostinfo.vpnIp = vpnIp
// Finally, put the correct vpn addrs in the host info, tell them to close the tunnel, and return true to tear down
hostinfo.vpnAddrs = vpnAddrs
f.sendCloseTunnel(hostinfo)
})
@ -474,7 +556,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
ci.window.Update(f.l, 2)
duration := time.Since(hh.startTime).Nanoseconds()
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
f.l.WithField("vpnNetworks", vpnNetworks).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
@ -485,7 +567,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
Info("Handshake message received")
// Build up the radix for the firewall if we have subnets in the cert
hostinfo.CreateRemoteCIDR(remoteCert.Certificate)
hostinfo.buildNetworks(remoteCert.Certificate)
// Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp
f.handshakeManager.Complete(hostinfo, f)

View File

@ -13,6 +13,7 @@ import (
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/udp"
)
@ -118,18 +119,18 @@ func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *Lig
}
}
func (c *HandshakeManager) Run(ctx context.Context) {
clockSource := time.NewTicker(c.config.tryInterval)
func (hm *HandshakeManager) Run(ctx context.Context) {
clockSource := time.NewTicker(hm.config.tryInterval)
defer clockSource.Stop()
for {
select {
case <-ctx.Done():
return
case vpnIP := <-c.trigger:
c.handleOutbound(vpnIP, true)
case vpnIP := <-hm.trigger:
hm.handleOutbound(vpnIP, true)
case now := <-clockSource.C:
c.NextOutboundHandshakeTimerTick(now)
hm.NextOutboundHandshakeTimerTick(now)
}
}
}
@ -159,14 +160,14 @@ func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender,
}
}
func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) {
c.OutboundHandshakeTimer.Advance(now)
func (hm *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) {
hm.OutboundHandshakeTimer.Advance(now)
for {
vpnIp, has := c.OutboundHandshakeTimer.Purge()
vpnIp, has := hm.OutboundHandshakeTimer.Purge()
if !has {
break
}
c.handleOutbound(vpnIp, false)
hm.handleOutbound(vpnIp, false)
}
}
@ -208,7 +209,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
// NB ^ This comment doesn't jive. It's how the thing gets initialized.
// It's the common path. Should it update every time, in case a future LH query/queries give us more info?
if hostinfo.remotes == nil {
hostinfo.remotes = hm.lightHouse.QueryCache(vpnIp)
hostinfo.remotes = hm.lightHouse.QueryCache([]netip.Addr{vpnIp})
}
remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())
@ -267,11 +268,18 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
// Send a RelayRequest to all known Relay IP's
for _, relay := range hostinfo.remotes.relays {
// Don't relay to myself, and don't relay through the host I'm trying to connect to
if relay == vpnIp || relay == hm.lightHouse.myVpnNet.Addr() {
// Don't relay to myself
if relay == vpnIp {
continue
}
relayHostInfo := hm.mainHostMap.QueryVpnIp(relay)
// Don't relay through the host I'm trying to connect to
_, found := hm.f.myVpnAddrsTable.Lookup(relay)
if found {
continue
}
relayHostInfo := hm.mainHostMap.QueryVpnAddr(relay)
if relayHostInfo == nil || !relayHostInfo.remote.IsValid() {
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
hm.f.Handshake(relay)
@ -286,17 +294,35 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
case Requested:
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request")
//TODO: IPV6-WORK
myVpnIpB := hm.f.myVpnNet.Addr().As4()
theirVpnIpB := vpnIp.As4()
// Re-send the CreateRelay request, in case the previous one was lost.
m := NebulaControl{
Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: existingRelay.LocalIndex,
RelayFromIp: binary.BigEndian.Uint32(myVpnIpB[:]),
RelayToIp: binary.BigEndian.Uint32(theirVpnIpB[:]),
}
switch relayHostInfo.GetCert().Certificate.Version() {
case cert.Version1:
if !hm.f.myVpnAddrs[0].Is4() {
hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
continue
}
if !vpnIp.Is4() {
hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
continue
}
b := hm.f.myVpnAddrs[0].As4()
m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
b = vpnIp.As4()
m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
case cert.Version2:
m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0])
m.RelayToAddr = netAddrToProtoAddr(vpnIp)
default:
hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
continue
}
msg, err := m.Marshal()
if err != nil {
hostinfo.logger(hm.l).
@ -306,7 +332,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
// This must send over the hostinfo, not over hm.Hosts[ip]
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
hm.l.WithFields(logrus.Fields{
"relayFrom": hm.f.myVpnNet.Addr(),
"relayFrom": hm.f.myVpnAddrs[0],
"relayTo": vpnIp,
"initiatorRelayIndex": existingRelay.LocalIndex,
"relay": relay}).
@ -316,7 +342,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
hostinfo.logger(hm.l).
WithField("vpnIp", vpnIp).
WithField("state", existingRelay.State).
WithField("relay", relayHostInfo.vpnIp).
WithField("relay", relayHostInfo.vpnAddrs[0]).
Errorf("Relay unexpected state")
}
} else {
@ -327,16 +353,35 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap")
}
//TODO: IPV6-WORK
myVpnIpB := hm.f.myVpnNet.Addr().As4()
theirVpnIpB := vpnIp.As4()
m := NebulaControl{
Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: idx,
RelayFromIp: binary.BigEndian.Uint32(myVpnIpB[:]),
RelayToIp: binary.BigEndian.Uint32(theirVpnIpB[:]),
}
switch relayHostInfo.GetCert().Certificate.Version() {
case cert.Version1:
if !hm.f.myVpnAddrs[0].Is4() {
hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
continue
}
if !vpnIp.Is4() {
hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
continue
}
b := hm.f.myVpnAddrs[0].As4()
m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
b = vpnIp.As4()
m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
case cert.Version2:
m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0])
m.RelayToAddr = netAddrToProtoAddr(vpnIp)
default:
hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
continue
}
msg, err := m.Marshal()
if err != nil {
hostinfo.logger(hm.l).
@ -345,7 +390,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
} else {
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
hm.l.WithFields(logrus.Fields{
"relayFrom": hm.f.myVpnNet.Addr(),
"relayFrom": hm.f.myVpnAddrs[0],
"relayTo": vpnIp,
"initiatorRelayIndex": idx,
"relay": relay}).
@ -381,10 +426,10 @@ func (hm *HandshakeManager) GetOrHandshake(vpnIp netip.Addr, cacheCb func(*Hands
}
// StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip
func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*HandshakeHostInfo)) *HostInfo {
func (hm *HandshakeManager) StartHandshake(vpnAddr netip.Addr, cacheCb func(*HandshakeHostInfo)) *HostInfo {
hm.Lock()
if hh, ok := hm.vpnIps[vpnIp]; ok {
if hh, ok := hm.vpnIps[vpnAddr]; ok {
// We are already trying to handshake with this vpn ip
if cacheCb != nil {
cacheCb(hh)
@ -394,12 +439,12 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands
}
hostinfo := &HostInfo{
vpnIp: vpnIp,
vpnAddrs: []netip.Addr{vpnAddr},
HandshakePacket: make(map[uint8][]byte, 0),
relayState: RelayState{
relays: map[netip.Addr]struct{}{},
relayForByIp: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{},
relays: map[netip.Addr]struct{}{},
relayForByAddr: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},
}
@ -407,9 +452,9 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands
hostinfo: hostinfo,
startTime: time.Now(),
}
hm.vpnIps[vpnIp] = hh
hm.vpnIps[vpnAddr] = hh
hm.metricInitiated.Inc(1)
hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval)
hm.OutboundHandshakeTimer.Add(vpnAddr, hm.config.tryInterval)
if cacheCb != nil {
cacheCb(hh)
@ -417,21 +462,21 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands
// If this is a static host, we don't need to wait for the HostQueryReply
// We can trigger the handshake right now
_, doTrigger := hm.lightHouse.GetStaticHostList()[vpnIp]
_, doTrigger := hm.lightHouse.GetStaticHostList()[vpnAddr]
if !doTrigger {
// Add any calculated remotes, and trigger early handshake if one found
doTrigger = hm.lightHouse.addCalculatedRemotes(vpnIp)
doTrigger = hm.lightHouse.addCalculatedRemotes(vpnAddr)
}
if doTrigger {
select {
case hm.trigger <- vpnIp:
case hm.trigger <- vpnAddr:
default:
}
}
hm.Unlock()
hm.lightHouse.QueryServer(vpnIp)
hm.lightHouse.QueryServer(vpnAddr)
return hostinfo
}
@ -452,14 +497,14 @@ var (
//
// ErrLocalIndexCollision if we already have an entry in the main or pending
// hostmap for the hostinfo.localIndexId.
func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) {
c.mainHostMap.Lock()
defer c.mainHostMap.Unlock()
c.Lock()
defer c.Unlock()
func (hm *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) {
hm.mainHostMap.Lock()
defer hm.mainHostMap.Unlock()
hm.Lock()
defer hm.Unlock()
// Check if we already have a tunnel with this vpn ip
existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp]
existingHostInfo, found := hm.mainHostMap.Hosts[hostinfo.vpnAddrs[0]]
if found && existingHostInfo != nil {
testHostInfo := existingHostInfo
for testHostInfo != nil {
@ -476,31 +521,31 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
return existingHostInfo, ErrExistingHostInfo
}
existingHostInfo.logger(c.l).Info("Taking new handshake")
existingHostInfo.logger(hm.l).Info("Taking new handshake")
}
existingIndex, found := c.mainHostMap.Indexes[hostinfo.localIndexId]
existingIndex, found := hm.mainHostMap.Indexes[hostinfo.localIndexId]
if found {
// We have a collision, but for a different hostinfo
return existingIndex, ErrLocalIndexCollision
}
existingPendingIndex, found := c.indexes[hostinfo.localIndexId]
existingPendingIndex, found := hm.indexes[hostinfo.localIndexId]
if found && existingPendingIndex.hostinfo != hostinfo {
// We have a collision, but for a different hostinfo
return existingPendingIndex.hostinfo, ErrLocalIndexCollision
}
existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
if found && existingRemoteIndex != nil && existingRemoteIndex.vpnIp != hostinfo.vpnIp {
existingRemoteIndex, found := hm.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
if found && existingRemoteIndex != nil && existingRemoteIndex.vpnAddrs[0] != hostinfo.vpnAddrs[0] {
// We have a collision, but this can happen since we can't control
// the remote ID. Just log about the situation as a note.
hostinfo.logger(c.l).
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp).
hostinfo.logger(hm.l).
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs).
Info("New host shadows existing host remoteIndex")
}
c.mainHostMap.unlockedAddHostInfo(hostinfo, f)
hm.mainHostMap.unlockedAddHostInfo(hostinfo, f)
return existingHostInfo, nil
}
@ -518,7 +563,7 @@ func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
// We have a collision, but this can happen since we can't control
// the remote ID. Just log about the situation as a note.
hostinfo.logger(hm.l).
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp).
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs).
Info("New host shadows existing host remoteIndex")
}
@ -555,31 +600,34 @@ func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error {
return errors.New("failed to generate unique localIndexId")
}
func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
c.Lock()
defer c.Unlock()
c.unlockedDeleteHostInfo(hostinfo)
func (hm *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
hm.Lock()
defer hm.Unlock()
hm.unlockedDeleteHostInfo(hostinfo)
}
func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
delete(c.vpnIps, hostinfo.vpnIp)
if len(c.vpnIps) == 0 {
c.vpnIps = map[netip.Addr]*HandshakeHostInfo{}
func (hm *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
for _, addr := range hostinfo.vpnAddrs {
delete(hm.vpnIps, addr)
}
delete(c.indexes, hostinfo.localIndexId)
if len(c.vpnIps) == 0 {
c.indexes = map[uint32]*HandshakeHostInfo{}
if len(hm.vpnIps) == 0 {
hm.vpnIps = map[netip.Addr]*HandshakeHostInfo{}
}
if c.l.Level >= logrus.DebugLevel {
c.l.WithField("hostMap", m{"mapTotalSize": len(c.vpnIps),
"vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
delete(hm.indexes, hostinfo.localIndexId)
if len(hm.indexes) == 0 {
hm.indexes = map[uint32]*HandshakeHostInfo{}
}
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.vpnIps),
"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
Debug("Pending hostmap hostInfo deleted")
}
}
func (hm *HandshakeManager) QueryVpnIp(vpnIp netip.Addr) *HostInfo {
func (hm *HandshakeManager) QueryVpnAddr(vpnIp netip.Addr) *HostInfo {
hh := hm.queryVpnIp(vpnIp)
if hh != nil {
return hh.hostinfo
@ -608,37 +656,37 @@ func (hm *HandshakeManager) queryIndex(index uint32) *HandshakeHostInfo {
return hm.indexes[index]
}
func (c *HandshakeManager) GetPreferredRanges() []netip.Prefix {
return c.mainHostMap.GetPreferredRanges()
func (hm *HandshakeManager) GetPreferredRanges() []netip.Prefix {
return hm.mainHostMap.GetPreferredRanges()
}
func (c *HandshakeManager) ForEachVpnIp(f controlEach) {
c.RLock()
defer c.RUnlock()
func (hm *HandshakeManager) ForEachVpnAddr(f controlEach) {
hm.RLock()
defer hm.RUnlock()
for _, v := range c.vpnIps {
for _, v := range hm.vpnIps {
f(v.hostinfo)
}
}
func (c *HandshakeManager) ForEachIndex(f controlEach) {
c.RLock()
defer c.RUnlock()
func (hm *HandshakeManager) ForEachIndex(f controlEach) {
hm.RLock()
defer hm.RUnlock()
for _, v := range c.indexes {
for _, v := range hm.indexes {
f(v.hostinfo)
}
}
func (c *HandshakeManager) EmitStats() {
c.RLock()
hostLen := len(c.vpnIps)
indexLen := len(c.indexes)
c.RUnlock()
func (hm *HandshakeManager) EmitStats() {
hm.RLock()
hostLen := len(hm.vpnIps)
indexLen := len(hm.indexes)
hm.RUnlock()
metrics.GetOrRegisterGauge("hostmap.pending.hosts", nil).Update(int64(hostLen))
metrics.GetOrRegisterGauge("hostmap.pending.indexes", nil).Update(int64(indexLen))
c.mainHostMap.EmitStats()
hm.mainHostMap.EmitStats()
}
// Utility functions below

View File

@ -5,6 +5,7 @@ import (
"testing"
"time"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/test"
"github.com/slackhq/nebula/udp"
@ -13,21 +14,20 @@ import (
func Test_NewHandshakeManagerVpnIp(t *testing.T) {
l := test.NewLogger()
vpncidr := netip.MustParsePrefix("172.1.1.1/24")
localrange := netip.MustParsePrefix("10.1.1.1/24")
ip := netip.MustParseAddr("172.1.1.2")
preferredRanges := []netip.Prefix{localrange}
mainHM := newHostMap(l, vpncidr)
mainHM := newHostMap(l)
mainHM.preferredRanges.Store(&preferredRanges)
lh := newTestLighthouse()
cs := &CertState{
RawCertificate: []byte{},
PrivateKey: []byte{},
Certificate: &dummyCert{},
RawCertificateNoKey: []byte{},
defaultVersion: cert.Version1,
privateKey: []byte{},
v1Cert: &dummyCert{version: cert.Version1},
v1HandshakeBytes: []byte{},
}
blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
@ -41,7 +41,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
i2 := blah.StartHandshake(ip, nil)
assert.Same(t, i, i2)
i.remotes = NewRemoteList(nil)
i.remotes = NewRemoteList([]netip.Addr{}, nil)
// Adding something to pending should not affect the main hostmap
assert.Len(t, mainHM.Hosts, 0)
@ -79,16 +79,24 @@ func testCountTimerWheelEntries(tw *LockingTimerWheel[netip.Addr]) (c int) {
type mockEncWriter struct {
}
func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) {
func (mw *mockEncWriter) SendMessageToVpnAddr(_ header.MessageType, _ header.MessageSubType, _ netip.Addr, _, _, _ []byte) {
return
}
func (mw *mockEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
func (mw *mockEncWriter) SendVia(_ *HostInfo, _ *Relay, _, _, _ []byte, _ bool) {
return
}
func (mw *mockEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
func (mw *mockEncWriter) SendMessageToHostInfo(_ header.MessageType, _ header.MessageSubType, _ *HostInfo, _, _, _ []byte) {
return
}
func (mw *mockEncWriter) Handshake(vpnIP netip.Addr) {}
func (mw *mockEncWriter) Handshake(_ netip.Addr) {}
func (mw *mockEncWriter) GetHostInfo(_ netip.Addr) *HostInfo {
return nil
}
func (mw *mockEncWriter) GetCertState() *CertState {
return &CertState{defaultVersion: cert.Version2}
}

View File

@ -48,7 +48,7 @@ type Relay struct {
State int
LocalIndex uint32
RemoteIndex uint32
PeerIp netip.Addr
PeerAddr netip.Addr
}
type HostMap struct {
@ -58,7 +58,6 @@ type HostMap struct {
RemoteIndexes map[uint32]*HostInfo
Hosts map[netip.Addr]*HostInfo
preferredRanges atomic.Pointer[[]netip.Prefix]
vpnCIDR netip.Prefix
l *logrus.Logger
}
@ -68,9 +67,9 @@ type HostMap struct {
type RelayState struct {
sync.RWMutex
relays map[netip.Addr]struct{} // Set of VpnIp's of Hosts to use as relays to access this peer
relayForByIp map[netip.Addr]*Relay // Maps VpnIps of peers for which this HostInfo is a relay to some Relay info
relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info
relays map[netip.Addr]struct{} // Set of vpnAddr's of Hosts to use as relays to access this peer
relayForByAddr map[netip.Addr]*Relay // Maps vpnAddr of peers for which this HostInfo is a relay to some Relay info
relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info
}
func (rs *RelayState) DeleteRelay(ip netip.Addr) {
@ -89,10 +88,10 @@ func (rs *RelayState) CopyAllRelayFor() []*Relay {
return ret
}
func (rs *RelayState) GetRelayForByIp(ip netip.Addr) (*Relay, bool) {
func (rs *RelayState) GetRelayForByAddr(addr netip.Addr) (*Relay, bool) {
rs.RLock()
defer rs.RUnlock()
r, ok := rs.relayForByIp[ip]
r, ok := rs.relayForByAddr[addr]
return r, ok
}
@ -115,8 +114,8 @@ func (rs *RelayState) CopyRelayIps() []netip.Addr {
func (rs *RelayState) CopyRelayForIps() []netip.Addr {
rs.RLock()
defer rs.RUnlock()
currentRelays := make([]netip.Addr, 0, len(rs.relayForByIp))
for relayIp := range rs.relayForByIp {
currentRelays := make([]netip.Addr, 0, len(rs.relayForByAddr))
for relayIp := range rs.relayForByAddr {
currentRelays = append(currentRelays, relayIp)
}
return currentRelays
@ -135,7 +134,7 @@ func (rs *RelayState) CopyRelayForIdxs() []uint32 {
func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool {
rs.Lock()
defer rs.Unlock()
r, ok := rs.relayForByIp[vpnIp]
r, ok := rs.relayForByAddr[vpnIp]
if !ok {
return false
}
@ -143,7 +142,7 @@ func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool
newRelay.State = Established
newRelay.RemoteIndex = remoteIdx
rs.relayForByIdx[r.LocalIndex] = &newRelay
rs.relayForByIp[r.PeerIp] = &newRelay
rs.relayForByAddr[r.PeerAddr] = &newRelay
return true
}
@ -158,14 +157,14 @@ func (rs *RelayState) CompleteRelayByIdx(localIdx uint32, remoteIdx uint32) (*Re
newRelay.State = Established
newRelay.RemoteIndex = remoteIdx
rs.relayForByIdx[r.LocalIndex] = &newRelay
rs.relayForByIp[r.PeerIp] = &newRelay
rs.relayForByAddr[r.PeerAddr] = &newRelay
return &newRelay, true
}
func (rs *RelayState) QueryRelayForByIp(vpnIp netip.Addr) (*Relay, bool) {
rs.RLock()
defer rs.RUnlock()
r, ok := rs.relayForByIp[vpnIp]
r, ok := rs.relayForByAddr[vpnIp]
return r, ok
}
@ -179,7 +178,7 @@ func (rs *RelayState) QueryRelayForByIdx(idx uint32) (*Relay, bool) {
func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) {
rs.Lock()
defer rs.Unlock()
rs.relayForByIp[ip] = r
rs.relayForByAddr[ip] = r
rs.relayForByIdx[idx] = r
}
@ -190,10 +189,12 @@ type HostInfo struct {
ConnectionState *ConnectionState
remoteIndexId uint32
localIndexId uint32
vpnIp netip.Addr
vpnAddrs []netip.Addr
recvError atomic.Uint32
remoteCidr *bart.Table[struct{}]
relayState RelayState
// networks are both all vpn and unsafe networks assigned to this host
networks *bart.Table[struct{}]
relayState RelayState
// HandshakePacket records the packets used to create this hostinfo
// We need these to avoid replayed handshake packets creating new hostinfos which causes churn
@ -241,28 +242,26 @@ type cachedPacketMetrics struct {
dropped metrics.Counter
}
func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR netip.Prefix, c *config.C) *HostMap {
hm := newHostMap(l, vpnCIDR)
func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap {
hm := newHostMap(l)
hm.reload(c, true)
c.RegisterReloadCallback(func(c *config.C) {
hm.reload(c, false)
})
l.WithField("network", hm.vpnCIDR.String()).
WithField("preferredRanges", hm.GetPreferredRanges()).
l.WithField("preferredRanges", hm.GetPreferredRanges()).
Info("Main HostMap created")
return hm
}
func newHostMap(l *logrus.Logger, vpnCIDR netip.Prefix) *HostMap {
func newHostMap(l *logrus.Logger) *HostMap {
return &HostMap{
Indexes: map[uint32]*HostInfo{},
Relays: map[uint32]*HostInfo{},
RemoteIndexes: map[uint32]*HostInfo{},
Hosts: map[netip.Addr]*HostInfo{},
vpnCIDR: vpnCIDR,
l: l,
}
}
@ -305,17 +304,6 @@ func (hm *HostMap) EmitStats() {
metrics.GetOrRegisterGauge("hostmap.main.relayIndexes", nil).Update(int64(relaysLen))
}
func (hm *HostMap) RemoveRelay(localIdx uint32) {
hm.Lock()
_, ok := hm.Relays[localIdx]
if !ok {
hm.Unlock()
return
}
delete(hm.Relays, localIdx)
hm.Unlock()
}
// DeleteHostInfo will fully unlink the hostinfo and return true if it was the final hostinfo for this vpn ip
func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool {
// Delete the host itself, ensuring it's not modified anymore
@ -335,7 +323,9 @@ func (hm *HostMap) MakePrimary(hostinfo *HostInfo) {
}
func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) {
oldHostinfo := hm.Hosts[hostinfo.vpnIp]
//TODO: we may need to promote follow on hostinfos from these vpnAddrs as well since their oldHostinfo might not be the same as this one
// this really looks like an ideal spot for memory leaks
oldHostinfo := hm.Hosts[hostinfo.vpnAddrs[0]]
if oldHostinfo == hostinfo {
return
}
@ -348,7 +338,7 @@ func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) {
hostinfo.next.prev = hostinfo.prev
}
hm.Hosts[hostinfo.vpnIp] = hostinfo
hm.Hosts[hostinfo.vpnAddrs[0]] = hostinfo
if oldHostinfo == nil {
return
@ -360,23 +350,35 @@ func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) {
}
func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
primary, ok := hm.Hosts[hostinfo.vpnIp]
for _, addr := range hostinfo.vpnAddrs {
h := hm.Hosts[addr]
for h != nil {
if h == hostinfo {
hm.unlockedInnerDeleteHostInfo(h, addr)
}
h = h.next
}
}
}
func (hm *HostMap) unlockedInnerDeleteHostInfo(hostinfo *HostInfo, addr netip.Addr) {
primary, ok := hm.Hosts[addr]
if ok && primary == hostinfo {
// The vpnIp pointer points to the same hostinfo as the local index id, we can remove it
delete(hm.Hosts, hostinfo.vpnIp)
// The vpn addr pointer points to the same hostinfo as the local index id, we can remove it
delete(hm.Hosts, addr)
if len(hm.Hosts) == 0 {
hm.Hosts = map[netip.Addr]*HostInfo{}
}
if hostinfo.next != nil {
// We had more than 1 hostinfo at this vpnip, promote the next in the list to primary
hm.Hosts[hostinfo.vpnIp] = hostinfo.next
// We had more than 1 hostinfo at this vpn addr, promote the next in the list to primary
hm.Hosts[addr] = hostinfo.next
// It is primary, there is no previous hostinfo now
hostinfo.next.prev = nil
}
} else {
// Relink if we were in the middle of multiple hostinfos for this vpn ip
// Relink if we were in the middle of multiple hostinfos for this vpn addr
if hostinfo.prev != nil {
hostinfo.prev.next = hostinfo.next
}
@ -406,7 +408,7 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.Hosts),
"vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
Debug("Hostmap hostInfo deleted")
}
@ -448,11 +450,11 @@ func (hm *HostMap) QueryReverseIndex(index uint32) *HostInfo {
}
}
func (hm *HostMap) QueryVpnIp(vpnIp netip.Addr) *HostInfo {
return hm.queryVpnIp(vpnIp, nil)
func (hm *HostMap) QueryVpnAddr(vpnIp netip.Addr) *HostInfo {
return hm.queryVpnAddr(vpnIp, nil)
}
func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp netip.Addr) (*HostInfo, *Relay, error) {
func (hm *HostMap) QueryVpnAddrsRelayFor(targetIps []netip.Addr, relayHostIp netip.Addr) (*HostInfo, *Relay, error) {
hm.RLock()
defer hm.RUnlock()
@ -460,17 +462,21 @@ func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp netip.Addr) (*HostIn
if !ok {
return nil, nil, errors.New("unable to find host")
}
for h != nil {
r, ok := h.relayState.QueryRelayForByIp(targetIp)
if ok && r.State == Established {
return h, r, nil
for _, targetIp := range targetIps {
r, ok := h.relayState.QueryRelayForByIp(targetIp)
if ok && r.State == Established {
return h, r, nil
}
}
h = h.next
}
return nil, nil, errors.New("unable to find host with relay")
}
func (hm *HostMap) queryVpnIp(vpnIp netip.Addr, promoteIfce *Interface) *HostInfo {
func (hm *HostMap) queryVpnAddr(vpnIp netip.Addr, promoteIfce *Interface) *HostInfo {
hm.RLock()
if h, ok := hm.Hosts[vpnIp]; ok {
hm.RUnlock()
@ -491,25 +497,30 @@ func (hm *HostMap) queryVpnIp(vpnIp netip.Addr, promoteIfce *Interface) *HostInf
func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
if f.serveDns {
remoteCert := hostinfo.ConnectionState.peerCert
dnsR.Add(remoteCert.Certificate.Name()+".", remoteCert.Certificate.Networks()[0].Addr().String())
dnsR.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs)
}
existing := hm.Hosts[hostinfo.vpnIp]
hm.Hosts[hostinfo.vpnIp] = hostinfo
if existing != nil {
hostinfo.next = existing
existing.prev = hostinfo
for _, addr := range hostinfo.vpnAddrs {
hm.unlockedInnerAddHostInfo(addr, hostinfo, f)
}
hm.Indexes[hostinfo.localIndexId] = hostinfo
hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"vpnIp": hostinfo.vpnIp, "mapTotalSize": len(hm.Hosts),
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": hostinfo.vpnIp}}).
hm.l.WithField("hostMap", m{"vpnAddrs": hostinfo.vpnAddrs, "mapTotalSize": len(hm.Hosts),
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}}).
Debug("Hostmap vpnIp added")
}
}
func (hm *HostMap) unlockedInnerAddHostInfo(vpnAddr netip.Addr, hostinfo *HostInfo, f *Interface) {
existing := hm.Hosts[vpnAddr]
hm.Hosts[vpnAddr] = hostinfo
if existing != nil && existing != hostinfo {
hostinfo.next = existing
existing.prev = hostinfo
}
i := 1
check := hostinfo
@ -527,7 +538,7 @@ func (hm *HostMap) GetPreferredRanges() []netip.Prefix {
return *hm.preferredRanges.Load()
}
func (hm *HostMap) ForEachVpnIp(f controlEach) {
func (hm *HostMap) ForEachVpnAddr(f controlEach) {
hm.RLock()
defer hm.RUnlock()
@ -581,7 +592,7 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []netip.Prefix, ifce *Interfac
}
i.nextLHQuery.Store(now + ifce.reQueryWait.Load())
ifce.lightHouse.QueryServer(i.vpnIp)
ifce.lightHouse.QueryServer(i.vpnAddrs[0])
}
}
@ -596,7 +607,7 @@ func (i *HostInfo) SetRemote(remote netip.AddrPort) {
// We copy here because we likely got this remote from a source that reuses the object
if i.remote != remote {
i.remote = remote
i.remotes.LearnRemote(i.vpnIp, remote)
i.remotes.LearnRemote(i.vpnAddrs[0], remote)
}
}
@ -647,21 +658,20 @@ func (i *HostInfo) RecvErrorExceeded() bool {
return true
}
func (i *HostInfo) CreateRemoteCIDR(c cert.Certificate) {
func (i *HostInfo) buildNetworks(c cert.Certificate) {
if len(c.Networks()) == 1 && len(c.UnsafeNetworks()) == 0 {
// Simple case, no CIDRTree needed
return
}
remoteCidr := new(bart.Table[struct{}])
i.networks = new(bart.Table[struct{}])
for _, network := range c.Networks() {
remoteCidr.Insert(network, struct{}{})
i.networks.Insert(network, struct{}{})
}
for _, network := range c.UnsafeNetworks() {
remoteCidr.Insert(network, struct{}{})
i.networks.Insert(network, struct{}{})
}
i.remoteCidr = remoteCidr
}
func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
@ -669,7 +679,7 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
return logrus.NewEntry(l)
}
li := l.WithField("vpnIp", i.vpnIp).
li := l.WithField("vpnAddrs", i.vpnAddrs).
WithField("localIndex", i.localIndexId).
WithField("remoteIndex", i.remoteIndexId)
@ -684,9 +694,9 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
// Utility functions
func localIps(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
//FIXME: This function is pretty garbage
var ips []netip.Addr
var finalAddrs []netip.Addr
ifaces, _ := net.Interfaces()
for _, i := range ifaces {
allow := allowList.AllowName(i.Name)
@ -698,39 +708,38 @@ func localIps(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
continue
}
addrs, _ := i.Addrs()
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
for _, rawAddr := range addrs {
var addr netip.Addr
switch v := rawAddr.(type) {
case *net.IPNet:
//continue
ip = v.IP
addr, _ = netip.AddrFromSlice(v.IP)
case *net.IPAddr:
ip = v.IP
addr, _ = netip.AddrFromSlice(v.IP)
}
nip, ok := netip.AddrFromSlice(ip)
if !ok {
if !addr.IsValid() {
if l.Level >= logrus.DebugLevel {
l.WithField("localIp", ip).Debug("ip was invalid for netip")
l.WithField("localAddr", rawAddr).Debug("addr was invalid")
}
continue
}
nip = nip.Unmap()
addr = addr.Unmap()
//TODO: Filtering out link local for now, this is probably the most correct thing
//TODO: Would be nice to filter out SLAAC MAC based ips as well
if nip.IsLoopback() == false && nip.IsLinkLocalUnicast() == false {
allow := allowList.Allow(nip)
if addr.IsLoopback() == false && addr.IsLinkLocalUnicast() == false {
isAllowed := allowList.Allow(addr)
if l.Level >= logrus.TraceLevel {
l.WithField("localIp", nip).WithField("allow", allow).Trace("localAllowList.Allow")
l.WithField("localAddr", addr).WithField("allowed", isAllowed).Trace("localAllowList.Allow")
}
if !allow {
if !isAllowed {
continue
}
ips = append(ips, nip)
finalAddrs = append(finalAddrs, addr)
}
}
}
return ips
return finalAddrs
}

View File

@ -11,17 +11,14 @@ import (
func TestHostMap_MakePrimary(t *testing.T) {
l := test.NewLogger()
hm := newHostMap(
l,
netip.MustParsePrefix("10.0.0.1/24"),
)
hm := newHostMap(l)
f := &Interface{}
h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1}
h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2}
h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3}
h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4}
h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1}
h2 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 2}
h3 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 3}
h4 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 4}
hm.unlockedAddHostInfo(h4, f)
hm.unlockedAddHostInfo(h3, f)
@ -29,7 +26,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
hm.unlockedAddHostInfo(h1, f)
// Make sure we go h1 -> h2 -> h3 -> h4
prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
prim := hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h1.localIndexId, prim.localIndexId)
assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@ -44,7 +41,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
hm.MakePrimary(h3)
// Make sure we go h3 -> h1 -> h2 -> h4
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h3.localIndexId, prim.localIndexId)
assert.Equal(t, h1.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@ -59,7 +56,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
hm.MakePrimary(h4)
// Make sure we go h4 -> h3 -> h1 -> h2
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h4.localIndexId, prim.localIndexId)
assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@ -74,7 +71,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
hm.MakePrimary(h4)
// Make sure we go h4 -> h3 -> h1 -> h2
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h4.localIndexId, prim.localIndexId)
assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@ -88,19 +85,16 @@ func TestHostMap_MakePrimary(t *testing.T) {
func TestHostMap_DeleteHostInfo(t *testing.T) {
l := test.NewLogger()
hm := newHostMap(
l,
netip.MustParsePrefix("10.0.0.1/24"),
)
hm := newHostMap(l)
f := &Interface{}
h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1}
h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2}
h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3}
h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4}
h5 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 5}
h6 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 6}
h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1}
h2 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 2}
h3 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 3}
h4 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 4}
h5 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 5}
h6 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 6}
hm.unlockedAddHostInfo(h6, f)
hm.unlockedAddHostInfo(h5, f)
@ -116,7 +110,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h)
// Make sure we go h1 -> h2 -> h3 -> h4 -> h5
prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
prim := hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h1.localIndexId, prim.localIndexId)
assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@ -135,7 +129,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h1.next)
// Make sure we go h2 -> h3 -> h4 -> h5
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h2.localIndexId, prim.localIndexId)
assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@ -153,7 +147,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h3.next)
// Make sure we go h2 -> h4 -> h5
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h2.localIndexId, prim.localIndexId)
assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@ -169,7 +163,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h5.next)
// Make sure we go h2 -> h4
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h2.localIndexId, prim.localIndexId)
assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@ -183,7 +177,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h2.next)
// Make sure we only have h4
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h4.localIndexId, prim.localIndexId)
assert.Nil(t, prim.prev)
assert.Nil(t, prim.next)
@ -195,7 +189,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h4.next)
// Make sure we have nil
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
assert.Nil(t, prim)
}
@ -203,11 +197,7 @@ func TestHostMap_reload(t *testing.T) {
l := test.NewLogger()
c := config.NewC(l)
hm := NewHostMapFromConfig(
l,
netip.MustParsePrefix("10.0.0.1/24"),
c,
)
hm := NewHostMapFromConfig(l, c)
toS := func(ipn []netip.Prefix) []string {
var s []string

View File

@ -9,8 +9,8 @@ import (
"net/netip"
)
func (i *HostInfo) GetVpnIp() netip.Addr {
return i.vpnIp
func (i *HostInfo) GetVpnAddrs() []netip.Addr {
return i.vpnAddrs
}
func (i *HostInfo) GetLocalIndex() uint32 {

View File

@ -20,14 +20,19 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
}
// Ignore local broadcast packets
if f.dropLocalBroadcast && fwPacket.RemoteIP == f.myBroadcastAddr {
return
if f.dropLocalBroadcast {
_, found := f.myBroadcastAddrsTable.Lookup(fwPacket.RemoteAddr)
if found {
return
}
}
if fwPacket.RemoteIP == f.myVpnNet.Addr() {
//TODO: seems like a huge bummer
_, found := f.myVpnAddrsTable.Lookup(fwPacket.RemoteAddr)
if found {
// Immediately forward packets from self to self.
// This should only happen on Darwin-based and FreeBSD hosts, which
// routes packets from the Nebula IP to the Nebula IP through the Nebula
// routes packets from the Nebula addr to the Nebula addr through the Nebula
// TUN device.
if immediatelyForwardToSelf {
_, err := f.readers[q].Write(packet)
@ -36,25 +41,25 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
}
}
// Otherwise, drop. On linux, we should never see these packets - Linux
// routes packets from the nebula IP to the nebula IP through the loopback device.
// routes packets from the nebula addr to the nebula addr through the loopback device.
return
}
// Ignore multicast packets
if f.dropMulticast && fwPacket.RemoteIP.IsMulticast() {
if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() {
return
}
hostinfo, ready := f.getOrHandshake(fwPacket.RemoteIP, func(hh *HandshakeHostInfo) {
hostinfo, ready := f.getOrHandshake(fwPacket.RemoteAddr, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
})
if hostinfo == nil {
f.rejectInside(packet, out, q)
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnIp", fwPacket.RemoteIP).
f.l.WithField("vpnAddr", fwPacket.RemoteAddr).
WithField("fwPacket", fwPacket).
Debugln("dropping outbound packet, vpnIp not in our CIDR or in unsafe routes")
Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks")
}
return
}
@ -117,21 +122,22 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
}
func (f *Interface) Handshake(vpnIp netip.Addr) {
f.getOrHandshake(vpnIp, nil)
func (f *Interface) Handshake(vpnAddr netip.Addr) {
f.getOrHandshake(vpnAddr, nil)
}
// getOrHandshake returns nil if the vpnIp is not routable.
// getOrHandshake returns nil if the vpnAddr is not routable.
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel
func (f *Interface) getOrHandshake(vpnIp netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
if !f.myVpnNet.Contains(vpnIp) {
vpnIp = f.inside.RouteFor(vpnIp)
if !vpnIp.IsValid() {
func (f *Interface) getOrHandshake(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
_, found := f.myVpnNetworksTable.Lookup(vpnAddr)
if !found {
vpnAddr = f.inside.RouteFor(vpnAddr)
if !vpnAddr.IsValid() {
return nil, false
}
}
return f.handshakeManager.GetOrHandshake(vpnIp, cacheCallback)
return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback)
}
func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
@ -156,16 +162,16 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0)
}
// SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) {
hostInfo, ready := f.getOrHandshake(vpnIp, func(hh *HandshakeHostInfo) {
// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr
func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) {
hostInfo, ready := f.getOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
})
if hostInfo == nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnIp", vpnIp).
Debugln("dropping SendMessageToVpnIp, vpnIp not in our CIDR or in unsafe routes")
f.l.WithField("vpnAddr", vpnAddr).
Debugln("dropping SendMessageToVpnAddr, vpnAddr not in our vpn networks or in unsafe routes")
}
return
}
@ -285,14 +291,14 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
f.connectionManager.Out(hostinfo.localIndexId)
// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
// all our IPs and enable a faster roaming.
// all our addrs and enable a faster roaming.
if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount {
//NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is
// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
f.lightHouse.QueryServer(hostinfo.vpnIp)
f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
hostinfo.lastRebindCount = f.rebindCount
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnIp", hostinfo.vpnIp).Debug("Lighthouse update triggered for punch due to rebind counter")
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter")
}
}
@ -324,7 +330,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
} else {
// Try to send via a relay
for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
relayHostInfo, relay, err := f.hostMap.QueryVpnIpRelayFor(hostinfo.vpnIp, relayIP)
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
if err != nil {
hostinfo.relayState.DeleteRelay(relayIP)
hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")

View File

@ -2,17 +2,16 @@ package nebula
import (
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"net/netip"
"os"
"runtime"
"sync/atomic"
"time"
"github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
@ -29,7 +28,6 @@ type InterfaceConfig struct {
Outside udp.Conn
Inside overlay.Device
pki *PKI
Cipher string
Firewall *Firewall
ServeDns bool
HandshakeManager *HandshakeManager
@ -53,25 +51,27 @@ type InterfaceConfig struct {
}
type Interface struct {
hostMap *HostMap
outside udp.Conn
inside overlay.Device
pki *PKI
cipher string
firewall *Firewall
connectionManager *connectionManager
handshakeManager *HandshakeManager
serveDns bool
createTime time.Time
lightHouse *LightHouse
myBroadcastAddr netip.Addr
myVpnNet netip.Prefix
dropLocalBroadcast bool
dropMulticast bool
routines int
disconnectInvalid atomic.Bool
closed atomic.Bool
relayManager *relayManager
hostMap *HostMap
outside udp.Conn
inside overlay.Device
pki *PKI
firewall *Firewall
connectionManager *connectionManager
handshakeManager *HandshakeManager
serveDns bool
createTime time.Time
lightHouse *LightHouse
myBroadcastAddrsTable *bart.Table[struct{}]
myVpnAddrs []netip.Addr // A list of addresses assigned to us via our certificate
myVpnAddrsTable *bart.Table[struct{}] // A table of addresses assigned to us via our certificate
myVpnNetworks []netip.Prefix // A table of networks assigned to us via our certificate
myVpnNetworksTable *bart.Table[struct{}] // A table of networks assigned to us via our certificate
dropLocalBroadcast bool
dropMulticast bool
routines int
disconnectInvalid atomic.Bool
closed atomic.Bool
relayManager *relayManager
tryPromoteEvery atomic.Uint32
reQueryEvery atomic.Uint32
@ -103,9 +103,11 @@ type EncWriter interface {
out []byte,
nocopy bool,
)
SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte)
SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte)
SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte)
Handshake(vpnIp netip.Addr)
Handshake(vpnAddr netip.Addr)
GetHostInfo(vpnAddr netip.Addr) *HostInfo
GetCertState() *CertState
}
type sendRecvErrorConfig uint8
@ -116,10 +118,10 @@ const (
sendRecvErrorPrivate
)
func (s sendRecvErrorConfig) ShouldSendRecvError(ip netip.AddrPort) bool {
func (s sendRecvErrorConfig) ShouldSendRecvError(endpoint netip.AddrPort) bool {
switch s {
case sendRecvErrorPrivate:
return ip.Addr().IsPrivate()
return endpoint.Addr().IsPrivate()
case sendRecvErrorAlways:
return true
case sendRecvErrorNever:
@ -156,27 +158,29 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
return nil, errors.New("no firewall rules")
}
certificate := c.pki.GetCertState().Certificate
cs := c.pki.getCertState()
ifce := &Interface{
pki: c.pki,
hostMap: c.HostMap,
outside: c.Outside,
inside: c.Inside,
cipher: c.Cipher,
firewall: c.Firewall,
serveDns: c.ServeDns,
handshakeManager: c.HandshakeManager,
createTime: time.Now(),
lightHouse: c.lightHouse,
dropLocalBroadcast: c.DropLocalBroadcast,
dropMulticast: c.DropMulticast,
routines: c.routines,
version: c.version,
writers: make([]udp.Conn, c.routines),
readers: make([]io.ReadWriteCloser, c.routines),
myVpnNet: certificate.Networks()[0],
relayManager: c.relayManager,
pki: c.pki,
hostMap: c.HostMap,
outside: c.Outside,
inside: c.Inside,
firewall: c.Firewall,
serveDns: c.ServeDns,
handshakeManager: c.HandshakeManager,
createTime: time.Now(),
lightHouse: c.lightHouse,
dropLocalBroadcast: c.DropLocalBroadcast,
dropMulticast: c.DropMulticast,
routines: c.routines,
version: c.version,
writers: make([]udp.Conn, c.routines),
readers: make([]io.ReadWriteCloser, c.routines),
myVpnNetworks: cs.myVpnNetworks,
myVpnNetworksTable: cs.myVpnNetworksTable,
myVpnAddrs: cs.myVpnAddrs,
myVpnAddrsTable: cs.myVpnAddrsTable,
myBroadcastAddrsTable: cs.myVpnBroadcastAddrsTable,
relayManager: c.relayManager,
conntrackCacheTimeout: c.ConntrackCacheTimeout,
@ -190,14 +194,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
l: c.l,
}
if ifce.myVpnNet.Addr().Is4() {
maskedAddr := certificate.Networks()[0].Masked()
addr := maskedAddr.Addr().As4()
mask := net.CIDRMask(maskedAddr.Bits(), maskedAddr.Addr().BitLen())
binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(mask))
ifce.myBroadcastAddr = netip.AddrFrom4(addr)
}
ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
ifce.reQueryEvery.Store(c.reQueryEvery)
ifce.reQueryWait.Store(int64(c.reQueryWait))
@ -218,7 +214,7 @@ func (f *Interface) activate() {
f.l.WithError(err).Error("Failed to get udp listen address")
}
f.l.WithField("interface", f.inside.Name()).WithField("network", f.inside.Cidr().String()).
f.l.WithField("interface", f.inside.Name()).WithField("networks", f.myVpnNetworks).
WithField("build", f.version).WithField("udpAddr", addr).
WithField("boringcrypto", boringEnabled()).
Info("Nebula interface is active")
@ -259,16 +255,22 @@ func (f *Interface) listenOut(i int) {
runtime.LockOSThread()
var li udp.Conn
// TODO clean this up with a coherent interface for each outside connection
if i > 0 {
li = f.writers[i]
} else {
li = f.outside
}
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
lhh := f.lightHouse.NewRequestHandler()
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
li.ListenOut(readOutsidePackets(f), lhHandleRequest(lhh, f), conntrackCache, i)
plaintext := make([]byte, udp.MTU)
h := &header.H{}
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
})
}
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
@ -325,7 +327,7 @@ func (f *Interface) reloadFirewall(c *config.C) {
return
}
fw, err := NewFirewallFromConfig(f.l, f.pki.GetCertState().Certificate, c)
fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c)
if err != nil {
f.l.WithError(err).Error("Error while creating firewall during reload")
return
@ -417,11 +419,20 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
f.firewall.EmitStats()
f.handshakeManager.EmitStats()
udpStats()
certExpirationGauge.Update(int64(f.pki.GetCertState().Certificate.NotAfter().Sub(time.Now()) / time.Second))
certExpirationGauge.Update(int64(f.pki.getDefaultCertificate().NotAfter().Sub(time.Now()) / time.Second))
//TODO: we should also report the default certificate version
}
}
}
func (f *Interface) GetHostInfo(vpnIp netip.Addr) *HostInfo {
return f.hostMap.QueryVpnAddr(vpnIp)
}
func (f *Interface) GetCertState() *CertState {
return f.pki.getCertState()
}
func (f *Interface) Close() error {
f.closed.Store(true)

File diff suppressed because it is too large Load Diff

View File

@ -7,6 +7,8 @@ import (
"net/netip"
"testing"
"github.com/gaissmai/bart"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/test"
@ -19,57 +21,48 @@ import (
func TestOldIPv4Only(t *testing.T) {
// This test ensures our new ipv6 enabled LH protobuf IpAndPorts works with the old style to enable backwards compatibility
b := []byte{8, 129, 130, 132, 80, 16, 10}
var m Ip4AndPort
var m V4AddrPort
err := m.Unmarshal(b)
assert.NoError(t, err)
ip := netip.MustParseAddr("10.1.1.1")
bp := ip.As4()
assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetIp())
}
func TestNewLhQuery(t *testing.T) {
myIp, err := netip.ParseAddr("192.1.1.1")
assert.NoError(t, err)
// Generating a new lh query should work
a := NewLhQueryByInt(myIp)
// The result should be a nebulameta protobuf
assert.IsType(t, &NebulaMeta{}, a)
// It should also Marshal fine
b, err := a.Marshal()
assert.Nil(t, err)
// and then Unmarshal fine
n := &NebulaMeta{}
err = n.Unmarshal(b)
assert.Nil(t, err)
assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetAddr())
}
func Test_lhStaticMapping(t *testing.T) {
l := test.NewLogger()
myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
nt := new(bart.Table[struct{}])
nt.Insert(myVpnNet, struct{}{})
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
}
lh1 := "10.128.0.2"
c := config.NewC(l)
c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}}
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
_, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
_, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
assert.Nil(t, err)
lh2 := "10.128.0.3"
c = config.NewC(l)
c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}}
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}}
_, err = NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
_, err = NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry")
}
func TestReloadLighthouseInterval(t *testing.T) {
l := test.NewLogger()
myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
nt := new(bart.Table[struct{}])
nt.Insert(myVpnNet, struct{}{})
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
}
lh1 := "10.128.0.2"
c := config.NewC(l)
@ -79,7 +72,7 @@ func TestReloadLighthouseInterval(t *testing.T) {
}
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
assert.NoError(t, err)
lh.ifce = &mockEncWriter{}
@ -99,9 +92,15 @@ func TestReloadLighthouseInterval(t *testing.T) {
func BenchmarkLighthouseHandleRequest(b *testing.B) {
l := test.NewLogger()
myVpnNet := netip.MustParsePrefix("10.128.0.1/0")
nt := new(bart.Table[struct{}])
nt.Insert(myVpnNet, struct{}{})
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
}
c := config.NewC(l)
lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
if !assert.NoError(b, err) {
b.Fatal()
}
@ -110,46 +109,47 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346")
vpnIp3 := netip.MustParseAddr("0.0.0.3")
lh.addrMap[vpnIp3] = NewRemoteList(nil)
lh.addrMap[vpnIp3] = NewRemoteList([]netip.Addr{vpnIp3}, nil)
lh.addrMap[vpnIp3].unlockedSetV4(
vpnIp3,
vpnIp3,
[]*Ip4AndPort{
NewIp4AndPortFromNetIP(hAddr.Addr(), hAddr.Port()),
NewIp4AndPortFromNetIP(hAddr2.Addr(), hAddr2.Port()),
[]*V4AddrPort{
netAddrToProtoV4AddrPort(hAddr.Addr(), hAddr.Port()),
netAddrToProtoV4AddrPort(hAddr2.Addr(), hAddr2.Port()),
},
func(netip.Addr, *Ip4AndPort) bool { return true },
func(netip.Addr, *V4AddrPort) bool { return true },
)
rAddr := netip.MustParseAddrPort("1.2.2.3:12345")
rAddr2 := netip.MustParseAddrPort("1.2.2.3:12346")
vpnIp2 := netip.MustParseAddr("0.0.0.3")
lh.addrMap[vpnIp2] = NewRemoteList(nil)
lh.addrMap[vpnIp2] = NewRemoteList([]netip.Addr{vpnIp2}, nil)
lh.addrMap[vpnIp2].unlockedSetV4(
vpnIp3,
vpnIp3,
[]*Ip4AndPort{
NewIp4AndPortFromNetIP(rAddr.Addr(), rAddr.Port()),
NewIp4AndPortFromNetIP(rAddr2.Addr(), rAddr2.Port()),
[]*V4AddrPort{
netAddrToProtoV4AddrPort(rAddr.Addr(), rAddr.Port()),
netAddrToProtoV4AddrPort(rAddr2.Addr(), rAddr2.Port()),
},
func(netip.Addr, *Ip4AndPort) bool { return true },
func(netip.Addr, *V4AddrPort) bool { return true },
)
mw := &mockEncWriter{}
hi := []netip.Addr{vpnIp2}
b.Run("notfound", func(b *testing.B) {
lhh := lh.NewRequestHandler()
req := &NebulaMeta{
Type: NebulaMeta_HostQuery,
Details: &NebulaMetaDetails{
VpnIp: 4,
Ip4AndPorts: nil,
OldVpnAddr: 4,
V4AddrPorts: nil,
},
}
p, err := req.Marshal()
assert.NoError(b, err)
for n := 0; n < b.N; n++ {
lhh.HandleRequest(rAddr, vpnIp2, p, mw)
lhh.HandleRequest(rAddr, hi, p, mw)
}
})
b.Run("found", func(b *testing.B) {
@ -157,15 +157,15 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
req := &NebulaMeta{
Type: NebulaMeta_HostQuery,
Details: &NebulaMetaDetails{
VpnIp: 3,
Ip4AndPorts: nil,
OldVpnAddr: 3,
V4AddrPorts: nil,
},
}
p, err := req.Marshal()
assert.NoError(b, err)
for n := 0; n < b.N; n++ {
lhh.HandleRequest(rAddr, vpnIp2, p, mw)
lhh.HandleRequest(rAddr, hi, p, mw)
}
})
}
@ -197,40 +197,49 @@ func TestLighthouse_Memory(t *testing.T) {
c := config.NewC(l)
c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil)
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
nt := new(bart.Table[struct{}])
nt.Insert(myVpnNet, struct{}{})
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
lh.ifce = &mockEncWriter{}
assert.NoError(t, err)
lhh := lh.NewRequestHandler()
// Test that my first update responds with just that
newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr2}, lhh)
r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr2)
assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr2)
// Ensure we don't accumulate addresses
newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr3}, lhh)
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr3)
assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr3)
// Grow it back to 2
newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr4}, lhh)
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4)
// Update a different host and ask about it
newLHHostUpdate(theirUdpAddr0, theirVpnIp, []netip.AddrPort{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh)
r = newLHHostRequest(theirUdpAddr0, theirVpnIp, theirVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
assertIp4InArray(t, r.msg.Details.V4AddrPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
// Have both hosts ask about the other
r = newLHHostRequest(theirUdpAddr0, theirVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4)
r = newLHHostRequest(myUdpAddr0, myVpnIp, theirVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
assertIp4InArray(t, r.msg.Details.V4AddrPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
// Make sure we didn't get changed
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4)
// Ensure proper ordering and limiting
// Send 12 addrs, get 10 back, the last 2 removed, allowing the duplicate to remain (clients dedupe)
@ -255,7 +264,7 @@ func TestLighthouse_Memory(t *testing.T) {
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(
t,
r.msg.Details.Ip4AndPorts,
r.msg.Details.V4AddrPorts,
myUdpAddr1, myUdpAddr2, myUdpAddr3, myUdpAddr4, myUdpAddr5, myUdpAddr5, myUdpAddr6, myUdpAddr7, myUdpAddr8, myUdpAddr9,
)
@ -265,7 +274,7 @@ func TestLighthouse_Memory(t *testing.T) {
good := netip.MustParseAddrPort("1.128.0.99:4242")
newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{bad1, bad2, good}, lhh)
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, good)
assertIp4InArray(t, r.msg.Details.V4AddrPorts, good)
}
func TestLighthouse_reload(t *testing.T) {
@ -273,7 +282,16 @@ func TestLighthouse_reload(t *testing.T) {
c := config.NewC(l)
c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil)
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
nt := new(bart.Table[struct{}])
nt.Insert(myVpnNet, struct{}{})
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
assert.NoError(t, err)
nc := map[interface{}]interface{}{
@ -295,7 +313,7 @@ func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, l
req := &NebulaMeta{
Type: NebulaMeta_HostQuery,
Details: &NebulaMetaDetails{
VpnIp: binary.BigEndian.Uint32(bip[:]),
OldVpnAddr: binary.BigEndian.Uint32(bip[:]),
},
}
@ -308,7 +326,7 @@ func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, l
w := &testEncWriter{
metaFilter: &filter,
}
lhh.HandleRequest(fromAddr, myVpnIp, b, w)
lhh.HandleRequest(fromAddr, []netip.Addr{myVpnIp}, b, w)
return w.lastReply
}
@ -318,13 +336,13 @@ func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.Ad
req := &NebulaMeta{
Type: NebulaMeta_HostUpdateNotification,
Details: &NebulaMetaDetails{
VpnIp: binary.BigEndian.Uint32(bip[:]),
Ip4AndPorts: make([]*Ip4AndPort, len(addrs)),
OldVpnAddr: binary.BigEndian.Uint32(bip[:]),
V4AddrPorts: make([]*V4AddrPort, len(addrs)),
},
}
for k, v := range addrs {
req.Details.Ip4AndPorts[k] = NewIp4AndPortFromNetIP(v.Addr(), v.Port())
req.Details.V4AddrPorts[k] = netAddrToProtoV4AddrPort(v.Addr(), v.Port())
}
b, err := req.Marshal()
@ -333,7 +351,7 @@ func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.Ad
}
w := &testEncWriter{}
lhh.HandleRequest(fromAddr, vpnIp, b, w)
lhh.HandleRequest(fromAddr, []netip.Addr{vpnIp}, b, w)
}
//TODO: this is a RemoteList test
@ -410,8 +428,9 @@ type testLhReply struct {
}
type testEncWriter struct {
lastReply testLhReply
metaFilter *NebulaMeta_MessageType
lastReply testLhReply
metaFilter *NebulaMeta_MessageType
protocolVersion cert.Version
}
func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
@ -426,7 +445,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M
tw.lastReply = testLhReply{
nebType: t,
nebSubType: st,
vpnIp: hostinfo.vpnIp,
vpnIp: hostinfo.vpnAddrs[0],
msg: msg,
}
}
@ -436,7 +455,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M
}
}
func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, _, _ []byte) {
func (tw *testEncWriter) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, _, _ []byte) {
msg := &NebulaMeta{}
err := msg.Unmarshal(p)
if tw.metaFilter == nil || msg.Type == *tw.metaFilter {
@ -453,15 +472,23 @@ func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.Mess
}
}
func (tw *testEncWriter) GetHostInfo(vpnIp netip.Addr) *HostInfo {
return nil
}
func (tw *testEncWriter) GetCertState() *CertState {
return &CertState{defaultVersion: tw.protocolVersion}
}
// assertIp4InArray asserts every address in want is at the same position in have and that the lengths match
func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...netip.AddrPort) {
func assertIp4InArray(t *testing.T, have []*V4AddrPort, want ...netip.AddrPort) {
if !assert.Len(t, have, len(want)) {
return
}
for k, w := range want {
//TODO: IPV6-WORK
h := AddrPortFromIp4AndPort(have[k])
h := protoV4AddrPortToNetAddrPort(have[k])
if !(h == w) {
assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v, found %v", w, k, h))
}

24
main.go
View File

@ -2,7 +2,6 @@ package nebula
import (
"context"
"encoding/binary"
"fmt"
"net"
"net/netip"
@ -61,15 +60,12 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err)
}
certificate := pki.GetCertState().Certificate
fw, err := NewFirewallFromConfig(l, certificate, c)
fw, err := NewFirewallFromConfig(l, pki.getCertState(), c)
if err != nil {
return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err)
}
l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started")
tunCidr := certificate.Networks()[0]
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
if err != nil {
return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err)
@ -132,7 +128,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
deviceFactory = overlay.NewDeviceFromConfig
}
tun, err = deviceFactory(c, l, tunCidr, routines)
tun, err = deviceFactory(c, l, pki.getCertState().myVpnNetworks, routines)
if err != nil {
return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err)
}
@ -187,9 +183,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
}
}
hostMap := NewHostMapFromConfig(l, tunCidr, c)
hostMap := NewHostMapFromConfig(l, c)
punchy := NewPunchyFromConfig(l, c)
lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy)
lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy)
if err != nil {
return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err)
}
@ -232,7 +228,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
Inside: tun,
Outside: udpConns[0],
pki: pki,
Cipher: c.GetString("cipher", "aes"),
Firewall: fw,
ServeDns: serveDns,
HandshakeManager: handshakeManager,
@ -254,15 +249,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
l: l,
}
switch ifConfig.Cipher {
case "aes":
noiseEndianness = binary.BigEndian
case "chachapoly":
noiseEndianness = binary.LittleEndian
default:
return nil, fmt.Errorf("unknown cipher: %v", ifConfig.Cipher)
}
var ifce *Interface
if !configTest {
ifce, err = NewInterface(ctx, ifConfig)
@ -303,7 +289,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
var dnsStart func()
if lightHouse.amLighthouse && serveDns {
l.Debugln("Starting dns server")
dnsStart = dnsMain(l, hostMap, c)
dnsStart = dnsMain(l, pki.getCertState(), hostMap, c)
}
return &Control{

File diff suppressed because it is too large Load Diff

View File

@ -23,19 +23,28 @@ message NebulaMeta {
}
message NebulaMetaDetails {
uint32 VpnIp = 1;
repeated Ip4AndPort Ip4AndPorts = 2;
repeated Ip6AndPort Ip6AndPorts = 4;
repeated uint32 RelayVpnIp = 5;
uint32 OldVpnAddr = 1 [deprecated = true];
Addr VpnAddr = 6;
repeated uint32 OldRelayVpnAddrs = 5 [deprecated = true];
repeated Addr RelayVpnAddrs = 7;
repeated V4AddrPort V4AddrPorts = 2;
repeated V6AddrPort V6AddrPorts = 4;
uint32 counter = 3;
}
message Ip4AndPort {
uint32 Ip = 1;
message Addr {
uint64 Hi = 1;
uint64 Lo = 2;
}
message V4AddrPort {
uint32 Addr = 1;
uint32 Port = 2;
}
message Ip6AndPort {
message V6AddrPort {
uint64 Hi = 1;
uint64 Lo = 2;
uint32 Port = 3;
@ -62,6 +71,7 @@ message NebulaHandshakeDetails {
uint32 ResponderIndex = 3;
uint64 Cookie = 4;
uint64 Time = 5;
uint32 CertVersion = 8;
// reserved for WIP multiport
reserved 6, 7;
}
@ -76,6 +86,10 @@ message NebulaControl {
uint32 InitiatorRelayIndex = 2;
uint32 ResponderRelayIndex = 3;
uint32 RelayToIp = 4;
uint32 RelayFromIp = 5;
uint32 OldRelayToAddr = 4 [deprecated = true];
uint32 OldRelayFromAddr = 5 [deprecated = true];
Addr RelayToAddr = 6;
Addr RelayFromAddr = 7;
}

View File

@ -7,12 +7,12 @@ import (
"net/netip"
"time"
"github.com/flynn/noise"
"github.com/google/gopacket/layers"
"golang.org/x/net/ipv6"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/udp"
"golang.org/x/net/ipv4"
)
@ -20,24 +20,7 @@ const (
minFwPacketLen = 4
)
// TODO: IPV6-WORK this can likely be removed now
func readOutsidePackets(f *Interface) udp.EncReader {
return func(
addr netip.AddrPort,
out []byte,
packet []byte,
header *header.H,
fwPacket *firewall.Packet,
lhh udp.LightHouseHandlerFunc,
nb []byte,
q int,
localCache firewall.ConntrackCache,
) {
f.readOutsidePackets(addr, nil, out, packet, header, fwPacket, lhh, nb, q, localCache)
}
}
func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) {
func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
err := h.Parse(packet)
if err != nil {
// TODO: best if we return this and let caller log
@ -51,7 +34,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
//l.Error("in packet ", header, packet[HeaderLen:])
if ip.IsValid() {
if f.myVpnNet.Contains(ip.Addr()) {
_, found := f.myVpnNetworksTable.Lookup(ip.Addr())
if found {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
}
@ -108,7 +92,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
if !ok {
// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
// its internal mapping. This should never happen.
hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnIp": hostinfo.vpnIp, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
return
}
@ -120,9 +104,9 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
return
case ForwardingType:
// Find the target HostInfo relay object
targetHI, targetRelay, err := f.hostMap.QueryVpnIpRelayFor(hostinfo.vpnIp, relay.PeerIp)
targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
if err != nil {
hostinfo.logger(f.l).WithField("relayTo", relay.PeerIp).WithError(err).Info("Failed to find target host info by ip")
hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip")
return
}
@ -138,7 +122,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
}
} else {
hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerIp, "relayFrom": hostinfo.vpnIp, "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
return
}
}
@ -161,7 +145,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
return
}
lhf(ip, hostinfo.vpnIp, d)
lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f)
// Fallthrough to the bottom to record incoming traffic
@ -228,14 +212,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
Error("Failed to decrypt Control packet")
return
}
m := &NebulaControl{}
err = m.Unmarshal(d)
if err != nil {
hostinfo.logger(f.l).WithError(err).Error("Failed to unmarshal control message")
break
}
f.relayManager.HandleControlMsg(hostinfo, m, f)
f.relayManager.HandleControlMsg(hostinfo, d, f)
default:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
@ -252,8 +230,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
func (f *Interface) closeTunnel(hostInfo *HostInfo) {
final := f.hostMap.DeleteHostInfo(hostInfo)
if final {
// We no longer have any tunnels with this vpn ip, clear learned lighthouse state to lower memory usage
f.lightHouse.DeleteVpnIp(hostInfo.vpnIp)
// We no longer have any tunnels with this vpn addr, clear learned lighthouse state to lower memory usage
f.lightHouse.DeleteVpnAddrs(hostInfo.vpnAddrs)
}
}
@ -262,25 +240,26 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) {
f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
}
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, ip netip.AddrPort) {
if ip.IsValid() && hostinfo.remote != ip {
if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, ip.Addr()) {
hostinfo.logger(f.l).WithField("newAddr", ip).Debug("lighthouse.remote_allow_list denied roaming")
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, vpnAddr netip.AddrPort) {
if vpnAddr.IsValid() && hostinfo.remote != vpnAddr {
//TODO: this is weird now that we can have multiple vpn addrs
if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnAddrs[0], vpnAddr.Addr()) {
hostinfo.logger(f.l).WithField("newAddr", vpnAddr).Debug("lighthouse.remote_allow_list denied roaming")
return
}
if !hostinfo.lastRoam.IsZero() && ip == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
if !hostinfo.lastRoam.IsZero() && vpnAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip).
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", vpnAddr).
Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
}
return
}
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip).
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", vpnAddr).
Info("Host roamed to new udp ip/port.")
hostinfo.lastRoam = time.Now()
hostinfo.lastRoamRemote = hostinfo.remote
hostinfo.SetRemote(ip)
hostinfo.SetRemote(vpnAddr)
}
}
@ -302,14 +281,114 @@ func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h
// newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers
func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
// Do we at least have an ipv4 header worth of data?
if len(data) < ipv4.HeaderLen {
return fmt.Errorf("packet is less than %v bytes", ipv4.HeaderLen)
if len(data) < 1 {
return errors.New("packet too short")
}
// Is it an ipv4 packet?
if int((data[0]>>4)&0x0f) != 4 {
return fmt.Errorf("packet is not ipv4, type: %v", int((data[0]>>4)&0x0f))
version := int((data[0] >> 4) & 0x0f)
switch version {
case ipv4.Version:
return parseV4(data, incoming, fp)
case ipv6.Version:
return parseV6(data, incoming, fp)
}
return fmt.Errorf("packet is an unknown ip version: %v", version)
}
func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
dataLen := len(data)
if dataLen < ipv6.HeaderLen {
return fmt.Errorf("ipv6 packet is less than %v bytes", ipv4.HeaderLen)
}
if incoming {
fp.RemoteAddr, _ = netip.AddrFromSlice(data[8:24])
fp.LocalAddr, _ = netip.AddrFromSlice(data[24:40])
} else {
fp.LocalAddr, _ = netip.AddrFromSlice(data[8:24])
fp.RemoteAddr, _ = netip.AddrFromSlice(data[24:40])
}
//TODO: whats a reasonable number of extension headers to attempt to parse?
//https://www.ietf.org/archive/id/draft-ietf-6man-eh-limits-00.html
protoAt := 6
offset := 40
for i := 0; i < 24; i++ {
if dataLen < offset {
break
}
proto := layers.IPProtocol(data[protoAt])
//fmt.Println(proto, protoAt)
switch proto {
case layers.IPProtocolICMPv6:
//TODO: we need a new protocol in config language "icmpv6"
fp.Protocol = uint8(proto)
fp.RemotePort = 0
fp.LocalPort = 0
fp.Fragment = false
return nil
case layers.IPProtocolTCP:
if dataLen < offset+4 {
return fmt.Errorf("ipv6 packet was too small")
}
fp.Protocol = uint8(proto)
if incoming {
fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2])
fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
} else {
fp.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2])
fp.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
}
fp.Fragment = false
return nil
case layers.IPProtocolUDP:
if dataLen < offset+4 {
return fmt.Errorf("ipv6 packet was too small")
}
fp.Protocol = uint8(proto)
if incoming {
fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2])
fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
} else {
fp.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2])
fp.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
}
fp.Fragment = false
return nil
case layers.IPProtocolIPv6Fragment:
//TODO: can we determine the protocol?
fp.RemotePort = 0
fp.LocalPort = 0
fp.Fragment = true
return nil
default:
if dataLen < offset+1 {
break
}
next := int(data[offset+1]) * 8
if next == 0 {
// each extension is at least 8 bytes
next = 8
}
protoAt = offset
offset = offset + next
}
}
return fmt.Errorf("could not find payload in ipv6 packet")
}
func parseV4(data []byte, incoming bool, fp *firewall.Packet) error {
// Do we at least have an ipv4 header worth of data?
if len(data) < ipv4.HeaderLen {
return fmt.Errorf("ipv4 packet is less than %v bytes", ipv4.HeaderLen)
}
// Adjust our start position based on the advertised ip header length
@ -317,7 +396,7 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
// Well formed ip header length?
if ihl < ipv4.HeaderLen {
return fmt.Errorf("packet had an invalid header length: %v", ihl)
return fmt.Errorf("ipv4 packet had an invalid header length: %v", ihl)
}
// Check if this is the second or further fragment of a fragmented packet.
@ -333,14 +412,13 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
minLen += minFwPacketLen
}
if len(data) < minLen {
return fmt.Errorf("packet is less than %v bytes, ip header len: %v", minLen, ihl)
return fmt.Errorf("ipv4 packet is less than %v bytes, ip header len: %v", minLen, ihl)
}
// Firewall packets are locally oriented
if incoming {
//TODO: IPV6-WORK
fp.RemoteIP, _ = netip.AddrFromSlice(data[12:16])
fp.LocalIP, _ = netip.AddrFromSlice(data[16:20])
fp.RemoteAddr, _ = netip.AddrFromSlice(data[12:16])
fp.LocalAddr, _ = netip.AddrFromSlice(data[16:20])
if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
fp.RemotePort = 0
fp.LocalPort = 0
@ -349,9 +427,8 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
}
} else {
//TODO: IPV6-WORK
fp.LocalIP, _ = netip.AddrFromSlice(data[12:16])
fp.RemoteIP, _ = netip.AddrFromSlice(data[16:20])
fp.LocalAddr, _ = netip.AddrFromSlice(data[12:16])
fp.RemoteAddr, _ = netip.AddrFromSlice(data[16:20])
if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
fp.RemotePort = 0
fp.LocalPort = 0
@ -492,27 +569,3 @@ func (f *Interface) sendMeta(ci *ConnectionState, endpoint *net.UDPAddr, meta *N
f.outside.WriteTo(msg, endpoint)
}
*/
func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte, caPool *cert.CAPool) (*cert.CachedCertificate, error) {
pk := h.PeerStatic()
if pk == nil {
return nil, errors.New("no peer static key was present")
}
if rawCertBytes == nil {
return nil, errors.New("provided payload was empty")
}
c, err := cert.UnmarshalCertificateFromHandshake(rawCertBytes, pk)
if err != nil {
return nil, fmt.Errorf("error unmarshaling cert: %w", err)
}
cc, err := caPool.VerifyCertificate(time.Now(), c)
if err != nil {
return nil, fmt.Errorf("certificate validation failed: %w", err)
}
return cc, nil
}

View File

@ -5,6 +5,9 @@ import (
"net/netip"
"testing"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/slackhq/nebula/firewall"
"github.com/stretchr/testify/assert"
"golang.org/x/net/ipv4"
@ -13,9 +16,15 @@ import (
func Test_newPacket(t *testing.T) {
p := &firewall.Packet{}
// length fail
err := newPacket([]byte{0, 1}, true, p)
assert.EqualError(t, err, "packet is less than 20 bytes")
// length fails
err := newPacket([]byte{}, true, p)
assert.EqualError(t, err, "packet too short")
err = newPacket([]byte{0x40}, true, p)
assert.EqualError(t, err, "ipv4 packet is less than 20 bytes")
err = newPacket([]byte{0x60}, true, p)
assert.EqualError(t, err, "ipv6 packet is less than 20 bytes")
// length fail with ip options
h := ipv4.Header{
@ -29,15 +38,15 @@ func Test_newPacket(t *testing.T) {
b, _ := h.Marshal()
err = newPacket(b, true, p)
assert.EqualError(t, err, "packet is less than 28 bytes, ip header len: 24")
assert.EqualError(t, err, "ipv4 packet is less than 28 bytes, ip header len: 24")
// not an ipv4 packet
err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
assert.EqualError(t, err, "packet is not ipv4, type: 0")
assert.EqualError(t, err, "packet is an unknown ip version: 0")
// invalid ihl
err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
assert.EqualError(t, err, "packet had an invalid header length: 8")
assert.EqualError(t, err, "ipv4 packet had an invalid header length: 8")
// account for variable ip header length - incoming
h = ipv4.Header{
@ -55,8 +64,8 @@ func Test_newPacket(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, p.Protocol, uint8(firewall.ProtoTCP))
assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.2"))
assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.1"))
assert.Equal(t, p.LocalAddr, netip.MustParseAddr("10.0.0.2"))
assert.Equal(t, p.RemoteAddr, netip.MustParseAddr("10.0.0.1"))
assert.Equal(t, p.RemotePort, uint16(3))
assert.Equal(t, p.LocalPort, uint16(4))
@ -76,8 +85,60 @@ func Test_newPacket(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, p.Protocol, uint8(2))
assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.1"))
assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.2"))
assert.Equal(t, p.LocalAddr, netip.MustParseAddr("10.0.0.1"))
assert.Equal(t, p.RemoteAddr, netip.MustParseAddr("10.0.0.2"))
assert.Equal(t, p.RemotePort, uint16(6))
assert.Equal(t, p.LocalPort, uint16(5))
}
func Test_newPacket_v6(t *testing.T) {
p := &firewall.Packet{}
ip := layers.IPv6{
Version: 6,
NextHeader: firewall.ProtoUDP,
HopLimit: 128,
SrcIP: net.IPv6linklocalallrouters,
DstIP: net.IPv6linklocalallnodes,
}
udp := layers.UDP{
SrcPort: layers.UDPPort(36123),
DstPort: layers.UDPPort(22),
}
err := udp.SetNetworkLayerForChecksum(&ip)
if err != nil {
panic(err)
}
buffer := gopacket.NewSerializeBuffer()
opt := gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload([]byte{0xde, 0xad, 0xbe, 0xef}))
if err != nil {
panic(err)
}
b := buffer.Bytes()
//test incoming
err = newPacket(b, true, p)
assert.Nil(t, err)
assert.Equal(t, p.Protocol, uint8(firewall.ProtoUDP))
assert.Equal(t, p.RemoteAddr, netip.MustParseAddr("ff02::2"))
assert.Equal(t, p.LocalAddr, netip.MustParseAddr("ff02::1"))
assert.Equal(t, p.RemotePort, uint16(36123))
assert.Equal(t, p.LocalPort, uint16(22))
//test outgoing
err = newPacket(b, false, p)
assert.Nil(t, err)
assert.Equal(t, p.Protocol, uint8(firewall.ProtoUDP))
assert.Equal(t, p.LocalAddr, netip.MustParseAddr("ff02::2"))
assert.Equal(t, p.RemoteAddr, netip.MustParseAddr("ff02::1"))
assert.Equal(t, p.LocalPort, uint16(36123))
assert.Equal(t, p.RemotePort, uint16(22))
}

View File

@ -8,7 +8,7 @@ import (
type Device interface {
io.ReadWriteCloser
Activate() error
Cidr() netip.Prefix
Networks() []netip.Prefix
Name() string
RouteFor(netip.Addr) netip.Addr
NewMultiQueueReader() (io.ReadWriteCloser, error)

View File

@ -61,7 +61,7 @@ func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table
return routeTree, nil
}
func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
func parseRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
var err error
r := c.Get("tun.routes")
@ -117,12 +117,20 @@ func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
return nil, fmt.Errorf("entry %v.route in tun.routes failed to parse: %v", i+1, err)
}
if !network.Contains(r.Cidr.Addr()) || r.Cidr.Bits() < network.Bits() {
found := false
for _, network := range networks {
if network.Contains(r.Cidr.Addr()) && r.Cidr.Bits() >= network.Bits() {
found = true
break
}
}
if !found {
return nil, fmt.Errorf(
"entry %v.route in tun.routes is not contained within the network attached to the certificate; route: %v, network: %v",
"entry %v.route in tun.routes is not contained within the configured vpn networks; route: %v, networks: %v",
i+1,
r.Cidr.String(),
network.String(),
networks,
)
}
@ -132,7 +140,7 @@ func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
return routes, nil
}
func parseUnsafeRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
var err error
r := c.Get("tun.unsafe_routes")
@ -229,13 +237,15 @@ func parseUnsafeRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes failed to parse: %v", i+1, err)
}
if network.Contains(r.Cidr.Addr()) {
return nil, fmt.Errorf(
"entry %v.route in tun.unsafe_routes is contained within the network attached to the certificate; route: %v, network: %v",
i+1,
r.Cidr.String(),
network.String(),
)
for _, network := range networks {
if network.Contains(r.Cidr.Addr()) {
return nil, fmt.Errorf(
"entry %v.route in tun.unsafe_routes is contained within the configured vpn networks; route: %v, network: %v",
i+1,
r.Cidr.String(),
network.String(),
)
}
}
routes[i] = r

View File

@ -17,76 +17,82 @@ func Test_parseRoutes(t *testing.T) {
assert.NoError(t, err)
// test no routes config
routes, err := parseRoutes(c, n)
routes, err := parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, err)
assert.Len(t, routes, 0)
// not an array
c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"}
routes, err = parseRoutes(c, n)
routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
assert.EqualError(t, err, "tun.routes is not an array")
// no routes
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}}
routes, err = parseRoutes(c, n)
routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, err)
assert.Len(t, routes, 0)
// weird route
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}}
routes, err = parseRoutes(c, n)
routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1 in tun.routes is invalid")
// no mtu
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{}}}
routes, err = parseRoutes(c, n)
routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.mtu in tun.routes is not present")
// bad mtu
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "nope"}}}
routes, err = parseRoutes(c, n)
routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
// low mtu
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "499"}}}
routes, err = parseRoutes(c, n)
routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499")
// missing route
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500"}}}
routes, err = parseRoutes(c, n)
routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.routes is not present")
// unparsable route
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}}
routes, err = parseRoutes(c, n)
routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
// below network range
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}}
routes, err = parseRoutes(c, n)
routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the network attached to the certificate; route: 1.0.0.0/8, network: 10.0.0.0/24")
assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 1.0.0.0/8, networks: [10.0.0.0/24]")
// above network range
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "10.0.1.0/24"}}}
routes, err = parseRoutes(c, n)
routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the network attached to the certificate; route: 10.0.1.0/24, network: 10.0.0.0/24")
assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 10.0.1.0/24, networks: [10.0.0.0/24]")
// Not in multiple ranges
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "192.0.0.0/24"}}}
routes, err = parseRoutes(c, []netip.Prefix{n, netip.MustParsePrefix("192.1.0.0/24")})
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 192.0.0.0/24, networks: [10.0.0.0/24 192.1.0.0/24]")
// happy case
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{
map[interface{}]interface{}{"mtu": "9000", "route": "10.0.0.0/29"},
map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"},
}}
routes, err = parseRoutes(c, n)
routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, err)
assert.Len(t, routes, 2)
@ -116,31 +122,31 @@ func Test_parseUnsafeRoutes(t *testing.T) {
assert.NoError(t, err)
// test no routes config
routes, err := parseUnsafeRoutes(c, n)
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, err)
assert.Len(t, routes, 0)
// not an array
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"}
routes, err = parseUnsafeRoutes(c, n)
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
assert.EqualError(t, err, "tun.unsafe_routes is not an array")
// no routes
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}}
routes, err = parseUnsafeRoutes(c, n)
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, err)
assert.Len(t, routes, 0)
// weird route
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}}
routes, err = parseUnsafeRoutes(c, n)
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid")
// no via
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{}}}
routes, err = parseUnsafeRoutes(c, n)
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present")
@ -149,68 +155,68 @@ func Test_parseUnsafeRoutes(t *testing.T) {
127, false, nil, 1.0, []string{"1", "2"},
} {
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}}
routes, err = parseUnsafeRoutes(c, n)
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
assert.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue))
}
// unparsable via
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}}
routes, err = parseUnsafeRoutes(c, n)
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP")
// missing route
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}}
routes, err = parseUnsafeRoutes(c, n)
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present")
// unparsable route
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}}
routes, err = parseUnsafeRoutes(c, n)
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
// within network range
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}}
routes, err = parseUnsafeRoutes(c, n)
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the network attached to the certificate; route: 10.0.0.0/24, network: 10.0.0.0/24")
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the configured vpn networks; route: 10.0.0.0/24, network: 10.0.0.0/24")
// below network range
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
routes, err = parseUnsafeRoutes(c, n)
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Len(t, routes, 1)
assert.Nil(t, err)
// above network range
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}}
routes, err = parseUnsafeRoutes(c, n)
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Len(t, routes, 1)
assert.Nil(t, err)
// no mtu
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
routes, err = parseUnsafeRoutes(c, n)
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Len(t, routes, 1)
assert.Equal(t, 0, routes[0].MTU)
// bad mtu
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}}
routes, err = parseUnsafeRoutes(c, n)
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
// low mtu
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "499"}}}
routes, err = parseUnsafeRoutes(c, n)
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499")
// bad install
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}}
routes, err = parseUnsafeRoutes(c, n)
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax")
@ -221,7 +227,7 @@ func Test_parseUnsafeRoutes(t *testing.T) {
map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1},
map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"},
}}
routes, err = parseUnsafeRoutes(c, n)
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, err)
assert.Len(t, routes, 4)
@ -260,7 +266,7 @@ func Test_makeRouteTree(t *testing.T) {
map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"},
map[interface{}]interface{}{"via": "192.168.0.2", "route": "1.0.0.1/32"},
}}
routes, err := parseUnsafeRoutes(c, n)
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
assert.NoError(t, err)
assert.Len(t, routes, 2)
routeTree, err := makeRouteTree(l, routes, true)

View File

@ -11,36 +11,36 @@ import (
const DefaultMTU = 1300
// TODO: We may be able to remove routines
type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error)
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) {
func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
switch {
case c.GetBool("tun.disabled", false):
tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
return tun, nil
default:
return newTun(c, l, tunCidr, routines > 1)
return newTun(c, l, vpnNetworks, routines > 1)
}
}
func NewFdDeviceFromConfig(fd *int) DeviceFactory {
return func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) {
return newTunFromFd(c, l, *fd, tunCidr)
return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
return newTunFromFd(c, l, *fd, vpnNetworks)
}
}
func getAllRoutesFromConfig(c *config.C, cidr netip.Prefix, initial bool) (bool, []Route, error) {
func getAllRoutesFromConfig(c *config.C, vpnNetworks []netip.Prefix, initial bool) (bool, []Route, error) {
if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") {
return false, nil, nil
}
routes, err := parseRoutes(c, cidr)
routes, err := parseRoutes(c, vpnNetworks)
if err != nil {
return true, nil, util.NewContextualError("Could not parse tun.routes", nil, err)
}
unsafeRoutes, err := parseUnsafeRoutes(c, cidr)
unsafeRoutes, err := parseUnsafeRoutes(c, vpnNetworks)
if err != nil {
return true, nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err)
}

View File

@ -18,14 +18,14 @@ import (
type tun struct {
io.ReadWriteCloser
fd int
cidr netip.Prefix
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger
fd int
vpnNetworks []netip.Prefix
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger
}
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
// XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly.
// Be sure not to call file.Fd() as it will set the fd to blocking mode.
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
@ -33,7 +33,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix
t := &tun{
ReadWriteCloser: file,
fd: deviceFd,
cidr: cidr,
vpnNetworks: vpnNetworks,
l: l,
}
@ -52,7 +52,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix
return t, nil
}
func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) {
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTun not supported in Android")
}
@ -66,7 +66,7 @@ func (t tun) Activate() error {
}
func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil {
return err
}
@ -86,8 +86,8 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil
}
func (t *tun) Cidr() netip.Prefix {
return t.cidr
func (t *tun) Networks() []netip.Prefix {
return t.vpnNetworks
}
func (t *tun) Name() string {

View File

@ -24,56 +24,62 @@ import (
type tun struct {
io.ReadWriteCloser
Device string
cidr netip.Prefix
DefaultMTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]]
linkAddr *netroute.LinkAddr
l *logrus.Logger
Device string
vpnNetworks []netip.Prefix
DefaultMTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]]
linkAddr *netroute.LinkAddr
l *logrus.Logger
// cache out buffer since we need to prepend 4 bytes for tun metadata
out []byte
}
type sockaddrCtl struct {
scLen uint8
scFamily uint8
ssSysaddr uint16
scID uint32
scUnit uint32
scReserved [5]uint32
}
type ifReq struct {
Name [16]byte
Name [unix.IFNAMSIZ]byte
Flags uint16
pad [8]byte
}
var sockaddrCtlSize uintptr = 32
const (
_SYSPROTO_CONTROL = 2 //define SYSPROTO_CONTROL 2 /* kernel control protocol */
_AF_SYS_CONTROL = 2 //#define AF_SYS_CONTROL 2 /* corresponding sub address type */
_PF_SYSTEM = unix.AF_SYSTEM //#define PF_SYSTEM AF_SYSTEM
_CTLIOCGINFO = 3227799043 //#define CTLIOCGINFO _IOWR('N', 3, struct ctl_info)
utunControlName = "com.apple.net.utun_control"
_SIOCAIFADDR_IN6 = 2155899162
_UTUN_OPT_IFNAME = 2
_IN6_IFF_NODAD = 0x0020
_IN6_IFF_SECURED = 0x0400
utunControlName = "com.apple.net.utun_control"
)
type ifreqAddr struct {
Name [16]byte
Addr unix.RawSockaddrInet4
pad [8]byte
}
type ifreqMTU struct {
Name [16]byte
MTU int32
pad [8]byte
}
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
type addrLifetime struct {
Expire float64
Preferred float64
Vltime uint32
Pltime uint32
}
type ifreqAlias4 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet4
DstAddr unix.RawSockaddrInet4
MaskAddr unix.RawSockaddrInet4
}
type ifreqAlias6 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet6
DstAddr unix.RawSockaddrInet6
PrefixMask unix.RawSockaddrInet6
Flags uint32
Lifetime addrLifetime
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
name := c.GetString("tun.dev", "")
ifIndex := -1
if name != "" && name != "utun" {
@ -86,66 +92,41 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
}
}
fd, err := unix.Socket(_PF_SYSTEM, unix.SOCK_DGRAM, _SYSPROTO_CONTROL)
fd, err := unix.Socket(unix.AF_SYSTEM, unix.SOCK_DGRAM, unix.AF_SYS_CONTROL)
if err != nil {
return nil, fmt.Errorf("system socket: %v", err)
}
var ctlInfo = &struct {
ctlID uint32
ctlName [96]byte
}{}
var ctlInfo = &unix.CtlInfo{}
copy(ctlInfo.Name[:], utunControlName)
copy(ctlInfo.ctlName[:], utunControlName)
err = ioctl(uintptr(fd), uintptr(_CTLIOCGINFO), uintptr(unsafe.Pointer(ctlInfo)))
err = unix.IoctlCtlInfo(fd, ctlInfo)
if err != nil {
return nil, fmt.Errorf("CTLIOCGINFO: %v", err)
}
sc := sockaddrCtl{
scLen: uint8(sockaddrCtlSize),
scFamily: unix.AF_SYSTEM,
ssSysaddr: _AF_SYS_CONTROL,
scID: ctlInfo.ctlID,
scUnit: uint32(ifIndex) + 1,
err = unix.Connect(fd, &unix.SockaddrCtl{
ID: ctlInfo.Id,
Unit: uint32(ifIndex) + 1,
})
if err != nil {
return nil, fmt.Errorf("SYS_CONNECT: %v", err)
}
_, _, errno := unix.RawSyscall(
unix.SYS_CONNECT,
uintptr(fd),
uintptr(unsafe.Pointer(&sc)),
sockaddrCtlSize,
)
if errno != 0 {
return nil, fmt.Errorf("SYS_CONNECT: %v", errno)
name, err = unix.GetsockoptString(fd, unix.AF_SYS_CONTROL, _UTUN_OPT_IFNAME)
if err != nil {
return nil, fmt.Errorf("failed to retrieve tun name: %w", err)
}
var ifName struct {
name [16]byte
}
ifNameSize := uintptr(len(ifName.name))
_, _, errno = syscall.Syscall6(syscall.SYS_GETSOCKOPT, uintptr(fd),
2, // SYSPROTO_CONTROL
2, // UTUN_OPT_IFNAME
uintptr(unsafe.Pointer(&ifName)),
uintptr(unsafe.Pointer(&ifNameSize)), 0)
if errno != 0 {
return nil, fmt.Errorf("SYS_GETSOCKOPT: %v", errno)
}
name = string(ifName.name[:ifNameSize-1])
err = syscall.SetNonblock(fd, true)
err = unix.SetNonblock(fd, true)
if err != nil {
return nil, fmt.Errorf("SetNonblock: %v", err)
}
file := os.NewFile(uintptr(fd), "")
t := &tun{
ReadWriteCloser: file,
ReadWriteCloser: os.NewFile(uintptr(fd), ""),
Device: name,
cidr: cidr,
vpnNetworks: vpnNetworks,
DefaultMTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
}
@ -172,7 +153,7 @@ func (t *tun) deviceBytes() (o [16]byte) {
return
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
}
@ -186,16 +167,6 @@ func (t *tun) Close() error {
func (t *tun) Activate() error {
devName := t.deviceBytes()
var addr, mask [4]byte
if !t.cidr.Addr().Is4() {
//TODO: IPV6-WORK
panic("need ipv6")
}
addr = t.cidr.Addr().As4()
copy(mask[:], prefixToMask(t.cidr))
s, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
@ -208,66 +179,18 @@ func (t *tun) Activate() error {
fd := uintptr(s)
ifra := ifreqAddr{
Name: devName,
Addr: unix.RawSockaddrInet4{
Family: unix.AF_INET,
Addr: addr,
},
}
// Set the device ip address
if err = ioctl(fd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil {
return fmt.Errorf("failed to set tun address: %s", err)
}
// Set the device network
ifra.Addr.Addr = mask
if err = ioctl(fd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil {
return fmt.Errorf("failed to set tun netmask: %s", err)
}
// Set the device name
ifrf := ifReq{Name: devName}
if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to set tun device name: %s", err)
}
// Set the MTU on the device
ifm := ifreqMTU{Name: devName, MTU: int32(t.DefaultMTU)}
if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
return fmt.Errorf("failed to set tun mtu: %v", err)
}
/*
// Set the transmit queue length
ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)}
if err = ioctl(fd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
// If we can't set the queue length nebula will still work but it may lead to packet loss
l.WithError(err).Error("Failed to set tun tx queue length")
}
*/
// Bring up the interface
ifrf.Flags = ifrf.Flags | unix.IFF_UP
if err = ioctl(fd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to bring the tun device up: %s", err)
// Get the device flags
ifrf := ifReq{Name: devName}
if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to get tun flags: %s", err)
}
routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer func() {
unix.Shutdown(routeSock, unix.SHUT_RDWR)
err := unix.Close(routeSock)
if err != nil {
t.l.WithError(err).Error("failed to close AF_ROUTE socket")
}
}()
routeAddr := &netroute.Inet4Addr{}
maskAddr := &netroute.Inet4Addr{}
linkAddr, err := getLinkAddr(t.Device)
if err != nil {
return err
@ -277,14 +200,18 @@ func (t *tun) Activate() error {
}
t.linkAddr = linkAddr
copy(routeAddr.IP[:], addr[:])
copy(maskAddr.IP[:], mask[:])
err = addRoute(routeSock, routeAddr, maskAddr, linkAddr)
if err != nil {
if errors.Is(err, unix.EEXIST) {
err = fmt.Errorf("unable to add tun route, identical route already exists: %s", t.cidr)
for _, network := range t.vpnNetworks {
if network.Addr().Is4() {
err = t.activate4(network)
if err != nil {
return err
}
} else {
err = t.activate6(network)
if err != nil {
return err
}
}
return err
}
// Run the interface
@ -297,8 +224,89 @@ func (t *tun) Activate() error {
return t.addRoutes(false)
}
func (t *tun) activate4(network netip.Prefix) error {
s, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
unix.IPPROTO_IP,
)
if err != nil {
return err
}
defer unix.Close(s)
ifr := ifreqAlias4{
Name: t.deviceBytes(),
Addr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: network.Addr().As4(),
},
DstAddr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: network.Addr().As4(),
},
MaskAddr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: prefixToMask(network).As4(),
},
}
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
return fmt.Errorf("failed to set tun v4 address: %s", err)
}
err = addRoute(network, t.linkAddr)
if err != nil {
return err
}
return nil
}
func (t *tun) activate6(network netip.Prefix) error {
s, err := unix.Socket(
unix.AF_INET6,
unix.SOCK_DGRAM,
unix.IPPROTO_IP,
)
if err != nil {
return err
}
defer unix.Close(s)
ifr := ifreqAlias6{
Name: t.deviceBytes(),
Addr: unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: network.Addr().As16(),
},
PrefixMask: unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: prefixToMask(network).As16(),
},
Lifetime: addrLifetime{
// never expires
Vltime: 0xffffffff,
Pltime: 0xffffffff,
},
//TODO: should we disable DAD (duplicate address detection) and mark this as a secured address?
Flags: _IN6_IFF_NODAD,
}
if err := ioctl(uintptr(s), _SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil {
return fmt.Errorf("failed to set tun address: %s", err)
}
return nil
}
func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil {
return err
}
@ -371,38 +379,15 @@ func getLinkAddr(name string) (*netroute.LinkAddr, error) {
}
func (t *tun) addRoutes(logErrors bool) error {
routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer func() {
unix.Shutdown(routeSock, unix.SHUT_RDWR)
err := unix.Close(routeSock)
if err != nil {
t.l.WithError(err).Error("failed to close AF_ROUTE socket")
}
}()
routeAddr := &netroute.Inet4Addr{}
maskAddr := &netroute.Inet4Addr{}
routes := *t.Routes.Load()
for _, r := range routes {
if !r.Via.IsValid() || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
if !r.Cidr.Addr().Is4() {
//TODO: implement ipv6
panic("Cant handle ipv6 routes yet")
}
routeAddr.IP = r.Cidr.Addr().As4()
//TODO: we could avoid the copy
copy(maskAddr.IP[:], prefixToMask(r.Cidr))
err := addRoute(routeSock, routeAddr, maskAddr, t.linkAddr)
err := addRoute(r.Cidr, t.linkAddr)
if err != nil {
if errors.Is(err, unix.EEXIST) {
t.l.WithField("route", r.Cidr).
@ -424,36 +409,12 @@ func (t *tun) addRoutes(logErrors bool) error {
}
func (t *tun) removeRoutes(routes []Route) error {
routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer func() {
unix.Shutdown(routeSock, unix.SHUT_RDWR)
err := unix.Close(routeSock)
if err != nil {
t.l.WithError(err).Error("failed to close AF_ROUTE socket")
}
}()
routeAddr := &netroute.Inet4Addr{}
maskAddr := &netroute.Inet4Addr{}
for _, r := range routes {
if !r.Install {
continue
}
if r.Cidr.Addr().Is6() {
//TODO: implement ipv6
panic("Cant handle ipv6 routes yet")
}
routeAddr.IP = r.Cidr.Addr().As4()
copy(maskAddr.IP[:], prefixToMask(r.Cidr))
err := delRoute(routeSock, routeAddr, maskAddr, t.linkAddr)
err := delRoute(r.Cidr, t.linkAddr)
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
@ -463,23 +424,39 @@ func (t *tun) removeRoutes(routes []Route) error {
return nil
}
func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error {
r := netroute.RouteMessage{
func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := &netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_ADD,
Flags: unix.RTF_UP,
Seq: 1,
Addrs: []netroute.Addr{
unix.RTAX_DST: addr,
unix.RTAX_GATEWAY: link,
unix.RTAX_NETMASK: mask,
},
}
data, err := r.Marshal()
if prefix.Addr().Is4() {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: gateway,
}
} else {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: gateway,
}
}
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
@ -488,19 +465,34 @@ func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr)
return nil
}
func delRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error {
r := netroute.RouteMessage{
func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_DELETE,
Seq: 1,
Addrs: []netroute.Addr{
unix.RTAX_DST: addr,
unix.RTAX_GATEWAY: link,
unix.RTAX_NETMASK: mask,
},
}
data, err := r.Marshal()
if prefix.Addr().Is4() {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: gateway,
}
} else {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: gateway,
}
}
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
@ -513,7 +505,6 @@ func delRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr)
}
func (t *tun) Read(to []byte) (int, error) {
buf := make([]byte, len(to)+4)
n, err := t.ReadWriteCloser.Read(buf)
@ -551,8 +542,8 @@ func (t *tun) Write(from []byte) (int, error) {
return n - 4, err
}
func (t *tun) Cidr() netip.Prefix {
return t.cidr
func (t *tun) Networks() []netip.Prefix {
return t.vpnNetworks
}
func (t *tun) Name() string {
@ -563,10 +554,12 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
}
func prefixToMask(prefix netip.Prefix) []byte {
func prefixToMask(prefix netip.Prefix) netip.Addr {
pLen := 128
if prefix.Addr().Is4() {
pLen = 32
}
return net.CIDRMask(prefix.Bits(), pLen)
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
return addr
}

View File

@ -12,8 +12,8 @@ import (
)
type disabledTun struct {
read chan []byte
cidr netip.Prefix
read chan []byte
vpnNetworks []netip.Prefix
// Track these metrics since we don't have the tun device to do it for us
tx metrics.Counter
@ -21,11 +21,11 @@ type disabledTun struct {
l *logrus.Logger
}
func newDisabledTun(cidr netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
tun := &disabledTun{
cidr: cidr,
read: make(chan []byte, queueLen),
l: l,
vpnNetworks: vpnNetworks,
read: make(chan []byte, queueLen),
l: l,
}
if metricsEnabled {
@ -47,8 +47,8 @@ func (*disabledTun) RouteFor(addr netip.Addr) netip.Addr {
return netip.Addr{}
}
func (t *disabledTun) Cidr() netip.Prefix {
return t.cidr
func (t *disabledTun) Networks() []netip.Prefix {
return t.vpnNetworks
}
func (*disabledTun) Name() string {

View File

@ -46,12 +46,12 @@ type ifreqDestroy struct {
}
type tun struct {
Device string
cidr netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger
Device string
vpnNetworks []netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger
io.ReadWriteCloser
}
@ -78,11 +78,11 @@ func (t *tun) Close() error {
return nil
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
}
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open existing tun device
var file *os.File
var err error
@ -150,7 +150,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
t := &tun{
ReadWriteCloser: file,
Device: deviceName,
cidr: cidr,
vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
}
@ -170,16 +170,16 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
return t, nil
}
func (t *tun) Activate() error {
func (t *tun) addIp(cidr netip.Prefix) error {
var err error
// TODO use syscalls instead of exec.Command
cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String())
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), "-interface", t.Device)
cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), "-interface", t.Device)
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err)
@ -195,8 +195,18 @@ func (t *tun) Activate() error {
return t.addRoutes(false)
}
func (t *tun) Activate() error {
for i := range t.vpnNetworks {
err := t.addIp(t.vpnNetworks[i])
if err != nil {
return err
}
}
return nil
}
func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil {
return err
}
@ -237,8 +247,8 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
return r
}
func (t *tun) Cidr() netip.Prefix {
return t.cidr
func (t *tun) Networks() []netip.Prefix {
return t.vpnNetworks
}
func (t *tun) Name() string {

View File

@ -21,20 +21,20 @@ import (
type tun struct {
io.ReadWriteCloser
cidr netip.Prefix
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger
vpnNetworks []netip.Prefix
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger
}
func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) {
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTun not supported in iOS")
}
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
file := os.NewFile(uintptr(deviceFd), "/dev/tun")
t := &tun{
cidr: cidr,
vpnNetworks: vpnNetworks,
ReadWriteCloser: &tunReadCloser{f: file},
l: l,
}
@ -59,7 +59,7 @@ func (t *tun) Activate() error {
}
func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil {
return err
}
@ -142,8 +142,8 @@ func (tr *tunReadCloser) Close() error {
return tr.f.Close()
}
func (t *tun) Cidr() netip.Prefix {
return t.cidr
func (t *tun) Networks() []netip.Prefix {
return t.vpnNetworks
}
func (t *tun) Name() string {

View File

@ -25,7 +25,7 @@ type tun struct {
io.ReadWriteCloser
fd int
Device string
cidr netip.Prefix
vpnNetworks []netip.Prefix
MaxMTU int
DefaultMTU int
TXQueueLen int
@ -40,18 +40,16 @@ type tun struct {
l *logrus.Logger
}
func (t *tun) Networks() []netip.Prefix {
return t.vpnNetworks
}
type ifReq struct {
Name [16]byte
Flags uint16
pad [8]byte
}
type ifreqAddr struct {
Name [16]byte
Addr unix.RawSockaddrInet4
pad [8]byte
}
type ifreqMTU struct {
Name [16]byte
MTU int32
@ -64,10 +62,10 @@ type ifreqQLEN struct {
pad [8]byte
}
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
t, err := newTunGeneric(c, l, file, cidr)
t, err := newTunGeneric(c, l, file, vpnNetworks)
if err != nil {
return nil, err
}
@ -77,7 +75,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix
return t, nil
}
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (*tun, error) {
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
@ -112,7 +110,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (
name := strings.Trim(string(req.Name[:]), "\x00")
file := os.NewFile(uintptr(fd), "/dev/net/tun")
t, err := newTunGeneric(c, l, file, cidr)
t, err := newTunGeneric(c, l, file, vpnNetworks)
if err != nil {
return nil, err
}
@ -122,11 +120,11 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (
return t, nil
}
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Prefix) (*tun, error) {
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
t := &tun{
ReadWriteCloser: file,
fd: int(file.Fd()),
cidr: cidr,
vpnNetworks: vpnNetworks,
TXQueueLen: c.GetInt("tun.tx_queue", 500),
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
l: l,
@ -148,7 +146,7 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Pref
}
func (t *tun) reload(c *config.C, initial bool) error {
routeChange, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
routeChange, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil {
return err
}
@ -190,11 +188,13 @@ func (t *tun) reload(c *config.C, initial bool) error {
}
if oldDefaultMTU != newDefaultMTU {
err := t.setDefaultRoute()
if err != nil {
t.l.Warn(err)
} else {
t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU)
for i := range t.vpnNetworks {
err := t.setDefaultRoute(t.vpnNetworks[i])
if err != nil {
t.l.Warn(err)
} else {
t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU)
}
}
}
@ -237,10 +237,10 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
func (t *tun) Write(b []byte) (int, error) {
var nn int
max := len(b)
maximum := len(b)
for {
n, err := unix.Write(t.fd, b[nn:max])
n, err := unix.Write(t.fd, b[nn:maximum])
if n > 0 {
nn += n
}
@ -265,6 +265,58 @@ func (t *tun) deviceBytes() (o [16]byte) {
return
}
func hasNetlinkAddr(al []*netlink.Addr, x netlink.Addr) bool {
for i := range al {
if al[i].Equal(x) {
return true
}
}
return false
}
// addIPs uses netlink to add all addresses that don't exist, then it removes ones that should not be there
func (t *tun) addIPs(link netlink.Link) error {
newAddrs := make([]*netlink.Addr, len(t.vpnNetworks))
for i := range t.vpnNetworks {
newAddrs[i] = &netlink.Addr{
IPNet: &net.IPNet{
IP: t.vpnNetworks[i].Addr().AsSlice(),
Mask: net.CIDRMask(t.vpnNetworks[i].Bits(), t.vpnNetworks[i].Addr().BitLen()),
},
Label: t.vpnNetworks[i].Addr().Zone(),
}
}
//add all new addresses
for i := range newAddrs {
//todo do we want to stack errors and try as many ops as possible?
//AddrReplace still adds new IPs, but if their properties change it will change them as well
if err := netlink.AddrReplace(link, newAddrs[i]); err != nil {
return err
}
}
//iterate over remainder, remove whoever shouldn't be there
al, err := netlink.AddrList(link, netlink.FAMILY_ALL)
if err != nil {
return fmt.Errorf("failed to get tun address list: %s", err)
}
for i := range al {
if hasNetlinkAddr(newAddrs, al[i]) {
continue
}
err = netlink.AddrDel(link, &al[i])
if err != nil {
t.l.WithError(err).Error("failed to remove address from tun address list")
} else {
t.l.WithField("removed", al[i].String()).Info("removed address not listed in cert(s)")
}
}
return nil
}
func (t *tun) Activate() error {
devName := t.deviceBytes()
@ -272,15 +324,8 @@ func (t *tun) Activate() error {
t.watchRoutes()
}
var addr, mask [4]byte
//TODO: IPV6-WORK
addr = t.cidr.Addr().As4()
tmask := net.CIDRMask(t.cidr.Bits(), 32)
copy(mask[:], tmask)
s, err := unix.Socket(
unix.AF_INET,
unix.AF_INET, //because everything we use t.ioctlFd for is address family independent, this is fine
unix.SOCK_DGRAM,
unix.IPPROTO_IP,
)
@ -289,31 +334,19 @@ func (t *tun) Activate() error {
}
t.ioctlFd = uintptr(s)
ifra := ifreqAddr{
Name: devName,
Addr: unix.RawSockaddrInet4{
Family: unix.AF_INET,
Addr: addr,
},
}
// Set the device ip address
if err = ioctl(t.ioctlFd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil {
return fmt.Errorf("failed to set tun address: %s", err)
}
// Set the device network
ifra.Addr.Addr = mask
if err = ioctl(t.ioctlFd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil {
return fmt.Errorf("failed to set tun netmask: %s", err)
}
// Set the device name
ifrf := ifReq{Name: devName}
if err = ioctl(t.ioctlFd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to set tun device name: %s", err)
}
link, err := netlink.LinkByName(t.Device)
if err != nil {
return fmt.Errorf("failed to get tun device link: %s", err)
}
t.deviceIndex = link.Attrs().Index
// Setup our default MTU
t.setMTU()
@ -324,33 +357,36 @@ func (t *tun) Activate() error {
t.l.WithError(err).Error("Failed to set tun tx queue length")
}
if err = t.addIPs(link); err != nil {
return err
}
// Bring up the interface
ifrf.Flags = ifrf.Flags | unix.IFF_UP
if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to bring the tun device up: %s", err)
}
link, err := netlink.LinkByName(t.Device)
if err != nil {
return fmt.Errorf("failed to get tun device link: %s", err)
}
t.deviceIndex = link.Attrs().Index
if err = t.setDefaultRoute(); err != nil {
return err
}
// Set the routes
if err = t.addRoutes(false); err != nil {
return err
}
// Run the interface
ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING
if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to run tun device: %s", err)
}
//set route MTU
for i := range t.vpnNetworks {
if err = t.setDefaultRoute(t.vpnNetworks[i]); err != nil {
return fmt.Errorf("failed to set default route MTU: %w", err)
}
}
// Set the routes
if err = t.addRoutes(false); err != nil {
return err
}
//todo do we want to keep the link-local address?
return nil
}
@ -363,12 +399,12 @@ func (t *tun) setMTU() {
}
}
func (t *tun) setDefaultRoute() error {
func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
// Default route
dr := &net.IPNet{
IP: t.cidr.Masked().Addr().AsSlice(),
Mask: net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen()),
IP: cidr.Masked().Addr().AsSlice(),
Mask: net.CIDRMask(cidr.Bits(), cidr.Addr().BitLen()),
}
nr := netlink.Route{
@ -377,7 +413,7 @@ func (t *tun) setDefaultRoute() error {
MTU: t.DefaultMTU,
AdvMSS: t.advMSS(Route{}),
Scope: unix.RT_SCOPE_LINK,
Src: net.IP(t.cidr.Addr().AsSlice()),
Src: net.IP(cidr.Addr().AsSlice()),
Protocol: unix.RTPROT_KERNEL,
Table: unix.RT_TABLE_MAIN,
Type: unix.RTN_UNICAST,
@ -463,10 +499,6 @@ func (t *tun) removeRoutes(routes []Route) {
}
}
func (t *tun) Cidr() netip.Prefix {
return t.cidr
}
func (t *tun) Name() string {
return t.Device
}
@ -523,9 +555,16 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
}
gwAddr = gwAddr.Unmap()
if !t.cidr.Contains(gwAddr) {
withinNetworks := false
for i := range t.vpnNetworks {
if t.vpnNetworks[i].Contains(gwAddr) {
withinNetworks = true
break
}
}
if !withinNetworks {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
t.l.WithField("route", r).Debug("Ignoring route update, not in our networks")
return
}

View File

@ -27,12 +27,12 @@ type ifreqDestroy struct {
}
type tun struct {
Device string
cidr netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger
Device string
vpnNetworks []netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger
io.ReadWriteCloser
}
@ -58,13 +58,13 @@ func (t *tun) Close() error {
return nil
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
}
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open tun device
var file *os.File
var err error
@ -84,7 +84,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
t := &tun{
ReadWriteCloser: file,
Device: deviceName,
cidr: cidr,
vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
}
@ -104,17 +104,17 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
return t, nil
}
func (t *tun) Activate() error {
func (t *tun) addIp(cidr netip.Prefix) error {
var err error
// TODO use syscalls instead of exec.Command
cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String())
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), t.cidr.Addr().String())
cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err)
@ -130,8 +130,18 @@ func (t *tun) Activate() error {
return t.addRoutes(false)
}
func (t *tun) Activate() error {
for i := range t.vpnNetworks {
err := t.addIp(t.vpnNetworks[i])
if err != nil {
return err
}
}
return nil
}
func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil {
return err
}
@ -172,8 +182,8 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
return r
}
func (t *tun) Cidr() netip.Prefix {
return t.cidr
func (t *tun) Networks() []netip.Prefix {
return t.vpnNetworks
}
func (t *tun) Name() string {
@ -192,7 +202,7 @@ func (t *tun) addRoutes(logErrors bool) error {
continue
}
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.Addr().String())
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
@ -213,7 +223,8 @@ func (t *tun) removeRoutes(routes []Route) error {
continue
}
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.cidr.Addr().String())
//todo is this right?
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")

View File

@ -21,12 +21,12 @@ import (
)
type tun struct {
Device string
cidr netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger
Device string
vpnNetworks []netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger
io.ReadWriteCloser
@ -42,13 +42,13 @@ func (t *tun) Close() error {
return nil
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD")
}
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
deviceName := c.GetString("tun.dev", "")
if deviceName == "" {
return nil, fmt.Errorf("a device name in the format of tunN must be specified")
@ -66,7 +66,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
t := &tun{
ReadWriteCloser: file,
Device: deviceName,
cidr: cidr,
vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
}
@ -87,7 +87,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
}
func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil {
return err
}
@ -123,10 +123,10 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil
}
func (t *tun) Activate() error {
func (t *tun) addIp(cidr netip.Prefix) error {
var err error
// TODO use syscalls instead of exec.Command
cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String())
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
@ -138,7 +138,7 @@ func (t *tun) Activate() error {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
cmd = exec.Command("/sbin/route", "-n", "add", "-inet", t.cidr.String(), t.cidr.Addr().String())
cmd = exec.Command("/sbin/route", "-n", "add", "-inet", cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err)
@ -148,6 +148,16 @@ func (t *tun) Activate() error {
return t.addRoutes(false)
}
func (t *tun) Activate() error {
for i := range t.vpnNetworks {
err := t.addIp(t.vpnNetworks[i])
if err != nil {
return err
}
}
return nil
}
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
r, _ := t.routeTree.Load().Lookup(ip)
return r
@ -160,8 +170,8 @@ func (t *tun) addRoutes(logErrors bool) error {
// We don't allow route MTUs so only install routes with a via
continue
}
cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.Addr().String())
//todo is this right?
cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
@ -181,8 +191,8 @@ func (t *tun) removeRoutes(routes []Route) error {
if !r.Install {
continue
}
cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.cidr.Addr().String())
//todo is this right?
cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
@ -193,8 +203,8 @@ func (t *tun) removeRoutes(routes []Route) error {
return nil
}
func (t *tun) Cidr() netip.Prefix {
return t.cidr
func (t *tun) Networks() []netip.Prefix {
return t.vpnNetworks
}
func (t *tun) Name() string {

View File

@ -16,19 +16,19 @@ import (
)
type TestTun struct {
Device string
cidr netip.Prefix
Routes []Route
routeTree *bart.Table[netip.Addr]
l *logrus.Logger
Device string
vpnNetworks []netip.Prefix
Routes []Route
routeTree *bart.Table[netip.Addr]
l *logrus.Logger
closed atomic.Bool
rxPackets chan []byte // Packets to receive into nebula
TxPackets chan []byte // Packets transmitted outside by nebula
}
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*TestTun, error) {
_, routes, err := getAllRoutesFromConfig(c, cidr, true)
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) {
_, routes, err := getAllRoutesFromConfig(c, vpnNetworks, true)
if err != nil {
return nil, err
}
@ -38,17 +38,17 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*TestTun,
}
return &TestTun{
Device: c.GetString("tun.dev", ""),
cidr: cidr,
Routes: routes,
routeTree: routeTree,
l: l,
rxPackets: make(chan []byte, 10),
TxPackets: make(chan []byte, 10),
Device: c.GetString("tun.dev", ""),
vpnNetworks: vpnNetworks,
Routes: routes,
routeTree: routeTree,
l: l,
rxPackets: make(chan []byte, 10),
TxPackets: make(chan []byte, 10),
}, nil
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*TestTun, error) {
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*TestTun, error) {
return nil, fmt.Errorf("newTunFromFd not supported")
}
@ -95,8 +95,8 @@ func (t *TestTun) Activate() error {
return nil
}
func (t *TestTun) Cidr() netip.Prefix {
return t.cidr
func (t *TestTun) Networks() []netip.Prefix {
return t.vpnNetworks
}
func (t *TestTun) Name() string {

View File

@ -1,208 +0,0 @@
package overlay
import (
"fmt"
"io"
"net"
"net/netip"
"os/exec"
"strconv"
"sync/atomic"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/util"
"github.com/songgao/water"
)
type waterTun struct {
Device string
cidr netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger
f *net.Interface
*water.Interface
}
func newWaterTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*waterTun, error) {
// NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate()
t := &waterTun{
cidr: cidr,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
}
err := t.reload(c, true)
if err != nil {
return nil, err
}
c.RegisterReloadCallback(func(c *config.C) {
err := t.reload(c, false)
if err != nil {
util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
}
})
return t, nil
}
func (t *waterTun) Activate() error {
var err error
t.Interface, err = water.New(water.Config{
DeviceType: water.TUN,
PlatformSpecificParams: water.PlatformSpecificParams{
ComponentID: "tap0901",
Network: t.cidr.String(),
},
})
if err != nil {
return fmt.Errorf("activate failed: %v", err)
}
t.Device = t.Interface.Name()
// TODO use syscalls instead of exec.Command
err = exec.Command(
`C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "address",
fmt.Sprintf("name=%s", t.Device),
"source=static",
fmt.Sprintf("addr=%s", t.cidr.Addr()),
fmt.Sprintf("mask=%s", net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen())),
"gateway=none",
).Run()
if err != nil {
return fmt.Errorf("failed to run 'netsh' to set address: %s", err)
}
err = exec.Command(
`C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "interface",
t.Device,
fmt.Sprintf("mtu=%d", t.MTU),
).Run()
if err != nil {
return fmt.Errorf("failed to run 'netsh' to set MTU: %s", err)
}
t.f, err = net.InterfaceByName(t.Device)
if err != nil {
return fmt.Errorf("failed to find interface named %s: %v", t.Device, err)
}
err = t.addRoutes(false)
if err != nil {
return err
}
return nil
}
func (t *waterTun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
if err != nil {
return err
}
if !initial && !change {
return nil
}
routeTree, err := makeRouteTree(t.l, routes, false)
if err != nil {
return err
}
// Teach nebula how to handle the routes before establishing them in the system table
oldRoutes := t.Routes.Swap(&routes)
t.routeTree.Store(routeTree)
if !initial {
// Remove first, if the system removes a wanted route hopefully it will be re-added next
t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
// Ensure any routes we actually want are installed
err = t.addRoutes(true)
if err != nil {
// Catch any stray logs
util.LogWithContextIfNeeded("Failed to set routes", err, t.l)
} else {
for _, r := range findRemovedRoutes(routes, *oldRoutes) {
t.l.WithField("route", r).Info("Removed route")
}
}
}
return nil
}
func (t *waterTun) addRoutes(logErrors bool) error {
// Path routes
routes := *t.Routes.Load()
for _, r := range routes {
if !r.Via.IsValid() || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
err := exec.Command(
"C:\\Windows\\System32\\route.exe", "add", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(t.f.Index), "METRIC", strconv.Itoa(r.Metric),
).Run()
if err != nil {
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
if logErrors {
retErr.Log(t.l)
} else {
return retErr
}
} else {
t.l.WithField("route", r).Info("Added route")
}
}
return nil
}
func (t *waterTun) removeRoutes(routes []Route) {
for _, r := range routes {
if !r.Install {
continue
}
err := exec.Command(
"C:\\Windows\\System32\\route.exe", "delete", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(t.f.Index), "METRIC", strconv.Itoa(r.Metric),
).Run()
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
}
}
}
func (t *waterTun) RouteFor(ip netip.Addr) netip.Addr {
r, _ := t.routeTree.Load().Lookup(ip)
return r
}
func (t *waterTun) Cidr() netip.Prefix {
return t.cidr
}
func (t *waterTun) Name() string {
return t.Device
}
func (t *waterTun) Close() error {
if t.Interface == nil {
return nil
}
return t.Interface.Close()
}
func (t *waterTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
}

View File

@ -4,41 +4,267 @@
package overlay
import (
"crypto"
"fmt"
"io"
"net/netip"
"os"
"path/filepath"
"runtime"
"sync/atomic"
"syscall"
"unsafe"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/util"
"github.com/slackhq/nebula/wintun"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
)
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (Device, error) {
const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
type winTun struct {
Device string
vpnNetworks []netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger
tun *wintun.NativeTun
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) {
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
}
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (Device, error) {
useWintun := true
if err := checkWinTunExists(); err != nil {
l.WithError(err).Warn("Check Wintun driver failed, fallback to wintap driver")
useWintun = false
}
if useWintun {
device, err := newWinTun(c, l, cidr, multiqueue)
if err != nil {
return nil, fmt.Errorf("create Wintun interface failed, %w", err)
}
return device, nil
}
device, err := newWaterTun(c, l, cidr, multiqueue)
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) {
err := checkWinTunExists()
if err != nil {
return nil, fmt.Errorf("create wintap driver failed, %w", err)
return nil, fmt.Errorf("can not load the wintun driver: %w", err)
}
return device, nil
deviceName := c.GetString("tun.dev", "")
guid, err := generateGUIDByDeviceName(deviceName)
if err != nil {
return nil, fmt.Errorf("generate GUID failed: %w", err)
}
t := &winTun{
Device: deviceName,
vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
}
err = t.reload(c, true)
if err != nil {
return nil, err
}
var tunDevice wintun.Device
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
if err != nil {
// Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device.
// Trying a second time resolves the issue.
l.WithError(err).Debug("Failed to create wintun device, retrying")
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
if err != nil {
return nil, fmt.Errorf("create TUN device failed: %w", err)
}
}
t.tun = tunDevice.(*wintun.NativeTun)
c.RegisterReloadCallback(func(c *config.C) {
err := t.reload(c, false)
if err != nil {
util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
}
})
return t, nil
}
func (t *winTun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil {
return err
}
if !initial && !change {
return nil
}
routeTree, err := makeRouteTree(t.l, routes, false)
if err != nil {
return err
}
// Teach nebula how to handle the routes before establishing them in the system table
oldRoutes := t.Routes.Swap(&routes)
t.routeTree.Store(routeTree)
if !initial {
// Remove first, if the system removes a wanted route hopefully it will be re-added next
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
if err != nil {
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
}
// Ensure any routes we actually want are installed
err = t.addRoutes(true)
if err != nil {
// Catch any stray logs
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
}
}
return nil
}
func (t *winTun) Activate() error {
luid := winipcfg.LUID(t.tun.LUID())
err := luid.SetIPAddresses(t.vpnNetworks)
if err != nil {
return fmt.Errorf("failed to set address: %w", err)
}
err = t.addRoutes(false)
if err != nil {
return err
}
return nil
}
func (t *winTun) addRoutes(logErrors bool) error {
luid := winipcfg.LUID(t.tun.LUID())
routes := *t.Routes.Load()
foundDefault4 := false
for _, r := range routes {
if !r.Via.IsValid() || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
// Add our unsafe route
err := luid.AddRoute(r.Cidr, r.Via, uint32(r.Metric))
if err != nil {
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
if logErrors {
retErr.Log(t.l)
continue
} else {
return retErr
}
} else {
t.l.WithField("route", r).Info("Added route")
}
if !foundDefault4 {
if r.Cidr.Bits() == 0 && r.Cidr.Addr().BitLen() == 32 {
foundDefault4 = true
}
}
}
ipif, err := luid.IPInterface(windows.AF_INET)
if err != nil {
return fmt.Errorf("failed to get ip interface: %w", err)
}
ipif.NLMTU = uint32(t.MTU)
if foundDefault4 {
ipif.UseAutomaticMetric = false
ipif.Metric = 0
}
if err := ipif.Set(); err != nil {
return fmt.Errorf("failed to set ip interface: %w", err)
}
return nil
}
func (t *winTun) removeRoutes(routes []Route) error {
luid := winipcfg.LUID(t.tun.LUID())
for _, r := range routes {
if !r.Install {
continue
}
err := luid.DeleteRoute(r.Cidr, r.Via)
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
}
}
return nil
}
func (t *winTun) RouteFor(ip netip.Addr) netip.Addr {
r, _ := t.routeTree.Load().Lookup(ip)
return r
}
func (t *winTun) Networks() []netip.Prefix {
return t.vpnNetworks
}
func (t *winTun) Name() string {
return t.Device
}
func (t *winTun) Read(b []byte) (int, error) {
return t.tun.Read(b, 0)
}
func (t *winTun) Write(b []byte) (int, error) {
return t.tun.Write(b, 0)
}
func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
}
func (t *winTun) Close() error {
// It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes,
// so to be certain, just remove everything before destroying.
luid := winipcfg.LUID(t.tun.LUID())
_ = luid.FlushRoutes(windows.AF_INET)
_ = luid.FlushIPAddresses(windows.AF_INET)
/* We don't support IPV6 yet
_ = luid.FlushRoutes(windows.AF_INET6)
_ = luid.FlushIPAddresses(windows.AF_INET6)
*/
_ = luid.FlushDNS(windows.AF_INET)
return t.tun.Close()
}
func generateGUIDByDeviceName(name string) (*windows.GUID, error) {
// GUID is 128 bit
hash := crypto.MD5.New()
_, err := hash.Write([]byte(tunGUIDLabel))
if err != nil {
return nil, err
}
_, err = hash.Write([]byte(name))
if err != nil {
return nil, err
}
sum := hash.Sum(nil)
return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil
}
func checkWinTunExists() error {

View File

@ -1,252 +0,0 @@
package overlay
import (
"crypto"
"fmt"
"io"
"net/netip"
"sync/atomic"
"unsafe"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/util"
"github.com/slackhq/nebula/wintun"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
)
const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
type winTun struct {
Device string
cidr netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger
tun *wintun.NativeTun
}
func generateGUIDByDeviceName(name string) (*windows.GUID, error) {
// GUID is 128 bit
hash := crypto.MD5.New()
_, err := hash.Write([]byte(tunGUIDLabel))
if err != nil {
return nil, err
}
_, err = hash.Write([]byte(name))
if err != nil {
return nil, err
}
sum := hash.Sum(nil)
return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil
}
func newWinTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*winTun, error) {
deviceName := c.GetString("tun.dev", "")
guid, err := generateGUIDByDeviceName(deviceName)
if err != nil {
return nil, fmt.Errorf("generate GUID failed: %w", err)
}
t := &winTun{
Device: deviceName,
cidr: cidr,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
}
err = t.reload(c, true)
if err != nil {
return nil, err
}
var tunDevice wintun.Device
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
if err != nil {
// Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device.
// Trying a second time resolves the issue.
l.WithError(err).Debug("Failed to create wintun device, retrying")
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
if err != nil {
return nil, fmt.Errorf("create TUN device failed: %w", err)
}
}
t.tun = tunDevice.(*wintun.NativeTun)
c.RegisterReloadCallback(func(c *config.C) {
err := t.reload(c, false)
if err != nil {
util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
}
})
return t, nil
}
func (t *winTun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
if err != nil {
return err
}
if !initial && !change {
return nil
}
routeTree, err := makeRouteTree(t.l, routes, false)
if err != nil {
return err
}
// Teach nebula how to handle the routes before establishing them in the system table
oldRoutes := t.Routes.Swap(&routes)
t.routeTree.Store(routeTree)
if !initial {
// Remove first, if the system removes a wanted route hopefully it will be re-added next
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
if err != nil {
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
}
// Ensure any routes we actually want are installed
err = t.addRoutes(true)
if err != nil {
// Catch any stray logs
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
}
}
return nil
}
func (t *winTun) Activate() error {
luid := winipcfg.LUID(t.tun.LUID())
err := luid.SetIPAddresses([]netip.Prefix{t.cidr})
if err != nil {
return fmt.Errorf("failed to set address: %w", err)
}
err = t.addRoutes(false)
if err != nil {
return err
}
return nil
}
func (t *winTun) addRoutes(logErrors bool) error {
luid := winipcfg.LUID(t.tun.LUID())
routes := *t.Routes.Load()
foundDefault4 := false
for _, r := range routes {
if !r.Via.IsValid() || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
// Add our unsafe route
err := luid.AddRoute(r.Cidr, r.Via, uint32(r.Metric))
if err != nil {
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
if logErrors {
retErr.Log(t.l)
continue
} else {
return retErr
}
} else {
t.l.WithField("route", r).Info("Added route")
}
if !foundDefault4 {
if r.Cidr.Bits() == 0 && r.Cidr.Addr().BitLen() == 32 {
foundDefault4 = true
}
}
}
ipif, err := luid.IPInterface(windows.AF_INET)
if err != nil {
return fmt.Errorf("failed to get ip interface: %w", err)
}
ipif.NLMTU = uint32(t.MTU)
if foundDefault4 {
ipif.UseAutomaticMetric = false
ipif.Metric = 0
}
if err := ipif.Set(); err != nil {
return fmt.Errorf("failed to set ip interface: %w", err)
}
return nil
}
func (t *winTun) removeRoutes(routes []Route) error {
luid := winipcfg.LUID(t.tun.LUID())
for _, r := range routes {
if !r.Install {
continue
}
err := luid.DeleteRoute(r.Cidr, r.Via)
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
}
}
return nil
}
func (t *winTun) RouteFor(ip netip.Addr) netip.Addr {
r, _ := t.routeTree.Load().Lookup(ip)
return r
}
func (t *winTun) Cidr() netip.Prefix {
return t.cidr
}
func (t *winTun) Name() string {
return t.Device
}
func (t *winTun) Read(b []byte) (int, error) {
return t.tun.Read(b, 0)
}
func (t *winTun) Write(b []byte) (int, error) {
return t.tun.Write(b, 0)
}
func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
}
func (t *winTun) Close() error {
// It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes,
// so to be certain, just remove everything before destroying.
luid := winipcfg.LUID(t.tun.LUID())
_ = luid.FlushRoutes(windows.AF_INET)
_ = luid.FlushIPAddresses(windows.AF_INET)
/* We don't support IPV6 yet
_ = luid.FlushRoutes(windows.AF_INET6)
_ = luid.FlushIPAddresses(windows.AF_INET6)
*/
_ = luid.FlushDNS(windows.AF_INET)
return t.tun.Close()
}

View File

@ -8,16 +8,16 @@ import (
"github.com/slackhq/nebula/config"
)
func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) {
return NewUserDevice(tunCidr)
func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
return NewUserDevice(vpnNetworks)
}
func NewUserDevice(tunCidr netip.Prefix) (Device, error) {
func NewUserDevice(vpnNetworks []netip.Prefix) (Device, error) {
// these pipes guarantee each write/read will match 1:1
or, ow := io.Pipe()
ir, iw := io.Pipe()
return &UserDevice{
tunCidr: tunCidr,
vpnNetworks: vpnNetworks,
outboundReader: or,
outboundWriter: ow,
inboundReader: ir,
@ -26,7 +26,7 @@ func NewUserDevice(tunCidr netip.Prefix) (Device, error) {
}
type UserDevice struct {
tunCidr netip.Prefix
vpnNetworks []netip.Prefix
outboundReader *io.PipeReader
outboundWriter *io.PipeWriter
@ -38,7 +38,7 @@ type UserDevice struct {
func (d *UserDevice) Activate() error {
return nil
}
func (d *UserDevice) Cidr() netip.Prefix { return d.tunCidr }
func (d *UserDevice) Networks() []netip.Prefix { return d.vpnNetworks }
func (d *UserDevice) Name() string { return "faketun0" }
func (d *UserDevice) RouteFor(ip netip.Addr) netip.Addr { return ip }
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {

425
pki.go
View File

@ -1,13 +1,19 @@
package nebula
import (
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"net"
"net/netip"
"os"
"slices"
"strings"
"sync/atomic"
"time"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
@ -21,12 +27,22 @@ type PKI struct {
}
type CertState struct {
Certificate cert.Certificate
RawCertificate []byte
RawCertificateNoKey []byte
PublicKey []byte
PrivateKey []byte
pkcs11Backed bool
v1Cert cert.Certificate
v1HandshakeBytes []byte
v2Cert cert.Certificate
v2HandshakeBytes []byte
defaultVersion cert.Version
privateKey []byte
pkcs11Backed bool
cipher string
myVpnNetworks []netip.Prefix
myVpnNetworksTable *bart.Table[struct{}]
myVpnAddrs []netip.Addr
myVpnAddrsTable *bart.Table[struct{}]
myVpnBroadcastAddrsTable *bart.Table[struct{}]
}
func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
@ -46,16 +62,26 @@ func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
return pki, nil
}
func (p *PKI) GetCertState() *CertState {
return p.cs.Load()
}
func (p *PKI) GetCAPool() *cert.CAPool {
return p.caPool.Load()
}
func (p *PKI) getCertState() *CertState {
return p.cs.Load()
}
// TODO: We should remove this
func (p *PKI) getDefaultCertificate() cert.Certificate {
return p.cs.Load().GetDefaultCertificate()
}
// TODO: We should remove this
func (p *PKI) getCertificate(v cert.Version) cert.Certificate {
return p.cs.Load().getCertificate(v)
}
func (p *PKI) reload(c *config.C, initial bool) error {
err := p.reloadCert(c, initial)
err := p.reloadCerts(c, initial)
if err != nil {
if initial {
return err
@ -74,33 +100,94 @@ func (p *PKI) reload(c *config.C, initial bool) error {
return nil
}
func (p *PKI) reloadCert(c *config.C, initial bool) *util.ContextualError {
cs, err := newCertStateFromConfig(c)
func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
newState, err := newCertStateFromConfig(c)
if err != nil {
return util.NewContextualError("Could not load client cert", nil, err)
}
if !initial {
//TODO: include check for mask equality as well
currentState := p.cs.Load()
if newState.v1Cert != nil {
if currentState.v1Cert == nil {
return util.NewContextualError("v1 certificate was added, restart required", nil, err)
}
// did IP in cert change? if so, don't set
currentCert := p.cs.Load().Certificate
oldIPs := currentCert.Networks()
newIPs := cs.Certificate.Networks()
if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() {
// did IP in cert change? if so, don't set
if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
return util.NewContextualError(
"Networks in new cert was different from old",
m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks()},
nil,
)
}
if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
return util.NewContextualError(
"Curve in new cert was different from old",
m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve()},
nil,
)
}
} else if currentState.v1Cert != nil {
//TODO: we should be able to tear this down
return util.NewContextualError("v1 certificate was removed, restart required", nil, err)
}
if newState.v2Cert != nil {
if currentState.v2Cert == nil {
return util.NewContextualError("v2 certificate was added, restart required", nil, err)
}
// did IP in cert change? if so, don't set
if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
return util.NewContextualError(
"Networks in new cert was different from old",
m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks()},
nil,
)
}
if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
return util.NewContextualError(
"Curve in new cert was different from old",
m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve()},
nil,
)
}
} else if currentState.v2Cert != nil {
return util.NewContextualError("v2 certificate was removed, restart required", nil, err)
}
// Cipher cant be hot swapped so just leave it at what it was before
newState.cipher = currentState.cipher
} else {
newState.cipher = c.GetString("cipher", "aes")
//TODO: this sucks and we should make it not a global
switch newState.cipher {
case "aes":
noiseEndianness = binary.BigEndian
case "chachapoly":
noiseEndianness = binary.LittleEndian
default:
return util.NewContextualError(
"Networks in new cert was different from old",
m{"new_network": newIPs[0], "old_network": oldIPs[0]},
"unknown cipher",
m{"cipher": newState.cipher},
nil,
)
}
}
p.cs.Store(cs)
p.cs.Store(newState)
//TODO: newState needs a stringer that does json
if initial {
p.l.WithField("cert", cs.Certificate).Debug("Client nebula certificate")
p.l.WithField("cert", newState).Debug("Client nebula certificate(s)")
} else {
p.l.WithField("cert", cs.Certificate).Info("Client cert refreshed from disk")
p.l.WithField("cert", newState).Info("Client certificate(s) refreshed from disk")
}
return nil
}
@ -116,55 +203,65 @@ func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError {
return nil
}
func newCertState(certificate cert.Certificate, pkcs11backed bool, privateKey []byte) (*CertState, error) {
// Marshal the certificate to ensure it is valid
rawCertificate, err := certificate.Marshal()
if err != nil {
return nil, fmt.Errorf("invalid nebula certificate on interface: %s", err)
func (cs *CertState) GetDefaultCertificate() cert.Certificate {
c := cs.getCertificate(cs.defaultVersion)
if c == nil {
panic("No default certificate found")
}
publicKey := certificate.PublicKey()
cs := &CertState{
RawCertificate: rawCertificate,
Certificate: certificate,
PrivateKey: privateKey,
PublicKey: publicKey,
pkcs11Backed: pkcs11backed,
}
rawCertNoKey, err := cs.Certificate.MarshalForHandshakes()
if err != nil {
return nil, fmt.Errorf("error marshalling certificate no key: %s", err)
}
cs.RawCertificateNoKey = rawCertNoKey
return cs, nil
return c
}
func loadPrivateKey(privPathOrPEM string) (rawKey []byte, curve cert.Curve, isPkcs11 bool, err error) {
var pemPrivateKey []byte
if strings.Contains(privPathOrPEM, "-----BEGIN") {
pemPrivateKey = []byte(privPathOrPEM)
privPathOrPEM = "<inline>"
rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey)
if err != nil {
return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
}
} else if strings.HasPrefix(privPathOrPEM, "pkcs11:") {
rawKey = []byte(privPathOrPEM)
return rawKey, cert.Curve_P256, true, nil
} else {
pemPrivateKey, err = os.ReadFile(privPathOrPEM)
if err != nil {
return nil, curve, false, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err)
}
rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey)
if err != nil {
return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
}
func (cs *CertState) getCertificate(v cert.Version) cert.Certificate {
switch v {
case cert.Version1:
return cs.v1Cert
case cert.Version2:
return cs.v2Cert
}
return
return nil
}
// getHandshakeBytes returns the cached bytes to be used in a handshake message for the requested version.
// Callers must check if the return []byte is nil.
func (cs *CertState) getHandshakeBytes(v cert.Version) []byte {
switch v {
case cert.Version1:
return cs.v1HandshakeBytes
case cert.Version2:
return cs.v2HandshakeBytes
default:
return nil
}
}
func (cs *CertState) String() string {
b, err := cs.MarshalJSON()
if err != nil {
return fmt.Sprintf("error marshaling certificate state: %v", err)
}
return string(b)
}
func (cs *CertState) MarshalJSON() ([]byte, error) {
msg := []json.RawMessage{}
if cs.v1Cert != nil {
b, err := cs.v1Cert.MarshalJSON()
if err != nil {
return nil, err
}
msg = append(msg, b)
}
if cs.v2Cert != nil {
b, err := cs.v2Cert.MarshalJSON()
if err != nil {
return nil, err
}
msg = append(msg, b)
}
return json.Marshal(msg)
}
func newCertStateFromConfig(c *config.C) (*CertState, error) {
@ -198,24 +295,198 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) {
}
}
nebulaCert, _, err := cert.UnmarshalCertificateFromPEM(rawCert)
var crt, v1, v2 cert.Certificate
for {
// Load the certificate
crt, rawCert, err = loadCertificate(rawCert)
if err != nil {
//TODO: check error
return nil, err
}
switch crt.Version() {
case cert.Version1:
if v1 != nil {
return nil, fmt.Errorf("v1 certificate already found in pki.cert")
}
v1 = crt
case cert.Version2:
if v2 != nil {
return nil, fmt.Errorf("v2 certificate already found in pki.cert")
}
v2 = crt
default:
return nil, fmt.Errorf("unknown certificate version %v", crt.Version())
}
if len(rawCert) == 0 || strings.TrimSpace(string(rawCert)) == "" {
break
}
}
if v1 == nil && v2 == nil {
return nil, errors.New("no certificates found in pki.cert")
}
useDefaultVersion := uint32(1)
if v1 == nil {
// The only condition that requires v2 as the default is if only a v2 certificate is present
// We do this to avoid having to configure it specifically in the config file
useDefaultVersion = 2
}
rawDefaultVersion := c.GetUint32("pki.default_version", useDefaultVersion)
var defaultVersion cert.Version
switch rawDefaultVersion {
case 1:
if v1 == nil {
return nil, fmt.Errorf("can not use pki.default_version 1 without a v1 certificate in pki.cert")
}
defaultVersion = cert.Version1
case 2:
defaultVersion = cert.Version2
default:
return nil, fmt.Errorf("unknown pki.default_version: %v", rawDefaultVersion)
}
return newCertState(defaultVersion, v1, v2, isPkcs11, curve, rawKey)
}
func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) {
cs := CertState{
privateKey: privateKey,
pkcs11Backed: pkcs11backed,
myVpnNetworksTable: new(bart.Table[struct{}]),
myVpnAddrsTable: new(bart.Table[struct{}]),
myVpnBroadcastAddrsTable: new(bart.Table[struct{}]),
}
if v1 != nil && v2 != nil {
if !slices.Equal(v1.PublicKey(), v2.PublicKey()) {
return nil, util.NewContextualError("v1 and v2 public keys are not the same, ignoring", nil, nil)
}
if v1.Curve() != v2.Curve() {
return nil, util.NewContextualError("v1 and v2 curve are not the same, ignoring", nil, nil)
}
//TODO: make sure v2 has v1s address
cs.defaultVersion = dv
}
if v1 != nil {
if pkcs11backed {
//TODO: We do not currently have a method to verify a public private key pair when the private key is in an hsm
} else {
if err := v1.VerifyPrivateKey(privateKeyCurve, privateKey); err != nil {
return nil, fmt.Errorf("private key is not a pair with public key in nebula cert")
}
}
v1hs, err := v1.MarshalForHandshakes()
if err != nil {
return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err)
}
cs.v1Cert = v1
cs.v1HandshakeBytes = v1hs
if cs.defaultVersion == 0 {
cs.defaultVersion = cert.Version1
}
}
if v2 != nil {
if pkcs11backed {
//TODO: We do not currently have a method to verify a public private key pair when the private key is in an hsm
} else {
if err := v2.VerifyPrivateKey(privateKeyCurve, privateKey); err != nil {
return nil, fmt.Errorf("private key is not a pair with public key in nebula cert")
}
}
v2hs, err := v2.MarshalForHandshakes()
if err != nil {
return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err)
}
cs.v2Cert = v2
cs.v2HandshakeBytes = v2hs
if cs.defaultVersion == 0 {
cs.defaultVersion = cert.Version2
}
}
var crt cert.Certificate
crt = cs.getCertificate(cert.Version2)
if crt == nil {
// v2 certificates are a superset, only look at v1 if its all we have
crt = cs.getCertificate(cert.Version1)
}
for _, network := range crt.Networks() {
cs.myVpnNetworks = append(cs.myVpnNetworks, network)
cs.myVpnNetworksTable.Insert(network, struct{}{})
cs.myVpnAddrs = append(cs.myVpnAddrs, network.Addr())
cs.myVpnAddrsTable.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()), struct{}{})
if network.Addr().Is4() {
addr := network.Masked().Addr().As4()
mask := net.CIDRMask(network.Bits(), network.Addr().BitLen())
binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(mask))
cs.myVpnBroadcastAddrsTable.Insert(netip.PrefixFrom(netip.AddrFrom4(addr), network.Addr().BitLen()), struct{}{})
}
}
return &cs, nil
}
func loadPrivateKey(privPathOrPEM string) (rawKey []byte, curve cert.Curve, isPkcs11 bool, err error) {
var pemPrivateKey []byte
if strings.Contains(privPathOrPEM, "-----BEGIN") {
pemPrivateKey = []byte(privPathOrPEM)
privPathOrPEM = "<inline>"
rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey)
if err != nil {
return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
}
} else if strings.HasPrefix(privPathOrPEM, "pkcs11:") {
rawKey = []byte(privPathOrPEM)
return rawKey, cert.Curve_P256, true, nil
} else {
pemPrivateKey, err = os.ReadFile(privPathOrPEM)
if err != nil {
return nil, curve, false, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err)
}
rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey)
if err != nil {
return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
}
}
return
}
func loadCertificate(b []byte) (cert.Certificate, []byte, error) {
c, b, err := cert.UnmarshalCertificateFromPEM(b)
if err != nil {
return nil, fmt.Errorf("error while unmarshaling pki.cert %s: %s", pubPathOrPEM, err)
return nil, b, fmt.Errorf("error while unmarshaling pki.cert: %w", err)
}
if nebulaCert.Expired(time.Now()) {
return nil, fmt.Errorf("nebula certificate for this host is expired")
if c.Expired(time.Now()) {
return nil, b, fmt.Errorf("nebula certificate for this host is expired")
}
if len(nebulaCert.Networks()) == 0 {
return nil, fmt.Errorf("no networks encoded in certificate")
if len(c.Networks()) == 0 {
return nil, b, fmt.Errorf("no networks encoded in certificate")
}
if err = nebulaCert.VerifyPrivateKey(curve, rawKey); err != nil {
return nil, fmt.Errorf("private key is not a pair with public key in nebula cert")
if c.IsCA() {
return nil, b, fmt.Errorf("host certificate is a CA certificate")
}
return newCertState(nebulaCert, isPkcs11, rawKey)
return c, b, nil
}
func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {

View File

@ -9,6 +9,7 @@ import (
"sync/atomic"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
)
@ -72,7 +73,7 @@ func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp neti
Type: relayType,
State: state,
LocalIndex: index,
PeerIp: vpnIp,
PeerAddr: vpnIp,
}
if remoteIdx != nil {
@ -91,40 +92,60 @@ func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp neti
func (rm *relayManager) EstablishRelay(relayHostInfo *HostInfo, m *NebulaControl) (*Relay, error) {
relay, ok := relayHostInfo.relayState.CompleteRelayByIdx(m.InitiatorRelayIndex, m.ResponderRelayIndex)
if !ok {
rm.l.WithFields(logrus.Fields{"relay": relayHostInfo.vpnIp,
//TODO: we need to handle possibly logging deprecated fields as well
rm.l.WithFields(logrus.Fields{"relay": relayHostInfo.vpnAddrs[0],
"initiatorRelayIndex": m.InitiatorRelayIndex,
"relayFrom": m.RelayFromIp,
"relayTo": m.RelayToIp}).Info("relayManager failed to update relay")
"relayFrom": m.RelayFromAddr,
"relayTo": m.RelayToAddr}).Info("relayManager failed to update relay")
return nil, fmt.Errorf("unknown relay")
}
return relay, nil
}
func (rm *relayManager) HandleControlMsg(h *HostInfo, m *NebulaControl, f *Interface) {
func (rm *relayManager) HandleControlMsg(h *HostInfo, d []byte, f *Interface) {
msg := &NebulaControl{}
err := msg.Unmarshal(d)
if err != nil {
h.logger(f.l).WithError(err).Error("Failed to unmarshal control message")
return
}
switch m.Type {
var v cert.Version
if msg.OldRelayFromAddr > 0 || msg.OldRelayToAddr > 0 {
v = cert.Version1
//TODO: yeah this is junk but maybe its less junky than the other options
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], msg.OldRelayFromAddr)
msg.RelayFromAddr = netAddrToProtoAddr(netip.AddrFrom4(b))
binary.BigEndian.PutUint32(b[:], msg.OldRelayToAddr)
msg.RelayToAddr = netAddrToProtoAddr(netip.AddrFrom4(b))
} else {
v = cert.Version2
}
switch msg.Type {
case NebulaControl_CreateRelayRequest:
rm.handleCreateRelayRequest(h, f, m)
rm.handleCreateRelayRequest(v, h, f, msg)
case NebulaControl_CreateRelayResponse:
rm.handleCreateRelayResponse(h, f, m)
rm.handleCreateRelayResponse(v, h, f, msg)
}
}
func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *NebulaControl) {
func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) {
rm.l.WithFields(logrus.Fields{
"relayFrom": m.RelayFromIp,
"relayTo": m.RelayToIp,
"relayFrom": m.RelayFromAddr,
"relayTo": m.RelayToAddr,
"initiatorRelayIndex": m.InitiatorRelayIndex,
"responderRelayIndex": m.ResponderRelayIndex,
"vpnIp": h.vpnIp}).
"vpnAddrs": h.vpnAddrs}).
Info("handleCreateRelayResponse")
target := m.RelayToIp
//TODO: IPV6-WORK
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], m.RelayToIp)
targetAddr := netip.AddrFrom4(b)
target := m.RelayToAddr
targetAddr := protoAddrToNetAddr(target)
relay, err := rm.EstablishRelay(h, m)
if err != nil {
@ -136,68 +157,79 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *
return
}
// I'm the middle man. Let the initiator know that the I've established the relay they requested.
peerHostInfo := rm.hostmap.QueryVpnIp(relay.PeerIp)
peerHostInfo := rm.hostmap.QueryVpnAddr(relay.PeerAddr)
if peerHostInfo == nil {
rm.l.WithField("relayTo", relay.PeerIp).Error("Can't find a HostInfo for peer")
rm.l.WithField("relayTo", relay.PeerAddr).Error("Can't find a HostInfo for peer")
return
}
peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(targetAddr)
if !ok {
rm.l.WithField("relayTo", peerHostInfo.vpnIp).Error("peerRelay does not have Relay state for relayTo")
rm.l.WithField("relayTo", peerHostInfo.vpnAddrs[0]).Error("peerRelay does not have Relay state for relayTo")
return
}
if peerRelay.State == PeerRequested {
//TODO: IPV6-WORK
b = peerHostInfo.vpnIp.As4()
peerRelay.State = Established
resp := NebulaControl{
Type: NebulaControl_CreateRelayResponse,
ResponderRelayIndex: peerRelay.LocalIndex,
InitiatorRelayIndex: peerRelay.RemoteIndex,
RelayFromIp: binary.BigEndian.Uint32(b[:]),
RelayToIp: uint32(target),
}
if v == cert.Version1 {
peer := peerHostInfo.vpnAddrs[0]
if !peer.Is4() {
//TODO: log cant do it
return
}
b := peer.As4()
resp.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
b = targetAddr.As4()
resp.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
} else {
resp.RelayFromAddr = netAddrToProtoAddr(peerHostInfo.vpnAddrs[0])
resp.RelayToAddr = target
}
msg, err := resp.Marshal()
if err != nil {
rm.l.
WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay")
rm.l.WithError(err).
Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay")
} else {
f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu))
rm.l.WithFields(logrus.Fields{
"relayFrom": resp.RelayFromIp,
"relayTo": resp.RelayToIp,
"relayFrom": resp.RelayFromAddr,
"relayTo": resp.RelayToAddr,
"initiatorRelayIndex": resp.InitiatorRelayIndex,
"responderRelayIndex": resp.ResponderRelayIndex,
"vpnIp": peerHostInfo.vpnIp}).
"vpnAddrs": peerHostInfo.vpnAddrs}).
Info("send CreateRelayResponse")
}
}
}
func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *NebulaControl) {
//TODO: IPV6-WORK
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], m.RelayFromIp)
from := netip.AddrFrom4(b)
binary.BigEndian.PutUint32(b[:], m.RelayToIp)
target := netip.AddrFrom4(b)
func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) {
from := protoAddrToNetAddr(m.RelayFromAddr)
target := protoAddrToNetAddr(m.RelayToAddr)
logMsg := rm.l.WithFields(logrus.Fields{
"relayFrom": from,
"relayTo": target,
"initiatorRelayIndex": m.InitiatorRelayIndex,
"vpnIp": h.vpnIp})
"vpnAddrs": h.vpnAddrs})
logMsg.Info("handleCreateRelayRequest")
// Is the source of the relay me? This should never happen, but did happen due to
// an issue migrating relays over to newly re-handshaked host info objects.
if from == f.myVpnNet.Addr() {
_, found := f.myVpnAddrsTable.Lookup(from)
if found {
logMsg.WithField("myIP", from).Error("Discarding relay request from myself")
return
}
// Is the target of the relay me?
if target == f.myVpnNet.Addr() {
_, found = f.myVpnAddrsTable.Lookup(target)
if found {
existingRelay, ok := h.relayState.QueryRelayForByIp(from)
if ok {
switch existingRelay.State {
@ -230,17 +262,22 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
return
}
//TODO: IPV6-WORK
fromB := from.As4()
targetB := target.As4()
resp := NebulaControl{
Type: NebulaControl_CreateRelayResponse,
ResponderRelayIndex: relay.LocalIndex,
InitiatorRelayIndex: relay.RemoteIndex,
RelayFromIp: binary.BigEndian.Uint32(fromB[:]),
RelayToIp: binary.BigEndian.Uint32(targetB[:]),
}
if v == cert.Version1 {
b := from.As4()
resp.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
b = target.As4()
resp.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
} else {
resp.RelayFromAddr = netAddrToProtoAddr(from)
resp.RelayToAddr = netAddrToProtoAddr(target)
}
msg, err := resp.Marshal()
if err != nil {
logMsg.
@ -253,7 +290,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
"relayTo": target,
"initiatorRelayIndex": resp.InitiatorRelayIndex,
"responderRelayIndex": resp.ResponderRelayIndex,
"vpnIp": h.vpnIp}).
"vpnAddrs": h.vpnAddrs}).
Info("send CreateRelayResponse")
}
return
@ -262,7 +299,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
if !rm.GetAmRelay() {
return
}
peer := rm.hostmap.QueryVpnIp(target)
peer := rm.hostmap.QueryVpnAddr(target)
if peer == nil {
// Try to establish a connection to this host. If we get a future relay request,
// we'll be ready!
@ -291,17 +328,27 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
sendCreateRequest = true
}
if sendCreateRequest {
//TODO: IPV6-WORK
fromB := h.vpnIp.As4()
targetB := target.As4()
// Send a CreateRelayRequest to the peer.
req := NebulaControl{
Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: index,
RelayFromIp: binary.BigEndian.Uint32(fromB[:]),
RelayToIp: binary.BigEndian.Uint32(targetB[:]),
}
if v == cert.Version1 {
if !h.vpnAddrs[0].Is4() {
//TODO: log it
return
}
b := h.vpnAddrs[0].As4()
req.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
b = target.As4()
req.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
} else {
req.RelayFromAddr = netAddrToProtoAddr(h.vpnAddrs[0])
req.RelayToAddr = netAddrToProtoAddr(target)
}
msg, err := req.Marshal()
if err != nil {
logMsg.
@ -310,11 +357,11 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu))
rm.l.WithFields(logrus.Fields{
//TODO: IPV6-WORK another lazy used to use the req object
"relayFrom": h.vpnIp,
"relayFrom": h.vpnAddrs[0],
"relayTo": target,
"initiatorRelayIndex": req.InitiatorRelayIndex,
"responderRelayIndex": req.ResponderRelayIndex,
"vpnIp": target}).
"vpnAddr": target}).
Info("send CreateRelayRequest")
}
}
@ -342,16 +389,28 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
"existingRemoteIndex": relay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest")
return
}
//TODO: IPV6-WORK
fromB := h.vpnIp.As4()
targetB := target.As4()
resp := NebulaControl{
Type: NebulaControl_CreateRelayResponse,
ResponderRelayIndex: relay.LocalIndex,
InitiatorRelayIndex: relay.RemoteIndex,
RelayFromIp: binary.BigEndian.Uint32(fromB[:]),
RelayToIp: binary.BigEndian.Uint32(targetB[:]),
}
if v == cert.Version1 {
if !h.vpnAddrs[0].Is4() {
//TODO: log it
return
}
b := h.vpnAddrs[0].As4()
resp.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
b = target.As4()
resp.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
} else {
resp.RelayFromAddr = netAddrToProtoAddr(h.vpnAddrs[0])
resp.RelayToAddr = netAddrToProtoAddr(target)
}
msg, err := resp.Marshal()
if err != nil {
rm.l.
@ -360,11 +419,11 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu))
rm.l.WithFields(logrus.Fields{
//TODO: IPV6-WORK more lazy, used to use resp object
"relayFrom": h.vpnIp,
"relayFrom": h.vpnAddrs[0],
"relayTo": target,
"initiatorRelayIndex": resp.InitiatorRelayIndex,
"responderRelayIndex": resp.ResponderRelayIndex,
"vpnIp": h.vpnIp}).
"vpnAddrs": h.vpnAddrs}).
Info("send CreateRelayResponse")
}

View File

@ -17,8 +17,8 @@ import (
type forEachFunc func(addr netip.AddrPort, preferred bool)
// The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate)
type checkFuncV4 func(vpnIp netip.Addr, to *Ip4AndPort) bool
type checkFuncV6 func(vpnIp netip.Addr, to *Ip6AndPort) bool
type checkFuncV4 func(vpnIp netip.Addr, to *V4AddrPort) bool
type checkFuncV6 func(vpnIp netip.Addr, to *V6AddrPort) bool
// CacheMap is a struct that better represents the lighthouse cache for humans
// The string key is the owners vpnIp
@ -48,14 +48,14 @@ type cacheRelay struct {
// cacheV4 stores learned and reported ipv4 records under cache
type cacheV4 struct {
learned *Ip4AndPort
reported []*Ip4AndPort
learned *V4AddrPort
reported []*V4AddrPort
}
// cacheV4 stores learned and reported ipv6 records under cache
type cacheV6 struct {
learned *Ip6AndPort
reported []*Ip6AndPort
learned *V6AddrPort
reported []*V6AddrPort
}
type hostnamePort struct {
@ -170,7 +170,7 @@ func (hr *hostnamesResults) Cancel() {
}
}
func (hr *hostnamesResults) GetIPs() []netip.AddrPort {
func (hr *hostnamesResults) GetAddrs() []netip.AddrPort {
var retSlice []netip.AddrPort
if hr != nil {
p := hr.ips.Load()
@ -189,6 +189,9 @@ type RemoteList struct {
// Every interaction with internals requires a lock!
sync.RWMutex
// The full list of vpn addresses assigned to this host
vpnAddrs []netip.Addr
// A deduplicated set of addresses. Any accessor should lock beforehand.
addrs []netip.AddrPort
@ -212,13 +215,16 @@ type RemoteList struct {
}
// NewRemoteList creates a new empty RemoteList
func NewRemoteList(shouldAdd func(netip.Addr) bool) *RemoteList {
return &RemoteList{
func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func(netip.Addr) bool) *RemoteList {
r := &RemoteList{
vpnAddrs: make([]netip.Addr, len(vpnAddrs)),
addrs: make([]netip.AddrPort, 0),
relays: make([]netip.Addr, 0),
cache: make(map[netip.Addr]*cache),
shouldAdd: shouldAdd,
}
copy(r.vpnAddrs, vpnAddrs)
return r
}
func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) {
@ -273,9 +279,9 @@ func (r *RemoteList) LearnRemote(ownerVpnIp netip.Addr, remote netip.AddrPort) {
r.Lock()
defer r.Unlock()
if remote.Addr().Is4() {
r.unlockedSetLearnedV4(ownerVpnIp, NewIp4AndPortFromNetIP(remote.Addr(), remote.Port()))
r.unlockedSetLearnedV4(ownerVpnIp, netAddrToProtoV4AddrPort(remote.Addr(), remote.Port()))
} else {
r.unlockedSetLearnedV6(ownerVpnIp, NewIp6AndPortFromNetIP(remote.Addr(), remote.Port()))
r.unlockedSetLearnedV6(ownerVpnIp, netAddrToProtoV6AddrPort(remote.Addr(), remote.Port()))
}
}
@ -304,21 +310,21 @@ func (r *RemoteList) CopyCache() *CacheMap {
if mc.v4 != nil {
if mc.v4.learned != nil {
c.Learned = append(c.Learned, AddrPortFromIp4AndPort(mc.v4.learned))
c.Learned = append(c.Learned, protoV4AddrPortToNetAddrPort(mc.v4.learned))
}
for _, a := range mc.v4.reported {
c.Reported = append(c.Reported, AddrPortFromIp4AndPort(a))
c.Reported = append(c.Reported, protoV4AddrPortToNetAddrPort(a))
}
}
if mc.v6 != nil {
if mc.v6.learned != nil {
c.Learned = append(c.Learned, AddrPortFromIp6AndPort(mc.v6.learned))
c.Learned = append(c.Learned, protoV6AddrPortToNetAddrPort(mc.v6.learned))
}
for _, a := range mc.v6.reported {
c.Reported = append(c.Reported, AddrPortFromIp6AndPort(a))
c.Reported = append(c.Reported, protoV6AddrPortToNetAddrPort(a))
}
}
@ -401,14 +407,14 @@ func (r *RemoteList) unlockedIsBad(remote netip.AddrPort) bool {
// unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the
// deduplicated address list as dirty
func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp netip.Addr, to *Ip4AndPort) {
func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp netip.Addr, to *V4AddrPort) {
r.shouldRebuild = true
r.unlockedGetOrMakeV4(ownerVpnIp).learned = to
}
// unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
// and marks the deduplicated address list as dirty
func (r *RemoteList) unlockedSetV4(ownerVpnIp, vpnIp netip.Addr, to []*Ip4AndPort, check checkFuncV4) {
func (r *RemoteList) unlockedSetV4(ownerVpnIp, vpnIp netip.Addr, to []*V4AddrPort, check checkFuncV4) {
r.shouldRebuild = true
c := r.unlockedGetOrMakeV4(ownerVpnIp)
@ -436,12 +442,12 @@ func (r *RemoteList) unlockedSetRelay(ownerVpnIp, vpnIp netip.Addr, to []netip.A
// unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner
// This is only useful for establishing static hosts
func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *Ip4AndPort) {
func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *V4AddrPort) {
r.shouldRebuild = true
c := r.unlockedGetOrMakeV4(ownerVpnIp)
// We are doing the easy append because this is rarely called
c.reported = append([]*Ip4AndPort{to}, c.reported...)
c.reported = append([]*V4AddrPort{to}, c.reported...)
if len(c.reported) > MaxRemotes {
c.reported = c.reported[:MaxRemotes]
}
@ -449,14 +455,14 @@ func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *Ip4AndPort) {
// unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the
// deduplicated address list as dirty
func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp netip.Addr, to *Ip6AndPort) {
func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp netip.Addr, to *V6AddrPort) {
r.shouldRebuild = true
r.unlockedGetOrMakeV6(ownerVpnIp).learned = to
}
// unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
// and marks the deduplicated address list as dirty
func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*Ip6AndPort, check checkFuncV6) {
func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*V6AddrPort, check checkFuncV6) {
r.shouldRebuild = true
c := r.unlockedGetOrMakeV6(ownerVpnIp)
@ -473,12 +479,12 @@ func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*Ip6AndPor
// unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner
// This is only useful for establishing static hosts
func (r *RemoteList) unlockedPrependV6(ownerVpnIp netip.Addr, to *Ip6AndPort) {
func (r *RemoteList) unlockedPrependV6(ownerVpnIp netip.Addr, to *V6AddrPort) {
r.shouldRebuild = true
c := r.unlockedGetOrMakeV6(ownerVpnIp)
// We are doing the easy append because this is rarely called
c.reported = append([]*Ip6AndPort{to}, c.reported...)
c.reported = append([]*V6AddrPort{to}, c.reported...)
if len(c.reported) > MaxRemotes {
c.reported = c.reported[:MaxRemotes]
}
@ -536,14 +542,14 @@ func (r *RemoteList) unlockedCollect() {
for _, c := range r.cache {
if c.v4 != nil {
if c.v4.learned != nil {
u := AddrPortFromIp4AndPort(c.v4.learned)
u := protoV4AddrPortToNetAddrPort(c.v4.learned)
if !r.unlockedIsBad(u) {
addrs = append(addrs, u)
}
}
for _, v := range c.v4.reported {
u := AddrPortFromIp4AndPort(v)
u := protoV4AddrPortToNetAddrPort(v)
if !r.unlockedIsBad(u) {
addrs = append(addrs, u)
}
@ -552,14 +558,14 @@ func (r *RemoteList) unlockedCollect() {
if c.v6 != nil {
if c.v6.learned != nil {
u := AddrPortFromIp6AndPort(c.v6.learned)
u := protoV6AddrPortToNetAddrPort(c.v6.learned)
if !r.unlockedIsBad(u) {
addrs = append(addrs, u)
}
}
for _, v := range c.v6.reported {
u := AddrPortFromIp6AndPort(v)
u := protoV6AddrPortToNetAddrPort(v)
if !r.unlockedIsBad(u) {
addrs = append(addrs, u)
}
@ -573,7 +579,7 @@ func (r *RemoteList) unlockedCollect() {
}
}
dnsAddrs := r.hr.GetIPs()
dnsAddrs := r.hr.GetAddrs()
for _, addr := range dnsAddrs {
if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) {
if !r.unlockedIsBad(addr) {

View File

@ -9,11 +9,11 @@ import (
)
func TestRemoteList_Rebuild(t *testing.T) {
rl := NewRemoteList(nil)
rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil)
rl.unlockedSetV4(
netip.MustParseAddr("0.0.0.0"),
netip.MustParseAddr("0.0.0.0"),
[]*Ip4AndPort{
[]*V4AddrPort{
newIp4AndPortFromString("70.199.182.92:1475"), // this is duped
newIp4AndPortFromString("172.17.0.182:10101"),
newIp4AndPortFromString("172.17.1.1:10101"), // this is duped
@ -25,20 +25,20 @@ func TestRemoteList_Rebuild(t *testing.T) {
newIp4AndPortFromString("70.199.182.92:1476"), // almost dupe of 0 with a diff port
newIp4AndPortFromString("70.199.182.92:1475"), // this is a dupe
},
func(netip.Addr, *Ip4AndPort) bool { return true },
func(netip.Addr, *V4AddrPort) bool { return true },
)
rl.unlockedSetV6(
netip.MustParseAddr("0.0.0.1"),
netip.MustParseAddr("0.0.0.1"),
[]*Ip6AndPort{
[]*V6AddrPort{
newIp6AndPortFromString("[1::1]:1"), // this is duped
newIp6AndPortFromString("[1::1]:2"), // almost dupe of 0 with a diff port, also gets duped
newIp6AndPortFromString("[1:100::1]:1"),
newIp6AndPortFromString("[1::1]:1"), // this is a dupe
newIp6AndPortFromString("[1::1]:2"), // this is a dupe
},
func(netip.Addr, *Ip6AndPort) bool { return true },
func(netip.Addr, *V6AddrPort) bool { return true },
)
rl.Rebuild([]netip.Prefix{})
@ -98,11 +98,11 @@ func TestRemoteList_Rebuild(t *testing.T) {
}
func BenchmarkFullRebuild(b *testing.B) {
rl := NewRemoteList(nil)
rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil)
rl.unlockedSetV4(
netip.MustParseAddr("0.0.0.0"),
netip.MustParseAddr("0.0.0.0"),
[]*Ip4AndPort{
[]*V4AddrPort{
newIp4AndPortFromString("70.199.182.92:1475"),
newIp4AndPortFromString("172.17.0.182:10101"),
newIp4AndPortFromString("172.17.1.1:10101"),
@ -112,19 +112,19 @@ func BenchmarkFullRebuild(b *testing.B) {
newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe
newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port
},
func(netip.Addr, *Ip4AndPort) bool { return true },
func(netip.Addr, *V4AddrPort) bool { return true },
)
rl.unlockedSetV6(
netip.MustParseAddr("0.0.0.0"),
netip.MustParseAddr("0.0.0.0"),
[]*Ip6AndPort{
[]*V6AddrPort{
newIp6AndPortFromString("[1::1]:1"),
newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port
newIp6AndPortFromString("[1:100::1]:1"),
newIp6AndPortFromString("[1::1]:1"), // this is a dupe
},
func(netip.Addr, *Ip6AndPort) bool { return true },
func(netip.Addr, *V6AddrPort) bool { return true },
)
b.Run("no preferred", func(b *testing.B) {
@ -160,11 +160,11 @@ func BenchmarkFullRebuild(b *testing.B) {
}
func BenchmarkSortRebuild(b *testing.B) {
rl := NewRemoteList(nil)
rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil)
rl.unlockedSetV4(
netip.MustParseAddr("0.0.0.0"),
netip.MustParseAddr("0.0.0.0"),
[]*Ip4AndPort{
[]*V4AddrPort{
newIp4AndPortFromString("70.199.182.92:1475"),
newIp4AndPortFromString("172.17.0.182:10101"),
newIp4AndPortFromString("172.17.1.1:10101"),
@ -174,19 +174,19 @@ func BenchmarkSortRebuild(b *testing.B) {
newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe
newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port
},
func(netip.Addr, *Ip4AndPort) bool { return true },
func(netip.Addr, *V4AddrPort) bool { return true },
)
rl.unlockedSetV6(
netip.MustParseAddr("0.0.0.0"),
netip.MustParseAddr("0.0.0.0"),
[]*Ip6AndPort{
[]*V6AddrPort{
newIp6AndPortFromString("[1::1]:1"),
newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port
newIp6AndPortFromString("[1:100::1]:1"),
newIp6AndPortFromString("[1::1]:1"), // this is a dupe
},
func(netip.Addr, *Ip6AndPort) bool { return true },
func(netip.Addr, *V6AddrPort) bool { return true },
)
b.Run("no preferred", func(b *testing.B) {
@ -224,19 +224,19 @@ func BenchmarkSortRebuild(b *testing.B) {
})
}
func newIp4AndPortFromString(s string) *Ip4AndPort {
func newIp4AndPortFromString(s string) *V4AddrPort {
a := netip.MustParseAddrPort(s)
v4Addr := a.Addr().As4()
return &Ip4AndPort{
Ip: binary.BigEndian.Uint32(v4Addr[:]),
return &V4AddrPort{
Addr: binary.BigEndian.Uint32(v4Addr[:]),
Port: uint32(a.Port()),
}
}
func newIp6AndPortFromString(s string) *Ip6AndPort {
func newIp6AndPortFromString(s string) *V6AddrPort {
a := netip.MustParseAddrPort(s)
v6Addr := a.Addr().As16()
return &Ip6AndPort{
return &V6AddrPort{
Hi: binary.BigEndian.Uint64(v6Addr[:8]),
Lo: binary.BigEndian.Uint64(v6Addr[8:]),
Port: uint32(a.Port()),

View File

@ -90,9 +90,9 @@ func New(config *config.C) (*Service, error) {
},
})
ipNet := device.Cidr()
ipNet := device.Networks()
pa := tcpip.ProtocolAddress{
AddressWithPrefix: tcpip.AddrFromSlice(ipNet.Addr().AsSlice()).WithPrefix(),
AddressWithPrefix: tcpip.AddrFromSlice(ipNet[0].Addr().AsSlice()).WithPrefix(),
Protocol: ipv4.ProtocolNumber,
}
if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{

View File

@ -19,7 +19,7 @@ import (
type m map[string]interface{}
func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service {
_, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.PrefixFrom(udpIp, 24)}, nil, []string{})
_, _, myPrivKey, myPEM := e2e.NewTestCert(cert.Version2, caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.PrefixFrom(udpIp, 24)}, nil, []string{})
caB, err := caCrt.MarshalPEM()
if err != nil {
panic(err)

37
ssh.go
View File

@ -430,7 +430,7 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er
}
sort.Slice(hm, func(i, j int) bool {
return hm[i].VpnIp.Compare(hm[j].VpnIp) < 0
return hm[i].VpnAddrs[0].Compare(hm[j].VpnAddrs[0]) < 0
})
if fs.Json || fs.Pretty {
@ -447,7 +447,7 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er
} else {
for _, v := range hm {
err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIp, v.RemoteAddrs))
err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnAddrs, v.RemoteAddrs))
if err != nil {
return err
}
@ -581,7 +581,7 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
hostInfo := ifce.hostMap.QueryVpnAddr(vpnIp)
if hostInfo == nil {
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
}
@ -622,12 +622,12 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
hostInfo := ifce.hostMap.QueryVpnAddr(vpnIp)
if hostInfo != nil {
return w.WriteLine(fmt.Sprintf("Tunnel already exists"))
}
hostInfo = ifce.handshakeManager.QueryVpnIp(vpnIp)
hostInfo = ifce.handshakeManager.QueryVpnAddr(vpnIp)
if hostInfo != nil {
return w.WriteLine(fmt.Sprintf("Tunnel already handshaking"))
}
@ -677,7 +677,7 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
hostInfo := ifce.hostMap.QueryVpnAddr(vpnIp)
if hostInfo == nil {
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
}
@ -785,7 +785,8 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
return nil
}
cert := ifce.pki.GetCertState().Certificate
//TODO: This should return both certs
cert := ifce.pki.getDefaultCertificate()
if len(a) > 0 {
vpnIp, err := netip.ParseAddr(a[0])
if err != nil {
@ -796,7 +797,7 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
hostInfo := ifce.hostMap.QueryVpnAddr(vpnIp)
if hostInfo == nil {
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
}
@ -880,16 +881,16 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
}
for k, v := range relays {
ro := RelayOutput{NebulaIp: v.vpnIp}
ro := RelayOutput{NebulaIp: v.vpnAddrs[0]}
co.Relays = append(co.Relays, &ro)
relayHI := ifce.hostMap.QueryVpnIp(v.vpnIp)
relayHI := ifce.hostMap.QueryVpnAddr(v.vpnAddrs[0])
if relayHI == nil {
ro.RelayForIps = append(ro.RelayForIps, RelayFor{Error: errors.New("could not find hostinfo")})
continue
}
for _, vpnIp := range relayHI.relayState.CopyRelayForIps() {
rf := RelayFor{Error: nil}
r, ok := relayHI.relayState.GetRelayForByIp(vpnIp)
r, ok := relayHI.relayState.GetRelayForByAddr(vpnIp)
if ok {
t := ""
switch r.Type {
@ -913,14 +914,14 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
rf.LocalIndex = r.LocalIndex
rf.RemoteIndex = r.RemoteIndex
rf.PeerIp = r.PeerIp
rf.PeerIp = r.PeerAddr
rf.Type = t
rf.State = s
if rf.LocalIndex != k {
rf.Error = fmt.Errorf("hostmap LocalIndex '%v' does not match RelayState LocalIndex", k)
}
}
relayedHI := ifce.hostMap.QueryVpnIp(vpnIp)
relayedHI := ifce.hostMap.QueryVpnAddr(vpnIp)
if relayedHI != nil {
rf.RelayedThrough = append(rf.RelayedThrough, relayedHI.relayState.CopyRelayIps()...)
}
@ -955,7 +956,7 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
hostInfo := ifce.hostMap.QueryVpnAddr(vpnIp)
if hostInfo == nil {
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
}
@ -971,13 +972,15 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
func sshDeviceInfo(ifce *Interface, fs interface{}, w sshd.StringWriter) error {
data := struct {
Name string `json:"name"`
Cidr string `json:"cidr"`
Name string `json:"name"`
Cidr []netip.Prefix `json:"cidr"`
}{
Name: ifce.inside.Name(),
Cidr: ifce.inside.Cidr().String(),
Cidr: make([]netip.Prefix, len(ifce.inside.Networks())),
}
copy(data.Cidr, ifce.inside.Networks())
flags, ok := fs.(*sshDeviceInfoFlags)
if !ok {
return fmt.Errorf("internal error: expected flags to be sshDeviceInfoFlags but was %+v", fs)

View File

@ -16,8 +16,8 @@ func (NoopTun) Activate() error {
return nil
}
func (NoopTun) Cidr() netip.Prefix {
return netip.Prefix{}
func (NoopTun) Networks() []netip.Prefix {
return []netip.Prefix{}
}
func (NoopTun) Name() string {

View File

@ -116,10 +116,10 @@ func TestTimerWheel_Purge(t *testing.T) {
assert.Equal(t, 0, tw.current)
fps := []firewall.Packet{
{LocalIP: netip.MustParseAddr("0.0.0.1")},
{LocalIP: netip.MustParseAddr("0.0.0.2")},
{LocalIP: netip.MustParseAddr("0.0.0.3")},
{LocalIP: netip.MustParseAddr("0.0.0.4")},
{LocalAddr: netip.MustParseAddr("0.0.0.1")},
{LocalAddr: netip.MustParseAddr("0.0.0.2")},
{LocalAddr: netip.MustParseAddr("0.0.0.3")},
{LocalAddr: netip.MustParseAddr("0.0.0.4")},
}
tw.Add(fps[0], time.Second*1)

View File

@ -4,28 +4,19 @@ import (
"net/netip"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
)
const MTU = 9001
type EncReader func(
addr netip.AddrPort,
out []byte,
packet []byte,
header *header.H,
fwPacket *firewall.Packet,
lhh LightHouseHandlerFunc,
nb []byte,
q int,
localCache firewall.ConntrackCache,
payload []byte,
)
type Conn interface {
Rebind() error
LocalAddr() (netip.AddrPort, error)
ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int)
ListenOut(r EncReader)
WriteTo(b []byte, addr netip.AddrPort) error
ReloadConfig(c *config.C)
Close() error
@ -39,7 +30,7 @@ func (NoopConn) Rebind() error {
func (NoopConn) LocalAddr() (netip.AddrPort, error) {
return netip.AddrPort{}, nil
}
func (NoopConn) ListenOut(_ EncReader, _ LightHouseHandlerFunc, _ *firewall.ConntrackCacheTicker, _ int) {
func (NoopConn) ListenOut(_ EncReader) {
return
}
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {

View File

@ -1,10 +0,0 @@
package udp
import (
"net/netip"
)
//TODO: The items in this file belong in their own packages but doing that in a single PR is a nightmare
// TODO: IPV6-WORK this can likely be removed now
type LightHouseHandlerFunc func(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte)

View File

@ -15,8 +15,6 @@ import (
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
)
type GenericConn struct {
@ -72,12 +70,8 @@ type rawMessage struct {
Len uint32
}
func (u *GenericConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
plaintext := make([]byte, MTU)
func (u *GenericConn) ListenOut(r EncReader) {
buffer := make([]byte, MTU)
h := &header.H{}
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
for {
// Just read one packet at a time
@ -87,16 +81,6 @@ func (u *GenericConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *f
return
}
r(
netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()),
plaintext[:0],
buffer[:n],
h,
fwPacket,
lhf,
nb,
q,
cache.Get(u.l),
)
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
}
}

View File

@ -14,8 +14,6 @@ import (
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
"golang.org/x/sys/unix"
)
@ -120,15 +118,9 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
}
}
func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
plaintext := make([]byte, MTU)
h := &header.H{}
fwPacket := &firewall.Packet{}
func (u *StdConn) ListenOut(r EncReader) {
var ip netip.Addr
nb := make([]byte, 12, 12)
//TODO: should we track this?
//metric := metrics.GetOrRegisterHistogram("test.batch_read", nil, metrics.NewExpDecaySample(1028, 0.015))
msgs, buffers, names := u.PrepareRawMessages(u.batch)
read := u.ReadMulti
if u.batch == 1 {
@ -142,26 +134,14 @@ func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew
return
}
//metric.Update(int64(n))
for i := 0; i < n; i++ {
// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
if u.isV4 {
ip, _ = netip.AddrFromSlice(names[i][4:8])
//TODO: IPV6-WORK what is not ok?
} else {
ip, _ = netip.AddrFromSlice(names[i][8:24])
//TODO: IPV6-WORK what is not ok?
}
r(
netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])),
plaintext[:0],
buffers[i][:msgs[i].Len],
h,
fwPacket,
lhf,
nb,
q,
cache.Get(u.l),
)
r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len])
}
}
}

View File

@ -18,9 +18,6 @@ import (
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/conn/winrio"
)
@ -118,12 +115,8 @@ func (u *RIOConn) bind(sa windows.Sockaddr) error {
return nil
}
func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
plaintext := make([]byte, MTU)
func (u *RIOConn) ListenOut(r EncReader) {
buffer := make([]byte, MTU)
h := &header.H{}
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
for {
// Just read one packet at a time
@ -133,17 +126,7 @@ func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew
return
}
r(
netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)),
plaintext[:0],
buffer[:n],
h,
fwPacket,
lhf,
nb,
q,
cache.Get(u.l),
)
r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n])
}
}

View File

@ -10,7 +10,6 @@ import (
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
)
@ -107,18 +106,13 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error {
return nil
}
func (u *TesterConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
plaintext := make([]byte, MTU)
h := &header.H{}
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
func (u *TesterConn) ListenOut(r EncReader) {
for {
p, ok := <-u.RxPackets
if !ok {
return
}
r(p.From, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l))
r(p.From, p.Data)
}
}