V2 certificate format (#1216)

Co-authored-by: Nate Brown <nbrown.us@gmail.com>
Co-authored-by: Jack Doan <jackdoan@rivian.com>
Co-authored-by: brad-defined <77982333+brad-defined@users.noreply.github.com>
Co-authored-by: Jack Doan <me@jackdoan.com>
This commit is contained in:
Nate Brown 2025-03-06 11:28:26 -06:00 committed by GitHub
parent 2b427a7e89
commit d97ed57a19
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
105 changed files with 8276 additions and 4528 deletions

3
.gitignore vendored
View File

@ -5,7 +5,8 @@
/nebula-darwin /nebula-darwin
/nebula.exe /nebula.exe
/nebula-cert.exe /nebula-cert.exe
/coverage.out **/coverage.out
**/cover.out
/cpu.pprof /cpu.pprof
/build /build
/*.tar.gz /*.tar.gz

View File

@ -196,7 +196,7 @@ bench-cpu-long:
go test -bench=. -benchtime=60s -cpuprofile=cpu.pprof go test -bench=. -benchtime=60s -cpuprofile=cpu.pprof
go tool pprof go-audit.test cpu.pprof go tool pprof go-audit.test cpu.pprof
proto: nebula.pb.go cert/cert.pb.go proto: nebula.pb.go cert/cert_v1.pb.go
nebula.pb.go: nebula.proto .FORCE nebula.pb.go: nebula.proto .FORCE
go build github.com/gogo/protobuf/protoc-gen-gogofaster go build github.com/gogo/protobuf/protoc-gen-gogofaster

View File

@ -128,7 +128,6 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
ipNet = netip.PrefixFrom(ipNet.Addr().Unmap(), ipNet.Bits()) ipNet = netip.PrefixFrom(ipNet.Addr().Unmap(), ipNet.Bits())
// TODO: should we error on duplicate CIDRs in the config?
tree.Insert(ipNet, value) tree.Insert(ipNet, value)
maskBits := ipNet.Bits() maskBits := ipNet.Bits()
@ -251,20 +250,20 @@ func getRemoteAllowRanges(c *config.C, k string) (*bart.Table[*AllowList], error
return remoteAllowRanges, nil return remoteAllowRanges, nil
} }
func (al *AllowList) Allow(ip netip.Addr) bool { func (al *AllowList) Allow(addr netip.Addr) bool {
if al == nil { if al == nil {
return true return true
} }
result, _ := al.cidrTree.Lookup(ip) result, _ := al.cidrTree.Lookup(addr)
return result return result
} }
func (al *LocalAllowList) Allow(ip netip.Addr) bool { func (al *LocalAllowList) Allow(udpAddr netip.Addr) bool {
if al == nil { if al == nil {
return true return true
} }
return al.AllowList.Allow(ip) return al.AllowList.Allow(udpAddr)
} }
func (al *LocalAllowList) AllowName(name string) bool { func (al *LocalAllowList) AllowName(name string) bool {
@ -282,23 +281,37 @@ func (al *LocalAllowList) AllowName(name string) bool {
return !al.nameRules[0].Allow return !al.nameRules[0].Allow
} }
func (al *RemoteAllowList) AllowUnknownVpnIp(ip netip.Addr) bool { func (al *RemoteAllowList) AllowUnknownVpnAddr(vpnAddr netip.Addr) bool {
if al == nil { if al == nil {
return true return true
} }
return al.AllowList.Allow(ip) return al.AllowList.Allow(vpnAddr)
} }
func (al *RemoteAllowList) Allow(vpnIp netip.Addr, ip netip.Addr) bool { func (al *RemoteAllowList) Allow(vpnAddr netip.Addr, udpAddr netip.Addr) bool {
if !al.getInsideAllowList(vpnIp).Allow(ip) { if !al.getInsideAllowList(vpnAddr).Allow(udpAddr) {
return false return false
} }
return al.AllowList.Allow(ip) return al.AllowList.Allow(udpAddr)
} }
func (al *RemoteAllowList) getInsideAllowList(vpnIp netip.Addr) *AllowList { func (al *RemoteAllowList) AllowAll(vpnAddrs []netip.Addr, udpAddr netip.Addr) bool {
if !al.AllowList.Allow(udpAddr) {
return false
}
for _, vpnAddr := range vpnAddrs {
if !al.getInsideAllowList(vpnAddr).Allow(udpAddr) {
return false
}
}
return true
}
func (al *RemoteAllowList) getInsideAllowList(vpnAddr netip.Addr) *AllowList {
if al.insideAllowLists != nil { if al.insideAllowLists != nil {
inside, ok := al.insideAllowLists.Lookup(vpnIp) inside, ok := al.insideAllowLists.Lookup(vpnAddr)
if ok { if ok {
return inside return inside
} }

View File

@ -21,7 +21,11 @@ type calculatedRemote struct {
port uint32 port uint32
} }
func newCalculatedRemote(maskCidr netip.Prefix, port int) (*calculatedRemote, error) { func newCalculatedRemote(cidr, maskCidr netip.Prefix, port int) (*calculatedRemote, error) {
if maskCidr.Addr().BitLen() != cidr.Addr().BitLen() {
return nil, fmt.Errorf("invalid mask: %s for cidr: %s", maskCidr, cidr)
}
masked := maskCidr.Masked() masked := maskCidr.Masked()
if port < 0 || port > math.MaxUint16 { if port < 0 || port > math.MaxUint16 {
return nil, fmt.Errorf("invalid port: %d", port) return nil, fmt.Errorf("invalid port: %d", port)
@ -38,32 +42,38 @@ func (c *calculatedRemote) String() string {
return fmt.Sprintf("CalculatedRemote(mask=%v port=%d)", c.ipNet, c.port) return fmt.Sprintf("CalculatedRemote(mask=%v port=%d)", c.ipNet, c.port)
} }
func (c *calculatedRemote) Apply(ip netip.Addr) *Ip4AndPort { func (c *calculatedRemote) ApplyV4(addr netip.Addr) *V4AddrPort {
// Combine the masked bytes of the "mask" IP with the unmasked bytes // Combine the masked bytes of the "mask" IP with the unmasked bytes of the overlay IP
// of the overlay IP
if c.ipNet.Addr().Is4() {
return c.apply4(ip)
}
return c.apply6(ip)
}
func (c *calculatedRemote) apply4(ip netip.Addr) *Ip4AndPort {
//TODO: IPV6-WORK this can be less crappy
maskb := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen()) maskb := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen())
mask := binary.BigEndian.Uint32(maskb[:]) mask := binary.BigEndian.Uint32(maskb[:])
b := c.mask.Addr().As4() b := c.mask.Addr().As4()
maskIp := binary.BigEndian.Uint32(b[:]) maskAddr := binary.BigEndian.Uint32(b[:])
b = ip.As4() b = addr.As4()
intIp := binary.BigEndian.Uint32(b[:]) intAddr := binary.BigEndian.Uint32(b[:])
return &Ip4AndPort{(maskIp & mask) | (intIp & ^mask), c.port} return &V4AddrPort{(maskAddr & mask) | (intAddr & ^mask), c.port}
} }
func (c *calculatedRemote) apply6(ip netip.Addr) *Ip4AndPort { func (c *calculatedRemote) ApplyV6(addr netip.Addr) *V6AddrPort {
//TODO: IPV6-WORK mask := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen())
panic("Can not calculate ipv6 remote addresses") maskAddr := c.mask.Addr().As16()
calcAddr := addr.As16()
ap := V6AddrPort{Port: c.port}
maskb := binary.BigEndian.Uint64(mask[:8])
maskAddrb := binary.BigEndian.Uint64(maskAddr[:8])
calcAddrb := binary.BigEndian.Uint64(calcAddr[:8])
ap.Hi = (maskAddrb & maskb) | (calcAddrb & ^maskb)
maskb = binary.BigEndian.Uint64(mask[8:])
maskAddrb = binary.BigEndian.Uint64(maskAddr[8:])
calcAddrb = binary.BigEndian.Uint64(calcAddr[8:])
ap.Lo = (maskAddrb & maskb) | (calcAddrb & ^maskb)
return &ap
} }
func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calculatedRemote], error) { func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calculatedRemote], error) {
@ -89,8 +99,7 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calcu
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR) return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
} }
//TODO: IPV6-WORK this does not verify that rawValue contains the same bits as cidr here entry, err := newCalculatedRemotesListFromConfig(cidr, rawValue)
entry, err := newCalculatedRemotesListFromConfig(rawValue)
if err != nil { if err != nil {
return nil, fmt.Errorf("config '%s.%s': %w", k, rawCIDR, err) return nil, fmt.Errorf("config '%s.%s': %w", k, rawCIDR, err)
} }
@ -101,7 +110,7 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calcu
return calculatedRemotes, nil return calculatedRemotes, nil
} }
func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) { func newCalculatedRemotesListFromConfig(cidr netip.Prefix, raw any) ([]*calculatedRemote, error) {
rawList, ok := raw.([]any) rawList, ok := raw.([]any)
if !ok { if !ok {
return nil, fmt.Errorf("calculated_remotes entry has invalid type: %T", raw) return nil, fmt.Errorf("calculated_remotes entry has invalid type: %T", raw)
@ -109,7 +118,7 @@ func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) {
var l []*calculatedRemote var l []*calculatedRemote
for _, e := range rawList { for _, e := range rawList {
c, err := newCalculatedRemotesEntryFromConfig(e) c, err := newCalculatedRemotesEntryFromConfig(cidr, e)
if err != nil { if err != nil {
return nil, fmt.Errorf("calculated_remotes entry: %w", err) return nil, fmt.Errorf("calculated_remotes entry: %w", err)
} }
@ -119,7 +128,7 @@ func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) {
return l, nil return l, nil
} }
func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) { func newCalculatedRemotesEntryFromConfig(cidr netip.Prefix, raw any) (*calculatedRemote, error) {
rawMap, ok := raw.(map[any]any) rawMap, ok := raw.(map[any]any)
if !ok { if !ok {
return nil, fmt.Errorf("invalid type: %T", raw) return nil, fmt.Errorf("invalid type: %T", raw)
@ -155,5 +164,5 @@ func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) {
return nil, fmt.Errorf("invalid port (type %T): %v", rawValue, rawValue) return nil, fmt.Errorf("invalid port (type %T): %v", rawValue, rawValue)
} }
return newCalculatedRemote(maskCidr, port) return newCalculatedRemote(cidr, maskCidr, port)
} }

View File

@ -9,10 +9,9 @@ import (
) )
func TestCalculatedRemoteApply(t *testing.T) { func TestCalculatedRemoteApply(t *testing.T) {
ipNet, err := netip.ParsePrefix("192.168.1.0/24") // Test v4 addresses
require.NoError(t, err) ipNet := netip.MustParsePrefix("192.168.1.0/24")
c, err := newCalculatedRemote(ipNet, ipNet, 4242)
c, err := newCalculatedRemote(ipNet, 4242)
require.NoError(t, err) require.NoError(t, err)
input, err := netip.ParseAddr("10.0.10.182") input, err := netip.ParseAddr("10.0.10.182")
@ -21,5 +20,62 @@ func TestCalculatedRemoteApply(t *testing.T) {
expected, err := netip.ParseAddr("192.168.1.182") expected, err := netip.ParseAddr("192.168.1.182")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, NewIp4AndPortFromNetIP(expected, 4242), c.Apply(input)) assert.Equal(t, netAddrToProtoV4AddrPort(expected, 4242), c.ApplyV4(input))
// Test v6 addresses
ipNet = netip.MustParsePrefix("ffff:ffff:ffff:ffff::0/64")
c, err = newCalculatedRemote(ipNet, ipNet, 4242)
require.NoError(t, err)
input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
assert.NoError(t, err)
expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:beef:beef:beef:beef")
assert.NoError(t, err)
assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
// Test v6 addresses part 2
ipNet = netip.MustParsePrefix("ffff:ffff:ffff:ffff:ffff::0/80")
c, err = newCalculatedRemote(ipNet, ipNet, 4242)
require.NoError(t, err)
input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
assert.NoError(t, err)
expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:ffff:beef:beef:beef")
assert.NoError(t, err)
assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
// Test v6 addresses part 2
ipNet = netip.MustParsePrefix("ffff:ffff:ffff::0/48")
c, err = newCalculatedRemote(ipNet, ipNet, 4242)
require.NoError(t, err)
input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
assert.NoError(t, err)
expected, err = netip.ParseAddr("ffff:ffff:ffff:beef:beef:beef:beef:beef")
assert.NoError(t, err)
assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
}
func Test_newCalculatedRemote(t *testing.T) {
c, err := newCalculatedRemote(netip.MustParsePrefix("1::1/128"), netip.MustParsePrefix("1.0.0.0/32"), 4242)
require.EqualError(t, err, "invalid mask: 1.0.0.0/32 for cidr: 1::1/128")
require.Nil(t, c)
c, err = newCalculatedRemote(netip.MustParsePrefix("1.0.0.0/32"), netip.MustParsePrefix("1::1/128"), 4242)
require.EqualError(t, err, "invalid mask: 1::1/128 for cidr: 1.0.0.0/32")
require.Nil(t, c)
c, err = newCalculatedRemote(netip.MustParsePrefix("1.0.0.0/32"), netip.MustParsePrefix("1.0.0.0/32"), 4242)
require.NoError(t, err)
require.NotNil(t, c)
c, err = newCalculatedRemote(netip.MustParsePrefix("1::1/128"), netip.MustParsePrefix("1::1/128"), 4242)
require.NoError(t, err)
require.NotNil(t, c)
} }

View File

@ -2,14 +2,25 @@
This is a library for interacting with `nebula` style certificates and authorities. This is a library for interacting with `nebula` style certificates and authorities.
A `protobuf` definition of the certificate format is also included There are now 2 versions of `nebula` certificates:
### Compiling the protobuf definition ## v1
Make sure you have `protoc` installed. This version is deprecated.
A `protobuf` definition of the certificate format is included at `cert_v1.proto`
To compile the definition you will need `protoc` installed.
To compile for `go` with the same version of protobuf specified in go.mod: To compile for `go` with the same version of protobuf specified in go.mod:
```bash ```bash
make make proto
``` ```
## v2
This is the latest version which uses asn.1 DER encoding. It can support ipv4 and ipv6 and tolerate
future certificate changes better than v1.
`cert_v2.asn1` defines the wire format and can be used to compile marshalers.

52
cert/asn1.go Normal file
View File

@ -0,0 +1,52 @@
package cert
import (
"golang.org/x/crypto/cryptobyte"
"golang.org/x/crypto/cryptobyte/asn1"
)
// readOptionalASN1Boolean reads an asn.1 boolean with a specific tag instead of a asn.1 tag wrapping a boolean with a value
// https://github.com/golang/go/issues/64811#issuecomment-1944446920
func readOptionalASN1Boolean(b *cryptobyte.String, out *bool, tag asn1.Tag, defaultValue bool) bool {
var present bool
var child cryptobyte.String
if !b.ReadOptionalASN1(&child, &present, tag) {
return false
}
if !present {
*out = defaultValue
return true
}
// Ensure we have 1 byte
if len(child) == 1 {
*out = child[0] > 0
return true
}
return false
}
// readOptionalASN1Byte reads an asn.1 uint8 with a specific tag instead of a asn.1 tag wrapping a uint8 with a value
// Similar issue as with readOptionalASN1Boolean
func readOptionalASN1Byte(b *cryptobyte.String, out *byte, tag asn1.Tag, defaultValue byte) bool {
var present bool
var child cryptobyte.String
if !b.ReadOptionalASN1(&child, &present, tag) {
return false
}
if !present {
*out = defaultValue
return true
}
// Ensure we have 1 byte
if len(child) == 1 {
*out = child[0]
return true
}
return false
}

View File

@ -213,7 +213,7 @@ func (ncp *CAPool) GetCAForCert(c Certificate) (*CachedCertificate, error) {
return signer, nil return signer, nil
} }
return nil, fmt.Errorf("could not find ca for the certificate") return nil, ErrCaNotFound
} }
// GetFingerprints returns an array of trusted CA fingerprints // GetFingerprints returns an array of trusted CA fingerprints

View File

@ -1,7 +1,9 @@
package cert package cert
import ( import (
"net/netip"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -10,15 +12,15 @@ func TestNewCAPoolFromBytes(t *testing.T) {
noNewLines := ` noNewLines := `
# Current provisional, Remove once everything moves over to the real root. # Current provisional, Remove once everything moves over to the real root.
-----BEGIN NEBULA CERTIFICATE----- -----BEGIN NEBULA CERTIFICATE-----
CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL Cj4KDm5lYnVsYSByb290IGNhKM0cMM24zPCvBzogV24YEw5YiqeI/oYo8XXFsoo+
vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv PBmiOafNJhLacf9rsspAARJAz9OAnh8TKAUKix1kKVMyQU4iM3LsFfZRf6ODWXIf
bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB 2qWMpB6fpd3PSoVYziPoOt2bIHIFLlgRLPJz3I3xBEdBCQ==
-----END NEBULA CERTIFICATE----- -----END NEBULA CERTIFICATE-----
# root-ca01 # root-ca01
-----BEGIN NEBULA CERTIFICATE----- -----BEGIN NEBULA CERTIFICATE-----
CkMKEW5lYnVsYSByb290IGNhIDAxKJL2u9EFMJL86+cGOiDPXMH4oU6HZTk/CqTG CkEKEW5lYnVsYSByb290IGNhIDAxKM0cMM24zPCvBzogPzbWTxt8ZgXPQEwup7Br
BVG+oJpAoqokUBbI4U0N8CSfpUABEkB/Pm5A2xyH/nc8mg/wvGUWG3pZ7nHzaDMf BrtIt1O0q5AuTRT3+t2x1VJAARJAZ+2ib23qBXjdy49oU1YysrwuKkWWKrtJ7Jye
8/phAUt+FLzqTECzQKisYswKvE3pl9mbEYKbOdIHrxdIp95mo4sF rFBQpDXikOukhQD/mfkloFwJ+Yjsfru7IpTN4ZfjXL+kN/2sCA==
-----END NEBULA CERTIFICATE----- -----END NEBULA CERTIFICATE-----
` `
@ -26,18 +28,18 @@ BVG+oJpAoqokUBbI4U0N8CSfpUABEkB/Pm5A2xyH/nc8mg/wvGUWG3pZ7nHzaDMf
# Current provisional, Remove once everything moves over to the real root. # Current provisional, Remove once everything moves over to the real root.
-----BEGIN NEBULA CERTIFICATE----- -----BEGIN NEBULA CERTIFICATE-----
CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL Cj4KDm5lYnVsYSByb290IGNhKM0cMM24zPCvBzogV24YEw5YiqeI/oYo8XXFsoo+
vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv PBmiOafNJhLacf9rsspAARJAz9OAnh8TKAUKix1kKVMyQU4iM3LsFfZRf6ODWXIf
bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB 2qWMpB6fpd3PSoVYziPoOt2bIHIFLlgRLPJz3I3xBEdBCQ==
-----END NEBULA CERTIFICATE----- -----END NEBULA CERTIFICATE-----
# root-ca01 # root-ca01
-----BEGIN NEBULA CERTIFICATE----- -----BEGIN NEBULA CERTIFICATE-----
CkMKEW5lYnVsYSByb290IGNhIDAxKJL2u9EFMJL86+cGOiDPXMH4oU6HZTk/CqTG CkEKEW5lYnVsYSByb290IGNhIDAxKM0cMM24zPCvBzogPzbWTxt8ZgXPQEwup7Br
BVG+oJpAoqokUBbI4U0N8CSfpUABEkB/Pm5A2xyH/nc8mg/wvGUWG3pZ7nHzaDMf BrtIt1O0q5AuTRT3+t2x1VJAARJAZ+2ib23qBXjdy49oU1YysrwuKkWWKrtJ7Jye
8/phAUt+FLzqTECzQKisYswKvE3pl9mbEYKbOdIHrxdIp95mo4sF rFBQpDXikOukhQD/mfkloFwJ+Yjsfru7IpTN4ZfjXL+kN/2sCA==
-----END NEBULA CERTIFICATE----- -----END NEBULA CERTIFICATE-----
` `
@ -45,65 +47,513 @@ BVG+oJpAoqokUBbI4U0N8CSfpUABEkB/Pm5A2xyH/nc8mg/wvGUWG3pZ7nHzaDMf
expired := ` expired := `
# expired certificate # expired certificate
-----BEGIN NEBULA CERTIFICATE----- -----BEGIN NEBULA CERTIFICATE-----
CjkKB2V4cGlyZWQouPmWjQYwufmWjQY6ILCRaoCkJlqHgv5jfDN4lzLHBvDzaQm4 CjMKB2V4cGlyZWQozRwwzRw6ICJSG94CqX8wn5I65Pwn25V6HftVfWeIySVtp2DA
vZxfu144hmgjQAESQG4qlnZi8DncvD/LDZnLgJHOaX1DWCHHEh59epVsC+BNgTie 7TY/QAESQMaAk5iJT5EnQwK524ZaaHGEJLUqqbh5yyOHhboIGiVTWkFeH3HccTW8
WH1M9n4O7cFtGlM6sJJOS+rCVVEJ3ABS7+MPdQs= Tq5a8AyWDQdfXbtEZ1FwabeHfH5Asw0=
-----END NEBULA CERTIFICATE----- -----END NEBULA CERTIFICATE-----
` `
p256 := ` p256 := `
# p256 certificate # p256 certificate
-----BEGIN NEBULA CERTIFICATE----- -----BEGIN NEBULA CERTIFICATE-----
CmYKEG5lYnVsYSBQMjU2IHRlc3Qo4s+7mgYw4tXrsAc6QQRkaW2jFmllYvN4+/k2 CmQKEG5lYnVsYSBQMjU2IHRlc3QozRwwzbjM8K8HOkEEdrmmg40zQp44AkMq6DZp
6tctO9sPT3jOx8ES6M1nIqOhpTmZeabF/4rELDqPV4aH5jfJut798DUXql0FlF8H k+coOv04r+zh33ISyhbsafnYduN17p2eD7CmHvHuerguXD9f32gcxo/KsFCKEjMe
76gvQAGgBgESRzBFAiEAib0/te6eMiZOKD8gdDeloMTS0wGuX2t0C7TFdUhAQzgC +0ABoAYBEkcwRQIgVoTg38L7uWku9xQgsr06kxZ/viQLOO/w1Qj1vFUEnhcCIQCq
IBNWYMep3ysx9zCgknfG5dKtwGTaqF++BWKDYdyl34KX 75SjTiV92kv/1GcbT3wWpAZQQDBiUHVMVmh1822szA==
-----END NEBULA CERTIFICATE----- -----END NEBULA CERTIFICATE-----
` `
rootCA := certificateV1{ rootCA := certificateV1{
details: detailsV1{ details: detailsV1{
Name: "nebula root ca", name: "nebula root ca",
}, },
} }
rootCA01 := certificateV1{ rootCA01 := certificateV1{
details: detailsV1{ details: detailsV1{
Name: "nebula root ca 01", name: "nebula root ca 01",
}, },
} }
rootCAP256 := certificateV1{ rootCAP256 := certificateV1{
details: detailsV1{ details: detailsV1{
Name: "nebula P256 test", name: "nebula P256 test",
}, },
} }
p, err := NewCAPoolFromPEM([]byte(noNewLines)) p, err := NewCAPoolFromPEM([]byte(noNewLines))
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, p.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.Name) assert.Equal(t, p.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name)
assert.Equal(t, p.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.Name) assert.Equal(t, p.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name)
pp, err := NewCAPoolFromPEM([]byte(withNewLines)) pp, err := NewCAPoolFromPEM([]byte(withNewLines))
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, pp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.Name) assert.Equal(t, pp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name)
assert.Equal(t, pp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.Name) assert.Equal(t, pp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name)
// expired cert, no valid certs // expired cert, no valid certs
ppp, err := NewCAPoolFromPEM([]byte(expired)) ppp, err := NewCAPoolFromPEM([]byte(expired))
assert.Equal(t, ErrExpired, err) assert.Equal(t, ErrExpired, err)
assert.Equal(t, ppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Certificate.Name(), "expired") assert.Equal(t, ppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name(), "expired")
// expired cert, with valid certs // expired cert, with valid certs
pppp, err := NewCAPoolFromPEM(append([]byte(expired), noNewLines...)) pppp, err := NewCAPoolFromPEM(append([]byte(expired), noNewLines...))
assert.Equal(t, ErrExpired, err) assert.Equal(t, ErrExpired, err)
assert.Equal(t, pppp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.Name) assert.Equal(t, pppp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name)
assert.Equal(t, pppp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.Name) assert.Equal(t, pppp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name)
assert.Equal(t, pppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Certificate.Name(), "expired") assert.Equal(t, pppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name(), "expired")
assert.Equal(t, len(pppp.CAs), 3) assert.Equal(t, len(pppp.CAs), 3)
ppppp, err := NewCAPoolFromPEM([]byte(p256)) ppppp, err := NewCAPoolFromPEM([]byte(p256))
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, ppppp.CAs[string("a7938893ec8c4ef769b06d7f425e5e46f7a7f5ffa49c3bcf4a86b608caba9159")].Certificate.Name(), rootCAP256.details.Name) assert.Equal(t, ppppp.CAs["552bf7d99bec1fc775a0e4c324bf6d8f789b3078f1919c7960d2e5e0c351ee97"].Certificate.Name(), rootCAP256.details.name)
assert.Equal(t, len(ppppp.CAs), 1) assert.Equal(t, len(ppppp.CAs), 1)
} }
func TestCertificateV1_Verify(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
caPool := NewCAPool()
assert.NoError(t, caPool.AddCA(ca))
f, err := c.Fingerprint()
assert.Nil(t, err)
caPool.BlocklistFingerprint(f)
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.EqualError(t, err, "certificate is in the block list")
caPool.ResetCertBlocklist()
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err)
_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
assert.EqualError(t, err, "root certificate is expired")
assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil)
})
// Test group assertion
ca, _, caKey, _ = NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
caPem, err := ca.MarshalPEM()
assert.Nil(t, err)
caPool = NewCAPool()
b, err := caPool.AddCAFromPEM(caPem)
assert.NoError(t, err)
assert.Empty(t, b)
assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"})
})
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
assert.Nil(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err)
}
func TestCertificateV1_VerifyP256(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
c, _, _, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
caPool := NewCAPool()
assert.NoError(t, caPool.AddCA(ca))
f, err := c.Fingerprint()
assert.Nil(t, err)
caPool.BlocklistFingerprint(f)
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.EqualError(t, err, "certificate is in the block list")
caPool.ResetCertBlocklist()
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err)
_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
assert.EqualError(t, err, "root certificate is expired")
assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
})
// Test group assertion
ca, _, caKey, _ = NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
caPem, err := ca.MarshalPEM()
assert.Nil(t, err)
caPool = NewCAPool()
b, err := caPool.AddCAFromPEM(caPem)
assert.NoError(t, err)
assert.Empty(t, b)
assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"})
})
c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err)
}
func TestCertificateV1_Verify_IPs(t *testing.T) {
caIp1 := mustParsePrefixUnmapped("10.0.0.0/16")
caIp2 := mustParsePrefixUnmapped("192.168.0.0/24")
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
caPem, err := ca.MarshalPEM()
assert.Nil(t, err)
caPool := NewCAPool()
b, err := caPool.AddCAFromPEM(caPem)
assert.NoError(t, err)
assert.Empty(t, b)
// ip is outside the network
cIp1 := mustParsePrefixUnmapped("10.1.0.0/24")
cIp2 := mustParsePrefixUnmapped("192.168.0.1/16")
assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
})
// ip is outside the network reversed order of above
cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
cIp2 = mustParsePrefixUnmapped("10.1.0.0/24")
assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
})
// ip is within the network but mask is outside
cIp1 = mustParsePrefixUnmapped("10.0.1.0/15")
cIp2 = mustParsePrefixUnmapped("192.168.0.1/24")
assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
})
// ip is within the network but mask is outside reversed order of above
cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
cIp2 = mustParsePrefixUnmapped("10.0.1.0/15")
assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
})
// ip and mask are within the network
cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err)
// Exact matches
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
assert.Nil(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err)
// Exact matches reversed
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"})
assert.Nil(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err)
// Exact matches reversed with just 1
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"})
assert.Nil(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err)
}
func TestCertificateV1_Verify_Subnets(t *testing.T) {
caIp1 := mustParsePrefixUnmapped("10.0.0.0/16")
caIp2 := mustParsePrefixUnmapped("192.168.0.0/24")
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
caPem, err := ca.MarshalPEM()
assert.Nil(t, err)
caPool := NewCAPool()
b, err := caPool.AddCAFromPEM(caPem)
assert.NoError(t, err)
assert.Empty(t, b)
// ip is outside the network
cIp1 := mustParsePrefixUnmapped("10.1.0.0/24")
cIp2 := mustParsePrefixUnmapped("192.168.0.1/16")
assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
})
// ip is outside the network reversed order of above
cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
cIp2 = mustParsePrefixUnmapped("10.1.0.0/24")
assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
})
// ip is within the network but mask is outside
cIp1 = mustParsePrefixUnmapped("10.0.1.0/15")
cIp2 = mustParsePrefixUnmapped("192.168.0.1/24")
assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
})
// ip is within the network but mask is outside reversed order of above
cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
cIp2 = mustParsePrefixUnmapped("10.0.1.0/15")
assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
})
// ip and mask are within the network
cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
assert.Nil(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err)
// Exact matches
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
assert.Nil(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err)
// Exact matches reversed
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"})
assert.Nil(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err)
// Exact matches reversed with just 1
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"})
assert.Nil(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err)
}
func TestCertificateV2_Verify(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
caPool := NewCAPool()
assert.NoError(t, caPool.AddCA(ca))
f, err := c.Fingerprint()
assert.Nil(t, err)
caPool.BlocklistFingerprint(f)
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.EqualError(t, err, "certificate is in the block list")
caPool.ResetCertBlocklist()
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err)
_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
assert.EqualError(t, err, "root certificate is expired")
assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil)
})
// Test group assertion
ca, _, caKey, _ = NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
caPem, err := ca.MarshalPEM()
assert.Nil(t, err)
caPool = NewCAPool()
b, err := caPool.AddCAFromPEM(caPem)
assert.NoError(t, err)
assert.Empty(t, b)
assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"})
})
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
assert.Nil(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err)
}
func TestCertificateV2_VerifyP256(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
c, _, _, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
caPool := NewCAPool()
assert.NoError(t, caPool.AddCA(ca))
f, err := c.Fingerprint()
assert.Nil(t, err)
caPool.BlocklistFingerprint(f)
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.EqualError(t, err, "certificate is in the block list")
caPool.ResetCertBlocklist()
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err)
_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
assert.EqualError(t, err, "root certificate is expired")
assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
})
// Test group assertion
ca, _, caKey, _ = NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
caPem, err := ca.MarshalPEM()
assert.Nil(t, err)
caPool = NewCAPool()
b, err := caPool.AddCAFromPEM(caPem)
assert.NoError(t, err)
assert.Empty(t, b)
assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"})
})
c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err)
}
func TestCertificateV2_Verify_IPs(t *testing.T) {
caIp1 := mustParsePrefixUnmapped("10.0.0.0/16")
caIp2 := mustParsePrefixUnmapped("192.168.0.0/24")
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
caPem, err := ca.MarshalPEM()
assert.Nil(t, err)
caPool := NewCAPool()
b, err := caPool.AddCAFromPEM(caPem)
assert.NoError(t, err)
assert.Empty(t, b)
// ip is outside the network
cIp1 := mustParsePrefixUnmapped("10.1.0.0/24")
cIp2 := mustParsePrefixUnmapped("192.168.0.1/16")
assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
})
// ip is outside the network reversed order of above
cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
cIp2 = mustParsePrefixUnmapped("10.1.0.0/24")
assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
})
// ip is within the network but mask is outside
cIp1 = mustParsePrefixUnmapped("10.0.1.0/15")
cIp2 = mustParsePrefixUnmapped("192.168.0.1/24")
assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
})
// ip is within the network but mask is outside reversed order of above
cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
cIp2 = mustParsePrefixUnmapped("10.0.1.0/15")
assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
})
// ip and mask are within the network
cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err)
// Exact matches
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
assert.Nil(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err)
// Exact matches reversed
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"})
assert.Nil(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err)
// Exact matches reversed with just 1
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"})
assert.Nil(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err)
}
func TestCertificateV2_Verify_Subnets(t *testing.T) {
caIp1 := mustParsePrefixUnmapped("10.0.0.0/16")
caIp2 := mustParsePrefixUnmapped("192.168.0.0/24")
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
caPem, err := ca.MarshalPEM()
assert.Nil(t, err)
caPool := NewCAPool()
b, err := caPool.AddCAFromPEM(caPem)
assert.NoError(t, err)
assert.Empty(t, b)
// ip is outside the network
cIp1 := mustParsePrefixUnmapped("10.1.0.0/24")
cIp2 := mustParsePrefixUnmapped("192.168.0.1/16")
assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
})
// ip is outside the network reversed order of above
cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
cIp2 = mustParsePrefixUnmapped("10.1.0.0/24")
assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
})
// ip is within the network but mask is outside
cIp1 = mustParsePrefixUnmapped("10.0.1.0/15")
cIp2 = mustParsePrefixUnmapped("192.168.0.1/24")
assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
})
// ip is within the network but mask is outside reversed order of above
cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
cIp2 = mustParsePrefixUnmapped("10.0.1.0/15")
assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
})
// ip and mask are within the network
cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
assert.Nil(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err)
// Exact matches
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
assert.Nil(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err)
// Exact matches reversed
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"})
assert.Nil(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err)
// Exact matches reversed with just 1
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"})
assert.Nil(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err)
}

View File

@ -1,15 +1,17 @@
package cert package cert
import ( import (
"fmt"
"net/netip" "net/netip"
"time" "time"
) )
type Version int type Version uint8
const ( const (
Version1 Version = 1 VersionPre1 Version = 0
Version2 Version = 2 Version1 Version = 1
Version2 Version = 2
) )
type Certificate interface { type Certificate interface {
@ -107,23 +109,57 @@ type CachedCertificate struct {
signerFingerprint string signerFingerprint string
} }
// UnmarshalCertificate will attempt to unmarshal a wire protocol level certificate. func (cc *CachedCertificate) String() string {
func UnmarshalCertificate(b []byte) (Certificate, error) { return cc.Certificate.String()
c, err := unmarshalCertificateV1(b, true)
if err != nil {
return nil, err
}
return c, nil
} }
// UnmarshalCertificateFromHandshake will attempt to unmarshal a certificate received in a handshake. // RecombineAndValidate will attempt to unmarshal a certificate received in a handshake.
// Handshakes save space by placing the peers public key in a different part of the packet, we have to // Handshakes save space by placing the peers public key in a different part of the packet, we have to
// reassemble the actual certificate structure with that in mind. // reassemble the actual certificate structure with that in mind.
func UnmarshalCertificateFromHandshake(b []byte, publicKey []byte) (Certificate, error) { func RecombineAndValidate(v Version, rawCertBytes, publicKey []byte, curve Curve, caPool *CAPool) (*CachedCertificate, error) {
c, err := unmarshalCertificateV1(b, false) if publicKey == nil {
return nil, ErrNoPeerStaticKey
}
if rawCertBytes == nil {
return nil, ErrNoPayload
}
c, err := unmarshalCertificateFromHandshake(v, rawCertBytes, publicKey, curve)
if err != nil {
return nil, fmt.Errorf("error unmarshaling cert: %w", err)
}
cc, err := caPool.VerifyCertificate(time.Now(), c)
if err != nil {
return nil, fmt.Errorf("certificate validation failed: %w", err)
}
return cc, nil
}
func unmarshalCertificateFromHandshake(v Version, b []byte, publicKey []byte, curve Curve) (Certificate, error) {
var c Certificate
var err error
switch v {
// Implementations must ensure the result is a valid cert!
case VersionPre1, Version1:
c, err = unmarshalCertificateV1(b, publicKey)
case Version2:
c, err = unmarshalCertificateV2(b, publicKey, curve)
default:
//TODO: CERT-V2 make a static var
return nil, fmt.Errorf("unknown certificate version %d", v)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
c.details.PublicKey = publicKey
if c.Curve() != curve {
return nil, fmt.Errorf("certificate curve %s does not match expected %s", c.Curve().String(), curve.String())
}
return c, nil return c, nil
} }

View File

@ -1,695 +0,0 @@
package cert
import (
"crypto/ecdh"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"fmt"
"io"
"net/netip"
"testing"
"time"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/curve25519"
"golang.org/x/crypto/ed25519"
)
func TestMarshalingNebulaCertificate(t *testing.T) {
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := []byte("1234567890abcedfghij1234567890ab")
nc := certificateV1{
details: detailsV1{
Name: "testing",
Ips: []netip.Prefix{
mustParsePrefixUnmapped("10.1.1.1/24"),
mustParsePrefixUnmapped("10.1.1.2/16"),
},
Subnets: []netip.Prefix{
mustParsePrefixUnmapped("9.1.1.2/24"),
mustParsePrefixUnmapped("9.1.1.3/16"),
},
Groups: []string{"test-group1", "test-group2", "test-group3"},
NotBefore: before,
NotAfter: after,
PublicKey: pubKey,
IsCA: false,
Issuer: "1234567890abcedfghij1234567890ab",
},
signature: []byte("1234567890abcedfghij1234567890ab"),
}
b, err := nc.Marshal()
assert.Nil(t, err)
//t.Log("Cert size:", len(b))
nc2, err := unmarshalCertificateV1(b, true)
assert.Nil(t, err)
assert.Equal(t, nc.signature, nc2.Signature())
assert.Equal(t, nc.details.Name, nc2.Name())
assert.Equal(t, nc.details.NotBefore, nc2.NotBefore())
assert.Equal(t, nc.details.NotAfter, nc2.NotAfter())
assert.Equal(t, nc.details.PublicKey, nc2.PublicKey())
assert.Equal(t, nc.details.IsCA, nc2.IsCA())
assert.Equal(t, nc.details.Ips, nc2.Networks())
assert.Equal(t, nc.details.Subnets, nc2.UnsafeNetworks())
assert.Equal(t, nc.details.Groups, nc2.Groups())
}
//func TestNebulaCertificate_Sign(t *testing.T) {
// before := time.Now().Add(time.Second * -60).Round(time.Second)
// after := time.Now().Add(time.Second * 60).Round(time.Second)
// pubKey := []byte("1234567890abcedfghij1234567890ab")
//
// nc := certificateV1{
// details: detailsV1{
// Name: "testing",
// Ips: []netip.Prefix{
// mustParsePrefixUnmapped("10.1.1.1/24"),
// mustParsePrefixUnmapped("10.1.1.2/16"),
// //TODO: netip cant do it
// //{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
// },
// Subnets: []netip.Prefix{
// //TODO: netip cant do it
// //{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
// mustParsePrefixUnmapped("9.1.1.2/24"),
// mustParsePrefixUnmapped("9.1.1.3/24"),
// },
// Groups: []string{"test-group1", "test-group2", "test-group3"},
// NotBefore: before,
// NotAfter: after,
// PublicKey: pubKey,
// IsCA: false,
// Issuer: "1234567890abcedfghij1234567890ab",
// },
// }
//
// pub, priv, err := ed25519.GenerateKey(rand.Reader)
// assert.Nil(t, err)
// assert.False(t, nc.CheckSignature(pub))
// assert.Nil(t, nc.Sign(Curve_CURVE25519, priv))
// assert.True(t, nc.CheckSignature(pub))
//
// _, err = nc.Marshal()
// assert.Nil(t, err)
// //t.Log("Cert size:", len(b))
//}
//func TestNebulaCertificate_SignP256(t *testing.T) {
// before := time.Now().Add(time.Second * -60).Round(time.Second)
// after := time.Now().Add(time.Second * 60).Round(time.Second)
// pubKey := []byte("01234567890abcedfghij1234567890ab1234567890abcedfghij1234567890ab")
//
// nc := certificateV1{
// details: detailsV1{
// Name: "testing",
// Ips: []netip.Prefix{
// mustParsePrefixUnmapped("10.1.1.1/24"),
// mustParsePrefixUnmapped("10.1.1.2/16"),
// //TODO: netip no can do
// //{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
// },
// Subnets: []netip.Prefix{
// //TODO: netip bad
// //{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
// mustParsePrefixUnmapped("9.1.1.2/24"),
// mustParsePrefixUnmapped("9.1.1.3/16"),
// },
// Groups: []string{"test-group1", "test-group2", "test-group3"},
// NotBefore: before,
// NotAfter: after,
// PublicKey: pubKey,
// IsCA: false,
// Curve: Curve_P256,
// Issuer: "1234567890abcedfghij1234567890ab",
// },
// }
//
// priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
// pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y)
// rawPriv := priv.D.FillBytes(make([]byte, 32))
//
// assert.Nil(t, err)
// assert.False(t, nc.CheckSignature(pub))
// assert.Nil(t, nc.Sign(Curve_P256, rawPriv))
// assert.True(t, nc.CheckSignature(pub))
//
// _, err = nc.Marshal()
// assert.Nil(t, err)
// //t.Log("Cert size:", len(b))
//}
func TestNebulaCertificate_Expired(t *testing.T) {
nc := certificateV1{
details: detailsV1{
NotBefore: time.Now().Add(time.Second * -60).Round(time.Second),
NotAfter: time.Now().Add(time.Second * 60).Round(time.Second),
},
}
assert.True(t, nc.Expired(time.Now().Add(time.Hour)))
assert.True(t, nc.Expired(time.Now().Add(-time.Hour)))
assert.False(t, nc.Expired(time.Now()))
}
func TestNebulaCertificate_MarshalJSON(t *testing.T) {
time.Local = time.UTC
pubKey := []byte("1234567890abcedfghij1234567890ab")
nc := certificateV1{
details: detailsV1{
Name: "testing",
Ips: []netip.Prefix{
mustParsePrefixUnmapped("10.1.1.1/24"),
mustParsePrefixUnmapped("10.1.1.2/16"),
},
Subnets: []netip.Prefix{
mustParsePrefixUnmapped("9.1.1.2/24"),
mustParsePrefixUnmapped("9.1.1.3/16"),
},
Groups: []string{"test-group1", "test-group2", "test-group3"},
NotBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC),
NotAfter: time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC),
PublicKey: pubKey,
IsCA: false,
Issuer: "1234567890abcedfghij1234567890ab",
},
signature: []byte("1234567890abcedfghij1234567890ab"),
}
b, err := nc.MarshalJSON()
assert.Nil(t, err)
assert.Equal(
t,
"{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"ips\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"subnets\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\"}",
string(b),
)
}
func TestNebulaCertificate_Verify(t *testing.T) {
ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
assert.Nil(t, err)
c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
assert.Nil(t, err)
caPool := NewCAPool()
assert.NoError(t, caPool.AddCA(ca))
f, err := c.Fingerprint()
assert.Nil(t, err)
caPool.BlocklistFingerprint(f)
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.EqualError(t, err, "certificate is in the block list")
caPool.ResetCertBlocklist()
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err)
_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
assert.EqualError(t, err, "root certificate is expired")
c, _, _, err = newTestCert(ca, caKey, time.Time{}, time.Time{}, nil, nil, nil)
assert.EqualError(t, err, "certificate is valid before the signing certificate")
// Test group assertion
ca, _, caKey, err = newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
assert.Nil(t, err)
caPem, err := ca.MarshalPEM()
assert.Nil(t, err)
caPool = NewCAPool()
b, err := caPool.AddCAFromPEM(caPem)
assert.NoError(t, err)
assert.Empty(t, b)
c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"})
assert.EqualError(t, err, "certificate contained a group not present on the signing ca: bad")
c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
assert.Nil(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err)
}
func TestNebulaCertificate_VerifyP256(t *testing.T) {
ca, _, caKey, err := newTestCaCertP256(time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
assert.Nil(t, err)
c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
assert.Nil(t, err)
caPool := NewCAPool()
assert.NoError(t, caPool.AddCA(ca))
f, err := c.Fingerprint()
assert.Nil(t, err)
caPool.BlocklistFingerprint(f)
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.EqualError(t, err, "certificate is in the block list")
caPool.ResetCertBlocklist()
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err)
_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
assert.EqualError(t, err, "root certificate is expired")
c, _, _, err = newTestCert(ca, caKey, time.Time{}, time.Time{}, nil, nil, nil)
assert.EqualError(t, err, "certificate is valid before the signing certificate")
// Test group assertion
ca, _, caKey, err = newTestCaCertP256(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
assert.Nil(t, err)
caPem, err := ca.MarshalPEM()
assert.Nil(t, err)
caPool = NewCAPool()
b, err := caPool.AddCAFromPEM(caPem)
assert.NoError(t, err)
assert.Empty(t, b)
c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"})
assert.EqualError(t, err, "certificate contained a group not present on the signing ca: bad")
c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
assert.Nil(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err)
}
func TestNebulaCertificate_Verify_IPs(t *testing.T) {
caIp1 := mustParsePrefixUnmapped("10.0.0.0/16")
caIp2 := mustParsePrefixUnmapped("192.168.0.0/24")
ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
assert.Nil(t, err)
caPem, err := ca.MarshalPEM()
assert.Nil(t, err)
caPool := NewCAPool()
b, err := caPool.AddCAFromPEM(caPem)
assert.NoError(t, err)
assert.Empty(t, b)
// ip is outside the network
cIp1 := mustParsePrefixUnmapped("10.1.0.0/24")
cIp2 := mustParsePrefixUnmapped("192.168.0.1/16")
c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
assert.EqualError(t, err, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24")
// ip is outside the network reversed order of above
cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
cIp2 = mustParsePrefixUnmapped("10.1.0.0/24")
c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
assert.EqualError(t, err, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24")
// ip is within the network but mask is outside
cIp1 = mustParsePrefixUnmapped("10.0.1.0/15")
cIp2 = mustParsePrefixUnmapped("192.168.0.1/24")
c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
assert.EqualError(t, err, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15")
// ip is within the network but mask is outside reversed order of above
cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
cIp2 = mustParsePrefixUnmapped("10.0.1.0/15")
c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
assert.EqualError(t, err, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15")
// ip and mask are within the network
cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
assert.Nil(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err)
// Exact matches
c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
assert.Nil(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err)
// Exact matches reversed
c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"})
assert.Nil(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err)
// Exact matches reversed with just 1
c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"})
assert.Nil(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err)
}
func TestNebulaCertificate_Verify_Subnets(t *testing.T) {
caIp1 := mustParsePrefixUnmapped("10.0.0.0/16")
caIp2 := mustParsePrefixUnmapped("192.168.0.0/24")
ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
assert.Nil(t, err)
caPem, err := ca.MarshalPEM()
assert.Nil(t, err)
caPool := NewCAPool()
b, err := caPool.AddCAFromPEM(caPem)
assert.NoError(t, err)
assert.Empty(t, b)
// ip is outside the network
cIp1 := mustParsePrefixUnmapped("10.1.0.0/24")
cIp2 := mustParsePrefixUnmapped("192.168.0.1/16")
c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
assert.EqualError(t, err, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24")
// ip is outside the network reversed order of above
cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
cIp2 = mustParsePrefixUnmapped("10.1.0.0/24")
c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
assert.EqualError(t, err, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24")
// ip is within the network but mask is outside
cIp1 = mustParsePrefixUnmapped("10.0.1.0/15")
cIp2 = mustParsePrefixUnmapped("192.168.0.1/24")
c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
assert.EqualError(t, err, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15")
// ip is within the network but mask is outside reversed order of above
cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
cIp2 = mustParsePrefixUnmapped("10.0.1.0/15")
c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
assert.EqualError(t, err, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15")
// ip and mask are within the network
cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
assert.Nil(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err)
// Exact matches
c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
assert.Nil(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err)
// Exact matches reversed
c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"})
assert.Nil(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err)
// Exact matches reversed with just 1
c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"})
assert.Nil(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
assert.Nil(t, err)
}
func TestNebulaCertificate_VerifyPrivateKey(t *testing.T) {
ca, _, caKey, err := newTestCaCert(time.Time{}, time.Time{}, nil, nil, nil)
assert.Nil(t, err)
err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
assert.Nil(t, err)
_, _, caKey2, err := newTestCaCert(time.Time{}, time.Time{}, nil, nil, nil)
assert.Nil(t, err)
err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2)
assert.NotNil(t, err)
c, _, priv, err := newTestCert(ca, caKey, time.Time{}, time.Time{}, nil, nil, nil)
err = c.VerifyPrivateKey(Curve_CURVE25519, priv)
assert.Nil(t, err)
_, priv2 := x25519Keypair()
err = c.VerifyPrivateKey(Curve_CURVE25519, priv2)
assert.NotNil(t, err)
}
func TestNebulaCertificate_VerifyPrivateKeyP256(t *testing.T) {
ca, _, caKey, err := newTestCaCertP256(time.Time{}, time.Time{}, nil, nil, nil)
assert.Nil(t, err)
err = ca.VerifyPrivateKey(Curve_P256, caKey)
assert.Nil(t, err)
_, _, caKey2, err := newTestCaCertP256(time.Time{}, time.Time{}, nil, nil, nil)
assert.Nil(t, err)
err = ca.VerifyPrivateKey(Curve_P256, caKey2)
assert.NotNil(t, err)
c, _, priv, err := newTestCert(ca, caKey, time.Time{}, time.Time{}, nil, nil, nil)
err = c.VerifyPrivateKey(Curve_P256, priv)
assert.Nil(t, err)
_, priv2 := p256Keypair()
err = c.VerifyPrivateKey(Curve_P256, priv2)
assert.NotNil(t, err)
}
func appendByteSlices(b ...[]byte) []byte {
retSlice := []byte{}
for _, v := range b {
retSlice = append(retSlice, v...)
}
return retSlice
}
// Ensure that upgrading the protobuf library does not change how certificates
// are marshalled, since this would break signature verification
//TODO: since netip cant represent 255.0.255.0 netmask we can't verify the old certs are ok
//func TestMarshalingNebulaCertificateConsistency(t *testing.T) {
// before := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC)
// after := time.Date(2017, time.January, 18, 28, 40, 0, 0, time.UTC)
// pubKey := []byte("1234567890abcedfghij1234567890ab")
//
// nc := certificateV1{
// details: detailsV1{
// Name: "testing",
// Ips: []netip.Prefix{
// mustParsePrefixUnmapped("10.1.1.1/24"),
// mustParsePrefixUnmapped("10.1.1.2/16"),
// //TODO: netip bad
// //{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
// },
// Subnets: []netip.Prefix{
// //TODO: netip bad
// //{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
// mustParsePrefixUnmapped("9.1.1.2/24"),
// mustParsePrefixUnmapped("9.1.1.3/16"),
// },
// Groups: []string{"test-group1", "test-group2", "test-group3"},
// NotBefore: before,
// NotAfter: after,
// PublicKey: pubKey,
// IsCA: false,
// Issuer: "1234567890abcedfghij1234567890ab",
// },
// signature: []byte("1234567890abcedfghij1234567890ab"),
// }
//
// b, err := nc.Marshal()
// assert.Nil(t, err)
// //t.Log("Cert size:", len(b))
// assert.Equal(t, "0aa2010a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf1220313233343536373839306162636564666768696a313233343536373839306162", fmt.Sprintf("%x", b))
//
// b, err = proto.Marshal(nc.getRawDetails())
// assert.Nil(t, err)
// //t.Log("Raw cert size:", len(b))
// assert.Equal(t, "0a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b))
//}
func TestNebulaCertificate_Copy(t *testing.T) {
ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
assert.Nil(t, err)
c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
assert.Nil(t, err)
cc := c.Copy()
test.AssertDeepCopyEqual(t, c, cc)
}
func TestUnmarshalNebulaCertificate(t *testing.T) {
// Test that we don't panic with an invalid certificate (#332)
data := []byte("\x98\x00\x00")
_, err := unmarshalCertificateV1(data, true)
assert.EqualError(t, err, "encoded Details was nil")
}
func newTestCaCert(before, after time.Time, ips, subnets []netip.Prefix, groups []string) (Certificate, []byte, []byte, error) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if before.IsZero() {
before = time.Now().Add(time.Second * -60).Round(time.Second)
}
if after.IsZero() {
after = time.Now().Add(time.Second * 60).Round(time.Second)
}
tbs := &TBSCertificate{
Version: Version1,
Name: "test ca",
IsCA: true,
NotBefore: time.Unix(before.Unix(), 0),
NotAfter: time.Unix(after.Unix(), 0),
PublicKey: pub,
}
if len(ips) > 0 {
tbs.Networks = ips
}
if len(subnets) > 0 {
tbs.UnsafeNetworks = subnets
}
if len(groups) > 0 {
tbs.Groups = groups
}
nc, err := tbs.Sign(nil, Curve_CURVE25519, priv)
if err != nil {
return nil, nil, nil, err
}
return nc, pub, priv, nil
}
func newTestCaCertP256(before, after time.Time, ips, subnets []netip.Prefix, groups []string) (Certificate, []byte, []byte, error) {
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y)
rawPriv := priv.D.FillBytes(make([]byte, 32))
if before.IsZero() {
before = time.Now().Add(time.Second * -60).Round(time.Second)
}
if after.IsZero() {
after = time.Now().Add(time.Second * 60).Round(time.Second)
}
tbs := &TBSCertificate{
Version: Version1,
Name: "test ca",
IsCA: true,
NotBefore: time.Unix(before.Unix(), 0),
NotAfter: time.Unix(after.Unix(), 0),
PublicKey: pub,
Curve: Curve_P256,
}
if len(ips) > 0 {
tbs.Networks = ips
}
if len(subnets) > 0 {
tbs.UnsafeNetworks = subnets
}
if len(groups) > 0 {
tbs.Groups = groups
}
nc, err := tbs.Sign(nil, Curve_P256, rawPriv)
if err != nil {
return nil, nil, nil, err
}
return nc, pub, rawPriv, nil
}
func newTestCert(ca Certificate, key []byte, before, after time.Time, ips, subnets []netip.Prefix, groups []string) (Certificate, []byte, []byte, error) {
if before.IsZero() {
before = time.Now().Add(time.Second * -60).Round(time.Second)
}
if after.IsZero() {
after = time.Now().Add(time.Second * 60).Round(time.Second)
}
if len(groups) == 0 {
groups = []string{"test-group1", "test-group2", "test-group3"}
}
if len(ips) == 0 {
ips = []netip.Prefix{
mustParsePrefixUnmapped("10.1.1.1/24"),
mustParsePrefixUnmapped("10.1.1.2/16"),
}
}
if len(subnets) == 0 {
subnets = []netip.Prefix{
mustParsePrefixUnmapped("9.1.1.2/24"),
mustParsePrefixUnmapped("9.1.1.3/16"),
}
}
var pub, rawPriv []byte
switch ca.Curve() {
case Curve_CURVE25519:
pub, rawPriv = x25519Keypair()
case Curve_P256:
pub, rawPriv = p256Keypair()
default:
return nil, nil, nil, fmt.Errorf("unknown curve: %v", ca.Curve())
}
tbs := &TBSCertificate{
Version: Version1,
Name: "testing",
Networks: ips,
UnsafeNetworks: subnets,
Groups: groups,
IsCA: false,
NotBefore: time.Unix(before.Unix(), 0),
NotAfter: time.Unix(after.Unix(), 0),
PublicKey: pub,
Curve: ca.Curve(),
}
nc, err := tbs.Sign(ca, ca.Curve(), key)
if err != nil {
return nil, nil, nil, err
}
return nc, pub, rawPriv, nil
}
func x25519Keypair() ([]byte, []byte) {
privkey := make([]byte, 32)
if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
panic(err)
}
pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint)
if err != nil {
panic(err)
}
return pubkey, privkey
}
func p256Keypair() ([]byte, []byte) {
privkey, err := ecdh.P256().GenerateKey(rand.Reader)
if err != nil {
panic(err)
}
pubkey := privkey.PublicKey()
return pubkey.Bytes(), privkey.Bytes()
}
func mustParsePrefixUnmapped(s string) netip.Prefix {
prefix := netip.MustParsePrefix(s)
return netip.PrefixFrom(prefix.Addr().Unmap(), prefix.Bits())
}

View File

@ -6,19 +6,16 @@ import (
"crypto/ecdsa" "crypto/ecdsa"
"crypto/ed25519" "crypto/ed25519"
"crypto/elliptic" "crypto/elliptic"
"crypto/rand"
"crypto/sha256" "crypto/sha256"
"encoding/binary" "encoding/binary"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"math/big"
"net" "net"
"net/netip" "net/netip"
"time" "time"
"github.com/slackhq/nebula/pkclient"
"golang.org/x/crypto/curve25519" "golang.org/x/crypto/curve25519"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
) )
@ -31,71 +28,71 @@ type certificateV1 struct {
} }
type detailsV1 struct { type detailsV1 struct {
Name string name string
Ips []netip.Prefix networks []netip.Prefix
Subnets []netip.Prefix unsafeNetworks []netip.Prefix
Groups []string groups []string
NotBefore time.Time notBefore time.Time
NotAfter time.Time notAfter time.Time
PublicKey []byte publicKey []byte
IsCA bool isCA bool
Issuer string issuer string
Curve Curve curve Curve
} }
type m map[string]interface{} type m map[string]interface{}
func (nc *certificateV1) Version() Version { func (c *certificateV1) Version() Version {
return Version1 return Version1
} }
func (nc *certificateV1) Curve() Curve { func (c *certificateV1) Curve() Curve {
return nc.details.Curve return c.details.curve
} }
func (nc *certificateV1) Groups() []string { func (c *certificateV1) Groups() []string {
return nc.details.Groups return c.details.groups
} }
func (nc *certificateV1) IsCA() bool { func (c *certificateV1) IsCA() bool {
return nc.details.IsCA return c.details.isCA
} }
func (nc *certificateV1) Issuer() string { func (c *certificateV1) Issuer() string {
return nc.details.Issuer return c.details.issuer
} }
func (nc *certificateV1) Name() string { func (c *certificateV1) Name() string {
return nc.details.Name return c.details.name
} }
func (nc *certificateV1) Networks() []netip.Prefix { func (c *certificateV1) Networks() []netip.Prefix {
return nc.details.Ips return c.details.networks
} }
func (nc *certificateV1) NotAfter() time.Time { func (c *certificateV1) NotAfter() time.Time {
return nc.details.NotAfter return c.details.notAfter
} }
func (nc *certificateV1) NotBefore() time.Time { func (c *certificateV1) NotBefore() time.Time {
return nc.details.NotBefore return c.details.notBefore
} }
func (nc *certificateV1) PublicKey() []byte { func (c *certificateV1) PublicKey() []byte {
return nc.details.PublicKey return c.details.publicKey
} }
func (nc *certificateV1) Signature() []byte { func (c *certificateV1) Signature() []byte {
return nc.signature return c.signature
} }
func (nc *certificateV1) UnsafeNetworks() []netip.Prefix { func (c *certificateV1) UnsafeNetworks() []netip.Prefix {
return nc.details.Subnets return c.details.unsafeNetworks
} }
func (nc *certificateV1) Fingerprint() (string, error) { func (c *certificateV1) Fingerprint() (string, error) {
b, err := nc.Marshal() b, err := c.Marshal()
if err != nil { if err != nil {
return "", err return "", err
} }
@ -104,33 +101,33 @@ func (nc *certificateV1) Fingerprint() (string, error) {
return hex.EncodeToString(sum[:]), nil return hex.EncodeToString(sum[:]), nil
} }
func (nc *certificateV1) CheckSignature(key []byte) bool { func (c *certificateV1) CheckSignature(key []byte) bool {
b, err := proto.Marshal(nc.getRawDetails()) b, err := proto.Marshal(c.getRawDetails())
if err != nil { if err != nil {
return false return false
} }
switch nc.details.Curve { switch c.details.curve {
case Curve_CURVE25519: case Curve_CURVE25519:
return ed25519.Verify(key, b, nc.signature) return ed25519.Verify(key, b, c.signature)
case Curve_P256: case Curve_P256:
x, y := elliptic.Unmarshal(elliptic.P256(), key) x, y := elliptic.Unmarshal(elliptic.P256(), key)
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y} pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
hashed := sha256.Sum256(b) hashed := sha256.Sum256(b)
return ecdsa.VerifyASN1(pubKey, hashed[:], nc.signature) return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
default: default:
return false return false
} }
} }
func (nc *certificateV1) Expired(t time.Time) bool { func (c *certificateV1) Expired(t time.Time) bool {
return nc.details.NotBefore.After(t) || nc.details.NotAfter.Before(t) return c.details.notBefore.After(t) || c.details.notAfter.Before(t)
} }
func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error { func (c *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
if curve != nc.details.Curve { if curve != c.details.curve {
return fmt.Errorf("curve in cert and private key supplied don't match") return fmt.Errorf("curve in cert and private key supplied don't match")
} }
if nc.details.IsCA { if c.details.isCA {
switch curve { switch curve {
case Curve_CURVE25519: case Curve_CURVE25519:
// the call to PublicKey below will panic slice bounds out of range otherwise // the call to PublicKey below will panic slice bounds out of range otherwise
@ -138,7 +135,7 @@ func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key") return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key")
} }
if !ed25519.PublicKey(nc.details.PublicKey).Equal(ed25519.PrivateKey(key).Public()) { if !ed25519.PublicKey(c.details.publicKey).Equal(ed25519.PrivateKey(key).Public()) {
return fmt.Errorf("public key in cert and private key supplied don't match") return fmt.Errorf("public key in cert and private key supplied don't match")
} }
case Curve_P256: case Curve_P256:
@ -147,7 +144,7 @@ func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
return fmt.Errorf("cannot parse private key as P256: %w", err) return fmt.Errorf("cannot parse private key as P256: %w", err)
} }
pub := privkey.PublicKey().Bytes() pub := privkey.PublicKey().Bytes()
if !bytes.Equal(pub, nc.details.PublicKey) { if !bytes.Equal(pub, c.details.publicKey) {
return fmt.Errorf("public key in cert and private key supplied don't match") return fmt.Errorf("public key in cert and private key supplied don't match")
} }
default: default:
@ -173,7 +170,7 @@ func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
default: default:
return fmt.Errorf("invalid curve: %s", curve) return fmt.Errorf("invalid curve: %s", curve)
} }
if !bytes.Equal(pub, nc.details.PublicKey) { if !bytes.Equal(pub, c.details.publicKey) {
return fmt.Errorf("public key in cert and private key supplied don't match") return fmt.Errorf("public key in cert and private key supplied don't match")
} }
@ -181,173 +178,219 @@ func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
} }
// getRawDetails marshals the raw details into protobuf ready struct // getRawDetails marshals the raw details into protobuf ready struct
func (nc *certificateV1) getRawDetails() *RawNebulaCertificateDetails { func (c *certificateV1) getRawDetails() *RawNebulaCertificateDetails {
rd := &RawNebulaCertificateDetails{ rd := &RawNebulaCertificateDetails{
Name: nc.details.Name, Name: c.details.name,
Groups: nc.details.Groups, Groups: c.details.groups,
NotBefore: nc.details.NotBefore.Unix(), NotBefore: c.details.notBefore.Unix(),
NotAfter: nc.details.NotAfter.Unix(), NotAfter: c.details.notAfter.Unix(),
PublicKey: make([]byte, len(nc.details.PublicKey)), PublicKey: make([]byte, len(c.details.publicKey)),
IsCA: nc.details.IsCA, IsCA: c.details.isCA,
Curve: nc.details.Curve, Curve: c.details.curve,
} }
for _, ipNet := range nc.details.Ips { for _, ipNet := range c.details.networks {
mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen()) mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen())
rd.Ips = append(rd.Ips, addr2int(ipNet.Addr()), ip2int(mask)) rd.Ips = append(rd.Ips, addr2int(ipNet.Addr()), ip2int(mask))
} }
for _, ipNet := range nc.details.Subnets { for _, ipNet := range c.details.unsafeNetworks {
mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen()) mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen())
rd.Subnets = append(rd.Subnets, addr2int(ipNet.Addr()), ip2int(mask)) rd.Subnets = append(rd.Subnets, addr2int(ipNet.Addr()), ip2int(mask))
} }
copy(rd.PublicKey, nc.details.PublicKey[:]) copy(rd.PublicKey, c.details.publicKey[:])
// I know, this is terrible // I know, this is terrible
rd.Issuer, _ = hex.DecodeString(nc.details.Issuer) rd.Issuer, _ = hex.DecodeString(c.details.issuer)
return rd return rd
} }
func (nc *certificateV1) String() string { func (c *certificateV1) String() string {
if nc == nil { b, err := json.MarshalIndent(c.marshalJSON(), "", "\t")
return "Certificate {}\n" if err != nil {
return fmt.Sprintf("<error marshalling certificate: %v>", err)
} }
return string(b)
s := "NebulaCertificate {\n"
s += "\tDetails {\n"
s += fmt.Sprintf("\t\tName: %v\n", nc.details.Name)
if len(nc.details.Ips) > 0 {
s += "\t\tIps: [\n"
for _, ip := range nc.details.Ips {
s += fmt.Sprintf("\t\t\t%v\n", ip.String())
}
s += "\t\t]\n"
} else {
s += "\t\tIps: []\n"
}
if len(nc.details.Subnets) > 0 {
s += "\t\tSubnets: [\n"
for _, ip := range nc.details.Subnets {
s += fmt.Sprintf("\t\t\t%v\n", ip.String())
}
s += "\t\t]\n"
} else {
s += "\t\tSubnets: []\n"
}
if len(nc.details.Groups) > 0 {
s += "\t\tGroups: [\n"
for _, g := range nc.details.Groups {
s += fmt.Sprintf("\t\t\t\"%v\"\n", g)
}
s += "\t\t]\n"
} else {
s += "\t\tGroups: []\n"
}
s += fmt.Sprintf("\t\tNot before: %v\n", nc.details.NotBefore)
s += fmt.Sprintf("\t\tNot After: %v\n", nc.details.NotAfter)
s += fmt.Sprintf("\t\tIs CA: %v\n", nc.details.IsCA)
s += fmt.Sprintf("\t\tIssuer: %s\n", nc.details.Issuer)
s += fmt.Sprintf("\t\tPublic key: %x\n", nc.details.PublicKey)
s += fmt.Sprintf("\t\tCurve: %s\n", nc.details.Curve)
s += "\t}\n"
fp, err := nc.Fingerprint()
if err == nil {
s += fmt.Sprintf("\tFingerprint: %s\n", fp)
}
s += fmt.Sprintf("\tSignature: %x\n", nc.Signature())
s += "}"
return s
} }
func (nc *certificateV1) MarshalForHandshakes() ([]byte, error) { func (c *certificateV1) MarshalForHandshakes() ([]byte, error) {
pubKey := nc.details.PublicKey pubKey := c.details.publicKey
nc.details.PublicKey = nil c.details.publicKey = nil
rawCertNoKey, err := nc.Marshal() rawCertNoKey, err := c.Marshal()
if err != nil { if err != nil {
return nil, err return nil, err
} }
nc.details.PublicKey = pubKey c.details.publicKey = pubKey
return rawCertNoKey, nil return rawCertNoKey, nil
} }
func (nc *certificateV1) Marshal() ([]byte, error) { func (c *certificateV1) Marshal() ([]byte, error) {
rc := RawNebulaCertificate{ rc := RawNebulaCertificate{
Details: nc.getRawDetails(), Details: c.getRawDetails(),
Signature: nc.signature, Signature: c.signature,
} }
return proto.Marshal(&rc) return proto.Marshal(&rc)
} }
func (nc *certificateV1) MarshalPEM() ([]byte, error) { func (c *certificateV1) MarshalPEM() ([]byte, error) {
b, err := nc.Marshal() b, err := c.Marshal()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return pem.EncodeToMemory(&pem.Block{Type: CertificateBanner, Bytes: b}), nil return pem.EncodeToMemory(&pem.Block{Type: CertificateBanner, Bytes: b}), nil
} }
func (nc *certificateV1) MarshalJSON() ([]byte, error) { func (c *certificateV1) MarshalJSON() ([]byte, error) {
fp, _ := nc.Fingerprint() return json.Marshal(c.marshalJSON())
jc := m{
"details": m{
"name": nc.details.Name,
"ips": nc.details.Ips,
"subnets": nc.details.Subnets,
"groups": nc.details.Groups,
"notBefore": nc.details.NotBefore,
"notAfter": nc.details.NotAfter,
"publicKey": fmt.Sprintf("%x", nc.details.PublicKey),
"isCa": nc.details.IsCA,
"issuer": nc.details.Issuer,
"curve": nc.details.Curve.String(),
},
"fingerprint": fp,
"signature": fmt.Sprintf("%x", nc.Signature()),
}
return json.Marshal(jc)
} }
func (nc *certificateV1) Copy() Certificate { func (c *certificateV1) marshalJSON() m {
c := &certificateV1{ fp, _ := c.Fingerprint()
details: detailsV1{ return m{
Name: nc.details.Name, "version": Version1,
Groups: make([]string, len(nc.details.Groups)), "details": m{
Ips: make([]netip.Prefix, len(nc.details.Ips)), "name": c.details.name,
Subnets: make([]netip.Prefix, len(nc.details.Subnets)), "networks": c.details.networks,
NotBefore: nc.details.NotBefore, "unsafeNetworks": c.details.unsafeNetworks,
NotAfter: nc.details.NotAfter, "groups": c.details.groups,
PublicKey: make([]byte, len(nc.details.PublicKey)), "notBefore": c.details.notBefore,
IsCA: nc.details.IsCA, "notAfter": c.details.notAfter,
Issuer: nc.details.Issuer, "publicKey": fmt.Sprintf("%x", c.details.publicKey),
"isCa": c.details.isCA,
"issuer": c.details.issuer,
"curve": c.details.curve.String(),
}, },
signature: make([]byte, len(nc.signature)), "fingerprint": fp,
"signature": fmt.Sprintf("%x", c.Signature()),
}
}
func (c *certificateV1) Copy() Certificate {
nc := &certificateV1{
details: detailsV1{
name: c.details.name,
notBefore: c.details.notBefore,
notAfter: c.details.notAfter,
publicKey: make([]byte, len(c.details.publicKey)),
isCA: c.details.isCA,
issuer: c.details.issuer,
curve: c.details.curve,
},
signature: make([]byte, len(c.signature)),
} }
copy(c.signature, nc.signature) if c.details.groups != nil {
copy(c.details.Groups, nc.details.Groups) nc.details.groups = make([]string, len(c.details.groups))
copy(c.details.PublicKey, nc.details.PublicKey) copy(nc.details.groups, c.details.groups)
for i, p := range nc.details.Ips {
c.details.Ips[i] = p
} }
for i, p := range nc.details.Subnets { if c.details.networks != nil {
c.details.Subnets[i] = p nc.details.networks = make([]netip.Prefix, len(c.details.networks))
copy(nc.details.networks, c.details.networks)
} }
return c if c.details.unsafeNetworks != nil {
nc.details.unsafeNetworks = make([]netip.Prefix, len(c.details.unsafeNetworks))
copy(nc.details.unsafeNetworks, c.details.unsafeNetworks)
}
copy(nc.signature, c.signature)
copy(nc.details.publicKey, c.details.publicKey)
return nc
}
func (c *certificateV1) fromTBSCertificate(t *TBSCertificate) error {
c.details = detailsV1{
name: t.Name,
networks: t.Networks,
unsafeNetworks: t.UnsafeNetworks,
groups: t.Groups,
notBefore: t.NotBefore,
notAfter: t.NotAfter,
publicKey: t.PublicKey,
isCA: t.IsCA,
curve: t.Curve,
issuer: t.issuer,
}
return c.validate()
}
func (c *certificateV1) validate() error {
// Empty names are allowed
if len(c.details.publicKey) == 0 {
return ErrInvalidPublicKey
}
// Original v1 rules allowed multiple networks to be present but ignored all but the first one.
// Continue to allow this behavior
if !c.details.isCA && len(c.details.networks) == 0 {
return NewErrInvalidCertificateProperties("non-CA certificates must contain exactly one network")
}
for _, network := range c.details.networks {
if !network.IsValid() || !network.Addr().IsValid() {
return NewErrInvalidCertificateProperties("invalid network: %s", network)
}
if network.Addr().Is6() {
return NewErrInvalidCertificateProperties("certificate may not contain IPv6 networks: %v", network)
}
if network.Addr().IsUnspecified() {
return NewErrInvalidCertificateProperties("non-CA certificates must not use the zero address as a network: %s", network)
}
if network.Addr().Zone() != "" {
return NewErrInvalidCertificateProperties("networks may not contain zones: %s", network)
}
}
for _, network := range c.details.unsafeNetworks {
if !network.IsValid() || !network.Addr().IsValid() {
return NewErrInvalidCertificateProperties("invalid unsafe network: %s", network)
}
if network.Addr().Is6() {
return NewErrInvalidCertificateProperties("certificate may not contain IPv6 unsafe networks: %v", network)
}
if network.Addr().Zone() != "" {
return NewErrInvalidCertificateProperties("unsafe networks may not contain zones: %s", network)
}
}
// v1 doesn't bother with sort order or uniqueness of networks or unsafe networks.
// We can't modify the unmarshalled data because verification requires re-marshalling and a re-ordered
// unsafe networks would result in a different signature.
return nil
}
func (c *certificateV1) marshalForSigning() ([]byte, error) {
b, err := proto.Marshal(c.getRawDetails())
if err != nil {
return nil, err
}
return b, nil
}
func (c *certificateV1) setSignature(b []byte) error {
if len(b) == 0 {
return ErrEmptySignature
}
c.signature = b
return nil
} }
// unmarshalCertificateV1 will unmarshal a protobuf byte representation of a nebula cert // unmarshalCertificateV1 will unmarshal a protobuf byte representation of a nebula cert
func unmarshalCertificateV1(b []byte, assertPublicKey bool) (*certificateV1, error) { // if the publicKey is provided here then it is not required to be present in `b`
func unmarshalCertificateV1(b []byte, publicKey []byte) (*certificateV1, error) {
if len(b) == 0 { if len(b) == 0 {
return nil, fmt.Errorf("nil byte array") return nil, fmt.Errorf("nil byte array")
} }
@ -371,27 +414,28 @@ func unmarshalCertificateV1(b []byte, assertPublicKey bool) (*certificateV1, err
nc := certificateV1{ nc := certificateV1{
details: detailsV1{ details: detailsV1{
Name: rc.Details.Name, name: rc.Details.Name,
Groups: make([]string, len(rc.Details.Groups)), groups: make([]string, len(rc.Details.Groups)),
Ips: make([]netip.Prefix, len(rc.Details.Ips)/2), networks: make([]netip.Prefix, len(rc.Details.Ips)/2),
Subnets: make([]netip.Prefix, len(rc.Details.Subnets)/2), unsafeNetworks: make([]netip.Prefix, len(rc.Details.Subnets)/2),
NotBefore: time.Unix(rc.Details.NotBefore, 0), notBefore: time.Unix(rc.Details.NotBefore, 0),
NotAfter: time.Unix(rc.Details.NotAfter, 0), notAfter: time.Unix(rc.Details.NotAfter, 0),
PublicKey: make([]byte, len(rc.Details.PublicKey)), publicKey: make([]byte, len(rc.Details.PublicKey)),
IsCA: rc.Details.IsCA, isCA: rc.Details.IsCA,
Curve: rc.Details.Curve, curve: rc.Details.Curve,
}, },
signature: make([]byte, len(rc.Signature)), signature: make([]byte, len(rc.Signature)),
} }
copy(nc.signature, rc.Signature) copy(nc.signature, rc.Signature)
copy(nc.details.Groups, rc.Details.Groups) copy(nc.details.groups, rc.Details.Groups)
nc.details.Issuer = hex.EncodeToString(rc.Details.Issuer) nc.details.issuer = hex.EncodeToString(rc.Details.Issuer)
if len(rc.Details.PublicKey) < publicKeyLen && assertPublicKey { if len(publicKey) > 0 {
return nil, fmt.Errorf("public key was fewer than 32 bytes; %v", len(rc.Details.PublicKey)) nc.details.publicKey = publicKey
} }
copy(nc.details.PublicKey, rc.Details.PublicKey)
copy(nc.details.publicKey, rc.Details.PublicKey)
var ip netip.Addr var ip netip.Addr
for i, rawIp := range rc.Details.Ips { for i, rawIp := range rc.Details.Ips {
@ -399,7 +443,7 @@ func unmarshalCertificateV1(b []byte, assertPublicKey bool) (*certificateV1, err
ip = int2addr(rawIp) ip = int2addr(rawIp)
} else { } else {
ones, _ := net.IPMask(int2ip(rawIp)).Size() ones, _ := net.IPMask(int2ip(rawIp)).Size()
nc.details.Ips[i/2] = netip.PrefixFrom(ip, ones) nc.details.networks[i/2] = netip.PrefixFrom(ip, ones)
} }
} }
@ -408,67 +452,16 @@ func unmarshalCertificateV1(b []byte, assertPublicKey bool) (*certificateV1, err
ip = int2addr(rawIp) ip = int2addr(rawIp)
} else { } else {
ones, _ := net.IPMask(int2ip(rawIp)).Size() ones, _ := net.IPMask(int2ip(rawIp)).Size()
nc.details.Subnets[i/2] = netip.PrefixFrom(ip, ones) nc.details.unsafeNetworks[i/2] = netip.PrefixFrom(ip, ones)
} }
} }
return &nc, nil err = nc.validate()
}
func signV1(t *TBSCertificate, curve Curve, key []byte, client *pkclient.PKClient) (*certificateV1, error) {
c := &certificateV1{
details: detailsV1{
Name: t.Name,
Ips: t.Networks,
Subnets: t.UnsafeNetworks,
Groups: t.Groups,
NotBefore: t.NotBefore,
NotAfter: t.NotAfter,
PublicKey: t.PublicKey,
IsCA: t.IsCA,
Curve: t.Curve,
Issuer: t.issuer,
},
}
b, err := proto.Marshal(c.getRawDetails())
if err != nil { if err != nil {
return nil, err return nil, err
} }
var sig []byte return &nc, nil
switch curve {
case Curve_CURVE25519:
signer := ed25519.PrivateKey(key)
sig = ed25519.Sign(signer, b)
case Curve_P256:
if client != nil {
sig, err = client.SignASN1(b)
} else {
signer := &ecdsa.PrivateKey{
PublicKey: ecdsa.PublicKey{
Curve: elliptic.P256(),
},
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95
D: new(big.Int).SetBytes(key),
}
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119
signer.X, signer.Y = signer.Curve.ScalarBaseMult(key)
// We need to hash first for ECDSA
// - https://pkg.go.dev/crypto/ecdsa#SignASN1
hashed := sha256.Sum256(b)
sig, err = ecdsa.SignASN1(rand.Reader, signer, hashed[:])
if err != nil {
return nil, err
}
}
default:
return nil, fmt.Errorf("invalid curve: %s", c.details.Curve)
}
c.signature = sig
return c, nil
} }
func ip2int(ip []byte) uint32 { func ip2int(ip []byte) uint32 {

218
cert/cert_v1_test.go Normal file
View File

@ -0,0 +1,218 @@
package cert
import (
"fmt"
"net/netip"
"testing"
"time"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
)
func TestCertificateV1_Marshal(t *testing.T) {
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := []byte("1234567890abcedfghij1234567890ab")
nc := certificateV1{
details: detailsV1{
name: "testing",
networks: []netip.Prefix{
mustParsePrefixUnmapped("10.1.1.1/24"),
mustParsePrefixUnmapped("10.1.1.2/16"),
},
unsafeNetworks: []netip.Prefix{
mustParsePrefixUnmapped("9.1.1.2/24"),
mustParsePrefixUnmapped("9.1.1.3/16"),
},
groups: []string{"test-group1", "test-group2", "test-group3"},
notBefore: before,
notAfter: after,
publicKey: pubKey,
isCA: false,
issuer: "1234567890abcedfghij1234567890ab",
},
signature: []byte("1234567890abcedfghij1234567890ab"),
}
b, err := nc.Marshal()
assert.Nil(t, err)
//t.Log("Cert size:", len(b))
nc2, err := unmarshalCertificateV1(b, nil)
assert.Nil(t, err)
assert.Equal(t, nc.Version(), Version1)
assert.Equal(t, nc.Curve(), Curve_CURVE25519)
assert.Equal(t, nc.Signature(), nc2.Signature())
assert.Equal(t, nc.Name(), nc2.Name())
assert.Equal(t, nc.NotBefore(), nc2.NotBefore())
assert.Equal(t, nc.NotAfter(), nc2.NotAfter())
assert.Equal(t, nc.PublicKey(), nc2.PublicKey())
assert.Equal(t, nc.IsCA(), nc2.IsCA())
assert.Equal(t, nc.Networks(), nc2.Networks())
assert.Equal(t, nc.UnsafeNetworks(), nc2.UnsafeNetworks())
assert.Equal(t, nc.Groups(), nc2.Groups())
}
func TestCertificateV1_Expired(t *testing.T) {
nc := certificateV1{
details: detailsV1{
notBefore: time.Now().Add(time.Second * -60).Round(time.Second),
notAfter: time.Now().Add(time.Second * 60).Round(time.Second),
},
}
assert.True(t, nc.Expired(time.Now().Add(time.Hour)))
assert.True(t, nc.Expired(time.Now().Add(-time.Hour)))
assert.False(t, nc.Expired(time.Now()))
}
func TestCertificateV1_MarshalJSON(t *testing.T) {
time.Local = time.UTC
pubKey := []byte("1234567890abcedfghij1234567890ab")
nc := certificateV1{
details: detailsV1{
name: "testing",
networks: []netip.Prefix{
mustParsePrefixUnmapped("10.1.1.1/24"),
mustParsePrefixUnmapped("10.1.1.2/16"),
},
unsafeNetworks: []netip.Prefix{
mustParsePrefixUnmapped("9.1.1.2/24"),
mustParsePrefixUnmapped("9.1.1.3/16"),
},
groups: []string{"test-group1", "test-group2", "test-group3"},
notBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC),
notAfter: time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC),
publicKey: pubKey,
isCA: false,
issuer: "1234567890abcedfghij1234567890ab",
},
signature: []byte("1234567890abcedfghij1234567890ab"),
}
b, err := nc.MarshalJSON()
assert.Nil(t, err)
assert.Equal(
t,
"{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"version\":1}",
string(b),
)
}
func TestCertificateV1_VerifyPrivateKey(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
assert.Nil(t, err)
_, _, caKey2, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
assert.Nil(t, err)
err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2)
assert.NotNil(t, err)
c, _, priv, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
assert.NoError(t, err)
assert.Empty(t, b)
assert.Equal(t, Curve_CURVE25519, curve)
err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv)
assert.Nil(t, err)
_, priv2 := X25519Keypair()
err = c.VerifyPrivateKey(Curve_CURVE25519, priv2)
assert.NotNil(t, err)
}
func TestCertificateV1_VerifyPrivateKeyP256(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
err := ca.VerifyPrivateKey(Curve_P256, caKey)
assert.Nil(t, err)
_, _, caKey2, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
assert.Nil(t, err)
err = ca.VerifyPrivateKey(Curve_P256, caKey2)
assert.NotNil(t, err)
c, _, priv, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
assert.NoError(t, err)
assert.Empty(t, b)
assert.Equal(t, Curve_P256, curve)
err = c.VerifyPrivateKey(Curve_P256, rawPriv)
assert.Nil(t, err)
_, priv2 := P256Keypair()
err = c.VerifyPrivateKey(Curve_P256, priv2)
assert.NotNil(t, err)
}
// Ensure that upgrading the protobuf library does not change how certificates
// are marshalled, since this would break signature verification
func TestMarshalingCertificateV1Consistency(t *testing.T) {
before := time.Date(1970, time.January, 1, 1, 1, 1, 1, time.UTC)
after := time.Date(9999, time.January, 1, 1, 1, 1, 1, time.UTC)
pubKey := []byte("1234567890abcedfghij1234567890ab")
nc := certificateV1{
details: detailsV1{
name: "testing",
networks: []netip.Prefix{
mustParsePrefixUnmapped("10.1.1.2/16"),
mustParsePrefixUnmapped("10.1.1.1/24"),
},
unsafeNetworks: []netip.Prefix{
mustParsePrefixUnmapped("9.1.1.3/16"),
mustParsePrefixUnmapped("9.1.1.2/24"),
},
groups: []string{"test-group1", "test-group2", "test-group3"},
notBefore: before,
notAfter: after,
publicKey: pubKey,
isCA: false,
issuer: "1234567890abcedfghij1234567890ab",
},
signature: []byte("1234567890abcedfghij1234567890ab"),
}
b, err := nc.Marshal()
require.Nil(t, err)
assert.Equal(t, "0a8e010a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf1220313233343536373839306162636564666768696a313233343536373839306162", fmt.Sprintf("%x", b))
b, err = proto.Marshal(nc.getRawDetails())
assert.Nil(t, err)
assert.Equal(t, "0a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b))
}
func TestCertificateV1_Copy(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
cc := c.Copy()
test.AssertDeepCopyEqual(t, c, cc)
}
func TestUnmarshalCertificateV1(t *testing.T) {
// Test that we don't panic with an invalid certificate (#332)
data := []byte("\x98\x00\x00")
_, err := unmarshalCertificateV1(data, nil)
assert.EqualError(t, err, "encoded Details was nil")
}
func appendByteSlices(b ...[]byte) []byte {
retSlice := []byte{}
for _, v := range b {
retSlice = append(retSlice, v...)
}
return retSlice
}
func mustParsePrefixUnmapped(s string) netip.Prefix {
prefix := netip.MustParsePrefix(s)
return netip.PrefixFrom(prefix.Addr().Unmap(), prefix.Bits())
}

37
cert/cert_v2.asn1 Normal file
View File

@ -0,0 +1,37 @@
Nebula DEFINITIONS AUTOMATIC TAGS ::= BEGIN
Name ::= UTF8String (SIZE (1..253))
Time ::= INTEGER (0..18446744073709551615) -- Seconds since unix epoch, uint64 maximum
Network ::= OCTET STRING (SIZE (5,17)) -- IP addresses are 4 or 16 bytes + 1 byte for the prefix length
Curve ::= ENUMERATED {
curve25519 (0),
p256 (1)
}
-- The maximum size of a certificate must not exceed 65536 bytes
Certificate ::= SEQUENCE {
details OCTET STRING,
curve Curve DEFAULT curve25519,
publicKey OCTET STRING,
-- signature(details + curve + publicKey) using the appropriate method for curve
signature OCTET STRING
}
Details ::= SEQUENCE {
name Name,
-- At least 1 ipv4 or ipv6 address must be present if isCA is false
networks SEQUENCE OF Network OPTIONAL,
unsafeNetworks SEQUENCE OF Network OPTIONAL,
groups SEQUENCE OF Name OPTIONAL,
isCA BOOLEAN DEFAULT false,
notBefore Time,
notAfter Time,
-- issuer is only required if isCA is false, if isCA is true then it must not be present
issuer OCTET STRING OPTIONAL,
...
-- New fields can be added below here
}
END

730
cert/cert_v2.go Normal file
View File

@ -0,0 +1,730 @@
package cert
import (
"bytes"
"crypto/ecdh"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"encoding/pem"
"fmt"
"net/netip"
"slices"
"time"
"golang.org/x/crypto/cryptobyte"
"golang.org/x/crypto/cryptobyte/asn1"
"golang.org/x/crypto/curve25519"
)
const (
classConstructed = 0x20
classContextSpecific = 0x80
TagCertDetails = 0 | classConstructed | classContextSpecific
TagCertCurve = 1 | classContextSpecific
TagCertPublicKey = 2 | classContextSpecific
TagCertSignature = 3 | classContextSpecific
TagDetailsName = 0 | classContextSpecific
TagDetailsNetworks = 1 | classConstructed | classContextSpecific
TagDetailsUnsafeNetworks = 2 | classConstructed | classContextSpecific
TagDetailsGroups = 3 | classConstructed | classContextSpecific
TagDetailsIsCA = 4 | classContextSpecific
TagDetailsNotBefore = 5 | classContextSpecific
TagDetailsNotAfter = 6 | classContextSpecific
TagDetailsIssuer = 7 | classContextSpecific
)
const (
// MaxCertificateSize is the maximum length a valid certificate can be
MaxCertificateSize = 65536
// MaxNameLength is limited to a maximum realistic DNS domain name to help facilitate DNS systems
MaxNameLength = 253
// MaxNetworkLength is the maximum length a network value can be.
// 16 bytes for an ipv6 address + 1 byte for the prefix length
MaxNetworkLength = 17
)
type certificateV2 struct {
details detailsV2
// RawDetails contains the entire asn.1 DER encoded Details struct
// This is to benefit forwards compatibility in signature checking.
// signature(RawDetails + Curve + PublicKey) == Signature
rawDetails []byte
curve Curve
publicKey []byte
signature []byte
}
type detailsV2 struct {
name string
networks []netip.Prefix // MUST BE SORTED
unsafeNetworks []netip.Prefix // MUST BE SORTED
groups []string
isCA bool
notBefore time.Time
notAfter time.Time
issuer string
}
func (c *certificateV2) Version() Version {
return Version2
}
func (c *certificateV2) Curve() Curve {
return c.curve
}
func (c *certificateV2) Groups() []string {
return c.details.groups
}
func (c *certificateV2) IsCA() bool {
return c.details.isCA
}
func (c *certificateV2) Issuer() string {
return c.details.issuer
}
func (c *certificateV2) Name() string {
return c.details.name
}
func (c *certificateV2) Networks() []netip.Prefix {
return c.details.networks
}
func (c *certificateV2) NotAfter() time.Time {
return c.details.notAfter
}
func (c *certificateV2) NotBefore() time.Time {
return c.details.notBefore
}
func (c *certificateV2) PublicKey() []byte {
return c.publicKey
}
func (c *certificateV2) Signature() []byte {
return c.signature
}
func (c *certificateV2) UnsafeNetworks() []netip.Prefix {
return c.details.unsafeNetworks
}
func (c *certificateV2) Fingerprint() (string, error) {
if len(c.rawDetails) == 0 {
return "", ErrMissingDetails
}
b := make([]byte, len(c.rawDetails)+1+len(c.publicKey)+len(c.signature))
copy(b, c.rawDetails)
b[len(c.rawDetails)] = byte(c.curve)
copy(b[len(c.rawDetails)+1:], c.publicKey)
copy(b[len(c.rawDetails)+1+len(c.publicKey):], c.signature)
sum := sha256.Sum256(b)
return hex.EncodeToString(sum[:]), nil
}
func (c *certificateV2) CheckSignature(key []byte) bool {
if len(c.rawDetails) == 0 {
return false
}
b := make([]byte, len(c.rawDetails)+1+len(c.publicKey))
copy(b, c.rawDetails)
b[len(c.rawDetails)] = byte(c.curve)
copy(b[len(c.rawDetails)+1:], c.publicKey)
switch c.curve {
case Curve_CURVE25519:
return ed25519.Verify(key, b, c.signature)
case Curve_P256:
x, y := elliptic.Unmarshal(elliptic.P256(), key)
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
hashed := sha256.Sum256(b)
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
default:
return false
}
}
func (c *certificateV2) Expired(t time.Time) bool {
return c.details.notBefore.After(t) || c.details.notAfter.Before(t)
}
func (c *certificateV2) VerifyPrivateKey(curve Curve, key []byte) error {
if curve != c.curve {
return ErrPublicPrivateCurveMismatch
}
if c.details.isCA {
switch curve {
case Curve_CURVE25519:
// the call to PublicKey below will panic slice bounds out of range otherwise
if len(key) != ed25519.PrivateKeySize {
return ErrInvalidPrivateKey
}
if !ed25519.PublicKey(c.publicKey).Equal(ed25519.PrivateKey(key).Public()) {
return ErrPublicPrivateKeyMismatch
}
case Curve_P256:
privkey, err := ecdh.P256().NewPrivateKey(key)
if err != nil {
return ErrInvalidPrivateKey
}
pub := privkey.PublicKey().Bytes()
if !bytes.Equal(pub, c.publicKey) {
return ErrPublicPrivateKeyMismatch
}
default:
return fmt.Errorf("invalid curve: %s", curve)
}
return nil
}
var pub []byte
switch curve {
case Curve_CURVE25519:
var err error
pub, err = curve25519.X25519(key, curve25519.Basepoint)
if err != nil {
return ErrInvalidPrivateKey
}
case Curve_P256:
privkey, err := ecdh.P256().NewPrivateKey(key)
if err != nil {
return ErrInvalidPrivateKey
}
pub = privkey.PublicKey().Bytes()
default:
return fmt.Errorf("invalid curve: %s", curve)
}
if !bytes.Equal(pub, c.publicKey) {
return ErrPublicPrivateKeyMismatch
}
return nil
}
func (c *certificateV2) String() string {
mb, err := c.marshalJSON()
if err != nil {
return fmt.Sprintf("<error marshalling certificate: %v>", err)
}
b, err := json.MarshalIndent(mb, "", "\t")
if err != nil {
return fmt.Sprintf("<error marshalling certificate: %v>", err)
}
return string(b)
}
func (c *certificateV2) MarshalForHandshakes() ([]byte, error) {
if c.rawDetails == nil {
return nil, ErrEmptyRawDetails
}
var b cryptobyte.Builder
// Outermost certificate
b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
// Add the cert details which is already marshalled
b.AddBytes(c.rawDetails)
// Skipping the curve and public key since those come across in a different part of the handshake
// Add the signature
b.AddASN1(TagCertSignature, func(b *cryptobyte.Builder) {
b.AddBytes(c.signature)
})
})
return b.Bytes()
}
func (c *certificateV2) Marshal() ([]byte, error) {
if c.rawDetails == nil {
return nil, ErrEmptyRawDetails
}
var b cryptobyte.Builder
// Outermost certificate
b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
// Add the cert details which is already marshalled
b.AddBytes(c.rawDetails)
// Add the curve only if its not the default value
if c.curve != Curve_CURVE25519 {
b.AddASN1(TagCertCurve, func(b *cryptobyte.Builder) {
b.AddBytes([]byte{byte(c.curve)})
})
}
// Add the public key if it is not empty
if c.publicKey != nil {
b.AddASN1(TagCertPublicKey, func(b *cryptobyte.Builder) {
b.AddBytes(c.publicKey)
})
}
// Add the signature
b.AddASN1(TagCertSignature, func(b *cryptobyte.Builder) {
b.AddBytes(c.signature)
})
})
return b.Bytes()
}
func (c *certificateV2) MarshalPEM() ([]byte, error) {
b, err := c.Marshal()
if err != nil {
return nil, err
}
return pem.EncodeToMemory(&pem.Block{Type: CertificateV2Banner, Bytes: b}), nil
}
func (c *certificateV2) MarshalJSON() ([]byte, error) {
b, err := c.marshalJSON()
if err != nil {
return nil, err
}
return json.Marshal(b)
}
func (c *certificateV2) marshalJSON() (m, error) {
fp, err := c.Fingerprint()
if err != nil {
return nil, err
}
return m{
"details": m{
"name": c.details.name,
"networks": c.details.networks,
"unsafeNetworks": c.details.unsafeNetworks,
"groups": c.details.groups,
"notBefore": c.details.notBefore,
"notAfter": c.details.notAfter,
"isCa": c.details.isCA,
"issuer": c.details.issuer,
},
"version": Version2,
"publicKey": fmt.Sprintf("%x", c.publicKey),
"curve": c.curve.String(),
"fingerprint": fp,
"signature": fmt.Sprintf("%x", c.Signature()),
}, nil
}
func (c *certificateV2) Copy() Certificate {
nc := &certificateV2{
details: detailsV2{
name: c.details.name,
notBefore: c.details.notBefore,
notAfter: c.details.notAfter,
isCA: c.details.isCA,
issuer: c.details.issuer,
},
curve: c.curve,
publicKey: make([]byte, len(c.publicKey)),
signature: make([]byte, len(c.signature)),
rawDetails: make([]byte, len(c.rawDetails)),
}
if c.details.groups != nil {
nc.details.groups = make([]string, len(c.details.groups))
copy(nc.details.groups, c.details.groups)
}
if c.details.networks != nil {
nc.details.networks = make([]netip.Prefix, len(c.details.networks))
copy(nc.details.networks, c.details.networks)
}
if c.details.unsafeNetworks != nil {
nc.details.unsafeNetworks = make([]netip.Prefix, len(c.details.unsafeNetworks))
copy(nc.details.unsafeNetworks, c.details.unsafeNetworks)
}
copy(nc.rawDetails, c.rawDetails)
copy(nc.signature, c.signature)
copy(nc.publicKey, c.publicKey)
return nc
}
func (c *certificateV2) fromTBSCertificate(t *TBSCertificate) error {
c.details = detailsV2{
name: t.Name,
networks: t.Networks,
unsafeNetworks: t.UnsafeNetworks,
groups: t.Groups,
isCA: t.IsCA,
notBefore: t.NotBefore,
notAfter: t.NotAfter,
issuer: t.issuer,
}
c.curve = t.Curve
c.publicKey = t.PublicKey
return c.validate()
}
func (c *certificateV2) validate() error {
// Empty names are allowed
if len(c.publicKey) == 0 {
return ErrInvalidPublicKey
}
if !c.details.isCA && len(c.details.networks) == 0 {
return NewErrInvalidCertificateProperties("non-CA certificate must contain at least 1 network")
}
hasV4Networks := false
hasV6Networks := false
for _, network := range c.details.networks {
if !network.IsValid() || !network.Addr().IsValid() {
return NewErrInvalidCertificateProperties("invalid network: %s", network)
}
if network.Addr().IsUnspecified() {
return NewErrInvalidCertificateProperties("non-CA certificates must not use the zero address as a network: %s", network)
}
if network.Addr().Zone() != "" {
return NewErrInvalidCertificateProperties("networks may not contain zones: %s", network)
}
if network.Addr().Is4In6() {
return NewErrInvalidCertificateProperties("4in6 networks are not allowed: %s", network)
}
hasV4Networks = hasV4Networks || network.Addr().Is4()
hasV6Networks = hasV6Networks || network.Addr().Is6()
}
slices.SortFunc(c.details.networks, comparePrefix)
err := findDuplicatePrefix(c.details.networks)
if err != nil {
return err
}
for _, network := range c.details.unsafeNetworks {
if !network.IsValid() || !network.Addr().IsValid() {
return NewErrInvalidCertificateProperties("invalid unsafe network: %s", network)
}
if network.Addr().Zone() != "" {
return NewErrInvalidCertificateProperties("unsafe networks may not contain zones: %s", network)
}
if !c.details.isCA {
if network.Addr().Is6() {
if !hasV6Networks {
return NewErrInvalidCertificateProperties("IPv6 unsafe networks require an IPv6 address assignment: %s", network)
}
} else if network.Addr().Is4() {
if !hasV4Networks {
return NewErrInvalidCertificateProperties("IPv4 unsafe networks require an IPv4 address assignment: %s", network)
}
}
}
}
slices.SortFunc(c.details.unsafeNetworks, comparePrefix)
err = findDuplicatePrefix(c.details.unsafeNetworks)
if err != nil {
return err
}
return nil
}
func (c *certificateV2) marshalForSigning() ([]byte, error) {
d, err := c.details.Marshal()
if err != nil {
return nil, fmt.Errorf("marshalling certificate details failed: %w", err)
}
c.rawDetails = d
b := make([]byte, len(c.rawDetails)+1+len(c.publicKey))
copy(b, c.rawDetails)
b[len(c.rawDetails)] = byte(c.curve)
copy(b[len(c.rawDetails)+1:], c.publicKey)
return b, nil
}
func (c *certificateV2) setSignature(b []byte) error {
if len(b) == 0 {
return ErrEmptySignature
}
c.signature = b
return nil
}
func (d *detailsV2) Marshal() ([]byte, error) {
var b cryptobyte.Builder
var err error
// Details are a structure
b.AddASN1(TagCertDetails, func(b *cryptobyte.Builder) {
// Add the name
b.AddASN1(TagDetailsName, func(b *cryptobyte.Builder) {
b.AddBytes([]byte(d.name))
})
// Add the networks if any exist
if len(d.networks) > 0 {
b.AddASN1(TagDetailsNetworks, func(b *cryptobyte.Builder) {
for _, n := range d.networks {
sb, innerErr := n.MarshalBinary()
if innerErr != nil {
// MarshalBinary never returns an error
err = fmt.Errorf("unable to marshal network: %w", innerErr)
return
}
b.AddASN1OctetString(sb)
}
})
}
// Add the unsafe networks if any exist
if len(d.unsafeNetworks) > 0 {
b.AddASN1(TagDetailsUnsafeNetworks, func(b *cryptobyte.Builder) {
for _, n := range d.unsafeNetworks {
sb, innerErr := n.MarshalBinary()
if innerErr != nil {
// MarshalBinary never returns an error
err = fmt.Errorf("unable to marshal unsafe network: %w", innerErr)
return
}
b.AddASN1OctetString(sb)
}
})
}
// Add groups if any exist
if len(d.groups) > 0 {
b.AddASN1(TagDetailsGroups, func(b *cryptobyte.Builder) {
for _, group := range d.groups {
b.AddASN1(asn1.UTF8String, func(b *cryptobyte.Builder) {
b.AddBytes([]byte(group))
})
}
})
}
// Add IsCA only if true
if d.isCA {
b.AddASN1(TagDetailsIsCA, func(b *cryptobyte.Builder) {
b.AddUint8(0xff)
})
}
// Add not before
b.AddASN1Int64WithTag(d.notBefore.Unix(), TagDetailsNotBefore)
// Add not after
b.AddASN1Int64WithTag(d.notAfter.Unix(), TagDetailsNotAfter)
// Add the issuer if present
if d.issuer != "" {
issuerBytes, innerErr := hex.DecodeString(d.issuer)
if innerErr != nil {
err = fmt.Errorf("failed to decode issuer: %w", innerErr)
return
}
b.AddASN1(TagDetailsIssuer, func(b *cryptobyte.Builder) {
b.AddBytes(issuerBytes)
})
}
})
if err != nil {
return nil, err
}
return b.Bytes()
}
func unmarshalCertificateV2(b []byte, publicKey []byte, curve Curve) (*certificateV2, error) {
l := len(b)
if l == 0 || l > MaxCertificateSize {
return nil, ErrBadFormat
}
input := cryptobyte.String(b)
// Open the envelope
if !input.ReadASN1(&input, asn1.SEQUENCE) || input.Empty() {
return nil, ErrBadFormat
}
// Grab the cert details, we need to preserve the tag and length
var rawDetails cryptobyte.String
if !input.ReadASN1Element(&rawDetails, TagCertDetails) || rawDetails.Empty() {
return nil, ErrBadFormat
}
//Maybe grab the curve
var rawCurve byte
if !readOptionalASN1Byte(&input, &rawCurve, TagCertCurve, byte(curve)) {
return nil, ErrBadFormat
}
curve = Curve(rawCurve)
// Maybe grab the public key
var rawPublicKey cryptobyte.String
if len(publicKey) > 0 {
rawPublicKey = publicKey
} else if !input.ReadOptionalASN1(&rawPublicKey, nil, TagCertPublicKey) {
return nil, ErrBadFormat
}
if len(rawPublicKey) == 0 {
return nil, ErrBadFormat
}
// Grab the signature
var rawSignature cryptobyte.String
if !input.ReadASN1(&rawSignature, TagCertSignature) || rawSignature.Empty() {
return nil, ErrBadFormat
}
// Finally unmarshal the details
details, err := unmarshalDetails(rawDetails)
if err != nil {
return nil, err
}
c := &certificateV2{
details: details,
rawDetails: rawDetails,
curve: curve,
publicKey: rawPublicKey,
signature: rawSignature,
}
err = c.validate()
if err != nil {
return nil, err
}
return c, nil
}
func unmarshalDetails(b cryptobyte.String) (detailsV2, error) {
// Open the envelope
if !b.ReadASN1(&b, TagCertDetails) || b.Empty() {
return detailsV2{}, ErrBadFormat
}
// Read the name
var name cryptobyte.String
if !b.ReadASN1(&name, TagDetailsName) || name.Empty() || len(name) > MaxNameLength {
return detailsV2{}, ErrBadFormat
}
// Read the network addresses
var subString cryptobyte.String
var found bool
if !b.ReadOptionalASN1(&subString, &found, TagDetailsNetworks) {
return detailsV2{}, ErrBadFormat
}
var networks []netip.Prefix
var val cryptobyte.String
if found {
for !subString.Empty() {
if !subString.ReadASN1(&val, asn1.OCTET_STRING) || val.Empty() || len(val) > MaxNetworkLength {
return detailsV2{}, ErrBadFormat
}
var n netip.Prefix
if err := n.UnmarshalBinary(val); err != nil {
return detailsV2{}, ErrBadFormat
}
networks = append(networks, n)
}
}
// Read out any unsafe networks
if !b.ReadOptionalASN1(&subString, &found, TagDetailsUnsafeNetworks) {
return detailsV2{}, ErrBadFormat
}
var unsafeNetworks []netip.Prefix
if found {
for !subString.Empty() {
if !subString.ReadASN1(&val, asn1.OCTET_STRING) || val.Empty() || len(val) > MaxNetworkLength {
return detailsV2{}, ErrBadFormat
}
var n netip.Prefix
if err := n.UnmarshalBinary(val); err != nil {
return detailsV2{}, ErrBadFormat
}
unsafeNetworks = append(unsafeNetworks, n)
}
}
// Read out any groups
if !b.ReadOptionalASN1(&subString, &found, TagDetailsGroups) {
return detailsV2{}, ErrBadFormat
}
var groups []string
if found {
for !subString.Empty() {
if !subString.ReadASN1(&val, asn1.UTF8String) || val.Empty() {
return detailsV2{}, ErrBadFormat
}
groups = append(groups, string(val))
}
}
// Read out IsCA
var isCa bool
if !readOptionalASN1Boolean(&b, &isCa, TagDetailsIsCA, false) {
return detailsV2{}, ErrBadFormat
}
// Read not before and not after
var notBefore int64
if !b.ReadASN1Int64WithTag(&notBefore, TagDetailsNotBefore) {
return detailsV2{}, ErrBadFormat
}
var notAfter int64
if !b.ReadASN1Int64WithTag(&notAfter, TagDetailsNotAfter) {
return detailsV2{}, ErrBadFormat
}
// Read issuer
var issuer cryptobyte.String
if !b.ReadOptionalASN1(&issuer, nil, TagDetailsIssuer) {
return detailsV2{}, ErrBadFormat
}
return detailsV2{
name: string(name),
networks: networks,
unsafeNetworks: unsafeNetworks,
groups: groups,
isCA: isCa,
notBefore: time.Unix(notBefore, 0),
notAfter: time.Unix(notAfter, 0),
issuer: hex.EncodeToString(issuer),
}, nil
}

267
cert/cert_v2_test.go Normal file
View File

@ -0,0 +1,267 @@
package cert
import (
"crypto/ed25519"
"crypto/rand"
"encoding/hex"
"net/netip"
"slices"
"testing"
"time"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCertificateV2_Marshal(t *testing.T) {
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := []byte("1234567890abcedfghij1234567890ab")
nc := certificateV2{
details: detailsV2{
name: "testing",
networks: []netip.Prefix{
mustParsePrefixUnmapped("10.1.1.2/16"),
mustParsePrefixUnmapped("10.1.1.1/24"),
},
unsafeNetworks: []netip.Prefix{
mustParsePrefixUnmapped("9.1.1.3/16"),
mustParsePrefixUnmapped("9.1.1.2/24"),
},
groups: []string{"test-group1", "test-group2", "test-group3"},
notBefore: before,
notAfter: after,
isCA: false,
issuer: "1234567890abcdef1234567890abcdef",
},
signature: []byte("1234567890abcdef1234567890abcdef"),
publicKey: pubKey,
}
db, err := nc.details.Marshal()
require.NoError(t, err)
nc.rawDetails = db
b, err := nc.Marshal()
require.Nil(t, err)
//t.Log("Cert size:", len(b))
nc2, err := unmarshalCertificateV2(b, nil, Curve_CURVE25519)
assert.Nil(t, err)
assert.Equal(t, nc.Version(), Version2)
assert.Equal(t, nc.Curve(), Curve_CURVE25519)
assert.Equal(t, nc.Signature(), nc2.Signature())
assert.Equal(t, nc.Name(), nc2.Name())
assert.Equal(t, nc.NotBefore(), nc2.NotBefore())
assert.Equal(t, nc.NotAfter(), nc2.NotAfter())
assert.Equal(t, nc.PublicKey(), nc2.PublicKey())
assert.Equal(t, nc.IsCA(), nc2.IsCA())
assert.Equal(t, nc.Issuer(), nc2.Issuer())
// unmarshalling will sort networks and unsafeNetworks, we need to do the same
// but first make sure it fails
assert.NotEqual(t, nc.Networks(), nc2.Networks())
assert.NotEqual(t, nc.UnsafeNetworks(), nc2.UnsafeNetworks())
slices.SortFunc(nc.details.networks, comparePrefix)
slices.SortFunc(nc.details.unsafeNetworks, comparePrefix)
assert.Equal(t, nc.Networks(), nc2.Networks())
assert.Equal(t, nc.UnsafeNetworks(), nc2.UnsafeNetworks())
assert.Equal(t, nc.Groups(), nc2.Groups())
}
func TestCertificateV2_Expired(t *testing.T) {
nc := certificateV2{
details: detailsV2{
notBefore: time.Now().Add(time.Second * -60).Round(time.Second),
notAfter: time.Now().Add(time.Second * 60).Round(time.Second),
},
}
assert.True(t, nc.Expired(time.Now().Add(time.Hour)))
assert.True(t, nc.Expired(time.Now().Add(-time.Hour)))
assert.False(t, nc.Expired(time.Now()))
}
func TestCertificateV2_MarshalJSON(t *testing.T) {
time.Local = time.UTC
pubKey := []byte("1234567890abcedf1234567890abcedf")
nc := certificateV2{
details: detailsV2{
name: "testing",
networks: []netip.Prefix{
mustParsePrefixUnmapped("10.1.1.1/24"),
mustParsePrefixUnmapped("10.1.1.2/16"),
},
unsafeNetworks: []netip.Prefix{
mustParsePrefixUnmapped("9.1.1.2/24"),
mustParsePrefixUnmapped("9.1.1.3/16"),
},
groups: []string{"test-group1", "test-group2", "test-group3"},
notBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC),
notAfter: time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC),
isCA: false,
issuer: "1234567890abcedf1234567890abcedf",
},
publicKey: pubKey,
signature: []byte("1234567890abcedf1234567890abcedf1234567890abcedf1234567890abcedf"),
}
b, err := nc.MarshalJSON()
assert.ErrorIs(t, err, ErrMissingDetails)
rd, err := nc.details.Marshal()
assert.NoError(t, err)
nc.rawDetails = rd
b, err = nc.MarshalJSON()
assert.Nil(t, err)
assert.Equal(
t,
"{\"curve\":\"CURVE25519\",\"details\":{\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedf1234567890abcedf\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"152d9a7400c1e001cb76cffd035215ebb351f69eeb797f7f847dd086e15e56dd\",\"publicKey\":\"3132333435363738393061626365646631323334353637383930616263656466\",\"signature\":\"31323334353637383930616263656466313233343536373839306162636564663132333435363738393061626365646631323334353637383930616263656466\",\"version\":2}",
string(b),
)
}
func TestCertificateV2_VerifyPrivateKey(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
assert.Nil(t, err)
err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey[:16])
assert.ErrorIs(t, err, ErrInvalidPrivateKey)
_, caKey2, err := ed25519.GenerateKey(rand.Reader)
require.Nil(t, err)
err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2)
assert.ErrorIs(t, err, ErrPublicPrivateKeyMismatch)
c, _, priv, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
assert.NoError(t, err)
assert.Empty(t, b)
assert.Equal(t, Curve_CURVE25519, curve)
err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv)
assert.Nil(t, err)
_, priv2 := X25519Keypair()
err = c.VerifyPrivateKey(Curve_P256, priv2)
assert.ErrorIs(t, err, ErrPublicPrivateCurveMismatch)
err = c.VerifyPrivateKey(Curve_CURVE25519, priv2)
assert.ErrorIs(t, err, ErrPublicPrivateKeyMismatch)
err = c.VerifyPrivateKey(Curve_CURVE25519, priv2[:16])
assert.ErrorIs(t, err, ErrInvalidPrivateKey)
ac, ok := c.(*certificateV2)
require.True(t, ok)
ac.curve = Curve(99)
err = c.VerifyPrivateKey(Curve(99), priv2)
assert.EqualError(t, err, "invalid curve: 99")
ca2, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
assert.Nil(t, err)
err = ca2.VerifyPrivateKey(Curve_P256, caKey2[:16])
assert.ErrorIs(t, err, ErrInvalidPrivateKey)
c, _, priv, _ = NewTestCert(Version2, Curve_P256, ca2, caKey2, "test", time.Time{}, time.Time{}, nil, nil, nil)
rawPriv, b, curve, err = UnmarshalPrivateKeyFromPEM(priv)
err = c.VerifyPrivateKey(Curve_P256, priv[:16])
assert.ErrorIs(t, err, ErrInvalidPrivateKey)
err = c.VerifyPrivateKey(Curve_P256, priv)
assert.ErrorIs(t, err, ErrInvalidPrivateKey)
aCa, ok := ca2.(*certificateV2)
require.True(t, ok)
aCa.curve = Curve(99)
err = aCa.VerifyPrivateKey(Curve(99), priv2)
assert.EqualError(t, err, "invalid curve: 99")
}
func TestCertificateV2_VerifyPrivateKeyP256(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
err := ca.VerifyPrivateKey(Curve_P256, caKey)
assert.Nil(t, err)
_, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
assert.Nil(t, err)
err = ca.VerifyPrivateKey(Curve_P256, caKey2)
assert.NotNil(t, err)
c, _, priv, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
assert.NoError(t, err)
assert.Empty(t, b)
assert.Equal(t, Curve_P256, curve)
err = c.VerifyPrivateKey(Curve_P256, rawPriv)
assert.Nil(t, err)
_, priv2 := P256Keypair()
err = c.VerifyPrivateKey(Curve_P256, priv2)
assert.NotNil(t, err)
}
func TestCertificateV2_Copy(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
cc := c.Copy()
test.AssertDeepCopyEqual(t, c, cc)
}
func TestUnmarshalCertificateV2(t *testing.T) {
data := []byte("\x98\x00\x00")
_, err := unmarshalCertificateV2(data, nil, Curve_CURVE25519)
assert.EqualError(t, err, "bad wire format")
}
func TestCertificateV2_marshalForSigningStability(t *testing.T) {
before := time.Date(1996, time.May, 5, 0, 0, 0, 0, time.UTC)
after := before.Add(time.Second * 60).Round(time.Second)
pubKey := []byte("1234567890abcedfghij1234567890ab")
nc := certificateV2{
details: detailsV2{
name: "testing",
networks: []netip.Prefix{
mustParsePrefixUnmapped("10.1.1.2/16"),
mustParsePrefixUnmapped("10.1.1.1/24"),
},
unsafeNetworks: []netip.Prefix{
mustParsePrefixUnmapped("9.1.1.3/16"),
mustParsePrefixUnmapped("9.1.1.2/24"),
},
groups: []string{"test-group1", "test-group2", "test-group3"},
notBefore: before,
notAfter: after,
isCA: false,
issuer: "1234567890abcdef1234567890abcdef",
},
signature: []byte("1234567890abcdef1234567890abcdef"),
publicKey: pubKey,
}
const expectedRawDetailsStr = "a070800774657374696e67a10e04050a0101021004050a01010118a20e0405090101031004050901010218a3270c0b746573742d67726f7570310c0b746573742d67726f7570320c0b746573742d67726f7570338504318bef808604318befbc87101234567890abcdef1234567890abcdef"
expectedRawDetails, err := hex.DecodeString(expectedRawDetailsStr)
require.NoError(t, err)
db, err := nc.details.Marshal()
require.NoError(t, err)
assert.Equal(t, expectedRawDetails, db)
expectedForSigning, err := hex.DecodeString(expectedRawDetailsStr + "00313233343536373839306162636564666768696a313233343536373839306162")
b, err := nc.marshalForSigning()
require.NoError(t, err)
assert.Equal(t, expectedForSigning, b)
}

View File

@ -2,21 +2,24 @@ package cert
import ( import (
"errors" "errors"
"fmt"
) )
var ( var (
ErrBadFormat = errors.New("bad wire format") ErrBadFormat = errors.New("bad wire format")
ErrRootExpired = errors.New("root certificate is expired") ErrRootExpired = errors.New("root certificate is expired")
ErrExpired = errors.New("certificate is expired") ErrExpired = errors.New("certificate is expired")
ErrNotCA = errors.New("certificate is not a CA") ErrNotCA = errors.New("certificate is not a CA")
ErrNotSelfSigned = errors.New("certificate is not self-signed") ErrNotSelfSigned = errors.New("certificate is not self-signed")
ErrBlockListed = errors.New("certificate is in the block list") ErrBlockListed = errors.New("certificate is in the block list")
ErrFingerprintMismatch = errors.New("certificate fingerprint did not match") ErrFingerprintMismatch = errors.New("certificate fingerprint did not match")
ErrSignatureMismatch = errors.New("certificate signature did not match") ErrSignatureMismatch = errors.New("certificate signature did not match")
ErrInvalidPublicKeyLength = errors.New("invalid public key length") ErrInvalidPublicKey = errors.New("invalid public key")
ErrInvalidPrivateKeyLength = errors.New("invalid private key length") ErrInvalidPrivateKey = errors.New("invalid private key")
ErrPublicPrivateCurveMismatch = errors.New("public key does not match private key curve")
ErrPrivateKeyEncrypted = errors.New("private key must be decrypted") ErrPublicPrivateKeyMismatch = errors.New("public key and private key are not a pair")
ErrPrivateKeyEncrypted = errors.New("private key must be decrypted")
ErrCaNotFound = errors.New("could not find ca for the certificate")
ErrInvalidPEMBlock = errors.New("input did not contain a valid PEM encoded block") ErrInvalidPEMBlock = errors.New("input did not contain a valid PEM encoded block")
ErrInvalidPEMCertificateBanner = errors.New("bytes did not contain a proper certificate banner") ErrInvalidPEMCertificateBanner = errors.New("bytes did not contain a proper certificate banner")
@ -24,4 +27,23 @@ var (
ErrInvalidPEMX25519PrivateKeyBanner = errors.New("bytes did not contain a proper X25519 private key banner") ErrInvalidPEMX25519PrivateKeyBanner = errors.New("bytes did not contain a proper X25519 private key banner")
ErrInvalidPEMEd25519PublicKeyBanner = errors.New("bytes did not contain a proper Ed25519 public key banner") ErrInvalidPEMEd25519PublicKeyBanner = errors.New("bytes did not contain a proper Ed25519 public key banner")
ErrInvalidPEMEd25519PrivateKeyBanner = errors.New("bytes did not contain a proper Ed25519 private key banner") ErrInvalidPEMEd25519PrivateKeyBanner = errors.New("bytes did not contain a proper Ed25519 private key banner")
ErrNoPeerStaticKey = errors.New("no peer static key was present")
ErrNoPayload = errors.New("provided payload was empty")
ErrMissingDetails = errors.New("certificate did not contain details")
ErrEmptySignature = errors.New("empty signature")
ErrEmptyRawDetails = errors.New("empty rawDetails not allowed")
) )
type ErrInvalidCertificateProperties struct {
str string
}
func NewErrInvalidCertificateProperties(format string, a ...any) error {
return &ErrInvalidCertificateProperties{fmt.Sprintf(format, a...)}
}
func (e *ErrInvalidCertificateProperties) Error() string {
return e.str
}

141
cert/helper_test.go Normal file
View File

@ -0,0 +1,141 @@
package cert
import (
"crypto/ecdh"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"io"
"net/netip"
"time"
"golang.org/x/crypto/curve25519"
"golang.org/x/crypto/ed25519"
)
// NewTestCaCert will create a new ca certificate
func NewTestCaCert(version Version, curve Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) {
var err error
var pub, priv []byte
switch curve {
case Curve_CURVE25519:
pub, priv, err = ed25519.GenerateKey(rand.Reader)
case Curve_P256:
privk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
panic(err)
}
pub = elliptic.Marshal(elliptic.P256(), privk.PublicKey.X, privk.PublicKey.Y)
priv = privk.D.FillBytes(make([]byte, 32))
default:
// There is no default to allow the underlying lib to respond with an error
}
if before.IsZero() {
before = time.Now().Add(time.Second * -60).Round(time.Second)
}
if after.IsZero() {
after = time.Now().Add(time.Second * 60).Round(time.Second)
}
t := &TBSCertificate{
Curve: curve,
Version: version,
Name: "test ca",
NotBefore: time.Unix(before.Unix(), 0),
NotAfter: time.Unix(after.Unix(), 0),
PublicKey: pub,
Networks: networks,
UnsafeNetworks: unsafeNetworks,
Groups: groups,
IsCA: true,
}
c, err := t.Sign(nil, curve, priv)
if err != nil {
panic(err)
}
pem, err := c.MarshalPEM()
if err != nil {
panic(err)
}
return c, pub, priv, pem
}
// NewTestCert will generate a signed certificate with the provided details.
// Expiry times are defaulted if you do not pass them in
func NewTestCert(v Version, curve Curve, ca Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) {
if before.IsZero() {
before = time.Now().Add(time.Second * -60).Round(time.Second)
}
if after.IsZero() {
after = time.Now().Add(time.Second * 60).Round(time.Second)
}
if len(networks) == 0 {
networks = []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")}
}
var pub, priv []byte
switch curve {
case Curve_CURVE25519:
pub, priv = X25519Keypair()
case Curve_P256:
pub, priv = P256Keypair()
default:
panic("unknown curve")
}
nc := &TBSCertificate{
Version: v,
Curve: curve,
Name: name,
Networks: networks,
UnsafeNetworks: unsafeNetworks,
Groups: groups,
NotBefore: time.Unix(before.Unix(), 0),
NotAfter: time.Unix(after.Unix(), 0),
PublicKey: pub,
IsCA: false,
}
c, err := nc.Sign(ca, ca.Curve(), key)
if err != nil {
panic(err)
}
pem, err := c.MarshalPEM()
if err != nil {
panic(err)
}
return c, pub, MarshalPrivateKeyToPEM(curve, priv), pem
}
func X25519Keypair() ([]byte, []byte) {
privkey := make([]byte, 32)
if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
panic(err)
}
pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint)
if err != nil {
panic(err)
}
return pubkey, privkey
}
func P256Keypair() ([]byte, []byte) {
privkey, err := ecdh.P256().GenerateKey(rand.Reader)
if err != nil {
panic(err)
}
pubkey := privkey.PublicKey()
return pubkey.Bytes(), privkey.Bytes()
}

View File

@ -30,19 +30,25 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
return nil, r, ErrInvalidPEMBlock return nil, r, ErrInvalidPEMBlock
} }
var c Certificate
var err error
switch p.Type { switch p.Type {
// Implementations must validate the resulting certificate contains valid information
case CertificateBanner: case CertificateBanner:
c, err := unmarshalCertificateV1(p.Bytes, true) c, err = unmarshalCertificateV1(p.Bytes, nil)
if err != nil {
return nil, nil, err
}
return c, r, nil
case CertificateV2Banner: case CertificateV2Banner:
//TODO c, err = unmarshalCertificateV2(p.Bytes, nil, Curve_CURVE25519)
panic("TODO")
default: default:
return nil, r, ErrInvalidPEMCertificateBanner return nil, r, ErrInvalidPEMCertificateBanner
} }
if err != nil {
return nil, r, err
}
return c, r, nil
} }
func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte { func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {

View File

@ -1,11 +1,15 @@
package cert package cert
import ( import (
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rand"
"crypto/sha256"
"fmt" "fmt"
"math/big"
"net/netip" "net/netip"
"time" "time"
"github.com/slackhq/nebula/pkclient"
) )
// TBSCertificate represents a certificate intended to be signed. // TBSCertificate represents a certificate intended to be signed.
@ -24,28 +28,61 @@ type TBSCertificate struct {
issuer string issuer string
} }
type beingSignedCertificate interface {
// fromTBSCertificate copies the values from the TBSCertificate to this versions internal representation
// Implementations must validate the resulting certificate contains valid information
fromTBSCertificate(*TBSCertificate) error
// marshalForSigning returns the bytes that should be signed
marshalForSigning() ([]byte, error)
// setSignature sets the signature for the certificate that has just been signed. The signature must not be blank.
setSignature([]byte) error
}
type SignerLambda func(certBytes []byte) ([]byte, error)
// Sign will create a sealed certificate using details provided by the TBSCertificate as long as those // Sign will create a sealed certificate using details provided by the TBSCertificate as long as those
// details do not violate constraints of the signing certificate. // details do not violate constraints of the signing certificate.
// If the TBSCertificate is a CA then signer must be nil. // If the TBSCertificate is a CA then signer must be nil.
func (t *TBSCertificate) Sign(signer Certificate, curve Curve, key []byte) (Certificate, error) { func (t *TBSCertificate) Sign(signer Certificate, curve Curve, key []byte) (Certificate, error) {
return t.sign(signer, curve, key, nil) switch t.Curve {
} case Curve_CURVE25519:
pk := ed25519.PrivateKey(key)
func (t *TBSCertificate) SignPkcs11(signer Certificate, curve Curve, client *pkclient.PKClient) (Certificate, error) { sp := func(certBytes []byte) ([]byte, error) {
if curve != Curve_P256 { sig := ed25519.Sign(pk, certBytes)
return nil, fmt.Errorf("only P256 is supported by PKCS#11") return sig, nil
}
return t.SignWith(signer, curve, sp)
case Curve_P256:
pk := &ecdsa.PrivateKey{
PublicKey: ecdsa.PublicKey{
Curve: elliptic.P256(),
},
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95
D: new(big.Int).SetBytes(key),
}
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119
pk.X, pk.Y = pk.Curve.ScalarBaseMult(key)
sp := func(certBytes []byte) ([]byte, error) {
// We need to hash first for ECDSA
// - https://pkg.go.dev/crypto/ecdsa#SignASN1
hashed := sha256.Sum256(certBytes)
return ecdsa.SignASN1(rand.Reader, pk, hashed[:])
}
return t.SignWith(signer, curve, sp)
default:
return nil, fmt.Errorf("invalid curve: %s", t.Curve)
} }
return t.sign(signer, curve, nil, client)
} }
func (t *TBSCertificate) sign(signer Certificate, curve Curve, key []byte, client *pkclient.PKClient) (Certificate, error) { // SignWith does the same thing as sign, but uses the function in `sp` to calculate the signature.
// You should only use SignWith if you do not have direct access to your private key.
func (t *TBSCertificate) SignWith(signer Certificate, curve Curve, sp SignerLambda) (Certificate, error) {
if curve != t.Curve { if curve != t.Curve {
return nil, fmt.Errorf("curve in cert and private key supplied don't match") return nil, fmt.Errorf("curve in cert and private key supplied don't match")
} }
//TODO: make sure we have all minimum properties to sign, like a public key
if signer != nil { if signer != nil {
if t.IsCA { if t.IsCA {
return nil, fmt.Errorf("can not sign a CA certificate with another") return nil, fmt.Errorf("can not sign a CA certificate with another")
@ -67,10 +104,64 @@ func (t *TBSCertificate) sign(signer Certificate, curve Curve, key []byte, clien
} }
} }
var c beingSignedCertificate
switch t.Version { switch t.Version {
case Version1: case Version1:
return signV1(t, curve, key, client) c = &certificateV1{}
err := c.fromTBSCertificate(t)
if err != nil {
return nil, err
}
case Version2:
c = &certificateV2{}
err := c.fromTBSCertificate(t)
if err != nil {
return nil, err
}
default: default:
return nil, fmt.Errorf("unknown cert version %d", t.Version) return nil, fmt.Errorf("unknown cert version %d", t.Version)
} }
certBytes, err := c.marshalForSigning()
if err != nil {
return nil, err
}
sig, err := sp(certBytes)
if err != nil {
return nil, err
}
err = c.setSignature(sig)
if err != nil {
return nil, err
}
sc, ok := c.(Certificate)
if !ok {
return nil, fmt.Errorf("invalid certificate")
}
return sc, nil
}
func comparePrefix(a, b netip.Prefix) int {
addr := a.Addr().Compare(b.Addr())
if addr == 0 {
return a.Bits() - b.Bits()
}
return addr
}
// findDuplicatePrefix returns an error if there is a duplicate prefix in the pre-sorted input slice sortedPrefixes
func findDuplicatePrefix(sortedPrefixes []netip.Prefix) error {
if len(sortedPrefixes) < 2 {
return nil
}
for i := 1; i < len(sortedPrefixes); i++ {
if comparePrefix(sortedPrefixes[i], sortedPrefixes[i-1]) == 0 {
return NewErrInvalidCertificateProperties("duplicate network detected: %v", sortedPrefixes[i])
}
}
return nil
} }

90
cert/sign_test.go Normal file
View File

@ -0,0 +1,90 @@
package cert
import (
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rand"
"net/netip"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestCertificateV1_Sign(t *testing.T) {
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := []byte("1234567890abcedfghij1234567890ab")
tbs := TBSCertificate{
Version: Version1,
Name: "testing",
Networks: []netip.Prefix{
mustParsePrefixUnmapped("10.1.1.1/24"),
mustParsePrefixUnmapped("10.1.1.2/16"),
},
UnsafeNetworks: []netip.Prefix{
mustParsePrefixUnmapped("9.1.1.2/24"),
mustParsePrefixUnmapped("9.1.1.3/24"),
},
Groups: []string{"test-group1", "test-group2", "test-group3"},
NotBefore: before,
NotAfter: after,
PublicKey: pubKey,
IsCA: false,
}
pub, priv, err := ed25519.GenerateKey(rand.Reader)
c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_CURVE25519, priv)
assert.Nil(t, err)
assert.NotNil(t, c)
assert.True(t, c.CheckSignature(pub))
b, err := c.Marshal()
assert.Nil(t, err)
uc, err := unmarshalCertificateV1(b, nil)
assert.Nil(t, err)
assert.NotNil(t, uc)
}
func TestCertificateV1_SignP256(t *testing.T) {
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := []byte("01234567890abcedfghij1234567890ab1234567890abcedfghij1234567890ab")
tbs := TBSCertificate{
Version: Version1,
Name: "testing",
Networks: []netip.Prefix{
mustParsePrefixUnmapped("10.1.1.1/24"),
mustParsePrefixUnmapped("10.1.1.2/16"),
},
UnsafeNetworks: []netip.Prefix{
mustParsePrefixUnmapped("9.1.1.2/24"),
mustParsePrefixUnmapped("9.1.1.3/16"),
},
Groups: []string{"test-group1", "test-group2", "test-group3"},
NotBefore: before,
NotAfter: after,
PublicKey: pubKey,
IsCA: false,
Curve: Curve_P256,
}
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
assert.NoError(t, err)
pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y)
rawPriv := priv.D.FillBytes(make([]byte, 32))
c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_P256, rawPriv)
assert.Nil(t, err)
assert.NotNil(t, c)
assert.True(t, c.CheckSignature(pub))
b, err := c.Marshal()
assert.Nil(t, err)
uc, err := unmarshalCertificateV1(b, nil)
assert.Nil(t, err)
assert.NotNil(t, uc)
}

View File

@ -1,6 +1,9 @@
package e2e package cert_test
import ( import (
"crypto/ecdh"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand" "crypto/rand"
"io" "io"
"net/netip" "net/netip"
@ -11,9 +14,26 @@ import (
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
) )
// NewTestCaCert will generate a CA cert // NewTestCaCert will create a new ca certificate
func NewTestCaCert(before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) { func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) {
pub, priv, err := ed25519.GenerateKey(rand.Reader) var err error
var pub, priv []byte
switch curve {
case cert.Curve_CURVE25519:
pub, priv, err = ed25519.GenerateKey(rand.Reader)
case cert.Curve_P256:
privk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
panic(err)
}
pub = elliptic.Marshal(elliptic.P256(), privk.PublicKey.X, privk.PublicKey.Y)
priv = privk.D.FillBytes(make([]byte, 32))
default:
// There is no default to allow the underlying lib to respond with an error
}
if before.IsZero() { if before.IsZero() {
before = time.Now().Add(time.Second * -60).Round(time.Second) before = time.Now().Add(time.Second * -60).Round(time.Second)
} }
@ -22,7 +42,8 @@ func NewTestCaCert(before, after time.Time, networks, unsafeNetworks []netip.Pre
} }
t := &cert.TBSCertificate{ t := &cert.TBSCertificate{
Version: cert.Version1, Curve: curve,
Version: version,
Name: "test ca", Name: "test ca",
NotBefore: time.Unix(before.Unix(), 0), NotBefore: time.Unix(before.Unix(), 0),
NotAfter: time.Unix(after.Unix(), 0), NotAfter: time.Unix(after.Unix(), 0),
@ -33,7 +54,7 @@ func NewTestCaCert(before, after time.Time, networks, unsafeNetworks []netip.Pre
IsCA: true, IsCA: true,
} }
c, err := t.Sign(nil, cert.Curve_CURVE25519, priv) c, err := t.Sign(nil, curve, priv)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -48,7 +69,7 @@ func NewTestCaCert(before, after time.Time, networks, unsafeNetworks []netip.Pre
// NewTestCert will generate a signed certificate with the provided details. // NewTestCert will generate a signed certificate with the provided details.
// Expiry times are defaulted if you do not pass them in // Expiry times are defaulted if you do not pass them in
func NewTestCert(ca cert.Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) { func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) {
if before.IsZero() { if before.IsZero() {
before = time.Now().Add(time.Second * -60).Round(time.Second) before = time.Now().Add(time.Second * -60).Round(time.Second)
} }
@ -57,9 +78,19 @@ func NewTestCert(ca cert.Certificate, key []byte, name string, before, after tim
after = time.Now().Add(time.Second * 60).Round(time.Second) after = time.Now().Add(time.Second * 60).Round(time.Second)
} }
pub, rawPriv := x25519Keypair() var pub, priv []byte
switch curve {
case cert.Curve_CURVE25519:
pub, priv = X25519Keypair()
case cert.Curve_P256:
pub, priv = P256Keypair()
default:
panic("unknown curve")
}
nc := &cert.TBSCertificate{ nc := &cert.TBSCertificate{
Version: cert.Version1, Version: v,
Curve: curve,
Name: name, Name: name,
Networks: networks, Networks: networks,
UnsafeNetworks: unsafeNetworks, UnsafeNetworks: unsafeNetworks,
@ -80,10 +111,10 @@ func NewTestCert(ca cert.Certificate, key []byte, name string, before, after tim
panic(err) panic(err)
} }
return c, pub, cert.MarshalPrivateKeyToPEM(cert.Curve_CURVE25519, rawPriv), pem return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem
} }
func x25519Keypair() ([]byte, []byte) { func X25519Keypair() ([]byte, []byte) {
privkey := make([]byte, 32) privkey := make([]byte, 32)
if _, err := io.ReadFull(rand.Reader, privkey); err != nil { if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
panic(err) panic(err)
@ -96,3 +127,12 @@ func x25519Keypair() ([]byte, []byte) {
return pubkey, privkey return pubkey, privkey
} }
func P256Keypair() ([]byte, []byte) {
privkey, err := ecdh.P256().GenerateKey(rand.Reader)
if err != nil {
panic(err)
}
pubkey := privkey.PublicKey()
return pubkey.Bytes(), privkey.Bytes()
}

View File

@ -27,34 +27,43 @@ type caFlags struct {
outCertPath *string outCertPath *string
outQRPath *string outQRPath *string
groups *string groups *string
ips *string networks *string
subnets *string unsafeNetworks *string
argonMemory *uint argonMemory *uint
argonIterations *uint argonIterations *uint
argonParallelism *uint argonParallelism *uint
encryption *bool encryption *bool
version *uint
curve *string curve *string
p11url *string p11url *string
// Deprecated options
ips *string
subnets *string
} }
func newCaFlags() *caFlags { func newCaFlags() *caFlags {
cf := caFlags{set: flag.NewFlagSet("ca", flag.ContinueOnError)} cf := caFlags{set: flag.NewFlagSet("ca", flag.ContinueOnError)}
cf.set.Usage = func() {} cf.set.Usage = func() {}
cf.name = cf.set.String("name", "", "Required: name of the certificate authority") cf.name = cf.set.String("name", "", "Required: name of the certificate authority")
cf.version = cf.set.Uint("version", uint(cert.Version2), "Optional: version of the certificate format to use")
cf.duration = cf.set.Duration("duration", time.Duration(time.Hour*8760), "Optional: amount of time the certificate should be valid for. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"") cf.duration = cf.set.Duration("duration", time.Duration(time.Hour*8760), "Optional: amount of time the certificate should be valid for. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"")
cf.outKeyPath = cf.set.String("out-key", "ca.key", "Optional: path to write the private key to") cf.outKeyPath = cf.set.String("out-key", "ca.key", "Optional: path to write the private key to")
cf.outCertPath = cf.set.String("out-crt", "ca.crt", "Optional: path to write the certificate to") cf.outCertPath = cf.set.String("out-crt", "ca.crt", "Optional: path to write the certificate to")
cf.outQRPath = cf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate") cf.outQRPath = cf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate")
cf.groups = cf.set.String("groups", "", "Optional: comma separated list of groups. This will limit which groups subordinate certs can use") cf.groups = cf.set.String("groups", "", "Optional: comma separated list of groups. This will limit which groups subordinate certs can use")
cf.ips = cf.set.String("ips", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use for ip addresses") cf.networks = cf.set.String("networks", "", "Optional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in networks")
cf.subnets = cf.set.String("subnets", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use in subnets") cf.unsafeNetworks = cf.set.String("unsafe-networks", "", "Optional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in unsafe networks")
cf.argonMemory = cf.set.Uint("argon-memory", 2*1024*1024, "Optional: Argon2 memory parameter (in KiB) used for encrypted private key passphrase") cf.argonMemory = cf.set.Uint("argon-memory", 2*1024*1024, "Optional: Argon2 memory parameter (in KiB) used for encrypted private key passphrase")
cf.argonParallelism = cf.set.Uint("argon-parallelism", 4, "Optional: Argon2 parallelism parameter used for encrypted private key passphrase") cf.argonParallelism = cf.set.Uint("argon-parallelism", 4, "Optional: Argon2 parallelism parameter used for encrypted private key passphrase")
cf.argonIterations = cf.set.Uint("argon-iterations", 1, "Optional: Argon2 iterations parameter used for encrypted private key passphrase") cf.argonIterations = cf.set.Uint("argon-iterations", 1, "Optional: Argon2 iterations parameter used for encrypted private key passphrase")
cf.encryption = cf.set.Bool("encrypt", false, "Optional: prompt for passphrase and write out-key in an encrypted format") cf.encryption = cf.set.Bool("encrypt", false, "Optional: prompt for passphrase and write out-key in an encrypted format")
cf.curve = cf.set.String("curve", "25519", "EdDSA/ECDSA Curve (25519, P256)") cf.curve = cf.set.String("curve", "25519", "EdDSA/ECDSA Curve (25519, P256)")
cf.p11url = p11Flag(cf.set) cf.p11url = p11Flag(cf.set)
cf.ips = cf.set.String("ips", "", "Deprecated, see -networks")
cf.subnets = cf.set.String("subnets", "", "Deprecated, see -unsafe-networks")
return &cf return &cf
} }
@ -113,36 +122,51 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
} }
} }
var ips []netip.Prefix version := cert.Version(*cf.version)
if *cf.ips != "" { if version != cert.Version1 && version != cert.Version2 {
for _, rs := range strings.Split(*cf.ips, ",") { return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2)
}
var networks []netip.Prefix
if *cf.networks == "" && *cf.ips != "" {
// Pull up deprecated -ips flag if needed
*cf.networks = *cf.ips
}
if *cf.networks != "" {
for _, rs := range strings.Split(*cf.networks, ",") {
rs := strings.Trim(rs, " ") rs := strings.Trim(rs, " ")
if rs != "" { if rs != "" {
n, err := netip.ParsePrefix(rs) n, err := netip.ParsePrefix(rs)
if err != nil { if err != nil {
return newHelpErrorf("invalid ip definition: %s", err) return newHelpErrorf("invalid -networks definition: %s", rs)
} }
if !n.Addr().Is4() { if version == cert.Version1 && !n.Addr().Is4() {
return newHelpErrorf("invalid ip definition: can only be ipv4, have %s", rs) return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4, have %s", rs)
} }
ips = append(ips, n) networks = append(networks, n)
} }
} }
} }
var subnets []netip.Prefix var unsafeNetworks []netip.Prefix
if *cf.subnets != "" { if *cf.unsafeNetworks == "" && *cf.subnets != "" {
for _, rs := range strings.Split(*cf.subnets, ",") { // Pull up deprecated -subnets flag if needed
*cf.unsafeNetworks = *cf.subnets
}
if *cf.unsafeNetworks != "" {
for _, rs := range strings.Split(*cf.unsafeNetworks, ",") {
rs := strings.Trim(rs, " ") rs := strings.Trim(rs, " ")
if rs != "" { if rs != "" {
n, err := netip.ParsePrefix(rs) n, err := netip.ParsePrefix(rs)
if err != nil { if err != nil {
return newHelpErrorf("invalid subnet definition: %s", err) return newHelpErrorf("invalid -unsafe-networks definition: %s", rs)
} }
if !n.Addr().Is4() { if version == cert.Version1 && !n.Addr().Is4() {
return newHelpErrorf("invalid subnet definition: can only be ipv4, have %s", rs) return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4, have %s", rs)
} }
subnets = append(subnets, n) unsafeNetworks = append(unsafeNetworks, n)
} }
} }
} }
@ -222,11 +246,11 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
} }
t := &cert.TBSCertificate{ t := &cert.TBSCertificate{
Version: cert.Version1, Version: version,
Name: *cf.name, Name: *cf.name,
Groups: groups, Groups: groups,
Networks: ips, Networks: networks,
UnsafeNetworks: subnets, UnsafeNetworks: unsafeNetworks,
NotBefore: time.Now(), NotBefore: time.Now(),
NotAfter: time.Now().Add(*cf.duration), NotAfter: time.Now().Add(*cf.duration),
PublicKey: pub, PublicKey: pub,
@ -248,7 +272,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
var b []byte var b []byte
if isP11 { if isP11 {
c, err = t.SignPkcs11(nil, curve, p11Client) c, err = t.SignWith(nil, curve, p11Client.SignASN1)
if err != nil { if err != nil {
return fmt.Errorf("error while signing with PKCS#11: %w", err) return fmt.Errorf("error while signing with PKCS#11: %w", err)
} }

View File

@ -16,8 +16,6 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
//TODO: test file permissions
func Test_caSummary(t *testing.T) { func Test_caSummary(t *testing.T) {
assert.Equal(t, "ca <flags>: create a self signed certificate authority", caSummary()) assert.Equal(t, "ca <flags>: create a self signed certificate authority", caSummary())
} }
@ -43,9 +41,11 @@ func Test_caHelp(t *testing.T) {
" -groups string\n"+ " -groups string\n"+
" \tOptional: comma separated list of groups. This will limit which groups subordinate certs can use\n"+ " \tOptional: comma separated list of groups. This will limit which groups subordinate certs can use\n"+
" -ips string\n"+ " -ips string\n"+
" \tOptional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use for ip addresses\n"+ " Deprecated, see -networks\n"+
" -name string\n"+ " -name string\n"+
" \tRequired: name of the certificate authority\n"+ " \tRequired: name of the certificate authority\n"+
" -networks string\n"+
" \tOptional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in networks\n"+
" -out-crt string\n"+ " -out-crt string\n"+
" \tOptional: path to write the certificate to (default \"ca.crt\")\n"+ " \tOptional: path to write the certificate to (default \"ca.crt\")\n"+
" -out-key string\n"+ " -out-key string\n"+
@ -54,7 +54,11 @@ func Test_caHelp(t *testing.T) {
" \tOptional: output a qr code image (png) of the certificate\n"+ " \tOptional: output a qr code image (png) of the certificate\n"+
optionalPkcs11String(" -pkcs11 string\n \tOptional: PKCS#11 URI to an existing private key\n")+ optionalPkcs11String(" -pkcs11 string\n \tOptional: PKCS#11 URI to an existing private key\n")+
" -subnets string\n"+ " -subnets string\n"+
" \tOptional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use in subnets\n", " \tDeprecated, see -unsafe-networks\n"+
" -unsafe-networks string\n"+
" \tOptional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in unsafe networks\n"+
" -version uint\n"+
" \tOptional: version of the certificate format to use (default 2)\n",
ob.String(), ob.String(),
) )
} }
@ -83,25 +87,25 @@ func Test_ca(t *testing.T) {
// required args // required args
assertHelpError(t, ca( assertHelpError(t, ca(
[]string{"-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw, []string{"-version", "1", "-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw,
), "-name is required") ), "-name is required")
assert.Equal(t, "", ob.String()) assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
// ipv4 only ips // ipv4 only ips
assertHelpError(t, ca([]string{"-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid ip definition: can only be ipv4, have 100::100/100") assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid -networks definition: v1 certificates can only be ipv4, have 100::100/100")
assert.Equal(t, "", ob.String()) assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
// ipv4 only subnets // ipv4 only subnets
assertHelpError(t, ca([]string{"-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid subnet definition: can only be ipv4, have 100::100/100") assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4, have 100::100/100")
assert.Equal(t, "", ob.String()) assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
// failed key write // failed key write
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args := []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"} args := []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"}
assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
assert.Equal(t, "", ob.String()) assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
@ -114,7 +118,7 @@ func Test_ca(t *testing.T) {
// failed cert write // failed cert write
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()} args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()}
assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
assert.Equal(t, "", ob.String()) assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
@ -128,7 +132,7 @@ func Test_ca(t *testing.T) {
// test proper cert with removed empty groups and subnets // test proper cert with removed empty groups and subnets
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.Nil(t, ca(args, ob, eb, nopw)) assert.Nil(t, ca(args, ob, eb, nopw))
assert.Equal(t, "", ob.String()) assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
@ -161,7 +165,7 @@ func Test_ca(t *testing.T) {
os.Remove(crtF.Name()) os.Remove(crtF.Name())
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.Nil(t, ca(args, ob, eb, testpw)) assert.Nil(t, ca(args, ob, eb, testpw))
assert.Equal(t, pwPromptOb, ob.String()) assert.Equal(t, pwPromptOb, ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
@ -189,7 +193,7 @@ func Test_ca(t *testing.T) {
os.Remove(crtF.Name()) os.Remove(crtF.Name())
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.Error(t, ca(args, ob, eb, errpw)) assert.Error(t, ca(args, ob, eb, errpw))
assert.Equal(t, pwPromptOb, ob.String()) assert.Equal(t, pwPromptOb, ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
@ -199,7 +203,7 @@ func Test_ca(t *testing.T) {
os.Remove(crtF.Name()) os.Remove(crtF.Name())
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext") assert.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext")
assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
@ -209,13 +213,13 @@ func Test_ca(t *testing.T) {
os.Remove(crtF.Name()) os.Remove(crtF.Name())
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.Nil(t, ca(args, ob, eb, nopw)) assert.Nil(t, ca(args, ob, eb, nopw))
// test that we won't overwrite existing certificate file // test that we won't overwrite existing certificate file
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name()) assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name())
assert.Equal(t, "", ob.String()) assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
@ -224,7 +228,7 @@ func Test_ca(t *testing.T) {
os.Remove(keyF.Name()) os.Remove(keyF.Name())
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name()) assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name())
assert.Equal(t, "", ob.String()) assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())

View File

@ -9,8 +9,6 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
//TODO: test file permissions
func Test_keygenSummary(t *testing.T) { func Test_keygenSummary(t *testing.T) {
assert.Equal(t, "keygen <flags>: create a public/private key pair. the public key can be passed to `nebula-cert sign`", keygenSummary()) assert.Equal(t, "keygen <flags>: create a public/private key pair. the public key can be passed to `nebula-cert sign`", keygenSummary())
} }

View File

@ -11,8 +11,6 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
//TODO: all flag parsing continueOnError will print to stderr on its own currently
func Test_help(t *testing.T) { func Test_help(t *testing.T) {
expected := "Usage of " + os.Args[0] + " <global flags> <mode>:\n" + expected := "Usage of " + os.Args[0] + " <global flags> <mode>:\n" +
" Global flags:\n" + " Global flags:\n" +

View File

@ -49,6 +49,8 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
var qrBytes []byte var qrBytes []byte
part := 0 part := 0
var jsonCerts []cert.Certificate
for { for {
c, rawCert, err = cert.UnmarshalCertificateFromPEM(rawCert) c, rawCert, err = cert.UnmarshalCertificateFromPEM(rawCert)
if err != nil { if err != nil {
@ -56,13 +58,10 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
} }
if *pf.json { if *pf.json {
b, _ := json.Marshal(c) jsonCerts = append(jsonCerts, c)
out.Write(b)
out.Write([]byte("\n"))
} else { } else {
out.Write([]byte(c.String())) _, _ = out.Write([]byte(c.String()))
out.Write([]byte("\n")) _, _ = out.Write([]byte("\n"))
} }
if *pf.outQRPath != "" { if *pf.outQRPath != "" {
@ -80,6 +79,12 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
part++ part++
} }
if *pf.json {
b, _ := json.Marshal(jsonCerts)
_, _ = out.Write(b)
_, _ = out.Write([]byte("\n"))
}
if *pf.outQRPath != "" { if *pf.outQRPath != "" {
b, err := qrcode.Encode(string(qrBytes), qrcode.Medium, -5) b, err := qrcode.Encode(string(qrBytes), qrcode.Medium, -5)
if err != nil { if err != nil {

View File

@ -73,7 +73,7 @@ func Test_printCert(t *testing.T) {
tf.Truncate(0) tf.Truncate(0)
tf.Seek(0, 0) tf.Seek(0, 0)
ca, caKey := NewTestCaCert("test ca", nil, nil, time.Time{}, time.Time{}, nil, nil, nil) ca, caKey := NewTestCaCert("test ca", nil, nil, time.Time{}, time.Time{}, nil, nil, nil)
c, _ := NewTestCert(ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, []string{"hi"}) c, _ := NewTestCert(ca, caKey, "test", time.Time{}, time.Time{}, []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")}, nil, []string{"hi"})
p, _ := c.MarshalPEM() p, _ := c.MarshalPEM()
tf.Write(p) tf.Write(p)
@ -87,7 +87,71 @@ func Test_printCert(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal( assert.Equal(
t, t,
"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\n", //"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\n",
`{
"details": {
"curve": "CURVE25519",
"groups": [
"hi"
],
"isCa": false,
"issuer": "`+c.Issuer()+`",
"name": "test",
"networks": [
"10.0.0.123/8"
],
"notAfter": "0001-01-01T00:00:00Z",
"notBefore": "0001-01-01T00:00:00Z",
"publicKey": "`+pk+`",
"unsafeNetworks": []
},
"fingerprint": "`+fp+`",
"signature": "`+sig+`",
"version": 1
}
{
"details": {
"curve": "CURVE25519",
"groups": [
"hi"
],
"isCa": false,
"issuer": "`+c.Issuer()+`",
"name": "test",
"networks": [
"10.0.0.123/8"
],
"notAfter": "0001-01-01T00:00:00Z",
"notBefore": "0001-01-01T00:00:00Z",
"publicKey": "`+pk+`",
"unsafeNetworks": []
},
"fingerprint": "`+fp+`",
"signature": "`+sig+`",
"version": 1
}
{
"details": {
"curve": "CURVE25519",
"groups": [
"hi"
],
"isCa": false,
"issuer": "`+c.Issuer()+`",
"name": "test",
"networks": [
"10.0.0.123/8"
],
"notAfter": "0001-01-01T00:00:00Z",
"notBefore": "0001-01-01T00:00:00Z",
"publicKey": "`+pk+`",
"unsafeNetworks": []
},
"fingerprint": "`+fp+`",
"signature": "`+sig+`",
"version": 1
}
`,
ob.String(), ob.String(),
) )
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
@ -108,7 +172,8 @@ func Test_printCert(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal( assert.Equal(
t, t,
"{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\""+c.Issuer()+"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\""+pk+"\",\"subnets\":[]},\"fingerprint\":\""+fp+"\",\"signature\":\""+sig+"\"}\n{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\""+c.Issuer()+"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\""+pk+"\",\"subnets\":[]},\"fingerprint\":\""+fp+"\",\"signature\":\""+sig+"\"}\n{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\""+c.Issuer()+"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\""+pk+"\",\"subnets\":[]},\"fingerprint\":\""+fp+"\",\"signature\":\""+sig+"\"}\n", `[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}]
`,
ob.String(), ob.String(),
) )
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
@ -153,6 +218,10 @@ func NewTestCert(ca cert.Certificate, signerKey []byte, name string, before, aft
after = ca.NotAfter() after = ca.NotAfter()
} }
if len(networks) == 0 {
networks = []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")}
}
pub, rawPriv := x25519Keypair() pub, rawPriv := x25519Keypair()
nc := &cert.TBSCertificate{ nc := &cert.TBSCertificate{
Version: cert.Version1, Version: cert.Version1,

View File

@ -3,6 +3,7 @@ package main
import ( import (
"crypto/ecdh" "crypto/ecdh"
"crypto/rand" "crypto/rand"
"errors"
"flag" "flag"
"fmt" "fmt"
"io" "io"
@ -18,36 +19,46 @@ import (
) )
type signFlags struct { type signFlags struct {
set *flag.FlagSet set *flag.FlagSet
caKeyPath *string version *uint
caCertPath *string caKeyPath *string
name *string caCertPath *string
ip *string name *string
duration *time.Duration networks *string
inPubPath *string unsafeNetworks *string
outKeyPath *string duration *time.Duration
outCertPath *string inPubPath *string
outQRPath *string outKeyPath *string
groups *string outCertPath *string
subnets *string outQRPath *string
p11url *string groups *string
p11url *string
// Deprecated options
ip *string
subnets *string
} }
func newSignFlags() *signFlags { func newSignFlags() *signFlags {
sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)} sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)}
sf.set.Usage = func() {} sf.set.Usage = func() {}
sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use, the default is to create both v1 and v2 certificates.")
sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key") sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key")
sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert") sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert")
sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname") sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname")
sf.ip = sf.set.String("ip", "", "Required: ipv4 address and network in CIDR notation to assign the cert") sf.networks = sf.set.String("networks", "", "Required: comma separated list of ip address and network in CIDR notation to assign to this cert")
sf.unsafeNetworks = sf.set.String("unsafe-networks", "", "Optional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for")
sf.duration = sf.set.Duration("duration", 0, "Optional: how long the cert should be valid for. The default is 1 second before the signing cert expires. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"") sf.duration = sf.set.Duration("duration", 0, "Optional: how long the cert should be valid for. The default is 1 second before the signing cert expires. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"")
sf.inPubPath = sf.set.String("in-pub", "", "Optional (if out-key not set): path to read a previously generated public key") sf.inPubPath = sf.set.String("in-pub", "", "Optional (if out-key not set): path to read a previously generated public key")
sf.outKeyPath = sf.set.String("out-key", "", "Optional (if in-pub not set): path to write the private key to") sf.outKeyPath = sf.set.String("out-key", "", "Optional (if in-pub not set): path to write the private key to")
sf.outCertPath = sf.set.String("out-crt", "", "Optional: path to write the certificate to") sf.outCertPath = sf.set.String("out-crt", "", "Optional: path to write the certificate to")
sf.outQRPath = sf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate") sf.outQRPath = sf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate")
sf.groups = sf.set.String("groups", "", "Optional: comma separated list of groups") sf.groups = sf.set.String("groups", "", "Optional: comma separated list of groups")
sf.subnets = sf.set.String("subnets", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. Subnets this cert can serve for")
sf.p11url = p11Flag(sf.set) sf.p11url = p11Flag(sf.set)
sf.ip = sf.set.String("ip", "", "Deprecated, see -networks")
sf.subnets = sf.set.String("subnets", "", "Deprecated, see -unsafe-networks")
return &sf return &sf
} }
@ -71,13 +82,26 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
if err := mustFlagString("name", sf.name); err != nil { if err := mustFlagString("name", sf.name); err != nil {
return err return err
} }
if err := mustFlagString("ip", sf.ip); err != nil {
return err
}
if !isP11 && *sf.inPubPath != "" && *sf.outKeyPath != "" { if !isP11 && *sf.inPubPath != "" && *sf.outKeyPath != "" {
return newHelpErrorf("cannot set both -in-pub and -out-key") return newHelpErrorf("cannot set both -in-pub and -out-key")
} }
var v4Networks []netip.Prefix
var v6Networks []netip.Prefix
if *sf.networks == "" && *sf.ip != "" {
// Pull up deprecated -ip flag if needed
*sf.networks = *sf.ip
}
if len(*sf.networks) == 0 {
return newHelpErrorf("-networks is required")
}
version := cert.Version(*sf.version)
if version != 0 && version != cert.Version1 && version != cert.Version2 {
return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2)
}
var curve cert.Curve var curve cert.Curve
var caKey []byte var caKey []byte
@ -91,14 +115,14 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
// naively attempt to decode the private key as though it is not encrypted // naively attempt to decode the private key as though it is not encrypted
caKey, _, curve, err = cert.UnmarshalSigningPrivateKeyFromPEM(rawCAKey) caKey, _, curve, err = cert.UnmarshalSigningPrivateKeyFromPEM(rawCAKey)
if err == cert.ErrPrivateKeyEncrypted { if errors.Is(err, cert.ErrPrivateKeyEncrypted) {
// ask for a passphrase until we get one // ask for a passphrase until we get one
var passphrase []byte var passphrase []byte
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
out.Write([]byte("Enter passphrase: ")) out.Write([]byte("Enter passphrase: "))
passphrase, err = pr.ReadPassword() passphrase, err = pr.ReadPassword()
if err == ErrNoTerminal { if errors.Is(err, ErrNoTerminal) {
return fmt.Errorf("ca-key is encrypted and must be decrypted interactively") return fmt.Errorf("ca-key is encrypted and must be decrypted interactively")
} else if err != nil { } else if err != nil {
return fmt.Errorf("error reading password: %s", err) return fmt.Errorf("error reading password: %s", err)
@ -146,12 +170,47 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
*sf.duration = time.Until(caCert.NotAfter()) - time.Second*1 *sf.duration = time.Until(caCert.NotAfter()) - time.Second*1
} }
network, err := netip.ParsePrefix(*sf.ip) if *sf.networks != "" {
if err != nil { for _, rs := range strings.Split(*sf.networks, ",") {
return newHelpErrorf("invalid ip definition: %s", *sf.ip) rs := strings.Trim(rs, " ")
if rs != "" {
n, err := netip.ParsePrefix(rs)
if err != nil {
return newHelpErrorf("invalid -networks definition: %s", rs)
}
if n.Addr().Is4() {
v4Networks = append(v4Networks, n)
} else {
v6Networks = append(v6Networks, n)
}
}
}
} }
if !network.Addr().Is4() {
return newHelpErrorf("invalid ip definition: can only be ipv4, have %s", *sf.ip) var v4UnsafeNetworks []netip.Prefix
var v6UnsafeNetworks []netip.Prefix
if *sf.unsafeNetworks == "" && *sf.subnets != "" {
// Pull up deprecated -subnets flag if needed
*sf.unsafeNetworks = *sf.subnets
}
if *sf.unsafeNetworks != "" {
for _, rs := range strings.Split(*sf.unsafeNetworks, ",") {
rs := strings.Trim(rs, " ")
if rs != "" {
n, err := netip.ParsePrefix(rs)
if err != nil {
return newHelpErrorf("invalid -unsafe-networks definition: %s", rs)
}
if n.Addr().Is4() {
v4UnsafeNetworks = append(v4UnsafeNetworks, n)
} else {
v6UnsafeNetworks = append(v6UnsafeNetworks, n)
}
}
}
} }
var groups []string var groups []string
@ -164,23 +223,6 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
} }
} }
var subnets []netip.Prefix
if *sf.subnets != "" {
for _, rs := range strings.Split(*sf.subnets, ",") {
rs := strings.Trim(rs, " ")
if rs != "" {
s, err := netip.ParsePrefix(rs)
if err != nil {
return newHelpErrorf("invalid subnet definition: %s", rs)
}
if !s.Addr().Is4() {
return newHelpErrorf("invalid subnet definition: can only be ipv4, have %s", rs)
}
subnets = append(subnets, s)
}
}
}
var pub, rawPriv []byte var pub, rawPriv []byte
var p11Client *pkclient.PKClient var p11Client *pkclient.PKClient
@ -218,19 +260,6 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
pub, rawPriv = newKeypair(curve) pub, rawPriv = newKeypair(curve)
} }
t := &cert.TBSCertificate{
Version: cert.Version1,
Name: *sf.name,
Networks: []netip.Prefix{network},
Groups: groups,
UnsafeNetworks: subnets,
NotBefore: time.Now(),
NotAfter: time.Now().Add(*sf.duration),
PublicKey: pub,
IsCA: false,
Curve: curve,
}
if *sf.outKeyPath == "" { if *sf.outKeyPath == "" {
*sf.outKeyPath = *sf.name + ".key" *sf.outKeyPath = *sf.name + ".key"
} }
@ -243,18 +272,85 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath) return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath)
} }
var c cert.Certificate var crts []cert.Certificate
if p11Client == nil { notBefore := time.Now()
c, err = t.Sign(caCert, curve, caKey) notAfter := notBefore.Add(*sf.duration)
if err != nil {
return fmt.Errorf("error while signing: %w", err) if version == 0 || version == cert.Version1 {
// Make sure we at least have an ip
if len(v4Networks) != 1 {
return newHelpErrorf("invalid -networks definition: v1 certificates can only have a single ipv4 address")
} }
} else {
c, err = t.SignPkcs11(caCert, curve, p11Client) if version == cert.Version1 {
if err != nil { // If we are asked to mint a v1 certificate only then we cant just ignore any v6 addresses
return fmt.Errorf("error while signing with PKCS#11: %w", err) if len(v6Networks) > 0 {
return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4")
}
if len(v6UnsafeNetworks) > 0 {
return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4")
}
} }
t := &cert.TBSCertificate{
Version: cert.Version1,
Name: *sf.name,
Networks: []netip.Prefix{v4Networks[0]},
Groups: groups,
UnsafeNetworks: v4UnsafeNetworks,
NotBefore: notBefore,
NotAfter: notAfter,
PublicKey: pub,
IsCA: false,
Curve: curve,
}
var nc cert.Certificate
if p11Client == nil {
nc, err = t.Sign(caCert, curve, caKey)
if err != nil {
return fmt.Errorf("error while signing: %w", err)
}
} else {
nc, err = t.SignWith(caCert, curve, p11Client.SignASN1)
if err != nil {
return fmt.Errorf("error while signing with PKCS#11: %w", err)
}
}
crts = append(crts, nc)
}
if version == 0 || version == cert.Version2 {
t := &cert.TBSCertificate{
Version: cert.Version2,
Name: *sf.name,
Networks: append(v4Networks, v6Networks...),
Groups: groups,
UnsafeNetworks: append(v4UnsafeNetworks, v6UnsafeNetworks...),
NotBefore: notBefore,
NotAfter: notAfter,
PublicKey: pub,
IsCA: false,
Curve: curve,
}
var nc cert.Certificate
if p11Client == nil {
nc, err = t.Sign(caCert, curve, caKey)
if err != nil {
return fmt.Errorf("error while signing: %w", err)
}
} else {
nc, err = t.SignWith(caCert, curve, p11Client.SignASN1)
if err != nil {
return fmt.Errorf("error while signing with PKCS#11: %w", err)
}
}
crts = append(crts, nc)
} }
if !isP11 && *sf.inPubPath == "" { if !isP11 && *sf.inPubPath == "" {
@ -268,9 +364,13 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
} }
} }
b, err := c.MarshalPEM() var b []byte
if err != nil { for _, c := range crts {
return fmt.Errorf("error while marshalling certificate: %s", err) sb, err := c.MarshalPEM()
if err != nil {
return fmt.Errorf("error while marshalling certificate: %s", err)
}
b = append(b, sb...)
} }
err = os.WriteFile(*sf.outCertPath, b, 0600) err = os.WriteFile(*sf.outCertPath, b, 0600)

View File

@ -16,8 +16,6 @@ import (
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
) )
//TODO: test file permissions
func Test_signSummary(t *testing.T) { func Test_signSummary(t *testing.T) {
assert.Equal(t, "sign <flags>: create and sign a certificate", signSummary()) assert.Equal(t, "sign <flags>: create and sign a certificate", signSummary())
} }
@ -39,9 +37,11 @@ func Test_signHelp(t *testing.T) {
" -in-pub string\n"+ " -in-pub string\n"+
" \tOptional (if out-key not set): path to read a previously generated public key\n"+ " \tOptional (if out-key not set): path to read a previously generated public key\n"+
" -ip string\n"+ " -ip string\n"+
" \tRequired: ipv4 address and network in CIDR notation to assign the cert\n"+ " \tDeprecated, see -networks\n"+
" -name string\n"+ " -name string\n"+
" \tRequired: name of the cert, usually a hostname\n"+ " \tRequired: name of the cert, usually a hostname\n"+
" -networks string\n"+
" \tRequired: comma separated list of ip address and network in CIDR notation to assign to this cert\n"+
" -out-crt string\n"+ " -out-crt string\n"+
" \tOptional: path to write the certificate to\n"+ " \tOptional: path to write the certificate to\n"+
" -out-key string\n"+ " -out-key string\n"+
@ -50,7 +50,11 @@ func Test_signHelp(t *testing.T) {
" \tOptional: output a qr code image (png) of the certificate\n"+ " \tOptional: output a qr code image (png) of the certificate\n"+
optionalPkcs11String(" -pkcs11 string\n \tOptional: PKCS#11 URI to an existing private key\n")+ optionalPkcs11String(" -pkcs11 string\n \tOptional: PKCS#11 URI to an existing private key\n")+
" -subnets string\n"+ " -subnets string\n"+
" \tOptional: comma separated list of ipv4 address and network in CIDR notation. Subnets this cert can serve for\n", " \tDeprecated, see -unsafe-networks\n"+
" -unsafe-networks string\n"+
" \tOptional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for\n"+
" -version uint\n"+
" \tOptional: version of the certificate format to use, the default is to create both v1 and v2 certificates.\n",
ob.String(), ob.String(),
) )
} }
@ -77,20 +81,20 @@ func Test_signCert(t *testing.T) {
// required args // required args
assertHelpError(t, signCert( assertHelpError(t, signCert(
[]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw, []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
), "-name is required") ), "-name is required")
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
assertHelpError(t, signCert( assertHelpError(t, signCert(
[]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw, []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
), "-ip is required") ), "-networks is required")
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
// cannot set -in-pub and -out-key // cannot set -in-pub and -out-key
assertHelpError(t, signCert( assertHelpError(t, signCert(
[]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb, nopw, []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb, nopw,
), "cannot set both -in-pub and -out-key") ), "cannot set both -in-pub and -out-key")
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -98,7 +102,7 @@ func Test_signCert(t *testing.T) {
// failed to read key // failed to read key
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args := []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} args := []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError) assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError)
// failed to unmarshal key // failed to unmarshal key
@ -108,7 +112,7 @@ func Test_signCert(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(caKeyF.Name()) defer os.Remove(caKeyF.Name())
args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block") assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block")
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -120,7 +124,7 @@ func Test_signCert(t *testing.T) {
caKeyF.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv)) caKeyF.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv))
// failed to read cert // failed to read cert
args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError) assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError)
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -132,7 +136,7 @@ func Test_signCert(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(caCrtF.Name()) defer os.Remove(caCrtF.Name())
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block") assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block")
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -143,7 +147,7 @@ func Test_signCert(t *testing.T) {
caCrtF.Write(b) caCrtF.Write(b)
// failed to read pub // failed to read pub
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError) assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError)
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -155,7 +159,7 @@ func Test_signCert(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(inPubF.Name()) defer os.Remove(inPubF.Name())
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block") assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block")
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -169,30 +173,37 @@ func Test_signCert(t *testing.T) {
// bad ip cidr // bad ip cidr
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "a1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "a1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: a1.1.1.1/24") assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: a1.1.1.1/24")
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "100::100/100", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "100::100/100", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: can only be ipv4, have 100::100/100") assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: v1 certificates can only have a single ipv4 address")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
ob.Reset()
eb.Reset()
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24,1.1.1.2/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: v1 certificates can only have a single ipv4 address")
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
// bad subnet cidr // bad subnet cidr
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: a") assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: a")
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"}
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: can only be ipv4, have 100::100/100") assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4")
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -205,7 +216,7 @@ func Test_signCert(t *testing.T) {
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key") assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key")
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -213,7 +224,7 @@ func Test_signCert(t *testing.T) {
// failed key write // failed key write
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -226,7 +237,7 @@ func Test_signCert(t *testing.T) {
// failed cert write // failed cert write
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -240,7 +251,7 @@ func Test_signCert(t *testing.T) {
// test proper cert with removed empty groups and subnets // test proper cert with removed empty groups and subnets
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.Nil(t, signCert(args, ob, eb, nopw)) assert.Nil(t, signCert(args, ob, eb, nopw))
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -283,7 +294,7 @@ func Test_signCert(t *testing.T) {
os.Remove(crtF.Name()) os.Remove(crtF.Name())
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"}
assert.Nil(t, signCert(args, ob, eb, nopw)) assert.Nil(t, signCert(args, ob, eb, nopw))
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -300,7 +311,7 @@ func Test_signCert(t *testing.T) {
eb.Reset() eb.Reset()
os.Remove(keyF.Name()) os.Remove(keyF.Name())
os.Remove(crtF.Name()) os.Remove(crtF.Name())
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while signing: certificate expires after signing certificate") assert.EqualError(t, signCert(args, ob, eb, nopw), "error while signing: certificate expires after signing certificate")
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -308,14 +319,14 @@ func Test_signCert(t *testing.T) {
// create valid cert/key for overwrite tests // create valid cert/key for overwrite tests
os.Remove(keyF.Name()) os.Remove(keyF.Name())
os.Remove(crtF.Name()) os.Remove(crtF.Name())
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.Nil(t, signCert(args, ob, eb, nopw)) assert.Nil(t, signCert(args, ob, eb, nopw))
// test that we won't overwrite existing key file // test that we won't overwrite existing key file
os.Remove(crtF.Name()) os.Remove(crtF.Name())
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name()) assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name())
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -323,14 +334,14 @@ func Test_signCert(t *testing.T) {
// create valid cert/key for overwrite tests // create valid cert/key for overwrite tests
os.Remove(keyF.Name()) os.Remove(keyF.Name())
os.Remove(crtF.Name()) os.Remove(crtF.Name())
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.Nil(t, signCert(args, ob, eb, nopw)) assert.Nil(t, signCert(args, ob, eb, nopw))
// test that we won't overwrite existing certificate file // test that we won't overwrite existing certificate file
os.Remove(keyF.Name()) os.Remove(keyF.Name())
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name()) assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name())
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -362,7 +373,7 @@ func Test_signCert(t *testing.T) {
caCrtF.Write(b) caCrtF.Write(b)
// test with the proper password // test with the proper password
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.Nil(t, signCert(args, ob, eb, testpw)) assert.Nil(t, signCert(args, ob, eb, testpw))
assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Equal(t, "Enter passphrase: ", ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -372,7 +383,7 @@ func Test_signCert(t *testing.T) {
eb.Reset() eb.Reset()
testpw.password = []byte("invalid password") testpw.password = []byte("invalid password")
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.Error(t, signCert(args, ob, eb, testpw)) assert.Error(t, signCert(args, ob, eb, testpw))
assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Equal(t, "Enter passphrase: ", ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -381,7 +392,7 @@ func Test_signCert(t *testing.T) {
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.Error(t, signCert(args, ob, eb, nopw)) assert.Error(t, signCert(args, ob, eb, nopw))
// normally the user hitting enter on the prompt would add newlines between these // normally the user hitting enter on the prompt would add newlines between these
assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String()) assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String())
@ -391,7 +402,7 @@ func Test_signCert(t *testing.T) {
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.Error(t, signCert(args, ob, eb, errpw)) assert.Error(t, signCert(args, ob, eb, errpw))
assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Equal(t, "Enter passphrase: ", ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())

View File

@ -1,6 +1,7 @@
package main package main
import ( import (
"errors"
"flag" "flag"
"fmt" "fmt"
"io" "io"
@ -41,14 +42,14 @@ func verify(args []string, out io.Writer, errOut io.Writer) error {
rawCACert, err := os.ReadFile(*vf.caPath) rawCACert, err := os.ReadFile(*vf.caPath)
if err != nil { if err != nil {
return fmt.Errorf("error while reading ca: %s", err) return fmt.Errorf("error while reading ca: %w", err)
} }
caPool := cert.NewCAPool() caPool := cert.NewCAPool()
for { for {
rawCACert, err = caPool.AddCAFromPEM(rawCACert) rawCACert, err = caPool.AddCAFromPEM(rawCACert)
if err != nil { if err != nil {
return fmt.Errorf("error while adding ca cert to pool: %s", err) return fmt.Errorf("error while adding ca cert to pool: %w", err)
} }
if rawCACert == nil || len(rawCACert) == 0 || strings.TrimSpace(string(rawCACert)) == "" { if rawCACert == nil || len(rawCACert) == 0 || strings.TrimSpace(string(rawCACert)) == "" {
@ -58,20 +59,30 @@ func verify(args []string, out io.Writer, errOut io.Writer) error {
rawCert, err := os.ReadFile(*vf.certPath) rawCert, err := os.ReadFile(*vf.certPath)
if err != nil { if err != nil {
return fmt.Errorf("unable to read crt; %s", err) return fmt.Errorf("unable to read crt: %w", err)
}
var errs []error
for {
if len(rawCert) == 0 {
break
}
c, extra, err := cert.UnmarshalCertificateFromPEM(rawCert)
if err != nil {
return fmt.Errorf("error while parsing crt: %w", err)
}
rawCert = extra
_, err = caPool.VerifyCertificate(time.Now(), c)
if err != nil {
switch {
case errors.Is(err, cert.ErrCaNotFound):
errs = append(errs, fmt.Errorf("error while verifying certificate v%d %s with issuer %s: %w", c.Version(), c.Name(), c.Issuer(), err))
default:
errs = append(errs, fmt.Errorf("error while verifying certificate %+v: %w", c, err))
}
}
} }
c, _, err := cert.UnmarshalCertificateFromPEM(rawCert) return errors.Join(errs...)
if err != nil {
return fmt.Errorf("error while parsing crt: %s", err)
}
_, err = caPool.VerifyCertificate(time.Now(), c)
if err != nil {
return err
}
return nil
} }
func verifySummary() string { func verifySummary() string {
@ -80,7 +91,7 @@ func verifySummary() string {
func verifyHelp(out io.Writer) { func verifyHelp(out io.Writer) {
vf := newVerifyFlags() vf := newVerifyFlags()
out.Write([]byte("Usage of " + os.Args[0] + " " + verifySummary() + "\n")) _, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + verifySummary() + "\n"))
vf.set.SetOutput(out) vf.set.SetOutput(out)
vf.set.PrintDefaults() vf.set.PrintDefaults()
} }

View File

@ -3,10 +3,12 @@ package main
import ( import (
"bytes" "bytes"
"crypto/rand" "crypto/rand"
"errors"
"os" "os"
"testing" "testing"
"time" "time"
"github.com/slackhq/nebula/cert"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
) )
@ -76,7 +78,7 @@ func Test_verify(t *testing.T) {
err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb) err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
assert.Equal(t, "", ob.String()) assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
assert.EqualError(t, err, "unable to read crt; open does_not_exist: "+NoSuchFileError) assert.EqualError(t, err, "unable to read crt: open does_not_exist: "+NoSuchFileError)
// invalid crt at path // invalid crt at path
ob.Reset() ob.Reset()
@ -106,7 +108,7 @@ func Test_verify(t *testing.T) {
err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb) err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
assert.Equal(t, "", ob.String()) assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
assert.EqualError(t, err, "certificate signature did not match") assert.True(t, errors.Is(err, cert.ErrSignatureMismatch))
// verified cert at path // verified cert at path
crt, _ = NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil) crt, _ = NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil)

View File

@ -38,9 +38,6 @@ func TestConfig_Load(t *testing.T) {
"new": "hi", "new": "hi",
} }
assert.Equal(t, expected, c.Settings) assert.Equal(t, expected, c.Settings)
//TODO: test symlinked file
//TODO: test symlinked directory
} }
func TestConfig_Get(t *testing.T) { func TestConfig_Get(t *testing.T) {

View File

@ -183,7 +183,7 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte,
case deleteTunnel: case deleteTunnel:
if n.hostMap.DeleteHostInfo(hostinfo) { if n.hostMap.DeleteHostInfo(hostinfo) {
// Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap // Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap
n.intf.lightHouse.DeleteVpnIp(hostinfo.vpnIp) n.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs)
} }
case closeTunnel: case closeTunnel:
@ -221,7 +221,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
relayFor := oldhostinfo.relayState.CopyAllRelayFor() relayFor := oldhostinfo.relayState.CopyAllRelayFor()
for _, r := range relayFor { for _, r := range relayFor {
existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerIp) existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerAddr)
var index uint32 var index uint32
var relayFrom netip.Addr var relayFrom netip.Addr
@ -235,11 +235,11 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
index = existing.LocalIndex index = existing.LocalIndex
switch r.Type { switch r.Type {
case TerminalType: case TerminalType:
relayFrom = n.intf.myVpnNet.Addr() relayFrom = n.intf.myVpnAddrs[0]
relayTo = existing.PeerIp relayTo = existing.PeerAddr
case ForwardingType: case ForwardingType:
relayFrom = existing.PeerIp relayFrom = existing.PeerAddr
relayTo = newhostinfo.vpnIp relayTo = newhostinfo.vpnAddrs[0]
default: default:
// should never happen // should never happen
} }
@ -253,45 +253,64 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
n.relayUsedLock.RUnlock() n.relayUsedLock.RUnlock()
// The relay doesn't exist at all; create some relay state and send the request. // The relay doesn't exist at all; create some relay state and send the request.
var err error var err error
index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerIp, nil, r.Type, Requested) index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerAddr, nil, r.Type, Requested)
if err != nil { if err != nil {
n.l.WithError(err).Error("failed to migrate relay to new hostinfo") n.l.WithError(err).Error("failed to migrate relay to new hostinfo")
continue continue
} }
switch r.Type { switch r.Type {
case TerminalType: case TerminalType:
relayFrom = n.intf.myVpnNet.Addr() relayFrom = n.intf.myVpnAddrs[0]
relayTo = r.PeerIp relayTo = r.PeerAddr
case ForwardingType: case ForwardingType:
relayFrom = r.PeerIp relayFrom = r.PeerAddr
relayTo = newhostinfo.vpnIp relayTo = newhostinfo.vpnAddrs[0]
default: default:
// should never happen // should never happen
} }
} }
//TODO: IPV6-WORK
relayFromB := relayFrom.As4()
relayToB := relayTo.As4()
// Send a CreateRelayRequest to the peer. // Send a CreateRelayRequest to the peer.
req := NebulaControl{ req := NebulaControl{
Type: NebulaControl_CreateRelayRequest, Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: index, InitiatorRelayIndex: index,
RelayFromIp: binary.BigEndian.Uint32(relayFromB[:]),
RelayToIp: binary.BigEndian.Uint32(relayToB[:]),
} }
switch newhostinfo.GetCert().Certificate.Version() {
case cert.Version1:
if !relayFrom.Is4() {
n.l.Error("can not migrate v1 relay with a v6 network because the relay is not running a current nebula version")
continue
}
if !relayTo.Is4() {
n.l.Error("can not migrate v1 relay with a v6 remote network because the relay is not running a current nebula version")
continue
}
b := relayFrom.As4()
req.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
b = relayTo.As4()
req.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
case cert.Version2:
req.RelayFromAddr = netAddrToProtoAddr(relayFrom)
req.RelayToAddr = netAddrToProtoAddr(relayTo)
default:
newhostinfo.logger(n.l).Error("Unknown certificate version found while attempting to migrate relay")
continue
}
msg, err := req.Marshal() msg, err := req.Marshal()
if err != nil { if err != nil {
n.l.WithError(err).Error("failed to marshal Control message to migrate relay") n.l.WithError(err).Error("failed to marshal Control message to migrate relay")
} else { } else {
n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu)) n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
n.l.WithFields(logrus.Fields{ n.l.WithFields(logrus.Fields{
"relayFrom": req.RelayFromIp, "relayFrom": req.RelayFromAddr,
"relayTo": req.RelayToIp, "relayTo": req.RelayToAddr,
"initiatorRelayIndex": req.InitiatorRelayIndex, "initiatorRelayIndex": req.InitiatorRelayIndex,
"responderRelayIndex": req.ResponderRelayIndex, "responderRelayIndex": req.ResponderRelayIndex,
"vpnIp": newhostinfo.vpnIp}). "vpnAddrs": newhostinfo.vpnAddrs}).
Info("send CreateRelayRequest") Info("send CreateRelayRequest")
} }
} }
@ -313,7 +332,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time
return closeTunnel, hostinfo, nil return closeTunnel, hostinfo, nil
} }
primary := n.hostMap.Hosts[hostinfo.vpnIp] primary := n.hostMap.Hosts[hostinfo.vpnAddrs[0]]
mainHostInfo := true mainHostInfo := true
if primary != nil && primary != hostinfo { if primary != nil && primary != hostinfo {
mainHostInfo = false mainHostInfo = false
@ -407,21 +426,24 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary. // If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
// Let's sort this out. // Let's sort this out.
if current.vpnIp.Compare(n.intf.myVpnNet.Addr()) < 0 { // Only one side should swap because if both swap then we may never resolve to a single tunnel.
// Only one side should flip primary because if both flip then we may never resolve to a single tunnel. // vpn addr is static across all tunnels for this host pair so lets
// vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping. // use that to determine if we should consider swapping.
// The remotes vpn ip is lower than mine. I will not flip. if current.vpnAddrs[0].Compare(n.intf.myVpnAddrs[0]) < 0 {
// Their primary vpn addr is less than mine. Do not swap.
return false return false
} }
certState := n.intf.pki.GetCertState() crt := n.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
return bytes.Equal(current.ConnectionState.myCert.Signature(), certState.Certificate.Signature()) // If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
// settle down.
return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
} }
func (n *connectionManager) swapPrimary(current, primary *HostInfo) { func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
n.hostMap.Lock() n.hostMap.Lock()
// Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake. // Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake.
if n.hostMap.Hosts[current.vpnIp] == primary { if n.hostMap.Hosts[current.vpnAddrs[0]] == primary {
n.hostMap.unlockedMakePrimary(current) n.hostMap.unlockedMakePrimary(current)
} }
n.hostMap.Unlock() n.hostMap.Unlock()
@ -473,14 +495,17 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
} }
func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) { func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
certState := n.intf.pki.GetCertState() cs := n.intf.pki.getCertState()
if bytes.Equal(hostinfo.ConnectionState.myCert.Signature(), certState.Certificate.Signature()) { curCrt := hostinfo.ConnectionState.myCert
myCrt := cs.getCertificate(curCrt.Version())
if curCrt.Version() >= cs.defaultVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
// The current tunnel is using the latest certificate and version, no need to rehandshake.
return return
} }
n.l.WithField("vpnIp", hostinfo.vpnIp). n.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("reason", "local certificate is not current"). WithField("reason", "local certificate is not current").
Info("Re-handshaking with remote") Info("Re-handshaking with remote")
n.intf.handshakeManager.StartHandshake(hostinfo.vpnIp, nil) n.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
} }

View File

@ -34,20 +34,19 @@ func newTestLighthouse() *LightHouse {
func Test_NewConnectionManagerTest(t *testing.T) { func Test_NewConnectionManagerTest(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
vpncidr := netip.MustParsePrefix("172.1.1.1/24")
localrange := netip.MustParsePrefix("10.1.1.1/24") localrange := netip.MustParsePrefix("10.1.1.1/24")
vpnIp := netip.MustParseAddr("172.1.1.2") vpnIp := netip.MustParseAddr("172.1.1.2")
preferredRanges := []netip.Prefix{localrange} preferredRanges := []netip.Prefix{localrange}
// Very incomplete mock objects // Very incomplete mock objects
hostMap := newHostMap(l, vpncidr) hostMap := newHostMap(l)
hostMap.preferredRanges.Store(&preferredRanges) hostMap.preferredRanges.Store(&preferredRanges)
cs := &CertState{ cs := &CertState{
RawCertificate: []byte{}, defaultVersion: cert.Version1,
PrivateKey: []byte{}, privateKey: []byte{},
Certificate: &dummyCert{}, v1Cert: &dummyCert{version: cert.Version1},
RawCertificateNoKey: []byte{}, v1HandshakeBytes: []byte{},
} }
lh := newTestLighthouse() lh := newTestLighthouse()
@ -74,12 +73,12 @@ func Test_NewConnectionManagerTest(t *testing.T) {
// Add an ip we have established a connection w/ to hostmap // Add an ip we have established a connection w/ to hostmap
hostinfo := &HostInfo{ hostinfo := &HostInfo{
vpnIp: vpnIp, vpnAddrs: []netip.Addr{vpnIp},
localIndexId: 1099, localIndexId: 1099,
remoteIndexId: 9901, remoteIndexId: 9901,
} }
hostinfo.ConnectionState = &ConnectionState{ hostinfo.ConnectionState = &ConnectionState{
myCert: &dummyCert{}, myCert: &dummyCert{version: cert.Version1},
H: &noise.HandshakeState{}, H: &noise.HandshakeState{},
} }
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
@ -88,7 +87,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
nc.Out(hostinfo.localIndexId) nc.Out(hostinfo.localIndexId)
nc.In(hostinfo.localIndexId) nc.In(hostinfo.localIndexId)
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.out, hostinfo.localIndexId) assert.Contains(t, nc.out, hostinfo.localIndexId)
@ -105,32 +104,31 @@ func Test_NewConnectionManagerTest(t *testing.T) {
assert.NotContains(t, nc.out, hostinfo.localIndexId) assert.NotContains(t, nc.out, hostinfo.localIndexId)
assert.NotContains(t, nc.in, hostinfo.localIndexId) assert.NotContains(t, nc.in, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
// Do a final traffic check tick, the host should now be removed // Do a final traffic check tick, the host should now be removed
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnIp) assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
} }
func Test_NewConnectionManagerTest2(t *testing.T) { func Test_NewConnectionManagerTest2(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
vpncidr := netip.MustParsePrefix("172.1.1.1/24")
localrange := netip.MustParsePrefix("10.1.1.1/24") localrange := netip.MustParsePrefix("10.1.1.1/24")
vpnIp := netip.MustParseAddr("172.1.1.2") vpnIp := netip.MustParseAddr("172.1.1.2")
preferredRanges := []netip.Prefix{localrange} preferredRanges := []netip.Prefix{localrange}
// Very incomplete mock objects // Very incomplete mock objects
hostMap := newHostMap(l, vpncidr) hostMap := newHostMap(l)
hostMap.preferredRanges.Store(&preferredRanges) hostMap.preferredRanges.Store(&preferredRanges)
cs := &CertState{ cs := &CertState{
RawCertificate: []byte{}, defaultVersion: cert.Version1,
PrivateKey: []byte{}, privateKey: []byte{},
Certificate: &dummyCert{}, v1Cert: &dummyCert{version: cert.Version1},
RawCertificateNoKey: []byte{}, v1HandshakeBytes: []byte{},
} }
lh := newTestLighthouse() lh := newTestLighthouse()
@ -157,12 +155,12 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
// Add an ip we have established a connection w/ to hostmap // Add an ip we have established a connection w/ to hostmap
hostinfo := &HostInfo{ hostinfo := &HostInfo{
vpnIp: vpnIp, vpnAddrs: []netip.Addr{vpnIp},
localIndexId: 1099, localIndexId: 1099,
remoteIndexId: 9901, remoteIndexId: 9901,
} }
hostinfo.ConnectionState = &ConnectionState{ hostinfo.ConnectionState = &ConnectionState{
myCert: &dummyCert{}, myCert: &dummyCert{version: cert.Version1},
H: &noise.HandshakeState{}, H: &noise.HandshakeState{},
} }
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
@ -170,8 +168,8 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
// We saw traffic out to vpnIp // We saw traffic out to vpnIp
nc.Out(hostinfo.localIndexId) nc.Out(hostinfo.localIndexId)
nc.In(hostinfo.localIndexId) nc.In(hostinfo.localIndexId)
assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnIp) assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnAddrs[0])
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded // Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
@ -187,7 +185,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
assert.NotContains(t, nc.out, hostinfo.localIndexId) assert.NotContains(t, nc.out, hostinfo.localIndexId)
assert.NotContains(t, nc.in, hostinfo.localIndexId) assert.NotContains(t, nc.in, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
// We saw traffic, should no longer be pending deletion // We saw traffic, should no longer be pending deletion
nc.In(hostinfo.localIndexId) nc.In(hostinfo.localIndexId)
@ -196,7 +194,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
assert.NotContains(t, nc.out, hostinfo.localIndexId) assert.NotContains(t, nc.out, hostinfo.localIndexId)
assert.NotContains(t, nc.in, hostinfo.localIndexId) assert.NotContains(t, nc.in, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
} }
// Check if we can disconnect the peer. // Check if we can disconnect the peer.
@ -210,7 +208,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
localrange := netip.MustParsePrefix("10.1.1.1/24") localrange := netip.MustParsePrefix("10.1.1.1/24")
vpnIp := netip.MustParseAddr("172.1.1.2") vpnIp := netip.MustParseAddr("172.1.1.2")
preferredRanges := []netip.Prefix{localrange} preferredRanges := []netip.Prefix{localrange}
hostMap := newHostMap(l, vpncidr) hostMap := newHostMap(l)
hostMap.preferredRanges.Store(&preferredRanges) hostMap.preferredRanges.Store(&preferredRanges)
// Generate keys for CA and peer's cert. // Generate keys for CA and peer's cert.
@ -244,10 +242,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert) cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert)
cs := &CertState{ cs := &CertState{
RawCertificate: []byte{}, privateKey: []byte{},
PrivateKey: []byte{}, v1Cert: &dummyCert{},
Certificate: &dummyCert{}, v1HandshakeBytes: []byte{},
RawCertificateNoKey: []byte{},
} }
lh := newTestLighthouse() lh := newTestLighthouse()
@ -273,7 +270,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
ifce.connectionManager = nc ifce.connectionManager = nc
hostinfo := &HostInfo{ hostinfo := &HostInfo{
vpnIp: vpnIp, vpnAddrs: []netip.Addr{vpnIp},
ConnectionState: &ConnectionState{ ConnectionState: &ConnectionState{
myCert: &dummyCert{}, myCert: &dummyCert{},
peerCert: cachedPeerCert, peerCert: cachedPeerCert,

View File

@ -3,6 +3,7 @@ package nebula
import ( import (
"crypto/rand" "crypto/rand"
"encoding/json" "encoding/json"
"fmt"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -26,46 +27,46 @@ type ConnectionState struct {
writeLock sync.Mutex writeLock sync.Mutex
} }
func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState { func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) {
var dhFunc noise.DHFunc var dhFunc noise.DHFunc
switch certState.Certificate.Curve() { switch crt.Curve() {
case cert.Curve_CURVE25519: case cert.Curve_CURVE25519:
dhFunc = noise.DH25519 dhFunc = noise.DH25519
case cert.Curve_P256: case cert.Curve_P256:
if certState.pkcs11Backed { if cs.pkcs11Backed {
dhFunc = noiseutil.DHP256PKCS11 dhFunc = noiseutil.DHP256PKCS11
} else { } else {
dhFunc = noiseutil.DHP256 dhFunc = noiseutil.DHP256
} }
default: default:
l.Errorf("invalid curve: %s", certState.Certificate.Curve()) return nil, fmt.Errorf("invalid curve: %s", crt.Curve())
return nil
} }
var cs noise.CipherSuite var ncs noise.CipherSuite
if cipher == "chachapoly" { if cs.cipher == "chachapoly" {
cs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256) ncs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
} else { } else {
cs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256) ncs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
} }
static := noise.DHKey{Private: certState.PrivateKey, Public: certState.PublicKey} static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()}
b := NewBits(ReplayWindow) b := NewBits(ReplayWindow)
// Clear out bit 0, we never transmit it and we don't want it showing as packet loss // Clear out bit 0, we never transmit it, and we don't want it showing as packet loss
b.Update(l, 0) b.Update(l, 0)
hs, err := noise.NewHandshakeState(noise.Config{ hs, err := noise.NewHandshakeState(noise.Config{
CipherSuite: cs, CipherSuite: ncs,
Random: rand.Reader, Random: rand.Reader,
Pattern: pattern, Pattern: pattern,
Initiator: initiator, Initiator: initiator,
StaticKeypair: static, StaticKeypair: static,
PresharedKey: psk, //NOTE: These should come from CertState (pki.go) when we finally implement it
PresharedKeyPlacement: pskStage, PresharedKey: []byte{},
PresharedKeyPlacement: 0,
}) })
if err != nil { if err != nil {
return nil return nil, fmt.Errorf("NewConnectionState: %s", err)
} }
// The queue and ready params prevent a counter race that would happen when // The queue and ready params prevent a counter race that would happen when
@ -74,12 +75,12 @@ func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, i
H: hs, H: hs,
initiator: initiator, initiator: initiator,
window: b, window: b,
myCert: certState.Certificate, myCert: crt,
} }
// always start the counter from 2, as packet 1 and packet 2 are handshake packets. // always start the counter from 2, as packet 1 and packet 2 are handshake packets.
ci.messageCounter.Add(2) ci.messageCounter.Add(2)
return ci return ci, nil
} }
func (cs *ConnectionState) MarshalJSON() ([]byte, error) { func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
@ -89,3 +90,7 @@ func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
"message_counter": cs.messageCounter.Load(), "message_counter": cs.messageCounter.Load(),
}) })
} }
func (cs *ConnectionState) Curve() cert.Curve {
return cs.myCert.Curve()
}

View File

@ -19,9 +19,9 @@ import (
type controlEach func(h *HostInfo) type controlEach func(h *HostInfo)
type controlHostLister interface { type controlHostLister interface {
QueryVpnIp(vpnIp netip.Addr) *HostInfo QueryVpnAddr(vpnAddr netip.Addr) *HostInfo
ForEachIndex(each controlEach) ForEachIndex(each controlEach)
ForEachVpnIp(each controlEach) ForEachVpnAddr(each controlEach)
GetPreferredRanges() []netip.Prefix GetPreferredRanges() []netip.Prefix
} }
@ -37,7 +37,7 @@ type Control struct {
} }
type ControlHostInfo struct { type ControlHostInfo struct {
VpnIp netip.Addr `json:"vpnIp"` VpnAddrs []netip.Addr `json:"vpnAddrs"`
LocalIndex uint32 `json:"localIndex"` LocalIndex uint32 `json:"localIndex"`
RemoteIndex uint32 `json:"remoteIndex"` RemoteIndex uint32 `json:"remoteIndex"`
RemoteAddrs []netip.AddrPort `json:"remoteAddrs"` RemoteAddrs []netip.AddrPort `json:"remoteAddrs"`
@ -131,10 +131,13 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
// GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found // GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found
func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate { func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate {
if c.f.myVpnNet.Addr() == vpnIp { _, found := c.f.myVpnAddrsTable.Lookup(vpnIp)
return c.f.pki.GetCertState().Certificate.Copy() if found {
// Only returning the default certificate since its impossible
// for any other host but ourselves to have more than 1
return c.f.pki.getCertState().GetDefaultCertificate().Copy()
} }
hi := c.f.hostMap.QueryVpnIp(vpnIp) hi := c.f.hostMap.QueryVpnAddr(vpnIp)
if hi == nil { if hi == nil {
return nil return nil
} }
@ -148,7 +151,7 @@ func (c *Control) CreateTunnel(vpnIp netip.Addr) {
// PrintTunnel creates a new tunnel to the given vpn ip. // PrintTunnel creates a new tunnel to the given vpn ip.
func (c *Control) PrintTunnel(vpnIp netip.Addr) *ControlHostInfo { func (c *Control) PrintTunnel(vpnIp netip.Addr) *ControlHostInfo {
hi := c.f.hostMap.QueryVpnIp(vpnIp) hi := c.f.hostMap.QueryVpnAddr(vpnIp)
if hi == nil { if hi == nil {
return nil return nil
} }
@ -165,9 +168,9 @@ func (c *Control) QueryLighthouse(vpnIp netip.Addr) *CacheMap {
return hi.CopyCache() return hi.CopyCache()
} }
// GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found // GetHostInfoByVpnAddr returns a single tunnels hostInfo, or nil if not found
// Caller should take care to Unmap() any 4in6 addresses prior to calling. // Caller should take care to Unmap() any 4in6 addresses prior to calling.
func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHostInfo { func (c *Control) GetHostInfoByVpnAddr(vpnAddr netip.Addr, pending bool) *ControlHostInfo {
var hl controlHostLister var hl controlHostLister
if pending { if pending {
hl = c.f.handshakeManager hl = c.f.handshakeManager
@ -175,7 +178,7 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHos
hl = c.f.hostMap hl = c.f.hostMap
} }
h := hl.QueryVpnIp(vpnIp) h := hl.QueryVpnAddr(vpnAddr)
if h == nil { if h == nil {
return nil return nil
} }
@ -187,7 +190,7 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHos
// SetRemoteForTunnel forces a tunnel to use a specific remote // SetRemoteForTunnel forces a tunnel to use a specific remote
// Caller should take care to Unmap() any 4in6 addresses prior to calling. // Caller should take care to Unmap() any 4in6 addresses prior to calling.
func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *ControlHostInfo { func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *ControlHostInfo {
hostInfo := c.f.hostMap.QueryVpnIp(vpnIp) hostInfo := c.f.hostMap.QueryVpnAddr(vpnIp)
if hostInfo == nil { if hostInfo == nil {
return nil return nil
} }
@ -200,7 +203,7 @@ func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *Con
// CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well. // CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
// Caller should take care to Unmap() any 4in6 addresses prior to calling. // Caller should take care to Unmap() any 4in6 addresses prior to calling.
func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool { func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool {
hostInfo := c.f.hostMap.QueryVpnIp(vpnIp) hostInfo := c.f.hostMap.QueryVpnAddr(vpnIp)
if hostInfo == nil { if hostInfo == nil {
return false return false
} }
@ -224,19 +227,14 @@ func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool {
// CloseAllTunnels is just like CloseTunnel except it goes through and shuts them all down, optionally you can avoid shutting down lighthouse tunnels // CloseAllTunnels is just like CloseTunnel except it goes through and shuts them all down, optionally you can avoid shutting down lighthouse tunnels
// the int returned is a count of tunnels closed // the int returned is a count of tunnels closed
func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) { func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
//TODO: this is probably better as a function in ConnectionManager or HostMap directly
lighthouses := c.f.lightHouse.GetLighthouses()
shutdown := func(h *HostInfo) { shutdown := func(h *HostInfo) {
if excludeLighthouses { if excludeLighthouses && c.f.lightHouse.IsAnyLighthouseAddr(h.vpnAddrs) {
if _, ok := lighthouses[h.vpnIp]; ok { return
return
}
} }
c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
c.f.closeTunnel(h) c.f.closeTunnel(h)
c.l.WithField("vpnIp", h.vpnIp).WithField("udpAddr", h.remote). c.l.WithField("vpnAddrs", h.vpnAddrs).WithField("udpAddr", h.remote).
Debug("Sending close tunnel message") Debug("Sending close tunnel message")
closed++ closed++
} }
@ -246,7 +244,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
// Grab the hostMap lock to access the Relays map // Grab the hostMap lock to access the Relays map
c.f.hostMap.Lock() c.f.hostMap.Lock()
for _, relayingHost := range c.f.hostMap.Relays { for _, relayingHost := range c.f.hostMap.Relays {
relayingHosts[relayingHost.vpnIp] = relayingHost relayingHosts[relayingHost.vpnAddrs[0]] = relayingHost
} }
c.f.hostMap.Unlock() c.f.hostMap.Unlock()
@ -254,7 +252,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
// Grab the hostMap lock to access the Hosts map // Grab the hostMap lock to access the Hosts map
c.f.hostMap.Lock() c.f.hostMap.Lock()
for _, relayHost := range c.f.hostMap.Indexes { for _, relayHost := range c.f.hostMap.Indexes {
if _, ok := relayingHosts[relayHost.vpnIp]; !ok { if _, ok := relayingHosts[relayHost.vpnAddrs[0]]; !ok {
hostInfos = append(hostInfos, relayHost) hostInfos = append(hostInfos, relayHost)
} }
} }
@ -274,9 +272,8 @@ func (c *Control) Device() overlay.Device {
} }
func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo { func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
chi := ControlHostInfo{ chi := ControlHostInfo{
VpnIp: h.vpnIp, VpnAddrs: make([]netip.Addr, len(h.vpnAddrs)),
LocalIndex: h.localIndexId, LocalIndex: h.localIndexId,
RemoteIndex: h.remoteIndexId, RemoteIndex: h.remoteIndexId,
RemoteAddrs: h.remotes.CopyAddrs(preferredRanges), RemoteAddrs: h.remotes.CopyAddrs(preferredRanges),
@ -285,6 +282,10 @@ func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
CurrentRemote: h.remote, CurrentRemote: h.remote,
} }
for i, a := range h.vpnAddrs {
chi.VpnAddrs[i] = a
}
if h.ConnectionState != nil { if h.ConnectionState != nil {
chi.MessageCounter = h.ConnectionState.messageCounter.Load() chi.MessageCounter = h.ConnectionState.messageCounter.Load()
} }
@ -299,7 +300,7 @@ func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
func listHostMapHosts(hl controlHostLister) []ControlHostInfo { func listHostMapHosts(hl controlHostLister) []ControlHostInfo {
hosts := make([]ControlHostInfo, 0) hosts := make([]ControlHostInfo, 0)
pr := hl.GetPreferredRanges() pr := hl.GetPreferredRanges()
hl.ForEachVpnIp(func(hostinfo *HostInfo) { hl.ForEachVpnAddr(func(hostinfo *HostInfo) {
hosts = append(hosts, copyHostInfo(hostinfo, pr)) hosts = append(hosts, copyHostInfo(hostinfo, pr))
}) })
return hosts return hosts

View File

@ -13,13 +13,13 @@ import (
) )
func TestControl_GetHostInfoByVpnIp(t *testing.T) { func TestControl_GetHostInfoByVpnIp(t *testing.T) {
//TODO: with multiple certificate versions we have a problem with this test //TODO: CERT-V2 with multiple certificate versions we have a problem with this test
// Some certs versions have different characteristics and each version implements their own Copy() func // Some certs versions have different characteristics and each version implements their own Copy() func
// which means this is not a good place to test for exposing memory // which means this is not a good place to test for exposing memory
l := test.NewLogger() l := test.NewLogger()
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object // Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
// To properly ensure we are not exposing core memory to the caller // To properly ensure we are not exposing core memory to the caller
hm := newHostMap(l, netip.Prefix{}) hm := newHostMap(l)
hm.preferredRanges.Store(&[]netip.Prefix{}) hm.preferredRanges.Store(&[]netip.Prefix{})
remote1 := netip.MustParseAddrPort("0.0.0.100:4444") remote1 := netip.MustParseAddrPort("0.0.0.100:4444")
@ -35,9 +35,9 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
Mask: net.IPMask{255, 255, 255, 0}, Mask: net.IPMask{255, 255, 255, 0},
} }
remotes := NewRemoteList(nil) remotes := NewRemoteList([]netip.Addr{netip.IPv4Unspecified()}, nil)
remotes.unlockedPrependV4(netip.IPv4Unspecified(), NewIp4AndPortFromNetIP(remote1.Addr(), remote1.Port())) remotes.unlockedPrependV4(netip.IPv4Unspecified(), netAddrToProtoV4AddrPort(remote1.Addr(), remote1.Port()))
remotes.unlockedPrependV6(netip.IPv4Unspecified(), NewIp6AndPortFromNetIP(remote2.Addr(), remote2.Port())) remotes.unlockedPrependV6(netip.IPv4Unspecified(), netAddrToProtoV6AddrPort(remote2.Addr(), remote2.Port()))
vpnIp, ok := netip.AddrFromSlice(ipNet.IP) vpnIp, ok := netip.AddrFromSlice(ipNet.IP)
assert.True(t, ok) assert.True(t, ok)
@ -51,11 +51,11 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
}, },
remoteIndexId: 200, remoteIndexId: 200,
localIndexId: 201, localIndexId: 201,
vpnIp: vpnIp, vpnAddrs: []netip.Addr{vpnIp},
relayState: RelayState{ relayState: RelayState{
relays: map[netip.Addr]struct{}{}, relays: map[netip.Addr]struct{}{},
relayForByIp: map[netip.Addr]*Relay{}, relayForByAddr: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{}, relayForByIdx: map[uint32]*Relay{},
}, },
}, &Interface{}) }, &Interface{})
@ -70,11 +70,11 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
}, },
remoteIndexId: 200, remoteIndexId: 200,
localIndexId: 201, localIndexId: 201,
vpnIp: vpnIp2, vpnAddrs: []netip.Addr{vpnIp2},
relayState: RelayState{ relayState: RelayState{
relays: map[netip.Addr]struct{}{}, relays: map[netip.Addr]struct{}{},
relayForByIp: map[netip.Addr]*Relay{}, relayForByAddr: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{}, relayForByIdx: map[uint32]*Relay{},
}, },
}, &Interface{}) }, &Interface{})
@ -85,10 +85,10 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
l: logrus.New(), l: logrus.New(),
} }
thi := c.GetHostInfoByVpnIp(vpnIp, false) thi := c.GetHostInfoByVpnAddr(vpnIp, false)
expectedInfo := ControlHostInfo{ expectedInfo := ControlHostInfo{
VpnIp: vpnIp, VpnAddrs: []netip.Addr{vpnIp},
LocalIndex: 201, LocalIndex: 201,
RemoteIndex: 200, RemoteIndex: 200,
RemoteAddrs: []netip.AddrPort{remote2, remote1}, RemoteAddrs: []netip.AddrPort{remote2, remote1},
@ -100,13 +100,13 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
} }
// Make sure we don't have any unexpected fields // Make sure we don't have any unexpected fields
assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi) assertFields(t, []string{"VpnAddrs", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
assert.EqualValues(t, &expectedInfo, thi) assert.EqualValues(t, &expectedInfo, thi)
test.AssertDeepCopyEqual(t, &expectedInfo, thi) test.AssertDeepCopyEqual(t, &expectedInfo, thi)
// Make sure we don't panic if the host info doesn't have a cert yet // Make sure we don't panic if the host info doesn't have a cert yet
assert.NotPanics(t, func() { assert.NotPanics(t, func() {
thi = c.GetHostInfoByVpnIp(vpnIp2, false) thi = c.GetHostInfoByVpnAddr(vpnIp2, false)
}) })
} }

View File

@ -6,8 +6,6 @@ package nebula
import ( import (
"net/netip" "net/netip"
"github.com/slackhq/nebula/cert"
"github.com/google/gopacket" "github.com/google/gopacket"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
@ -51,15 +49,15 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType,
// This is necessary if you did not configure static hosts or are not running a lighthouse // This is necessary if you did not configure static hosts or are not running a lighthouse
func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort) { func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort) {
c.f.lightHouse.Lock() c.f.lightHouse.Lock()
remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp) remoteList := c.f.lightHouse.unlockedGetRemoteList([]netip.Addr{vpnIp})
remoteList.Lock() remoteList.Lock()
defer remoteList.Unlock() defer remoteList.Unlock()
c.f.lightHouse.Unlock() c.f.lightHouse.Unlock()
if toAddr.Addr().Is4() { if toAddr.Addr().Is4() {
remoteList.unlockedPrependV4(vpnIp, NewIp4AndPortFromNetIP(toAddr.Addr(), toAddr.Port())) remoteList.unlockedPrependV4(vpnIp, netAddrToProtoV4AddrPort(toAddr.Addr(), toAddr.Port()))
} else { } else {
remoteList.unlockedPrependV6(vpnIp, NewIp6AndPortFromNetIP(toAddr.Addr(), toAddr.Port())) remoteList.unlockedPrependV6(vpnIp, netAddrToProtoV6AddrPort(toAddr.Addr(), toAddr.Port()))
} }
} }
@ -67,12 +65,12 @@ func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort)
// This is necessary to inform an initiator of possible relays for communicating with a responder // This is necessary to inform an initiator of possible relays for communicating with a responder
func (c *Control) InjectRelays(vpnIp netip.Addr, relayVpnIps []netip.Addr) { func (c *Control) InjectRelays(vpnIp netip.Addr, relayVpnIps []netip.Addr) {
c.f.lightHouse.Lock() c.f.lightHouse.Lock()
remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp) remoteList := c.f.lightHouse.unlockedGetRemoteList([]netip.Addr{vpnIp})
remoteList.Lock() remoteList.Lock()
defer remoteList.Unlock() defer remoteList.Unlock()
c.f.lightHouse.Unlock() c.f.lightHouse.Unlock()
remoteList.unlockedSetRelay(vpnIp, vpnIp, relayVpnIps) remoteList.unlockedSetRelay(vpnIp, relayVpnIps)
} }
// GetFromTun will pull a packet off the tun side of nebula // GetFromTun will pull a packet off the tun side of nebula
@ -99,21 +97,42 @@ func (c *Control) InjectUDPPacket(p *udp.Packet) {
} }
// InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol // InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol
func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort uint16, data []byte) { func (c *Control) InjectTunUDPPacket(toAddr netip.Addr, toPort uint16, fromAddr netip.Addr, fromPort uint16, data []byte) {
//TODO: IPV6-WORK serialize := make([]gopacket.SerializableLayer, 0)
ip := layers.IPv4{ var netLayer gopacket.NetworkLayer
Version: 4, if toAddr.Is6() {
TTL: 64, if !fromAddr.Is6() {
Protocol: layers.IPProtocolUDP, panic("Cant send ipv6 to ipv4")
SrcIP: c.f.inside.Cidr().Addr().Unmap().AsSlice(), }
DstIP: toIp.Unmap().AsSlice(), ip := &layers.IPv6{
Version: 6,
NextHeader: layers.IPProtocolUDP,
SrcIP: fromAddr.Unmap().AsSlice(),
DstIP: toAddr.Unmap().AsSlice(),
}
serialize = append(serialize, ip)
netLayer = ip
} else {
if !fromAddr.Is4() {
panic("Cant send ipv4 to ipv6")
}
ip := &layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
SrcIP: fromAddr.Unmap().AsSlice(),
DstIP: toAddr.Unmap().AsSlice(),
}
serialize = append(serialize, ip)
netLayer = ip
} }
udp := layers.UDP{ udp := layers.UDP{
SrcPort: layers.UDPPort(fromPort), SrcPort: layers.UDPPort(fromPort),
DstPort: layers.UDPPort(toPort), DstPort: layers.UDPPort(toPort),
} }
err := udp.SetNetworkLayerForChecksum(&ip) err := udp.SetNetworkLayerForChecksum(netLayer)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -123,7 +142,9 @@ func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort ui
ComputeChecksums: true, ComputeChecksums: true,
FixLengths: true, FixLengths: true,
} }
err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload(data))
serialize = append(serialize, &udp, gopacket.Payload(data))
err = gopacket.SerializeLayers(buffer, opt, serialize...)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -131,8 +152,8 @@ func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort ui
c.f.inside.(*overlay.TestTun).Send(buffer.Bytes()) c.f.inside.(*overlay.TestTun).Send(buffer.Bytes())
} }
func (c *Control) GetVpnIp() netip.Addr { func (c *Control) GetVpnAddrs() []netip.Addr {
return c.f.myVpnNet.Addr() return c.f.myVpnAddrs
} }
func (c *Control) GetUDPAddr() netip.AddrPort { func (c *Control) GetUDPAddr() netip.AddrPort {
@ -140,7 +161,7 @@ func (c *Control) GetUDPAddr() netip.AddrPort {
} }
func (c *Control) KillPendingTunnel(vpnIp netip.Addr) bool { func (c *Control) KillPendingTunnel(vpnIp netip.Addr) bool {
hostinfo := c.f.handshakeManager.QueryVpnIp(vpnIp) hostinfo := c.f.handshakeManager.QueryVpnAddr(vpnIp)
if hostinfo == nil { if hostinfo == nil {
return false return false
} }
@ -153,8 +174,8 @@ func (c *Control) GetHostmap() *HostMap {
return c.f.hostMap return c.f.hostMap
} }
func (c *Control) GetCert() cert.Certificate { func (c *Control) GetCertState() *CertState {
return c.f.pki.GetCertState().Certificate return c.f.pki.getCertState()
} }
func (c *Control) ReHandshake(vpnIp netip.Addr) { func (c *Control) ReHandshake(vpnIp netip.Addr) {

View File

@ -8,6 +8,7 @@ import (
"strings" "strings"
"sync" "sync"
"github.com/gaissmai/bart"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
@ -21,24 +22,39 @@ var dnsAddr string
type dnsRecords struct { type dnsRecords struct {
sync.RWMutex sync.RWMutex
dnsMap map[string]string l *logrus.Logger
hostMap *HostMap dnsMap4 map[string]netip.Addr
dnsMap6 map[string]netip.Addr
hostMap *HostMap
myVpnAddrsTable *bart.Table[struct{}]
} }
func newDnsRecords(hostMap *HostMap) *dnsRecords { func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords {
return &dnsRecords{ return &dnsRecords{
dnsMap: make(map[string]string), l: l,
hostMap: hostMap, dnsMap4: make(map[string]netip.Addr),
dnsMap6: make(map[string]netip.Addr),
hostMap: hostMap,
myVpnAddrsTable: cs.myVpnAddrsTable,
} }
} }
func (d *dnsRecords) Query(data string) string { func (d *dnsRecords) Query(q uint16, data string) netip.Addr {
data = strings.ToLower(data)
d.RLock() d.RLock()
defer d.RUnlock() defer d.RUnlock()
if r, ok := d.dnsMap[strings.ToLower(data)]; ok { switch q {
return r case dns.TypeA:
if r, ok := d.dnsMap4[data]; ok {
return r
}
case dns.TypeAAAA:
if r, ok := d.dnsMap6[data]; ok {
return r
}
} }
return ""
return netip.Addr{}
} }
func (d *dnsRecords) QueryCert(data string) string { func (d *dnsRecords) QueryCert(data string) string {
@ -47,7 +63,7 @@ func (d *dnsRecords) QueryCert(data string) string {
return "" return ""
} }
hostinfo := d.hostMap.QueryVpnIp(ip) hostinfo := d.hostMap.QueryVpnAddr(ip)
if hostinfo == nil { if hostinfo == nil {
return "" return ""
} }
@ -64,38 +80,62 @@ func (d *dnsRecords) QueryCert(data string) string {
return string(b) return string(b)
} }
func (d *dnsRecords) Add(host, data string) { // Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host`
func (d *dnsRecords) Add(host string, addresses []netip.Addr) {
host = strings.ToLower(host)
d.Lock() d.Lock()
defer d.Unlock() defer d.Unlock()
d.dnsMap[strings.ToLower(host)] = data haveV4 := false
haveV6 := false
for _, addr := range addresses {
if addr.Is4() && !haveV4 {
d.dnsMap4[host] = addr
haveV4 = true
} else if addr.Is6() && !haveV6 {
d.dnsMap6[host] = addr
haveV6 = true
}
if haveV4 && haveV6 {
break
}
}
} }
func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) { func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool {
a, _, _ := net.SplitHostPort(addr)
b, err := netip.ParseAddr(a)
if err != nil {
return false
}
if b.IsLoopback() {
return true
}
_, found := d.myVpnAddrsTable.Lookup(b)
return found //if we found it in this table, it's good
}
func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
for _, q := range m.Question { for _, q := range m.Question {
switch q.Qtype { switch q.Qtype {
case dns.TypeA: case dns.TypeA, dns.TypeAAAA:
l.Debugf("Query for A %s", q.Name) qType := dns.TypeToString[q.Qtype]
ip := dnsR.Query(q.Name) d.l.Debugf("Query for %s %s", qType, q.Name)
if ip != "" { ip := d.Query(q.Qtype, q.Name)
rr, err := dns.NewRR(fmt.Sprintf("%s A %s", q.Name, ip)) if ip.IsValid() {
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip))
if err == nil { if err == nil {
m.Answer = append(m.Answer, rr) m.Answer = append(m.Answer, rr)
} }
} }
case dns.TypeTXT: case dns.TypeTXT:
a, _, _ := net.SplitHostPort(w.RemoteAddr().String()) // We only answer these queries from nebula nodes or localhost
b, err := netip.ParseAddr(a) if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) {
if err != nil {
return return
} }
d.l.Debugf("Query for TXT %s", q.Name)
// We don't answer these queries from non nebula nodes or localhost ip := d.QueryCert(q.Name)
//l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR)
if !dnsR.hostMap.vpnCIDR.Contains(b) && a != "127.0.0.1" {
return
}
l.Debugf("Query for TXT %s", q.Name)
ip := dnsR.QueryCert(q.Name)
if ip != "" { if ip != "" {
rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip)) rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
if err == nil { if err == nil {
@ -110,26 +150,24 @@ func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
} }
} }
func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) { func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg) m := new(dns.Msg)
m.SetReply(r) m.SetReply(r)
m.Compress = false m.Compress = false
switch r.Opcode { switch r.Opcode {
case dns.OpcodeQuery: case dns.OpcodeQuery:
parseQuery(l, m, w) d.parseQuery(m, w)
} }
w.WriteMsg(m) w.WriteMsg(m)
} }
func dnsMain(l *logrus.Logger, hostMap *HostMap, c *config.C) func() { func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() {
dnsR = newDnsRecords(hostMap) dnsR = newDnsRecords(l, cs, hostMap)
// attach request handler func // attach request handler func
dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) { dns.HandleFunc(".", dnsR.handleDnsRequest)
handleDnsRequest(l, w, r)
})
c.RegisterReloadCallback(func(c *config.C) { c.RegisterReloadCallback(func(c *config.C) {
reloadDns(l, c) reloadDns(l, c)

View File

@ -1,23 +1,38 @@
package nebula package nebula
import ( import (
"net/netip"
"testing" "testing"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestParsequery(t *testing.T) { func TestParsequery(t *testing.T) {
//TODO: This test is basically pointless l := logrus.New()
hostMap := &HostMap{} hostMap := &HostMap{}
ds := newDnsRecords(hostMap) ds := newDnsRecords(l, &CertState{}, hostMap)
ds.Add("test.com.com", "1.2.3.4") addrs := []netip.Addr{
netip.MustParseAddr("1.2.3.4"),
netip.MustParseAddr("1.2.3.5"),
netip.MustParseAddr("fd01::24"),
netip.MustParseAddr("fd01::25"),
}
ds.Add("test.com.com", addrs)
m := new(dns.Msg) m := &dns.Msg{}
m.SetQuestion("test.com.com", dns.TypeA) m.SetQuestion("test.com.com", dns.TypeA)
ds.parseQuery(m, nil)
assert.NotNil(t, m.Answer)
assert.Equal(t, "1.2.3.4", m.Answer[0].(*dns.A).A.String())
//parseQuery(m) m = &dns.Msg{}
m.SetQuestion("test.com.com", dns.TypeAAAA)
ds.parseQuery(m, nil)
assert.NotNil(t, m.Answer)
assert.Equal(t, "fd01::24", m.Answer[0].(*dns.AAAA).AAAA.String())
} }
func Test_getDnsServerAddr(t *testing.T) { func Test_getDnsServerAddr(t *testing.T) {

File diff suppressed because it is too large Load Diff

View File

@ -8,6 +8,7 @@ import (
"io" "io"
"net/netip" "net/netip"
"os" "os"
"strings"
"testing" "testing"
"time" "time"
@ -17,6 +18,7 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula" "github.com/slackhq/nebula"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/e2e/router" "github.com/slackhq/nebula/e2e/router"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -26,25 +28,35 @@ import (
type m map[string]interface{} type m map[string]interface{}
// newSimpleServer creates a nebula instance with many assumptions // newSimpleServer creates a nebula instance with many assumptions
func newSimpleServer(caCrt cert.Certificate, caKey []byte, name string, sVpnIpNet string, overrides m) (*nebula.Control, netip.Prefix, netip.AddrPort, *config.C) { func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
l := NewTestLogger() l := NewTestLogger()
vpnIpNet, err := netip.ParsePrefix(sVpnIpNet) var vpnNetworks []netip.Prefix
if err != nil { for _, sn := range strings.Split(sVpnNetworks, ",") {
panic(err) vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
if err != nil {
panic(err)
}
vpnNetworks = append(vpnNetworks, vpnIpNet)
}
if len(vpnNetworks) == 0 {
panic("no vpn networks")
} }
var udpAddr netip.AddrPort var udpAddr netip.AddrPort
if vpnIpNet.Addr().Is4() { if vpnNetworks[0].Addr().Is4() {
budpIp := vpnIpNet.Addr().As4() budpIp := vpnNetworks[0].Addr().As4()
budpIp[1] -= 128 budpIp[1] -= 128
udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242) udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242)
} else { } else {
budpIp := vpnIpNet.Addr().As16() budpIp := vpnNetworks[0].Addr().As16()
budpIp[13] -= 128 // beef for funsies
budpIp[2] = 190
budpIp[3] = 239
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242) udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
} }
_, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{vpnIpNet}, nil, []string{}) _, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{})
caB, err := caCrt.MarshalPEM() caB, err := caCrt.MarshalPEM()
if err != nil { if err != nil {
@ -88,11 +100,16 @@ func newSimpleServer(caCrt cert.Certificate, caKey []byte, name string, sVpnIpNe
} }
if overrides != nil { if overrides != nil {
err = mergo.Merge(&overrides, mc, mergo.WithAppendSlice) final := m{}
err = mergo.Merge(&final, overrides, mergo.WithAppendSlice)
if err != nil { if err != nil {
panic(err) panic(err)
} }
mc = overrides err = mergo.Merge(&final, mc, mergo.WithAppendSlice)
if err != nil {
panic(err)
}
mc = final
} }
cb, err := yaml.Marshal(mc) cb, err := yaml.Marshal(mc)
@ -109,7 +126,7 @@ func newSimpleServer(caCrt cert.Certificate, caKey []byte, name string, sVpnIpNe
panic(err) panic(err)
} }
return control, vpnIpNet, udpAddr, c return control, vpnNetworks, udpAddr, c
} }
type doneCb func() type doneCb func()
@ -132,27 +149,28 @@ func deadline(t *testing.T, seconds time.Duration) doneCb {
func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) { func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
// Send a packet from them to me // Send a packet from them to me
controlB.InjectTunUDPPacket(vpnIpA, 80, 90, []byte("Hi from B")) controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B"))
bPacket := r.RouteForAllUntilTxTun(controlA) bPacket := r.RouteForAllUntilTxTun(controlA)
assertUdpPacket(t, []byte("Hi from B"), bPacket, vpnIpB, vpnIpA, 90, 80) assertUdpPacket(t, []byte("Hi from B"), bPacket, vpnIpB, vpnIpA, 90, 80)
// And once more from me to them // And once more from me to them
controlA.InjectTunUDPPacket(vpnIpB, 80, 90, []byte("Hello from A")) controlA.InjectTunUDPPacket(vpnIpB, 80, vpnIpA, 90, []byte("Hello from A"))
aPacket := r.RouteForAllUntilTxTun(controlB) aPacket := r.RouteForAllUntilTxTun(controlB)
assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80) assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80)
} }
func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control) { func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpnNetsB []netip.Prefix, controlA, controlB *nebula.Control) {
// Get both host infos // Get both host infos
hBinA := controlA.GetHostInfoByVpnIp(vpnIpB, false) //TODO: CERT-V2 we may want to loop over each vpnAddr and assert all the things
assert.NotNil(t, hBinA, "Host B was not found by vpnIp in controlA") hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false)
assert.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA")
hAinB := controlB.GetHostInfoByVpnIp(vpnIpA, false) hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false)
assert.NotNil(t, hAinB, "Host A was not found by vpnIp in controlB") assert.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB")
// Check that both vpn and real addr are correct // Check that both vpn and real addr are correct
assert.Equal(t, vpnIpB, hBinA.VpnIp, "Host B VpnIp is wrong in control A") assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A")
assert.Equal(t, vpnIpA, hAinB.VpnIp, "Host A VpnIp is wrong in control B") assert.EqualValues(t, getAddrs(vpnNetsA), hAinB.VpnAddrs, "Host A VpnIp is wrong in control B")
assert.Equal(t, addrB, hBinA.CurrentRemote, "Host B remote is wrong in control A") assert.Equal(t, addrB, hBinA.CurrentRemote, "Host B remote is wrong in control A")
assert.Equal(t, addrA, hAinB.CurrentRemote, "Host A remote is wrong in control B") assert.Equal(t, addrA, hAinB.CurrentRemote, "Host A remote is wrong in control B")
@ -160,25 +178,36 @@ func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnIpA, vpnIp
// Check that our indexes match // Check that our indexes match
assert.Equal(t, hBinA.LocalIndex, hAinB.RemoteIndex, "Host B local index does not match host A remote index") assert.Equal(t, hBinA.LocalIndex, hAinB.RemoteIndex, "Host B local index does not match host A remote index")
assert.Equal(t, hBinA.RemoteIndex, hAinB.LocalIndex, "Host B remote index does not match host A local index") assert.Equal(t, hBinA.RemoteIndex, hAinB.LocalIndex, "Host B remote index does not match host A local index")
//TODO: Would be nice to assert this memory
//checkIndexes := func(name string, hm *HostMap, hi *HostInfo) {
// hBbyIndex := hmA.Indexes[hBinA.localIndexId]
// assert.NotNil(t, hBbyIndex, "Could not host info by local index in %s", name)
// assert.Equal(t, &hBbyIndex, &hBinA, "%s Indexes map did not point to the right host info", name)
//
// //TODO: remote indexes are susceptible to collision
// hBbyRemoteIndex := hmA.RemoteIndexes[hBinA.remoteIndexId]
// assert.NotNil(t, hBbyIndex, "Could not host info by remote index in %s", name)
// assert.Equal(t, &hBbyRemoteIndex, &hBinA, "%s RemoteIndexes did not point to the right host info", name)
//}
//
//// Check hostmap indexes too
//checkIndexes("hmA", hmA, hBinA)
//checkIndexes("hmB", hmB, hAinB)
} }
func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
if toIp.Is6() {
assertUdpPacket6(t, expected, b, fromIp, toIp, fromPort, toPort)
} else {
assertUdpPacket4(t, expected, b, fromIp, toIp, fromPort, toPort)
}
}
func assertUdpPacket6(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
packet := gopacket.NewPacket(b, layers.LayerTypeIPv6, gopacket.Lazy)
v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6)
assert.NotNil(t, v6, "No ipv6 data found")
assert.Equal(t, fromIp.AsSlice(), []byte(v6.SrcIP), "Source ip was incorrect")
assert.Equal(t, toIp.AsSlice(), []byte(v6.DstIP), "Dest ip was incorrect")
udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
assert.NotNil(t, udp, "No udp data found")
assert.Equal(t, fromPort, uint16(udp.SrcPort), "Source port was incorrect")
assert.Equal(t, toPort, uint16(udp.DstPort), "Dest port was incorrect")
data := packet.ApplicationLayer()
assert.NotNil(t, data)
assert.Equal(t, expected, data.Payload(), "Data was incorrect")
}
func assertUdpPacket4(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy) packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy)
v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
assert.NotNil(t, v4, "No ipv4 data found") assert.NotNil(t, v4, "No ipv4 data found")
@ -197,6 +226,14 @@ func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr,
assert.Equal(t, expected, data.Payload(), "Data was incorrect") assert.Equal(t, expected, data.Payload(), "Data was incorrect")
} }
func getAddrs(ns []netip.Prefix) []netip.Addr {
var a []netip.Addr
for _, n := range ns {
a = append(a, n.Addr())
}
return a
}
func NewTestLogger() *logrus.Logger { func NewTestLogger() *logrus.Logger {
l := logrus.New() l := logrus.New()

View File

@ -58,8 +58,9 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {
var lines []string var lines []string
var globalLines []*edge var globalLines []*edge
clusterName := strings.Trim(c.GetCert().Name(), " ") crt := c.GetCertState().GetDefaultCertificate()
clusterVpnIp := c.GetCert().Networks()[0].Addr() clusterName := strings.Trim(crt.Name(), " ")
clusterVpnIp := crt.Networks()[0].Addr()
r := fmt.Sprintf("\tsubgraph %s[\"%s (%s)\"]\n", clusterName, clusterName, clusterVpnIp) r := fmt.Sprintf("\tsubgraph %s[\"%s (%s)\"]\n", clusterName, clusterName, clusterVpnIp)
hm := c.GetHostmap() hm := c.GetHostmap()
@ -101,7 +102,7 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {
for _, idx := range indexes { for _, idx := range indexes {
hi, ok := hm.Indexes[idx] hi, ok := hm.Indexes[idx]
if ok { if ok {
r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnIp()) r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnAddrs())
remoteClusterName := strings.Trim(hi.GetCert().Certificate.Name(), " ") remoteClusterName := strings.Trim(hi.GetCert().Certificate.Name(), " ")
globalLines = append(globalLines, &edge{from: fmt.Sprintf("%v.%v", clusterName, idx), to: fmt.Sprintf("%v.%v", remoteClusterName, hi.GetRemoteIndex())}) globalLines = append(globalLines, &edge{from: fmt.Sprintf("%v.%v", clusterName, idx), to: fmt.Sprintf("%v.%v", remoteClusterName, hi.GetRemoteIndex())})
_ = hi _ = hi

View File

@ -10,8 +10,8 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"reflect" "reflect"
"regexp"
"sort" "sort"
"strings"
"sync" "sync"
"testing" "testing"
"time" "time"
@ -136,7 +136,10 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
panic("Duplicate listen address: " + addr.String()) panic("Duplicate listen address: " + addr.String())
} }
r.vpnControls[c.GetVpnIp()] = c for _, vpnAddr := range c.GetVpnAddrs() {
r.vpnControls[vpnAddr] = c
}
r.controls[addr] = c r.controls[addr] = c
} }
@ -213,11 +216,11 @@ func (r *R) renderFlow() {
continue continue
} }
participants[addr] = struct{}{} participants[addr] = struct{}{}
sanAddr := strings.Replace(addr.String(), ":", "-", 1) sanAddr := normalizeName(addr.String())
participantsVals = append(participantsVals, sanAddr) participantsVals = append(participantsVals, sanAddr)
fmt.Fprintf( fmt.Fprintf(
f, " participant %s as Nebula: %s<br/>UDP: %s\n", f, " participant %s as Nebula: %s<br/>UDP: %s\n",
sanAddr, e.packet.from.GetVpnIp(), sanAddr, sanAddr, e.packet.from.GetVpnAddrs(), sanAddr,
) )
} }
@ -250,9 +253,9 @@ func (r *R) renderFlow() {
fmt.Fprintf(f, fmt.Fprintf(f,
" %s%s%s: %s(%s), index %v, counter: %v\n", " %s%s%s: %s(%s), index %v, counter: %v\n",
strings.Replace(p.from.GetUDPAddr().String(), ":", "-", 1), normalizeName(p.from.GetUDPAddr().String()),
line, line,
strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1), normalizeName(p.to.GetUDPAddr().String()),
h.TypeName(), h.SubTypeName(), h.RemoteIndex, h.MessageCounter, h.TypeName(), h.SubTypeName(), h.RemoteIndex, h.MessageCounter,
) )
} }
@ -267,6 +270,11 @@ func (r *R) renderFlow() {
} }
} }
func normalizeName(s string) string {
rx := regexp.MustCompile("[\\[\\]\\:]")
return rx.ReplaceAllLiteralString(s, "_")
}
// IgnoreFlow tells the router to stop recording future flows that matches the provided criteria. // IgnoreFlow tells the router to stop recording future flows that matches the provided criteria.
// messageType and subType will target nebula underlay packets while tun will target nebula overlay packets // messageType and subType will target nebula underlay packets while tun will target nebula overlay packets
// NOTE: This is a very broad system, if you set tun to true then no more tun traffic will be rendered // NOTE: This is a very broad system, if you set tun to true then no more tun traffic will be rendered
@ -303,7 +311,7 @@ func (r *R) RenderHostmaps(title string, controls ...*nebula.Control) {
func (r *R) renderHostmaps(title string) { func (r *R) renderHostmaps(title string) {
c := maps.Values(r.controls) c := maps.Values(r.controls)
sort.SliceStable(c, func(i, j int) bool { sort.SliceStable(c, func(i, j int) bool {
return c[i].GetVpnIp().Compare(c[j].GetVpnIp()) > 0 return c[i].GetVpnAddrs()[0].Compare(c[j].GetVpnAddrs()[0]) > 0
}) })
s := renderHostmaps(c...) s := renderHostmaps(c...)
@ -419,10 +427,11 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) []
// Nope, lets push the sender along // Nope, lets push the sender along
case p := <-udpTx: case p := <-udpTx:
r.Lock() r.Lock()
c := r.getControl(sender.GetUDPAddr(), p.To, p) a := sender.GetUDPAddr()
c := r.getControl(a, p.To, p)
if c == nil { if c == nil {
r.Unlock() r.Unlock()
panic("No control for udp tx") panic("No control for udp tx " + a.String())
} }
fp := r.unlockedInjectFlow(sender, c, p, false) fp := r.unlockedInjectFlow(sender, c, p, false)
c.InjectUDPPacket(p) c.InjectUDPPacket(p)
@ -475,10 +484,11 @@ func (r *R) RouteForAllUntilTxTun(receiver *nebula.Control) []byte {
} else { } else {
// we are a udp tx, route and continue // we are a udp tx, route and continue
p := rx.Interface().(*udp.Packet) p := rx.Interface().(*udp.Packet)
c := r.getControl(cm[x].GetUDPAddr(), p.To, p) a := cm[x].GetUDPAddr()
c := r.getControl(a, p.To, p)
if c == nil { if c == nil {
r.Unlock() r.Unlock()
panic("No control for udp tx") panic(fmt.Sprintf("No control for udp tx %s", p.To))
} }
fp := r.unlockedInjectFlow(cm[x], c, p, false) fp := r.unlockedInjectFlow(cm[x], c, p, false)
c.InjectUDPPacket(p) c.InjectUDPPacket(p)
@ -711,30 +721,42 @@ func (r *R) getControl(fromAddr, toAddr netip.AddrPort, p *udp.Packet) *nebula.C
} }
func (r *R) formatUdpPacket(p *packet) string { func (r *R) formatUdpPacket(p *packet) string {
packet := gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv4, gopacket.Lazy) var packet gopacket.Packet
v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) var srcAddr netip.Addr
if v4 == nil {
panic("not an ipv4 packet") packet = gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv6, gopacket.Lazy)
if packet.ErrorLayer() == nil {
v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6)
if v6 == nil {
panic("not an ipv6 packet")
}
srcAddr, _ = netip.AddrFromSlice(v6.SrcIP)
} else {
packet = gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv4, gopacket.Lazy)
v6 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
if v6 == nil {
panic("not an ipv6 packet")
}
srcAddr, _ = netip.AddrFromSlice(v6.SrcIP)
} }
from := "unknown" from := "unknown"
srcAddr, _ := netip.AddrFromSlice(v4.SrcIP)
if c, ok := r.vpnControls[srcAddr]; ok { if c, ok := r.vpnControls[srcAddr]; ok {
from = c.GetUDPAddr().String() from = c.GetUDPAddr().String()
} }
udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) udpLayer := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
if udp == nil { if udpLayer == nil {
panic("not a udp packet") panic("not a udp packet")
} }
data := packet.ApplicationLayer() data := packet.ApplicationLayer()
return fmt.Sprintf( return fmt.Sprintf(
" %s-->>%s: src port: %v<br/>dest port: %v<br/>data: \"%v\"\n", " %s-->>%s: src port: %v<br/>dest port: %v<br/>data: \"%v\"\n",
strings.Replace(from, ":", "-", 1), normalizeName(from),
strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1), normalizeName(p.to.GetUDPAddr().String()),
udp.SrcPort, udpLayer.SrcPort,
udp.DstPort, udpLayer.DstPort,
string(data.Payload()), string(data.Payload()),
) )
} }

View File

@ -13,6 +13,12 @@ pki:
# disconnect_invalid is a toggle to force a client to be disconnected if the certificate is expired or invalid. # disconnect_invalid is a toggle to force a client to be disconnected if the certificate is expired or invalid.
#disconnect_invalid: true #disconnect_invalid: true
# default_version controls which certificate version is used in handshakes.
# This setting only applies if both a v1 and a v2 certificate are configured, in which case it will default to `1`.
# Once all hosts in the mesh are configured with both a v1 and v2 certificate then this should be changed to `2`.
# After all hosts in the mesh are using a v2 certificate then v1 certificates are no longer needed.
# default_version: 1
# The static host map defines a set of hosts with fixed IP addresses on the internet (or any network). # The static host map defines a set of hosts with fixed IP addresses on the internet (or any network).
# A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel. # A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel.
# The syntax is: # The syntax is:
@ -244,7 +250,6 @@ tun:
# in nebula configuration files. Default false, not reloadable. # in nebula configuration files. Default false, not reloadable.
#use_system_route_table: false #use_system_route_table: false
# TODO
# Configure logging level # Configure logging level
logging: logging:
# panic, fatal, error, warning, info, or debug. Default is info and is reloadable. # panic, fatal, error, warning, info, or debug. Default is info and is reloadable.
@ -336,10 +341,12 @@ firewall:
# host: `any` or a literal hostname, ie `test-host` # host: `any` or a literal hostname, ie `test-host`
# group: `any` or a literal group name, ie `default-group` # group: `any` or a literal group name, ie `default-group`
# groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass # groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
# cidr: a remote CIDR, `0.0.0.0/0` is any. # cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6.
# local_cidr: a local CIDR, `0.0.0.0/0` is any. This could be used to filter destinations when using unsafe_routes. # local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. This could be used to filter destinations when using unsafe_routes.
# Default is `any` unless the certificate contains subnets and then the default is the ip issued in the certificate # If no unsafe networks are present in the certificate(s) or `default_local_cidr_any` is true then the default is any ipv4 or ipv6 network.
# if `default_local_cidr_any` is false, otherwise its `any`. # Otherwise the default is any vpn network assigned to via the certificate.
# `default_local_cidr_any` defaults to false and is deprecated, it will be removed in a future release.
# If there are unsafe routes present its best to set `local_cidr` to whatever best fits the situation.
# ca_name: An issuing CA name # ca_name: An issuing CA name
# ca_sha: An issuing CA shasum # ca_sha: An issuing CA shasum

View File

@ -22,7 +22,7 @@ import (
) )
type FirewallInterface interface { type FirewallInterface interface {
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, addr, localAddr netip.Prefix, caName string, caSha string) error
} }
type conn struct { type conn struct {
@ -51,9 +51,12 @@ type Firewall struct {
UDPTimeout time.Duration //linux: 180s max UDPTimeout time.Duration //linux: 180s max
DefaultTimeout time.Duration //linux: 600s DefaultTimeout time.Duration //linux: 600s
// Used to ensure we don't emit local packets for ips we don't own // routableNetworks describes the vpn addresses as well as any unsafe networks issued to us in the certificate.
localIps *bart.Table[struct{}] // The vpn addresses are a full bit match while the unsafe networks only match the prefix
assignedCIDR netip.Prefix routableNetworks *bart.Table[struct{}]
// assignedNetworks is a list of vpn networks assigned to us in the certificate.
assignedNetworks []netip.Prefix
hasUnsafeNetworks bool hasUnsafeNetworks bool
rules string rules string
@ -67,9 +70,9 @@ type Firewall struct {
} }
type firewallMetrics struct { type firewallMetrics struct {
droppedLocalIP metrics.Counter droppedLocalAddr metrics.Counter
droppedRemoteIP metrics.Counter droppedRemoteAddr metrics.Counter
droppedNoRule metrics.Counter droppedNoRule metrics.Counter
} }
type FirewallConntrack struct { type FirewallConntrack struct {
@ -126,84 +129,87 @@ type firewallLocalCIDR struct {
} }
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts. // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
// The certificate provided should be the highest version loaded in memory.
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall { func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall {
//TODO: error on 0 duration //TODO: error on 0 duration
var min, max time.Duration var tmin, tmax time.Duration
if tcpTimeout < UDPTimeout { if tcpTimeout < UDPTimeout {
min = tcpTimeout tmin = tcpTimeout
max = UDPTimeout tmax = UDPTimeout
} else { } else {
min = UDPTimeout tmin = UDPTimeout
max = tcpTimeout tmax = tcpTimeout
} }
if defaultTimeout < min { if defaultTimeout < tmin {
min = defaultTimeout tmin = defaultTimeout
} else if defaultTimeout > max { } else if defaultTimeout > tmax {
max = defaultTimeout tmax = defaultTimeout
} }
localIps := new(bart.Table[struct{}]) routableNetworks := new(bart.Table[struct{}])
var assignedCIDR netip.Prefix var assignedNetworks []netip.Prefix
var assignedSet bool
for _, network := range c.Networks() { for _, network := range c.Networks() {
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen()) nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
localIps.Insert(nprefix, struct{}{}) routableNetworks.Insert(nprefix, struct{}{})
assignedNetworks = append(assignedNetworks, network)
if !assignedSet {
// Only grabbing the first one in the cert since any more than that currently has undefined behavior
assignedCIDR = nprefix
assignedSet = true
}
} }
hasUnsafeNetworks := false hasUnsafeNetworks := false
for _, n := range c.UnsafeNetworks() { for _, n := range c.UnsafeNetworks() {
localIps.Insert(n, struct{}{}) routableNetworks.Insert(n, struct{}{})
hasUnsafeNetworks = true hasUnsafeNetworks = true
} }
return &Firewall{ return &Firewall{
Conntrack: &FirewallConntrack{ Conntrack: &FirewallConntrack{
Conns: make(map[firewall.Packet]*conn), Conns: make(map[firewall.Packet]*conn),
TimerWheel: NewTimerWheel[firewall.Packet](min, max), TimerWheel: NewTimerWheel[firewall.Packet](tmin, tmax),
}, },
InRules: newFirewallTable(), InRules: newFirewallTable(),
OutRules: newFirewallTable(), OutRules: newFirewallTable(),
TCPTimeout: tcpTimeout, TCPTimeout: tcpTimeout,
UDPTimeout: UDPTimeout, UDPTimeout: UDPTimeout,
DefaultTimeout: defaultTimeout, DefaultTimeout: defaultTimeout,
localIps: localIps, routableNetworks: routableNetworks,
assignedCIDR: assignedCIDR, assignedNetworks: assignedNetworks,
hasUnsafeNetworks: hasUnsafeNetworks, hasUnsafeNetworks: hasUnsafeNetworks,
l: l, l: l,
incomingMetrics: firewallMetrics{ incomingMetrics: firewallMetrics{
droppedLocalIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_ip", nil), droppedLocalAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_addr", nil),
droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_ip", nil), droppedRemoteAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_addr", nil),
droppedNoRule: metrics.GetOrRegisterCounter("firewall.incoming.dropped.no_rule", nil), droppedNoRule: metrics.GetOrRegisterCounter("firewall.incoming.dropped.no_rule", nil),
}, },
outgoingMetrics: firewallMetrics{ outgoingMetrics: firewallMetrics{
droppedLocalIP: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.local_ip", nil), droppedLocalAddr: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.local_addr", nil),
droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.remote_ip", nil), droppedRemoteAddr: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.remote_addr", nil),
droppedNoRule: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.no_rule", nil), droppedNoRule: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.no_rule", nil),
}, },
} }
} }
func NewFirewallFromConfig(l *logrus.Logger, nc cert.Certificate, c *config.C) (*Firewall, error) { func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) {
certificate := cs.getCertificate(cert.Version2)
if certificate == nil {
certificate = cs.getCertificate(cert.Version1)
}
if certificate == nil {
panic("No certificate available to reconfigure the firewall")
}
fw := NewFirewall( fw := NewFirewall(
l, l,
c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12), c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3), c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3),
c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10), c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10),
nc, certificate,
//TODO: max_connections //TODO: max_connections
) )
//TODO: Flip to false after v1.9 release fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", false)
fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", true)
inboundAction := c.GetString("firewall.inbound_action", "drop") inboundAction := c.GetString("firewall.inbound_action", "drop")
switch inboundAction { switch inboundAction {
@ -283,7 +289,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
fp = ft.TCP fp = ft.TCP
case firewall.ProtoUDP: case firewall.ProtoUDP:
fp = ft.UDP fp = ft.UDP
case firewall.ProtoICMP: case firewall.ProtoICMP, firewall.ProtoICMPv6:
fp = ft.ICMP fp = ft.ICMP
case firewall.ProtoAny: case firewall.ProtoAny:
fp = ft.AnyProto fp = ft.AnyProto
@ -424,26 +430,24 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
} }
// Make sure remote address matches nebula certificate // Make sure remote address matches nebula certificate
if remoteCidr := h.remoteCidr; remoteCidr != nil { if h.networks != nil {
//TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different _, ok := h.networks.Lookup(fp.RemoteAddr)
_, ok := remoteCidr.Lookup(fp.RemoteIP)
if !ok { if !ok {
f.metrics(incoming).droppedRemoteIP.Inc(1) f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP return ErrInvalidRemoteIP
} }
} else { } else {
// Simple case: Certificate has one IP and no subnets // Simple case: Certificate has one address and no unsafe networks
if fp.RemoteIP != h.vpnIp { if h.vpnAddrs[0] != fp.RemoteAddr {
f.metrics(incoming).droppedRemoteIP.Inc(1) f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP return ErrInvalidRemoteIP
} }
} }
// Make sure we are supposed to be handling this local ip address // Make sure we are supposed to be handling this local ip address
//TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different _, ok := f.routableNetworks.Lookup(fp.LocalAddr)
_, ok := f.localIps.Lookup(fp.LocalIP)
if !ok { if !ok {
f.metrics(incoming).droppedLocalIP.Inc(1) f.metrics(incoming).droppedLocalAddr.Inc(1)
return ErrInvalidLocalIP return ErrInvalidLocalIP
} }
@ -629,7 +633,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.CachedC
if ft.UDP.match(p, incoming, c, caPool) { if ft.UDP.match(p, incoming, c, caPool) {
return true return true
} }
case firewall.ProtoICMP: case firewall.ProtoICMP, firewall.ProtoICMPv6:
if ft.ICMP.match(p, incoming, c, caPool) { if ft.ICMP.match(p, incoming, c, caPool) {
return true return true
} }
@ -859,9 +863,9 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool
} }
matched := false matched := false
prefix := netip.PrefixFrom(p.RemoteIP, p.RemoteIP.BitLen()) prefix := netip.PrefixFrom(p.RemoteAddr, p.RemoteAddr.BitLen())
fr.CIDR.EachLookupPrefix(prefix, func(prefix netip.Prefix, val *firewallLocalCIDR) bool { fr.CIDR.EachLookupPrefix(prefix, func(prefix netip.Prefix, val *firewallLocalCIDR) bool {
if prefix.Contains(p.RemoteIP) && val.match(p, c) { if prefix.Contains(p.RemoteAddr) && val.match(p, c) {
matched = true matched = true
return false return false
} }
@ -877,9 +881,14 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
return nil return nil
} }
localIp = f.assignedCIDR for _, network := range f.assignedNetworks {
flc.LocalCIDR.Insert(network, struct{}{})
}
return nil
} else if localIp.Bits() == 0 { } else if localIp.Bits() == 0 {
flc.Any = true flc.Any = true
return nil
} }
flc.LocalCIDR.Insert(localIp, struct{}{}) flc.LocalCIDR.Insert(localIp, struct{}{})
@ -895,7 +904,7 @@ func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.CachedCertificate
return true return true
} }
_, ok := flc.LocalCIDR.Lookup(p.LocalIP) _, ok := flc.LocalCIDR.Lookup(p.LocalAddr)
return ok return ok
} }

View File

@ -9,18 +9,19 @@ import (
type m map[string]interface{} type m map[string]interface{}
const ( const (
ProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever ProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever
ProtoTCP = 6 ProtoTCP = 6
ProtoUDP = 17 ProtoUDP = 17
ProtoICMP = 1 ProtoICMP = 1
ProtoICMPv6 = 58
PortAny = 0 // Special value for matching `port: any` PortAny = 0 // Special value for matching `port: any`
PortFragment = -1 // Special value for matching `port: fragment` PortFragment = -1 // Special value for matching `port: fragment`
) )
type Packet struct { type Packet struct {
LocalIP netip.Addr LocalAddr netip.Addr
RemoteIP netip.Addr RemoteAddr netip.Addr
LocalPort uint16 LocalPort uint16
RemotePort uint16 RemotePort uint16
Protocol uint8 Protocol uint8
@ -29,8 +30,8 @@ type Packet struct {
func (fp *Packet) Copy() *Packet { func (fp *Packet) Copy() *Packet {
return &Packet{ return &Packet{
LocalIP: fp.LocalIP, LocalAddr: fp.LocalAddr,
RemoteIP: fp.RemoteIP, RemoteAddr: fp.RemoteAddr,
LocalPort: fp.LocalPort, LocalPort: fp.LocalPort,
RemotePort: fp.RemotePort, RemotePort: fp.RemotePort,
Protocol: fp.Protocol, Protocol: fp.Protocol,
@ -51,8 +52,8 @@ func (fp Packet) MarshalJSON() ([]byte, error) {
proto = fmt.Sprintf("unknown %v", fp.Protocol) proto = fmt.Sprintf("unknown %v", fp.Protocol)
} }
return json.Marshal(m{ return json.Marshal(m{
"LocalIP": fp.LocalIP.String(), "LocalAddr": fp.LocalAddr.String(),
"RemoteIP": fp.RemoteIP.String(), "RemoteAddr": fp.RemoteAddr.String(),
"LocalPort": fp.LocalPort, "LocalPort": fp.LocalPort,
"RemotePort": fp.RemotePort, "RemotePort": fp.RemotePort,
"Protocol": proto, "Protocol": proto,

View File

@ -13,6 +13,7 @@ import (
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestNewFirewall(t *testing.T) { func TestNewFirewall(t *testing.T) {
@ -128,8 +129,8 @@ func TestFirewall_Drop(t *testing.T) {
l.SetOutput(ob) l.SetOutput(ob)
p := firewall.Packet{ p := firewall.Packet{
LocalIP: netip.MustParseAddr("1.2.3.4"), LocalAddr: netip.MustParseAddr("1.2.3.4"),
RemoteIP: netip.MustParseAddr("1.2.3.4"), RemoteAddr: netip.MustParseAddr("1.2.3.4"),
LocalPort: 10, LocalPort: 10,
RemotePort: 90, RemotePort: 90,
Protocol: firewall.ProtoUDP, Protocol: firewall.ProtoUDP,
@ -149,9 +150,9 @@ func TestFirewall_Drop(t *testing.T) {
InvertedGroups: map[string]struct{}{"default-group": {}}, InvertedGroups: map[string]struct{}{"default-group": {}},
}, },
}, },
vpnIp: netip.MustParseAddr("1.2.3.4"), vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
} }
h.CreateRemoteCIDR(&c) h.buildNetworks(c.networks, c.unsafeNetworks)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@ -166,10 +167,10 @@ func TestFirewall_Drop(t *testing.T) {
assert.NoError(t, fw.Drop(p, false, &h, cp, nil)) assert.NoError(t, fw.Drop(p, false, &h, cp, nil))
// test remote mismatch // test remote mismatch
oldRemote := p.RemoteIP oldRemote := p.RemoteAddr
p.RemoteIP = netip.MustParseAddr("1.2.3.10") p.RemoteAddr = netip.MustParseAddr("1.2.3.10")
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP) assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
p.RemoteIP = oldRemote p.RemoteAddr = oldRemote
// ensure signer doesn't get in the way of group checks // ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
@ -235,7 +236,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
} }
ip := netip.MustParsePrefix("9.254.254.254/32") ip := netip.MustParsePrefix("9.254.254.254/32")
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip.Addr()}, true, c, cp)) assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
} }
}) })
@ -261,7 +262,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
InvertedGroups: map[string]struct{}{"nope": {}}, InvertedGroups: map[string]struct{}{"nope": {}},
} }
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp)) assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
} }
}) })
@ -285,7 +286,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
InvertedGroups: map[string]struct{}{"good-group": {}}, InvertedGroups: map[string]struct{}{"good-group": {}},
} }
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp)) assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
} }
}) })
@ -308,8 +309,8 @@ func TestFirewall_Drop2(t *testing.T) {
l.SetOutput(ob) l.SetOutput(ob)
p := firewall.Packet{ p := firewall.Packet{
LocalIP: netip.MustParseAddr("1.2.3.4"), LocalAddr: netip.MustParseAddr("1.2.3.4"),
RemoteIP: netip.MustParseAddr("1.2.3.4"), RemoteAddr: netip.MustParseAddr("1.2.3.4"),
LocalPort: 10, LocalPort: 10,
RemotePort: 90, RemotePort: 90,
Protocol: firewall.ProtoUDP, Protocol: firewall.ProtoUDP,
@ -329,9 +330,9 @@ func TestFirewall_Drop2(t *testing.T) {
ConnectionState: &ConnectionState{ ConnectionState: &ConnectionState{
peerCert: &c, peerCert: &c,
}, },
vpnIp: network.Addr(), vpnAddrs: []netip.Addr{network.Addr()},
} }
h.CreateRemoteCIDR(c.Certificate) h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
c1 := cert.CachedCertificate{ c1 := cert.CachedCertificate{
Certificate: &dummyCert{ Certificate: &dummyCert{
@ -341,11 +342,12 @@ func TestFirewall_Drop2(t *testing.T) {
InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}}, InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}},
} }
h1 := HostInfo{ h1 := HostInfo{
vpnAddrs: []netip.Addr{network.Addr()},
ConnectionState: &ConnectionState{ ConnectionState: &ConnectionState{
peerCert: &c1, peerCert: &c1,
}, },
} }
h1.CreateRemoteCIDR(c1.Certificate) h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@ -364,8 +366,8 @@ func TestFirewall_Drop3(t *testing.T) {
l.SetOutput(ob) l.SetOutput(ob)
p := firewall.Packet{ p := firewall.Packet{
LocalIP: netip.MustParseAddr("1.2.3.4"), LocalAddr: netip.MustParseAddr("1.2.3.4"),
RemoteIP: netip.MustParseAddr("1.2.3.4"), RemoteAddr: netip.MustParseAddr("1.2.3.4"),
LocalPort: 1, LocalPort: 1,
RemotePort: 1, RemotePort: 1,
Protocol: firewall.ProtoUDP, Protocol: firewall.ProtoUDP,
@ -391,9 +393,9 @@ func TestFirewall_Drop3(t *testing.T) {
ConnectionState: &ConnectionState{ ConnectionState: &ConnectionState{
peerCert: &c1, peerCert: &c1,
}, },
vpnIp: network.Addr(), vpnAddrs: []netip.Addr{network.Addr()},
} }
h1.CreateRemoteCIDR(c1.Certificate) h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
c2 := cert.CachedCertificate{ c2 := cert.CachedCertificate{
Certificate: &dummyCert{ Certificate: &dummyCert{
@ -406,9 +408,9 @@ func TestFirewall_Drop3(t *testing.T) {
ConnectionState: &ConnectionState{ ConnectionState: &ConnectionState{
peerCert: &c2, peerCert: &c2,
}, },
vpnIp: network.Addr(), vpnAddrs: []netip.Addr{network.Addr()},
} }
h2.CreateRemoteCIDR(c2.Certificate) h2.buildNetworks(c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks())
c3 := cert.CachedCertificate{ c3 := cert.CachedCertificate{
Certificate: &dummyCert{ Certificate: &dummyCert{
@ -421,9 +423,9 @@ func TestFirewall_Drop3(t *testing.T) {
ConnectionState: &ConnectionState{ ConnectionState: &ConnectionState{
peerCert: &c3, peerCert: &c3,
}, },
vpnIp: network.Addr(), vpnAddrs: []netip.Addr{network.Addr()},
} }
h3.CreateRemoteCIDR(c3.Certificate) h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
@ -446,8 +448,8 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
l.SetOutput(ob) l.SetOutput(ob)
p := firewall.Packet{ p := firewall.Packet{
LocalIP: netip.MustParseAddr("1.2.3.4"), LocalAddr: netip.MustParseAddr("1.2.3.4"),
RemoteIP: netip.MustParseAddr("1.2.3.4"), RemoteAddr: netip.MustParseAddr("1.2.3.4"),
LocalPort: 10, LocalPort: 10,
RemotePort: 90, RemotePort: 90,
Protocol: firewall.ProtoUDP, Protocol: firewall.ProtoUDP,
@ -468,9 +470,9 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
ConnectionState: &ConnectionState{ ConnectionState: &ConnectionState{
peerCert: &c, peerCert: &c,
}, },
vpnIp: network.Addr(), vpnAddrs: []netip.Addr{network.Addr()},
} }
h.CreateRemoteCIDR(c.Certificate) h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@ -574,8 +576,6 @@ func BenchmarkLookup(b *testing.B) {
ml(m, a) ml(m, a)
} }
}) })
//TODO: only way array lookup in array will help is if both are sorted, then maybe it's faster
} }
func Test_parsePort(t *testing.T) { func Test_parsePort(t *testing.T) {
@ -622,55 +622,58 @@ func TestNewFirewallFromConfig(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
// Test a bad rule definition // Test a bad rule definition
c := &dummyCert{} c := &dummyCert{}
cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil)
require.NoError(t, err)
conf := config.NewC(l) conf := config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
_, err := NewFirewallFromConfig(l, c, conf) _, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules") assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
// Test both port and code // Test both port and code
conf = config.NewC(l) conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
_, err = NewFirewallFromConfig(l, c, conf) _, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided") assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
// Test missing host, group, cidr, ca_name and ca_sha // Test missing host, group, cidr, ca_name and ca_sha
conf = config.NewC(l) conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
_, err = NewFirewallFromConfig(l, c, conf) _, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided") assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
// Test code/port error // Test code/port error
conf = config.NewC(l) conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
_, err = NewFirewallFromConfig(l, c, conf) _, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`") assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}}
_, err = NewFirewallFromConfig(l, c, conf) _, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`") assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
// Test proto error // Test proto error
conf = config.NewC(l) conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
_, err = NewFirewallFromConfig(l, c, conf) _, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``") assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
// Test cidr parse error // Test cidr parse error
conf = config.NewC(l) conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, c, conf) _, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
// Test local_cidr parse error // Test local_cidr parse error
conf = config.NewC(l) conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, c, conf) _, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
// Test both group and groups // Test both group and groups
conf = config.NewC(l) conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
_, err = NewFirewallFromConfig(l, c, conf) _, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided") assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
} }

1
go.mod
View File

@ -21,7 +21,6 @@ require (
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
github.com/sirupsen/logrus v1.9.3 github.com/sirupsen/logrus v1.9.3
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.9.0
github.com/vishvananda/netlink v1.3.0 github.com/vishvananda/netlink v1.3.0

2
go.sum
View File

@ -137,8 +137,6 @@ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8=
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E=
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 h1:pnnLyeX7o/5aX8qUQ69P/mLojDqwda8hFOCBTmP/6hw= github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 h1:pnnLyeX7o/5aX8qUQ69P/mLojDqwda8hFOCBTmP/6hw=
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6/go.mod h1:39R/xuhNgVhi+K0/zst4TLrJrVmbm6LVgl4A0+ZFS5M= github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6/go.mod h1:39R/xuhNgVhi+K0/zst4TLrJrVmbm6LVgl4A0+ZFS5M=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=

View File

@ -2,10 +2,12 @@ package nebula
import ( import (
"net/netip" "net/netip"
"slices"
"time" "time"
"github.com/flynn/noise" "github.com/flynn/noise"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
) )
@ -16,30 +18,60 @@ import (
func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
err := f.handshakeManager.allocateIndex(hh) err := f.handshakeManager.allocateIndex(hh)
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp). f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index") WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
return false return false
} }
certState := f.pki.GetCertState() // If we're connecting to a v6 address we must use a v2 cert
ci := NewConnectionState(f.l, f.cipher, certState, true, noise.HandshakeIX, []byte{}, 0) cs := f.pki.getCertState()
v := cs.defaultVersion
for _, a := range hh.hostinfo.vpnAddrs {
if a.Is6() {
v = cert.Version2
break
}
}
crt := cs.getCertificate(v)
if crt == nil {
f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", v).
Error("Unable to handshake with host because no certificate is available")
return false
}
crtHs := cs.getHandshakeBytes(v)
if crtHs == nil {
f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", v).
Error("Unable to handshake with host because no certificate handshake bytes is available")
}
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", v).
Error("Failed to create connection state")
return false
}
hh.hostinfo.ConnectionState = ci hh.hostinfo.ConnectionState = ci
hsProto := &NebulaHandshakeDetails{
InitiatorIndex: hh.hostinfo.localIndexId,
Time: uint64(time.Now().UnixNano()),
Cert: certState.RawCertificateNoKey,
}
hsBytes := []byte{}
hs := &NebulaHandshake{ hs := &NebulaHandshake{
Details: hsProto, Details: &NebulaHandshakeDetails{
InitiatorIndex: hh.hostinfo.localIndexId,
Time: uint64(time.Now().UnixNano()),
Cert: crtHs,
CertVersion: uint32(v),
},
} }
hsBytes, err = hs.Marshal()
hsBytes, err := hs.Marshal()
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp). f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).WithField("certVersion", v).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message") WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
return false return false
} }
@ -48,7 +80,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
msg, _, _, err := ci.H.WriteMessage(h, hsBytes) msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp). f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
return false return false
} }
@ -63,30 +95,44 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
} }
func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) { func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
certState := f.pki.GetCertState() cs := f.pki.getCertState()
ci := NewConnectionState(f.l, f.cipher, certState, false, noise.HandshakeIX, []byte{}, 0) crt := cs.GetDefaultCertificate()
if crt == nil {
f.l.WithField("udpAddr", addr).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", cs.defaultVersion).
Error("Unable to handshake with host because no certificate is available")
}
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
if err != nil {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to create connection state")
return
}
// Mark packet 1 as seen so it doesn't show up as missed // Mark packet 1 as seen so it doesn't show up as missed
ci.window.Update(f.l, 1) ci.window.Update(f.l, 1)
msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:]) msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil { if err != nil {
f.l.WithError(err).WithField("udpAddr", addr). f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to call noise.ReadMessage")
return return
} }
hs := &NebulaHandshake{} hs := &NebulaHandshake{}
err = hs.Unmarshal(msg) err = hs.Unmarshal(msg)
/*
l.Debugln("GOT INDEX: ", hs.Details.InitiatorIndex)
*/
if err != nil || hs.Details == nil { if err != nil || hs.Details == nil {
f.l.WithError(err).WithField("udpAddr", addr). f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed unmarshal handshake message")
return return
} }
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool()) remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool())
if err != nil { if err != nil {
e := f.l.WithError(err).WithField("udpAddr", addr). e := f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}) WithField("handshake", m{"stage": 1, "style": "ix_psk0"})
@ -99,6 +145,20 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
return return
} }
if remoteCert.Certificate.Version() != ci.myCert.Version() {
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
rc := cs.getCertificate(remoteCert.Certificate.Version())
if rc == nil {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
Info("Unable to handshake with host due to missing certificate version")
return
}
// Record the certificate we are actually using
ci.myCert = rc
}
if len(remoteCert.Certificate.Networks()) == 0 { if len(remoteCert.Certificate.Networks()) == 0 {
e := f.l.WithError(err).WithField("udpAddr", addr). e := f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}) WithField("handshake", m{"stage": 1, "style": "ix_psk0"})
@ -111,30 +171,54 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
return return
} }
vpnIp := remoteCert.Certificate.Networks()[0].Addr().Unmap() var vpnAddrs []netip.Addr
var filteredNetworks []netip.Prefix
certName := remoteCert.Certificate.Name() certName := remoteCert.Certificate.Name()
fingerprint := remoteCert.Fingerprint fingerprint := remoteCert.Fingerprint
issuer := remoteCert.Certificate.Issuer() issuer := remoteCert.Certificate.Issuer()
if vpnIp == f.myVpnNet.Addr() { for _, network := range remoteCert.Certificate.Networks() {
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). vpnAddr := network.Addr()
_, found := f.myVpnAddrsTable.Lookup(vpnAddr)
if found {
f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
return
}
// vpnAddrs outside our vpn networks are of no use to us, filter them out
if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok {
continue
}
filteredNetworks = append(filteredNetworks, network)
vpnAddrs = append(vpnAddrs, vpnAddr)
}
if len(vpnAddrs) == 0 {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
return return
} }
if addr.IsValid() { if addr.IsValid() {
if !f.lightHouse.GetRemoteAllowList().Allow(vpnIp, addr.Addr()) { // addr can be invalid when the tunnel is being relayed.
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") // We only want to apply the remote allow list for direct tunnels here
if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, addr.Addr()) {
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
return return
} }
} }
myIndex, err := generateIndex(f.l) myIndex, err := generateIndex(f.l)
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr). f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
@ -146,17 +230,17 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
ConnectionState: ci, ConnectionState: ci,
localIndexId: myIndex, localIndexId: myIndex,
remoteIndexId: hs.Details.InitiatorIndex, remoteIndexId: hs.Details.InitiatorIndex,
vpnIp: vpnIp, vpnAddrs: vpnAddrs,
HandshakePacket: make(map[uint8][]byte, 0), HandshakePacket: make(map[uint8][]byte, 0),
lastHandshakeTime: hs.Details.Time, lastHandshakeTime: hs.Details.Time,
relayState: RelayState{ relayState: RelayState{
relays: map[netip.Addr]struct{}{}, relays: map[netip.Addr]struct{}{},
relayForByIp: map[netip.Addr]*Relay{}, relayForByAddr: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{}, relayForByIdx: map[uint32]*Relay{},
}, },
} }
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
@ -165,13 +249,26 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
Info("Handshake message received") Info("Handshake message received")
hs.Details.ResponderIndex = myIndex hs.Details.ResponderIndex = myIndex
hs.Details.Cert = certState.RawCertificateNoKey hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
if hs.Details.Cert == nil {
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("certVersion", ci.myCert.Version()).
Error("Unable to handshake with host because no certificate handshake bytes is available")
return
}
hs.Details.CertVersion = uint32(ci.myCert.Version())
// Update the time in case their clock is way off from ours // Update the time in case their clock is way off from ours
hs.Details.Time = uint64(time.Now().UnixNano()) hs.Details.Time = uint64(time.Now().UnixNano())
hsBytes, err := hs.Marshal() hsBytes, err := hs.Marshal()
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
@ -182,14 +279,14 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2) nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2)
msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes) msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes)
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
return return
} else if dKey == nil || eKey == nil { } else if dKey == nil || eKey == nil {
f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
@ -213,9 +310,9 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
ci.dKey = NewNebulaCipherState(dKey) ci.dKey = NewNebulaCipherState(dKey)
ci.eKey = NewNebulaCipherState(eKey) ci.eKey = NewNebulaCipherState(eKey)
hostinfo.remotes = f.lightHouse.QueryCache(vpnIp) hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
hostinfo.SetRemote(addr) hostinfo.SetRemote(addr)
hostinfo.CreateRemoteCIDR(remoteCert.Certificate) hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f) existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
if err != nil { if err != nil {
@ -225,7 +322,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
if existing.SetRemoteIfPreferred(f.hostMap, addr) { if existing.SetRemoteIfPreferred(f.hostMap, addr) {
// Send a test packet to ensure the other side has also switched to // Send a test packet to ensure the other side has also switched to
// the preferred remote // the preferred remote
f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
} }
msg = existing.HandshakePacket[2] msg = existing.HandshakePacket[2]
@ -233,11 +330,11 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
if addr.IsValid() { if addr.IsValid() {
err := f.outside.WriteTo(msg, addr) err := f.outside.WriteTo(msg, addr)
if err != nil { if err != nil {
f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr). f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
WithError(err).Error("Failed to send handshake message") WithError(err).Error("Failed to send handshake message")
} else { } else {
f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr). f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
Info("Handshake message sent") Info("Handshake message sent")
} }
@ -247,16 +344,16 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
f.l.Error("Handshake send failed: both addr and via are nil.") f.l.Error("Handshake send failed: both addr and via are nil.")
return return
} }
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp) hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
f.l.WithField("vpnIp", existing.vpnIp).WithField("relay", via.relayHI.vpnIp). f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
Info("Handshake message sent") Info("Handshake message sent")
return return
} }
case ErrExistingHostInfo: case ErrExistingHostInfo:
// This means there was an existing tunnel and this handshake was older than the one we are currently based on // This means there was an existing tunnel and this handshake was older than the one we are currently based on
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("oldHandshakeTime", existing.lastHandshakeTime). WithField("oldHandshakeTime", existing.lastHandshakeTime).
WithField("newHandshakeTime", hostinfo.lastHandshakeTime). WithField("newHandshakeTime", hostinfo.lastHandshakeTime).
@ -267,23 +364,23 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
Info("Handshake too old") Info("Handshake too old")
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
return return
case ErrLocalIndexCollision: case ErrLocalIndexCollision:
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry // This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnIp). WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnAddrs).
Error("Failed to add HostInfo due to localIndex collision") Error("Failed to add HostInfo due to localIndex collision")
return return
default: default:
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete // Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
// And we forget to update it here // And we forget to update it here
f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr). f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
@ -299,7 +396,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
if addr.IsValid() { if addr.IsValid() {
err = f.outside.WriteTo(msg, addr) err = f.outside.WriteTo(msg, addr)
if err != nil { if err != nil {
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
@ -307,7 +404,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithError(err).Error("Failed to send handshake") WithError(err).Error("Failed to send handshake")
} else { } else {
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
@ -320,9 +417,12 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
f.l.Error("Handshake send failed: both addr and via are nil.") f.l.Error("Handshake send failed: both addr and via are nil.")
return return
} }
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp) hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
// I successfully received a handshake. Just in case I marked this tunnel as 'Disestablished', ensure
// it's correctly marked as working.
via.relayHI.relayState.UpdateRelayForByIdxState(via.remoteIdx, Established)
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
f.l.WithField("vpnIp", vpnIp).WithField("relay", via.relayHI.vpnIp). f.l.WithField("vpnAddrs", vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
@ -349,8 +449,9 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
hostinfo := hh.hostinfo hostinfo := hh.hostinfo
if addr.IsValid() { if addr.IsValid() {
if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.Addr()) { // The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list.
f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, addr.Addr()) {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
return false return false
} }
} }
@ -358,7 +459,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
ci := hostinfo.ConnectionState ci := hostinfo.ConnectionState
msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:]) msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil { if err != nil {
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
Error("Failed to call noise.ReadMessage") Error("Failed to call noise.ReadMessage")
@ -367,7 +468,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
// near future // near future
return false return false
} else if dKey == nil || eKey == nil { } else if dKey == nil || eKey == nil {
f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Error("Noise did not arrive at a key") Error("Noise did not arrive at a key")
@ -379,16 +480,16 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
hs := &NebulaHandshake{} hs := &NebulaHandshake{}
err = hs.Unmarshal(msg) err = hs.Unmarshal(msg)
if err != nil || hs.Details == nil { if err != nil || hs.Details == nil {
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again // The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
return true return true
} }
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool()) remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool())
if err != nil { if err != nil {
e := f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). e := f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}) WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
if f.l.Level > logrus.DebugLevel { if f.l.Level > logrus.DebugLevel {
@ -409,11 +510,11 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
e = e.WithField("cert", remoteCert) e = e.WithField("cert", remoteCert)
} }
e.Info("Invalid vpn ip from host") e.Info("Empty networks from host")
return true return true
} }
vpnIp := remoteCert.Certificate.Networks()[0].Addr().Unmap() vpnNetworks := remoteCert.Certificate.Networks()
certName := remoteCert.Certificate.Name() certName := remoteCert.Certificate.Name()
fingerprint := remoteCert.Fingerprint fingerprint := remoteCert.Fingerprint
issuer := remoteCert.Certificate.Issuer() issuer := remoteCert.Certificate.Issuer()
@ -430,12 +531,34 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
if addr.IsValid() { if addr.IsValid() {
hostinfo.SetRemote(addr) hostinfo.SetRemote(addr)
} else { } else {
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp) hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
}
var vpnAddrs []netip.Addr
var filteredNetworks []netip.Prefix
for _, network := range vpnNetworks {
// vpnAddrs outside our vpn networks are of no use to us, filter them out
vpnAddr := network.Addr()
if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok {
continue
}
filteredNetworks = append(filteredNetworks, network)
vpnAddrs = append(vpnAddrs, vpnAddr)
}
if len(vpnAddrs) == 0 {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
return true
} }
// Ensure the right host responded // Ensure the right host responded
if vpnIp != hostinfo.vpnIp { if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) {
f.l.WithField("intendedVpnIp", hostinfo.vpnIp).WithField("haveVpnIp", vpnIp). f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
WithField("udpAddr", addr).WithField("certName", certName). WithField("udpAddr", addr).WithField("certName", certName).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("Incorrect host responded to handshake") Info("Incorrect host responded to handshake")
@ -444,14 +567,13 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
f.handshakeManager.DeleteHostInfo(hostinfo) f.handshakeManager.DeleteHostInfo(hostinfo)
// Create a new hostinfo/handshake for the intended vpn ip // Create a new hostinfo/handshake for the intended vpn ip
f.handshakeManager.StartHandshake(hostinfo.vpnIp, func(newHH *HandshakeHostInfo) { f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
//TODO: this doesnt know if its being added or is being used for caching a packet
// Block the current used address // Block the current used address
newHH.hostinfo.remotes = hostinfo.remotes newHH.hostinfo.remotes = hostinfo.remotes
newHH.hostinfo.remotes.BlockRemote(addr) newHH.hostinfo.remotes.BlockRemote(addr)
f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()). f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).
WithField("vpnIp", newHH.hostinfo.vpnIp). WithField("vpnNetworks", vpnNetworks).
WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())). WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())).
Info("Blocked addresses for handshakes") Info("Blocked addresses for handshakes")
@ -459,11 +581,8 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
newHH.packetStore = hh.packetStore newHH.packetStore = hh.packetStore
hh.packetStore = []*cachedPacket{} hh.packetStore = []*cachedPacket{}
// Get the correct remote list for the host we did handshake with // Finally, put the correct vpn addrs in the host info, tell them to close the tunnel, and return true to tear down
hostinfo.SetRemote(addr) hostinfo.vpnAddrs = vpnAddrs
hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
// Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down
hostinfo.vpnIp = vpnIp
f.sendCloseTunnel(hostinfo) f.sendCloseTunnel(hostinfo)
}) })
@ -474,7 +593,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
ci.window.Update(f.l, 2) ci.window.Update(f.l, 2)
duration := time.Since(hh.startTime).Nanoseconds() duration := time.Since(hh.startTime).Nanoseconds()
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("issuer", issuer). WithField("issuer", issuer).
@ -485,9 +604,10 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
Info("Handshake message received") Info("Handshake message received")
// Build up the radix for the firewall if we have subnets in the cert // Build up the radix for the firewall if we have subnets in the cert
hostinfo.CreateRemoteCIDR(remoteCert.Certificate) hostinfo.vpnAddrs = vpnAddrs
hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
// Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp // Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
f.handshakeManager.Complete(hostinfo, f) f.handshakeManager.Complete(hostinfo, f)
f.connectionManager.AddTrafficWatch(hostinfo.localIndexId) f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)

View File

@ -13,6 +13,7 @@ import (
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
) )
@ -118,18 +119,18 @@ func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *Lig
} }
} }
func (c *HandshakeManager) Run(ctx context.Context) { func (hm *HandshakeManager) Run(ctx context.Context) {
clockSource := time.NewTicker(c.config.tryInterval) clockSource := time.NewTicker(hm.config.tryInterval)
defer clockSource.Stop() defer clockSource.Stop()
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return return
case vpnIP := <-c.trigger: case vpnIP := <-hm.trigger:
c.handleOutbound(vpnIP, true) hm.handleOutbound(vpnIP, true)
case now := <-clockSource.C: case now := <-clockSource.C:
c.NextOutboundHandshakeTimerTick(now) hm.NextOutboundHandshakeTimerTick(now)
} }
} }
} }
@ -137,7 +138,7 @@ func (c *HandshakeManager) Run(ctx context.Context) {
func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) { func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
// First remote allow list check before we know the vpnIp // First remote allow list check before we know the vpnIp
if addr.IsValid() { if addr.IsValid() {
if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.Addr()) { if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(addr.Addr()) {
hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
return return
} }
@ -159,14 +160,14 @@ func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender,
} }
} }
func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) { func (hm *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) {
c.OutboundHandshakeTimer.Advance(now) hm.OutboundHandshakeTimer.Advance(now)
for { for {
vpnIp, has := c.OutboundHandshakeTimer.Purge() vpnIp, has := hm.OutboundHandshakeTimer.Purge()
if !has { if !has {
break break
} }
c.handleOutbound(vpnIp, false) hm.handleOutbound(vpnIp, false)
} }
} }
@ -208,7 +209,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
// NB ^ This comment doesn't jive. It's how the thing gets initialized. // NB ^ This comment doesn't jive. It's how the thing gets initialized.
// It's the common path. Should it update every time, in case a future LH query/queries give us more info? // It's the common path. Should it update every time, in case a future LH query/queries give us more info?
if hostinfo.remotes == nil { if hostinfo.remotes == nil {
hostinfo.remotes = hm.lightHouse.QueryCache(vpnIp) hostinfo.remotes = hm.lightHouse.QueryCache([]netip.Addr{vpnIp})
} }
remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges()) remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())
@ -223,7 +224,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
hh.lastRemotes = remotes hh.lastRemotes = remotes
// TODO: this will generate a load of queries for hosts with only 1 ip // This will generate a load of queries for hosts with only 1 ip
// (such as ones registered to the lighthouse with only a private IP) // (such as ones registered to the lighthouse with only a private IP)
// So we only do it one time after attempting 5 handshakes already. // So we only do it one time after attempting 5 handshakes already.
if len(remotes) <= 1 && hh.counter == 5 { if len(remotes) <= 1 && hh.counter == 5 {
@ -267,59 +268,26 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts") hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
// Send a RelayRequest to all known Relay IP's // Send a RelayRequest to all known Relay IP's
for _, relay := range hostinfo.remotes.relays { for _, relay := range hostinfo.remotes.relays {
// Don't relay to myself, and don't relay through the host I'm trying to connect to // Don't relay to myself
if relay == vpnIp || relay == hm.lightHouse.myVpnNet.Addr() { if relay == vpnIp {
continue continue
} }
relayHostInfo := hm.mainHostMap.QueryVpnIp(relay)
// Don't relay through the host I'm trying to connect to
_, found := hm.f.myVpnAddrsTable.Lookup(relay)
if found {
continue
}
relayHostInfo := hm.mainHostMap.QueryVpnAddr(relay)
if relayHostInfo == nil || !relayHostInfo.remote.IsValid() { if relayHostInfo == nil || !relayHostInfo.remote.IsValid() {
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target") hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
hm.f.Handshake(relay) hm.f.Handshake(relay)
continue continue
} }
// Check the relay HostInfo to see if we already established a relay through it // Check the relay HostInfo to see if we already established a relay through
if existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp); ok { existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp)
switch existingRelay.State { if !ok {
case Established:
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Send handshake via relay")
hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false)
case Requested:
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request")
//TODO: IPV6-WORK
myVpnIpB := hm.f.myVpnNet.Addr().As4()
theirVpnIpB := vpnIp.As4()
// Re-send the CreateRelay request, in case the previous one was lost.
m := NebulaControl{
Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: existingRelay.LocalIndex,
RelayFromIp: binary.BigEndian.Uint32(myVpnIpB[:]),
RelayToIp: binary.BigEndian.Uint32(theirVpnIpB[:]),
}
msg, err := m.Marshal()
if err != nil {
hostinfo.logger(hm.l).
WithError(err).
Error("Failed to marshal Control message to create relay")
} else {
// This must send over the hostinfo, not over hm.Hosts[ip]
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
hm.l.WithFields(logrus.Fields{
"relayFrom": hm.f.myVpnNet.Addr(),
"relayTo": vpnIp,
"initiatorRelayIndex": existingRelay.LocalIndex,
"relay": relay}).
Info("send CreateRelayRequest")
}
default:
hostinfo.logger(hm.l).
WithField("vpnIp", vpnIp).
WithField("state", existingRelay.State).
WithField("relay", relayHostInfo.vpnIp).
Errorf("Relay unexpected state")
}
} else {
// No relays exist or requested yet. // No relays exist or requested yet.
if relayHostInfo.remote.IsValid() { if relayHostInfo.remote.IsValid() {
idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested) idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested)
@ -327,16 +295,35 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap") hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap")
} }
//TODO: IPV6-WORK
myVpnIpB := hm.f.myVpnNet.Addr().As4()
theirVpnIpB := vpnIp.As4()
m := NebulaControl{ m := NebulaControl{
Type: NebulaControl_CreateRelayRequest, Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: idx, InitiatorRelayIndex: idx,
RelayFromIp: binary.BigEndian.Uint32(myVpnIpB[:]),
RelayToIp: binary.BigEndian.Uint32(theirVpnIpB[:]),
} }
switch relayHostInfo.GetCert().Certificate.Version() {
case cert.Version1:
if !hm.f.myVpnAddrs[0].Is4() {
hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
continue
}
if !vpnIp.Is4() {
hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
continue
}
b := hm.f.myVpnAddrs[0].As4()
m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
b = vpnIp.As4()
m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
case cert.Version2:
m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0])
m.RelayToAddr = netAddrToProtoAddr(vpnIp)
default:
hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
continue
}
msg, err := m.Marshal() msg, err := m.Marshal()
if err != nil { if err != nil {
hostinfo.logger(hm.l). hostinfo.logger(hm.l).
@ -345,13 +332,80 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
} else { } else {
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
hm.l.WithFields(logrus.Fields{ hm.l.WithFields(logrus.Fields{
"relayFrom": hm.f.myVpnNet.Addr(), "relayFrom": hm.f.myVpnAddrs[0],
"relayTo": vpnIp, "relayTo": vpnIp,
"initiatorRelayIndex": idx, "initiatorRelayIndex": idx,
"relay": relay}). "relay": relay}).
Info("send CreateRelayRequest") Info("send CreateRelayRequest")
} }
} }
continue
}
switch existingRelay.State {
case Established:
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Send handshake via relay")
hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false)
case Disestablished:
// Mark this relay as 'requested'
relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested)
fallthrough
case Requested:
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request")
// Re-send the CreateRelay request, in case the previous one was lost.
m := NebulaControl{
Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: existingRelay.LocalIndex,
}
switch relayHostInfo.GetCert().Certificate.Version() {
case cert.Version1:
if !hm.f.myVpnAddrs[0].Is4() {
hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
continue
}
if !vpnIp.Is4() {
hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
continue
}
b := hm.f.myVpnAddrs[0].As4()
m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
b = vpnIp.As4()
m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
case cert.Version2:
m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0])
m.RelayToAddr = netAddrToProtoAddr(vpnIp)
default:
hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
continue
}
msg, err := m.Marshal()
if err != nil {
hostinfo.logger(hm.l).
WithError(err).
Error("Failed to marshal Control message to create relay")
} else {
// This must send over the hostinfo, not over hm.Hosts[ip]
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
hm.l.WithFields(logrus.Fields{
"relayFrom": hm.f.myVpnAddrs[0],
"relayTo": vpnIp,
"initiatorRelayIndex": existingRelay.LocalIndex,
"relay": relay}).
Info("send CreateRelayRequest")
}
case PeerRequested:
// PeerRequested only occurs in Forwarding relays, not Terminal relays, and this is a Terminal relay case.
fallthrough
default:
hostinfo.logger(hm.l).
WithField("vpnIp", vpnIp).
WithField("state", existingRelay.State).
WithField("relay", relay).
Errorf("Relay unexpected state")
} }
} }
} }
@ -381,10 +435,10 @@ func (hm *HandshakeManager) GetOrHandshake(vpnIp netip.Addr, cacheCb func(*Hands
} }
// StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip // StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip
func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*HandshakeHostInfo)) *HostInfo { func (hm *HandshakeManager) StartHandshake(vpnAddr netip.Addr, cacheCb func(*HandshakeHostInfo)) *HostInfo {
hm.Lock() hm.Lock()
if hh, ok := hm.vpnIps[vpnIp]; ok { if hh, ok := hm.vpnIps[vpnAddr]; ok {
// We are already trying to handshake with this vpn ip // We are already trying to handshake with this vpn ip
if cacheCb != nil { if cacheCb != nil {
cacheCb(hh) cacheCb(hh)
@ -394,12 +448,12 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands
} }
hostinfo := &HostInfo{ hostinfo := &HostInfo{
vpnIp: vpnIp, vpnAddrs: []netip.Addr{vpnAddr},
HandshakePacket: make(map[uint8][]byte, 0), HandshakePacket: make(map[uint8][]byte, 0),
relayState: RelayState{ relayState: RelayState{
relays: map[netip.Addr]struct{}{}, relays: map[netip.Addr]struct{}{},
relayForByIp: map[netip.Addr]*Relay{}, relayForByAddr: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{}, relayForByIdx: map[uint32]*Relay{},
}, },
} }
@ -407,9 +461,9 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands
hostinfo: hostinfo, hostinfo: hostinfo,
startTime: time.Now(), startTime: time.Now(),
} }
hm.vpnIps[vpnIp] = hh hm.vpnIps[vpnAddr] = hh
hm.metricInitiated.Inc(1) hm.metricInitiated.Inc(1)
hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval) hm.OutboundHandshakeTimer.Add(vpnAddr, hm.config.tryInterval)
if cacheCb != nil { if cacheCb != nil {
cacheCb(hh) cacheCb(hh)
@ -417,21 +471,21 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands
// If this is a static host, we don't need to wait for the HostQueryReply // If this is a static host, we don't need to wait for the HostQueryReply
// We can trigger the handshake right now // We can trigger the handshake right now
_, doTrigger := hm.lightHouse.GetStaticHostList()[vpnIp] _, doTrigger := hm.lightHouse.GetStaticHostList()[vpnAddr]
if !doTrigger { if !doTrigger {
// Add any calculated remotes, and trigger early handshake if one found // Add any calculated remotes, and trigger early handshake if one found
doTrigger = hm.lightHouse.addCalculatedRemotes(vpnIp) doTrigger = hm.lightHouse.addCalculatedRemotes(vpnAddr)
} }
if doTrigger { if doTrigger {
select { select {
case hm.trigger <- vpnIp: case hm.trigger <- vpnAddr:
default: default:
} }
} }
hm.Unlock() hm.Unlock()
hm.lightHouse.QueryServer(vpnIp) hm.lightHouse.QueryServer(vpnAddr)
return hostinfo return hostinfo
} }
@ -452,14 +506,14 @@ var (
// //
// ErrLocalIndexCollision if we already have an entry in the main or pending // ErrLocalIndexCollision if we already have an entry in the main or pending
// hostmap for the hostinfo.localIndexId. // hostmap for the hostinfo.localIndexId.
func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) { func (hm *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) {
c.mainHostMap.Lock() hm.mainHostMap.Lock()
defer c.mainHostMap.Unlock() defer hm.mainHostMap.Unlock()
c.Lock() hm.Lock()
defer c.Unlock() defer hm.Unlock()
// Check if we already have a tunnel with this vpn ip // Check if we already have a tunnel with this vpn ip
existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp] existingHostInfo, found := hm.mainHostMap.Hosts[hostinfo.vpnAddrs[0]]
if found && existingHostInfo != nil { if found && existingHostInfo != nil {
testHostInfo := existingHostInfo testHostInfo := existingHostInfo
for testHostInfo != nil { for testHostInfo != nil {
@ -476,31 +530,31 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
return existingHostInfo, ErrExistingHostInfo return existingHostInfo, ErrExistingHostInfo
} }
existingHostInfo.logger(c.l).Info("Taking new handshake") existingHostInfo.logger(hm.l).Info("Taking new handshake")
} }
existingIndex, found := c.mainHostMap.Indexes[hostinfo.localIndexId] existingIndex, found := hm.mainHostMap.Indexes[hostinfo.localIndexId]
if found { if found {
// We have a collision, but for a different hostinfo // We have a collision, but for a different hostinfo
return existingIndex, ErrLocalIndexCollision return existingIndex, ErrLocalIndexCollision
} }
existingPendingIndex, found := c.indexes[hostinfo.localIndexId] existingPendingIndex, found := hm.indexes[hostinfo.localIndexId]
if found && existingPendingIndex.hostinfo != hostinfo { if found && existingPendingIndex.hostinfo != hostinfo {
// We have a collision, but for a different hostinfo // We have a collision, but for a different hostinfo
return existingPendingIndex.hostinfo, ErrLocalIndexCollision return existingPendingIndex.hostinfo, ErrLocalIndexCollision
} }
existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId] existingRemoteIndex, found := hm.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
if found && existingRemoteIndex != nil && existingRemoteIndex.vpnIp != hostinfo.vpnIp { if found && existingRemoteIndex != nil && existingRemoteIndex.vpnAddrs[0] != hostinfo.vpnAddrs[0] {
// We have a collision, but this can happen since we can't control // We have a collision, but this can happen since we can't control
// the remote ID. Just log about the situation as a note. // the remote ID. Just log about the situation as a note.
hostinfo.logger(c.l). hostinfo.logger(hm.l).
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp). WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs).
Info("New host shadows existing host remoteIndex") Info("New host shadows existing host remoteIndex")
} }
c.mainHostMap.unlockedAddHostInfo(hostinfo, f) hm.mainHostMap.unlockedAddHostInfo(hostinfo, f)
return existingHostInfo, nil return existingHostInfo, nil
} }
@ -518,7 +572,7 @@ func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
// We have a collision, but this can happen since we can't control // We have a collision, but this can happen since we can't control
// the remote ID. Just log about the situation as a note. // the remote ID. Just log about the situation as a note.
hostinfo.logger(hm.l). hostinfo.logger(hm.l).
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp). WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs).
Info("New host shadows existing host remoteIndex") Info("New host shadows existing host remoteIndex")
} }
@ -555,31 +609,34 @@ func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error {
return errors.New("failed to generate unique localIndexId") return errors.New("failed to generate unique localIndexId")
} }
func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) { func (hm *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
c.Lock() hm.Lock()
defer c.Unlock() defer hm.Unlock()
c.unlockedDeleteHostInfo(hostinfo) hm.unlockedDeleteHostInfo(hostinfo)
} }
func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) { func (hm *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
delete(c.vpnIps, hostinfo.vpnIp) for _, addr := range hostinfo.vpnAddrs {
if len(c.vpnIps) == 0 { delete(hm.vpnIps, addr)
c.vpnIps = map[netip.Addr]*HandshakeHostInfo{}
} }
delete(c.indexes, hostinfo.localIndexId) if len(hm.vpnIps) == 0 {
if len(c.vpnIps) == 0 { hm.vpnIps = map[netip.Addr]*HandshakeHostInfo{}
c.indexes = map[uint32]*HandshakeHostInfo{}
} }
if c.l.Level >= logrus.DebugLevel { delete(hm.indexes, hostinfo.localIndexId)
c.l.WithField("hostMap", m{"mapTotalSize": len(c.vpnIps), if len(hm.indexes) == 0 {
"vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). hm.indexes = map[uint32]*HandshakeHostInfo{}
}
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.vpnIps),
"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
Debug("Pending hostmap hostInfo deleted") Debug("Pending hostmap hostInfo deleted")
} }
} }
func (hm *HandshakeManager) QueryVpnIp(vpnIp netip.Addr) *HostInfo { func (hm *HandshakeManager) QueryVpnAddr(vpnIp netip.Addr) *HostInfo {
hh := hm.queryVpnIp(vpnIp) hh := hm.queryVpnIp(vpnIp)
if hh != nil { if hh != nil {
return hh.hostinfo return hh.hostinfo
@ -608,37 +665,37 @@ func (hm *HandshakeManager) queryIndex(index uint32) *HandshakeHostInfo {
return hm.indexes[index] return hm.indexes[index]
} }
func (c *HandshakeManager) GetPreferredRanges() []netip.Prefix { func (hm *HandshakeManager) GetPreferredRanges() []netip.Prefix {
return c.mainHostMap.GetPreferredRanges() return hm.mainHostMap.GetPreferredRanges()
} }
func (c *HandshakeManager) ForEachVpnIp(f controlEach) { func (hm *HandshakeManager) ForEachVpnAddr(f controlEach) {
c.RLock() hm.RLock()
defer c.RUnlock() defer hm.RUnlock()
for _, v := range c.vpnIps { for _, v := range hm.vpnIps {
f(v.hostinfo) f(v.hostinfo)
} }
} }
func (c *HandshakeManager) ForEachIndex(f controlEach) { func (hm *HandshakeManager) ForEachIndex(f controlEach) {
c.RLock() hm.RLock()
defer c.RUnlock() defer hm.RUnlock()
for _, v := range c.indexes { for _, v := range hm.indexes {
f(v.hostinfo) f(v.hostinfo)
} }
} }
func (c *HandshakeManager) EmitStats() { func (hm *HandshakeManager) EmitStats() {
c.RLock() hm.RLock()
hostLen := len(c.vpnIps) hostLen := len(hm.vpnIps)
indexLen := len(c.indexes) indexLen := len(hm.indexes)
c.RUnlock() hm.RUnlock()
metrics.GetOrRegisterGauge("hostmap.pending.hosts", nil).Update(int64(hostLen)) metrics.GetOrRegisterGauge("hostmap.pending.hosts", nil).Update(int64(hostLen))
metrics.GetOrRegisterGauge("hostmap.pending.indexes", nil).Update(int64(indexLen)) metrics.GetOrRegisterGauge("hostmap.pending.indexes", nil).Update(int64(indexLen))
c.mainHostMap.EmitStats() hm.mainHostMap.EmitStats()
} }
// Utility functions below // Utility functions below

View File

@ -5,6 +5,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
@ -13,21 +14,20 @@ import (
func Test_NewHandshakeManagerVpnIp(t *testing.T) { func Test_NewHandshakeManagerVpnIp(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
vpncidr := netip.MustParsePrefix("172.1.1.1/24")
localrange := netip.MustParsePrefix("10.1.1.1/24") localrange := netip.MustParsePrefix("10.1.1.1/24")
ip := netip.MustParseAddr("172.1.1.2") ip := netip.MustParseAddr("172.1.1.2")
preferredRanges := []netip.Prefix{localrange} preferredRanges := []netip.Prefix{localrange}
mainHM := newHostMap(l, vpncidr) mainHM := newHostMap(l)
mainHM.preferredRanges.Store(&preferredRanges) mainHM.preferredRanges.Store(&preferredRanges)
lh := newTestLighthouse() lh := newTestLighthouse()
cs := &CertState{ cs := &CertState{
RawCertificate: []byte{}, defaultVersion: cert.Version1,
PrivateKey: []byte{}, privateKey: []byte{},
Certificate: &dummyCert{}, v1Cert: &dummyCert{version: cert.Version1},
RawCertificateNoKey: []byte{}, v1HandshakeBytes: []byte{},
} }
blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig) blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
@ -41,7 +41,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
i2 := blah.StartHandshake(ip, nil) i2 := blah.StartHandshake(ip, nil)
assert.Same(t, i, i2) assert.Same(t, i, i2)
i.remotes = NewRemoteList(nil) i.remotes = NewRemoteList([]netip.Addr{}, nil)
// Adding something to pending should not affect the main hostmap // Adding something to pending should not affect the main hostmap
assert.Len(t, mainHM.Hosts, 0) assert.Len(t, mainHM.Hosts, 0)
@ -79,16 +79,24 @@ func testCountTimerWheelEntries(tw *LockingTimerWheel[netip.Addr]) (c int) {
type mockEncWriter struct { type mockEncWriter struct {
} }
func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) { func (mw *mockEncWriter) SendMessageToVpnAddr(_ header.MessageType, _ header.MessageSubType, _ netip.Addr, _, _, _ []byte) {
return return
} }
func (mw *mockEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) { func (mw *mockEncWriter) SendVia(_ *HostInfo, _ *Relay, _, _, _ []byte, _ bool) {
return return
} }
func (mw *mockEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) { func (mw *mockEncWriter) SendMessageToHostInfo(_ header.MessageType, _ header.MessageSubType, _ *HostInfo, _, _, _ []byte) {
return return
} }
func (mw *mockEncWriter) Handshake(vpnIP netip.Addr) {} func (mw *mockEncWriter) Handshake(_ netip.Addr) {}
func (mw *mockEncWriter) GetHostInfo(_ netip.Addr) *HostInfo {
return nil
}
func (mw *mockEncWriter) GetCertState() *CertState {
return &CertState{defaultVersion: cert.Version2}
}

View File

@ -35,6 +35,7 @@ const (
Requested = iota Requested = iota
PeerRequested PeerRequested
Established Established
Disestablished
) )
const ( const (
@ -48,7 +49,7 @@ type Relay struct {
State int State int
LocalIndex uint32 LocalIndex uint32
RemoteIndex uint32 RemoteIndex uint32
PeerIp netip.Addr PeerAddr netip.Addr
} }
type HostMap struct { type HostMap struct {
@ -58,7 +59,6 @@ type HostMap struct {
RemoteIndexes map[uint32]*HostInfo RemoteIndexes map[uint32]*HostInfo
Hosts map[netip.Addr]*HostInfo Hosts map[netip.Addr]*HostInfo
preferredRanges atomic.Pointer[[]netip.Prefix] preferredRanges atomic.Pointer[[]netip.Prefix]
vpnCIDR netip.Prefix
l *logrus.Logger l *logrus.Logger
} }
@ -68,9 +68,12 @@ type HostMap struct {
type RelayState struct { type RelayState struct {
sync.RWMutex sync.RWMutex
relays map[netip.Addr]struct{} // Set of VpnIp's of Hosts to use as relays to access this peer relays map[netip.Addr]struct{} // Set of vpnAddr's of Hosts to use as relays to access this peer
relayForByIp map[netip.Addr]*Relay // Maps VpnIps of peers for which this HostInfo is a relay to some Relay info // For data race avoidance, the contents of a *Relay are treated immutably. To update a *Relay, copy the existing data,
relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info // modify what needs to be updated, and store the new modified copy in the relayForByIp and relayForByIdx maps (with
// the RelayState Lock held)
relayForByAddr map[netip.Addr]*Relay // Maps vpnAddr of peers for which this HostInfo is a relay to some Relay info
relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info
} }
func (rs *RelayState) DeleteRelay(ip netip.Addr) { func (rs *RelayState) DeleteRelay(ip netip.Addr) {
@ -79,6 +82,28 @@ func (rs *RelayState) DeleteRelay(ip netip.Addr) {
delete(rs.relays, ip) delete(rs.relays, ip)
} }
func (rs *RelayState) UpdateRelayForByIpState(vpnIp netip.Addr, state int) {
rs.Lock()
defer rs.Unlock()
if r, ok := rs.relayForByAddr[vpnIp]; ok {
newRelay := *r
newRelay.State = state
rs.relayForByAddr[newRelay.PeerAddr] = &newRelay
rs.relayForByIdx[newRelay.LocalIndex] = &newRelay
}
}
func (rs *RelayState) UpdateRelayForByIdxState(idx uint32, state int) {
rs.Lock()
defer rs.Unlock()
if r, ok := rs.relayForByIdx[idx]; ok {
newRelay := *r
newRelay.State = state
rs.relayForByAddr[newRelay.PeerAddr] = &newRelay
rs.relayForByIdx[newRelay.LocalIndex] = &newRelay
}
}
func (rs *RelayState) CopyAllRelayFor() []*Relay { func (rs *RelayState) CopyAllRelayFor() []*Relay {
rs.RLock() rs.RLock()
defer rs.RUnlock() defer rs.RUnlock()
@ -89,10 +114,10 @@ func (rs *RelayState) CopyAllRelayFor() []*Relay {
return ret return ret
} }
func (rs *RelayState) GetRelayForByIp(ip netip.Addr) (*Relay, bool) { func (rs *RelayState) GetRelayForByAddr(addr netip.Addr) (*Relay, bool) {
rs.RLock() rs.RLock()
defer rs.RUnlock() defer rs.RUnlock()
r, ok := rs.relayForByIp[ip] r, ok := rs.relayForByAddr[addr]
return r, ok return r, ok
} }
@ -115,8 +140,8 @@ func (rs *RelayState) CopyRelayIps() []netip.Addr {
func (rs *RelayState) CopyRelayForIps() []netip.Addr { func (rs *RelayState) CopyRelayForIps() []netip.Addr {
rs.RLock() rs.RLock()
defer rs.RUnlock() defer rs.RUnlock()
currentRelays := make([]netip.Addr, 0, len(rs.relayForByIp)) currentRelays := make([]netip.Addr, 0, len(rs.relayForByAddr))
for relayIp := range rs.relayForByIp { for relayIp := range rs.relayForByAddr {
currentRelays = append(currentRelays, relayIp) currentRelays = append(currentRelays, relayIp)
} }
return currentRelays return currentRelays
@ -135,7 +160,7 @@ func (rs *RelayState) CopyRelayForIdxs() []uint32 {
func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool { func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool {
rs.Lock() rs.Lock()
defer rs.Unlock() defer rs.Unlock()
r, ok := rs.relayForByIp[vpnIp] r, ok := rs.relayForByAddr[vpnIp]
if !ok { if !ok {
return false return false
} }
@ -143,7 +168,7 @@ func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool
newRelay.State = Established newRelay.State = Established
newRelay.RemoteIndex = remoteIdx newRelay.RemoteIndex = remoteIdx
rs.relayForByIdx[r.LocalIndex] = &newRelay rs.relayForByIdx[r.LocalIndex] = &newRelay
rs.relayForByIp[r.PeerIp] = &newRelay rs.relayForByAddr[r.PeerAddr] = &newRelay
return true return true
} }
@ -158,14 +183,14 @@ func (rs *RelayState) CompleteRelayByIdx(localIdx uint32, remoteIdx uint32) (*Re
newRelay.State = Established newRelay.State = Established
newRelay.RemoteIndex = remoteIdx newRelay.RemoteIndex = remoteIdx
rs.relayForByIdx[r.LocalIndex] = &newRelay rs.relayForByIdx[r.LocalIndex] = &newRelay
rs.relayForByIp[r.PeerIp] = &newRelay rs.relayForByAddr[r.PeerAddr] = &newRelay
return &newRelay, true return &newRelay, true
} }
func (rs *RelayState) QueryRelayForByIp(vpnIp netip.Addr) (*Relay, bool) { func (rs *RelayState) QueryRelayForByIp(vpnIp netip.Addr) (*Relay, bool) {
rs.RLock() rs.RLock()
defer rs.RUnlock() defer rs.RUnlock()
r, ok := rs.relayForByIp[vpnIp] r, ok := rs.relayForByAddr[vpnIp]
return r, ok return r, ok
} }
@ -179,7 +204,7 @@ func (rs *RelayState) QueryRelayForByIdx(idx uint32) (*Relay, bool) {
func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) { func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) {
rs.Lock() rs.Lock()
defer rs.Unlock() defer rs.Unlock()
rs.relayForByIp[ip] = r rs.relayForByAddr[ip] = r
rs.relayForByIdx[idx] = r rs.relayForByIdx[idx] = r
} }
@ -190,10 +215,16 @@ type HostInfo struct {
ConnectionState *ConnectionState ConnectionState *ConnectionState
remoteIndexId uint32 remoteIndexId uint32
localIndexId uint32 localIndexId uint32
vpnIp netip.Addr
recvError atomic.Uint32 // vpnAddrs is a list of vpn addresses assigned to this host that are within our own vpn networks
remoteCidr *bart.Table[struct{}] // The host may have other vpn addresses that are outside our
relayState RelayState // vpn networks but were removed because they are not usable
vpnAddrs []netip.Addr
recvError atomic.Uint32
// networks are both all vpn and unsafe networks assigned to this host
networks *bart.Table[struct{}]
relayState RelayState
// HandshakePacket records the packets used to create this hostinfo // HandshakePacket records the packets used to create this hostinfo
// We need these to avoid replayed handshake packets creating new hostinfos which causes churn // We need these to avoid replayed handshake packets creating new hostinfos which causes churn
@ -241,28 +272,26 @@ type cachedPacketMetrics struct {
dropped metrics.Counter dropped metrics.Counter
} }
func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR netip.Prefix, c *config.C) *HostMap { func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap {
hm := newHostMap(l, vpnCIDR) hm := newHostMap(l)
hm.reload(c, true) hm.reload(c, true)
c.RegisterReloadCallback(func(c *config.C) { c.RegisterReloadCallback(func(c *config.C) {
hm.reload(c, false) hm.reload(c, false)
}) })
l.WithField("network", hm.vpnCIDR.String()). l.WithField("preferredRanges", hm.GetPreferredRanges()).
WithField("preferredRanges", hm.GetPreferredRanges()).
Info("Main HostMap created") Info("Main HostMap created")
return hm return hm
} }
func newHostMap(l *logrus.Logger, vpnCIDR netip.Prefix) *HostMap { func newHostMap(l *logrus.Logger) *HostMap {
return &HostMap{ return &HostMap{
Indexes: map[uint32]*HostInfo{}, Indexes: map[uint32]*HostInfo{},
Relays: map[uint32]*HostInfo{}, Relays: map[uint32]*HostInfo{},
RemoteIndexes: map[uint32]*HostInfo{}, RemoteIndexes: map[uint32]*HostInfo{},
Hosts: map[netip.Addr]*HostInfo{}, Hosts: map[netip.Addr]*HostInfo{},
vpnCIDR: vpnCIDR,
l: l, l: l,
} }
} }
@ -305,17 +334,6 @@ func (hm *HostMap) EmitStats() {
metrics.GetOrRegisterGauge("hostmap.main.relayIndexes", nil).Update(int64(relaysLen)) metrics.GetOrRegisterGauge("hostmap.main.relayIndexes", nil).Update(int64(relaysLen))
} }
func (hm *HostMap) RemoveRelay(localIdx uint32) {
hm.Lock()
_, ok := hm.Relays[localIdx]
if !ok {
hm.Unlock()
return
}
delete(hm.Relays, localIdx)
hm.Unlock()
}
// DeleteHostInfo will fully unlink the hostinfo and return true if it was the final hostinfo for this vpn ip // DeleteHostInfo will fully unlink the hostinfo and return true if it was the final hostinfo for this vpn ip
func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool { func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool {
// Delete the host itself, ensuring it's not modified anymore // Delete the host itself, ensuring it's not modified anymore
@ -335,48 +353,73 @@ func (hm *HostMap) MakePrimary(hostinfo *HostInfo) {
} }
func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) { func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) {
oldHostinfo := hm.Hosts[hostinfo.vpnIp] // Get the current primary, if it exists
oldHostinfo := hm.Hosts[hostinfo.vpnAddrs[0]]
// Every address in the hostinfo gets elevated to primary
for _, vpnAddr := range hostinfo.vpnAddrs {
//NOTE: It is possible that we leave a dangling hostinfo here but connection manager works on
// indexes so it should be fine.
hm.Hosts[vpnAddr] = hostinfo
}
// If we are already primary then we won't bother re-linking
if oldHostinfo == hostinfo { if oldHostinfo == hostinfo {
return return
} }
// Unlink this hostinfo
if hostinfo.prev != nil { if hostinfo.prev != nil {
hostinfo.prev.next = hostinfo.next hostinfo.prev.next = hostinfo.next
} }
if hostinfo.next != nil { if hostinfo.next != nil {
hostinfo.next.prev = hostinfo.prev hostinfo.next.prev = hostinfo.prev
} }
hm.Hosts[hostinfo.vpnIp] = hostinfo // If there wasn't a previous primary then clear out any links
if oldHostinfo == nil { if oldHostinfo == nil {
hostinfo.next = nil
hostinfo.prev = nil
return return
} }
// Relink the hostinfo as primary
hostinfo.next = oldHostinfo hostinfo.next = oldHostinfo
oldHostinfo.prev = hostinfo oldHostinfo.prev = hostinfo
hostinfo.prev = nil hostinfo.prev = nil
} }
func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) { func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
primary, ok := hm.Hosts[hostinfo.vpnIp] for _, addr := range hostinfo.vpnAddrs {
h := hm.Hosts[addr]
for h != nil {
if h == hostinfo {
hm.unlockedInnerDeleteHostInfo(h, addr)
}
h = h.next
}
}
}
func (hm *HostMap) unlockedInnerDeleteHostInfo(hostinfo *HostInfo, addr netip.Addr) {
primary, ok := hm.Hosts[addr]
isLastHostinfo := hostinfo.next == nil && hostinfo.prev == nil
if ok && primary == hostinfo { if ok && primary == hostinfo {
// The vpnIp pointer points to the same hostinfo as the local index id, we can remove it // The vpn addr pointer points to the same hostinfo as the local index id, we can remove it
delete(hm.Hosts, hostinfo.vpnIp) delete(hm.Hosts, addr)
if len(hm.Hosts) == 0 { if len(hm.Hosts) == 0 {
hm.Hosts = map[netip.Addr]*HostInfo{} hm.Hosts = map[netip.Addr]*HostInfo{}
} }
if hostinfo.next != nil { if hostinfo.next != nil {
// We had more than 1 hostinfo at this vpnip, promote the next in the list to primary // We had more than 1 hostinfo at this vpn addr, promote the next in the list to primary
hm.Hosts[hostinfo.vpnIp] = hostinfo.next hm.Hosts[addr] = hostinfo.next
// It is primary, there is no previous hostinfo now // It is primary, there is no previous hostinfo now
hostinfo.next.prev = nil hostinfo.next.prev = nil
} }
} else { } else {
// Relink if we were in the middle of multiple hostinfos for this vpn ip // Relink if we were in the middle of multiple hostinfos for this vpn addr
if hostinfo.prev != nil { if hostinfo.prev != nil {
hostinfo.prev.next = hostinfo.next hostinfo.prev.next = hostinfo.next
} }
@ -406,10 +449,16 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
if hm.l.Level >= logrus.DebugLevel { if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.Hosts), hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.Hosts),
"vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). "vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
Debug("Hostmap hostInfo deleted") Debug("Hostmap hostInfo deleted")
} }
if isLastHostinfo {
// I have lost connectivity to my peers. My relay tunnel is likely broken. Mark the next
// hops as 'Requested' so that new relay tunnels are created in the future.
hm.unlockedDisestablishVpnAddrRelayFor(hostinfo)
}
// Clean up any local relay indexes for which I am acting as a relay hop
for _, localRelayIdx := range hostinfo.relayState.CopyRelayForIdxs() { for _, localRelayIdx := range hostinfo.relayState.CopyRelayForIdxs() {
delete(hm.Relays, localRelayIdx) delete(hm.Relays, localRelayIdx)
} }
@ -448,11 +497,11 @@ func (hm *HostMap) QueryReverseIndex(index uint32) *HostInfo {
} }
} }
func (hm *HostMap) QueryVpnIp(vpnIp netip.Addr) *HostInfo { func (hm *HostMap) QueryVpnAddr(vpnIp netip.Addr) *HostInfo {
return hm.queryVpnIp(vpnIp, nil) return hm.queryVpnAddr(vpnIp, nil)
} }
func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp netip.Addr) (*HostInfo, *Relay, error) { func (hm *HostMap) QueryVpnAddrsRelayFor(targetIps []netip.Addr, relayHostIp netip.Addr) (*HostInfo, *Relay, error) {
hm.RLock() hm.RLock()
defer hm.RUnlock() defer hm.RUnlock()
@ -460,17 +509,42 @@ func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp netip.Addr) (*HostIn
if !ok { if !ok {
return nil, nil, errors.New("unable to find host") return nil, nil, errors.New("unable to find host")
} }
for h != nil { for h != nil {
r, ok := h.relayState.QueryRelayForByIp(targetIp) for _, targetIp := range targetIps {
if ok && r.State == Established { r, ok := h.relayState.QueryRelayForByIp(targetIp)
return h, r, nil if ok && r.State == Established {
return h, r, nil
}
} }
h = h.next h = h.next
} }
return nil, nil, errors.New("unable to find host with relay") return nil, nil, errors.New("unable to find host with relay")
} }
func (hm *HostMap) queryVpnIp(vpnIp netip.Addr, promoteIfce *Interface) *HostInfo { func (hm *HostMap) unlockedDisestablishVpnAddrRelayFor(hi *HostInfo) {
for _, relayHostIp := range hi.relayState.CopyRelayIps() {
if h, ok := hm.Hosts[relayHostIp]; ok {
for h != nil {
h.relayState.UpdateRelayForByIpState(hi.vpnAddrs[0], Disestablished)
h = h.next
}
}
}
for _, rs := range hi.relayState.CopyAllRelayFor() {
if rs.Type == ForwardingType {
if h, ok := hm.Hosts[rs.PeerAddr]; ok {
for h != nil {
h.relayState.UpdateRelayForByIpState(hi.vpnAddrs[0], Disestablished)
h = h.next
}
}
}
}
}
func (hm *HostMap) queryVpnAddr(vpnIp netip.Addr, promoteIfce *Interface) *HostInfo {
hm.RLock() hm.RLock()
if h, ok := hm.Hosts[vpnIp]; ok { if h, ok := hm.Hosts[vpnIp]; ok {
hm.RUnlock() hm.RUnlock()
@ -491,25 +565,30 @@ func (hm *HostMap) queryVpnIp(vpnIp netip.Addr, promoteIfce *Interface) *HostInf
func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) { func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
if f.serveDns { if f.serveDns {
remoteCert := hostinfo.ConnectionState.peerCert remoteCert := hostinfo.ConnectionState.peerCert
dnsR.Add(remoteCert.Certificate.Name()+".", remoteCert.Certificate.Networks()[0].Addr().String()) dnsR.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs)
} }
for _, addr := range hostinfo.vpnAddrs {
existing := hm.Hosts[hostinfo.vpnIp] hm.unlockedInnerAddHostInfo(addr, hostinfo, f)
hm.Hosts[hostinfo.vpnIp] = hostinfo
if existing != nil {
hostinfo.next = existing
existing.prev = hostinfo
} }
hm.Indexes[hostinfo.localIndexId] = hostinfo hm.Indexes[hostinfo.localIndexId] = hostinfo
hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
if hm.l.Level >= logrus.DebugLevel { if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"vpnIp": hostinfo.vpnIp, "mapTotalSize": len(hm.Hosts), hm.l.WithField("hostMap", m{"vpnAddrs": hostinfo.vpnAddrs, "mapTotalSize": len(hm.Hosts),
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": hostinfo.vpnIp}}). "hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}}).
Debug("Hostmap vpnIp added") Debug("Hostmap vpnIp added")
} }
}
func (hm *HostMap) unlockedInnerAddHostInfo(vpnAddr netip.Addr, hostinfo *HostInfo, f *Interface) {
existing := hm.Hosts[vpnAddr]
hm.Hosts[vpnAddr] = hostinfo
if existing != nil && existing != hostinfo {
hostinfo.next = existing
existing.prev = hostinfo
}
i := 1 i := 1
check := hostinfo check := hostinfo
@ -527,7 +606,7 @@ func (hm *HostMap) GetPreferredRanges() []netip.Prefix {
return *hm.preferredRanges.Load() return *hm.preferredRanges.Load()
} }
func (hm *HostMap) ForEachVpnIp(f controlEach) { func (hm *HostMap) ForEachVpnAddr(f controlEach) {
hm.RLock() hm.RLock()
defer hm.RUnlock() defer hm.RUnlock()
@ -581,7 +660,7 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []netip.Prefix, ifce *Interfac
} }
i.nextLHQuery.Store(now + ifce.reQueryWait.Load()) i.nextLHQuery.Store(now + ifce.reQueryWait.Load())
ifce.lightHouse.QueryServer(i.vpnIp) ifce.lightHouse.QueryServer(i.vpnAddrs[0])
} }
} }
@ -596,7 +675,7 @@ func (i *HostInfo) SetRemote(remote netip.AddrPort) {
// We copy here because we likely got this remote from a source that reuses the object // We copy here because we likely got this remote from a source that reuses the object
if i.remote != remote { if i.remote != remote {
i.remote = remote i.remote = remote
i.remotes.LearnRemote(i.vpnIp, remote) i.remotes.LearnRemote(i.vpnAddrs[0], remote)
} }
} }
@ -647,21 +726,20 @@ func (i *HostInfo) RecvErrorExceeded() bool {
return true return true
} }
func (i *HostInfo) CreateRemoteCIDR(c cert.Certificate) { func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
if len(c.Networks()) == 1 && len(c.UnsafeNetworks()) == 0 { if len(networks) == 1 && len(unsafeNetworks) == 0 {
// Simple case, no CIDRTree needed // Simple case, no CIDRTree needed
return return
} }
remoteCidr := new(bart.Table[struct{}]) i.networks = new(bart.Table[struct{}])
for _, network := range c.Networks() { for _, network := range networks {
remoteCidr.Insert(network, struct{}{}) i.networks.Insert(network, struct{}{})
} }
for _, network := range c.UnsafeNetworks() { for _, network := range unsafeNetworks {
remoteCidr.Insert(network, struct{}{}) i.networks.Insert(network, struct{}{})
} }
i.remoteCidr = remoteCidr
} }
func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry { func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
@ -669,7 +747,7 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
return logrus.NewEntry(l) return logrus.NewEntry(l)
} }
li := l.WithField("vpnIp", i.vpnIp). li := l.WithField("vpnAddrs", i.vpnAddrs).
WithField("localIndex", i.localIndexId). WithField("localIndex", i.localIndexId).
WithField("remoteIndex", i.remoteIndexId) WithField("remoteIndex", i.remoteIndexId)
@ -684,9 +762,9 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
// Utility functions // Utility functions
func localIps(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr { func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
//FIXME: This function is pretty garbage //FIXME: This function is pretty garbage
var ips []netip.Addr var finalAddrs []netip.Addr
ifaces, _ := net.Interfaces() ifaces, _ := net.Interfaces()
for _, i := range ifaces { for _, i := range ifaces {
allow := allowList.AllowName(i.Name) allow := allowList.AllowName(i.Name)
@ -698,39 +776,36 @@ func localIps(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
continue continue
} }
addrs, _ := i.Addrs() addrs, _ := i.Addrs()
for _, addr := range addrs { for _, rawAddr := range addrs {
var ip net.IP var addr netip.Addr
switch v := addr.(type) { switch v := rawAddr.(type) {
case *net.IPNet: case *net.IPNet:
//continue //continue
ip = v.IP addr, _ = netip.AddrFromSlice(v.IP)
case *net.IPAddr: case *net.IPAddr:
ip = v.IP addr, _ = netip.AddrFromSlice(v.IP)
} }
nip, ok := netip.AddrFromSlice(ip) if !addr.IsValid() {
if !ok {
if l.Level >= logrus.DebugLevel { if l.Level >= logrus.DebugLevel {
l.WithField("localIp", ip).Debug("ip was invalid for netip") l.WithField("localAddr", rawAddr).Debug("addr was invalid")
} }
continue continue
} }
nip = nip.Unmap() addr = addr.Unmap()
//TODO: Filtering out link local for now, this is probably the most correct thing if addr.IsLoopback() == false && addr.IsLinkLocalUnicast() == false {
//TODO: Would be nice to filter out SLAAC MAC based ips as well isAllowed := allowList.Allow(addr)
if nip.IsLoopback() == false && nip.IsLinkLocalUnicast() == false {
allow := allowList.Allow(nip)
if l.Level >= logrus.TraceLevel { if l.Level >= logrus.TraceLevel {
l.WithField("localIp", nip).WithField("allow", allow).Trace("localAllowList.Allow") l.WithField("localAddr", addr).WithField("allowed", isAllowed).Trace("localAllowList.Allow")
} }
if !allow { if !isAllowed {
continue continue
} }
ips = append(ips, nip) finalAddrs = append(finalAddrs, addr)
} }
} }
} }
return ips return finalAddrs
} }

View File

@ -11,17 +11,14 @@ import (
func TestHostMap_MakePrimary(t *testing.T) { func TestHostMap_MakePrimary(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
hm := newHostMap( hm := newHostMap(l)
l,
netip.MustParsePrefix("10.0.0.1/24"),
)
f := &Interface{} f := &Interface{}
h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1} h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1}
h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2} h2 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 2}
h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3} h3 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 3}
h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4} h4 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 4}
hm.unlockedAddHostInfo(h4, f) hm.unlockedAddHostInfo(h4, f)
hm.unlockedAddHostInfo(h3, f) hm.unlockedAddHostInfo(h3, f)
@ -29,7 +26,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
hm.unlockedAddHostInfo(h1, f) hm.unlockedAddHostInfo(h1, f)
// Make sure we go h1 -> h2 -> h3 -> h4 // Make sure we go h1 -> h2 -> h3 -> h4
prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) prim := hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h1.localIndexId, prim.localIndexId) assert.Equal(t, h1.localIndexId, prim.localIndexId)
assert.Equal(t, h2.localIndexId, prim.next.localIndexId) assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev) assert.Nil(t, prim.prev)
@ -44,7 +41,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
hm.MakePrimary(h3) hm.MakePrimary(h3)
// Make sure we go h3 -> h1 -> h2 -> h4 // Make sure we go h3 -> h1 -> h2 -> h4
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h3.localIndexId, prim.localIndexId) assert.Equal(t, h3.localIndexId, prim.localIndexId)
assert.Equal(t, h1.localIndexId, prim.next.localIndexId) assert.Equal(t, h1.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev) assert.Nil(t, prim.prev)
@ -59,7 +56,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
hm.MakePrimary(h4) hm.MakePrimary(h4)
// Make sure we go h4 -> h3 -> h1 -> h2 // Make sure we go h4 -> h3 -> h1 -> h2
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Equal(t, h4.localIndexId, prim.localIndexId)
assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev) assert.Nil(t, prim.prev)
@ -74,7 +71,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
hm.MakePrimary(h4) hm.MakePrimary(h4)
// Make sure we go h4 -> h3 -> h1 -> h2 // Make sure we go h4 -> h3 -> h1 -> h2
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Equal(t, h4.localIndexId, prim.localIndexId)
assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev) assert.Nil(t, prim.prev)
@ -88,19 +85,16 @@ func TestHostMap_MakePrimary(t *testing.T) {
func TestHostMap_DeleteHostInfo(t *testing.T) { func TestHostMap_DeleteHostInfo(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
hm := newHostMap( hm := newHostMap(l)
l,
netip.MustParsePrefix("10.0.0.1/24"),
)
f := &Interface{} f := &Interface{}
h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1} h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1}
h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2} h2 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 2}
h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3} h3 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 3}
h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4} h4 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 4}
h5 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 5} h5 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 5}
h6 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 6} h6 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 6}
hm.unlockedAddHostInfo(h6, f) hm.unlockedAddHostInfo(h6, f)
hm.unlockedAddHostInfo(h5, f) hm.unlockedAddHostInfo(h5, f)
@ -116,7 +110,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h) assert.Nil(t, h)
// Make sure we go h1 -> h2 -> h3 -> h4 -> h5 // Make sure we go h1 -> h2 -> h3 -> h4 -> h5
prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) prim := hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h1.localIndexId, prim.localIndexId) assert.Equal(t, h1.localIndexId, prim.localIndexId)
assert.Equal(t, h2.localIndexId, prim.next.localIndexId) assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev) assert.Nil(t, prim.prev)
@ -135,7 +129,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h1.next) assert.Nil(t, h1.next)
// Make sure we go h2 -> h3 -> h4 -> h5 // Make sure we go h2 -> h3 -> h4 -> h5
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h2.localIndexId, prim.localIndexId)
assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev) assert.Nil(t, prim.prev)
@ -153,7 +147,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h3.next) assert.Nil(t, h3.next)
// Make sure we go h2 -> h4 -> h5 // Make sure we go h2 -> h4 -> h5
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h2.localIndexId, prim.localIndexId)
assert.Equal(t, h4.localIndexId, prim.next.localIndexId) assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev) assert.Nil(t, prim.prev)
@ -169,7 +163,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h5.next) assert.Nil(t, h5.next)
// Make sure we go h2 -> h4 // Make sure we go h2 -> h4
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h2.localIndexId, prim.localIndexId)
assert.Equal(t, h4.localIndexId, prim.next.localIndexId) assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev) assert.Nil(t, prim.prev)
@ -183,7 +177,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h2.next) assert.Nil(t, h2.next)
// Make sure we only have h4 // Make sure we only have h4
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Equal(t, h4.localIndexId, prim.localIndexId)
assert.Nil(t, prim.prev) assert.Nil(t, prim.prev)
assert.Nil(t, prim.next) assert.Nil(t, prim.next)
@ -195,7 +189,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h4.next) assert.Nil(t, h4.next)
// Make sure we have nil // Make sure we have nil
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
assert.Nil(t, prim) assert.Nil(t, prim)
} }
@ -203,11 +197,7 @@ func TestHostMap_reload(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
c := config.NewC(l) c := config.NewC(l)
hm := NewHostMapFromConfig( hm := NewHostMapFromConfig(l, c)
l,
netip.MustParsePrefix("10.0.0.1/24"),
c,
)
toS := func(ipn []netip.Prefix) []string { toS := func(ipn []netip.Prefix) []string {
var s []string var s []string

View File

@ -9,8 +9,8 @@ import (
"net/netip" "net/netip"
) )
func (i *HostInfo) GetVpnIp() netip.Addr { func (i *HostInfo) GetVpnAddrs() []netip.Addr {
return i.vpnIp return i.vpnAddrs
} }
func (i *HostInfo) GetLocalIndex() uint32 { func (i *HostInfo) GetLocalIndex() uint32 {

View File

@ -20,14 +20,18 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
} }
// Ignore local broadcast packets // Ignore local broadcast packets
if f.dropLocalBroadcast && fwPacket.RemoteIP == f.myBroadcastAddr { if f.dropLocalBroadcast {
return _, found := f.myBroadcastAddrsTable.Lookup(fwPacket.RemoteAddr)
if found {
return
}
} }
if fwPacket.RemoteIP == f.myVpnNet.Addr() { _, found := f.myVpnAddrsTable.Lookup(fwPacket.RemoteAddr)
if found {
// Immediately forward packets from self to self. // Immediately forward packets from self to self.
// This should only happen on Darwin-based and FreeBSD hosts, which // This should only happen on Darwin-based and FreeBSD hosts, which
// routes packets from the Nebula IP to the Nebula IP through the Nebula // routes packets from the Nebula addr to the Nebula addr through the Nebula
// TUN device. // TUN device.
if immediatelyForwardToSelf { if immediatelyForwardToSelf {
_, err := f.readers[q].Write(packet) _, err := f.readers[q].Write(packet)
@ -36,25 +40,25 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
} }
} }
// Otherwise, drop. On linux, we should never see these packets - Linux // Otherwise, drop. On linux, we should never see these packets - Linux
// routes packets from the nebula IP to the nebula IP through the loopback device. // routes packets from the nebula addr to the nebula addr through the loopback device.
return return
} }
// Ignore multicast packets // Ignore multicast packets
if f.dropMulticast && fwPacket.RemoteIP.IsMulticast() { if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() {
return return
} }
hostinfo, ready := f.getOrHandshake(fwPacket.RemoteIP, func(hh *HandshakeHostInfo) { hostinfo, ready := f.getOrHandshake(fwPacket.RemoteAddr, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
}) })
if hostinfo == nil { if hostinfo == nil {
f.rejectInside(packet, out, q) f.rejectInside(packet, out, q)
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnIp", fwPacket.RemoteIP). f.l.WithField("vpnAddr", fwPacket.RemoteAddr).
WithField("fwPacket", fwPacket). WithField("fwPacket", fwPacket).
Debugln("dropping outbound packet, vpnIp not in our CIDR or in unsafe routes") Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks")
} }
return return
} }
@ -117,21 +121,22 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q) f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
} }
func (f *Interface) Handshake(vpnIp netip.Addr) { func (f *Interface) Handshake(vpnAddr netip.Addr) {
f.getOrHandshake(vpnIp, nil) f.getOrHandshake(vpnAddr, nil)
} }
// getOrHandshake returns nil if the vpnIp is not routable. // getOrHandshake returns nil if the vpnAddr is not routable.
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel
func (f *Interface) getOrHandshake(vpnIp netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { func (f *Interface) getOrHandshake(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
if !f.myVpnNet.Contains(vpnIp) { _, found := f.myVpnNetworksTable.Lookup(vpnAddr)
vpnIp = f.inside.RouteFor(vpnIp) if !found {
if !vpnIp.IsValid() { vpnAddr = f.inside.RouteFor(vpnAddr)
if !vpnAddr.IsValid() {
return nil, false return nil, false
} }
} }
return f.handshakeManager.GetOrHandshake(vpnIp, cacheCallback) return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback)
} }
func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) { func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
@ -156,16 +161,16 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0) f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0)
} }
// SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp // SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr
func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) { func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) {
hostInfo, ready := f.getOrHandshake(vpnIp, func(hh *HandshakeHostInfo) { hostInfo, ready := f.getOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics) hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
}) })
if hostInfo == nil { if hostInfo == nil {
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnIp", vpnIp). f.l.WithField("vpnAddr", vpnAddr).
Debugln("dropping SendMessageToVpnIp, vpnIp not in our CIDR or in unsafe routes") Debugln("dropping SendMessageToVpnAddr, vpnAddr not in our vpn networks or in unsafe routes")
} }
return return
} }
@ -258,7 +263,6 @@ func (f *Interface) SendVia(via *HostInfo,
func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int) { func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int) {
if ci.eKey == nil { if ci.eKey == nil {
//TODO: log warning
return return
} }
useRelay := !remote.IsValid() && !hostinfo.remote.IsValid() useRelay := !remote.IsValid() && !hostinfo.remote.IsValid()
@ -285,14 +289,14 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
f.connectionManager.Out(hostinfo.localIndexId) f.connectionManager.Out(hostinfo.localIndexId)
// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against // Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
// all our IPs and enable a faster roaming. // all our addrs and enable a faster roaming.
if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount { if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount {
//NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is //NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is
// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help. // finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
f.lightHouse.QueryServer(hostinfo.vpnIp) f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
hostinfo.lastRebindCount = f.rebindCount hostinfo.lastRebindCount = f.rebindCount
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnIp", hostinfo.vpnIp).Debug("Lighthouse update triggered for punch due to rebind counter") f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter")
} }
} }
@ -324,7 +328,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
} else { } else {
// Try to send via a relay // Try to send via a relay
for _, relayIP := range hostinfo.relayState.CopyRelayIps() { for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
relayHostInfo, relay, err := f.hostMap.QueryVpnIpRelayFor(hostinfo.vpnIp, relayIP) relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
if err != nil { if err != nil {
hostinfo.relayState.DeleteRelay(relayIP) hostinfo.relayState.DeleteRelay(relayIP)
hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo") hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")

View File

@ -2,17 +2,16 @@ package nebula
import ( import (
"context" "context"
"encoding/binary"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net"
"net/netip" "net/netip"
"os" "os"
"runtime" "runtime"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
@ -29,7 +28,6 @@ type InterfaceConfig struct {
Outside udp.Conn Outside udp.Conn
Inside overlay.Device Inside overlay.Device
pki *PKI pki *PKI
Cipher string
Firewall *Firewall Firewall *Firewall
ServeDns bool ServeDns bool
HandshakeManager *HandshakeManager HandshakeManager *HandshakeManager
@ -53,25 +51,27 @@ type InterfaceConfig struct {
} }
type Interface struct { type Interface struct {
hostMap *HostMap hostMap *HostMap
outside udp.Conn outside udp.Conn
inside overlay.Device inside overlay.Device
pki *PKI pki *PKI
cipher string firewall *Firewall
firewall *Firewall connectionManager *connectionManager
connectionManager *connectionManager handshakeManager *HandshakeManager
handshakeManager *HandshakeManager serveDns bool
serveDns bool createTime time.Time
createTime time.Time lightHouse *LightHouse
lightHouse *LightHouse myBroadcastAddrsTable *bart.Table[struct{}]
myBroadcastAddr netip.Addr myVpnAddrs []netip.Addr // A list of addresses assigned to us via our certificate
myVpnNet netip.Prefix myVpnAddrsTable *bart.Table[struct{}] // A table of addresses assigned to us via our certificate
dropLocalBroadcast bool myVpnNetworks []netip.Prefix // A list of networks assigned to us via our certificate
dropMulticast bool myVpnNetworksTable *bart.Table[struct{}] // A table of networks assigned to us via our certificate
routines int dropLocalBroadcast bool
disconnectInvalid atomic.Bool dropMulticast bool
closed atomic.Bool routines int
relayManager *relayManager disconnectInvalid atomic.Bool
closed atomic.Bool
relayManager *relayManager
tryPromoteEvery atomic.Uint32 tryPromoteEvery atomic.Uint32
reQueryEvery atomic.Uint32 reQueryEvery atomic.Uint32
@ -103,9 +103,11 @@ type EncWriter interface {
out []byte, out []byte,
nocopy bool, nocopy bool,
) )
SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte)
SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte)
Handshake(vpnIp netip.Addr) Handshake(vpnAddr netip.Addr)
GetHostInfo(vpnAddr netip.Addr) *HostInfo
GetCertState() *CertState
} }
type sendRecvErrorConfig uint8 type sendRecvErrorConfig uint8
@ -116,10 +118,10 @@ const (
sendRecvErrorPrivate sendRecvErrorPrivate
) )
func (s sendRecvErrorConfig) ShouldSendRecvError(ip netip.AddrPort) bool { func (s sendRecvErrorConfig) ShouldSendRecvError(endpoint netip.AddrPort) bool {
switch s { switch s {
case sendRecvErrorPrivate: case sendRecvErrorPrivate:
return ip.Addr().IsPrivate() return endpoint.Addr().IsPrivate()
case sendRecvErrorAlways: case sendRecvErrorAlways:
return true return true
case sendRecvErrorNever: case sendRecvErrorNever:
@ -156,27 +158,29 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
return nil, errors.New("no firewall rules") return nil, errors.New("no firewall rules")
} }
certificate := c.pki.GetCertState().Certificate cs := c.pki.getCertState()
ifce := &Interface{ ifce := &Interface{
pki: c.pki, pki: c.pki,
hostMap: c.HostMap, hostMap: c.HostMap,
outside: c.Outside, outside: c.Outside,
inside: c.Inside, inside: c.Inside,
cipher: c.Cipher, firewall: c.Firewall,
firewall: c.Firewall, serveDns: c.ServeDns,
serveDns: c.ServeDns, handshakeManager: c.HandshakeManager,
handshakeManager: c.HandshakeManager, createTime: time.Now(),
createTime: time.Now(), lightHouse: c.lightHouse,
lightHouse: c.lightHouse, dropLocalBroadcast: c.DropLocalBroadcast,
dropLocalBroadcast: c.DropLocalBroadcast, dropMulticast: c.DropMulticast,
dropMulticast: c.DropMulticast, routines: c.routines,
routines: c.routines, version: c.version,
version: c.version, writers: make([]udp.Conn, c.routines),
writers: make([]udp.Conn, c.routines), readers: make([]io.ReadWriteCloser, c.routines),
readers: make([]io.ReadWriteCloser, c.routines), myVpnNetworks: cs.myVpnNetworks,
myVpnNet: certificate.Networks()[0], myVpnNetworksTable: cs.myVpnNetworksTable,
relayManager: c.relayManager, myVpnAddrs: cs.myVpnAddrs,
myVpnAddrsTable: cs.myVpnAddrsTable,
myBroadcastAddrsTable: cs.myVpnBroadcastAddrsTable,
relayManager: c.relayManager,
conntrackCacheTimeout: c.ConntrackCacheTimeout, conntrackCacheTimeout: c.ConntrackCacheTimeout,
@ -190,14 +194,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
l: c.l, l: c.l,
} }
if ifce.myVpnNet.Addr().Is4() {
maskedAddr := certificate.Networks()[0].Masked()
addr := maskedAddr.Addr().As4()
mask := net.CIDRMask(maskedAddr.Bits(), maskedAddr.Addr().BitLen())
binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(mask))
ifce.myBroadcastAddr = netip.AddrFrom4(addr)
}
ifce.tryPromoteEvery.Store(c.tryPromoteEvery) ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
ifce.reQueryEvery.Store(c.reQueryEvery) ifce.reQueryEvery.Store(c.reQueryEvery)
ifce.reQueryWait.Store(int64(c.reQueryWait)) ifce.reQueryWait.Store(int64(c.reQueryWait))
@ -218,7 +214,7 @@ func (f *Interface) activate() {
f.l.WithError(err).Error("Failed to get udp listen address") f.l.WithError(err).Error("Failed to get udp listen address")
} }
f.l.WithField("interface", f.inside.Name()).WithField("network", f.inside.Cidr().String()). f.l.WithField("interface", f.inside.Name()).WithField("networks", f.myVpnNetworks).
WithField("build", f.version).WithField("udpAddr", addr). WithField("build", f.version).WithField("udpAddr", addr).
WithField("boringcrypto", boringEnabled()). WithField("boringcrypto", boringEnabled()).
Info("Nebula interface is active") Info("Nebula interface is active")
@ -259,16 +255,22 @@ func (f *Interface) listenOut(i int) {
runtime.LockOSThread() runtime.LockOSThread()
var li udp.Conn var li udp.Conn
// TODO clean this up with a coherent interface for each outside connection
if i > 0 { if i > 0 {
li = f.writers[i] li = f.writers[i]
} else { } else {
li = f.outside li = f.outside
} }
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
lhh := f.lightHouse.NewRequestHandler() lhh := f.lightHouse.NewRequestHandler()
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) plaintext := make([]byte, udp.MTU)
li.ListenOut(readOutsidePackets(f), lhHandleRequest(lhh, f), conntrackCache, i) h := &header.H{}
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
})
} }
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
@ -325,7 +327,7 @@ func (f *Interface) reloadFirewall(c *config.C) {
return return
} }
fw, err := NewFirewallFromConfig(f.l, f.pki.GetCertState().Certificate, c) fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c)
if err != nil { if err != nil {
f.l.WithError(err).Error("Error while creating firewall during reload") f.l.WithError(err).Error("Error while creating firewall during reload")
return return
@ -408,6 +410,8 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
udpStats := udp.NewUDPStatsEmitter(f.writers) udpStats := udp.NewUDPStatsEmitter(f.writers)
certExpirationGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil) certExpirationGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil)
certDefaultVersion := metrics.GetOrRegisterGauge("certificate.default_version", nil)
certMaxVersion := metrics.GetOrRegisterGauge("certificate.max_version", nil)
for { for {
select { select {
@ -417,11 +421,30 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
f.firewall.EmitStats() f.firewall.EmitStats()
f.handshakeManager.EmitStats() f.handshakeManager.EmitStats()
udpStats() udpStats()
certExpirationGauge.Update(int64(f.pki.GetCertState().Certificate.NotAfter().Sub(time.Now()) / time.Second))
certState := f.pki.getCertState()
defaultCrt := certState.GetDefaultCertificate()
certExpirationGauge.Update(int64(defaultCrt.NotAfter().Sub(time.Now()) / time.Second))
certDefaultVersion.Update(int64(defaultCrt.Version()))
// Report the max certificate version we are capable of using
if certState.v2Cert != nil {
certMaxVersion.Update(int64(certState.v2Cert.Version()))
} else {
certMaxVersion.Update(int64(certState.v1Cert.Version()))
}
} }
} }
} }
func (f *Interface) GetHostInfo(vpnIp netip.Addr) *HostInfo {
return f.hostMap.QueryVpnAddr(vpnIp)
}
func (f *Interface) GetCertState() *CertState {
return f.pki.getCertState()
}
func (f *Interface) Close() error { func (f *Interface) Close() error {
f.closed.Store(true) f.closed.Store(true)

View File

@ -6,8 +6,6 @@ import (
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
) )
//TODO: IPV6-WORK can probably delete this
const ( const (
// Need 96 bytes for the largest reject packet: // Need 96 bytes for the largest reject packet:
// - 20 byte ipv4 header // - 20 byte ipv4 header

File diff suppressed because it is too large Load Diff

View File

@ -7,6 +7,8 @@ import (
"net/netip" "net/netip"
"testing" "testing"
"github.com/gaissmai/bart"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
@ -14,62 +16,51 @@ import (
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
) )
//TODO: Add a test to ensure udpAddr is copied and not reused
func TestOldIPv4Only(t *testing.T) { func TestOldIPv4Only(t *testing.T) {
// This test ensures our new ipv6 enabled LH protobuf IpAndPorts works with the old style to enable backwards compatibility // This test ensures our new ipv6 enabled LH protobuf IpAndPorts works with the old style to enable backwards compatibility
b := []byte{8, 129, 130, 132, 80, 16, 10} b := []byte{8, 129, 130, 132, 80, 16, 10}
var m Ip4AndPort var m V4AddrPort
err := m.Unmarshal(b) err := m.Unmarshal(b)
assert.NoError(t, err) assert.NoError(t, err)
ip := netip.MustParseAddr("10.1.1.1") ip := netip.MustParseAddr("10.1.1.1")
bp := ip.As4() bp := ip.As4()
assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetIp()) assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetAddr())
}
func TestNewLhQuery(t *testing.T) {
myIp, err := netip.ParseAddr("192.1.1.1")
assert.NoError(t, err)
// Generating a new lh query should work
a := NewLhQueryByInt(myIp)
// The result should be a nebulameta protobuf
assert.IsType(t, &NebulaMeta{}, a)
// It should also Marshal fine
b, err := a.Marshal()
assert.Nil(t, err)
// and then Unmarshal fine
n := &NebulaMeta{}
err = n.Unmarshal(b)
assert.Nil(t, err)
} }
func Test_lhStaticMapping(t *testing.T) { func Test_lhStaticMapping(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
myVpnNet := netip.MustParsePrefix("10.128.0.1/16") myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
nt := new(bart.Table[struct{}])
nt.Insert(myVpnNet, struct{}{})
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
}
lh1 := "10.128.0.2" lh1 := "10.128.0.2"
c := config.NewC(l) c := config.NewC(l)
c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}} c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}}
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}} c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
_, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) _, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
assert.Nil(t, err) assert.Nil(t, err)
lh2 := "10.128.0.3" lh2 := "10.128.0.3"
c = config.NewC(l) c = config.NewC(l)
c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}} c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}}
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}} c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}}
_, err = NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) _, err = NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry") assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry")
} }
func TestReloadLighthouseInterval(t *testing.T) { func TestReloadLighthouseInterval(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
myVpnNet := netip.MustParsePrefix("10.128.0.1/16") myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
nt := new(bart.Table[struct{}])
nt.Insert(myVpnNet, struct{}{})
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
}
lh1 := "10.128.0.2" lh1 := "10.128.0.2"
c := config.NewC(l) c := config.NewC(l)
@ -79,7 +70,7 @@ func TestReloadLighthouseInterval(t *testing.T) {
} }
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}} c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
assert.NoError(t, err) assert.NoError(t, err)
lh.ifce = &mockEncWriter{} lh.ifce = &mockEncWriter{}
@ -99,9 +90,15 @@ func TestReloadLighthouseInterval(t *testing.T) {
func BenchmarkLighthouseHandleRequest(b *testing.B) { func BenchmarkLighthouseHandleRequest(b *testing.B) {
l := test.NewLogger() l := test.NewLogger()
myVpnNet := netip.MustParsePrefix("10.128.0.1/0") myVpnNet := netip.MustParsePrefix("10.128.0.1/0")
nt := new(bart.Table[struct{}])
nt.Insert(myVpnNet, struct{}{})
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
}
c := config.NewC(l) c := config.NewC(l)
lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
if !assert.NoError(b, err) { if !assert.NoError(b, err) {
b.Fatal() b.Fatal()
} }
@ -110,46 +107,47 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346") hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346")
vpnIp3 := netip.MustParseAddr("0.0.0.3") vpnIp3 := netip.MustParseAddr("0.0.0.3")
lh.addrMap[vpnIp3] = NewRemoteList(nil) lh.addrMap[vpnIp3] = NewRemoteList([]netip.Addr{vpnIp3}, nil)
lh.addrMap[vpnIp3].unlockedSetV4( lh.addrMap[vpnIp3].unlockedSetV4(
vpnIp3, vpnIp3,
vpnIp3, vpnIp3,
[]*Ip4AndPort{ []*V4AddrPort{
NewIp4AndPortFromNetIP(hAddr.Addr(), hAddr.Port()), netAddrToProtoV4AddrPort(hAddr.Addr(), hAddr.Port()),
NewIp4AndPortFromNetIP(hAddr2.Addr(), hAddr2.Port()), netAddrToProtoV4AddrPort(hAddr2.Addr(), hAddr2.Port()),
}, },
func(netip.Addr, *Ip4AndPort) bool { return true }, func(netip.Addr, *V4AddrPort) bool { return true },
) )
rAddr := netip.MustParseAddrPort("1.2.2.3:12345") rAddr := netip.MustParseAddrPort("1.2.2.3:12345")
rAddr2 := netip.MustParseAddrPort("1.2.2.3:12346") rAddr2 := netip.MustParseAddrPort("1.2.2.3:12346")
vpnIp2 := netip.MustParseAddr("0.0.0.3") vpnIp2 := netip.MustParseAddr("0.0.0.3")
lh.addrMap[vpnIp2] = NewRemoteList(nil) lh.addrMap[vpnIp2] = NewRemoteList([]netip.Addr{vpnIp2}, nil)
lh.addrMap[vpnIp2].unlockedSetV4( lh.addrMap[vpnIp2].unlockedSetV4(
vpnIp3, vpnIp3,
vpnIp3, vpnIp3,
[]*Ip4AndPort{ []*V4AddrPort{
NewIp4AndPortFromNetIP(rAddr.Addr(), rAddr.Port()), netAddrToProtoV4AddrPort(rAddr.Addr(), rAddr.Port()),
NewIp4AndPortFromNetIP(rAddr2.Addr(), rAddr2.Port()), netAddrToProtoV4AddrPort(rAddr2.Addr(), rAddr2.Port()),
}, },
func(netip.Addr, *Ip4AndPort) bool { return true }, func(netip.Addr, *V4AddrPort) bool { return true },
) )
mw := &mockEncWriter{} mw := &mockEncWriter{}
hi := []netip.Addr{vpnIp2}
b.Run("notfound", func(b *testing.B) { b.Run("notfound", func(b *testing.B) {
lhh := lh.NewRequestHandler() lhh := lh.NewRequestHandler()
req := &NebulaMeta{ req := &NebulaMeta{
Type: NebulaMeta_HostQuery, Type: NebulaMeta_HostQuery,
Details: &NebulaMetaDetails{ Details: &NebulaMetaDetails{
VpnIp: 4, OldVpnAddr: 4,
Ip4AndPorts: nil, V4AddrPorts: nil,
}, },
} }
p, err := req.Marshal() p, err := req.Marshal()
assert.NoError(b, err) assert.NoError(b, err)
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
lhh.HandleRequest(rAddr, vpnIp2, p, mw) lhh.HandleRequest(rAddr, hi, p, mw)
} }
}) })
b.Run("found", func(b *testing.B) { b.Run("found", func(b *testing.B) {
@ -157,15 +155,15 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
req := &NebulaMeta{ req := &NebulaMeta{
Type: NebulaMeta_HostQuery, Type: NebulaMeta_HostQuery,
Details: &NebulaMetaDetails{ Details: &NebulaMetaDetails{
VpnIp: 3, OldVpnAddr: 3,
Ip4AndPorts: nil, V4AddrPorts: nil,
}, },
} }
p, err := req.Marshal() p, err := req.Marshal()
assert.NoError(b, err) assert.NoError(b, err)
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
lhh.HandleRequest(rAddr, vpnIp2, p, mw) lhh.HandleRequest(rAddr, hi, p, mw)
} }
}) })
} }
@ -197,40 +195,49 @@ func TestLighthouse_Memory(t *testing.T) {
c := config.NewC(l) c := config.NewC(l)
c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true} c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
c.Settings["listen"] = map[interface{}]interface{}{"port": 4242} c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil)
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
nt := new(bart.Table[struct{}])
nt.Insert(myVpnNet, struct{}{})
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
lh.ifce = &mockEncWriter{}
assert.NoError(t, err) assert.NoError(t, err)
lhh := lh.NewRequestHandler() lhh := lh.NewRequestHandler()
// Test that my first update responds with just that // Test that my first update responds with just that
newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr2}, lhh) newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr2}, lhh)
r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr2) assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr2)
// Ensure we don't accumulate addresses // Ensure we don't accumulate addresses
newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr3}, lhh) newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr3}, lhh)
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr3) assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr3)
// Grow it back to 2 // Grow it back to 2
newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr4}, lhh) newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr4}, lhh)
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4) assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4)
// Update a different host and ask about it // Update a different host and ask about it
newLHHostUpdate(theirUdpAddr0, theirVpnIp, []netip.AddrPort{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh) newLHHostUpdate(theirUdpAddr0, theirVpnIp, []netip.AddrPort{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh)
r = newLHHostRequest(theirUdpAddr0, theirVpnIp, theirVpnIp, lhh) r = newLHHostRequest(theirUdpAddr0, theirVpnIp, theirVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4) assertIp4InArray(t, r.msg.Details.V4AddrPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
// Have both hosts ask about the other // Have both hosts ask about the other
r = newLHHostRequest(theirUdpAddr0, theirVpnIp, myVpnIp, lhh) r = newLHHostRequest(theirUdpAddr0, theirVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4) assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4)
r = newLHHostRequest(myUdpAddr0, myVpnIp, theirVpnIp, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, theirVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4) assertIp4InArray(t, r.msg.Details.V4AddrPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
// Make sure we didn't get changed // Make sure we didn't get changed
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4) assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4)
// Ensure proper ordering and limiting // Ensure proper ordering and limiting
// Send 12 addrs, get 10 back, the last 2 removed, allowing the duplicate to remain (clients dedupe) // Send 12 addrs, get 10 back, the last 2 removed, allowing the duplicate to remain (clients dedupe)
@ -255,7 +262,7 @@ func TestLighthouse_Memory(t *testing.T) {
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray( assertIp4InArray(
t, t,
r.msg.Details.Ip4AndPorts, r.msg.Details.V4AddrPorts,
myUdpAddr1, myUdpAddr2, myUdpAddr3, myUdpAddr4, myUdpAddr5, myUdpAddr5, myUdpAddr6, myUdpAddr7, myUdpAddr8, myUdpAddr9, myUdpAddr1, myUdpAddr2, myUdpAddr3, myUdpAddr4, myUdpAddr5, myUdpAddr5, myUdpAddr6, myUdpAddr7, myUdpAddr8, myUdpAddr9,
) )
@ -265,7 +272,7 @@ func TestLighthouse_Memory(t *testing.T) {
good := netip.MustParseAddrPort("1.128.0.99:4242") good := netip.MustParseAddrPort("1.128.0.99:4242")
newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{bad1, bad2, good}, lhh) newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{bad1, bad2, good}, lhh)
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, good) assertIp4InArray(t, r.msg.Details.V4AddrPorts, good)
} }
func TestLighthouse_reload(t *testing.T) { func TestLighthouse_reload(t *testing.T) {
@ -273,7 +280,16 @@ func TestLighthouse_reload(t *testing.T) {
c := config.NewC(l) c := config.NewC(l)
c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true} c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
c.Settings["listen"] = map[interface{}]interface{}{"port": 4242} c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil)
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
nt := new(bart.Table[struct{}])
nt.Insert(myVpnNet, struct{}{})
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
assert.NoError(t, err) assert.NoError(t, err)
nc := map[interface{}]interface{}{ nc := map[interface{}]interface{}{
@ -290,13 +306,16 @@ func TestLighthouse_reload(t *testing.T) {
} }
func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply { func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply {
//TODO: IPV6-WORK
bip := queryVpnIp.As4()
req := &NebulaMeta{ req := &NebulaMeta{
Type: NebulaMeta_HostQuery, Type: NebulaMeta_HostQuery,
Details: &NebulaMetaDetails{ Details: &NebulaMetaDetails{},
VpnIp: binary.BigEndian.Uint32(bip[:]), }
},
if queryVpnIp.Is4() {
bip := queryVpnIp.As4()
req.Details.OldVpnAddr = binary.BigEndian.Uint32(bip[:])
} else {
req.Details.VpnAddr = netAddrToProtoAddr(queryVpnIp)
} }
b, err := req.Marshal() b, err := req.Marshal()
@ -308,23 +327,29 @@ func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, l
w := &testEncWriter{ w := &testEncWriter{
metaFilter: &filter, metaFilter: &filter,
} }
lhh.HandleRequest(fromAddr, myVpnIp, b, w) lhh.HandleRequest(fromAddr, []netip.Addr{myVpnIp}, b, w)
return w.lastReply return w.lastReply
} }
func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.AddrPort, lhh *LightHouseHandler) { func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.AddrPort, lhh *LightHouseHandler) {
//TODO: IPV6-WORK
bip := vpnIp.As4()
req := &NebulaMeta{ req := &NebulaMeta{
Type: NebulaMeta_HostUpdateNotification, Type: NebulaMeta_HostUpdateNotification,
Details: &NebulaMetaDetails{ Details: &NebulaMetaDetails{},
VpnIp: binary.BigEndian.Uint32(bip[:]),
Ip4AndPorts: make([]*Ip4AndPort, len(addrs)),
},
} }
for k, v := range addrs { if vpnIp.Is4() {
req.Details.Ip4AndPorts[k] = NewIp4AndPortFromNetIP(v.Addr(), v.Port()) bip := vpnIp.As4()
req.Details.OldVpnAddr = binary.BigEndian.Uint32(bip[:])
} else {
req.Details.VpnAddr = netAddrToProtoAddr(vpnIp)
}
for _, v := range addrs {
if v.Addr().Is4() {
req.Details.V4AddrPorts = append(req.Details.V4AddrPorts, netAddrToProtoV4AddrPort(v.Addr(), v.Port()))
} else {
req.Details.V6AddrPorts = append(req.Details.V6AddrPorts, netAddrToProtoV6AddrPort(v.Addr(), v.Port()))
}
} }
b, err := req.Marshal() b, err := req.Marshal()
@ -333,75 +358,9 @@ func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.Ad
} }
w := &testEncWriter{} w := &testEncWriter{}
lhh.HandleRequest(fromAddr, vpnIp, b, w) lhh.HandleRequest(fromAddr, []netip.Addr{vpnIp}, b, w)
} }
//TODO: this is a RemoteList test
//func Test_lhRemoteAllowList(t *testing.T) {
// l := NewLogger()
// c := NewConfig(l)
// c.Settings["remoteallowlist"] = map[interface{}]interface{}{
// "10.20.0.0/12": false,
// }
// allowList, err := c.GetAllowList("remoteallowlist", false)
// assert.Nil(t, err)
//
// lh1 := "10.128.0.2"
// lh1IP := net.ParseIP(lh1)
//
// udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
//
// lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
// lh.SetRemoteAllowList(allowList)
//
// // A disallowed ip should not enter the cache but we should end up with an empty entry in the addrMap
// remote1IP := net.ParseIP("10.20.0.3")
// remotes := lh.unlockedGetRemoteList(ip2int(remote1IP))
// remotes.unlockedPrependV4(ip2int(remote1IP), NewIp4AndPort(remote1IP, 4242))
// assert.NotNil(t, lh.addrMap[ip2int(remote1IP)])
// assert.Empty(t, lh.addrMap[ip2int(remote1IP)].CopyAddrs([]*net.IPNet{}))
//
// // Make sure a good ip enters the cache and addrMap
// remote2IP := net.ParseIP("10.128.0.3")
// remote2UDPAddr := NewUDPAddr(remote2IP, uint16(4242))
// lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remote2UDPAddr.IP, uint32(remote2UDPAddr.Port)), false, false)
// assertUdpAddrInArray(t, lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}), remote2UDPAddr)
//
// // Another good ip gets into the cache, ordering is inverted
// remote3IP := net.ParseIP("10.128.0.4")
// remote3UDPAddr := NewUDPAddr(remote3IP, uint16(4243))
// lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remote3UDPAddr.IP, uint32(remote3UDPAddr.Port)), false, false)
// assertUdpAddrInArray(t, lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}), remote2UDPAddr, remote3UDPAddr)
//
// // If we exceed the length limit we should only have the most recent addresses
// addedAddrs := []*udpAddr{}
// for i := 0; i < 11; i++ {
// remoteUDPAddr := NewUDPAddr(net.IP{10, 128, 0, 4}, uint16(4243+i))
// lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remoteUDPAddr.IP, uint32(remoteUDPAddr.Port)), false, false)
// // The first entry here is a duplicate, don't add it to the assert list
// if i != 0 {
// addedAddrs = append(addedAddrs, remoteUDPAddr)
// }
// }
//
// // We should only have the last 10 of what we tried to add
// assert.True(t, len(addedAddrs) >= 10, "We should have tried to add at least 10 addresses")
// assertUdpAddrInArray(
// t,
// lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}),
// addedAddrs[0],
// addedAddrs[1],
// addedAddrs[2],
// addedAddrs[3],
// addedAddrs[4],
// addedAddrs[5],
// addedAddrs[6],
// addedAddrs[7],
// addedAddrs[8],
// addedAddrs[9],
// )
//}
type testLhReply struct { type testLhReply struct {
nebType header.MessageType nebType header.MessageType
nebSubType header.MessageSubType nebSubType header.MessageSubType
@ -410,8 +369,9 @@ type testLhReply struct {
} }
type testEncWriter struct { type testEncWriter struct {
lastReply testLhReply lastReply testLhReply
metaFilter *NebulaMeta_MessageType metaFilter *NebulaMeta_MessageType
protocolVersion cert.Version
} }
func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) { func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
@ -426,7 +386,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M
tw.lastReply = testLhReply{ tw.lastReply = testLhReply{
nebType: t, nebType: t,
nebSubType: st, nebSubType: st,
vpnIp: hostinfo.vpnIp, vpnIp: hostinfo.vpnAddrs[0],
msg: msg, msg: msg,
} }
} }
@ -436,7 +396,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M
} }
} }
func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, _, _ []byte) { func (tw *testEncWriter) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, _, _ []byte) {
msg := &NebulaMeta{} msg := &NebulaMeta{}
err := msg.Unmarshal(p) err := msg.Unmarshal(p)
if tw.metaFilter == nil || msg.Type == *tw.metaFilter { if tw.metaFilter == nil || msg.Type == *tw.metaFilter {
@ -453,17 +413,84 @@ func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.Mess
} }
} }
func (tw *testEncWriter) GetHostInfo(vpnIp netip.Addr) *HostInfo {
return nil
}
func (tw *testEncWriter) GetCertState() *CertState {
return &CertState{defaultVersion: tw.protocolVersion}
}
// assertIp4InArray asserts every address in want is at the same position in have and that the lengths match // assertIp4InArray asserts every address in want is at the same position in have and that the lengths match
func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...netip.AddrPort) { func assertIp4InArray(t *testing.T, have []*V4AddrPort, want ...netip.AddrPort) {
if !assert.Len(t, have, len(want)) { if !assert.Len(t, have, len(want)) {
return return
} }
for k, w := range want { for k, w := range want {
//TODO: IPV6-WORK h := protoV4AddrPortToNetAddrPort(have[k])
h := AddrPortFromIp4AndPort(have[k])
if !(h == w) { if !(h == w) {
assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v, found %v", w, k, h)) assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v, found %v", w, k, h))
} }
} }
} }
func Test_findNetworkUnion(t *testing.T) {
var out netip.Addr
var ok bool
tenDot := netip.MustParsePrefix("10.0.0.0/8")
oneSevenTwo := netip.MustParsePrefix("172.16.0.0/16")
fe80 := netip.MustParsePrefix("fe80::/8")
fc00 := netip.MustParsePrefix("fc00::/7")
a1 := netip.MustParseAddr("10.0.0.1")
afe81 := netip.MustParseAddr("fe80::1")
//simple
out, ok = findNetworkUnion([]netip.Prefix{tenDot}, []netip.Addr{a1})
assert.True(t, ok)
assert.Equal(t, out, a1)
//mixed lengths
out, ok = findNetworkUnion([]netip.Prefix{tenDot}, []netip.Addr{a1, afe81})
assert.True(t, ok)
assert.Equal(t, out, a1)
out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo}, []netip.Addr{a1})
assert.True(t, ok)
assert.Equal(t, out, a1)
//mixed family
out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{a1})
assert.True(t, ok)
assert.Equal(t, out, a1)
out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{a1, afe81})
assert.True(t, ok)
assert.Equal(t, out, a1)
//ordering
out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{afe81, a1})
assert.True(t, ok)
assert.Equal(t, out, a1)
out, ok = findNetworkUnion([]netip.Prefix{fe80, tenDot, oneSevenTwo}, []netip.Addr{afe81, a1})
assert.True(t, ok)
assert.Equal(t, out, afe81)
//some mismatches
out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{afe81})
assert.True(t, ok)
assert.Equal(t, out, afe81)
out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1, afe81})
assert.True(t, ok)
assert.Equal(t, out, afe81)
//falsey cases
out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1})
assert.False(t, ok)
out, ok = findNetworkUnion([]netip.Prefix{fc00, fe80}, []netip.Addr{a1})
assert.False(t, ok)
out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fc00}, []netip.Addr{a1, afe81})
assert.False(t, ok)
out, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81})
assert.False(t, ok)
}

29
main.go
View File

@ -2,7 +2,6 @@ package nebula
import ( import (
"context" "context"
"encoding/binary"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
@ -61,15 +60,12 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err) return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err)
} }
certificate := pki.GetCertState().Certificate fw, err := NewFirewallFromConfig(l, pki.getCertState(), c)
fw, err := NewFirewallFromConfig(l, certificate, c)
if err != nil { if err != nil {
return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err) return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err)
} }
l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started") l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started")
tunCidr := certificate.Networks()[0]
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
if err != nil { if err != nil {
return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err) return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err)
@ -132,7 +128,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
deviceFactory = overlay.NewDeviceFromConfig deviceFactory = overlay.NewDeviceFromConfig
} }
tun, err = deviceFactory(c, l, tunCidr, routines) tun, err = deviceFactory(c, l, pki.getCertState().myVpnNetworks, routines)
if err != nil { if err != nil {
return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err) return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err)
} }
@ -187,9 +183,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
} }
} }
hostMap := NewHostMapFromConfig(l, tunCidr, c) hostMap := NewHostMapFromConfig(l, c)
punchy := NewPunchyFromConfig(l, c) punchy := NewPunchyFromConfig(l, c)
lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy) lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy)
if err != nil { if err != nil {
return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err) return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err)
} }
@ -232,7 +228,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
Inside: tun, Inside: tun,
Outside: udpConns[0], Outside: udpConns[0],
pki: pki, pki: pki,
Cipher: c.GetString("cipher", "aes"),
Firewall: fw, Firewall: fw,
ServeDns: serveDns, ServeDns: serveDns,
HandshakeManager: handshakeManager, HandshakeManager: handshakeManager,
@ -254,15 +249,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
l: l, l: l,
} }
switch ifConfig.Cipher {
case "aes":
noiseEndianness = binary.BigEndian
case "chachapoly":
noiseEndianness = binary.LittleEndian
default:
return nil, fmt.Errorf("unknown cipher: %v", ifConfig.Cipher)
}
var ifce *Interface var ifce *Interface
if !configTest { if !configTest {
ifce, err = NewInterface(ctx, ifConfig) ifce, err = NewInterface(ctx, ifConfig)
@ -270,8 +256,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
return nil, fmt.Errorf("failed to initialize interface: %s", err) return nil, fmt.Errorf("failed to initialize interface: %s", err)
} }
// TODO: Better way to attach these, probably want a new interface in InterfaceConfig
// I don't want to make this initial commit too far-reaching though
ifce.writers = udpConns ifce.writers = udpConns
lightHouse.ifce = ifce lightHouse.ifce = ifce
@ -283,8 +267,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
go handshakeManager.Run(ctx) go handshakeManager.Run(ctx)
} }
// TODO - stats third-party modules start uncancellable goroutines. Update those libs to accept
// a context so that they can exit when the context is Done.
statsStart, err := startStats(l, c, buildVersion, configTest) statsStart, err := startStats(l, c, buildVersion, configTest)
if err != nil { if err != nil {
return nil, util.ContextualizeIfNeeded("Failed to start stats emitter", err) return nil, util.ContextualizeIfNeeded("Failed to start stats emitter", err)
@ -294,7 +276,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
return nil, nil return nil, nil
} }
//TODO: check if we _should_ be emitting stats
go ifce.emitStats(ctx, c.GetDuration("stats.interval", time.Second*10)) go ifce.emitStats(ctx, c.GetDuration("stats.interval", time.Second*10))
attachCommands(l, c, ssh, ifce) attachCommands(l, c, ssh, ifce)
@ -303,7 +284,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
var dnsStart func() var dnsStart func()
if lightHouse.amLighthouse && serveDns { if lightHouse.amLighthouse && serveDns {
l.Debugln("Starting dns server") l.Debugln("Starting dns server")
dnsStart = dnsMain(l, hostMap, c) dnsStart = dnsMain(l, pki.getCertState(), hostMap, c)
} }
return &Control{ return &Control{

View File

@ -7,8 +7,6 @@ import (
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
) )
//TODO: this can probably move into the header package
type MessageMetrics struct { type MessageMetrics struct {
rx [][]metrics.Counter rx [][]metrics.Counter
tx [][]metrics.Counter tx [][]metrics.Counter

File diff suppressed because it is too large Load Diff

View File

@ -23,19 +23,28 @@ message NebulaMeta {
} }
message NebulaMetaDetails { message NebulaMetaDetails {
uint32 VpnIp = 1; uint32 OldVpnAddr = 1 [deprecated = true];
repeated Ip4AndPort Ip4AndPorts = 2; Addr VpnAddr = 6;
repeated Ip6AndPort Ip6AndPorts = 4;
repeated uint32 RelayVpnIp = 5; repeated uint32 OldRelayVpnAddrs = 5 [deprecated = true];
repeated Addr RelayVpnAddrs = 7;
repeated V4AddrPort V4AddrPorts = 2;
repeated V6AddrPort V6AddrPorts = 4;
uint32 counter = 3; uint32 counter = 3;
} }
message Ip4AndPort { message Addr {
uint32 Ip = 1; uint64 Hi = 1;
uint64 Lo = 2;
}
message V4AddrPort {
uint32 Addr = 1;
uint32 Port = 2; uint32 Port = 2;
} }
message Ip6AndPort { message V6AddrPort {
uint64 Hi = 1; uint64 Hi = 1;
uint64 Lo = 2; uint64 Lo = 2;
uint32 Port = 3; uint32 Port = 3;
@ -62,6 +71,7 @@ message NebulaHandshakeDetails {
uint32 ResponderIndex = 3; uint32 ResponderIndex = 3;
uint64 Cookie = 4; uint64 Cookie = 4;
uint64 Time = 5; uint64 Time = 5;
uint32 CertVersion = 8;
// reserved for WIP multiport // reserved for WIP multiport
reserved 6, 7; reserved 6, 7;
} }
@ -76,6 +86,10 @@ message NebulaControl {
uint32 InitiatorRelayIndex = 2; uint32 InitiatorRelayIndex = 2;
uint32 ResponderRelayIndex = 3; uint32 ResponderRelayIndex = 3;
uint32 RelayToIp = 4;
uint32 RelayFromIp = 5; uint32 OldRelayToAddr = 4 [deprecated = true];
uint32 OldRelayFromAddr = 5 [deprecated = true];
Addr RelayToAddr = 6;
Addr RelayFromAddr = 7;
} }

View File

@ -3,16 +3,15 @@ package nebula
import ( import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt"
"net/netip" "net/netip"
"time" "time"
"github.com/flynn/noise" "github.com/google/gopacket/layers"
"golang.org/x/net/ipv6"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/udp"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
) )
@ -20,28 +19,9 @@ const (
minFwPacketLen = 4 minFwPacketLen = 4
) )
// TODO: IPV6-WORK this can likely be removed now func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
func readOutsidePackets(f *Interface) udp.EncReader {
return func(
addr netip.AddrPort,
out []byte,
packet []byte,
header *header.H,
fwPacket *firewall.Packet,
lhh udp.LightHouseHandlerFunc,
nb []byte,
q int,
localCache firewall.ConntrackCache,
) {
f.readOutsidePackets(addr, nil, out, packet, header, fwPacket, lhh, nb, q, localCache)
}
}
func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) {
err := h.Parse(packet) err := h.Parse(packet)
if err != nil { if err != nil {
// TODO: best if we return this and let caller log
// TODO: Might be better to send the literal []byte("holepunch") packet and ignore that?
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
if len(packet) > 1 { if len(packet) > 1 {
f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", ip, err) f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", ip, err)
@ -51,7 +31,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
//l.Error("in packet ", header, packet[HeaderLen:]) //l.Error("in packet ", header, packet[HeaderLen:])
if ip.IsValid() { if ip.IsValid() {
if f.myVpnNet.Contains(ip.Addr()) { _, found := f.myVpnNetworksTable.Lookup(ip.Addr())
if found {
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet") f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
} }
@ -108,7 +89,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
if !ok { if !ok {
// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing // The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
// its internal mapping. This should never happen. // its internal mapping. This should never happen.
hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnIp": hostinfo.vpnIp, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index") hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
return return
} }
@ -120,9 +101,9 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
return return
case ForwardingType: case ForwardingType:
// Find the target HostInfo relay object // Find the target HostInfo relay object
targetHI, targetRelay, err := f.hostMap.QueryVpnIpRelayFor(hostinfo.vpnIp, relay.PeerIp) targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
if err != nil { if err != nil {
hostinfo.logger(f.l).WithField("relayTo", relay.PeerIp).WithError(err).Info("Failed to find target host info by ip") hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip")
return return
} }
@ -138,7 +119,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal") hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
} }
} else { } else {
hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerIp, "relayFrom": hostinfo.vpnIp, "targetRelayState": targetRelay.State}).Info("Unexpected target relay state") hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
return return
} }
} }
@ -155,13 +136,10 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
WithField("packet", packet). WithField("packet", packet).
Error("Failed to decrypt lighthouse packet") Error("Failed to decrypt lighthouse packet")
//TODO: maybe after build 64 is out? 06/14/2018 - NB
//f.sendRecvError(net.Addr(addr), header.RemoteIndex)
return return
} }
lhf(ip, hostinfo.vpnIp, d) lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f)
// Fallthrough to the bottom to record incoming traffic // Fallthrough to the bottom to record incoming traffic
@ -176,9 +154,6 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
WithField("packet", packet). WithField("packet", packet).
Error("Failed to decrypt test packet") Error("Failed to decrypt test packet")
//TODO: maybe after build 64 is out? 06/14/2018 - NB
//f.sendRecvError(net.Addr(addr), header.RemoteIndex)
return return
} }
@ -228,14 +203,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
Error("Failed to decrypt Control packet") Error("Failed to decrypt Control packet")
return return
} }
m := &NebulaControl{}
err = m.Unmarshal(d)
if err != nil {
hostinfo.logger(f.l).WithError(err).Error("Failed to unmarshal control message")
break
}
f.relayManager.HandleControlMsg(hostinfo, m, f) f.relayManager.HandleControlMsg(hostinfo, d, f)
default: default:
f.messageMetrics.Rx(h.Type, h.Subtype, 1) f.messageMetrics.Rx(h.Type, h.Subtype, 1)
@ -252,8 +221,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
func (f *Interface) closeTunnel(hostInfo *HostInfo) { func (f *Interface) closeTunnel(hostInfo *HostInfo) {
final := f.hostMap.DeleteHostInfo(hostInfo) final := f.hostMap.DeleteHostInfo(hostInfo)
if final { if final {
// We no longer have any tunnels with this vpn ip, clear learned lighthouse state to lower memory usage // We no longer have any tunnels with this vpn addr, clear learned lighthouse state to lower memory usage
f.lightHouse.DeleteVpnIp(hostInfo.vpnIp) f.lightHouse.DeleteVpnAddrs(hostInfo.vpnAddrs)
} }
} }
@ -262,25 +231,26 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) {
f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
} }
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, ip netip.AddrPort) { func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort) {
if ip.IsValid() && hostinfo.remote != ip { if udpAddr.IsValid() && hostinfo.remote != udpAddr {
if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, ip.Addr()) { if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, udpAddr.Addr()) {
hostinfo.logger(f.l).WithField("newAddr", ip).Debug("lighthouse.remote_allow_list denied roaming") hostinfo.logger(f.l).WithField("newAddr", udpAddr).Debug("lighthouse.remote_allow_list denied roaming")
return return
} }
if !hostinfo.lastRoam.IsZero() && ip == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
if !hostinfo.lastRoam.IsZero() && udpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip). hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds) Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
} }
return return
} }
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip). hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
Info("Host roamed to new udp ip/port.") Info("Host roamed to new udp ip/port.")
hostinfo.lastRoam = time.Now() hostinfo.lastRoam = time.Now()
hostinfo.lastRoamRemote = hostinfo.remote hostinfo.lastRoamRemote = hostinfo.remote
hostinfo.SetRemote(ip) hostinfo.SetRemote(udpAddr)
} }
} }
@ -300,24 +270,141 @@ func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h
return true return true
} }
var (
ErrPacketTooShort = errors.New("packet is too short")
ErrUnknownIPVersion = errors.New("packet is an unknown ip version")
ErrIPv4InvalidHeaderLength = errors.New("invalid ipv4 header length")
ErrIPv4PacketTooShort = errors.New("ipv4 packet is too short")
ErrIPv6PacketTooShort = errors.New("ipv6 packet is too short")
ErrIPv6CouldNotFindPayload = errors.New("could not find payload in ipv6 packet")
)
// newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers // newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers
func newPacket(data []byte, incoming bool, fp *firewall.Packet) error { func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
// Do we at least have an ipv4 header worth of data? if len(data) < 1 {
if len(data) < ipv4.HeaderLen { return ErrPacketTooShort
return fmt.Errorf("packet is less than %v bytes", ipv4.HeaderLen)
} }
// Is it an ipv4 packet? version := int((data[0] >> 4) & 0x0f)
if int((data[0]>>4)&0x0f) != 4 { switch version {
return fmt.Errorf("packet is not ipv4, type: %v", int((data[0]>>4)&0x0f)) case ipv4.Version:
return parseV4(data, incoming, fp)
case ipv6.Version:
return parseV6(data, incoming, fp)
}
return ErrUnknownIPVersion
}
func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
dataLen := len(data)
if dataLen < ipv6.HeaderLen {
return ErrIPv6PacketTooShort
}
if incoming {
fp.RemoteAddr, _ = netip.AddrFromSlice(data[8:24])
fp.LocalAddr, _ = netip.AddrFromSlice(data[24:40])
} else {
fp.LocalAddr, _ = netip.AddrFromSlice(data[8:24])
fp.RemoteAddr, _ = netip.AddrFromSlice(data[24:40])
}
protoAt := 6 // NextHeader is at 6 bytes into the ipv6 header
offset := ipv6.HeaderLen // Start at the end of the ipv6 header
next := 0
for {
if dataLen < offset {
break
}
proto := layers.IPProtocol(data[protoAt])
//fmt.Println(proto, protoAt)
switch proto {
case layers.IPProtocolICMPv6, layers.IPProtocolESP, layers.IPProtocolNoNextHeader:
fp.Protocol = uint8(proto)
fp.RemotePort = 0
fp.LocalPort = 0
fp.Fragment = false
return nil
case layers.IPProtocolTCP, layers.IPProtocolUDP:
if dataLen < offset+4 {
return ErrIPv6PacketTooShort
}
fp.Protocol = uint8(proto)
if incoming {
fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2])
fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
} else {
fp.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2])
fp.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
}
fp.Fragment = false
return nil
case layers.IPProtocolIPv6Fragment:
// Fragment header is 8 bytes, need at least offset+4 to read the offset field
if dataLen < offset+8 {
return ErrIPv6PacketTooShort
}
// Check if this is the first fragment
fragmentOffset := binary.BigEndian.Uint16(data[offset+2:offset+4]) &^ uint16(0x7) // Remove the reserved and M flag bits
if fragmentOffset != 0 {
// Non-first fragment, use what we have now and stop processing
fp.Protocol = data[offset]
fp.Fragment = true
fp.RemotePort = 0
fp.LocalPort = 0
return nil
}
// The next loop should be the transport layer since we are the first fragment
next = 8 // Fragment headers are always 8 bytes
case layers.IPProtocolAH:
// Auth headers, used by IPSec, have a different meaning for header length
if dataLen < offset+1 {
break
}
next = int(data[offset+1]+2) << 2
default:
// Normal ipv6 header length processing
if dataLen < offset+1 {
break
}
next = int(data[offset+1]+1) << 3
}
if next <= 0 {
// Safety check, each ipv6 header has to be at least 8 bytes
next = 8
}
protoAt = offset
offset = offset + next
}
return ErrIPv6CouldNotFindPayload
}
func parseV4(data []byte, incoming bool, fp *firewall.Packet) error {
// Do we at least have an ipv4 header worth of data?
if len(data) < ipv4.HeaderLen {
return ErrIPv4PacketTooShort
} }
// Adjust our start position based on the advertised ip header length // Adjust our start position based on the advertised ip header length
ihl := int(data[0]&0x0f) << 2 ihl := int(data[0]&0x0f) << 2
// Well formed ip header length? // Well-formed ip header length?
if ihl < ipv4.HeaderLen { if ihl < ipv4.HeaderLen {
return fmt.Errorf("packet had an invalid header length: %v", ihl) return ErrIPv4InvalidHeaderLength
} }
// Check if this is the second or further fragment of a fragmented packet. // Check if this is the second or further fragment of a fragmented packet.
@ -333,14 +420,13 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
minLen += minFwPacketLen minLen += minFwPacketLen
} }
if len(data) < minLen { if len(data) < minLen {
return fmt.Errorf("packet is less than %v bytes, ip header len: %v", minLen, ihl) return ErrIPv4InvalidHeaderLength
} }
// Firewall packets are locally oriented // Firewall packets are locally oriented
if incoming { if incoming {
//TODO: IPV6-WORK fp.RemoteAddr, _ = netip.AddrFromSlice(data[12:16])
fp.RemoteIP, _ = netip.AddrFromSlice(data[12:16]) fp.LocalAddr, _ = netip.AddrFromSlice(data[16:20])
fp.LocalIP, _ = netip.AddrFromSlice(data[16:20])
if fp.Fragment || fp.Protocol == firewall.ProtoICMP { if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
fp.RemotePort = 0 fp.RemotePort = 0
fp.LocalPort = 0 fp.LocalPort = 0
@ -349,9 +435,8 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
} }
} else { } else {
//TODO: IPV6-WORK fp.LocalAddr, _ = netip.AddrFromSlice(data[12:16])
fp.LocalIP, _ = netip.AddrFromSlice(data[12:16]) fp.RemoteAddr, _ = netip.AddrFromSlice(data[16:20])
fp.RemoteIP, _ = netip.AddrFromSlice(data[16:20])
if fp.Fragment || fp.Protocol == firewall.ProtoICMP { if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
fp.RemotePort = 0 fp.RemotePort = 0
fp.LocalPort = 0 fp.LocalPort = 0
@ -386,8 +471,6 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb) out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
if err != nil { if err != nil {
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet") hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
//TODO: maybe after build 64 is out? 06/14/2018 - NB
//f.sendRecvError(hostinfo.remote, header.RemoteIndex)
return false return false
} }
@ -434,9 +517,8 @@ func (f *Interface) maybeSendRecvError(endpoint netip.AddrPort, index uint32) {
func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) { func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) {
f.messageMetrics.Tx(header.RecvError, 0, 1) f.messageMetrics.Tx(header.RecvError, 0, 1)
//TODO: this should be a signed message so we can trust that we should drop the index
b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0) b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0)
f.outside.WriteTo(b, endpoint) _ = f.outside.WriteTo(b, endpoint)
if f.l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
f.l.WithField("index", index). f.l.WithField("index", index).
WithField("udpAddr", endpoint). WithField("udpAddr", endpoint).
@ -470,49 +552,3 @@ func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
// We also delete it from pending hostmap to allow for fast reconnect. // We also delete it from pending hostmap to allow for fast reconnect.
f.handshakeManager.DeleteHostInfo(hostinfo) f.handshakeManager.DeleteHostInfo(hostinfo)
} }
/*
func (f *Interface) sendMeta(ci *ConnectionState, endpoint *net.UDPAddr, meta *NebulaMeta) {
if ci.eKey != nil {
//TODO: log error?
return
}
msg, err := proto.Marshal(meta)
if err != nil {
l.Debugln("failed to encode header")
}
c := ci.messageCounter
b := HeaderEncode(nil, Version, uint8(metadata), 0, hostinfo.remoteIndexId, c)
ci.messageCounter++
msg := ci.eKey.EncryptDanger(b, nil, msg, c)
//msg := ci.eKey.EncryptDanger(b, nil, []byte(fmt.Sprintf("%d", counter)), c)
f.outside.WriteTo(msg, endpoint)
}
*/
func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte, caPool *cert.CAPool) (*cert.CachedCertificate, error) {
pk := h.PeerStatic()
if pk == nil {
return nil, errors.New("no peer static key was present")
}
if rawCertBytes == nil {
return nil, errors.New("provided payload was empty")
}
c, err := cert.UnmarshalCertificateFromHandshake(rawCertBytes, pk)
if err != nil {
return nil, fmt.Errorf("error unmarshaling cert: %w", err)
}
cc, err := caPool.VerifyCertificate(time.Now(), c)
if err != nil {
return nil, fmt.Errorf("certificate validation failed: %w", err)
}
return cc, nil
}

View File

@ -1,10 +1,15 @@
package nebula package nebula
import ( import (
"bytes"
"encoding/binary"
"net" "net"
"net/netip" "net/netip"
"testing" "testing"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
@ -13,9 +18,15 @@ import (
func Test_newPacket(t *testing.T) { func Test_newPacket(t *testing.T) {
p := &firewall.Packet{} p := &firewall.Packet{}
// length fail // length fails
err := newPacket([]byte{0, 1}, true, p) err := newPacket([]byte{}, true, p)
assert.EqualError(t, err, "packet is less than 20 bytes") assert.ErrorIs(t, err, ErrPacketTooShort)
err = newPacket([]byte{0x40}, true, p)
assert.ErrorIs(t, err, ErrIPv4PacketTooShort)
err = newPacket([]byte{0x60}, true, p)
assert.ErrorIs(t, err, ErrIPv6PacketTooShort)
// length fail with ip options // length fail with ip options
h := ipv4.Header{ h := ipv4.Header{
@ -28,16 +39,15 @@ func Test_newPacket(t *testing.T) {
b, _ := h.Marshal() b, _ := h.Marshal()
err = newPacket(b, true, p) err = newPacket(b, true, p)
assert.ErrorIs(t, err, ErrIPv4InvalidHeaderLength)
assert.EqualError(t, err, "packet is less than 28 bytes, ip header len: 24")
// not an ipv4 packet // not an ipv4 packet
err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p) err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
assert.EqualError(t, err, "packet is not ipv4, type: 0") assert.ErrorIs(t, err, ErrUnknownIPVersion)
// invalid ihl // invalid ihl
err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p) err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
assert.EqualError(t, err, "packet had an invalid header length: 8") assert.ErrorIs(t, err, ErrIPv4InvalidHeaderLength)
// account for variable ip header length - incoming // account for variable ip header length - incoming
h = ipv4.Header{ h = ipv4.Header{
@ -54,11 +64,12 @@ func Test_newPacket(t *testing.T) {
err = newPacket(b, true, p) err = newPacket(b, true, p)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, p.Protocol, uint8(firewall.ProtoTCP)) assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.2")) assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.LocalAddr)
assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.1")) assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.RemoteAddr)
assert.Equal(t, p.RemotePort, uint16(3)) assert.Equal(t, uint16(3), p.RemotePort)
assert.Equal(t, p.LocalPort, uint16(4)) assert.Equal(t, uint16(4), p.LocalPort)
assert.False(t, p.Fragment)
// account for variable ip header length - outgoing // account for variable ip header length - outgoing
h = ipv4.Header{ h = ipv4.Header{
@ -75,9 +86,506 @@ func Test_newPacket(t *testing.T) {
err = newPacket(b, false, p) err = newPacket(b, false, p)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, p.Protocol, uint8(2)) assert.Equal(t, uint8(2), p.Protocol)
assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.1")) assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.LocalAddr)
assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.2")) assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.RemoteAddr)
assert.Equal(t, p.RemotePort, uint16(6)) assert.Equal(t, uint16(6), p.RemotePort)
assert.Equal(t, p.LocalPort, uint16(5)) assert.Equal(t, uint16(5), p.LocalPort)
assert.False(t, p.Fragment)
}
func Test_newPacket_v6(t *testing.T) {
p := &firewall.Packet{}
// invalid ipv6
ip := layers.IPv6{
Version: 6,
HopLimit: 128,
SrcIP: net.IPv6linklocalallrouters,
DstIP: net.IPv6linklocalallnodes,
}
buffer := gopacket.NewSerializeBuffer()
opt := gopacket.SerializeOptions{
ComputeChecksums: false,
FixLengths: false,
}
err := gopacket.SerializeLayers(buffer, opt, &ip)
assert.NoError(t, err)
err = newPacket(buffer.Bytes(), true, p)
assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
// A good ICMP packet
ip = layers.IPv6{
Version: 6,
NextHeader: layers.IPProtocolICMPv6,
HopLimit: 128,
SrcIP: net.IPv6linklocalallrouters,
DstIP: net.IPv6linklocalallnodes,
}
icmp := layers.ICMPv6{}
buffer.Clear()
err = gopacket.SerializeLayers(buffer, opt, &ip, &icmp)
if err != nil {
panic(err)
}
err = newPacket(buffer.Bytes(), true, p)
assert.Nil(t, err)
assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
assert.Equal(t, uint16(0), p.RemotePort)
assert.Equal(t, uint16(0), p.LocalPort)
assert.False(t, p.Fragment)
// A good ESP packet
b := buffer.Bytes()
b[6] = byte(layers.IPProtocolESP)
err = newPacket(b, true, p)
assert.Nil(t, err)
assert.Equal(t, uint8(layers.IPProtocolESP), p.Protocol)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
assert.Equal(t, uint16(0), p.RemotePort)
assert.Equal(t, uint16(0), p.LocalPort)
assert.False(t, p.Fragment)
// A good None packet
b = buffer.Bytes()
b[6] = byte(layers.IPProtocolNoNextHeader)
err = newPacket(b, true, p)
assert.Nil(t, err)
assert.Equal(t, uint8(layers.IPProtocolNoNextHeader), p.Protocol)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
assert.Equal(t, uint16(0), p.RemotePort)
assert.Equal(t, uint16(0), p.LocalPort)
assert.False(t, p.Fragment)
// An unknown protocol packet
b = buffer.Bytes()
b[6] = 255 // 255 is a reserved protocol number
err = newPacket(b, true, p)
assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
// A good UDP packet
ip = layers.IPv6{
Version: 6,
NextHeader: firewall.ProtoUDP,
HopLimit: 128,
SrcIP: net.IPv6linklocalallrouters,
DstIP: net.IPv6linklocalallnodes,
}
udp := layers.UDP{
SrcPort: layers.UDPPort(36123),
DstPort: layers.UDPPort(22),
}
err = udp.SetNetworkLayerForChecksum(&ip)
assert.NoError(t, err)
buffer.Clear()
err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload([]byte{0xde, 0xad, 0xbe, 0xef}))
if err != nil {
panic(err)
}
b = buffer.Bytes()
// incoming
err = newPacket(b, true, p)
assert.Nil(t, err)
assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
assert.Equal(t, uint16(36123), p.RemotePort)
assert.Equal(t, uint16(22), p.LocalPort)
assert.False(t, p.Fragment)
// outgoing
err = newPacket(b, false, p)
assert.Nil(t, err)
assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
assert.Equal(t, uint16(36123), p.LocalPort)
assert.Equal(t, uint16(22), p.RemotePort)
assert.False(t, p.Fragment)
// Too short UDP packet
err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
assert.ErrorIs(t, err, ErrIPv6PacketTooShort)
// A good TCP packet
b[6] = byte(layers.IPProtocolTCP)
// incoming
err = newPacket(b, true, p)
assert.Nil(t, err)
assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
assert.Equal(t, uint16(36123), p.RemotePort)
assert.Equal(t, uint16(22), p.LocalPort)
assert.False(t, p.Fragment)
// outgoing
err = newPacket(b, false, p)
assert.Nil(t, err)
assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
assert.Equal(t, uint16(36123), p.LocalPort)
assert.Equal(t, uint16(22), p.RemotePort)
assert.False(t, p.Fragment)
// Too short TCP packet
err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
assert.ErrorIs(t, err, ErrIPv6PacketTooShort)
// A good UDP packet with an AH header
ip = layers.IPv6{
Version: 6,
NextHeader: layers.IPProtocolAH,
HopLimit: 128,
SrcIP: net.IPv6linklocalallrouters,
DstIP: net.IPv6linklocalallnodes,
}
ah := layers.IPSecAH{
AuthenticationData: []byte{0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef},
}
ah.NextHeader = layers.IPProtocolUDP
udpHeader := []byte{
0x8d, 0x1b, // Source port 36123
0x00, 0x16, // Destination port 22
0x00, 0x00, // Length
0x00, 0x00, // Checksum
}
buffer.Clear()
err = ip.SerializeTo(buffer, opt)
if err != nil {
panic(err)
}
b = buffer.Bytes()
ahb := serializeAH(&ah)
b = append(b, ahb...)
b = append(b, udpHeader...)
err = newPacket(b, true, p)
assert.Nil(t, err)
assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
assert.Equal(t, uint16(36123), p.RemotePort)
assert.Equal(t, uint16(22), p.LocalPort)
assert.False(t, p.Fragment)
// Invalid AH header
b = buffer.Bytes()
err = newPacket(b, true, p)
assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
}
func Test_newPacket_ipv6Fragment(t *testing.T) {
p := &firewall.Packet{}
ip := &layers.IPv6{
Version: 6,
NextHeader: layers.IPProtocolIPv6Fragment,
HopLimit: 64,
SrcIP: net.IPv6linklocalallrouters,
DstIP: net.IPv6linklocalallnodes,
}
// First fragment
fragHeader1 := []byte{
uint8(layers.IPProtocolUDP), // Next Header (UDP)
0x00, // Reserved
0x00, // Fragment Offset high byte (0)
0x01, // Fragment Offset low byte & flags (M=1)
0x00, 0x00, 0x00, 0x01, // Identification
}
udpHeader := []byte{
0x8d, 0x1b, // Source port 36123
0x00, 0x16, // Destination port 22
0x00, 0x00, // Length
0x00, 0x00, // Checksum
}
buffer := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
err := ip.SerializeTo(buffer, opts)
if err != nil {
t.Fatal(err)
}
firstFrag := buffer.Bytes()
firstFrag = append(firstFrag, fragHeader1...)
firstFrag = append(firstFrag, udpHeader...)
firstFrag = append(firstFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
// Test first fragment incoming
err = newPacket(firstFrag, true, p)
assert.NoError(t, err)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
assert.Equal(t, uint16(36123), p.RemotePort)
assert.Equal(t, uint16(22), p.LocalPort)
assert.False(t, p.Fragment)
// Test first fragment outgoing
err = newPacket(firstFrag, false, p)
assert.NoError(t, err)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
assert.Equal(t, uint16(36123), p.LocalPort)
assert.Equal(t, uint16(22), p.RemotePort)
assert.False(t, p.Fragment)
// Second fragment
fragHeader2 := []byte{
uint8(layers.IPProtocolUDP), // Next Header (UDP)
0x00, // Reserved
0xb9, // Fragment Offset high byte (185)
0x01, // Fragment Offset low byte & flags (M=1)
0x00, 0x00, 0x00, 0x01, // Identification
}
buffer.Clear()
err = ip.SerializeTo(buffer, opts)
if err != nil {
t.Fatal(err)
}
secondFrag := buffer.Bytes()
secondFrag = append(secondFrag, fragHeader2...)
secondFrag = append(secondFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
// Test second fragment incoming
err = newPacket(secondFrag, true, p)
assert.NoError(t, err)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
assert.Equal(t, uint16(0), p.RemotePort)
assert.Equal(t, uint16(0), p.LocalPort)
assert.True(t, p.Fragment)
// Test second fragment outgoing
err = newPacket(secondFrag, false, p)
assert.NoError(t, err)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
assert.Equal(t, uint16(0), p.LocalPort)
assert.Equal(t, uint16(0), p.RemotePort)
assert.True(t, p.Fragment)
// Too short of a fragment packet
err = newPacket(secondFrag[:len(secondFrag)-10], false, p)
assert.ErrorIs(t, err, ErrIPv6PacketTooShort)
}
func BenchmarkParseV6(b *testing.B) {
// Regular UDP packet
ip := &layers.IPv6{
Version: 6,
NextHeader: layers.IPProtocolUDP,
HopLimit: 64,
SrcIP: net.IPv6linklocalallrouters,
DstIP: net.IPv6linklocalallnodes,
}
udp := &layers.UDP{
SrcPort: layers.UDPPort(36123),
DstPort: layers.UDPPort(22),
}
buffer := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
ComputeChecksums: false,
FixLengths: true,
}
err := gopacket.SerializeLayers(buffer, opts, ip, udp)
if err != nil {
b.Fatal(err)
}
normalPacket := buffer.Bytes()
// First Fragment packet
ipFrag := &layers.IPv6{
Version: 6,
NextHeader: layers.IPProtocolIPv6Fragment,
HopLimit: 64,
SrcIP: net.IPv6linklocalallrouters,
DstIP: net.IPv6linklocalallnodes,
}
fragHeader := []byte{
uint8(layers.IPProtocolUDP), // Next Header (UDP)
0x00, // Reserved
0x00, // Fragment Offset high byte (0)
0x01, // Fragment Offset low byte & flags (M=1)
0x00, 0x00, 0x00, 0x01, // Identification
}
udpHeader := []byte{
0x8d, 0x7b, // Source port 36123
0x00, 0x16, // Destination port 22
0x00, 0x00, // Length
0x00, 0x00, // Checksum
}
buffer.Clear()
err = ipFrag.SerializeTo(buffer, opts)
if err != nil {
b.Fatal(err)
}
firstFrag := buffer.Bytes()
firstFrag = append(firstFrag, fragHeader...)
firstFrag = append(firstFrag, udpHeader...)
firstFrag = append(firstFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
// Second Fragment packet
fragHeader[2] = 0xb9 // offset 185
buffer.Clear()
err = ipFrag.SerializeTo(buffer, opts)
if err != nil {
b.Fatal(err)
}
secondFrag := buffer.Bytes()
secondFrag = append(secondFrag, fragHeader...)
secondFrag = append(secondFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
fp := &firewall.Packet{}
b.Run("Normal", func(b *testing.B) {
for i := 0; i < b.N; i++ {
if err = parseV6(normalPacket, true, fp); err != nil {
b.Fatal(err)
}
}
})
b.Run("FirstFragment", func(b *testing.B) {
for i := 0; i < b.N; i++ {
if err = parseV6(firstFrag, true, fp); err != nil {
b.Fatal(err)
}
}
})
b.Run("SecondFragment", func(b *testing.B) {
for i := 0; i < b.N; i++ {
if err = parseV6(secondFrag, true, fp); err != nil {
b.Fatal(err)
}
}
})
// Evil packet
evilPacket := &layers.IPv6{
Version: 6,
NextHeader: layers.IPProtocolIPv6HopByHop,
HopLimit: 64,
SrcIP: net.IPv6linklocalallrouters,
DstIP: net.IPv6linklocalallnodes,
}
hopHeader := []byte{
uint8(layers.IPProtocolIPv6HopByHop), // Next Header (HopByHop)
0x00, // Length
0x00, 0x00, // Options and padding
0x00, 0x00, 0x00, 0x00, // More options and padding
}
lastHopHeader := []byte{
uint8(layers.IPProtocolUDP), // Next Header (UDP)
0x00, // Length
0x00, 0x00, // Options and padding
0x00, 0x00, 0x00, 0x00, // More options and padding
}
buffer.Clear()
err = evilPacket.SerializeTo(buffer, opts)
if err != nil {
b.Fatal(err)
}
evilBytes := buffer.Bytes()
for i := 0; i < 200; i++ {
evilBytes = append(evilBytes, hopHeader...)
}
evilBytes = append(evilBytes, lastHopHeader...)
evilBytes = append(evilBytes, udpHeader...)
evilBytes = append(evilBytes, []byte{0xde, 0xad, 0xbe, 0xef}...)
b.Run("200 HopByHop headers", func(b *testing.B) {
for i := 0; i < b.N; i++ {
if err = parseV6(evilBytes, false, fp); err != nil {
b.Fatal(err)
}
}
})
}
// Ensure authentication data is a multiple of 8 bytes by padding if necessary
func padAuthData(authData []byte) []byte {
// Length of Authentication Data must be a multiple of 8 bytes
paddingLength := (8 - (len(authData) % 8)) % 8 // Only pad if necessary
if paddingLength > 0 {
authData = append(authData, make([]byte, paddingLength)...)
}
return authData
}
// Custom function to manually serialize IPSecAH for both IPv4 and IPv6
func serializeAH(ah *layers.IPSecAH) []byte {
buf := new(bytes.Buffer)
// Ensure Authentication Data is a multiple of 8 bytes
ah.AuthenticationData = padAuthData(ah.AuthenticationData)
// Calculate Payload Length (in 32-bit words, minus 2)
payloadLen := uint8((12+len(ah.AuthenticationData))/4) - 2
// Serialize fields
if err := binary.Write(buf, binary.BigEndian, ah.NextHeader); err != nil {
panic(err)
}
if err := binary.Write(buf, binary.BigEndian, payloadLen); err != nil {
panic(err)
}
if err := binary.Write(buf, binary.BigEndian, ah.Reserved); err != nil {
panic(err)
}
if err := binary.Write(buf, binary.BigEndian, ah.SPI); err != nil {
panic(err)
}
if err := binary.Write(buf, binary.BigEndian, ah.Seq); err != nil {
panic(err)
}
if len(ah.AuthenticationData) > 0 {
if err := binary.Write(buf, binary.BigEndian, ah.AuthenticationData); err != nil {
panic(err)
}
}
return buf.Bytes()
} }

View File

@ -8,7 +8,7 @@ import (
type Device interface { type Device interface {
io.ReadWriteCloser io.ReadWriteCloser
Activate() error Activate() error
Cidr() netip.Prefix Networks() []netip.Prefix
Name() string Name() string
RouteFor(netip.Addr) netip.Addr RouteFor(netip.Addr) netip.Addr
NewMultiQueueReader() (io.ReadWriteCloser, error) NewMultiQueueReader() (io.ReadWriteCloser, error)

View File

@ -61,7 +61,7 @@ func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table
return routeTree, nil return routeTree, nil
} }
func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) { func parseRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
var err error var err error
r := c.Get("tun.routes") r := c.Get("tun.routes")
@ -117,12 +117,20 @@ func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
return nil, fmt.Errorf("entry %v.route in tun.routes failed to parse: %v", i+1, err) return nil, fmt.Errorf("entry %v.route in tun.routes failed to parse: %v", i+1, err)
} }
if !network.Contains(r.Cidr.Addr()) || r.Cidr.Bits() < network.Bits() { found := false
for _, network := range networks {
if network.Contains(r.Cidr.Addr()) && r.Cidr.Bits() >= network.Bits() {
found = true
break
}
}
if !found {
return nil, fmt.Errorf( return nil, fmt.Errorf(
"entry %v.route in tun.routes is not contained within the network attached to the certificate; route: %v, network: %v", "entry %v.route in tun.routes is not contained within the configured vpn networks; route: %v, networks: %v",
i+1, i+1,
r.Cidr.String(), r.Cidr.String(),
network.String(), networks,
) )
} }
@ -132,7 +140,7 @@ func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
return routes, nil return routes, nil
} }
func parseUnsafeRoutes(c *config.C, network netip.Prefix) ([]Route, error) { func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
var err error var err error
r := c.Get("tun.unsafe_routes") r := c.Get("tun.unsafe_routes")
@ -229,13 +237,15 @@ func parseUnsafeRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes failed to parse: %v", i+1, err) return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes failed to parse: %v", i+1, err)
} }
if network.Contains(r.Cidr.Addr()) { for _, network := range networks {
return nil, fmt.Errorf( if network.Contains(r.Cidr.Addr()) {
"entry %v.route in tun.unsafe_routes is contained within the network attached to the certificate; route: %v, network: %v", return nil, fmt.Errorf(
i+1, "entry %v.route in tun.unsafe_routes is contained within the configured vpn networks; route: %v, network: %v",
r.Cidr.String(), i+1,
network.String(), r.Cidr.String(),
) network.String(),
)
}
} }
routes[i] = r routes[i] = r

View File

@ -17,76 +17,82 @@ func Test_parseRoutes(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
// test no routes config // test no routes config
routes, err := parseRoutes(c, n) routes, err := parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, err) assert.Nil(t, err)
assert.Len(t, routes, 0) assert.Len(t, routes, 0)
// not an array // not an array
c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"} c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"}
routes, err = parseRoutes(c, n) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "tun.routes is not an array") assert.EqualError(t, err, "tun.routes is not an array")
// no routes // no routes
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}} c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}}
routes, err = parseRoutes(c, n) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, err) assert.Nil(t, err)
assert.Len(t, routes, 0) assert.Len(t, routes, 0)
// weird route // weird route
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}} c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}}
routes, err = parseRoutes(c, n) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1 in tun.routes is invalid") assert.EqualError(t, err, "entry 1 in tun.routes is invalid")
// no mtu // no mtu
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{}}} c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{}}}
routes, err = parseRoutes(c, n) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.mtu in tun.routes is not present") assert.EqualError(t, err, "entry 1.mtu in tun.routes is not present")
// bad mtu // bad mtu
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "nope"}}} c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "nope"}}}
routes, err = parseRoutes(c, n) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") assert.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
// low mtu // low mtu
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "499"}}} c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "499"}}}
routes, err = parseRoutes(c, n) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499") assert.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499")
// missing route // missing route
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500"}}} c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500"}}}
routes, err = parseRoutes(c, n) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.routes is not present") assert.EqualError(t, err, "entry 1.route in tun.routes is not present")
// unparsable route // unparsable route
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}} c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}}
routes, err = parseRoutes(c, n) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
// below network range // below network range
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}} c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}}
routes, err = parseRoutes(c, n) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the network attached to the certificate; route: 1.0.0.0/8, network: 10.0.0.0/24") assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 1.0.0.0/8, networks: [10.0.0.0/24]")
// above network range // above network range
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "10.0.1.0/24"}}} c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "10.0.1.0/24"}}}
routes, err = parseRoutes(c, n) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the network attached to the certificate; route: 10.0.1.0/24, network: 10.0.0.0/24") assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 10.0.1.0/24, networks: [10.0.0.0/24]")
// Not in multiple ranges
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "192.0.0.0/24"}}}
routes, err = parseRoutes(c, []netip.Prefix{n, netip.MustParsePrefix("192.1.0.0/24")})
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 192.0.0.0/24, networks: [10.0.0.0/24 192.1.0.0/24]")
// happy case // happy case
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{ c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{
map[interface{}]interface{}{"mtu": "9000", "route": "10.0.0.0/29"}, map[interface{}]interface{}{"mtu": "9000", "route": "10.0.0.0/29"},
map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"}, map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"},
}} }}
routes, err = parseRoutes(c, n) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, err) assert.Nil(t, err)
assert.Len(t, routes, 2) assert.Len(t, routes, 2)
@ -116,31 +122,31 @@ func Test_parseUnsafeRoutes(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
// test no routes config // test no routes config
routes, err := parseUnsafeRoutes(c, n) routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, err) assert.Nil(t, err)
assert.Len(t, routes, 0) assert.Len(t, routes, 0)
// not an array // not an array
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"}
routes, err = parseUnsafeRoutes(c, n) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "tun.unsafe_routes is not an array") assert.EqualError(t, err, "tun.unsafe_routes is not an array")
// no routes // no routes
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}}
routes, err = parseUnsafeRoutes(c, n) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, err) assert.Nil(t, err)
assert.Len(t, routes, 0) assert.Len(t, routes, 0)
// weird route // weird route
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}}
routes, err = parseUnsafeRoutes(c, n) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid") assert.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid")
// no via // no via
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{}}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{}}}
routes, err = parseUnsafeRoutes(c, n) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present") assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present")
@ -149,68 +155,68 @@ func Test_parseUnsafeRoutes(t *testing.T) {
127, false, nil, 1.0, []string{"1", "2"}, 127, false, nil, 1.0, []string{"1", "2"},
} { } {
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}}
routes, err = parseUnsafeRoutes(c, n) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue)) assert.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue))
} }
// unparsable via // unparsable via
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}}
routes, err = parseUnsafeRoutes(c, n) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP") assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP")
// missing route // missing route
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}}
routes, err = parseUnsafeRoutes(c, n) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present") assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present")
// unparsable route // unparsable route
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}}
routes, err = parseUnsafeRoutes(c, n) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
// within network range // within network range
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}}
routes, err = parseUnsafeRoutes(c, n) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the network attached to the certificate; route: 10.0.0.0/24, network: 10.0.0.0/24") assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the configured vpn networks; route: 10.0.0.0/24, network: 10.0.0.0/24")
// below network range // below network range
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
routes, err = parseUnsafeRoutes(c, n) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Len(t, routes, 1) assert.Len(t, routes, 1)
assert.Nil(t, err) assert.Nil(t, err)
// above network range // above network range
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}}
routes, err = parseUnsafeRoutes(c, n) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Len(t, routes, 1) assert.Len(t, routes, 1)
assert.Nil(t, err) assert.Nil(t, err)
// no mtu // no mtu
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
routes, err = parseUnsafeRoutes(c, n) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Len(t, routes, 1) assert.Len(t, routes, 1)
assert.Equal(t, 0, routes[0].MTU) assert.Equal(t, 0, routes[0].MTU)
// bad mtu // bad mtu
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}}
routes, err = parseUnsafeRoutes(c, n) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
// low mtu // low mtu
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "499"}}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "499"}}}
routes, err = parseUnsafeRoutes(c, n) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499") assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499")
// bad install // bad install
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}}
routes, err = parseUnsafeRoutes(c, n) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax") assert.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax")
@ -221,7 +227,7 @@ func Test_parseUnsafeRoutes(t *testing.T) {
map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1}, map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1},
map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"}, map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"},
}} }}
routes, err = parseUnsafeRoutes(c, n) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, err) assert.Nil(t, err)
assert.Len(t, routes, 4) assert.Len(t, routes, 4)
@ -260,7 +266,7 @@ func Test_makeRouteTree(t *testing.T) {
map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"}, map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"},
map[interface{}]interface{}{"via": "192.168.0.2", "route": "1.0.0.1/32"}, map[interface{}]interface{}{"via": "192.168.0.2", "route": "1.0.0.1/32"},
}} }}
routes, err := parseUnsafeRoutes(c, n) routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, routes, 2) assert.Len(t, routes, 2)
routeTree, err := makeRouteTree(l, routes, true) routeTree, err := makeRouteTree(l, routes, true)

View File

@ -11,36 +11,36 @@ import (
const DefaultMTU = 1300 const DefaultMTU = 1300
// TODO: We may be able to remove routines // TODO: We may be able to remove routines
type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) { func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
switch { switch {
case c.GetBool("tun.disabled", false): case c.GetBool("tun.disabled", false):
tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l) tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
return tun, nil return tun, nil
default: default:
return newTun(c, l, tunCidr, routines > 1) return newTun(c, l, vpnNetworks, routines > 1)
} }
} }
func NewFdDeviceFromConfig(fd *int) DeviceFactory { func NewFdDeviceFromConfig(fd *int) DeviceFactory {
return func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) { return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
return newTunFromFd(c, l, *fd, tunCidr) return newTunFromFd(c, l, *fd, vpnNetworks)
} }
} }
func getAllRoutesFromConfig(c *config.C, cidr netip.Prefix, initial bool) (bool, []Route, error) { func getAllRoutesFromConfig(c *config.C, vpnNetworks []netip.Prefix, initial bool) (bool, []Route, error) {
if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") { if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") {
return false, nil, nil return false, nil, nil
} }
routes, err := parseRoutes(c, cidr) routes, err := parseRoutes(c, vpnNetworks)
if err != nil { if err != nil {
return true, nil, util.NewContextualError("Could not parse tun.routes", nil, err) return true, nil, util.NewContextualError("Could not parse tun.routes", nil, err)
} }
unsafeRoutes, err := parseUnsafeRoutes(c, cidr) unsafeRoutes, err := parseUnsafeRoutes(c, vpnNetworks)
if err != nil { if err != nil {
return true, nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err) return true, nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err)
} }

View File

@ -18,14 +18,14 @@ import (
type tun struct { type tun struct {
io.ReadWriteCloser io.ReadWriteCloser
fd int fd int
cidr netip.Prefix vpnNetworks []netip.Prefix
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]] routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger l *logrus.Logger
} }
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) { func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
// XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly. // XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly.
// Be sure not to call file.Fd() as it will set the fd to blocking mode. // Be sure not to call file.Fd() as it will set the fd to blocking mode.
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
@ -33,7 +33,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix
t := &tun{ t := &tun{
ReadWriteCloser: file, ReadWriteCloser: file,
fd: deviceFd, fd: deviceFd,
cidr: cidr, vpnNetworks: vpnNetworks,
l: l, l: l,
} }
@ -52,7 +52,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix
return t, nil return t, nil
} }
func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) { func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTun not supported in Android") return nil, fmt.Errorf("newTun not supported in Android")
} }
@ -66,7 +66,7 @@ func (t tun) Activate() error {
} }
func (t *tun) reload(c *config.C, initial bool) error { func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil { if err != nil {
return err return err
} }
@ -86,8 +86,8 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil return nil
} }
func (t *tun) Cidr() netip.Prefix { func (t *tun) Networks() []netip.Prefix {
return t.cidr return t.vpnNetworks
} }
func (t *tun) Name() string { func (t *tun) Name() string {

View File

@ -24,56 +24,62 @@ import (
type tun struct { type tun struct {
io.ReadWriteCloser io.ReadWriteCloser
Device string Device string
cidr netip.Prefix vpnNetworks []netip.Prefix
DefaultMTU int DefaultMTU int
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]] routeTree atomic.Pointer[bart.Table[netip.Addr]]
linkAddr *netroute.LinkAddr linkAddr *netroute.LinkAddr
l *logrus.Logger l *logrus.Logger
// cache out buffer since we need to prepend 4 bytes for tun metadata // cache out buffer since we need to prepend 4 bytes for tun metadata
out []byte out []byte
} }
type sockaddrCtl struct {
scLen uint8
scFamily uint8
ssSysaddr uint16
scID uint32
scUnit uint32
scReserved [5]uint32
}
type ifReq struct { type ifReq struct {
Name [16]byte Name [unix.IFNAMSIZ]byte
Flags uint16 Flags uint16
pad [8]byte pad [8]byte
} }
var sockaddrCtlSize uintptr = 32
const ( const (
_SYSPROTO_CONTROL = 2 //define SYSPROTO_CONTROL 2 /* kernel control protocol */ _SIOCAIFADDR_IN6 = 2155899162
_AF_SYS_CONTROL = 2 //#define AF_SYS_CONTROL 2 /* corresponding sub address type */ _UTUN_OPT_IFNAME = 2
_PF_SYSTEM = unix.AF_SYSTEM //#define PF_SYSTEM AF_SYSTEM _IN6_IFF_NODAD = 0x0020
_CTLIOCGINFO = 3227799043 //#define CTLIOCGINFO _IOWR('N', 3, struct ctl_info) _IN6_IFF_SECURED = 0x0400
utunControlName = "com.apple.net.utun_control" utunControlName = "com.apple.net.utun_control"
) )
type ifreqAddr struct {
Name [16]byte
Addr unix.RawSockaddrInet4
pad [8]byte
}
type ifreqMTU struct { type ifreqMTU struct {
Name [16]byte Name [16]byte
MTU int32 MTU int32
pad [8]byte pad [8]byte
} }
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { type addrLifetime struct {
Expire float64
Preferred float64
Vltime uint32
Pltime uint32
}
type ifreqAlias4 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet4
DstAddr unix.RawSockaddrInet4
MaskAddr unix.RawSockaddrInet4
}
type ifreqAlias6 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet6
DstAddr unix.RawSockaddrInet6
PrefixMask unix.RawSockaddrInet6
Flags uint32
Lifetime addrLifetime
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
name := c.GetString("tun.dev", "") name := c.GetString("tun.dev", "")
ifIndex := -1 ifIndex := -1
if name != "" && name != "utun" { if name != "" && name != "utun" {
@ -86,66 +92,41 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
} }
} }
fd, err := unix.Socket(_PF_SYSTEM, unix.SOCK_DGRAM, _SYSPROTO_CONTROL) fd, err := unix.Socket(unix.AF_SYSTEM, unix.SOCK_DGRAM, unix.AF_SYS_CONTROL)
if err != nil { if err != nil {
return nil, fmt.Errorf("system socket: %v", err) return nil, fmt.Errorf("system socket: %v", err)
} }
var ctlInfo = &struct { var ctlInfo = &unix.CtlInfo{}
ctlID uint32 copy(ctlInfo.Name[:], utunControlName)
ctlName [96]byte
}{}
copy(ctlInfo.ctlName[:], utunControlName) err = unix.IoctlCtlInfo(fd, ctlInfo)
err = ioctl(uintptr(fd), uintptr(_CTLIOCGINFO), uintptr(unsafe.Pointer(ctlInfo)))
if err != nil { if err != nil {
return nil, fmt.Errorf("CTLIOCGINFO: %v", err) return nil, fmt.Errorf("CTLIOCGINFO: %v", err)
} }
sc := sockaddrCtl{ err = unix.Connect(fd, &unix.SockaddrCtl{
scLen: uint8(sockaddrCtlSize), ID: ctlInfo.Id,
scFamily: unix.AF_SYSTEM, Unit: uint32(ifIndex) + 1,
ssSysaddr: _AF_SYS_CONTROL, })
scID: ctlInfo.ctlID, if err != nil {
scUnit: uint32(ifIndex) + 1, return nil, fmt.Errorf("SYS_CONNECT: %v", err)
} }
_, _, errno := unix.RawSyscall( name, err = unix.GetsockoptString(fd, unix.AF_SYS_CONTROL, _UTUN_OPT_IFNAME)
unix.SYS_CONNECT, if err != nil {
uintptr(fd), return nil, fmt.Errorf("failed to retrieve tun name: %w", err)
uintptr(unsafe.Pointer(&sc)),
sockaddrCtlSize,
)
if errno != 0 {
return nil, fmt.Errorf("SYS_CONNECT: %v", errno)
} }
var ifName struct { err = unix.SetNonblock(fd, true)
name [16]byte
}
ifNameSize := uintptr(len(ifName.name))
_, _, errno = syscall.Syscall6(syscall.SYS_GETSOCKOPT, uintptr(fd),
2, // SYSPROTO_CONTROL
2, // UTUN_OPT_IFNAME
uintptr(unsafe.Pointer(&ifName)),
uintptr(unsafe.Pointer(&ifNameSize)), 0)
if errno != 0 {
return nil, fmt.Errorf("SYS_GETSOCKOPT: %v", errno)
}
name = string(ifName.name[:ifNameSize-1])
err = syscall.SetNonblock(fd, true)
if err != nil { if err != nil {
return nil, fmt.Errorf("SetNonblock: %v", err) return nil, fmt.Errorf("SetNonblock: %v", err)
} }
file := os.NewFile(uintptr(fd), "")
t := &tun{ t := &tun{
ReadWriteCloser: file, ReadWriteCloser: os.NewFile(uintptr(fd), ""),
Device: name, Device: name,
cidr: cidr, vpnNetworks: vpnNetworks,
DefaultMTU: c.GetInt("tun.mtu", DefaultMTU), DefaultMTU: c.GetInt("tun.mtu", DefaultMTU),
l: l, l: l,
} }
@ -172,7 +153,7 @@ func (t *tun) deviceBytes() (o [16]byte) {
return return
} }
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in Darwin") return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
} }
@ -186,16 +167,6 @@ func (t *tun) Close() error {
func (t *tun) Activate() error { func (t *tun) Activate() error {
devName := t.deviceBytes() devName := t.deviceBytes()
var addr, mask [4]byte
if !t.cidr.Addr().Is4() {
//TODO: IPV6-WORK
panic("need ipv6")
}
addr = t.cidr.Addr().As4()
copy(mask[:], prefixToMask(t.cidr))
s, err := unix.Socket( s, err := unix.Socket(
unix.AF_INET, unix.AF_INET,
unix.SOCK_DGRAM, unix.SOCK_DGRAM,
@ -208,66 +179,18 @@ func (t *tun) Activate() error {
fd := uintptr(s) fd := uintptr(s)
ifra := ifreqAddr{
Name: devName,
Addr: unix.RawSockaddrInet4{
Family: unix.AF_INET,
Addr: addr,
},
}
// Set the device ip address
if err = ioctl(fd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil {
return fmt.Errorf("failed to set tun address: %s", err)
}
// Set the device network
ifra.Addr.Addr = mask
if err = ioctl(fd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil {
return fmt.Errorf("failed to set tun netmask: %s", err)
}
// Set the device name
ifrf := ifReq{Name: devName}
if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to set tun device name: %s", err)
}
// Set the MTU on the device // Set the MTU on the device
ifm := ifreqMTU{Name: devName, MTU: int32(t.DefaultMTU)} ifm := ifreqMTU{Name: devName, MTU: int32(t.DefaultMTU)}
if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil { if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
return fmt.Errorf("failed to set tun mtu: %v", err) return fmt.Errorf("failed to set tun mtu: %v", err)
} }
/* // Get the device flags
// Set the transmit queue length ifrf := ifReq{Name: devName}
ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)} if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
if err = ioctl(fd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil { return fmt.Errorf("failed to get tun flags: %s", err)
// If we can't set the queue length nebula will still work but it may lead to packet loss
l.WithError(err).Error("Failed to set tun tx queue length")
}
*/
// Bring up the interface
ifrf.Flags = ifrf.Flags | unix.IFF_UP
if err = ioctl(fd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to bring the tun device up: %s", err)
} }
routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer func() {
unix.Shutdown(routeSock, unix.SHUT_RDWR)
err := unix.Close(routeSock)
if err != nil {
t.l.WithError(err).Error("failed to close AF_ROUTE socket")
}
}()
routeAddr := &netroute.Inet4Addr{}
maskAddr := &netroute.Inet4Addr{}
linkAddr, err := getLinkAddr(t.Device) linkAddr, err := getLinkAddr(t.Device)
if err != nil { if err != nil {
return err return err
@ -277,14 +200,18 @@ func (t *tun) Activate() error {
} }
t.linkAddr = linkAddr t.linkAddr = linkAddr
copy(routeAddr.IP[:], addr[:]) for _, network := range t.vpnNetworks {
copy(maskAddr.IP[:], mask[:]) if network.Addr().Is4() {
err = addRoute(routeSock, routeAddr, maskAddr, linkAddr) err = t.activate4(network)
if err != nil { if err != nil {
if errors.Is(err, unix.EEXIST) { return err
err = fmt.Errorf("unable to add tun route, identical route already exists: %s", t.cidr) }
} else {
err = t.activate6(network)
if err != nil {
return err
}
} }
return err
} }
// Run the interface // Run the interface
@ -297,8 +224,89 @@ func (t *tun) Activate() error {
return t.addRoutes(false) return t.addRoutes(false)
} }
func (t *tun) activate4(network netip.Prefix) error {
s, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
unix.IPPROTO_IP,
)
if err != nil {
return err
}
defer unix.Close(s)
ifr := ifreqAlias4{
Name: t.deviceBytes(),
Addr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: network.Addr().As4(),
},
DstAddr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: network.Addr().As4(),
},
MaskAddr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: prefixToMask(network).As4(),
},
}
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
return fmt.Errorf("failed to set tun v4 address: %s", err)
}
err = addRoute(network, t.linkAddr)
if err != nil {
return err
}
return nil
}
func (t *tun) activate6(network netip.Prefix) error {
s, err := unix.Socket(
unix.AF_INET6,
unix.SOCK_DGRAM,
unix.IPPROTO_IP,
)
if err != nil {
return err
}
defer unix.Close(s)
ifr := ifreqAlias6{
Name: t.deviceBytes(),
Addr: unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: network.Addr().As16(),
},
PrefixMask: unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: prefixToMask(network).As16(),
},
Lifetime: addrLifetime{
// never expires
Vltime: 0xffffffff,
Pltime: 0xffffffff,
},
//TODO: CERT-V2 should we disable DAD (duplicate address detection) and mark this as a secured address?
Flags: _IN6_IFF_NODAD,
}
if err := ioctl(uintptr(s), _SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil {
return fmt.Errorf("failed to set tun address: %s", err)
}
return nil
}
func (t *tun) reload(c *config.C, initial bool) error { func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil { if err != nil {
return err return err
} }
@ -343,7 +351,7 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
} }
// Get the LinkAddr for the interface of the given name // Get the LinkAddr for the interface of the given name
// TODO: Is there an easier way to fetch this when we create the interface? // Is there an easier way to fetch this when we create the interface?
// Maybe SIOCGIFINDEX? but this doesn't appear to exist in the darwin headers. // Maybe SIOCGIFINDEX? but this doesn't appear to exist in the darwin headers.
func getLinkAddr(name string) (*netroute.LinkAddr, error) { func getLinkAddr(name string) (*netroute.LinkAddr, error) {
rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0) rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0)
@ -371,38 +379,15 @@ func getLinkAddr(name string) (*netroute.LinkAddr, error) {
} }
func (t *tun) addRoutes(logErrors bool) error { func (t *tun) addRoutes(logErrors bool) error {
routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer func() {
unix.Shutdown(routeSock, unix.SHUT_RDWR)
err := unix.Close(routeSock)
if err != nil {
t.l.WithError(err).Error("failed to close AF_ROUTE socket")
}
}()
routeAddr := &netroute.Inet4Addr{}
maskAddr := &netroute.Inet4Addr{}
routes := *t.Routes.Load() routes := *t.Routes.Load()
for _, r := range routes { for _, r := range routes {
if !r.Via.IsValid() || !r.Install { if !r.Via.IsValid() || !r.Install {
// We don't allow route MTUs so only install routes with a via // We don't allow route MTUs so only install routes with a via
continue continue
} }
if !r.Cidr.Addr().Is4() { err := addRoute(r.Cidr, t.linkAddr)
//TODO: implement ipv6
panic("Cant handle ipv6 routes yet")
}
routeAddr.IP = r.Cidr.Addr().As4()
//TODO: we could avoid the copy
copy(maskAddr.IP[:], prefixToMask(r.Cidr))
err := addRoute(routeSock, routeAddr, maskAddr, t.linkAddr)
if err != nil { if err != nil {
if errors.Is(err, unix.EEXIST) { if errors.Is(err, unix.EEXIST) {
t.l.WithField("route", r.Cidr). t.l.WithField("route", r.Cidr).
@ -424,36 +409,12 @@ func (t *tun) addRoutes(logErrors bool) error {
} }
func (t *tun) removeRoutes(routes []Route) error { func (t *tun) removeRoutes(routes []Route) error {
routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer func() {
unix.Shutdown(routeSock, unix.SHUT_RDWR)
err := unix.Close(routeSock)
if err != nil {
t.l.WithError(err).Error("failed to close AF_ROUTE socket")
}
}()
routeAddr := &netroute.Inet4Addr{}
maskAddr := &netroute.Inet4Addr{}
for _, r := range routes { for _, r := range routes {
if !r.Install { if !r.Install {
continue continue
} }
if r.Cidr.Addr().Is6() { err := delRoute(r.Cidr, t.linkAddr)
//TODO: implement ipv6
panic("Cant handle ipv6 routes yet")
}
routeAddr.IP = r.Cidr.Addr().As4()
copy(maskAddr.IP[:], prefixToMask(r.Cidr))
err := delRoute(routeSock, routeAddr, maskAddr, t.linkAddr)
if err != nil { if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route") t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else { } else {
@ -463,23 +424,39 @@ func (t *tun) removeRoutes(routes []Route) error {
return nil return nil
} }
func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error { func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
r := netroute.RouteMessage{ sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := &netroute.RouteMessage{
Version: unix.RTM_VERSION, Version: unix.RTM_VERSION,
Type: unix.RTM_ADD, Type: unix.RTM_ADD,
Flags: unix.RTF_UP, Flags: unix.RTF_UP,
Seq: 1, Seq: 1,
Addrs: []netroute.Addr{
unix.RTAX_DST: addr,
unix.RTAX_GATEWAY: link,
unix.RTAX_NETMASK: mask,
},
} }
data, err := r.Marshal() if prefix.Addr().Is4() {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: gateway,
}
} else {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: gateway,
}
}
data, err := route.Marshal()
if err != nil { if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err) return fmt.Errorf("failed to create route.RouteMessage: %w", err)
} }
_, err = unix.Write(sock, data[:]) _, err = unix.Write(sock, data[:])
if err != nil { if err != nil {
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err) return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
@ -488,19 +465,34 @@ func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr)
return nil return nil
} }
func delRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error { func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
r := netroute.RouteMessage{ sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := netroute.RouteMessage{
Version: unix.RTM_VERSION, Version: unix.RTM_VERSION,
Type: unix.RTM_DELETE, Type: unix.RTM_DELETE,
Seq: 1, Seq: 1,
Addrs: []netroute.Addr{
unix.RTAX_DST: addr,
unix.RTAX_GATEWAY: link,
unix.RTAX_NETMASK: mask,
},
} }
data, err := r.Marshal() if prefix.Addr().Is4() {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: gateway,
}
} else {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: gateway,
}
}
data, err := route.Marshal()
if err != nil { if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err) return fmt.Errorf("failed to create route.RouteMessage: %w", err)
} }
@ -513,7 +505,6 @@ func delRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr)
} }
func (t *tun) Read(to []byte) (int, error) { func (t *tun) Read(to []byte) (int, error) {
buf := make([]byte, len(to)+4) buf := make([]byte, len(to)+4)
n, err := t.ReadWriteCloser.Read(buf) n, err := t.ReadWriteCloser.Read(buf)
@ -551,8 +542,8 @@ func (t *tun) Write(from []byte) (int, error) {
return n - 4, err return n - 4, err
} }
func (t *tun) Cidr() netip.Prefix { func (t *tun) Networks() []netip.Prefix {
return t.cidr return t.vpnNetworks
} }
func (t *tun) Name() string { func (t *tun) Name() string {
@ -563,10 +554,12 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin") return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
} }
func prefixToMask(prefix netip.Prefix) []byte { func prefixToMask(prefix netip.Prefix) netip.Addr {
pLen := 128 pLen := 128
if prefix.Addr().Is4() { if prefix.Addr().Is4() {
pLen = 32 pLen = 32
} }
return net.CIDRMask(prefix.Bits(), pLen)
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
return addr
} }

View File

@ -12,8 +12,8 @@ import (
) )
type disabledTun struct { type disabledTun struct {
read chan []byte read chan []byte
cidr netip.Prefix vpnNetworks []netip.Prefix
// Track these metrics since we don't have the tun device to do it for us // Track these metrics since we don't have the tun device to do it for us
tx metrics.Counter tx metrics.Counter
@ -21,11 +21,11 @@ type disabledTun struct {
l *logrus.Logger l *logrus.Logger
} }
func newDisabledTun(cidr netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun { func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
tun := &disabledTun{ tun := &disabledTun{
cidr: cidr, vpnNetworks: vpnNetworks,
read: make(chan []byte, queueLen), read: make(chan []byte, queueLen),
l: l, l: l,
} }
if metricsEnabled { if metricsEnabled {
@ -47,8 +47,8 @@ func (*disabledTun) RouteFor(addr netip.Addr) netip.Addr {
return netip.Addr{} return netip.Addr{}
} }
func (t *disabledTun) Cidr() netip.Prefix { func (t *disabledTun) Networks() []netip.Prefix {
return t.cidr return t.vpnNetworks
} }
func (*disabledTun) Name() string { func (*disabledTun) Name() string {

View File

@ -46,12 +46,12 @@ type ifreqDestroy struct {
} }
type tun struct { type tun struct {
Device string Device string
cidr netip.Prefix vpnNetworks []netip.Prefix
MTU int MTU int
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]] routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger l *logrus.Logger
io.ReadWriteCloser io.ReadWriteCloser
} }
@ -78,11 +78,11 @@ func (t *tun) Close() error {
return nil return nil
} }
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD") return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
} }
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open existing tun device // Try to open existing tun device
var file *os.File var file *os.File
var err error var err error
@ -150,7 +150,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
t := &tun{ t := &tun{
ReadWriteCloser: file, ReadWriteCloser: file,
Device: deviceName, Device: deviceName,
cidr: cidr, vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU), MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l, l: l,
} }
@ -170,16 +170,16 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
return t, nil return t, nil
} }
func (t *tun) Activate() error { func (t *tun) addIp(cidr netip.Prefix) error {
var err error var err error
// TODO use syscalls instead of exec.Command // TODO use syscalls instead of exec.Command
cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String()) cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String()) t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil { if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err) return fmt.Errorf("failed to run 'ifconfig': %s", err)
} }
cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), "-interface", t.Device) cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), "-interface", t.Device)
t.l.Debug("command: ", cmd.String()) t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil { if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err) return fmt.Errorf("failed to run 'route add': %s", err)
@ -195,8 +195,18 @@ func (t *tun) Activate() error {
return t.addRoutes(false) return t.addRoutes(false)
} }
func (t *tun) Activate() error {
for i := range t.vpnNetworks {
err := t.addIp(t.vpnNetworks[i])
if err != nil {
return err
}
}
return nil
}
func (t *tun) reload(c *config.C, initial bool) error { func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil { if err != nil {
return err return err
} }
@ -237,8 +247,8 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
return r return r
} }
func (t *tun) Cidr() netip.Prefix { func (t *tun) Networks() []netip.Prefix {
return t.cidr return t.vpnNetworks
} }
func (t *tun) Name() string { func (t *tun) Name() string {

View File

@ -21,20 +21,20 @@ import (
type tun struct { type tun struct {
io.ReadWriteCloser io.ReadWriteCloser
cidr netip.Prefix vpnNetworks []netip.Prefix
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]] routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger l *logrus.Logger
} }
func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) { func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTun not supported in iOS") return nil, fmt.Errorf("newTun not supported in iOS")
} }
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) { func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
file := os.NewFile(uintptr(deviceFd), "/dev/tun") file := os.NewFile(uintptr(deviceFd), "/dev/tun")
t := &tun{ t := &tun{
cidr: cidr, vpnNetworks: vpnNetworks,
ReadWriteCloser: &tunReadCloser{f: file}, ReadWriteCloser: &tunReadCloser{f: file},
l: l, l: l,
} }
@ -59,7 +59,7 @@ func (t *tun) Activate() error {
} }
func (t *tun) reload(c *config.C, initial bool) error { func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil { if err != nil {
return err return err
} }
@ -142,8 +142,8 @@ func (tr *tunReadCloser) Close() error {
return tr.f.Close() return tr.f.Close()
} }
func (t *tun) Cidr() netip.Prefix { func (t *tun) Networks() []netip.Prefix {
return t.cidr return t.vpnNetworks
} }
func (t *tun) Name() string { func (t *tun) Name() string {

View File

@ -11,6 +11,7 @@ import (
"os" "os"
"strings" "strings"
"sync/atomic" "sync/atomic"
"time"
"unsafe" "unsafe"
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
@ -25,7 +26,7 @@ type tun struct {
io.ReadWriteCloser io.ReadWriteCloser
fd int fd int
Device string Device string
cidr netip.Prefix vpnNetworks []netip.Prefix
MaxMTU int MaxMTU int
DefaultMTU int DefaultMTU int
TXQueueLen int TXQueueLen int
@ -40,18 +41,16 @@ type tun struct {
l *logrus.Logger l *logrus.Logger
} }
func (t *tun) Networks() []netip.Prefix {
return t.vpnNetworks
}
type ifReq struct { type ifReq struct {
Name [16]byte Name [16]byte
Flags uint16 Flags uint16
pad [8]byte pad [8]byte
} }
type ifreqAddr struct {
Name [16]byte
Addr unix.RawSockaddrInet4
pad [8]byte
}
type ifreqMTU struct { type ifreqMTU struct {
Name [16]byte Name [16]byte
MTU int32 MTU int32
@ -64,10 +63,10 @@ type ifreqQLEN struct {
pad [8]byte pad [8]byte
} }
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) { func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
t, err := newTunGeneric(c, l, file, cidr) t, err := newTunGeneric(c, l, file, vpnNetworks)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -77,7 +76,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix
return t, nil return t, nil
} }
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (*tun, error) { func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil { if err != nil {
// If /dev/net/tun doesn't exist, try to create it (will happen in docker) // If /dev/net/tun doesn't exist, try to create it (will happen in docker)
@ -112,7 +111,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (
name := strings.Trim(string(req.Name[:]), "\x00") name := strings.Trim(string(req.Name[:]), "\x00")
file := os.NewFile(uintptr(fd), "/dev/net/tun") file := os.NewFile(uintptr(fd), "/dev/net/tun")
t, err := newTunGeneric(c, l, file, cidr) t, err := newTunGeneric(c, l, file, vpnNetworks)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -122,11 +121,11 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (
return t, nil return t, nil
} }
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Prefix) (*tun, error) { func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
t := &tun{ t := &tun{
ReadWriteCloser: file, ReadWriteCloser: file,
fd: int(file.Fd()), fd: int(file.Fd()),
cidr: cidr, vpnNetworks: vpnNetworks,
TXQueueLen: c.GetInt("tun.tx_queue", 500), TXQueueLen: c.GetInt("tun.tx_queue", 500),
useSystemRoutes: c.GetBool("tun.use_system_route_table", false), useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
l: l, l: l,
@ -148,7 +147,7 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Pref
} }
func (t *tun) reload(c *config.C, initial bool) error { func (t *tun) reload(c *config.C, initial bool) error {
routeChange, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) routeChange, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil { if err != nil {
return err return err
} }
@ -190,11 +189,13 @@ func (t *tun) reload(c *config.C, initial bool) error {
} }
if oldDefaultMTU != newDefaultMTU { if oldDefaultMTU != newDefaultMTU {
err := t.setDefaultRoute() for i := range t.vpnNetworks {
if err != nil { err := t.setDefaultRoute(t.vpnNetworks[i])
t.l.Warn(err) if err != nil {
} else { t.l.Warn(err)
t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU) } else {
t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU)
}
} }
} }
@ -237,10 +238,10 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
func (t *tun) Write(b []byte) (int, error) { func (t *tun) Write(b []byte) (int, error) {
var nn int var nn int
max := len(b) maximum := len(b)
for { for {
n, err := unix.Write(t.fd, b[nn:max]) n, err := unix.Write(t.fd, b[nn:maximum])
if n > 0 { if n > 0 {
nn += n nn += n
} }
@ -265,6 +266,58 @@ func (t *tun) deviceBytes() (o [16]byte) {
return return
} }
func hasNetlinkAddr(al []*netlink.Addr, x netlink.Addr) bool {
for i := range al {
if al[i].Equal(x) {
return true
}
}
return false
}
// addIPs uses netlink to add all addresses that don't exist, then it removes ones that should not be there
func (t *tun) addIPs(link netlink.Link) error {
newAddrs := make([]*netlink.Addr, len(t.vpnNetworks))
for i := range t.vpnNetworks {
newAddrs[i] = &netlink.Addr{
IPNet: &net.IPNet{
IP: t.vpnNetworks[i].Addr().AsSlice(),
Mask: net.CIDRMask(t.vpnNetworks[i].Bits(), t.vpnNetworks[i].Addr().BitLen()),
},
Label: t.vpnNetworks[i].Addr().Zone(),
}
}
//add all new addresses
for i := range newAddrs {
//TODO: CERT-V2 do we want to stack errors and try as many ops as possible?
//AddrReplace still adds new IPs, but if their properties change it will change them as well
if err := netlink.AddrReplace(link, newAddrs[i]); err != nil {
return err
}
}
//iterate over remainder, remove whoever shouldn't be there
al, err := netlink.AddrList(link, netlink.FAMILY_ALL)
if err != nil {
return fmt.Errorf("failed to get tun address list: %s", err)
}
for i := range al {
if hasNetlinkAddr(newAddrs, al[i]) {
continue
}
err = netlink.AddrDel(link, &al[i])
if err != nil {
t.l.WithError(err).Error("failed to remove address from tun address list")
} else {
t.l.WithField("removed", al[i].String()).Info("removed address not listed in cert(s)")
}
}
return nil
}
func (t *tun) Activate() error { func (t *tun) Activate() error {
devName := t.deviceBytes() devName := t.deviceBytes()
@ -272,15 +325,8 @@ func (t *tun) Activate() error {
t.watchRoutes() t.watchRoutes()
} }
var addr, mask [4]byte
//TODO: IPV6-WORK
addr = t.cidr.Addr().As4()
tmask := net.CIDRMask(t.cidr.Bits(), 32)
copy(mask[:], tmask)
s, err := unix.Socket( s, err := unix.Socket(
unix.AF_INET, unix.AF_INET, //because everything we use t.ioctlFd for is address family independent, this is fine
unix.SOCK_DGRAM, unix.SOCK_DGRAM,
unix.IPPROTO_IP, unix.IPPROTO_IP,
) )
@ -289,31 +335,19 @@ func (t *tun) Activate() error {
} }
t.ioctlFd = uintptr(s) t.ioctlFd = uintptr(s)
ifra := ifreqAddr{
Name: devName,
Addr: unix.RawSockaddrInet4{
Family: unix.AF_INET,
Addr: addr,
},
}
// Set the device ip address
if err = ioctl(t.ioctlFd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil {
return fmt.Errorf("failed to set tun address: %s", err)
}
// Set the device network
ifra.Addr.Addr = mask
if err = ioctl(t.ioctlFd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil {
return fmt.Errorf("failed to set tun netmask: %s", err)
}
// Set the device name // Set the device name
ifrf := ifReq{Name: devName} ifrf := ifReq{Name: devName}
if err = ioctl(t.ioctlFd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { if err = ioctl(t.ioctlFd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to set tun device name: %s", err) return fmt.Errorf("failed to set tun device name: %s", err)
} }
link, err := netlink.LinkByName(t.Device)
if err != nil {
return fmt.Errorf("failed to get tun device link: %s", err)
}
t.deviceIndex = link.Attrs().Index
// Setup our default MTU // Setup our default MTU
t.setMTU() t.setMTU()
@ -324,20 +358,21 @@ func (t *tun) Activate() error {
t.l.WithError(err).Error("Failed to set tun tx queue length") t.l.WithError(err).Error("Failed to set tun tx queue length")
} }
if err = t.addIPs(link); err != nil {
return err
}
// Bring up the interface // Bring up the interface
ifrf.Flags = ifrf.Flags | unix.IFF_UP ifrf.Flags = ifrf.Flags | unix.IFF_UP
if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to bring the tun device up: %s", err) return fmt.Errorf("failed to bring the tun device up: %s", err)
} }
link, err := netlink.LinkByName(t.Device) //set route MTU
if err != nil { for i := range t.vpnNetworks {
return fmt.Errorf("failed to get tun device link: %s", err) if err = t.setDefaultRoute(t.vpnNetworks[i]); err != nil {
} return fmt.Errorf("failed to set default route MTU: %w", err)
t.deviceIndex = link.Attrs().Index }
if err = t.setDefaultRoute(); err != nil {
return err
} }
// Set the routes // Set the routes
@ -363,12 +398,10 @@ func (t *tun) setMTU() {
} }
} }
func (t *tun) setDefaultRoute() error { func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
// Default route
dr := &net.IPNet{ dr := &net.IPNet{
IP: t.cidr.Masked().Addr().AsSlice(), IP: cidr.Masked().Addr().AsSlice(),
Mask: net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen()), Mask: net.CIDRMask(cidr.Bits(), cidr.Addr().BitLen()),
} }
nr := netlink.Route{ nr := netlink.Route{
@ -377,14 +410,27 @@ func (t *tun) setDefaultRoute() error {
MTU: t.DefaultMTU, MTU: t.DefaultMTU,
AdvMSS: t.advMSS(Route{}), AdvMSS: t.advMSS(Route{}),
Scope: unix.RT_SCOPE_LINK, Scope: unix.RT_SCOPE_LINK,
Src: net.IP(t.cidr.Addr().AsSlice()), Src: net.IP(cidr.Addr().AsSlice()),
Protocol: unix.RTPROT_KERNEL, Protocol: unix.RTPROT_KERNEL,
Table: unix.RT_TABLE_MAIN, Table: unix.RT_TABLE_MAIN,
Type: unix.RTN_UNICAST, Type: unix.RTN_UNICAST,
} }
err := netlink.RouteReplace(&nr) err := netlink.RouteReplace(&nr)
if err != nil { if err != nil {
return fmt.Errorf("failed to set mtu %v on the default route %v; %v", t.DefaultMTU, dr, err) t.l.WithError(err).WithField("cidr", cidr).Warn("Failed to set default route MTU, retrying")
//retry twice more -- on some systems there appears to be a race condition where if we set routes too soon, netlink says `invalid argument`
for i := 0; i < 2; i++ {
time.Sleep(100 * time.Millisecond)
err = netlink.RouteReplace(&nr)
if err == nil {
break
} else {
t.l.WithError(err).WithField("cidr", cidr).WithField("mtu", t.DefaultMTU).Warn("Failed to set default route MTU, retrying")
}
}
if err != nil {
return fmt.Errorf("failed to set mtu %v on the default route %v; %v", t.DefaultMTU, dr, err)
}
} }
return nil return nil
@ -463,10 +509,6 @@ func (t *tun) removeRoutes(routes []Route) {
} }
} }
func (t *tun) Cidr() netip.Prefix {
return t.cidr
}
func (t *tun) Name() string { func (t *tun) Name() string {
return t.Device return t.Device
} }
@ -515,7 +557,6 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
return return
} }
//TODO: IPV6-WORK what if not ok?
gwAddr, ok := netip.AddrFromSlice(r.Gw) gwAddr, ok := netip.AddrFromSlice(r.Gw)
if !ok { if !ok {
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address") t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
@ -523,15 +564,16 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
} }
gwAddr = gwAddr.Unmap() gwAddr = gwAddr.Unmap()
if !t.cidr.Contains(gwAddr) { withinNetworks := false
// Gateway isn't in our overlay network, ignore for i := range t.vpnNetworks {
t.l.WithField("route", r).Debug("Ignoring route update, not in our network") if t.vpnNetworks[i].Contains(gwAddr) {
return withinNetworks = true
break
}
} }
if !withinNetworks {
if x := r.Dst.IP.To4(); x == nil { // Gateway isn't in our overlay network, ignore
// Nebula only handles ipv4 on the overlay currently t.l.WithField("route", r).Debug("Ignoring route update, not in our networks")
t.l.WithField("route", r).Debug("Ignoring route update, destination is not ipv4")
return return
} }
@ -563,11 +605,11 @@ func (t *tun) Close() error {
} }
if t.ReadWriteCloser != nil { if t.ReadWriteCloser != nil {
t.ReadWriteCloser.Close() _ = t.ReadWriteCloser.Close()
} }
if t.ioctlFd > 0 { if t.ioctlFd > 0 {
os.NewFile(t.ioctlFd, "ioctlFd").Close() _ = os.NewFile(t.ioctlFd, "ioctlFd").Close()
} }
return nil return nil

View File

@ -27,12 +27,12 @@ type ifreqDestroy struct {
} }
type tun struct { type tun struct {
Device string Device string
cidr netip.Prefix vpnNetworks []netip.Prefix
MTU int MTU int
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]] routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger l *logrus.Logger
io.ReadWriteCloser io.ReadWriteCloser
} }
@ -58,13 +58,13 @@ func (t *tun) Close() error {
return nil return nil
} }
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD") return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
} }
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open tun device // Try to open tun device
var file *os.File var file *os.File
var err error var err error
@ -84,7 +84,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
t := &tun{ t := &tun{
ReadWriteCloser: file, ReadWriteCloser: file,
Device: deviceName, Device: deviceName,
cidr: cidr, vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU), MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l, l: l,
} }
@ -104,17 +104,17 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
return t, nil return t, nil
} }
func (t *tun) Activate() error { func (t *tun) addIp(cidr netip.Prefix) error {
var err error var err error
// TODO use syscalls instead of exec.Command // TODO use syscalls instead of exec.Command
cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String()) cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String()) t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil { if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err) return fmt.Errorf("failed to run 'ifconfig': %s", err)
} }
cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), t.cidr.Addr().String()) cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String()) t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil { if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err) return fmt.Errorf("failed to run 'route add': %s", err)
@ -130,8 +130,18 @@ func (t *tun) Activate() error {
return t.addRoutes(false) return t.addRoutes(false)
} }
func (t *tun) Activate() error {
for i := range t.vpnNetworks {
err := t.addIp(t.vpnNetworks[i])
if err != nil {
return err
}
}
return nil
}
func (t *tun) reload(c *config.C, initial bool) error { func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil { if err != nil {
return err return err
} }
@ -172,8 +182,8 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
return r return r
} }
func (t *tun) Cidr() netip.Prefix { func (t *tun) Networks() []netip.Prefix {
return t.cidr return t.vpnNetworks
} }
func (t *tun) Name() string { func (t *tun) Name() string {
@ -192,7 +202,7 @@ func (t *tun) addRoutes(logErrors bool) error {
continue continue
} }
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.Addr().String()) cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
t.l.Debug("command: ", cmd.String()) t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil { if err := cmd.Run(); err != nil {
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err) retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
@ -213,7 +223,8 @@ func (t *tun) removeRoutes(routes []Route) error {
continue continue
} }
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.cidr.Addr().String()) //TODO: CERT-V2 is this right?
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
t.l.Debug("command: ", cmd.String()) t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil { if err := cmd.Run(); err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route") t.l.WithError(err).WithField("route", r).Error("Failed to remove route")

View File

@ -21,12 +21,12 @@ import (
) )
type tun struct { type tun struct {
Device string Device string
cidr netip.Prefix vpnNetworks []netip.Prefix
MTU int MTU int
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]] routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger l *logrus.Logger
io.ReadWriteCloser io.ReadWriteCloser
@ -42,13 +42,13 @@ func (t *tun) Close() error {
return nil return nil
} }
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD") return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD")
} }
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
deviceName := c.GetString("tun.dev", "") deviceName := c.GetString("tun.dev", "")
if deviceName == "" { if deviceName == "" {
return nil, fmt.Errorf("a device name in the format of tunN must be specified") return nil, fmt.Errorf("a device name in the format of tunN must be specified")
@ -66,7 +66,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
t := &tun{ t := &tun{
ReadWriteCloser: file, ReadWriteCloser: file,
Device: deviceName, Device: deviceName,
cidr: cidr, vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU), MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l, l: l,
} }
@ -87,7 +87,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
} }
func (t *tun) reload(c *config.C, initial bool) error { func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil { if err != nil {
return err return err
} }
@ -123,10 +123,10 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil return nil
} }
func (t *tun) Activate() error { func (t *tun) addIp(cidr netip.Prefix) error {
var err error var err error
// TODO use syscalls instead of exec.Command // TODO use syscalls instead of exec.Command
cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String()) cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String()) t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil { if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err) return fmt.Errorf("failed to run 'ifconfig': %s", err)
@ -138,7 +138,7 @@ func (t *tun) Activate() error {
return fmt.Errorf("failed to run 'ifconfig': %s", err) return fmt.Errorf("failed to run 'ifconfig': %s", err)
} }
cmd = exec.Command("/sbin/route", "-n", "add", "-inet", t.cidr.String(), t.cidr.Addr().String()) cmd = exec.Command("/sbin/route", "-n", "add", "-inet", cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String()) t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil { if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err) return fmt.Errorf("failed to run 'route add': %s", err)
@ -148,6 +148,16 @@ func (t *tun) Activate() error {
return t.addRoutes(false) return t.addRoutes(false)
} }
func (t *tun) Activate() error {
for i := range t.vpnNetworks {
err := t.addIp(t.vpnNetworks[i])
if err != nil {
return err
}
}
return nil
}
func (t *tun) RouteFor(ip netip.Addr) netip.Addr { func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
r, _ := t.routeTree.Load().Lookup(ip) r, _ := t.routeTree.Load().Lookup(ip)
return r return r
@ -160,8 +170,8 @@ func (t *tun) addRoutes(logErrors bool) error {
// We don't allow route MTUs so only install routes with a via // We don't allow route MTUs so only install routes with a via
continue continue
} }
//TODO: CERT-V2 is this right?
cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.Addr().String()) cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
t.l.Debug("command: ", cmd.String()) t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil { if err := cmd.Run(); err != nil {
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err) retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
@ -181,8 +191,8 @@ func (t *tun) removeRoutes(routes []Route) error {
if !r.Install { if !r.Install {
continue continue
} }
//TODO: CERT-V2 is this right?
cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.cidr.Addr().String()) cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
t.l.Debug("command: ", cmd.String()) t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil { if err := cmd.Run(); err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route") t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
@ -193,8 +203,8 @@ func (t *tun) removeRoutes(routes []Route) error {
return nil return nil
} }
func (t *tun) Cidr() netip.Prefix { func (t *tun) Networks() []netip.Prefix {
return t.cidr return t.vpnNetworks
} }
func (t *tun) Name() string { func (t *tun) Name() string {

View File

@ -16,19 +16,19 @@ import (
) )
type TestTun struct { type TestTun struct {
Device string Device string
cidr netip.Prefix vpnNetworks []netip.Prefix
Routes []Route Routes []Route
routeTree *bart.Table[netip.Addr] routeTree *bart.Table[netip.Addr]
l *logrus.Logger l *logrus.Logger
closed atomic.Bool closed atomic.Bool
rxPackets chan []byte // Packets to receive into nebula rxPackets chan []byte // Packets to receive into nebula
TxPackets chan []byte // Packets transmitted outside by nebula TxPackets chan []byte // Packets transmitted outside by nebula
} }
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*TestTun, error) { func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) {
_, routes, err := getAllRoutesFromConfig(c, cidr, true) _, routes, err := getAllRoutesFromConfig(c, vpnNetworks, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -38,17 +38,17 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*TestTun,
} }
return &TestTun{ return &TestTun{
Device: c.GetString("tun.dev", ""), Device: c.GetString("tun.dev", ""),
cidr: cidr, vpnNetworks: vpnNetworks,
Routes: routes, Routes: routes,
routeTree: routeTree, routeTree: routeTree,
l: l, l: l,
rxPackets: make(chan []byte, 10), rxPackets: make(chan []byte, 10),
TxPackets: make(chan []byte, 10), TxPackets: make(chan []byte, 10),
}, nil }, nil
} }
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*TestTun, error) { func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*TestTun, error) {
return nil, fmt.Errorf("newTunFromFd not supported") return nil, fmt.Errorf("newTunFromFd not supported")
} }
@ -95,8 +95,8 @@ func (t *TestTun) Activate() error {
return nil return nil
} }
func (t *TestTun) Cidr() netip.Prefix { func (t *TestTun) Networks() []netip.Prefix {
return t.cidr return t.vpnNetworks
} }
func (t *TestTun) Name() string { func (t *TestTun) Name() string {

View File

@ -1,208 +0,0 @@
package overlay
import (
"fmt"
"io"
"net"
"net/netip"
"os/exec"
"strconv"
"sync/atomic"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/util"
"github.com/songgao/water"
)
type waterTun struct {
Device string
cidr netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger
f *net.Interface
*water.Interface
}
func newWaterTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*waterTun, error) {
// NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate()
t := &waterTun{
cidr: cidr,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
}
err := t.reload(c, true)
if err != nil {
return nil, err
}
c.RegisterReloadCallback(func(c *config.C) {
err := t.reload(c, false)
if err != nil {
util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
}
})
return t, nil
}
func (t *waterTun) Activate() error {
var err error
t.Interface, err = water.New(water.Config{
DeviceType: water.TUN,
PlatformSpecificParams: water.PlatformSpecificParams{
ComponentID: "tap0901",
Network: t.cidr.String(),
},
})
if err != nil {
return fmt.Errorf("activate failed: %v", err)
}
t.Device = t.Interface.Name()
// TODO use syscalls instead of exec.Command
err = exec.Command(
`C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "address",
fmt.Sprintf("name=%s", t.Device),
"source=static",
fmt.Sprintf("addr=%s", t.cidr.Addr()),
fmt.Sprintf("mask=%s", net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen())),
"gateway=none",
).Run()
if err != nil {
return fmt.Errorf("failed to run 'netsh' to set address: %s", err)
}
err = exec.Command(
`C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "interface",
t.Device,
fmt.Sprintf("mtu=%d", t.MTU),
).Run()
if err != nil {
return fmt.Errorf("failed to run 'netsh' to set MTU: %s", err)
}
t.f, err = net.InterfaceByName(t.Device)
if err != nil {
return fmt.Errorf("failed to find interface named %s: %v", t.Device, err)
}
err = t.addRoutes(false)
if err != nil {
return err
}
return nil
}
func (t *waterTun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
if err != nil {
return err
}
if !initial && !change {
return nil
}
routeTree, err := makeRouteTree(t.l, routes, false)
if err != nil {
return err
}
// Teach nebula how to handle the routes before establishing them in the system table
oldRoutes := t.Routes.Swap(&routes)
t.routeTree.Store(routeTree)
if !initial {
// Remove first, if the system removes a wanted route hopefully it will be re-added next
t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
// Ensure any routes we actually want are installed
err = t.addRoutes(true)
if err != nil {
// Catch any stray logs
util.LogWithContextIfNeeded("Failed to set routes", err, t.l)
} else {
for _, r := range findRemovedRoutes(routes, *oldRoutes) {
t.l.WithField("route", r).Info("Removed route")
}
}
}
return nil
}
func (t *waterTun) addRoutes(logErrors bool) error {
// Path routes
routes := *t.Routes.Load()
for _, r := range routes {
if !r.Via.IsValid() || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
err := exec.Command(
"C:\\Windows\\System32\\route.exe", "add", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(t.f.Index), "METRIC", strconv.Itoa(r.Metric),
).Run()
if err != nil {
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
if logErrors {
retErr.Log(t.l)
} else {
return retErr
}
} else {
t.l.WithField("route", r).Info("Added route")
}
}
return nil
}
func (t *waterTun) removeRoutes(routes []Route) {
for _, r := range routes {
if !r.Install {
continue
}
err := exec.Command(
"C:\\Windows\\System32\\route.exe", "delete", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(t.f.Index), "METRIC", strconv.Itoa(r.Metric),
).Run()
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
}
}
}
func (t *waterTun) RouteFor(ip netip.Addr) netip.Addr {
r, _ := t.routeTree.Load().Lookup(ip)
return r
}
func (t *waterTun) Cidr() netip.Prefix {
return t.cidr
}
func (t *waterTun) Name() string {
return t.Device
}
func (t *waterTun) Close() error {
if t.Interface == nil {
return nil
}
return t.Interface.Close()
}
func (t *waterTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
}

View File

@ -4,41 +4,268 @@
package overlay package overlay
import ( import (
"crypto"
"fmt" "fmt"
"io"
"net/netip" "net/netip"
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"sync/atomic"
"syscall" "syscall"
"unsafe"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/util"
"github.com/slackhq/nebula/wintun"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
) )
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (Device, error) { const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
type winTun struct {
Device string
vpnNetworks []netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger
tun *wintun.NativeTun
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) {
return nil, fmt.Errorf("newTunFromFd not supported in Windows") return nil, fmt.Errorf("newTunFromFd not supported in Windows")
} }
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (Device, error) { func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) {
useWintun := true err := checkWinTunExists()
if err := checkWinTunExists(); err != nil {
l.WithError(err).Warn("Check Wintun driver failed, fallback to wintap driver")
useWintun = false
}
if useWintun {
device, err := newWinTun(c, l, cidr, multiqueue)
if err != nil {
return nil, fmt.Errorf("create Wintun interface failed, %w", err)
}
return device, nil
}
device, err := newWaterTun(c, l, cidr, multiqueue)
if err != nil { if err != nil {
return nil, fmt.Errorf("create wintap driver failed, %w", err) return nil, fmt.Errorf("can not load the wintun driver: %w", err)
} }
return device, nil
deviceName := c.GetString("tun.dev", "")
guid, err := generateGUIDByDeviceName(deviceName)
if err != nil {
return nil, fmt.Errorf("generate GUID failed: %w", err)
}
t := &winTun{
Device: deviceName,
vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
}
err = t.reload(c, true)
if err != nil {
return nil, err
}
var tunDevice wintun.Device
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
if err != nil {
// Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device.
// Trying a second time resolves the issue.
l.WithError(err).Debug("Failed to create wintun device, retrying")
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
if err != nil {
return nil, fmt.Errorf("create TUN device failed: %w", err)
}
}
t.tun = tunDevice.(*wintun.NativeTun)
c.RegisterReloadCallback(func(c *config.C) {
err := t.reload(c, false)
if err != nil {
util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
}
})
return t, nil
}
func (t *winTun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil {
return err
}
if !initial && !change {
return nil
}
routeTree, err := makeRouteTree(t.l, routes, false)
if err != nil {
return err
}
// Teach nebula how to handle the routes before establishing them in the system table
oldRoutes := t.Routes.Swap(&routes)
t.routeTree.Store(routeTree)
if !initial {
// Remove first, if the system removes a wanted route hopefully it will be re-added next
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
if err != nil {
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
}
// Ensure any routes we actually want are installed
err = t.addRoutes(true)
if err != nil {
// Catch any stray logs
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
}
}
return nil
}
func (t *winTun) Activate() error {
luid := winipcfg.LUID(t.tun.LUID())
err := luid.SetIPAddresses(t.vpnNetworks)
if err != nil {
return fmt.Errorf("failed to set address: %w", err)
}
err = t.addRoutes(false)
if err != nil {
return err
}
return nil
}
func (t *winTun) addRoutes(logErrors bool) error {
luid := winipcfg.LUID(t.tun.LUID())
routes := *t.Routes.Load()
foundDefault4 := false
for _, r := range routes {
if !r.Via.IsValid() || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
// Add our unsafe route
err := luid.AddRoute(r.Cidr, r.Via, uint32(r.Metric))
if err != nil {
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
if logErrors {
retErr.Log(t.l)
continue
} else {
return retErr
}
} else {
t.l.WithField("route", r).Info("Added route")
}
if !foundDefault4 {
if r.Cidr.Bits() == 0 && r.Cidr.Addr().BitLen() == 32 {
foundDefault4 = true
}
}
}
ipif, err := luid.IPInterface(windows.AF_INET)
if err != nil {
return fmt.Errorf("failed to get ip interface: %w", err)
}
ipif.NLMTU = uint32(t.MTU)
if foundDefault4 {
ipif.UseAutomaticMetric = false
ipif.Metric = 0
}
if err := ipif.Set(); err != nil {
return fmt.Errorf("failed to set ip interface: %w", err)
}
return nil
}
func (t *winTun) removeRoutes(routes []Route) error {
luid := winipcfg.LUID(t.tun.LUID())
for _, r := range routes {
if !r.Install {
continue
}
err := luid.DeleteRoute(r.Cidr, r.Via)
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
}
}
return nil
}
func (t *winTun) RouteFor(ip netip.Addr) netip.Addr {
r, _ := t.routeTree.Load().Lookup(ip)
return r
}
func (t *winTun) Networks() []netip.Prefix {
return t.vpnNetworks
}
func (t *winTun) Name() string {
return t.Device
}
func (t *winTun) Read(b []byte) (int, error) {
return t.tun.Read(b, 0)
}
func (t *winTun) Write(b []byte) (int, error) {
return t.tun.Write(b, 0)
}
func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
}
func (t *winTun) Close() error {
// It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes,
// so to be certain, just remove everything before destroying.
luid := winipcfg.LUID(t.tun.LUID())
_ = luid.FlushRoutes(windows.AF_INET)
_ = luid.FlushIPAddresses(windows.AF_INET)
_ = luid.FlushRoutes(windows.AF_INET6)
_ = luid.FlushIPAddresses(windows.AF_INET6)
_ = luid.FlushDNS(windows.AF_INET)
_ = luid.FlushDNS(windows.AF_INET6)
return t.tun.Close()
}
func generateGUIDByDeviceName(name string) (*windows.GUID, error) {
// GUID is 128 bit
hash := crypto.MD5.New()
_, err := hash.Write([]byte(tunGUIDLabel))
if err != nil {
return nil, err
}
_, err = hash.Write([]byte(name))
if err != nil {
return nil, err
}
sum := hash.Sum(nil)
return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil
} }
func checkWinTunExists() error { func checkWinTunExists() error {

View File

@ -1,252 +0,0 @@
package overlay
import (
"crypto"
"fmt"
"io"
"net/netip"
"sync/atomic"
"unsafe"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/util"
"github.com/slackhq/nebula/wintun"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
)
const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
type winTun struct {
Device string
cidr netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger
tun *wintun.NativeTun
}
func generateGUIDByDeviceName(name string) (*windows.GUID, error) {
// GUID is 128 bit
hash := crypto.MD5.New()
_, err := hash.Write([]byte(tunGUIDLabel))
if err != nil {
return nil, err
}
_, err = hash.Write([]byte(name))
if err != nil {
return nil, err
}
sum := hash.Sum(nil)
return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil
}
func newWinTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*winTun, error) {
deviceName := c.GetString("tun.dev", "")
guid, err := generateGUIDByDeviceName(deviceName)
if err != nil {
return nil, fmt.Errorf("generate GUID failed: %w", err)
}
t := &winTun{
Device: deviceName,
cidr: cidr,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
}
err = t.reload(c, true)
if err != nil {
return nil, err
}
var tunDevice wintun.Device
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
if err != nil {
// Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device.
// Trying a second time resolves the issue.
l.WithError(err).Debug("Failed to create wintun device, retrying")
tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
if err != nil {
return nil, fmt.Errorf("create TUN device failed: %w", err)
}
}
t.tun = tunDevice.(*wintun.NativeTun)
c.RegisterReloadCallback(func(c *config.C) {
err := t.reload(c, false)
if err != nil {
util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
}
})
return t, nil
}
func (t *winTun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
if err != nil {
return err
}
if !initial && !change {
return nil
}
routeTree, err := makeRouteTree(t.l, routes, false)
if err != nil {
return err
}
// Teach nebula how to handle the routes before establishing them in the system table
oldRoutes := t.Routes.Swap(&routes)
t.routeTree.Store(routeTree)
if !initial {
// Remove first, if the system removes a wanted route hopefully it will be re-added next
err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
if err != nil {
util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
}
// Ensure any routes we actually want are installed
err = t.addRoutes(true)
if err != nil {
// Catch any stray logs
util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
}
}
return nil
}
func (t *winTun) Activate() error {
luid := winipcfg.LUID(t.tun.LUID())
err := luid.SetIPAddresses([]netip.Prefix{t.cidr})
if err != nil {
return fmt.Errorf("failed to set address: %w", err)
}
err = t.addRoutes(false)
if err != nil {
return err
}
return nil
}
func (t *winTun) addRoutes(logErrors bool) error {
luid := winipcfg.LUID(t.tun.LUID())
routes := *t.Routes.Load()
foundDefault4 := false
for _, r := range routes {
if !r.Via.IsValid() || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
// Add our unsafe route
err := luid.AddRoute(r.Cidr, r.Via, uint32(r.Metric))
if err != nil {
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
if logErrors {
retErr.Log(t.l)
continue
} else {
return retErr
}
} else {
t.l.WithField("route", r).Info("Added route")
}
if !foundDefault4 {
if r.Cidr.Bits() == 0 && r.Cidr.Addr().BitLen() == 32 {
foundDefault4 = true
}
}
}
ipif, err := luid.IPInterface(windows.AF_INET)
if err != nil {
return fmt.Errorf("failed to get ip interface: %w", err)
}
ipif.NLMTU = uint32(t.MTU)
if foundDefault4 {
ipif.UseAutomaticMetric = false
ipif.Metric = 0
}
if err := ipif.Set(); err != nil {
return fmt.Errorf("failed to set ip interface: %w", err)
}
return nil
}
func (t *winTun) removeRoutes(routes []Route) error {
luid := winipcfg.LUID(t.tun.LUID())
for _, r := range routes {
if !r.Install {
continue
}
err := luid.DeleteRoute(r.Cidr, r.Via)
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
t.l.WithField("route", r).Info("Removed route")
}
}
return nil
}
func (t *winTun) RouteFor(ip netip.Addr) netip.Addr {
r, _ := t.routeTree.Load().Lookup(ip)
return r
}
func (t *winTun) Cidr() netip.Prefix {
return t.cidr
}
func (t *winTun) Name() string {
return t.Device
}
func (t *winTun) Read(b []byte) (int, error) {
return t.tun.Read(b, 0)
}
func (t *winTun) Write(b []byte) (int, error) {
return t.tun.Write(b, 0)
}
func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
}
func (t *winTun) Close() error {
// It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes,
// so to be certain, just remove everything before destroying.
luid := winipcfg.LUID(t.tun.LUID())
_ = luid.FlushRoutes(windows.AF_INET)
_ = luid.FlushIPAddresses(windows.AF_INET)
/* We don't support IPV6 yet
_ = luid.FlushRoutes(windows.AF_INET6)
_ = luid.FlushIPAddresses(windows.AF_INET6)
*/
_ = luid.FlushDNS(windows.AF_INET)
return t.tun.Close()
}

View File

@ -8,16 +8,16 @@ import (
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
) )
func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) { func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
return NewUserDevice(tunCidr) return NewUserDevice(vpnNetworks)
} }
func NewUserDevice(tunCidr netip.Prefix) (Device, error) { func NewUserDevice(vpnNetworks []netip.Prefix) (Device, error) {
// these pipes guarantee each write/read will match 1:1 // these pipes guarantee each write/read will match 1:1
or, ow := io.Pipe() or, ow := io.Pipe()
ir, iw := io.Pipe() ir, iw := io.Pipe()
return &UserDevice{ return &UserDevice{
tunCidr: tunCidr, vpnNetworks: vpnNetworks,
outboundReader: or, outboundReader: or,
outboundWriter: ow, outboundWriter: ow,
inboundReader: ir, inboundReader: ir,
@ -26,7 +26,7 @@ func NewUserDevice(tunCidr netip.Prefix) (Device, error) {
} }
type UserDevice struct { type UserDevice struct {
tunCidr netip.Prefix vpnNetworks []netip.Prefix
outboundReader *io.PipeReader outboundReader *io.PipeReader
outboundWriter *io.PipeWriter outboundWriter *io.PipeWriter
@ -38,7 +38,7 @@ type UserDevice struct {
func (d *UserDevice) Activate() error { func (d *UserDevice) Activate() error {
return nil return nil
} }
func (d *UserDevice) Cidr() netip.Prefix { return d.tunCidr } func (d *UserDevice) Networks() []netip.Prefix { return d.vpnNetworks }
func (d *UserDevice) Name() string { return "faketun0" } func (d *UserDevice) Name() string { return "faketun0" }
func (d *UserDevice) RouteFor(ip netip.Addr) netip.Addr { return ip } func (d *UserDevice) RouteFor(ip netip.Addr) netip.Addr { return ip }
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {

414
pki.go
View File

@ -1,13 +1,19 @@
package nebula package nebula
import ( import (
"encoding/binary"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"net"
"net/netip"
"os" "os"
"slices"
"strings" "strings"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
@ -21,12 +27,22 @@ type PKI struct {
} }
type CertState struct { type CertState struct {
Certificate cert.Certificate v1Cert cert.Certificate
RawCertificate []byte v1HandshakeBytes []byte
RawCertificateNoKey []byte
PublicKey []byte v2Cert cert.Certificate
PrivateKey []byte v2HandshakeBytes []byte
pkcs11Backed bool
defaultVersion cert.Version
privateKey []byte
pkcs11Backed bool
cipher string
myVpnNetworks []netip.Prefix
myVpnNetworksTable *bart.Table[struct{}]
myVpnAddrs []netip.Addr
myVpnAddrsTable *bart.Table[struct{}]
myVpnBroadcastAddrsTable *bart.Table[struct{}]
} }
func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) { func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
@ -46,16 +62,16 @@ func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
return pki, nil return pki, nil
} }
func (p *PKI) GetCertState() *CertState {
return p.cs.Load()
}
func (p *PKI) GetCAPool() *cert.CAPool { func (p *PKI) GetCAPool() *cert.CAPool {
return p.caPool.Load() return p.caPool.Load()
} }
func (p *PKI) getCertState() *CertState {
return p.cs.Load()
}
func (p *PKI) reload(c *config.C, initial bool) error { func (p *PKI) reload(c *config.C, initial bool) error {
err := p.reloadCert(c, initial) err := p.reloadCerts(c, initial)
if err != nil { if err != nil {
if initial { if initial {
return err return err
@ -74,33 +90,94 @@ func (p *PKI) reload(c *config.C, initial bool) error {
return nil return nil
} }
func (p *PKI) reloadCert(c *config.C, initial bool) *util.ContextualError { func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
cs, err := newCertStateFromConfig(c) newState, err := newCertStateFromConfig(c)
if err != nil { if err != nil {
return util.NewContextualError("Could not load client cert", nil, err) return util.NewContextualError("Could not load client cert", nil, err)
} }
if !initial { if !initial {
//TODO: include check for mask equality as well currentState := p.cs.Load()
if newState.v1Cert != nil {
if currentState.v1Cert == nil {
return util.NewContextualError("v1 certificate was added, restart required", nil, err)
}
// did IP in cert change? if so, don't set // did IP in cert change? if so, don't set
currentCert := p.cs.Load().Certificate if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
oldIPs := currentCert.Networks() return util.NewContextualError(
newIPs := cs.Certificate.Networks() "Networks in new cert was different from old",
if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() { m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks()},
nil,
)
}
if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
return util.NewContextualError(
"Curve in new cert was different from old",
m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve()},
nil,
)
}
} else if currentState.v1Cert != nil {
//TODO: CERT-V2 we should be able to tear this down
return util.NewContextualError("v1 certificate was removed, restart required", nil, err)
}
if newState.v2Cert != nil {
if currentState.v2Cert == nil {
return util.NewContextualError("v2 certificate was added, restart required", nil, err)
}
// did IP in cert change? if so, don't set
if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
return util.NewContextualError(
"Networks in new cert was different from old",
m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks()},
nil,
)
}
if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
return util.NewContextualError(
"Curve in new cert was different from old",
m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve()},
nil,
)
}
} else if currentState.v2Cert != nil {
return util.NewContextualError("v2 certificate was removed, restart required", nil, err)
}
// Cipher cant be hot swapped so just leave it at what it was before
newState.cipher = currentState.cipher
} else {
newState.cipher = c.GetString("cipher", "aes")
//TODO: this sucks and we should make it not a global
switch newState.cipher {
case "aes":
noiseEndianness = binary.BigEndian
case "chachapoly":
noiseEndianness = binary.LittleEndian
default:
return util.NewContextualError( return util.NewContextualError(
"Networks in new cert was different from old", "unknown cipher",
m{"new_network": newIPs[0], "old_network": oldIPs[0]}, m{"cipher": newState.cipher},
nil, nil,
) )
} }
} }
p.cs.Store(cs) p.cs.Store(newState)
//TODO: CERT-V2 newState needs a stringer that does json
if initial { if initial {
p.l.WithField("cert", cs.Certificate).Debug("Client nebula certificate") p.l.WithField("cert", newState).Debug("Client nebula certificate(s)")
} else { } else {
p.l.WithField("cert", cs.Certificate).Info("Client cert refreshed from disk") p.l.WithField("cert", newState).Info("Client certificate(s) refreshed from disk")
} }
return nil return nil
} }
@ -116,55 +193,65 @@ func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError {
return nil return nil
} }
func newCertState(certificate cert.Certificate, pkcs11backed bool, privateKey []byte) (*CertState, error) { func (cs *CertState) GetDefaultCertificate() cert.Certificate {
// Marshal the certificate to ensure it is valid c := cs.getCertificate(cs.defaultVersion)
rawCertificate, err := certificate.Marshal() if c == nil {
if err != nil { panic("No default certificate found")
return nil, fmt.Errorf("invalid nebula certificate on interface: %s", err)
} }
return c
publicKey := certificate.PublicKey()
cs := &CertState{
RawCertificate: rawCertificate,
Certificate: certificate,
PrivateKey: privateKey,
PublicKey: publicKey,
pkcs11Backed: pkcs11backed,
}
rawCertNoKey, err := cs.Certificate.MarshalForHandshakes()
if err != nil {
return nil, fmt.Errorf("error marshalling certificate no key: %s", err)
}
cs.RawCertificateNoKey = rawCertNoKey
return cs, nil
} }
func loadPrivateKey(privPathOrPEM string) (rawKey []byte, curve cert.Curve, isPkcs11 bool, err error) { func (cs *CertState) getCertificate(v cert.Version) cert.Certificate {
var pemPrivateKey []byte switch v {
if strings.Contains(privPathOrPEM, "-----BEGIN") { case cert.Version1:
pemPrivateKey = []byte(privPathOrPEM) return cs.v1Cert
privPathOrPEM = "<inline>" case cert.Version2:
rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey) return cs.v2Cert
if err != nil {
return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
}
} else if strings.HasPrefix(privPathOrPEM, "pkcs11:") {
rawKey = []byte(privPathOrPEM)
return rawKey, cert.Curve_P256, true, nil
} else {
pemPrivateKey, err = os.ReadFile(privPathOrPEM)
if err != nil {
return nil, curve, false, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err)
}
rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey)
if err != nil {
return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
}
} }
return return nil
}
// getHandshakeBytes returns the cached bytes to be used in a handshake message for the requested version.
// Callers must check if the return []byte is nil.
func (cs *CertState) getHandshakeBytes(v cert.Version) []byte {
switch v {
case cert.Version1:
return cs.v1HandshakeBytes
case cert.Version2:
return cs.v2HandshakeBytes
default:
return nil
}
}
func (cs *CertState) String() string {
b, err := cs.MarshalJSON()
if err != nil {
return fmt.Sprintf("error marshaling certificate state: %v", err)
}
return string(b)
}
func (cs *CertState) MarshalJSON() ([]byte, error) {
msg := []json.RawMessage{}
if cs.v1Cert != nil {
b, err := cs.v1Cert.MarshalJSON()
if err != nil {
return nil, err
}
msg = append(msg, b)
}
if cs.v2Cert != nil {
b, err := cs.v2Cert.MarshalJSON()
if err != nil {
return nil, err
}
msg = append(msg, b)
}
return json.Marshal(msg)
} }
func newCertStateFromConfig(c *config.C) (*CertState, error) { func newCertStateFromConfig(c *config.C) (*CertState, error) {
@ -198,24 +285,197 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) {
} }
} }
nebulaCert, _, err := cert.UnmarshalCertificateFromPEM(rawCert) var crt, v1, v2 cert.Certificate
for {
// Load the certificate
crt, rawCert, err = loadCertificate(rawCert)
if err != nil {
return nil, err
}
switch crt.Version() {
case cert.Version1:
if v1 != nil {
return nil, fmt.Errorf("v1 certificate already found in pki.cert")
}
v1 = crt
case cert.Version2:
if v2 != nil {
return nil, fmt.Errorf("v2 certificate already found in pki.cert")
}
v2 = crt
default:
return nil, fmt.Errorf("unknown certificate version %v", crt.Version())
}
if len(rawCert) == 0 || strings.TrimSpace(string(rawCert)) == "" {
break
}
}
if v1 == nil && v2 == nil {
return nil, errors.New("no certificates found in pki.cert")
}
useDefaultVersion := uint32(1)
if v1 == nil {
// The only condition that requires v2 as the default is if only a v2 certificate is present
// We do this to avoid having to configure it specifically in the config file
useDefaultVersion = 2
}
rawDefaultVersion := c.GetUint32("pki.default_version", useDefaultVersion)
var defaultVersion cert.Version
switch rawDefaultVersion {
case 1:
if v1 == nil {
return nil, fmt.Errorf("can not use pki.default_version 1 without a v1 certificate in pki.cert")
}
defaultVersion = cert.Version1
case 2:
defaultVersion = cert.Version2
default:
return nil, fmt.Errorf("unknown pki.default_version: %v", rawDefaultVersion)
}
return newCertState(defaultVersion, v1, v2, isPkcs11, curve, rawKey)
}
func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) {
cs := CertState{
privateKey: privateKey,
pkcs11Backed: pkcs11backed,
myVpnNetworksTable: new(bart.Table[struct{}]),
myVpnAddrsTable: new(bart.Table[struct{}]),
myVpnBroadcastAddrsTable: new(bart.Table[struct{}]),
}
if v1 != nil && v2 != nil {
if !slices.Equal(v1.PublicKey(), v2.PublicKey()) {
return nil, util.NewContextualError("v1 and v2 public keys are not the same, ignoring", nil, nil)
}
if v1.Curve() != v2.Curve() {
return nil, util.NewContextualError("v1 and v2 curve are not the same, ignoring", nil, nil)
}
//TODO: CERT-V2 make sure v2 has v1s address
cs.defaultVersion = dv
}
if v1 != nil {
if pkcs11backed {
//NOTE: We do not currently have a method to verify a public private key pair when the private key is in an hsm
} else {
if err := v1.VerifyPrivateKey(privateKeyCurve, privateKey); err != nil {
return nil, fmt.Errorf("private key is not a pair with public key in nebula cert")
}
}
v1hs, err := v1.MarshalForHandshakes()
if err != nil {
return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err)
}
cs.v1Cert = v1
cs.v1HandshakeBytes = v1hs
if cs.defaultVersion == 0 {
cs.defaultVersion = cert.Version1
}
}
if v2 != nil {
if pkcs11backed {
//NOTE: We do not currently have a method to verify a public private key pair when the private key is in an hsm
} else {
if err := v2.VerifyPrivateKey(privateKeyCurve, privateKey); err != nil {
return nil, fmt.Errorf("private key is not a pair with public key in nebula cert")
}
}
v2hs, err := v2.MarshalForHandshakes()
if err != nil {
return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err)
}
cs.v2Cert = v2
cs.v2HandshakeBytes = v2hs
if cs.defaultVersion == 0 {
cs.defaultVersion = cert.Version2
}
}
var crt cert.Certificate
crt = cs.getCertificate(cert.Version2)
if crt == nil {
// v2 certificates are a superset, only look at v1 if its all we have
crt = cs.getCertificate(cert.Version1)
}
for _, network := range crt.Networks() {
cs.myVpnNetworks = append(cs.myVpnNetworks, network)
cs.myVpnNetworksTable.Insert(network, struct{}{})
cs.myVpnAddrs = append(cs.myVpnAddrs, network.Addr())
cs.myVpnAddrsTable.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()), struct{}{})
if network.Addr().Is4() {
addr := network.Masked().Addr().As4()
mask := net.CIDRMask(network.Bits(), network.Addr().BitLen())
binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(mask))
cs.myVpnBroadcastAddrsTable.Insert(netip.PrefixFrom(netip.AddrFrom4(addr), network.Addr().BitLen()), struct{}{})
}
}
return &cs, nil
}
func loadPrivateKey(privPathOrPEM string) (rawKey []byte, curve cert.Curve, isPkcs11 bool, err error) {
var pemPrivateKey []byte
if strings.Contains(privPathOrPEM, "-----BEGIN") {
pemPrivateKey = []byte(privPathOrPEM)
privPathOrPEM = "<inline>"
rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey)
if err != nil {
return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
}
} else if strings.HasPrefix(privPathOrPEM, "pkcs11:") {
rawKey = []byte(privPathOrPEM)
return rawKey, cert.Curve_P256, true, nil
} else {
pemPrivateKey, err = os.ReadFile(privPathOrPEM)
if err != nil {
return nil, curve, false, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err)
}
rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey)
if err != nil {
return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
}
}
return
}
func loadCertificate(b []byte) (cert.Certificate, []byte, error) {
c, b, err := cert.UnmarshalCertificateFromPEM(b)
if err != nil { if err != nil {
return nil, fmt.Errorf("error while unmarshaling pki.cert %s: %s", pubPathOrPEM, err) return nil, b, fmt.Errorf("error while unmarshaling pki.cert: %w", err)
} }
if nebulaCert.Expired(time.Now()) { if c.Expired(time.Now()) {
return nil, fmt.Errorf("nebula certificate for this host is expired") return nil, b, fmt.Errorf("nebula certificate for this host is expired")
} }
if len(nebulaCert.Networks()) == 0 { if len(c.Networks()) == 0 {
return nil, fmt.Errorf("no networks encoded in certificate") return nil, b, fmt.Errorf("no networks encoded in certificate")
} }
if err = nebulaCert.VerifyPrivateKey(curve, rawKey); err != nil { if c.IsCA() {
return nil, fmt.Errorf("private key is not a pair with public key in nebula cert") return nil, b, fmt.Errorf("host certificate is a CA certificate")
} }
return newCertState(nebulaCert, isPkcs11, rawKey) return c, b, nil
} }
func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) { func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {

View File

@ -9,6 +9,7 @@ import (
"sync/atomic" "sync/atomic"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
) )
@ -72,7 +73,7 @@ func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp neti
Type: relayType, Type: relayType,
State: state, State: state,
LocalIndex: index, LocalIndex: index,
PeerIp: vpnIp, PeerAddr: vpnIp,
} }
if remoteIdx != nil { if remoteIdx != nil {
@ -91,40 +92,71 @@ func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp neti
func (rm *relayManager) EstablishRelay(relayHostInfo *HostInfo, m *NebulaControl) (*Relay, error) { func (rm *relayManager) EstablishRelay(relayHostInfo *HostInfo, m *NebulaControl) (*Relay, error) {
relay, ok := relayHostInfo.relayState.CompleteRelayByIdx(m.InitiatorRelayIndex, m.ResponderRelayIndex) relay, ok := relayHostInfo.relayState.CompleteRelayByIdx(m.InitiatorRelayIndex, m.ResponderRelayIndex)
if !ok { if !ok {
rm.l.WithFields(logrus.Fields{"relay": relayHostInfo.vpnIp, fields := logrus.Fields{
"relay": relayHostInfo.vpnAddrs[0],
"initiatorRelayIndex": m.InitiatorRelayIndex, "initiatorRelayIndex": m.InitiatorRelayIndex,
"relayFrom": m.RelayFromIp, }
"relayTo": m.RelayToIp}).Info("relayManager failed to update relay")
if m.RelayFromAddr == nil {
fields["relayFrom"] = m.OldRelayFromAddr
} else {
fields["relayFrom"] = m.RelayFromAddr
}
if m.RelayToAddr == nil {
fields["relayTo"] = m.OldRelayToAddr
} else {
fields["relayTo"] = m.RelayToAddr
}
rm.l.WithFields(fields).Info("relayManager failed to update relay")
return nil, fmt.Errorf("unknown relay") return nil, fmt.Errorf("unknown relay")
} }
return relay, nil return relay, nil
} }
func (rm *relayManager) HandleControlMsg(h *HostInfo, m *NebulaControl, f *Interface) { func (rm *relayManager) HandleControlMsg(h *HostInfo, d []byte, f *Interface) {
msg := &NebulaControl{}
switch m.Type { err := msg.Unmarshal(d)
case NebulaControl_CreateRelayRequest: if err != nil {
rm.handleCreateRelayRequest(h, f, m) h.logger(f.l).WithError(err).Error("Failed to unmarshal control message")
case NebulaControl_CreateRelayResponse: return
rm.handleCreateRelayResponse(h, f, m)
} }
var v cert.Version
if msg.OldRelayFromAddr > 0 || msg.OldRelayToAddr > 0 {
v = cert.Version1
b := [4]byte{}
binary.BigEndian.PutUint32(b[:], msg.OldRelayFromAddr)
msg.RelayFromAddr = netAddrToProtoAddr(netip.AddrFrom4(b))
binary.BigEndian.PutUint32(b[:], msg.OldRelayToAddr)
msg.RelayToAddr = netAddrToProtoAddr(netip.AddrFrom4(b))
} else {
v = cert.Version2
}
switch msg.Type {
case NebulaControl_CreateRelayRequest:
rm.handleCreateRelayRequest(v, h, f, msg)
case NebulaControl_CreateRelayResponse:
rm.handleCreateRelayResponse(v, h, f, msg)
}
} }
func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *NebulaControl) { func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) {
rm.l.WithFields(logrus.Fields{ rm.l.WithFields(logrus.Fields{
"relayFrom": m.RelayFromIp, "relayFrom": protoAddrToNetAddr(m.RelayFromAddr),
"relayTo": m.RelayToIp, "relayTo": protoAddrToNetAddr(m.RelayToAddr),
"initiatorRelayIndex": m.InitiatorRelayIndex, "initiatorRelayIndex": m.InitiatorRelayIndex,
"responderRelayIndex": m.ResponderRelayIndex, "responderRelayIndex": m.ResponderRelayIndex,
"vpnIp": h.vpnIp}). "vpnAddrs": h.vpnAddrs}).
Info("handleCreateRelayResponse") Info("handleCreateRelayResponse")
target := m.RelayToIp
//TODO: IPV6-WORK target := m.RelayToAddr
b := [4]byte{} targetAddr := protoAddrToNetAddr(target)
binary.BigEndian.PutUint32(b[:], m.RelayToIp)
targetAddr := netip.AddrFrom4(b)
relay, err := rm.EstablishRelay(h, m) relay, err := rm.EstablishRelay(h, m)
if err != nil { if err != nil {
@ -136,68 +168,88 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *
return return
} }
// I'm the middle man. Let the initiator know that the I've established the relay they requested. // I'm the middle man. Let the initiator know that the I've established the relay they requested.
peerHostInfo := rm.hostmap.QueryVpnIp(relay.PeerIp) peerHostInfo := rm.hostmap.QueryVpnAddr(relay.PeerAddr)
if peerHostInfo == nil { if peerHostInfo == nil {
rm.l.WithField("relayTo", relay.PeerIp).Error("Can't find a HostInfo for peer") rm.l.WithField("relayTo", relay.PeerAddr).Error("Can't find a HostInfo for peer")
return return
} }
peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(targetAddr) peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(targetAddr)
if !ok { if !ok {
rm.l.WithField("relayTo", peerHostInfo.vpnIp).Error("peerRelay does not have Relay state for relayTo") rm.l.WithField("relayTo", peerHostInfo.vpnAddrs[0]).Error("peerRelay does not have Relay state for relayTo")
return return
} }
if peerRelay.State == PeerRequested { switch peerRelay.State {
//TODO: IPV6-WORK case Requested:
b = peerHostInfo.vpnIp.As4() // I initiated the request to this peer, but haven't heard back from the peer yet. I must wait for this peer
peerRelay.State = Established // to respond to complete the connection.
case PeerRequested, Disestablished, Established:
peerHostInfo.relayState.UpdateRelayForByIpState(targetAddr, Established)
resp := NebulaControl{ resp := NebulaControl{
Type: NebulaControl_CreateRelayResponse, Type: NebulaControl_CreateRelayResponse,
ResponderRelayIndex: peerRelay.LocalIndex, ResponderRelayIndex: peerRelay.LocalIndex,
InitiatorRelayIndex: peerRelay.RemoteIndex, InitiatorRelayIndex: peerRelay.RemoteIndex,
RelayFromIp: binary.BigEndian.Uint32(b[:]),
RelayToIp: uint32(target),
} }
if v == cert.Version1 {
peer := peerHostInfo.vpnAddrs[0]
if !peer.Is4() {
rm.l.WithField("relayFrom", peer).
WithField("relayTo", target).
WithField("initiatorRelayIndex", resp.InitiatorRelayIndex).
WithField("responderRelayIndex", resp.ResponderRelayIndex).
WithField("vpnAddrs", peerHostInfo.vpnAddrs).
Error("Refusing to CreateRelayResponse for a v1 relay with an ipv6 address")
return
}
b := peer.As4()
resp.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
b = targetAddr.As4()
resp.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
} else {
resp.RelayFromAddr = netAddrToProtoAddr(peerHostInfo.vpnAddrs[0])
resp.RelayToAddr = target
}
msg, err := resp.Marshal() msg, err := resp.Marshal()
if err != nil { if err != nil {
rm.l. rm.l.WithError(err).
WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay") Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay")
} else { } else {
f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu)) f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu))
rm.l.WithFields(logrus.Fields{ rm.l.WithFields(logrus.Fields{
"relayFrom": resp.RelayFromIp, "relayFrom": resp.RelayFromAddr,
"relayTo": resp.RelayToIp, "relayTo": resp.RelayToAddr,
"initiatorRelayIndex": resp.InitiatorRelayIndex, "initiatorRelayIndex": resp.InitiatorRelayIndex,
"responderRelayIndex": resp.ResponderRelayIndex, "responderRelayIndex": resp.ResponderRelayIndex,
"vpnIp": peerHostInfo.vpnIp}). "vpnAddrs": peerHostInfo.vpnAddrs}).
Info("send CreateRelayResponse") Info("send CreateRelayResponse")
} }
} }
} }
func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *NebulaControl) { func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) {
//TODO: IPV6-WORK from := protoAddrToNetAddr(m.RelayFromAddr)
b := [4]byte{} target := protoAddrToNetAddr(m.RelayToAddr)
binary.BigEndian.PutUint32(b[:], m.RelayFromIp)
from := netip.AddrFrom4(b)
binary.BigEndian.PutUint32(b[:], m.RelayToIp)
target := netip.AddrFrom4(b)
logMsg := rm.l.WithFields(logrus.Fields{ logMsg := rm.l.WithFields(logrus.Fields{
"relayFrom": from, "relayFrom": from,
"relayTo": target, "relayTo": target,
"initiatorRelayIndex": m.InitiatorRelayIndex, "initiatorRelayIndex": m.InitiatorRelayIndex,
"vpnIp": h.vpnIp}) "vpnAddrs": h.vpnAddrs})
logMsg.Info("handleCreateRelayRequest") logMsg.Info("handleCreateRelayRequest")
// Is the source of the relay me? This should never happen, but did happen due to // Is the source of the relay me? This should never happen, but did happen due to
// an issue migrating relays over to newly re-handshaked host info objects. // an issue migrating relays over to newly re-handshaked host info objects.
if from == f.myVpnNet.Addr() { _, found := f.myVpnAddrsTable.Lookup(from)
if found {
logMsg.WithField("myIP", from).Error("Discarding relay request from myself") logMsg.WithField("myIP", from).Error("Discarding relay request from myself")
return return
} }
// Is the target of the relay me? // Is the target of the relay me?
if target == f.myVpnNet.Addr() { _, found = f.myVpnAddrsTable.Lookup(target)
if found {
existingRelay, ok := h.relayState.QueryRelayForByIp(from) existingRelay, ok := h.relayState.QueryRelayForByIp(from)
if ok { if ok {
switch existingRelay.State { switch existingRelay.State {
@ -215,6 +267,21 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
"existingRemoteIndex": existingRelay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest") "existingRemoteIndex": existingRelay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest")
return return
} }
case Disestablished:
if existingRelay.RemoteIndex != m.InitiatorRelayIndex {
// We got a brand new Relay request, because its index is different than what we saw before.
// This should never happen. The peer should never change an index, once created.
logMsg.WithFields(logrus.Fields{
"existingRemoteIndex": existingRelay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest")
return
}
// Mark the relay as 'Established' because it's safe to use again
h.relayState.UpdateRelayForByIpState(from, Established)
case PeerRequested:
// I should never be in this state, because I am terminal, not forwarding.
logMsg.WithFields(logrus.Fields{
"existingRemoteIndex": existingRelay.RemoteIndex,
"state": existingRelay.State}).Error("Unexpected Relay State found")
} }
} else { } else {
_, err := AddRelay(rm.l, h, f.hostMap, from, &m.InitiatorRelayIndex, TerminalType, Established) _, err := AddRelay(rm.l, h, f.hostMap, from, &m.InitiatorRelayIndex, TerminalType, Established)
@ -226,21 +293,26 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
relay, ok := h.relayState.QueryRelayForByIp(from) relay, ok := h.relayState.QueryRelayForByIp(from)
if !ok { if !ok {
logMsg.Error("Relay State not found") logMsg.WithField("from", from).Error("Relay State not found")
return return
} }
//TODO: IPV6-WORK
fromB := from.As4()
targetB := target.As4()
resp := NebulaControl{ resp := NebulaControl{
Type: NebulaControl_CreateRelayResponse, Type: NebulaControl_CreateRelayResponse,
ResponderRelayIndex: relay.LocalIndex, ResponderRelayIndex: relay.LocalIndex,
InitiatorRelayIndex: relay.RemoteIndex, InitiatorRelayIndex: relay.RemoteIndex,
RelayFromIp: binary.BigEndian.Uint32(fromB[:]),
RelayToIp: binary.BigEndian.Uint32(targetB[:]),
} }
if v == cert.Version1 {
b := from.As4()
resp.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
b = target.As4()
resp.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
} else {
resp.RelayFromAddr = netAddrToProtoAddr(from)
resp.RelayToAddr = netAddrToProtoAddr(target)
}
msg, err := resp.Marshal() msg, err := resp.Marshal()
if err != nil { if err != nil {
logMsg. logMsg.
@ -248,12 +320,11 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
} else { } else {
f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu)) f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu))
rm.l.WithFields(logrus.Fields{ rm.l.WithFields(logrus.Fields{
//TODO: IPV6-WORK, this used to use the resp object but I am getting lazy now
"relayFrom": from, "relayFrom": from,
"relayTo": target, "relayTo": target,
"initiatorRelayIndex": resp.InitiatorRelayIndex, "initiatorRelayIndex": resp.InitiatorRelayIndex,
"responderRelayIndex": resp.ResponderRelayIndex, "responderRelayIndex": resp.ResponderRelayIndex,
"vpnIp": h.vpnIp}). "vpnAddrs": h.vpnAddrs}).
Info("send CreateRelayResponse") Info("send CreateRelayResponse")
} }
return return
@ -262,7 +333,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
if !rm.GetAmRelay() { if !rm.GetAmRelay() {
return return
} }
peer := rm.hostmap.QueryVpnIp(target) peer := rm.hostmap.QueryVpnAddr(target)
if peer == nil { if peer == nil {
// Try to establish a connection to this host. If we get a future relay request, // Try to establish a connection to this host. If we get a future relay request,
// we'll be ready! // we'll be ready!
@ -273,104 +344,69 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
// Only create relays to peers for whom I have a direct connection // Only create relays to peers for whom I have a direct connection
return return
} }
sendCreateRequest := false
var index uint32 var index uint32
var err error var err error
targetRelay, ok := peer.relayState.QueryRelayForByIp(from) targetRelay, ok := peer.relayState.QueryRelayForByIp(from)
if ok { if ok {
index = targetRelay.LocalIndex index = targetRelay.LocalIndex
if targetRelay.State == Requested {
sendCreateRequest = true
}
} else { } else {
// Allocate an index in the hostMap for this relay peer // Allocate an index in the hostMap for this relay peer
index, err = AddRelay(rm.l, peer, f.hostMap, from, nil, ForwardingType, Requested) index, err = AddRelay(rm.l, peer, f.hostMap, from, nil, ForwardingType, Requested)
if err != nil { if err != nil {
return return
} }
sendCreateRequest = true
} }
if sendCreateRequest { peer.relayState.UpdateRelayForByIpState(from, Requested)
//TODO: IPV6-WORK // Send a CreateRelayRequest to the peer.
fromB := h.vpnIp.As4() req := NebulaControl{
targetB := target.As4() Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: index,
}
// Send a CreateRelayRequest to the peer. if v == cert.Version1 {
req := NebulaControl{ if !h.vpnAddrs[0].Is4() {
Type: NebulaControl_CreateRelayRequest, rm.l.WithField("relayFrom", h.vpnAddrs[0]).
InitiatorRelayIndex: index, WithField("relayTo", target).
RelayFromIp: binary.BigEndian.Uint32(fromB[:]), WithField("initiatorRelayIndex", req.InitiatorRelayIndex).
RelayToIp: binary.BigEndian.Uint32(targetB[:]), WithField("responderRelayIndex", req.ResponderRelayIndex).
} WithField("vpnAddr", target).
msg, err := req.Marshal() Error("Refusing to CreateRelayRequest for a v1 relay with an ipv6 address")
if err != nil { return
logMsg.
WithError(err).Error("relayManager Failed to marshal Control message to create relay")
} else {
f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu))
rm.l.WithFields(logrus.Fields{
//TODO: IPV6-WORK another lazy used to use the req object
"relayFrom": h.vpnIp,
"relayTo": target,
"initiatorRelayIndex": req.InitiatorRelayIndex,
"responderRelayIndex": req.ResponderRelayIndex,
"vpnIp": target}).
Info("send CreateRelayRequest")
} }
b := h.vpnAddrs[0].As4()
req.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
b = target.As4()
req.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
} else {
req.RelayFromAddr = netAddrToProtoAddr(h.vpnAddrs[0])
req.RelayToAddr = netAddrToProtoAddr(target)
} }
msg, err := req.Marshal()
if err != nil {
logMsg.
WithError(err).Error("relayManager Failed to marshal Control message to create relay")
} else {
f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu))
rm.l.WithFields(logrus.Fields{
"relayFrom": h.vpnAddrs[0],
"relayTo": target,
"initiatorRelayIndex": req.InitiatorRelayIndex,
"responderRelayIndex": req.ResponderRelayIndex,
"vpnAddr": target}).
Info("send CreateRelayRequest")
}
// Also track the half-created Relay state just received // Also track the half-created Relay state just received
relay, ok := h.relayState.QueryRelayForByIp(target) _, ok = h.relayState.QueryRelayForByIp(target)
if !ok { if !ok {
// Add the relay _, err := AddRelay(rm.l, h, f.hostMap, target, &m.InitiatorRelayIndex, ForwardingType, PeerRequested)
state := PeerRequested
if targetRelay != nil && targetRelay.State == Established {
state = Established
}
_, err := AddRelay(rm.l, h, f.hostMap, target, &m.InitiatorRelayIndex, ForwardingType, state)
if err != nil { if err != nil {
logMsg. logMsg.
WithError(err).Error("relayManager Failed to allocate a local index for relay") WithError(err).Error("relayManager Failed to allocate a local index for relay")
return return
} }
} else {
switch relay.State {
case Established:
if relay.RemoteIndex != m.InitiatorRelayIndex {
// We got a brand new Relay request, because its index is different than what we saw before.
// This should never happen. The peer should never change an index, once created.
logMsg.WithFields(logrus.Fields{
"existingRemoteIndex": relay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest")
return
}
//TODO: IPV6-WORK
fromB := h.vpnIp.As4()
targetB := target.As4()
resp := NebulaControl{
Type: NebulaControl_CreateRelayResponse,
ResponderRelayIndex: relay.LocalIndex,
InitiatorRelayIndex: relay.RemoteIndex,
RelayFromIp: binary.BigEndian.Uint32(fromB[:]),
RelayToIp: binary.BigEndian.Uint32(targetB[:]),
}
msg, err := resp.Marshal()
if err != nil {
rm.l.
WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay")
} else {
f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu))
rm.l.WithFields(logrus.Fields{
//TODO: IPV6-WORK more lazy, used to use resp object
"relayFrom": h.vpnIp,
"relayTo": target,
"initiatorRelayIndex": resp.InitiatorRelayIndex,
"responderRelayIndex": resp.ResponderRelayIndex,
"vpnIp": h.vpnIp}).
Info("send CreateRelayResponse")
}
case Requested:
// Keep waiting for the other relay to complete
}
} }
} }
} }

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"net" "net"
"net/netip" "net/netip"
"slices"
"sort" "sort"
"strconv" "strconv"
"sync" "sync"
@ -17,8 +18,8 @@ import (
type forEachFunc func(addr netip.AddrPort, preferred bool) type forEachFunc func(addr netip.AddrPort, preferred bool)
// The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate) // The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate)
type checkFuncV4 func(vpnIp netip.Addr, to *Ip4AndPort) bool type checkFuncV4 func(vpnIp netip.Addr, to *V4AddrPort) bool
type checkFuncV6 func(vpnIp netip.Addr, to *Ip6AndPort) bool type checkFuncV6 func(vpnIp netip.Addr, to *V6AddrPort) bool
// CacheMap is a struct that better represents the lighthouse cache for humans // CacheMap is a struct that better represents the lighthouse cache for humans
// The string key is the owners vpnIp // The string key is the owners vpnIp
@ -32,9 +33,6 @@ type Cache struct {
Relay []netip.Addr `json:"relay"` Relay []netip.Addr `json:"relay"`
} }
//TODO: Seems like we should plop static host entries in here too since the are protected by the lighthouse from deletion
// We will never clean learned/reported information for them as it stands today
// cache is an internal struct that splits v4 and v6 addresses inside the cache map // cache is an internal struct that splits v4 and v6 addresses inside the cache map
type cache struct { type cache struct {
v4 *cacheV4 v4 *cacheV4
@ -48,14 +46,14 @@ type cacheRelay struct {
// cacheV4 stores learned and reported ipv4 records under cache // cacheV4 stores learned and reported ipv4 records under cache
type cacheV4 struct { type cacheV4 struct {
learned *Ip4AndPort learned *V4AddrPort
reported []*Ip4AndPort reported []*V4AddrPort
} }
// cacheV4 stores learned and reported ipv6 records under cache // cacheV4 stores learned and reported ipv6 records under cache
type cacheV6 struct { type cacheV6 struct {
learned *Ip6AndPort learned *V6AddrPort
reported []*Ip6AndPort reported []*V6AddrPort
} }
type hostnamePort struct { type hostnamePort struct {
@ -170,7 +168,7 @@ func (hr *hostnamesResults) Cancel() {
} }
} }
func (hr *hostnamesResults) GetIPs() []netip.AddrPort { func (hr *hostnamesResults) GetAddrs() []netip.AddrPort {
var retSlice []netip.AddrPort var retSlice []netip.AddrPort
if hr != nil { if hr != nil {
p := hr.ips.Load() p := hr.ips.Load()
@ -189,6 +187,9 @@ type RemoteList struct {
// Every interaction with internals requires a lock! // Every interaction with internals requires a lock!
sync.RWMutex sync.RWMutex
// The full list of vpn addresses assigned to this host
vpnAddrs []netip.Addr
// A deduplicated set of addresses. Any accessor should lock beforehand. // A deduplicated set of addresses. Any accessor should lock beforehand.
addrs []netip.AddrPort addrs []netip.AddrPort
@ -212,13 +213,16 @@ type RemoteList struct {
} }
// NewRemoteList creates a new empty RemoteList // NewRemoteList creates a new empty RemoteList
func NewRemoteList(shouldAdd func(netip.Addr) bool) *RemoteList { func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func(netip.Addr) bool) *RemoteList {
return &RemoteList{ r := &RemoteList{
vpnAddrs: make([]netip.Addr, len(vpnAddrs)),
addrs: make([]netip.AddrPort, 0), addrs: make([]netip.AddrPort, 0),
relays: make([]netip.Addr, 0), relays: make([]netip.Addr, 0),
cache: make(map[netip.Addr]*cache), cache: make(map[netip.Addr]*cache),
shouldAdd: shouldAdd, shouldAdd: shouldAdd,
} }
copy(r.vpnAddrs, vpnAddrs)
return r
} }
func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) { func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) {
@ -268,14 +272,13 @@ func (r *RemoteList) CopyAddrs(preferredRanges []netip.Prefix) []netip.AddrPort
// LearnRemote locks and sets the learned slot for the owner vpn ip to the provided addr // LearnRemote locks and sets the learned slot for the owner vpn ip to the provided addr
// Currently this is only needed when HostInfo.SetRemote is called as that should cover both handshaking and roaming. // Currently this is only needed when HostInfo.SetRemote is called as that should cover both handshaking and roaming.
// It will mark the deduplicated address list as dirty, so do not call it unless new information is available // It will mark the deduplicated address list as dirty, so do not call it unless new information is available
// TODO: this needs to support the allow list list
func (r *RemoteList) LearnRemote(ownerVpnIp netip.Addr, remote netip.AddrPort) { func (r *RemoteList) LearnRemote(ownerVpnIp netip.Addr, remote netip.AddrPort) {
r.Lock() r.Lock()
defer r.Unlock() defer r.Unlock()
if remote.Addr().Is4() { if remote.Addr().Is4() {
r.unlockedSetLearnedV4(ownerVpnIp, NewIp4AndPortFromNetIP(remote.Addr(), remote.Port())) r.unlockedSetLearnedV4(ownerVpnIp, netAddrToProtoV4AddrPort(remote.Addr(), remote.Port()))
} else { } else {
r.unlockedSetLearnedV6(ownerVpnIp, NewIp6AndPortFromNetIP(remote.Addr(), remote.Port())) r.unlockedSetLearnedV6(ownerVpnIp, netAddrToProtoV6AddrPort(remote.Addr(), remote.Port()))
} }
} }
@ -304,21 +307,21 @@ func (r *RemoteList) CopyCache() *CacheMap {
if mc.v4 != nil { if mc.v4 != nil {
if mc.v4.learned != nil { if mc.v4.learned != nil {
c.Learned = append(c.Learned, AddrPortFromIp4AndPort(mc.v4.learned)) c.Learned = append(c.Learned, protoV4AddrPortToNetAddrPort(mc.v4.learned))
} }
for _, a := range mc.v4.reported { for _, a := range mc.v4.reported {
c.Reported = append(c.Reported, AddrPortFromIp4AndPort(a)) c.Reported = append(c.Reported, protoV4AddrPortToNetAddrPort(a))
} }
} }
if mc.v6 != nil { if mc.v6 != nil {
if mc.v6.learned != nil { if mc.v6.learned != nil {
c.Learned = append(c.Learned, AddrPortFromIp6AndPort(mc.v6.learned)) c.Learned = append(c.Learned, protoV6AddrPortToNetAddrPort(mc.v6.learned))
} }
for _, a := range mc.v6.reported { for _, a := range mc.v6.reported {
c.Reported = append(c.Reported, AddrPortFromIp6AndPort(a)) c.Reported = append(c.Reported, protoV6AddrPortToNetAddrPort(a))
} }
} }
@ -379,7 +382,6 @@ func (r *RemoteList) Rebuild(preferredRanges []netip.Prefix) {
defer r.Unlock() defer r.Unlock()
// Only rebuild if the cache changed // Only rebuild if the cache changed
//TODO: shouldRebuild is probably pointless as we don't check for actual change when lighthouse updates come in
if r.shouldRebuild { if r.shouldRebuild {
r.unlockedCollect() r.unlockedCollect()
r.shouldRebuild = false r.shouldRebuild = false
@ -401,14 +403,14 @@ func (r *RemoteList) unlockedIsBad(remote netip.AddrPort) bool {
// unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the // unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the
// deduplicated address list as dirty // deduplicated address list as dirty
func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp netip.Addr, to *Ip4AndPort) { func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp netip.Addr, to *V4AddrPort) {
r.shouldRebuild = true r.shouldRebuild = true
r.unlockedGetOrMakeV4(ownerVpnIp).learned = to r.unlockedGetOrMakeV4(ownerVpnIp).learned = to
} }
// unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided // unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
// and marks the deduplicated address list as dirty // and marks the deduplicated address list as dirty
func (r *RemoteList) unlockedSetV4(ownerVpnIp, vpnIp netip.Addr, to []*Ip4AndPort, check checkFuncV4) { func (r *RemoteList) unlockedSetV4(ownerVpnIp, vpnIp netip.Addr, to []*V4AddrPort, check checkFuncV4) {
r.shouldRebuild = true r.shouldRebuild = true
c := r.unlockedGetOrMakeV4(ownerVpnIp) c := r.unlockedGetOrMakeV4(ownerVpnIp)
@ -423,7 +425,7 @@ func (r *RemoteList) unlockedSetV4(ownerVpnIp, vpnIp netip.Addr, to []*Ip4AndPor
} }
} }
func (r *RemoteList) unlockedSetRelay(ownerVpnIp, vpnIp netip.Addr, to []netip.Addr) { func (r *RemoteList) unlockedSetRelay(ownerVpnIp netip.Addr, to []netip.Addr) {
r.shouldRebuild = true r.shouldRebuild = true
c := r.unlockedGetOrMakeRelay(ownerVpnIp) c := r.unlockedGetOrMakeRelay(ownerVpnIp)
@ -436,12 +438,12 @@ func (r *RemoteList) unlockedSetRelay(ownerVpnIp, vpnIp netip.Addr, to []netip.A
// unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner // unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner
// This is only useful for establishing static hosts // This is only useful for establishing static hosts
func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *Ip4AndPort) { func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *V4AddrPort) {
r.shouldRebuild = true r.shouldRebuild = true
c := r.unlockedGetOrMakeV4(ownerVpnIp) c := r.unlockedGetOrMakeV4(ownerVpnIp)
// We are doing the easy append because this is rarely called // We are doing the easy append because this is rarely called
c.reported = append([]*Ip4AndPort{to}, c.reported...) c.reported = append([]*V4AddrPort{to}, c.reported...)
if len(c.reported) > MaxRemotes { if len(c.reported) > MaxRemotes {
c.reported = c.reported[:MaxRemotes] c.reported = c.reported[:MaxRemotes]
} }
@ -449,14 +451,14 @@ func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *Ip4AndPort) {
// unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the // unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the
// deduplicated address list as dirty // deduplicated address list as dirty
func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp netip.Addr, to *Ip6AndPort) { func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp netip.Addr, to *V6AddrPort) {
r.shouldRebuild = true r.shouldRebuild = true
r.unlockedGetOrMakeV6(ownerVpnIp).learned = to r.unlockedGetOrMakeV6(ownerVpnIp).learned = to
} }
// unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided // unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
// and marks the deduplicated address list as dirty // and marks the deduplicated address list as dirty
func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*Ip6AndPort, check checkFuncV6) { func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*V6AddrPort, check checkFuncV6) {
r.shouldRebuild = true r.shouldRebuild = true
c := r.unlockedGetOrMakeV6(ownerVpnIp) c := r.unlockedGetOrMakeV6(ownerVpnIp)
@ -473,12 +475,12 @@ func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*Ip6AndPor
// unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner // unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner
// This is only useful for establishing static hosts // This is only useful for establishing static hosts
func (r *RemoteList) unlockedPrependV6(ownerVpnIp netip.Addr, to *Ip6AndPort) { func (r *RemoteList) unlockedPrependV6(ownerVpnIp netip.Addr, to *V6AddrPort) {
r.shouldRebuild = true r.shouldRebuild = true
c := r.unlockedGetOrMakeV6(ownerVpnIp) c := r.unlockedGetOrMakeV6(ownerVpnIp)
// We are doing the easy append because this is rarely called // We are doing the easy append because this is rarely called
c.reported = append([]*Ip6AndPort{to}, c.reported...) c.reported = append([]*V6AddrPort{to}, c.reported...)
if len(c.reported) > MaxRemotes { if len(c.reported) > MaxRemotes {
c.reported = c.reported[:MaxRemotes] c.reported = c.reported[:MaxRemotes]
} }
@ -536,14 +538,14 @@ func (r *RemoteList) unlockedCollect() {
for _, c := range r.cache { for _, c := range r.cache {
if c.v4 != nil { if c.v4 != nil {
if c.v4.learned != nil { if c.v4.learned != nil {
u := AddrPortFromIp4AndPort(c.v4.learned) u := protoV4AddrPortToNetAddrPort(c.v4.learned)
if !r.unlockedIsBad(u) { if !r.unlockedIsBad(u) {
addrs = append(addrs, u) addrs = append(addrs, u)
} }
} }
for _, v := range c.v4.reported { for _, v := range c.v4.reported {
u := AddrPortFromIp4AndPort(v) u := protoV4AddrPortToNetAddrPort(v)
if !r.unlockedIsBad(u) { if !r.unlockedIsBad(u) {
addrs = append(addrs, u) addrs = append(addrs, u)
} }
@ -552,14 +554,14 @@ func (r *RemoteList) unlockedCollect() {
if c.v6 != nil { if c.v6 != nil {
if c.v6.learned != nil { if c.v6.learned != nil {
u := AddrPortFromIp6AndPort(c.v6.learned) u := protoV6AddrPortToNetAddrPort(c.v6.learned)
if !r.unlockedIsBad(u) { if !r.unlockedIsBad(u) {
addrs = append(addrs, u) addrs = append(addrs, u)
} }
} }
for _, v := range c.v6.reported { for _, v := range c.v6.reported {
u := AddrPortFromIp6AndPort(v) u := protoV6AddrPortToNetAddrPort(v)
if !r.unlockedIsBad(u) { if !r.unlockedIsBad(u) {
addrs = append(addrs, u) addrs = append(addrs, u)
} }
@ -573,7 +575,7 @@ func (r *RemoteList) unlockedCollect() {
} }
} }
dnsAddrs := r.hr.GetIPs() dnsAddrs := r.hr.GetAddrs()
for _, addr := range dnsAddrs { for _, addr := range dnsAddrs {
if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) { if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) {
if !r.unlockedIsBad(addr) { if !r.unlockedIsBad(addr) {
@ -589,6 +591,21 @@ func (r *RemoteList) unlockedCollect() {
// unlockedSort assumes you have the write lock and performs the deduping and sorting of the address list // unlockedSort assumes you have the write lock and performs the deduping and sorting of the address list
func (r *RemoteList) unlockedSort(preferredRanges []netip.Prefix) { func (r *RemoteList) unlockedSort(preferredRanges []netip.Prefix) {
// Use a map to deduplicate any relay addresses
dedupedRelays := map[netip.Addr]struct{}{}
for _, relay := range r.relays {
dedupedRelays[relay] = struct{}{}
}
r.relays = r.relays[:0]
for relay := range dedupedRelays {
r.relays = append(r.relays, relay)
}
// Put them in a somewhat consistent order after de-duplication
slices.SortFunc(r.relays, func(a, b netip.Addr) int {
return a.Compare(b)
})
// Now the addrs
n := len(r.addrs) n := len(r.addrs)
if n < 2 { if n < 2 {
return return
@ -687,7 +704,6 @@ func minInt(a, b int) int {
// isPreferred returns true of the ip is contained in the preferredRanges list // isPreferred returns true of the ip is contained in the preferredRanges list
func isPreferred(ip netip.Addr, preferredRanges []netip.Prefix) bool { func isPreferred(ip netip.Addr, preferredRanges []netip.Prefix) bool {
//TODO: this would be better in a CIDR6Tree
for _, p := range preferredRanges { for _, p := range preferredRanges {
if p.Contains(ip) { if p.Contains(ip) {
return true return true

View File

@ -9,11 +9,11 @@ import (
) )
func TestRemoteList_Rebuild(t *testing.T) { func TestRemoteList_Rebuild(t *testing.T) {
rl := NewRemoteList(nil) rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil)
rl.unlockedSetV4( rl.unlockedSetV4(
netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"),
netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"),
[]*Ip4AndPort{ []*V4AddrPort{
newIp4AndPortFromString("70.199.182.92:1475"), // this is duped newIp4AndPortFromString("70.199.182.92:1475"), // this is duped
newIp4AndPortFromString("172.17.0.182:10101"), newIp4AndPortFromString("172.17.0.182:10101"),
newIp4AndPortFromString("172.17.1.1:10101"), // this is duped newIp4AndPortFromString("172.17.1.1:10101"), // this is duped
@ -25,20 +25,30 @@ func TestRemoteList_Rebuild(t *testing.T) {
newIp4AndPortFromString("70.199.182.92:1476"), // almost dupe of 0 with a diff port newIp4AndPortFromString("70.199.182.92:1476"), // almost dupe of 0 with a diff port
newIp4AndPortFromString("70.199.182.92:1475"), // this is a dupe newIp4AndPortFromString("70.199.182.92:1475"), // this is a dupe
}, },
func(netip.Addr, *Ip4AndPort) bool { return true }, func(netip.Addr, *V4AddrPort) bool { return true },
) )
rl.unlockedSetV6( rl.unlockedSetV6(
netip.MustParseAddr("0.0.0.1"), netip.MustParseAddr("0.0.0.1"),
netip.MustParseAddr("0.0.0.1"), netip.MustParseAddr("0.0.0.1"),
[]*Ip6AndPort{ []*V6AddrPort{
newIp6AndPortFromString("[1::1]:1"), // this is duped newIp6AndPortFromString("[1::1]:1"), // this is duped
newIp6AndPortFromString("[1::1]:2"), // almost dupe of 0 with a diff port, also gets duped newIp6AndPortFromString("[1::1]:2"), // almost dupe of 0 with a diff port, also gets duped
newIp6AndPortFromString("[1:100::1]:1"), newIp6AndPortFromString("[1:100::1]:1"),
newIp6AndPortFromString("[1::1]:1"), // this is a dupe newIp6AndPortFromString("[1::1]:1"), // this is a dupe
newIp6AndPortFromString("[1::1]:2"), // this is a dupe newIp6AndPortFromString("[1::1]:2"), // this is a dupe
}, },
func(netip.Addr, *Ip6AndPort) bool { return true }, func(netip.Addr, *V6AddrPort) bool { return true },
)
rl.unlockedSetRelay(
netip.MustParseAddr("0.0.0.1"),
[]netip.Addr{
netip.MustParseAddr("1::1"),
netip.MustParseAddr("1.2.3.4"),
netip.MustParseAddr("1.2.3.4"),
netip.MustParseAddr("1::1"),
},
) )
rl.Rebuild([]netip.Prefix{}) rl.Rebuild([]netip.Prefix{})
@ -76,6 +86,11 @@ func TestRemoteList_Rebuild(t *testing.T) {
assert.Equal(t, "[1::1]:2", rl.addrs[8].String()) assert.Equal(t, "[1::1]:2", rl.addrs[8].String())
assert.Equal(t, "[1:100::1]:1", rl.addrs[9].String()) assert.Equal(t, "[1:100::1]:1", rl.addrs[9].String())
// assert relay deduplicated
assert.Len(t, rl.relays, 2)
assert.Equal(t, "1.2.3.4", rl.relays[0].String())
assert.Equal(t, "1::1", rl.relays[1].String())
// Ensure we can hoist a specific ipv4 range over anything else // Ensure we can hoist a specific ipv4 range over anything else
rl.Rebuild([]netip.Prefix{netip.MustParsePrefix("172.17.0.0/16")}) rl.Rebuild([]netip.Prefix{netip.MustParsePrefix("172.17.0.0/16")})
assert.Len(t, rl.addrs, 10, "addrs contains too many entries") assert.Len(t, rl.addrs, 10, "addrs contains too many entries")
@ -98,11 +113,11 @@ func TestRemoteList_Rebuild(t *testing.T) {
} }
func BenchmarkFullRebuild(b *testing.B) { func BenchmarkFullRebuild(b *testing.B) {
rl := NewRemoteList(nil) rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil)
rl.unlockedSetV4( rl.unlockedSetV4(
netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"),
netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"),
[]*Ip4AndPort{ []*V4AddrPort{
newIp4AndPortFromString("70.199.182.92:1475"), newIp4AndPortFromString("70.199.182.92:1475"),
newIp4AndPortFromString("172.17.0.182:10101"), newIp4AndPortFromString("172.17.0.182:10101"),
newIp4AndPortFromString("172.17.1.1:10101"), newIp4AndPortFromString("172.17.1.1:10101"),
@ -112,19 +127,19 @@ func BenchmarkFullRebuild(b *testing.B) {
newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe
newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port
}, },
func(netip.Addr, *Ip4AndPort) bool { return true }, func(netip.Addr, *V4AddrPort) bool { return true },
) )
rl.unlockedSetV6( rl.unlockedSetV6(
netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"),
netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"),
[]*Ip6AndPort{ []*V6AddrPort{
newIp6AndPortFromString("[1::1]:1"), newIp6AndPortFromString("[1::1]:1"),
newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port
newIp6AndPortFromString("[1:100::1]:1"), newIp6AndPortFromString("[1:100::1]:1"),
newIp6AndPortFromString("[1::1]:1"), // this is a dupe newIp6AndPortFromString("[1::1]:1"), // this is a dupe
}, },
func(netip.Addr, *Ip6AndPort) bool { return true }, func(netip.Addr, *V6AddrPort) bool { return true },
) )
b.Run("no preferred", func(b *testing.B) { b.Run("no preferred", func(b *testing.B) {
@ -160,11 +175,11 @@ func BenchmarkFullRebuild(b *testing.B) {
} }
func BenchmarkSortRebuild(b *testing.B) { func BenchmarkSortRebuild(b *testing.B) {
rl := NewRemoteList(nil) rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil)
rl.unlockedSetV4( rl.unlockedSetV4(
netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"),
netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"),
[]*Ip4AndPort{ []*V4AddrPort{
newIp4AndPortFromString("70.199.182.92:1475"), newIp4AndPortFromString("70.199.182.92:1475"),
newIp4AndPortFromString("172.17.0.182:10101"), newIp4AndPortFromString("172.17.0.182:10101"),
newIp4AndPortFromString("172.17.1.1:10101"), newIp4AndPortFromString("172.17.1.1:10101"),
@ -174,19 +189,19 @@ func BenchmarkSortRebuild(b *testing.B) {
newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe
newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port
}, },
func(netip.Addr, *Ip4AndPort) bool { return true }, func(netip.Addr, *V4AddrPort) bool { return true },
) )
rl.unlockedSetV6( rl.unlockedSetV6(
netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"),
netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"),
[]*Ip6AndPort{ []*V6AddrPort{
newIp6AndPortFromString("[1::1]:1"), newIp6AndPortFromString("[1::1]:1"),
newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port
newIp6AndPortFromString("[1:100::1]:1"), newIp6AndPortFromString("[1:100::1]:1"),
newIp6AndPortFromString("[1::1]:1"), // this is a dupe newIp6AndPortFromString("[1::1]:1"), // this is a dupe
}, },
func(netip.Addr, *Ip6AndPort) bool { return true }, func(netip.Addr, *V6AddrPort) bool { return true },
) )
b.Run("no preferred", func(b *testing.B) { b.Run("no preferred", func(b *testing.B) {
@ -224,19 +239,19 @@ func BenchmarkSortRebuild(b *testing.B) {
}) })
} }
func newIp4AndPortFromString(s string) *Ip4AndPort { func newIp4AndPortFromString(s string) *V4AddrPort {
a := netip.MustParseAddrPort(s) a := netip.MustParseAddrPort(s)
v4Addr := a.Addr().As4() v4Addr := a.Addr().As4()
return &Ip4AndPort{ return &V4AddrPort{
Ip: binary.BigEndian.Uint32(v4Addr[:]), Addr: binary.BigEndian.Uint32(v4Addr[:]),
Port: uint32(a.Port()), Port: uint32(a.Port()),
} }
} }
func newIp6AndPortFromString(s string) *Ip6AndPort { func newIp6AndPortFromString(s string) *V6AddrPort {
a := netip.MustParseAddrPort(s) a := netip.MustParseAddrPort(s)
v6Addr := a.Addr().As16() v6Addr := a.Addr().As16()
return &Ip6AndPort{ return &V6AddrPort{
Hi: binary.BigEndian.Uint64(v6Addr[:8]), Hi: binary.BigEndian.Uint64(v6Addr[:8]),
Lo: binary.BigEndian.Uint64(v6Addr[8:]), Lo: binary.BigEndian.Uint64(v6Addr[8:]),
Port: uint32(a.Port()), Port: uint32(a.Port()),

View File

@ -90,9 +90,9 @@ func New(config *config.C) (*Service, error) {
}, },
}) })
ipNet := device.Cidr() ipNet := device.Networks()
pa := tcpip.ProtocolAddress{ pa := tcpip.ProtocolAddress{
AddressWithPrefix: tcpip.AddrFromSlice(ipNet.Addr().AsSlice()).WithPrefix(), AddressWithPrefix: tcpip.AddrFromSlice(ipNet[0].Addr().AsSlice()).WithPrefix(),
Protocol: ipv4.ProtocolNumber, Protocol: ipv4.ProtocolNumber,
} }
if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{ if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{

View File

@ -10,8 +10,8 @@ import (
"dario.cat/mergo" "dario.cat/mergo"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/e2e"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
) )
@ -19,7 +19,7 @@ import (
type m map[string]interface{} type m map[string]interface{}
func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service { func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service {
_, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.PrefixFrom(udpIp, 24)}, nil, []string{}) _, _, myPrivKey, myPEM := cert_test.NewTestCert(cert.Version2, cert.Curve_CURVE25519, caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.PrefixFrom(udpIp, 24)}, nil, []string{})
caB, err := caCrt.MarshalPEM() caB, err := caCrt.MarshalPEM()
if err != nil { if err != nil {
panic(err) panic(err)
@ -79,7 +79,7 @@ func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp n
} }
func TestService(t *testing.T) { func TestService(t *testing.T) {
ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
a := newSimpleService(ca, caKey, "a", netip.MustParseAddr("10.0.0.1"), m{ a := newSimpleService(ca, caKey, "a", netip.MustParseAddr("10.0.0.1"), m{
"static_host_map": m{}, "static_host_map": m{},
"lighthouse": m{ "lighthouse": m{

161
ssh.go
View File

@ -77,9 +77,6 @@ func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) {
// that callers may invoke to run the configured ssh server. On // that callers may invoke to run the configured ssh server. On
// failure, it returns nil, error. // failure, it returns nil, error.
func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), error) { func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), error) {
//TODO conntrack list
//TODO print firewall rules or hash?
listen := c.GetString("sshd.listen", "") listen := c.GetString("sshd.listen", "")
if listen == "" { if listen == "" {
return nil, fmt.Errorf("sshd.listen must be provided") return nil, fmt.Errorf("sshd.listen must be provided")
@ -93,7 +90,6 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
return nil, fmt.Errorf("sshd.listen can not use port 22") return nil, fmt.Errorf("sshd.listen can not use port 22")
} }
//TODO: no good way to reload this right now
hostKeyPathOrKey := c.GetString("sshd.host_key", "") hostKeyPathOrKey := c.GetString("sshd.host_key", "")
if hostKeyPathOrKey == "" { if hostKeyPathOrKey == "" {
return nil, fmt.Errorf("sshd.host_key must be provided") return nil, fmt.Errorf("sshd.host_key must be provided")
@ -320,7 +316,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{ ssh.RegisterCommand(&sshd.Command{
Name: "print-cert", Name: "print-cert",
ShortDescription: "Prints the current certificate being used or the certificate for the provided vpn ip", ShortDescription: "Prints the current certificate being used or the certificate for the provided vpn addr",
Flags: func() (*flag.FlagSet, interface{}) { Flags: func() (*flag.FlagSet, interface{}) {
fl := flag.NewFlagSet("", flag.ContinueOnError) fl := flag.NewFlagSet("", flag.ContinueOnError)
s := sshPrintCertFlags{} s := sshPrintCertFlags{}
@ -336,7 +332,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{ ssh.RegisterCommand(&sshd.Command{
Name: "print-tunnel", Name: "print-tunnel",
ShortDescription: "Prints json details about a tunnel for the provided vpn ip", ShortDescription: "Prints json details about a tunnel for the provided vpn addr",
Flags: func() (*flag.FlagSet, interface{}) { Flags: func() (*flag.FlagSet, interface{}) {
fl := flag.NewFlagSet("", flag.ContinueOnError) fl := flag.NewFlagSet("", flag.ContinueOnError)
s := sshPrintTunnelFlags{} s := sshPrintTunnelFlags{}
@ -364,7 +360,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{ ssh.RegisterCommand(&sshd.Command{
Name: "change-remote", Name: "change-remote",
ShortDescription: "Changes the remote address used in the tunnel for the provided vpn ip", ShortDescription: "Changes the remote address used in the tunnel for the provided vpn addr",
Flags: func() (*flag.FlagSet, interface{}) { Flags: func() (*flag.FlagSet, interface{}) {
fl := flag.NewFlagSet("", flag.ContinueOnError) fl := flag.NewFlagSet("", flag.ContinueOnError)
s := sshChangeRemoteFlags{} s := sshChangeRemoteFlags{}
@ -378,7 +374,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{ ssh.RegisterCommand(&sshd.Command{
Name: "close-tunnel", Name: "close-tunnel",
ShortDescription: "Closes a tunnel for the provided vpn ip", ShortDescription: "Closes a tunnel for the provided vpn addr",
Flags: func() (*flag.FlagSet, interface{}) { Flags: func() (*flag.FlagSet, interface{}) {
fl := flag.NewFlagSet("", flag.ContinueOnError) fl := flag.NewFlagSet("", flag.ContinueOnError)
s := sshCloseTunnelFlags{} s := sshCloseTunnelFlags{}
@ -392,7 +388,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{ ssh.RegisterCommand(&sshd.Command{
Name: "create-tunnel", Name: "create-tunnel",
ShortDescription: "Creates a tunnel for the provided vpn ip and address", ShortDescription: "Creates a tunnel for the provided vpn address",
Help: "The lighthouses will be queried for real addresses but you can provide one as well.", Help: "The lighthouses will be queried for real addresses but you can provide one as well.",
Flags: func() (*flag.FlagSet, interface{}) { Flags: func() (*flag.FlagSet, interface{}) {
fl := flag.NewFlagSet("", flag.ContinueOnError) fl := flag.NewFlagSet("", flag.ContinueOnError)
@ -407,8 +403,8 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{ ssh.RegisterCommand(&sshd.Command{
Name: "query-lighthouse", Name: "query-lighthouse",
ShortDescription: "Query the lighthouses for the provided vpn ip", ShortDescription: "Query the lighthouses for the provided vpn address",
Help: "This command is asynchronous. Only currently known udp ips will be printed.", Help: "This command is asynchronous. Only currently known udp addresses will be printed.",
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshQueryLighthouse(f, fs, a, w) return sshQueryLighthouse(f, fs, a, w)
}, },
@ -418,7 +414,6 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) error { func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) error {
fs, ok := a.(*sshListHostMapFlags) fs, ok := a.(*sshListHostMapFlags)
if !ok { if !ok {
//TODO: error
return nil return nil
} }
@ -430,7 +425,7 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er
} }
sort.Slice(hm, func(i, j int) bool { sort.Slice(hm, func(i, j int) bool {
return hm[i].VpnIp.Compare(hm[j].VpnIp) < 0 return hm[i].VpnAddrs[0].Compare(hm[j].VpnAddrs[0]) < 0
}) })
if fs.Json || fs.Pretty { if fs.Json || fs.Pretty {
@ -441,13 +436,12 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er
err := js.Encode(hm) err := js.Encode(hm)
if err != nil { if err != nil {
//TODO
return nil return nil
} }
} else { } else {
for _, v := range hm { for _, v := range hm {
err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIp, v.RemoteAddrs)) err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnAddrs, v.RemoteAddrs))
if err != nil { if err != nil {
return err return err
} }
@ -460,13 +454,12 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er
func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWriter) error { func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWriter) error {
fs, ok := a.(*sshListHostMapFlags) fs, ok := a.(*sshListHostMapFlags)
if !ok { if !ok {
//TODO: error
return nil return nil
} }
type lighthouseInfo struct { type lighthouseInfo struct {
VpnIp string `json:"vpnIp"` VpnAddr string `json:"vpnAddr"`
Addrs *CacheMap `json:"addrs"` Addrs *CacheMap `json:"addrs"`
} }
lightHouse.RLock() lightHouse.RLock()
@ -474,15 +467,15 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
x := 0 x := 0
for k, v := range lightHouse.addrMap { for k, v := range lightHouse.addrMap {
addrMap[x] = lighthouseInfo{ addrMap[x] = lighthouseInfo{
VpnIp: k.String(), VpnAddr: k.String(),
Addrs: v.CopyCache(), Addrs: v.CopyCache(),
} }
x++ x++
} }
lightHouse.RUnlock() lightHouse.RUnlock()
sort.Slice(addrMap, func(i, j int) bool { sort.Slice(addrMap, func(i, j int) bool {
return strings.Compare(addrMap[i].VpnIp, addrMap[j].VpnIp) < 0 return strings.Compare(addrMap[i].VpnAddr, addrMap[j].VpnAddr) < 0
}) })
if fs.Json || fs.Pretty { if fs.Json || fs.Pretty {
@ -493,7 +486,6 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
err := js.Encode(addrMap) err := js.Encode(addrMap)
if err != nil { if err != nil {
//TODO
return nil return nil
} }
@ -503,7 +495,7 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
if err != nil { if err != nil {
return err return err
} }
err = w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIp, string(b))) err = w.WriteLine(fmt.Sprintf("%s: %s", v.VpnAddr, string(b)))
if err != nil { if err != nil {
return err return err
} }
@ -541,20 +533,20 @@ func sshVersion(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter
func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
if len(a) == 0 { if len(a) == 0 {
return w.WriteLine("No vpn ip was provided") return w.WriteLine("No vpn address was provided")
} }
vpnIp, err := netip.ParseAddr(a[0]) vpnAddr, err := netip.ParseAddr(a[0])
if err != nil { if err != nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0]))
} }
if !vpnIp.IsValid() { if !vpnAddr.IsValid() {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0]))
} }
var cm *CacheMap var cm *CacheMap
rl := ifce.lightHouse.Query(vpnIp) rl := ifce.lightHouse.Query(vpnAddr)
if rl != nil { if rl != nil {
cm = rl.CopyCache() cm = rl.CopyCache()
} }
@ -564,26 +556,25 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri
func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
flags, ok := fs.(*sshCloseTunnelFlags) flags, ok := fs.(*sshCloseTunnelFlags)
if !ok { if !ok {
//TODO: error
return nil return nil
} }
if len(a) == 0 { if len(a) == 0 {
return w.WriteLine("No vpn ip was provided") return w.WriteLine("No vpn address was provided")
} }
vpnIp, err := netip.ParseAddr(a[0]) vpnAddr, err := netip.ParseAddr(a[0])
if err != nil { if err != nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0]))
} }
if !vpnIp.IsValid() { if !vpnAddr.IsValid() {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0]))
} }
hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr)
if hostInfo == nil { if hostInfo == nil {
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn address: %v", a[0]))
} }
if !flags.LocalOnly { if !flags.LocalOnly {
@ -605,29 +596,28 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
flags, ok := fs.(*sshCreateTunnelFlags) flags, ok := fs.(*sshCreateTunnelFlags)
if !ok { if !ok {
//TODO: error
return nil return nil
} }
if len(a) == 0 { if len(a) == 0 {
return w.WriteLine("No vpn ip was provided") return w.WriteLine("No vpn address was provided")
} }
vpnIp, err := netip.ParseAddr(a[0]) vpnAddr, err := netip.ParseAddr(a[0])
if err != nil { if err != nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0]))
} }
if !vpnIp.IsValid() { if !vpnAddr.IsValid() {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0]))
} }
hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr)
if hostInfo != nil { if hostInfo != nil {
return w.WriteLine(fmt.Sprintf("Tunnel already exists")) return w.WriteLine(fmt.Sprintf("Tunnel already exists"))
} }
hostInfo = ifce.handshakeManager.QueryVpnIp(vpnIp) hostInfo = ifce.handshakeManager.QueryVpnAddr(vpnAddr)
if hostInfo != nil { if hostInfo != nil {
return w.WriteLine(fmt.Sprintf("Tunnel already handshaking")) return w.WriteLine(fmt.Sprintf("Tunnel already handshaking"))
} }
@ -640,7 +630,7 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
} }
} }
hostInfo = ifce.handshakeManager.StartHandshake(vpnIp, nil) hostInfo = ifce.handshakeManager.StartHandshake(vpnAddr, nil)
if addr.IsValid() { if addr.IsValid() {
hostInfo.SetRemote(addr) hostInfo.SetRemote(addr)
} }
@ -651,12 +641,11 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
flags, ok := fs.(*sshChangeRemoteFlags) flags, ok := fs.(*sshChangeRemoteFlags)
if !ok { if !ok {
//TODO: error
return nil return nil
} }
if len(a) == 0 { if len(a) == 0 {
return w.WriteLine("No vpn ip was provided") return w.WriteLine("No vpn address was provided")
} }
if flags.Address == "" { if flags.Address == "" {
@ -668,18 +657,18 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW
return w.WriteLine("Address could not be parsed") return w.WriteLine("Address could not be parsed")
} }
vpnIp, err := netip.ParseAddr(a[0]) vpnAddr, err := netip.ParseAddr(a[0])
if err != nil { if err != nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0]))
} }
if !vpnIp.IsValid() { if !vpnAddr.IsValid() {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0]))
} }
hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr)
if hostInfo == nil { if hostInfo == nil {
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn address: %v", a[0]))
} }
hostInfo.SetRemote(addr) hostInfo.SetRemote(addr)
@ -781,24 +770,23 @@ func sshLogFormat(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWri
func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
args, ok := fs.(*sshPrintCertFlags) args, ok := fs.(*sshPrintCertFlags)
if !ok { if !ok {
//TODO: error
return nil return nil
} }
cert := ifce.pki.GetCertState().Certificate cert := ifce.pki.getCertState().GetDefaultCertificate()
if len(a) > 0 { if len(a) > 0 {
vpnIp, err := netip.ParseAddr(a[0]) vpnAddr, err := netip.ParseAddr(a[0])
if err != nil { if err != nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn addr could not be parsed: %s", a[0]))
} }
if !vpnIp.IsValid() { if !vpnAddr.IsValid() {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn addr could not be parsed: %s", a[0]))
} }
hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr)
if hostInfo == nil { if hostInfo == nil {
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn addr: %v", a[0]))
} }
cert = hostInfo.GetCert().Certificate cert = hostInfo.GetCert().Certificate
@ -807,7 +795,6 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
if args.Json || args.Pretty { if args.Json || args.Pretty {
b, err := cert.MarshalJSON() b, err := cert.MarshalJSON()
if err != nil { if err != nil {
//TODO: handle it
return nil return nil
} }
@ -816,7 +803,6 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
err := json.Indent(buf, b, "", " ") err := json.Indent(buf, b, "", " ")
b = buf.Bytes() b = buf.Bytes()
if err != nil { if err != nil {
//TODO: handle it
return nil return nil
} }
} }
@ -827,7 +813,6 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
if args.Raw { if args.Raw {
b, err := cert.MarshalPEM() b, err := cert.MarshalPEM()
if err != nil { if err != nil {
//TODO: handle it
return nil return nil
} }
@ -840,7 +825,6 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
args, ok := fs.(*sshPrintTunnelFlags) args, ok := fs.(*sshPrintTunnelFlags)
if !ok { if !ok {
//TODO: error
w.WriteLine(fmt.Sprintf("sshPrintRelays failed to convert args type")) w.WriteLine(fmt.Sprintf("sshPrintRelays failed to convert args type"))
return nil return nil
} }
@ -856,15 +840,15 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
Error error Error error
Type string Type string
State string State string
PeerIp netip.Addr PeerAddr netip.Addr
LocalIndex uint32 LocalIndex uint32
RemoteIndex uint32 RemoteIndex uint32
RelayedThrough []netip.Addr RelayedThrough []netip.Addr
} }
type RelayOutput struct { type RelayOutput struct {
NebulaIp netip.Addr NebulaAddr netip.Addr
RelayForIps []RelayFor RelayForAddrs []RelayFor
} }
type CmdOutput struct { type CmdOutput struct {
@ -880,16 +864,16 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
} }
for k, v := range relays { for k, v := range relays {
ro := RelayOutput{NebulaIp: v.vpnIp} ro := RelayOutput{NebulaAddr: v.vpnAddrs[0]}
co.Relays = append(co.Relays, &ro) co.Relays = append(co.Relays, &ro)
relayHI := ifce.hostMap.QueryVpnIp(v.vpnIp) relayHI := ifce.hostMap.QueryVpnAddr(v.vpnAddrs[0])
if relayHI == nil { if relayHI == nil {
ro.RelayForIps = append(ro.RelayForIps, RelayFor{Error: errors.New("could not find hostinfo")}) ro.RelayForAddrs = append(ro.RelayForAddrs, RelayFor{Error: errors.New("could not find hostinfo")})
continue continue
} }
for _, vpnIp := range relayHI.relayState.CopyRelayForIps() { for _, vpnAddr := range relayHI.relayState.CopyRelayForIps() {
rf := RelayFor{Error: nil} rf := RelayFor{Error: nil}
r, ok := relayHI.relayState.GetRelayForByIp(vpnIp) r, ok := relayHI.relayState.GetRelayForByAddr(vpnAddr)
if ok { if ok {
t := "" t := ""
switch r.Type { switch r.Type {
@ -913,19 +897,19 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
rf.LocalIndex = r.LocalIndex rf.LocalIndex = r.LocalIndex
rf.RemoteIndex = r.RemoteIndex rf.RemoteIndex = r.RemoteIndex
rf.PeerIp = r.PeerIp rf.PeerAddr = r.PeerAddr
rf.Type = t rf.Type = t
rf.State = s rf.State = s
if rf.LocalIndex != k { if rf.LocalIndex != k {
rf.Error = fmt.Errorf("hostmap LocalIndex '%v' does not match RelayState LocalIndex", k) rf.Error = fmt.Errorf("hostmap LocalIndex '%v' does not match RelayState LocalIndex", k)
} }
} }
relayedHI := ifce.hostMap.QueryVpnIp(vpnIp) relayedHI := ifce.hostMap.QueryVpnAddr(vpnAddr)
if relayedHI != nil { if relayedHI != nil {
rf.RelayedThrough = append(rf.RelayedThrough, relayedHI.relayState.CopyRelayIps()...) rf.RelayedThrough = append(rf.RelayedThrough, relayedHI.relayState.CopyRelayIps()...)
} }
ro.RelayForIps = append(ro.RelayForIps, rf) ro.RelayForAddrs = append(ro.RelayForAddrs, rf)
} }
} }
err := enc.Encode(co) err := enc.Encode(co)
@ -938,26 +922,25 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
args, ok := fs.(*sshPrintTunnelFlags) args, ok := fs.(*sshPrintTunnelFlags)
if !ok { if !ok {
//TODO: error
return nil return nil
} }
if len(a) == 0 { if len(a) == 0 {
return w.WriteLine("No vpn ip was provided") return w.WriteLine("No vpn address was provided")
} }
vpnIp, err := netip.ParseAddr(a[0]) vpnAddr, err := netip.ParseAddr(a[0])
if err != nil { if err != nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn addr could not be parsed: %s", a[0]))
} }
if !vpnIp.IsValid() { if !vpnAddr.IsValid() {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) return w.WriteLine(fmt.Sprintf("The provided vpn addr could not be parsed: %s", a[0]))
} }
hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr)
if hostInfo == nil { if hostInfo == nil {
return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn addr: %v", a[0]))
} }
enc := json.NewEncoder(w.GetWriter()) enc := json.NewEncoder(w.GetWriter())
@ -971,13 +954,15 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
func sshDeviceInfo(ifce *Interface, fs interface{}, w sshd.StringWriter) error { func sshDeviceInfo(ifce *Interface, fs interface{}, w sshd.StringWriter) error {
data := struct { data := struct {
Name string `json:"name"` Name string `json:"name"`
Cidr string `json:"cidr"` Cidr []netip.Prefix `json:"cidr"`
}{ }{
Name: ifce.inside.Name(), Name: ifce.inside.Name(),
Cidr: ifce.inside.Cidr().String(), Cidr: make([]netip.Prefix, len(ifce.inside.Networks())),
} }
copy(data.Cidr, ifce.inside.Networks())
flags, ok := fs.(*sshDeviceInfoFlags) flags, ok := fs.(*sshDeviceInfoFlags)
if !ok { if !ok {
return fmt.Errorf("internal error: expected flags to be sshDeviceInfoFlags but was %+v", fs) return fmt.Errorf("internal error: expected flags to be sshDeviceInfoFlags but was %+v", fs)

View File

@ -57,7 +57,6 @@ func execCommand(c *Command, args []string, w StringWriter) error {
func dumpCommands(c *radix.Tree, w StringWriter) { func dumpCommands(c *radix.Tree, w StringWriter) {
err := w.WriteLine("Available commands:") err := w.WriteLine("Available commands:")
if err != nil { if err != nil {
//TODO: log
return return
} }
@ -67,10 +66,7 @@ func dumpCommands(c *radix.Tree, w StringWriter) {
} }
sort.Strings(cmds) sort.Strings(cmds)
err = w.Write(strings.Join(cmds, "\n") + "\n\n") _ = w.Write(strings.Join(cmds, "\n") + "\n\n")
if err != nil {
//TODO: log
}
} }
func lookupCommand(c *radix.Tree, sCmd string) (*Command, error) { func lookupCommand(c *radix.Tree, sCmd string) (*Command, error) {
@ -119,8 +115,6 @@ func helpCallback(commands *radix.Tree, a []string, w StringWriter) (err error)
// We are printing a specific commands help text // We are printing a specific commands help text
cmd, err := lookupCommand(commands, a[0]) cmd, err := lookupCommand(commands, a[0])
if err != nil { if err != nil {
//TODO: handle error
//TODO: message the user
return return
} }

View File

@ -80,9 +80,7 @@ func NewSSHServer(l *logrus.Entry) (*SSHServer, error) {
s.config = &ssh.ServerConfig{ s.config = &ssh.ServerConfig{
PublicKeyCallback: cc.Authenticate, PublicKeyCallback: cc.Authenticate,
//TODO: AuthLogCallback: s.authAttempt, ServerVersion: fmt.Sprintf("SSH-2.0-Nebula???"),
//TODO: version string
ServerVersion: fmt.Sprintf("SSH-2.0-Nebula???"),
} }
s.RegisterCommand(&Command{ s.RegisterCommand(&Command{

View File

@ -62,7 +62,6 @@ func (s *session) handleChannels(chans <-chan ssh.NewChannel) {
func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) { func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) {
for req := range in { for req := range in {
var err error var err error
//TODO: maybe support window sizing?
switch req.Type { switch req.Type {
case "shell": case "shell":
if s.term == nil { if s.term == nil {
@ -89,9 +88,7 @@ func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) {
req.Reply(true, nil) req.Reply(true, nil)
s.dispatchCommand(payload.Value, &stringWriter{channel}) s.dispatchCommand(payload.Value, &stringWriter{channel})
//TODO: Fix error handling and report the proper status back
status := struct{ Status uint32 }{uint32(0)} status := struct{ Status uint32 }{uint32(0)}
//TODO: I think this is how we shut down a shell as well?
channel.SendRequest("exit-status", false, ssh.Marshal(status)) channel.SendRequest("exit-status", false, ssh.Marshal(status))
channel.Close() channel.Close()
return return
@ -110,7 +107,6 @@ func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) {
} }
func (s *session) createTerm(channel ssh.Channel) *terminal.Terminal { func (s *session) createTerm(channel ssh.Channel) *terminal.Terminal {
//TODO: PS1 with nebula cert name
term := terminal.NewTerminal(channel, s.c.User()+"@nebula > ") term := terminal.NewTerminal(channel, s.c.User()+"@nebula > ")
term.AutoCompleteCallback = func(line string, pos int, key rune) (newLine string, newPos int, ok bool) { term.AutoCompleteCallback = func(line string, pos int, key rune) (newLine string, newPos int, ok bool) {
// key 9 is tab // key 9 is tab
@ -137,7 +133,6 @@ func (s *session) handleInput(channel ssh.Channel) {
for { for {
line, err := s.term.ReadLine() line, err := s.term.ReadLine()
if err != nil { if err != nil {
//TODO: log
break break
} }
@ -148,7 +143,6 @@ func (s *session) handleInput(channel ssh.Channel) {
func (s *session) dispatchCommand(line string, w StringWriter) { func (s *session) dispatchCommand(line string, w StringWriter) {
args, err := shlex.Split(line, true) args, err := shlex.Split(line, true)
if err != nil { if err != nil {
//todo: LOG IT
return return
} }
@ -159,13 +153,11 @@ func (s *session) dispatchCommand(line string, w StringWriter) {
c, err := lookupCommand(s.commands, args[0]) c, err := lookupCommand(s.commands, args[0])
if err != nil { if err != nil {
//TODO: handle the error
return return
} }
if c == nil { if c == nil {
err := w.WriteLine(fmt.Sprintf("did not understand: %s", line)) err := w.WriteLine(fmt.Sprintf("did not understand: %s", line))
//TODO: log error
_ = err _ = err
dumpCommands(s.commands, w) dumpCommands(s.commands, w)
@ -177,10 +169,7 @@ func (s *session) dispatchCommand(line string, w StringWriter) {
return return
} }
err = execCommand(c, args[1:], w) _ = execCommand(c, args[1:], w)
if err != nil {
//TODO: log the error
}
return return
} }

View File

@ -16,8 +16,8 @@ func (NoopTun) Activate() error {
return nil return nil
} }
func (NoopTun) Cidr() netip.Prefix { func (NoopTun) Networks() []netip.Prefix {
return netip.Prefix{} return []netip.Prefix{}
} }
func (NoopTun) Name() string { func (NoopTun) Name() string {

View File

@ -116,10 +116,10 @@ func TestTimerWheel_Purge(t *testing.T) {
assert.Equal(t, 0, tw.current) assert.Equal(t, 0, tw.current)
fps := []firewall.Packet{ fps := []firewall.Packet{
{LocalIP: netip.MustParseAddr("0.0.0.1")}, {LocalAddr: netip.MustParseAddr("0.0.0.1")},
{LocalIP: netip.MustParseAddr("0.0.0.2")}, {LocalAddr: netip.MustParseAddr("0.0.0.2")},
{LocalIP: netip.MustParseAddr("0.0.0.3")}, {LocalAddr: netip.MustParseAddr("0.0.0.3")},
{LocalIP: netip.MustParseAddr("0.0.0.4")}, {LocalAddr: netip.MustParseAddr("0.0.0.4")},
} }
tw.Add(fps[0], time.Second*1) tw.Add(fps[0], time.Second*1)

View File

@ -4,28 +4,19 @@ import (
"net/netip" "net/netip"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
) )
const MTU = 9001 const MTU = 9001
type EncReader func( type EncReader func(
addr netip.AddrPort, addr netip.AddrPort,
out []byte, payload []byte,
packet []byte,
header *header.H,
fwPacket *firewall.Packet,
lhh LightHouseHandlerFunc,
nb []byte,
q int,
localCache firewall.ConntrackCache,
) )
type Conn interface { type Conn interface {
Rebind() error Rebind() error
LocalAddr() (netip.AddrPort, error) LocalAddr() (netip.AddrPort, error)
ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) ListenOut(r EncReader)
WriteTo(b []byte, addr netip.AddrPort) error WriteTo(b []byte, addr netip.AddrPort) error
ReloadConfig(c *config.C) ReloadConfig(c *config.C)
Close() error Close() error
@ -39,7 +30,7 @@ func (NoopConn) Rebind() error {
func (NoopConn) LocalAddr() (netip.AddrPort, error) { func (NoopConn) LocalAddr() (netip.AddrPort, error) {
return netip.AddrPort{}, nil return netip.AddrPort{}, nil
} }
func (NoopConn) ListenOut(_ EncReader, _ LightHouseHandlerFunc, _ *firewall.ConntrackCacheTicker, _ int) { func (NoopConn) ListenOut(_ EncReader) {
return return
} }
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error { func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {

View File

@ -1,10 +0,0 @@
package udp
import (
"net/netip"
)
//TODO: The items in this file belong in their own packages but doing that in a single PR is a nightmare
// TODO: IPV6-WORK this can likely be removed now
type LightHouseHandlerFunc func(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte)

View File

@ -15,8 +15,6 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
) )
type GenericConn struct { type GenericConn struct {
@ -60,7 +58,7 @@ func (u *GenericConn) LocalAddr() (netip.AddrPort, error) {
} }
func (u *GenericConn) ReloadConfig(c *config.C) { func (u *GenericConn) ReloadConfig(c *config.C) {
// TODO
} }
func NewUDPStatsEmitter(udpConns []Conn) func() { func NewUDPStatsEmitter(udpConns []Conn) func() {
@ -72,12 +70,8 @@ type rawMessage struct {
Len uint32 Len uint32
} }
func (u *GenericConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { func (u *GenericConn) ListenOut(r EncReader) {
plaintext := make([]byte, MTU)
buffer := make([]byte, MTU) buffer := make([]byte, MTU)
h := &header.H{}
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
for { for {
// Just read one packet at a time // Just read one packet at a time
@ -87,16 +81,6 @@ func (u *GenericConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *f
return return
} }
r( r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()),
plaintext[:0],
buffer[:n],
h,
fwPacket,
lhf,
nb,
q,
cache.Get(u.l),
)
} }
} }

Some files were not shown because too many files have changed in this diff Show More