Support for ipv6 in the overlay with v2 certificates

---------

Co-authored-by: Jack Doan <jackdoan@rivian.com>
This commit is contained in:
Nate Brown 2024-10-23 22:02:10 -05:00
parent 3e6c75573f
commit f2c32421c4
86 changed files with 5747 additions and 3335 deletions

View File

@ -196,7 +196,7 @@ bench-cpu-long:
go test -bench=. -benchtime=60s -cpuprofile=cpu.pprof go test -bench=. -benchtime=60s -cpuprofile=cpu.pprof
go tool pprof go-audit.test cpu.pprof go tool pprof go-audit.test cpu.pprof
proto: nebula.pb.go cert/cert.pb.go proto: nebula.pb.go cert/cert_v1.pb.go
nebula.pb.go: nebula.proto .FORCE nebula.pb.go: nebula.proto .FORCE
go build github.com/gogo/protobuf/protoc-gen-gogofaster go build github.com/gogo/protobuf/protoc-gen-gogofaster

View File

@ -21,7 +21,11 @@ type calculatedRemote struct {
port uint32 port uint32
} }
func newCalculatedRemote(maskCidr netip.Prefix, port int) (*calculatedRemote, error) { func newCalculatedRemote(cidr, maskCidr netip.Prefix, port int) (*calculatedRemote, error) {
if maskCidr.Addr().BitLen() != cidr.Addr().BitLen() {
return nil, fmt.Errorf("invalid mask: %s for cidr: %s", maskCidr, cidr)
}
masked := maskCidr.Masked() masked := maskCidr.Masked()
if port < 0 || port > math.MaxUint16 { if port < 0 || port > math.MaxUint16 {
return nil, fmt.Errorf("invalid port: %d", port) return nil, fmt.Errorf("invalid port: %d", port)
@ -38,32 +42,38 @@ func (c *calculatedRemote) String() string {
return fmt.Sprintf("CalculatedRemote(mask=%v port=%d)", c.ipNet, c.port) return fmt.Sprintf("CalculatedRemote(mask=%v port=%d)", c.ipNet, c.port)
} }
func (c *calculatedRemote) Apply(ip netip.Addr) *Ip4AndPort { func (c *calculatedRemote) ApplyV4(addr netip.Addr) *V4AddrPort {
// Combine the masked bytes of the "mask" IP with the unmasked bytes // Combine the masked bytes of the "mask" IP with the unmasked bytes of the overlay IP
// of the overlay IP
if c.ipNet.Addr().Is4() {
return c.apply4(ip)
}
return c.apply6(ip)
}
func (c *calculatedRemote) apply4(ip netip.Addr) *Ip4AndPort {
//TODO: IPV6-WORK this can be less crappy
maskb := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen()) maskb := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen())
mask := binary.BigEndian.Uint32(maskb[:]) mask := binary.BigEndian.Uint32(maskb[:])
b := c.mask.Addr().As4() b := c.mask.Addr().As4()
maskIp := binary.BigEndian.Uint32(b[:]) maskAddr := binary.BigEndian.Uint32(b[:])
b = ip.As4() b = addr.As4()
intIp := binary.BigEndian.Uint32(b[:]) intAddr := binary.BigEndian.Uint32(b[:])
return &Ip4AndPort{(maskIp & mask) | (intIp & ^mask), c.port} return &V4AddrPort{(maskAddr & mask) | (intAddr & ^mask), c.port}
} }
func (c *calculatedRemote) apply6(ip netip.Addr) *Ip4AndPort { func (c *calculatedRemote) ApplyV6(addr netip.Addr) *V6AddrPort {
//TODO: IPV6-WORK mask := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen())
panic("Can not calculate ipv6 remote addresses") maskAddr := c.mask.Addr().As16()
calcAddr := addr.As16()
ap := V6AddrPort{Port: c.port}
maskb := binary.BigEndian.Uint64(mask[:8])
maskAddrb := binary.BigEndian.Uint64(maskAddr[:8])
calcAddrb := binary.BigEndian.Uint64(calcAddr[:8])
ap.Hi = (maskAddrb & maskb) | (calcAddrb & ^maskb)
maskb = binary.BigEndian.Uint64(mask[8:])
maskAddrb = binary.BigEndian.Uint64(maskAddr[8:])
calcAddrb = binary.BigEndian.Uint64(calcAddr[8:])
ap.Lo = (maskAddrb & maskb) | (calcAddrb & ^maskb)
return &ap
} }
func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calculatedRemote], error) { func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calculatedRemote], error) {
@ -89,8 +99,7 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calcu
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR) return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
} }
//TODO: IPV6-WORK this does not verify that rawValue contains the same bits as cidr here entry, err := newCalculatedRemotesListFromConfig(cidr, rawValue)
entry, err := newCalculatedRemotesListFromConfig(rawValue)
if err != nil { if err != nil {
return nil, fmt.Errorf("config '%s.%s': %w", k, rawCIDR, err) return nil, fmt.Errorf("config '%s.%s': %w", k, rawCIDR, err)
} }
@ -101,7 +110,7 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calcu
return calculatedRemotes, nil return calculatedRemotes, nil
} }
func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) { func newCalculatedRemotesListFromConfig(cidr netip.Prefix, raw any) ([]*calculatedRemote, error) {
rawList, ok := raw.([]any) rawList, ok := raw.([]any)
if !ok { if !ok {
return nil, fmt.Errorf("calculated_remotes entry has invalid type: %T", raw) return nil, fmt.Errorf("calculated_remotes entry has invalid type: %T", raw)
@ -109,7 +118,7 @@ func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) {
var l []*calculatedRemote var l []*calculatedRemote
for _, e := range rawList { for _, e := range rawList {
c, err := newCalculatedRemotesEntryFromConfig(e) c, err := newCalculatedRemotesEntryFromConfig(cidr, e)
if err != nil { if err != nil {
return nil, fmt.Errorf("calculated_remotes entry: %w", err) return nil, fmt.Errorf("calculated_remotes entry: %w", err)
} }
@ -119,7 +128,7 @@ func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) {
return l, nil return l, nil
} }
func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) { func newCalculatedRemotesEntryFromConfig(cidr netip.Prefix, raw any) (*calculatedRemote, error) {
rawMap, ok := raw.(map[any]any) rawMap, ok := raw.(map[any]any)
if !ok { if !ok {
return nil, fmt.Errorf("invalid type: %T", raw) return nil, fmt.Errorf("invalid type: %T", raw)
@ -155,5 +164,5 @@ func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) {
return nil, fmt.Errorf("invalid port (type %T): %v", rawValue, rawValue) return nil, fmt.Errorf("invalid port (type %T): %v", rawValue, rawValue)
} }
return newCalculatedRemote(maskCidr, port) return newCalculatedRemote(cidr, maskCidr, port)
} }

View File

@ -9,10 +9,9 @@ import (
) )
func TestCalculatedRemoteApply(t *testing.T) { func TestCalculatedRemoteApply(t *testing.T) {
ipNet, err := netip.ParsePrefix("192.168.1.0/24") // Test v4 addresses
require.NoError(t, err) ipNet := netip.MustParsePrefix("192.168.1.0/24")
c, err := newCalculatedRemote(ipNet, ipNet, 4242)
c, err := newCalculatedRemote(ipNet, 4242)
require.NoError(t, err) require.NoError(t, err)
input, err := netip.ParseAddr("10.0.10.182") input, err := netip.ParseAddr("10.0.10.182")
@ -21,5 +20,62 @@ func TestCalculatedRemoteApply(t *testing.T) {
expected, err := netip.ParseAddr("192.168.1.182") expected, err := netip.ParseAddr("192.168.1.182")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, NewIp4AndPortFromNetIP(expected, 4242), c.Apply(input)) assert.Equal(t, netAddrToProtoV4AddrPort(expected, 4242), c.ApplyV4(input))
// Test v6 addresses
ipNet = netip.MustParsePrefix("ffff:ffff:ffff:ffff::0/64")
c, err = newCalculatedRemote(ipNet, ipNet, 4242)
require.NoError(t, err)
input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
assert.NoError(t, err)
expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:beef:beef:beef:beef")
assert.NoError(t, err)
assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
// Test v6 addresses part 2
ipNet = netip.MustParsePrefix("ffff:ffff:ffff:ffff:ffff::0/80")
c, err = newCalculatedRemote(ipNet, ipNet, 4242)
require.NoError(t, err)
input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
assert.NoError(t, err)
expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:ffff:beef:beef:beef")
assert.NoError(t, err)
assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
// Test v6 addresses part 2
ipNet = netip.MustParsePrefix("ffff:ffff:ffff::0/48")
c, err = newCalculatedRemote(ipNet, ipNet, 4242)
require.NoError(t, err)
input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
assert.NoError(t, err)
expected, err = netip.ParseAddr("ffff:ffff:ffff:beef:beef:beef:beef:beef")
assert.NoError(t, err)
assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
}
func Test_newCalculatedRemote(t *testing.T) {
c, err := newCalculatedRemote(netip.MustParsePrefix("1::1/128"), netip.MustParsePrefix("1.0.0.0/32"), 4242)
require.EqualError(t, err, "invalid mask: 1.0.0.0/32 for cidr: 1::1/128")
require.Nil(t, c)
c, err = newCalculatedRemote(netip.MustParsePrefix("1.0.0.0/32"), netip.MustParsePrefix("1::1/128"), 4242)
require.EqualError(t, err, "invalid mask: 1::1/128 for cidr: 1.0.0.0/32")
require.Nil(t, c)
c, err = newCalculatedRemote(netip.MustParsePrefix("1.0.0.0/32"), netip.MustParsePrefix("1.0.0.0/32"), 4242)
require.NoError(t, err)
require.NotNil(t, c)
c, err = newCalculatedRemote(netip.MustParsePrefix("1::1/128"), netip.MustParsePrefix("1::1/128"), 4242)
require.NoError(t, err)
require.NotNil(t, c)
} }

View File

@ -2,14 +2,25 @@
This is a library for interacting with `nebula` style certificates and authorities. This is a library for interacting with `nebula` style certificates and authorities.
A `protobuf` definition of the certificate format is also included There are now 2 versions of `nebula` certificates:
### Compiling the protobuf definition ## v1
Make sure you have `protoc` installed. This version is deprecated.
A `protobuf` definition of the certificate format is included at `cert_v1.proto`
To compile the definition you will need `protoc` installed.
To compile for `go` with the same version of protobuf specified in go.mod: To compile for `go` with the same version of protobuf specified in go.mod:
```bash ```bash
make make proto
``` ```
## v2
This is the latest version which uses asn.1 DER encoding. It can support ipv4 and ipv6 and tolerate
future certificate changes better than v1.
`cert_v2.asn1` defines the wire format and can be used to compile marshalers.

52
cert/asn1.go Normal file
View File

@ -0,0 +1,52 @@
package cert
import (
"golang.org/x/crypto/cryptobyte"
"golang.org/x/crypto/cryptobyte/asn1"
)
// readOptionalASN1Boolean reads an asn.1 boolean with a specific tag instead of a asn.1 tag wrapping a boolean with a value
// https://github.com/golang/go/issues/64811#issuecomment-1944446920
func readOptionalASN1Boolean(b *cryptobyte.String, out *bool, tag asn1.Tag, defaultValue bool) bool {
var present bool
var child cryptobyte.String
if !b.ReadOptionalASN1(&child, &present, tag) {
return false
}
if !present {
*out = defaultValue
return true
}
// Ensure we have 1 byte
if len(child) == 1 {
*out = child[0] > 0
return true
}
return false
}
// readOptionalASN1Byte reads an asn.1 uint8 with a specific tag instead of a asn.1 tag wrapping a uint8 with a value
// Similar issue as with readOptionalASN1Boolean
func readOptionalASN1Byte(b *cryptobyte.String, out *byte, tag asn1.Tag, defaultValue byte) bool {
var present bool
var child cryptobyte.String
if !b.ReadOptionalASN1(&child, &present, tag) {
return false
}
if !present {
*out = defaultValue
return true
}
// Ensure we have 1 byte
if len(child) == 1 {
*out = child[0]
return true
}
return false
}

View File

@ -63,31 +63,31 @@ IBNWYMep3ysx9zCgknfG5dKtwGTaqF++BWKDYdyl34KX
rootCA := certificateV1{ rootCA := certificateV1{
details: detailsV1{ details: detailsV1{
Name: "nebula root ca", name: "nebula root ca",
}, },
} }
rootCA01 := certificateV1{ rootCA01 := certificateV1{
details: detailsV1{ details: detailsV1{
Name: "nebula root ca 01", name: "nebula root ca 01",
}, },
} }
rootCAP256 := certificateV1{ rootCAP256 := certificateV1{
details: detailsV1{ details: detailsV1{
Name: "nebula P256 test", name: "nebula P256 test",
}, },
} }
p, err := NewCAPoolFromPEM([]byte(noNewLines)) p, err := NewCAPoolFromPEM([]byte(noNewLines))
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, p.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.Name) assert.Equal(t, p.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.name)
assert.Equal(t, p.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.Name) assert.Equal(t, p.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.name)
pp, err := NewCAPoolFromPEM([]byte(withNewLines)) pp, err := NewCAPoolFromPEM([]byte(withNewLines))
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, pp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.Name) assert.Equal(t, pp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.name)
assert.Equal(t, pp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.Name) assert.Equal(t, pp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.name)
// expired cert, no valid certs // expired cert, no valid certs
ppp, err := NewCAPoolFromPEM([]byte(expired)) ppp, err := NewCAPoolFromPEM([]byte(expired))
@ -97,13 +97,13 @@ IBNWYMep3ysx9zCgknfG5dKtwGTaqF++BWKDYdyl34KX
// expired cert, with valid certs // expired cert, with valid certs
pppp, err := NewCAPoolFromPEM(append([]byte(expired), noNewLines...)) pppp, err := NewCAPoolFromPEM(append([]byte(expired), noNewLines...))
assert.Equal(t, ErrExpired, err) assert.Equal(t, ErrExpired, err)
assert.Equal(t, pppp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.Name) assert.Equal(t, pppp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.name)
assert.Equal(t, pppp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.Name) assert.Equal(t, pppp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.name)
assert.Equal(t, pppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Certificate.Name(), "expired") assert.Equal(t, pppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Certificate.Name(), "expired")
assert.Equal(t, len(pppp.CAs), 3) assert.Equal(t, len(pppp.CAs), 3)
ppppp, err := NewCAPoolFromPEM([]byte(p256)) ppppp, err := NewCAPoolFromPEM([]byte(p256))
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, ppppp.CAs[string("a7938893ec8c4ef769b06d7f425e5e46f7a7f5ffa49c3bcf4a86b608caba9159")].Certificate.Name(), rootCAP256.details.Name) assert.Equal(t, ppppp.CAs[string("a7938893ec8c4ef769b06d7f425e5e46f7a7f5ffa49c3bcf4a86b608caba9159")].Certificate.Name(), rootCAP256.details.name)
assert.Equal(t, len(ppppp.CAs), 1) assert.Equal(t, len(ppppp.CAs), 1)
} }

View File

@ -1,13 +1,15 @@
package cert package cert
import ( import (
"fmt"
"net/netip" "net/netip"
"time" "time"
) )
type Version int type Version uint8
const ( const (
VersionPre1 Version = 0
Version1 Version = 1 Version1 Version = 1
Version2 Version = 2 Version2 Version = 2
) )
@ -107,23 +109,56 @@ type CachedCertificate struct {
signerFingerprint string signerFingerprint string
} }
// UnmarshalCertificate will attempt to unmarshal a wire protocol level certificate. func (cc *CachedCertificate) String() string {
func UnmarshalCertificate(b []byte) (Certificate, error) { return cc.Certificate.String()
c, err := unmarshalCertificateV1(b, true)
if err != nil {
return nil, err
}
return c, nil
} }
// UnmarshalCertificateFromHandshake will attempt to unmarshal a certificate received in a handshake. // RecombineAndValidate will attempt to unmarshal a certificate received in a handshake.
// Handshakes save space by placing the peers public key in a different part of the packet, we have to // Handshakes save space by placing the peers public key in a different part of the packet, we have to
// reassemble the actual certificate structure with that in mind. // reassemble the actual certificate structure with that in mind.
func UnmarshalCertificateFromHandshake(b []byte, publicKey []byte) (Certificate, error) { func RecombineAndValidate(v Version, rawCertBytes, publicKey []byte, curve Curve, caPool *CAPool) (*CachedCertificate, error) {
c, err := unmarshalCertificateV1(b, false) if publicKey == nil {
return nil, ErrNoPeerStaticKey
}
if rawCertBytes == nil {
return nil, ErrNoPayload
}
c, err := unmarshalCertificateFromHandshake(v, rawCertBytes, publicKey, curve)
if err != nil {
return nil, fmt.Errorf("error unmarshaling cert: %w", err)
}
cc, err := caPool.VerifyCertificate(time.Now(), c)
if err != nil {
return nil, fmt.Errorf("certificate validation failed: %w", err)
}
return cc, nil
}
func unmarshalCertificateFromHandshake(v Version, b []byte, publicKey []byte, curve Curve) (Certificate, error) {
var c Certificate
var err error
switch v {
case VersionPre1, Version1:
c, err = unmarshalCertificateV1(b, publicKey)
case Version2:
c, err = unmarshalCertificateV2(b, publicKey, curve)
default:
//TODO: make a static var
return nil, fmt.Errorf("unknown certificate version %d", v)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
c.details.PublicKey = publicKey
if c.Curve() != curve {
return nil, fmt.Errorf("certificate curve %s does not match expected %s", c.Curve().String(), curve.String())
}
return c, nil return c, nil
} }

View File

@ -24,21 +24,21 @@ func TestMarshalingNebulaCertificate(t *testing.T) {
nc := certificateV1{ nc := certificateV1{
details: detailsV1{ details: detailsV1{
Name: "testing", name: "testing",
Ips: []netip.Prefix{ networks: []netip.Prefix{
mustParsePrefixUnmapped("10.1.1.1/24"), mustParsePrefixUnmapped("10.1.1.1/24"),
mustParsePrefixUnmapped("10.1.1.2/16"), mustParsePrefixUnmapped("10.1.1.2/16"),
}, },
Subnets: []netip.Prefix{ unsafeNetworks: []netip.Prefix{
mustParsePrefixUnmapped("9.1.1.2/24"), mustParsePrefixUnmapped("9.1.1.2/24"),
mustParsePrefixUnmapped("9.1.1.3/16"), mustParsePrefixUnmapped("9.1.1.3/16"),
}, },
Groups: []string{"test-group1", "test-group2", "test-group3"}, groups: []string{"test-group1", "test-group2", "test-group3"},
NotBefore: before, notBefore: before,
NotAfter: after, notAfter: after,
PublicKey: pubKey, publicKey: pubKey,
IsCA: false, isCA: false,
Issuer: "1234567890abcedfghij1234567890ab", issuer: "1234567890abcedfghij1234567890ab",
}, },
signature: []byte("1234567890abcedfghij1234567890ab"), signature: []byte("1234567890abcedfghij1234567890ab"),
} }
@ -47,20 +47,20 @@ func TestMarshalingNebulaCertificate(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
//t.Log("Cert size:", len(b)) //t.Log("Cert size:", len(b))
nc2, err := unmarshalCertificateV1(b, true) nc2, err := unmarshalCertificateV1(b, nil)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, nc.signature, nc2.Signature()) assert.Equal(t, nc.signature, nc2.Signature())
assert.Equal(t, nc.details.Name, nc2.Name()) assert.Equal(t, nc.details.name, nc2.Name())
assert.Equal(t, nc.details.NotBefore, nc2.NotBefore()) assert.Equal(t, nc.details.notBefore, nc2.NotBefore())
assert.Equal(t, nc.details.NotAfter, nc2.NotAfter()) assert.Equal(t, nc.details.notAfter, nc2.NotAfter())
assert.Equal(t, nc.details.PublicKey, nc2.PublicKey()) assert.Equal(t, nc.details.publicKey, nc2.PublicKey())
assert.Equal(t, nc.details.IsCA, nc2.IsCA()) assert.Equal(t, nc.details.isCA, nc2.IsCA())
assert.Equal(t, nc.details.Ips, nc2.Networks()) assert.Equal(t, nc.details.networks, nc2.Networks())
assert.Equal(t, nc.details.Subnets, nc2.UnsafeNetworks()) assert.Equal(t, nc.details.unsafeNetworks, nc2.UnsafeNetworks())
assert.Equal(t, nc.details.Groups, nc2.Groups()) assert.Equal(t, nc.details.groups, nc2.Groups())
} }
//func TestNebulaCertificate_Sign(t *testing.T) { //func TestNebulaCertificate_Sign(t *testing.T) {
@ -150,8 +150,8 @@ func TestMarshalingNebulaCertificate(t *testing.T) {
func TestNebulaCertificate_Expired(t *testing.T) { func TestNebulaCertificate_Expired(t *testing.T) {
nc := certificateV1{ nc := certificateV1{
details: detailsV1{ details: detailsV1{
NotBefore: time.Now().Add(time.Second * -60).Round(time.Second), notBefore: time.Now().Add(time.Second * -60).Round(time.Second),
NotAfter: time.Now().Add(time.Second * 60).Round(time.Second), notAfter: time.Now().Add(time.Second * 60).Round(time.Second),
}, },
} }
@ -166,21 +166,21 @@ func TestNebulaCertificate_MarshalJSON(t *testing.T) {
nc := certificateV1{ nc := certificateV1{
details: detailsV1{ details: detailsV1{
Name: "testing", name: "testing",
Ips: []netip.Prefix{ networks: []netip.Prefix{
mustParsePrefixUnmapped("10.1.1.1/24"), mustParsePrefixUnmapped("10.1.1.1/24"),
mustParsePrefixUnmapped("10.1.1.2/16"), mustParsePrefixUnmapped("10.1.1.2/16"),
}, },
Subnets: []netip.Prefix{ unsafeNetworks: []netip.Prefix{
mustParsePrefixUnmapped("9.1.1.2/24"), mustParsePrefixUnmapped("9.1.1.2/24"),
mustParsePrefixUnmapped("9.1.1.3/16"), mustParsePrefixUnmapped("9.1.1.3/16"),
}, },
Groups: []string{"test-group1", "test-group2", "test-group3"}, groups: []string{"test-group1", "test-group2", "test-group3"},
NotBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC), notBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC),
NotAfter: time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC), notAfter: time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC),
PublicKey: pubKey, publicKey: pubKey,
IsCA: false, isCA: false,
Issuer: "1234567890abcedfghij1234567890ab", issuer: "1234567890abcedfghij1234567890ab",
}, },
signature: []byte("1234567890abcedfghij1234567890ab"), signature: []byte("1234567890abcedfghij1234567890ab"),
} }
@ -189,7 +189,7 @@ func TestNebulaCertificate_MarshalJSON(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal( assert.Equal(
t, t,
"{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"ips\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"subnets\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\"}", "{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"version\":1}",
string(b), string(b),
) )
} }
@ -526,7 +526,7 @@ func TestNebulaCertificate_Copy(t *testing.T) {
func TestUnmarshalNebulaCertificate(t *testing.T) { func TestUnmarshalNebulaCertificate(t *testing.T) {
// Test that we don't panic with an invalid certificate (#332) // Test that we don't panic with an invalid certificate (#332)
data := []byte("\x98\x00\x00") data := []byte("\x98\x00\x00")
_, err := unmarshalCertificateV1(data, true) _, err := unmarshalCertificateV1(data, nil)
assert.EqualError(t, err, "encoded Details was nil") assert.EqualError(t, err, "encoded Details was nil")
} }

View File

@ -6,19 +6,16 @@ import (
"crypto/ecdsa" "crypto/ecdsa"
"crypto/ed25519" "crypto/ed25519"
"crypto/elliptic" "crypto/elliptic"
"crypto/rand"
"crypto/sha256" "crypto/sha256"
"encoding/binary" "encoding/binary"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"math/big"
"net" "net"
"net/netip" "net/netip"
"time" "time"
"github.com/slackhq/nebula/pkclient"
"golang.org/x/crypto/curve25519" "golang.org/x/crypto/curve25519"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
) )
@ -31,71 +28,71 @@ type certificateV1 struct {
} }
type detailsV1 struct { type detailsV1 struct {
Name string name string
Ips []netip.Prefix networks []netip.Prefix
Subnets []netip.Prefix unsafeNetworks []netip.Prefix
Groups []string groups []string
NotBefore time.Time notBefore time.Time
NotAfter time.Time notAfter time.Time
PublicKey []byte publicKey []byte
IsCA bool isCA bool
Issuer string issuer string
Curve Curve curve Curve
} }
type m map[string]interface{} type m map[string]interface{}
func (nc *certificateV1) Version() Version { func (c *certificateV1) Version() Version {
return Version1 return Version1
} }
func (nc *certificateV1) Curve() Curve { func (c *certificateV1) Curve() Curve {
return nc.details.Curve return c.details.curve
} }
func (nc *certificateV1) Groups() []string { func (c *certificateV1) Groups() []string {
return nc.details.Groups return c.details.groups
} }
func (nc *certificateV1) IsCA() bool { func (c *certificateV1) IsCA() bool {
return nc.details.IsCA return c.details.isCA
} }
func (nc *certificateV1) Issuer() string { func (c *certificateV1) Issuer() string {
return nc.details.Issuer return c.details.issuer
} }
func (nc *certificateV1) Name() string { func (c *certificateV1) Name() string {
return nc.details.Name return c.details.name
} }
func (nc *certificateV1) Networks() []netip.Prefix { func (c *certificateV1) Networks() []netip.Prefix {
return nc.details.Ips return c.details.networks
} }
func (nc *certificateV1) NotAfter() time.Time { func (c *certificateV1) NotAfter() time.Time {
return nc.details.NotAfter return c.details.notAfter
} }
func (nc *certificateV1) NotBefore() time.Time { func (c *certificateV1) NotBefore() time.Time {
return nc.details.NotBefore return c.details.notBefore
} }
func (nc *certificateV1) PublicKey() []byte { func (c *certificateV1) PublicKey() []byte {
return nc.details.PublicKey return c.details.publicKey
} }
func (nc *certificateV1) Signature() []byte { func (c *certificateV1) Signature() []byte {
return nc.signature return c.signature
} }
func (nc *certificateV1) UnsafeNetworks() []netip.Prefix { func (c *certificateV1) UnsafeNetworks() []netip.Prefix {
return nc.details.Subnets return c.details.unsafeNetworks
} }
func (nc *certificateV1) Fingerprint() (string, error) { func (c *certificateV1) Fingerprint() (string, error) {
b, err := nc.Marshal() b, err := c.Marshal()
if err != nil { if err != nil {
return "", err return "", err
} }
@ -104,33 +101,33 @@ func (nc *certificateV1) Fingerprint() (string, error) {
return hex.EncodeToString(sum[:]), nil return hex.EncodeToString(sum[:]), nil
} }
func (nc *certificateV1) CheckSignature(key []byte) bool { func (c *certificateV1) CheckSignature(key []byte) bool {
b, err := proto.Marshal(nc.getRawDetails()) b, err := proto.Marshal(c.getRawDetails())
if err != nil { if err != nil {
return false return false
} }
switch nc.details.Curve { switch c.details.curve {
case Curve_CURVE25519: case Curve_CURVE25519:
return ed25519.Verify(key, b, nc.signature) return ed25519.Verify(key, b, c.signature)
case Curve_P256: case Curve_P256:
x, y := elliptic.Unmarshal(elliptic.P256(), key) x, y := elliptic.Unmarshal(elliptic.P256(), key)
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y} pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
hashed := sha256.Sum256(b) hashed := sha256.Sum256(b)
return ecdsa.VerifyASN1(pubKey, hashed[:], nc.signature) return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
default: default:
return false return false
} }
} }
func (nc *certificateV1) Expired(t time.Time) bool { func (c *certificateV1) Expired(t time.Time) bool {
return nc.details.NotBefore.After(t) || nc.details.NotAfter.Before(t) return c.details.notBefore.After(t) || c.details.notAfter.Before(t)
} }
func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error { func (c *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
if curve != nc.details.Curve { if curve != c.details.curve {
return fmt.Errorf("curve in cert and private key supplied don't match") return fmt.Errorf("curve in cert and private key supplied don't match")
} }
if nc.details.IsCA { if c.details.isCA {
switch curve { switch curve {
case Curve_CURVE25519: case Curve_CURVE25519:
// the call to PublicKey below will panic slice bounds out of range otherwise // the call to PublicKey below will panic slice bounds out of range otherwise
@ -138,7 +135,7 @@ func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key") return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key")
} }
if !ed25519.PublicKey(nc.details.PublicKey).Equal(ed25519.PrivateKey(key).Public()) { if !ed25519.PublicKey(c.details.publicKey).Equal(ed25519.PrivateKey(key).Public()) {
return fmt.Errorf("public key in cert and private key supplied don't match") return fmt.Errorf("public key in cert and private key supplied don't match")
} }
case Curve_P256: case Curve_P256:
@ -147,7 +144,7 @@ func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
return fmt.Errorf("cannot parse private key as P256: %w", err) return fmt.Errorf("cannot parse private key as P256: %w", err)
} }
pub := privkey.PublicKey().Bytes() pub := privkey.PublicKey().Bytes()
if !bytes.Equal(pub, nc.details.PublicKey) { if !bytes.Equal(pub, c.details.publicKey) {
return fmt.Errorf("public key in cert and private key supplied don't match") return fmt.Errorf("public key in cert and private key supplied don't match")
} }
default: default:
@ -173,7 +170,7 @@ func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
default: default:
return fmt.Errorf("invalid curve: %s", curve) return fmt.Errorf("invalid curve: %s", curve)
} }
if !bytes.Equal(pub, nc.details.PublicKey) { if !bytes.Equal(pub, c.details.publicKey) {
return fmt.Errorf("public key in cert and private key supplied don't match") return fmt.Errorf("public key in cert and private key supplied don't match")
} }
@ -181,173 +178,155 @@ func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
} }
// getRawDetails marshals the raw details into protobuf ready struct // getRawDetails marshals the raw details into protobuf ready struct
func (nc *certificateV1) getRawDetails() *RawNebulaCertificateDetails { func (c *certificateV1) getRawDetails() *RawNebulaCertificateDetails {
rd := &RawNebulaCertificateDetails{ rd := &RawNebulaCertificateDetails{
Name: nc.details.Name, Name: c.details.name,
Groups: nc.details.Groups, Groups: c.details.groups,
NotBefore: nc.details.NotBefore.Unix(), NotBefore: c.details.notBefore.Unix(),
NotAfter: nc.details.NotAfter.Unix(), NotAfter: c.details.notAfter.Unix(),
PublicKey: make([]byte, len(nc.details.PublicKey)), PublicKey: make([]byte, len(c.details.publicKey)),
IsCA: nc.details.IsCA, IsCA: c.details.isCA,
Curve: nc.details.Curve, Curve: c.details.curve,
} }
for _, ipNet := range nc.details.Ips { for _, ipNet := range c.details.networks {
mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen()) mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen())
rd.Ips = append(rd.Ips, addr2int(ipNet.Addr()), ip2int(mask)) rd.Ips = append(rd.Ips, addr2int(ipNet.Addr()), ip2int(mask))
} }
for _, ipNet := range nc.details.Subnets { for _, ipNet := range c.details.unsafeNetworks {
mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen()) mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen())
rd.Subnets = append(rd.Subnets, addr2int(ipNet.Addr()), ip2int(mask)) rd.Subnets = append(rd.Subnets, addr2int(ipNet.Addr()), ip2int(mask))
} }
copy(rd.PublicKey, nc.details.PublicKey[:]) copy(rd.PublicKey, c.details.publicKey[:])
// I know, this is terrible // I know, this is terrible
rd.Issuer, _ = hex.DecodeString(nc.details.Issuer) rd.Issuer, _ = hex.DecodeString(c.details.issuer)
return rd return rd
} }
func (nc *certificateV1) String() string { func (c *certificateV1) String() string {
if nc == nil { b, err := json.MarshalIndent(c.marshalJSON(), "", "\t")
return "Certificate {}\n" if err != nil {
return "<error marshalling certificate>"
} }
return string(b)
s := "NebulaCertificate {\n"
s += "\tDetails {\n"
s += fmt.Sprintf("\t\tName: %v\n", nc.details.Name)
if len(nc.details.Ips) > 0 {
s += "\t\tIps: [\n"
for _, ip := range nc.details.Ips {
s += fmt.Sprintf("\t\t\t%v\n", ip.String())
}
s += "\t\t]\n"
} else {
s += "\t\tIps: []\n"
}
if len(nc.details.Subnets) > 0 {
s += "\t\tSubnets: [\n"
for _, ip := range nc.details.Subnets {
s += fmt.Sprintf("\t\t\t%v\n", ip.String())
}
s += "\t\t]\n"
} else {
s += "\t\tSubnets: []\n"
}
if len(nc.details.Groups) > 0 {
s += "\t\tGroups: [\n"
for _, g := range nc.details.Groups {
s += fmt.Sprintf("\t\t\t\"%v\"\n", g)
}
s += "\t\t]\n"
} else {
s += "\t\tGroups: []\n"
}
s += fmt.Sprintf("\t\tNot before: %v\n", nc.details.NotBefore)
s += fmt.Sprintf("\t\tNot After: %v\n", nc.details.NotAfter)
s += fmt.Sprintf("\t\tIs CA: %v\n", nc.details.IsCA)
s += fmt.Sprintf("\t\tIssuer: %s\n", nc.details.Issuer)
s += fmt.Sprintf("\t\tPublic key: %x\n", nc.details.PublicKey)
s += fmt.Sprintf("\t\tCurve: %s\n", nc.details.Curve)
s += "\t}\n"
fp, err := nc.Fingerprint()
if err == nil {
s += fmt.Sprintf("\tFingerprint: %s\n", fp)
}
s += fmt.Sprintf("\tSignature: %x\n", nc.Signature())
s += "}"
return s
} }
func (nc *certificateV1) MarshalForHandshakes() ([]byte, error) { func (c *certificateV1) MarshalForHandshakes() ([]byte, error) {
pubKey := nc.details.PublicKey pubKey := c.details.publicKey
nc.details.PublicKey = nil c.details.publicKey = nil
rawCertNoKey, err := nc.Marshal() rawCertNoKey, err := c.Marshal()
if err != nil { if err != nil {
return nil, err return nil, err
} }
nc.details.PublicKey = pubKey c.details.publicKey = pubKey
return rawCertNoKey, nil return rawCertNoKey, nil
} }
func (nc *certificateV1) Marshal() ([]byte, error) { func (c *certificateV1) Marshal() ([]byte, error) {
rc := RawNebulaCertificate{ rc := RawNebulaCertificate{
Details: nc.getRawDetails(), Details: c.getRawDetails(),
Signature: nc.signature, Signature: c.signature,
} }
return proto.Marshal(&rc) return proto.Marshal(&rc)
} }
func (nc *certificateV1) MarshalPEM() ([]byte, error) { func (c *certificateV1) MarshalPEM() ([]byte, error) {
b, err := nc.Marshal() b, err := c.Marshal()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return pem.EncodeToMemory(&pem.Block{Type: CertificateBanner, Bytes: b}), nil return pem.EncodeToMemory(&pem.Block{Type: CertificateBanner, Bytes: b}), nil
} }
func (nc *certificateV1) MarshalJSON() ([]byte, error) { func (c *certificateV1) MarshalJSON() ([]byte, error) {
fp, _ := nc.Fingerprint() return json.Marshal(c.marshalJSON())
jc := m{
"details": m{
"name": nc.details.Name,
"ips": nc.details.Ips,
"subnets": nc.details.Subnets,
"groups": nc.details.Groups,
"notBefore": nc.details.NotBefore,
"notAfter": nc.details.NotAfter,
"publicKey": fmt.Sprintf("%x", nc.details.PublicKey),
"isCa": nc.details.IsCA,
"issuer": nc.details.Issuer,
"curve": nc.details.Curve.String(),
},
"fingerprint": fp,
"signature": fmt.Sprintf("%x", nc.Signature()),
}
return json.Marshal(jc)
} }
func (nc *certificateV1) Copy() Certificate { func (c *certificateV1) marshalJSON() m {
c := &certificateV1{ fp, _ := c.Fingerprint()
details: detailsV1{ return m{
Name: nc.details.Name, "version": Version1,
Groups: make([]string, len(nc.details.Groups)), "details": m{
Ips: make([]netip.Prefix, len(nc.details.Ips)), "name": c.details.name,
Subnets: make([]netip.Prefix, len(nc.details.Subnets)), "networks": c.details.networks,
NotBefore: nc.details.NotBefore, "unsafeNetworks": c.details.unsafeNetworks,
NotAfter: nc.details.NotAfter, "groups": c.details.groups,
PublicKey: make([]byte, len(nc.details.PublicKey)), "notBefore": c.details.notBefore,
IsCA: nc.details.IsCA, "notAfter": c.details.notAfter,
Issuer: nc.details.Issuer, "publicKey": fmt.Sprintf("%x", c.details.publicKey),
"isCa": c.details.isCA,
"issuer": c.details.issuer,
"curve": c.details.curve.String(),
}, },
signature: make([]byte, len(nc.signature)), "fingerprint": fp,
"signature": fmt.Sprintf("%x", c.Signature()),
}
}
func (c *certificateV1) Copy() Certificate {
nc := &certificateV1{
details: detailsV1{
name: c.details.name,
groups: make([]string, len(c.details.groups)),
networks: make([]netip.Prefix, len(c.details.networks)),
unsafeNetworks: make([]netip.Prefix, len(c.details.unsafeNetworks)),
notBefore: c.details.notBefore,
notAfter: c.details.notAfter,
publicKey: make([]byte, len(c.details.publicKey)),
isCA: c.details.isCA,
issuer: c.details.issuer,
curve: c.details.curve,
},
signature: make([]byte, len(c.signature)),
} }
copy(c.signature, nc.signature) copy(nc.signature, c.signature)
copy(c.details.Groups, nc.details.Groups) copy(nc.details.groups, c.details.groups)
copy(c.details.PublicKey, nc.details.PublicKey) copy(nc.details.publicKey, c.details.publicKey)
copy(nc.details.networks, c.details.networks)
copy(nc.details.unsafeNetworks, c.details.unsafeNetworks)
for i, p := range nc.details.Ips { return nc
c.details.Ips[i] = p }
func (c *certificateV1) fromTBSCertificate(t *TBSCertificate) error {
c.details = detailsV1{
name: t.Name,
networks: t.Networks,
unsafeNetworks: t.UnsafeNetworks,
groups: t.Groups,
notBefore: t.NotBefore,
notAfter: t.NotAfter,
publicKey: t.PublicKey,
isCA: t.IsCA,
curve: t.Curve,
issuer: t.issuer,
} }
for i, p := range nc.details.Subnets { return nil
c.details.Subnets[i] = p }
}
return c func (c *certificateV1) marshalForSigning() ([]byte, error) {
b, err := proto.Marshal(c.getRawDetails())
if err != nil {
return nil, err
}
return b, nil
}
func (c *certificateV1) setSignature(b []byte) error {
c.signature = b
return nil
} }
// unmarshalCertificateV1 will unmarshal a protobuf byte representation of a nebula cert // unmarshalCertificateV1 will unmarshal a protobuf byte representation of a nebula cert
func unmarshalCertificateV1(b []byte, assertPublicKey bool) (*certificateV1, error) { // if the publicKey is provided here then it is not required to be present in `b`
func unmarshalCertificateV1(b []byte, publicKey []byte) (*certificateV1, error) {
if len(b) == 0 { if len(b) == 0 {
return nil, fmt.Errorf("nil byte array") return nil, fmt.Errorf("nil byte array")
} }
@ -371,27 +350,28 @@ func unmarshalCertificateV1(b []byte, assertPublicKey bool) (*certificateV1, err
nc := certificateV1{ nc := certificateV1{
details: detailsV1{ details: detailsV1{
Name: rc.Details.Name, name: rc.Details.Name,
Groups: make([]string, len(rc.Details.Groups)), groups: make([]string, len(rc.Details.Groups)),
Ips: make([]netip.Prefix, len(rc.Details.Ips)/2), networks: make([]netip.Prefix, len(rc.Details.Ips)/2),
Subnets: make([]netip.Prefix, len(rc.Details.Subnets)/2), unsafeNetworks: make([]netip.Prefix, len(rc.Details.Subnets)/2),
NotBefore: time.Unix(rc.Details.NotBefore, 0), notBefore: time.Unix(rc.Details.NotBefore, 0),
NotAfter: time.Unix(rc.Details.NotAfter, 0), notAfter: time.Unix(rc.Details.NotAfter, 0),
PublicKey: make([]byte, len(rc.Details.PublicKey)), publicKey: make([]byte, len(rc.Details.PublicKey)),
IsCA: rc.Details.IsCA, isCA: rc.Details.IsCA,
Curve: rc.Details.Curve, curve: rc.Details.Curve,
}, },
signature: make([]byte, len(rc.Signature)), signature: make([]byte, len(rc.Signature)),
} }
copy(nc.signature, rc.Signature) copy(nc.signature, rc.Signature)
copy(nc.details.Groups, rc.Details.Groups) copy(nc.details.groups, rc.Details.Groups)
nc.details.Issuer = hex.EncodeToString(rc.Details.Issuer) nc.details.issuer = hex.EncodeToString(rc.Details.Issuer)
if len(rc.Details.PublicKey) < publicKeyLen && assertPublicKey { if len(publicKey) > 0 {
return nil, fmt.Errorf("public key was fewer than 32 bytes; %v", len(rc.Details.PublicKey)) nc.details.publicKey = publicKey
} }
copy(nc.details.PublicKey, rc.Details.PublicKey)
copy(nc.details.publicKey, rc.Details.PublicKey)
var ip netip.Addr var ip netip.Addr
for i, rawIp := range rc.Details.Ips { for i, rawIp := range rc.Details.Ips {
@ -399,7 +379,7 @@ func unmarshalCertificateV1(b []byte, assertPublicKey bool) (*certificateV1, err
ip = int2addr(rawIp) ip = int2addr(rawIp)
} else { } else {
ones, _ := net.IPMask(int2ip(rawIp)).Size() ones, _ := net.IPMask(int2ip(rawIp)).Size()
nc.details.Ips[i/2] = netip.PrefixFrom(ip, ones) nc.details.networks[i/2] = netip.PrefixFrom(ip, ones)
} }
} }
@ -408,69 +388,15 @@ func unmarshalCertificateV1(b []byte, assertPublicKey bool) (*certificateV1, err
ip = int2addr(rawIp) ip = int2addr(rawIp)
} else { } else {
ones, _ := net.IPMask(int2ip(rawIp)).Size() ones, _ := net.IPMask(int2ip(rawIp)).Size()
nc.details.Subnets[i/2] = netip.PrefixFrom(ip, ones) nc.details.unsafeNetworks[i/2] = netip.PrefixFrom(ip, ones)
} }
} }
//do not sort the subnets field for V1 certs
return &nc, nil return &nc, nil
} }
func signV1(t *TBSCertificate, curve Curve, key []byte, client *pkclient.PKClient) (*certificateV1, error) {
c := &certificateV1{
details: detailsV1{
Name: t.Name,
Ips: t.Networks,
Subnets: t.UnsafeNetworks,
Groups: t.Groups,
NotBefore: t.NotBefore,
NotAfter: t.NotAfter,
PublicKey: t.PublicKey,
IsCA: t.IsCA,
Curve: t.Curve,
Issuer: t.issuer,
},
}
b, err := proto.Marshal(c.getRawDetails())
if err != nil {
return nil, err
}
var sig []byte
switch curve {
case Curve_CURVE25519:
signer := ed25519.PrivateKey(key)
sig = ed25519.Sign(signer, b)
case Curve_P256:
if client != nil {
sig, err = client.SignASN1(b)
} else {
signer := &ecdsa.PrivateKey{
PublicKey: ecdsa.PublicKey{
Curve: elliptic.P256(),
},
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95
D: new(big.Int).SetBytes(key),
}
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119
signer.X, signer.Y = signer.Curve.ScalarBaseMult(key)
// We need to hash first for ECDSA
// - https://pkg.go.dev/crypto/ecdsa#SignASN1
hashed := sha256.Sum256(b)
sig, err = ecdsa.SignASN1(rand.Reader, signer, hashed[:])
if err != nil {
return nil, err
}
}
default:
return nil, fmt.Errorf("invalid curve: %s", c.details.Curve)
}
c.signature = sig
return c, nil
}
func ip2int(ip []byte) uint32 { func ip2int(ip []byte) uint32 {
if len(ip) == 16 { if len(ip) == 16 {
return binary.BigEndian.Uint32(ip[12:16]) return binary.BigEndian.Uint32(ip[12:16])

37
cert/cert_v2.asn1 Normal file
View File

@ -0,0 +1,37 @@
Nebula DEFINITIONS AUTOMATIC TAGS ::= BEGIN
Name ::= UTF8String (SIZE (1..253))
Time ::= INTEGER (0..18446744073709551615) -- Seconds since unix epoch, uint64 maximum
Network ::= OCTET STRING (SIZE (5,17)) -- IP addresses are 4 or 16 bytes + 1 byte for the prefix length
Curve ::= ENUMERATED {
curve25519 (0),
p256 (1)
}
-- The maximum size of a certificate must not exceed 65536 bytes
Certificate ::= SEQUENCE {
details OCTET STRING,
curve Curve DEFAULT curve25519,
publicKey OCTET STRING,
-- signature(details + curve + publicKey) using the appropriate method for curve
signature OCTET STRING
}
Details ::= SEQUENCE {
name Name,
-- At least 1 ipv4 or ipv6 address must be present if isCA is false
networks SEQUENCE OF Network OPTIONAL,
unsafeNetworks SEQUENCE OF Network OPTIONAL,
groups SEQUENCE OF Name OPTIONAL,
isCA BOOLEAN DEFAULT false,
notBefore Time,
notAfter Time,
-- issuer is only required if isCA is false, if isCA is true then it must not be present
issuer OCTET STRING OPTIONAL,
...
-- New fields can be added below here
}
END

621
cert/cert_v2.go Normal file
View File

@ -0,0 +1,621 @@
package cert
import (
"bytes"
"crypto/ecdh"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"encoding/pem"
"fmt"
"net/netip"
"slices"
"time"
"golang.org/x/crypto/cryptobyte"
"golang.org/x/crypto/cryptobyte/asn1"
"golang.org/x/crypto/curve25519"
)
//TODO: should we avoid hex encoding shit on output? Just let it be base64?
const (
classConstructed = 0x20
classContextSpecific = 0x80
TagCertDetails = 0 | classConstructed | classContextSpecific
TagCertCurve = 1 | classContextSpecific
TagCertPublicKey = 2 | classContextSpecific
TagCertSignature = 3 | classContextSpecific
TagDetailsName = 0 | classContextSpecific
TagDetailsNetworks = 1 | classConstructed | classContextSpecific
TagDetailsUnsafeNetworks = 2 | classConstructed | classContextSpecific
TagDetailsGroups = 3 | classConstructed | classContextSpecific
TagDetailsIsCA = 4 | classContextSpecific
TagDetailsNotBefore = 5 | classContextSpecific
TagDetailsNotAfter = 6 | classContextSpecific
TagDetailsIssuer = 7 | classContextSpecific
)
const (
// MaxCertificateSize is the maximum length a valid certificate can be
MaxCertificateSize = 65536
// MaxNameLength is limited to a maximum realistic DNS domain name to help facilitate DNS systems
MaxNameLength = 253
// MaxNetworkLength is the maximum length a network value can be.
// 16 bytes for an ipv6 address + 1 byte for the prefix length
MaxNetworkLength = 17
)
type certificateV2 struct {
details detailsV2
// RawDetails contains the entire asn.1 DER encoded Details struct
// This is to benefit forwards compatibility in signature checking.
// signature(RawDetails + Curve + PublicKey) == Signature
rawDetails []byte
curve Curve
publicKey []byte
signature []byte
}
type detailsV2 struct {
name string
networks []netip.Prefix
unsafeNetworks []netip.Prefix
groups []string
isCA bool
notBefore time.Time
notAfter time.Time
issuer string
}
func (c *certificateV2) Version() Version {
return Version2
}
func (c *certificateV2) Curve() Curve {
return c.curve
}
func (c *certificateV2) Groups() []string {
return c.details.groups
}
func (c *certificateV2) IsCA() bool {
return c.details.isCA
}
func (c *certificateV2) Issuer() string {
return c.details.issuer
}
func (c *certificateV2) Name() string {
return c.details.name
}
func (c *certificateV2) Networks() []netip.Prefix {
return c.details.networks
}
func (c *certificateV2) NotAfter() time.Time {
return c.details.notAfter
}
func (c *certificateV2) NotBefore() time.Time {
return c.details.notBefore
}
func (c *certificateV2) PublicKey() []byte {
return c.publicKey
}
func (c *certificateV2) Signature() []byte {
return c.signature
}
func (c *certificateV2) UnsafeNetworks() []netip.Prefix {
return c.details.unsafeNetworks
}
func (c *certificateV2) Fingerprint() (string, error) {
b := make([]byte, len(c.rawDetails)+1+len(c.publicKey))
//TODO: double check this, panic on empty raw details
copy(b, c.rawDetails)
b[len(c.rawDetails)] = byte(c.curve)
copy(b[len(c.rawDetails)+1:], c.publicKey)
copy(b[len(c.rawDetails)+1+len(c.publicKey):], c.signature)
sum := sha256.Sum256(b)
return hex.EncodeToString(sum[:]), nil
}
func (c *certificateV2) CheckSignature(key []byte) bool {
b := make([]byte, len(c.rawDetails)+1+len(c.publicKey))
//TODO: double check this, panic on empty raw details
copy(b, c.rawDetails)
b[len(c.rawDetails)] = byte(c.curve)
copy(b[len(c.rawDetails)+1:], c.publicKey)
switch c.curve {
case Curve_CURVE25519:
return ed25519.Verify(key, b, c.signature)
case Curve_P256:
//TODO: NewPublicKey
x, y := elliptic.Unmarshal(elliptic.P256(), key)
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
hashed := sha256.Sum256(b)
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
default:
return false
}
}
func (c *certificateV2) Expired(t time.Time) bool {
return c.details.notBefore.After(t) || c.details.notAfter.Before(t)
}
func (c *certificateV2) VerifyPrivateKey(curve Curve, key []byte) error {
if curve != c.curve {
return fmt.Errorf("curve in cert and private key supplied don't match")
}
if c.details.isCA {
switch curve {
case Curve_CURVE25519:
// the call to PublicKey below will panic slice bounds out of range otherwise
if len(key) != ed25519.PrivateKeySize {
return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key")
}
if !ed25519.PublicKey(c.publicKey).Equal(ed25519.PrivateKey(key).Public()) {
return fmt.Errorf("public key in cert and private key supplied don't match")
}
case Curve_P256:
privkey, err := ecdh.P256().NewPrivateKey(key)
if err != nil {
return fmt.Errorf("cannot parse private key as P256")
}
pub := privkey.PublicKey().Bytes()
if !bytes.Equal(pub, c.publicKey) {
return fmt.Errorf("public key in cert and private key supplied don't match")
}
default:
return fmt.Errorf("invalid curve: %s", curve)
}
return nil
}
var pub []byte
switch curve {
case Curve_CURVE25519:
var err error
pub, err = curve25519.X25519(key, curve25519.Basepoint)
if err != nil {
return err
}
case Curve_P256:
privkey, err := ecdh.P256().NewPrivateKey(key)
if err != nil {
return err
}
pub = privkey.PublicKey().Bytes()
default:
return fmt.Errorf("invalid curve: %s", curve)
}
if !bytes.Equal(pub, c.publicKey) {
return fmt.Errorf("public key in cert and private key supplied don't match")
}
return nil
}
func (c *certificateV2) String() string {
b, err := json.MarshalIndent(c.marshalJSON(), "", "\t")
if err != nil {
return "<error marshalling certificate>"
}
return string(b)
}
func (c *certificateV2) MarshalForHandshakes() ([]byte, error) {
var b cryptobyte.Builder
// Outermost certificate
b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
// Add the cert details which is already marshalled
//TODO: panic on nil rawDetails
b.AddBytes(c.rawDetails)
// Skipping the curve and public key since those come across in a different part of the handshake
// Add the signature
b.AddASN1(TagCertSignature, func(b *cryptobyte.Builder) {
b.AddBytes(c.signature)
})
})
return b.Bytes()
}
func (c *certificateV2) Marshal() ([]byte, error) {
var b cryptobyte.Builder
// Outermost certificate
b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
// Add the cert details which is already marshalled
b.AddBytes(c.rawDetails)
// Add the curve only if its not the default value
if c.curve != Curve_CURVE25519 {
b.AddASN1(TagCertCurve, func(b *cryptobyte.Builder) {
b.AddBytes([]byte{byte(c.curve)})
})
}
// Add the public key if it is not empty
if c.publicKey != nil {
b.AddASN1(TagCertPublicKey, func(b *cryptobyte.Builder) {
b.AddBytes(c.publicKey)
})
}
// Add the signature
b.AddASN1(TagCertSignature, func(b *cryptobyte.Builder) {
b.AddBytes(c.signature)
})
})
return b.Bytes()
}
func (c *certificateV2) MarshalPEM() ([]byte, error) {
b, err := c.Marshal()
if err != nil {
return nil, err
}
return pem.EncodeToMemory(&pem.Block{Type: CertificateV2Banner, Bytes: b}), nil
}
func (c *certificateV2) MarshalJSON() ([]byte, error) {
return json.Marshal(c.marshalJSON())
}
func (c *certificateV2) marshalJSON() m {
fp, _ := c.Fingerprint()
return m{
"details": m{
"name": c.details.name,
"networks": c.details.networks,
"unsafeNetworks": c.details.unsafeNetworks,
"groups": c.details.groups,
"notBefore": c.details.notBefore,
"notAfter": c.details.notAfter,
"isCa": c.details.isCA,
"issuer": c.details.issuer,
},
"version": Version2,
"publicKey": fmt.Sprintf("%x", c.publicKey),
"curve": c.curve.String(),
"fingerprint": fp,
"signature": fmt.Sprintf("%x", c.Signature()),
}
}
func (c *certificateV2) Copy() Certificate {
nc := &certificateV2{
details: detailsV2{
name: c.details.name,
groups: make([]string, len(c.details.groups)),
networks: make([]netip.Prefix, len(c.details.networks)),
unsafeNetworks: make([]netip.Prefix, len(c.details.unsafeNetworks)),
notBefore: c.details.notBefore,
notAfter: c.details.notAfter,
isCA: c.details.isCA,
issuer: c.details.issuer,
},
curve: c.curve,
publicKey: make([]byte, len(c.publicKey)),
signature: make([]byte, len(c.signature)),
}
copy(nc.signature, c.signature)
copy(nc.details.groups, c.details.groups)
copy(nc.publicKey, c.publicKey)
copy(nc.details.networks, c.details.networks)
copy(nc.details.unsafeNetworks, c.details.unsafeNetworks)
return nc
}
func (c *certificateV2) fromTBSCertificate(t *TBSCertificate) error {
c.details = detailsV2{
name: t.Name,
networks: t.Networks,
unsafeNetworks: t.UnsafeNetworks,
groups: t.Groups,
isCA: t.IsCA,
notBefore: t.NotBefore,
notAfter: t.NotAfter,
issuer: t.issuer,
}
c.curve = t.Curve
c.publicKey = t.PublicKey
return nil
}
func (c *certificateV2) marshalForSigning() ([]byte, error) {
d, err := c.details.Marshal()
if err != nil {
//TODO: annotate?
return nil, err
}
c.rawDetails = d
b := make([]byte, len(c.rawDetails)+1+len(c.publicKey))
//TODO: double check this
copy(b, c.rawDetails)
b[len(c.rawDetails)] = byte(c.curve)
copy(b[len(c.rawDetails)+1:], c.publicKey)
return b, nil
}
func (c *certificateV2) setSignature(b []byte) error {
c.signature = b
return nil
}
func (d *detailsV2) Marshal() ([]byte, error) {
var b cryptobyte.Builder
var err error
// Details are a structure
b.AddASN1(TagCertDetails, func(b *cryptobyte.Builder) {
// Add the name
b.AddASN1(TagDetailsName, func(b *cryptobyte.Builder) {
b.AddBytes([]byte(d.name))
})
// Add the networks if any exist
if len(d.networks) > 0 {
b.AddASN1(TagDetailsNetworks, func(b *cryptobyte.Builder) {
for _, n := range d.networks {
sb, innerErr := n.MarshalBinary()
if innerErr != nil {
// MarshalBinary never returns an error
err = fmt.Errorf("unable to marshal network: %w", innerErr)
return
}
b.AddASN1OctetString(sb)
}
})
}
// Add the unsafe networks if any exist
if len(d.unsafeNetworks) > 0 {
b.AddASN1(TagDetailsUnsafeNetworks, func(b *cryptobyte.Builder) {
for _, n := range d.unsafeNetworks {
sb, innerErr := n.MarshalBinary()
if innerErr != nil {
// MarshalBinary never returns an error
err = fmt.Errorf("unable to marshal unsafe network: %w", innerErr)
return
}
b.AddASN1OctetString(sb)
}
})
}
// Add groups if any exist
if len(d.groups) > 0 {
b.AddASN1(TagDetailsGroups, func(b *cryptobyte.Builder) {
for _, group := range d.groups {
b.AddASN1(asn1.UTF8String, func(b *cryptobyte.Builder) {
b.AddBytes([]byte(group))
})
}
})
}
// Add IsCA only if true
if d.isCA {
b.AddASN1(TagDetailsIsCA, func(b *cryptobyte.Builder) {
b.AddUint8(0xff)
})
}
// Add not before
b.AddASN1Int64WithTag(d.notBefore.Unix(), TagDetailsNotBefore)
// Add not after
b.AddASN1Int64WithTag(d.notAfter.Unix(), TagDetailsNotAfter)
// Add the issuer if present
if d.issuer != "" {
issuerBytes, innerErr := hex.DecodeString(d.issuer)
if innerErr != nil {
err = fmt.Errorf("failed to decode issuer: %w", innerErr)
return
}
b.AddASN1(TagDetailsIssuer, func(b *cryptobyte.Builder) {
b.AddBytes(issuerBytes)
})
}
})
if err != nil {
return nil, err
}
return b.Bytes()
}
func unmarshalCertificateV2(b []byte, publicKey []byte, curve Curve) (*certificateV2, error) {
l := len(b)
if l == 0 || l > MaxCertificateSize {
return nil, ErrBadFormat
}
input := cryptobyte.String(b)
// Open the envelope
if !input.ReadASN1(&input, asn1.SEQUENCE) || input.Empty() {
return nil, ErrBadFormat
}
// Grab the cert details, we need to preserve the tag and length
var rawDetails cryptobyte.String
if !input.ReadASN1Element(&rawDetails, TagCertDetails) || rawDetails.Empty() {
return nil, ErrBadFormat
}
//Maybe grab the curve
var rawCurve byte
if !readOptionalASN1Byte(&input, &rawCurve, TagCertCurve, byte(curve)) {
return nil, ErrBadFormat
}
curve = Curve(rawCurve)
// Maybe grab the public key
var rawPublicKey cryptobyte.String
if len(publicKey) > 0 {
rawPublicKey = publicKey
} else if !input.ReadOptionalASN1(&rawPublicKey, nil, TagCertPublicKey) {
return nil, ErrBadFormat
}
//TODO: Assert public key length
// Grab the signature
var rawSignature cryptobyte.String
if !input.ReadASN1(&rawSignature, TagCertSignature) || rawSignature.Empty() {
return nil, ErrBadFormat
}
// Finally unmarshal the details
details, err := unmarshalDetails(rawDetails)
if err != nil {
return nil, err
}
return &certificateV2{
details: details,
rawDetails: rawDetails,
curve: curve,
publicKey: rawPublicKey,
signature: rawSignature,
}, nil
}
func unmarshalDetails(b cryptobyte.String) (detailsV2, error) {
// Open the envelope
if !b.ReadASN1(&b, TagCertDetails) || b.Empty() {
return detailsV2{}, ErrBadFormat
}
// Read the name
var name cryptobyte.String
if !b.ReadASN1(&name, TagDetailsName) || name.Empty() || len(name) > MaxNameLength {
return detailsV2{}, ErrBadFormat
}
// Read the network addresses
var subString cryptobyte.String
var found bool
if !b.ReadOptionalASN1(&subString, &found, TagDetailsNetworks) {
return detailsV2{}, ErrBadFormat
}
var networks []netip.Prefix
var val cryptobyte.String
if found {
for !subString.Empty() {
if !subString.ReadASN1(&val, asn1.OCTET_STRING) || val.Empty() || len(val) > MaxNetworkLength {
return detailsV2{}, ErrBadFormat
}
var n netip.Prefix
if err := n.UnmarshalBinary(val); err != nil {
return detailsV2{}, ErrBadFormat
}
networks = append(networks, n)
}
}
// Read out any unsafe networks
if !b.ReadOptionalASN1(&subString, &found, TagDetailsUnsafeNetworks) {
return detailsV2{}, ErrBadFormat
}
var unsafeNetworks []netip.Prefix
if found {
for !subString.Empty() {
if !subString.ReadASN1(&val, asn1.OCTET_STRING) || val.Empty() || len(val) > MaxNetworkLength {
return detailsV2{}, ErrBadFormat
}
var n netip.Prefix
if err := n.UnmarshalBinary(val); err != nil {
return detailsV2{}, ErrBadFormat
}
unsafeNetworks = append(unsafeNetworks, n)
}
}
// Read out any groups
if !b.ReadOptionalASN1(&subString, &found, TagDetailsGroups) {
return detailsV2{}, ErrBadFormat
}
var groups []string
if found {
for !subString.Empty() {
if !subString.ReadASN1(&val, asn1.UTF8String) || val.Empty() {
return detailsV2{}, ErrBadFormat
}
groups = append(groups, string(val))
}
}
// Read out IsCA
var isCa bool
if !readOptionalASN1Boolean(&b, &isCa, TagDetailsIsCA, false) {
return detailsV2{}, ErrBadFormat
}
// Read not before and not after
var notBefore int64
if !b.ReadASN1Int64WithTag(&notBefore, TagDetailsNotBefore) {
return detailsV2{}, ErrBadFormat
}
var notAfter int64
if !b.ReadASN1Int64WithTag(&notAfter, TagDetailsNotAfter) {
return detailsV2{}, ErrBadFormat
}
// Read issuer
var issuer cryptobyte.String
if !b.ReadOptionalASN1(&issuer, nil, TagDetailsIssuer) {
return detailsV2{}, ErrBadFormat
}
slices.SortFunc(networks, comparePrefix)
slices.SortFunc(unsafeNetworks, comparePrefix)
return detailsV2{
name: string(name),
networks: networks,
unsafeNetworks: unsafeNetworks,
groups: groups,
isCA: isCa,
notBefore: time.Unix(notBefore, 0),
notAfter: time.Unix(notAfter, 0),
issuer: hex.EncodeToString(issuer),
}, nil
}

View File

@ -24,4 +24,7 @@ var (
ErrInvalidPEMX25519PrivateKeyBanner = errors.New("bytes did not contain a proper X25519 private key banner") ErrInvalidPEMX25519PrivateKeyBanner = errors.New("bytes did not contain a proper X25519 private key banner")
ErrInvalidPEMEd25519PublicKeyBanner = errors.New("bytes did not contain a proper Ed25519 public key banner") ErrInvalidPEMEd25519PublicKeyBanner = errors.New("bytes did not contain a proper Ed25519 public key banner")
ErrInvalidPEMEd25519PrivateKeyBanner = errors.New("bytes did not contain a proper Ed25519 private key banner") ErrInvalidPEMEd25519PrivateKeyBanner = errors.New("bytes did not contain a proper Ed25519 private key banner")
ErrNoPeerStaticKey = errors.New("no peer static key was present")
ErrNoPayload = errors.New("provided payload was empty")
) )

View File

@ -30,19 +30,24 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
return nil, r, ErrInvalidPEMBlock return nil, r, ErrInvalidPEMBlock
} }
var c Certificate
var err error
switch p.Type { switch p.Type {
case CertificateBanner: case CertificateBanner:
c, err := unmarshalCertificateV1(p.Bytes, true) c, err = unmarshalCertificateV1(p.Bytes, nil)
if err != nil {
return nil, nil, err
}
return c, r, nil
case CertificateV2Banner: case CertificateV2Banner:
//TODO c, err = unmarshalCertificateV2(p.Bytes, nil, Curve_CURVE25519)
panic("TODO")
default: default:
return nil, r, ErrInvalidPEMCertificateBanner return nil, r, ErrInvalidPEMCertificateBanner
} }
if err != nil {
return nil, r, err
}
return c, r, nil
} }
func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte { func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {

View File

@ -1,11 +1,16 @@
package cert package cert
import ( import (
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rand"
"crypto/sha256"
"fmt" "fmt"
"math/big"
"net/netip" "net/netip"
"slices"
"time" "time"
"github.com/slackhq/nebula/pkclient"
) )
// TBSCertificate represents a certificate intended to be signed. // TBSCertificate represents a certificate intended to be signed.
@ -24,27 +29,62 @@ type TBSCertificate struct {
issuer string issuer string
} }
type beingSignedCertificate interface {
// fromTBSCertificate copies the values from the TBSCertificate to this versions internal representation
fromTBSCertificate(*TBSCertificate) error
// marshalForSigning returns the bytes that should be signed
marshalForSigning() ([]byte, error)
// setSignature sets the signature for the certificate that has just been signed
setSignature([]byte) error
}
type SignerLambda func(certBytes []byte) ([]byte, error)
// Sign will create a sealed certificate using details provided by the TBSCertificate as long as those // Sign will create a sealed certificate using details provided by the TBSCertificate as long as those
// details do not violate constraints of the signing certificate. // details do not violate constraints of the signing certificate.
// If the TBSCertificate is a CA then signer must be nil. // If the TBSCertificate is a CA then signer must be nil.
func (t *TBSCertificate) Sign(signer Certificate, curve Curve, key []byte) (Certificate, error) { func (t *TBSCertificate) Sign(signer Certificate, curve Curve, key []byte) (Certificate, error) {
return t.sign(signer, curve, key, nil) switch t.Curve {
} case Curve_CURVE25519:
pk := ed25519.PrivateKey(key)
func (t *TBSCertificate) SignPkcs11(signer Certificate, curve Curve, client *pkclient.PKClient) (Certificate, error) { sp := func(certBytes []byte) ([]byte, error) {
if curve != Curve_P256 { sig := ed25519.Sign(pk, certBytes)
return nil, fmt.Errorf("only P256 is supported by PKCS#11") return sig, nil
}
return t.SignWith(signer, curve, sp)
case Curve_P256:
pk := &ecdsa.PrivateKey{
PublicKey: ecdsa.PublicKey{
Curve: elliptic.P256(),
},
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95
D: new(big.Int).SetBytes(key),
}
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119
pk.X, pk.Y = pk.Curve.ScalarBaseMult(key)
sp := func(certBytes []byte) ([]byte, error) {
// We need to hash first for ECDSA
// - https://pkg.go.dev/crypto/ecdsa#SignASN1
hashed := sha256.Sum256(certBytes)
return ecdsa.SignASN1(rand.Reader, pk, hashed[:])
}
return t.SignWith(signer, curve, sp)
default:
return nil, fmt.Errorf("invalid curve: %s", t.Curve)
} }
return t.sign(signer, curve, nil, client)
} }
func (t *TBSCertificate) sign(signer Certificate, curve Curve, key []byte, client *pkclient.PKClient) (Certificate, error) { // SignWith does the same thing as sign, but uses the function in `sp` to calculate the signature.
// You should only use SignWith if you do not have direct access to your private key.
func (t *TBSCertificate) SignWith(signer Certificate, curve Curve, sp SignerLambda) (Certificate, error) {
if curve != t.Curve { if curve != t.Curve {
return nil, fmt.Errorf("curve in cert and private key supplied don't match") return nil, fmt.Errorf("curve in cert and private key supplied don't match")
} }
//TODO: make sure we have all minimum properties to sign, like a public key //TODO: make sure we have all minimum properties to sign, like a public key
//TODO: we need to verify networks and unsafe networks (no duplicates, max of 1 of each version for v2 certs
if signer != nil { if signer != nil {
if t.IsCA { if t.IsCA {
@ -67,10 +107,55 @@ func (t *TBSCertificate) sign(signer Certificate, curve Curve, key []byte, clien
} }
} }
slices.SortFunc(t.Networks, comparePrefix)
slices.SortFunc(t.UnsafeNetworks, comparePrefix)
var c beingSignedCertificate
switch t.Version { switch t.Version {
case Version1: case Version1:
return signV1(t, curve, key, client) c = &certificateV1{}
err := c.fromTBSCertificate(t)
if err != nil {
return nil, err
}
case Version2:
c = &certificateV2{}
err := c.fromTBSCertificate(t)
if err != nil {
return nil, err
}
default: default:
return nil, fmt.Errorf("unknown cert version %d", t.Version) return nil, fmt.Errorf("unknown cert version %d", t.Version)
} }
certBytes, err := c.marshalForSigning()
if err != nil {
return nil, err
}
sig, err := sp(certBytes)
if err != nil {
return nil, err
}
//TODO: check if we have sig bytes?
err = c.setSignature(sig)
if err != nil {
return nil, err
}
sc, ok := c.(Certificate)
if !ok {
return nil, fmt.Errorf("invalid certificate")
}
return sc, nil
}
func comparePrefix(a, b netip.Prefix) int {
addr := a.Addr().Compare(b.Addr())
if addr == 0 {
return a.Bits() - b.Bits()
}
return addr
} }

View File

@ -27,34 +27,43 @@ type caFlags struct {
outCertPath *string outCertPath *string
outQRPath *string outQRPath *string
groups *string groups *string
ips *string networks *string
subnets *string unsafeNetworks *string
argonMemory *uint argonMemory *uint
argonIterations *uint argonIterations *uint
argonParallelism *uint argonParallelism *uint
encryption *bool encryption *bool
version *uint
curve *string curve *string
p11url *string p11url *string
// Deprecated options
ips *string
subnets *string
} }
func newCaFlags() *caFlags { func newCaFlags() *caFlags {
cf := caFlags{set: flag.NewFlagSet("ca", flag.ContinueOnError)} cf := caFlags{set: flag.NewFlagSet("ca", flag.ContinueOnError)}
cf.set.Usage = func() {} cf.set.Usage = func() {}
cf.name = cf.set.String("name", "", "Required: name of the certificate authority") cf.name = cf.set.String("name", "", "Required: name of the certificate authority")
cf.version = cf.set.Uint("version", uint(cert.Version2), "Optional: version of the certificate format to use")
cf.duration = cf.set.Duration("duration", time.Duration(time.Hour*8760), "Optional: amount of time the certificate should be valid for. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"") cf.duration = cf.set.Duration("duration", time.Duration(time.Hour*8760), "Optional: amount of time the certificate should be valid for. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"")
cf.outKeyPath = cf.set.String("out-key", "ca.key", "Optional: path to write the private key to") cf.outKeyPath = cf.set.String("out-key", "ca.key", "Optional: path to write the private key to")
cf.outCertPath = cf.set.String("out-crt", "ca.crt", "Optional: path to write the certificate to") cf.outCertPath = cf.set.String("out-crt", "ca.crt", "Optional: path to write the certificate to")
cf.outQRPath = cf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate") cf.outQRPath = cf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate")
cf.groups = cf.set.String("groups", "", "Optional: comma separated list of groups. This will limit which groups subordinate certs can use") cf.groups = cf.set.String("groups", "", "Optional: comma separated list of groups. This will limit which groups subordinate certs can use")
cf.ips = cf.set.String("ips", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use for ip addresses") cf.networks = cf.set.String("networks", "", "Optional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in networks")
cf.subnets = cf.set.String("subnets", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use in subnets") cf.unsafeNetworks = cf.set.String("unsafe-networks", "", "Optional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in unsafe networks")
cf.argonMemory = cf.set.Uint("argon-memory", 2*1024*1024, "Optional: Argon2 memory parameter (in KiB) used for encrypted private key passphrase") cf.argonMemory = cf.set.Uint("argon-memory", 2*1024*1024, "Optional: Argon2 memory parameter (in KiB) used for encrypted private key passphrase")
cf.argonParallelism = cf.set.Uint("argon-parallelism", 4, "Optional: Argon2 parallelism parameter used for encrypted private key passphrase") cf.argonParallelism = cf.set.Uint("argon-parallelism", 4, "Optional: Argon2 parallelism parameter used for encrypted private key passphrase")
cf.argonIterations = cf.set.Uint("argon-iterations", 1, "Optional: Argon2 iterations parameter used for encrypted private key passphrase") cf.argonIterations = cf.set.Uint("argon-iterations", 1, "Optional: Argon2 iterations parameter used for encrypted private key passphrase")
cf.encryption = cf.set.Bool("encrypt", false, "Optional: prompt for passphrase and write out-key in an encrypted format") cf.encryption = cf.set.Bool("encrypt", false, "Optional: prompt for passphrase and write out-key in an encrypted format")
cf.curve = cf.set.String("curve", "25519", "EdDSA/ECDSA Curve (25519, P256)") cf.curve = cf.set.String("curve", "25519", "EdDSA/ECDSA Curve (25519, P256)")
cf.p11url = p11Flag(cf.set) cf.p11url = p11Flag(cf.set)
cf.ips = cf.set.String("ips", "", "Deprecated, see -networks")
cf.subnets = cf.set.String("subnets", "", "Deprecated, see -unsafe-networks")
return &cf return &cf
} }
@ -113,36 +122,51 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
} }
} }
var ips []netip.Prefix version := cert.Version(*cf.version)
if *cf.ips != "" { if version != cert.Version1 && version != cert.Version2 {
for _, rs := range strings.Split(*cf.ips, ",") { return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2)
}
var networks []netip.Prefix
if *cf.networks == "" && *cf.ips != "" {
// Pull up deprecated -ips flag if needed
*cf.networks = *cf.ips
}
if *cf.networks != "" {
for _, rs := range strings.Split(*cf.networks, ",") {
rs := strings.Trim(rs, " ") rs := strings.Trim(rs, " ")
if rs != "" { if rs != "" {
n, err := netip.ParsePrefix(rs) n, err := netip.ParsePrefix(rs)
if err != nil { if err != nil {
return newHelpErrorf("invalid ip definition: %s", err) return newHelpErrorf("invalid -networks definition: %s", rs)
} }
if !n.Addr().Is4() { if version == cert.Version1 && !n.Addr().Is4() {
return newHelpErrorf("invalid ip definition: can only be ipv4, have %s", rs) return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4, have %s", rs)
} }
ips = append(ips, n) networks = append(networks, n)
} }
} }
} }
var subnets []netip.Prefix var unsafeNetworks []netip.Prefix
if *cf.subnets != "" { if *cf.unsafeNetworks == "" && *cf.subnets != "" {
for _, rs := range strings.Split(*cf.subnets, ",") { // Pull up deprecated -subnets flag if needed
*cf.unsafeNetworks = *cf.subnets
}
if *cf.unsafeNetworks != "" {
for _, rs := range strings.Split(*cf.unsafeNetworks, ",") {
rs := strings.Trim(rs, " ") rs := strings.Trim(rs, " ")
if rs != "" { if rs != "" {
n, err := netip.ParsePrefix(rs) n, err := netip.ParsePrefix(rs)
if err != nil { if err != nil {
return newHelpErrorf("invalid subnet definition: %s", err) return newHelpErrorf("invalid -unsafe-networks definition: %s", rs)
} }
if !n.Addr().Is4() { if version == cert.Version1 && !n.Addr().Is4() {
return newHelpErrorf("invalid subnet definition: can only be ipv4, have %s", rs) return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4, have %s", rs)
} }
subnets = append(subnets, n) unsafeNetworks = append(unsafeNetworks, n)
} }
} }
} }
@ -222,11 +246,11 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
} }
t := &cert.TBSCertificate{ t := &cert.TBSCertificate{
Version: cert.Version1, Version: version,
Name: *cf.name, Name: *cf.name,
Groups: groups, Groups: groups,
Networks: ips, Networks: networks,
UnsafeNetworks: subnets, UnsafeNetworks: unsafeNetworks,
NotBefore: time.Now(), NotBefore: time.Now(),
NotAfter: time.Now().Add(*cf.duration), NotAfter: time.Now().Add(*cf.duration),
PublicKey: pub, PublicKey: pub,
@ -248,7 +272,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
var b []byte var b []byte
if isP11 { if isP11 {
c, err = t.SignPkcs11(nil, curve, p11Client) c, err = t.SignWith(nil, curve, p11Client.SignASN1)
if err != nil { if err != nil {
return fmt.Errorf("error while signing with PKCS#11: %w", err) return fmt.Errorf("error while signing with PKCS#11: %w", err)
} }

View File

@ -43,9 +43,11 @@ func Test_caHelp(t *testing.T) {
" -groups string\n"+ " -groups string\n"+
" \tOptional: comma separated list of groups. This will limit which groups subordinate certs can use\n"+ " \tOptional: comma separated list of groups. This will limit which groups subordinate certs can use\n"+
" -ips string\n"+ " -ips string\n"+
" \tOptional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use for ip addresses\n"+ " Deprecated, see -networks\n"+
" -name string\n"+ " -name string\n"+
" \tRequired: name of the certificate authority\n"+ " \tRequired: name of the certificate authority\n"+
" -networks string\n"+
" \tOptional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in networks\n"+
" -out-crt string\n"+ " -out-crt string\n"+
" \tOptional: path to write the certificate to (default \"ca.crt\")\n"+ " \tOptional: path to write the certificate to (default \"ca.crt\")\n"+
" -out-key string\n"+ " -out-key string\n"+
@ -54,7 +56,11 @@ func Test_caHelp(t *testing.T) {
" \tOptional: output a qr code image (png) of the certificate\n"+ " \tOptional: output a qr code image (png) of the certificate\n"+
optionalPkcs11String(" -pkcs11 string\n \tOptional: PKCS#11 URI to an existing private key\n")+ optionalPkcs11String(" -pkcs11 string\n \tOptional: PKCS#11 URI to an existing private key\n")+
" -subnets string\n"+ " -subnets string\n"+
" \tOptional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use in subnets\n", " \tDeprecated, see -unsafe-networks\n"+
" -unsafe-networks string\n"+
" \tOptional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in unsafe networks\n"+
" -version uint\n"+
" \tOptional: version of the certificate format to use (default 2)\n",
ob.String(), ob.String(),
) )
} }
@ -83,25 +89,25 @@ func Test_ca(t *testing.T) {
// required args // required args
assertHelpError(t, ca( assertHelpError(t, ca(
[]string{"-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw, []string{"-version", "1", "-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw,
), "-name is required") ), "-name is required")
assert.Equal(t, "", ob.String()) assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
// ipv4 only ips // ipv4 only ips
assertHelpError(t, ca([]string{"-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid ip definition: can only be ipv4, have 100::100/100") assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid -networks definition: v1 certificates can only be ipv4, have 100::100/100")
assert.Equal(t, "", ob.String()) assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
// ipv4 only subnets // ipv4 only subnets
assertHelpError(t, ca([]string{"-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid subnet definition: can only be ipv4, have 100::100/100") assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4, have 100::100/100")
assert.Equal(t, "", ob.String()) assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
// failed key write // failed key write
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args := []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"} args := []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"}
assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
assert.Equal(t, "", ob.String()) assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
@ -114,7 +120,7 @@ func Test_ca(t *testing.T) {
// failed cert write // failed cert write
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()} args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()}
assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
assert.Equal(t, "", ob.String()) assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
@ -128,7 +134,7 @@ func Test_ca(t *testing.T) {
// test proper cert with removed empty groups and subnets // test proper cert with removed empty groups and subnets
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.Nil(t, ca(args, ob, eb, nopw)) assert.Nil(t, ca(args, ob, eb, nopw))
assert.Equal(t, "", ob.String()) assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
@ -161,7 +167,7 @@ func Test_ca(t *testing.T) {
os.Remove(crtF.Name()) os.Remove(crtF.Name())
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.Nil(t, ca(args, ob, eb, testpw)) assert.Nil(t, ca(args, ob, eb, testpw))
assert.Equal(t, pwPromptOb, ob.String()) assert.Equal(t, pwPromptOb, ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
@ -189,7 +195,7 @@ func Test_ca(t *testing.T) {
os.Remove(crtF.Name()) os.Remove(crtF.Name())
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.Error(t, ca(args, ob, eb, errpw)) assert.Error(t, ca(args, ob, eb, errpw))
assert.Equal(t, pwPromptOb, ob.String()) assert.Equal(t, pwPromptOb, ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
@ -199,7 +205,7 @@ func Test_ca(t *testing.T) {
os.Remove(crtF.Name()) os.Remove(crtF.Name())
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext") assert.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext")
assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
@ -209,13 +215,13 @@ func Test_ca(t *testing.T) {
os.Remove(crtF.Name()) os.Remove(crtF.Name())
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.Nil(t, ca(args, ob, eb, nopw)) assert.Nil(t, ca(args, ob, eb, nopw))
// test that we won't overwrite existing certificate file // test that we won't overwrite existing certificate file
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name()) assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name())
assert.Equal(t, "", ob.String()) assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
@ -224,7 +230,7 @@ func Test_ca(t *testing.T) {
os.Remove(keyF.Name()) os.Remove(keyF.Name())
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name()) assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name())
assert.Equal(t, "", ob.String()) assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())

View File

@ -49,6 +49,8 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
var qrBytes []byte var qrBytes []byte
part := 0 part := 0
var jsonCerts []cert.Certificate
for { for {
c, rawCert, err = cert.UnmarshalCertificateFromPEM(rawCert) c, rawCert, err = cert.UnmarshalCertificateFromPEM(rawCert)
if err != nil { if err != nil {
@ -56,13 +58,10 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
} }
if *pf.json { if *pf.json {
b, _ := json.Marshal(c) jsonCerts = append(jsonCerts, c)
out.Write(b)
out.Write([]byte("\n"))
} else { } else {
out.Write([]byte(c.String())) _, _ = out.Write([]byte(c.String()))
out.Write([]byte("\n")) _, _ = out.Write([]byte("\n"))
} }
if *pf.outQRPath != "" { if *pf.outQRPath != "" {
@ -80,6 +79,12 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
part++ part++
} }
if *pf.json {
b, _ := json.Marshal(jsonCerts)
_, _ = out.Write(b)
_, _ = out.Write([]byte("\n"))
}
if *pf.outQRPath != "" { if *pf.outQRPath != "" {
b, err := qrcode.Encode(string(qrBytes), qrcode.Medium, -5) b, err := qrcode.Encode(string(qrBytes), qrcode.Medium, -5)
if err != nil { if err != nil {

View File

@ -87,7 +87,65 @@ func Test_printCert(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal( assert.Equal(
t, t,
"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\n", //"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\n",
`{
"details": {
"curve": "CURVE25519",
"groups": [
"hi"
],
"isCa": false,
"issuer": "`+c.Issuer()+`",
"name": "test",
"networks": [],
"notAfter": "0001-01-01T00:00:00Z",
"notBefore": "0001-01-01T00:00:00Z",
"publicKey": "`+pk+`",
"unsafeNetworks": []
},
"fingerprint": "`+fp+`",
"signature": "`+sig+`",
"version": 1
}
{
"details": {
"curve": "CURVE25519",
"groups": [
"hi"
],
"isCa": false,
"issuer": "`+c.Issuer()+`",
"name": "test",
"networks": [],
"notAfter": "0001-01-01T00:00:00Z",
"notBefore": "0001-01-01T00:00:00Z",
"publicKey": "`+pk+`",
"unsafeNetworks": []
},
"fingerprint": "`+fp+`",
"signature": "`+sig+`",
"version": 1
}
{
"details": {
"curve": "CURVE25519",
"groups": [
"hi"
],
"isCa": false,
"issuer": "`+c.Issuer()+`",
"name": "test",
"networks": [],
"notAfter": "0001-01-01T00:00:00Z",
"notBefore": "0001-01-01T00:00:00Z",
"publicKey": "`+pk+`",
"unsafeNetworks": []
},
"fingerprint": "`+fp+`",
"signature": "`+sig+`",
"version": 1
}
`,
ob.String(), ob.String(),
) )
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
@ -108,7 +166,8 @@ func Test_printCert(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal( assert.Equal(
t, t,
"{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\""+c.Issuer()+"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\""+pk+"\",\"subnets\":[]},\"fingerprint\":\""+fp+"\",\"signature\":\""+sig+"\"}\n{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\""+c.Issuer()+"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\""+pk+"\",\"subnets\":[]},\"fingerprint\":\""+fp+"\",\"signature\":\""+sig+"\"}\n{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\""+c.Issuer()+"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\""+pk+"\",\"subnets\":[]},\"fingerprint\":\""+fp+"\",\"signature\":\""+sig+"\"}\n", `[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":[],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":[],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":[],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}]
`,
ob.String(), ob.String(),
) )
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())

View File

@ -3,6 +3,7 @@ package main
import ( import (
"crypto/ecdh" "crypto/ecdh"
"crypto/rand" "crypto/rand"
"errors"
"flag" "flag"
"fmt" "fmt"
"io" "io"
@ -19,35 +20,45 @@ import (
type signFlags struct { type signFlags struct {
set *flag.FlagSet set *flag.FlagSet
version *uint
caKeyPath *string caKeyPath *string
caCertPath *string caCertPath *string
name *string name *string
ip *string networks *string
unsafeNetworks *string
duration *time.Duration duration *time.Duration
inPubPath *string inPubPath *string
outKeyPath *string outKeyPath *string
outCertPath *string outCertPath *string
outQRPath *string outQRPath *string
groups *string groups *string
subnets *string
p11url *string p11url *string
// Deprecated options
ip *string
subnets *string
} }
func newSignFlags() *signFlags { func newSignFlags() *signFlags {
sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)} sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)}
sf.set.Usage = func() {} sf.set.Usage = func() {}
sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use, the default is to create both v1 and v2 certificates.")
sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key") sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key")
sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert") sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert")
sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname") sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname")
sf.ip = sf.set.String("ip", "", "Required: ipv4 address and network in CIDR notation to assign the cert") sf.networks = sf.set.String("networks", "", "Required: comma separated list of ip address and network in CIDR notation to assign to this cert")
sf.unsafeNetworks = sf.set.String("unsafe-networks", "", "Optional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for")
sf.duration = sf.set.Duration("duration", 0, "Optional: how long the cert should be valid for. The default is 1 second before the signing cert expires. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"") sf.duration = sf.set.Duration("duration", 0, "Optional: how long the cert should be valid for. The default is 1 second before the signing cert expires. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"")
sf.inPubPath = sf.set.String("in-pub", "", "Optional (if out-key not set): path to read a previously generated public key") sf.inPubPath = sf.set.String("in-pub", "", "Optional (if out-key not set): path to read a previously generated public key")
sf.outKeyPath = sf.set.String("out-key", "", "Optional (if in-pub not set): path to write the private key to") sf.outKeyPath = sf.set.String("out-key", "", "Optional (if in-pub not set): path to write the private key to")
sf.outCertPath = sf.set.String("out-crt", "", "Optional: path to write the certificate to") sf.outCertPath = sf.set.String("out-crt", "", "Optional: path to write the certificate to")
sf.outQRPath = sf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate") sf.outQRPath = sf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate")
sf.groups = sf.set.String("groups", "", "Optional: comma separated list of groups") sf.groups = sf.set.String("groups", "", "Optional: comma separated list of groups")
sf.subnets = sf.set.String("subnets", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. Subnets this cert can serve for")
sf.p11url = p11Flag(sf.set) sf.p11url = p11Flag(sf.set)
sf.ip = sf.set.String("ip", "", "Deprecated, see -networks")
sf.subnets = sf.set.String("subnets", "", "Deprecated, see -unsafe-networks")
return &sf return &sf
} }
@ -71,13 +82,26 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
if err := mustFlagString("name", sf.name); err != nil { if err := mustFlagString("name", sf.name); err != nil {
return err return err
} }
if err := mustFlagString("ip", sf.ip); err != nil {
return err
}
if !isP11 && *sf.inPubPath != "" && *sf.outKeyPath != "" { if !isP11 && *sf.inPubPath != "" && *sf.outKeyPath != "" {
return newHelpErrorf("cannot set both -in-pub and -out-key") return newHelpErrorf("cannot set both -in-pub and -out-key")
} }
var v4Networks []netip.Prefix
var v6Networks []netip.Prefix
if *sf.networks == "" && *sf.ip != "" {
// Pull up deprecated -ip flag if needed
*sf.networks = *sf.ip
}
if len(*sf.networks) == 0 {
return newHelpErrorf("-networks is required")
}
version := cert.Version(*sf.version)
if version != 0 && version != cert.Version1 && version != cert.Version2 {
return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2)
}
var curve cert.Curve var curve cert.Curve
var caKey []byte var caKey []byte
@ -91,14 +115,14 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
// naively attempt to decode the private key as though it is not encrypted // naively attempt to decode the private key as though it is not encrypted
caKey, _, curve, err = cert.UnmarshalSigningPrivateKeyFromPEM(rawCAKey) caKey, _, curve, err = cert.UnmarshalSigningPrivateKeyFromPEM(rawCAKey)
if err == cert.ErrPrivateKeyEncrypted { if errors.Is(err, cert.ErrPrivateKeyEncrypted) {
// ask for a passphrase until we get one // ask for a passphrase until we get one
var passphrase []byte var passphrase []byte
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
out.Write([]byte("Enter passphrase: ")) out.Write([]byte("Enter passphrase: "))
passphrase, err = pr.ReadPassword() passphrase, err = pr.ReadPassword()
if err == ErrNoTerminal { if errors.Is(err, ErrNoTerminal) {
return fmt.Errorf("ca-key is encrypted and must be decrypted interactively") return fmt.Errorf("ca-key is encrypted and must be decrypted interactively")
} else if err != nil { } else if err != nil {
return fmt.Errorf("error reading password: %s", err) return fmt.Errorf("error reading password: %s", err)
@ -146,12 +170,49 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
*sf.duration = time.Until(caCert.NotAfter()) - time.Second*1 *sf.duration = time.Until(caCert.NotAfter()) - time.Second*1
} }
network, err := netip.ParsePrefix(*sf.ip) if *sf.networks != "" {
for _, rs := range strings.Split(*sf.networks, ",") {
//TODO: error on duplicates? Mainly only addr matters, having two of the same addr in the same or different prefix space is strange
rs := strings.Trim(rs, " ")
if rs != "" {
n, err := netip.ParsePrefix(rs)
if err != nil { if err != nil {
return newHelpErrorf("invalid ip definition: %s", *sf.ip) return newHelpErrorf("invalid -networks definition: %s", rs)
}
if n.Addr().Is4() {
v4Networks = append(v4Networks, n)
} else {
v6Networks = append(v6Networks, n)
}
}
}
}
var v4UnsafeNetworks []netip.Prefix
var v6UnsafeNetworks []netip.Prefix
if *sf.unsafeNetworks == "" && *sf.subnets != "" {
// Pull up deprecated -subnets flag if needed
*sf.unsafeNetworks = *sf.subnets
}
if *sf.unsafeNetworks != "" {
//TODO: error on duplicates?
for _, rs := range strings.Split(*sf.unsafeNetworks, ",") {
rs := strings.Trim(rs, " ")
if rs != "" {
n, err := netip.ParsePrefix(rs)
if err != nil {
return newHelpErrorf("invalid -unsafe-networks definition: %s", rs)
}
if n.Addr().Is4() {
v4UnsafeNetworks = append(v4UnsafeNetworks, n)
} else {
v6UnsafeNetworks = append(v6UnsafeNetworks, n)
}
}
} }
if !network.Addr().Is4() {
return newHelpErrorf("invalid ip definition: can only be ipv4, have %s", *sf.ip)
} }
var groups []string var groups []string
@ -164,23 +225,6 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
} }
} }
var subnets []netip.Prefix
if *sf.subnets != "" {
for _, rs := range strings.Split(*sf.subnets, ",") {
rs := strings.Trim(rs, " ")
if rs != "" {
s, err := netip.ParsePrefix(rs)
if err != nil {
return newHelpErrorf("invalid subnet definition: %s", rs)
}
if !s.Addr().Is4() {
return newHelpErrorf("invalid subnet definition: can only be ipv4, have %s", rs)
}
subnets = append(subnets, s)
}
}
}
var pub, rawPriv []byte var pub, rawPriv []byte
var p11Client *pkclient.PKClient var p11Client *pkclient.PKClient
@ -218,19 +262,6 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
pub, rawPriv = newKeypair(curve) pub, rawPriv = newKeypair(curve)
} }
t := &cert.TBSCertificate{
Version: cert.Version1,
Name: *sf.name,
Networks: []netip.Prefix{network},
Groups: groups,
UnsafeNetworks: subnets,
NotBefore: time.Now(),
NotAfter: time.Now().Add(*sf.duration),
PublicKey: pub,
IsCA: false,
Curve: curve,
}
if *sf.outKeyPath == "" { if *sf.outKeyPath == "" {
*sf.outKeyPath = *sf.name + ".key" *sf.outKeyPath = *sf.name + ".key"
} }
@ -243,20 +274,87 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath) return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath)
} }
var c cert.Certificate var crts []cert.Certificate
notBefore := time.Now()
notAfter := notBefore.Add(*sf.duration)
if version == 0 || version == cert.Version1 {
// Make sure we at least have an ip
if len(v4Networks) != 1 {
return newHelpErrorf("invalid -networks definition: v1 certificates can only have a single ipv4 address")
}
if version == cert.Version1 {
// If we are asked to mint a v1 certificate only then we cant just ignore any v6 addresses
if len(v6Networks) > 0 {
return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4")
}
if len(v6UnsafeNetworks) > 0 {
return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4")
}
}
t := &cert.TBSCertificate{
Version: cert.Version1,
Name: *sf.name,
Networks: []netip.Prefix{v4Networks[0]},
Groups: groups,
UnsafeNetworks: v4UnsafeNetworks,
NotBefore: notBefore,
NotAfter: notAfter,
PublicKey: pub,
IsCA: false,
Curve: curve,
}
var nc cert.Certificate
if p11Client == nil { if p11Client == nil {
c, err = t.Sign(caCert, curve, caKey) nc, err = t.Sign(caCert, curve, caKey)
if err != nil { if err != nil {
return fmt.Errorf("error while signing: %w", err) return fmt.Errorf("error while signing: %w", err)
} }
} else { } else {
c, err = t.SignPkcs11(caCert, curve, p11Client) nc, err = t.SignWith(caCert, curve, p11Client.SignASN1)
if err != nil { if err != nil {
return fmt.Errorf("error while signing with PKCS#11: %w", err) return fmt.Errorf("error while signing with PKCS#11: %w", err)
} }
} }
crts = append(crts, nc)
}
if version == 0 || version == cert.Version2 {
t := &cert.TBSCertificate{
Version: cert.Version2,
Name: *sf.name,
Networks: append(v4Networks, v6Networks...),
Groups: groups,
UnsafeNetworks: append(v4UnsafeNetworks, v6UnsafeNetworks...),
NotBefore: notBefore,
NotAfter: notAfter,
PublicKey: pub,
IsCA: false,
Curve: curve,
}
var nc cert.Certificate
if p11Client == nil {
nc, err = t.Sign(caCert, curve, caKey)
if err != nil {
return fmt.Errorf("error while signing: %w", err)
}
} else {
nc, err = t.SignWith(caCert, curve, p11Client.SignASN1)
if err != nil {
return fmt.Errorf("error while signing with PKCS#11: %w", err)
}
}
crts = append(crts, nc)
}
if !isP11 && *sf.inPubPath == "" { if !isP11 && *sf.inPubPath == "" {
if _, err := os.Stat(*sf.outKeyPath); err == nil { if _, err := os.Stat(*sf.outKeyPath); err == nil {
return fmt.Errorf("refusing to overwrite existing key: %s", *sf.outKeyPath) return fmt.Errorf("refusing to overwrite existing key: %s", *sf.outKeyPath)
@ -268,10 +366,14 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
} }
} }
b, err := c.MarshalPEM() var b []byte
for _, c := range crts {
sb, err := c.MarshalPEM()
if err != nil { if err != nil {
return fmt.Errorf("error while marshalling certificate: %s", err) return fmt.Errorf("error while marshalling certificate: %s", err)
} }
b = append(b, sb...)
}
err = os.WriteFile(*sf.outCertPath, b, 0600) err = os.WriteFile(*sf.outCertPath, b, 0600)
if err != nil { if err != nil {

View File

@ -39,9 +39,11 @@ func Test_signHelp(t *testing.T) {
" -in-pub string\n"+ " -in-pub string\n"+
" \tOptional (if out-key not set): path to read a previously generated public key\n"+ " \tOptional (if out-key not set): path to read a previously generated public key\n"+
" -ip string\n"+ " -ip string\n"+
" \tRequired: ipv4 address and network in CIDR notation to assign the cert\n"+ " \tDeprecated, see -networks\n"+
" -name string\n"+ " -name string\n"+
" \tRequired: name of the cert, usually a hostname\n"+ " \tRequired: name of the cert, usually a hostname\n"+
" -networks string\n"+
" \tRequired: comma separated list of ip address and network in CIDR notation to assign to this cert\n"+
" -out-crt string\n"+ " -out-crt string\n"+
" \tOptional: path to write the certificate to\n"+ " \tOptional: path to write the certificate to\n"+
" -out-key string\n"+ " -out-key string\n"+
@ -50,7 +52,11 @@ func Test_signHelp(t *testing.T) {
" \tOptional: output a qr code image (png) of the certificate\n"+ " \tOptional: output a qr code image (png) of the certificate\n"+
optionalPkcs11String(" -pkcs11 string\n \tOptional: PKCS#11 URI to an existing private key\n")+ optionalPkcs11String(" -pkcs11 string\n \tOptional: PKCS#11 URI to an existing private key\n")+
" -subnets string\n"+ " -subnets string\n"+
" \tOptional: comma separated list of ipv4 address and network in CIDR notation. Subnets this cert can serve for\n", " \tDeprecated, see -unsafe-networks\n"+
" -unsafe-networks string\n"+
" \tOptional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for\n"+
" -version uint\n"+
" \tOptional: version of the certificate format to use, the default is to create both v1 and v2 certificates.\n",
ob.String(), ob.String(),
) )
} }
@ -77,20 +83,20 @@ func Test_signCert(t *testing.T) {
// required args // required args
assertHelpError(t, signCert( assertHelpError(t, signCert(
[]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw, []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
), "-name is required") ), "-name is required")
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
assertHelpError(t, signCert( assertHelpError(t, signCert(
[]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw, []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
), "-ip is required") ), "-networks is required")
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
// cannot set -in-pub and -out-key // cannot set -in-pub and -out-key
assertHelpError(t, signCert( assertHelpError(t, signCert(
[]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb, nopw, []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb, nopw,
), "cannot set both -in-pub and -out-key") ), "cannot set both -in-pub and -out-key")
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -98,7 +104,7 @@ func Test_signCert(t *testing.T) {
// failed to read key // failed to read key
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args := []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} args := []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError) assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError)
// failed to unmarshal key // failed to unmarshal key
@ -108,7 +114,7 @@ func Test_signCert(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(caKeyF.Name()) defer os.Remove(caKeyF.Name())
args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block") assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block")
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -120,7 +126,7 @@ func Test_signCert(t *testing.T) {
caKeyF.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv)) caKeyF.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv))
// failed to read cert // failed to read cert
args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError) assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError)
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -132,7 +138,7 @@ func Test_signCert(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(caCrtF.Name()) defer os.Remove(caCrtF.Name())
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block") assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block")
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -143,7 +149,7 @@ func Test_signCert(t *testing.T) {
caCrtF.Write(b) caCrtF.Write(b)
// failed to read pub // failed to read pub
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError) assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError)
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -155,7 +161,7 @@ func Test_signCert(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(inPubF.Name()) defer os.Remove(inPubF.Name())
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block") assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block")
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -169,30 +175,37 @@ func Test_signCert(t *testing.T) {
// bad ip cidr // bad ip cidr
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "a1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "a1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: a1.1.1.1/24") assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: a1.1.1.1/24")
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "100::100/100", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "100::100/100", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: can only be ipv4, have 100::100/100") assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: v1 certificates can only have a single ipv4 address")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
ob.Reset()
eb.Reset()
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24,1.1.1.2/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: v1 certificates can only have a single ipv4 address")
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
// bad subnet cidr // bad subnet cidr
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: a") assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: a")
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"}
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: can only be ipv4, have 100::100/100") assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4")
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -205,7 +218,7 @@ func Test_signCert(t *testing.T) {
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key") assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key")
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -213,7 +226,7 @@ func Test_signCert(t *testing.T) {
// failed key write // failed key write
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -226,7 +239,7 @@ func Test_signCert(t *testing.T) {
// failed cert write // failed cert write
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -240,7 +253,7 @@ func Test_signCert(t *testing.T) {
// test proper cert with removed empty groups and subnets // test proper cert with removed empty groups and subnets
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.Nil(t, signCert(args, ob, eb, nopw)) assert.Nil(t, signCert(args, ob, eb, nopw))
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -283,7 +296,7 @@ func Test_signCert(t *testing.T) {
os.Remove(crtF.Name()) os.Remove(crtF.Name())
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"}
assert.Nil(t, signCert(args, ob, eb, nopw)) assert.Nil(t, signCert(args, ob, eb, nopw))
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -300,7 +313,7 @@ func Test_signCert(t *testing.T) {
eb.Reset() eb.Reset()
os.Remove(keyF.Name()) os.Remove(keyF.Name())
os.Remove(crtF.Name()) os.Remove(crtF.Name())
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while signing: certificate expires after signing certificate") assert.EqualError(t, signCert(args, ob, eb, nopw), "error while signing: certificate expires after signing certificate")
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -308,14 +321,14 @@ func Test_signCert(t *testing.T) {
// create valid cert/key for overwrite tests // create valid cert/key for overwrite tests
os.Remove(keyF.Name()) os.Remove(keyF.Name())
os.Remove(crtF.Name()) os.Remove(crtF.Name())
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.Nil(t, signCert(args, ob, eb, nopw)) assert.Nil(t, signCert(args, ob, eb, nopw))
// test that we won't overwrite existing key file // test that we won't overwrite existing key file
os.Remove(crtF.Name()) os.Remove(crtF.Name())
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name()) assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name())
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -323,14 +336,14 @@ func Test_signCert(t *testing.T) {
// create valid cert/key for overwrite tests // create valid cert/key for overwrite tests
os.Remove(keyF.Name()) os.Remove(keyF.Name())
os.Remove(crtF.Name()) os.Remove(crtF.Name())
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.Nil(t, signCert(args, ob, eb, nopw)) assert.Nil(t, signCert(args, ob, eb, nopw))
// test that we won't overwrite existing certificate file // test that we won't overwrite existing certificate file
os.Remove(keyF.Name()) os.Remove(keyF.Name())
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name()) assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name())
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -362,7 +375,7 @@ func Test_signCert(t *testing.T) {
caCrtF.Write(b) caCrtF.Write(b)
// test with the proper password // test with the proper password
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.Nil(t, signCert(args, ob, eb, testpw)) assert.Nil(t, signCert(args, ob, eb, testpw))
assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Equal(t, "Enter passphrase: ", ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -372,7 +385,7 @@ func Test_signCert(t *testing.T) {
eb.Reset() eb.Reset()
testpw.password = []byte("invalid password") testpw.password = []byte("invalid password")
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.Error(t, signCert(args, ob, eb, testpw)) assert.Error(t, signCert(args, ob, eb, testpw))
assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Equal(t, "Enter passphrase: ", ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -381,7 +394,7 @@ func Test_signCert(t *testing.T) {
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.Error(t, signCert(args, ob, eb, nopw)) assert.Error(t, signCert(args, ob, eb, nopw))
// normally the user hitting enter on the prompt would add newlines between these // normally the user hitting enter on the prompt would add newlines between these
assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String()) assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String())
@ -391,7 +404,7 @@ func Test_signCert(t *testing.T) {
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.Error(t, signCert(args, ob, eb, errpw)) assert.Error(t, signCert(args, ob, eb, errpw))
assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Equal(t, "Enter passphrase: ", ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())

View File

@ -183,7 +183,7 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte,
case deleteTunnel: case deleteTunnel:
if n.hostMap.DeleteHostInfo(hostinfo) { if n.hostMap.DeleteHostInfo(hostinfo) {
// Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap // Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap
n.intf.lightHouse.DeleteVpnIp(hostinfo.vpnIp) n.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs)
} }
case closeTunnel: case closeTunnel:
@ -221,7 +221,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
relayFor := oldhostinfo.relayState.CopyAllRelayFor() relayFor := oldhostinfo.relayState.CopyAllRelayFor()
for _, r := range relayFor { for _, r := range relayFor {
existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerIp) existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerAddr)
var index uint32 var index uint32
var relayFrom netip.Addr var relayFrom netip.Addr
@ -235,11 +235,11 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
index = existing.LocalIndex index = existing.LocalIndex
switch r.Type { switch r.Type {
case TerminalType: case TerminalType:
relayFrom = n.intf.myVpnNet.Addr() relayFrom = n.intf.myVpnAddrs[0]
relayTo = existing.PeerIp relayTo = existing.PeerAddr
case ForwardingType: case ForwardingType:
relayFrom = existing.PeerIp relayFrom = existing.PeerAddr
relayTo = newhostinfo.vpnIp relayTo = newhostinfo.vpnAddrs[0]
default: default:
// should never happen // should never happen
} }
@ -253,45 +253,64 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
n.relayUsedLock.RUnlock() n.relayUsedLock.RUnlock()
// The relay doesn't exist at all; create some relay state and send the request. // The relay doesn't exist at all; create some relay state and send the request.
var err error var err error
index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerIp, nil, r.Type, Requested) index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerAddr, nil, r.Type, Requested)
if err != nil { if err != nil {
n.l.WithError(err).Error("failed to migrate relay to new hostinfo") n.l.WithError(err).Error("failed to migrate relay to new hostinfo")
continue continue
} }
switch r.Type { switch r.Type {
case TerminalType: case TerminalType:
relayFrom = n.intf.myVpnNet.Addr() relayFrom = n.intf.myVpnAddrs[0]
relayTo = r.PeerIp relayTo = r.PeerAddr
case ForwardingType: case ForwardingType:
relayFrom = r.PeerIp relayFrom = r.PeerAddr
relayTo = newhostinfo.vpnIp relayTo = newhostinfo.vpnAddrs[0]
default: default:
// should never happen // should never happen
} }
} }
//TODO: IPV6-WORK
relayFromB := relayFrom.As4()
relayToB := relayTo.As4()
// Send a CreateRelayRequest to the peer. // Send a CreateRelayRequest to the peer.
req := NebulaControl{ req := NebulaControl{
Type: NebulaControl_CreateRelayRequest, Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: index, InitiatorRelayIndex: index,
RelayFromIp: binary.BigEndian.Uint32(relayFromB[:]),
RelayToIp: binary.BigEndian.Uint32(relayToB[:]),
} }
switch newhostinfo.GetCert().Certificate.Version() {
case cert.Version1:
if !relayFrom.Is4() {
n.l.Error("can not migrate v1 relay with a v6 network because the relay is not running a current nebula version")
continue
}
if !relayTo.Is4() {
n.l.Error("can not migrate v1 relay with a v6 remote network because the relay is not running a current nebula version")
continue
}
b := relayFrom.As4()
req.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
b = relayTo.As4()
req.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
case cert.Version2:
req.RelayFromAddr = netAddrToProtoAddr(relayFrom)
req.RelayToAddr = netAddrToProtoAddr(relayTo)
default:
newhostinfo.logger(n.l).Error("Unknown certificate version found while attempting to migrate relay")
continue
}
msg, err := req.Marshal() msg, err := req.Marshal()
if err != nil { if err != nil {
n.l.WithError(err).Error("failed to marshal Control message to migrate relay") n.l.WithError(err).Error("failed to marshal Control message to migrate relay")
} else { } else {
n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu)) n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
n.l.WithFields(logrus.Fields{ n.l.WithFields(logrus.Fields{
"relayFrom": req.RelayFromIp, "relayFrom": req.RelayFromAddr,
"relayTo": req.RelayToIp, "relayTo": req.RelayToAddr,
"initiatorRelayIndex": req.InitiatorRelayIndex, "initiatorRelayIndex": req.InitiatorRelayIndex,
"responderRelayIndex": req.ResponderRelayIndex, "responderRelayIndex": req.ResponderRelayIndex,
"vpnIp": newhostinfo.vpnIp}). "vpnAddrs": newhostinfo.vpnAddrs}).
Info("send CreateRelayRequest") Info("send CreateRelayRequest")
} }
} }
@ -313,7 +332,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time
return closeTunnel, hostinfo, nil return closeTunnel, hostinfo, nil
} }
primary := n.hostMap.Hosts[hostinfo.vpnIp] primary := n.hostMap.Hosts[hostinfo.vpnAddrs[0]]
mainHostInfo := true mainHostInfo := true
if primary != nil && primary != hostinfo { if primary != nil && primary != hostinfo {
mainHostInfo = false mainHostInfo = false
@ -407,21 +426,24 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary. // If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
// Let's sort this out. // Let's sort this out.
if current.vpnIp.Compare(n.intf.myVpnNet.Addr()) < 0 { //TODO: current.vpnIp should become an array of vpnIps
if current.vpnAddrs[0].Compare(n.intf.myVpnAddrs[0]) < 0 {
// Only one side should flip primary because if both flip then we may never resolve to a single tunnel. // Only one side should flip primary because if both flip then we may never resolve to a single tunnel.
// vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping. // vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping.
// The remotes vpn ip is lower than mine. I will not flip. // The remotes vpn ip is lower than mine. I will not flip.
return false return false
} }
certState := n.intf.pki.GetCertState() //TODO: we should favor v2 over v1 certificates if configured to send them
return bytes.Equal(current.ConnectionState.myCert.Signature(), certState.Certificate.Signature())
crt := n.intf.pki.getCertificate(current.ConnectionState.myCert.Version())
return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
} }
func (n *connectionManager) swapPrimary(current, primary *HostInfo) { func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
n.hostMap.Lock() n.hostMap.Lock()
// Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake. // Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake.
if n.hostMap.Hosts[current.vpnIp] == primary { if n.hostMap.Hosts[current.vpnAddrs[0]] == primary {
n.hostMap.unlockedMakePrimary(current) n.hostMap.unlockedMakePrimary(current)
} }
n.hostMap.Unlock() n.hostMap.Unlock()
@ -473,14 +495,16 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
} }
func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) { func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
certState := n.intf.pki.GetCertState() crt := n.intf.pki.getCertificate(hostinfo.ConnectionState.myCert.Version())
if bytes.Equal(hostinfo.ConnectionState.myCert.Signature(), certState.Certificate.Signature()) { if bytes.Equal(hostinfo.ConnectionState.myCert.Signature(), crt.Signature()) {
return return
} }
n.l.WithField("vpnIp", hostinfo.vpnIp). //TODO: we should favor v2 over v1 certificates if configured to send them
n.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("reason", "local certificate is not current"). WithField("reason", "local certificate is not current").
Info("Re-handshaking with remote") Info("Re-handshaking with remote")
n.intf.handshakeManager.StartHandshake(hostinfo.vpnIp, nil) n.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
} }

View File

@ -34,20 +34,19 @@ func newTestLighthouse() *LightHouse {
func Test_NewConnectionManagerTest(t *testing.T) { func Test_NewConnectionManagerTest(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
vpncidr := netip.MustParsePrefix("172.1.1.1/24")
localrange := netip.MustParsePrefix("10.1.1.1/24") localrange := netip.MustParsePrefix("10.1.1.1/24")
vpnIp := netip.MustParseAddr("172.1.1.2") vpnIp := netip.MustParseAddr("172.1.1.2")
preferredRanges := []netip.Prefix{localrange} preferredRanges := []netip.Prefix{localrange}
// Very incomplete mock objects // Very incomplete mock objects
hostMap := newHostMap(l, vpncidr) hostMap := newHostMap(l)
hostMap.preferredRanges.Store(&preferredRanges) hostMap.preferredRanges.Store(&preferredRanges)
cs := &CertState{ cs := &CertState{
RawCertificate: []byte{}, defaultVersion: cert.Version1,
PrivateKey: []byte{}, privateKey: []byte{},
Certificate: &dummyCert{}, v1Cert: &dummyCert{version: cert.Version1},
RawCertificateNoKey: []byte{}, v1HandshakeBytes: []byte{},
} }
lh := newTestLighthouse() lh := newTestLighthouse()
@ -74,12 +73,12 @@ func Test_NewConnectionManagerTest(t *testing.T) {
// Add an ip we have established a connection w/ to hostmap // Add an ip we have established a connection w/ to hostmap
hostinfo := &HostInfo{ hostinfo := &HostInfo{
vpnIp: vpnIp, vpnAddrs: []netip.Addr{vpnIp},
localIndexId: 1099, localIndexId: 1099,
remoteIndexId: 9901, remoteIndexId: 9901,
} }
hostinfo.ConnectionState = &ConnectionState{ hostinfo.ConnectionState = &ConnectionState{
myCert: &dummyCert{}, myCert: &dummyCert{version: cert.Version1},
H: &noise.HandshakeState{}, H: &noise.HandshakeState{},
} }
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
@ -88,7 +87,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
nc.Out(hostinfo.localIndexId) nc.Out(hostinfo.localIndexId)
nc.In(hostinfo.localIndexId) nc.In(hostinfo.localIndexId)
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.out, hostinfo.localIndexId) assert.Contains(t, nc.out, hostinfo.localIndexId)
@ -105,32 +104,31 @@ func Test_NewConnectionManagerTest(t *testing.T) {
assert.NotContains(t, nc.out, hostinfo.localIndexId) assert.NotContains(t, nc.out, hostinfo.localIndexId)
assert.NotContains(t, nc.in, hostinfo.localIndexId) assert.NotContains(t, nc.in, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
// Do a final traffic check tick, the host should now be removed // Do a final traffic check tick, the host should now be removed
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnIp) assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
} }
func Test_NewConnectionManagerTest2(t *testing.T) { func Test_NewConnectionManagerTest2(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
vpncidr := netip.MustParsePrefix("172.1.1.1/24")
localrange := netip.MustParsePrefix("10.1.1.1/24") localrange := netip.MustParsePrefix("10.1.1.1/24")
vpnIp := netip.MustParseAddr("172.1.1.2") vpnIp := netip.MustParseAddr("172.1.1.2")
preferredRanges := []netip.Prefix{localrange} preferredRanges := []netip.Prefix{localrange}
// Very incomplete mock objects // Very incomplete mock objects
hostMap := newHostMap(l, vpncidr) hostMap := newHostMap(l)
hostMap.preferredRanges.Store(&preferredRanges) hostMap.preferredRanges.Store(&preferredRanges)
cs := &CertState{ cs := &CertState{
RawCertificate: []byte{}, defaultVersion: cert.Version1,
PrivateKey: []byte{}, privateKey: []byte{},
Certificate: &dummyCert{}, v1Cert: &dummyCert{version: cert.Version1},
RawCertificateNoKey: []byte{}, v1HandshakeBytes: []byte{},
} }
lh := newTestLighthouse() lh := newTestLighthouse()
@ -157,12 +155,12 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
// Add an ip we have established a connection w/ to hostmap // Add an ip we have established a connection w/ to hostmap
hostinfo := &HostInfo{ hostinfo := &HostInfo{
vpnIp: vpnIp, vpnAddrs: []netip.Addr{vpnIp},
localIndexId: 1099, localIndexId: 1099,
remoteIndexId: 9901, remoteIndexId: 9901,
} }
hostinfo.ConnectionState = &ConnectionState{ hostinfo.ConnectionState = &ConnectionState{
myCert: &dummyCert{}, myCert: &dummyCert{version: cert.Version1},
H: &noise.HandshakeState{}, H: &noise.HandshakeState{},
} }
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
@ -170,8 +168,8 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
// We saw traffic out to vpnIp // We saw traffic out to vpnIp
nc.Out(hostinfo.localIndexId) nc.Out(hostinfo.localIndexId)
nc.In(hostinfo.localIndexId) nc.In(hostinfo.localIndexId)
assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnIp) assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnAddrs[0])
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded // Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
@ -187,7 +185,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
assert.NotContains(t, nc.out, hostinfo.localIndexId) assert.NotContains(t, nc.out, hostinfo.localIndexId)
assert.NotContains(t, nc.in, hostinfo.localIndexId) assert.NotContains(t, nc.in, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
// We saw traffic, should no longer be pending deletion // We saw traffic, should no longer be pending deletion
nc.In(hostinfo.localIndexId) nc.In(hostinfo.localIndexId)
@ -196,7 +194,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
assert.NotContains(t, nc.out, hostinfo.localIndexId) assert.NotContains(t, nc.out, hostinfo.localIndexId)
assert.NotContains(t, nc.in, hostinfo.localIndexId) assert.NotContains(t, nc.in, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
} }
// Check if we can disconnect the peer. // Check if we can disconnect the peer.
@ -210,7 +208,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
localrange := netip.MustParsePrefix("10.1.1.1/24") localrange := netip.MustParsePrefix("10.1.1.1/24")
vpnIp := netip.MustParseAddr("172.1.1.2") vpnIp := netip.MustParseAddr("172.1.1.2")
preferredRanges := []netip.Prefix{localrange} preferredRanges := []netip.Prefix{localrange}
hostMap := newHostMap(l, vpncidr) hostMap := newHostMap(l)
hostMap.preferredRanges.Store(&preferredRanges) hostMap.preferredRanges.Store(&preferredRanges)
// Generate keys for CA and peer's cert. // Generate keys for CA and peer's cert.
@ -244,10 +242,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert) cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert)
cs := &CertState{ cs := &CertState{
RawCertificate: []byte{}, privateKey: []byte{},
PrivateKey: []byte{}, v1Cert: &dummyCert{},
Certificate: &dummyCert{}, v1HandshakeBytes: []byte{},
RawCertificateNoKey: []byte{},
} }
lh := newTestLighthouse() lh := newTestLighthouse()
@ -273,7 +270,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
ifce.connectionManager = nc ifce.connectionManager = nc
hostinfo := &HostInfo{ hostinfo := &HostInfo{
vpnIp: vpnIp, vpnAddrs: []netip.Addr{vpnIp},
ConnectionState: &ConnectionState{ ConnectionState: &ConnectionState{
myCert: &dummyCert{}, myCert: &dummyCert{},
peerCert: cachedPeerCert, peerCert: cachedPeerCert,

View File

@ -3,6 +3,7 @@ package nebula
import ( import (
"crypto/rand" "crypto/rand"
"encoding/json" "encoding/json"
"fmt"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -26,46 +27,46 @@ type ConnectionState struct {
writeLock sync.Mutex writeLock sync.Mutex
} }
func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState { func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) {
var dhFunc noise.DHFunc var dhFunc noise.DHFunc
switch certState.Certificate.Curve() { switch crt.Curve() {
case cert.Curve_CURVE25519: case cert.Curve_CURVE25519:
dhFunc = noise.DH25519 dhFunc = noise.DH25519
case cert.Curve_P256: case cert.Curve_P256:
if certState.pkcs11Backed { if cs.pkcs11Backed {
dhFunc = noiseutil.DHP256PKCS11 dhFunc = noiseutil.DHP256PKCS11
} else { } else {
dhFunc = noiseutil.DHP256 dhFunc = noiseutil.DHP256
} }
default: default:
l.Errorf("invalid curve: %s", certState.Certificate.Curve()) return nil, fmt.Errorf("invalid curve: %s", crt.Curve())
return nil
} }
var cs noise.CipherSuite var ncs noise.CipherSuite
if cipher == "chachapoly" { if cs.cipher == "chachapoly" {
cs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256) ncs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
} else { } else {
cs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256) ncs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
} }
static := noise.DHKey{Private: certState.PrivateKey, Public: certState.PublicKey} static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()}
b := NewBits(ReplayWindow) b := NewBits(ReplayWindow)
// Clear out bit 0, we never transmit it and we don't want it showing as packet loss // Clear out bit 0, we never transmit it, and we don't want it showing as packet loss
b.Update(l, 0) b.Update(l, 0)
hs, err := noise.NewHandshakeState(noise.Config{ hs, err := noise.NewHandshakeState(noise.Config{
CipherSuite: cs, CipherSuite: ncs,
Random: rand.Reader, Random: rand.Reader,
Pattern: pattern, Pattern: pattern,
Initiator: initiator, Initiator: initiator,
StaticKeypair: static, StaticKeypair: static,
PresharedKey: psk, //NOTE: These should come from CertState (pki.go) when we finally implement it
PresharedKeyPlacement: pskStage, PresharedKey: []byte{},
PresharedKeyPlacement: 0,
}) })
if err != nil { if err != nil {
return nil return nil, fmt.Errorf("NewConnectionState: %s", err)
} }
// The queue and ready params prevent a counter race that would happen when // The queue and ready params prevent a counter race that would happen when
@ -74,12 +75,12 @@ func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, i
H: hs, H: hs,
initiator: initiator, initiator: initiator,
window: b, window: b,
myCert: certState.Certificate, myCert: crt,
} }
// always start the counter from 2, as packet 1 and packet 2 are handshake packets. // always start the counter from 2, as packet 1 and packet 2 are handshake packets.
ci.messageCounter.Add(2) ci.messageCounter.Add(2)
return ci return ci, nil
} }
func (cs *ConnectionState) MarshalJSON() ([]byte, error) { func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
@ -89,3 +90,7 @@ func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
"message_counter": cs.messageCounter.Load(), "message_counter": cs.messageCounter.Load(),
}) })
} }
func (cs *ConnectionState) Curve() cert.Curve {
return cs.myCert.Curve()
}

View File

@ -19,9 +19,9 @@ import (
type controlEach func(h *HostInfo) type controlEach func(h *HostInfo)
type controlHostLister interface { type controlHostLister interface {
QueryVpnIp(vpnIp netip.Addr) *HostInfo QueryVpnAddr(vpnAddr netip.Addr) *HostInfo
ForEachIndex(each controlEach) ForEachIndex(each controlEach)
ForEachVpnIp(each controlEach) ForEachVpnAddr(each controlEach)
GetPreferredRanges() []netip.Prefix GetPreferredRanges() []netip.Prefix
} }
@ -37,7 +37,7 @@ type Control struct {
} }
type ControlHostInfo struct { type ControlHostInfo struct {
VpnIp netip.Addr `json:"vpnIp"` VpnAddrs []netip.Addr `json:"vpnAddrs"`
LocalIndex uint32 `json:"localIndex"` LocalIndex uint32 `json:"localIndex"`
RemoteIndex uint32 `json:"remoteIndex"` RemoteIndex uint32 `json:"remoteIndex"`
RemoteAddrs []netip.AddrPort `json:"remoteAddrs"` RemoteAddrs []netip.AddrPort `json:"remoteAddrs"`
@ -131,10 +131,13 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
// GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found // GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found
func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate { func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate {
if c.f.myVpnNet.Addr() == vpnIp { _, found := c.f.myVpnAddrsTable.Lookup(vpnIp)
return c.f.pki.GetCertState().Certificate.Copy() if found {
//TODO: we might have 2 certs....
//TODO: this should return our latest version cert
return c.f.pki.getDefaultCertificate().Copy()
} }
hi := c.f.hostMap.QueryVpnIp(vpnIp) hi := c.f.hostMap.QueryVpnAddr(vpnIp)
if hi == nil { if hi == nil {
return nil return nil
} }
@ -148,7 +151,7 @@ func (c *Control) CreateTunnel(vpnIp netip.Addr) {
// PrintTunnel creates a new tunnel to the given vpn ip. // PrintTunnel creates a new tunnel to the given vpn ip.
func (c *Control) PrintTunnel(vpnIp netip.Addr) *ControlHostInfo { func (c *Control) PrintTunnel(vpnIp netip.Addr) *ControlHostInfo {
hi := c.f.hostMap.QueryVpnIp(vpnIp) hi := c.f.hostMap.QueryVpnAddr(vpnIp)
if hi == nil { if hi == nil {
return nil return nil
} }
@ -165,9 +168,9 @@ func (c *Control) QueryLighthouse(vpnIp netip.Addr) *CacheMap {
return hi.CopyCache() return hi.CopyCache()
} }
// GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found // GetHostInfoByVpnAddr returns a single tunnels hostInfo, or nil if not found
// Caller should take care to Unmap() any 4in6 addresses prior to calling. // Caller should take care to Unmap() any 4in6 addresses prior to calling.
func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHostInfo { func (c *Control) GetHostInfoByVpnAddr(vpnAddr netip.Addr, pending bool) *ControlHostInfo {
var hl controlHostLister var hl controlHostLister
if pending { if pending {
hl = c.f.handshakeManager hl = c.f.handshakeManager
@ -175,7 +178,7 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHos
hl = c.f.hostMap hl = c.f.hostMap
} }
h := hl.QueryVpnIp(vpnIp) h := hl.QueryVpnAddr(vpnAddr)
if h == nil { if h == nil {
return nil return nil
} }
@ -187,7 +190,7 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHos
// SetRemoteForTunnel forces a tunnel to use a specific remote // SetRemoteForTunnel forces a tunnel to use a specific remote
// Caller should take care to Unmap() any 4in6 addresses prior to calling. // Caller should take care to Unmap() any 4in6 addresses prior to calling.
func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *ControlHostInfo { func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *ControlHostInfo {
hostInfo := c.f.hostMap.QueryVpnIp(vpnIp) hostInfo := c.f.hostMap.QueryVpnAddr(vpnIp)
if hostInfo == nil { if hostInfo == nil {
return nil return nil
} }
@ -200,7 +203,7 @@ func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *Con
// CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well. // CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
// Caller should take care to Unmap() any 4in6 addresses prior to calling. // Caller should take care to Unmap() any 4in6 addresses prior to calling.
func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool { func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool {
hostInfo := c.f.hostMap.QueryVpnIp(vpnIp) hostInfo := c.f.hostMap.QueryVpnAddr(vpnIp)
if hostInfo == nil { if hostInfo == nil {
return false return false
} }
@ -229,14 +232,14 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
shutdown := func(h *HostInfo) { shutdown := func(h *HostInfo) {
if excludeLighthouses { if excludeLighthouses {
if _, ok := lighthouses[h.vpnIp]; ok { if _, ok := lighthouses[h.vpnAddrs[0]]; ok {
return return
} }
} }
c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
c.f.closeTunnel(h) c.f.closeTunnel(h)
c.l.WithField("vpnIp", h.vpnIp).WithField("udpAddr", h.remote). c.l.WithField("vpnIp", h.vpnAddrs[0]).WithField("udpAddr", h.remote).
Debug("Sending close tunnel message") Debug("Sending close tunnel message")
closed++ closed++
} }
@ -246,7 +249,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
// Grab the hostMap lock to access the Relays map // Grab the hostMap lock to access the Relays map
c.f.hostMap.Lock() c.f.hostMap.Lock()
for _, relayingHost := range c.f.hostMap.Relays { for _, relayingHost := range c.f.hostMap.Relays {
relayingHosts[relayingHost.vpnIp] = relayingHost relayingHosts[relayingHost.vpnAddrs[0]] = relayingHost
} }
c.f.hostMap.Unlock() c.f.hostMap.Unlock()
@ -254,7 +257,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
// Grab the hostMap lock to access the Hosts map // Grab the hostMap lock to access the Hosts map
c.f.hostMap.Lock() c.f.hostMap.Lock()
for _, relayHost := range c.f.hostMap.Indexes { for _, relayHost := range c.f.hostMap.Indexes {
if _, ok := relayingHosts[relayHost.vpnIp]; !ok { if _, ok := relayingHosts[relayHost.vpnAddrs[0]]; !ok {
hostInfos = append(hostInfos, relayHost) hostInfos = append(hostInfos, relayHost)
} }
} }
@ -274,9 +277,8 @@ func (c *Control) Device() overlay.Device {
} }
func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo { func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
chi := ControlHostInfo{ chi := ControlHostInfo{
VpnIp: h.vpnIp, VpnAddrs: make([]netip.Addr, len(h.vpnAddrs)),
LocalIndex: h.localIndexId, LocalIndex: h.localIndexId,
RemoteIndex: h.remoteIndexId, RemoteIndex: h.remoteIndexId,
RemoteAddrs: h.remotes.CopyAddrs(preferredRanges), RemoteAddrs: h.remotes.CopyAddrs(preferredRanges),
@ -285,6 +287,10 @@ func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
CurrentRemote: h.remote, CurrentRemote: h.remote,
} }
for i, a := range h.vpnAddrs {
chi.VpnAddrs[i] = a
}
if h.ConnectionState != nil { if h.ConnectionState != nil {
chi.MessageCounter = h.ConnectionState.messageCounter.Load() chi.MessageCounter = h.ConnectionState.messageCounter.Load()
} }
@ -299,7 +305,7 @@ func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
func listHostMapHosts(hl controlHostLister) []ControlHostInfo { func listHostMapHosts(hl controlHostLister) []ControlHostInfo {
hosts := make([]ControlHostInfo, 0) hosts := make([]ControlHostInfo, 0)
pr := hl.GetPreferredRanges() pr := hl.GetPreferredRanges()
hl.ForEachVpnIp(func(hostinfo *HostInfo) { hl.ForEachVpnAddr(func(hostinfo *HostInfo) {
hosts = append(hosts, copyHostInfo(hostinfo, pr)) hosts = append(hosts, copyHostInfo(hostinfo, pr))
}) })
return hosts return hosts

View File

@ -19,7 +19,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object // Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
// To properly ensure we are not exposing core memory to the caller // To properly ensure we are not exposing core memory to the caller
hm := newHostMap(l, netip.Prefix{}) hm := newHostMap(l)
hm.preferredRanges.Store(&[]netip.Prefix{}) hm.preferredRanges.Store(&[]netip.Prefix{})
remote1 := netip.MustParseAddrPort("0.0.0.100:4444") remote1 := netip.MustParseAddrPort("0.0.0.100:4444")
@ -35,9 +35,9 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
Mask: net.IPMask{255, 255, 255, 0}, Mask: net.IPMask{255, 255, 255, 0},
} }
remotes := NewRemoteList(nil) remotes := NewRemoteList([]netip.Addr{netip.IPv4Unspecified()}, nil)
remotes.unlockedPrependV4(netip.IPv4Unspecified(), NewIp4AndPortFromNetIP(remote1.Addr(), remote1.Port())) remotes.unlockedPrependV4(netip.IPv4Unspecified(), netAddrToProtoV4AddrPort(remote1.Addr(), remote1.Port()))
remotes.unlockedPrependV6(netip.IPv4Unspecified(), NewIp6AndPortFromNetIP(remote2.Addr(), remote2.Port())) remotes.unlockedPrependV6(netip.IPv4Unspecified(), netAddrToProtoV6AddrPort(remote2.Addr(), remote2.Port()))
vpnIp, ok := netip.AddrFromSlice(ipNet.IP) vpnIp, ok := netip.AddrFromSlice(ipNet.IP)
assert.True(t, ok) assert.True(t, ok)
@ -51,10 +51,10 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
}, },
remoteIndexId: 200, remoteIndexId: 200,
localIndexId: 201, localIndexId: 201,
vpnIp: vpnIp, vpnAddrs: []netip.Addr{vpnIp},
relayState: RelayState{ relayState: RelayState{
relays: map[netip.Addr]struct{}{}, relays: map[netip.Addr]struct{}{},
relayForByIp: map[netip.Addr]*Relay{}, relayForByAddr: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{}, relayForByIdx: map[uint32]*Relay{},
}, },
}, &Interface{}) }, &Interface{})
@ -70,10 +70,10 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
}, },
remoteIndexId: 200, remoteIndexId: 200,
localIndexId: 201, localIndexId: 201,
vpnIp: vpnIp2, vpnAddrs: []netip.Addr{vpnIp2},
relayState: RelayState{ relayState: RelayState{
relays: map[netip.Addr]struct{}{}, relays: map[netip.Addr]struct{}{},
relayForByIp: map[netip.Addr]*Relay{}, relayForByAddr: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{}, relayForByIdx: map[uint32]*Relay{},
}, },
}, &Interface{}) }, &Interface{})
@ -85,10 +85,10 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
l: logrus.New(), l: logrus.New(),
} }
thi := c.GetHostInfoByVpnIp(vpnIp, false) thi := c.GetHostInfoByVpnAddr(vpnIp, false)
expectedInfo := ControlHostInfo{ expectedInfo := ControlHostInfo{
VpnIp: vpnIp, VpnAddrs: []netip.Addr{vpnIp},
LocalIndex: 201, LocalIndex: 201,
RemoteIndex: 200, RemoteIndex: 200,
RemoteAddrs: []netip.AddrPort{remote2, remote1}, RemoteAddrs: []netip.AddrPort{remote2, remote1},
@ -100,13 +100,13 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
} }
// Make sure we don't have any unexpected fields // Make sure we don't have any unexpected fields
assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi) assertFields(t, []string{"VpnAddrs", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
assert.EqualValues(t, &expectedInfo, thi) assert.EqualValues(t, &expectedInfo, thi)
test.AssertDeepCopyEqual(t, &expectedInfo, thi) test.AssertDeepCopyEqual(t, &expectedInfo, thi)
// Make sure we don't panic if the host info doesn't have a cert yet // Make sure we don't panic if the host info doesn't have a cert yet
assert.NotPanics(t, func() { assert.NotPanics(t, func() {
thi = c.GetHostInfoByVpnIp(vpnIp2, false) thi = c.GetHostInfoByVpnAddr(vpnIp2, false)
}) })
} }

View File

@ -6,8 +6,6 @@ package nebula
import ( import (
"net/netip" "net/netip"
"github.com/slackhq/nebula/cert"
"github.com/google/gopacket" "github.com/google/gopacket"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
@ -51,15 +49,15 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType,
// This is necessary if you did not configure static hosts or are not running a lighthouse // This is necessary if you did not configure static hosts or are not running a lighthouse
func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort) { func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort) {
c.f.lightHouse.Lock() c.f.lightHouse.Lock()
remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp) remoteList := c.f.lightHouse.unlockedGetRemoteList([]netip.Addr{vpnIp})
remoteList.Lock() remoteList.Lock()
defer remoteList.Unlock() defer remoteList.Unlock()
c.f.lightHouse.Unlock() c.f.lightHouse.Unlock()
if toAddr.Addr().Is4() { if toAddr.Addr().Is4() {
remoteList.unlockedPrependV4(vpnIp, NewIp4AndPortFromNetIP(toAddr.Addr(), toAddr.Port())) remoteList.unlockedPrependV4(vpnIp, netAddrToProtoV4AddrPort(toAddr.Addr(), toAddr.Port()))
} else { } else {
remoteList.unlockedPrependV6(vpnIp, NewIp6AndPortFromNetIP(toAddr.Addr(), toAddr.Port())) remoteList.unlockedPrependV6(vpnIp, netAddrToProtoV6AddrPort(toAddr.Addr(), toAddr.Port()))
} }
} }
@ -67,7 +65,7 @@ func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort)
// This is necessary to inform an initiator of possible relays for communicating with a responder // This is necessary to inform an initiator of possible relays for communicating with a responder
func (c *Control) InjectRelays(vpnIp netip.Addr, relayVpnIps []netip.Addr) { func (c *Control) InjectRelays(vpnIp netip.Addr, relayVpnIps []netip.Addr) {
c.f.lightHouse.Lock() c.f.lightHouse.Lock()
remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp) remoteList := c.f.lightHouse.unlockedGetRemoteList([]netip.Addr{vpnIp})
remoteList.Lock() remoteList.Lock()
defer remoteList.Unlock() defer remoteList.Unlock()
c.f.lightHouse.Unlock() c.f.lightHouse.Unlock()
@ -99,21 +97,42 @@ func (c *Control) InjectUDPPacket(p *udp.Packet) {
} }
// InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol // InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol
func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort uint16, data []byte) { func (c *Control) InjectTunUDPPacket(toAddr netip.Addr, toPort uint16, fromAddr netip.Addr, fromPort uint16, data []byte) {
//TODO: IPV6-WORK serialize := make([]gopacket.SerializableLayer, 0)
ip := layers.IPv4{ var netLayer gopacket.NetworkLayer
if toAddr.Is6() {
if !fromAddr.Is6() {
panic("Cant send ipv6 to ipv4")
}
ip := &layers.IPv6{
Version: 6,
NextHeader: layers.IPProtocolUDP,
SrcIP: fromAddr.Unmap().AsSlice(),
DstIP: toAddr.Unmap().AsSlice(),
}
serialize = append(serialize, ip)
netLayer = ip
} else {
if !fromAddr.Is4() {
panic("Cant send ipv4 to ipv6")
}
ip := &layers.IPv4{
Version: 4, Version: 4,
TTL: 64, TTL: 64,
Protocol: layers.IPProtocolUDP, Protocol: layers.IPProtocolUDP,
SrcIP: c.f.inside.Cidr().Addr().Unmap().AsSlice(), SrcIP: fromAddr.Unmap().AsSlice(),
DstIP: toIp.Unmap().AsSlice(), DstIP: toAddr.Unmap().AsSlice(),
}
serialize = append(serialize, ip)
netLayer = ip
} }
udp := layers.UDP{ udp := layers.UDP{
SrcPort: layers.UDPPort(fromPort), SrcPort: layers.UDPPort(fromPort),
DstPort: layers.UDPPort(toPort), DstPort: layers.UDPPort(toPort),
} }
err := udp.SetNetworkLayerForChecksum(&ip) err := udp.SetNetworkLayerForChecksum(netLayer)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -123,7 +142,9 @@ func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort ui
ComputeChecksums: true, ComputeChecksums: true,
FixLengths: true, FixLengths: true,
} }
err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload(data))
serialize = append(serialize, &udp, gopacket.Payload(data))
err = gopacket.SerializeLayers(buffer, opt, serialize...)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -131,8 +152,8 @@ func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort ui
c.f.inside.(*overlay.TestTun).Send(buffer.Bytes()) c.f.inside.(*overlay.TestTun).Send(buffer.Bytes())
} }
func (c *Control) GetVpnIp() netip.Addr { func (c *Control) GetVpnAddrs() []netip.Addr {
return c.f.myVpnNet.Addr() return c.f.myVpnAddrs
} }
func (c *Control) GetUDPAddr() netip.AddrPort { func (c *Control) GetUDPAddr() netip.AddrPort {
@ -140,7 +161,7 @@ func (c *Control) GetUDPAddr() netip.AddrPort {
} }
func (c *Control) KillPendingTunnel(vpnIp netip.Addr) bool { func (c *Control) KillPendingTunnel(vpnIp netip.Addr) bool {
hostinfo := c.f.handshakeManager.QueryVpnIp(vpnIp) hostinfo := c.f.handshakeManager.QueryVpnAddr(vpnIp)
if hostinfo == nil { if hostinfo == nil {
return false return false
} }
@ -153,8 +174,8 @@ func (c *Control) GetHostmap() *HostMap {
return c.f.hostMap return c.f.hostMap
} }
func (c *Control) GetCert() cert.Certificate { func (c *Control) GetCertState() *CertState {
return c.f.pki.GetCertState().Certificate return c.f.pki.getCertState()
} }
func (c *Control) ReHandshake(vpnIp netip.Addr) { func (c *Control) ReHandshake(vpnIp netip.Addr) {

View File

@ -8,6 +8,7 @@ import (
"strings" "strings"
"sync" "sync"
"github.com/gaissmai/bart"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
@ -21,24 +22,39 @@ var dnsAddr string
type dnsRecords struct { type dnsRecords struct {
sync.RWMutex sync.RWMutex
dnsMap map[string]string l *logrus.Logger
dnsMap4 map[string]netip.Addr
dnsMap6 map[string]netip.Addr
hostMap *HostMap hostMap *HostMap
myVpnAddrsTable *bart.Table[struct{}]
} }
func newDnsRecords(hostMap *HostMap) *dnsRecords { func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords {
return &dnsRecords{ return &dnsRecords{
dnsMap: make(map[string]string), l: l,
dnsMap4: make(map[string]netip.Addr),
dnsMap6: make(map[string]netip.Addr),
hostMap: hostMap, hostMap: hostMap,
myVpnAddrsTable: cs.myVpnAddrsTable,
} }
} }
func (d *dnsRecords) Query(data string) string { func (d *dnsRecords) Query(q uint16, data string) netip.Addr {
data = strings.ToLower(data)
d.RLock() d.RLock()
defer d.RUnlock() defer d.RUnlock()
if r, ok := d.dnsMap[strings.ToLower(data)]; ok { switch q {
case dns.TypeA:
if r, ok := d.dnsMap4[data]; ok {
return r return r
} }
return "" case dns.TypeAAAA:
if r, ok := d.dnsMap6[data]; ok {
return r
}
}
return netip.Addr{}
} }
func (d *dnsRecords) QueryCert(data string) string { func (d *dnsRecords) QueryCert(data string) string {
@ -47,7 +63,7 @@ func (d *dnsRecords) QueryCert(data string) string {
return "" return ""
} }
hostinfo := d.hostMap.QueryVpnIp(ip) hostinfo := d.hostMap.QueryVpnAddr(ip)
if hostinfo == nil { if hostinfo == nil {
return "" return ""
} }
@ -64,38 +80,62 @@ func (d *dnsRecords) QueryCert(data string) string {
return string(b) return string(b)
} }
func (d *dnsRecords) Add(host, data string) { // Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host`
func (d *dnsRecords) Add(host string, addresses []netip.Addr) {
host = strings.ToLower(host)
d.Lock() d.Lock()
defer d.Unlock() defer d.Unlock()
d.dnsMap[strings.ToLower(host)] = data haveV4 := false
haveV6 := false
for _, addr := range addresses {
if addr.Is4() && !haveV4 {
d.dnsMap4[host] = addr
haveV4 = true
} else if addr.Is6() && !haveV6 {
d.dnsMap6[host] = addr
haveV6 = true
}
if haveV4 && haveV6 {
break
}
}
} }
func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) { func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool {
a, _, _ := net.SplitHostPort(addr)
b, err := netip.ParseAddr(a)
if err != nil {
return false
}
if b.IsLoopback() {
return true
}
_, found := d.myVpnAddrsTable.Lookup(b)
return found //if we found it in this table, it's good
}
func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
for _, q := range m.Question { for _, q := range m.Question {
switch q.Qtype { switch q.Qtype {
case dns.TypeA: case dns.TypeA, dns.TypeAAAA:
l.Debugf("Query for A %s", q.Name) qType := dns.TypeToString[q.Qtype]
ip := dnsR.Query(q.Name) d.l.Debugf("Query for %s %s", qType, q.Name)
if ip != "" { ip := d.Query(q.Qtype, q.Name)
rr, err := dns.NewRR(fmt.Sprintf("%s A %s", q.Name, ip)) if ip.IsValid() {
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip))
if err == nil { if err == nil {
m.Answer = append(m.Answer, rr) m.Answer = append(m.Answer, rr)
} }
} }
case dns.TypeTXT: case dns.TypeTXT:
a, _, _ := net.SplitHostPort(w.RemoteAddr().String()) // We only answer these queries from nebula nodes or localhost
b, err := netip.ParseAddr(a) if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) {
if err != nil {
return return
} }
d.l.Debugf("Query for TXT %s", q.Name)
// We don't answer these queries from non nebula nodes or localhost ip := d.QueryCert(q.Name)
//l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR)
if !dnsR.hostMap.vpnCIDR.Contains(b) && a != "127.0.0.1" {
return
}
l.Debugf("Query for TXT %s", q.Name)
ip := dnsR.QueryCert(q.Name)
if ip != "" { if ip != "" {
rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip)) rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
if err == nil { if err == nil {
@ -110,26 +150,24 @@ func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
} }
} }
func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) { func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg) m := new(dns.Msg)
m.SetReply(r) m.SetReply(r)
m.Compress = false m.Compress = false
switch r.Opcode { switch r.Opcode {
case dns.OpcodeQuery: case dns.OpcodeQuery:
parseQuery(l, m, w) d.parseQuery(m, w)
} }
w.WriteMsg(m) w.WriteMsg(m)
} }
func dnsMain(l *logrus.Logger, hostMap *HostMap, c *config.C) func() { func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() {
dnsR = newDnsRecords(hostMap) dnsR = newDnsRecords(l, cs, hostMap)
// attach request handler func // attach request handler func
dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) { dns.HandleFunc(".", dnsR.handleDnsRequest)
handleDnsRequest(l, w, r)
})
c.RegisterReloadCallback(func(c *config.C) { c.RegisterReloadCallback(func(c *config.C) {
reloadDns(l, c) reloadDns(l, c)

View File

@ -1,23 +1,38 @@
package nebula package nebula
import ( import (
"net/netip"
"testing" "testing"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestParsequery(t *testing.T) { func TestParsequery(t *testing.T) {
//TODO: This test is basically pointless l := logrus.New()
hostMap := &HostMap{} hostMap := &HostMap{}
ds := newDnsRecords(hostMap) ds := newDnsRecords(l, &CertState{}, hostMap)
ds.Add("test.com.com", "1.2.3.4") addrs := []netip.Addr{
netip.MustParseAddr("1.2.3.4"),
netip.MustParseAddr("1.2.3.5"),
netip.MustParseAddr("fd01::24"),
netip.MustParseAddr("fd01::25"),
}
ds.Add("test.com.com", addrs)
m := new(dns.Msg) m := &dns.Msg{}
m.SetQuestion("test.com.com", dns.TypeA) m.SetQuestion("test.com.com", dns.TypeA)
ds.parseQuery(m, nil)
assert.NotNil(t, m.Answer)
assert.Equal(t, "1.2.3.4", m.Answer[0].(*dns.A).A.String())
//parseQuery(m) m = &dns.Msg{}
m.SetQuestion("test.com.com", dns.TypeAAAA)
ds.parseQuery(m, nil)
assert.NotNil(t, m.Answer)
assert.Equal(t, "fd01::24", m.Answer[0].(*dns.AAAA).AAAA.String())
} }
func Test_getDnsServerAddr(t *testing.T) { func Test_getDnsServerAddr(t *testing.T) {

View File

@ -4,7 +4,6 @@
package e2e package e2e
import ( import (
"fmt"
"net/netip" "net/netip"
"slices" "slices"
"testing" "testing"
@ -12,6 +11,7 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula" "github.com/slackhq/nebula"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/e2e/router" "github.com/slackhq/nebula/e2e/router"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
@ -21,11 +21,11 @@ import (
func BenchmarkHotPath(b *testing.B) { func BenchmarkHotPath(b *testing.B) {
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, _, _, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
// Put their info in our lighthouse // Put their info in our lighthouse
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
// Start the servers // Start the servers
myControl.Start() myControl.Start()
@ -35,7 +35,7 @@ func BenchmarkHotPath(b *testing.B) {
r.CancelFlowLogs() r.CancelFlowLogs()
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
_ = r.RouteForAllUntilTxTun(theirControl) _ = r.RouteForAllUntilTxTun(theirControl)
} }
@ -45,18 +45,18 @@ func BenchmarkHotPath(b *testing.B) {
func TestGoodHandshake(t *testing.T) { func TestGoodHandshake(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
// Put their info in our lighthouse // Put their info in our lighthouse
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
// Start the servers // Start the servers
myControl.Start() myControl.Start()
theirControl.Start() theirControl.Start()
t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side") t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
t.Log("Have them consume my stage 0 packet. They have a tunnel now") t.Log("Have them consume my stage 0 packet. They have a tunnel now")
theirControl.InjectUDPPacket(myControl.GetFromUDP(true)) theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
@ -77,16 +77,16 @@ func TestGoodHandshake(t *testing.T) {
myControl.WaitForType(1, 0, theirControl) myControl.WaitForType(1, 0, theirControl)
t.Log("Make sure our host infos are correct") t.Log("Make sure our host infos are correct")
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl) assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
t.Log("Get that cached packet and make sure it looks right") t.Log("Get that cached packet and make sure it looks right")
myCachedPacket := theirControl.GetFromTun(true) myCachedPacket := theirControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
t.Log("Do a bidirectional tunnel test") t.Log("Do a bidirectional tunnel test")
r := router.NewR(t, myControl, theirControl) r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow() defer r.RenderFlow()
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
r.RenderHostmaps("Final hostmaps", myControl, theirControl) r.RenderHostmaps("Final hostmaps", myControl, theirControl)
myControl.Stop() myControl.Stop()
@ -97,12 +97,12 @@ func TestGoodHandshake(t *testing.T) {
func TestWrongResponderHandshake(t *testing.T) { func TestWrongResponderHandshake(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.100/24", nil) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.100/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.99/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.99/24", nil)
evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", "10.128.0.2/24", nil) evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "evil", "10.128.0.2/24", nil)
// Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse. // Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse.
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), evilUdpAddr) myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), evilUdpAddr)
// Build a router so we don't have to reason who gets which packet // Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, theirControl, evilControl) r := router.NewR(t, myControl, theirControl, evilControl)
@ -114,7 +114,7 @@ func TestWrongResponderHandshake(t *testing.T) {
evilControl.Start() evilControl.Start()
t.Log("Start the handshake process, we will route until we see the evil tunnel closed") t.Log("Start the handshake process, we will route until we see the evil tunnel closed")
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
h := &header.H{} h := &header.H{}
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
@ -131,8 +131,8 @@ func TestWrongResponderHandshake(t *testing.T) {
}) })
t.Log("Evil tunnel is closed, inject the correct udp addr for them") t.Log("Evil tunnel is closed, inject the correct udp addr for them")
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
pendingHi := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), true) pendingHi := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), true)
assert.NotContains(t, pendingHi.RemoteAddrs, evilUdpAddr) assert.NotContains(t, pendingHi.RemoteAddrs, evilUdpAddr)
t.Log("Route until we see the cached packet") t.Log("Route until we see the cached packet")
@ -153,18 +153,18 @@ func TestWrongResponderHandshake(t *testing.T) {
t.Log("My cached packet should be received by them") t.Log("My cached packet should be received by them")
myCachedPacket := theirControl.GetFromTun(true) myCachedPacket := theirControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
t.Log("Test the tunnel with them") t.Log("Test the tunnel with them")
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl) assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
t.Log("Flush all packets from all controllers") t.Log("Flush all packets from all controllers")
r.FlushAll() r.FlushAll()
t.Log("Ensure ensure I don't have any hostinfo artifacts from evil") t.Log("Ensure ensure I don't have any hostinfo artifacts from evil")
assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), true), "My pending hostmap should not contain evil") assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIp[0].Addr(), true), "My pending hostmap should not contain evil")
assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), false), "My main hostmap should not contain evil") assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIp[0].Addr(), false), "My main hostmap should not contain evil")
//TODO: assert hostmaps for everyone //TODO: assert hostmaps for everyone
r.RenderHostmaps("Final hostmaps", myControl, theirControl, evilControl) r.RenderHostmaps("Final hostmaps", myControl, theirControl, evilControl)
@ -176,17 +176,17 @@ func TestWrongResponderHandshake(t *testing.T) {
func TestWrongResponderHandshakeStaticHostMap(t *testing.T) { func TestWrongResponderHandshakeStaticHostMap(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.99/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.99/24", nil)
evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", "10.128.0.2/24", nil) evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "evil", "10.128.0.2/24", nil)
o := m{ o := m{
"static_host_map": m{ "static_host_map": m{
theirVpnIpNet.Addr().String(): []string{evilUdpAddr.String()}, theirVpnIpNet[0].Addr().String(): []string{evilUdpAddr.String()},
}, },
} }
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.100/24", o) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.100/24", o)
// Put the evil udp addr in for their vpn addr, this is a case of a remote at a static entry changing its vpn addr. // Put the evil udp addr in for their vpn addr, this is a case of a remote at a static entry changing its vpn addr.
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), evilUdpAddr) myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), evilUdpAddr)
// Build a router so we don't have to reason who gets which packet // Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, theirControl, evilControl) r := router.NewR(t, myControl, theirControl, evilControl)
@ -198,7 +198,7 @@ func TestWrongResponderHandshakeStaticHostMap(t *testing.T) {
evilControl.Start() evilControl.Start()
t.Log("Start the handshake process, we will route until we see the evil tunnel closed") t.Log("Start the handshake process, we will route until we see the evil tunnel closed")
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
h := &header.H{} h := &header.H{}
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
@ -215,8 +215,8 @@ func TestWrongResponderHandshakeStaticHostMap(t *testing.T) {
}) })
t.Log("Evil tunnel is closed, inject the correct udp addr for them") t.Log("Evil tunnel is closed, inject the correct udp addr for them")
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
pendingHi := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), true) pendingHi := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), true)
assert.NotContains(t, pendingHi.RemoteAddrs, evilUdpAddr) assert.NotContains(t, pendingHi.RemoteAddrs, evilUdpAddr)
t.Log("Route until we see the cached packet") t.Log("Route until we see the cached packet")
@ -237,18 +237,19 @@ func TestWrongResponderHandshakeStaticHostMap(t *testing.T) {
t.Log("My cached packet should be received by them") t.Log("My cached packet should be received by them")
myCachedPacket := theirControl.GetFromTun(true) myCachedPacket := theirControl.GetFromTun(true)
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
t.Log("Test the tunnel with them") t.Log("Test the tunnel with them")
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl) assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
t.Log("Flush all packets from all controllers") t.Log("Flush all packets from all controllers")
r.FlushAll() r.FlushAll()
t.Log("Ensure ensure I don't have any hostinfo artifacts from evil") t.Log("Ensure ensure I don't have any hostinfo artifacts from evil")
assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), true), "My pending hostmap should not contain evil") assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIp[0].Addr(), true), "My pending hostmap should not contain evil")
assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), false), "My main hostmap should not contain evil") assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIp[0].Addr(), false), "My main hostmap should not contain evil")
//NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete
//TODO: assert hostmaps for everyone //TODO: assert hostmaps for everyone
r.RenderHostmaps("Final hostmaps", myControl, theirControl, evilControl) r.RenderHostmaps("Final hostmaps", myControl, theirControl, evilControl)
@ -262,12 +263,12 @@ func TestStage1Race(t *testing.T) {
// But will eventually collapse down to a single tunnel // But will eventually collapse down to a single tunnel
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
// Put their info in our lighthouse and vice versa // Put their info in our lighthouse and vice versa
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Build a router so we don't have to reason who gets which packet // Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, theirControl) r := router.NewR(t, myControl, theirControl)
@ -278,8 +279,8 @@ func TestStage1Race(t *testing.T) {
theirControl.Start() theirControl.Start()
t.Log("Trigger a handshake to start on both me and them") t.Log("Trigger a handshake to start on both me and them")
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them")) theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
t.Log("Get both stage 1 handshake packets") t.Log("Get both stage 1 handshake packets")
myHsForThem := myControl.GetFromUDP(true) myHsForThem := myControl.GetFromUDP(true)
@ -291,14 +292,14 @@ func TestStage1Race(t *testing.T) {
r.Log("Route until they receive a message packet") r.Log("Route until they receive a message packet")
myCachedPacket := r.RouteForAllUntilTxTun(theirControl) myCachedPacket := r.RouteForAllUntilTxTun(theirControl)
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
r.Log("Their cached packet should be received by me") r.Log("Their cached packet should be received by me")
theirCachedPacket := r.RouteForAllUntilTxTun(myControl) theirCachedPacket := r.RouteForAllUntilTxTun(myControl)
assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), 80, 80) assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80)
r.Log("Do a bidirectional tunnel test") r.Log("Do a bidirectional tunnel test")
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
myHostmapHosts := myControl.ListHostmapHosts(false) myHostmapHosts := myControl.ListHostmapHosts(false)
myHostmapIndexes := myControl.ListHostmapIndexes(false) myHostmapIndexes := myControl.ListHostmapIndexes(false)
@ -316,7 +317,7 @@ func TestStage1Race(t *testing.T) {
r.Log("Spin until connection manager tears down a tunnel") r.Log("Spin until connection manager tears down a tunnel")
for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
t.Log("Connection manager hasn't ticked yet") t.Log("Connection manager hasn't ticked yet")
time.Sleep(time.Second) time.Sleep(time.Second)
} }
@ -339,12 +340,12 @@ func TestStage1Race(t *testing.T) {
func TestUncleanShutdownRaceLoser(t *testing.T) { func TestUncleanShutdownRaceLoser(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
// Teach my how to get to the relay and that their can be reached via the relay // Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Build a router so we don't have to reason who gets which packet // Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, theirControl) r := router.NewR(t, myControl, theirControl)
@ -355,10 +356,10 @@ func TestUncleanShutdownRaceLoser(t *testing.T) {
theirControl.Start() theirControl.Start()
r.Log("Trigger a handshake from me to them") r.Log("Trigger a handshake from me to them")
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl) p := r.RouteForAllUntilTxTun(theirControl)
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
r.Log("Nuke my hostmap") r.Log("Nuke my hostmap")
myHostmap := myControl.GetHostmap() myHostmap := myControl.GetHostmap()
@ -366,17 +367,17 @@ func TestUncleanShutdownRaceLoser(t *testing.T) {
myHostmap.Indexes = map[uint32]*nebula.HostInfo{} myHostmap.Indexes = map[uint32]*nebula.HostInfo{}
myHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{} myHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me again")) myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me again"))
p = r.RouteForAllUntilTxTun(theirControl) p = r.RouteForAllUntilTxTun(theirControl)
assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
r.Log("Assert the tunnel works") r.Log("Assert the tunnel works")
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
r.Log("Wait for the dead index to go away") r.Log("Wait for the dead index to go away")
start := len(theirControl.GetHostmap().Indexes) start := len(theirControl.GetHostmap().Indexes)
for { for {
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
if len(theirControl.GetHostmap().Indexes) < start { if len(theirControl.GetHostmap().Indexes) < start {
break break
} }
@ -388,12 +389,12 @@ func TestUncleanShutdownRaceLoser(t *testing.T) {
func TestUncleanShutdownRaceWinner(t *testing.T) { func TestUncleanShutdownRaceWinner(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
// Teach my how to get to the relay and that their can be reached via the relay // Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Build a router so we don't have to reason who gets which packet // Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, theirControl) r := router.NewR(t, myControl, theirControl)
@ -404,10 +405,10 @@ func TestUncleanShutdownRaceWinner(t *testing.T) {
theirControl.Start() theirControl.Start()
r.Log("Trigger a handshake from me to them") r.Log("Trigger a handshake from me to them")
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl) p := r.RouteForAllUntilTxTun(theirControl)
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
r.RenderHostmaps("Final hostmaps", myControl, theirControl) r.RenderHostmaps("Final hostmaps", myControl, theirControl)
r.Log("Nuke my hostmap") r.Log("Nuke my hostmap")
@ -416,18 +417,18 @@ func TestUncleanShutdownRaceWinner(t *testing.T) {
theirHostmap.Indexes = map[uint32]*nebula.HostInfo{} theirHostmap.Indexes = map[uint32]*nebula.HostInfo{}
theirHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{} theirHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them again")) theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them again"))
p = r.RouteForAllUntilTxTun(myControl) p = r.RouteForAllUntilTxTun(myControl)
assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), 80, 80) assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80)
r.RenderHostmaps("Derp hostmaps", myControl, theirControl) r.RenderHostmaps("Derp hostmaps", myControl, theirControl)
r.Log("Assert the tunnel works") r.Log("Assert the tunnel works")
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
r.Log("Wait for the dead index to go away") r.Log("Wait for the dead index to go away")
start := len(myControl.GetHostmap().Indexes) start := len(myControl.GetHostmap().Indexes)
for { for {
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
if len(myControl.GetHostmap().Indexes) < start { if len(myControl.GetHostmap().Indexes) < start {
break break
} }
@ -439,14 +440,14 @@ func TestUncleanShutdownRaceWinner(t *testing.T) {
func TestRelays(t *testing.T) { func TestRelays(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
// Teach my how to get to the relay and that their can be reached via the relay // Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
// Build a router so we don't have to reason who gets which packet // Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl) r := router.NewR(t, myControl, relayControl, theirControl)
@ -458,11 +459,11 @@ func TestRelays(t *testing.T) {
theirControl.Start() theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay") t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl) p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works") r.Log("Assert the tunnel works")
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl) r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
//TODO: assert we actually used the relay even though it should be impossible for a tunnel to have occurred without it //TODO: assert we actually used the relay even though it should be impossible for a tunnel to have occurred without it
} }
@ -470,19 +471,19 @@ func TestRelays(t *testing.T) {
func TestStage1RaceRelays(t *testing.T) { func TestStage1RaceRelays(t *testing.T) {
//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay //NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
// Teach my how to get to the relay and that their can be reached via the relay // Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
theirControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) theirControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
theirControl.InjectRelays(myVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) theirControl.InjectRelays(myVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
relayControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) relayControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Build a router so we don't have to reason who gets which packet // Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl) r := router.NewR(t, myControl, relayControl, theirControl)
@ -494,14 +495,14 @@ func TestStage1RaceRelays(t *testing.T) {
theirControl.Start() theirControl.Start()
r.Log("Get a tunnel between me and relay") r.Log("Get a tunnel between me and relay")
assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r) assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r)
r.Log("Get a tunnel between them and relay") r.Log("Get a tunnel between them and relay")
assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r) assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r)
r.Log("Trigger a handshake from both them and me via relay to them and me") r.Log("Trigger a handshake from both them and me via relay to them and me")
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them")) theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
r.Log("Wait for a packet from them to me") r.Log("Wait for a packet from them to me")
p := r.RouteForAllUntilTxTun(myControl) p := r.RouteForAllUntilTxTun(myControl)
@ -519,20 +520,20 @@ func TestStage1RaceRelays(t *testing.T) {
func TestStage1RaceRelays2(t *testing.T) { func TestStage1RaceRelays2(t *testing.T) {
//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay //NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
l := NewTestLogger() l := NewTestLogger()
// Teach my how to get to the relay and that their can be reached via the relay // Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
theirControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) theirControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
theirControl.InjectRelays(myVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) theirControl.InjectRelays(myVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
relayControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) relayControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Build a router so we don't have to reason who gets which packet // Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl) r := router.NewR(t, myControl, relayControl, theirControl)
@ -545,16 +546,16 @@ func TestStage1RaceRelays2(t *testing.T) {
r.Log("Get a tunnel between me and relay") r.Log("Get a tunnel between me and relay")
l.Info("Get a tunnel between me and relay") l.Info("Get a tunnel between me and relay")
assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r) assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r)
r.Log("Get a tunnel between them and relay") r.Log("Get a tunnel between them and relay")
l.Info("Get a tunnel between them and relay") l.Info("Get a tunnel between them and relay")
assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r) assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r)
r.Log("Trigger a handshake from both them and me via relay to them and me") r.Log("Trigger a handshake from both them and me via relay to them and me")
l.Info("Trigger a handshake from both them and me via relay to them and me") l.Info("Trigger a handshake from both them and me via relay to them and me")
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them")) theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
//r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone) //r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone)
//r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone) //r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone)
@ -567,7 +568,7 @@ func TestStage1RaceRelays2(t *testing.T) {
r.Log("Assert the tunnel works") r.Log("Assert the tunnel works")
l.Info("Assert the tunnel works") l.Info("Assert the tunnel works")
assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
t.Log("Wait until we remove extra tunnels") t.Log("Wait until we remove extra tunnels")
l.Info("Wait until we remove extra tunnels") l.Info("Wait until we remove extra tunnels")
@ -587,7 +588,7 @@ func TestStage1RaceRelays2(t *testing.T) {
"theirControl": len(theirControl.GetHostmap().Indexes), "theirControl": len(theirControl.GetHostmap().Indexes),
"relayControl": len(relayControl.GetHostmap().Indexes), "relayControl": len(relayControl.GetHostmap().Indexes),
}).Info("Waiting for hostinfos to be removed...") }).Info("Waiting for hostinfos to be removed...")
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
t.Log("Connection manager hasn't ticked yet") t.Log("Connection manager hasn't ticked yet")
time.Sleep(time.Second) time.Sleep(time.Second)
retries-- retries--
@ -595,7 +596,7 @@ func TestStage1RaceRelays2(t *testing.T) {
r.Log("Assert the tunnel works") r.Log("Assert the tunnel works")
l.Info("Assert the tunnel works") l.Info("Assert the tunnel works")
assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
myControl.Stop() myControl.Stop()
theirControl.Stop() theirControl.Stop()
@ -607,14 +608,14 @@ func TestStage1RaceRelays2(t *testing.T) {
func TestRehandshakingRelays(t *testing.T) { func TestRehandshakingRelays(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
// Teach my how to get to the relay and that their can be reached via the relay // Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
// Build a router so we don't have to reason who gets which packet // Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl) r := router.NewR(t, myControl, relayControl, theirControl)
@ -626,17 +627,17 @@ func TestRehandshakingRelays(t *testing.T) {
theirControl.Start() theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay") t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl) p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works") r.Log("Assert the tunnel works")
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
// When I update the certificate for the relay, both me and them will have 2 host infos for the relay, // When I update the certificate for the relay, both me and them will have 2 host infos for the relay,
// and the main host infos will not have any relay state to handle the me<->relay<->them tunnel. // and the main host infos will not have any relay state to handle the me<->relay<->them tunnel.
r.Log("Renew relay certificate and spin until me and them sees it") r.Log("Renew relay certificate and spin until me and them sees it")
_, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{relayVpnIpNet}, nil, []string{"new group"}) _, _, myNextPrivKey, myNextPEM := NewTestCert(cert.Version1, ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"})
caB, err := ca.MarshalPEM() caB, err := ca.MarshalPEM()
if err != nil { if err != nil {
@ -654,8 +655,8 @@ func TestRehandshakingRelays(t *testing.T) {
for { for {
r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet") r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet")
assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r) assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r)
c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) c := myControl.GetHostInfoByVpnAddr(relayVpnIpNet[0].Addr(), false)
if len(c.Cert.Groups()) != 0 { if len(c.Cert.Groups()) != 0 {
// We have a new certificate now // We have a new certificate now
r.Log("Certificate between my and relay is updated!") r.Log("Certificate between my and relay is updated!")
@ -667,8 +668,8 @@ func TestRehandshakingRelays(t *testing.T) {
for { for {
r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet") r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet")
assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r) assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r)
c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) c := theirControl.GetHostInfoByVpnAddr(relayVpnIpNet[0].Addr(), false)
if len(c.Cert.Groups()) != 0 { if len(c.Cert.Groups()) != 0 {
// We have a new certificate now // We have a new certificate now
r.Log("Certificate between their and relay is updated!") r.Log("Certificate between their and relay is updated!")
@ -679,13 +680,13 @@ func TestRehandshakingRelays(t *testing.T) {
} }
r.Log("Assert the relay tunnel still works") r.Log("Assert the relay tunnel still works")
assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
// We should have two hostinfos on all sides // We should have two hostinfos on all sides
for len(myControl.GetHostmap().Indexes) != 2 { for len(myControl.GetHostmap().Indexes) != 2 {
t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes)) t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes))
r.Log("Assert the relay tunnel still works") r.Log("Assert the relay tunnel still works")
assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
r.Log("yupitdoes") r.Log("yupitdoes")
time.Sleep(time.Second) time.Sleep(time.Second)
} }
@ -693,7 +694,7 @@ func TestRehandshakingRelays(t *testing.T) {
for len(theirControl.GetHostmap().Indexes) != 2 { for len(theirControl.GetHostmap().Indexes) != 2 {
t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes)) t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes))
r.Log("Assert the relay tunnel still works") r.Log("Assert the relay tunnel still works")
assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
r.Log("yupitdoes") r.Log("yupitdoes")
time.Sleep(time.Second) time.Sleep(time.Second)
} }
@ -701,7 +702,7 @@ func TestRehandshakingRelays(t *testing.T) {
for len(relayControl.GetHostmap().Indexes) != 2 { for len(relayControl.GetHostmap().Indexes) != 2 {
t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes)) t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes))
r.Log("Assert the relay tunnel still works") r.Log("Assert the relay tunnel still works")
assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
r.Log("yupitdoes") r.Log("yupitdoes")
time.Sleep(time.Second) time.Sleep(time.Second)
} }
@ -711,14 +712,14 @@ func TestRehandshakingRelays(t *testing.T) {
func TestRehandshakingRelaysPrimary(t *testing.T) { func TestRehandshakingRelaysPrimary(t *testing.T) {
// This test is the same as TestRehandshakingRelays but one of the terminal types is a primary swap winner // This test is the same as TestRehandshakingRelays but one of the terminal types is a primary swap winner
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.128/24", m{"relay": m{"use_relays": true}}) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.128/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", "10.128.0.1/24", m{"relay": m{"am_relay": true}}) relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.1/24", m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
// Teach my how to get to the relay and that their can be reached via the relay // Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
// Build a router so we don't have to reason who gets which packet // Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl) r := router.NewR(t, myControl, relayControl, theirControl)
@ -730,17 +731,17 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
theirControl.Start() theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay") t.Log("Trigger a handshake from me to them via the relay")
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl) p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works") r.Log("Assert the tunnel works")
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
// When I update the certificate for the relay, both me and them will have 2 host infos for the relay, // When I update the certificate for the relay, both me and them will have 2 host infos for the relay,
// and the main host infos will not have any relay state to handle the me<->relay<->them tunnel. // and the main host infos will not have any relay state to handle the me<->relay<->them tunnel.
r.Log("Renew relay certificate and spin until me and them sees it") r.Log("Renew relay certificate and spin until me and them sees it")
_, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{relayVpnIpNet}, nil, []string{"new group"}) _, _, myNextPrivKey, myNextPEM := NewTestCert(cert.Version1, ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"})
caB, err := ca.MarshalPEM() caB, err := ca.MarshalPEM()
if err != nil { if err != nil {
@ -758,8 +759,8 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
for { for {
r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet") r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet")
assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r) assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r)
c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) c := myControl.GetHostInfoByVpnAddr(relayVpnIpNet[0].Addr(), false)
if len(c.Cert.Groups()) != 0 { if len(c.Cert.Groups()) != 0 {
// We have a new certificate now // We have a new certificate now
r.Log("Certificate between my and relay is updated!") r.Log("Certificate between my and relay is updated!")
@ -771,8 +772,8 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
for { for {
r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet") r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet")
assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r) assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r)
c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) c := theirControl.GetHostInfoByVpnAddr(relayVpnIpNet[0].Addr(), false)
if len(c.Cert.Groups()) != 0 { if len(c.Cert.Groups()) != 0 {
// We have a new certificate now // We have a new certificate now
r.Log("Certificate between their and relay is updated!") r.Log("Certificate between their and relay is updated!")
@ -783,13 +784,13 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
} }
r.Log("Assert the relay tunnel still works") r.Log("Assert the relay tunnel still works")
assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
// We should have two hostinfos on all sides // We should have two hostinfos on all sides
for len(myControl.GetHostmap().Indexes) != 2 { for len(myControl.GetHostmap().Indexes) != 2 {
t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes)) t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes))
r.Log("Assert the relay tunnel still works") r.Log("Assert the relay tunnel still works")
assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
r.Log("yupitdoes") r.Log("yupitdoes")
time.Sleep(time.Second) time.Sleep(time.Second)
} }
@ -797,7 +798,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
for len(theirControl.GetHostmap().Indexes) != 2 { for len(theirControl.GetHostmap().Indexes) != 2 {
t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes)) t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes))
r.Log("Assert the relay tunnel still works") r.Log("Assert the relay tunnel still works")
assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
r.Log("yupitdoes") r.Log("yupitdoes")
time.Sleep(time.Second) time.Sleep(time.Second)
} }
@ -805,7 +806,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
for len(relayControl.GetHostmap().Indexes) != 2 { for len(relayControl.GetHostmap().Indexes) != 2 {
t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes)) t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes))
r.Log("Assert the relay tunnel still works") r.Log("Assert the relay tunnel still works")
assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
r.Log("yupitdoes") r.Log("yupitdoes")
time.Sleep(time.Second) time.Sleep(time.Second)
} }
@ -814,12 +815,12 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
func TestRehandshaking(t *testing.T) { func TestRehandshaking(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", "10.128.0.2/24", nil) myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.2/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", "10.128.0.1/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.1/24", nil)
// Put their info in our lighthouse and vice versa // Put their info in our lighthouse and vice versa
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Build a router so we don't have to reason who gets which packet // Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, theirControl) r := router.NewR(t, myControl, theirControl)
@ -830,12 +831,12 @@ func TestRehandshaking(t *testing.T) {
theirControl.Start() theirControl.Start()
t.Log("Stand up a tunnel between me and them") t.Log("Stand up a tunnel between me and them")
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
r.RenderHostmaps("Starting hostmaps", myControl, theirControl) r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
r.Log("Renew my certificate and spin until their sees it") r.Log("Renew my certificate and spin until their sees it")
_, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{myVpnIpNet}, nil, []string{"new group"}) _, _, myNextPrivKey, myNextPEM := NewTestCert(cert.Version1, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), myVpnIpNet, nil, []string{"new group"})
caB, err := ca.MarshalPEM() caB, err := ca.MarshalPEM()
if err != nil { if err != nil {
@ -852,8 +853,8 @@ func TestRehandshaking(t *testing.T) {
myConfig.ReloadConfigString(string(rc)) myConfig.ReloadConfigString(string(rc))
for { for {
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false) c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
if len(c.Cert.Groups()) != 0 { if len(c.Cert.Groups()) != 0 {
// We have a new certificate now // We have a new certificate now
break break
@ -880,19 +881,19 @@ func TestRehandshaking(t *testing.T) {
r.Log("Spin until there is only 1 tunnel") r.Log("Spin until there is only 1 tunnel")
for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
t.Log("Connection manager hasn't ticked yet") t.Log("Connection manager hasn't ticked yet")
time.Sleep(time.Second) time.Sleep(time.Second)
} }
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
myFinalHostmapHosts := myControl.ListHostmapHosts(false) myFinalHostmapHosts := myControl.ListHostmapHosts(false)
myFinalHostmapIndexes := myControl.ListHostmapIndexes(false) myFinalHostmapIndexes := myControl.ListHostmapIndexes(false)
theirFinalHostmapHosts := theirControl.ListHostmapHosts(false) theirFinalHostmapHosts := theirControl.ListHostmapHosts(false)
theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false) theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false)
// Make sure the correct tunnel won // Make sure the correct tunnel won
c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false) c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
assert.Contains(t, c.Cert.Groups(), "new group") assert.Contains(t, c.Cert.Groups(), "new group")
// We should only have a single tunnel now on both sides // We should only have a single tunnel now on both sides
@ -911,12 +912,12 @@ func TestRehandshakingLoser(t *testing.T) {
// The purpose of this test is that the race loser renews their certificate and rehandshakes. The final tunnel // The purpose of this test is that the race loser renews their certificate and rehandshakes. The final tunnel
// Should be the one with the new certificate // Should be the one with the new certificate
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", "10.128.0.2/24", nil) myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.2/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", "10.128.0.1/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.1/24", nil)
// Put their info in our lighthouse and vice versa // Put their info in our lighthouse and vice versa
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Build a router so we don't have to reason who gets which packet // Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, theirControl) r := router.NewR(t, myControl, theirControl)
@ -927,16 +928,12 @@ func TestRehandshakingLoser(t *testing.T) {
theirControl.Start() theirControl.Start()
t.Log("Stand up a tunnel between me and them") t.Log("Stand up a tunnel between me and them")
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
tt1 := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false)
tt2 := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false)
fmt.Println(tt1.LocalIndex, tt2.LocalIndex)
r.RenderHostmaps("Starting hostmaps", myControl, theirControl) r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
r.Log("Renew their certificate and spin until mine sees it") r.Log("Renew their certificate and spin until mine sees it")
_, _, theirNextPrivKey, theirNextPEM := NewTestCert(ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{theirVpnIpNet}, nil, []string{"their new group"}) _, _, theirNextPrivKey, theirNextPEM := NewTestCert(cert.Version1, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), theirVpnIpNet, nil, []string{"their new group"})
caB, err := ca.MarshalPEM() caB, err := ca.MarshalPEM()
if err != nil { if err != nil {
@ -953,8 +950,8 @@ func TestRehandshakingLoser(t *testing.T) {
theirConfig.ReloadConfigString(string(rc)) theirConfig.ReloadConfigString(string(rc))
for { for {
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false) theirCertInMe := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
if slices.Contains(theirCertInMe.Cert.Groups(), "their new group") { if slices.Contains(theirCertInMe.Cert.Groups(), "their new group") {
break break
@ -980,19 +977,19 @@ func TestRehandshakingLoser(t *testing.T) {
r.Log("Spin until there is only 1 tunnel") r.Log("Spin until there is only 1 tunnel")
for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
t.Log("Connection manager hasn't ticked yet") t.Log("Connection manager hasn't ticked yet")
time.Sleep(time.Second) time.Sleep(time.Second)
} }
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
myFinalHostmapHosts := myControl.ListHostmapHosts(false) myFinalHostmapHosts := myControl.ListHostmapHosts(false)
myFinalHostmapIndexes := myControl.ListHostmapIndexes(false) myFinalHostmapIndexes := myControl.ListHostmapIndexes(false)
theirFinalHostmapHosts := theirControl.ListHostmapHosts(false) theirFinalHostmapHosts := theirControl.ListHostmapHosts(false)
theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false) theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false)
// Make sure the correct tunnel won // Make sure the correct tunnel won
theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false) theirCertInMe := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
assert.Contains(t, theirCertInMe.Cert.Groups(), "their new group") assert.Contains(t, theirCertInMe.Cert.Groups(), "their new group")
// We should only have a single tunnel now on both sides // We should only have a single tunnel now on both sides
@ -1011,12 +1008,12 @@ func TestRaceRegression(t *testing.T) {
// We had a bug where we were not finding the duplicate handshake and responding to the final stage 1 which // We had a bug where we were not finding the duplicate handshake and responding to the final stage 1 which
// caused a cross-linked hostinfo // caused a cross-linked hostinfo
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
// Put their info in our lighthouse // Put their info in our lighthouse
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Start the servers // Start the servers
myControl.Start() myControl.Start()
@ -1030,8 +1027,8 @@ func TestRaceRegression(t *testing.T) {
//them rx stage:2 initiatorIndex=120607833 responderIndex=4209862089 //them rx stage:2 initiatorIndex=120607833 responderIndex=4209862089
t.Log("Start both handshakes") t.Log("Start both handshakes")
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them")) theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
t.Log("Get both stage 1") t.Log("Get both stage 1")
myStage1ForThem := myControl.GetFromUDP(true) myStage1ForThem := myControl.GetFromUDP(true)
@ -1061,12 +1058,52 @@ func TestRaceRegression(t *testing.T) {
r.RenderHostmaps("Starting hostmaps", myControl, theirControl) r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
t.Log("Make sure the tunnel still works") t.Log("Make sure the tunnel still works")
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
myControl.Stop() myControl.Stop()
theirControl.Stop() theirControl.Stop()
} }
func TestV2NonPrimaryWithLighthouse(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh ", "10.128.0.1/24, ff::1/64", m{"lighthouse": m{"am_lighthouse": true}})
o := m{
"static_host_map": m{
lhVpnIpNet[1].Addr().String(): []string{lhUdpAddr.String()},
},
"lighthouse": m{
"hosts": []string{lhVpnIpNet[1].Addr().String()},
"local_allow_list": m{
// Try and block our lighthouse updates from using the actual addresses assigned to this computer
// If we start discovering addresses the test router doesn't know about then test traffic cant flow
"10.0.0.0/24": true,
"::/0": false,
},
},
}
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.2/24, ff::2/64", o)
theirControl, theirVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "10.128.0.3/24, ff::3/64", o)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, lhControl, myControl, theirControl)
defer r.RenderFlow()
// Start the servers
lhControl.Start()
myControl.Start()
theirControl.Start()
t.Log("Stand up an ipv6 tunnel between me and them")
assert.True(t, myVpnIpNet[1].Addr().Is6())
assert.True(t, theirVpnIpNet[1].Addr().Is6())
assertTunnel(t, myVpnIpNet[1].Addr(), theirVpnIpNet[1].Addr(), myControl, theirControl, r)
lhControl.Stop()
myControl.Stop()
theirControl.Stop()
}
//TODO: test //TODO: test
// Race winner renews and handshakes // Race winner renews and handshakes
// Race loser renews and handshakes // Race loser renews and handshakes

View File

@ -48,7 +48,7 @@ func NewTestCaCert(before, after time.Time, networks, unsafeNetworks []netip.Pre
// NewTestCert will generate a signed certificate with the provided details. // NewTestCert will generate a signed certificate with the provided details.
// Expiry times are defaulted if you do not pass them in // Expiry times are defaulted if you do not pass them in
func NewTestCert(ca cert.Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) { func NewTestCert(v cert.Version, ca cert.Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) {
if before.IsZero() { if before.IsZero() {
before = time.Now().Add(time.Second * -60).Round(time.Second) before = time.Now().Add(time.Second * -60).Round(time.Second)
} }
@ -59,7 +59,7 @@ func NewTestCert(ca cert.Certificate, key []byte, name string, before, after tim
pub, rawPriv := x25519Keypair() pub, rawPriv := x25519Keypair()
nc := &cert.TBSCertificate{ nc := &cert.TBSCertificate{
Version: cert.Version1, Version: v,
Name: name, Name: name,
Networks: networks, Networks: networks,
UnsafeNetworks: unsafeNetworks, UnsafeNetworks: unsafeNetworks,

View File

@ -8,6 +8,7 @@ import (
"io" "io"
"net/netip" "net/netip"
"os" "os"
"strings"
"testing" "testing"
"time" "time"
@ -26,25 +27,35 @@ import (
type m map[string]interface{} type m map[string]interface{}
// newSimpleServer creates a nebula instance with many assumptions // newSimpleServer creates a nebula instance with many assumptions
func newSimpleServer(caCrt cert.Certificate, caKey []byte, name string, sVpnIpNet string, overrides m) (*nebula.Control, netip.Prefix, netip.AddrPort, *config.C) { func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
l := NewTestLogger() l := NewTestLogger()
vpnIpNet, err := netip.ParsePrefix(sVpnIpNet) var vpnNetworks []netip.Prefix
for _, sn := range strings.Split(sVpnNetworks, ",") {
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
if err != nil { if err != nil {
panic(err) panic(err)
} }
vpnNetworks = append(vpnNetworks, vpnIpNet)
}
if len(vpnNetworks) == 0 {
panic("no vpn networks")
}
var udpAddr netip.AddrPort var udpAddr netip.AddrPort
if vpnIpNet.Addr().Is4() { if vpnNetworks[0].Addr().Is4() {
budpIp := vpnIpNet.Addr().As4() budpIp := vpnNetworks[0].Addr().As4()
budpIp[1] -= 128 budpIp[1] -= 128
udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242) udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242)
} else { } else {
budpIp := vpnIpNet.Addr().As16() budpIp := vpnNetworks[0].Addr().As16()
budpIp[13] -= 128 // beef for funsies
budpIp[2] = 190
budpIp[3] = 239
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242) udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
} }
_, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{vpnIpNet}, nil, []string{}) _, _, myPrivKey, myPEM := NewTestCert(v, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{})
caB, err := caCrt.MarshalPEM() caB, err := caCrt.MarshalPEM()
if err != nil { if err != nil {
@ -88,11 +99,16 @@ func newSimpleServer(caCrt cert.Certificate, caKey []byte, name string, sVpnIpNe
} }
if overrides != nil { if overrides != nil {
err = mergo.Merge(&overrides, mc, mergo.WithAppendSlice) final := m{}
err = mergo.Merge(&final, overrides, mergo.WithAppendSlice)
if err != nil { if err != nil {
panic(err) panic(err)
} }
mc = overrides err = mergo.Merge(&final, mc, mergo.WithAppendSlice)
if err != nil {
panic(err)
}
mc = final
} }
cb, err := yaml.Marshal(mc) cb, err := yaml.Marshal(mc)
@ -109,7 +125,7 @@ func newSimpleServer(caCrt cert.Certificate, caKey []byte, name string, sVpnIpNe
panic(err) panic(err)
} }
return control, vpnIpNet, udpAddr, c return control, vpnNetworks, udpAddr, c
} }
type doneCb func() type doneCb func()
@ -132,27 +148,28 @@ func deadline(t *testing.T, seconds time.Duration) doneCb {
func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) { func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
// Send a packet from them to me // Send a packet from them to me
controlB.InjectTunUDPPacket(vpnIpA, 80, 90, []byte("Hi from B")) controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B"))
bPacket := r.RouteForAllUntilTxTun(controlA) bPacket := r.RouteForAllUntilTxTun(controlA)
assertUdpPacket(t, []byte("Hi from B"), bPacket, vpnIpB, vpnIpA, 90, 80) assertUdpPacket(t, []byte("Hi from B"), bPacket, vpnIpB, vpnIpA, 90, 80)
// And once more from me to them // And once more from me to them
controlA.InjectTunUDPPacket(vpnIpB, 80, 90, []byte("Hello from A")) controlA.InjectTunUDPPacket(vpnIpB, 80, vpnIpA, 90, []byte("Hello from A"))
aPacket := r.RouteForAllUntilTxTun(controlB) aPacket := r.RouteForAllUntilTxTun(controlB)
assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80) assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80)
} }
func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control) { func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpnNetsB []netip.Prefix, controlA, controlB *nebula.Control) {
// Get both host infos // Get both host infos
hBinA := controlA.GetHostInfoByVpnIp(vpnIpB, false) //TODO: we may want to loop over each vpnAddr and assert all the things
assert.NotNil(t, hBinA, "Host B was not found by vpnIp in controlA") hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false)
assert.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA")
hAinB := controlB.GetHostInfoByVpnIp(vpnIpA, false) hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false)
assert.NotNil(t, hAinB, "Host A was not found by vpnIp in controlB") assert.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB")
// Check that both vpn and real addr are correct // Check that both vpn and real addr are correct
assert.Equal(t, vpnIpB, hBinA.VpnIp, "Host B VpnIp is wrong in control A") assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A")
assert.Equal(t, vpnIpA, hAinB.VpnIp, "Host A VpnIp is wrong in control B") assert.EqualValues(t, getAddrs(vpnNetsA), hAinB.VpnAddrs, "Host A VpnIp is wrong in control B")
assert.Equal(t, addrB, hBinA.CurrentRemote, "Host B remote is wrong in control A") assert.Equal(t, addrB, hBinA.CurrentRemote, "Host B remote is wrong in control A")
assert.Equal(t, addrA, hAinB.CurrentRemote, "Host A remote is wrong in control B") assert.Equal(t, addrA, hAinB.CurrentRemote, "Host A remote is wrong in control B")
@ -179,6 +196,33 @@ func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnIpA, vpnIp
} }
func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
if toIp.Is6() {
assertUdpPacket6(t, expected, b, fromIp, toIp, fromPort, toPort)
} else {
assertUdpPacket4(t, expected, b, fromIp, toIp, fromPort, toPort)
}
}
func assertUdpPacket6(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
packet := gopacket.NewPacket(b, layers.LayerTypeIPv6, gopacket.Lazy)
v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6)
assert.NotNil(t, v6, "No ipv6 data found")
assert.Equal(t, fromIp.AsSlice(), []byte(v6.SrcIP), "Source ip was incorrect")
assert.Equal(t, toIp.AsSlice(), []byte(v6.DstIP), "Dest ip was incorrect")
udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
assert.NotNil(t, udp, "No udp data found")
assert.Equal(t, fromPort, uint16(udp.SrcPort), "Source port was incorrect")
assert.Equal(t, toPort, uint16(udp.DstPort), "Dest port was incorrect")
data := packet.ApplicationLayer()
assert.NotNil(t, data)
assert.Equal(t, expected, data.Payload(), "Data was incorrect")
}
func assertUdpPacket4(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy) packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy)
v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
assert.NotNil(t, v4, "No ipv4 data found") assert.NotNil(t, v4, "No ipv4 data found")
@ -197,6 +241,14 @@ func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr,
assert.Equal(t, expected, data.Payload(), "Data was incorrect") assert.Equal(t, expected, data.Payload(), "Data was incorrect")
} }
func getAddrs(ns []netip.Prefix) []netip.Addr {
var a []netip.Addr
for _, n := range ns {
a = append(a, n.Addr())
}
return a
}
func NewTestLogger() *logrus.Logger { func NewTestLogger() *logrus.Logger {
l := logrus.New() l := logrus.New()

View File

@ -58,8 +58,9 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {
var lines []string var lines []string
var globalLines []*edge var globalLines []*edge
clusterName := strings.Trim(c.GetCert().Name(), " ") crt := c.GetCertState().GetDefaultCertificate()
clusterVpnIp := c.GetCert().Networks()[0].Addr() clusterName := strings.Trim(crt.Name(), " ")
clusterVpnIp := crt.Networks()[0].Addr()
r := fmt.Sprintf("\tsubgraph %s[\"%s (%s)\"]\n", clusterName, clusterName, clusterVpnIp) r := fmt.Sprintf("\tsubgraph %s[\"%s (%s)\"]\n", clusterName, clusterName, clusterVpnIp)
hm := c.GetHostmap() hm := c.GetHostmap()
@ -101,7 +102,7 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {
for _, idx := range indexes { for _, idx := range indexes {
hi, ok := hm.Indexes[idx] hi, ok := hm.Indexes[idx]
if ok { if ok {
r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnIp()) r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnAddrs())
remoteClusterName := strings.Trim(hi.GetCert().Certificate.Name(), " ") remoteClusterName := strings.Trim(hi.GetCert().Certificate.Name(), " ")
globalLines = append(globalLines, &edge{from: fmt.Sprintf("%v.%v", clusterName, idx), to: fmt.Sprintf("%v.%v", remoteClusterName, hi.GetRemoteIndex())}) globalLines = append(globalLines, &edge{from: fmt.Sprintf("%v.%v", clusterName, idx), to: fmt.Sprintf("%v.%v", remoteClusterName, hi.GetRemoteIndex())})
_ = hi _ = hi

View File

@ -10,8 +10,8 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"reflect" "reflect"
"regexp"
"sort" "sort"
"strings"
"sync" "sync"
"testing" "testing"
"time" "time"
@ -136,7 +136,10 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
panic("Duplicate listen address: " + addr.String()) panic("Duplicate listen address: " + addr.String())
} }
r.vpnControls[c.GetVpnIp()] = c for _, vpnAddr := range c.GetVpnAddrs() {
r.vpnControls[vpnAddr] = c
}
r.controls[addr] = c r.controls[addr] = c
} }
@ -213,11 +216,11 @@ func (r *R) renderFlow() {
continue continue
} }
participants[addr] = struct{}{} participants[addr] = struct{}{}
sanAddr := strings.Replace(addr.String(), ":", "-", 1) sanAddr := normalizeName(addr.String())
participantsVals = append(participantsVals, sanAddr) participantsVals = append(participantsVals, sanAddr)
fmt.Fprintf( fmt.Fprintf(
f, " participant %s as Nebula: %s<br/>UDP: %s\n", f, " participant %s as Nebula: %s<br/>UDP: %s\n",
sanAddr, e.packet.from.GetVpnIp(), sanAddr, sanAddr, e.packet.from.GetVpnAddrs(), sanAddr,
) )
} }
@ -250,9 +253,9 @@ func (r *R) renderFlow() {
fmt.Fprintf(f, fmt.Fprintf(f,
" %s%s%s: %s(%s), index %v, counter: %v\n", " %s%s%s: %s(%s), index %v, counter: %v\n",
strings.Replace(p.from.GetUDPAddr().String(), ":", "-", 1), normalizeName(p.from.GetUDPAddr().String()),
line, line,
strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1), normalizeName(p.to.GetUDPAddr().String()),
h.TypeName(), h.SubTypeName(), h.RemoteIndex, h.MessageCounter, h.TypeName(), h.SubTypeName(), h.RemoteIndex, h.MessageCounter,
) )
} }
@ -267,6 +270,11 @@ func (r *R) renderFlow() {
} }
} }
func normalizeName(s string) string {
rx := regexp.MustCompile("[\\[\\]\\:]")
return rx.ReplaceAllLiteralString(s, "_")
}
// IgnoreFlow tells the router to stop recording future flows that matches the provided criteria. // IgnoreFlow tells the router to stop recording future flows that matches the provided criteria.
// messageType and subType will target nebula underlay packets while tun will target nebula overlay packets // messageType and subType will target nebula underlay packets while tun will target nebula overlay packets
// NOTE: This is a very broad system, if you set tun to true then no more tun traffic will be rendered // NOTE: This is a very broad system, if you set tun to true then no more tun traffic will be rendered
@ -303,7 +311,7 @@ func (r *R) RenderHostmaps(title string, controls ...*nebula.Control) {
func (r *R) renderHostmaps(title string) { func (r *R) renderHostmaps(title string) {
c := maps.Values(r.controls) c := maps.Values(r.controls)
sort.SliceStable(c, func(i, j int) bool { sort.SliceStable(c, func(i, j int) bool {
return c[i].GetVpnIp().Compare(c[j].GetVpnIp()) > 0 return c[i].GetVpnAddrs()[0].Compare(c[j].GetVpnAddrs()[0]) > 0
}) })
s := renderHostmaps(c...) s := renderHostmaps(c...)
@ -419,10 +427,11 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) []
// Nope, lets push the sender along // Nope, lets push the sender along
case p := <-udpTx: case p := <-udpTx:
r.Lock() r.Lock()
c := r.getControl(sender.GetUDPAddr(), p.To, p) a := sender.GetUDPAddr()
c := r.getControl(a, p.To, p)
if c == nil { if c == nil {
r.Unlock() r.Unlock()
panic("No control for udp tx") panic("No control for udp tx " + a.String())
} }
fp := r.unlockedInjectFlow(sender, c, p, false) fp := r.unlockedInjectFlow(sender, c, p, false)
c.InjectUDPPacket(p) c.InjectUDPPacket(p)
@ -475,10 +484,11 @@ func (r *R) RouteForAllUntilTxTun(receiver *nebula.Control) []byte {
} else { } else {
// we are a udp tx, route and continue // we are a udp tx, route and continue
p := rx.Interface().(*udp.Packet) p := rx.Interface().(*udp.Packet)
c := r.getControl(cm[x].GetUDPAddr(), p.To, p) a := cm[x].GetUDPAddr()
c := r.getControl(a, p.To, p)
if c == nil { if c == nil {
r.Unlock() r.Unlock()
panic("No control for udp tx") panic(fmt.Sprintf("No control for udp tx %s", p.To))
} }
fp := r.unlockedInjectFlow(cm[x], c, p, false) fp := r.unlockedInjectFlow(cm[x], c, p, false)
c.InjectUDPPacket(p) c.InjectUDPPacket(p)
@ -711,30 +721,42 @@ func (r *R) getControl(fromAddr, toAddr netip.AddrPort, p *udp.Packet) *nebula.C
} }
func (r *R) formatUdpPacket(p *packet) string { func (r *R) formatUdpPacket(p *packet) string {
packet := gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv4, gopacket.Lazy) var packet gopacket.Packet
v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) var srcAddr netip.Addr
if v4 == nil {
panic("not an ipv4 packet") packet = gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv6, gopacket.Lazy)
if packet.ErrorLayer() == nil {
v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6)
if v6 == nil {
panic("not an ipv6 packet")
}
srcAddr, _ = netip.AddrFromSlice(v6.SrcIP)
} else {
packet = gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv4, gopacket.Lazy)
v6 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
if v6 == nil {
panic("not an ipv6 packet")
}
srcAddr, _ = netip.AddrFromSlice(v6.SrcIP)
} }
from := "unknown" from := "unknown"
srcAddr, _ := netip.AddrFromSlice(v4.SrcIP)
if c, ok := r.vpnControls[srcAddr]; ok { if c, ok := r.vpnControls[srcAddr]; ok {
from = c.GetUDPAddr().String() from = c.GetUDPAddr().String()
} }
udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) udpLayer := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
if udp == nil { if udpLayer == nil {
panic("not a udp packet") panic("not a udp packet")
} }
data := packet.ApplicationLayer() data := packet.ApplicationLayer()
return fmt.Sprintf( return fmt.Sprintf(
" %s-->>%s: src port: %v<br/>dest port: %v<br/>data: \"%v\"\n", " %s-->>%s: src port: %v<br/>dest port: %v<br/>data: \"%v\"\n",
strings.Replace(from, ":", "-", 1), normalizeName(from),
strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1), normalizeName(p.to.GetUDPAddr().String()),
udp.SrcPort, udpLayer.SrcPort,
udp.DstPort, udpLayer.DstPort,
string(data.Payload()), string(data.Payload()),
) )
} }

View File

@ -13,6 +13,12 @@ pki:
# disconnect_invalid is a toggle to force a client to be disconnected if the certificate is expired or invalid. # disconnect_invalid is a toggle to force a client to be disconnected if the certificate is expired or invalid.
#disconnect_invalid: true #disconnect_invalid: true
# default_version controls which certificate version is used in handshakes.
# This setting only applies if both a v1 and a v2 certificate are configured, in which case it will default to `1`.
# Once all hosts in the mesh are configured with both a v1 and v2 certificate then this should be changed to `2`.
# After all hosts in the mesh are using a v2 certificate then v1 certificates are no longer needed.
# default_version: 1
# The static host map defines a set of hosts with fixed IP addresses on the internet (or any network). # The static host map defines a set of hosts with fixed IP addresses on the internet (or any network).
# A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel. # A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel.
# The syntax is: # The syntax is:
@ -336,10 +342,13 @@ firewall:
# host: `any` or a literal hostname, ie `test-host` # host: `any` or a literal hostname, ie `test-host`
# group: `any` or a literal group name, ie `default-group` # group: `any` or a literal group name, ie `default-group`
# groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass # groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
# cidr: a remote CIDR, `0.0.0.0/0` is any. # cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. //TODO: we have a problem, firewall needs to understand this and should probably allow `any` for both
# local_cidr: a local CIDR, `0.0.0.0/0` is any. This could be used to filter destinations when using unsafe_routes. # local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. This could be used to filter destinations when using unsafe_routes.
# Default is `any` unless the certificate contains subnets and then the default is the ip issued in the certificate # //TODO: probably should have an `any` that covers both ip versions
# if `default_local_cidr_any` is false, otherwise its `any`. # If no unsafe networks are present in the certificate(s) or `default_local_cidr_any` is true then the default is any ipv4 or ipv6 network.
# Otherwise the default is any vpn network assigned to via the certificate.
# `default_local_cidr_any` defaults to false and is deprecated, it will be removed in a future release.
# If there are unsafe routes present its best to set `local_cidr` to whatever best fits the situation.
# ca_name: An issuing CA name # ca_name: An issuing CA name
# ca_sha: An issuing CA shasum # ca_sha: An issuing CA shasum

View File

@ -8,6 +8,7 @@ import (
"hash/fnv" "hash/fnv"
"net/netip" "net/netip"
"reflect" "reflect"
"slices"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -22,7 +23,8 @@ import (
) )
type FirewallInterface interface { type FirewallInterface interface {
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error //TODO: name these better addr, localAddr. Are they vpnAddrs?
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, addr, localAddr netip.Prefix, caName string, caSha string) error
} }
type conn struct { type conn struct {
@ -51,9 +53,12 @@ type Firewall struct {
UDPTimeout time.Duration //linux: 180s max UDPTimeout time.Duration //linux: 180s max
DefaultTimeout time.Duration //linux: 600s DefaultTimeout time.Duration //linux: 600s
// Used to ensure we don't emit local packets for ips we don't own // routableNetworks describes the vpn addresses as well as any unsafe networks issued to us in the certificate.
localIps *bart.Table[struct{}] // The vpn addresses are a full bit match while the unsafe networks only match the prefix
assignedCIDR netip.Prefix routableNetworks *bart.Table[struct{}]
// assignedNetworks is a list of vpn networks assigned to us in the certificate.
assignedNetworks []netip.Prefix
hasUnsafeNetworks bool hasUnsafeNetworks bool
rules string rules string
@ -67,8 +72,8 @@ type Firewall struct {
} }
type firewallMetrics struct { type firewallMetrics struct {
droppedLocalIP metrics.Counter droppedLocalAddr metrics.Counter
droppedRemoteIP metrics.Counter droppedRemoteAddr metrics.Counter
droppedNoRule metrics.Counter droppedNoRule metrics.Counter
} }
@ -126,84 +131,87 @@ type firewallLocalCIDR struct {
} }
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts. // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
// The certificate provided should be the highest version loaded in memory.
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall { func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall {
//TODO: error on 0 duration //TODO: error on 0 duration
var min, max time.Duration var tmin, tmax time.Duration
if tcpTimeout < UDPTimeout { if tcpTimeout < UDPTimeout {
min = tcpTimeout tmin = tcpTimeout
max = UDPTimeout tmax = UDPTimeout
} else { } else {
min = UDPTimeout tmin = UDPTimeout
max = tcpTimeout tmax = tcpTimeout
} }
if defaultTimeout < min { if defaultTimeout < tmin {
min = defaultTimeout tmin = defaultTimeout
} else if defaultTimeout > max { } else if defaultTimeout > tmax {
max = defaultTimeout tmax = defaultTimeout
} }
localIps := new(bart.Table[struct{}]) routableNetworks := new(bart.Table[struct{}])
var assignedCIDR netip.Prefix var assignedNetworks []netip.Prefix
var assignedSet bool
for _, network := range c.Networks() { for _, network := range c.Networks() {
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen()) nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
localIps.Insert(nprefix, struct{}{}) routableNetworks.Insert(nprefix, struct{}{})
assignedNetworks = append(assignedNetworks, network)
if !assignedSet {
// Only grabbing the first one in the cert since any more than that currently has undefined behavior
assignedCIDR = nprefix
assignedSet = true
}
} }
hasUnsafeNetworks := false hasUnsafeNetworks := false
for _, n := range c.UnsafeNetworks() { for _, n := range c.UnsafeNetworks() {
localIps.Insert(n, struct{}{}) routableNetworks.Insert(n, struct{}{})
hasUnsafeNetworks = true hasUnsafeNetworks = true
} }
return &Firewall{ return &Firewall{
Conntrack: &FirewallConntrack{ Conntrack: &FirewallConntrack{
Conns: make(map[firewall.Packet]*conn), Conns: make(map[firewall.Packet]*conn),
TimerWheel: NewTimerWheel[firewall.Packet](min, max), TimerWheel: NewTimerWheel[firewall.Packet](tmin, tmax),
}, },
InRules: newFirewallTable(), InRules: newFirewallTable(),
OutRules: newFirewallTable(), OutRules: newFirewallTable(),
TCPTimeout: tcpTimeout, TCPTimeout: tcpTimeout,
UDPTimeout: UDPTimeout, UDPTimeout: UDPTimeout,
DefaultTimeout: defaultTimeout, DefaultTimeout: defaultTimeout,
localIps: localIps, routableNetworks: routableNetworks,
assignedCIDR: assignedCIDR, assignedNetworks: assignedNetworks,
hasUnsafeNetworks: hasUnsafeNetworks, hasUnsafeNetworks: hasUnsafeNetworks,
l: l, l: l,
incomingMetrics: firewallMetrics{ incomingMetrics: firewallMetrics{
droppedLocalIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_ip", nil), droppedLocalAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_addr", nil),
droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_ip", nil), droppedRemoteAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_addr", nil),
droppedNoRule: metrics.GetOrRegisterCounter("firewall.incoming.dropped.no_rule", nil), droppedNoRule: metrics.GetOrRegisterCounter("firewall.incoming.dropped.no_rule", nil),
}, },
outgoingMetrics: firewallMetrics{ outgoingMetrics: firewallMetrics{
droppedLocalIP: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.local_ip", nil), droppedLocalAddr: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.local_addr", nil),
droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.remote_ip", nil), droppedRemoteAddr: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.remote_addr", nil),
droppedNoRule: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.no_rule", nil), droppedNoRule: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.no_rule", nil),
}, },
} }
} }
func NewFirewallFromConfig(l *logrus.Logger, nc cert.Certificate, c *config.C) (*Firewall, error) { func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) {
certificate := cs.getCertificate(cert.Version2)
if certificate == nil {
certificate = cs.getCertificate(cert.Version1)
}
if certificate == nil {
panic("No certificate available to reconfigure the firewall")
}
fw := NewFirewall( fw := NewFirewall(
l, l,
c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12), c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3), c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3),
c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10), c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10),
nc, certificate,
//TODO: max_connections //TODO: max_connections
) )
//TODO: Flip to false after v1.9 release fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", false)
fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", true)
inboundAction := c.GetString("firewall.inbound_action", "drop") inboundAction := c.GetString("firewall.inbound_action", "drop")
switch inboundAction { switch inboundAction {
@ -283,7 +291,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
fp = ft.TCP fp = ft.TCP
case firewall.ProtoUDP: case firewall.ProtoUDP:
fp = ft.UDP fp = ft.UDP
case firewall.ProtoICMP: case firewall.ProtoICMP, firewall.ProtoICMPv6:
fp = ft.ICMP fp = ft.ICMP
case firewall.ProtoAny: case firewall.ProtoAny:
fp = ft.AnyProto fp = ft.AnyProto
@ -424,26 +432,25 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
} }
// Make sure remote address matches nebula certificate // Make sure remote address matches nebula certificate
if remoteCidr := h.remoteCidr; remoteCidr != nil { if h.networks != nil {
//TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different _, ok := h.networks.Lookup(fp.RemoteAddr)
_, ok := remoteCidr.Lookup(fp.RemoteIP)
if !ok { if !ok {
f.metrics(incoming).droppedRemoteIP.Inc(1) f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP return ErrInvalidRemoteIP
} }
} else { } else {
// Simple case: Certificate has one IP and no subnets // Simple case: Certificate has one IP and no subnets
if fp.RemoteIP != h.vpnIp { //TODO: we can make this more performant
f.metrics(incoming).droppedRemoteIP.Inc(1) if !slices.Contains(h.vpnAddrs, fp.RemoteAddr) {
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP return ErrInvalidRemoteIP
} }
} }
// Make sure we are supposed to be handling this local ip address // Make sure we are supposed to be handling this local ip address
//TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different _, ok := f.routableNetworks.Lookup(fp.LocalAddr)
_, ok := f.localIps.Lookup(fp.LocalIP)
if !ok { if !ok {
f.metrics(incoming).droppedLocalIP.Inc(1) f.metrics(incoming).droppedLocalAddr.Inc(1)
return ErrInvalidLocalIP return ErrInvalidLocalIP
} }
@ -629,7 +636,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.CachedC
if ft.UDP.match(p, incoming, c, caPool) { if ft.UDP.match(p, incoming, c, caPool) {
return true return true
} }
case firewall.ProtoICMP: case firewall.ProtoICMP, firewall.ProtoICMPv6:
if ft.ICMP.match(p, incoming, c, caPool) { if ft.ICMP.match(p, incoming, c, caPool) {
return true return true
} }
@ -859,9 +866,9 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool
} }
matched := false matched := false
prefix := netip.PrefixFrom(p.RemoteIP, p.RemoteIP.BitLen()) prefix := netip.PrefixFrom(p.RemoteAddr, p.RemoteAddr.BitLen())
fr.CIDR.EachLookupPrefix(prefix, func(prefix netip.Prefix, val *firewallLocalCIDR) bool { fr.CIDR.EachLookupPrefix(prefix, func(prefix netip.Prefix, val *firewallLocalCIDR) bool {
if prefix.Contains(p.RemoteIP) && val.match(p, c) { if prefix.Contains(p.RemoteAddr) && val.match(p, c) {
matched = true matched = true
return false return false
} }
@ -877,9 +884,14 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
return nil return nil
} }
localIp = f.assignedCIDR for _, network := range f.assignedNetworks {
flc.LocalCIDR.Insert(network, struct{}{})
}
return nil
} else if localIp.Bits() == 0 { } else if localIp.Bits() == 0 {
flc.Any = true flc.Any = true
return nil
} }
flc.LocalCIDR.Insert(localIp, struct{}{}) flc.LocalCIDR.Insert(localIp, struct{}{})
@ -895,7 +907,7 @@ func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.CachedCertificate
return true return true
} }
_, ok := flc.LocalCIDR.Lookup(p.LocalIP) _, ok := flc.LocalCIDR.Lookup(p.LocalAddr)
return ok return ok
} }

View File

@ -13,14 +13,15 @@ const (
ProtoTCP = 6 ProtoTCP = 6
ProtoUDP = 17 ProtoUDP = 17
ProtoICMP = 1 ProtoICMP = 1
ProtoICMPv6 = 58
PortAny = 0 // Special value for matching `port: any` PortAny = 0 // Special value for matching `port: any`
PortFragment = -1 // Special value for matching `port: fragment` PortFragment = -1 // Special value for matching `port: fragment`
) )
type Packet struct { type Packet struct {
LocalIP netip.Addr LocalAddr netip.Addr
RemoteIP netip.Addr RemoteAddr netip.Addr
LocalPort uint16 LocalPort uint16
RemotePort uint16 RemotePort uint16
Protocol uint8 Protocol uint8
@ -29,8 +30,8 @@ type Packet struct {
func (fp *Packet) Copy() *Packet { func (fp *Packet) Copy() *Packet {
return &Packet{ return &Packet{
LocalIP: fp.LocalIP, LocalAddr: fp.LocalAddr,
RemoteIP: fp.RemoteIP, RemoteAddr: fp.RemoteAddr,
LocalPort: fp.LocalPort, LocalPort: fp.LocalPort,
RemotePort: fp.RemotePort, RemotePort: fp.RemotePort,
Protocol: fp.Protocol, Protocol: fp.Protocol,
@ -51,8 +52,8 @@ func (fp Packet) MarshalJSON() ([]byte, error) {
proto = fmt.Sprintf("unknown %v", fp.Protocol) proto = fmt.Sprintf("unknown %v", fp.Protocol)
} }
return json.Marshal(m{ return json.Marshal(m{
"LocalIP": fp.LocalIP.String(), "LocalAddr": fp.LocalAddr.String(),
"RemoteIP": fp.RemoteIP.String(), "RemoteAddr": fp.RemoteAddr.String(),
"LocalPort": fp.LocalPort, "LocalPort": fp.LocalPort,
"RemotePort": fp.RemotePort, "RemotePort": fp.RemotePort,
"Protocol": proto, "Protocol": proto,

View File

@ -13,6 +13,7 @@ import (
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestNewFirewall(t *testing.T) { func TestNewFirewall(t *testing.T) {
@ -128,8 +129,8 @@ func TestFirewall_Drop(t *testing.T) {
l.SetOutput(ob) l.SetOutput(ob)
p := firewall.Packet{ p := firewall.Packet{
LocalIP: netip.MustParseAddr("1.2.3.4"), LocalAddr: netip.MustParseAddr("1.2.3.4"),
RemoteIP: netip.MustParseAddr("1.2.3.4"), RemoteAddr: netip.MustParseAddr("1.2.3.4"),
LocalPort: 10, LocalPort: 10,
RemotePort: 90, RemotePort: 90,
Protocol: firewall.ProtoUDP, Protocol: firewall.ProtoUDP,
@ -149,9 +150,9 @@ func TestFirewall_Drop(t *testing.T) {
InvertedGroups: map[string]struct{}{"default-group": {}}, InvertedGroups: map[string]struct{}{"default-group": {}},
}, },
}, },
vpnIp: netip.MustParseAddr("1.2.3.4"), vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
} }
h.CreateRemoteCIDR(&c) h.buildNetworks(&c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@ -166,10 +167,10 @@ func TestFirewall_Drop(t *testing.T) {
assert.NoError(t, fw.Drop(p, false, &h, cp, nil)) assert.NoError(t, fw.Drop(p, false, &h, cp, nil))
// test remote mismatch // test remote mismatch
oldRemote := p.RemoteIP oldRemote := p.RemoteAddr
p.RemoteIP = netip.MustParseAddr("1.2.3.10") p.RemoteAddr = netip.MustParseAddr("1.2.3.10")
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP) assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
p.RemoteIP = oldRemote p.RemoteAddr = oldRemote
// ensure signer doesn't get in the way of group checks // ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
@ -235,7 +236,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
} }
ip := netip.MustParsePrefix("9.254.254.254/32") ip := netip.MustParsePrefix("9.254.254.254/32")
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip.Addr()}, true, c, cp)) assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
} }
}) })
@ -261,7 +262,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
InvertedGroups: map[string]struct{}{"nope": {}}, InvertedGroups: map[string]struct{}{"nope": {}},
} }
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp)) assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
} }
}) })
@ -285,7 +286,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
InvertedGroups: map[string]struct{}{"good-group": {}}, InvertedGroups: map[string]struct{}{"good-group": {}},
} }
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp)) assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
} }
}) })
@ -308,8 +309,8 @@ func TestFirewall_Drop2(t *testing.T) {
l.SetOutput(ob) l.SetOutput(ob)
p := firewall.Packet{ p := firewall.Packet{
LocalIP: netip.MustParseAddr("1.2.3.4"), LocalAddr: netip.MustParseAddr("1.2.3.4"),
RemoteIP: netip.MustParseAddr("1.2.3.4"), RemoteAddr: netip.MustParseAddr("1.2.3.4"),
LocalPort: 10, LocalPort: 10,
RemotePort: 90, RemotePort: 90,
Protocol: firewall.ProtoUDP, Protocol: firewall.ProtoUDP,
@ -329,9 +330,9 @@ func TestFirewall_Drop2(t *testing.T) {
ConnectionState: &ConnectionState{ ConnectionState: &ConnectionState{
peerCert: &c, peerCert: &c,
}, },
vpnIp: network.Addr(), vpnAddrs: []netip.Addr{network.Addr()},
} }
h.CreateRemoteCIDR(c.Certificate) h.buildNetworks(c.Certificate)
c1 := cert.CachedCertificate{ c1 := cert.CachedCertificate{
Certificate: &dummyCert{ Certificate: &dummyCert{
@ -345,7 +346,7 @@ func TestFirewall_Drop2(t *testing.T) {
peerCert: &c1, peerCert: &c1,
}, },
} }
h1.CreateRemoteCIDR(c1.Certificate) h1.buildNetworks(c1.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@ -364,8 +365,8 @@ func TestFirewall_Drop3(t *testing.T) {
l.SetOutput(ob) l.SetOutput(ob)
p := firewall.Packet{ p := firewall.Packet{
LocalIP: netip.MustParseAddr("1.2.3.4"), LocalAddr: netip.MustParseAddr("1.2.3.4"),
RemoteIP: netip.MustParseAddr("1.2.3.4"), RemoteAddr: netip.MustParseAddr("1.2.3.4"),
LocalPort: 1, LocalPort: 1,
RemotePort: 1, RemotePort: 1,
Protocol: firewall.ProtoUDP, Protocol: firewall.ProtoUDP,
@ -391,9 +392,9 @@ func TestFirewall_Drop3(t *testing.T) {
ConnectionState: &ConnectionState{ ConnectionState: &ConnectionState{
peerCert: &c1, peerCert: &c1,
}, },
vpnIp: network.Addr(), vpnAddrs: []netip.Addr{network.Addr()},
} }
h1.CreateRemoteCIDR(c1.Certificate) h1.buildNetworks(c1.Certificate)
c2 := cert.CachedCertificate{ c2 := cert.CachedCertificate{
Certificate: &dummyCert{ Certificate: &dummyCert{
@ -406,9 +407,9 @@ func TestFirewall_Drop3(t *testing.T) {
ConnectionState: &ConnectionState{ ConnectionState: &ConnectionState{
peerCert: &c2, peerCert: &c2,
}, },
vpnIp: network.Addr(), vpnAddrs: []netip.Addr{network.Addr()},
} }
h2.CreateRemoteCIDR(c2.Certificate) h2.buildNetworks(c2.Certificate)
c3 := cert.CachedCertificate{ c3 := cert.CachedCertificate{
Certificate: &dummyCert{ Certificate: &dummyCert{
@ -421,9 +422,9 @@ func TestFirewall_Drop3(t *testing.T) {
ConnectionState: &ConnectionState{ ConnectionState: &ConnectionState{
peerCert: &c3, peerCert: &c3,
}, },
vpnIp: network.Addr(), vpnAddrs: []netip.Addr{network.Addr()},
} }
h3.CreateRemoteCIDR(c3.Certificate) h3.buildNetworks(c3.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
@ -446,8 +447,8 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
l.SetOutput(ob) l.SetOutput(ob)
p := firewall.Packet{ p := firewall.Packet{
LocalIP: netip.MustParseAddr("1.2.3.4"), LocalAddr: netip.MustParseAddr("1.2.3.4"),
RemoteIP: netip.MustParseAddr("1.2.3.4"), RemoteAddr: netip.MustParseAddr("1.2.3.4"),
LocalPort: 10, LocalPort: 10,
RemotePort: 90, RemotePort: 90,
Protocol: firewall.ProtoUDP, Protocol: firewall.ProtoUDP,
@ -468,9 +469,9 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
ConnectionState: &ConnectionState{ ConnectionState: &ConnectionState{
peerCert: &c, peerCert: &c,
}, },
vpnIp: network.Addr(), vpnAddrs: []netip.Addr{network.Addr()},
} }
h.CreateRemoteCIDR(c.Certificate) h.buildNetworks(c.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@ -622,55 +623,58 @@ func TestNewFirewallFromConfig(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
// Test a bad rule definition // Test a bad rule definition
c := &dummyCert{} c := &dummyCert{}
cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil)
require.NoError(t, err)
conf := config.NewC(l) conf := config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
_, err := NewFirewallFromConfig(l, c, conf) _, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules") assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
// Test both port and code // Test both port and code
conf = config.NewC(l) conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
_, err = NewFirewallFromConfig(l, c, conf) _, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided") assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
// Test missing host, group, cidr, ca_name and ca_sha // Test missing host, group, cidr, ca_name and ca_sha
conf = config.NewC(l) conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
_, err = NewFirewallFromConfig(l, c, conf) _, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided") assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
// Test code/port error // Test code/port error
conf = config.NewC(l) conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
_, err = NewFirewallFromConfig(l, c, conf) _, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`") assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}}
_, err = NewFirewallFromConfig(l, c, conf) _, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`") assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
// Test proto error // Test proto error
conf = config.NewC(l) conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
_, err = NewFirewallFromConfig(l, c, conf) _, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``") assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
// Test cidr parse error // Test cidr parse error
conf = config.NewC(l) conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, c, conf) _, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
// Test local_cidr parse error // Test local_cidr parse error
conf = config.NewC(l) conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, c, conf) _, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
// Test both group and groups // Test both group and groups
conf = config.NewC(l) conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
_, err = NewFirewallFromConfig(l, c, conf) _, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided") assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
} }

1
go.mod
View File

@ -21,7 +21,6 @@ require (
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
github.com/sirupsen/logrus v1.9.3 github.com/sirupsen/logrus v1.9.3
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.9.0
github.com/vishvananda/netlink v1.3.0 github.com/vishvananda/netlink v1.3.0

2
go.sum
View File

@ -137,8 +137,6 @@ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8=
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E=
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 h1:pnnLyeX7o/5aX8qUQ69P/mLojDqwda8hFOCBTmP/6hw= github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 h1:pnnLyeX7o/5aX8qUQ69P/mLojDqwda8hFOCBTmP/6hw=
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6/go.mod h1:39R/xuhNgVhi+K0/zst4TLrJrVmbm6LVgl4A0+ZFS5M= github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6/go.mod h1:39R/xuhNgVhi+K0/zst4TLrJrVmbm6LVgl4A0+ZFS5M=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=

View File

@ -2,10 +2,12 @@ package nebula
import ( import (
"net/netip" "net/netip"
"slices"
"time" "time"
"github.com/flynn/noise" "github.com/flynn/noise"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
) )
@ -16,30 +18,60 @@ import (
func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
err := f.handshakeManager.allocateIndex(hh) err := f.handshakeManager.allocateIndex(hh)
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp). f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index") WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
return false return false
} }
certState := f.pki.GetCertState() // If we're connecting to a v6 address we must use a v2 cert
ci := NewConnectionState(f.l, f.cipher, certState, true, noise.HandshakeIX, []byte{}, 0) cs := f.pki.getCertState()
v := cs.defaultVersion
for _, a := range hh.hostinfo.vpnAddrs {
if a.Is6() {
v = cert.Version2
break
}
}
crt := cs.getCertificate(v)
if crt == nil {
f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", v).
Error("Unable to handshake with host because no certificate is available")
return false
}
crtHs := cs.getHandshakeBytes(v)
if crtHs == nil {
f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", v).
Error("Unable to handshake with host because no certificate handshake bytes is available")
}
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", v).
Error("Failed to create connection state")
return false
}
hh.hostinfo.ConnectionState = ci hh.hostinfo.ConnectionState = ci
hsProto := &NebulaHandshakeDetails{ hs := &NebulaHandshake{
Details: &NebulaHandshakeDetails{
InitiatorIndex: hh.hostinfo.localIndexId, InitiatorIndex: hh.hostinfo.localIndexId,
Time: uint64(time.Now().UnixNano()), Time: uint64(time.Now().UnixNano()),
Cert: certState.RawCertificateNoKey, Cert: crtHs,
CertVersion: uint32(v),
},
} }
hsBytes := []byte{} hsBytes, err := hs.Marshal()
hs := &NebulaHandshake{
Details: hsProto,
}
hsBytes, err = hs.Marshal()
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp). f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).WithField("certVersion", v).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message") WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
return false return false
} }
@ -48,7 +80,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
msg, _, _, err := ci.H.WriteMessage(h, hsBytes) msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp). f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
return false return false
} }
@ -63,30 +95,44 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
} }
func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) { func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
certState := f.pki.GetCertState() cs := f.pki.getCertState()
ci := NewConnectionState(f.l, f.cipher, certState, false, noise.HandshakeIX, []byte{}, 0) crt := cs.GetDefaultCertificate()
if crt == nil {
f.l.WithField("udpAddr", addr).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", cs.defaultVersion).
Error("Unable to handshake with host because no certificate is available")
}
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
if err != nil {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to create connection state")
return
}
// Mark packet 1 as seen so it doesn't show up as missed // Mark packet 1 as seen so it doesn't show up as missed
ci.window.Update(f.l, 1) ci.window.Update(f.l, 1)
msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:]) msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil { if err != nil {
f.l.WithError(err).WithField("udpAddr", addr). f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to call noise.ReadMessage")
return return
} }
hs := &NebulaHandshake{} hs := &NebulaHandshake{}
err = hs.Unmarshal(msg) err = hs.Unmarshal(msg)
/*
l.Debugln("GOT INDEX: ", hs.Details.InitiatorIndex)
*/
if err != nil || hs.Details == nil { if err != nil || hs.Details == nil {
f.l.WithError(err).WithField("udpAddr", addr). f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed unmarshal handshake message")
return return
} }
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool()) remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool())
if err != nil { if err != nil {
e := f.l.WithError(err).WithField("udpAddr", addr). e := f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}) WithField("handshake", m{"stage": 1, "style": "ix_psk0"})
@ -99,6 +145,20 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
return return
} }
if remoteCert.Certificate.Version() != ci.myCert.Version() {
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
rc := cs.getCertificate(remoteCert.Certificate.Version())
if rc == nil {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
Info("Unable to handshake with host due to missing certificate version")
return
}
// Record the certificate we are actually using
ci.myCert = rc
}
if len(remoteCert.Certificate.Networks()) == 0 { if len(remoteCert.Certificate.Networks()) == 0 {
e := f.l.WithError(err).WithField("udpAddr", addr). e := f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}) WithField("handshake", m{"stage": 1, "style": "ix_psk0"})
@ -111,13 +171,16 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
return return
} }
vpnIp := remoteCert.Certificate.Networks()[0].Addr().Unmap() var vpnAddrs []netip.Addr
certName := remoteCert.Certificate.Name() certName := remoteCert.Certificate.Name()
fingerprint := remoteCert.Fingerprint fingerprint := remoteCert.Fingerprint
issuer := remoteCert.Certificate.Issuer() issuer := remoteCert.Certificate.Issuer()
if vpnIp == f.myVpnNet.Addr() { for _, network := range remoteCert.Certificate.Networks() {
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). vpnAddr := network.Addr()
_, found := f.myVpnAddrsTable.Lookup(vpnAddr)
if found {
f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
@ -126,15 +189,18 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
} }
if addr.IsValid() { if addr.IsValid() {
if !f.lightHouse.GetRemoteAllowList().Allow(vpnIp, addr.Addr()) { if !f.lightHouse.GetRemoteAllowList().Allow(vpnAddr, addr.Addr()) {
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
return return
} }
} }
vpnAddrs = append(vpnAddrs, vpnAddr)
}
myIndex, err := generateIndex(f.l) myIndex, err := generateIndex(f.l)
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr). f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
@ -146,17 +212,17 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
ConnectionState: ci, ConnectionState: ci,
localIndexId: myIndex, localIndexId: myIndex,
remoteIndexId: hs.Details.InitiatorIndex, remoteIndexId: hs.Details.InitiatorIndex,
vpnIp: vpnIp, vpnAddrs: vpnAddrs,
HandshakePacket: make(map[uint8][]byte, 0), HandshakePacket: make(map[uint8][]byte, 0),
lastHandshakeTime: hs.Details.Time, lastHandshakeTime: hs.Details.Time,
relayState: RelayState{ relayState: RelayState{
relays: map[netip.Addr]struct{}{}, relays: map[netip.Addr]struct{}{},
relayForByIp: map[netip.Addr]*Relay{}, relayForByAddr: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{}, relayForByIdx: map[uint32]*Relay{},
}, },
} }
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
@ -165,13 +231,26 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
Info("Handshake message received") Info("Handshake message received")
hs.Details.ResponderIndex = myIndex hs.Details.ResponderIndex = myIndex
hs.Details.Cert = certState.RawCertificateNoKey hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
if hs.Details.Cert == nil {
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("certVersion", ci.myCert.Version()).
Error("Unable to handshake with host because no certificate handshake bytes is available")
return
}
hs.Details.CertVersion = uint32(ci.myCert.Version())
// Update the time in case their clock is way off from ours // Update the time in case their clock is way off from ours
hs.Details.Time = uint64(time.Now().UnixNano()) hs.Details.Time = uint64(time.Now().UnixNano())
hsBytes, err := hs.Marshal() hsBytes, err := hs.Marshal()
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
@ -182,14 +261,14 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2) nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2)
msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes) msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes)
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
return return
} else if dKey == nil || eKey == nil { } else if dKey == nil || eKey == nil {
f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
@ -213,9 +292,9 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
ci.dKey = NewNebulaCipherState(dKey) ci.dKey = NewNebulaCipherState(dKey)
ci.eKey = NewNebulaCipherState(eKey) ci.eKey = NewNebulaCipherState(eKey)
hostinfo.remotes = f.lightHouse.QueryCache(vpnIp) hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
hostinfo.SetRemote(addr) hostinfo.SetRemote(addr)
hostinfo.CreateRemoteCIDR(remoteCert.Certificate) hostinfo.buildNetworks(remoteCert.Certificate)
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f) existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
if err != nil { if err != nil {
@ -225,7 +304,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
if existing.SetRemoteIfPreferred(f.hostMap, addr) { if existing.SetRemoteIfPreferred(f.hostMap, addr) {
// Send a test packet to ensure the other side has also switched to // Send a test packet to ensure the other side has also switched to
// the preferred remote // the preferred remote
f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
} }
msg = existing.HandshakePacket[2] msg = existing.HandshakePacket[2]
@ -233,11 +312,11 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
if addr.IsValid() { if addr.IsValid() {
err := f.outside.WriteTo(msg, addr) err := f.outside.WriteTo(msg, addr)
if err != nil { if err != nil {
f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr). f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
WithError(err).Error("Failed to send handshake message") WithError(err).Error("Failed to send handshake message")
} else { } else {
f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr). f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
Info("Handshake message sent") Info("Handshake message sent")
} }
@ -247,16 +326,16 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
f.l.Error("Handshake send failed: both addr and via are nil.") f.l.Error("Handshake send failed: both addr and via are nil.")
return return
} }
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp) hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
f.l.WithField("vpnIp", existing.vpnIp).WithField("relay", via.relayHI.vpnIp). f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
Info("Handshake message sent") Info("Handshake message sent")
return return
} }
case ErrExistingHostInfo: case ErrExistingHostInfo:
// This means there was an existing tunnel and this handshake was older than the one we are currently based on // This means there was an existing tunnel and this handshake was older than the one we are currently based on
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("oldHandshakeTime", existing.lastHandshakeTime). WithField("oldHandshakeTime", existing.lastHandshakeTime).
WithField("newHandshakeTime", hostinfo.lastHandshakeTime). WithField("newHandshakeTime", hostinfo.lastHandshakeTime).
@ -267,23 +346,23 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
Info("Handshake too old") Info("Handshake too old")
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
return return
case ErrLocalIndexCollision: case ErrLocalIndexCollision:
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry // This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnIp). WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnAddrs).
Error("Failed to add HostInfo due to localIndex collision") Error("Failed to add HostInfo due to localIndex collision")
return return
default: default:
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete // Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
// And we forget to update it here // And we forget to update it here
f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr). f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
@ -299,7 +378,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
if addr.IsValid() { if addr.IsValid() {
err = f.outside.WriteTo(msg, addr) err = f.outside.WriteTo(msg, addr)
if err != nil { if err != nil {
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
@ -307,7 +386,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithError(err).Error("Failed to send handshake") WithError(err).Error("Failed to send handshake")
} else { } else {
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
@ -320,9 +399,9 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
f.l.Error("Handshake send failed: both addr and via are nil.") f.l.Error("Handshake send failed: both addr and via are nil.")
return return
} }
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp) hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
f.l.WithField("vpnIp", vpnIp).WithField("relay", via.relayHI.vpnIp). f.l.WithField("vpnAddrs", vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
@ -349,8 +428,9 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
hostinfo := hh.hostinfo hostinfo := hh.hostinfo
if addr.IsValid() { if addr.IsValid() {
if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.Addr()) { //TODO: this is kind of nonsense now
f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnAddrs[0], addr.Addr()) {
f.l.WithField("vpnIp", hostinfo.vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
return false return false
} }
} }
@ -358,7 +438,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
ci := hostinfo.ConnectionState ci := hostinfo.ConnectionState
msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:]) msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
Error("Failed to call noise.ReadMessage") Error("Failed to call noise.ReadMessage")
@ -367,7 +447,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
// near future // near future
return false return false
} else if dKey == nil || eKey == nil { } else if dKey == nil || eKey == nil {
f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Error("Noise did not arrive at a key") Error("Noise did not arrive at a key")
@ -379,16 +459,16 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
hs := &NebulaHandshake{} hs := &NebulaHandshake{}
err = hs.Unmarshal(msg) err = hs.Unmarshal(msg)
if err != nil || hs.Details == nil { if err != nil || hs.Details == nil {
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again // The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
return true return true
} }
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool()) remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool())
if err != nil { if err != nil {
e := f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). e := f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}) WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
if f.l.Level > logrus.DebugLevel { if f.l.Level > logrus.DebugLevel {
@ -413,7 +493,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
return true return true
} }
vpnIp := remoteCert.Certificate.Networks()[0].Addr().Unmap() vpnNetworks := remoteCert.Certificate.Networks()
certName := remoteCert.Certificate.Name() certName := remoteCert.Certificate.Name()
fingerprint := remoteCert.Fingerprint fingerprint := remoteCert.Fingerprint
issuer := remoteCert.Certificate.Issuer() issuer := remoteCert.Certificate.Issuer()
@ -430,12 +510,17 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
if addr.IsValid() { if addr.IsValid() {
hostinfo.SetRemote(addr) hostinfo.SetRemote(addr)
} else { } else {
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp) hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
}
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
for i, n := range vpnNetworks {
vpnAddrs[i] = n.Addr()
} }
// Ensure the right host responded // Ensure the right host responded
if vpnIp != hostinfo.vpnIp { if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) {
f.l.WithField("intendedVpnIp", hostinfo.vpnIp).WithField("haveVpnIp", vpnIp). f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
WithField("udpAddr", addr).WithField("certName", certName). WithField("udpAddr", addr).WithField("certName", certName).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("Incorrect host responded to handshake") Info("Incorrect host responded to handshake")
@ -444,14 +529,14 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
f.handshakeManager.DeleteHostInfo(hostinfo) f.handshakeManager.DeleteHostInfo(hostinfo)
// Create a new hostinfo/handshake for the intended vpn ip // Create a new hostinfo/handshake for the intended vpn ip
f.handshakeManager.StartHandshake(hostinfo.vpnIp, func(newHH *HandshakeHostInfo) { f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
//TODO: this doesnt know if its being added or is being used for caching a packet //TODO: this doesnt know if its being added or is being used for caching a packet
// Block the current used address // Block the current used address
newHH.hostinfo.remotes = hostinfo.remotes newHH.hostinfo.remotes = hostinfo.remotes
newHH.hostinfo.remotes.BlockRemote(addr) newHH.hostinfo.remotes.BlockRemote(addr)
f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()). f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).
WithField("vpnIp", newHH.hostinfo.vpnIp). WithField("vpnNetworks", vpnNetworks).
WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())). WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())).
Info("Blocked addresses for handshakes") Info("Blocked addresses for handshakes")
@ -459,11 +544,8 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
newHH.packetStore = hh.packetStore newHH.packetStore = hh.packetStore
hh.packetStore = []*cachedPacket{} hh.packetStore = []*cachedPacket{}
// Get the correct remote list for the host we did handshake with // Finally, put the correct vpn addrs in the host info, tell them to close the tunnel, and return true to tear down
hostinfo.SetRemote(addr) hostinfo.vpnAddrs = vpnAddrs
hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
// Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down
hostinfo.vpnIp = vpnIp
f.sendCloseTunnel(hostinfo) f.sendCloseTunnel(hostinfo)
}) })
@ -474,7 +556,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
ci.window.Update(f.l, 2) ci.window.Update(f.l, 2)
duration := time.Since(hh.startTime).Nanoseconds() duration := time.Since(hh.startTime).Nanoseconds()
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). f.l.WithField("vpnNetworks", vpnNetworks).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
@ -485,7 +567,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
Info("Handshake message received") Info("Handshake message received")
// Build up the radix for the firewall if we have subnets in the cert // Build up the radix for the firewall if we have subnets in the cert
hostinfo.CreateRemoteCIDR(remoteCert.Certificate) hostinfo.buildNetworks(remoteCert.Certificate)
// Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp // Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp
f.handshakeManager.Complete(hostinfo, f) f.handshakeManager.Complete(hostinfo, f)

View File

@ -13,6 +13,7 @@ import (
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
) )
@ -118,18 +119,18 @@ func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *Lig
} }
} }
func (c *HandshakeManager) Run(ctx context.Context) { func (hm *HandshakeManager) Run(ctx context.Context) {
clockSource := time.NewTicker(c.config.tryInterval) clockSource := time.NewTicker(hm.config.tryInterval)
defer clockSource.Stop() defer clockSource.Stop()
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return return
case vpnIP := <-c.trigger: case vpnIP := <-hm.trigger:
c.handleOutbound(vpnIP, true) hm.handleOutbound(vpnIP, true)
case now := <-clockSource.C: case now := <-clockSource.C:
c.NextOutboundHandshakeTimerTick(now) hm.NextOutboundHandshakeTimerTick(now)
} }
} }
} }
@ -159,14 +160,14 @@ func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender,
} }
} }
func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) { func (hm *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) {
c.OutboundHandshakeTimer.Advance(now) hm.OutboundHandshakeTimer.Advance(now)
for { for {
vpnIp, has := c.OutboundHandshakeTimer.Purge() vpnIp, has := hm.OutboundHandshakeTimer.Purge()
if !has { if !has {
break break
} }
c.handleOutbound(vpnIp, false) hm.handleOutbound(vpnIp, false)
} }
} }
@ -208,7 +209,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
// NB ^ This comment doesn't jive. It's how the thing gets initialized. // NB ^ This comment doesn't jive. It's how the thing gets initialized.
// It's the common path. Should it update every time, in case a future LH query/queries give us more info? // It's the common path. Should it update every time, in case a future LH query/queries give us more info?
if hostinfo.remotes == nil { if hostinfo.remotes == nil {
hostinfo.remotes = hm.lightHouse.QueryCache(vpnIp) hostinfo.remotes = hm.lightHouse.QueryCache([]netip.Addr{vpnIp})
} }
remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges()) remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())
@ -267,11 +268,18 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts") hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
// Send a RelayRequest to all known Relay IP's // Send a RelayRequest to all known Relay IP's
for _, relay := range hostinfo.remotes.relays { for _, relay := range hostinfo.remotes.relays {
// Don't relay to myself, and don't relay through the host I'm trying to connect to // Don't relay to myself
if relay == vpnIp || relay == hm.lightHouse.myVpnNet.Addr() { if relay == vpnIp {
continue continue
} }
relayHostInfo := hm.mainHostMap.QueryVpnIp(relay)
// Don't relay through the host I'm trying to connect to
_, found := hm.f.myVpnAddrsTable.Lookup(relay)
if found {
continue
}
relayHostInfo := hm.mainHostMap.QueryVpnAddr(relay)
if relayHostInfo == nil || !relayHostInfo.remote.IsValid() { if relayHostInfo == nil || !relayHostInfo.remote.IsValid() {
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target") hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
hm.f.Handshake(relay) hm.f.Handshake(relay)
@ -286,17 +294,35 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
case Requested: case Requested:
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request") hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request")
//TODO: IPV6-WORK
myVpnIpB := hm.f.myVpnNet.Addr().As4()
theirVpnIpB := vpnIp.As4()
// Re-send the CreateRelay request, in case the previous one was lost.
m := NebulaControl{ m := NebulaControl{
Type: NebulaControl_CreateRelayRequest, Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: existingRelay.LocalIndex, InitiatorRelayIndex: existingRelay.LocalIndex,
RelayFromIp: binary.BigEndian.Uint32(myVpnIpB[:]),
RelayToIp: binary.BigEndian.Uint32(theirVpnIpB[:]),
} }
switch relayHostInfo.GetCert().Certificate.Version() {
case cert.Version1:
if !hm.f.myVpnAddrs[0].Is4() {
hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
continue
}
if !vpnIp.Is4() {
hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
continue
}
b := hm.f.myVpnAddrs[0].As4()
m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
b = vpnIp.As4()
m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
case cert.Version2:
m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0])
m.RelayToAddr = netAddrToProtoAddr(vpnIp)
default:
hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
continue
}
msg, err := m.Marshal() msg, err := m.Marshal()
if err != nil { if err != nil {
hostinfo.logger(hm.l). hostinfo.logger(hm.l).
@ -306,7 +332,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
// This must send over the hostinfo, not over hm.Hosts[ip] // This must send over the hostinfo, not over hm.Hosts[ip]
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
hm.l.WithFields(logrus.Fields{ hm.l.WithFields(logrus.Fields{
"relayFrom": hm.f.myVpnNet.Addr(), "relayFrom": hm.f.myVpnAddrs[0],
"relayTo": vpnIp, "relayTo": vpnIp,
"initiatorRelayIndex": existingRelay.LocalIndex, "initiatorRelayIndex": existingRelay.LocalIndex,
"relay": relay}). "relay": relay}).
@ -316,7 +342,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
hostinfo.logger(hm.l). hostinfo.logger(hm.l).
WithField("vpnIp", vpnIp). WithField("vpnIp", vpnIp).
WithField("state", existingRelay.State). WithField("state", existingRelay.State).
WithField("relay", relayHostInfo.vpnIp). WithField("relay", relayHostInfo.vpnAddrs[0]).
Errorf("Relay unexpected state") Errorf("Relay unexpected state")
} }
} else { } else {
@ -327,16 +353,35 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap") hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap")
} }
//TODO: IPV6-WORK
myVpnIpB := hm.f.myVpnNet.Addr().As4()
theirVpnIpB := vpnIp.As4()
m := NebulaControl{ m := NebulaControl{
Type: NebulaControl_CreateRelayRequest, Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: idx, InitiatorRelayIndex: idx,
RelayFromIp: binary.BigEndian.Uint32(myVpnIpB[:]),
RelayToIp: binary.BigEndian.Uint32(theirVpnIpB[:]),
} }
switch relayHostInfo.GetCert().Certificate.Version() {
case cert.Version1:
if !hm.f.myVpnAddrs[0].Is4() {
hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
continue
}
if !vpnIp.Is4() {
hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
continue
}
b := hm.f.myVpnAddrs[0].As4()
m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
b = vpnIp.As4()
m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
case cert.Version2:
m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0])
m.RelayToAddr = netAddrToProtoAddr(vpnIp)
default:
hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
continue
}
msg, err := m.Marshal() msg, err := m.Marshal()
if err != nil { if err != nil {
hostinfo.logger(hm.l). hostinfo.logger(hm.l).
@ -345,7 +390,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
} else { } else {
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
hm.l.WithFields(logrus.Fields{ hm.l.WithFields(logrus.Fields{
"relayFrom": hm.f.myVpnNet.Addr(), "relayFrom": hm.f.myVpnAddrs[0],
"relayTo": vpnIp, "relayTo": vpnIp,
"initiatorRelayIndex": idx, "initiatorRelayIndex": idx,
"relay": relay}). "relay": relay}).
@ -381,10 +426,10 @@ func (hm *HandshakeManager) GetOrHandshake(vpnIp netip.Addr, cacheCb func(*Hands
} }
// StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip // StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip
func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*HandshakeHostInfo)) *HostInfo { func (hm *HandshakeManager) StartHandshake(vpnAddr netip.Addr, cacheCb func(*HandshakeHostInfo)) *HostInfo {
hm.Lock() hm.Lock()
if hh, ok := hm.vpnIps[vpnIp]; ok { if hh, ok := hm.vpnIps[vpnAddr]; ok {
// We are already trying to handshake with this vpn ip // We are already trying to handshake with this vpn ip
if cacheCb != nil { if cacheCb != nil {
cacheCb(hh) cacheCb(hh)
@ -394,11 +439,11 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands
} }
hostinfo := &HostInfo{ hostinfo := &HostInfo{
vpnIp: vpnIp, vpnAddrs: []netip.Addr{vpnAddr},
HandshakePacket: make(map[uint8][]byte, 0), HandshakePacket: make(map[uint8][]byte, 0),
relayState: RelayState{ relayState: RelayState{
relays: map[netip.Addr]struct{}{}, relays: map[netip.Addr]struct{}{},
relayForByIp: map[netip.Addr]*Relay{}, relayForByAddr: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{}, relayForByIdx: map[uint32]*Relay{},
}, },
} }
@ -407,9 +452,9 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands
hostinfo: hostinfo, hostinfo: hostinfo,
startTime: time.Now(), startTime: time.Now(),
} }
hm.vpnIps[vpnIp] = hh hm.vpnIps[vpnAddr] = hh
hm.metricInitiated.Inc(1) hm.metricInitiated.Inc(1)
hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval) hm.OutboundHandshakeTimer.Add(vpnAddr, hm.config.tryInterval)
if cacheCb != nil { if cacheCb != nil {
cacheCb(hh) cacheCb(hh)
@ -417,21 +462,21 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands
// If this is a static host, we don't need to wait for the HostQueryReply // If this is a static host, we don't need to wait for the HostQueryReply
// We can trigger the handshake right now // We can trigger the handshake right now
_, doTrigger := hm.lightHouse.GetStaticHostList()[vpnIp] _, doTrigger := hm.lightHouse.GetStaticHostList()[vpnAddr]
if !doTrigger { if !doTrigger {
// Add any calculated remotes, and trigger early handshake if one found // Add any calculated remotes, and trigger early handshake if one found
doTrigger = hm.lightHouse.addCalculatedRemotes(vpnIp) doTrigger = hm.lightHouse.addCalculatedRemotes(vpnAddr)
} }
if doTrigger { if doTrigger {
select { select {
case hm.trigger <- vpnIp: case hm.trigger <- vpnAddr:
default: default:
} }
} }
hm.Unlock() hm.Unlock()
hm.lightHouse.QueryServer(vpnIp) hm.lightHouse.QueryServer(vpnAddr)
return hostinfo return hostinfo
} }
@ -452,14 +497,14 @@ var (
// //
// ErrLocalIndexCollision if we already have an entry in the main or pending // ErrLocalIndexCollision if we already have an entry in the main or pending
// hostmap for the hostinfo.localIndexId. // hostmap for the hostinfo.localIndexId.
func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) { func (hm *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) {
c.mainHostMap.Lock() hm.mainHostMap.Lock()
defer c.mainHostMap.Unlock() defer hm.mainHostMap.Unlock()
c.Lock() hm.Lock()
defer c.Unlock() defer hm.Unlock()
// Check if we already have a tunnel with this vpn ip // Check if we already have a tunnel with this vpn ip
existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp] existingHostInfo, found := hm.mainHostMap.Hosts[hostinfo.vpnAddrs[0]]
if found && existingHostInfo != nil { if found && existingHostInfo != nil {
testHostInfo := existingHostInfo testHostInfo := existingHostInfo
for testHostInfo != nil { for testHostInfo != nil {
@ -476,31 +521,31 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
return existingHostInfo, ErrExistingHostInfo return existingHostInfo, ErrExistingHostInfo
} }
existingHostInfo.logger(c.l).Info("Taking new handshake") existingHostInfo.logger(hm.l).Info("Taking new handshake")
} }
existingIndex, found := c.mainHostMap.Indexes[hostinfo.localIndexId] existingIndex, found := hm.mainHostMap.Indexes[hostinfo.localIndexId]
if found { if found {
// We have a collision, but for a different hostinfo // We have a collision, but for a different hostinfo
return existingIndex, ErrLocalIndexCollision return existingIndex, ErrLocalIndexCollision
} }
existingPendingIndex, found := c.indexes[hostinfo.localIndexId] existingPendingIndex, found := hm.indexes[hostinfo.localIndexId]
if found && existingPendingIndex.hostinfo != hostinfo { if found && existingPendingIndex.hostinfo != hostinfo {
// We have a collision, but for a different hostinfo // We have a collision, but for a different hostinfo
return existingPendingIndex.hostinfo, ErrLocalIndexCollision return existingPendingIndex.hostinfo, ErrLocalIndexCollision
} }
existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId] existingRemoteIndex, found := hm.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
if found && existingRemoteIndex != nil && existingRemoteIndex.vpnIp != hostinfo.vpnIp { if found && existingRemoteIndex != nil && existingRemoteIndex.vpnAddrs[0] != hostinfo.vpnAddrs[0] {
// We have a collision, but this can happen since we can't control // We have a collision, but this can happen since we can't control
// the remote ID. Just log about the situation as a note. // the remote ID. Just log about the situation as a note.
hostinfo.logger(c.l). hostinfo.logger(hm.l).
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp). WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs).
Info("New host shadows existing host remoteIndex") Info("New host shadows existing host remoteIndex")
} }
c.mainHostMap.unlockedAddHostInfo(hostinfo, f) hm.mainHostMap.unlockedAddHostInfo(hostinfo, f)
return existingHostInfo, nil return existingHostInfo, nil
} }
@ -518,7 +563,7 @@ func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
// We have a collision, but this can happen since we can't control // We have a collision, but this can happen since we can't control
// the remote ID. Just log about the situation as a note. // the remote ID. Just log about the situation as a note.
hostinfo.logger(hm.l). hostinfo.logger(hm.l).
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp). WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs).
Info("New host shadows existing host remoteIndex") Info("New host shadows existing host remoteIndex")
} }
@ -555,31 +600,34 @@ func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error {
return errors.New("failed to generate unique localIndexId") return errors.New("failed to generate unique localIndexId")
} }
func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) { func (hm *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
c.Lock() hm.Lock()
defer c.Unlock() defer hm.Unlock()
c.unlockedDeleteHostInfo(hostinfo) hm.unlockedDeleteHostInfo(hostinfo)
} }
func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) { func (hm *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
delete(c.vpnIps, hostinfo.vpnIp) for _, addr := range hostinfo.vpnAddrs {
if len(c.vpnIps) == 0 { delete(hm.vpnIps, addr)
c.vpnIps = map[netip.Addr]*HandshakeHostInfo{}
} }
delete(c.indexes, hostinfo.localIndexId) if len(hm.vpnIps) == 0 {
if len(c.vpnIps) == 0 { hm.vpnIps = map[netip.Addr]*HandshakeHostInfo{}
c.indexes = map[uint32]*HandshakeHostInfo{}
} }
if c.l.Level >= logrus.DebugLevel { delete(hm.indexes, hostinfo.localIndexId)
c.l.WithField("hostMap", m{"mapTotalSize": len(c.vpnIps), if len(hm.indexes) == 0 {
"vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). hm.indexes = map[uint32]*HandshakeHostInfo{}
}
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.vpnIps),
"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
Debug("Pending hostmap hostInfo deleted") Debug("Pending hostmap hostInfo deleted")
} }
} }
func (hm *HandshakeManager) QueryVpnIp(vpnIp netip.Addr) *HostInfo { func (hm *HandshakeManager) QueryVpnAddr(vpnIp netip.Addr) *HostInfo {
hh := hm.queryVpnIp(vpnIp) hh := hm.queryVpnIp(vpnIp)
if hh != nil { if hh != nil {
return hh.hostinfo return hh.hostinfo
@ -608,37 +656,37 @@ func (hm *HandshakeManager) queryIndex(index uint32) *HandshakeHostInfo {
return hm.indexes[index] return hm.indexes[index]
} }
func (c *HandshakeManager) GetPreferredRanges() []netip.Prefix { func (hm *HandshakeManager) GetPreferredRanges() []netip.Prefix {
return c.mainHostMap.GetPreferredRanges() return hm.mainHostMap.GetPreferredRanges()
} }
func (c *HandshakeManager) ForEachVpnIp(f controlEach) { func (hm *HandshakeManager) ForEachVpnAddr(f controlEach) {
c.RLock() hm.RLock()
defer c.RUnlock() defer hm.RUnlock()
for _, v := range c.vpnIps { for _, v := range hm.vpnIps {
f(v.hostinfo) f(v.hostinfo)
} }
} }
func (c *HandshakeManager) ForEachIndex(f controlEach) { func (hm *HandshakeManager) ForEachIndex(f controlEach) {
c.RLock() hm.RLock()
defer c.RUnlock() defer hm.RUnlock()
for _, v := range c.indexes { for _, v := range hm.indexes {
f(v.hostinfo) f(v.hostinfo)
} }
} }
func (c *HandshakeManager) EmitStats() { func (hm *HandshakeManager) EmitStats() {
c.RLock() hm.RLock()
hostLen := len(c.vpnIps) hostLen := len(hm.vpnIps)
indexLen := len(c.indexes) indexLen := len(hm.indexes)
c.RUnlock() hm.RUnlock()
metrics.GetOrRegisterGauge("hostmap.pending.hosts", nil).Update(int64(hostLen)) metrics.GetOrRegisterGauge("hostmap.pending.hosts", nil).Update(int64(hostLen))
metrics.GetOrRegisterGauge("hostmap.pending.indexes", nil).Update(int64(indexLen)) metrics.GetOrRegisterGauge("hostmap.pending.indexes", nil).Update(int64(indexLen))
c.mainHostMap.EmitStats() hm.mainHostMap.EmitStats()
} }
// Utility functions below // Utility functions below

View File

@ -5,6 +5,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
@ -13,21 +14,20 @@ import (
func Test_NewHandshakeManagerVpnIp(t *testing.T) { func Test_NewHandshakeManagerVpnIp(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
vpncidr := netip.MustParsePrefix("172.1.1.1/24")
localrange := netip.MustParsePrefix("10.1.1.1/24") localrange := netip.MustParsePrefix("10.1.1.1/24")
ip := netip.MustParseAddr("172.1.1.2") ip := netip.MustParseAddr("172.1.1.2")
preferredRanges := []netip.Prefix{localrange} preferredRanges := []netip.Prefix{localrange}
mainHM := newHostMap(l, vpncidr) mainHM := newHostMap(l)
mainHM.preferredRanges.Store(&preferredRanges) mainHM.preferredRanges.Store(&preferredRanges)
lh := newTestLighthouse() lh := newTestLighthouse()
cs := &CertState{ cs := &CertState{
RawCertificate: []byte{}, defaultVersion: cert.Version1,
PrivateKey: []byte{}, privateKey: []byte{},
Certificate: &dummyCert{}, v1Cert: &dummyCert{version: cert.Version1},
RawCertificateNoKey: []byte{}, v1HandshakeBytes: []byte{},
} }
blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig) blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
@ -41,7 +41,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
i2 := blah.StartHandshake(ip, nil) i2 := blah.StartHandshake(ip, nil)
assert.Same(t, i, i2) assert.Same(t, i, i2)
i.remotes = NewRemoteList(nil) i.remotes = NewRemoteList([]netip.Addr{}, nil)
// Adding something to pending should not affect the main hostmap // Adding something to pending should not affect the main hostmap
assert.Len(t, mainHM.Hosts, 0) assert.Len(t, mainHM.Hosts, 0)
@ -79,16 +79,24 @@ func testCountTimerWheelEntries(tw *LockingTimerWheel[netip.Addr]) (c int) {
type mockEncWriter struct { type mockEncWriter struct {
} }
func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) { func (mw *mockEncWriter) SendMessageToVpnAddr(_ header.MessageType, _ header.MessageSubType, _ netip.Addr, _, _, _ []byte) {
return return
} }
func (mw *mockEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) { func (mw *mockEncWriter) SendVia(_ *HostInfo, _ *Relay, _, _, _ []byte, _ bool) {
return return
} }
func (mw *mockEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) { func (mw *mockEncWriter) SendMessageToHostInfo(_ header.MessageType, _ header.MessageSubType, _ *HostInfo, _, _, _ []byte) {
return return
} }
func (mw *mockEncWriter) Handshake(vpnIP netip.Addr) {} func (mw *mockEncWriter) Handshake(_ netip.Addr) {}
func (mw *mockEncWriter) GetHostInfo(_ netip.Addr) *HostInfo {
return nil
}
func (mw *mockEncWriter) GetCertState() *CertState {
return &CertState{defaultVersion: cert.Version2}
}

View File

@ -48,7 +48,7 @@ type Relay struct {
State int State int
LocalIndex uint32 LocalIndex uint32
RemoteIndex uint32 RemoteIndex uint32
PeerIp netip.Addr PeerAddr netip.Addr
} }
type HostMap struct { type HostMap struct {
@ -58,7 +58,6 @@ type HostMap struct {
RemoteIndexes map[uint32]*HostInfo RemoteIndexes map[uint32]*HostInfo
Hosts map[netip.Addr]*HostInfo Hosts map[netip.Addr]*HostInfo
preferredRanges atomic.Pointer[[]netip.Prefix] preferredRanges atomic.Pointer[[]netip.Prefix]
vpnCIDR netip.Prefix
l *logrus.Logger l *logrus.Logger
} }
@ -68,8 +67,8 @@ type HostMap struct {
type RelayState struct { type RelayState struct {
sync.RWMutex sync.RWMutex
relays map[netip.Addr]struct{} // Set of VpnIp's of Hosts to use as relays to access this peer relays map[netip.Addr]struct{} // Set of vpnAddr's of Hosts to use as relays to access this peer
relayForByIp map[netip.Addr]*Relay // Maps VpnIps of peers for which this HostInfo is a relay to some Relay info relayForByAddr map[netip.Addr]*Relay // Maps vpnAddr of peers for which this HostInfo is a relay to some Relay info
relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info
} }
@ -89,10 +88,10 @@ func (rs *RelayState) CopyAllRelayFor() []*Relay {
return ret return ret
} }
func (rs *RelayState) GetRelayForByIp(ip netip.Addr) (*Relay, bool) { func (rs *RelayState) GetRelayForByAddr(addr netip.Addr) (*Relay, bool) {
rs.RLock() rs.RLock()
defer rs.RUnlock() defer rs.RUnlock()
r, ok := rs.relayForByIp[ip] r, ok := rs.relayForByAddr[addr]
return r, ok return r, ok
} }
@ -115,8 +114,8 @@ func (rs *RelayState) CopyRelayIps() []netip.Addr {
func (rs *RelayState) CopyRelayForIps() []netip.Addr { func (rs *RelayState) CopyRelayForIps() []netip.Addr {
rs.RLock() rs.RLock()
defer rs.RUnlock() defer rs.RUnlock()
currentRelays := make([]netip.Addr, 0, len(rs.relayForByIp)) currentRelays := make([]netip.Addr, 0, len(rs.relayForByAddr))
for relayIp := range rs.relayForByIp { for relayIp := range rs.relayForByAddr {
currentRelays = append(currentRelays, relayIp) currentRelays = append(currentRelays, relayIp)
} }
return currentRelays return currentRelays
@ -135,7 +134,7 @@ func (rs *RelayState) CopyRelayForIdxs() []uint32 {
func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool { func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool {
rs.Lock() rs.Lock()
defer rs.Unlock() defer rs.Unlock()
r, ok := rs.relayForByIp[vpnIp] r, ok := rs.relayForByAddr[vpnIp]
if !ok { if !ok {
return false return false
} }
@ -143,7 +142,7 @@ func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool
newRelay.State = Established newRelay.State = Established
newRelay.RemoteIndex = remoteIdx newRelay.RemoteIndex = remoteIdx
rs.relayForByIdx[r.LocalIndex] = &newRelay rs.relayForByIdx[r.LocalIndex] = &newRelay
rs.relayForByIp[r.PeerIp] = &newRelay rs.relayForByAddr[r.PeerAddr] = &newRelay
return true return true
} }
@ -158,14 +157,14 @@ func (rs *RelayState) CompleteRelayByIdx(localIdx uint32, remoteIdx uint32) (*Re
newRelay.State = Established newRelay.State = Established
newRelay.RemoteIndex = remoteIdx newRelay.RemoteIndex = remoteIdx
rs.relayForByIdx[r.LocalIndex] = &newRelay rs.relayForByIdx[r.LocalIndex] = &newRelay
rs.relayForByIp[r.PeerIp] = &newRelay rs.relayForByAddr[r.PeerAddr] = &newRelay
return &newRelay, true return &newRelay, true
} }
func (rs *RelayState) QueryRelayForByIp(vpnIp netip.Addr) (*Relay, bool) { func (rs *RelayState) QueryRelayForByIp(vpnIp netip.Addr) (*Relay, bool) {
rs.RLock() rs.RLock()
defer rs.RUnlock() defer rs.RUnlock()
r, ok := rs.relayForByIp[vpnIp] r, ok := rs.relayForByAddr[vpnIp]
return r, ok return r, ok
} }
@ -179,7 +178,7 @@ func (rs *RelayState) QueryRelayForByIdx(idx uint32) (*Relay, bool) {
func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) { func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) {
rs.Lock() rs.Lock()
defer rs.Unlock() defer rs.Unlock()
rs.relayForByIp[ip] = r rs.relayForByAddr[ip] = r
rs.relayForByIdx[idx] = r rs.relayForByIdx[idx] = r
} }
@ -190,9 +189,11 @@ type HostInfo struct {
ConnectionState *ConnectionState ConnectionState *ConnectionState
remoteIndexId uint32 remoteIndexId uint32
localIndexId uint32 localIndexId uint32
vpnIp netip.Addr vpnAddrs []netip.Addr
recvError atomic.Uint32 recvError atomic.Uint32
remoteCidr *bart.Table[struct{}]
// networks are both all vpn and unsafe networks assigned to this host
networks *bart.Table[struct{}]
relayState RelayState relayState RelayState
// HandshakePacket records the packets used to create this hostinfo // HandshakePacket records the packets used to create this hostinfo
@ -241,28 +242,26 @@ type cachedPacketMetrics struct {
dropped metrics.Counter dropped metrics.Counter
} }
func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR netip.Prefix, c *config.C) *HostMap { func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap {
hm := newHostMap(l, vpnCIDR) hm := newHostMap(l)
hm.reload(c, true) hm.reload(c, true)
c.RegisterReloadCallback(func(c *config.C) { c.RegisterReloadCallback(func(c *config.C) {
hm.reload(c, false) hm.reload(c, false)
}) })
l.WithField("network", hm.vpnCIDR.String()). l.WithField("preferredRanges", hm.GetPreferredRanges()).
WithField("preferredRanges", hm.GetPreferredRanges()).
Info("Main HostMap created") Info("Main HostMap created")
return hm return hm
} }
func newHostMap(l *logrus.Logger, vpnCIDR netip.Prefix) *HostMap { func newHostMap(l *logrus.Logger) *HostMap {
return &HostMap{ return &HostMap{
Indexes: map[uint32]*HostInfo{}, Indexes: map[uint32]*HostInfo{},
Relays: map[uint32]*HostInfo{}, Relays: map[uint32]*HostInfo{},
RemoteIndexes: map[uint32]*HostInfo{}, RemoteIndexes: map[uint32]*HostInfo{},
Hosts: map[netip.Addr]*HostInfo{}, Hosts: map[netip.Addr]*HostInfo{},
vpnCIDR: vpnCIDR,
l: l, l: l,
} }
} }
@ -305,17 +304,6 @@ func (hm *HostMap) EmitStats() {
metrics.GetOrRegisterGauge("hostmap.main.relayIndexes", nil).Update(int64(relaysLen)) metrics.GetOrRegisterGauge("hostmap.main.relayIndexes", nil).Update(int64(relaysLen))
} }
func (hm *HostMap) RemoveRelay(localIdx uint32) {
hm.Lock()
_, ok := hm.Relays[localIdx]
if !ok {
hm.Unlock()
return
}
delete(hm.Relays, localIdx)
hm.Unlock()
}
// DeleteHostInfo will fully unlink the hostinfo and return true if it was the final hostinfo for this vpn ip // DeleteHostInfo will fully unlink the hostinfo and return true if it was the final hostinfo for this vpn ip
func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool { func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool {
// Delete the host itself, ensuring it's not modified anymore // Delete the host itself, ensuring it's not modified anymore
@ -335,7 +323,9 @@ func (hm *HostMap) MakePrimary(hostinfo *HostInfo) {
} }
func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) { func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) {
oldHostinfo := hm.Hosts[hostinfo.vpnIp] //TODO: we may need to promote follow on hostinfos from these vpnAddrs as well since their oldHostinfo might not be the same as this one
// this really looks like an ideal spot for memory leaks
oldHostinfo := hm.Hosts[hostinfo.vpnAddrs[0]]
if oldHostinfo == hostinfo { if oldHostinfo == hostinfo {
return return
} }
@ -348,7 +338,7 @@ func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) {
hostinfo.next.prev = hostinfo.prev hostinfo.next.prev = hostinfo.prev
} }
hm.Hosts[hostinfo.vpnIp] = hostinfo hm.Hosts[hostinfo.vpnAddrs[0]] = hostinfo
if oldHostinfo == nil { if oldHostinfo == nil {
return return
@ -360,23 +350,35 @@ func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) {
} }
func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) { func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
primary, ok := hm.Hosts[hostinfo.vpnIp] for _, addr := range hostinfo.vpnAddrs {
h := hm.Hosts[addr]
for h != nil {
if h == hostinfo {
hm.unlockedInnerDeleteHostInfo(h, addr)
}
h = h.next
}
}
}
func (hm *HostMap) unlockedInnerDeleteHostInfo(hostinfo *HostInfo, addr netip.Addr) {
primary, ok := hm.Hosts[addr]
if ok && primary == hostinfo { if ok && primary == hostinfo {
// The vpnIp pointer points to the same hostinfo as the local index id, we can remove it // The vpn addr pointer points to the same hostinfo as the local index id, we can remove it
delete(hm.Hosts, hostinfo.vpnIp) delete(hm.Hosts, addr)
if len(hm.Hosts) == 0 { if len(hm.Hosts) == 0 {
hm.Hosts = map[netip.Addr]*HostInfo{} hm.Hosts = map[netip.Addr]*HostInfo{}
} }
if hostinfo.next != nil { if hostinfo.next != nil {
// We had more than 1 hostinfo at this vpnip, promote the next in the list to primary // We had more than 1 hostinfo at this vpn addr, promote the next in the list to primary
hm.Hosts[hostinfo.vpnIp] = hostinfo.next hm.Hosts[addr] = hostinfo.next
// It is primary, there is no previous hostinfo now // It is primary, there is no previous hostinfo now
hostinfo.next.prev = nil hostinfo.next.prev = nil
} }
} else { } else {
// Relink if we were in the middle of multiple hostinfos for this vpn ip // Relink if we were in the middle of multiple hostinfos for this vpn addr
if hostinfo.prev != nil { if hostinfo.prev != nil {
hostinfo.prev.next = hostinfo.next hostinfo.prev.next = hostinfo.next
} }
@ -406,7 +408,7 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
if hm.l.Level >= logrus.DebugLevel { if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.Hosts), hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.Hosts),
"vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). "vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
Debug("Hostmap hostInfo deleted") Debug("Hostmap hostInfo deleted")
} }
@ -448,11 +450,11 @@ func (hm *HostMap) QueryReverseIndex(index uint32) *HostInfo {
} }
} }
func (hm *HostMap) QueryVpnIp(vpnIp netip.Addr) *HostInfo { func (hm *HostMap) QueryVpnAddr(vpnIp netip.Addr) *HostInfo {
return hm.queryVpnIp(vpnIp, nil) return hm.queryVpnAddr(vpnIp, nil)
} }
func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp netip.Addr) (*HostInfo, *Relay, error) { func (hm *HostMap) QueryVpnAddrsRelayFor(targetIps []netip.Addr, relayHostIp netip.Addr) (*HostInfo, *Relay, error) {
hm.RLock() hm.RLock()
defer hm.RUnlock() defer hm.RUnlock()
@ -460,17 +462,21 @@ func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp netip.Addr) (*HostIn
if !ok { if !ok {
return nil, nil, errors.New("unable to find host") return nil, nil, errors.New("unable to find host")
} }
for h != nil { for h != nil {
for _, targetIp := range targetIps {
r, ok := h.relayState.QueryRelayForByIp(targetIp) r, ok := h.relayState.QueryRelayForByIp(targetIp)
if ok && r.State == Established { if ok && r.State == Established {
return h, r, nil return h, r, nil
} }
}
h = h.next h = h.next
} }
return nil, nil, errors.New("unable to find host with relay") return nil, nil, errors.New("unable to find host with relay")
} }
func (hm *HostMap) queryVpnIp(vpnIp netip.Addr, promoteIfce *Interface) *HostInfo { func (hm *HostMap) queryVpnAddr(vpnIp netip.Addr, promoteIfce *Interface) *HostInfo {
hm.RLock() hm.RLock()
if h, ok := hm.Hosts[vpnIp]; ok { if h, ok := hm.Hosts[vpnIp]; ok {
hm.RUnlock() hm.RUnlock()
@ -491,25 +497,30 @@ func (hm *HostMap) queryVpnIp(vpnIp netip.Addr, promoteIfce *Interface) *HostInf
func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) { func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
if f.serveDns { if f.serveDns {
remoteCert := hostinfo.ConnectionState.peerCert remoteCert := hostinfo.ConnectionState.peerCert
dnsR.Add(remoteCert.Certificate.Name()+".", remoteCert.Certificate.Networks()[0].Addr().String()) dnsR.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs)
} }
for _, addr := range hostinfo.vpnAddrs {
existing := hm.Hosts[hostinfo.vpnIp] hm.unlockedInnerAddHostInfo(addr, hostinfo, f)
hm.Hosts[hostinfo.vpnIp] = hostinfo
if existing != nil {
hostinfo.next = existing
existing.prev = hostinfo
} }
hm.Indexes[hostinfo.localIndexId] = hostinfo hm.Indexes[hostinfo.localIndexId] = hostinfo
hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
if hm.l.Level >= logrus.DebugLevel { if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"vpnIp": hostinfo.vpnIp, "mapTotalSize": len(hm.Hosts), hm.l.WithField("hostMap", m{"vpnAddrs": hostinfo.vpnAddrs, "mapTotalSize": len(hm.Hosts),
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": hostinfo.vpnIp}}). "hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}}).
Debug("Hostmap vpnIp added") Debug("Hostmap vpnIp added")
} }
}
func (hm *HostMap) unlockedInnerAddHostInfo(vpnAddr netip.Addr, hostinfo *HostInfo, f *Interface) {
existing := hm.Hosts[vpnAddr]
hm.Hosts[vpnAddr] = hostinfo
if existing != nil && existing != hostinfo {
hostinfo.next = existing
existing.prev = hostinfo
}
i := 1 i := 1
check := hostinfo check := hostinfo
@ -527,7 +538,7 @@ func (hm *HostMap) GetPreferredRanges() []netip.Prefix {
return *hm.preferredRanges.Load() return *hm.preferredRanges.Load()
} }
func (hm *HostMap) ForEachVpnIp(f controlEach) { func (hm *HostMap) ForEachVpnAddr(f controlEach) {
hm.RLock() hm.RLock()
defer hm.RUnlock() defer hm.RUnlock()
@ -581,7 +592,7 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []netip.Prefix, ifce *Interfac
} }
i.nextLHQuery.Store(now + ifce.reQueryWait.Load()) i.nextLHQuery.Store(now + ifce.reQueryWait.Load())
ifce.lightHouse.QueryServer(i.vpnIp) ifce.lightHouse.QueryServer(i.vpnAddrs[0])
} }
} }
@ -596,7 +607,7 @@ func (i *HostInfo) SetRemote(remote netip.AddrPort) {
// We copy here because we likely got this remote from a source that reuses the object // We copy here because we likely got this remote from a source that reuses the object
if i.remote != remote { if i.remote != remote {
i.remote = remote i.remote = remote
i.remotes.LearnRemote(i.vpnIp, remote) i.remotes.LearnRemote(i.vpnAddrs[0], remote)
} }
} }
@ -647,21 +658,20 @@ func (i *HostInfo) RecvErrorExceeded() bool {
return true return true
} }
func (i *HostInfo) CreateRemoteCIDR(c cert.Certificate) { func (i *HostInfo) buildNetworks(c cert.Certificate) {
if len(c.Networks()) == 1 && len(c.UnsafeNetworks()) == 0 { if len(c.Networks()) == 1 && len(c.UnsafeNetworks()) == 0 {
// Simple case, no CIDRTree needed // Simple case, no CIDRTree needed
return return
} }
remoteCidr := new(bart.Table[struct{}]) i.networks = new(bart.Table[struct{}])
for _, network := range c.Networks() { for _, network := range c.Networks() {
remoteCidr.Insert(network, struct{}{}) i.networks.Insert(network, struct{}{})
} }
for _, network := range c.UnsafeNetworks() { for _, network := range c.UnsafeNetworks() {
remoteCidr.Insert(network, struct{}{}) i.networks.Insert(network, struct{}{})
} }
i.remoteCidr = remoteCidr
} }
func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry { func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
@ -669,7 +679,7 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
return logrus.NewEntry(l) return logrus.NewEntry(l)
} }
li := l.WithField("vpnIp", i.vpnIp). li := l.WithField("vpnAddrs", i.vpnAddrs).
WithField("localIndex", i.localIndexId). WithField("localIndex", i.localIndexId).
WithField("remoteIndex", i.remoteIndexId) WithField("remoteIndex", i.remoteIndexId)
@ -684,9 +694,9 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
// Utility functions // Utility functions
func localIps(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr { func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
//FIXME: This function is pretty garbage //FIXME: This function is pretty garbage
var ips []netip.Addr var finalAddrs []netip.Addr
ifaces, _ := net.Interfaces() ifaces, _ := net.Interfaces()
for _, i := range ifaces { for _, i := range ifaces {
allow := allowList.AllowName(i.Name) allow := allowList.AllowName(i.Name)
@ -698,39 +708,38 @@ func localIps(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
continue continue
} }
addrs, _ := i.Addrs() addrs, _ := i.Addrs()
for _, addr := range addrs { for _, rawAddr := range addrs {
var ip net.IP var addr netip.Addr
switch v := addr.(type) { switch v := rawAddr.(type) {
case *net.IPNet: case *net.IPNet:
//continue //continue
ip = v.IP addr, _ = netip.AddrFromSlice(v.IP)
case *net.IPAddr: case *net.IPAddr:
ip = v.IP addr, _ = netip.AddrFromSlice(v.IP)
} }
nip, ok := netip.AddrFromSlice(ip) if !addr.IsValid() {
if !ok {
if l.Level >= logrus.DebugLevel { if l.Level >= logrus.DebugLevel {
l.WithField("localIp", ip).Debug("ip was invalid for netip") l.WithField("localAddr", rawAddr).Debug("addr was invalid")
} }
continue continue
} }
nip = nip.Unmap() addr = addr.Unmap()
//TODO: Filtering out link local for now, this is probably the most correct thing //TODO: Filtering out link local for now, this is probably the most correct thing
//TODO: Would be nice to filter out SLAAC MAC based ips as well //TODO: Would be nice to filter out SLAAC MAC based ips as well
if nip.IsLoopback() == false && nip.IsLinkLocalUnicast() == false { if addr.IsLoopback() == false && addr.IsLinkLocalUnicast() == false {
allow := allowList.Allow(nip) isAllowed := allowList.Allow(addr)
if l.Level >= logrus.TraceLevel { if l.Level >= logrus.TraceLevel {
l.WithField("localIp", nip).WithField("allow", allow).Trace("localAllowList.Allow") l.WithField("localAddr", addr).WithField("allowed", isAllowed).Trace("localAllowList.Allow")
} }
if !allow { if !isAllowed {
continue continue
} }
ips = append(ips, nip) finalAddrs = append(finalAddrs, addr)
} }
} }
} }
return ips return finalAddrs
} }

View File

@ -11,17 +11,14 @@ import (
func TestHostMap_MakePrimary(t *testing.T) { func TestHostMap_MakePrimary(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
hm := newHostMap( hm := newHostMap(l)
l,
netip.MustParsePrefix("10.0.0.1/24"),
)
f := &Interface{} f := &Interface{}
h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1} h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1}
h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2} h2 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 2}
h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3} h3 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 3}
h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4} h4 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 4}
hm.unlockedAddHostInfo(h4, f) hm.unlockedAddHostInfo(h4, f)
hm.unlockedAddHostInfo(h3, f) hm.unlockedAddHostInfo(h3, f)
@ -29,7 +26,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
hm.unlockedAddHostInfo(h1, f) hm.unlockedAddHostInfo(h1, f)
// Make sure we go h1 -> h2 -> h3 -> h4 // Make sure we go h1 -> h2 -> h3 -> h4
prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) prim := hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h1.localIndexId, prim.localIndexId) assert.Equal(t, h1.localIndexId, prim.localIndexId)
assert.Equal(t, h2.localIndexId, prim.next.localIndexId) assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev) assert.Nil(t, prim.prev)
@ -44,7 +41,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
hm.MakePrimary(h3) hm.MakePrimary(h3)
// Make sure we go h3 -> h1 -> h2 -> h4 // Make sure we go h3 -> h1 -> h2 -> h4
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h3.localIndexId, prim.localIndexId) assert.Equal(t, h3.localIndexId, prim.localIndexId)
assert.Equal(t, h1.localIndexId, prim.next.localIndexId) assert.Equal(t, h1.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev) assert.Nil(t, prim.prev)
@ -59,7 +56,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
hm.MakePrimary(h4) hm.MakePrimary(h4)
// Make sure we go h4 -> h3 -> h1 -> h2 // Make sure we go h4 -> h3 -> h1 -> h2
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Equal(t, h4.localIndexId, prim.localIndexId)
assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev) assert.Nil(t, prim.prev)
@ -74,7 +71,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
hm.MakePrimary(h4) hm.MakePrimary(h4)
// Make sure we go h4 -> h3 -> h1 -> h2 // Make sure we go h4 -> h3 -> h1 -> h2
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Equal(t, h4.localIndexId, prim.localIndexId)
assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev) assert.Nil(t, prim.prev)
@ -88,19 +85,16 @@ func TestHostMap_MakePrimary(t *testing.T) {
func TestHostMap_DeleteHostInfo(t *testing.T) { func TestHostMap_DeleteHostInfo(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
hm := newHostMap( hm := newHostMap(l)
l,
netip.MustParsePrefix("10.0.0.1/24"),
)
f := &Interface{} f := &Interface{}
h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1} h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1}
h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2} h2 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 2}
h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3} h3 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 3}
h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4} h4 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 4}
h5 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 5} h5 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 5}
h6 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 6} h6 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 6}
hm.unlockedAddHostInfo(h6, f) hm.unlockedAddHostInfo(h6, f)
hm.unlockedAddHostInfo(h5, f) hm.unlockedAddHostInfo(h5, f)
@ -116,7 +110,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h) assert.Nil(t, h)
// Make sure we go h1 -> h2 -> h3 -> h4 -> h5 // Make sure we go h1 -> h2 -> h3 -> h4 -> h5
prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) prim := hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h1.localIndexId, prim.localIndexId) assert.Equal(t, h1.localIndexId, prim.localIndexId)
assert.Equal(t, h2.localIndexId, prim.next.localIndexId) assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev) assert.Nil(t, prim.prev)
@ -135,7 +129,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h1.next) assert.Nil(t, h1.next)
// Make sure we go h2 -> h3 -> h4 -> h5 // Make sure we go h2 -> h3 -> h4 -> h5
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h2.localIndexId, prim.localIndexId)
assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev) assert.Nil(t, prim.prev)
@ -153,7 +147,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h3.next) assert.Nil(t, h3.next)
// Make sure we go h2 -> h4 -> h5 // Make sure we go h2 -> h4 -> h5
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h2.localIndexId, prim.localIndexId)
assert.Equal(t, h4.localIndexId, prim.next.localIndexId) assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev) assert.Nil(t, prim.prev)
@ -169,7 +163,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h5.next) assert.Nil(t, h5.next)
// Make sure we go h2 -> h4 // Make sure we go h2 -> h4
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h2.localIndexId, prim.localIndexId)
assert.Equal(t, h4.localIndexId, prim.next.localIndexId) assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev) assert.Nil(t, prim.prev)
@ -183,7 +177,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h2.next) assert.Nil(t, h2.next)
// Make sure we only have h4 // Make sure we only have h4
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Equal(t, h4.localIndexId, prim.localIndexId)
assert.Nil(t, prim.prev) assert.Nil(t, prim.prev)
assert.Nil(t, prim.next) assert.Nil(t, prim.next)
@ -195,7 +189,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h4.next) assert.Nil(t, h4.next)
// Make sure we have nil // Make sure we have nil
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
assert.Nil(t, prim) assert.Nil(t, prim)
} }
@ -203,11 +197,7 @@ func TestHostMap_reload(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
c := config.NewC(l) c := config.NewC(l)
hm := NewHostMapFromConfig( hm := NewHostMapFromConfig(l, c)
l,
netip.MustParsePrefix("10.0.0.1/24"),
c,
)
toS := func(ipn []netip.Prefix) []string { toS := func(ipn []netip.Prefix) []string {
var s []string var s []string

View File

@ -9,8 +9,8 @@ import (
"net/netip" "net/netip"
) )
func (i *HostInfo) GetVpnIp() netip.Addr { func (i *HostInfo) GetVpnAddrs() []netip.Addr {
return i.vpnIp return i.vpnAddrs
} }
func (i *HostInfo) GetLocalIndex() uint32 { func (i *HostInfo) GetLocalIndex() uint32 {

View File

@ -20,14 +20,19 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
} }
// Ignore local broadcast packets // Ignore local broadcast packets
if f.dropLocalBroadcast && fwPacket.RemoteIP == f.myBroadcastAddr { if f.dropLocalBroadcast {
_, found := f.myBroadcastAddrsTable.Lookup(fwPacket.RemoteAddr)
if found {
return return
} }
}
if fwPacket.RemoteIP == f.myVpnNet.Addr() { //TODO: seems like a huge bummer
_, found := f.myVpnAddrsTable.Lookup(fwPacket.RemoteAddr)
if found {
// Immediately forward packets from self to self. // Immediately forward packets from self to self.
// This should only happen on Darwin-based and FreeBSD hosts, which // This should only happen on Darwin-based and FreeBSD hosts, which
// routes packets from the Nebula IP to the Nebula IP through the Nebula // routes packets from the Nebula addr to the Nebula addr through the Nebula
// TUN device. // TUN device.
if immediatelyForwardToSelf { if immediatelyForwardToSelf {
_, err := f.readers[q].Write(packet) _, err := f.readers[q].Write(packet)
@ -36,25 +41,25 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
} }
} }
// Otherwise, drop. On linux, we should never see these packets - Linux // Otherwise, drop. On linux, we should never see these packets - Linux
// routes packets from the nebula IP to the nebula IP through the loopback device. // routes packets from the nebula addr to the nebula addr through the loopback device.
return return
} }
// Ignore multicast packets // Ignore multicast packets
if f.dropMulticast && fwPacket.RemoteIP.IsMulticast() { if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() {
return return
} }
hostinfo, ready := f.getOrHandshake(fwPacket.RemoteIP, func(hh *HandshakeHostInfo) { hostinfo, ready := f.getOrHandshake(fwPacket.RemoteAddr, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
}) })
if hostinfo == nil { if hostinfo == nil {
f.rejectInside(packet, out, q) f.rejectInside(packet, out, q)
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnIp", fwPacket.RemoteIP). f.l.WithField("vpnAddr", fwPacket.RemoteAddr).
WithField("fwPacket", fwPacket). WithField("fwPacket", fwPacket).
Debugln("dropping outbound packet, vpnIp not in our CIDR or in unsafe routes") Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks")
} }
return return
} }
@ -117,21 +122,22 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q) f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
} }
func (f *Interface) Handshake(vpnIp netip.Addr) { func (f *Interface) Handshake(vpnAddr netip.Addr) {
f.getOrHandshake(vpnIp, nil) f.getOrHandshake(vpnAddr, nil)
} }
// getOrHandshake returns nil if the vpnIp is not routable. // getOrHandshake returns nil if the vpnAddr is not routable.
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel
func (f *Interface) getOrHandshake(vpnIp netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { func (f *Interface) getOrHandshake(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
if !f.myVpnNet.Contains(vpnIp) { _, found := f.myVpnNetworksTable.Lookup(vpnAddr)
vpnIp = f.inside.RouteFor(vpnIp) if !found {
if !vpnIp.IsValid() { vpnAddr = f.inside.RouteFor(vpnAddr)
if !vpnAddr.IsValid() {
return nil, false return nil, false
} }
} }
return f.handshakeManager.GetOrHandshake(vpnIp, cacheCallback) return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback)
} }
func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) { func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
@ -156,16 +162,16 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0) f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0)
} }
// SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp // SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr
func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) { func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) {
hostInfo, ready := f.getOrHandshake(vpnIp, func(hh *HandshakeHostInfo) { hostInfo, ready := f.getOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics) hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
}) })
if hostInfo == nil { if hostInfo == nil {
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnIp", vpnIp). f.l.WithField("vpnAddr", vpnAddr).
Debugln("dropping SendMessageToVpnIp, vpnIp not in our CIDR or in unsafe routes") Debugln("dropping SendMessageToVpnAddr, vpnAddr not in our vpn networks or in unsafe routes")
} }
return return
} }
@ -285,14 +291,14 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
f.connectionManager.Out(hostinfo.localIndexId) f.connectionManager.Out(hostinfo.localIndexId)
// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against // Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
// all our IPs and enable a faster roaming. // all our addrs and enable a faster roaming.
if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount { if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount {
//NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is //NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is
// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help. // finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
f.lightHouse.QueryServer(hostinfo.vpnIp) f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
hostinfo.lastRebindCount = f.rebindCount hostinfo.lastRebindCount = f.rebindCount
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnIp", hostinfo.vpnIp).Debug("Lighthouse update triggered for punch due to rebind counter") f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter")
} }
} }
@ -324,7 +330,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
} else { } else {
// Try to send via a relay // Try to send via a relay
for _, relayIP := range hostinfo.relayState.CopyRelayIps() { for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
relayHostInfo, relay, err := f.hostMap.QueryVpnIpRelayFor(hostinfo.vpnIp, relayIP) relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
if err != nil { if err != nil {
hostinfo.relayState.DeleteRelay(relayIP) hostinfo.relayState.DeleteRelay(relayIP)
hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo") hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")

View File

@ -2,17 +2,16 @@ package nebula
import ( import (
"context" "context"
"encoding/binary"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net"
"net/netip" "net/netip"
"os" "os"
"runtime" "runtime"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
@ -29,7 +28,6 @@ type InterfaceConfig struct {
Outside udp.Conn Outside udp.Conn
Inside overlay.Device Inside overlay.Device
pki *PKI pki *PKI
Cipher string
Firewall *Firewall Firewall *Firewall
ServeDns bool ServeDns bool
HandshakeManager *HandshakeManager HandshakeManager *HandshakeManager
@ -57,15 +55,17 @@ type Interface struct {
outside udp.Conn outside udp.Conn
inside overlay.Device inside overlay.Device
pki *PKI pki *PKI
cipher string
firewall *Firewall firewall *Firewall
connectionManager *connectionManager connectionManager *connectionManager
handshakeManager *HandshakeManager handshakeManager *HandshakeManager
serveDns bool serveDns bool
createTime time.Time createTime time.Time
lightHouse *LightHouse lightHouse *LightHouse
myBroadcastAddr netip.Addr myBroadcastAddrsTable *bart.Table[struct{}]
myVpnNet netip.Prefix myVpnAddrs []netip.Addr // A list of addresses assigned to us via our certificate
myVpnAddrsTable *bart.Table[struct{}] // A table of addresses assigned to us via our certificate
myVpnNetworks []netip.Prefix // A table of networks assigned to us via our certificate
myVpnNetworksTable *bart.Table[struct{}] // A table of networks assigned to us via our certificate
dropLocalBroadcast bool dropLocalBroadcast bool
dropMulticast bool dropMulticast bool
routines int routines int
@ -103,9 +103,11 @@ type EncWriter interface {
out []byte, out []byte,
nocopy bool, nocopy bool,
) )
SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte)
SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte)
Handshake(vpnIp netip.Addr) Handshake(vpnAddr netip.Addr)
GetHostInfo(vpnAddr netip.Addr) *HostInfo
GetCertState() *CertState
} }
type sendRecvErrorConfig uint8 type sendRecvErrorConfig uint8
@ -116,10 +118,10 @@ const (
sendRecvErrorPrivate sendRecvErrorPrivate
) )
func (s sendRecvErrorConfig) ShouldSendRecvError(ip netip.AddrPort) bool { func (s sendRecvErrorConfig) ShouldSendRecvError(endpoint netip.AddrPort) bool {
switch s { switch s {
case sendRecvErrorPrivate: case sendRecvErrorPrivate:
return ip.Addr().IsPrivate() return endpoint.Addr().IsPrivate()
case sendRecvErrorAlways: case sendRecvErrorAlways:
return true return true
case sendRecvErrorNever: case sendRecvErrorNever:
@ -156,14 +158,12 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
return nil, errors.New("no firewall rules") return nil, errors.New("no firewall rules")
} }
certificate := c.pki.GetCertState().Certificate cs := c.pki.getCertState()
ifce := &Interface{ ifce := &Interface{
pki: c.pki, pki: c.pki,
hostMap: c.HostMap, hostMap: c.HostMap,
outside: c.Outside, outside: c.Outside,
inside: c.Inside, inside: c.Inside,
cipher: c.Cipher,
firewall: c.Firewall, firewall: c.Firewall,
serveDns: c.ServeDns, serveDns: c.ServeDns,
handshakeManager: c.HandshakeManager, handshakeManager: c.HandshakeManager,
@ -175,7 +175,11 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
version: c.version, version: c.version,
writers: make([]udp.Conn, c.routines), writers: make([]udp.Conn, c.routines),
readers: make([]io.ReadWriteCloser, c.routines), readers: make([]io.ReadWriteCloser, c.routines),
myVpnNet: certificate.Networks()[0], myVpnNetworks: cs.myVpnNetworks,
myVpnNetworksTable: cs.myVpnNetworksTable,
myVpnAddrs: cs.myVpnAddrs,
myVpnAddrsTable: cs.myVpnAddrsTable,
myBroadcastAddrsTable: cs.myVpnBroadcastAddrsTable,
relayManager: c.relayManager, relayManager: c.relayManager,
conntrackCacheTimeout: c.ConntrackCacheTimeout, conntrackCacheTimeout: c.ConntrackCacheTimeout,
@ -190,14 +194,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
l: c.l, l: c.l,
} }
if ifce.myVpnNet.Addr().Is4() {
maskedAddr := certificate.Networks()[0].Masked()
addr := maskedAddr.Addr().As4()
mask := net.CIDRMask(maskedAddr.Bits(), maskedAddr.Addr().BitLen())
binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(mask))
ifce.myBroadcastAddr = netip.AddrFrom4(addr)
}
ifce.tryPromoteEvery.Store(c.tryPromoteEvery) ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
ifce.reQueryEvery.Store(c.reQueryEvery) ifce.reQueryEvery.Store(c.reQueryEvery)
ifce.reQueryWait.Store(int64(c.reQueryWait)) ifce.reQueryWait.Store(int64(c.reQueryWait))
@ -218,7 +214,7 @@ func (f *Interface) activate() {
f.l.WithError(err).Error("Failed to get udp listen address") f.l.WithError(err).Error("Failed to get udp listen address")
} }
f.l.WithField("interface", f.inside.Name()).WithField("network", f.inside.Cidr().String()). f.l.WithField("interface", f.inside.Name()).WithField("networks", f.myVpnNetworks).
WithField("build", f.version).WithField("udpAddr", addr). WithField("build", f.version).WithField("udpAddr", addr).
WithField("boringcrypto", boringEnabled()). WithField("boringcrypto", boringEnabled()).
Info("Nebula interface is active") Info("Nebula interface is active")
@ -259,16 +255,22 @@ func (f *Interface) listenOut(i int) {
runtime.LockOSThread() runtime.LockOSThread()
var li udp.Conn var li udp.Conn
// TODO clean this up with a coherent interface for each outside connection
if i > 0 { if i > 0 {
li = f.writers[i] li = f.writers[i]
} else { } else {
li = f.outside li = f.outside
} }
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
lhh := f.lightHouse.NewRequestHandler() lhh := f.lightHouse.NewRequestHandler()
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) plaintext := make([]byte, udp.MTU)
li.ListenOut(readOutsidePackets(f), lhHandleRequest(lhh, f), conntrackCache, i) h := &header.H{}
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
})
} }
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
@ -325,7 +327,7 @@ func (f *Interface) reloadFirewall(c *config.C) {
return return
} }
fw, err := NewFirewallFromConfig(f.l, f.pki.GetCertState().Certificate, c) fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c)
if err != nil { if err != nil {
f.l.WithError(err).Error("Error while creating firewall during reload") f.l.WithError(err).Error("Error while creating firewall during reload")
return return
@ -417,11 +419,20 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
f.firewall.EmitStats() f.firewall.EmitStats()
f.handshakeManager.EmitStats() f.handshakeManager.EmitStats()
udpStats() udpStats()
certExpirationGauge.Update(int64(f.pki.GetCertState().Certificate.NotAfter().Sub(time.Now()) / time.Second)) certExpirationGauge.Update(int64(f.pki.getDefaultCertificate().NotAfter().Sub(time.Now()) / time.Second))
//TODO: we should also report the default certificate version
} }
} }
} }
func (f *Interface) GetHostInfo(vpnIp netip.Addr) *HostInfo {
return f.hostMap.QueryVpnAddr(vpnIp)
}
func (f *Interface) GetCertState() *CertState {
return f.pki.getCertState()
}
func (f *Interface) Close() error { func (f *Interface) Close() error {
f.closed.Store(true) f.closed.Store(true)

File diff suppressed because it is too large Load Diff

View File

@ -7,6 +7,8 @@ import (
"net/netip" "net/netip"
"testing" "testing"
"github.com/gaissmai/bart"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
@ -19,57 +21,48 @@ import (
func TestOldIPv4Only(t *testing.T) { func TestOldIPv4Only(t *testing.T) {
// This test ensures our new ipv6 enabled LH protobuf IpAndPorts works with the old style to enable backwards compatibility // This test ensures our new ipv6 enabled LH protobuf IpAndPorts works with the old style to enable backwards compatibility
b := []byte{8, 129, 130, 132, 80, 16, 10} b := []byte{8, 129, 130, 132, 80, 16, 10}
var m Ip4AndPort var m V4AddrPort
err := m.Unmarshal(b) err := m.Unmarshal(b)
assert.NoError(t, err) assert.NoError(t, err)
ip := netip.MustParseAddr("10.1.1.1") ip := netip.MustParseAddr("10.1.1.1")
bp := ip.As4() bp := ip.As4()
assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetIp()) assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetAddr())
}
func TestNewLhQuery(t *testing.T) {
myIp, err := netip.ParseAddr("192.1.1.1")
assert.NoError(t, err)
// Generating a new lh query should work
a := NewLhQueryByInt(myIp)
// The result should be a nebulameta protobuf
assert.IsType(t, &NebulaMeta{}, a)
// It should also Marshal fine
b, err := a.Marshal()
assert.Nil(t, err)
// and then Unmarshal fine
n := &NebulaMeta{}
err = n.Unmarshal(b)
assert.Nil(t, err)
} }
func Test_lhStaticMapping(t *testing.T) { func Test_lhStaticMapping(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
myVpnNet := netip.MustParsePrefix("10.128.0.1/16") myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
nt := new(bart.Table[struct{}])
nt.Insert(myVpnNet, struct{}{})
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
}
lh1 := "10.128.0.2" lh1 := "10.128.0.2"
c := config.NewC(l) c := config.NewC(l)
c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}} c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}}
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}} c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
_, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) _, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
assert.Nil(t, err) assert.Nil(t, err)
lh2 := "10.128.0.3" lh2 := "10.128.0.3"
c = config.NewC(l) c = config.NewC(l)
c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}} c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}}
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}} c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}}
_, err = NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) _, err = NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry") assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry")
} }
func TestReloadLighthouseInterval(t *testing.T) { func TestReloadLighthouseInterval(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
myVpnNet := netip.MustParsePrefix("10.128.0.1/16") myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
nt := new(bart.Table[struct{}])
nt.Insert(myVpnNet, struct{}{})
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
}
lh1 := "10.128.0.2" lh1 := "10.128.0.2"
c := config.NewC(l) c := config.NewC(l)
@ -79,7 +72,7 @@ func TestReloadLighthouseInterval(t *testing.T) {
} }
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}} c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
assert.NoError(t, err) assert.NoError(t, err)
lh.ifce = &mockEncWriter{} lh.ifce = &mockEncWriter{}
@ -99,9 +92,15 @@ func TestReloadLighthouseInterval(t *testing.T) {
func BenchmarkLighthouseHandleRequest(b *testing.B) { func BenchmarkLighthouseHandleRequest(b *testing.B) {
l := test.NewLogger() l := test.NewLogger()
myVpnNet := netip.MustParsePrefix("10.128.0.1/0") myVpnNet := netip.MustParsePrefix("10.128.0.1/0")
nt := new(bart.Table[struct{}])
nt.Insert(myVpnNet, struct{}{})
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
}
c := config.NewC(l) c := config.NewC(l)
lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
if !assert.NoError(b, err) { if !assert.NoError(b, err) {
b.Fatal() b.Fatal()
} }
@ -110,46 +109,47 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346") hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346")
vpnIp3 := netip.MustParseAddr("0.0.0.3") vpnIp3 := netip.MustParseAddr("0.0.0.3")
lh.addrMap[vpnIp3] = NewRemoteList(nil) lh.addrMap[vpnIp3] = NewRemoteList([]netip.Addr{vpnIp3}, nil)
lh.addrMap[vpnIp3].unlockedSetV4( lh.addrMap[vpnIp3].unlockedSetV4(
vpnIp3, vpnIp3,
vpnIp3, vpnIp3,
[]*Ip4AndPort{ []*V4AddrPort{
NewIp4AndPortFromNetIP(hAddr.Addr(), hAddr.Port()), netAddrToProtoV4AddrPort(hAddr.Addr(), hAddr.Port()),
NewIp4AndPortFromNetIP(hAddr2.Addr(), hAddr2.Port()), netAddrToProtoV4AddrPort(hAddr2.Addr(), hAddr2.Port()),
}, },
func(netip.Addr, *Ip4AndPort) bool { return true }, func(netip.Addr, *V4AddrPort) bool { return true },
) )
rAddr := netip.MustParseAddrPort("1.2.2.3:12345") rAddr := netip.MustParseAddrPort("1.2.2.3:12345")
rAddr2 := netip.MustParseAddrPort("1.2.2.3:12346") rAddr2 := netip.MustParseAddrPort("1.2.2.3:12346")
vpnIp2 := netip.MustParseAddr("0.0.0.3") vpnIp2 := netip.MustParseAddr("0.0.0.3")
lh.addrMap[vpnIp2] = NewRemoteList(nil) lh.addrMap[vpnIp2] = NewRemoteList([]netip.Addr{vpnIp2}, nil)
lh.addrMap[vpnIp2].unlockedSetV4( lh.addrMap[vpnIp2].unlockedSetV4(
vpnIp3, vpnIp3,
vpnIp3, vpnIp3,
[]*Ip4AndPort{ []*V4AddrPort{
NewIp4AndPortFromNetIP(rAddr.Addr(), rAddr.Port()), netAddrToProtoV4AddrPort(rAddr.Addr(), rAddr.Port()),
NewIp4AndPortFromNetIP(rAddr2.Addr(), rAddr2.Port()), netAddrToProtoV4AddrPort(rAddr2.Addr(), rAddr2.Port()),
}, },
func(netip.Addr, *Ip4AndPort) bool { return true }, func(netip.Addr, *V4AddrPort) bool { return true },
) )
mw := &mockEncWriter{} mw := &mockEncWriter{}
hi := []netip.Addr{vpnIp2}
b.Run("notfound", func(b *testing.B) { b.Run("notfound", func(b *testing.B) {
lhh := lh.NewRequestHandler() lhh := lh.NewRequestHandler()
req := &NebulaMeta{ req := &NebulaMeta{
Type: NebulaMeta_HostQuery, Type: NebulaMeta_HostQuery,
Details: &NebulaMetaDetails{ Details: &NebulaMetaDetails{
VpnIp: 4, OldVpnAddr: 4,
Ip4AndPorts: nil, V4AddrPorts: nil,
}, },
} }
p, err := req.Marshal() p, err := req.Marshal()
assert.NoError(b, err) assert.NoError(b, err)
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
lhh.HandleRequest(rAddr, vpnIp2, p, mw) lhh.HandleRequest(rAddr, hi, p, mw)
} }
}) })
b.Run("found", func(b *testing.B) { b.Run("found", func(b *testing.B) {
@ -157,15 +157,15 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
req := &NebulaMeta{ req := &NebulaMeta{
Type: NebulaMeta_HostQuery, Type: NebulaMeta_HostQuery,
Details: &NebulaMetaDetails{ Details: &NebulaMetaDetails{
VpnIp: 3, OldVpnAddr: 3,
Ip4AndPorts: nil, V4AddrPorts: nil,
}, },
} }
p, err := req.Marshal() p, err := req.Marshal()
assert.NoError(b, err) assert.NoError(b, err)
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
lhh.HandleRequest(rAddr, vpnIp2, p, mw) lhh.HandleRequest(rAddr, hi, p, mw)
} }
}) })
} }
@ -197,40 +197,49 @@ func TestLighthouse_Memory(t *testing.T) {
c := config.NewC(l) c := config.NewC(l)
c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true} c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
c.Settings["listen"] = map[interface{}]interface{}{"port": 4242} c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil)
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
nt := new(bart.Table[struct{}])
nt.Insert(myVpnNet, struct{}{})
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
lh.ifce = &mockEncWriter{}
assert.NoError(t, err) assert.NoError(t, err)
lhh := lh.NewRequestHandler() lhh := lh.NewRequestHandler()
// Test that my first update responds with just that // Test that my first update responds with just that
newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr2}, lhh) newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr2}, lhh)
r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr2) assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr2)
// Ensure we don't accumulate addresses // Ensure we don't accumulate addresses
newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr3}, lhh) newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr3}, lhh)
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr3) assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr3)
// Grow it back to 2 // Grow it back to 2
newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr4}, lhh) newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr4}, lhh)
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4) assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4)
// Update a different host and ask about it // Update a different host and ask about it
newLHHostUpdate(theirUdpAddr0, theirVpnIp, []netip.AddrPort{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh) newLHHostUpdate(theirUdpAddr0, theirVpnIp, []netip.AddrPort{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh)
r = newLHHostRequest(theirUdpAddr0, theirVpnIp, theirVpnIp, lhh) r = newLHHostRequest(theirUdpAddr0, theirVpnIp, theirVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4) assertIp4InArray(t, r.msg.Details.V4AddrPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
// Have both hosts ask about the other // Have both hosts ask about the other
r = newLHHostRequest(theirUdpAddr0, theirVpnIp, myVpnIp, lhh) r = newLHHostRequest(theirUdpAddr0, theirVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4) assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4)
r = newLHHostRequest(myUdpAddr0, myVpnIp, theirVpnIp, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, theirVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4) assertIp4InArray(t, r.msg.Details.V4AddrPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
// Make sure we didn't get changed // Make sure we didn't get changed
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4) assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4)
// Ensure proper ordering and limiting // Ensure proper ordering and limiting
// Send 12 addrs, get 10 back, the last 2 removed, allowing the duplicate to remain (clients dedupe) // Send 12 addrs, get 10 back, the last 2 removed, allowing the duplicate to remain (clients dedupe)
@ -255,7 +264,7 @@ func TestLighthouse_Memory(t *testing.T) {
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray( assertIp4InArray(
t, t,
r.msg.Details.Ip4AndPorts, r.msg.Details.V4AddrPorts,
myUdpAddr1, myUdpAddr2, myUdpAddr3, myUdpAddr4, myUdpAddr5, myUdpAddr5, myUdpAddr6, myUdpAddr7, myUdpAddr8, myUdpAddr9, myUdpAddr1, myUdpAddr2, myUdpAddr3, myUdpAddr4, myUdpAddr5, myUdpAddr5, myUdpAddr6, myUdpAddr7, myUdpAddr8, myUdpAddr9,
) )
@ -265,7 +274,7 @@ func TestLighthouse_Memory(t *testing.T) {
good := netip.MustParseAddrPort("1.128.0.99:4242") good := netip.MustParseAddrPort("1.128.0.99:4242")
newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{bad1, bad2, good}, lhh) newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{bad1, bad2, good}, lhh)
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, good) assertIp4InArray(t, r.msg.Details.V4AddrPorts, good)
} }
func TestLighthouse_reload(t *testing.T) { func TestLighthouse_reload(t *testing.T) {
@ -273,7 +282,16 @@ func TestLighthouse_reload(t *testing.T) {
c := config.NewC(l) c := config.NewC(l)
c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true} c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
c.Settings["listen"] = map[interface{}]interface{}{"port": 4242} c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil)
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
nt := new(bart.Table[struct{}])
nt.Insert(myVpnNet, struct{}{})
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
assert.NoError(t, err) assert.NoError(t, err)
nc := map[interface{}]interface{}{ nc := map[interface{}]interface{}{
@ -295,7 +313,7 @@ func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, l
req := &NebulaMeta{ req := &NebulaMeta{
Type: NebulaMeta_HostQuery, Type: NebulaMeta_HostQuery,
Details: &NebulaMetaDetails{ Details: &NebulaMetaDetails{
VpnIp: binary.BigEndian.Uint32(bip[:]), OldVpnAddr: binary.BigEndian.Uint32(bip[:]),
}, },
} }
@ -308,7 +326,7 @@ func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, l
w := &testEncWriter{ w := &testEncWriter{
metaFilter: &filter, metaFilter: &filter,
} }
lhh.HandleRequest(fromAddr, myVpnIp, b, w) lhh.HandleRequest(fromAddr, []netip.Addr{myVpnIp}, b, w)
return w.lastReply return w.lastReply
} }
@ -318,13 +336,13 @@ func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.Ad
req := &NebulaMeta{ req := &NebulaMeta{
Type: NebulaMeta_HostUpdateNotification, Type: NebulaMeta_HostUpdateNotification,
Details: &NebulaMetaDetails{ Details: &NebulaMetaDetails{
VpnIp: binary.BigEndian.Uint32(bip[:]), OldVpnAddr: binary.BigEndian.Uint32(bip[:]),
Ip4AndPorts: make([]*Ip4AndPort, len(addrs)), V4AddrPorts: make([]*V4AddrPort, len(addrs)),
}, },
} }
for k, v := range addrs { for k, v := range addrs {
req.Details.Ip4AndPorts[k] = NewIp4AndPortFromNetIP(v.Addr(), v.Port()) req.Details.V4AddrPorts[k] = netAddrToProtoV4AddrPort(v.Addr(), v.Port())
} }
b, err := req.Marshal() b, err := req.Marshal()
@ -333,7 +351,7 @@ func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.Ad
} }
w := &testEncWriter{} w := &testEncWriter{}
lhh.HandleRequest(fromAddr, vpnIp, b, w) lhh.HandleRequest(fromAddr, []netip.Addr{vpnIp}, b, w)
} }
//TODO: this is a RemoteList test //TODO: this is a RemoteList test
@ -412,6 +430,7 @@ type testLhReply struct {
type testEncWriter struct { type testEncWriter struct {
lastReply testLhReply lastReply testLhReply
metaFilter *NebulaMeta_MessageType metaFilter *NebulaMeta_MessageType
protocolVersion cert.Version
} }
func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) { func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
@ -426,7 +445,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M
tw.lastReply = testLhReply{ tw.lastReply = testLhReply{
nebType: t, nebType: t,
nebSubType: st, nebSubType: st,
vpnIp: hostinfo.vpnIp, vpnIp: hostinfo.vpnAddrs[0],
msg: msg, msg: msg,
} }
} }
@ -436,7 +455,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M
} }
} }
func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, _, _ []byte) { func (tw *testEncWriter) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, _, _ []byte) {
msg := &NebulaMeta{} msg := &NebulaMeta{}
err := msg.Unmarshal(p) err := msg.Unmarshal(p)
if tw.metaFilter == nil || msg.Type == *tw.metaFilter { if tw.metaFilter == nil || msg.Type == *tw.metaFilter {
@ -453,15 +472,23 @@ func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.Mess
} }
} }
func (tw *testEncWriter) GetHostInfo(vpnIp netip.Addr) *HostInfo {
return nil
}
func (tw *testEncWriter) GetCertState() *CertState {
return &CertState{defaultVersion: tw.protocolVersion}
}
// assertIp4InArray asserts every address in want is at the same position in have and that the lengths match // assertIp4InArray asserts every address in want is at the same position in have and that the lengths match
func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...netip.AddrPort) { func assertIp4InArray(t *testing.T, have []*V4AddrPort, want ...netip.AddrPort) {
if !assert.Len(t, have, len(want)) { if !assert.Len(t, have, len(want)) {
return return
} }
for k, w := range want { for k, w := range want {
//TODO: IPV6-WORK //TODO: IPV6-WORK
h := AddrPortFromIp4AndPort(have[k]) h := protoV4AddrPortToNetAddrPort(have[k])
if !(h == w) { if !(h == w) {
assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v, found %v", w, k, h)) assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v, found %v", w, k, h))
} }

24
main.go
View File

@ -2,7 +2,6 @@ package nebula
import ( import (
"context" "context"
"encoding/binary"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
@ -61,15 +60,12 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err) return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err)
} }
certificate := pki.GetCertState().Certificate fw, err := NewFirewallFromConfig(l, pki.getCertState(), c)
fw, err := NewFirewallFromConfig(l, certificate, c)
if err != nil { if err != nil {
return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err) return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err)
} }
l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started") l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started")
tunCidr := certificate.Networks()[0]
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
if err != nil { if err != nil {
return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err) return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err)
@ -132,7 +128,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
deviceFactory = overlay.NewDeviceFromConfig deviceFactory = overlay.NewDeviceFromConfig
} }
tun, err = deviceFactory(c, l, tunCidr, routines) tun, err = deviceFactory(c, l, pki.getCertState().myVpnNetworks, routines)
if err != nil { if err != nil {
return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err) return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err)
} }
@ -187,9 +183,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
} }
} }
hostMap := NewHostMapFromConfig(l, tunCidr, c) hostMap := NewHostMapFromConfig(l, c)
punchy := NewPunchyFromConfig(l, c) punchy := NewPunchyFromConfig(l, c)
lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy) lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy)
if err != nil { if err != nil {
return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err) return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err)
} }
@ -232,7 +228,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
Inside: tun, Inside: tun,
Outside: udpConns[0], Outside: udpConns[0],
pki: pki, pki: pki,
Cipher: c.GetString("cipher", "aes"),
Firewall: fw, Firewall: fw,
ServeDns: serveDns, ServeDns: serveDns,
HandshakeManager: handshakeManager, HandshakeManager: handshakeManager,
@ -254,15 +249,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
l: l, l: l,
} }
switch ifConfig.Cipher {
case "aes":
noiseEndianness = binary.BigEndian
case "chachapoly":
noiseEndianness = binary.LittleEndian
default:
return nil, fmt.Errorf("unknown cipher: %v", ifConfig.Cipher)
}
var ifce *Interface var ifce *Interface
if !configTest { if !configTest {
ifce, err = NewInterface(ctx, ifConfig) ifce, err = NewInterface(ctx, ifConfig)
@ -303,7 +289,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
var dnsStart func() var dnsStart func()
if lightHouse.amLighthouse && serveDns { if lightHouse.amLighthouse && serveDns {
l.Debugln("Starting dns server") l.Debugln("Starting dns server")
dnsStart = dnsMain(l, hostMap, c) dnsStart = dnsMain(l, pki.getCertState(), hostMap, c)
} }
return &Control{ return &Control{

File diff suppressed because it is too large Load Diff

View File

@ -23,19 +23,28 @@ message NebulaMeta {
} }
message NebulaMetaDetails { message NebulaMetaDetails {
uint32 VpnIp = 1; uint32 OldVpnAddr = 1 [deprecated = true];
repeated Ip4AndPort Ip4AndPorts = 2; Addr VpnAddr = 6;
repeated Ip6AndPort Ip6AndPorts = 4;
repeated uint32 RelayVpnIp = 5; repeated uint32 OldRelayVpnAddrs = 5 [deprecated = true];
repeated Addr RelayVpnAddrs = 7;
repeated V4AddrPort V4AddrPorts = 2;
repeated V6AddrPort V6AddrPorts = 4;
uint32 counter = 3; uint32 counter = 3;
} }
message Ip4AndPort { message Addr {
uint32 Ip = 1; uint64 Hi = 1;
uint64 Lo = 2;
}
message V4AddrPort {
uint32 Addr = 1;
uint32 Port = 2; uint32 Port = 2;
} }
message Ip6AndPort { message V6AddrPort {
uint64 Hi = 1; uint64 Hi = 1;
uint64 Lo = 2; uint64 Lo = 2;
uint32 Port = 3; uint32 Port = 3;
@ -62,6 +71,7 @@ message NebulaHandshakeDetails {
uint32 ResponderIndex = 3; uint32 ResponderIndex = 3;
uint64 Cookie = 4; uint64 Cookie = 4;
uint64 Time = 5; uint64 Time = 5;
uint32 CertVersion = 8;
// reserved for WIP multiport // reserved for WIP multiport
reserved 6, 7; reserved 6, 7;
} }
@ -76,6 +86,10 @@ message NebulaControl {
uint32 InitiatorRelayIndex = 2; uint32 InitiatorRelayIndex = 2;
uint32 ResponderRelayIndex = 3; uint32 ResponderRelayIndex = 3;
uint32 RelayToIp = 4;
uint32 RelayFromIp = 5; uint32 OldRelayToAddr = 4 [deprecated = true];
uint32 OldRelayFromAddr = 5 [deprecated = true];
Addr RelayToAddr = 6;
Addr RelayFromAddr = 7;
} }

View File

@ -7,12 +7,12 @@ import (
"net/netip" "net/netip"
"time" "time"
"github.com/flynn/noise" "github.com/google/gopacket/layers"
"golang.org/x/net/ipv6"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/udp"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
) )
@ -20,24 +20,7 @@ const (
minFwPacketLen = 4 minFwPacketLen = 4
) )
// TODO: IPV6-WORK this can likely be removed now func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
func readOutsidePackets(f *Interface) udp.EncReader {
return func(
addr netip.AddrPort,
out []byte,
packet []byte,
header *header.H,
fwPacket *firewall.Packet,
lhh udp.LightHouseHandlerFunc,
nb []byte,
q int,
localCache firewall.ConntrackCache,
) {
f.readOutsidePackets(addr, nil, out, packet, header, fwPacket, lhh, nb, q, localCache)
}
}
func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) {
err := h.Parse(packet) err := h.Parse(packet)
if err != nil { if err != nil {
// TODO: best if we return this and let caller log // TODO: best if we return this and let caller log
@ -51,7 +34,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
//l.Error("in packet ", header, packet[HeaderLen:]) //l.Error("in packet ", header, packet[HeaderLen:])
if ip.IsValid() { if ip.IsValid() {
if f.myVpnNet.Contains(ip.Addr()) { _, found := f.myVpnNetworksTable.Lookup(ip.Addr())
if found {
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet") f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
} }
@ -108,7 +92,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
if !ok { if !ok {
// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing // The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
// its internal mapping. This should never happen. // its internal mapping. This should never happen.
hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnIp": hostinfo.vpnIp, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index") hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
return return
} }
@ -120,9 +104,9 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
return return
case ForwardingType: case ForwardingType:
// Find the target HostInfo relay object // Find the target HostInfo relay object
targetHI, targetRelay, err := f.hostMap.QueryVpnIpRelayFor(hostinfo.vpnIp, relay.PeerIp) targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
if err != nil { if err != nil {
hostinfo.logger(f.l).WithField("relayTo", relay.PeerIp).WithError(err).Info("Failed to find target host info by ip") hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip")
return return
} }
@ -138,7 +122,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal") hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
} }
} else { } else {
hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerIp, "relayFrom": hostinfo.vpnIp, "targetRelayState": targetRelay.State}).Info("Unexpected target relay state") hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
return return
} }
} }
@ -161,7 +145,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
return return
} }
lhf(ip, hostinfo.vpnIp, d) lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f)
// Fallthrough to the bottom to record incoming traffic // Fallthrough to the bottom to record incoming traffic
@ -228,14 +212,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
Error("Failed to decrypt Control packet") Error("Failed to decrypt Control packet")
return return
} }
m := &NebulaControl{}
err = m.Unmarshal(d)
if err != nil {
hostinfo.logger(f.l).WithError(err).Error("Failed to unmarshal control message")
break
}
f.relayManager.HandleControlMsg(hostinfo, m, f) f.relayManager.HandleControlMsg(hostinfo, d, f)
default: default:
f.messageMetrics.Rx(h.Type, h.Subtype, 1) f.messageMetrics.Rx(h.Type, h.Subtype, 1)
@ -252,8 +230,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
func (f *Interface) closeTunnel(hostInfo *HostInfo) { func (f *Interface) closeTunnel(hostInfo *HostInfo) {
final := f.hostMap.DeleteHostInfo(hostInfo) final := f.hostMap.DeleteHostInfo(hostInfo)
if final { if final {
// We no longer have any tunnels with this vpn ip, clear learned lighthouse state to lower memory usage // We no longer have any tunnels with this vpn addr, clear learned lighthouse state to lower memory usage
f.lightHouse.DeleteVpnIp(hostInfo.vpnIp) f.lightHouse.DeleteVpnAddrs(hostInfo.vpnAddrs)
} }
} }
@ -262,25 +240,26 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) {
f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
} }
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, ip netip.AddrPort) { func (f *Interface) handleHostRoaming(hostinfo *HostInfo, vpnAddr netip.AddrPort) {
if ip.IsValid() && hostinfo.remote != ip { if vpnAddr.IsValid() && hostinfo.remote != vpnAddr {
if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, ip.Addr()) { //TODO: this is weird now that we can have multiple vpn addrs
hostinfo.logger(f.l).WithField("newAddr", ip).Debug("lighthouse.remote_allow_list denied roaming") if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnAddrs[0], vpnAddr.Addr()) {
hostinfo.logger(f.l).WithField("newAddr", vpnAddr).Debug("lighthouse.remote_allow_list denied roaming")
return return
} }
if !hostinfo.lastRoam.IsZero() && ip == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { if !hostinfo.lastRoam.IsZero() && vpnAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip). hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", vpnAddr).
Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds) Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
} }
return return
} }
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip). hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", vpnAddr).
Info("Host roamed to new udp ip/port.") Info("Host roamed to new udp ip/port.")
hostinfo.lastRoam = time.Now() hostinfo.lastRoam = time.Now()
hostinfo.lastRoamRemote = hostinfo.remote hostinfo.lastRoamRemote = hostinfo.remote
hostinfo.SetRemote(ip) hostinfo.SetRemote(vpnAddr)
} }
} }
@ -302,14 +281,114 @@ func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h
// newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers // newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers
func newPacket(data []byte, incoming bool, fp *firewall.Packet) error { func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
// Do we at least have an ipv4 header worth of data? if len(data) < 1 {
if len(data) < ipv4.HeaderLen { return errors.New("packet too short")
return fmt.Errorf("packet is less than %v bytes", ipv4.HeaderLen)
} }
// Is it an ipv4 packet? version := int((data[0] >> 4) & 0x0f)
if int((data[0]>>4)&0x0f) != 4 { switch version {
return fmt.Errorf("packet is not ipv4, type: %v", int((data[0]>>4)&0x0f)) case ipv4.Version:
return parseV4(data, incoming, fp)
case ipv6.Version:
return parseV6(data, incoming, fp)
}
return fmt.Errorf("packet is an unknown ip version: %v", version)
}
func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
dataLen := len(data)
if dataLen < ipv6.HeaderLen {
return fmt.Errorf("ipv6 packet is less than %v bytes", ipv4.HeaderLen)
}
if incoming {
fp.RemoteAddr, _ = netip.AddrFromSlice(data[8:24])
fp.LocalAddr, _ = netip.AddrFromSlice(data[24:40])
} else {
fp.LocalAddr, _ = netip.AddrFromSlice(data[8:24])
fp.RemoteAddr, _ = netip.AddrFromSlice(data[24:40])
}
//TODO: whats a reasonable number of extension headers to attempt to parse?
//https://www.ietf.org/archive/id/draft-ietf-6man-eh-limits-00.html
protoAt := 6
offset := 40
for i := 0; i < 24; i++ {
if dataLen < offset {
break
}
proto := layers.IPProtocol(data[protoAt])
//fmt.Println(proto, protoAt)
switch proto {
case layers.IPProtocolICMPv6:
//TODO: we need a new protocol in config language "icmpv6"
fp.Protocol = uint8(proto)
fp.RemotePort = 0
fp.LocalPort = 0
fp.Fragment = false
return nil
case layers.IPProtocolTCP:
if dataLen < offset+4 {
return fmt.Errorf("ipv6 packet was too small")
}
fp.Protocol = uint8(proto)
if incoming {
fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2])
fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
} else {
fp.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2])
fp.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
}
fp.Fragment = false
return nil
case layers.IPProtocolUDP:
if dataLen < offset+4 {
return fmt.Errorf("ipv6 packet was too small")
}
fp.Protocol = uint8(proto)
if incoming {
fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2])
fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
} else {
fp.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2])
fp.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
}
fp.Fragment = false
return nil
case layers.IPProtocolIPv6Fragment:
//TODO: can we determine the protocol?
fp.RemotePort = 0
fp.LocalPort = 0
fp.Fragment = true
return nil
default:
if dataLen < offset+1 {
break
}
next := int(data[offset+1]) * 8
if next == 0 {
// each extension is at least 8 bytes
next = 8
}
protoAt = offset
offset = offset + next
}
}
return fmt.Errorf("could not find payload in ipv6 packet")
}
func parseV4(data []byte, incoming bool, fp *firewall.Packet) error {
// Do we at least have an ipv4 header worth of data?
if len(data) < ipv4.HeaderLen {
return fmt.Errorf("ipv4 packet is less than %v bytes", ipv4.HeaderLen)
} }
// Adjust our start position based on the advertised ip header length // Adjust our start position based on the advertised ip header length
@ -317,7 +396,7 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
// Well formed ip header length? // Well formed ip header length?
if ihl < ipv4.HeaderLen { if ihl < ipv4.HeaderLen {
return fmt.Errorf("packet had an invalid header length: %v", ihl) return fmt.Errorf("ipv4 packet had an invalid header length: %v", ihl)
} }
// Check if this is the second or further fragment of a fragmented packet. // Check if this is the second or further fragment of a fragmented packet.
@ -333,14 +412,13 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
minLen += minFwPacketLen minLen += minFwPacketLen
} }
if len(data) < minLen { if len(data) < minLen {
return fmt.Errorf("packet is less than %v bytes, ip header len: %v", minLen, ihl) return fmt.Errorf("ipv4 packet is less than %v bytes, ip header len: %v", minLen, ihl)
} }
// Firewall packets are locally oriented // Firewall packets are locally oriented
if incoming { if incoming {
//TODO: IPV6-WORK fp.RemoteAddr, _ = netip.AddrFromSlice(data[12:16])
fp.RemoteIP, _ = netip.AddrFromSlice(data[12:16]) fp.LocalAddr, _ = netip.AddrFromSlice(data[16:20])
fp.LocalIP, _ = netip.AddrFromSlice(data[16:20])
if fp.Fragment || fp.Protocol == firewall.ProtoICMP { if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
fp.RemotePort = 0 fp.RemotePort = 0
fp.LocalPort = 0 fp.LocalPort = 0
@ -349,9 +427,8 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
} }
} else { } else {
//TODO: IPV6-WORK fp.LocalAddr, _ = netip.AddrFromSlice(data[12:16])
fp.LocalIP, _ = netip.AddrFromSlice(data[12:16]) fp.RemoteAddr, _ = netip.AddrFromSlice(data[16:20])
fp.RemoteIP, _ = netip.AddrFromSlice(data[16:20])
if fp.Fragment || fp.Protocol == firewall.ProtoICMP { if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
fp.RemotePort = 0 fp.RemotePort = 0
fp.LocalPort = 0 fp.LocalPort = 0
@ -492,27 +569,3 @@ func (f *Interface) sendMeta(ci *ConnectionState, endpoint *net.UDPAddr, meta *N
f.outside.WriteTo(msg, endpoint) f.outside.WriteTo(msg, endpoint)
} }
*/ */
func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte, caPool *cert.CAPool) (*cert.CachedCertificate, error) {
pk := h.PeerStatic()
if pk == nil {
return nil, errors.New("no peer static key was present")
}
if rawCertBytes == nil {
return nil, errors.New("provided payload was empty")
}
c, err := cert.UnmarshalCertificateFromHandshake(rawCertBytes, pk)
if err != nil {
return nil, fmt.Errorf("error unmarshaling cert: %w", err)
}
cc, err := caPool.VerifyCertificate(time.Now(), c)
if err != nil {
return nil, fmt.Errorf("certificate validation failed: %w", err)
}
return cc, nil
}

View File

@ -5,6 +5,9 @@ import (
"net/netip" "net/netip"
"testing" "testing"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
@ -13,9 +16,15 @@ import (
func Test_newPacket(t *testing.T) { func Test_newPacket(t *testing.T) {
p := &firewall.Packet{} p := &firewall.Packet{}
// length fail // length fails
err := newPacket([]byte{0, 1}, true, p) err := newPacket([]byte{}, true, p)
assert.EqualError(t, err, "packet is less than 20 bytes") assert.EqualError(t, err, "packet too short")
err = newPacket([]byte{0x40}, true, p)
assert.EqualError(t, err, "ipv4 packet is less than 20 bytes")
err = newPacket([]byte{0x60}, true, p)
assert.EqualError(t, err, "ipv6 packet is less than 20 bytes")
// length fail with ip options // length fail with ip options
h := ipv4.Header{ h := ipv4.Header{
@ -29,15 +38,15 @@ func Test_newPacket(t *testing.T) {
b, _ := h.Marshal() b, _ := h.Marshal()
err = newPacket(b, true, p) err = newPacket(b, true, p)
assert.EqualError(t, err, "packet is less than 28 bytes, ip header len: 24") assert.EqualError(t, err, "ipv4 packet is less than 28 bytes, ip header len: 24")
// not an ipv4 packet // not an ipv4 packet
err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p) err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
assert.EqualError(t, err, "packet is not ipv4, type: 0") assert.EqualError(t, err, "packet is an unknown ip version: 0")
// invalid ihl // invalid ihl
err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p) err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
assert.EqualError(t, err, "packet had an invalid header length: 8") assert.EqualError(t, err, "ipv4 packet had an invalid header length: 8")
// account for variable ip header length - incoming // account for variable ip header length - incoming
h = ipv4.Header{ h = ipv4.Header{
@ -55,8 +64,8 @@ func Test_newPacket(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, p.Protocol, uint8(firewall.ProtoTCP)) assert.Equal(t, p.Protocol, uint8(firewall.ProtoTCP))
assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.2")) assert.Equal(t, p.LocalAddr, netip.MustParseAddr("10.0.0.2"))
assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.1")) assert.Equal(t, p.RemoteAddr, netip.MustParseAddr("10.0.0.1"))
assert.Equal(t, p.RemotePort, uint16(3)) assert.Equal(t, p.RemotePort, uint16(3))
assert.Equal(t, p.LocalPort, uint16(4)) assert.Equal(t, p.LocalPort, uint16(4))
@ -76,8 +85,60 @@ func Test_newPacket(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, p.Protocol, uint8(2)) assert.Equal(t, p.Protocol, uint8(2))
assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.1")) assert.Equal(t, p.LocalAddr, netip.MustParseAddr("10.0.0.1"))
assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.2")) assert.Equal(t, p.RemoteAddr, netip.MustParseAddr("10.0.0.2"))
assert.Equal(t, p.RemotePort, uint16(6)) assert.Equal(t, p.RemotePort, uint16(6))
assert.Equal(t, p.LocalPort, uint16(5)) assert.Equal(t, p.LocalPort, uint16(5))
} }
func Test_newPacket_v6(t *testing.T) {
p := &firewall.Packet{}
ip := layers.IPv6{
Version: 6,
NextHeader: firewall.ProtoUDP,
HopLimit: 128,
SrcIP: net.IPv6linklocalallrouters,
DstIP: net.IPv6linklocalallnodes,
}
udp := layers.UDP{
SrcPort: layers.UDPPort(36123),
DstPort: layers.UDPPort(22),
}
err := udp.SetNetworkLayerForChecksum(&ip)
if err != nil {
panic(err)
}
buffer := gopacket.NewSerializeBuffer()
opt := gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload([]byte{0xde, 0xad, 0xbe, 0xef}))
if err != nil {
panic(err)
}
b := buffer.Bytes()
//test incoming
err = newPacket(b, true, p)
assert.Nil(t, err)
assert.Equal(t, p.Protocol, uint8(firewall.ProtoUDP))
assert.Equal(t, p.RemoteAddr, netip.MustParseAddr("ff02::2"))
assert.Equal(t, p.LocalAddr, netip.MustParseAddr("ff02::1"))
assert.Equal(t, p.RemotePort, uint16(36123))
assert.Equal(t, p.LocalPort, uint16(22))
//test outgoing
err = newPacket(b, false, p)
assert.Nil(t, err)
assert.Equal(t, p.Protocol, uint8(firewall.ProtoUDP))
assert.Equal(t, p.LocalAddr, netip.MustParseAddr("ff02::2"))
assert.Equal(t, p.RemoteAddr, netip.MustParseAddr("ff02::1"))
assert.Equal(t, p.LocalPort, uint16(36123))
assert.Equal(t, p.RemotePort, uint16(22))
}

View File

@ -8,7 +8,7 @@ import (
type Device interface { type Device interface {
io.ReadWriteCloser io.ReadWriteCloser
Activate() error Activate() error
Cidr() netip.Prefix Networks() []netip.Prefix
Name() string Name() string
RouteFor(netip.Addr) netip.Addr RouteFor(netip.Addr) netip.Addr
NewMultiQueueReader() (io.ReadWriteCloser, error) NewMultiQueueReader() (io.ReadWriteCloser, error)

View File

@ -61,7 +61,7 @@ func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table
return routeTree, nil return routeTree, nil
} }
func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) { func parseRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
var err error var err error
r := c.Get("tun.routes") r := c.Get("tun.routes")
@ -117,12 +117,20 @@ func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
return nil, fmt.Errorf("entry %v.route in tun.routes failed to parse: %v", i+1, err) return nil, fmt.Errorf("entry %v.route in tun.routes failed to parse: %v", i+1, err)
} }
if !network.Contains(r.Cidr.Addr()) || r.Cidr.Bits() < network.Bits() { found := false
for _, network := range networks {
if network.Contains(r.Cidr.Addr()) && r.Cidr.Bits() >= network.Bits() {
found = true
break
}
}
if !found {
return nil, fmt.Errorf( return nil, fmt.Errorf(
"entry %v.route in tun.routes is not contained within the network attached to the certificate; route: %v, network: %v", "entry %v.route in tun.routes is not contained within the configured vpn networks; route: %v, networks: %v",
i+1, i+1,
r.Cidr.String(), r.Cidr.String(),
network.String(), networks,
) )
} }
@ -132,7 +140,7 @@ func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
return routes, nil return routes, nil
} }
func parseUnsafeRoutes(c *config.C, network netip.Prefix) ([]Route, error) { func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
var err error var err error
r := c.Get("tun.unsafe_routes") r := c.Get("tun.unsafe_routes")
@ -229,14 +237,16 @@ func parseUnsafeRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes failed to parse: %v", i+1, err) return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes failed to parse: %v", i+1, err)
} }
for _, network := range networks {
if network.Contains(r.Cidr.Addr()) { if network.Contains(r.Cidr.Addr()) {
return nil, fmt.Errorf( return nil, fmt.Errorf(
"entry %v.route in tun.unsafe_routes is contained within the network attached to the certificate; route: %v, network: %v", "entry %v.route in tun.unsafe_routes is contained within the configured vpn networks; route: %v, network: %v",
i+1, i+1,
r.Cidr.String(), r.Cidr.String(),
network.String(), network.String(),
) )
} }
}
routes[i] = r routes[i] = r
} }

View File

@ -17,76 +17,82 @@ func Test_parseRoutes(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
// test no routes config // test no routes config
routes, err := parseRoutes(c, n) routes, err := parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, err) assert.Nil(t, err)
assert.Len(t, routes, 0) assert.Len(t, routes, 0)
// not an array // not an array
c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"} c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"}
routes, err = parseRoutes(c, n) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "tun.routes is not an array") assert.EqualError(t, err, "tun.routes is not an array")
// no routes // no routes
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}} c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}}
routes, err = parseRoutes(c, n) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, err) assert.Nil(t, err)
assert.Len(t, routes, 0) assert.Len(t, routes, 0)
// weird route // weird route
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}} c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}}
routes, err = parseRoutes(c, n) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1 in tun.routes is invalid") assert.EqualError(t, err, "entry 1 in tun.routes is invalid")
// no mtu // no mtu
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{}}} c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{}}}
routes, err = parseRoutes(c, n) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.mtu in tun.routes is not present") assert.EqualError(t, err, "entry 1.mtu in tun.routes is not present")
// bad mtu // bad mtu
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "nope"}}} c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "nope"}}}
routes, err = parseRoutes(c, n) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") assert.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
// low mtu // low mtu
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "499"}}} c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "499"}}}
routes, err = parseRoutes(c, n) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499") assert.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499")
// missing route // missing route
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500"}}} c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500"}}}
routes, err = parseRoutes(c, n) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.routes is not present") assert.EqualError(t, err, "entry 1.route in tun.routes is not present")
// unparsable route // unparsable route
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}} c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}}
routes, err = parseRoutes(c, n) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
// below network range // below network range
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}} c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}}
routes, err = parseRoutes(c, n) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the network attached to the certificate; route: 1.0.0.0/8, network: 10.0.0.0/24") assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 1.0.0.0/8, networks: [10.0.0.0/24]")
// above network range // above network range
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "10.0.1.0/24"}}} c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "10.0.1.0/24"}}}
routes, err = parseRoutes(c, n) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the network attached to the certificate; route: 10.0.1.0/24, network: 10.0.0.0/24") assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 10.0.1.0/24, networks: [10.0.0.0/24]")
// Not in multiple ranges
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "192.0.0.0/24"}}}
routes, err = parseRoutes(c, []netip.Prefix{n, netip.MustParsePrefix("192.1.0.0/24")})
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 192.0.0.0/24, networks: [10.0.0.0/24 192.1.0.0/24]")
// happy case // happy case
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{ c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{
map[interface{}]interface{}{"mtu": "9000", "route": "10.0.0.0/29"}, map[interface{}]interface{}{"mtu": "9000", "route": "10.0.0.0/29"},
map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"}, map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"},
}} }}
routes, err = parseRoutes(c, n) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, err) assert.Nil(t, err)
assert.Len(t, routes, 2) assert.Len(t, routes, 2)
@ -116,31 +122,31 @@ func Test_parseUnsafeRoutes(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
// test no routes config // test no routes config
routes, err := parseUnsafeRoutes(c, n) routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, err) assert.Nil(t, err)
assert.Len(t, routes, 0) assert.Len(t, routes, 0)
// not an array // not an array
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"}
routes, err = parseUnsafeRoutes(c, n) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "tun.unsafe_routes is not an array") assert.EqualError(t, err, "tun.unsafe_routes is not an array")
// no routes // no routes
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}}
routes, err = parseUnsafeRoutes(c, n) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, err) assert.Nil(t, err)
assert.Len(t, routes, 0) assert.Len(t, routes, 0)
// weird route // weird route
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}}
routes, err = parseUnsafeRoutes(c, n) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid") assert.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid")
// no via // no via
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{}}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{}}}
routes, err = parseUnsafeRoutes(c, n) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present") assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present")
@ -149,68 +155,68 @@ func Test_parseUnsafeRoutes(t *testing.T) {
127, false, nil, 1.0, []string{"1", "2"}, 127, false, nil, 1.0, []string{"1", "2"},
} { } {
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}}
routes, err = parseUnsafeRoutes(c, n) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue)) assert.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue))
} }
// unparsable via // unparsable via
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}}
routes, err = parseUnsafeRoutes(c, n) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP") assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP")
// missing route // missing route
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}}
routes, err = parseUnsafeRoutes(c, n) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present") assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present")
// unparsable route // unparsable route
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}}
routes, err = parseUnsafeRoutes(c, n) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
// within network range // within network range
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}}
routes, err = parseUnsafeRoutes(c, n) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the network attached to the certificate; route: 10.0.0.0/24, network: 10.0.0.0/24") assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the configured vpn networks; route: 10.0.0.0/24, network: 10.0.0.0/24")
// below network range // below network range
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
routes, err = parseUnsafeRoutes(c, n) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Len(t, routes, 1) assert.Len(t, routes, 1)
assert.Nil(t, err) assert.Nil(t, err)
// above network range // above network range
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}}
routes, err = parseUnsafeRoutes(c, n) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Len(t, routes, 1) assert.Len(t, routes, 1)
assert.Nil(t, err) assert.Nil(t, err)
// no mtu // no mtu
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
routes, err = parseUnsafeRoutes(c, n) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Len(t, routes, 1) assert.Len(t, routes, 1)
assert.Equal(t, 0, routes[0].MTU) assert.Equal(t, 0, routes[0].MTU)
// bad mtu // bad mtu
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}}
routes, err = parseUnsafeRoutes(c, n) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
// low mtu // low mtu
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "499"}}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "499"}}}
routes, err = parseUnsafeRoutes(c, n) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499") assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499")
// bad install // bad install
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}}
routes, err = parseUnsafeRoutes(c, n) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax") assert.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax")
@ -221,7 +227,7 @@ func Test_parseUnsafeRoutes(t *testing.T) {
map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1}, map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1},
map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"}, map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"},
}} }}
routes, err = parseUnsafeRoutes(c, n) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, err) assert.Nil(t, err)
assert.Len(t, routes, 4) assert.Len(t, routes, 4)
@ -260,7 +266,7 @@ func Test_makeRouteTree(t *testing.T) {
map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"}, map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"},
map[interface{}]interface{}{"via": "192.168.0.2", "route": "1.0.0.1/32"}, map[interface{}]interface{}{"via": "192.168.0.2", "route": "1.0.0.1/32"},
}} }}
routes, err := parseUnsafeRoutes(c, n) routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, routes, 2) assert.Len(t, routes, 2)
routeTree, err := makeRouteTree(l, routes, true) routeTree, err := makeRouteTree(l, routes, true)

View File

@ -11,36 +11,36 @@ import (
const DefaultMTU = 1300 const DefaultMTU = 1300
// TODO: We may be able to remove routines // TODO: We may be able to remove routines
type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) { func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
switch { switch {
case c.GetBool("tun.disabled", false): case c.GetBool("tun.disabled", false):
tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l) tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
return tun, nil return tun, nil
default: default:
return newTun(c, l, tunCidr, routines > 1) return newTun(c, l, vpnNetworks, routines > 1)
} }
} }
func NewFdDeviceFromConfig(fd *int) DeviceFactory { func NewFdDeviceFromConfig(fd *int) DeviceFactory {
return func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) { return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
return newTunFromFd(c, l, *fd, tunCidr) return newTunFromFd(c, l, *fd, vpnNetworks)
} }
} }
func getAllRoutesFromConfig(c *config.C, cidr netip.Prefix, initial bool) (bool, []Route, error) { func getAllRoutesFromConfig(c *config.C, vpnNetworks []netip.Prefix, initial bool) (bool, []Route, error) {
if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") { if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") {
return false, nil, nil return false, nil, nil
} }
routes, err := parseRoutes(c, cidr) routes, err := parseRoutes(c, vpnNetworks)
if err != nil { if err != nil {
return true, nil, util.NewContextualError("Could not parse tun.routes", nil, err) return true, nil, util.NewContextualError("Could not parse tun.routes", nil, err)
} }
unsafeRoutes, err := parseUnsafeRoutes(c, cidr) unsafeRoutes, err := parseUnsafeRoutes(c, vpnNetworks)
if err != nil { if err != nil {
return true, nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err) return true, nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err)
} }

View File

@ -19,13 +19,13 @@ import (
type tun struct { type tun struct {
io.ReadWriteCloser io.ReadWriteCloser
fd int fd int
cidr netip.Prefix vpnNetworks []netip.Prefix
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]] routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger l *logrus.Logger
} }
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) { func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
// XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly. // XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly.
// Be sure not to call file.Fd() as it will set the fd to blocking mode. // Be sure not to call file.Fd() as it will set the fd to blocking mode.
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
@ -33,7 +33,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix
t := &tun{ t := &tun{
ReadWriteCloser: file, ReadWriteCloser: file,
fd: deviceFd, fd: deviceFd,
cidr: cidr, vpnNetworks: vpnNetworks,
l: l, l: l,
} }
@ -52,7 +52,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix
return t, nil return t, nil
} }
func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) { func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTun not supported in Android") return nil, fmt.Errorf("newTun not supported in Android")
} }
@ -66,7 +66,7 @@ func (t tun) Activate() error {
} }
func (t *tun) reload(c *config.C, initial bool) error { func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil { if err != nil {
return err return err
} }
@ -86,8 +86,8 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil return nil
} }
func (t *tun) Cidr() netip.Prefix { func (t *tun) Networks() []netip.Prefix {
return t.cidr return t.vpnNetworks
} }
func (t *tun) Name() string { func (t *tun) Name() string {

View File

@ -25,7 +25,7 @@ import (
type tun struct { type tun struct {
io.ReadWriteCloser io.ReadWriteCloser
Device string Device string
cidr netip.Prefix vpnNetworks []netip.Prefix
DefaultMTU int DefaultMTU int
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]] routeTree atomic.Pointer[bart.Table[netip.Addr]]
@ -36,44 +36,50 @@ type tun struct {
out []byte out []byte
} }
type sockaddrCtl struct {
scLen uint8
scFamily uint8
ssSysaddr uint16
scID uint32
scUnit uint32
scReserved [5]uint32
}
type ifReq struct { type ifReq struct {
Name [16]byte Name [unix.IFNAMSIZ]byte
Flags uint16 Flags uint16
pad [8]byte pad [8]byte
} }
var sockaddrCtlSize uintptr = 32
const ( const (
_SYSPROTO_CONTROL = 2 //define SYSPROTO_CONTROL 2 /* kernel control protocol */ _SIOCAIFADDR_IN6 = 2155899162
_AF_SYS_CONTROL = 2 //#define AF_SYS_CONTROL 2 /* corresponding sub address type */ _UTUN_OPT_IFNAME = 2
_PF_SYSTEM = unix.AF_SYSTEM //#define PF_SYSTEM AF_SYSTEM _IN6_IFF_NODAD = 0x0020
_CTLIOCGINFO = 3227799043 //#define CTLIOCGINFO _IOWR('N', 3, struct ctl_info) _IN6_IFF_SECURED = 0x0400
utunControlName = "com.apple.net.utun_control" utunControlName = "com.apple.net.utun_control"
) )
type ifreqAddr struct {
Name [16]byte
Addr unix.RawSockaddrInet4
pad [8]byte
}
type ifreqMTU struct { type ifreqMTU struct {
Name [16]byte Name [16]byte
MTU int32 MTU int32
pad [8]byte pad [8]byte
} }
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { type addrLifetime struct {
Expire float64
Preferred float64
Vltime uint32
Pltime uint32
}
type ifreqAlias4 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet4
DstAddr unix.RawSockaddrInet4
MaskAddr unix.RawSockaddrInet4
}
type ifreqAlias6 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet6
DstAddr unix.RawSockaddrInet6
PrefixMask unix.RawSockaddrInet6
Flags uint32
Lifetime addrLifetime
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
name := c.GetString("tun.dev", "") name := c.GetString("tun.dev", "")
ifIndex := -1 ifIndex := -1
if name != "" && name != "utun" { if name != "" && name != "utun" {
@ -86,66 +92,41 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
} }
} }
fd, err := unix.Socket(_PF_SYSTEM, unix.SOCK_DGRAM, _SYSPROTO_CONTROL) fd, err := unix.Socket(unix.AF_SYSTEM, unix.SOCK_DGRAM, unix.AF_SYS_CONTROL)
if err != nil { if err != nil {
return nil, fmt.Errorf("system socket: %v", err) return nil, fmt.Errorf("system socket: %v", err)
} }
var ctlInfo = &struct { var ctlInfo = &unix.CtlInfo{}
ctlID uint32 copy(ctlInfo.Name[:], utunControlName)
ctlName [96]byte
}{}
copy(ctlInfo.ctlName[:], utunControlName) err = unix.IoctlCtlInfo(fd, ctlInfo)
err = ioctl(uintptr(fd), uintptr(_CTLIOCGINFO), uintptr(unsafe.Pointer(ctlInfo)))
if err != nil { if err != nil {
return nil, fmt.Errorf("CTLIOCGINFO: %v", err) return nil, fmt.Errorf("CTLIOCGINFO: %v", err)
} }
sc := sockaddrCtl{ err = unix.Connect(fd, &unix.SockaddrCtl{
scLen: uint8(sockaddrCtlSize), ID: ctlInfo.Id,
scFamily: unix.AF_SYSTEM, Unit: uint32(ifIndex) + 1,
ssSysaddr: _AF_SYS_CONTROL, })
scID: ctlInfo.ctlID, if err != nil {
scUnit: uint32(ifIndex) + 1, return nil, fmt.Errorf("SYS_CONNECT: %v", err)
} }
_, _, errno := unix.RawSyscall( name, err = unix.GetsockoptString(fd, unix.AF_SYS_CONTROL, _UTUN_OPT_IFNAME)
unix.SYS_CONNECT, if err != nil {
uintptr(fd), return nil, fmt.Errorf("failed to retrieve tun name: %w", err)
uintptr(unsafe.Pointer(&sc)),
sockaddrCtlSize,
)
if errno != 0 {
return nil, fmt.Errorf("SYS_CONNECT: %v", errno)
} }
var ifName struct { err = unix.SetNonblock(fd, true)
name [16]byte
}
ifNameSize := uintptr(len(ifName.name))
_, _, errno = syscall.Syscall6(syscall.SYS_GETSOCKOPT, uintptr(fd),
2, // SYSPROTO_CONTROL
2, // UTUN_OPT_IFNAME
uintptr(unsafe.Pointer(&ifName)),
uintptr(unsafe.Pointer(&ifNameSize)), 0)
if errno != 0 {
return nil, fmt.Errorf("SYS_GETSOCKOPT: %v", errno)
}
name = string(ifName.name[:ifNameSize-1])
err = syscall.SetNonblock(fd, true)
if err != nil { if err != nil {
return nil, fmt.Errorf("SetNonblock: %v", err) return nil, fmt.Errorf("SetNonblock: %v", err)
} }
file := os.NewFile(uintptr(fd), "")
t := &tun{ t := &tun{
ReadWriteCloser: file, ReadWriteCloser: os.NewFile(uintptr(fd), ""),
Device: name, Device: name,
cidr: cidr, vpnNetworks: vpnNetworks,
DefaultMTU: c.GetInt("tun.mtu", DefaultMTU), DefaultMTU: c.GetInt("tun.mtu", DefaultMTU),
l: l, l: l,
} }
@ -172,7 +153,7 @@ func (t *tun) deviceBytes() (o [16]byte) {
return return
} }
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in Darwin") return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
} }
@ -186,16 +167,6 @@ func (t *tun) Close() error {
func (t *tun) Activate() error { func (t *tun) Activate() error {
devName := t.deviceBytes() devName := t.deviceBytes()
var addr, mask [4]byte
if !t.cidr.Addr().Is4() {
//TODO: IPV6-WORK
panic("need ipv6")
}
addr = t.cidr.Addr().As4()
copy(mask[:], prefixToMask(t.cidr))
s, err := unix.Socket( s, err := unix.Socket(
unix.AF_INET, unix.AF_INET,
unix.SOCK_DGRAM, unix.SOCK_DGRAM,
@ -208,66 +179,18 @@ func (t *tun) Activate() error {
fd := uintptr(s) fd := uintptr(s)
ifra := ifreqAddr{
Name: devName,
Addr: unix.RawSockaddrInet4{
Family: unix.AF_INET,
Addr: addr,
},
}
// Set the device ip address
if err = ioctl(fd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil {
return fmt.Errorf("failed to set tun address: %s", err)
}
// Set the device network
ifra.Addr.Addr = mask
if err = ioctl(fd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil {
return fmt.Errorf("failed to set tun netmask: %s", err)
}
// Set the device name
ifrf := ifReq{Name: devName}
if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to set tun device name: %s", err)
}
// Set the MTU on the device // Set the MTU on the device
ifm := ifreqMTU{Name: devName, MTU: int32(t.DefaultMTU)} ifm := ifreqMTU{Name: devName, MTU: int32(t.DefaultMTU)}
if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil { if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
return fmt.Errorf("failed to set tun mtu: %v", err) return fmt.Errorf("failed to set tun mtu: %v", err)
} }
/* // Get the device flags
// Set the transmit queue length ifrf := ifReq{Name: devName}
ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)} if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
if err = ioctl(fd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil { return fmt.Errorf("failed to get tun flags: %s", err)
// If we can't set the queue length nebula will still work but it may lead to packet loss
l.WithError(err).Error("Failed to set tun tx queue length")
}
*/
// Bring up the interface
ifrf.Flags = ifrf.Flags | unix.IFF_UP
if err = ioctl(fd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to bring the tun device up: %s", err)
} }
routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer func() {
unix.Shutdown(routeSock, unix.SHUT_RDWR)
err := unix.Close(routeSock)
if err != nil {
t.l.WithError(err).Error("failed to close AF_ROUTE socket")
}
}()
routeAddr := &netroute.Inet4Addr{}
maskAddr := &netroute.Inet4Addr{}
linkAddr, err := getLinkAddr(t.Device) linkAddr, err := getLinkAddr(t.Device)
if err != nil { if err != nil {
return err return err
@ -277,15 +200,19 @@ func (t *tun) Activate() error {
} }
t.linkAddr = linkAddr t.linkAddr = linkAddr
copy(routeAddr.IP[:], addr[:]) for _, network := range t.vpnNetworks {
copy(maskAddr.IP[:], mask[:]) if network.Addr().Is4() {
err = addRoute(routeSock, routeAddr, maskAddr, linkAddr) err = t.activate4(network)
if err != nil { if err != nil {
if errors.Is(err, unix.EEXIST) {
err = fmt.Errorf("unable to add tun route, identical route already exists: %s", t.cidr)
}
return err return err
} }
} else {
err = t.activate6(network)
if err != nil {
return err
}
}
}
// Run the interface // Run the interface
ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING
@ -297,8 +224,89 @@ func (t *tun) Activate() error {
return t.addRoutes(false) return t.addRoutes(false)
} }
func (t *tun) activate4(network netip.Prefix) error {
s, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
unix.IPPROTO_IP,
)
if err != nil {
return err
}
defer unix.Close(s)
ifr := ifreqAlias4{
Name: t.deviceBytes(),
Addr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: network.Addr().As4(),
},
DstAddr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: network.Addr().As4(),
},
MaskAddr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: prefixToMask(network).As4(),
},
}
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
return fmt.Errorf("failed to set tun v4 address: %s", err)
}
err = addRoute(network, t.linkAddr)
if err != nil {
return err
}
return nil
}
func (t *tun) activate6(network netip.Prefix) error {
s, err := unix.Socket(
unix.AF_INET6,
unix.SOCK_DGRAM,
unix.IPPROTO_IP,
)
if err != nil {
return err
}
defer unix.Close(s)
ifr := ifreqAlias6{
Name: t.deviceBytes(),
Addr: unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: network.Addr().As16(),
},
PrefixMask: unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: prefixToMask(network).As16(),
},
Lifetime: addrLifetime{
// never expires
Vltime: 0xffffffff,
Pltime: 0xffffffff,
},
//TODO: should we disable DAD (duplicate address detection) and mark this as a secured address?
Flags: _IN6_IFF_NODAD,
}
if err := ioctl(uintptr(s), _SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil {
return fmt.Errorf("failed to set tun address: %s", err)
}
return nil
}
func (t *tun) reload(c *config.C, initial bool) error { func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil { if err != nil {
return err return err
} }
@ -371,38 +379,15 @@ func getLinkAddr(name string) (*netroute.LinkAddr, error) {
} }
func (t *tun) addRoutes(logErrors bool) error { func (t *tun) addRoutes(logErrors bool) error {
routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer func() {
unix.Shutdown(routeSock, unix.SHUT_RDWR)
err := unix.Close(routeSock)
if err != nil {
t.l.WithError(err).Error("failed to close AF_ROUTE socket")
}
}()
routeAddr := &netroute.Inet4Addr{}
maskAddr := &netroute.Inet4Addr{}
routes := *t.Routes.Load() routes := *t.Routes.Load()
for _, r := range routes { for _, r := range routes {
if !r.Via.IsValid() || !r.Install { if !r.Via.IsValid() || !r.Install {
// We don't allow route MTUs so only install routes with a via // We don't allow route MTUs so only install routes with a via
continue continue
} }
if !r.Cidr.Addr().Is4() { err := addRoute(r.Cidr, t.linkAddr)
//TODO: implement ipv6
panic("Cant handle ipv6 routes yet")
}
routeAddr.IP = r.Cidr.Addr().As4()
//TODO: we could avoid the copy
copy(maskAddr.IP[:], prefixToMask(r.Cidr))
err := addRoute(routeSock, routeAddr, maskAddr, t.linkAddr)
if err != nil { if err != nil {
if errors.Is(err, unix.EEXIST) { if errors.Is(err, unix.EEXIST) {
t.l.WithField("route", r.Cidr). t.l.WithField("route", r.Cidr).
@ -424,36 +409,12 @@ func (t *tun) addRoutes(logErrors bool) error {
} }
func (t *tun) removeRoutes(routes []Route) error { func (t *tun) removeRoutes(routes []Route) error {
routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer func() {
unix.Shutdown(routeSock, unix.SHUT_RDWR)
err := unix.Close(routeSock)
if err != nil {
t.l.WithError(err).Error("failed to close AF_ROUTE socket")
}
}()
routeAddr := &netroute.Inet4Addr{}
maskAddr := &netroute.Inet4Addr{}
for _, r := range routes { for _, r := range routes {
if !r.Install { if !r.Install {
continue continue
} }
if r.Cidr.Addr().Is6() { err := delRoute(r.Cidr, t.linkAddr)
//TODO: implement ipv6
panic("Cant handle ipv6 routes yet")
}
routeAddr.IP = r.Cidr.Addr().As4()
copy(maskAddr.IP[:], prefixToMask(r.Cidr))
err := delRoute(routeSock, routeAddr, maskAddr, t.linkAddr)
if err != nil { if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route") t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else { } else {
@ -463,23 +424,39 @@ func (t *tun) removeRoutes(routes []Route) error {
return nil return nil
} }
func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error { func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
r := netroute.RouteMessage{ sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := &netroute.RouteMessage{
Version: unix.RTM_VERSION, Version: unix.RTM_VERSION,
Type: unix.RTM_ADD, Type: unix.RTM_ADD,
Flags: unix.RTF_UP, Flags: unix.RTF_UP,
Seq: 1, Seq: 1,
Addrs: []netroute.Addr{
unix.RTAX_DST: addr,
unix.RTAX_GATEWAY: link,
unix.RTAX_NETMASK: mask,
},
} }
data, err := r.Marshal() if prefix.Addr().Is4() {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: gateway,
}
} else {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: gateway,
}
}
data, err := route.Marshal()
if err != nil { if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err) return fmt.Errorf("failed to create route.RouteMessage: %w", err)
} }
_, err = unix.Write(sock, data[:]) _, err = unix.Write(sock, data[:])
if err != nil { if err != nil {
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err) return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
@ -488,19 +465,34 @@ func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr)
return nil return nil
} }
func delRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error { func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
r := netroute.RouteMessage{ sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := netroute.RouteMessage{
Version: unix.RTM_VERSION, Version: unix.RTM_VERSION,
Type: unix.RTM_DELETE, Type: unix.RTM_DELETE,
Seq: 1, Seq: 1,
Addrs: []netroute.Addr{
unix.RTAX_DST: addr,
unix.RTAX_GATEWAY: link,
unix.RTAX_NETMASK: mask,
},
} }
data, err := r.Marshal() if prefix.Addr().Is4() {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: gateway,
}
} else {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: gateway,
}
}
data, err := route.Marshal()
if err != nil { if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err) return fmt.Errorf("failed to create route.RouteMessage: %w", err)
} }
@ -513,7 +505,6 @@ func delRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr)
} }
func (t *tun) Read(to []byte) (int, error) { func (t *tun) Read(to []byte) (int, error) {
buf := make([]byte, len(to)+4) buf := make([]byte, len(to)+4)
n, err := t.ReadWriteCloser.Read(buf) n, err := t.ReadWriteCloser.Read(buf)
@ -551,8 +542,8 @@ func (t *tun) Write(from []byte) (int, error) {
return n - 4, err return n - 4, err
} }
func (t *tun) Cidr() netip.Prefix { func (t *tun) Networks() []netip.Prefix {
return t.cidr return t.vpnNetworks
} }
func (t *tun) Name() string { func (t *tun) Name() string {
@ -563,10 +554,12 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin") return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
} }
func prefixToMask(prefix netip.Prefix) []byte { func prefixToMask(prefix netip.Prefix) netip.Addr {
pLen := 128 pLen := 128
if prefix.Addr().Is4() { if prefix.Addr().Is4() {
pLen = 32 pLen = 32
} }
return net.CIDRMask(prefix.Bits(), pLen)
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
return addr
} }

View File

@ -13,7 +13,7 @@ import (
type disabledTun struct { type disabledTun struct {
read chan []byte read chan []byte
cidr netip.Prefix vpnNetworks []netip.Prefix
// Track these metrics since we don't have the tun device to do it for us // Track these metrics since we don't have the tun device to do it for us
tx metrics.Counter tx metrics.Counter
@ -21,9 +21,9 @@ type disabledTun struct {
l *logrus.Logger l *logrus.Logger
} }
func newDisabledTun(cidr netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun { func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
tun := &disabledTun{ tun := &disabledTun{
cidr: cidr, vpnNetworks: vpnNetworks,
read: make(chan []byte, queueLen), read: make(chan []byte, queueLen),
l: l, l: l,
} }
@ -47,8 +47,8 @@ func (*disabledTun) RouteFor(addr netip.Addr) netip.Addr {
return netip.Addr{} return netip.Addr{}
} }
func (t *disabledTun) Cidr() netip.Prefix { func (t *disabledTun) Networks() []netip.Prefix {
return t.cidr return t.vpnNetworks
} }
func (*disabledTun) Name() string { func (*disabledTun) Name() string {

View File

@ -47,7 +47,7 @@ type ifreqDestroy struct {
type tun struct { type tun struct {
Device string Device string
cidr netip.Prefix vpnNetworks []netip.Prefix
MTU int MTU int
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]] routeTree atomic.Pointer[bart.Table[netip.Addr]]
@ -78,11 +78,11 @@ func (t *tun) Close() error {
return nil return nil
} }
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD") return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
} }
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open existing tun device // Try to open existing tun device
var file *os.File var file *os.File
var err error var err error
@ -150,7 +150,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
t := &tun{ t := &tun{
ReadWriteCloser: file, ReadWriteCloser: file,
Device: deviceName, Device: deviceName,
cidr: cidr, vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU), MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l, l: l,
} }
@ -170,16 +170,16 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
return t, nil return t, nil
} }
func (t *tun) Activate() error { func (t *tun) addIp(cidr netip.Prefix) error {
var err error var err error
// TODO use syscalls instead of exec.Command // TODO use syscalls instead of exec.Command
cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String()) cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String()) t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil { if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err) return fmt.Errorf("failed to run 'ifconfig': %s", err)
} }
cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), "-interface", t.Device) cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), "-interface", t.Device)
t.l.Debug("command: ", cmd.String()) t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil { if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err) return fmt.Errorf("failed to run 'route add': %s", err)
@ -195,8 +195,18 @@ func (t *tun) Activate() error {
return t.addRoutes(false) return t.addRoutes(false)
} }
func (t *tun) Activate() error {
for i := range t.vpnNetworks {
err := t.addIp(t.vpnNetworks[i])
if err != nil {
return err
}
}
return nil
}
func (t *tun) reload(c *config.C, initial bool) error { func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil { if err != nil {
return err return err
} }
@ -237,8 +247,8 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
return r return r
} }
func (t *tun) Cidr() netip.Prefix { func (t *tun) Networks() []netip.Prefix {
return t.cidr return t.vpnNetworks
} }
func (t *tun) Name() string { func (t *tun) Name() string {

View File

@ -21,20 +21,20 @@ import (
type tun struct { type tun struct {
io.ReadWriteCloser io.ReadWriteCloser
cidr netip.Prefix vpnNetworks []netip.Prefix
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]] routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger l *logrus.Logger
} }
func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) { func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTun not supported in iOS") return nil, fmt.Errorf("newTun not supported in iOS")
} }
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) { func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
file := os.NewFile(uintptr(deviceFd), "/dev/tun") file := os.NewFile(uintptr(deviceFd), "/dev/tun")
t := &tun{ t := &tun{
cidr: cidr, vpnNetworks: vpnNetworks,
ReadWriteCloser: &tunReadCloser{f: file}, ReadWriteCloser: &tunReadCloser{f: file},
l: l, l: l,
} }
@ -59,7 +59,7 @@ func (t *tun) Activate() error {
} }
func (t *tun) reload(c *config.C, initial bool) error { func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil { if err != nil {
return err return err
} }
@ -142,8 +142,8 @@ func (tr *tunReadCloser) Close() error {
return tr.f.Close() return tr.f.Close()
} }
func (t *tun) Cidr() netip.Prefix { func (t *tun) Networks() []netip.Prefix {
return t.cidr return t.vpnNetworks
} }
func (t *tun) Name() string { func (t *tun) Name() string {

View File

@ -25,7 +25,7 @@ type tun struct {
io.ReadWriteCloser io.ReadWriteCloser
fd int fd int
Device string Device string
cidr netip.Prefix vpnNetworks []netip.Prefix
MaxMTU int MaxMTU int
DefaultMTU int DefaultMTU int
TXQueueLen int TXQueueLen int
@ -40,18 +40,16 @@ type tun struct {
l *logrus.Logger l *logrus.Logger
} }
func (t *tun) Networks() []netip.Prefix {
return t.vpnNetworks
}
type ifReq struct { type ifReq struct {
Name [16]byte Name [16]byte
Flags uint16 Flags uint16
pad [8]byte pad [8]byte
} }
type ifreqAddr struct {
Name [16]byte
Addr unix.RawSockaddrInet4
pad [8]byte
}
type ifreqMTU struct { type ifreqMTU struct {
Name [16]byte Name [16]byte
MTU int32 MTU int32
@ -64,10 +62,10 @@ type ifreqQLEN struct {
pad [8]byte pad [8]byte
} }
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) { func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
t, err := newTunGeneric(c, l, file, cidr) t, err := newTunGeneric(c, l, file, vpnNetworks)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -77,7 +75,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix
return t, nil return t, nil
} }
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (*tun, error) { func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil { if err != nil {
// If /dev/net/tun doesn't exist, try to create it (will happen in docker) // If /dev/net/tun doesn't exist, try to create it (will happen in docker)
@ -112,7 +110,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (
name := strings.Trim(string(req.Name[:]), "\x00") name := strings.Trim(string(req.Name[:]), "\x00")
file := os.NewFile(uintptr(fd), "/dev/net/tun") file := os.NewFile(uintptr(fd), "/dev/net/tun")
t, err := newTunGeneric(c, l, file, cidr) t, err := newTunGeneric(c, l, file, vpnNetworks)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -122,11 +120,11 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (
return t, nil return t, nil
} }
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Prefix) (*tun, error) { func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
t := &tun{ t := &tun{
ReadWriteCloser: file, ReadWriteCloser: file,
fd: int(file.Fd()), fd: int(file.Fd()),
cidr: cidr, vpnNetworks: vpnNetworks,
TXQueueLen: c.GetInt("tun.tx_queue", 500), TXQueueLen: c.GetInt("tun.tx_queue", 500),
useSystemRoutes: c.GetBool("tun.use_system_route_table", false), useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
l: l, l: l,
@ -148,7 +146,7 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Pref
} }
func (t *tun) reload(c *config.C, initial bool) error { func (t *tun) reload(c *config.C, initial bool) error {
routeChange, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) routeChange, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil { if err != nil {
return err return err
} }
@ -190,13 +188,15 @@ func (t *tun) reload(c *config.C, initial bool) error {
} }
if oldDefaultMTU != newDefaultMTU { if oldDefaultMTU != newDefaultMTU {
err := t.setDefaultRoute() for i := range t.vpnNetworks {
err := t.setDefaultRoute(t.vpnNetworks[i])
if err != nil { if err != nil {
t.l.Warn(err) t.l.Warn(err)
} else { } else {
t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU) t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU)
} }
} }
}
// Remove first, if the system removes a wanted route hopefully it will be re-added next // Remove first, if the system removes a wanted route hopefully it will be re-added next
t.removeRoutes(findRemovedRoutes(routes, *oldRoutes)) t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
@ -237,10 +237,10 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
func (t *tun) Write(b []byte) (int, error) { func (t *tun) Write(b []byte) (int, error) {
var nn int var nn int
max := len(b) maximum := len(b)
for { for {
n, err := unix.Write(t.fd, b[nn:max]) n, err := unix.Write(t.fd, b[nn:maximum])
if n > 0 { if n > 0 {
nn += n nn += n
} }
@ -265,6 +265,58 @@ func (t *tun) deviceBytes() (o [16]byte) {
return return
} }
func hasNetlinkAddr(al []*netlink.Addr, x netlink.Addr) bool {
for i := range al {
if al[i].Equal(x) {
return true
}
}
return false
}
// addIPs uses netlink to add all addresses that don't exist, then it removes ones that should not be there
func (t *tun) addIPs(link netlink.Link) error {
newAddrs := make([]*netlink.Addr, len(t.vpnNetworks))
for i := range t.vpnNetworks {
newAddrs[i] = &netlink.Addr{
IPNet: &net.IPNet{
IP: t.vpnNetworks[i].Addr().AsSlice(),
Mask: net.CIDRMask(t.vpnNetworks[i].Bits(), t.vpnNetworks[i].Addr().BitLen()),
},
Label: t.vpnNetworks[i].Addr().Zone(),
}
}
//add all new addresses
for i := range newAddrs {
//todo do we want to stack errors and try as many ops as possible?
//AddrReplace still adds new IPs, but if their properties change it will change them as well
if err := netlink.AddrReplace(link, newAddrs[i]); err != nil {
return err
}
}
//iterate over remainder, remove whoever shouldn't be there
al, err := netlink.AddrList(link, netlink.FAMILY_ALL)
if err != nil {
return fmt.Errorf("failed to get tun address list: %s", err)
}
for i := range al {
if hasNetlinkAddr(newAddrs, al[i]) {
continue
}
err = netlink.AddrDel(link, &al[i])
if err != nil {
t.l.WithError(err).Error("failed to remove address from tun address list")
} else {
t.l.WithField("removed", al[i].String()).Info("removed address not listed in cert(s)")
}
}
return nil
}
func (t *tun) Activate() error { func (t *tun) Activate() error {
devName := t.deviceBytes() devName := t.deviceBytes()
@ -272,15 +324,8 @@ func (t *tun) Activate() error {
t.watchRoutes() t.watchRoutes()
} }
var addr, mask [4]byte
//TODO: IPV6-WORK
addr = t.cidr.Addr().As4()
tmask := net.CIDRMask(t.cidr.Bits(), 32)
copy(mask[:], tmask)
s, err := unix.Socket( s, err := unix.Socket(
unix.AF_INET, unix.AF_INET, //because everything we use t.ioctlFd for is address family independent, this is fine
unix.SOCK_DGRAM, unix.SOCK_DGRAM,
unix.IPPROTO_IP, unix.IPPROTO_IP,
) )
@ -289,31 +334,19 @@ func (t *tun) Activate() error {
} }
t.ioctlFd = uintptr(s) t.ioctlFd = uintptr(s)
ifra := ifreqAddr{
Name: devName,
Addr: unix.RawSockaddrInet4{
Family: unix.AF_INET,
Addr: addr,
},
}
// Set the device ip address
if err = ioctl(t.ioctlFd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil {
return fmt.Errorf("failed to set tun address: %s", err)
}
// Set the device network
ifra.Addr.Addr = mask
if err = ioctl(t.ioctlFd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil {
return fmt.Errorf("failed to set tun netmask: %s", err)
}
// Set the device name // Set the device name
ifrf := ifReq{Name: devName} ifrf := ifReq{Name: devName}
if err = ioctl(t.ioctlFd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { if err = ioctl(t.ioctlFd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to set tun device name: %s", err) return fmt.Errorf("failed to set tun device name: %s", err)
} }
link, err := netlink.LinkByName(t.Device)
if err != nil {
return fmt.Errorf("failed to get tun device link: %s", err)
}
t.deviceIndex = link.Attrs().Index
// Setup our default MTU // Setup our default MTU
t.setMTU() t.setMTU()
@ -324,33 +357,36 @@ func (t *tun) Activate() error {
t.l.WithError(err).Error("Failed to set tun tx queue length") t.l.WithError(err).Error("Failed to set tun tx queue length")
} }
if err = t.addIPs(link); err != nil {
return err
}
// Bring up the interface // Bring up the interface
ifrf.Flags = ifrf.Flags | unix.IFF_UP ifrf.Flags = ifrf.Flags | unix.IFF_UP
if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to bring the tun device up: %s", err) return fmt.Errorf("failed to bring the tun device up: %s", err)
} }
link, err := netlink.LinkByName(t.Device)
if err != nil {
return fmt.Errorf("failed to get tun device link: %s", err)
}
t.deviceIndex = link.Attrs().Index
if err = t.setDefaultRoute(); err != nil {
return err
}
// Set the routes
if err = t.addRoutes(false); err != nil {
return err
}
// Run the interface // Run the interface
ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING
if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to run tun device: %s", err) return fmt.Errorf("failed to run tun device: %s", err)
} }
//set route MTU
for i := range t.vpnNetworks {
if err = t.setDefaultRoute(t.vpnNetworks[i]); err != nil {
return fmt.Errorf("failed to set default route MTU: %w", err)
}
}
// Set the routes
if err = t.addRoutes(false); err != nil {
return err
}
//todo do we want to keep the link-local address?
return nil return nil
} }
@ -363,12 +399,12 @@ func (t *tun) setMTU() {
} }
} }
func (t *tun) setDefaultRoute() error { func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
// Default route // Default route
dr := &net.IPNet{ dr := &net.IPNet{
IP: t.cidr.Masked().Addr().AsSlice(), IP: cidr.Masked().Addr().AsSlice(),
Mask: net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen()), Mask: net.CIDRMask(cidr.Bits(), cidr.Addr().BitLen()),
} }
nr := netlink.Route{ nr := netlink.Route{
@ -377,7 +413,7 @@ func (t *tun) setDefaultRoute() error {
MTU: t.DefaultMTU, MTU: t.DefaultMTU,
AdvMSS: t.advMSS(Route{}), AdvMSS: t.advMSS(Route{}),
Scope: unix.RT_SCOPE_LINK, Scope: unix.RT_SCOPE_LINK,
Src: net.IP(t.cidr.Addr().AsSlice()), Src: net.IP(cidr.Addr().AsSlice()),
Protocol: unix.RTPROT_KERNEL, Protocol: unix.RTPROT_KERNEL,
Table: unix.RT_TABLE_MAIN, Table: unix.RT_TABLE_MAIN,
Type: unix.RTN_UNICAST, Type: unix.RTN_UNICAST,
@ -463,10 +499,6 @@ func (t *tun) removeRoutes(routes []Route) {
} }
} }
func (t *tun) Cidr() netip.Prefix {
return t.cidr
}
func (t *tun) Name() string { func (t *tun) Name() string {
return t.Device return t.Device
} }
@ -523,9 +555,16 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
} }
gwAddr = gwAddr.Unmap() gwAddr = gwAddr.Unmap()
if !t.cidr.Contains(gwAddr) { withinNetworks := false
for i := range t.vpnNetworks {
if t.vpnNetworks[i].Contains(gwAddr) {
withinNetworks = true
break
}
}
if !withinNetworks {
// Gateway isn't in our overlay network, ignore // Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, not in our network") t.l.WithField("route", r).Debug("Ignoring route update, not in our networks")
return return
} }

View File

@ -28,7 +28,7 @@ type ifreqDestroy struct {
type tun struct { type tun struct {
Device string Device string
cidr netip.Prefix vpnNetworks []netip.Prefix
MTU int MTU int
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]] routeTree atomic.Pointer[bart.Table[netip.Addr]]
@ -58,13 +58,13 @@ func (t *tun) Close() error {
return nil return nil
} }
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD") return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
} }
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open tun device // Try to open tun device
var file *os.File var file *os.File
var err error var err error
@ -84,7 +84,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
t := &tun{ t := &tun{
ReadWriteCloser: file, ReadWriteCloser: file,
Device: deviceName, Device: deviceName,
cidr: cidr, vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU), MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l, l: l,
} }
@ -104,17 +104,17 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
return t, nil return t, nil
} }
func (t *tun) Activate() error { func (t *tun) addIp(cidr netip.Prefix) error {
var err error var err error
// TODO use syscalls instead of exec.Command // TODO use syscalls instead of exec.Command
cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String()) cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String()) t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil { if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err) return fmt.Errorf("failed to run 'ifconfig': %s", err)
} }
cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), t.cidr.Addr().String()) cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String()) t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil { if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err) return fmt.Errorf("failed to run 'route add': %s", err)
@ -130,8 +130,18 @@ func (t *tun) Activate() error {
return t.addRoutes(false) return t.addRoutes(false)
} }
func (t *tun) Activate() error {
for i := range t.vpnNetworks {
err := t.addIp(t.vpnNetworks[i])
if err != nil {
return err
}
}
return nil
}
func (t *tun) reload(c *config.C, initial bool) error { func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil { if err != nil {
return err return err
} }
@ -172,8 +182,8 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
return r return r
} }
func (t *tun) Cidr() netip.Prefix { func (t *tun) Networks() []netip.Prefix {
return t.cidr return t.vpnNetworks
} }
func (t *tun) Name() string { func (t *tun) Name() string {
@ -192,7 +202,7 @@ func (t *tun) addRoutes(logErrors bool) error {
continue continue
} }
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.Addr().String()) cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
t.l.Debug("command: ", cmd.String()) t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil { if err := cmd.Run(); err != nil {
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err) retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
@ -213,7 +223,8 @@ func (t *tun) removeRoutes(routes []Route) error {
continue continue
} }
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.cidr.Addr().String()) //todo is this right?
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
t.l.Debug("command: ", cmd.String()) t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil { if err := cmd.Run(); err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route") t.l.WithError(err).WithField("route", r).Error("Failed to remove route")

View File

@ -22,7 +22,7 @@ import (
type tun struct { type tun struct {
Device string Device string
cidr netip.Prefix vpnNetworks []netip.Prefix
MTU int MTU int
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]] routeTree atomic.Pointer[bart.Table[netip.Addr]]
@ -42,13 +42,13 @@ func (t *tun) Close() error {
return nil return nil
} }
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD") return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD")
} }
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
deviceName := c.GetString("tun.dev", "") deviceName := c.GetString("tun.dev", "")
if deviceName == "" { if deviceName == "" {
return nil, fmt.Errorf("a device name in the format of tunN must be specified") return nil, fmt.Errorf("a device name in the format of tunN must be specified")
@ -66,7 +66,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
t := &tun{ t := &tun{
ReadWriteCloser: file, ReadWriteCloser: file,
Device: deviceName, Device: deviceName,
cidr: cidr, vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU), MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l, l: l,
} }
@ -87,7 +87,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
} }
func (t *tun) reload(c *config.C, initial bool) error { func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil { if err != nil {
return err return err
} }
@ -123,10 +123,10 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil return nil
} }
func (t *tun) Activate() error { func (t *tun) addIp(cidr netip.Prefix) error {
var err error var err error
// TODO use syscalls instead of exec.Command // TODO use syscalls instead of exec.Command
cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String()) cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String()) t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil { if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err) return fmt.Errorf("failed to run 'ifconfig': %s", err)
@ -138,7 +138,7 @@ func (t *tun) Activate() error {
return fmt.Errorf("failed to run 'ifconfig': %s", err) return fmt.Errorf("failed to run 'ifconfig': %s", err)
} }
cmd = exec.Command("/sbin/route", "-n", "add", "-inet", t.cidr.String(), t.cidr.Addr().String()) cmd = exec.Command("/sbin/route", "-n", "add", "-inet", cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String()) t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil { if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err) return fmt.Errorf("failed to run 'route add': %s", err)
@ -148,6 +148,16 @@ func (t *tun) Activate() error {
return t.addRoutes(false) return t.addRoutes(false)
} }
func (t *tun) Activate() error {
for i := range t.vpnNetworks {
err := t.addIp(t.vpnNetworks[i])
if err != nil {
return err
}
}
return nil
}
func (t *tun) RouteFor(ip netip.Addr) netip.Addr { func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
r, _ := t.routeTree.Load().Lookup(ip) r, _ := t.routeTree.Load().Lookup(ip)
return r return r
@ -160,8 +170,8 @@ func (t *tun) addRoutes(logErrors bool) error {
// We don't allow route MTUs so only install routes with a via // We don't allow route MTUs so only install routes with a via
continue continue
} }
//todo is this right?
cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.Addr().String()) cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
t.l.Debug("command: ", cmd.String()) t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil { if err := cmd.Run(); err != nil {
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err) retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
@ -181,8 +191,8 @@ func (t *tun) removeRoutes(routes []Route) error {
if !r.Install { if !r.Install {
continue continue
} }
//todo is this right?
cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.cidr.Addr().String()) cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
t.l.Debug("command: ", cmd.String()) t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil { if err := cmd.Run(); err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route") t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
@ -193,8 +203,8 @@ func (t *tun) removeRoutes(routes []Route) error {
return nil return nil
} }
func (t *tun) Cidr() netip.Prefix { func (t *tun) Networks() []netip.Prefix {
return t.cidr return t.vpnNetworks
} }
func (t *tun) Name() string { func (t *tun) Name() string {

View File

@ -17,7 +17,7 @@ import (
type TestTun struct { type TestTun struct {
Device string Device string
cidr netip.Prefix vpnNetworks []netip.Prefix
Routes []Route Routes []Route
routeTree *bart.Table[netip.Addr] routeTree *bart.Table[netip.Addr]
l *logrus.Logger l *logrus.Logger
@ -27,8 +27,8 @@ type TestTun struct {
TxPackets chan []byte // Packets transmitted outside by nebula TxPackets chan []byte // Packets transmitted outside by nebula
} }
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*TestTun, error) { func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) {
_, routes, err := getAllRoutesFromConfig(c, cidr, true) _, routes, err := getAllRoutesFromConfig(c, vpnNetworks, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -39,7 +39,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*TestTun,
return &TestTun{ return &TestTun{
Device: c.GetString("tun.dev", ""), Device: c.GetString("tun.dev", ""),
cidr: cidr, vpnNetworks: vpnNetworks,
Routes: routes, Routes: routes,
routeTree: routeTree, routeTree: routeTree,
l: l, l: l,
@ -48,7 +48,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*TestTun,
}, nil }, nil
} }
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*TestTun, error) { func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*TestTun, error) {
return nil, fmt.Errorf("newTunFromFd not supported") return nil, fmt.Errorf("newTunFromFd not supported")
} }
@ -95,8 +95,8 @@ func (t *TestTun) Activate() error {
return nil return nil
} }
func (t *TestTun) Cidr() netip.Prefix { func (t *TestTun) Networks() []netip.Prefix {
return t.cidr return t.vpnNetworks
} }
func (t *TestTun) Name() string { func (t *TestTun) Name() string {

View File

@ -1,208 +0,0 @@
package overlay
import (
"fmt"
"io"
"net"
"net/netip"
"os/exec"
"strconv"
"sync/atomic"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/util"
"github.com/songgao/water"
)
type waterTun struct {
Device string
cidr netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger
f *net.Interface
*water.Interface
}
func newWaterTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*waterTun, error) {
// NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate()
t := &waterTun{
cidr: cidr,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
}
err := t.reload(c, true)
if err != nil {
return nil, err
}
c.RegisterReloadCallback(func(c *config.C) {
err := t.reload(c, false)
if err != nil {
util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
}
})
return t, nil
}
func (t *waterTun) Activate() error {
var err error
t.Interface, err = water.New(water.Config{
DeviceType: water.TUN,
PlatformSpecificParams: water.PlatformSpecificParams{
ComponentID: "tap0901",
Network: t.cidr.String(),
},
})
if err != nil {
return fmt.Errorf("activate failed: %v", err)
}
t.Device = t.Interface.Name()
// TODO use syscalls instead of exec.Command
err = exec.Command(
`C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "address",
fmt.Sprintf("name=%s", t.Device),
"source=static",
fmt.Sprintf("addr=%s", t.cidr.Addr()),
fmt.Sprintf("mask=%s", net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen())),
"gateway=none",
).Run()
if err != nil {
return fmt.Errorf("failed to run 'netsh' to set address: %s", err)
}
err = exec.Command(
`C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "interface",
t.Device,
fmt.Sprintf("mtu=%d", t.MTU),
).Run()
if err != nil {
return fmt.Errorf("failed to run 'netsh' to set MTU: %s", err)
}
t.f, err = net.InterfaceByName(t.Device)
if err != nil {
return fmt.Errorf("failed to find interface named %s: %v", t.Device, err)
}
err = t.addRoutes(false)
if err != nil {
return err
}
return nil
}
func (t *waterTun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
if err != nil {
return err
}
if !initial && !change {
return nil
}
routeTree, err := makeRouteTree(t.l, routes, false)
if err != nil {
return err
}
// Teach nebula how to handle the routes before establishing them in the system table
oldRoutes := t.Routes.Swap(&routes)
t.routeTree.Store(routeTree)
if !initial {
// Remove first, if the system removes a wanted route hopefully it will be re-added next
t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
// Ensure any routes we actually want are installed
err = t.addRoutes(true)
if err != nil {
// Catch any stray logs
util.LogWithContextIfNeeded("Failed to set routes", err, t.l)
} else {
for _, r := range findRemovedRoutes(routes, *oldRoutes) {
t.l.WithField("route", r).Info("Removed route")
}
}
}
return nil
}
func (t *waterTun) addRoutes(logErrors bool) error {
// Path routes
routes := *t.Routes.Load()
for _, r := range routes {
if !r.Via.IsValid() || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
err := exec.Command(
"C:\\Windows\\System32\\route.exe", "add", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(t.f.Index), "METRIC", strconv.Itoa(r.Metric),
).Run()
if err != nil {
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
if logErrors {
retErr.Log(t.l)
} else {
return retErr
}
} else {
t.l.WithField("route", r).Info("Added route")
}
}
return nil
}
func (t *waterTun) removeRoutes(routes []Route) {
for _, r := range routes {
if !r.Install {
continue
}
err := exec.Command(
"C:\\Windows\\System32\\route.exe", "delete", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(t.f.Index), "METRIC", strconv.Itoa(r.Metric),
).Run()
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
}
}
}
func (t *waterTun) RouteFor(ip netip.Addr) netip.Addr {
r, _ := t.routeTree.Load().Lookup(ip)
return r
}
func (t *waterTun) Cidr() netip.Prefix {
return t.cidr
}
func (t *waterTun) Name() string {
return t.Device
}
func (t *waterTun) Close() error {
if t.Interface == nil {
return nil
}
return t.Interface.Close()
}
func (t *waterTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
}

View File

@ -4,41 +4,267 @@
package overlay package overlay
import ( import (
"crypto"
"fmt" "fmt"
"io"
"net/netip" "net/netip"
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"sync/atomic"
"syscall" "syscall"
"unsafe"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/util"
"github.com/slackhq/nebula/wintun"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
) )
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (Device, error) { const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
type winTun struct {
Device string
vpnNetworks []netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger
tun *wintun.NativeTun
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) {
return nil, fmt.Errorf("newTunFromFd not supported in Windows") return nil, fmt.Errorf("newTunFromFd not supported in Windows")
} }
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (Device, error) { func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) {
useWintun := true err := checkWinTunExists()
if err := checkWinTunExists(); err != nil { if err != nil {
l.WithError(err).Warn("Check Wintun driver failed, fallback to wintap driver") return nil, fmt.Errorf("can not load the wintun driver: %w", err)
useWintun = false
} }
if useWintun { deviceName := c.GetString("tun.dev", "")
device, err := newWinTun(c, l, cidr, multiqueue) guid, err := generateGUIDByDeviceName(deviceName)
if err != nil { if err != nil {
return nil, fmt.Errorf("create Wintun interface failed, %w", err) return nil, fmt.Errorf("generate GUID failed: %w", err)
}
return device, nil
} }
device, err := newWaterTun(c, l, cidr, multiqueue) t := &winTun{
if err != nil { Device: deviceName,
return nil, fmt.Errorf("create wintap driver failed, %w", err) vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
} }
return device, nil
err = t.reload(c, true)
if err != nil {
return nil, err
}
var tunDevice wintun.Device
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
if err != nil {
// Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device.
// Trying a second time resolves the issue.
l.WithError(err).Debug("Failed to create wintun device, retrying")
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
if err != nil {
return nil, fmt.Errorf("create TUN device failed: %w", err)
}
}
t.tun = tunDevice.(*wintun.NativeTun)
c.RegisterReloadCallback(func(c *config.C) {
err := t.reload(c, false)
if err != nil {
util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
}
})
return t, nil
}
func (t *winTun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil {
return err
}
if !initial && !change {
return nil
}
routeTree, err := makeRouteTree(t.l, routes, false)
if err != nil {
return err
}
// Teach nebula how to handle the routes before establishing them in the system table
oldRoutes := t.Routes.Swap(&routes)
t.routeTree.Store(routeTree)
if !initial {
// Remove first, if the system removes a wanted route hopefully it will be re-added next
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
if err != nil {
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
}
// Ensure any routes we actually want are installed
err = t.addRoutes(true)
if err != nil {
// Catch any stray logs
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
}
}
return nil
}
func (t *winTun) Activate() error {
luid := winipcfg.LUID(t.tun.LUID())
err := luid.SetIPAddresses(t.vpnNetworks)
if err != nil {
return fmt.Errorf("failed to set address: %w", err)
}
err = t.addRoutes(false)
if err != nil {
return err
}
return nil
}
func (t *winTun) addRoutes(logErrors bool) error {
luid := winipcfg.LUID(t.tun.LUID())
routes := *t.Routes.Load()
foundDefault4 := false
for _, r := range routes {
if !r.Via.IsValid() || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
// Add our unsafe route
err := luid.AddRoute(r.Cidr, r.Via, uint32(r.Metric))
if err != nil {
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
if logErrors {
retErr.Log(t.l)
continue
} else {
return retErr
}
} else {
t.l.WithField("route", r).Info("Added route")
}
if !foundDefault4 {
if r.Cidr.Bits() == 0 && r.Cidr.Addr().BitLen() == 32 {
foundDefault4 = true
}
}
}
ipif, err := luid.IPInterface(windows.AF_INET)
if err != nil {
return fmt.Errorf("failed to get ip interface: %w", err)
}
ipif.NLMTU = uint32(t.MTU)
if foundDefault4 {
ipif.UseAutomaticMetric = false
ipif.Metric = 0
}
if err := ipif.Set(); err != nil {
return fmt.Errorf("failed to set ip interface: %w", err)
}
return nil
}
func (t *winTun) removeRoutes(routes []Route) error {
luid := winipcfg.LUID(t.tun.LUID())
for _, r := range routes {
if !r.Install {
continue
}
err := luid.DeleteRoute(r.Cidr, r.Via)
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
}
}
return nil
}
func (t *winTun) RouteFor(ip netip.Addr) netip.Addr {
r, _ := t.routeTree.Load().Lookup(ip)
return r
}
func (t *winTun) Networks() []netip.Prefix {
return t.vpnNetworks
}
func (t *winTun) Name() string {
return t.Device
}
func (t *winTun) Read(b []byte) (int, error) {
return t.tun.Read(b, 0)
}
func (t *winTun) Write(b []byte) (int, error) {
return t.tun.Write(b, 0)
}
func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
}
func (t *winTun) Close() error {
// It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes,
// so to be certain, just remove everything before destroying.
luid := winipcfg.LUID(t.tun.LUID())
_ = luid.FlushRoutes(windows.AF_INET)
_ = luid.FlushIPAddresses(windows.AF_INET)
/* We don't support IPV6 yet
_ = luid.FlushRoutes(windows.AF_INET6)
_ = luid.FlushIPAddresses(windows.AF_INET6)
*/
_ = luid.FlushDNS(windows.AF_INET)
return t.tun.Close()
}
func generateGUIDByDeviceName(name string) (*windows.GUID, error) {
// GUID is 128 bit
hash := crypto.MD5.New()
_, err := hash.Write([]byte(tunGUIDLabel))
if err != nil {
return nil, err
}
_, err = hash.Write([]byte(name))
if err != nil {
return nil, err
}
sum := hash.Sum(nil)
return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil
} }
func checkWinTunExists() error { func checkWinTunExists() error {

View File

@ -1,252 +0,0 @@
package overlay
import (
"crypto"
"fmt"
"io"
"net/netip"
"sync/atomic"
"unsafe"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/util"
"github.com/slackhq/nebula/wintun"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
)
const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
type winTun struct {
Device string
cidr netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger
tun *wintun.NativeTun
}
func generateGUIDByDeviceName(name string) (*windows.GUID, error) {
// GUID is 128 bit
hash := crypto.MD5.New()
_, err := hash.Write([]byte(tunGUIDLabel))
if err != nil {
return nil, err
}
_, err = hash.Write([]byte(name))
if err != nil {
return nil, err
}
sum := hash.Sum(nil)
return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil
}
func newWinTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*winTun, error) {
deviceName := c.GetString("tun.dev", "")
guid, err := generateGUIDByDeviceName(deviceName)
if err != nil {
return nil, fmt.Errorf("generate GUID failed: %w", err)
}
t := &winTun{
Device: deviceName,
cidr: cidr,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
}
err = t.reload(c, true)
if err != nil {
return nil, err
}
var tunDevice wintun.Device
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
if err != nil {
// Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device.
// Trying a second time resolves the issue.
l.WithError(err).Debug("Failed to create wintun device, retrying")
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
if err != nil {
return nil, fmt.Errorf("create TUN device failed: %w", err)
}
}
t.tun = tunDevice.(*wintun.NativeTun)
c.RegisterReloadCallback(func(c *config.C) {
err := t.reload(c, false)
if err != nil {
util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
}
})
return t, nil
}
func (t *winTun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
if err != nil {
return err
}
if !initial && !change {
return nil
}
routeTree, err := makeRouteTree(t.l, routes, false)
if err != nil {
return err
}
// Teach nebula how to handle the routes before establishing them in the system table
oldRoutes := t.Routes.Swap(&routes)
t.routeTree.Store(routeTree)
if !initial {
// Remove first, if the system removes a wanted route hopefully it will be re-added next
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
if err != nil {
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
}
// Ensure any routes we actually want are installed
err = t.addRoutes(true)
if err != nil {
// Catch any stray logs
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
}
}
return nil
}
func (t *winTun) Activate() error {
luid := winipcfg.LUID(t.tun.LUID())
err := luid.SetIPAddresses([]netip.Prefix{t.cidr})
if err != nil {
return fmt.Errorf("failed to set address: %w", err)
}
err = t.addRoutes(false)
if err != nil {
return err
}
return nil
}
func (t *winTun) addRoutes(logErrors bool) error {
luid := winipcfg.LUID(t.tun.LUID())
routes := *t.Routes.Load()
foundDefault4 := false
for _, r := range routes {
if !r.Via.IsValid() || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
// Add our unsafe route
err := luid.AddRoute(r.Cidr, r.Via, uint32(r.Metric))
if err != nil {
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
if logErrors {
retErr.Log(t.l)
continue
} else {
return retErr
}
} else {
t.l.WithField("route", r).Info("Added route")
}
if !foundDefault4 {
if r.Cidr.Bits() == 0 && r.Cidr.Addr().BitLen() == 32 {
foundDefault4 = true
}
}
}
ipif, err := luid.IPInterface(windows.AF_INET)
if err != nil {
return fmt.Errorf("failed to get ip interface: %w", err)
}
ipif.NLMTU = uint32(t.MTU)
if foundDefault4 {
ipif.UseAutomaticMetric = false
ipif.Metric = 0
}
if err := ipif.Set(); err != nil {
return fmt.Errorf("failed to set ip interface: %w", err)
}
return nil
}
func (t *winTun) removeRoutes(routes []Route) error {
luid := winipcfg.LUID(t.tun.LUID())
for _, r := range routes {
if !r.Install {
continue
}
err := luid.DeleteRoute(r.Cidr, r.Via)
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
}
}
return nil
}
func (t *winTun) RouteFor(ip netip.Addr) netip.Addr {
r, _ := t.routeTree.Load().Lookup(ip)
return r
}
func (t *winTun) Cidr() netip.Prefix {
return t.cidr
}
func (t *winTun) Name() string {
return t.Device
}
func (t *winTun) Read(b []byte) (int, error) {
return t.tun.Read(b, 0)
}
func (t *winTun) Write(b []byte) (int, error) {
return t.tun.Write(b, 0)
}
func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
}
func (t *winTun) Close() error {
// It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes,
// so to be certain, just remove everything before destroying.
luid := winipcfg.LUID(t.tun.LUID())
_ = luid.FlushRoutes(windows.AF_INET)
_ = luid.FlushIPAddresses(windows.AF_INET)
/* We don't support IPV6 yet
_ = luid.FlushRoutes(windows.AF_INET6)
_ = luid.FlushIPAddresses(windows.AF_INET6)
*/
_ = luid.FlushDNS(windows.AF_INET)
return t.tun.Close()
}

View File

@ -8,16 +8,16 @@ import (
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
) )
func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) { func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
return NewUserDevice(tunCidr) return NewUserDevice(vpnNetworks)
} }
func NewUserDevice(tunCidr netip.Prefix) (Device, error) { func NewUserDevice(vpnNetworks []netip.Prefix) (Device, error) {
// these pipes guarantee each write/read will match 1:1 // these pipes guarantee each write/read will match 1:1
or, ow := io.Pipe() or, ow := io.Pipe()
ir, iw := io.Pipe() ir, iw := io.Pipe()
return &UserDevice{ return &UserDevice{
tunCidr: tunCidr, vpnNetworks: vpnNetworks,
outboundReader: or, outboundReader: or,
outboundWriter: ow, outboundWriter: ow,
inboundReader: ir, inboundReader: ir,
@ -26,7 +26,7 @@ func NewUserDevice(tunCidr netip.Prefix) (Device, error) {
} }
type UserDevice struct { type UserDevice struct {
tunCidr netip.Prefix vpnNetworks []netip.Prefix
outboundReader *io.PipeReader outboundReader *io.PipeReader
outboundWriter *io.PipeWriter outboundWriter *io.PipeWriter
@ -38,7 +38,7 @@ type UserDevice struct {
func (d *UserDevice) Activate() error { func (d *UserDevice) Activate() error {
return nil return nil
} }
func (d *UserDevice) Cidr() netip.Prefix { return d.tunCidr } func (d *UserDevice) Networks() []netip.Prefix { return d.vpnNetworks }
func (d *UserDevice) Name() string { return "faketun0" } func (d *UserDevice) Name() string { return "faketun0" }
func (d *UserDevice) RouteFor(ip netip.Addr) netip.Addr { return ip } func (d *UserDevice) RouteFor(ip netip.Addr) netip.Addr { return ip }
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {

417
pki.go
View File

@ -1,13 +1,19 @@
package nebula package nebula
import ( import (
"encoding/binary"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"net"
"net/netip"
"os" "os"
"slices"
"strings" "strings"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
@ -21,12 +27,22 @@ type PKI struct {
} }
type CertState struct { type CertState struct {
Certificate cert.Certificate v1Cert cert.Certificate
RawCertificate []byte v1HandshakeBytes []byte
RawCertificateNoKey []byte
PublicKey []byte v2Cert cert.Certificate
PrivateKey []byte v2HandshakeBytes []byte
defaultVersion cert.Version
privateKey []byte
pkcs11Backed bool pkcs11Backed bool
cipher string
myVpnNetworks []netip.Prefix
myVpnNetworksTable *bart.Table[struct{}]
myVpnAddrs []netip.Addr
myVpnAddrsTable *bart.Table[struct{}]
myVpnBroadcastAddrsTable *bart.Table[struct{}]
} }
func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) { func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
@ -46,16 +62,26 @@ func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
return pki, nil return pki, nil
} }
func (p *PKI) GetCertState() *CertState {
return p.cs.Load()
}
func (p *PKI) GetCAPool() *cert.CAPool { func (p *PKI) GetCAPool() *cert.CAPool {
return p.caPool.Load() return p.caPool.Load()
} }
func (p *PKI) getCertState() *CertState {
return p.cs.Load()
}
// TODO: We should remove this
func (p *PKI) getDefaultCertificate() cert.Certificate {
return p.cs.Load().GetDefaultCertificate()
}
// TODO: We should remove this
func (p *PKI) getCertificate(v cert.Version) cert.Certificate {
return p.cs.Load().getCertificate(v)
}
func (p *PKI) reload(c *config.C, initial bool) error { func (p *PKI) reload(c *config.C, initial bool) error {
err := p.reloadCert(c, initial) err := p.reloadCerts(c, initial)
if err != nil { if err != nil {
if initial { if initial {
return err return err
@ -74,33 +100,94 @@ func (p *PKI) reload(c *config.C, initial bool) error {
return nil return nil
} }
func (p *PKI) reloadCert(c *config.C, initial bool) *util.ContextualError { func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
cs, err := newCertStateFromConfig(c) newState, err := newCertStateFromConfig(c)
if err != nil { if err != nil {
return util.NewContextualError("Could not load client cert", nil, err) return util.NewContextualError("Could not load client cert", nil, err)
} }
if !initial { if !initial {
//TODO: include check for mask equality as well currentState := p.cs.Load()
if newState.v1Cert != nil {
if currentState.v1Cert == nil {
return util.NewContextualError("v1 certificate was added, restart required", nil, err)
}
// did IP in cert change? if so, don't set // did IP in cert change? if so, don't set
currentCert := p.cs.Load().Certificate if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
oldIPs := currentCert.Networks()
newIPs := cs.Certificate.Networks()
if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() {
return util.NewContextualError( return util.NewContextualError(
"Networks in new cert was different from old", "Networks in new cert was different from old",
m{"new_network": newIPs[0], "old_network": oldIPs[0]}, m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks()},
nil,
)
}
if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
return util.NewContextualError(
"Curve in new cert was different from old",
m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve()},
nil,
)
}
} else if currentState.v1Cert != nil {
//TODO: we should be able to tear this down
return util.NewContextualError("v1 certificate was removed, restart required", nil, err)
}
if newState.v2Cert != nil {
if currentState.v2Cert == nil {
return util.NewContextualError("v2 certificate was added, restart required", nil, err)
}
// did IP in cert change? if so, don't set
if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
return util.NewContextualError(
"Networks in new cert was different from old",
m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks()},
nil,
)
}
if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
return util.NewContextualError(
"Curve in new cert was different from old",
m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve()},
nil,
)
}
} else if currentState.v2Cert != nil {
return util.NewContextualError("v2 certificate was removed, restart required", nil, err)
}
// Cipher cant be hot swapped so just leave it at what it was before
newState.cipher = currentState.cipher
} else {
newState.cipher = c.GetString("cipher", "aes")
//TODO: this sucks and we should make it not a global
switch newState.cipher {
case "aes":
noiseEndianness = binary.BigEndian
case "chachapoly":
noiseEndianness = binary.LittleEndian
default:
return util.NewContextualError(
"unknown cipher",
m{"cipher": newState.cipher},
nil, nil,
) )
} }
} }
p.cs.Store(cs) p.cs.Store(newState)
//TODO: newState needs a stringer that does json
if initial { if initial {
p.l.WithField("cert", cs.Certificate).Debug("Client nebula certificate") p.l.WithField("cert", newState).Debug("Client nebula certificate(s)")
} else { } else {
p.l.WithField("cert", cs.Certificate).Info("Client cert refreshed from disk") p.l.WithField("cert", newState).Info("Client certificate(s) refreshed from disk")
} }
return nil return nil
} }
@ -116,55 +203,65 @@ func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError {
return nil return nil
} }
func newCertState(certificate cert.Certificate, pkcs11backed bool, privateKey []byte) (*CertState, error) { func (cs *CertState) GetDefaultCertificate() cert.Certificate {
// Marshal the certificate to ensure it is valid c := cs.getCertificate(cs.defaultVersion)
rawCertificate, err := certificate.Marshal() if c == nil {
if err != nil { panic("No default certificate found")
return nil, fmt.Errorf("invalid nebula certificate on interface: %s", err)
} }
return c
publicKey := certificate.PublicKey()
cs := &CertState{
RawCertificate: rawCertificate,
Certificate: certificate,
PrivateKey: privateKey,
PublicKey: publicKey,
pkcs11Backed: pkcs11backed,
}
rawCertNoKey, err := cs.Certificate.MarshalForHandshakes()
if err != nil {
return nil, fmt.Errorf("error marshalling certificate no key: %s", err)
}
cs.RawCertificateNoKey = rawCertNoKey
return cs, nil
} }
func loadPrivateKey(privPathOrPEM string) (rawKey []byte, curve cert.Curve, isPkcs11 bool, err error) { func (cs *CertState) getCertificate(v cert.Version) cert.Certificate {
var pemPrivateKey []byte switch v {
if strings.Contains(privPathOrPEM, "-----BEGIN") { case cert.Version1:
pemPrivateKey = []byte(privPathOrPEM) return cs.v1Cert
privPathOrPEM = "<inline>" case cert.Version2:
rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey) return cs.v2Cert
if err != nil {
return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
}
} else if strings.HasPrefix(privPathOrPEM, "pkcs11:") {
rawKey = []byte(privPathOrPEM)
return rawKey, cert.Curve_P256, true, nil
} else {
pemPrivateKey, err = os.ReadFile(privPathOrPEM)
if err != nil {
return nil, curve, false, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err)
}
rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey)
if err != nil {
return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
}
} }
return return nil
}
// getHandshakeBytes returns the cached bytes to be used in a handshake message for the requested version.
// Callers must check if the return []byte is nil.
func (cs *CertState) getHandshakeBytes(v cert.Version) []byte {
switch v {
case cert.Version1:
return cs.v1HandshakeBytes
case cert.Version2:
return cs.v2HandshakeBytes
default:
return nil
}
}
func (cs *CertState) String() string {
b, err := cs.MarshalJSON()
if err != nil {
return fmt.Sprintf("error marshaling certificate state: %v", err)
}
return string(b)
}
func (cs *CertState) MarshalJSON() ([]byte, error) {
msg := []json.RawMessage{}
if cs.v1Cert != nil {
b, err := cs.v1Cert.MarshalJSON()
if err != nil {
return nil, err
}
msg = append(msg, b)
}
if cs.v2Cert != nil {
b, err := cs.v2Cert.MarshalJSON()
if err != nil {
return nil, err
}
msg = append(msg, b)
}
return json.Marshal(msg)
} }
func newCertStateFromConfig(c *config.C) (*CertState, error) { func newCertStateFromConfig(c *config.C) (*CertState, error) {
@ -198,24 +295,198 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) {
} }
} }
nebulaCert, _, err := cert.UnmarshalCertificateFromPEM(rawCert) var crt, v1, v2 cert.Certificate
for {
// Load the certificate
crt, rawCert, err = loadCertificate(rawCert)
if err != nil { if err != nil {
return nil, fmt.Errorf("error while unmarshaling pki.cert %s: %s", pubPathOrPEM, err) //TODO: check error
return nil, err
} }
if nebulaCert.Expired(time.Now()) { switch crt.Version() {
return nil, fmt.Errorf("nebula certificate for this host is expired") case cert.Version1:
if v1 != nil {
return nil, fmt.Errorf("v1 certificate already found in pki.cert")
}
v1 = crt
case cert.Version2:
if v2 != nil {
return nil, fmt.Errorf("v2 certificate already found in pki.cert")
}
v2 = crt
default:
return nil, fmt.Errorf("unknown certificate version %v", crt.Version())
} }
if len(nebulaCert.Networks()) == 0 { if len(rawCert) == 0 || strings.TrimSpace(string(rawCert)) == "" {
return nil, fmt.Errorf("no networks encoded in certificate") break
}
} }
if err = nebulaCert.VerifyPrivateKey(curve, rawKey); err != nil { if v1 == nil && v2 == nil {
return nil, errors.New("no certificates found in pki.cert")
}
useDefaultVersion := uint32(1)
if v1 == nil {
// The only condition that requires v2 as the default is if only a v2 certificate is present
// We do this to avoid having to configure it specifically in the config file
useDefaultVersion = 2
}
rawDefaultVersion := c.GetUint32("pki.default_version", useDefaultVersion)
var defaultVersion cert.Version
switch rawDefaultVersion {
case 1:
if v1 == nil {
return nil, fmt.Errorf("can not use pki.default_version 1 without a v1 certificate in pki.cert")
}
defaultVersion = cert.Version1
case 2:
defaultVersion = cert.Version2
default:
return nil, fmt.Errorf("unknown pki.default_version: %v", rawDefaultVersion)
}
return newCertState(defaultVersion, v1, v2, isPkcs11, curve, rawKey)
}
func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) {
cs := CertState{
privateKey: privateKey,
pkcs11Backed: pkcs11backed,
myVpnNetworksTable: new(bart.Table[struct{}]),
myVpnAddrsTable: new(bart.Table[struct{}]),
myVpnBroadcastAddrsTable: new(bart.Table[struct{}]),
}
if v1 != nil && v2 != nil {
if !slices.Equal(v1.PublicKey(), v2.PublicKey()) {
return nil, util.NewContextualError("v1 and v2 public keys are not the same, ignoring", nil, nil)
}
if v1.Curve() != v2.Curve() {
return nil, util.NewContextualError("v1 and v2 curve are not the same, ignoring", nil, nil)
}
//TODO: make sure v2 has v1s address
cs.defaultVersion = dv
}
if v1 != nil {
if pkcs11backed {
//TODO: We do not currently have a method to verify a public private key pair when the private key is in an hsm
} else {
if err := v1.VerifyPrivateKey(privateKeyCurve, privateKey); err != nil {
return nil, fmt.Errorf("private key is not a pair with public key in nebula cert") return nil, fmt.Errorf("private key is not a pair with public key in nebula cert")
} }
}
return newCertState(nebulaCert, isPkcs11, rawKey) v1hs, err := v1.MarshalForHandshakes()
if err != nil {
return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err)
}
cs.v1Cert = v1
cs.v1HandshakeBytes = v1hs
if cs.defaultVersion == 0 {
cs.defaultVersion = cert.Version1
}
}
if v2 != nil {
if pkcs11backed {
//TODO: We do not currently have a method to verify a public private key pair when the private key is in an hsm
} else {
if err := v2.VerifyPrivateKey(privateKeyCurve, privateKey); err != nil {
return nil, fmt.Errorf("private key is not a pair with public key in nebula cert")
}
}
v2hs, err := v2.MarshalForHandshakes()
if err != nil {
return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err)
}
cs.v2Cert = v2
cs.v2HandshakeBytes = v2hs
if cs.defaultVersion == 0 {
cs.defaultVersion = cert.Version2
}
}
var crt cert.Certificate
crt = cs.getCertificate(cert.Version2)
if crt == nil {
// v2 certificates are a superset, only look at v1 if its all we have
crt = cs.getCertificate(cert.Version1)
}
for _, network := range crt.Networks() {
cs.myVpnNetworks = append(cs.myVpnNetworks, network)
cs.myVpnNetworksTable.Insert(network, struct{}{})
cs.myVpnAddrs = append(cs.myVpnAddrs, network.Addr())
cs.myVpnAddrsTable.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()), struct{}{})
if network.Addr().Is4() {
addr := network.Masked().Addr().As4()
mask := net.CIDRMask(network.Bits(), network.Addr().BitLen())
binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(mask))
cs.myVpnBroadcastAddrsTable.Insert(netip.PrefixFrom(netip.AddrFrom4(addr), network.Addr().BitLen()), struct{}{})
}
}
return &cs, nil
}
func loadPrivateKey(privPathOrPEM string) (rawKey []byte, curve cert.Curve, isPkcs11 bool, err error) {
var pemPrivateKey []byte
if strings.Contains(privPathOrPEM, "-----BEGIN") {
pemPrivateKey = []byte(privPathOrPEM)
privPathOrPEM = "<inline>"
rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey)
if err != nil {
return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
}
} else if strings.HasPrefix(privPathOrPEM, "pkcs11:") {
rawKey = []byte(privPathOrPEM)
return rawKey, cert.Curve_P256, true, nil
} else {
pemPrivateKey, err = os.ReadFile(privPathOrPEM)
if err != nil {
return nil, curve, false, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err)
}
rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey)
if err != nil {
return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
}
}
return
}
func loadCertificate(b []byte) (cert.Certificate, []byte, error) {
c, b, err := cert.UnmarshalCertificateFromPEM(b)
if err != nil {
return nil, b, fmt.Errorf("error while unmarshaling pki.cert: %w", err)
}
if c.Expired(time.Now()) {
return nil, b, fmt.Errorf("nebula certificate for this host is expired")
}
if len(c.Networks()) == 0 {
return nil, b, fmt.Errorf("no networks encoded in certificate")
}
if c.IsCA() {
return nil, b, fmt.Errorf("host certificate is a CA certificate")
}
return c, b, nil
} }
func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) { func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {

View File

@ -9,6 +9,7 @@ import (
"sync/atomic" "sync/atomic"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
) )
@ -72,7 +73,7 @@ func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp neti
Type: relayType, Type: relayType,
State: state, State: state,
LocalIndex: index, LocalIndex: index,
PeerIp: vpnIp, PeerAddr: vpnIp,
} }
if remoteIdx != nil { if remoteIdx != nil {
@ -91,40 +92,60 @@ func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp neti
func (rm *relayManager) EstablishRelay(relayHostInfo *HostInfo, m *NebulaControl) (*Relay, error) { func (rm *relayManager) EstablishRelay(relayHostInfo *HostInfo, m *NebulaControl) (*Relay, error) {
relay, ok := relayHostInfo.relayState.CompleteRelayByIdx(m.InitiatorRelayIndex, m.ResponderRelayIndex) relay, ok := relayHostInfo.relayState.CompleteRelayByIdx(m.InitiatorRelayIndex, m.ResponderRelayIndex)
if !ok { if !ok {
rm.l.WithFields(logrus.Fields{"relay": relayHostInfo.vpnIp, //TODO: we need to handle possibly logging deprecated fields as well
rm.l.WithFields(logrus.Fields{"relay": relayHostInfo.vpnAddrs[0],
"initiatorRelayIndex": m.InitiatorRelayIndex, "initiatorRelayIndex": m.InitiatorRelayIndex,
"relayFrom": m.RelayFromIp, "relayFrom": m.RelayFromAddr,
"relayTo": m.RelayToIp}).Info("relayManager failed to update relay") "relayTo": m.RelayToAddr}).Info("relayManager failed to update relay")
return nil, fmt.Errorf("unknown relay") return nil, fmt.Errorf("unknown relay")
} }
return relay, nil return relay, nil
} }
func (rm *relayManager) HandleControlMsg(h *HostInfo, m *NebulaControl, f *Interface) { func (rm *relayManager) HandleControlMsg(h *HostInfo, d []byte, f *Interface) {
msg := &NebulaControl{}
err := msg.Unmarshal(d)
if err != nil {
h.logger(f.l).WithError(err).Error("Failed to unmarshal control message")
return
}
switch m.Type { var v cert.Version
if msg.OldRelayFromAddr > 0 || msg.OldRelayToAddr > 0 {
v = cert.Version1
//TODO: yeah this is junk but maybe its less junky than the other options
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], msg.OldRelayFromAddr)
msg.RelayFromAddr = netAddrToProtoAddr(netip.AddrFrom4(b))
binary.BigEndian.PutUint32(b[:], msg.OldRelayToAddr)
msg.RelayToAddr = netAddrToProtoAddr(netip.AddrFrom4(b))
} else {
v = cert.Version2
}
switch msg.Type {
case NebulaControl_CreateRelayRequest: case NebulaControl_CreateRelayRequest:
rm.handleCreateRelayRequest(h, f, m) rm.handleCreateRelayRequest(v, h, f, msg)
case NebulaControl_CreateRelayResponse: case NebulaControl_CreateRelayResponse:
rm.handleCreateRelayResponse(h, f, m) rm.handleCreateRelayResponse(v, h, f, msg)
} }
} }
func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *NebulaControl) { func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) {
rm.l.WithFields(logrus.Fields{ rm.l.WithFields(logrus.Fields{
"relayFrom": m.RelayFromIp, "relayFrom": m.RelayFromAddr,
"relayTo": m.RelayToIp, "relayTo": m.RelayToAddr,
"initiatorRelayIndex": m.InitiatorRelayIndex, "initiatorRelayIndex": m.InitiatorRelayIndex,
"responderRelayIndex": m.ResponderRelayIndex, "responderRelayIndex": m.ResponderRelayIndex,
"vpnIp": h.vpnIp}). "vpnAddrs": h.vpnAddrs}).
Info("handleCreateRelayResponse") Info("handleCreateRelayResponse")
target := m.RelayToIp
//TODO: IPV6-WORK target := m.RelayToAddr
b := [4]byte{} targetAddr := protoAddrToNetAddr(target)
binary.BigEndian.PutUint32(b[:], m.RelayToIp)
targetAddr := netip.AddrFrom4(b)
relay, err := rm.EstablishRelay(h, m) relay, err := rm.EstablishRelay(h, m)
if err != nil { if err != nil {
@ -136,68 +157,79 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *
return return
} }
// I'm the middle man. Let the initiator know that the I've established the relay they requested. // I'm the middle man. Let the initiator know that the I've established the relay they requested.
peerHostInfo := rm.hostmap.QueryVpnIp(relay.PeerIp) peerHostInfo := rm.hostmap.QueryVpnAddr(relay.PeerAddr)
if peerHostInfo == nil { if peerHostInfo == nil {
rm.l.WithField("relayTo", relay.PeerIp).Error("Can't find a HostInfo for peer") rm.l.WithField("relayTo", relay.PeerAddr).Error("Can't find a HostInfo for peer")
return return
} }
peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(targetAddr) peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(targetAddr)
if !ok { if !ok {
rm.l.WithField("relayTo", peerHostInfo.vpnIp).Error("peerRelay does not have Relay state for relayTo") rm.l.WithField("relayTo", peerHostInfo.vpnAddrs[0]).Error("peerRelay does not have Relay state for relayTo")
return return
} }
if peerRelay.State == PeerRequested { if peerRelay.State == PeerRequested {
//TODO: IPV6-WORK
b = peerHostInfo.vpnIp.As4()
peerRelay.State = Established peerRelay.State = Established
resp := NebulaControl{ resp := NebulaControl{
Type: NebulaControl_CreateRelayResponse, Type: NebulaControl_CreateRelayResponse,
ResponderRelayIndex: peerRelay.LocalIndex, ResponderRelayIndex: peerRelay.LocalIndex,
InitiatorRelayIndex: peerRelay.RemoteIndex, InitiatorRelayIndex: peerRelay.RemoteIndex,
RelayFromIp: binary.BigEndian.Uint32(b[:]),
RelayToIp: uint32(target),
} }
if v == cert.Version1 {
peer := peerHostInfo.vpnAddrs[0]
if !peer.Is4() {
//TODO: log cant do it
return
}
b := peer.As4()
resp.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
b = targetAddr.As4()
resp.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
} else {
resp.RelayFromAddr = netAddrToProtoAddr(peerHostInfo.vpnAddrs[0])
resp.RelayToAddr = target
}
msg, err := resp.Marshal() msg, err := resp.Marshal()
if err != nil { if err != nil {
rm.l. rm.l.WithError(err).
WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay") Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay")
} else { } else {
f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu)) f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu))
rm.l.WithFields(logrus.Fields{ rm.l.WithFields(logrus.Fields{
"relayFrom": resp.RelayFromIp, "relayFrom": resp.RelayFromAddr,
"relayTo": resp.RelayToIp, "relayTo": resp.RelayToAddr,
"initiatorRelayIndex": resp.InitiatorRelayIndex, "initiatorRelayIndex": resp.InitiatorRelayIndex,
"responderRelayIndex": resp.ResponderRelayIndex, "responderRelayIndex": resp.ResponderRelayIndex,
"vpnIp": peerHostInfo.vpnIp}). "vpnAddrs": peerHostInfo.vpnAddrs}).
Info("send CreateRelayResponse") Info("send CreateRelayResponse")
} }
} }
} }
func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *NebulaControl) { func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) {
//TODO: IPV6-WORK from := protoAddrToNetAddr(m.RelayFromAddr)
b := [4]byte{} target := protoAddrToNetAddr(m.RelayToAddr)
binary.BigEndian.PutUint32(b[:], m.RelayFromIp)
from := netip.AddrFrom4(b)
binary.BigEndian.PutUint32(b[:], m.RelayToIp)
target := netip.AddrFrom4(b)
logMsg := rm.l.WithFields(logrus.Fields{ logMsg := rm.l.WithFields(logrus.Fields{
"relayFrom": from, "relayFrom": from,
"relayTo": target, "relayTo": target,
"initiatorRelayIndex": m.InitiatorRelayIndex, "initiatorRelayIndex": m.InitiatorRelayIndex,
"vpnIp": h.vpnIp}) "vpnAddrs": h.vpnAddrs})
logMsg.Info("handleCreateRelayRequest") logMsg.Info("handleCreateRelayRequest")
// Is the source of the relay me? This should never happen, but did happen due to // Is the source of the relay me? This should never happen, but did happen due to
// an issue migrating relays over to newly re-handshaked host info objects. // an issue migrating relays over to newly re-handshaked host info objects.
if from == f.myVpnNet.Addr() { _, found := f.myVpnAddrsTable.Lookup(from)
if found {
logMsg.WithField("myIP", from).Error("Discarding relay request from myself") logMsg.WithField("myIP", from).Error("Discarding relay request from myself")
return return
} }
// Is the target of the relay me? // Is the target of the relay me?
if target == f.myVpnNet.Addr() { _, found = f.myVpnAddrsTable.Lookup(target)
if found {
existingRelay, ok := h.relayState.QueryRelayForByIp(from) existingRelay, ok := h.relayState.QueryRelayForByIp(from)
if ok { if ok {
switch existingRelay.State { switch existingRelay.State {
@ -230,17 +262,22 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
return return
} }
//TODO: IPV6-WORK
fromB := from.As4()
targetB := target.As4()
resp := NebulaControl{ resp := NebulaControl{
Type: NebulaControl_CreateRelayResponse, Type: NebulaControl_CreateRelayResponse,
ResponderRelayIndex: relay.LocalIndex, ResponderRelayIndex: relay.LocalIndex,
InitiatorRelayIndex: relay.RemoteIndex, InitiatorRelayIndex: relay.RemoteIndex,
RelayFromIp: binary.BigEndian.Uint32(fromB[:]),
RelayToIp: binary.BigEndian.Uint32(targetB[:]),
} }
if v == cert.Version1 {
b := from.As4()
resp.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
b = target.As4()
resp.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
} else {
resp.RelayFromAddr = netAddrToProtoAddr(from)
resp.RelayToAddr = netAddrToProtoAddr(target)
}
msg, err := resp.Marshal() msg, err := resp.Marshal()
if err != nil { if err != nil {
logMsg. logMsg.
@ -253,7 +290,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
"relayTo": target, "relayTo": target,
"initiatorRelayIndex": resp.InitiatorRelayIndex, "initiatorRelayIndex": resp.InitiatorRelayIndex,
"responderRelayIndex": resp.ResponderRelayIndex, "responderRelayIndex": resp.ResponderRelayIndex,
"vpnIp": h.vpnIp}). "vpnAddrs": h.vpnAddrs}).
Info("send CreateRelayResponse") Info("send CreateRelayResponse")
} }
return return
@ -262,7 +299,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
if !rm.GetAmRelay() { if !rm.GetAmRelay() {
return return
} }
peer := rm.hostmap.QueryVpnIp(target) peer := rm.hostmap.QueryVpnAddr(target)
if peer == nil { if peer == nil {
// Try to establish a connection to this host. If we get a future relay request, // Try to establish a connection to this host. If we get a future relay request,
// we'll be ready! // we'll be ready!
@ -291,17 +328,27 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
sendCreateRequest = true sendCreateRequest = true
} }
if sendCreateRequest { if sendCreateRequest {
//TODO: IPV6-WORK
fromB := h.vpnIp.As4()
targetB := target.As4()
// Send a CreateRelayRequest to the peer. // Send a CreateRelayRequest to the peer.
req := NebulaControl{ req := NebulaControl{
Type: NebulaControl_CreateRelayRequest, Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: index, InitiatorRelayIndex: index,
RelayFromIp: binary.BigEndian.Uint32(fromB[:]),
RelayToIp: binary.BigEndian.Uint32(targetB[:]),
} }
if v == cert.Version1 {
if !h.vpnAddrs[0].Is4() {
//TODO: log it
return
}
b := h.vpnAddrs[0].As4()
req.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
b = target.As4()
req.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
} else {
req.RelayFromAddr = netAddrToProtoAddr(h.vpnAddrs[0])
req.RelayToAddr = netAddrToProtoAddr(target)
}
msg, err := req.Marshal() msg, err := req.Marshal()
if err != nil { if err != nil {
logMsg. logMsg.
@ -310,11 +357,11 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu)) f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu))
rm.l.WithFields(logrus.Fields{ rm.l.WithFields(logrus.Fields{
//TODO: IPV6-WORK another lazy used to use the req object //TODO: IPV6-WORK another lazy used to use the req object
"relayFrom": h.vpnIp, "relayFrom": h.vpnAddrs[0],
"relayTo": target, "relayTo": target,
"initiatorRelayIndex": req.InitiatorRelayIndex, "initiatorRelayIndex": req.InitiatorRelayIndex,
"responderRelayIndex": req.ResponderRelayIndex, "responderRelayIndex": req.ResponderRelayIndex,
"vpnIp": target}). "vpnAddr": target}).
Info("send CreateRelayRequest") Info("send CreateRelayRequest")
} }
} }
@ -342,16 +389,28 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
"existingRemoteIndex": relay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest") "existingRemoteIndex": relay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest")
return return
} }
//TODO: IPV6-WORK
fromB := h.vpnIp.As4()
targetB := target.As4()
resp := NebulaControl{ resp := NebulaControl{
Type: NebulaControl_CreateRelayResponse, Type: NebulaControl_CreateRelayResponse,
ResponderRelayIndex: relay.LocalIndex, ResponderRelayIndex: relay.LocalIndex,
InitiatorRelayIndex: relay.RemoteIndex, InitiatorRelayIndex: relay.RemoteIndex,
RelayFromIp: binary.BigEndian.Uint32(fromB[:]),
RelayToIp: binary.BigEndian.Uint32(targetB[:]),
} }
if v == cert.Version1 {
if !h.vpnAddrs[0].Is4() {
//TODO: log it
return
}
b := h.vpnAddrs[0].As4()
resp.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
b = target.As4()
resp.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
} else {
resp.RelayFromAddr = netAddrToProtoAddr(h.vpnAddrs[0])
resp.RelayToAddr = netAddrToProtoAddr(target)
}
msg, err := resp.Marshal() msg, err := resp.Marshal()
if err != nil { if err != nil {
rm.l. rm.l.
@ -360,11 +419,11 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu)) f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu))
rm.l.WithFields(logrus.Fields{ rm.l.WithFields(logrus.Fields{
//TODO: IPV6-WORK more lazy, used to use resp object //TODO: IPV6-WORK more lazy, used to use resp object
"relayFrom": h.vpnIp, "relayFrom": h.vpnAddrs[0],
"relayTo": target, "relayTo": target,
"initiatorRelayIndex": resp.InitiatorRelayIndex, "initiatorRelayIndex": resp.InitiatorRelayIndex,
"responderRelayIndex": resp.ResponderRelayIndex, "responderRelayIndex": resp.ResponderRelayIndex,
"vpnIp": h.vpnIp}). "vpnAddrs": h.vpnAddrs}).
Info("send CreateRelayResponse") Info("send CreateRelayResponse")
} }

View File

@ -17,8 +17,8 @@ import (
type forEachFunc func(addr netip.AddrPort, preferred bool) type forEachFunc func(addr netip.AddrPort, preferred bool)
// The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate) // The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate)
type checkFuncV4 func(vpnIp netip.Addr, to *Ip4AndPort) bool type checkFuncV4 func(vpnIp netip.Addr, to *V4AddrPort) bool
type checkFuncV6 func(vpnIp netip.Addr, to *Ip6AndPort) bool type checkFuncV6 func(vpnIp netip.Addr, to *V6AddrPort) bool
// CacheMap is a struct that better represents the lighthouse cache for humans // CacheMap is a struct that better represents the lighthouse cache for humans
// The string key is the owners vpnIp // The string key is the owners vpnIp
@ -48,14 +48,14 @@ type cacheRelay struct {
// cacheV4 stores learned and reported ipv4 records under cache // cacheV4 stores learned and reported ipv4 records under cache
type cacheV4 struct { type cacheV4 struct {
learned *Ip4AndPort learned *V4AddrPort
reported []*Ip4AndPort reported []*V4AddrPort
} }
// cacheV4 stores learned and reported ipv6 records under cache // cacheV4 stores learned and reported ipv6 records under cache
type cacheV6 struct { type cacheV6 struct {
learned *Ip6AndPort learned *V6AddrPort
reported []*Ip6AndPort reported []*V6AddrPort
} }
type hostnamePort struct { type hostnamePort struct {
@ -170,7 +170,7 @@ func (hr *hostnamesResults) Cancel() {
} }
} }
func (hr *hostnamesResults) GetIPs() []netip.AddrPort { func (hr *hostnamesResults) GetAddrs() []netip.AddrPort {
var retSlice []netip.AddrPort var retSlice []netip.AddrPort
if hr != nil { if hr != nil {
p := hr.ips.Load() p := hr.ips.Load()
@ -189,6 +189,9 @@ type RemoteList struct {
// Every interaction with internals requires a lock! // Every interaction with internals requires a lock!
sync.RWMutex sync.RWMutex
// The full list of vpn addresses assigned to this host
vpnAddrs []netip.Addr
// A deduplicated set of addresses. Any accessor should lock beforehand. // A deduplicated set of addresses. Any accessor should lock beforehand.
addrs []netip.AddrPort addrs []netip.AddrPort
@ -212,13 +215,16 @@ type RemoteList struct {
} }
// NewRemoteList creates a new empty RemoteList // NewRemoteList creates a new empty RemoteList
func NewRemoteList(shouldAdd func(netip.Addr) bool) *RemoteList { func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func(netip.Addr) bool) *RemoteList {
return &RemoteList{ r := &RemoteList{
vpnAddrs: make([]netip.Addr, len(vpnAddrs)),
addrs: make([]netip.AddrPort, 0), addrs: make([]netip.AddrPort, 0),
relays: make([]netip.Addr, 0), relays: make([]netip.Addr, 0),
cache: make(map[netip.Addr]*cache), cache: make(map[netip.Addr]*cache),
shouldAdd: shouldAdd, shouldAdd: shouldAdd,
} }
copy(r.vpnAddrs, vpnAddrs)
return r
} }
func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) { func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) {
@ -273,9 +279,9 @@ func (r *RemoteList) LearnRemote(ownerVpnIp netip.Addr, remote netip.AddrPort) {
r.Lock() r.Lock()
defer r.Unlock() defer r.Unlock()
if remote.Addr().Is4() { if remote.Addr().Is4() {
r.unlockedSetLearnedV4(ownerVpnIp, NewIp4AndPortFromNetIP(remote.Addr(), remote.Port())) r.unlockedSetLearnedV4(ownerVpnIp, netAddrToProtoV4AddrPort(remote.Addr(), remote.Port()))
} else { } else {
r.unlockedSetLearnedV6(ownerVpnIp, NewIp6AndPortFromNetIP(remote.Addr(), remote.Port())) r.unlockedSetLearnedV6(ownerVpnIp, netAddrToProtoV6AddrPort(remote.Addr(), remote.Port()))
} }
} }
@ -304,21 +310,21 @@ func (r *RemoteList) CopyCache() *CacheMap {
if mc.v4 != nil { if mc.v4 != nil {
if mc.v4.learned != nil { if mc.v4.learned != nil {
c.Learned = append(c.Learned, AddrPortFromIp4AndPort(mc.v4.learned)) c.Learned = append(c.Learned, protoV4AddrPortToNetAddrPort(mc.v4.learned))
} }
for _, a := range mc.v4.reported { for _, a := range mc.v4.reported {
c.Reported = append(c.Reported, AddrPortFromIp4AndPort(a)) c.Reported = append(c.Reported, protoV4AddrPortToNetAddrPort(a))
} }
} }
if mc.v6 != nil { if mc.v6 != nil {
if mc.v6.learned != nil { if mc.v6.learned != nil {
c.Learned = append(c.Learned, AddrPortFromIp6AndPort(mc.v6.learned)) c.Learned = append(c.Learned, protoV6AddrPortToNetAddrPort(mc.v6.learned))
} }
for _, a := range mc.v6.reported { for _, a := range mc.v6.reported {
c.Reported = append(c.Reported, AddrPortFromIp6AndPort(a)) c.Reported = append(c.Reported, protoV6AddrPortToNetAddrPort(a))
} }
} }
@ -401,14 +407,14 @@ func (r *RemoteList) unlockedIsBad(remote netip.AddrPort) bool {
// unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the // unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the
// deduplicated address list as dirty // deduplicated address list as dirty
func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp netip.Addr, to *Ip4AndPort) { func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp netip.Addr, to *V4AddrPort) {
r.shouldRebuild = true r.shouldRebuild = true
r.unlockedGetOrMakeV4(ownerVpnIp).learned = to r.unlockedGetOrMakeV4(ownerVpnIp).learned = to
} }
// unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided // unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
// and marks the deduplicated address list as dirty // and marks the deduplicated address list as dirty
func (r *RemoteList) unlockedSetV4(ownerVpnIp, vpnIp netip.Addr, to []*Ip4AndPort, check checkFuncV4) { func (r *RemoteList) unlockedSetV4(ownerVpnIp, vpnIp netip.Addr, to []*V4AddrPort, check checkFuncV4) {
r.shouldRebuild = true r.shouldRebuild = true
c := r.unlockedGetOrMakeV4(ownerVpnIp) c := r.unlockedGetOrMakeV4(ownerVpnIp)
@ -436,12 +442,12 @@ func (r *RemoteList) unlockedSetRelay(ownerVpnIp, vpnIp netip.Addr, to []netip.A
// unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner // unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner
// This is only useful for establishing static hosts // This is only useful for establishing static hosts
func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *Ip4AndPort) { func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *V4AddrPort) {
r.shouldRebuild = true r.shouldRebuild = true
c := r.unlockedGetOrMakeV4(ownerVpnIp) c := r.unlockedGetOrMakeV4(ownerVpnIp)
// We are doing the easy append because this is rarely called // We are doing the easy append because this is rarely called
c.reported = append([]*Ip4AndPort{to}, c.reported...) c.reported = append([]*V4AddrPort{to}, c.reported...)
if len(c.reported) > MaxRemotes { if len(c.reported) > MaxRemotes {
c.reported = c.reported[:MaxRemotes] c.reported = c.reported[:MaxRemotes]
} }
@ -449,14 +455,14 @@ func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *Ip4AndPort) {
// unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the // unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the
// deduplicated address list as dirty // deduplicated address list as dirty
func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp netip.Addr, to *Ip6AndPort) { func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp netip.Addr, to *V6AddrPort) {
r.shouldRebuild = true r.shouldRebuild = true
r.unlockedGetOrMakeV6(ownerVpnIp).learned = to r.unlockedGetOrMakeV6(ownerVpnIp).learned = to
} }
// unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided // unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
// and marks the deduplicated address list as dirty // and marks the deduplicated address list as dirty
func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*Ip6AndPort, check checkFuncV6) { func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*V6AddrPort, check checkFuncV6) {
r.shouldRebuild = true r.shouldRebuild = true
c := r.unlockedGetOrMakeV6(ownerVpnIp) c := r.unlockedGetOrMakeV6(ownerVpnIp)
@ -473,12 +479,12 @@ func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*Ip6AndPor
// unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner // unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner
// This is only useful for establishing static hosts // This is only useful for establishing static hosts
func (r *RemoteList) unlockedPrependV6(ownerVpnIp netip.Addr, to *Ip6AndPort) { func (r *RemoteList) unlockedPrependV6(ownerVpnIp netip.Addr, to *V6AddrPort) {
r.shouldRebuild = true r.shouldRebuild = true
c := r.unlockedGetOrMakeV6(ownerVpnIp) c := r.unlockedGetOrMakeV6(ownerVpnIp)
// We are doing the easy append because this is rarely called // We are doing the easy append because this is rarely called
c.reported = append([]*Ip6AndPort{to}, c.reported...) c.reported = append([]*V6AddrPort{to}, c.reported...)
if len(c.reported) > MaxRemotes { if len(c.reported) > MaxRemotes {
c.reported = c.reported[:MaxRemotes] c.reported = c.reported[:MaxRemotes]
} }
@ -536,14 +542,14 @@ func (r *RemoteList) unlockedCollect() {
for _, c := range r.cache { for _, c := range r.cache {
if c.v4 != nil { if c.v4 != nil {
if c.v4.learned != nil { if c.v4.learned != nil {
u := AddrPortFromIp4AndPort(c.v4.learned) u := protoV4AddrPortToNetAddrPort(c.v4.learned)
if !r.unlockedIsBad(u) { if !r.unlockedIsBad(u) {
addrs = append(addrs, u) addrs = append(addrs, u)
} }
} }
for _, v := range c.v4.reported { for _, v := range c.v4.reported {
u := AddrPortFromIp4AndPort(v) u := protoV4AddrPortToNetAddrPort(v)
if !r.unlockedIsBad(u) { if !r.unlockedIsBad(u) {
addrs = append(addrs, u) addrs = append(addrs, u)
} }
@ -552,14 +558,14 @@ func (r *RemoteList) unlockedCollect() {
if c.v6 != nil { if c.v6 != nil {
if c.v6.learned != nil { if c.v6.learned != nil {
u := AddrPortFromIp6AndPort(c.v6.learned) u := protoV6AddrPortToNetAddrPort(c.v6.learned)
if !r.unlockedIsBad(u) { if !r.unlockedIsBad(u) {
addrs = append(addrs, u) addrs = append(addrs, u)
} }
} }
for _, v := range c.v6.reported { for _, v := range c.v6.reported {
u := AddrPortFromIp6AndPort(v) u := protoV6AddrPortToNetAddrPort(v)
if !r.unlockedIsBad(u) { if !r.unlockedIsBad(u) {
addrs = append(addrs, u) addrs = append(addrs, u)
} }
@ -573,7 +579,7 @@ func (r *RemoteList) unlockedCollect() {
} }
} }
dnsAddrs := r.hr.GetIPs() dnsAddrs := r.hr.GetAddrs()
for _, addr := range dnsAddrs { for _, addr := range dnsAddrs {
if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) { if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) {
if !r.unlockedIsBad(addr) { if !r.unlockedIsBad(addr) {

View File

@ -9,11 +9,11 @@ import (
) )
func TestRemoteList_Rebuild(t *testing.T) { func TestRemoteList_Rebuild(t *testing.T) {
rl := NewRemoteList(nil) rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil)
rl.unlockedSetV4( rl.unlockedSetV4(
netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"),
netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"),
[]*Ip4AndPort{ []*V4AddrPort{
newIp4AndPortFromString("70.199.182.92:1475"), // this is duped newIp4AndPortFromString("70.199.182.92:1475"), // this is duped
newIp4AndPortFromString("172.17.0.182:10101"), newIp4AndPortFromString("172.17.0.182:10101"),
newIp4AndPortFromString("172.17.1.1:10101"), // this is duped newIp4AndPortFromString("172.17.1.1:10101"), // this is duped
@ -25,20 +25,20 @@ func TestRemoteList_Rebuild(t *testing.T) {
newIp4AndPortFromString("70.199.182.92:1476"), // almost dupe of 0 with a diff port newIp4AndPortFromString("70.199.182.92:1476"), // almost dupe of 0 with a diff port
newIp4AndPortFromString("70.199.182.92:1475"), // this is a dupe newIp4AndPortFromString("70.199.182.92:1475"), // this is a dupe
}, },
func(netip.Addr, *Ip4AndPort) bool { return true }, func(netip.Addr, *V4AddrPort) bool { return true },
) )
rl.unlockedSetV6( rl.unlockedSetV6(
netip.MustParseAddr("0.0.0.1"), netip.MustParseAddr("0.0.0.1"),
netip.MustParseAddr("0.0.0.1"), netip.MustParseAddr("0.0.0.1"),
[]*Ip6AndPort{ []*V6AddrPort{
newIp6AndPortFromString("[1::1]:1"), // this is duped newIp6AndPortFromString("[1::1]:1"), // this is duped
newIp6AndPortFromString("[1::1]:2"), // almost dupe of 0 with a diff port, also gets duped newIp6AndPortFromString("[1::1]:2"), // almost dupe of 0 with a diff port, also gets duped
newIp6AndPortFromString("[1:100::1]:1"), newIp6AndPortFromString("[1:100::1]:1"),
newIp6AndPortFromString("[1::1]:1"), // this is a dupe newIp6AndPortFromString("[1::1]:1"), // this is a dupe
newIp6AndPortFromString("[1::1]:2"), // this is a dupe newIp6AndPortFromString("[1::1]:2"), // this is a dupe
}, },
func(netip.Addr, *Ip6AndPort) bool { return true }, func(netip.Addr, *V6AddrPort) bool { return true },
) )
rl.Rebuild([]netip.Prefix{}) rl.Rebuild([]netip.Prefix{})
@ -98,11 +98,11 @@ func TestRemoteList_Rebuild(t *testing.T) {
} }
func BenchmarkFullRebuild(b *testing.B) { func BenchmarkFullRebuild(b *testing.B) {
rl := NewRemoteList(nil) rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil)
rl.unlockedSetV4( rl.unlockedSetV4(
netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"),
netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"),
[]*Ip4AndPort{ []*V4AddrPort{
newIp4AndPortFromString("70.199.182.92:1475"), newIp4AndPortFromString("70.199.182.92:1475"),
newIp4AndPortFromString("172.17.0.182:10101"), newIp4AndPortFromString("172.17.0.182:10101"),
newIp4AndPortFromString("172.17.1.1:10101"), newIp4AndPortFromString("172.17.1.1:10101"),
@ -112,19 +112,19 @@ func BenchmarkFullRebuild(b *testing.B) {
newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe
newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port
}, },
func(netip.Addr, *Ip4AndPort) bool { return true }, func(netip.Addr, *V4AddrPort) bool { return true },
) )
rl.unlockedSetV6( rl.unlockedSetV6(
netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"),
netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"),
[]*Ip6AndPort{ []*V6AddrPort{
newIp6AndPortFromString("[1::1]:1"), newIp6AndPortFromString("[1::1]:1"),
newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port
newIp6AndPortFromString("[1:100::1]:1"), newIp6AndPortFromString("[1:100::1]:1"),
newIp6AndPortFromString("[1::1]:1"), // this is a dupe newIp6AndPortFromString("[1::1]:1"), // this is a dupe
}, },
func(netip.Addr, *Ip6AndPort) bool { return true }, func(netip.Addr, *V6AddrPort) bool { return true },
) )
b.Run("no preferred", func(b *testing.B) { b.Run("no preferred", func(b *testing.B) {
@ -160,11 +160,11 @@ func BenchmarkFullRebuild(b *testing.B) {
} }
func BenchmarkSortRebuild(b *testing.B) { func BenchmarkSortRebuild(b *testing.B) {
rl := NewRemoteList(nil) rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil)
rl.unlockedSetV4( rl.unlockedSetV4(
netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"),
netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"),
[]*Ip4AndPort{ []*V4AddrPort{
newIp4AndPortFromString("70.199.182.92:1475"), newIp4AndPortFromString("70.199.182.92:1475"),
newIp4AndPortFromString("172.17.0.182:10101"), newIp4AndPortFromString("172.17.0.182:10101"),
newIp4AndPortFromString("172.17.1.1:10101"), newIp4AndPortFromString("172.17.1.1:10101"),
@ -174,19 +174,19 @@ func BenchmarkSortRebuild(b *testing.B) {
newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe
newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port
}, },
func(netip.Addr, *Ip4AndPort) bool { return true }, func(netip.Addr, *V4AddrPort) bool { return true },
) )
rl.unlockedSetV6( rl.unlockedSetV6(
netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"),
netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"),
[]*Ip6AndPort{ []*V6AddrPort{
newIp6AndPortFromString("[1::1]:1"), newIp6AndPortFromString("[1::1]:1"),
newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port
newIp6AndPortFromString("[1:100::1]:1"), newIp6AndPortFromString("[1:100::1]:1"),
newIp6AndPortFromString("[1::1]:1"), // this is a dupe newIp6AndPortFromString("[1::1]:1"), // this is a dupe
}, },
func(netip.Addr, *Ip6AndPort) bool { return true }, func(netip.Addr, *V6AddrPort) bool { return true },
) )
b.Run("no preferred", func(b *testing.B) { b.Run("no preferred", func(b *testing.B) {
@ -224,19 +224,19 @@ func BenchmarkSortRebuild(b *testing.B) {
}) })
} }
func newIp4AndPortFromString(s string) *Ip4AndPort { func newIp4AndPortFromString(s string) *V4AddrPort {
a := netip.MustParseAddrPort(s) a := netip.MustParseAddrPort(s)
v4Addr := a.Addr().As4() v4Addr := a.Addr().As4()
return &Ip4AndPort{ return &V4AddrPort{
Ip: binary.BigEndian.Uint32(v4Addr[:]), Addr: binary.BigEndian.Uint32(v4Addr[:]),
Port: uint32(a.Port()), Port: uint32(a.Port()),
} }
} }
func newIp6AndPortFromString(s string) *Ip6AndPort { func newIp6AndPortFromString(s string) *V6AddrPort {
a := netip.MustParseAddrPort(s) a := netip.MustParseAddrPort(s)
v6Addr := a.Addr().As16() v6Addr := a.Addr().As16()
return &Ip6AndPort{ return &V6AddrPort{
Hi: binary.BigEndian.Uint64(v6Addr[:8]), Hi: binary.BigEndian.Uint64(v6Addr[:8]),
Lo: binary.BigEndian.Uint64(v6Addr[8:]), Lo: binary.BigEndian.Uint64(v6Addr[8:]),
Port: uint32(a.Port()), Port: uint32(a.Port()),

View File

@ -90,9 +90,9 @@ func New(config *config.C) (*Service, error) {
}, },
}) })
ipNet := device.Cidr() ipNet := device.Networks()
pa := tcpip.ProtocolAddress{ pa := tcpip.ProtocolAddress{
AddressWithPrefix: tcpip.AddrFromSlice(ipNet.Addr().AsSlice()).WithPrefix(), AddressWithPrefix: tcpip.AddrFromSlice(ipNet[0].Addr().AsSlice()).WithPrefix(),
Protocol: ipv4.ProtocolNumber, Protocol: ipv4.ProtocolNumber,
} }
if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{ if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{

View File

@ -19,7 +19,7 @@ import (
type m map[string]interface{} type m map[string]interface{}
func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service { func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service {
_, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.PrefixFrom(udpIp, 24)}, nil, []string{}) _, _, myPrivKey, myPEM := e2e.NewTestCert(cert.Version2, caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.PrefixFrom(udpIp, 24)}, nil, []string{})
caB, err := caCrt.MarshalPEM() caB, err := caCrt.MarshalPEM()
if err != nil { if err != nil {
panic(err) panic(err)

35
ssh.go
View File

@ -430,7 +430,7 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er
} }
sort.Slice(hm, func(i, j int) bool { sort.Slice(hm, func(i, j int) bool {
return hm[i].VpnIp.Compare(hm[j].VpnIp) < 0 return hm[i].VpnAddrs[0].Compare(hm[j].VpnAddrs[0]) < 0
}) })
if fs.Json || fs.Pretty { if fs.Json || fs.Pretty {
@ -447,7 +447,7 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er
} else { } else {
for _, v := range hm { for _, v := range hm {
err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIp, v.RemoteAddrs)) err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnAddrs, v.RemoteAddrs))
if err != nil { if err != nil {
return err return err
} }
@ -581,7 +581,7 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }
hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) hostInfo := ifce.hostMap.QueryVpnAddr(vpnIp)
if hostInfo == nil { if hostInfo == nil {
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
} }
@ -622,12 +622,12 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }
hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) hostInfo := ifce.hostMap.QueryVpnAddr(vpnIp)
if hostInfo != nil { if hostInfo != nil {
return w.WriteLine(fmt.Sprintf("Tunnel already exists")) return w.WriteLine(fmt.Sprintf("Tunnel already exists"))
} }
hostInfo = ifce.handshakeManager.QueryVpnIp(vpnIp) hostInfo = ifce.handshakeManager.QueryVpnAddr(vpnIp)
if hostInfo != nil { if hostInfo != nil {
return w.WriteLine(fmt.Sprintf("Tunnel already handshaking")) return w.WriteLine(fmt.Sprintf("Tunnel already handshaking"))
} }
@ -677,7 +677,7 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }
hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) hostInfo := ifce.hostMap.QueryVpnAddr(vpnIp)
if hostInfo == nil { if hostInfo == nil {
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
} }
@ -785,7 +785,8 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
return nil return nil
} }
cert := ifce.pki.GetCertState().Certificate //TODO: This should return both certs
cert := ifce.pki.getDefaultCertificate()
if len(a) > 0 { if len(a) > 0 {
vpnIp, err := netip.ParseAddr(a[0]) vpnIp, err := netip.ParseAddr(a[0])
if err != nil { if err != nil {
@ -796,7 +797,7 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }
hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) hostInfo := ifce.hostMap.QueryVpnAddr(vpnIp)
if hostInfo == nil { if hostInfo == nil {
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
} }
@ -880,16 +881,16 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
} }
for k, v := range relays { for k, v := range relays {
ro := RelayOutput{NebulaIp: v.vpnIp} ro := RelayOutput{NebulaIp: v.vpnAddrs[0]}
co.Relays = append(co.Relays, &ro) co.Relays = append(co.Relays, &ro)
relayHI := ifce.hostMap.QueryVpnIp(v.vpnIp) relayHI := ifce.hostMap.QueryVpnAddr(v.vpnAddrs[0])
if relayHI == nil { if relayHI == nil {
ro.RelayForIps = append(ro.RelayForIps, RelayFor{Error: errors.New("could not find hostinfo")}) ro.RelayForIps = append(ro.RelayForIps, RelayFor{Error: errors.New("could not find hostinfo")})
continue continue
} }
for _, vpnIp := range relayHI.relayState.CopyRelayForIps() { for _, vpnIp := range relayHI.relayState.CopyRelayForIps() {
rf := RelayFor{Error: nil} rf := RelayFor{Error: nil}
r, ok := relayHI.relayState.GetRelayForByIp(vpnIp) r, ok := relayHI.relayState.GetRelayForByAddr(vpnIp)
if ok { if ok {
t := "" t := ""
switch r.Type { switch r.Type {
@ -913,14 +914,14 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
rf.LocalIndex = r.LocalIndex rf.LocalIndex = r.LocalIndex
rf.RemoteIndex = r.RemoteIndex rf.RemoteIndex = r.RemoteIndex
rf.PeerIp = r.PeerIp rf.PeerIp = r.PeerAddr
rf.Type = t rf.Type = t
rf.State = s rf.State = s
if rf.LocalIndex != k { if rf.LocalIndex != k {
rf.Error = fmt.Errorf("hostmap LocalIndex '%v' does not match RelayState LocalIndex", k) rf.Error = fmt.Errorf("hostmap LocalIndex '%v' does not match RelayState LocalIndex", k)
} }
} }
relayedHI := ifce.hostMap.QueryVpnIp(vpnIp) relayedHI := ifce.hostMap.QueryVpnAddr(vpnIp)
if relayedHI != nil { if relayedHI != nil {
rf.RelayedThrough = append(rf.RelayedThrough, relayedHI.relayState.CopyRelayIps()...) rf.RelayedThrough = append(rf.RelayedThrough, relayedHI.relayState.CopyRelayIps()...)
} }
@ -955,7 +956,7 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
} }
hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) hostInfo := ifce.hostMap.QueryVpnAddr(vpnIp)
if hostInfo == nil { if hostInfo == nil {
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
} }
@ -972,12 +973,14 @@ func sshDeviceInfo(ifce *Interface, fs interface{}, w sshd.StringWriter) error {
data := struct { data := struct {
Name string `json:"name"` Name string `json:"name"`
Cidr string `json:"cidr"` Cidr []netip.Prefix `json:"cidr"`
}{ }{
Name: ifce.inside.Name(), Name: ifce.inside.Name(),
Cidr: ifce.inside.Cidr().String(), Cidr: make([]netip.Prefix, len(ifce.inside.Networks())),
} }
copy(data.Cidr, ifce.inside.Networks())
flags, ok := fs.(*sshDeviceInfoFlags) flags, ok := fs.(*sshDeviceInfoFlags)
if !ok { if !ok {
return fmt.Errorf("internal error: expected flags to be sshDeviceInfoFlags but was %+v", fs) return fmt.Errorf("internal error: expected flags to be sshDeviceInfoFlags but was %+v", fs)

View File

@ -16,8 +16,8 @@ func (NoopTun) Activate() error {
return nil return nil
} }
func (NoopTun) Cidr() netip.Prefix { func (NoopTun) Networks() []netip.Prefix {
return netip.Prefix{} return []netip.Prefix{}
} }
func (NoopTun) Name() string { func (NoopTun) Name() string {

View File

@ -116,10 +116,10 @@ func TestTimerWheel_Purge(t *testing.T) {
assert.Equal(t, 0, tw.current) assert.Equal(t, 0, tw.current)
fps := []firewall.Packet{ fps := []firewall.Packet{
{LocalIP: netip.MustParseAddr("0.0.0.1")}, {LocalAddr: netip.MustParseAddr("0.0.0.1")},
{LocalIP: netip.MustParseAddr("0.0.0.2")}, {LocalAddr: netip.MustParseAddr("0.0.0.2")},
{LocalIP: netip.MustParseAddr("0.0.0.3")}, {LocalAddr: netip.MustParseAddr("0.0.0.3")},
{LocalIP: netip.MustParseAddr("0.0.0.4")}, {LocalAddr: netip.MustParseAddr("0.0.0.4")},
} }
tw.Add(fps[0], time.Second*1) tw.Add(fps[0], time.Second*1)

View File

@ -4,28 +4,19 @@ import (
"net/netip" "net/netip"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
) )
const MTU = 9001 const MTU = 9001
type EncReader func( type EncReader func(
addr netip.AddrPort, addr netip.AddrPort,
out []byte, payload []byte,
packet []byte,
header *header.H,
fwPacket *firewall.Packet,
lhh LightHouseHandlerFunc,
nb []byte,
q int,
localCache firewall.ConntrackCache,
) )
type Conn interface { type Conn interface {
Rebind() error Rebind() error
LocalAddr() (netip.AddrPort, error) LocalAddr() (netip.AddrPort, error)
ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) ListenOut(r EncReader)
WriteTo(b []byte, addr netip.AddrPort) error WriteTo(b []byte, addr netip.AddrPort) error
ReloadConfig(c *config.C) ReloadConfig(c *config.C)
Close() error Close() error
@ -39,7 +30,7 @@ func (NoopConn) Rebind() error {
func (NoopConn) LocalAddr() (netip.AddrPort, error) { func (NoopConn) LocalAddr() (netip.AddrPort, error) {
return netip.AddrPort{}, nil return netip.AddrPort{}, nil
} }
func (NoopConn) ListenOut(_ EncReader, _ LightHouseHandlerFunc, _ *firewall.ConntrackCacheTicker, _ int) { func (NoopConn) ListenOut(_ EncReader) {
return return
} }
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error { func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {

View File

@ -1,10 +0,0 @@
package udp
import (
"net/netip"
)
//TODO: The items in this file belong in their own packages but doing that in a single PR is a nightmare
// TODO: IPV6-WORK this can likely be removed now
type LightHouseHandlerFunc func(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte)

View File

@ -15,8 +15,6 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
) )
type GenericConn struct { type GenericConn struct {
@ -72,12 +70,8 @@ type rawMessage struct {
Len uint32 Len uint32
} }
func (u *GenericConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { func (u *GenericConn) ListenOut(r EncReader) {
plaintext := make([]byte, MTU)
buffer := make([]byte, MTU) buffer := make([]byte, MTU)
h := &header.H{}
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
for { for {
// Just read one packet at a time // Just read one packet at a time
@ -87,16 +81,6 @@ func (u *GenericConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *f
return return
} }
r( r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()),
plaintext[:0],
buffer[:n],
h,
fwPacket,
lhf,
nb,
q,
cache.Get(u.l),
)
} }
} }

View File

@ -14,8 +14,6 @@ import (
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
@ -120,15 +118,9 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
} }
} }
func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { func (u *StdConn) ListenOut(r EncReader) {
plaintext := make([]byte, MTU)
h := &header.H{}
fwPacket := &firewall.Packet{}
var ip netip.Addr var ip netip.Addr
nb := make([]byte, 12, 12)
//TODO: should we track this?
//metric := metrics.GetOrRegisterHistogram("test.batch_read", nil, metrics.NewExpDecaySample(1028, 0.015))
msgs, buffers, names := u.PrepareRawMessages(u.batch) msgs, buffers, names := u.PrepareRawMessages(u.batch)
read := u.ReadMulti read := u.ReadMulti
if u.batch == 1 { if u.batch == 1 {
@ -142,26 +134,14 @@ func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew
return return
} }
//metric.Update(int64(n))
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
if u.isV4 { if u.isV4 {
ip, _ = netip.AddrFromSlice(names[i][4:8]) ip, _ = netip.AddrFromSlice(names[i][4:8])
//TODO: IPV6-WORK what is not ok?
} else { } else {
ip, _ = netip.AddrFromSlice(names[i][8:24]) ip, _ = netip.AddrFromSlice(names[i][8:24])
//TODO: IPV6-WORK what is not ok?
} }
r( r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len])
netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])),
plaintext[:0],
buffers[i][:msgs[i].Len],
h,
fwPacket,
lhf,
nb,
q,
cache.Get(u.l),
)
} }
} }
} }

View File

@ -18,9 +18,6 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/conn/winrio" "golang.zx2c4.com/wireguard/conn/winrio"
) )
@ -118,12 +115,8 @@ func (u *RIOConn) bind(sa windows.Sockaddr) error {
return nil return nil
} }
func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { func (u *RIOConn) ListenOut(r EncReader) {
plaintext := make([]byte, MTU)
buffer := make([]byte, MTU) buffer := make([]byte, MTU)
h := &header.H{}
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
for { for {
// Just read one packet at a time // Just read one packet at a time
@ -133,17 +126,7 @@ func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew
return return
} }
r( r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n])
netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)),
plaintext[:0],
buffer[:n],
h,
fwPacket,
lhf,
nb,
q,
cache.Get(u.l),
)
} }
} }

View File

@ -10,7 +10,6 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
) )
@ -107,18 +106,13 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error {
return nil return nil
} }
func (u *TesterConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { func (u *TesterConn) ListenOut(r EncReader) {
plaintext := make([]byte, MTU)
h := &header.H{}
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
for { for {
p, ok := <-u.RxPackets p, ok := <-u.RxPackets
if !ok { if !ok {
return return
} }
r(p.From, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l)) r(p.From, p.Data)
} }
} }