mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-12 06:33:58 +01:00
Support for ipv6 in the overlay with v2 certificates
--------- Co-authored-by: Jack Doan <jackdoan@rivian.com>
This commit is contained in:
parent
3e6c75573f
commit
f2c32421c4
2
Makefile
2
Makefile
@ -196,7 +196,7 @@ bench-cpu-long:
|
|||||||
go test -bench=. -benchtime=60s -cpuprofile=cpu.pprof
|
go test -bench=. -benchtime=60s -cpuprofile=cpu.pprof
|
||||||
go tool pprof go-audit.test cpu.pprof
|
go tool pprof go-audit.test cpu.pprof
|
||||||
|
|
||||||
proto: nebula.pb.go cert/cert.pb.go
|
proto: nebula.pb.go cert/cert_v1.pb.go
|
||||||
|
|
||||||
nebula.pb.go: nebula.proto .FORCE
|
nebula.pb.go: nebula.proto .FORCE
|
||||||
go build github.com/gogo/protobuf/protoc-gen-gogofaster
|
go build github.com/gogo/protobuf/protoc-gen-gogofaster
|
||||||
|
|||||||
@ -21,7 +21,11 @@ type calculatedRemote struct {
|
|||||||
port uint32
|
port uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
func newCalculatedRemote(maskCidr netip.Prefix, port int) (*calculatedRemote, error) {
|
func newCalculatedRemote(cidr, maskCidr netip.Prefix, port int) (*calculatedRemote, error) {
|
||||||
|
if maskCidr.Addr().BitLen() != cidr.Addr().BitLen() {
|
||||||
|
return nil, fmt.Errorf("invalid mask: %s for cidr: %s", maskCidr, cidr)
|
||||||
|
}
|
||||||
|
|
||||||
masked := maskCidr.Masked()
|
masked := maskCidr.Masked()
|
||||||
if port < 0 || port > math.MaxUint16 {
|
if port < 0 || port > math.MaxUint16 {
|
||||||
return nil, fmt.Errorf("invalid port: %d", port)
|
return nil, fmt.Errorf("invalid port: %d", port)
|
||||||
@ -38,32 +42,38 @@ func (c *calculatedRemote) String() string {
|
|||||||
return fmt.Sprintf("CalculatedRemote(mask=%v port=%d)", c.ipNet, c.port)
|
return fmt.Sprintf("CalculatedRemote(mask=%v port=%d)", c.ipNet, c.port)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *calculatedRemote) Apply(ip netip.Addr) *Ip4AndPort {
|
func (c *calculatedRemote) ApplyV4(addr netip.Addr) *V4AddrPort {
|
||||||
// Combine the masked bytes of the "mask" IP with the unmasked bytes
|
// Combine the masked bytes of the "mask" IP with the unmasked bytes of the overlay IP
|
||||||
// of the overlay IP
|
|
||||||
if c.ipNet.Addr().Is4() {
|
|
||||||
return c.apply4(ip)
|
|
||||||
}
|
|
||||||
return c.apply6(ip)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *calculatedRemote) apply4(ip netip.Addr) *Ip4AndPort {
|
|
||||||
//TODO: IPV6-WORK this can be less crappy
|
|
||||||
maskb := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen())
|
maskb := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen())
|
||||||
mask := binary.BigEndian.Uint32(maskb[:])
|
mask := binary.BigEndian.Uint32(maskb[:])
|
||||||
|
|
||||||
b := c.mask.Addr().As4()
|
b := c.mask.Addr().As4()
|
||||||
maskIp := binary.BigEndian.Uint32(b[:])
|
maskAddr := binary.BigEndian.Uint32(b[:])
|
||||||
|
|
||||||
b = ip.As4()
|
b = addr.As4()
|
||||||
intIp := binary.BigEndian.Uint32(b[:])
|
intAddr := binary.BigEndian.Uint32(b[:])
|
||||||
|
|
||||||
return &Ip4AndPort{(maskIp & mask) | (intIp & ^mask), c.port}
|
return &V4AddrPort{(maskAddr & mask) | (intAddr & ^mask), c.port}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *calculatedRemote) apply6(ip netip.Addr) *Ip4AndPort {
|
func (c *calculatedRemote) ApplyV6(addr netip.Addr) *V6AddrPort {
|
||||||
//TODO: IPV6-WORK
|
mask := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen())
|
||||||
panic("Can not calculate ipv6 remote addresses")
|
maskAddr := c.mask.Addr().As16()
|
||||||
|
calcAddr := addr.As16()
|
||||||
|
|
||||||
|
ap := V6AddrPort{Port: c.port}
|
||||||
|
|
||||||
|
maskb := binary.BigEndian.Uint64(mask[:8])
|
||||||
|
maskAddrb := binary.BigEndian.Uint64(maskAddr[:8])
|
||||||
|
calcAddrb := binary.BigEndian.Uint64(calcAddr[:8])
|
||||||
|
ap.Hi = (maskAddrb & maskb) | (calcAddrb & ^maskb)
|
||||||
|
|
||||||
|
maskb = binary.BigEndian.Uint64(mask[8:])
|
||||||
|
maskAddrb = binary.BigEndian.Uint64(maskAddr[8:])
|
||||||
|
calcAddrb = binary.BigEndian.Uint64(calcAddr[8:])
|
||||||
|
ap.Lo = (maskAddrb & maskb) | (calcAddrb & ^maskb)
|
||||||
|
|
||||||
|
return &ap
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calculatedRemote], error) {
|
func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calculatedRemote], error) {
|
||||||
@ -89,8 +99,7 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calcu
|
|||||||
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
|
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO: IPV6-WORK this does not verify that rawValue contains the same bits as cidr here
|
entry, err := newCalculatedRemotesListFromConfig(cidr, rawValue)
|
||||||
entry, err := newCalculatedRemotesListFromConfig(rawValue)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("config '%s.%s': %w", k, rawCIDR, err)
|
return nil, fmt.Errorf("config '%s.%s': %w", k, rawCIDR, err)
|
||||||
}
|
}
|
||||||
@ -101,7 +110,7 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calcu
|
|||||||
return calculatedRemotes, nil
|
return calculatedRemotes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) {
|
func newCalculatedRemotesListFromConfig(cidr netip.Prefix, raw any) ([]*calculatedRemote, error) {
|
||||||
rawList, ok := raw.([]any)
|
rawList, ok := raw.([]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("calculated_remotes entry has invalid type: %T", raw)
|
return nil, fmt.Errorf("calculated_remotes entry has invalid type: %T", raw)
|
||||||
@ -109,7 +118,7 @@ func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) {
|
|||||||
|
|
||||||
var l []*calculatedRemote
|
var l []*calculatedRemote
|
||||||
for _, e := range rawList {
|
for _, e := range rawList {
|
||||||
c, err := newCalculatedRemotesEntryFromConfig(e)
|
c, err := newCalculatedRemotesEntryFromConfig(cidr, e)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("calculated_remotes entry: %w", err)
|
return nil, fmt.Errorf("calculated_remotes entry: %w", err)
|
||||||
}
|
}
|
||||||
@ -119,7 +128,7 @@ func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) {
|
|||||||
return l, nil
|
return l, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) {
|
func newCalculatedRemotesEntryFromConfig(cidr netip.Prefix, raw any) (*calculatedRemote, error) {
|
||||||
rawMap, ok := raw.(map[any]any)
|
rawMap, ok := raw.(map[any]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid type: %T", raw)
|
return nil, fmt.Errorf("invalid type: %T", raw)
|
||||||
@ -155,5 +164,5 @@ func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) {
|
|||||||
return nil, fmt.Errorf("invalid port (type %T): %v", rawValue, rawValue)
|
return nil, fmt.Errorf("invalid port (type %T): %v", rawValue, rawValue)
|
||||||
}
|
}
|
||||||
|
|
||||||
return newCalculatedRemote(maskCidr, port)
|
return newCalculatedRemote(cidr, maskCidr, port)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -9,10 +9,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestCalculatedRemoteApply(t *testing.T) {
|
func TestCalculatedRemoteApply(t *testing.T) {
|
||||||
ipNet, err := netip.ParsePrefix("192.168.1.0/24")
|
// Test v4 addresses
|
||||||
require.NoError(t, err)
|
ipNet := netip.MustParsePrefix("192.168.1.0/24")
|
||||||
|
c, err := newCalculatedRemote(ipNet, ipNet, 4242)
|
||||||
c, err := newCalculatedRemote(ipNet, 4242)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
input, err := netip.ParseAddr("10.0.10.182")
|
input, err := netip.ParseAddr("10.0.10.182")
|
||||||
@ -21,5 +20,62 @@ func TestCalculatedRemoteApply(t *testing.T) {
|
|||||||
expected, err := netip.ParseAddr("192.168.1.182")
|
expected, err := netip.ParseAddr("192.168.1.182")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, NewIp4AndPortFromNetIP(expected, 4242), c.Apply(input))
|
assert.Equal(t, netAddrToProtoV4AddrPort(expected, 4242), c.ApplyV4(input))
|
||||||
|
|
||||||
|
// Test v6 addresses
|
||||||
|
ipNet = netip.MustParsePrefix("ffff:ffff:ffff:ffff::0/64")
|
||||||
|
c, err = newCalculatedRemote(ipNet, ipNet, 4242)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:beef:beef:beef:beef")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
|
||||||
|
|
||||||
|
// Test v6 addresses part 2
|
||||||
|
ipNet = netip.MustParsePrefix("ffff:ffff:ffff:ffff:ffff::0/80")
|
||||||
|
c, err = newCalculatedRemote(ipNet, ipNet, 4242)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:ffff:beef:beef:beef")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
|
||||||
|
|
||||||
|
// Test v6 addresses part 2
|
||||||
|
ipNet = netip.MustParsePrefix("ffff:ffff:ffff::0/48")
|
||||||
|
c, err = newCalculatedRemote(ipNet, ipNet, 4242)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
expected, err = netip.ParseAddr("ffff:ffff:ffff:beef:beef:beef:beef:beef")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_newCalculatedRemote(t *testing.T) {
|
||||||
|
c, err := newCalculatedRemote(netip.MustParsePrefix("1::1/128"), netip.MustParsePrefix("1.0.0.0/32"), 4242)
|
||||||
|
require.EqualError(t, err, "invalid mask: 1.0.0.0/32 for cidr: 1::1/128")
|
||||||
|
require.Nil(t, c)
|
||||||
|
|
||||||
|
c, err = newCalculatedRemote(netip.MustParsePrefix("1.0.0.0/32"), netip.MustParsePrefix("1::1/128"), 4242)
|
||||||
|
require.EqualError(t, err, "invalid mask: 1::1/128 for cidr: 1.0.0.0/32")
|
||||||
|
require.Nil(t, c)
|
||||||
|
|
||||||
|
c, err = newCalculatedRemote(netip.MustParsePrefix("1.0.0.0/32"), netip.MustParsePrefix("1.0.0.0/32"), 4242)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, c)
|
||||||
|
|
||||||
|
c, err = newCalculatedRemote(netip.MustParsePrefix("1::1/128"), netip.MustParsePrefix("1::1/128"), 4242)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, c)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2,14 +2,25 @@
|
|||||||
|
|
||||||
This is a library for interacting with `nebula` style certificates and authorities.
|
This is a library for interacting with `nebula` style certificates and authorities.
|
||||||
|
|
||||||
A `protobuf` definition of the certificate format is also included
|
There are now 2 versions of `nebula` certificates:
|
||||||
|
|
||||||
### Compiling the protobuf definition
|
## v1
|
||||||
|
|
||||||
Make sure you have `protoc` installed.
|
This version is deprecated.
|
||||||
|
|
||||||
|
A `protobuf` definition of the certificate format is included at `cert_v1.proto`
|
||||||
|
|
||||||
|
To compile the definition you will need `protoc` installed.
|
||||||
|
|
||||||
To compile for `go` with the same version of protobuf specified in go.mod:
|
To compile for `go` with the same version of protobuf specified in go.mod:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
make
|
make proto
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## v2
|
||||||
|
|
||||||
|
This is the latest version which uses asn.1 DER encoding. It can support ipv4 and ipv6 and tolerate
|
||||||
|
future certificate changes better than v1.
|
||||||
|
|
||||||
|
`cert_v2.asn1` defines the wire format and can be used to compile marshalers.
|
||||||
52
cert/asn1.go
Normal file
52
cert/asn1.go
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
package cert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"golang.org/x/crypto/cryptobyte"
|
||||||
|
"golang.org/x/crypto/cryptobyte/asn1"
|
||||||
|
)
|
||||||
|
|
||||||
|
// readOptionalASN1Boolean reads an asn.1 boolean with a specific tag instead of a asn.1 tag wrapping a boolean with a value
|
||||||
|
// https://github.com/golang/go/issues/64811#issuecomment-1944446920
|
||||||
|
func readOptionalASN1Boolean(b *cryptobyte.String, out *bool, tag asn1.Tag, defaultValue bool) bool {
|
||||||
|
var present bool
|
||||||
|
var child cryptobyte.String
|
||||||
|
if !b.ReadOptionalASN1(&child, &present, tag) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !present {
|
||||||
|
*out = defaultValue
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure we have 1 byte
|
||||||
|
if len(child) == 1 {
|
||||||
|
*out = child[0] > 0
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// readOptionalASN1Byte reads an asn.1 uint8 with a specific tag instead of a asn.1 tag wrapping a uint8 with a value
|
||||||
|
// Similar issue as with readOptionalASN1Boolean
|
||||||
|
func readOptionalASN1Byte(b *cryptobyte.String, out *byte, tag asn1.Tag, defaultValue byte) bool {
|
||||||
|
var present bool
|
||||||
|
var child cryptobyte.String
|
||||||
|
if !b.ReadOptionalASN1(&child, &present, tag) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !present {
|
||||||
|
*out = defaultValue
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure we have 1 byte
|
||||||
|
if len(child) == 1 {
|
||||||
|
*out = child[0]
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
@ -63,31 +63,31 @@ IBNWYMep3ysx9zCgknfG5dKtwGTaqF++BWKDYdyl34KX
|
|||||||
|
|
||||||
rootCA := certificateV1{
|
rootCA := certificateV1{
|
||||||
details: detailsV1{
|
details: detailsV1{
|
||||||
Name: "nebula root ca",
|
name: "nebula root ca",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
rootCA01 := certificateV1{
|
rootCA01 := certificateV1{
|
||||||
details: detailsV1{
|
details: detailsV1{
|
||||||
Name: "nebula root ca 01",
|
name: "nebula root ca 01",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
rootCAP256 := certificateV1{
|
rootCAP256 := certificateV1{
|
||||||
details: detailsV1{
|
details: detailsV1{
|
||||||
Name: "nebula P256 test",
|
name: "nebula P256 test",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
p, err := NewCAPoolFromPEM([]byte(noNewLines))
|
p, err := NewCAPoolFromPEM([]byte(noNewLines))
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, p.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.Name)
|
assert.Equal(t, p.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.name)
|
||||||
assert.Equal(t, p.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.Name)
|
assert.Equal(t, p.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.name)
|
||||||
|
|
||||||
pp, err := NewCAPoolFromPEM([]byte(withNewLines))
|
pp, err := NewCAPoolFromPEM([]byte(withNewLines))
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, pp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.Name)
|
assert.Equal(t, pp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.name)
|
||||||
assert.Equal(t, pp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.Name)
|
assert.Equal(t, pp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.name)
|
||||||
|
|
||||||
// expired cert, no valid certs
|
// expired cert, no valid certs
|
||||||
ppp, err := NewCAPoolFromPEM([]byte(expired))
|
ppp, err := NewCAPoolFromPEM([]byte(expired))
|
||||||
@ -97,13 +97,13 @@ IBNWYMep3ysx9zCgknfG5dKtwGTaqF++BWKDYdyl34KX
|
|||||||
// expired cert, with valid certs
|
// expired cert, with valid certs
|
||||||
pppp, err := NewCAPoolFromPEM(append([]byte(expired), noNewLines...))
|
pppp, err := NewCAPoolFromPEM(append([]byte(expired), noNewLines...))
|
||||||
assert.Equal(t, ErrExpired, err)
|
assert.Equal(t, ErrExpired, err)
|
||||||
assert.Equal(t, pppp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.Name)
|
assert.Equal(t, pppp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.name)
|
||||||
assert.Equal(t, pppp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.Name)
|
assert.Equal(t, pppp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.name)
|
||||||
assert.Equal(t, pppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Certificate.Name(), "expired")
|
assert.Equal(t, pppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Certificate.Name(), "expired")
|
||||||
assert.Equal(t, len(pppp.CAs), 3)
|
assert.Equal(t, len(pppp.CAs), 3)
|
||||||
|
|
||||||
ppppp, err := NewCAPoolFromPEM([]byte(p256))
|
ppppp, err := NewCAPoolFromPEM([]byte(p256))
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, ppppp.CAs[string("a7938893ec8c4ef769b06d7f425e5e46f7a7f5ffa49c3bcf4a86b608caba9159")].Certificate.Name(), rootCAP256.details.Name)
|
assert.Equal(t, ppppp.CAs[string("a7938893ec8c4ef769b06d7f425e5e46f7a7f5ffa49c3bcf4a86b608caba9159")].Certificate.Name(), rootCAP256.details.name)
|
||||||
assert.Equal(t, len(ppppp.CAs), 1)
|
assert.Equal(t, len(ppppp.CAs), 1)
|
||||||
}
|
}
|
||||||
|
|||||||
59
cert/cert.go
59
cert/cert.go
@ -1,13 +1,15 @@
|
|||||||
package cert
|
package cert
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Version int
|
type Version uint8
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
VersionPre1 Version = 0
|
||||||
Version1 Version = 1
|
Version1 Version = 1
|
||||||
Version2 Version = 2
|
Version2 Version = 2
|
||||||
)
|
)
|
||||||
@ -107,23 +109,56 @@ type CachedCertificate struct {
|
|||||||
signerFingerprint string
|
signerFingerprint string
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnmarshalCertificate will attempt to unmarshal a wire protocol level certificate.
|
func (cc *CachedCertificate) String() string {
|
||||||
func UnmarshalCertificate(b []byte) (Certificate, error) {
|
return cc.Certificate.String()
|
||||||
c, err := unmarshalCertificateV1(b, true)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return c, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnmarshalCertificateFromHandshake will attempt to unmarshal a certificate received in a handshake.
|
// RecombineAndValidate will attempt to unmarshal a certificate received in a handshake.
|
||||||
// Handshakes save space by placing the peers public key in a different part of the packet, we have to
|
// Handshakes save space by placing the peers public key in a different part of the packet, we have to
|
||||||
// reassemble the actual certificate structure with that in mind.
|
// reassemble the actual certificate structure with that in mind.
|
||||||
func UnmarshalCertificateFromHandshake(b []byte, publicKey []byte) (Certificate, error) {
|
func RecombineAndValidate(v Version, rawCertBytes, publicKey []byte, curve Curve, caPool *CAPool) (*CachedCertificate, error) {
|
||||||
c, err := unmarshalCertificateV1(b, false)
|
if publicKey == nil {
|
||||||
|
return nil, ErrNoPeerStaticKey
|
||||||
|
}
|
||||||
|
|
||||||
|
if rawCertBytes == nil {
|
||||||
|
return nil, ErrNoPayload
|
||||||
|
}
|
||||||
|
|
||||||
|
c, err := unmarshalCertificateFromHandshake(v, rawCertBytes, publicKey, curve)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error unmarshaling cert: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cc, err := caPool.VerifyCertificate(time.Now(), c)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("certificate validation failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return cc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func unmarshalCertificateFromHandshake(v Version, b []byte, publicKey []byte, curve Curve) (Certificate, error) {
|
||||||
|
var c Certificate
|
||||||
|
var err error
|
||||||
|
|
||||||
|
switch v {
|
||||||
|
case VersionPre1, Version1:
|
||||||
|
c, err = unmarshalCertificateV1(b, publicKey)
|
||||||
|
case Version2:
|
||||||
|
c, err = unmarshalCertificateV2(b, publicKey, curve)
|
||||||
|
default:
|
||||||
|
//TODO: make a static var
|
||||||
|
return nil, fmt.Errorf("unknown certificate version %d", v)
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
c.details.PublicKey = publicKey
|
|
||||||
|
if c.Curve() != curve {
|
||||||
|
return nil, fmt.Errorf("certificate curve %s does not match expected %s", c.Curve().String(), curve.String())
|
||||||
|
}
|
||||||
|
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -24,21 +24,21 @@ func TestMarshalingNebulaCertificate(t *testing.T) {
|
|||||||
|
|
||||||
nc := certificateV1{
|
nc := certificateV1{
|
||||||
details: detailsV1{
|
details: detailsV1{
|
||||||
Name: "testing",
|
name: "testing",
|
||||||
Ips: []netip.Prefix{
|
networks: []netip.Prefix{
|
||||||
mustParsePrefixUnmapped("10.1.1.1/24"),
|
mustParsePrefixUnmapped("10.1.1.1/24"),
|
||||||
mustParsePrefixUnmapped("10.1.1.2/16"),
|
mustParsePrefixUnmapped("10.1.1.2/16"),
|
||||||
},
|
},
|
||||||
Subnets: []netip.Prefix{
|
unsafeNetworks: []netip.Prefix{
|
||||||
mustParsePrefixUnmapped("9.1.1.2/24"),
|
mustParsePrefixUnmapped("9.1.1.2/24"),
|
||||||
mustParsePrefixUnmapped("9.1.1.3/16"),
|
mustParsePrefixUnmapped("9.1.1.3/16"),
|
||||||
},
|
},
|
||||||
Groups: []string{"test-group1", "test-group2", "test-group3"},
|
groups: []string{"test-group1", "test-group2", "test-group3"},
|
||||||
NotBefore: before,
|
notBefore: before,
|
||||||
NotAfter: after,
|
notAfter: after,
|
||||||
PublicKey: pubKey,
|
publicKey: pubKey,
|
||||||
IsCA: false,
|
isCA: false,
|
||||||
Issuer: "1234567890abcedfghij1234567890ab",
|
issuer: "1234567890abcedfghij1234567890ab",
|
||||||
},
|
},
|
||||||
signature: []byte("1234567890abcedfghij1234567890ab"),
|
signature: []byte("1234567890abcedfghij1234567890ab"),
|
||||||
}
|
}
|
||||||
@ -47,20 +47,20 @@ func TestMarshalingNebulaCertificate(t *testing.T) {
|
|||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
//t.Log("Cert size:", len(b))
|
//t.Log("Cert size:", len(b))
|
||||||
|
|
||||||
nc2, err := unmarshalCertificateV1(b, true)
|
nc2, err := unmarshalCertificateV1(b, nil)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
assert.Equal(t, nc.signature, nc2.Signature())
|
assert.Equal(t, nc.signature, nc2.Signature())
|
||||||
assert.Equal(t, nc.details.Name, nc2.Name())
|
assert.Equal(t, nc.details.name, nc2.Name())
|
||||||
assert.Equal(t, nc.details.NotBefore, nc2.NotBefore())
|
assert.Equal(t, nc.details.notBefore, nc2.NotBefore())
|
||||||
assert.Equal(t, nc.details.NotAfter, nc2.NotAfter())
|
assert.Equal(t, nc.details.notAfter, nc2.NotAfter())
|
||||||
assert.Equal(t, nc.details.PublicKey, nc2.PublicKey())
|
assert.Equal(t, nc.details.publicKey, nc2.PublicKey())
|
||||||
assert.Equal(t, nc.details.IsCA, nc2.IsCA())
|
assert.Equal(t, nc.details.isCA, nc2.IsCA())
|
||||||
|
|
||||||
assert.Equal(t, nc.details.Ips, nc2.Networks())
|
assert.Equal(t, nc.details.networks, nc2.Networks())
|
||||||
assert.Equal(t, nc.details.Subnets, nc2.UnsafeNetworks())
|
assert.Equal(t, nc.details.unsafeNetworks, nc2.UnsafeNetworks())
|
||||||
|
|
||||||
assert.Equal(t, nc.details.Groups, nc2.Groups())
|
assert.Equal(t, nc.details.groups, nc2.Groups())
|
||||||
}
|
}
|
||||||
|
|
||||||
//func TestNebulaCertificate_Sign(t *testing.T) {
|
//func TestNebulaCertificate_Sign(t *testing.T) {
|
||||||
@ -150,8 +150,8 @@ func TestMarshalingNebulaCertificate(t *testing.T) {
|
|||||||
func TestNebulaCertificate_Expired(t *testing.T) {
|
func TestNebulaCertificate_Expired(t *testing.T) {
|
||||||
nc := certificateV1{
|
nc := certificateV1{
|
||||||
details: detailsV1{
|
details: detailsV1{
|
||||||
NotBefore: time.Now().Add(time.Second * -60).Round(time.Second),
|
notBefore: time.Now().Add(time.Second * -60).Round(time.Second),
|
||||||
NotAfter: time.Now().Add(time.Second * 60).Round(time.Second),
|
notAfter: time.Now().Add(time.Second * 60).Round(time.Second),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -166,21 +166,21 @@ func TestNebulaCertificate_MarshalJSON(t *testing.T) {
|
|||||||
|
|
||||||
nc := certificateV1{
|
nc := certificateV1{
|
||||||
details: detailsV1{
|
details: detailsV1{
|
||||||
Name: "testing",
|
name: "testing",
|
||||||
Ips: []netip.Prefix{
|
networks: []netip.Prefix{
|
||||||
mustParsePrefixUnmapped("10.1.1.1/24"),
|
mustParsePrefixUnmapped("10.1.1.1/24"),
|
||||||
mustParsePrefixUnmapped("10.1.1.2/16"),
|
mustParsePrefixUnmapped("10.1.1.2/16"),
|
||||||
},
|
},
|
||||||
Subnets: []netip.Prefix{
|
unsafeNetworks: []netip.Prefix{
|
||||||
mustParsePrefixUnmapped("9.1.1.2/24"),
|
mustParsePrefixUnmapped("9.1.1.2/24"),
|
||||||
mustParsePrefixUnmapped("9.1.1.3/16"),
|
mustParsePrefixUnmapped("9.1.1.3/16"),
|
||||||
},
|
},
|
||||||
Groups: []string{"test-group1", "test-group2", "test-group3"},
|
groups: []string{"test-group1", "test-group2", "test-group3"},
|
||||||
NotBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC),
|
notBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC),
|
||||||
NotAfter: time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC),
|
notAfter: time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC),
|
||||||
PublicKey: pubKey,
|
publicKey: pubKey,
|
||||||
IsCA: false,
|
isCA: false,
|
||||||
Issuer: "1234567890abcedfghij1234567890ab",
|
issuer: "1234567890abcedfghij1234567890ab",
|
||||||
},
|
},
|
||||||
signature: []byte("1234567890abcedfghij1234567890ab"),
|
signature: []byte("1234567890abcedfghij1234567890ab"),
|
||||||
}
|
}
|
||||||
@ -189,7 +189,7 @@ func TestNebulaCertificate_MarshalJSON(t *testing.T) {
|
|||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(
|
assert.Equal(
|
||||||
t,
|
t,
|
||||||
"{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"ips\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"subnets\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\"}",
|
"{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"version\":1}",
|
||||||
string(b),
|
string(b),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@ -526,7 +526,7 @@ func TestNebulaCertificate_Copy(t *testing.T) {
|
|||||||
func TestUnmarshalNebulaCertificate(t *testing.T) {
|
func TestUnmarshalNebulaCertificate(t *testing.T) {
|
||||||
// Test that we don't panic with an invalid certificate (#332)
|
// Test that we don't panic with an invalid certificate (#332)
|
||||||
data := []byte("\x98\x00\x00")
|
data := []byte("\x98\x00\x00")
|
||||||
_, err := unmarshalCertificateV1(data, true)
|
_, err := unmarshalCertificateV1(data, nil)
|
||||||
assert.EqualError(t, err, "encoded Details was nil")
|
assert.EqualError(t, err, "encoded Details was nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
396
cert/cert_v1.go
396
cert/cert_v1.go
@ -6,19 +6,16 @@ import (
|
|||||||
"crypto/ecdsa"
|
"crypto/ecdsa"
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
"crypto/elliptic"
|
"crypto/elliptic"
|
||||||
"crypto/rand"
|
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/big"
|
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/slackhq/nebula/pkclient"
|
|
||||||
"golang.org/x/crypto/curve25519"
|
"golang.org/x/crypto/curve25519"
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
)
|
)
|
||||||
@ -31,71 +28,71 @@ type certificateV1 struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type detailsV1 struct {
|
type detailsV1 struct {
|
||||||
Name string
|
name string
|
||||||
Ips []netip.Prefix
|
networks []netip.Prefix
|
||||||
Subnets []netip.Prefix
|
unsafeNetworks []netip.Prefix
|
||||||
Groups []string
|
groups []string
|
||||||
NotBefore time.Time
|
notBefore time.Time
|
||||||
NotAfter time.Time
|
notAfter time.Time
|
||||||
PublicKey []byte
|
publicKey []byte
|
||||||
IsCA bool
|
isCA bool
|
||||||
Issuer string
|
issuer string
|
||||||
|
|
||||||
Curve Curve
|
curve Curve
|
||||||
}
|
}
|
||||||
|
|
||||||
type m map[string]interface{}
|
type m map[string]interface{}
|
||||||
|
|
||||||
func (nc *certificateV1) Version() Version {
|
func (c *certificateV1) Version() Version {
|
||||||
return Version1
|
return Version1
|
||||||
}
|
}
|
||||||
|
|
||||||
func (nc *certificateV1) Curve() Curve {
|
func (c *certificateV1) Curve() Curve {
|
||||||
return nc.details.Curve
|
return c.details.curve
|
||||||
}
|
}
|
||||||
|
|
||||||
func (nc *certificateV1) Groups() []string {
|
func (c *certificateV1) Groups() []string {
|
||||||
return nc.details.Groups
|
return c.details.groups
|
||||||
}
|
}
|
||||||
|
|
||||||
func (nc *certificateV1) IsCA() bool {
|
func (c *certificateV1) IsCA() bool {
|
||||||
return nc.details.IsCA
|
return c.details.isCA
|
||||||
}
|
}
|
||||||
|
|
||||||
func (nc *certificateV1) Issuer() string {
|
func (c *certificateV1) Issuer() string {
|
||||||
return nc.details.Issuer
|
return c.details.issuer
|
||||||
}
|
}
|
||||||
|
|
||||||
func (nc *certificateV1) Name() string {
|
func (c *certificateV1) Name() string {
|
||||||
return nc.details.Name
|
return c.details.name
|
||||||
}
|
}
|
||||||
|
|
||||||
func (nc *certificateV1) Networks() []netip.Prefix {
|
func (c *certificateV1) Networks() []netip.Prefix {
|
||||||
return nc.details.Ips
|
return c.details.networks
|
||||||
}
|
}
|
||||||
|
|
||||||
func (nc *certificateV1) NotAfter() time.Time {
|
func (c *certificateV1) NotAfter() time.Time {
|
||||||
return nc.details.NotAfter
|
return c.details.notAfter
|
||||||
}
|
}
|
||||||
|
|
||||||
func (nc *certificateV1) NotBefore() time.Time {
|
func (c *certificateV1) NotBefore() time.Time {
|
||||||
return nc.details.NotBefore
|
return c.details.notBefore
|
||||||
}
|
}
|
||||||
|
|
||||||
func (nc *certificateV1) PublicKey() []byte {
|
func (c *certificateV1) PublicKey() []byte {
|
||||||
return nc.details.PublicKey
|
return c.details.publicKey
|
||||||
}
|
}
|
||||||
|
|
||||||
func (nc *certificateV1) Signature() []byte {
|
func (c *certificateV1) Signature() []byte {
|
||||||
return nc.signature
|
return c.signature
|
||||||
}
|
}
|
||||||
|
|
||||||
func (nc *certificateV1) UnsafeNetworks() []netip.Prefix {
|
func (c *certificateV1) UnsafeNetworks() []netip.Prefix {
|
||||||
return nc.details.Subnets
|
return c.details.unsafeNetworks
|
||||||
}
|
}
|
||||||
|
|
||||||
func (nc *certificateV1) Fingerprint() (string, error) {
|
func (c *certificateV1) Fingerprint() (string, error) {
|
||||||
b, err := nc.Marshal()
|
b, err := c.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@ -104,33 +101,33 @@ func (nc *certificateV1) Fingerprint() (string, error) {
|
|||||||
return hex.EncodeToString(sum[:]), nil
|
return hex.EncodeToString(sum[:]), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (nc *certificateV1) CheckSignature(key []byte) bool {
|
func (c *certificateV1) CheckSignature(key []byte) bool {
|
||||||
b, err := proto.Marshal(nc.getRawDetails())
|
b, err := proto.Marshal(c.getRawDetails())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
switch nc.details.Curve {
|
switch c.details.curve {
|
||||||
case Curve_CURVE25519:
|
case Curve_CURVE25519:
|
||||||
return ed25519.Verify(key, b, nc.signature)
|
return ed25519.Verify(key, b, c.signature)
|
||||||
case Curve_P256:
|
case Curve_P256:
|
||||||
x, y := elliptic.Unmarshal(elliptic.P256(), key)
|
x, y := elliptic.Unmarshal(elliptic.P256(), key)
|
||||||
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
|
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
|
||||||
hashed := sha256.Sum256(b)
|
hashed := sha256.Sum256(b)
|
||||||
return ecdsa.VerifyASN1(pubKey, hashed[:], nc.signature)
|
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
|
||||||
default:
|
default:
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (nc *certificateV1) Expired(t time.Time) bool {
|
func (c *certificateV1) Expired(t time.Time) bool {
|
||||||
return nc.details.NotBefore.After(t) || nc.details.NotAfter.Before(t)
|
return c.details.notBefore.After(t) || c.details.notAfter.Before(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
|
func (c *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
|
||||||
if curve != nc.details.Curve {
|
if curve != c.details.curve {
|
||||||
return fmt.Errorf("curve in cert and private key supplied don't match")
|
return fmt.Errorf("curve in cert and private key supplied don't match")
|
||||||
}
|
}
|
||||||
if nc.details.IsCA {
|
if c.details.isCA {
|
||||||
switch curve {
|
switch curve {
|
||||||
case Curve_CURVE25519:
|
case Curve_CURVE25519:
|
||||||
// the call to PublicKey below will panic slice bounds out of range otherwise
|
// the call to PublicKey below will panic slice bounds out of range otherwise
|
||||||
@ -138,7 +135,7 @@ func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
|
|||||||
return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key")
|
return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !ed25519.PublicKey(nc.details.PublicKey).Equal(ed25519.PrivateKey(key).Public()) {
|
if !ed25519.PublicKey(c.details.publicKey).Equal(ed25519.PrivateKey(key).Public()) {
|
||||||
return fmt.Errorf("public key in cert and private key supplied don't match")
|
return fmt.Errorf("public key in cert and private key supplied don't match")
|
||||||
}
|
}
|
||||||
case Curve_P256:
|
case Curve_P256:
|
||||||
@ -147,7 +144,7 @@ func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
|
|||||||
return fmt.Errorf("cannot parse private key as P256: %w", err)
|
return fmt.Errorf("cannot parse private key as P256: %w", err)
|
||||||
}
|
}
|
||||||
pub := privkey.PublicKey().Bytes()
|
pub := privkey.PublicKey().Bytes()
|
||||||
if !bytes.Equal(pub, nc.details.PublicKey) {
|
if !bytes.Equal(pub, c.details.publicKey) {
|
||||||
return fmt.Errorf("public key in cert and private key supplied don't match")
|
return fmt.Errorf("public key in cert and private key supplied don't match")
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
@ -173,7 +170,7 @@ func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
|
|||||||
default:
|
default:
|
||||||
return fmt.Errorf("invalid curve: %s", curve)
|
return fmt.Errorf("invalid curve: %s", curve)
|
||||||
}
|
}
|
||||||
if !bytes.Equal(pub, nc.details.PublicKey) {
|
if !bytes.Equal(pub, c.details.publicKey) {
|
||||||
return fmt.Errorf("public key in cert and private key supplied don't match")
|
return fmt.Errorf("public key in cert and private key supplied don't match")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -181,173 +178,155 @@ func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getRawDetails marshals the raw details into protobuf ready struct
|
// getRawDetails marshals the raw details into protobuf ready struct
|
||||||
func (nc *certificateV1) getRawDetails() *RawNebulaCertificateDetails {
|
func (c *certificateV1) getRawDetails() *RawNebulaCertificateDetails {
|
||||||
rd := &RawNebulaCertificateDetails{
|
rd := &RawNebulaCertificateDetails{
|
||||||
Name: nc.details.Name,
|
Name: c.details.name,
|
||||||
Groups: nc.details.Groups,
|
Groups: c.details.groups,
|
||||||
NotBefore: nc.details.NotBefore.Unix(),
|
NotBefore: c.details.notBefore.Unix(),
|
||||||
NotAfter: nc.details.NotAfter.Unix(),
|
NotAfter: c.details.notAfter.Unix(),
|
||||||
PublicKey: make([]byte, len(nc.details.PublicKey)),
|
PublicKey: make([]byte, len(c.details.publicKey)),
|
||||||
IsCA: nc.details.IsCA,
|
IsCA: c.details.isCA,
|
||||||
Curve: nc.details.Curve,
|
Curve: c.details.curve,
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, ipNet := range nc.details.Ips {
|
for _, ipNet := range c.details.networks {
|
||||||
mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen())
|
mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen())
|
||||||
rd.Ips = append(rd.Ips, addr2int(ipNet.Addr()), ip2int(mask))
|
rd.Ips = append(rd.Ips, addr2int(ipNet.Addr()), ip2int(mask))
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, ipNet := range nc.details.Subnets {
|
for _, ipNet := range c.details.unsafeNetworks {
|
||||||
mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen())
|
mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen())
|
||||||
rd.Subnets = append(rd.Subnets, addr2int(ipNet.Addr()), ip2int(mask))
|
rd.Subnets = append(rd.Subnets, addr2int(ipNet.Addr()), ip2int(mask))
|
||||||
}
|
}
|
||||||
|
|
||||||
copy(rd.PublicKey, nc.details.PublicKey[:])
|
copy(rd.PublicKey, c.details.publicKey[:])
|
||||||
|
|
||||||
// I know, this is terrible
|
// I know, this is terrible
|
||||||
rd.Issuer, _ = hex.DecodeString(nc.details.Issuer)
|
rd.Issuer, _ = hex.DecodeString(c.details.issuer)
|
||||||
|
|
||||||
return rd
|
return rd
|
||||||
}
|
}
|
||||||
|
|
||||||
func (nc *certificateV1) String() string {
|
func (c *certificateV1) String() string {
|
||||||
if nc == nil {
|
b, err := json.MarshalIndent(c.marshalJSON(), "", "\t")
|
||||||
return "Certificate {}\n"
|
if err != nil {
|
||||||
|
return "<error marshalling certificate>"
|
||||||
|
}
|
||||||
|
return string(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
s := "NebulaCertificate {\n"
|
func (c *certificateV1) MarshalForHandshakes() ([]byte, error) {
|
||||||
s += "\tDetails {\n"
|
pubKey := c.details.publicKey
|
||||||
s += fmt.Sprintf("\t\tName: %v\n", nc.details.Name)
|
c.details.publicKey = nil
|
||||||
|
rawCertNoKey, err := c.Marshal()
|
||||||
if len(nc.details.Ips) > 0 {
|
|
||||||
s += "\t\tIps: [\n"
|
|
||||||
for _, ip := range nc.details.Ips {
|
|
||||||
s += fmt.Sprintf("\t\t\t%v\n", ip.String())
|
|
||||||
}
|
|
||||||
s += "\t\t]\n"
|
|
||||||
} else {
|
|
||||||
s += "\t\tIps: []\n"
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(nc.details.Subnets) > 0 {
|
|
||||||
s += "\t\tSubnets: [\n"
|
|
||||||
for _, ip := range nc.details.Subnets {
|
|
||||||
s += fmt.Sprintf("\t\t\t%v\n", ip.String())
|
|
||||||
}
|
|
||||||
s += "\t\t]\n"
|
|
||||||
} else {
|
|
||||||
s += "\t\tSubnets: []\n"
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(nc.details.Groups) > 0 {
|
|
||||||
s += "\t\tGroups: [\n"
|
|
||||||
for _, g := range nc.details.Groups {
|
|
||||||
s += fmt.Sprintf("\t\t\t\"%v\"\n", g)
|
|
||||||
}
|
|
||||||
s += "\t\t]\n"
|
|
||||||
} else {
|
|
||||||
s += "\t\tGroups: []\n"
|
|
||||||
}
|
|
||||||
|
|
||||||
s += fmt.Sprintf("\t\tNot before: %v\n", nc.details.NotBefore)
|
|
||||||
s += fmt.Sprintf("\t\tNot After: %v\n", nc.details.NotAfter)
|
|
||||||
s += fmt.Sprintf("\t\tIs CA: %v\n", nc.details.IsCA)
|
|
||||||
s += fmt.Sprintf("\t\tIssuer: %s\n", nc.details.Issuer)
|
|
||||||
s += fmt.Sprintf("\t\tPublic key: %x\n", nc.details.PublicKey)
|
|
||||||
s += fmt.Sprintf("\t\tCurve: %s\n", nc.details.Curve)
|
|
||||||
s += "\t}\n"
|
|
||||||
fp, err := nc.Fingerprint()
|
|
||||||
if err == nil {
|
|
||||||
s += fmt.Sprintf("\tFingerprint: %s\n", fp)
|
|
||||||
}
|
|
||||||
s += fmt.Sprintf("\tSignature: %x\n", nc.Signature())
|
|
||||||
s += "}"
|
|
||||||
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
func (nc *certificateV1) MarshalForHandshakes() ([]byte, error) {
|
|
||||||
pubKey := nc.details.PublicKey
|
|
||||||
nc.details.PublicKey = nil
|
|
||||||
rawCertNoKey, err := nc.Marshal()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
nc.details.PublicKey = pubKey
|
c.details.publicKey = pubKey
|
||||||
return rawCertNoKey, nil
|
return rawCertNoKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (nc *certificateV1) Marshal() ([]byte, error) {
|
func (c *certificateV1) Marshal() ([]byte, error) {
|
||||||
rc := RawNebulaCertificate{
|
rc := RawNebulaCertificate{
|
||||||
Details: nc.getRawDetails(),
|
Details: c.getRawDetails(),
|
||||||
Signature: nc.signature,
|
Signature: c.signature,
|
||||||
}
|
}
|
||||||
|
|
||||||
return proto.Marshal(&rc)
|
return proto.Marshal(&rc)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (nc *certificateV1) MarshalPEM() ([]byte, error) {
|
func (c *certificateV1) MarshalPEM() ([]byte, error) {
|
||||||
b, err := nc.Marshal()
|
b, err := c.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return pem.EncodeToMemory(&pem.Block{Type: CertificateBanner, Bytes: b}), nil
|
return pem.EncodeToMemory(&pem.Block{Type: CertificateBanner, Bytes: b}), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (nc *certificateV1) MarshalJSON() ([]byte, error) {
|
func (c *certificateV1) MarshalJSON() ([]byte, error) {
|
||||||
fp, _ := nc.Fingerprint()
|
return json.Marshal(c.marshalJSON())
|
||||||
jc := m{
|
}
|
||||||
|
|
||||||
|
func (c *certificateV1) marshalJSON() m {
|
||||||
|
fp, _ := c.Fingerprint()
|
||||||
|
return m{
|
||||||
|
"version": Version1,
|
||||||
"details": m{
|
"details": m{
|
||||||
"name": nc.details.Name,
|
"name": c.details.name,
|
||||||
"ips": nc.details.Ips,
|
"networks": c.details.networks,
|
||||||
"subnets": nc.details.Subnets,
|
"unsafeNetworks": c.details.unsafeNetworks,
|
||||||
"groups": nc.details.Groups,
|
"groups": c.details.groups,
|
||||||
"notBefore": nc.details.NotBefore,
|
"notBefore": c.details.notBefore,
|
||||||
"notAfter": nc.details.NotAfter,
|
"notAfter": c.details.notAfter,
|
||||||
"publicKey": fmt.Sprintf("%x", nc.details.PublicKey),
|
"publicKey": fmt.Sprintf("%x", c.details.publicKey),
|
||||||
"isCa": nc.details.IsCA,
|
"isCa": c.details.isCA,
|
||||||
"issuer": nc.details.Issuer,
|
"issuer": c.details.issuer,
|
||||||
"curve": nc.details.Curve.String(),
|
"curve": c.details.curve.String(),
|
||||||
},
|
},
|
||||||
"fingerprint": fp,
|
"fingerprint": fp,
|
||||||
"signature": fmt.Sprintf("%x", nc.Signature()),
|
"signature": fmt.Sprintf("%x", c.Signature()),
|
||||||
}
|
}
|
||||||
return json.Marshal(jc)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (nc *certificateV1) Copy() Certificate {
|
func (c *certificateV1) Copy() Certificate {
|
||||||
c := &certificateV1{
|
nc := &certificateV1{
|
||||||
details: detailsV1{
|
details: detailsV1{
|
||||||
Name: nc.details.Name,
|
name: c.details.name,
|
||||||
Groups: make([]string, len(nc.details.Groups)),
|
groups: make([]string, len(c.details.groups)),
|
||||||
Ips: make([]netip.Prefix, len(nc.details.Ips)),
|
networks: make([]netip.Prefix, len(c.details.networks)),
|
||||||
Subnets: make([]netip.Prefix, len(nc.details.Subnets)),
|
unsafeNetworks: make([]netip.Prefix, len(c.details.unsafeNetworks)),
|
||||||
NotBefore: nc.details.NotBefore,
|
notBefore: c.details.notBefore,
|
||||||
NotAfter: nc.details.NotAfter,
|
notAfter: c.details.notAfter,
|
||||||
PublicKey: make([]byte, len(nc.details.PublicKey)),
|
publicKey: make([]byte, len(c.details.publicKey)),
|
||||||
IsCA: nc.details.IsCA,
|
isCA: c.details.isCA,
|
||||||
Issuer: nc.details.Issuer,
|
issuer: c.details.issuer,
|
||||||
|
curve: c.details.curve,
|
||||||
},
|
},
|
||||||
signature: make([]byte, len(nc.signature)),
|
signature: make([]byte, len(c.signature)),
|
||||||
}
|
}
|
||||||
|
|
||||||
copy(c.signature, nc.signature)
|
copy(nc.signature, c.signature)
|
||||||
copy(c.details.Groups, nc.details.Groups)
|
copy(nc.details.groups, c.details.groups)
|
||||||
copy(c.details.PublicKey, nc.details.PublicKey)
|
copy(nc.details.publicKey, c.details.publicKey)
|
||||||
|
copy(nc.details.networks, c.details.networks)
|
||||||
|
copy(nc.details.unsafeNetworks, c.details.unsafeNetworks)
|
||||||
|
|
||||||
for i, p := range nc.details.Ips {
|
return nc
|
||||||
c.details.Ips[i] = p
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, p := range nc.details.Subnets {
|
func (c *certificateV1) fromTBSCertificate(t *TBSCertificate) error {
|
||||||
c.details.Subnets[i] = p
|
c.details = detailsV1{
|
||||||
|
name: t.Name,
|
||||||
|
networks: t.Networks,
|
||||||
|
unsafeNetworks: t.UnsafeNetworks,
|
||||||
|
groups: t.Groups,
|
||||||
|
notBefore: t.NotBefore,
|
||||||
|
notAfter: t.NotAfter,
|
||||||
|
publicKey: t.PublicKey,
|
||||||
|
isCA: t.IsCA,
|
||||||
|
curve: t.Curve,
|
||||||
|
issuer: t.issuer,
|
||||||
}
|
}
|
||||||
|
|
||||||
return c
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *certificateV1) marshalForSigning() ([]byte, error) {
|
||||||
|
b, err := proto.Marshal(c.getRawDetails())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return b, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *certificateV1) setSignature(b []byte) error {
|
||||||
|
c.signature = b
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// unmarshalCertificateV1 will unmarshal a protobuf byte representation of a nebula cert
|
// unmarshalCertificateV1 will unmarshal a protobuf byte representation of a nebula cert
|
||||||
func unmarshalCertificateV1(b []byte, assertPublicKey bool) (*certificateV1, error) {
|
// if the publicKey is provided here then it is not required to be present in `b`
|
||||||
|
func unmarshalCertificateV1(b []byte, publicKey []byte) (*certificateV1, error) {
|
||||||
if len(b) == 0 {
|
if len(b) == 0 {
|
||||||
return nil, fmt.Errorf("nil byte array")
|
return nil, fmt.Errorf("nil byte array")
|
||||||
}
|
}
|
||||||
@ -371,27 +350,28 @@ func unmarshalCertificateV1(b []byte, assertPublicKey bool) (*certificateV1, err
|
|||||||
|
|
||||||
nc := certificateV1{
|
nc := certificateV1{
|
||||||
details: detailsV1{
|
details: detailsV1{
|
||||||
Name: rc.Details.Name,
|
name: rc.Details.Name,
|
||||||
Groups: make([]string, len(rc.Details.Groups)),
|
groups: make([]string, len(rc.Details.Groups)),
|
||||||
Ips: make([]netip.Prefix, len(rc.Details.Ips)/2),
|
networks: make([]netip.Prefix, len(rc.Details.Ips)/2),
|
||||||
Subnets: make([]netip.Prefix, len(rc.Details.Subnets)/2),
|
unsafeNetworks: make([]netip.Prefix, len(rc.Details.Subnets)/2),
|
||||||
NotBefore: time.Unix(rc.Details.NotBefore, 0),
|
notBefore: time.Unix(rc.Details.NotBefore, 0),
|
||||||
NotAfter: time.Unix(rc.Details.NotAfter, 0),
|
notAfter: time.Unix(rc.Details.NotAfter, 0),
|
||||||
PublicKey: make([]byte, len(rc.Details.PublicKey)),
|
publicKey: make([]byte, len(rc.Details.PublicKey)),
|
||||||
IsCA: rc.Details.IsCA,
|
isCA: rc.Details.IsCA,
|
||||||
Curve: rc.Details.Curve,
|
curve: rc.Details.Curve,
|
||||||
},
|
},
|
||||||
signature: make([]byte, len(rc.Signature)),
|
signature: make([]byte, len(rc.Signature)),
|
||||||
}
|
}
|
||||||
|
|
||||||
copy(nc.signature, rc.Signature)
|
copy(nc.signature, rc.Signature)
|
||||||
copy(nc.details.Groups, rc.Details.Groups)
|
copy(nc.details.groups, rc.Details.Groups)
|
||||||
nc.details.Issuer = hex.EncodeToString(rc.Details.Issuer)
|
nc.details.issuer = hex.EncodeToString(rc.Details.Issuer)
|
||||||
|
|
||||||
if len(rc.Details.PublicKey) < publicKeyLen && assertPublicKey {
|
if len(publicKey) > 0 {
|
||||||
return nil, fmt.Errorf("public key was fewer than 32 bytes; %v", len(rc.Details.PublicKey))
|
nc.details.publicKey = publicKey
|
||||||
}
|
}
|
||||||
copy(nc.details.PublicKey, rc.Details.PublicKey)
|
|
||||||
|
copy(nc.details.publicKey, rc.Details.PublicKey)
|
||||||
|
|
||||||
var ip netip.Addr
|
var ip netip.Addr
|
||||||
for i, rawIp := range rc.Details.Ips {
|
for i, rawIp := range rc.Details.Ips {
|
||||||
@ -399,7 +379,7 @@ func unmarshalCertificateV1(b []byte, assertPublicKey bool) (*certificateV1, err
|
|||||||
ip = int2addr(rawIp)
|
ip = int2addr(rawIp)
|
||||||
} else {
|
} else {
|
||||||
ones, _ := net.IPMask(int2ip(rawIp)).Size()
|
ones, _ := net.IPMask(int2ip(rawIp)).Size()
|
||||||
nc.details.Ips[i/2] = netip.PrefixFrom(ip, ones)
|
nc.details.networks[i/2] = netip.PrefixFrom(ip, ones)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -408,69 +388,15 @@ func unmarshalCertificateV1(b []byte, assertPublicKey bool) (*certificateV1, err
|
|||||||
ip = int2addr(rawIp)
|
ip = int2addr(rawIp)
|
||||||
} else {
|
} else {
|
||||||
ones, _ := net.IPMask(int2ip(rawIp)).Size()
|
ones, _ := net.IPMask(int2ip(rawIp)).Size()
|
||||||
nc.details.Subnets[i/2] = netip.PrefixFrom(ip, ones)
|
nc.details.unsafeNetworks[i/2] = netip.PrefixFrom(ip, ones)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//do not sort the subnets field for V1 certs
|
||||||
|
|
||||||
return &nc, nil
|
return &nc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func signV1(t *TBSCertificate, curve Curve, key []byte, client *pkclient.PKClient) (*certificateV1, error) {
|
|
||||||
c := &certificateV1{
|
|
||||||
details: detailsV1{
|
|
||||||
Name: t.Name,
|
|
||||||
Ips: t.Networks,
|
|
||||||
Subnets: t.UnsafeNetworks,
|
|
||||||
Groups: t.Groups,
|
|
||||||
NotBefore: t.NotBefore,
|
|
||||||
NotAfter: t.NotAfter,
|
|
||||||
PublicKey: t.PublicKey,
|
|
||||||
IsCA: t.IsCA,
|
|
||||||
Curve: t.Curve,
|
|
||||||
Issuer: t.issuer,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
b, err := proto.Marshal(c.getRawDetails())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var sig []byte
|
|
||||||
|
|
||||||
switch curve {
|
|
||||||
case Curve_CURVE25519:
|
|
||||||
signer := ed25519.PrivateKey(key)
|
|
||||||
sig = ed25519.Sign(signer, b)
|
|
||||||
case Curve_P256:
|
|
||||||
if client != nil {
|
|
||||||
sig, err = client.SignASN1(b)
|
|
||||||
} else {
|
|
||||||
signer := &ecdsa.PrivateKey{
|
|
||||||
PublicKey: ecdsa.PublicKey{
|
|
||||||
Curve: elliptic.P256(),
|
|
||||||
},
|
|
||||||
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95
|
|
||||||
D: new(big.Int).SetBytes(key),
|
|
||||||
}
|
|
||||||
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119
|
|
||||||
signer.X, signer.Y = signer.Curve.ScalarBaseMult(key)
|
|
||||||
|
|
||||||
// We need to hash first for ECDSA
|
|
||||||
// - https://pkg.go.dev/crypto/ecdsa#SignASN1
|
|
||||||
hashed := sha256.Sum256(b)
|
|
||||||
sig, err = ecdsa.SignASN1(rand.Reader, signer, hashed[:])
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("invalid curve: %s", c.details.Curve)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.signature = sig
|
|
||||||
return c, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func ip2int(ip []byte) uint32 {
|
func ip2int(ip []byte) uint32 {
|
||||||
if len(ip) == 16 {
|
if len(ip) == 16 {
|
||||||
return binary.BigEndian.Uint32(ip[12:16])
|
return binary.BigEndian.Uint32(ip[12:16])
|
||||||
|
|||||||
37
cert/cert_v2.asn1
Normal file
37
cert/cert_v2.asn1
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
Nebula DEFINITIONS AUTOMATIC TAGS ::= BEGIN
|
||||||
|
|
||||||
|
Name ::= UTF8String (SIZE (1..253))
|
||||||
|
Time ::= INTEGER (0..18446744073709551615) -- Seconds since unix epoch, uint64 maximum
|
||||||
|
Network ::= OCTET STRING (SIZE (5,17)) -- IP addresses are 4 or 16 bytes + 1 byte for the prefix length
|
||||||
|
Curve ::= ENUMERATED {
|
||||||
|
curve25519 (0),
|
||||||
|
p256 (1)
|
||||||
|
}
|
||||||
|
|
||||||
|
-- The maximum size of a certificate must not exceed 65536 bytes
|
||||||
|
Certificate ::= SEQUENCE {
|
||||||
|
details OCTET STRING,
|
||||||
|
curve Curve DEFAULT curve25519,
|
||||||
|
publicKey OCTET STRING,
|
||||||
|
-- signature(details + curve + publicKey) using the appropriate method for curve
|
||||||
|
signature OCTET STRING
|
||||||
|
}
|
||||||
|
|
||||||
|
Details ::= SEQUENCE {
|
||||||
|
name Name,
|
||||||
|
|
||||||
|
-- At least 1 ipv4 or ipv6 address must be present if isCA is false
|
||||||
|
networks SEQUENCE OF Network OPTIONAL,
|
||||||
|
unsafeNetworks SEQUENCE OF Network OPTIONAL,
|
||||||
|
groups SEQUENCE OF Name OPTIONAL,
|
||||||
|
isCA BOOLEAN DEFAULT false,
|
||||||
|
notBefore Time,
|
||||||
|
notAfter Time,
|
||||||
|
|
||||||
|
-- issuer is only required if isCA is false, if isCA is true then it must not be present
|
||||||
|
issuer OCTET STRING OPTIONAL,
|
||||||
|
...
|
||||||
|
-- New fields can be added below here
|
||||||
|
}
|
||||||
|
|
||||||
|
END
|
||||||
621
cert/cert_v2.go
Normal file
621
cert/cert_v2.go
Normal file
@ -0,0 +1,621 @@
|
|||||||
|
package cert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/ecdh"
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"slices"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/cryptobyte"
|
||||||
|
"golang.org/x/crypto/cryptobyte/asn1"
|
||||||
|
"golang.org/x/crypto/curve25519"
|
||||||
|
)
|
||||||
|
|
||||||
|
//TODO: should we avoid hex encoding shit on output? Just let it be base64?
|
||||||
|
|
||||||
|
const (
|
||||||
|
classConstructed = 0x20
|
||||||
|
classContextSpecific = 0x80
|
||||||
|
|
||||||
|
TagCertDetails = 0 | classConstructed | classContextSpecific
|
||||||
|
TagCertCurve = 1 | classContextSpecific
|
||||||
|
TagCertPublicKey = 2 | classContextSpecific
|
||||||
|
TagCertSignature = 3 | classContextSpecific
|
||||||
|
|
||||||
|
TagDetailsName = 0 | classContextSpecific
|
||||||
|
TagDetailsNetworks = 1 | classConstructed | classContextSpecific
|
||||||
|
TagDetailsUnsafeNetworks = 2 | classConstructed | classContextSpecific
|
||||||
|
TagDetailsGroups = 3 | classConstructed | classContextSpecific
|
||||||
|
TagDetailsIsCA = 4 | classContextSpecific
|
||||||
|
TagDetailsNotBefore = 5 | classContextSpecific
|
||||||
|
TagDetailsNotAfter = 6 | classContextSpecific
|
||||||
|
TagDetailsIssuer = 7 | classContextSpecific
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// MaxCertificateSize is the maximum length a valid certificate can be
|
||||||
|
MaxCertificateSize = 65536
|
||||||
|
|
||||||
|
// MaxNameLength is limited to a maximum realistic DNS domain name to help facilitate DNS systems
|
||||||
|
MaxNameLength = 253
|
||||||
|
|
||||||
|
// MaxNetworkLength is the maximum length a network value can be.
|
||||||
|
// 16 bytes for an ipv6 address + 1 byte for the prefix length
|
||||||
|
MaxNetworkLength = 17
|
||||||
|
)
|
||||||
|
|
||||||
|
type certificateV2 struct {
|
||||||
|
details detailsV2
|
||||||
|
|
||||||
|
// RawDetails contains the entire asn.1 DER encoded Details struct
|
||||||
|
// This is to benefit forwards compatibility in signature checking.
|
||||||
|
// signature(RawDetails + Curve + PublicKey) == Signature
|
||||||
|
rawDetails []byte
|
||||||
|
curve Curve
|
||||||
|
publicKey []byte
|
||||||
|
signature []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type detailsV2 struct {
|
||||||
|
name string
|
||||||
|
networks []netip.Prefix
|
||||||
|
unsafeNetworks []netip.Prefix
|
||||||
|
groups []string
|
||||||
|
isCA bool
|
||||||
|
notBefore time.Time
|
||||||
|
notAfter time.Time
|
||||||
|
issuer string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *certificateV2) Version() Version {
|
||||||
|
return Version2
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *certificateV2) Curve() Curve {
|
||||||
|
return c.curve
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *certificateV2) Groups() []string {
|
||||||
|
return c.details.groups
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *certificateV2) IsCA() bool {
|
||||||
|
return c.details.isCA
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *certificateV2) Issuer() string {
|
||||||
|
return c.details.issuer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *certificateV2) Name() string {
|
||||||
|
return c.details.name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *certificateV2) Networks() []netip.Prefix {
|
||||||
|
return c.details.networks
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *certificateV2) NotAfter() time.Time {
|
||||||
|
return c.details.notAfter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *certificateV2) NotBefore() time.Time {
|
||||||
|
return c.details.notBefore
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *certificateV2) PublicKey() []byte {
|
||||||
|
return c.publicKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *certificateV2) Signature() []byte {
|
||||||
|
return c.signature
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *certificateV2) UnsafeNetworks() []netip.Prefix {
|
||||||
|
return c.details.unsafeNetworks
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *certificateV2) Fingerprint() (string, error) {
|
||||||
|
b := make([]byte, len(c.rawDetails)+1+len(c.publicKey))
|
||||||
|
//TODO: double check this, panic on empty raw details
|
||||||
|
copy(b, c.rawDetails)
|
||||||
|
b[len(c.rawDetails)] = byte(c.curve)
|
||||||
|
copy(b[len(c.rawDetails)+1:], c.publicKey)
|
||||||
|
copy(b[len(c.rawDetails)+1+len(c.publicKey):], c.signature)
|
||||||
|
sum := sha256.Sum256(b)
|
||||||
|
return hex.EncodeToString(sum[:]), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *certificateV2) CheckSignature(key []byte) bool {
|
||||||
|
b := make([]byte, len(c.rawDetails)+1+len(c.publicKey))
|
||||||
|
//TODO: double check this, panic on empty raw details
|
||||||
|
copy(b, c.rawDetails)
|
||||||
|
b[len(c.rawDetails)] = byte(c.curve)
|
||||||
|
copy(b[len(c.rawDetails)+1:], c.publicKey)
|
||||||
|
|
||||||
|
switch c.curve {
|
||||||
|
case Curve_CURVE25519:
|
||||||
|
return ed25519.Verify(key, b, c.signature)
|
||||||
|
case Curve_P256:
|
||||||
|
//TODO: NewPublicKey
|
||||||
|
x, y := elliptic.Unmarshal(elliptic.P256(), key)
|
||||||
|
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
|
||||||
|
hashed := sha256.Sum256(b)
|
||||||
|
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *certificateV2) Expired(t time.Time) bool {
|
||||||
|
return c.details.notBefore.After(t) || c.details.notAfter.Before(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *certificateV2) VerifyPrivateKey(curve Curve, key []byte) error {
|
||||||
|
if curve != c.curve {
|
||||||
|
return fmt.Errorf("curve in cert and private key supplied don't match")
|
||||||
|
}
|
||||||
|
if c.details.isCA {
|
||||||
|
switch curve {
|
||||||
|
case Curve_CURVE25519:
|
||||||
|
// the call to PublicKey below will panic slice bounds out of range otherwise
|
||||||
|
if len(key) != ed25519.PrivateKeySize {
|
||||||
|
return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !ed25519.PublicKey(c.publicKey).Equal(ed25519.PrivateKey(key).Public()) {
|
||||||
|
return fmt.Errorf("public key in cert and private key supplied don't match")
|
||||||
|
}
|
||||||
|
case Curve_P256:
|
||||||
|
privkey, err := ecdh.P256().NewPrivateKey(key)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot parse private key as P256")
|
||||||
|
}
|
||||||
|
pub := privkey.PublicKey().Bytes()
|
||||||
|
if !bytes.Equal(pub, c.publicKey) {
|
||||||
|
return fmt.Errorf("public key in cert and private key supplied don't match")
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("invalid curve: %s", curve)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var pub []byte
|
||||||
|
switch curve {
|
||||||
|
case Curve_CURVE25519:
|
||||||
|
var err error
|
||||||
|
pub, err = curve25519.X25519(key, curve25519.Basepoint)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case Curve_P256:
|
||||||
|
privkey, err := ecdh.P256().NewPrivateKey(key)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
pub = privkey.PublicKey().Bytes()
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("invalid curve: %s", curve)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(pub, c.publicKey) {
|
||||||
|
return fmt.Errorf("public key in cert and private key supplied don't match")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *certificateV2) String() string {
|
||||||
|
b, err := json.MarshalIndent(c.marshalJSON(), "", "\t")
|
||||||
|
if err != nil {
|
||||||
|
return "<error marshalling certificate>"
|
||||||
|
}
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *certificateV2) MarshalForHandshakes() ([]byte, error) {
|
||||||
|
var b cryptobyte.Builder
|
||||||
|
// Outermost certificate
|
||||||
|
b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
|
||||||
|
|
||||||
|
// Add the cert details which is already marshalled
|
||||||
|
//TODO: panic on nil rawDetails
|
||||||
|
b.AddBytes(c.rawDetails)
|
||||||
|
|
||||||
|
// Skipping the curve and public key since those come across in a different part of the handshake
|
||||||
|
|
||||||
|
// Add the signature
|
||||||
|
b.AddASN1(TagCertSignature, func(b *cryptobyte.Builder) {
|
||||||
|
b.AddBytes(c.signature)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
return b.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *certificateV2) Marshal() ([]byte, error) {
|
||||||
|
var b cryptobyte.Builder
|
||||||
|
// Outermost certificate
|
||||||
|
b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
|
||||||
|
|
||||||
|
// Add the cert details which is already marshalled
|
||||||
|
b.AddBytes(c.rawDetails)
|
||||||
|
|
||||||
|
// Add the curve only if its not the default value
|
||||||
|
if c.curve != Curve_CURVE25519 {
|
||||||
|
b.AddASN1(TagCertCurve, func(b *cryptobyte.Builder) {
|
||||||
|
b.AddBytes([]byte{byte(c.curve)})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the public key if it is not empty
|
||||||
|
if c.publicKey != nil {
|
||||||
|
b.AddASN1(TagCertPublicKey, func(b *cryptobyte.Builder) {
|
||||||
|
b.AddBytes(c.publicKey)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the signature
|
||||||
|
b.AddASN1(TagCertSignature, func(b *cryptobyte.Builder) {
|
||||||
|
b.AddBytes(c.signature)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
return b.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *certificateV2) MarshalPEM() ([]byte, error) {
|
||||||
|
b, err := c.Marshal()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return pem.EncodeToMemory(&pem.Block{Type: CertificateV2Banner, Bytes: b}), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *certificateV2) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(c.marshalJSON())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *certificateV2) marshalJSON() m {
|
||||||
|
fp, _ := c.Fingerprint()
|
||||||
|
return m{
|
||||||
|
"details": m{
|
||||||
|
"name": c.details.name,
|
||||||
|
"networks": c.details.networks,
|
||||||
|
"unsafeNetworks": c.details.unsafeNetworks,
|
||||||
|
"groups": c.details.groups,
|
||||||
|
"notBefore": c.details.notBefore,
|
||||||
|
"notAfter": c.details.notAfter,
|
||||||
|
"isCa": c.details.isCA,
|
||||||
|
"issuer": c.details.issuer,
|
||||||
|
},
|
||||||
|
"version": Version2,
|
||||||
|
"publicKey": fmt.Sprintf("%x", c.publicKey),
|
||||||
|
"curve": c.curve.String(),
|
||||||
|
"fingerprint": fp,
|
||||||
|
"signature": fmt.Sprintf("%x", c.Signature()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *certificateV2) Copy() Certificate {
|
||||||
|
nc := &certificateV2{
|
||||||
|
details: detailsV2{
|
||||||
|
name: c.details.name,
|
||||||
|
groups: make([]string, len(c.details.groups)),
|
||||||
|
networks: make([]netip.Prefix, len(c.details.networks)),
|
||||||
|
unsafeNetworks: make([]netip.Prefix, len(c.details.unsafeNetworks)),
|
||||||
|
notBefore: c.details.notBefore,
|
||||||
|
notAfter: c.details.notAfter,
|
||||||
|
isCA: c.details.isCA,
|
||||||
|
issuer: c.details.issuer,
|
||||||
|
},
|
||||||
|
curve: c.curve,
|
||||||
|
publicKey: make([]byte, len(c.publicKey)),
|
||||||
|
signature: make([]byte, len(c.signature)),
|
||||||
|
}
|
||||||
|
|
||||||
|
copy(nc.signature, c.signature)
|
||||||
|
copy(nc.details.groups, c.details.groups)
|
||||||
|
copy(nc.publicKey, c.publicKey)
|
||||||
|
copy(nc.details.networks, c.details.networks)
|
||||||
|
copy(nc.details.unsafeNetworks, c.details.unsafeNetworks)
|
||||||
|
|
||||||
|
return nc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *certificateV2) fromTBSCertificate(t *TBSCertificate) error {
|
||||||
|
c.details = detailsV2{
|
||||||
|
name: t.Name,
|
||||||
|
networks: t.Networks,
|
||||||
|
unsafeNetworks: t.UnsafeNetworks,
|
||||||
|
groups: t.Groups,
|
||||||
|
isCA: t.IsCA,
|
||||||
|
notBefore: t.NotBefore,
|
||||||
|
notAfter: t.NotAfter,
|
||||||
|
issuer: t.issuer,
|
||||||
|
}
|
||||||
|
c.curve = t.Curve
|
||||||
|
c.publicKey = t.PublicKey
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *certificateV2) marshalForSigning() ([]byte, error) {
|
||||||
|
d, err := c.details.Marshal()
|
||||||
|
if err != nil {
|
||||||
|
//TODO: annotate?
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
c.rawDetails = d
|
||||||
|
|
||||||
|
b := make([]byte, len(c.rawDetails)+1+len(c.publicKey))
|
||||||
|
//TODO: double check this
|
||||||
|
copy(b, c.rawDetails)
|
||||||
|
b[len(c.rawDetails)] = byte(c.curve)
|
||||||
|
copy(b[len(c.rawDetails)+1:], c.publicKey)
|
||||||
|
return b, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *certificateV2) setSignature(b []byte) error {
|
||||||
|
c.signature = b
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *detailsV2) Marshal() ([]byte, error) {
|
||||||
|
var b cryptobyte.Builder
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// Details are a structure
|
||||||
|
b.AddASN1(TagCertDetails, func(b *cryptobyte.Builder) {
|
||||||
|
|
||||||
|
// Add the name
|
||||||
|
b.AddASN1(TagDetailsName, func(b *cryptobyte.Builder) {
|
||||||
|
b.AddBytes([]byte(d.name))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Add the networks if any exist
|
||||||
|
if len(d.networks) > 0 {
|
||||||
|
b.AddASN1(TagDetailsNetworks, func(b *cryptobyte.Builder) {
|
||||||
|
for _, n := range d.networks {
|
||||||
|
sb, innerErr := n.MarshalBinary()
|
||||||
|
if innerErr != nil {
|
||||||
|
// MarshalBinary never returns an error
|
||||||
|
err = fmt.Errorf("unable to marshal network: %w", innerErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
b.AddASN1OctetString(sb)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the unsafe networks if any exist
|
||||||
|
if len(d.unsafeNetworks) > 0 {
|
||||||
|
b.AddASN1(TagDetailsUnsafeNetworks, func(b *cryptobyte.Builder) {
|
||||||
|
for _, n := range d.unsafeNetworks {
|
||||||
|
sb, innerErr := n.MarshalBinary()
|
||||||
|
if innerErr != nil {
|
||||||
|
// MarshalBinary never returns an error
|
||||||
|
err = fmt.Errorf("unable to marshal unsafe network: %w", innerErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
b.AddASN1OctetString(sb)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add groups if any exist
|
||||||
|
if len(d.groups) > 0 {
|
||||||
|
b.AddASN1(TagDetailsGroups, func(b *cryptobyte.Builder) {
|
||||||
|
for _, group := range d.groups {
|
||||||
|
b.AddASN1(asn1.UTF8String, func(b *cryptobyte.Builder) {
|
||||||
|
b.AddBytes([]byte(group))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add IsCA only if true
|
||||||
|
if d.isCA {
|
||||||
|
b.AddASN1(TagDetailsIsCA, func(b *cryptobyte.Builder) {
|
||||||
|
b.AddUint8(0xff)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add not before
|
||||||
|
b.AddASN1Int64WithTag(d.notBefore.Unix(), TagDetailsNotBefore)
|
||||||
|
|
||||||
|
// Add not after
|
||||||
|
b.AddASN1Int64WithTag(d.notAfter.Unix(), TagDetailsNotAfter)
|
||||||
|
|
||||||
|
// Add the issuer if present
|
||||||
|
if d.issuer != "" {
|
||||||
|
issuerBytes, innerErr := hex.DecodeString(d.issuer)
|
||||||
|
if innerErr != nil {
|
||||||
|
err = fmt.Errorf("failed to decode issuer: %w", innerErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
b.AddASN1(TagDetailsIssuer, func(b *cryptobyte.Builder) {
|
||||||
|
b.AddBytes(issuerBytes)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
func unmarshalCertificateV2(b []byte, publicKey []byte, curve Curve) (*certificateV2, error) {
|
||||||
|
l := len(b)
|
||||||
|
if l == 0 || l > MaxCertificateSize {
|
||||||
|
return nil, ErrBadFormat
|
||||||
|
}
|
||||||
|
|
||||||
|
input := cryptobyte.String(b)
|
||||||
|
// Open the envelope
|
||||||
|
if !input.ReadASN1(&input, asn1.SEQUENCE) || input.Empty() {
|
||||||
|
return nil, ErrBadFormat
|
||||||
|
}
|
||||||
|
|
||||||
|
// Grab the cert details, we need to preserve the tag and length
|
||||||
|
var rawDetails cryptobyte.String
|
||||||
|
if !input.ReadASN1Element(&rawDetails, TagCertDetails) || rawDetails.Empty() {
|
||||||
|
return nil, ErrBadFormat
|
||||||
|
}
|
||||||
|
|
||||||
|
//Maybe grab the curve
|
||||||
|
var rawCurve byte
|
||||||
|
if !readOptionalASN1Byte(&input, &rawCurve, TagCertCurve, byte(curve)) {
|
||||||
|
return nil, ErrBadFormat
|
||||||
|
}
|
||||||
|
curve = Curve(rawCurve)
|
||||||
|
|
||||||
|
// Maybe grab the public key
|
||||||
|
var rawPublicKey cryptobyte.String
|
||||||
|
if len(publicKey) > 0 {
|
||||||
|
rawPublicKey = publicKey
|
||||||
|
} else if !input.ReadOptionalASN1(&rawPublicKey, nil, TagCertPublicKey) {
|
||||||
|
return nil, ErrBadFormat
|
||||||
|
}
|
||||||
|
|
||||||
|
//TODO: Assert public key length
|
||||||
|
|
||||||
|
// Grab the signature
|
||||||
|
var rawSignature cryptobyte.String
|
||||||
|
if !input.ReadASN1(&rawSignature, TagCertSignature) || rawSignature.Empty() {
|
||||||
|
return nil, ErrBadFormat
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finally unmarshal the details
|
||||||
|
details, err := unmarshalDetails(rawDetails)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &certificateV2{
|
||||||
|
details: details,
|
||||||
|
rawDetails: rawDetails,
|
||||||
|
curve: curve,
|
||||||
|
publicKey: rawPublicKey,
|
||||||
|
signature: rawSignature,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func unmarshalDetails(b cryptobyte.String) (detailsV2, error) {
|
||||||
|
// Open the envelope
|
||||||
|
if !b.ReadASN1(&b, TagCertDetails) || b.Empty() {
|
||||||
|
return detailsV2{}, ErrBadFormat
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the name
|
||||||
|
var name cryptobyte.String
|
||||||
|
if !b.ReadASN1(&name, TagDetailsName) || name.Empty() || len(name) > MaxNameLength {
|
||||||
|
return detailsV2{}, ErrBadFormat
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the network addresses
|
||||||
|
var subString cryptobyte.String
|
||||||
|
var found bool
|
||||||
|
|
||||||
|
if !b.ReadOptionalASN1(&subString, &found, TagDetailsNetworks) {
|
||||||
|
return detailsV2{}, ErrBadFormat
|
||||||
|
}
|
||||||
|
|
||||||
|
var networks []netip.Prefix
|
||||||
|
var val cryptobyte.String
|
||||||
|
if found {
|
||||||
|
for !subString.Empty() {
|
||||||
|
if !subString.ReadASN1(&val, asn1.OCTET_STRING) || val.Empty() || len(val) > MaxNetworkLength {
|
||||||
|
return detailsV2{}, ErrBadFormat
|
||||||
|
}
|
||||||
|
|
||||||
|
var n netip.Prefix
|
||||||
|
if err := n.UnmarshalBinary(val); err != nil {
|
||||||
|
return detailsV2{}, ErrBadFormat
|
||||||
|
}
|
||||||
|
networks = append(networks, n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read out any unsafe networks
|
||||||
|
if !b.ReadOptionalASN1(&subString, &found, TagDetailsUnsafeNetworks) {
|
||||||
|
return detailsV2{}, ErrBadFormat
|
||||||
|
}
|
||||||
|
|
||||||
|
var unsafeNetworks []netip.Prefix
|
||||||
|
if found {
|
||||||
|
for !subString.Empty() {
|
||||||
|
if !subString.ReadASN1(&val, asn1.OCTET_STRING) || val.Empty() || len(val) > MaxNetworkLength {
|
||||||
|
return detailsV2{}, ErrBadFormat
|
||||||
|
}
|
||||||
|
|
||||||
|
var n netip.Prefix
|
||||||
|
if err := n.UnmarshalBinary(val); err != nil {
|
||||||
|
return detailsV2{}, ErrBadFormat
|
||||||
|
}
|
||||||
|
unsafeNetworks = append(unsafeNetworks, n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read out any groups
|
||||||
|
if !b.ReadOptionalASN1(&subString, &found, TagDetailsGroups) {
|
||||||
|
return detailsV2{}, ErrBadFormat
|
||||||
|
}
|
||||||
|
|
||||||
|
var groups []string
|
||||||
|
if found {
|
||||||
|
for !subString.Empty() {
|
||||||
|
if !subString.ReadASN1(&val, asn1.UTF8String) || val.Empty() {
|
||||||
|
return detailsV2{}, ErrBadFormat
|
||||||
|
}
|
||||||
|
groups = append(groups, string(val))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read out IsCA
|
||||||
|
var isCa bool
|
||||||
|
if !readOptionalASN1Boolean(&b, &isCa, TagDetailsIsCA, false) {
|
||||||
|
return detailsV2{}, ErrBadFormat
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read not before and not after
|
||||||
|
var notBefore int64
|
||||||
|
if !b.ReadASN1Int64WithTag(¬Before, TagDetailsNotBefore) {
|
||||||
|
return detailsV2{}, ErrBadFormat
|
||||||
|
}
|
||||||
|
|
||||||
|
var notAfter int64
|
||||||
|
if !b.ReadASN1Int64WithTag(¬After, TagDetailsNotAfter) {
|
||||||
|
return detailsV2{}, ErrBadFormat
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read issuer
|
||||||
|
var issuer cryptobyte.String
|
||||||
|
if !b.ReadOptionalASN1(&issuer, nil, TagDetailsIssuer) {
|
||||||
|
return detailsV2{}, ErrBadFormat
|
||||||
|
}
|
||||||
|
|
||||||
|
slices.SortFunc(networks, comparePrefix)
|
||||||
|
slices.SortFunc(unsafeNetworks, comparePrefix)
|
||||||
|
|
||||||
|
return detailsV2{
|
||||||
|
name: string(name),
|
||||||
|
networks: networks,
|
||||||
|
unsafeNetworks: unsafeNetworks,
|
||||||
|
groups: groups,
|
||||||
|
isCA: isCa,
|
||||||
|
notBefore: time.Unix(notBefore, 0),
|
||||||
|
notAfter: time.Unix(notAfter, 0),
|
||||||
|
issuer: hex.EncodeToString(issuer),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
@ -24,4 +24,7 @@ var (
|
|||||||
ErrInvalidPEMX25519PrivateKeyBanner = errors.New("bytes did not contain a proper X25519 private key banner")
|
ErrInvalidPEMX25519PrivateKeyBanner = errors.New("bytes did not contain a proper X25519 private key banner")
|
||||||
ErrInvalidPEMEd25519PublicKeyBanner = errors.New("bytes did not contain a proper Ed25519 public key banner")
|
ErrInvalidPEMEd25519PublicKeyBanner = errors.New("bytes did not contain a proper Ed25519 public key banner")
|
||||||
ErrInvalidPEMEd25519PrivateKeyBanner = errors.New("bytes did not contain a proper Ed25519 private key banner")
|
ErrInvalidPEMEd25519PrivateKeyBanner = errors.New("bytes did not contain a proper Ed25519 private key banner")
|
||||||
|
|
||||||
|
ErrNoPeerStaticKey = errors.New("no peer static key was present")
|
||||||
|
ErrNoPayload = errors.New("provided payload was empty")
|
||||||
)
|
)
|
||||||
|
|||||||
19
cert/pem.go
19
cert/pem.go
@ -30,19 +30,24 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
|
|||||||
return nil, r, ErrInvalidPEMBlock
|
return nil, r, ErrInvalidPEMBlock
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var c Certificate
|
||||||
|
var err error
|
||||||
|
|
||||||
switch p.Type {
|
switch p.Type {
|
||||||
case CertificateBanner:
|
case CertificateBanner:
|
||||||
c, err := unmarshalCertificateV1(p.Bytes, true)
|
c, err = unmarshalCertificateV1(p.Bytes, nil)
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
return c, r, nil
|
|
||||||
case CertificateV2Banner:
|
case CertificateV2Banner:
|
||||||
//TODO
|
c, err = unmarshalCertificateV2(p.Bytes, nil, Curve_CURVE25519)
|
||||||
panic("TODO")
|
|
||||||
default:
|
default:
|
||||||
return nil, r, ErrInvalidPEMCertificateBanner
|
return nil, r, ErrInvalidPEMCertificateBanner
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, r, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return c, r, nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
|
func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
|
||||||
|
|||||||
111
cert/sign.go
111
cert/sign.go
@ -1,11 +1,16 @@
|
|||||||
package cert
|
package cert
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math/big"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"slices"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/slackhq/nebula/pkclient"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// TBSCertificate represents a certificate intended to be signed.
|
// TBSCertificate represents a certificate intended to be signed.
|
||||||
@ -24,27 +29,62 @@ type TBSCertificate struct {
|
|||||||
issuer string
|
issuer string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type beingSignedCertificate interface {
|
||||||
|
// fromTBSCertificate copies the values from the TBSCertificate to this versions internal representation
|
||||||
|
fromTBSCertificate(*TBSCertificate) error
|
||||||
|
|
||||||
|
// marshalForSigning returns the bytes that should be signed
|
||||||
|
marshalForSigning() ([]byte, error)
|
||||||
|
|
||||||
|
// setSignature sets the signature for the certificate that has just been signed
|
||||||
|
setSignature([]byte) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type SignerLambda func(certBytes []byte) ([]byte, error)
|
||||||
|
|
||||||
// Sign will create a sealed certificate using details provided by the TBSCertificate as long as those
|
// Sign will create a sealed certificate using details provided by the TBSCertificate as long as those
|
||||||
// details do not violate constraints of the signing certificate.
|
// details do not violate constraints of the signing certificate.
|
||||||
// If the TBSCertificate is a CA then signer must be nil.
|
// If the TBSCertificate is a CA then signer must be nil.
|
||||||
func (t *TBSCertificate) Sign(signer Certificate, curve Curve, key []byte) (Certificate, error) {
|
func (t *TBSCertificate) Sign(signer Certificate, curve Curve, key []byte) (Certificate, error) {
|
||||||
return t.sign(signer, curve, key, nil)
|
switch t.Curve {
|
||||||
|
case Curve_CURVE25519:
|
||||||
|
pk := ed25519.PrivateKey(key)
|
||||||
|
sp := func(certBytes []byte) ([]byte, error) {
|
||||||
|
sig := ed25519.Sign(pk, certBytes)
|
||||||
|
return sig, nil
|
||||||
|
}
|
||||||
|
return t.SignWith(signer, curve, sp)
|
||||||
|
case Curve_P256:
|
||||||
|
pk := &ecdsa.PrivateKey{
|
||||||
|
PublicKey: ecdsa.PublicKey{
|
||||||
|
Curve: elliptic.P256(),
|
||||||
|
},
|
||||||
|
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95
|
||||||
|
D: new(big.Int).SetBytes(key),
|
||||||
|
}
|
||||||
|
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119
|
||||||
|
pk.X, pk.Y = pk.Curve.ScalarBaseMult(key)
|
||||||
|
sp := func(certBytes []byte) ([]byte, error) {
|
||||||
|
// We need to hash first for ECDSA
|
||||||
|
// - https://pkg.go.dev/crypto/ecdsa#SignASN1
|
||||||
|
hashed := sha256.Sum256(certBytes)
|
||||||
|
return ecdsa.SignASN1(rand.Reader, pk, hashed[:])
|
||||||
|
}
|
||||||
|
return t.SignWith(signer, curve, sp)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid curve: %s", t.Curve)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TBSCertificate) SignPkcs11(signer Certificate, curve Curve, client *pkclient.PKClient) (Certificate, error) {
|
// SignWith does the same thing as sign, but uses the function in `sp` to calculate the signature.
|
||||||
if curve != Curve_P256 {
|
// You should only use SignWith if you do not have direct access to your private key.
|
||||||
return nil, fmt.Errorf("only P256 is supported by PKCS#11")
|
func (t *TBSCertificate) SignWith(signer Certificate, curve Curve, sp SignerLambda) (Certificate, error) {
|
||||||
}
|
|
||||||
|
|
||||||
return t.sign(signer, curve, nil, client)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *TBSCertificate) sign(signer Certificate, curve Curve, key []byte, client *pkclient.PKClient) (Certificate, error) {
|
|
||||||
if curve != t.Curve {
|
if curve != t.Curve {
|
||||||
return nil, fmt.Errorf("curve in cert and private key supplied don't match")
|
return nil, fmt.Errorf("curve in cert and private key supplied don't match")
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO: make sure we have all minimum properties to sign, like a public key
|
//TODO: make sure we have all minimum properties to sign, like a public key
|
||||||
|
//TODO: we need to verify networks and unsafe networks (no duplicates, max of 1 of each version for v2 certs
|
||||||
|
|
||||||
if signer != nil {
|
if signer != nil {
|
||||||
if t.IsCA {
|
if t.IsCA {
|
||||||
@ -67,10 +107,55 @@ func (t *TBSCertificate) sign(signer Certificate, curve Curve, key []byte, clien
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
slices.SortFunc(t.Networks, comparePrefix)
|
||||||
|
slices.SortFunc(t.UnsafeNetworks, comparePrefix)
|
||||||
|
|
||||||
|
var c beingSignedCertificate
|
||||||
switch t.Version {
|
switch t.Version {
|
||||||
case Version1:
|
case Version1:
|
||||||
return signV1(t, curve, key, client)
|
c = &certificateV1{}
|
||||||
|
err := c.fromTBSCertificate(t)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
case Version2:
|
||||||
|
c = &certificateV2{}
|
||||||
|
err := c.fromTBSCertificate(t)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unknown cert version %d", t.Version)
|
return nil, fmt.Errorf("unknown cert version %d", t.Version)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
certBytes, err := c.marshalForSigning()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sig, err := sp(certBytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
//TODO: check if we have sig bytes?
|
||||||
|
err = c.setSignature(sig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sc, ok := c.(Certificate)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid certificate")
|
||||||
|
}
|
||||||
|
|
||||||
|
return sc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func comparePrefix(a, b netip.Prefix) int {
|
||||||
|
addr := a.Addr().Compare(b.Addr())
|
||||||
|
if addr == 0 {
|
||||||
|
return a.Bits() - b.Bits()
|
||||||
|
}
|
||||||
|
return addr
|
||||||
}
|
}
|
||||||
|
|||||||
@ -27,34 +27,43 @@ type caFlags struct {
|
|||||||
outCertPath *string
|
outCertPath *string
|
||||||
outQRPath *string
|
outQRPath *string
|
||||||
groups *string
|
groups *string
|
||||||
ips *string
|
networks *string
|
||||||
subnets *string
|
unsafeNetworks *string
|
||||||
argonMemory *uint
|
argonMemory *uint
|
||||||
argonIterations *uint
|
argonIterations *uint
|
||||||
argonParallelism *uint
|
argonParallelism *uint
|
||||||
encryption *bool
|
encryption *bool
|
||||||
|
version *uint
|
||||||
|
|
||||||
curve *string
|
curve *string
|
||||||
p11url *string
|
p11url *string
|
||||||
|
|
||||||
|
// Deprecated options
|
||||||
|
ips *string
|
||||||
|
subnets *string
|
||||||
}
|
}
|
||||||
|
|
||||||
func newCaFlags() *caFlags {
|
func newCaFlags() *caFlags {
|
||||||
cf := caFlags{set: flag.NewFlagSet("ca", flag.ContinueOnError)}
|
cf := caFlags{set: flag.NewFlagSet("ca", flag.ContinueOnError)}
|
||||||
cf.set.Usage = func() {}
|
cf.set.Usage = func() {}
|
||||||
cf.name = cf.set.String("name", "", "Required: name of the certificate authority")
|
cf.name = cf.set.String("name", "", "Required: name of the certificate authority")
|
||||||
|
cf.version = cf.set.Uint("version", uint(cert.Version2), "Optional: version of the certificate format to use")
|
||||||
cf.duration = cf.set.Duration("duration", time.Duration(time.Hour*8760), "Optional: amount of time the certificate should be valid for. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"")
|
cf.duration = cf.set.Duration("duration", time.Duration(time.Hour*8760), "Optional: amount of time the certificate should be valid for. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"")
|
||||||
cf.outKeyPath = cf.set.String("out-key", "ca.key", "Optional: path to write the private key to")
|
cf.outKeyPath = cf.set.String("out-key", "ca.key", "Optional: path to write the private key to")
|
||||||
cf.outCertPath = cf.set.String("out-crt", "ca.crt", "Optional: path to write the certificate to")
|
cf.outCertPath = cf.set.String("out-crt", "ca.crt", "Optional: path to write the certificate to")
|
||||||
cf.outQRPath = cf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate")
|
cf.outQRPath = cf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate")
|
||||||
cf.groups = cf.set.String("groups", "", "Optional: comma separated list of groups. This will limit which groups subordinate certs can use")
|
cf.groups = cf.set.String("groups", "", "Optional: comma separated list of groups. This will limit which groups subordinate certs can use")
|
||||||
cf.ips = cf.set.String("ips", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use for ip addresses")
|
cf.networks = cf.set.String("networks", "", "Optional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in networks")
|
||||||
cf.subnets = cf.set.String("subnets", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use in subnets")
|
cf.unsafeNetworks = cf.set.String("unsafe-networks", "", "Optional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in unsafe networks")
|
||||||
cf.argonMemory = cf.set.Uint("argon-memory", 2*1024*1024, "Optional: Argon2 memory parameter (in KiB) used for encrypted private key passphrase")
|
cf.argonMemory = cf.set.Uint("argon-memory", 2*1024*1024, "Optional: Argon2 memory parameter (in KiB) used for encrypted private key passphrase")
|
||||||
cf.argonParallelism = cf.set.Uint("argon-parallelism", 4, "Optional: Argon2 parallelism parameter used for encrypted private key passphrase")
|
cf.argonParallelism = cf.set.Uint("argon-parallelism", 4, "Optional: Argon2 parallelism parameter used for encrypted private key passphrase")
|
||||||
cf.argonIterations = cf.set.Uint("argon-iterations", 1, "Optional: Argon2 iterations parameter used for encrypted private key passphrase")
|
cf.argonIterations = cf.set.Uint("argon-iterations", 1, "Optional: Argon2 iterations parameter used for encrypted private key passphrase")
|
||||||
cf.encryption = cf.set.Bool("encrypt", false, "Optional: prompt for passphrase and write out-key in an encrypted format")
|
cf.encryption = cf.set.Bool("encrypt", false, "Optional: prompt for passphrase and write out-key in an encrypted format")
|
||||||
cf.curve = cf.set.String("curve", "25519", "EdDSA/ECDSA Curve (25519, P256)")
|
cf.curve = cf.set.String("curve", "25519", "EdDSA/ECDSA Curve (25519, P256)")
|
||||||
cf.p11url = p11Flag(cf.set)
|
cf.p11url = p11Flag(cf.set)
|
||||||
|
|
||||||
|
cf.ips = cf.set.String("ips", "", "Deprecated, see -networks")
|
||||||
|
cf.subnets = cf.set.String("subnets", "", "Deprecated, see -unsafe-networks")
|
||||||
return &cf
|
return &cf
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -113,36 +122,51 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var ips []netip.Prefix
|
version := cert.Version(*cf.version)
|
||||||
if *cf.ips != "" {
|
if version != cert.Version1 && version != cert.Version2 {
|
||||||
for _, rs := range strings.Split(*cf.ips, ",") {
|
return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2)
|
||||||
|
}
|
||||||
|
|
||||||
|
var networks []netip.Prefix
|
||||||
|
if *cf.networks == "" && *cf.ips != "" {
|
||||||
|
// Pull up deprecated -ips flag if needed
|
||||||
|
*cf.networks = *cf.ips
|
||||||
|
}
|
||||||
|
|
||||||
|
if *cf.networks != "" {
|
||||||
|
for _, rs := range strings.Split(*cf.networks, ",") {
|
||||||
rs := strings.Trim(rs, " ")
|
rs := strings.Trim(rs, " ")
|
||||||
if rs != "" {
|
if rs != "" {
|
||||||
n, err := netip.ParsePrefix(rs)
|
n, err := netip.ParsePrefix(rs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return newHelpErrorf("invalid ip definition: %s", err)
|
return newHelpErrorf("invalid -networks definition: %s", rs)
|
||||||
}
|
}
|
||||||
if !n.Addr().Is4() {
|
if version == cert.Version1 && !n.Addr().Is4() {
|
||||||
return newHelpErrorf("invalid ip definition: can only be ipv4, have %s", rs)
|
return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4, have %s", rs)
|
||||||
}
|
}
|
||||||
ips = append(ips, n)
|
networks = append(networks, n)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var subnets []netip.Prefix
|
var unsafeNetworks []netip.Prefix
|
||||||
if *cf.subnets != "" {
|
if *cf.unsafeNetworks == "" && *cf.subnets != "" {
|
||||||
for _, rs := range strings.Split(*cf.subnets, ",") {
|
// Pull up deprecated -subnets flag if needed
|
||||||
|
*cf.unsafeNetworks = *cf.subnets
|
||||||
|
}
|
||||||
|
|
||||||
|
if *cf.unsafeNetworks != "" {
|
||||||
|
for _, rs := range strings.Split(*cf.unsafeNetworks, ",") {
|
||||||
rs := strings.Trim(rs, " ")
|
rs := strings.Trim(rs, " ")
|
||||||
if rs != "" {
|
if rs != "" {
|
||||||
n, err := netip.ParsePrefix(rs)
|
n, err := netip.ParsePrefix(rs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return newHelpErrorf("invalid subnet definition: %s", err)
|
return newHelpErrorf("invalid -unsafe-networks definition: %s", rs)
|
||||||
}
|
}
|
||||||
if !n.Addr().Is4() {
|
if version == cert.Version1 && !n.Addr().Is4() {
|
||||||
return newHelpErrorf("invalid subnet definition: can only be ipv4, have %s", rs)
|
return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4, have %s", rs)
|
||||||
}
|
}
|
||||||
subnets = append(subnets, n)
|
unsafeNetworks = append(unsafeNetworks, n)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -222,11 +246,11 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
|
|||||||
}
|
}
|
||||||
|
|
||||||
t := &cert.TBSCertificate{
|
t := &cert.TBSCertificate{
|
||||||
Version: cert.Version1,
|
Version: version,
|
||||||
Name: *cf.name,
|
Name: *cf.name,
|
||||||
Groups: groups,
|
Groups: groups,
|
||||||
Networks: ips,
|
Networks: networks,
|
||||||
UnsafeNetworks: subnets,
|
UnsafeNetworks: unsafeNetworks,
|
||||||
NotBefore: time.Now(),
|
NotBefore: time.Now(),
|
||||||
NotAfter: time.Now().Add(*cf.duration),
|
NotAfter: time.Now().Add(*cf.duration),
|
||||||
PublicKey: pub,
|
PublicKey: pub,
|
||||||
@ -248,7 +272,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
|
|||||||
var b []byte
|
var b []byte
|
||||||
|
|
||||||
if isP11 {
|
if isP11 {
|
||||||
c, err = t.SignPkcs11(nil, curve, p11Client)
|
c, err = t.SignWith(nil, curve, p11Client.SignASN1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error while signing with PKCS#11: %w", err)
|
return fmt.Errorf("error while signing with PKCS#11: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -43,9 +43,11 @@ func Test_caHelp(t *testing.T) {
|
|||||||
" -groups string\n"+
|
" -groups string\n"+
|
||||||
" \tOptional: comma separated list of groups. This will limit which groups subordinate certs can use\n"+
|
" \tOptional: comma separated list of groups. This will limit which groups subordinate certs can use\n"+
|
||||||
" -ips string\n"+
|
" -ips string\n"+
|
||||||
" \tOptional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use for ip addresses\n"+
|
" Deprecated, see -networks\n"+
|
||||||
" -name string\n"+
|
" -name string\n"+
|
||||||
" \tRequired: name of the certificate authority\n"+
|
" \tRequired: name of the certificate authority\n"+
|
||||||
|
" -networks string\n"+
|
||||||
|
" \tOptional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in networks\n"+
|
||||||
" -out-crt string\n"+
|
" -out-crt string\n"+
|
||||||
" \tOptional: path to write the certificate to (default \"ca.crt\")\n"+
|
" \tOptional: path to write the certificate to (default \"ca.crt\")\n"+
|
||||||
" -out-key string\n"+
|
" -out-key string\n"+
|
||||||
@ -54,7 +56,11 @@ func Test_caHelp(t *testing.T) {
|
|||||||
" \tOptional: output a qr code image (png) of the certificate\n"+
|
" \tOptional: output a qr code image (png) of the certificate\n"+
|
||||||
optionalPkcs11String(" -pkcs11 string\n \tOptional: PKCS#11 URI to an existing private key\n")+
|
optionalPkcs11String(" -pkcs11 string\n \tOptional: PKCS#11 URI to an existing private key\n")+
|
||||||
" -subnets string\n"+
|
" -subnets string\n"+
|
||||||
" \tOptional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use in subnets\n",
|
" \tDeprecated, see -unsafe-networks\n"+
|
||||||
|
" -unsafe-networks string\n"+
|
||||||
|
" \tOptional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in unsafe networks\n"+
|
||||||
|
" -version uint\n"+
|
||||||
|
" \tOptional: version of the certificate format to use (default 2)\n",
|
||||||
ob.String(),
|
ob.String(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@ -83,25 +89,25 @@ func Test_ca(t *testing.T) {
|
|||||||
|
|
||||||
// required args
|
// required args
|
||||||
assertHelpError(t, ca(
|
assertHelpError(t, ca(
|
||||||
[]string{"-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw,
|
[]string{"-version", "1", "-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw,
|
||||||
), "-name is required")
|
), "-name is required")
|
||||||
assert.Equal(t, "", ob.String())
|
assert.Equal(t, "", ob.String())
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
|
|
||||||
// ipv4 only ips
|
// ipv4 only ips
|
||||||
assertHelpError(t, ca([]string{"-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid ip definition: can only be ipv4, have 100::100/100")
|
assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid -networks definition: v1 certificates can only be ipv4, have 100::100/100")
|
||||||
assert.Equal(t, "", ob.String())
|
assert.Equal(t, "", ob.String())
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
|
|
||||||
// ipv4 only subnets
|
// ipv4 only subnets
|
||||||
assertHelpError(t, ca([]string{"-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid subnet definition: can only be ipv4, have 100::100/100")
|
assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4, have 100::100/100")
|
||||||
assert.Equal(t, "", ob.String())
|
assert.Equal(t, "", ob.String())
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
|
|
||||||
// failed key write
|
// failed key write
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args := []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"}
|
args := []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"}
|
||||||
assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
|
assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
|
||||||
assert.Equal(t, "", ob.String())
|
assert.Equal(t, "", ob.String())
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
@ -114,7 +120,7 @@ func Test_ca(t *testing.T) {
|
|||||||
// failed cert write
|
// failed cert write
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()}
|
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()}
|
||||||
assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
|
assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
|
||||||
assert.Equal(t, "", ob.String())
|
assert.Equal(t, "", ob.String())
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
@ -128,7 +134,7 @@ func Test_ca(t *testing.T) {
|
|||||||
// test proper cert with removed empty groups and subnets
|
// test proper cert with removed empty groups and subnets
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||||
assert.Nil(t, ca(args, ob, eb, nopw))
|
assert.Nil(t, ca(args, ob, eb, nopw))
|
||||||
assert.Equal(t, "", ob.String())
|
assert.Equal(t, "", ob.String())
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
@ -161,7 +167,7 @@ func Test_ca(t *testing.T) {
|
|||||||
os.Remove(crtF.Name())
|
os.Remove(crtF.Name())
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||||
assert.Nil(t, ca(args, ob, eb, testpw))
|
assert.Nil(t, ca(args, ob, eb, testpw))
|
||||||
assert.Equal(t, pwPromptOb, ob.String())
|
assert.Equal(t, pwPromptOb, ob.String())
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
@ -189,7 +195,7 @@ func Test_ca(t *testing.T) {
|
|||||||
os.Remove(crtF.Name())
|
os.Remove(crtF.Name())
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||||
assert.Error(t, ca(args, ob, eb, errpw))
|
assert.Error(t, ca(args, ob, eb, errpw))
|
||||||
assert.Equal(t, pwPromptOb, ob.String())
|
assert.Equal(t, pwPromptOb, ob.String())
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
@ -199,7 +205,7 @@ func Test_ca(t *testing.T) {
|
|||||||
os.Remove(crtF.Name())
|
os.Remove(crtF.Name())
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||||
assert.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext")
|
assert.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext")
|
||||||
assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up
|
assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
@ -209,13 +215,13 @@ func Test_ca(t *testing.T) {
|
|||||||
os.Remove(crtF.Name())
|
os.Remove(crtF.Name())
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||||
assert.Nil(t, ca(args, ob, eb, nopw))
|
assert.Nil(t, ca(args, ob, eb, nopw))
|
||||||
|
|
||||||
// test that we won't overwrite existing certificate file
|
// test that we won't overwrite existing certificate file
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||||
assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name())
|
assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name())
|
||||||
assert.Equal(t, "", ob.String())
|
assert.Equal(t, "", ob.String())
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
@ -224,7 +230,7 @@ func Test_ca(t *testing.T) {
|
|||||||
os.Remove(keyF.Name())
|
os.Remove(keyF.Name())
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||||
assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name())
|
assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name())
|
||||||
assert.Equal(t, "", ob.String())
|
assert.Equal(t, "", ob.String())
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
|
|||||||
@ -49,6 +49,8 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
|
|||||||
var qrBytes []byte
|
var qrBytes []byte
|
||||||
part := 0
|
part := 0
|
||||||
|
|
||||||
|
var jsonCerts []cert.Certificate
|
||||||
|
|
||||||
for {
|
for {
|
||||||
c, rawCert, err = cert.UnmarshalCertificateFromPEM(rawCert)
|
c, rawCert, err = cert.UnmarshalCertificateFromPEM(rawCert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -56,13 +58,10 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if *pf.json {
|
if *pf.json {
|
||||||
b, _ := json.Marshal(c)
|
jsonCerts = append(jsonCerts, c)
|
||||||
out.Write(b)
|
|
||||||
out.Write([]byte("\n"))
|
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
out.Write([]byte(c.String()))
|
_, _ = out.Write([]byte(c.String()))
|
||||||
out.Write([]byte("\n"))
|
_, _ = out.Write([]byte("\n"))
|
||||||
}
|
}
|
||||||
|
|
||||||
if *pf.outQRPath != "" {
|
if *pf.outQRPath != "" {
|
||||||
@ -80,6 +79,12 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
|
|||||||
part++
|
part++
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if *pf.json {
|
||||||
|
b, _ := json.Marshal(jsonCerts)
|
||||||
|
_, _ = out.Write(b)
|
||||||
|
_, _ = out.Write([]byte("\n"))
|
||||||
|
}
|
||||||
|
|
||||||
if *pf.outQRPath != "" {
|
if *pf.outQRPath != "" {
|
||||||
b, err := qrcode.Encode(string(qrBytes), qrcode.Medium, -5)
|
b, err := qrcode.Encode(string(qrBytes), qrcode.Medium, -5)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@ -87,7 +87,65 @@ func Test_printCert(t *testing.T) {
|
|||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(
|
assert.Equal(
|
||||||
t,
|
t,
|
||||||
"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\n",
|
//"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\n",
|
||||||
|
`{
|
||||||
|
"details": {
|
||||||
|
"curve": "CURVE25519",
|
||||||
|
"groups": [
|
||||||
|
"hi"
|
||||||
|
],
|
||||||
|
"isCa": false,
|
||||||
|
"issuer": "`+c.Issuer()+`",
|
||||||
|
"name": "test",
|
||||||
|
"networks": [],
|
||||||
|
"notAfter": "0001-01-01T00:00:00Z",
|
||||||
|
"notBefore": "0001-01-01T00:00:00Z",
|
||||||
|
"publicKey": "`+pk+`",
|
||||||
|
"unsafeNetworks": []
|
||||||
|
},
|
||||||
|
"fingerprint": "`+fp+`",
|
||||||
|
"signature": "`+sig+`",
|
||||||
|
"version": 1
|
||||||
|
}
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"curve": "CURVE25519",
|
||||||
|
"groups": [
|
||||||
|
"hi"
|
||||||
|
],
|
||||||
|
"isCa": false,
|
||||||
|
"issuer": "`+c.Issuer()+`",
|
||||||
|
"name": "test",
|
||||||
|
"networks": [],
|
||||||
|
"notAfter": "0001-01-01T00:00:00Z",
|
||||||
|
"notBefore": "0001-01-01T00:00:00Z",
|
||||||
|
"publicKey": "`+pk+`",
|
||||||
|
"unsafeNetworks": []
|
||||||
|
},
|
||||||
|
"fingerprint": "`+fp+`",
|
||||||
|
"signature": "`+sig+`",
|
||||||
|
"version": 1
|
||||||
|
}
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"curve": "CURVE25519",
|
||||||
|
"groups": [
|
||||||
|
"hi"
|
||||||
|
],
|
||||||
|
"isCa": false,
|
||||||
|
"issuer": "`+c.Issuer()+`",
|
||||||
|
"name": "test",
|
||||||
|
"networks": [],
|
||||||
|
"notAfter": "0001-01-01T00:00:00Z",
|
||||||
|
"notBefore": "0001-01-01T00:00:00Z",
|
||||||
|
"publicKey": "`+pk+`",
|
||||||
|
"unsafeNetworks": []
|
||||||
|
},
|
||||||
|
"fingerprint": "`+fp+`",
|
||||||
|
"signature": "`+sig+`",
|
||||||
|
"version": 1
|
||||||
|
}
|
||||||
|
`,
|
||||||
ob.String(),
|
ob.String(),
|
||||||
)
|
)
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
@ -108,7 +166,8 @@ func Test_printCert(t *testing.T) {
|
|||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(
|
assert.Equal(
|
||||||
t,
|
t,
|
||||||
"{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\""+c.Issuer()+"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\""+pk+"\",\"subnets\":[]},\"fingerprint\":\""+fp+"\",\"signature\":\""+sig+"\"}\n{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\""+c.Issuer()+"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\""+pk+"\",\"subnets\":[]},\"fingerprint\":\""+fp+"\",\"signature\":\""+sig+"\"}\n{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\""+c.Issuer()+"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\""+pk+"\",\"subnets\":[]},\"fingerprint\":\""+fp+"\",\"signature\":\""+sig+"\"}\n",
|
`[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":[],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":[],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":[],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}]
|
||||||
|
`,
|
||||||
ob.String(),
|
ob.String(),
|
||||||
)
|
)
|
||||||
assert.Equal(t, "", eb.String())
|
assert.Equal(t, "", eb.String())
|
||||||
|
|||||||
@ -3,6 +3,7 @@ package main
|
|||||||
import (
|
import (
|
||||||
"crypto/ecdh"
|
"crypto/ecdh"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
"errors"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@ -19,35 +20,45 @@ import (
|
|||||||
|
|
||||||
type signFlags struct {
|
type signFlags struct {
|
||||||
set *flag.FlagSet
|
set *flag.FlagSet
|
||||||
|
version *uint
|
||||||
caKeyPath *string
|
caKeyPath *string
|
||||||
caCertPath *string
|
caCertPath *string
|
||||||
name *string
|
name *string
|
||||||
ip *string
|
networks *string
|
||||||
|
unsafeNetworks *string
|
||||||
duration *time.Duration
|
duration *time.Duration
|
||||||
inPubPath *string
|
inPubPath *string
|
||||||
outKeyPath *string
|
outKeyPath *string
|
||||||
outCertPath *string
|
outCertPath *string
|
||||||
outQRPath *string
|
outQRPath *string
|
||||||
groups *string
|
groups *string
|
||||||
subnets *string
|
|
||||||
p11url *string
|
p11url *string
|
||||||
|
|
||||||
|
// Deprecated options
|
||||||
|
ip *string
|
||||||
|
subnets *string
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSignFlags() *signFlags {
|
func newSignFlags() *signFlags {
|
||||||
sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)}
|
sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)}
|
||||||
sf.set.Usage = func() {}
|
sf.set.Usage = func() {}
|
||||||
|
sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use, the default is to create both v1 and v2 certificates.")
|
||||||
sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key")
|
sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key")
|
||||||
sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert")
|
sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert")
|
||||||
sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname")
|
sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname")
|
||||||
sf.ip = sf.set.String("ip", "", "Required: ipv4 address and network in CIDR notation to assign the cert")
|
sf.networks = sf.set.String("networks", "", "Required: comma separated list of ip address and network in CIDR notation to assign to this cert")
|
||||||
|
sf.unsafeNetworks = sf.set.String("unsafe-networks", "", "Optional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for")
|
||||||
sf.duration = sf.set.Duration("duration", 0, "Optional: how long the cert should be valid for. The default is 1 second before the signing cert expires. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"")
|
sf.duration = sf.set.Duration("duration", 0, "Optional: how long the cert should be valid for. The default is 1 second before the signing cert expires. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"")
|
||||||
sf.inPubPath = sf.set.String("in-pub", "", "Optional (if out-key not set): path to read a previously generated public key")
|
sf.inPubPath = sf.set.String("in-pub", "", "Optional (if out-key not set): path to read a previously generated public key")
|
||||||
sf.outKeyPath = sf.set.String("out-key", "", "Optional (if in-pub not set): path to write the private key to")
|
sf.outKeyPath = sf.set.String("out-key", "", "Optional (if in-pub not set): path to write the private key to")
|
||||||
sf.outCertPath = sf.set.String("out-crt", "", "Optional: path to write the certificate to")
|
sf.outCertPath = sf.set.String("out-crt", "", "Optional: path to write the certificate to")
|
||||||
sf.outQRPath = sf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate")
|
sf.outQRPath = sf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate")
|
||||||
sf.groups = sf.set.String("groups", "", "Optional: comma separated list of groups")
|
sf.groups = sf.set.String("groups", "", "Optional: comma separated list of groups")
|
||||||
sf.subnets = sf.set.String("subnets", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. Subnets this cert can serve for")
|
|
||||||
sf.p11url = p11Flag(sf.set)
|
sf.p11url = p11Flag(sf.set)
|
||||||
|
|
||||||
|
sf.ip = sf.set.String("ip", "", "Deprecated, see -networks")
|
||||||
|
sf.subnets = sf.set.String("subnets", "", "Deprecated, see -unsafe-networks")
|
||||||
return &sf
|
return &sf
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -71,13 +82,26 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
|||||||
if err := mustFlagString("name", sf.name); err != nil {
|
if err := mustFlagString("name", sf.name); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := mustFlagString("ip", sf.ip); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if !isP11 && *sf.inPubPath != "" && *sf.outKeyPath != "" {
|
if !isP11 && *sf.inPubPath != "" && *sf.outKeyPath != "" {
|
||||||
return newHelpErrorf("cannot set both -in-pub and -out-key")
|
return newHelpErrorf("cannot set both -in-pub and -out-key")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var v4Networks []netip.Prefix
|
||||||
|
var v6Networks []netip.Prefix
|
||||||
|
if *sf.networks == "" && *sf.ip != "" {
|
||||||
|
// Pull up deprecated -ip flag if needed
|
||||||
|
*sf.networks = *sf.ip
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(*sf.networks) == 0 {
|
||||||
|
return newHelpErrorf("-networks is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
version := cert.Version(*sf.version)
|
||||||
|
if version != 0 && version != cert.Version1 && version != cert.Version2 {
|
||||||
|
return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2)
|
||||||
|
}
|
||||||
|
|
||||||
var curve cert.Curve
|
var curve cert.Curve
|
||||||
var caKey []byte
|
var caKey []byte
|
||||||
|
|
||||||
@ -91,14 +115,14 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
|||||||
|
|
||||||
// naively attempt to decode the private key as though it is not encrypted
|
// naively attempt to decode the private key as though it is not encrypted
|
||||||
caKey, _, curve, err = cert.UnmarshalSigningPrivateKeyFromPEM(rawCAKey)
|
caKey, _, curve, err = cert.UnmarshalSigningPrivateKeyFromPEM(rawCAKey)
|
||||||
if err == cert.ErrPrivateKeyEncrypted {
|
if errors.Is(err, cert.ErrPrivateKeyEncrypted) {
|
||||||
// ask for a passphrase until we get one
|
// ask for a passphrase until we get one
|
||||||
var passphrase []byte
|
var passphrase []byte
|
||||||
for i := 0; i < 5; i++ {
|
for i := 0; i < 5; i++ {
|
||||||
out.Write([]byte("Enter passphrase: "))
|
out.Write([]byte("Enter passphrase: "))
|
||||||
passphrase, err = pr.ReadPassword()
|
passphrase, err = pr.ReadPassword()
|
||||||
|
|
||||||
if err == ErrNoTerminal {
|
if errors.Is(err, ErrNoTerminal) {
|
||||||
return fmt.Errorf("ca-key is encrypted and must be decrypted interactively")
|
return fmt.Errorf("ca-key is encrypted and must be decrypted interactively")
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return fmt.Errorf("error reading password: %s", err)
|
return fmt.Errorf("error reading password: %s", err)
|
||||||
@ -146,12 +170,49 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
|||||||
*sf.duration = time.Until(caCert.NotAfter()) - time.Second*1
|
*sf.duration = time.Until(caCert.NotAfter()) - time.Second*1
|
||||||
}
|
}
|
||||||
|
|
||||||
network, err := netip.ParsePrefix(*sf.ip)
|
if *sf.networks != "" {
|
||||||
|
for _, rs := range strings.Split(*sf.networks, ",") {
|
||||||
|
//TODO: error on duplicates? Mainly only addr matters, having two of the same addr in the same or different prefix space is strange
|
||||||
|
rs := strings.Trim(rs, " ")
|
||||||
|
if rs != "" {
|
||||||
|
n, err := netip.ParsePrefix(rs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return newHelpErrorf("invalid ip definition: %s", *sf.ip)
|
return newHelpErrorf("invalid -networks definition: %s", rs)
|
||||||
|
}
|
||||||
|
|
||||||
|
if n.Addr().Is4() {
|
||||||
|
v4Networks = append(v4Networks, n)
|
||||||
|
} else {
|
||||||
|
v6Networks = append(v6Networks, n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var v4UnsafeNetworks []netip.Prefix
|
||||||
|
var v6UnsafeNetworks []netip.Prefix
|
||||||
|
if *sf.unsafeNetworks == "" && *sf.subnets != "" {
|
||||||
|
// Pull up deprecated -subnets flag if needed
|
||||||
|
*sf.unsafeNetworks = *sf.subnets
|
||||||
|
}
|
||||||
|
|
||||||
|
if *sf.unsafeNetworks != "" {
|
||||||
|
//TODO: error on duplicates?
|
||||||
|
for _, rs := range strings.Split(*sf.unsafeNetworks, ",") {
|
||||||
|
rs := strings.Trim(rs, " ")
|
||||||
|
if rs != "" {
|
||||||
|
n, err := netip.ParsePrefix(rs)
|
||||||
|
if err != nil {
|
||||||
|
return newHelpErrorf("invalid -unsafe-networks definition: %s", rs)
|
||||||
|
}
|
||||||
|
|
||||||
|
if n.Addr().Is4() {
|
||||||
|
v4UnsafeNetworks = append(v4UnsafeNetworks, n)
|
||||||
|
} else {
|
||||||
|
v6UnsafeNetworks = append(v6UnsafeNetworks, n)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if !network.Addr().Is4() {
|
|
||||||
return newHelpErrorf("invalid ip definition: can only be ipv4, have %s", *sf.ip)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var groups []string
|
var groups []string
|
||||||
@ -164,23 +225,6 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var subnets []netip.Prefix
|
|
||||||
if *sf.subnets != "" {
|
|
||||||
for _, rs := range strings.Split(*sf.subnets, ",") {
|
|
||||||
rs := strings.Trim(rs, " ")
|
|
||||||
if rs != "" {
|
|
||||||
s, err := netip.ParsePrefix(rs)
|
|
||||||
if err != nil {
|
|
||||||
return newHelpErrorf("invalid subnet definition: %s", rs)
|
|
||||||
}
|
|
||||||
if !s.Addr().Is4() {
|
|
||||||
return newHelpErrorf("invalid subnet definition: can only be ipv4, have %s", rs)
|
|
||||||
}
|
|
||||||
subnets = append(subnets, s)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var pub, rawPriv []byte
|
var pub, rawPriv []byte
|
||||||
var p11Client *pkclient.PKClient
|
var p11Client *pkclient.PKClient
|
||||||
|
|
||||||
@ -218,19 +262,6 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
|||||||
pub, rawPriv = newKeypair(curve)
|
pub, rawPriv = newKeypair(curve)
|
||||||
}
|
}
|
||||||
|
|
||||||
t := &cert.TBSCertificate{
|
|
||||||
Version: cert.Version1,
|
|
||||||
Name: *sf.name,
|
|
||||||
Networks: []netip.Prefix{network},
|
|
||||||
Groups: groups,
|
|
||||||
UnsafeNetworks: subnets,
|
|
||||||
NotBefore: time.Now(),
|
|
||||||
NotAfter: time.Now().Add(*sf.duration),
|
|
||||||
PublicKey: pub,
|
|
||||||
IsCA: false,
|
|
||||||
Curve: curve,
|
|
||||||
}
|
|
||||||
|
|
||||||
if *sf.outKeyPath == "" {
|
if *sf.outKeyPath == "" {
|
||||||
*sf.outKeyPath = *sf.name + ".key"
|
*sf.outKeyPath = *sf.name + ".key"
|
||||||
}
|
}
|
||||||
@ -243,20 +274,87 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
|||||||
return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath)
|
return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
var c cert.Certificate
|
var crts []cert.Certificate
|
||||||
|
|
||||||
|
notBefore := time.Now()
|
||||||
|
notAfter := notBefore.Add(*sf.duration)
|
||||||
|
|
||||||
|
if version == 0 || version == cert.Version1 {
|
||||||
|
// Make sure we at least have an ip
|
||||||
|
if len(v4Networks) != 1 {
|
||||||
|
return newHelpErrorf("invalid -networks definition: v1 certificates can only have a single ipv4 address")
|
||||||
|
}
|
||||||
|
|
||||||
|
if version == cert.Version1 {
|
||||||
|
// If we are asked to mint a v1 certificate only then we cant just ignore any v6 addresses
|
||||||
|
if len(v6Networks) > 0 {
|
||||||
|
return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(v6UnsafeNetworks) > 0 {
|
||||||
|
return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t := &cert.TBSCertificate{
|
||||||
|
Version: cert.Version1,
|
||||||
|
Name: *sf.name,
|
||||||
|
Networks: []netip.Prefix{v4Networks[0]},
|
||||||
|
Groups: groups,
|
||||||
|
UnsafeNetworks: v4UnsafeNetworks,
|
||||||
|
NotBefore: notBefore,
|
||||||
|
NotAfter: notAfter,
|
||||||
|
PublicKey: pub,
|
||||||
|
IsCA: false,
|
||||||
|
Curve: curve,
|
||||||
|
}
|
||||||
|
|
||||||
|
var nc cert.Certificate
|
||||||
if p11Client == nil {
|
if p11Client == nil {
|
||||||
c, err = t.Sign(caCert, curve, caKey)
|
nc, err = t.Sign(caCert, curve, caKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error while signing: %w", err)
|
return fmt.Errorf("error while signing: %w", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
c, err = t.SignPkcs11(caCert, curve, p11Client)
|
nc, err = t.SignWith(caCert, curve, p11Client.SignASN1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error while signing with PKCS#11: %w", err)
|
return fmt.Errorf("error while signing with PKCS#11: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
crts = append(crts, nc)
|
||||||
|
}
|
||||||
|
|
||||||
|
if version == 0 || version == cert.Version2 {
|
||||||
|
t := &cert.TBSCertificate{
|
||||||
|
Version: cert.Version2,
|
||||||
|
Name: *sf.name,
|
||||||
|
Networks: append(v4Networks, v6Networks...),
|
||||||
|
Groups: groups,
|
||||||
|
UnsafeNetworks: append(v4UnsafeNetworks, v6UnsafeNetworks...),
|
||||||
|
NotBefore: notBefore,
|
||||||
|
NotAfter: notAfter,
|
||||||
|
PublicKey: pub,
|
||||||
|
IsCA: false,
|
||||||
|
Curve: curve,
|
||||||
|
}
|
||||||
|
|
||||||
|
var nc cert.Certificate
|
||||||
|
if p11Client == nil {
|
||||||
|
nc, err = t.Sign(caCert, curve, caKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error while signing: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
nc, err = t.SignWith(caCert, curve, p11Client.SignASN1)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error while signing with PKCS#11: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
crts = append(crts, nc)
|
||||||
|
}
|
||||||
|
|
||||||
if !isP11 && *sf.inPubPath == "" {
|
if !isP11 && *sf.inPubPath == "" {
|
||||||
if _, err := os.Stat(*sf.outKeyPath); err == nil {
|
if _, err := os.Stat(*sf.outKeyPath); err == nil {
|
||||||
return fmt.Errorf("refusing to overwrite existing key: %s", *sf.outKeyPath)
|
return fmt.Errorf("refusing to overwrite existing key: %s", *sf.outKeyPath)
|
||||||
@ -268,10 +366,14 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
b, err := c.MarshalPEM()
|
var b []byte
|
||||||
|
for _, c := range crts {
|
||||||
|
sb, err := c.MarshalPEM()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error while marshalling certificate: %s", err)
|
return fmt.Errorf("error while marshalling certificate: %s", err)
|
||||||
}
|
}
|
||||||
|
b = append(b, sb...)
|
||||||
|
}
|
||||||
|
|
||||||
err = os.WriteFile(*sf.outCertPath, b, 0600)
|
err = os.WriteFile(*sf.outCertPath, b, 0600)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@ -39,9 +39,11 @@ func Test_signHelp(t *testing.T) {
|
|||||||
" -in-pub string\n"+
|
" -in-pub string\n"+
|
||||||
" \tOptional (if out-key not set): path to read a previously generated public key\n"+
|
" \tOptional (if out-key not set): path to read a previously generated public key\n"+
|
||||||
" -ip string\n"+
|
" -ip string\n"+
|
||||||
" \tRequired: ipv4 address and network in CIDR notation to assign the cert\n"+
|
" \tDeprecated, see -networks\n"+
|
||||||
" -name string\n"+
|
" -name string\n"+
|
||||||
" \tRequired: name of the cert, usually a hostname\n"+
|
" \tRequired: name of the cert, usually a hostname\n"+
|
||||||
|
" -networks string\n"+
|
||||||
|
" \tRequired: comma separated list of ip address and network in CIDR notation to assign to this cert\n"+
|
||||||
" -out-crt string\n"+
|
" -out-crt string\n"+
|
||||||
" \tOptional: path to write the certificate to\n"+
|
" \tOptional: path to write the certificate to\n"+
|
||||||
" -out-key string\n"+
|
" -out-key string\n"+
|
||||||
@ -50,7 +52,11 @@ func Test_signHelp(t *testing.T) {
|
|||||||
" \tOptional: output a qr code image (png) of the certificate\n"+
|
" \tOptional: output a qr code image (png) of the certificate\n"+
|
||||||
optionalPkcs11String(" -pkcs11 string\n \tOptional: PKCS#11 URI to an existing private key\n")+
|
optionalPkcs11String(" -pkcs11 string\n \tOptional: PKCS#11 URI to an existing private key\n")+
|
||||||
" -subnets string\n"+
|
" -subnets string\n"+
|
||||||
" \tOptional: comma separated list of ipv4 address and network in CIDR notation. Subnets this cert can serve for\n",
|
" \tDeprecated, see -unsafe-networks\n"+
|
||||||
|
" -unsafe-networks string\n"+
|
||||||
|
" \tOptional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for\n"+
|
||||||
|
" -version uint\n"+
|
||||||
|
" \tOptional: version of the certificate format to use, the default is to create both v1 and v2 certificates.\n",
|
||||||
ob.String(),
|
ob.String(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@ -77,20 +83,20 @@ func Test_signCert(t *testing.T) {
|
|||||||
|
|
||||||
// required args
|
// required args
|
||||||
assertHelpError(t, signCert(
|
assertHelpError(t, signCert(
|
||||||
[]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
|
[]string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
|
||||||
), "-name is required")
|
), "-name is required")
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
assertHelpError(t, signCert(
|
assertHelpError(t, signCert(
|
||||||
[]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
|
[]string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
|
||||||
), "-ip is required")
|
), "-networks is required")
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
// cannot set -in-pub and -out-key
|
// cannot set -in-pub and -out-key
|
||||||
assertHelpError(t, signCert(
|
assertHelpError(t, signCert(
|
||||||
[]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb, nopw,
|
[]string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb, nopw,
|
||||||
), "cannot set both -in-pub and -out-key")
|
), "cannot set both -in-pub and -out-key")
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
@ -98,7 +104,7 @@ func Test_signCert(t *testing.T) {
|
|||||||
// failed to read key
|
// failed to read key
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args := []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
args := []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
||||||
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError)
|
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError)
|
||||||
|
|
||||||
// failed to unmarshal key
|
// failed to unmarshal key
|
||||||
@ -108,7 +114,7 @@ func Test_signCert(t *testing.T) {
|
|||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
defer os.Remove(caKeyF.Name())
|
defer os.Remove(caKeyF.Name())
|
||||||
|
|
||||||
args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
||||||
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block")
|
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block")
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
@ -120,7 +126,7 @@ func Test_signCert(t *testing.T) {
|
|||||||
caKeyF.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv))
|
caKeyF.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv))
|
||||||
|
|
||||||
// failed to read cert
|
// failed to read cert
|
||||||
args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
||||||
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError)
|
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError)
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
@ -132,7 +138,7 @@ func Test_signCert(t *testing.T) {
|
|||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
defer os.Remove(caCrtF.Name())
|
defer os.Remove(caCrtF.Name())
|
||||||
|
|
||||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
||||||
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block")
|
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block")
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
@ -143,7 +149,7 @@ func Test_signCert(t *testing.T) {
|
|||||||
caCrtF.Write(b)
|
caCrtF.Write(b)
|
||||||
|
|
||||||
// failed to read pub
|
// failed to read pub
|
||||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"}
|
||||||
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError)
|
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError)
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
@ -155,7 +161,7 @@ func Test_signCert(t *testing.T) {
|
|||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
defer os.Remove(inPubF.Name())
|
defer os.Remove(inPubF.Name())
|
||||||
|
|
||||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"}
|
||||||
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block")
|
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block")
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
@ -169,30 +175,37 @@ func Test_signCert(t *testing.T) {
|
|||||||
// bad ip cidr
|
// bad ip cidr
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "a1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "a1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
||||||
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: a1.1.1.1/24")
|
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: a1.1.1.1/24")
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "100::100/100", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "100::100/100", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
||||||
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: can only be ipv4, have 100::100/100")
|
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: v1 certificates can only have a single ipv4 address")
|
||||||
|
assert.Empty(t, ob.String())
|
||||||
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
|
ob.Reset()
|
||||||
|
eb.Reset()
|
||||||
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24,1.1.1.2/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
||||||
|
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: v1 certificates can only have a single ipv4 address")
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
// bad subnet cidr
|
// bad subnet cidr
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
|
||||||
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: a")
|
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: a")
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"}
|
||||||
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: can only be ipv4, have 100::100/100")
|
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4")
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|
||||||
@ -205,7 +218,7 @@ func Test_signCert(t *testing.T) {
|
|||||||
|
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
|
||||||
assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key")
|
assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key")
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
@ -213,7 +226,7 @@ func Test_signCert(t *testing.T) {
|
|||||||
// failed key write
|
// failed key write
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"}
|
||||||
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
|
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
@ -226,7 +239,7 @@ func Test_signCert(t *testing.T) {
|
|||||||
// failed cert write
|
// failed cert write
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"}
|
||||||
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
|
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
@ -240,7 +253,7 @@ func Test_signCert(t *testing.T) {
|
|||||||
// test proper cert with removed empty groups and subnets
|
// test proper cert with removed empty groups and subnets
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||||
assert.Nil(t, signCert(args, ob, eb, nopw))
|
assert.Nil(t, signCert(args, ob, eb, nopw))
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
@ -283,7 +296,7 @@ func Test_signCert(t *testing.T) {
|
|||||||
os.Remove(crtF.Name())
|
os.Remove(crtF.Name())
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"}
|
||||||
assert.Nil(t, signCert(args, ob, eb, nopw))
|
assert.Nil(t, signCert(args, ob, eb, nopw))
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
@ -300,7 +313,7 @@ func Test_signCert(t *testing.T) {
|
|||||||
eb.Reset()
|
eb.Reset()
|
||||||
os.Remove(keyF.Name())
|
os.Remove(keyF.Name())
|
||||||
os.Remove(crtF.Name())
|
os.Remove(crtF.Name())
|
||||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||||
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while signing: certificate expires after signing certificate")
|
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while signing: certificate expires after signing certificate")
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
@ -308,14 +321,14 @@ func Test_signCert(t *testing.T) {
|
|||||||
// create valid cert/key for overwrite tests
|
// create valid cert/key for overwrite tests
|
||||||
os.Remove(keyF.Name())
|
os.Remove(keyF.Name())
|
||||||
os.Remove(crtF.Name())
|
os.Remove(crtF.Name())
|
||||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||||
assert.Nil(t, signCert(args, ob, eb, nopw))
|
assert.Nil(t, signCert(args, ob, eb, nopw))
|
||||||
|
|
||||||
// test that we won't overwrite existing key file
|
// test that we won't overwrite existing key file
|
||||||
os.Remove(crtF.Name())
|
os.Remove(crtF.Name())
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||||
assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name())
|
assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name())
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
@ -323,14 +336,14 @@ func Test_signCert(t *testing.T) {
|
|||||||
// create valid cert/key for overwrite tests
|
// create valid cert/key for overwrite tests
|
||||||
os.Remove(keyF.Name())
|
os.Remove(keyF.Name())
|
||||||
os.Remove(crtF.Name())
|
os.Remove(crtF.Name())
|
||||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||||
assert.Nil(t, signCert(args, ob, eb, nopw))
|
assert.Nil(t, signCert(args, ob, eb, nopw))
|
||||||
|
|
||||||
// test that we won't overwrite existing certificate file
|
// test that we won't overwrite existing certificate file
|
||||||
os.Remove(keyF.Name())
|
os.Remove(keyF.Name())
|
||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||||
assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name())
|
assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name())
|
||||||
assert.Empty(t, ob.String())
|
assert.Empty(t, ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
@ -362,7 +375,7 @@ func Test_signCert(t *testing.T) {
|
|||||||
caCrtF.Write(b)
|
caCrtF.Write(b)
|
||||||
|
|
||||||
// test with the proper password
|
// test with the proper password
|
||||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||||
assert.Nil(t, signCert(args, ob, eb, testpw))
|
assert.Nil(t, signCert(args, ob, eb, testpw))
|
||||||
assert.Equal(t, "Enter passphrase: ", ob.String())
|
assert.Equal(t, "Enter passphrase: ", ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
@ -372,7 +385,7 @@ func Test_signCert(t *testing.T) {
|
|||||||
eb.Reset()
|
eb.Reset()
|
||||||
|
|
||||||
testpw.password = []byte("invalid password")
|
testpw.password = []byte("invalid password")
|
||||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||||
assert.Error(t, signCert(args, ob, eb, testpw))
|
assert.Error(t, signCert(args, ob, eb, testpw))
|
||||||
assert.Equal(t, "Enter passphrase: ", ob.String())
|
assert.Equal(t, "Enter passphrase: ", ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
@ -381,7 +394,7 @@ func Test_signCert(t *testing.T) {
|
|||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
|
|
||||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||||
assert.Error(t, signCert(args, ob, eb, nopw))
|
assert.Error(t, signCert(args, ob, eb, nopw))
|
||||||
// normally the user hitting enter on the prompt would add newlines between these
|
// normally the user hitting enter on the prompt would add newlines between these
|
||||||
assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String())
|
assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String())
|
||||||
@ -391,7 +404,7 @@ func Test_signCert(t *testing.T) {
|
|||||||
ob.Reset()
|
ob.Reset()
|
||||||
eb.Reset()
|
eb.Reset()
|
||||||
|
|
||||||
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||||
assert.Error(t, signCert(args, ob, eb, errpw))
|
assert.Error(t, signCert(args, ob, eb, errpw))
|
||||||
assert.Equal(t, "Enter passphrase: ", ob.String())
|
assert.Equal(t, "Enter passphrase: ", ob.String())
|
||||||
assert.Empty(t, eb.String())
|
assert.Empty(t, eb.String())
|
||||||
|
|||||||
@ -183,7 +183,7 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte,
|
|||||||
case deleteTunnel:
|
case deleteTunnel:
|
||||||
if n.hostMap.DeleteHostInfo(hostinfo) {
|
if n.hostMap.DeleteHostInfo(hostinfo) {
|
||||||
// Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap
|
// Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap
|
||||||
n.intf.lightHouse.DeleteVpnIp(hostinfo.vpnIp)
|
n.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs)
|
||||||
}
|
}
|
||||||
|
|
||||||
case closeTunnel:
|
case closeTunnel:
|
||||||
@ -221,7 +221,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
|
|||||||
relayFor := oldhostinfo.relayState.CopyAllRelayFor()
|
relayFor := oldhostinfo.relayState.CopyAllRelayFor()
|
||||||
|
|
||||||
for _, r := range relayFor {
|
for _, r := range relayFor {
|
||||||
existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerIp)
|
existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerAddr)
|
||||||
|
|
||||||
var index uint32
|
var index uint32
|
||||||
var relayFrom netip.Addr
|
var relayFrom netip.Addr
|
||||||
@ -235,11 +235,11 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
|
|||||||
index = existing.LocalIndex
|
index = existing.LocalIndex
|
||||||
switch r.Type {
|
switch r.Type {
|
||||||
case TerminalType:
|
case TerminalType:
|
||||||
relayFrom = n.intf.myVpnNet.Addr()
|
relayFrom = n.intf.myVpnAddrs[0]
|
||||||
relayTo = existing.PeerIp
|
relayTo = existing.PeerAddr
|
||||||
case ForwardingType:
|
case ForwardingType:
|
||||||
relayFrom = existing.PeerIp
|
relayFrom = existing.PeerAddr
|
||||||
relayTo = newhostinfo.vpnIp
|
relayTo = newhostinfo.vpnAddrs[0]
|
||||||
default:
|
default:
|
||||||
// should never happen
|
// should never happen
|
||||||
}
|
}
|
||||||
@ -253,45 +253,64 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
|
|||||||
n.relayUsedLock.RUnlock()
|
n.relayUsedLock.RUnlock()
|
||||||
// The relay doesn't exist at all; create some relay state and send the request.
|
// The relay doesn't exist at all; create some relay state and send the request.
|
||||||
var err error
|
var err error
|
||||||
index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerIp, nil, r.Type, Requested)
|
index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerAddr, nil, r.Type, Requested)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
n.l.WithError(err).Error("failed to migrate relay to new hostinfo")
|
n.l.WithError(err).Error("failed to migrate relay to new hostinfo")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
switch r.Type {
|
switch r.Type {
|
||||||
case TerminalType:
|
case TerminalType:
|
||||||
relayFrom = n.intf.myVpnNet.Addr()
|
relayFrom = n.intf.myVpnAddrs[0]
|
||||||
relayTo = r.PeerIp
|
relayTo = r.PeerAddr
|
||||||
case ForwardingType:
|
case ForwardingType:
|
||||||
relayFrom = r.PeerIp
|
relayFrom = r.PeerAddr
|
||||||
relayTo = newhostinfo.vpnIp
|
relayTo = newhostinfo.vpnAddrs[0]
|
||||||
default:
|
default:
|
||||||
// should never happen
|
// should never happen
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO: IPV6-WORK
|
|
||||||
relayFromB := relayFrom.As4()
|
|
||||||
relayToB := relayTo.As4()
|
|
||||||
|
|
||||||
// Send a CreateRelayRequest to the peer.
|
// Send a CreateRelayRequest to the peer.
|
||||||
req := NebulaControl{
|
req := NebulaControl{
|
||||||
Type: NebulaControl_CreateRelayRequest,
|
Type: NebulaControl_CreateRelayRequest,
|
||||||
InitiatorRelayIndex: index,
|
InitiatorRelayIndex: index,
|
||||||
RelayFromIp: binary.BigEndian.Uint32(relayFromB[:]),
|
|
||||||
RelayToIp: binary.BigEndian.Uint32(relayToB[:]),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
switch newhostinfo.GetCert().Certificate.Version() {
|
||||||
|
case cert.Version1:
|
||||||
|
if !relayFrom.Is4() {
|
||||||
|
n.l.Error("can not migrate v1 relay with a v6 network because the relay is not running a current nebula version")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !relayTo.Is4() {
|
||||||
|
n.l.Error("can not migrate v1 relay with a v6 remote network because the relay is not running a current nebula version")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
b := relayFrom.As4()
|
||||||
|
req.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
|
||||||
|
b = relayTo.As4()
|
||||||
|
req.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
|
||||||
|
case cert.Version2:
|
||||||
|
req.RelayFromAddr = netAddrToProtoAddr(relayFrom)
|
||||||
|
req.RelayToAddr = netAddrToProtoAddr(relayTo)
|
||||||
|
default:
|
||||||
|
newhostinfo.logger(n.l).Error("Unknown certificate version found while attempting to migrate relay")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
msg, err := req.Marshal()
|
msg, err := req.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
n.l.WithError(err).Error("failed to marshal Control message to migrate relay")
|
n.l.WithError(err).Error("failed to marshal Control message to migrate relay")
|
||||||
} else {
|
} else {
|
||||||
n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
|
n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
|
||||||
n.l.WithFields(logrus.Fields{
|
n.l.WithFields(logrus.Fields{
|
||||||
"relayFrom": req.RelayFromIp,
|
"relayFrom": req.RelayFromAddr,
|
||||||
"relayTo": req.RelayToIp,
|
"relayTo": req.RelayToAddr,
|
||||||
"initiatorRelayIndex": req.InitiatorRelayIndex,
|
"initiatorRelayIndex": req.InitiatorRelayIndex,
|
||||||
"responderRelayIndex": req.ResponderRelayIndex,
|
"responderRelayIndex": req.ResponderRelayIndex,
|
||||||
"vpnIp": newhostinfo.vpnIp}).
|
"vpnAddrs": newhostinfo.vpnAddrs}).
|
||||||
Info("send CreateRelayRequest")
|
Info("send CreateRelayRequest")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -313,7 +332,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time
|
|||||||
return closeTunnel, hostinfo, nil
|
return closeTunnel, hostinfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
primary := n.hostMap.Hosts[hostinfo.vpnIp]
|
primary := n.hostMap.Hosts[hostinfo.vpnAddrs[0]]
|
||||||
mainHostInfo := true
|
mainHostInfo := true
|
||||||
if primary != nil && primary != hostinfo {
|
if primary != nil && primary != hostinfo {
|
||||||
mainHostInfo = false
|
mainHostInfo = false
|
||||||
@ -407,21 +426,24 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
|
|||||||
// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
|
// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
|
||||||
// Let's sort this out.
|
// Let's sort this out.
|
||||||
|
|
||||||
if current.vpnIp.Compare(n.intf.myVpnNet.Addr()) < 0 {
|
//TODO: current.vpnIp should become an array of vpnIps
|
||||||
|
if current.vpnAddrs[0].Compare(n.intf.myVpnAddrs[0]) < 0 {
|
||||||
// Only one side should flip primary because if both flip then we may never resolve to a single tunnel.
|
// Only one side should flip primary because if both flip then we may never resolve to a single tunnel.
|
||||||
// vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping.
|
// vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping.
|
||||||
// The remotes vpn ip is lower than mine. I will not flip.
|
// The remotes vpn ip is lower than mine. I will not flip.
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
certState := n.intf.pki.GetCertState()
|
//TODO: we should favor v2 over v1 certificates if configured to send them
|
||||||
return bytes.Equal(current.ConnectionState.myCert.Signature(), certState.Certificate.Signature())
|
|
||||||
|
crt := n.intf.pki.getCertificate(current.ConnectionState.myCert.Version())
|
||||||
|
return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
|
func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
|
||||||
n.hostMap.Lock()
|
n.hostMap.Lock()
|
||||||
// Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake.
|
// Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake.
|
||||||
if n.hostMap.Hosts[current.vpnIp] == primary {
|
if n.hostMap.Hosts[current.vpnAddrs[0]] == primary {
|
||||||
n.hostMap.unlockedMakePrimary(current)
|
n.hostMap.unlockedMakePrimary(current)
|
||||||
}
|
}
|
||||||
n.hostMap.Unlock()
|
n.hostMap.Unlock()
|
||||||
@ -473,14 +495,16 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
||||||
certState := n.intf.pki.GetCertState()
|
crt := n.intf.pki.getCertificate(hostinfo.ConnectionState.myCert.Version())
|
||||||
if bytes.Equal(hostinfo.ConnectionState.myCert.Signature(), certState.Certificate.Signature()) {
|
if bytes.Equal(hostinfo.ConnectionState.myCert.Signature(), crt.Signature()) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
n.l.WithField("vpnIp", hostinfo.vpnIp).
|
//TODO: we should favor v2 over v1 certificates if configured to send them
|
||||||
|
|
||||||
|
n.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||||
WithField("reason", "local certificate is not current").
|
WithField("reason", "local certificate is not current").
|
||||||
Info("Re-handshaking with remote")
|
Info("Re-handshaking with remote")
|
||||||
|
|
||||||
n.intf.handshakeManager.StartHandshake(hostinfo.vpnIp, nil)
|
n.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -34,20 +34,19 @@ func newTestLighthouse() *LightHouse {
|
|||||||
func Test_NewConnectionManagerTest(t *testing.T) {
|
func Test_NewConnectionManagerTest(t *testing.T) {
|
||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
|
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
|
||||||
vpncidr := netip.MustParsePrefix("172.1.1.1/24")
|
|
||||||
localrange := netip.MustParsePrefix("10.1.1.1/24")
|
localrange := netip.MustParsePrefix("10.1.1.1/24")
|
||||||
vpnIp := netip.MustParseAddr("172.1.1.2")
|
vpnIp := netip.MustParseAddr("172.1.1.2")
|
||||||
preferredRanges := []netip.Prefix{localrange}
|
preferredRanges := []netip.Prefix{localrange}
|
||||||
|
|
||||||
// Very incomplete mock objects
|
// Very incomplete mock objects
|
||||||
hostMap := newHostMap(l, vpncidr)
|
hostMap := newHostMap(l)
|
||||||
hostMap.preferredRanges.Store(&preferredRanges)
|
hostMap.preferredRanges.Store(&preferredRanges)
|
||||||
|
|
||||||
cs := &CertState{
|
cs := &CertState{
|
||||||
RawCertificate: []byte{},
|
defaultVersion: cert.Version1,
|
||||||
PrivateKey: []byte{},
|
privateKey: []byte{},
|
||||||
Certificate: &dummyCert{},
|
v1Cert: &dummyCert{version: cert.Version1},
|
||||||
RawCertificateNoKey: []byte{},
|
v1HandshakeBytes: []byte{},
|
||||||
}
|
}
|
||||||
|
|
||||||
lh := newTestLighthouse()
|
lh := newTestLighthouse()
|
||||||
@ -74,12 +73,12 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
|||||||
|
|
||||||
// Add an ip we have established a connection w/ to hostmap
|
// Add an ip we have established a connection w/ to hostmap
|
||||||
hostinfo := &HostInfo{
|
hostinfo := &HostInfo{
|
||||||
vpnIp: vpnIp,
|
vpnAddrs: []netip.Addr{vpnIp},
|
||||||
localIndexId: 1099,
|
localIndexId: 1099,
|
||||||
remoteIndexId: 9901,
|
remoteIndexId: 9901,
|
||||||
}
|
}
|
||||||
hostinfo.ConnectionState = &ConnectionState{
|
hostinfo.ConnectionState = &ConnectionState{
|
||||||
myCert: &dummyCert{},
|
myCert: &dummyCert{version: cert.Version1},
|
||||||
H: &noise.HandshakeState{},
|
H: &noise.HandshakeState{},
|
||||||
}
|
}
|
||||||
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
||||||
@ -88,7 +87,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
|||||||
nc.Out(hostinfo.localIndexId)
|
nc.Out(hostinfo.localIndexId)
|
||||||
nc.In(hostinfo.localIndexId)
|
nc.In(hostinfo.localIndexId)
|
||||||
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
|
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
|
||||||
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
|
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
||||||
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
||||||
assert.Contains(t, nc.out, hostinfo.localIndexId)
|
assert.Contains(t, nc.out, hostinfo.localIndexId)
|
||||||
|
|
||||||
@ -105,32 +104,31 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
|||||||
assert.NotContains(t, nc.out, hostinfo.localIndexId)
|
assert.NotContains(t, nc.out, hostinfo.localIndexId)
|
||||||
assert.NotContains(t, nc.in, hostinfo.localIndexId)
|
assert.NotContains(t, nc.in, hostinfo.localIndexId)
|
||||||
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
||||||
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
|
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
||||||
|
|
||||||
// Do a final traffic check tick, the host should now be removed
|
// Do a final traffic check tick, the host should now be removed
|
||||||
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
|
||||||
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
|
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
|
||||||
assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
|
assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
||||||
assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_NewConnectionManagerTest2(t *testing.T) {
|
func Test_NewConnectionManagerTest2(t *testing.T) {
|
||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
|
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
|
||||||
vpncidr := netip.MustParsePrefix("172.1.1.1/24")
|
|
||||||
localrange := netip.MustParsePrefix("10.1.1.1/24")
|
localrange := netip.MustParsePrefix("10.1.1.1/24")
|
||||||
vpnIp := netip.MustParseAddr("172.1.1.2")
|
vpnIp := netip.MustParseAddr("172.1.1.2")
|
||||||
preferredRanges := []netip.Prefix{localrange}
|
preferredRanges := []netip.Prefix{localrange}
|
||||||
|
|
||||||
// Very incomplete mock objects
|
// Very incomplete mock objects
|
||||||
hostMap := newHostMap(l, vpncidr)
|
hostMap := newHostMap(l)
|
||||||
hostMap.preferredRanges.Store(&preferredRanges)
|
hostMap.preferredRanges.Store(&preferredRanges)
|
||||||
|
|
||||||
cs := &CertState{
|
cs := &CertState{
|
||||||
RawCertificate: []byte{},
|
defaultVersion: cert.Version1,
|
||||||
PrivateKey: []byte{},
|
privateKey: []byte{},
|
||||||
Certificate: &dummyCert{},
|
v1Cert: &dummyCert{version: cert.Version1},
|
||||||
RawCertificateNoKey: []byte{},
|
v1HandshakeBytes: []byte{},
|
||||||
}
|
}
|
||||||
|
|
||||||
lh := newTestLighthouse()
|
lh := newTestLighthouse()
|
||||||
@ -157,12 +155,12 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
|||||||
|
|
||||||
// Add an ip we have established a connection w/ to hostmap
|
// Add an ip we have established a connection w/ to hostmap
|
||||||
hostinfo := &HostInfo{
|
hostinfo := &HostInfo{
|
||||||
vpnIp: vpnIp,
|
vpnAddrs: []netip.Addr{vpnIp},
|
||||||
localIndexId: 1099,
|
localIndexId: 1099,
|
||||||
remoteIndexId: 9901,
|
remoteIndexId: 9901,
|
||||||
}
|
}
|
||||||
hostinfo.ConnectionState = &ConnectionState{
|
hostinfo.ConnectionState = &ConnectionState{
|
||||||
myCert: &dummyCert{},
|
myCert: &dummyCert{version: cert.Version1},
|
||||||
H: &noise.HandshakeState{},
|
H: &noise.HandshakeState{},
|
||||||
}
|
}
|
||||||
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
||||||
@ -170,8 +168,8 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
|||||||
// We saw traffic out to vpnIp
|
// We saw traffic out to vpnIp
|
||||||
nc.Out(hostinfo.localIndexId)
|
nc.Out(hostinfo.localIndexId)
|
||||||
nc.In(hostinfo.localIndexId)
|
nc.In(hostinfo.localIndexId)
|
||||||
assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnIp)
|
assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnAddrs[0])
|
||||||
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
|
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
||||||
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
||||||
|
|
||||||
// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
|
// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
|
||||||
@ -187,7 +185,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
|||||||
assert.NotContains(t, nc.out, hostinfo.localIndexId)
|
assert.NotContains(t, nc.out, hostinfo.localIndexId)
|
||||||
assert.NotContains(t, nc.in, hostinfo.localIndexId)
|
assert.NotContains(t, nc.in, hostinfo.localIndexId)
|
||||||
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
||||||
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
|
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
||||||
|
|
||||||
// We saw traffic, should no longer be pending deletion
|
// We saw traffic, should no longer be pending deletion
|
||||||
nc.In(hostinfo.localIndexId)
|
nc.In(hostinfo.localIndexId)
|
||||||
@ -196,7 +194,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
|||||||
assert.NotContains(t, nc.out, hostinfo.localIndexId)
|
assert.NotContains(t, nc.out, hostinfo.localIndexId)
|
||||||
assert.NotContains(t, nc.in, hostinfo.localIndexId)
|
assert.NotContains(t, nc.in, hostinfo.localIndexId)
|
||||||
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
|
||||||
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
|
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if we can disconnect the peer.
|
// Check if we can disconnect the peer.
|
||||||
@ -210,7 +208,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
|||||||
localrange := netip.MustParsePrefix("10.1.1.1/24")
|
localrange := netip.MustParsePrefix("10.1.1.1/24")
|
||||||
vpnIp := netip.MustParseAddr("172.1.1.2")
|
vpnIp := netip.MustParseAddr("172.1.1.2")
|
||||||
preferredRanges := []netip.Prefix{localrange}
|
preferredRanges := []netip.Prefix{localrange}
|
||||||
hostMap := newHostMap(l, vpncidr)
|
hostMap := newHostMap(l)
|
||||||
hostMap.preferredRanges.Store(&preferredRanges)
|
hostMap.preferredRanges.Store(&preferredRanges)
|
||||||
|
|
||||||
// Generate keys for CA and peer's cert.
|
// Generate keys for CA and peer's cert.
|
||||||
@ -244,10 +242,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
|||||||
cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert)
|
cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert)
|
||||||
|
|
||||||
cs := &CertState{
|
cs := &CertState{
|
||||||
RawCertificate: []byte{},
|
privateKey: []byte{},
|
||||||
PrivateKey: []byte{},
|
v1Cert: &dummyCert{},
|
||||||
Certificate: &dummyCert{},
|
v1HandshakeBytes: []byte{},
|
||||||
RawCertificateNoKey: []byte{},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
lh := newTestLighthouse()
|
lh := newTestLighthouse()
|
||||||
@ -273,7 +270,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
|||||||
ifce.connectionManager = nc
|
ifce.connectionManager = nc
|
||||||
|
|
||||||
hostinfo := &HostInfo{
|
hostinfo := &HostInfo{
|
||||||
vpnIp: vpnIp,
|
vpnAddrs: []netip.Addr{vpnIp},
|
||||||
ConnectionState: &ConnectionState{
|
ConnectionState: &ConnectionState{
|
||||||
myCert: &dummyCert{},
|
myCert: &dummyCert{},
|
||||||
peerCert: cachedPeerCert,
|
peerCert: cachedPeerCert,
|
||||||
|
|||||||
@ -3,6 +3,7 @@ package nebula
|
|||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
@ -26,46 +27,46 @@ type ConnectionState struct {
|
|||||||
writeLock sync.Mutex
|
writeLock sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
|
func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) {
|
||||||
var dhFunc noise.DHFunc
|
var dhFunc noise.DHFunc
|
||||||
switch certState.Certificate.Curve() {
|
switch crt.Curve() {
|
||||||
case cert.Curve_CURVE25519:
|
case cert.Curve_CURVE25519:
|
||||||
dhFunc = noise.DH25519
|
dhFunc = noise.DH25519
|
||||||
case cert.Curve_P256:
|
case cert.Curve_P256:
|
||||||
if certState.pkcs11Backed {
|
if cs.pkcs11Backed {
|
||||||
dhFunc = noiseutil.DHP256PKCS11
|
dhFunc = noiseutil.DHP256PKCS11
|
||||||
} else {
|
} else {
|
||||||
dhFunc = noiseutil.DHP256
|
dhFunc = noiseutil.DHP256
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
l.Errorf("invalid curve: %s", certState.Certificate.Curve())
|
return nil, fmt.Errorf("invalid curve: %s", crt.Curve())
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var cs noise.CipherSuite
|
var ncs noise.CipherSuite
|
||||||
if cipher == "chachapoly" {
|
if cs.cipher == "chachapoly" {
|
||||||
cs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
|
ncs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
|
||||||
} else {
|
} else {
|
||||||
cs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
|
ncs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
|
||||||
}
|
}
|
||||||
|
|
||||||
static := noise.DHKey{Private: certState.PrivateKey, Public: certState.PublicKey}
|
static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()}
|
||||||
|
|
||||||
b := NewBits(ReplayWindow)
|
b := NewBits(ReplayWindow)
|
||||||
// Clear out bit 0, we never transmit it and we don't want it showing as packet loss
|
// Clear out bit 0, we never transmit it, and we don't want it showing as packet loss
|
||||||
b.Update(l, 0)
|
b.Update(l, 0)
|
||||||
|
|
||||||
hs, err := noise.NewHandshakeState(noise.Config{
|
hs, err := noise.NewHandshakeState(noise.Config{
|
||||||
CipherSuite: cs,
|
CipherSuite: ncs,
|
||||||
Random: rand.Reader,
|
Random: rand.Reader,
|
||||||
Pattern: pattern,
|
Pattern: pattern,
|
||||||
Initiator: initiator,
|
Initiator: initiator,
|
||||||
StaticKeypair: static,
|
StaticKeypair: static,
|
||||||
PresharedKey: psk,
|
//NOTE: These should come from CertState (pki.go) when we finally implement it
|
||||||
PresharedKeyPlacement: pskStage,
|
PresharedKey: []byte{},
|
||||||
|
PresharedKeyPlacement: 0,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil, fmt.Errorf("NewConnectionState: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// The queue and ready params prevent a counter race that would happen when
|
// The queue and ready params prevent a counter race that would happen when
|
||||||
@ -74,12 +75,12 @@ func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, i
|
|||||||
H: hs,
|
H: hs,
|
||||||
initiator: initiator,
|
initiator: initiator,
|
||||||
window: b,
|
window: b,
|
||||||
myCert: certState.Certificate,
|
myCert: crt,
|
||||||
}
|
}
|
||||||
// always start the counter from 2, as packet 1 and packet 2 are handshake packets.
|
// always start the counter from 2, as packet 1 and packet 2 are handshake packets.
|
||||||
ci.messageCounter.Add(2)
|
ci.messageCounter.Add(2)
|
||||||
|
|
||||||
return ci
|
return ci, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
|
func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
|
||||||
@ -89,3 +90,7 @@ func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
|
|||||||
"message_counter": cs.messageCounter.Load(),
|
"message_counter": cs.messageCounter.Load(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (cs *ConnectionState) Curve() cert.Curve {
|
||||||
|
return cs.myCert.Curve()
|
||||||
|
}
|
||||||
|
|||||||
44
control.go
44
control.go
@ -19,9 +19,9 @@ import (
|
|||||||
type controlEach func(h *HostInfo)
|
type controlEach func(h *HostInfo)
|
||||||
|
|
||||||
type controlHostLister interface {
|
type controlHostLister interface {
|
||||||
QueryVpnIp(vpnIp netip.Addr) *HostInfo
|
QueryVpnAddr(vpnAddr netip.Addr) *HostInfo
|
||||||
ForEachIndex(each controlEach)
|
ForEachIndex(each controlEach)
|
||||||
ForEachVpnIp(each controlEach)
|
ForEachVpnAddr(each controlEach)
|
||||||
GetPreferredRanges() []netip.Prefix
|
GetPreferredRanges() []netip.Prefix
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -37,7 +37,7 @@ type Control struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ControlHostInfo struct {
|
type ControlHostInfo struct {
|
||||||
VpnIp netip.Addr `json:"vpnIp"`
|
VpnAddrs []netip.Addr `json:"vpnAddrs"`
|
||||||
LocalIndex uint32 `json:"localIndex"`
|
LocalIndex uint32 `json:"localIndex"`
|
||||||
RemoteIndex uint32 `json:"remoteIndex"`
|
RemoteIndex uint32 `json:"remoteIndex"`
|
||||||
RemoteAddrs []netip.AddrPort `json:"remoteAddrs"`
|
RemoteAddrs []netip.AddrPort `json:"remoteAddrs"`
|
||||||
@ -131,10 +131,13 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
|
|||||||
|
|
||||||
// GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found
|
// GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found
|
||||||
func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate {
|
func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate {
|
||||||
if c.f.myVpnNet.Addr() == vpnIp {
|
_, found := c.f.myVpnAddrsTable.Lookup(vpnIp)
|
||||||
return c.f.pki.GetCertState().Certificate.Copy()
|
if found {
|
||||||
|
//TODO: we might have 2 certs....
|
||||||
|
//TODO: this should return our latest version cert
|
||||||
|
return c.f.pki.getDefaultCertificate().Copy()
|
||||||
}
|
}
|
||||||
hi := c.f.hostMap.QueryVpnIp(vpnIp)
|
hi := c.f.hostMap.QueryVpnAddr(vpnIp)
|
||||||
if hi == nil {
|
if hi == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -148,7 +151,7 @@ func (c *Control) CreateTunnel(vpnIp netip.Addr) {
|
|||||||
|
|
||||||
// PrintTunnel creates a new tunnel to the given vpn ip.
|
// PrintTunnel creates a new tunnel to the given vpn ip.
|
||||||
func (c *Control) PrintTunnel(vpnIp netip.Addr) *ControlHostInfo {
|
func (c *Control) PrintTunnel(vpnIp netip.Addr) *ControlHostInfo {
|
||||||
hi := c.f.hostMap.QueryVpnIp(vpnIp)
|
hi := c.f.hostMap.QueryVpnAddr(vpnIp)
|
||||||
if hi == nil {
|
if hi == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -165,9 +168,9 @@ func (c *Control) QueryLighthouse(vpnIp netip.Addr) *CacheMap {
|
|||||||
return hi.CopyCache()
|
return hi.CopyCache()
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found
|
// GetHostInfoByVpnAddr returns a single tunnels hostInfo, or nil if not found
|
||||||
// Caller should take care to Unmap() any 4in6 addresses prior to calling.
|
// Caller should take care to Unmap() any 4in6 addresses prior to calling.
|
||||||
func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHostInfo {
|
func (c *Control) GetHostInfoByVpnAddr(vpnAddr netip.Addr, pending bool) *ControlHostInfo {
|
||||||
var hl controlHostLister
|
var hl controlHostLister
|
||||||
if pending {
|
if pending {
|
||||||
hl = c.f.handshakeManager
|
hl = c.f.handshakeManager
|
||||||
@ -175,7 +178,7 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHos
|
|||||||
hl = c.f.hostMap
|
hl = c.f.hostMap
|
||||||
}
|
}
|
||||||
|
|
||||||
h := hl.QueryVpnIp(vpnIp)
|
h := hl.QueryVpnAddr(vpnAddr)
|
||||||
if h == nil {
|
if h == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -187,7 +190,7 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHos
|
|||||||
// SetRemoteForTunnel forces a tunnel to use a specific remote
|
// SetRemoteForTunnel forces a tunnel to use a specific remote
|
||||||
// Caller should take care to Unmap() any 4in6 addresses prior to calling.
|
// Caller should take care to Unmap() any 4in6 addresses prior to calling.
|
||||||
func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *ControlHostInfo {
|
func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *ControlHostInfo {
|
||||||
hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
|
hostInfo := c.f.hostMap.QueryVpnAddr(vpnIp)
|
||||||
if hostInfo == nil {
|
if hostInfo == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -200,7 +203,7 @@ func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *Con
|
|||||||
// CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
|
// CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
|
||||||
// Caller should take care to Unmap() any 4in6 addresses prior to calling.
|
// Caller should take care to Unmap() any 4in6 addresses prior to calling.
|
||||||
func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool {
|
func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool {
|
||||||
hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
|
hostInfo := c.f.hostMap.QueryVpnAddr(vpnIp)
|
||||||
if hostInfo == nil {
|
if hostInfo == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@ -229,14 +232,14 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
|
|||||||
|
|
||||||
shutdown := func(h *HostInfo) {
|
shutdown := func(h *HostInfo) {
|
||||||
if excludeLighthouses {
|
if excludeLighthouses {
|
||||||
if _, ok := lighthouses[h.vpnIp]; ok {
|
if _, ok := lighthouses[h.vpnAddrs[0]]; ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
|
c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
|
||||||
c.f.closeTunnel(h)
|
c.f.closeTunnel(h)
|
||||||
|
|
||||||
c.l.WithField("vpnIp", h.vpnIp).WithField("udpAddr", h.remote).
|
c.l.WithField("vpnIp", h.vpnAddrs[0]).WithField("udpAddr", h.remote).
|
||||||
Debug("Sending close tunnel message")
|
Debug("Sending close tunnel message")
|
||||||
closed++
|
closed++
|
||||||
}
|
}
|
||||||
@ -246,7 +249,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
|
|||||||
// Grab the hostMap lock to access the Relays map
|
// Grab the hostMap lock to access the Relays map
|
||||||
c.f.hostMap.Lock()
|
c.f.hostMap.Lock()
|
||||||
for _, relayingHost := range c.f.hostMap.Relays {
|
for _, relayingHost := range c.f.hostMap.Relays {
|
||||||
relayingHosts[relayingHost.vpnIp] = relayingHost
|
relayingHosts[relayingHost.vpnAddrs[0]] = relayingHost
|
||||||
}
|
}
|
||||||
c.f.hostMap.Unlock()
|
c.f.hostMap.Unlock()
|
||||||
|
|
||||||
@ -254,7 +257,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
|
|||||||
// Grab the hostMap lock to access the Hosts map
|
// Grab the hostMap lock to access the Hosts map
|
||||||
c.f.hostMap.Lock()
|
c.f.hostMap.Lock()
|
||||||
for _, relayHost := range c.f.hostMap.Indexes {
|
for _, relayHost := range c.f.hostMap.Indexes {
|
||||||
if _, ok := relayingHosts[relayHost.vpnIp]; !ok {
|
if _, ok := relayingHosts[relayHost.vpnAddrs[0]]; !ok {
|
||||||
hostInfos = append(hostInfos, relayHost)
|
hostInfos = append(hostInfos, relayHost)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -274,9 +277,8 @@ func (c *Control) Device() overlay.Device {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
|
func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
|
||||||
|
|
||||||
chi := ControlHostInfo{
|
chi := ControlHostInfo{
|
||||||
VpnIp: h.vpnIp,
|
VpnAddrs: make([]netip.Addr, len(h.vpnAddrs)),
|
||||||
LocalIndex: h.localIndexId,
|
LocalIndex: h.localIndexId,
|
||||||
RemoteIndex: h.remoteIndexId,
|
RemoteIndex: h.remoteIndexId,
|
||||||
RemoteAddrs: h.remotes.CopyAddrs(preferredRanges),
|
RemoteAddrs: h.remotes.CopyAddrs(preferredRanges),
|
||||||
@ -285,6 +287,10 @@ func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
|
|||||||
CurrentRemote: h.remote,
|
CurrentRemote: h.remote,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for i, a := range h.vpnAddrs {
|
||||||
|
chi.VpnAddrs[i] = a
|
||||||
|
}
|
||||||
|
|
||||||
if h.ConnectionState != nil {
|
if h.ConnectionState != nil {
|
||||||
chi.MessageCounter = h.ConnectionState.messageCounter.Load()
|
chi.MessageCounter = h.ConnectionState.messageCounter.Load()
|
||||||
}
|
}
|
||||||
@ -299,7 +305,7 @@ func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
|
|||||||
func listHostMapHosts(hl controlHostLister) []ControlHostInfo {
|
func listHostMapHosts(hl controlHostLister) []ControlHostInfo {
|
||||||
hosts := make([]ControlHostInfo, 0)
|
hosts := make([]ControlHostInfo, 0)
|
||||||
pr := hl.GetPreferredRanges()
|
pr := hl.GetPreferredRanges()
|
||||||
hl.ForEachVpnIp(func(hostinfo *HostInfo) {
|
hl.ForEachVpnAddr(func(hostinfo *HostInfo) {
|
||||||
hosts = append(hosts, copyHostInfo(hostinfo, pr))
|
hosts = append(hosts, copyHostInfo(hostinfo, pr))
|
||||||
})
|
})
|
||||||
return hosts
|
return hosts
|
||||||
|
|||||||
@ -19,7 +19,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
|
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
|
||||||
// To properly ensure we are not exposing core memory to the caller
|
// To properly ensure we are not exposing core memory to the caller
|
||||||
hm := newHostMap(l, netip.Prefix{})
|
hm := newHostMap(l)
|
||||||
hm.preferredRanges.Store(&[]netip.Prefix{})
|
hm.preferredRanges.Store(&[]netip.Prefix{})
|
||||||
|
|
||||||
remote1 := netip.MustParseAddrPort("0.0.0.100:4444")
|
remote1 := netip.MustParseAddrPort("0.0.0.100:4444")
|
||||||
@ -35,9 +35,9 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
|||||||
Mask: net.IPMask{255, 255, 255, 0},
|
Mask: net.IPMask{255, 255, 255, 0},
|
||||||
}
|
}
|
||||||
|
|
||||||
remotes := NewRemoteList(nil)
|
remotes := NewRemoteList([]netip.Addr{netip.IPv4Unspecified()}, nil)
|
||||||
remotes.unlockedPrependV4(netip.IPv4Unspecified(), NewIp4AndPortFromNetIP(remote1.Addr(), remote1.Port()))
|
remotes.unlockedPrependV4(netip.IPv4Unspecified(), netAddrToProtoV4AddrPort(remote1.Addr(), remote1.Port()))
|
||||||
remotes.unlockedPrependV6(netip.IPv4Unspecified(), NewIp6AndPortFromNetIP(remote2.Addr(), remote2.Port()))
|
remotes.unlockedPrependV6(netip.IPv4Unspecified(), netAddrToProtoV6AddrPort(remote2.Addr(), remote2.Port()))
|
||||||
|
|
||||||
vpnIp, ok := netip.AddrFromSlice(ipNet.IP)
|
vpnIp, ok := netip.AddrFromSlice(ipNet.IP)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
@ -51,10 +51,10 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
|||||||
},
|
},
|
||||||
remoteIndexId: 200,
|
remoteIndexId: 200,
|
||||||
localIndexId: 201,
|
localIndexId: 201,
|
||||||
vpnIp: vpnIp,
|
vpnAddrs: []netip.Addr{vpnIp},
|
||||||
relayState: RelayState{
|
relayState: RelayState{
|
||||||
relays: map[netip.Addr]struct{}{},
|
relays: map[netip.Addr]struct{}{},
|
||||||
relayForByIp: map[netip.Addr]*Relay{},
|
relayForByAddr: map[netip.Addr]*Relay{},
|
||||||
relayForByIdx: map[uint32]*Relay{},
|
relayForByIdx: map[uint32]*Relay{},
|
||||||
},
|
},
|
||||||
}, &Interface{})
|
}, &Interface{})
|
||||||
@ -70,10 +70,10 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
|||||||
},
|
},
|
||||||
remoteIndexId: 200,
|
remoteIndexId: 200,
|
||||||
localIndexId: 201,
|
localIndexId: 201,
|
||||||
vpnIp: vpnIp2,
|
vpnAddrs: []netip.Addr{vpnIp2},
|
||||||
relayState: RelayState{
|
relayState: RelayState{
|
||||||
relays: map[netip.Addr]struct{}{},
|
relays: map[netip.Addr]struct{}{},
|
||||||
relayForByIp: map[netip.Addr]*Relay{},
|
relayForByAddr: map[netip.Addr]*Relay{},
|
||||||
relayForByIdx: map[uint32]*Relay{},
|
relayForByIdx: map[uint32]*Relay{},
|
||||||
},
|
},
|
||||||
}, &Interface{})
|
}, &Interface{})
|
||||||
@ -85,10 +85,10 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
|||||||
l: logrus.New(),
|
l: logrus.New(),
|
||||||
}
|
}
|
||||||
|
|
||||||
thi := c.GetHostInfoByVpnIp(vpnIp, false)
|
thi := c.GetHostInfoByVpnAddr(vpnIp, false)
|
||||||
|
|
||||||
expectedInfo := ControlHostInfo{
|
expectedInfo := ControlHostInfo{
|
||||||
VpnIp: vpnIp,
|
VpnAddrs: []netip.Addr{vpnIp},
|
||||||
LocalIndex: 201,
|
LocalIndex: 201,
|
||||||
RemoteIndex: 200,
|
RemoteIndex: 200,
|
||||||
RemoteAddrs: []netip.AddrPort{remote2, remote1},
|
RemoteAddrs: []netip.AddrPort{remote2, remote1},
|
||||||
@ -100,13 +100,13 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Make sure we don't have any unexpected fields
|
// Make sure we don't have any unexpected fields
|
||||||
assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
|
assertFields(t, []string{"VpnAddrs", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
|
||||||
assert.EqualValues(t, &expectedInfo, thi)
|
assert.EqualValues(t, &expectedInfo, thi)
|
||||||
test.AssertDeepCopyEqual(t, &expectedInfo, thi)
|
test.AssertDeepCopyEqual(t, &expectedInfo, thi)
|
||||||
|
|
||||||
// Make sure we don't panic if the host info doesn't have a cert yet
|
// Make sure we don't panic if the host info doesn't have a cert yet
|
||||||
assert.NotPanics(t, func() {
|
assert.NotPanics(t, func() {
|
||||||
thi = c.GetHostInfoByVpnIp(vpnIp2, false)
|
thi = c.GetHostInfoByVpnAddr(vpnIp2, false)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -6,8 +6,6 @@ package nebula
|
|||||||
import (
|
import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/slackhq/nebula/cert"
|
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
@ -51,15 +49,15 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType,
|
|||||||
// This is necessary if you did not configure static hosts or are not running a lighthouse
|
// This is necessary if you did not configure static hosts or are not running a lighthouse
|
||||||
func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort) {
|
func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort) {
|
||||||
c.f.lightHouse.Lock()
|
c.f.lightHouse.Lock()
|
||||||
remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp)
|
remoteList := c.f.lightHouse.unlockedGetRemoteList([]netip.Addr{vpnIp})
|
||||||
remoteList.Lock()
|
remoteList.Lock()
|
||||||
defer remoteList.Unlock()
|
defer remoteList.Unlock()
|
||||||
c.f.lightHouse.Unlock()
|
c.f.lightHouse.Unlock()
|
||||||
|
|
||||||
if toAddr.Addr().Is4() {
|
if toAddr.Addr().Is4() {
|
||||||
remoteList.unlockedPrependV4(vpnIp, NewIp4AndPortFromNetIP(toAddr.Addr(), toAddr.Port()))
|
remoteList.unlockedPrependV4(vpnIp, netAddrToProtoV4AddrPort(toAddr.Addr(), toAddr.Port()))
|
||||||
} else {
|
} else {
|
||||||
remoteList.unlockedPrependV6(vpnIp, NewIp6AndPortFromNetIP(toAddr.Addr(), toAddr.Port()))
|
remoteList.unlockedPrependV6(vpnIp, netAddrToProtoV6AddrPort(toAddr.Addr(), toAddr.Port()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -67,7 +65,7 @@ func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort)
|
|||||||
// This is necessary to inform an initiator of possible relays for communicating with a responder
|
// This is necessary to inform an initiator of possible relays for communicating with a responder
|
||||||
func (c *Control) InjectRelays(vpnIp netip.Addr, relayVpnIps []netip.Addr) {
|
func (c *Control) InjectRelays(vpnIp netip.Addr, relayVpnIps []netip.Addr) {
|
||||||
c.f.lightHouse.Lock()
|
c.f.lightHouse.Lock()
|
||||||
remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp)
|
remoteList := c.f.lightHouse.unlockedGetRemoteList([]netip.Addr{vpnIp})
|
||||||
remoteList.Lock()
|
remoteList.Lock()
|
||||||
defer remoteList.Unlock()
|
defer remoteList.Unlock()
|
||||||
c.f.lightHouse.Unlock()
|
c.f.lightHouse.Unlock()
|
||||||
@ -99,21 +97,42 @@ func (c *Control) InjectUDPPacket(p *udp.Packet) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol
|
// InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol
|
||||||
func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort uint16, data []byte) {
|
func (c *Control) InjectTunUDPPacket(toAddr netip.Addr, toPort uint16, fromAddr netip.Addr, fromPort uint16, data []byte) {
|
||||||
//TODO: IPV6-WORK
|
serialize := make([]gopacket.SerializableLayer, 0)
|
||||||
ip := layers.IPv4{
|
var netLayer gopacket.NetworkLayer
|
||||||
|
if toAddr.Is6() {
|
||||||
|
if !fromAddr.Is6() {
|
||||||
|
panic("Cant send ipv6 to ipv4")
|
||||||
|
}
|
||||||
|
ip := &layers.IPv6{
|
||||||
|
Version: 6,
|
||||||
|
NextHeader: layers.IPProtocolUDP,
|
||||||
|
SrcIP: fromAddr.Unmap().AsSlice(),
|
||||||
|
DstIP: toAddr.Unmap().AsSlice(),
|
||||||
|
}
|
||||||
|
serialize = append(serialize, ip)
|
||||||
|
netLayer = ip
|
||||||
|
} else {
|
||||||
|
if !fromAddr.Is4() {
|
||||||
|
panic("Cant send ipv4 to ipv6")
|
||||||
|
}
|
||||||
|
|
||||||
|
ip := &layers.IPv4{
|
||||||
Version: 4,
|
Version: 4,
|
||||||
TTL: 64,
|
TTL: 64,
|
||||||
Protocol: layers.IPProtocolUDP,
|
Protocol: layers.IPProtocolUDP,
|
||||||
SrcIP: c.f.inside.Cidr().Addr().Unmap().AsSlice(),
|
SrcIP: fromAddr.Unmap().AsSlice(),
|
||||||
DstIP: toIp.Unmap().AsSlice(),
|
DstIP: toAddr.Unmap().AsSlice(),
|
||||||
|
}
|
||||||
|
serialize = append(serialize, ip)
|
||||||
|
netLayer = ip
|
||||||
}
|
}
|
||||||
|
|
||||||
udp := layers.UDP{
|
udp := layers.UDP{
|
||||||
SrcPort: layers.UDPPort(fromPort),
|
SrcPort: layers.UDPPort(fromPort),
|
||||||
DstPort: layers.UDPPort(toPort),
|
DstPort: layers.UDPPort(toPort),
|
||||||
}
|
}
|
||||||
err := udp.SetNetworkLayerForChecksum(&ip)
|
err := udp.SetNetworkLayerForChecksum(netLayer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
@ -123,7 +142,9 @@ func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort ui
|
|||||||
ComputeChecksums: true,
|
ComputeChecksums: true,
|
||||||
FixLengths: true,
|
FixLengths: true,
|
||||||
}
|
}
|
||||||
err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload(data))
|
|
||||||
|
serialize = append(serialize, &udp, gopacket.Payload(data))
|
||||||
|
err = gopacket.SerializeLayers(buffer, opt, serialize...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
@ -131,8 +152,8 @@ func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort ui
|
|||||||
c.f.inside.(*overlay.TestTun).Send(buffer.Bytes())
|
c.f.inside.(*overlay.TestTun).Send(buffer.Bytes())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Control) GetVpnIp() netip.Addr {
|
func (c *Control) GetVpnAddrs() []netip.Addr {
|
||||||
return c.f.myVpnNet.Addr()
|
return c.f.myVpnAddrs
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Control) GetUDPAddr() netip.AddrPort {
|
func (c *Control) GetUDPAddr() netip.AddrPort {
|
||||||
@ -140,7 +161,7 @@ func (c *Control) GetUDPAddr() netip.AddrPort {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Control) KillPendingTunnel(vpnIp netip.Addr) bool {
|
func (c *Control) KillPendingTunnel(vpnIp netip.Addr) bool {
|
||||||
hostinfo := c.f.handshakeManager.QueryVpnIp(vpnIp)
|
hostinfo := c.f.handshakeManager.QueryVpnAddr(vpnIp)
|
||||||
if hostinfo == nil {
|
if hostinfo == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@ -153,8 +174,8 @@ func (c *Control) GetHostmap() *HostMap {
|
|||||||
return c.f.hostMap
|
return c.f.hostMap
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Control) GetCert() cert.Certificate {
|
func (c *Control) GetCertState() *CertState {
|
||||||
return c.f.pki.GetCertState().Certificate
|
return c.f.pki.getCertState()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Control) ReHandshake(vpnIp netip.Addr) {
|
func (c *Control) ReHandshake(vpnIp netip.Addr) {
|
||||||
|
|||||||
104
dns_server.go
104
dns_server.go
@ -8,6 +8,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/gaissmai/bart"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
@ -21,24 +22,39 @@ var dnsAddr string
|
|||||||
|
|
||||||
type dnsRecords struct {
|
type dnsRecords struct {
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
dnsMap map[string]string
|
l *logrus.Logger
|
||||||
|
dnsMap4 map[string]netip.Addr
|
||||||
|
dnsMap6 map[string]netip.Addr
|
||||||
hostMap *HostMap
|
hostMap *HostMap
|
||||||
|
myVpnAddrsTable *bart.Table[struct{}]
|
||||||
}
|
}
|
||||||
|
|
||||||
func newDnsRecords(hostMap *HostMap) *dnsRecords {
|
func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords {
|
||||||
return &dnsRecords{
|
return &dnsRecords{
|
||||||
dnsMap: make(map[string]string),
|
l: l,
|
||||||
|
dnsMap4: make(map[string]netip.Addr),
|
||||||
|
dnsMap6: make(map[string]netip.Addr),
|
||||||
hostMap: hostMap,
|
hostMap: hostMap,
|
||||||
|
myVpnAddrsTable: cs.myVpnAddrsTable,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *dnsRecords) Query(data string) string {
|
func (d *dnsRecords) Query(q uint16, data string) netip.Addr {
|
||||||
|
data = strings.ToLower(data)
|
||||||
d.RLock()
|
d.RLock()
|
||||||
defer d.RUnlock()
|
defer d.RUnlock()
|
||||||
if r, ok := d.dnsMap[strings.ToLower(data)]; ok {
|
switch q {
|
||||||
|
case dns.TypeA:
|
||||||
|
if r, ok := d.dnsMap4[data]; ok {
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
return ""
|
case dns.TypeAAAA:
|
||||||
|
if r, ok := d.dnsMap6[data]; ok {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return netip.Addr{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *dnsRecords) QueryCert(data string) string {
|
func (d *dnsRecords) QueryCert(data string) string {
|
||||||
@ -47,7 +63,7 @@ func (d *dnsRecords) QueryCert(data string) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
hostinfo := d.hostMap.QueryVpnIp(ip)
|
hostinfo := d.hostMap.QueryVpnAddr(ip)
|
||||||
if hostinfo == nil {
|
if hostinfo == nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
@ -64,38 +80,62 @@ func (d *dnsRecords) QueryCert(data string) string {
|
|||||||
return string(b)
|
return string(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *dnsRecords) Add(host, data string) {
|
// Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host`
|
||||||
|
func (d *dnsRecords) Add(host string, addresses []netip.Addr) {
|
||||||
|
host = strings.ToLower(host)
|
||||||
d.Lock()
|
d.Lock()
|
||||||
defer d.Unlock()
|
defer d.Unlock()
|
||||||
d.dnsMap[strings.ToLower(host)] = data
|
haveV4 := false
|
||||||
|
haveV6 := false
|
||||||
|
for _, addr := range addresses {
|
||||||
|
if addr.Is4() && !haveV4 {
|
||||||
|
d.dnsMap4[host] = addr
|
||||||
|
haveV4 = true
|
||||||
|
} else if addr.Is6() && !haveV6 {
|
||||||
|
d.dnsMap6[host] = addr
|
||||||
|
haveV6 = true
|
||||||
|
}
|
||||||
|
if haveV4 && haveV6 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
|
func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool {
|
||||||
|
a, _, _ := net.SplitHostPort(addr)
|
||||||
|
b, err := netip.ParseAddr(a)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if b.IsLoopback() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
_, found := d.myVpnAddrsTable.Lookup(b)
|
||||||
|
return found //if we found it in this table, it's good
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
||||||
for _, q := range m.Question {
|
for _, q := range m.Question {
|
||||||
switch q.Qtype {
|
switch q.Qtype {
|
||||||
case dns.TypeA:
|
case dns.TypeA, dns.TypeAAAA:
|
||||||
l.Debugf("Query for A %s", q.Name)
|
qType := dns.TypeToString[q.Qtype]
|
||||||
ip := dnsR.Query(q.Name)
|
d.l.Debugf("Query for %s %s", qType, q.Name)
|
||||||
if ip != "" {
|
ip := d.Query(q.Qtype, q.Name)
|
||||||
rr, err := dns.NewRR(fmt.Sprintf("%s A %s", q.Name, ip))
|
if ip.IsValid() {
|
||||||
|
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
m.Answer = append(m.Answer, rr)
|
m.Answer = append(m.Answer, rr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case dns.TypeTXT:
|
case dns.TypeTXT:
|
||||||
a, _, _ := net.SplitHostPort(w.RemoteAddr().String())
|
// We only answer these queries from nebula nodes or localhost
|
||||||
b, err := netip.ParseAddr(a)
|
if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) {
|
||||||
if err != nil {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
d.l.Debugf("Query for TXT %s", q.Name)
|
||||||
// We don't answer these queries from non nebula nodes or localhost
|
ip := d.QueryCert(q.Name)
|
||||||
//l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR)
|
|
||||||
if !dnsR.hostMap.vpnCIDR.Contains(b) && a != "127.0.0.1" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
l.Debugf("Query for TXT %s", q.Name)
|
|
||||||
ip := dnsR.QueryCert(q.Name)
|
|
||||||
if ip != "" {
|
if ip != "" {
|
||||||
rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
|
rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@ -110,26 +150,24 @@ func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) {
|
func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
m := new(dns.Msg)
|
m := new(dns.Msg)
|
||||||
m.SetReply(r)
|
m.SetReply(r)
|
||||||
m.Compress = false
|
m.Compress = false
|
||||||
|
|
||||||
switch r.Opcode {
|
switch r.Opcode {
|
||||||
case dns.OpcodeQuery:
|
case dns.OpcodeQuery:
|
||||||
parseQuery(l, m, w)
|
d.parseQuery(m, w)
|
||||||
}
|
}
|
||||||
|
|
||||||
w.WriteMsg(m)
|
w.WriteMsg(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
func dnsMain(l *logrus.Logger, hostMap *HostMap, c *config.C) func() {
|
func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() {
|
||||||
dnsR = newDnsRecords(hostMap)
|
dnsR = newDnsRecords(l, cs, hostMap)
|
||||||
|
|
||||||
// attach request handler func
|
// attach request handler func
|
||||||
dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) {
|
dns.HandleFunc(".", dnsR.handleDnsRequest)
|
||||||
handleDnsRequest(l, w, r)
|
|
||||||
})
|
|
||||||
|
|
||||||
c.RegisterReloadCallback(func(c *config.C) {
|
c.RegisterReloadCallback(func(c *config.C) {
|
||||||
reloadDns(l, c)
|
reloadDns(l, c)
|
||||||
|
|||||||
@ -1,23 +1,38 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestParsequery(t *testing.T) {
|
func TestParsequery(t *testing.T) {
|
||||||
//TODO: This test is basically pointless
|
l := logrus.New()
|
||||||
hostMap := &HostMap{}
|
hostMap := &HostMap{}
|
||||||
ds := newDnsRecords(hostMap)
|
ds := newDnsRecords(l, &CertState{}, hostMap)
|
||||||
ds.Add("test.com.com", "1.2.3.4")
|
addrs := []netip.Addr{
|
||||||
|
netip.MustParseAddr("1.2.3.4"),
|
||||||
|
netip.MustParseAddr("1.2.3.5"),
|
||||||
|
netip.MustParseAddr("fd01::24"),
|
||||||
|
netip.MustParseAddr("fd01::25"),
|
||||||
|
}
|
||||||
|
ds.Add("test.com.com", addrs)
|
||||||
|
|
||||||
m := new(dns.Msg)
|
m := &dns.Msg{}
|
||||||
m.SetQuestion("test.com.com", dns.TypeA)
|
m.SetQuestion("test.com.com", dns.TypeA)
|
||||||
|
ds.parseQuery(m, nil)
|
||||||
|
assert.NotNil(t, m.Answer)
|
||||||
|
assert.Equal(t, "1.2.3.4", m.Answer[0].(*dns.A).A.String())
|
||||||
|
|
||||||
//parseQuery(m)
|
m = &dns.Msg{}
|
||||||
|
m.SetQuestion("test.com.com", dns.TypeAAAA)
|
||||||
|
ds.parseQuery(m, nil)
|
||||||
|
assert.NotNil(t, m.Answer)
|
||||||
|
assert.Equal(t, "fd01::24", m.Answer[0].(*dns.AAAA).AAAA.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_getDnsServerAddr(t *testing.T) {
|
func Test_getDnsServerAddr(t *testing.T) {
|
||||||
|
|||||||
@ -4,7 +4,6 @@
|
|||||||
package e2e
|
package e2e
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
"testing"
|
"testing"
|
||||||
@ -12,6 +11,7 @@ import (
|
|||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula"
|
"github.com/slackhq/nebula"
|
||||||
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/e2e/router"
|
"github.com/slackhq/nebula/e2e/router"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
@ -21,11 +21,11 @@ import (
|
|||||||
|
|
||||||
func BenchmarkHotPath(b *testing.B) {
|
func BenchmarkHotPath(b *testing.B) {
|
||||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
myControl, _, _, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil)
|
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
|
||||||
|
|
||||||
// Put their info in our lighthouse
|
// Put their info in our lighthouse
|
||||||
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
|
||||||
// Start the servers
|
// Start the servers
|
||||||
myControl.Start()
|
myControl.Start()
|
||||||
@ -35,7 +35,7 @@ func BenchmarkHotPath(b *testing.B) {
|
|||||||
r.CancelFlowLogs()
|
r.CancelFlowLogs()
|
||||||
|
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
|
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
||||||
_ = r.RouteForAllUntilTxTun(theirControl)
|
_ = r.RouteForAllUntilTxTun(theirControl)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -45,18 +45,18 @@ func BenchmarkHotPath(b *testing.B) {
|
|||||||
|
|
||||||
func TestGoodHandshake(t *testing.T) {
|
func TestGoodHandshake(t *testing.T) {
|
||||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil)
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
|
||||||
|
|
||||||
// Put their info in our lighthouse
|
// Put their info in our lighthouse
|
||||||
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
|
||||||
// Start the servers
|
// Start the servers
|
||||||
myControl.Start()
|
myControl.Start()
|
||||||
theirControl.Start()
|
theirControl.Start()
|
||||||
|
|
||||||
t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
|
t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
|
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
||||||
|
|
||||||
t.Log("Have them consume my stage 0 packet. They have a tunnel now")
|
t.Log("Have them consume my stage 0 packet. They have a tunnel now")
|
||||||
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
|
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
|
||||||
@ -77,16 +77,16 @@ func TestGoodHandshake(t *testing.T) {
|
|||||||
myControl.WaitForType(1, 0, theirControl)
|
myControl.WaitForType(1, 0, theirControl)
|
||||||
|
|
||||||
t.Log("Make sure our host infos are correct")
|
t.Log("Make sure our host infos are correct")
|
||||||
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl)
|
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
|
||||||
|
|
||||||
t.Log("Get that cached packet and make sure it looks right")
|
t.Log("Get that cached packet and make sure it looks right")
|
||||||
myCachedPacket := theirControl.GetFromTun(true)
|
myCachedPacket := theirControl.GetFromTun(true)
|
||||||
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
|
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
|
||||||
|
|
||||||
t.Log("Do a bidirectional tunnel test")
|
t.Log("Do a bidirectional tunnel test")
|
||||||
r := router.NewR(t, myControl, theirControl)
|
r := router.NewR(t, myControl, theirControl)
|
||||||
defer r.RenderFlow()
|
defer r.RenderFlow()
|
||||||
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
|
||||||
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
||||||
myControl.Stop()
|
myControl.Stop()
|
||||||
@ -97,12 +97,12 @@ func TestGoodHandshake(t *testing.T) {
|
|||||||
func TestWrongResponderHandshake(t *testing.T) {
|
func TestWrongResponderHandshake(t *testing.T) {
|
||||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
|
||||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.100/24", nil)
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.100/24", nil)
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.99/24", nil)
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.99/24", nil)
|
||||||
evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", "10.128.0.2/24", nil)
|
evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "evil", "10.128.0.2/24", nil)
|
||||||
|
|
||||||
// Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse.
|
// Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse.
|
||||||
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), evilUdpAddr)
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), evilUdpAddr)
|
||||||
|
|
||||||
// Build a router so we don't have to reason who gets which packet
|
// Build a router so we don't have to reason who gets which packet
|
||||||
r := router.NewR(t, myControl, theirControl, evilControl)
|
r := router.NewR(t, myControl, theirControl, evilControl)
|
||||||
@ -114,7 +114,7 @@ func TestWrongResponderHandshake(t *testing.T) {
|
|||||||
evilControl.Start()
|
evilControl.Start()
|
||||||
|
|
||||||
t.Log("Start the handshake process, we will route until we see the evil tunnel closed")
|
t.Log("Start the handshake process, we will route until we see the evil tunnel closed")
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
|
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
||||||
|
|
||||||
h := &header.H{}
|
h := &header.H{}
|
||||||
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
|
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
|
||||||
@ -131,8 +131,8 @@ func TestWrongResponderHandshake(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Log("Evil tunnel is closed, inject the correct udp addr for them")
|
t.Log("Evil tunnel is closed, inject the correct udp addr for them")
|
||||||
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
pendingHi := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), true)
|
pendingHi := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), true)
|
||||||
assert.NotContains(t, pendingHi.RemoteAddrs, evilUdpAddr)
|
assert.NotContains(t, pendingHi.RemoteAddrs, evilUdpAddr)
|
||||||
|
|
||||||
t.Log("Route until we see the cached packet")
|
t.Log("Route until we see the cached packet")
|
||||||
@ -153,18 +153,18 @@ func TestWrongResponderHandshake(t *testing.T) {
|
|||||||
|
|
||||||
t.Log("My cached packet should be received by them")
|
t.Log("My cached packet should be received by them")
|
||||||
myCachedPacket := theirControl.GetFromTun(true)
|
myCachedPacket := theirControl.GetFromTun(true)
|
||||||
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
|
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
|
||||||
|
|
||||||
t.Log("Test the tunnel with them")
|
t.Log("Test the tunnel with them")
|
||||||
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl)
|
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
|
||||||
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
|
||||||
t.Log("Flush all packets from all controllers")
|
t.Log("Flush all packets from all controllers")
|
||||||
r.FlushAll()
|
r.FlushAll()
|
||||||
|
|
||||||
t.Log("Ensure ensure I don't have any hostinfo artifacts from evil")
|
t.Log("Ensure ensure I don't have any hostinfo artifacts from evil")
|
||||||
assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), true), "My pending hostmap should not contain evil")
|
assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIp[0].Addr(), true), "My pending hostmap should not contain evil")
|
||||||
assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), false), "My main hostmap should not contain evil")
|
assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIp[0].Addr(), false), "My main hostmap should not contain evil")
|
||||||
|
|
||||||
//TODO: assert hostmaps for everyone
|
//TODO: assert hostmaps for everyone
|
||||||
r.RenderHostmaps("Final hostmaps", myControl, theirControl, evilControl)
|
r.RenderHostmaps("Final hostmaps", myControl, theirControl, evilControl)
|
||||||
@ -176,17 +176,17 @@ func TestWrongResponderHandshake(t *testing.T) {
|
|||||||
func TestWrongResponderHandshakeStaticHostMap(t *testing.T) {
|
func TestWrongResponderHandshakeStaticHostMap(t *testing.T) {
|
||||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.99/24", nil)
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.99/24", nil)
|
||||||
evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", "10.128.0.2/24", nil)
|
evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "evil", "10.128.0.2/24", nil)
|
||||||
o := m{
|
o := m{
|
||||||
"static_host_map": m{
|
"static_host_map": m{
|
||||||
theirVpnIpNet.Addr().String(): []string{evilUdpAddr.String()},
|
theirVpnIpNet[0].Addr().String(): []string{evilUdpAddr.String()},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.100/24", o)
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.100/24", o)
|
||||||
|
|
||||||
// Put the evil udp addr in for their vpn addr, this is a case of a remote at a static entry changing its vpn addr.
|
// Put the evil udp addr in for their vpn addr, this is a case of a remote at a static entry changing its vpn addr.
|
||||||
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), evilUdpAddr)
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), evilUdpAddr)
|
||||||
|
|
||||||
// Build a router so we don't have to reason who gets which packet
|
// Build a router so we don't have to reason who gets which packet
|
||||||
r := router.NewR(t, myControl, theirControl, evilControl)
|
r := router.NewR(t, myControl, theirControl, evilControl)
|
||||||
@ -198,7 +198,7 @@ func TestWrongResponderHandshakeStaticHostMap(t *testing.T) {
|
|||||||
evilControl.Start()
|
evilControl.Start()
|
||||||
|
|
||||||
t.Log("Start the handshake process, we will route until we see the evil tunnel closed")
|
t.Log("Start the handshake process, we will route until we see the evil tunnel closed")
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
|
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
||||||
|
|
||||||
h := &header.H{}
|
h := &header.H{}
|
||||||
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
|
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
|
||||||
@ -215,8 +215,8 @@ func TestWrongResponderHandshakeStaticHostMap(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Log("Evil tunnel is closed, inject the correct udp addr for them")
|
t.Log("Evil tunnel is closed, inject the correct udp addr for them")
|
||||||
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
pendingHi := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), true)
|
pendingHi := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), true)
|
||||||
assert.NotContains(t, pendingHi.RemoteAddrs, evilUdpAddr)
|
assert.NotContains(t, pendingHi.RemoteAddrs, evilUdpAddr)
|
||||||
|
|
||||||
t.Log("Route until we see the cached packet")
|
t.Log("Route until we see the cached packet")
|
||||||
@ -237,18 +237,19 @@ func TestWrongResponderHandshakeStaticHostMap(t *testing.T) {
|
|||||||
|
|
||||||
t.Log("My cached packet should be received by them")
|
t.Log("My cached packet should be received by them")
|
||||||
myCachedPacket := theirControl.GetFromTun(true)
|
myCachedPacket := theirControl.GetFromTun(true)
|
||||||
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
|
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
|
||||||
|
|
||||||
t.Log("Test the tunnel with them")
|
t.Log("Test the tunnel with them")
|
||||||
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl)
|
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
|
||||||
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
|
||||||
t.Log("Flush all packets from all controllers")
|
t.Log("Flush all packets from all controllers")
|
||||||
r.FlushAll()
|
r.FlushAll()
|
||||||
|
|
||||||
t.Log("Ensure ensure I don't have any hostinfo artifacts from evil")
|
t.Log("Ensure ensure I don't have any hostinfo artifacts from evil")
|
||||||
assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), true), "My pending hostmap should not contain evil")
|
assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIp[0].Addr(), true), "My pending hostmap should not contain evil")
|
||||||
assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), false), "My main hostmap should not contain evil")
|
assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIp[0].Addr(), false), "My main hostmap should not contain evil")
|
||||||
|
//NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete
|
||||||
|
|
||||||
//TODO: assert hostmaps for everyone
|
//TODO: assert hostmaps for everyone
|
||||||
r.RenderHostmaps("Final hostmaps", myControl, theirControl, evilControl)
|
r.RenderHostmaps("Final hostmaps", myControl, theirControl, evilControl)
|
||||||
@ -262,12 +263,12 @@ func TestStage1Race(t *testing.T) {
|
|||||||
// But will eventually collapse down to a single tunnel
|
// But will eventually collapse down to a single tunnel
|
||||||
|
|
||||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil)
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil)
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
|
||||||
|
|
||||||
// Put their info in our lighthouse and vice versa
|
// Put their info in our lighthouse and vice versa
|
||||||
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
|
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||||
|
|
||||||
// Build a router so we don't have to reason who gets which packet
|
// Build a router so we don't have to reason who gets which packet
|
||||||
r := router.NewR(t, myControl, theirControl)
|
r := router.NewR(t, myControl, theirControl)
|
||||||
@ -278,8 +279,8 @@ func TestStage1Race(t *testing.T) {
|
|||||||
theirControl.Start()
|
theirControl.Start()
|
||||||
|
|
||||||
t.Log("Trigger a handshake to start on both me and them")
|
t.Log("Trigger a handshake to start on both me and them")
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
|
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
||||||
theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them"))
|
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
|
||||||
|
|
||||||
t.Log("Get both stage 1 handshake packets")
|
t.Log("Get both stage 1 handshake packets")
|
||||||
myHsForThem := myControl.GetFromUDP(true)
|
myHsForThem := myControl.GetFromUDP(true)
|
||||||
@ -291,14 +292,14 @@ func TestStage1Race(t *testing.T) {
|
|||||||
|
|
||||||
r.Log("Route until they receive a message packet")
|
r.Log("Route until they receive a message packet")
|
||||||
myCachedPacket := r.RouteForAllUntilTxTun(theirControl)
|
myCachedPacket := r.RouteForAllUntilTxTun(theirControl)
|
||||||
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
|
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
|
||||||
|
|
||||||
r.Log("Their cached packet should be received by me")
|
r.Log("Their cached packet should be received by me")
|
||||||
theirCachedPacket := r.RouteForAllUntilTxTun(myControl)
|
theirCachedPacket := r.RouteForAllUntilTxTun(myControl)
|
||||||
assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), 80, 80)
|
assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80)
|
||||||
|
|
||||||
r.Log("Do a bidirectional tunnel test")
|
r.Log("Do a bidirectional tunnel test")
|
||||||
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
|
||||||
myHostmapHosts := myControl.ListHostmapHosts(false)
|
myHostmapHosts := myControl.ListHostmapHosts(false)
|
||||||
myHostmapIndexes := myControl.ListHostmapIndexes(false)
|
myHostmapIndexes := myControl.ListHostmapIndexes(false)
|
||||||
@ -316,7 +317,7 @@ func TestStage1Race(t *testing.T) {
|
|||||||
r.Log("Spin until connection manager tears down a tunnel")
|
r.Log("Spin until connection manager tears down a tunnel")
|
||||||
|
|
||||||
for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
|
for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
|
||||||
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
t.Log("Connection manager hasn't ticked yet")
|
t.Log("Connection manager hasn't ticked yet")
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}
|
}
|
||||||
@ -339,12 +340,12 @@ func TestStage1Race(t *testing.T) {
|
|||||||
|
|
||||||
func TestUncleanShutdownRaceLoser(t *testing.T) {
|
func TestUncleanShutdownRaceLoser(t *testing.T) {
|
||||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil)
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil)
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
|
||||||
|
|
||||||
// Teach my how to get to the relay and that their can be reached via the relay
|
// Teach my how to get to the relay and that their can be reached via the relay
|
||||||
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
|
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||||
|
|
||||||
// Build a router so we don't have to reason who gets which packet
|
// Build a router so we don't have to reason who gets which packet
|
||||||
r := router.NewR(t, myControl, theirControl)
|
r := router.NewR(t, myControl, theirControl)
|
||||||
@ -355,10 +356,10 @@ func TestUncleanShutdownRaceLoser(t *testing.T) {
|
|||||||
theirControl.Start()
|
theirControl.Start()
|
||||||
|
|
||||||
r.Log("Trigger a handshake from me to them")
|
r.Log("Trigger a handshake from me to them")
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
|
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
||||||
|
|
||||||
p := r.RouteForAllUntilTxTun(theirControl)
|
p := r.RouteForAllUntilTxTun(theirControl)
|
||||||
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
|
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
|
||||||
|
|
||||||
r.Log("Nuke my hostmap")
|
r.Log("Nuke my hostmap")
|
||||||
myHostmap := myControl.GetHostmap()
|
myHostmap := myControl.GetHostmap()
|
||||||
@ -366,17 +367,17 @@ func TestUncleanShutdownRaceLoser(t *testing.T) {
|
|||||||
myHostmap.Indexes = map[uint32]*nebula.HostInfo{}
|
myHostmap.Indexes = map[uint32]*nebula.HostInfo{}
|
||||||
myHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
|
myHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
|
||||||
|
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me again"))
|
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me again"))
|
||||||
p = r.RouteForAllUntilTxTun(theirControl)
|
p = r.RouteForAllUntilTxTun(theirControl)
|
||||||
assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
|
assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
|
||||||
|
|
||||||
r.Log("Assert the tunnel works")
|
r.Log("Assert the tunnel works")
|
||||||
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
|
||||||
r.Log("Wait for the dead index to go away")
|
r.Log("Wait for the dead index to go away")
|
||||||
start := len(theirControl.GetHostmap().Indexes)
|
start := len(theirControl.GetHostmap().Indexes)
|
||||||
for {
|
for {
|
||||||
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
if len(theirControl.GetHostmap().Indexes) < start {
|
if len(theirControl.GetHostmap().Indexes) < start {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@ -388,12 +389,12 @@ func TestUncleanShutdownRaceLoser(t *testing.T) {
|
|||||||
|
|
||||||
func TestUncleanShutdownRaceWinner(t *testing.T) {
|
func TestUncleanShutdownRaceWinner(t *testing.T) {
|
||||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil)
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil)
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
|
||||||
|
|
||||||
// Teach my how to get to the relay and that their can be reached via the relay
|
// Teach my how to get to the relay and that their can be reached via the relay
|
||||||
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
|
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||||
|
|
||||||
// Build a router so we don't have to reason who gets which packet
|
// Build a router so we don't have to reason who gets which packet
|
||||||
r := router.NewR(t, myControl, theirControl)
|
r := router.NewR(t, myControl, theirControl)
|
||||||
@ -404,10 +405,10 @@ func TestUncleanShutdownRaceWinner(t *testing.T) {
|
|||||||
theirControl.Start()
|
theirControl.Start()
|
||||||
|
|
||||||
r.Log("Trigger a handshake from me to them")
|
r.Log("Trigger a handshake from me to them")
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
|
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
||||||
|
|
||||||
p := r.RouteForAllUntilTxTun(theirControl)
|
p := r.RouteForAllUntilTxTun(theirControl)
|
||||||
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
|
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
|
||||||
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
||||||
|
|
||||||
r.Log("Nuke my hostmap")
|
r.Log("Nuke my hostmap")
|
||||||
@ -416,18 +417,18 @@ func TestUncleanShutdownRaceWinner(t *testing.T) {
|
|||||||
theirHostmap.Indexes = map[uint32]*nebula.HostInfo{}
|
theirHostmap.Indexes = map[uint32]*nebula.HostInfo{}
|
||||||
theirHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
|
theirHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
|
||||||
|
|
||||||
theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them again"))
|
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them again"))
|
||||||
p = r.RouteForAllUntilTxTun(myControl)
|
p = r.RouteForAllUntilTxTun(myControl)
|
||||||
assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), 80, 80)
|
assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80)
|
||||||
r.RenderHostmaps("Derp hostmaps", myControl, theirControl)
|
r.RenderHostmaps("Derp hostmaps", myControl, theirControl)
|
||||||
|
|
||||||
r.Log("Assert the tunnel works")
|
r.Log("Assert the tunnel works")
|
||||||
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
|
||||||
r.Log("Wait for the dead index to go away")
|
r.Log("Wait for the dead index to go away")
|
||||||
start := len(myControl.GetHostmap().Indexes)
|
start := len(myControl.GetHostmap().Indexes)
|
||||||
for {
|
for {
|
||||||
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
if len(myControl.GetHostmap().Indexes) < start {
|
if len(myControl.GetHostmap().Indexes) < start {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@ -439,14 +440,14 @@ func TestUncleanShutdownRaceWinner(t *testing.T) {
|
|||||||
|
|
||||||
func TestRelays(t *testing.T) {
|
func TestRelays(t *testing.T) {
|
||||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
||||||
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
|
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
|
||||||
|
|
||||||
// Teach my how to get to the relay and that their can be reached via the relay
|
// Teach my how to get to the relay and that their can be reached via the relay
|
||||||
myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
|
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
|
||||||
myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
|
myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
|
||||||
relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
|
relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
|
||||||
// Build a router so we don't have to reason who gets which packet
|
// Build a router so we don't have to reason who gets which packet
|
||||||
r := router.NewR(t, myControl, relayControl, theirControl)
|
r := router.NewR(t, myControl, relayControl, theirControl)
|
||||||
@ -458,11 +459,11 @@ func TestRelays(t *testing.T) {
|
|||||||
theirControl.Start()
|
theirControl.Start()
|
||||||
|
|
||||||
t.Log("Trigger a handshake from me to them via the relay")
|
t.Log("Trigger a handshake from me to them via the relay")
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
|
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
||||||
|
|
||||||
p := r.RouteForAllUntilTxTun(theirControl)
|
p := r.RouteForAllUntilTxTun(theirControl)
|
||||||
r.Log("Assert the tunnel works")
|
r.Log("Assert the tunnel works")
|
||||||
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
|
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
|
||||||
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
|
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
|
||||||
//TODO: assert we actually used the relay even though it should be impossible for a tunnel to have occurred without it
|
//TODO: assert we actually used the relay even though it should be impossible for a tunnel to have occurred without it
|
||||||
}
|
}
|
||||||
@ -470,19 +471,19 @@ func TestRelays(t *testing.T) {
|
|||||||
func TestStage1RaceRelays(t *testing.T) {
|
func TestStage1RaceRelays(t *testing.T) {
|
||||||
//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
|
//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
|
||||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
||||||
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
|
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
|
||||||
|
|
||||||
// Teach my how to get to the relay and that their can be reached via the relay
|
// Teach my how to get to the relay and that their can be reached via the relay
|
||||||
myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
|
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
|
||||||
theirControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
|
theirControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
|
||||||
|
|
||||||
myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
|
myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
|
||||||
theirControl.InjectRelays(myVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
|
theirControl.InjectRelays(myVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
|
||||||
|
|
||||||
relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
|
relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
relayControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
|
relayControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||||
|
|
||||||
// Build a router so we don't have to reason who gets which packet
|
// Build a router so we don't have to reason who gets which packet
|
||||||
r := router.NewR(t, myControl, relayControl, theirControl)
|
r := router.NewR(t, myControl, relayControl, theirControl)
|
||||||
@ -494,14 +495,14 @@ func TestStage1RaceRelays(t *testing.T) {
|
|||||||
theirControl.Start()
|
theirControl.Start()
|
||||||
|
|
||||||
r.Log("Get a tunnel between me and relay")
|
r.Log("Get a tunnel between me and relay")
|
||||||
assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r)
|
assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r)
|
||||||
|
|
||||||
r.Log("Get a tunnel between them and relay")
|
r.Log("Get a tunnel between them and relay")
|
||||||
assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r)
|
assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r)
|
||||||
|
|
||||||
r.Log("Trigger a handshake from both them and me via relay to them and me")
|
r.Log("Trigger a handshake from both them and me via relay to them and me")
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
|
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
||||||
theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them"))
|
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
|
||||||
|
|
||||||
r.Log("Wait for a packet from them to me")
|
r.Log("Wait for a packet from them to me")
|
||||||
p := r.RouteForAllUntilTxTun(myControl)
|
p := r.RouteForAllUntilTxTun(myControl)
|
||||||
@ -519,20 +520,20 @@ func TestStage1RaceRelays(t *testing.T) {
|
|||||||
func TestStage1RaceRelays2(t *testing.T) {
|
func TestStage1RaceRelays2(t *testing.T) {
|
||||||
//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
|
//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
|
||||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
||||||
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
|
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
|
||||||
l := NewTestLogger()
|
l := NewTestLogger()
|
||||||
|
|
||||||
// Teach my how to get to the relay and that their can be reached via the relay
|
// Teach my how to get to the relay and that their can be reached via the relay
|
||||||
myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
|
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
|
||||||
theirControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
|
theirControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
|
||||||
|
|
||||||
myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
|
myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
|
||||||
theirControl.InjectRelays(myVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
|
theirControl.InjectRelays(myVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
|
||||||
|
|
||||||
relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
|
relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
relayControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
|
relayControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||||
|
|
||||||
// Build a router so we don't have to reason who gets which packet
|
// Build a router so we don't have to reason who gets which packet
|
||||||
r := router.NewR(t, myControl, relayControl, theirControl)
|
r := router.NewR(t, myControl, relayControl, theirControl)
|
||||||
@ -545,16 +546,16 @@ func TestStage1RaceRelays2(t *testing.T) {
|
|||||||
|
|
||||||
r.Log("Get a tunnel between me and relay")
|
r.Log("Get a tunnel between me and relay")
|
||||||
l.Info("Get a tunnel between me and relay")
|
l.Info("Get a tunnel between me and relay")
|
||||||
assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r)
|
assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r)
|
||||||
|
|
||||||
r.Log("Get a tunnel between them and relay")
|
r.Log("Get a tunnel between them and relay")
|
||||||
l.Info("Get a tunnel between them and relay")
|
l.Info("Get a tunnel between them and relay")
|
||||||
assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r)
|
assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r)
|
||||||
|
|
||||||
r.Log("Trigger a handshake from both them and me via relay to them and me")
|
r.Log("Trigger a handshake from both them and me via relay to them and me")
|
||||||
l.Info("Trigger a handshake from both them and me via relay to them and me")
|
l.Info("Trigger a handshake from both them and me via relay to them and me")
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
|
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
||||||
theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them"))
|
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
|
||||||
|
|
||||||
//r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone)
|
//r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone)
|
||||||
//r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone)
|
//r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone)
|
||||||
@ -567,7 +568,7 @@ func TestStage1RaceRelays2(t *testing.T) {
|
|||||||
|
|
||||||
r.Log("Assert the tunnel works")
|
r.Log("Assert the tunnel works")
|
||||||
l.Info("Assert the tunnel works")
|
l.Info("Assert the tunnel works")
|
||||||
assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
|
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
||||||
|
|
||||||
t.Log("Wait until we remove extra tunnels")
|
t.Log("Wait until we remove extra tunnels")
|
||||||
l.Info("Wait until we remove extra tunnels")
|
l.Info("Wait until we remove extra tunnels")
|
||||||
@ -587,7 +588,7 @@ func TestStage1RaceRelays2(t *testing.T) {
|
|||||||
"theirControl": len(theirControl.GetHostmap().Indexes),
|
"theirControl": len(theirControl.GetHostmap().Indexes),
|
||||||
"relayControl": len(relayControl.GetHostmap().Indexes),
|
"relayControl": len(relayControl.GetHostmap().Indexes),
|
||||||
}).Info("Waiting for hostinfos to be removed...")
|
}).Info("Waiting for hostinfos to be removed...")
|
||||||
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
t.Log("Connection manager hasn't ticked yet")
|
t.Log("Connection manager hasn't ticked yet")
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
retries--
|
retries--
|
||||||
@ -595,7 +596,7 @@ func TestStage1RaceRelays2(t *testing.T) {
|
|||||||
|
|
||||||
r.Log("Assert the tunnel works")
|
r.Log("Assert the tunnel works")
|
||||||
l.Info("Assert the tunnel works")
|
l.Info("Assert the tunnel works")
|
||||||
assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
|
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
||||||
|
|
||||||
myControl.Stop()
|
myControl.Stop()
|
||||||
theirControl.Stop()
|
theirControl.Stop()
|
||||||
@ -607,14 +608,14 @@ func TestStage1RaceRelays2(t *testing.T) {
|
|||||||
|
|
||||||
func TestRehandshakingRelays(t *testing.T) {
|
func TestRehandshakingRelays(t *testing.T) {
|
||||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
|
||||||
relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
|
relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
|
||||||
|
|
||||||
// Teach my how to get to the relay and that their can be reached via the relay
|
// Teach my how to get to the relay and that their can be reached via the relay
|
||||||
myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
|
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
|
||||||
myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
|
myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
|
||||||
relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
|
relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
|
||||||
// Build a router so we don't have to reason who gets which packet
|
// Build a router so we don't have to reason who gets which packet
|
||||||
r := router.NewR(t, myControl, relayControl, theirControl)
|
r := router.NewR(t, myControl, relayControl, theirControl)
|
||||||
@ -626,17 +627,17 @@ func TestRehandshakingRelays(t *testing.T) {
|
|||||||
theirControl.Start()
|
theirControl.Start()
|
||||||
|
|
||||||
t.Log("Trigger a handshake from me to them via the relay")
|
t.Log("Trigger a handshake from me to them via the relay")
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
|
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
||||||
|
|
||||||
p := r.RouteForAllUntilTxTun(theirControl)
|
p := r.RouteForAllUntilTxTun(theirControl)
|
||||||
r.Log("Assert the tunnel works")
|
r.Log("Assert the tunnel works")
|
||||||
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
|
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
|
||||||
r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
|
r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
|
||||||
|
|
||||||
// When I update the certificate for the relay, both me and them will have 2 host infos for the relay,
|
// When I update the certificate for the relay, both me and them will have 2 host infos for the relay,
|
||||||
// and the main host infos will not have any relay state to handle the me<->relay<->them tunnel.
|
// and the main host infos will not have any relay state to handle the me<->relay<->them tunnel.
|
||||||
r.Log("Renew relay certificate and spin until me and them sees it")
|
r.Log("Renew relay certificate and spin until me and them sees it")
|
||||||
_, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{relayVpnIpNet}, nil, []string{"new group"})
|
_, _, myNextPrivKey, myNextPEM := NewTestCert(cert.Version1, ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"})
|
||||||
|
|
||||||
caB, err := ca.MarshalPEM()
|
caB, err := ca.MarshalPEM()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -654,8 +655,8 @@ func TestRehandshakingRelays(t *testing.T) {
|
|||||||
|
|
||||||
for {
|
for {
|
||||||
r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet")
|
r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet")
|
||||||
assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r)
|
assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r)
|
||||||
c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false)
|
c := myControl.GetHostInfoByVpnAddr(relayVpnIpNet[0].Addr(), false)
|
||||||
if len(c.Cert.Groups()) != 0 {
|
if len(c.Cert.Groups()) != 0 {
|
||||||
// We have a new certificate now
|
// We have a new certificate now
|
||||||
r.Log("Certificate between my and relay is updated!")
|
r.Log("Certificate between my and relay is updated!")
|
||||||
@ -667,8 +668,8 @@ func TestRehandshakingRelays(t *testing.T) {
|
|||||||
|
|
||||||
for {
|
for {
|
||||||
r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet")
|
r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet")
|
||||||
assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r)
|
assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r)
|
||||||
c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false)
|
c := theirControl.GetHostInfoByVpnAddr(relayVpnIpNet[0].Addr(), false)
|
||||||
if len(c.Cert.Groups()) != 0 {
|
if len(c.Cert.Groups()) != 0 {
|
||||||
// We have a new certificate now
|
// We have a new certificate now
|
||||||
r.Log("Certificate between their and relay is updated!")
|
r.Log("Certificate between their and relay is updated!")
|
||||||
@ -679,13 +680,13 @@ func TestRehandshakingRelays(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
r.Log("Assert the relay tunnel still works")
|
r.Log("Assert the relay tunnel still works")
|
||||||
assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
|
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
||||||
r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
|
r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
|
||||||
// We should have two hostinfos on all sides
|
// We should have two hostinfos on all sides
|
||||||
for len(myControl.GetHostmap().Indexes) != 2 {
|
for len(myControl.GetHostmap().Indexes) != 2 {
|
||||||
t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes))
|
t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes))
|
||||||
r.Log("Assert the relay tunnel still works")
|
r.Log("Assert the relay tunnel still works")
|
||||||
assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
|
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
||||||
r.Log("yupitdoes")
|
r.Log("yupitdoes")
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}
|
}
|
||||||
@ -693,7 +694,7 @@ func TestRehandshakingRelays(t *testing.T) {
|
|||||||
for len(theirControl.GetHostmap().Indexes) != 2 {
|
for len(theirControl.GetHostmap().Indexes) != 2 {
|
||||||
t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes))
|
t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes))
|
||||||
r.Log("Assert the relay tunnel still works")
|
r.Log("Assert the relay tunnel still works")
|
||||||
assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
|
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
||||||
r.Log("yupitdoes")
|
r.Log("yupitdoes")
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}
|
}
|
||||||
@ -701,7 +702,7 @@ func TestRehandshakingRelays(t *testing.T) {
|
|||||||
for len(relayControl.GetHostmap().Indexes) != 2 {
|
for len(relayControl.GetHostmap().Indexes) != 2 {
|
||||||
t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes))
|
t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes))
|
||||||
r.Log("Assert the relay tunnel still works")
|
r.Log("Assert the relay tunnel still works")
|
||||||
assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
|
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
||||||
r.Log("yupitdoes")
|
r.Log("yupitdoes")
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}
|
}
|
||||||
@ -711,14 +712,14 @@ func TestRehandshakingRelays(t *testing.T) {
|
|||||||
func TestRehandshakingRelaysPrimary(t *testing.T) {
|
func TestRehandshakingRelaysPrimary(t *testing.T) {
|
||||||
// This test is the same as TestRehandshakingRelays but one of the terminal types is a primary swap winner
|
// This test is the same as TestRehandshakingRelays but one of the terminal types is a primary swap winner
|
||||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.128/24", m{"relay": m{"use_relays": true}})
|
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.128/24", m{"relay": m{"use_relays": true}})
|
||||||
relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", "10.128.0.1/24", m{"relay": m{"am_relay": true}})
|
relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.1/24", m{"relay": m{"am_relay": true}})
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
|
||||||
|
|
||||||
// Teach my how to get to the relay and that their can be reached via the relay
|
// Teach my how to get to the relay and that their can be reached via the relay
|
||||||
myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
|
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
|
||||||
myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
|
myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
|
||||||
relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
|
relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
|
|
||||||
// Build a router so we don't have to reason who gets which packet
|
// Build a router so we don't have to reason who gets which packet
|
||||||
r := router.NewR(t, myControl, relayControl, theirControl)
|
r := router.NewR(t, myControl, relayControl, theirControl)
|
||||||
@ -730,17 +731,17 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
|
|||||||
theirControl.Start()
|
theirControl.Start()
|
||||||
|
|
||||||
t.Log("Trigger a handshake from me to them via the relay")
|
t.Log("Trigger a handshake from me to them via the relay")
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
|
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
||||||
|
|
||||||
p := r.RouteForAllUntilTxTun(theirControl)
|
p := r.RouteForAllUntilTxTun(theirControl)
|
||||||
r.Log("Assert the tunnel works")
|
r.Log("Assert the tunnel works")
|
||||||
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
|
assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
|
||||||
r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
|
r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
|
||||||
|
|
||||||
// When I update the certificate for the relay, both me and them will have 2 host infos for the relay,
|
// When I update the certificate for the relay, both me and them will have 2 host infos for the relay,
|
||||||
// and the main host infos will not have any relay state to handle the me<->relay<->them tunnel.
|
// and the main host infos will not have any relay state to handle the me<->relay<->them tunnel.
|
||||||
r.Log("Renew relay certificate and spin until me and them sees it")
|
r.Log("Renew relay certificate and spin until me and them sees it")
|
||||||
_, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{relayVpnIpNet}, nil, []string{"new group"})
|
_, _, myNextPrivKey, myNextPEM := NewTestCert(cert.Version1, ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"})
|
||||||
|
|
||||||
caB, err := ca.MarshalPEM()
|
caB, err := ca.MarshalPEM()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -758,8 +759,8 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
|
|||||||
|
|
||||||
for {
|
for {
|
||||||
r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet")
|
r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet")
|
||||||
assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r)
|
assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r)
|
||||||
c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false)
|
c := myControl.GetHostInfoByVpnAddr(relayVpnIpNet[0].Addr(), false)
|
||||||
if len(c.Cert.Groups()) != 0 {
|
if len(c.Cert.Groups()) != 0 {
|
||||||
// We have a new certificate now
|
// We have a new certificate now
|
||||||
r.Log("Certificate between my and relay is updated!")
|
r.Log("Certificate between my and relay is updated!")
|
||||||
@ -771,8 +772,8 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
|
|||||||
|
|
||||||
for {
|
for {
|
||||||
r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet")
|
r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet")
|
||||||
assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r)
|
assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r)
|
||||||
c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false)
|
c := theirControl.GetHostInfoByVpnAddr(relayVpnIpNet[0].Addr(), false)
|
||||||
if len(c.Cert.Groups()) != 0 {
|
if len(c.Cert.Groups()) != 0 {
|
||||||
// We have a new certificate now
|
// We have a new certificate now
|
||||||
r.Log("Certificate between their and relay is updated!")
|
r.Log("Certificate between their and relay is updated!")
|
||||||
@ -783,13 +784,13 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
r.Log("Assert the relay tunnel still works")
|
r.Log("Assert the relay tunnel still works")
|
||||||
assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
|
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
||||||
r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
|
r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
|
||||||
// We should have two hostinfos on all sides
|
// We should have two hostinfos on all sides
|
||||||
for len(myControl.GetHostmap().Indexes) != 2 {
|
for len(myControl.GetHostmap().Indexes) != 2 {
|
||||||
t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes))
|
t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes))
|
||||||
r.Log("Assert the relay tunnel still works")
|
r.Log("Assert the relay tunnel still works")
|
||||||
assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
|
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
||||||
r.Log("yupitdoes")
|
r.Log("yupitdoes")
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}
|
}
|
||||||
@ -797,7 +798,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
|
|||||||
for len(theirControl.GetHostmap().Indexes) != 2 {
|
for len(theirControl.GetHostmap().Indexes) != 2 {
|
||||||
t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes))
|
t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes))
|
||||||
r.Log("Assert the relay tunnel still works")
|
r.Log("Assert the relay tunnel still works")
|
||||||
assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
|
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
||||||
r.Log("yupitdoes")
|
r.Log("yupitdoes")
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}
|
}
|
||||||
@ -805,7 +806,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
|
|||||||
for len(relayControl.GetHostmap().Indexes) != 2 {
|
for len(relayControl.GetHostmap().Indexes) != 2 {
|
||||||
t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes))
|
t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes))
|
||||||
r.Log("Assert the relay tunnel still works")
|
r.Log("Assert the relay tunnel still works")
|
||||||
assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
|
assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
||||||
r.Log("yupitdoes")
|
r.Log("yupitdoes")
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}
|
}
|
||||||
@ -814,12 +815,12 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
|
|||||||
|
|
||||||
func TestRehandshaking(t *testing.T) {
|
func TestRehandshaking(t *testing.T) {
|
||||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", "10.128.0.2/24", nil)
|
myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.2/24", nil)
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", "10.128.0.1/24", nil)
|
theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.1/24", nil)
|
||||||
|
|
||||||
// Put their info in our lighthouse and vice versa
|
// Put their info in our lighthouse and vice versa
|
||||||
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
|
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||||
|
|
||||||
// Build a router so we don't have to reason who gets which packet
|
// Build a router so we don't have to reason who gets which packet
|
||||||
r := router.NewR(t, myControl, theirControl)
|
r := router.NewR(t, myControl, theirControl)
|
||||||
@ -830,12 +831,12 @@ func TestRehandshaking(t *testing.T) {
|
|||||||
theirControl.Start()
|
theirControl.Start()
|
||||||
|
|
||||||
t.Log("Stand up a tunnel between me and them")
|
t.Log("Stand up a tunnel between me and them")
|
||||||
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
|
||||||
r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
|
r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
|
||||||
|
|
||||||
r.Log("Renew my certificate and spin until their sees it")
|
r.Log("Renew my certificate and spin until their sees it")
|
||||||
_, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{myVpnIpNet}, nil, []string{"new group"})
|
_, _, myNextPrivKey, myNextPEM := NewTestCert(cert.Version1, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), myVpnIpNet, nil, []string{"new group"})
|
||||||
|
|
||||||
caB, err := ca.MarshalPEM()
|
caB, err := ca.MarshalPEM()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -852,8 +853,8 @@ func TestRehandshaking(t *testing.T) {
|
|||||||
myConfig.ReloadConfigString(string(rc))
|
myConfig.ReloadConfigString(string(rc))
|
||||||
|
|
||||||
for {
|
for {
|
||||||
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false)
|
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
||||||
if len(c.Cert.Groups()) != 0 {
|
if len(c.Cert.Groups()) != 0 {
|
||||||
// We have a new certificate now
|
// We have a new certificate now
|
||||||
break
|
break
|
||||||
@ -880,19 +881,19 @@ func TestRehandshaking(t *testing.T) {
|
|||||||
|
|
||||||
r.Log("Spin until there is only 1 tunnel")
|
r.Log("Spin until there is only 1 tunnel")
|
||||||
for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
|
for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
|
||||||
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
t.Log("Connection manager hasn't ticked yet")
|
t.Log("Connection manager hasn't ticked yet")
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
myFinalHostmapHosts := myControl.ListHostmapHosts(false)
|
myFinalHostmapHosts := myControl.ListHostmapHosts(false)
|
||||||
myFinalHostmapIndexes := myControl.ListHostmapIndexes(false)
|
myFinalHostmapIndexes := myControl.ListHostmapIndexes(false)
|
||||||
theirFinalHostmapHosts := theirControl.ListHostmapHosts(false)
|
theirFinalHostmapHosts := theirControl.ListHostmapHosts(false)
|
||||||
theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false)
|
theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false)
|
||||||
|
|
||||||
// Make sure the correct tunnel won
|
// Make sure the correct tunnel won
|
||||||
c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false)
|
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
||||||
assert.Contains(t, c.Cert.Groups(), "new group")
|
assert.Contains(t, c.Cert.Groups(), "new group")
|
||||||
|
|
||||||
// We should only have a single tunnel now on both sides
|
// We should only have a single tunnel now on both sides
|
||||||
@ -911,12 +912,12 @@ func TestRehandshakingLoser(t *testing.T) {
|
|||||||
// The purpose of this test is that the race loser renews their certificate and rehandshakes. The final tunnel
|
// The purpose of this test is that the race loser renews their certificate and rehandshakes. The final tunnel
|
||||||
// Should be the one with the new certificate
|
// Should be the one with the new certificate
|
||||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", "10.128.0.2/24", nil)
|
myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.2/24", nil)
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", "10.128.0.1/24", nil)
|
theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.1/24", nil)
|
||||||
|
|
||||||
// Put their info in our lighthouse and vice versa
|
// Put their info in our lighthouse and vice versa
|
||||||
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
|
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||||
|
|
||||||
// Build a router so we don't have to reason who gets which packet
|
// Build a router so we don't have to reason who gets which packet
|
||||||
r := router.NewR(t, myControl, theirControl)
|
r := router.NewR(t, myControl, theirControl)
|
||||||
@ -927,16 +928,12 @@ func TestRehandshakingLoser(t *testing.T) {
|
|||||||
theirControl.Start()
|
theirControl.Start()
|
||||||
|
|
||||||
t.Log("Stand up a tunnel between me and them")
|
t.Log("Stand up a tunnel between me and them")
|
||||||
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
|
||||||
tt1 := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false)
|
|
||||||
tt2 := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false)
|
|
||||||
fmt.Println(tt1.LocalIndex, tt2.LocalIndex)
|
|
||||||
|
|
||||||
r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
|
r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
|
||||||
|
|
||||||
r.Log("Renew their certificate and spin until mine sees it")
|
r.Log("Renew their certificate and spin until mine sees it")
|
||||||
_, _, theirNextPrivKey, theirNextPEM := NewTestCert(ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{theirVpnIpNet}, nil, []string{"their new group"})
|
_, _, theirNextPrivKey, theirNextPEM := NewTestCert(cert.Version1, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), theirVpnIpNet, nil, []string{"their new group"})
|
||||||
|
|
||||||
caB, err := ca.MarshalPEM()
|
caB, err := ca.MarshalPEM()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -953,8 +950,8 @@ func TestRehandshakingLoser(t *testing.T) {
|
|||||||
theirConfig.ReloadConfigString(string(rc))
|
theirConfig.ReloadConfigString(string(rc))
|
||||||
|
|
||||||
for {
|
for {
|
||||||
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false)
|
theirCertInMe := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
|
||||||
|
|
||||||
if slices.Contains(theirCertInMe.Cert.Groups(), "their new group") {
|
if slices.Contains(theirCertInMe.Cert.Groups(), "their new group") {
|
||||||
break
|
break
|
||||||
@ -980,19 +977,19 @@ func TestRehandshakingLoser(t *testing.T) {
|
|||||||
|
|
||||||
r.Log("Spin until there is only 1 tunnel")
|
r.Log("Spin until there is only 1 tunnel")
|
||||||
for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
|
for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
|
||||||
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
t.Log("Connection manager hasn't ticked yet")
|
t.Log("Connection manager hasn't ticked yet")
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
myFinalHostmapHosts := myControl.ListHostmapHosts(false)
|
myFinalHostmapHosts := myControl.ListHostmapHosts(false)
|
||||||
myFinalHostmapIndexes := myControl.ListHostmapIndexes(false)
|
myFinalHostmapIndexes := myControl.ListHostmapIndexes(false)
|
||||||
theirFinalHostmapHosts := theirControl.ListHostmapHosts(false)
|
theirFinalHostmapHosts := theirControl.ListHostmapHosts(false)
|
||||||
theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false)
|
theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false)
|
||||||
|
|
||||||
// Make sure the correct tunnel won
|
// Make sure the correct tunnel won
|
||||||
theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false)
|
theirCertInMe := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
|
||||||
assert.Contains(t, theirCertInMe.Cert.Groups(), "their new group")
|
assert.Contains(t, theirCertInMe.Cert.Groups(), "their new group")
|
||||||
|
|
||||||
// We should only have a single tunnel now on both sides
|
// We should only have a single tunnel now on both sides
|
||||||
@ -1011,12 +1008,12 @@ func TestRaceRegression(t *testing.T) {
|
|||||||
// We had a bug where we were not finding the duplicate handshake and responding to the final stage 1 which
|
// We had a bug where we were not finding the duplicate handshake and responding to the final stage 1 which
|
||||||
// caused a cross-linked hostinfo
|
// caused a cross-linked hostinfo
|
||||||
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil)
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
|
||||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
|
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
|
||||||
|
|
||||||
// Put their info in our lighthouse
|
// Put their info in our lighthouse
|
||||||
myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
|
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||||
theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
|
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||||
|
|
||||||
// Start the servers
|
// Start the servers
|
||||||
myControl.Start()
|
myControl.Start()
|
||||||
@ -1030,8 +1027,8 @@ func TestRaceRegression(t *testing.T) {
|
|||||||
//them rx stage:2 initiatorIndex=120607833 responderIndex=4209862089
|
//them rx stage:2 initiatorIndex=120607833 responderIndex=4209862089
|
||||||
|
|
||||||
t.Log("Start both handshakes")
|
t.Log("Start both handshakes")
|
||||||
myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
|
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
|
||||||
theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them"))
|
theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
|
||||||
|
|
||||||
t.Log("Get both stage 1")
|
t.Log("Get both stage 1")
|
||||||
myStage1ForThem := myControl.GetFromUDP(true)
|
myStage1ForThem := myControl.GetFromUDP(true)
|
||||||
@ -1061,12 +1058,52 @@ func TestRaceRegression(t *testing.T) {
|
|||||||
r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
|
r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
|
||||||
|
|
||||||
t.Log("Make sure the tunnel still works")
|
t.Log("Make sure the tunnel still works")
|
||||||
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
|
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||||
|
|
||||||
myControl.Stop()
|
myControl.Stop()
|
||||||
theirControl.Stop()
|
theirControl.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestV2NonPrimaryWithLighthouse(t *testing.T) {
|
||||||
|
ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
|
lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh ", "10.128.0.1/24, ff::1/64", m{"lighthouse": m{"am_lighthouse": true}})
|
||||||
|
|
||||||
|
o := m{
|
||||||
|
"static_host_map": m{
|
||||||
|
lhVpnIpNet[1].Addr().String(): []string{lhUdpAddr.String()},
|
||||||
|
},
|
||||||
|
"lighthouse": m{
|
||||||
|
"hosts": []string{lhVpnIpNet[1].Addr().String()},
|
||||||
|
"local_allow_list": m{
|
||||||
|
// Try and block our lighthouse updates from using the actual addresses assigned to this computer
|
||||||
|
// If we start discovering addresses the test router doesn't know about then test traffic cant flow
|
||||||
|
"10.0.0.0/24": true,
|
||||||
|
"::/0": false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.2/24, ff::2/64", o)
|
||||||
|
theirControl, theirVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "10.128.0.3/24, ff::3/64", o)
|
||||||
|
|
||||||
|
// Build a router so we don't have to reason who gets which packet
|
||||||
|
r := router.NewR(t, lhControl, myControl, theirControl)
|
||||||
|
defer r.RenderFlow()
|
||||||
|
|
||||||
|
// Start the servers
|
||||||
|
lhControl.Start()
|
||||||
|
myControl.Start()
|
||||||
|
theirControl.Start()
|
||||||
|
|
||||||
|
t.Log("Stand up an ipv6 tunnel between me and them")
|
||||||
|
assert.True(t, myVpnIpNet[1].Addr().Is6())
|
||||||
|
assert.True(t, theirVpnIpNet[1].Addr().Is6())
|
||||||
|
assertTunnel(t, myVpnIpNet[1].Addr(), theirVpnIpNet[1].Addr(), myControl, theirControl, r)
|
||||||
|
|
||||||
|
lhControl.Stop()
|
||||||
|
myControl.Stop()
|
||||||
|
theirControl.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
//TODO: test
|
//TODO: test
|
||||||
// Race winner renews and handshakes
|
// Race winner renews and handshakes
|
||||||
// Race loser renews and handshakes
|
// Race loser renews and handshakes
|
||||||
|
|||||||
@ -48,7 +48,7 @@ func NewTestCaCert(before, after time.Time, networks, unsafeNetworks []netip.Pre
|
|||||||
|
|
||||||
// NewTestCert will generate a signed certificate with the provided details.
|
// NewTestCert will generate a signed certificate with the provided details.
|
||||||
// Expiry times are defaulted if you do not pass them in
|
// Expiry times are defaulted if you do not pass them in
|
||||||
func NewTestCert(ca cert.Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) {
|
func NewTestCert(v cert.Version, ca cert.Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) {
|
||||||
if before.IsZero() {
|
if before.IsZero() {
|
||||||
before = time.Now().Add(time.Second * -60).Round(time.Second)
|
before = time.Now().Add(time.Second * -60).Round(time.Second)
|
||||||
}
|
}
|
||||||
@ -59,7 +59,7 @@ func NewTestCert(ca cert.Certificate, key []byte, name string, before, after tim
|
|||||||
|
|
||||||
pub, rawPriv := x25519Keypair()
|
pub, rawPriv := x25519Keypair()
|
||||||
nc := &cert.TBSCertificate{
|
nc := &cert.TBSCertificate{
|
||||||
Version: cert.Version1,
|
Version: v,
|
||||||
Name: name,
|
Name: name,
|
||||||
Networks: networks,
|
Networks: networks,
|
||||||
UnsafeNetworks: unsafeNetworks,
|
UnsafeNetworks: unsafeNetworks,
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -26,25 +27,35 @@ import (
|
|||||||
type m map[string]interface{}
|
type m map[string]interface{}
|
||||||
|
|
||||||
// newSimpleServer creates a nebula instance with many assumptions
|
// newSimpleServer creates a nebula instance with many assumptions
|
||||||
func newSimpleServer(caCrt cert.Certificate, caKey []byte, name string, sVpnIpNet string, overrides m) (*nebula.Control, netip.Prefix, netip.AddrPort, *config.C) {
|
func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
|
||||||
l := NewTestLogger()
|
l := NewTestLogger()
|
||||||
|
|
||||||
vpnIpNet, err := netip.ParsePrefix(sVpnIpNet)
|
var vpnNetworks []netip.Prefix
|
||||||
|
for _, sn := range strings.Split(sVpnNetworks, ",") {
|
||||||
|
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
vpnNetworks = append(vpnNetworks, vpnIpNet)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(vpnNetworks) == 0 {
|
||||||
|
panic("no vpn networks")
|
||||||
|
}
|
||||||
|
|
||||||
var udpAddr netip.AddrPort
|
var udpAddr netip.AddrPort
|
||||||
if vpnIpNet.Addr().Is4() {
|
if vpnNetworks[0].Addr().Is4() {
|
||||||
budpIp := vpnIpNet.Addr().As4()
|
budpIp := vpnNetworks[0].Addr().As4()
|
||||||
budpIp[1] -= 128
|
budpIp[1] -= 128
|
||||||
udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242)
|
udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242)
|
||||||
} else {
|
} else {
|
||||||
budpIp := vpnIpNet.Addr().As16()
|
budpIp := vpnNetworks[0].Addr().As16()
|
||||||
budpIp[13] -= 128
|
// beef for funsies
|
||||||
|
budpIp[2] = 190
|
||||||
|
budpIp[3] = 239
|
||||||
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
|
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
|
||||||
}
|
}
|
||||||
_, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{vpnIpNet}, nil, []string{})
|
_, _, myPrivKey, myPEM := NewTestCert(v, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{})
|
||||||
|
|
||||||
caB, err := caCrt.MarshalPEM()
|
caB, err := caCrt.MarshalPEM()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -88,11 +99,16 @@ func newSimpleServer(caCrt cert.Certificate, caKey []byte, name string, sVpnIpNe
|
|||||||
}
|
}
|
||||||
|
|
||||||
if overrides != nil {
|
if overrides != nil {
|
||||||
err = mergo.Merge(&overrides, mc, mergo.WithAppendSlice)
|
final := m{}
|
||||||
|
err = mergo.Merge(&final, overrides, mergo.WithAppendSlice)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
mc = overrides
|
err = mergo.Merge(&final, mc, mergo.WithAppendSlice)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
mc = final
|
||||||
}
|
}
|
||||||
|
|
||||||
cb, err := yaml.Marshal(mc)
|
cb, err := yaml.Marshal(mc)
|
||||||
@ -109,7 +125,7 @@ func newSimpleServer(caCrt cert.Certificate, caKey []byte, name string, sVpnIpNe
|
|||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return control, vpnIpNet, udpAddr, c
|
return control, vpnNetworks, udpAddr, c
|
||||||
}
|
}
|
||||||
|
|
||||||
type doneCb func()
|
type doneCb func()
|
||||||
@ -132,27 +148,28 @@ func deadline(t *testing.T, seconds time.Duration) doneCb {
|
|||||||
|
|
||||||
func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
|
func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
|
||||||
// Send a packet from them to me
|
// Send a packet from them to me
|
||||||
controlB.InjectTunUDPPacket(vpnIpA, 80, 90, []byte("Hi from B"))
|
controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B"))
|
||||||
bPacket := r.RouteForAllUntilTxTun(controlA)
|
bPacket := r.RouteForAllUntilTxTun(controlA)
|
||||||
assertUdpPacket(t, []byte("Hi from B"), bPacket, vpnIpB, vpnIpA, 90, 80)
|
assertUdpPacket(t, []byte("Hi from B"), bPacket, vpnIpB, vpnIpA, 90, 80)
|
||||||
|
|
||||||
// And once more from me to them
|
// And once more from me to them
|
||||||
controlA.InjectTunUDPPacket(vpnIpB, 80, 90, []byte("Hello from A"))
|
controlA.InjectTunUDPPacket(vpnIpB, 80, vpnIpA, 90, []byte("Hello from A"))
|
||||||
aPacket := r.RouteForAllUntilTxTun(controlB)
|
aPacket := r.RouteForAllUntilTxTun(controlB)
|
||||||
assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80)
|
assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80)
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control) {
|
func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpnNetsB []netip.Prefix, controlA, controlB *nebula.Control) {
|
||||||
// Get both host infos
|
// Get both host infos
|
||||||
hBinA := controlA.GetHostInfoByVpnIp(vpnIpB, false)
|
//TODO: we may want to loop over each vpnAddr and assert all the things
|
||||||
assert.NotNil(t, hBinA, "Host B was not found by vpnIp in controlA")
|
hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false)
|
||||||
|
assert.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA")
|
||||||
|
|
||||||
hAinB := controlB.GetHostInfoByVpnIp(vpnIpA, false)
|
hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false)
|
||||||
assert.NotNil(t, hAinB, "Host A was not found by vpnIp in controlB")
|
assert.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB")
|
||||||
|
|
||||||
// Check that both vpn and real addr are correct
|
// Check that both vpn and real addr are correct
|
||||||
assert.Equal(t, vpnIpB, hBinA.VpnIp, "Host B VpnIp is wrong in control A")
|
assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A")
|
||||||
assert.Equal(t, vpnIpA, hAinB.VpnIp, "Host A VpnIp is wrong in control B")
|
assert.EqualValues(t, getAddrs(vpnNetsA), hAinB.VpnAddrs, "Host A VpnIp is wrong in control B")
|
||||||
|
|
||||||
assert.Equal(t, addrB, hBinA.CurrentRemote, "Host B remote is wrong in control A")
|
assert.Equal(t, addrB, hBinA.CurrentRemote, "Host B remote is wrong in control A")
|
||||||
assert.Equal(t, addrA, hAinB.CurrentRemote, "Host A remote is wrong in control B")
|
assert.Equal(t, addrA, hAinB.CurrentRemote, "Host A remote is wrong in control B")
|
||||||
@ -179,6 +196,33 @@ func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnIpA, vpnIp
|
|||||||
}
|
}
|
||||||
|
|
||||||
func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
|
func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
|
||||||
|
if toIp.Is6() {
|
||||||
|
assertUdpPacket6(t, expected, b, fromIp, toIp, fromPort, toPort)
|
||||||
|
} else {
|
||||||
|
assertUdpPacket4(t, expected, b, fromIp, toIp, fromPort, toPort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertUdpPacket6(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
|
||||||
|
packet := gopacket.NewPacket(b, layers.LayerTypeIPv6, gopacket.Lazy)
|
||||||
|
v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6)
|
||||||
|
assert.NotNil(t, v6, "No ipv6 data found")
|
||||||
|
|
||||||
|
assert.Equal(t, fromIp.AsSlice(), []byte(v6.SrcIP), "Source ip was incorrect")
|
||||||
|
assert.Equal(t, toIp.AsSlice(), []byte(v6.DstIP), "Dest ip was incorrect")
|
||||||
|
|
||||||
|
udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
|
||||||
|
assert.NotNil(t, udp, "No udp data found")
|
||||||
|
|
||||||
|
assert.Equal(t, fromPort, uint16(udp.SrcPort), "Source port was incorrect")
|
||||||
|
assert.Equal(t, toPort, uint16(udp.DstPort), "Dest port was incorrect")
|
||||||
|
|
||||||
|
data := packet.ApplicationLayer()
|
||||||
|
assert.NotNil(t, data)
|
||||||
|
assert.Equal(t, expected, data.Payload(), "Data was incorrect")
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertUdpPacket4(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
|
||||||
packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy)
|
packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy)
|
||||||
v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
|
v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
|
||||||
assert.NotNil(t, v4, "No ipv4 data found")
|
assert.NotNil(t, v4, "No ipv4 data found")
|
||||||
@ -197,6 +241,14 @@ func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr,
|
|||||||
assert.Equal(t, expected, data.Payload(), "Data was incorrect")
|
assert.Equal(t, expected, data.Payload(), "Data was incorrect")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getAddrs(ns []netip.Prefix) []netip.Addr {
|
||||||
|
var a []netip.Addr
|
||||||
|
for _, n := range ns {
|
||||||
|
a = append(a, n.Addr())
|
||||||
|
}
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
func NewTestLogger() *logrus.Logger {
|
func NewTestLogger() *logrus.Logger {
|
||||||
l := logrus.New()
|
l := logrus.New()
|
||||||
|
|
||||||
|
|||||||
@ -58,8 +58,9 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {
|
|||||||
var lines []string
|
var lines []string
|
||||||
var globalLines []*edge
|
var globalLines []*edge
|
||||||
|
|
||||||
clusterName := strings.Trim(c.GetCert().Name(), " ")
|
crt := c.GetCertState().GetDefaultCertificate()
|
||||||
clusterVpnIp := c.GetCert().Networks()[0].Addr()
|
clusterName := strings.Trim(crt.Name(), " ")
|
||||||
|
clusterVpnIp := crt.Networks()[0].Addr()
|
||||||
r := fmt.Sprintf("\tsubgraph %s[\"%s (%s)\"]\n", clusterName, clusterName, clusterVpnIp)
|
r := fmt.Sprintf("\tsubgraph %s[\"%s (%s)\"]\n", clusterName, clusterName, clusterVpnIp)
|
||||||
|
|
||||||
hm := c.GetHostmap()
|
hm := c.GetHostmap()
|
||||||
@ -101,7 +102,7 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {
|
|||||||
for _, idx := range indexes {
|
for _, idx := range indexes {
|
||||||
hi, ok := hm.Indexes[idx]
|
hi, ok := hm.Indexes[idx]
|
||||||
if ok {
|
if ok {
|
||||||
r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnIp())
|
r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnAddrs())
|
||||||
remoteClusterName := strings.Trim(hi.GetCert().Certificate.Name(), " ")
|
remoteClusterName := strings.Trim(hi.GetCert().Certificate.Name(), " ")
|
||||||
globalLines = append(globalLines, &edge{from: fmt.Sprintf("%v.%v", clusterName, idx), to: fmt.Sprintf("%v.%v", remoteClusterName, hi.GetRemoteIndex())})
|
globalLines = append(globalLines, &edge{from: fmt.Sprintf("%v.%v", clusterName, idx), to: fmt.Sprintf("%v.%v", remoteClusterName, hi.GetRemoteIndex())})
|
||||||
_ = hi
|
_ = hi
|
||||||
|
|||||||
@ -10,8 +10,8 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"regexp"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@ -136,7 +136,10 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
|
|||||||
panic("Duplicate listen address: " + addr.String())
|
panic("Duplicate listen address: " + addr.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
r.vpnControls[c.GetVpnIp()] = c
|
for _, vpnAddr := range c.GetVpnAddrs() {
|
||||||
|
r.vpnControls[vpnAddr] = c
|
||||||
|
}
|
||||||
|
|
||||||
r.controls[addr] = c
|
r.controls[addr] = c
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -213,11 +216,11 @@ func (r *R) renderFlow() {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
participants[addr] = struct{}{}
|
participants[addr] = struct{}{}
|
||||||
sanAddr := strings.Replace(addr.String(), ":", "-", 1)
|
sanAddr := normalizeName(addr.String())
|
||||||
participantsVals = append(participantsVals, sanAddr)
|
participantsVals = append(participantsVals, sanAddr)
|
||||||
fmt.Fprintf(
|
fmt.Fprintf(
|
||||||
f, " participant %s as Nebula: %s<br/>UDP: %s\n",
|
f, " participant %s as Nebula: %s<br/>UDP: %s\n",
|
||||||
sanAddr, e.packet.from.GetVpnIp(), sanAddr,
|
sanAddr, e.packet.from.GetVpnAddrs(), sanAddr,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -250,9 +253,9 @@ func (r *R) renderFlow() {
|
|||||||
|
|
||||||
fmt.Fprintf(f,
|
fmt.Fprintf(f,
|
||||||
" %s%s%s: %s(%s), index %v, counter: %v\n",
|
" %s%s%s: %s(%s), index %v, counter: %v\n",
|
||||||
strings.Replace(p.from.GetUDPAddr().String(), ":", "-", 1),
|
normalizeName(p.from.GetUDPAddr().String()),
|
||||||
line,
|
line,
|
||||||
strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1),
|
normalizeName(p.to.GetUDPAddr().String()),
|
||||||
h.TypeName(), h.SubTypeName(), h.RemoteIndex, h.MessageCounter,
|
h.TypeName(), h.SubTypeName(), h.RemoteIndex, h.MessageCounter,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@ -267,6 +270,11 @@ func (r *R) renderFlow() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func normalizeName(s string) string {
|
||||||
|
rx := regexp.MustCompile("[\\[\\]\\:]")
|
||||||
|
return rx.ReplaceAllLiteralString(s, "_")
|
||||||
|
}
|
||||||
|
|
||||||
// IgnoreFlow tells the router to stop recording future flows that matches the provided criteria.
|
// IgnoreFlow tells the router to stop recording future flows that matches the provided criteria.
|
||||||
// messageType and subType will target nebula underlay packets while tun will target nebula overlay packets
|
// messageType and subType will target nebula underlay packets while tun will target nebula overlay packets
|
||||||
// NOTE: This is a very broad system, if you set tun to true then no more tun traffic will be rendered
|
// NOTE: This is a very broad system, if you set tun to true then no more tun traffic will be rendered
|
||||||
@ -303,7 +311,7 @@ func (r *R) RenderHostmaps(title string, controls ...*nebula.Control) {
|
|||||||
func (r *R) renderHostmaps(title string) {
|
func (r *R) renderHostmaps(title string) {
|
||||||
c := maps.Values(r.controls)
|
c := maps.Values(r.controls)
|
||||||
sort.SliceStable(c, func(i, j int) bool {
|
sort.SliceStable(c, func(i, j int) bool {
|
||||||
return c[i].GetVpnIp().Compare(c[j].GetVpnIp()) > 0
|
return c[i].GetVpnAddrs()[0].Compare(c[j].GetVpnAddrs()[0]) > 0
|
||||||
})
|
})
|
||||||
|
|
||||||
s := renderHostmaps(c...)
|
s := renderHostmaps(c...)
|
||||||
@ -419,10 +427,11 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) []
|
|||||||
// Nope, lets push the sender along
|
// Nope, lets push the sender along
|
||||||
case p := <-udpTx:
|
case p := <-udpTx:
|
||||||
r.Lock()
|
r.Lock()
|
||||||
c := r.getControl(sender.GetUDPAddr(), p.To, p)
|
a := sender.GetUDPAddr()
|
||||||
|
c := r.getControl(a, p.To, p)
|
||||||
if c == nil {
|
if c == nil {
|
||||||
r.Unlock()
|
r.Unlock()
|
||||||
panic("No control for udp tx")
|
panic("No control for udp tx " + a.String())
|
||||||
}
|
}
|
||||||
fp := r.unlockedInjectFlow(sender, c, p, false)
|
fp := r.unlockedInjectFlow(sender, c, p, false)
|
||||||
c.InjectUDPPacket(p)
|
c.InjectUDPPacket(p)
|
||||||
@ -475,10 +484,11 @@ func (r *R) RouteForAllUntilTxTun(receiver *nebula.Control) []byte {
|
|||||||
} else {
|
} else {
|
||||||
// we are a udp tx, route and continue
|
// we are a udp tx, route and continue
|
||||||
p := rx.Interface().(*udp.Packet)
|
p := rx.Interface().(*udp.Packet)
|
||||||
c := r.getControl(cm[x].GetUDPAddr(), p.To, p)
|
a := cm[x].GetUDPAddr()
|
||||||
|
c := r.getControl(a, p.To, p)
|
||||||
if c == nil {
|
if c == nil {
|
||||||
r.Unlock()
|
r.Unlock()
|
||||||
panic("No control for udp tx")
|
panic(fmt.Sprintf("No control for udp tx %s", p.To))
|
||||||
}
|
}
|
||||||
fp := r.unlockedInjectFlow(cm[x], c, p, false)
|
fp := r.unlockedInjectFlow(cm[x], c, p, false)
|
||||||
c.InjectUDPPacket(p)
|
c.InjectUDPPacket(p)
|
||||||
@ -711,30 +721,42 @@ func (r *R) getControl(fromAddr, toAddr netip.AddrPort, p *udp.Packet) *nebula.C
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *R) formatUdpPacket(p *packet) string {
|
func (r *R) formatUdpPacket(p *packet) string {
|
||||||
packet := gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv4, gopacket.Lazy)
|
var packet gopacket.Packet
|
||||||
v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
|
var srcAddr netip.Addr
|
||||||
if v4 == nil {
|
|
||||||
panic("not an ipv4 packet")
|
packet = gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv6, gopacket.Lazy)
|
||||||
|
if packet.ErrorLayer() == nil {
|
||||||
|
v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6)
|
||||||
|
if v6 == nil {
|
||||||
|
panic("not an ipv6 packet")
|
||||||
|
}
|
||||||
|
srcAddr, _ = netip.AddrFromSlice(v6.SrcIP)
|
||||||
|
} else {
|
||||||
|
packet = gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv4, gopacket.Lazy)
|
||||||
|
v6 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
|
||||||
|
if v6 == nil {
|
||||||
|
panic("not an ipv6 packet")
|
||||||
|
}
|
||||||
|
srcAddr, _ = netip.AddrFromSlice(v6.SrcIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
from := "unknown"
|
from := "unknown"
|
||||||
srcAddr, _ := netip.AddrFromSlice(v4.SrcIP)
|
|
||||||
if c, ok := r.vpnControls[srcAddr]; ok {
|
if c, ok := r.vpnControls[srcAddr]; ok {
|
||||||
from = c.GetUDPAddr().String()
|
from = c.GetUDPAddr().String()
|
||||||
}
|
}
|
||||||
|
|
||||||
udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
|
udpLayer := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
|
||||||
if udp == nil {
|
if udpLayer == nil {
|
||||||
panic("not a udp packet")
|
panic("not a udp packet")
|
||||||
}
|
}
|
||||||
|
|
||||||
data := packet.ApplicationLayer()
|
data := packet.ApplicationLayer()
|
||||||
return fmt.Sprintf(
|
return fmt.Sprintf(
|
||||||
" %s-->>%s: src port: %v<br/>dest port: %v<br/>data: \"%v\"\n",
|
" %s-->>%s: src port: %v<br/>dest port: %v<br/>data: \"%v\"\n",
|
||||||
strings.Replace(from, ":", "-", 1),
|
normalizeName(from),
|
||||||
strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1),
|
normalizeName(p.to.GetUDPAddr().String()),
|
||||||
udp.SrcPort,
|
udpLayer.SrcPort,
|
||||||
udp.DstPort,
|
udpLayer.DstPort,
|
||||||
string(data.Payload()),
|
string(data.Payload()),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -13,6 +13,12 @@ pki:
|
|||||||
# disconnect_invalid is a toggle to force a client to be disconnected if the certificate is expired or invalid.
|
# disconnect_invalid is a toggle to force a client to be disconnected if the certificate is expired or invalid.
|
||||||
#disconnect_invalid: true
|
#disconnect_invalid: true
|
||||||
|
|
||||||
|
# default_version controls which certificate version is used in handshakes.
|
||||||
|
# This setting only applies if both a v1 and a v2 certificate are configured, in which case it will default to `1`.
|
||||||
|
# Once all hosts in the mesh are configured with both a v1 and v2 certificate then this should be changed to `2`.
|
||||||
|
# After all hosts in the mesh are using a v2 certificate then v1 certificates are no longer needed.
|
||||||
|
# default_version: 1
|
||||||
|
|
||||||
# The static host map defines a set of hosts with fixed IP addresses on the internet (or any network).
|
# The static host map defines a set of hosts with fixed IP addresses on the internet (or any network).
|
||||||
# A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel.
|
# A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel.
|
||||||
# The syntax is:
|
# The syntax is:
|
||||||
@ -336,10 +342,13 @@ firewall:
|
|||||||
# host: `any` or a literal hostname, ie `test-host`
|
# host: `any` or a literal hostname, ie `test-host`
|
||||||
# group: `any` or a literal group name, ie `default-group`
|
# group: `any` or a literal group name, ie `default-group`
|
||||||
# groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
|
# groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
|
||||||
# cidr: a remote CIDR, `0.0.0.0/0` is any.
|
# cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. //TODO: we have a problem, firewall needs to understand this and should probably allow `any` for both
|
||||||
# local_cidr: a local CIDR, `0.0.0.0/0` is any. This could be used to filter destinations when using unsafe_routes.
|
# local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. This could be used to filter destinations when using unsafe_routes.
|
||||||
# Default is `any` unless the certificate contains subnets and then the default is the ip issued in the certificate
|
# //TODO: probably should have an `any` that covers both ip versions
|
||||||
# if `default_local_cidr_any` is false, otherwise its `any`.
|
# If no unsafe networks are present in the certificate(s) or `default_local_cidr_any` is true then the default is any ipv4 or ipv6 network.
|
||||||
|
# Otherwise the default is any vpn network assigned to via the certificate.
|
||||||
|
# `default_local_cidr_any` defaults to false and is deprecated, it will be removed in a future release.
|
||||||
|
# If there are unsafe routes present its best to set `local_cidr` to whatever best fits the situation.
|
||||||
# ca_name: An issuing CA name
|
# ca_name: An issuing CA name
|
||||||
# ca_sha: An issuing CA shasum
|
# ca_sha: An issuing CA shasum
|
||||||
|
|
||||||
|
|||||||
116
firewall.go
116
firewall.go
@ -8,6 +8,7 @@ import (
|
|||||||
"hash/fnv"
|
"hash/fnv"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@ -22,7 +23,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type FirewallInterface interface {
|
type FirewallInterface interface {
|
||||||
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error
|
//TODO: name these better addr, localAddr. Are they vpnAddrs?
|
||||||
|
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, addr, localAddr netip.Prefix, caName string, caSha string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type conn struct {
|
type conn struct {
|
||||||
@ -51,9 +53,12 @@ type Firewall struct {
|
|||||||
UDPTimeout time.Duration //linux: 180s max
|
UDPTimeout time.Duration //linux: 180s max
|
||||||
DefaultTimeout time.Duration //linux: 600s
|
DefaultTimeout time.Duration //linux: 600s
|
||||||
|
|
||||||
// Used to ensure we don't emit local packets for ips we don't own
|
// routableNetworks describes the vpn addresses as well as any unsafe networks issued to us in the certificate.
|
||||||
localIps *bart.Table[struct{}]
|
// The vpn addresses are a full bit match while the unsafe networks only match the prefix
|
||||||
assignedCIDR netip.Prefix
|
routableNetworks *bart.Table[struct{}]
|
||||||
|
|
||||||
|
// assignedNetworks is a list of vpn networks assigned to us in the certificate.
|
||||||
|
assignedNetworks []netip.Prefix
|
||||||
hasUnsafeNetworks bool
|
hasUnsafeNetworks bool
|
||||||
|
|
||||||
rules string
|
rules string
|
||||||
@ -67,8 +72,8 @@ type Firewall struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type firewallMetrics struct {
|
type firewallMetrics struct {
|
||||||
droppedLocalIP metrics.Counter
|
droppedLocalAddr metrics.Counter
|
||||||
droppedRemoteIP metrics.Counter
|
droppedRemoteAddr metrics.Counter
|
||||||
droppedNoRule metrics.Counter
|
droppedNoRule metrics.Counter
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -126,84 +131,87 @@ type firewallLocalCIDR struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
|
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
|
||||||
|
// The certificate provided should be the highest version loaded in memory.
|
||||||
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall {
|
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall {
|
||||||
//TODO: error on 0 duration
|
//TODO: error on 0 duration
|
||||||
var min, max time.Duration
|
var tmin, tmax time.Duration
|
||||||
|
|
||||||
if tcpTimeout < UDPTimeout {
|
if tcpTimeout < UDPTimeout {
|
||||||
min = tcpTimeout
|
tmin = tcpTimeout
|
||||||
max = UDPTimeout
|
tmax = UDPTimeout
|
||||||
} else {
|
} else {
|
||||||
min = UDPTimeout
|
tmin = UDPTimeout
|
||||||
max = tcpTimeout
|
tmax = tcpTimeout
|
||||||
}
|
}
|
||||||
|
|
||||||
if defaultTimeout < min {
|
if defaultTimeout < tmin {
|
||||||
min = defaultTimeout
|
tmin = defaultTimeout
|
||||||
} else if defaultTimeout > max {
|
} else if defaultTimeout > tmax {
|
||||||
max = defaultTimeout
|
tmax = defaultTimeout
|
||||||
}
|
}
|
||||||
|
|
||||||
localIps := new(bart.Table[struct{}])
|
routableNetworks := new(bart.Table[struct{}])
|
||||||
var assignedCIDR netip.Prefix
|
var assignedNetworks []netip.Prefix
|
||||||
var assignedSet bool
|
|
||||||
for _, network := range c.Networks() {
|
for _, network := range c.Networks() {
|
||||||
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
|
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
|
||||||
localIps.Insert(nprefix, struct{}{})
|
routableNetworks.Insert(nprefix, struct{}{})
|
||||||
|
assignedNetworks = append(assignedNetworks, network)
|
||||||
if !assignedSet {
|
|
||||||
// Only grabbing the first one in the cert since any more than that currently has undefined behavior
|
|
||||||
assignedCIDR = nprefix
|
|
||||||
assignedSet = true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
hasUnsafeNetworks := false
|
hasUnsafeNetworks := false
|
||||||
for _, n := range c.UnsafeNetworks() {
|
for _, n := range c.UnsafeNetworks() {
|
||||||
localIps.Insert(n, struct{}{})
|
routableNetworks.Insert(n, struct{}{})
|
||||||
hasUnsafeNetworks = true
|
hasUnsafeNetworks = true
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Firewall{
|
return &Firewall{
|
||||||
Conntrack: &FirewallConntrack{
|
Conntrack: &FirewallConntrack{
|
||||||
Conns: make(map[firewall.Packet]*conn),
|
Conns: make(map[firewall.Packet]*conn),
|
||||||
TimerWheel: NewTimerWheel[firewall.Packet](min, max),
|
TimerWheel: NewTimerWheel[firewall.Packet](tmin, tmax),
|
||||||
},
|
},
|
||||||
InRules: newFirewallTable(),
|
InRules: newFirewallTable(),
|
||||||
OutRules: newFirewallTable(),
|
OutRules: newFirewallTable(),
|
||||||
TCPTimeout: tcpTimeout,
|
TCPTimeout: tcpTimeout,
|
||||||
UDPTimeout: UDPTimeout,
|
UDPTimeout: UDPTimeout,
|
||||||
DefaultTimeout: defaultTimeout,
|
DefaultTimeout: defaultTimeout,
|
||||||
localIps: localIps,
|
routableNetworks: routableNetworks,
|
||||||
assignedCIDR: assignedCIDR,
|
assignedNetworks: assignedNetworks,
|
||||||
hasUnsafeNetworks: hasUnsafeNetworks,
|
hasUnsafeNetworks: hasUnsafeNetworks,
|
||||||
l: l,
|
l: l,
|
||||||
|
|
||||||
incomingMetrics: firewallMetrics{
|
incomingMetrics: firewallMetrics{
|
||||||
droppedLocalIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_ip", nil),
|
droppedLocalAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_addr", nil),
|
||||||
droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_ip", nil),
|
droppedRemoteAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_addr", nil),
|
||||||
droppedNoRule: metrics.GetOrRegisterCounter("firewall.incoming.dropped.no_rule", nil),
|
droppedNoRule: metrics.GetOrRegisterCounter("firewall.incoming.dropped.no_rule", nil),
|
||||||
},
|
},
|
||||||
outgoingMetrics: firewallMetrics{
|
outgoingMetrics: firewallMetrics{
|
||||||
droppedLocalIP: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.local_ip", nil),
|
droppedLocalAddr: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.local_addr", nil),
|
||||||
droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.remote_ip", nil),
|
droppedRemoteAddr: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.remote_addr", nil),
|
||||||
droppedNoRule: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.no_rule", nil),
|
droppedNoRule: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.no_rule", nil),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewFirewallFromConfig(l *logrus.Logger, nc cert.Certificate, c *config.C) (*Firewall, error) {
|
func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) {
|
||||||
|
certificate := cs.getCertificate(cert.Version2)
|
||||||
|
if certificate == nil {
|
||||||
|
certificate = cs.getCertificate(cert.Version1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if certificate == nil {
|
||||||
|
panic("No certificate available to reconfigure the firewall")
|
||||||
|
}
|
||||||
|
|
||||||
fw := NewFirewall(
|
fw := NewFirewall(
|
||||||
l,
|
l,
|
||||||
c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
|
c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
|
||||||
c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3),
|
c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3),
|
||||||
c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10),
|
c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10),
|
||||||
nc,
|
certificate,
|
||||||
//TODO: max_connections
|
//TODO: max_connections
|
||||||
)
|
)
|
||||||
|
|
||||||
//TODO: Flip to false after v1.9 release
|
fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", false)
|
||||||
fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", true)
|
|
||||||
|
|
||||||
inboundAction := c.GetString("firewall.inbound_action", "drop")
|
inboundAction := c.GetString("firewall.inbound_action", "drop")
|
||||||
switch inboundAction {
|
switch inboundAction {
|
||||||
@ -283,7 +291,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
|
|||||||
fp = ft.TCP
|
fp = ft.TCP
|
||||||
case firewall.ProtoUDP:
|
case firewall.ProtoUDP:
|
||||||
fp = ft.UDP
|
fp = ft.UDP
|
||||||
case firewall.ProtoICMP:
|
case firewall.ProtoICMP, firewall.ProtoICMPv6:
|
||||||
fp = ft.ICMP
|
fp = ft.ICMP
|
||||||
case firewall.ProtoAny:
|
case firewall.ProtoAny:
|
||||||
fp = ft.AnyProto
|
fp = ft.AnyProto
|
||||||
@ -424,26 +432,25 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Make sure remote address matches nebula certificate
|
// Make sure remote address matches nebula certificate
|
||||||
if remoteCidr := h.remoteCidr; remoteCidr != nil {
|
if h.networks != nil {
|
||||||
//TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different
|
_, ok := h.networks.Lookup(fp.RemoteAddr)
|
||||||
_, ok := remoteCidr.Lookup(fp.RemoteIP)
|
|
||||||
if !ok {
|
if !ok {
|
||||||
f.metrics(incoming).droppedRemoteIP.Inc(1)
|
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||||
return ErrInvalidRemoteIP
|
return ErrInvalidRemoteIP
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Simple case: Certificate has one IP and no subnets
|
// Simple case: Certificate has one IP and no subnets
|
||||||
if fp.RemoteIP != h.vpnIp {
|
//TODO: we can make this more performant
|
||||||
f.metrics(incoming).droppedRemoteIP.Inc(1)
|
if !slices.Contains(h.vpnAddrs, fp.RemoteAddr) {
|
||||||
|
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||||
return ErrInvalidRemoteIP
|
return ErrInvalidRemoteIP
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure we are supposed to be handling this local ip address
|
// Make sure we are supposed to be handling this local ip address
|
||||||
//TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different
|
_, ok := f.routableNetworks.Lookup(fp.LocalAddr)
|
||||||
_, ok := f.localIps.Lookup(fp.LocalIP)
|
|
||||||
if !ok {
|
if !ok {
|
||||||
f.metrics(incoming).droppedLocalIP.Inc(1)
|
f.metrics(incoming).droppedLocalAddr.Inc(1)
|
||||||
return ErrInvalidLocalIP
|
return ErrInvalidLocalIP
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -629,7 +636,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.CachedC
|
|||||||
if ft.UDP.match(p, incoming, c, caPool) {
|
if ft.UDP.match(p, incoming, c, caPool) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
case firewall.ProtoICMP:
|
case firewall.ProtoICMP, firewall.ProtoICMPv6:
|
||||||
if ft.ICMP.match(p, incoming, c, caPool) {
|
if ft.ICMP.match(p, incoming, c, caPool) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@ -859,9 +866,9 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool
|
|||||||
}
|
}
|
||||||
|
|
||||||
matched := false
|
matched := false
|
||||||
prefix := netip.PrefixFrom(p.RemoteIP, p.RemoteIP.BitLen())
|
prefix := netip.PrefixFrom(p.RemoteAddr, p.RemoteAddr.BitLen())
|
||||||
fr.CIDR.EachLookupPrefix(prefix, func(prefix netip.Prefix, val *firewallLocalCIDR) bool {
|
fr.CIDR.EachLookupPrefix(prefix, func(prefix netip.Prefix, val *firewallLocalCIDR) bool {
|
||||||
if prefix.Contains(p.RemoteIP) && val.match(p, c) {
|
if prefix.Contains(p.RemoteAddr) && val.match(p, c) {
|
||||||
matched = true
|
matched = true
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@ -877,9 +884,14 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
localIp = f.assignedCIDR
|
for _, network := range f.assignedNetworks {
|
||||||
|
flc.LocalCIDR.Insert(network, struct{}{})
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
|
||||||
} else if localIp.Bits() == 0 {
|
} else if localIp.Bits() == 0 {
|
||||||
flc.Any = true
|
flc.Any = true
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
flc.LocalCIDR.Insert(localIp, struct{}{})
|
flc.LocalCIDR.Insert(localIp, struct{}{})
|
||||||
@ -895,7 +907,7 @@ func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.CachedCertificate
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
_, ok := flc.LocalCIDR.Lookup(p.LocalIP)
|
_, ok := flc.LocalCIDR.Lookup(p.LocalAddr)
|
||||||
return ok
|
return ok
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -13,14 +13,15 @@ const (
|
|||||||
ProtoTCP = 6
|
ProtoTCP = 6
|
||||||
ProtoUDP = 17
|
ProtoUDP = 17
|
||||||
ProtoICMP = 1
|
ProtoICMP = 1
|
||||||
|
ProtoICMPv6 = 58
|
||||||
|
|
||||||
PortAny = 0 // Special value for matching `port: any`
|
PortAny = 0 // Special value for matching `port: any`
|
||||||
PortFragment = -1 // Special value for matching `port: fragment`
|
PortFragment = -1 // Special value for matching `port: fragment`
|
||||||
)
|
)
|
||||||
|
|
||||||
type Packet struct {
|
type Packet struct {
|
||||||
LocalIP netip.Addr
|
LocalAddr netip.Addr
|
||||||
RemoteIP netip.Addr
|
RemoteAddr netip.Addr
|
||||||
LocalPort uint16
|
LocalPort uint16
|
||||||
RemotePort uint16
|
RemotePort uint16
|
||||||
Protocol uint8
|
Protocol uint8
|
||||||
@ -29,8 +30,8 @@ type Packet struct {
|
|||||||
|
|
||||||
func (fp *Packet) Copy() *Packet {
|
func (fp *Packet) Copy() *Packet {
|
||||||
return &Packet{
|
return &Packet{
|
||||||
LocalIP: fp.LocalIP,
|
LocalAddr: fp.LocalAddr,
|
||||||
RemoteIP: fp.RemoteIP,
|
RemoteAddr: fp.RemoteAddr,
|
||||||
LocalPort: fp.LocalPort,
|
LocalPort: fp.LocalPort,
|
||||||
RemotePort: fp.RemotePort,
|
RemotePort: fp.RemotePort,
|
||||||
Protocol: fp.Protocol,
|
Protocol: fp.Protocol,
|
||||||
@ -51,8 +52,8 @@ func (fp Packet) MarshalJSON() ([]byte, error) {
|
|||||||
proto = fmt.Sprintf("unknown %v", fp.Protocol)
|
proto = fmt.Sprintf("unknown %v", fp.Protocol)
|
||||||
}
|
}
|
||||||
return json.Marshal(m{
|
return json.Marshal(m{
|
||||||
"LocalIP": fp.LocalIP.String(),
|
"LocalAddr": fp.LocalAddr.String(),
|
||||||
"RemoteIP": fp.RemoteIP.String(),
|
"RemoteAddr": fp.RemoteAddr.String(),
|
||||||
"LocalPort": fp.LocalPort,
|
"LocalPort": fp.LocalPort,
|
||||||
"RemotePort": fp.RemotePort,
|
"RemotePort": fp.RemotePort,
|
||||||
"Protocol": proto,
|
"Protocol": proto,
|
||||||
|
|||||||
@ -13,6 +13,7 @@ import (
|
|||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
"github.com/slackhq/nebula/test"
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewFirewall(t *testing.T) {
|
func TestNewFirewall(t *testing.T) {
|
||||||
@ -128,8 +129,8 @@ func TestFirewall_Drop(t *testing.T) {
|
|||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
|
|
||||||
p := firewall.Packet{
|
p := firewall.Packet{
|
||||||
LocalIP: netip.MustParseAddr("1.2.3.4"),
|
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
RemoteIP: netip.MustParseAddr("1.2.3.4"),
|
RemoteAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
LocalPort: 10,
|
LocalPort: 10,
|
||||||
RemotePort: 90,
|
RemotePort: 90,
|
||||||
Protocol: firewall.ProtoUDP,
|
Protocol: firewall.ProtoUDP,
|
||||||
@ -149,9 +150,9 @@ func TestFirewall_Drop(t *testing.T) {
|
|||||||
InvertedGroups: map[string]struct{}{"default-group": {}},
|
InvertedGroups: map[string]struct{}{"default-group": {}},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
vpnIp: netip.MustParseAddr("1.2.3.4"),
|
vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
|
||||||
}
|
}
|
||||||
h.CreateRemoteCIDR(&c)
|
h.buildNetworks(&c)
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
@ -166,10 +167,10 @@ func TestFirewall_Drop(t *testing.T) {
|
|||||||
assert.NoError(t, fw.Drop(p, false, &h, cp, nil))
|
assert.NoError(t, fw.Drop(p, false, &h, cp, nil))
|
||||||
|
|
||||||
// test remote mismatch
|
// test remote mismatch
|
||||||
oldRemote := p.RemoteIP
|
oldRemote := p.RemoteAddr
|
||||||
p.RemoteIP = netip.MustParseAddr("1.2.3.10")
|
p.RemoteAddr = netip.MustParseAddr("1.2.3.10")
|
||||||
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
|
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
|
||||||
p.RemoteIP = oldRemote
|
p.RemoteAddr = oldRemote
|
||||||
|
|
||||||
// ensure signer doesn't get in the way of group checks
|
// ensure signer doesn't get in the way of group checks
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
@ -235,7 +236,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
|||||||
}
|
}
|
||||||
ip := netip.MustParsePrefix("9.254.254.254/32")
|
ip := netip.MustParsePrefix("9.254.254.254/32")
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip.Addr()}, true, c, cp))
|
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -261,7 +262,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
|||||||
InvertedGroups: map[string]struct{}{"nope": {}},
|
InvertedGroups: map[string]struct{}{"nope": {}},
|
||||||
}
|
}
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp))
|
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -285,7 +286,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
|||||||
InvertedGroups: map[string]struct{}{"good-group": {}},
|
InvertedGroups: map[string]struct{}{"good-group": {}},
|
||||||
}
|
}
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp))
|
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -308,8 +309,8 @@ func TestFirewall_Drop2(t *testing.T) {
|
|||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
|
|
||||||
p := firewall.Packet{
|
p := firewall.Packet{
|
||||||
LocalIP: netip.MustParseAddr("1.2.3.4"),
|
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
RemoteIP: netip.MustParseAddr("1.2.3.4"),
|
RemoteAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
LocalPort: 10,
|
LocalPort: 10,
|
||||||
RemotePort: 90,
|
RemotePort: 90,
|
||||||
Protocol: firewall.ProtoUDP,
|
Protocol: firewall.ProtoUDP,
|
||||||
@ -329,9 +330,9 @@ func TestFirewall_Drop2(t *testing.T) {
|
|||||||
ConnectionState: &ConnectionState{
|
ConnectionState: &ConnectionState{
|
||||||
peerCert: &c,
|
peerCert: &c,
|
||||||
},
|
},
|
||||||
vpnIp: network.Addr(),
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
}
|
}
|
||||||
h.CreateRemoteCIDR(c.Certificate)
|
h.buildNetworks(c.Certificate)
|
||||||
|
|
||||||
c1 := cert.CachedCertificate{
|
c1 := cert.CachedCertificate{
|
||||||
Certificate: &dummyCert{
|
Certificate: &dummyCert{
|
||||||
@ -345,7 +346,7 @@ func TestFirewall_Drop2(t *testing.T) {
|
|||||||
peerCert: &c1,
|
peerCert: &c1,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
h1.CreateRemoteCIDR(c1.Certificate)
|
h1.buildNetworks(c1.Certificate)
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
@ -364,8 +365,8 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
|
|
||||||
p := firewall.Packet{
|
p := firewall.Packet{
|
||||||
LocalIP: netip.MustParseAddr("1.2.3.4"),
|
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
RemoteIP: netip.MustParseAddr("1.2.3.4"),
|
RemoteAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
LocalPort: 1,
|
LocalPort: 1,
|
||||||
RemotePort: 1,
|
RemotePort: 1,
|
||||||
Protocol: firewall.ProtoUDP,
|
Protocol: firewall.ProtoUDP,
|
||||||
@ -391,9 +392,9 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||||||
ConnectionState: &ConnectionState{
|
ConnectionState: &ConnectionState{
|
||||||
peerCert: &c1,
|
peerCert: &c1,
|
||||||
},
|
},
|
||||||
vpnIp: network.Addr(),
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
}
|
}
|
||||||
h1.CreateRemoteCIDR(c1.Certificate)
|
h1.buildNetworks(c1.Certificate)
|
||||||
|
|
||||||
c2 := cert.CachedCertificate{
|
c2 := cert.CachedCertificate{
|
||||||
Certificate: &dummyCert{
|
Certificate: &dummyCert{
|
||||||
@ -406,9 +407,9 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||||||
ConnectionState: &ConnectionState{
|
ConnectionState: &ConnectionState{
|
||||||
peerCert: &c2,
|
peerCert: &c2,
|
||||||
},
|
},
|
||||||
vpnIp: network.Addr(),
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
}
|
}
|
||||||
h2.CreateRemoteCIDR(c2.Certificate)
|
h2.buildNetworks(c2.Certificate)
|
||||||
|
|
||||||
c3 := cert.CachedCertificate{
|
c3 := cert.CachedCertificate{
|
||||||
Certificate: &dummyCert{
|
Certificate: &dummyCert{
|
||||||
@ -421,9 +422,9 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||||||
ConnectionState: &ConnectionState{
|
ConnectionState: &ConnectionState{
|
||||||
peerCert: &c3,
|
peerCert: &c3,
|
||||||
},
|
},
|
||||||
vpnIp: network.Addr(),
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
}
|
}
|
||||||
h3.CreateRemoteCIDR(c3.Certificate)
|
h3.buildNetworks(c3.Certificate)
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
@ -446,8 +447,8 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|||||||
l.SetOutput(ob)
|
l.SetOutput(ob)
|
||||||
|
|
||||||
p := firewall.Packet{
|
p := firewall.Packet{
|
||||||
LocalIP: netip.MustParseAddr("1.2.3.4"),
|
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
RemoteIP: netip.MustParseAddr("1.2.3.4"),
|
RemoteAddr: netip.MustParseAddr("1.2.3.4"),
|
||||||
LocalPort: 10,
|
LocalPort: 10,
|
||||||
RemotePort: 90,
|
RemotePort: 90,
|
||||||
Protocol: firewall.ProtoUDP,
|
Protocol: firewall.ProtoUDP,
|
||||||
@ -468,9 +469,9 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|||||||
ConnectionState: &ConnectionState{
|
ConnectionState: &ConnectionState{
|
||||||
peerCert: &c,
|
peerCert: &c,
|
||||||
},
|
},
|
||||||
vpnIp: network.Addr(),
|
vpnAddrs: []netip.Addr{network.Addr()},
|
||||||
}
|
}
|
||||||
h.CreateRemoteCIDR(c.Certificate)
|
h.buildNetworks(c.Certificate)
|
||||||
|
|
||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
|
||||||
@ -622,55 +623,58 @@ func TestNewFirewallFromConfig(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
// Test a bad rule definition
|
// Test a bad rule definition
|
||||||
c := &dummyCert{}
|
c := &dummyCert{}
|
||||||
|
cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
conf := config.NewC(l)
|
conf := config.NewC(l)
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
|
||||||
_, err := NewFirewallFromConfig(l, c, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
|
assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
|
||||||
|
|
||||||
// Test both port and code
|
// Test both port and code
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
|
||||||
_, err = NewFirewallFromConfig(l, c, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
|
assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
|
||||||
|
|
||||||
// Test missing host, group, cidr, ca_name and ca_sha
|
// Test missing host, group, cidr, ca_name and ca_sha
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
|
||||||
_, err = NewFirewallFromConfig(l, c, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
|
assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
|
||||||
|
|
||||||
// Test code/port error
|
// Test code/port error
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
|
||||||
_, err = NewFirewallFromConfig(l, c, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
|
assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
|
||||||
|
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}}
|
||||||
_, err = NewFirewallFromConfig(l, c, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
|
assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
|
||||||
|
|
||||||
// Test proto error
|
// Test proto error
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
|
||||||
_, err = NewFirewallFromConfig(l, c, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
|
assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
|
||||||
|
|
||||||
// Test cidr parse error
|
// Test cidr parse error
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
|
||||||
_, err = NewFirewallFromConfig(l, c, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
|
assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
|
||||||
|
|
||||||
// Test local_cidr parse error
|
// Test local_cidr parse error
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}}
|
||||||
_, err = NewFirewallFromConfig(l, c, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
|
assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
|
||||||
|
|
||||||
// Test both group and groups
|
// Test both group and groups
|
||||||
conf = config.NewC(l)
|
conf = config.NewC(l)
|
||||||
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
|
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
|
||||||
_, err = NewFirewallFromConfig(l, c, conf)
|
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||||
assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
|
assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
1
go.mod
1
go.mod
@ -21,7 +21,6 @@ require (
|
|||||||
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
|
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
|
||||||
github.com/sirupsen/logrus v1.9.3
|
github.com/sirupsen/logrus v1.9.3
|
||||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
||||||
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
|
|
||||||
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
|
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
|
||||||
github.com/stretchr/testify v1.9.0
|
github.com/stretchr/testify v1.9.0
|
||||||
github.com/vishvananda/netlink v1.3.0
|
github.com/vishvananda/netlink v1.3.0
|
||||||
|
|||||||
2
go.sum
2
go.sum
@ -137,8 +137,6 @@ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ
|
|||||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
|
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
|
||||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
|
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
|
||||||
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8=
|
|
||||||
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E=
|
|
||||||
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 h1:pnnLyeX7o/5aX8qUQ69P/mLojDqwda8hFOCBTmP/6hw=
|
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 h1:pnnLyeX7o/5aX8qUQ69P/mLojDqwda8hFOCBTmP/6hw=
|
||||||
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6/go.mod h1:39R/xuhNgVhi+K0/zst4TLrJrVmbm6LVgl4A0+ZFS5M=
|
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6/go.mod h1:39R/xuhNgVhi+K0/zst4TLrJrVmbm6LVgl4A0+ZFS5M=
|
||||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
|
|||||||
224
handshake_ix.go
224
handshake_ix.go
@ -2,10 +2,12 @@ package nebula
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"slices"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/flynn/noise"
|
"github.com/flynn/noise"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -16,30 +18,60 @@ import (
|
|||||||
func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
||||||
err := f.handshakeManager.allocateIndex(hh)
|
err := f.handshakeManager.allocateIndex(hh)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp).
|
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
|
||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
|
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
certState := f.pki.GetCertState()
|
// If we're connecting to a v6 address we must use a v2 cert
|
||||||
ci := NewConnectionState(f.l, f.cipher, certState, true, noise.HandshakeIX, []byte{}, 0)
|
cs := f.pki.getCertState()
|
||||||
|
v := cs.defaultVersion
|
||||||
|
for _, a := range hh.hostinfo.vpnAddrs {
|
||||||
|
if a.Is6() {
|
||||||
|
v = cert.Version2
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
crt := cs.getCertificate(v)
|
||||||
|
if crt == nil {
|
||||||
|
f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
|
||||||
|
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
||||||
|
WithField("certVersion", v).
|
||||||
|
Error("Unable to handshake with host because no certificate is available")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
crtHs := cs.getHandshakeBytes(v)
|
||||||
|
if crtHs == nil {
|
||||||
|
f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
|
||||||
|
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
||||||
|
WithField("certVersion", v).
|
||||||
|
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
||||||
|
}
|
||||||
|
|
||||||
|
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
|
||||||
|
if err != nil {
|
||||||
|
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
|
||||||
|
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
||||||
|
WithField("certVersion", v).
|
||||||
|
Error("Failed to create connection state")
|
||||||
|
return false
|
||||||
|
}
|
||||||
hh.hostinfo.ConnectionState = ci
|
hh.hostinfo.ConnectionState = ci
|
||||||
|
|
||||||
hsProto := &NebulaHandshakeDetails{
|
hs := &NebulaHandshake{
|
||||||
|
Details: &NebulaHandshakeDetails{
|
||||||
InitiatorIndex: hh.hostinfo.localIndexId,
|
InitiatorIndex: hh.hostinfo.localIndexId,
|
||||||
Time: uint64(time.Now().UnixNano()),
|
Time: uint64(time.Now().UnixNano()),
|
||||||
Cert: certState.RawCertificateNoKey,
|
Cert: crtHs,
|
||||||
|
CertVersion: uint32(v),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
hsBytes := []byte{}
|
hsBytes, err := hs.Marshal()
|
||||||
|
|
||||||
hs := &NebulaHandshake{
|
|
||||||
Details: hsProto,
|
|
||||||
}
|
|
||||||
hsBytes, err = hs.Marshal()
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp).
|
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).WithField("certVersion", v).
|
||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
|
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@ -48,7 +80,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
|||||||
|
|
||||||
msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
|
msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp).
|
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
|
||||||
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
|
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@ -63,30 +95,44 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
|
func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
|
||||||
certState := f.pki.GetCertState()
|
cs := f.pki.getCertState()
|
||||||
ci := NewConnectionState(f.l, f.cipher, certState, false, noise.HandshakeIX, []byte{}, 0)
|
crt := cs.GetDefaultCertificate()
|
||||||
|
if crt == nil {
|
||||||
|
f.l.WithField("udpAddr", addr).
|
||||||
|
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
|
||||||
|
WithField("certVersion", cs.defaultVersion).
|
||||||
|
Error("Unable to handshake with host because no certificate is available")
|
||||||
|
}
|
||||||
|
|
||||||
|
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
|
||||||
|
if err != nil {
|
||||||
|
f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||||
|
Error("Failed to create connection state")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Mark packet 1 as seen so it doesn't show up as missed
|
// Mark packet 1 as seen so it doesn't show up as missed
|
||||||
ci.window.Update(f.l, 1)
|
ci.window.Update(f.l, 1)
|
||||||
|
|
||||||
msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
|
msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage")
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||||
|
Error("Failed to call noise.ReadMessage")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
hs := &NebulaHandshake{}
|
hs := &NebulaHandshake{}
|
||||||
err = hs.Unmarshal(msg)
|
err = hs.Unmarshal(msg)
|
||||||
/*
|
|
||||||
l.Debugln("GOT INDEX: ", hs.Details.InitiatorIndex)
|
|
||||||
*/
|
|
||||||
if err != nil || hs.Details == nil {
|
if err != nil || hs.Details == nil {
|
||||||
f.l.WithError(err).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||||
|
Error("Failed unmarshal handshake message")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool())
|
remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
e := f.l.WithError(err).WithField("udpAddr", addr).
|
e := f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"})
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"})
|
||||||
@ -99,6 +145,20 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if remoteCert.Certificate.Version() != ci.myCert.Version() {
|
||||||
|
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
|
||||||
|
rc := cs.getCertificate(remoteCert.Certificate.Version())
|
||||||
|
if rc == nil {
|
||||||
|
f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
|
||||||
|
Info("Unable to handshake with host due to missing certificate version")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record the certificate we are actually using
|
||||||
|
ci.myCert = rc
|
||||||
|
}
|
||||||
|
|
||||||
if len(remoteCert.Certificate.Networks()) == 0 {
|
if len(remoteCert.Certificate.Networks()) == 0 {
|
||||||
e := f.l.WithError(err).WithField("udpAddr", addr).
|
e := f.l.WithError(err).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"})
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"})
|
||||||
@ -111,13 +171,16 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
vpnIp := remoteCert.Certificate.Networks()[0].Addr().Unmap()
|
var vpnAddrs []netip.Addr
|
||||||
certName := remoteCert.Certificate.Name()
|
certName := remoteCert.Certificate.Name()
|
||||||
fingerprint := remoteCert.Fingerprint
|
fingerprint := remoteCert.Fingerprint
|
||||||
issuer := remoteCert.Certificate.Issuer()
|
issuer := remoteCert.Certificate.Issuer()
|
||||||
|
|
||||||
if vpnIp == f.myVpnNet.Addr() {
|
for _, network := range remoteCert.Certificate.Networks() {
|
||||||
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
|
vpnAddr := network.Addr()
|
||||||
|
_, found := f.myVpnAddrsTable.Lookup(vpnAddr)
|
||||||
|
if found {
|
||||||
|
f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
@ -126,15 +189,18 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
}
|
}
|
||||||
|
|
||||||
if addr.IsValid() {
|
if addr.IsValid() {
|
||||||
if !f.lightHouse.GetRemoteAllowList().Allow(vpnIp, addr.Addr()) {
|
if !f.lightHouse.GetRemoteAllowList().Allow(vpnAddr, addr.Addr()) {
|
||||||
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
|
f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
vpnAddrs = append(vpnAddrs, vpnAddr)
|
||||||
|
}
|
||||||
|
|
||||||
myIndex, err := generateIndex(f.l)
|
myIndex, err := generateIndex(f.l)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
@ -146,17 +212,17 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
ConnectionState: ci,
|
ConnectionState: ci,
|
||||||
localIndexId: myIndex,
|
localIndexId: myIndex,
|
||||||
remoteIndexId: hs.Details.InitiatorIndex,
|
remoteIndexId: hs.Details.InitiatorIndex,
|
||||||
vpnIp: vpnIp,
|
vpnAddrs: vpnAddrs,
|
||||||
HandshakePacket: make(map[uint8][]byte, 0),
|
HandshakePacket: make(map[uint8][]byte, 0),
|
||||||
lastHandshakeTime: hs.Details.Time,
|
lastHandshakeTime: hs.Details.Time,
|
||||||
relayState: RelayState{
|
relayState: RelayState{
|
||||||
relays: map[netip.Addr]struct{}{},
|
relays: map[netip.Addr]struct{}{},
|
||||||
relayForByIp: map[netip.Addr]*Relay{},
|
relayForByAddr: map[netip.Addr]*Relay{},
|
||||||
relayForByIdx: map[uint32]*Relay{},
|
relayForByIdx: map[uint32]*Relay{},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
|
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
@ -165,13 +231,26 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
Info("Handshake message received")
|
Info("Handshake message received")
|
||||||
|
|
||||||
hs.Details.ResponderIndex = myIndex
|
hs.Details.ResponderIndex = myIndex
|
||||||
hs.Details.Cert = certState.RawCertificateNoKey
|
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
|
||||||
|
if hs.Details.Cert == nil {
|
||||||
|
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
|
WithField("certName", certName).
|
||||||
|
WithField("fingerprint", fingerprint).
|
||||||
|
WithField("issuer", issuer).
|
||||||
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
|
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||||
|
WithField("certVersion", ci.myCert.Version()).
|
||||||
|
Error("Unable to handshake with host because no certificate handshake bytes is available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
hs.Details.CertVersion = uint32(ci.myCert.Version())
|
||||||
// Update the time in case their clock is way off from ours
|
// Update the time in case their clock is way off from ours
|
||||||
hs.Details.Time = uint64(time.Now().UnixNano())
|
hs.Details.Time = uint64(time.Now().UnixNano())
|
||||||
|
|
||||||
hsBytes, err := hs.Marshal()
|
hsBytes, err := hs.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
@ -182,14 +261,14 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2)
|
nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2)
|
||||||
msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes)
|
msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
|
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
|
||||||
return
|
return
|
||||||
} else if dKey == nil || eKey == nil {
|
} else if dKey == nil || eKey == nil {
|
||||||
f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
|
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
@ -213,9 +292,9 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
ci.dKey = NewNebulaCipherState(dKey)
|
ci.dKey = NewNebulaCipherState(dKey)
|
||||||
ci.eKey = NewNebulaCipherState(eKey)
|
ci.eKey = NewNebulaCipherState(eKey)
|
||||||
|
|
||||||
hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
|
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
|
||||||
hostinfo.SetRemote(addr)
|
hostinfo.SetRemote(addr)
|
||||||
hostinfo.CreateRemoteCIDR(remoteCert.Certificate)
|
hostinfo.buildNetworks(remoteCert.Certificate)
|
||||||
|
|
||||||
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
|
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -225,7 +304,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
if existing.SetRemoteIfPreferred(f.hostMap, addr) {
|
if existing.SetRemoteIfPreferred(f.hostMap, addr) {
|
||||||
// Send a test packet to ensure the other side has also switched to
|
// Send a test packet to ensure the other side has also switched to
|
||||||
// the preferred remote
|
// the preferred remote
|
||||||
f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
||||||
}
|
}
|
||||||
|
|
||||||
msg = existing.HandshakePacket[2]
|
msg = existing.HandshakePacket[2]
|
||||||
@ -233,11 +312,11 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
if addr.IsValid() {
|
if addr.IsValid() {
|
||||||
err := f.outside.WriteTo(msg, addr)
|
err := f.outside.WriteTo(msg, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr).
|
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
||||||
WithError(err).Error("Failed to send handshake message")
|
WithError(err).Error("Failed to send handshake message")
|
||||||
} else {
|
} else {
|
||||||
f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr).
|
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
||||||
Info("Handshake message sent")
|
Info("Handshake message sent")
|
||||||
}
|
}
|
||||||
@ -247,16 +326,16 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
f.l.Error("Handshake send failed: both addr and via are nil.")
|
f.l.Error("Handshake send failed: both addr and via are nil.")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
|
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
||||||
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
|
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
|
||||||
f.l.WithField("vpnIp", existing.vpnIp).WithField("relay", via.relayHI.vpnIp).
|
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
|
||||||
Info("Handshake message sent")
|
Info("Handshake message sent")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
case ErrExistingHostInfo:
|
case ErrExistingHostInfo:
|
||||||
// This means there was an existing tunnel and this handshake was older than the one we are currently based on
|
// This means there was an existing tunnel and this handshake was older than the one we are currently based on
|
||||||
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
|
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("oldHandshakeTime", existing.lastHandshakeTime).
|
WithField("oldHandshakeTime", existing.lastHandshakeTime).
|
||||||
WithField("newHandshakeTime", hostinfo.lastHandshakeTime).
|
WithField("newHandshakeTime", hostinfo.lastHandshakeTime).
|
||||||
@ -267,23 +346,23 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
Info("Handshake too old")
|
Info("Handshake too old")
|
||||||
|
|
||||||
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
|
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
|
||||||
f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
||||||
return
|
return
|
||||||
case ErrLocalIndexCollision:
|
case ErrLocalIndexCollision:
|
||||||
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
|
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
|
||||||
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
|
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
|
||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||||
WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnIp).
|
WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnAddrs).
|
||||||
Error("Failed to add HostInfo due to localIndex collision")
|
Error("Failed to add HostInfo due to localIndex collision")
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
|
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
|
||||||
// And we forget to update it here
|
// And we forget to update it here
|
||||||
f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
@ -299,7 +378,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
if addr.IsValid() {
|
if addr.IsValid() {
|
||||||
err = f.outside.WriteTo(msg, addr)
|
err = f.outside.WriteTo(msg, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
|
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
@ -307,7 +386,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||||
WithError(err).Error("Failed to send handshake")
|
WithError(err).Error("Failed to send handshake")
|
||||||
} else {
|
} else {
|
||||||
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
|
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
@ -320,9 +399,9 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
|
|||||||
f.l.Error("Handshake send failed: both addr and via are nil.")
|
f.l.Error("Handshake send failed: both addr and via are nil.")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
|
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
||||||
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
|
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
|
||||||
f.l.WithField("vpnIp", vpnIp).WithField("relay", via.relayHI.vpnIp).
|
f.l.WithField("vpnAddrs", vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
@ -349,8 +428,9 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
|
|
||||||
hostinfo := hh.hostinfo
|
hostinfo := hh.hostinfo
|
||||||
if addr.IsValid() {
|
if addr.IsValid() {
|
||||||
if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.Addr()) {
|
//TODO: this is kind of nonsense now
|
||||||
f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
|
if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnAddrs[0], addr.Addr()) {
|
||||||
|
f.l.WithField("vpnIp", hostinfo.vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -358,7 +438,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
ci := hostinfo.ConnectionState
|
ci := hostinfo.ConnectionState
|
||||||
msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
|
msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
|
||||||
Error("Failed to call noise.ReadMessage")
|
Error("Failed to call noise.ReadMessage")
|
||||||
|
|
||||||
@ -367,7 +447,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
// near future
|
// near future
|
||||||
return false
|
return false
|
||||||
} else if dKey == nil || eKey == nil {
|
} else if dKey == nil || eKey == nil {
|
||||||
f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
|
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||||
Error("Noise did not arrive at a key")
|
Error("Noise did not arrive at a key")
|
||||||
|
|
||||||
@ -379,16 +459,16 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
hs := &NebulaHandshake{}
|
hs := &NebulaHandshake{}
|
||||||
err = hs.Unmarshal(msg)
|
err = hs.Unmarshal(msg)
|
||||||
if err != nil || hs.Details == nil {
|
if err != nil || hs.Details == nil {
|
||||||
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
|
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
|
||||||
|
|
||||||
// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
|
// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool())
|
remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
e := f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
|
e := f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
|
||||||
|
|
||||||
if f.l.Level > logrus.DebugLevel {
|
if f.l.Level > logrus.DebugLevel {
|
||||||
@ -413,7 +493,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
vpnIp := remoteCert.Certificate.Networks()[0].Addr().Unmap()
|
vpnNetworks := remoteCert.Certificate.Networks()
|
||||||
certName := remoteCert.Certificate.Name()
|
certName := remoteCert.Certificate.Name()
|
||||||
fingerprint := remoteCert.Fingerprint
|
fingerprint := remoteCert.Fingerprint
|
||||||
issuer := remoteCert.Certificate.Issuer()
|
issuer := remoteCert.Certificate.Issuer()
|
||||||
@ -430,12 +510,17 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
if addr.IsValid() {
|
if addr.IsValid() {
|
||||||
hostinfo.SetRemote(addr)
|
hostinfo.SetRemote(addr)
|
||||||
} else {
|
} else {
|
||||||
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
|
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
|
||||||
|
for i, n := range vpnNetworks {
|
||||||
|
vpnAddrs[i] = n.Addr()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure the right host responded
|
// Ensure the right host responded
|
||||||
if vpnIp != hostinfo.vpnIp {
|
if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) {
|
||||||
f.l.WithField("intendedVpnIp", hostinfo.vpnIp).WithField("haveVpnIp", vpnIp).
|
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
|
||||||
WithField("udpAddr", addr).WithField("certName", certName).
|
WithField("udpAddr", addr).WithField("certName", certName).
|
||||||
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
|
||||||
Info("Incorrect host responded to handshake")
|
Info("Incorrect host responded to handshake")
|
||||||
@ -444,14 +529,14 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
f.handshakeManager.DeleteHostInfo(hostinfo)
|
f.handshakeManager.DeleteHostInfo(hostinfo)
|
||||||
|
|
||||||
// Create a new hostinfo/handshake for the intended vpn ip
|
// Create a new hostinfo/handshake for the intended vpn ip
|
||||||
f.handshakeManager.StartHandshake(hostinfo.vpnIp, func(newHH *HandshakeHostInfo) {
|
f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
|
||||||
//TODO: this doesnt know if its being added or is being used for caching a packet
|
//TODO: this doesnt know if its being added or is being used for caching a packet
|
||||||
// Block the current used address
|
// Block the current used address
|
||||||
newHH.hostinfo.remotes = hostinfo.remotes
|
newHH.hostinfo.remotes = hostinfo.remotes
|
||||||
newHH.hostinfo.remotes.BlockRemote(addr)
|
newHH.hostinfo.remotes.BlockRemote(addr)
|
||||||
|
|
||||||
f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).
|
f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).
|
||||||
WithField("vpnIp", newHH.hostinfo.vpnIp).
|
WithField("vpnNetworks", vpnNetworks).
|
||||||
WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())).
|
WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())).
|
||||||
Info("Blocked addresses for handshakes")
|
Info("Blocked addresses for handshakes")
|
||||||
|
|
||||||
@ -459,11 +544,8 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
newHH.packetStore = hh.packetStore
|
newHH.packetStore = hh.packetStore
|
||||||
hh.packetStore = []*cachedPacket{}
|
hh.packetStore = []*cachedPacket{}
|
||||||
|
|
||||||
// Get the correct remote list for the host we did handshake with
|
// Finally, put the correct vpn addrs in the host info, tell them to close the tunnel, and return true to tear down
|
||||||
hostinfo.SetRemote(addr)
|
hostinfo.vpnAddrs = vpnAddrs
|
||||||
hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
|
|
||||||
// Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down
|
|
||||||
hostinfo.vpnIp = vpnIp
|
|
||||||
f.sendCloseTunnel(hostinfo)
|
f.sendCloseTunnel(hostinfo)
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -474,7 +556,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
ci.window.Update(f.l, 2)
|
ci.window.Update(f.l, 2)
|
||||||
|
|
||||||
duration := time.Since(hh.startTime).Nanoseconds()
|
duration := time.Since(hh.startTime).Nanoseconds()
|
||||||
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
|
f.l.WithField("vpnNetworks", vpnNetworks).WithField("udpAddr", addr).
|
||||||
WithField("certName", certName).
|
WithField("certName", certName).
|
||||||
WithField("fingerprint", fingerprint).
|
WithField("fingerprint", fingerprint).
|
||||||
WithField("issuer", issuer).
|
WithField("issuer", issuer).
|
||||||
@ -485,7 +567,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
|
|||||||
Info("Handshake message received")
|
Info("Handshake message received")
|
||||||
|
|
||||||
// Build up the radix for the firewall if we have subnets in the cert
|
// Build up the radix for the firewall if we have subnets in the cert
|
||||||
hostinfo.CreateRemoteCIDR(remoteCert.Certificate)
|
hostinfo.buildNetworks(remoteCert.Certificate)
|
||||||
|
|
||||||
// Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp
|
// Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp
|
||||||
f.handshakeManager.Complete(hostinfo, f)
|
f.handshakeManager.Complete(hostinfo, f)
|
||||||
|
|||||||
@ -13,6 +13,7 @@ import (
|
|||||||
|
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
)
|
)
|
||||||
@ -118,18 +119,18 @@ func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *Lig
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *HandshakeManager) Run(ctx context.Context) {
|
func (hm *HandshakeManager) Run(ctx context.Context) {
|
||||||
clockSource := time.NewTicker(c.config.tryInterval)
|
clockSource := time.NewTicker(hm.config.tryInterval)
|
||||||
defer clockSource.Stop()
|
defer clockSource.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
case vpnIP := <-c.trigger:
|
case vpnIP := <-hm.trigger:
|
||||||
c.handleOutbound(vpnIP, true)
|
hm.handleOutbound(vpnIP, true)
|
||||||
case now := <-clockSource.C:
|
case now := <-clockSource.C:
|
||||||
c.NextOutboundHandshakeTimerTick(now)
|
hm.NextOutboundHandshakeTimerTick(now)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -159,14 +160,14 @@ func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) {
|
func (hm *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) {
|
||||||
c.OutboundHandshakeTimer.Advance(now)
|
hm.OutboundHandshakeTimer.Advance(now)
|
||||||
for {
|
for {
|
||||||
vpnIp, has := c.OutboundHandshakeTimer.Purge()
|
vpnIp, has := hm.OutboundHandshakeTimer.Purge()
|
||||||
if !has {
|
if !has {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
c.handleOutbound(vpnIp, false)
|
hm.handleOutbound(vpnIp, false)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -208,7 +209,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
|||||||
// NB ^ This comment doesn't jive. It's how the thing gets initialized.
|
// NB ^ This comment doesn't jive. It's how the thing gets initialized.
|
||||||
// It's the common path. Should it update every time, in case a future LH query/queries give us more info?
|
// It's the common path. Should it update every time, in case a future LH query/queries give us more info?
|
||||||
if hostinfo.remotes == nil {
|
if hostinfo.remotes == nil {
|
||||||
hostinfo.remotes = hm.lightHouse.QueryCache(vpnIp)
|
hostinfo.remotes = hm.lightHouse.QueryCache([]netip.Addr{vpnIp})
|
||||||
}
|
}
|
||||||
|
|
||||||
remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())
|
remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())
|
||||||
@ -267,11 +268,18 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
|||||||
hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
|
hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
|
||||||
// Send a RelayRequest to all known Relay IP's
|
// Send a RelayRequest to all known Relay IP's
|
||||||
for _, relay := range hostinfo.remotes.relays {
|
for _, relay := range hostinfo.remotes.relays {
|
||||||
// Don't relay to myself, and don't relay through the host I'm trying to connect to
|
// Don't relay to myself
|
||||||
if relay == vpnIp || relay == hm.lightHouse.myVpnNet.Addr() {
|
if relay == vpnIp {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
relayHostInfo := hm.mainHostMap.QueryVpnIp(relay)
|
|
||||||
|
// Don't relay through the host I'm trying to connect to
|
||||||
|
_, found := hm.f.myVpnAddrsTable.Lookup(relay)
|
||||||
|
if found {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
relayHostInfo := hm.mainHostMap.QueryVpnAddr(relay)
|
||||||
if relayHostInfo == nil || !relayHostInfo.remote.IsValid() {
|
if relayHostInfo == nil || !relayHostInfo.remote.IsValid() {
|
||||||
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
|
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
|
||||||
hm.f.Handshake(relay)
|
hm.f.Handshake(relay)
|
||||||
@ -286,17 +294,35 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
|||||||
case Requested:
|
case Requested:
|
||||||
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request")
|
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request")
|
||||||
|
|
||||||
//TODO: IPV6-WORK
|
|
||||||
myVpnIpB := hm.f.myVpnNet.Addr().As4()
|
|
||||||
theirVpnIpB := vpnIp.As4()
|
|
||||||
|
|
||||||
// Re-send the CreateRelay request, in case the previous one was lost.
|
|
||||||
m := NebulaControl{
|
m := NebulaControl{
|
||||||
Type: NebulaControl_CreateRelayRequest,
|
Type: NebulaControl_CreateRelayRequest,
|
||||||
InitiatorRelayIndex: existingRelay.LocalIndex,
|
InitiatorRelayIndex: existingRelay.LocalIndex,
|
||||||
RelayFromIp: binary.BigEndian.Uint32(myVpnIpB[:]),
|
|
||||||
RelayToIp: binary.BigEndian.Uint32(theirVpnIpB[:]),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
switch relayHostInfo.GetCert().Certificate.Version() {
|
||||||
|
case cert.Version1:
|
||||||
|
if !hm.f.myVpnAddrs[0].Is4() {
|
||||||
|
hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !vpnIp.Is4() {
|
||||||
|
hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
b := hm.f.myVpnAddrs[0].As4()
|
||||||
|
m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
|
||||||
|
b = vpnIp.As4()
|
||||||
|
m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
|
||||||
|
case cert.Version2:
|
||||||
|
m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0])
|
||||||
|
m.RelayToAddr = netAddrToProtoAddr(vpnIp)
|
||||||
|
default:
|
||||||
|
hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
msg, err := m.Marshal()
|
msg, err := m.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(hm.l).
|
hostinfo.logger(hm.l).
|
||||||
@ -306,7 +332,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
|||||||
// This must send over the hostinfo, not over hm.Hosts[ip]
|
// This must send over the hostinfo, not over hm.Hosts[ip]
|
||||||
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
||||||
hm.l.WithFields(logrus.Fields{
|
hm.l.WithFields(logrus.Fields{
|
||||||
"relayFrom": hm.f.myVpnNet.Addr(),
|
"relayFrom": hm.f.myVpnAddrs[0],
|
||||||
"relayTo": vpnIp,
|
"relayTo": vpnIp,
|
||||||
"initiatorRelayIndex": existingRelay.LocalIndex,
|
"initiatorRelayIndex": existingRelay.LocalIndex,
|
||||||
"relay": relay}).
|
"relay": relay}).
|
||||||
@ -316,7 +342,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
|||||||
hostinfo.logger(hm.l).
|
hostinfo.logger(hm.l).
|
||||||
WithField("vpnIp", vpnIp).
|
WithField("vpnIp", vpnIp).
|
||||||
WithField("state", existingRelay.State).
|
WithField("state", existingRelay.State).
|
||||||
WithField("relay", relayHostInfo.vpnIp).
|
WithField("relay", relayHostInfo.vpnAddrs[0]).
|
||||||
Errorf("Relay unexpected state")
|
Errorf("Relay unexpected state")
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -327,16 +353,35 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
|||||||
hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap")
|
hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap")
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO: IPV6-WORK
|
|
||||||
myVpnIpB := hm.f.myVpnNet.Addr().As4()
|
|
||||||
theirVpnIpB := vpnIp.As4()
|
|
||||||
|
|
||||||
m := NebulaControl{
|
m := NebulaControl{
|
||||||
Type: NebulaControl_CreateRelayRequest,
|
Type: NebulaControl_CreateRelayRequest,
|
||||||
InitiatorRelayIndex: idx,
|
InitiatorRelayIndex: idx,
|
||||||
RelayFromIp: binary.BigEndian.Uint32(myVpnIpB[:]),
|
|
||||||
RelayToIp: binary.BigEndian.Uint32(theirVpnIpB[:]),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
switch relayHostInfo.GetCert().Certificate.Version() {
|
||||||
|
case cert.Version1:
|
||||||
|
if !hm.f.myVpnAddrs[0].Is4() {
|
||||||
|
hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !vpnIp.Is4() {
|
||||||
|
hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
b := hm.f.myVpnAddrs[0].As4()
|
||||||
|
m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
|
||||||
|
b = vpnIp.As4()
|
||||||
|
m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
|
||||||
|
case cert.Version2:
|
||||||
|
m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0])
|
||||||
|
m.RelayToAddr = netAddrToProtoAddr(vpnIp)
|
||||||
|
default:
|
||||||
|
hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
msg, err := m.Marshal()
|
msg, err := m.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(hm.l).
|
hostinfo.logger(hm.l).
|
||||||
@ -345,7 +390,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
|||||||
} else {
|
} else {
|
||||||
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
||||||
hm.l.WithFields(logrus.Fields{
|
hm.l.WithFields(logrus.Fields{
|
||||||
"relayFrom": hm.f.myVpnNet.Addr(),
|
"relayFrom": hm.f.myVpnAddrs[0],
|
||||||
"relayTo": vpnIp,
|
"relayTo": vpnIp,
|
||||||
"initiatorRelayIndex": idx,
|
"initiatorRelayIndex": idx,
|
||||||
"relay": relay}).
|
"relay": relay}).
|
||||||
@ -381,10 +426,10 @@ func (hm *HandshakeManager) GetOrHandshake(vpnIp netip.Addr, cacheCb func(*Hands
|
|||||||
}
|
}
|
||||||
|
|
||||||
// StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip
|
// StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip
|
||||||
func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*HandshakeHostInfo)) *HostInfo {
|
func (hm *HandshakeManager) StartHandshake(vpnAddr netip.Addr, cacheCb func(*HandshakeHostInfo)) *HostInfo {
|
||||||
hm.Lock()
|
hm.Lock()
|
||||||
|
|
||||||
if hh, ok := hm.vpnIps[vpnIp]; ok {
|
if hh, ok := hm.vpnIps[vpnAddr]; ok {
|
||||||
// We are already trying to handshake with this vpn ip
|
// We are already trying to handshake with this vpn ip
|
||||||
if cacheCb != nil {
|
if cacheCb != nil {
|
||||||
cacheCb(hh)
|
cacheCb(hh)
|
||||||
@ -394,11 +439,11 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands
|
|||||||
}
|
}
|
||||||
|
|
||||||
hostinfo := &HostInfo{
|
hostinfo := &HostInfo{
|
||||||
vpnIp: vpnIp,
|
vpnAddrs: []netip.Addr{vpnAddr},
|
||||||
HandshakePacket: make(map[uint8][]byte, 0),
|
HandshakePacket: make(map[uint8][]byte, 0),
|
||||||
relayState: RelayState{
|
relayState: RelayState{
|
||||||
relays: map[netip.Addr]struct{}{},
|
relays: map[netip.Addr]struct{}{},
|
||||||
relayForByIp: map[netip.Addr]*Relay{},
|
relayForByAddr: map[netip.Addr]*Relay{},
|
||||||
relayForByIdx: map[uint32]*Relay{},
|
relayForByIdx: map[uint32]*Relay{},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -407,9 +452,9 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands
|
|||||||
hostinfo: hostinfo,
|
hostinfo: hostinfo,
|
||||||
startTime: time.Now(),
|
startTime: time.Now(),
|
||||||
}
|
}
|
||||||
hm.vpnIps[vpnIp] = hh
|
hm.vpnIps[vpnAddr] = hh
|
||||||
hm.metricInitiated.Inc(1)
|
hm.metricInitiated.Inc(1)
|
||||||
hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval)
|
hm.OutboundHandshakeTimer.Add(vpnAddr, hm.config.tryInterval)
|
||||||
|
|
||||||
if cacheCb != nil {
|
if cacheCb != nil {
|
||||||
cacheCb(hh)
|
cacheCb(hh)
|
||||||
@ -417,21 +462,21 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands
|
|||||||
|
|
||||||
// If this is a static host, we don't need to wait for the HostQueryReply
|
// If this is a static host, we don't need to wait for the HostQueryReply
|
||||||
// We can trigger the handshake right now
|
// We can trigger the handshake right now
|
||||||
_, doTrigger := hm.lightHouse.GetStaticHostList()[vpnIp]
|
_, doTrigger := hm.lightHouse.GetStaticHostList()[vpnAddr]
|
||||||
if !doTrigger {
|
if !doTrigger {
|
||||||
// Add any calculated remotes, and trigger early handshake if one found
|
// Add any calculated remotes, and trigger early handshake if one found
|
||||||
doTrigger = hm.lightHouse.addCalculatedRemotes(vpnIp)
|
doTrigger = hm.lightHouse.addCalculatedRemotes(vpnAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
if doTrigger {
|
if doTrigger {
|
||||||
select {
|
select {
|
||||||
case hm.trigger <- vpnIp:
|
case hm.trigger <- vpnAddr:
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
hm.Unlock()
|
hm.Unlock()
|
||||||
hm.lightHouse.QueryServer(vpnIp)
|
hm.lightHouse.QueryServer(vpnAddr)
|
||||||
return hostinfo
|
return hostinfo
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -452,14 +497,14 @@ var (
|
|||||||
//
|
//
|
||||||
// ErrLocalIndexCollision if we already have an entry in the main or pending
|
// ErrLocalIndexCollision if we already have an entry in the main or pending
|
||||||
// hostmap for the hostinfo.localIndexId.
|
// hostmap for the hostinfo.localIndexId.
|
||||||
func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) {
|
func (hm *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) {
|
||||||
c.mainHostMap.Lock()
|
hm.mainHostMap.Lock()
|
||||||
defer c.mainHostMap.Unlock()
|
defer hm.mainHostMap.Unlock()
|
||||||
c.Lock()
|
hm.Lock()
|
||||||
defer c.Unlock()
|
defer hm.Unlock()
|
||||||
|
|
||||||
// Check if we already have a tunnel with this vpn ip
|
// Check if we already have a tunnel with this vpn ip
|
||||||
existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp]
|
existingHostInfo, found := hm.mainHostMap.Hosts[hostinfo.vpnAddrs[0]]
|
||||||
if found && existingHostInfo != nil {
|
if found && existingHostInfo != nil {
|
||||||
testHostInfo := existingHostInfo
|
testHostInfo := existingHostInfo
|
||||||
for testHostInfo != nil {
|
for testHostInfo != nil {
|
||||||
@ -476,31 +521,31 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
|
|||||||
return existingHostInfo, ErrExistingHostInfo
|
return existingHostInfo, ErrExistingHostInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
existingHostInfo.logger(c.l).Info("Taking new handshake")
|
existingHostInfo.logger(hm.l).Info("Taking new handshake")
|
||||||
}
|
}
|
||||||
|
|
||||||
existingIndex, found := c.mainHostMap.Indexes[hostinfo.localIndexId]
|
existingIndex, found := hm.mainHostMap.Indexes[hostinfo.localIndexId]
|
||||||
if found {
|
if found {
|
||||||
// We have a collision, but for a different hostinfo
|
// We have a collision, but for a different hostinfo
|
||||||
return existingIndex, ErrLocalIndexCollision
|
return existingIndex, ErrLocalIndexCollision
|
||||||
}
|
}
|
||||||
|
|
||||||
existingPendingIndex, found := c.indexes[hostinfo.localIndexId]
|
existingPendingIndex, found := hm.indexes[hostinfo.localIndexId]
|
||||||
if found && existingPendingIndex.hostinfo != hostinfo {
|
if found && existingPendingIndex.hostinfo != hostinfo {
|
||||||
// We have a collision, but for a different hostinfo
|
// We have a collision, but for a different hostinfo
|
||||||
return existingPendingIndex.hostinfo, ErrLocalIndexCollision
|
return existingPendingIndex.hostinfo, ErrLocalIndexCollision
|
||||||
}
|
}
|
||||||
|
|
||||||
existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
|
existingRemoteIndex, found := hm.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
|
||||||
if found && existingRemoteIndex != nil && existingRemoteIndex.vpnIp != hostinfo.vpnIp {
|
if found && existingRemoteIndex != nil && existingRemoteIndex.vpnAddrs[0] != hostinfo.vpnAddrs[0] {
|
||||||
// We have a collision, but this can happen since we can't control
|
// We have a collision, but this can happen since we can't control
|
||||||
// the remote ID. Just log about the situation as a note.
|
// the remote ID. Just log about the situation as a note.
|
||||||
hostinfo.logger(c.l).
|
hostinfo.logger(hm.l).
|
||||||
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp).
|
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs).
|
||||||
Info("New host shadows existing host remoteIndex")
|
Info("New host shadows existing host remoteIndex")
|
||||||
}
|
}
|
||||||
|
|
||||||
c.mainHostMap.unlockedAddHostInfo(hostinfo, f)
|
hm.mainHostMap.unlockedAddHostInfo(hostinfo, f)
|
||||||
return existingHostInfo, nil
|
return existingHostInfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -518,7 +563,7 @@ func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
|
|||||||
// We have a collision, but this can happen since we can't control
|
// We have a collision, but this can happen since we can't control
|
||||||
// the remote ID. Just log about the situation as a note.
|
// the remote ID. Just log about the situation as a note.
|
||||||
hostinfo.logger(hm.l).
|
hostinfo.logger(hm.l).
|
||||||
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp).
|
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs).
|
||||||
Info("New host shadows existing host remoteIndex")
|
Info("New host shadows existing host remoteIndex")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -555,31 +600,34 @@ func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error {
|
|||||||
return errors.New("failed to generate unique localIndexId")
|
return errors.New("failed to generate unique localIndexId")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
|
func (hm *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
|
||||||
c.Lock()
|
hm.Lock()
|
||||||
defer c.Unlock()
|
defer hm.Unlock()
|
||||||
c.unlockedDeleteHostInfo(hostinfo)
|
hm.unlockedDeleteHostInfo(hostinfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
|
func (hm *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
|
||||||
delete(c.vpnIps, hostinfo.vpnIp)
|
for _, addr := range hostinfo.vpnAddrs {
|
||||||
if len(c.vpnIps) == 0 {
|
delete(hm.vpnIps, addr)
|
||||||
c.vpnIps = map[netip.Addr]*HandshakeHostInfo{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(c.indexes, hostinfo.localIndexId)
|
if len(hm.vpnIps) == 0 {
|
||||||
if len(c.vpnIps) == 0 {
|
hm.vpnIps = map[netip.Addr]*HandshakeHostInfo{}
|
||||||
c.indexes = map[uint32]*HandshakeHostInfo{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.l.Level >= logrus.DebugLevel {
|
delete(hm.indexes, hostinfo.localIndexId)
|
||||||
c.l.WithField("hostMap", m{"mapTotalSize": len(c.vpnIps),
|
if len(hm.indexes) == 0 {
|
||||||
"vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
|
hm.indexes = map[uint32]*HandshakeHostInfo{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if hm.l.Level >= logrus.DebugLevel {
|
||||||
|
hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.vpnIps),
|
||||||
|
"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
|
||||||
Debug("Pending hostmap hostInfo deleted")
|
Debug("Pending hostmap hostInfo deleted")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hm *HandshakeManager) QueryVpnIp(vpnIp netip.Addr) *HostInfo {
|
func (hm *HandshakeManager) QueryVpnAddr(vpnIp netip.Addr) *HostInfo {
|
||||||
hh := hm.queryVpnIp(vpnIp)
|
hh := hm.queryVpnIp(vpnIp)
|
||||||
if hh != nil {
|
if hh != nil {
|
||||||
return hh.hostinfo
|
return hh.hostinfo
|
||||||
@ -608,37 +656,37 @@ func (hm *HandshakeManager) queryIndex(index uint32) *HandshakeHostInfo {
|
|||||||
return hm.indexes[index]
|
return hm.indexes[index]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *HandshakeManager) GetPreferredRanges() []netip.Prefix {
|
func (hm *HandshakeManager) GetPreferredRanges() []netip.Prefix {
|
||||||
return c.mainHostMap.GetPreferredRanges()
|
return hm.mainHostMap.GetPreferredRanges()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *HandshakeManager) ForEachVpnIp(f controlEach) {
|
func (hm *HandshakeManager) ForEachVpnAddr(f controlEach) {
|
||||||
c.RLock()
|
hm.RLock()
|
||||||
defer c.RUnlock()
|
defer hm.RUnlock()
|
||||||
|
|
||||||
for _, v := range c.vpnIps {
|
for _, v := range hm.vpnIps {
|
||||||
f(v.hostinfo)
|
f(v.hostinfo)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *HandshakeManager) ForEachIndex(f controlEach) {
|
func (hm *HandshakeManager) ForEachIndex(f controlEach) {
|
||||||
c.RLock()
|
hm.RLock()
|
||||||
defer c.RUnlock()
|
defer hm.RUnlock()
|
||||||
|
|
||||||
for _, v := range c.indexes {
|
for _, v := range hm.indexes {
|
||||||
f(v.hostinfo)
|
f(v.hostinfo)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *HandshakeManager) EmitStats() {
|
func (hm *HandshakeManager) EmitStats() {
|
||||||
c.RLock()
|
hm.RLock()
|
||||||
hostLen := len(c.vpnIps)
|
hostLen := len(hm.vpnIps)
|
||||||
indexLen := len(c.indexes)
|
indexLen := len(hm.indexes)
|
||||||
c.RUnlock()
|
hm.RUnlock()
|
||||||
|
|
||||||
metrics.GetOrRegisterGauge("hostmap.pending.hosts", nil).Update(int64(hostLen))
|
metrics.GetOrRegisterGauge("hostmap.pending.hosts", nil).Update(int64(hostLen))
|
||||||
metrics.GetOrRegisterGauge("hostmap.pending.indexes", nil).Update(int64(indexLen))
|
metrics.GetOrRegisterGauge("hostmap.pending.indexes", nil).Update(int64(indexLen))
|
||||||
c.mainHostMap.EmitStats()
|
hm.mainHostMap.EmitStats()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Utility functions below
|
// Utility functions below
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"github.com/slackhq/nebula/test"
|
"github.com/slackhq/nebula/test"
|
||||||
"github.com/slackhq/nebula/udp"
|
"github.com/slackhq/nebula/udp"
|
||||||
@ -13,21 +14,20 @@ import (
|
|||||||
|
|
||||||
func Test_NewHandshakeManagerVpnIp(t *testing.T) {
|
func Test_NewHandshakeManagerVpnIp(t *testing.T) {
|
||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
vpncidr := netip.MustParsePrefix("172.1.1.1/24")
|
|
||||||
localrange := netip.MustParsePrefix("10.1.1.1/24")
|
localrange := netip.MustParsePrefix("10.1.1.1/24")
|
||||||
ip := netip.MustParseAddr("172.1.1.2")
|
ip := netip.MustParseAddr("172.1.1.2")
|
||||||
|
|
||||||
preferredRanges := []netip.Prefix{localrange}
|
preferredRanges := []netip.Prefix{localrange}
|
||||||
mainHM := newHostMap(l, vpncidr)
|
mainHM := newHostMap(l)
|
||||||
mainHM.preferredRanges.Store(&preferredRanges)
|
mainHM.preferredRanges.Store(&preferredRanges)
|
||||||
|
|
||||||
lh := newTestLighthouse()
|
lh := newTestLighthouse()
|
||||||
|
|
||||||
cs := &CertState{
|
cs := &CertState{
|
||||||
RawCertificate: []byte{},
|
defaultVersion: cert.Version1,
|
||||||
PrivateKey: []byte{},
|
privateKey: []byte{},
|
||||||
Certificate: &dummyCert{},
|
v1Cert: &dummyCert{version: cert.Version1},
|
||||||
RawCertificateNoKey: []byte{},
|
v1HandshakeBytes: []byte{},
|
||||||
}
|
}
|
||||||
|
|
||||||
blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
|
blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
|
||||||
@ -41,7 +41,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
|
|||||||
i2 := blah.StartHandshake(ip, nil)
|
i2 := blah.StartHandshake(ip, nil)
|
||||||
assert.Same(t, i, i2)
|
assert.Same(t, i, i2)
|
||||||
|
|
||||||
i.remotes = NewRemoteList(nil)
|
i.remotes = NewRemoteList([]netip.Addr{}, nil)
|
||||||
|
|
||||||
// Adding something to pending should not affect the main hostmap
|
// Adding something to pending should not affect the main hostmap
|
||||||
assert.Len(t, mainHM.Hosts, 0)
|
assert.Len(t, mainHM.Hosts, 0)
|
||||||
@ -79,16 +79,24 @@ func testCountTimerWheelEntries(tw *LockingTimerWheel[netip.Addr]) (c int) {
|
|||||||
type mockEncWriter struct {
|
type mockEncWriter struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) {
|
func (mw *mockEncWriter) SendMessageToVpnAddr(_ header.MessageType, _ header.MessageSubType, _ netip.Addr, _, _, _ []byte) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mw *mockEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
|
func (mw *mockEncWriter) SendVia(_ *HostInfo, _ *Relay, _, _, _ []byte, _ bool) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mw *mockEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
|
func (mw *mockEncWriter) SendMessageToHostInfo(_ header.MessageType, _ header.MessageSubType, _ *HostInfo, _, _, _ []byte) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mw *mockEncWriter) Handshake(vpnIP netip.Addr) {}
|
func (mw *mockEncWriter) Handshake(_ netip.Addr) {}
|
||||||
|
|
||||||
|
func (mw *mockEncWriter) GetHostInfo(_ netip.Addr) *HostInfo {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mw *mockEncWriter) GetCertState() *CertState {
|
||||||
|
return &CertState{defaultVersion: cert.Version2}
|
||||||
|
}
|
||||||
|
|||||||
171
hostmap.go
171
hostmap.go
@ -48,7 +48,7 @@ type Relay struct {
|
|||||||
State int
|
State int
|
||||||
LocalIndex uint32
|
LocalIndex uint32
|
||||||
RemoteIndex uint32
|
RemoteIndex uint32
|
||||||
PeerIp netip.Addr
|
PeerAddr netip.Addr
|
||||||
}
|
}
|
||||||
|
|
||||||
type HostMap struct {
|
type HostMap struct {
|
||||||
@ -58,7 +58,6 @@ type HostMap struct {
|
|||||||
RemoteIndexes map[uint32]*HostInfo
|
RemoteIndexes map[uint32]*HostInfo
|
||||||
Hosts map[netip.Addr]*HostInfo
|
Hosts map[netip.Addr]*HostInfo
|
||||||
preferredRanges atomic.Pointer[[]netip.Prefix]
|
preferredRanges atomic.Pointer[[]netip.Prefix]
|
||||||
vpnCIDR netip.Prefix
|
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -68,8 +67,8 @@ type HostMap struct {
|
|||||||
type RelayState struct {
|
type RelayState struct {
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
|
|
||||||
relays map[netip.Addr]struct{} // Set of VpnIp's of Hosts to use as relays to access this peer
|
relays map[netip.Addr]struct{} // Set of vpnAddr's of Hosts to use as relays to access this peer
|
||||||
relayForByIp map[netip.Addr]*Relay // Maps VpnIps of peers for which this HostInfo is a relay to some Relay info
|
relayForByAddr map[netip.Addr]*Relay // Maps vpnAddr of peers for which this HostInfo is a relay to some Relay info
|
||||||
relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info
|
relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -89,10 +88,10 @@ func (rs *RelayState) CopyAllRelayFor() []*Relay {
|
|||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rs *RelayState) GetRelayForByIp(ip netip.Addr) (*Relay, bool) {
|
func (rs *RelayState) GetRelayForByAddr(addr netip.Addr) (*Relay, bool) {
|
||||||
rs.RLock()
|
rs.RLock()
|
||||||
defer rs.RUnlock()
|
defer rs.RUnlock()
|
||||||
r, ok := rs.relayForByIp[ip]
|
r, ok := rs.relayForByAddr[addr]
|
||||||
return r, ok
|
return r, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -115,8 +114,8 @@ func (rs *RelayState) CopyRelayIps() []netip.Addr {
|
|||||||
func (rs *RelayState) CopyRelayForIps() []netip.Addr {
|
func (rs *RelayState) CopyRelayForIps() []netip.Addr {
|
||||||
rs.RLock()
|
rs.RLock()
|
||||||
defer rs.RUnlock()
|
defer rs.RUnlock()
|
||||||
currentRelays := make([]netip.Addr, 0, len(rs.relayForByIp))
|
currentRelays := make([]netip.Addr, 0, len(rs.relayForByAddr))
|
||||||
for relayIp := range rs.relayForByIp {
|
for relayIp := range rs.relayForByAddr {
|
||||||
currentRelays = append(currentRelays, relayIp)
|
currentRelays = append(currentRelays, relayIp)
|
||||||
}
|
}
|
||||||
return currentRelays
|
return currentRelays
|
||||||
@ -135,7 +134,7 @@ func (rs *RelayState) CopyRelayForIdxs() []uint32 {
|
|||||||
func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool {
|
func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool {
|
||||||
rs.Lock()
|
rs.Lock()
|
||||||
defer rs.Unlock()
|
defer rs.Unlock()
|
||||||
r, ok := rs.relayForByIp[vpnIp]
|
r, ok := rs.relayForByAddr[vpnIp]
|
||||||
if !ok {
|
if !ok {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@ -143,7 +142,7 @@ func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool
|
|||||||
newRelay.State = Established
|
newRelay.State = Established
|
||||||
newRelay.RemoteIndex = remoteIdx
|
newRelay.RemoteIndex = remoteIdx
|
||||||
rs.relayForByIdx[r.LocalIndex] = &newRelay
|
rs.relayForByIdx[r.LocalIndex] = &newRelay
|
||||||
rs.relayForByIp[r.PeerIp] = &newRelay
|
rs.relayForByAddr[r.PeerAddr] = &newRelay
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -158,14 +157,14 @@ func (rs *RelayState) CompleteRelayByIdx(localIdx uint32, remoteIdx uint32) (*Re
|
|||||||
newRelay.State = Established
|
newRelay.State = Established
|
||||||
newRelay.RemoteIndex = remoteIdx
|
newRelay.RemoteIndex = remoteIdx
|
||||||
rs.relayForByIdx[r.LocalIndex] = &newRelay
|
rs.relayForByIdx[r.LocalIndex] = &newRelay
|
||||||
rs.relayForByIp[r.PeerIp] = &newRelay
|
rs.relayForByAddr[r.PeerAddr] = &newRelay
|
||||||
return &newRelay, true
|
return &newRelay, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rs *RelayState) QueryRelayForByIp(vpnIp netip.Addr) (*Relay, bool) {
|
func (rs *RelayState) QueryRelayForByIp(vpnIp netip.Addr) (*Relay, bool) {
|
||||||
rs.RLock()
|
rs.RLock()
|
||||||
defer rs.RUnlock()
|
defer rs.RUnlock()
|
||||||
r, ok := rs.relayForByIp[vpnIp]
|
r, ok := rs.relayForByAddr[vpnIp]
|
||||||
return r, ok
|
return r, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -179,7 +178,7 @@ func (rs *RelayState) QueryRelayForByIdx(idx uint32) (*Relay, bool) {
|
|||||||
func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) {
|
func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) {
|
||||||
rs.Lock()
|
rs.Lock()
|
||||||
defer rs.Unlock()
|
defer rs.Unlock()
|
||||||
rs.relayForByIp[ip] = r
|
rs.relayForByAddr[ip] = r
|
||||||
rs.relayForByIdx[idx] = r
|
rs.relayForByIdx[idx] = r
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -190,9 +189,11 @@ type HostInfo struct {
|
|||||||
ConnectionState *ConnectionState
|
ConnectionState *ConnectionState
|
||||||
remoteIndexId uint32
|
remoteIndexId uint32
|
||||||
localIndexId uint32
|
localIndexId uint32
|
||||||
vpnIp netip.Addr
|
vpnAddrs []netip.Addr
|
||||||
recvError atomic.Uint32
|
recvError atomic.Uint32
|
||||||
remoteCidr *bart.Table[struct{}]
|
|
||||||
|
// networks are both all vpn and unsafe networks assigned to this host
|
||||||
|
networks *bart.Table[struct{}]
|
||||||
relayState RelayState
|
relayState RelayState
|
||||||
|
|
||||||
// HandshakePacket records the packets used to create this hostinfo
|
// HandshakePacket records the packets used to create this hostinfo
|
||||||
@ -241,28 +242,26 @@ type cachedPacketMetrics struct {
|
|||||||
dropped metrics.Counter
|
dropped metrics.Counter
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR netip.Prefix, c *config.C) *HostMap {
|
func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap {
|
||||||
hm := newHostMap(l, vpnCIDR)
|
hm := newHostMap(l)
|
||||||
|
|
||||||
hm.reload(c, true)
|
hm.reload(c, true)
|
||||||
c.RegisterReloadCallback(func(c *config.C) {
|
c.RegisterReloadCallback(func(c *config.C) {
|
||||||
hm.reload(c, false)
|
hm.reload(c, false)
|
||||||
})
|
})
|
||||||
|
|
||||||
l.WithField("network", hm.vpnCIDR.String()).
|
l.WithField("preferredRanges", hm.GetPreferredRanges()).
|
||||||
WithField("preferredRanges", hm.GetPreferredRanges()).
|
|
||||||
Info("Main HostMap created")
|
Info("Main HostMap created")
|
||||||
|
|
||||||
return hm
|
return hm
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostMap(l *logrus.Logger, vpnCIDR netip.Prefix) *HostMap {
|
func newHostMap(l *logrus.Logger) *HostMap {
|
||||||
return &HostMap{
|
return &HostMap{
|
||||||
Indexes: map[uint32]*HostInfo{},
|
Indexes: map[uint32]*HostInfo{},
|
||||||
Relays: map[uint32]*HostInfo{},
|
Relays: map[uint32]*HostInfo{},
|
||||||
RemoteIndexes: map[uint32]*HostInfo{},
|
RemoteIndexes: map[uint32]*HostInfo{},
|
||||||
Hosts: map[netip.Addr]*HostInfo{},
|
Hosts: map[netip.Addr]*HostInfo{},
|
||||||
vpnCIDR: vpnCIDR,
|
|
||||||
l: l,
|
l: l,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -305,17 +304,6 @@ func (hm *HostMap) EmitStats() {
|
|||||||
metrics.GetOrRegisterGauge("hostmap.main.relayIndexes", nil).Update(int64(relaysLen))
|
metrics.GetOrRegisterGauge("hostmap.main.relayIndexes", nil).Update(int64(relaysLen))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hm *HostMap) RemoveRelay(localIdx uint32) {
|
|
||||||
hm.Lock()
|
|
||||||
_, ok := hm.Relays[localIdx]
|
|
||||||
if !ok {
|
|
||||||
hm.Unlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
delete(hm.Relays, localIdx)
|
|
||||||
hm.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteHostInfo will fully unlink the hostinfo and return true if it was the final hostinfo for this vpn ip
|
// DeleteHostInfo will fully unlink the hostinfo and return true if it was the final hostinfo for this vpn ip
|
||||||
func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool {
|
func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool {
|
||||||
// Delete the host itself, ensuring it's not modified anymore
|
// Delete the host itself, ensuring it's not modified anymore
|
||||||
@ -335,7 +323,9 @@ func (hm *HostMap) MakePrimary(hostinfo *HostInfo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) {
|
func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) {
|
||||||
oldHostinfo := hm.Hosts[hostinfo.vpnIp]
|
//TODO: we may need to promote follow on hostinfos from these vpnAddrs as well since their oldHostinfo might not be the same as this one
|
||||||
|
// this really looks like an ideal spot for memory leaks
|
||||||
|
oldHostinfo := hm.Hosts[hostinfo.vpnAddrs[0]]
|
||||||
if oldHostinfo == hostinfo {
|
if oldHostinfo == hostinfo {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -348,7 +338,7 @@ func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) {
|
|||||||
hostinfo.next.prev = hostinfo.prev
|
hostinfo.next.prev = hostinfo.prev
|
||||||
}
|
}
|
||||||
|
|
||||||
hm.Hosts[hostinfo.vpnIp] = hostinfo
|
hm.Hosts[hostinfo.vpnAddrs[0]] = hostinfo
|
||||||
|
|
||||||
if oldHostinfo == nil {
|
if oldHostinfo == nil {
|
||||||
return
|
return
|
||||||
@ -360,23 +350,35 @@ func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
|
func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
|
||||||
primary, ok := hm.Hosts[hostinfo.vpnIp]
|
for _, addr := range hostinfo.vpnAddrs {
|
||||||
|
h := hm.Hosts[addr]
|
||||||
|
for h != nil {
|
||||||
|
if h == hostinfo {
|
||||||
|
hm.unlockedInnerDeleteHostInfo(h, addr)
|
||||||
|
}
|
||||||
|
h = h.next
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hm *HostMap) unlockedInnerDeleteHostInfo(hostinfo *HostInfo, addr netip.Addr) {
|
||||||
|
primary, ok := hm.Hosts[addr]
|
||||||
if ok && primary == hostinfo {
|
if ok && primary == hostinfo {
|
||||||
// The vpnIp pointer points to the same hostinfo as the local index id, we can remove it
|
// The vpn addr pointer points to the same hostinfo as the local index id, we can remove it
|
||||||
delete(hm.Hosts, hostinfo.vpnIp)
|
delete(hm.Hosts, addr)
|
||||||
if len(hm.Hosts) == 0 {
|
if len(hm.Hosts) == 0 {
|
||||||
hm.Hosts = map[netip.Addr]*HostInfo{}
|
hm.Hosts = map[netip.Addr]*HostInfo{}
|
||||||
}
|
}
|
||||||
|
|
||||||
if hostinfo.next != nil {
|
if hostinfo.next != nil {
|
||||||
// We had more than 1 hostinfo at this vpnip, promote the next in the list to primary
|
// We had more than 1 hostinfo at this vpn addr, promote the next in the list to primary
|
||||||
hm.Hosts[hostinfo.vpnIp] = hostinfo.next
|
hm.Hosts[addr] = hostinfo.next
|
||||||
// It is primary, there is no previous hostinfo now
|
// It is primary, there is no previous hostinfo now
|
||||||
hostinfo.next.prev = nil
|
hostinfo.next.prev = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
// Relink if we were in the middle of multiple hostinfos for this vpn ip
|
// Relink if we were in the middle of multiple hostinfos for this vpn addr
|
||||||
if hostinfo.prev != nil {
|
if hostinfo.prev != nil {
|
||||||
hostinfo.prev.next = hostinfo.next
|
hostinfo.prev.next = hostinfo.next
|
||||||
}
|
}
|
||||||
@ -406,7 +408,7 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
|
|||||||
|
|
||||||
if hm.l.Level >= logrus.DebugLevel {
|
if hm.l.Level >= logrus.DebugLevel {
|
||||||
hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.Hosts),
|
hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.Hosts),
|
||||||
"vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
|
"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
|
||||||
Debug("Hostmap hostInfo deleted")
|
Debug("Hostmap hostInfo deleted")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -448,11 +450,11 @@ func (hm *HostMap) QueryReverseIndex(index uint32) *HostInfo {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hm *HostMap) QueryVpnIp(vpnIp netip.Addr) *HostInfo {
|
func (hm *HostMap) QueryVpnAddr(vpnIp netip.Addr) *HostInfo {
|
||||||
return hm.queryVpnIp(vpnIp, nil)
|
return hm.queryVpnAddr(vpnIp, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp netip.Addr) (*HostInfo, *Relay, error) {
|
func (hm *HostMap) QueryVpnAddrsRelayFor(targetIps []netip.Addr, relayHostIp netip.Addr) (*HostInfo, *Relay, error) {
|
||||||
hm.RLock()
|
hm.RLock()
|
||||||
defer hm.RUnlock()
|
defer hm.RUnlock()
|
||||||
|
|
||||||
@ -460,17 +462,21 @@ func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp netip.Addr) (*HostIn
|
|||||||
if !ok {
|
if !ok {
|
||||||
return nil, nil, errors.New("unable to find host")
|
return nil, nil, errors.New("unable to find host")
|
||||||
}
|
}
|
||||||
|
|
||||||
for h != nil {
|
for h != nil {
|
||||||
|
for _, targetIp := range targetIps {
|
||||||
r, ok := h.relayState.QueryRelayForByIp(targetIp)
|
r, ok := h.relayState.QueryRelayForByIp(targetIp)
|
||||||
if ok && r.State == Established {
|
if ok && r.State == Established {
|
||||||
return h, r, nil
|
return h, r, nil
|
||||||
}
|
}
|
||||||
|
}
|
||||||
h = h.next
|
h = h.next
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, nil, errors.New("unable to find host with relay")
|
return nil, nil, errors.New("unable to find host with relay")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hm *HostMap) queryVpnIp(vpnIp netip.Addr, promoteIfce *Interface) *HostInfo {
|
func (hm *HostMap) queryVpnAddr(vpnIp netip.Addr, promoteIfce *Interface) *HostInfo {
|
||||||
hm.RLock()
|
hm.RLock()
|
||||||
if h, ok := hm.Hosts[vpnIp]; ok {
|
if h, ok := hm.Hosts[vpnIp]; ok {
|
||||||
hm.RUnlock()
|
hm.RUnlock()
|
||||||
@ -491,25 +497,30 @@ func (hm *HostMap) queryVpnIp(vpnIp netip.Addr, promoteIfce *Interface) *HostInf
|
|||||||
func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
|
func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
|
||||||
if f.serveDns {
|
if f.serveDns {
|
||||||
remoteCert := hostinfo.ConnectionState.peerCert
|
remoteCert := hostinfo.ConnectionState.peerCert
|
||||||
dnsR.Add(remoteCert.Certificate.Name()+".", remoteCert.Certificate.Networks()[0].Addr().String())
|
dnsR.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs)
|
||||||
}
|
}
|
||||||
|
for _, addr := range hostinfo.vpnAddrs {
|
||||||
existing := hm.Hosts[hostinfo.vpnIp]
|
hm.unlockedInnerAddHostInfo(addr, hostinfo, f)
|
||||||
hm.Hosts[hostinfo.vpnIp] = hostinfo
|
|
||||||
|
|
||||||
if existing != nil {
|
|
||||||
hostinfo.next = existing
|
|
||||||
existing.prev = hostinfo
|
|
||||||
}
|
}
|
||||||
|
|
||||||
hm.Indexes[hostinfo.localIndexId] = hostinfo
|
hm.Indexes[hostinfo.localIndexId] = hostinfo
|
||||||
hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
|
hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
|
||||||
|
|
||||||
if hm.l.Level >= logrus.DebugLevel {
|
if hm.l.Level >= logrus.DebugLevel {
|
||||||
hm.l.WithField("hostMap", m{"vpnIp": hostinfo.vpnIp, "mapTotalSize": len(hm.Hosts),
|
hm.l.WithField("hostMap", m{"vpnAddrs": hostinfo.vpnAddrs, "mapTotalSize": len(hm.Hosts),
|
||||||
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": hostinfo.vpnIp}}).
|
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}}).
|
||||||
Debug("Hostmap vpnIp added")
|
Debug("Hostmap vpnIp added")
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hm *HostMap) unlockedInnerAddHostInfo(vpnAddr netip.Addr, hostinfo *HostInfo, f *Interface) {
|
||||||
|
existing := hm.Hosts[vpnAddr]
|
||||||
|
hm.Hosts[vpnAddr] = hostinfo
|
||||||
|
|
||||||
|
if existing != nil && existing != hostinfo {
|
||||||
|
hostinfo.next = existing
|
||||||
|
existing.prev = hostinfo
|
||||||
|
}
|
||||||
|
|
||||||
i := 1
|
i := 1
|
||||||
check := hostinfo
|
check := hostinfo
|
||||||
@ -527,7 +538,7 @@ func (hm *HostMap) GetPreferredRanges() []netip.Prefix {
|
|||||||
return *hm.preferredRanges.Load()
|
return *hm.preferredRanges.Load()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hm *HostMap) ForEachVpnIp(f controlEach) {
|
func (hm *HostMap) ForEachVpnAddr(f controlEach) {
|
||||||
hm.RLock()
|
hm.RLock()
|
||||||
defer hm.RUnlock()
|
defer hm.RUnlock()
|
||||||
|
|
||||||
@ -581,7 +592,7 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []netip.Prefix, ifce *Interfac
|
|||||||
}
|
}
|
||||||
|
|
||||||
i.nextLHQuery.Store(now + ifce.reQueryWait.Load())
|
i.nextLHQuery.Store(now + ifce.reQueryWait.Load())
|
||||||
ifce.lightHouse.QueryServer(i.vpnIp)
|
ifce.lightHouse.QueryServer(i.vpnAddrs[0])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -596,7 +607,7 @@ func (i *HostInfo) SetRemote(remote netip.AddrPort) {
|
|||||||
// We copy here because we likely got this remote from a source that reuses the object
|
// We copy here because we likely got this remote from a source that reuses the object
|
||||||
if i.remote != remote {
|
if i.remote != remote {
|
||||||
i.remote = remote
|
i.remote = remote
|
||||||
i.remotes.LearnRemote(i.vpnIp, remote)
|
i.remotes.LearnRemote(i.vpnAddrs[0], remote)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -647,21 +658,20 @@ func (i *HostInfo) RecvErrorExceeded() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *HostInfo) CreateRemoteCIDR(c cert.Certificate) {
|
func (i *HostInfo) buildNetworks(c cert.Certificate) {
|
||||||
if len(c.Networks()) == 1 && len(c.UnsafeNetworks()) == 0 {
|
if len(c.Networks()) == 1 && len(c.UnsafeNetworks()) == 0 {
|
||||||
// Simple case, no CIDRTree needed
|
// Simple case, no CIDRTree needed
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
remoteCidr := new(bart.Table[struct{}])
|
i.networks = new(bart.Table[struct{}])
|
||||||
for _, network := range c.Networks() {
|
for _, network := range c.Networks() {
|
||||||
remoteCidr.Insert(network, struct{}{})
|
i.networks.Insert(network, struct{}{})
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, network := range c.UnsafeNetworks() {
|
for _, network := range c.UnsafeNetworks() {
|
||||||
remoteCidr.Insert(network, struct{}{})
|
i.networks.Insert(network, struct{}{})
|
||||||
}
|
}
|
||||||
i.remoteCidr = remoteCidr
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
|
func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
|
||||||
@ -669,7 +679,7 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
|
|||||||
return logrus.NewEntry(l)
|
return logrus.NewEntry(l)
|
||||||
}
|
}
|
||||||
|
|
||||||
li := l.WithField("vpnIp", i.vpnIp).
|
li := l.WithField("vpnAddrs", i.vpnAddrs).
|
||||||
WithField("localIndex", i.localIndexId).
|
WithField("localIndex", i.localIndexId).
|
||||||
WithField("remoteIndex", i.remoteIndexId)
|
WithField("remoteIndex", i.remoteIndexId)
|
||||||
|
|
||||||
@ -684,9 +694,9 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
|
|||||||
|
|
||||||
// Utility functions
|
// Utility functions
|
||||||
|
|
||||||
func localIps(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
|
func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
|
||||||
//FIXME: This function is pretty garbage
|
//FIXME: This function is pretty garbage
|
||||||
var ips []netip.Addr
|
var finalAddrs []netip.Addr
|
||||||
ifaces, _ := net.Interfaces()
|
ifaces, _ := net.Interfaces()
|
||||||
for _, i := range ifaces {
|
for _, i := range ifaces {
|
||||||
allow := allowList.AllowName(i.Name)
|
allow := allowList.AllowName(i.Name)
|
||||||
@ -698,39 +708,38 @@ func localIps(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
addrs, _ := i.Addrs()
|
addrs, _ := i.Addrs()
|
||||||
for _, addr := range addrs {
|
for _, rawAddr := range addrs {
|
||||||
var ip net.IP
|
var addr netip.Addr
|
||||||
switch v := addr.(type) {
|
switch v := rawAddr.(type) {
|
||||||
case *net.IPNet:
|
case *net.IPNet:
|
||||||
//continue
|
//continue
|
||||||
ip = v.IP
|
addr, _ = netip.AddrFromSlice(v.IP)
|
||||||
case *net.IPAddr:
|
case *net.IPAddr:
|
||||||
ip = v.IP
|
addr, _ = netip.AddrFromSlice(v.IP)
|
||||||
}
|
}
|
||||||
|
|
||||||
nip, ok := netip.AddrFromSlice(ip)
|
if !addr.IsValid() {
|
||||||
if !ok {
|
|
||||||
if l.Level >= logrus.DebugLevel {
|
if l.Level >= logrus.DebugLevel {
|
||||||
l.WithField("localIp", ip).Debug("ip was invalid for netip")
|
l.WithField("localAddr", rawAddr).Debug("addr was invalid")
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
nip = nip.Unmap()
|
addr = addr.Unmap()
|
||||||
|
|
||||||
//TODO: Filtering out link local for now, this is probably the most correct thing
|
//TODO: Filtering out link local for now, this is probably the most correct thing
|
||||||
//TODO: Would be nice to filter out SLAAC MAC based ips as well
|
//TODO: Would be nice to filter out SLAAC MAC based ips as well
|
||||||
if nip.IsLoopback() == false && nip.IsLinkLocalUnicast() == false {
|
if addr.IsLoopback() == false && addr.IsLinkLocalUnicast() == false {
|
||||||
allow := allowList.Allow(nip)
|
isAllowed := allowList.Allow(addr)
|
||||||
if l.Level >= logrus.TraceLevel {
|
if l.Level >= logrus.TraceLevel {
|
||||||
l.WithField("localIp", nip).WithField("allow", allow).Trace("localAllowList.Allow")
|
l.WithField("localAddr", addr).WithField("allowed", isAllowed).Trace("localAllowList.Allow")
|
||||||
}
|
}
|
||||||
if !allow {
|
if !isAllowed {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
ips = append(ips, nip)
|
finalAddrs = append(finalAddrs, addr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return ips
|
return finalAddrs
|
||||||
}
|
}
|
||||||
|
|||||||
@ -11,17 +11,14 @@ import (
|
|||||||
|
|
||||||
func TestHostMap_MakePrimary(t *testing.T) {
|
func TestHostMap_MakePrimary(t *testing.T) {
|
||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
hm := newHostMap(
|
hm := newHostMap(l)
|
||||||
l,
|
|
||||||
netip.MustParsePrefix("10.0.0.1/24"),
|
|
||||||
)
|
|
||||||
|
|
||||||
f := &Interface{}
|
f := &Interface{}
|
||||||
|
|
||||||
h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1}
|
h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1}
|
||||||
h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2}
|
h2 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 2}
|
||||||
h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3}
|
h3 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 3}
|
||||||
h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4}
|
h4 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 4}
|
||||||
|
|
||||||
hm.unlockedAddHostInfo(h4, f)
|
hm.unlockedAddHostInfo(h4, f)
|
||||||
hm.unlockedAddHostInfo(h3, f)
|
hm.unlockedAddHostInfo(h3, f)
|
||||||
@ -29,7 +26,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
|
|||||||
hm.unlockedAddHostInfo(h1, f)
|
hm.unlockedAddHostInfo(h1, f)
|
||||||
|
|
||||||
// Make sure we go h1 -> h2 -> h3 -> h4
|
// Make sure we go h1 -> h2 -> h3 -> h4
|
||||||
prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
|
prim := hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
|
||||||
assert.Equal(t, h1.localIndexId, prim.localIndexId)
|
assert.Equal(t, h1.localIndexId, prim.localIndexId)
|
||||||
assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
|
assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
|
||||||
assert.Nil(t, prim.prev)
|
assert.Nil(t, prim.prev)
|
||||||
@ -44,7 +41,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
|
|||||||
hm.MakePrimary(h3)
|
hm.MakePrimary(h3)
|
||||||
|
|
||||||
// Make sure we go h3 -> h1 -> h2 -> h4
|
// Make sure we go h3 -> h1 -> h2 -> h4
|
||||||
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
|
prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
|
||||||
assert.Equal(t, h3.localIndexId, prim.localIndexId)
|
assert.Equal(t, h3.localIndexId, prim.localIndexId)
|
||||||
assert.Equal(t, h1.localIndexId, prim.next.localIndexId)
|
assert.Equal(t, h1.localIndexId, prim.next.localIndexId)
|
||||||
assert.Nil(t, prim.prev)
|
assert.Nil(t, prim.prev)
|
||||||
@ -59,7 +56,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
|
|||||||
hm.MakePrimary(h4)
|
hm.MakePrimary(h4)
|
||||||
|
|
||||||
// Make sure we go h4 -> h3 -> h1 -> h2
|
// Make sure we go h4 -> h3 -> h1 -> h2
|
||||||
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
|
prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
|
||||||
assert.Equal(t, h4.localIndexId, prim.localIndexId)
|
assert.Equal(t, h4.localIndexId, prim.localIndexId)
|
||||||
assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
|
assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
|
||||||
assert.Nil(t, prim.prev)
|
assert.Nil(t, prim.prev)
|
||||||
@ -74,7 +71,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
|
|||||||
hm.MakePrimary(h4)
|
hm.MakePrimary(h4)
|
||||||
|
|
||||||
// Make sure we go h4 -> h3 -> h1 -> h2
|
// Make sure we go h4 -> h3 -> h1 -> h2
|
||||||
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
|
prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
|
||||||
assert.Equal(t, h4.localIndexId, prim.localIndexId)
|
assert.Equal(t, h4.localIndexId, prim.localIndexId)
|
||||||
assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
|
assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
|
||||||
assert.Nil(t, prim.prev)
|
assert.Nil(t, prim.prev)
|
||||||
@ -88,19 +85,16 @@ func TestHostMap_MakePrimary(t *testing.T) {
|
|||||||
|
|
||||||
func TestHostMap_DeleteHostInfo(t *testing.T) {
|
func TestHostMap_DeleteHostInfo(t *testing.T) {
|
||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
hm := newHostMap(
|
hm := newHostMap(l)
|
||||||
l,
|
|
||||||
netip.MustParsePrefix("10.0.0.1/24"),
|
|
||||||
)
|
|
||||||
|
|
||||||
f := &Interface{}
|
f := &Interface{}
|
||||||
|
|
||||||
h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1}
|
h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1}
|
||||||
h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2}
|
h2 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 2}
|
||||||
h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3}
|
h3 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 3}
|
||||||
h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4}
|
h4 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 4}
|
||||||
h5 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 5}
|
h5 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 5}
|
||||||
h6 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 6}
|
h6 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 6}
|
||||||
|
|
||||||
hm.unlockedAddHostInfo(h6, f)
|
hm.unlockedAddHostInfo(h6, f)
|
||||||
hm.unlockedAddHostInfo(h5, f)
|
hm.unlockedAddHostInfo(h5, f)
|
||||||
@ -116,7 +110,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
|
|||||||
assert.Nil(t, h)
|
assert.Nil(t, h)
|
||||||
|
|
||||||
// Make sure we go h1 -> h2 -> h3 -> h4 -> h5
|
// Make sure we go h1 -> h2 -> h3 -> h4 -> h5
|
||||||
prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
|
prim := hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
|
||||||
assert.Equal(t, h1.localIndexId, prim.localIndexId)
|
assert.Equal(t, h1.localIndexId, prim.localIndexId)
|
||||||
assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
|
assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
|
||||||
assert.Nil(t, prim.prev)
|
assert.Nil(t, prim.prev)
|
||||||
@ -135,7 +129,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
|
|||||||
assert.Nil(t, h1.next)
|
assert.Nil(t, h1.next)
|
||||||
|
|
||||||
// Make sure we go h2 -> h3 -> h4 -> h5
|
// Make sure we go h2 -> h3 -> h4 -> h5
|
||||||
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
|
prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
|
||||||
assert.Equal(t, h2.localIndexId, prim.localIndexId)
|
assert.Equal(t, h2.localIndexId, prim.localIndexId)
|
||||||
assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
|
assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
|
||||||
assert.Nil(t, prim.prev)
|
assert.Nil(t, prim.prev)
|
||||||
@ -153,7 +147,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
|
|||||||
assert.Nil(t, h3.next)
|
assert.Nil(t, h3.next)
|
||||||
|
|
||||||
// Make sure we go h2 -> h4 -> h5
|
// Make sure we go h2 -> h4 -> h5
|
||||||
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
|
prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
|
||||||
assert.Equal(t, h2.localIndexId, prim.localIndexId)
|
assert.Equal(t, h2.localIndexId, prim.localIndexId)
|
||||||
assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
|
assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
|
||||||
assert.Nil(t, prim.prev)
|
assert.Nil(t, prim.prev)
|
||||||
@ -169,7 +163,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
|
|||||||
assert.Nil(t, h5.next)
|
assert.Nil(t, h5.next)
|
||||||
|
|
||||||
// Make sure we go h2 -> h4
|
// Make sure we go h2 -> h4
|
||||||
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
|
prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
|
||||||
assert.Equal(t, h2.localIndexId, prim.localIndexId)
|
assert.Equal(t, h2.localIndexId, prim.localIndexId)
|
||||||
assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
|
assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
|
||||||
assert.Nil(t, prim.prev)
|
assert.Nil(t, prim.prev)
|
||||||
@ -183,7 +177,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
|
|||||||
assert.Nil(t, h2.next)
|
assert.Nil(t, h2.next)
|
||||||
|
|
||||||
// Make sure we only have h4
|
// Make sure we only have h4
|
||||||
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
|
prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
|
||||||
assert.Equal(t, h4.localIndexId, prim.localIndexId)
|
assert.Equal(t, h4.localIndexId, prim.localIndexId)
|
||||||
assert.Nil(t, prim.prev)
|
assert.Nil(t, prim.prev)
|
||||||
assert.Nil(t, prim.next)
|
assert.Nil(t, prim.next)
|
||||||
@ -195,7 +189,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
|
|||||||
assert.Nil(t, h4.next)
|
assert.Nil(t, h4.next)
|
||||||
|
|
||||||
// Make sure we have nil
|
// Make sure we have nil
|
||||||
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
|
prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
|
||||||
assert.Nil(t, prim)
|
assert.Nil(t, prim)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -203,11 +197,7 @@ func TestHostMap_reload(t *testing.T) {
|
|||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
|
|
||||||
hm := NewHostMapFromConfig(
|
hm := NewHostMapFromConfig(l, c)
|
||||||
l,
|
|
||||||
netip.MustParsePrefix("10.0.0.1/24"),
|
|
||||||
c,
|
|
||||||
)
|
|
||||||
|
|
||||||
toS := func(ipn []netip.Prefix) []string {
|
toS := func(ipn []netip.Prefix) []string {
|
||||||
var s []string
|
var s []string
|
||||||
|
|||||||
@ -9,8 +9,8 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (i *HostInfo) GetVpnIp() netip.Addr {
|
func (i *HostInfo) GetVpnAddrs() []netip.Addr {
|
||||||
return i.vpnIp
|
return i.vpnAddrs
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *HostInfo) GetLocalIndex() uint32 {
|
func (i *HostInfo) GetLocalIndex() uint32 {
|
||||||
|
|||||||
56
inside.go
56
inside.go
@ -20,14 +20,19 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Ignore local broadcast packets
|
// Ignore local broadcast packets
|
||||||
if f.dropLocalBroadcast && fwPacket.RemoteIP == f.myBroadcastAddr {
|
if f.dropLocalBroadcast {
|
||||||
|
_, found := f.myBroadcastAddrsTable.Lookup(fwPacket.RemoteAddr)
|
||||||
|
if found {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if fwPacket.RemoteIP == f.myVpnNet.Addr() {
|
//TODO: seems like a huge bummer
|
||||||
|
_, found := f.myVpnAddrsTable.Lookup(fwPacket.RemoteAddr)
|
||||||
|
if found {
|
||||||
// Immediately forward packets from self to self.
|
// Immediately forward packets from self to self.
|
||||||
// This should only happen on Darwin-based and FreeBSD hosts, which
|
// This should only happen on Darwin-based and FreeBSD hosts, which
|
||||||
// routes packets from the Nebula IP to the Nebula IP through the Nebula
|
// routes packets from the Nebula addr to the Nebula addr through the Nebula
|
||||||
// TUN device.
|
// TUN device.
|
||||||
if immediatelyForwardToSelf {
|
if immediatelyForwardToSelf {
|
||||||
_, err := f.readers[q].Write(packet)
|
_, err := f.readers[q].Write(packet)
|
||||||
@ -36,25 +41,25 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Otherwise, drop. On linux, we should never see these packets - Linux
|
// Otherwise, drop. On linux, we should never see these packets - Linux
|
||||||
// routes packets from the nebula IP to the nebula IP through the loopback device.
|
// routes packets from the nebula addr to the nebula addr through the loopback device.
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ignore multicast packets
|
// Ignore multicast packets
|
||||||
if f.dropMulticast && fwPacket.RemoteIP.IsMulticast() {
|
if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
hostinfo, ready := f.getOrHandshake(fwPacket.RemoteIP, func(hh *HandshakeHostInfo) {
|
hostinfo, ready := f.getOrHandshake(fwPacket.RemoteAddr, func(hh *HandshakeHostInfo) {
|
||||||
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
|
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
|
||||||
})
|
})
|
||||||
|
|
||||||
if hostinfo == nil {
|
if hostinfo == nil {
|
||||||
f.rejectInside(packet, out, q)
|
f.rejectInside(packet, out, q)
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
f.l.WithField("vpnIp", fwPacket.RemoteIP).
|
f.l.WithField("vpnAddr", fwPacket.RemoteAddr).
|
||||||
WithField("fwPacket", fwPacket).
|
WithField("fwPacket", fwPacket).
|
||||||
Debugln("dropping outbound packet, vpnIp not in our CIDR or in unsafe routes")
|
Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks")
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -117,21 +122,22 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
|
|||||||
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
|
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) Handshake(vpnIp netip.Addr) {
|
func (f *Interface) Handshake(vpnAddr netip.Addr) {
|
||||||
f.getOrHandshake(vpnIp, nil)
|
f.getOrHandshake(vpnAddr, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getOrHandshake returns nil if the vpnIp is not routable.
|
// getOrHandshake returns nil if the vpnAddr is not routable.
|
||||||
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel
|
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel
|
||||||
func (f *Interface) getOrHandshake(vpnIp netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
|
func (f *Interface) getOrHandshake(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
|
||||||
if !f.myVpnNet.Contains(vpnIp) {
|
_, found := f.myVpnNetworksTable.Lookup(vpnAddr)
|
||||||
vpnIp = f.inside.RouteFor(vpnIp)
|
if !found {
|
||||||
if !vpnIp.IsValid() {
|
vpnAddr = f.inside.RouteFor(vpnAddr)
|
||||||
|
if !vpnAddr.IsValid() {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return f.handshakeManager.GetOrHandshake(vpnIp, cacheCallback)
|
return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
|
func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
|
||||||
@ -156,16 +162,16 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
|
|||||||
f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0)
|
f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
|
// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr
|
||||||
func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) {
|
func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) {
|
||||||
hostInfo, ready := f.getOrHandshake(vpnIp, func(hh *HandshakeHostInfo) {
|
hostInfo, ready := f.getOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) {
|
||||||
hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
|
hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
|
||||||
})
|
})
|
||||||
|
|
||||||
if hostInfo == nil {
|
if hostInfo == nil {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
f.l.WithField("vpnIp", vpnIp).
|
f.l.WithField("vpnAddr", vpnAddr).
|
||||||
Debugln("dropping SendMessageToVpnIp, vpnIp not in our CIDR or in unsafe routes")
|
Debugln("dropping SendMessageToVpnAddr, vpnAddr not in our vpn networks or in unsafe routes")
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -285,14 +291,14 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
|||||||
f.connectionManager.Out(hostinfo.localIndexId)
|
f.connectionManager.Out(hostinfo.localIndexId)
|
||||||
|
|
||||||
// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
|
// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
|
||||||
// all our IPs and enable a faster roaming.
|
// all our addrs and enable a faster roaming.
|
||||||
if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount {
|
if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount {
|
||||||
//NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is
|
//NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is
|
||||||
// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
|
// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
|
||||||
f.lightHouse.QueryServer(hostinfo.vpnIp)
|
f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
|
||||||
hostinfo.lastRebindCount = f.rebindCount
|
hostinfo.lastRebindCount = f.rebindCount
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
f.l.WithField("vpnIp", hostinfo.vpnIp).Debug("Lighthouse update triggered for punch due to rebind counter")
|
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -324,7 +330,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
|
|||||||
} else {
|
} else {
|
||||||
// Try to send via a relay
|
// Try to send via a relay
|
||||||
for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
|
for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
|
||||||
relayHostInfo, relay, err := f.hostMap.QueryVpnIpRelayFor(hostinfo.vpnIp, relayIP)
|
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.relayState.DeleteRelay(relayIP)
|
hostinfo.relayState.DeleteRelay(relayIP)
|
||||||
hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")
|
hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")
|
||||||
|
|||||||
67
interface.go
67
interface.go
@ -2,17 +2,16 @@ package nebula
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gaissmai/bart"
|
||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
@ -29,7 +28,6 @@ type InterfaceConfig struct {
|
|||||||
Outside udp.Conn
|
Outside udp.Conn
|
||||||
Inside overlay.Device
|
Inside overlay.Device
|
||||||
pki *PKI
|
pki *PKI
|
||||||
Cipher string
|
|
||||||
Firewall *Firewall
|
Firewall *Firewall
|
||||||
ServeDns bool
|
ServeDns bool
|
||||||
HandshakeManager *HandshakeManager
|
HandshakeManager *HandshakeManager
|
||||||
@ -57,15 +55,17 @@ type Interface struct {
|
|||||||
outside udp.Conn
|
outside udp.Conn
|
||||||
inside overlay.Device
|
inside overlay.Device
|
||||||
pki *PKI
|
pki *PKI
|
||||||
cipher string
|
|
||||||
firewall *Firewall
|
firewall *Firewall
|
||||||
connectionManager *connectionManager
|
connectionManager *connectionManager
|
||||||
handshakeManager *HandshakeManager
|
handshakeManager *HandshakeManager
|
||||||
serveDns bool
|
serveDns bool
|
||||||
createTime time.Time
|
createTime time.Time
|
||||||
lightHouse *LightHouse
|
lightHouse *LightHouse
|
||||||
myBroadcastAddr netip.Addr
|
myBroadcastAddrsTable *bart.Table[struct{}]
|
||||||
myVpnNet netip.Prefix
|
myVpnAddrs []netip.Addr // A list of addresses assigned to us via our certificate
|
||||||
|
myVpnAddrsTable *bart.Table[struct{}] // A table of addresses assigned to us via our certificate
|
||||||
|
myVpnNetworks []netip.Prefix // A table of networks assigned to us via our certificate
|
||||||
|
myVpnNetworksTable *bart.Table[struct{}] // A table of networks assigned to us via our certificate
|
||||||
dropLocalBroadcast bool
|
dropLocalBroadcast bool
|
||||||
dropMulticast bool
|
dropMulticast bool
|
||||||
routines int
|
routines int
|
||||||
@ -103,9 +103,11 @@ type EncWriter interface {
|
|||||||
out []byte,
|
out []byte,
|
||||||
nocopy bool,
|
nocopy bool,
|
||||||
)
|
)
|
||||||
SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte)
|
SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte)
|
||||||
SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte)
|
SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte)
|
||||||
Handshake(vpnIp netip.Addr)
|
Handshake(vpnAddr netip.Addr)
|
||||||
|
GetHostInfo(vpnAddr netip.Addr) *HostInfo
|
||||||
|
GetCertState() *CertState
|
||||||
}
|
}
|
||||||
|
|
||||||
type sendRecvErrorConfig uint8
|
type sendRecvErrorConfig uint8
|
||||||
@ -116,10 +118,10 @@ const (
|
|||||||
sendRecvErrorPrivate
|
sendRecvErrorPrivate
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s sendRecvErrorConfig) ShouldSendRecvError(ip netip.AddrPort) bool {
|
func (s sendRecvErrorConfig) ShouldSendRecvError(endpoint netip.AddrPort) bool {
|
||||||
switch s {
|
switch s {
|
||||||
case sendRecvErrorPrivate:
|
case sendRecvErrorPrivate:
|
||||||
return ip.Addr().IsPrivate()
|
return endpoint.Addr().IsPrivate()
|
||||||
case sendRecvErrorAlways:
|
case sendRecvErrorAlways:
|
||||||
return true
|
return true
|
||||||
case sendRecvErrorNever:
|
case sendRecvErrorNever:
|
||||||
@ -156,14 +158,12 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
return nil, errors.New("no firewall rules")
|
return nil, errors.New("no firewall rules")
|
||||||
}
|
}
|
||||||
|
|
||||||
certificate := c.pki.GetCertState().Certificate
|
cs := c.pki.getCertState()
|
||||||
|
|
||||||
ifce := &Interface{
|
ifce := &Interface{
|
||||||
pki: c.pki,
|
pki: c.pki,
|
||||||
hostMap: c.HostMap,
|
hostMap: c.HostMap,
|
||||||
outside: c.Outside,
|
outside: c.Outside,
|
||||||
inside: c.Inside,
|
inside: c.Inside,
|
||||||
cipher: c.Cipher,
|
|
||||||
firewall: c.Firewall,
|
firewall: c.Firewall,
|
||||||
serveDns: c.ServeDns,
|
serveDns: c.ServeDns,
|
||||||
handshakeManager: c.HandshakeManager,
|
handshakeManager: c.HandshakeManager,
|
||||||
@ -175,7 +175,11 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
version: c.version,
|
version: c.version,
|
||||||
writers: make([]udp.Conn, c.routines),
|
writers: make([]udp.Conn, c.routines),
|
||||||
readers: make([]io.ReadWriteCloser, c.routines),
|
readers: make([]io.ReadWriteCloser, c.routines),
|
||||||
myVpnNet: certificate.Networks()[0],
|
myVpnNetworks: cs.myVpnNetworks,
|
||||||
|
myVpnNetworksTable: cs.myVpnNetworksTable,
|
||||||
|
myVpnAddrs: cs.myVpnAddrs,
|
||||||
|
myVpnAddrsTable: cs.myVpnAddrsTable,
|
||||||
|
myBroadcastAddrsTable: cs.myVpnBroadcastAddrsTable,
|
||||||
relayManager: c.relayManager,
|
relayManager: c.relayManager,
|
||||||
|
|
||||||
conntrackCacheTimeout: c.ConntrackCacheTimeout,
|
conntrackCacheTimeout: c.ConntrackCacheTimeout,
|
||||||
@ -190,14 +194,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
l: c.l,
|
l: c.l,
|
||||||
}
|
}
|
||||||
|
|
||||||
if ifce.myVpnNet.Addr().Is4() {
|
|
||||||
maskedAddr := certificate.Networks()[0].Masked()
|
|
||||||
addr := maskedAddr.Addr().As4()
|
|
||||||
mask := net.CIDRMask(maskedAddr.Bits(), maskedAddr.Addr().BitLen())
|
|
||||||
binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(mask))
|
|
||||||
ifce.myBroadcastAddr = netip.AddrFrom4(addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
|
ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
|
||||||
ifce.reQueryEvery.Store(c.reQueryEvery)
|
ifce.reQueryEvery.Store(c.reQueryEvery)
|
||||||
ifce.reQueryWait.Store(int64(c.reQueryWait))
|
ifce.reQueryWait.Store(int64(c.reQueryWait))
|
||||||
@ -218,7 +214,7 @@ func (f *Interface) activate() {
|
|||||||
f.l.WithError(err).Error("Failed to get udp listen address")
|
f.l.WithError(err).Error("Failed to get udp listen address")
|
||||||
}
|
}
|
||||||
|
|
||||||
f.l.WithField("interface", f.inside.Name()).WithField("network", f.inside.Cidr().String()).
|
f.l.WithField("interface", f.inside.Name()).WithField("networks", f.myVpnNetworks).
|
||||||
WithField("build", f.version).WithField("udpAddr", addr).
|
WithField("build", f.version).WithField("udpAddr", addr).
|
||||||
WithField("boringcrypto", boringEnabled()).
|
WithField("boringcrypto", boringEnabled()).
|
||||||
Info("Nebula interface is active")
|
Info("Nebula interface is active")
|
||||||
@ -259,16 +255,22 @@ func (f *Interface) listenOut(i int) {
|
|||||||
runtime.LockOSThread()
|
runtime.LockOSThread()
|
||||||
|
|
||||||
var li udp.Conn
|
var li udp.Conn
|
||||||
// TODO clean this up with a coherent interface for each outside connection
|
|
||||||
if i > 0 {
|
if i > 0 {
|
||||||
li = f.writers[i]
|
li = f.writers[i]
|
||||||
} else {
|
} else {
|
||||||
li = f.outside
|
li = f.outside
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
||||||
lhh := f.lightHouse.NewRequestHandler()
|
lhh := f.lightHouse.NewRequestHandler()
|
||||||
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
|
plaintext := make([]byte, udp.MTU)
|
||||||
li.ListenOut(readOutsidePackets(f), lhHandleRequest(lhh, f), conntrackCache, i)
|
h := &header.H{}
|
||||||
|
fwPacket := &firewall.Packet{}
|
||||||
|
nb := make([]byte, 12, 12)
|
||||||
|
|
||||||
|
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
|
||||||
|
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
|
||||||
@ -325,7 +327,7 @@ func (f *Interface) reloadFirewall(c *config.C) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
fw, err := NewFirewallFromConfig(f.l, f.pki.GetCertState().Certificate, c)
|
fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.WithError(err).Error("Error while creating firewall during reload")
|
f.l.WithError(err).Error("Error while creating firewall during reload")
|
||||||
return
|
return
|
||||||
@ -417,11 +419,20 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
|
|||||||
f.firewall.EmitStats()
|
f.firewall.EmitStats()
|
||||||
f.handshakeManager.EmitStats()
|
f.handshakeManager.EmitStats()
|
||||||
udpStats()
|
udpStats()
|
||||||
certExpirationGauge.Update(int64(f.pki.GetCertState().Certificate.NotAfter().Sub(time.Now()) / time.Second))
|
certExpirationGauge.Update(int64(f.pki.getDefaultCertificate().NotAfter().Sub(time.Now()) / time.Second))
|
||||||
|
//TODO: we should also report the default certificate version
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *Interface) GetHostInfo(vpnIp netip.Addr) *HostInfo {
|
||||||
|
return f.hostMap.QueryVpnAddr(vpnIp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Interface) GetCertState() *CertState {
|
||||||
|
return f.pki.getCertState()
|
||||||
|
}
|
||||||
|
|
||||||
func (f *Interface) Close() error {
|
func (f *Interface) Close() error {
|
||||||
f.closed.Store(true)
|
f.closed.Store(true)
|
||||||
|
|
||||||
|
|||||||
836
lighthouse.go
836
lighthouse.go
File diff suppressed because it is too large
Load Diff
@ -7,6 +7,8 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gaissmai/bart"
|
||||||
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"github.com/slackhq/nebula/test"
|
"github.com/slackhq/nebula/test"
|
||||||
@ -19,57 +21,48 @@ import (
|
|||||||
func TestOldIPv4Only(t *testing.T) {
|
func TestOldIPv4Only(t *testing.T) {
|
||||||
// This test ensures our new ipv6 enabled LH protobuf IpAndPorts works with the old style to enable backwards compatibility
|
// This test ensures our new ipv6 enabled LH protobuf IpAndPorts works with the old style to enable backwards compatibility
|
||||||
b := []byte{8, 129, 130, 132, 80, 16, 10}
|
b := []byte{8, 129, 130, 132, 80, 16, 10}
|
||||||
var m Ip4AndPort
|
var m V4AddrPort
|
||||||
err := m.Unmarshal(b)
|
err := m.Unmarshal(b)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
ip := netip.MustParseAddr("10.1.1.1")
|
ip := netip.MustParseAddr("10.1.1.1")
|
||||||
bp := ip.As4()
|
bp := ip.As4()
|
||||||
assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetIp())
|
assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetAddr())
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewLhQuery(t *testing.T) {
|
|
||||||
myIp, err := netip.ParseAddr("192.1.1.1")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
// Generating a new lh query should work
|
|
||||||
a := NewLhQueryByInt(myIp)
|
|
||||||
|
|
||||||
// The result should be a nebulameta protobuf
|
|
||||||
assert.IsType(t, &NebulaMeta{}, a)
|
|
||||||
|
|
||||||
// It should also Marshal fine
|
|
||||||
b, err := a.Marshal()
|
|
||||||
assert.Nil(t, err)
|
|
||||||
|
|
||||||
// and then Unmarshal fine
|
|
||||||
n := &NebulaMeta{}
|
|
||||||
err = n.Unmarshal(b)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_lhStaticMapping(t *testing.T) {
|
func Test_lhStaticMapping(t *testing.T) {
|
||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
|
myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
|
||||||
|
nt := new(bart.Table[struct{}])
|
||||||
|
nt.Insert(myVpnNet, struct{}{})
|
||||||
|
cs := &CertState{
|
||||||
|
myVpnNetworks: []netip.Prefix{myVpnNet},
|
||||||
|
myVpnNetworksTable: nt,
|
||||||
|
}
|
||||||
lh1 := "10.128.0.2"
|
lh1 := "10.128.0.2"
|
||||||
|
|
||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}}
|
c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}}
|
||||||
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
|
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
|
||||||
_, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
|
_, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
lh2 := "10.128.0.3"
|
lh2 := "10.128.0.3"
|
||||||
c = config.NewC(l)
|
c = config.NewC(l)
|
||||||
c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}}
|
c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}}
|
||||||
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}}
|
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}}
|
||||||
_, err = NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
|
_, err = NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
||||||
assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry")
|
assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestReloadLighthouseInterval(t *testing.T) {
|
func TestReloadLighthouseInterval(t *testing.T) {
|
||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
|
myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
|
||||||
|
nt := new(bart.Table[struct{}])
|
||||||
|
nt.Insert(myVpnNet, struct{}{})
|
||||||
|
cs := &CertState{
|
||||||
|
myVpnNetworks: []netip.Prefix{myVpnNet},
|
||||||
|
myVpnNetworksTable: nt,
|
||||||
|
}
|
||||||
lh1 := "10.128.0.2"
|
lh1 := "10.128.0.2"
|
||||||
|
|
||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
@ -79,7 +72,7 @@ func TestReloadLighthouseInterval(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
|
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
|
||||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
|
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
lh.ifce = &mockEncWriter{}
|
lh.ifce = &mockEncWriter{}
|
||||||
|
|
||||||
@ -99,9 +92,15 @@ func TestReloadLighthouseInterval(t *testing.T) {
|
|||||||
func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
||||||
l := test.NewLogger()
|
l := test.NewLogger()
|
||||||
myVpnNet := netip.MustParsePrefix("10.128.0.1/0")
|
myVpnNet := netip.MustParsePrefix("10.128.0.1/0")
|
||||||
|
nt := new(bart.Table[struct{}])
|
||||||
|
nt.Insert(myVpnNet, struct{}{})
|
||||||
|
cs := &CertState{
|
||||||
|
myVpnNetworks: []netip.Prefix{myVpnNet},
|
||||||
|
myVpnNetworksTable: nt,
|
||||||
|
}
|
||||||
|
|
||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
|
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
||||||
if !assert.NoError(b, err) {
|
if !assert.NoError(b, err) {
|
||||||
b.Fatal()
|
b.Fatal()
|
||||||
}
|
}
|
||||||
@ -110,46 +109,47 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
|||||||
hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346")
|
hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346")
|
||||||
|
|
||||||
vpnIp3 := netip.MustParseAddr("0.0.0.3")
|
vpnIp3 := netip.MustParseAddr("0.0.0.3")
|
||||||
lh.addrMap[vpnIp3] = NewRemoteList(nil)
|
lh.addrMap[vpnIp3] = NewRemoteList([]netip.Addr{vpnIp3}, nil)
|
||||||
lh.addrMap[vpnIp3].unlockedSetV4(
|
lh.addrMap[vpnIp3].unlockedSetV4(
|
||||||
vpnIp3,
|
vpnIp3,
|
||||||
vpnIp3,
|
vpnIp3,
|
||||||
[]*Ip4AndPort{
|
[]*V4AddrPort{
|
||||||
NewIp4AndPortFromNetIP(hAddr.Addr(), hAddr.Port()),
|
netAddrToProtoV4AddrPort(hAddr.Addr(), hAddr.Port()),
|
||||||
NewIp4AndPortFromNetIP(hAddr2.Addr(), hAddr2.Port()),
|
netAddrToProtoV4AddrPort(hAddr2.Addr(), hAddr2.Port()),
|
||||||
},
|
},
|
||||||
func(netip.Addr, *Ip4AndPort) bool { return true },
|
func(netip.Addr, *V4AddrPort) bool { return true },
|
||||||
)
|
)
|
||||||
|
|
||||||
rAddr := netip.MustParseAddrPort("1.2.2.3:12345")
|
rAddr := netip.MustParseAddrPort("1.2.2.3:12345")
|
||||||
rAddr2 := netip.MustParseAddrPort("1.2.2.3:12346")
|
rAddr2 := netip.MustParseAddrPort("1.2.2.3:12346")
|
||||||
vpnIp2 := netip.MustParseAddr("0.0.0.3")
|
vpnIp2 := netip.MustParseAddr("0.0.0.3")
|
||||||
lh.addrMap[vpnIp2] = NewRemoteList(nil)
|
lh.addrMap[vpnIp2] = NewRemoteList([]netip.Addr{vpnIp2}, nil)
|
||||||
lh.addrMap[vpnIp2].unlockedSetV4(
|
lh.addrMap[vpnIp2].unlockedSetV4(
|
||||||
vpnIp3,
|
vpnIp3,
|
||||||
vpnIp3,
|
vpnIp3,
|
||||||
[]*Ip4AndPort{
|
[]*V4AddrPort{
|
||||||
NewIp4AndPortFromNetIP(rAddr.Addr(), rAddr.Port()),
|
netAddrToProtoV4AddrPort(rAddr.Addr(), rAddr.Port()),
|
||||||
NewIp4AndPortFromNetIP(rAddr2.Addr(), rAddr2.Port()),
|
netAddrToProtoV4AddrPort(rAddr2.Addr(), rAddr2.Port()),
|
||||||
},
|
},
|
||||||
func(netip.Addr, *Ip4AndPort) bool { return true },
|
func(netip.Addr, *V4AddrPort) bool { return true },
|
||||||
)
|
)
|
||||||
|
|
||||||
mw := &mockEncWriter{}
|
mw := &mockEncWriter{}
|
||||||
|
|
||||||
|
hi := []netip.Addr{vpnIp2}
|
||||||
b.Run("notfound", func(b *testing.B) {
|
b.Run("notfound", func(b *testing.B) {
|
||||||
lhh := lh.NewRequestHandler()
|
lhh := lh.NewRequestHandler()
|
||||||
req := &NebulaMeta{
|
req := &NebulaMeta{
|
||||||
Type: NebulaMeta_HostQuery,
|
Type: NebulaMeta_HostQuery,
|
||||||
Details: &NebulaMetaDetails{
|
Details: &NebulaMetaDetails{
|
||||||
VpnIp: 4,
|
OldVpnAddr: 4,
|
||||||
Ip4AndPorts: nil,
|
V4AddrPorts: nil,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
p, err := req.Marshal()
|
p, err := req.Marshal()
|
||||||
assert.NoError(b, err)
|
assert.NoError(b, err)
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
lhh.HandleRequest(rAddr, vpnIp2, p, mw)
|
lhh.HandleRequest(rAddr, hi, p, mw)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
b.Run("found", func(b *testing.B) {
|
b.Run("found", func(b *testing.B) {
|
||||||
@ -157,15 +157,15 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
|
|||||||
req := &NebulaMeta{
|
req := &NebulaMeta{
|
||||||
Type: NebulaMeta_HostQuery,
|
Type: NebulaMeta_HostQuery,
|
||||||
Details: &NebulaMetaDetails{
|
Details: &NebulaMetaDetails{
|
||||||
VpnIp: 3,
|
OldVpnAddr: 3,
|
||||||
Ip4AndPorts: nil,
|
V4AddrPorts: nil,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
p, err := req.Marshal()
|
p, err := req.Marshal()
|
||||||
assert.NoError(b, err)
|
assert.NoError(b, err)
|
||||||
|
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
lhh.HandleRequest(rAddr, vpnIp2, p, mw)
|
lhh.HandleRequest(rAddr, hi, p, mw)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -197,40 +197,49 @@ func TestLighthouse_Memory(t *testing.T) {
|
|||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
|
c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
|
||||||
c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
|
c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
|
||||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil)
|
|
||||||
|
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
|
||||||
|
nt := new(bart.Table[struct{}])
|
||||||
|
nt.Insert(myVpnNet, struct{}{})
|
||||||
|
cs := &CertState{
|
||||||
|
myVpnNetworks: []netip.Prefix{myVpnNet},
|
||||||
|
myVpnNetworksTable: nt,
|
||||||
|
}
|
||||||
|
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
||||||
|
lh.ifce = &mockEncWriter{}
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
lhh := lh.NewRequestHandler()
|
lhh := lh.NewRequestHandler()
|
||||||
|
|
||||||
// Test that my first update responds with just that
|
// Test that my first update responds with just that
|
||||||
newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr2}, lhh)
|
newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr2}, lhh)
|
||||||
r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
|
r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
|
||||||
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr2)
|
assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr2)
|
||||||
|
|
||||||
// Ensure we don't accumulate addresses
|
// Ensure we don't accumulate addresses
|
||||||
newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr3}, lhh)
|
newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr3}, lhh)
|
||||||
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
|
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
|
||||||
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr3)
|
assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr3)
|
||||||
|
|
||||||
// Grow it back to 2
|
// Grow it back to 2
|
||||||
newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr4}, lhh)
|
newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr4}, lhh)
|
||||||
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
|
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
|
||||||
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
|
assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4)
|
||||||
|
|
||||||
// Update a different host and ask about it
|
// Update a different host and ask about it
|
||||||
newLHHostUpdate(theirUdpAddr0, theirVpnIp, []netip.AddrPort{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh)
|
newLHHostUpdate(theirUdpAddr0, theirVpnIp, []netip.AddrPort{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh)
|
||||||
r = newLHHostRequest(theirUdpAddr0, theirVpnIp, theirVpnIp, lhh)
|
r = newLHHostRequest(theirUdpAddr0, theirVpnIp, theirVpnIp, lhh)
|
||||||
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
|
assertIp4InArray(t, r.msg.Details.V4AddrPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
|
||||||
|
|
||||||
// Have both hosts ask about the other
|
// Have both hosts ask about the other
|
||||||
r = newLHHostRequest(theirUdpAddr0, theirVpnIp, myVpnIp, lhh)
|
r = newLHHostRequest(theirUdpAddr0, theirVpnIp, myVpnIp, lhh)
|
||||||
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
|
assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4)
|
||||||
|
|
||||||
r = newLHHostRequest(myUdpAddr0, myVpnIp, theirVpnIp, lhh)
|
r = newLHHostRequest(myUdpAddr0, myVpnIp, theirVpnIp, lhh)
|
||||||
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
|
assertIp4InArray(t, r.msg.Details.V4AddrPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
|
||||||
|
|
||||||
// Make sure we didn't get changed
|
// Make sure we didn't get changed
|
||||||
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
|
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
|
||||||
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
|
assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4)
|
||||||
|
|
||||||
// Ensure proper ordering and limiting
|
// Ensure proper ordering and limiting
|
||||||
// Send 12 addrs, get 10 back, the last 2 removed, allowing the duplicate to remain (clients dedupe)
|
// Send 12 addrs, get 10 back, the last 2 removed, allowing the duplicate to remain (clients dedupe)
|
||||||
@ -255,7 +264,7 @@ func TestLighthouse_Memory(t *testing.T) {
|
|||||||
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
|
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
|
||||||
assertIp4InArray(
|
assertIp4InArray(
|
||||||
t,
|
t,
|
||||||
r.msg.Details.Ip4AndPorts,
|
r.msg.Details.V4AddrPorts,
|
||||||
myUdpAddr1, myUdpAddr2, myUdpAddr3, myUdpAddr4, myUdpAddr5, myUdpAddr5, myUdpAddr6, myUdpAddr7, myUdpAddr8, myUdpAddr9,
|
myUdpAddr1, myUdpAddr2, myUdpAddr3, myUdpAddr4, myUdpAddr5, myUdpAddr5, myUdpAddr6, myUdpAddr7, myUdpAddr8, myUdpAddr9,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -265,7 +274,7 @@ func TestLighthouse_Memory(t *testing.T) {
|
|||||||
good := netip.MustParseAddrPort("1.128.0.99:4242")
|
good := netip.MustParseAddrPort("1.128.0.99:4242")
|
||||||
newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{bad1, bad2, good}, lhh)
|
newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{bad1, bad2, good}, lhh)
|
||||||
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
|
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
|
||||||
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, good)
|
assertIp4InArray(t, r.msg.Details.V4AddrPorts, good)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLighthouse_reload(t *testing.T) {
|
func TestLighthouse_reload(t *testing.T) {
|
||||||
@ -273,7 +282,16 @@ func TestLighthouse_reload(t *testing.T) {
|
|||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
|
c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
|
||||||
c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
|
c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
|
||||||
lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil)
|
|
||||||
|
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
|
||||||
|
nt := new(bart.Table[struct{}])
|
||||||
|
nt.Insert(myVpnNet, struct{}{})
|
||||||
|
cs := &CertState{
|
||||||
|
myVpnNetworks: []netip.Prefix{myVpnNet},
|
||||||
|
myVpnNetworksTable: nt,
|
||||||
|
}
|
||||||
|
|
||||||
|
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
nc := map[interface{}]interface{}{
|
nc := map[interface{}]interface{}{
|
||||||
@ -295,7 +313,7 @@ func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, l
|
|||||||
req := &NebulaMeta{
|
req := &NebulaMeta{
|
||||||
Type: NebulaMeta_HostQuery,
|
Type: NebulaMeta_HostQuery,
|
||||||
Details: &NebulaMetaDetails{
|
Details: &NebulaMetaDetails{
|
||||||
VpnIp: binary.BigEndian.Uint32(bip[:]),
|
OldVpnAddr: binary.BigEndian.Uint32(bip[:]),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -308,7 +326,7 @@ func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, l
|
|||||||
w := &testEncWriter{
|
w := &testEncWriter{
|
||||||
metaFilter: &filter,
|
metaFilter: &filter,
|
||||||
}
|
}
|
||||||
lhh.HandleRequest(fromAddr, myVpnIp, b, w)
|
lhh.HandleRequest(fromAddr, []netip.Addr{myVpnIp}, b, w)
|
||||||
return w.lastReply
|
return w.lastReply
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -318,13 +336,13 @@ func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.Ad
|
|||||||
req := &NebulaMeta{
|
req := &NebulaMeta{
|
||||||
Type: NebulaMeta_HostUpdateNotification,
|
Type: NebulaMeta_HostUpdateNotification,
|
||||||
Details: &NebulaMetaDetails{
|
Details: &NebulaMetaDetails{
|
||||||
VpnIp: binary.BigEndian.Uint32(bip[:]),
|
OldVpnAddr: binary.BigEndian.Uint32(bip[:]),
|
||||||
Ip4AndPorts: make([]*Ip4AndPort, len(addrs)),
|
V4AddrPorts: make([]*V4AddrPort, len(addrs)),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for k, v := range addrs {
|
for k, v := range addrs {
|
||||||
req.Details.Ip4AndPorts[k] = NewIp4AndPortFromNetIP(v.Addr(), v.Port())
|
req.Details.V4AddrPorts[k] = netAddrToProtoV4AddrPort(v.Addr(), v.Port())
|
||||||
}
|
}
|
||||||
|
|
||||||
b, err := req.Marshal()
|
b, err := req.Marshal()
|
||||||
@ -333,7 +351,7 @@ func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.Ad
|
|||||||
}
|
}
|
||||||
|
|
||||||
w := &testEncWriter{}
|
w := &testEncWriter{}
|
||||||
lhh.HandleRequest(fromAddr, vpnIp, b, w)
|
lhh.HandleRequest(fromAddr, []netip.Addr{vpnIp}, b, w)
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO: this is a RemoteList test
|
//TODO: this is a RemoteList test
|
||||||
@ -412,6 +430,7 @@ type testLhReply struct {
|
|||||||
type testEncWriter struct {
|
type testEncWriter struct {
|
||||||
lastReply testLhReply
|
lastReply testLhReply
|
||||||
metaFilter *NebulaMeta_MessageType
|
metaFilter *NebulaMeta_MessageType
|
||||||
|
protocolVersion cert.Version
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
|
func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
|
||||||
@ -426,7 +445,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M
|
|||||||
tw.lastReply = testLhReply{
|
tw.lastReply = testLhReply{
|
||||||
nebType: t,
|
nebType: t,
|
||||||
nebSubType: st,
|
nebSubType: st,
|
||||||
vpnIp: hostinfo.vpnIp,
|
vpnIp: hostinfo.vpnAddrs[0],
|
||||||
msg: msg,
|
msg: msg,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -436,7 +455,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, _, _ []byte) {
|
func (tw *testEncWriter) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, _, _ []byte) {
|
||||||
msg := &NebulaMeta{}
|
msg := &NebulaMeta{}
|
||||||
err := msg.Unmarshal(p)
|
err := msg.Unmarshal(p)
|
||||||
if tw.metaFilter == nil || msg.Type == *tw.metaFilter {
|
if tw.metaFilter == nil || msg.Type == *tw.metaFilter {
|
||||||
@ -453,15 +472,23 @@ func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.Mess
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (tw *testEncWriter) GetHostInfo(vpnIp netip.Addr) *HostInfo {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tw *testEncWriter) GetCertState() *CertState {
|
||||||
|
return &CertState{defaultVersion: tw.protocolVersion}
|
||||||
|
}
|
||||||
|
|
||||||
// assertIp4InArray asserts every address in want is at the same position in have and that the lengths match
|
// assertIp4InArray asserts every address in want is at the same position in have and that the lengths match
|
||||||
func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...netip.AddrPort) {
|
func assertIp4InArray(t *testing.T, have []*V4AddrPort, want ...netip.AddrPort) {
|
||||||
if !assert.Len(t, have, len(want)) {
|
if !assert.Len(t, have, len(want)) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for k, w := range want {
|
for k, w := range want {
|
||||||
//TODO: IPV6-WORK
|
//TODO: IPV6-WORK
|
||||||
h := AddrPortFromIp4AndPort(have[k])
|
h := protoV4AddrPortToNetAddrPort(have[k])
|
||||||
if !(h == w) {
|
if !(h == w) {
|
||||||
assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v, found %v", w, k, h))
|
assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v, found %v", w, k, h))
|
||||||
}
|
}
|
||||||
|
|||||||
24
main.go
24
main.go
@ -2,7 +2,6 @@ package nebula
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@ -61,15 +60,12 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err)
|
return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
certificate := pki.GetCertState().Certificate
|
fw, err := NewFirewallFromConfig(l, pki.getCertState(), c)
|
||||||
fw, err := NewFirewallFromConfig(l, certificate, c)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err)
|
return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err)
|
||||||
}
|
}
|
||||||
l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started")
|
l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started")
|
||||||
|
|
||||||
tunCidr := certificate.Networks()[0]
|
|
||||||
|
|
||||||
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
|
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err)
|
return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err)
|
||||||
@ -132,7 +128,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
deviceFactory = overlay.NewDeviceFromConfig
|
deviceFactory = overlay.NewDeviceFromConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
tun, err = deviceFactory(c, l, tunCidr, routines)
|
tun, err = deviceFactory(c, l, pki.getCertState().myVpnNetworks, routines)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err)
|
return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err)
|
||||||
}
|
}
|
||||||
@ -187,9 +183,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
hostMap := NewHostMapFromConfig(l, tunCidr, c)
|
hostMap := NewHostMapFromConfig(l, c)
|
||||||
punchy := NewPunchyFromConfig(l, c)
|
punchy := NewPunchyFromConfig(l, c)
|
||||||
lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy)
|
lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err)
|
return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err)
|
||||||
}
|
}
|
||||||
@ -232,7 +228,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
Inside: tun,
|
Inside: tun,
|
||||||
Outside: udpConns[0],
|
Outside: udpConns[0],
|
||||||
pki: pki,
|
pki: pki,
|
||||||
Cipher: c.GetString("cipher", "aes"),
|
|
||||||
Firewall: fw,
|
Firewall: fw,
|
||||||
ServeDns: serveDns,
|
ServeDns: serveDns,
|
||||||
HandshakeManager: handshakeManager,
|
HandshakeManager: handshakeManager,
|
||||||
@ -254,15 +249,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
l: l,
|
l: l,
|
||||||
}
|
}
|
||||||
|
|
||||||
switch ifConfig.Cipher {
|
|
||||||
case "aes":
|
|
||||||
noiseEndianness = binary.BigEndian
|
|
||||||
case "chachapoly":
|
|
||||||
noiseEndianness = binary.LittleEndian
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unknown cipher: %v", ifConfig.Cipher)
|
|
||||||
}
|
|
||||||
|
|
||||||
var ifce *Interface
|
var ifce *Interface
|
||||||
if !configTest {
|
if !configTest {
|
||||||
ifce, err = NewInterface(ctx, ifConfig)
|
ifce, err = NewInterface(ctx, ifConfig)
|
||||||
@ -303,7 +289,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
var dnsStart func()
|
var dnsStart func()
|
||||||
if lightHouse.amLighthouse && serveDns {
|
if lightHouse.amLighthouse && serveDns {
|
||||||
l.Debugln("Starting dns server")
|
l.Debugln("Starting dns server")
|
||||||
dnsStart = dnsMain(l, hostMap, c)
|
dnsStart = dnsMain(l, pki.getCertState(), hostMap, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Control{
|
return &Control{
|
||||||
|
|||||||
913
nebula.pb.go
913
nebula.pb.go
File diff suppressed because it is too large
Load Diff
32
nebula.proto
32
nebula.proto
@ -23,19 +23,28 @@ message NebulaMeta {
|
|||||||
}
|
}
|
||||||
|
|
||||||
message NebulaMetaDetails {
|
message NebulaMetaDetails {
|
||||||
uint32 VpnIp = 1;
|
uint32 OldVpnAddr = 1 [deprecated = true];
|
||||||
repeated Ip4AndPort Ip4AndPorts = 2;
|
Addr VpnAddr = 6;
|
||||||
repeated Ip6AndPort Ip6AndPorts = 4;
|
|
||||||
repeated uint32 RelayVpnIp = 5;
|
repeated uint32 OldRelayVpnAddrs = 5 [deprecated = true];
|
||||||
|
repeated Addr RelayVpnAddrs = 7;
|
||||||
|
|
||||||
|
repeated V4AddrPort V4AddrPorts = 2;
|
||||||
|
repeated V6AddrPort V6AddrPorts = 4;
|
||||||
uint32 counter = 3;
|
uint32 counter = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
message Ip4AndPort {
|
message Addr {
|
||||||
uint32 Ip = 1;
|
uint64 Hi = 1;
|
||||||
|
uint64 Lo = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message V4AddrPort {
|
||||||
|
uint32 Addr = 1;
|
||||||
uint32 Port = 2;
|
uint32 Port = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
message Ip6AndPort {
|
message V6AddrPort {
|
||||||
uint64 Hi = 1;
|
uint64 Hi = 1;
|
||||||
uint64 Lo = 2;
|
uint64 Lo = 2;
|
||||||
uint32 Port = 3;
|
uint32 Port = 3;
|
||||||
@ -62,6 +71,7 @@ message NebulaHandshakeDetails {
|
|||||||
uint32 ResponderIndex = 3;
|
uint32 ResponderIndex = 3;
|
||||||
uint64 Cookie = 4;
|
uint64 Cookie = 4;
|
||||||
uint64 Time = 5;
|
uint64 Time = 5;
|
||||||
|
uint32 CertVersion = 8;
|
||||||
// reserved for WIP multiport
|
// reserved for WIP multiport
|
||||||
reserved 6, 7;
|
reserved 6, 7;
|
||||||
}
|
}
|
||||||
@ -76,6 +86,10 @@ message NebulaControl {
|
|||||||
|
|
||||||
uint32 InitiatorRelayIndex = 2;
|
uint32 InitiatorRelayIndex = 2;
|
||||||
uint32 ResponderRelayIndex = 3;
|
uint32 ResponderRelayIndex = 3;
|
||||||
uint32 RelayToIp = 4;
|
|
||||||
uint32 RelayFromIp = 5;
|
uint32 OldRelayToAddr = 4 [deprecated = true];
|
||||||
|
uint32 OldRelayFromAddr = 5 [deprecated = true];
|
||||||
|
|
||||||
|
Addr RelayToAddr = 6;
|
||||||
|
Addr RelayFromAddr = 7;
|
||||||
}
|
}
|
||||||
|
|||||||
217
outside.go
217
outside.go
@ -7,12 +7,12 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/flynn/noise"
|
"github.com/google/gopacket/layers"
|
||||||
|
"golang.org/x/net/ipv6"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/cert"
|
|
||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"github.com/slackhq/nebula/udp"
|
|
||||||
"golang.org/x/net/ipv4"
|
"golang.org/x/net/ipv4"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -20,24 +20,7 @@ const (
|
|||||||
minFwPacketLen = 4
|
minFwPacketLen = 4
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO: IPV6-WORK this can likely be removed now
|
func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
|
||||||
func readOutsidePackets(f *Interface) udp.EncReader {
|
|
||||||
return func(
|
|
||||||
addr netip.AddrPort,
|
|
||||||
out []byte,
|
|
||||||
packet []byte,
|
|
||||||
header *header.H,
|
|
||||||
fwPacket *firewall.Packet,
|
|
||||||
lhh udp.LightHouseHandlerFunc,
|
|
||||||
nb []byte,
|
|
||||||
q int,
|
|
||||||
localCache firewall.ConntrackCache,
|
|
||||||
) {
|
|
||||||
f.readOutsidePackets(addr, nil, out, packet, header, fwPacket, lhh, nb, q, localCache)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) {
|
|
||||||
err := h.Parse(packet)
|
err := h.Parse(packet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// TODO: best if we return this and let caller log
|
// TODO: best if we return this and let caller log
|
||||||
@ -51,7 +34,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
|||||||
|
|
||||||
//l.Error("in packet ", header, packet[HeaderLen:])
|
//l.Error("in packet ", header, packet[HeaderLen:])
|
||||||
if ip.IsValid() {
|
if ip.IsValid() {
|
||||||
if f.myVpnNet.Contains(ip.Addr()) {
|
_, found := f.myVpnNetworksTable.Lookup(ip.Addr())
|
||||||
|
if found {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
|
f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
|
||||||
}
|
}
|
||||||
@ -108,7 +92,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
|||||||
if !ok {
|
if !ok {
|
||||||
// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
|
// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
|
||||||
// its internal mapping. This should never happen.
|
// its internal mapping. This should never happen.
|
||||||
hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnIp": hostinfo.vpnIp, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
|
hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -120,9 +104,9 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
|||||||
return
|
return
|
||||||
case ForwardingType:
|
case ForwardingType:
|
||||||
// Find the target HostInfo relay object
|
// Find the target HostInfo relay object
|
||||||
targetHI, targetRelay, err := f.hostMap.QueryVpnIpRelayFor(hostinfo.vpnIp, relay.PeerIp)
|
targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithField("relayTo", relay.PeerIp).WithError(err).Info("Failed to find target host info by ip")
|
hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -138,7 +122,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
|||||||
hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
|
hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerIp, "relayFrom": hostinfo.vpnIp, "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
|
hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -161,7 +145,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
lhf(ip, hostinfo.vpnIp, d)
|
lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f)
|
||||||
|
|
||||||
// Fallthrough to the bottom to record incoming traffic
|
// Fallthrough to the bottom to record incoming traffic
|
||||||
|
|
||||||
@ -228,14 +212,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
|||||||
Error("Failed to decrypt Control packet")
|
Error("Failed to decrypt Control packet")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
m := &NebulaControl{}
|
|
||||||
err = m.Unmarshal(d)
|
|
||||||
if err != nil {
|
|
||||||
hostinfo.logger(f.l).WithError(err).Error("Failed to unmarshal control message")
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
f.relayManager.HandleControlMsg(hostinfo, m, f)
|
f.relayManager.HandleControlMsg(hostinfo, d, f)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
|
||||||
@ -252,8 +230,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
|
|||||||
func (f *Interface) closeTunnel(hostInfo *HostInfo) {
|
func (f *Interface) closeTunnel(hostInfo *HostInfo) {
|
||||||
final := f.hostMap.DeleteHostInfo(hostInfo)
|
final := f.hostMap.DeleteHostInfo(hostInfo)
|
||||||
if final {
|
if final {
|
||||||
// We no longer have any tunnels with this vpn ip, clear learned lighthouse state to lower memory usage
|
// We no longer have any tunnels with this vpn addr, clear learned lighthouse state to lower memory usage
|
||||||
f.lightHouse.DeleteVpnIp(hostInfo.vpnIp)
|
f.lightHouse.DeleteVpnAddrs(hostInfo.vpnAddrs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -262,25 +240,26 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) {
|
|||||||
f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
|
f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, ip netip.AddrPort) {
|
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, vpnAddr netip.AddrPort) {
|
||||||
if ip.IsValid() && hostinfo.remote != ip {
|
if vpnAddr.IsValid() && hostinfo.remote != vpnAddr {
|
||||||
if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, ip.Addr()) {
|
//TODO: this is weird now that we can have multiple vpn addrs
|
||||||
hostinfo.logger(f.l).WithField("newAddr", ip).Debug("lighthouse.remote_allow_list denied roaming")
|
if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnAddrs[0], vpnAddr.Addr()) {
|
||||||
|
hostinfo.logger(f.l).WithField("newAddr", vpnAddr).Debug("lighthouse.remote_allow_list denied roaming")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !hostinfo.lastRoam.IsZero() && ip == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
|
if !hostinfo.lastRoam.IsZero() && vpnAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip).
|
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", vpnAddr).
|
||||||
Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
|
Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip).
|
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", vpnAddr).
|
||||||
Info("Host roamed to new udp ip/port.")
|
Info("Host roamed to new udp ip/port.")
|
||||||
hostinfo.lastRoam = time.Now()
|
hostinfo.lastRoam = time.Now()
|
||||||
hostinfo.lastRoamRemote = hostinfo.remote
|
hostinfo.lastRoamRemote = hostinfo.remote
|
||||||
hostinfo.SetRemote(ip)
|
hostinfo.SetRemote(vpnAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@ -302,14 +281,114 @@ func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h
|
|||||||
|
|
||||||
// newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers
|
// newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers
|
||||||
func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
|
func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
|
||||||
// Do we at least have an ipv4 header worth of data?
|
if len(data) < 1 {
|
||||||
if len(data) < ipv4.HeaderLen {
|
return errors.New("packet too short")
|
||||||
return fmt.Errorf("packet is less than %v bytes", ipv4.HeaderLen)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Is it an ipv4 packet?
|
version := int((data[0] >> 4) & 0x0f)
|
||||||
if int((data[0]>>4)&0x0f) != 4 {
|
switch version {
|
||||||
return fmt.Errorf("packet is not ipv4, type: %v", int((data[0]>>4)&0x0f))
|
case ipv4.Version:
|
||||||
|
return parseV4(data, incoming, fp)
|
||||||
|
case ipv6.Version:
|
||||||
|
return parseV6(data, incoming, fp)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("packet is an unknown ip version: %v", version)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
|
||||||
|
dataLen := len(data)
|
||||||
|
if dataLen < ipv6.HeaderLen {
|
||||||
|
return fmt.Errorf("ipv6 packet is less than %v bytes", ipv4.HeaderLen)
|
||||||
|
}
|
||||||
|
|
||||||
|
if incoming {
|
||||||
|
fp.RemoteAddr, _ = netip.AddrFromSlice(data[8:24])
|
||||||
|
fp.LocalAddr, _ = netip.AddrFromSlice(data[24:40])
|
||||||
|
} else {
|
||||||
|
fp.LocalAddr, _ = netip.AddrFromSlice(data[8:24])
|
||||||
|
fp.RemoteAddr, _ = netip.AddrFromSlice(data[24:40])
|
||||||
|
}
|
||||||
|
|
||||||
|
//TODO: whats a reasonable number of extension headers to attempt to parse?
|
||||||
|
//https://www.ietf.org/archive/id/draft-ietf-6man-eh-limits-00.html
|
||||||
|
protoAt := 6
|
||||||
|
offset := 40
|
||||||
|
for i := 0; i < 24; i++ {
|
||||||
|
if dataLen < offset {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
proto := layers.IPProtocol(data[protoAt])
|
||||||
|
//fmt.Println(proto, protoAt)
|
||||||
|
switch proto {
|
||||||
|
case layers.IPProtocolICMPv6:
|
||||||
|
//TODO: we need a new protocol in config language "icmpv6"
|
||||||
|
fp.Protocol = uint8(proto)
|
||||||
|
fp.RemotePort = 0
|
||||||
|
fp.LocalPort = 0
|
||||||
|
fp.Fragment = false
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case layers.IPProtocolTCP:
|
||||||
|
if dataLen < offset+4 {
|
||||||
|
return fmt.Errorf("ipv6 packet was too small")
|
||||||
|
}
|
||||||
|
fp.Protocol = uint8(proto)
|
||||||
|
if incoming {
|
||||||
|
fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2])
|
||||||
|
fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
|
||||||
|
} else {
|
||||||
|
fp.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2])
|
||||||
|
fp.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
|
||||||
|
}
|
||||||
|
fp.Fragment = false
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case layers.IPProtocolUDP:
|
||||||
|
if dataLen < offset+4 {
|
||||||
|
return fmt.Errorf("ipv6 packet was too small")
|
||||||
|
}
|
||||||
|
fp.Protocol = uint8(proto)
|
||||||
|
if incoming {
|
||||||
|
fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2])
|
||||||
|
fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
|
||||||
|
} else {
|
||||||
|
fp.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2])
|
||||||
|
fp.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
|
||||||
|
}
|
||||||
|
fp.Fragment = false
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case layers.IPProtocolIPv6Fragment:
|
||||||
|
//TODO: can we determine the protocol?
|
||||||
|
fp.RemotePort = 0
|
||||||
|
fp.LocalPort = 0
|
||||||
|
fp.Fragment = true
|
||||||
|
return nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
if dataLen < offset+1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
next := int(data[offset+1]) * 8
|
||||||
|
if next == 0 {
|
||||||
|
// each extension is at least 8 bytes
|
||||||
|
next = 8
|
||||||
|
}
|
||||||
|
|
||||||
|
protoAt = offset
|
||||||
|
offset = offset + next
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("could not find payload in ipv6 packet")
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseV4(data []byte, incoming bool, fp *firewall.Packet) error {
|
||||||
|
// Do we at least have an ipv4 header worth of data?
|
||||||
|
if len(data) < ipv4.HeaderLen {
|
||||||
|
return fmt.Errorf("ipv4 packet is less than %v bytes", ipv4.HeaderLen)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Adjust our start position based on the advertised ip header length
|
// Adjust our start position based on the advertised ip header length
|
||||||
@ -317,7 +396,7 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
|
|||||||
|
|
||||||
// Well formed ip header length?
|
// Well formed ip header length?
|
||||||
if ihl < ipv4.HeaderLen {
|
if ihl < ipv4.HeaderLen {
|
||||||
return fmt.Errorf("packet had an invalid header length: %v", ihl)
|
return fmt.Errorf("ipv4 packet had an invalid header length: %v", ihl)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if this is the second or further fragment of a fragmented packet.
|
// Check if this is the second or further fragment of a fragmented packet.
|
||||||
@ -333,14 +412,13 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
|
|||||||
minLen += minFwPacketLen
|
minLen += minFwPacketLen
|
||||||
}
|
}
|
||||||
if len(data) < minLen {
|
if len(data) < minLen {
|
||||||
return fmt.Errorf("packet is less than %v bytes, ip header len: %v", minLen, ihl)
|
return fmt.Errorf("ipv4 packet is less than %v bytes, ip header len: %v", minLen, ihl)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Firewall packets are locally oriented
|
// Firewall packets are locally oriented
|
||||||
if incoming {
|
if incoming {
|
||||||
//TODO: IPV6-WORK
|
fp.RemoteAddr, _ = netip.AddrFromSlice(data[12:16])
|
||||||
fp.RemoteIP, _ = netip.AddrFromSlice(data[12:16])
|
fp.LocalAddr, _ = netip.AddrFromSlice(data[16:20])
|
||||||
fp.LocalIP, _ = netip.AddrFromSlice(data[16:20])
|
|
||||||
if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
|
if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
|
||||||
fp.RemotePort = 0
|
fp.RemotePort = 0
|
||||||
fp.LocalPort = 0
|
fp.LocalPort = 0
|
||||||
@ -349,9 +427,8 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
|
|||||||
fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
|
fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
//TODO: IPV6-WORK
|
fp.LocalAddr, _ = netip.AddrFromSlice(data[12:16])
|
||||||
fp.LocalIP, _ = netip.AddrFromSlice(data[12:16])
|
fp.RemoteAddr, _ = netip.AddrFromSlice(data[16:20])
|
||||||
fp.RemoteIP, _ = netip.AddrFromSlice(data[16:20])
|
|
||||||
if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
|
if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
|
||||||
fp.RemotePort = 0
|
fp.RemotePort = 0
|
||||||
fp.LocalPort = 0
|
fp.LocalPort = 0
|
||||||
@ -492,27 +569,3 @@ func (f *Interface) sendMeta(ci *ConnectionState, endpoint *net.UDPAddr, meta *N
|
|||||||
f.outside.WriteTo(msg, endpoint)
|
f.outside.WriteTo(msg, endpoint)
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte, caPool *cert.CAPool) (*cert.CachedCertificate, error) {
|
|
||||||
pk := h.PeerStatic()
|
|
||||||
|
|
||||||
if pk == nil {
|
|
||||||
return nil, errors.New("no peer static key was present")
|
|
||||||
}
|
|
||||||
|
|
||||||
if rawCertBytes == nil {
|
|
||||||
return nil, errors.New("provided payload was empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
c, err := cert.UnmarshalCertificateFromHandshake(rawCertBytes, pk)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error unmarshaling cert: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cc, err := caPool.VerifyCertificate(time.Now(), c)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("certificate validation failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return cc, nil
|
|
||||||
}
|
|
||||||
|
|||||||
@ -5,6 +5,9 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/gopacket"
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
|
||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"golang.org/x/net/ipv4"
|
"golang.org/x/net/ipv4"
|
||||||
@ -13,9 +16,15 @@ import (
|
|||||||
func Test_newPacket(t *testing.T) {
|
func Test_newPacket(t *testing.T) {
|
||||||
p := &firewall.Packet{}
|
p := &firewall.Packet{}
|
||||||
|
|
||||||
// length fail
|
// length fails
|
||||||
err := newPacket([]byte{0, 1}, true, p)
|
err := newPacket([]byte{}, true, p)
|
||||||
assert.EqualError(t, err, "packet is less than 20 bytes")
|
assert.EqualError(t, err, "packet too short")
|
||||||
|
|
||||||
|
err = newPacket([]byte{0x40}, true, p)
|
||||||
|
assert.EqualError(t, err, "ipv4 packet is less than 20 bytes")
|
||||||
|
|
||||||
|
err = newPacket([]byte{0x60}, true, p)
|
||||||
|
assert.EqualError(t, err, "ipv6 packet is less than 20 bytes")
|
||||||
|
|
||||||
// length fail with ip options
|
// length fail with ip options
|
||||||
h := ipv4.Header{
|
h := ipv4.Header{
|
||||||
@ -29,15 +38,15 @@ func Test_newPacket(t *testing.T) {
|
|||||||
b, _ := h.Marshal()
|
b, _ := h.Marshal()
|
||||||
err = newPacket(b, true, p)
|
err = newPacket(b, true, p)
|
||||||
|
|
||||||
assert.EqualError(t, err, "packet is less than 28 bytes, ip header len: 24")
|
assert.EqualError(t, err, "ipv4 packet is less than 28 bytes, ip header len: 24")
|
||||||
|
|
||||||
// not an ipv4 packet
|
// not an ipv4 packet
|
||||||
err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
|
err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
|
||||||
assert.EqualError(t, err, "packet is not ipv4, type: 0")
|
assert.EqualError(t, err, "packet is an unknown ip version: 0")
|
||||||
|
|
||||||
// invalid ihl
|
// invalid ihl
|
||||||
err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
|
err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
|
||||||
assert.EqualError(t, err, "packet had an invalid header length: 8")
|
assert.EqualError(t, err, "ipv4 packet had an invalid header length: 8")
|
||||||
|
|
||||||
// account for variable ip header length - incoming
|
// account for variable ip header length - incoming
|
||||||
h = ipv4.Header{
|
h = ipv4.Header{
|
||||||
@ -55,8 +64,8 @@ func Test_newPacket(t *testing.T) {
|
|||||||
|
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, p.Protocol, uint8(firewall.ProtoTCP))
|
assert.Equal(t, p.Protocol, uint8(firewall.ProtoTCP))
|
||||||
assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.2"))
|
assert.Equal(t, p.LocalAddr, netip.MustParseAddr("10.0.0.2"))
|
||||||
assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.1"))
|
assert.Equal(t, p.RemoteAddr, netip.MustParseAddr("10.0.0.1"))
|
||||||
assert.Equal(t, p.RemotePort, uint16(3))
|
assert.Equal(t, p.RemotePort, uint16(3))
|
||||||
assert.Equal(t, p.LocalPort, uint16(4))
|
assert.Equal(t, p.LocalPort, uint16(4))
|
||||||
|
|
||||||
@ -76,8 +85,60 @@ func Test_newPacket(t *testing.T) {
|
|||||||
|
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, p.Protocol, uint8(2))
|
assert.Equal(t, p.Protocol, uint8(2))
|
||||||
assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.1"))
|
assert.Equal(t, p.LocalAddr, netip.MustParseAddr("10.0.0.1"))
|
||||||
assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.2"))
|
assert.Equal(t, p.RemoteAddr, netip.MustParseAddr("10.0.0.2"))
|
||||||
assert.Equal(t, p.RemotePort, uint16(6))
|
assert.Equal(t, p.RemotePort, uint16(6))
|
||||||
assert.Equal(t, p.LocalPort, uint16(5))
|
assert.Equal(t, p.LocalPort, uint16(5))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_newPacket_v6(t *testing.T) {
|
||||||
|
p := &firewall.Packet{}
|
||||||
|
|
||||||
|
ip := layers.IPv6{
|
||||||
|
Version: 6,
|
||||||
|
NextHeader: firewall.ProtoUDP,
|
||||||
|
HopLimit: 128,
|
||||||
|
SrcIP: net.IPv6linklocalallrouters,
|
||||||
|
DstIP: net.IPv6linklocalallnodes,
|
||||||
|
}
|
||||||
|
|
||||||
|
udp := layers.UDP{
|
||||||
|
SrcPort: layers.UDPPort(36123),
|
||||||
|
DstPort: layers.UDPPort(22),
|
||||||
|
}
|
||||||
|
err := udp.SetNetworkLayerForChecksum(&ip)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer := gopacket.NewSerializeBuffer()
|
||||||
|
opt := gopacket.SerializeOptions{
|
||||||
|
ComputeChecksums: true,
|
||||||
|
FixLengths: true,
|
||||||
|
}
|
||||||
|
err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload([]byte{0xde, 0xad, 0xbe, 0xef}))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
b := buffer.Bytes()
|
||||||
|
|
||||||
|
//test incoming
|
||||||
|
err = newPacket(b, true, p)
|
||||||
|
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, p.Protocol, uint8(firewall.ProtoUDP))
|
||||||
|
assert.Equal(t, p.RemoteAddr, netip.MustParseAddr("ff02::2"))
|
||||||
|
assert.Equal(t, p.LocalAddr, netip.MustParseAddr("ff02::1"))
|
||||||
|
assert.Equal(t, p.RemotePort, uint16(36123))
|
||||||
|
assert.Equal(t, p.LocalPort, uint16(22))
|
||||||
|
|
||||||
|
//test outgoing
|
||||||
|
err = newPacket(b, false, p)
|
||||||
|
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, p.Protocol, uint8(firewall.ProtoUDP))
|
||||||
|
assert.Equal(t, p.LocalAddr, netip.MustParseAddr("ff02::2"))
|
||||||
|
assert.Equal(t, p.RemoteAddr, netip.MustParseAddr("ff02::1"))
|
||||||
|
assert.Equal(t, p.LocalPort, uint16(36123))
|
||||||
|
assert.Equal(t, p.RemotePort, uint16(22))
|
||||||
|
}
|
||||||
|
|||||||
@ -8,7 +8,7 @@ import (
|
|||||||
type Device interface {
|
type Device interface {
|
||||||
io.ReadWriteCloser
|
io.ReadWriteCloser
|
||||||
Activate() error
|
Activate() error
|
||||||
Cidr() netip.Prefix
|
Networks() []netip.Prefix
|
||||||
Name() string
|
Name() string
|
||||||
RouteFor(netip.Addr) netip.Addr
|
RouteFor(netip.Addr) netip.Addr
|
||||||
NewMultiQueueReader() (io.ReadWriteCloser, error)
|
NewMultiQueueReader() (io.ReadWriteCloser, error)
|
||||||
|
|||||||
@ -61,7 +61,7 @@ func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table
|
|||||||
return routeTree, nil
|
return routeTree, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
|
func parseRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
r := c.Get("tun.routes")
|
r := c.Get("tun.routes")
|
||||||
@ -117,12 +117,20 @@ func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
|
|||||||
return nil, fmt.Errorf("entry %v.route in tun.routes failed to parse: %v", i+1, err)
|
return nil, fmt.Errorf("entry %v.route in tun.routes failed to parse: %v", i+1, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !network.Contains(r.Cidr.Addr()) || r.Cidr.Bits() < network.Bits() {
|
found := false
|
||||||
|
for _, network := range networks {
|
||||||
|
if network.Contains(r.Cidr.Addr()) && r.Cidr.Bits() >= network.Bits() {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !found {
|
||||||
return nil, fmt.Errorf(
|
return nil, fmt.Errorf(
|
||||||
"entry %v.route in tun.routes is not contained within the network attached to the certificate; route: %v, network: %v",
|
"entry %v.route in tun.routes is not contained within the configured vpn networks; route: %v, networks: %v",
|
||||||
i+1,
|
i+1,
|
||||||
r.Cidr.String(),
|
r.Cidr.String(),
|
||||||
network.String(),
|
networks,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -132,7 +140,7 @@ func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
|
|||||||
return routes, nil
|
return routes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseUnsafeRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
|
func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
r := c.Get("tun.unsafe_routes")
|
r := c.Get("tun.unsafe_routes")
|
||||||
@ -229,14 +237,16 @@ func parseUnsafeRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
|
|||||||
return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes failed to parse: %v", i+1, err)
|
return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes failed to parse: %v", i+1, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, network := range networks {
|
||||||
if network.Contains(r.Cidr.Addr()) {
|
if network.Contains(r.Cidr.Addr()) {
|
||||||
return nil, fmt.Errorf(
|
return nil, fmt.Errorf(
|
||||||
"entry %v.route in tun.unsafe_routes is contained within the network attached to the certificate; route: %v, network: %v",
|
"entry %v.route in tun.unsafe_routes is contained within the configured vpn networks; route: %v, network: %v",
|
||||||
i+1,
|
i+1,
|
||||||
r.Cidr.String(),
|
r.Cidr.String(),
|
||||||
network.String(),
|
network.String(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
routes[i] = r
|
routes[i] = r
|
||||||
}
|
}
|
||||||
|
|||||||
@ -17,76 +17,82 @@ func Test_parseRoutes(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// test no routes config
|
// test no routes config
|
||||||
routes, err := parseRoutes(c, n)
|
routes, err := parseRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Len(t, routes, 0)
|
assert.Len(t, routes, 0)
|
||||||
|
|
||||||
// not an array
|
// not an array
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"}
|
c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"}
|
||||||
routes, err = parseRoutes(c, n)
|
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
assert.EqualError(t, err, "tun.routes is not an array")
|
assert.EqualError(t, err, "tun.routes is not an array")
|
||||||
|
|
||||||
// no routes
|
// no routes
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}}
|
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}}
|
||||||
routes, err = parseRoutes(c, n)
|
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Len(t, routes, 0)
|
assert.Len(t, routes, 0)
|
||||||
|
|
||||||
// weird route
|
// weird route
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}}
|
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}}
|
||||||
routes, err = parseRoutes(c, n)
|
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
assert.EqualError(t, err, "entry 1 in tun.routes is invalid")
|
assert.EqualError(t, err, "entry 1 in tun.routes is invalid")
|
||||||
|
|
||||||
// no mtu
|
// no mtu
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{}}}
|
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{}}}
|
||||||
routes, err = parseRoutes(c, n)
|
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
assert.EqualError(t, err, "entry 1.mtu in tun.routes is not present")
|
assert.EqualError(t, err, "entry 1.mtu in tun.routes is not present")
|
||||||
|
|
||||||
// bad mtu
|
// bad mtu
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "nope"}}}
|
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "nope"}}}
|
||||||
routes, err = parseRoutes(c, n)
|
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
assert.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
|
assert.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
|
||||||
|
|
||||||
// low mtu
|
// low mtu
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "499"}}}
|
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "499"}}}
|
||||||
routes, err = parseRoutes(c, n)
|
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
assert.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499")
|
assert.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499")
|
||||||
|
|
||||||
// missing route
|
// missing route
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500"}}}
|
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500"}}}
|
||||||
routes, err = parseRoutes(c, n)
|
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
assert.EqualError(t, err, "entry 1.route in tun.routes is not present")
|
assert.EqualError(t, err, "entry 1.route in tun.routes is not present")
|
||||||
|
|
||||||
// unparsable route
|
// unparsable route
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}}
|
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}}
|
||||||
routes, err = parseRoutes(c, n)
|
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
|
assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
|
||||||
|
|
||||||
// below network range
|
// below network range
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}}
|
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}}
|
||||||
routes, err = parseRoutes(c, n)
|
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the network attached to the certificate; route: 1.0.0.0/8, network: 10.0.0.0/24")
|
assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 1.0.0.0/8, networks: [10.0.0.0/24]")
|
||||||
|
|
||||||
// above network range
|
// above network range
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "10.0.1.0/24"}}}
|
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "10.0.1.0/24"}}}
|
||||||
routes, err = parseRoutes(c, n)
|
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the network attached to the certificate; route: 10.0.1.0/24, network: 10.0.0.0/24")
|
assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 10.0.1.0/24, networks: [10.0.0.0/24]")
|
||||||
|
|
||||||
|
// Not in multiple ranges
|
||||||
|
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "192.0.0.0/24"}}}
|
||||||
|
routes, err = parseRoutes(c, []netip.Prefix{n, netip.MustParsePrefix("192.1.0.0/24")})
|
||||||
|
assert.Nil(t, routes)
|
||||||
|
assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 192.0.0.0/24, networks: [10.0.0.0/24 192.1.0.0/24]")
|
||||||
|
|
||||||
// happy case
|
// happy case
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{
|
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{
|
||||||
map[interface{}]interface{}{"mtu": "9000", "route": "10.0.0.0/29"},
|
map[interface{}]interface{}{"mtu": "9000", "route": "10.0.0.0/29"},
|
||||||
map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"},
|
map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"},
|
||||||
}}
|
}}
|
||||||
routes, err = parseRoutes(c, n)
|
routes, err = parseRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Len(t, routes, 2)
|
assert.Len(t, routes, 2)
|
||||||
|
|
||||||
@ -116,31 +122,31 @@ func Test_parseUnsafeRoutes(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// test no routes config
|
// test no routes config
|
||||||
routes, err := parseUnsafeRoutes(c, n)
|
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Len(t, routes, 0)
|
assert.Len(t, routes, 0)
|
||||||
|
|
||||||
// not an array
|
// not an array
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"}
|
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"}
|
||||||
routes, err = parseUnsafeRoutes(c, n)
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
assert.EqualError(t, err, "tun.unsafe_routes is not an array")
|
assert.EqualError(t, err, "tun.unsafe_routes is not an array")
|
||||||
|
|
||||||
// no routes
|
// no routes
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}}
|
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}}
|
||||||
routes, err = parseUnsafeRoutes(c, n)
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Len(t, routes, 0)
|
assert.Len(t, routes, 0)
|
||||||
|
|
||||||
// weird route
|
// weird route
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}}
|
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}}
|
||||||
routes, err = parseUnsafeRoutes(c, n)
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
assert.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid")
|
assert.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid")
|
||||||
|
|
||||||
// no via
|
// no via
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{}}}
|
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{}}}
|
||||||
routes, err = parseUnsafeRoutes(c, n)
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present")
|
assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present")
|
||||||
|
|
||||||
@ -149,68 +155,68 @@ func Test_parseUnsafeRoutes(t *testing.T) {
|
|||||||
127, false, nil, 1.0, []string{"1", "2"},
|
127, false, nil, 1.0, []string{"1", "2"},
|
||||||
} {
|
} {
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}}
|
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}}
|
||||||
routes, err = parseUnsafeRoutes(c, n)
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
assert.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue))
|
assert.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue))
|
||||||
}
|
}
|
||||||
|
|
||||||
// unparsable via
|
// unparsable via
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}}
|
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}}
|
||||||
routes, err = parseUnsafeRoutes(c, n)
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP")
|
assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP")
|
||||||
|
|
||||||
// missing route
|
// missing route
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}}
|
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}}
|
||||||
routes, err = parseUnsafeRoutes(c, n)
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present")
|
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present")
|
||||||
|
|
||||||
// unparsable route
|
// unparsable route
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}}
|
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}}
|
||||||
routes, err = parseUnsafeRoutes(c, n)
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
|
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
|
||||||
|
|
||||||
// within network range
|
// within network range
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}}
|
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}}
|
||||||
routes, err = parseUnsafeRoutes(c, n)
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the network attached to the certificate; route: 10.0.0.0/24, network: 10.0.0.0/24")
|
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the configured vpn networks; route: 10.0.0.0/24, network: 10.0.0.0/24")
|
||||||
|
|
||||||
// below network range
|
// below network range
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
|
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
|
||||||
routes, err = parseUnsafeRoutes(c, n)
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Len(t, routes, 1)
|
assert.Len(t, routes, 1)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
// above network range
|
// above network range
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}}
|
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}}
|
||||||
routes, err = parseUnsafeRoutes(c, n)
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Len(t, routes, 1)
|
assert.Len(t, routes, 1)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
// no mtu
|
// no mtu
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
|
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
|
||||||
routes, err = parseUnsafeRoutes(c, n)
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Len(t, routes, 1)
|
assert.Len(t, routes, 1)
|
||||||
assert.Equal(t, 0, routes[0].MTU)
|
assert.Equal(t, 0, routes[0].MTU)
|
||||||
|
|
||||||
// bad mtu
|
// bad mtu
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}}
|
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}}
|
||||||
routes, err = parseUnsafeRoutes(c, n)
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
|
assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
|
||||||
|
|
||||||
// low mtu
|
// low mtu
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "499"}}}
|
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "499"}}}
|
||||||
routes, err = parseUnsafeRoutes(c, n)
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499")
|
assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499")
|
||||||
|
|
||||||
// bad install
|
// bad install
|
||||||
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}}
|
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}}
|
||||||
routes, err = parseUnsafeRoutes(c, n)
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, routes)
|
assert.Nil(t, routes)
|
||||||
assert.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax")
|
assert.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax")
|
||||||
|
|
||||||
@ -221,7 +227,7 @@ func Test_parseUnsafeRoutes(t *testing.T) {
|
|||||||
map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1},
|
map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1},
|
||||||
map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"},
|
map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"},
|
||||||
}}
|
}}
|
||||||
routes, err = parseUnsafeRoutes(c, n)
|
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Len(t, routes, 4)
|
assert.Len(t, routes, 4)
|
||||||
|
|
||||||
@ -260,7 +266,7 @@ func Test_makeRouteTree(t *testing.T) {
|
|||||||
map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"},
|
map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"},
|
||||||
map[interface{}]interface{}{"via": "192.168.0.2", "route": "1.0.0.1/32"},
|
map[interface{}]interface{}{"via": "192.168.0.2", "route": "1.0.0.1/32"},
|
||||||
}}
|
}}
|
||||||
routes, err := parseUnsafeRoutes(c, n)
|
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Len(t, routes, 2)
|
assert.Len(t, routes, 2)
|
||||||
routeTree, err := makeRouteTree(l, routes, true)
|
routeTree, err := makeRouteTree(l, routes, true)
|
||||||
|
|||||||
@ -11,36 +11,36 @@ import (
|
|||||||
const DefaultMTU = 1300
|
const DefaultMTU = 1300
|
||||||
|
|
||||||
// TODO: We may be able to remove routines
|
// TODO: We may be able to remove routines
|
||||||
type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error)
|
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
|
||||||
|
|
||||||
func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) {
|
func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
||||||
switch {
|
switch {
|
||||||
case c.GetBool("tun.disabled", false):
|
case c.GetBool("tun.disabled", false):
|
||||||
tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
|
tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
|
||||||
return tun, nil
|
return tun, nil
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return newTun(c, l, tunCidr, routines > 1)
|
return newTun(c, l, vpnNetworks, routines > 1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewFdDeviceFromConfig(fd *int) DeviceFactory {
|
func NewFdDeviceFromConfig(fd *int) DeviceFactory {
|
||||||
return func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) {
|
return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
||||||
return newTunFromFd(c, l, *fd, tunCidr)
|
return newTunFromFd(c, l, *fd, vpnNetworks)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getAllRoutesFromConfig(c *config.C, cidr netip.Prefix, initial bool) (bool, []Route, error) {
|
func getAllRoutesFromConfig(c *config.C, vpnNetworks []netip.Prefix, initial bool) (bool, []Route, error) {
|
||||||
if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") {
|
if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") {
|
||||||
return false, nil, nil
|
return false, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
routes, err := parseRoutes(c, cidr)
|
routes, err := parseRoutes(c, vpnNetworks)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return true, nil, util.NewContextualError("Could not parse tun.routes", nil, err)
|
return true, nil, util.NewContextualError("Could not parse tun.routes", nil, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafeRoutes, err := parseUnsafeRoutes(c, cidr)
|
unsafeRoutes, err := parseUnsafeRoutes(c, vpnNetworks)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return true, nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err)
|
return true, nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -19,13 +19,13 @@ import (
|
|||||||
type tun struct {
|
type tun struct {
|
||||||
io.ReadWriteCloser
|
io.ReadWriteCloser
|
||||||
fd int
|
fd int
|
||||||
cidr netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
|
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||||
// XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly.
|
// XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly.
|
||||||
// Be sure not to call file.Fd() as it will set the fd to blocking mode.
|
// Be sure not to call file.Fd() as it will set the fd to blocking mode.
|
||||||
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
||||||
@ -33,7 +33,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix
|
|||||||
t := &tun{
|
t := &tun{
|
||||||
ReadWriteCloser: file,
|
ReadWriteCloser: file,
|
||||||
fd: deviceFd,
|
fd: deviceFd,
|
||||||
cidr: cidr,
|
vpnNetworks: vpnNetworks,
|
||||||
l: l,
|
l: l,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -52,7 +52,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix
|
|||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) {
|
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
|
||||||
return nil, fmt.Errorf("newTun not supported in Android")
|
return nil, fmt.Errorf("newTun not supported in Android")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -66,7 +66,7 @@ func (t tun) Activate() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) reload(c *config.C, initial bool) error {
|
func (t *tun) reload(c *config.C, initial bool) error {
|
||||||
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
|
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -86,8 +86,8 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Cidr() netip.Prefix {
|
func (t *tun) Networks() []netip.Prefix {
|
||||||
return t.cidr
|
return t.vpnNetworks
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Name() string {
|
func (t *tun) Name() string {
|
||||||
|
|||||||
@ -25,7 +25,7 @@ import (
|
|||||||
type tun struct {
|
type tun struct {
|
||||||
io.ReadWriteCloser
|
io.ReadWriteCloser
|
||||||
Device string
|
Device string
|
||||||
cidr netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
DefaultMTU int
|
DefaultMTU int
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||||
@ -36,44 +36,50 @@ type tun struct {
|
|||||||
out []byte
|
out []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
type sockaddrCtl struct {
|
|
||||||
scLen uint8
|
|
||||||
scFamily uint8
|
|
||||||
ssSysaddr uint16
|
|
||||||
scID uint32
|
|
||||||
scUnit uint32
|
|
||||||
scReserved [5]uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
type ifReq struct {
|
type ifReq struct {
|
||||||
Name [16]byte
|
Name [unix.IFNAMSIZ]byte
|
||||||
Flags uint16
|
Flags uint16
|
||||||
pad [8]byte
|
pad [8]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
var sockaddrCtlSize uintptr = 32
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
_SYSPROTO_CONTROL = 2 //define SYSPROTO_CONTROL 2 /* kernel control protocol */
|
_SIOCAIFADDR_IN6 = 2155899162
|
||||||
_AF_SYS_CONTROL = 2 //#define AF_SYS_CONTROL 2 /* corresponding sub address type */
|
_UTUN_OPT_IFNAME = 2
|
||||||
_PF_SYSTEM = unix.AF_SYSTEM //#define PF_SYSTEM AF_SYSTEM
|
_IN6_IFF_NODAD = 0x0020
|
||||||
_CTLIOCGINFO = 3227799043 //#define CTLIOCGINFO _IOWR('N', 3, struct ctl_info)
|
_IN6_IFF_SECURED = 0x0400
|
||||||
utunControlName = "com.apple.net.utun_control"
|
utunControlName = "com.apple.net.utun_control"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ifreqAddr struct {
|
|
||||||
Name [16]byte
|
|
||||||
Addr unix.RawSockaddrInet4
|
|
||||||
pad [8]byte
|
|
||||||
}
|
|
||||||
|
|
||||||
type ifreqMTU struct {
|
type ifreqMTU struct {
|
||||||
Name [16]byte
|
Name [16]byte
|
||||||
MTU int32
|
MTU int32
|
||||||
pad [8]byte
|
pad [8]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
|
type addrLifetime struct {
|
||||||
|
Expire float64
|
||||||
|
Preferred float64
|
||||||
|
Vltime uint32
|
||||||
|
Pltime uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
type ifreqAlias4 struct {
|
||||||
|
Name [unix.IFNAMSIZ]byte
|
||||||
|
Addr unix.RawSockaddrInet4
|
||||||
|
DstAddr unix.RawSockaddrInet4
|
||||||
|
MaskAddr unix.RawSockaddrInet4
|
||||||
|
}
|
||||||
|
|
||||||
|
type ifreqAlias6 struct {
|
||||||
|
Name [unix.IFNAMSIZ]byte
|
||||||
|
Addr unix.RawSockaddrInet6
|
||||||
|
DstAddr unix.RawSockaddrInet6
|
||||||
|
PrefixMask unix.RawSockaddrInet6
|
||||||
|
Flags uint32
|
||||||
|
Lifetime addrLifetime
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||||
name := c.GetString("tun.dev", "")
|
name := c.GetString("tun.dev", "")
|
||||||
ifIndex := -1
|
ifIndex := -1
|
||||||
if name != "" && name != "utun" {
|
if name != "" && name != "utun" {
|
||||||
@ -86,66 +92,41 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fd, err := unix.Socket(_PF_SYSTEM, unix.SOCK_DGRAM, _SYSPROTO_CONTROL)
|
fd, err := unix.Socket(unix.AF_SYSTEM, unix.SOCK_DGRAM, unix.AF_SYS_CONTROL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("system socket: %v", err)
|
return nil, fmt.Errorf("system socket: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var ctlInfo = &struct {
|
var ctlInfo = &unix.CtlInfo{}
|
||||||
ctlID uint32
|
copy(ctlInfo.Name[:], utunControlName)
|
||||||
ctlName [96]byte
|
|
||||||
}{}
|
|
||||||
|
|
||||||
copy(ctlInfo.ctlName[:], utunControlName)
|
err = unix.IoctlCtlInfo(fd, ctlInfo)
|
||||||
|
|
||||||
err = ioctl(uintptr(fd), uintptr(_CTLIOCGINFO), uintptr(unsafe.Pointer(ctlInfo)))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("CTLIOCGINFO: %v", err)
|
return nil, fmt.Errorf("CTLIOCGINFO: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
sc := sockaddrCtl{
|
err = unix.Connect(fd, &unix.SockaddrCtl{
|
||||||
scLen: uint8(sockaddrCtlSize),
|
ID: ctlInfo.Id,
|
||||||
scFamily: unix.AF_SYSTEM,
|
Unit: uint32(ifIndex) + 1,
|
||||||
ssSysaddr: _AF_SYS_CONTROL,
|
})
|
||||||
scID: ctlInfo.ctlID,
|
if err != nil {
|
||||||
scUnit: uint32(ifIndex) + 1,
|
return nil, fmt.Errorf("SYS_CONNECT: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, _, errno := unix.RawSyscall(
|
name, err = unix.GetsockoptString(fd, unix.AF_SYS_CONTROL, _UTUN_OPT_IFNAME)
|
||||||
unix.SYS_CONNECT,
|
if err != nil {
|
||||||
uintptr(fd),
|
return nil, fmt.Errorf("failed to retrieve tun name: %w", err)
|
||||||
uintptr(unsafe.Pointer(&sc)),
|
|
||||||
sockaddrCtlSize,
|
|
||||||
)
|
|
||||||
if errno != 0 {
|
|
||||||
return nil, fmt.Errorf("SYS_CONNECT: %v", errno)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var ifName struct {
|
err = unix.SetNonblock(fd, true)
|
||||||
name [16]byte
|
|
||||||
}
|
|
||||||
ifNameSize := uintptr(len(ifName.name))
|
|
||||||
_, _, errno = syscall.Syscall6(syscall.SYS_GETSOCKOPT, uintptr(fd),
|
|
||||||
2, // SYSPROTO_CONTROL
|
|
||||||
2, // UTUN_OPT_IFNAME
|
|
||||||
uintptr(unsafe.Pointer(&ifName)),
|
|
||||||
uintptr(unsafe.Pointer(&ifNameSize)), 0)
|
|
||||||
if errno != 0 {
|
|
||||||
return nil, fmt.Errorf("SYS_GETSOCKOPT: %v", errno)
|
|
||||||
}
|
|
||||||
name = string(ifName.name[:ifNameSize-1])
|
|
||||||
|
|
||||||
err = syscall.SetNonblock(fd, true)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("SetNonblock: %v", err)
|
return nil, fmt.Errorf("SetNonblock: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
file := os.NewFile(uintptr(fd), "")
|
|
||||||
|
|
||||||
t := &tun{
|
t := &tun{
|
||||||
ReadWriteCloser: file,
|
ReadWriteCloser: os.NewFile(uintptr(fd), ""),
|
||||||
Device: name,
|
Device: name,
|
||||||
cidr: cidr,
|
vpnNetworks: vpnNetworks,
|
||||||
DefaultMTU: c.GetInt("tun.mtu", DefaultMTU),
|
DefaultMTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||||
l: l,
|
l: l,
|
||||||
}
|
}
|
||||||
@ -172,7 +153,7 @@ func (t *tun) deviceBytes() (o [16]byte) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
|
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||||
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
|
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -186,16 +167,6 @@ func (t *tun) Close() error {
|
|||||||
func (t *tun) Activate() error {
|
func (t *tun) Activate() error {
|
||||||
devName := t.deviceBytes()
|
devName := t.deviceBytes()
|
||||||
|
|
||||||
var addr, mask [4]byte
|
|
||||||
|
|
||||||
if !t.cidr.Addr().Is4() {
|
|
||||||
//TODO: IPV6-WORK
|
|
||||||
panic("need ipv6")
|
|
||||||
}
|
|
||||||
|
|
||||||
addr = t.cidr.Addr().As4()
|
|
||||||
copy(mask[:], prefixToMask(t.cidr))
|
|
||||||
|
|
||||||
s, err := unix.Socket(
|
s, err := unix.Socket(
|
||||||
unix.AF_INET,
|
unix.AF_INET,
|
||||||
unix.SOCK_DGRAM,
|
unix.SOCK_DGRAM,
|
||||||
@ -208,66 +179,18 @@ func (t *tun) Activate() error {
|
|||||||
|
|
||||||
fd := uintptr(s)
|
fd := uintptr(s)
|
||||||
|
|
||||||
ifra := ifreqAddr{
|
|
||||||
Name: devName,
|
|
||||||
Addr: unix.RawSockaddrInet4{
|
|
||||||
Family: unix.AF_INET,
|
|
||||||
Addr: addr,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the device ip address
|
|
||||||
if err = ioctl(fd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil {
|
|
||||||
return fmt.Errorf("failed to set tun address: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the device network
|
|
||||||
ifra.Addr.Addr = mask
|
|
||||||
if err = ioctl(fd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil {
|
|
||||||
return fmt.Errorf("failed to set tun netmask: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the device name
|
|
||||||
ifrf := ifReq{Name: devName}
|
|
||||||
if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
|
||||||
return fmt.Errorf("failed to set tun device name: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the MTU on the device
|
// Set the MTU on the device
|
||||||
ifm := ifreqMTU{Name: devName, MTU: int32(t.DefaultMTU)}
|
ifm := ifreqMTU{Name: devName, MTU: int32(t.DefaultMTU)}
|
||||||
if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
|
if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
|
||||||
return fmt.Errorf("failed to set tun mtu: %v", err)
|
return fmt.Errorf("failed to set tun mtu: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
// Get the device flags
|
||||||
// Set the transmit queue length
|
ifrf := ifReq{Name: devName}
|
||||||
ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)}
|
if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
||||||
if err = ioctl(fd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
|
return fmt.Errorf("failed to get tun flags: %s", err)
|
||||||
// If we can't set the queue length nebula will still work but it may lead to packet loss
|
|
||||||
l.WithError(err).Error("Failed to set tun tx queue length")
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
// Bring up the interface
|
|
||||||
ifrf.Flags = ifrf.Flags | unix.IFF_UP
|
|
||||||
if err = ioctl(fd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
|
||||||
return fmt.Errorf("failed to bring the tun device up: %s", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
unix.Shutdown(routeSock, unix.SHUT_RDWR)
|
|
||||||
err := unix.Close(routeSock)
|
|
||||||
if err != nil {
|
|
||||||
t.l.WithError(err).Error("failed to close AF_ROUTE socket")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
routeAddr := &netroute.Inet4Addr{}
|
|
||||||
maskAddr := &netroute.Inet4Addr{}
|
|
||||||
linkAddr, err := getLinkAddr(t.Device)
|
linkAddr, err := getLinkAddr(t.Device)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -277,15 +200,19 @@ func (t *tun) Activate() error {
|
|||||||
}
|
}
|
||||||
t.linkAddr = linkAddr
|
t.linkAddr = linkAddr
|
||||||
|
|
||||||
copy(routeAddr.IP[:], addr[:])
|
for _, network := range t.vpnNetworks {
|
||||||
copy(maskAddr.IP[:], mask[:])
|
if network.Addr().Is4() {
|
||||||
err = addRoute(routeSock, routeAddr, maskAddr, linkAddr)
|
err = t.activate4(network)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, unix.EEXIST) {
|
|
||||||
err = fmt.Errorf("unable to add tun route, identical route already exists: %s", t.cidr)
|
|
||||||
}
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
err = t.activate6(network)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Run the interface
|
// Run the interface
|
||||||
ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING
|
ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING
|
||||||
@ -297,8 +224,89 @@ func (t *tun) Activate() error {
|
|||||||
return t.addRoutes(false)
|
return t.addRoutes(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *tun) activate4(network netip.Prefix) error {
|
||||||
|
s, err := unix.Socket(
|
||||||
|
unix.AF_INET,
|
||||||
|
unix.SOCK_DGRAM,
|
||||||
|
unix.IPPROTO_IP,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer unix.Close(s)
|
||||||
|
|
||||||
|
ifr := ifreqAlias4{
|
||||||
|
Name: t.deviceBytes(),
|
||||||
|
Addr: unix.RawSockaddrInet4{
|
||||||
|
Len: unix.SizeofSockaddrInet4,
|
||||||
|
Family: unix.AF_INET,
|
||||||
|
Addr: network.Addr().As4(),
|
||||||
|
},
|
||||||
|
DstAddr: unix.RawSockaddrInet4{
|
||||||
|
Len: unix.SizeofSockaddrInet4,
|
||||||
|
Family: unix.AF_INET,
|
||||||
|
Addr: network.Addr().As4(),
|
||||||
|
},
|
||||||
|
MaskAddr: unix.RawSockaddrInet4{
|
||||||
|
Len: unix.SizeofSockaddrInet4,
|
||||||
|
Family: unix.AF_INET,
|
||||||
|
Addr: prefixToMask(network).As4(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
|
||||||
|
return fmt.Errorf("failed to set tun v4 address: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = addRoute(network, t.linkAddr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tun) activate6(network netip.Prefix) error {
|
||||||
|
s, err := unix.Socket(
|
||||||
|
unix.AF_INET6,
|
||||||
|
unix.SOCK_DGRAM,
|
||||||
|
unix.IPPROTO_IP,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer unix.Close(s)
|
||||||
|
|
||||||
|
ifr := ifreqAlias6{
|
||||||
|
Name: t.deviceBytes(),
|
||||||
|
Addr: unix.RawSockaddrInet6{
|
||||||
|
Len: unix.SizeofSockaddrInet6,
|
||||||
|
Family: unix.AF_INET6,
|
||||||
|
Addr: network.Addr().As16(),
|
||||||
|
},
|
||||||
|
PrefixMask: unix.RawSockaddrInet6{
|
||||||
|
Len: unix.SizeofSockaddrInet6,
|
||||||
|
Family: unix.AF_INET6,
|
||||||
|
Addr: prefixToMask(network).As16(),
|
||||||
|
},
|
||||||
|
Lifetime: addrLifetime{
|
||||||
|
// never expires
|
||||||
|
Vltime: 0xffffffff,
|
||||||
|
Pltime: 0xffffffff,
|
||||||
|
},
|
||||||
|
//TODO: should we disable DAD (duplicate address detection) and mark this as a secured address?
|
||||||
|
Flags: _IN6_IFF_NODAD,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ioctl(uintptr(s), _SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil {
|
||||||
|
return fmt.Errorf("failed to set tun address: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (t *tun) reload(c *config.C, initial bool) error {
|
func (t *tun) reload(c *config.C, initial bool) error {
|
||||||
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
|
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -371,38 +379,15 @@ func getLinkAddr(name string) (*netroute.LinkAddr, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) addRoutes(logErrors bool) error {
|
func (t *tun) addRoutes(logErrors bool) error {
|
||||||
routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
unix.Shutdown(routeSock, unix.SHUT_RDWR)
|
|
||||||
err := unix.Close(routeSock)
|
|
||||||
if err != nil {
|
|
||||||
t.l.WithError(err).Error("failed to close AF_ROUTE socket")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
routeAddr := &netroute.Inet4Addr{}
|
|
||||||
maskAddr := &netroute.Inet4Addr{}
|
|
||||||
routes := *t.Routes.Load()
|
routes := *t.Routes.Load()
|
||||||
|
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if !r.Via.IsValid() || !r.Install {
|
if !r.Via.IsValid() || !r.Install {
|
||||||
// We don't allow route MTUs so only install routes with a via
|
// We don't allow route MTUs so only install routes with a via
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if !r.Cidr.Addr().Is4() {
|
err := addRoute(r.Cidr, t.linkAddr)
|
||||||
//TODO: implement ipv6
|
|
||||||
panic("Cant handle ipv6 routes yet")
|
|
||||||
}
|
|
||||||
|
|
||||||
routeAddr.IP = r.Cidr.Addr().As4()
|
|
||||||
//TODO: we could avoid the copy
|
|
||||||
copy(maskAddr.IP[:], prefixToMask(r.Cidr))
|
|
||||||
|
|
||||||
err := addRoute(routeSock, routeAddr, maskAddr, t.linkAddr)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, unix.EEXIST) {
|
if errors.Is(err, unix.EEXIST) {
|
||||||
t.l.WithField("route", r.Cidr).
|
t.l.WithField("route", r.Cidr).
|
||||||
@ -424,36 +409,12 @@ func (t *tun) addRoutes(logErrors bool) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) removeRoutes(routes []Route) error {
|
func (t *tun) removeRoutes(routes []Route) error {
|
||||||
routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
unix.Shutdown(routeSock, unix.SHUT_RDWR)
|
|
||||||
err := unix.Close(routeSock)
|
|
||||||
if err != nil {
|
|
||||||
t.l.WithError(err).Error("failed to close AF_ROUTE socket")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
routeAddr := &netroute.Inet4Addr{}
|
|
||||||
maskAddr := &netroute.Inet4Addr{}
|
|
||||||
|
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if !r.Install {
|
if !r.Install {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Cidr.Addr().Is6() {
|
err := delRoute(r.Cidr, t.linkAddr)
|
||||||
//TODO: implement ipv6
|
|
||||||
panic("Cant handle ipv6 routes yet")
|
|
||||||
}
|
|
||||||
|
|
||||||
routeAddr.IP = r.Cidr.Addr().As4()
|
|
||||||
copy(maskAddr.IP[:], prefixToMask(r.Cidr))
|
|
||||||
|
|
||||||
err := delRoute(routeSock, routeAddr, maskAddr, t.linkAddr)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||||
} else {
|
} else {
|
||||||
@ -463,23 +424,39 @@ func (t *tun) removeRoutes(routes []Route) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error {
|
func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
|
||||||
r := netroute.RouteMessage{
|
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||||
|
}
|
||||||
|
defer unix.Close(sock)
|
||||||
|
|
||||||
|
route := &netroute.RouteMessage{
|
||||||
Version: unix.RTM_VERSION,
|
Version: unix.RTM_VERSION,
|
||||||
Type: unix.RTM_ADD,
|
Type: unix.RTM_ADD,
|
||||||
Flags: unix.RTF_UP,
|
Flags: unix.RTF_UP,
|
||||||
Seq: 1,
|
Seq: 1,
|
||||||
Addrs: []netroute.Addr{
|
|
||||||
unix.RTAX_DST: addr,
|
|
||||||
unix.RTAX_GATEWAY: link,
|
|
||||||
unix.RTAX_NETMASK: mask,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := r.Marshal()
|
if prefix.Addr().Is4() {
|
||||||
|
route.Addrs = []netroute.Addr{
|
||||||
|
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||||
|
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||||
|
unix.RTAX_GATEWAY: gateway,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
route.Addrs = []netroute.Addr{
|
||||||
|
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||||
|
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||||
|
unix.RTAX_GATEWAY: gateway,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := route.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = unix.Write(sock, data[:])
|
_, err = unix.Write(sock, data[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
|
||||||
@ -488,19 +465,34 @@ func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr)
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func delRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error {
|
func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
|
||||||
r := netroute.RouteMessage{
|
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
|
||||||
|
}
|
||||||
|
defer unix.Close(sock)
|
||||||
|
|
||||||
|
route := netroute.RouteMessage{
|
||||||
Version: unix.RTM_VERSION,
|
Version: unix.RTM_VERSION,
|
||||||
Type: unix.RTM_DELETE,
|
Type: unix.RTM_DELETE,
|
||||||
Seq: 1,
|
Seq: 1,
|
||||||
Addrs: []netroute.Addr{
|
|
||||||
unix.RTAX_DST: addr,
|
|
||||||
unix.RTAX_GATEWAY: link,
|
|
||||||
unix.RTAX_NETMASK: mask,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := r.Marshal()
|
if prefix.Addr().Is4() {
|
||||||
|
route.Addrs = []netroute.Addr{
|
||||||
|
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
|
||||||
|
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
|
||||||
|
unix.RTAX_GATEWAY: gateway,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
route.Addrs = []netroute.Addr{
|
||||||
|
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
|
||||||
|
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
|
||||||
|
unix.RTAX_GATEWAY: gateway,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := route.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
|
||||||
}
|
}
|
||||||
@ -513,7 +505,6 @@ func delRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Read(to []byte) (int, error) {
|
func (t *tun) Read(to []byte) (int, error) {
|
||||||
|
|
||||||
buf := make([]byte, len(to)+4)
|
buf := make([]byte, len(to)+4)
|
||||||
|
|
||||||
n, err := t.ReadWriteCloser.Read(buf)
|
n, err := t.ReadWriteCloser.Read(buf)
|
||||||
@ -551,8 +542,8 @@ func (t *tun) Write(from []byte) (int, error) {
|
|||||||
return n - 4, err
|
return n - 4, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Cidr() netip.Prefix {
|
func (t *tun) Networks() []netip.Prefix {
|
||||||
return t.cidr
|
return t.vpnNetworks
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Name() string {
|
func (t *tun) Name() string {
|
||||||
@ -563,10 +554,12 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
|||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
|
||||||
}
|
}
|
||||||
|
|
||||||
func prefixToMask(prefix netip.Prefix) []byte {
|
func prefixToMask(prefix netip.Prefix) netip.Addr {
|
||||||
pLen := 128
|
pLen := 128
|
||||||
if prefix.Addr().Is4() {
|
if prefix.Addr().Is4() {
|
||||||
pLen = 32
|
pLen = 32
|
||||||
}
|
}
|
||||||
return net.CIDRMask(prefix.Bits(), pLen)
|
|
||||||
|
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
|
||||||
|
return addr
|
||||||
}
|
}
|
||||||
|
|||||||
@ -13,7 +13,7 @@ import (
|
|||||||
|
|
||||||
type disabledTun struct {
|
type disabledTun struct {
|
||||||
read chan []byte
|
read chan []byte
|
||||||
cidr netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
|
|
||||||
// Track these metrics since we don't have the tun device to do it for us
|
// Track these metrics since we don't have the tun device to do it for us
|
||||||
tx metrics.Counter
|
tx metrics.Counter
|
||||||
@ -21,9 +21,9 @@ type disabledTun struct {
|
|||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func newDisabledTun(cidr netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
|
func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
|
||||||
tun := &disabledTun{
|
tun := &disabledTun{
|
||||||
cidr: cidr,
|
vpnNetworks: vpnNetworks,
|
||||||
read: make(chan []byte, queueLen),
|
read: make(chan []byte, queueLen),
|
||||||
l: l,
|
l: l,
|
||||||
}
|
}
|
||||||
@ -47,8 +47,8 @@ func (*disabledTun) RouteFor(addr netip.Addr) netip.Addr {
|
|||||||
return netip.Addr{}
|
return netip.Addr{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *disabledTun) Cidr() netip.Prefix {
|
func (t *disabledTun) Networks() []netip.Prefix {
|
||||||
return t.cidr
|
return t.vpnNetworks
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*disabledTun) Name() string {
|
func (*disabledTun) Name() string {
|
||||||
|
|||||||
@ -47,7 +47,7 @@ type ifreqDestroy struct {
|
|||||||
|
|
||||||
type tun struct {
|
type tun struct {
|
||||||
Device string
|
Device string
|
||||||
cidr netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
MTU int
|
MTU int
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||||
@ -78,11 +78,11 @@ func (t *tun) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
|
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||||
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
|
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
|
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||||
// Try to open existing tun device
|
// Try to open existing tun device
|
||||||
var file *os.File
|
var file *os.File
|
||||||
var err error
|
var err error
|
||||||
@ -150,7 +150,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
|
|||||||
t := &tun{
|
t := &tun{
|
||||||
ReadWriteCloser: file,
|
ReadWriteCloser: file,
|
||||||
Device: deviceName,
|
Device: deviceName,
|
||||||
cidr: cidr,
|
vpnNetworks: vpnNetworks,
|
||||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||||
l: l,
|
l: l,
|
||||||
}
|
}
|
||||||
@ -170,16 +170,16 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
|
|||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Activate() error {
|
func (t *tun) addIp(cidr netip.Prefix) error {
|
||||||
var err error
|
var err error
|
||||||
// TODO use syscalls instead of exec.Command
|
// TODO use syscalls instead of exec.Command
|
||||||
cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String())
|
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
|
||||||
t.l.Debug("command: ", cmd.String())
|
t.l.Debug("command: ", cmd.String())
|
||||||
if err = cmd.Run(); err != nil {
|
if err = cmd.Run(); err != nil {
|
||||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), "-interface", t.Device)
|
cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), "-interface", t.Device)
|
||||||
t.l.Debug("command: ", cmd.String())
|
t.l.Debug("command: ", cmd.String())
|
||||||
if err = cmd.Run(); err != nil {
|
if err = cmd.Run(); err != nil {
|
||||||
return fmt.Errorf("failed to run 'route add': %s", err)
|
return fmt.Errorf("failed to run 'route add': %s", err)
|
||||||
@ -195,8 +195,18 @@ func (t *tun) Activate() error {
|
|||||||
return t.addRoutes(false)
|
return t.addRoutes(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *tun) Activate() error {
|
||||||
|
for i := range t.vpnNetworks {
|
||||||
|
err := t.addIp(t.vpnNetworks[i])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (t *tun) reload(c *config.C, initial bool) error {
|
func (t *tun) reload(c *config.C, initial bool) error {
|
||||||
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
|
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -237,8 +247,8 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
|
|||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Cidr() netip.Prefix {
|
func (t *tun) Networks() []netip.Prefix {
|
||||||
return t.cidr
|
return t.vpnNetworks
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Name() string {
|
func (t *tun) Name() string {
|
||||||
|
|||||||
@ -21,20 +21,20 @@ import (
|
|||||||
|
|
||||||
type tun struct {
|
type tun struct {
|
||||||
io.ReadWriteCloser
|
io.ReadWriteCloser
|
||||||
cidr netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) {
|
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
|
||||||
return nil, fmt.Errorf("newTun not supported in iOS")
|
return nil, fmt.Errorf("newTun not supported in iOS")
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
|
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||||
file := os.NewFile(uintptr(deviceFd), "/dev/tun")
|
file := os.NewFile(uintptr(deviceFd), "/dev/tun")
|
||||||
t := &tun{
|
t := &tun{
|
||||||
cidr: cidr,
|
vpnNetworks: vpnNetworks,
|
||||||
ReadWriteCloser: &tunReadCloser{f: file},
|
ReadWriteCloser: &tunReadCloser{f: file},
|
||||||
l: l,
|
l: l,
|
||||||
}
|
}
|
||||||
@ -59,7 +59,7 @@ func (t *tun) Activate() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) reload(c *config.C, initial bool) error {
|
func (t *tun) reload(c *config.C, initial bool) error {
|
||||||
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
|
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -142,8 +142,8 @@ func (tr *tunReadCloser) Close() error {
|
|||||||
return tr.f.Close()
|
return tr.f.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Cidr() netip.Prefix {
|
func (t *tun) Networks() []netip.Prefix {
|
||||||
return t.cidr
|
return t.vpnNetworks
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Name() string {
|
func (t *tun) Name() string {
|
||||||
|
|||||||
@ -25,7 +25,7 @@ type tun struct {
|
|||||||
io.ReadWriteCloser
|
io.ReadWriteCloser
|
||||||
fd int
|
fd int
|
||||||
Device string
|
Device string
|
||||||
cidr netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
MaxMTU int
|
MaxMTU int
|
||||||
DefaultMTU int
|
DefaultMTU int
|
||||||
TXQueueLen int
|
TXQueueLen int
|
||||||
@ -40,18 +40,16 @@ type tun struct {
|
|||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *tun) Networks() []netip.Prefix {
|
||||||
|
return t.vpnNetworks
|
||||||
|
}
|
||||||
|
|
||||||
type ifReq struct {
|
type ifReq struct {
|
||||||
Name [16]byte
|
Name [16]byte
|
||||||
Flags uint16
|
Flags uint16
|
||||||
pad [8]byte
|
pad [8]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
type ifreqAddr struct {
|
|
||||||
Name [16]byte
|
|
||||||
Addr unix.RawSockaddrInet4
|
|
||||||
pad [8]byte
|
|
||||||
}
|
|
||||||
|
|
||||||
type ifreqMTU struct {
|
type ifreqMTU struct {
|
||||||
Name [16]byte
|
Name [16]byte
|
||||||
MTU int32
|
MTU int32
|
||||||
@ -64,10 +62,10 @@ type ifreqQLEN struct {
|
|||||||
pad [8]byte
|
pad [8]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
|
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||||
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
||||||
|
|
||||||
t, err := newTunGeneric(c, l, file, cidr)
|
t, err := newTunGeneric(c, l, file, vpnNetworks)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -77,7 +75,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix
|
|||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (*tun, error) {
|
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
|
||||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
|
// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
|
||||||
@ -112,7 +110,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (
|
|||||||
name := strings.Trim(string(req.Name[:]), "\x00")
|
name := strings.Trim(string(req.Name[:]), "\x00")
|
||||||
|
|
||||||
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
||||||
t, err := newTunGeneric(c, l, file, cidr)
|
t, err := newTunGeneric(c, l, file, vpnNetworks)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -122,11 +120,11 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (
|
|||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Prefix) (*tun, error) {
|
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||||
t := &tun{
|
t := &tun{
|
||||||
ReadWriteCloser: file,
|
ReadWriteCloser: file,
|
||||||
fd: int(file.Fd()),
|
fd: int(file.Fd()),
|
||||||
cidr: cidr,
|
vpnNetworks: vpnNetworks,
|
||||||
TXQueueLen: c.GetInt("tun.tx_queue", 500),
|
TXQueueLen: c.GetInt("tun.tx_queue", 500),
|
||||||
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
|
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
|
||||||
l: l,
|
l: l,
|
||||||
@ -148,7 +146,7 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Pref
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) reload(c *config.C, initial bool) error {
|
func (t *tun) reload(c *config.C, initial bool) error {
|
||||||
routeChange, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
|
routeChange, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -190,13 +188,15 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if oldDefaultMTU != newDefaultMTU {
|
if oldDefaultMTU != newDefaultMTU {
|
||||||
err := t.setDefaultRoute()
|
for i := range t.vpnNetworks {
|
||||||
|
err := t.setDefaultRoute(t.vpnNetworks[i])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.l.Warn(err)
|
t.l.Warn(err)
|
||||||
} else {
|
} else {
|
||||||
t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU)
|
t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Remove first, if the system removes a wanted route hopefully it will be re-added next
|
// Remove first, if the system removes a wanted route hopefully it will be re-added next
|
||||||
t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
|
t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
|
||||||
@ -237,10 +237,10 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
|
|||||||
|
|
||||||
func (t *tun) Write(b []byte) (int, error) {
|
func (t *tun) Write(b []byte) (int, error) {
|
||||||
var nn int
|
var nn int
|
||||||
max := len(b)
|
maximum := len(b)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
n, err := unix.Write(t.fd, b[nn:max])
|
n, err := unix.Write(t.fd, b[nn:maximum])
|
||||||
if n > 0 {
|
if n > 0 {
|
||||||
nn += n
|
nn += n
|
||||||
}
|
}
|
||||||
@ -265,6 +265,58 @@ func (t *tun) deviceBytes() (o [16]byte) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func hasNetlinkAddr(al []*netlink.Addr, x netlink.Addr) bool {
|
||||||
|
for i := range al {
|
||||||
|
if al[i].Equal(x) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// addIPs uses netlink to add all addresses that don't exist, then it removes ones that should not be there
|
||||||
|
func (t *tun) addIPs(link netlink.Link) error {
|
||||||
|
newAddrs := make([]*netlink.Addr, len(t.vpnNetworks))
|
||||||
|
for i := range t.vpnNetworks {
|
||||||
|
newAddrs[i] = &netlink.Addr{
|
||||||
|
IPNet: &net.IPNet{
|
||||||
|
IP: t.vpnNetworks[i].Addr().AsSlice(),
|
||||||
|
Mask: net.CIDRMask(t.vpnNetworks[i].Bits(), t.vpnNetworks[i].Addr().BitLen()),
|
||||||
|
},
|
||||||
|
Label: t.vpnNetworks[i].Addr().Zone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//add all new addresses
|
||||||
|
for i := range newAddrs {
|
||||||
|
//todo do we want to stack errors and try as many ops as possible?
|
||||||
|
//AddrReplace still adds new IPs, but if their properties change it will change them as well
|
||||||
|
if err := netlink.AddrReplace(link, newAddrs[i]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//iterate over remainder, remove whoever shouldn't be there
|
||||||
|
al, err := netlink.AddrList(link, netlink.FAMILY_ALL)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get tun address list: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range al {
|
||||||
|
if hasNetlinkAddr(newAddrs, al[i]) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
err = netlink.AddrDel(link, &al[i])
|
||||||
|
if err != nil {
|
||||||
|
t.l.WithError(err).Error("failed to remove address from tun address list")
|
||||||
|
} else {
|
||||||
|
t.l.WithField("removed", al[i].String()).Info("removed address not listed in cert(s)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (t *tun) Activate() error {
|
func (t *tun) Activate() error {
|
||||||
devName := t.deviceBytes()
|
devName := t.deviceBytes()
|
||||||
|
|
||||||
@ -272,15 +324,8 @@ func (t *tun) Activate() error {
|
|||||||
t.watchRoutes()
|
t.watchRoutes()
|
||||||
}
|
}
|
||||||
|
|
||||||
var addr, mask [4]byte
|
|
||||||
|
|
||||||
//TODO: IPV6-WORK
|
|
||||||
addr = t.cidr.Addr().As4()
|
|
||||||
tmask := net.CIDRMask(t.cidr.Bits(), 32)
|
|
||||||
copy(mask[:], tmask)
|
|
||||||
|
|
||||||
s, err := unix.Socket(
|
s, err := unix.Socket(
|
||||||
unix.AF_INET,
|
unix.AF_INET, //because everything we use t.ioctlFd for is address family independent, this is fine
|
||||||
unix.SOCK_DGRAM,
|
unix.SOCK_DGRAM,
|
||||||
unix.IPPROTO_IP,
|
unix.IPPROTO_IP,
|
||||||
)
|
)
|
||||||
@ -289,31 +334,19 @@ func (t *tun) Activate() error {
|
|||||||
}
|
}
|
||||||
t.ioctlFd = uintptr(s)
|
t.ioctlFd = uintptr(s)
|
||||||
|
|
||||||
ifra := ifreqAddr{
|
|
||||||
Name: devName,
|
|
||||||
Addr: unix.RawSockaddrInet4{
|
|
||||||
Family: unix.AF_INET,
|
|
||||||
Addr: addr,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the device ip address
|
|
||||||
if err = ioctl(t.ioctlFd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil {
|
|
||||||
return fmt.Errorf("failed to set tun address: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the device network
|
|
||||||
ifra.Addr.Addr = mask
|
|
||||||
if err = ioctl(t.ioctlFd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil {
|
|
||||||
return fmt.Errorf("failed to set tun netmask: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the device name
|
// Set the device name
|
||||||
ifrf := ifReq{Name: devName}
|
ifrf := ifReq{Name: devName}
|
||||||
if err = ioctl(t.ioctlFd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
if err = ioctl(t.ioctlFd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
||||||
return fmt.Errorf("failed to set tun device name: %s", err)
|
return fmt.Errorf("failed to set tun device name: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
link, err := netlink.LinkByName(t.Device)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get tun device link: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.deviceIndex = link.Attrs().Index
|
||||||
|
|
||||||
// Setup our default MTU
|
// Setup our default MTU
|
||||||
t.setMTU()
|
t.setMTU()
|
||||||
|
|
||||||
@ -324,33 +357,36 @@ func (t *tun) Activate() error {
|
|||||||
t.l.WithError(err).Error("Failed to set tun tx queue length")
|
t.l.WithError(err).Error("Failed to set tun tx queue length")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err = t.addIPs(link); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// Bring up the interface
|
// Bring up the interface
|
||||||
ifrf.Flags = ifrf.Flags | unix.IFF_UP
|
ifrf.Flags = ifrf.Flags | unix.IFF_UP
|
||||||
if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
||||||
return fmt.Errorf("failed to bring the tun device up: %s", err)
|
return fmt.Errorf("failed to bring the tun device up: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
link, err := netlink.LinkByName(t.Device)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get tun device link: %s", err)
|
|
||||||
}
|
|
||||||
t.deviceIndex = link.Attrs().Index
|
|
||||||
|
|
||||||
if err = t.setDefaultRoute(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the routes
|
|
||||||
if err = t.addRoutes(false); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run the interface
|
// Run the interface
|
||||||
ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING
|
ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING
|
||||||
if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
|
||||||
return fmt.Errorf("failed to run tun device: %s", err)
|
return fmt.Errorf("failed to run tun device: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//set route MTU
|
||||||
|
for i := range t.vpnNetworks {
|
||||||
|
if err = t.setDefaultRoute(t.vpnNetworks[i]); err != nil {
|
||||||
|
return fmt.Errorf("failed to set default route MTU: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the routes
|
||||||
|
if err = t.addRoutes(false); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
//todo do we want to keep the link-local address?
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -363,12 +399,12 @@ func (t *tun) setMTU() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) setDefaultRoute() error {
|
func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
|
||||||
// Default route
|
// Default route
|
||||||
|
|
||||||
dr := &net.IPNet{
|
dr := &net.IPNet{
|
||||||
IP: t.cidr.Masked().Addr().AsSlice(),
|
IP: cidr.Masked().Addr().AsSlice(),
|
||||||
Mask: net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen()),
|
Mask: net.CIDRMask(cidr.Bits(), cidr.Addr().BitLen()),
|
||||||
}
|
}
|
||||||
|
|
||||||
nr := netlink.Route{
|
nr := netlink.Route{
|
||||||
@ -377,7 +413,7 @@ func (t *tun) setDefaultRoute() error {
|
|||||||
MTU: t.DefaultMTU,
|
MTU: t.DefaultMTU,
|
||||||
AdvMSS: t.advMSS(Route{}),
|
AdvMSS: t.advMSS(Route{}),
|
||||||
Scope: unix.RT_SCOPE_LINK,
|
Scope: unix.RT_SCOPE_LINK,
|
||||||
Src: net.IP(t.cidr.Addr().AsSlice()),
|
Src: net.IP(cidr.Addr().AsSlice()),
|
||||||
Protocol: unix.RTPROT_KERNEL,
|
Protocol: unix.RTPROT_KERNEL,
|
||||||
Table: unix.RT_TABLE_MAIN,
|
Table: unix.RT_TABLE_MAIN,
|
||||||
Type: unix.RTN_UNICAST,
|
Type: unix.RTN_UNICAST,
|
||||||
@ -463,10 +499,6 @@ func (t *tun) removeRoutes(routes []Route) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Cidr() netip.Prefix {
|
|
||||||
return t.cidr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tun) Name() string {
|
func (t *tun) Name() string {
|
||||||
return t.Device
|
return t.Device
|
||||||
}
|
}
|
||||||
@ -523,9 +555,16 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
gwAddr = gwAddr.Unmap()
|
gwAddr = gwAddr.Unmap()
|
||||||
if !t.cidr.Contains(gwAddr) {
|
withinNetworks := false
|
||||||
|
for i := range t.vpnNetworks {
|
||||||
|
if t.vpnNetworks[i].Contains(gwAddr) {
|
||||||
|
withinNetworks = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !withinNetworks {
|
||||||
// Gateway isn't in our overlay network, ignore
|
// Gateway isn't in our overlay network, ignore
|
||||||
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
|
t.l.WithField("route", r).Debug("Ignoring route update, not in our networks")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -28,7 +28,7 @@ type ifreqDestroy struct {
|
|||||||
|
|
||||||
type tun struct {
|
type tun struct {
|
||||||
Device string
|
Device string
|
||||||
cidr netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
MTU int
|
MTU int
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||||
@ -58,13 +58,13 @@ func (t *tun) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
|
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||||
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
|
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
|
||||||
}
|
}
|
||||||
|
|
||||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
|
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||||
// Try to open tun device
|
// Try to open tun device
|
||||||
var file *os.File
|
var file *os.File
|
||||||
var err error
|
var err error
|
||||||
@ -84,7 +84,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
|
|||||||
t := &tun{
|
t := &tun{
|
||||||
ReadWriteCloser: file,
|
ReadWriteCloser: file,
|
||||||
Device: deviceName,
|
Device: deviceName,
|
||||||
cidr: cidr,
|
vpnNetworks: vpnNetworks,
|
||||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||||
l: l,
|
l: l,
|
||||||
}
|
}
|
||||||
@ -104,17 +104,17 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
|
|||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Activate() error {
|
func (t *tun) addIp(cidr netip.Prefix) error {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// TODO use syscalls instead of exec.Command
|
// TODO use syscalls instead of exec.Command
|
||||||
cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String())
|
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
|
||||||
t.l.Debug("command: ", cmd.String())
|
t.l.Debug("command: ", cmd.String())
|
||||||
if err = cmd.Run(); err != nil {
|
if err = cmd.Run(); err != nil {
|
||||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), t.cidr.Addr().String())
|
cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), cidr.Addr().String())
|
||||||
t.l.Debug("command: ", cmd.String())
|
t.l.Debug("command: ", cmd.String())
|
||||||
if err = cmd.Run(); err != nil {
|
if err = cmd.Run(); err != nil {
|
||||||
return fmt.Errorf("failed to run 'route add': %s", err)
|
return fmt.Errorf("failed to run 'route add': %s", err)
|
||||||
@ -130,8 +130,18 @@ func (t *tun) Activate() error {
|
|||||||
return t.addRoutes(false)
|
return t.addRoutes(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *tun) Activate() error {
|
||||||
|
for i := range t.vpnNetworks {
|
||||||
|
err := t.addIp(t.vpnNetworks[i])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (t *tun) reload(c *config.C, initial bool) error {
|
func (t *tun) reload(c *config.C, initial bool) error {
|
||||||
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
|
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -172,8 +182,8 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
|
|||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Cidr() netip.Prefix {
|
func (t *tun) Networks() []netip.Prefix {
|
||||||
return t.cidr
|
return t.vpnNetworks
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Name() string {
|
func (t *tun) Name() string {
|
||||||
@ -192,7 +202,7 @@ func (t *tun) addRoutes(logErrors bool) error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.Addr().String())
|
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
|
||||||
t.l.Debug("command: ", cmd.String())
|
t.l.Debug("command: ", cmd.String())
|
||||||
if err := cmd.Run(); err != nil {
|
if err := cmd.Run(); err != nil {
|
||||||
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
|
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
|
||||||
@ -213,7 +223,8 @@ func (t *tun) removeRoutes(routes []Route) error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.cidr.Addr().String())
|
//todo is this right?
|
||||||
|
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
|
||||||
t.l.Debug("command: ", cmd.String())
|
t.l.Debug("command: ", cmd.String())
|
||||||
if err := cmd.Run(); err != nil {
|
if err := cmd.Run(); err != nil {
|
||||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||||
|
|||||||
@ -22,7 +22,7 @@ import (
|
|||||||
|
|
||||||
type tun struct {
|
type tun struct {
|
||||||
Device string
|
Device string
|
||||||
cidr netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
MTU int
|
MTU int
|
||||||
Routes atomic.Pointer[[]Route]
|
Routes atomic.Pointer[[]Route]
|
||||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||||
@ -42,13 +42,13 @@ func (t *tun) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
|
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||||
return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD")
|
return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD")
|
||||||
}
|
}
|
||||||
|
|
||||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
|
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||||
deviceName := c.GetString("tun.dev", "")
|
deviceName := c.GetString("tun.dev", "")
|
||||||
if deviceName == "" {
|
if deviceName == "" {
|
||||||
return nil, fmt.Errorf("a device name in the format of tunN must be specified")
|
return nil, fmt.Errorf("a device name in the format of tunN must be specified")
|
||||||
@ -66,7 +66,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
|
|||||||
t := &tun{
|
t := &tun{
|
||||||
ReadWriteCloser: file,
|
ReadWriteCloser: file,
|
||||||
Device: deviceName,
|
Device: deviceName,
|
||||||
cidr: cidr,
|
vpnNetworks: vpnNetworks,
|
||||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||||
l: l,
|
l: l,
|
||||||
}
|
}
|
||||||
@ -87,7 +87,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) reload(c *config.C, initial bool) error {
|
func (t *tun) reload(c *config.C, initial bool) error {
|
||||||
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
|
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -123,10 +123,10 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Activate() error {
|
func (t *tun) addIp(cidr netip.Prefix) error {
|
||||||
var err error
|
var err error
|
||||||
// TODO use syscalls instead of exec.Command
|
// TODO use syscalls instead of exec.Command
|
||||||
cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String())
|
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
|
||||||
t.l.Debug("command: ", cmd.String())
|
t.l.Debug("command: ", cmd.String())
|
||||||
if err = cmd.Run(); err != nil {
|
if err = cmd.Run(); err != nil {
|
||||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||||
@ -138,7 +138,7 @@ func (t *tun) Activate() error {
|
|||||||
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
return fmt.Errorf("failed to run 'ifconfig': %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd = exec.Command("/sbin/route", "-n", "add", "-inet", t.cidr.String(), t.cidr.Addr().String())
|
cmd = exec.Command("/sbin/route", "-n", "add", "-inet", cidr.String(), cidr.Addr().String())
|
||||||
t.l.Debug("command: ", cmd.String())
|
t.l.Debug("command: ", cmd.String())
|
||||||
if err = cmd.Run(); err != nil {
|
if err = cmd.Run(); err != nil {
|
||||||
return fmt.Errorf("failed to run 'route add': %s", err)
|
return fmt.Errorf("failed to run 'route add': %s", err)
|
||||||
@ -148,6 +148,16 @@ func (t *tun) Activate() error {
|
|||||||
return t.addRoutes(false)
|
return t.addRoutes(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *tun) Activate() error {
|
||||||
|
for i := range t.vpnNetworks {
|
||||||
|
err := t.addIp(t.vpnNetworks[i])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
|
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
|
||||||
r, _ := t.routeTree.Load().Lookup(ip)
|
r, _ := t.routeTree.Load().Lookup(ip)
|
||||||
return r
|
return r
|
||||||
@ -160,8 +170,8 @@ func (t *tun) addRoutes(logErrors bool) error {
|
|||||||
// We don't allow route MTUs so only install routes with a via
|
// We don't allow route MTUs so only install routes with a via
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
//todo is this right?
|
||||||
cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.Addr().String())
|
cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
|
||||||
t.l.Debug("command: ", cmd.String())
|
t.l.Debug("command: ", cmd.String())
|
||||||
if err := cmd.Run(); err != nil {
|
if err := cmd.Run(); err != nil {
|
||||||
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
|
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
|
||||||
@ -181,8 +191,8 @@ func (t *tun) removeRoutes(routes []Route) error {
|
|||||||
if !r.Install {
|
if !r.Install {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
//todo is this right?
|
||||||
cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.cidr.Addr().String())
|
cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
|
||||||
t.l.Debug("command: ", cmd.String())
|
t.l.Debug("command: ", cmd.String())
|
||||||
if err := cmd.Run(); err != nil {
|
if err := cmd.Run(); err != nil {
|
||||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||||
@ -193,8 +203,8 @@ func (t *tun) removeRoutes(routes []Route) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Cidr() netip.Prefix {
|
func (t *tun) Networks() []netip.Prefix {
|
||||||
return t.cidr
|
return t.vpnNetworks
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tun) Name() string {
|
func (t *tun) Name() string {
|
||||||
|
|||||||
@ -17,7 +17,7 @@ import (
|
|||||||
|
|
||||||
type TestTun struct {
|
type TestTun struct {
|
||||||
Device string
|
Device string
|
||||||
cidr netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
Routes []Route
|
Routes []Route
|
||||||
routeTree *bart.Table[netip.Addr]
|
routeTree *bart.Table[netip.Addr]
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
@ -27,8 +27,8 @@ type TestTun struct {
|
|||||||
TxPackets chan []byte // Packets transmitted outside by nebula
|
TxPackets chan []byte // Packets transmitted outside by nebula
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*TestTun, error) {
|
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) {
|
||||||
_, routes, err := getAllRoutesFromConfig(c, cidr, true)
|
_, routes, err := getAllRoutesFromConfig(c, vpnNetworks, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -39,7 +39,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*TestTun,
|
|||||||
|
|
||||||
return &TestTun{
|
return &TestTun{
|
||||||
Device: c.GetString("tun.dev", ""),
|
Device: c.GetString("tun.dev", ""),
|
||||||
cidr: cidr,
|
vpnNetworks: vpnNetworks,
|
||||||
Routes: routes,
|
Routes: routes,
|
||||||
routeTree: routeTree,
|
routeTree: routeTree,
|
||||||
l: l,
|
l: l,
|
||||||
@ -48,7 +48,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*TestTun,
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*TestTun, error) {
|
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*TestTun, error) {
|
||||||
return nil, fmt.Errorf("newTunFromFd not supported")
|
return nil, fmt.Errorf("newTunFromFd not supported")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -95,8 +95,8 @@ func (t *TestTun) Activate() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TestTun) Cidr() netip.Prefix {
|
func (t *TestTun) Networks() []netip.Prefix {
|
||||||
return t.cidr
|
return t.vpnNetworks
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TestTun) Name() string {
|
func (t *TestTun) Name() string {
|
||||||
|
|||||||
@ -1,208 +0,0 @@
|
|||||||
package overlay
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"os/exec"
|
|
||||||
"strconv"
|
|
||||||
"sync/atomic"
|
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
|
||||||
"github.com/slackhq/nebula/util"
|
|
||||||
"github.com/songgao/water"
|
|
||||||
)
|
|
||||||
|
|
||||||
type waterTun struct {
|
|
||||||
Device string
|
|
||||||
cidr netip.Prefix
|
|
||||||
MTU int
|
|
||||||
Routes atomic.Pointer[[]Route]
|
|
||||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
|
||||||
l *logrus.Logger
|
|
||||||
f *net.Interface
|
|
||||||
*water.Interface
|
|
||||||
}
|
|
||||||
|
|
||||||
func newWaterTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*waterTun, error) {
|
|
||||||
// NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate()
|
|
||||||
t := &waterTun{
|
|
||||||
cidr: cidr,
|
|
||||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
|
||||||
l: l,
|
|
||||||
}
|
|
||||||
|
|
||||||
err := t.reload(c, true)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.RegisterReloadCallback(func(c *config.C) {
|
|
||||||
err := t.reload(c, false)
|
|
||||||
if err != nil {
|
|
||||||
util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
return t, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *waterTun) Activate() error {
|
|
||||||
var err error
|
|
||||||
t.Interface, err = water.New(water.Config{
|
|
||||||
DeviceType: water.TUN,
|
|
||||||
PlatformSpecificParams: water.PlatformSpecificParams{
|
|
||||||
ComponentID: "tap0901",
|
|
||||||
Network: t.cidr.String(),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("activate failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Device = t.Interface.Name()
|
|
||||||
|
|
||||||
// TODO use syscalls instead of exec.Command
|
|
||||||
err = exec.Command(
|
|
||||||
`C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "address",
|
|
||||||
fmt.Sprintf("name=%s", t.Device),
|
|
||||||
"source=static",
|
|
||||||
fmt.Sprintf("addr=%s", t.cidr.Addr()),
|
|
||||||
fmt.Sprintf("mask=%s", net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen())),
|
|
||||||
"gateway=none",
|
|
||||||
).Run()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to run 'netsh' to set address: %s", err)
|
|
||||||
}
|
|
||||||
err = exec.Command(
|
|
||||||
`C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "interface",
|
|
||||||
t.Device,
|
|
||||||
fmt.Sprintf("mtu=%d", t.MTU),
|
|
||||||
).Run()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to run 'netsh' to set MTU: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.f, err = net.InterfaceByName(t.Device)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to find interface named %s: %v", t.Device, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = t.addRoutes(false)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *waterTun) reload(c *config.C, initial bool) error {
|
|
||||||
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !initial && !change {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
routeTree, err := makeRouteTree(t.l, routes, false)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Teach nebula how to handle the routes before establishing them in the system table
|
|
||||||
oldRoutes := t.Routes.Swap(&routes)
|
|
||||||
t.routeTree.Store(routeTree)
|
|
||||||
|
|
||||||
if !initial {
|
|
||||||
// Remove first, if the system removes a wanted route hopefully it will be re-added next
|
|
||||||
t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
|
|
||||||
|
|
||||||
// Ensure any routes we actually want are installed
|
|
||||||
err = t.addRoutes(true)
|
|
||||||
if err != nil {
|
|
||||||
// Catch any stray logs
|
|
||||||
util.LogWithContextIfNeeded("Failed to set routes", err, t.l)
|
|
||||||
} else {
|
|
||||||
for _, r := range findRemovedRoutes(routes, *oldRoutes) {
|
|
||||||
t.l.WithField("route", r).Info("Removed route")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *waterTun) addRoutes(logErrors bool) error {
|
|
||||||
// Path routes
|
|
||||||
routes := *t.Routes.Load()
|
|
||||||
for _, r := range routes {
|
|
||||||
if !r.Via.IsValid() || !r.Install {
|
|
||||||
// We don't allow route MTUs so only install routes with a via
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
err := exec.Command(
|
|
||||||
"C:\\Windows\\System32\\route.exe", "add", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(t.f.Index), "METRIC", strconv.Itoa(r.Metric),
|
|
||||||
).Run()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
|
|
||||||
if logErrors {
|
|
||||||
retErr.Log(t.l)
|
|
||||||
} else {
|
|
||||||
return retErr
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
t.l.WithField("route", r).Info("Added route")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *waterTun) removeRoutes(routes []Route) {
|
|
||||||
for _, r := range routes {
|
|
||||||
if !r.Install {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
err := exec.Command(
|
|
||||||
"C:\\Windows\\System32\\route.exe", "delete", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(t.f.Index), "METRIC", strconv.Itoa(r.Metric),
|
|
||||||
).Run()
|
|
||||||
if err != nil {
|
|
||||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
|
||||||
} else {
|
|
||||||
t.l.WithField("route", r).Info("Removed route")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *waterTun) RouteFor(ip netip.Addr) netip.Addr {
|
|
||||||
r, _ := t.routeTree.Load().Lookup(ip)
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *waterTun) Cidr() netip.Prefix {
|
|
||||||
return t.cidr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *waterTun) Name() string {
|
|
||||||
return t.Device
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *waterTun) Close() error {
|
|
||||||
if t.Interface == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return t.Interface.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *waterTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
|
|
||||||
}
|
|
||||||
@ -4,41 +4,267 @@
|
|||||||
package overlay
|
package overlay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"sync/atomic"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/util"
|
||||||
|
"github.com/slackhq/nebula/wintun"
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (Device, error) {
|
const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
|
||||||
|
|
||||||
|
type winTun struct {
|
||||||
|
Device string
|
||||||
|
vpnNetworks []netip.Prefix
|
||||||
|
MTU int
|
||||||
|
Routes atomic.Pointer[[]Route]
|
||||||
|
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
||||||
|
l *logrus.Logger
|
||||||
|
|
||||||
|
tun *wintun.NativeTun
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) {
|
||||||
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
|
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (Device, error) {
|
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) {
|
||||||
useWintun := true
|
err := checkWinTunExists()
|
||||||
if err := checkWinTunExists(); err != nil {
|
if err != nil {
|
||||||
l.WithError(err).Warn("Check Wintun driver failed, fallback to wintap driver")
|
return nil, fmt.Errorf("can not load the wintun driver: %w", err)
|
||||||
useWintun = false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if useWintun {
|
deviceName := c.GetString("tun.dev", "")
|
||||||
device, err := newWinTun(c, l, cidr, multiqueue)
|
guid, err := generateGUIDByDeviceName(deviceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create Wintun interface failed, %w", err)
|
return nil, fmt.Errorf("generate GUID failed: %w", err)
|
||||||
}
|
|
||||||
return device, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
device, err := newWaterTun(c, l, cidr, multiqueue)
|
t := &winTun{
|
||||||
if err != nil {
|
Device: deviceName,
|
||||||
return nil, fmt.Errorf("create wintap driver failed, %w", err)
|
vpnNetworks: vpnNetworks,
|
||||||
|
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||||
|
l: l,
|
||||||
}
|
}
|
||||||
return device, nil
|
|
||||||
|
err = t.reload(c, true)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var tunDevice wintun.Device
|
||||||
|
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
|
||||||
|
if err != nil {
|
||||||
|
// Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device.
|
||||||
|
// Trying a second time resolves the issue.
|
||||||
|
l.WithError(err).Debug("Failed to create wintun device, retrying")
|
||||||
|
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create TUN device failed: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.tun = tunDevice.(*wintun.NativeTun)
|
||||||
|
|
||||||
|
c.RegisterReloadCallback(func(c *config.C) {
|
||||||
|
err := t.reload(c, false)
|
||||||
|
if err != nil {
|
||||||
|
util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return t, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *winTun) reload(c *config.C, initial bool) error {
|
||||||
|
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !initial && !change {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
routeTree, err := makeRouteTree(t.l, routes, false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Teach nebula how to handle the routes before establishing them in the system table
|
||||||
|
oldRoutes := t.Routes.Swap(&routes)
|
||||||
|
t.routeTree.Store(routeTree)
|
||||||
|
|
||||||
|
if !initial {
|
||||||
|
// Remove first, if the system removes a wanted route hopefully it will be re-added next
|
||||||
|
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
|
||||||
|
if err != nil {
|
||||||
|
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure any routes we actually want are installed
|
||||||
|
err = t.addRoutes(true)
|
||||||
|
if err != nil {
|
||||||
|
// Catch any stray logs
|
||||||
|
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *winTun) Activate() error {
|
||||||
|
luid := winipcfg.LUID(t.tun.LUID())
|
||||||
|
|
||||||
|
err := luid.SetIPAddresses(t.vpnNetworks)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to set address: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = t.addRoutes(false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *winTun) addRoutes(logErrors bool) error {
|
||||||
|
luid := winipcfg.LUID(t.tun.LUID())
|
||||||
|
routes := *t.Routes.Load()
|
||||||
|
foundDefault4 := false
|
||||||
|
|
||||||
|
for _, r := range routes {
|
||||||
|
if !r.Via.IsValid() || !r.Install {
|
||||||
|
// We don't allow route MTUs so only install routes with a via
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add our unsafe route
|
||||||
|
err := luid.AddRoute(r.Cidr, r.Via, uint32(r.Metric))
|
||||||
|
if err != nil {
|
||||||
|
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
|
||||||
|
if logErrors {
|
||||||
|
retErr.Log(t.l)
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
|
return retErr
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t.l.WithField("route", r).Info("Added route")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !foundDefault4 {
|
||||||
|
if r.Cidr.Bits() == 0 && r.Cidr.Addr().BitLen() == 32 {
|
||||||
|
foundDefault4 = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ipif, err := luid.IPInterface(windows.AF_INET)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get ip interface: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ipif.NLMTU = uint32(t.MTU)
|
||||||
|
if foundDefault4 {
|
||||||
|
ipif.UseAutomaticMetric = false
|
||||||
|
ipif.Metric = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ipif.Set(); err != nil {
|
||||||
|
return fmt.Errorf("failed to set ip interface: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *winTun) removeRoutes(routes []Route) error {
|
||||||
|
luid := winipcfg.LUID(t.tun.LUID())
|
||||||
|
|
||||||
|
for _, r := range routes {
|
||||||
|
if !r.Install {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err := luid.DeleteRoute(r.Cidr, r.Via)
|
||||||
|
if err != nil {
|
||||||
|
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
||||||
|
} else {
|
||||||
|
t.l.WithField("route", r).Info("Removed route")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *winTun) RouteFor(ip netip.Addr) netip.Addr {
|
||||||
|
r, _ := t.routeTree.Load().Lookup(ip)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *winTun) Networks() []netip.Prefix {
|
||||||
|
return t.vpnNetworks
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *winTun) Name() string {
|
||||||
|
return t.Device
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *winTun) Read(b []byte) (int, error) {
|
||||||
|
return t.tun.Read(b, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *winTun) Write(b []byte) (int, error) {
|
||||||
|
return t.tun.Write(b, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
|
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *winTun) Close() error {
|
||||||
|
// It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes,
|
||||||
|
// so to be certain, just remove everything before destroying.
|
||||||
|
luid := winipcfg.LUID(t.tun.LUID())
|
||||||
|
_ = luid.FlushRoutes(windows.AF_INET)
|
||||||
|
_ = luid.FlushIPAddresses(windows.AF_INET)
|
||||||
|
/* We don't support IPV6 yet
|
||||||
|
_ = luid.FlushRoutes(windows.AF_INET6)
|
||||||
|
_ = luid.FlushIPAddresses(windows.AF_INET6)
|
||||||
|
*/
|
||||||
|
_ = luid.FlushDNS(windows.AF_INET)
|
||||||
|
|
||||||
|
return t.tun.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateGUIDByDeviceName(name string) (*windows.GUID, error) {
|
||||||
|
// GUID is 128 bit
|
||||||
|
hash := crypto.MD5.New()
|
||||||
|
|
||||||
|
_, err := hash.Write([]byte(tunGUIDLabel))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = hash.Write([]byte(name))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sum := hash.Sum(nil)
|
||||||
|
|
||||||
|
return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkWinTunExists() error {
|
func checkWinTunExists() error {
|
||||||
|
|||||||
@ -1,252 +0,0 @@
|
|||||||
package overlay
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/netip"
|
|
||||||
"sync/atomic"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/slackhq/nebula/config"
|
|
||||||
"github.com/slackhq/nebula/util"
|
|
||||||
"github.com/slackhq/nebula/wintun"
|
|
||||||
"golang.org/x/sys/windows"
|
|
||||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
|
||||||
)
|
|
||||||
|
|
||||||
const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
|
|
||||||
|
|
||||||
type winTun struct {
|
|
||||||
Device string
|
|
||||||
cidr netip.Prefix
|
|
||||||
MTU int
|
|
||||||
Routes atomic.Pointer[[]Route]
|
|
||||||
routeTree atomic.Pointer[bart.Table[netip.Addr]]
|
|
||||||
l *logrus.Logger
|
|
||||||
|
|
||||||
tun *wintun.NativeTun
|
|
||||||
}
|
|
||||||
|
|
||||||
func generateGUIDByDeviceName(name string) (*windows.GUID, error) {
|
|
||||||
// GUID is 128 bit
|
|
||||||
hash := crypto.MD5.New()
|
|
||||||
|
|
||||||
_, err := hash.Write([]byte(tunGUIDLabel))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = hash.Write([]byte(name))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
sum := hash.Sum(nil)
|
|
||||||
|
|
||||||
return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func newWinTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*winTun, error) {
|
|
||||||
deviceName := c.GetString("tun.dev", "")
|
|
||||||
guid, err := generateGUIDByDeviceName(deviceName)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("generate GUID failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
t := &winTun{
|
|
||||||
Device: deviceName,
|
|
||||||
cidr: cidr,
|
|
||||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
|
||||||
l: l,
|
|
||||||
}
|
|
||||||
|
|
||||||
err = t.reload(c, true)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var tunDevice wintun.Device
|
|
||||||
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
|
|
||||||
if err != nil {
|
|
||||||
// Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device.
|
|
||||||
// Trying a second time resolves the issue.
|
|
||||||
l.WithError(err).Debug("Failed to create wintun device, retrying")
|
|
||||||
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("create TUN device failed: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
t.tun = tunDevice.(*wintun.NativeTun)
|
|
||||||
|
|
||||||
c.RegisterReloadCallback(func(c *config.C) {
|
|
||||||
err := t.reload(c, false)
|
|
||||||
if err != nil {
|
|
||||||
util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
return t, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *winTun) reload(c *config.C, initial bool) error {
|
|
||||||
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !initial && !change {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
routeTree, err := makeRouteTree(t.l, routes, false)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Teach nebula how to handle the routes before establishing them in the system table
|
|
||||||
oldRoutes := t.Routes.Swap(&routes)
|
|
||||||
t.routeTree.Store(routeTree)
|
|
||||||
|
|
||||||
if !initial {
|
|
||||||
// Remove first, if the system removes a wanted route hopefully it will be re-added next
|
|
||||||
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
|
|
||||||
if err != nil {
|
|
||||||
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure any routes we actually want are installed
|
|
||||||
err = t.addRoutes(true)
|
|
||||||
if err != nil {
|
|
||||||
// Catch any stray logs
|
|
||||||
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *winTun) Activate() error {
|
|
||||||
luid := winipcfg.LUID(t.tun.LUID())
|
|
||||||
|
|
||||||
err := luid.SetIPAddresses([]netip.Prefix{t.cidr})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to set address: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = t.addRoutes(false)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *winTun) addRoutes(logErrors bool) error {
|
|
||||||
luid := winipcfg.LUID(t.tun.LUID())
|
|
||||||
routes := *t.Routes.Load()
|
|
||||||
foundDefault4 := false
|
|
||||||
|
|
||||||
for _, r := range routes {
|
|
||||||
if !r.Via.IsValid() || !r.Install {
|
|
||||||
// We don't allow route MTUs so only install routes with a via
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add our unsafe route
|
|
||||||
err := luid.AddRoute(r.Cidr, r.Via, uint32(r.Metric))
|
|
||||||
if err != nil {
|
|
||||||
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
|
|
||||||
if logErrors {
|
|
||||||
retErr.Log(t.l)
|
|
||||||
continue
|
|
||||||
} else {
|
|
||||||
return retErr
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
t.l.WithField("route", r).Info("Added route")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !foundDefault4 {
|
|
||||||
if r.Cidr.Bits() == 0 && r.Cidr.Addr().BitLen() == 32 {
|
|
||||||
foundDefault4 = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ipif, err := luid.IPInterface(windows.AF_INET)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get ip interface: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ipif.NLMTU = uint32(t.MTU)
|
|
||||||
if foundDefault4 {
|
|
||||||
ipif.UseAutomaticMetric = false
|
|
||||||
ipif.Metric = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := ipif.Set(); err != nil {
|
|
||||||
return fmt.Errorf("failed to set ip interface: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *winTun) removeRoutes(routes []Route) error {
|
|
||||||
luid := winipcfg.LUID(t.tun.LUID())
|
|
||||||
|
|
||||||
for _, r := range routes {
|
|
||||||
if !r.Install {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
err := luid.DeleteRoute(r.Cidr, r.Via)
|
|
||||||
if err != nil {
|
|
||||||
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
|
|
||||||
} else {
|
|
||||||
t.l.WithField("route", r).Info("Removed route")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *winTun) RouteFor(ip netip.Addr) netip.Addr {
|
|
||||||
r, _ := t.routeTree.Load().Lookup(ip)
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *winTun) Cidr() netip.Prefix {
|
|
||||||
return t.cidr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *winTun) Name() string {
|
|
||||||
return t.Device
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *winTun) Read(b []byte) (int, error) {
|
|
||||||
return t.tun.Read(b, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *winTun) Write(b []byte) (int, error) {
|
|
||||||
return t.tun.Write(b, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
|
||||||
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *winTun) Close() error {
|
|
||||||
// It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes,
|
|
||||||
// so to be certain, just remove everything before destroying.
|
|
||||||
luid := winipcfg.LUID(t.tun.LUID())
|
|
||||||
_ = luid.FlushRoutes(windows.AF_INET)
|
|
||||||
_ = luid.FlushIPAddresses(windows.AF_INET)
|
|
||||||
/* We don't support IPV6 yet
|
|
||||||
_ = luid.FlushRoutes(windows.AF_INET6)
|
|
||||||
_ = luid.FlushIPAddresses(windows.AF_INET6)
|
|
||||||
*/
|
|
||||||
_ = luid.FlushDNS(windows.AF_INET)
|
|
||||||
|
|
||||||
return t.tun.Close()
|
|
||||||
}
|
|
||||||
@ -8,16 +8,16 @@ import (
|
|||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) {
|
func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
||||||
return NewUserDevice(tunCidr)
|
return NewUserDevice(vpnNetworks)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUserDevice(tunCidr netip.Prefix) (Device, error) {
|
func NewUserDevice(vpnNetworks []netip.Prefix) (Device, error) {
|
||||||
// these pipes guarantee each write/read will match 1:1
|
// these pipes guarantee each write/read will match 1:1
|
||||||
or, ow := io.Pipe()
|
or, ow := io.Pipe()
|
||||||
ir, iw := io.Pipe()
|
ir, iw := io.Pipe()
|
||||||
return &UserDevice{
|
return &UserDevice{
|
||||||
tunCidr: tunCidr,
|
vpnNetworks: vpnNetworks,
|
||||||
outboundReader: or,
|
outboundReader: or,
|
||||||
outboundWriter: ow,
|
outboundWriter: ow,
|
||||||
inboundReader: ir,
|
inboundReader: ir,
|
||||||
@ -26,7 +26,7 @@ func NewUserDevice(tunCidr netip.Prefix) (Device, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type UserDevice struct {
|
type UserDevice struct {
|
||||||
tunCidr netip.Prefix
|
vpnNetworks []netip.Prefix
|
||||||
|
|
||||||
outboundReader *io.PipeReader
|
outboundReader *io.PipeReader
|
||||||
outboundWriter *io.PipeWriter
|
outboundWriter *io.PipeWriter
|
||||||
@ -38,7 +38,7 @@ type UserDevice struct {
|
|||||||
func (d *UserDevice) Activate() error {
|
func (d *UserDevice) Activate() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
func (d *UserDevice) Cidr() netip.Prefix { return d.tunCidr }
|
func (d *UserDevice) Networks() []netip.Prefix { return d.vpnNetworks }
|
||||||
func (d *UserDevice) Name() string { return "faketun0" }
|
func (d *UserDevice) Name() string { return "faketun0" }
|
||||||
func (d *UserDevice) RouteFor(ip netip.Addr) netip.Addr { return ip }
|
func (d *UserDevice) RouteFor(ip netip.Addr) netip.Addr { return ip }
|
||||||
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||||
|
|||||||
409
pki.go
409
pki.go
@ -1,13 +1,19 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gaissmai/bart"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
@ -21,12 +27,22 @@ type PKI struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type CertState struct {
|
type CertState struct {
|
||||||
Certificate cert.Certificate
|
v1Cert cert.Certificate
|
||||||
RawCertificate []byte
|
v1HandshakeBytes []byte
|
||||||
RawCertificateNoKey []byte
|
|
||||||
PublicKey []byte
|
v2Cert cert.Certificate
|
||||||
PrivateKey []byte
|
v2HandshakeBytes []byte
|
||||||
|
|
||||||
|
defaultVersion cert.Version
|
||||||
|
privateKey []byte
|
||||||
pkcs11Backed bool
|
pkcs11Backed bool
|
||||||
|
cipher string
|
||||||
|
|
||||||
|
myVpnNetworks []netip.Prefix
|
||||||
|
myVpnNetworksTable *bart.Table[struct{}]
|
||||||
|
myVpnAddrs []netip.Addr
|
||||||
|
myVpnAddrsTable *bart.Table[struct{}]
|
||||||
|
myVpnBroadcastAddrsTable *bart.Table[struct{}]
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
|
func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
|
||||||
@ -46,16 +62,26 @@ func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
|
|||||||
return pki, nil
|
return pki, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PKI) GetCertState() *CertState {
|
|
||||||
return p.cs.Load()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *PKI) GetCAPool() *cert.CAPool {
|
func (p *PKI) GetCAPool() *cert.CAPool {
|
||||||
return p.caPool.Load()
|
return p.caPool.Load()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *PKI) getCertState() *CertState {
|
||||||
|
return p.cs.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: We should remove this
|
||||||
|
func (p *PKI) getDefaultCertificate() cert.Certificate {
|
||||||
|
return p.cs.Load().GetDefaultCertificate()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: We should remove this
|
||||||
|
func (p *PKI) getCertificate(v cert.Version) cert.Certificate {
|
||||||
|
return p.cs.Load().getCertificate(v)
|
||||||
|
}
|
||||||
|
|
||||||
func (p *PKI) reload(c *config.C, initial bool) error {
|
func (p *PKI) reload(c *config.C, initial bool) error {
|
||||||
err := p.reloadCert(c, initial)
|
err := p.reloadCerts(c, initial)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if initial {
|
if initial {
|
||||||
return err
|
return err
|
||||||
@ -74,33 +100,94 @@ func (p *PKI) reload(c *config.C, initial bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PKI) reloadCert(c *config.C, initial bool) *util.ContextualError {
|
func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
|
||||||
cs, err := newCertStateFromConfig(c)
|
newState, err := newCertStateFromConfig(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.NewContextualError("Could not load client cert", nil, err)
|
return util.NewContextualError("Could not load client cert", nil, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !initial {
|
if !initial {
|
||||||
//TODO: include check for mask equality as well
|
currentState := p.cs.Load()
|
||||||
|
if newState.v1Cert != nil {
|
||||||
|
if currentState.v1Cert == nil {
|
||||||
|
return util.NewContextualError("v1 certificate was added, restart required", nil, err)
|
||||||
|
}
|
||||||
|
|
||||||
// did IP in cert change? if so, don't set
|
// did IP in cert change? if so, don't set
|
||||||
currentCert := p.cs.Load().Certificate
|
if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
|
||||||
oldIPs := currentCert.Networks()
|
|
||||||
newIPs := cs.Certificate.Networks()
|
|
||||||
if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() {
|
|
||||||
return util.NewContextualError(
|
return util.NewContextualError(
|
||||||
"Networks in new cert was different from old",
|
"Networks in new cert was different from old",
|
||||||
m{"new_network": newIPs[0], "old_network": oldIPs[0]},
|
m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks()},
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
|
||||||
|
return util.NewContextualError(
|
||||||
|
"Curve in new cert was different from old",
|
||||||
|
m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve()},
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
} else if currentState.v1Cert != nil {
|
||||||
|
//TODO: we should be able to tear this down
|
||||||
|
return util.NewContextualError("v1 certificate was removed, restart required", nil, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if newState.v2Cert != nil {
|
||||||
|
if currentState.v2Cert == nil {
|
||||||
|
return util.NewContextualError("v2 certificate was added, restart required", nil, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// did IP in cert change? if so, don't set
|
||||||
|
if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
|
||||||
|
return util.NewContextualError(
|
||||||
|
"Networks in new cert was different from old",
|
||||||
|
m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks()},
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
|
||||||
|
return util.NewContextualError(
|
||||||
|
"Curve in new cert was different from old",
|
||||||
|
m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve()},
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
} else if currentState.v2Cert != nil {
|
||||||
|
return util.NewContextualError("v2 certificate was removed, restart required", nil, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cipher cant be hot swapped so just leave it at what it was before
|
||||||
|
newState.cipher = currentState.cipher
|
||||||
|
|
||||||
|
} else {
|
||||||
|
newState.cipher = c.GetString("cipher", "aes")
|
||||||
|
//TODO: this sucks and we should make it not a global
|
||||||
|
switch newState.cipher {
|
||||||
|
case "aes":
|
||||||
|
noiseEndianness = binary.BigEndian
|
||||||
|
case "chachapoly":
|
||||||
|
noiseEndianness = binary.LittleEndian
|
||||||
|
default:
|
||||||
|
return util.NewContextualError(
|
||||||
|
"unknown cipher",
|
||||||
|
m{"cipher": newState.cipher},
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
p.cs.Store(cs)
|
p.cs.Store(newState)
|
||||||
|
|
||||||
|
//TODO: newState needs a stringer that does json
|
||||||
if initial {
|
if initial {
|
||||||
p.l.WithField("cert", cs.Certificate).Debug("Client nebula certificate")
|
p.l.WithField("cert", newState).Debug("Client nebula certificate(s)")
|
||||||
} else {
|
} else {
|
||||||
p.l.WithField("cert", cs.Certificate).Info("Client cert refreshed from disk")
|
p.l.WithField("cert", newState).Info("Client certificate(s) refreshed from disk")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -116,55 +203,65 @@ func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newCertState(certificate cert.Certificate, pkcs11backed bool, privateKey []byte) (*CertState, error) {
|
func (cs *CertState) GetDefaultCertificate() cert.Certificate {
|
||||||
// Marshal the certificate to ensure it is valid
|
c := cs.getCertificate(cs.defaultVersion)
|
||||||
rawCertificate, err := certificate.Marshal()
|
if c == nil {
|
||||||
if err != nil {
|
panic("No default certificate found")
|
||||||
return nil, fmt.Errorf("invalid nebula certificate on interface: %s", err)
|
}
|
||||||
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
publicKey := certificate.PublicKey()
|
func (cs *CertState) getCertificate(v cert.Version) cert.Certificate {
|
||||||
cs := &CertState{
|
switch v {
|
||||||
RawCertificate: rawCertificate,
|
case cert.Version1:
|
||||||
Certificate: certificate,
|
return cs.v1Cert
|
||||||
PrivateKey: privateKey,
|
case cert.Version2:
|
||||||
PublicKey: publicKey,
|
return cs.v2Cert
|
||||||
pkcs11Backed: pkcs11backed,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
rawCertNoKey, err := cs.Certificate.MarshalForHandshakes()
|
return nil
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error marshalling certificate no key: %s", err)
|
|
||||||
}
|
|
||||||
cs.RawCertificateNoKey = rawCertNoKey
|
|
||||||
|
|
||||||
return cs, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func loadPrivateKey(privPathOrPEM string) (rawKey []byte, curve cert.Curve, isPkcs11 bool, err error) {
|
// getHandshakeBytes returns the cached bytes to be used in a handshake message for the requested version.
|
||||||
var pemPrivateKey []byte
|
// Callers must check if the return []byte is nil.
|
||||||
if strings.Contains(privPathOrPEM, "-----BEGIN") {
|
func (cs *CertState) getHandshakeBytes(v cert.Version) []byte {
|
||||||
pemPrivateKey = []byte(privPathOrPEM)
|
switch v {
|
||||||
privPathOrPEM = "<inline>"
|
case cert.Version1:
|
||||||
rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey)
|
return cs.v1HandshakeBytes
|
||||||
if err != nil {
|
case cert.Version2:
|
||||||
return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
|
return cs.v2HandshakeBytes
|
||||||
}
|
default:
|
||||||
} else if strings.HasPrefix(privPathOrPEM, "pkcs11:") {
|
return nil
|
||||||
rawKey = []byte(privPathOrPEM)
|
|
||||||
return rawKey, cert.Curve_P256, true, nil
|
|
||||||
} else {
|
|
||||||
pemPrivateKey, err = os.ReadFile(privPathOrPEM)
|
|
||||||
if err != nil {
|
|
||||||
return nil, curve, false, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err)
|
|
||||||
}
|
|
||||||
rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
func (cs *CertState) String() string {
|
||||||
|
b, err := cs.MarshalJSON()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Sprintf("error marshaling certificate state: %v", err)
|
||||||
|
}
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cs *CertState) MarshalJSON() ([]byte, error) {
|
||||||
|
msg := []json.RawMessage{}
|
||||||
|
if cs.v1Cert != nil {
|
||||||
|
b, err := cs.v1Cert.MarshalJSON()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
msg = append(msg, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cs.v2Cert != nil {
|
||||||
|
b, err := cs.v2Cert.MarshalJSON()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
msg = append(msg, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.Marshal(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newCertStateFromConfig(c *config.C) (*CertState, error) {
|
func newCertStateFromConfig(c *config.C) (*CertState, error) {
|
||||||
@ -198,24 +295,198 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
nebulaCert, _, err := cert.UnmarshalCertificateFromPEM(rawCert)
|
var crt, v1, v2 cert.Certificate
|
||||||
|
for {
|
||||||
|
// Load the certificate
|
||||||
|
crt, rawCert, err = loadCertificate(rawCert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error while unmarshaling pki.cert %s: %s", pubPathOrPEM, err)
|
//TODO: check error
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if nebulaCert.Expired(time.Now()) {
|
switch crt.Version() {
|
||||||
return nil, fmt.Errorf("nebula certificate for this host is expired")
|
case cert.Version1:
|
||||||
|
if v1 != nil {
|
||||||
|
return nil, fmt.Errorf("v1 certificate already found in pki.cert")
|
||||||
|
}
|
||||||
|
v1 = crt
|
||||||
|
case cert.Version2:
|
||||||
|
if v2 != nil {
|
||||||
|
return nil, fmt.Errorf("v2 certificate already found in pki.cert")
|
||||||
|
}
|
||||||
|
v2 = crt
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown certificate version %v", crt.Version())
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(nebulaCert.Networks()) == 0 {
|
if len(rawCert) == 0 || strings.TrimSpace(string(rawCert)) == "" {
|
||||||
return nil, fmt.Errorf("no networks encoded in certificate")
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = nebulaCert.VerifyPrivateKey(curve, rawKey); err != nil {
|
if v1 == nil && v2 == nil {
|
||||||
|
return nil, errors.New("no certificates found in pki.cert")
|
||||||
|
}
|
||||||
|
|
||||||
|
useDefaultVersion := uint32(1)
|
||||||
|
if v1 == nil {
|
||||||
|
// The only condition that requires v2 as the default is if only a v2 certificate is present
|
||||||
|
// We do this to avoid having to configure it specifically in the config file
|
||||||
|
useDefaultVersion = 2
|
||||||
|
}
|
||||||
|
|
||||||
|
rawDefaultVersion := c.GetUint32("pki.default_version", useDefaultVersion)
|
||||||
|
var defaultVersion cert.Version
|
||||||
|
switch rawDefaultVersion {
|
||||||
|
case 1:
|
||||||
|
if v1 == nil {
|
||||||
|
return nil, fmt.Errorf("can not use pki.default_version 1 without a v1 certificate in pki.cert")
|
||||||
|
}
|
||||||
|
defaultVersion = cert.Version1
|
||||||
|
case 2:
|
||||||
|
defaultVersion = cert.Version2
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown pki.default_version: %v", rawDefaultVersion)
|
||||||
|
}
|
||||||
|
|
||||||
|
return newCertState(defaultVersion, v1, v2, isPkcs11, curve, rawKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) {
|
||||||
|
cs := CertState{
|
||||||
|
privateKey: privateKey,
|
||||||
|
pkcs11Backed: pkcs11backed,
|
||||||
|
myVpnNetworksTable: new(bart.Table[struct{}]),
|
||||||
|
myVpnAddrsTable: new(bart.Table[struct{}]),
|
||||||
|
myVpnBroadcastAddrsTable: new(bart.Table[struct{}]),
|
||||||
|
}
|
||||||
|
|
||||||
|
if v1 != nil && v2 != nil {
|
||||||
|
if !slices.Equal(v1.PublicKey(), v2.PublicKey()) {
|
||||||
|
return nil, util.NewContextualError("v1 and v2 public keys are not the same, ignoring", nil, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
if v1.Curve() != v2.Curve() {
|
||||||
|
return nil, util.NewContextualError("v1 and v2 curve are not the same, ignoring", nil, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
//TODO: make sure v2 has v1s address
|
||||||
|
|
||||||
|
cs.defaultVersion = dv
|
||||||
|
}
|
||||||
|
|
||||||
|
if v1 != nil {
|
||||||
|
if pkcs11backed {
|
||||||
|
//TODO: We do not currently have a method to verify a public private key pair when the private key is in an hsm
|
||||||
|
} else {
|
||||||
|
if err := v1.VerifyPrivateKey(privateKeyCurve, privateKey); err != nil {
|
||||||
return nil, fmt.Errorf("private key is not a pair with public key in nebula cert")
|
return nil, fmt.Errorf("private key is not a pair with public key in nebula cert")
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return newCertState(nebulaCert, isPkcs11, rawKey)
|
v1hs, err := v1.MarshalForHandshakes()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err)
|
||||||
|
}
|
||||||
|
cs.v1Cert = v1
|
||||||
|
cs.v1HandshakeBytes = v1hs
|
||||||
|
|
||||||
|
if cs.defaultVersion == 0 {
|
||||||
|
cs.defaultVersion = cert.Version1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if v2 != nil {
|
||||||
|
if pkcs11backed {
|
||||||
|
//TODO: We do not currently have a method to verify a public private key pair when the private key is in an hsm
|
||||||
|
} else {
|
||||||
|
if err := v2.VerifyPrivateKey(privateKeyCurve, privateKey); err != nil {
|
||||||
|
return nil, fmt.Errorf("private key is not a pair with public key in nebula cert")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
v2hs, err := v2.MarshalForHandshakes()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err)
|
||||||
|
}
|
||||||
|
cs.v2Cert = v2
|
||||||
|
cs.v2HandshakeBytes = v2hs
|
||||||
|
|
||||||
|
if cs.defaultVersion == 0 {
|
||||||
|
cs.defaultVersion = cert.Version2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var crt cert.Certificate
|
||||||
|
crt = cs.getCertificate(cert.Version2)
|
||||||
|
if crt == nil {
|
||||||
|
// v2 certificates are a superset, only look at v1 if its all we have
|
||||||
|
crt = cs.getCertificate(cert.Version1)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, network := range crt.Networks() {
|
||||||
|
cs.myVpnNetworks = append(cs.myVpnNetworks, network)
|
||||||
|
cs.myVpnNetworksTable.Insert(network, struct{}{})
|
||||||
|
|
||||||
|
cs.myVpnAddrs = append(cs.myVpnAddrs, network.Addr())
|
||||||
|
cs.myVpnAddrsTable.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()), struct{}{})
|
||||||
|
|
||||||
|
if network.Addr().Is4() {
|
||||||
|
addr := network.Masked().Addr().As4()
|
||||||
|
mask := net.CIDRMask(network.Bits(), network.Addr().BitLen())
|
||||||
|
binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(mask))
|
||||||
|
cs.myVpnBroadcastAddrsTable.Insert(netip.PrefixFrom(netip.AddrFrom4(addr), network.Addr().BitLen()), struct{}{})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &cs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadPrivateKey(privPathOrPEM string) (rawKey []byte, curve cert.Curve, isPkcs11 bool, err error) {
|
||||||
|
var pemPrivateKey []byte
|
||||||
|
if strings.Contains(privPathOrPEM, "-----BEGIN") {
|
||||||
|
pemPrivateKey = []byte(privPathOrPEM)
|
||||||
|
privPathOrPEM = "<inline>"
|
||||||
|
rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
|
||||||
|
}
|
||||||
|
} else if strings.HasPrefix(privPathOrPEM, "pkcs11:") {
|
||||||
|
rawKey = []byte(privPathOrPEM)
|
||||||
|
return rawKey, cert.Curve_P256, true, nil
|
||||||
|
} else {
|
||||||
|
pemPrivateKey, err = os.ReadFile(privPathOrPEM)
|
||||||
|
if err != nil {
|
||||||
|
return nil, curve, false, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err)
|
||||||
|
}
|
||||||
|
rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadCertificate(b []byte) (cert.Certificate, []byte, error) {
|
||||||
|
c, b, err := cert.UnmarshalCertificateFromPEM(b)
|
||||||
|
if err != nil {
|
||||||
|
return nil, b, fmt.Errorf("error while unmarshaling pki.cert: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.Expired(time.Now()) {
|
||||||
|
return nil, b, fmt.Errorf("nebula certificate for this host is expired")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(c.Networks()) == 0 {
|
||||||
|
return nil, b, fmt.Errorf("no networks encoded in certificate")
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.IsCA() {
|
||||||
|
return nil, b, fmt.Errorf("host certificate is a CA certificate")
|
||||||
|
}
|
||||||
|
|
||||||
|
return c, b, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {
|
func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {
|
||||||
|
|||||||
185
relay_manager.go
185
relay_manager.go
@ -9,6 +9,7 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
)
|
)
|
||||||
@ -72,7 +73,7 @@ func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp neti
|
|||||||
Type: relayType,
|
Type: relayType,
|
||||||
State: state,
|
State: state,
|
||||||
LocalIndex: index,
|
LocalIndex: index,
|
||||||
PeerIp: vpnIp,
|
PeerAddr: vpnIp,
|
||||||
}
|
}
|
||||||
|
|
||||||
if remoteIdx != nil {
|
if remoteIdx != nil {
|
||||||
@ -91,40 +92,60 @@ func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp neti
|
|||||||
func (rm *relayManager) EstablishRelay(relayHostInfo *HostInfo, m *NebulaControl) (*Relay, error) {
|
func (rm *relayManager) EstablishRelay(relayHostInfo *HostInfo, m *NebulaControl) (*Relay, error) {
|
||||||
relay, ok := relayHostInfo.relayState.CompleteRelayByIdx(m.InitiatorRelayIndex, m.ResponderRelayIndex)
|
relay, ok := relayHostInfo.relayState.CompleteRelayByIdx(m.InitiatorRelayIndex, m.ResponderRelayIndex)
|
||||||
if !ok {
|
if !ok {
|
||||||
rm.l.WithFields(logrus.Fields{"relay": relayHostInfo.vpnIp,
|
//TODO: we need to handle possibly logging deprecated fields as well
|
||||||
|
rm.l.WithFields(logrus.Fields{"relay": relayHostInfo.vpnAddrs[0],
|
||||||
"initiatorRelayIndex": m.InitiatorRelayIndex,
|
"initiatorRelayIndex": m.InitiatorRelayIndex,
|
||||||
"relayFrom": m.RelayFromIp,
|
"relayFrom": m.RelayFromAddr,
|
||||||
"relayTo": m.RelayToIp}).Info("relayManager failed to update relay")
|
"relayTo": m.RelayToAddr}).Info("relayManager failed to update relay")
|
||||||
return nil, fmt.Errorf("unknown relay")
|
return nil, fmt.Errorf("unknown relay")
|
||||||
}
|
}
|
||||||
|
|
||||||
return relay, nil
|
return relay, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rm *relayManager) HandleControlMsg(h *HostInfo, m *NebulaControl, f *Interface) {
|
func (rm *relayManager) HandleControlMsg(h *HostInfo, d []byte, f *Interface) {
|
||||||
|
msg := &NebulaControl{}
|
||||||
|
err := msg.Unmarshal(d)
|
||||||
|
if err != nil {
|
||||||
|
h.logger(f.l).WithError(err).Error("Failed to unmarshal control message")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
switch m.Type {
|
var v cert.Version
|
||||||
|
if msg.OldRelayFromAddr > 0 || msg.OldRelayToAddr > 0 {
|
||||||
|
v = cert.Version1
|
||||||
|
|
||||||
|
//TODO: yeah this is junk but maybe its less junky than the other options
|
||||||
|
b := [4]byte{}
|
||||||
|
binary.BigEndian.PutUint32(b[:], msg.OldRelayFromAddr)
|
||||||
|
msg.RelayFromAddr = netAddrToProtoAddr(netip.AddrFrom4(b))
|
||||||
|
|
||||||
|
binary.BigEndian.PutUint32(b[:], msg.OldRelayToAddr)
|
||||||
|
msg.RelayToAddr = netAddrToProtoAddr(netip.AddrFrom4(b))
|
||||||
|
} else {
|
||||||
|
v = cert.Version2
|
||||||
|
}
|
||||||
|
|
||||||
|
switch msg.Type {
|
||||||
case NebulaControl_CreateRelayRequest:
|
case NebulaControl_CreateRelayRequest:
|
||||||
rm.handleCreateRelayRequest(h, f, m)
|
rm.handleCreateRelayRequest(v, h, f, msg)
|
||||||
case NebulaControl_CreateRelayResponse:
|
case NebulaControl_CreateRelayResponse:
|
||||||
rm.handleCreateRelayResponse(h, f, m)
|
rm.handleCreateRelayResponse(v, h, f, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *NebulaControl) {
|
func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) {
|
||||||
rm.l.WithFields(logrus.Fields{
|
rm.l.WithFields(logrus.Fields{
|
||||||
"relayFrom": m.RelayFromIp,
|
"relayFrom": m.RelayFromAddr,
|
||||||
"relayTo": m.RelayToIp,
|
"relayTo": m.RelayToAddr,
|
||||||
"initiatorRelayIndex": m.InitiatorRelayIndex,
|
"initiatorRelayIndex": m.InitiatorRelayIndex,
|
||||||
"responderRelayIndex": m.ResponderRelayIndex,
|
"responderRelayIndex": m.ResponderRelayIndex,
|
||||||
"vpnIp": h.vpnIp}).
|
"vpnAddrs": h.vpnAddrs}).
|
||||||
Info("handleCreateRelayResponse")
|
Info("handleCreateRelayResponse")
|
||||||
target := m.RelayToIp
|
|
||||||
//TODO: IPV6-WORK
|
target := m.RelayToAddr
|
||||||
b := [4]byte{}
|
targetAddr := protoAddrToNetAddr(target)
|
||||||
binary.BigEndian.PutUint32(b[:], m.RelayToIp)
|
|
||||||
targetAddr := netip.AddrFrom4(b)
|
|
||||||
|
|
||||||
relay, err := rm.EstablishRelay(h, m)
|
relay, err := rm.EstablishRelay(h, m)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -136,68 +157,79 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
// I'm the middle man. Let the initiator know that the I've established the relay they requested.
|
// I'm the middle man. Let the initiator know that the I've established the relay they requested.
|
||||||
peerHostInfo := rm.hostmap.QueryVpnIp(relay.PeerIp)
|
peerHostInfo := rm.hostmap.QueryVpnAddr(relay.PeerAddr)
|
||||||
if peerHostInfo == nil {
|
if peerHostInfo == nil {
|
||||||
rm.l.WithField("relayTo", relay.PeerIp).Error("Can't find a HostInfo for peer")
|
rm.l.WithField("relayTo", relay.PeerAddr).Error("Can't find a HostInfo for peer")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(targetAddr)
|
peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(targetAddr)
|
||||||
if !ok {
|
if !ok {
|
||||||
rm.l.WithField("relayTo", peerHostInfo.vpnIp).Error("peerRelay does not have Relay state for relayTo")
|
rm.l.WithField("relayTo", peerHostInfo.vpnAddrs[0]).Error("peerRelay does not have Relay state for relayTo")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if peerRelay.State == PeerRequested {
|
if peerRelay.State == PeerRequested {
|
||||||
//TODO: IPV6-WORK
|
|
||||||
b = peerHostInfo.vpnIp.As4()
|
|
||||||
peerRelay.State = Established
|
peerRelay.State = Established
|
||||||
resp := NebulaControl{
|
resp := NebulaControl{
|
||||||
Type: NebulaControl_CreateRelayResponse,
|
Type: NebulaControl_CreateRelayResponse,
|
||||||
ResponderRelayIndex: peerRelay.LocalIndex,
|
ResponderRelayIndex: peerRelay.LocalIndex,
|
||||||
InitiatorRelayIndex: peerRelay.RemoteIndex,
|
InitiatorRelayIndex: peerRelay.RemoteIndex,
|
||||||
RelayFromIp: binary.BigEndian.Uint32(b[:]),
|
|
||||||
RelayToIp: uint32(target),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if v == cert.Version1 {
|
||||||
|
peer := peerHostInfo.vpnAddrs[0]
|
||||||
|
if !peer.Is4() {
|
||||||
|
//TODO: log cant do it
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
b := peer.As4()
|
||||||
|
resp.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
|
||||||
|
b = targetAddr.As4()
|
||||||
|
resp.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
|
||||||
|
} else {
|
||||||
|
resp.RelayFromAddr = netAddrToProtoAddr(peerHostInfo.vpnAddrs[0])
|
||||||
|
resp.RelayToAddr = target
|
||||||
|
}
|
||||||
|
|
||||||
msg, err := resp.Marshal()
|
msg, err := resp.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rm.l.
|
rm.l.WithError(err).
|
||||||
WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay")
|
Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay")
|
||||||
} else {
|
} else {
|
||||||
f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu))
|
||||||
rm.l.WithFields(logrus.Fields{
|
rm.l.WithFields(logrus.Fields{
|
||||||
"relayFrom": resp.RelayFromIp,
|
"relayFrom": resp.RelayFromAddr,
|
||||||
"relayTo": resp.RelayToIp,
|
"relayTo": resp.RelayToAddr,
|
||||||
"initiatorRelayIndex": resp.InitiatorRelayIndex,
|
"initiatorRelayIndex": resp.InitiatorRelayIndex,
|
||||||
"responderRelayIndex": resp.ResponderRelayIndex,
|
"responderRelayIndex": resp.ResponderRelayIndex,
|
||||||
"vpnIp": peerHostInfo.vpnIp}).
|
"vpnAddrs": peerHostInfo.vpnAddrs}).
|
||||||
Info("send CreateRelayResponse")
|
Info("send CreateRelayResponse")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *NebulaControl) {
|
func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) {
|
||||||
//TODO: IPV6-WORK
|
from := protoAddrToNetAddr(m.RelayFromAddr)
|
||||||
b := [4]byte{}
|
target := protoAddrToNetAddr(m.RelayToAddr)
|
||||||
binary.BigEndian.PutUint32(b[:], m.RelayFromIp)
|
|
||||||
from := netip.AddrFrom4(b)
|
|
||||||
|
|
||||||
binary.BigEndian.PutUint32(b[:], m.RelayToIp)
|
|
||||||
target := netip.AddrFrom4(b)
|
|
||||||
|
|
||||||
logMsg := rm.l.WithFields(logrus.Fields{
|
logMsg := rm.l.WithFields(logrus.Fields{
|
||||||
"relayFrom": from,
|
"relayFrom": from,
|
||||||
"relayTo": target,
|
"relayTo": target,
|
||||||
"initiatorRelayIndex": m.InitiatorRelayIndex,
|
"initiatorRelayIndex": m.InitiatorRelayIndex,
|
||||||
"vpnIp": h.vpnIp})
|
"vpnAddrs": h.vpnAddrs})
|
||||||
|
|
||||||
logMsg.Info("handleCreateRelayRequest")
|
logMsg.Info("handleCreateRelayRequest")
|
||||||
// Is the source of the relay me? This should never happen, but did happen due to
|
// Is the source of the relay me? This should never happen, but did happen due to
|
||||||
// an issue migrating relays over to newly re-handshaked host info objects.
|
// an issue migrating relays over to newly re-handshaked host info objects.
|
||||||
if from == f.myVpnNet.Addr() {
|
_, found := f.myVpnAddrsTable.Lookup(from)
|
||||||
|
if found {
|
||||||
logMsg.WithField("myIP", from).Error("Discarding relay request from myself")
|
logMsg.WithField("myIP", from).Error("Discarding relay request from myself")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Is the target of the relay me?
|
// Is the target of the relay me?
|
||||||
if target == f.myVpnNet.Addr() {
|
_, found = f.myVpnAddrsTable.Lookup(target)
|
||||||
|
if found {
|
||||||
existingRelay, ok := h.relayState.QueryRelayForByIp(from)
|
existingRelay, ok := h.relayState.QueryRelayForByIp(from)
|
||||||
if ok {
|
if ok {
|
||||||
switch existingRelay.State {
|
switch existingRelay.State {
|
||||||
@ -230,17 +262,22 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO: IPV6-WORK
|
|
||||||
fromB := from.As4()
|
|
||||||
targetB := target.As4()
|
|
||||||
|
|
||||||
resp := NebulaControl{
|
resp := NebulaControl{
|
||||||
Type: NebulaControl_CreateRelayResponse,
|
Type: NebulaControl_CreateRelayResponse,
|
||||||
ResponderRelayIndex: relay.LocalIndex,
|
ResponderRelayIndex: relay.LocalIndex,
|
||||||
InitiatorRelayIndex: relay.RemoteIndex,
|
InitiatorRelayIndex: relay.RemoteIndex,
|
||||||
RelayFromIp: binary.BigEndian.Uint32(fromB[:]),
|
|
||||||
RelayToIp: binary.BigEndian.Uint32(targetB[:]),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if v == cert.Version1 {
|
||||||
|
b := from.As4()
|
||||||
|
resp.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
|
||||||
|
b = target.As4()
|
||||||
|
resp.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
|
||||||
|
} else {
|
||||||
|
resp.RelayFromAddr = netAddrToProtoAddr(from)
|
||||||
|
resp.RelayToAddr = netAddrToProtoAddr(target)
|
||||||
|
}
|
||||||
|
|
||||||
msg, err := resp.Marshal()
|
msg, err := resp.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logMsg.
|
logMsg.
|
||||||
@ -253,7 +290,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
|
|||||||
"relayTo": target,
|
"relayTo": target,
|
||||||
"initiatorRelayIndex": resp.InitiatorRelayIndex,
|
"initiatorRelayIndex": resp.InitiatorRelayIndex,
|
||||||
"responderRelayIndex": resp.ResponderRelayIndex,
|
"responderRelayIndex": resp.ResponderRelayIndex,
|
||||||
"vpnIp": h.vpnIp}).
|
"vpnAddrs": h.vpnAddrs}).
|
||||||
Info("send CreateRelayResponse")
|
Info("send CreateRelayResponse")
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@ -262,7 +299,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
|
|||||||
if !rm.GetAmRelay() {
|
if !rm.GetAmRelay() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
peer := rm.hostmap.QueryVpnIp(target)
|
peer := rm.hostmap.QueryVpnAddr(target)
|
||||||
if peer == nil {
|
if peer == nil {
|
||||||
// Try to establish a connection to this host. If we get a future relay request,
|
// Try to establish a connection to this host. If we get a future relay request,
|
||||||
// we'll be ready!
|
// we'll be ready!
|
||||||
@ -291,17 +328,27 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
|
|||||||
sendCreateRequest = true
|
sendCreateRequest = true
|
||||||
}
|
}
|
||||||
if sendCreateRequest {
|
if sendCreateRequest {
|
||||||
//TODO: IPV6-WORK
|
|
||||||
fromB := h.vpnIp.As4()
|
|
||||||
targetB := target.As4()
|
|
||||||
|
|
||||||
// Send a CreateRelayRequest to the peer.
|
// Send a CreateRelayRequest to the peer.
|
||||||
req := NebulaControl{
|
req := NebulaControl{
|
||||||
Type: NebulaControl_CreateRelayRequest,
|
Type: NebulaControl_CreateRelayRequest,
|
||||||
InitiatorRelayIndex: index,
|
InitiatorRelayIndex: index,
|
||||||
RelayFromIp: binary.BigEndian.Uint32(fromB[:]),
|
|
||||||
RelayToIp: binary.BigEndian.Uint32(targetB[:]),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if v == cert.Version1 {
|
||||||
|
if !h.vpnAddrs[0].Is4() {
|
||||||
|
//TODO: log it
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
b := h.vpnAddrs[0].As4()
|
||||||
|
req.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
|
||||||
|
b = target.As4()
|
||||||
|
req.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
|
||||||
|
} else {
|
||||||
|
req.RelayFromAddr = netAddrToProtoAddr(h.vpnAddrs[0])
|
||||||
|
req.RelayToAddr = netAddrToProtoAddr(target)
|
||||||
|
}
|
||||||
|
|
||||||
msg, err := req.Marshal()
|
msg, err := req.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logMsg.
|
logMsg.
|
||||||
@ -310,11 +357,11 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
|
|||||||
f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu))
|
f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu))
|
||||||
rm.l.WithFields(logrus.Fields{
|
rm.l.WithFields(logrus.Fields{
|
||||||
//TODO: IPV6-WORK another lazy used to use the req object
|
//TODO: IPV6-WORK another lazy used to use the req object
|
||||||
"relayFrom": h.vpnIp,
|
"relayFrom": h.vpnAddrs[0],
|
||||||
"relayTo": target,
|
"relayTo": target,
|
||||||
"initiatorRelayIndex": req.InitiatorRelayIndex,
|
"initiatorRelayIndex": req.InitiatorRelayIndex,
|
||||||
"responderRelayIndex": req.ResponderRelayIndex,
|
"responderRelayIndex": req.ResponderRelayIndex,
|
||||||
"vpnIp": target}).
|
"vpnAddr": target}).
|
||||||
Info("send CreateRelayRequest")
|
Info("send CreateRelayRequest")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -342,16 +389,28 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
|
|||||||
"existingRemoteIndex": relay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest")
|
"existingRemoteIndex": relay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
//TODO: IPV6-WORK
|
|
||||||
fromB := h.vpnIp.As4()
|
|
||||||
targetB := target.As4()
|
|
||||||
resp := NebulaControl{
|
resp := NebulaControl{
|
||||||
Type: NebulaControl_CreateRelayResponse,
|
Type: NebulaControl_CreateRelayResponse,
|
||||||
ResponderRelayIndex: relay.LocalIndex,
|
ResponderRelayIndex: relay.LocalIndex,
|
||||||
InitiatorRelayIndex: relay.RemoteIndex,
|
InitiatorRelayIndex: relay.RemoteIndex,
|
||||||
RelayFromIp: binary.BigEndian.Uint32(fromB[:]),
|
|
||||||
RelayToIp: binary.BigEndian.Uint32(targetB[:]),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if v == cert.Version1 {
|
||||||
|
if !h.vpnAddrs[0].Is4() {
|
||||||
|
//TODO: log it
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
b := h.vpnAddrs[0].As4()
|
||||||
|
resp.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
|
||||||
|
b = target.As4()
|
||||||
|
resp.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
|
||||||
|
} else {
|
||||||
|
resp.RelayFromAddr = netAddrToProtoAddr(h.vpnAddrs[0])
|
||||||
|
resp.RelayToAddr = netAddrToProtoAddr(target)
|
||||||
|
}
|
||||||
|
|
||||||
msg, err := resp.Marshal()
|
msg, err := resp.Marshal()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rm.l.
|
rm.l.
|
||||||
@ -360,11 +419,11 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
|
|||||||
f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu))
|
f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu))
|
||||||
rm.l.WithFields(logrus.Fields{
|
rm.l.WithFields(logrus.Fields{
|
||||||
//TODO: IPV6-WORK more lazy, used to use resp object
|
//TODO: IPV6-WORK more lazy, used to use resp object
|
||||||
"relayFrom": h.vpnIp,
|
"relayFrom": h.vpnAddrs[0],
|
||||||
"relayTo": target,
|
"relayTo": target,
|
||||||
"initiatorRelayIndex": resp.InitiatorRelayIndex,
|
"initiatorRelayIndex": resp.InitiatorRelayIndex,
|
||||||
"responderRelayIndex": resp.ResponderRelayIndex,
|
"responderRelayIndex": resp.ResponderRelayIndex,
|
||||||
"vpnIp": h.vpnIp}).
|
"vpnAddrs": h.vpnAddrs}).
|
||||||
Info("send CreateRelayResponse")
|
Info("send CreateRelayResponse")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -17,8 +17,8 @@ import (
|
|||||||
type forEachFunc func(addr netip.AddrPort, preferred bool)
|
type forEachFunc func(addr netip.AddrPort, preferred bool)
|
||||||
|
|
||||||
// The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate)
|
// The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate)
|
||||||
type checkFuncV4 func(vpnIp netip.Addr, to *Ip4AndPort) bool
|
type checkFuncV4 func(vpnIp netip.Addr, to *V4AddrPort) bool
|
||||||
type checkFuncV6 func(vpnIp netip.Addr, to *Ip6AndPort) bool
|
type checkFuncV6 func(vpnIp netip.Addr, to *V6AddrPort) bool
|
||||||
|
|
||||||
// CacheMap is a struct that better represents the lighthouse cache for humans
|
// CacheMap is a struct that better represents the lighthouse cache for humans
|
||||||
// The string key is the owners vpnIp
|
// The string key is the owners vpnIp
|
||||||
@ -48,14 +48,14 @@ type cacheRelay struct {
|
|||||||
|
|
||||||
// cacheV4 stores learned and reported ipv4 records under cache
|
// cacheV4 stores learned and reported ipv4 records under cache
|
||||||
type cacheV4 struct {
|
type cacheV4 struct {
|
||||||
learned *Ip4AndPort
|
learned *V4AddrPort
|
||||||
reported []*Ip4AndPort
|
reported []*V4AddrPort
|
||||||
}
|
}
|
||||||
|
|
||||||
// cacheV4 stores learned and reported ipv6 records under cache
|
// cacheV4 stores learned and reported ipv6 records under cache
|
||||||
type cacheV6 struct {
|
type cacheV6 struct {
|
||||||
learned *Ip6AndPort
|
learned *V6AddrPort
|
||||||
reported []*Ip6AndPort
|
reported []*V6AddrPort
|
||||||
}
|
}
|
||||||
|
|
||||||
type hostnamePort struct {
|
type hostnamePort struct {
|
||||||
@ -170,7 +170,7 @@ func (hr *hostnamesResults) Cancel() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hr *hostnamesResults) GetIPs() []netip.AddrPort {
|
func (hr *hostnamesResults) GetAddrs() []netip.AddrPort {
|
||||||
var retSlice []netip.AddrPort
|
var retSlice []netip.AddrPort
|
||||||
if hr != nil {
|
if hr != nil {
|
||||||
p := hr.ips.Load()
|
p := hr.ips.Load()
|
||||||
@ -189,6 +189,9 @@ type RemoteList struct {
|
|||||||
// Every interaction with internals requires a lock!
|
// Every interaction with internals requires a lock!
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
|
|
||||||
|
// The full list of vpn addresses assigned to this host
|
||||||
|
vpnAddrs []netip.Addr
|
||||||
|
|
||||||
// A deduplicated set of addresses. Any accessor should lock beforehand.
|
// A deduplicated set of addresses. Any accessor should lock beforehand.
|
||||||
addrs []netip.AddrPort
|
addrs []netip.AddrPort
|
||||||
|
|
||||||
@ -212,13 +215,16 @@ type RemoteList struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewRemoteList creates a new empty RemoteList
|
// NewRemoteList creates a new empty RemoteList
|
||||||
func NewRemoteList(shouldAdd func(netip.Addr) bool) *RemoteList {
|
func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func(netip.Addr) bool) *RemoteList {
|
||||||
return &RemoteList{
|
r := &RemoteList{
|
||||||
|
vpnAddrs: make([]netip.Addr, len(vpnAddrs)),
|
||||||
addrs: make([]netip.AddrPort, 0),
|
addrs: make([]netip.AddrPort, 0),
|
||||||
relays: make([]netip.Addr, 0),
|
relays: make([]netip.Addr, 0),
|
||||||
cache: make(map[netip.Addr]*cache),
|
cache: make(map[netip.Addr]*cache),
|
||||||
shouldAdd: shouldAdd,
|
shouldAdd: shouldAdd,
|
||||||
}
|
}
|
||||||
|
copy(r.vpnAddrs, vpnAddrs)
|
||||||
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) {
|
func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) {
|
||||||
@ -273,9 +279,9 @@ func (r *RemoteList) LearnRemote(ownerVpnIp netip.Addr, remote netip.AddrPort) {
|
|||||||
r.Lock()
|
r.Lock()
|
||||||
defer r.Unlock()
|
defer r.Unlock()
|
||||||
if remote.Addr().Is4() {
|
if remote.Addr().Is4() {
|
||||||
r.unlockedSetLearnedV4(ownerVpnIp, NewIp4AndPortFromNetIP(remote.Addr(), remote.Port()))
|
r.unlockedSetLearnedV4(ownerVpnIp, netAddrToProtoV4AddrPort(remote.Addr(), remote.Port()))
|
||||||
} else {
|
} else {
|
||||||
r.unlockedSetLearnedV6(ownerVpnIp, NewIp6AndPortFromNetIP(remote.Addr(), remote.Port()))
|
r.unlockedSetLearnedV6(ownerVpnIp, netAddrToProtoV6AddrPort(remote.Addr(), remote.Port()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -304,21 +310,21 @@ func (r *RemoteList) CopyCache() *CacheMap {
|
|||||||
|
|
||||||
if mc.v4 != nil {
|
if mc.v4 != nil {
|
||||||
if mc.v4.learned != nil {
|
if mc.v4.learned != nil {
|
||||||
c.Learned = append(c.Learned, AddrPortFromIp4AndPort(mc.v4.learned))
|
c.Learned = append(c.Learned, protoV4AddrPortToNetAddrPort(mc.v4.learned))
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, a := range mc.v4.reported {
|
for _, a := range mc.v4.reported {
|
||||||
c.Reported = append(c.Reported, AddrPortFromIp4AndPort(a))
|
c.Reported = append(c.Reported, protoV4AddrPortToNetAddrPort(a))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if mc.v6 != nil {
|
if mc.v6 != nil {
|
||||||
if mc.v6.learned != nil {
|
if mc.v6.learned != nil {
|
||||||
c.Learned = append(c.Learned, AddrPortFromIp6AndPort(mc.v6.learned))
|
c.Learned = append(c.Learned, protoV6AddrPortToNetAddrPort(mc.v6.learned))
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, a := range mc.v6.reported {
|
for _, a := range mc.v6.reported {
|
||||||
c.Reported = append(c.Reported, AddrPortFromIp6AndPort(a))
|
c.Reported = append(c.Reported, protoV6AddrPortToNetAddrPort(a))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -401,14 +407,14 @@ func (r *RemoteList) unlockedIsBad(remote netip.AddrPort) bool {
|
|||||||
|
|
||||||
// unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the
|
// unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the
|
||||||
// deduplicated address list as dirty
|
// deduplicated address list as dirty
|
||||||
func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp netip.Addr, to *Ip4AndPort) {
|
func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp netip.Addr, to *V4AddrPort) {
|
||||||
r.shouldRebuild = true
|
r.shouldRebuild = true
|
||||||
r.unlockedGetOrMakeV4(ownerVpnIp).learned = to
|
r.unlockedGetOrMakeV4(ownerVpnIp).learned = to
|
||||||
}
|
}
|
||||||
|
|
||||||
// unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
|
// unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
|
||||||
// and marks the deduplicated address list as dirty
|
// and marks the deduplicated address list as dirty
|
||||||
func (r *RemoteList) unlockedSetV4(ownerVpnIp, vpnIp netip.Addr, to []*Ip4AndPort, check checkFuncV4) {
|
func (r *RemoteList) unlockedSetV4(ownerVpnIp, vpnIp netip.Addr, to []*V4AddrPort, check checkFuncV4) {
|
||||||
r.shouldRebuild = true
|
r.shouldRebuild = true
|
||||||
c := r.unlockedGetOrMakeV4(ownerVpnIp)
|
c := r.unlockedGetOrMakeV4(ownerVpnIp)
|
||||||
|
|
||||||
@ -436,12 +442,12 @@ func (r *RemoteList) unlockedSetRelay(ownerVpnIp, vpnIp netip.Addr, to []netip.A
|
|||||||
|
|
||||||
// unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner
|
// unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner
|
||||||
// This is only useful for establishing static hosts
|
// This is only useful for establishing static hosts
|
||||||
func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *Ip4AndPort) {
|
func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *V4AddrPort) {
|
||||||
r.shouldRebuild = true
|
r.shouldRebuild = true
|
||||||
c := r.unlockedGetOrMakeV4(ownerVpnIp)
|
c := r.unlockedGetOrMakeV4(ownerVpnIp)
|
||||||
|
|
||||||
// We are doing the easy append because this is rarely called
|
// We are doing the easy append because this is rarely called
|
||||||
c.reported = append([]*Ip4AndPort{to}, c.reported...)
|
c.reported = append([]*V4AddrPort{to}, c.reported...)
|
||||||
if len(c.reported) > MaxRemotes {
|
if len(c.reported) > MaxRemotes {
|
||||||
c.reported = c.reported[:MaxRemotes]
|
c.reported = c.reported[:MaxRemotes]
|
||||||
}
|
}
|
||||||
@ -449,14 +455,14 @@ func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *Ip4AndPort) {
|
|||||||
|
|
||||||
// unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the
|
// unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the
|
||||||
// deduplicated address list as dirty
|
// deduplicated address list as dirty
|
||||||
func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp netip.Addr, to *Ip6AndPort) {
|
func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp netip.Addr, to *V6AddrPort) {
|
||||||
r.shouldRebuild = true
|
r.shouldRebuild = true
|
||||||
r.unlockedGetOrMakeV6(ownerVpnIp).learned = to
|
r.unlockedGetOrMakeV6(ownerVpnIp).learned = to
|
||||||
}
|
}
|
||||||
|
|
||||||
// unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
|
// unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
|
||||||
// and marks the deduplicated address list as dirty
|
// and marks the deduplicated address list as dirty
|
||||||
func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*Ip6AndPort, check checkFuncV6) {
|
func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*V6AddrPort, check checkFuncV6) {
|
||||||
r.shouldRebuild = true
|
r.shouldRebuild = true
|
||||||
c := r.unlockedGetOrMakeV6(ownerVpnIp)
|
c := r.unlockedGetOrMakeV6(ownerVpnIp)
|
||||||
|
|
||||||
@ -473,12 +479,12 @@ func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*Ip6AndPor
|
|||||||
|
|
||||||
// unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner
|
// unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner
|
||||||
// This is only useful for establishing static hosts
|
// This is only useful for establishing static hosts
|
||||||
func (r *RemoteList) unlockedPrependV6(ownerVpnIp netip.Addr, to *Ip6AndPort) {
|
func (r *RemoteList) unlockedPrependV6(ownerVpnIp netip.Addr, to *V6AddrPort) {
|
||||||
r.shouldRebuild = true
|
r.shouldRebuild = true
|
||||||
c := r.unlockedGetOrMakeV6(ownerVpnIp)
|
c := r.unlockedGetOrMakeV6(ownerVpnIp)
|
||||||
|
|
||||||
// We are doing the easy append because this is rarely called
|
// We are doing the easy append because this is rarely called
|
||||||
c.reported = append([]*Ip6AndPort{to}, c.reported...)
|
c.reported = append([]*V6AddrPort{to}, c.reported...)
|
||||||
if len(c.reported) > MaxRemotes {
|
if len(c.reported) > MaxRemotes {
|
||||||
c.reported = c.reported[:MaxRemotes]
|
c.reported = c.reported[:MaxRemotes]
|
||||||
}
|
}
|
||||||
@ -536,14 +542,14 @@ func (r *RemoteList) unlockedCollect() {
|
|||||||
for _, c := range r.cache {
|
for _, c := range r.cache {
|
||||||
if c.v4 != nil {
|
if c.v4 != nil {
|
||||||
if c.v4.learned != nil {
|
if c.v4.learned != nil {
|
||||||
u := AddrPortFromIp4AndPort(c.v4.learned)
|
u := protoV4AddrPortToNetAddrPort(c.v4.learned)
|
||||||
if !r.unlockedIsBad(u) {
|
if !r.unlockedIsBad(u) {
|
||||||
addrs = append(addrs, u)
|
addrs = append(addrs, u)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, v := range c.v4.reported {
|
for _, v := range c.v4.reported {
|
||||||
u := AddrPortFromIp4AndPort(v)
|
u := protoV4AddrPortToNetAddrPort(v)
|
||||||
if !r.unlockedIsBad(u) {
|
if !r.unlockedIsBad(u) {
|
||||||
addrs = append(addrs, u)
|
addrs = append(addrs, u)
|
||||||
}
|
}
|
||||||
@ -552,14 +558,14 @@ func (r *RemoteList) unlockedCollect() {
|
|||||||
|
|
||||||
if c.v6 != nil {
|
if c.v6 != nil {
|
||||||
if c.v6.learned != nil {
|
if c.v6.learned != nil {
|
||||||
u := AddrPortFromIp6AndPort(c.v6.learned)
|
u := protoV6AddrPortToNetAddrPort(c.v6.learned)
|
||||||
if !r.unlockedIsBad(u) {
|
if !r.unlockedIsBad(u) {
|
||||||
addrs = append(addrs, u)
|
addrs = append(addrs, u)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, v := range c.v6.reported {
|
for _, v := range c.v6.reported {
|
||||||
u := AddrPortFromIp6AndPort(v)
|
u := protoV6AddrPortToNetAddrPort(v)
|
||||||
if !r.unlockedIsBad(u) {
|
if !r.unlockedIsBad(u) {
|
||||||
addrs = append(addrs, u)
|
addrs = append(addrs, u)
|
||||||
}
|
}
|
||||||
@ -573,7 +579,7 @@ func (r *RemoteList) unlockedCollect() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsAddrs := r.hr.GetIPs()
|
dnsAddrs := r.hr.GetAddrs()
|
||||||
for _, addr := range dnsAddrs {
|
for _, addr := range dnsAddrs {
|
||||||
if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) {
|
if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) {
|
||||||
if !r.unlockedIsBad(addr) {
|
if !r.unlockedIsBad(addr) {
|
||||||
|
|||||||
@ -9,11 +9,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestRemoteList_Rebuild(t *testing.T) {
|
func TestRemoteList_Rebuild(t *testing.T) {
|
||||||
rl := NewRemoteList(nil)
|
rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil)
|
||||||
rl.unlockedSetV4(
|
rl.unlockedSetV4(
|
||||||
netip.MustParseAddr("0.0.0.0"),
|
netip.MustParseAddr("0.0.0.0"),
|
||||||
netip.MustParseAddr("0.0.0.0"),
|
netip.MustParseAddr("0.0.0.0"),
|
||||||
[]*Ip4AndPort{
|
[]*V4AddrPort{
|
||||||
newIp4AndPortFromString("70.199.182.92:1475"), // this is duped
|
newIp4AndPortFromString("70.199.182.92:1475"), // this is duped
|
||||||
newIp4AndPortFromString("172.17.0.182:10101"),
|
newIp4AndPortFromString("172.17.0.182:10101"),
|
||||||
newIp4AndPortFromString("172.17.1.1:10101"), // this is duped
|
newIp4AndPortFromString("172.17.1.1:10101"), // this is duped
|
||||||
@ -25,20 +25,20 @@ func TestRemoteList_Rebuild(t *testing.T) {
|
|||||||
newIp4AndPortFromString("70.199.182.92:1476"), // almost dupe of 0 with a diff port
|
newIp4AndPortFromString("70.199.182.92:1476"), // almost dupe of 0 with a diff port
|
||||||
newIp4AndPortFromString("70.199.182.92:1475"), // this is a dupe
|
newIp4AndPortFromString("70.199.182.92:1475"), // this is a dupe
|
||||||
},
|
},
|
||||||
func(netip.Addr, *Ip4AndPort) bool { return true },
|
func(netip.Addr, *V4AddrPort) bool { return true },
|
||||||
)
|
)
|
||||||
|
|
||||||
rl.unlockedSetV6(
|
rl.unlockedSetV6(
|
||||||
netip.MustParseAddr("0.0.0.1"),
|
netip.MustParseAddr("0.0.0.1"),
|
||||||
netip.MustParseAddr("0.0.0.1"),
|
netip.MustParseAddr("0.0.0.1"),
|
||||||
[]*Ip6AndPort{
|
[]*V6AddrPort{
|
||||||
newIp6AndPortFromString("[1::1]:1"), // this is duped
|
newIp6AndPortFromString("[1::1]:1"), // this is duped
|
||||||
newIp6AndPortFromString("[1::1]:2"), // almost dupe of 0 with a diff port, also gets duped
|
newIp6AndPortFromString("[1::1]:2"), // almost dupe of 0 with a diff port, also gets duped
|
||||||
newIp6AndPortFromString("[1:100::1]:1"),
|
newIp6AndPortFromString("[1:100::1]:1"),
|
||||||
newIp6AndPortFromString("[1::1]:1"), // this is a dupe
|
newIp6AndPortFromString("[1::1]:1"), // this is a dupe
|
||||||
newIp6AndPortFromString("[1::1]:2"), // this is a dupe
|
newIp6AndPortFromString("[1::1]:2"), // this is a dupe
|
||||||
},
|
},
|
||||||
func(netip.Addr, *Ip6AndPort) bool { return true },
|
func(netip.Addr, *V6AddrPort) bool { return true },
|
||||||
)
|
)
|
||||||
|
|
||||||
rl.Rebuild([]netip.Prefix{})
|
rl.Rebuild([]netip.Prefix{})
|
||||||
@ -98,11 +98,11 @@ func TestRemoteList_Rebuild(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkFullRebuild(b *testing.B) {
|
func BenchmarkFullRebuild(b *testing.B) {
|
||||||
rl := NewRemoteList(nil)
|
rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil)
|
||||||
rl.unlockedSetV4(
|
rl.unlockedSetV4(
|
||||||
netip.MustParseAddr("0.0.0.0"),
|
netip.MustParseAddr("0.0.0.0"),
|
||||||
netip.MustParseAddr("0.0.0.0"),
|
netip.MustParseAddr("0.0.0.0"),
|
||||||
[]*Ip4AndPort{
|
[]*V4AddrPort{
|
||||||
newIp4AndPortFromString("70.199.182.92:1475"),
|
newIp4AndPortFromString("70.199.182.92:1475"),
|
||||||
newIp4AndPortFromString("172.17.0.182:10101"),
|
newIp4AndPortFromString("172.17.0.182:10101"),
|
||||||
newIp4AndPortFromString("172.17.1.1:10101"),
|
newIp4AndPortFromString("172.17.1.1:10101"),
|
||||||
@ -112,19 +112,19 @@ func BenchmarkFullRebuild(b *testing.B) {
|
|||||||
newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe
|
newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe
|
||||||
newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port
|
newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port
|
||||||
},
|
},
|
||||||
func(netip.Addr, *Ip4AndPort) bool { return true },
|
func(netip.Addr, *V4AddrPort) bool { return true },
|
||||||
)
|
)
|
||||||
|
|
||||||
rl.unlockedSetV6(
|
rl.unlockedSetV6(
|
||||||
netip.MustParseAddr("0.0.0.0"),
|
netip.MustParseAddr("0.0.0.0"),
|
||||||
netip.MustParseAddr("0.0.0.0"),
|
netip.MustParseAddr("0.0.0.0"),
|
||||||
[]*Ip6AndPort{
|
[]*V6AddrPort{
|
||||||
newIp6AndPortFromString("[1::1]:1"),
|
newIp6AndPortFromString("[1::1]:1"),
|
||||||
newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port
|
newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port
|
||||||
newIp6AndPortFromString("[1:100::1]:1"),
|
newIp6AndPortFromString("[1:100::1]:1"),
|
||||||
newIp6AndPortFromString("[1::1]:1"), // this is a dupe
|
newIp6AndPortFromString("[1::1]:1"), // this is a dupe
|
||||||
},
|
},
|
||||||
func(netip.Addr, *Ip6AndPort) bool { return true },
|
func(netip.Addr, *V6AddrPort) bool { return true },
|
||||||
)
|
)
|
||||||
|
|
||||||
b.Run("no preferred", func(b *testing.B) {
|
b.Run("no preferred", func(b *testing.B) {
|
||||||
@ -160,11 +160,11 @@ func BenchmarkFullRebuild(b *testing.B) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkSortRebuild(b *testing.B) {
|
func BenchmarkSortRebuild(b *testing.B) {
|
||||||
rl := NewRemoteList(nil)
|
rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil)
|
||||||
rl.unlockedSetV4(
|
rl.unlockedSetV4(
|
||||||
netip.MustParseAddr("0.0.0.0"),
|
netip.MustParseAddr("0.0.0.0"),
|
||||||
netip.MustParseAddr("0.0.0.0"),
|
netip.MustParseAddr("0.0.0.0"),
|
||||||
[]*Ip4AndPort{
|
[]*V4AddrPort{
|
||||||
newIp4AndPortFromString("70.199.182.92:1475"),
|
newIp4AndPortFromString("70.199.182.92:1475"),
|
||||||
newIp4AndPortFromString("172.17.0.182:10101"),
|
newIp4AndPortFromString("172.17.0.182:10101"),
|
||||||
newIp4AndPortFromString("172.17.1.1:10101"),
|
newIp4AndPortFromString("172.17.1.1:10101"),
|
||||||
@ -174,19 +174,19 @@ func BenchmarkSortRebuild(b *testing.B) {
|
|||||||
newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe
|
newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe
|
||||||
newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port
|
newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port
|
||||||
},
|
},
|
||||||
func(netip.Addr, *Ip4AndPort) bool { return true },
|
func(netip.Addr, *V4AddrPort) bool { return true },
|
||||||
)
|
)
|
||||||
|
|
||||||
rl.unlockedSetV6(
|
rl.unlockedSetV6(
|
||||||
netip.MustParseAddr("0.0.0.0"),
|
netip.MustParseAddr("0.0.0.0"),
|
||||||
netip.MustParseAddr("0.0.0.0"),
|
netip.MustParseAddr("0.0.0.0"),
|
||||||
[]*Ip6AndPort{
|
[]*V6AddrPort{
|
||||||
newIp6AndPortFromString("[1::1]:1"),
|
newIp6AndPortFromString("[1::1]:1"),
|
||||||
newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port
|
newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port
|
||||||
newIp6AndPortFromString("[1:100::1]:1"),
|
newIp6AndPortFromString("[1:100::1]:1"),
|
||||||
newIp6AndPortFromString("[1::1]:1"), // this is a dupe
|
newIp6AndPortFromString("[1::1]:1"), // this is a dupe
|
||||||
},
|
},
|
||||||
func(netip.Addr, *Ip6AndPort) bool { return true },
|
func(netip.Addr, *V6AddrPort) bool { return true },
|
||||||
)
|
)
|
||||||
|
|
||||||
b.Run("no preferred", func(b *testing.B) {
|
b.Run("no preferred", func(b *testing.B) {
|
||||||
@ -224,19 +224,19 @@ func BenchmarkSortRebuild(b *testing.B) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func newIp4AndPortFromString(s string) *Ip4AndPort {
|
func newIp4AndPortFromString(s string) *V4AddrPort {
|
||||||
a := netip.MustParseAddrPort(s)
|
a := netip.MustParseAddrPort(s)
|
||||||
v4Addr := a.Addr().As4()
|
v4Addr := a.Addr().As4()
|
||||||
return &Ip4AndPort{
|
return &V4AddrPort{
|
||||||
Ip: binary.BigEndian.Uint32(v4Addr[:]),
|
Addr: binary.BigEndian.Uint32(v4Addr[:]),
|
||||||
Port: uint32(a.Port()),
|
Port: uint32(a.Port()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newIp6AndPortFromString(s string) *Ip6AndPort {
|
func newIp6AndPortFromString(s string) *V6AddrPort {
|
||||||
a := netip.MustParseAddrPort(s)
|
a := netip.MustParseAddrPort(s)
|
||||||
v6Addr := a.Addr().As16()
|
v6Addr := a.Addr().As16()
|
||||||
return &Ip6AndPort{
|
return &V6AddrPort{
|
||||||
Hi: binary.BigEndian.Uint64(v6Addr[:8]),
|
Hi: binary.BigEndian.Uint64(v6Addr[:8]),
|
||||||
Lo: binary.BigEndian.Uint64(v6Addr[8:]),
|
Lo: binary.BigEndian.Uint64(v6Addr[8:]),
|
||||||
Port: uint32(a.Port()),
|
Port: uint32(a.Port()),
|
||||||
|
|||||||
@ -90,9 +90,9 @@ func New(config *config.C) (*Service, error) {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
ipNet := device.Cidr()
|
ipNet := device.Networks()
|
||||||
pa := tcpip.ProtocolAddress{
|
pa := tcpip.ProtocolAddress{
|
||||||
AddressWithPrefix: tcpip.AddrFromSlice(ipNet.Addr().AsSlice()).WithPrefix(),
|
AddressWithPrefix: tcpip.AddrFromSlice(ipNet[0].Addr().AsSlice()).WithPrefix(),
|
||||||
Protocol: ipv4.ProtocolNumber,
|
Protocol: ipv4.ProtocolNumber,
|
||||||
}
|
}
|
||||||
if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{
|
if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{
|
||||||
|
|||||||
@ -19,7 +19,7 @@ import (
|
|||||||
type m map[string]interface{}
|
type m map[string]interface{}
|
||||||
|
|
||||||
func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service {
|
func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service {
|
||||||
_, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.PrefixFrom(udpIp, 24)}, nil, []string{})
|
_, _, myPrivKey, myPEM := e2e.NewTestCert(cert.Version2, caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.PrefixFrom(udpIp, 24)}, nil, []string{})
|
||||||
caB, err := caCrt.MarshalPEM()
|
caB, err := caCrt.MarshalPEM()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
|
|||||||
35
ssh.go
35
ssh.go
@ -430,7 +430,7 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er
|
|||||||
}
|
}
|
||||||
|
|
||||||
sort.Slice(hm, func(i, j int) bool {
|
sort.Slice(hm, func(i, j int) bool {
|
||||||
return hm[i].VpnIp.Compare(hm[j].VpnIp) < 0
|
return hm[i].VpnAddrs[0].Compare(hm[j].VpnAddrs[0]) < 0
|
||||||
})
|
})
|
||||||
|
|
||||||
if fs.Json || fs.Pretty {
|
if fs.Json || fs.Pretty {
|
||||||
@ -447,7 +447,7 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er
|
|||||||
|
|
||||||
} else {
|
} else {
|
||||||
for _, v := range hm {
|
for _, v := range hm {
|
||||||
err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIp, v.RemoteAddrs))
|
err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnAddrs, v.RemoteAddrs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -581,7 +581,7 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
|
|||||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||||
}
|
}
|
||||||
|
|
||||||
hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
|
hostInfo := ifce.hostMap.QueryVpnAddr(vpnIp)
|
||||||
if hostInfo == nil {
|
if hostInfo == nil {
|
||||||
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
|
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
|
||||||
}
|
}
|
||||||
@ -622,12 +622,12 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
|
|||||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||||
}
|
}
|
||||||
|
|
||||||
hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
|
hostInfo := ifce.hostMap.QueryVpnAddr(vpnIp)
|
||||||
if hostInfo != nil {
|
if hostInfo != nil {
|
||||||
return w.WriteLine(fmt.Sprintf("Tunnel already exists"))
|
return w.WriteLine(fmt.Sprintf("Tunnel already exists"))
|
||||||
}
|
}
|
||||||
|
|
||||||
hostInfo = ifce.handshakeManager.QueryVpnIp(vpnIp)
|
hostInfo = ifce.handshakeManager.QueryVpnAddr(vpnIp)
|
||||||
if hostInfo != nil {
|
if hostInfo != nil {
|
||||||
return w.WriteLine(fmt.Sprintf("Tunnel already handshaking"))
|
return w.WriteLine(fmt.Sprintf("Tunnel already handshaking"))
|
||||||
}
|
}
|
||||||
@ -677,7 +677,7 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW
|
|||||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||||
}
|
}
|
||||||
|
|
||||||
hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
|
hostInfo := ifce.hostMap.QueryVpnAddr(vpnIp)
|
||||||
if hostInfo == nil {
|
if hostInfo == nil {
|
||||||
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
|
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
|
||||||
}
|
}
|
||||||
@ -785,7 +785,8 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
cert := ifce.pki.GetCertState().Certificate
|
//TODO: This should return both certs
|
||||||
|
cert := ifce.pki.getDefaultCertificate()
|
||||||
if len(a) > 0 {
|
if len(a) > 0 {
|
||||||
vpnIp, err := netip.ParseAddr(a[0])
|
vpnIp, err := netip.ParseAddr(a[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -796,7 +797,7 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
|
|||||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||||
}
|
}
|
||||||
|
|
||||||
hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
|
hostInfo := ifce.hostMap.QueryVpnAddr(vpnIp)
|
||||||
if hostInfo == nil {
|
if hostInfo == nil {
|
||||||
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
|
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
|
||||||
}
|
}
|
||||||
@ -880,16 +881,16 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
|
|||||||
}
|
}
|
||||||
|
|
||||||
for k, v := range relays {
|
for k, v := range relays {
|
||||||
ro := RelayOutput{NebulaIp: v.vpnIp}
|
ro := RelayOutput{NebulaIp: v.vpnAddrs[0]}
|
||||||
co.Relays = append(co.Relays, &ro)
|
co.Relays = append(co.Relays, &ro)
|
||||||
relayHI := ifce.hostMap.QueryVpnIp(v.vpnIp)
|
relayHI := ifce.hostMap.QueryVpnAddr(v.vpnAddrs[0])
|
||||||
if relayHI == nil {
|
if relayHI == nil {
|
||||||
ro.RelayForIps = append(ro.RelayForIps, RelayFor{Error: errors.New("could not find hostinfo")})
|
ro.RelayForIps = append(ro.RelayForIps, RelayFor{Error: errors.New("could not find hostinfo")})
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
for _, vpnIp := range relayHI.relayState.CopyRelayForIps() {
|
for _, vpnIp := range relayHI.relayState.CopyRelayForIps() {
|
||||||
rf := RelayFor{Error: nil}
|
rf := RelayFor{Error: nil}
|
||||||
r, ok := relayHI.relayState.GetRelayForByIp(vpnIp)
|
r, ok := relayHI.relayState.GetRelayForByAddr(vpnIp)
|
||||||
if ok {
|
if ok {
|
||||||
t := ""
|
t := ""
|
||||||
switch r.Type {
|
switch r.Type {
|
||||||
@ -913,14 +914,14 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
|
|||||||
|
|
||||||
rf.LocalIndex = r.LocalIndex
|
rf.LocalIndex = r.LocalIndex
|
||||||
rf.RemoteIndex = r.RemoteIndex
|
rf.RemoteIndex = r.RemoteIndex
|
||||||
rf.PeerIp = r.PeerIp
|
rf.PeerIp = r.PeerAddr
|
||||||
rf.Type = t
|
rf.Type = t
|
||||||
rf.State = s
|
rf.State = s
|
||||||
if rf.LocalIndex != k {
|
if rf.LocalIndex != k {
|
||||||
rf.Error = fmt.Errorf("hostmap LocalIndex '%v' does not match RelayState LocalIndex", k)
|
rf.Error = fmt.Errorf("hostmap LocalIndex '%v' does not match RelayState LocalIndex", k)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
relayedHI := ifce.hostMap.QueryVpnIp(vpnIp)
|
relayedHI := ifce.hostMap.QueryVpnAddr(vpnIp)
|
||||||
if relayedHI != nil {
|
if relayedHI != nil {
|
||||||
rf.RelayedThrough = append(rf.RelayedThrough, relayedHI.relayState.CopyRelayIps()...)
|
rf.RelayedThrough = append(rf.RelayedThrough, relayedHI.relayState.CopyRelayIps()...)
|
||||||
}
|
}
|
||||||
@ -955,7 +956,7 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
|
|||||||
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
|
||||||
}
|
}
|
||||||
|
|
||||||
hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
|
hostInfo := ifce.hostMap.QueryVpnAddr(vpnIp)
|
||||||
if hostInfo == nil {
|
if hostInfo == nil {
|
||||||
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
|
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
|
||||||
}
|
}
|
||||||
@ -972,12 +973,14 @@ func sshDeviceInfo(ifce *Interface, fs interface{}, w sshd.StringWriter) error {
|
|||||||
|
|
||||||
data := struct {
|
data := struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Cidr string `json:"cidr"`
|
Cidr []netip.Prefix `json:"cidr"`
|
||||||
}{
|
}{
|
||||||
Name: ifce.inside.Name(),
|
Name: ifce.inside.Name(),
|
||||||
Cidr: ifce.inside.Cidr().String(),
|
Cidr: make([]netip.Prefix, len(ifce.inside.Networks())),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
copy(data.Cidr, ifce.inside.Networks())
|
||||||
|
|
||||||
flags, ok := fs.(*sshDeviceInfoFlags)
|
flags, ok := fs.(*sshDeviceInfoFlags)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("internal error: expected flags to be sshDeviceInfoFlags but was %+v", fs)
|
return fmt.Errorf("internal error: expected flags to be sshDeviceInfoFlags but was %+v", fs)
|
||||||
|
|||||||
@ -16,8 +16,8 @@ func (NoopTun) Activate() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (NoopTun) Cidr() netip.Prefix {
|
func (NoopTun) Networks() []netip.Prefix {
|
||||||
return netip.Prefix{}
|
return []netip.Prefix{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (NoopTun) Name() string {
|
func (NoopTun) Name() string {
|
||||||
|
|||||||
@ -116,10 +116,10 @@ func TestTimerWheel_Purge(t *testing.T) {
|
|||||||
assert.Equal(t, 0, tw.current)
|
assert.Equal(t, 0, tw.current)
|
||||||
|
|
||||||
fps := []firewall.Packet{
|
fps := []firewall.Packet{
|
||||||
{LocalIP: netip.MustParseAddr("0.0.0.1")},
|
{LocalAddr: netip.MustParseAddr("0.0.0.1")},
|
||||||
{LocalIP: netip.MustParseAddr("0.0.0.2")},
|
{LocalAddr: netip.MustParseAddr("0.0.0.2")},
|
||||||
{LocalIP: netip.MustParseAddr("0.0.0.3")},
|
{LocalAddr: netip.MustParseAddr("0.0.0.3")},
|
||||||
{LocalIP: netip.MustParseAddr("0.0.0.4")},
|
{LocalAddr: netip.MustParseAddr("0.0.0.4")},
|
||||||
}
|
}
|
||||||
|
|
||||||
tw.Add(fps[0], time.Second*1)
|
tw.Add(fps[0], time.Second*1)
|
||||||
|
|||||||
15
udp/conn.go
15
udp/conn.go
@ -4,28 +4,19 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/firewall"
|
|
||||||
"github.com/slackhq/nebula/header"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const MTU = 9001
|
const MTU = 9001
|
||||||
|
|
||||||
type EncReader func(
|
type EncReader func(
|
||||||
addr netip.AddrPort,
|
addr netip.AddrPort,
|
||||||
out []byte,
|
payload []byte,
|
||||||
packet []byte,
|
|
||||||
header *header.H,
|
|
||||||
fwPacket *firewall.Packet,
|
|
||||||
lhh LightHouseHandlerFunc,
|
|
||||||
nb []byte,
|
|
||||||
q int,
|
|
||||||
localCache firewall.ConntrackCache,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Conn interface {
|
type Conn interface {
|
||||||
Rebind() error
|
Rebind() error
|
||||||
LocalAddr() (netip.AddrPort, error)
|
LocalAddr() (netip.AddrPort, error)
|
||||||
ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int)
|
ListenOut(r EncReader)
|
||||||
WriteTo(b []byte, addr netip.AddrPort) error
|
WriteTo(b []byte, addr netip.AddrPort) error
|
||||||
ReloadConfig(c *config.C)
|
ReloadConfig(c *config.C)
|
||||||
Close() error
|
Close() error
|
||||||
@ -39,7 +30,7 @@ func (NoopConn) Rebind() error {
|
|||||||
func (NoopConn) LocalAddr() (netip.AddrPort, error) {
|
func (NoopConn) LocalAddr() (netip.AddrPort, error) {
|
||||||
return netip.AddrPort{}, nil
|
return netip.AddrPort{}, nil
|
||||||
}
|
}
|
||||||
func (NoopConn) ListenOut(_ EncReader, _ LightHouseHandlerFunc, _ *firewall.ConntrackCacheTicker, _ int) {
|
func (NoopConn) ListenOut(_ EncReader) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
|
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
|
||||||
|
|||||||
10
udp/temp.go
10
udp/temp.go
@ -1,10 +0,0 @@
|
|||||||
package udp
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
)
|
|
||||||
|
|
||||||
//TODO: The items in this file belong in their own packages but doing that in a single PR is a nightmare
|
|
||||||
|
|
||||||
// TODO: IPV6-WORK this can likely be removed now
|
|
||||||
type LightHouseHandlerFunc func(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte)
|
|
||||||
@ -15,8 +15,6 @@ import (
|
|||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/firewall"
|
|
||||||
"github.com/slackhq/nebula/header"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type GenericConn struct {
|
type GenericConn struct {
|
||||||
@ -72,12 +70,8 @@ type rawMessage struct {
|
|||||||
Len uint32
|
Len uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *GenericConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
|
func (u *GenericConn) ListenOut(r EncReader) {
|
||||||
plaintext := make([]byte, MTU)
|
|
||||||
buffer := make([]byte, MTU)
|
buffer := make([]byte, MTU)
|
||||||
h := &header.H{}
|
|
||||||
fwPacket := &firewall.Packet{}
|
|
||||||
nb := make([]byte, 12, 12)
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
// Just read one packet at a time
|
// Just read one packet at a time
|
||||||
@ -87,16 +81,6 @@ func (u *GenericConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *f
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
r(
|
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
|
||||||
netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()),
|
|
||||||
plaintext[:0],
|
|
||||||
buffer[:n],
|
|
||||||
h,
|
|
||||||
fwPacket,
|
|
||||||
lhf,
|
|
||||||
nb,
|
|
||||||
q,
|
|
||||||
cache.Get(u.l),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -14,8 +14,6 @@ import (
|
|||||||
"github.com/rcrowley/go-metrics"
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/firewall"
|
|
||||||
"github.com/slackhq/nebula/header"
|
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -120,15 +118,9 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
|
func (u *StdConn) ListenOut(r EncReader) {
|
||||||
plaintext := make([]byte, MTU)
|
|
||||||
h := &header.H{}
|
|
||||||
fwPacket := &firewall.Packet{}
|
|
||||||
var ip netip.Addr
|
var ip netip.Addr
|
||||||
nb := make([]byte, 12, 12)
|
|
||||||
|
|
||||||
//TODO: should we track this?
|
|
||||||
//metric := metrics.GetOrRegisterHistogram("test.batch_read", nil, metrics.NewExpDecaySample(1028, 0.015))
|
|
||||||
msgs, buffers, names := u.PrepareRawMessages(u.batch)
|
msgs, buffers, names := u.PrepareRawMessages(u.batch)
|
||||||
read := u.ReadMulti
|
read := u.ReadMulti
|
||||||
if u.batch == 1 {
|
if u.batch == 1 {
|
||||||
@ -142,26 +134,14 @@ func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
//metric.Update(int64(n))
|
|
||||||
for i := 0; i < n; i++ {
|
for i := 0; i < n; i++ {
|
||||||
|
// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
|
||||||
if u.isV4 {
|
if u.isV4 {
|
||||||
ip, _ = netip.AddrFromSlice(names[i][4:8])
|
ip, _ = netip.AddrFromSlice(names[i][4:8])
|
||||||
//TODO: IPV6-WORK what is not ok?
|
|
||||||
} else {
|
} else {
|
||||||
ip, _ = netip.AddrFromSlice(names[i][8:24])
|
ip, _ = netip.AddrFromSlice(names[i][8:24])
|
||||||
//TODO: IPV6-WORK what is not ok?
|
|
||||||
}
|
}
|
||||||
r(
|
r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len])
|
||||||
netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])),
|
|
||||||
plaintext[:0],
|
|
||||||
buffers[i][:msgs[i].Len],
|
|
||||||
h,
|
|
||||||
fwPacket,
|
|
||||||
lhf,
|
|
||||||
nb,
|
|
||||||
q,
|
|
||||||
cache.Get(u.l),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -18,9 +18,6 @@ import (
|
|||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/firewall"
|
|
||||||
"github.com/slackhq/nebula/header"
|
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
"golang.zx2c4.com/wireguard/conn/winrio"
|
"golang.zx2c4.com/wireguard/conn/winrio"
|
||||||
)
|
)
|
||||||
@ -118,12 +115,8 @@ func (u *RIOConn) bind(sa windows.Sockaddr) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
|
func (u *RIOConn) ListenOut(r EncReader) {
|
||||||
plaintext := make([]byte, MTU)
|
|
||||||
buffer := make([]byte, MTU)
|
buffer := make([]byte, MTU)
|
||||||
h := &header.H{}
|
|
||||||
fwPacket := &firewall.Packet{}
|
|
||||||
nb := make([]byte, 12, 12)
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
// Just read one packet at a time
|
// Just read one packet at a time
|
||||||
@ -133,17 +126,7 @@ func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
r(
|
r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n])
|
||||||
netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)),
|
|
||||||
plaintext[:0],
|
|
||||||
buffer[:n],
|
|
||||||
h,
|
|
||||||
fwPacket,
|
|
||||||
lhf,
|
|
||||||
nb,
|
|
||||||
q,
|
|
||||||
cache.Get(u.l),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -10,7 +10,6 @@ import (
|
|||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/firewall"
|
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -107,18 +106,13 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *TesterConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
|
func (u *TesterConn) ListenOut(r EncReader) {
|
||||||
plaintext := make([]byte, MTU)
|
|
||||||
h := &header.H{}
|
|
||||||
fwPacket := &firewall.Packet{}
|
|
||||||
nb := make([]byte, 12, 12)
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
p, ok := <-u.RxPackets
|
p, ok := <-u.RxPackets
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
r(p.From, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l))
|
r(p.From, p.Data)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user