Merge remote-tracking branch 'origin/master' into holepunch-remote-allow-list

This commit is contained in:
Wade Simmons 2025-03-13 09:57:24 -04:00
commit 97f0abdc55
128 changed files with 11044 additions and 7098 deletions

View File

@ -18,7 +18,7 @@ jobs:
- uses: actions/setup-go@v5
with:
go-version: '1.22'
go-version: '1.23'
check-latest: true
- name: Install goimports

View File

@ -14,7 +14,7 @@ jobs:
- uses: actions/setup-go@v5
with:
go-version: '1.22'
go-version: '1.23'
check-latest: true
- name: Build
@ -37,7 +37,7 @@ jobs:
- uses: actions/setup-go@v5
with:
go-version: '1.22'
go-version: '1.23'
check-latest: true
- name: Build
@ -70,7 +70,7 @@ jobs:
- uses: actions/setup-go@v5
with:
go-version: '1.22'
go-version: '1.23'
check-latest: true
- name: Import certificates

View File

@ -27,6 +27,9 @@ jobs:
go-version-file: 'go.mod'
check-latest: true
- name: add hashicorp source
run: wget -O- https://apt.releases.hashicorp.com/gpg | gpg --dearmor | sudo tee /usr/share/keyrings/hashicorp-archive-keyring.gpg && echo "deb [signed-by=/usr/share/keyrings/hashicorp-archive-keyring.gpg] https://apt.releases.hashicorp.com $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/hashicorp.list
- name: install vagrant
run: sudo apt-get update && sudo apt-get install -y vagrant virtualbox

View File

@ -22,7 +22,7 @@ jobs:
- uses: actions/setup-go@v5
with:
go-version: '1.22'
go-version: '1.23'
check-latest: true
- name: build

View File

@ -5,6 +5,10 @@ set -e -x
rm -rf ./build
mkdir ./build
# TODO: Assumes your docker bridge network is a /24, and the first container that launches will be .1
# - We could make this better by launching the lighthouse first and then fetching what IP it is.
NET="$(docker network inspect bridge -f '{{ range .IPAM.Config }}{{ .Subnet }}{{ end }}' | cut -d. -f1-3)"
(
cd build
@ -21,16 +25,16 @@ mkdir ./build
../genconfig.sh >lighthouse1.yml
HOST="host2" \
LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \
LIGHTHOUSES="192.168.100.1 $NET.2:4242" \
../genconfig.sh >host2.yml
HOST="host3" \
LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \
LIGHTHOUSES="192.168.100.1 $NET.2:4242" \
INBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \
../genconfig.sh >host3.yml
HOST="host4" \
LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \
LIGHTHOUSES="192.168.100.1 $NET.2:4242" \
OUTBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \
../genconfig.sh >host4.yml

View File

@ -29,13 +29,13 @@ docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test
docker run --name host2 --rm "$CONTAINER" -config host2.yml -test
vagrant up
vagrant ssh -c "cd /nebula && /nebula/$1-nebula -config host3.yml -test"
vagrant ssh -c "cd /nebula && /nebula/$1-nebula -config host3.yml -test" -- -T
docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' &
sleep 1
docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' &
sleep 1
vagrant ssh -c "cd /nebula && sudo sh -c 'echo \$\$ >/nebula/pid && exec /nebula/$1-nebula -config host3.yml'" &
vagrant ssh -c "cd /nebula && sudo sh -c 'echo \$\$ >/nebula/pid && exec /nebula/$1-nebula -config host3.yml'" 2>&1 -- -T | tee logs/host3 | sed -u 's/^/ [host3] /' &
sleep 15
# grab tcpdump pcaps for debugging
@ -46,8 +46,8 @@ docker exec host2 tcpdump -i eth0 -q -w - -U 2>logs/host2.outside.log >logs/host
# vagrant ssh -c "tcpdump -i nebula1 -q -w - -U" 2>logs/host3.inside.log >logs/host3.inside.pcap &
# vagrant ssh -c "tcpdump -i eth0 -q -w - -U" 2>logs/host3.outside.log >logs/host3.outside.pcap &
docker exec host2 ncat -nklv 0.0.0.0 2000 &
vagrant ssh -c "ncat -nklv 0.0.0.0 2000" &
#docker exec host2 ncat -nklv 0.0.0.0 2000 &
#vagrant ssh -c "ncat -nklv 0.0.0.0 2000" &
#docker exec host2 ncat -e '/usr/bin/echo host2' -nkluv 0.0.0.0 3000 &
#vagrant ssh -c "ncat -e '/usr/bin/echo host3' -nkluv 0.0.0.0 3000" &
@ -68,11 +68,11 @@ docker exec host2 ping -c1 192.168.100.1
# Should fail because not allowed by host3 inbound firewall
! docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1
set +x
echo
echo " *** Testing ncat from host2"
echo
set -x
#set +x
#echo
#echo " *** Testing ncat from host2"
#echo
#set -x
# Should fail because not allowed by host3 inbound firewall
#! docker exec host2 ncat -nzv -w5 192.168.100.3 2000 || exit 1
#! docker exec host2 ncat -nzuv -w5 192.168.100.3 3000 | grep -q host3 || exit 1
@ -82,18 +82,18 @@ echo
echo " *** Testing ping from host3"
echo
set -x
vagrant ssh -c "ping -c1 192.168.100.1"
vagrant ssh -c "ping -c1 192.168.100.2"
vagrant ssh -c "ping -c1 192.168.100.1" -- -T
vagrant ssh -c "ping -c1 192.168.100.2" -- -T
set +x
echo
echo " *** Testing ncat from host3"
echo
set -x
#set +x
#echo
#echo " *** Testing ncat from host3"
#echo
#set -x
#vagrant ssh -c "ncat -nzv -w5 192.168.100.2 2000"
#vagrant ssh -c "ncat -nzuv -w5 192.168.100.2 3000" | grep -q host2
vagrant ssh -c "sudo xargs kill </nebula/pid"
vagrant ssh -c "sudo xargs kill </nebula/pid" -- -T
docker exec host2 sh -c 'kill 1'
docker exec lighthouse1 sh -c 'kill 1'
sleep 1

View File

@ -22,7 +22,7 @@ jobs:
- uses: actions/setup-go@v5
with:
go-version: '1.22'
go-version: '1.23'
check-latest: true
- name: Build
@ -31,6 +31,11 @@ jobs:
- name: Vet
run: make vet
- name: golangci-lint
uses: golangci/golangci-lint-action@v6
with:
version: v1.64
- name: Test
run: make test
@ -55,7 +60,7 @@ jobs:
- uses: actions/setup-go@v5
with:
go-version: '1.22'
go-version: '1.23'
check-latest: true
- name: Build
@ -65,7 +70,7 @@ jobs:
run: make test-boringcrypto
- name: End 2 end
run: make e2evv GOEXPERIMENT=boringcrypto CGO_ENABLED=1
run: make e2e GOEXPERIMENT=boringcrypto CGO_ENABLED=1 TEST_ENV="TEST_LOGS=1" TEST_FLAGS="-v -ldflags -checklinkname=0"
test-linux-pkcs11:
name: Build and test on linux with pkcs11
@ -97,7 +102,7 @@ jobs:
- uses: actions/setup-go@v5
with:
go-version: '1.22'
go-version: '1.23'
check-latest: true
- name: Build nebula
@ -109,6 +114,11 @@ jobs:
- name: Vet
run: make vet
- name: golangci-lint
uses: golangci/golangci-lint-action@v6
with:
version: v1.64
- name: Test
run: make test

3
.gitignore vendored
View File

@ -5,7 +5,8 @@
/nebula-darwin
/nebula.exe
/nebula-cert.exe
/coverage.out
**/coverage.out
**/cover.out
/cpu.pprof
/build
/*.tar.gz

9
.golangci.yaml Normal file
View File

@ -0,0 +1,9 @@
# yaml-language-server: $schema=https://golangci-lint.run/jsonschema/golangci.jsonschema.json
linters:
# Disable all linters.
# Default: false
disable-all: true
# Enable specific linter
# https://golangci-lint.run/usage/linters/#enabled-by-default
enable:
- testifylint

View File

@ -137,6 +137,8 @@ build/linux-mips-softfloat/%: LDFLAGS += -s -w
# boringcrypto
build/linux-amd64-boringcrypto/%: GOENV += GOEXPERIMENT=boringcrypto CGO_ENABLED=1
build/linux-arm64-boringcrypto/%: GOENV += GOEXPERIMENT=boringcrypto CGO_ENABLED=1
build/linux-amd64-boringcrypto/%: LDFLAGS += -checklinkname=0
build/linux-arm64-boringcrypto/%: LDFLAGS += -checklinkname=0
build/%/nebula: .FORCE
GOOS=$(firstword $(subst -, , $*)) \
@ -170,7 +172,7 @@ test:
go test -v ./...
test-boringcrypto:
GOEXPERIMENT=boringcrypto CGO_ENABLED=1 go test -v ./...
GOEXPERIMENT=boringcrypto CGO_ENABLED=1 go test -ldflags "-checklinkname=0" -v ./...
test-pkcs11:
CGO_ENABLED=1 go test -v -tags pkcs11 ./...
@ -196,7 +198,7 @@ bench-cpu-long:
go test -bench=. -benchtime=60s -cpuprofile=cpu.pprof
go tool pprof go-audit.test cpu.pprof
proto: nebula.pb.go cert/cert.pb.go
proto: nebula.pb.go cert/cert_v1.pb.go
nebula.pb.go: nebula.proto .FORCE
go build github.com/gogo/protobuf/protoc-gen-gogofaster

View File

@ -12,7 +12,7 @@ Further documentation can be found [here](https://nebula.defined.net/docs/).
You can read more about Nebula [here](https://medium.com/p/884110a5579).
You can also join the NebulaOSS Slack group [here](https://join.slack.com/t/nebulaoss/shared_invite/enQtOTA5MDI4NDg3MTg4LTkwY2EwNTI4NzQyMzc0M2ZlODBjNWI3NTY1MzhiOThiMmZlZjVkMTI0NGY4YTMyNjUwMWEyNzNkZTJmYzQxOGU).
You can also join the NebulaOSS Slack group [here](https://join.slack.com/t/nebulaoss/shared_invite/zt-2xqe6e7vn-k_KGi8s13nsr7cvHVvHvuQ).
## Supported Platforms
@ -47,7 +47,7 @@ Check the [releases](https://github.com/slackhq/nebula/releases/latest) page for
$ sudo apk add nebula
```
- [macOS Homebrew](https://github.com/Homebrew/homebrew-core/blob/HEAD/Formula/nebula.rb)
- [macOS Homebrew](https://github.com/Homebrew/homebrew-core/blob/HEAD/Formula/n/nebula.rb)
```
$ brew install nebula
```

View File

@ -128,7 +128,6 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
ipNet = netip.PrefixFrom(ipNet.Addr().Unmap(), ipNet.Bits())
// TODO: should we error on duplicate CIDRs in the config?
tree.Insert(ipNet, value)
maskBits := ipNet.Bits()
@ -251,20 +250,20 @@ func getRemoteAllowRanges(c *config.C, k string) (*bart.Table[*AllowList], error
return remoteAllowRanges, nil
}
func (al *AllowList) Allow(ip netip.Addr) bool {
func (al *AllowList) Allow(addr netip.Addr) bool {
if al == nil {
return true
}
result, _ := al.cidrTree.Lookup(ip)
result, _ := al.cidrTree.Lookup(addr)
return result
}
func (al *LocalAllowList) Allow(ip netip.Addr) bool {
func (al *LocalAllowList) Allow(udpAddr netip.Addr) bool {
if al == nil {
return true
}
return al.AllowList.Allow(ip)
return al.AllowList.Allow(udpAddr)
}
func (al *LocalAllowList) AllowName(name string) bool {
@ -282,23 +281,37 @@ func (al *LocalAllowList) AllowName(name string) bool {
return !al.nameRules[0].Allow
}
func (al *RemoteAllowList) AllowUnknownVpnIp(ip netip.Addr) bool {
func (al *RemoteAllowList) AllowUnknownVpnAddr(vpnAddr netip.Addr) bool {
if al == nil {
return true
}
return al.AllowList.Allow(ip)
return al.AllowList.Allow(vpnAddr)
}
func (al *RemoteAllowList) Allow(vpnIp netip.Addr, ip netip.Addr) bool {
if !al.getInsideAllowList(vpnIp).Allow(ip) {
func (al *RemoteAllowList) Allow(vpnAddr netip.Addr, udpAddr netip.Addr) bool {
if !al.getInsideAllowList(vpnAddr).Allow(udpAddr) {
return false
}
return al.AllowList.Allow(ip)
return al.AllowList.Allow(udpAddr)
}
func (al *RemoteAllowList) getInsideAllowList(vpnIp netip.Addr) *AllowList {
func (al *RemoteAllowList) AllowAll(vpnAddrs []netip.Addr, udpAddr netip.Addr) bool {
if !al.AllowList.Allow(udpAddr) {
return false
}
for _, vpnAddr := range vpnAddrs {
if !al.getInsideAllowList(vpnAddr).Allow(udpAddr) {
return false
}
}
return true
}
func (al *RemoteAllowList) getInsideAllowList(vpnAddr netip.Addr) *AllowList {
if al.insideAllowLists != nil {
inside, ok := al.insideAllowLists.Lookup(vpnIp)
inside, ok := al.insideAllowLists.Lookup(vpnAddr)
if ok {
return inside
}

View File

@ -9,6 +9,7 @@ import (
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewAllowListFromConfig(t *testing.T) {
@ -18,21 +19,21 @@ func TestNewAllowListFromConfig(t *testing.T) {
"192.168.0.0": true,
}
r, err := newAllowListFromConfig(c, "allowlist", nil)
assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0. netip.ParsePrefix(\"192.168.0.0\"): no '/'")
require.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0. netip.ParsePrefix(\"192.168.0.0\"): no '/'")
assert.Nil(t, r)
c.Settings["allowlist"] = map[interface{}]interface{}{
"192.168.0.0/16": "abc",
}
r, err = newAllowListFromConfig(c, "allowlist", nil)
assert.EqualError(t, err, "config `allowlist` has invalid value (type string): abc")
require.EqualError(t, err, "config `allowlist` has invalid value (type string): abc")
c.Settings["allowlist"] = map[interface{}]interface{}{
"192.168.0.0/16": true,
"10.0.0.0/8": false,
}
r, err = newAllowListFromConfig(c, "allowlist", nil)
assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0")
require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0")
c.Settings["allowlist"] = map[interface{}]interface{}{
"0.0.0.0/0": true,
@ -42,7 +43,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
"fd00:fd00::/16": false,
}
r, err = newAllowListFromConfig(c, "allowlist", nil)
assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0")
require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0")
c.Settings["allowlist"] = map[interface{}]interface{}{
"0.0.0.0/0": true,
@ -75,7 +76,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
},
}
lr, err := NewLocalAllowListFromConfig(c, "allowlist")
assert.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo")
require.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo")
c.Settings["allowlist"] = map[interface{}]interface{}{
"interfaces": map[interface{}]interface{}{
@ -84,7 +85,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
},
}
lr, err = NewLocalAllowListFromConfig(c, "allowlist")
assert.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value")
require.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value")
c.Settings["allowlist"] = map[interface{}]interface{}{
"interfaces": map[interface{}]interface{}{
@ -98,7 +99,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
}
func TestAllowList_Allow(t *testing.T) {
assert.Equal(t, true, ((*AllowList)(nil)).Allow(netip.MustParseAddr("1.1.1.1")))
assert.True(t, ((*AllowList)(nil)).Allow(netip.MustParseAddr("1.1.1.1")))
tree := new(bart.Table[bool])
tree.Insert(netip.MustParsePrefix("0.0.0.0/0"), true)
@ -111,17 +112,17 @@ func TestAllowList_Allow(t *testing.T) {
tree.Insert(netip.MustParsePrefix("::2/128"), false)
al := &AllowList{cidrTree: tree}
assert.Equal(t, true, al.Allow(netip.MustParseAddr("1.1.1.1")))
assert.Equal(t, false, al.Allow(netip.MustParseAddr("10.0.0.4")))
assert.Equal(t, true, al.Allow(netip.MustParseAddr("10.42.42.42")))
assert.Equal(t, false, al.Allow(netip.MustParseAddr("10.42.42.41")))
assert.Equal(t, true, al.Allow(netip.MustParseAddr("10.42.0.1")))
assert.Equal(t, true, al.Allow(netip.MustParseAddr("::1")))
assert.Equal(t, false, al.Allow(netip.MustParseAddr("::2")))
assert.True(t, al.Allow(netip.MustParseAddr("1.1.1.1")))
assert.False(t, al.Allow(netip.MustParseAddr("10.0.0.4")))
assert.True(t, al.Allow(netip.MustParseAddr("10.42.42.42")))
assert.False(t, al.Allow(netip.MustParseAddr("10.42.42.41")))
assert.True(t, al.Allow(netip.MustParseAddr("10.42.0.1")))
assert.True(t, al.Allow(netip.MustParseAddr("::1")))
assert.False(t, al.Allow(netip.MustParseAddr("::2")))
}
func TestLocalAllowList_AllowName(t *testing.T) {
assert.Equal(t, true, ((*LocalAllowList)(nil)).AllowName("docker0"))
assert.True(t, ((*LocalAllowList)(nil)).AllowName("docker0"))
rules := []AllowListNameRule{
{Name: regexp.MustCompile("^docker.*$"), Allow: false},
@ -129,9 +130,9 @@ func TestLocalAllowList_AllowName(t *testing.T) {
}
al := &LocalAllowList{nameRules: rules}
assert.Equal(t, false, al.AllowName("docker0"))
assert.Equal(t, false, al.AllowName("tun0"))
assert.Equal(t, true, al.AllowName("eth0"))
assert.False(t, al.AllowName("docker0"))
assert.False(t, al.AllowName("tun0"))
assert.True(t, al.AllowName("eth0"))
rules = []AllowListNameRule{
{Name: regexp.MustCompile("^eth.*$"), Allow: true},
@ -139,7 +140,7 @@ func TestLocalAllowList_AllowName(t *testing.T) {
}
al = &LocalAllowList{nameRules: rules}
assert.Equal(t, false, al.AllowName("docker0"))
assert.Equal(t, true, al.AllowName("eth0"))
assert.Equal(t, true, al.AllowName("ens5"))
assert.False(t, al.AllowName("docker0"))
assert.True(t, al.AllowName("eth0"))
assert.True(t, al.AllowName("ens5"))
}

View File

@ -21,7 +21,11 @@ type calculatedRemote struct {
port uint32
}
func newCalculatedRemote(maskCidr netip.Prefix, port int) (*calculatedRemote, error) {
func newCalculatedRemote(cidr, maskCidr netip.Prefix, port int) (*calculatedRemote, error) {
if maskCidr.Addr().BitLen() != cidr.Addr().BitLen() {
return nil, fmt.Errorf("invalid mask: %s for cidr: %s", maskCidr, cidr)
}
masked := maskCidr.Masked()
if port < 0 || port > math.MaxUint16 {
return nil, fmt.Errorf("invalid port: %d", port)
@ -38,32 +42,38 @@ func (c *calculatedRemote) String() string {
return fmt.Sprintf("CalculatedRemote(mask=%v port=%d)", c.ipNet, c.port)
}
func (c *calculatedRemote) Apply(ip netip.Addr) *Ip4AndPort {
// Combine the masked bytes of the "mask" IP with the unmasked bytes
// of the overlay IP
if c.ipNet.Addr().Is4() {
return c.apply4(ip)
}
return c.apply6(ip)
}
func (c *calculatedRemote) apply4(ip netip.Addr) *Ip4AndPort {
//TODO: IPV6-WORK this can be less crappy
func (c *calculatedRemote) ApplyV4(addr netip.Addr) *V4AddrPort {
// Combine the masked bytes of the "mask" IP with the unmasked bytes of the overlay IP
maskb := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen())
mask := binary.BigEndian.Uint32(maskb[:])
b := c.mask.Addr().As4()
maskIp := binary.BigEndian.Uint32(b[:])
maskAddr := binary.BigEndian.Uint32(b[:])
b = ip.As4()
intIp := binary.BigEndian.Uint32(b[:])
b = addr.As4()
intAddr := binary.BigEndian.Uint32(b[:])
return &Ip4AndPort{(maskIp & mask) | (intIp & ^mask), c.port}
return &V4AddrPort{(maskAddr & mask) | (intAddr & ^mask), c.port}
}
func (c *calculatedRemote) apply6(ip netip.Addr) *Ip4AndPort {
//TODO: IPV6-WORK
panic("Can not calculate ipv6 remote addresses")
func (c *calculatedRemote) ApplyV6(addr netip.Addr) *V6AddrPort {
mask := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen())
maskAddr := c.mask.Addr().As16()
calcAddr := addr.As16()
ap := V6AddrPort{Port: c.port}
maskb := binary.BigEndian.Uint64(mask[:8])
maskAddrb := binary.BigEndian.Uint64(maskAddr[:8])
calcAddrb := binary.BigEndian.Uint64(calcAddr[:8])
ap.Hi = (maskAddrb & maskb) | (calcAddrb & ^maskb)
maskb = binary.BigEndian.Uint64(mask[8:])
maskAddrb = binary.BigEndian.Uint64(maskAddr[8:])
calcAddrb = binary.BigEndian.Uint64(calcAddr[8:])
ap.Lo = (maskAddrb & maskb) | (calcAddrb & ^maskb)
return &ap
}
func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calculatedRemote], error) {
@ -89,8 +99,7 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calcu
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
}
//TODO: IPV6-WORK this does not verify that rawValue contains the same bits as cidr here
entry, err := newCalculatedRemotesListFromConfig(rawValue)
entry, err := newCalculatedRemotesListFromConfig(cidr, rawValue)
if err != nil {
return nil, fmt.Errorf("config '%s.%s': %w", k, rawCIDR, err)
}
@ -101,7 +110,7 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calcu
return calculatedRemotes, nil
}
func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) {
func newCalculatedRemotesListFromConfig(cidr netip.Prefix, raw any) ([]*calculatedRemote, error) {
rawList, ok := raw.([]any)
if !ok {
return nil, fmt.Errorf("calculated_remotes entry has invalid type: %T", raw)
@ -109,7 +118,7 @@ func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) {
var l []*calculatedRemote
for _, e := range rawList {
c, err := newCalculatedRemotesEntryFromConfig(e)
c, err := newCalculatedRemotesEntryFromConfig(cidr, e)
if err != nil {
return nil, fmt.Errorf("calculated_remotes entry: %w", err)
}
@ -119,7 +128,7 @@ func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) {
return l, nil
}
func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) {
func newCalculatedRemotesEntryFromConfig(cidr netip.Prefix, raw any) (*calculatedRemote, error) {
rawMap, ok := raw.(map[any]any)
if !ok {
return nil, fmt.Errorf("invalid type: %T", raw)
@ -155,5 +164,5 @@ func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) {
return nil, fmt.Errorf("invalid port (type %T): %v", rawValue, rawValue)
}
return newCalculatedRemote(maskCidr, port)
return newCalculatedRemote(cidr, maskCidr, port)
}

View File

@ -9,17 +9,73 @@ import (
)
func TestCalculatedRemoteApply(t *testing.T) {
ipNet, err := netip.ParsePrefix("192.168.1.0/24")
require.NoError(t, err)
c, err := newCalculatedRemote(ipNet, 4242)
// Test v4 addresses
ipNet := netip.MustParsePrefix("192.168.1.0/24")
c, err := newCalculatedRemote(ipNet, ipNet, 4242)
require.NoError(t, err)
input, err := netip.ParseAddr("10.0.10.182")
assert.NoError(t, err)
require.NoError(t, err)
expected, err := netip.ParseAddr("192.168.1.182")
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, NewIp4AndPortFromNetIP(expected, 4242), c.Apply(input))
assert.Equal(t, netAddrToProtoV4AddrPort(expected, 4242), c.ApplyV4(input))
// Test v6 addresses
ipNet = netip.MustParsePrefix("ffff:ffff:ffff:ffff::0/64")
c, err = newCalculatedRemote(ipNet, ipNet, 4242)
require.NoError(t, err)
input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
require.NoError(t, err)
expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:beef:beef:beef:beef")
require.NoError(t, err)
assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
// Test v6 addresses part 2
ipNet = netip.MustParsePrefix("ffff:ffff:ffff:ffff:ffff::0/80")
c, err = newCalculatedRemote(ipNet, ipNet, 4242)
require.NoError(t, err)
input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
require.NoError(t, err)
expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:ffff:beef:beef:beef")
require.NoError(t, err)
assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
// Test v6 addresses part 2
ipNet = netip.MustParsePrefix("ffff:ffff:ffff::0/48")
c, err = newCalculatedRemote(ipNet, ipNet, 4242)
require.NoError(t, err)
input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
require.NoError(t, err)
expected, err = netip.ParseAddr("ffff:ffff:ffff:beef:beef:beef:beef:beef")
require.NoError(t, err)
assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
}
func Test_newCalculatedRemote(t *testing.T) {
c, err := newCalculatedRemote(netip.MustParsePrefix("1::1/128"), netip.MustParsePrefix("1.0.0.0/32"), 4242)
require.EqualError(t, err, "invalid mask: 1.0.0.0/32 for cidr: 1::1/128")
require.Nil(t, c)
c, err = newCalculatedRemote(netip.MustParsePrefix("1.0.0.0/32"), netip.MustParsePrefix("1::1/128"), 4242)
require.EqualError(t, err, "invalid mask: 1::1/128 for cidr: 1.0.0.0/32")
require.Nil(t, c)
c, err = newCalculatedRemote(netip.MustParsePrefix("1.0.0.0/32"), netip.MustParsePrefix("1.0.0.0/32"), 4242)
require.NoError(t, err)
require.NotNil(t, c)
c, err = newCalculatedRemote(netip.MustParsePrefix("1::1/128"), netip.MustParsePrefix("1::1/128"), 4242)
require.NoError(t, err)
require.NotNil(t, c)
}

View File

@ -1,7 +1,7 @@
GO111MODULE = on
export GO111MODULE
cert.pb.go: cert.proto .FORCE
cert_v1.pb.go: cert_v1.proto .FORCE
go build google.golang.org/protobuf/cmd/protoc-gen-go
PATH="$(CURDIR):$(PATH)" protoc --go_out=. --go_opt=paths=source_relative $<
rm protoc-gen-go

View File

@ -2,14 +2,25 @@
This is a library for interacting with `nebula` style certificates and authorities.
A `protobuf` definition of the certificate format is also included
There are now 2 versions of `nebula` certificates:
### Compiling the protobuf definition
## v1
Make sure you have `protoc` installed.
This version is deprecated.
A `protobuf` definition of the certificate format is included at `cert_v1.proto`
To compile the definition you will need `protoc` installed.
To compile for `go` with the same version of protobuf specified in go.mod:
```bash
make
make proto
```
## v2
This is the latest version which uses asn.1 DER encoding. It can support ipv4 and ipv6 and tolerate
future certificate changes better than v1.
`cert_v2.asn1` defines the wire format and can be used to compile marshalers.

52
cert/asn1.go Normal file
View File

@ -0,0 +1,52 @@
package cert
import (
"golang.org/x/crypto/cryptobyte"
"golang.org/x/crypto/cryptobyte/asn1"
)
// readOptionalASN1Boolean reads an asn.1 boolean with a specific tag instead of a asn.1 tag wrapping a boolean with a value
// https://github.com/golang/go/issues/64811#issuecomment-1944446920
func readOptionalASN1Boolean(b *cryptobyte.String, out *bool, tag asn1.Tag, defaultValue bool) bool {
var present bool
var child cryptobyte.String
if !b.ReadOptionalASN1(&child, &present, tag) {
return false
}
if !present {
*out = defaultValue
return true
}
// Ensure we have 1 byte
if len(child) == 1 {
*out = child[0] > 0
return true
}
return false
}
// readOptionalASN1Byte reads an asn.1 uint8 with a specific tag instead of a asn.1 tag wrapping a uint8 with a value
// Similar issue as with readOptionalASN1Boolean
func readOptionalASN1Byte(b *cryptobyte.String, out *byte, tag asn1.Tag, defaultValue byte) bool {
var present bool
var child cryptobyte.String
if !b.ReadOptionalASN1(&child, &present, tag) {
return false
}
if !present {
*out = defaultValue
return true
}
// Ensure we have 1 byte
if len(child) == 1 {
*out = child[0]
return true
}
return false
}

View File

@ -1,140 +0,0 @@
package cert
import (
"errors"
"fmt"
"strings"
"time"
)
type NebulaCAPool struct {
CAs map[string]*NebulaCertificate
certBlocklist map[string]struct{}
}
// NewCAPool creates a CAPool
func NewCAPool() *NebulaCAPool {
ca := NebulaCAPool{
CAs: make(map[string]*NebulaCertificate),
certBlocklist: make(map[string]struct{}),
}
return &ca
}
// NewCAPoolFromBytes will create a new CA pool from the provided
// input bytes, which must be a PEM-encoded set of nebula certificates.
// If the pool contains any expired certificates, an ErrExpired will be
// returned along with the pool. The caller must handle any such errors.
func NewCAPoolFromBytes(caPEMs []byte) (*NebulaCAPool, error) {
pool := NewCAPool()
var err error
var expired bool
for {
caPEMs, err = pool.AddCACertificate(caPEMs)
if errors.Is(err, ErrExpired) {
expired = true
err = nil
}
if err != nil {
return nil, err
}
if len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" {
break
}
}
if expired {
return pool, ErrExpired
}
return pool, nil
}
// AddCACertificate verifies a Nebula CA certificate and adds it to the pool
// Only the first pem encoded object will be consumed, any remaining bytes are returned.
// Parsed certificates will be verified and must be a CA
func (ncp *NebulaCAPool) AddCACertificate(pemBytes []byte) ([]byte, error) {
c, pemBytes, err := UnmarshalNebulaCertificateFromPEM(pemBytes)
if err != nil {
return pemBytes, err
}
if !c.Details.IsCA {
return pemBytes, fmt.Errorf("%s: %w", c.Details.Name, ErrNotCA)
}
if !c.CheckSignature(c.Details.PublicKey) {
return pemBytes, fmt.Errorf("%s: %w", c.Details.Name, ErrNotSelfSigned)
}
sum, err := c.Sha256Sum()
if err != nil {
return pemBytes, fmt.Errorf("could not calculate shasum for provided CA; error: %s; %s", err, c.Details.Name)
}
ncp.CAs[sum] = c
if c.Expired(time.Now()) {
return pemBytes, fmt.Errorf("%s: %w", c.Details.Name, ErrExpired)
}
return pemBytes, nil
}
// BlocklistFingerprint adds a cert fingerprint to the blocklist
func (ncp *NebulaCAPool) BlocklistFingerprint(f string) {
ncp.certBlocklist[f] = struct{}{}
}
// ResetCertBlocklist removes all previously blocklisted cert fingerprints
func (ncp *NebulaCAPool) ResetCertBlocklist() {
ncp.certBlocklist = make(map[string]struct{})
}
// NOTE: This uses an internal cache for Sha256Sum() that will not be invalidated
// automatically if you manually change any fields in the NebulaCertificate.
func (ncp *NebulaCAPool) IsBlocklisted(c *NebulaCertificate) bool {
return ncp.isBlocklistedWithCache(c, false)
}
// IsBlocklisted returns true if the fingerprint fails to generate or has been explicitly blocklisted
func (ncp *NebulaCAPool) isBlocklistedWithCache(c *NebulaCertificate, useCache bool) bool {
h, err := c.sha256SumWithCache(useCache)
if err != nil {
return true
}
if _, ok := ncp.certBlocklist[h]; ok {
return true
}
return false
}
// GetCAForCert attempts to return the signing certificate for the provided certificate.
// No signature validation is performed
func (ncp *NebulaCAPool) GetCAForCert(c *NebulaCertificate) (*NebulaCertificate, error) {
if c.Details.Issuer == "" {
return nil, fmt.Errorf("no issuer in certificate")
}
signer, ok := ncp.CAs[c.Details.Issuer]
if ok {
return signer, nil
}
return nil, fmt.Errorf("could not find ca for the certificate")
}
// GetFingerprints returns an array of trusted CA fingerprints
func (ncp *NebulaCAPool) GetFingerprints() []string {
fp := make([]string, len(ncp.CAs))
i := 0
for k := range ncp.CAs {
fp[i] = k
i++
}
return fp
}

296
cert/ca_pool.go Normal file
View File

@ -0,0 +1,296 @@
package cert
import (
"errors"
"fmt"
"net/netip"
"slices"
"strings"
"time"
)
type CAPool struct {
CAs map[string]*CachedCertificate
certBlocklist map[string]struct{}
}
// NewCAPool creates an empty CAPool
func NewCAPool() *CAPool {
ca := CAPool{
CAs: make(map[string]*CachedCertificate),
certBlocklist: make(map[string]struct{}),
}
return &ca
}
// NewCAPoolFromPEM will create a new CA pool from the provided
// input bytes, which must be a PEM-encoded set of nebula certificates.
// If the pool contains any expired certificates, an ErrExpired will be
// returned along with the pool. The caller must handle any such errors.
func NewCAPoolFromPEM(caPEMs []byte) (*CAPool, error) {
pool := NewCAPool()
var err error
var expired bool
for {
caPEMs, err = pool.AddCAFromPEM(caPEMs)
if errors.Is(err, ErrExpired) {
expired = true
err = nil
}
if err != nil {
return nil, err
}
if len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" {
break
}
}
if expired {
return pool, ErrExpired
}
return pool, nil
}
// AddCAFromPEM verifies a Nebula CA certificate and adds it to the pool.
// Only the first pem encoded object will be consumed, any remaining bytes are returned.
// Parsed certificates will be verified and must be a CA
func (ncp *CAPool) AddCAFromPEM(pemBytes []byte) ([]byte, error) {
c, pemBytes, err := UnmarshalCertificateFromPEM(pemBytes)
if err != nil {
return pemBytes, err
}
err = ncp.AddCA(c)
if err != nil {
return pemBytes, err
}
return pemBytes, nil
}
// AddCA verifies a Nebula CA certificate and adds it to the pool.
func (ncp *CAPool) AddCA(c Certificate) error {
if !c.IsCA() {
return fmt.Errorf("%s: %w", c.Name(), ErrNotCA)
}
if !c.CheckSignature(c.PublicKey()) {
return fmt.Errorf("%s: %w", c.Name(), ErrNotSelfSigned)
}
sum, err := c.Fingerprint()
if err != nil {
return fmt.Errorf("could not calculate fingerprint for provided CA; error: %w; %s", err, c.Name())
}
cc := &CachedCertificate{
Certificate: c,
Fingerprint: sum,
InvertedGroups: make(map[string]struct{}),
}
for _, g := range c.Groups() {
cc.InvertedGroups[g] = struct{}{}
}
ncp.CAs[sum] = cc
if c.Expired(time.Now()) {
return fmt.Errorf("%s: %w", c.Name(), ErrExpired)
}
return nil
}
// BlocklistFingerprint adds a cert fingerprint to the blocklist
func (ncp *CAPool) BlocklistFingerprint(f string) {
ncp.certBlocklist[f] = struct{}{}
}
// ResetCertBlocklist removes all previously blocklisted cert fingerprints
func (ncp *CAPool) ResetCertBlocklist() {
ncp.certBlocklist = make(map[string]struct{})
}
// IsBlocklisted tests the provided fingerprint against the pools blocklist.
// Returns true if the fingerprint is blocked.
func (ncp *CAPool) IsBlocklisted(fingerprint string) bool {
if _, ok := ncp.certBlocklist[fingerprint]; ok {
return true
}
return false
}
// VerifyCertificate verifies the certificate is valid and is signed by a trusted CA in the pool.
// If the certificate is valid then the returned CachedCertificate can be used in subsequent verification attempts
// to increase performance.
func (ncp *CAPool) VerifyCertificate(now time.Time, c Certificate) (*CachedCertificate, error) {
if c == nil {
return nil, fmt.Errorf("no certificate")
}
fp, err := c.Fingerprint()
if err != nil {
return nil, fmt.Errorf("could not calculate fingerprint to verify: %w", err)
}
signer, err := ncp.verify(c, now, fp, "")
if err != nil {
return nil, err
}
cc := CachedCertificate{
Certificate: c,
InvertedGroups: make(map[string]struct{}),
Fingerprint: fp,
signerFingerprint: signer.Fingerprint,
}
for _, g := range c.Groups() {
cc.InvertedGroups[g] = struct{}{}
}
return &cc, nil
}
// VerifyCachedCertificate is the same as VerifyCertificate other than it operates on a pre-verified structure and
// is a cheaper operation to perform as a result.
func (ncp *CAPool) VerifyCachedCertificate(now time.Time, c *CachedCertificate) error {
_, err := ncp.verify(c.Certificate, now, c.Fingerprint, c.signerFingerprint)
return err
}
func (ncp *CAPool) verify(c Certificate, now time.Time, certFp string, signerFp string) (*CachedCertificate, error) {
if ncp.IsBlocklisted(certFp) {
return nil, ErrBlockListed
}
signer, err := ncp.GetCAForCert(c)
if err != nil {
return nil, err
}
if signer.Certificate.Expired(now) {
return nil, ErrRootExpired
}
if c.Expired(now) {
return nil, ErrExpired
}
// If we are checking a cached certificate then we can bail early here
// Either the root is no longer trusted or everything is fine
if len(signerFp) > 0 {
if signerFp != signer.Fingerprint {
return nil, ErrFingerprintMismatch
}
return signer, nil
}
if !c.CheckSignature(signer.Certificate.PublicKey()) {
return nil, ErrSignatureMismatch
}
err = CheckCAConstraints(signer.Certificate, c)
if err != nil {
return nil, err
}
return signer, nil
}
// GetCAForCert attempts to return the signing certificate for the provided certificate.
// No signature validation is performed
func (ncp *CAPool) GetCAForCert(c Certificate) (*CachedCertificate, error) {
issuer := c.Issuer()
if issuer == "" {
return nil, fmt.Errorf("no issuer in certificate")
}
signer, ok := ncp.CAs[issuer]
if ok {
return signer, nil
}
return nil, ErrCaNotFound
}
// GetFingerprints returns an array of trusted CA fingerprints
func (ncp *CAPool) GetFingerprints() []string {
fp := make([]string, len(ncp.CAs))
i := 0
for k := range ncp.CAs {
fp[i] = k
i++
}
return fp
}
// CheckCAConstraints returns an error if the sub certificate violates constraints present in the signer certificate.
func CheckCAConstraints(signer Certificate, sub Certificate) error {
return checkCAConstraints(signer, sub.NotBefore(), sub.NotAfter(), sub.Groups(), sub.Networks(), sub.UnsafeNetworks())
}
// checkCAConstraints is a very generic function allowing both Certificates and TBSCertificates to be tested.
func checkCAConstraints(signer Certificate, notBefore, notAfter time.Time, groups []string, networks, unsafeNetworks []netip.Prefix) error {
// Make sure this cert isn't valid after the root
if notAfter.After(signer.NotAfter()) {
return fmt.Errorf("certificate expires after signing certificate")
}
// Make sure this cert wasn't valid before the root
if notBefore.Before(signer.NotBefore()) {
return fmt.Errorf("certificate is valid before the signing certificate")
}
// If the signer has a limited set of groups make sure the cert only contains a subset
signerGroups := signer.Groups()
if len(signerGroups) > 0 {
for _, g := range groups {
if !slices.Contains(signerGroups, g) {
return fmt.Errorf("certificate contained a group not present on the signing ca: %s", g)
}
}
}
// If the signer has a limited set of ip ranges to issue from make sure the cert only contains a subset
signingNetworks := signer.Networks()
if len(signingNetworks) > 0 {
for _, certNetwork := range networks {
found := false
for _, signingNetwork := range signingNetworks {
if signingNetwork.Contains(certNetwork.Addr()) && signingNetwork.Bits() <= certNetwork.Bits() {
found = true
break
}
}
if !found {
return fmt.Errorf("certificate contained a network assignment outside the limitations of the signing ca: %s", certNetwork.String())
}
}
}
// If the signer has a limited set of subnet ranges to issue from make sure the cert only contains a subset
signingUnsafeNetworks := signer.UnsafeNetworks()
if len(signingUnsafeNetworks) > 0 {
for _, certUnsafeNetwork := range unsafeNetworks {
found := false
for _, caNetwork := range signingUnsafeNetworks {
if caNetwork.Contains(certUnsafeNetwork.Addr()) && caNetwork.Bits() <= certUnsafeNetwork.Bits() {
found = true
break
}
}
if !found {
return fmt.Errorf("certificate contained an unsafe network assignment outside the limitations of the signing ca: %s", certUnsafeNetwork.String())
}
}
}
return nil
}

560
cert/ca_pool_test.go Normal file
View File

@ -0,0 +1,560 @@
package cert
import (
"net/netip"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewCAPoolFromBytes(t *testing.T) {
noNewLines := `
# Current provisional, Remove once everything moves over to the real root.
-----BEGIN NEBULA CERTIFICATE-----
Cj4KDm5lYnVsYSByb290IGNhKM0cMM24zPCvBzogV24YEw5YiqeI/oYo8XXFsoo+
PBmiOafNJhLacf9rsspAARJAz9OAnh8TKAUKix1kKVMyQU4iM3LsFfZRf6ODWXIf
2qWMpB6fpd3PSoVYziPoOt2bIHIFLlgRLPJz3I3xBEdBCQ==
-----END NEBULA CERTIFICATE-----
# root-ca01
-----BEGIN NEBULA CERTIFICATE-----
CkEKEW5lYnVsYSByb290IGNhIDAxKM0cMM24zPCvBzogPzbWTxt8ZgXPQEwup7Br
BrtIt1O0q5AuTRT3+t2x1VJAARJAZ+2ib23qBXjdy49oU1YysrwuKkWWKrtJ7Jye
rFBQpDXikOukhQD/mfkloFwJ+Yjsfru7IpTN4ZfjXL+kN/2sCA==
-----END NEBULA CERTIFICATE-----
`
withNewLines := `
# Current provisional, Remove once everything moves over to the real root.
-----BEGIN NEBULA CERTIFICATE-----
Cj4KDm5lYnVsYSByb290IGNhKM0cMM24zPCvBzogV24YEw5YiqeI/oYo8XXFsoo+
PBmiOafNJhLacf9rsspAARJAz9OAnh8TKAUKix1kKVMyQU4iM3LsFfZRf6ODWXIf
2qWMpB6fpd3PSoVYziPoOt2bIHIFLlgRLPJz3I3xBEdBCQ==
-----END NEBULA CERTIFICATE-----
# root-ca01
-----BEGIN NEBULA CERTIFICATE-----
CkEKEW5lYnVsYSByb290IGNhIDAxKM0cMM24zPCvBzogPzbWTxt8ZgXPQEwup7Br
BrtIt1O0q5AuTRT3+t2x1VJAARJAZ+2ib23qBXjdy49oU1YysrwuKkWWKrtJ7Jye
rFBQpDXikOukhQD/mfkloFwJ+Yjsfru7IpTN4ZfjXL+kN/2sCA==
-----END NEBULA CERTIFICATE-----
`
expired := `
# expired certificate
-----BEGIN NEBULA CERTIFICATE-----
CjMKB2V4cGlyZWQozRwwzRw6ICJSG94CqX8wn5I65Pwn25V6HftVfWeIySVtp2DA
7TY/QAESQMaAk5iJT5EnQwK524ZaaHGEJLUqqbh5yyOHhboIGiVTWkFeH3HccTW8
Tq5a8AyWDQdfXbtEZ1FwabeHfH5Asw0=
-----END NEBULA CERTIFICATE-----
`
p256 := `
# p256 certificate
-----BEGIN NEBULA CERTIFICATE-----
CmQKEG5lYnVsYSBQMjU2IHRlc3QozRwwzbjM8K8HOkEEdrmmg40zQp44AkMq6DZp
k+coOv04r+zh33ISyhbsafnYduN17p2eD7CmHvHuerguXD9f32gcxo/KsFCKEjMe
+0ABoAYBEkcwRQIgVoTg38L7uWku9xQgsr06kxZ/viQLOO/w1Qj1vFUEnhcCIQCq
75SjTiV92kv/1GcbT3wWpAZQQDBiUHVMVmh1822szA==
-----END NEBULA CERTIFICATE-----
`
rootCA := certificateV1{
details: detailsV1{
name: "nebula root ca",
},
}
rootCA01 := certificateV1{
details: detailsV1{
name: "nebula root ca 01",
},
}
rootCAP256 := certificateV1{
details: detailsV1{
name: "nebula P256 test",
},
}
p, err := NewCAPoolFromPEM([]byte(noNewLines))
require.NoError(t, err)
assert.Equal(t, p.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name)
assert.Equal(t, p.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name)
pp, err := NewCAPoolFromPEM([]byte(withNewLines))
require.NoError(t, err)
assert.Equal(t, pp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name)
assert.Equal(t, pp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name)
// expired cert, no valid certs
ppp, err := NewCAPoolFromPEM([]byte(expired))
assert.Equal(t, ErrExpired, err)
assert.Equal(t, "expired", ppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name())
// expired cert, with valid certs
pppp, err := NewCAPoolFromPEM(append([]byte(expired), noNewLines...))
assert.Equal(t, ErrExpired, err)
assert.Equal(t, pppp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name)
assert.Equal(t, pppp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name)
assert.Equal(t, "expired", pppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name())
assert.Len(t, pppp.CAs, 3)
ppppp, err := NewCAPoolFromPEM([]byte(p256))
require.NoError(t, err)
assert.Equal(t, ppppp.CAs["552bf7d99bec1fc775a0e4c324bf6d8f789b3078f1919c7960d2e5e0c351ee97"].Certificate.Name(), rootCAP256.details.name)
assert.Len(t, ppppp.CAs, 1)
}
func TestCertificateV1_Verify(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
caPool := NewCAPool()
require.NoError(t, caPool.AddCA(ca))
f, err := c.Fingerprint()
require.NoError(t, err)
caPool.BlocklistFingerprint(f)
_, err = caPool.VerifyCertificate(time.Now(), c)
require.EqualError(t, err, "certificate is in the block list")
caPool.ResetCertBlocklist()
_, err = caPool.VerifyCertificate(time.Now(), c)
require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
require.EqualError(t, err, "root certificate is expired")
assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil)
})
// Test group assertion
ca, _, caKey, _ = NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
caPem, err := ca.MarshalPEM()
require.NoError(t, err)
caPool = NewCAPool()
b, err := caPool.AddCAFromPEM(caPem)
require.NoError(t, err)
assert.Empty(t, b)
assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"})
})
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
require.NoError(t, err)
}
func TestCertificateV1_VerifyP256(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
c, _, _, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
caPool := NewCAPool()
require.NoError(t, caPool.AddCA(ca))
f, err := c.Fingerprint()
require.NoError(t, err)
caPool.BlocklistFingerprint(f)
_, err = caPool.VerifyCertificate(time.Now(), c)
require.EqualError(t, err, "certificate is in the block list")
caPool.ResetCertBlocklist()
_, err = caPool.VerifyCertificate(time.Now(), c)
require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
require.EqualError(t, err, "root certificate is expired")
assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
})
// Test group assertion
ca, _, caKey, _ = NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
caPem, err := ca.MarshalPEM()
require.NoError(t, err)
caPool = NewCAPool()
b, err := caPool.AddCAFromPEM(caPem)
require.NoError(t, err)
assert.Empty(t, b)
assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"})
})
c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
_, err = caPool.VerifyCertificate(time.Now(), c)
require.NoError(t, err)
}
func TestCertificateV1_Verify_IPs(t *testing.T) {
caIp1 := mustParsePrefixUnmapped("10.0.0.0/16")
caIp2 := mustParsePrefixUnmapped("192.168.0.0/24")
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
caPem, err := ca.MarshalPEM()
require.NoError(t, err)
caPool := NewCAPool()
b, err := caPool.AddCAFromPEM(caPem)
require.NoError(t, err)
assert.Empty(t, b)
// ip is outside the network
cIp1 := mustParsePrefixUnmapped("10.1.0.0/24")
cIp2 := mustParsePrefixUnmapped("192.168.0.1/16")
assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
})
// ip is outside the network reversed order of above
cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
cIp2 = mustParsePrefixUnmapped("10.1.0.0/24")
assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
})
// ip is within the network but mask is outside
cIp1 = mustParsePrefixUnmapped("10.0.1.0/15")
cIp2 = mustParsePrefixUnmapped("192.168.0.1/24")
assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
})
// ip is within the network but mask is outside reversed order of above
cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
cIp2 = mustParsePrefixUnmapped("10.0.1.0/15")
assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
})
// ip and mask are within the network
cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
_, err = caPool.VerifyCertificate(time.Now(), c)
require.NoError(t, err)
// Exact matches
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
require.NoError(t, err)
// Exact matches reversed
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"})
require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
require.NoError(t, err)
// Exact matches reversed with just 1
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"})
require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
require.NoError(t, err)
}
func TestCertificateV1_Verify_Subnets(t *testing.T) {
caIp1 := mustParsePrefixUnmapped("10.0.0.0/16")
caIp2 := mustParsePrefixUnmapped("192.168.0.0/24")
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
caPem, err := ca.MarshalPEM()
require.NoError(t, err)
caPool := NewCAPool()
b, err := caPool.AddCAFromPEM(caPem)
require.NoError(t, err)
assert.Empty(t, b)
// ip is outside the network
cIp1 := mustParsePrefixUnmapped("10.1.0.0/24")
cIp2 := mustParsePrefixUnmapped("192.168.0.1/16")
assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
})
// ip is outside the network reversed order of above
cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
cIp2 = mustParsePrefixUnmapped("10.1.0.0/24")
assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
})
// ip is within the network but mask is outside
cIp1 = mustParsePrefixUnmapped("10.0.1.0/15")
cIp2 = mustParsePrefixUnmapped("192.168.0.1/24")
assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
})
// ip is within the network but mask is outside reversed order of above
cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
cIp2 = mustParsePrefixUnmapped("10.0.1.0/15")
assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
})
// ip and mask are within the network
cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
require.NoError(t, err)
// Exact matches
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
require.NoError(t, err)
// Exact matches reversed
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"})
require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
require.NoError(t, err)
// Exact matches reversed with just 1
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"})
require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
require.NoError(t, err)
}
func TestCertificateV2_Verify(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
caPool := NewCAPool()
require.NoError(t, caPool.AddCA(ca))
f, err := c.Fingerprint()
require.NoError(t, err)
caPool.BlocklistFingerprint(f)
_, err = caPool.VerifyCertificate(time.Now(), c)
require.EqualError(t, err, "certificate is in the block list")
caPool.ResetCertBlocklist()
_, err = caPool.VerifyCertificate(time.Now(), c)
require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
require.EqualError(t, err, "root certificate is expired")
assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil)
})
// Test group assertion
ca, _, caKey, _ = NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
caPem, err := ca.MarshalPEM()
require.NoError(t, err)
caPool = NewCAPool()
b, err := caPool.AddCAFromPEM(caPem)
require.NoError(t, err)
assert.Empty(t, b)
assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"})
})
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
require.NoError(t, err)
}
func TestCertificateV2_VerifyP256(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
c, _, _, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
caPool := NewCAPool()
require.NoError(t, caPool.AddCA(ca))
f, err := c.Fingerprint()
require.NoError(t, err)
caPool.BlocklistFingerprint(f)
_, err = caPool.VerifyCertificate(time.Now(), c)
require.EqualError(t, err, "certificate is in the block list")
caPool.ResetCertBlocklist()
_, err = caPool.VerifyCertificate(time.Now(), c)
require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
require.EqualError(t, err, "root certificate is expired")
assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
})
// Test group assertion
ca, _, caKey, _ = NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
caPem, err := ca.MarshalPEM()
require.NoError(t, err)
caPool = NewCAPool()
b, err := caPool.AddCAFromPEM(caPem)
require.NoError(t, err)
assert.Empty(t, b)
assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"})
})
c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
_, err = caPool.VerifyCertificate(time.Now(), c)
require.NoError(t, err)
}
func TestCertificateV2_Verify_IPs(t *testing.T) {
caIp1 := mustParsePrefixUnmapped("10.0.0.0/16")
caIp2 := mustParsePrefixUnmapped("192.168.0.0/24")
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
caPem, err := ca.MarshalPEM()
require.NoError(t, err)
caPool := NewCAPool()
b, err := caPool.AddCAFromPEM(caPem)
require.NoError(t, err)
assert.Empty(t, b)
// ip is outside the network
cIp1 := mustParsePrefixUnmapped("10.1.0.0/24")
cIp2 := mustParsePrefixUnmapped("192.168.0.1/16")
assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
})
// ip is outside the network reversed order of above
cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
cIp2 = mustParsePrefixUnmapped("10.1.0.0/24")
assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
})
// ip is within the network but mask is outside
cIp1 = mustParsePrefixUnmapped("10.0.1.0/15")
cIp2 = mustParsePrefixUnmapped("192.168.0.1/24")
assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
})
// ip is within the network but mask is outside reversed order of above
cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
cIp2 = mustParsePrefixUnmapped("10.0.1.0/15")
assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
})
// ip and mask are within the network
cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
_, err = caPool.VerifyCertificate(time.Now(), c)
require.NoError(t, err)
// Exact matches
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
require.NoError(t, err)
// Exact matches reversed
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"})
require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
require.NoError(t, err)
// Exact matches reversed with just 1
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"})
require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
require.NoError(t, err)
}
func TestCertificateV2_Verify_Subnets(t *testing.T) {
caIp1 := mustParsePrefixUnmapped("10.0.0.0/16")
caIp2 := mustParsePrefixUnmapped("192.168.0.0/24")
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
caPem, err := ca.MarshalPEM()
require.NoError(t, err)
caPool := NewCAPool()
b, err := caPool.AddCAFromPEM(caPem)
require.NoError(t, err)
assert.Empty(t, b)
// ip is outside the network
cIp1 := mustParsePrefixUnmapped("10.1.0.0/24")
cIp2 := mustParsePrefixUnmapped("192.168.0.1/16")
assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
})
// ip is outside the network reversed order of above
cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
cIp2 = mustParsePrefixUnmapped("10.1.0.0/24")
assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
})
// ip is within the network but mask is outside
cIp1 = mustParsePrefixUnmapped("10.0.1.0/15")
cIp2 = mustParsePrefixUnmapped("192.168.0.1/24")
assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
})
// ip is within the network but mask is outside reversed order of above
cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
cIp2 = mustParsePrefixUnmapped("10.0.1.0/15")
assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
})
// ip and mask are within the network
cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
require.NoError(t, err)
// Exact matches
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
require.NoError(t, err)
// Exact matches reversed
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"})
require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
require.NoError(t, err)
// Exact matches reversed with just 1
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"})
require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c)
require.NoError(t, err)
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

489
cert/cert_v1.go Normal file
View File

@ -0,0 +1,489 @@
package cert
import (
"bytes"
"crypto/ecdh"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"encoding/json"
"encoding/pem"
"fmt"
"net"
"net/netip"
"time"
"golang.org/x/crypto/curve25519"
"google.golang.org/protobuf/proto"
)
const publicKeyLen = 32
type certificateV1 struct {
details detailsV1
signature []byte
}
type detailsV1 struct {
name string
networks []netip.Prefix
unsafeNetworks []netip.Prefix
groups []string
notBefore time.Time
notAfter time.Time
publicKey []byte
isCA bool
issuer string
curve Curve
}
type m map[string]interface{}
func (c *certificateV1) Version() Version {
return Version1
}
func (c *certificateV1) Curve() Curve {
return c.details.curve
}
func (c *certificateV1) Groups() []string {
return c.details.groups
}
func (c *certificateV1) IsCA() bool {
return c.details.isCA
}
func (c *certificateV1) Issuer() string {
return c.details.issuer
}
func (c *certificateV1) Name() string {
return c.details.name
}
func (c *certificateV1) Networks() []netip.Prefix {
return c.details.networks
}
func (c *certificateV1) NotAfter() time.Time {
return c.details.notAfter
}
func (c *certificateV1) NotBefore() time.Time {
return c.details.notBefore
}
func (c *certificateV1) PublicKey() []byte {
return c.details.publicKey
}
func (c *certificateV1) Signature() []byte {
return c.signature
}
func (c *certificateV1) UnsafeNetworks() []netip.Prefix {
return c.details.unsafeNetworks
}
func (c *certificateV1) Fingerprint() (string, error) {
b, err := c.Marshal()
if err != nil {
return "", err
}
sum := sha256.Sum256(b)
return hex.EncodeToString(sum[:]), nil
}
func (c *certificateV1) CheckSignature(key []byte) bool {
b, err := proto.Marshal(c.getRawDetails())
if err != nil {
return false
}
switch c.details.curve {
case Curve_CURVE25519:
return ed25519.Verify(key, b, c.signature)
case Curve_P256:
x, y := elliptic.Unmarshal(elliptic.P256(), key)
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
hashed := sha256.Sum256(b)
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
default:
return false
}
}
func (c *certificateV1) Expired(t time.Time) bool {
return c.details.notBefore.After(t) || c.details.notAfter.Before(t)
}
func (c *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
if curve != c.details.curve {
return fmt.Errorf("curve in cert and private key supplied don't match")
}
if c.details.isCA {
switch curve {
case Curve_CURVE25519:
// the call to PublicKey below will panic slice bounds out of range otherwise
if len(key) != ed25519.PrivateKeySize {
return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key")
}
if !ed25519.PublicKey(c.details.publicKey).Equal(ed25519.PrivateKey(key).Public()) {
return fmt.Errorf("public key in cert and private key supplied don't match")
}
case Curve_P256:
privkey, err := ecdh.P256().NewPrivateKey(key)
if err != nil {
return fmt.Errorf("cannot parse private key as P256: %w", err)
}
pub := privkey.PublicKey().Bytes()
if !bytes.Equal(pub, c.details.publicKey) {
return fmt.Errorf("public key in cert and private key supplied don't match")
}
default:
return fmt.Errorf("invalid curve: %s", curve)
}
return nil
}
var pub []byte
switch curve {
case Curve_CURVE25519:
var err error
pub, err = curve25519.X25519(key, curve25519.Basepoint)
if err != nil {
return err
}
case Curve_P256:
privkey, err := ecdh.P256().NewPrivateKey(key)
if err != nil {
return err
}
pub = privkey.PublicKey().Bytes()
default:
return fmt.Errorf("invalid curve: %s", curve)
}
if !bytes.Equal(pub, c.details.publicKey) {
return fmt.Errorf("public key in cert and private key supplied don't match")
}
return nil
}
// getRawDetails marshals the raw details into protobuf ready struct
func (c *certificateV1) getRawDetails() *RawNebulaCertificateDetails {
rd := &RawNebulaCertificateDetails{
Name: c.details.name,
Groups: c.details.groups,
NotBefore: c.details.notBefore.Unix(),
NotAfter: c.details.notAfter.Unix(),
PublicKey: make([]byte, len(c.details.publicKey)),
IsCA: c.details.isCA,
Curve: c.details.curve,
}
for _, ipNet := range c.details.networks {
mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen())
rd.Ips = append(rd.Ips, addr2int(ipNet.Addr()), ip2int(mask))
}
for _, ipNet := range c.details.unsafeNetworks {
mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen())
rd.Subnets = append(rd.Subnets, addr2int(ipNet.Addr()), ip2int(mask))
}
copy(rd.PublicKey, c.details.publicKey[:])
// I know, this is terrible
rd.Issuer, _ = hex.DecodeString(c.details.issuer)
return rd
}
func (c *certificateV1) String() string {
b, err := json.MarshalIndent(c.marshalJSON(), "", "\t")
if err != nil {
return fmt.Sprintf("<error marshalling certificate: %v>", err)
}
return string(b)
}
func (c *certificateV1) MarshalForHandshakes() ([]byte, error) {
pubKey := c.details.publicKey
c.details.publicKey = nil
rawCertNoKey, err := c.Marshal()
if err != nil {
return nil, err
}
c.details.publicKey = pubKey
return rawCertNoKey, nil
}
func (c *certificateV1) Marshal() ([]byte, error) {
rc := RawNebulaCertificate{
Details: c.getRawDetails(),
Signature: c.signature,
}
return proto.Marshal(&rc)
}
func (c *certificateV1) MarshalPEM() ([]byte, error) {
b, err := c.Marshal()
if err != nil {
return nil, err
}
return pem.EncodeToMemory(&pem.Block{Type: CertificateBanner, Bytes: b}), nil
}
func (c *certificateV1) MarshalJSON() ([]byte, error) {
return json.Marshal(c.marshalJSON())
}
func (c *certificateV1) marshalJSON() m {
fp, _ := c.Fingerprint()
return m{
"version": Version1,
"details": m{
"name": c.details.name,
"networks": c.details.networks,
"unsafeNetworks": c.details.unsafeNetworks,
"groups": c.details.groups,
"notBefore": c.details.notBefore,
"notAfter": c.details.notAfter,
"publicKey": fmt.Sprintf("%x", c.details.publicKey),
"isCa": c.details.isCA,
"issuer": c.details.issuer,
"curve": c.details.curve.String(),
},
"fingerprint": fp,
"signature": fmt.Sprintf("%x", c.Signature()),
}
}
func (c *certificateV1) Copy() Certificate {
nc := &certificateV1{
details: detailsV1{
name: c.details.name,
notBefore: c.details.notBefore,
notAfter: c.details.notAfter,
publicKey: make([]byte, len(c.details.publicKey)),
isCA: c.details.isCA,
issuer: c.details.issuer,
curve: c.details.curve,
},
signature: make([]byte, len(c.signature)),
}
if c.details.groups != nil {
nc.details.groups = make([]string, len(c.details.groups))
copy(nc.details.groups, c.details.groups)
}
if c.details.networks != nil {
nc.details.networks = make([]netip.Prefix, len(c.details.networks))
copy(nc.details.networks, c.details.networks)
}
if c.details.unsafeNetworks != nil {
nc.details.unsafeNetworks = make([]netip.Prefix, len(c.details.unsafeNetworks))
copy(nc.details.unsafeNetworks, c.details.unsafeNetworks)
}
copy(nc.signature, c.signature)
copy(nc.details.publicKey, c.details.publicKey)
return nc
}
func (c *certificateV1) fromTBSCertificate(t *TBSCertificate) error {
c.details = detailsV1{
name: t.Name,
networks: t.Networks,
unsafeNetworks: t.UnsafeNetworks,
groups: t.Groups,
notBefore: t.NotBefore,
notAfter: t.NotAfter,
publicKey: t.PublicKey,
isCA: t.IsCA,
curve: t.Curve,
issuer: t.issuer,
}
return c.validate()
}
func (c *certificateV1) validate() error {
// Empty names are allowed
if len(c.details.publicKey) == 0 {
return ErrInvalidPublicKey
}
// Original v1 rules allowed multiple networks to be present but ignored all but the first one.
// Continue to allow this behavior
if !c.details.isCA && len(c.details.networks) == 0 {
return NewErrInvalidCertificateProperties("non-CA certificates must contain exactly one network")
}
for _, network := range c.details.networks {
if !network.IsValid() || !network.Addr().IsValid() {
return NewErrInvalidCertificateProperties("invalid network: %s", network)
}
if network.Addr().Is6() {
return NewErrInvalidCertificateProperties("certificate may not contain IPv6 networks: %v", network)
}
if network.Addr().IsUnspecified() {
return NewErrInvalidCertificateProperties("non-CA certificates must not use the zero address as a network: %s", network)
}
if network.Addr().Zone() != "" {
return NewErrInvalidCertificateProperties("networks may not contain zones: %s", network)
}
}
for _, network := range c.details.unsafeNetworks {
if !network.IsValid() || !network.Addr().IsValid() {
return NewErrInvalidCertificateProperties("invalid unsafe network: %s", network)
}
if network.Addr().Is6() {
return NewErrInvalidCertificateProperties("certificate may not contain IPv6 unsafe networks: %v", network)
}
if network.Addr().Zone() != "" {
return NewErrInvalidCertificateProperties("unsafe networks may not contain zones: %s", network)
}
}
// v1 doesn't bother with sort order or uniqueness of networks or unsafe networks.
// We can't modify the unmarshalled data because verification requires re-marshalling and a re-ordered
// unsafe networks would result in a different signature.
return nil
}
func (c *certificateV1) marshalForSigning() ([]byte, error) {
b, err := proto.Marshal(c.getRawDetails())
if err != nil {
return nil, err
}
return b, nil
}
func (c *certificateV1) setSignature(b []byte) error {
if len(b) == 0 {
return ErrEmptySignature
}
c.signature = b
return nil
}
// unmarshalCertificateV1 will unmarshal a protobuf byte representation of a nebula cert
// if the publicKey is provided here then it is not required to be present in `b`
func unmarshalCertificateV1(b []byte, publicKey []byte) (*certificateV1, error) {
if len(b) == 0 {
return nil, fmt.Errorf("nil byte array")
}
var rc RawNebulaCertificate
err := proto.Unmarshal(b, &rc)
if err != nil {
return nil, err
}
if rc.Details == nil {
return nil, fmt.Errorf("encoded Details was nil")
}
if len(rc.Details.Ips)%2 != 0 {
return nil, fmt.Errorf("encoded IPs should be in pairs, an odd number was found")
}
if len(rc.Details.Subnets)%2 != 0 {
return nil, fmt.Errorf("encoded Subnets should be in pairs, an odd number was found")
}
nc := certificateV1{
details: detailsV1{
name: rc.Details.Name,
groups: make([]string, len(rc.Details.Groups)),
networks: make([]netip.Prefix, len(rc.Details.Ips)/2),
unsafeNetworks: make([]netip.Prefix, len(rc.Details.Subnets)/2),
notBefore: time.Unix(rc.Details.NotBefore, 0),
notAfter: time.Unix(rc.Details.NotAfter, 0),
publicKey: make([]byte, len(rc.Details.PublicKey)),
isCA: rc.Details.IsCA,
curve: rc.Details.Curve,
},
signature: make([]byte, len(rc.Signature)),
}
copy(nc.signature, rc.Signature)
copy(nc.details.groups, rc.Details.Groups)
nc.details.issuer = hex.EncodeToString(rc.Details.Issuer)
if len(publicKey) > 0 {
nc.details.publicKey = publicKey
}
copy(nc.details.publicKey, rc.Details.PublicKey)
var ip netip.Addr
for i, rawIp := range rc.Details.Ips {
if i%2 == 0 {
ip = int2addr(rawIp)
} else {
ones, _ := net.IPMask(int2ip(rawIp)).Size()
nc.details.networks[i/2] = netip.PrefixFrom(ip, ones)
}
}
for i, rawIp := range rc.Details.Subnets {
if i%2 == 0 {
ip = int2addr(rawIp)
} else {
ones, _ := net.IPMask(int2ip(rawIp)).Size()
nc.details.unsafeNetworks[i/2] = netip.PrefixFrom(ip, ones)
}
}
err = nc.validate()
if err != nil {
return nil, err
}
return &nc, nil
}
func ip2int(ip []byte) uint32 {
if len(ip) == 16 {
return binary.BigEndian.Uint32(ip[12:16])
}
return binary.BigEndian.Uint32(ip)
}
func int2ip(nn uint32) net.IP {
ip := make(net.IP, net.IPv4len)
binary.BigEndian.PutUint32(ip, nn)
return ip
}
func addr2int(addr netip.Addr) uint32 {
b := addr.Unmap().As4()
return binary.BigEndian.Uint32(b[:])
}
func int2addr(nn uint32) netip.Addr {
ip := [4]byte{}
binary.BigEndian.PutUint32(ip[:], nn)
return netip.AddrFrom4(ip).Unmap()
}

View File

@ -1,8 +1,8 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.30.0
// protoc-gen-go v1.34.2
// protoc v3.21.5
// source: cert.proto
// source: cert_v1.proto
package cert
@ -50,11 +50,11 @@ func (x Curve) String() string {
}
func (Curve) Descriptor() protoreflect.EnumDescriptor {
return file_cert_proto_enumTypes[0].Descriptor()
return file_cert_v1_proto_enumTypes[0].Descriptor()
}
func (Curve) Type() protoreflect.EnumType {
return &file_cert_proto_enumTypes[0]
return &file_cert_v1_proto_enumTypes[0]
}
func (x Curve) Number() protoreflect.EnumNumber {
@ -63,7 +63,7 @@ func (x Curve) Number() protoreflect.EnumNumber {
// Deprecated: Use Curve.Descriptor instead.
func (Curve) EnumDescriptor() ([]byte, []int) {
return file_cert_proto_rawDescGZIP(), []int{0}
return file_cert_v1_proto_rawDescGZIP(), []int{0}
}
type RawNebulaCertificate struct {
@ -78,7 +78,7 @@ type RawNebulaCertificate struct {
func (x *RawNebulaCertificate) Reset() {
*x = RawNebulaCertificate{}
if protoimpl.UnsafeEnabled {
mi := &file_cert_proto_msgTypes[0]
mi := &file_cert_v1_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@ -91,7 +91,7 @@ func (x *RawNebulaCertificate) String() string {
func (*RawNebulaCertificate) ProtoMessage() {}
func (x *RawNebulaCertificate) ProtoReflect() protoreflect.Message {
mi := &file_cert_proto_msgTypes[0]
mi := &file_cert_v1_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@ -104,7 +104,7 @@ func (x *RawNebulaCertificate) ProtoReflect() protoreflect.Message {
// Deprecated: Use RawNebulaCertificate.ProtoReflect.Descriptor instead.
func (*RawNebulaCertificate) Descriptor() ([]byte, []int) {
return file_cert_proto_rawDescGZIP(), []int{0}
return file_cert_v1_proto_rawDescGZIP(), []int{0}
}
func (x *RawNebulaCertificate) GetDetails() *RawNebulaCertificateDetails {
@ -143,7 +143,7 @@ type RawNebulaCertificateDetails struct {
func (x *RawNebulaCertificateDetails) Reset() {
*x = RawNebulaCertificateDetails{}
if protoimpl.UnsafeEnabled {
mi := &file_cert_proto_msgTypes[1]
mi := &file_cert_v1_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@ -156,7 +156,7 @@ func (x *RawNebulaCertificateDetails) String() string {
func (*RawNebulaCertificateDetails) ProtoMessage() {}
func (x *RawNebulaCertificateDetails) ProtoReflect() protoreflect.Message {
mi := &file_cert_proto_msgTypes[1]
mi := &file_cert_v1_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@ -169,7 +169,7 @@ func (x *RawNebulaCertificateDetails) ProtoReflect() protoreflect.Message {
// Deprecated: Use RawNebulaCertificateDetails.ProtoReflect.Descriptor instead.
func (*RawNebulaCertificateDetails) Descriptor() ([]byte, []int) {
return file_cert_proto_rawDescGZIP(), []int{1}
return file_cert_v1_proto_rawDescGZIP(), []int{1}
}
func (x *RawNebulaCertificateDetails) GetName() string {
@ -254,7 +254,7 @@ type RawNebulaEncryptedData struct {
func (x *RawNebulaEncryptedData) Reset() {
*x = RawNebulaEncryptedData{}
if protoimpl.UnsafeEnabled {
mi := &file_cert_proto_msgTypes[2]
mi := &file_cert_v1_proto_msgTypes[2]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@ -267,7 +267,7 @@ func (x *RawNebulaEncryptedData) String() string {
func (*RawNebulaEncryptedData) ProtoMessage() {}
func (x *RawNebulaEncryptedData) ProtoReflect() protoreflect.Message {
mi := &file_cert_proto_msgTypes[2]
mi := &file_cert_v1_proto_msgTypes[2]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@ -280,7 +280,7 @@ func (x *RawNebulaEncryptedData) ProtoReflect() protoreflect.Message {
// Deprecated: Use RawNebulaEncryptedData.ProtoReflect.Descriptor instead.
func (*RawNebulaEncryptedData) Descriptor() ([]byte, []int) {
return file_cert_proto_rawDescGZIP(), []int{2}
return file_cert_v1_proto_rawDescGZIP(), []int{2}
}
func (x *RawNebulaEncryptedData) GetEncryptionMetadata() *RawNebulaEncryptionMetadata {
@ -309,7 +309,7 @@ type RawNebulaEncryptionMetadata struct {
func (x *RawNebulaEncryptionMetadata) Reset() {
*x = RawNebulaEncryptionMetadata{}
if protoimpl.UnsafeEnabled {
mi := &file_cert_proto_msgTypes[3]
mi := &file_cert_v1_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@ -322,7 +322,7 @@ func (x *RawNebulaEncryptionMetadata) String() string {
func (*RawNebulaEncryptionMetadata) ProtoMessage() {}
func (x *RawNebulaEncryptionMetadata) ProtoReflect() protoreflect.Message {
mi := &file_cert_proto_msgTypes[3]
mi := &file_cert_v1_proto_msgTypes[3]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@ -335,7 +335,7 @@ func (x *RawNebulaEncryptionMetadata) ProtoReflect() protoreflect.Message {
// Deprecated: Use RawNebulaEncryptionMetadata.ProtoReflect.Descriptor instead.
func (*RawNebulaEncryptionMetadata) Descriptor() ([]byte, []int) {
return file_cert_proto_rawDescGZIP(), []int{3}
return file_cert_v1_proto_rawDescGZIP(), []int{3}
}
func (x *RawNebulaEncryptionMetadata) GetEncryptionAlgorithm() string {
@ -367,7 +367,7 @@ type RawNebulaArgon2Parameters struct {
func (x *RawNebulaArgon2Parameters) Reset() {
*x = RawNebulaArgon2Parameters{}
if protoimpl.UnsafeEnabled {
mi := &file_cert_proto_msgTypes[4]
mi := &file_cert_v1_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@ -380,7 +380,7 @@ func (x *RawNebulaArgon2Parameters) String() string {
func (*RawNebulaArgon2Parameters) ProtoMessage() {}
func (x *RawNebulaArgon2Parameters) ProtoReflect() protoreflect.Message {
mi := &file_cert_proto_msgTypes[4]
mi := &file_cert_v1_proto_msgTypes[4]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@ -393,7 +393,7 @@ func (x *RawNebulaArgon2Parameters) ProtoReflect() protoreflect.Message {
// Deprecated: Use RawNebulaArgon2Parameters.ProtoReflect.Descriptor instead.
func (*RawNebulaArgon2Parameters) Descriptor() ([]byte, []int) {
return file_cert_proto_rawDescGZIP(), []int{4}
return file_cert_v1_proto_rawDescGZIP(), []int{4}
}
func (x *RawNebulaArgon2Parameters) GetVersion() int32 {
@ -431,87 +431,87 @@ func (x *RawNebulaArgon2Parameters) GetSalt() []byte {
return nil
}
var File_cert_proto protoreflect.FileDescriptor
var File_cert_v1_proto protoreflect.FileDescriptor
var file_cert_proto_rawDesc = []byte{
0x0a, 0x0a, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x04, 0x63, 0x65,
0x72, 0x74, 0x22, 0x71, 0x0a, 0x14, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x43,
0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x12, 0x3b, 0x0a, 0x07, 0x44, 0x65,
0x74, 0x61, 0x69, 0x6c, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x63, 0x65,
0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x43, 0x65, 0x72, 0x74,
0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x52, 0x07,
0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x53, 0x69, 0x67, 0x6e, 0x61,
0x74, 0x75, 0x72, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x53, 0x69, 0x67, 0x6e,
0x61, 0x74, 0x75, 0x72, 0x65, 0x22, 0x9c, 0x02, 0x0a, 0x1b, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62,
0x75, 0x6c, 0x61, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x44, 0x65,
0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20,
0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x49, 0x70, 0x73,
0x18, 0x02, 0x20, 0x03, 0x28, 0x0d, 0x52, 0x03, 0x49, 0x70, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x53,
0x75, 0x62, 0x6e, 0x65, 0x74, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0d, 0x52, 0x07, 0x53, 0x75,
0x62, 0x6e, 0x65, 0x74, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18,
0x04, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x1c, 0x0a,
0x09, 0x4e, 0x6f, 0x74, 0x42, 0x65, 0x66, 0x6f, 0x72, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03,
0x52, 0x09, 0x4e, 0x6f, 0x74, 0x42, 0x65, 0x66, 0x6f, 0x72, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x4e,
0x6f, 0x74, 0x41, 0x66, 0x74, 0x65, 0x72, 0x18, 0x06, 0x20, 0x01, 0x28, 0x03, 0x52, 0x08, 0x4e,
0x6f, 0x74, 0x41, 0x66, 0x74, 0x65, 0x72, 0x12, 0x1c, 0x0a, 0x09, 0x50, 0x75, 0x62, 0x6c, 0x69,
0x63, 0x4b, 0x65, 0x79, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x50, 0x75, 0x62, 0x6c,
0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x49, 0x73, 0x43, 0x41, 0x18, 0x08, 0x20,
0x01, 0x28, 0x08, 0x52, 0x04, 0x49, 0x73, 0x43, 0x41, 0x12, 0x16, 0x0a, 0x06, 0x49, 0x73, 0x73,
0x75, 0x65, 0x72, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x06, 0x49, 0x73, 0x73, 0x75, 0x65,
0x72, 0x12, 0x21, 0x0a, 0x05, 0x63, 0x75, 0x72, 0x76, 0x65, 0x18, 0x64, 0x20, 0x01, 0x28, 0x0e,
0x32, 0x0b, 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x43, 0x75, 0x72, 0x76, 0x65, 0x52, 0x05, 0x63,
0x75, 0x72, 0x76, 0x65, 0x22, 0x8b, 0x01, 0x0a, 0x16, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75,
0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x44, 0x61, 0x74, 0x61, 0x12,
0x51, 0x0a, 0x12, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74,
0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x63, 0x65,
0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72,
0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x52, 0x12,
0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61,
0x74, 0x61, 0x12, 0x1e, 0x0a, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74, 0x65, 0x78, 0x74,
0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74, 0x65,
0x78, 0x74, 0x22, 0x9c, 0x01, 0x0a, 0x1b, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61,
0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61,
0x74, 0x61, 0x12, 0x30, 0x0a, 0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e,
0x41, 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52,
0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x41, 0x6c, 0x67, 0x6f, 0x72,
0x69, 0x74, 0x68, 0x6d, 0x12, 0x4b, 0x0a, 0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61,
0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f,
0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x41,
0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x52,
0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72,
0x73, 0x22, 0xa3, 0x01, 0x0a, 0x19, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x41,
0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x12,
0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05,
0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x6d,
0x6f, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x06, 0x6d, 0x65, 0x6d, 0x6f, 0x72,
0x79, 0x12, 0x20, 0x0a, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, 0x6c, 0x65, 0x6c, 0x69, 0x73, 0x6d,
0x18, 0x04, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, 0x6c, 0x65, 0x6c,
0x69, 0x73, 0x6d, 0x12, 0x1e, 0x0a, 0x0a, 0x69, 0x74, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e,
0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x69, 0x74, 0x65, 0x72, 0x61, 0x74, 0x69,
0x6f, 0x6e, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28,
0x0c, 0x52, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x2a, 0x21, 0x0a, 0x05, 0x43, 0x75, 0x72, 0x76, 0x65,
0x12, 0x0e, 0x0a, 0x0a, 0x43, 0x55, 0x52, 0x56, 0x45, 0x32, 0x35, 0x35, 0x31, 0x39, 0x10, 0x00,
0x12, 0x08, 0x0a, 0x04, 0x50, 0x32, 0x35, 0x36, 0x10, 0x01, 0x42, 0x20, 0x5a, 0x1e, 0x67, 0x69,
0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x73, 0x6c, 0x61, 0x63, 0x6b, 0x68, 0x71,
0x2f, 0x6e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x2f, 0x63, 0x65, 0x72, 0x74, 0x62, 0x06, 0x70, 0x72,
0x6f, 0x74, 0x6f, 0x33,
var file_cert_v1_proto_rawDesc = []byte{
0x0a, 0x0d, 0x63, 0x65, 0x72, 0x74, 0x5f, 0x76, 0x31, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12,
0x04, 0x63, 0x65, 0x72, 0x74, 0x22, 0x71, 0x0a, 0x14, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75,
0x6c, 0x61, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x12, 0x3b, 0x0a,
0x07, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21,
0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x43,
0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c,
0x73, 0x52, 0x07, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x53, 0x69,
0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x53,
0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x22, 0x9c, 0x02, 0x0a, 0x1b, 0x52, 0x61, 0x77,
0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74,
0x65, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65,
0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03,
0x49, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0d, 0x52, 0x03, 0x49, 0x70, 0x73, 0x12, 0x18,
0x0a, 0x07, 0x53, 0x75, 0x62, 0x6e, 0x65, 0x74, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0d, 0x52,
0x07, 0x53, 0x75, 0x62, 0x6e, 0x65, 0x74, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x47, 0x72, 0x6f, 0x75,
0x70, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73,
0x12, 0x1c, 0x0a, 0x09, 0x4e, 0x6f, 0x74, 0x42, 0x65, 0x66, 0x6f, 0x72, 0x65, 0x18, 0x05, 0x20,
0x01, 0x28, 0x03, 0x52, 0x09, 0x4e, 0x6f, 0x74, 0x42, 0x65, 0x66, 0x6f, 0x72, 0x65, 0x12, 0x1a,
0x0a, 0x08, 0x4e, 0x6f, 0x74, 0x41, 0x66, 0x74, 0x65, 0x72, 0x18, 0x06, 0x20, 0x01, 0x28, 0x03,
0x52, 0x08, 0x4e, 0x6f, 0x74, 0x41, 0x66, 0x74, 0x65, 0x72, 0x12, 0x1c, 0x0a, 0x09, 0x50, 0x75,
0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x50,
0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x49, 0x73, 0x43, 0x41,
0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x04, 0x49, 0x73, 0x43, 0x41, 0x12, 0x16, 0x0a, 0x06,
0x49, 0x73, 0x73, 0x75, 0x65, 0x72, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x06, 0x49, 0x73,
0x73, 0x75, 0x65, 0x72, 0x12, 0x21, 0x0a, 0x05, 0x63, 0x75, 0x72, 0x76, 0x65, 0x18, 0x64, 0x20,
0x01, 0x28, 0x0e, 0x32, 0x0b, 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x43, 0x75, 0x72, 0x76, 0x65,
0x52, 0x05, 0x63, 0x75, 0x72, 0x76, 0x65, 0x22, 0x8b, 0x01, 0x0a, 0x16, 0x52, 0x61, 0x77, 0x4e,
0x65, 0x62, 0x75, 0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x44, 0x61,
0x74, 0x61, 0x12, 0x51, 0x0a, 0x12, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e,
0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21,
0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x45,
0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74,
0x61, 0x52, 0x12, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74,
0x61, 0x64, 0x61, 0x74, 0x61, 0x12, 0x1e, 0x0a, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74,
0x65, 0x78, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65,
0x72, 0x74, 0x65, 0x78, 0x74, 0x22, 0x9c, 0x01, 0x0a, 0x1b, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62,
0x75, 0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74,
0x61, 0x64, 0x61, 0x74, 0x61, 0x12, 0x30, 0x0a, 0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74,
0x69, 0x6f, 0x6e, 0x41, 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x18, 0x01, 0x20, 0x01,
0x28, 0x09, 0x52, 0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x41, 0x6c,
0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x12, 0x4b, 0x0a, 0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e,
0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28,
0x0b, 0x32, 0x1f, 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75,
0x6c, 0x61, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65,
0x72, 0x73, 0x52, 0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65,
0x74, 0x65, 0x72, 0x73, 0x22, 0xa3, 0x01, 0x0a, 0x19, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75,
0x6c, 0x61, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65,
0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20,
0x01, 0x28, 0x05, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x16, 0x0a, 0x06,
0x6d, 0x65, 0x6d, 0x6f, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x06, 0x6d, 0x65,
0x6d, 0x6f, 0x72, 0x79, 0x12, 0x20, 0x0a, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, 0x6c, 0x65, 0x6c,
0x69, 0x73, 0x6d, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c,
0x6c, 0x65, 0x6c, 0x69, 0x73, 0x6d, 0x12, 0x1e, 0x0a, 0x0a, 0x69, 0x74, 0x65, 0x72, 0x61, 0x74,
0x69, 0x6f, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x69, 0x74, 0x65, 0x72,
0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x18, 0x05,
0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x2a, 0x21, 0x0a, 0x05, 0x43, 0x75,
0x72, 0x76, 0x65, 0x12, 0x0e, 0x0a, 0x0a, 0x43, 0x55, 0x52, 0x56, 0x45, 0x32, 0x35, 0x35, 0x31,
0x39, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x50, 0x32, 0x35, 0x36, 0x10, 0x01, 0x42, 0x20, 0x5a,
0x1e, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x73, 0x6c, 0x61, 0x63,
0x6b, 0x68, 0x71, 0x2f, 0x6e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x2f, 0x63, 0x65, 0x72, 0x74, 0x62,
0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
file_cert_proto_rawDescOnce sync.Once
file_cert_proto_rawDescData = file_cert_proto_rawDesc
file_cert_v1_proto_rawDescOnce sync.Once
file_cert_v1_proto_rawDescData = file_cert_v1_proto_rawDesc
)
func file_cert_proto_rawDescGZIP() []byte {
file_cert_proto_rawDescOnce.Do(func() {
file_cert_proto_rawDescData = protoimpl.X.CompressGZIP(file_cert_proto_rawDescData)
func file_cert_v1_proto_rawDescGZIP() []byte {
file_cert_v1_proto_rawDescOnce.Do(func() {
file_cert_v1_proto_rawDescData = protoimpl.X.CompressGZIP(file_cert_v1_proto_rawDescData)
})
return file_cert_proto_rawDescData
return file_cert_v1_proto_rawDescData
}
var file_cert_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
var file_cert_proto_msgTypes = make([]protoimpl.MessageInfo, 5)
var file_cert_proto_goTypes = []interface{}{
var file_cert_v1_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
var file_cert_v1_proto_msgTypes = make([]protoimpl.MessageInfo, 5)
var file_cert_v1_proto_goTypes = []any{
(Curve)(0), // 0: cert.Curve
(*RawNebulaCertificate)(nil), // 1: cert.RawNebulaCertificate
(*RawNebulaCertificateDetails)(nil), // 2: cert.RawNebulaCertificateDetails
@ -519,7 +519,7 @@ var file_cert_proto_goTypes = []interface{}{
(*RawNebulaEncryptionMetadata)(nil), // 4: cert.RawNebulaEncryptionMetadata
(*RawNebulaArgon2Parameters)(nil), // 5: cert.RawNebulaArgon2Parameters
}
var file_cert_proto_depIdxs = []int32{
var file_cert_v1_proto_depIdxs = []int32{
2, // 0: cert.RawNebulaCertificate.Details:type_name -> cert.RawNebulaCertificateDetails
0, // 1: cert.RawNebulaCertificateDetails.curve:type_name -> cert.Curve
4, // 2: cert.RawNebulaEncryptedData.EncryptionMetadata:type_name -> cert.RawNebulaEncryptionMetadata
@ -531,13 +531,13 @@ var file_cert_proto_depIdxs = []int32{
0, // [0:4] is the sub-list for field type_name
}
func init() { file_cert_proto_init() }
func file_cert_proto_init() {
if File_cert_proto != nil {
func init() { file_cert_v1_proto_init() }
func file_cert_v1_proto_init() {
if File_cert_v1_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_cert_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
file_cert_v1_proto_msgTypes[0].Exporter = func(v any, i int) any {
switch v := v.(*RawNebulaCertificate); i {
case 0:
return &v.state
@ -549,7 +549,7 @@ func file_cert_proto_init() {
return nil
}
}
file_cert_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
file_cert_v1_proto_msgTypes[1].Exporter = func(v any, i int) any {
switch v := v.(*RawNebulaCertificateDetails); i {
case 0:
return &v.state
@ -561,7 +561,7 @@ func file_cert_proto_init() {
return nil
}
}
file_cert_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
file_cert_v1_proto_msgTypes[2].Exporter = func(v any, i int) any {
switch v := v.(*RawNebulaEncryptedData); i {
case 0:
return &v.state
@ -573,7 +573,7 @@ func file_cert_proto_init() {
return nil
}
}
file_cert_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
file_cert_v1_proto_msgTypes[3].Exporter = func(v any, i int) any {
switch v := v.(*RawNebulaEncryptionMetadata); i {
case 0:
return &v.state
@ -585,7 +585,7 @@ func file_cert_proto_init() {
return nil
}
}
file_cert_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
file_cert_v1_proto_msgTypes[4].Exporter = func(v any, i int) any {
switch v := v.(*RawNebulaArgon2Parameters); i {
case 0:
return &v.state
@ -602,19 +602,19 @@ func file_cert_proto_init() {
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_cert_proto_rawDesc,
RawDescriptor: file_cert_v1_proto_rawDesc,
NumEnums: 1,
NumMessages: 5,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_cert_proto_goTypes,
DependencyIndexes: file_cert_proto_depIdxs,
EnumInfos: file_cert_proto_enumTypes,
MessageInfos: file_cert_proto_msgTypes,
GoTypes: file_cert_v1_proto_goTypes,
DependencyIndexes: file_cert_v1_proto_depIdxs,
EnumInfos: file_cert_v1_proto_enumTypes,
MessageInfos: file_cert_v1_proto_msgTypes,
}.Build()
File_cert_proto = out.File
file_cert_proto_rawDesc = nil
file_cert_proto_goTypes = nil
file_cert_proto_depIdxs = nil
File_cert_v1_proto = out.File
file_cert_v1_proto_rawDesc = nil
file_cert_v1_proto_goTypes = nil
file_cert_v1_proto_depIdxs = nil
}

218
cert/cert_v1_test.go Normal file
View File

@ -0,0 +1,218 @@
package cert
import (
"fmt"
"net/netip"
"testing"
"time"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
)
func TestCertificateV1_Marshal(t *testing.T) {
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := []byte("1234567890abcedfghij1234567890ab")
nc := certificateV1{
details: detailsV1{
name: "testing",
networks: []netip.Prefix{
mustParsePrefixUnmapped("10.1.1.1/24"),
mustParsePrefixUnmapped("10.1.1.2/16"),
},
unsafeNetworks: []netip.Prefix{
mustParsePrefixUnmapped("9.1.1.2/24"),
mustParsePrefixUnmapped("9.1.1.3/16"),
},
groups: []string{"test-group1", "test-group2", "test-group3"},
notBefore: before,
notAfter: after,
publicKey: pubKey,
isCA: false,
issuer: "1234567890abcedfghij1234567890ab",
},
signature: []byte("1234567890abcedfghij1234567890ab"),
}
b, err := nc.Marshal()
require.NoError(t, err)
//t.Log("Cert size:", len(b))
nc2, err := unmarshalCertificateV1(b, nil)
require.NoError(t, err)
assert.Equal(t, Version1, nc.Version())
assert.Equal(t, Curve_CURVE25519, nc.Curve())
assert.Equal(t, nc.Signature(), nc2.Signature())
assert.Equal(t, nc.Name(), nc2.Name())
assert.Equal(t, nc.NotBefore(), nc2.NotBefore())
assert.Equal(t, nc.NotAfter(), nc2.NotAfter())
assert.Equal(t, nc.PublicKey(), nc2.PublicKey())
assert.Equal(t, nc.IsCA(), nc2.IsCA())
assert.Equal(t, nc.Networks(), nc2.Networks())
assert.Equal(t, nc.UnsafeNetworks(), nc2.UnsafeNetworks())
assert.Equal(t, nc.Groups(), nc2.Groups())
}
func TestCertificateV1_Expired(t *testing.T) {
nc := certificateV1{
details: detailsV1{
notBefore: time.Now().Add(time.Second * -60).Round(time.Second),
notAfter: time.Now().Add(time.Second * 60).Round(time.Second),
},
}
assert.True(t, nc.Expired(time.Now().Add(time.Hour)))
assert.True(t, nc.Expired(time.Now().Add(-time.Hour)))
assert.False(t, nc.Expired(time.Now()))
}
func TestCertificateV1_MarshalJSON(t *testing.T) {
time.Local = time.UTC
pubKey := []byte("1234567890abcedfghij1234567890ab")
nc := certificateV1{
details: detailsV1{
name: "testing",
networks: []netip.Prefix{
mustParsePrefixUnmapped("10.1.1.1/24"),
mustParsePrefixUnmapped("10.1.1.2/16"),
},
unsafeNetworks: []netip.Prefix{
mustParsePrefixUnmapped("9.1.1.2/24"),
mustParsePrefixUnmapped("9.1.1.3/16"),
},
groups: []string{"test-group1", "test-group2", "test-group3"},
notBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC),
notAfter: time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC),
publicKey: pubKey,
isCA: false,
issuer: "1234567890abcedfghij1234567890ab",
},
signature: []byte("1234567890abcedfghij1234567890ab"),
}
b, err := nc.MarshalJSON()
require.NoError(t, err)
assert.JSONEq(
t,
"{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"version\":1}",
string(b),
)
}
func TestCertificateV1_VerifyPrivateKey(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
require.NoError(t, err)
_, _, caKey2, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
require.NoError(t, err)
err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2)
require.Error(t, err)
c, _, priv, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
require.NoError(t, err)
assert.Empty(t, b)
assert.Equal(t, Curve_CURVE25519, curve)
err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv)
require.NoError(t, err)
_, priv2 := X25519Keypair()
err = c.VerifyPrivateKey(Curve_CURVE25519, priv2)
require.Error(t, err)
}
func TestCertificateV1_VerifyPrivateKeyP256(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
err := ca.VerifyPrivateKey(Curve_P256, caKey)
require.NoError(t, err)
_, _, caKey2, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
require.NoError(t, err)
err = ca.VerifyPrivateKey(Curve_P256, caKey2)
require.Error(t, err)
c, _, priv, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
require.NoError(t, err)
assert.Empty(t, b)
assert.Equal(t, Curve_P256, curve)
err = c.VerifyPrivateKey(Curve_P256, rawPriv)
require.NoError(t, err)
_, priv2 := P256Keypair()
err = c.VerifyPrivateKey(Curve_P256, priv2)
require.Error(t, err)
}
// Ensure that upgrading the protobuf library does not change how certificates
// are marshalled, since this would break signature verification
func TestMarshalingCertificateV1Consistency(t *testing.T) {
before := time.Date(1970, time.January, 1, 1, 1, 1, 1, time.UTC)
after := time.Date(9999, time.January, 1, 1, 1, 1, 1, time.UTC)
pubKey := []byte("1234567890abcedfghij1234567890ab")
nc := certificateV1{
details: detailsV1{
name: "testing",
networks: []netip.Prefix{
mustParsePrefixUnmapped("10.1.1.2/16"),
mustParsePrefixUnmapped("10.1.1.1/24"),
},
unsafeNetworks: []netip.Prefix{
mustParsePrefixUnmapped("9.1.1.3/16"),
mustParsePrefixUnmapped("9.1.1.2/24"),
},
groups: []string{"test-group1", "test-group2", "test-group3"},
notBefore: before,
notAfter: after,
publicKey: pubKey,
isCA: false,
issuer: "1234567890abcedfghij1234567890ab",
},
signature: []byte("1234567890abcedfghij1234567890ab"),
}
b, err := nc.Marshal()
require.NoError(t, err)
assert.Equal(t, "0a8e010a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf1220313233343536373839306162636564666768696a313233343536373839306162", fmt.Sprintf("%x", b))
b, err = proto.Marshal(nc.getRawDetails())
require.NoError(t, err)
assert.Equal(t, "0a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b))
}
func TestCertificateV1_Copy(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
cc := c.Copy()
test.AssertDeepCopyEqual(t, c, cc)
}
func TestUnmarshalCertificateV1(t *testing.T) {
// Test that we don't panic with an invalid certificate (#332)
data := []byte("\x98\x00\x00")
_, err := unmarshalCertificateV1(data, nil)
require.EqualError(t, err, "encoded Details was nil")
}
func appendByteSlices(b ...[]byte) []byte {
retSlice := []byte{}
for _, v := range b {
retSlice = append(retSlice, v...)
}
return retSlice
}
func mustParsePrefixUnmapped(s string) netip.Prefix {
prefix := netip.MustParsePrefix(s)
return netip.PrefixFrom(prefix.Addr().Unmap(), prefix.Bits())
}

37
cert/cert_v2.asn1 Normal file
View File

@ -0,0 +1,37 @@
Nebula DEFINITIONS AUTOMATIC TAGS ::= BEGIN
Name ::= UTF8String (SIZE (1..253))
Time ::= INTEGER (0..18446744073709551615) -- Seconds since unix epoch, uint64 maximum
Network ::= OCTET STRING (SIZE (5,17)) -- IP addresses are 4 or 16 bytes + 1 byte for the prefix length
Curve ::= ENUMERATED {
curve25519 (0),
p256 (1)
}
-- The maximum size of a certificate must not exceed 65536 bytes
Certificate ::= SEQUENCE {
details OCTET STRING,
curve Curve DEFAULT curve25519,
publicKey OCTET STRING,
-- signature(details + curve + publicKey) using the appropriate method for curve
signature OCTET STRING
}
Details ::= SEQUENCE {
name Name,
-- At least 1 ipv4 or ipv6 address must be present if isCA is false
networks SEQUENCE OF Network OPTIONAL,
unsafeNetworks SEQUENCE OF Network OPTIONAL,
groups SEQUENCE OF Name OPTIONAL,
isCA BOOLEAN DEFAULT false,
notBefore Time,
notAfter Time,
-- issuer is only required if isCA is false, if isCA is true then it must not be present
issuer OCTET STRING OPTIONAL,
...
-- New fields can be added below here
}
END

730
cert/cert_v2.go Normal file
View File

@ -0,0 +1,730 @@
package cert
import (
"bytes"
"crypto/ecdh"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"encoding/pem"
"fmt"
"net/netip"
"slices"
"time"
"golang.org/x/crypto/cryptobyte"
"golang.org/x/crypto/cryptobyte/asn1"
"golang.org/x/crypto/curve25519"
)
const (
classConstructed = 0x20
classContextSpecific = 0x80
TagCertDetails = 0 | classConstructed | classContextSpecific
TagCertCurve = 1 | classContextSpecific
TagCertPublicKey = 2 | classContextSpecific
TagCertSignature = 3 | classContextSpecific
TagDetailsName = 0 | classContextSpecific
TagDetailsNetworks = 1 | classConstructed | classContextSpecific
TagDetailsUnsafeNetworks = 2 | classConstructed | classContextSpecific
TagDetailsGroups = 3 | classConstructed | classContextSpecific
TagDetailsIsCA = 4 | classContextSpecific
TagDetailsNotBefore = 5 | classContextSpecific
TagDetailsNotAfter = 6 | classContextSpecific
TagDetailsIssuer = 7 | classContextSpecific
)
const (
// MaxCertificateSize is the maximum length a valid certificate can be
MaxCertificateSize = 65536
// MaxNameLength is limited to a maximum realistic DNS domain name to help facilitate DNS systems
MaxNameLength = 253
// MaxNetworkLength is the maximum length a network value can be.
// 16 bytes for an ipv6 address + 1 byte for the prefix length
MaxNetworkLength = 17
)
type certificateV2 struct {
details detailsV2
// RawDetails contains the entire asn.1 DER encoded Details struct
// This is to benefit forwards compatibility in signature checking.
// signature(RawDetails + Curve + PublicKey) == Signature
rawDetails []byte
curve Curve
publicKey []byte
signature []byte
}
type detailsV2 struct {
name string
networks []netip.Prefix // MUST BE SORTED
unsafeNetworks []netip.Prefix // MUST BE SORTED
groups []string
isCA bool
notBefore time.Time
notAfter time.Time
issuer string
}
func (c *certificateV2) Version() Version {
return Version2
}
func (c *certificateV2) Curve() Curve {
return c.curve
}
func (c *certificateV2) Groups() []string {
return c.details.groups
}
func (c *certificateV2) IsCA() bool {
return c.details.isCA
}
func (c *certificateV2) Issuer() string {
return c.details.issuer
}
func (c *certificateV2) Name() string {
return c.details.name
}
func (c *certificateV2) Networks() []netip.Prefix {
return c.details.networks
}
func (c *certificateV2) NotAfter() time.Time {
return c.details.notAfter
}
func (c *certificateV2) NotBefore() time.Time {
return c.details.notBefore
}
func (c *certificateV2) PublicKey() []byte {
return c.publicKey
}
func (c *certificateV2) Signature() []byte {
return c.signature
}
func (c *certificateV2) UnsafeNetworks() []netip.Prefix {
return c.details.unsafeNetworks
}
func (c *certificateV2) Fingerprint() (string, error) {
if len(c.rawDetails) == 0 {
return "", ErrMissingDetails
}
b := make([]byte, len(c.rawDetails)+1+len(c.publicKey)+len(c.signature))
copy(b, c.rawDetails)
b[len(c.rawDetails)] = byte(c.curve)
copy(b[len(c.rawDetails)+1:], c.publicKey)
copy(b[len(c.rawDetails)+1+len(c.publicKey):], c.signature)
sum := sha256.Sum256(b)
return hex.EncodeToString(sum[:]), nil
}
func (c *certificateV2) CheckSignature(key []byte) bool {
if len(c.rawDetails) == 0 {
return false
}
b := make([]byte, len(c.rawDetails)+1+len(c.publicKey))
copy(b, c.rawDetails)
b[len(c.rawDetails)] = byte(c.curve)
copy(b[len(c.rawDetails)+1:], c.publicKey)
switch c.curve {
case Curve_CURVE25519:
return ed25519.Verify(key, b, c.signature)
case Curve_P256:
x, y := elliptic.Unmarshal(elliptic.P256(), key)
pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
hashed := sha256.Sum256(b)
return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
default:
return false
}
}
func (c *certificateV2) Expired(t time.Time) bool {
return c.details.notBefore.After(t) || c.details.notAfter.Before(t)
}
func (c *certificateV2) VerifyPrivateKey(curve Curve, key []byte) error {
if curve != c.curve {
return ErrPublicPrivateCurveMismatch
}
if c.details.isCA {
switch curve {
case Curve_CURVE25519:
// the call to PublicKey below will panic slice bounds out of range otherwise
if len(key) != ed25519.PrivateKeySize {
return ErrInvalidPrivateKey
}
if !ed25519.PublicKey(c.publicKey).Equal(ed25519.PrivateKey(key).Public()) {
return ErrPublicPrivateKeyMismatch
}
case Curve_P256:
privkey, err := ecdh.P256().NewPrivateKey(key)
if err != nil {
return ErrInvalidPrivateKey
}
pub := privkey.PublicKey().Bytes()
if !bytes.Equal(pub, c.publicKey) {
return ErrPublicPrivateKeyMismatch
}
default:
return fmt.Errorf("invalid curve: %s", curve)
}
return nil
}
var pub []byte
switch curve {
case Curve_CURVE25519:
var err error
pub, err = curve25519.X25519(key, curve25519.Basepoint)
if err != nil {
return ErrInvalidPrivateKey
}
case Curve_P256:
privkey, err := ecdh.P256().NewPrivateKey(key)
if err != nil {
return ErrInvalidPrivateKey
}
pub = privkey.PublicKey().Bytes()
default:
return fmt.Errorf("invalid curve: %s", curve)
}
if !bytes.Equal(pub, c.publicKey) {
return ErrPublicPrivateKeyMismatch
}
return nil
}
func (c *certificateV2) String() string {
mb, err := c.marshalJSON()
if err != nil {
return fmt.Sprintf("<error marshalling certificate: %v>", err)
}
b, err := json.MarshalIndent(mb, "", "\t")
if err != nil {
return fmt.Sprintf("<error marshalling certificate: %v>", err)
}
return string(b)
}
func (c *certificateV2) MarshalForHandshakes() ([]byte, error) {
if c.rawDetails == nil {
return nil, ErrEmptyRawDetails
}
var b cryptobyte.Builder
// Outermost certificate
b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
// Add the cert details which is already marshalled
b.AddBytes(c.rawDetails)
// Skipping the curve and public key since those come across in a different part of the handshake
// Add the signature
b.AddASN1(TagCertSignature, func(b *cryptobyte.Builder) {
b.AddBytes(c.signature)
})
})
return b.Bytes()
}
func (c *certificateV2) Marshal() ([]byte, error) {
if c.rawDetails == nil {
return nil, ErrEmptyRawDetails
}
var b cryptobyte.Builder
// Outermost certificate
b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
// Add the cert details which is already marshalled
b.AddBytes(c.rawDetails)
// Add the curve only if its not the default value
if c.curve != Curve_CURVE25519 {
b.AddASN1(TagCertCurve, func(b *cryptobyte.Builder) {
b.AddBytes([]byte{byte(c.curve)})
})
}
// Add the public key if it is not empty
if c.publicKey != nil {
b.AddASN1(TagCertPublicKey, func(b *cryptobyte.Builder) {
b.AddBytes(c.publicKey)
})
}
// Add the signature
b.AddASN1(TagCertSignature, func(b *cryptobyte.Builder) {
b.AddBytes(c.signature)
})
})
return b.Bytes()
}
func (c *certificateV2) MarshalPEM() ([]byte, error) {
b, err := c.Marshal()
if err != nil {
return nil, err
}
return pem.EncodeToMemory(&pem.Block{Type: CertificateV2Banner, Bytes: b}), nil
}
func (c *certificateV2) MarshalJSON() ([]byte, error) {
b, err := c.marshalJSON()
if err != nil {
return nil, err
}
return json.Marshal(b)
}
func (c *certificateV2) marshalJSON() (m, error) {
fp, err := c.Fingerprint()
if err != nil {
return nil, err
}
return m{
"details": m{
"name": c.details.name,
"networks": c.details.networks,
"unsafeNetworks": c.details.unsafeNetworks,
"groups": c.details.groups,
"notBefore": c.details.notBefore,
"notAfter": c.details.notAfter,
"isCa": c.details.isCA,
"issuer": c.details.issuer,
},
"version": Version2,
"publicKey": fmt.Sprintf("%x", c.publicKey),
"curve": c.curve.String(),
"fingerprint": fp,
"signature": fmt.Sprintf("%x", c.Signature()),
}, nil
}
func (c *certificateV2) Copy() Certificate {
nc := &certificateV2{
details: detailsV2{
name: c.details.name,
notBefore: c.details.notBefore,
notAfter: c.details.notAfter,
isCA: c.details.isCA,
issuer: c.details.issuer,
},
curve: c.curve,
publicKey: make([]byte, len(c.publicKey)),
signature: make([]byte, len(c.signature)),
rawDetails: make([]byte, len(c.rawDetails)),
}
if c.details.groups != nil {
nc.details.groups = make([]string, len(c.details.groups))
copy(nc.details.groups, c.details.groups)
}
if c.details.networks != nil {
nc.details.networks = make([]netip.Prefix, len(c.details.networks))
copy(nc.details.networks, c.details.networks)
}
if c.details.unsafeNetworks != nil {
nc.details.unsafeNetworks = make([]netip.Prefix, len(c.details.unsafeNetworks))
copy(nc.details.unsafeNetworks, c.details.unsafeNetworks)
}
copy(nc.rawDetails, c.rawDetails)
copy(nc.signature, c.signature)
copy(nc.publicKey, c.publicKey)
return nc
}
func (c *certificateV2) fromTBSCertificate(t *TBSCertificate) error {
c.details = detailsV2{
name: t.Name,
networks: t.Networks,
unsafeNetworks: t.UnsafeNetworks,
groups: t.Groups,
isCA: t.IsCA,
notBefore: t.NotBefore,
notAfter: t.NotAfter,
issuer: t.issuer,
}
c.curve = t.Curve
c.publicKey = t.PublicKey
return c.validate()
}
func (c *certificateV2) validate() error {
// Empty names are allowed
if len(c.publicKey) == 0 {
return ErrInvalidPublicKey
}
if !c.details.isCA && len(c.details.networks) == 0 {
return NewErrInvalidCertificateProperties("non-CA certificate must contain at least 1 network")
}
hasV4Networks := false
hasV6Networks := false
for _, network := range c.details.networks {
if !network.IsValid() || !network.Addr().IsValid() {
return NewErrInvalidCertificateProperties("invalid network: %s", network)
}
if network.Addr().IsUnspecified() {
return NewErrInvalidCertificateProperties("non-CA certificates must not use the zero address as a network: %s", network)
}
if network.Addr().Zone() != "" {
return NewErrInvalidCertificateProperties("networks may not contain zones: %s", network)
}
if network.Addr().Is4In6() {
return NewErrInvalidCertificateProperties("4in6 networks are not allowed: %s", network)
}
hasV4Networks = hasV4Networks || network.Addr().Is4()
hasV6Networks = hasV6Networks || network.Addr().Is6()
}
slices.SortFunc(c.details.networks, comparePrefix)
err := findDuplicatePrefix(c.details.networks)
if err != nil {
return err
}
for _, network := range c.details.unsafeNetworks {
if !network.IsValid() || !network.Addr().IsValid() {
return NewErrInvalidCertificateProperties("invalid unsafe network: %s", network)
}
if network.Addr().Zone() != "" {
return NewErrInvalidCertificateProperties("unsafe networks may not contain zones: %s", network)
}
if !c.details.isCA {
if network.Addr().Is6() {
if !hasV6Networks {
return NewErrInvalidCertificateProperties("IPv6 unsafe networks require an IPv6 address assignment: %s", network)
}
} else if network.Addr().Is4() {
if !hasV4Networks {
return NewErrInvalidCertificateProperties("IPv4 unsafe networks require an IPv4 address assignment: %s", network)
}
}
}
}
slices.SortFunc(c.details.unsafeNetworks, comparePrefix)
err = findDuplicatePrefix(c.details.unsafeNetworks)
if err != nil {
return err
}
return nil
}
func (c *certificateV2) marshalForSigning() ([]byte, error) {
d, err := c.details.Marshal()
if err != nil {
return nil, fmt.Errorf("marshalling certificate details failed: %w", err)
}
c.rawDetails = d
b := make([]byte, len(c.rawDetails)+1+len(c.publicKey))
copy(b, c.rawDetails)
b[len(c.rawDetails)] = byte(c.curve)
copy(b[len(c.rawDetails)+1:], c.publicKey)
return b, nil
}
func (c *certificateV2) setSignature(b []byte) error {
if len(b) == 0 {
return ErrEmptySignature
}
c.signature = b
return nil
}
func (d *detailsV2) Marshal() ([]byte, error) {
var b cryptobyte.Builder
var err error
// Details are a structure
b.AddASN1(TagCertDetails, func(b *cryptobyte.Builder) {
// Add the name
b.AddASN1(TagDetailsName, func(b *cryptobyte.Builder) {
b.AddBytes([]byte(d.name))
})
// Add the networks if any exist
if len(d.networks) > 0 {
b.AddASN1(TagDetailsNetworks, func(b *cryptobyte.Builder) {
for _, n := range d.networks {
sb, innerErr := n.MarshalBinary()
if innerErr != nil {
// MarshalBinary never returns an error
err = fmt.Errorf("unable to marshal network: %w", innerErr)
return
}
b.AddASN1OctetString(sb)
}
})
}
// Add the unsafe networks if any exist
if len(d.unsafeNetworks) > 0 {
b.AddASN1(TagDetailsUnsafeNetworks, func(b *cryptobyte.Builder) {
for _, n := range d.unsafeNetworks {
sb, innerErr := n.MarshalBinary()
if innerErr != nil {
// MarshalBinary never returns an error
err = fmt.Errorf("unable to marshal unsafe network: %w", innerErr)
return
}
b.AddASN1OctetString(sb)
}
})
}
// Add groups if any exist
if len(d.groups) > 0 {
b.AddASN1(TagDetailsGroups, func(b *cryptobyte.Builder) {
for _, group := range d.groups {
b.AddASN1(asn1.UTF8String, func(b *cryptobyte.Builder) {
b.AddBytes([]byte(group))
})
}
})
}
// Add IsCA only if true
if d.isCA {
b.AddASN1(TagDetailsIsCA, func(b *cryptobyte.Builder) {
b.AddUint8(0xff)
})
}
// Add not before
b.AddASN1Int64WithTag(d.notBefore.Unix(), TagDetailsNotBefore)
// Add not after
b.AddASN1Int64WithTag(d.notAfter.Unix(), TagDetailsNotAfter)
// Add the issuer if present
if d.issuer != "" {
issuerBytes, innerErr := hex.DecodeString(d.issuer)
if innerErr != nil {
err = fmt.Errorf("failed to decode issuer: %w", innerErr)
return
}
b.AddASN1(TagDetailsIssuer, func(b *cryptobyte.Builder) {
b.AddBytes(issuerBytes)
})
}
})
if err != nil {
return nil, err
}
return b.Bytes()
}
func unmarshalCertificateV2(b []byte, publicKey []byte, curve Curve) (*certificateV2, error) {
l := len(b)
if l == 0 || l > MaxCertificateSize {
return nil, ErrBadFormat
}
input := cryptobyte.String(b)
// Open the envelope
if !input.ReadASN1(&input, asn1.SEQUENCE) || input.Empty() {
return nil, ErrBadFormat
}
// Grab the cert details, we need to preserve the tag and length
var rawDetails cryptobyte.String
if !input.ReadASN1Element(&rawDetails, TagCertDetails) || rawDetails.Empty() {
return nil, ErrBadFormat
}
//Maybe grab the curve
var rawCurve byte
if !readOptionalASN1Byte(&input, &rawCurve, TagCertCurve, byte(curve)) {
return nil, ErrBadFormat
}
curve = Curve(rawCurve)
// Maybe grab the public key
var rawPublicKey cryptobyte.String
if len(publicKey) > 0 {
rawPublicKey = publicKey
} else if !input.ReadOptionalASN1(&rawPublicKey, nil, TagCertPublicKey) {
return nil, ErrBadFormat
}
if len(rawPublicKey) == 0 {
return nil, ErrBadFormat
}
// Grab the signature
var rawSignature cryptobyte.String
if !input.ReadASN1(&rawSignature, TagCertSignature) || rawSignature.Empty() {
return nil, ErrBadFormat
}
// Finally unmarshal the details
details, err := unmarshalDetails(rawDetails)
if err != nil {
return nil, err
}
c := &certificateV2{
details: details,
rawDetails: rawDetails,
curve: curve,
publicKey: rawPublicKey,
signature: rawSignature,
}
err = c.validate()
if err != nil {
return nil, err
}
return c, nil
}
func unmarshalDetails(b cryptobyte.String) (detailsV2, error) {
// Open the envelope
if !b.ReadASN1(&b, TagCertDetails) || b.Empty() {
return detailsV2{}, ErrBadFormat
}
// Read the name
var name cryptobyte.String
if !b.ReadASN1(&name, TagDetailsName) || name.Empty() || len(name) > MaxNameLength {
return detailsV2{}, ErrBadFormat
}
// Read the network addresses
var subString cryptobyte.String
var found bool
if !b.ReadOptionalASN1(&subString, &found, TagDetailsNetworks) {
return detailsV2{}, ErrBadFormat
}
var networks []netip.Prefix
var val cryptobyte.String
if found {
for !subString.Empty() {
if !subString.ReadASN1(&val, asn1.OCTET_STRING) || val.Empty() || len(val) > MaxNetworkLength {
return detailsV2{}, ErrBadFormat
}
var n netip.Prefix
if err := n.UnmarshalBinary(val); err != nil {
return detailsV2{}, ErrBadFormat
}
networks = append(networks, n)
}
}
// Read out any unsafe networks
if !b.ReadOptionalASN1(&subString, &found, TagDetailsUnsafeNetworks) {
return detailsV2{}, ErrBadFormat
}
var unsafeNetworks []netip.Prefix
if found {
for !subString.Empty() {
if !subString.ReadASN1(&val, asn1.OCTET_STRING) || val.Empty() || len(val) > MaxNetworkLength {
return detailsV2{}, ErrBadFormat
}
var n netip.Prefix
if err := n.UnmarshalBinary(val); err != nil {
return detailsV2{}, ErrBadFormat
}
unsafeNetworks = append(unsafeNetworks, n)
}
}
// Read out any groups
if !b.ReadOptionalASN1(&subString, &found, TagDetailsGroups) {
return detailsV2{}, ErrBadFormat
}
var groups []string
if found {
for !subString.Empty() {
if !subString.ReadASN1(&val, asn1.UTF8String) || val.Empty() {
return detailsV2{}, ErrBadFormat
}
groups = append(groups, string(val))
}
}
// Read out IsCA
var isCa bool
if !readOptionalASN1Boolean(&b, &isCa, TagDetailsIsCA, false) {
return detailsV2{}, ErrBadFormat
}
// Read not before and not after
var notBefore int64
if !b.ReadASN1Int64WithTag(&notBefore, TagDetailsNotBefore) {
return detailsV2{}, ErrBadFormat
}
var notAfter int64
if !b.ReadASN1Int64WithTag(&notAfter, TagDetailsNotAfter) {
return detailsV2{}, ErrBadFormat
}
// Read issuer
var issuer cryptobyte.String
if !b.ReadOptionalASN1(&issuer, nil, TagDetailsIssuer) {
return detailsV2{}, ErrBadFormat
}
return detailsV2{
name: string(name),
networks: networks,
unsafeNetworks: unsafeNetworks,
groups: groups,
isCA: isCa,
notBefore: time.Unix(notBefore, 0),
notAfter: time.Unix(notAfter, 0),
issuer: hex.EncodeToString(issuer),
}, nil
}

267
cert/cert_v2_test.go Normal file
View File

@ -0,0 +1,267 @@
package cert
import (
"crypto/ed25519"
"crypto/rand"
"encoding/hex"
"net/netip"
"slices"
"testing"
"time"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCertificateV2_Marshal(t *testing.T) {
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := []byte("1234567890abcedfghij1234567890ab")
nc := certificateV2{
details: detailsV2{
name: "testing",
networks: []netip.Prefix{
mustParsePrefixUnmapped("10.1.1.2/16"),
mustParsePrefixUnmapped("10.1.1.1/24"),
},
unsafeNetworks: []netip.Prefix{
mustParsePrefixUnmapped("9.1.1.3/16"),
mustParsePrefixUnmapped("9.1.1.2/24"),
},
groups: []string{"test-group1", "test-group2", "test-group3"},
notBefore: before,
notAfter: after,
isCA: false,
issuer: "1234567890abcdef1234567890abcdef",
},
signature: []byte("1234567890abcdef1234567890abcdef"),
publicKey: pubKey,
}
db, err := nc.details.Marshal()
require.NoError(t, err)
nc.rawDetails = db
b, err := nc.Marshal()
require.NoError(t, err)
//t.Log("Cert size:", len(b))
nc2, err := unmarshalCertificateV2(b, nil, Curve_CURVE25519)
require.NoError(t, err)
assert.Equal(t, Version2, nc.Version())
assert.Equal(t, Curve_CURVE25519, nc.Curve())
assert.Equal(t, nc.Signature(), nc2.Signature())
assert.Equal(t, nc.Name(), nc2.Name())
assert.Equal(t, nc.NotBefore(), nc2.NotBefore())
assert.Equal(t, nc.NotAfter(), nc2.NotAfter())
assert.Equal(t, nc.PublicKey(), nc2.PublicKey())
assert.Equal(t, nc.IsCA(), nc2.IsCA())
assert.Equal(t, nc.Issuer(), nc2.Issuer())
// unmarshalling will sort networks and unsafeNetworks, we need to do the same
// but first make sure it fails
assert.NotEqual(t, nc.Networks(), nc2.Networks())
assert.NotEqual(t, nc.UnsafeNetworks(), nc2.UnsafeNetworks())
slices.SortFunc(nc.details.networks, comparePrefix)
slices.SortFunc(nc.details.unsafeNetworks, comparePrefix)
assert.Equal(t, nc.Networks(), nc2.Networks())
assert.Equal(t, nc.UnsafeNetworks(), nc2.UnsafeNetworks())
assert.Equal(t, nc.Groups(), nc2.Groups())
}
func TestCertificateV2_Expired(t *testing.T) {
nc := certificateV2{
details: detailsV2{
notBefore: time.Now().Add(time.Second * -60).Round(time.Second),
notAfter: time.Now().Add(time.Second * 60).Round(time.Second),
},
}
assert.True(t, nc.Expired(time.Now().Add(time.Hour)))
assert.True(t, nc.Expired(time.Now().Add(-time.Hour)))
assert.False(t, nc.Expired(time.Now()))
}
func TestCertificateV2_MarshalJSON(t *testing.T) {
time.Local = time.UTC
pubKey := []byte("1234567890abcedf1234567890abcedf")
nc := certificateV2{
details: detailsV2{
name: "testing",
networks: []netip.Prefix{
mustParsePrefixUnmapped("10.1.1.1/24"),
mustParsePrefixUnmapped("10.1.1.2/16"),
},
unsafeNetworks: []netip.Prefix{
mustParsePrefixUnmapped("9.1.1.2/24"),
mustParsePrefixUnmapped("9.1.1.3/16"),
},
groups: []string{"test-group1", "test-group2", "test-group3"},
notBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC),
notAfter: time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC),
isCA: false,
issuer: "1234567890abcedf1234567890abcedf",
},
publicKey: pubKey,
signature: []byte("1234567890abcedf1234567890abcedf1234567890abcedf1234567890abcedf"),
}
b, err := nc.MarshalJSON()
require.ErrorIs(t, err, ErrMissingDetails)
rd, err := nc.details.Marshal()
require.NoError(t, err)
nc.rawDetails = rd
b, err = nc.MarshalJSON()
require.NoError(t, err)
assert.JSONEq(
t,
"{\"curve\":\"CURVE25519\",\"details\":{\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedf1234567890abcedf\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"152d9a7400c1e001cb76cffd035215ebb351f69eeb797f7f847dd086e15e56dd\",\"publicKey\":\"3132333435363738393061626365646631323334353637383930616263656466\",\"signature\":\"31323334353637383930616263656466313233343536373839306162636564663132333435363738393061626365646631323334353637383930616263656466\",\"version\":2}",
string(b),
)
}
func TestCertificateV2_VerifyPrivateKey(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
require.NoError(t, err)
err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey[:16])
require.ErrorIs(t, err, ErrInvalidPrivateKey)
_, caKey2, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2)
require.ErrorIs(t, err, ErrPublicPrivateKeyMismatch)
c, _, priv, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
require.NoError(t, err)
assert.Empty(t, b)
assert.Equal(t, Curve_CURVE25519, curve)
err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv)
require.NoError(t, err)
_, priv2 := X25519Keypair()
err = c.VerifyPrivateKey(Curve_P256, priv2)
require.ErrorIs(t, err, ErrPublicPrivateCurveMismatch)
err = c.VerifyPrivateKey(Curve_CURVE25519, priv2)
require.ErrorIs(t, err, ErrPublicPrivateKeyMismatch)
err = c.VerifyPrivateKey(Curve_CURVE25519, priv2[:16])
require.ErrorIs(t, err, ErrInvalidPrivateKey)
ac, ok := c.(*certificateV2)
require.True(t, ok)
ac.curve = Curve(99)
err = c.VerifyPrivateKey(Curve(99), priv2)
require.EqualError(t, err, "invalid curve: 99")
ca2, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
require.NoError(t, err)
err = ca2.VerifyPrivateKey(Curve_P256, caKey2[:16])
require.ErrorIs(t, err, ErrInvalidPrivateKey)
c, _, priv, _ = NewTestCert(Version2, Curve_P256, ca2, caKey2, "test", time.Time{}, time.Time{}, nil, nil, nil)
rawPriv, b, curve, err = UnmarshalPrivateKeyFromPEM(priv)
err = c.VerifyPrivateKey(Curve_P256, priv[:16])
require.ErrorIs(t, err, ErrInvalidPrivateKey)
err = c.VerifyPrivateKey(Curve_P256, priv)
require.ErrorIs(t, err, ErrInvalidPrivateKey)
aCa, ok := ca2.(*certificateV2)
require.True(t, ok)
aCa.curve = Curve(99)
err = aCa.VerifyPrivateKey(Curve(99), priv2)
require.EqualError(t, err, "invalid curve: 99")
}
func TestCertificateV2_VerifyPrivateKeyP256(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
err := ca.VerifyPrivateKey(Curve_P256, caKey)
require.NoError(t, err)
_, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
require.NoError(t, err)
err = ca.VerifyPrivateKey(Curve_P256, caKey2)
require.Error(t, err)
c, _, priv, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
require.NoError(t, err)
assert.Empty(t, b)
assert.Equal(t, Curve_P256, curve)
err = c.VerifyPrivateKey(Curve_P256, rawPriv)
require.NoError(t, err)
_, priv2 := P256Keypair()
err = c.VerifyPrivateKey(Curve_P256, priv2)
require.Error(t, err)
}
func TestCertificateV2_Copy(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
cc := c.Copy()
test.AssertDeepCopyEqual(t, c, cc)
}
func TestUnmarshalCertificateV2(t *testing.T) {
data := []byte("\x98\x00\x00")
_, err := unmarshalCertificateV2(data, nil, Curve_CURVE25519)
require.EqualError(t, err, "bad wire format")
}
func TestCertificateV2_marshalForSigningStability(t *testing.T) {
before := time.Date(1996, time.May, 5, 0, 0, 0, 0, time.UTC)
after := before.Add(time.Second * 60).Round(time.Second)
pubKey := []byte("1234567890abcedfghij1234567890ab")
nc := certificateV2{
details: detailsV2{
name: "testing",
networks: []netip.Prefix{
mustParsePrefixUnmapped("10.1.1.2/16"),
mustParsePrefixUnmapped("10.1.1.1/24"),
},
unsafeNetworks: []netip.Prefix{
mustParsePrefixUnmapped("9.1.1.3/16"),
mustParsePrefixUnmapped("9.1.1.2/24"),
},
groups: []string{"test-group1", "test-group2", "test-group3"},
notBefore: before,
notAfter: after,
isCA: false,
issuer: "1234567890abcdef1234567890abcdef",
},
signature: []byte("1234567890abcdef1234567890abcdef"),
publicKey: pubKey,
}
const expectedRawDetailsStr = "a070800774657374696e67a10e04050a0101021004050a01010118a20e0405090101031004050901010218a3270c0b746573742d67726f7570310c0b746573742d67726f7570320c0b746573742d67726f7570338504318bef808604318befbc87101234567890abcdef1234567890abcdef"
expectedRawDetails, err := hex.DecodeString(expectedRawDetailsStr)
require.NoError(t, err)
db, err := nc.details.Marshal()
require.NoError(t, err)
assert.Equal(t, expectedRawDetails, db)
expectedForSigning, err := hex.DecodeString(expectedRawDetailsStr + "00313233343536373839306162636564666768696a313233343536373839306162")
b, err := nc.marshalForSigning()
require.NoError(t, err)
assert.Equal(t, expectedForSigning, b)
}

View File

@ -3,14 +3,28 @@ package cert
import (
"crypto/aes"
"crypto/cipher"
"crypto/ed25519"
"crypto/rand"
"encoding/pem"
"fmt"
"io"
"math"
"golang.org/x/crypto/argon2"
"google.golang.org/protobuf/proto"
)
// KDF factors
type NebulaEncryptedData struct {
EncryptionMetadata NebulaEncryptionMetadata
Ciphertext []byte
}
type NebulaEncryptionMetadata struct {
EncryptionAlgorithm string
Argon2Parameters Argon2Parameters
}
// Argon2Parameters KDF factors
type Argon2Parameters struct {
version rune
Memory uint32 // KiB
@ -19,7 +33,7 @@ type Argon2Parameters struct {
salt []byte
}
// Returns a new Argon2Parameters object with current version set
// NewArgon2Parameters Returns a new Argon2Parameters object with current version set
func NewArgon2Parameters(memory uint32, parallelism uint8, iterations uint32) *Argon2Parameters {
return &Argon2Parameters{
version: argon2.Version,
@ -141,3 +155,146 @@ func splitNonceCiphertext(blob []byte, nonceSize int) ([]byte, []byte, error) {
return blob[:nonceSize], blob[nonceSize:], nil
}
// EncryptAndMarshalSigningPrivateKey is a simple helper to encrypt and PEM encode a private key
func EncryptAndMarshalSigningPrivateKey(curve Curve, b []byte, passphrase []byte, kdfParams *Argon2Parameters) ([]byte, error) {
ciphertext, err := aes256Encrypt(passphrase, kdfParams, b)
if err != nil {
return nil, err
}
b, err = proto.Marshal(&RawNebulaEncryptedData{
EncryptionMetadata: &RawNebulaEncryptionMetadata{
EncryptionAlgorithm: "AES-256-GCM",
Argon2Parameters: &RawNebulaArgon2Parameters{
Version: kdfParams.version,
Memory: kdfParams.Memory,
Parallelism: uint32(kdfParams.Parallelism),
Iterations: kdfParams.Iterations,
Salt: kdfParams.salt,
},
},
Ciphertext: ciphertext,
})
if err != nil {
return nil, err
}
switch curve {
case Curve_CURVE25519:
return pem.EncodeToMemory(&pem.Block{Type: EncryptedEd25519PrivateKeyBanner, Bytes: b}), nil
case Curve_P256:
return pem.EncodeToMemory(&pem.Block{Type: EncryptedECDSAP256PrivateKeyBanner, Bytes: b}), nil
default:
return nil, fmt.Errorf("invalid curve: %v", curve)
}
}
// UnmarshalNebulaEncryptedData will unmarshal a protobuf byte representation of a nebula cert into its
// protobuf-generated struct.
func UnmarshalNebulaEncryptedData(b []byte) (*NebulaEncryptedData, error) {
if len(b) == 0 {
return nil, fmt.Errorf("nil byte array")
}
var rned RawNebulaEncryptedData
err := proto.Unmarshal(b, &rned)
if err != nil {
return nil, err
}
if rned.EncryptionMetadata == nil {
return nil, fmt.Errorf("encoded EncryptionMetadata was nil")
}
if rned.EncryptionMetadata.Argon2Parameters == nil {
return nil, fmt.Errorf("encoded Argon2Parameters was nil")
}
params, err := unmarshalArgon2Parameters(rned.EncryptionMetadata.Argon2Parameters)
if err != nil {
return nil, err
}
ned := NebulaEncryptedData{
EncryptionMetadata: NebulaEncryptionMetadata{
EncryptionAlgorithm: rned.EncryptionMetadata.EncryptionAlgorithm,
Argon2Parameters: *params,
},
Ciphertext: rned.Ciphertext,
}
return &ned, nil
}
func unmarshalArgon2Parameters(params *RawNebulaArgon2Parameters) (*Argon2Parameters, error) {
if params.Version < math.MinInt32 || params.Version > math.MaxInt32 {
return nil, fmt.Errorf("Argon2Parameters Version must be at least %d and no more than %d", math.MinInt32, math.MaxInt32)
}
if params.Memory <= 0 || params.Memory > math.MaxUint32 {
return nil, fmt.Errorf("Argon2Parameters Memory must be be greater than 0 and no more than %d KiB", uint32(math.MaxUint32))
}
if params.Parallelism <= 0 || params.Parallelism > math.MaxUint8 {
return nil, fmt.Errorf("Argon2Parameters Parallelism must be be greater than 0 and no more than %d", math.MaxUint8)
}
if params.Iterations <= 0 || params.Iterations > math.MaxUint32 {
return nil, fmt.Errorf("-argon-iterations must be be greater than 0 and no more than %d", uint32(math.MaxUint32))
}
return &Argon2Parameters{
version: params.Version,
Memory: params.Memory,
Parallelism: uint8(params.Parallelism),
Iterations: params.Iterations,
salt: params.Salt,
}, nil
}
// DecryptAndUnmarshalSigningPrivateKey will try to pem decode and decrypt an Ed25519/ECDSA private key with
// the given passphrase, returning any other bytes b or an error on failure
func DecryptAndUnmarshalSigningPrivateKey(passphrase, b []byte) (Curve, []byte, []byte, error) {
var curve Curve
k, r := pem.Decode(b)
if k == nil {
return curve, nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
}
switch k.Type {
case EncryptedEd25519PrivateKeyBanner:
curve = Curve_CURVE25519
case EncryptedECDSAP256PrivateKeyBanner:
curve = Curve_P256
default:
return curve, nil, r, fmt.Errorf("bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner")
}
ned, err := UnmarshalNebulaEncryptedData(k.Bytes)
if err != nil {
return curve, nil, r, err
}
var bytes []byte
switch ned.EncryptionMetadata.EncryptionAlgorithm {
case "AES-256-GCM":
bytes, err = aes256Decrypt(passphrase, &ned.EncryptionMetadata.Argon2Parameters, ned.Ciphertext)
if err != nil {
return curve, nil, r, err
}
default:
return curve, nil, r, fmt.Errorf("unsupported encryption algorithm: %s", ned.EncryptionMetadata.EncryptionAlgorithm)
}
switch curve {
case Curve_CURVE25519:
if len(bytes) != ed25519.PrivateKeySize {
return curve, nil, r, fmt.Errorf("key was not %d bytes, is invalid ed25519 private key", ed25519.PrivateKeySize)
}
case Curve_P256:
if len(bytes) != 32 {
return curve, nil, r, fmt.Errorf("key was not 32 bytes, is invalid ECDSA P256 private key")
}
}
return curve, bytes, r, nil
}

View File

@ -4,6 +4,7 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/argon2"
)
@ -23,3 +24,90 @@ func TestNewArgon2Parameters(t *testing.T) {
Iterations: 1,
}, p)
}
func TestDecryptAndUnmarshalSigningPrivateKey(t *testing.T) {
passphrase := []byte("DO NOT USE THIS KEY")
privKey := []byte(`# A good key
-----BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT
oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl
+Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB
qrlJ69wer3ZUHFXA
-----END NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
`)
shortKey := []byte(`# A key which, once decrypted, is too short
-----BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCoga5h8owMEBWRSMMJKzuUvWce7
k0qlBkQmCxiuLh80MuASW70YcKt8jeEIS2axo2V6zAKA9TSMcCsJW1kDDXEtL/xe
GLF5T7sDl5COp4LU3pGxpV+KoeQ/S3gQCAAcnaOtnJQX+aSDnbO3jCHyP7U9CHbs
rQr3bdH3Oy/WiYU=
-----END NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
`)
invalidBanner := []byte(`# Invalid banner (not encrypted)
-----BEGIN NEBULA ED25519 PRIVATE KEY-----
bWRp2CTVFhW9HD/qCd28ltDgK3w8VXSeaEYczDWos8sMUBqDb9jP3+NYwcS4lURG
XgLvodMXZJuaFPssp+WwtA==
-----END NEBULA ED25519 PRIVATE KEY-----
`)
invalidPem := []byte(`# Not a valid PEM format
-BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT
oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl
+Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB
qrlJ69wer3ZUHFXA
-END NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
`)
keyBundle := appendByteSlices(privKey, shortKey, invalidBanner, invalidPem)
// Success test case
curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, keyBundle)
require.NoError(t, err)
assert.Equal(t, Curve_CURVE25519, curve)
assert.Len(t, k, 64)
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
// Fail due to short key
curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
require.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key")
assert.Nil(t, k)
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
// Fail due to invalid banner
curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
require.EqualError(t, err, "bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner")
assert.Nil(t, k)
assert.Equal(t, rest, invalidPem)
// Fail due to ivalid PEM format, because
// it's missing the requisite pre-encapsulation boundary.
curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
require.EqualError(t, err, "input did not contain a valid PEM encoded block")
assert.Nil(t, k)
assert.Equal(t, rest, invalidPem)
// Fail due to invalid passphrase
curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey([]byte("invalid passphrase"), privKey)
require.EqualError(t, err, "invalid passphrase or corrupt private key")
assert.Nil(t, k)
assert.Equal(t, []byte{}, rest)
}
func TestEncryptAndMarshalSigningPrivateKey(t *testing.T) {
// Having proved that decryption works correctly above, we can test the
// encryption function produces a value which can be decrypted
passphrase := []byte("passphrase")
bytes := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
kdfParams := NewArgon2Parameters(64*1024, 4, 3)
key, err := EncryptAndMarshalSigningPrivateKey(Curve_CURVE25519, bytes, passphrase, kdfParams)
require.NoError(t, err)
// Verify the "key" can be decrypted successfully
curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, key)
assert.Len(t, k, 64)
assert.Equal(t, Curve_CURVE25519, curve)
assert.Equal(t, []byte{}, rest)
require.NoError(t, err)
// EncryptAndMarshalEd25519PrivateKey does not create any errors itself
}

View File

@ -2,13 +2,48 @@ package cert
import (
"errors"
"fmt"
)
var (
ErrBadFormat = errors.New("bad wire format")
ErrRootExpired = errors.New("root certificate is expired")
ErrExpired = errors.New("certificate is expired")
ErrNotCA = errors.New("certificate is not a CA")
ErrNotSelfSigned = errors.New("certificate is not self-signed")
ErrBlockListed = errors.New("certificate is in the block list")
ErrFingerprintMismatch = errors.New("certificate fingerprint did not match")
ErrSignatureMismatch = errors.New("certificate signature did not match")
ErrInvalidPublicKey = errors.New("invalid public key")
ErrInvalidPrivateKey = errors.New("invalid private key")
ErrPublicPrivateCurveMismatch = errors.New("public key does not match private key curve")
ErrPublicPrivateKeyMismatch = errors.New("public key and private key are not a pair")
ErrPrivateKeyEncrypted = errors.New("private key must be decrypted")
ErrCaNotFound = errors.New("could not find ca for the certificate")
ErrInvalidPEMBlock = errors.New("input did not contain a valid PEM encoded block")
ErrInvalidPEMCertificateBanner = errors.New("bytes did not contain a proper certificate banner")
ErrInvalidPEMX25519PublicKeyBanner = errors.New("bytes did not contain a proper X25519 public key banner")
ErrInvalidPEMX25519PrivateKeyBanner = errors.New("bytes did not contain a proper X25519 private key banner")
ErrInvalidPEMEd25519PublicKeyBanner = errors.New("bytes did not contain a proper Ed25519 public key banner")
ErrInvalidPEMEd25519PrivateKeyBanner = errors.New("bytes did not contain a proper Ed25519 private key banner")
ErrNoPeerStaticKey = errors.New("no peer static key was present")
ErrNoPayload = errors.New("provided payload was empty")
ErrMissingDetails = errors.New("certificate did not contain details")
ErrEmptySignature = errors.New("empty signature")
ErrEmptyRawDetails = errors.New("empty rawDetails not allowed")
)
type ErrInvalidCertificateProperties struct {
str string
}
func NewErrInvalidCertificateProperties(format string, a ...any) error {
return &ErrInvalidCertificateProperties{fmt.Sprintf(format, a...)}
}
func (e *ErrInvalidCertificateProperties) Error() string {
return e.str
}

141
cert/helper_test.go Normal file
View File

@ -0,0 +1,141 @@
package cert
import (
"crypto/ecdh"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"io"
"net/netip"
"time"
"golang.org/x/crypto/curve25519"
"golang.org/x/crypto/ed25519"
)
// NewTestCaCert will create a new ca certificate
func NewTestCaCert(version Version, curve Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) {
var err error
var pub, priv []byte
switch curve {
case Curve_CURVE25519:
pub, priv, err = ed25519.GenerateKey(rand.Reader)
case Curve_P256:
privk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
panic(err)
}
pub = elliptic.Marshal(elliptic.P256(), privk.PublicKey.X, privk.PublicKey.Y)
priv = privk.D.FillBytes(make([]byte, 32))
default:
// There is no default to allow the underlying lib to respond with an error
}
if before.IsZero() {
before = time.Now().Add(time.Second * -60).Round(time.Second)
}
if after.IsZero() {
after = time.Now().Add(time.Second * 60).Round(time.Second)
}
t := &TBSCertificate{
Curve: curve,
Version: version,
Name: "test ca",
NotBefore: time.Unix(before.Unix(), 0),
NotAfter: time.Unix(after.Unix(), 0),
PublicKey: pub,
Networks: networks,
UnsafeNetworks: unsafeNetworks,
Groups: groups,
IsCA: true,
}
c, err := t.Sign(nil, curve, priv)
if err != nil {
panic(err)
}
pem, err := c.MarshalPEM()
if err != nil {
panic(err)
}
return c, pub, priv, pem
}
// NewTestCert will generate a signed certificate with the provided details.
// Expiry times are defaulted if you do not pass them in
func NewTestCert(v Version, curve Curve, ca Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) {
if before.IsZero() {
before = time.Now().Add(time.Second * -60).Round(time.Second)
}
if after.IsZero() {
after = time.Now().Add(time.Second * 60).Round(time.Second)
}
if len(networks) == 0 {
networks = []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")}
}
var pub, priv []byte
switch curve {
case Curve_CURVE25519:
pub, priv = X25519Keypair()
case Curve_P256:
pub, priv = P256Keypair()
default:
panic("unknown curve")
}
nc := &TBSCertificate{
Version: v,
Curve: curve,
Name: name,
Networks: networks,
UnsafeNetworks: unsafeNetworks,
Groups: groups,
NotBefore: time.Unix(before.Unix(), 0),
NotAfter: time.Unix(after.Unix(), 0),
PublicKey: pub,
IsCA: false,
}
c, err := nc.Sign(ca, ca.Curve(), key)
if err != nil {
panic(err)
}
pem, err := c.MarshalPEM()
if err != nil {
panic(err)
}
return c, pub, MarshalPrivateKeyToPEM(curve, priv), pem
}
func X25519Keypair() ([]byte, []byte) {
privkey := make([]byte, 32)
if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
panic(err)
}
pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint)
if err != nil {
panic(err)
}
return pubkey, privkey
}
func P256Keypair() ([]byte, []byte) {
privkey, err := ecdh.P256().GenerateKey(rand.Reader)
if err != nil {
panic(err)
}
pubkey := privkey.PublicKey()
return pubkey.Bytes(), privkey.Bytes()
}

161
cert/pem.go Normal file
View File

@ -0,0 +1,161 @@
package cert
import (
"encoding/pem"
"fmt"
"golang.org/x/crypto/ed25519"
)
const (
CertificateBanner = "NEBULA CERTIFICATE"
CertificateV2Banner = "NEBULA CERTIFICATE V2"
X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY"
X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY"
EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY"
Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY"
Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY"
P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY"
P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY"
EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY"
ECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 PRIVATE KEY"
)
// UnmarshalCertificateFromPEM will try to unmarshal the first pem block in a byte array, returning any non consumed
// data or an error on failure
func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
p, r := pem.Decode(b)
if p == nil {
return nil, r, ErrInvalidPEMBlock
}
var c Certificate
var err error
switch p.Type {
// Implementations must validate the resulting certificate contains valid information
case CertificateBanner:
c, err = unmarshalCertificateV1(p.Bytes, nil)
case CertificateV2Banner:
c, err = unmarshalCertificateV2(p.Bytes, nil, Curve_CURVE25519)
default:
return nil, r, ErrInvalidPEMCertificateBanner
}
if err != nil {
return nil, r, err
}
return c, r, nil
}
func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
switch curve {
case Curve_CURVE25519:
return pem.EncodeToMemory(&pem.Block{Type: X25519PublicKeyBanner, Bytes: b})
case Curve_P256:
return pem.EncodeToMemory(&pem.Block{Type: P256PublicKeyBanner, Bytes: b})
default:
return nil
}
}
func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
k, r := pem.Decode(b)
if k == nil {
return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block")
}
var expectedLen int
var curve Curve
switch k.Type {
case X25519PublicKeyBanner, Ed25519PublicKeyBanner:
expectedLen = 32
curve = Curve_CURVE25519
case P256PublicKeyBanner:
// Uncompressed
expectedLen = 65
curve = Curve_P256
default:
return nil, r, 0, fmt.Errorf("bytes did not contain a proper public key banner")
}
if len(k.Bytes) != expectedLen {
return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid %s public key", expectedLen, curve)
}
return k.Bytes, r, curve, nil
}
func MarshalPrivateKeyToPEM(curve Curve, b []byte) []byte {
switch curve {
case Curve_CURVE25519:
return pem.EncodeToMemory(&pem.Block{Type: X25519PrivateKeyBanner, Bytes: b})
case Curve_P256:
return pem.EncodeToMemory(&pem.Block{Type: P256PrivateKeyBanner, Bytes: b})
default:
return nil
}
}
func MarshalSigningPrivateKeyToPEM(curve Curve, b []byte) []byte {
switch curve {
case Curve_CURVE25519:
return pem.EncodeToMemory(&pem.Block{Type: Ed25519PrivateKeyBanner, Bytes: b})
case Curve_P256:
return pem.EncodeToMemory(&pem.Block{Type: ECDSAP256PrivateKeyBanner, Bytes: b})
default:
return nil
}
}
// UnmarshalPrivateKeyFromPEM will try to unmarshal the first pem block in a byte array, returning any non
// consumed data or an error on failure
func UnmarshalPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
k, r := pem.Decode(b)
if k == nil {
return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block")
}
var expectedLen int
var curve Curve
switch k.Type {
case X25519PrivateKeyBanner:
expectedLen = 32
curve = Curve_CURVE25519
case P256PrivateKeyBanner:
expectedLen = 32
curve = Curve_P256
default:
return nil, r, 0, fmt.Errorf("bytes did not contain a proper private key banner")
}
if len(k.Bytes) != expectedLen {
return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid %s private key", expectedLen, curve)
}
return k.Bytes, r, curve, nil
}
func UnmarshalSigningPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
k, r := pem.Decode(b)
if k == nil {
return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block")
}
var curve Curve
switch k.Type {
case EncryptedEd25519PrivateKeyBanner:
return nil, nil, Curve_CURVE25519, ErrPrivateKeyEncrypted
case EncryptedECDSAP256PrivateKeyBanner:
return nil, nil, Curve_P256, ErrPrivateKeyEncrypted
case Ed25519PrivateKeyBanner:
curve = Curve_CURVE25519
if len(k.Bytes) != ed25519.PrivateKeySize {
return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid Ed25519 private key", ed25519.PrivateKeySize)
}
case ECDSAP256PrivateKeyBanner:
curve = Curve_P256
if len(k.Bytes) != 32 {
return nil, r, 0, fmt.Errorf("key was not 32 bytes, is invalid ECDSA P256 private key")
}
default:
return nil, r, 0, fmt.Errorf("bytes did not contain a proper Ed25519/ECDSA private key banner")
}
return k.Bytes, r, curve, nil
}

293
cert/pem_test.go Normal file
View File

@ -0,0 +1,293 @@
package cert
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestUnmarshalCertificateFromPEM(t *testing.T) {
goodCert := []byte(`
# A good cert
-----BEGIN NEBULA CERTIFICATE-----
CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
-----END NEBULA CERTIFICATE-----
`)
badBanner := []byte(`# A bad banner
-----BEGIN NOT A NEBULA CERTIFICATE-----
CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
-----END NOT A NEBULA CERTIFICATE-----
`)
invalidPem := []byte(`# Not a valid PEM format
-BEGIN NEBULA CERTIFICATE-----
CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
-END NEBULA CERTIFICATE----`)
certBundle := appendByteSlices(goodCert, badBanner, invalidPem)
// Success test case
cert, rest, err := UnmarshalCertificateFromPEM(certBundle)
assert.NotNil(t, cert)
assert.Equal(t, rest, append(badBanner, invalidPem...))
require.NoError(t, err)
// Fail due to invalid banner.
cert, rest, err = UnmarshalCertificateFromPEM(rest)
assert.Nil(t, cert)
assert.Equal(t, rest, invalidPem)
require.EqualError(t, err, "bytes did not contain a proper certificate banner")
// Fail due to ivalid PEM format, because
// it's missing the requisite pre-encapsulation boundary.
cert, rest, err = UnmarshalCertificateFromPEM(rest)
assert.Nil(t, cert)
assert.Equal(t, rest, invalidPem)
require.EqualError(t, err, "input did not contain a valid PEM encoded block")
}
func TestUnmarshalSigningPrivateKeyFromPEM(t *testing.T) {
privKey := []byte(`# A good key
-----BEGIN NEBULA ED25519 PRIVATE KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
-----END NEBULA ED25519 PRIVATE KEY-----
`)
privP256Key := []byte(`# A good key
-----BEGIN NEBULA ECDSA P256 PRIVATE KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA ECDSA P256 PRIVATE KEY-----
`)
shortKey := []byte(`# A short key
-----BEGIN NEBULA ED25519 PRIVATE KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
-----END NEBULA ED25519 PRIVATE KEY-----
`)
invalidBanner := []byte(`# Invalid banner
-----BEGIN NOT A NEBULA PRIVATE KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
-----END NOT A NEBULA PRIVATE KEY-----
`)
invalidPem := []byte(`# Not a valid PEM format
-BEGIN NEBULA ED25519 PRIVATE KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
-END NEBULA ED25519 PRIVATE KEY-----`)
keyBundle := appendByteSlices(privKey, privP256Key, shortKey, invalidBanner, invalidPem)
// Success test case
k, rest, curve, err := UnmarshalSigningPrivateKeyFromPEM(keyBundle)
assert.Len(t, k, 64)
assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem))
assert.Equal(t, Curve_CURVE25519, curve)
require.NoError(t, err)
// Success test case
k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
assert.Len(t, k, 32)
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
assert.Equal(t, Curve_P256, curve)
require.NoError(t, err)
// Fail due to short key
k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
assert.Nil(t, k)
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
require.EqualError(t, err, "key was not 64 bytes, is invalid Ed25519 private key")
// Fail due to invalid banner
k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
assert.Nil(t, k)
assert.Equal(t, rest, invalidPem)
require.EqualError(t, err, "bytes did not contain a proper Ed25519/ECDSA private key banner")
// Fail due to ivalid PEM format, because
// it's missing the requisite pre-encapsulation boundary.
k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
assert.Nil(t, k)
assert.Equal(t, rest, invalidPem)
require.EqualError(t, err, "input did not contain a valid PEM encoded block")
}
func TestUnmarshalPrivateKeyFromPEM(t *testing.T) {
privKey := []byte(`# A good key
-----BEGIN NEBULA X25519 PRIVATE KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA X25519 PRIVATE KEY-----
`)
privP256Key := []byte(`# A good key
-----BEGIN NEBULA P256 PRIVATE KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA P256 PRIVATE KEY-----
`)
shortKey := []byte(`# A short key
-----BEGIN NEBULA X25519 PRIVATE KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
-----END NEBULA X25519 PRIVATE KEY-----
`)
invalidBanner := []byte(`# Invalid banner
-----BEGIN NOT A NEBULA PRIVATE KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
-----END NOT A NEBULA PRIVATE KEY-----
`)
invalidPem := []byte(`# Not a valid PEM format
-BEGIN NEBULA X25519 PRIVATE KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
-END NEBULA X25519 PRIVATE KEY-----`)
keyBundle := appendByteSlices(privKey, privP256Key, shortKey, invalidBanner, invalidPem)
// Success test case
k, rest, curve, err := UnmarshalPrivateKeyFromPEM(keyBundle)
assert.Len(t, k, 32)
assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem))
assert.Equal(t, Curve_CURVE25519, curve)
require.NoError(t, err)
// Success test case
k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
assert.Len(t, k, 32)
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
assert.Equal(t, Curve_P256, curve)
require.NoError(t, err)
// Fail due to short key
k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
assert.Nil(t, k)
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 private key")
// Fail due to invalid banner
k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
assert.Nil(t, k)
assert.Equal(t, rest, invalidPem)
require.EqualError(t, err, "bytes did not contain a proper private key banner")
// Fail due to ivalid PEM format, because
// it's missing the requisite pre-encapsulation boundary.
k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
assert.Nil(t, k)
assert.Equal(t, rest, invalidPem)
require.EqualError(t, err, "input did not contain a valid PEM encoded block")
}
func TestUnmarshalPublicKeyFromPEM(t *testing.T) {
pubKey := []byte(`# A good key
-----BEGIN NEBULA ED25519 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA ED25519 PUBLIC KEY-----
`)
shortKey := []byte(`# A short key
-----BEGIN NEBULA ED25519 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
-----END NEBULA ED25519 PUBLIC KEY-----
`)
invalidBanner := []byte(`# Invalid banner
-----BEGIN NOT A NEBULA PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
-----END NOT A NEBULA PUBLIC KEY-----
`)
invalidPem := []byte(`# Not a valid PEM format
-BEGIN NEBULA ED25519 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
-END NEBULA ED25519 PUBLIC KEY-----`)
keyBundle := appendByteSlices(pubKey, shortKey, invalidBanner, invalidPem)
// Success test case
k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
assert.Len(t, k, 32)
assert.Equal(t, Curve_CURVE25519, curve)
require.NoError(t, err)
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
// Fail due to short key
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
assert.Nil(t, k)
assert.Equal(t, Curve_CURVE25519, curve)
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key")
// Fail due to invalid banner
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
assert.Nil(t, k)
assert.Equal(t, Curve_CURVE25519, curve)
require.EqualError(t, err, "bytes did not contain a proper public key banner")
assert.Equal(t, rest, invalidPem)
// Fail due to ivalid PEM format, because
// it's missing the requisite pre-encapsulation boundary.
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
assert.Nil(t, k)
assert.Equal(t, Curve_CURVE25519, curve)
assert.Equal(t, rest, invalidPem)
require.EqualError(t, err, "input did not contain a valid PEM encoded block")
}
func TestUnmarshalX25519PublicKey(t *testing.T) {
pubKey := []byte(`# A good key
-----BEGIN NEBULA X25519 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA X25519 PUBLIC KEY-----
`)
pubP256Key := []byte(`# A good key
-----BEGIN NEBULA P256 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAA=
-----END NEBULA P256 PUBLIC KEY-----
`)
shortKey := []byte(`# A short key
-----BEGIN NEBULA X25519 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
-----END NEBULA X25519 PUBLIC KEY-----
`)
invalidBanner := []byte(`# Invalid banner
-----BEGIN NOT A NEBULA PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
-----END NOT A NEBULA PUBLIC KEY-----
`)
invalidPem := []byte(`# Not a valid PEM format
-BEGIN NEBULA X25519 PUBLIC KEY-----
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
-END NEBULA X25519 PUBLIC KEY-----`)
keyBundle := appendByteSlices(pubKey, pubP256Key, shortKey, invalidBanner, invalidPem)
// Success test case
k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
assert.Len(t, k, 32)
require.NoError(t, err)
assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem))
assert.Equal(t, Curve_CURVE25519, curve)
// Success test case
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
assert.Len(t, k, 65)
require.NoError(t, err)
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
assert.Equal(t, Curve_P256, curve)
// Fail due to short key
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
assert.Nil(t, k)
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key")
// Fail due to invalid banner
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
assert.Nil(t, k)
require.EqualError(t, err, "bytes did not contain a proper public key banner")
assert.Equal(t, rest, invalidPem)
// Fail due to ivalid PEM format, because
// it's missing the requisite pre-encapsulation boundary.
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
assert.Nil(t, k)
assert.Equal(t, rest, invalidPem)
require.EqualError(t, err, "input did not contain a valid PEM encoded block")
}

167
cert/sign.go Normal file
View File

@ -0,0 +1,167 @@
package cert
import (
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rand"
"crypto/sha256"
"fmt"
"math/big"
"net/netip"
"time"
)
// TBSCertificate represents a certificate intended to be signed.
// It is invalid to use this structure as a Certificate.
type TBSCertificate struct {
Version Version
Name string
Networks []netip.Prefix
UnsafeNetworks []netip.Prefix
Groups []string
IsCA bool
NotBefore time.Time
NotAfter time.Time
PublicKey []byte
Curve Curve
issuer string
}
type beingSignedCertificate interface {
// fromTBSCertificate copies the values from the TBSCertificate to this versions internal representation
// Implementations must validate the resulting certificate contains valid information
fromTBSCertificate(*TBSCertificate) error
// marshalForSigning returns the bytes that should be signed
marshalForSigning() ([]byte, error)
// setSignature sets the signature for the certificate that has just been signed. The signature must not be blank.
setSignature([]byte) error
}
type SignerLambda func(certBytes []byte) ([]byte, error)
// Sign will create a sealed certificate using details provided by the TBSCertificate as long as those
// details do not violate constraints of the signing certificate.
// If the TBSCertificate is a CA then signer must be nil.
func (t *TBSCertificate) Sign(signer Certificate, curve Curve, key []byte) (Certificate, error) {
switch t.Curve {
case Curve_CURVE25519:
pk := ed25519.PrivateKey(key)
sp := func(certBytes []byte) ([]byte, error) {
sig := ed25519.Sign(pk, certBytes)
return sig, nil
}
return t.SignWith(signer, curve, sp)
case Curve_P256:
pk := &ecdsa.PrivateKey{
PublicKey: ecdsa.PublicKey{
Curve: elliptic.P256(),
},
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95
D: new(big.Int).SetBytes(key),
}
// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119
pk.X, pk.Y = pk.Curve.ScalarBaseMult(key)
sp := func(certBytes []byte) ([]byte, error) {
// We need to hash first for ECDSA
// - https://pkg.go.dev/crypto/ecdsa#SignASN1
hashed := sha256.Sum256(certBytes)
return ecdsa.SignASN1(rand.Reader, pk, hashed[:])
}
return t.SignWith(signer, curve, sp)
default:
return nil, fmt.Errorf("invalid curve: %s", t.Curve)
}
}
// SignWith does the same thing as sign, but uses the function in `sp` to calculate the signature.
// You should only use SignWith if you do not have direct access to your private key.
func (t *TBSCertificate) SignWith(signer Certificate, curve Curve, sp SignerLambda) (Certificate, error) {
if curve != t.Curve {
return nil, fmt.Errorf("curve in cert and private key supplied don't match")
}
if signer != nil {
if t.IsCA {
return nil, fmt.Errorf("can not sign a CA certificate with another")
}
err := checkCAConstraints(signer, t.NotBefore, t.NotAfter, t.Groups, t.Networks, t.UnsafeNetworks)
if err != nil {
return nil, err
}
issuer, err := signer.Fingerprint()
if err != nil {
return nil, fmt.Errorf("error computing issuer: %v", err)
}
t.issuer = issuer
} else {
if !t.IsCA {
return nil, fmt.Errorf("self signed certificates must have IsCA set to true")
}
}
var c beingSignedCertificate
switch t.Version {
case Version1:
c = &certificateV1{}
err := c.fromTBSCertificate(t)
if err != nil {
return nil, err
}
case Version2:
c = &certificateV2{}
err := c.fromTBSCertificate(t)
if err != nil {
return nil, err
}
default:
return nil, fmt.Errorf("unknown cert version %d", t.Version)
}
certBytes, err := c.marshalForSigning()
if err != nil {
return nil, err
}
sig, err := sp(certBytes)
if err != nil {
return nil, err
}
err = c.setSignature(sig)
if err != nil {
return nil, err
}
sc, ok := c.(Certificate)
if !ok {
return nil, fmt.Errorf("invalid certificate")
}
return sc, nil
}
func comparePrefix(a, b netip.Prefix) int {
addr := a.Addr().Compare(b.Addr())
if addr == 0 {
return a.Bits() - b.Bits()
}
return addr
}
// findDuplicatePrefix returns an error if there is a duplicate prefix in the pre-sorted input slice sortedPrefixes
func findDuplicatePrefix(sortedPrefixes []netip.Prefix) error {
if len(sortedPrefixes) < 2 {
return nil
}
for i := 1; i < len(sortedPrefixes); i++ {
if comparePrefix(sortedPrefixes[i], sortedPrefixes[i-1]) == 0 {
return NewErrInvalidCertificateProperties("duplicate network detected: %v", sortedPrefixes[i])
}
}
return nil
}

91
cert/sign_test.go Normal file
View File

@ -0,0 +1,91 @@
package cert
import (
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rand"
"net/netip"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCertificateV1_Sign(t *testing.T) {
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := []byte("1234567890abcedfghij1234567890ab")
tbs := TBSCertificate{
Version: Version1,
Name: "testing",
Networks: []netip.Prefix{
mustParsePrefixUnmapped("10.1.1.1/24"),
mustParsePrefixUnmapped("10.1.1.2/16"),
},
UnsafeNetworks: []netip.Prefix{
mustParsePrefixUnmapped("9.1.1.2/24"),
mustParsePrefixUnmapped("9.1.1.3/24"),
},
Groups: []string{"test-group1", "test-group2", "test-group3"},
NotBefore: before,
NotAfter: after,
PublicKey: pubKey,
IsCA: false,
}
pub, priv, err := ed25519.GenerateKey(rand.Reader)
c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_CURVE25519, priv)
require.NoError(t, err)
assert.NotNil(t, c)
assert.True(t, c.CheckSignature(pub))
b, err := c.Marshal()
require.NoError(t, err)
uc, err := unmarshalCertificateV1(b, nil)
require.NoError(t, err)
assert.NotNil(t, uc)
}
func TestCertificateV1_SignP256(t *testing.T) {
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := []byte("01234567890abcedfghij1234567890ab1234567890abcedfghij1234567890ab")
tbs := TBSCertificate{
Version: Version1,
Name: "testing",
Networks: []netip.Prefix{
mustParsePrefixUnmapped("10.1.1.1/24"),
mustParsePrefixUnmapped("10.1.1.2/16"),
},
UnsafeNetworks: []netip.Prefix{
mustParsePrefixUnmapped("9.1.1.2/24"),
mustParsePrefixUnmapped("9.1.1.3/16"),
},
Groups: []string{"test-group1", "test-group2", "test-group3"},
NotBefore: before,
NotAfter: after,
PublicKey: pubKey,
IsCA: false,
Curve: Curve_P256,
}
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y)
rawPriv := priv.D.FillBytes(make([]byte, 32))
c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_P256, rawPriv)
require.NoError(t, err)
assert.NotNil(t, c)
assert.True(t, c.CheckSignature(pub))
b, err := c.Marshal()
require.NoError(t, err)
uc, err := unmarshalCertificateV1(b, nil)
require.NoError(t, err)
assert.NotNil(t, uc)
}

138
cert_test/cert.go Normal file
View File

@ -0,0 +1,138 @@
package cert_test
import (
"crypto/ecdh"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"io"
"net/netip"
"time"
"github.com/slackhq/nebula/cert"
"golang.org/x/crypto/curve25519"
"golang.org/x/crypto/ed25519"
)
// NewTestCaCert will create a new ca certificate
func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) {
var err error
var pub, priv []byte
switch curve {
case cert.Curve_CURVE25519:
pub, priv, err = ed25519.GenerateKey(rand.Reader)
case cert.Curve_P256:
privk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
panic(err)
}
pub = elliptic.Marshal(elliptic.P256(), privk.PublicKey.X, privk.PublicKey.Y)
priv = privk.D.FillBytes(make([]byte, 32))
default:
// There is no default to allow the underlying lib to respond with an error
}
if before.IsZero() {
before = time.Now().Add(time.Second * -60).Round(time.Second)
}
if after.IsZero() {
after = time.Now().Add(time.Second * 60).Round(time.Second)
}
t := &cert.TBSCertificate{
Curve: curve,
Version: version,
Name: "test ca",
NotBefore: time.Unix(before.Unix(), 0),
NotAfter: time.Unix(after.Unix(), 0),
PublicKey: pub,
Networks: networks,
UnsafeNetworks: unsafeNetworks,
Groups: groups,
IsCA: true,
}
c, err := t.Sign(nil, curve, priv)
if err != nil {
panic(err)
}
pem, err := c.MarshalPEM()
if err != nil {
panic(err)
}
return c, pub, priv, pem
}
// NewTestCert will generate a signed certificate with the provided details.
// Expiry times are defaulted if you do not pass them in
func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) {
if before.IsZero() {
before = time.Now().Add(time.Second * -60).Round(time.Second)
}
if after.IsZero() {
after = time.Now().Add(time.Second * 60).Round(time.Second)
}
var pub, priv []byte
switch curve {
case cert.Curve_CURVE25519:
pub, priv = X25519Keypair()
case cert.Curve_P256:
pub, priv = P256Keypair()
default:
panic("unknown curve")
}
nc := &cert.TBSCertificate{
Version: v,
Curve: curve,
Name: name,
Networks: networks,
UnsafeNetworks: unsafeNetworks,
Groups: groups,
NotBefore: time.Unix(before.Unix(), 0),
NotAfter: time.Unix(after.Unix(), 0),
PublicKey: pub,
IsCA: false,
}
c, err := nc.Sign(ca, ca.Curve(), key)
if err != nil {
panic(err)
}
pem, err := c.MarshalPEM()
if err != nil {
panic(err)
}
return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem
}
func X25519Keypair() ([]byte, []byte) {
privkey := make([]byte, 32)
if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
panic(err)
}
pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint)
if err != nil {
panic(err)
}
return pubkey, privkey
}
func P256Keypair() ([]byte, []byte) {
privkey, err := ecdh.P256().GenerateKey(rand.Reader)
if err != nil {
panic(err)
}
pubkey := privkey.PublicKey()
return pubkey.Bytes(), privkey.Bytes()
}

View File

@ -4,12 +4,11 @@ import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"flag"
"fmt"
"io"
"math"
"net"
"net/netip"
"os"
"strings"
"time"
@ -28,34 +27,43 @@ type caFlags struct {
outCertPath *string
outQRPath *string
groups *string
ips *string
subnets *string
networks *string
unsafeNetworks *string
argonMemory *uint
argonIterations *uint
argonParallelism *uint
encryption *bool
version *uint
curve *string
p11url *string
// Deprecated options
ips *string
subnets *string
}
func newCaFlags() *caFlags {
cf := caFlags{set: flag.NewFlagSet("ca", flag.ContinueOnError)}
cf.set.Usage = func() {}
cf.name = cf.set.String("name", "", "Required: name of the certificate authority")
cf.version = cf.set.Uint("version", uint(cert.Version2), "Optional: version of the certificate format to use")
cf.duration = cf.set.Duration("duration", time.Duration(time.Hour*8760), "Optional: amount of time the certificate should be valid for. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"")
cf.outKeyPath = cf.set.String("out-key", "ca.key", "Optional: path to write the private key to")
cf.outCertPath = cf.set.String("out-crt", "ca.crt", "Optional: path to write the certificate to")
cf.outQRPath = cf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate")
cf.groups = cf.set.String("groups", "", "Optional: comma separated list of groups. This will limit which groups subordinate certs can use")
cf.ips = cf.set.String("ips", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use for ip addresses")
cf.subnets = cf.set.String("subnets", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use in subnets")
cf.networks = cf.set.String("networks", "", "Optional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in networks")
cf.unsafeNetworks = cf.set.String("unsafe-networks", "", "Optional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in unsafe networks")
cf.argonMemory = cf.set.Uint("argon-memory", 2*1024*1024, "Optional: Argon2 memory parameter (in KiB) used for encrypted private key passphrase")
cf.argonParallelism = cf.set.Uint("argon-parallelism", 4, "Optional: Argon2 parallelism parameter used for encrypted private key passphrase")
cf.argonIterations = cf.set.Uint("argon-iterations", 1, "Optional: Argon2 iterations parameter used for encrypted private key passphrase")
cf.encryption = cf.set.Bool("encrypt", false, "Optional: prompt for passphrase and write out-key in an encrypted format")
cf.curve = cf.set.String("curve", "25519", "EdDSA/ECDSA Curve (25519, P256)")
cf.p11url = p11Flag(cf.set)
cf.ips = cf.set.String("ips", "", "Deprecated, see -networks")
cf.subnets = cf.set.String("subnets", "", "Deprecated, see -unsafe-networks")
return &cf
}
@ -114,38 +122,51 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
}
}
var ips []*net.IPNet
if *cf.ips != "" {
for _, rs := range strings.Split(*cf.ips, ",") {
rs := strings.Trim(rs, " ")
if rs != "" {
ip, ipNet, err := net.ParseCIDR(rs)
if err != nil {
return newHelpErrorf("invalid ip definition: %s", err)
}
if ip.To4() == nil {
return newHelpErrorf("invalid ip definition: can only be ipv4, have %s", rs)
version := cert.Version(*cf.version)
if version != cert.Version1 && version != cert.Version2 {
return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2)
}
ipNet.IP = ip
ips = append(ips, ipNet)
var networks []netip.Prefix
if *cf.networks == "" && *cf.ips != "" {
// Pull up deprecated -ips flag if needed
*cf.networks = *cf.ips
}
if *cf.networks != "" {
for _, rs := range strings.Split(*cf.networks, ",") {
rs := strings.Trim(rs, " ")
if rs != "" {
n, err := netip.ParsePrefix(rs)
if err != nil {
return newHelpErrorf("invalid -networks definition: %s", rs)
}
if version == cert.Version1 && !n.Addr().Is4() {
return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4, have %s", rs)
}
networks = append(networks, n)
}
}
}
var subnets []*net.IPNet
if *cf.subnets != "" {
for _, rs := range strings.Split(*cf.subnets, ",") {
var unsafeNetworks []netip.Prefix
if *cf.unsafeNetworks == "" && *cf.subnets != "" {
// Pull up deprecated -subnets flag if needed
*cf.unsafeNetworks = *cf.subnets
}
if *cf.unsafeNetworks != "" {
for _, rs := range strings.Split(*cf.unsafeNetworks, ",") {
rs := strings.Trim(rs, " ")
if rs != "" {
_, s, err := net.ParseCIDR(rs)
n, err := netip.ParsePrefix(rs)
if err != nil {
return newHelpErrorf("invalid subnet definition: %s", err)
return newHelpErrorf("invalid -unsafe-networks definition: %s", rs)
}
if s.IP.To4() == nil {
return newHelpErrorf("invalid subnet definition: can only be ipv4, have %s", rs)
if version == cert.Version1 && !n.Addr().Is4() {
return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4, have %s", rs)
}
subnets = append(subnets, s)
unsafeNetworks = append(unsafeNetworks, n)
}
}
}
@ -224,19 +245,17 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
}
}
nc := cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
t := &cert.TBSCertificate{
Version: version,
Name: *cf.name,
Groups: groups,
Ips: ips,
Subnets: subnets,
Networks: networks,
UnsafeNetworks: unsafeNetworks,
NotBefore: time.Now(),
NotAfter: time.Now().Add(*cf.duration),
PublicKey: pub,
IsCA: true,
Curve: curve,
},
Pkcs11Backed: isP11,
}
if !isP11 {
@ -249,15 +268,16 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
return fmt.Errorf("refusing to overwrite existing CA cert: %s", *cf.outCertPath)
}
var c cert.Certificate
var b []byte
if isP11 {
err = nc.SignPkcs11(curve, p11Client)
c, err = t.SignWith(nil, curve, p11Client.SignASN1)
if err != nil {
return fmt.Errorf("error while signing with PKCS#11: %w", err)
}
} else {
err = nc.Sign(curve, rawPriv)
c, err = t.Sign(nil, curve, rawPriv)
if err != nil {
return fmt.Errorf("error while signing: %s", err)
}
@ -268,19 +288,16 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
return fmt.Errorf("error while encrypting out-key: %s", err)
}
} else {
b = cert.MarshalSigningPrivateKey(curve, rawPriv)
b = cert.MarshalSigningPrivateKeyToPEM(curve, rawPriv)
}
err = os.WriteFile(*cf.outKeyPath, b, 0600)
if err != nil {
return fmt.Errorf("error while writing out-key: %s", err)
}
if _, err := os.Stat(*cf.outCertPath); err == nil {
return fmt.Errorf("refusing to overwrite existing CA cert: %s", *cf.outCertPath)
}
}
b, err = nc.MarshalToPEM()
b, err = c.MarshalPEM()
if err != nil {
return fmt.Errorf("error while marshalling certificate: %s", err)
}

View File

@ -14,10 +14,9 @@ import (
"github.com/slackhq/nebula/cert"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
//TODO: test file permissions
func Test_caSummary(t *testing.T) {
assert.Equal(t, "ca <flags>: create a self signed certificate authority", caSummary())
}
@ -43,9 +42,11 @@ func Test_caHelp(t *testing.T) {
" -groups string\n"+
" \tOptional: comma separated list of groups. This will limit which groups subordinate certs can use\n"+
" -ips string\n"+
" \tOptional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use for ip addresses\n"+
" Deprecated, see -networks\n"+
" -name string\n"+
" \tRequired: name of the certificate authority\n"+
" -networks string\n"+
" \tOptional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in networks\n"+
" -out-crt string\n"+
" \tOptional: path to write the certificate to (default \"ca.crt\")\n"+
" -out-key string\n"+
@ -54,7 +55,11 @@ func Test_caHelp(t *testing.T) {
" \tOptional: output a qr code image (png) of the certificate\n"+
optionalPkcs11String(" -pkcs11 string\n \tOptional: PKCS#11 URI to an existing private key\n")+
" -subnets string\n"+
" \tOptional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use in subnets\n",
" \tDeprecated, see -unsafe-networks\n"+
" -unsafe-networks string\n"+
" \tOptional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in unsafe networks\n"+
" -version uint\n"+
" \tOptional: version of the certificate format to use (default 2)\n",
ob.String(),
)
}
@ -83,85 +88,86 @@ func Test_ca(t *testing.T) {
// required args
assertHelpError(t, ca(
[]string{"-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw,
[]string{"-version", "1", "-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw,
), "-name is required")
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
// ipv4 only ips
assertHelpError(t, ca([]string{"-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid ip definition: can only be ipv4, have 100::100/100")
assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid -networks definition: v1 certificates can only be ipv4, have 100::100/100")
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
// ipv4 only subnets
assertHelpError(t, ca([]string{"-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid subnet definition: can only be ipv4, have 100::100/100")
assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4, have 100::100/100")
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
// failed key write
ob.Reset()
eb.Reset()
args := []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"}
assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
args := []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"}
require.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
// create temp key file
keyF, err := os.CreateTemp("", "test.key")
assert.Nil(t, err)
os.Remove(keyF.Name())
require.NoError(t, err)
require.NoError(t, os.Remove(keyF.Name()))
// failed cert write
ob.Reset()
eb.Reset()
args = []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()}
assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()}
require.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
// create temp cert file
crtF, err := os.CreateTemp("", "test.crt")
assert.Nil(t, err)
os.Remove(crtF.Name())
os.Remove(keyF.Name())
require.NoError(t, err)
require.NoError(t, os.Remove(crtF.Name()))
require.NoError(t, os.Remove(keyF.Name()))
// test proper cert with removed empty groups and subnets
ob.Reset()
eb.Reset()
args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.Nil(t, ca(args, ob, eb, nopw))
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
require.NoError(t, ca(args, ob, eb, nopw))
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
// read cert and key files
rb, _ := os.ReadFile(keyF.Name())
lKey, b, err := cert.UnmarshalEd25519PrivateKey(rb)
assert.Len(t, b, 0)
assert.Nil(t, err)
lKey, b, c, err := cert.UnmarshalSigningPrivateKeyFromPEM(rb)
assert.Equal(t, cert.Curve_CURVE25519, c)
assert.Empty(t, b)
require.NoError(t, err)
assert.Len(t, lKey, 64)
rb, _ = os.ReadFile(crtF.Name())
lCrt, b, err := cert.UnmarshalNebulaCertificateFromPEM(rb)
assert.Len(t, b, 0)
assert.Nil(t, err)
lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb)
assert.Empty(t, b)
require.NoError(t, err)
assert.Equal(t, "test", lCrt.Details.Name)
assert.Len(t, lCrt.Details.Ips, 0)
assert.True(t, lCrt.Details.IsCA)
assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Details.Groups)
assert.Len(t, lCrt.Details.Subnets, 0)
assert.Len(t, lCrt.Details.PublicKey, 32)
assert.Equal(t, time.Duration(time.Minute*100), lCrt.Details.NotAfter.Sub(lCrt.Details.NotBefore))
assert.Equal(t, "", lCrt.Details.Issuer)
assert.True(t, lCrt.CheckSignature(lCrt.Details.PublicKey))
assert.Equal(t, "test", lCrt.Name())
assert.Empty(t, lCrt.Networks())
assert.True(t, lCrt.IsCA())
assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Groups())
assert.Empty(t, lCrt.UnsafeNetworks())
assert.Len(t, lCrt.PublicKey(), 32)
assert.Equal(t, time.Duration(time.Minute*100), lCrt.NotAfter().Sub(lCrt.NotBefore()))
assert.Equal(t, "", lCrt.Issuer())
assert.True(t, lCrt.CheckSignature(lCrt.PublicKey()))
// test encrypted key
os.Remove(keyF.Name())
os.Remove(crtF.Name())
ob.Reset()
eb.Reset()
args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.Nil(t, ca(args, ob, eb, testpw))
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
require.NoError(t, ca(args, ob, eb, testpw))
assert.Equal(t, pwPromptOb, ob.String())
assert.Equal(t, "", eb.String())
@ -169,7 +175,7 @@ func Test_ca(t *testing.T) {
rb, _ = os.ReadFile(keyF.Name())
k, _ := pem.Decode(rb)
ned, err := cert.UnmarshalNebulaEncryptedData(k.Bytes)
assert.Nil(t, err)
require.NoError(t, err)
// we won't know salt in advance, so just check start of string
assert.Equal(t, uint32(2*1024*1024), ned.EncryptionMetadata.Argon2Parameters.Memory)
assert.Equal(t, uint8(4), ned.EncryptionMetadata.Argon2Parameters.Parallelism)
@ -179,8 +185,8 @@ func Test_ca(t *testing.T) {
var curve cert.Curve
curve, lKey, b, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rb)
assert.Equal(t, cert.Curve_CURVE25519, curve)
assert.Nil(t, err)
assert.Len(t, b, 0)
require.NoError(t, err)
assert.Empty(t, b)
assert.Len(t, lKey, 64)
// test when reading passsword results in an error
@ -188,8 +194,8 @@ func Test_ca(t *testing.T) {
os.Remove(crtF.Name())
ob.Reset()
eb.Reset()
args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.Error(t, ca(args, ob, eb, errpw))
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
require.Error(t, ca(args, ob, eb, errpw))
assert.Equal(t, pwPromptOb, ob.String())
assert.Equal(t, "", eb.String())
@ -198,8 +204,8 @@ func Test_ca(t *testing.T) {
os.Remove(crtF.Name())
ob.Reset()
eb.Reset()
args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext")
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
require.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext")
assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up
assert.Equal(t, "", eb.String())
@ -208,14 +214,14 @@ func Test_ca(t *testing.T) {
os.Remove(crtF.Name())
ob.Reset()
eb.Reset()
args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.Nil(t, ca(args, ob, eb, nopw))
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
require.NoError(t, ca(args, ob, eb, nopw))
// test that we won't overwrite existing certificate file
ob.Reset()
eb.Reset()
args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name())
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
require.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name())
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
@ -223,8 +229,8 @@ func Test_ca(t *testing.T) {
os.Remove(keyF.Name())
ob.Reset()
eb.Reset()
args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name())
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
require.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name())
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
os.Remove(keyF.Name())

View File

@ -82,12 +82,12 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error {
return fmt.Errorf("error while getting public key: %w", err)
}
} else {
err = os.WriteFile(*cf.outKeyPath, cert.MarshalPrivateKey(curve, rawPriv), 0600)
err = os.WriteFile(*cf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600)
if err != nil {
return fmt.Errorf("error while writing out-key: %s", err)
}
}
err = os.WriteFile(*cf.outPubPath, cert.MarshalPublicKey(curve, pub), 0600)
err = os.WriteFile(*cf.outPubPath, cert.MarshalPublicKeyToPEM(curve, pub), 0600)
if err != nil {
return fmt.Errorf("error while writing out-pub: %s", err)
}

View File

@ -7,10 +7,9 @@ import (
"github.com/slackhq/nebula/cert"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
//TODO: test file permissions
func Test_keygenSummary(t *testing.T) {
assert.Equal(t, "keygen <flags>: create a public/private key pair. the public key can be passed to `nebula-cert sign`", keygenSummary())
}
@ -49,46 +48,48 @@ func Test_keygen(t *testing.T) {
ob.Reset()
eb.Reset()
args := []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", "/do/not/write/pleasekey"}
assert.EqualError(t, keygen(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
require.EqualError(t, keygen(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
// create temp key file
keyF, err := os.CreateTemp("", "test.key")
assert.Nil(t, err)
require.NoError(t, err)
defer os.Remove(keyF.Name())
// failed pub write
ob.Reset()
eb.Reset()
args = []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", keyF.Name()}
assert.EqualError(t, keygen(args, ob, eb), "error while writing out-pub: open /do/not/write/pleasepub: "+NoSuchDirError)
require.EqualError(t, keygen(args, ob, eb), "error while writing out-pub: open /do/not/write/pleasepub: "+NoSuchDirError)
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
// create temp pub file
pubF, err := os.CreateTemp("", "test.pub")
assert.Nil(t, err)
require.NoError(t, err)
defer os.Remove(pubF.Name())
// test proper keygen
ob.Reset()
eb.Reset()
args = []string{"-out-pub", pubF.Name(), "-out-key", keyF.Name()}
assert.Nil(t, keygen(args, ob, eb))
require.NoError(t, keygen(args, ob, eb))
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
// read cert and key files
rb, _ := os.ReadFile(keyF.Name())
lKey, b, err := cert.UnmarshalX25519PrivateKey(rb)
assert.Len(t, b, 0)
assert.Nil(t, err)
lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb)
assert.Equal(t, cert.Curve_CURVE25519, curve)
assert.Empty(t, b)
require.NoError(t, err)
assert.Len(t, lKey, 32)
rb, _ = os.ReadFile(pubF.Name())
lPub, b, err := cert.UnmarshalX25519PublicKey(rb)
assert.Len(t, b, 0)
assert.Nil(t, err)
lPub, b, curve, err := cert.UnmarshalPublicKeyFromPEM(rb)
assert.Equal(t, cert.Curve_CURVE25519, curve)
assert.Empty(t, b)
require.NoError(t, err)
assert.Len(t, lPub, 32)
}

View File

@ -9,10 +9,9 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
//TODO: all flag parsing continueOnError will print to stderr on its own currently
func Test_help(t *testing.T) {
expected := "Usage of " + os.Args[0] + " <global flags> <mode>:\n" +
" Global flags:\n" +
@ -81,7 +80,7 @@ func assertHelpError(t *testing.T, err error, msg string) {
t.Fatal(fmt.Sprintf("err was not a helpError: %q, expected %q", err, msg))
}
assert.EqualError(t, err, msg)
require.EqualError(t, err, msg)
}
func optionalPkcs11String(msg string) string {

View File

@ -45,28 +45,27 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
return fmt.Errorf("unable to read cert; %s", err)
}
var c *cert.NebulaCertificate
var c cert.Certificate
var qrBytes []byte
part := 0
var jsonCerts []cert.Certificate
for {
c, rawCert, err = cert.UnmarshalNebulaCertificateFromPEM(rawCert)
c, rawCert, err = cert.UnmarshalCertificateFromPEM(rawCert)
if err != nil {
return fmt.Errorf("error while unmarshaling cert: %s", err)
}
if *pf.json {
b, _ := json.Marshal(c)
out.Write(b)
out.Write([]byte("\n"))
jsonCerts = append(jsonCerts, c)
} else {
out.Write([]byte(c.String()))
out.Write([]byte("\n"))
_, _ = out.Write([]byte(c.String()))
_, _ = out.Write([]byte("\n"))
}
if *pf.outQRPath != "" {
b, err := c.MarshalToPEM()
b, err := c.MarshalPEM()
if err != nil {
return fmt.Errorf("error while marshalling cert to PEM: %s", err)
}
@ -80,6 +79,12 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
part++
}
if *pf.json {
b, _ := json.Marshal(jsonCerts)
_, _ = out.Write(b)
_, _ = out.Write([]byte("\n"))
}
if *pf.outQRPath != "" {
b, err := qrcode.Encode(string(qrBytes), qrcode.Medium, -5)
if err != nil {

View File

@ -2,12 +2,17 @@ package main
import (
"bytes"
"crypto/ed25519"
"crypto/rand"
"encoding/hex"
"net/netip"
"os"
"testing"
"time"
"github.com/slackhq/nebula/cert"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_printSummary(t *testing.T) {
@ -48,45 +53,106 @@ func Test_printCert(t *testing.T) {
err = printCert([]string{"-path", "does_not_exist"}, ob, eb)
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
assert.EqualError(t, err, "unable to read cert; open does_not_exist: "+NoSuchFileError)
require.EqualError(t, err, "unable to read cert; open does_not_exist: "+NoSuchFileError)
// invalid cert at path
ob.Reset()
eb.Reset()
tf, err := os.CreateTemp("", "print-cert")
assert.Nil(t, err)
require.NoError(t, err)
defer os.Remove(tf.Name())
tf.WriteString("-----BEGIN NOPE-----")
err = printCert([]string{"-path", tf.Name()}, ob, eb)
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
assert.EqualError(t, err, "error while unmarshaling cert: input did not contain a valid PEM encoded block")
require.EqualError(t, err, "error while unmarshaling cert: input did not contain a valid PEM encoded block")
// test multiple certs
ob.Reset()
eb.Reset()
tf.Truncate(0)
tf.Seek(0, 0)
c := cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
Name: "test",
Groups: []string{"hi"},
PublicKey: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2},
},
Signature: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2},
}
ca, caKey := NewTestCaCert("test ca", nil, nil, time.Time{}, time.Time{}, nil, nil, nil)
c, _ := NewTestCert(ca, caKey, "test", time.Time{}, time.Time{}, []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")}, nil, []string{"hi"})
p, _ := c.MarshalToPEM()
p, _ := c.MarshalPEM()
tf.Write(p)
tf.Write(p)
tf.Write(p)
err = printCert([]string{"-path", tf.Name()}, ob, eb)
assert.Nil(t, err)
fp, _ := c.Fingerprint()
pk := hex.EncodeToString(c.PublicKey())
sig := hex.EncodeToString(c.Signature())
require.NoError(t, err)
assert.Equal(
t,
"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\n",
//"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\n",
`{
"details": {
"curve": "CURVE25519",
"groups": [
"hi"
],
"isCa": false,
"issuer": "`+c.Issuer()+`",
"name": "test",
"networks": [
"10.0.0.123/8"
],
"notAfter": "0001-01-01T00:00:00Z",
"notBefore": "0001-01-01T00:00:00Z",
"publicKey": "`+pk+`",
"unsafeNetworks": []
},
"fingerprint": "`+fp+`",
"signature": "`+sig+`",
"version": 1
}
{
"details": {
"curve": "CURVE25519",
"groups": [
"hi"
],
"isCa": false,
"issuer": "`+c.Issuer()+`",
"name": "test",
"networks": [
"10.0.0.123/8"
],
"notAfter": "0001-01-01T00:00:00Z",
"notBefore": "0001-01-01T00:00:00Z",
"publicKey": "`+pk+`",
"unsafeNetworks": []
},
"fingerprint": "`+fp+`",
"signature": "`+sig+`",
"version": 1
}
{
"details": {
"curve": "CURVE25519",
"groups": [
"hi"
],
"isCa": false,
"issuer": "`+c.Issuer()+`",
"name": "test",
"networks": [
"10.0.0.123/8"
],
"notAfter": "0001-01-01T00:00:00Z",
"notBefore": "0001-01-01T00:00:00Z",
"publicKey": "`+pk+`",
"unsafeNetworks": []
},
"fingerprint": "`+fp+`",
"signature": "`+sig+`",
"version": 1
}
`,
ob.String(),
)
assert.Equal(t, "", eb.String())
@ -96,26 +162,84 @@ func Test_printCert(t *testing.T) {
eb.Reset()
tf.Truncate(0)
tf.Seek(0, 0)
c = cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
Name: "test",
Groups: []string{"hi"},
PublicKey: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2},
},
Signature: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2},
}
p, _ = c.MarshalToPEM()
tf.Write(p)
tf.Write(p)
tf.Write(p)
err = printCert([]string{"-json", "-path", tf.Name()}, ob, eb)
assert.Nil(t, err)
fp, _ = c.Fingerprint()
pk = hex.EncodeToString(c.PublicKey())
sig = hex.EncodeToString(c.Signature())
require.NoError(t, err)
assert.Equal(
t,
"{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n",
`[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}]
`,
ob.String(),
)
assert.Equal(t, "", eb.String())
}
// NewTestCaCert will generate a CA cert
func NewTestCaCert(name string, pubKey, privKey []byte, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte) {
var err error
if pubKey == nil || privKey == nil {
pubKey, privKey, err = ed25519.GenerateKey(rand.Reader)
if err != nil {
panic(err)
}
}
t := &cert.TBSCertificate{
Version: cert.Version1,
Name: name,
NotBefore: time.Unix(before.Unix(), 0),
NotAfter: time.Unix(after.Unix(), 0),
PublicKey: pubKey,
Networks: networks,
UnsafeNetworks: unsafeNetworks,
Groups: groups,
IsCA: true,
}
c, err := t.Sign(nil, cert.Curve_CURVE25519, privKey)
if err != nil {
panic(err)
}
return c, privKey
}
func NewTestCert(ca cert.Certificate, signerKey []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte) {
if before.IsZero() {
before = ca.NotBefore()
}
if after.IsZero() {
after = ca.NotAfter()
}
if len(networks) == 0 {
networks = []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")}
}
pub, rawPriv := x25519Keypair()
nc := &cert.TBSCertificate{
Version: cert.Version1,
Name: name,
Networks: networks,
UnsafeNetworks: unsafeNetworks,
Groups: groups,
NotBefore: time.Unix(before.Unix(), 0),
NotAfter: time.Unix(after.Unix(), 0),
PublicKey: pub,
IsCA: false,
}
c, err := nc.Sign(ca, ca.Curve(), signerKey)
if err != nil {
panic(err)
}
return c, rawPriv
}

View File

@ -3,10 +3,11 @@ package main
import (
"crypto/ecdh"
"crypto/rand"
"errors"
"flag"
"fmt"
"io"
"net"
"net/netip"
"os"
"strings"
"time"
@ -19,35 +20,45 @@ import (
type signFlags struct {
set *flag.FlagSet
version *uint
caKeyPath *string
caCertPath *string
name *string
ip *string
networks *string
unsafeNetworks *string
duration *time.Duration
inPubPath *string
outKeyPath *string
outCertPath *string
outQRPath *string
groups *string
subnets *string
p11url *string
// Deprecated options
ip *string
subnets *string
}
func newSignFlags() *signFlags {
sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)}
sf.set.Usage = func() {}
sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use, the default is to create both v1 and v2 certificates.")
sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key")
sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert")
sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname")
sf.ip = sf.set.String("ip", "", "Required: ipv4 address and network in CIDR notation to assign the cert")
sf.networks = sf.set.String("networks", "", "Required: comma separated list of ip address and network in CIDR notation to assign to this cert")
sf.unsafeNetworks = sf.set.String("unsafe-networks", "", "Optional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for")
sf.duration = sf.set.Duration("duration", 0, "Optional: how long the cert should be valid for. The default is 1 second before the signing cert expires. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"")
sf.inPubPath = sf.set.String("in-pub", "", "Optional (if out-key not set): path to read a previously generated public key")
sf.outKeyPath = sf.set.String("out-key", "", "Optional (if in-pub not set): path to write the private key to")
sf.outCertPath = sf.set.String("out-crt", "", "Optional: path to write the certificate to")
sf.outQRPath = sf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate")
sf.groups = sf.set.String("groups", "", "Optional: comma separated list of groups")
sf.subnets = sf.set.String("subnets", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. Subnets this cert can serve for")
sf.p11url = p11Flag(sf.set)
sf.ip = sf.set.String("ip", "", "Deprecated, see -networks")
sf.subnets = sf.set.String("subnets", "", "Deprecated, see -unsafe-networks")
return &sf
}
@ -71,32 +82,47 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
if err := mustFlagString("name", sf.name); err != nil {
return err
}
if err := mustFlagString("ip", sf.ip); err != nil {
return err
}
if !isP11 && *sf.inPubPath != "" && *sf.outKeyPath != "" {
return newHelpErrorf("cannot set both -in-pub and -out-key")
}
var v4Networks []netip.Prefix
var v6Networks []netip.Prefix
if *sf.networks == "" && *sf.ip != "" {
// Pull up deprecated -ip flag if needed
*sf.networks = *sf.ip
}
if len(*sf.networks) == 0 {
return newHelpErrorf("-networks is required")
}
version := cert.Version(*sf.version)
if version != 0 && version != cert.Version1 && version != cert.Version2 {
return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2)
}
var curve cert.Curve
var caKey []byte
if !isP11 {
var rawCAKey []byte
rawCAKey, err := os.ReadFile(*sf.caKeyPath)
if err != nil {
return fmt.Errorf("error while reading ca-key: %s", err)
}
// naively attempt to decode the private key as though it is not encrypted
caKey, _, curve, err = cert.UnmarshalSigningPrivateKey(rawCAKey)
if err == cert.ErrPrivateKeyEncrypted {
caKey, _, curve, err = cert.UnmarshalSigningPrivateKeyFromPEM(rawCAKey)
if errors.Is(err, cert.ErrPrivateKeyEncrypted) {
// ask for a passphrase until we get one
var passphrase []byte
for i := 0; i < 5; i++ {
out.Write([]byte("Enter passphrase: "))
passphrase, err = pr.ReadPassword()
if err == ErrNoTerminal {
if errors.Is(err, ErrNoTerminal) {
return fmt.Errorf("ca-key is encrypted and must be decrypted interactively")
} else if err != nil {
return fmt.Errorf("error reading password: %s", err)
@ -124,7 +150,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
return fmt.Errorf("error while reading ca-crt: %s", err)
}
caCert, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCACert)
caCert, _, err := cert.UnmarshalCertificateFromPEM(rawCACert)
if err != nil {
return fmt.Errorf("error while parsing ca-crt: %s", err)
}
@ -135,30 +161,59 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
}
}
issuer, err := caCert.Sha256Sum()
if err != nil {
return fmt.Errorf("error while getting -ca-crt fingerprint: %s", err)
}
if caCert.Expired(time.Now()) {
return fmt.Errorf("ca certificate is expired")
}
// if no duration is given, expire one second before the root expires
if *sf.duration <= 0 {
*sf.duration = time.Until(caCert.Details.NotAfter) - time.Second*1
*sf.duration = time.Until(caCert.NotAfter()) - time.Second*1
}
ip, ipNet, err := net.ParseCIDR(*sf.ip)
if *sf.networks != "" {
for _, rs := range strings.Split(*sf.networks, ",") {
rs := strings.Trim(rs, " ")
if rs != "" {
n, err := netip.ParsePrefix(rs)
if err != nil {
return newHelpErrorf("invalid ip definition: %s", err)
return newHelpErrorf("invalid -networks definition: %s", rs)
}
if ip.To4() == nil {
return newHelpErrorf("invalid ip definition: can only be ipv4, have %s", *sf.ip)
}
ipNet.IP = ip
groups := []string{}
if n.Addr().Is4() {
v4Networks = append(v4Networks, n)
} else {
v6Networks = append(v6Networks, n)
}
}
}
}
var v4UnsafeNetworks []netip.Prefix
var v6UnsafeNetworks []netip.Prefix
if *sf.unsafeNetworks == "" && *sf.subnets != "" {
// Pull up deprecated -subnets flag if needed
*sf.unsafeNetworks = *sf.subnets
}
if *sf.unsafeNetworks != "" {
for _, rs := range strings.Split(*sf.unsafeNetworks, ",") {
rs := strings.Trim(rs, " ")
if rs != "" {
n, err := netip.ParsePrefix(rs)
if err != nil {
return newHelpErrorf("invalid -unsafe-networks definition: %s", rs)
}
if n.Addr().Is4() {
v4UnsafeNetworks = append(v4UnsafeNetworks, n)
} else {
v6UnsafeNetworks = append(v6UnsafeNetworks, n)
}
}
}
}
var groups []string
if *sf.groups != "" {
for _, rg := range strings.Split(*sf.groups, ",") {
g := strings.TrimSpace(rg)
@ -168,23 +223,6 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
}
}
subnets := []*net.IPNet{}
if *sf.subnets != "" {
for _, rs := range strings.Split(*sf.subnets, ",") {
rs := strings.Trim(rs, " ")
if rs != "" {
_, s, err := net.ParseCIDR(rs)
if err != nil {
return newHelpErrorf("invalid subnet definition: %s", err)
}
if s.IP.To4() == nil {
return newHelpErrorf("invalid subnet definition: can only be ipv4, have %s", rs)
}
subnets = append(subnets, s)
}
}
}
var pub, rawPriv []byte
var p11Client *pkclient.PKClient
@ -205,7 +243,8 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
if err != nil {
return fmt.Errorf("error while reading in-pub: %s", err)
}
pub, _, pubCurve, err = cert.UnmarshalPublicKey(rawPub)
pub, _, pubCurve, err = cert.UnmarshalPublicKeyFromPEM(rawPub)
if err != nil {
return fmt.Errorf("error while parsing in-pub: %s", err)
}
@ -221,38 +260,6 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
pub, rawPriv = newKeypair(curve)
}
nc := cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
Name: *sf.name,
Ips: []*net.IPNet{ipNet},
Groups: groups,
Subnets: subnets,
NotBefore: time.Now(),
NotAfter: time.Now().Add(*sf.duration),
PublicKey: pub,
IsCA: false,
Issuer: issuer,
Curve: curve,
},
Pkcs11Backed: isP11,
}
if p11Client == nil {
err = nc.Sign(curve, caKey)
if err != nil {
return fmt.Errorf("error while signing: %w", err)
}
} else {
err = nc.SignPkcs11(curve, p11Client)
if err != nil {
return fmt.Errorf("error while signing with PKCS#11: %w", err)
}
}
if err := nc.CheckRootConstrains(caCert); err != nil {
return fmt.Errorf("refusing to sign, root certificate constraints violated: %s", err)
}
if *sf.outKeyPath == "" {
*sf.outKeyPath = *sf.name + ".key"
}
@ -265,21 +272,106 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath)
}
var crts []cert.Certificate
notBefore := time.Now()
notAfter := notBefore.Add(*sf.duration)
if version == 0 || version == cert.Version1 {
// Make sure we at least have an ip
if len(v4Networks) != 1 {
return newHelpErrorf("invalid -networks definition: v1 certificates can only have a single ipv4 address")
}
if version == cert.Version1 {
// If we are asked to mint a v1 certificate only then we cant just ignore any v6 addresses
if len(v6Networks) > 0 {
return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4")
}
if len(v6UnsafeNetworks) > 0 {
return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4")
}
}
t := &cert.TBSCertificate{
Version: cert.Version1,
Name: *sf.name,
Networks: []netip.Prefix{v4Networks[0]},
Groups: groups,
UnsafeNetworks: v4UnsafeNetworks,
NotBefore: notBefore,
NotAfter: notAfter,
PublicKey: pub,
IsCA: false,
Curve: curve,
}
var nc cert.Certificate
if p11Client == nil {
nc, err = t.Sign(caCert, curve, caKey)
if err != nil {
return fmt.Errorf("error while signing: %w", err)
}
} else {
nc, err = t.SignWith(caCert, curve, p11Client.SignASN1)
if err != nil {
return fmt.Errorf("error while signing with PKCS#11: %w", err)
}
}
crts = append(crts, nc)
}
if version == 0 || version == cert.Version2 {
t := &cert.TBSCertificate{
Version: cert.Version2,
Name: *sf.name,
Networks: append(v4Networks, v6Networks...),
Groups: groups,
UnsafeNetworks: append(v4UnsafeNetworks, v6UnsafeNetworks...),
NotBefore: notBefore,
NotAfter: notAfter,
PublicKey: pub,
IsCA: false,
Curve: curve,
}
var nc cert.Certificate
if p11Client == nil {
nc, err = t.Sign(caCert, curve, caKey)
if err != nil {
return fmt.Errorf("error while signing: %w", err)
}
} else {
nc, err = t.SignWith(caCert, curve, p11Client.SignASN1)
if err != nil {
return fmt.Errorf("error while signing with PKCS#11: %w", err)
}
}
crts = append(crts, nc)
}
if !isP11 && *sf.inPubPath == "" {
if _, err := os.Stat(*sf.outKeyPath); err == nil {
return fmt.Errorf("refusing to overwrite existing key: %s", *sf.outKeyPath)
}
err = os.WriteFile(*sf.outKeyPath, cert.MarshalPrivateKey(curve, rawPriv), 0600)
err = os.WriteFile(*sf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600)
if err != nil {
return fmt.Errorf("error while writing out-key: %s", err)
}
}
b, err := nc.MarshalToPEM()
var b []byte
for _, c := range crts {
sb, err := c.MarshalPEM()
if err != nil {
return fmt.Errorf("error while marshalling certificate: %s", err)
}
b = append(b, sb...)
}
err = os.WriteFile(*sf.outCertPath, b, 0600)
if err != nil {

View File

@ -13,11 +13,10 @@ import (
"github.com/slackhq/nebula/cert"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ed25519"
)
//TODO: test file permissions
func Test_signSummary(t *testing.T) {
assert.Equal(t, "sign <flags>: create and sign a certificate", signSummary())
}
@ -39,9 +38,11 @@ func Test_signHelp(t *testing.T) {
" -in-pub string\n"+
" \tOptional (if out-key not set): path to read a previously generated public key\n"+
" -ip string\n"+
" \tRequired: ipv4 address and network in CIDR notation to assign the cert\n"+
" \tDeprecated, see -networks\n"+
" -name string\n"+
" \tRequired: name of the cert, usually a hostname\n"+
" -networks string\n"+
" \tRequired: comma separated list of ip address and network in CIDR notation to assign to this cert\n"+
" -out-crt string\n"+
" \tOptional: path to write the certificate to\n"+
" -out-key string\n"+
@ -50,7 +51,11 @@ func Test_signHelp(t *testing.T) {
" \tOptional: output a qr code image (png) of the certificate\n"+
optionalPkcs11String(" -pkcs11 string\n \tOptional: PKCS#11 URI to an existing private key\n")+
" -subnets string\n"+
" \tOptional: comma separated list of ipv4 address and network in CIDR notation. Subnets this cert can serve for\n",
" \tDeprecated, see -unsafe-networks\n"+
" -unsafe-networks string\n"+
" \tOptional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for\n"+
" -version uint\n"+
" \tOptional: version of the certificate format to use, the default is to create both v1 and v2 certificates.\n",
ob.String(),
)
}
@ -77,20 +82,20 @@ func Test_signCert(t *testing.T) {
// required args
assertHelpError(t, signCert(
[]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
[]string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
), "-name is required")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
assertHelpError(t, signCert(
[]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
), "-ip is required")
[]string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
), "-networks is required")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
// cannot set -in-pub and -out-key
assertHelpError(t, signCert(
[]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb, nopw,
[]string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb, nopw,
), "cannot set both -in-pub and -out-key")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
@ -98,18 +103,18 @@ func Test_signCert(t *testing.T) {
// failed to read key
ob.Reset()
eb.Reset()
args := []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError)
args := []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
require.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError)
// failed to unmarshal key
ob.Reset()
eb.Reset()
caKeyF, err := os.CreateTemp("", "sign-cert.key")
assert.Nil(t, err)
require.NoError(t, err)
defer os.Remove(caKeyF.Name())
args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block")
args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
@ -117,11 +122,11 @@ func Test_signCert(t *testing.T) {
ob.Reset()
eb.Reset()
caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader)
caKeyF.Write(cert.MarshalEd25519PrivateKey(caPriv))
caKeyF.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv))
// failed to read cert
args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError)
args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
require.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError)
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
@ -129,30 +134,22 @@ func Test_signCert(t *testing.T) {
ob.Reset()
eb.Reset()
caCrtF, err := os.CreateTemp("", "sign-cert.crt")
assert.Nil(t, err)
require.NoError(t, err)
defer os.Remove(caCrtF.Name())
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block")
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
// write a proper ca cert for later
ca := cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
Name: "ca",
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Minute * 200),
PublicKey: caPub,
IsCA: true,
},
}
b, _ := ca.MarshalToPEM()
ca, _ := NewTestCaCert("ca", caPub, caPriv, time.Now(), time.Now().Add(time.Minute*200), nil, nil, nil)
b, _ := ca.MarshalPEM()
caCrtF.Write(b)
// failed to read pub
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError)
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"}
require.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError)
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
@ -160,11 +157,11 @@ func Test_signCert(t *testing.T) {
ob.Reset()
eb.Reset()
inPubF, err := os.CreateTemp("", "in.pub")
assert.Nil(t, err)
require.NoError(t, err)
defer os.Remove(inPubF.Name())
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block")
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"}
require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
@ -172,116 +169,124 @@ func Test_signCert(t *testing.T) {
ob.Reset()
eb.Reset()
inPub, _ := x25519Keypair()
inPubF.Write(cert.MarshalX25519PublicKey(inPub))
inPubF.Write(cert.MarshalPublicKeyToPEM(cert.Curve_CURVE25519, inPub))
// bad ip cidr
ob.Reset()
eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "a1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: invalid CIDR address: a1.1.1.1/24")
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "a1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: a1.1.1.1/24")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
ob.Reset()
eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "100::100/100", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: can only be ipv4, have 100::100/100")
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "100::100/100", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: v1 certificates can only have a single ipv4 address")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
ob.Reset()
eb.Reset()
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24,1.1.1.2/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: v1 certificates can only have a single ipv4 address")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
// bad subnet cidr
ob.Reset()
eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: invalid CIDR address: a")
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: a")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
ob.Reset()
eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"}
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: can only be ipv4, have 100::100/100")
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"}
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
// mismatched ca key
_, caPriv2, _ := ed25519.GenerateKey(rand.Reader)
caKeyF2, err := os.CreateTemp("", "sign-cert-2.key")
assert.Nil(t, err)
require.NoError(t, err)
defer os.Remove(caKeyF2.Name())
caKeyF2.Write(cert.MarshalEd25519PrivateKey(caPriv2))
caKeyF2.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv2))
ob.Reset()
eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key")
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
require.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
// failed key write
ob.Reset()
eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"}
require.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
// create temp key file
keyF, err := os.CreateTemp("", "test.key")
assert.Nil(t, err)
require.NoError(t, err)
os.Remove(keyF.Name())
// failed cert write
ob.Reset()
eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"}
require.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
os.Remove(keyF.Name())
// create temp cert file
crtF, err := os.CreateTemp("", "test.crt")
assert.Nil(t, err)
require.NoError(t, err)
os.Remove(crtF.Name())
// test proper cert with removed empty groups and subnets
ob.Reset()
eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.Nil(t, signCert(args, ob, eb, nopw))
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
require.NoError(t, signCert(args, ob, eb, nopw))
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
// read cert and key files
rb, _ := os.ReadFile(keyF.Name())
lKey, b, err := cert.UnmarshalX25519PrivateKey(rb)
assert.Len(t, b, 0)
assert.Nil(t, err)
lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb)
assert.Equal(t, cert.Curve_CURVE25519, curve)
assert.Empty(t, b)
require.NoError(t, err)
assert.Len(t, lKey, 32)
rb, _ = os.ReadFile(crtF.Name())
lCrt, b, err := cert.UnmarshalNebulaCertificateFromPEM(rb)
assert.Len(t, b, 0)
assert.Nil(t, err)
lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb)
assert.Empty(t, b)
require.NoError(t, err)
assert.Equal(t, "test", lCrt.Details.Name)
assert.Equal(t, "1.1.1.1/24", lCrt.Details.Ips[0].String())
assert.Len(t, lCrt.Details.Ips, 1)
assert.False(t, lCrt.Details.IsCA)
assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Details.Groups)
assert.Len(t, lCrt.Details.Subnets, 3)
assert.Len(t, lCrt.Details.PublicKey, 32)
assert.Equal(t, time.Duration(time.Minute*100), lCrt.Details.NotAfter.Sub(lCrt.Details.NotBefore))
assert.Equal(t, "test", lCrt.Name())
assert.Equal(t, "1.1.1.1/24", lCrt.Networks()[0].String())
assert.Len(t, lCrt.Networks(), 1)
assert.False(t, lCrt.IsCA())
assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Groups())
assert.Len(t, lCrt.UnsafeNetworks(), 3)
assert.Len(t, lCrt.PublicKey(), 32)
assert.Equal(t, time.Duration(time.Minute*100), lCrt.NotAfter().Sub(lCrt.NotBefore()))
sns := []string{}
for _, sn := range lCrt.Details.Subnets {
for _, sn := range lCrt.UnsafeNetworks() {
sns = append(sns, sn.String())
}
assert.Equal(t, []string{"10.1.1.1/32", "10.2.2.2/32", "10.5.5.5/32"}, sns)
issuer, _ := ca.Sha256Sum()
assert.Equal(t, issuer, lCrt.Details.Issuer)
issuer, _ := ca.Fingerprint()
assert.Equal(t, issuer, lCrt.Issuer())
assert.True(t, lCrt.CheckSignature(caPub))
@ -290,53 +295,55 @@ func Test_signCert(t *testing.T) {
os.Remove(crtF.Name())
ob.Reset()
eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"}
assert.Nil(t, signCert(args, ob, eb, nopw))
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"}
require.NoError(t, signCert(args, ob, eb, nopw))
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
// read cert file and check pub key matches in-pub
rb, _ = os.ReadFile(crtF.Name())
lCrt, b, err = cert.UnmarshalNebulaCertificateFromPEM(rb)
assert.Len(t, b, 0)
assert.Nil(t, err)
assert.Equal(t, lCrt.Details.PublicKey, inPub)
lCrt, b, err = cert.UnmarshalCertificateFromPEM(rb)
assert.Empty(t, b)
require.NoError(t, err)
assert.Equal(t, lCrt.PublicKey(), inPub)
// test refuse to sign cert with duration beyond root
ob.Reset()
eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate constraints violated: certificate expires after signing certificate")
os.Remove(keyF.Name())
os.Remove(crtF.Name())
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
require.EqualError(t, signCert(args, ob, eb, nopw), "error while signing: certificate expires after signing certificate")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
// create valid cert/key for overwrite tests
os.Remove(keyF.Name())
os.Remove(crtF.Name())
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.Nil(t, signCert(args, ob, eb, nopw))
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
require.NoError(t, signCert(args, ob, eb, nopw))
// test that we won't overwrite existing key file
os.Remove(crtF.Name())
ob.Reset()
eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name())
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
require.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name())
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
// create valid cert/key for overwrite tests
os.Remove(keyF.Name())
os.Remove(crtF.Name())
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.Nil(t, signCert(args, ob, eb, nopw))
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
require.NoError(t, signCert(args, ob, eb, nopw))
// test that we won't overwrite existing certificate file
os.Remove(keyF.Name())
ob.Reset()
eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name())
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
require.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name())
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
@ -349,11 +356,11 @@ func Test_signCert(t *testing.T) {
eb.Reset()
caKeyF, err = os.CreateTemp("", "sign-cert.key")
assert.Nil(t, err)
require.NoError(t, err)
defer os.Remove(caKeyF.Name())
caCrtF, err = os.CreateTemp("", "sign-cert.crt")
assert.Nil(t, err)
require.NoError(t, err)
defer os.Remove(caCrtF.Name())
// generate the encrypted key
@ -362,21 +369,13 @@ func Test_signCert(t *testing.T) {
b, _ = cert.EncryptAndMarshalSigningPrivateKey(cert.Curve_CURVE25519, caPriv, passphrase, kdfParams)
caKeyF.Write(b)
ca = cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
Name: "ca",
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Minute * 200),
PublicKey: caPub,
IsCA: true,
},
}
b, _ = ca.MarshalToPEM()
ca, _ = NewTestCaCert("ca", caPub, caPriv, time.Now(), time.Now().Add(time.Minute*200), nil, nil, nil)
b, _ = ca.MarshalPEM()
caCrtF.Write(b)
// test with the proper password
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.Nil(t, signCert(args, ob, eb, testpw))
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
require.NoError(t, signCert(args, ob, eb, testpw))
assert.Equal(t, "Enter passphrase: ", ob.String())
assert.Empty(t, eb.String())
@ -385,8 +384,8 @@ func Test_signCert(t *testing.T) {
eb.Reset()
testpw.password = []byte("invalid password")
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.Error(t, signCert(args, ob, eb, testpw))
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
require.Error(t, signCert(args, ob, eb, testpw))
assert.Equal(t, "Enter passphrase: ", ob.String())
assert.Empty(t, eb.String())
@ -394,8 +393,8 @@ func Test_signCert(t *testing.T) {
ob.Reset()
eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.Error(t, signCert(args, ob, eb, nopw))
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
require.Error(t, signCert(args, ob, eb, nopw))
// normally the user hitting enter on the prompt would add newlines between these
assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String())
assert.Empty(t, eb.String())
@ -404,8 +403,8 @@ func Test_signCert(t *testing.T) {
ob.Reset()
eb.Reset()
args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.Error(t, signCert(args, ob, eb, errpw))
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
require.Error(t, signCert(args, ob, eb, errpw))
assert.Equal(t, "Enter passphrase: ", ob.String())
assert.Empty(t, eb.String())
}

View File

@ -1,6 +1,7 @@
package main
import (
"errors"
"flag"
"fmt"
"io"
@ -41,14 +42,14 @@ func verify(args []string, out io.Writer, errOut io.Writer) error {
rawCACert, err := os.ReadFile(*vf.caPath)
if err != nil {
return fmt.Errorf("error while reading ca: %s", err)
return fmt.Errorf("error while reading ca: %w", err)
}
caPool := cert.NewCAPool()
for {
rawCACert, err = caPool.AddCACertificate(rawCACert)
rawCACert, err = caPool.AddCAFromPEM(rawCACert)
if err != nil {
return fmt.Errorf("error while adding ca cert to pool: %s", err)
return fmt.Errorf("error while adding ca cert to pool: %w", err)
}
if rawCACert == nil || len(rawCACert) == 0 || strings.TrimSpace(string(rawCACert)) == "" {
@ -58,20 +59,30 @@ func verify(args []string, out io.Writer, errOut io.Writer) error {
rawCert, err := os.ReadFile(*vf.certPath)
if err != nil {
return fmt.Errorf("unable to read crt; %s", err)
return fmt.Errorf("unable to read crt: %w", err)
}
c, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCert)
var errs []error
for {
if len(rawCert) == 0 {
break
}
c, extra, err := cert.UnmarshalCertificateFromPEM(rawCert)
if err != nil {
return fmt.Errorf("error while parsing crt: %s", err)
return fmt.Errorf("error while parsing crt: %w", err)
}
rawCert = extra
_, err = caPool.VerifyCertificate(time.Now(), c)
if err != nil {
switch {
case errors.Is(err, cert.ErrCaNotFound):
errs = append(errs, fmt.Errorf("error while verifying certificate v%d %s with issuer %s: %w", c.Version(), c.Name(), c.Issuer(), err))
default:
errs = append(errs, fmt.Errorf("error while verifying certificate %+v: %w", c, err))
}
}
}
good, err := c.Verify(time.Now(), caPool)
if !good {
return err
}
return nil
return errors.Join(errs...)
}
func verifySummary() string {
@ -80,7 +91,7 @@ func verifySummary() string {
func verifyHelp(out io.Writer) {
vf := newVerifyFlags()
out.Write([]byte("Usage of " + os.Args[0] + " " + verifySummary() + "\n"))
_, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + verifySummary() + "\n"))
vf.set.SetOutput(out)
vf.set.PrintDefaults()
}

View File

@ -9,6 +9,7 @@ import (
"github.com/slackhq/nebula/cert"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ed25519"
)
@ -50,34 +51,25 @@ func Test_verify(t *testing.T) {
err := verify([]string{"-ca", "does_not_exist", "-crt", "does_not_exist"}, ob, eb)
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
assert.EqualError(t, err, "error while reading ca: open does_not_exist: "+NoSuchFileError)
require.EqualError(t, err, "error while reading ca: open does_not_exist: "+NoSuchFileError)
// invalid ca at path
ob.Reset()
eb.Reset()
caFile, err := os.CreateTemp("", "verify-ca")
assert.Nil(t, err)
require.NoError(t, err)
defer os.Remove(caFile.Name())
caFile.WriteString("-----BEGIN NOPE-----")
err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
assert.EqualError(t, err, "error while adding ca cert to pool: input did not contain a valid PEM encoded block")
require.EqualError(t, err, "error while adding ca cert to pool: input did not contain a valid PEM encoded block")
// make a ca for later
caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader)
ca := cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
Name: "test-ca",
NotBefore: time.Now().Add(time.Hour * -1),
NotAfter: time.Now().Add(time.Hour * 2),
PublicKey: caPub,
IsCA: true,
},
}
ca.Sign(cert.Curve_CURVE25519, caPriv)
b, _ := ca.MarshalToPEM()
ca, _ := NewTestCaCert("test-ca", caPub, caPriv, time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour*2), nil, nil, nil)
b, _ := ca.MarshalPEM()
caFile.Truncate(0)
caFile.Seek(0, 0)
caFile.Write(b)
@ -86,38 +78,29 @@ func Test_verify(t *testing.T) {
err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
assert.EqualError(t, err, "unable to read crt; open does_not_exist: "+NoSuchFileError)
require.EqualError(t, err, "unable to read crt: open does_not_exist: "+NoSuchFileError)
// invalid crt at path
ob.Reset()
eb.Reset()
certFile, err := os.CreateTemp("", "verify-cert")
assert.Nil(t, err)
require.NoError(t, err)
defer os.Remove(certFile.Name())
certFile.WriteString("-----BEGIN NOPE-----")
err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
assert.EqualError(t, err, "error while parsing crt: input did not contain a valid PEM encoded block")
require.EqualError(t, err, "error while parsing crt: input did not contain a valid PEM encoded block")
// unverifiable cert at path
_, badPriv, _ := ed25519.GenerateKey(rand.Reader)
certPub, _ := x25519Keypair()
signer, _ := ca.Sha256Sum()
crt := cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
Name: "test-cert",
NotBefore: time.Now().Add(time.Hour * -1),
NotAfter: time.Now().Add(time.Hour),
PublicKey: certPub,
IsCA: false,
Issuer: signer,
},
crt, _ := NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil)
// Slightly evil hack to modify the certificate after it was sealed to generate an invalid signature
pub := crt.PublicKey()
for i, _ := range pub {
pub[i] = 0
}
crt.Sign(cert.Curve_CURVE25519, badPriv)
b, _ = crt.MarshalToPEM()
b, _ = crt.MarshalPEM()
certFile.Truncate(0)
certFile.Seek(0, 0)
certFile.Write(b)
@ -125,11 +108,11 @@ func Test_verify(t *testing.T) {
err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
assert.EqualError(t, err, "certificate signature did not match")
require.ErrorIs(t, err, cert.ErrSignatureMismatch)
// verified cert at path
crt.Sign(cert.Curve_CURVE25519, caPriv)
b, _ = crt.MarshalToPEM()
crt, _ = NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil)
b, _ = crt.MarshalPEM()
certFile.Truncate(0)
certFile.Seek(0, 0)
certFile.Write(b)
@ -137,5 +120,5 @@ func Test_verify(t *testing.T) {
err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String())
assert.Nil(t, err)
require.NoError(t, err)
}

View File

@ -19,18 +19,18 @@ func TestConfig_Load(t *testing.T) {
// invalid yaml
c := NewC(l)
os.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644)
assert.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}")
require.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}")
// simple multi config merge
c = NewC(l)
os.RemoveAll(dir)
os.Mkdir(dir, 0755)
assert.Nil(t, err)
require.NoError(t, err)
os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644)
os.WriteFile(filepath.Join(dir, "02.yml"), []byte("outer:\n inner: override\nnew: hi"), 0644)
assert.Nil(t, c.Load(dir))
require.NoError(t, c.Load(dir))
expected := map[interface{}]interface{}{
"outer": map[interface{}]interface{}{
"inner": "override",
@ -38,9 +38,6 @@ func TestConfig_Load(t *testing.T) {
"new": "hi",
}
assert.Equal(t, expected, c.Settings)
//TODO: test symlinked file
//TODO: test symlinked directory
}
func TestConfig_Get(t *testing.T) {
@ -70,28 +67,28 @@ func TestConfig_GetBool(t *testing.T) {
l := test.NewLogger()
c := NewC(l)
c.Settings["bool"] = true
assert.Equal(t, true, c.GetBool("bool", false))
assert.True(t, c.GetBool("bool", false))
c.Settings["bool"] = "true"
assert.Equal(t, true, c.GetBool("bool", false))
assert.True(t, c.GetBool("bool", false))
c.Settings["bool"] = false
assert.Equal(t, false, c.GetBool("bool", true))
assert.False(t, c.GetBool("bool", true))
c.Settings["bool"] = "false"
assert.Equal(t, false, c.GetBool("bool", true))
assert.False(t, c.GetBool("bool", true))
c.Settings["bool"] = "Y"
assert.Equal(t, true, c.GetBool("bool", false))
assert.True(t, c.GetBool("bool", false))
c.Settings["bool"] = "yEs"
assert.Equal(t, true, c.GetBool("bool", false))
assert.True(t, c.GetBool("bool", false))
c.Settings["bool"] = "N"
assert.Equal(t, false, c.GetBool("bool", true))
assert.False(t, c.GetBool("bool", true))
c.Settings["bool"] = "nO"
assert.Equal(t, false, c.GetBool("bool", true))
assert.False(t, c.GetBool("bool", true))
}
func TestConfig_HasChanged(t *testing.T) {
@ -120,11 +117,11 @@ func TestConfig_ReloadConfig(t *testing.T) {
l := test.NewLogger()
done := make(chan bool, 1)
dir, err := os.MkdirTemp("", "config-test")
assert.Nil(t, err)
require.NoError(t, err)
os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644)
c := NewC(l)
assert.Nil(t, c.Load(dir))
require.NoError(t, c.Load(dir))
assert.False(t, c.HasChanged("outer.inner"))
assert.False(t, c.HasChanged("outer"))

View File

@ -183,7 +183,7 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte,
case deleteTunnel:
if n.hostMap.DeleteHostInfo(hostinfo) {
// Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap
n.intf.lightHouse.DeleteVpnIp(hostinfo.vpnIp)
n.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs)
}
case closeTunnel:
@ -221,7 +221,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
relayFor := oldhostinfo.relayState.CopyAllRelayFor()
for _, r := range relayFor {
existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerIp)
existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerAddr)
var index uint32
var relayFrom netip.Addr
@ -235,11 +235,11 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
index = existing.LocalIndex
switch r.Type {
case TerminalType:
relayFrom = n.intf.myVpnNet.Addr()
relayTo = existing.PeerIp
relayFrom = n.intf.myVpnAddrs[0]
relayTo = existing.PeerAddr
case ForwardingType:
relayFrom = existing.PeerIp
relayTo = newhostinfo.vpnIp
relayFrom = existing.PeerAddr
relayTo = newhostinfo.vpnAddrs[0]
default:
// should never happen
}
@ -253,45 +253,64 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
n.relayUsedLock.RUnlock()
// The relay doesn't exist at all; create some relay state and send the request.
var err error
index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerIp, nil, r.Type, Requested)
index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerAddr, nil, r.Type, Requested)
if err != nil {
n.l.WithError(err).Error("failed to migrate relay to new hostinfo")
continue
}
switch r.Type {
case TerminalType:
relayFrom = n.intf.myVpnNet.Addr()
relayTo = r.PeerIp
relayFrom = n.intf.myVpnAddrs[0]
relayTo = r.PeerAddr
case ForwardingType:
relayFrom = r.PeerIp
relayTo = newhostinfo.vpnIp
relayFrom = r.PeerAddr
relayTo = newhostinfo.vpnAddrs[0]
default:
// should never happen
}
}
//TODO: IPV6-WORK
relayFromB := relayFrom.As4()
relayToB := relayTo.As4()
// Send a CreateRelayRequest to the peer.
req := NebulaControl{
Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: index,
RelayFromIp: binary.BigEndian.Uint32(relayFromB[:]),
RelayToIp: binary.BigEndian.Uint32(relayToB[:]),
}
switch newhostinfo.GetCert().Certificate.Version() {
case cert.Version1:
if !relayFrom.Is4() {
n.l.Error("can not migrate v1 relay with a v6 network because the relay is not running a current nebula version")
continue
}
if !relayTo.Is4() {
n.l.Error("can not migrate v1 relay with a v6 remote network because the relay is not running a current nebula version")
continue
}
b := relayFrom.As4()
req.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
b = relayTo.As4()
req.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
case cert.Version2:
req.RelayFromAddr = netAddrToProtoAddr(relayFrom)
req.RelayToAddr = netAddrToProtoAddr(relayTo)
default:
newhostinfo.logger(n.l).Error("Unknown certificate version found while attempting to migrate relay")
continue
}
msg, err := req.Marshal()
if err != nil {
n.l.WithError(err).Error("failed to marshal Control message to migrate relay")
} else {
n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
n.l.WithFields(logrus.Fields{
"relayFrom": req.RelayFromIp,
"relayTo": req.RelayToIp,
"relayFrom": req.RelayFromAddr,
"relayTo": req.RelayToAddr,
"initiatorRelayIndex": req.InitiatorRelayIndex,
"responderRelayIndex": req.ResponderRelayIndex,
"vpnIp": newhostinfo.vpnIp}).
"vpnAddrs": newhostinfo.vpnAddrs}).
Info("send CreateRelayRequest")
}
}
@ -313,7 +332,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time
return closeTunnel, hostinfo, nil
}
primary := n.hostMap.Hosts[hostinfo.vpnIp]
primary := n.hostMap.Hosts[hostinfo.vpnAddrs[0]]
mainHostInfo := true
if primary != nil && primary != hostinfo {
mainHostInfo = false
@ -407,21 +426,24 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
// Let's sort this out.
if current.vpnIp.Compare(n.intf.myVpnNet.Addr()) < 0 {
// Only one side should flip primary because if both flip then we may never resolve to a single tunnel.
// vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping.
// The remotes vpn ip is lower than mine. I will not flip.
// Only one side should swap because if both swap then we may never resolve to a single tunnel.
// vpn addr is static across all tunnels for this host pair so lets
// use that to determine if we should consider swapping.
if current.vpnAddrs[0].Compare(n.intf.myVpnAddrs[0]) < 0 {
// Their primary vpn addr is less than mine. Do not swap.
return false
}
certState := n.intf.pki.GetCertState()
return bytes.Equal(current.ConnectionState.myCert.Signature, certState.Certificate.Signature)
crt := n.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
// settle down.
return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
}
func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
n.hostMap.Lock()
// Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake.
if n.hostMap.Hosts[current.vpnIp] == primary {
if n.hostMap.Hosts[current.vpnAddrs[0]] == primary {
n.hostMap.unlockedMakePrimary(current)
}
n.hostMap.Unlock()
@ -436,8 +458,9 @@ func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostIn
return false
}
valid, err := remoteCert.VerifyWithCache(now, n.intf.pki.GetCAPool())
if valid {
caPool := n.intf.pki.GetCAPool()
err := caPool.VerifyCachedCertificate(now, remoteCert)
if err == nil {
return false
}
@ -446,9 +469,8 @@ func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostIn
return false
}
fingerprint, _ := remoteCert.Sha256Sum()
hostinfo.logger(n.l).WithError(err).
WithField("fingerprint", fingerprint).
WithField("fingerprint", remoteCert.Fingerprint).
Info("Remote certificate is no longer valid, tearing down the tunnel")
return true
@ -473,14 +495,17 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
}
func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
certState := n.intf.pki.GetCertState()
if bytes.Equal(hostinfo.ConnectionState.myCert.Signature, certState.Certificate.Signature) {
cs := n.intf.pki.getCertState()
curCrt := hostinfo.ConnectionState.myCert
myCrt := cs.getCertificate(curCrt.Version())
if curCrt.Version() >= cs.defaultVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
// The current tunnel is using the latest certificate and version, no need to rehandshake.
return
}
n.l.WithField("vpnIp", hostinfo.vpnIp).
n.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("reason", "local certificate is not current").
Info("Re-handshaking with remote")
n.intf.handshakeManager.StartHandshake(hostinfo.vpnIp, nil)
n.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
}

View File

@ -4,7 +4,6 @@ import (
"context"
"crypto/ed25519"
"crypto/rand"
"net"
"net/netip"
"testing"
"time"
@ -15,6 +14,7 @@ import (
"github.com/slackhq/nebula/test"
"github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func newTestLighthouse() *LightHouse {
@ -35,20 +35,19 @@ func newTestLighthouse() *LightHouse {
func Test_NewConnectionManagerTest(t *testing.T) {
l := test.NewLogger()
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
vpncidr := netip.MustParsePrefix("172.1.1.1/24")
localrange := netip.MustParsePrefix("10.1.1.1/24")
vpnIp := netip.MustParseAddr("172.1.1.2")
preferredRanges := []netip.Prefix{localrange}
// Very incomplete mock objects
hostMap := newHostMap(l, vpncidr)
hostMap := newHostMap(l)
hostMap.preferredRanges.Store(&preferredRanges)
cs := &CertState{
RawCertificate: []byte{},
PrivateKey: []byte{},
Certificate: &cert.NebulaCertificate{},
RawCertificateNoKey: []byte{},
defaultVersion: cert.Version1,
privateKey: []byte{},
v1Cert: &dummyCert{version: cert.Version1},
v1HandshakeBytes: []byte{},
}
lh := newTestLighthouse()
@ -75,12 +74,12 @@ func Test_NewConnectionManagerTest(t *testing.T) {
// Add an ip we have established a connection w/ to hostmap
hostinfo := &HostInfo{
vpnIp: vpnIp,
vpnAddrs: []netip.Addr{vpnIp},
localIndexId: 1099,
remoteIndexId: 9901,
}
hostinfo.ConnectionState = &ConnectionState{
myCert: &cert.NebulaCertificate{},
myCert: &dummyCert{version: cert.Version1},
H: &noise.HandshakeState{},
}
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
@ -89,7 +88,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
nc.Out(hostinfo.localIndexId)
nc.In(hostinfo.localIndexId)
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.out, hostinfo.localIndexId)
@ -106,32 +105,31 @@ func Test_NewConnectionManagerTest(t *testing.T) {
assert.NotContains(t, nc.out, hostinfo.localIndexId)
assert.NotContains(t, nc.in, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
// Do a final traffic check tick, the host should now be removed
nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
}
func Test_NewConnectionManagerTest2(t *testing.T) {
l := test.NewLogger()
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
vpncidr := netip.MustParsePrefix("172.1.1.1/24")
localrange := netip.MustParsePrefix("10.1.1.1/24")
vpnIp := netip.MustParseAddr("172.1.1.2")
preferredRanges := []netip.Prefix{localrange}
// Very incomplete mock objects
hostMap := newHostMap(l, vpncidr)
hostMap := newHostMap(l)
hostMap.preferredRanges.Store(&preferredRanges)
cs := &CertState{
RawCertificate: []byte{},
PrivateKey: []byte{},
Certificate: &cert.NebulaCertificate{},
RawCertificateNoKey: []byte{},
defaultVersion: cert.Version1,
privateKey: []byte{},
v1Cert: &dummyCert{version: cert.Version1},
v1HandshakeBytes: []byte{},
}
lh := newTestLighthouse()
@ -158,12 +156,12 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
// Add an ip we have established a connection w/ to hostmap
hostinfo := &HostInfo{
vpnIp: vpnIp,
vpnAddrs: []netip.Addr{vpnIp},
localIndexId: 1099,
remoteIndexId: 9901,
}
hostinfo.ConnectionState = &ConnectionState{
myCert: &cert.NebulaCertificate{},
myCert: &dummyCert{version: cert.Version1},
H: &noise.HandshakeState{},
}
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
@ -171,8 +169,8 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
// We saw traffic out to vpnIp
nc.Out(hostinfo.localIndexId)
nc.In(hostinfo.localIndexId)
assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnIp)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnAddrs[0])
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
@ -188,7 +186,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
assert.NotContains(t, nc.out, hostinfo.localIndexId)
assert.NotContains(t, nc.in, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
// We saw traffic, should no longer be pending deletion
nc.In(hostinfo.localIndexId)
@ -197,7 +195,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
assert.NotContains(t, nc.out, hostinfo.localIndexId)
assert.NotContains(t, nc.in, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
}
// Check if we can disconnect the peer.
@ -206,55 +204,48 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
now := time.Now()
l := test.NewLogger()
ipNet := net.IPNet{
IP: net.IPv4(172, 1, 1, 2),
Mask: net.IPMask{255, 255, 255, 0},
}
vpncidr := netip.MustParsePrefix("172.1.1.1/24")
localrange := netip.MustParsePrefix("10.1.1.1/24")
vpnIp := netip.MustParseAddr("172.1.1.2")
preferredRanges := []netip.Prefix{localrange}
hostMap := newHostMap(l, vpncidr)
hostMap := newHostMap(l)
hostMap.preferredRanges.Store(&preferredRanges)
// Generate keys for CA and peer's cert.
pubCA, privCA, _ := ed25519.GenerateKey(rand.Reader)
caCert := cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
tbs := &cert.TBSCertificate{
Version: 1,
Name: "ca",
IsCA: true,
NotBefore: now,
NotAfter: now.Add(1 * time.Hour),
IsCA: true,
PublicKey: pubCA,
},
}
assert.NoError(t, caCert.Sign(cert.Curve_CURVE25519, privCA))
ncp := &cert.NebulaCAPool{
CAs: cert.NewCAPool().CAs,
}
ncp.CAs["ca"] = &caCert
caCert, err := tbs.Sign(nil, cert.Curve_CURVE25519, privCA)
require.NoError(t, err)
ncp := cert.NewCAPool()
require.NoError(t, ncp.AddCA(caCert))
pubCrt, _, _ := ed25519.GenerateKey(rand.Reader)
peerCert := cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
tbs = &cert.TBSCertificate{
Version: 1,
Name: "host",
Ips: []*net.IPNet{&ipNet},
Subnets: []*net.IPNet{},
Networks: []netip.Prefix{vpncidr},
NotBefore: now,
NotAfter: now.Add(60 * time.Second),
PublicKey: pubCrt,
IsCA: false,
Issuer: "ca",
},
}
assert.NoError(t, peerCert.Sign(cert.Curve_CURVE25519, privCA))
peerCert, err := tbs.Sign(caCert, cert.Curve_CURVE25519, privCA)
require.NoError(t, err)
cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert)
cs := &CertState{
RawCertificate: []byte{},
PrivateKey: []byte{},
Certificate: &cert.NebulaCertificate{},
RawCertificateNoKey: []byte{},
privateKey: []byte{},
v1Cert: &dummyCert{},
v1HandshakeBytes: []byte{},
}
lh := newTestLighthouse()
@ -280,10 +271,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
ifce.connectionManager = nc
hostinfo := &HostInfo{
vpnIp: vpnIp,
vpnAddrs: []netip.Addr{vpnIp},
ConnectionState: &ConnectionState{
myCert: &cert.NebulaCertificate{},
peerCert: &peerCert,
myCert: &dummyCert{},
peerCert: cachedPeerCert,
H: &noise.HandshakeState{},
},
}
@ -303,3 +294,114 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
invalid = nc.isInvalidCertificate(nextTick, hostinfo)
assert.True(t, invalid)
}
type dummyCert struct {
version cert.Version
curve cert.Curve
groups []string
isCa bool
issuer string
name string
networks []netip.Prefix
notAfter time.Time
notBefore time.Time
publicKey []byte
signature []byte
unsafeNetworks []netip.Prefix
}
func (d *dummyCert) Version() cert.Version {
return d.version
}
func (d *dummyCert) Curve() cert.Curve {
return d.curve
}
func (d *dummyCert) Groups() []string {
return d.groups
}
func (d *dummyCert) IsCA() bool {
return d.isCa
}
func (d *dummyCert) Issuer() string {
return d.issuer
}
func (d *dummyCert) Name() string {
return d.name
}
func (d *dummyCert) Networks() []netip.Prefix {
return d.networks
}
func (d *dummyCert) NotAfter() time.Time {
return d.notAfter
}
func (d *dummyCert) NotBefore() time.Time {
return d.notBefore
}
func (d *dummyCert) PublicKey() []byte {
return d.publicKey
}
func (d *dummyCert) Signature() []byte {
return d.signature
}
func (d *dummyCert) UnsafeNetworks() []netip.Prefix {
return d.unsafeNetworks
}
func (d *dummyCert) MarshalForHandshakes() ([]byte, error) {
return nil, nil
}
func (d *dummyCert) Sign(curve cert.Curve, key []byte) error {
return nil
}
func (d *dummyCert) CheckSignature(key []byte) bool {
return true
}
func (d *dummyCert) Expired(t time.Time) bool {
return false
}
func (d *dummyCert) CheckRootConstraints(signer cert.Certificate) error {
return nil
}
func (d *dummyCert) VerifyPrivateKey(curve cert.Curve, key []byte) error {
return nil
}
func (d *dummyCert) String() string {
return ""
}
func (d *dummyCert) Marshal() ([]byte, error) {
return nil, nil
}
func (d *dummyCert) MarshalPEM() ([]byte, error) {
return nil, nil
}
func (d *dummyCert) Fingerprint() (string, error) {
return "", nil
}
func (d *dummyCert) MarshalJSON() ([]byte, error) {
return nil, nil
}
func (d *dummyCert) Copy() cert.Certificate {
return d
}

View File

@ -3,6 +3,7 @@ package nebula
import (
"crypto/rand"
"encoding/json"
"fmt"
"sync"
"sync/atomic"
@ -18,54 +19,54 @@ type ConnectionState struct {
eKey *NebulaCipherState
dKey *NebulaCipherState
H *noise.HandshakeState
myCert *cert.NebulaCertificate
peerCert *cert.NebulaCertificate
myCert cert.Certificate
peerCert *cert.CachedCertificate
initiator bool
messageCounter atomic.Uint64
window *Bits
writeLock sync.Mutex
}
func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) {
var dhFunc noise.DHFunc
switch certState.Certificate.Details.Curve {
switch crt.Curve() {
case cert.Curve_CURVE25519:
dhFunc = noise.DH25519
case cert.Curve_P256:
if certState.Certificate.Pkcs11Backed {
if cs.pkcs11Backed {
dhFunc = noiseutil.DHP256PKCS11
} else {
dhFunc = noiseutil.DHP256
}
default:
l.Errorf("invalid curve: %s", certState.Certificate.Details.Curve)
return nil
return nil, fmt.Errorf("invalid curve: %s", crt.Curve())
}
var cs noise.CipherSuite
if cipher == "chachapoly" {
cs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
var ncs noise.CipherSuite
if cs.cipher == "chachapoly" {
ncs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
} else {
cs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
ncs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
}
static := noise.DHKey{Private: certState.PrivateKey, Public: certState.PublicKey}
static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()}
b := NewBits(ReplayWindow)
// Clear out bit 0, we never transmit it and we don't want it showing as packet loss
// Clear out bit 0, we never transmit it, and we don't want it showing as packet loss
b.Update(l, 0)
hs, err := noise.NewHandshakeState(noise.Config{
CipherSuite: cs,
CipherSuite: ncs,
Random: rand.Reader,
Pattern: pattern,
Initiator: initiator,
StaticKeypair: static,
PresharedKey: psk,
PresharedKeyPlacement: pskStage,
//NOTE: These should come from CertState (pki.go) when we finally implement it
PresharedKey: []byte{},
PresharedKeyPlacement: 0,
})
if err != nil {
return nil
return nil, fmt.Errorf("NewConnectionState: %s", err)
}
// The queue and ready params prevent a counter race that would happen when
@ -74,12 +75,12 @@ func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, i
H: hs,
initiator: initiator,
window: b,
myCert: certState.Certificate,
myCert: crt,
}
// always start the counter from 2, as packet 1 and packet 2 are handshake packets.
ci.messageCounter.Add(2)
return ci
return ci, nil
}
func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
@ -89,3 +90,7 @@ func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
"message_counter": cs.messageCounter.Load(),
})
}
func (cs *ConnectionState) Curve() cert.Curve {
return cs.myCert.Curve()
}

View File

@ -19,9 +19,9 @@ import (
type controlEach func(h *HostInfo)
type controlHostLister interface {
QueryVpnIp(vpnIp netip.Addr) *HostInfo
QueryVpnAddr(vpnAddr netip.Addr) *HostInfo
ForEachIndex(each controlEach)
ForEachVpnIp(each controlEach)
ForEachVpnAddr(each controlEach)
GetPreferredRanges() []netip.Prefix
}
@ -37,11 +37,11 @@ type Control struct {
}
type ControlHostInfo struct {
VpnIp netip.Addr `json:"vpnIp"`
VpnAddrs []netip.Addr `json:"vpnAddrs"`
LocalIndex uint32 `json:"localIndex"`
RemoteIndex uint32 `json:"remoteIndex"`
RemoteAddrs []netip.AddrPort `json:"remoteAddrs"`
Cert *cert.NebulaCertificate `json:"cert"`
Cert cert.Certificate `json:"cert"`
MessageCounter uint64 `json:"messageCounter"`
CurrentRemote netip.AddrPort `json:"currentRemote"`
CurrentRelaysToMe []netip.Addr `json:"currentRelaysToMe"`
@ -130,15 +130,18 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
}
// GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found
func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) *cert.NebulaCertificate {
if c.f.myVpnNet.Addr() == vpnIp {
return c.f.pki.GetCertState().Certificate
func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate {
_, found := c.f.myVpnAddrsTable.Lookup(vpnIp)
if found {
// Only returning the default certificate since its impossible
// for any other host but ourselves to have more than 1
return c.f.pki.getCertState().GetDefaultCertificate().Copy()
}
hi := c.f.hostMap.QueryVpnIp(vpnIp)
hi := c.f.hostMap.QueryVpnAddr(vpnIp)
if hi == nil {
return nil
}
return hi.GetCert()
return hi.GetCert().Certificate.Copy()
}
// CreateTunnel creates a new tunnel to the given vpn ip.
@ -148,7 +151,7 @@ func (c *Control) CreateTunnel(vpnIp netip.Addr) {
// PrintTunnel creates a new tunnel to the given vpn ip.
func (c *Control) PrintTunnel(vpnIp netip.Addr) *ControlHostInfo {
hi := c.f.hostMap.QueryVpnIp(vpnIp)
hi := c.f.hostMap.QueryVpnAddr(vpnIp)
if hi == nil {
return nil
}
@ -165,9 +168,9 @@ func (c *Control) QueryLighthouse(vpnIp netip.Addr) *CacheMap {
return hi.CopyCache()
}
// GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found
// GetHostInfoByVpnAddr returns a single tunnels hostInfo, or nil if not found
// Caller should take care to Unmap() any 4in6 addresses prior to calling.
func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHostInfo {
func (c *Control) GetHostInfoByVpnAddr(vpnAddr netip.Addr, pending bool) *ControlHostInfo {
var hl controlHostLister
if pending {
hl = c.f.handshakeManager
@ -175,7 +178,7 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHos
hl = c.f.hostMap
}
h := hl.QueryVpnIp(vpnIp)
h := hl.QueryVpnAddr(vpnAddr)
if h == nil {
return nil
}
@ -187,7 +190,7 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHos
// SetRemoteForTunnel forces a tunnel to use a specific remote
// Caller should take care to Unmap() any 4in6 addresses prior to calling.
func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *ControlHostInfo {
hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
hostInfo := c.f.hostMap.QueryVpnAddr(vpnIp)
if hostInfo == nil {
return nil
}
@ -200,7 +203,7 @@ func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *Con
// CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
// Caller should take care to Unmap() any 4in6 addresses prior to calling.
func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool {
hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
hostInfo := c.f.hostMap.QueryVpnAddr(vpnIp)
if hostInfo == nil {
return false
}
@ -224,19 +227,14 @@ func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool {
// CloseAllTunnels is just like CloseTunnel except it goes through and shuts them all down, optionally you can avoid shutting down lighthouse tunnels
// the int returned is a count of tunnels closed
func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
//TODO: this is probably better as a function in ConnectionManager or HostMap directly
lighthouses := c.f.lightHouse.GetLighthouses()
shutdown := func(h *HostInfo) {
if excludeLighthouses {
if _, ok := lighthouses[h.vpnIp]; ok {
if excludeLighthouses && c.f.lightHouse.IsAnyLighthouseAddr(h.vpnAddrs) {
return
}
}
c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
c.f.closeTunnel(h)
c.l.WithField("vpnIp", h.vpnIp).WithField("udpAddr", h.remote).
c.l.WithField("vpnAddrs", h.vpnAddrs).WithField("udpAddr", h.remote).
Debug("Sending close tunnel message")
closed++
}
@ -246,7 +244,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
// Grab the hostMap lock to access the Relays map
c.f.hostMap.Lock()
for _, relayingHost := range c.f.hostMap.Relays {
relayingHosts[relayingHost.vpnIp] = relayingHost
relayingHosts[relayingHost.vpnAddrs[0]] = relayingHost
}
c.f.hostMap.Unlock()
@ -254,7 +252,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
// Grab the hostMap lock to access the Hosts map
c.f.hostMap.Lock()
for _, relayHost := range c.f.hostMap.Indexes {
if _, ok := relayingHosts[relayHost.vpnIp]; !ok {
if _, ok := relayingHosts[relayHost.vpnAddrs[0]]; !ok {
hostInfos = append(hostInfos, relayHost)
}
}
@ -274,9 +272,8 @@ func (c *Control) Device() overlay.Device {
}
func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
chi := ControlHostInfo{
VpnIp: h.vpnIp,
VpnAddrs: make([]netip.Addr, len(h.vpnAddrs)),
LocalIndex: h.localIndexId,
RemoteIndex: h.remoteIndexId,
RemoteAddrs: h.remotes.CopyAddrs(preferredRanges),
@ -285,12 +282,16 @@ func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
CurrentRemote: h.remote,
}
for i, a := range h.vpnAddrs {
chi.VpnAddrs[i] = a
}
if h.ConnectionState != nil {
chi.MessageCounter = h.ConnectionState.messageCounter.Load()
}
if c := h.GetCert(); c != nil {
chi.Cert = c.Copy()
chi.Cert = c.Certificate.Copy()
}
return chi
@ -299,7 +300,7 @@ func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
func listHostMapHosts(hl controlHostLister) []ControlHostInfo {
hosts := make([]ControlHostInfo, 0)
pr := hl.GetPreferredRanges()
hl.ForEachVpnIp(func(hostinfo *HostInfo) {
hl.ForEachVpnAddr(func(hostinfo *HostInfo) {
hosts = append(hosts, copyHostInfo(hostinfo, pr))
})
return hosts

View File

@ -5,7 +5,6 @@ import (
"net/netip"
"reflect"
"testing"
"time"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
@ -14,10 +13,13 @@ import (
)
func TestControl_GetHostInfoByVpnIp(t *testing.T) {
//TODO: CERT-V2 with multiple certificate versions we have a problem with this test
// Some certs versions have different characteristics and each version implements their own Copy() func
// which means this is not a good place to test for exposing memory
l := test.NewLogger()
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
// To properly ensure we are not exposing core memory to the caller
hm := newHostMap(l, netip.Prefix{})
hm := newHostMap(l)
hm.preferredRanges.Store(&[]netip.Prefix{})
remote1 := netip.MustParseAddrPort("0.0.0.100:4444")
@ -33,41 +35,26 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
Mask: net.IPMask{255, 255, 255, 0},
}
crt := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
Name: "test",
Ips: []*net.IPNet{&ipNet},
Subnets: []*net.IPNet{},
Groups: []string{"default-group"},
NotBefore: time.Unix(1, 0),
NotAfter: time.Unix(2, 0),
PublicKey: []byte{5, 6, 7, 8},
IsCA: false,
Issuer: "the-issuer",
InvertedGroups: map[string]struct{}{"default-group": {}},
},
Signature: []byte{1, 2, 1, 2, 1, 3},
}
remotes := NewRemoteList(nil)
remotes.unlockedPrependV4(netip.IPv4Unspecified(), NewIp4AndPortFromNetIP(remote1.Addr(), remote1.Port()))
remotes.unlockedPrependV6(netip.IPv4Unspecified(), NewIp6AndPortFromNetIP(remote2.Addr(), remote2.Port()))
remotes := NewRemoteList([]netip.Addr{netip.IPv4Unspecified()}, nil)
remotes.unlockedPrependV4(netip.IPv4Unspecified(), netAddrToProtoV4AddrPort(remote1.Addr(), remote1.Port()))
remotes.unlockedPrependV6(netip.IPv4Unspecified(), netAddrToProtoV6AddrPort(remote2.Addr(), remote2.Port()))
vpnIp, ok := netip.AddrFromSlice(ipNet.IP)
assert.True(t, ok)
crt := &dummyCert{}
hm.unlockedAddHostInfo(&HostInfo{
remote: remote1,
remotes: remotes,
ConnectionState: &ConnectionState{
peerCert: crt,
peerCert: &cert.CachedCertificate{Certificate: crt},
},
remoteIndexId: 200,
localIndexId: 201,
vpnIp: vpnIp,
vpnAddrs: []netip.Addr{vpnIp},
relayState: RelayState{
relays: map[netip.Addr]struct{}{},
relayForByIp: map[netip.Addr]*Relay{},
relayForByAddr: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},
}, &Interface{})
@ -83,10 +70,10 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
},
remoteIndexId: 200,
localIndexId: 201,
vpnIp: vpnIp2,
vpnAddrs: []netip.Addr{vpnIp2},
relayState: RelayState{
relays: map[netip.Addr]struct{}{},
relayForByIp: map[netip.Addr]*Relay{},
relayForByAddr: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},
}, &Interface{})
@ -98,10 +85,10 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
l: logrus.New(),
}
thi := c.GetHostInfoByVpnIp(vpnIp, false)
thi := c.GetHostInfoByVpnAddr(vpnIp, false)
expectedInfo := ControlHostInfo{
VpnIp: vpnIp,
VpnAddrs: []netip.Addr{vpnIp},
LocalIndex: 201,
RemoteIndex: 200,
RemoteAddrs: []netip.AddrPort{remote2, remote1},
@ -113,14 +100,13 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
}
// Make sure we don't have any unexpected fields
assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
assertFields(t, []string{"VpnAddrs", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
assert.EqualValues(t, &expectedInfo, thi)
//TODO: netip.Addr reuses global memory for zone identifiers which breaks our "no reused memory check" here
//test.AssertDeepCopyEqual(t, &expectedInfo, thi)
test.AssertDeepCopyEqual(t, &expectedInfo, thi)
// Make sure we don't panic if the host info doesn't have a cert yet
assert.NotPanics(t, func() {
thi = c.GetHostInfoByVpnIp(vpnIp2, false)
thi = c.GetHostInfoByVpnAddr(vpnIp2, false)
})
}

View File

@ -6,8 +6,6 @@ package nebula
import (
"net/netip"
"github.com/slackhq/nebula/cert"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/slackhq/nebula/header"
@ -51,15 +49,15 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType,
// This is necessary if you did not configure static hosts or are not running a lighthouse
func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort) {
c.f.lightHouse.Lock()
remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp)
remoteList := c.f.lightHouse.unlockedGetRemoteList([]netip.Addr{vpnIp})
remoteList.Lock()
defer remoteList.Unlock()
c.f.lightHouse.Unlock()
if toAddr.Addr().Is4() {
remoteList.unlockedPrependV4(vpnIp, NewIp4AndPortFromNetIP(toAddr.Addr(), toAddr.Port()))
remoteList.unlockedPrependV4(vpnIp, netAddrToProtoV4AddrPort(toAddr.Addr(), toAddr.Port()))
} else {
remoteList.unlockedPrependV6(vpnIp, NewIp6AndPortFromNetIP(toAddr.Addr(), toAddr.Port()))
remoteList.unlockedPrependV6(vpnIp, netAddrToProtoV6AddrPort(toAddr.Addr(), toAddr.Port()))
}
}
@ -67,12 +65,12 @@ func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort)
// This is necessary to inform an initiator of possible relays for communicating with a responder
func (c *Control) InjectRelays(vpnIp netip.Addr, relayVpnIps []netip.Addr) {
c.f.lightHouse.Lock()
remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp)
remoteList := c.f.lightHouse.unlockedGetRemoteList([]netip.Addr{vpnIp})
remoteList.Lock()
defer remoteList.Unlock()
c.f.lightHouse.Unlock()
remoteList.unlockedSetRelay(vpnIp, vpnIp, relayVpnIps)
remoteList.unlockedSetRelay(vpnIp, relayVpnIps)
}
// GetFromTun will pull a packet off the tun side of nebula
@ -99,21 +97,42 @@ func (c *Control) InjectUDPPacket(p *udp.Packet) {
}
// InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol
func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort uint16, data []byte) {
//TODO: IPV6-WORK
ip := layers.IPv4{
func (c *Control) InjectTunUDPPacket(toAddr netip.Addr, toPort uint16, fromAddr netip.Addr, fromPort uint16, data []byte) {
serialize := make([]gopacket.SerializableLayer, 0)
var netLayer gopacket.NetworkLayer
if toAddr.Is6() {
if !fromAddr.Is6() {
panic("Cant send ipv6 to ipv4")
}
ip := &layers.IPv6{
Version: 6,
NextHeader: layers.IPProtocolUDP,
SrcIP: fromAddr.Unmap().AsSlice(),
DstIP: toAddr.Unmap().AsSlice(),
}
serialize = append(serialize, ip)
netLayer = ip
} else {
if !fromAddr.Is4() {
panic("Cant send ipv4 to ipv6")
}
ip := &layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
SrcIP: c.f.inside.Cidr().Addr().Unmap().AsSlice(),
DstIP: toIp.Unmap().AsSlice(),
SrcIP: fromAddr.Unmap().AsSlice(),
DstIP: toAddr.Unmap().AsSlice(),
}
serialize = append(serialize, ip)
netLayer = ip
}
udp := layers.UDP{
SrcPort: layers.UDPPort(fromPort),
DstPort: layers.UDPPort(toPort),
}
err := udp.SetNetworkLayerForChecksum(&ip)
err := udp.SetNetworkLayerForChecksum(netLayer)
if err != nil {
panic(err)
}
@ -123,7 +142,9 @@ func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort ui
ComputeChecksums: true,
FixLengths: true,
}
err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload(data))
serialize = append(serialize, &udp, gopacket.Payload(data))
err = gopacket.SerializeLayers(buffer, opt, serialize...)
if err != nil {
panic(err)
}
@ -131,8 +152,8 @@ func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort ui
c.f.inside.(*overlay.TestTun).Send(buffer.Bytes())
}
func (c *Control) GetVpnIp() netip.Addr {
return c.f.myVpnNet.Addr()
func (c *Control) GetVpnAddrs() []netip.Addr {
return c.f.myVpnAddrs
}
func (c *Control) GetUDPAddr() netip.AddrPort {
@ -140,7 +161,7 @@ func (c *Control) GetUDPAddr() netip.AddrPort {
}
func (c *Control) KillPendingTunnel(vpnIp netip.Addr) bool {
hostinfo := c.f.handshakeManager.QueryVpnIp(vpnIp)
hostinfo := c.f.handshakeManager.QueryVpnAddr(vpnIp)
if hostinfo == nil {
return false
}
@ -153,8 +174,8 @@ func (c *Control) GetHostmap() *HostMap {
return c.f.hostMap
}
func (c *Control) GetCert() *cert.NebulaCertificate {
return c.f.pki.GetCertState().Certificate
func (c *Control) GetCertState() *CertState {
return c.f.pki.getCertState()
}
func (c *Control) ReHandshake(vpnIp netip.Addr) {

View File

@ -8,6 +8,7 @@ import (
"strings"
"sync"
"github.com/gaissmai/bart"
"github.com/miekg/dns"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
@ -21,24 +22,39 @@ var dnsAddr string
type dnsRecords struct {
sync.RWMutex
dnsMap map[string]string
l *logrus.Logger
dnsMap4 map[string]netip.Addr
dnsMap6 map[string]netip.Addr
hostMap *HostMap
myVpnAddrsTable *bart.Table[struct{}]
}
func newDnsRecords(hostMap *HostMap) *dnsRecords {
func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords {
return &dnsRecords{
dnsMap: make(map[string]string),
l: l,
dnsMap4: make(map[string]netip.Addr),
dnsMap6: make(map[string]netip.Addr),
hostMap: hostMap,
myVpnAddrsTable: cs.myVpnAddrsTable,
}
}
func (d *dnsRecords) Query(data string) string {
func (d *dnsRecords) Query(q uint16, data string) netip.Addr {
data = strings.ToLower(data)
d.RLock()
defer d.RUnlock()
if r, ok := d.dnsMap[strings.ToLower(data)]; ok {
switch q {
case dns.TypeA:
if r, ok := d.dnsMap4[data]; ok {
return r
}
return ""
case dns.TypeAAAA:
if r, ok := d.dnsMap6[data]; ok {
return r
}
}
return netip.Addr{}
}
func (d *dnsRecords) QueryCert(data string) string {
@ -47,7 +63,7 @@ func (d *dnsRecords) QueryCert(data string) string {
return ""
}
hostinfo := d.hostMap.QueryVpnIp(ip)
hostinfo := d.hostMap.QueryVpnAddr(ip)
if hostinfo == nil {
return ""
}
@ -57,43 +73,69 @@ func (d *dnsRecords) QueryCert(data string) string {
return ""
}
cert := q.Details
c := fmt.Sprintf("\"Name: %s\" \"Ips: %s\" \"Subnets %s\" \"Groups %s\" \"NotBefore %s\" \"NotAfter %s\" \"PublicKey %x\" \"IsCA %t\" \"Issuer %s\"", cert.Name, cert.Ips, cert.Subnets, cert.Groups, cert.NotBefore, cert.NotAfter, cert.PublicKey, cert.IsCA, cert.Issuer)
return c
b, err := q.Certificate.MarshalJSON()
if err != nil {
return ""
}
return string(b)
}
func (d *dnsRecords) Add(host, data string) {
// Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host`
func (d *dnsRecords) Add(host string, addresses []netip.Addr) {
host = strings.ToLower(host)
d.Lock()
defer d.Unlock()
d.dnsMap[strings.ToLower(host)] = data
haveV4 := false
haveV6 := false
for _, addr := range addresses {
if addr.Is4() && !haveV4 {
d.dnsMap4[host] = addr
haveV4 = true
} else if addr.Is6() && !haveV6 {
d.dnsMap6[host] = addr
haveV6 = true
}
if haveV4 && haveV6 {
break
}
}
}
func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool {
a, _, _ := net.SplitHostPort(addr)
b, err := netip.ParseAddr(a)
if err != nil {
return false
}
if b.IsLoopback() {
return true
}
_, found := d.myVpnAddrsTable.Lookup(b)
return found //if we found it in this table, it's good
}
func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
for _, q := range m.Question {
switch q.Qtype {
case dns.TypeA:
l.Debugf("Query for A %s", q.Name)
ip := dnsR.Query(q.Name)
if ip != "" {
rr, err := dns.NewRR(fmt.Sprintf("%s A %s", q.Name, ip))
case dns.TypeA, dns.TypeAAAA:
qType := dns.TypeToString[q.Qtype]
d.l.Debugf("Query for %s %s", qType, q.Name)
ip := d.Query(q.Qtype, q.Name)
if ip.IsValid() {
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip))
if err == nil {
m.Answer = append(m.Answer, rr)
}
}
case dns.TypeTXT:
a, _, _ := net.SplitHostPort(w.RemoteAddr().String())
b, err := netip.ParseAddr(a)
if err != nil {
// We only answer these queries from nebula nodes or localhost
if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) {
return
}
// We don't answer these queries from non nebula nodes or localhost
//l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR)
if !dnsR.hostMap.vpnCIDR.Contains(b) && a != "127.0.0.1" {
return
}
l.Debugf("Query for TXT %s", q.Name)
ip := dnsR.QueryCert(q.Name)
d.l.Debugf("Query for TXT %s", q.Name)
ip := d.QueryCert(q.Name)
if ip != "" {
rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
if err == nil {
@ -108,26 +150,24 @@ func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
}
}
func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) {
func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Compress = false
switch r.Opcode {
case dns.OpcodeQuery:
parseQuery(l, m, w)
d.parseQuery(m, w)
}
w.WriteMsg(m)
}
func dnsMain(l *logrus.Logger, hostMap *HostMap, c *config.C) func() {
dnsR = newDnsRecords(hostMap)
func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() {
dnsR = newDnsRecords(l, cs, hostMap)
// attach request handler func
dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) {
handleDnsRequest(l, w, r)
})
dns.HandleFunc(".", dnsR.handleDnsRequest)
c.RegisterReloadCallback(func(c *config.C) {
reloadDns(l, c)

View File

@ -1,23 +1,38 @@
package nebula
import (
"net/netip"
"testing"
"github.com/miekg/dns"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/stretchr/testify/assert"
)
func TestParsequery(t *testing.T) {
//TODO: This test is basically pointless
l := logrus.New()
hostMap := &HostMap{}
ds := newDnsRecords(hostMap)
ds.Add("test.com.com", "1.2.3.4")
ds := newDnsRecords(l, &CertState{}, hostMap)
addrs := []netip.Addr{
netip.MustParseAddr("1.2.3.4"),
netip.MustParseAddr("1.2.3.5"),
netip.MustParseAddr("fd01::24"),
netip.MustParseAddr("fd01::25"),
}
ds.Add("test.com.com", addrs)
m := new(dns.Msg)
m := &dns.Msg{}
m.SetQuestion("test.com.com", dns.TypeA)
ds.parseQuery(m, nil)
assert.NotNil(t, m.Answer)
assert.Equal(t, "1.2.3.4", m.Answer[0].(*dns.A).A.String())
//parseQuery(m)
m = &dns.Msg{}
m.SetQuestion("test.com.com", dns.TypeAAAA)
ds.parseQuery(m, nil)
assert.NotNil(t, m.Answer)
assert.Equal(t, "fd01::24", m.Answer[0].(*dns.AAAA).AAAA.String())
}
func Test_getDnsServerAddr(t *testing.T) {

File diff suppressed because it is too large Load Diff

View File

@ -1,125 +0,0 @@
package e2e
import (
"crypto/rand"
"io"
"net"
"net/netip"
"time"
"github.com/slackhq/nebula/cert"
"golang.org/x/crypto/curve25519"
"golang.org/x/crypto/ed25519"
)
// NewTestCaCert will generate a CA cert
func NewTestCaCert(before, after time.Time, ips, subnets []netip.Prefix, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if before.IsZero() {
before = time.Now().Add(time.Second * -60).Round(time.Second)
}
if after.IsZero() {
after = time.Now().Add(time.Second * 60).Round(time.Second)
}
nc := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
Name: "test ca",
NotBefore: time.Unix(before.Unix(), 0),
NotAfter: time.Unix(after.Unix(), 0),
PublicKey: pub,
IsCA: true,
InvertedGroups: make(map[string]struct{}),
},
}
if len(ips) > 0 {
nc.Details.Ips = make([]*net.IPNet, len(ips))
for i, ip := range ips {
nc.Details.Ips[i] = &net.IPNet{IP: ip.Addr().AsSlice(), Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())}
}
}
if len(subnets) > 0 {
nc.Details.Subnets = make([]*net.IPNet, len(subnets))
for i, ip := range subnets {
nc.Details.Ips[i] = &net.IPNet{IP: ip.Addr().AsSlice(), Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())}
}
}
if len(groups) > 0 {
nc.Details.Groups = groups
}
err = nc.Sign(cert.Curve_CURVE25519, priv)
if err != nil {
panic(err)
}
pem, err := nc.MarshalToPEM()
if err != nil {
panic(err)
}
return nc, pub, priv, pem
}
// NewTestCert will generate a signed certificate with the provided details.
// Expiry times are defaulted if you do not pass them in
func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip netip.Prefix, subnets []netip.Prefix, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
issuer, err := ca.Sha256Sum()
if err != nil {
panic(err)
}
if before.IsZero() {
before = time.Now().Add(time.Second * -60).Round(time.Second)
}
if after.IsZero() {
after = time.Now().Add(time.Second * 60).Round(time.Second)
}
pub, rawPriv := x25519Keypair()
ipb := ip.Addr().AsSlice()
nc := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
Name: name,
Ips: []*net.IPNet{{IP: ipb[:], Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())}},
//Subnets: subnets,
Groups: groups,
NotBefore: time.Unix(before.Unix(), 0),
NotAfter: time.Unix(after.Unix(), 0),
PublicKey: pub,
IsCA: false,
Issuer: issuer,
InvertedGroups: make(map[string]struct{}),
},
}
err = nc.Sign(ca.Details.Curve, key)
if err != nil {
panic(err)
}
pem, err := nc.MarshalToPEM()
if err != nil {
panic(err)
}
return nc, pub, cert.MarshalX25519PrivateKey(rawPriv), pem
}
func x25519Keypair() ([]byte, []byte) {
privkey := make([]byte, 32)
if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
panic(err)
}
pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint)
if err != nil {
panic(err)
}
return pubkey, privkey
}

View File

@ -8,6 +8,7 @@ import (
"io"
"net/netip"
"os"
"strings"
"testing"
"time"
@ -17,6 +18,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/e2e/router"
"github.com/stretchr/testify/assert"
@ -26,27 +28,37 @@ import (
type m map[string]interface{}
// newSimpleServer creates a nebula instance with many assumptions
func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, sVpnIpNet string, overrides m) (*nebula.Control, netip.Prefix, netip.AddrPort, *config.C) {
func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
l := NewTestLogger()
vpnIpNet, err := netip.ParsePrefix(sVpnIpNet)
var vpnNetworks []netip.Prefix
for _, sn := range strings.Split(sVpnNetworks, ",") {
vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
if err != nil {
panic(err)
}
vpnNetworks = append(vpnNetworks, vpnIpNet)
}
if len(vpnNetworks) == 0 {
panic("no vpn networks")
}
var udpAddr netip.AddrPort
if vpnIpNet.Addr().Is4() {
budpIp := vpnIpNet.Addr().As4()
if vpnNetworks[0].Addr().Is4() {
budpIp := vpnNetworks[0].Addr().As4()
budpIp[1] -= 128
udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242)
} else {
budpIp := vpnIpNet.Addr().As16()
budpIp[13] -= 128
budpIp := vpnNetworks[0].Addr().As16()
// beef for funsies
budpIp[2] = 190
budpIp[3] = 239
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
}
_, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{})
_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{})
caB, err := caCrt.MarshalToPEM()
caB, err := caCrt.MarshalPEM()
if err != nil {
panic(err)
}
@ -88,11 +100,16 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, s
}
if overrides != nil {
err = mergo.Merge(&overrides, mc, mergo.WithAppendSlice)
final := m{}
err = mergo.Merge(&final, overrides, mergo.WithAppendSlice)
if err != nil {
panic(err)
}
mc = overrides
err = mergo.Merge(&final, mc, mergo.WithAppendSlice)
if err != nil {
panic(err)
}
mc = final
}
cb, err := yaml.Marshal(mc)
@ -109,7 +126,7 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, s
panic(err)
}
return control, vpnIpNet, udpAddr, c
return control, vpnNetworks, udpAddr, c
}
type doneCb func()
@ -132,27 +149,28 @@ func deadline(t *testing.T, seconds time.Duration) doneCb {
func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
// Send a packet from them to me
controlB.InjectTunUDPPacket(vpnIpA, 80, 90, []byte("Hi from B"))
controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B"))
bPacket := r.RouteForAllUntilTxTun(controlA)
assertUdpPacket(t, []byte("Hi from B"), bPacket, vpnIpB, vpnIpA, 90, 80)
// And once more from me to them
controlA.InjectTunUDPPacket(vpnIpB, 80, 90, []byte("Hello from A"))
controlA.InjectTunUDPPacket(vpnIpB, 80, vpnIpA, 90, []byte("Hello from A"))
aPacket := r.RouteForAllUntilTxTun(controlB)
assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80)
}
func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control) {
func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpnNetsB []netip.Prefix, controlA, controlB *nebula.Control) {
// Get both host infos
hBinA := controlA.GetHostInfoByVpnIp(vpnIpB, false)
assert.NotNil(t, hBinA, "Host B was not found by vpnIp in controlA")
//TODO: CERT-V2 we may want to loop over each vpnAddr and assert all the things
hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false)
assert.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA")
hAinB := controlB.GetHostInfoByVpnIp(vpnIpA, false)
assert.NotNil(t, hAinB, "Host A was not found by vpnIp in controlB")
hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false)
assert.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB")
// Check that both vpn and real addr are correct
assert.Equal(t, vpnIpB, hBinA.VpnIp, "Host B VpnIp is wrong in control A")
assert.Equal(t, vpnIpA, hAinB.VpnIp, "Host A VpnIp is wrong in control B")
assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A")
assert.EqualValues(t, getAddrs(vpnNetsA), hAinB.VpnAddrs, "Host A VpnIp is wrong in control B")
assert.Equal(t, addrB, hBinA.CurrentRemote, "Host B remote is wrong in control A")
assert.Equal(t, addrA, hAinB.CurrentRemote, "Host A remote is wrong in control B")
@ -160,25 +178,36 @@ func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnIpA, vpnIp
// Check that our indexes match
assert.Equal(t, hBinA.LocalIndex, hAinB.RemoteIndex, "Host B local index does not match host A remote index")
assert.Equal(t, hBinA.RemoteIndex, hAinB.LocalIndex, "Host B remote index does not match host A local index")
//TODO: Would be nice to assert this memory
//checkIndexes := func(name string, hm *HostMap, hi *HostInfo) {
// hBbyIndex := hmA.Indexes[hBinA.localIndexId]
// assert.NotNil(t, hBbyIndex, "Could not host info by local index in %s", name)
// assert.Equal(t, &hBbyIndex, &hBinA, "%s Indexes map did not point to the right host info", name)
//
// //TODO: remote indexes are susceptible to collision
// hBbyRemoteIndex := hmA.RemoteIndexes[hBinA.remoteIndexId]
// assert.NotNil(t, hBbyIndex, "Could not host info by remote index in %s", name)
// assert.Equal(t, &hBbyRemoteIndex, &hBinA, "%s RemoteIndexes did not point to the right host info", name)
//}
//
//// Check hostmap indexes too
//checkIndexes("hmA", hmA, hBinA)
//checkIndexes("hmB", hmB, hAinB)
}
func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
if toIp.Is6() {
assertUdpPacket6(t, expected, b, fromIp, toIp, fromPort, toPort)
} else {
assertUdpPacket4(t, expected, b, fromIp, toIp, fromPort, toPort)
}
}
func assertUdpPacket6(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
packet := gopacket.NewPacket(b, layers.LayerTypeIPv6, gopacket.Lazy)
v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6)
assert.NotNil(t, v6, "No ipv6 data found")
assert.Equal(t, fromIp.AsSlice(), []byte(v6.SrcIP), "Source ip was incorrect")
assert.Equal(t, toIp.AsSlice(), []byte(v6.DstIP), "Dest ip was incorrect")
udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
assert.NotNil(t, udp, "No udp data found")
assert.Equal(t, fromPort, uint16(udp.SrcPort), "Source port was incorrect")
assert.Equal(t, toPort, uint16(udp.DstPort), "Dest port was incorrect")
data := packet.ApplicationLayer()
assert.NotNil(t, data)
assert.Equal(t, expected, data.Payload(), "Data was incorrect")
}
func assertUdpPacket4(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy)
v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
assert.NotNil(t, v4, "No ipv4 data found")
@ -197,6 +226,14 @@ func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr,
assert.Equal(t, expected, data.Payload(), "Data was incorrect")
}
func getAddrs(ns []netip.Prefix) []netip.Addr {
var a []netip.Addr
for _, n := range ns {
a = append(a, n.Addr())
}
return a
}
func NewTestLogger() *logrus.Logger {
l := logrus.New()

View File

@ -58,8 +58,9 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {
var lines []string
var globalLines []*edge
clusterName := strings.Trim(c.GetCert().Details.Name, " ")
clusterVpnIp := c.GetCert().Details.Ips[0].IP
crt := c.GetCertState().GetDefaultCertificate()
clusterName := strings.Trim(crt.Name(), " ")
clusterVpnIp := crt.Networks()[0].Addr()
r := fmt.Sprintf("\tsubgraph %s[\"%s (%s)\"]\n", clusterName, clusterName, clusterVpnIp)
hm := c.GetHostmap()
@ -101,8 +102,8 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {
for _, idx := range indexes {
hi, ok := hm.Indexes[idx]
if ok {
r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnIp())
remoteClusterName := strings.Trim(hi.GetCert().Details.Name, " ")
r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnAddrs())
remoteClusterName := strings.Trim(hi.GetCert().Certificate.Name(), " ")
globalLines = append(globalLines, &edge{from: fmt.Sprintf("%v.%v", clusterName, idx), to: fmt.Sprintf("%v.%v", remoteClusterName, hi.GetRemoteIndex())})
_ = hi
}

View File

@ -10,8 +10,8 @@ import (
"os"
"path/filepath"
"reflect"
"regexp"
"sort"
"strings"
"sync"
"testing"
"time"
@ -136,7 +136,10 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
panic("Duplicate listen address: " + addr.String())
}
r.vpnControls[c.GetVpnIp()] = c
for _, vpnAddr := range c.GetVpnAddrs() {
r.vpnControls[vpnAddr] = c
}
r.controls[addr] = c
}
@ -213,11 +216,11 @@ func (r *R) renderFlow() {
continue
}
participants[addr] = struct{}{}
sanAddr := strings.Replace(addr.String(), ":", "-", 1)
sanAddr := normalizeName(addr.String())
participantsVals = append(participantsVals, sanAddr)
fmt.Fprintf(
f, " participant %s as Nebula: %s<br/>UDP: %s\n",
sanAddr, e.packet.from.GetVpnIp(), sanAddr,
sanAddr, e.packet.from.GetVpnAddrs(), sanAddr,
)
}
@ -250,9 +253,9 @@ func (r *R) renderFlow() {
fmt.Fprintf(f,
" %s%s%s: %s(%s), index %v, counter: %v\n",
strings.Replace(p.from.GetUDPAddr().String(), ":", "-", 1),
normalizeName(p.from.GetUDPAddr().String()),
line,
strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1),
normalizeName(p.to.GetUDPAddr().String()),
h.TypeName(), h.SubTypeName(), h.RemoteIndex, h.MessageCounter,
)
}
@ -267,6 +270,11 @@ func (r *R) renderFlow() {
}
}
func normalizeName(s string) string {
rx := regexp.MustCompile("[\\[\\]\\:]")
return rx.ReplaceAllLiteralString(s, "_")
}
// IgnoreFlow tells the router to stop recording future flows that matches the provided criteria.
// messageType and subType will target nebula underlay packets while tun will target nebula overlay packets
// NOTE: This is a very broad system, if you set tun to true then no more tun traffic will be rendered
@ -303,7 +311,7 @@ func (r *R) RenderHostmaps(title string, controls ...*nebula.Control) {
func (r *R) renderHostmaps(title string) {
c := maps.Values(r.controls)
sort.SliceStable(c, func(i, j int) bool {
return c[i].GetVpnIp().Compare(c[j].GetVpnIp()) > 0
return c[i].GetVpnAddrs()[0].Compare(c[j].GetVpnAddrs()[0]) > 0
})
s := renderHostmaps(c...)
@ -419,10 +427,11 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) []
// Nope, lets push the sender along
case p := <-udpTx:
r.Lock()
c := r.getControl(sender.GetUDPAddr(), p.To, p)
a := sender.GetUDPAddr()
c := r.getControl(a, p.To, p)
if c == nil {
r.Unlock()
panic("No control for udp tx")
panic("No control for udp tx " + a.String())
}
fp := r.unlockedInjectFlow(sender, c, p, false)
c.InjectUDPPacket(p)
@ -475,10 +484,11 @@ func (r *R) RouteForAllUntilTxTun(receiver *nebula.Control) []byte {
} else {
// we are a udp tx, route and continue
p := rx.Interface().(*udp.Packet)
c := r.getControl(cm[x].GetUDPAddr(), p.To, p)
a := cm[x].GetUDPAddr()
c := r.getControl(a, p.To, p)
if c == nil {
r.Unlock()
panic("No control for udp tx")
panic(fmt.Sprintf("No control for udp tx %s", p.To))
}
fp := r.unlockedInjectFlow(cm[x], c, p, false)
c.InjectUDPPacket(p)
@ -711,30 +721,42 @@ func (r *R) getControl(fromAddr, toAddr netip.AddrPort, p *udp.Packet) *nebula.C
}
func (r *R) formatUdpPacket(p *packet) string {
packet := gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv4, gopacket.Lazy)
v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
if v4 == nil {
panic("not an ipv4 packet")
var packet gopacket.Packet
var srcAddr netip.Addr
packet = gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv6, gopacket.Lazy)
if packet.ErrorLayer() == nil {
v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6)
if v6 == nil {
panic("not an ipv6 packet")
}
srcAddr, _ = netip.AddrFromSlice(v6.SrcIP)
} else {
packet = gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv4, gopacket.Lazy)
v6 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
if v6 == nil {
panic("not an ipv6 packet")
}
srcAddr, _ = netip.AddrFromSlice(v6.SrcIP)
}
from := "unknown"
srcAddr, _ := netip.AddrFromSlice(v4.SrcIP)
if c, ok := r.vpnControls[srcAddr]; ok {
from = c.GetUDPAddr().String()
}
udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
if udp == nil {
udpLayer := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
if udpLayer == nil {
panic("not a udp packet")
}
data := packet.ApplicationLayer()
return fmt.Sprintf(
" %s-->>%s: src port: %v<br/>dest port: %v<br/>data: \"%v\"\n",
strings.Replace(from, ":", "-", 1),
strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1),
udp.SrcPort,
udp.DstPort,
normalizeName(from),
normalizeName(p.to.GetUDPAddr().String()),
udpLayer.SrcPort,
udpLayer.DstPort,
string(data.Payload()),
)
}

View File

@ -13,6 +13,12 @@ pki:
# disconnect_invalid is a toggle to force a client to be disconnected if the certificate is expired or invalid.
#disconnect_invalid: true
# default_version controls which certificate version is used in handshakes.
# This setting only applies if both a v1 and a v2 certificate are configured, in which case it will default to `1`.
# Once all hosts in the mesh are configured with both a v1 and v2 certificate then this should be changed to `2`.
# After all hosts in the mesh are using a v2 certificate then v1 certificates are no longer needed.
# default_version: 1
# The static host map defines a set of hosts with fixed IP addresses on the internet (or any network).
# A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel.
# The syntax is:
@ -120,8 +126,8 @@ lighthouse:
# Port Nebula will be listening on. The default here is 4242. For a lighthouse node, the port should be defined,
# however using port 0 will dynamically assign a port and is recommended for roaming nodes.
listen:
# To listen on both any ipv4 and ipv6 use "::"
host: 0.0.0.0
# To listen on only ipv4, use "0.0.0.0"
host: "::"
port: 4242
# Sets the max number of packets to pull from the kernel for each syscall (under systems that support recvmmsg)
# default is 64, does not support reload
@ -138,6 +144,11 @@ listen:
# valid values: always, never, private
# This setting is reloadable.
#send_recv_error: always
# The so_sock option is a Linux-specific feature that allows all outgoing Nebula packets to be tagged with a specific identifier.
# This tagging enables IP rule-based filtering. For example, it supports 0.0.0.0/0 unsafe_routes,
# allowing for more precise routing decisions based on the packet tags. Default is 0 meaning no mark is set.
# This setting is reloadable.
#so_mark: 0
# Routines is the number of thread pairs to run that consume from the tun and UDP queues.
# Currently, this defaults to 1 which means we have 1 tun queue reader and 1
@ -244,7 +255,6 @@ tun:
# in nebula configuration files. Default false, not reloadable.
#use_system_route_table: false
# TODO
# Configure logging level
logging:
# panic, fatal, error, warning, info, or debug. Default is info and is reloadable.
@ -336,10 +346,12 @@ firewall:
# host: `any` or a literal hostname, ie `test-host`
# group: `any` or a literal group name, ie `default-group`
# groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
# cidr: a remote CIDR, `0.0.0.0/0` is any.
# local_cidr: a local CIDR, `0.0.0.0/0` is any. This could be used to filter destinations when using unsafe_routes.
# Default is `any` unless the certificate contains subnets and then the default is the ip issued in the certificate
# if `default_local_cidr_any` is false, otherwise its `any`.
# cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6.
# local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. This could be used to filter destinations when using unsafe_routes.
# If no unsafe networks are present in the certificate(s) or `default_local_cidr_any` is true then the default is any ipv4 or ipv6 network.
# Otherwise the default is any vpn network assigned to via the certificate.
# `default_local_cidr_any` defaults to false and is deprecated, it will be removed in a future release.
# If there are unsafe routes present its best to set `local_cidr` to whatever best fits the situation.
# ca_name: An issuing CA name
# ca_sha: An issuing CA shasum

View File

@ -22,7 +22,7 @@ import (
)
type FirewallInterface interface {
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, addr, localAddr netip.Prefix, caName string, caSha string) error
}
type conn struct {
@ -51,10 +51,13 @@ type Firewall struct {
UDPTimeout time.Duration //linux: 180s max
DefaultTimeout time.Duration //linux: 600s
// Used to ensure we don't emit local packets for ips we don't own
localIps *bart.Table[struct{}]
assignedCIDR netip.Prefix
hasSubnets bool
// routableNetworks describes the vpn addresses as well as any unsafe networks issued to us in the certificate.
// The vpn addresses are a full bit match while the unsafe networks only match the prefix
routableNetworks *bart.Table[struct{}]
// assignedNetworks is a list of vpn networks assigned to us in the certificate.
assignedNetworks []netip.Prefix
hasUnsafeNetworks bool
rules string
rulesVersion uint16
@ -67,8 +70,8 @@ type Firewall struct {
}
type firewallMetrics struct {
droppedLocalIP metrics.Counter
droppedRemoteIP metrics.Counter
droppedLocalAddr metrics.Counter
droppedRemoteAddr metrics.Counter
droppedNoRule metrics.Counter
}
@ -126,88 +129,87 @@ type firewallLocalCIDR struct {
}
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall {
// The certificate provided should be the highest version loaded in memory.
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall {
//TODO: error on 0 duration
var min, max time.Duration
var tmin, tmax time.Duration
if tcpTimeout < UDPTimeout {
min = tcpTimeout
max = UDPTimeout
tmin = tcpTimeout
tmax = UDPTimeout
} else {
min = UDPTimeout
max = tcpTimeout
tmin = UDPTimeout
tmax = tcpTimeout
}
if defaultTimeout < min {
min = defaultTimeout
} else if defaultTimeout > max {
max = defaultTimeout
if defaultTimeout < tmin {
tmin = defaultTimeout
} else if defaultTimeout > tmax {
tmax = defaultTimeout
}
localIps := new(bart.Table[struct{}])
var assignedCIDR netip.Prefix
var assignedSet bool
for _, ip := range c.Details.Ips {
//TODO: IPV6-WORK the unmap is a bit unfortunate
nip, _ := netip.AddrFromSlice(ip.IP)
nip = nip.Unmap()
nprefix := netip.PrefixFrom(nip, nip.BitLen())
localIps.Insert(nprefix, struct{}{})
if !assignedSet {
// Only grabbing the first one in the cert since any more than that currently has undefined behavior
assignedCIDR = nprefix
assignedSet = true
}
routableNetworks := new(bart.Table[struct{}])
var assignedNetworks []netip.Prefix
for _, network := range c.Networks() {
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
routableNetworks.Insert(nprefix, struct{}{})
assignedNetworks = append(assignedNetworks, network)
}
for _, n := range c.Details.Subnets {
nip, _ := netip.AddrFromSlice(n.IP)
ones, _ := n.Mask.Size()
nip = nip.Unmap()
localIps.Insert(netip.PrefixFrom(nip, ones), struct{}{})
hasUnsafeNetworks := false
for _, n := range c.UnsafeNetworks() {
routableNetworks.Insert(n, struct{}{})
hasUnsafeNetworks = true
}
return &Firewall{
Conntrack: &FirewallConntrack{
Conns: make(map[firewall.Packet]*conn),
TimerWheel: NewTimerWheel[firewall.Packet](min, max),
TimerWheel: NewTimerWheel[firewall.Packet](tmin, tmax),
},
InRules: newFirewallTable(),
OutRules: newFirewallTable(),
TCPTimeout: tcpTimeout,
UDPTimeout: UDPTimeout,
DefaultTimeout: defaultTimeout,
localIps: localIps,
assignedCIDR: assignedCIDR,
hasSubnets: len(c.Details.Subnets) > 0,
routableNetworks: routableNetworks,
assignedNetworks: assignedNetworks,
hasUnsafeNetworks: hasUnsafeNetworks,
l: l,
incomingMetrics: firewallMetrics{
droppedLocalIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_ip", nil),
droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_ip", nil),
droppedLocalAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_addr", nil),
droppedRemoteAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_addr", nil),
droppedNoRule: metrics.GetOrRegisterCounter("firewall.incoming.dropped.no_rule", nil),
},
outgoingMetrics: firewallMetrics{
droppedLocalIP: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.local_ip", nil),
droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.remote_ip", nil),
droppedLocalAddr: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.local_addr", nil),
droppedRemoteAddr: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.remote_addr", nil),
droppedNoRule: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.no_rule", nil),
},
}
}
func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *config.C) (*Firewall, error) {
func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) {
certificate := cs.getCertificate(cert.Version2)
if certificate == nil {
certificate = cs.getCertificate(cert.Version1)
}
if certificate == nil {
panic("No certificate available to reconfigure the firewall")
}
fw := NewFirewall(
l,
c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3),
c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10),
nc,
certificate,
//TODO: max_connections
)
//TODO: Flip to false after v1.9 release
fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", true)
fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", false)
inboundAction := c.GetString("firewall.inbound_action", "drop")
switch inboundAction {
@ -287,7 +289,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
fp = ft.TCP
case firewall.ProtoUDP:
fp = ft.UDP
case firewall.ProtoICMP:
case firewall.ProtoICMP, firewall.ProtoICMPv6:
fp = ft.ICMP
case firewall.ProtoAny:
fp = ft.AnyProto
@ -421,33 +423,31 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
// Drop returns an error if the packet should be dropped, explaining why. It
// returns nil if the packet should not be dropped.
func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) error {
func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error {
// Check if we spoke to this tuple, if we did then allow this packet
if f.inConns(fp, h, caPool, localCache) {
return nil
}
// Make sure remote address matches nebula certificate
if remoteCidr := h.remoteCidr; remoteCidr != nil {
//TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different
_, ok := remoteCidr.Lookup(fp.RemoteIP)
if h.networks != nil {
_, ok := h.networks.Lookup(fp.RemoteAddr)
if !ok {
f.metrics(incoming).droppedRemoteIP.Inc(1)
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP
}
} else {
// Simple case: Certificate has one IP and no subnets
if fp.RemoteIP != h.vpnIp {
f.metrics(incoming).droppedRemoteIP.Inc(1)
// Simple case: Certificate has one address and no unsafe networks
if h.vpnAddrs[0] != fp.RemoteAddr {
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP
}
}
// Make sure we are supposed to be handling this local ip address
//TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different
_, ok := f.localIps.Lookup(fp.LocalIP)
_, ok := f.routableNetworks.Lookup(fp.LocalAddr)
if !ok {
f.metrics(incoming).droppedLocalIP.Inc(1)
f.metrics(incoming).droppedLocalAddr.Inc(1)
return ErrInvalidLocalIP
}
@ -492,7 +492,7 @@ func (f *Firewall) EmitStats() {
metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
}
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) bool {
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool {
if localCache != nil {
if _, ok := localCache[fp]; ok {
return true
@ -619,7 +619,7 @@ func (f *Firewall) evict(p firewall.Packet) {
delete(conntrack.Conns, p)
}
func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.CachedCertificate, caPool *cert.CAPool) bool {
if ft.AnyProto.match(p, incoming, c, caPool) {
return true
}
@ -633,7 +633,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaC
if ft.UDP.match(p, incoming, c, caPool) {
return true
}
case firewall.ProtoICMP:
case firewall.ProtoICMP, firewall.ProtoICMPv6:
if ft.ICMP.match(p, incoming, c, caPool) {
return true
}
@ -663,7 +663,7 @@ func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, grou
return nil
}
func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.CachedCertificate, caPool *cert.CAPool) bool {
// We don't have any allowed ports, bail
if fp == nil {
return false
@ -726,7 +726,7 @@ func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, loc
return nil
}
func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
func (fc *FirewallCA) match(p firewall.Packet, c *cert.CachedCertificate, caPool *cert.CAPool) bool {
if fc == nil {
return false
}
@ -735,18 +735,18 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool
return true
}
if t, ok := fc.CAShas[c.Details.Issuer]; ok {
if t, ok := fc.CAShas[c.Certificate.Issuer()]; ok {
if t.match(p, c) {
return true
}
}
s, err := caPool.GetCAForCert(c)
s, err := caPool.GetCAForCert(c.Certificate)
if err != nil {
return false
}
return fc.CANames[s.Details.Name].match(p, c)
return fc.CANames[s.Certificate.Name()].match(p, c)
}
func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error {
@ -826,7 +826,7 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip netip.Prefix) boo
return false
}
func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool {
func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool {
if fr == nil {
return false
}
@ -841,7 +841,7 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
found := false
for _, g := range sg.Groups {
if _, ok := c.Details.InvertedGroups[g]; !ok {
if _, ok := c.InvertedGroups[g]; !ok {
found = false
break
}
@ -855,42 +855,44 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
}
if fr.Hosts != nil {
if flc, ok := fr.Hosts[c.Details.Name]; ok {
if flc, ok := fr.Hosts[c.Certificate.Name()]; ok {
if flc.match(p, c) {
return true
}
}
}
matched := false
prefix := netip.PrefixFrom(p.RemoteIP, p.RemoteIP.BitLen())
fr.CIDR.EachLookupPrefix(prefix, func(prefix netip.Prefix, val *firewallLocalCIDR) bool {
if prefix.Contains(p.RemoteIP) && val.match(p, c) {
matched = true
return false
}
for _, v := range fr.CIDR.Supernets(netip.PrefixFrom(p.RemoteAddr, p.RemoteAddr.BitLen())) {
if v.match(p, c) {
return true
})
return matched
}
}
return false
}
func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
if !localIp.IsValid() {
if !f.hasSubnets || f.defaultLocalCIDRAny {
if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny {
flc.Any = true
return nil
}
localIp = f.assignedCIDR
for _, network := range f.assignedNetworks {
flc.LocalCIDR.Insert(network, struct{}{})
}
return nil
} else if localIp.Bits() == 0 {
flc.Any = true
return nil
}
flc.LocalCIDR.Insert(localIp, struct{}{})
return nil
}
func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate) bool {
func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.CachedCertificate) bool {
if flc == nil {
return false
}
@ -899,7 +901,7 @@ func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate
return true
}
_, ok := flc.LocalCIDR.Lookup(p.LocalIP)
_, ok := flc.LocalCIDR.Lookup(p.LocalAddr)
return ok
}

View File

@ -13,14 +13,15 @@ const (
ProtoTCP = 6
ProtoUDP = 17
ProtoICMP = 1
ProtoICMPv6 = 58
PortAny = 0 // Special value for matching `port: any`
PortFragment = -1 // Special value for matching `port: fragment`
)
type Packet struct {
LocalIP netip.Addr
RemoteIP netip.Addr
LocalAddr netip.Addr
RemoteAddr netip.Addr
LocalPort uint16
RemotePort uint16
Protocol uint8
@ -29,8 +30,8 @@ type Packet struct {
func (fp *Packet) Copy() *Packet {
return &Packet{
LocalIP: fp.LocalIP,
RemoteIP: fp.RemoteIP,
LocalAddr: fp.LocalAddr,
RemoteAddr: fp.RemoteAddr,
LocalPort: fp.LocalPort,
RemotePort: fp.RemotePort,
Protocol: fp.Protocol,
@ -51,8 +52,8 @@ func (fp Packet) MarshalJSON() ([]byte, error) {
proto = fmt.Sprintf("unknown %v", fp.Protocol)
}
return json.Marshal(m{
"LocalIP": fp.LocalIP.String(),
"RemoteIP": fp.RemoteIP.String(),
"LocalAddr": fp.LocalAddr.String(),
"RemoteAddr": fp.RemoteAddr.String(),
"LocalPort": fp.LocalPort,
"RemotePort": fp.RemotePort,
"Protocol": proto,

View File

@ -4,7 +4,6 @@ import (
"bytes"
"errors"
"math"
"net"
"net/netip"
"testing"
"time"
@ -14,11 +13,12 @@ import (
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewFirewall(t *testing.T) {
l := test.NewLogger()
c := &cert.NebulaCertificate{}
c := &dummyCert{}
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
conntrack := fw.Conntrack
assert.NotNil(t, conntrack)
@ -60,67 +60,67 @@ func TestFirewall_AddRule(t *testing.T) {
ob := &bytes.Buffer{}
l.SetOutput(ob)
c := &cert.NebulaCertificate{}
c := &dummyCert{}
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.NotNil(t, fw.InRules)
assert.NotNil(t, fw.OutRules)
ti, err := netip.ParsePrefix("1.2.3.4/32")
assert.NoError(t, err)
require.NoError(t, err)
assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
// An empty rule is any
assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
assert.Nil(t, fw.InRules.UDP[1].Any.Any)
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", ""))
assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", ""))
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha"))
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha"))
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
anyIp, err := netip.ParsePrefix("0.0.0.0/0")
assert.NoError(t, err)
require.NoError(t, err)
assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
// Test error conditions
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
}
func TestFirewall_Drop(t *testing.T) {
@ -129,79 +129,74 @@ func TestFirewall_Drop(t *testing.T) {
l.SetOutput(ob)
p := firewall.Packet{
LocalIP: netip.MustParseAddr("1.2.3.4"),
RemoteIP: netip.MustParseAddr("1.2.3.4"),
LocalAddr: netip.MustParseAddr("1.2.3.4"),
RemoteAddr: netip.MustParseAddr("1.2.3.4"),
LocalPort: 10,
RemotePort: 90,
Protocol: firewall.ProtoUDP,
Fragment: false,
}
ipNet := net.IPNet{
IP: net.IPv4(1, 2, 3, 4),
Mask: net.IPMask{255, 255, 255, 0},
}
c := cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
Name: "host1",
Ips: []*net.IPNet{&ipNet},
Groups: []string{"default-group"},
InvertedGroups: map[string]struct{}{"default-group": {}},
Issuer: "signer-shasum",
},
c := dummyCert{
name: "host1",
networks: []netip.Prefix{netip.MustParsePrefix("1.2.3.4/24")},
groups: []string{"default-group"},
issuer: "signer-shasum",
}
h := HostInfo{
ConnectionState: &ConnectionState{
peerCert: &c,
peerCert: &cert.CachedCertificate{
Certificate: &c,
InvertedGroups: map[string]struct{}{"default-group": {}},
},
vpnIp: netip.MustParseAddr("1.2.3.4"),
},
vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
}
h.CreateRemoteCIDR(&c)
h.buildNetworks(c.networks, c.unsafeNetworks)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool()
// Drop outbound
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
// Allow inbound
resetConntrack(fw)
assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
// Allow outbound because conntrack
assert.NoError(t, fw.Drop(p, false, &h, cp, nil))
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
// test remote mismatch
oldRemote := p.RemoteIP
p.RemoteIP = netip.MustParseAddr("1.2.3.10")
oldRemote := p.RemoteAddr
p.RemoteAddr = netip.MustParseAddr("1.2.3.10")
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
p.RemoteIP = oldRemote
p.RemoteAddr = oldRemote
// ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
}
func BenchmarkFirewallTable_match(b *testing.B) {
@ -217,7 +212,9 @@ func BenchmarkFirewallTable_match(b *testing.B) {
b.Run("fail on proto", func(b *testing.B) {
// This benchmark is showing us the cost of failing to match the protocol
c := &cert.NebulaCertificate{}
c := &cert.CachedCertificate{
Certificate: &dummyCert{},
}
for n := 0; n < b.N; n++ {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp))
}
@ -225,28 +222,31 @@ func BenchmarkFirewallTable_match(b *testing.B) {
b.Run("pass proto, fail on port", func(b *testing.B) {
// This benchmark is showing us the cost of matching a specific protocol but failing to match the port
c := &cert.NebulaCertificate{}
c := &cert.CachedCertificate{
Certificate: &dummyCert{},
}
for n := 0; n < b.N; n++ {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp))
}
})
b.Run("pass proto, port, fail on local CIDR", func(b *testing.B) {
c := &cert.NebulaCertificate{}
c := &cert.CachedCertificate{
Certificate: &dummyCert{},
}
ip := netip.MustParsePrefix("9.254.254.254/32")
for n := 0; n < b.N; n++ {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip.Addr()}, true, c, cp))
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
}
})
b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) {
_, ip, _ := net.ParseCIDR("9.254.254.254/32")
c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
InvertedGroups: map[string]struct{}{"nope": {}},
Name: "nope",
Ips: []*net.IPNet{ip},
c := &cert.CachedCertificate{
Certificate: &dummyCert{
name: "nope",
networks: []netip.Prefix{netip.MustParsePrefix("9.254.254.245/32")},
},
InvertedGroups: map[string]struct{}{"nope": {}},
}
for n := 0; n < b.N; n++ {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
@ -254,25 +254,24 @@ func BenchmarkFirewallTable_match(b *testing.B) {
})
b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) {
_, ip, _ := net.ParseCIDR("9.254.254.254/32")
c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
InvertedGroups: map[string]struct{}{"nope": {}},
Name: "nope",
Ips: []*net.IPNet{ip},
c := &cert.CachedCertificate{
Certificate: &dummyCert{
name: "nope",
networks: []netip.Prefix{netip.MustParsePrefix("9.254.254.245/32")},
},
InvertedGroups: map[string]struct{}{"nope": {}},
}
for n := 0; n < b.N; n++ {
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp))
assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
}
})
b.Run("pass on group on any local cidr", func(b *testing.B) {
c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
InvertedGroups: map[string]struct{}{"good-group": {}},
Name: "nope",
c := &cert.CachedCertificate{
Certificate: &dummyCert{
name: "nope",
},
InvertedGroups: map[string]struct{}{"good-group": {}},
}
for n := 0; n < b.N; n++ {
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
@ -280,82 +279,28 @@ func BenchmarkFirewallTable_match(b *testing.B) {
})
b.Run("pass on group on specific local cidr", func(b *testing.B) {
c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
InvertedGroups: map[string]struct{}{"good-group": {}},
Name: "nope",
c := &cert.CachedCertificate{
Certificate: &dummyCert{
name: "nope",
},
InvertedGroups: map[string]struct{}{"good-group": {}},
}
for n := 0; n < b.N; n++ {
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp))
assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
}
})
b.Run("pass on name", func(b *testing.B) {
c := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
InvertedGroups: map[string]struct{}{"nope": {}},
Name: "good-host",
c := &cert.CachedCertificate{
Certificate: &dummyCert{
name: "good-host",
},
InvertedGroups: map[string]struct{}{"nope": {}},
}
for n := 0; n < b.N; n++ {
ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
}
})
//
//b.Run("pass on ip", func(b *testing.B) {
// ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
// c := &cert.NebulaCertificate{
// Details: cert.NebulaCertificateDetails{
// InvertedGroups: map[string]struct{}{"nope": {}},
// Name: "good-host",
// },
// }
// for n := 0; n < b.N; n++ {
// ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp)
// }
//})
//
//b.Run("pass on local ip", func(b *testing.B) {
// ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
// c := &cert.NebulaCertificate{
// Details: cert.NebulaCertificateDetails{
// InvertedGroups: map[string]struct{}{"nope": {}},
// Name: "good-host",
// },
// }
// for n := 0; n < b.N; n++ {
// ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, LocalIP: ip}, true, c, cp)
// }
//})
//
//_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, n, "", "")
//
//b.Run("pass on ip with any port", func(b *testing.B) {
// ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
// c := &cert.NebulaCertificate{
// Details: cert.NebulaCertificateDetails{
// InvertedGroups: map[string]struct{}{"nope": {}},
// Name: "good-host",
// },
// }
// for n := 0; n < b.N; n++ {
// ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp)
// }
//})
//
//b.Run("pass on local ip with any port", func(b *testing.B) {
// ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
// c := &cert.NebulaCertificate{
// Details: cert.NebulaCertificateDetails{
// InvertedGroups: map[string]struct{}{"nope": {}},
// Name: "good-host",
// },
// }
// for n := 0; n < b.N; n++ {
// ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip}, true, c, cp)
// }
//})
}
func TestFirewall_Drop2(t *testing.T) {
@ -364,57 +309,55 @@ func TestFirewall_Drop2(t *testing.T) {
l.SetOutput(ob)
p := firewall.Packet{
LocalIP: netip.MustParseAddr("1.2.3.4"),
RemoteIP: netip.MustParseAddr("1.2.3.4"),
LocalAddr: netip.MustParseAddr("1.2.3.4"),
RemoteAddr: netip.MustParseAddr("1.2.3.4"),
LocalPort: 10,
RemotePort: 90,
Protocol: firewall.ProtoUDP,
Fragment: false,
}
ipNet := net.IPNet{
IP: net.IPv4(1, 2, 3, 4),
Mask: net.IPMask{255, 255, 255, 0},
}
network := netip.MustParsePrefix("1.2.3.4/24")
c := cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
Name: "host1",
Ips: []*net.IPNet{&ipNet},
InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}},
c := cert.CachedCertificate{
Certificate: &dummyCert{
name: "host1",
networks: []netip.Prefix{network},
},
InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}},
}
h := HostInfo{
ConnectionState: &ConnectionState{
peerCert: &c,
},
vpnIp: netip.MustParseAddr(ipNet.IP.String()),
vpnAddrs: []netip.Addr{network.Addr()},
}
h.CreateRemoteCIDR(&c)
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
c1 := cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
Name: "host1",
Ips: []*net.IPNet{&ipNet},
InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}},
c1 := cert.CachedCertificate{
Certificate: &dummyCert{
name: "host1",
networks: []netip.Prefix{network},
},
InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}},
}
h1 := HostInfo{
vpnAddrs: []netip.Addr{network.Addr()},
ConnectionState: &ConnectionState{
peerCert: &c1,
},
}
h1.CreateRemoteCIDR(&c1)
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool()
// h1/c1 lacks the proper groups
assert.Error(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule)
require.ErrorIs(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule)
// c has the proper groups
resetConntrack(fw)
assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
}
func TestFirewall_Drop3(t *testing.T) {
@ -423,84 +366,85 @@ func TestFirewall_Drop3(t *testing.T) {
l.SetOutput(ob)
p := firewall.Packet{
LocalIP: netip.MustParseAddr("1.2.3.4"),
RemoteIP: netip.MustParseAddr("1.2.3.4"),
LocalAddr: netip.MustParseAddr("1.2.3.4"),
RemoteAddr: netip.MustParseAddr("1.2.3.4"),
LocalPort: 1,
RemotePort: 1,
Protocol: firewall.ProtoUDP,
Fragment: false,
}
ipNet := net.IPNet{
IP: net.IPv4(1, 2, 3, 4),
Mask: net.IPMask{255, 255, 255, 0},
}
c := cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
Name: "host-owner",
Ips: []*net.IPNet{&ipNet},
network := netip.MustParsePrefix("1.2.3.4/24")
c := cert.CachedCertificate{
Certificate: &dummyCert{
name: "host-owner",
networks: []netip.Prefix{network},
},
}
c1 := cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
Name: "host1",
Ips: []*net.IPNet{&ipNet},
Issuer: "signer-sha-bad",
c1 := cert.CachedCertificate{
Certificate: &dummyCert{
name: "host1",
networks: []netip.Prefix{network},
issuer: "signer-sha-bad",
},
}
h1 := HostInfo{
ConnectionState: &ConnectionState{
peerCert: &c1,
},
vpnIp: netip.MustParseAddr(ipNet.IP.String()),
vpnAddrs: []netip.Addr{network.Addr()},
}
h1.CreateRemoteCIDR(&c1)
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
c2 := cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
Name: "host2",
Ips: []*net.IPNet{&ipNet},
Issuer: "signer-sha",
c2 := cert.CachedCertificate{
Certificate: &dummyCert{
name: "host2",
networks: []netip.Prefix{network},
issuer: "signer-sha",
},
}
h2 := HostInfo{
ConnectionState: &ConnectionState{
peerCert: &c2,
},
vpnIp: netip.MustParseAddr(ipNet.IP.String()),
vpnAddrs: []netip.Addr{network.Addr()},
}
h2.CreateRemoteCIDR(&c2)
h2.buildNetworks(c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks())
c3 := cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
Name: "host3",
Ips: []*net.IPNet{&ipNet},
Issuer: "signer-sha-bad",
c3 := cert.CachedCertificate{
Certificate: &dummyCert{
name: "host3",
networks: []netip.Prefix{network},
issuer: "signer-sha-bad",
},
}
h3 := HostInfo{
ConnectionState: &ConnectionState{
peerCert: &c3,
},
vpnIp: netip.MustParseAddr(ipNet.IP.String()),
vpnAddrs: []netip.Addr{network.Addr()},
}
h3.CreateRemoteCIDR(&c3)
h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha"))
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha"))
cp := cert.NewCAPool()
// c1 should pass because host match
assert.NoError(t, fw.Drop(p, true, &h1, cp, nil))
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
// c2 should pass because ca sha match
resetConntrack(fw)
assert.NoError(t, fw.Drop(p, true, &h2, cp, nil))
require.NoError(t, fw.Drop(p, true, &h2, cp, nil))
// c3 should fail because no match
resetConntrack(fw)
assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule)
// Test a remote address match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("1.2.3.4/24"), netip.Prefix{}, "", ""))
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
}
func TestFirewall_DropConntrackReload(t *testing.T) {
@ -509,60 +453,56 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
l.SetOutput(ob)
p := firewall.Packet{
LocalIP: netip.MustParseAddr("1.2.3.4"),
RemoteIP: netip.MustParseAddr("1.2.3.4"),
LocalAddr: netip.MustParseAddr("1.2.3.4"),
RemoteAddr: netip.MustParseAddr("1.2.3.4"),
LocalPort: 10,
RemotePort: 90,
Protocol: firewall.ProtoUDP,
Fragment: false,
}
network := netip.MustParsePrefix("1.2.3.4/24")
ipNet := net.IPNet{
IP: net.IPv4(1, 2, 3, 4),
Mask: net.IPMask{255, 255, 255, 0},
}
c := cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
Name: "host1",
Ips: []*net.IPNet{&ipNet},
Groups: []string{"default-group"},
InvertedGroups: map[string]struct{}{"default-group": {}},
Issuer: "signer-shasum",
c := cert.CachedCertificate{
Certificate: &dummyCert{
name: "host1",
networks: []netip.Prefix{network},
groups: []string{"default-group"},
issuer: "signer-shasum",
},
InvertedGroups: map[string]struct{}{"default-group": {}},
}
h := HostInfo{
ConnectionState: &ConnectionState{
peerCert: &c,
},
vpnIp: netip.MustParseAddr(ipNet.IP.String()),
vpnAddrs: []netip.Addr{network.Addr()},
}
h.CreateRemoteCIDR(&c)
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool()
// Drop outbound
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
// Allow outbound because conntrack
assert.NoError(t, fw.Drop(p, false, &h, cp, nil))
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
oldFw := fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1
// Allow outbound because conntrack and new rules allow port 10
assert.NoError(t, fw.Drop(p, false, &h, cp, nil))
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
oldFw = fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1
@ -641,104 +581,105 @@ func BenchmarkLookup(b *testing.B) {
ml(m, a)
}
})
//TODO: only way array lookup in array will help is if both are sorted, then maybe it's faster
}
func Test_parsePort(t *testing.T) {
_, _, err := parsePort("")
assert.EqualError(t, err, "was not a number; ``")
require.EqualError(t, err, "was not a number; ``")
_, _, err = parsePort(" ")
assert.EqualError(t, err, "was not a number; ` `")
require.EqualError(t, err, "was not a number; ` `")
_, _, err = parsePort("-")
assert.EqualError(t, err, "appears to be a range but could not be parsed; `-`")
require.EqualError(t, err, "appears to be a range but could not be parsed; `-`")
_, _, err = parsePort(" - ")
assert.EqualError(t, err, "appears to be a range but could not be parsed; ` - `")
require.EqualError(t, err, "appears to be a range but could not be parsed; ` - `")
_, _, err = parsePort("a-b")
assert.EqualError(t, err, "beginning range was not a number; `a`")
require.EqualError(t, err, "beginning range was not a number; `a`")
_, _, err = parsePort("1-b")
assert.EqualError(t, err, "ending range was not a number; `b`")
require.EqualError(t, err, "ending range was not a number; `b`")
s, e, err := parsePort(" 1 - 2 ")
assert.Equal(t, int32(1), s)
assert.Equal(t, int32(2), e)
assert.Nil(t, err)
require.NoError(t, err)
s, e, err = parsePort("0-1")
assert.Equal(t, int32(0), s)
assert.Equal(t, int32(0), e)
assert.Nil(t, err)
require.NoError(t, err)
s, e, err = parsePort("9919")
assert.Equal(t, int32(9919), s)
assert.Equal(t, int32(9919), e)
assert.Nil(t, err)
require.NoError(t, err)
s, e, err = parsePort("any")
assert.Equal(t, int32(0), s)
assert.Equal(t, int32(0), e)
assert.Nil(t, err)
require.NoError(t, err)
}
func TestNewFirewallFromConfig(t *testing.T) {
l := test.NewLogger()
// Test a bad rule definition
c := &cert.NebulaCertificate{}
c := &dummyCert{}
cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil)
require.NoError(t, err)
conf := config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
_, err := NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
// Test both port and code
conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
_, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
// Test missing host, group, cidr, ca_name and ca_sha
conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
_, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
// Test code/port error
conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
_, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}}
_, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
// Test proto error
conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
_, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
// Test cidr parse error
conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
// Test local_cidr parse error
conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
// Test both group and groups
conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
_, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
_, err = NewFirewallFromConfig(l, cs, conf)
require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
}
func TestAddFirewallRulesFromConfig(t *testing.T) {
@ -747,28 +688,28 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
conf := config.NewC(l)
mf := &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding udp rule
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding icmp rule
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding any rule
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding rule with cidr
@ -776,49 +717,49 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding rule with local_cidr
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
// Test adding rule with ca_sha
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall)
// Test adding rule with ca_name
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall)
// Test single group
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test single groups
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test multiple AND groups
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test Add error
@ -826,7 +767,7 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
mf = &mockFirewall{}
mf.nextCallReturn = errors.New("test error")
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
assert.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`")
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`")
}
func TestFirewall_convertRule(t *testing.T) {
@ -841,7 +782,7 @@ func TestFirewall_convertRule(t *testing.T) {
r, err := convertRule(l, c, "test", 1)
assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
assert.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, "group1", r.Group)
// Ensure group array of > 1 is errord
@ -852,7 +793,7 @@ func TestFirewall_convertRule(t *testing.T) {
r, err = convertRule(l, c, "test", 1)
assert.Equal(t, "", ob.String())
assert.Error(t, err, "group should contain a single value, an array with more than one entry was provided")
require.Error(t, err, "group should contain a single value, an array with more than one entry was provided")
// Make sure a well formed group is alright
ob.Reset()
@ -861,7 +802,7 @@ func TestFirewall_convertRule(t *testing.T) {
}
r, err = convertRule(l, c, "test", 1)
assert.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, "group1", r.Group)
}

38
go.mod
View File

@ -1,8 +1,8 @@
module github.com/slackhq/nebula
go 1.22.0
go 1.23.6
toolchain go1.22.2
toolchain go1.23.7
require (
dario.cat/mergo v1.0.1
@ -10,45 +10,45 @@ require (
github.com/armon/go-radix v1.0.0
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
github.com/flynn/noise v1.1.0
github.com/gaissmai/bart v0.11.1
github.com/gaissmai/bart v0.18.1
github.com/gogo/protobuf v1.3.2
github.com/google/gopacket v1.1.19
github.com/kardianos/service v1.2.2
github.com/miekg/dns v1.1.61
github.com/miekg/dns v1.1.63
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b
github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
github.com/prometheus/client_golang v1.19.1
github.com/prometheus/client_golang v1.21.1
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
github.com/sirupsen/logrus v1.9.3
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
github.com/stretchr/testify v1.9.0
github.com/vishvananda/netlink v1.2.1-beta.2
golang.org/x/crypto v0.26.0
github.com/stretchr/testify v1.10.0
github.com/vishvananda/netlink v1.3.0
golang.org/x/crypto v0.36.0
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
golang.org/x/net v0.28.0
golang.org/x/sync v0.8.0
golang.org/x/sys v0.24.0
golang.org/x/term v0.23.0
golang.org/x/net v0.37.0
golang.org/x/sync v0.12.0
golang.org/x/sys v0.31.0
golang.org/x/term v0.30.0
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/protobuf v1.34.2
google.golang.org/protobuf v1.36.5
gopkg.in/yaml.v2 v2.4.0
gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe
)
require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/bits-and-blooms/bitset v1.13.0 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/btree v1.1.2 // indirect
github.com/klauspost/compress v1.17.11 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_model v0.5.0 // indirect
github.com/prometheus/common v0.48.0 // indirect
github.com/prometheus/procfs v0.12.0 // indirect
github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.62.0 // indirect
github.com/prometheus/procfs v0.15.1 // indirect
github.com/vishvananda/netns v0.0.4 // indirect
golang.org/x/mod v0.18.0 // indirect
golang.org/x/time v0.5.0 // indirect

75
go.sum
View File

@ -14,11 +14,9 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24
github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/bits-and-blooms/bitset v1.13.0 h1:bAQ9OPNFYbGHV6Nez0tmNI0RiEu7/hxlYJRUA0wFAVE=
github.com/bits-and-blooms/bitset v1.13.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8=
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 h1:M5QgkYacWj0Xs8MhpIK/5uwU02icXpEoSo9sM2aRCps=
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432/go.mod h1:xwIwAxMvYnVrGJPe2FKx5prTrnAjGOD8zvDOnxnrrkM=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@ -26,8 +24,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
github.com/gaissmai/bart v0.11.1 h1:5Uv5XwsaFBRo4E5VBcb9TzY8B7zxFf+U7isDxqOrRfc=
github.com/gaissmai/bart v0.11.1/go.mod h1:KHeYECXQiBjTzQz/om2tqn3sZF1J7hw9m6z41ftj3fg=
github.com/gaissmai/bart v0.18.1 h1:bX2j560JC1MJpoEDevBGvXL5OZ1mkls320Vl8Igb5QQ=
github.com/gaissmai/bart v0.18.1/go.mod h1:JJzMAhNF5Rjo4SF4jWBrANuJfqY+FvsFhW7t1UZJ+XY=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
@ -70,6 +68,8 @@ github.com/kardianos/service v1.2.2 h1:ZvePhAHfvo0A7Mftk/tEzqEZ7Q4lgnR8sGz4xu1YX
github.com/kardianos/service v1.2.2/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
@ -80,15 +80,19 @@ github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3x
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/miekg/dns v1.1.61 h1:nLxbwF3XxhwVSm8g9Dghm9MHPaUZuqhPiGL+675ZmEs=
github.com/miekg/dns v1.1.61/go.mod h1:mnAarhS3nWaW+NVP2wTkYVIZyHNJ098SJZUki3eykwQ=
github.com/miekg/dns v1.1.63 h1:8M5aAw6OMZfFXTT7K5V0Eu5YiiL8l7nUAkyN6C9YwaY=
github.com/miekg/dns v1.1.63/go.mod h1:6NGHfjhpmr5lt3XPLuyfDJi5AXbNIPM9PY6H6sF1Nfs=
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk=
github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f h1:8dM0ilqKL0Uzl42GABzzC4Oqlc3kGRILz0vgoff7nwg=
@ -102,24 +106,24 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
github.com/prometheus/client_golang v1.19.1 h1:wZWJDwK+NameRJuPGDhlnFgx8e8HN3XHQeLaYJFJBOE=
github.com/prometheus/client_golang v1.19.1/go.mod h1:mP78NwGzrVks5S2H6ab8+ZZGJLZUq1hoULYBAYBw1Ho=
github.com/prometheus/client_golang v1.21.1 h1:DOvXXTqVzvkIewV/CDPFdejpMCGeMcbGCQ8YOmu+Ibk=
github.com/prometheus/client_golang v1.21.1/go.mod h1:U9NM32ykUErtVBxdvD3zfi+EuFkkaBvMb09mIfe0Zgg=
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/client_model v0.5.0 h1:VQw1hfvPvk3Uv6Qf29VrPF32JB6rtbgI6cYPYQjL0Qw=
github.com/prometheus/client_model v0.5.0/go.mod h1:dTiFglRmd66nLR9Pv9f0mZi7B7fk5Pm3gvsjB5tr+kI=
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4=
github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo=
github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc=
github.com/prometheus/common v0.48.0 h1:QO8U2CdOzSn1BBsmXJXduaaW+dY/5QLjfB8svtSzKKE=
github.com/prometheus/common v0.48.0/go.mod h1:0/KsvlIEfPQCQ5I2iNSAWKPZziNCvRs5EC6ILDTlAPc=
github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io=
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo=
github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo=
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM=
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
@ -131,8 +135,6 @@ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8=
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E=
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 h1:pnnLyeX7o/5aX8qUQ69P/mLojDqwda8hFOCBTmP/6hw=
github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6/go.mod h1:39R/xuhNgVhi+K0/zst4TLrJrVmbm6LVgl4A0+ZFS5M=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
@ -141,11 +143,10 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/vishvananda/netlink v1.2.1-beta.2 h1:Llsql0lnQEbHj0I1OuKyp8otXp0r3q0mPkuhwHfStVs=
github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho=
github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk=
github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs=
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
@ -155,8 +156,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw=
golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
@ -175,8 +176,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE=
golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg=
golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@ -184,30 +185,30 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200728102440-3e129f6d46b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg=
golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.23.0 h1:F6D4vR+EHoL9/sWAWgAR1H2DcHr4PareCbAaCo1RpuU=
golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk=
golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y=
golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@ -238,8 +239,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM=
google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@ -2,10 +2,12 @@ package nebula
import (
"net/netip"
"slices"
"time"
"github.com/flynn/noise"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/header"
)
@ -16,30 +18,60 @@ import (
func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
err := f.handshakeManager.allocateIndex(hh)
if err != nil {
f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp).
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
return false
}
certState := f.pki.GetCertState()
ci := NewConnectionState(f.l, f.cipher, certState, true, noise.HandshakeIX, []byte{}, 0)
// If we're connecting to a v6 address we must use a v2 cert
cs := f.pki.getCertState()
v := cs.defaultVersion
for _, a := range hh.hostinfo.vpnAddrs {
if a.Is6() {
v = cert.Version2
break
}
}
crt := cs.getCertificate(v)
if crt == nil {
f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", v).
Error("Unable to handshake with host because no certificate is available")
return false
}
crtHs := cs.getHandshakeBytes(v)
if crtHs == nil {
f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", v).
Error("Unable to handshake with host because no certificate handshake bytes is available")
}
ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", v).
Error("Failed to create connection state")
return false
}
hh.hostinfo.ConnectionState = ci
hsProto := &NebulaHandshakeDetails{
hs := &NebulaHandshake{
Details: &NebulaHandshakeDetails{
InitiatorIndex: hh.hostinfo.localIndexId,
Time: uint64(time.Now().UnixNano()),
Cert: certState.RawCertificateNoKey,
Cert: crtHs,
CertVersion: uint32(v),
},
}
hsBytes := []byte{}
hs := &NebulaHandshake{
Details: hsProto,
}
hsBytes, err = hs.Marshal()
hsBytes, err := hs.Marshal()
if err != nil {
f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp).
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).WithField("certVersion", v).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
return false
}
@ -48,7 +80,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
if err != nil {
f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp).
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
return false
}
@ -63,62 +95,104 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
}
func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
certState := f.pki.GetCertState()
ci := NewConnectionState(f.l, f.cipher, certState, false, noise.HandshakeIX, []byte{}, 0)
cs := f.pki.getCertState()
crt := cs.GetDefaultCertificate()
if crt == nil {
f.l.WithField("udpAddr", addr).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", cs.defaultVersion).
Error("Unable to handshake with host because no certificate is available")
}
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
if err != nil {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to create connection state")
return
}
// Mark packet 1 as seen so it doesn't show up as missed
ci.window.Update(f.l, 1)
msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage")
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to call noise.ReadMessage")
return
}
hs := &NebulaHandshake{}
err = hs.Unmarshal(msg)
/*
l.Debugln("GOT INDEX: ", hs.Details.InitiatorIndex)
*/
if err != nil || hs.Details == nil {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed unmarshal handshake message")
return
}
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool())
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
if err != nil {
e := f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"})
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Handshake did not contain a certificate")
return
}
if f.l.Level > logrus.DebugLevel {
e = e.WithField("cert", remoteCert)
remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
if err != nil {
fp, err := rc.Fingerprint()
if err != nil {
fp = "<error generating certificate fingerprint>"
}
e := f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("certVpnNetworks", rc.Networks()).
WithField("certFingerprint", fp)
if f.l.Level >= logrus.DebugLevel {
e = e.WithField("cert", rc)
}
e.Info("Invalid certificate from host")
return
}
vpnIp, ok := netip.AddrFromSlice(remoteCert.Details.Ips[0].IP)
if !ok {
e := f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"})
if f.l.Level > logrus.DebugLevel {
e = e.WithField("cert", remoteCert)
}
e.Info("Invalid vpn ip from host")
if remoteCert.Certificate.Version() != ci.myCert.Version() {
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
rc := cs.getCertificate(remoteCert.Certificate.Version())
if rc == nil {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
Info("Unable to handshake with host due to missing certificate version")
return
}
vpnIp = vpnIp.Unmap()
certName := remoteCert.Details.Name
fingerprint, _ := remoteCert.Sha256Sum()
issuer := remoteCert.Details.Issuer
// Record the certificate we are actually using
ci.myCert = rc
}
if vpnIp == f.myVpnNet.Addr() {
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
if len(remoteCert.Certificate.Networks()) == 0 {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("cert", remoteCert).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("No networks in certificate")
return
}
var vpnAddrs []netip.Addr
var filteredNetworks []netip.Prefix
certName := remoteCert.Certificate.Name()
fingerprint := remoteCert.Fingerprint
issuer := remoteCert.Certificate.Issuer()
for _, network := range remoteCert.Certificate.Networks() {
vpnAddr := network.Addr()
_, found := f.myVpnAddrsTable.Lookup(vpnAddr)
if found {
f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
@ -126,16 +200,36 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
return
}
// vpnAddrs outside our vpn networks are of no use to us, filter them out
if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok {
continue
}
filteredNetworks = append(filteredNetworks, network)
vpnAddrs = append(vpnAddrs, vpnAddr)
}
if len(vpnAddrs) == 0 {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
return
}
if addr.IsValid() {
if !f.lightHouse.GetRemoteAllowList().Allow(vpnIp, addr.Addr()) {
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
// addr can be invalid when the tunnel is being relayed.
// We only want to apply the remote allow list for direct tunnels here
if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, addr.Addr()) {
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
return
}
}
myIndex, err := generateIndex(f.l)
if err != nil {
f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
@ -147,17 +241,17 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
ConnectionState: ci,
localIndexId: myIndex,
remoteIndexId: hs.Details.InitiatorIndex,
vpnIp: vpnIp,
vpnAddrs: vpnAddrs,
HandshakePacket: make(map[uint8][]byte, 0),
lastHandshakeTime: hs.Details.Time,
relayState: RelayState{
relays: map[netip.Addr]struct{}{},
relayForByIp: map[netip.Addr]*Relay{},
relayForByAddr: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},
}
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
@ -166,13 +260,26 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
Info("Handshake message received")
hs.Details.ResponderIndex = myIndex
hs.Details.Cert = certState.RawCertificateNoKey
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
if hs.Details.Cert == nil {
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("certVersion", ci.myCert.Version()).
Error("Unable to handshake with host because no certificate handshake bytes is available")
return
}
hs.Details.CertVersion = uint32(ci.myCert.Version())
// Update the time in case their clock is way off from ours
hs.Details.Time = uint64(time.Now().UnixNano())
hsBytes, err := hs.Marshal()
if err != nil {
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
@ -183,14 +290,14 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2)
msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes)
if err != nil {
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
return
} else if dKey == nil || eKey == nil {
f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
@ -214,9 +321,9 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
ci.dKey = NewNebulaCipherState(dKey)
ci.eKey = NewNebulaCipherState(eKey)
hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
hostinfo.SetRemote(addr)
hostinfo.CreateRemoteCIDR(remoteCert)
hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
if err != nil {
@ -226,7 +333,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
if existing.SetRemoteIfPreferred(f.hostMap, addr) {
// Send a test packet to ensure the other side has also switched to
// the preferred remote
f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
}
msg = existing.HandshakePacket[2]
@ -234,11 +341,11 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
if addr.IsValid() {
err := f.outside.WriteTo(msg, addr)
if err != nil {
f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr).
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
WithError(err).Error("Failed to send handshake message")
} else {
f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr).
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
Info("Handshake message sent")
}
@ -248,16 +355,16 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
f.l.Error("Handshake send failed: both addr and via are nil.")
return
}
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
f.l.WithField("vpnIp", existing.vpnIp).WithField("relay", via.relayHI.vpnIp).
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
Info("Handshake message sent")
return
}
case ErrExistingHostInfo:
// This means there was an existing tunnel and this handshake was older than the one we are currently based on
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("oldHandshakeTime", existing.lastHandshakeTime).
WithField("newHandshakeTime", hostinfo.lastHandshakeTime).
@ -268,23 +375,23 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
Info("Handshake too old")
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
return
case ErrLocalIndexCollision:
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnIp).
WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnAddrs).
Error("Failed to add HostInfo due to localIndex collision")
return
default:
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
// And we forget to update it here
f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
@ -300,7 +407,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
if addr.IsValid() {
err = f.outside.WriteTo(msg, addr)
if err != nil {
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
@ -308,7 +415,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithError(err).Error("Failed to send handshake")
} else {
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
@ -321,9 +428,12 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
f.l.Error("Handshake send failed: both addr and via are nil.")
return
}
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
// I successfully received a handshake. Just in case I marked this tunnel as 'Disestablished', ensure
// it's correctly marked as working.
via.relayHI.relayState.UpdateRelayForByIdxState(via.remoteIdx, Established)
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
f.l.WithField("vpnIp", vpnIp).WithField("relay", via.relayHI.vpnIp).
f.l.WithField("vpnAddrs", vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
@ -350,8 +460,9 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
hostinfo := hh.hostinfo
if addr.IsValid() {
if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.Addr()) {
f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
// The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list.
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, addr.Addr()) {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
return false
}
}
@ -359,7 +470,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
ci := hostinfo.ConnectionState
msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil {
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
Error("Failed to call noise.ReadMessage")
@ -368,7 +479,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
// near future
return false
} else if dKey == nil || eKey == nil {
f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Error("Noise did not arrive at a key")
@ -380,95 +491,56 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
hs := &NebulaHandshake{}
err = hs.Unmarshal(msg)
if err != nil || hs.Details == nil {
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
return true
}
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool())
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
if err != nil {
e := f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
if f.l.Level > logrus.DebugLevel {
e = e.WithField("cert", remoteCert)
}
e.Error("Invalid certificate from host")
// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
return true
}
vpnIp, ok := netip.AddrFromSlice(remoteCert.Details.Ips[0].IP)
if !ok {
e := f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
if f.l.Level > logrus.DebugLevel {
e = e.WithField("cert", remoteCert)
}
e.Info("Invalid vpn ip from host")
return true
}
vpnIp = vpnIp.Unmap()
certName := remoteCert.Details.Name
fingerprint, _ := remoteCert.Sha256Sum()
issuer := remoteCert.Details.Issuer
// Ensure the right host responded
if vpnIp != hostinfo.vpnIp {
f.l.WithField("intendedVpnIp", hostinfo.vpnIp).WithField("haveVpnIp", vpnIp).
WithField("udpAddr", addr).WithField("certName", certName).
f.l.WithError(err).WithField("udpAddr", addr).
WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("Incorrect host responded to handshake")
// Release our old handshake from pending, it should not continue
f.handshakeManager.DeleteHostInfo(hostinfo)
// Create a new hostinfo/handshake for the intended vpn ip
f.handshakeManager.StartHandshake(hostinfo.vpnIp, func(newHH *HandshakeHostInfo) {
//TODO: this doesnt know if its being added or is being used for caching a packet
// Block the current used address
newHH.hostinfo.remotes = hostinfo.remotes
newHH.hostinfo.remotes.BlockRemote(addr)
// Get the correct remote list for the host we did handshake with
hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp).
WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())).
Info("Blocked addresses for handshakes")
// Swap the packet store to benefit the original intended recipient
newHH.packetStore = hh.packetStore
hh.packetStore = []*cachedPacket{}
// Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down
hostinfo.vpnIp = vpnIp
f.sendCloseTunnel(hostinfo)
})
Info("Handshake did not contain a certificate")
return true
}
// Mark packet 2 as seen so it doesn't show up as missed
ci.window.Update(f.l, 2)
remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
if err != nil {
fp, err := rc.Fingerprint()
if err != nil {
fp = "<error generating certificate fingerprint>"
}
duration := time.Since(hh.startTime).Nanoseconds()
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithField("durationNs", duration).
WithField("sentCachedPackets", len(hh.packetStore)).
Info("Handshake message received")
e := f.l.WithError(err).WithField("udpAddr", addr).
WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithField("certFingerprint", fp).
WithField("certVpnNetworks", rc.Networks())
if f.l.Level >= logrus.DebugLevel {
e = e.WithField("cert", rc)
}
e.Info("Invalid certificate from host")
return true
}
if len(remoteCert.Certificate.Networks()) == 0 {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("cert", remoteCert).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("No networks in certificate")
return true
}
vpnNetworks := remoteCert.Certificate.Networks()
certName := remoteCert.Certificate.Name()
fingerprint := remoteCert.Fingerprint
issuer := remoteCert.Certificate.Issuer()
hostinfo.remoteIndexId = hs.Details.ResponderIndex
hostinfo.lastHandshakeTime = hs.Details.Time
@ -482,13 +554,83 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
if addr.IsValid() {
hostinfo.SetRemote(addr)
} else {
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
}
// Build up the radix for the firewall if we have subnets in the cert
hostinfo.CreateRemoteCIDR(remoteCert)
var vpnAddrs []netip.Addr
var filteredNetworks []netip.Prefix
for _, network := range vpnNetworks {
// vpnAddrs outside our vpn networks are of no use to us, filter them out
vpnAddr := network.Addr()
if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok {
continue
}
// Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp
filteredNetworks = append(filteredNetworks, network)
vpnAddrs = append(vpnAddrs, vpnAddr)
}
if len(vpnAddrs) == 0 {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
return true
}
// Ensure the right host responded
if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) {
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
WithField("udpAddr", addr).WithField("certName", certName).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("Incorrect host responded to handshake")
// Release our old handshake from pending, it should not continue
f.handshakeManager.DeleteHostInfo(hostinfo)
// Create a new hostinfo/handshake for the intended vpn ip
f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
// Block the current used address
newHH.hostinfo.remotes = hostinfo.remotes
newHH.hostinfo.remotes.BlockRemote(addr)
f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).
WithField("vpnNetworks", vpnNetworks).
WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())).
Info("Blocked addresses for handshakes")
// Swap the packet store to benefit the original intended recipient
newHH.packetStore = hh.packetStore
hh.packetStore = []*cachedPacket{}
// Finally, put the correct vpn addrs in the host info, tell them to close the tunnel, and return true to tear down
hostinfo.vpnAddrs = vpnAddrs
f.sendCloseTunnel(hostinfo)
})
return true
}
// Mark packet 2 as seen so it doesn't show up as missed
ci.window.Update(f.l, 2)
duration := time.Since(hh.startTime).Nanoseconds()
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithField("durationNs", duration).
WithField("sentCachedPackets", len(hh.packetStore)).
Info("Handshake message received")
// Build up the radix for the firewall if we have subnets in the cert
hostinfo.vpnAddrs = vpnAddrs
hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
f.handshakeManager.Complete(hostinfo, f)
f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)

View File

@ -7,14 +7,15 @@ import (
"encoding/binary"
"errors"
"net/netip"
"slices"
"sync"
"time"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/udp"
"golang.org/x/exp/slices"
)
const (
@ -118,18 +119,18 @@ func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *Lig
}
}
func (c *HandshakeManager) Run(ctx context.Context) {
clockSource := time.NewTicker(c.config.tryInterval)
func (hm *HandshakeManager) Run(ctx context.Context) {
clockSource := time.NewTicker(hm.config.tryInterval)
defer clockSource.Stop()
for {
select {
case <-ctx.Done():
return
case vpnIP := <-c.trigger:
c.handleOutbound(vpnIP, true)
case vpnIP := <-hm.trigger:
hm.handleOutbound(vpnIP, true)
case now := <-clockSource.C:
c.NextOutboundHandshakeTimerTick(now)
hm.NextOutboundHandshakeTimerTick(now)
}
}
}
@ -137,7 +138,7 @@ func (c *HandshakeManager) Run(ctx context.Context) {
func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
// First remote allow list check before we know the vpnIp
if addr.IsValid() {
if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.Addr()) {
if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(addr.Addr()) {
hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
return
}
@ -159,14 +160,14 @@ func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender,
}
}
func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) {
c.OutboundHandshakeTimer.Advance(now)
func (hm *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) {
hm.OutboundHandshakeTimer.Advance(now)
for {
vpnIp, has := c.OutboundHandshakeTimer.Purge()
vpnIp, has := hm.OutboundHandshakeTimer.Purge()
if !has {
break
}
c.handleOutbound(vpnIp, false)
hm.handleOutbound(vpnIp, false)
}
}
@ -208,7 +209,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
// NB ^ This comment doesn't jive. It's how the thing gets initialized.
// It's the common path. Should it update every time, in case a future LH query/queries give us more info?
if hostinfo.remotes == nil {
hostinfo.remotes = hm.lightHouse.QueryCache(vpnIp)
hostinfo.remotes = hm.lightHouse.QueryCache([]netip.Addr{vpnIp})
}
remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())
@ -223,7 +224,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
hh.lastRemotes = remotes
// TODO: this will generate a load of queries for hosts with only 1 ip
// This will generate a load of queries for hosts with only 1 ip
// (such as ones registered to the lighthouse with only a private IP)
// So we only do it one time after attempting 5 handshakes already.
if len(remotes) <= 1 && hh.counter == 5 {
@ -256,7 +257,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
WithField("initiatorIndex", hostinfo.localIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Handshake message sent")
} else if hm.l.IsLevelEnabled(logrus.DebugLevel) {
} else if hm.l.Level >= logrus.DebugLevel {
hostinfo.logger(hm.l).WithField("udpAddrs", sentTo).
WithField("initiatorIndex", hostinfo.localIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
@ -267,35 +268,118 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
// Send a RelayRequest to all known Relay IP's
for _, relay := range hostinfo.remotes.relays {
// Don't relay to myself, and don't relay through the host I'm trying to connect to
if relay == vpnIp || relay == hm.lightHouse.myVpnNet.Addr() {
// Don't relay to myself
if relay == vpnIp {
continue
}
relayHostInfo := hm.mainHostMap.QueryVpnIp(relay)
// Don't relay through the host I'm trying to connect to
_, found := hm.f.myVpnAddrsTable.Lookup(relay)
if found {
continue
}
relayHostInfo := hm.mainHostMap.QueryVpnAddr(relay)
if relayHostInfo == nil || !relayHostInfo.remote.IsValid() {
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
hm.f.Handshake(relay)
continue
}
// Check the relay HostInfo to see if we already established a relay through it
if existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp); ok {
// Check the relay HostInfo to see if we already established a relay through
existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp)
if !ok {
// No relays exist or requested yet.
if relayHostInfo.remote.IsValid() {
idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested)
if err != nil {
hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap")
}
m := NebulaControl{
Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: idx,
}
switch relayHostInfo.GetCert().Certificate.Version() {
case cert.Version1:
if !hm.f.myVpnAddrs[0].Is4() {
hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
continue
}
if !vpnIp.Is4() {
hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
continue
}
b := hm.f.myVpnAddrs[0].As4()
m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
b = vpnIp.As4()
m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
case cert.Version2:
m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0])
m.RelayToAddr = netAddrToProtoAddr(vpnIp)
default:
hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
continue
}
msg, err := m.Marshal()
if err != nil {
hostinfo.logger(hm.l).
WithError(err).
Error("Failed to marshal Control message to create relay")
} else {
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
hm.l.WithFields(logrus.Fields{
"relayFrom": hm.f.myVpnAddrs[0],
"relayTo": vpnIp,
"initiatorRelayIndex": idx,
"relay": relay}).
Info("send CreateRelayRequest")
}
}
continue
}
switch existingRelay.State {
case Established:
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Send handshake via relay")
hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false)
case Disestablished:
// Mark this relay as 'requested'
relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested)
fallthrough
case Requested:
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request")
//TODO: IPV6-WORK
myVpnIpB := hm.f.myVpnNet.Addr().As4()
theirVpnIpB := vpnIp.As4()
// Re-send the CreateRelay request, in case the previous one was lost.
m := NebulaControl{
Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: existingRelay.LocalIndex,
RelayFromIp: binary.BigEndian.Uint32(myVpnIpB[:]),
RelayToIp: binary.BigEndian.Uint32(theirVpnIpB[:]),
}
switch relayHostInfo.GetCert().Certificate.Version() {
case cert.Version1:
if !hm.f.myVpnAddrs[0].Is4() {
hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
continue
}
if !vpnIp.Is4() {
hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
continue
}
b := hm.f.myVpnAddrs[0].As4()
m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
b = vpnIp.As4()
m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
case cert.Version2:
m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0])
m.RelayToAddr = netAddrToProtoAddr(vpnIp)
default:
hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
continue
}
msg, err := m.Marshal()
if err != nil {
@ -306,52 +390,22 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
// This must send over the hostinfo, not over hm.Hosts[ip]
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
hm.l.WithFields(logrus.Fields{
"relayFrom": hm.f.myVpnNet.Addr(),
"relayFrom": hm.f.myVpnAddrs[0],
"relayTo": vpnIp,
"initiatorRelayIndex": existingRelay.LocalIndex,
"relay": relay}).
Info("send CreateRelayRequest")
}
case PeerRequested:
// PeerRequested only occurs in Forwarding relays, not Terminal relays, and this is a Terminal relay case.
fallthrough
default:
hostinfo.logger(hm.l).
WithField("vpnIp", vpnIp).
WithField("state", existingRelay.State).
WithField("relay", relayHostInfo.vpnIp).
WithField("relay", relay).
Errorf("Relay unexpected state")
}
} else {
// No relays exist or requested yet.
if relayHostInfo.remote.IsValid() {
idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested)
if err != nil {
hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap")
}
//TODO: IPV6-WORK
myVpnIpB := hm.f.myVpnNet.Addr().As4()
theirVpnIpB := vpnIp.As4()
m := NebulaControl{
Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: idx,
RelayFromIp: binary.BigEndian.Uint32(myVpnIpB[:]),
RelayToIp: binary.BigEndian.Uint32(theirVpnIpB[:]),
}
msg, err := m.Marshal()
if err != nil {
hostinfo.logger(hm.l).
WithError(err).
Error("Failed to marshal Control message to create relay")
} else {
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
hm.l.WithFields(logrus.Fields{
"relayFrom": hm.f.myVpnNet.Addr(),
"relayTo": vpnIp,
"initiatorRelayIndex": idx,
"relay": relay}).
Info("send CreateRelayRequest")
}
}
}
}
}
@ -381,10 +435,10 @@ func (hm *HandshakeManager) GetOrHandshake(vpnIp netip.Addr, cacheCb func(*Hands
}
// StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip
func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*HandshakeHostInfo)) *HostInfo {
func (hm *HandshakeManager) StartHandshake(vpnAddr netip.Addr, cacheCb func(*HandshakeHostInfo)) *HostInfo {
hm.Lock()
if hh, ok := hm.vpnIps[vpnIp]; ok {
if hh, ok := hm.vpnIps[vpnAddr]; ok {
// We are already trying to handshake with this vpn ip
if cacheCb != nil {
cacheCb(hh)
@ -394,11 +448,11 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands
}
hostinfo := &HostInfo{
vpnIp: vpnIp,
vpnAddrs: []netip.Addr{vpnAddr},
HandshakePacket: make(map[uint8][]byte, 0),
relayState: RelayState{
relays: map[netip.Addr]struct{}{},
relayForByIp: map[netip.Addr]*Relay{},
relayForByAddr: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},
}
@ -407,9 +461,9 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands
hostinfo: hostinfo,
startTime: time.Now(),
}
hm.vpnIps[vpnIp] = hh
hm.vpnIps[vpnAddr] = hh
hm.metricInitiated.Inc(1)
hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval)
hm.OutboundHandshakeTimer.Add(vpnAddr, hm.config.tryInterval)
if cacheCb != nil {
cacheCb(hh)
@ -417,21 +471,21 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands
// If this is a static host, we don't need to wait for the HostQueryReply
// We can trigger the handshake right now
_, doTrigger := hm.lightHouse.GetStaticHostList()[vpnIp]
_, doTrigger := hm.lightHouse.GetStaticHostList()[vpnAddr]
if !doTrigger {
// Add any calculated remotes, and trigger early handshake if one found
doTrigger = hm.lightHouse.addCalculatedRemotes(vpnIp)
doTrigger = hm.lightHouse.addCalculatedRemotes(vpnAddr)
}
if doTrigger {
select {
case hm.trigger <- vpnIp:
case hm.trigger <- vpnAddr:
default:
}
}
hm.Unlock()
hm.lightHouse.QueryServer(vpnIp)
hm.lightHouse.QueryServer(vpnAddr)
return hostinfo
}
@ -452,14 +506,14 @@ var (
//
// ErrLocalIndexCollision if we already have an entry in the main or pending
// hostmap for the hostinfo.localIndexId.
func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) {
c.mainHostMap.Lock()
defer c.mainHostMap.Unlock()
c.Lock()
defer c.Unlock()
func (hm *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) {
hm.mainHostMap.Lock()
defer hm.mainHostMap.Unlock()
hm.Lock()
defer hm.Unlock()
// Check if we already have a tunnel with this vpn ip
existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp]
existingHostInfo, found := hm.mainHostMap.Hosts[hostinfo.vpnAddrs[0]]
if found && existingHostInfo != nil {
testHostInfo := existingHostInfo
for testHostInfo != nil {
@ -476,31 +530,31 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
return existingHostInfo, ErrExistingHostInfo
}
existingHostInfo.logger(c.l).Info("Taking new handshake")
existingHostInfo.logger(hm.l).Info("Taking new handshake")
}
existingIndex, found := c.mainHostMap.Indexes[hostinfo.localIndexId]
existingIndex, found := hm.mainHostMap.Indexes[hostinfo.localIndexId]
if found {
// We have a collision, but for a different hostinfo
return existingIndex, ErrLocalIndexCollision
}
existingPendingIndex, found := c.indexes[hostinfo.localIndexId]
existingPendingIndex, found := hm.indexes[hostinfo.localIndexId]
if found && existingPendingIndex.hostinfo != hostinfo {
// We have a collision, but for a different hostinfo
return existingPendingIndex.hostinfo, ErrLocalIndexCollision
}
existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
if found && existingRemoteIndex != nil && existingRemoteIndex.vpnIp != hostinfo.vpnIp {
existingRemoteIndex, found := hm.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
if found && existingRemoteIndex != nil && existingRemoteIndex.vpnAddrs[0] != hostinfo.vpnAddrs[0] {
// We have a collision, but this can happen since we can't control
// the remote ID. Just log about the situation as a note.
hostinfo.logger(c.l).
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp).
hostinfo.logger(hm.l).
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs).
Info("New host shadows existing host remoteIndex")
}
c.mainHostMap.unlockedAddHostInfo(hostinfo, f)
hm.mainHostMap.unlockedAddHostInfo(hostinfo, f)
return existingHostInfo, nil
}
@ -518,7 +572,7 @@ func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
// We have a collision, but this can happen since we can't control
// the remote ID. Just log about the situation as a note.
hostinfo.logger(hm.l).
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp).
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs).
Info("New host shadows existing host remoteIndex")
}
@ -555,31 +609,34 @@ func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error {
return errors.New("failed to generate unique localIndexId")
}
func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
c.Lock()
defer c.Unlock()
c.unlockedDeleteHostInfo(hostinfo)
func (hm *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
hm.Lock()
defer hm.Unlock()
hm.unlockedDeleteHostInfo(hostinfo)
}
func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
delete(c.vpnIps, hostinfo.vpnIp)
if len(c.vpnIps) == 0 {
c.vpnIps = map[netip.Addr]*HandshakeHostInfo{}
func (hm *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
for _, addr := range hostinfo.vpnAddrs {
delete(hm.vpnIps, addr)
}
delete(c.indexes, hostinfo.localIndexId)
if len(c.vpnIps) == 0 {
c.indexes = map[uint32]*HandshakeHostInfo{}
if len(hm.vpnIps) == 0 {
hm.vpnIps = map[netip.Addr]*HandshakeHostInfo{}
}
if c.l.Level >= logrus.DebugLevel {
c.l.WithField("hostMap", m{"mapTotalSize": len(c.vpnIps),
"vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
delete(hm.indexes, hostinfo.localIndexId)
if len(hm.indexes) == 0 {
hm.indexes = map[uint32]*HandshakeHostInfo{}
}
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.vpnIps),
"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
Debug("Pending hostmap hostInfo deleted")
}
}
func (hm *HandshakeManager) QueryVpnIp(vpnIp netip.Addr) *HostInfo {
func (hm *HandshakeManager) QueryVpnAddr(vpnIp netip.Addr) *HostInfo {
hh := hm.queryVpnIp(vpnIp)
if hh != nil {
return hh.hostinfo
@ -608,37 +665,37 @@ func (hm *HandshakeManager) queryIndex(index uint32) *HandshakeHostInfo {
return hm.indexes[index]
}
func (c *HandshakeManager) GetPreferredRanges() []netip.Prefix {
return c.mainHostMap.GetPreferredRanges()
func (hm *HandshakeManager) GetPreferredRanges() []netip.Prefix {
return hm.mainHostMap.GetPreferredRanges()
}
func (c *HandshakeManager) ForEachVpnIp(f controlEach) {
c.RLock()
defer c.RUnlock()
func (hm *HandshakeManager) ForEachVpnAddr(f controlEach) {
hm.RLock()
defer hm.RUnlock()
for _, v := range c.vpnIps {
for _, v := range hm.vpnIps {
f(v.hostinfo)
}
}
func (c *HandshakeManager) ForEachIndex(f controlEach) {
c.RLock()
defer c.RUnlock()
func (hm *HandshakeManager) ForEachIndex(f controlEach) {
hm.RLock()
defer hm.RUnlock()
for _, v := range c.indexes {
for _, v := range hm.indexes {
f(v.hostinfo)
}
}
func (c *HandshakeManager) EmitStats() {
c.RLock()
hostLen := len(c.vpnIps)
indexLen := len(c.indexes)
c.RUnlock()
func (hm *HandshakeManager) EmitStats() {
hm.RLock()
hostLen := len(hm.vpnIps)
indexLen := len(hm.indexes)
hm.RUnlock()
metrics.GetOrRegisterGauge("hostmap.pending.hosts", nil).Update(int64(hostLen))
metrics.GetOrRegisterGauge("hostmap.pending.indexes", nil).Update(int64(indexLen))
c.mainHostMap.EmitStats()
hm.mainHostMap.EmitStats()
}
// Utility functions below

View File

@ -14,21 +14,20 @@ import (
func Test_NewHandshakeManagerVpnIp(t *testing.T) {
l := test.NewLogger()
vpncidr := netip.MustParsePrefix("172.1.1.1/24")
localrange := netip.MustParsePrefix("10.1.1.1/24")
ip := netip.MustParseAddr("172.1.1.2")
preferredRanges := []netip.Prefix{localrange}
mainHM := newHostMap(l, vpncidr)
mainHM := newHostMap(l)
mainHM.preferredRanges.Store(&preferredRanges)
lh := newTestLighthouse()
cs := &CertState{
RawCertificate: []byte{},
PrivateKey: []byte{},
Certificate: &cert.NebulaCertificate{},
RawCertificateNoKey: []byte{},
defaultVersion: cert.Version1,
privateKey: []byte{},
v1Cert: &dummyCert{version: cert.Version1},
v1HandshakeBytes: []byte{},
}
blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
@ -42,10 +41,10 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
i2 := blah.StartHandshake(ip, nil)
assert.Same(t, i, i2)
i.remotes = NewRemoteList(nil)
i.remotes = NewRemoteList([]netip.Addr{}, nil)
// Adding something to pending should not affect the main hostmap
assert.Len(t, mainHM.Hosts, 0)
assert.Empty(t, mainHM.Hosts)
// Confirm they are in the pending index list
assert.Contains(t, blah.vpnIps, ip)
@ -80,16 +79,24 @@ func testCountTimerWheelEntries(tw *LockingTimerWheel[netip.Addr]) (c int) {
type mockEncWriter struct {
}
func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) {
func (mw *mockEncWriter) SendMessageToVpnAddr(_ header.MessageType, _ header.MessageSubType, _ netip.Addr, _, _, _ []byte) {
return
}
func (mw *mockEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
func (mw *mockEncWriter) SendVia(_ *HostInfo, _ *Relay, _, _, _ []byte, _ bool) {
return
}
func (mw *mockEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
func (mw *mockEncWriter) SendMessageToHostInfo(_ header.MessageType, _ header.MessageSubType, _ *HostInfo, _, _, _ []byte) {
return
}
func (mw *mockEncWriter) Handshake(vpnIP netip.Addr) {}
func (mw *mockEncWriter) Handshake(_ netip.Addr) {}
func (mw *mockEncWriter) GetHostInfo(_ netip.Addr) *HostInfo {
return nil
}
func (mw *mockEncWriter) GetCertState() *CertState {
return &CertState{defaultVersion: cert.Version2}
}

View File

@ -5,6 +5,7 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type headerTest struct {
@ -111,7 +112,7 @@ func TestHeader_String(t *testing.T) {
func TestHeader_MarshalJSON(t *testing.T) {
b, err := (&H{100, Test, TestRequest, 99, 98, 97}).MarshalJSON()
assert.Nil(t, err)
require.NoError(t, err)
assert.Equal(
t,
"{\"messageCounter\":97,\"remoteIndex\":98,\"reserved\":99,\"subType\":\"testRequest\",\"type\":\"test\",\"version\":100}",

View File

@ -35,6 +35,7 @@ const (
Requested = iota
PeerRequested
Established
Disestablished
)
const (
@ -48,7 +49,7 @@ type Relay struct {
State int
LocalIndex uint32
RemoteIndex uint32
PeerIp netip.Addr
PeerAddr netip.Addr
}
type HostMap struct {
@ -58,7 +59,6 @@ type HostMap struct {
RemoteIndexes map[uint32]*HostInfo
Hosts map[netip.Addr]*HostInfo
preferredRanges atomic.Pointer[[]netip.Prefix]
vpnCIDR netip.Prefix
l *logrus.Logger
}
@ -68,8 +68,11 @@ type HostMap struct {
type RelayState struct {
sync.RWMutex
relays map[netip.Addr]struct{} // Set of VpnIp's of Hosts to use as relays to access this peer
relayForByIp map[netip.Addr]*Relay // Maps VpnIps of peers for which this HostInfo is a relay to some Relay info
relays map[netip.Addr]struct{} // Set of vpnAddr's of Hosts to use as relays to access this peer
// For data race avoidance, the contents of a *Relay are treated immutably. To update a *Relay, copy the existing data,
// modify what needs to be updated, and store the new modified copy in the relayForByIp and relayForByIdx maps (with
// the RelayState Lock held)
relayForByAddr map[netip.Addr]*Relay // Maps vpnAddr of peers for which this HostInfo is a relay to some Relay info
relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info
}
@ -79,6 +82,28 @@ func (rs *RelayState) DeleteRelay(ip netip.Addr) {
delete(rs.relays, ip)
}
func (rs *RelayState) UpdateRelayForByIpState(vpnIp netip.Addr, state int) {
rs.Lock()
defer rs.Unlock()
if r, ok := rs.relayForByAddr[vpnIp]; ok {
newRelay := *r
newRelay.State = state
rs.relayForByAddr[newRelay.PeerAddr] = &newRelay
rs.relayForByIdx[newRelay.LocalIndex] = &newRelay
}
}
func (rs *RelayState) UpdateRelayForByIdxState(idx uint32, state int) {
rs.Lock()
defer rs.Unlock()
if r, ok := rs.relayForByIdx[idx]; ok {
newRelay := *r
newRelay.State = state
rs.relayForByAddr[newRelay.PeerAddr] = &newRelay
rs.relayForByIdx[newRelay.LocalIndex] = &newRelay
}
}
func (rs *RelayState) CopyAllRelayFor() []*Relay {
rs.RLock()
defer rs.RUnlock()
@ -89,10 +114,10 @@ func (rs *RelayState) CopyAllRelayFor() []*Relay {
return ret
}
func (rs *RelayState) GetRelayForByIp(ip netip.Addr) (*Relay, bool) {
func (rs *RelayState) GetRelayForByAddr(addr netip.Addr) (*Relay, bool) {
rs.RLock()
defer rs.RUnlock()
r, ok := rs.relayForByIp[ip]
r, ok := rs.relayForByAddr[addr]
return r, ok
}
@ -115,8 +140,8 @@ func (rs *RelayState) CopyRelayIps() []netip.Addr {
func (rs *RelayState) CopyRelayForIps() []netip.Addr {
rs.RLock()
defer rs.RUnlock()
currentRelays := make([]netip.Addr, 0, len(rs.relayForByIp))
for relayIp := range rs.relayForByIp {
currentRelays := make([]netip.Addr, 0, len(rs.relayForByAddr))
for relayIp := range rs.relayForByAddr {
currentRelays = append(currentRelays, relayIp)
}
return currentRelays
@ -135,7 +160,7 @@ func (rs *RelayState) CopyRelayForIdxs() []uint32 {
func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool {
rs.Lock()
defer rs.Unlock()
r, ok := rs.relayForByIp[vpnIp]
r, ok := rs.relayForByAddr[vpnIp]
if !ok {
return false
}
@ -143,7 +168,7 @@ func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool
newRelay.State = Established
newRelay.RemoteIndex = remoteIdx
rs.relayForByIdx[r.LocalIndex] = &newRelay
rs.relayForByIp[r.PeerIp] = &newRelay
rs.relayForByAddr[r.PeerAddr] = &newRelay
return true
}
@ -158,14 +183,14 @@ func (rs *RelayState) CompleteRelayByIdx(localIdx uint32, remoteIdx uint32) (*Re
newRelay.State = Established
newRelay.RemoteIndex = remoteIdx
rs.relayForByIdx[r.LocalIndex] = &newRelay
rs.relayForByIp[r.PeerIp] = &newRelay
rs.relayForByAddr[r.PeerAddr] = &newRelay
return &newRelay, true
}
func (rs *RelayState) QueryRelayForByIp(vpnIp netip.Addr) (*Relay, bool) {
rs.RLock()
defer rs.RUnlock()
r, ok := rs.relayForByIp[vpnIp]
r, ok := rs.relayForByAddr[vpnIp]
return r, ok
}
@ -179,7 +204,7 @@ func (rs *RelayState) QueryRelayForByIdx(idx uint32) (*Relay, bool) {
func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) {
rs.Lock()
defer rs.Unlock()
rs.relayForByIp[ip] = r
rs.relayForByAddr[ip] = r
rs.relayForByIdx[idx] = r
}
@ -190,9 +215,15 @@ type HostInfo struct {
ConnectionState *ConnectionState
remoteIndexId uint32
localIndexId uint32
vpnIp netip.Addr
// vpnAddrs is a list of vpn addresses assigned to this host that are within our own vpn networks
// The host may have other vpn addresses that are outside our
// vpn networks but were removed because they are not usable
vpnAddrs []netip.Addr
recvError atomic.Uint32
remoteCidr *bart.Table[struct{}]
// networks are both all vpn and unsafe networks assigned to this host
networks *bart.Table[struct{}]
relayState RelayState
// HandshakePacket records the packets used to create this hostinfo
@ -241,28 +272,26 @@ type cachedPacketMetrics struct {
dropped metrics.Counter
}
func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR netip.Prefix, c *config.C) *HostMap {
hm := newHostMap(l, vpnCIDR)
func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap {
hm := newHostMap(l)
hm.reload(c, true)
c.RegisterReloadCallback(func(c *config.C) {
hm.reload(c, false)
})
l.WithField("network", hm.vpnCIDR.String()).
WithField("preferredRanges", hm.GetPreferredRanges()).
l.WithField("preferredRanges", hm.GetPreferredRanges()).
Info("Main HostMap created")
return hm
}
func newHostMap(l *logrus.Logger, vpnCIDR netip.Prefix) *HostMap {
func newHostMap(l *logrus.Logger) *HostMap {
return &HostMap{
Indexes: map[uint32]*HostInfo{},
Relays: map[uint32]*HostInfo{},
RemoteIndexes: map[uint32]*HostInfo{},
Hosts: map[netip.Addr]*HostInfo{},
vpnCIDR: vpnCIDR,
l: l,
}
}
@ -305,17 +334,6 @@ func (hm *HostMap) EmitStats() {
metrics.GetOrRegisterGauge("hostmap.main.relayIndexes", nil).Update(int64(relaysLen))
}
func (hm *HostMap) RemoveRelay(localIdx uint32) {
hm.Lock()
_, ok := hm.Relays[localIdx]
if !ok {
hm.Unlock()
return
}
delete(hm.Relays, localIdx)
hm.Unlock()
}
// DeleteHostInfo will fully unlink the hostinfo and return true if it was the final hostinfo for this vpn ip
func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool {
// Delete the host itself, ensuring it's not modified anymore
@ -335,48 +353,73 @@ func (hm *HostMap) MakePrimary(hostinfo *HostInfo) {
}
func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) {
oldHostinfo := hm.Hosts[hostinfo.vpnIp]
// Get the current primary, if it exists
oldHostinfo := hm.Hosts[hostinfo.vpnAddrs[0]]
// Every address in the hostinfo gets elevated to primary
for _, vpnAddr := range hostinfo.vpnAddrs {
//NOTE: It is possible that we leave a dangling hostinfo here but connection manager works on
// indexes so it should be fine.
hm.Hosts[vpnAddr] = hostinfo
}
// If we are already primary then we won't bother re-linking
if oldHostinfo == hostinfo {
return
}
// Unlink this hostinfo
if hostinfo.prev != nil {
hostinfo.prev.next = hostinfo.next
}
if hostinfo.next != nil {
hostinfo.next.prev = hostinfo.prev
}
hm.Hosts[hostinfo.vpnIp] = hostinfo
// If there wasn't a previous primary then clear out any links
if oldHostinfo == nil {
hostinfo.next = nil
hostinfo.prev = nil
return
}
// Relink the hostinfo as primary
hostinfo.next = oldHostinfo
oldHostinfo.prev = hostinfo
hostinfo.prev = nil
}
func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
primary, ok := hm.Hosts[hostinfo.vpnIp]
for _, addr := range hostinfo.vpnAddrs {
h := hm.Hosts[addr]
for h != nil {
if h == hostinfo {
hm.unlockedInnerDeleteHostInfo(h, addr)
}
h = h.next
}
}
}
func (hm *HostMap) unlockedInnerDeleteHostInfo(hostinfo *HostInfo, addr netip.Addr) {
primary, ok := hm.Hosts[addr]
isLastHostinfo := hostinfo.next == nil && hostinfo.prev == nil
if ok && primary == hostinfo {
// The vpnIp pointer points to the same hostinfo as the local index id, we can remove it
delete(hm.Hosts, hostinfo.vpnIp)
// The vpn addr pointer points to the same hostinfo as the local index id, we can remove it
delete(hm.Hosts, addr)
if len(hm.Hosts) == 0 {
hm.Hosts = map[netip.Addr]*HostInfo{}
}
if hostinfo.next != nil {
// We had more than 1 hostinfo at this vpnip, promote the next in the list to primary
hm.Hosts[hostinfo.vpnIp] = hostinfo.next
// We had more than 1 hostinfo at this vpn addr, promote the next in the list to primary
hm.Hosts[addr] = hostinfo.next
// It is primary, there is no previous hostinfo now
hostinfo.next.prev = nil
}
} else {
// Relink if we were in the middle of multiple hostinfos for this vpn ip
// Relink if we were in the middle of multiple hostinfos for this vpn addr
if hostinfo.prev != nil {
hostinfo.prev.next = hostinfo.next
}
@ -406,10 +449,16 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.Hosts),
"vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
Debug("Hostmap hostInfo deleted")
}
if isLastHostinfo {
// I have lost connectivity to my peers. My relay tunnel is likely broken. Mark the next
// hops as 'Requested' so that new relay tunnels are created in the future.
hm.unlockedDisestablishVpnAddrRelayFor(hostinfo)
}
// Clean up any local relay indexes for which I am acting as a relay hop
for _, localRelayIdx := range hostinfo.relayState.CopyRelayForIdxs() {
delete(hm.Relays, localRelayIdx)
}
@ -448,11 +497,11 @@ func (hm *HostMap) QueryReverseIndex(index uint32) *HostInfo {
}
}
func (hm *HostMap) QueryVpnIp(vpnIp netip.Addr) *HostInfo {
return hm.queryVpnIp(vpnIp, nil)
func (hm *HostMap) QueryVpnAddr(vpnIp netip.Addr) *HostInfo {
return hm.queryVpnAddr(vpnIp, nil)
}
func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp netip.Addr) (*HostInfo, *Relay, error) {
func (hm *HostMap) QueryVpnAddrsRelayFor(targetIps []netip.Addr, relayHostIp netip.Addr) (*HostInfo, *Relay, error) {
hm.RLock()
defer hm.RUnlock()
@ -460,17 +509,42 @@ func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp netip.Addr) (*HostIn
if !ok {
return nil, nil, errors.New("unable to find host")
}
for h != nil {
for _, targetIp := range targetIps {
r, ok := h.relayState.QueryRelayForByIp(targetIp)
if ok && r.State == Established {
return h, r, nil
}
}
h = h.next
}
return nil, nil, errors.New("unable to find host with relay")
}
func (hm *HostMap) queryVpnIp(vpnIp netip.Addr, promoteIfce *Interface) *HostInfo {
func (hm *HostMap) unlockedDisestablishVpnAddrRelayFor(hi *HostInfo) {
for _, relayHostIp := range hi.relayState.CopyRelayIps() {
if h, ok := hm.Hosts[relayHostIp]; ok {
for h != nil {
h.relayState.UpdateRelayForByIpState(hi.vpnAddrs[0], Disestablished)
h = h.next
}
}
}
for _, rs := range hi.relayState.CopyAllRelayFor() {
if rs.Type == ForwardingType {
if h, ok := hm.Hosts[rs.PeerAddr]; ok {
for h != nil {
h.relayState.UpdateRelayForByIpState(hi.vpnAddrs[0], Disestablished)
h = h.next
}
}
}
}
}
func (hm *HostMap) queryVpnAddr(vpnIp netip.Addr, promoteIfce *Interface) *HostInfo {
hm.RLock()
if h, ok := hm.Hosts[vpnIp]; ok {
hm.RUnlock()
@ -491,25 +565,30 @@ func (hm *HostMap) queryVpnIp(vpnIp netip.Addr, promoteIfce *Interface) *HostInf
func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
if f.serveDns {
remoteCert := hostinfo.ConnectionState.peerCert
dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
dnsR.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs)
}
existing := hm.Hosts[hostinfo.vpnIp]
hm.Hosts[hostinfo.vpnIp] = hostinfo
if existing != nil {
hostinfo.next = existing
existing.prev = hostinfo
for _, addr := range hostinfo.vpnAddrs {
hm.unlockedInnerAddHostInfo(addr, hostinfo, f)
}
hm.Indexes[hostinfo.localIndexId] = hostinfo
hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"vpnIp": hostinfo.vpnIp, "mapTotalSize": len(hm.Hosts),
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": hostinfo.vpnIp}}).
hm.l.WithField("hostMap", m{"vpnAddrs": hostinfo.vpnAddrs, "mapTotalSize": len(hm.Hosts),
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}}).
Debug("Hostmap vpnIp added")
}
}
func (hm *HostMap) unlockedInnerAddHostInfo(vpnAddr netip.Addr, hostinfo *HostInfo, f *Interface) {
existing := hm.Hosts[vpnAddr]
hm.Hosts[vpnAddr] = hostinfo
if existing != nil && existing != hostinfo {
hostinfo.next = existing
existing.prev = hostinfo
}
i := 1
check := hostinfo
@ -527,7 +606,7 @@ func (hm *HostMap) GetPreferredRanges() []netip.Prefix {
return *hm.preferredRanges.Load()
}
func (hm *HostMap) ForEachVpnIp(f controlEach) {
func (hm *HostMap) ForEachVpnAddr(f controlEach) {
hm.RLock()
defer hm.RUnlock()
@ -581,11 +660,11 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []netip.Prefix, ifce *Interfac
}
i.nextLHQuery.Store(now + ifce.reQueryWait.Load())
ifce.lightHouse.QueryServer(i.vpnIp)
ifce.lightHouse.QueryServer(i.vpnAddrs[0])
}
}
func (i *HostInfo) GetCert() *cert.NebulaCertificate {
func (i *HostInfo) GetCert() *cert.CachedCertificate {
if i.ConnectionState != nil {
return i.ConnectionState.peerCert
}
@ -596,7 +675,7 @@ func (i *HostInfo) SetRemote(remote netip.AddrPort) {
// We copy here because we likely got this remote from a source that reuses the object
if i.remote != remote {
i.remote = remote
i.remotes.LearnRemote(i.vpnIp, remote)
i.remotes.LearnRemote(i.vpnAddrs[0], remote)
}
}
@ -647,29 +726,20 @@ func (i *HostInfo) RecvErrorExceeded() bool {
return true
}
func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) {
if len(c.Details.Ips) == 1 && len(c.Details.Subnets) == 0 {
func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
if len(networks) == 1 && len(unsafeNetworks) == 0 {
// Simple case, no CIDRTree needed
return
}
remoteCidr := new(bart.Table[struct{}])
for _, ip := range c.Details.Ips {
//TODO: IPV6-WORK what to do when ip is invalid?
nip, _ := netip.AddrFromSlice(ip.IP)
nip = nip.Unmap()
bits, _ := ip.Mask.Size()
remoteCidr.Insert(netip.PrefixFrom(nip, bits), struct{}{})
i.networks = new(bart.Table[struct{}])
for _, network := range networks {
i.networks.Insert(network, struct{}{})
}
for _, n := range c.Details.Subnets {
//TODO: IPV6-WORK what to do when ip is invalid?
nip, _ := netip.AddrFromSlice(n.IP)
nip = nip.Unmap()
bits, _ := n.Mask.Size()
remoteCidr.Insert(netip.PrefixFrom(nip, bits), struct{}{})
for _, network := range unsafeNetworks {
i.networks.Insert(network, struct{}{})
}
i.remoteCidr = remoteCidr
}
func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
@ -677,13 +747,13 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
return logrus.NewEntry(l)
}
li := l.WithField("vpnIp", i.vpnIp).
li := l.WithField("vpnAddrs", i.vpnAddrs).
WithField("localIndex", i.localIndexId).
WithField("remoteIndex", i.remoteIndexId)
if connState := i.ConnectionState; connState != nil {
if peerCert := connState.peerCert; peerCert != nil {
li = li.WithField("certName", peerCert.Details.Name)
li = li.WithField("certName", peerCert.Certificate.Name())
}
}
@ -692,9 +762,9 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
// Utility functions
func localIps(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
//FIXME: This function is pretty garbage
var ips []netip.Addr
var finalAddrs []netip.Addr
ifaces, _ := net.Interfaces()
for _, i := range ifaces {
allow := allowList.AllowName(i.Name)
@ -706,39 +776,36 @@ func localIps(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
continue
}
addrs, _ := i.Addrs()
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
for _, rawAddr := range addrs {
var addr netip.Addr
switch v := rawAddr.(type) {
case *net.IPNet:
//continue
ip = v.IP
addr, _ = netip.AddrFromSlice(v.IP)
case *net.IPAddr:
ip = v.IP
addr, _ = netip.AddrFromSlice(v.IP)
}
nip, ok := netip.AddrFromSlice(ip)
if !ok {
if !addr.IsValid() {
if l.Level >= logrus.DebugLevel {
l.WithField("localIp", ip).Debug("ip was invalid for netip")
l.WithField("localAddr", rawAddr).Debug("addr was invalid")
}
continue
}
nip = nip.Unmap()
addr = addr.Unmap()
//TODO: Filtering out link local for now, this is probably the most correct thing
//TODO: Would be nice to filter out SLAAC MAC based ips as well
if nip.IsLoopback() == false && nip.IsLinkLocalUnicast() == false {
allow := allowList.Allow(nip)
if addr.IsLoopback() == false && addr.IsLinkLocalUnicast() == false {
isAllowed := allowList.Allow(addr)
if l.Level >= logrus.TraceLevel {
l.WithField("localIp", nip).WithField("allow", allow).Trace("localAllowList.Allow")
l.WithField("localAddr", addr).WithField("allowed", isAllowed).Trace("localAllowList.Allow")
}
if !allow {
if !isAllowed {
continue
}
ips = append(ips, nip)
finalAddrs = append(finalAddrs, addr)
}
}
}
return ips
return finalAddrs
}

View File

@ -11,17 +11,14 @@ import (
func TestHostMap_MakePrimary(t *testing.T) {
l := test.NewLogger()
hm := newHostMap(
l,
netip.MustParsePrefix("10.0.0.1/24"),
)
hm := newHostMap(l)
f := &Interface{}
h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1}
h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2}
h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3}
h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4}
h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1}
h2 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 2}
h3 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 3}
h4 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 4}
hm.unlockedAddHostInfo(h4, f)
hm.unlockedAddHostInfo(h3, f)
@ -29,7 +26,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
hm.unlockedAddHostInfo(h1, f)
// Make sure we go h1 -> h2 -> h3 -> h4
prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
prim := hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h1.localIndexId, prim.localIndexId)
assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@ -44,7 +41,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
hm.MakePrimary(h3)
// Make sure we go h3 -> h1 -> h2 -> h4
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h3.localIndexId, prim.localIndexId)
assert.Equal(t, h1.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@ -59,7 +56,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
hm.MakePrimary(h4)
// Make sure we go h4 -> h3 -> h1 -> h2
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h4.localIndexId, prim.localIndexId)
assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@ -74,7 +71,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
hm.MakePrimary(h4)
// Make sure we go h4 -> h3 -> h1 -> h2
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h4.localIndexId, prim.localIndexId)
assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@ -88,19 +85,16 @@ func TestHostMap_MakePrimary(t *testing.T) {
func TestHostMap_DeleteHostInfo(t *testing.T) {
l := test.NewLogger()
hm := newHostMap(
l,
netip.MustParsePrefix("10.0.0.1/24"),
)
hm := newHostMap(l)
f := &Interface{}
h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1}
h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2}
h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3}
h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4}
h5 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 5}
h6 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 6}
h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1}
h2 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 2}
h3 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 3}
h4 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 4}
h5 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 5}
h6 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 6}
hm.unlockedAddHostInfo(h6, f)
hm.unlockedAddHostInfo(h5, f)
@ -116,7 +110,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h)
// Make sure we go h1 -> h2 -> h3 -> h4 -> h5
prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
prim := hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h1.localIndexId, prim.localIndexId)
assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@ -135,7 +129,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h1.next)
// Make sure we go h2 -> h3 -> h4 -> h5
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h2.localIndexId, prim.localIndexId)
assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@ -153,7 +147,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h3.next)
// Make sure we go h2 -> h4 -> h5
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h2.localIndexId, prim.localIndexId)
assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@ -169,7 +163,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h5.next)
// Make sure we go h2 -> h4
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h2.localIndexId, prim.localIndexId)
assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@ -183,7 +177,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h2.next)
// Make sure we only have h4
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h4.localIndexId, prim.localIndexId)
assert.Nil(t, prim.prev)
assert.Nil(t, prim.next)
@ -195,7 +189,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h4.next)
// Make sure we have nil
prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
assert.Nil(t, prim)
}
@ -203,11 +197,7 @@ func TestHostMap_reload(t *testing.T) {
l := test.NewLogger()
c := config.NewC(l)
hm := NewHostMapFromConfig(
l,
netip.MustParsePrefix("10.0.0.1/24"),
c,
)
hm := NewHostMapFromConfig(l, c)
toS := func(ipn []netip.Prefix) []string {
var s []string

View File

@ -9,8 +9,8 @@ import (
"net/netip"
)
func (i *HostInfo) GetVpnIp() netip.Addr {
return i.vpnIp
func (i *HostInfo) GetVpnAddrs() []netip.Addr {
return i.vpnAddrs
}
func (i *HostInfo) GetLocalIndex() uint32 {

View File

@ -20,14 +20,18 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
}
// Ignore local broadcast packets
if f.dropLocalBroadcast && fwPacket.RemoteIP == f.myBroadcastAddr {
if f.dropLocalBroadcast {
_, found := f.myBroadcastAddrsTable.Lookup(fwPacket.RemoteAddr)
if found {
return
}
}
if fwPacket.RemoteIP == f.myVpnNet.Addr() {
_, found := f.myVpnAddrsTable.Lookup(fwPacket.RemoteAddr)
if found {
// Immediately forward packets from self to self.
// This should only happen on Darwin-based and FreeBSD hosts, which
// routes packets from the Nebula IP to the Nebula IP through the Nebula
// routes packets from the Nebula addr to the Nebula addr through the Nebula
// TUN device.
if immediatelyForwardToSelf {
_, err := f.readers[q].Write(packet)
@ -36,25 +40,25 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
}
}
// Otherwise, drop. On linux, we should never see these packets - Linux
// routes packets from the nebula IP to the nebula IP through the loopback device.
// routes packets from the nebula addr to the nebula addr through the loopback device.
return
}
// Ignore multicast packets
if f.dropMulticast && fwPacket.RemoteIP.IsMulticast() {
if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() {
return
}
hostinfo, ready := f.getOrHandshake(fwPacket.RemoteIP, func(hh *HandshakeHostInfo) {
hostinfo, ready := f.getOrHandshake(fwPacket.RemoteAddr, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
})
if hostinfo == nil {
f.rejectInside(packet, out, q)
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnIp", fwPacket.RemoteIP).
f.l.WithField("vpnAddr", fwPacket.RemoteAddr).
WithField("fwPacket", fwPacket).
Debugln("dropping outbound packet, vpnIp not in our CIDR or in unsafe routes")
Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks")
}
return
}
@ -117,21 +121,22 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
}
func (f *Interface) Handshake(vpnIp netip.Addr) {
f.getOrHandshake(vpnIp, nil)
func (f *Interface) Handshake(vpnAddr netip.Addr) {
f.getOrHandshake(vpnAddr, nil)
}
// getOrHandshake returns nil if the vpnIp is not routable.
// getOrHandshake returns nil if the vpnAddr is not routable.
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel
func (f *Interface) getOrHandshake(vpnIp netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
if !f.myVpnNet.Contains(vpnIp) {
vpnIp = f.inside.RouteFor(vpnIp)
if !vpnIp.IsValid() {
func (f *Interface) getOrHandshake(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
_, found := f.myVpnNetworksTable.Lookup(vpnAddr)
if !found {
vpnAddr = f.inside.RouteFor(vpnAddr)
if !vpnAddr.IsValid() {
return nil, false
}
}
return f.handshakeManager.GetOrHandshake(vpnIp, cacheCallback)
return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback)
}
func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
@ -156,16 +161,16 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0)
}
// SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) {
hostInfo, ready := f.getOrHandshake(vpnIp, func(hh *HandshakeHostInfo) {
// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr
func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) {
hostInfo, ready := f.getOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
})
if hostInfo == nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnIp", vpnIp).
Debugln("dropping SendMessageToVpnIp, vpnIp not in our CIDR or in unsafe routes")
f.l.WithField("vpnAddr", vpnAddr).
Debugln("dropping SendMessageToVpnAddr, vpnAddr not in our vpn networks or in unsafe routes")
}
return
}
@ -258,7 +263,6 @@ func (f *Interface) SendVia(via *HostInfo,
func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int) {
if ci.eKey == nil {
//TODO: log warning
return
}
useRelay := !remote.IsValid() && !hostinfo.remote.IsValid()
@ -285,14 +289,14 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
f.connectionManager.Out(hostinfo.localIndexId)
// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
// all our IPs and enable a faster roaming.
// all our addrs and enable a faster roaming.
if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount {
//NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is
// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
f.lightHouse.QueryServer(hostinfo.vpnIp)
f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
hostinfo.lastRebindCount = f.rebindCount
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnIp", hostinfo.vpnIp).Debug("Lighthouse update triggered for punch due to rebind counter")
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter")
}
}
@ -324,7 +328,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
} else {
// Try to send via a relay
for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
relayHostInfo, relay, err := f.hostMap.QueryVpnIpRelayFor(hostinfo.vpnIp, relayIP)
relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
if err != nil {
hostinfo.relayState.DeleteRelay(relayIP)
hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")

View File

@ -2,7 +2,6 @@ package nebula
import (
"context"
"encoding/binary"
"errors"
"fmt"
"io"
@ -12,6 +11,7 @@ import (
"sync/atomic"
"time"
"github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
@ -28,7 +28,6 @@ type InterfaceConfig struct {
Outside udp.Conn
Inside overlay.Device
pki *PKI
Cipher string
Firewall *Firewall
ServeDns bool
HandshakeManager *HandshakeManager
@ -56,15 +55,17 @@ type Interface struct {
outside udp.Conn
inside overlay.Device
pki *PKI
cipher string
firewall *Firewall
connectionManager *connectionManager
handshakeManager *HandshakeManager
serveDns bool
createTime time.Time
lightHouse *LightHouse
myBroadcastAddr netip.Addr
myVpnNet netip.Prefix
myBroadcastAddrsTable *bart.Table[struct{}]
myVpnAddrs []netip.Addr // A list of addresses assigned to us via our certificate
myVpnAddrsTable *bart.Table[struct{}] // A table of addresses assigned to us via our certificate
myVpnNetworks []netip.Prefix // A list of networks assigned to us via our certificate
myVpnNetworksTable *bart.Table[struct{}] // A table of networks assigned to us via our certificate
dropLocalBroadcast bool
dropMulticast bool
routines int
@ -102,9 +103,11 @@ type EncWriter interface {
out []byte,
nocopy bool,
)
SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte)
SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte)
SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte)
Handshake(vpnIp netip.Addr)
Handshake(vpnAddr netip.Addr)
GetHostInfo(vpnAddr netip.Addr) *HostInfo
GetCertState() *CertState
}
type sendRecvErrorConfig uint8
@ -115,10 +118,10 @@ const (
sendRecvErrorPrivate
)
func (s sendRecvErrorConfig) ShouldSendRecvError(ip netip.AddrPort) bool {
func (s sendRecvErrorConfig) ShouldSendRecvError(endpoint netip.AddrPort) bool {
switch s {
case sendRecvErrorPrivate:
return ip.Addr().IsPrivate()
return endpoint.Addr().IsPrivate()
case sendRecvErrorAlways:
return true
case sendRecvErrorNever:
@ -155,34 +158,12 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
return nil, errors.New("no firewall rules")
}
certificate := c.pki.GetCertState().Certificate
myVpnAddr, ok := netip.AddrFromSlice(certificate.Details.Ips[0].IP)
if !ok {
return nil, fmt.Errorf("invalid ip address in certificate: %s", certificate.Details.Ips[0].IP)
}
myVpnMask, ok := netip.AddrFromSlice(certificate.Details.Ips[0].Mask)
if !ok {
return nil, fmt.Errorf("invalid ip mask in certificate: %s", certificate.Details.Ips[0].Mask)
}
myVpnAddr = myVpnAddr.Unmap()
myVpnMask = myVpnMask.Unmap()
if myVpnAddr.BitLen() != myVpnMask.BitLen() {
return nil, fmt.Errorf("ip address and mask are different lengths in certificate")
}
ones, _ := certificate.Details.Ips[0].Mask.Size()
myVpnNet := netip.PrefixFrom(myVpnAddr, ones)
cs := c.pki.getCertState()
ifce := &Interface{
pki: c.pki,
hostMap: c.HostMap,
outside: c.Outside,
inside: c.Inside,
cipher: c.Cipher,
firewall: c.Firewall,
serveDns: c.ServeDns,
handshakeManager: c.HandshakeManager,
@ -194,7 +175,11 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
version: c.version,
writers: make([]udp.Conn, c.routines),
readers: make([]io.ReadWriteCloser, c.routines),
myVpnNet: myVpnNet,
myVpnNetworks: cs.myVpnNetworks,
myVpnNetworksTable: cs.myVpnNetworksTable,
myVpnAddrs: cs.myVpnAddrs,
myVpnAddrsTable: cs.myVpnAddrsTable,
myBroadcastAddrsTable: cs.myVpnBroadcastAddrsTable,
relayManager: c.relayManager,
conntrackCacheTimeout: c.ConntrackCacheTimeout,
@ -209,12 +194,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
l: c.l,
}
if myVpnAddr.Is4() {
addr := myVpnNet.Masked().Addr().As4()
binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(certificate.Details.Ips[0].Mask))
ifce.myBroadcastAddr = netip.AddrFrom4(addr)
}
ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
ifce.reQueryEvery.Store(c.reQueryEvery)
ifce.reQueryWait.Store(int64(c.reQueryWait))
@ -235,7 +214,7 @@ func (f *Interface) activate() {
f.l.WithError(err).Error("Failed to get udp listen address")
}
f.l.WithField("interface", f.inside.Name()).WithField("network", f.inside.Cidr().String()).
f.l.WithField("interface", f.inside.Name()).WithField("networks", f.myVpnNetworks).
WithField("build", f.version).WithField("udpAddr", addr).
WithField("boringcrypto", boringEnabled()).
Info("Nebula interface is active")
@ -276,16 +255,22 @@ func (f *Interface) listenOut(i int) {
runtime.LockOSThread()
var li udp.Conn
// TODO clean this up with a coherent interface for each outside connection
if i > 0 {
li = f.writers[i]
} else {
li = f.outside
}
ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
lhh := f.lightHouse.NewRequestHandler()
conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
li.ListenOut(readOutsidePackets(f), lhHandleRequest(lhh, f), conntrackCache, i)
plaintext := make([]byte, udp.MTU)
h := &header.H{}
fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12)
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
})
}
func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
@ -342,7 +327,7 @@ func (f *Interface) reloadFirewall(c *config.C) {
return
}
fw, err := NewFirewallFromConfig(f.l, f.pki.GetCertState().Certificate, c)
fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c)
if err != nil {
f.l.WithError(err).Error("Error while creating firewall during reload")
return
@ -425,6 +410,8 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
udpStats := udp.NewUDPStatsEmitter(f.writers)
certExpirationGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil)
certDefaultVersion := metrics.GetOrRegisterGauge("certificate.default_version", nil)
certMaxVersion := metrics.GetOrRegisterGauge("certificate.max_version", nil)
for {
select {
@ -434,9 +421,28 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
f.firewall.EmitStats()
f.handshakeManager.EmitStats()
udpStats()
certExpirationGauge.Update(int64(f.pki.GetCertState().Certificate.Details.NotAfter.Sub(time.Now()) / time.Second))
certState := f.pki.getCertState()
defaultCrt := certState.GetDefaultCertificate()
certExpirationGauge.Update(int64(defaultCrt.NotAfter().Sub(time.Now()) / time.Second))
certDefaultVersion.Update(int64(defaultCrt.Version()))
// Report the max certificate version we are capable of using
if certState.v2Cert != nil {
certMaxVersion.Update(int64(certState.v2Cert.Version()))
} else {
certMaxVersion.Update(int64(certState.v1Cert.Version()))
}
}
}
}
func (f *Interface) GetHostInfo(vpnIp netip.Addr) *HostInfo {
return f.hostMap.QueryVpnAddr(vpnIp)
}
func (f *Interface) GetCertState() *CertState {
return f.pki.getCertState()
}
func (f *Interface) Close() error {

View File

@ -6,8 +6,6 @@ import (
"golang.org/x/net/ipv4"
)
//TODO: IPV6-WORK can probably delete this
const (
// Need 96 bytes for the largest reject packet:
// - 20 byte ipv4 header

File diff suppressed because it is too large Load Diff

View File

@ -7,69 +7,61 @@ import (
"net/netip"
"testing"
"github.com/gaissmai/bart"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/yaml.v2"
)
//TODO: Add a test to ensure udpAddr is copied and not reused
func TestOldIPv4Only(t *testing.T) {
// This test ensures our new ipv6 enabled LH protobuf IpAndPorts works with the old style to enable backwards compatibility
b := []byte{8, 129, 130, 132, 80, 16, 10}
var m Ip4AndPort
var m V4AddrPort
err := m.Unmarshal(b)
assert.NoError(t, err)
require.NoError(t, err)
ip := netip.MustParseAddr("10.1.1.1")
bp := ip.As4()
assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetIp())
}
func TestNewLhQuery(t *testing.T) {
myIp, err := netip.ParseAddr("192.1.1.1")
assert.NoError(t, err)
// Generating a new lh query should work
a := NewLhQueryByInt(myIp)
// The result should be a nebulameta protobuf
assert.IsType(t, &NebulaMeta{}, a)
// It should also Marshal fine
b, err := a.Marshal()
assert.Nil(t, err)
// and then Unmarshal fine
n := &NebulaMeta{}
err = n.Unmarshal(b)
assert.Nil(t, err)
assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetAddr())
}
func Test_lhStaticMapping(t *testing.T) {
l := test.NewLogger()
myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
nt := new(bart.Table[struct{}])
nt.Insert(myVpnNet, struct{}{})
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
}
lh1 := "10.128.0.2"
c := config.NewC(l)
c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}}
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
_, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
assert.Nil(t, err)
_, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
require.NoError(t, err)
lh2 := "10.128.0.3"
c = config.NewC(l)
c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}}
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}}
_, err = NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry")
_, err = NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
require.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry")
}
func TestReloadLighthouseInterval(t *testing.T) {
l := test.NewLogger()
myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
nt := new(bart.Table[struct{}])
nt.Insert(myVpnNet, struct{}{})
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
}
lh1 := "10.128.0.2"
c := config.NewC(l)
@ -79,77 +71,82 @@ func TestReloadLighthouseInterval(t *testing.T) {
}
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
assert.NoError(t, err)
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
require.NoError(t, err)
lh.ifce = &mockEncWriter{}
// The first one routine is kicked off by main.go currently, lets make sure that one dies
assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 5"))
require.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 5"))
assert.Equal(t, int64(5), lh.interval.Load())
// Subsequent calls are killed off by the LightHouse.Reload function
assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 10"))
require.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 10"))
assert.Equal(t, int64(10), lh.interval.Load())
// If this completes then nothing is stealing our reload routine
assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 11"))
require.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 11"))
assert.Equal(t, int64(11), lh.interval.Load())
}
func BenchmarkLighthouseHandleRequest(b *testing.B) {
l := test.NewLogger()
myVpnNet := netip.MustParsePrefix("10.128.0.1/0")
nt := new(bart.Table[struct{}])
nt.Insert(myVpnNet, struct{}{})
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
}
c := config.NewC(l)
lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
if !assert.NoError(b, err) {
b.Fatal()
}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
require.NoError(b, err)
hAddr := netip.MustParseAddrPort("4.5.6.7:12345")
hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346")
vpnIp3 := netip.MustParseAddr("0.0.0.3")
lh.addrMap[vpnIp3] = NewRemoteList(nil)
lh.addrMap[vpnIp3] = NewRemoteList([]netip.Addr{vpnIp3}, nil)
lh.addrMap[vpnIp3].unlockedSetV4(
vpnIp3,
vpnIp3,
[]*Ip4AndPort{
NewIp4AndPortFromNetIP(hAddr.Addr(), hAddr.Port()),
NewIp4AndPortFromNetIP(hAddr2.Addr(), hAddr2.Port()),
[]*V4AddrPort{
netAddrToProtoV4AddrPort(hAddr.Addr(), hAddr.Port()),
netAddrToProtoV4AddrPort(hAddr2.Addr(), hAddr2.Port()),
},
func(netip.Addr, *Ip4AndPort) bool { return true },
func(netip.Addr, *V4AddrPort) bool { return true },
)
rAddr := netip.MustParseAddrPort("1.2.2.3:12345")
rAddr2 := netip.MustParseAddrPort("1.2.2.3:12346")
vpnIp2 := netip.MustParseAddr("0.0.0.3")
lh.addrMap[vpnIp2] = NewRemoteList(nil)
lh.addrMap[vpnIp2] = NewRemoteList([]netip.Addr{vpnIp2}, nil)
lh.addrMap[vpnIp2].unlockedSetV4(
vpnIp3,
vpnIp3,
[]*Ip4AndPort{
NewIp4AndPortFromNetIP(rAddr.Addr(), rAddr.Port()),
NewIp4AndPortFromNetIP(rAddr2.Addr(), rAddr2.Port()),
[]*V4AddrPort{
netAddrToProtoV4AddrPort(rAddr.Addr(), rAddr.Port()),
netAddrToProtoV4AddrPort(rAddr2.Addr(), rAddr2.Port()),
},
func(netip.Addr, *Ip4AndPort) bool { return true },
func(netip.Addr, *V4AddrPort) bool { return true },
)
mw := &mockEncWriter{}
hi := []netip.Addr{vpnIp2}
b.Run("notfound", func(b *testing.B) {
lhh := lh.NewRequestHandler()
req := &NebulaMeta{
Type: NebulaMeta_HostQuery,
Details: &NebulaMetaDetails{
VpnIp: 4,
Ip4AndPorts: nil,
OldVpnAddr: 4,
V4AddrPorts: nil,
},
}
p, err := req.Marshal()
assert.NoError(b, err)
require.NoError(b, err)
for n := 0; n < b.N; n++ {
lhh.HandleRequest(rAddr, vpnIp2, p, mw)
lhh.HandleRequest(rAddr, hi, p, mw)
}
})
b.Run("found", func(b *testing.B) {
@ -157,15 +154,15 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
req := &NebulaMeta{
Type: NebulaMeta_HostQuery,
Details: &NebulaMetaDetails{
VpnIp: 3,
Ip4AndPorts: nil,
OldVpnAddr: 3,
V4AddrPorts: nil,
},
}
p, err := req.Marshal()
assert.NoError(b, err)
require.NoError(b, err)
for n := 0; n < b.N; n++ {
lhh.HandleRequest(rAddr, vpnIp2, p, mw)
lhh.HandleRequest(rAddr, hi, p, mw)
}
})
}
@ -197,40 +194,49 @@ func TestLighthouse_Memory(t *testing.T) {
c := config.NewC(l)
c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil)
assert.NoError(t, err)
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
nt := new(bart.Table[struct{}])
nt.Insert(myVpnNet, struct{}{})
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
lh.ifce = &mockEncWriter{}
require.NoError(t, err)
lhh := lh.NewRequestHandler()
// Test that my first update responds with just that
newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr2}, lhh)
r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr2)
assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr2)
// Ensure we don't accumulate addresses
newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr3}, lhh)
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr3)
assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr3)
// Grow it back to 2
newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr4}, lhh)
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4)
// Update a different host and ask about it
newLHHostUpdate(theirUdpAddr0, theirVpnIp, []netip.AddrPort{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh)
r = newLHHostRequest(theirUdpAddr0, theirVpnIp, theirVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
assertIp4InArray(t, r.msg.Details.V4AddrPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
// Have both hosts ask about the other
r = newLHHostRequest(theirUdpAddr0, theirVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4)
r = newLHHostRequest(myUdpAddr0, myVpnIp, theirVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
assertIp4InArray(t, r.msg.Details.V4AddrPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
// Make sure we didn't get changed
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4)
// Ensure proper ordering and limiting
// Send 12 addrs, get 10 back, the last 2 removed, allowing the duplicate to remain (clients dedupe)
@ -255,7 +261,7 @@ func TestLighthouse_Memory(t *testing.T) {
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(
t,
r.msg.Details.Ip4AndPorts,
r.msg.Details.V4AddrPorts,
myUdpAddr1, myUdpAddr2, myUdpAddr3, myUdpAddr4, myUdpAddr5, myUdpAddr5, myUdpAddr6, myUdpAddr7, myUdpAddr8, myUdpAddr9,
)
@ -265,7 +271,7 @@ func TestLighthouse_Memory(t *testing.T) {
good := netip.MustParseAddrPort("1.128.0.99:4242")
newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{bad1, bad2, good}, lhh)
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, good)
assertIp4InArray(t, r.msg.Details.V4AddrPorts, good)
}
func TestLighthouse_reload(t *testing.T) {
@ -273,8 +279,17 @@ func TestLighthouse_reload(t *testing.T) {
c := config.NewC(l)
c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil)
assert.NoError(t, err)
myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
nt := new(bart.Table[struct{}])
nt.Insert(myVpnNet, struct{}{})
cs := &CertState{
myVpnNetworks: []netip.Prefix{myVpnNet},
myVpnNetworksTable: nt,
}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
require.NoError(t, err)
nc := map[interface{}]interface{}{
"static_host_map": map[interface{}]interface{}{
@ -282,21 +297,24 @@ func TestLighthouse_reload(t *testing.T) {
},
}
rc, err := yaml.Marshal(nc)
assert.NoError(t, err)
require.NoError(t, err)
c.ReloadConfigString(string(rc))
err = lh.reload(c, false)
assert.NoError(t, err)
require.NoError(t, err)
}
func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply {
//TODO: IPV6-WORK
bip := queryVpnIp.As4()
req := &NebulaMeta{
Type: NebulaMeta_HostQuery,
Details: &NebulaMetaDetails{
VpnIp: binary.BigEndian.Uint32(bip[:]),
},
Details: &NebulaMetaDetails{},
}
if queryVpnIp.Is4() {
bip := queryVpnIp.As4()
req.Details.OldVpnAddr = binary.BigEndian.Uint32(bip[:])
} else {
req.Details.VpnAddr = netAddrToProtoAddr(queryVpnIp)
}
b, err := req.Marshal()
@ -308,23 +326,29 @@ func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, l
w := &testEncWriter{
metaFilter: &filter,
}
lhh.HandleRequest(fromAddr, myVpnIp, b, w)
lhh.HandleRequest(fromAddr, []netip.Addr{myVpnIp}, b, w)
return w.lastReply
}
func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.AddrPort, lhh *LightHouseHandler) {
//TODO: IPV6-WORK
bip := vpnIp.As4()
req := &NebulaMeta{
Type: NebulaMeta_HostUpdateNotification,
Details: &NebulaMetaDetails{
VpnIp: binary.BigEndian.Uint32(bip[:]),
Ip4AndPorts: make([]*Ip4AndPort, len(addrs)),
},
Details: &NebulaMetaDetails{},
}
for k, v := range addrs {
req.Details.Ip4AndPorts[k] = NewIp4AndPortFromNetIP(v.Addr(), v.Port())
if vpnIp.Is4() {
bip := vpnIp.As4()
req.Details.OldVpnAddr = binary.BigEndian.Uint32(bip[:])
} else {
req.Details.VpnAddr = netAddrToProtoAddr(vpnIp)
}
for _, v := range addrs {
if v.Addr().Is4() {
req.Details.V4AddrPorts = append(req.Details.V4AddrPorts, netAddrToProtoV4AddrPort(v.Addr(), v.Port()))
} else {
req.Details.V6AddrPorts = append(req.Details.V6AddrPorts, netAddrToProtoV6AddrPort(v.Addr(), v.Port()))
}
}
b, err := req.Marshal()
@ -333,75 +357,9 @@ func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.Ad
}
w := &testEncWriter{}
lhh.HandleRequest(fromAddr, vpnIp, b, w)
lhh.HandleRequest(fromAddr, []netip.Addr{vpnIp}, b, w)
}
//TODO: this is a RemoteList test
//func Test_lhRemoteAllowList(t *testing.T) {
// l := NewLogger()
// c := NewConfig(l)
// c.Settings["remoteallowlist"] = map[interface{}]interface{}{
// "10.20.0.0/12": false,
// }
// allowList, err := c.GetAllowList("remoteallowlist", false)
// assert.Nil(t, err)
//
// lh1 := "10.128.0.2"
// lh1IP := net.ParseIP(lh1)
//
// udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
//
// lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
// lh.SetRemoteAllowList(allowList)
//
// // A disallowed ip should not enter the cache but we should end up with an empty entry in the addrMap
// remote1IP := net.ParseIP("10.20.0.3")
// remotes := lh.unlockedGetRemoteList(ip2int(remote1IP))
// remotes.unlockedPrependV4(ip2int(remote1IP), NewIp4AndPort(remote1IP, 4242))
// assert.NotNil(t, lh.addrMap[ip2int(remote1IP)])
// assert.Empty(t, lh.addrMap[ip2int(remote1IP)].CopyAddrs([]*net.IPNet{}))
//
// // Make sure a good ip enters the cache and addrMap
// remote2IP := net.ParseIP("10.128.0.3")
// remote2UDPAddr := NewUDPAddr(remote2IP, uint16(4242))
// lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remote2UDPAddr.IP, uint32(remote2UDPAddr.Port)), false, false)
// assertUdpAddrInArray(t, lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}), remote2UDPAddr)
//
// // Another good ip gets into the cache, ordering is inverted
// remote3IP := net.ParseIP("10.128.0.4")
// remote3UDPAddr := NewUDPAddr(remote3IP, uint16(4243))
// lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remote3UDPAddr.IP, uint32(remote3UDPAddr.Port)), false, false)
// assertUdpAddrInArray(t, lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}), remote2UDPAddr, remote3UDPAddr)
//
// // If we exceed the length limit we should only have the most recent addresses
// addedAddrs := []*udpAddr{}
// for i := 0; i < 11; i++ {
// remoteUDPAddr := NewUDPAddr(net.IP{10, 128, 0, 4}, uint16(4243+i))
// lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remoteUDPAddr.IP, uint32(remoteUDPAddr.Port)), false, false)
// // The first entry here is a duplicate, don't add it to the assert list
// if i != 0 {
// addedAddrs = append(addedAddrs, remoteUDPAddr)
// }
// }
//
// // We should only have the last 10 of what we tried to add
// assert.True(t, len(addedAddrs) >= 10, "We should have tried to add at least 10 addresses")
// assertUdpAddrInArray(
// t,
// lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}),
// addedAddrs[0],
// addedAddrs[1],
// addedAddrs[2],
// addedAddrs[3],
// addedAddrs[4],
// addedAddrs[5],
// addedAddrs[6],
// addedAddrs[7],
// addedAddrs[8],
// addedAddrs[9],
// )
//}
type testLhReply struct {
nebType header.MessageType
nebSubType header.MessageSubType
@ -412,6 +370,7 @@ type testLhReply struct {
type testEncWriter struct {
lastReply testLhReply
metaFilter *NebulaMeta_MessageType
protocolVersion cert.Version
}
func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
@ -426,7 +385,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M
tw.lastReply = testLhReply{
nebType: t,
nebSubType: st,
vpnIp: hostinfo.vpnIp,
vpnIp: hostinfo.vpnAddrs[0],
msg: msg,
}
}
@ -436,7 +395,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M
}
}
func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, _, _ []byte) {
func (tw *testEncWriter) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, _, _ []byte) {
msg := &NebulaMeta{}
err := msg.Unmarshal(p)
if tw.metaFilter == nil || msg.Type == *tw.metaFilter {
@ -453,17 +412,84 @@ func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.Mess
}
}
func (tw *testEncWriter) GetHostInfo(vpnIp netip.Addr) *HostInfo {
return nil
}
func (tw *testEncWriter) GetCertState() *CertState {
return &CertState{defaultVersion: tw.protocolVersion}
}
// assertIp4InArray asserts every address in want is at the same position in have and that the lengths match
func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...netip.AddrPort) {
func assertIp4InArray(t *testing.T, have []*V4AddrPort, want ...netip.AddrPort) {
if !assert.Len(t, have, len(want)) {
return
}
for k, w := range want {
//TODO: IPV6-WORK
h := AddrPortFromIp4AndPort(have[k])
h := protoV4AddrPortToNetAddrPort(have[k])
if !(h == w) {
assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v, found %v", w, k, h))
}
}
}
func Test_findNetworkUnion(t *testing.T) {
var out netip.Addr
var ok bool
tenDot := netip.MustParsePrefix("10.0.0.0/8")
oneSevenTwo := netip.MustParsePrefix("172.16.0.0/16")
fe80 := netip.MustParsePrefix("fe80::/8")
fc00 := netip.MustParsePrefix("fc00::/7")
a1 := netip.MustParseAddr("10.0.0.1")
afe81 := netip.MustParseAddr("fe80::1")
//simple
out, ok = findNetworkUnion([]netip.Prefix{tenDot}, []netip.Addr{a1})
assert.True(t, ok)
assert.Equal(t, out, a1)
//mixed lengths
out, ok = findNetworkUnion([]netip.Prefix{tenDot}, []netip.Addr{a1, afe81})
assert.True(t, ok)
assert.Equal(t, out, a1)
out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo}, []netip.Addr{a1})
assert.True(t, ok)
assert.Equal(t, out, a1)
//mixed family
out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{a1})
assert.True(t, ok)
assert.Equal(t, out, a1)
out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{a1, afe81})
assert.True(t, ok)
assert.Equal(t, out, a1)
//ordering
out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{afe81, a1})
assert.True(t, ok)
assert.Equal(t, out, a1)
out, ok = findNetworkUnion([]netip.Prefix{fe80, tenDot, oneSevenTwo}, []netip.Addr{afe81, a1})
assert.True(t, ok)
assert.Equal(t, out, afe81)
//some mismatches
out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{afe81})
assert.True(t, ok)
assert.Equal(t, out, afe81)
out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1, afe81})
assert.True(t, ok)
assert.Equal(t, out, afe81)
//falsey cases
out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1})
assert.False(t, ok)
out, ok = findNetworkUnion([]netip.Prefix{fc00, fe80}, []netip.Addr{a1})
assert.False(t, ok)
out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fc00}, []netip.Addr{a1, afe81})
assert.False(t, ok)
out, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81})
assert.False(t, ok)
}

39
main.go
View File

@ -2,7 +2,6 @@ package nebula
import (
"context"
"encoding/binary"
"fmt"
"net"
"net/netip"
@ -61,25 +60,12 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err)
}
certificate := pki.GetCertState().Certificate
fw, err := NewFirewallFromConfig(l, certificate, c)
fw, err := NewFirewallFromConfig(l, pki.getCertState(), c)
if err != nil {
return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err)
}
l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started")
ones, _ := certificate.Details.Ips[0].Mask.Size()
addr, ok := netip.AddrFromSlice(certificate.Details.Ips[0].IP)
if !ok {
err = util.NewContextualError(
"Invalid ip address in certificate",
m{"vpnIp": certificate.Details.Ips[0].IP},
nil,
)
return nil, err
}
tunCidr := netip.PrefixFrom(addr, ones)
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
if err != nil {
return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err)
@ -142,7 +128,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
deviceFactory = overlay.NewDeviceFromConfig
}
tun, err = deviceFactory(c, l, tunCidr, routines)
tun, err = deviceFactory(c, l, pki.getCertState().myVpnNetworks, routines)
if err != nil {
return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err)
}
@ -197,9 +183,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
}
}
hostMap := NewHostMapFromConfig(l, tunCidr, c)
hostMap := NewHostMapFromConfig(l, c)
punchy := NewPunchyFromConfig(l, c)
lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy)
lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy)
if err != nil {
return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err)
}
@ -242,7 +228,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
Inside: tun,
Outside: udpConns[0],
pki: pki,
Cipher: c.GetString("cipher", "aes"),
Firewall: fw,
ServeDns: serveDns,
HandshakeManager: handshakeManager,
@ -264,15 +249,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
l: l,
}
switch ifConfig.Cipher {
case "aes":
noiseEndianness = binary.BigEndian
case "chachapoly":
noiseEndianness = binary.LittleEndian
default:
return nil, fmt.Errorf("unknown cipher: %v", ifConfig.Cipher)
}
var ifce *Interface
if !configTest {
ifce, err = NewInterface(ctx, ifConfig)
@ -280,8 +256,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
return nil, fmt.Errorf("failed to initialize interface: %s", err)
}
// TODO: Better way to attach these, probably want a new interface in InterfaceConfig
// I don't want to make this initial commit too far-reaching though
ifce.writers = udpConns
lightHouse.ifce = ifce
@ -293,8 +267,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
go handshakeManager.Run(ctx)
}
// TODO - stats third-party modules start uncancellable goroutines. Update those libs to accept
// a context so that they can exit when the context is Done.
statsStart, err := startStats(l, c, buildVersion, configTest)
if err != nil {
return nil, util.ContextualizeIfNeeded("Failed to start stats emitter", err)
@ -304,7 +276,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
return nil, nil
}
//TODO: check if we _should_ be emitting stats
go ifce.emitStats(ctx, c.GetDuration("stats.interval", time.Second*10))
attachCommands(l, c, ssh, ifce)
@ -313,7 +284,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
var dnsStart func()
if lightHouse.amLighthouse && serveDns {
l.Debugln("Starting dns server")
dnsStart = dnsMain(l, hostMap, c)
dnsStart = dnsMain(l, pki.getCertState(), hostMap, c)
}
return &Control{

View File

@ -7,8 +7,6 @@ import (
"github.com/slackhq/nebula/header"
)
//TODO: this can probably move into the header package
type MessageMetrics struct {
rx [][]metrics.Counter
tx [][]metrics.Counter

View File

@ -1,18 +0,0 @@
package nebula
/*
import (
proto "google.golang.org/protobuf/proto"
)
func HandleMetaProto(p []byte) {
m := &NebulaMeta{}
err := proto.Unmarshal(p, m)
if err != nil {
l.Debugf("problem unmarshaling meta message: %s", err)
}
//fmt.Println(m)
}
*/

File diff suppressed because it is too large Load Diff

View File

@ -23,19 +23,28 @@ message NebulaMeta {
}
message NebulaMetaDetails {
uint32 VpnIp = 1;
repeated Ip4AndPort Ip4AndPorts = 2;
repeated Ip6AndPort Ip6AndPorts = 4;
repeated uint32 RelayVpnIp = 5;
uint32 OldVpnAddr = 1 [deprecated = true];
Addr VpnAddr = 6;
repeated uint32 OldRelayVpnAddrs = 5 [deprecated = true];
repeated Addr RelayVpnAddrs = 7;
repeated V4AddrPort V4AddrPorts = 2;
repeated V6AddrPort V6AddrPorts = 4;
uint32 counter = 3;
}
message Ip4AndPort {
uint32 Ip = 1;
message Addr {
uint64 Hi = 1;
uint64 Lo = 2;
}
message V4AddrPort {
uint32 Addr = 1;
uint32 Port = 2;
}
message Ip6AndPort {
message V6AddrPort {
uint64 Hi = 1;
uint64 Lo = 2;
uint32 Port = 3;
@ -62,6 +71,7 @@ message NebulaHandshakeDetails {
uint32 ResponderIndex = 3;
uint64 Cookie = 4;
uint64 Time = 5;
uint32 CertVersion = 8;
// reserved for WIP multiport
reserved 6, 7;
}
@ -76,6 +86,10 @@ message NebulaControl {
uint32 InitiatorRelayIndex = 2;
uint32 ResponderRelayIndex = 3;
uint32 RelayToIp = 4;
uint32 RelayFromIp = 5;
uint32 OldRelayToAddr = 4 [deprecated = true];
uint32 OldRelayFromAddr = 5 [deprecated = true];
Addr RelayToAddr = 6;
Addr RelayFromAddr = 7;
}

View File

@ -3,46 +3,25 @@ package nebula
import (
"encoding/binary"
"errors"
"fmt"
"net/netip"
"time"
"github.com/flynn/noise"
"github.com/google/gopacket/layers"
"golang.org/x/net/ipv6"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/udp"
"golang.org/x/net/ipv4"
"google.golang.org/protobuf/proto"
)
const (
minFwPacketLen = 4
)
// TODO: IPV6-WORK this can likely be removed now
func readOutsidePackets(f *Interface) udp.EncReader {
return func(
addr netip.AddrPort,
out []byte,
packet []byte,
header *header.H,
fwPacket *firewall.Packet,
lhh udp.LightHouseHandlerFunc,
nb []byte,
q int,
localCache firewall.ConntrackCache,
) {
f.readOutsidePackets(addr, nil, out, packet, header, fwPacket, lhh, nb, q, localCache)
}
}
func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) {
func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
err := h.Parse(packet)
if err != nil {
// TODO: best if we return this and let caller log
// TODO: Might be better to send the literal []byte("holepunch") packet and ignore that?
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
if len(packet) > 1 {
f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", ip, err)
@ -52,7 +31,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
//l.Error("in packet ", header, packet[HeaderLen:])
if ip.IsValid() {
if f.myVpnNet.Contains(ip.Addr()) {
_, found := f.myVpnNetworksTable.Lookup(ip.Addr())
if found {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
}
@ -109,7 +89,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
if !ok {
// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
// its internal mapping. This should never happen.
hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnIp": hostinfo.vpnIp, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
return
}
@ -121,9 +101,9 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
return
case ForwardingType:
// Find the target HostInfo relay object
targetHI, targetRelay, err := f.hostMap.QueryVpnIpRelayFor(hostinfo.vpnIp, relay.PeerIp)
targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
if err != nil {
hostinfo.logger(f.l).WithField("relayTo", relay.PeerIp).WithError(err).Info("Failed to find target host info by ip")
hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip")
return
}
@ -139,7 +119,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
}
} else {
hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerIp, "relayFrom": hostinfo.vpnIp, "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
return
}
}
@ -156,13 +136,10 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
WithField("packet", packet).
Error("Failed to decrypt lighthouse packet")
//TODO: maybe after build 64 is out? 06/14/2018 - NB
//f.sendRecvError(net.Addr(addr), header.RemoteIndex)
return
}
lhf(ip, hostinfo.vpnIp, d)
lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f)
// Fallthrough to the bottom to record incoming traffic
@ -177,9 +154,6 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
WithField("packet", packet).
Error("Failed to decrypt test packet")
//TODO: maybe after build 64 is out? 06/14/2018 - NB
//f.sendRecvError(net.Addr(addr), header.RemoteIndex)
return
}
@ -229,14 +203,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
Error("Failed to decrypt Control packet")
return
}
m := &NebulaControl{}
err = m.Unmarshal(d)
if err != nil {
hostinfo.logger(f.l).WithError(err).Error("Failed to unmarshal control message")
break
}
f.relayManager.HandleControlMsg(hostinfo, m, f)
f.relayManager.HandleControlMsg(hostinfo, d, f)
default:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
@ -253,8 +221,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
func (f *Interface) closeTunnel(hostInfo *HostInfo) {
final := f.hostMap.DeleteHostInfo(hostInfo)
if final {
// We no longer have any tunnels with this vpn ip, clear learned lighthouse state to lower memory usage
f.lightHouse.DeleteVpnIp(hostInfo.vpnIp)
// We no longer have any tunnels with this vpn addr, clear learned lighthouse state to lower memory usage
f.lightHouse.DeleteVpnAddrs(hostInfo.vpnAddrs)
}
}
@ -263,25 +231,26 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) {
f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
}
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, ip netip.AddrPort) {
if ip.IsValid() && hostinfo.remote != ip {
if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, ip.Addr()) {
hostinfo.logger(f.l).WithField("newAddr", ip).Debug("lighthouse.remote_allow_list denied roaming")
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort) {
if udpAddr.IsValid() && hostinfo.remote != udpAddr {
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, udpAddr.Addr()) {
hostinfo.logger(f.l).WithField("newAddr", udpAddr).Debug("lighthouse.remote_allow_list denied roaming")
return
}
if !hostinfo.lastRoam.IsZero() && ip == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
if !hostinfo.lastRoam.IsZero() && udpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip).
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
}
return
}
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip).
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
Info("Host roamed to new udp ip/port.")
hostinfo.lastRoam = time.Now()
hostinfo.lastRoamRemote = hostinfo.remote
hostinfo.SetRemote(ip)
hostinfo.SetRemote(udpAddr)
}
}
@ -301,24 +270,141 @@ func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h
return true
}
var (
ErrPacketTooShort = errors.New("packet is too short")
ErrUnknownIPVersion = errors.New("packet is an unknown ip version")
ErrIPv4InvalidHeaderLength = errors.New("invalid ipv4 header length")
ErrIPv4PacketTooShort = errors.New("ipv4 packet is too short")
ErrIPv6PacketTooShort = errors.New("ipv6 packet is too short")
ErrIPv6CouldNotFindPayload = errors.New("could not find payload in ipv6 packet")
)
// newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers
func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
// Do we at least have an ipv4 header worth of data?
if len(data) < ipv4.HeaderLen {
return fmt.Errorf("packet is less than %v bytes", ipv4.HeaderLen)
if len(data) < 1 {
return ErrPacketTooShort
}
// Is it an ipv4 packet?
if int((data[0]>>4)&0x0f) != 4 {
return fmt.Errorf("packet is not ipv4, type: %v", int((data[0]>>4)&0x0f))
version := int((data[0] >> 4) & 0x0f)
switch version {
case ipv4.Version:
return parseV4(data, incoming, fp)
case ipv6.Version:
return parseV6(data, incoming, fp)
}
return ErrUnknownIPVersion
}
func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
dataLen := len(data)
if dataLen < ipv6.HeaderLen {
return ErrIPv6PacketTooShort
}
if incoming {
fp.RemoteAddr, _ = netip.AddrFromSlice(data[8:24])
fp.LocalAddr, _ = netip.AddrFromSlice(data[24:40])
} else {
fp.LocalAddr, _ = netip.AddrFromSlice(data[8:24])
fp.RemoteAddr, _ = netip.AddrFromSlice(data[24:40])
}
protoAt := 6 // NextHeader is at 6 bytes into the ipv6 header
offset := ipv6.HeaderLen // Start at the end of the ipv6 header
next := 0
for {
if dataLen < offset {
break
}
proto := layers.IPProtocol(data[protoAt])
//fmt.Println(proto, protoAt)
switch proto {
case layers.IPProtocolICMPv6, layers.IPProtocolESP, layers.IPProtocolNoNextHeader:
fp.Protocol = uint8(proto)
fp.RemotePort = 0
fp.LocalPort = 0
fp.Fragment = false
return nil
case layers.IPProtocolTCP, layers.IPProtocolUDP:
if dataLen < offset+4 {
return ErrIPv6PacketTooShort
}
fp.Protocol = uint8(proto)
if incoming {
fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2])
fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
} else {
fp.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2])
fp.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
}
fp.Fragment = false
return nil
case layers.IPProtocolIPv6Fragment:
// Fragment header is 8 bytes, need at least offset+4 to read the offset field
if dataLen < offset+8 {
return ErrIPv6PacketTooShort
}
// Check if this is the first fragment
fragmentOffset := binary.BigEndian.Uint16(data[offset+2:offset+4]) &^ uint16(0x7) // Remove the reserved and M flag bits
if fragmentOffset != 0 {
// Non-first fragment, use what we have now and stop processing
fp.Protocol = data[offset]
fp.Fragment = true
fp.RemotePort = 0
fp.LocalPort = 0
return nil
}
// The next loop should be the transport layer since we are the first fragment
next = 8 // Fragment headers are always 8 bytes
case layers.IPProtocolAH:
// Auth headers, used by IPSec, have a different meaning for header length
if dataLen < offset+1 {
break
}
next = int(data[offset+1]+2) << 2
default:
// Normal ipv6 header length processing
if dataLen < offset+1 {
break
}
next = int(data[offset+1]+1) << 3
}
if next <= 0 {
// Safety check, each ipv6 header has to be at least 8 bytes
next = 8
}
protoAt = offset
offset = offset + next
}
return ErrIPv6CouldNotFindPayload
}
func parseV4(data []byte, incoming bool, fp *firewall.Packet) error {
// Do we at least have an ipv4 header worth of data?
if len(data) < ipv4.HeaderLen {
return ErrIPv4PacketTooShort
}
// Adjust our start position based on the advertised ip header length
ihl := int(data[0]&0x0f) << 2
// Well formed ip header length?
// Well-formed ip header length?
if ihl < ipv4.HeaderLen {
return fmt.Errorf("packet had an invalid header length: %v", ihl)
return ErrIPv4InvalidHeaderLength
}
// Check if this is the second or further fragment of a fragmented packet.
@ -334,14 +420,13 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
minLen += minFwPacketLen
}
if len(data) < minLen {
return fmt.Errorf("packet is less than %v bytes, ip header len: %v", minLen, ihl)
return ErrIPv4InvalidHeaderLength
}
// Firewall packets are locally oriented
if incoming {
//TODO: IPV6-WORK
fp.RemoteIP, _ = netip.AddrFromSlice(data[12:16])
fp.LocalIP, _ = netip.AddrFromSlice(data[16:20])
fp.RemoteAddr, _ = netip.AddrFromSlice(data[12:16])
fp.LocalAddr, _ = netip.AddrFromSlice(data[16:20])
if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
fp.RemotePort = 0
fp.LocalPort = 0
@ -350,9 +435,8 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
}
} else {
//TODO: IPV6-WORK
fp.LocalIP, _ = netip.AddrFromSlice(data[12:16])
fp.RemoteIP, _ = netip.AddrFromSlice(data[16:20])
fp.LocalAddr, _ = netip.AddrFromSlice(data[12:16])
fp.RemoteAddr, _ = netip.AddrFromSlice(data[16:20])
if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
fp.RemotePort = 0
fp.LocalPort = 0
@ -387,8 +471,6 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
//TODO: maybe after build 64 is out? 06/14/2018 - NB
//f.sendRecvError(hostinfo.remote, header.RemoteIndex)
return false
}
@ -435,9 +517,8 @@ func (f *Interface) maybeSendRecvError(endpoint netip.AddrPort, index uint32) {
func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) {
f.messageMetrics.Tx(header.RecvError, 0, 1)
//TODO: this should be a signed message so we can trust that we should drop the index
b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0)
f.outside.WriteTo(b, endpoint)
_ = f.outside.WriteTo(b, endpoint)
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("index", index).
WithField("udpAddr", endpoint).
@ -471,65 +552,3 @@ func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
// We also delete it from pending hostmap to allow for fast reconnect.
f.handshakeManager.DeleteHostInfo(hostinfo)
}
/*
func (f *Interface) sendMeta(ci *ConnectionState, endpoint *net.UDPAddr, meta *NebulaMeta) {
if ci.eKey != nil {
//TODO: log error?
return
}
msg, err := proto.Marshal(meta)
if err != nil {
l.Debugln("failed to encode header")
}
c := ci.messageCounter
b := HeaderEncode(nil, Version, uint8(metadata), 0, hostinfo.remoteIndexId, c)
ci.messageCounter++
msg := ci.eKey.EncryptDanger(b, nil, msg, c)
//msg := ci.eKey.EncryptDanger(b, nil, []byte(fmt.Sprintf("%d", counter)), c)
f.outside.WriteTo(msg, endpoint)
}
*/
func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte, caPool *cert.NebulaCAPool) (*cert.NebulaCertificate, error) {
pk := h.PeerStatic()
if pk == nil {
return nil, errors.New("no peer static key was present")
}
if rawCertBytes == nil {
return nil, errors.New("provided payload was empty")
}
r := &cert.RawNebulaCertificate{}
err := proto.Unmarshal(rawCertBytes, r)
if err != nil {
return nil, fmt.Errorf("error unmarshaling cert: %s", err)
}
// If the Details are nil, just exit to avoid crashing
if r.Details == nil {
return nil, fmt.Errorf("certificate did not contain any details")
}
r.Details.PublicKey = pk
recombined, err := proto.Marshal(r)
if err != nil {
return nil, fmt.Errorf("error while recombining certificate: %s", err)
}
c, _ := cert.UnmarshalNebulaCertificate(recombined)
isValid, err := c.Verify(time.Now(), caPool)
if err != nil {
return c, fmt.Errorf("certificate validation failed: %s", err)
} else if !isValid {
// This case should never happen but here's to defensive programming!
return c, errors.New("certificate validation failed but did not return an error")
}
return c, nil
}

View File

@ -1,21 +1,33 @@
package nebula
import (
"bytes"
"encoding/binary"
"net"
"net/netip"
"testing"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/slackhq/nebula/firewall"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/ipv4"
)
func Test_newPacket(t *testing.T) {
p := &firewall.Packet{}
// length fail
err := newPacket([]byte{0, 1}, true, p)
assert.EqualError(t, err, "packet is less than 20 bytes")
// length fails
err := newPacket([]byte{}, true, p)
require.ErrorIs(t, err, ErrPacketTooShort)
err = newPacket([]byte{0x40}, true, p)
require.ErrorIs(t, err, ErrIPv4PacketTooShort)
err = newPacket([]byte{0x60}, true, p)
require.ErrorIs(t, err, ErrIPv6PacketTooShort)
// length fail with ip options
h := ipv4.Header{
@ -28,16 +40,15 @@ func Test_newPacket(t *testing.T) {
b, _ := h.Marshal()
err = newPacket(b, true, p)
assert.EqualError(t, err, "packet is less than 28 bytes, ip header len: 24")
require.ErrorIs(t, err, ErrIPv4InvalidHeaderLength)
// not an ipv4 packet
err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
assert.EqualError(t, err, "packet is not ipv4, type: 0")
require.ErrorIs(t, err, ErrUnknownIPVersion)
// invalid ihl
err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
assert.EqualError(t, err, "packet had an invalid header length: 8")
require.ErrorIs(t, err, ErrIPv4InvalidHeaderLength)
// account for variable ip header length - incoming
h = ipv4.Header{
@ -53,12 +64,13 @@ func Test_newPacket(t *testing.T) {
b = append(b, []byte{0, 3, 0, 4}...)
err = newPacket(b, true, p)
assert.Nil(t, err)
assert.Equal(t, p.Protocol, uint8(firewall.ProtoTCP))
assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.2"))
assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.1"))
assert.Equal(t, p.RemotePort, uint16(3))
assert.Equal(t, p.LocalPort, uint16(4))
require.NoError(t, err)
assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.LocalAddr)
assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.RemoteAddr)
assert.Equal(t, uint16(3), p.RemotePort)
assert.Equal(t, uint16(4), p.LocalPort)
assert.False(t, p.Fragment)
// account for variable ip header length - outgoing
h = ipv4.Header{
@ -74,10 +86,507 @@ func Test_newPacket(t *testing.T) {
b = append(b, []byte{0, 5, 0, 6}...)
err = newPacket(b, false, p)
assert.Nil(t, err)
assert.Equal(t, p.Protocol, uint8(2))
assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.1"))
assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.2"))
assert.Equal(t, p.RemotePort, uint16(6))
assert.Equal(t, p.LocalPort, uint16(5))
require.NoError(t, err)
assert.Equal(t, uint8(2), p.Protocol)
assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.LocalAddr)
assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.RemoteAddr)
assert.Equal(t, uint16(6), p.RemotePort)
assert.Equal(t, uint16(5), p.LocalPort)
assert.False(t, p.Fragment)
}
func Test_newPacket_v6(t *testing.T) {
p := &firewall.Packet{}
// invalid ipv6
ip := layers.IPv6{
Version: 6,
HopLimit: 128,
SrcIP: net.IPv6linklocalallrouters,
DstIP: net.IPv6linklocalallnodes,
}
buffer := gopacket.NewSerializeBuffer()
opt := gopacket.SerializeOptions{
ComputeChecksums: false,
FixLengths: false,
}
err := gopacket.SerializeLayers(buffer, opt, &ip)
require.NoError(t, err)
err = newPacket(buffer.Bytes(), true, p)
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
// A good ICMP packet
ip = layers.IPv6{
Version: 6,
NextHeader: layers.IPProtocolICMPv6,
HopLimit: 128,
SrcIP: net.IPv6linklocalallrouters,
DstIP: net.IPv6linklocalallnodes,
}
icmp := layers.ICMPv6{}
buffer.Clear()
err = gopacket.SerializeLayers(buffer, opt, &ip, &icmp)
if err != nil {
panic(err)
}
err = newPacket(buffer.Bytes(), true, p)
require.NoError(t, err)
assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
assert.Equal(t, uint16(0), p.RemotePort)
assert.Equal(t, uint16(0), p.LocalPort)
assert.False(t, p.Fragment)
// A good ESP packet
b := buffer.Bytes()
b[6] = byte(layers.IPProtocolESP)
err = newPacket(b, true, p)
require.NoError(t, err)
assert.Equal(t, uint8(layers.IPProtocolESP), p.Protocol)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
assert.Equal(t, uint16(0), p.RemotePort)
assert.Equal(t, uint16(0), p.LocalPort)
assert.False(t, p.Fragment)
// A good None packet
b = buffer.Bytes()
b[6] = byte(layers.IPProtocolNoNextHeader)
err = newPacket(b, true, p)
require.NoError(t, err)
assert.Equal(t, uint8(layers.IPProtocolNoNextHeader), p.Protocol)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
assert.Equal(t, uint16(0), p.RemotePort)
assert.Equal(t, uint16(0), p.LocalPort)
assert.False(t, p.Fragment)
// An unknown protocol packet
b = buffer.Bytes()
b[6] = 255 // 255 is a reserved protocol number
err = newPacket(b, true, p)
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
// A good UDP packet
ip = layers.IPv6{
Version: 6,
NextHeader: firewall.ProtoUDP,
HopLimit: 128,
SrcIP: net.IPv6linklocalallrouters,
DstIP: net.IPv6linklocalallnodes,
}
udp := layers.UDP{
SrcPort: layers.UDPPort(36123),
DstPort: layers.UDPPort(22),
}
err = udp.SetNetworkLayerForChecksum(&ip)
require.NoError(t, err)
buffer.Clear()
err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload([]byte{0xde, 0xad, 0xbe, 0xef}))
if err != nil {
panic(err)
}
b = buffer.Bytes()
// incoming
err = newPacket(b, true, p)
require.NoError(t, err)
assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
assert.Equal(t, uint16(36123), p.RemotePort)
assert.Equal(t, uint16(22), p.LocalPort)
assert.False(t, p.Fragment)
// outgoing
err = newPacket(b, false, p)
require.NoError(t, err)
assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
assert.Equal(t, uint16(36123), p.LocalPort)
assert.Equal(t, uint16(22), p.RemotePort)
assert.False(t, p.Fragment)
// Too short UDP packet
err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
require.ErrorIs(t, err, ErrIPv6PacketTooShort)
// A good TCP packet
b[6] = byte(layers.IPProtocolTCP)
// incoming
err = newPacket(b, true, p)
require.NoError(t, err)
assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
assert.Equal(t, uint16(36123), p.RemotePort)
assert.Equal(t, uint16(22), p.LocalPort)
assert.False(t, p.Fragment)
// outgoing
err = newPacket(b, false, p)
require.NoError(t, err)
assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
assert.Equal(t, uint16(36123), p.LocalPort)
assert.Equal(t, uint16(22), p.RemotePort)
assert.False(t, p.Fragment)
// Too short TCP packet
err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
require.ErrorIs(t, err, ErrIPv6PacketTooShort)
// A good UDP packet with an AH header
ip = layers.IPv6{
Version: 6,
NextHeader: layers.IPProtocolAH,
HopLimit: 128,
SrcIP: net.IPv6linklocalallrouters,
DstIP: net.IPv6linklocalallnodes,
}
ah := layers.IPSecAH{
AuthenticationData: []byte{0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef},
}
ah.NextHeader = layers.IPProtocolUDP
udpHeader := []byte{
0x8d, 0x1b, // Source port 36123
0x00, 0x16, // Destination port 22
0x00, 0x00, // Length
0x00, 0x00, // Checksum
}
buffer.Clear()
err = ip.SerializeTo(buffer, opt)
if err != nil {
panic(err)
}
b = buffer.Bytes()
ahb := serializeAH(&ah)
b = append(b, ahb...)
b = append(b, udpHeader...)
err = newPacket(b, true, p)
require.NoError(t, err)
assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
assert.Equal(t, uint16(36123), p.RemotePort)
assert.Equal(t, uint16(22), p.LocalPort)
assert.False(t, p.Fragment)
// Invalid AH header
b = buffer.Bytes()
err = newPacket(b, true, p)
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
}
func Test_newPacket_ipv6Fragment(t *testing.T) {
p := &firewall.Packet{}
ip := &layers.IPv6{
Version: 6,
NextHeader: layers.IPProtocolIPv6Fragment,
HopLimit: 64,
SrcIP: net.IPv6linklocalallrouters,
DstIP: net.IPv6linklocalallnodes,
}
// First fragment
fragHeader1 := []byte{
uint8(layers.IPProtocolUDP), // Next Header (UDP)
0x00, // Reserved
0x00, // Fragment Offset high byte (0)
0x01, // Fragment Offset low byte & flags (M=1)
0x00, 0x00, 0x00, 0x01, // Identification
}
udpHeader := []byte{
0x8d, 0x1b, // Source port 36123
0x00, 0x16, // Destination port 22
0x00, 0x00, // Length
0x00, 0x00, // Checksum
}
buffer := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
err := ip.SerializeTo(buffer, opts)
if err != nil {
t.Fatal(err)
}
firstFrag := buffer.Bytes()
firstFrag = append(firstFrag, fragHeader1...)
firstFrag = append(firstFrag, udpHeader...)
firstFrag = append(firstFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
// Test first fragment incoming
err = newPacket(firstFrag, true, p)
require.NoError(t, err)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
assert.Equal(t, uint16(36123), p.RemotePort)
assert.Equal(t, uint16(22), p.LocalPort)
assert.False(t, p.Fragment)
// Test first fragment outgoing
err = newPacket(firstFrag, false, p)
require.NoError(t, err)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
assert.Equal(t, uint16(36123), p.LocalPort)
assert.Equal(t, uint16(22), p.RemotePort)
assert.False(t, p.Fragment)
// Second fragment
fragHeader2 := []byte{
uint8(layers.IPProtocolUDP), // Next Header (UDP)
0x00, // Reserved
0xb9, // Fragment Offset high byte (185)
0x01, // Fragment Offset low byte & flags (M=1)
0x00, 0x00, 0x00, 0x01, // Identification
}
buffer.Clear()
err = ip.SerializeTo(buffer, opts)
if err != nil {
t.Fatal(err)
}
secondFrag := buffer.Bytes()
secondFrag = append(secondFrag, fragHeader2...)
secondFrag = append(secondFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
// Test second fragment incoming
err = newPacket(secondFrag, true, p)
require.NoError(t, err)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
assert.Equal(t, uint16(0), p.RemotePort)
assert.Equal(t, uint16(0), p.LocalPort)
assert.True(t, p.Fragment)
// Test second fragment outgoing
err = newPacket(secondFrag, false, p)
require.NoError(t, err)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
assert.Equal(t, uint16(0), p.LocalPort)
assert.Equal(t, uint16(0), p.RemotePort)
assert.True(t, p.Fragment)
// Too short of a fragment packet
err = newPacket(secondFrag[:len(secondFrag)-10], false, p)
require.ErrorIs(t, err, ErrIPv6PacketTooShort)
}
func BenchmarkParseV6(b *testing.B) {
// Regular UDP packet
ip := &layers.IPv6{
Version: 6,
NextHeader: layers.IPProtocolUDP,
HopLimit: 64,
SrcIP: net.IPv6linklocalallrouters,
DstIP: net.IPv6linklocalallnodes,
}
udp := &layers.UDP{
SrcPort: layers.UDPPort(36123),
DstPort: layers.UDPPort(22),
}
buffer := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
ComputeChecksums: false,
FixLengths: true,
}
err := gopacket.SerializeLayers(buffer, opts, ip, udp)
if err != nil {
b.Fatal(err)
}
normalPacket := buffer.Bytes()
// First Fragment packet
ipFrag := &layers.IPv6{
Version: 6,
NextHeader: layers.IPProtocolIPv6Fragment,
HopLimit: 64,
SrcIP: net.IPv6linklocalallrouters,
DstIP: net.IPv6linklocalallnodes,
}
fragHeader := []byte{
uint8(layers.IPProtocolUDP), // Next Header (UDP)
0x00, // Reserved
0x00, // Fragment Offset high byte (0)
0x01, // Fragment Offset low byte & flags (M=1)
0x00, 0x00, 0x00, 0x01, // Identification
}
udpHeader := []byte{
0x8d, 0x7b, // Source port 36123
0x00, 0x16, // Destination port 22
0x00, 0x00, // Length
0x00, 0x00, // Checksum
}
buffer.Clear()
err = ipFrag.SerializeTo(buffer, opts)
if err != nil {
b.Fatal(err)
}
firstFrag := buffer.Bytes()
firstFrag = append(firstFrag, fragHeader...)
firstFrag = append(firstFrag, udpHeader...)
firstFrag = append(firstFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
// Second Fragment packet
fragHeader[2] = 0xb9 // offset 185
buffer.Clear()
err = ipFrag.SerializeTo(buffer, opts)
if err != nil {
b.Fatal(err)
}
secondFrag := buffer.Bytes()
secondFrag = append(secondFrag, fragHeader...)
secondFrag = append(secondFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
fp := &firewall.Packet{}
b.Run("Normal", func(b *testing.B) {
for i := 0; i < b.N; i++ {
if err = parseV6(normalPacket, true, fp); err != nil {
b.Fatal(err)
}
}
})
b.Run("FirstFragment", func(b *testing.B) {
for i := 0; i < b.N; i++ {
if err = parseV6(firstFrag, true, fp); err != nil {
b.Fatal(err)
}
}
})
b.Run("SecondFragment", func(b *testing.B) {
for i := 0; i < b.N; i++ {
if err = parseV6(secondFrag, true, fp); err != nil {
b.Fatal(err)
}
}
})
// Evil packet
evilPacket := &layers.IPv6{
Version: 6,
NextHeader: layers.IPProtocolIPv6HopByHop,
HopLimit: 64,
SrcIP: net.IPv6linklocalallrouters,
DstIP: net.IPv6linklocalallnodes,
}
hopHeader := []byte{
uint8(layers.IPProtocolIPv6HopByHop), // Next Header (HopByHop)
0x00, // Length
0x00, 0x00, // Options and padding
0x00, 0x00, 0x00, 0x00, // More options and padding
}
lastHopHeader := []byte{
uint8(layers.IPProtocolUDP), // Next Header (UDP)
0x00, // Length
0x00, 0x00, // Options and padding
0x00, 0x00, 0x00, 0x00, // More options and padding
}
buffer.Clear()
err = evilPacket.SerializeTo(buffer, opts)
if err != nil {
b.Fatal(err)
}
evilBytes := buffer.Bytes()
for i := 0; i < 200; i++ {
evilBytes = append(evilBytes, hopHeader...)
}
evilBytes = append(evilBytes, lastHopHeader...)
evilBytes = append(evilBytes, udpHeader...)
evilBytes = append(evilBytes, []byte{0xde, 0xad, 0xbe, 0xef}...)
b.Run("200 HopByHop headers", func(b *testing.B) {
for i := 0; i < b.N; i++ {
if err = parseV6(evilBytes, false, fp); err != nil {
b.Fatal(err)
}
}
})
}
// Ensure authentication data is a multiple of 8 bytes by padding if necessary
func padAuthData(authData []byte) []byte {
// Length of Authentication Data must be a multiple of 8 bytes
paddingLength := (8 - (len(authData) % 8)) % 8 // Only pad if necessary
if paddingLength > 0 {
authData = append(authData, make([]byte, paddingLength)...)
}
return authData
}
// Custom function to manually serialize IPSecAH for both IPv4 and IPv6
func serializeAH(ah *layers.IPSecAH) []byte {
buf := new(bytes.Buffer)
// Ensure Authentication Data is a multiple of 8 bytes
ah.AuthenticationData = padAuthData(ah.AuthenticationData)
// Calculate Payload Length (in 32-bit words, minus 2)
payloadLen := uint8((12+len(ah.AuthenticationData))/4) - 2
// Serialize fields
if err := binary.Write(buf, binary.BigEndian, ah.NextHeader); err != nil {
panic(err)
}
if err := binary.Write(buf, binary.BigEndian, payloadLen); err != nil {
panic(err)
}
if err := binary.Write(buf, binary.BigEndian, ah.Reserved); err != nil {
panic(err)
}
if err := binary.Write(buf, binary.BigEndian, ah.SPI); err != nil {
panic(err)
}
if err := binary.Write(buf, binary.BigEndian, ah.Seq); err != nil {
panic(err)
}
if len(ah.AuthenticationData) > 0 {
if err := binary.Write(buf, binary.BigEndian, ah.AuthenticationData); err != nil {
panic(err)
}
}
return buf.Bytes()
}

View File

@ -8,7 +8,7 @@ import (
type Device interface {
io.ReadWriteCloser
Activate() error
Cidr() netip.Prefix
Networks() []netip.Prefix
Name() string
RouteFor(netip.Addr) netip.Addr
NewMultiQueueReader() (io.ReadWriteCloser, error)

View File

@ -61,7 +61,7 @@ func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table
return routeTree, nil
}
func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
func parseRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
var err error
r := c.Get("tun.routes")
@ -117,12 +117,20 @@ func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
return nil, fmt.Errorf("entry %v.route in tun.routes failed to parse: %v", i+1, err)
}
if !network.Contains(r.Cidr.Addr()) || r.Cidr.Bits() < network.Bits() {
found := false
for _, network := range networks {
if network.Contains(r.Cidr.Addr()) && r.Cidr.Bits() >= network.Bits() {
found = true
break
}
}
if !found {
return nil, fmt.Errorf(
"entry %v.route in tun.routes is not contained within the network attached to the certificate; route: %v, network: %v",
"entry %v.route in tun.routes is not contained within the configured vpn networks; route: %v, networks: %v",
i+1,
r.Cidr.String(),
network.String(),
networks,
)
}
@ -132,7 +140,7 @@ func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
return routes, nil
}
func parseUnsafeRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
var err error
r := c.Get("tun.unsafe_routes")
@ -229,14 +237,16 @@ func parseUnsafeRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes failed to parse: %v", i+1, err)
}
for _, network := range networks {
if network.Contains(r.Cidr.Addr()) {
return nil, fmt.Errorf(
"entry %v.route in tun.unsafe_routes is contained within the network attached to the certificate; route: %v, network: %v",
"entry %v.route in tun.unsafe_routes is contained within the configured vpn networks; route: %v, network: %v",
i+1,
r.Cidr.String(),
network.String(),
)
}
}
routes[i] = r
}

View File

@ -8,86 +8,93 @@ import (
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_parseRoutes(t *testing.T) {
l := test.NewLogger()
c := config.NewC(l)
n, err := netip.ParsePrefix("10.0.0.0/24")
assert.NoError(t, err)
require.NoError(t, err)
// test no routes config
routes, err := parseRoutes(c, n)
assert.Nil(t, err)
assert.Len(t, routes, 0)
routes, err := parseRoutes(c, []netip.Prefix{n})
require.NoError(t, err)
assert.Empty(t, routes)
// not an array
c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"}
routes, err = parseRoutes(c, n)
routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
assert.EqualError(t, err, "tun.routes is not an array")
require.EqualError(t, err, "tun.routes is not an array")
// no routes
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}}
routes, err = parseRoutes(c, n)
assert.Nil(t, err)
assert.Len(t, routes, 0)
routes, err = parseRoutes(c, []netip.Prefix{n})
require.NoError(t, err)
assert.Empty(t, routes)
// weird route
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}}
routes, err = parseRoutes(c, n)
routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1 in tun.routes is invalid")
require.EqualError(t, err, "entry 1 in tun.routes is invalid")
// no mtu
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{}}}
routes, err = parseRoutes(c, n)
routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.mtu in tun.routes is not present")
require.EqualError(t, err, "entry 1.mtu in tun.routes is not present")
// bad mtu
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "nope"}}}
routes, err = parseRoutes(c, n)
routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
require.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
// low mtu
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "499"}}}
routes, err = parseRoutes(c, n)
routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499")
require.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499")
// missing route
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500"}}}
routes, err = parseRoutes(c, n)
routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.routes is not present")
require.EqualError(t, err, "entry 1.route in tun.routes is not present")
// unparsable route
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}}
routes, err = parseRoutes(c, n)
routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
require.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
// below network range
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}}
routes, err = parseRoutes(c, n)
routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the network attached to the certificate; route: 1.0.0.0/8, network: 10.0.0.0/24")
require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 1.0.0.0/8, networks: [10.0.0.0/24]")
// above network range
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "10.0.1.0/24"}}}
routes, err = parseRoutes(c, n)
routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the network attached to the certificate; route: 10.0.1.0/24, network: 10.0.0.0/24")
require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 10.0.1.0/24, networks: [10.0.0.0/24]")
// Not in multiple ranges
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "192.0.0.0/24"}}}
routes, err = parseRoutes(c, []netip.Prefix{n, netip.MustParsePrefix("192.1.0.0/24")})
assert.Nil(t, routes)
require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 192.0.0.0/24, networks: [10.0.0.0/24 192.1.0.0/24]")
// happy case
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{
map[interface{}]interface{}{"mtu": "9000", "route": "10.0.0.0/29"},
map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"},
}}
routes, err = parseRoutes(c, n)
assert.Nil(t, err)
routes, err = parseRoutes(c, []netip.Prefix{n})
require.NoError(t, err)
assert.Len(t, routes, 2)
tested := 0
@ -113,106 +120,106 @@ func Test_parseUnsafeRoutes(t *testing.T) {
l := test.NewLogger()
c := config.NewC(l)
n, err := netip.ParsePrefix("10.0.0.0/24")
assert.NoError(t, err)
require.NoError(t, err)
// test no routes config
routes, err := parseUnsafeRoutes(c, n)
assert.Nil(t, err)
assert.Len(t, routes, 0)
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
require.NoError(t, err)
assert.Empty(t, routes)
// not an array
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"}
routes, err = parseUnsafeRoutes(c, n)
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
assert.EqualError(t, err, "tun.unsafe_routes is not an array")
require.EqualError(t, err, "tun.unsafe_routes is not an array")
// no routes
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}}
routes, err = parseUnsafeRoutes(c, n)
assert.Nil(t, err)
assert.Len(t, routes, 0)
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
require.NoError(t, err)
assert.Empty(t, routes)
// weird route
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}}
routes, err = parseUnsafeRoutes(c, n)
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid")
require.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid")
// no via
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{}}}
routes, err = parseUnsafeRoutes(c, n)
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present")
require.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present")
// invalid via
for _, invalidValue := range []interface{}{
127, false, nil, 1.0, []string{"1", "2"},
} {
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}}
routes, err = parseUnsafeRoutes(c, n)
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
assert.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue))
require.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue))
}
// unparsable via
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}}
routes, err = parseUnsafeRoutes(c, n)
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP")
require.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP")
// missing route
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}}
routes, err = parseUnsafeRoutes(c, n)
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present")
require.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present")
// unparsable route
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}}
routes, err = parseUnsafeRoutes(c, n)
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
require.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
// within network range
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}}
routes, err = parseUnsafeRoutes(c, n)
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the network attached to the certificate; route: 10.0.0.0/24, network: 10.0.0.0/24")
require.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the configured vpn networks; route: 10.0.0.0/24, network: 10.0.0.0/24")
// below network range
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
routes, err = parseUnsafeRoutes(c, n)
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Len(t, routes, 1)
assert.Nil(t, err)
require.NoError(t, err)
// above network range
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}}
routes, err = parseUnsafeRoutes(c, n)
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Len(t, routes, 1)
assert.Nil(t, err)
require.NoError(t, err)
// no mtu
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
routes, err = parseUnsafeRoutes(c, n)
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Len(t, routes, 1)
assert.Equal(t, 0, routes[0].MTU)
// bad mtu
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}}
routes, err = parseUnsafeRoutes(c, n)
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
require.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
// low mtu
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "499"}}}
routes, err = parseUnsafeRoutes(c, n)
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499")
require.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499")
// bad install
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}}
routes, err = parseUnsafeRoutes(c, n)
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax")
require.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax")
// happy case
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{
@ -221,8 +228,8 @@ func Test_parseUnsafeRoutes(t *testing.T) {
map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1},
map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"},
}}
routes, err = parseUnsafeRoutes(c, n)
assert.Nil(t, err)
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
require.NoError(t, err)
assert.Len(t, routes, 4)
tested := 0
@ -254,38 +261,38 @@ func Test_makeRouteTree(t *testing.T) {
l := test.NewLogger()
c := config.NewC(l)
n, err := netip.ParsePrefix("10.0.0.0/24")
assert.NoError(t, err)
require.NoError(t, err)
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{
map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"},
map[interface{}]interface{}{"via": "192.168.0.2", "route": "1.0.0.1/32"},
}}
routes, err := parseUnsafeRoutes(c, n)
assert.NoError(t, err)
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
require.NoError(t, err)
assert.Len(t, routes, 2)
routeTree, err := makeRouteTree(l, routes, true)
assert.NoError(t, err)
require.NoError(t, err)
ip, err := netip.ParseAddr("1.0.0.2")
assert.NoError(t, err)
require.NoError(t, err)
r, ok := routeTree.Lookup(ip)
assert.True(t, ok)
nip, err := netip.ParseAddr("192.168.0.1")
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, nip, r)
ip, err = netip.ParseAddr("1.0.0.1")
assert.NoError(t, err)
require.NoError(t, err)
r, ok = routeTree.Lookup(ip)
assert.True(t, ok)
nip, err = netip.ParseAddr("192.168.0.2")
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, nip, r)
ip, err = netip.ParseAddr("1.1.0.1")
assert.NoError(t, err)
require.NoError(t, err)
r, ok = routeTree.Lookup(ip)
assert.False(t, ok)
}

View File

@ -11,36 +11,36 @@ import (
const DefaultMTU = 1300
// TODO: We may be able to remove routines
type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error)
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) {
func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
switch {
case c.GetBool("tun.disabled", false):
tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
return tun, nil
default:
return newTun(c, l, tunCidr, routines > 1)
return newTun(c, l, vpnNetworks, routines > 1)
}
}
func NewFdDeviceFromConfig(fd *int) DeviceFactory {
return func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) {
return newTunFromFd(c, l, *fd, tunCidr)
return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
return newTunFromFd(c, l, *fd, vpnNetworks)
}
}
func getAllRoutesFromConfig(c *config.C, cidr netip.Prefix, initial bool) (bool, []Route, error) {
func getAllRoutesFromConfig(c *config.C, vpnNetworks []netip.Prefix, initial bool) (bool, []Route, error) {
if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") {
return false, nil, nil
}
routes, err := parseRoutes(c, cidr)
routes, err := parseRoutes(c, vpnNetworks)
if err != nil {
return true, nil, util.NewContextualError("Could not parse tun.routes", nil, err)
}
unsafeRoutes, err := parseUnsafeRoutes(c, cidr)
unsafeRoutes, err := parseUnsafeRoutes(c, vpnNetworks)
if err != nil {
return true, nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err)
}

View File

@ -19,13 +19,13 @@ import (
type tun struct {
io.ReadWriteCloser
fd int
cidr netip.Prefix
vpnNetworks []netip.Prefix
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger
}
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
// XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly.
// Be sure not to call file.Fd() as it will set the fd to blocking mode.
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
@ -33,7 +33,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix
t := &tun{
ReadWriteCloser: file,
fd: deviceFd,
cidr: cidr,
vpnNetworks: vpnNetworks,
l: l,
}
@ -52,7 +52,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix
return t, nil
}
func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) {
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTun not supported in Android")
}
@ -66,7 +66,7 @@ func (t tun) Activate() error {
}
func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil {
return err
}
@ -86,8 +86,8 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil
}
func (t *tun) Cidr() netip.Prefix {
return t.cidr
func (t *tun) Networks() []netip.Prefix {
return t.vpnNetworks
}
func (t *tun) Name() string {

View File

@ -25,7 +25,7 @@ import (
type tun struct {
io.ReadWriteCloser
Device string
cidr netip.Prefix
vpnNetworks []netip.Prefix
DefaultMTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]]
@ -36,44 +36,50 @@ type tun struct {
out []byte
}
type sockaddrCtl struct {
scLen uint8
scFamily uint8
ssSysaddr uint16
scID uint32
scUnit uint32
scReserved [5]uint32
}
type ifReq struct {
Name [16]byte
Name [unix.IFNAMSIZ]byte
Flags uint16
pad [8]byte
}
var sockaddrCtlSize uintptr = 32
const (
_SYSPROTO_CONTROL = 2 //define SYSPROTO_CONTROL 2 /* kernel control protocol */
_AF_SYS_CONTROL = 2 //#define AF_SYS_CONTROL 2 /* corresponding sub address type */
_PF_SYSTEM = unix.AF_SYSTEM //#define PF_SYSTEM AF_SYSTEM
_CTLIOCGINFO = 3227799043 //#define CTLIOCGINFO _IOWR('N', 3, struct ctl_info)
_SIOCAIFADDR_IN6 = 2155899162
_UTUN_OPT_IFNAME = 2
_IN6_IFF_NODAD = 0x0020
_IN6_IFF_SECURED = 0x0400
utunControlName = "com.apple.net.utun_control"
)
type ifreqAddr struct {
Name [16]byte
Addr unix.RawSockaddrInet4
pad [8]byte
}
type ifreqMTU struct {
Name [16]byte
MTU int32
pad [8]byte
}
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
type addrLifetime struct {
Expire float64
Preferred float64
Vltime uint32
Pltime uint32
}
type ifreqAlias4 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet4
DstAddr unix.RawSockaddrInet4
MaskAddr unix.RawSockaddrInet4
}
type ifreqAlias6 struct {
Name [unix.IFNAMSIZ]byte
Addr unix.RawSockaddrInet6
DstAddr unix.RawSockaddrInet6
PrefixMask unix.RawSockaddrInet6
Flags uint32
Lifetime addrLifetime
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
name := c.GetString("tun.dev", "")
ifIndex := -1
if name != "" && name != "utun" {
@ -86,66 +92,41 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
}
}
fd, err := unix.Socket(_PF_SYSTEM, unix.SOCK_DGRAM, _SYSPROTO_CONTROL)
fd, err := unix.Socket(unix.AF_SYSTEM, unix.SOCK_DGRAM, unix.AF_SYS_CONTROL)
if err != nil {
return nil, fmt.Errorf("system socket: %v", err)
}
var ctlInfo = &struct {
ctlID uint32
ctlName [96]byte
}{}
var ctlInfo = &unix.CtlInfo{}
copy(ctlInfo.Name[:], utunControlName)
copy(ctlInfo.ctlName[:], utunControlName)
err = ioctl(uintptr(fd), uintptr(_CTLIOCGINFO), uintptr(unsafe.Pointer(ctlInfo)))
err = unix.IoctlCtlInfo(fd, ctlInfo)
if err != nil {
return nil, fmt.Errorf("CTLIOCGINFO: %v", err)
}
sc := sockaddrCtl{
scLen: uint8(sockaddrCtlSize),
scFamily: unix.AF_SYSTEM,
ssSysaddr: _AF_SYS_CONTROL,
scID: ctlInfo.ctlID,
scUnit: uint32(ifIndex) + 1,
err = unix.Connect(fd, &unix.SockaddrCtl{
ID: ctlInfo.Id,
Unit: uint32(ifIndex) + 1,
})
if err != nil {
return nil, fmt.Errorf("SYS_CONNECT: %v", err)
}
_, _, errno := unix.RawSyscall(
unix.SYS_CONNECT,
uintptr(fd),
uintptr(unsafe.Pointer(&sc)),
sockaddrCtlSize,
)
if errno != 0 {
return nil, fmt.Errorf("SYS_CONNECT: %v", errno)
name, err = unix.GetsockoptString(fd, unix.AF_SYS_CONTROL, _UTUN_OPT_IFNAME)
if err != nil {
return nil, fmt.Errorf("failed to retrieve tun name: %w", err)
}
var ifName struct {
name [16]byte
}
ifNameSize := uintptr(len(ifName.name))
_, _, errno = syscall.Syscall6(syscall.SYS_GETSOCKOPT, uintptr(fd),
2, // SYSPROTO_CONTROL
2, // UTUN_OPT_IFNAME
uintptr(unsafe.Pointer(&ifName)),
uintptr(unsafe.Pointer(&ifNameSize)), 0)
if errno != 0 {
return nil, fmt.Errorf("SYS_GETSOCKOPT: %v", errno)
}
name = string(ifName.name[:ifNameSize-1])
err = syscall.SetNonblock(fd, true)
err = unix.SetNonblock(fd, true)
if err != nil {
return nil, fmt.Errorf("SetNonblock: %v", err)
}
file := os.NewFile(uintptr(fd), "")
t := &tun{
ReadWriteCloser: file,
ReadWriteCloser: os.NewFile(uintptr(fd), ""),
Device: name,
cidr: cidr,
vpnNetworks: vpnNetworks,
DefaultMTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
}
@ -172,7 +153,7 @@ func (t *tun) deviceBytes() (o [16]byte) {
return
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
}
@ -186,16 +167,6 @@ func (t *tun) Close() error {
func (t *tun) Activate() error {
devName := t.deviceBytes()
var addr, mask [4]byte
if !t.cidr.Addr().Is4() {
//TODO: IPV6-WORK
panic("need ipv6")
}
addr = t.cidr.Addr().As4()
copy(mask[:], prefixToMask(t.cidr))
s, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
@ -208,66 +179,18 @@ func (t *tun) Activate() error {
fd := uintptr(s)
ifra := ifreqAddr{
Name: devName,
Addr: unix.RawSockaddrInet4{
Family: unix.AF_INET,
Addr: addr,
},
}
// Set the device ip address
if err = ioctl(fd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil {
return fmt.Errorf("failed to set tun address: %s", err)
}
// Set the device network
ifra.Addr.Addr = mask
if err = ioctl(fd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil {
return fmt.Errorf("failed to set tun netmask: %s", err)
}
// Set the device name
ifrf := ifReq{Name: devName}
if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to set tun device name: %s", err)
}
// Set the MTU on the device
ifm := ifreqMTU{Name: devName, MTU: int32(t.DefaultMTU)}
if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
return fmt.Errorf("failed to set tun mtu: %v", err)
}
/*
// Set the transmit queue length
ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)}
if err = ioctl(fd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
// If we can't set the queue length nebula will still work but it may lead to packet loss
l.WithError(err).Error("Failed to set tun tx queue length")
}
*/
// Bring up the interface
ifrf.Flags = ifrf.Flags | unix.IFF_UP
if err = ioctl(fd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to bring the tun device up: %s", err)
// Get the device flags
ifrf := ifReq{Name: devName}
if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to get tun flags: %s", err)
}
routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer func() {
unix.Shutdown(routeSock, unix.SHUT_RDWR)
err := unix.Close(routeSock)
if err != nil {
t.l.WithError(err).Error("failed to close AF_ROUTE socket")
}
}()
routeAddr := &netroute.Inet4Addr{}
maskAddr := &netroute.Inet4Addr{}
linkAddr, err := getLinkAddr(t.Device)
if err != nil {
return err
@ -277,15 +200,19 @@ func (t *tun) Activate() error {
}
t.linkAddr = linkAddr
copy(routeAddr.IP[:], addr[:])
copy(maskAddr.IP[:], mask[:])
err = addRoute(routeSock, routeAddr, maskAddr, linkAddr)
for _, network := range t.vpnNetworks {
if network.Addr().Is4() {
err = t.activate4(network)
if err != nil {
if errors.Is(err, unix.EEXIST) {
err = fmt.Errorf("unable to add tun route, identical route already exists: %s", t.cidr)
}
return err
}
} else {
err = t.activate6(network)
if err != nil {
return err
}
}
}
// Run the interface
ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING
@ -297,8 +224,89 @@ func (t *tun) Activate() error {
return t.addRoutes(false)
}
func (t *tun) activate4(network netip.Prefix) error {
s, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
unix.IPPROTO_IP,
)
if err != nil {
return err
}
defer unix.Close(s)
ifr := ifreqAlias4{
Name: t.deviceBytes(),
Addr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: network.Addr().As4(),
},
DstAddr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: network.Addr().As4(),
},
MaskAddr: unix.RawSockaddrInet4{
Len: unix.SizeofSockaddrInet4,
Family: unix.AF_INET,
Addr: prefixToMask(network).As4(),
},
}
if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
return fmt.Errorf("failed to set tun v4 address: %s", err)
}
err = addRoute(network, t.linkAddr)
if err != nil {
return err
}
return nil
}
func (t *tun) activate6(network netip.Prefix) error {
s, err := unix.Socket(
unix.AF_INET6,
unix.SOCK_DGRAM,
unix.IPPROTO_IP,
)
if err != nil {
return err
}
defer unix.Close(s)
ifr := ifreqAlias6{
Name: t.deviceBytes(),
Addr: unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: network.Addr().As16(),
},
PrefixMask: unix.RawSockaddrInet6{
Len: unix.SizeofSockaddrInet6,
Family: unix.AF_INET6,
Addr: prefixToMask(network).As16(),
},
Lifetime: addrLifetime{
// never expires
Vltime: 0xffffffff,
Pltime: 0xffffffff,
},
//TODO: CERT-V2 should we disable DAD (duplicate address detection) and mark this as a secured address?
Flags: _IN6_IFF_NODAD,
}
if err := ioctl(uintptr(s), _SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil {
return fmt.Errorf("failed to set tun address: %s", err)
}
return nil
}
func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil {
return err
}
@ -343,7 +351,7 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
}
// Get the LinkAddr for the interface of the given name
// TODO: Is there an easier way to fetch this when we create the interface?
// Is there an easier way to fetch this when we create the interface?
// Maybe SIOCGIFINDEX? but this doesn't appear to exist in the darwin headers.
func getLinkAddr(name string) (*netroute.LinkAddr, error) {
rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0)
@ -371,38 +379,15 @@ func getLinkAddr(name string) (*netroute.LinkAddr, error) {
}
func (t *tun) addRoutes(logErrors bool) error {
routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer func() {
unix.Shutdown(routeSock, unix.SHUT_RDWR)
err := unix.Close(routeSock)
if err != nil {
t.l.WithError(err).Error("failed to close AF_ROUTE socket")
}
}()
routeAddr := &netroute.Inet4Addr{}
maskAddr := &netroute.Inet4Addr{}
routes := *t.Routes.Load()
for _, r := range routes {
if !r.Via.IsValid() || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
if !r.Cidr.Addr().Is4() {
//TODO: implement ipv6
panic("Cant handle ipv6 routes yet")
}
routeAddr.IP = r.Cidr.Addr().As4()
//TODO: we could avoid the copy
copy(maskAddr.IP[:], prefixToMask(r.Cidr))
err := addRoute(routeSock, routeAddr, maskAddr, t.linkAddr)
err := addRoute(r.Cidr, t.linkAddr)
if err != nil {
if errors.Is(err, unix.EEXIST) {
t.l.WithField("route", r.Cidr).
@ -424,36 +409,12 @@ func (t *tun) addRoutes(logErrors bool) error {
}
func (t *tun) removeRoutes(routes []Route) error {
routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer func() {
unix.Shutdown(routeSock, unix.SHUT_RDWR)
err := unix.Close(routeSock)
if err != nil {
t.l.WithError(err).Error("failed to close AF_ROUTE socket")
}
}()
routeAddr := &netroute.Inet4Addr{}
maskAddr := &netroute.Inet4Addr{}
for _, r := range routes {
if !r.Install {
continue
}
if r.Cidr.Addr().Is6() {
//TODO: implement ipv6
panic("Cant handle ipv6 routes yet")
}
routeAddr.IP = r.Cidr.Addr().As4()
copy(maskAddr.IP[:], prefixToMask(r.Cidr))
err := delRoute(routeSock, routeAddr, maskAddr, t.linkAddr)
err := delRoute(r.Cidr, t.linkAddr)
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
@ -463,23 +424,39 @@ func (t *tun) removeRoutes(routes []Route) error {
return nil
}
func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error {
r := netroute.RouteMessage{
func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := &netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_ADD,
Flags: unix.RTF_UP,
Seq: 1,
Addrs: []netroute.Addr{
unix.RTAX_DST: addr,
unix.RTAX_GATEWAY: link,
unix.RTAX_NETMASK: mask,
},
}
data, err := r.Marshal()
if prefix.Addr().Is4() {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: gateway,
}
} else {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: gateway,
}
}
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
_, err = unix.Write(sock, data[:])
if err != nil {
return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
@ -488,19 +465,34 @@ func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr)
return nil
}
func delRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error {
r := netroute.RouteMessage{
func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
}
defer unix.Close(sock)
route := netroute.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_DELETE,
Seq: 1,
Addrs: []netroute.Addr{
unix.RTAX_DST: addr,
unix.RTAX_GATEWAY: link,
unix.RTAX_NETMASK: mask,
},
}
data, err := r.Marshal()
if prefix.Addr().Is4() {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
unix.RTAX_GATEWAY: gateway,
}
} else {
route.Addrs = []netroute.Addr{
unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
unix.RTAX_GATEWAY: gateway,
}
}
data, err := route.Marshal()
if err != nil {
return fmt.Errorf("failed to create route.RouteMessage: %w", err)
}
@ -513,7 +505,6 @@ func delRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr)
}
func (t *tun) Read(to []byte) (int, error) {
buf := make([]byte, len(to)+4)
n, err := t.ReadWriteCloser.Read(buf)
@ -551,8 +542,8 @@ func (t *tun) Write(from []byte) (int, error) {
return n - 4, err
}
func (t *tun) Cidr() netip.Prefix {
return t.cidr
func (t *tun) Networks() []netip.Prefix {
return t.vpnNetworks
}
func (t *tun) Name() string {
@ -563,10 +554,12 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
}
func prefixToMask(prefix netip.Prefix) []byte {
func prefixToMask(prefix netip.Prefix) netip.Addr {
pLen := 128
if prefix.Addr().Is4() {
pLen = 32
}
return net.CIDRMask(prefix.Bits(), pLen)
addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
return addr
}

View File

@ -13,7 +13,7 @@ import (
type disabledTun struct {
read chan []byte
cidr netip.Prefix
vpnNetworks []netip.Prefix
// Track these metrics since we don't have the tun device to do it for us
tx metrics.Counter
@ -21,9 +21,9 @@ type disabledTun struct {
l *logrus.Logger
}
func newDisabledTun(cidr netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
tun := &disabledTun{
cidr: cidr,
vpnNetworks: vpnNetworks,
read: make(chan []byte, queueLen),
l: l,
}
@ -47,8 +47,8 @@ func (*disabledTun) RouteFor(addr netip.Addr) netip.Addr {
return netip.Addr{}
}
func (t *disabledTun) Cidr() netip.Prefix {
return t.cidr
func (t *disabledTun) Networks() []netip.Prefix {
return t.vpnNetworks
}
func (*disabledTun) Name() string {

View File

@ -47,7 +47,7 @@ type ifreqDestroy struct {
type tun struct {
Device string
cidr netip.Prefix
vpnNetworks []netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]]
@ -78,11 +78,11 @@ func (t *tun) Close() error {
return nil
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
}
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open existing tun device
var file *os.File
var err error
@ -150,7 +150,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
t := &tun{
ReadWriteCloser: file,
Device: deviceName,
cidr: cidr,
vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
}
@ -170,16 +170,16 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
return t, nil
}
func (t *tun) Activate() error {
func (t *tun) addIp(cidr netip.Prefix) error {
var err error
// TODO use syscalls instead of exec.Command
cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String())
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), "-interface", t.Device)
cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), "-interface", t.Device)
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err)
@ -195,8 +195,18 @@ func (t *tun) Activate() error {
return t.addRoutes(false)
}
func (t *tun) Activate() error {
for i := range t.vpnNetworks {
err := t.addIp(t.vpnNetworks[i])
if err != nil {
return err
}
}
return nil
}
func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil {
return err
}
@ -237,8 +247,8 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
return r
}
func (t *tun) Cidr() netip.Prefix {
return t.cidr
func (t *tun) Networks() []netip.Prefix {
return t.vpnNetworks
}
func (t *tun) Name() string {

View File

@ -21,20 +21,20 @@ import (
type tun struct {
io.ReadWriteCloser
cidr netip.Prefix
vpnNetworks []netip.Prefix
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger
}
func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) {
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTun not supported in iOS")
}
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
file := os.NewFile(uintptr(deviceFd), "/dev/tun")
t := &tun{
cidr: cidr,
vpnNetworks: vpnNetworks,
ReadWriteCloser: &tunReadCloser{f: file},
l: l,
}
@ -59,7 +59,7 @@ func (t *tun) Activate() error {
}
func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil {
return err
}
@ -142,8 +142,8 @@ func (tr *tunReadCloser) Close() error {
return tr.f.Close()
}
func (t *tun) Cidr() netip.Prefix {
return t.cidr
func (t *tun) Networks() []netip.Prefix {
return t.vpnNetworks
}
func (t *tun) Name() string {

View File

@ -11,6 +11,7 @@ import (
"os"
"strings"
"sync/atomic"
"time"
"unsafe"
"github.com/gaissmai/bart"
@ -25,7 +26,7 @@ type tun struct {
io.ReadWriteCloser
fd int
Device string
cidr netip.Prefix
vpnNetworks []netip.Prefix
MaxMTU int
DefaultMTU int
TXQueueLen int
@ -40,18 +41,16 @@ type tun struct {
l *logrus.Logger
}
func (t *tun) Networks() []netip.Prefix {
return t.vpnNetworks
}
type ifReq struct {
Name [16]byte
Flags uint16
pad [8]byte
}
type ifreqAddr struct {
Name [16]byte
Addr unix.RawSockaddrInet4
pad [8]byte
}
type ifreqMTU struct {
Name [16]byte
MTU int32
@ -64,10 +63,10 @@ type ifreqQLEN struct {
pad [8]byte
}
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
t, err := newTunGeneric(c, l, file, cidr)
t, err := newTunGeneric(c, l, file, vpnNetworks)
if err != nil {
return nil, err
}
@ -77,7 +76,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix
return t, nil
}
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (*tun, error) {
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
@ -112,7 +111,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (
name := strings.Trim(string(req.Name[:]), "\x00")
file := os.NewFile(uintptr(fd), "/dev/net/tun")
t, err := newTunGeneric(c, l, file, cidr)
t, err := newTunGeneric(c, l, file, vpnNetworks)
if err != nil {
return nil, err
}
@ -122,11 +121,11 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (
return t, nil
}
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Prefix) (*tun, error) {
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
t := &tun{
ReadWriteCloser: file,
fd: int(file.Fd()),
cidr: cidr,
vpnNetworks: vpnNetworks,
TXQueueLen: c.GetInt("tun.tx_queue", 500),
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
l: l,
@ -148,7 +147,7 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Pref
}
func (t *tun) reload(c *config.C, initial bool) error {
routeChange, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
routeChange, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil {
return err
}
@ -190,13 +189,15 @@ func (t *tun) reload(c *config.C, initial bool) error {
}
if oldDefaultMTU != newDefaultMTU {
err := t.setDefaultRoute()
for i := range t.vpnNetworks {
err := t.setDefaultRoute(t.vpnNetworks[i])
if err != nil {
t.l.Warn(err)
} else {
t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU)
}
}
}
// Remove first, if the system removes a wanted route hopefully it will be re-added next
t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
@ -237,10 +238,10 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
func (t *tun) Write(b []byte) (int, error) {
var nn int
max := len(b)
maximum := len(b)
for {
n, err := unix.Write(t.fd, b[nn:max])
n, err := unix.Write(t.fd, b[nn:maximum])
if n > 0 {
nn += n
}
@ -265,6 +266,58 @@ func (t *tun) deviceBytes() (o [16]byte) {
return
}
func hasNetlinkAddr(al []*netlink.Addr, x netlink.Addr) bool {
for i := range al {
if al[i].Equal(x) {
return true
}
}
return false
}
// addIPs uses netlink to add all addresses that don't exist, then it removes ones that should not be there
func (t *tun) addIPs(link netlink.Link) error {
newAddrs := make([]*netlink.Addr, len(t.vpnNetworks))
for i := range t.vpnNetworks {
newAddrs[i] = &netlink.Addr{
IPNet: &net.IPNet{
IP: t.vpnNetworks[i].Addr().AsSlice(),
Mask: net.CIDRMask(t.vpnNetworks[i].Bits(), t.vpnNetworks[i].Addr().BitLen()),
},
Label: t.vpnNetworks[i].Addr().Zone(),
}
}
//add all new addresses
for i := range newAddrs {
//TODO: CERT-V2 do we want to stack errors and try as many ops as possible?
//AddrReplace still adds new IPs, but if their properties change it will change them as well
if err := netlink.AddrReplace(link, newAddrs[i]); err != nil {
return err
}
}
//iterate over remainder, remove whoever shouldn't be there
al, err := netlink.AddrList(link, netlink.FAMILY_ALL)
if err != nil {
return fmt.Errorf("failed to get tun address list: %s", err)
}
for i := range al {
if hasNetlinkAddr(newAddrs, al[i]) {
continue
}
err = netlink.AddrDel(link, &al[i])
if err != nil {
t.l.WithError(err).Error("failed to remove address from tun address list")
} else {
t.l.WithField("removed", al[i].String()).Info("removed address not listed in cert(s)")
}
}
return nil
}
func (t *tun) Activate() error {
devName := t.deviceBytes()
@ -272,15 +325,8 @@ func (t *tun) Activate() error {
t.watchRoutes()
}
var addr, mask [4]byte
//TODO: IPV6-WORK
addr = t.cidr.Addr().As4()
tmask := net.CIDRMask(t.cidr.Bits(), 32)
copy(mask[:], tmask)
s, err := unix.Socket(
unix.AF_INET,
unix.AF_INET, //because everything we use t.ioctlFd for is address family independent, this is fine
unix.SOCK_DGRAM,
unix.IPPROTO_IP,
)
@ -289,31 +335,19 @@ func (t *tun) Activate() error {
}
t.ioctlFd = uintptr(s)
ifra := ifreqAddr{
Name: devName,
Addr: unix.RawSockaddrInet4{
Family: unix.AF_INET,
Addr: addr,
},
}
// Set the device ip address
if err = ioctl(t.ioctlFd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil {
return fmt.Errorf("failed to set tun address: %s", err)
}
// Set the device network
ifra.Addr.Addr = mask
if err = ioctl(t.ioctlFd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil {
return fmt.Errorf("failed to set tun netmask: %s", err)
}
// Set the device name
ifrf := ifReq{Name: devName}
if err = ioctl(t.ioctlFd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to set tun device name: %s", err)
}
link, err := netlink.LinkByName(t.Device)
if err != nil {
return fmt.Errorf("failed to get tun device link: %s", err)
}
t.deviceIndex = link.Attrs().Index
// Setup our default MTU
t.setMTU()
@ -324,20 +358,21 @@ func (t *tun) Activate() error {
t.l.WithError(err).Error("Failed to set tun tx queue length")
}
if err = t.addIPs(link); err != nil {
return err
}
// Bring up the interface
ifrf.Flags = ifrf.Flags | unix.IFF_UP
if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to bring the tun device up: %s", err)
}
link, err := netlink.LinkByName(t.Device)
if err != nil {
return fmt.Errorf("failed to get tun device link: %s", err)
//set route MTU
for i := range t.vpnNetworks {
if err = t.setDefaultRoute(t.vpnNetworks[i]); err != nil {
return fmt.Errorf("failed to set default route MTU: %w", err)
}
t.deviceIndex = link.Attrs().Index
if err = t.setDefaultRoute(); err != nil {
return err
}
// Set the routes
@ -363,12 +398,10 @@ func (t *tun) setMTU() {
}
}
func (t *tun) setDefaultRoute() error {
// Default route
func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
dr := &net.IPNet{
IP: t.cidr.Masked().Addr().AsSlice(),
Mask: net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen()),
IP: cidr.Masked().Addr().AsSlice(),
Mask: net.CIDRMask(cidr.Bits(), cidr.Addr().BitLen()),
}
nr := netlink.Route{
@ -377,15 +410,28 @@ func (t *tun) setDefaultRoute() error {
MTU: t.DefaultMTU,
AdvMSS: t.advMSS(Route{}),
Scope: unix.RT_SCOPE_LINK,
Src: net.IP(t.cidr.Addr().AsSlice()),
Src: net.IP(cidr.Addr().AsSlice()),
Protocol: unix.RTPROT_KERNEL,
Table: unix.RT_TABLE_MAIN,
Type: unix.RTN_UNICAST,
}
err := netlink.RouteReplace(&nr)
if err != nil {
t.l.WithError(err).WithField("cidr", cidr).Warn("Failed to set default route MTU, retrying")
//retry twice more -- on some systems there appears to be a race condition where if we set routes too soon, netlink says `invalid argument`
for i := 0; i < 2; i++ {
time.Sleep(100 * time.Millisecond)
err = netlink.RouteReplace(&nr)
if err == nil {
break
} else {
t.l.WithError(err).WithField("cidr", cidr).WithField("mtu", t.DefaultMTU).Warn("Failed to set default route MTU, retrying")
}
}
if err != nil {
return fmt.Errorf("failed to set mtu %v on the default route %v; %v", t.DefaultMTU, dr, err)
}
}
return nil
}
@ -463,10 +509,6 @@ func (t *tun) removeRoutes(routes []Route) {
}
}
func (t *tun) Cidr() netip.Prefix {
return t.cidr
}
func (t *tun) Name() string {
return t.Device
}
@ -515,7 +557,6 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
return
}
//TODO: IPV6-WORK what if not ok?
gwAddr, ok := netip.AddrFromSlice(r.Gw)
if !ok {
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
@ -523,15 +564,16 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
}
gwAddr = gwAddr.Unmap()
if !t.cidr.Contains(gwAddr) {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
return
withinNetworks := false
for i := range t.vpnNetworks {
if t.vpnNetworks[i].Contains(gwAddr) {
withinNetworks = true
break
}
if x := r.Dst.IP.To4(); x == nil {
// Nebula only handles ipv4 on the overlay currently
t.l.WithField("route", r).Debug("Ignoring route update, destination is not ipv4")
}
if !withinNetworks {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, not in our networks")
return
}
@ -563,11 +605,11 @@ func (t *tun) Close() error {
}
if t.ReadWriteCloser != nil {
t.ReadWriteCloser.Close()
_ = t.ReadWriteCloser.Close()
}
if t.ioctlFd > 0 {
os.NewFile(t.ioctlFd, "ioctlFd").Close()
_ = os.NewFile(t.ioctlFd, "ioctlFd").Close()
}
return nil

View File

@ -28,7 +28,7 @@ type ifreqDestroy struct {
type tun struct {
Device string
cidr netip.Prefix
vpnNetworks []netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]]
@ -58,13 +58,13 @@ func (t *tun) Close() error {
return nil
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
}
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open tun device
var file *os.File
var err error
@ -84,7 +84,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
t := &tun{
ReadWriteCloser: file,
Device: deviceName,
cidr: cidr,
vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
}
@ -104,17 +104,17 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
return t, nil
}
func (t *tun) Activate() error {
func (t *tun) addIp(cidr netip.Prefix) error {
var err error
// TODO use syscalls instead of exec.Command
cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String())
cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), t.cidr.Addr().String())
cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err)
@ -130,8 +130,18 @@ func (t *tun) Activate() error {
return t.addRoutes(false)
}
func (t *tun) Activate() error {
for i := range t.vpnNetworks {
err := t.addIp(t.vpnNetworks[i])
if err != nil {
return err
}
}
return nil
}
func (t *tun) reload(c *config.C, initial bool) error {
change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
if err != nil {
return err
}
@ -172,8 +182,8 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
return r
}
func (t *tun) Cidr() netip.Prefix {
return t.cidr
func (t *tun) Networks() []netip.Prefix {
return t.vpnNetworks
}
func (t *tun) Name() string {
@ -192,7 +202,7 @@ func (t *tun) addRoutes(logErrors bool) error {
continue
}
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.Addr().String())
cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
@ -213,7 +223,8 @@ func (t *tun) removeRoutes(routes []Route) error {
continue
}
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.cidr.Addr().String())
//TODO: CERT-V2 is this right?
cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
t.l.Debug("command: ", cmd.String())
if err := cmd.Run(); err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")

Some files were not shown because too many files have changed in this diff Show More